Skip to content

Commit 5c9cffb

Browse files
committed
feat(ai): add cache writes for gen_ai
1 parent 50c790c commit 5c9cffb

File tree

11 files changed

+915
-24
lines changed

11 files changed

+915
-24
lines changed

sentry_sdk/ai/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
1+
from .monitoring import record_token_usage # noqa: F401
12
from .utils import (
23
set_data_normalized,
34
GEN_AI_MESSAGE_ROLE_MAPPING,
45
GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING,
56
normalize_message_role,
67
normalize_message_roles,
78
) # noqa: F401
9+
10+
__all__ = [
11+
"record_token_usage",
12+
"set_data_normalized",
13+
"GEN_AI_MESSAGE_ROLE_MAPPING",
14+
"GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING",
15+
"normalize_message_role",
16+
"normalize_message_roles",
17+
]

sentry_sdk/ai/monitoring.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def record_token_usage(
9595
span: "Span",
9696
input_tokens: "Optional[int]" = None,
9797
input_tokens_cached: "Optional[int]" = None,
98+
input_tokens_cache_write: "Optional[int]" = None,
9899
output_tokens: "Optional[int]" = None,
99100
output_tokens_reasoning: "Optional[int]" = None,
100101
total_tokens: "Optional[int]" = None,
@@ -113,6 +114,12 @@ def record_token_usage(
113114
input_tokens_cached,
114115
)
115116

117+
if input_tokens_cache_write is not None:
118+
span.set_data(
119+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE,
120+
input_tokens_cache_write,
121+
)
122+
116123
if output_tokens is not None:
117124
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
118125

sentry_sdk/consts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,12 @@ class SPANDATA:
632632
Example: 50
633633
"""
634634

635+
GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE = "gen_ai.usage.input_tokens.cache_write"
636+
"""
637+
The number of tokens written to the cache when processing the AI input (prompt).
638+
Example: 100
639+
"""
640+
635641
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
636642
"""
637643
The number of tokens in the output.

sentry_sdk/integrations/anthropic.py

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,29 +72,47 @@ def _capture_exception(exc: "Any") -> None:
7272
sentry_sdk.capture_event(event, hint=hint)
7373

7474

75-
def _get_token_usage(result: "Messages") -> "tuple[int, int]":
75+
def _get_token_usage(result: "Messages") -> "tuple[int, int, int, int]":
7676
"""
7777
Get token usage from the Anthropic response.
78+
Returns: (input_tokens, output_tokens, cache_read_input_tokens, cache_write_input_tokens)
7879
"""
7980
input_tokens = 0
8081
output_tokens = 0
82+
cache_read_input_tokens = 0
83+
cache_write_input_tokens = 0
8184
if hasattr(result, "usage"):
8285
usage = result.usage
8386
if hasattr(usage, "input_tokens") and isinstance(usage.input_tokens, int):
8487
input_tokens = usage.input_tokens
8588
if hasattr(usage, "output_tokens") and isinstance(usage.output_tokens, int):
8689
output_tokens = usage.output_tokens
87-
88-
return input_tokens, output_tokens
90+
if hasattr(usage, "cache_read_input_tokens") and isinstance(
91+
usage.cache_read_input_tokens, int
92+
):
93+
cache_read_input_tokens = usage.cache_read_input_tokens
94+
if hasattr(usage, "cache_creation_input_tokens") and isinstance(
95+
usage.cache_creation_input_tokens, int
96+
):
97+
cache_write_input_tokens = usage.cache_creation_input_tokens
98+
99+
return (
100+
input_tokens,
101+
output_tokens,
102+
cache_read_input_tokens,
103+
cache_write_input_tokens,
104+
)
89105

90106

91107
def _collect_ai_data(
92108
event: "MessageStreamEvent",
93109
model: "str | None",
94110
input_tokens: int,
95111
output_tokens: int,
112+
cache_read_input_tokens: int,
113+
cache_write_input_tokens: int,
96114
content_blocks: "list[str]",
97-
) -> "tuple[str | None, int, int, list[str]]":
115+
) -> "tuple[str | None, int, int, int, int, list[str]]":
98116
"""
99117
Collect model information, token usage, and collect content blocks from the AI streaming response.
100118
"""
@@ -104,6 +122,14 @@ def _collect_ai_data(
104122
usage = event.message.usage
105123
input_tokens += usage.input_tokens
106124
output_tokens += usage.output_tokens
125+
if hasattr(usage, "cache_read_input_tokens") and isinstance(
126+
usage.cache_read_input_tokens, int
127+
):
128+
cache_read_input_tokens += usage.cache_read_input_tokens
129+
if hasattr(usage, "cache_creation_input_tokens") and isinstance(
130+
usage.cache_creation_input_tokens, int
131+
):
132+
cache_write_input_tokens += usage.cache_creation_input_tokens
107133
model = event.message.model or model
108134
elif event.type == "content_block_start":
109135
pass
@@ -117,7 +143,14 @@ def _collect_ai_data(
117143
elif event.type == "message_delta":
118144
output_tokens += event.usage.output_tokens
119145

120-
return model, input_tokens, output_tokens, content_blocks
146+
return (
147+
model,
148+
input_tokens,
149+
output_tokens,
150+
cache_read_input_tokens,
151+
cache_write_input_tokens,
152+
content_blocks,
153+
)
121154

122155

123156
def _set_input_data(
@@ -219,6 +252,8 @@ def _set_output_data(
219252
model: "str | None",
220253
input_tokens: "int | None",
221254
output_tokens: "int | None",
255+
cache_read_input_tokens: "int | None",
256+
cache_write_input_tokens: "int | None",
222257
content_blocks: "list[Any]",
223258
finish_span: bool = False,
224259
) -> None:
@@ -254,6 +289,8 @@ def _set_output_data(
254289
span,
255290
input_tokens=input_tokens,
256291
output_tokens=output_tokens,
292+
input_tokens_cached=cache_read_input_tokens,
293+
input_tokens_cache_write=cache_write_input_tokens,
257294
)
258295

259296
if finish_span:
@@ -288,7 +325,12 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
288325

289326
with capture_internal_exceptions():
290327
if hasattr(result, "content"):
291-
input_tokens, output_tokens = _get_token_usage(result)
328+
(
329+
input_tokens,
330+
output_tokens,
331+
cache_read_input_tokens,
332+
cache_write_input_tokens,
333+
) = _get_token_usage(result)
292334

293335
content_blocks = []
294336
for content_block in result.content:
@@ -305,6 +347,8 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
305347
model=getattr(result, "model", None),
306348
input_tokens=input_tokens,
307349
output_tokens=output_tokens,
350+
cache_read_input_tokens=cache_read_input_tokens,
351+
cache_write_input_tokens=cache_write_input_tokens,
308352
content_blocks=content_blocks,
309353
finish_span=True,
310354
)
@@ -317,13 +361,26 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
317361
model = None
318362
input_tokens = 0
319363
output_tokens = 0
364+
cache_read_input_tokens = 0
365+
cache_write_input_tokens = 0
320366
content_blocks: "list[str]" = []
321367

322368
for event in old_iterator:
323-
model, input_tokens, output_tokens, content_blocks = (
324-
_collect_ai_data(
325-
event, model, input_tokens, output_tokens, content_blocks
326-
)
369+
(
370+
model,
371+
input_tokens,
372+
output_tokens,
373+
cache_read_input_tokens,
374+
cache_write_input_tokens,
375+
content_blocks,
376+
) = _collect_ai_data(
377+
event,
378+
model,
379+
input_tokens,
380+
output_tokens,
381+
cache_read_input_tokens,
382+
cache_write_input_tokens,
383+
content_blocks,
327384
)
328385
yield event
329386

@@ -333,6 +390,8 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
333390
model=model,
334391
input_tokens=input_tokens,
335392
output_tokens=output_tokens,
393+
cache_read_input_tokens=cache_read_input_tokens,
394+
cache_write_input_tokens=cache_write_input_tokens,
336395
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
337396
finish_span=True,
338397
)
@@ -341,13 +400,26 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
341400
model = None
342401
input_tokens = 0
343402
output_tokens = 0
403+
cache_read_input_tokens = 0
404+
cache_write_input_tokens = 0
344405
content_blocks: "list[str]" = []
345406

346407
async for event in old_iterator:
347-
model, input_tokens, output_tokens, content_blocks = (
348-
_collect_ai_data(
349-
event, model, input_tokens, output_tokens, content_blocks
350-
)
408+
(
409+
model,
410+
input_tokens,
411+
output_tokens,
412+
cache_read_input_tokens,
413+
cache_write_input_tokens,
414+
content_blocks,
415+
) = _collect_ai_data(
416+
event,
417+
model,
418+
input_tokens,
419+
output_tokens,
420+
cache_read_input_tokens,
421+
cache_write_input_tokens,
422+
content_blocks,
351423
)
352424
yield event
353425

@@ -357,6 +429,8 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
357429
model=model,
358430
input_tokens=input_tokens,
359431
output_tokens=output_tokens,
432+
cache_read_input_tokens=cache_read_input_tokens,
433+
cache_write_input_tokens=cache_write_input_tokens,
360434
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
361435
finish_span=True,
362436
)

sentry_sdk/integrations/litellm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@ def _success_callback(
192192
record_token_usage(
193193
span,
194194
input_tokens=getattr(usage, "prompt_tokens", None),
195+
input_tokens_cached=getattr(usage, "cache_read_input_tokens", None),
196+
input_tokens_cache_write=getattr(
197+
usage, "cache_write_input_tokens", None
198+
),
195199
output_tokens=getattr(usage, "completion_tokens", None),
196200
total_tokens=getattr(usage, "total_tokens", None),
197201
)

sentry_sdk/integrations/openai.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,12 @@ def _calculate_token_usage(
131131

132132
if hasattr(response, "usage"):
133133
input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"])
134-
if hasattr(response.usage, "input_tokens_details"):
134+
if hasattr(response.usage, "prompt_tokens_details"):
135+
input_tokens_cached = _get_usage(
136+
response.usage.prompt_tokens_details, ["cached_tokens"]
137+
)
138+
# OpenAI also supports input_tokens_details for compatibility
139+
elif hasattr(response.usage, "input_tokens_details"):
135140
input_tokens_cached = _get_usage(
136141
response.usage.input_tokens_details, ["cached_tokens"]
137142
)
@@ -143,6 +148,10 @@ def _calculate_token_usage(
143148
output_tokens_reasoning = _get_usage(
144149
response.usage.output_tokens_details, ["reasoning_tokens"]
145150
)
151+
elif hasattr(response.usage, "completion_tokens_details"):
152+
output_tokens_reasoning = _get_usage(
153+
response.usage.completion_tokens_details, ["reasoning_tokens"]
154+
)
146155

147156
total_tokens = _get_usage(response.usage, ["total_tokens"])
148157

sentry_sdk/integrations/openai_agents/utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,23 @@ def _set_agent_data(span: "sentry_sdk.tracing.Span", agent: "agents.Agent") -> N
9696

9797
def _set_usage_data(span: "sentry_sdk.tracing.Span", usage: "Usage") -> None:
9898
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens)
99-
span.set_data(
100-
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
101-
usage.input_tokens_details.cached_tokens,
102-
)
99+
100+
if hasattr(usage, "input_tokens_details") and usage.input_tokens_details:
101+
if hasattr(usage.input_tokens_details, "cached_tokens"):
102+
span.set_data(
103+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
104+
usage.input_tokens_details.cached_tokens,
105+
)
106+
103107
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens)
104-
span.set_data(
105-
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
106-
usage.output_tokens_details.reasoning_tokens,
107-
)
108+
109+
if hasattr(usage, "output_tokens_details") and usage.output_tokens_details:
110+
if hasattr(usage.output_tokens_details, "reasoning_tokens"):
111+
span.set_data(
112+
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
113+
usage.output_tokens_details.reasoning_tokens,
114+
)
115+
108116
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage.total_tokens)
109117

110118

sentry_sdk/integrations/pydantic_ai/spans/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,30 @@ def _set_usage_data(
2828
if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
2929
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens)
3030

31+
# Pydantic AI uses cache_read_tokens (not input_tokens_cached)
32+
if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
33+
span.set_data(
34+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED, usage.cache_read_tokens
35+
)
36+
37+
# Pydantic AI uses cache_write_tokens (not input_tokens_cache_write)
38+
if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
39+
span.set_data(
40+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE,
41+
usage.cache_write_tokens,
42+
)
43+
3144
if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
3245
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens)
3346

47+
if (
48+
hasattr(usage, "output_tokens_reasoning")
49+
and usage.output_tokens_reasoning is not None
50+
):
51+
span.set_data(
52+
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
53+
usage.output_tokens_reasoning,
54+
)
55+
3456
if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
3557
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage.total_tokens)

0 commit comments

Comments
 (0)