|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | from dataclasses import field |
| 4 | +from typing import Annotated |
2 | 5 |
|
| 6 | +from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails |
3 | 7 | from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails |
| 8 | +from pydantic import BeforeValidator |
4 | 9 | from pydantic.dataclasses import dataclass |
5 | 10 |
|
6 | 11 |
|
| 12 | +def _normalize_input_tokens_details( |
| 13 | + v: InputTokensDetails | PromptTokensDetails | None, |
| 14 | +) -> InputTokensDetails: |
| 15 | + """Converts None or PromptTokensDetails to InputTokensDetails.""" |
| 16 | + if v is None: |
| 17 | + return InputTokensDetails(cached_tokens=0) |
| 18 | + if isinstance(v, PromptTokensDetails): |
| 19 | + return InputTokensDetails(cached_tokens=v.cached_tokens or 0) |
| 20 | + return v |
| 21 | + |
| 22 | + |
| 23 | +def _normalize_output_tokens_details( |
| 24 | + v: OutputTokensDetails | CompletionTokensDetails | None, |
| 25 | +) -> OutputTokensDetails: |
| 26 | + """Converts None or CompletionTokensDetails to OutputTokensDetails.""" |
| 27 | + if v is None: |
| 28 | + return OutputTokensDetails(reasoning_tokens=0) |
| 29 | + if isinstance(v, CompletionTokensDetails): |
| 30 | + return OutputTokensDetails(reasoning_tokens=v.reasoning_tokens or 0) |
| 31 | + return v |
| 32 | + |
| 33 | + |
7 | 34 | @dataclass |
8 | 35 | class RequestUsage: |
9 | 36 | """Usage details for a single API request.""" |
@@ -32,16 +59,16 @@ class Usage: |
32 | 59 | input_tokens: int = 0 |
33 | 60 | """Total input tokens sent, across all requests.""" |
34 | 61 |
|
35 | | - input_tokens_details: InputTokensDetails = field( |
36 | | - default_factory=lambda: InputTokensDetails(cached_tokens=0) |
37 | | - ) |
| 62 | + input_tokens_details: Annotated[ |
| 63 | + InputTokensDetails, BeforeValidator(_normalize_input_tokens_details) |
| 64 | + ] = field(default_factory=lambda: InputTokensDetails(cached_tokens=0)) |
38 | 65 | """Details about the input tokens, matching responses API usage details.""" |
39 | 66 | output_tokens: int = 0 |
40 | 67 | """Total output tokens received, across all requests.""" |
41 | 68 |
|
42 | | - output_tokens_details: OutputTokensDetails = field( |
43 | | - default_factory=lambda: OutputTokensDetails(reasoning_tokens=0) |
44 | | - ) |
| 69 | + output_tokens_details: Annotated[ |
| 70 | + OutputTokensDetails, BeforeValidator(_normalize_output_tokens_details) |
| 71 | + ] = field(default_factory=lambda: OutputTokensDetails(reasoning_tokens=0)) |
45 | 72 | """Details about the output tokens, matching responses API usage details.""" |
46 | 73 |
|
47 | 74 | total_tokens: int = 0 |
@@ -70,7 +97,7 @@ def __post_init__(self) -> None: |
70 | 97 | if self.output_tokens_details.reasoning_tokens is None: |
71 | 98 | self.output_tokens_details = OutputTokensDetails(reasoning_tokens=0) |
72 | 99 |
|
73 | | - def add(self, other: "Usage") -> None: |
| 100 | + def add(self, other: Usage) -> None: |
74 | 101 | """Add another Usage object to this one, aggregating all fields. |
75 | 102 |
|
76 | 103 | This method automatically preserves request_usage_entries. |
|
0 commit comments