diff --git a/examples/inference_policy_example.py b/examples/inference_policy_example.py new file mode 100644 index 000000000..bf4095e30 --- /dev/null +++ b/examples/inference_policy_example.py @@ -0,0 +1,96 @@ +""" +Example usage of InferencePolicy abstractions. + +Demonstrates how to use policies for deployment-ready inference. +""" + +import asyncio +from openai import AsyncOpenAI + +import verifiers as vf + + +async def main(): + # Example 1: Using APIPolicy with OpenAI + # ======================================== + print("Example 1: APIPolicy with OpenAI API") + + # Create policy from client (backwards compatible) + client = AsyncOpenAI(api_key="your-api-key") + policy = vf.InferencePolicy.from_client(client, model="gpt-4") + + # Or create directly + policy = vf.APIPolicy(client=client, model="gpt-4") + + # Generate response + response = await policy.generate( + prompt=[{"role": "user", "content": "Hello!"}], + sampling_args={"temperature": 0.7, "max_tokens": 100} + ) + print(f"Response: {response.choices[0].message.content}") + + # Example 2: Using with Environment (backwards compatible) + # ========================================================= + print("\nExample 2: Using policy with Environment") + + # Load environment + env = vf.load_environment("math-python") + + # Old way (still works) + results_old = env.evaluate( + client=client, + model="gpt-4", + num_examples=5 + ) + + # New way (using policy) - planned for future + # results_new = env.evaluate( + # policy=policy, + # num_examples=5 + # ) + + # Example 3: VLLMPolicy for high-throughput serving + # ================================================== + print("\nExample 3: VLLMPolicy for production deployment") + + try: + # Requires vLLM server running + vllm_policy = vf.VLLMPolicy( + host="localhost", + port=8000, + model="your-model" + ) + + # Use for inference + response = await vllm_policy.generate( + prompt=[{"role": "user", "content": "Solve: 2+2=?"}], + sampling_args={"temperature": 0.0} + ) + print(f"vLLM Response: {response.choices[0].message.content}") + + # Optional: Enable weight syncing for online learning + # vllm_policy.enable_weight_sync() + # vllm_policy.sync_weights(model) + + except Exception as e: + print(f"VLLMPolicy example skipped: {e}") + + # Example 4: Custom deployment scenarios + # ======================================= + print("\nExample 4: Flexible deployment patterns") + + # API-based evaluation (no training infrastructure needed) + api_policy = vf.APIPolicy( + client=AsyncOpenAI(api_key="test", base_url="https://api.openai.com/v1"), + model="gpt-4" + ) + + # Can be used anywhere that accepts a policy + # - Evaluation scripts + # - Production serving + # - A/B testing different models + # - Local development + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_inference_policy.py b/tests/test_inference_policy.py new file mode 100644 index 000000000..be47b71d2 --- /dev/null +++ b/tests/test_inference_policy.py @@ -0,0 +1,201 @@ +# ABOUTME: Unit tests for inference policy abstractions +# ABOUTME: Tests APIPolicy, VLLMPolicy, and factory methods +import pytest +from unittest.mock import AsyncMock, Mock + +from openai import AsyncOpenAI, OpenAI +from openai.types.chat import ChatCompletion +from openai.types import Completion + +from verifiers.inference.policy import APIPolicy, InferencePolicy + + +class TestInferencePolicy: + """Test base InferencePolicy abstraction.""" + + def test_cannot_instantiate_abstract_class(self): + """Base class cannot be instantiated directly.""" + with pytest.raises(TypeError): + InferencePolicy() # type: ignore + + def test_from_client_with_async_client(self): + """Factory method creates APIPolicy from AsyncOpenAI client.""" + client = AsyncOpenAI(api_key="test", base_url="http://test") + policy = InferencePolicy.from_client(client, "test-model") + + assert isinstance(policy, APIPolicy) + assert policy.model == "test-model" + assert policy.client == client + + def test_from_client_with_sync_client(self): + """Factory method creates APIPolicy from sync OpenAI client.""" + client = OpenAI(api_key="test", base_url="http://test") + policy = InferencePolicy.from_client(client, "test-model") + + assert isinstance(policy, APIPolicy) + assert policy.model == "test-model" + # Should wrap sync client in async client + assert isinstance(policy.client, AsyncOpenAI) + + +class TestAPIPolicy: + """Test APIPolicy implementation.""" + + @pytest.fixture + def mock_async_client(self): + """Create mock AsyncOpenAI client.""" + client = Mock(spec=AsyncOpenAI) + client.api_key = "test-key" + client.base_url = "http://test" + return client + + @pytest.fixture + def api_policy(self, mock_async_client): + """Create APIPolicy instance.""" + return APIPolicy(client=mock_async_client, model="test-model") + + def test_initialization_with_async_client(self, mock_async_client): + """APIPolicy initializes correctly with AsyncOpenAI client.""" + policy = APIPolicy(client=mock_async_client, model="test-model") + assert policy.client == mock_async_client + assert policy.model == "test-model" + + def test_initialization_with_sync_client(self): + """APIPolicy wraps sync OpenAI client in AsyncOpenAI.""" + sync_client = Mock(spec=OpenAI) + sync_client.api_key = "test-key" + sync_client.base_url = "http://test" + + policy = APIPolicy(client=sync_client, model="test-model") + + assert isinstance(policy.client, AsyncOpenAI) + assert policy.model == "test-model" + + @pytest.mark.asyncio + async def test_generate_chat_format(self, api_policy, mock_async_client): + """Generate works with chat completion format.""" + # Setup mock response + mock_response = Mock(spec=ChatCompletion) + mock_async_client.chat = Mock() + mock_async_client.chat.completions = Mock() + mock_async_client.chat.completions.create = AsyncMock(return_value=mock_response) + + # Test chat format + prompt = [{"role": "user", "content": "Hello"}] + result = await api_policy.generate( + prompt=prompt, + sampling_args={"temperature": 0.7} + ) + + # Verify call was made + mock_async_client.chat.completions.create.assert_called_once_with( + model="test-model", + messages=prompt, + temperature=0.7 + ) + assert result == mock_response + + @pytest.mark.asyncio + async def test_generate_completion_format(self, api_policy, mock_async_client): + """Generate works with completion format.""" + # Setup mock response + mock_response = Mock(spec=Completion) + mock_async_client.completions = Mock() + mock_async_client.completions.create = AsyncMock(return_value=mock_response) + + # Test completion format + prompt = "Hello, world" + result = await api_policy.generate( + prompt=prompt, + sampling_args={"temperature": 0.7} + ) + + # Verify call was made + mock_async_client.completions.create.assert_called_once_with( + model="test-model", + prompt=prompt, + temperature=0.7 + ) + assert result == mock_response + + @pytest.mark.asyncio + async def test_generate_with_no_sampling_args(self, api_policy, mock_async_client): + """Generate works without sampling args.""" + mock_response = Mock(spec=ChatCompletion) + mock_async_client.chat = Mock() + mock_async_client.chat.completions = Mock() + mock_async_client.chat.completions.create = AsyncMock(return_value=mock_response) + + prompt = [{"role": "user", "content": "Hello"}] + await api_policy.generate(prompt=prompt) + + # Should still call with model + mock_async_client.chat.completions.create.assert_called_once() + call_kwargs = mock_async_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "test-model" + assert call_kwargs["messages"] == prompt + + +class TestVLLMPolicy: + """Test VLLMPolicy implementation.""" + + def test_initialization_requires_vllm(self): + """VLLMPolicy requires vLLM to be installed.""" + # This will fail if vLLM is not installed, which is expected + # In CI, we should skip this test or mock the import + try: + from verifiers.inference.backends.vllm_policy import VLLMPolicy + # If import succeeds, test basic initialization + # Note: This will try to connect to a server, so we can't fully test without mocking + assert VLLMPolicy is not None + except ImportError: + pytest.skip("vLLM not installed") + + def test_vllm_policy_structure(self): + """VLLMPolicy has expected methods and attributes.""" + try: + from verifiers.inference.backends.vllm_policy import VLLMPolicy + + # Check class has required methods + assert hasattr(VLLMPolicy, 'generate') + assert hasattr(VLLMPolicy, 'enable_weight_sync') + assert hasattr(VLLMPolicy, 'sync_weights') + assert hasattr(VLLMPolicy, 'from_client') + except ImportError: + pytest.skip("vLLM not installed") + + +class TestBackwardsCompatibility: + """Test backwards compatibility patterns.""" + + def test_policy_can_wrap_existing_client_code(self): + """Existing code using clients can be wrapped in policy.""" + # Simulate existing code pattern + client = OpenAI(api_key="test", base_url="http://test") + model = "gpt-4" + + # New code can wrap this + policy = InferencePolicy.from_client(client, model) + + assert isinstance(policy, APIPolicy) + assert policy.model == model + + @pytest.mark.asyncio + async def test_policy_interface_matches_client_usage(self): + """Policy interface is similar to direct client usage.""" + mock_client = Mock(spec=AsyncOpenAI) + mock_client.chat = Mock() + mock_client.chat.completions = Mock() + mock_client.chat.completions.create = AsyncMock(return_value=Mock()) + + policy = APIPolicy(client=mock_client, model="test") + + # Both interfaces should work similarly + prompt = [{"role": "user", "content": "test"}] + sampling_args = {"temperature": 0.7} + + # Policy interface + await policy.generate(prompt=prompt, sampling_args=sampling_args) + + # Verify underlying client was called + assert mock_client.chat.completions.create.called diff --git a/verifiers/__init__.py b/verifiers/__init__.py index 9433e2bc3..69c2bbf9f 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -24,6 +24,7 @@ extract_hash_answer, load_example_dataset, ) +from .inference.policy import APIPolicy, InferencePolicy from .utils.env_utils import load_environment from .utils.logging_utils import print_prompt_completions_sample @@ -83,6 +84,9 @@ def setup_logging( "StatefulToolEnv", "ToolEnv", "EnvGroup", + "InferencePolicy", + "APIPolicy", + "VLLMPolicy", "extract_boxed_answer", "extract_hash_answer", "load_example_dataset", @@ -110,6 +114,7 @@ def setup_logging( "SandboxEnv": "verifiers.envs.sandbox_env:SandboxEnv", "PythonEnv": "verifiers.envs.python_env:PythonEnv", "TextArenaEnv": "verifiers.envs.textarena_env:TextArenaEnv", + "VLLMPolicy": "verifiers.inference.backends.vllm_policy:VLLMPolicy", } @@ -130,6 +135,7 @@ def __getattr__(name: str): from .envs.python_env import PythonEnv # noqa: F401 from .envs.sandbox_env import SandboxEnv # noqa: F401 from .envs.textarena_env import TextArenaEnv # noqa: F401 + from .inference.backends.vllm_policy import VLLMPolicy # noqa: F401 from .rubrics.math_rubric import MathRubric # noqa: F401 from .trainers import ( # noqa: F401 GRPOConfig, diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 86a39d26f..2ce5cddc9 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -37,6 +37,7 @@ from transformers.tokenization_utils_base import ( # type: ignore PreTrainedTokenizerBase, ) + from verifiers.inference.policy import InferencePolicy class Environment(ABC): diff --git a/verifiers/inference/backends/__init__.py b/verifiers/inference/backends/__init__.py new file mode 100644 index 000000000..be2e274a6 --- /dev/null +++ b/verifiers/inference/backends/__init__.py @@ -0,0 +1,2 @@ +# ABOUTME: Backend implementations for inference policies +# ABOUTME: Contains specific policy implementations for different inference backends diff --git a/verifiers/inference/backends/vllm_policy.py b/verifiers/inference/backends/vllm_policy.py new file mode 100644 index 000000000..67eb75c9a --- /dev/null +++ b/verifiers/inference/backends/vllm_policy.py @@ -0,0 +1,129 @@ +# ABOUTME: vLLM-based inference policy for high-throughput serving +# ABOUTME: Wraps VLLMClient for production deployment scenarios +from typing import TYPE_CHECKING + +from verifiers.inference.policy import InferencePolicy +from verifiers.inference.vllm_client import VLLMClient +from verifiers.types import Messages, ModelResponse, SamplingArgs + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + + +class VLLMPolicy(InferencePolicy): + """ + Inference policy using vLLM for high-throughput serving. + + Optimized for production deployment with: + - Continuous batching for efficient GPU utilization + - PagedAttention for memory efficiency + - Optional weight syncing for online learning scenarios + """ + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 8000, + model: str | None = None, + connection_timeout: float = 0.0, + ): + """ + Initialize vLLM policy. + + Args: + host: vLLM server host + port: vLLM server port + model: Model name (for reference tracking) + connection_timeout: Timeout for initial server connection + """ + self.client = VLLMClient( + host=host, + port=port, + connection_timeout=connection_timeout, + ) + self.model = model or "vllm-model" + self._supports_weight_sync = False + + async def generate( + self, + prompt: Messages, + sampling_args: SamplingArgs | None = None, + **kwargs, + ) -> ModelResponse: + """Generate response using vLLM server.""" + sampling_args = sampling_args or {} + + # Determine message type + is_chat = isinstance(prompt, list) + + if is_chat: + # Chat completions format + response = await self.client.chat.completions.create( + model=self.model, + messages=prompt, # type: ignore + **sampling_args, + ) + else: + # Completions format + response = await self.client.completions.create( + model=self.model, + prompt=prompt, # type: ignore + **sampling_args, + ) + + return response + + def enable_weight_sync(self, group_port: int = 51216) -> None: + """ + Enable weight synchronization for online learning. + + Args: + group_port: Port for weight update communication + + Note: + This is only needed for training scenarios where model + weights are updated during inference. + """ + self.client.group_port = group_port + self.client.init_communicator() + self._supports_weight_sync = True + + def sync_weights(self, model: "PreTrainedModel") -> None: + """ + Synchronize model weights to vLLM server. + + Args: + model: Source model with updated weights + + Raises: + RuntimeError: If weight sync not enabled + """ + if not self._supports_weight_sync: + raise RuntimeError( + "Weight sync not enabled. Call enable_weight_sync() first." + ) + + # Update all parameters + for name, param in model.named_parameters(): + self.client.update_named_param(name, param.data) + + # Reset cache after weight update + self.client.reset_prefix_cache() + + @classmethod + def from_client(cls, client: VLLMClient, model: str | None = None) -> "VLLMPolicy": + """ + Create policy from existing VLLMClient. + + Args: + client: Configured VLLMClient instance + model: Model name override + + Returns: + VLLMPolicy wrapping the client + """ + policy = cls.__new__(cls) + policy.client = client + policy.model = model or "vllm-model" + policy._supports_weight_sync = False + return policy diff --git a/verifiers/inference/policy.py b/verifiers/inference/policy.py new file mode 100644 index 000000000..99c03c63b --- /dev/null +++ b/verifiers/inference/policy.py @@ -0,0 +1,108 @@ +# ABOUTME: Base abstraction for model inference policies +# ABOUTME: Decouples inference from training infrastructure for deployment flexibility +from abc import ABC, abstractmethod + +from openai import AsyncOpenAI, OpenAI + +from verifiers.types import Messages, ModelResponse, SamplingArgs + + +class InferencePolicy(ABC): + """ + Abstract base class for model inference policies. + + Provides a unified interface for generating responses across different + backends (API, local models, vLLM servers). Designed for deployment + scenarios where training infrastructure is not required. + """ + + @abstractmethod + async def generate( + self, + prompt: Messages, + sampling_args: SamplingArgs | None = None, + **kwargs, + ) -> ModelResponse: + """ + Generate a model response for the given prompt. + + Args: + prompt: Input messages (chat format) or string (completion format) + sampling_args: Sampling parameters (temperature, top_p, etc.) + **kwargs: Additional backend-specific parameters + + Returns: + ModelResponse (ChatCompletion or Completion object) + """ + pass + + @classmethod + def from_client( + cls, client: AsyncOpenAI | OpenAI, model: str + ) -> "APIPolicy": + """ + Create a policy from an existing OpenAI-compatible client. + + Args: + client: AsyncOpenAI or OpenAI client instance + model: Model name/identifier + + Returns: + APIPolicy wrapping the client + """ + # Import here to avoid circular dependency at module level + # APIPolicy is defined below in this same file + return APIPolicy(client=client, model=model) + + +class APIPolicy(InferencePolicy): + """ + Inference policy for OpenAI-compatible API endpoints. + + Wraps AsyncOpenAI clients for use with any OpenAI-compatible API + (OpenAI, Anthropic, vLLM server, etc.). + """ + + def __init__(self, client: AsyncOpenAI | OpenAI, model: str): + """ + Initialize API policy. + + Args: + client: AsyncOpenAI or OpenAI client + model: Model name for generation requests + """ + self.client = ( + client if isinstance(client, AsyncOpenAI) else AsyncOpenAI( + api_key=client.api_key, base_url=client.base_url + ) + ) + self.model = model + + async def generate( + self, + prompt: Messages, + sampling_args: SamplingArgs | None = None, + **kwargs, + ) -> ModelResponse: + """Generate response using the wrapped API client.""" + sampling_args = sampling_args or {} + + # Determine message type + is_chat = isinstance(prompt, list) + + if is_chat: + # Chat completions format + response = await self.client.chat.completions.create( + model=self.model, + messages=prompt, # type: ignore + **sampling_args, + ) + else: + # Completions format + response = await self.client.completions.create( + model=self.model, + prompt=prompt, # type: ignore + **sampling_args, + ) + + return response