diff --git a/README.md b/README.md
index 0df3f8fc80..d1bd93c297 100644
--- a/README.md
+++ b/README.md
@@ -42,6 +42,7 @@ Trinity-RFT provides functionalities for users with different backgrounds and ob
## 🚀 News
+* [2025-12] Trinity-RFT has supported [tinker](https://thinkingmachines.ai/tinker/) training backend, which enables model training on devices **without GPUs**.
* [2025-12] Trinity-RFT powers the medical and health business of "Taobao Shangou", enabling the AI agent to understand vague symptoms, proactively ask follow-up questions, and provide precise recommendations ([News](https://tech.china.com.cn/sx/20251201/411376.shtml)).
* [2025-11] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.3)] Trinity-RFT v0.3.3 released: bug fixes.
* [2025-11] Introducing [Learn-to-Ask](https://github.com/modelscope/Trinity-RFT/tree/main/examples/learn_to_ask): a framework for training proactive dialogue agents from offline expert data ([paper](https://arxiv.org/pdf/2510.25441)).
@@ -154,6 +155,10 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
> [!NOTE]
> This project is currently under active development. Comments and suggestions are welcome!
+>
+> **No GPU? No problem!** You can still try it out:
+> 1. Follow the installation steps (feel free to skip GPU-specific packages like `flash-attn`)
+> 2. Run the **[Tinker training example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/tinker)**, which is specifically designed to work on CPU-only systems.
### Step 1: installation
@@ -186,6 +191,9 @@ Choose one of the following options:
conda create -n trinity python=3.12
conda activate trinity
+pip install -e ".[verl]"
+# If you have no GPU, use Tinker instead.
+# pip install -e ".[tinker]"
pip install -e ".[dev]"
pip install -e ".[flash_attn]"
# if you encounter issues when installing flash-attn, try:
@@ -198,6 +206,9 @@ pip install -e ".[flash_attn]"
python3.10 -m venv .venv
source .venv/bin/activate
+pip install -e ".[verl]"
+# If you have no GPU, use Tinker instead.
+# pip install -e ".[tinker]"
pip install -e ".[dev]"
pip install -e ".[flash_attn]"
# if you encounter issues when installing flash-attn, try:
diff --git a/README_zh.md b/README_zh.md
index a16063f30d..15a4a12a82 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -41,6 +41,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
## 🚀 新闻
+* [2025-12] Trinity-RFT 已支持 [tinker](https://thinkingmachines.ai/tinker/) 训练后端,可在**无 GPU 的设备**上进行模型训练。
* [2025-12] Trinity-RFT 助力淘宝闪购医药健康业务,让 AI 智能体能够理解模糊症状、主动询问后续问题,并提供精准推荐([新闻](https://tech.china.com.cn/sx/20251201/411376.shtml))。
* [2025-11] [[发布说明](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.3)] Trinity-RFT v0.3.3 发布:修复若干 Bug。
* [2025-11] 推出 [Learn-to-Ask](https://github.com/modelscope/Trinity-RFT/tree/main/examples/learn_to_ask):利用离线专家数据,训练具备主动问询能力的对话智能体([论文](https://arxiv.org/pdf/2510.25441)).
@@ -79,6 +80,10 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
> [!NOTE]
> 更多教程请参考 [Trinity-RFT 文档](https://modelscope.github.io/Trinity-RFT/)。
+>
+> 没有 GPU?没问题!你仍然可以尝试以下方案:
+> 1. 按照安装步骤操作(可跳过 GPU 专用的软件包,例如 `flash-attn`)
+> 2. 运行 **[Tinker 训练示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/tinker)**,该示例专为仅使用 CPU 的系统设计。
@@ -185,6 +190,9 @@ cd Trinity-RFT
conda create -n trinity python=3.12
conda activate trinity
+pip install -e ".[verl]"
+# 如果没有GPU,需要使用Tinker则改为
+# pip install -e ".[tinker]"
pip install -e ".[dev]"
pip install -e ".[flash_attn]"
# 如果安装 flash-attn 时遇到问题,可尝试:
@@ -197,6 +205,9 @@ pip install -e ".[flash_attn]"
python3.10 -m venv .venv
source .venv/bin/activate
+pip install -e ".[verl]"
+# 如果没有GPU,需要使用Tinker则改为
+# pip install -e ".[tinker]"
pip install -e ".[dev]"
pip install -e ".[flash_attn]"
# 如果安装 flash-attn 时遇到问题,可尝试:
diff --git a/docs/sphinx_doc/assets/tinker-gsm8k.png b/docs/sphinx_doc/assets/tinker-gsm8k.png
new file mode 100644
index 0000000000..fe0d1145a3
Binary files /dev/null and b/docs/sphinx_doc/assets/tinker-gsm8k.png differ
diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md
index b1d08f03b0..5dfc2de4dd 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md
@@ -164,9 +164,21 @@ model:
max_response_tokens: 16384
min_response_tokens: 1
enable_prompt_truncation: true
+ repetition_penalty: 1.0
+ lora_configs: null
+ rope_scaling: null
+ rope_theta: null
+ tinker:
+ enable: false
+ base_model: null
+ rank: 32
+ seed: null
+ train_mlp: true
+ train_attn: true
+ train_unembed: true
```
-- `model_path`: Path to the model being trained.
+- `model_path`: Path to the model being trained. If `tinker` is enabled, this is the path to the local tokenizer.
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
- `custom_chat_template`: Optional custom chat template in string format. If not specified, the system will use the default chat template from tokenizer.
- `chat_template_path`: Optional path to the chat template file in jinja2 type; overrides `custom_chat_template` if set. If not specified, the system will use the default chat template from tokenizer.
@@ -175,6 +187,25 @@ model:
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.
- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. This function does not work with openai api mode.
+- `repetition_penalty`: Repetition penalty factor. Default is `1.0`.
+- `lora_configs`: Optional LoRA configuration. If not specified, defaults to `null`. Currently, only one LoRA configuration is supported, and this configuration will not be applied if `tinker` is enabled.
+ - `name`: Name of the LoRA. Default is `None`.
+ - `path`: Path to the LoRA. Default is `None`.
+ - `base_model_name`: Name of the base model for LoRA. If not specified, defaults to `None`.
+ - `lora_rank`: Rank of the LoRA. Default is `32`.
+ - `lora_alpha`: Alpha value of the LoRA. Default is `32`.
+ - `lora_dtype`: Data type of the LoRA. Default is `auto`.
+ - `target_modules`: List of target modules for LoRA. Default is `all-linear`.
+- `rope_scaling`: Optional RoPE scaling configuration in JSON format. If not specified, defaults to `null`.
+- `rope_theta`: Optional RoPE theta value. If not specified, defaults to `null`.
+- `tinker`: Optional Tinker configuration. Note: LoRA configuration will be ignored if Tinker is enabled.
+ - `enable`: Whether to enable Tinker. Default is `false`.
+ - `base_model`: Path to the base model for Tinker. If not specified, defaults to `model_path`.
+ - `rank`: LoRA rank controlling the size of adaptation matrices. Default is `32`.
+ - `seed`: Random seed for Tinker. If not specified, defaults to `null`.
+ - `train_mlp`: Whether to train the MLP layer. Default is `true`.
+ - `train_attn`: Whether to train the attention layer. Default is `true`.
+ - `train_unembed`: Whether to train the unembedding layer. Default is `true`.
```{tip}
If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API.
diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
index 0fff6cd91f..ec524252e4 100644
--- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
@@ -164,9 +164,21 @@ model:
max_response_tokens: 16384
min_response_tokens: 1
enable_prompt_truncation: true
+ repetition_penalty: 1.0
+ lora_configs: null
+ rope_scaling: null
+ rope_theta: null
+ tinker:
+ enable: false
+ base_model: null
+ rank: 32
+ seed: null
+ train_mlp: true
+ train_attn: true
+ train_unembed: true
```
-- `model_path`: 被训练模型的路径。
+- `model_path`: 被训练模型的路径。如果启用了`tinker`,则该路径为本地 tokenizer 的路径。
- `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。
- `custom_chat_template`: 可选的自定义 chat template 字符串格式。若未指定,系统会使用 tokenizer 的默认 chat template。
- `chat_template_path`: 可选的 chat template 文件路径,类型通常为 jinja2;若设置,则覆盖 `custom_chat_template`。若未指定,系统会使用 tokenizer 的默认 chat template。
@@ -175,6 +187,25 @@ model:
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。在 OpenAI API 模式下不生效。
+- `repetition_penalty`:重复惩罚因子。默认值为 `1.0`。
+- `lora_configs`:可选的 LoRA 配置。若未指定,则默认为 `null`。目前仅支持一个 LoRA 配置,并且如果启用了`tinker`,则不会使用此LoRA配置。
+ - `name`:LoRA 的名称。默认为 `None`。
+ - `path`:LoRA 的路径。默认为 `None`。
+ - `base_model_name`:LoRA 所基于的基础模型名称。若未指定,则默认为 `None`。
+ - `lora_rank`:LoRA 的秩(rank)。默认为 `32`。
+ - `lora_alpha`:LoRA 的 alpha 值。默认为 `32`。
+ - `lora_dtype`:LoRA 的数据类型。默认为 `auto`。
+ - `target_modules`:LoRA 的目标模块列表。默认为 `all-linear`。
+- `rope_scaling`:可选的 RoPE 缩放配置,采用 JSON 格式。若未指定,则默认为 `null`。
+- `rope_theta`:可选的 RoPE theta 值。若未指定,则默认为 `null`。
+- `tinker`:可选的 Tinker 配置。注意:若启用 Tinker,则 LoRA 配置将被忽略。
+ - `enable`:是否启用 Tinker。默认为 `false`。
+ - `base_model`:Tinker 所使用的基础模型路径。若未指定,则默认为 `model_path`。
+ - `rank`:控制适配矩阵大小的 LoRA 秩(rank)。默认为 `32`。
+ - `seed`:Tinker 使用的随机种子。若未指定,则默认为 `null`。
+ - `train_mlp`:是否训练 MLP 层。默认为 `true`。
+ - `train_attn`:是否训练注意力层。默认为 `true`。
+ - `train_unembed`:是否训练反嵌入(unembedding)层。默认为 `true`。
```{tip}
如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。
diff --git a/examples/tinker/README.md b/examples/tinker/README.md
new file mode 100644
index 0000000000..79de2c959f
--- /dev/null
+++ b/examples/tinker/README.md
@@ -0,0 +1,245 @@
+# Trinity with Tinker Backend
+
+> [!NOTE]
+> This example demonstrates how to use Trinity with the [Tinker](https://thinkingmachines.ai/tinker/) backend, which enables model training on devices **without GPUs**.
+
+## Setup Instructions
+
+### 1. API Key Configuration
+Before starting Ray, you must set the `TRINITY_API_KEY` environment variable to your Tinker API key to enable proper access to Tinker's API:
+
+```bash
+export TRINITY_API_KEY=your_tinker_api_key
+```
+
+### 2. Configuration File
+Configure the Tinker backend in your YAML configuration file by setting the `model.tinker` parameters as shown below:
+
+```yaml
+model:
+ tinker:
+ enable: true
+ base_model: null
+ rank: 32
+ seed: null
+ train_mlp: true
+ train_attn: true
+ train_unembed: true
+```
+
+### 3. Configuration Parameters Explained
+
+- **`tinker`**: Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings (`model.lora_configs`) will be ignored.
+ - **`enable`**: Whether to activate the Tinker backend. Default: `false`
+ - **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config
+ - **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32`
+ - **`seed`**: Random seed for reproducible Tinker operations. If not specified (`null`), no specific seed is set
+ - **`train_mlp`**: Whether to train the MLP (feed-forward) layers. Default: `true`
+ - **`train_attn`**: Whether to train the attention layers. Default: `true`
+ - **`train_unembed`**: Whether to train the unembedding (output) layer. Default: `true`
+
+
+## Usage
+
+Once configured, Trinity works with the Tinker backend just like it does with the standard veRL backend. Start training with:
+
+```bash
+trinity run --config tinker.yaml # Replace with your actual config file path
+```
+
+### Important Limitations of the Tinker Backend
+
+1. **Entropy loss** is not consistent compared to veRL backends.
+2. **Algorithms requiring `compute_advantage_in_trainer=true` are NOT supported currently**, including:
+ - PPO (`algorithm.algorithm_type=ppo`)
+ - Reinforce++ (`algorithm.algorithm_type=reinforceplusplus`)
+ - RLOO (`algorithm.algorithm_type=rloo`)
+ - On-policy distillation (`algorithm.algorithm_type=on_policy_distill`)
+
+ Algorithms like `algorithm.algorithm_type=grpo` are supported. We will add support for these algorithms in the future.
+3. **Multiple stages training** is not supported currently, we will add support for this in the future.
+
+> 💡 A complete example configuration file is available at [`tinker.yaml`](tinker.yaml).
+
+
+## Results on the Llama-3.2-3B Model
+
+We trained the **Llama-3.2-3B** model on the **GSM8K** dataset using both the **Tinker** and **veRL** backends. Below are the full configuration files used in our experiments.
+
+
+Click to expand: Tinker Backend Configuration
+
+```yaml
+mode: both
+project: Trinity-RFT-gsm8k
+group: alignment-tinker
+name: tinker-llama3.2-3B-off1
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ repeat_times: 8
+ sample_strategy: default
+ kl_loss_fn_args:
+ kl_coef: 0.0
+ optimizer:
+ lr: 1.0e-05
+ lr_warmup_steps_ratio: 0.0
+ warmup_style: constant
+data_processor: {}
+model:
+ model_path: meta-llama/Llama-3.2-3B
+ max_prompt_tokens: 1024
+ max_response_tokens: 2048
+ custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
+ tinker:
+ enable: true
+ base_model: meta-llama/Llama-3.2-3B
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ batch_size: 96
+ total_epochs: 1
+ explorer_input:
+ taskset:
+ name: taskset
+ storage_type: file
+ path: openai/gsm8k
+ split: train
+ subset_name: main
+ format:
+ prompt_key: question
+ response_key: answer
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets: []
+ default_workflow_type: math_workflow
+ trainer_input:
+ experience_buffer:
+ name: experience_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: false
+explorer:
+ runner_per_model: 16
+ rollout_model:
+ engine_num: 4
+ seed: 42
+ auxiliary_models: []
+ eval_interval: 1000
+trainer:
+ save_interval: 100
+ enable_preview: true
+ grad_clip: 1.0
+ max_token_len_per_gpu: 16384
+monitor:
+ monitor_type: wandb
+synchronizer:
+ sync_method: checkpoint
+ sync_style: fixed
+ sync_interval: 1
+ sync_offset: 1
+ sync_timeout: 1200
+```
+
+
+
+
+Click to expand: veRL Backend Configuration (LoRA)
+
+```yaml
+mode: both
+project: Trinity-RFT-gsm8k
+group: alignment-tinker
+name: verl-llama3.2-3B-lora-off1
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ repeat_times: 8
+ sample_strategy: default
+ kl_loss_fn_args:
+ kl_coef: 0.0
+ optimizer:
+ lr: 1.0e-05
+ lr_warmup_steps_ratio: 0.0
+ warmup_style: constant
+data_processor: {}
+model:
+ model_path: meta-llama/Llama-3.2-3B
+ max_prompt_tokens: 1024
+ max_response_tokens: 2048
+ custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
+ lora_configs:
+ - name: lora
+ lora_rank: 32
+ lora_alpha: 32
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ batch_size: 96
+ total_epochs: 1
+ explorer_input:
+ taskset:
+ name: taskset
+ storage_type: file
+ path: openai/gsm8k
+ split: train
+ subset_name: main
+ format:
+ prompt_key: question
+ response_key: answer
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets: []
+ default_workflow_type: math_workflow
+ trainer_input:
+ experience_buffer:
+ name: experience_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: false
+explorer:
+ runner_per_model: 16
+ rollout_model:
+ engine_num: 4
+ tensor_parallel_size: 1
+ enforce_eager: false
+ enable_prefix_caching: false
+ enable_chunked_prefill: false
+ gpu_memory_utilization: 0.9
+ dtype: bfloat16
+ seed: 42
+ enable_thinking: false
+ enable_history: false
+ enable_openai_api: false
+ enable_auto_tool_choice: false
+ tool_call_parser: null
+ reasoning_parser: null
+ auxiliary_models: []
+ eval_interval: 1000
+trainer:
+ trainer_type: verl
+ save_interval: 100
+ enable_preview: true
+ grad_clip: 1.0
+ max_token_len_per_gpu: 16384
+monitor:
+ monitor_type: wandb
+synchronizer:
+ sync_method: checkpoint
+ sync_style: fixed
+ sync_interval: 1
+ sync_offset: 1
+ sync_timeout: 1200
+```
+
+
+
+### Observations
+
+Since Llama-3.2-3B is a base (non-instruct-tuned) model, it has limited ability to follow formatting instructions. Additionally, we trained for only **one epoch**. As a result, both backends achieved final rewards just slightly above 0.1. Nonetheless, the training curves show a clear upward trend in reward, indicating successful learning. The results are visualized below:
+
+
diff --git a/examples/tinker/tinker.yaml b/examples/tinker/tinker.yaml
new file mode 100644
index 0000000000..744357e745
--- /dev/null
+++ b/examples/tinker/tinker.yaml
@@ -0,0 +1,67 @@
+mode: both
+project: Trinity-RFT-gsm8k
+name: tinker-Qwen3-4B
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ repeat_times: 8
+ sample_strategy: default
+ kl_loss_fn_args:
+ kl_coef: 0.0
+ optimizer:
+ lr: 1.0e-05
+ lr_warmup_steps_ratio: 0.0
+ warmup_style: constant
+data_processor: {}
+model:
+ model_path: Qwen/Qwen3-4B-Instruct-2507
+ max_prompt_tokens: 1024
+ max_response_tokens: 2048
+ tinker:
+ enable: true
+ base_model: Qwen/Qwen3-4B-Instruct-2507
+buffer:
+ batch_size: 96
+ total_epochs: 1
+ explorer_input:
+ taskset:
+ name: taskset
+ storage_type: file
+ path: openai/gsm8k
+ split: train
+ subset_name: main
+ format:
+ prompt_key: question
+ response_key: answer
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets: []
+ default_workflow_type: math_workflow
+ trainer_input:
+ experience_buffer:
+ name: experience_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: false
+explorer:
+ runner_per_model: 8
+ rollout_model:
+ engine_num: 4
+ seed: 42
+ auxiliary_models: []
+ eval_interval: 1000
+trainer:
+ save_interval: 100
+ enable_preview: true
+ grad_clip: 1.0
+ max_token_len_per_gpu: 16384
+monitor:
+ monitor_type: tensorboard
+synchronizer:
+ sync_method: memory
+ sync_style: fixed
+ sync_interval: 1
+ sync_timeout: 1200
+log:
+ level: INFO
diff --git a/pyproject.toml b/pyproject.toml
index b7e3227a0a..8a8faae156 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,9 +21,7 @@ classifiers = [
]
requires-python = ">=3.10,<3.13"
dependencies = [
- "verl==0.5.0",
"ray[default]>=2.50.0",
- "vllm>=0.10.2,<=0.11.0",
"tensordict",
"wandb",
"omegaconf",
@@ -43,12 +41,17 @@ dependencies = [
"sortedcontainers",
"word2number",
"transformers",
+ "datasets",
]
[project.scripts]
trinity = "trinity.cli.launcher:main"
[project.optional-dependencies]
+verl = [
+ "verl==0.5.0",
+ "vllm>=0.10.2,<=0.11.0",
+]
data = [
"py-data-juicer>=1.4.3"
]
@@ -78,6 +81,9 @@ megatron = [
"transformer_engine[pytorch]==2.8.0",
"mbridge>=0.13.0",
]
+tinker = [
+ "tinker", # tinker requires python>=3.11
+]
doc = [
"sphinx",
diff --git a/tests/buffer/sample_strategy_test.py b/tests/buffer/sample_strategy_test.py
index 32ea84bdb7..c3a9af9179 100644
--- a/tests/buffer/sample_strategy_test.py
+++ b/tests/buffer/sample_strategy_test.py
@@ -58,7 +58,9 @@ def _init_buffer_writer_and_sample_strategy(self):
async def _verify_model_version(self, step, expected_versions):
batch, metrics, _ = await self.sample_strategy.sample(step=step)
self.assertEqual(
- batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}"
+ [exp.reward for exp in batch],
+ expected_versions,
+ f"Model versions mismatch at step {step}",
)
self.assertEqual(
metrics["sample/model_version/min"],
diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py
index 61fd03e675..187c0eb544 100644
--- a/tests/common/vllm_test.py
+++ b/tests/common/vllm_test.py
@@ -125,17 +125,13 @@ def setUp(self):
self.config.algorithm.repeat_times = self.repeat_times
self.config.explorer.rollout_model.enable_history = self.enable_history
self.config.check_and_update()
- from pprint import pprint
- pprint(self.config)
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(
self.engines[0], engine_type="vllm", enable_history=self.enable_history
)
- async def test_generate(
- self,
- ):
+ async def test_generate(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()
self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path)
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index 972071b99c..5854a04f8f 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -1322,3 +1322,38 @@ def test_trainer(self):
def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)
+
+
+class TestTinkerTrainer(BaseTrainerCase):
+ @unittest.skip("Require tinker API key")
+ def test_trainer(self):
+ """Test GSM8K on tinker."""
+ # test both mode
+ self.config.algorithm.algorithm_type = "grpo"
+ self.config.algorithm.repeat_times = 4
+ self.config.algorithm.advantage_fn = "grpo"
+ self.config.algorithm.advantage_fn_args = {
+ "epsilon": 1e-6,
+ }
+ self.config.buffer.total_epochs = 1
+ self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
+ self.config.model.tinker.enable = True
+ self.config.model.tinker.base_model = "Qwen/Qwen3-4B-Instruct-2507"
+ self.config.check_and_update()
+ both(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
+ rollout_metrics = parser.metric_list("rollout")
+ self.assertTrue(len(rollout_metrics) > 0)
+ pipeline_metrics = parser.metric_list("experience_pipeline")
+ self.assertTrue(len(pipeline_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
+ actor_metrics = parser.metric_list("actor")
+ self.assertTrue(len(actor_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
+ response_metrics = parser.metric_list("response_length")
+ self.assertTrue(len(response_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
+
+ def tearDown(self):
+ # remove dir only when the test passed
+ shutil.rmtree(self.config.checkpoint_job_dir)
diff --git a/trinity/algorithm/advantage_fn/asymre_advantage.py b/trinity/algorithm/advantage_fn/asymre_advantage.py
index 52f2314272..5c2ccf9932 100644
--- a/trinity/algorithm/advantage_fn/asymre_advantage.py
+++ b/trinity/algorithm/advantage_fn/asymre_advantage.py
@@ -1,10 +1,12 @@
"""AsymRE advantage computation"""
from collections import defaultdict
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple
import torch
-from verl import DataProto
+
+if TYPE_CHECKING:
+ from verl import DataProto
from trinity.algorithm.advantage_fn import AdvantageFn, GroupAdvantage
from trinity.common.experience import Experience, group_by
@@ -23,9 +25,9 @@ def __init__(
def __call__(
self,
- exps: DataProto,
+ exps: "DataProto",
**kwargs,
- ) -> Tuple[DataProto, Dict]:
+ ) -> Tuple["DataProto", Dict]:
"""Modified from compute_grpo_outcome_advantage
Compute advantage for AsymRE, operating only on Outcome reward
diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py
index 7d1c58977d..4562a5a4b9 100644
--- a/trinity/algorithm/advantage_fn/grpo_advantage.py
+++ b/trinity/algorithm/advantage_fn/grpo_advantage.py
@@ -3,10 +3,12 @@
import copy
from collections import defaultdict
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
-from verl import DataProto
+
+if TYPE_CHECKING:
+ from verl import DataProto
from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn, GroupAdvantage
from trinity.common.experience import Experience, group_by
@@ -26,9 +28,9 @@ def __init__(
def __call__(
self,
- exps: DataProto,
+ exps: "DataProto",
**kwargs,
- ) -> Tuple[DataProto, Dict]:
+ ) -> Tuple["DataProto", Dict]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py
index d5e9203e3c..8c0a586986 100644
--- a/trinity/algorithm/advantage_fn/opmd_advantage.py
+++ b/trinity/algorithm/advantage_fn/opmd_advantage.py
@@ -1,10 +1,12 @@
"""OPMD advantage computation"""
from collections import defaultdict
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple
import torch
-from verl import DataProto
+
+if TYPE_CHECKING:
+ from verl import DataProto
from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn, GroupAdvantage
from trinity.common.experience import Experience, group_by
@@ -25,9 +27,9 @@ def __init__(
def __call__(
self,
- exps: DataProto,
+ exps: "DataProto",
**kwargs,
- ) -> Tuple[DataProto, Dict]:
+ ) -> Tuple["DataProto", Dict]:
"""Modified from compute_grpo_outcome_advantage
Compute advantage for OPMD, operating only on Outcome reward
diff --git a/trinity/algorithm/key_mapper.py b/trinity/algorithm/key_mapper.py
index 09c1f988a6..4813c18fb5 100644
--- a/trinity/algorithm/key_mapper.py
+++ b/trinity/algorithm/key_mapper.py
@@ -26,4 +26,5 @@ def from_trinity(self, key: str) -> str:
"advantages": "advantages",
}
),
+ "tinker": KeyMapper({}),
}
diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py
index 65acc44d3c..8f0f6ed3d5 100644
--- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py
+++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py
@@ -8,7 +8,7 @@
from trinity.algorithm.sample_strategy.utils import representative_sample
from trinity.buffer import get_buffer_reader
from trinity.common.config import BufferConfig
-from trinity.common.experience import CustomField, Experiences
+from trinity.common.experience import CustomField, Experience
from trinity.utils.timer import Timer
@@ -53,7 +53,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
expert_buffer_config,
)
- async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
+ async def sample(self, step: int) -> Tuple[List[Experience], Dict, List]:
metrics = {}
with Timer(metrics, "time/read_experience"):
usual_exp_list = await self.usual_exp_buffer.read_async()
@@ -82,24 +82,21 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
repr_samples = representative_sample(exp_list)
self.set_model_version_metric(exp_list, metrics)
- with Timer(metrics, "time/gather_experience"):
- exps = Experiences.gather_experiences(
- experiences=exp_list,
- pad_token_id=self.pad_token_id, # type: ignore [arg-type]
- custom_fields=[
- CustomField(
- source_field="is_expert",
- destination_field="expert_mask",
- data_type=torch.bool,
- ),
- CustomField(
- source_field="step",
- destination_field="step",
- data_type=torch.int32,
- ),
- ],
- ) # type: ignore
- return exps, metrics, repr_samples
+ custom_fields = [
+ CustomField(
+ source_field="is_expert",
+ destination_field="expert_mask",
+ data_type=torch.bool,
+ ),
+ CustomField(
+ source_field="step",
+ destination_field="step",
+ data_type=torch.int32,
+ ),
+ ]
+ for exp in exp_list:
+ exp.custom_fields = custom_fields
+ return exp_list, metrics, repr_samples
@classmethod
def default_args(cls) -> Dict:
diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py
index 2ab63032cb..2398961ae6 100644
--- a/trinity/algorithm/sample_strategy/sample_strategy.py
+++ b/trinity/algorithm/sample_strategy/sample_strategy.py
@@ -4,7 +4,7 @@
from trinity.algorithm.sample_strategy.utils import representative_sample
from trinity.buffer import get_buffer_reader
from trinity.common.config import BufferConfig
-from trinity.common.experience import Experience, Experiences
+from trinity.common.experience import Experience
from trinity.utils.annotations import Deprecated
from trinity.utils.monitor import gather_metrics
from trinity.utils.timer import Timer
@@ -12,7 +12,7 @@
class SampleStrategy(ABC):
def __init__(self, buffer_config: BufferConfig, **kwargs) -> None:
- self.pad_token_id = buffer_config.pad_token_id
+ pass
def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict):
metric_list = [
@@ -23,14 +23,14 @@ def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict):
metrics.update(gather_metrics(metric_list, "sample"))
@abstractmethod
- async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
+ async def sample(self, step: int) -> Tuple[List[Experience], Dict, List]:
"""Sample data from buffer.
Args:
step (`int`): The step number of current step.
Returns:
- `Experiences`: The sampled Experiences data.
+ `List[Experience]`: The sampled List[Experience] data.
`Dict`: Metrics for logging.
`List`: Representative data for logging.
"""
@@ -54,15 +54,13 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
super().__init__(buffer_config)
self.exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type]
- async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
+ async def sample(self, step: int, **kwargs) -> Tuple[List[Experience], Dict, List]:
metrics = {}
with Timer(metrics, "time/read_experience"):
exp_list = await self.exp_buffer.read_async()
repr_samples = representative_sample(exp_list)
self.set_model_version_metric(exp_list, metrics)
- with Timer(metrics, "time/gather_experience"):
- exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
- return exps, metrics, repr_samples
+ return exp_list, metrics, repr_samples
@classmethod
def default_args(cls) -> dict:
@@ -81,16 +79,14 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
super().__init__(buffer_config)
self.max_staleness = kwargs.get("max_staleness", float("inf"))
- async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
+ async def sample(self, step: int, **kwargs) -> Tuple[List[Experience], Dict, List]:
min_model_version = max(step - self.max_staleness, 0)
metrics = {}
with Timer(metrics, "time/read_experience"):
exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version)
repr_samples = representative_sample(exp_list)
self.set_model_version_metric(exp_list, metrics)
- with Timer(metrics, "time/gather_experience"):
- exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
- return exps, metrics, repr_samples
+ return exp_list, metrics, repr_samples
@Deprecated
diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py
index b3b1d14c12..4a9977dcc0 100644
--- a/trinity/buffer/reader/queue_reader.py
+++ b/trinity/buffer/reader/queue_reader.py
@@ -18,6 +18,7 @@ def __init__(self, config: StorageConfig):
self.timeout = config.max_read_timeout
self.read_batch_size = config.batch_size
self.queue = QueueStorage.get_wrapper(config)
+ ray.get(self.queue.acquire.remote())
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
try:
@@ -47,3 +48,6 @@ def state_dict(self) -> Dict:
def load_state_dict(self, state_dict):
# Queue Not supporting state dict yet
return None
+
+ def __del__(self):
+ ray.get(self.queue.release.remote())
diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py
index f7572c628c..ea83ba2dc0 100644
--- a/trinity/buffer/reader/sql_reader.py
+++ b/trinity/buffer/reader/sql_reader.py
@@ -17,6 +17,7 @@ def __init__(self, config: StorageConfig) -> None:
assert config.storage_type == StorageType.SQL.value
self.wrap_in_ray = config.wrap_in_ray
self.storage = SQLStorage.get_wrapper(config)
+ ray.get(self.storage.acquire.remote())
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
if self.wrap_in_ray:
@@ -40,3 +41,6 @@ def state_dict(self) -> Dict:
def load_state_dict(self, state_dict):
# SQL Not supporting state dict yet
return None
+
+ def __del__(self):
+ ray.get(self.storage.release.remote())
diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py
index d9c0d95771..28ba57ecc8 100644
--- a/trinity/cli/launcher.py
+++ b/trinity/cli/launcher.py
@@ -176,9 +176,9 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
raise RuntimeError("Ray is not running, please start it by `ray start --head`.")
try:
- from trinity.trainer.verl.utils import get_latest_hf_checkpoint_path
-
if config.stages:
+ from trinity.trainer.verl.utils import get_latest_hf_checkpoint_path
+
state_manager = StateManager(
path=os.path.join(config.checkpoint_root_dir, config.project, config.name)
)
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 2a43b2e235..8fe5f23740 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -428,6 +428,17 @@ class DataProcessorConfig:
)
+@dataclass
+class TinkerConfig:
+ enable: bool = False
+ base_model: Optional[str] = None
+ rank: int = 32 # lora rank
+ seed: Optional[int] = None
+ train_mlp: bool = True
+ train_attn: bool = True
+ train_unembed: bool = True
+
+
@dataclass
class ModelConfig:
# source model path
@@ -472,11 +483,15 @@ class ModelConfig:
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None
+ # tinker config
+ tinker: TinkerConfig = field(default_factory=TinkerConfig)
+
@dataclass
class InferenceModelConfig:
# ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path
model_path: Optional[str] = None
+ tinker_base_model: Optional[str] = None
engine_type: str = "vllm"
engine_num: int = 1
@@ -1149,6 +1164,9 @@ def _check_model(self) -> None:
if not model.critic_model_path:
model.critic_model_path = model.model_path
+ if model.tinker.enable:
+ self._check_tinker()
+
# check template
if model.chat_template_path is not None and model.custom_chat_template is None:
try:
@@ -1160,7 +1178,57 @@ def _check_model(self) -> None:
)
# check max_model_len, max_prompt_tokens, max_response_tokens
+ self._check_model_len()
+
+ def _check_tinker(self) -> None:
+ model = self.model
+ from trinity.algorithm import ALGORITHM_TYPE
+
+ algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type)
+ if algorithm.use_critic:
+ raise ValueError("Critic model is not supported when using tinker!")
+
+ set_if_none(model.tinker, "base_model", model.model_path)
+
+ import tinker
+
+ service_client = tinker.ServiceClient()
+ supported_models = {
+ item.model_name for item in service_client.get_server_capabilities().supported_models
+ }
+ if model.tinker.base_model not in supported_models:
+ logger.error(f"Supported models: {supported_models}")
+ raise ValueError(f"{model.tinker.base_model} is not supported by tinker!")
+ if model.tinker.base_model != model.model_path:
+ logger.warning(
+ f"The local tokenizer will use {model.model_path}, while tinker will use {model.tinker.base_model}"
+ )
+
+ if (
+ self.algorithm.entropy_loss_fn != "none"
+ and self.algorithm.entropy_loss_fn_args.get("entropy_coef", 0.0) != 0.0
+ ):
+ logger.warning(
+ "The entropy in Tinker trainer is an estimated value; "
+ "it is recommended to set `entropy_coef` to 0."
+ )
+
+ if self.explorer.rollout_model.engine_type != "tinker":
+ self.explorer.rollout_model.engine_type = "tinker"
+ logger.warning("Rollout model engine type is set to `tinker`.")
+
+ if self.trainer.trainer_type != "tinker":
+ self.trainer.trainer_type = "tinker"
+ logger.warning("Trainer type is set to `tinker`.")
+
+ if self.synchronizer.sync_method == SyncMethod.NCCL:
+ self.synchronizer.sync_method = SyncMethod.CHECKPOINT
+ logger.warning(
+ "Tinker do not support NCCL, `synchronizer.sync_method` is set to `checkpoint`."
+ )
+ def _check_model_len(self) -> None:
+ model = self.model
# if all three are set, check if they are valid
if (
model.max_model_len is not None
@@ -1225,6 +1293,107 @@ def _check_model(self) -> None:
"`enable_prompt_truncation` is set to False; please make sure the prompt is not too long and `max_model_len` is large enough, otherwise prompt length + response length may exceed `max_model_len`!"
)
+ def _check_explorer(self) -> None:
+ rollout_args = ["temperature", "top_p", "top_k", "logprobs", "repetition_penalty"]
+ length_args = [
+ "max_model_len",
+ "max_prompt_tokens",
+ "max_response_tokens",
+ "min_response_tokens",
+ "enable_prompt_truncation",
+ ]
+ rope_args = ["rope_scaling", "rope_theta"]
+ model_args = rollout_args + length_args + rope_args
+ set_if_none(self.explorer.rollout_model, "model_path", self.model.model_path)
+ for args in model_args:
+ set_if_none(self.explorer.rollout_model, args, getattr(self.model, args))
+ if (
+ self.explorer.rollout_model.chat_template is None
+ and self.model.custom_chat_template is not None
+ ):
+ self.explorer.rollout_model.chat_template = self.model.custom_chat_template
+ for aux_model in self.explorer.auxiliary_models:
+ if not aux_model.model_path:
+ raise ValueError("auxiliary model's model_path is required.")
+ for args in model_args:
+ set_if_none(aux_model, args, getattr(self.model, args))
+
+ if self.explorer.rollout_model.engine_type == "tinker":
+ set_if_none(
+ self.explorer.rollout_model, "tinker_base_model", self.model.tinker.base_model
+ )
+ else:
+ # check gpu number
+ rollout_gpu_num = (
+ self.explorer.rollout_model.tensor_parallel_size
+ * self.explorer.rollout_model.engine_num
+ + sum(
+ (
+ model.tensor_parallel_size * model.engine_num
+ for model in self.explorer.auxiliary_models
+ )
+ )
+ )
+ assert self.cluster.node_num is not None
+ assert self.cluster.gpu_per_node is not None
+ total_gpu_num = self.cluster.node_num * self.cluster.gpu_per_node
+ if self.mode in ["explore", "bench", "serve"] and rollout_gpu_num > total_gpu_num:
+ raise ValueError(
+ f"Total GPU number ({total_gpu_num}) is less than the number of GPUs required for rollout ({rollout_gpu_num})."
+ )
+ elif self.mode == "both" and rollout_gpu_num >= total_gpu_num:
+ raise ValueError(
+ f"Not enough GPUs for trainer in 'both' mode. Explorer requires {rollout_gpu_num} GPUs, but total available GPUs are {total_gpu_num}."
+ )
+
+ if self.explorer.over_rollout.ratio > 0.0:
+ if not (0.0 <= self.explorer.over_rollout.ratio < 1.0):
+ raise ValueError("over_rollout_ratio should be in [0.0, 1.0)")
+ if self.synchronizer.sync_style == SyncStyle.FIXED:
+ raise ValueError(
+ "over_rollout_ratio is not compatible with fixed sync_style, please set "
+ "`synchronizer.sync_style` to `dynamic_by_explorer` or `dynamic_by_trainer`."
+ )
+
+ # for lora configs
+ if not self.model.tinker.enable and self.model.lora_configs is not None:
+ self.explorer.rollout_model.enable_lora = True
+ if len(self.model.lora_configs) > 1:
+ raise ValueError("Only one lora adapter is supported for now.")
+ if self.model.lora_configs[0].path is None:
+ logger.info("Creating dummy lora, since no lora_path is provided.")
+ lora_path = create_dummy_lora(
+ model_path=self.model.model_path,
+ checkpoint_job_dir=self.checkpoint_job_dir,
+ lora_rank=self.model.lora_configs[0].lora_rank,
+ lora_alpha=self.model.lora_configs[0].lora_alpha,
+ target_modules=self.model.lora_configs[0].target_modules,
+ )
+ self.model.lora_configs[0].path = lora_path
+ self.explorer.rollout_model.lora_modules = [
+ {
+ "lora_int_id": i + 1,
+ "lora_name": cfg.name,
+ "lora_path": cfg.path,
+ "base_model_name": cfg.base_model_name,
+ }
+ for i, cfg in enumerate(self.model.lora_configs)
+ ]
+ self.explorer.rollout_model.lora_kwargs = {
+ "max_loras": len(self.model.lora_configs),
+ "max_lora_rank": max(
+ (
+ model_config.lora_rank
+ for model_config in self.model.lora_configs
+ if model_config.lora_rank > 0
+ ),
+ default=0,
+ ),
+ "default_lora_path": os.path.join(
+ self.checkpoint_job_dir, "global_step_0", "actor", "lora_adapter"
+ ), # will be poped later
+ }
+
def __iter__(self):
"""Iterate over configs with each stage applied in order.
@@ -1291,99 +1460,7 @@ def check_and_update(self) -> Config: # noqa: C901
# check explorer
if self.explorer is not None:
- rollout_args = ["temperature", "top_p", "top_k", "logprobs", "repetition_penalty"]
- length_args = [
- "max_model_len",
- "max_prompt_tokens",
- "max_response_tokens",
- "min_response_tokens",
- "enable_prompt_truncation",
- ]
- rope_args = ["rope_scaling", "rope_theta"]
- model_args = rollout_args + length_args + rope_args
- for args in ["model_path"] + model_args:
- set_if_none(self.explorer.rollout_model, args, getattr(self.model, args))
- if (
- self.explorer.rollout_model.chat_template is None
- and self.model.custom_chat_template is not None
- ):
- self.explorer.rollout_model.chat_template = self.model.custom_chat_template
- for aux_model in self.explorer.auxiliary_models:
- if not aux_model.model_path:
- raise ValueError("auxiliary model's model_path is required.")
- for args in model_args:
- set_if_none(aux_model, args, getattr(self.model, args))
-
- # check gpu number
- rollout_gpu_num = (
- self.explorer.rollout_model.tensor_parallel_size
- * self.explorer.rollout_model.engine_num
- + sum(
- (
- model.tensor_parallel_size * model.engine_num
- for model in self.explorer.auxiliary_models
- )
- )
- )
- assert self.cluster.node_num is not None
- assert self.cluster.gpu_per_node is not None
- total_gpu_num = self.cluster.node_num * self.cluster.gpu_per_node
- if self.mode in ["explore", "bench", "serve"] and rollout_gpu_num > total_gpu_num:
- raise ValueError(
- f"Total GPU number ({total_gpu_num}) is less than the number of GPUs required for rollout ({rollout_gpu_num})."
- )
- elif self.mode == "both" and rollout_gpu_num >= total_gpu_num:
- raise ValueError(
- f"Not enough GPUs for trainer in 'both' mode. Explorer requires {rollout_gpu_num} GPUs, but total available GPUs are {total_gpu_num}."
- )
-
- if self.explorer.over_rollout.ratio > 0.0:
- if not (0.0 <= self.explorer.over_rollout.ratio < 1.0):
- raise ValueError("over_rollout_ratio should be in [0.0, 1.0)")
- if self.synchronizer.sync_style == SyncStyle.FIXED:
- raise ValueError(
- "over_rollout_ratio is not compatible with fixed sync_style, please set "
- "`synchronizer.sync_style` to `dynamic_by_explorer` or `dynamic_by_trainer`."
- )
-
- # for lora configs
- if self.model.lora_configs is not None:
- self.explorer.rollout_model.enable_lora = True
- if len(self.model.lora_configs) > 1:
- raise ValueError("Only one lora adapter is supported for now.")
- if self.model.lora_configs[0].path is None:
- logger.info("Creating dummy lora, since no lora_path is provided.")
- lora_path = create_dummy_lora(
- model_path=self.model.model_path,
- checkpoint_job_dir=self.checkpoint_job_dir,
- lora_rank=self.model.lora_configs[0].lora_rank,
- lora_alpha=self.model.lora_configs[0].lora_alpha,
- target_modules=self.model.lora_configs[0].target_modules,
- )
- self.model.lora_configs[0].path = lora_path
- self.explorer.rollout_model.lora_modules = [
- {
- "lora_int_id": i + 1,
- "lora_name": cfg.name,
- "lora_path": cfg.path,
- "base_model_name": cfg.base_model_name,
- }
- for i, cfg in enumerate(self.model.lora_configs)
- ]
- self.explorer.rollout_model.lora_kwargs = {
- "max_loras": len(self.model.lora_configs),
- "max_lora_rank": max(
- (
- model_config.lora_rank
- for model_config in self.model.lora_configs
- if model_config.lora_rank > 0
- ),
- default=0,
- ),
- "default_lora_path": os.path.join(
- self.checkpoint_job_dir, "global_step_0", "actor", "lora_adapter"
- ), # will be poped later
- }
+ self._check_explorer()
# check synchronizer
self.synchronizer.ray_namespace = self.ray_namespace
@@ -1391,14 +1468,17 @@ def check_and_update(self) -> Config: # noqa: C901
self.explorer.rollout_model.engine_num
* self.explorer.rollout_model.tensor_parallel_size
)
- if (
- self.mode in ["train", "explore", "bench", "serve"]
- and self.synchronizer.sync_method == SyncMethod.NCCL
- ):
- self.synchronizer.sync_method = SyncMethod.CHECKPOINT
- logger.warning(
- f"`{self.mode}` mode does not support NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`."
- )
+ if self.synchronizer.sync_method == SyncMethod.NCCL:
+ if self.mode in ["train", "explore", "bench", "serve"]:
+ self.synchronizer.sync_method = SyncMethod.CHECKPOINT
+ logger.warning(
+ f"`{self.mode}` mode does not support NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`."
+ )
+ if self.model.lora_configs is not None:
+ self.synchronizer.sync_method = SyncMethod.CHECKPOINT
+ logger.warning(
+ "LoRA is not supported with NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`."
+ )
self._check_interval()
@@ -1450,9 +1530,11 @@ def check_and_update(self) -> Config: # noqa: C901
f"Invalid trainer.save_hf_checkpoint: {self.trainer.save_hf_checkpoint}, "
"must be one of 'last', 'always', or 'never'."
)
+ self.trainer.trainer_config.synchronize_config(self)
+ elif self.trainer.trainer_type == "tinker":
+ self.trainer.trainer_config = None
else:
raise ValueError(f"Invalid trainer type: {self.trainer_type}")
- self.trainer.trainer_config.synchronize_config(self)
# check service
if self.service.data_juicer is not None:
diff --git a/trinity/common/experience.py b/trinity/common/experience.py
index 0f94e0bdd5..9fa48a59ef 100644
--- a/trinity/common/experience.py
+++ b/trinity/common/experience.py
@@ -136,6 +136,8 @@ class Experience:
# for on-policy distillation
teacher_logprobs: Optional[Tensor] = None # [resp_length]
+ custom_fields: List[CustomField] = field(default_factory=list)
+
def __init__( # noqa: C901
self,
*,
@@ -161,6 +163,7 @@ def __init__( # noqa: C901
rejected_messages=None,
multi_modal_inputs=None,
teacher_logprobs=None,
+ custom_fields=None,
):
if action_mask is not None:
experience_type = "multi_turn"
@@ -250,6 +253,7 @@ def __init__( # noqa: C901
self.rejected = torch.tensor(self.rejected)
if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor):
self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32)
+ self.custom_fields = custom_fields or []
def serialize(self) -> bytes:
"""Serialize the experience to bytes."""
diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py
index 190be581cb..46958faa6c 100644
--- a/trinity/common/models/__init__.py
+++ b/trinity/common/models/__init__.py
@@ -45,15 +45,44 @@ def create_inference_models(
from ray.util.placement_group import placement_group, placement_group_table
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
- from trinity.common.models.vllm_model import vLLMRolloutModel
-
logger = get_logger(__name__)
engine_num = config.explorer.rollout_model.engine_num
tensor_parallel_size = config.explorer.rollout_model.tensor_parallel_size
rollout_engines = []
if config.explorer.rollout_model.engine_type.startswith("vllm"):
+ from trinity.common.models.vllm_model import vLLMRolloutModel
+
engine_cls = vLLMRolloutModel
+ elif config.explorer.rollout_model.engine_type == "tinker":
+ from trinity.common.models.tinker_model import TinkerModel
+
+ engine_cls = TinkerModel
+ namespace = ray.get_runtime_context().namespace
+ rollout_engines = [
+ ray.remote(engine_cls)
+ .options(
+ name=f"{config.explorer.name}_rollout_model_{i}",
+ namespace=namespace,
+ )
+ .remote(
+ config=config.explorer.rollout_model,
+ )
+ for i in range(engine_num)
+ ]
+ auxiliary_engines = [
+ ray.remote(engine_cls)
+ .options(
+ name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
+ namespace=namespace,
+ )
+ .remote(
+ config=config.explorer.auxiliary_models[i],
+ )
+ for i, model_config in enumerate(config.explorer.auxiliary_models)
+ for j in range(model_config.engine_num)
+ ]
+ return rollout_engines, auxiliary_engines
else:
raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}")
@@ -124,7 +153,7 @@ def create_inference_models(
model_config.engine_type = "vllm"
model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine])
engines.append(
- ray.remote(vLLMRolloutModel)
+ ray.remote(engine_cls)
.options(
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
num_cpus=0,
diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py
index 3fe6f2bf37..00d27b5742 100644
--- a/trinity/common/models/model.py
+++ b/trinity/common/models/model.py
@@ -46,6 +46,10 @@ async def prepare(self) -> None:
"""Prepare the model before inference."""
pass
+ @abstractmethod
+ async def sync_model(self, model_version: int) -> int:
+ """Sync the model with the latest model_version."""
+
@abstractmethod
def get_model_version(self) -> int:
"""Get the checkpoint version."""
@@ -105,7 +109,9 @@ def __init__(
enable_history (bool): Whether to enable history recording. Default to False.
enable_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models.
"""
- assert engine_type.startswith("vllm"), "Only vLLM model is supported for now."
+ assert (
+ engine_type.startswith("vllm") or engine_type == "tinker"
+ ), "Only vLLM and tinker model is supported for now."
self.model = model
self.api_address: str = None
self.openai_client: openai.OpenAI = None
@@ -205,13 +211,13 @@ async def generate_mm_async(
def chat(self, messages: List[dict], **kwargs) -> List[Experience]:
"""Generate a list of experiences from a list of messages."""
lora_request = self.get_lora_request()
- return ray.get(self.model.chat.remote(messages, lora_request, **kwargs))
+ return ray.get(self.model.chat.remote(messages, lora_request=lora_request, **kwargs))
@_history_recorder
async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]:
"""Generate a list of experiences from a list of messages in async."""
lora_request = await self.get_lora_request_async()
- return await self.model.chat.remote(messages, lora_request, **kwargs)
+ return await self.model.chat.remote(messages, lora_request=lora_request, **kwargs)
@_history_recorder
def chat_mm(
diff --git a/trinity/common/models/tinker_model.py b/trinity/common/models/tinker_model.py
new file mode 100644
index 0000000000..951a4cf322
--- /dev/null
+++ b/trinity/common/models/tinker_model.py
@@ -0,0 +1,203 @@
+from typing import List, Optional, Sequence
+
+import ray
+import tinker
+import torch
+import transformers
+from tinker import types
+from torch import Tensor
+
+from trinity.common.config import InferenceModelConfig
+from trinity.common.experience import Experience
+from trinity.common.models.model import InferenceModel
+from trinity.common.models.utils import get_action_mask_method
+from trinity.manager.synchronizer import Synchronizer
+from trinity.utils.log import get_logger
+
+
+class TinkerModel(InferenceModel):
+ def __init__(
+ self,
+ config: InferenceModelConfig,
+ ) -> None:
+ self.config = config
+ self.model_version = -1
+ self.synchronizer = Synchronizer.get_actor(namespace=ray.get_runtime_context().namespace)
+ self.logger = get_logger(__name__)
+ self.model = None
+ self.tokenizer = None
+ self.chat_template = None
+ if self.config.chat_template:
+ self.chat_template = self.config.chat_template
+ self.action_mask_method = get_action_mask_method(self.chat_template)
+ self.enable_thinking = config.enable_thinking
+
+ async def _initialize_tokenizer(self) -> None:
+ """Initialize the tokenizer."""
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.config.model_path)
+
+ async def _generate_internal(self, prompt: dict, **kwargs) -> types.SampleResponse:
+ assert self.model is not None
+ sampling_params = {
+ "max_tokens": kwargs.get("max_tokens", self.config.max_response_tokens),
+ "seed": kwargs.get("seed", self.config.seed),
+ "temperature": kwargs.get("temperature", 1.0),
+ "top_k": kwargs.get("top_k", -1),
+ "top_p": kwargs.get("top_p", 1),
+ }
+
+ return await self.model.sample_async(
+ prompt=types.ModelInput.from_ints(prompt["prompt_token_ids"]),
+ sampling_params=sampling_params,
+ num_samples=kwargs.get("n", 1),
+ include_prompt_logprobs=kwargs.get("include_prompt_logprobs", False),
+ topk_prompt_logprobs=kwargs.get("topk_prompt_logprobs", self.config.logprobs),
+ )
+
+ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]:
+ """Generate a responses from a prompt in async."""
+ if self.tokenizer is None:
+ await self._initialize_tokenizer()
+
+ # Tokenize once without truncation to check if truncation is needed
+ token_ids = self.tokenizer( # type: ignore
+ prompt,
+ truncation=False,
+ return_tensors="pt",
+ )[
+ "input_ids"
+ ][0].tolist()
+
+ # Check if truncation is needed and apply it
+ if self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None:
+ if len(token_ids) > self.config.max_prompt_tokens:
+ self.logger.warning(
+ f"Prompt was truncated to {self.config.max_prompt_tokens} tokens"
+ )
+ token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response
+ return [
+ Experience(
+ tokens=token_ids,
+ logprobs=torch.zeros(1, dtype=torch.float32),
+ prompt_length=len(token_ids) - 1,
+ prompt_text=self.tokenizer.decode(token_ids[:-1]),
+ response_text=self.tokenizer.decode(token_ids[-1]),
+ truncate_status="prompt_truncated",
+ reward=0.0,
+ )
+ for _ in range(kwargs.get("n", 1))
+ ]
+
+ output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs)
+ experiences = [
+ Experience(
+ tokens=torch.tensor(token_ids + sequence.tokens, dtype=torch.int32),
+ logprobs=torch.tensor(sequence.logprobs, dtype=torch.float32),
+ prompt_length=len(token_ids),
+ prompt_text=self.tokenizer.decode(token_ids),
+ response_text=self.tokenizer.decode(sequence.tokens),
+ )
+ for sequence in output.sequences
+ ]
+
+ return experiences
+
+ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
+ """Generate experiences from a list of history chat messages in async."""
+ if self.tokenizer is None:
+ await self._initialize_tokenizer()
+ if self.chat_template is None:
+ self.chat_template = self.tokenizer.get_chat_template()
+ if messages[-1]["role"] == "assistant":
+ prompt = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ continue_final_message=True,
+ chat_template=self.chat_template,
+ )
+ else:
+ prompt = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ chat_template=self.chat_template,
+ enable_thinking=self.enable_thinking,
+ )
+ return await self.generate(prompt=prompt, **kwargs)
+
+ async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor:
+ """Generate logprobs for a list of tokens in async."""
+ logprobs = await self.model.compute_logprobs_async(types.ModelInput(token_ids))
+ return torch.tensor(logprobs[1:], dtype=torch.float32)
+
+ async def convert_messages_to_experience(
+ self,
+ messages: List[dict],
+ tools: Optional[List[dict]] = None,
+ temperature: Optional[float] = None,
+ ) -> Experience:
+ """Convert a list of messages into an experience in async."""
+ if self.tokenizer is None:
+ await self._initialize_tokenizer()
+ if self.chat_template is None:
+ self.chat_template = self.tokenizer.get_chat_template()
+ token_ids, action_mask, prompt_length = self.action_mask_method(
+ tokenizer=self.tokenizer,
+ messages=messages,
+ tools=tools,
+ chat_template=self.chat_template,
+ enable_thinking=self.enable_thinking,
+ ) # (seq_length, ), (seq_length, )
+
+ # Truncate tokens if they exceed the length limit
+ assert token_ids is not None
+ truncate_status = None
+ if self.config.max_model_len is not None and self.config.max_model_len > 0:
+ if len(token_ids) > self.config.max_model_len - 1:
+ truncate_status = "response_truncated"
+ self.logger.warning(
+ f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}"
+ )
+ token_ids = token_ids[: self.config.max_model_len - 1]
+ action_mask = action_mask[: self.config.max_model_len - 1]
+
+ temperature = temperature if temperature is not None else self.config.temperature
+ logprobs = await self.logprobs(
+ token_ids=token_ids.tolist(), temperature=temperature
+ ) # (seq_length - 1,)
+ return Experience(
+ tokens=token_ids,
+ logprobs=logprobs[prompt_length - 1 :],
+ prompt_length=prompt_length,
+ action_mask=action_mask[prompt_length:], # Exclude the prompt tokens
+ messages=messages,
+ truncate_status=truncate_status,
+ )
+
+ async def prepare(self) -> None:
+ """Prepare the model before inference."""
+ self.service_client = tinker.ServiceClient()
+ self.model = await self.service_client.create_sampling_client_async(
+ base_model=self.config.tinker_base_model,
+ )
+
+ async def sync_model(self, model_version: int) -> int:
+ self.model_version = model_version
+ remote_sampler_path, _ = await self.synchronizer.get_model_state_dict.remote()
+ self.model = await self.service_client.create_sampling_client_async(
+ model_path=remote_sampler_path,
+ )
+ return model_version
+
+ def get_model_version(self) -> int:
+ """Get the checkpoint version."""
+ return self.model_version
+
+ def get_api_server_url(self) -> Optional[str]:
+ """Get the API server URL if available."""
+ # TODO: tinker will support openai api later
+ return None
+
+ def get_model_path(self) -> Optional[str]:
+ """Get the model path"""
+ return self.config.model_path # type: ignore [return-value]
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index 767d30a5ed..5912d951c1 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -50,6 +50,7 @@ def __init__(self, config: Config):
self.last_monitored_step = self.explore_step_num
self.synchronizer = Synchronizer.get_actor(config)
self.config = config
+ self.model_type = config.explorer.rollout_model.engine_type
self.models, self.auxiliary_models = create_inference_models(config)
self.experience_pipeline = self._init_experience_pipeline()
self.taskset = (
@@ -149,7 +150,10 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> in
async def _pull_latest_weights(self):
self.logger.info("Start to pull latest model weights.")
- new_version = await self.synchronizer.wait_new_model_state_dict.remote(self.model_version)
+ new_version = await self.synchronizer.wait_new_model_state_dict.remote(
+ current_version=self.model_version,
+ no_wait=(self.config.synchronizer.sync_style != SyncStyle.FIXED),
+ )
if new_version > self.model_version:
if self.model_version != -1:
self.logger.info(f"New model weights version: {new_version}")
@@ -195,7 +199,7 @@ async def prepare(self) -> None:
await asyncio.gather(*run_api_ref)
self.logger.info("All models are ready.")
- if not self.use_nccl_sync:
+ if not self.use_nccl_sync and self.model_type != "tinker":
if self.config.mode == "serve":
# In serving mode, each engine will setup its own process group
await self.setup_model_level_weight_sync_group()
diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py
index 8157ad088d..c0b913812e 100644
--- a/trinity/manager/synchronizer.py
+++ b/trinity/manager/synchronizer.py
@@ -79,7 +79,16 @@ async def _check_modules(self) -> None:
pass
async def _find_latest_state_dict(self) -> None:
- assert self.config.trainer.trainer_type == "verl"
+ if self.config.trainer.trainer_type == "verl":
+ await self._find_verl_latest_state_dict()
+ elif self.config.trainer.trainer_type == "tinker":
+ await self._find_tinker_latest_state_dict()
+ else:
+ self.logger.warning(
+ "Synchronizer does not support this trainer type. Please use `verl` or `tinker`."
+ )
+
+ async def _find_verl_latest_state_dict(self) -> None:
default_local_dir = self.config.checkpoint_job_dir
local_latest_state_dict_iteration = os.path.join(
default_local_dir, "latest_state_dict_iteration.txt"
@@ -112,6 +121,33 @@ async def _find_latest_state_dict(self) -> None:
await self.set_model_state_dict(model_state_dict, latest_model_version)
await asyncio.sleep(1)
+ async def _find_tinker_latest_state_dict(self) -> None:
+ default_local_dir = self.config.checkpoint_job_dir
+ local_latest_state_dict_iteration = os.path.join(
+ default_local_dir, "latest_state_dict_iteration.txt"
+ )
+ while True:
+ if os.path.exists(local_latest_state_dict_iteration):
+ try:
+ with open(local_latest_state_dict_iteration, "r") as f:
+ latest_model_version = int(f.read().strip())
+ except (IOError, ValueError) as e:
+ self.logger.warning(f"Failed to read or parse state dict iteration file: {e}")
+ continue
+ if latest_model_version > self.model_version:
+ self.logger.info(
+ f"Synchronizer has found a new remote tinker sampler path at step {latest_model_version}."
+ )
+ remote_path_file = os.path.join(
+ default_local_dir,
+ f"global_step_{latest_model_version}",
+ "remote_sampler_path.txt",
+ )
+ with open(remote_path_file, "r") as f:
+ remote_sampler_path = f.read().strip()
+ await self.set_model_state_dict(remote_sampler_path, latest_model_version)
+ await asyncio.sleep(1)
+
async def set_trainer_status(self, status: RunningStatus):
"""Update the status of the trainer."""
async with self._ready_condition:
@@ -192,7 +228,7 @@ async def set_model_state_dict_with_step_num(
return checkpoint_step_num
async def set_model_state_dict(
- self, model_state_dict: Union[dict, None, Tuple[str, str]], trainer_step: int
+ self, model_state_dict: Union[dict, None, str, Tuple[str, str]], trainer_step: int
):
"""
Set the new model state and update the version.
diff --git a/trinity/trainer/tinker/__init__.py b/trinity/trainer/tinker/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/trainer/tinker/utils.py b/trinity/trainer/tinker/utils.py
new file mode 100644
index 0000000000..544ef89768
--- /dev/null
+++ b/trinity/trainer/tinker/utils.py
@@ -0,0 +1,243 @@
+from logging import Logger
+from typing import Any, List, Tuple
+
+import torch
+from tinker import types
+
+from trinity.common.experience import Experience, split_dpo_experience_to_single_turn
+
+
+def to_tinker_input(
+ experiences: List[Experience], logger: Logger
+) -> Tuple[List[types.Datum], List[types.ModelInput], List[dict]]:
+ assert len(experiences) > 0, "No experiences provided."
+ if experiences[0].experience_type == "dpo":
+ experiences = split_dpo_experience_to_single_turn(experiences)
+
+ batch = []
+ batch_input_tokens = []
+ model_inputs_list = []
+ for exp in experiences:
+ tokens = exp.tokens
+ input_tokens = tokens.long()
+ prompt_length = exp.prompt_length
+ total_length = len(tokens) # type: ignore
+ response_length = total_length - prompt_length
+ loss_fn_inputs = {
+ "weights": torch.concat(
+ [
+ torch.zeros(prompt_length - 1, dtype=torch.float32),
+ exp.action_mask.float(),
+ ]
+ ),
+ "target_tokens": input_tokens.tolist()[1:],
+ }
+ model_inputs = {
+ "total_length": total_length,
+ "action_mask": exp.action_mask,
+ }
+ if exp.reward is not None or exp.token_level_reward is not None:
+ assert exp.logprobs is not None
+ if exp.token_level_reward is not None:
+ if exp.reward is not None:
+ logger.warning(
+ "Both exp.rewards and exp.token_level_rewards are provided. "
+ "Using exp.token_level_rewards."
+ )
+ token_level_reward = exp.token_level_reward
+ else:
+ token_level_reward = torch.zeros(response_length, dtype=torch.float32)
+ token_level_reward[-1] = exp.reward
+ model_inputs.update(
+ {
+ "token_level_scores": token_level_reward,
+ "old_logprob": exp.logprobs,
+ }
+ )
+ for attr in ["advantages", "returns", "teacher_logprobs"]:
+ if getattr(exp, attr, None) is not None:
+ model_inputs[attr] = getattr(exp, attr)
+ # TODO: if tinker support multi-modal input, we can add it here
+ for custom_field in exp.custom_fields:
+ model_inputs[custom_field.destination_field] = torch.tensor(
+ exp.info[custom_field.source_field],
+ dtype=custom_field.data_type,
+ )
+
+ batch.append(
+ types.Datum(
+ model_input=types.ModelInput.from_ints(tokens=input_tokens.tolist()[:-1]),
+ loss_fn_inputs=loss_fn_inputs,
+ )
+ )
+ batch_input_tokens.append(types.ModelInput.from_ints(input_tokens.tolist()))
+ model_inputs_list.append(model_inputs)
+ return batch, batch_input_tokens, model_inputs_list
+
+
+def compute_data_metrics(batch: List[dict[str, torch.Tensor]]) -> dict:
+ """
+ Computes various metrics from a batch of data for PPO training.
+ Modified from `verl.trainer.ppo.metric_utils.compute_data_metrics`.
+
+ This function calculates metrics related to scores, rewards, advantages, returns, values,
+ and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
+ for each metric category.
+
+ Args:
+ batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
+ use_critic: Whether to include critic-specific metrics. Defaults to True.
+
+ Returns:
+ A dictionary of metrics including:
+ - critic/score/mean, max, min: Statistics about sequence scores
+ - critic/rewards/mean, max, min: Statistics about sequence rewards
+ - critic/advantages/mean, max, min: Statistics about advantages
+ - critic/returns/mean, max, min: Statistics about returns
+ - critic/values/mean, max, min: Statistics about critic values
+ - critic/vf_explained_var: Explained variance of the value function
+ - response_length/mean, max, min, clip_ratio: Statistics about response lengths
+ - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
+ """
+ metrics = {}
+
+ assert len(batch) > 0, "Batch is empty"
+
+ if "token_level_rewards" in batch[0] and "token_level_scores" in batch[0]:
+ sequence_score = torch.tensor([data["token_level_scores"].sum() for data in batch])
+ sequence_reward = torch.tensor([data["token_level_rewards"].sum() for data in batch])
+ metrics.update(
+ {
+ # score
+ "critic/score/mean": torch.mean(sequence_score).detach().item(),
+ "critic/score/max": torch.max(sequence_score).detach().item(),
+ "critic/score/min": torch.min(sequence_score).detach().item(),
+ # reward
+ "critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
+ "critic/rewards/max": torch.max(sequence_reward).detach().item(),
+ "critic/rewards/min": torch.min(sequence_reward).detach().item(),
+ }
+ )
+
+ response_length = torch.tensor([len(data["action_mask"]) for data in batch]).float()
+ token_length = torch.tensor([data["total_length"] for data in batch]).float()
+ prompt_length = token_length - response_length
+ max_response_length = max(response_length)
+ max_prompt_length = max(prompt_length)
+ metrics.update(
+ {
+ # response length
+ "response_length/mean": torch.mean(response_length).detach().item(),
+ "response_length/max": torch.max(response_length).detach().item(),
+ "response_length/min": torch.min(response_length).detach().item(),
+ "response_length/clip_ratio": torch.mean(
+ torch.eq(response_length, max_response_length).float()
+ )
+ .detach()
+ .item(),
+ # prompt length
+ "prompt_length/mean": torch.mean(prompt_length).detach().item(),
+ "prompt_length/max": torch.max(prompt_length).detach().item(),
+ "prompt_length/min": torch.min(prompt_length).detach().item(),
+ "prompt_length/clip_ratio": torch.mean(
+ torch.eq(prompt_length, max_prompt_length).float()
+ )
+ .detach()
+ .item(),
+ }
+ )
+
+ if "advantages" in batch[0]:
+ valid_adv = torch.concat([data["advantages"] for data in batch])
+ metrics.update(
+ {
+ "critic/advantages/mean": torch.mean(valid_adv).detach().item(),
+ "critic/advantages/max": torch.max(valid_adv).detach().item(),
+ "critic/advantages/min": torch.min(valid_adv).detach().item(),
+ }
+ )
+ if "returns" in batch[0]:
+ valid_returns = torch.concat([data["returns"] for data in batch])
+ metrics.update(
+ {
+ "critic/returns/mean": torch.mean(valid_returns).detach().item(),
+ "critic/returns/max": torch.max(valid_returns).detach().item(),
+ "critic/returns/min": torch.min(valid_returns).detach().item(),
+ }
+ )
+
+ return metrics
+
+
+def compute_timing_metrics(
+ batch: List[dict[str, torch.Tensor]], timing_raw: dict[str, float]
+) -> dict[str, Any]:
+ """
+ Computes timing metrics for different processing stages in PPO training.
+ Modified from `verl.trainer.ppo.metric_utils.compute_timing_metrics`.
+
+ This function calculates both raw timing metrics (in seconds) and per-token timing metrics
+ (in milliseconds) for various processing stages like generation, reference computation,
+ value computation, advantage computation, and model updates.
+
+ Args:
+ batch: A DataProto object containing batch data with responses and attention masks.
+ timing_raw: A dictionary mapping stage names to their execution times in seconds.
+
+ Returns:
+ A dictionary containing:
+ - timing_s/{name}: Raw timing in seconds for each stage
+ - timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage
+
+ Note:
+ Different stages use different token counts for normalization:
+ - "gen" uses only response tokens
+ - Other stages ("ref", "values", "adv", "update_critic", "update_actor") use all tokens
+ (prompt + response)
+ """
+ num_overall_tokens = sum(data["total_length"] for data in batch)
+ num_response_tokens = sum(len(data["action_mask"]) for data in batch)
+
+ num_tokens_of_section = {
+ "gen": num_response_tokens,
+ **{
+ name: num_overall_tokens
+ for name in ["ref", "values", "adv", "update_critic", "update_actor"]
+ },
+ }
+
+ return {
+ **{f"timing_s/{name}": value for name, value in timing_raw.items()},
+ **{
+ f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
+ for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
+ },
+ }
+
+
+def compute_throughout_metrics(
+ batch: List[dict[str, torch.Tensor]], timing_raw: dict[str, float]
+) -> dict[str, Any]:
+ """
+ Computes throughput metrics for PPO training.
+ Modified from `verl.trainer.ppo.metric_utils.compute_throughout_metrics`.
+
+ This function calculates performance metrics related to token processing speed,
+ including the total number of tokens processed and time per step.
+
+ Args:
+ batch: A DataProto object containing batch data with meta information about token counts.
+ timing_raw: A dictionary mapping stage names to their execution times in seconds.
+ Must contain a "step" key with the total step time.
+
+ Returns:
+ A dictionary containing:
+ - perf/total_num_tokens: Total number of tokens processed in the batch
+ - perf/time_per_step: Time taken for the step in seconds
+ """
+ total_num_tokens = sum(data["total_length"] for data in batch)
+ time = timing_raw["step"]
+ return {
+ "perf/total_num_tokens": total_num_tokens,
+ "perf/time_per_step": time,
+ }
diff --git a/trinity/trainer/tinker_trainer.py b/trinity/trainer/tinker_trainer.py
new file mode 100644
index 0000000000..3c7468923f
--- /dev/null
+++ b/trinity/trainer/tinker_trainer.py
@@ -0,0 +1,324 @@
+import os
+from typing import Dict, List
+
+import ray
+import tinker
+import torch
+from tinker import types
+
+from trinity.algorithm import ALGORITHM_TYPE
+from trinity.algorithm.advantage_fn import ADVANTAGE_FN
+from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN
+from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn
+from trinity.algorithm.kl_fn import KL_FN
+from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN
+from trinity.algorithm.utils import prefix_metrics
+from trinity.common.config import Config
+from trinity.common.experience import Experience
+from trinity.manager.synchronizer import Synchronizer
+from trinity.trainer.tinker.utils import (
+ compute_data_metrics,
+ compute_throughout_metrics,
+ compute_timing_metrics,
+ to_tinker_input,
+)
+from trinity.trainer.trainer import TrainEngineWrapper
+from trinity.utils.log import get_logger
+from trinity.utils.timer import Timer
+
+
+class TinkerTrainerWrapper(TrainEngineWrapper):
+ def __init__(self, config: Config):
+ self.config = config
+ self.logger = get_logger("tinker_trainer")
+ self._init_algorithm()
+ self.synchronizer = Synchronizer.get_actor(namespace=self.config.synchronizer.ray_namespace)
+
+ def _init_algorithm(self):
+ self.algorithm = ALGORITHM_TYPE.get(self.config.algorithm.algorithm_type)
+ algorithm_config = self.config.algorithm
+ if self.algorithm.compute_advantage_in_trainer:
+ self.advantage_fn = ADVANTAGE_FN.get(algorithm_config.advantage_fn)(
+ **algorithm_config.advantage_fn_args
+ )
+ self.kl_fn = KL_FN.get(algorithm_config.kl_penalty_fn)(
+ **algorithm_config.kl_penalty_fn_args
+ )
+ # TODO
+ raise NotImplementedError(
+ "`compute_advantage_in_trainer` is not implemented yet in tinker"
+ )
+ self.loss_agg_mode = algorithm_config.loss_agg_mode
+ self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)(
+ backend="tinker", **algorithm_config.policy_loss_fn_args
+ )
+ self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args)
+ self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)(
+ **algorithm_config.entropy_loss_fn_args
+ )
+
+ # EXPERIMENTAL: apply loss scale fix
+ self.do_fix_actor_microbatch_loss_scale = (
+ self.config.trainer.fix_actor_microbatch_loss_scale
+ and (self.loss_agg_mode == "token-mean")
+ )
+
+ self.adam_params = types.AdamParams(
+ learning_rate=algorithm_config.optimizer.lr,
+ beta1=algorithm_config.optimizer.betas[0],
+ beta2=algorithm_config.optimizer.betas[1],
+ # eps is currently not in config
+ weight_decay=algorithm_config.optimizer.weight_decay,
+ grad_clip_norm=self.config.trainer.grad_clip,
+ )
+
+ async def prepare(self):
+ self.service_client = tinker.ServiceClient()
+
+ name_prefix_list = [self.config.project, self.config.group, self.config.name]
+ self.tinker_checkpoint_name_prefix = "-".join(
+ [prefix for prefix in name_prefix_list if prefix]
+ )
+ self.default_local_dir = self.config.checkpoint_job_dir
+
+ self.local_latest_checkpointed_iteration = os.path.join(
+ self.config.checkpoint_job_dir, "latest_checkpointed_iteration.txt"
+ )
+ self.local_latest_state_dict_iteration = os.path.join(
+ self.config.checkpoint_job_dir, "latest_state_dict_iteration.txt"
+ )
+
+ if os.path.exists(self.local_latest_checkpointed_iteration):
+ with open(self.local_latest_checkpointed_iteration, "r") as f:
+ self._train_step_num = self.latest_remote_checkpoint_step = int(f.read().strip())
+ checkpoint_file_path = os.path.join(
+ self.default_local_dir,
+ f"global_step_{self._train_step_num}",
+ "remote_checkpoint_path.txt",
+ )
+ with open(checkpoint_file_path, "r") as f:
+ self.latest_remote_checkpoint_path = f.read().strip()
+ self.actor_client = (
+ await self.service_client.create_training_client_from_state_with_optimizer_async(
+ path=self.latest_remote_checkpoint_path,
+ )
+ )
+ else:
+ self.actor_client = await self.service_client.create_lora_training_client_async(
+ base_model=self.config.model.tinker.base_model,
+ rank=self.config.model.tinker.rank,
+ seed=self.config.model.tinker.seed,
+ train_mlp=self.config.model.tinker.train_mlp,
+ train_attn=self.config.model.tinker.train_attn,
+ train_unembed=self.config.model.tinker.train_unembed,
+ )
+ self.latest_remote_checkpoint_step = 0
+ self.latest_remote_checkpoint_path = None
+ self._train_step_num = 0
+
+ if os.path.exists(self.local_latest_state_dict_iteration):
+ with open(self.local_latest_state_dict_iteration, "r") as f:
+ self.latest_remote_sampler_step = int(f.read().strip())
+ sampler_file_path = os.path.join(
+ self.default_local_dir,
+ f"global_step_{self.latest_remote_sampler_step}",
+ "remote_sampler_path.txt",
+ )
+ with open(sampler_file_path, "r") as f:
+ self.latest_remote_sampler_path = f.read().strip()
+ else:
+ self.latest_remote_sampler_step = 0
+ self.latest_remote_sampler_path = None
+
+ self.ref_client = await self.service_client.create_sampling_client_async(
+ base_model=self.config.model.tinker.base_model,
+ )
+
+ @property
+ def train_step_num(self) -> int:
+ """Get the current training step number."""
+ return self._train_step_num
+
+ def _loss_func(
+ self, batch: list[types.Datum], logprobs: list[torch.Tensor]
+ ) -> tuple[torch.Tensor, dict[str, float]]:
+ total_loss = 0.0
+ metrics = {}
+ assert len(self.model_inputs_list) == len(
+ logprobs
+ ), "len(self.model_inputs_list) must equal to len(logprobs)"
+ for model_inputs, logprob in zip(self.model_inputs_list, logprobs):
+ micro_batch_metrics = {}
+ response_mask = model_inputs["action_mask"]
+ logprob = logprob[-response_mask.shape[0] :]
+
+ pg_loss, pg_loss_metrics = self.policy_loss_fn(logprob=logprob, **model_inputs)
+ prefix_metrics(
+ src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics
+ )
+
+ if self.entropy_loss_fn != DummyEntropyLossFn:
+ entropy = -(logprob * logprob.exp())
+ else:
+ entropy = None
+ # compute entropy loss from entropy
+ entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore
+ entropy=entropy,
+ **model_inputs,
+ loss_agg_mode=self.loss_agg_mode,
+ )
+ prefix_metrics(
+ src_metrics=entropy_loss_metrics,
+ prefix="actor",
+ dst_metrics=micro_batch_metrics,
+ )
+
+ # compute kl loss
+ kl_loss, kl_loss_metrics = self.kl_loss_fn.calculate_kl_loss(
+ logprob=logprob,
+ ref_logprob=model_inputs["ref_logprob"],
+ response_mask=response_mask,
+ loss_agg_mode=self.loss_agg_mode,
+ old_logprob=model_inputs["old_logprob"],
+ )
+ prefix_metrics(
+ src_metrics=kl_loss_metrics,
+ prefix="actor",
+ dst_metrics=micro_batch_metrics,
+ )
+
+ # compute policy loss
+ policy_loss = pg_loss - entropy_loss + kl_loss
+ loss_scale = 1.0
+ if not self.do_fix_actor_microbatch_loss_scale:
+ loss_scale /= len(logprobs)
+ loss = policy_loss * loss_scale
+ total_loss = total_loss + loss
+ micro_batch_metrics["actor/final_loss"] = loss.detach().item()
+
+ # update metrics
+ for key, val in micro_batch_metrics.items():
+ if key not in metrics:
+ metrics[key] = []
+ metrics[key].append(val)
+
+ avg_metrics = {k: sum(v) / len(v) for k, v in metrics.items()}
+ return total_loss, avg_metrics
+
+ async def train_step(self, batch_exps: List[Experience]) -> Dict:
+ """Training one step.
+
+ Args:
+ batch (List[Experience]): A batch of experiences to train.
+
+ Returns:
+ Dict: Metrics of the training step.
+ """
+ batch, batch_input_tokens, model_inputs_list = to_tinker_input(batch_exps, self.logger)
+ self.model_inputs_list = model_inputs_list
+ timing_raw = {}
+ metrics = {}
+ self._train_step_num += 1
+
+ with Timer(timing_raw, "step"):
+ if self.algorithm.use_reference: # ref_logprob may not be used
+ import asyncio
+
+ ref_logprobs = await asyncio.gather(
+ *[
+ self.ref_client.compute_logprobs_async(input_tokens)
+ for input_tokens in batch_input_tokens
+ ]
+ )
+ for model_inputs, ref_logprob in zip(model_inputs_list, ref_logprobs):
+ response_length = model_inputs["action_mask"].shape[0]
+ model_inputs["ref_logprob"] = torch.tensor(ref_logprob[-response_length:])
+
+ if self.algorithm.compute_advantage_in_trainer:
+ # TODO: following is verl format, which is not compatible with tinker
+ raise NotImplementedError(
+ "`compute_advantage_in_trainer` is not implemented yet in tinker"
+ )
+ else:
+ # skip token_level_scores for sft/dpo
+ for model_inputs in model_inputs_list:
+ if "token_level_scores" in model_inputs:
+ assert "token_level_rewards" not in model_inputs
+ model_inputs["token_level_rewards"] = model_inputs["token_level_scores"]
+
+ # update actor
+ with Timer(timing_raw, "update_actor"):
+ fwdbwd_future = await self.actor_client.forward_backward_custom_async(
+ batch, self._loss_func
+ )
+ optim_future = await self.actor_client.optim_step_async(self.adam_params)
+ fwdbwd_result = await fwdbwd_future
+ optim_result = await optim_future
+ metrics.update(fwdbwd_result.metrics)
+ if optim_result.metrics:
+ metrics.update(optim_result.metrics)
+
+ # collect metrics
+ metrics.update(compute_data_metrics(batch=self.model_inputs_list))
+ timing_metrics = compute_timing_metrics(batch=self.model_inputs_list, timing_raw=timing_raw)
+ metrics.update({k.replace("timing_s/", "time/"): v for k, v in timing_metrics.items()})
+ metrics.update(
+ compute_throughout_metrics(batch=self.model_inputs_list, timing_raw=timing_raw)
+ )
+
+ return metrics
+
+ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = False) -> None:
+ """Save the checkpoint."""
+ if self.train_step_num == self.latest_remote_checkpoint_step:
+ return
+ self.latest_remote_checkpoint_step = self.train_step_num
+ checkpoint_name = f"{self.tinker_checkpoint_name_prefix}-state-{self.train_step_num}"
+ self.latest_remote_checkpoint_path = (
+ self.actor_client.save_state(checkpoint_name).result().path
+ )
+ local_path = os.path.join(
+ self.default_local_dir,
+ f"global_step_{self.train_step_num}",
+ )
+ os.makedirs(local_path, exist_ok=True)
+ remote_checkpoint_path = os.path.join(local_path, "remote_checkpoint_path.txt")
+ with open(remote_checkpoint_path, "w") as f:
+ f.write(self.latest_remote_checkpoint_path)
+
+ with open(self.local_latest_checkpointed_iteration, "w") as f:
+ f.write(str(self.train_step_num))
+
+ def sync_weight(self) -> None:
+ """Sync the model weight."""
+ raise NotImplementedError("Tinker trainer does not support NCCL sync")
+
+ def upload_state_dict(self) -> None:
+ """Upload the state dict to Synchronizer."""
+ self.save_state_dict()
+ ray.get(
+ self.synchronizer.set_model_state_dict.remote(
+ self.latest_remote_sampler_path, self.train_step_num
+ )
+ )
+
+ def save_state_dict(self) -> None:
+ """Only save the model state dict for Synchronizer."""
+ if self.train_step_num == self.latest_remote_sampler_step:
+ return
+ self.latest_remote_sampler_step = self.train_step_num
+ checkpoint_name = f"{self.tinker_checkpoint_name_prefix}-sampler-{self.train_step_num}"
+ self.latest_remote_sampler_path = (
+ self.actor_client.save_weights_for_sampler(checkpoint_name).result().path
+ )
+ local_path = os.path.join(
+ self.default_local_dir,
+ f"global_step_{self.train_step_num}",
+ )
+ os.makedirs(local_path, exist_ok=True)
+ remote_sampler_path = os.path.join(local_path, "remote_sampler_path.txt")
+ with open(remote_sampler_path, "w") as f:
+ f.write(self.latest_remote_sampler_path)
+
+ with open(self.local_latest_state_dict_iteration, "w") as f:
+ f.write(str(self.train_step_num))
diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py
index c42fccfb26..c4901f3a2f 100644
--- a/trinity/trainer/trainer.py
+++ b/trinity/trainer/trainer.py
@@ -17,7 +17,7 @@
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
from trinity.common.config import Config
from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle
-from trinity.common.experience import Experiences
+from trinity.common.experience import Experience
from trinity.manager.state_manager import StateManager
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.log import get_logger
@@ -39,7 +39,6 @@ def __init__(self, config: Config) -> None:
path=config.checkpoint_job_dir, trainer_name=config.trainer.name, config=config
)
trainer_state = self.state.load_trainer()
- self.last_trainer_sync_step = 0
self.monitor = MONITOR.get(config.monitor.monitor_type)(
project=config.project,
group=self.config.group,
@@ -60,15 +59,15 @@ def __init__(self, config: Config) -> None:
sample_strategy_state = trainer_state.get("sample_strategy_state", {})
self.sample_strategy.load_state_dict(sample_strategy_state)
self.save_interval = config.trainer.save_interval
- self.last_sync_step = None
+ self.last_sync_step = 0
self.last_sync_time = None
self.total_steps = config.trainer.total_steps or float("inf")
self.save_hf_checkpoint = config.trainer.save_hf_checkpoint
async def prepare(self) -> None:
"""Prepare the trainer."""
- self.engine.prepare()
- self.last_trainer_sync_step = self.train_step_num
+ await self.engine.prepare()
+ self.last_sync_step = self.train_step_num
await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)
async def train(self) -> str:
@@ -109,7 +108,7 @@ async def train(self) -> str:
self.logger.info("--------------------\n> Trainer finished.\n--------------------")
return self.config.trainer.name
- async def train_step(self, exps: Experiences) -> Dict:
+ async def train_step(self, exps: List[Experience]) -> Dict:
"""Train one step.
Returns:
@@ -119,21 +118,21 @@ async def train_step(self, exps: Experiences) -> Dict:
self.logger.info(f"Training at step {self.train_step_num + 1} started.")
metrics = {}
with Timer(metrics, "time/train_step"):
- train_metrics = self.engine.train_step(exps)
+ train_metrics = await self.engine.train_step(exps)
self.logger.info(f"Training at step {self.train_step_num} finished.")
metrics.update(train_metrics)
return metrics
- async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]:
+ async def _sample_data(self) -> Tuple[List[Experience], Dict, List[Dict]]:
"""Sample a batch of experiences.
Returns:
- Experiences: A batch of experiences.
+ List[Experience]: A batch of experiences.
Dict: Metrics of the sampling step.
List[Dict]: A list of representative samples for logging.
"""
batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1)
- metrics["sample/task_count"] = len(set(eid.tid for eid in batch.eids))
+ metrics["sample/task_count"] = len(set(exp.eid.tid for exp in batch))
return batch, metrics, repr_samples
async def need_sync(self) -> bool:
@@ -145,14 +144,17 @@ async def need_sync(self) -> bool:
)
else:
if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_TRAINER:
- delta = self.train_step_num - self.last_trainer_sync_step
+ delta = self.train_step_num - self.last_sync_step
if delta >= self.config.synchronizer.sync_interval:
await self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC)
explorer_status_counts = await self.synchronizer.get_explorer_status_counts.remote()
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
return explorer_status_counts[RunningStatus.WAITING_SYNC] > 0
else: # memory & checkpoint
- return explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0
+ return (
+ self.last_sync_step != self.train_step_num
+ and explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0
+ )
def need_save(self) -> bool:
"""Whether to save the checkpoint."""
@@ -173,7 +175,6 @@ async def sync_weight(self) -> Dict:
self.logger.error("Trainer sync_weights failed.")
else:
self.engine.sync_weight()
- self.last_trainer_sync_step = self.train_step_num
elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT:
self.engine.save_state_dict()
elif self.config.synchronizer.sync_method == SyncMethod.MEMORY:
@@ -229,7 +230,7 @@ class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
@abstractmethod
- def prepare(self) -> None:
+ async def prepare(self) -> None:
"""Do some preparation before training started."""
@property
@@ -238,11 +239,11 @@ def train_step_num(self) -> int:
"""Get the current training step number."""
@abstractmethod
- def train_step(self, batch: Experiences) -> Dict:
+ async def train_step(self, batch_exps: List[Experience]) -> Dict:
"""Training one step.
Args:
- batch (Experiences): A batch of experiences to train.
+ batch_exps (List[Experience]): A batch of experiences to train.
Returns:
Dict: Metrics of the training step.
@@ -271,5 +272,9 @@ def get_trainer_wrapper(config: Config) -> TrainEngineWrapper:
from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper
return VerlPPOTrainerWrapper(config)
+ elif config.trainer.trainer_type == "tinker":
+ from trinity.trainer.tinker_trainer import TinkerTrainerWrapper
+
+ return TinkerTrainerWrapper(config)
else:
raise NotImplementedError
diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py
index 922182e643..640ee2b748 100644
--- a/trinity/trainer/verl/utils.py
+++ b/trinity/trainer/verl/utils.py
@@ -2,6 +2,7 @@
import os
from logging import Logger
+from typing import List
import numpy as np
import torch
@@ -10,75 +11,98 @@
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from trinity.common.config import Config
-from trinity.common.experience import Experiences
-
-
-def to_data_proto(experiences: Experiences, logger: Logger) -> DataProto: # noqa: C901
- """Convert Experiences to verl DataProto."""
- attention_mask = experiences.attention_masks
+from trinity.common.experience import (
+ Experience,
+ gather_action_masks,
+ gather_attention_masks,
+ gather_response_attrs,
+ gather_token_ids,
+ split_dpo_experience_to_single_turn,
+)
+
+
+def to_data_proto(
+ experiences: List[Experience], pad_token_id: int, logger: Logger
+) -> DataProto: # noqa: C901
+ """Convert List[Experience] to verl DataProto."""
+ assert len(experiences) > 0, "No experiences provided."
+ if experiences[0].experience_type == "dpo":
+ experiences = split_dpo_experience_to_single_turn(experiences)
+ max_prompt_length = max([exp.prompt_length for exp in experiences])
+ max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore
+
+ attention_mask = gather_attention_masks(
+ experiences, max_prompt_length, max_response_length
+ ).long()
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
+ tokens = gather_token_ids(
+ experiences, max_prompt_length, max_response_length, pad_token_id
+ ).long()
batch_dict = {
- "uid": np.array([eid.tid for eid in experiences.eids]),
- "unique_ids": np.array([eid.uid for eid in experiences.eids]),
+ "uid": np.array([exp.eid.tid for exp in experiences]),
+ "unique_ids": np.array([exp.eid.uid for exp in experiences]),
"position_ids": position_ids,
- "input_ids": experiences.tokens.long(),
- "responses": experiences.tokens[:, experiences.prompt_length :].long(),
- "attention_mask": attention_mask.long(),
- "response_mask": (
- experiences.action_masks.long()
- if hasattr(experiences, "action_masks") and experiences.action_masks is not None
- else attention_mask[:, experiences.prompt_length :].long()
- ),
+ "input_ids": tokens,
+ "responses": tokens[:, max_prompt_length:],
+ "attention_mask": attention_mask,
+ "response_mask": gather_action_masks(experiences, max_response_length),
}
- if experiences.rewards is not None or experiences.token_level_rewards is not None:
- assert experiences.logprobs is not None
- if experiences.token_level_rewards is not None:
- if experiences.rewards is not None:
+ have_reward = all(exp.reward is not None for exp in experiences)
+ have_token_level_reward = all(exp.token_level_reward is not None for exp in experiences)
+ if have_reward or have_token_level_reward:
+ assert all(exp.logprobs is not None for exp in experiences), "No logprobs provided."
+ if have_token_level_reward:
+ if have_reward:
logger.warning(
"Both experiences.rewards and experiences.token_level_rewards are provided. "
"Using experiences.token_level_rewards."
)
- token_level_rewards = experiences.token_level_rewards
+ token_level_rewards = gather_response_attrs(
+ experiences, "token_level_reward", max_response_length
+ )
else:
- token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
+ token_level_rewards = torch.zeros(attention_mask.shape, dtype=torch.float32)
eos_mask_idx = cumsum.argmax(dim=-1)
- token_level_rewards[
- torch.arange(experiences.batch_size), eos_mask_idx
- ] = experiences.rewards
- token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
+ token_level_rewards[torch.arange(len(experiences)), eos_mask_idx] = torch.tensor(
+ [exp.reward for exp in experiences]
+ )
+ token_level_rewards = token_level_rewards[:, max_prompt_length:]
batch_dict.update(
{
"token_level_scores": token_level_rewards,
- "old_log_probs": experiences.logprobs, # type: ignore
+ "old_log_probs": gather_response_attrs(
+ experiences, "logprobs", max_response_length
+ ),
}
)
- if experiences.advantages is not None:
- batch_dict["advantages"] = experiences.advantages
- if experiences.returns is not None:
- batch_dict["returns"] = experiences.returns
- if experiences.teacher_logprobs is not None:
- batch_dict["teacher_log_probs"] = experiences.teacher_logprobs
-
- if experiences.multi_modal_inputs is not None:
- batch_size = len(batch_dict["unique_ids"])
+
+ for attr in ["advantages", "returns", "teacher_logprobs"]:
+ if all(getattr(exp, attr, None) is not None for exp in experiences):
+ batch_dict[attr] = gather_response_attrs(experiences, attr, max_response_length)
+
+ if all(exp.multi_modal_inputs is not None for exp in experiences):
+ keys = experiences[0].multi_modal_inputs.keys()
batch_dict["multi_modal_inputs"] = np.array(
- [
- {k: v[i] for k, v in experiences.multi_modal_inputs.items()}
- for i in range(batch_size)
- ],
+ [{key: exp.multi_modal_inputs[key] for key in keys} for exp in experiences], # type: ignore
dtype=object,
)
- if experiences.custom_fields:
- for field in experiences.custom_fields:
- if hasattr(experiences, field):
- batch_dict[field] = getattr(experiences, field)
+ custom_fields_set = set(tuple(exp.custom_fields) for exp in experiences)
+ if len(custom_fields_set) == 1:
+ custom_fields = list(custom_fields_set)[0]
+ for custom_field in custom_fields:
+ batch_dict[custom_field.destination_field] = torch.tensor(
+ [exp.info[custom_field.source_field] for exp in experiences],
+ dtype=custom_field.data_type,
+ )
+ else:
+ raise ValueError("Custom fields are not consistent across experiences.")
return DataProto.from_single_dict(batch_dict)
-def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict:
+def compute_data_metrics(batch: DataProto) -> dict:
"""
Computes various metrics from a batch of data for PPO training.
Modified from verl.trainer.ppo.metric_utils.compute_data_metrics
@@ -89,7 +113,6 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict:
Args:
batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
- use_critic: Whether to include critic-specific metrics. Defaults to True.
Returns:
A dictionary of metrics including:
@@ -97,8 +120,8 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict:
- critic/rewards/mean, max, min: Statistics about sequence rewards
- critic/advantages/mean, max, min: Statistics about advantages
- critic/returns/mean, max, min: Statistics about returns
- - critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
- - critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
+ - critic/values/mean, max, min: Statistics about critic values
+ - critic/vf_explained_var: Explained variance of the value function
- response_length/mean, max, min, clip_ratio: Statistics about response lengths
- prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
"""
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index 16c0525327..9c52af3d66 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -7,7 +7,7 @@
import os
import sys
from collections import defaultdict
-from typing import Dict, Optional
+from typing import Dict, List, Optional
import ray
import torch
@@ -30,11 +30,11 @@
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.metric import reduce_metrics
-from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN, SAMPLE_STRATEGY
+from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import Config
from trinity.common.constants import SaveStrategy
-from trinity.common.experience import Experiences
+from trinity.common.experience import Experience
from trinity.trainer.trainer import TrainEngineWrapper
from trinity.trainer.verl.utils import compute_data_metrics, to_data_proto
from trinity.utils.log import get_logger
@@ -187,6 +187,7 @@ def __init__(
global_config: Config,
):
self.logger = get_logger(__name__, in_ray_actor=True)
+ self.pad_token_id = global_config.buffer.pad_token_id
train_config = global_config.trainer
config = OmegaConf.structured(train_config.trainer_config)
# download the checkpoint from hdfs
@@ -261,11 +262,6 @@ def __init__(
self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)(
**self.algorithm_config.kl_penalty_fn_args
)
- self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)(
- buffer_config=global_config.buffer,
- trainer_type=global_config.trainer.trainer_type,
- **global_config.algorithm.sample_strategy_args,
- )
super().__init__(
config,
tokenizer,
@@ -379,7 +375,7 @@ def init_workers(self):
def train_step_num(self) -> int:
return self.global_steps
- def prepare(self):
+ async def prepare(self):
self.actor_rollout_wg.setup_weight_sync_group()
self.actor_rollout_wg.set_algorithm(self.algorithm_config)
@@ -411,8 +407,8 @@ def save_state_dict(self): # checkpoint sync
def upload_state_dict(self): # state dict sync
self.actor_rollout_wg.upload_state_dict(self.global_steps)
- def train_step(self, batch: Experiences) -> Dict: # noqa C901
- batch = to_data_proto(batch, self.logger)
+ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901
+ batch = to_data_proto(batch_exps, self.pad_token_id, self.logger) # type: ignore
batch = self.post_process_batch(batch)
metrics = {}
self.global_steps += 1
@@ -476,7 +472,7 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901
metrics.update(actor_output_metrics)
# collect metrics
- metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
+ metrics.update(compute_data_metrics(batch=batch))
timing_metrics = compute_timing_metrics(batch=batch, timing_raw=timing_raw)
metrics.update({k.replace("timing_s/", "time/"): v for k, v in timing_metrics.items()})
n_gpus = self.resource_pool_manager.get_n_gpus()