Skip to content
69 changes: 51 additions & 18 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from ..output import OutputMode
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
from ..providers import infer_provider
from ..providers import Provider, infer_provider, infer_provider_class
from ..settings import ModelSettings, merge_model_settings
from ..tools import ToolDefinition
from ..usage import RequestUsage
Expand Down Expand Up @@ -636,14 +636,26 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]:
ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition]


def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
"""Infer the model from the name."""
if isinstance(model, Model):
return model
elif model == 'test':
@dataclass
class ModelClassInformation:
"""Metadata of model as parsed from string with the model class and provider class for instantiation."""

model_class: type[Model]
"""The raw model class"""
provider_class: type[Provider[Any]] | None
"""The raw provider class"""
model_name: str
"""The model name as parsed from input string"""
provider_name: str | None
"""The provider name as parsed from input string"""


def infer_provider_model_class(model: KnownModelName | str) -> ModelClassInformation: # C901
"""Infer the model and provider from the name."""
if model == 'test':
from .test import TestModel

return TestModel()
return ModelClassInformation(model_class=TestModel, provider_class=None, model_name='test', provider_name=None)

try:
provider_name, model_name = model.split(':', maxsplit=1)
Expand Down Expand Up @@ -672,8 +684,6 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
)
provider_name = 'google-vertex'

provider = infer_provider(provider_name)

model_kind = provider_name
if model_kind.startswith('gateway/'):
model_kind = provider_name.removeprefix('gateway/')
Expand All @@ -688,6 +698,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
'heroku',
'moonshotai',
'ollama',
'openai-chat',
'openrouter',
'together',
'vercel',
Expand All @@ -699,45 +710,67 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
elif model_kind in ('google-gla', 'google-vertex'):
model_kind = 'google'

inferred_model: type[Model]
if model_kind == 'openai-chat':
from .openai import OpenAIChatModel

return OpenAIChatModel(model_name, provider=provider)
inferred_model = OpenAIChatModel
elif model_kind == 'openai-responses':
from .openai import OpenAIResponsesModel

return OpenAIResponsesModel(model_name, provider=provider)
inferred_model = OpenAIResponsesModel
elif model_kind == 'google':
from .google import GoogleModel

return GoogleModel(model_name, provider=provider)
inferred_model = GoogleModel
elif model_kind == 'groq':
from .groq import GroqModel

return GroqModel(model_name, provider=provider)
inferred_model = GroqModel
elif model_kind == 'cohere':
from .cohere import CohereModel

return CohereModel(model_name, provider=provider)
inferred_model = CohereModel
elif model_kind == 'mistral':
from .mistral import MistralModel

return MistralModel(model_name, provider=provider)
inferred_model = MistralModel
elif model_kind == 'anthropic':
from .anthropic import AnthropicModel

return AnthropicModel(model_name, provider=provider)
inferred_model = AnthropicModel
elif model_kind == 'bedrock':
from .bedrock import BedrockConverseModel

return BedrockConverseModel(model_name, provider=provider)
inferred_model = BedrockConverseModel
elif model_kind == 'huggingface':
from .huggingface import HuggingFaceModel

return HuggingFaceModel(model_name, provider=provider)
inferred_model = HuggingFaceModel
else:
raise UserError(f'Unknown model: {model}') # pragma: no cover

return ModelClassInformation(
model_class=inferred_model,
provider_class=infer_provider_class(provider_name) if not provider_name.startswith('gateway/') else None,
model_name=model_name,
provider_name=provider_name,
)


def infer_model(model: Model | KnownModelName | str) -> Model:
"""Infer the model from the name."""
if isinstance(model, Model):
return model

model_information = infer_provider_model_class(model)
if model_information.provider_name is None:
return model_information.model_class()
return model_information.model_class(
model_name=model_information.model_name,
provider=infer_provider(model_information.provider_name),
)


def cached_async_http_client(*, provider: str | None = None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
"""Cached HTTPX async client that creates a separate client for each provider.
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __repr__(self) -> str:
return f'{self.__class__.__name__}(name={self.name}, base_url={self.base_url})' # pragma: lax no cover


def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
def infer_provider_class(provider: str) -> type[Provider[Any]]: # C901
"""Infers the provider class from the provider name."""
if provider in ('openai', 'openai-chat', 'openai-responses'):
from .openai import OpenAIProvider
Expand Down
Loading