diff --git a/benchmark/config/alfworld-template.yaml b/benchmark/config/alfworld-template.yaml index be62a1bc0b..a1553e976d 100644 --- a/benchmark/config/alfworld-template.yaml +++ b/benchmark/config/alfworld-template.yaml @@ -22,6 +22,7 @@ model: model_path: placeholder max_prompt_tokens: 10240 max_response_tokens: 4096 + enable_thinking: false cluster: node_num: 1 gpu_per_node: 8 @@ -65,7 +66,6 @@ explorer: gpu_memory_utilization: 0.7 dtype: bfloat16 seed: 42 - enable_thinking: false enable_openai_api: false auxiliary_models: [] bench_on_latest_checkpoint: true diff --git a/benchmark/config/countdown-template.yaml b/benchmark/config/countdown-template.yaml index 1003fb1f77..1eae416038 100644 --- a/benchmark/config/countdown-template.yaml +++ b/benchmark/config/countdown-template.yaml @@ -16,6 +16,7 @@ model: model_path: Qwen/Qwen2.5-1.5B-Instruct max_prompt_tokens: 256 max_response_tokens: 1024 + enable_thinking: false cluster: node_num: 1 gpu_per_node: 8 @@ -60,7 +61,6 @@ explorer: gpu_memory_utilization: 0.9 dtype: bfloat16 seed: 42 - enable_thinking: false enable_openai_api: false auxiliary_models: [] eval_interval: 1000 diff --git a/benchmark/config/frozenlake-template.yaml b/benchmark/config/frozenlake-template.yaml index 7208b19c76..933bd82105 100644 --- a/benchmark/config/frozenlake-template.yaml +++ b/benchmark/config/frozenlake-template.yaml @@ -24,6 +24,7 @@ model: model_path: Qwen/Qwen2.5-3B-Instruct max_prompt_tokens: 4096 max_response_tokens: 10240 + enable_thinking: false cluster: node_num: 1 gpu_per_node: 8 @@ -70,7 +71,6 @@ explorer: gpu_memory_utilization: 0.85 dtype: bfloat16 seed: 42 - enable_thinking: false enable_openai_api: false auxiliary_models: [] bench_on_latest_checkpoint: true diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml index 92a1c8b5ff..f4b49dc850 100644 --- a/benchmark/config/gsm8k-template.yaml +++ b/benchmark/config/gsm8k-template.yaml @@ -21,6 +21,7 @@ model: model_path: Qwen/Qwen2.5-1.5B-Instruct max_prompt_tokens: 256 max_response_tokens: 1024 + enable_thinking: false cluster: node_num: 1 gpu_per_node: 8 @@ -65,7 +66,6 @@ explorer: gpu_memory_utilization: 0.9 dtype: bfloat16 seed: 42 - enable_thinking: false enable_openai_api: false auxiliary_models: [] eval_interval: 1000 diff --git a/docs/sphinx_doc/source/tutorial/example_react.md b/docs/sphinx_doc/source/tutorial/example_react.md index 1d8862a242..8a9ea82659 100644 --- a/docs/sphinx_doc/source/tutorial/example_react.md +++ b/docs/sphinx_doc/source/tutorial/example_react.md @@ -137,7 +137,6 @@ explorer: enable_auto_tool_choice: true # Allow model to generate `tool_calls` tool_call_parser: hermes # Specify parser for tool call outputs reasoning_parser: deepseek_r1 # Helps parse model reasoning process - enable_thinking: true # Enable thinking (mainly for Qwen3 series models) ``` #### Multi-Step Training Algorithm diff --git a/docs/sphinx_doc/source_zh/tutorial/example_react.md b/docs/sphinx_doc/source_zh/tutorial/example_react.md index 2b1ee7c8a9..061623cd6c 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_react.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_react.md @@ -144,7 +144,6 @@ explorer: enable_auto_tool_choice: true # 允许模型生成 `tool_calls` tool_call_parser: hermes # 指定格式化解析工具调用输出的解析器 reasoning_parser: deepseek_r1 # 有助于解析模型的思维过程 - enable_thinking: true # 是否启用模型深度思考能力(主要针对 Qwen3 系列模型) ``` #### 多步训练算法 diff --git a/examples/agentscope_frozenlake/frozenlake_agent.yaml b/examples/agentscope_frozenlake/frozenlake_agent.yaml index eabd6244e4..8b73bbbad1 100644 --- a/examples/agentscope_frozenlake/frozenlake_agent.yaml +++ b/examples/agentscope_frozenlake/frozenlake_agent.yaml @@ -18,6 +18,7 @@ model: max_response_tokens: 2048 max_model_len: 25600 temperature: 1.0 + enable_thinking: true cluster: node_num: 1 gpu_per_node: 8 @@ -61,7 +62,6 @@ explorer: enable_auto_tool_choice: true tool_call_parser: hermes # reasoning_parser: deepseek_r1 # if you use Qwen3 series, uncomment this line - enable_thinking: true dtype: bfloat16 seed: 42 gpu_memory_utilization: 0.85 diff --git a/examples/agentscope_react/gsm8k.yaml b/examples/agentscope_react/gsm8k.yaml index d7d510091c..7af12e2724 100644 --- a/examples/agentscope_react/gsm8k.yaml +++ b/examples/agentscope_react/gsm8k.yaml @@ -10,6 +10,7 @@ model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-8B} max_response_tokens: 16384 max_model_len: 24576 + enable_thinking: true cluster: node_num: 1 gpu_per_node: 8 @@ -49,7 +50,6 @@ explorer: enable_auto_tool_choice: true tool_call_parser: hermes reasoning_parser: deepseek_r1 - enable_thinking: true dtype: bfloat16 seed: 42 synchronizer: diff --git a/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml b/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml index d9fa085a96..56141ecd16 100644 --- a/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml +++ b/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml @@ -10,6 +10,7 @@ model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-8B} max_response_tokens: 16384 max_model_len: 24576 + enable_thinking: true cluster: node_num: 1 gpu_per_node: 8 @@ -52,7 +53,6 @@ explorer: enable_auto_tool_choice: true tool_call_parser: hermes reasoning_parser: deepseek_r1 - enable_thinking: true synchronizer: sync_style: explorer_driven sync_method: 'nccl' diff --git a/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml b/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml index 3f0b9bbff7..01d2a459fc 100644 --- a/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml +++ b/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml @@ -10,6 +10,7 @@ model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B} max_response_tokens: 16384 max_model_len: 24576 + enable_thinking: true cluster: node_num: 1 gpu_per_node: 8 @@ -52,7 +53,6 @@ explorer: enable_auto_tool_choice: true tool_call_parser: hermes reasoning_parser: deepseek_r1 - enable_thinking: true synchronizer: sync_style: explorer_driven sync_method: 'nccl' diff --git a/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml b/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml index 6545f3356a..a319195e8e 100644 --- a/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml +++ b/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml @@ -10,6 +10,7 @@ model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} max_response_tokens: 4096 max_model_len: 20480 + enable_thinking: true cluster: node_num: 1 gpu_per_node: 8 @@ -60,7 +61,6 @@ explorer: max_repeat_times_per_runner: 1 max_timeout: 3600 rollout_model: - enable_thinking: true enable_history: true enable_openai_api: true enable_auto_tool_choice: true diff --git a/examples/grpo_email_search/email_search.yaml b/examples/grpo_email_search/email_search.yaml index e389e5bb3b..2ce081b65b 100644 --- a/examples/grpo_email_search/email_search.yaml +++ b/examples/grpo_email_search/email_search.yaml @@ -24,6 +24,7 @@ model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} max_response_tokens: 4096 max_model_len: 20480 + enable_thinking: true cluster: node_num: 1 gpu_per_node: 8 @@ -81,7 +82,6 @@ explorer: max_repeat_times_per_runner: 1 max_timeout: 3600 rollout_model: - enable_thinking: true enable_history: true enable_openai_api: true enable_auto_tool_choice: true diff --git a/examples/learn_to_ask/train.yaml b/examples/learn_to_ask/train.yaml index 1e4d3972df..243a679703 100644 --- a/examples/learn_to_ask/train.yaml +++ b/examples/learn_to_ask/train.yaml @@ -21,6 +21,7 @@ model: max_response_tokens: 1024 temperature: 1.0 logprobs: 0 + enable_thinking: false cluster: node_num: 1 gpu_per_node: 8 @@ -67,7 +68,6 @@ explorer: gpu_memory_utilization: 0.9 dtype: bfloat16 seed: 42 - enable_thinking: false enable_openai_api: true auxiliary_models: - model_path: ${oc.env:TRINITY_AUX_MODEL_PATH,Qwen/Qwen2.5-32B-Instruct} diff --git a/tests/buffer/task_scheduler_test.py b/tests/buffer/task_scheduler_test.py index 6792290d68..62468431ad 100644 --- a/tests/buffer/task_scheduler_test.py +++ b/tests/buffer/task_scheduler_test.py @@ -9,7 +9,7 @@ from trinity.buffer.reader import READER from trinity.buffer.reader.file_reader import TaskFileReader from trinity.buffer.task_scheduler import TasksetScheduler, get_taskset_scheduler -from trinity.common.config import FormatConfig, TaskSelectorConfig, TasksetConfig +from trinity.common.config import DataSelectorConfig, FormatConfig, TasksetConfig from trinity.common.workflows.workflow import Task @@ -250,7 +250,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ] ) async def test_task_scheduler( - self, buffer_config_kwargs, task_selector_kwargs, batch_tasks_orders + self, buffer_config_kwargs, data_selector_kwargs, batch_tasks_orders ) -> None: config = get_template_config() config.mode = "explore" @@ -276,8 +276,8 @@ async def test_task_scheduler( ), default_workflow_type="math_workflow", default_reward_fn_type="math_reward", - task_selector=TaskSelectorConfig( - **task_selector_kwargs, + data_selector=DataSelectorConfig( + **data_selector_kwargs, ), ), TasksetConfig( @@ -298,8 +298,8 @@ async def test_task_scheduler( ), default_workflow_type="math_workflow", default_reward_fn_type="math_reward", - task_selector=TaskSelectorConfig( - **task_selector_kwargs, + data_selector=DataSelectorConfig( + **data_selector_kwargs, ), ), ] diff --git a/tests/common/models/__init__.py b/tests/common/models/__init__.py new file mode 100644 index 0000000000..40a96afc6f --- /dev/null +++ b/tests/common/models/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/tests/common/models/utils_test.py b/tests/common/models/utils_test.py new file mode 100644 index 0000000000..d374553a1b --- /dev/null +++ b/tests/common/models/utils_test.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +"""Tests for model utils tokenization helpers.""" + +import unittest + +import torch +import transformers + +from tests.tools import get_model_path +from trinity.common.models.utils import tokenize_and_mask_messages_default + + +class TestTokenizeAndMaskMessagesDefault(unittest.TestCase): + def setUp(self): + self.tokenizer = transformers.AutoTokenizer.from_pretrained(get_model_path()) + return super().setUp() + + def test_normal_conversation_data(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi", "reasoning_content": "greeting"}, + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "I am fine.", "reasoning_content": "answering"}, + ] + + token_ids, assistant_mask, prompt_length = tokenize_and_mask_messages_default( + tokenizer=self.tokenizer, + messages=messages, + enable_thinking=True, + ) + + self.assertTrue( + torch.equal( + assistant_mask, + torch.tensor([0] * 24 + [1] * 14, dtype=torch.int), + ) + ) + self.assertEqual(prompt_length, 24) + + def test_messages_empty(self): + with self.assertRaises(ValueError): + tokenize_and_mask_messages_default(tokenizer=self.tokenizer, messages=[]) + + def test_no_assistant_messages(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "Still user"}, + ] + + token_ids, assistant_mask, prompt_length = tokenize_and_mask_messages_default( + tokenizer=self.tokenizer, + messages=messages, + enable_thinking=True, + ) + + self.assertTrue(torch.equal(assistant_mask, torch.zeros(13, dtype=torch.int))) + self.assertEqual(prompt_length, 0) + + def test_first_message_is_assistant(self): + messages = [ + {"role": "assistant", "content": "I start first.", "reasoning": "starting"}, + {"role": "user", "content": "Then me."}, + {"role": "assistant", "content": "Final reply.", "reasoning": "ending"}, + ] + + token_ids, assistant_mask, prompt_length = tokenize_and_mask_messages_default( + tokenizer=self.tokenizer, + messages=messages, + enable_thinking=True, + ) + + self.assertTrue( + torch.equal( + assistant_mask, + torch.tensor([0] * 20 + [1] * 9, dtype=torch.int), + ) + ) + self.assertEqual(prompt_length, 20) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index f5370bb79e..d6e77a363b 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -36,10 +36,10 @@ AlgorithmConfig, BufferConfig, Config, + DataSelectorConfig, ExperienceBufferConfig, ExplorerInput, StageConfig, - TaskSelectorConfig, TrainerInput, ) from trinity.common.constants import ( @@ -92,7 +92,7 @@ def test_trainer(self): } self.config.model.rope_theta = 10000 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig( + self.config.buffer.explorer_input.taskset.data_selector = DataSelectorConfig( selector_type="shuffle", seed=42 ) eval_tasksets = self.config.buffer.explorer_input.eval_tasksets diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 6b668d8994..ce8173213d 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -35,6 +35,8 @@ def __init__( ): self.dataset = dataset self.dataset_size = len(dataset) + if self.dataset_size == 0: + raise ValueError(f"Dataset [{name}] is empty and cannot be read in batches.") self.name = name self.current_batch_size = None self.drop_last = drop_last @@ -42,7 +44,7 @@ def __init__( self.current_offset = offset # convert epochs/steps to sample number - if total_steps: + if total_steps is not None: self.total_samples = default_batch_size * total_steps else: self.total_samples = self.dataset_size * total_epochs @@ -94,11 +96,63 @@ def select_batch(self, indices: List[int]) -> List: class BaseFileReader(BufferReader): async def read_async(self, batch_size: Optional[int] = None, **kwargs): try: - return self.read(batch_size) + return self.read(batch_size, **kwargs) except StopIteration as e: raise StopAsyncIteration from e +class _DatasetFileReader(BaseFileReader): + def __init__(self, config: StorageConfig): + self.config = config + self.name = config.name + self.read_batch_size = config.batch_size + self.formatter, self.dataset = self._init_formatter_and_dataset(config) + self._init_selector(config) + + def _init_formatter_and_dataset(self, config: StorageConfig): + raise NotImplementedError + + def _init_selector(self, config: StorageConfig): + if config.data_selector is not None: + from trinity.buffer.selector import SELECTORS + from trinity.buffer.selector.selector import BaseSelector + + self.selector: BaseSelector = SELECTORS.get(config.data_selector.selector_type)( + self.dataset, config.data_selector + ) + else: + self.selector = None + + def _read_samples(self, batch_size: int) -> Tuple[List, List]: + if self.selector is not None: + indices = self.selector.get_indices(batch_size) + samples = self.dataset.select_batch(indices) + return samples, indices + return self.dataset.read_batch(batch_size) + + def state_dict(self): + if self.selector is not None: + return self.selector.state_dict() + return {"current_index": self.dataset.current_offset} + + def load_state_dict(self, state_dict): + if self.selector is not None: + self.selector.load_state_dict(state_dict) + else: + self.dataset.current_offset = state_dict["current_index"] + + def __len__(self): + return self.dataset.dataset_size + + def _convert_batch(self, samples: List, indices: List) -> List: + raise NotImplementedError + + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: + batch_size = batch_size or self.read_batch_size + samples, indices = self._read_samples(batch_size) + return self._convert_batch(samples, indices) + + class FileReader(BaseFileReader): """Provide a unified interface for Experience and Task file readers.""" @@ -109,7 +163,7 @@ def __init__(self, config: StorageConfig): self.reader = TaskFileReader(config) def read(self, batch_size: Optional[int] = None, **kwargs) -> List: - return self.reader.read(batch_size) + return self.reader.read(batch_size, **kwargs) def state_dict(self): return self.reader.state_dict() @@ -125,15 +179,17 @@ def __len__(self): return self.reader.__len__() -class ExperienceFileReader(BaseFileReader): +class ExperienceFileReader(_DatasetFileReader): """Reader for SFT / DPO file data.""" def __init__(self, config: StorageConfig): - self.formatter = FORMATTER.get(config.schema_type)( + super().__init__(config) + + def _init_formatter_and_dataset(self, config: StorageConfig): + formatter = FORMATTER.get(config.schema_type)( tokenizer_path=config.tokenizer_path, format_config=config.format ) - self.read_batch_size = config.batch_size - self.dataset = _HFBatchReader( + dataset = _HFBatchReader( load_dataset(config.path, name=config.subset_name, split=config.split), name=config.name, default_batch_size=self.read_batch_size, @@ -142,82 +198,41 @@ def __init__(self, config: StorageConfig): total_steps=config.total_steps, enable_progress_bar=config.enable_progress_bar, ) - self.selector = None + return formatter, dataset - def read(self, batch_size: Optional[int] = None, **kwargs) -> List: - samples, _ = self.dataset.read_batch(batch_size or self.read_batch_size) + def _convert_batch(self, samples: List, indices: List) -> List: exp_list = [] for sample in samples: experience = self.formatter.format(sample) exp_list.append(experience) return exp_list - def state_dict(self): - return {"current_index": self.dataset.current_offset} - - def load_state_dict(self, state_dict): - self.dataset.current_offset = state_dict["current_index"] - - def __len__(self): - return self.dataset.dataset_size - -class TaskFileReader(BaseFileReader): +class TaskFileReader(_DatasetFileReader): """A Reader for task file data.""" def __init__(self, config: StorageConfig): - self.config = config - self.name = config.name - self.epoch = 0 datasets.disable_caching() - self.read_batch_size = config.batch_size - self.dataset = _HFBatchReader( - load_dataset(self.config.path, name=self.config.subset_name, split=self.config.split), - name=self.config.name, + super().__init__(config) + + def _init_formatter_and_dataset(self, config): + formatter = FORMATTER.get("task")(config) + dataset = _HFBatchReader( + load_dataset(config.path, name=config.subset_name, split=config.split), + name=config.name, default_batch_size=self.read_batch_size, - total_epochs=self.config.total_epochs if not self.config.is_eval else 1, - offset=self.config.index, - drop_last=not self.config.is_eval, - total_steps=self.config.total_steps if not self.config.is_eval else None, - enable_progress_bar=self.config.enable_progress_bar, + total_epochs=config.total_epochs if not config.is_eval else 1, + offset=config.index, + drop_last=not config.is_eval, + total_steps=config.total_steps if not config.is_eval else None, + enable_progress_bar=config.enable_progress_bar, ) - self.formatter = FORMATTER.get("task")(config) - if self.config.task_selector is not None: - from trinity.buffer.selector import SELECTORS - from trinity.buffer.selector.selector import BaseSelector - - self.selector: BaseSelector = SELECTORS.get(self.config.task_selector.selector_type)( - self.dataset, self.config.task_selector - ) - else: - self.selector = None + return formatter, dataset - def _get_tasks(self, samples: List, indices: List) -> List: + def _convert_batch(self, samples: List, indices: List) -> List: tasks = [] for sample, index in zip(samples, indices): task = self.formatter.format(sample) task.index["index"] = int(index) tasks.append(task) return tasks - - def read(self, batch_size: Optional[int] = None, **kwargs) -> List: - batch_size = batch_size or self.read_batch_size - if self.selector is not None: - indices = self.selector.get_indices(batch_size) - samples = self.dataset.select_batch(indices) - else: - samples, indices = self.dataset.read_batch(batch_size) - return self._get_tasks(samples, indices) - - def state_dict(self): - if self.selector is not None: - return self.selector.state_dict() - return {"current_index": self.dataset.current_offset} - - def load_state_dict(self, state_dict): - if self.selector is not None: - self.selector.load_state_dict(state_dict) - self.dataset.current_offset = state_dict["current_index"] - - def __len__(self): - return self.dataset.dataset_size diff --git a/trinity/buffer/schema/formatter.py b/trinity/buffer/schema/formatter.py index ecaf87747b..4e65ee68a6 100644 --- a/trinity/buffer/schema/formatter.py +++ b/trinity/buffer/schema/formatter.py @@ -100,6 +100,7 @@ def __init__(self, tokenizer_path: str, format_config: FormatConfig): self.tools_key = format_config.tools_key self.image_key = format_config.image_key self.video_key = format_config.video_key + self.enable_thinking = format_config.enable_thinking if self.image_key is not None or self.video_key is not None: assert ( self.enable_concatenated_multi_turn is False @@ -150,7 +151,7 @@ def _messages_to_experience( ) if isinstance(tools, str): try: - tools = json.loads(tools) + tools = json.loads(tools) if tools else None except json.JSONDecodeError: self.logger.error( "[SFT Data Error] Failed to decode 'tools' JSON. Please check your data format." @@ -162,6 +163,7 @@ def _messages_to_experience( messages=messages, tools=tools, chat_template=self.chat_template, + enable_thinking=self.enable_thinking, ) return Experience( tokens=token_ids, diff --git a/trinity/buffer/selector/selector.py b/trinity/buffer/selector/selector.py index a67036dd45..316a87a6d8 100644 --- a/trinity/buffer/selector/selector.py +++ b/trinity/buffer/selector/selector.py @@ -6,7 +6,7 @@ from trinity.buffer.reader.file_reader import _HFBatchReader from trinity.buffer.selector.difficulty_estimator import InterpolationBetaPREstimator -from trinity.common.config import TaskSelectorConfig +from trinity.common.config import DataSelectorConfig from trinity.utils.annotations import Experimental from trinity.utils.log import get_logger @@ -27,7 +27,7 @@ class BaseSelector: - state_dict / load_state_dict: for saving/loading selector state (checkpointing) """ - def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig): + def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig): self.data_source = data_source self.config = config @@ -82,7 +82,7 @@ class SequentialSelector(BaseSelector): Example: [0,1,2,...,B-1], then [B,B+1,...,2B-1], etc. """ - def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig): + def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig): super().__init__(data_source, config) self.dataset_size = data_source.dataset_size self.current_index = 0 @@ -117,7 +117,7 @@ class ShuffleSelector(BaseSelector): Mimics standard PyTorch DataLoader with shuffle=True. """ - def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig): + def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig): super().__init__(data_source, config) self.dataset_size = data_source.dataset_size # Total available samples self.current_index = 0 # Progress tracker @@ -172,7 +172,7 @@ class RandomSelector(BaseSelector): Can result in repeated samples within an epoch. Suitable for online or stochastic training regimes. """ - def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig): + def __init__(self, data_source: _HFBatchReader, config: DataSelectorConfig): super().__init__(data_source, config) self.dataset_size = data_source.dataset_size self.current_index = 0 @@ -214,7 +214,7 @@ class OfflineEasy2HardSelector(BaseSelector): (e.g., via teacher model confidence, length, BLEU score, etc.). """ - def __init__(self, data_source, config: TaskSelectorConfig): + def __init__(self, data_source, config: DataSelectorConfig): super().__init__(data_source, config) self.logger = get_logger("offline_easy2hard_selector") @@ -297,7 +297,7 @@ class DifficultyBasedSelector(BaseSelector): Supports both greedy selection (`tau=0`) and stochastic sampling (`tau>0`). """ - def __init__(self, data_source, config: TaskSelectorConfig) -> None: + def __init__(self, data_source, config: DataSelectorConfig) -> None: super().__init__(data_source, config) self.logger = get_logger("difficulty_based_selector") diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index c511886145..c758f92f18 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -24,7 +24,7 @@ def get_taskset_scheduler(explorer_state: Dict, config: Config) -> "TasksetSched TasksetSchedulerBase: The taskset scheduler instance """ taskset_configs = config.buffer.explorer_input.tasksets - if len(taskset_configs) == 1 and taskset_configs[0].task_selector.selector_type == "sequential": + if len(taskset_configs) == 1 and taskset_configs[0].data_selector.selector_type == "sequential": return SimpleTasksetScheduler(explorer_state, config) else: return TasksetScheduler(explorer_state, config) @@ -70,7 +70,7 @@ def __init__(self, explorer_state: Dict, config: Config): ) taskset_config = deepcopy(self.config.buffer.explorer_input.tasksets[0]) taskset_config.index = index - taskset_config.task_selector = None # disable selection + taskset_config.data_selector = None # disable selection self.reader = get_buffer_reader(taskset_config) async def read_async(self) -> List: diff --git a/trinity/common/config.py b/trinity/common/config.py index 20a159fb10..bfd3ed977a 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -70,6 +70,7 @@ class FormatConfig: # for sft / dpo, if None, use model.custom_chat_template chat_template: Optional[str] = None + enable_thinking: Optional[bool] = None @dataclass @@ -115,7 +116,7 @@ class LoRAConfig: @Experimental @dataclass -class TaskSelectorConfig: +class DataSelectorConfig: """Data selector config.""" selector_type: Optional[str] = "sequential" @@ -190,7 +191,7 @@ class StorageConfig: rollout_args: GenerationConfig = field(default_factory=GenerationConfig) workflow_args: dict = field(default_factory=dict) reward_fn_args: dict = field(default_factory=dict) - task_selector: TaskSelectorConfig = field(default_factory=TaskSelectorConfig) + data_selector: DataSelectorConfig = field(default_factory=DataSelectorConfig) # enable progress bar (tqdm) for _HFBatchReader enable_progress_bar: Optional[bool] = False @@ -231,7 +232,7 @@ class TasksetConfig: rollout_args: GenerationConfig = field(default_factory=GenerationConfig) workflow_args: dict = field(default_factory=dict) reward_fn_args: dict = field(default_factory=dict) - task_selector: TaskSelectorConfig = field(default_factory=TaskSelectorConfig) + data_selector: DataSelectorConfig = field(default_factory=DataSelectorConfig) # used for StorageType.FILE split: str = "train" @@ -264,7 +265,7 @@ def to_storage_config(self) -> StorageConfig: name=self.name, storage_type=self.storage_type, path=self.path, - task_selector=self.task_selector, + data_selector=self.data_selector, repeat_times=self.repeat_times, split=self.split, subset_name=self.subset_name, @@ -309,6 +310,7 @@ class ExperienceBufferConfig: subset_name: Optional[str] = None format: FormatConfig = field(default_factory=FormatConfig) enable_progress_bar: Optional[bool] = False + data_selector: DataSelectorConfig = field(default_factory=DataSelectorConfig) # ! DO NOT SET, automatically set schema_type: Optional[str] = None @@ -330,6 +332,7 @@ def to_storage_config(self) -> StorageConfig: name=self.name, storage_type=self.storage_type, path=self.path, + data_selector=self.data_selector, capacity=self.capacity, max_read_timeout=self.max_read_timeout, replay_buffer=self.replay_buffer, @@ -473,6 +476,7 @@ class ModelConfig: enable_prompt_truncation: bool = True # repetition penalty for response generation repetition_penalty: float = 1.0 + enable_thinking: Optional[bool] = None # lora config lora_configs: Optional[List[LoRAConfig]] = None diff --git a/trinity/common/config_validator.py b/trinity/common/config_validator.py index 9248e78431..4e2ebe981d 100644 --- a/trinity/common/config_validator.py +++ b/trinity/common/config_validator.py @@ -591,7 +591,7 @@ def validate(self, config: Config) -> None: "enable_prompt_truncation", ] rope_args = ["rope_scaling", "rope_theta"] - model_args = rollout_args + length_args + rope_args + model_args = rollout_args + length_args + rope_args + ["enable_thinking"] # rollout model for args in model_args + ["model_path", "trust_remote_code"]: @@ -767,7 +767,7 @@ def validate(self, config: Config) -> None: """ assert config.synchronizer.sync_interval > 0, "`sync_interval` must be positive." - if config.mode != "bench" and config.algorithm.algorithm_type != "dpo": # TODO + if config.mode != "bench" and config.algorithm.algorithm_type not in {"dpo", "sft"}: # TODO # check eval_interval if config.explorer.eval_interval % config.synchronizer.sync_interval != 0: config.explorer.eval_interval = ( @@ -930,10 +930,11 @@ def _fill_taskset_config(taskset: TasksetConfig, index: int, is_eval: bool = Fal _fill_taskset_config(taskset, i) # check if selector is supported - selector = SELECTORS.get(taskset.task_selector.selector_type) + selector = SELECTORS.get(taskset.data_selector.selector_type) if selector is None: raise ValueError( - f"Selector {taskset.task_selector.selector_type} is not supported." + f"Selector `{taskset.data_selector.selector_type}` " + f"in {taskset.name} is not supported." ) for idx, taskset in enumerate(explorer_input.eval_tasksets): @@ -998,6 +999,7 @@ def _check_trainer_input(self, config: Config): experience_buffer.tokenizer_path = config.model.model_path set_if_none(experience_buffer, "ray_namespace", config.ray_namespace) set_if_none(experience_buffer.format, "chat_template", config.model.custom_chat_template) + set_if_none(experience_buffer.format, "enable_thinking", config.model.enable_thinking) for aux_name, aux_buffer in trainer_input.auxiliary_buffers.items(): aux_buffer.batch_size = config.buffer.train_batch_size aux_buffer.tokenizer_path = config.model.model_path @@ -1008,6 +1010,16 @@ def _check_trainer_input(self, config: Config): f"please set it to the path of the auxiliary buffer." ) + from trinity.buffer.selector import SELECTORS + + # check if selector is supported + selector = SELECTORS.get(experience_buffer.data_selector.selector_type) + if selector is None: + raise ValueError( + f"Selector {experience_buffer.data_selector.selector_type} " + "in `experience_buffer` is not supported." + ) + if config.mode == "train": assert ( experience_buffer is not None diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 12822f125c..e6276d5f7f 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -73,51 +73,69 @@ def tokenize_and_mask_messages_default( If the assumption is not met, the function may produce incorrect results. Please check the chat template before using this method. """ + if len(messages) == 0: + raise ValueError("Messages should not be empty") - tokens = tokenizer.apply_chat_template( - messages, + common_kwargs = dict( tools=tools, chat_template=chat_template, - add_generation_prompt=False, enable_thinking=enable_thinking, padding=False, truncation=True, - return_tensors="pt", add_special_tokens=False, - return_dict=False, + tokenize=True, + return_dict=True, ) - assistant_token_mask = torch.zeros(tokens.shape[1], dtype=torch.int) - for idx, message in enumerate(messages): + + generation_messages = [] + response_messages = [] + + start_idx = 0 + if "" in (chat_template or tokenizer.chat_template): + # find last user message for thinking template + for idx in range(len(messages) - 1, -1, -1): + message = messages[idx] + if message["role"] == "user": + start_idx = idx + break + + for idx in range(start_idx, len(messages)): + message = messages[idx] if message["role"] == "assistant": - prompt_token_ids = tokenizer.apply_chat_template( - messages[:idx], - tools=tools, - chat_template=chat_template, - add_generation_prompt=True, - enable_thinking=enable_thinking, - padding=False, - truncation=True, - return_tensors="pt", - add_special_tokens=False, - return_dict=False, - ) - prompt_length = prompt_token_ids.shape[1] - prompt_response_token_ids = tokenizer.apply_chat_template( - messages[: idx + 1], - tools=tools, - chat_template=chat_template, - add_generation_prompt=False, - enable_thinking=enable_thinking, - padding=False, - truncation=True, - return_tensors="pt", - add_special_tokens=False, - return_dict=False, - ) - prompt_response_length = prompt_response_token_ids.shape[1] - assistant_token_mask[prompt_length:prompt_response_length] = 1 + generation_messages.append(messages[:idx]) + response_messages.append(messages[: idx + 1]) + elif idx == len(messages) - 1: + response_messages.append(messages) + + # response_messages contains at least one message, so response_token_ids_list is not empty + response_token_ids_list = tokenizer.apply_chat_template( + response_messages, + add_generation_prompt=False, + **common_kwargs, + )["input_ids"] + assistant_token_mask = torch.zeros(len(response_token_ids_list[-1]), dtype=torch.int) + + if len(generation_messages) == 0: # no assistant message + return torch.tensor(response_token_ids_list[-1]), assistant_token_mask, 0 + + first_generation_message_empty_flag = len(generation_messages[0]) == 0 + if first_generation_message_empty_flag: + # the first message is from assistant, so generation_messages[0] is empty + generation_messages[0] = response_messages[0] + prompt_token_ids_list = tokenizer.apply_chat_template( + generation_messages, + add_generation_prompt=True, + **common_kwargs, + )["input_ids"] + if first_generation_message_empty_flag: + # the first message is from assistant, so set the first prompt_token_ids to empty + prompt_token_ids_list[0] = [] + + for prompt_token_ids, response_token_ids in zip(prompt_token_ids_list, response_token_ids_list): + assistant_token_mask[len(prompt_token_ids) : len(response_token_ids)] = 1 + prompt_length = torch.argmax(assistant_token_mask).item() - return tokens[0], assistant_token_mask, prompt_length + return torch.tensor(response_token_ids_list[-1]), assistant_token_mask, prompt_length def get_action_mask_method(chat_template: Optional[str] = None) -> Callable: diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 74dbd41ce2..29c087b72f 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -618,6 +618,7 @@ def _build_model_optimizer( # noqa: C901 num_training_steps=total_steps, min_lr_ratio=min_lr_ratio, num_cycles=num_cycles, + init_lr_ratio=min_lr_ratio, ) else: raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") diff --git a/trinity/trainer/verl/verl_config.py b/trinity/trainer/verl/verl_config.py index 89c996268d..5b151ade14 100644 --- a/trinity/trainer/verl/verl_config.py +++ b/trinity/trainer/verl/verl_config.py @@ -651,10 +651,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 set_if_none(actor_optim, "lr_warmup_init", optim_config.min_lr_ratio * optim_config.lr) set_if_none(actor_optim, "lr_decay_steps", self.trainer.total_training_steps) set_if_none(actor_optim, "lr_decay_style", optim_config.lr_scheduler_type) + set_if_none(actor_optim, "min_lr_ratio", optim_config.min_lr_ratio) set_if_none(actor_optim, "min_lr", optim_config.min_lr_ratio * optim_config.lr) set_if_none(critic_optim, "lr_warmup_init", 0.0) set_if_none(critic_optim, "lr_decay_steps", self.trainer.total_training_steps) set_if_none(critic_optim, "lr_decay_style", "constant") + set_if_none(critic_optim, "min_lr_ratio", 0.0) set_if_none(critic_optim, "min_lr", 0.0) # fix optimizer type for fsdp if config.trainer.trainer_strategy.startswith("fsdp"):