Skip to content

Commit 509ddda

Browse files
fix(usage): Normalize None token detail objects on Usage initialization (#2141)
1 parent a05af4b commit 509ddda

File tree

3 files changed

+86
-21
lines changed

3 files changed

+86
-21
lines changed

src/agents/models/openai_chatcompletions.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from openai.types.chat.chat_completion import Choice
1212
from openai.types.responses import Response
1313
from openai.types.responses.response_prompt_param import ResponsePromptParam
14-
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
1514

1615
from .. import _debug
1716
from ..agent_output import AgentOutputSchemaBase
@@ -102,18 +101,9 @@ async def get_response(
102101
input_tokens=response.usage.prompt_tokens,
103102
output_tokens=response.usage.completion_tokens,
104103
total_tokens=response.usage.total_tokens,
105-
input_tokens_details=InputTokensDetails(
106-
cached_tokens=getattr(
107-
response.usage.prompt_tokens_details, "cached_tokens", 0
108-
)
109-
or 0,
110-
),
111-
output_tokens_details=OutputTokensDetails(
112-
reasoning_tokens=getattr(
113-
response.usage.completion_tokens_details, "reasoning_tokens", 0
114-
)
115-
or 0,
116-
),
104+
# BeforeValidator in Usage normalizes these from Chat Completions types
105+
input_tokens_details=response.usage.prompt_tokens_details, # type: ignore[arg-type]
106+
output_tokens_details=response.usage.completion_tokens_details, # type: ignore[arg-type]
117107
)
118108
if response.usage
119109
else Usage()

src/agents/usage.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,36 @@
1+
from __future__ import annotations
2+
13
from dataclasses import field
4+
from typing import Annotated
25

6+
from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails
37
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
8+
from pydantic import BeforeValidator
49
from pydantic.dataclasses import dataclass
510

611

12+
def _normalize_input_tokens_details(
13+
v: InputTokensDetails | PromptTokensDetails | None,
14+
) -> InputTokensDetails:
15+
"""Converts None or PromptTokensDetails to InputTokensDetails."""
16+
if v is None:
17+
return InputTokensDetails(cached_tokens=0)
18+
if isinstance(v, PromptTokensDetails):
19+
return InputTokensDetails(cached_tokens=v.cached_tokens or 0)
20+
return v
21+
22+
23+
def _normalize_output_tokens_details(
24+
v: OutputTokensDetails | CompletionTokensDetails | None,
25+
) -> OutputTokensDetails:
26+
"""Converts None or CompletionTokensDetails to OutputTokensDetails."""
27+
if v is None:
28+
return OutputTokensDetails(reasoning_tokens=0)
29+
if isinstance(v, CompletionTokensDetails):
30+
return OutputTokensDetails(reasoning_tokens=v.reasoning_tokens or 0)
31+
return v
32+
33+
734
@dataclass
835
class RequestUsage:
936
"""Usage details for a single API request."""
@@ -32,16 +59,16 @@ class Usage:
3259
input_tokens: int = 0
3360
"""Total input tokens sent, across all requests."""
3461

35-
input_tokens_details: InputTokensDetails = field(
36-
default_factory=lambda: InputTokensDetails(cached_tokens=0)
37-
)
62+
input_tokens_details: Annotated[
63+
InputTokensDetails, BeforeValidator(_normalize_input_tokens_details)
64+
] = field(default_factory=lambda: InputTokensDetails(cached_tokens=0))
3865
"""Details about the input tokens, matching responses API usage details."""
3966
output_tokens: int = 0
4067
"""Total output tokens received, across all requests."""
4168

42-
output_tokens_details: OutputTokensDetails = field(
43-
default_factory=lambda: OutputTokensDetails(reasoning_tokens=0)
44-
)
69+
output_tokens_details: Annotated[
70+
OutputTokensDetails, BeforeValidator(_normalize_output_tokens_details)
71+
] = field(default_factory=lambda: OutputTokensDetails(reasoning_tokens=0))
4572
"""Details about the output tokens, matching responses API usage details."""
4673

4774
total_tokens: int = 0
@@ -70,7 +97,7 @@ def __post_init__(self) -> None:
7097
if self.output_tokens_details.reasoning_tokens is None:
7198
self.output_tokens_details = OutputTokensDetails(reasoning_tokens=0)
7299

73-
def add(self, other: "Usage") -> None:
100+
def add(self, other: Usage) -> None:
74101
"""Add another Usage object to this one, aggregating all fields.
75102
76103
This method automatically preserves request_usage_entries.

tests/test_usage.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails
12
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
23

34
from agents.usage import RequestUsage, Usage
@@ -270,7 +271,24 @@ def test_anthropic_cost_calculation_scenario():
270271

271272

272273
def test_usage_normalizes_none_token_details():
273-
# Some providers don't populate optional fields, resulting in None values
274+
# Some providers don't populate optional token detail fields
275+
# (cached_tokens, reasoning_tokens), and the OpenAI SDK's generated
276+
# code can bypass Pydantic validation (e.g., via model_construct),
277+
# allowing None values. We normalize these to 0 to prevent TypeErrors.
278+
279+
# Test entire objects being None (BeforeValidator)
280+
usage = Usage(
281+
requests=1,
282+
input_tokens=100,
283+
input_tokens_details=None, # type: ignore[arg-type]
284+
output_tokens=50,
285+
output_tokens_details=None, # type: ignore[arg-type]
286+
total_tokens=150,
287+
)
288+
assert usage.input_tokens_details.cached_tokens == 0
289+
assert usage.output_tokens_details.reasoning_tokens == 0
290+
291+
# Test fields within objects being None (__post_init__)
274292
input_details = InputTokensDetails(cached_tokens=0)
275293
input_details.__dict__["cached_tokens"] = None
276294

@@ -289,3 +307,33 @@ def test_usage_normalizes_none_token_details():
289307
# __post_init__ should normalize None to 0
290308
assert usage.input_tokens_details.cached_tokens == 0
291309
assert usage.output_tokens_details.reasoning_tokens == 0
310+
311+
312+
def test_usage_normalizes_chat_completions_types():
313+
# Chat Completions API uses PromptTokensDetails and CompletionTokensDetails,
314+
# while Usage expects InputTokensDetails and OutputTokensDetails (Responses API).
315+
# The BeforeValidator should convert between these types.
316+
317+
prompt_details = PromptTokensDetails(audio_tokens=10, cached_tokens=50)
318+
completion_details = CompletionTokensDetails(
319+
accepted_prediction_tokens=5,
320+
audio_tokens=10,
321+
reasoning_tokens=100,
322+
rejected_prediction_tokens=2,
323+
)
324+
325+
usage = Usage(
326+
requests=1,
327+
input_tokens=200,
328+
input_tokens_details=prompt_details, # type: ignore[arg-type]
329+
output_tokens=150,
330+
output_tokens_details=completion_details, # type: ignore[arg-type]
331+
total_tokens=350,
332+
)
333+
334+
# Should convert to Responses API types, extracting the relevant fields
335+
assert isinstance(usage.input_tokens_details, InputTokensDetails)
336+
assert usage.input_tokens_details.cached_tokens == 50
337+
338+
assert isinstance(usage.output_tokens_details, OutputTokensDetails)
339+
assert usage.output_tokens_details.reasoning_tokens == 100

0 commit comments

Comments
 (0)