Skip to content

Commit ab5cdde

Browse files
shellmayrclaude
andauthored
feat(ai): add cache writes for gen_ai (#5319)
- Implement cache token tracking for AI integrations - Adds support for tracking cache read and cache write tokens across AI integrations, following new semantic convention that added `gen_ai.usage.input_tokens.cache_write` - Updated integrations to track cache tokens Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 1970486 commit ab5cdde

File tree

6 files changed

+227
-17
lines changed

6 files changed

+227
-17
lines changed

sentry_sdk/ai/monitoring.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def record_token_usage(
100100
span: "Span",
101101
input_tokens: "Optional[int]" = None,
102102
input_tokens_cached: "Optional[int]" = None,
103+
input_tokens_cache_write: "Optional[int]" = None,
103104
output_tokens: "Optional[int]" = None,
104105
output_tokens_reasoning: "Optional[int]" = None,
105106
total_tokens: "Optional[int]" = None,
@@ -118,6 +119,12 @@ def record_token_usage(
118119
input_tokens_cached,
119120
)
120121

122+
if input_tokens_cache_write is not None:
123+
span.set_data(
124+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE,
125+
input_tokens_cache_write,
126+
)
127+
121128
if output_tokens is not None:
122129
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
123130

sentry_sdk/consts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,12 @@ class SPANDATA:
632632
Example: 50
633633
"""
634634

635+
GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE = "gen_ai.usage.input_tokens.cache_write"
636+
"""
637+
The number of tokens written to the cache when processing the AI input (prompt).
638+
Example: 100
639+
"""
640+
635641
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
636642
"""
637643
The number of tokens in the output.

sentry_sdk/integrations/anthropic.py

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,29 +75,47 @@ def _capture_exception(exc: "Any") -> None:
7575
sentry_sdk.capture_event(event, hint=hint)
7676

7777

78-
def _get_token_usage(result: "Messages") -> "tuple[int, int]":
78+
def _get_token_usage(result: "Messages") -> "tuple[int, int, int, int]":
7979
"""
8080
Get token usage from the Anthropic response.
81+
Returns: (input_tokens, output_tokens, cache_read_input_tokens, cache_write_input_tokens)
8182
"""
8283
input_tokens = 0
8384
output_tokens = 0
85+
cache_read_input_tokens = 0
86+
cache_write_input_tokens = 0
8487
if hasattr(result, "usage"):
8588
usage = result.usage
8689
if hasattr(usage, "input_tokens") and isinstance(usage.input_tokens, int):
8790
input_tokens = usage.input_tokens
8891
if hasattr(usage, "output_tokens") and isinstance(usage.output_tokens, int):
8992
output_tokens = usage.output_tokens
90-
91-
return input_tokens, output_tokens
93+
if hasattr(usage, "cache_read_input_tokens") and isinstance(
94+
usage.cache_read_input_tokens, int
95+
):
96+
cache_read_input_tokens = usage.cache_read_input_tokens
97+
if hasattr(usage, "cache_creation_input_tokens") and isinstance(
98+
usage.cache_creation_input_tokens, int
99+
):
100+
cache_write_input_tokens = usage.cache_creation_input_tokens
101+
102+
return (
103+
input_tokens,
104+
output_tokens,
105+
cache_read_input_tokens,
106+
cache_write_input_tokens,
107+
)
92108

93109

94110
def _collect_ai_data(
95111
event: "MessageStreamEvent",
96112
model: "str | None",
97113
input_tokens: int,
98114
output_tokens: int,
115+
cache_read_input_tokens: int,
116+
cache_write_input_tokens: int,
99117
content_blocks: "list[str]",
100-
) -> "tuple[str | None, int, int, list[str]]":
118+
) -> "tuple[str | None, int, int, int, int, list[str]]":
101119
"""
102120
Collect model information, token usage, and collect content blocks from the AI streaming response.
103121
"""
@@ -107,6 +125,14 @@ def _collect_ai_data(
107125
usage = event.message.usage
108126
input_tokens += usage.input_tokens
109127
output_tokens += usage.output_tokens
128+
if hasattr(usage, "cache_read_input_tokens") and isinstance(
129+
usage.cache_read_input_tokens, int
130+
):
131+
cache_read_input_tokens += usage.cache_read_input_tokens
132+
if hasattr(usage, "cache_creation_input_tokens") and isinstance(
133+
usage.cache_creation_input_tokens, int
134+
):
135+
cache_write_input_tokens += usage.cache_creation_input_tokens
110136
model = event.message.model or model
111137
elif event.type == "content_block_start":
112138
pass
@@ -120,7 +146,14 @@ def _collect_ai_data(
120146
elif event.type == "message_delta":
121147
output_tokens += event.usage.output_tokens
122148

123-
return model, input_tokens, output_tokens, content_blocks
149+
return (
150+
model,
151+
input_tokens,
152+
output_tokens,
153+
cache_read_input_tokens,
154+
cache_write_input_tokens,
155+
content_blocks,
156+
)
124157

125158

126159
def _transform_anthropic_content_block(
@@ -265,6 +298,8 @@ def _set_output_data(
265298
model: "str | None",
266299
input_tokens: "int | None",
267300
output_tokens: "int | None",
301+
cache_read_input_tokens: "int | None",
302+
cache_write_input_tokens: "int | None",
268303
content_blocks: "list[Any]",
269304
finish_span: bool = False,
270305
) -> None:
@@ -300,6 +335,8 @@ def _set_output_data(
300335
span,
301336
input_tokens=input_tokens,
302337
output_tokens=output_tokens,
338+
input_tokens_cached=cache_read_input_tokens,
339+
input_tokens_cache_write=cache_write_input_tokens,
303340
)
304341

305342
if finish_span:
@@ -334,7 +371,12 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
334371

335372
with capture_internal_exceptions():
336373
if hasattr(result, "content"):
337-
input_tokens, output_tokens = _get_token_usage(result)
374+
(
375+
input_tokens,
376+
output_tokens,
377+
cache_read_input_tokens,
378+
cache_write_input_tokens,
379+
) = _get_token_usage(result)
338380

339381
content_blocks = []
340382
for content_block in result.content:
@@ -351,6 +393,8 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
351393
model=getattr(result, "model", None),
352394
input_tokens=input_tokens,
353395
output_tokens=output_tokens,
396+
cache_read_input_tokens=cache_read_input_tokens,
397+
cache_write_input_tokens=cache_write_input_tokens,
354398
content_blocks=content_blocks,
355399
finish_span=True,
356400
)
@@ -363,13 +407,26 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
363407
model = None
364408
input_tokens = 0
365409
output_tokens = 0
410+
cache_read_input_tokens = 0
411+
cache_write_input_tokens = 0
366412
content_blocks: "list[str]" = []
367413

368414
for event in old_iterator:
369-
model, input_tokens, output_tokens, content_blocks = (
370-
_collect_ai_data(
371-
event, model, input_tokens, output_tokens, content_blocks
372-
)
415+
(
416+
model,
417+
input_tokens,
418+
output_tokens,
419+
cache_read_input_tokens,
420+
cache_write_input_tokens,
421+
content_blocks,
422+
) = _collect_ai_data(
423+
event,
424+
model,
425+
input_tokens,
426+
output_tokens,
427+
cache_read_input_tokens,
428+
cache_write_input_tokens,
429+
content_blocks,
373430
)
374431
yield event
375432

@@ -379,6 +436,8 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
379436
model=model,
380437
input_tokens=input_tokens,
381438
output_tokens=output_tokens,
439+
cache_read_input_tokens=cache_read_input_tokens,
440+
cache_write_input_tokens=cache_write_input_tokens,
382441
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
383442
finish_span=True,
384443
)
@@ -387,13 +446,26 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
387446
model = None
388447
input_tokens = 0
389448
output_tokens = 0
449+
cache_read_input_tokens = 0
450+
cache_write_input_tokens = 0
390451
content_blocks: "list[str]" = []
391452

392453
async for event in old_iterator:
393-
model, input_tokens, output_tokens, content_blocks = (
394-
_collect_ai_data(
395-
event, model, input_tokens, output_tokens, content_blocks
396-
)
454+
(
455+
model,
456+
input_tokens,
457+
output_tokens,
458+
cache_read_input_tokens,
459+
cache_write_input_tokens,
460+
content_blocks,
461+
) = _collect_ai_data(
462+
event,
463+
model,
464+
input_tokens,
465+
output_tokens,
466+
cache_read_input_tokens,
467+
cache_write_input_tokens,
468+
content_blocks,
397469
)
398470
yield event
399471

@@ -403,6 +475,8 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
403475
model=model,
404476
input_tokens=input_tokens,
405477
output_tokens=output_tokens,
478+
cache_read_input_tokens=cache_read_input_tokens,
479+
cache_write_input_tokens=cache_write_input_tokens,
406480
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
407481
finish_span=True,
408482
)

sentry_sdk/integrations/pydantic_ai/spans/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,19 @@ def _set_usage_data(
2828
if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
2929
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens)
3030

31+
# Pydantic AI uses cache_read_tokens (not input_tokens_cached)
32+
if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
33+
span.set_data(
34+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED, usage.cache_read_tokens
35+
)
36+
37+
# Pydantic AI uses cache_write_tokens (not input_tokens_cache_write)
38+
if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
39+
span.set_data(
40+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE,
41+
usage.cache_write_tokens,
42+
)
43+
3144
if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
3245
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens)
3346

tests/integrations/anthropic/test_anthropic.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -853,10 +853,11 @@ def test_collect_ai_data_with_input_json_delta():
853853
output_tokens = 20
854854
content_blocks = []
855855

856-
model, new_input_tokens, new_output_tokens, new_content_blocks = _collect_ai_data(
857-
event, model, input_tokens, output_tokens, content_blocks
856+
model, new_input_tokens, new_output_tokens, _, _, new_content_blocks = (
857+
_collect_ai_data(
858+
event, model, input_tokens, output_tokens, 0, 0, content_blocks
859+
)
858860
)
859-
860861
assert model is None
861862
assert new_input_tokens == input_tokens
862863
assert new_output_tokens == output_tokens
@@ -884,6 +885,8 @@ def test_set_output_data_with_input_json_delta(sentry_init):
884885
model="",
885886
input_tokens=10,
886887
output_tokens=20,
888+
cache_read_input_tokens=0,
889+
cache_write_input_tokens=0,
887890
content_blocks=[{"text": "".join(json_deltas), "type": "text"}],
888891
)
889892

@@ -2154,3 +2157,83 @@ def test_binary_content_not_stored_when_prompts_disabled(sentry_init, capture_ev
21542157

21552158
# Messages should not be stored
21562159
assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"]
2160+
2161+
2162+
def test_cache_tokens_nonstreaming(sentry_init, capture_events):
2163+
"""Test cache read/write tokens are tracked for non-streaming responses."""
2164+
sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
2165+
events = capture_events()
2166+
client = Anthropic(api_key="z")
2167+
2168+
client.messages._post = mock.Mock(
2169+
return_value=Message(
2170+
id="id",
2171+
model="claude-3-5-sonnet-20241022",
2172+
role="assistant",
2173+
content=[TextBlock(type="text", text="Response")],
2174+
type="message",
2175+
usage=Usage(
2176+
input_tokens=100,
2177+
output_tokens=50,
2178+
cache_read_input_tokens=80,
2179+
cache_creation_input_tokens=20,
2180+
),
2181+
)
2182+
)
2183+
2184+
with start_transaction(name="anthropic"):
2185+
client.messages.create(
2186+
max_tokens=1024,
2187+
messages=[{"role": "user", "content": "Hello"}],
2188+
model="claude-3-5-sonnet-20241022",
2189+
)
2190+
2191+
(span,) = events[0]["spans"]
2192+
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80
2193+
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20
2194+
2195+
2196+
def test_cache_tokens_streaming(sentry_init, capture_events):
2197+
"""Test cache tokens are tracked for streaming responses."""
2198+
client = Anthropic(api_key="z")
2199+
returned_stream = Stream(cast_to=None, response=None, client=client)
2200+
returned_stream._iterator = [
2201+
MessageStartEvent(
2202+
type="message_start",
2203+
message=Message(
2204+
id="id",
2205+
model="claude-3-5-sonnet-20241022",
2206+
role="assistant",
2207+
content=[],
2208+
type="message",
2209+
usage=Usage(
2210+
input_tokens=100,
2211+
output_tokens=0,
2212+
cache_read_input_tokens=80,
2213+
cache_creation_input_tokens=20,
2214+
),
2215+
),
2216+
),
2217+
MessageDeltaEvent(
2218+
type="message_delta",
2219+
delta=Delta(stop_reason="end_turn"),
2220+
usage=MessageDeltaUsage(output_tokens=10),
2221+
),
2222+
]
2223+
2224+
sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0)
2225+
events = capture_events()
2226+
client.messages._post = mock.Mock(return_value=returned_stream)
2227+
2228+
with start_transaction(name="anthropic"):
2229+
for _ in client.messages.create(
2230+
max_tokens=1024,
2231+
messages=[{"role": "user", "content": "Hello"}],
2232+
model="claude-3-5-sonnet-20241022",
2233+
stream=True,
2234+
):
2235+
pass
2236+
2237+
(span,) = events[0]["spans"]
2238+
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80
2239+
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20

0 commit comments

Comments
 (0)