diff --git a/docs/api/providers.md b/docs/api/providers.md index 68c124ce67..3e60b7fb62 100644 --- a/docs/api/providers.md +++ b/docs/api/providers.md @@ -34,6 +34,8 @@ ::: pydantic_ai.providers.vercel.VercelProvider +::: pydantic_ai.providers.cloudflare.CloudflareProvider + ::: pydantic_ai.providers.huggingface.HuggingFaceProvider ::: pydantic_ai.providers.moonshotai.MoonshotAIProvider diff --git a/docs/models/openai.md b/docs/models/openai.md index e28dc8374e..a626167f15 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -441,6 +441,67 @@ agent = Agent(model) ... ``` +### Cloudflare AI Gateway + +To use [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/), first set up a gateway in your [Cloudflare dashboard](https://dash.cloudflare.com/?to=/:account/ai/ai-gateway) and obtain your account ID and gateway ID. + +!!! note + This provider uses Cloudflare's [unified API endpoint](https://developers.cloudflare.com/ai-gateway/usage/chat-completion/) for routing requests to multiple AI providers. For the full list of supported providers, see [Cloudflare's documentation](https://developers.cloudflare.com/ai-gateway/usage/chat-completion/#supported-providers). + +You can set the `CLOUDFLARE_ACCOUNT_ID`, `CLOUDFLARE_GATEWAY_ID`, and optionally `CLOUDFLARE_AI_GATEWAY_AUTH` environment variables and use the `cloudflare:` model name prefix: + +```python test="skip - requires actual API keys" +from pydantic_ai import Agent + +# Set via environment or in code: +# CLOUDFLARE_ACCOUNT_ID='your-account-id' +# CLOUDFLARE_GATEWAY_ID='your-gateway-id' +# OPENAI_API_KEY='your-openai-api-key' + +agent = Agent('cloudflare:openai/gpt-4o') +... +``` + +Or use [`CloudflareProvider`][pydantic_ai.providers.cloudflare.CloudflareProvider] directly: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.cloudflare import CloudflareProvider + +model = OpenAIChatModel( + 'openai/gpt-4o', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + api_key='your-openai-api-key', + ), +) +agent = Agent(model) +... +``` + +For authenticated gateways with stored API keys in Cloudflare's dashboard: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.cloudflare import CloudflareProvider + +model = OpenAIChatModel( + 'anthropic/claude-3-5-sonnet', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + gateway_auth_token='your-gateway-token', + ), +) +agent = Agent(model) +... +``` + +See [`CloudflareProvider`][pydantic_ai.providers.cloudflare.CloudflareProvider] for additional configuration options including BYOK modes and authenticated gateways. + ### Grok (xAI) Go to [xAI API Console](https://console.x.ai/) and create an API key. diff --git a/docs/models/overview.md b/docs/models/overview.md index 45af29c862..46e40f70fc 100644 --- a/docs/models/overview.md +++ b/docs/models/overview.md @@ -20,6 +20,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used - [Ollama](openai.md#ollama) - [OpenRouter](openai.md#openrouter) - [Vercel AI Gateway](openai.md#vercel-ai-gateway) +- [Cloudflare AI Gateway](openai.md#cloudflare-ai-gateway) - [Perplexity](openai.md#perplexity) - [Fireworks AI](openai.md#fireworks-ai) - [Together AI](openai.md#together-ai) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 1668e04bd1..71bb36d89b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -127,6 +127,16 @@ 'cerebras:qwen-3-32b', 'cerebras:qwen-3-coder-480b', 'cerebras:qwen-3-235b-a22b-thinking-2507', + 'cloudflare:anthropic/claude-3-5-sonnet', + 'cloudflare:cohere/command-r-plus', + 'cloudflare:deepseek/deepseek-chat', + 'cloudflare:google/gemini-2.0-flash', + 'cloudflare:groq/llama-3.3-70b-versatile', + 'cloudflare:mistral/mistral-large-latest', + 'cloudflare:openai/gpt-4o', + 'cloudflare:perplexity/llama-3.1-sonar-small-128k-online', + 'cloudflare:workers-ai/@cf/meta/llama-3.1-8b-instruct', + 'cloudflare:xai/grok-2-1212', 'cohere:c4ai-aya-expanse-32b', 'cohere:c4ai-aya-expanse-8b', 'cohere:command-nightly', @@ -675,6 +685,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 'azure', 'deepseek', 'cerebras', + 'cloudflare', 'fireworks', 'github', 'grok', diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index e7cf15c3dc..d74dcca3a5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -272,6 +272,7 @@ def __init__( 'azure', 'deepseek', 'cerebras', + 'cloudflare', 'fireworks', 'github', 'grok', @@ -329,6 +330,7 @@ def __init__( 'azure', 'deepseek', 'cerebras', + 'cloudflare', 'fireworks', 'github', 'grok', diff --git a/pydantic_ai_slim/pydantic_ai/profiles/perplexity.py b/pydantic_ai_slim/pydantic_ai/profiles/perplexity.py new file mode 100644 index 0000000000..3a30e9ffbe --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/profiles/perplexity.py @@ -0,0 +1,8 @@ +from __future__ import annotations as _annotations + +from . import ModelProfile + + +def perplexity_model_profile(model_name: str) -> ModelProfile | None: + """Get the model profile for a Perplexity model.""" + return None diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index f71f2d94e0..6d6b78904a 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -69,6 +69,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901 from .vercel import VercelProvider return VercelProvider + elif provider == 'cloudflare': + from .cloudflare import CloudflareProvider + + return CloudflareProvider elif provider == 'azure': from .azure import AzureProvider diff --git a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py index 267cf41b8c..4891eeddd3 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py @@ -23,6 +23,35 @@ ) from _import_error +def cerebras_provider_model_profile(model_name: str) -> ModelProfile | None: + """Get the model profile for a model routed through Cerebras provider. + + This function handles model profiling for models that use Cerebras's API, + and applies Cerebras-specific settings like unsupported model parameters. + """ + prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile} + + profile = None + for prefix, profile_func in prefix_to_profile.items(): + model_name = model_name.lower() + if model_name.startswith(prefix): + profile = profile_func(model_name) + + # According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features, + # Cerebras doesn't support some model settings. + unsupported_model_settings = ( + 'frequency_penalty', + 'logit_bias', + 'presence_penalty', + 'parallel_tool_calls', + 'service_tier', + ) + return OpenAIModelProfile( + json_schema_transformer=OpenAIJsonSchemaTransformer, + openai_unsupported_model_settings=unsupported_model_settings, + ).update(profile) + + class CerebrasProvider(Provider[AsyncOpenAI]): """Provider for Cerebras API.""" @@ -39,27 +68,7 @@ def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: - prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile} - - profile = None - for prefix, profile_func in prefix_to_profile.items(): - model_name = model_name.lower() - if model_name.startswith(prefix): - profile = profile_func(model_name) - - # According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features, - # Cerebras doesn't support some model settings. - unsupported_model_settings = ( - 'frequency_penalty', - 'logit_bias', - 'presence_penalty', - 'parallel_tool_calls', - 'service_tier', - ) - return OpenAIModelProfile( - json_schema_transformer=OpenAIJsonSchemaTransformer, - openai_unsupported_model_settings=unsupported_model_settings, - ).update(profile) + return cerebras_provider_model_profile(model_name) @overload def __init__(self) -> None: ... diff --git a/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py new file mode 100644 index 0000000000..1bd2bc2c6d --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py @@ -0,0 +1,278 @@ +from __future__ import annotations as _annotations + +import os +from typing import overload + +import httpx + +from pydantic_ai import ModelProfile +from pydantic_ai.exceptions import UserError +from pydantic_ai.models import cached_async_http_client +from pydantic_ai.profiles.anthropic import anthropic_model_profile +from pydantic_ai.profiles.cohere import cohere_model_profile +from pydantic_ai.profiles.deepseek import deepseek_model_profile +from pydantic_ai.profiles.google import google_model_profile +from pydantic_ai.profiles.grok import grok_model_profile +from pydantic_ai.profiles.mistral import mistral_model_profile +from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile +from pydantic_ai.profiles.perplexity import perplexity_model_profile +from pydantic_ai.providers import Provider + +from .cerebras import cerebras_provider_model_profile +from .groq import groq_provider_model_profile + +try: + from openai import AsyncOpenAI +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `openai` package to use the Cloudflare provider, ' + 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' + ) from _import_error + + +class CloudflareProvider(Provider[AsyncOpenAI]): + """Provider for Cloudflare AI Gateway API. + + Cloudflare AI Gateway provides a unified OpenAI-compatible endpoint that routes + requests to various AI providers while adding features like caching, rate limiting, + analytics, and logging. + + !!! note + This provider uses Cloudflare's unified API endpoint for routing requests. + For the full list of supported providers, see + [Cloudflare's documentation](https://developers.cloudflare.com/ai-gateway/usage/chat-completion/#supported-providers). + + This provider looks for these environment variables if they are not provided as parameters: + - account_id: `CLOUDFLARE_ACCOUNT_ID` + - gateway_id: `CLOUDFLARE_GATEWAY_ID` + - gateway_auth_token: `CLOUDFLARE_AI_GATEWAY_AUTH` (optional) + + There are three usage modes: + + 1. User-managed keys with unauthenticated gateway: + ```python + from pydantic_ai import Agent + from pydantic_ai.models.openai import OpenAIChatModel + from pydantic_ai.providers.cloudflare import CloudflareProvider + + model = OpenAIChatModel( + 'openai/gpt-4o', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + api_key='your-openai-api-key', + ), + ) + agent = Agent(model) + ``` + + 2. User-managed keys with authenticated gateway (API key + gateway authentication): + ```python + from pydantic_ai import Agent + from pydantic_ai.models.openai import OpenAIChatModel + from pydantic_ai.providers.cloudflare import CloudflareProvider + + model = OpenAIChatModel( + 'anthropic/claude-3-5-sonnet', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + api_key='your-openai-api-key', + gateway_auth_token='your-gateway-token', + ), + ) + agent = Agent(model) + ``` + + 3. CF-managed keys mode (use API keys stored in Cloudflare dashboard): + ```python + from pydantic_ai import Agent + from pydantic_ai.models.openai import OpenAIChatModel + from pydantic_ai.providers.cloudflare import CloudflareProvider + + model = OpenAIChatModel( + 'openai/gpt-4o', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + gateway_auth_token='your-gateway-token', + ), + ) + agent = Agent(model) + ``` + """ + + @property + def name(self) -> str: + return 'cloudflare' + + @property + def base_url(self) -> str: + return self._base_url + + @property + def client(self) -> AsyncOpenAI: + return self._client + + def model_profile(self, model_name: str) -> ModelProfile | None: + """Return the model profile for the given model name. + + Model names should be in the format 'provider/model', e.g., 'openai/gpt-4o', + 'anthropic/claude-3-5-sonnet', 'groq/llama-3.3-70b-versatile'. + + For the full list of supported providers, see + [Cloudflare's documentation](https://developers.cloudflare.com/ai-gateway/usage/chat-completion/#supported-providers). + """ + provider_to_profile = { + 'anthropic': anthropic_model_profile, + 'openai': openai_model_profile, + 'groq': groq_provider_model_profile, + 'mistral': mistral_model_profile, + 'cohere': cohere_model_profile, + 'deepseek': deepseek_model_profile, + 'perplexity': perplexity_model_profile, + 'workers-ai': openai_model_profile, # Cloudflare Workers AI uses OpenAI-compatible API + 'workersai': openai_model_profile, # Alternative naming + 'google-ai-studio': google_model_profile, + 'grok': grok_model_profile, + 'xai': grok_model_profile, # xai is an alias for grok + 'cerebras': cerebras_provider_model_profile, + } + + profile = None + + try: + provider, model_name = model_name.split('/', 1) + except ValueError: + raise UserError(f"Model name must be in 'provider/model' format, got: {model_name!r}") + + if provider in provider_to_profile: + profile = provider_to_profile[provider](model_name) + + # As CloudflareProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer, + # we need to maintain that behavior unless json_schema_transformer is set explicitly + return OpenAIModelProfile( + json_schema_transformer=OpenAIJsonSchemaTransformer, + ).update(profile) + + # Scenario 1: User-managed keys with unauthenticated gateway (api_key required) + @overload + def __init__(self, *, account_id: str, gateway_id: str, api_key: str) -> None: ... + + @overload + def __init__(self, *, account_id: str, gateway_id: str, api_key: str, http_client: httpx.AsyncClient) -> None: ... + + # Scenario 2: User-managed keys with authenticated gateway (api_key + gateway_auth_token) + @overload + def __init__(self, *, account_id: str, gateway_id: str, api_key: str, gateway_auth_token: str) -> None: ... + + @overload + def __init__( + self, + *, + account_id: str, + gateway_id: str, + api_key: str, + gateway_auth_token: str, + http_client: httpx.AsyncClient, + ) -> None: ... + + # Scenario 3: CF-managed keys with authenticated gateway (no api_key, gateway_auth_token required) + @overload + def __init__(self, *, account_id: str, gateway_id: str, gateway_auth_token: str) -> None: ... + + @overload + def __init__( + self, + *, + account_id: str, + gateway_id: str, + gateway_auth_token: str, + http_client: httpx.AsyncClient, + ) -> None: ... + + # Advanced: Pre-configured OpenAI client + @overload + def __init__(self, *, account_id: str, gateway_id: str, openai_client: AsyncOpenAI) -> None: ... + + def __init__( + self, + *, + account_id: str | None = None, + gateway_id: str | None = None, + api_key: str | None = None, + gateway_auth_token: str | None = None, + openai_client: AsyncOpenAI | None = None, + http_client: httpx.AsyncClient | None = None, + ) -> None: + """Initialize the Cloudflare AI Gateway provider. + + Args: + account_id: Your Cloudflare account ID. Can also be set via CLOUDFLARE_ACCOUNT_ID environment variable. + gateway_id: Your Cloudflare AI Gateway ID. Can also be set via CLOUDFLARE_GATEWAY_ID environment variable. + api_key: The API key for the upstream provider (OpenAI, Anthropic, etc.). + - Required for user-managed mode + - Omit this (along with providing gateway_auth_token) to use CF-managed keys mode + - Optional when using the openai_client parameter (pre-configured client) + gateway_auth_token: Authorization token for authenticated gateways. + - Required for CF-managed keys mode (when api_key is omitted) + - Optional for user-managed mode (provides additional gateway authentication) + - Can also be set via CLOUDFLARE_AI_GATEWAY_AUTH environment variable + openai_client: Optional pre-configured AsyncOpenAI client for advanced use cases. + http_client: Optional HTTP client to use for requests. + + Raises: + UserError: If configuration is invalid (e.g., neither api_key nor CF-managed keys mode is configured). + """ + account_id = account_id or os.getenv('CLOUDFLARE_ACCOUNT_ID') + gateway_id = gateway_id or os.getenv('CLOUDFLARE_GATEWAY_ID') + + if not account_id: + raise UserError( + 'Set the `CLOUDFLARE_ACCOUNT_ID` environment variable ' + 'or pass it via `CloudflareProvider(account_id=...)` to use the Cloudflare provider.' + ) + + if not gateway_id: + raise UserError( + 'Set the `CLOUDFLARE_GATEWAY_ID` environment variable ' + 'or pass it via `CloudflareProvider(gateway_id=...)` to use the Cloudflare provider.' + ) + + gateway_auth_token = gateway_auth_token or os.getenv('CLOUDFLARE_AI_GATEWAY_AUTH') + + # Detect CF-managed keys mode: no api_key provided + gateway_auth_token present + no pre-configured client + use_cf_managed_keys = api_key is None and gateway_auth_token is not None and openai_client is None + + if use_cf_managed_keys: + # CF-managed keys mode: use API keys stored in Cloudflare dashboard + # Use empty string for AsyncOpenAI - this prevents the Authorization header from being sent + api_key = '' + elif api_key is None and openai_client is None: + # Not using CF-managed keys, so api_key is required (unless using pre-configured openai_client) + raise UserError( + 'You must provide an api_key for user-managed mode.\n' + 'To use API keys stored in your Cloudflare dashboard (CF-managed), omit api_key and provide gateway_auth_token instead.' + ) + + self._base_url = f'https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/compat' + + default_headers = { + 'http-referer': 'https://ai.pydantic.dev/', + 'x-title': 'pydantic-ai', + } + + if gateway_auth_token: + default_headers['cf-aig-authorization'] = gateway_auth_token + + if openai_client is not None: + self._client = openai_client + elif http_client is not None: + self._client = AsyncOpenAI( + base_url=self._base_url, api_key=api_key, http_client=http_client, default_headers=default_headers + ) + else: + http_client = cached_async_http_client(provider='cloudflare') + self._client = AsyncOpenAI( + base_url=self._base_url, api_key=api_key, http_client=http_client, default_headers=default_headers + ) diff --git a/pydantic_ai_slim/pydantic_ai/providers/groq.py b/pydantic_ai_slim/pydantic_ai/providers/groq.py index f0e5c5b53b..56604ef8ca 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/groq.py +++ b/pydantic_ai_slim/pydantic_ai/providers/groq.py @@ -44,6 +44,34 @@ def meta_groq_model_profile(model_name: str) -> ModelProfile | None: return meta_model_profile(model_name) +def groq_provider_model_profile(model_name: str) -> ModelProfile | None: + """Get the model profile for a model routed through Groq provider. + + This function handles model profiling for models that use Groq's API, + including various model families like Llama, Gemma, Qwen, etc. + """ + prefix_to_profile = { + 'llama': meta_model_profile, + 'meta-llama/': meta_groq_model_profile, + 'gemma': google_model_profile, + 'qwen': qwen_model_profile, + 'deepseek': deepseek_model_profile, + 'mistral': mistral_model_profile, + 'moonshotai/': groq_moonshotai_model_profile, + 'compound-': groq_model_profile, + 'openai/': openai_model_profile, + } + + for prefix, profile_func in prefix_to_profile.items(): + model_name = model_name.lower() + if model_name.startswith(prefix): + if prefix.endswith('/'): + model_name = model_name[len(prefix) :] + return profile_func(model_name) + + return None + + class GroqProvider(Provider[AsyncGroq]): """Provider for Groq API.""" @@ -60,26 +88,7 @@ def client(self) -> AsyncGroq: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: - prefix_to_profile = { - 'llama': meta_model_profile, - 'meta-llama/': meta_groq_model_profile, - 'gemma': google_model_profile, - 'qwen': qwen_model_profile, - 'deepseek': deepseek_model_profile, - 'mistral': mistral_model_profile, - 'moonshotai/': groq_moonshotai_model_profile, - 'compound-': groq_model_profile, - 'openai/': openai_model_profile, - } - - for prefix, profile_func in prefix_to_profile.items(): - model_name = model_name.lower() - if model_name.startswith(prefix): - if prefix.endswith('/'): - model_name = model_name[len(prefix) :] - return profile_func(model_name) - - return None + return groq_provider_model_profile(model_name) @overload def __init__(self, *, groq_client: AsyncGroq | None = None) -> None: ... diff --git a/tests/models/test_model_names.py b/tests/models/test_model_names.py index b27aa2d8c2..28c58b56ab 100644 --- a/tests/models/test_model_names.py +++ b/tests/models/test_model_names.py @@ -70,6 +70,18 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] bedrock_names = [f'bedrock:{n}' for n in get_model_names(BedrockModelName)] deepseek_names = ['deepseek:deepseek-chat', 'deepseek:deepseek-reasoner'] + cloudflare_names = [ + 'cloudflare:anthropic/claude-3-5-sonnet', + 'cloudflare:cohere/command-r-plus', + 'cloudflare:deepseek/deepseek-chat', + 'cloudflare:google/gemini-2.0-flash', + 'cloudflare:groq/llama-3.3-70b-versatile', + 'cloudflare:mistral/mistral-large-latest', + 'cloudflare:openai/gpt-4o', + 'cloudflare:perplexity/llama-3.1-sonar-small-128k-online', + 'cloudflare:workers-ai/@cf/meta/llama-3.1-8b-instruct', + 'cloudflare:xai/grok-2-1212', + ] huggingface_names = [f'huggingface:{n}' for n in get_model_names(HuggingFaceModelName)] heroku_names = get_heroku_model_names() cerebras_names = get_cerebras_model_names() @@ -86,6 +98,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: + openai_names + bedrock_names + deepseek_names + + cloudflare_names + huggingface_names + heroku_names + cerebras_names diff --git a/tests/providers/test_cloudflare.py b/tests/providers/test_cloudflare.py new file mode 100644 index 0000000000..42564cae0f --- /dev/null +++ b/tests/providers/test_cloudflare.py @@ -0,0 +1,310 @@ +import re + +import httpx +import pytest +from pytest_mock import MockerFixture + +from pydantic_ai import Agent +from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer +from pydantic_ai.exceptions import UserError +from pydantic_ai.profiles.anthropic import anthropic_model_profile +from pydantic_ai.profiles.cohere import cohere_model_profile +from pydantic_ai.profiles.deepseek import deepseek_model_profile +from pydantic_ai.profiles.google import GoogleJsonSchemaTransformer, google_model_profile +from pydantic_ai.profiles.grok import grok_model_profile +from pydantic_ai.profiles.mistral import mistral_model_profile +from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile +from pydantic_ai.profiles.perplexity import perplexity_model_profile + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + import openai + + from pydantic_ai.providers.cloudflare import CloudflareProvider + + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='openai not installed'), + pytest.mark.vcr, + pytest.mark.anyio, +] + + +def test_cloudflare_provider(): + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') + assert provider.name == 'cloudflare' + assert provider.base_url == 'https://gateway.ai.cloudflare.com/v1/test-account-id/test-gateway-id/compat' + assert isinstance(provider.client, openai.AsyncOpenAI) + assert provider.client.api_key == 'api-key' + + +def test_cloudflare_provider_need_account_id(env: TestEnv) -> None: + env.remove('CLOUDFLARE_ACCOUNT_ID') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `CLOUDFLARE_ACCOUNT_ID` environment variable ' + 'or pass it via `CloudflareProvider(account_id=...)` to use the Cloudflare provider.' + ), + ): + CloudflareProvider(gateway_id='test-gateway-id', api_key='api-key') # type: ignore[call-overload] + + +def test_cloudflare_provider_need_gateway_id(env: TestEnv) -> None: + env.remove('CLOUDFLARE_GATEWAY_ID') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `CLOUDFLARE_GATEWAY_ID` environment variable ' + 'or pass it via `CloudflareProvider(gateway_id=...)` to use the Cloudflare provider.' + ), + ): + CloudflareProvider(account_id='test-account-id', api_key='api-key') # type: ignore[call-overload] + + +def test_cloudflare_provider_from_env(env: TestEnv) -> None: + env.set('CLOUDFLARE_ACCOUNT_ID', 'env-account-id') + env.set('CLOUDFLARE_GATEWAY_ID', 'env-gateway-id') + + # Test with explicit api_key (account_id and gateway_id from env) + provider = CloudflareProvider(api_key='env-api-key') # type: ignore[call-overload] + assert provider.base_url == 'https://gateway.ai.cloudflare.com/v1/env-account-id/env-gateway-id/compat' + assert provider.client.api_key == 'env-api-key' + + +def test_cloudflare_provider_with_gateway_auth_token(): + provider = CloudflareProvider( + account_id='test-account-id', + gateway_id='test-gateway-id', + api_key='api-key', + gateway_auth_token='gateway-token', + ) + assert provider.client.default_headers['cf-aig-authorization'] == 'gateway-token' + + +def test_cloudflare_provider_gateway_auth_token_from_env(env: TestEnv) -> None: + env.set('CLOUDFLARE_ACCOUNT_ID', 'test-account-id') + env.set('CLOUDFLARE_GATEWAY_ID', 'test-gateway-id') + env.set('CLOUDFLARE_AI_GATEWAY_AUTH', 'env-gateway-token') + + provider = CloudflareProvider(api_key='api-key') # type: ignore[call-overload] + assert provider.client.default_headers['cf-aig-authorization'] == 'env-gateway-token' + + +def test_cloudflare_pass_openai_client() -> None: + openai_client = openai.AsyncOpenAI(api_key='api-key') + provider = CloudflareProvider( + account_id='test-account-id', gateway_id='test-gateway-id', openai_client=openai_client + ) + assert provider.client == openai_client + + +def test_cloudflare_provider_model_profile(mocker: MockerFixture, env: TestEnv): + # Set dummy API keys so we can use real GroqProvider and CerebrasProvider + env.set('GROQ_API_KEY', 'test-groq-key') + env.set('CEREBRAS_API_KEY', 'test-cerebras-key') + + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') + + ns = 'pydantic_ai.providers.cloudflare' + + # Mock all profile functions + anthropic_mock = mocker.patch(f'{ns}.anthropic_model_profile', wraps=anthropic_model_profile) + cohere_mock = mocker.patch(f'{ns}.cohere_model_profile', wraps=cohere_model_profile) + deepseek_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile) + google_mock = mocker.patch(f'{ns}.google_model_profile', wraps=google_model_profile) + grok_mock = mocker.patch(f'{ns}.grok_model_profile', wraps=grok_model_profile) + mistral_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile) + openai_mock = mocker.patch(f'{ns}.openai_model_profile', wraps=openai_model_profile) + perplexity_mock = mocker.patch(f'{ns}.perplexity_model_profile', wraps=perplexity_model_profile) + + # Use real GroqProvider and CerebrasProvider - they don't make API calls for model_profile() + # We just need dummy API keys which are set via env vars above + + # Test openai provider + profile = provider.model_profile('openai/gpt-4o') + openai_mock.assert_called_with('gpt-4o') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test anthropic provider + profile = provider.model_profile('anthropic/claude-3-sonnet') + anthropic_mock.assert_called_with('claude-3-sonnet') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test cohere provider + profile = provider.model_profile('cohere/command-r-plus') + cohere_mock.assert_called_with('command-r-plus') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test deepseek provider + profile = provider.model_profile('deepseek/deepseek-chat') + deepseek_mock.assert_called_with('deepseek-chat') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test mistral provider + profile = provider.model_profile('mistral/mistral-large') + mistral_mock.assert_called_with('mistral-large') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test google-ai-studio provider + profile = provider.model_profile('google-ai-studio/gemini-1.5-pro') + google_mock.assert_called_with('gemini-1.5-pro') + assert profile is not None + assert profile.json_schema_transformer == GoogleJsonSchemaTransformer + + # Test grok provider + profile = provider.model_profile('grok/grok-beta') + grok_mock.assert_called_with('grok-beta') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test xai provider (alias for grok) + profile = provider.model_profile('xai/grok-2') + grok_mock.assert_called_with('grok-2') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test groq provider with llama model (delegates to GroqProvider which returns meta profile) + # meta_model_profile uses InlineDefsJsonSchemaTransformer + profile = provider.model_profile('groq/llama-3.3-70b-versatile') + assert profile is not None + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer + + # Test groq provider with gemma model (delegates to GroqProvider which returns google profile) + # google_model_profile uses GoogleJsonSchemaTransformer + profile = provider.model_profile('groq/gemma-7b-it') + assert profile is not None + assert profile.json_schema_transformer == GoogleJsonSchemaTransformer + + # Test perplexity provider (currently returns None, falls back to OpenAI-compatible) + profile = provider.model_profile('perplexity/llama-3.1-sonar-small-128k-online') + perplexity_mock.assert_called_with('llama-3.1-sonar-small-128k-online') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test workers-ai provider (Cloudflare's own AI service) + profile = provider.model_profile('workers-ai/@cf/meta/llama-3.1-8b-instruct') + openai_mock.assert_called_with('@cf/meta/llama-3.1-8b-instruct') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test cerebras provider with llama model (delegates to CerebrasProvider which returns meta profile) + # meta_model_profile uses InlineDefsJsonSchemaTransformer, wrapped by CerebrasProvider's OpenAIModelProfile + profile = provider.model_profile('cerebras/llama3.1-8b') + assert profile is not None + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer + + # Test cerebras provider with qwen model (delegates to CerebrasProvider which returns qwen profile) + # qwen_model_profile uses InlineDefsJsonSchemaTransformer, wrapped by CerebrasProvider's OpenAIModelProfile + profile = provider.model_profile('cerebras/qwen3.5-8b') + assert profile is not None + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer + + +def test_cloudflare_with_http_client(): + http_client = httpx.AsyncClient() + provider = CloudflareProvider( + account_id='test-account-id', gateway_id='test-gateway-id', api_key='test-key', http_client=http_client + ) + assert provider.client.api_key == 'test-key' + assert ( + str(provider.client.base_url) == 'https://gateway.ai.cloudflare.com/v1/test-account-id/test-gateway-id/compat/' + ) + + +def test_cloudflare_provider_invalid_model_name(): + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') + + with pytest.raises(UserError, match="Model name must be in 'provider/model' format"): + provider.model_profile('invalid-model-name') + + +def test_cloudflare_provider_unknown_provider(): + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') + + profile = provider.model_profile('unknown/gpt-4') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + +def test_cloudflare_default_headers(): + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') + + # Check that default headers are set + assert provider.client.default_headers['http-referer'] == 'https://ai.pydantic.dev/' + assert provider.client.default_headers['x-title'] == 'pydantic-ai' + + +def test_cloudflare_provider_stored_keys(): + """Test CF-managed keys mode - API keys stored in Cloudflare dashboard (requires authenticated gateway).""" + provider = CloudflareProvider( + account_id='test-account-id', + gateway_id='test-gateway-id', + gateway_auth_token='gateway-token', + ) + # api_key is set to empty string for AsyncOpenAI to prevent Authorization header + assert provider.client.api_key == '' + assert provider.base_url == 'https://gateway.ai.cloudflare.com/v1/test-account-id/test-gateway-id/compat' + assert provider.client.default_headers['cf-aig-authorization'] == 'gateway-token' + + +def test_cloudflare_provider_missing_credentials(): + """Test that error is raised when api_key is missing and not in CF-managed keys mode.""" + with pytest.raises( + UserError, + match=re.escape('You must provide an api_key for user-managed mode.'), + ): + CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id') # type: ignore[call-overload] + + +def test_cloudflare_stored_keys_no_auth_header(): + """Test that Authorization header is not sent in CF-managed keys mode (empty api_key).""" + provider = CloudflareProvider( + account_id='test-account-id', + gateway_id='test-gateway-id', + gateway_auth_token='gateway-token', + ) + + # In CF-managed keys mode, api_key is empty string which prevents OpenAI SDK from adding Authorization header + assert provider.client.api_key == '' + assert provider.client.default_headers['cf-aig-authorization'] == 'gateway-token' + + +def test_cloudflare_documented_patterns(): + """Test the exact usage patterns from the documentation work correctly. + + This test validates the examples shown in docs/models/openai.md work as documented. + """ + from pydantic_ai.models.openai import OpenAIChatModel + + # Example 1: Basic BYOK mode (from docs) + model = OpenAIChatModel( + 'openai/gpt-4o', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + api_key='your-openai-api-key', + ), + ) + agent = Agent(model) + assert isinstance(agent.model, OpenAIChatModel) + assert agent.model.model_name == 'openai/gpt-4o' + + # Example 2: Stored keys mode (from docs) + model = OpenAIChatModel( + 'anthropic/claude-3-5-sonnet', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + gateway_auth_token='your-gateway-token', + ), + ) + agent = Agent(model) + assert isinstance(agent.model, OpenAIChatModel) + assert agent.model.model_name == 'anthropic/claude-3-5-sonnet' diff --git a/tests/providers/test_provider_names.py b/tests/providers/test_provider_names.py index d44ab68276..654272efd1 100644 --- a/tests/providers/test_provider_names.py +++ b/tests/providers/test_provider_names.py @@ -16,6 +16,7 @@ from pydantic_ai.providers.anthropic import AnthropicProvider from pydantic_ai.providers.azure import AzureProvider + from pydantic_ai.providers.cloudflare import CloudflareProvider from pydantic_ai.providers.cohere import CohereProvider from pydantic_ai.providers.deepseek import DeepSeekProvider from pydantic_ai.providers.fireworks import FireworksProvider @@ -37,6 +38,7 @@ test_infer_provider_params = [ ('anthropic', AnthropicProvider, 'ANTHROPIC_API_KEY'), + ('cloudflare', CloudflareProvider, 'CLOUDFLARE_ACCOUNT_ID'), ('cohere', CohereProvider, 'CO_API_KEY'), ('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'), ('openrouter', OpenRouterProvider, 'OPENROUTER_API_KEY'), diff --git a/tests/test_cli.py b/tests/test_cli.py index e95ff09141..125423ca24 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -138,6 +138,7 @@ def test_list_models(capfd: CaptureFixture[str]): 'anthropic', 'bedrock', 'cerebras', + 'cloudflare', 'google-vertex', 'google-gla', 'groq', diff --git a/tests/test_examples.py b/tests/test_examples.py index 87649e44b3..6e32feb532 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -170,6 +170,8 @@ def print(self, *args: Any, **kwargs: Any) -> None: env.set('AWS_DEFAULT_REGION', 'us-east-1') env.set('VERCEL_AI_GATEWAY_API_KEY', 'testing') env.set('CEREBRAS_API_KEY', 'testing') + env.set('CLOUDFLARE_ACCOUNT_ID', 'testing') + env.set('CLOUDFLARE_GATEWAY_ID', 'testing') env.set('NEBIUS_API_KEY', 'testing') env.set('HEROKU_INFERENCE_KEY', 'testing') env.set('FIREWORKS_API_KEY', 'testing')