diff --git a/docs/source/reference/llms.rst b/docs/source/reference/llms.rst index 3a04bc5c95c..4ab9409d62f 100644 --- a/docs/source/reference/llms.rst +++ b/docs/source/reference/llms.rst @@ -1155,11 +1155,8 @@ Objectives LLM post-training requires specialized loss functions that are adapted to the unique characteristics of language models. -GRPO -~~~~ - -The :class:`~torchrl.objectives.llm.GRPOLoss` class is a thin wrapper around the :class:`~torchrl.objectives.PPOLoss` class -that codes the LLM-specific functionalities. +GRPO, DAPO, CISPO +^^^^^^^^^^^^^^^^^ .. currentmodule:: torchrl.objectives.llm @@ -1167,8 +1164,13 @@ that codes the LLM-specific functionalities. :toctree: generated/ :template: rl_template.rst + LLMLossOutput GRPOLoss GRPOLossOutput + CISPOLoss + CISPOLossOutput + DAPO + DAPOLossOutput MCAdvantage SFT diff --git a/test/llm/test_objectives.py b/test/llm/test_objectives.py index 3c09a252ea8..e1cd3a61eb8 100644 --- a/test/llm/test_objectives.py +++ b/test/llm/test_objectives.py @@ -16,7 +16,13 @@ from torchrl.envs.llm.transforms.kl import RetrieveLogProb from torchrl.modules.llm import TransformersWrapper, vLLMWrapper from torchrl.modules.llm.policies.common import ChatHistory, Masks, Text, Tokens -from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage +from torchrl.objectives.llm.grpo import ( + CISPO, + CISPOLossOutput, + GRPOLoss, + GRPOLossOutput, + MCAdvantage, +) from torchrl.objectives.llm.sft import SFTLoss _has_transformers = importlib.util.find_spec("transformers") is not None @@ -203,7 +209,6 @@ def test_grpo(self, mock_transformer_model, dapo): loss_vals = loss_fn(data) # Assertions: Check output type and structure - from torchrl.objectives.llm.grpo import GRPOLossOutput assert isinstance( loss_vals, GRPOLossOutput @@ -240,6 +245,68 @@ def test_grpo(self, mock_transformer_model, dapo): 0 <= loss_vals.clip_fraction <= 1 ), f"clip_fraction out of range: {loss_vals.clip_fraction}" + def test_cispo(self, mock_transformer_model): + """Test CISPO loss computation with mock models.""" + vocab_size = 1024 + device = torch.device("cpu") + eps = 0.20 + + # Create mock model and wrap it + model = mock_transformer_model(vocab_size=vocab_size, device=device) + actor_network = TransformersWrapper( + model, + generate=False, + pad_output=True, + input_mode="history", + ) + + # Create loss module + + loss_fn = CISPO(actor_network, clip_epsilon=eps) + + # Create fake data + data = _mock_data_grpo(vocab_size=vocab_size, device=device) + + # Compute loss + loss_vals = loss_fn(data) + + # Assertions: Check output type and structure + + assert isinstance( + loss_vals, CISPOLossOutput + ), f"Expected CISPOLossOutput, got {type(loss_vals)}" + + # Check that all expected keys are present (same as GRPO) + assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective" + assert hasattr(loss_vals, "clip_fraction"), "Missing clip_fraction" + assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx" + assert hasattr(loss_vals, "ESS"), "Missing ESS" + assert hasattr(loss_vals, "entropy"), "Missing entropy" + assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy" + + # Check tensor shapes (all losses should be scalars after reduction) + assert ( + loss_vals.loss_objective.shape == () + ), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}" + assert ( + loss_vals.clip_fraction.shape == () + ), f"clip_fraction should be scalar, got {loss_vals.clip_fraction.shape}" + assert ( + loss_vals.kl_approx.shape == () + ), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}" + assert ( + loss_vals.ESS.shape == () + ), f"ESS should be scalar, got {loss_vals.ESS.shape}" + + # Check that losses are finite + assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite" + assert torch.isfinite(loss_vals.ESS), "ESS is not finite" + + # Check that clip_fraction is in valid range [0, 1] + assert ( + 0 <= loss_vals.clip_fraction <= 1 + ), f"clip_fraction out of range: {loss_vals.clip_fraction}" + class TestSFT: @pytest.fixture(scope="class") diff --git a/torchrl/objectives/llm/__init__.py b/torchrl/objectives/llm/__init__.py index eb3920845d5..0a4cc9fb65f 100644 --- a/torchrl/objectives/llm/__init__.py +++ b/torchrl/objectives/llm/__init__.py @@ -4,7 +4,27 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from .grpo import GRPOLoss, GRPOLossOutput, MCAdvantage +from .grpo import ( + CISPO, + CISPOLossOutput, + DAPO, + DAPOLossOutput, + GRPOLoss, + GRPOLossOutput, + LLMLossOutput, + MCAdvantage, +) from .sft import SFTLoss, SFTLossOutput -__all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage", "SFTLoss", "SFTLossOutput"] +__all__ = [ + "CISPO", + "CISPOLossOutput", + "DAPO", + "DAPOLossOutput", + "GRPOLoss", + "GRPOLossOutput", + "LLMLossOutput", + "MCAdvantage", + "SFTLoss", + "SFTLossOutput", +] diff --git a/torchrl/objectives/llm/grpo.py b/torchrl/objectives/llm/grpo.py index 2cdab05be9f..e11f04509e9 100644 --- a/torchrl/objectives/llm/grpo.py +++ b/torchrl/objectives/llm/grpo.py @@ -8,7 +8,7 @@ from collections import defaultdict, deque from dataclasses import dataclass -from typing import Literal +from typing import Literal, TypeVar import torch from tensordict import ( @@ -33,8 +33,12 @@ from torchrl.objectives.utils import _reduce, _sum_td_features -class GRPOLossOutput(TensorClass["nocast"]): - """GRPO Loss Output.""" +class LLMLossOutput(TensorClass["nocast"]): + """Base class for LLM loss outputs. + + This base class defines the common structure for all LLM-based policy optimization + loss outputs (GRPO, DAPO, CISPO, etc.). + """ loss_objective: torch.Tensor clip_fraction: torch.Tensor @@ -48,6 +52,21 @@ class GRPOLossOutput(TensorClass["nocast"]): kl_to_inference: torch.Tensor | None = None +LLMOutputType = TypeVar("LLMOutputType", bound=LLMLossOutput) + + +class GRPOLossOutput(LLMLossOutput): + """GRPO Loss Output.""" + + +class DAPOLossOutput(LLMLossOutput): + """DAPO Loss Output.""" + + +class CISPOLossOutput(LLMLossOutput): + """CISPO Loss Output.""" + + class GRPOLoss(LossModule): """GRPO loss. @@ -123,6 +142,7 @@ class GRPOLoss(LossModule): """ actor_network: LLMWrapperBase + output_type: type[LLMLossOutput] = GRPOLossOutput @dataclass class _AcceptedKeys(LossModule._AcceptedKeys): @@ -137,6 +157,33 @@ class _AcceptedKeys(LossModule._AcceptedKeys): sample_log_prob: NestedKey = ("log_probs", "full") ref_log_probs: NestedKey = ("next", "ref_log_probs", "full") + @property + def tensor_keys(self) -> _AcceptedKeys: + """Access the tensordict key configuration for this loss. + + This property provides access to the configurable keys used by the loss module + to read tensors from input TensorDicts. These keys include: + + - ``advantage``: key for the advantage values + - ``action``: key for the action tokens (default: ``("tokens", "full")``) + - ``sample_log_prob``: key for the log probabilities from the reference policy (default: ``("log_probs", "full")``) + - ``ref_log_probs``: key for the reference policy log probabilities (default: ``("next", "ref_log_probs", "full")``) + + To modify these keys, use the :meth:`~.set_keys` method. + + Examples: + >>> loss = GRPOLoss(actor_network) + >>> # Access current keys + >>> print(loss.tensor_keys.advantage) # "advantage" + >>> # Modify keys + >>> loss.set_keys(advantage="my_advantage_key") + >>> print(loss.tensor_keys.advantage) # "my_advantage_key" + + Returns: + An instance of _AcceptedKeys containing all configurable tensordict keys. + """ + return self._tensor_keys + def __init__( self, actor_network: LLMWrapperBase | None = None, @@ -316,7 +363,7 @@ def _get_cur_log_prob(self, tensordict): ) return log_prob, dist, False - def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: + def forward(self, tensordict: TensorDictBase) -> LLMOutputType: # Some sanity checks and housekeeping: # - We may not have the tokens yet. If not, we will use the tokenizer of the actor to tokenize the text. # We default to history rather than text because the history will account for multiturn, or multimodal inputs. @@ -348,16 +395,10 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: raise ValueError( f"advantage and log_weight must have the same number of dimensions, got {advantage.ndim=} and {log_weight.ndim=}" ) - gain1 = log_weight.exp() * advantage - - log_weight_clip = log_weight.clamp(*self._clip_bounds) - clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean() - ratio = log_weight_clip.exp() - gain2 = ratio * advantage - - # Token-level objective: compute min over clipped/unclipped at the token level - gain = torch.stack([gain1, gain2], -1).min(dim=-1).values - td_out = TensorDict({"loss_objective": -gain}) + loss_objective, clip_fraction = self._compute_policy_objective( + log_weight, advantage + ) + td_out = TensorDict({"loss_objective": loss_objective}) td_out.set("clip_fraction", clip_fraction) td_out.set("kl_approx", kl_approx.detach().mean()) # for logging @@ -404,7 +445,22 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput: td_out["loss_kl_to_inference"] = loss_kl td_out["kl_to_inference"] = kl_penalty.detach() del tensordict["_cur_log_prob"] - return GRPOLossOutput.from_tensordict(td_out) + return self.output_type.from_tensordict(td_out) + + def _compute_policy_objective( + self, log_weight: torch.Tensor, advantage: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Default GRPO objective: PPO-style min between unclipped and clipped ratios. + + Returns (loss_objective, clip_fraction). + """ + gain1 = log_weight.exp() * advantage + log_weight_clip = log_weight.clamp(*self._clip_bounds) + clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean() + ratio = log_weight_clip.exp() + gain2 = ratio * advantage + gain = torch.stack([gain1, gain2], -1).min(dim=-1).values + return -gain, clip_fraction def _get_entropy( self, dist: d.Distribution, adv_shape: torch.Size @@ -548,10 +604,12 @@ def _log_weight( class DAPO(GRPOLoss): """DAPO (Clip-Higher over GRPO). - Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO - [arXiv](https://arxiv.org/html/2503.14476). + Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in + the `DAPO `_ paper. """ + output_type: type[LLMLossOutput] = DAPOLossOutput + def __init__( self, tensordict: TensorDictBase, @@ -594,6 +652,29 @@ def __init__( return coeff * kl_penalty, kl_penalty +class CISPO(GRPOLoss): + """CISPO (Clipped Importance Sampling Policy Optimization). + + Inherits the GRPO pipeline (masking, ESS, entropy, optional KL penalties) but + replaces the PPO-style min with a clipped-importance objective: + loss = - clip(weight, [1 - eps_low, 1 + eps_high]) * advantage + + See the `MiniMax-M1 (CISPO) `_ paper. + """ + + output_type: type[LLMLossOutput] = CISPOLossOutput + + def _compute_policy_objective( + self, log_weight: torch.Tensor, advantage: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + # CISPO: use clipped importance weights directly + log_weight_clip = log_weight.clamp(*self._clip_bounds) + clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean() + ratio = log_weight_clip.exp() + gain = ratio * advantage + return -gain, clip_fraction + + class MCAdvantage(Transform): """Monte-Carlo advantage computation engine.