Skip to content

Commit ba58119

Browse files
mahiro72alexmojaki
andauthored
Add provider_url to ModelResponse and use it in cost() (#3648)
Co-authored-by: Alex Hall <alex.mojaki@gmail.com>
1 parent 5abfbaa commit ba58119

35 files changed

+514
-5
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def model_name(self) -> str:
4646
def provider_name(self) -> str:
4747
return self.response.provider_name or '' # pragma: no cover
4848

49+
@property
50+
def provider_url(self) -> str | None:
51+
return self.response.provider_url # pragma: no cover
52+
4953
@property
5054
def timestamp(self) -> datetime:
5155
return self.response.timestamp # pragma: no cover

pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def model_name(self) -> str:
5555
def provider_name(self) -> str:
5656
return self.response.provider_name or '' # pragma: no cover
5757

58+
@property
59+
def provider_url(self) -> str | None:
60+
return self.response.provider_url # pragma: no cover
61+
5862
@property
5963
def timestamp(self) -> datetime:
6064
return self.response.timestamp # pragma: no cover

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def model_name(self) -> str:
6060
def provider_name(self) -> str:
6161
return self.response.provider_name or '' # pragma: no cover
6262

63+
@property
64+
def provider_url(self) -> str | None:
65+
return self.response.provider_url # pragma: no cover
66+
6367
@property
6468
def timestamp(self) -> datetime:
6569
return self.response.timestamp # pragma: no cover

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,9 @@ class ModelResponse:
12491249
provider_name: str | None = None
12501250
"""The name of the LLM provider that generated the response."""
12511251

1252+
provider_url: str | None = None
1253+
"""The base URL of the LLM provider that generated the response."""
1254+
12521255
provider_details: Annotated[
12531256
dict[str, Any] | None,
12541257
# `vendor_details` is deprecated, but we still want to support deserializing model responses stored in a DB before the name was changed
@@ -1337,6 +1340,17 @@ def cost(self) -> genai_types.PriceCalculation:
13371340
Uses [`genai-prices`](https://github.com/pydantic/genai-prices).
13381341
"""
13391342
assert self.model_name, 'Model name is required to calculate price'
1343+
# Try matching on provider_api_url first as this is more specific, then fall back to provider_id.
1344+
if self.provider_url:
1345+
try:
1346+
return calc_price(
1347+
self.usage,
1348+
self.model_name,
1349+
provider_api_url=self.provider_url,
1350+
genai_request_timestamp=self.timestamp,
1351+
)
1352+
except LookupError:
1353+
pass
13401354
return calc_price(
13411355
self.usage,
13421356
self.model_name,

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,7 @@ def get(self) -> ModelResponse:
875875
timestamp=self.timestamp,
876876
usage=self.usage(),
877877
provider_name=self.provider_name,
878+
provider_url=self.provider_url,
878879
provider_response_id=self.provider_response_id,
879880
provider_details=self.provider_details,
880881
finish_reason=self.finish_reason,
@@ -897,6 +898,12 @@ def provider_name(self) -> str | None:
897898
"""Get the provider name."""
898899
raise NotImplementedError()
899900

901+
@property
902+
@abstractmethod
903+
def provider_url(self) -> str | None:
904+
"""Get the provider base URL."""
905+
raise NotImplementedError()
906+
900907
@property
901908
@abstractmethod
902909
def timestamp(self) -> datetime:

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
561561
model_name=response.model,
562562
provider_response_id=response.id,
563563
provider_name=self._provider.name,
564+
provider_url=self._provider.base_url,
564565
finish_reason=finish_reason,
565566
provider_details=provider_details,
566567
)
@@ -1298,6 +1299,11 @@ def provider_name(self) -> str:
12981299
"""Get the provider name."""
12991300
return self._provider_name
13001301

1302+
@property
1303+
def provider_url(self) -> str:
1304+
"""Get the provider base URL."""
1305+
return self._provider_url
1306+
13011307
@property
13021308
def timestamp(self) -> datetime:
13031309
"""Get the timestamp of the response."""

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ async def request_stream(
330330
_model_name=self.model_name,
331331
_event_stream=response['stream'],
332332
_provider_name=self._provider.name,
333+
_provider_url=self.base_url,
333334
_provider_response_id=response.get('ResponseMetadata', {}).get('RequestId', None),
334335
)
335336

@@ -381,6 +382,7 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
381382
model_name=self.model_name,
382383
provider_response_id=response_id,
383384
provider_name=self._provider.name,
385+
provider_url=self.base_url,
384386
finish_reason=finish_reason,
385387
provider_details=provider_details,
386388
)
@@ -700,6 +702,7 @@ class BedrockStreamedResponse(StreamedResponse):
700702
_model_name: BedrockModelName
701703
_event_stream: EventStream[ConverseStreamOutputTypeDef]
702704
_provider_name: str
705+
_provider_url: str
703706
_timestamp: datetime = field(default_factory=_utils.now_utc)
704707
_provider_response_id: str | None = None
705708

@@ -787,6 +790,11 @@ def provider_name(self) -> str:
787790
"""Get the provider name."""
788791
return self._provider_name
789792

793+
@property
794+
def provider_url(self) -> str:
795+
"""Get the provider base URL."""
796+
return self._provider_url
797+
790798
@property
791799
def timestamp(self) -> datetime:
792800
return self._timestamp

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def _process_response(self, response: V2ChatResponse) -> ModelResponse:
225225
usage=_map_usage(response),
226226
model_name=self._model_name,
227227
provider_name=self._provider.name,
228+
provider_url=self.base_url,
228229
finish_reason=finish_reason,
229230
provider_details=provider_details,
230231
)

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,11 @@ def provider_name(self) -> None:
347347
"""Get the provider name."""
348348
return None
349349

350+
@property
351+
def provider_url(self) -> None:
352+
"""Get the provider base URL."""
353+
return None
354+
350355
@property
351356
def timestamp(self) -> datetime:
352357
"""Get the timestamp of the response."""

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def __init__(
148148

149149
@property
150150
def base_url(self) -> str:
151-
assert self._url is not None, 'URL not initialized' # pragma: no cover
152-
return self._url # pragma: no cover
151+
assert self._url is not None, 'URL not initialized'
152+
return self._url
153153

154154
@property
155155
def model_name(self) -> GeminiModelName:
@@ -298,6 +298,7 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
298298
usage,
299299
vendor_id=vendor_id,
300300
vendor_details=vendor_details,
301+
provider_url=self.base_url,
301302
)
302303

303304
async def _process_streamed_response(
@@ -329,6 +330,7 @@ async def _process_streamed_response(
329330
_content=content,
330331
_stream=aiter_bytes,
331332
_provider_name=self._provider.name,
333+
_provider_url=self.base_url,
332334
)
333335

334336
async def _message_to_gemini_content(
@@ -453,6 +455,7 @@ class GeminiStreamedResponse(StreamedResponse):
453455
_content: bytearray
454456
_stream: AsyncIterator[bytes]
455457
_provider_name: str
458+
_provider_url: str
456459
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
457460

458461
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
@@ -527,6 +530,11 @@ def provider_name(self) -> str:
527530
"""Get the provider name."""
528531
return self._provider_name
529532

533+
@property
534+
def provider_url(self) -> str:
535+
"""Get the provider base URL."""
536+
return self._provider_url
537+
530538
@property
531539
def timestamp(self) -> datetime:
532540
"""Get the timestamp of the response."""
@@ -713,6 +721,7 @@ def _process_response_from_parts(
713721
model_name: GeminiModelName,
714722
usage: usage.RequestUsage,
715723
vendor_id: str | None,
724+
provider_url: str,
716725
vendor_details: dict[str, Any] | None = None,
717726
) -> ModelResponse:
718727
items: list[ModelResponsePart] = []
@@ -731,7 +740,12 @@ def _process_response_from_parts(
731740
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
732741
)
733742
return ModelResponse(
734-
parts=items, usage=usage, model_name=model_name, provider_response_id=vendor_id, provider_details=vendor_details
743+
parts=items,
744+
usage=usage,
745+
model_name=model_name,
746+
provider_response_id=vendor_id,
747+
provider_details=vendor_details,
748+
provider_url=provider_url,
735749
)
736750

737751

0 commit comments

Comments
 (0)