Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,8 @@ exclude = [

[tool.usort]
first_party_detection = false

[project.entry-points."vllm.general_plugins"]
# Ensure FP32 overrides are registered in all vLLM processes (main, workers, and
# the registry subprocess) before resolving model classes.
fp32_overrides = "torchrl.modules.llm.backends.vllm_plugin:register_fp32_overrides"
12 changes: 12 additions & 0 deletions sota-implementations/grpo/grpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,18 @@ def get_inference_model(
f"Setting VLLM_ATTENTION_BACKEND={vllm_backend} (from config: {attn_impl})"
)

# Handle FP32 output configuration
if hasattr(cfg.inference_model, "enable_fp32_output"):
enable_fp32 = cfg.inference_model.enable_fp32_output
if enable_fp32:
os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
torchrl_logger.info(
"Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). "
"This will use FP32 for the final output layer if the model supports it."
)
# Add to inference params so it gets passed to AsyncVLLM
inference_params["enable_fp32_output"] = enable_fp32

# Add other common vLLM parameters from config if present
optional_vllm_params = [
"max_model_len",
Expand Down
77 changes: 77 additions & 0 deletions test/llm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Shared test fixtures and mock infrastructure for LLM tests."""
from __future__ import annotations

import pytest
import torch


class MockTransformerConfig:
"""Mock config to mimic transformers model config."""

def __init__(self, vocab_size: int, max_position_embeddings: int = 2048):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = vocab_size # For simplicity


class MockTransformerOutput:
"""Mock output that mimics transformers model output with dict-like access."""

def __init__(self, logits):
self.logits = logits

def __getitem__(self, key):
"""Allow dict-like access for compatibility."""
if key == "logits":
return self.logits
raise KeyError(f"Key {key} not found in model output")


class MockTransformerModel(torch.nn.Module):
"""Mock transformer model that mimics the structure of HuggingFace models."""

def __init__(self, vocab_size: int, device: torch.device | str | int = "cpu"):
super().__init__()
device = torch.device(device)
self.config = MockTransformerConfig(vocab_size)
# Simple embedding layer that maps tokens to logits
self.embedding = torch.nn.Embedding(vocab_size, vocab_size, device=device)
self.device = device

def forward(self, input_ids, attention_mask=None, **kwargs):
"""Forward pass that returns logits in the expected format."""
# Get embeddings (which we'll use as logits for simplicity)
logits = self.embedding(input_ids % self.config.vocab_size)
# Return output object similar to transformers models
return MockTransformerOutput(logits)

def get_tokenizer(self):
from transformers import AutoTokenizer

return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")


@pytest.fixture
def mock_transformer_model():
"""Fixture that provides a mock transformer model factory."""

def _make_model(
vocab_size: int = 1024, device: torch.device | str | int = "cpu"
) -> MockTransformerModel:
"""Make a mock transformer model."""
device = torch.device(device)
return MockTransformerModel(vocab_size, device)

return _make_model


@pytest.fixture
def mock_tokenizer():
"""Fixture that provides a mock tokenizer."""
from transformers import AutoTokenizer

return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
Loading
Loading