Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/source/reference/llms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1155,20 +1155,22 @@ Objectives

LLM post-training requires specialized loss functions that are adapted to the unique characteristics of language models.

GRPO
~~~~

The :class:`~torchrl.objectives.llm.GRPOLoss` class is a thin wrapper around the :class:`~torchrl.objectives.PPOLoss` class
that codes the LLM-specific functionalities.
GRPO, DAPO, CISPO
^^^^^^^^^^^^^^^^^

.. currentmodule:: torchrl.objectives.llm

.. autosummary::
:toctree: generated/
:template: rl_template.rst

LLMLossOutput
GRPOLoss
GRPOLossOutput
CISPOLoss
CISPOLossOutput
DAPO
DAPOLossOutput
MCAdvantage

SFT
Expand Down
71 changes: 69 additions & 2 deletions test/llm/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from torchrl.envs.llm.transforms.kl import RetrieveLogProb
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
from torchrl.modules.llm.policies.common import ChatHistory, Masks, Text, Tokens
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
from torchrl.objectives.llm.grpo import (
CISPO,
CISPOLossOutput,
GRPOLoss,
GRPOLossOutput,
MCAdvantage,
)
from torchrl.objectives.llm.sft import SFTLoss

_has_transformers = importlib.util.find_spec("transformers") is not None
Expand Down Expand Up @@ -203,7 +209,6 @@ def test_grpo(self, mock_transformer_model, dapo):
loss_vals = loss_fn(data)

# Assertions: Check output type and structure
from torchrl.objectives.llm.grpo import GRPOLossOutput

assert isinstance(
loss_vals, GRPOLossOutput
Expand Down Expand Up @@ -240,6 +245,68 @@ def test_grpo(self, mock_transformer_model, dapo):
0 <= loss_vals.clip_fraction <= 1
), f"clip_fraction out of range: {loss_vals.clip_fraction}"

def test_cispo(self, mock_transformer_model):
"""Test CISPO loss computation with mock models."""
vocab_size = 1024
device = torch.device("cpu")
eps = 0.20

# Create mock model and wrap it
model = mock_transformer_model(vocab_size=vocab_size, device=device)
actor_network = TransformersWrapper(
model,
generate=False,
pad_output=True,
input_mode="history",
)

# Create loss module

loss_fn = CISPO(actor_network, clip_epsilon=eps)

# Create fake data
data = _mock_data_grpo(vocab_size=vocab_size, device=device)

# Compute loss
loss_vals = loss_fn(data)

# Assertions: Check output type and structure

assert isinstance(
loss_vals, CISPOLossOutput
), f"Expected CISPOLossOutput, got {type(loss_vals)}"

# Check that all expected keys are present (same as GRPO)
assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective"
assert hasattr(loss_vals, "clip_fraction"), "Missing clip_fraction"
assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx"
assert hasattr(loss_vals, "ESS"), "Missing ESS"
assert hasattr(loss_vals, "entropy"), "Missing entropy"
assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy"

# Check tensor shapes (all losses should be scalars after reduction)
assert (
loss_vals.loss_objective.shape == ()
), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}"
assert (
loss_vals.clip_fraction.shape == ()
), f"clip_fraction should be scalar, got {loss_vals.clip_fraction.shape}"
assert (
loss_vals.kl_approx.shape == ()
), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}"
assert (
loss_vals.ESS.shape == ()
), f"ESS should be scalar, got {loss_vals.ESS.shape}"

# Check that losses are finite
assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite"
assert torch.isfinite(loss_vals.ESS), "ESS is not finite"

# Check that clip_fraction is in valid range [0, 1]
assert (
0 <= loss_vals.clip_fraction <= 1
), f"clip_fraction out of range: {loss_vals.clip_fraction}"


class TestSFT:
@pytest.fixture(scope="class")
Expand Down
24 changes: 22 additions & 2 deletions torchrl/objectives/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,27 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from .grpo import GRPOLoss, GRPOLossOutput, MCAdvantage
from .grpo import (
CISPO,
CISPOLossOutput,
DAPO,
DAPOLossOutput,
GRPOLoss,
GRPOLossOutput,
LLMLossOutput,
MCAdvantage,
)
from .sft import SFTLoss, SFTLossOutput

__all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage", "SFTLoss", "SFTLossOutput"]
__all__ = [
"CISPO",
"CISPOLossOutput",
"DAPO",
"DAPOLossOutput",
"GRPOLoss",
"GRPOLossOutput",
"LLMLossOutput",
"MCAdvantage",
"SFTLoss",
"SFTLossOutput",
]
115 changes: 98 additions & 17 deletions torchrl/objectives/llm/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Literal
from typing import Literal, TypeVar

