diff --git a/sentry_sdk/ai/monitoring.py b/sentry_sdk/ai/monitoring.py index 5655712d53..581e967bd4 100644 --- a/sentry_sdk/ai/monitoring.py +++ b/sentry_sdk/ai/monitoring.py @@ -100,6 +100,7 @@ def record_token_usage( span: "Span", input_tokens: "Optional[int]" = None, input_tokens_cached: "Optional[int]" = None, + input_tokens_cache_write: "Optional[int]" = None, output_tokens: "Optional[int]" = None, output_tokens_reasoning: "Optional[int]" = None, total_tokens: "Optional[int]" = None, @@ -118,6 +119,12 @@ def record_token_usage( input_tokens_cached, ) + if input_tokens_cache_write is not None: + span.set_data( + SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE, + input_tokens_cache_write, + ) + if output_tokens is not None: span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens) diff --git a/sentry_sdk/consts.py b/sentry_sdk/consts.py index 59d3997c9a..e53533018f 100644 --- a/sentry_sdk/consts.py +++ b/sentry_sdk/consts.py @@ -632,6 +632,12 @@ class SPANDATA: Example: 50 """ + GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE = "gen_ai.usage.input_tokens.cache_write" + """ + The number of tokens written to the cache when processing the AI input (prompt). + Example: 100 + """ + GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" """ The number of tokens in the output. diff --git a/sentry_sdk/integrations/anthropic.py b/sentry_sdk/integrations/anthropic.py index 45d810da72..2bc48e54e3 100644 --- a/sentry_sdk/integrations/anthropic.py +++ b/sentry_sdk/integrations/anthropic.py @@ -75,20 +75,36 @@ def _capture_exception(exc: "Any") -> None: sentry_sdk.capture_event(event, hint=hint) -def _get_token_usage(result: "Messages") -> "tuple[int, int]": +def _get_token_usage(result: "Messages") -> "tuple[int, int, int, int]": """ Get token usage from the Anthropic response. + Returns: (input_tokens, output_tokens, cache_read_input_tokens, cache_write_input_tokens) """ input_tokens = 0 output_tokens = 0 + cache_read_input_tokens = 0 + cache_write_input_tokens = 0 if hasattr(result, "usage"): usage = result.usage if hasattr(usage, "input_tokens") and isinstance(usage.input_tokens, int): input_tokens = usage.input_tokens if hasattr(usage, "output_tokens") and isinstance(usage.output_tokens, int): output_tokens = usage.output_tokens - - return input_tokens, output_tokens + if hasattr(usage, "cache_read_input_tokens") and isinstance( + usage.cache_read_input_tokens, int + ): + cache_read_input_tokens = usage.cache_read_input_tokens + if hasattr(usage, "cache_creation_input_tokens") and isinstance( + usage.cache_creation_input_tokens, int + ): + cache_write_input_tokens = usage.cache_creation_input_tokens + + return ( + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + ) def _collect_ai_data( @@ -96,8 +112,10 @@ def _collect_ai_data( model: "str | None", input_tokens: int, output_tokens: int, + cache_read_input_tokens: int, + cache_write_input_tokens: int, content_blocks: "list[str]", -) -> "tuple[str | None, int, int, list[str]]": +) -> "tuple[str | None, int, int, int, int, list[str]]": """ Collect model information, token usage, and collect content blocks from the AI streaming response. """ @@ -107,6 +125,14 @@ def _collect_ai_data( usage = event.message.usage input_tokens += usage.input_tokens output_tokens += usage.output_tokens + if hasattr(usage, "cache_read_input_tokens") and isinstance( + usage.cache_read_input_tokens, int + ): + cache_read_input_tokens += usage.cache_read_input_tokens + if hasattr(usage, "cache_creation_input_tokens") and isinstance( + usage.cache_creation_input_tokens, int + ): + cache_write_input_tokens += usage.cache_creation_input_tokens model = event.message.model or model elif event.type == "content_block_start": pass @@ -120,7 +146,14 @@ def _collect_ai_data( elif event.type == "message_delta": output_tokens += event.usage.output_tokens - return model, input_tokens, output_tokens, content_blocks + return ( + model, + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + content_blocks, + ) def _transform_anthropic_content_block( @@ -265,6 +298,8 @@ def _set_output_data( model: "str | None", input_tokens: "int | None", output_tokens: "int | None", + cache_read_input_tokens: "int | None", + cache_write_input_tokens: "int | None", content_blocks: "list[Any]", finish_span: bool = False, ) -> None: @@ -300,6 +335,8 @@ def _set_output_data( span, input_tokens=input_tokens, output_tokens=output_tokens, + input_tokens_cached=cache_read_input_tokens, + input_tokens_cache_write=cache_write_input_tokens, ) if finish_span: @@ -334,7 +371,12 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A with capture_internal_exceptions(): if hasattr(result, "content"): - input_tokens, output_tokens = _get_token_usage(result) + ( + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + ) = _get_token_usage(result) content_blocks = [] for content_block in result.content: @@ -351,6 +393,8 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A model=getattr(result, "model", None), input_tokens=input_tokens, output_tokens=output_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_write_input_tokens=cache_write_input_tokens, content_blocks=content_blocks, finish_span=True, ) @@ -363,13 +407,26 @@ def new_iterator() -> "Iterator[MessageStreamEvent]": model = None input_tokens = 0 output_tokens = 0 + cache_read_input_tokens = 0 + cache_write_input_tokens = 0 content_blocks: "list[str]" = [] for event in old_iterator: - model, input_tokens, output_tokens, content_blocks = ( - _collect_ai_data( - event, model, input_tokens, output_tokens, content_blocks - ) + ( + model, + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + content_blocks, + ) = _collect_ai_data( + event, + model, + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + content_blocks, ) yield event @@ -379,6 +436,8 @@ def new_iterator() -> "Iterator[MessageStreamEvent]": model=model, input_tokens=input_tokens, output_tokens=output_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_write_input_tokens=cache_write_input_tokens, content_blocks=[{"text": "".join(content_blocks), "type": "text"}], finish_span=True, ) @@ -387,13 +446,26 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]": model = None input_tokens = 0 output_tokens = 0 + cache_read_input_tokens = 0 + cache_write_input_tokens = 0 content_blocks: "list[str]" = [] async for event in old_iterator: - model, input_tokens, output_tokens, content_blocks = ( - _collect_ai_data( - event, model, input_tokens, output_tokens, content_blocks - ) + ( + model, + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + content_blocks, + ) = _collect_ai_data( + event, + model, + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + content_blocks, ) yield event @@ -403,6 +475,8 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]": model=model, input_tokens=input_tokens, output_tokens=output_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_write_input_tokens=cache_write_input_tokens, content_blocks=[{"text": "".join(content_blocks), "type": "text"}], finish_span=True, ) diff --git a/sentry_sdk/integrations/pydantic_ai/spans/utils.py b/sentry_sdk/integrations/pydantic_ai/spans/utils.py index 89fef172e1..5d71ade1e6 100644 --- a/sentry_sdk/integrations/pydantic_ai/spans/utils.py +++ b/sentry_sdk/integrations/pydantic_ai/spans/utils.py @@ -28,8 +28,30 @@ def _set_usage_data( if hasattr(usage, "input_tokens") and usage.input_tokens is not None: span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens) + # Pydantic AI uses cache_read_tokens (not input_tokens_cached) + if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None: + span.set_data( + SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED, usage.cache_read_tokens + ) + + # Pydantic AI uses cache_write_tokens (not input_tokens_cache_write) + if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None: + span.set_data( + SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE, + usage.cache_write_tokens, + ) + if hasattr(usage, "output_tokens") and usage.output_tokens is not None: span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens) + if ( + hasattr(usage, "output_tokens_reasoning") + and usage.output_tokens_reasoning is not None + ): + span.set_data( + SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING, + usage.output_tokens_reasoning, + ) + if hasattr(usage, "total_tokens") and usage.total_tokens is not None: span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage.total_tokens) diff --git a/tests/integrations/anthropic/test_anthropic.py b/tests/integrations/anthropic/test_anthropic.py index a8b2feba37..e8bc4648b6 100644 --- a/tests/integrations/anthropic/test_anthropic.py +++ b/tests/integrations/anthropic/test_anthropic.py @@ -853,10 +853,11 @@ def test_collect_ai_data_with_input_json_delta(): output_tokens = 20 content_blocks = [] - model, new_input_tokens, new_output_tokens, new_content_blocks = _collect_ai_data( - event, model, input_tokens, output_tokens, content_blocks + model, new_input_tokens, new_output_tokens, _, _, new_content_blocks = ( + _collect_ai_data( + event, model, input_tokens, output_tokens, 0, 0, content_blocks + ) ) - assert model is None assert new_input_tokens == input_tokens assert new_output_tokens == output_tokens @@ -884,6 +885,8 @@ def test_set_output_data_with_input_json_delta(sentry_init): model="", input_tokens=10, output_tokens=20, + cache_read_input_tokens=0, + cache_write_input_tokens=0, content_blocks=[{"text": "".join(json_deltas), "type": "text"}], ) @@ -2154,3 +2157,83 @@ def test_binary_content_not_stored_when_prompts_disabled(sentry_init, capture_ev # Messages should not be stored assert SPANDATA.GEN_AI_REQUEST_MESSAGES not in span["data"] + + +def test_cache_tokens_nonstreaming(sentry_init, capture_events): + """Test cache read/write tokens are tracked for non-streaming responses.""" + sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0) + events = capture_events() + client = Anthropic(api_key="z") + + client.messages._post = mock.Mock( + return_value=Message( + id="id", + model="claude-3-5-sonnet-20241022", + role="assistant", + content=[TextBlock(type="text", text="Response")], + type="message", + usage=Usage( + input_tokens=100, + output_tokens=50, + cache_read_input_tokens=80, + cache_creation_input_tokens=20, + ), + ) + ) + + with start_transaction(name="anthropic"): + client.messages.create( + max_tokens=1024, + messages=[{"role": "user", "content": "Hello"}], + model="claude-3-5-sonnet-20241022", + ) + + (span,) = events[0]["spans"] + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 + + +def test_cache_tokens_streaming(sentry_init, capture_events): + """Test cache tokens are tracked for streaming responses.""" + client = Anthropic(api_key="z") + returned_stream = Stream(cast_to=None, response=None, client=client) + returned_stream._iterator = [ + MessageStartEvent( + type="message_start", + message=Message( + id="id", + model="claude-3-5-sonnet-20241022", + role="assistant", + content=[], + type="message", + usage=Usage( + input_tokens=100, + output_tokens=0, + cache_read_input_tokens=80, + cache_creation_input_tokens=20, + ), + ), + ), + MessageDeltaEvent( + type="message_delta", + delta=Delta(stop_reason="end_turn"), + usage=MessageDeltaUsage(output_tokens=10), + ), + ] + + sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0) + events = capture_events() + client.messages._post = mock.Mock(return_value=returned_stream) + + with start_transaction(name="anthropic"): + for _ in client.messages.create( + max_tokens=1024, + messages=[{"role": "user", "content": "Hello"}], + model="claude-3-5-sonnet-20241022", + stream=True, + ): + pass + + (span,) = events[0]["spans"] + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 diff --git a/tests/integrations/pydantic_ai/test_pydantic_ai.py b/tests/integrations/pydantic_ai/test_pydantic_ai.py index 7387f1be15..fc6a27d714 100644 --- a/tests/integrations/pydantic_ai/test_pydantic_ai.py +++ b/tests/integrations/pydantic_ai/test_pydantic_ai.py @@ -2730,3 +2730,32 @@ async def test_binary_content_in_agent_run(sentry_init, capture_events): if "gen_ai.request.messages" in chat_span["data"]: messages_str = str(chat_span["data"]["gen_ai.request.messages"]) assert any(keyword in messages_str for keyword in ["blob", "image", "base64"]) + + +@pytest.mark.asyncio +async def test_set_usage_data_with_cache_tokens(sentry_init, capture_events): + """Test that cache_read_tokens and cache_write_tokens are tracked.""" + import sentry_sdk + from pydantic_ai.usage import RequestUsage + from sentry_sdk.integrations.pydantic_ai.spans.utils import _set_usage_data + from sentry_sdk.consts import SPANDATA + + sentry_init(integrations=[PydanticAIIntegration()], traces_sample_rate=1.0) + + events = capture_events() + + with sentry_sdk.start_transaction(op="test", name="test"): + span = sentry_sdk.start_span(op="test_span") + usage = RequestUsage( + input_tokens=100, + output_tokens=50, + cache_read_tokens=80, + cache_write_tokens=20, + ) + _set_usage_data(span, usage) + span.finish() + + (event,) = events + (span_data,) = event["spans"] + assert span_data["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 + assert span_data["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20