diff --git a/pyproject.toml b/pyproject.toml index f156381d153..89e0b6fc33b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,3 +145,8 @@ exclude = [ [tool.usort] first_party_detection = false + +[project.entry-points."vllm.general_plugins"] +# Ensure FP32 overrides are registered in all vLLM processes (main, workers, and +# the registry subprocess) before resolving model classes. +fp32_overrides = "torchrl.modules.llm.backends.vllm_plugin:register_fp32_overrides" diff --git a/sota-implementations/grpo/grpo_utils.py b/sota-implementations/grpo/grpo_utils.py index 5b05136fc0b..a550bb82710 100644 --- a/sota-implementations/grpo/grpo_utils.py +++ b/sota-implementations/grpo/grpo_utils.py @@ -259,6 +259,18 @@ def get_inference_model( f"Setting VLLM_ATTENTION_BACKEND={vllm_backend} (from config: {attn_impl})" ) + # Handle FP32 output configuration + if hasattr(cfg.inference_model, "enable_fp32_output"): + enable_fp32 = cfg.inference_model.enable_fp32_output + if enable_fp32: + os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1" + torchrl_logger.info( + "Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). " + "This will use FP32 for the final output layer if the model supports it." + ) + # Add to inference params so it gets passed to AsyncVLLM + inference_params["enable_fp32_output"] = enable_fp32 + # Add other common vLLM parameters from config if present optional_vllm_params = [ "max_model_len", diff --git a/test/llm/conftest.py b/test/llm/conftest.py new file mode 100644 index 00000000000..a69bdce6700 --- /dev/null +++ b/test/llm/conftest.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Shared test fixtures and mock infrastructure for LLM tests.""" +from __future__ import annotations + +import pytest +import torch + + +class MockTransformerConfig: + """Mock config to mimic transformers model config.""" + + def __init__(self, vocab_size: int, max_position_embeddings: int = 2048): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = vocab_size # For simplicity + + +class MockTransformerOutput: + """Mock output that mimics transformers model output with dict-like access.""" + + def __init__(self, logits): + self.logits = logits + + def __getitem__(self, key): + """Allow dict-like access for compatibility.""" + if key == "logits": + return self.logits + raise KeyError(f"Key {key} not found in model output") + + +class MockTransformerModel(torch.nn.Module): + """Mock transformer model that mimics the structure of HuggingFace models.""" + + def __init__(self, vocab_size: int, device: torch.device | str | int = "cpu"): + super().__init__() + device = torch.device(device) + self.config = MockTransformerConfig(vocab_size) + # Simple embedding layer that maps tokens to logits + self.embedding = torch.nn.Embedding(vocab_size, vocab_size, device=device) + self.device = device + + def forward(self, input_ids, attention_mask=None, **kwargs): + """Forward pass that returns logits in the expected format.""" + # Get embeddings (which we'll use as logits for simplicity) + logits = self.embedding(input_ids % self.config.vocab_size) + # Return output object similar to transformers models + return MockTransformerOutput(logits) + + def get_tokenizer(self): + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") + + +@pytest.fixture +def mock_transformer_model(): + """Fixture that provides a mock transformer model factory.""" + + def _make_model( + vocab_size: int = 1024, device: torch.device | str | int = "cpu" + ) -> MockTransformerModel: + """Make a mock transformer model.""" + device = torch.device(device) + return MockTransformerModel(vocab_size, device) + + return _make_model + + +@pytest.fixture +def mock_tokenizer(): + """Fixture that provides a mock tokenizer.""" + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") diff --git a/test/llm/test_conversions.py b/test/llm/test_conversions.py new file mode 100644 index 00000000000..89bbd008e0f --- /dev/null +++ b/test/llm/test_conversions.py @@ -0,0 +1,448 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch +from tensordict import lazy_stack, TensorDict +from torchrl.data.llm import History +from torchrl.modules.llm.policies.common import ChatHistory, Text, Tokens + +# Test data +SIMPLE_CONVERSATION = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, +] + +MULTI_TURN_CONVERSATION = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + {"role": "user", "content": "Thanks!"}, + {"role": "assistant", "content": "You're welcome!"}, +] + + +@pytest.fixture +def tokenizer(): + """Get a tokenizer for testing.""" + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +class TestChatHistoryConversions: + """Test conversions from ChatHistory to Text and Tokens.""" + + def test_history_to_text_single(self, tokenizer): + """Test converting a single history to text.""" + history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(full=history) + + text = chat_history.to_text(tokenizer) + + assert isinstance(text, Text) + assert text.full is not None + assert isinstance(text.full, list) + assert len(text.full) == 1 + assert "Hello" in text.full[0] + assert "Hi there!" in text.full[0] + + def test_history_to_text_batch(self, tokenizer): + """Test converting a batch of histories to text.""" + histories = History.from_chats([SIMPLE_CONVERSATION, MULTI_TURN_CONVERSATION]) + # Create a batch of ChatHistory objects + assert histories.shape[0] == 2 + chat_histories = [ + ChatHistory(full=histories[i : i + 1]) for i in range(histories.shape[0]) + ] + chat_history_batch = lazy_stack(chat_histories) + + text = chat_history_batch.to_text(tokenizer) + + assert isinstance(text, Text) + assert text.full is not None + assert isinstance(text.full, list) + assert len(text.full) == 2 + assert "Hello" in text.full[0][0] + assert "helpful assistant" in text.full[1][0] + + def test_history_to_text_prompt_response(self, tokenizer): + """Test converting history with prompt and response to text.""" + prompt_history = History.from_chats([[SIMPLE_CONVERSATION[0]]]) + full_history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(prompt=prompt_history, full=full_history) + + text = chat_history.to_text(tokenizer) + + assert isinstance(text, Text) + assert text.prompt is not None + assert text.full is not None + assert text.response is not None + assert isinstance(text.prompt, list) + assert isinstance(text.full, list) + assert isinstance(text.response, list) + assert len(text.prompt) == 1 + assert len(text.full) == 1 + assert len(text.response) == 1 + # Response should be the part after prompt + assert text.full[0].startswith(text.prompt[0]) + assert text.full[0] == text.prompt[0] + text.response[0] + + def test_history_to_tokens_single(self, tokenizer): + """Test converting a single history to tokens.""" + history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(full=history) + + tokens = chat_history.to_tokens(tokenizer) + + assert isinstance(tokens, Tokens) + assert tokens.full is not None + # Check if it's a nested tensor by checking if it has the _values attribute + assert hasattr(tokens.full, "_values") or isinstance(tokens.full, torch.Tensor) + assert tokens.padded is False + + def test_history_to_tokens_batch(self, tokenizer): + """Test converting a batch of histories to tokens.""" + histories = History.from_chats([SIMPLE_CONVERSATION, MULTI_TURN_CONVERSATION]) + # Create a batch of ChatHistory objects + chat_histories = [ + ChatHistory(full=histories[i : i + 1]) for i in range(histories.shape[0]) + ] + chat_history_batch = lazy_stack(chat_histories) + + tokens = chat_history_batch.to_tokens(tokenizer) + + assert isinstance(tokens, Tokens) + assert (full := tokens.get("full", as_nested_tensor=True)) is not None + # Check if it's a nested tensor + assert hasattr(full, "_values") or isinstance(full, torch.Tensor) + assert not any(tokens.padded) + # Check batch size + assert tokens.batch_size[0] == 2 + + def test_history_to_tokens_prompt_response(self, tokenizer): + """Test converting history with prompt and response to tokens.""" + prompt_history = History.from_chats([[SIMPLE_CONVERSATION[0]]]) + full_history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(prompt=prompt_history, full=full_history) + + tokens = chat_history.to_tokens(tokenizer) + + assert isinstance(tokens, Tokens) + assert tokens.prompt is not None + assert tokens.full is not None + assert tokens.response is not None + # Check if they're nested tensors + assert hasattr(tokens.prompt, "_values") or isinstance( + tokens.prompt, torch.Tensor + ) + assert hasattr(tokens.full, "_values") or isinstance(tokens.full, torch.Tensor) + assert hasattr(tokens.response, "_values") or isinstance( + tokens.response, torch.Tensor + ) + # Response should be the part after prompt + prompt_tokens_list = tokens._tensordict.get("prompt", as_list=True) + full_tokens_list = tokens._tensordict.get("full", as_list=True) + response_tokens_list = tokens._tensordict.get("response", as_list=True) + prompt_len = prompt_tokens_list[0].shape[0] + full_len = full_tokens_list[0].shape[0] + response_len = response_tokens_list[0].shape[0] + assert full_len == prompt_len + response_len + + +class TestTokensConversions: + """Test conversions from Tokens to Text and ChatHistory.""" + + def test_tokens_to_text_single(self, tokenizer): + """Test converting tokens to text.""" + history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(full=history) + tokens = chat_history.to_tokens(tokenizer) + + text = tokens.to_text(tokenizer) + + assert isinstance(text, Text) + assert text.full is not None + assert isinstance(text.full, list) + assert len(text.full) == 1 + + def test_tokens_to_text_batch(self, tokenizer): + """Test converting a batch of tokens to text.""" + histories = History.from_chats([SIMPLE_CONVERSATION, MULTI_TURN_CONVERSATION]) + # Create a batch of ChatHistory objects + chat_histories = [ + ChatHistory(full=histories[i : i + 1]) for i in range(histories.shape[0]) + ] + chat_history_batch = lazy_stack(chat_histories) + tokens = chat_history_batch.to_tokens(tokenizer) + + text = tokens.to_text(tokenizer) + + assert isinstance(text, Text) + assert text.full is not None + assert isinstance(text.full, list) + assert len(text.full) == 2 + + def test_tokens_to_text_prompt_response(self, tokenizer): + """Test converting tokens with prompt and response to text.""" + prompt_history = History.from_chats([[SIMPLE_CONVERSATION[0]]]) + full_history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(prompt=prompt_history, full=full_history) + tokens = chat_history.to_tokens(tokenizer) + + text = tokens.to_text(tokenizer) + + assert isinstance(text, Text) + assert text.prompt is not None + assert text.full is not None + assert text.response is not None + + def test_tokens_to_text_padded_error(self, tokenizer): + """Test that padded tokens raise an error.""" + history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(full=history) + tokens = chat_history.to_tokens(tokenizer) + tokens.padded = True # Manually set to padded + + with pytest.raises(ValueError, match="padded tokens"): + tokens.to_text(tokenizer) + + def test_tokens_to_history_single(self, tokenizer): + """Test converting tokens to history.""" + history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(full=history) + tokens = chat_history.to_tokens(tokenizer) + + reconstructed_history = tokens.to_history(tokenizer) + + assert isinstance(reconstructed_history, ChatHistory) + assert reconstructed_history.full is not None + + def test_tokens_to_history_batch(self, tokenizer): + """Test converting a batch of tokens to history.""" + histories = History.from_chats([SIMPLE_CONVERSATION, MULTI_TURN_CONVERSATION]) + # Create a batch of ChatHistory objects + chat_histories = [ + ChatHistory(full=histories[i : i + 1]) for i in range(histories.shape[0]) + ] + chat_history_batch = lazy_stack(chat_histories) + tokens = chat_history_batch.to_tokens(tokenizer) + + reconstructed_history = tokens.to_history(tokenizer) + + assert isinstance(reconstructed_history, ChatHistory) + assert reconstructed_history.full is not None + assert reconstructed_history.batch_size[0] == 2 + + +class TestTextConversions: + """Test conversions from Text to Tokens and ChatHistory.""" + + def test_text_to_tokens_single(self, tokenizer): + """Test converting text to tokens.""" + text_obj = Text(full=["Hello, how are you?"]) + + tokens = text_obj.to_tokens(tokenizer) + + assert isinstance(tokens, Tokens) + assert tokens.full is not None + assert isinstance(tokens.full, torch.Tensor) + assert tokens.padded is False + + def test_text_to_tokens_batch(self, tokenizer): + """Test converting a batch of text to tokens.""" + text_obj = Text._from_tensordict(TensorDict(batch_size=(2,)).to_lazystack(0)) + with text_obj.view(-1) as text_flat: + text_flat.full = ["Hello, how are you?", "I'm doing great!"] + + tokens = text_obj.to_tokens(tokenizer) + + assert isinstance(tokens, Tokens) + assert (full := tokens.get("full", as_nested_tensor=True)) is not None + assert isinstance(full, torch.Tensor) + assert tokens.batch_size[0] == 2 + + def test_text_to_tokens_prompt_response(self, tokenizer): + """Test converting text with prompt and response to tokens.""" + text_obj = Text( + prompt=["Hello"], + response=[", how are you?"], + full=["Hello, how are you?"], + ) + + tokens = text_obj.to_tokens(tokenizer) + + assert isinstance(tokens, Tokens) + assert tokens.prompt is not None + assert tokens.response is not None + assert tokens.full is not None + + def test_text_to_tokens_padding_error(self, tokenizer): + """Test that padding raises an error.""" + text_obj = Text(full=["Hello, how are you?"]) + + with pytest.raises(ValueError, match="Padding is not yet supported"): + text_obj.to_tokens(tokenizer, padding=True) + + def test_text_to_history_single(self, tokenizer): + """Test converting text to history.""" + history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(full=history) + text_obj = chat_history.to_text(tokenizer) + + reconstructed_history = text_obj.to_history(tokenizer) + + assert isinstance(reconstructed_history, ChatHistory) + assert reconstructed_history.full is not None + + def test_text_to_history_batch(self, tokenizer): + """Test converting a batch of text to history.""" + histories = History.from_chats([SIMPLE_CONVERSATION, MULTI_TURN_CONVERSATION]) + # Create a batch of ChatHistory objects + chat_histories = [ + ChatHistory(full=histories[i : i + 1]) for i in range(histories.shape[0]) + ] + chat_history_batch = lazy_stack(chat_histories) + text_obj = chat_history_batch.to_text(tokenizer) + + reconstructed_history = text_obj.to_history(tokenizer) + + assert isinstance(reconstructed_history, ChatHistory) + assert reconstructed_history.full is not None + assert reconstructed_history.batch_size[0] == 2 + + +class TestBijectivity: + """Test that conversions are bijective (round-trip conversions).""" + + def test_history_to_text_to_history(self, tokenizer): + """Test History -> Text -> History round-trip.""" + history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(full=history) + + # Convert to text and back + text = chat_history.to_text(tokenizer) + reconstructed = text.to_history(tokenizer) + + assert isinstance(reconstructed, ChatHistory) + assert reconstructed.full is not None + # Check that the content is preserved + original_content = history.content + reconstructed_content = reconstructed.full.content + assert original_content == reconstructed_content + + def test_history_to_tokens_to_text_to_history(self, tokenizer): + """Test History -> Tokens -> Text -> History round-trip.""" + history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(full=history) + + # Convert through tokens and text + tokens = chat_history.to_tokens(tokenizer) + text = tokens.to_text(tokenizer) + reconstructed = text.to_history(tokenizer) + + assert isinstance(reconstructed, ChatHistory) + assert reconstructed.full is not None + # Check that the content is preserved + original_content = [msg.content for msg in history.unbind(0)[0].unbind(0)] + reconstructed_content = [ + msg.content for msg in reconstructed.full.unbind(0)[0].unbind(0) + ] + assert original_content == reconstructed_content + + def test_text_to_tokens_to_text(self, tokenizer): + """Test Text -> Tokens -> Text round-trip.""" + original_text = Text(full=["Hello, how are you?"]) + + # Convert to tokens and back + tokens = original_text.to_tokens(tokenizer) + reconstructed_text = tokens.to_text(tokenizer) + + assert isinstance(reconstructed_text, Text) + assert reconstructed_text.full is not None + # The text should be very similar (may have minor tokenization artifacts) + reconstructed_full_list = reconstructed_text._tensordict.get( + "full", as_list=True + ) + original_full_list = original_text._tensordict.get("full", as_list=True) + assert len(reconstructed_full_list) == len(original_full_list) + + def test_tokens_to_text_to_tokens_shape_preserved(self, tokenizer): + """Test that Tokens -> Text -> Tokens preserves token shapes.""" + history = History.from_chats([SIMPLE_CONVERSATION]) + chat_history = ChatHistory(full=history) + original_tokens = chat_history.to_tokens(tokenizer) + + # Convert to text and back to tokens + text = original_tokens.to_text(tokenizer) + reconstructed_tokens = text.to_tokens(tokenizer) + + assert isinstance(reconstructed_tokens, Tokens) + assert reconstructed_tokens.full is not None + # Check that shapes are similar (may differ slightly due to tokenization) + original_full_list = original_tokens._tensordict.get("full", as_list=True) + reconstructed_full_list = reconstructed_tokens._tensordict.get( + "full", as_list=True + ) + original_len = original_full_list[0].shape[0] + reconstructed_len = reconstructed_full_list[0].shape[0] + # Allow some tolerance for tokenization differences + assert abs(original_len - reconstructed_len) <= 2 + + +class TestBatchDimensions: + """Test that conversions work correctly with different batch dimensions.""" + + def test_single_batch_dimension(self, tokenizer): + """Test conversions with single batch dimension.""" + histories = History.from_chats([SIMPLE_CONVERSATION, MULTI_TURN_CONVERSATION]) + # Create a batch of ChatHistory objects + chat_histories = [ + ChatHistory(full=histories[i : i + 1]) for i in range(histories.shape[0]) + ] + chat_history_batch = lazy_stack(chat_histories) + assert chat_history_batch.batch_size == torch.Size([2]) + + # Test all conversions maintain batch size + text = chat_history_batch.to_text(tokenizer) + assert text.batch_size == torch.Size([2]) + + tokens = chat_history_batch.to_tokens(tokenizer) + assert tokens.batch_size == torch.Size([2]) + + reconstructed = text.to_history(tokenizer) + assert reconstructed.batch_size == torch.Size([2]) + + def test_nested_batch_dimensions(self, tokenizer): + """Test conversions with nested batch dimensions.""" + # Create a 2x2 batch + histories = History.from_chats([SIMPLE_CONVERSATION, MULTI_TURN_CONVERSATION]) + # Create a 2x2 batch of ChatHistory objects + chat_histories_outer = [] + for _ in range(2): + chat_histories_inner = [ + ChatHistory(full=histories[i : i + 1]) + for i in range(histories.shape[0]) + ] + chat_histories_outer.append(lazy_stack(chat_histories_inner)) + chat_history_batch = lazy_stack(chat_histories_outer) + assert chat_history_batch.batch_size == torch.Size([2, 2]) + + # Test conversions maintain batch size + text = chat_history_batch.to_text(tokenizer) + assert text.batch_size == torch.Size([2, 2]) + + tokens = chat_history_batch.to_tokens(tokenizer) + assert tokens.batch_size == torch.Size([2, 2]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _, args = parser.parse_known_args() + pytest.main([__file__, "-v"] + args) diff --git a/torchrl/modules/llm/backends/_models.py b/torchrl/modules/llm/backends/_models.py new file mode 100644 index 00000000000..46842c94186 --- /dev/null +++ b/torchrl/modules/llm/backends/_models.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Override the last layers of your models here.""" + +from __future__ import annotations + +import os + +import torch + +try: + from vllm.config import VllmConfig + from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM +except ImportError: + + class VllmConfig: + """Placeholder for VllmConfig class when vLLM is not installed.""" + + class Qwen3ForCausalLM: + """Placeholder for Qwen3ForCausalLM class when vLLM is not installed.""" + + +def is_fp32_output_enabled() -> bool: + """Check if FP32 output is enabled.""" + return os.getenv("VLLM_ENABLE_FP32_OUTPUT", "0") == "1" + + +class Qwen3ForCausalLMFP32(Qwen3ForCausalLM): + """Qwen3ForCausalLM with FP32 output.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + if is_fp32_output_enabled(): + self.lm_head.float() + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + if is_fp32_output_enabled(): + hidden_states = hidden_states.float() + logits = self.logits_processor(self.lm_head, hidden_states) + return logits diff --git a/torchrl/modules/llm/backends/vllm/vllm_async.py b/torchrl/modules/llm/backends/vllm/vllm_async.py index cc5dc0c9cd2..d7435b9b99e 100644 --- a/torchrl/modules/llm/backends/vllm/vllm_async.py +++ b/torchrl/modules/llm/backends/vllm/vllm_async.py @@ -42,20 +42,17 @@ TIMEOUT_SECONDS = os.getenv("TORCHRL_VLLM_TIMEOUT_SECONDS", 300) +try: + import vllm -class _AsyncvLLMWorker: - """Async vLLM worker for Ray with weight update capabilities. + _has_vllm = True +except ImportError: + vllm = None + _has_vllm = False - This worker extends the base vLLM Worker to support async operations - and weight updates via NCCL communication groups. - """ - def __init__(self, *args, **kwargs): - torchrl_logger.info(f"=> in {type(self).__name__}.__init__") - torchrl_logger.info(f"visible devices {os.getenv('CUDA_VISIBLE_DEVICES')}") - torchrl_logger.info(f"device count {torch.cuda.device_count()}") - self.model_update_group = None - super().__init__(*args, **kwargs) +class _AsyncvLLMWorker: + """Async vLLM worker extension for Ray with weight update capabilities.""" def init_weight_update_group( self, @@ -715,6 +712,7 @@ def from_pretrained( num_replicas: int = 1, verbose: bool = True, compile: bool = True, + enable_fp32_output: bool = False, **kwargs, ) -> AsyncVLLM: """Create an AsyncVLLM instance from a pretrained model. @@ -728,6 +726,7 @@ def from_pretrained( num_replicas (int): Number of engine replicas to create. verbose (bool, optional): Whether to enable verbose logging with throughput statistics. Defaults to True. compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True. + enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False. **kwargs: Additional arguments passed to AsyncEngineArgs. Returns: @@ -748,6 +747,12 @@ def from_pretrained( >>> # Generate text >>> from vllm import SamplingParams >>> result = service.generate("Hello, world!", SamplingParams(max_tokens=50)) + >>> + >>> # Enable FP32 output for better numerical stability + >>> service = AsyncVLLM.from_pretrained( + ... "Qwen/Qwen2.5-3B", + ... enable_fp32_output=True + ... ) """ return make_async_vllm_engine( model_name=model_name, @@ -755,6 +760,7 @@ def from_pretrained( num_replicas=num_replicas, verbose=verbose, compile=compile, + enable_fp32_output=enable_fp32_output, **kwargs, ) @@ -1909,11 +1915,13 @@ def get_stats(self) -> dict[str, Any]: def make_async_vllm_engine( + *, model_name: str, num_devices: int | None = None, num_replicas: int = 1, verbose: bool = True, compile: bool = True, + enable_fp32_output: bool = False, tensor_parallel_size: int | None = None, data_parallel_size: int | None = None, pipeline_parallel_size: int | None = None, @@ -1927,6 +1935,9 @@ def make_async_vllm_engine( num_replicas (int): Number of engine replicas to create. verbose (bool, optional): Whether to enable verbose logging with throughput statistics. Defaults to True. compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True. + enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False. + This can help with numerical stability for certain models. Requires model-specific support in + torchrl.modules.llm.backends._models. tensor_parallel_size (int, optional): Number of devices to use, per replica. Defaults to None. data_parallel_size (int, optional): Number of data parallel groups to use. Defaults to None. pipeline_parallel_size (int, optional): Number of pipeline parallel groups to use. Defaults to None. @@ -1947,6 +1958,9 @@ def make_async_vllm_engine( >>> service = make_async_vllm_engine("Qwen/Qwen2.5-3B", num_devices=2, num_replicas=2) >>> # Generate text >>> result = service.generate("Hello, world!", sampling_params) + >>> + >>> # Create with FP32 output enabled + >>> service = make_async_vllm_engine("Qwen/Qwen2.5-3B", enable_fp32_output=True) """ if not _has_vllm: raise ImportError( @@ -1955,6 +1969,14 @@ def make_async_vllm_engine( from vllm import AsyncEngineArgs + # Set FP32 output environment variable if requested + if enable_fp32_output: + os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1" + torchrl_logger.info( + "Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). " + "This will use FP32 for the final output layer if the model supports it." + ) + # Configure verbose logging if requested if verbose: import logging diff --git a/torchrl/modules/llm/backends/vllm/vllm_sync.py b/torchrl/modules/llm/backends/vllm/vllm_sync.py index 759bdff8be2..b1b0ad0caaf 100644 --- a/torchrl/modules/llm/backends/vllm/vllm_sync.py +++ b/torchrl/modules/llm/backends/vllm/vllm_sync.py @@ -329,6 +329,7 @@ def make_vllm_worker( num_devices: int | None = None, make_ray_worker: bool = True, enforce_eager: bool = False, + enable_fp32_output: bool = False, **kwargs, ) -> RayLLMWorker | LocalLLMWrapper: """Creates a vLLM inference engine with tensor parallelism support. @@ -339,6 +340,9 @@ def make_vllm_worker( num_devices (int, optional): Number of devices to use. Exclusive with devices. make_ray_worker (bool, optional): Whether to create a Ray actor. Defaults to True. enforce_eager (bool, optional): Whether to enforce eager execution. Defaults to `False`. + enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False. + This can help with numerical stability for certain models. Requires model-specific support in + torchrl.modules.llm.backends._models. **kwargs: Additional arguments passed to vLLM.LLM.__init__. Returns: @@ -349,12 +353,22 @@ def make_vllm_worker( >>> worker = make_vllm_worker("Qwen/Qwen2.5-3B", num_devices=2) >>> # Create a local LLM instance on GPU 1 >>> llm = make_vllm_worker("Qwen/Qwen2.5-3B", devices=[1], make_ray_worker=False) + >>> # Create with FP32 output enabled + >>> worker = make_vllm_worker("Qwen/Qwen2.5-3B", num_devices=2, enable_fp32_output=True) """ if not _has_vllm: raise ImportError( "vllm is not installed. Please install it with `pip install vllm`." ) + # Set FP32 output environment variable if requested + if enable_fp32_output: + os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1" + torchrl_logger.info( + "Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). " + "This will use FP32 for the final output layer if the model supports it." + ) + # Handle device specification if num_devices is not None and devices is not None: raise ValueError("Cannot specify both num_devices and devices") diff --git a/torchrl/modules/llm/backends/vllm_plugin.py b/torchrl/modules/llm/backends/vllm_plugin.py new file mode 100644 index 00000000000..5af5c860aa3 --- /dev/null +++ b/torchrl/modules/llm/backends/vllm_plugin.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from torchrl._utils import logger + + +def register_fp32_overrides() -> None: + """Register FP32 overrides for vLLM models.""" + from vllm.model_executor.models.registry import ModelRegistry + + # ======= Register models here ======= + # Register Qwen3 models with FP32 override + ModelRegistry.register_model( + "Qwen3ForCausalLM", + "torchrl.modules.llm.backends._models:Qwen3ForCausalLMFP32", + ) + + logger.info("Registered Qwen3 FP32 model overrides")