Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/config/alfworld-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ model:
model_path: placeholder
max_prompt_tokens: 10240
max_response_tokens: 4096
enable_thinking: false
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -65,7 +66,6 @@ explorer:
gpu_memory_utilization: 0.7
dtype: bfloat16
seed: 42
enable_thinking: false
enable_openai_api: false
auxiliary_models: []
bench_on_latest_checkpoint: true
Expand Down
2 changes: 1 addition & 1 deletion benchmark/config/countdown-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ model:
model_path: Qwen/Qwen2.5-1.5B-Instruct
max_prompt_tokens: 256
max_response_tokens: 1024
enable_thinking: false
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -60,7 +61,6 @@ explorer:
gpu_memory_utilization: 0.9
dtype: bfloat16
seed: 42
enable_thinking: false
enable_openai_api: false
auxiliary_models: []
eval_interval: 1000
Expand Down
2 changes: 1 addition & 1 deletion benchmark/config/frozenlake-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ model:
model_path: Qwen/Qwen2.5-3B-Instruct
max_prompt_tokens: 4096
max_response_tokens: 10240
enable_thinking: false
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -70,7 +71,6 @@ explorer:
gpu_memory_utilization: 0.85
dtype: bfloat16
seed: 42
enable_thinking: false
enable_openai_api: false
auxiliary_models: []
bench_on_latest_checkpoint: true
Expand Down
2 changes: 1 addition & 1 deletion benchmark/config/gsm8k-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ model:
model_path: Qwen/Qwen2.5-1.5B-Instruct
max_prompt_tokens: 256
max_response_tokens: 1024
enable_thinking: false
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -65,7 +66,6 @@ explorer:
gpu_memory_utilization: 0.9
dtype: bfloat16
seed: 42
enable_thinking: false
enable_openai_api: false
auxiliary_models: []
eval_interval: 1000
Expand Down
1 change: 0 additions & 1 deletion docs/sphinx_doc/source/tutorial/example_react.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ explorer:
enable_auto_tool_choice: true # Allow model to generate `tool_calls`
tool_call_parser: hermes # Specify parser for tool call outputs
reasoning_parser: deepseek_r1 # Helps parse model reasoning process
enable_thinking: true # Enable thinking (mainly for Qwen3 series models)
```

#### Multi-Step Training Algorithm
Expand Down
1 change: 0 additions & 1 deletion docs/sphinx_doc/source_zh/tutorial/example_react.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ explorer:
enable_auto_tool_choice: true # 允许模型生成 `tool_calls`
tool_call_parser: hermes # 指定格式化解析工具调用输出的解析器
reasoning_parser: deepseek_r1 # 有助于解析模型的思维过程
enable_thinking: true # 是否启用模型深度思考能力(主要针对 Qwen3 系列模型)
```

#### 多步训练算法
Expand Down
2 changes: 1 addition & 1 deletion examples/agentscope_frozenlake/frozenlake_agent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ model:
max_response_tokens: 2048
max_model_len: 25600
temperature: 1.0
enable_thinking: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -61,7 +62,6 @@ explorer:
enable_auto_tool_choice: true
tool_call_parser: hermes
# reasoning_parser: deepseek_r1 # if you use Qwen3 series, uncomment this line
enable_thinking: true
dtype: bfloat16
seed: 42
gpu_memory_utilization: 0.85
Expand Down
2 changes: 1 addition & 1 deletion examples/agentscope_react/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-8B}
max_response_tokens: 16384
max_model_len: 24576
enable_thinking: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -49,7 +50,6 @@ explorer:
enable_auto_tool_choice: true
tool_call_parser: hermes
reasoning_parser: deepseek_r1
enable_thinking: true
dtype: bfloat16
seed: 42
synchronizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-8B}
max_response_tokens: 16384
max_model_len: 24576
enable_thinking: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -52,7 +53,6 @@ explorer:
enable_auto_tool_choice: true
tool_call_parser: hermes
reasoning_parser: deepseek_r1
enable_thinking: true
synchronizer:
sync_style: explorer_driven
sync_method: 'nccl'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B}
max_response_tokens: 16384
max_model_len: 24576
enable_thinking: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -52,7 +53,6 @@ explorer:
enable_auto_tool_choice: true
tool_call_parser: hermes
reasoning_parser: deepseek_r1
enable_thinking: true
synchronizer:
sync_style: explorer_driven
sync_method: 'nccl'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
max_response_tokens: 4096
max_model_len: 20480
enable_thinking: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -60,7 +61,6 @@ explorer:
max_repeat_times_per_runner: 1
max_timeout: 3600
rollout_model:
enable_thinking: true
enable_history: true
enable_openai_api: true
enable_auto_tool_choice: true
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_email_search/email_search.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507}
max_response_tokens: 4096
max_model_len: 20480
enable_thinking: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -81,7 +82,6 @@ explorer:
max_repeat_times_per_runner: 1
max_timeout: 3600
rollout_model:
enable_thinking: true
enable_history: true
enable_openai_api: true
enable_auto_tool_choice: true
Expand Down
2 changes: 1 addition & 1 deletion examples/learn_to_ask/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ model:
max_response_tokens: 1024
temperature: 1.0
logprobs: 0
enable_thinking: false
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -67,7 +68,6 @@ explorer:
gpu_memory_utilization: 0.9
dtype: bfloat16
seed: 42
enable_thinking: false
enable_openai_api: true
auxiliary_models:
- model_path: ${oc.env:TRINITY_AUX_MODEL_PATH,Qwen/Qwen2.5-32B-Instruct}
Expand Down
12 changes: 6 additions & 6 deletions tests/buffer/task_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from trinity.buffer.reader import READER
from trinity.buffer.reader.file_reader import TaskFileReader
from trinity.buffer.task_scheduler import TasksetScheduler, get_taskset_scheduler
from trinity.common.config import FormatConfig, TaskSelectorConfig, TasksetConfig
from trinity.common.config import DataSelectorConfig, FormatConfig, TasksetConfig
from trinity.common.workflows.workflow import Task


