Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions examples/inference_policy_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Example usage of InferencePolicy abstractions.

Demonstrates how to use policies for deployment-ready inference.
"""

import asyncio
from openai import AsyncOpenAI

import verifiers as vf


async def main():
# Example 1: Using APIPolicy with OpenAI
# ========================================
print("Example 1: APIPolicy with OpenAI API")

# Create policy from client (backwards compatible)
client = AsyncOpenAI(api_key="your-api-key")
policy = vf.InferencePolicy.from_client(client, model="gpt-4")

# Or create directly
policy = vf.APIPolicy(client=client, model="gpt-4")

# Generate response
response = await policy.generate(
prompt=[{"role": "user", "content": "Hello!"}],
sampling_args={"temperature": 0.7, "max_tokens": 100}
)
print(f"Response: {response.choices[0].message.content}")

# Example 2: Using with Environment (backwards compatible)
# =========================================================
print("\nExample 2: Using policy with Environment")

# Load environment
env = vf.load_environment("math-python")

# Old way (still works)
results_old = env.evaluate(
client=client,
model="gpt-4",
num_examples=5
)

# New way (using policy) - planned for future
# results_new = env.evaluate(
# policy=policy,
# num_examples=5
# )

# Example 3: VLLMPolicy for high-throughput serving
# ==================================================
print("\nExample 3: VLLMPolicy for production deployment")

try:
# Requires vLLM server running
vllm_policy = vf.VLLMPolicy(
host="localhost",
port=8000,
model="your-model"
)

# Use for inference
response = await vllm_policy.generate(
prompt=[{"role": "user", "content": "Solve: 2+2=?"}],
sampling_args={"temperature": 0.0}
)
print(f"vLLM Response: {response.choices[0].message.content}")

# Optional: Enable weight syncing for online learning
# vllm_policy.enable_weight_sync()
# vllm_policy.sync_weights(model)

except Exception as e:
print(f"VLLMPolicy example skipped: {e}")

# Example 4: Custom deployment scenarios
# =======================================
print("\nExample 4: Flexible deployment patterns")

# API-based evaluation (no training infrastructure needed)
api_policy = vf.APIPolicy(
client=AsyncOpenAI(api_key="test", base_url="https://api.openai.com/v1"),
model="gpt-4"
)

# Can be used anywhere that accepts a policy
# - Evaluation scripts
# - Production serving
# - A/B testing different models
# - Local development


if __name__ == "__main__":
asyncio.run(main())
201 changes: 201 additions & 0 deletions tests/test_inference_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# ABOUTME: Unit tests for inference policy abstractions
# ABOUTME: Tests APIPolicy, VLLMPolicy, and factory methods
import pytest
from unittest.mock import AsyncMock, Mock

from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletion
from openai.types import Completion

from verifiers.inference.policy import APIPolicy, InferencePolicy


class TestInferencePolicy:
"""Test base InferencePolicy abstraction."""

def test_cannot_instantiate_abstract_class(self):
"""Base class cannot be instantiated directly."""
with pytest.raises(TypeError):
InferencePolicy() # type: ignore

def test_from_client_with_async_client(self):
"""Factory method creates APIPolicy from AsyncOpenAI client."""
client = AsyncOpenAI(api_key="test", base_url="http://test")
policy = InferencePolicy.from_client(client, "test-model")

assert isinstance(policy, APIPolicy)
assert policy.model == "test-model"
assert policy.client == client

def test_from_client_with_sync_client(self):
"""Factory method creates APIPolicy from sync OpenAI client."""
client = OpenAI(api_key="test", base_url="http://test")
policy = InferencePolicy.from_client(client, "test-model")

assert isinstance(policy, APIPolicy)
assert policy.model == "test-model"
# Should wrap sync client in async client
assert isinstance(policy.client, AsyncOpenAI)


class TestAPIPolicy:
"""Test APIPolicy implementation."""

@pytest.fixture
def mock_async_client(self):
"""Create mock AsyncOpenAI client."""
client = Mock(spec=AsyncOpenAI)
client.api_key = "test-key"
client.base_url = "http://test"
return client

@pytest.fixture
def api_policy(self, mock_async_client):
"""Create APIPolicy instance."""
return APIPolicy(client=mock_async_client, model="test-model")

def test_initialization_with_async_client(self, mock_async_client):
"""APIPolicy initializes correctly with AsyncOpenAI client."""
policy = APIPolicy(client=mock_async_client, model="test-model")
assert policy.client == mock_async_client
assert policy.model == "test-model"

def test_initialization_with_sync_client(self):
"""APIPolicy wraps sync OpenAI client in AsyncOpenAI."""
sync_client = Mock(spec=OpenAI)
sync_client.api_key = "test-key"
sync_client.base_url = "http://test"

policy = APIPolicy(client=sync_client, model="test-model")

assert isinstance(policy.client, AsyncOpenAI)
assert policy.model == "test-model"

@pytest.mark.asyncio
async def test_generate_chat_format(self, api_policy, mock_async_client):
"""Generate works with chat completion format."""
# Setup mock response
mock_response = Mock(spec=ChatCompletion)
mock_async_client.chat = Mock()
mock_async_client.chat.completions = Mock()
mock_async_client.chat.completions.create = AsyncMock(return_value=mock_response)

# Test chat format
prompt = [{"role": "user", "content": "Hello"}]
result = await api_policy.generate(
prompt=prompt,
sampling_args={"temperature": 0.7}
)

# Verify call was made
mock_async_client.chat.completions.create.assert_called_once_with(
model="test-model",
messages=prompt,
temperature=0.7
)
assert result == mock_response

@pytest.mark.asyncio
async def test_generate_completion_format(self, api_policy, mock_async_client):
"""Generate works with completion format."""
# Setup mock response
mock_response = Mock(spec=Completion)
mock_async_client.completions = Mock()
mock_async_client.completions.create = AsyncMock(return_value=mock_response)

# Test completion format
prompt = "Hello, world"
result = await api_policy.generate(
prompt=prompt,
sampling_args={"temperature": 0.7}
)

# Verify call was made
mock_async_client.completions.create.assert_called_once_with(
model="test-model",
prompt=prompt,
temperature=0.7
)
assert result == mock_response

@pytest.mark.asyncio
async def test_generate_with_no_sampling_args(self, api_policy, mock_async_client):
"""Generate works without sampling args."""
mock_response = Mock(spec=ChatCompletion)
mock_async_client.chat = Mock()
mock_async_client.chat.completions = Mock()
mock_async_client.chat.completions.create = AsyncMock(return_value=mock_response)

prompt = [{"role": "user", "content": "Hello"}]
await api_policy.generate(prompt=prompt)

# Should still call with model
mock_async_client.chat.completions.create.assert_called_once()
call_kwargs = mock_async_client.chat.completions.create.call_args[1]
assert call_kwargs["model"] == "test-model"
assert call_kwargs["messages"] == prompt


class TestVLLMPolicy:
"""Test VLLMPolicy implementation."""

def test_initialization_requires_vllm(self):
"""VLLMPolicy requires vLLM to be installed."""
# This will fail if vLLM is not installed, which is expected
# In CI, we should skip this test or mock the import
try:
from verifiers.inference.backends.vllm_policy import VLLMPolicy
# If import succeeds, test basic initialization
# Note: This will try to connect to a server, so we can't fully test without mocking
assert VLLMPolicy is not None
except ImportError:
pytest.skip("vLLM not installed")

def test_vllm_policy_structure(self):
"""VLLMPolicy has expected methods and attributes."""
try:
from verifiers.inference.backends.vllm_policy import VLLMPolicy

# Check class has required methods
assert hasattr(VLLMPolicy, 'generate')
assert hasattr(VLLMPolicy, 'enable_weight_sync')
assert hasattr(VLLMPolicy, 'sync_weights')
assert hasattr(VLLMPolicy, 'from_client')
except ImportError:
pytest.skip("vLLM not installed")


class TestBackwardsCompatibility:
"""Test backwards compatibility patterns."""

def test_policy_can_wrap_existing_client_code(self):
"""Existing code using clients can be wrapped in policy."""
# Simulate existing code pattern
client = OpenAI(api_key="test", base_url="http://test")
model = "gpt-4"

# New code can wrap this
policy = InferencePolicy.from_client(client, model)

assert isinstance(policy, APIPolicy)
assert policy.model == model

@pytest.mark.asyncio
async def test_policy_interface_matches_client_usage(self):
"""Policy interface is similar to direct client usage."""
mock_client = Mock(spec=AsyncOpenAI)
mock_client.chat = Mock()
mock_client.chat.completions = Mock()
mock_client.chat.completions.create = AsyncMock(return_value=Mock())

policy = APIPolicy(client=mock_client, model="test")

# Both interfaces should work similarly
prompt = [{"role": "user", "content": "test"}]
sampling_args = {"temperature": 0.7}

# Policy interface
await policy.generate(prompt=prompt, sampling_args=sampling_args)

# Verify underlying client was called
assert mock_client.chat.completions.create.called
6 changes: 6 additions & 0 deletions verifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
extract_hash_answer,
load_example_dataset,
)
from .inference.policy import APIPolicy, InferencePolicy
from .utils.env_utils import load_environment
from .utils.logging_utils import print_prompt_completions_sample

Expand Down Expand Up @@ -83,6 +84,9 @@ def setup_logging(
"StatefulToolEnv",
"ToolEnv",
"EnvGroup",
"InferencePolicy",
"APIPolicy",
"VLLMPolicy",
"extract_boxed_answer",
"extract_hash_answer",
"load_example_dataset",
Expand Down Expand Up @@ -110,6 +114,7 @@ def setup_logging(
"SandboxEnv": "verifiers.envs.sandbox_env:SandboxEnv",
"PythonEnv": "verifiers.envs.python_env:PythonEnv",
"TextArenaEnv": "verifiers.envs.textarena_env:TextArenaEnv",
"VLLMPolicy": "verifiers.inference.backends.vllm_policy:VLLMPolicy",
}


Expand All @@ -130,6 +135,7 @@ def __getattr__(name: str):
from .envs.python_env import PythonEnv # noqa: F401
from .envs.sandbox_env import SandboxEnv # noqa: F401
from .envs.textarena_env import TextArenaEnv # noqa: F401
from .inference.backends.vllm_policy import VLLMPolicy # noqa: F401
from .rubrics.math_rubric import MathRubric # noqa: F401
from .trainers import ( # noqa: F401
GRPOConfig,
Expand Down
1 change: 1 addition & 0 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from transformers.tokenization_utils_base import ( # type: ignore
PreTrainedTokenizerBase,
)
from verifiers.inference.policy import InferencePolicy


class Environment(ABC):
Expand Down
2 changes: 2 additions & 0 deletions verifiers/inference/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# ABOUTME: Backend implementations for inference policies
# ABOUTME: Contains specific policy implementations for different inference backends
Loading