Skip to content

Commit c5b1495

Browse files
authored
Extract openai usage using genai-prices (#3123)
1 parent 0351a14 commit c5b1495

File tree

4 files changed

+52
-53
lines changed

4 files changed

+52
-53
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
600600

601601
return ModelResponse(
602602
parts=items,
603-
usage=_map_usage(response),
603+
usage=_map_usage(response, self._provider.name, self._provider.base_url, self._model_name),
604604
model_name=response.model,
605605
timestamp=timestamp,
606606
provider_details=vendor_details or None,
@@ -631,6 +631,7 @@ async def _process_streamed_response(
631631
_response=peekable_response,
632632
_timestamp=number_to_datetime(first_chunk.created),
633633
_provider_name=self._provider.name,
634+
_provider_url=self._provider.base_url,
634635
)
635636

636637
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
@@ -1061,7 +1062,7 @@ def _process_response( # noqa: C901
10611062

10621063
return ModelResponse(
10631064
parts=items,
1064-
usage=_map_usage(response),
1065+
usage=_map_usage(response, self._provider.name, self._provider.base_url, self._model_name),
10651066
model_name=response.model,
10661067
provider_response_id=response.id,
10671068
timestamp=timestamp,
@@ -1088,6 +1089,7 @@ async def _process_streamed_response(
10881089
_response=peekable_response,
10891090
_timestamp=number_to_datetime(first_chunk.response.created_at),
10901091
_provider_name=self._provider.name,
1092+
_provider_url=self._provider.base_url,
10911093
)
10921094

10931095
@overload
@@ -1589,10 +1591,11 @@ class OpenAIStreamedResponse(StreamedResponse):
15891591
_response: AsyncIterable[ChatCompletionChunk]
15901592
_timestamp: datetime
15911593
_provider_name: str
1594+
_provider_url: str
15921595

15931596
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
15941597
async for chunk in self._response:
1595-
self._usage += _map_usage(chunk)
1598+
self._usage += _map_usage(chunk, self._provider_name, self._provider_url, self._model_name)
15961599

15971600
if chunk.id: # pragma: no branch
15981601
self.provider_response_id = chunk.id
@@ -1683,12 +1686,13 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
16831686
_response: AsyncIterable[responses.ResponseStreamEvent]
16841687
_timestamp: datetime
16851688
_provider_name: str
1689+
_provider_url: str
16861690

16871691
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
16881692
async for chunk in self._response:
16891693
# NOTE: You can inspect the builtin tools used checking the `ResponseCompletedEvent`.
16901694
if isinstance(chunk, responses.ResponseCompletedEvent):
1691-
self._usage += _map_usage(chunk.response)
1695+
self._usage += self._map_usage(chunk.response)
16921696

16931697
raw_finish_reason = (
16941698
details.reason if (details := chunk.response.incomplete_details) else chunk.response.status
@@ -1708,7 +1712,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
17081712
self.provider_response_id = chunk.response.id
17091713

17101714
elif isinstance(chunk, responses.ResponseFailedEvent): # pragma: no cover
1711-
self._usage += _map_usage(chunk.response)
1715+
self._usage += self._map_usage(chunk.response)
17121716

17131717
elif isinstance(chunk, responses.ResponseFunctionCallArgumentsDeltaEvent):
17141718
maybe_event = self._parts_manager.handle_tool_call_delta(
@@ -1722,10 +1726,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
17221726
pass # there's nothing we need to do here
17231727

17241728
elif isinstance(chunk, responses.ResponseIncompleteEvent): # pragma: no cover
1725-
self._usage += _map_usage(chunk.response)
1729+
self._usage += self._map_usage(chunk.response)
17261730

17271731
elif isinstance(chunk, responses.ResponseInProgressEvent):
1728-
self._usage += _map_usage(chunk.response)
1732+
self._usage += self._map_usage(chunk.response)
17291733

17301734
elif isinstance(chunk, responses.ResponseOutputItemAddedEvent):
17311735
if isinstance(chunk.item, responses.ResponseFunctionToolCall):
@@ -1906,6 +1910,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
19061910
UserWarning,
19071911
)
19081912

1913+
def _map_usage(self, response: responses.Response):
1914+
return _map_usage(response, self._provider_name, self._provider_url, self._model_name)
1915+
19091916
@property
19101917
def model_name(self) -> OpenAIModelName:
19111918
"""Get the model name of the response."""
@@ -1922,55 +1929,45 @@ def timestamp(self) -> datetime:
19221929
return self._timestamp
19231930

19241931

1925-
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.RequestUsage:
1932+
def _map_usage(
1933+
response: chat.ChatCompletion | ChatCompletionChunk | responses.Response,
1934+
provider: str,
1935+
provider_url: str,
1936+
model: str,
1937+
) -> usage.RequestUsage:
19261938
response_usage = response.usage
19271939
if response_usage is None:
19281940
return usage.RequestUsage()
1929-
elif isinstance(response_usage, responses.ResponseUsage):
1930-
details: dict[str, int] = {
1931-
key: value
1932-
for key, value in response_usage.model_dump(
1933-
exclude={'input_tokens', 'output_tokens', 'total_tokens'}
1934-
).items()
1935-
if isinstance(value, int)
1936-
}
1937-
# Handle vLLM compatibility - some providers don't include token details
1938-
if getattr(response_usage, 'input_tokens_details', None) is not None:
1939-
cache_read_tokens = response_usage.input_tokens_details.cached_tokens
1940-
else:
1941-
cache_read_tokens = 0
1941+
1942+
usage_data = response_usage.model_dump(exclude_none=True)
1943+
details = {
1944+
k: v
1945+
for k, v in usage_data.items()
1946+
if k not in {'prompt_tokens', 'completion_tokens', 'input_tokens', 'output_tokens', 'total_tokens'}
1947+
if isinstance(v, int)
1948+
}
1949+
response_data = dict(model=model, usage=usage_data)
1950+
if isinstance(response_usage, responses.ResponseUsage):
1951+
api_flavor = 'responses'
19421952

19431953
if getattr(response_usage, 'output_tokens_details', None) is not None:
19441954
details['reasoning_tokens'] = response_usage.output_tokens_details.reasoning_tokens
19451955
else:
19461956
details['reasoning_tokens'] = 0
1947-
1948-
return usage.RequestUsage(
1949-
input_tokens=response_usage.input_tokens,
1950-
output_tokens=response_usage.output_tokens,
1951-
cache_read_tokens=cache_read_tokens,
1952-
details=details,
1953-
)
19541957
else:
1955-
details = {
1956-
key: value
1957-
for key, value in response_usage.model_dump(
1958-
exclude_none=True, exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'}
1959-
).items()
1960-
if isinstance(value, int)
1961-
}
1962-
u = usage.RequestUsage(
1963-
input_tokens=response_usage.prompt_tokens,
1964-
output_tokens=response_usage.completion_tokens,
1965-
details=details,
1966-
)
1958+
api_flavor = 'chat'
1959+
19671960
if response_usage.completion_tokens_details is not None:
19681961
details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
1969-
u.output_audio_tokens = response_usage.completion_tokens_details.audio_tokens or 0
1970-
if response_usage.prompt_tokens_details is not None:
1971-
u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0
1972-
u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0
1973-
return u
1962+
1963+
return usage.RequestUsage.extract(
1964+
response_data,
1965+
provider=provider,
1966+
provider_url=provider_url,
1967+
provider_fallback='openai',
1968+
api_flavor=api_flavor,
1969+
details=details,
1970+
)
19741971

19751972

19761973
def _split_combined_tool_call_id(combined_id: str) -> tuple[str, str | None]:

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ dependencies = [
6060
"exceptiongroup; python_version < '3.11'",
6161
"opentelemetry-api>=1.28.0",
6262
"typing-inspection>=0.4.0",
63-
"genai-prices>=0.0.30",
63+
"genai-prices>=0.0.31",
6464
]
6565

6666
[tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies]

tests/models/mock_openai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class MockOpenAI:
2929
stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]] | None = None
3030
index: int = 0
3131
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
32+
base_url: str = 'https://api.openai.com/v1'
3233

3334
@cached_property
3435
def chat(self) -> Any:
@@ -98,6 +99,7 @@ class MockOpenAIResponses:
9899
stream: Sequence[MockResponseStreamEvent] | Sequence[Sequence[MockResponseStreamEvent]] | None = None
99100
index: int = 0
100101
response_kwargs: list[dict[str, Any]] = field(default_factory=list)
102+
base_url: str = 'https://api.openai.com/v1'
101103

102104
@cached_property
103105
def responses(self) -> Any:

uv.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)