diff --git a/test/llm/test_objectives.py b/test/llm/test_objectives.py index 9dd0ffb9367..f6216129c51 100644 --- a/test/llm/test_objectives.py +++ b/test/llm/test_objectives.py @@ -14,9 +14,9 @@ from tensordict import lazy_stack, TensorDict from torchrl.data import History, LazyStackStorage, ReplayBuffer from torchrl.envs.llm.transforms.kl import RetrieveLogProb -from torchrl.modules.llm import Text, TransformersWrapper, vLLMWrapper -from torchrl.modules.llm.policies.common import ChatHistory, Masks, Tokens -from torchrl.objectives.llm.grpo import MCAdvantage +from torchrl.modules.llm import TransformersWrapper, vLLMWrapper +from torchrl.modules.llm.policies.common import ChatHistory, Masks, Text, Tokens +from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage from torchrl.objectives.llm.sft import SFTLoss _has_transformers = importlib.util.find_spec("transformers") is not None @@ -53,7 +53,7 @@ def make_silly_trajectory(n_steps=None): rewards = [torch.randn(n_tokens, 1)] prompt = np.random.choice(prompts) td = TensorDict( - text=Text(prompt=prompt), + query=prompt, # MCAdvantage expects "query" key, not "text" next=TensorDict( reward=rewards, done=torch.zeros(1, dtype=torch.bool) ), @@ -83,8 +83,158 @@ def make_silly_trajectory(n_steps=None): assert "advantage" in s.keys() -def test_grpo(): - ... +# Mock infrastructure moved to conftest.py + + +def _mock_data_grpo( + vocab_size: int, device: torch.device | str = "cpu" +) -> TensorDict: + from transformers import AutoTokenizer + + device = torch.device(device) + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") + prompt = History( + role=["system", "user"], + content=["You are a useful assistant.", "What is 2+2?"], + batch_size=(2,), + device=device, + ) + response = History( + role=["assistant"], + content=["2 + 2 = 4."], + batch_size=(1,), + device=device, + ) + full_history = prompt.extend(response, inplace=False) + history = ChatHistory( + prompt=prompt, + response=response, + full=full_history, + device=device, + ) + batch_size = 1 + + # Expand history to match batch size before getting tokens + history = history.expand((batch_size,)) + next_history = ChatHistory( + prompt=full_history, + device=device, + ) + next_history = next_history.expand((batch_size,)) + + # Now get tokens from the expanded history objects + tokens_full = history.to_tokens(tokenizer) + next_tokens = next_history.to_tokens(tokenizer) + + # Get the actual sequence length from the tokens + # tokens_full has structure with "full" key containing the actual tokens + # We need to get the padded version to know the actual length + tokens_input_ids = tokens_full.get( + "full", as_padded_tensor=True, padding_side="left", padding_value=0 + ) + seq_len = tokens_input_ids.shape[-1] + + # Create tensors with proper shapes + reward = torch.randn(batch_size, seq_len, 1, device=device) + done = torch.zeros(batch_size, seq_len, 1, dtype=torch.bool, device=device) + advantage = torch.randn(batch_size, seq_len, 1, device=device) + log_probs = torch.randn_like(tokens_full, dtype=torch.float32, device=device) + + # Create attention mask (all ones for non-padded tokens) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device) + + # Import Masks to create proper mask structure + from tensordict import MetaData + from torchrl.modules.llm.policies.common import Masks + + masks = Masks( + all_attention_mask=attention_mask, + all_assistant_mask=None, # Will be computed by the wrapper + padded=MetaData(True), + device=device, + ) + + data = TensorDict( + { + "advantage": advantage, + "history": history, + "tokens": tokens_full % vocab_size, + "masks": masks, + "next": { + "history": next_history, + "tokens": next_tokens % vocab_size, + "reward": reward, + "done": done, + }, + "log_probs": log_probs, + }, + batch_size=(batch_size,), + ) + return data + + +class TestLosses: + def test_grpo(self, mock_transformer_model): + """Test GRPO loss computation with mock models.""" + vocab_size = 1024 + device = torch.device("cpu") + + # Create mock model and wrap it + model = mock_transformer_model(vocab_size=vocab_size, device=device) + actor_network = TransformersWrapper( + model, + generate=False, + pad_output=True, + input_mode="history", + ) + + # Create loss module + loss_fn = GRPOLoss(actor_network) + + # Create fake data + data = _mock_data_grpo(vocab_size=vocab_size, device=device) + + # Compute loss + loss_vals = loss_fn(data) + + # Assertions: Check output type and structure + from torchrl.objectives.llm.grpo import GRPOLossOutput + + assert isinstance( + loss_vals, GRPOLossOutput + ), f"Expected GRPOLossOutput, got {type(loss_vals)}" + + # Check that all expected keys are present + assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective" + assert hasattr(loss_vals, "clip_fraction"), "Missing clip_fraction" + assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx" + assert hasattr(loss_vals, "ESS"), "Missing ESS" + assert hasattr(loss_vals, "entropy"), "Missing entropy" + assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy" + + # Check tensor shapes (all losses should be scalars after reduction) + assert ( + loss_vals.loss_objective.shape == () + ), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}" + assert ( + loss_vals.clip_fraction.shape == () + ), f"clip_fraction should be scalar, got {loss_vals.clip_fraction.shape}" + assert ( + loss_vals.kl_approx.shape == () + ), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}" + assert ( + loss_vals.ESS.shape == () + ), f"ESS should be scalar, got {loss_vals.ESS.shape}" + + # Check that losses are finite + assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite" + assert torch.isfinite(loss_vals.ESS), "ESS is not finite" + + # Check that clip_fraction is in valid range [0, 1] + assert ( + 0 <= loss_vals.clip_fraction <= 1 + ), f"clip_fraction out of range: {loss_vals.clip_fraction}" class TestSFT: @@ -203,7 +353,7 @@ def test_sft( assistant_only=True, tokenizer_kwargs={"chat_template_name": "qwen"}, tokenizer=tokenizer, - log_probs_key=("ref_log_prob", "full"), + log_probs_full_key=("ref_log_probs", "full"), ) with torch.no_grad(): # Compute ref log-probs @@ -247,7 +397,7 @@ def test_sft_assistant_only(self, data): assistant_only=True, tokenizer_kwargs={"chat_template_name": "qwen"}, tokenizer=tokenizer, - log_probs_key=("ref_log_prob", "full"), + log_probs_full_key=("ref_log_probs", "full"), ) td = transform(data) assert td is data @@ -262,10 +412,12 @@ def test_sft_assistant_only(self, data): loss(td) +@pytest.mark.slow +@pytest.mark.integration class TestGRPOLossIntegration: - """Test GRPOLoss integration with the new distribution methods.""" + """Integration tests for GRPOLoss with real models (vLLM + transformers).""" - @pytest.fixture(scope="module") + @pytest.fixture(scope="class") def transformers_instance(self): """Create transformers model and tokenizer for testing.""" if not _has_transformers: @@ -277,7 +429,7 @@ def transformers_instance(self): tokenizer.pad_token = tokenizer.eos_token return model, tokenizer - @pytest.fixture(scope="module") + @pytest.fixture(scope="class") def vllm_instance(self): """Create vLLM model and tokenizer for testing.""" if not _has_vllm: @@ -297,102 +449,52 @@ def vllm_instance(self): except Exception as e: pytest.skip(f"Failed to load vLLM model: {e}") - @pytest.fixture(scope="module") - def sample_tokens(self, vllm_instance): - """Create sample tokens for testing.""" - model, tokenizer = vllm_instance - text = [ - "Are you happy? Say yes or no.", - "Explain the difference between a cat and a dog. Be very detailed.", - ] - tokenized = tokenizer( - text, return_tensors="pt", padding=True, padding_side="left" - ) - return tokenized["input_ids"], tokenized["attention_mask"] - - @pytest.fixture(scope="module") - def sample_text(self): - """Create sample text for testing.""" - return [ - "Are you happy? Say yes or no.", - "Explain the difference between a cat and a dog. Be very detailed.", - ] - - @pytest.fixture(scope="module") - def sample_history(self): - """Create sample conversation history for testing.""" - chats = [ - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Are you happy? Say yes or no."}, - ], - [ - { - "role": "system", - "content": "You are a very helpful assistant, but more handsome.", - }, - { - "role": "user", - "content": "Explain the difference between a cat and a dog. Be very detailed.", - }, - ], - ] - return History.from_chats(chats) - - @pytest.fixture(scope="module") - def sample_history_assistant(self): - """Create sample conversation history for testing.""" - chats = [ - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Are you happy? Say yes or no."}, - {"role": "assistant", "content": "Yes."}, - ], - [ - { - "role": "system", - "content": "You are a very helpful assistant, but more handsome.", - }, - { - "role": "user", - "content": "Explain the difference between a cat and a dog. Be very detailed.", - }, - { - "role": "assistant", - "content": "A cat is a small animal that meows, while a dog is a larger animal that barks.", - }, - ], - ] - return History.from_chats(chats) - @pytest.mark.skipif(not _has_vllm, reason="vllm not available") @pytest.mark.parametrize("masking_strategy", ["sft", "rlhf"]) - def test_grpo_loss_with_transformers( + def test_grpo_loss_with_real_models( self, vllm_instance, transformers_instance, - sample_history, - sample_tokens, masking_strategy, ): - """Test GRPOLoss with vLLM wrapper and different masking strategies.""" + """Test GRPOLoss with vLLM generation and transformers loss computation.""" from torchrl.objectives.llm.grpo import GRPOLoss model, tokenizer = transformers_instance vllm_model, vllm_tokenizer = vllm_instance - # Use tokens input mode for SFT, history for RLHF/generic + # Create sample input based on masking strategy if masking_strategy == "sft": - input_mode = "tokens" - input_ids, attention_mask = sample_tokens + # Use tokens input mode for SFT + text = [ + "Are you happy? Say yes or no.", + "What is 2+2?", + ] + tokenized = tokenizer( + text, return_tensors="pt", padding=True, padding_side="left" + ) input_data = { - "tokens": Tokens(prompt=input_ids), - "masks": Masks(all_attention_mask=attention_mask), + "tokens": Tokens(prompt=tokenized["input_ids"]), + "masks": Masks(all_attention_mask=tokenized["attention_mask"]), } + input_mode = "tokens" else: - input_mode = "history" + # Use history input mode for RLHF + chats = [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Are you happy? Say yes or no."}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2+2?"}, + ], + ] + sample_history = History.from_chats(chats) input_data = {"history": ChatHistory(prompt=sample_history)} + input_mode = "history" + # Generate responses with vLLM wrapper_gen = vLLMWrapper( vllm_model, tokenizer=vllm_tokenizer, @@ -403,12 +505,11 @@ def test_grpo_loss_with_transformers( generate_kwargs={"max_tokens": 10}, ) - # Create test data with advantage and correct batch size td = TensorDict(input_data, batch_size=(2,)).to_lazystack(0) td = wrapper_gen(td) - # use a shape that can be broadcast td["advantage"] = torch.randn(2, 1, 1) + # Compute loss with transformers wrapper = TransformersWrapper( model, tokenizer=tokenizer, @@ -418,23 +519,13 @@ def test_grpo_loss_with_transformers( pad_output=True, ) - # Create GRPOLoss with specified masking strategy - loss_fn = GRPOLoss( - actor_network=wrapper, - masking_strategy=masking_strategy, - ) + loss_fn = GRPOLoss(actor_network=wrapper, masking_strategy=masking_strategy) - # This should work without shape mismatch errors - try: - result = loss_fn(td) - assert result is not None - except ValueError as e: - if "Shape mismatch" in str(e): - # This is expected if the advantage shape doesn't match the log-prob shape - # due to different masking strategies - assert masking_strategy in str(e) - else: - raise + # Should successfully compute loss + result = loss_fn(td) + assert result is not None + assert hasattr(result, "loss_objective") + assert torch.isfinite(result.loss_objective) if __name__ == "__main__": diff --git a/torchrl/envs/llm/transforms/kl.py b/torchrl/envs/llm/transforms/kl.py index 311c9d72e55..0c44c036d9d 100644 --- a/torchrl/envs/llm/transforms/kl.py +++ b/torchrl/envs/llm/transforms/kl.py @@ -241,7 +241,7 @@ def __init__( if out_keys is None: out_keys = copy(in_keys) if len(out_keys) == len(in_keys): - out_keys = out_keys + ["kl_penalty", "ref_log_prob"] + out_keys = out_keys + ["kl_penalty", "ref_log_probs"] elif len(out_keys) != len(in_keys) + 2: raise ValueError( "The out_keys must have the same length as the in_keys (plus two additional optional kl entries for logging)." diff --git a/torchrl/modules/llm/policies/common.py b/torchrl/modules/llm/policies/common.py index de53db03988..ee9761613b1 100644 --- a/torchrl/modules/llm/policies/common.py +++ b/torchrl/modules/llm/policies/common.py @@ -11,7 +11,7 @@ from contextlib import nullcontext from functools import wraps -from typing import Any, Literal, overload +from typing import Any, Literal, overload, TYPE_CHECKING import torch from tensordict import lazy_stack, NestedKey, TensorDictBase @@ -21,10 +21,14 @@ from torch import distributions as D from torch.distributions import Categorical from torch.nn.utils.rnn import pad_sequence +from torchrl._utils import logger as torchrl_logger from torchrl.data.llm import History from torchrl.data.tensor_specs import Unbounded from torchrl.modules.distributions.discrete import LLMMaskedCategorical +if TYPE_CHECKING: + from transformers import AutoTokenizer + # TODOs: # - [ ] Remove the useless view(-1) calls when num_samples is not > 1 # - [ ] Remove as_list=True and use a context manager to handle that @@ -72,6 +76,107 @@ def default_spec( return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True) + def to_text( + self, + tokenizer: AutoTokenizer, + skip_special_tokens: bool = False, + ) -> Text: + """Convert tokens to text using the tokenizer. + + Args: + tokenizer: The tokenizer to use for decoding. + skip_special_tokens: Whether to skip special tokens in the output. + + Returns: + A Text object with decoded text. + + Raises: + ValueError: If padded tokens are provided (not yet supported). + """ + # Check if padded - handle both bool and LinkedList cases + padded = self.padded + if isinstance(padded, bool): + if padded: + raise ValueError( + "Conversion from padded tokens to text is not yet supported. " + "Please use unpadded tokens (nested tensors)." + ) + else: + # LinkedList case (when stacked) - check if any are True + padded_list = self.view(-1).padded + if any(padded_list): + raise ValueError( + "Conversion from padded tokens to text is not yet supported. " + "Please use unpadded tokens (nested tensors)." + ) + + # Create output structure + text_out = Text._from_tensordict(self._tensordict.empty()) + + # Helper to prepare tokens for batch_decode + def _prepare_tokens_for_decode(tokens_list): + """Ensure tokens are in the right format for batch_decode.""" + if isinstance(tokens_list, list): + # Squeeze out extra batch dimensions if present + return [t.squeeze(0) if t.dim() > 1 else t for t in tokens_list] + else: + # Single tensor case + return tokens_list + + # Decode prompt if available + if "prompt" in self._tensordict.keys(): + prompt_tokens_list = self.get("prompt", as_list=True) + prompt_tokens_list = _prepare_tokens_for_decode(prompt_tokens_list) + prompt_texts = tokenizer.batch_decode( + prompt_tokens_list, skip_special_tokens=skip_special_tokens + ) + text_out.set("prompt", prompt_texts) + + # Decode response if available + if "response" in self._tensordict.keys(): + response_tokens_list = self.get("response", as_list=True) + response_tokens_list = _prepare_tokens_for_decode(response_tokens_list) + response_texts = tokenizer.batch_decode( + response_tokens_list, skip_special_tokens=skip_special_tokens + ) + text_out.set("response", response_texts) + + # Decode full if available + if "full" in self._tensordict.keys(): + full_tokens_list = self.get("full", as_list=True) + full_tokens_list = _prepare_tokens_for_decode(full_tokens_list) + full_texts = tokenizer.batch_decode( + full_tokens_list, skip_special_tokens=skip_special_tokens + ) + text_out.set("full", full_texts) + + return text_out + + def to_history( + self, + tokenizer: AutoTokenizer, + chat_template_name: str | None = None, + skip_special_tokens: bool = False, + ) -> ChatHistory: + """Convert tokens to history by first decoding to text, then parsing. + + Args: + tokenizer: The tokenizer to use for decoding and parsing. + chat_template_name: Optional chat template name for parsing. + skip_special_tokens: Whether to skip special tokens when decoding. + + Returns: + A ChatHistory object with parsed conversation history. + + Raises: + ValueError: If padded tokens are provided (not yet supported). + """ + # First convert to text + text_obj = self.to_text(tokenizer, skip_special_tokens=skip_special_tokens) + + # Then convert text to history + return text_obj.to_history(tokenizer, chat_template_name=chat_template_name) + class Masks(TensorClass["nocast"]): """A Masks container. @@ -210,6 +315,183 @@ def __post_init__(self): [self.full], -1 ) # equivalent to unsqueeze(-1) but make sure it's a lazy stack + def to_tokens( + self, + tokenizer: AutoTokenizer, + chat_template_name: str | None = None, + chat_template: str | None = None, + ) -> Tokens: + """Tokenize the conversation history into a :class:`Tokens` object. + + Args: + tokenizer: The tokenizer to use for tokenization. + chat_template_name: Optional chat template name to use. + chat_template: Optional chat template string to use. + + Returns: + A Tokens object with prompt, response, and full tokens. + + Note: + - For prompt: uses add_generation_prompt=True + - For full: uses add_generation_prompt=False + - Response is computed by slicing full tokens after prompt length + """ + from tensordict.utils import _zip_strict + + tokenizer_kwargs = {} + if chat_template_name is not None: + tokenizer_kwargs["chat_template_name"] = chat_template_name + if chat_template is not None: + tokenizer_kwargs["chat_template"] = chat_template + + # Create output structure + tokens_out = Tokens._from_tensordict(self._tensordict.empty()) + + # Process prompt if available + if self.prompt is not None: + prompt_tokens = self.prompt.apply_chat_template( + tokenizer=tokenizer, + return_dict=True, + add_generation_prompt=True, + tokenize=True, + padding=False, + **tokenizer_kwargs, + ) + # Get input_ids using as_nested_tensor to handle different lengths + tokens_out._tensordict.set( + "prompt", prompt_tokens.get("input_ids", as_list=True) + ) + + # Process full if available + if self.full is not None: + full_tokens = self.full.apply_chat_template( + tokenizer=tokenizer, + return_dict=True, + add_generation_prompt=False, + tokenize=True, + padding=False, + **tokenizer_kwargs, + ) + # Get input_ids using as_nested_tensor to handle different lengths + tokens_out._tensordict.set( + "full", full_tokens.get("input_ids", as_list=True) + ) + + # Compute response by slicing if both prompt and full are available + if self.prompt is not None and self.full is not None: + prompt_tokens_list = tokens_out.get("prompt", as_list=True) + full_tokens_list = tokens_out.get("full", as_list=True) + response_tokens_list = [] + + for prompt_tok, full_tok in _zip_strict( + prompt_tokens_list, full_tokens_list + ): + prompt_len = prompt_tok.shape[-1] + response_tok = full_tok[..., prompt_len:] + response_tokens_list.append(response_tok) + + tokens_out.set("response", response_tokens_list) + + # Process response directly if available (and full is not) + elif self.response is not None: + response_tokens = self.response.apply_chat_template( + tokenizer=tokenizer, + return_dict=True, + add_generation_prompt=False, + tokenize=True, + padding=False, + **tokenizer_kwargs, + ) + # Get input_ids using as_nested_tensor to handle different lengths + tokens_out._tensordict.set( + "response", response_tokens.get("input_ids", as_list=True) + ) + + tokens_out.padded = False + return tokens_out + + def to_text( + self, + tokenizer: AutoTokenizer, + chat_template_name: str | None = None, + chat_template: str | None = None, + ) -> Text: + """Convert the conversation history into a :class:`Text` object. + + Args: + tokenizer: The tokenizer to use for applying chat templates. + chat_template_name: Optional chat template name to use. + chat_template: Optional chat template string to use. + + Returns: + A Text object with prompt, response, and full text. + + Note: + - For prompt: uses add_generation_prompt=True + - For full: uses add_generation_prompt=False + - Response is computed by removing prompt prefix from full text + """ + from tensordict.utils import _zip_strict + + tokenizer_kwargs = {} + if chat_template_name is not None: + tokenizer_kwargs["chat_template_name"] = chat_template_name + if chat_template is not None: + tokenizer_kwargs["chat_template"] = chat_template + + # Create output structure + text_out = Text._from_tensordict(self._tensordict.empty()) + + # Process prompt if available + if self.prompt is not None: + prompt_text = self.prompt.apply_chat_template( + tokenizer=tokenizer, + tokenize=False, + add_generation_prompt=True, + **tokenizer_kwargs, + ) + text_out.set("prompt", prompt_text) + + # Process full if available + if self.full is not None: + full_text = self.full.apply_chat_template( + tokenizer=tokenizer, + tokenize=False, + add_generation_prompt=False, + **tokenizer_kwargs, + ) + text_out.set("full", full_text) + + # Compute response by removing prompt prefix if both are available + if self.prompt is not None and self.full is not None: + prompt_texts_list = text_out.get("prompt", as_list=True) + full_texts_list = text_out.get("full", as_list=True) + response_texts_list = [] + + for prompt_txt, full_txt in _zip_strict(prompt_texts_list, full_texts_list): + if full_txt.startswith(prompt_txt): + response_txt = full_txt[len(prompt_txt) :] + else: + raise ValueError( + f"Full text does not start with prompt text. " + f"Prompt: {prompt_txt[:50]}..., Full: {full_txt[:50]}..." + ) + response_texts_list.append(response_txt) + + text_out.set("response", response_texts_list) + + # Process response directly if available (and full is not) + elif self.response is not None: + response_text = self.response.apply_chat_template( + tokenizer=tokenizer, + tokenize=False, + add_generation_prompt=False, + **tokenizer_kwargs, + ) + text_out.set("response", response_text) + + return text_out + class LogProbs(TensorClass["nocast"]): """A log-probability container. @@ -281,6 +563,157 @@ def default_spec( return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True) + def to_tokens( + self, + tokenizer: AutoTokenizer, + padding: bool = False, + truncation: bool = False, + return_tensors: str = "pt", + ) -> Tokens: + """Convert text to tokens using the tokenizer. + + Args: + tokenizer: The tokenizer to use for encoding. + padding: Whether to pad the sequences. + truncation: Whether to truncate the sequences. + return_tensors: The format of the output tensors. + + Returns: + A Tokens object with tokenized text. + + Raises: + ValueError: If padding is requested (not yet supported). + """ + if padding: + raise ValueError( + "Padding is not yet supported for text to tokens conversion. " + "Please use padding=False." + ) + + # When not padding, we can't use return_tensors because sequences have different lengths + # We'll get lists and convert them to tensors ourselves + actual_return_tensors = return_tensors if padding else None + + # Create output structure + tokens_out = Tokens._from_tensordict(self._tensordict.empty()) + + # Tokenize prompt if available + if self.prompt is not None: + prompt_texts_list = self.prompt + prompt_tokens = tokenizer( + prompt_texts_list, + padding=padding, + truncation=truncation, + return_tensors=actual_return_tensors, + ) + # Convert to list of tensors + input_ids = prompt_tokens["input_ids"] + if not isinstance(input_ids, list): + input_ids = list(input_ids) + else: + # Convert each list to tensor + input_ids = [torch.tensor(ids) for ids in input_ids] + tokens_out.set("prompt", input_ids) + + # Tokenize response if available + if self.response is not None: + response_texts_list = self.response + response_tokens = tokenizer( + response_texts_list, + padding=padding, + truncation=truncation, + return_tensors=actual_return_tensors, + ) + # Convert to list of tensors + input_ids = response_tokens["input_ids"] + if not isinstance(input_ids, list): + input_ids = list(input_ids) + else: + # Convert each list to tensor + input_ids = [torch.tensor(ids) for ids in input_ids] + tokens_out.set("response", input_ids) + + # Tokenize full if available + if self.full is not None: + full_texts_list = self.full + full_tokens = tokenizer( + full_texts_list, + padding=padding, + truncation=truncation, + return_tensors=actual_return_tensors, + ) + # Convert to list of tensors + input_ids = full_tokens["input_ids"] + if not isinstance(input_ids, list): + input_ids = list(input_ids) + else: + # Convert each list to tensor + input_ids = [torch.tensor(ids) for ids in input_ids] + tokens_out.set("full", input_ids) + + tokens_out.padded = padding + return tokens_out + + def to_history( + self, + tokenizer: AutoTokenizer, + chat_template_name: str | None = None, + ) -> ChatHistory: + """Convert text to history by parsing the chat format. + + Args: + tokenizer: The tokenizer to use for parsing. + chat_template_name: Optional chat template name for parsing. + + Returns: + A ChatHistory object with parsed conversation history. + """ + from torchrl.data.llm import History + + # Create output structure + history_out = ChatHistory._from_tensordict(self._tensordict.empty()) + + # Parse prompt if available + if self.prompt is not None: + prompt_texts_list = self.prompt + prompt_histories_list = [] + for prompt_text in prompt_texts_list: + prompt_hist = History.from_text( + prompt_text, + chat_template_name=chat_template_name, + tokenizer=tokenizer, + ) + prompt_histories_list.append(prompt_hist) + history_out.set("prompt", lazy_stack(prompt_histories_list)) + + # Parse response if available + if self.response is not None: + response_texts_list = self.response + response_histories_list = [] + for response_text in response_texts_list: + response_hist = History.from_text( + response_text, + chat_template_name=chat_template_name, + tokenizer=tokenizer, + ) + response_histories_list.append(response_hist) + history_out.set("response", lazy_stack(response_histories_list)) + + # Parse full if available + if self.full is not None: + full_texts_list = self.full + full_histories_list = [] + for full_text in full_texts_list: + full_hist = History.from_text( + full_text, + chat_template_name=chat_template_name, + tokenizer=tokenizer, + ) + full_histories_list.append(full_hist) + history_out.set("full", lazy_stack(full_histories_list)) + + return history_out + class LogProbDistribution(D.Distribution): """A distribution that works directly with log-probabilities. @@ -866,6 +1299,8 @@ def _get_dist_with_prompt_mask( # Make the response mask using prompt tokens if not self.pad_output: # Check that the lengths of the mask is the same as the logits + torchrl_logger.info(f"Response mask: {response_mask}") + torchrl_logger.info(f"Logits: {logits}") for m, lg in _zip_strict(response_mask, logits): if m.shape[-1] != lg.shape[-2]: raise ValueError( diff --git a/torchrl/modules/llm/policies/transformers_wrapper.py b/torchrl/modules/llm/policies/transformers_wrapper.py index e03eaceea7c..dffd035df61 100644 --- a/torchrl/modules/llm/policies/transformers_wrapper.py +++ b/torchrl/modules/llm/policies/transformers_wrapper.py @@ -2216,7 +2216,8 @@ def _model_forward_with_padded_sequences( "Input contains empty sequences. Packing/padding requires at least one token per sequence." ) # Error handling for overlong sequences - max_len = getattr(self.model.config, "max_position_embeddings", None) + config = getattr(self.model, "config", None) + max_len = getattr(config, "max_position_embeddings", None) if max_len is not None and tokens_full_padded.shape[-1] > max_len: raise ValueError( f"Input sequence length ({tokens_full_padded.shape[-1]}) exceeds model's max_position_embeddings ({max_len}). Consider truncating or splitting your input." diff --git a/torchrl/objectives/llm/grpo.py b/torchrl/objectives/llm/grpo.py index 9633fd451f6..98581aac01c 100644 --- a/torchrl/objectives/llm/grpo.py +++ b/torchrl/objectives/llm/grpo.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import contextlib + from collections import defaultdict, deque from dataclasses import dataclass from typing import Literal @@ -15,19 +17,19 @@ TensorClass, TensorDict, TensorDictBase, - TensorDictParams, ) from tensordict.nn import ( + CompositeDistribution, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, - TensorDictModule, + set_composite_lp_aggregate, ) from tensordict.utils import expand_as_right from torch import distributions as d -from torchrl._utils import logger as torchrl_logger +from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.envs.transforms.transforms import Transform from torchrl.modules.llm import LLMWrapperBase -from torchrl.objectives.ppo import ClipPPOLoss +from torchrl.objectives.common import LossModule from torchrl.objectives.utils import _reduce, _sum_td_features @@ -46,7 +48,7 @@ class GRPOLossOutput(TensorClass["nocast"]): kl_to_inference: torch.Tensor | None = None -class GRPOLoss(ClipPPOLoss): +class GRPOLoss(LossModule): """GRPO loss. The clipped importance weighted loss is computed as follows: @@ -116,20 +118,18 @@ class GRPOLoss(ClipPPOLoss): """ actor_network: LLMWrapperBase - critic_network: TensorDictModule - actor_network_params: TensorDictParams - critic_network_params: TensorDictParams - target_actor_network_params: TensorDictParams - target_critic_network_params: TensorDictParams @dataclass - class _AcceptedKeys(ClipPPOLoss._AcceptedKeys): + class _AcceptedKeys(LossModule._AcceptedKeys): """Maintains default values for all configurable tensordict keys. This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their default values """ + advantage: NestedKey = "advantage" + action: NestedKey = ("tokens", "full") + sample_log_prob: NestedKey = ("log_probs", "full") ref_log_probs: NestedKey = ("next", "ref_log_probs", "full") def __init__( @@ -149,32 +149,85 @@ def __init__( masking_strategy: Literal["sft", "rlhf", "generic"] = "sft", **kwargs, ): - # Define clipping of the value loss - if isinstance(clip_value, bool): - clip_value = clip_epsilon if clip_value else None - - super().__init__( - actor_network, - critic_network=None, - entropy_bonus=entropy_bonus, - samples_mc_entropy=samples_mc_entropy, - entropy_coeff=entropy_coeff, - gamma=gamma, - separate_losses=False, - reduction=reduction, - clip_value=clip_value, - functional=False, - device=device, - **kwargs, - ) - # We don't want to use the string action but the tokens - self._set_in_keys() + super().__init__() + # Core modules and hyper-parameters + self.actor_network = actor_network + self.entropy_bonus = entropy_bonus + self.samples_mc_entropy = samples_mc_entropy + self.entropy_coeff = entropy_coeff + self.reduction = reduction if reduction is not None else "mean" + + # Determine device and register clip epsilon as buffer + if device is None: + try: + device = next(self.parameters()).device + except (AttributeError, StopIteration): + device = getattr( + torch, "get_default_device", lambda: torch.device("cpu") + )() + self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device)) + self.masking_strategy = masking_strategy - # Always use the full tokens for the action + # Defaults for keys self.set_keys(sample_log_prob=("log_probs", "full"), action=("tokens", "full")) - # TODO: make this a buffer + # KL coefficients self.kl_to_ref_coeff = kl_to_ref_coeff self.kl_to_inference_coeff = kl_to_inference_coeff + # Prepare IO keys + self._set_in_keys() + + @property + def _clip_bounds(self): + return ((-self.clip_epsilon).log1p(), self.clip_epsilon.log1p()) + + def _set_in_keys(self): + keys = [] + if getattr(self, "actor_network", None) is not None and hasattr( + self.actor_network, "in_keys" + ): + in_keys = self.actor_network.in_keys + if isinstance(in_keys, (list, tuple)): + keys.extend(in_keys) + keys.append(self.tensor_keys.action) + keys.append(self.tensor_keys.sample_log_prob) + keys.append(self.tensor_keys.advantage) + keys.append(self.tensor_keys.ref_log_probs) + self._in_keys = list(dict.fromkeys(keys)) + + @property + def in_keys(self): + if getattr(self, "_in_keys", None) is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if getattr(self, "_out_keys", None) is None: + keys = ["loss_objective", "clip_fraction", "ESS", "kl_approx"] + if self.entropy_bonus: + keys.extend(["entropy", "loss_entropy"]) + keys.extend( + [ + "loss_kl_to_ref", + "kl_to_ref", + "loss_kl_to_inference", + "kl_to_inference", + ] + ) + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + def _forward_value_estimator_keys(self, **kwargs) -> None: + # No value estimator in GRPO; simply refresh input keys + self._set_in_keys() def _get_cur_log_prob(self, tensordict): """Override to use LLM-specific distribution with explicit masking strategy. @@ -238,12 +291,16 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: # - We may not have the tokens yet. If not, we will use the tokenizer of the actor to tokenize the text. # We default to history rather than text because the history will account for multiturn, or multimodal inputs. if self.tensor_keys.action not in tensordict: - raise ValueError + raise ValueError(f"Action key {self.tensor_keys.action} not in tensordict.") tensordict = tensordict.copy() advantage = tensordict.get( self.tensor_keys.advantage, None, as_padded_tensor=True ) + if advantage is None: + raise ValueError( + f"Advantage key {self.tensor_keys.advantage} not in tensordict." + ) log_weight, dist, kl_approx = self._log_weight( tensordict, adv_shape=advantage.shape[:-1] ) @@ -281,11 +338,6 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: entropy = _sum_td_features(entropy) td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coeff * entropy) - if self._has_critic: - loss_critic, value_clip_fraction = self.loss_critic(tensordict) - td_out.set("loss_critic", loss_critic) - if value_clip_fraction is not None: - td_out.set("value_clip_fraction", value_clip_fraction) td_out.set("ESS", _reduce(ess / batch, self.reduction)) td_out = td_out.named_apply( @@ -323,10 +375,45 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: del tensordict["_cur_log_prob"] return GRPOLossOutput.from_tensordict(td_out) + def _get_entropy( + self, dist: d.Distribution, adv_shape: torch.Size + ) -> torch.Tensor | TensorDict: + try: + entropy = dist.entropy() + if not entropy.isfinite().all(): + del entropy + if VERBOSE: + torchrl_logger.info( + "Entropy is not finite. Using Monte Carlo sampling." + ) + raise NotImplementedError + except NotImplementedError: + if VERBOSE: + torchrl_logger.warning( + f"Entropy not implemented for {type(dist)} or is not finite. Using Monte Carlo sampling." + ) + if getattr(dist, "has_rsample", False): + x = dist.rsample((self.samples_mc_entropy,)) + else: + x = dist.sample((self.samples_mc_entropy,)) + with set_composite_lp_aggregate(False) if isinstance( + dist, CompositeDistribution + ) else contextlib.nullcontext(): + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) + if is_tensor_collection(entropy) and entropy.batch_size != adv_shape: + entropy.batch_size = adv_shape + return entropy.unsqueeze(-1) + def _kl_to_ref( self, tensordict: TensorDictBase, - key: NestedKey = ("next", "ref_log_prob"), + key: NestedKey = ("next", "ref_log_probs"), ref_log_prob: torch.Tensor | None = None, coeff: float | None = None, mask: torch.Tensor | None = None, diff --git a/torchrl/objectives/llm/sft.py b/torchrl/objectives/llm/sft.py index 5c70ee51ce7..1c9f3582177 100644 --- a/torchrl/objectives/llm/sft.py +++ b/torchrl/objectives/llm/sft.py @@ -126,7 +126,7 @@ class SFTLoss(LossModule): .. note:: The input tensordict is expected to contain the following keys by default: - ``("next", "history")``: The chat history - - ``("next", "ref_log_prob")`` (optional): Reference model log probabilities, required if kl_to_ref_coeff is set + - ``("next", "ref_log_probs")`` (optional): Reference model log probabilities, required if kl_to_ref_coeff is set These keys can be customized using the ``set_keys()`` method. @@ -215,7 +215,7 @@ class SFTLoss(LossModule): >>> >>> # Apply the transform to get reference log probabilities >>> data = transform(data) - >>> assert "ref_log_prob" in data["next"].keys() + >>> assert "ref_log_probs" in data["next"].keys() >>> >>> # Use with SFTLoss for KL regularization >>> loss = SFTLoss( @@ -244,7 +244,7 @@ class _AcceptedKeys: history (NestedKey): The input tensordict key where the chat history is expected. Defaults to ``("next", "history")``. ref_log_prob (NestedKey): The input tensordict key where the reference model log probabilities are expected. - Only used when kl_to_ref_coeff is set. Defaults to ``("next", "ref_log_prob")``. + Only used when kl_to_ref_coeff is set. Defaults to ``("next", "ref_log_probs")``. log_probs (NestedKey): The output tensordict key where the model's log probabilities will be written. Defaults to ``"log_probs"``. """ @@ -447,7 +447,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ref_log_probs = tensordict.get(self.tensor_keys.ref_log_prob, as_list=True) if ref_log_probs is None: raise ValueError( - f"Reference log probs not found at {self.tensor_keys.ref_log_prob=} in tensordict with keys {tensordict.keys()} but loss_function is 'minor_sft'" + f"Reference log probs not found at {self.tensor_keys.ref_log_prob=} in tensordict with keys {tensordict.keys(True, True)} but loss_function is 'minor_sft'" ) # we need to re-sum ref_log_probs as they are not summed per-sequence