Skip to content

Commit 01d2801

Browse files
committed
[Feature] float32 patch
ghstack-source-id: e5f1a7a Pull-Request: #3219
1 parent 963fdd4 commit 01d2801

File tree

8 files changed

+657
-11
lines changed

8 files changed

+657
-11
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,8 @@ exclude = [
145145

146146
[tool.usort]
147147
first_party_detection = false
148+
149+
[project.entry-points."vllm.general_plugins"]
150+
# Ensure FP32 overrides are registered in all vLLM processes (main, workers, and
151+
# the registry subprocess) before resolving model classes.
152+
fp32_overrides = "torchrl.modules.llm.backends.vllm_plugin:register_fp32_overrides"

sota-implementations/grpo/grpo_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,18 @@ def get_inference_model(
259259
f"Setting VLLM_ATTENTION_BACKEND={vllm_backend} (from config: {attn_impl})"
260260
)
261261

262+
# Handle FP32 output configuration
263+
if hasattr(cfg.inference_model, "enable_fp32_output"):
264+
enable_fp32 = cfg.inference_model.enable_fp32_output
265+
if enable_fp32:
266+
os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
267+
torchrl_logger.info(
268+
"Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). "
269+
"This will use FP32 for the final output layer if the model supports it."
270+
)
271+
# Add to inference params so it gets passed to AsyncVLLM
272+
inference_params["enable_fp32_output"] = enable_fp32
273+
262274
# Add other common vLLM parameters from config if present
263275
optional_vllm_params = [
264276
"max_model_len",

test/llm/conftest.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Shared test fixtures and mock infrastructure for LLM tests."""
6+
from __future__ import annotations
7+
8+
import pytest
9+
import torch
10+
11+
12+
class MockTransformerConfig:
13+
"""Mock config to mimic transformers model config."""
14+
15+
def __init__(self, vocab_size: int, max_position_embeddings: int = 2048):
16+
self.vocab_size = vocab_size
17+
self.max_position_embeddings = max_position_embeddings
18+
self.hidden_size = vocab_size # For simplicity
19+
20+
21+
class MockTransformerOutput:
22+
"""Mock output that mimics transformers model output with dict-like access."""
23+
24+
def __init__(self, logits):
25+
self.logits = logits
26+
27+
def __getitem__(self, key):
28+
"""Allow dict-like access for compatibility."""
29+
if key == "logits":
30+
return self.logits
31+
raise KeyError(f"Key {key} not found in model output")
32+
33+
34+
class MockTransformerModel(torch.nn.Module):
35+
"""Mock transformer model that mimics the structure of HuggingFace models."""
36+
37+
def __init__(self, vocab_size: int, device: torch.device | str | int = "cpu"):
38+
super().__init__()
39+
device = torch.device(device)
40+
self.config = MockTransformerConfig(vocab_size)
41+
# Simple embedding layer that maps tokens to logits
42+
self.embedding = torch.nn.Embedding(vocab_size, vocab_size, device=device)
43+
self.device = device
44+
45+
def forward(self, input_ids, attention_mask=None, **kwargs):
46+
"""Forward pass that returns logits in the expected format."""
47+
# Get embeddings (which we'll use as logits for simplicity)
48+
logits = self.embedding(input_ids % self.config.vocab_size)
49+
# Return output object similar to transformers models
50+
return MockTransformerOutput(logits)
51+
52+
def get_tokenizer(self):
53+
from transformers import AutoTokenizer
54+
55+
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
56+
57+
58+
@pytest.fixture
59+
def mock_transformer_model():
60+
"""Fixture that provides a mock transformer model factory."""
61+
62+
def _make_model(
63+
vocab_size: int = 1024, device: torch.device | str | int = "cpu"
64+
) -> MockTransformerModel:
65+
"""Make a mock transformer model."""
66+
device = torch.device(device)
67+
return MockTransformerModel(vocab_size, device)
68+
69+
return _make_model
70+
71+
72+
@pytest.fixture
73+
def mock_tokenizer():
74+
"""Fixture that provides a mock tokenizer."""
75+
from transformers import AutoTokenizer
76+
77+
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")

0 commit comments

Comments
 (0)