diff --git a/README.md b/README.md index 0df3f8fc80..d1bd93c297 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ Trinity-RFT provides functionalities for users with different backgrounds and ob ## 🚀 News +* [2025-12] Trinity-RFT has supported [tinker](https://thinkingmachines.ai/tinker/) training backend, which enables model training on devices **without GPUs**. * [2025-12] Trinity-RFT powers the medical and health business of "Taobao Shangou", enabling the AI agent to understand vague symptoms, proactively ask follow-up questions, and provide precise recommendations ([News](https://tech.china.com.cn/sx/20251201/411376.shtml)). * [2025-11] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.3)] Trinity-RFT v0.3.3 released: bug fixes. * [2025-11] Introducing [Learn-to-Ask](https://github.com/modelscope/Trinity-RFT/tree/main/examples/learn_to_ask): a framework for training proactive dialogue agents from offline expert data ([paper](https://arxiv.org/pdf/2510.25441)). @@ -154,6 +155,10 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor > [!NOTE] > This project is currently under active development. Comments and suggestions are welcome! +> +> **No GPU? No problem!** You can still try it out: +> 1. Follow the installation steps (feel free to skip GPU-specific packages like `flash-attn`) +> 2. Run the **[Tinker training example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/tinker)**, which is specifically designed to work on CPU-only systems. ### Step 1: installation @@ -186,6 +191,9 @@ Choose one of the following options: conda create -n trinity python=3.12 conda activate trinity +pip install -e ".[verl]" +# If you have no GPU, use Tinker instead. +# pip install -e ".[tinker]" pip install -e ".[dev]" pip install -e ".[flash_attn]" # if you encounter issues when installing flash-attn, try: @@ -198,6 +206,9 @@ pip install -e ".[flash_attn]" python3.10 -m venv .venv source .venv/bin/activate +pip install -e ".[verl]" +# If you have no GPU, use Tinker instead. +# pip install -e ".[tinker]" pip install -e ".[dev]" pip install -e ".[flash_attn]" # if you encounter issues when installing flash-attn, try: diff --git a/README_zh.md b/README_zh.md index a16063f30d..15a4a12a82 100644 --- a/README_zh.md +++ b/README_zh.md @@ -41,6 +41,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: ## 🚀 新闻 +* [2025-12] Trinity-RFT 已支持 [tinker](https://thinkingmachines.ai/tinker/) 训练后端,可在**无 GPU 的设备**上进行模型训练。 * [2025-12] Trinity-RFT 助力淘宝闪购医药健康业务,让 AI 智能体能够理解模糊症状、主动询问后续问题,并提供精准推荐([新闻](https://tech.china.com.cn/sx/20251201/411376.shtml))。 * [2025-11] [[发布说明](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.3)] Trinity-RFT v0.3.3 发布:修复若干 Bug。 * [2025-11] 推出 [Learn-to-Ask](https://github.com/modelscope/Trinity-RFT/tree/main/examples/learn_to_ask):利用离线专家数据,训练具备主动问询能力的对话智能体([论文](https://arxiv.org/pdf/2510.25441)). @@ -79,6 +80,10 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: > [!NOTE] > 更多教程请参考 [Trinity-RFT 文档](https://modelscope.github.io/Trinity-RFT/)。 +> +> 没有 GPU?没问题!你仍然可以尝试以下方案: +> 1. 按照安装步骤操作(可跳过 GPU 专用的软件包,例如 `flash-attn`) +> 2. 运行 **[Tinker 训练示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/tinker)**,该示例专为仅使用 CPU 的系统设计。 @@ -185,6 +190,9 @@ cd Trinity-RFT conda create -n trinity python=3.12 conda activate trinity +pip install -e ".[verl]" +# 如果没有GPU,需要使用Tinker则改为 +# pip install -e ".[tinker]" pip install -e ".[dev]" pip install -e ".[flash_attn]" # 如果安装 flash-attn 时遇到问题,可尝试: @@ -197,6 +205,9 @@ pip install -e ".[flash_attn]" python3.10 -m venv .venv source .venv/bin/activate +pip install -e ".[verl]" +# 如果没有GPU,需要使用Tinker则改为 +# pip install -e ".[tinker]" pip install -e ".[dev]" pip install -e ".[flash_attn]" # 如果安装 flash-attn 时遇到问题,可尝试: diff --git a/docs/sphinx_doc/assets/tinker-gsm8k.png b/docs/sphinx_doc/assets/tinker-gsm8k.png new file mode 100644 index 0000000000..fe0d1145a3 Binary files /dev/null and b/docs/sphinx_doc/assets/tinker-gsm8k.png differ diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index b1d08f03b0..5dfc2de4dd 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -164,9 +164,21 @@ model: max_response_tokens: 16384 min_response_tokens: 1 enable_prompt_truncation: true + repetition_penalty: 1.0 + lora_configs: null + rope_scaling: null + rope_theta: null + tinker: + enable: false + base_model: null + rank: 32 + seed: null + train_mlp: true + train_attn: true + train_unembed: true ``` -- `model_path`: Path to the model being trained. +- `model_path`: Path to the model being trained. If `tinker` is enabled, this is the path to the local tokenizer. - `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`. - `custom_chat_template`: Optional custom chat template in string format. If not specified, the system will use the default chat template from tokenizer. - `chat_template_path`: Optional path to the chat template file in jinja2 type; overrides `custom_chat_template` if set. If not specified, the system will use the default chat template from tokenizer. @@ -175,6 +187,25 @@ model: - `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`. - `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`. - `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. This function does not work with openai api mode. +- `repetition_penalty`: Repetition penalty factor. Default is `1.0`. +- `lora_configs`: Optional LoRA configuration. If not specified, defaults to `null`. Currently, only one LoRA configuration is supported, and this configuration will not be applied if `tinker` is enabled. + - `name`: Name of the LoRA. Default is `None`. + - `path`: Path to the LoRA. Default is `None`. + - `base_model_name`: Name of the base model for LoRA. If not specified, defaults to `None`. + - `lora_rank`: Rank of the LoRA. Default is `32`. + - `lora_alpha`: Alpha value of the LoRA. Default is `32`. + - `lora_dtype`: Data type of the LoRA. Default is `auto`. + - `target_modules`: List of target modules for LoRA. Default is `all-linear`. +- `rope_scaling`: Optional RoPE scaling configuration in JSON format. If not specified, defaults to `null`. +- `rope_theta`: Optional RoPE theta value. If not specified, defaults to `null`. +- `tinker`: Optional Tinker configuration. Note: LoRA configuration will be ignored if Tinker is enabled. + - `enable`: Whether to enable Tinker. Default is `false`. + - `base_model`: Path to the base model for Tinker. If not specified, defaults to `model_path`. + - `rank`: LoRA rank controlling the size of adaptation matrices. Default is `32`. + - `seed`: Random seed for Tinker. If not specified, defaults to `null`. + - `train_mlp`: Whether to train the MLP layer. Default is `true`. + - `train_attn`: Whether to train the attention layer. Default is `true`. + - `train_unembed`: Whether to train the unembedding layer. Default is `true`. ```{tip} If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 0fff6cd91f..ec524252e4 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -164,9 +164,21 @@ model: max_response_tokens: 16384 min_response_tokens: 1 enable_prompt_truncation: true + repetition_penalty: 1.0 + lora_configs: null + rope_scaling: null + rope_theta: null + tinker: + enable: false + base_model: null + rank: 32 + seed: null + train_mlp: true + train_attn: true + train_unembed: true ``` -- `model_path`: 被训练模型的路径。 +- `model_path`: 被训练模型的路径。如果启用了`tinker`,则该路径为本地 tokenizer 的路径。 - `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。 - `custom_chat_template`: 可选的自定义 chat template 字符串格式。若未指定,系统会使用 tokenizer 的默认 chat template。 - `chat_template_path`: 可选的 chat template 文件路径,类型通常为 jinja2;若设置,则覆盖 `custom_chat_template`。若未指定,系统会使用 tokenizer 的默认 chat template。 @@ -175,6 +187,25 @@ model: - `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。在 OpenAI API 模式下不生效。 +- `repetition_penalty`:重复惩罚因子。默认值为 `1.0`。 +- `lora_configs`:可选的 LoRA 配置。若未指定,则默认为 `null`。目前仅支持一个 LoRA 配置,并且如果启用了`tinker`,则不会使用此LoRA配置。 + - `name`:LoRA 的名称。默认为 `None`。 + - `path`:LoRA 的路径。默认为 `None`。 + - `base_model_name`:LoRA 所基于的基础模型名称。若未指定,则默认为 `None`。 + - `lora_rank`:LoRA 的秩(rank)。默认为 `32`。 + - `lora_alpha`:LoRA 的 alpha 值。默认为 `32`。 + - `lora_dtype`:LoRA 的数据类型。默认为 `auto`。 + - `target_modules`:LoRA 的目标模块列表。默认为 `all-linear`。 +- `rope_scaling`:可选的 RoPE 缩放配置,采用 JSON 格式。若未指定,则默认为 `null`。 +- `rope_theta`:可选的 RoPE theta 值。若未指定,则默认为 `null`。 +- `tinker`:可选的 Tinker 配置。注意:若启用 Tinker,则 LoRA 配置将被忽略。 + - `enable`:是否启用 Tinker。默认为 `false`。 + - `base_model`:Tinker 所使用的基础模型路径。若未指定,则默认为 `model_path`。 + - `rank`:控制适配矩阵大小的 LoRA 秩(rank)。默认为 `32`。 + - `seed`:Tinker 使用的随机种子。若未指定,则默认为 `null`。 + - `train_mlp`:是否训练 MLP 层。默认为 `true`。 + - `train_attn`:是否训练注意力层。默认为 `true`。 + - `train_unembed`:是否训练反嵌入(unembedding)层。默认为 `true`。 ```{tip} 如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。 diff --git a/examples/tinker/README.md b/examples/tinker/README.md new file mode 100644 index 0000000000..79de2c959f --- /dev/null +++ b/examples/tinker/README.md @@ -0,0 +1,245 @@ +# Trinity with Tinker Backend + +> [!NOTE] +> This example demonstrates how to use Trinity with the [Tinker](https://thinkingmachines.ai/tinker/) backend, which enables model training on devices **without GPUs**. + +## Setup Instructions + +### 1. API Key Configuration +Before starting Ray, you must set the `TRINITY_API_KEY` environment variable to your Tinker API key to enable proper access to Tinker's API: + +```bash +export TRINITY_API_KEY=your_tinker_api_key +``` + +### 2. Configuration File +Configure the Tinker backend in your YAML configuration file by setting the `model.tinker` parameters as shown below: + +```yaml +model: + tinker: + enable: true + base_model: null + rank: 32 + seed: null + train_mlp: true + train_attn: true + train_unembed: true +``` + +### 3. Configuration Parameters Explained + +- **`tinker`**: Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings (`model.lora_configs`) will be ignored. + - **`enable`**: Whether to activate the Tinker backend. Default: `false` + - **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config + - **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32` + - **`seed`**: Random seed for reproducible Tinker operations. If not specified (`null`), no specific seed is set + - **`train_mlp`**: Whether to train the MLP (feed-forward) layers. Default: `true` + - **`train_attn`**: Whether to train the attention layers. Default: `true` + - **`train_unembed`**: Whether to train the unembedding (output) layer. Default: `true` + + +## Usage + +Once configured, Trinity works with the Tinker backend just like it does with the standard veRL backend. Start training with: + +```bash +trinity run --config tinker.yaml # Replace with your actual config file path +``` + +### Important Limitations of the Tinker Backend + +1. **Entropy loss** is not consistent compared to veRL backends. +2. **Algorithms requiring `compute_advantage_in_trainer=true` are NOT supported currently**, including: + - PPO (`algorithm.algorithm_type=ppo`) + - Reinforce++ (`algorithm.algorithm_type=reinforceplusplus`) + - RLOO (`algorithm.algorithm_type=rloo`) + - On-policy distillation (`algorithm.algorithm_type=on_policy_distill`) + + Algorithms like `algorithm.algorithm_type=grpo` are supported. We will add support for these algorithms in the future. +3. **Multiple stages training** is not supported currently, we will add support for this in the future. + +> 💡 A complete example configuration file is available at [`tinker.yaml`](tinker.yaml). + + +## Results on the Llama-3.2-3B Model + +We trained the **Llama-3.2-3B** model on the **GSM8K** dataset using both the **Tinker** and **veRL** backends. Below are the full configuration files used in our experiments. + + +
Click to expand: Tinker Backend Configuration + +```yaml +mode: both +project: Trinity-RFT-gsm8k +group: alignment-tinker +name: tinker-llama3.2-3B-off1 +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 8 + sample_strategy: default + kl_loss_fn_args: + kl_coef: 0.0 + optimizer: + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant +data_processor: {} +model: + model_path: meta-llama/Llama-3.2-3B + max_prompt_tokens: 1024 + max_response_tokens: 2048 + custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" + tinker: + enable: true + base_model: meta-llama/Llama-3.2-3B +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + batch_size: 96 + total_epochs: 1 + explorer_input: + taskset: + name: taskset + storage_type: file + path: openai/gsm8k + split: train + subset_name: main + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: [] + default_workflow_type: math_workflow + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + replay_buffer: + enable: false +explorer: + runner_per_model: 16 + rollout_model: + engine_num: 4 + seed: 42 + auxiliary_models: [] + eval_interval: 1000 +trainer: + save_interval: 100 + enable_preview: true + grad_clip: 1.0 + max_token_len_per_gpu: 16384 +monitor: + monitor_type: wandb +synchronizer: + sync_method: checkpoint + sync_style: fixed + sync_interval: 1 + sync_offset: 1 + sync_timeout: 1200 +``` + +
+ + +
Click to expand: veRL Backend Configuration (LoRA) + +```yaml +mode: both +project: Trinity-RFT-gsm8k +group: alignment-tinker +name: verl-llama3.2-3B-lora-off1 +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 8 + sample_strategy: default + kl_loss_fn_args: + kl_coef: 0.0 + optimizer: + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant +data_processor: {} +model: + model_path: meta-llama/Llama-3.2-3B + max_prompt_tokens: 1024 + max_response_tokens: 2048 + custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" + lora_configs: + - name: lora + lora_rank: 32 + lora_alpha: 32 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + batch_size: 96 + total_epochs: 1 + explorer_input: + taskset: + name: taskset + storage_type: file + path: openai/gsm8k + split: train + subset_name: main + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: [] + default_workflow_type: math_workflow + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + replay_buffer: + enable: false +explorer: + runner_per_model: 16 + rollout_model: + engine_num: 4 + tensor_parallel_size: 1 + enforce_eager: false + enable_prefix_caching: false + enable_chunked_prefill: false + gpu_memory_utilization: 0.9 + dtype: bfloat16 + seed: 42 + enable_thinking: false + enable_history: false + enable_openai_api: false + enable_auto_tool_choice: false + tool_call_parser: null + reasoning_parser: null + auxiliary_models: [] + eval_interval: 1000 +trainer: + trainer_type: verl + save_interval: 100 + enable_preview: true + grad_clip: 1.0 + max_token_len_per_gpu: 16384 +monitor: + monitor_type: wandb +synchronizer: + sync_method: checkpoint + sync_style: fixed + sync_interval: 1 + sync_offset: 1 + sync_timeout: 1200 +``` + +
+ +### Observations + +Since Llama-3.2-3B is a base (non-instruct-tuned) model, it has limited ability to follow formatting instructions. Additionally, we trained for only **one epoch**. As a result, both backends achieved final rewards just slightly above 0.1. Nonetheless, the training curves show a clear upward trend in reward, indicating successful learning. The results are visualized below: + +![Training Rewards on GSM8K](../../docs/sphinx_doc/assets/tinker-gsm8k.png) diff --git a/examples/tinker/tinker.yaml b/examples/tinker/tinker.yaml new file mode 100644 index 0000000000..744357e745 --- /dev/null +++ b/examples/tinker/tinker.yaml @@ -0,0 +1,67 @@ +mode: both +project: Trinity-RFT-gsm8k +name: tinker-Qwen3-4B +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 8 + sample_strategy: default + kl_loss_fn_args: + kl_coef: 0.0 + optimizer: + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant +data_processor: {} +model: + model_path: Qwen/Qwen3-4B-Instruct-2507 + max_prompt_tokens: 1024 + max_response_tokens: 2048 + tinker: + enable: true + base_model: Qwen/Qwen3-4B-Instruct-2507 +buffer: + batch_size: 96 + total_epochs: 1 + explorer_input: + taskset: + name: taskset + storage_type: file + path: openai/gsm8k + split: train + subset_name: main + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: [] + default_workflow_type: math_workflow + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + replay_buffer: + enable: false +explorer: + runner_per_model: 8 + rollout_model: + engine_num: 4 + seed: 42 + auxiliary_models: [] + eval_interval: 1000 +trainer: + save_interval: 100 + enable_preview: true + grad_clip: 1.0 + max_token_len_per_gpu: 16384 +monitor: + monitor_type: tensorboard +synchronizer: + sync_method: memory + sync_style: fixed + sync_interval: 1 + sync_timeout: 1200 +log: + level: INFO diff --git a/pyproject.toml b/pyproject.toml index b7e3227a0a..8a8faae156 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,9 +21,7 @@ classifiers = [ ] requires-python = ">=3.10,<3.13" dependencies = [ - "verl==0.5.0", "ray[default]>=2.50.0", - "vllm>=0.10.2,<=0.11.0", "tensordict", "wandb", "omegaconf", @@ -43,12 +41,17 @@ dependencies = [ "sortedcontainers", "word2number", "transformers", + "datasets", ] [project.scripts] trinity = "trinity.cli.launcher:main" [project.optional-dependencies] +verl = [ + "verl==0.5.0", + "vllm>=0.10.2,<=0.11.0", +] data = [ "py-data-juicer>=1.4.3" ] @@ -78,6 +81,9 @@ megatron = [ "transformer_engine[pytorch]==2.8.0", "mbridge>=0.13.0", ] +tinker = [ + "tinker", # tinker requires python>=3.11 +] doc = [ "sphinx", diff --git a/tests/buffer/sample_strategy_test.py b/tests/buffer/sample_strategy_test.py index 32ea84bdb7..c3a9af9179 100644 --- a/tests/buffer/sample_strategy_test.py +++ b/tests/buffer/sample_strategy_test.py @@ -58,7 +58,9 @@ def _init_buffer_writer_and_sample_strategy(self): async def _verify_model_version(self, step, expected_versions): batch, metrics, _ = await self.sample_strategy.sample(step=step) self.assertEqual( - batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}" + [exp.reward for exp in batch], + expected_versions, + f"Model versions mismatch at step {step}", ) self.assertEqual( metrics["sample/model_version/min"], diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 61fd03e675..187c0eb544 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -125,17 +125,13 @@ def setUp(self): self.config.algorithm.repeat_times = self.repeat_times self.config.explorer.rollout_model.enable_history = self.enable_history self.config.check_and_update() - from pprint import pprint - pprint(self.config) self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper( self.engines[0], engine_type="vllm", enable_history=self.enable_history ) - async def test_generate( - self, - ): + async def test_generate(self): await prepare_engines(self.engines, self.auxiliary_engines) await self.model_wrapper.prepare() self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 972071b99c..5854a04f8f 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1322,3 +1322,38 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestTinkerTrainer(BaseTrainerCase): + @unittest.skip("Require tinker API key") + def test_trainer(self): + """Test GSM8K on tinker.""" + # test both mode + self.config.algorithm.algorithm_type = "grpo" + self.config.algorithm.repeat_times = 4 + self.config.algorithm.advantage_fn = "grpo" + self.config.algorithm.advantage_fn_args = { + "epsilon": 1e-6, + } + self.config.buffer.total_epochs = 1 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.model.tinker.enable = True + self.config.model.tinker.base_model = "Qwen/Qwen3-4B-Instruct-2507" + self.config.check_and_update() + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + pipeline_metrics = parser.metric_list("experience_pipeline") + self.assertTrue(len(pipeline_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) diff --git a/trinity/algorithm/advantage_fn/asymre_advantage.py b/trinity/algorithm/advantage_fn/asymre_advantage.py index 52f2314272..5c2ccf9932 100644 --- a/trinity/algorithm/advantage_fn/asymre_advantage.py +++ b/trinity/algorithm/advantage_fn/asymre_advantage.py @@ -1,10 +1,12 @@ """AsymRE advantage computation""" from collections import defaultdict -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple import torch -from verl import DataProto + +if TYPE_CHECKING: + from verl import DataProto from trinity.algorithm.advantage_fn import AdvantageFn, GroupAdvantage from trinity.common.experience import Experience, group_by @@ -23,9 +25,9 @@ def __init__( def __call__( self, - exps: DataProto, + exps: "DataProto", **kwargs, - ) -> Tuple[DataProto, Dict]: + ) -> Tuple["DataProto", Dict]: """Modified from compute_grpo_outcome_advantage Compute advantage for AsymRE, operating only on Outcome reward diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py index 7d1c58977d..4562a5a4b9 100644 --- a/trinity/algorithm/advantage_fn/grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -3,10 +3,12 @@ import copy from collections import defaultdict -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import torch -from verl import DataProto + +if TYPE_CHECKING: + from verl import DataProto from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn, GroupAdvantage from trinity.common.experience import Experience, group_by @@ -26,9 +28,9 @@ def __init__( def __call__( self, - exps: DataProto, + exps: "DataProto", **kwargs, - ) -> Tuple[DataProto, Dict]: + ) -> Tuple["DataProto", Dict]: """ Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response). diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py index d5e9203e3c..8c0a586986 100644 --- a/trinity/algorithm/advantage_fn/opmd_advantage.py +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -1,10 +1,12 @@ """OPMD advantage computation""" from collections import defaultdict -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple import torch -from verl import DataProto + +if TYPE_CHECKING: + from verl import DataProto from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn, GroupAdvantage from trinity.common.experience import Experience, group_by @@ -25,9 +27,9 @@ def __init__( def __call__( self, - exps: DataProto, + exps: "DataProto", **kwargs, - ) -> Tuple[DataProto, Dict]: + ) -> Tuple["DataProto", Dict]: """Modified from compute_grpo_outcome_advantage Compute advantage for OPMD, operating only on Outcome reward diff --git a/trinity/algorithm/key_mapper.py b/trinity/algorithm/key_mapper.py index 09c1f988a6..4813c18fb5 100644 --- a/trinity/algorithm/key_mapper.py +++ b/trinity/algorithm/key_mapper.py @@ -26,4 +26,5 @@ def from_trinity(self, key: str) -> str: "advantages": "advantages", } ), + "tinker": KeyMapper({}), } diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 65acc44d3c..8f0f6ed3d5 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -8,7 +8,7 @@ from trinity.algorithm.sample_strategy.utils import representative_sample from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig -from trinity.common.experience import CustomField, Experiences +from trinity.common.experience import CustomField, Experience from trinity.utils.timer import Timer @@ -53,7 +53,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): expert_buffer_config, ) - async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int) -> Tuple[List[Experience], Dict, List]: metrics = {} with Timer(metrics, "time/read_experience"): usual_exp_list = await self.usual_exp_buffer.read_async() @@ -82,24 +82,21 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: repr_samples = representative_sample(exp_list) self.set_model_version_metric(exp_list, metrics) - with Timer(metrics, "time/gather_experience"): - exps = Experiences.gather_experiences( - experiences=exp_list, - pad_token_id=self.pad_token_id, # type: ignore [arg-type] - custom_fields=[ - CustomField( - source_field="is_expert", - destination_field="expert_mask", - data_type=torch.bool, - ), - CustomField( - source_field="step", - destination_field="step", - data_type=torch.int32, - ), - ], - ) # type: ignore - return exps, metrics, repr_samples + custom_fields = [ + CustomField( + source_field="is_expert", + destination_field="expert_mask", + data_type=torch.bool, + ), + CustomField( + source_field="step", + destination_field="step", + data_type=torch.int32, + ), + ] + for exp in exp_list: + exp.custom_fields = custom_fields + return exp_list, metrics, repr_samples @classmethod def default_args(cls) -> Dict: diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 2ab63032cb..2398961ae6 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -4,7 +4,7 @@ from trinity.algorithm.sample_strategy.utils import representative_sample from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig -from trinity.common.experience import Experience, Experiences +from trinity.common.experience import Experience from trinity.utils.annotations import Deprecated from trinity.utils.monitor import gather_metrics from trinity.utils.timer import Timer @@ -12,7 +12,7 @@ class SampleStrategy(ABC): def __init__(self, buffer_config: BufferConfig, **kwargs) -> None: - self.pad_token_id = buffer_config.pad_token_id + pass def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict): metric_list = [ @@ -23,14 +23,14 @@ def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict): metrics.update(gather_metrics(metric_list, "sample")) @abstractmethod - async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int) -> Tuple[List[Experience], Dict, List]: """Sample data from buffer. Args: step (`int`): The step number of current step. Returns: - `Experiences`: The sampled Experiences data. + `List[Experience]`: The sampled List[Experience] data. `Dict`: Metrics for logging. `List`: Representative data for logging. """ @@ -54,15 +54,13 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) self.exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type] - async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int, **kwargs) -> Tuple[List[Experience], Dict, List]: metrics = {} with Timer(metrics, "time/read_experience"): exp_list = await self.exp_buffer.read_async() repr_samples = representative_sample(exp_list) self.set_model_version_metric(exp_list, metrics) - with Timer(metrics, "time/gather_experience"): - exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - return exps, metrics, repr_samples + return exp_list, metrics, repr_samples @classmethod def default_args(cls) -> dict: @@ -81,16 +79,14 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) self.max_staleness = kwargs.get("max_staleness", float("inf")) - async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int, **kwargs) -> Tuple[List[Experience], Dict, List]: min_model_version = max(step - self.max_staleness, 0) metrics = {} with Timer(metrics, "time/read_experience"): exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version) repr_samples = representative_sample(exp_list) self.set_model_version_metric(exp_list, metrics) - with Timer(metrics, "time/gather_experience"): - exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - return exps, metrics, repr_samples + return exp_list, metrics, repr_samples @Deprecated diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index b3b1d14c12..4a9977dcc0 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -18,6 +18,7 @@ def __init__(self, config: StorageConfig): self.timeout = config.max_read_timeout self.read_batch_size = config.batch_size self.queue = QueueStorage.get_wrapper(config) + ray.get(self.queue.acquire.remote()) def read(self, batch_size: Optional[int] = None, **kwargs) -> List: try: @@ -47,3 +48,6 @@ def state_dict(self) -> Dict: def load_state_dict(self, state_dict): # Queue Not supporting state dict yet return None + + def __del__(self): + ray.get(self.queue.release.remote()) diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index f7572c628c..ea83ba2dc0 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -17,6 +17,7 @@ def __init__(self, config: StorageConfig) -> None: assert config.storage_type == StorageType.SQL.value self.wrap_in_ray = config.wrap_in_ray self.storage = SQLStorage.get_wrapper(config) + ray.get(self.storage.acquire.remote()) def read(self, batch_size: Optional[int] = None, **kwargs) -> List: if self.wrap_in_ray: @@ -40,3 +41,6 @@ def state_dict(self) -> Dict: def load_state_dict(self, state_dict): # SQL Not supporting state dict yet return None + + def __del__(self): + ray.get(self.storage.release.remote()) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index d9c0d95771..28ba57ecc8 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -176,9 +176,9 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): raise RuntimeError("Ray is not running, please start it by `ray start --head`.") try: - from trinity.trainer.verl.utils import get_latest_hf_checkpoint_path - if config.stages: + from trinity.trainer.verl.utils import get_latest_hf_checkpoint_path + state_manager = StateManager( path=os.path.join(config.checkpoint_root_dir, config.project, config.name) ) diff --git a/trinity/common/config.py b/trinity/common/config.py index 2a43b2e235..8fe5f23740 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -428,6 +428,17 @@ class DataProcessorConfig: ) +@dataclass +class TinkerConfig: + enable: bool = False + base_model: Optional[str] = None + rank: int = 32 # lora rank + seed: Optional[int] = None + train_mlp: bool = True + train_attn: bool = True + train_unembed: bool = True + + @dataclass class ModelConfig: # source model path @@ -472,11 +483,15 @@ class ModelConfig: rope_scaling: Optional[dict] = None rope_theta: Optional[float] = None + # tinker config + tinker: TinkerConfig = field(default_factory=TinkerConfig) + @dataclass class InferenceModelConfig: # ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path model_path: Optional[str] = None + tinker_base_model: Optional[str] = None engine_type: str = "vllm" engine_num: int = 1 @@ -1149,6 +1164,9 @@ def _check_model(self) -> None: if not model.critic_model_path: model.critic_model_path = model.model_path + if model.tinker.enable: + self._check_tinker() + # check template if model.chat_template_path is not None and model.custom_chat_template is None: try: @@ -1160,7 +1178,57 @@ def _check_model(self) -> None: ) # check max_model_len, max_prompt_tokens, max_response_tokens + self._check_model_len() + + def _check_tinker(self) -> None: + model = self.model + from trinity.algorithm import ALGORITHM_TYPE + + algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type) + if algorithm.use_critic: + raise ValueError("Critic model is not supported when using tinker!") + + set_if_none(model.tinker, "base_model", model.model_path) + + import tinker + + service_client = tinker.ServiceClient() + supported_models = { + item.model_name for item in service_client.get_server_capabilities().supported_models + } + if model.tinker.base_model not in supported_models: + logger.error(f"Supported models: {supported_models}") + raise ValueError(f"{model.tinker.base_model} is not supported by tinker!") + if model.tinker.base_model != model.model_path: + logger.warning( + f"The local tokenizer will use {model.model_path}, while tinker will use {model.tinker.base_model}" + ) + + if ( + self.algorithm.entropy_loss_fn != "none" + and self.algorithm.entropy_loss_fn_args.get("entropy_coef", 0.0) != 0.0 + ): + logger.warning( + "The entropy in Tinker trainer is an estimated value; " + "it is recommended to set `entropy_coef` to 0." + ) + + if self.explorer.rollout_model.engine_type != "tinker": + self.explorer.rollout_model.engine_type = "tinker" + logger.warning("Rollout model engine type is set to `tinker`.") + + if self.trainer.trainer_type != "tinker": + self.trainer.trainer_type = "tinker" + logger.warning("Trainer type is set to `tinker`.") + + if self.synchronizer.sync_method == SyncMethod.NCCL: + self.synchronizer.sync_method = SyncMethod.CHECKPOINT + logger.warning( + "Tinker do not support NCCL, `synchronizer.sync_method` is set to `checkpoint`." + ) + def _check_model_len(self) -> None: + model = self.model # if all three are set, check if they are valid if ( model.max_model_len is not None @@ -1225,6 +1293,107 @@ def _check_model(self) -> None: "`enable_prompt_truncation` is set to False; please make sure the prompt is not too long and `max_model_len` is large enough, otherwise prompt length + response length may exceed `max_model_len`!" ) + def _check_explorer(self) -> None: + rollout_args = ["temperature", "top_p", "top_k", "logprobs", "repetition_penalty"] + length_args = [ + "max_model_len", + "max_prompt_tokens", + "max_response_tokens", + "min_response_tokens", + "enable_prompt_truncation", + ] + rope_args = ["rope_scaling", "rope_theta"] + model_args = rollout_args + length_args + rope_args + set_if_none(self.explorer.rollout_model, "model_path", self.model.model_path) + for args in model_args: + set_if_none(self.explorer.rollout_model, args, getattr(self.model, args)) + if ( + self.explorer.rollout_model.chat_template is None + and self.model.custom_chat_template is not None + ): + self.explorer.rollout_model.chat_template = self.model.custom_chat_template + for aux_model in self.explorer.auxiliary_models: + if not aux_model.model_path: + raise ValueError("auxiliary model's model_path is required.") + for args in model_args: + set_if_none(aux_model, args, getattr(self.model, args)) + + if self.explorer.rollout_model.engine_type == "tinker": + set_if_none( + self.explorer.rollout_model, "tinker_base_model", self.model.tinker.base_model + ) + else: + # check gpu number + rollout_gpu_num = ( + self.explorer.rollout_model.tensor_parallel_size + * self.explorer.rollout_model.engine_num + + sum( + ( + model.tensor_parallel_size * model.engine_num + for model in self.explorer.auxiliary_models + ) + ) + ) + assert self.cluster.node_num is not None + assert self.cluster.gpu_per_node is not None + total_gpu_num = self.cluster.node_num * self.cluster.gpu_per_node + if self.mode in ["explore", "bench", "serve"] and rollout_gpu_num > total_gpu_num: + raise ValueError( + f"Total GPU number ({total_gpu_num}) is less than the number of GPUs required for rollout ({rollout_gpu_num})." + ) + elif self.mode == "both" and rollout_gpu_num >= total_gpu_num: + raise ValueError( + f"Not enough GPUs for trainer in 'both' mode. Explorer requires {rollout_gpu_num} GPUs, but total available GPUs are {total_gpu_num}." + ) + + if self.explorer.over_rollout.ratio > 0.0: + if not (0.0 <= self.explorer.over_rollout.ratio < 1.0): + raise ValueError("over_rollout_ratio should be in [0.0, 1.0)") + if self.synchronizer.sync_style == SyncStyle.FIXED: + raise ValueError( + "over_rollout_ratio is not compatible with fixed sync_style, please set " + "`synchronizer.sync_style` to `dynamic_by_explorer` or `dynamic_by_trainer`." + ) + + # for lora configs + if not self.model.tinker.enable and self.model.lora_configs is not None: + self.explorer.rollout_model.enable_lora = True + if len(self.model.lora_configs) > 1: + raise ValueError("Only one lora adapter is supported for now.") + if self.model.lora_configs[0].path is None: + logger.info("Creating dummy lora, since no lora_path is provided.") + lora_path = create_dummy_lora( + model_path=self.model.model_path, + checkpoint_job_dir=self.checkpoint_job_dir, + lora_rank=self.model.lora_configs[0].lora_rank, + lora_alpha=self.model.lora_configs[0].lora_alpha, + target_modules=self.model.lora_configs[0].target_modules, + ) + self.model.lora_configs[0].path = lora_path + self.explorer.rollout_model.lora_modules = [ + { + "lora_int_id": i + 1, + "lora_name": cfg.name, + "lora_path": cfg.path, + "base_model_name": cfg.base_model_name, + } + for i, cfg in enumerate(self.model.lora_configs) + ] + self.explorer.rollout_model.lora_kwargs = { + "max_loras": len(self.model.lora_configs), + "max_lora_rank": max( + ( + model_config.lora_rank + for model_config in self.model.lora_configs + if model_config.lora_rank > 0 + ), + default=0, + ), + "default_lora_path": os.path.join( + self.checkpoint_job_dir, "global_step_0", "actor", "lora_adapter" + ), # will be poped later + } + def __iter__(self): """Iterate over configs with each stage applied in order. @@ -1291,99 +1460,7 @@ def check_and_update(self) -> Config: # noqa: C901 # check explorer if self.explorer is not None: - rollout_args = ["temperature", "top_p", "top_k", "logprobs", "repetition_penalty"] - length_args = [ - "max_model_len", - "max_prompt_tokens", - "max_response_tokens", - "min_response_tokens", - "enable_prompt_truncation", - ] - rope_args = ["rope_scaling", "rope_theta"] - model_args = rollout_args + length_args + rope_args - for args in ["model_path"] + model_args: - set_if_none(self.explorer.rollout_model, args, getattr(self.model, args)) - if ( - self.explorer.rollout_model.chat_template is None - and self.model.custom_chat_template is not None - ): - self.explorer.rollout_model.chat_template = self.model.custom_chat_template - for aux_model in self.explorer.auxiliary_models: - if not aux_model.model_path: - raise ValueError("auxiliary model's model_path is required.") - for args in model_args: - set_if_none(aux_model, args, getattr(self.model, args)) - - # check gpu number - rollout_gpu_num = ( - self.explorer.rollout_model.tensor_parallel_size - * self.explorer.rollout_model.engine_num - + sum( - ( - model.tensor_parallel_size * model.engine_num - for model in self.explorer.auxiliary_models - ) - ) - ) - assert self.cluster.node_num is not None - assert self.cluster.gpu_per_node is not None - total_gpu_num = self.cluster.node_num * self.cluster.gpu_per_node - if self.mode in ["explore", "bench", "serve"] and rollout_gpu_num > total_gpu_num: - raise ValueError( - f"Total GPU number ({total_gpu_num}) is less than the number of GPUs required for rollout ({rollout_gpu_num})." - ) - elif self.mode == "both" and rollout_gpu_num >= total_gpu_num: - raise ValueError( - f"Not enough GPUs for trainer in 'both' mode. Explorer requires {rollout_gpu_num} GPUs, but total available GPUs are {total_gpu_num}." - ) - - if self.explorer.over_rollout.ratio > 0.0: - if not (0.0 <= self.explorer.over_rollout.ratio < 1.0): - raise ValueError("over_rollout_ratio should be in [0.0, 1.0)") - if self.synchronizer.sync_style == SyncStyle.FIXED: - raise ValueError( - "over_rollout_ratio is not compatible with fixed sync_style, please set " - "`synchronizer.sync_style` to `dynamic_by_explorer` or `dynamic_by_trainer`." - ) - - # for lora configs - if self.model.lora_configs is not None: - self.explorer.rollout_model.enable_lora = True - if len(self.model.lora_configs) > 1: - raise ValueError("Only one lora adapter is supported for now.") - if self.model.lora_configs[0].path is None: - logger.info("Creating dummy lora, since no lora_path is provided.") - lora_path = create_dummy_lora( - model_path=self.model.model_path, - checkpoint_job_dir=self.checkpoint_job_dir, - lora_rank=self.model.lora_configs[0].lora_rank, - lora_alpha=self.model.lora_configs[0].lora_alpha, - target_modules=self.model.lora_configs[0].target_modules, - ) - self.model.lora_configs[0].path = lora_path - self.explorer.rollout_model.lora_modules = [ - { - "lora_int_id": i + 1, - "lora_name": cfg.name, - "lora_path": cfg.path, - "base_model_name": cfg.base_model_name, - } - for i, cfg in enumerate(self.model.lora_configs) - ] - self.explorer.rollout_model.lora_kwargs = { - "max_loras": len(self.model.lora_configs), - "max_lora_rank": max( - ( - model_config.lora_rank - for model_config in self.model.lora_configs - if model_config.lora_rank > 0 - ), - default=0, - ), - "default_lora_path": os.path.join( - self.checkpoint_job_dir, "global_step_0", "actor", "lora_adapter" - ), # will be poped later - } + self._check_explorer() # check synchronizer self.synchronizer.ray_namespace = self.ray_namespace @@ -1391,14 +1468,17 @@ def check_and_update(self) -> Config: # noqa: C901 self.explorer.rollout_model.engine_num * self.explorer.rollout_model.tensor_parallel_size ) - if ( - self.mode in ["train", "explore", "bench", "serve"] - and self.synchronizer.sync_method == SyncMethod.NCCL - ): - self.synchronizer.sync_method = SyncMethod.CHECKPOINT - logger.warning( - f"`{self.mode}` mode does not support NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`." - ) + if self.synchronizer.sync_method == SyncMethod.NCCL: + if self.mode in ["train", "explore", "bench", "serve"]: + self.synchronizer.sync_method = SyncMethod.CHECKPOINT + logger.warning( + f"`{self.mode}` mode does not support NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`." + ) + if self.model.lora_configs is not None: + self.synchronizer.sync_method = SyncMethod.CHECKPOINT + logger.warning( + "LoRA is not supported with NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`." + ) self._check_interval() @@ -1450,9 +1530,11 @@ def check_and_update(self) -> Config: # noqa: C901 f"Invalid trainer.save_hf_checkpoint: {self.trainer.save_hf_checkpoint}, " "must be one of 'last', 'always', or 'never'." ) + self.trainer.trainer_config.synchronize_config(self) + elif self.trainer.trainer_type == "tinker": + self.trainer.trainer_config = None else: raise ValueError(f"Invalid trainer type: {self.trainer_type}") - self.trainer.trainer_config.synchronize_config(self) # check service if self.service.data_juicer is not None: diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 0f94e0bdd5..9fa48a59ef 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -136,6 +136,8 @@ class Experience: # for on-policy distillation teacher_logprobs: Optional[Tensor] = None # [resp_length] + custom_fields: List[CustomField] = field(default_factory=list) + def __init__( # noqa: C901 self, *, @@ -161,6 +163,7 @@ def __init__( # noqa: C901 rejected_messages=None, multi_modal_inputs=None, teacher_logprobs=None, + custom_fields=None, ): if action_mask is not None: experience_type = "multi_turn" @@ -250,6 +253,7 @@ def __init__( # noqa: C901 self.rejected = torch.tensor(self.rejected) if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor): self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32) + self.custom_fields = custom_fields or [] def serialize(self) -> bytes: """Serialize the experience to bytes.""" diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 190be581cb..46958faa6c 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -45,15 +45,44 @@ def create_inference_models( from ray.util.placement_group import placement_group, placement_group_table from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - from trinity.common.models.vllm_model import vLLMRolloutModel - logger = get_logger(__name__) engine_num = config.explorer.rollout_model.engine_num tensor_parallel_size = config.explorer.rollout_model.tensor_parallel_size rollout_engines = [] if config.explorer.rollout_model.engine_type.startswith("vllm"): + from trinity.common.models.vllm_model import vLLMRolloutModel + engine_cls = vLLMRolloutModel + elif config.explorer.rollout_model.engine_type == "tinker": + from trinity.common.models.tinker_model import TinkerModel + + engine_cls = TinkerModel + namespace = ray.get_runtime_context().namespace + rollout_engines = [ + ray.remote(engine_cls) + .options( + name=f"{config.explorer.name}_rollout_model_{i}", + namespace=namespace, + ) + .remote( + config=config.explorer.rollout_model, + ) + for i in range(engine_num) + ] + auxiliary_engines = [ + ray.remote(engine_cls) + .options( + name=f"{config.explorer.name}_auxiliary_model_{i}_{j}", + namespace=namespace, + ) + .remote( + config=config.explorer.auxiliary_models[i], + ) + for i, model_config in enumerate(config.explorer.auxiliary_models) + for j in range(model_config.engine_num) + ] + return rollout_engines, auxiliary_engines else: raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}") @@ -124,7 +153,7 @@ def create_inference_models( model_config.engine_type = "vllm" model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) engines.append( - ray.remote(vLLMRolloutModel) + ray.remote(engine_cls) .options( name=f"{config.explorer.name}_auxiliary_model_{i}_{j}", num_cpus=0, diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 3fe6f2bf37..00d27b5742 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -46,6 +46,10 @@ async def prepare(self) -> None: """Prepare the model before inference.""" pass + @abstractmethod + async def sync_model(self, model_version: int) -> int: + """Sync the model with the latest model_version.""" + @abstractmethod def get_model_version(self) -> int: """Get the checkpoint version.""" @@ -105,7 +109,9 @@ def __init__( enable_history (bool): Whether to enable history recording. Default to False. enable_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models. """ - assert engine_type.startswith("vllm"), "Only vLLM model is supported for now." + assert ( + engine_type.startswith("vllm") or engine_type == "tinker" + ), "Only vLLM and tinker model is supported for now." self.model = model self.api_address: str = None self.openai_client: openai.OpenAI = None @@ -205,13 +211,13 @@ async def generate_mm_async( def chat(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of messages.""" lora_request = self.get_lora_request() - return ray.get(self.model.chat.remote(messages, lora_request, **kwargs)) + return ray.get(self.model.chat.remote(messages, lora_request=lora_request, **kwargs)) @_history_recorder async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of messages in async.""" lora_request = await self.get_lora_request_async() - return await self.model.chat.remote(messages, lora_request, **kwargs) + return await self.model.chat.remote(messages, lora_request=lora_request, **kwargs) @_history_recorder def chat_mm( diff --git a/trinity/common/models/tinker_model.py b/trinity/common/models/tinker_model.py new file mode 100644 index 0000000000..951a4cf322 --- /dev/null +++ b/trinity/common/models/tinker_model.py @@ -0,0 +1,203 @@ +from typing import List, Optional, Sequence + +import ray +import tinker +import torch +import transformers +from tinker import types +from torch import Tensor + +from trinity.common.config import InferenceModelConfig +from trinity.common.experience import Experience +from trinity.common.models.model import InferenceModel +from trinity.common.models.utils import get_action_mask_method +from trinity.manager.synchronizer import Synchronizer +from trinity.utils.log import get_logger + + +class TinkerModel(InferenceModel): + def __init__( + self, + config: InferenceModelConfig, + ) -> None: + self.config = config + self.model_version = -1 + self.synchronizer = Synchronizer.get_actor(namespace=ray.get_runtime_context().namespace) + self.logger = get_logger(__name__) + self.model = None + self.tokenizer = None + self.chat_template = None + if self.config.chat_template: + self.chat_template = self.config.chat_template + self.action_mask_method = get_action_mask_method(self.chat_template) + self.enable_thinking = config.enable_thinking + + async def _initialize_tokenizer(self) -> None: + """Initialize the tokenizer.""" + self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.config.model_path) + + async def _generate_internal(self, prompt: dict, **kwargs) -> types.SampleResponse: + assert self.model is not None + sampling_params = { + "max_tokens": kwargs.get("max_tokens", self.config.max_response_tokens), + "seed": kwargs.get("seed", self.config.seed), + "temperature": kwargs.get("temperature", 1.0), + "top_k": kwargs.get("top_k", -1), + "top_p": kwargs.get("top_p", 1), + } + + return await self.model.sample_async( + prompt=types.ModelInput.from_ints(prompt["prompt_token_ids"]), + sampling_params=sampling_params, + num_samples=kwargs.get("n", 1), + include_prompt_logprobs=kwargs.get("include_prompt_logprobs", False), + topk_prompt_logprobs=kwargs.get("topk_prompt_logprobs", self.config.logprobs), + ) + + async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: + """Generate a responses from a prompt in async.""" + if self.tokenizer is None: + await self._initialize_tokenizer() + + # Tokenize once without truncation to check if truncation is needed + token_ids = self.tokenizer( # type: ignore + prompt, + truncation=False, + return_tensors="pt", + )[ + "input_ids" + ][0].tolist() + + # Check if truncation is needed and apply it + if self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None: + if len(token_ids) > self.config.max_prompt_tokens: + self.logger.warning( + f"Prompt was truncated to {self.config.max_prompt_tokens} tokens" + ) + token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response + return [ + Experience( + tokens=token_ids, + logprobs=torch.zeros(1, dtype=torch.float32), + prompt_length=len(token_ids) - 1, + prompt_text=self.tokenizer.decode(token_ids[:-1]), + response_text=self.tokenizer.decode(token_ids[-1]), + truncate_status="prompt_truncated", + reward=0.0, + ) + for _ in range(kwargs.get("n", 1)) + ] + + output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs) + experiences = [ + Experience( + tokens=torch.tensor(token_ids + sequence.tokens, dtype=torch.int32), + logprobs=torch.tensor(sequence.logprobs, dtype=torch.float32), + prompt_length=len(token_ids), + prompt_text=self.tokenizer.decode(token_ids), + response_text=self.tokenizer.decode(sequence.tokens), + ) + for sequence in output.sequences + ] + + return experiences + + async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: + """Generate experiences from a list of history chat messages in async.""" + if self.tokenizer is None: + await self._initialize_tokenizer() + if self.chat_template is None: + self.chat_template = self.tokenizer.get_chat_template() + if messages[-1]["role"] == "assistant": + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + continue_final_message=True, + chat_template=self.chat_template, + ) + else: + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + chat_template=self.chat_template, + enable_thinking=self.enable_thinking, + ) + return await self.generate(prompt=prompt, **kwargs) + + async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor: + """Generate logprobs for a list of tokens in async.""" + logprobs = await self.model.compute_logprobs_async(types.ModelInput(token_ids)) + return torch.tensor(logprobs[1:], dtype=torch.float32) + + async def convert_messages_to_experience( + self, + messages: List[dict], + tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, + ) -> Experience: + """Convert a list of messages into an experience in async.""" + if self.tokenizer is None: + await self._initialize_tokenizer() + if self.chat_template is None: + self.chat_template = self.tokenizer.get_chat_template() + token_ids, action_mask, prompt_length = self.action_mask_method( + tokenizer=self.tokenizer, + messages=messages, + tools=tools, + chat_template=self.chat_template, + enable_thinking=self.enable_thinking, + ) # (seq_length, ), (seq_length, ) + + # Truncate tokens if they exceed the length limit + assert token_ids is not None + truncate_status = None + if self.config.max_model_len is not None and self.config.max_model_len > 0: + if len(token_ids) > self.config.max_model_len - 1: + truncate_status = "response_truncated" + self.logger.warning( + f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}" + ) + token_ids = token_ids[: self.config.max_model_len - 1] + action_mask = action_mask[: self.config.max_model_len - 1] + + temperature = temperature if temperature is not None else self.config.temperature + logprobs = await self.logprobs( + token_ids=token_ids.tolist(), temperature=temperature + ) # (seq_length - 1,) + return Experience( + tokens=token_ids, + logprobs=logprobs[prompt_length - 1 :], + prompt_length=prompt_length, + action_mask=action_mask[prompt_length:], # Exclude the prompt tokens + messages=messages, + truncate_status=truncate_status, + ) + + async def prepare(self) -> None: + """Prepare the model before inference.""" + self.service_client = tinker.ServiceClient() + self.model = await self.service_client.create_sampling_client_async( + base_model=self.config.tinker_base_model, + ) + + async def sync_model(self, model_version: int) -> int: + self.model_version = model_version + remote_sampler_path, _ = await self.synchronizer.get_model_state_dict.remote() + self.model = await self.service_client.create_sampling_client_async( + model_path=remote_sampler_path, + ) + return model_version + + def get_model_version(self) -> int: + """Get the checkpoint version.""" + return self.model_version + + def get_api_server_url(self) -> Optional[str]: + """Get the API server URL if available.""" + # TODO: tinker will support openai api later + return None + + def get_model_path(self) -> Optional[str]: + """Get the model path""" + return self.config.model_path # type: ignore [return-value] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 767d30a5ed..5912d951c1 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -50,6 +50,7 @@ def __init__(self, config: Config): self.last_monitored_step = self.explore_step_num self.synchronizer = Synchronizer.get_actor(config) self.config = config + self.model_type = config.explorer.rollout_model.engine_type self.models, self.auxiliary_models = create_inference_models(config) self.experience_pipeline = self._init_experience_pipeline() self.taskset = ( @@ -149,7 +150,10 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> in async def _pull_latest_weights(self): self.logger.info("Start to pull latest model weights.") - new_version = await self.synchronizer.wait_new_model_state_dict.remote(self.model_version) + new_version = await self.synchronizer.wait_new_model_state_dict.remote( + current_version=self.model_version, + no_wait=(self.config.synchronizer.sync_style != SyncStyle.FIXED), + ) if new_version > self.model_version: if self.model_version != -1: self.logger.info(f"New model weights version: {new_version}") @@ -195,7 +199,7 @@ async def prepare(self) -> None: await asyncio.gather(*run_api_ref) self.logger.info("All models are ready.") - if not self.use_nccl_sync: + if not self.use_nccl_sync and self.model_type != "tinker": if self.config.mode == "serve": # In serving mode, each engine will setup its own process group await self.setup_model_level_weight_sync_group() diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 8157ad088d..c0b913812e 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -79,7 +79,16 @@ async def _check_modules(self) -> None: pass async def _find_latest_state_dict(self) -> None: - assert self.config.trainer.trainer_type == "verl" + if self.config.trainer.trainer_type == "verl": + await self._find_verl_latest_state_dict() + elif self.config.trainer.trainer_type == "tinker": + await self._find_tinker_latest_state_dict() + else: + self.logger.warning( + "Synchronizer does not support this trainer type. Please use `verl` or `tinker`." + ) + + async def _find_verl_latest_state_dict(self) -> None: default_local_dir = self.config.checkpoint_job_dir local_latest_state_dict_iteration = os.path.join( default_local_dir, "latest_state_dict_iteration.txt" @@ -112,6 +121,33 @@ async def _find_latest_state_dict(self) -> None: await self.set_model_state_dict(model_state_dict, latest_model_version) await asyncio.sleep(1) + async def _find_tinker_latest_state_dict(self) -> None: + default_local_dir = self.config.checkpoint_job_dir + local_latest_state_dict_iteration = os.path.join( + default_local_dir, "latest_state_dict_iteration.txt" + ) + while True: + if os.path.exists(local_latest_state_dict_iteration): + try: + with open(local_latest_state_dict_iteration, "r") as f: + latest_model_version = int(f.read().strip()) + except (IOError, ValueError) as e: + self.logger.warning(f"Failed to read or parse state dict iteration file: {e}") + continue + if latest_model_version > self.model_version: + self.logger.info( + f"Synchronizer has found a new remote tinker sampler path at step {latest_model_version}." + ) + remote_path_file = os.path.join( + default_local_dir, + f"global_step_{latest_model_version}", + "remote_sampler_path.txt", + ) + with open(remote_path_file, "r") as f: + remote_sampler_path = f.read().strip() + await self.set_model_state_dict(remote_sampler_path, latest_model_version) + await asyncio.sleep(1) + async def set_trainer_status(self, status: RunningStatus): """Update the status of the trainer.""" async with self._ready_condition: @@ -192,7 +228,7 @@ async def set_model_state_dict_with_step_num( return checkpoint_step_num async def set_model_state_dict( - self, model_state_dict: Union[dict, None, Tuple[str, str]], trainer_step: int + self, model_state_dict: Union[dict, None, str, Tuple[str, str]], trainer_step: int ): """ Set the new model state and update the version. diff --git a/trinity/trainer/tinker/__init__.py b/trinity/trainer/tinker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/trainer/tinker/utils.py b/trinity/trainer/tinker/utils.py new file mode 100644 index 0000000000..544ef89768 --- /dev/null +++ b/trinity/trainer/tinker/utils.py @@ -0,0 +1,243 @@ +from logging import Logger +from typing import Any, List, Tuple + +import torch +from tinker import types + +from trinity.common.experience import Experience, split_dpo_experience_to_single_turn + + +def to_tinker_input( + experiences: List[Experience], logger: Logger +) -> Tuple[List[types.Datum], List[types.ModelInput], List[dict]]: + assert len(experiences) > 0, "No experiences provided." + if experiences[0].experience_type == "dpo": + experiences = split_dpo_experience_to_single_turn(experiences) + + batch = [] + batch_input_tokens = [] + model_inputs_list = [] + for exp in experiences: + tokens = exp.tokens + input_tokens = tokens.long() + prompt_length = exp.prompt_length + total_length = len(tokens) # type: ignore + response_length = total_length - prompt_length + loss_fn_inputs = { + "weights": torch.concat( + [ + torch.zeros(prompt_length - 1, dtype=torch.float32), + exp.action_mask.float(), + ] + ), + "target_tokens": input_tokens.tolist()[1:], + } + model_inputs = { + "total_length": total_length, + "action_mask": exp.action_mask, + } + if exp.reward is not None or exp.token_level_reward is not None: + assert exp.logprobs is not None + if exp.token_level_reward is not None: + if exp.reward is not None: + logger.warning( + "Both exp.rewards and exp.token_level_rewards are provided. " + "Using exp.token_level_rewards." + ) + token_level_reward = exp.token_level_reward + else: + token_level_reward = torch.zeros(response_length, dtype=torch.float32) + token_level_reward[-1] = exp.reward + model_inputs.update( + { + "token_level_scores": token_level_reward, + "old_logprob": exp.logprobs, + } + ) + for attr in ["advantages", "returns", "teacher_logprobs"]: + if getattr(exp, attr, None) is not None: + model_inputs[attr] = getattr(exp, attr) + # TODO: if tinker support multi-modal input, we can add it here + for custom_field in exp.custom_fields: + model_inputs[custom_field.destination_field] = torch.tensor( + exp.info[custom_field.source_field], + dtype=custom_field.data_type, + ) + + batch.append( + types.Datum( + model_input=types.ModelInput.from_ints(tokens=input_tokens.tolist()[:-1]), + loss_fn_inputs=loss_fn_inputs, + ) + ) + batch_input_tokens.append(types.ModelInput.from_ints(input_tokens.tolist())) + model_inputs_list.append(model_inputs) + return batch, batch_input_tokens, model_inputs_list + + +def compute_data_metrics(batch: List[dict[str, torch.Tensor]]) -> dict: + """ + Computes various metrics from a batch of data for PPO training. + Modified from `verl.trainer.ppo.metric_utils.compute_data_metrics`. + + This function calculates metrics related to scores, rewards, advantages, returns, values, + and sequence lengths from a batch of data. It provides statistical information (mean, max, min) + for each metric category. + + Args: + batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. + use_critic: Whether to include critic-specific metrics. Defaults to True. + + Returns: + A dictionary of metrics including: + - critic/score/mean, max, min: Statistics about sequence scores + - critic/rewards/mean, max, min: Statistics about sequence rewards + - critic/advantages/mean, max, min: Statistics about advantages + - critic/returns/mean, max, min: Statistics about returns + - critic/values/mean, max, min: Statistics about critic values + - critic/vf_explained_var: Explained variance of the value function + - response_length/mean, max, min, clip_ratio: Statistics about response lengths + - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths + """ + metrics = {} + + assert len(batch) > 0, "Batch is empty" + + if "token_level_rewards" in batch[0] and "token_level_scores" in batch[0]: + sequence_score = torch.tensor([data["token_level_scores"].sum() for data in batch]) + sequence_reward = torch.tensor([data["token_level_rewards"].sum() for data in batch]) + metrics.update( + { + # score + "critic/score/mean": torch.mean(sequence_score).detach().item(), + "critic/score/max": torch.max(sequence_score).detach().item(), + "critic/score/min": torch.min(sequence_score).detach().item(), + # reward + "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), + "critic/rewards/max": torch.max(sequence_reward).detach().item(), + "critic/rewards/min": torch.min(sequence_reward).detach().item(), + } + ) + + response_length = torch.tensor([len(data["action_mask"]) for data in batch]).float() + token_length = torch.tensor([data["total_length"] for data in batch]).float() + prompt_length = token_length - response_length + max_response_length = max(response_length) + max_prompt_length = max(prompt_length) + metrics.update( + { + # response length + "response_length/mean": torch.mean(response_length).detach().item(), + "response_length/max": torch.max(response_length).detach().item(), + "response_length/min": torch.min(response_length).detach().item(), + "response_length/clip_ratio": torch.mean( + torch.eq(response_length, max_response_length).float() + ) + .detach() + .item(), + # prompt length + "prompt_length/mean": torch.mean(prompt_length).detach().item(), + "prompt_length/max": torch.max(prompt_length).detach().item(), + "prompt_length/min": torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio": torch.mean( + torch.eq(prompt_length, max_prompt_length).float() + ) + .detach() + .item(), + } + ) + + if "advantages" in batch[0]: + valid_adv = torch.concat([data["advantages"] for data in batch]) + metrics.update( + { + "critic/advantages/mean": torch.mean(valid_adv).detach().item(), + "critic/advantages/max": torch.max(valid_adv).detach().item(), + "critic/advantages/min": torch.min(valid_adv).detach().item(), + } + ) + if "returns" in batch[0]: + valid_returns = torch.concat([data["returns"] for data in batch]) + metrics.update( + { + "critic/returns/mean": torch.mean(valid_returns).detach().item(), + "critic/returns/max": torch.max(valid_returns).detach().item(), + "critic/returns/min": torch.min(valid_returns).detach().item(), + } + ) + + return metrics + + +def compute_timing_metrics( + batch: List[dict[str, torch.Tensor]], timing_raw: dict[str, float] +) -> dict[str, Any]: + """ + Computes timing metrics for different processing stages in PPO training. + Modified from `verl.trainer.ppo.metric_utils.compute_timing_metrics`. + + This function calculates both raw timing metrics (in seconds) and per-token timing metrics + (in milliseconds) for various processing stages like generation, reference computation, + value computation, advantage computation, and model updates. + + Args: + batch: A DataProto object containing batch data with responses and attention masks. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + + Returns: + A dictionary containing: + - timing_s/{name}: Raw timing in seconds for each stage + - timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage + + Note: + Different stages use different token counts for normalization: + - "gen" uses only response tokens + - Other stages ("ref", "values", "adv", "update_critic", "update_actor") use all tokens + (prompt + response) + """ + num_overall_tokens = sum(data["total_length"] for data in batch) + num_response_tokens = sum(len(data["action_mask"]) for data in batch) + + num_tokens_of_section = { + "gen": num_response_tokens, + **{ + name: num_overall_tokens + for name in ["ref", "values", "adv", "update_critic", "update_actor"] + }, + } + + return { + **{f"timing_s/{name}": value for name, value in timing_raw.items()}, + **{ + f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) + }, + } + + +def compute_throughout_metrics( + batch: List[dict[str, torch.Tensor]], timing_raw: dict[str, float] +) -> dict[str, Any]: + """ + Computes throughput metrics for PPO training. + Modified from `verl.trainer.ppo.metric_utils.compute_throughout_metrics`. + + This function calculates performance metrics related to token processing speed, + including the total number of tokens processed and time per step. + + Args: + batch: A DataProto object containing batch data with meta information about token counts. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + Must contain a "step" key with the total step time. + + Returns: + A dictionary containing: + - perf/total_num_tokens: Total number of tokens processed in the batch + - perf/time_per_step: Time taken for the step in seconds + """ + total_num_tokens = sum(data["total_length"] for data in batch) + time = timing_raw["step"] + return { + "perf/total_num_tokens": total_num_tokens, + "perf/time_per_step": time, + } diff --git a/trinity/trainer/tinker_trainer.py b/trinity/trainer/tinker_trainer.py new file mode 100644 index 0000000000..3c7468923f --- /dev/null +++ b/trinity/trainer/tinker_trainer.py @@ -0,0 +1,324 @@ +import os +from typing import Dict, List + +import ray +import tinker +import torch +from tinker import types + +from trinity.algorithm import ALGORITHM_TYPE +from trinity.algorithm.advantage_fn import ADVANTAGE_FN +from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn +from trinity.algorithm.kl_fn import KL_FN +from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN +from trinity.algorithm.utils import prefix_metrics +from trinity.common.config import Config +from trinity.common.experience import Experience +from trinity.manager.synchronizer import Synchronizer +from trinity.trainer.tinker.utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + to_tinker_input, +) +from trinity.trainer.trainer import TrainEngineWrapper +from trinity.utils.log import get_logger +from trinity.utils.timer import Timer + + +class TinkerTrainerWrapper(TrainEngineWrapper): + def __init__(self, config: Config): + self.config = config + self.logger = get_logger("tinker_trainer") + self._init_algorithm() + self.synchronizer = Synchronizer.get_actor(namespace=self.config.synchronizer.ray_namespace) + + def _init_algorithm(self): + self.algorithm = ALGORITHM_TYPE.get(self.config.algorithm.algorithm_type) + algorithm_config = self.config.algorithm + if self.algorithm.compute_advantage_in_trainer: + self.advantage_fn = ADVANTAGE_FN.get(algorithm_config.advantage_fn)( + **algorithm_config.advantage_fn_args + ) + self.kl_fn = KL_FN.get(algorithm_config.kl_penalty_fn)( + **algorithm_config.kl_penalty_fn_args + ) + # TODO + raise NotImplementedError( + "`compute_advantage_in_trainer` is not implemented yet in tinker" + ) + self.loss_agg_mode = algorithm_config.loss_agg_mode + self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( + backend="tinker", **algorithm_config.policy_loss_fn_args + ) + self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args) + self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( + **algorithm_config.entropy_loss_fn_args + ) + + # EXPERIMENTAL: apply loss scale fix + self.do_fix_actor_microbatch_loss_scale = ( + self.config.trainer.fix_actor_microbatch_loss_scale + and (self.loss_agg_mode == "token-mean") + ) + + self.adam_params = types.AdamParams( + learning_rate=algorithm_config.optimizer.lr, + beta1=algorithm_config.optimizer.betas[0], + beta2=algorithm_config.optimizer.betas[1], + # eps is currently not in config + weight_decay=algorithm_config.optimizer.weight_decay, + grad_clip_norm=self.config.trainer.grad_clip, + ) + + async def prepare(self): + self.service_client = tinker.ServiceClient() + + name_prefix_list = [self.config.project, self.config.group, self.config.name] + self.tinker_checkpoint_name_prefix = "-".join( + [prefix for prefix in name_prefix_list if prefix] + ) + self.default_local_dir = self.config.checkpoint_job_dir + + self.local_latest_checkpointed_iteration = os.path.join( + self.config.checkpoint_job_dir, "latest_checkpointed_iteration.txt" + ) + self.local_latest_state_dict_iteration = os.path.join( + self.config.checkpoint_job_dir, "latest_state_dict_iteration.txt" + ) + + if os.path.exists(self.local_latest_checkpointed_iteration): + with open(self.local_latest_checkpointed_iteration, "r") as f: + self._train_step_num = self.latest_remote_checkpoint_step = int(f.read().strip()) + checkpoint_file_path = os.path.join( + self.default_local_dir, + f"global_step_{self._train_step_num}", + "remote_checkpoint_path.txt", + ) + with open(checkpoint_file_path, "r") as f: + self.latest_remote_checkpoint_path = f.read().strip() + self.actor_client = ( + await self.service_client.create_training_client_from_state_with_optimizer_async( + path=self.latest_remote_checkpoint_path, + ) + ) + else: + self.actor_client = await self.service_client.create_lora_training_client_async( + base_model=self.config.model.tinker.base_model, + rank=self.config.model.tinker.rank, + seed=self.config.model.tinker.seed, + train_mlp=self.config.model.tinker.train_mlp, + train_attn=self.config.model.tinker.train_attn, + train_unembed=self.config.model.tinker.train_unembed, + ) + self.latest_remote_checkpoint_step = 0 + self.latest_remote_checkpoint_path = None + self._train_step_num = 0 + + if os.path.exists(self.local_latest_state_dict_iteration): + with open(self.local_latest_state_dict_iteration, "r") as f: + self.latest_remote_sampler_step = int(f.read().strip()) + sampler_file_path = os.path.join( + self.default_local_dir, + f"global_step_{self.latest_remote_sampler_step}", + "remote_sampler_path.txt", + ) + with open(sampler_file_path, "r") as f: + self.latest_remote_sampler_path = f.read().strip() + else: + self.latest_remote_sampler_step = 0 + self.latest_remote_sampler_path = None + + self.ref_client = await self.service_client.create_sampling_client_async( + base_model=self.config.model.tinker.base_model, + ) + + @property + def train_step_num(self) -> int: + """Get the current training step number.""" + return self._train_step_num + + def _loss_func( + self, batch: list[types.Datum], logprobs: list[torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, float]]: + total_loss = 0.0 + metrics = {} + assert len(self.model_inputs_list) == len( + logprobs + ), "len(self.model_inputs_list) must equal to len(logprobs)" + for model_inputs, logprob in zip(self.model_inputs_list, logprobs): + micro_batch_metrics = {} + response_mask = model_inputs["action_mask"] + logprob = logprob[-response_mask.shape[0] :] + + pg_loss, pg_loss_metrics = self.policy_loss_fn(logprob=logprob, **model_inputs) + prefix_metrics( + src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics + ) + + if self.entropy_loss_fn != DummyEntropyLossFn: + entropy = -(logprob * logprob.exp()) + else: + entropy = None + # compute entropy loss from entropy + entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore + entropy=entropy, + **model_inputs, + loss_agg_mode=self.loss_agg_mode, + ) + prefix_metrics( + src_metrics=entropy_loss_metrics, + prefix="actor", + dst_metrics=micro_batch_metrics, + ) + + # compute kl loss + kl_loss, kl_loss_metrics = self.kl_loss_fn.calculate_kl_loss( + logprob=logprob, + ref_logprob=model_inputs["ref_logprob"], + response_mask=response_mask, + loss_agg_mode=self.loss_agg_mode, + old_logprob=model_inputs["old_logprob"], + ) + prefix_metrics( + src_metrics=kl_loss_metrics, + prefix="actor", + dst_metrics=micro_batch_metrics, + ) + + # compute policy loss + policy_loss = pg_loss - entropy_loss + kl_loss + loss_scale = 1.0 + if not self.do_fix_actor_microbatch_loss_scale: + loss_scale /= len(logprobs) + loss = policy_loss * loss_scale + total_loss = total_loss + loss + micro_batch_metrics["actor/final_loss"] = loss.detach().item() + + # update metrics + for key, val in micro_batch_metrics.items(): + if key not in metrics: + metrics[key] = [] + metrics[key].append(val) + + avg_metrics = {k: sum(v) / len(v) for k, v in metrics.items()} + return total_loss, avg_metrics + + async def train_step(self, batch_exps: List[Experience]) -> Dict: + """Training one step. + + Args: + batch (List[Experience]): A batch of experiences to train. + + Returns: + Dict: Metrics of the training step. + """ + batch, batch_input_tokens, model_inputs_list = to_tinker_input(batch_exps, self.logger) + self.model_inputs_list = model_inputs_list + timing_raw = {} + metrics = {} + self._train_step_num += 1 + + with Timer(timing_raw, "step"): + if self.algorithm.use_reference: # ref_logprob may not be used + import asyncio + + ref_logprobs = await asyncio.gather( + *[ + self.ref_client.compute_logprobs_async(input_tokens) + for input_tokens in batch_input_tokens + ] + ) + for model_inputs, ref_logprob in zip(model_inputs_list, ref_logprobs): + response_length = model_inputs["action_mask"].shape[0] + model_inputs["ref_logprob"] = torch.tensor(ref_logprob[-response_length:]) + + if self.algorithm.compute_advantage_in_trainer: + # TODO: following is verl format, which is not compatible with tinker + raise NotImplementedError( + "`compute_advantage_in_trainer` is not implemented yet in tinker" + ) + else: + # skip token_level_scores for sft/dpo + for model_inputs in model_inputs_list: + if "token_level_scores" in model_inputs: + assert "token_level_rewards" not in model_inputs + model_inputs["token_level_rewards"] = model_inputs["token_level_scores"] + + # update actor + with Timer(timing_raw, "update_actor"): + fwdbwd_future = await self.actor_client.forward_backward_custom_async( + batch, self._loss_func + ) + optim_future = await self.actor_client.optim_step_async(self.adam_params) + fwdbwd_result = await fwdbwd_future + optim_result = await optim_future + metrics.update(fwdbwd_result.metrics) + if optim_result.metrics: + metrics.update(optim_result.metrics) + + # collect metrics + metrics.update(compute_data_metrics(batch=self.model_inputs_list)) + timing_metrics = compute_timing_metrics(batch=self.model_inputs_list, timing_raw=timing_raw) + metrics.update({k.replace("timing_s/", "time/"): v for k, v in timing_metrics.items()}) + metrics.update( + compute_throughout_metrics(batch=self.model_inputs_list, timing_raw=timing_raw) + ) + + return metrics + + def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = False) -> None: + """Save the checkpoint.""" + if self.train_step_num == self.latest_remote_checkpoint_step: + return + self.latest_remote_checkpoint_step = self.train_step_num + checkpoint_name = f"{self.tinker_checkpoint_name_prefix}-state-{self.train_step_num}" + self.latest_remote_checkpoint_path = ( + self.actor_client.save_state(checkpoint_name).result().path + ) + local_path = os.path.join( + self.default_local_dir, + f"global_step_{self.train_step_num}", + ) + os.makedirs(local_path, exist_ok=True) + remote_checkpoint_path = os.path.join(local_path, "remote_checkpoint_path.txt") + with open(remote_checkpoint_path, "w") as f: + f.write(self.latest_remote_checkpoint_path) + + with open(self.local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.train_step_num)) + + def sync_weight(self) -> None: + """Sync the model weight.""" + raise NotImplementedError("Tinker trainer does not support NCCL sync") + + def upload_state_dict(self) -> None: + """Upload the state dict to Synchronizer.""" + self.save_state_dict() + ray.get( + self.synchronizer.set_model_state_dict.remote( + self.latest_remote_sampler_path, self.train_step_num + ) + ) + + def save_state_dict(self) -> None: + """Only save the model state dict for Synchronizer.""" + if self.train_step_num == self.latest_remote_sampler_step: + return + self.latest_remote_sampler_step = self.train_step_num + checkpoint_name = f"{self.tinker_checkpoint_name_prefix}-sampler-{self.train_step_num}" + self.latest_remote_sampler_path = ( + self.actor_client.save_weights_for_sampler(checkpoint_name).result().path + ) + local_path = os.path.join( + self.default_local_dir, + f"global_step_{self.train_step_num}", + ) + os.makedirs(local_path, exist_ok=True) + remote_sampler_path = os.path.join(local_path, "remote_sampler_path.txt") + with open(remote_sampler_path, "w") as f: + f.write(self.latest_remote_sampler_path) + + with open(self.local_latest_state_dict_iteration, "w") as f: + f.write(str(self.train_step_num)) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index c42fccfb26..c4901f3a2f 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -17,7 +17,7 @@ from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy from trinity.common.config import Config from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle -from trinity.common.experience import Experiences +from trinity.common.experience import Experience from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer from trinity.utils.log import get_logger @@ -39,7 +39,6 @@ def __init__(self, config: Config) -> None: path=config.checkpoint_job_dir, trainer_name=config.trainer.name, config=config ) trainer_state = self.state.load_trainer() - self.last_trainer_sync_step = 0 self.monitor = MONITOR.get(config.monitor.monitor_type)( project=config.project, group=self.config.group, @@ -60,15 +59,15 @@ def __init__(self, config: Config) -> None: sample_strategy_state = trainer_state.get("sample_strategy_state", {}) self.sample_strategy.load_state_dict(sample_strategy_state) self.save_interval = config.trainer.save_interval - self.last_sync_step = None + self.last_sync_step = 0 self.last_sync_time = None self.total_steps = config.trainer.total_steps or float("inf") self.save_hf_checkpoint = config.trainer.save_hf_checkpoint async def prepare(self) -> None: """Prepare the trainer.""" - self.engine.prepare() - self.last_trainer_sync_step = self.train_step_num + await self.engine.prepare() + self.last_sync_step = self.train_step_num await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING) async def train(self) -> str: @@ -109,7 +108,7 @@ async def train(self) -> str: self.logger.info("--------------------\n> Trainer finished.\n--------------------") return self.config.trainer.name - async def train_step(self, exps: Experiences) -> Dict: + async def train_step(self, exps: List[Experience]) -> Dict: """Train one step. Returns: @@ -119,21 +118,21 @@ async def train_step(self, exps: Experiences) -> Dict: self.logger.info(f"Training at step {self.train_step_num + 1} started.") metrics = {} with Timer(metrics, "time/train_step"): - train_metrics = self.engine.train_step(exps) + train_metrics = await self.engine.train_step(exps) self.logger.info(f"Training at step {self.train_step_num} finished.") metrics.update(train_metrics) return metrics - async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]: + async def _sample_data(self) -> Tuple[List[Experience], Dict, List[Dict]]: """Sample a batch of experiences. Returns: - Experiences: A batch of experiences. + List[Experience]: A batch of experiences. Dict: Metrics of the sampling step. List[Dict]: A list of representative samples for logging. """ batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1) - metrics["sample/task_count"] = len(set(eid.tid for eid in batch.eids)) + metrics["sample/task_count"] = len(set(exp.eid.tid for exp in batch)) return batch, metrics, repr_samples async def need_sync(self) -> bool: @@ -145,14 +144,17 @@ async def need_sync(self) -> bool: ) else: if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_TRAINER: - delta = self.train_step_num - self.last_trainer_sync_step + delta = self.train_step_num - self.last_sync_step if delta >= self.config.synchronizer.sync_interval: await self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC) explorer_status_counts = await self.synchronizer.get_explorer_status_counts.remote() if self.config.synchronizer.sync_method == SyncMethod.NCCL: return explorer_status_counts[RunningStatus.WAITING_SYNC] > 0 else: # memory & checkpoint - return explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0 + return ( + self.last_sync_step != self.train_step_num + and explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0 + ) def need_save(self) -> bool: """Whether to save the checkpoint.""" @@ -173,7 +175,6 @@ async def sync_weight(self) -> Dict: self.logger.error("Trainer sync_weights failed.") else: self.engine.sync_weight() - self.last_trainer_sync_step = self.train_step_num elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: self.engine.save_state_dict() elif self.config.synchronizer.sync_method == SyncMethod.MEMORY: @@ -229,7 +230,7 @@ class TrainEngineWrapper(ABC): """A wrapper class to wrap various training engines.""" @abstractmethod - def prepare(self) -> None: + async def prepare(self) -> None: """Do some preparation before training started.""" @property @@ -238,11 +239,11 @@ def train_step_num(self) -> int: """Get the current training step number.""" @abstractmethod - def train_step(self, batch: Experiences) -> Dict: + async def train_step(self, batch_exps: List[Experience]) -> Dict: """Training one step. Args: - batch (Experiences): A batch of experiences to train. + batch_exps (List[Experience]): A batch of experiences to train. Returns: Dict: Metrics of the training step. @@ -271,5 +272,9 @@ def get_trainer_wrapper(config: Config) -> TrainEngineWrapper: from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper return VerlPPOTrainerWrapper(config) + elif config.trainer.trainer_type == "tinker": + from trinity.trainer.tinker_trainer import TinkerTrainerWrapper + + return TinkerTrainerWrapper(config) else: raise NotImplementedError diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index 922182e643..640ee2b748 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -2,6 +2,7 @@ import os from logging import Logger +from typing import List import numpy as np import torch @@ -10,75 +11,98 @@ from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from trinity.common.config import Config -from trinity.common.experience import Experiences - - -def to_data_proto(experiences: Experiences, logger: Logger) -> DataProto: # noqa: C901 - """Convert Experiences to verl DataProto.""" - attention_mask = experiences.attention_masks +from trinity.common.experience import ( + Experience, + gather_action_masks, + gather_attention_masks, + gather_response_attrs, + gather_token_ids, + split_dpo_experience_to_single_turn, +) + + +def to_data_proto( + experiences: List[Experience], pad_token_id: int, logger: Logger +) -> DataProto: # noqa: C901 + """Convert List[Experience] to verl DataProto.""" + assert len(experiences) > 0, "No experiences provided." + if experiences[0].experience_type == "dpo": + experiences = split_dpo_experience_to_single_turn(experiences) + max_prompt_length = max([exp.prompt_length for exp in experiences]) + max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore + + attention_mask = gather_attention_masks( + experiences, max_prompt_length, max_response_length + ).long() cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() + tokens = gather_token_ids( + experiences, max_prompt_length, max_response_length, pad_token_id + ).long() batch_dict = { - "uid": np.array([eid.tid for eid in experiences.eids]), - "unique_ids": np.array([eid.uid for eid in experiences.eids]), + "uid": np.array([exp.eid.tid for exp in experiences]), + "unique_ids": np.array([exp.eid.uid for exp in experiences]), "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks.long() - if hasattr(experiences, "action_masks") and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), + "input_ids": tokens, + "responses": tokens[:, max_prompt_length:], + "attention_mask": attention_mask, + "response_mask": gather_action_masks(experiences, max_response_length), } - if experiences.rewards is not None or experiences.token_level_rewards is not None: - assert experiences.logprobs is not None - if experiences.token_level_rewards is not None: - if experiences.rewards is not None: + have_reward = all(exp.reward is not None for exp in experiences) + have_token_level_reward = all(exp.token_level_reward is not None for exp in experiences) + if have_reward or have_token_level_reward: + assert all(exp.logprobs is not None for exp in experiences), "No logprobs provided." + if have_token_level_reward: + if have_reward: logger.warning( "Both experiences.rewards and experiences.token_level_rewards are provided. " "Using experiences.token_level_rewards." ) - token_level_rewards = experiences.token_level_rewards + token_level_rewards = gather_response_attrs( + experiences, "token_level_reward", max_response_length + ) else: - token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + token_level_rewards = torch.zeros(attention_mask.shape, dtype=torch.float32) eos_mask_idx = cumsum.argmax(dim=-1) - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + token_level_rewards[torch.arange(len(experiences)), eos_mask_idx] = torch.tensor( + [exp.reward for exp in experiences] + ) + token_level_rewards = token_level_rewards[:, max_prompt_length:] batch_dict.update( { "token_level_scores": token_level_rewards, - "old_log_probs": experiences.logprobs, # type: ignore + "old_log_probs": gather_response_attrs( + experiences, "logprobs", max_response_length + ), } ) - if experiences.advantages is not None: - batch_dict["advantages"] = experiences.advantages - if experiences.returns is not None: - batch_dict["returns"] = experiences.returns - if experiences.teacher_logprobs is not None: - batch_dict["teacher_log_probs"] = experiences.teacher_logprobs - - if experiences.multi_modal_inputs is not None: - batch_size = len(batch_dict["unique_ids"]) + + for attr in ["advantages", "returns", "teacher_logprobs"]: + if all(getattr(exp, attr, None) is not None for exp in experiences): + batch_dict[attr] = gather_response_attrs(experiences, attr, max_response_length) + + if all(exp.multi_modal_inputs is not None for exp in experiences): + keys = experiences[0].multi_modal_inputs.keys() batch_dict["multi_modal_inputs"] = np.array( - [ - {k: v[i] for k, v in experiences.multi_modal_inputs.items()} - for i in range(batch_size) - ], + [{key: exp.multi_modal_inputs[key] for key in keys} for exp in experiences], # type: ignore dtype=object, ) - if experiences.custom_fields: - for field in experiences.custom_fields: - if hasattr(experiences, field): - batch_dict[field] = getattr(experiences, field) + custom_fields_set = set(tuple(exp.custom_fields) for exp in experiences) + if len(custom_fields_set) == 1: + custom_fields = list(custom_fields_set)[0] + for custom_field in custom_fields: + batch_dict[custom_field.destination_field] = torch.tensor( + [exp.info[custom_field.source_field] for exp in experiences], + dtype=custom_field.data_type, + ) + else: + raise ValueError("Custom fields are not consistent across experiences.") return DataProto.from_single_dict(batch_dict) -def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict: +def compute_data_metrics(batch: DataProto) -> dict: """ Computes various metrics from a batch of data for PPO training. Modified from verl.trainer.ppo.metric_utils.compute_data_metrics @@ -89,7 +113,6 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict: Args: batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. - use_critic: Whether to include critic-specific metrics. Defaults to True. Returns: A dictionary of metrics including: @@ -97,8 +120,8 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict: - critic/rewards/mean, max, min: Statistics about sequence rewards - critic/advantages/mean, max, min: Statistics about advantages - critic/returns/mean, max, min: Statistics about returns - - critic/values/mean, max, min: Statistics about critic values (if use_critic=True) - - critic/vf_explained_var: Explained variance of the value function (if use_critic=True) + - critic/values/mean, max, min: Statistics about critic values + - critic/vf_explained_var: Explained variance of the value function - response_length/mean, max, min, clip_ratio: Statistics about response lengths - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths """ diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 16c0525327..9c52af3d66 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -7,7 +7,7 @@ import os import sys from collections import defaultdict -from typing import Dict, Optional +from typing import Dict, List, Optional import ray import torch @@ -30,11 +30,11 @@ from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.metric import reduce_metrics -from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN, SAMPLE_STRATEGY +from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN from trinity.algorithm.utils import prefix_metrics from trinity.common.config import Config from trinity.common.constants import SaveStrategy -from trinity.common.experience import Experiences +from trinity.common.experience import Experience from trinity.trainer.trainer import TrainEngineWrapper from trinity.trainer.verl.utils import compute_data_metrics, to_data_proto from trinity.utils.log import get_logger @@ -187,6 +187,7 @@ def __init__( global_config: Config, ): self.logger = get_logger(__name__, in_ray_actor=True) + self.pad_token_id = global_config.buffer.pad_token_id train_config = global_config.trainer config = OmegaConf.structured(train_config.trainer_config) # download the checkpoint from hdfs @@ -261,11 +262,6 @@ def __init__( self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)( **self.algorithm_config.kl_penalty_fn_args ) - self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)( - buffer_config=global_config.buffer, - trainer_type=global_config.trainer.trainer_type, - **global_config.algorithm.sample_strategy_args, - ) super().__init__( config, tokenizer, @@ -379,7 +375,7 @@ def init_workers(self): def train_step_num(self) -> int: return self.global_steps - def prepare(self): + async def prepare(self): self.actor_rollout_wg.setup_weight_sync_group() self.actor_rollout_wg.set_algorithm(self.algorithm_config) @@ -411,8 +407,8 @@ def save_state_dict(self): # checkpoint sync def upload_state_dict(self): # state dict sync self.actor_rollout_wg.upload_state_dict(self.global_steps) - def train_step(self, batch: Experiences) -> Dict: # noqa C901 - batch = to_data_proto(batch, self.logger) + async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 + batch = to_data_proto(batch_exps, self.pad_token_id, self.logger) # type: ignore batch = self.post_process_batch(batch) metrics = {} self.global_steps += 1 @@ -476,7 +472,7 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901 metrics.update(actor_output_metrics) # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_data_metrics(batch=batch)) timing_metrics = compute_timing_metrics(batch=batch, timing_raw=timing_raw) metrics.update({k.replace("timing_s/", "time/"): v for k, v in timing_metrics.items()}) n_gpus = self.resource_pool_manager.get_n_gpus()