import torch
from tensordict import (
Expand All @@ -33,8 +33,12 @@
from torchrl.objectives.utils import _reduce, _sum_td_features


class GRPOLossOutput(TensorClass["nocast"]):
"""GRPO Loss Output."""
class LLMLossOutput(TensorClass["nocast"]):
"""Base class for LLM loss outputs.

This base class defines the common structure for all LLM-based policy optimization
loss outputs (GRPO, DAPO, CISPO, etc.).
"""

loss_objective: torch.Tensor
clip_fraction: torch.Tensor
Expand All @@ -48,6 +52,21 @@ class GRPOLossOutput(TensorClass["nocast"]):
kl_to_inference: torch.Tensor | None = None


LLMOutputType = TypeVar("LLMOutputType", bound=LLMLossOutput)


class GRPOLossOutput(LLMLossOutput):
"""GRPO Loss Output."""


class DAPOLossOutput(LLMLossOutput):
"""DAPO Loss Output."""


class CISPOLossOutput(LLMLossOutput):
"""CISPO Loss Output."""


class GRPOLoss(LossModule):
"""GRPO loss.

Expand Down Expand Up @@ -123,6 +142,7 @@ class GRPOLoss(LossModule):
"""

actor_network: LLMWrapperBase
output_type: type[LLMLossOutput] = GRPOLossOutput

@dataclass
class _AcceptedKeys(LossModule._AcceptedKeys):
Expand All @@ -137,6 +157,33 @@ class _AcceptedKeys(LossModule._AcceptedKeys):
sample_log_prob: NestedKey = ("log_probs", "full")
ref_log_probs: NestedKey = ("next", "ref_log_probs", "full")

@property
def tensor_keys(self) -> _AcceptedKeys:
"""Access the tensordict key configuration for this loss.

This property provides access to the configurable keys used by the loss module
to read tensors from input TensorDicts. These keys include:

- ``advantage``: key for the advantage values
- ``action``: key for the action tokens (default: ``("tokens", "full")``)
- ``sample_log_prob``: key for the log probabilities from the reference policy (default: ``("log_probs", "full")``)
- ``ref_log_probs``: key for the reference policy log probabilities (default: ``("next", "ref_log_probs", "full")``)

To modify these keys, use the :meth:`~.set_keys` method.

Examples:
>>> loss = GRPOLoss(actor_network)
>>> # Access current keys
>>> print(loss.tensor_keys.advantage) # "advantage"
>>> # Modify keys
>>> loss.set_keys(advantage="my_advantage_key")
>>> print(loss.tensor_keys.advantage) # "my_advantage_key"

Returns:
An instance of _AcceptedKeys containing all configurable tensordict keys.
"""
return self._tensor_keys

def __init__(
self,
actor_network: LLMWrapperBase | None = None,
Expand Down Expand Up @@ -316,7 +363,7 @@ def _get_cur_log_prob(self, tensordict):
)
return log_prob, dist, False

def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
def forward(self, tensordict: TensorDictBase) -> LLMOutputType:
# Some sanity checks and housekeeping:
# - We may not have the tokens yet. If not, we will use the tokenizer of the actor to tokenize the text.
# We default to history rather than text because the history will account for multiturn, or multimodal inputs.
Expand Down Expand Up @@ -348,16 +395,10 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
raise ValueError(
f"advantage and log_weight must have the same number of dimensions, got {advantage.ndim=} and {log_weight.ndim=}"
)
gain1 = log_weight.exp() * advantage

log_weight_clip = log_weight.clamp(*self._clip_bounds)
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
ratio = log_weight_clip.exp()
gain2 = ratio * advantage

# Token-level objective: compute min over clipped/unclipped at the token level
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
td_out = TensorDict({"loss_objective": -gain})
loss_objective, clip_fraction = self._compute_policy_objective(
log_weight, advantage
)
td_out = TensorDict({"loss_objective": loss_objective})
td_out.set("clip_fraction", clip_fraction)
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging

Expand Down Expand Up @@ -404,7 +445,22 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
td_out["loss_kl_to_inference"] = loss_kl
td_out["kl_to_inference"] = kl_penalty.detach()
del tensordict["_cur_log_prob"]
return GRPOLossOutput.from_tensordict(td_out)
return self.output_type.from_tensordict(td_out)

def _compute_policy_objective(
self, log_weight: torch.Tensor, advantage: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Default GRPO objective: PPO-style min between unclipped and clipped ratios.

Returns (loss_objective, clip_fraction).
"""
gain1 = log_weight.exp() * advantage
log_weight_clip = log_weight.clamp(*self._clip_bounds)
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
ratio = log_weight_clip.exp()
gain2 = ratio * advantage
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
return -gain, clip_fraction

def _get_entropy(
self, dist: d.Distribution, adv_shape: torch.Size
Expand Down Expand Up @@ -548,10 +604,12 @@ def _log_weight(
class DAPO(GRPOLoss):
"""DAPO (Clip-Higher over GRPO).

Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO
[arXiv](https://arxiv.org/html/2503.14476).
Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in
the `DAPO <https://arxiv.org/html/2503.14476>`_ paper.
"""

output_type: type[LLMLossOutput] = DAPOLossOutput

def __init__(
self,
tensordict: TensorDictBase,
Expand Down Expand Up @@ -594,6 +652,29 @@ def __init__(
return coeff * kl_penalty, kl_penalty


class CISPO(GRPOLoss):
"""CISPO (Clipped Importance Sampling Policy Optimization).

Inherits the GRPO pipeline (masking, ESS, entropy, optional KL penalties) but
replaces the PPO-style min with a clipped-importance objective:
loss = - clip(weight, [1 - eps_low, 1 + eps_high]) * advantage

See the `MiniMax-M1 (CISPO) <https://arxiv.org/html/2506.13585>`_ paper.
"""

output_type: type[LLMLossOutput] = CISPOLossOutput

def _compute_policy_objective(
self, log_weight: torch.Tensor, advantage: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# CISPO: use clipped importance weights directly
log_weight_clip = log_weight.clamp(*self._clip_bounds)
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
ratio = log_weight_clip.exp()
gain = ratio * advantage
return -gain, clip_fraction


class MCAdvantage(Transform):
"""Monte-Carlo advantage computation engine.

Expand Down
Loading