Expand Down Expand Up @@ -250,7 +250,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
]
)
async def test_task_scheduler(
self, buffer_config_kwargs, task_selector_kwargs, batch_tasks_orders
self, buffer_config_kwargs, data_selector_kwargs, batch_tasks_orders
) -> None:
config = get_template_config()
config.mode = "explore"
Expand All @@ -276,8 +276,8 @@ async def test_task_scheduler(
),
default_workflow_type="math_workflow",
default_reward_fn_type="math_reward",
task_selector=TaskSelectorConfig(
**task_selector_kwargs,
data_selector=DataSelectorConfig(
**data_selector_kwargs,
),
),
TasksetConfig(
Expand All @@ -298,8 +298,8 @@ async def test_task_scheduler(
),
default_workflow_type="math_workflow",
default_reward_fn_type="math_reward",
task_selector=TaskSelectorConfig(
**task_selector_kwargs,
data_selector=DataSelectorConfig(
**data_selector_kwargs,
),
),
]
Expand Down
1 change: 1 addition & 0 deletions tests/common/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# -*- coding: utf-8 -*-
82 changes: 82 additions & 0 deletions tests/common/models/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
"""Tests for model utils tokenization helpers."""

import unittest

import torch
import transformers

from tests.tools import get_model_path
from trinity.common.models.utils import tokenize_and_mask_messages_default


class TestTokenizeAndMaskMessagesDefault(unittest.TestCase):
def setUp(self):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(get_model_path())
return super().setUp()

def test_normal_conversation_data(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi", "reasoning_content": "greeting"},
{"role": "user", "content": "How are you?"},
{"role": "assistant", "content": "I am fine.", "reasoning_content": "answering"},
]

token_ids, assistant_mask, prompt_length = tokenize_and_mask_messages_default(
tokenizer=self.tokenizer,
messages=messages,
enable_thinking=True,
)

self.assertTrue(
torch.equal(
assistant_mask,
torch.tensor([0] * 24 + [1] * 14, dtype=torch.int),
)
)
self.assertEqual(prompt_length, 24)

def test_messages_empty(self):
with self.assertRaises(ValueError):
tokenize_and_mask_messages_default(tokenizer=self.tokenizer, messages=[])

def test_no_assistant_messages(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "user", "content": "Still user"},
]

token_ids, assistant_mask, prompt_length = tokenize_and_mask_messages_default(
tokenizer=self.tokenizer,
messages=messages,
enable_thinking=True,
)

self.assertTrue(torch.equal(assistant_mask, torch.zeros(13, dtype=torch.int)))
self.assertEqual(prompt_length, 0)

def test_first_message_is_assistant(self):
messages = [
{"role": "assistant", "content": "I start first.", "reasoning": "starting"},
{"role": "user", "content": "Then me."},
{"role": "assistant", "content": "Final reply.", "reasoning": "ending"},
]

token_ids, assistant_mask, prompt_length = tokenize_and_mask_messages_default(
tokenizer=self.tokenizer,
messages=messages,
enable_thinking=True,
)

self.assertTrue(
torch.equal(
assistant_mask,
torch.tensor([0] * 20 + [1] * 9, dtype=torch.int),
)
)
self.assertEqual(prompt_length, 20)


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
AlgorithmConfig,
BufferConfig,
Config,
DataSelectorConfig,
ExperienceBufferConfig,
ExplorerInput,
StageConfig,
TaskSelectorConfig,
TrainerInput,
)
from trinity.common.constants import (
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_trainer(self):
}
self.config.model.rope_theta = 10000
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig(
self.config.buffer.explorer_input.taskset.data_selector = DataSelectorConfig(
selector_type="shuffle", seed=42
)
eval_tasksets = self.config.buffer.explorer_input.eval_tasksets
Expand Down
Loading
Loading