From 5c9cffb159e445af0e7b47aebb9951dc0d7030f0 Mon Sep 17 00:00:00 2001 From: Simon Hellmayr Date: Wed, 14 Jan 2026 11:15:57 +0100 Subject: [PATCH 1/3] feat(ai): add cache writes for gen_ai --- sentry_sdk/ai/__init__.py | 10 + sentry_sdk/ai/monitoring.py | 7 + sentry_sdk/consts.py | 6 + sentry_sdk/integrations/anthropic.py | 102 +++++- sentry_sdk/integrations/litellm.py | 4 + sentry_sdk/integrations/openai.py | 11 +- .../integrations/openai_agents/utils.py | 24 +- .../integrations/pydantic_ai/spans/utils.py | 22 ++ .../integrations/anthropic/test_anthropic.py | 286 +++++++++++++++++ tests/integrations/litellm/test_litellm.py | 175 +++++++++++ tests/integrations/openai/test_openai.py | 292 +++++++++++++++++- 11 files changed, 915 insertions(+), 24 deletions(-) diff --git a/sentry_sdk/ai/__init__.py b/sentry_sdk/ai/__init__.py index fbcb9c061d..6eeeed3d76 100644 --- a/sentry_sdk/ai/__init__.py +++ b/sentry_sdk/ai/__init__.py @@ -1,3 +1,4 @@ +from .monitoring import record_token_usage # noqa: F401 from .utils import ( set_data_normalized, GEN_AI_MESSAGE_ROLE_MAPPING, @@ -5,3 +6,12 @@ normalize_message_role, normalize_message_roles, ) # noqa: F401 + +__all__ = [ + "record_token_usage", + "set_data_normalized", + "GEN_AI_MESSAGE_ROLE_MAPPING", + "GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING", + "normalize_message_role", + "normalize_message_roles", +] diff --git a/sentry_sdk/ai/monitoring.py b/sentry_sdk/ai/monitoring.py index e7e00ad462..33c397bbb5 100644 --- a/sentry_sdk/ai/monitoring.py +++ b/sentry_sdk/ai/monitoring.py @@ -95,6 +95,7 @@ def record_token_usage( span: "Span", input_tokens: "Optional[int]" = None, input_tokens_cached: "Optional[int]" = None, + input_tokens_cache_write: "Optional[int]" = None, output_tokens: "Optional[int]" = None, output_tokens_reasoning: "Optional[int]" = None, total_tokens: "Optional[int]" = None, @@ -113,6 +114,12 @@ def record_token_usage( input_tokens_cached, ) + if input_tokens_cache_write is not None: + span.set_data( + SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE, + input_tokens_cache_write, + ) + if output_tokens is not None: span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens) diff --git a/sentry_sdk/consts.py b/sentry_sdk/consts.py index 59d3997c9a..e53533018f 100644 --- a/sentry_sdk/consts.py +++ b/sentry_sdk/consts.py @@ -632,6 +632,12 @@ class SPANDATA: Example: 50 """ + GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE = "gen_ai.usage.input_tokens.cache_write" + """ + The number of tokens written to the cache when processing the AI input (prompt). + Example: 100 + """ + GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" """ The number of tokens in the output. diff --git a/sentry_sdk/integrations/anthropic.py b/sentry_sdk/integrations/anthropic.py index 5257e3bf60..6fa00a38e6 100644 --- a/sentry_sdk/integrations/anthropic.py +++ b/sentry_sdk/integrations/anthropic.py @@ -72,20 +72,36 @@ def _capture_exception(exc: "Any") -> None: sentry_sdk.capture_event(event, hint=hint) -def _get_token_usage(result: "Messages") -> "tuple[int, int]": +def _get_token_usage(result: "Messages") -> "tuple[int, int, int, int]": """ Get token usage from the Anthropic response. + Returns: (input_tokens, output_tokens, cache_read_input_tokens, cache_write_input_tokens) """ input_tokens = 0 output_tokens = 0 + cache_read_input_tokens = 0 + cache_write_input_tokens = 0 if hasattr(result, "usage"): usage = result.usage if hasattr(usage, "input_tokens") and isinstance(usage.input_tokens, int): input_tokens = usage.input_tokens if hasattr(usage, "output_tokens") and isinstance(usage.output_tokens, int): output_tokens = usage.output_tokens - - return input_tokens, output_tokens + if hasattr(usage, "cache_read_input_tokens") and isinstance( + usage.cache_read_input_tokens, int + ): + cache_read_input_tokens = usage.cache_read_input_tokens + if hasattr(usage, "cache_creation_input_tokens") and isinstance( + usage.cache_creation_input_tokens, int + ): + cache_write_input_tokens = usage.cache_creation_input_tokens + + return ( + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + ) def _collect_ai_data( @@ -93,8 +109,10 @@ def _collect_ai_data( model: "str | None", input_tokens: int, output_tokens: int, + cache_read_input_tokens: int, + cache_write_input_tokens: int, content_blocks: "list[str]", -) -> "tuple[str | None, int, int, list[str]]": +) -> "tuple[str | None, int, int, int, int, list[str]]": """ Collect model information, token usage, and collect content blocks from the AI streaming response. """ @@ -104,6 +122,14 @@ def _collect_ai_data( usage = event.message.usage input_tokens += usage.input_tokens output_tokens += usage.output_tokens + if hasattr(usage, "cache_read_input_tokens") and isinstance( + usage.cache_read_input_tokens, int + ): + cache_read_input_tokens += usage.cache_read_input_tokens + if hasattr(usage, "cache_creation_input_tokens") and isinstance( + usage.cache_creation_input_tokens, int + ): + cache_write_input_tokens += usage.cache_creation_input_tokens model = event.message.model or model elif event.type == "content_block_start": pass @@ -117,7 +143,14 @@ def _collect_ai_data( elif event.type == "message_delta": output_tokens += event.usage.output_tokens - return model, input_tokens, output_tokens, content_blocks + return ( + model, + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + content_blocks, + ) def _set_input_data( @@ -219,6 +252,8 @@ def _set_output_data( model: "str | None", input_tokens: "int | None", output_tokens: "int | None", + cache_read_input_tokens: "int | None", + cache_write_input_tokens: "int | None", content_blocks: "list[Any]", finish_span: bool = False, ) -> None: @@ -254,6 +289,8 @@ def _set_output_data( span, input_tokens=input_tokens, output_tokens=output_tokens, + input_tokens_cached=cache_read_input_tokens, + input_tokens_cache_write=cache_write_input_tokens, ) if finish_span: @@ -288,7 +325,12 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A with capture_internal_exceptions(): if hasattr(result, "content"): - input_tokens, output_tokens = _get_token_usage(result) + ( + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + ) = _get_token_usage(result) content_blocks = [] for content_block in result.content: @@ -305,6 +347,8 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A model=getattr(result, "model", None), input_tokens=input_tokens, output_tokens=output_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_write_input_tokens=cache_write_input_tokens, content_blocks=content_blocks, finish_span=True, ) @@ -317,13 +361,26 @@ def new_iterator() -> "Iterator[MessageStreamEvent]": model = None input_tokens = 0 output_tokens = 0 + cache_read_input_tokens = 0 + cache_write_input_tokens = 0 content_blocks: "list[str]" = [] for event in old_iterator: - model, input_tokens, output_tokens, content_blocks = ( - _collect_ai_data( - event, model, input_tokens, output_tokens, content_blocks - ) + ( + model, + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + content_blocks, + ) = _collect_ai_data( + event, + model, + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + content_blocks, ) yield event @@ -333,6 +390,8 @@ def new_iterator() -> "Iterator[MessageStreamEvent]": model=model, input_tokens=input_tokens, output_tokens=output_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_write_input_tokens=cache_write_input_tokens, content_blocks=[{"text": "".join(content_blocks), "type": "text"}], finish_span=True, ) @@ -341,13 +400,26 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]": model = None input_tokens = 0 output_tokens = 0 + cache_read_input_tokens = 0 + cache_write_input_tokens = 0 content_blocks: "list[str]" = [] async for event in old_iterator: - model, input_tokens, output_tokens, content_blocks = ( - _collect_ai_data( - event, model, input_tokens, output_tokens, content_blocks - ) + ( + model, + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + content_blocks, + ) = _collect_ai_data( + event, + model, + input_tokens, + output_tokens, + cache_read_input_tokens, + cache_write_input_tokens, + content_blocks, ) yield event @@ -357,6 +429,8 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]": model=model, input_tokens=input_tokens, output_tokens=output_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_write_input_tokens=cache_write_input_tokens, content_blocks=[{"text": "".join(content_blocks), "type": "text"}], finish_span=True, ) diff --git a/sentry_sdk/integrations/litellm.py b/sentry_sdk/integrations/litellm.py index 06ae8a0782..0d34e2e949 100644 --- a/sentry_sdk/integrations/litellm.py +++ b/sentry_sdk/integrations/litellm.py @@ -192,6 +192,10 @@ def _success_callback( record_token_usage( span, input_tokens=getattr(usage, "prompt_tokens", None), + input_tokens_cached=getattr(usage, "cache_read_input_tokens", None), + input_tokens_cache_write=getattr( + usage, "cache_write_input_tokens", None + ), output_tokens=getattr(usage, "completion_tokens", None), total_tokens=getattr(usage, "total_tokens", None), ) diff --git a/sentry_sdk/integrations/openai.py b/sentry_sdk/integrations/openai.py index a2c7cc8d1d..13757d59d2 100644 --- a/sentry_sdk/integrations/openai.py +++ b/sentry_sdk/integrations/openai.py @@ -131,7 +131,12 @@ def _calculate_token_usage( if hasattr(response, "usage"): input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"]) - if hasattr(response.usage, "input_tokens_details"): + if hasattr(response.usage, "prompt_tokens_details"): + input_tokens_cached = _get_usage( + response.usage.prompt_tokens_details, ["cached_tokens"] + ) + # OpenAI also supports input_tokens_details for compatibility + elif hasattr(response.usage, "input_tokens_details"): input_tokens_cached = _get_usage( response.usage.input_tokens_details, ["cached_tokens"] ) @@ -143,6 +148,10 @@ def _calculate_token_usage( output_tokens_reasoning = _get_usage( response.usage.output_tokens_details, ["reasoning_tokens"] ) + elif hasattr(response.usage, "completion_tokens_details"): + output_tokens_reasoning = _get_usage( + response.usage.completion_tokens_details, ["reasoning_tokens"] + ) total_tokens = _get_usage(response.usage, ["total_tokens"]) diff --git a/sentry_sdk/integrations/openai_agents/utils.py b/sentry_sdk/integrations/openai_agents/utils.py index a24d0e909d..d45b8546c6 100644 --- a/sentry_sdk/integrations/openai_agents/utils.py +++ b/sentry_sdk/integrations/openai_agents/utils.py @@ -96,15 +96,23 @@ def _set_agent_data(span: "sentry_sdk.tracing.Span", agent: "agents.Agent") -> N def _set_usage_data(span: "sentry_sdk.tracing.Span", usage: "Usage") -> None: span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens) - span.set_data( - SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED, - usage.input_tokens_details.cached_tokens, - ) + + if hasattr(usage, "input_tokens_details") and usage.input_tokens_details: + if hasattr(usage.input_tokens_details, "cached_tokens"): + span.set_data( + SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED, + usage.input_tokens_details.cached_tokens, + ) + span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens) - span.set_data( - SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING, - usage.output_tokens_details.reasoning_tokens, - ) + + if hasattr(usage, "output_tokens_details") and usage.output_tokens_details: + if hasattr(usage.output_tokens_details, "reasoning_tokens"): + span.set_data( + SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING, + usage.output_tokens_details.reasoning_tokens, + ) + span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage.total_tokens) diff --git a/sentry_sdk/integrations/pydantic_ai/spans/utils.py b/sentry_sdk/integrations/pydantic_ai/spans/utils.py index c70afd5f31..035234b16b 100644 --- a/sentry_sdk/integrations/pydantic_ai/spans/utils.py +++ b/sentry_sdk/integrations/pydantic_ai/spans/utils.py @@ -28,8 +28,30 @@ def _set_usage_data( if hasattr(usage, "input_tokens") and usage.input_tokens is not None: span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage.input_tokens) + # Pydantic AI uses cache_read_tokens (not input_tokens_cached) + if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None: + span.set_data( + SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED, usage.cache_read_tokens + ) + + # Pydantic AI uses cache_write_tokens (not input_tokens_cache_write) + if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None: + span.set_data( + SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE, + usage.cache_write_tokens, + ) + if hasattr(usage, "output_tokens") and usage.output_tokens is not None: span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage.output_tokens) + if ( + hasattr(usage, "output_tokens_reasoning") + and usage.output_tokens_reasoning is not None + ): + span.set_data( + SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING, + usage.output_tokens_reasoning, + ) + if hasattr(usage, "total_tokens") and usage.total_tokens is not None: span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage.total_tokens) diff --git a/tests/integrations/anthropic/test_anthropic.py b/tests/integrations/anthropic/test_anthropic.py index 2204505d47..b3bd0d44b3 100644 --- a/tests/integrations/anthropic/test_anthropic.py +++ b/tests/integrations/anthropic/test_anthropic.py @@ -1446,3 +1446,289 @@ def test_system_prompt_with_complex_structure(sentry_init, capture_events): assert stored_messages[0]["content"][1]["text"] == "Be concise and clear." assert stored_messages[1]["role"] == "user" assert stored_messages[1]["content"] == "Hello" + + +def test_cache_tokens_nonstreaming(sentry_init, capture_events): + """Test that cache read and write tokens are properly tracked for non-streaming responses.""" + sentry_init( + integrations=[AnthropicIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + client = Anthropic(api_key="z") + + # Create a message with cache token usage + message_with_cache = Message( + id="id", + model="claude-3-5-sonnet-20241022", + role="assistant", + content=[TextBlock(type="text", text="Response using cache")], + type="message", + usage=Usage( + input_tokens=100, + output_tokens=50, + cache_read_input_tokens=80, # 80 tokens read from cache + cache_write_input_tokens=20, # 20 tokens written to cache + ), + ) + + client.messages._post = mock.Mock(return_value=message_with_cache) + + messages = [{"role": "user", "content": "Hello"}] + + with start_transaction(name="anthropic"): + response = client.messages.create( + max_tokens=1024, messages=messages, model="claude-3-5-sonnet-20241022" + ) + + assert response == message_with_cache + usage = response.usage + + assert usage.input_tokens == 100 + assert usage.output_tokens == 50 + assert usage.cache_read_input_tokens == 80 + assert usage.cache_write_input_tokens == 20 + + assert len(events) == 1 + (event,) = events + + assert event["type"] == "transaction" + assert len(event["spans"]) == 1 + (span,) = event["spans"] + + assert span["op"] == OP.GEN_AI_CHAT + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 50 + assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 150 + # Check cache-related tokens + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 + + +def test_cache_tokens_only_reads(sentry_init, capture_events): + """Test tracking when only cache reads are present (no writes).""" + sentry_init( + integrations=[AnthropicIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + client = Anthropic(api_key="z") + + # Message with only cache reads, no writes + message_cache_read_only = Message( + id="id", + model="claude-3-5-sonnet-20241022", + role="assistant", + content=[TextBlock(type="text", text="Response")], + type="message", + usage=Usage( + input_tokens=100, + output_tokens=50, + cache_read_input_tokens=100, # All tokens read from cache + cache_write_input_tokens=0, # No new cache writes + ), + ) + + client.messages._post = mock.Mock(return_value=message_cache_read_only) + + with start_transaction(name="anthropic"): + client.messages.create( + max_tokens=1024, + messages=[{"role": "user", "content": "Hello"}], + model="claude-3-5-sonnet-20241022", + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 100 + # Cache write should not be present when it's 0 + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] + + +def test_cache_tokens_streaming(sentry_init, capture_events): + """Test that cache tokens are tracked correctly for streaming responses.""" + sentry_init( + integrations=[AnthropicIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + client = Anthropic(api_key="z") + + # Create streaming events with cache usage + stream_events = [ + MessageStartEvent( + type="message_start", + message=Message( + id="id", + model="claude-3-5-sonnet-20241022", + role="assistant", + content=[], + type="message", + usage=Usage( + input_tokens=100, + output_tokens=0, + cache_read_input_tokens=80, + cache_write_input_tokens=20, + ), + ), + ), + ContentBlockDeltaEvent( + type="content_block_delta", + index=0, + delta=TextDelta(type="text_delta", text="Hello"), + ), + MessageDeltaEvent( + type="message_delta", + delta=Delta(stop_reason="end_turn"), + usage=MessageDeltaUsage(output_tokens=10), + ), + ] + + mock_stream = mock.MagicMock(spec=Stream) + mock_stream.__iter__ = mock.Mock(return_value=iter(stream_events)) + mock_stream._iterator = iter(stream_events) + + client.messages._post = mock.Mock(return_value=mock_stream) + + with start_transaction(name="anthropic"): + stream = client.messages.create( + max_tokens=1024, + messages=[{"role": "user", "content": "Hello"}], + model="claude-3-5-sonnet-20241022", + stream=True, + ) + # Consume the stream + for _ in stream: + pass + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["op"] == OP.GEN_AI_CHAT + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10 + assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 110 + # Check streaming cache tokens + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True + + +@pytest.mark.asyncio +async def test_cache_tokens_streaming_async(sentry_init, capture_events): + """Test that cache tokens are tracked correctly for async streaming responses.""" + sentry_init( + integrations=[AnthropicIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + client = AsyncAnthropic(api_key="z") + + async def async_iterator(values): + for value in values: + yield value + + # Create streaming events with cache usage + stream_events = [ + MessageStartEvent( + type="message_start", + message=Message( + id="id", + model="claude-3-5-sonnet-20241022", + role="assistant", + content=[], + type="message", + usage=Usage( + input_tokens=100, + output_tokens=0, + cache_read_input_tokens=80, + cache_write_input_tokens=20, + ), + ), + ), + ContentBlockDeltaEvent( + type="content_block_delta", + index=0, + delta=TextDelta(type="text_delta", text="Hello"), + ), + MessageDeltaEvent( + type="message_delta", + delta=Delta(stop_reason="end_turn"), + usage=MessageDeltaUsage(output_tokens=10), + ), + ] + + mock_stream = mock.MagicMock(spec=AsyncStream) + mock_stream.__aiter__ = mock.Mock(return_value=async_iterator(stream_events)) + mock_stream._iterator = async_iterator(stream_events) + + client.messages._post = mock.Mock(return_value=mock_stream) + + with start_transaction(name="anthropic"): + stream = await client.messages.create( + max_tokens=1024, + messages=[{"role": "user", "content": "Hello"}], + model="claude-3-5-sonnet-20241022", + stream=True, + ) + # Consume the stream + async for _ in stream: + pass + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["op"] == OP.GEN_AI_CHAT + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10 + # Check async streaming cache tokens + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 + + +def test_no_cache_tokens(sentry_init, capture_events): + """Test that requests without cache usage don't have cache fields.""" + sentry_init( + integrations=[AnthropicIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + client = Anthropic(api_key="z") + + # Message without any cache usage + message_no_cache = Message( + id="id", + model="claude-3-5-sonnet-20241022", + role="assistant", + content=[TextBlock(type="text", text="Response")], + type="message", + usage=Usage(input_tokens=100, output_tokens=50), + ) + + client.messages._post = mock.Mock(return_value=message_no_cache) + + with start_transaction(name="anthropic"): + client.messages.create( + max_tokens=1024, + messages=[{"role": "user", "content": "Hello"}], + model="claude-3-5-sonnet-20241022", + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 50 + # Cache fields should not be present + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED not in span["data"] + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] diff --git a/tests/integrations/litellm/test_litellm.py b/tests/integrations/litellm/test_litellm.py index 1b925fb61f..7f4dc0842b 100644 --- a/tests/integrations/litellm/test_litellm.py +++ b/tests/integrations/litellm/test_litellm.py @@ -753,3 +753,178 @@ def test_litellm_message_truncation(sentry_init, capture_events): assert "small message 4" in str(parsed_messages[0]) assert "small message 5" in str(parsed_messages[1]) assert tx["_meta"]["spans"]["0"]["data"]["gen_ai.request.messages"][""]["len"] == 5 + + +def test_litellm_cache_tokens_anthropic(sentry_init, capture_events): + """Test that cache tokens are properly tracked when using Anthropic through LiteLLM.""" + sentry_init( + integrations=[LiteLLMIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + # Mock LiteLLM completion response with Anthropic cache tokens + mock_response = mock.Mock() + mock_response.model = "claude-3-5-sonnet-20241022" + mock_response.choices = [ + mock.Mock( + message=mock.Mock( + role="assistant", + content="Response using cache", + ) + ) + ] + mock_response.usage = mock.Mock( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + cache_read_input_tokens=80, # Anthropic cache reads + cache_write_input_tokens=20, # Anthropic cache writes + ) + + with mock.patch("litellm.completion", return_value=mock_response): + with start_transaction(name="litellm"): + litellm.completion( + model="anthropic/claude-3-5-sonnet-20241022", + messages=[{"role": "user", "content": "Hello"}], + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["op"] == OP.GEN_AI_CHAT + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 50 + assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 150 + # Check cache tokens from Anthropic provider + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 + + +def test_litellm_cache_tokens_openai(sentry_init, capture_events): + """Test that cached tokens from OpenAI through LiteLLM are tracked.""" + sentry_init( + integrations=[LiteLLMIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + # Mock LiteLLM completion response with OpenAI cache tokens + mock_response = mock.Mock() + mock_response.model = "gpt-4o" + mock_response.choices = [ + mock.Mock( + message=mock.Mock( + role="assistant", + content="Response", + ) + ) + ] + mock_response.usage = mock.Mock( + prompt_tokens=200, + completion_tokens=75, + total_tokens=275, + cache_read_input_tokens=150, # OpenAI cache reads + cache_write_input_tokens=0, # OpenAI doesn't charge for cache writes + ) + + with mock.patch("litellm.completion", return_value=mock_response): + with start_transaction(name="litellm"): + litellm.completion( + model="gpt-4o", + messages=[{"role": "user", "content": "Complex prompt"}], + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 200 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 75 + # Check cached tokens + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 150 + # Cache write should not be present when it's 0 + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] + + +def test_litellm_no_cache_tokens(sentry_init, capture_events): + """Test that requests without cache usage don't have cache fields.""" + sentry_init( + integrations=[LiteLLMIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + # Mock response without cache tokens + mock_response = mock.Mock() + mock_response.model = "gpt-3.5-turbo" + mock_response.choices = [ + mock.Mock( + message=mock.Mock( + role="assistant", + content="Response", + ) + ) + ] + mock_response.usage = mock.Mock( + prompt_tokens=50, + completion_tokens=25, + total_tokens=75, + ) + + with mock.patch("litellm.completion", return_value=mock_response): + with start_transaction(name="litellm"): + litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello"}], + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 50 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 25 + # Cache fields should not be present + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED not in span["data"] + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] + + +def test_litellm_embeddings_no_cache(sentry_init, capture_events): + """Test that embeddings don't track cache tokens.""" + sentry_init( + integrations=[LiteLLMIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + # Mock embeddings response + mock_response = mock.Mock() + mock_response.model = "text-embedding-ada-002" + mock_response.data = [mock.Mock(embedding=[0.1, 0.2, 0.3])] + mock_response.usage = mock.Mock( + prompt_tokens=10, + total_tokens=10, + ) + + with mock.patch("litellm.embedding", return_value=mock_response): + with start_transaction(name="litellm"): + litellm.embedding( + model="text-embedding-ada-002", + input="Test text", + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["op"] == OP.GEN_AI_EMBEDDINGS + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10 + # Embeddings don't have cache tokens + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED not in span["data"] + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] diff --git a/tests/integrations/openai/test_openai.py b/tests/integrations/openai/test_openai.py index 814289c887..1e88b21e86 100644 --- a/tests/integrations/openai/test_openai.py +++ b/tests/integrations/openai/test_openai.py @@ -39,7 +39,7 @@ SKIP_RESPONSES_TESTS = True from sentry_sdk import start_transaction -from sentry_sdk.consts import SPANDATA +from sentry_sdk.consts import OP, SPANDATA from sentry_sdk.integrations.openai import ( OpenAIIntegration, _calculate_token_usage, @@ -1559,3 +1559,293 @@ def test_openai_message_truncation(sentry_init, capture_events): if SPANDATA.GEN_AI_REQUEST_MESSAGES in span_meta: messages_meta = span_meta[SPANDATA.GEN_AI_REQUEST_MESSAGES] assert "len" in messages_meta.get("", {}) + + +def test_openai_cache_tokens_prompt_tokens_details(sentry_init, capture_events): + """Test that cached tokens are tracked from prompt_tokens_details (OpenAI standard).""" + sentry_init( + integrations=[OpenAIIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + # Mock OpenAI client + import openai + + client = openai.OpenAI(api_key="test-key") + + # Create a response with cached tokens in prompt_tokens_details + mock_response = ChatCompletion( + id="chatcmpl-test", + model="gpt-4o", + object="chat.completion", + created=1234567890, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Hello! How can I help you?", + ), + ) + ], + usage=CompletionUsage( + completion_tokens=10, + prompt_tokens=100, + total_tokens=110, + ), + ) + + # Add prompt_tokens_details with cached_tokens + mock_response.usage.prompt_tokens_details = {"cached_tokens": 80} + + client.chat.completions._post = mock.Mock(return_value=mock_response) + + with start_transaction(name="openai"): + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello"}], + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["op"] == OP.GEN_AI_CHAT + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10 + assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 110 + # Check cached tokens from prompt_tokens_details + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 + # OpenAI doesn't track cache writes, so this should not be present + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] + + +def test_openai_cache_tokens_input_tokens_details_compat(sentry_init, capture_events): + """Test that cached tokens are tracked from input_tokens_details (compatibility).""" + sentry_init( + integrations=[OpenAIIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + import openai + + client = openai.OpenAI(api_key="test-key") + + # Create a response with cached tokens in input_tokens_details (for compatibility) + mock_response = ChatCompletion( + id="chatcmpl-test", + model="gpt-4o", + object="chat.completion", + created=1234567890, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Response", + ), + ) + ], + usage=CompletionUsage( + completion_tokens=20, + prompt_tokens=150, + total_tokens=170, + ), + ) + + # Add input_tokens_details with cached_tokens (compatibility path) + mock_response.usage.input_tokens_details = {"cached_tokens": 120} + + client.chat.completions._post = mock.Mock(return_value=mock_response) + + with start_transaction(name="openai"): + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello"}], + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 150 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20 + # Check cached tokens from input_tokens_details + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 120 + + +def test_openai_reasoning_tokens_compatibility(sentry_init, capture_events): + """Test that reasoning tokens are tracked from both output_tokens_details and completion_tokens_details.""" + sentry_init( + integrations=[OpenAIIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + import openai + + client = openai.OpenAI(api_key="test-key") + + # Create a response with reasoning tokens in completion_tokens_details + mock_response = ChatCompletion( + id="chatcmpl-test", + model="o1-preview", + object="chat.completion", + created=1234567890, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Response", + ), + ) + ], + usage=CompletionUsage( + completion_tokens=50, + prompt_tokens=100, + total_tokens=150, + ), + ) + + # Add completion_tokens_details with reasoning_tokens (OpenAI standard) + mock_response.usage.completion_tokens_details = {"reasoning_tokens": 30} + + client.chat.completions._post = mock.Mock(return_value=mock_response) + + with start_transaction(name="openai"): + client.chat.completions.create( + model="o1-preview", + messages=[{"role": "user", "content": "Solve this problem"}], + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + # Check reasoning tokens from completion_tokens_details + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 30 + + +def test_openai_no_cache_tokens(sentry_init, capture_events): + """Test that requests without cache usage don't have cache fields.""" + sentry_init( + integrations=[OpenAIIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + import openai + + client = openai.OpenAI(api_key="test-key") + + # Response without any cache or reasoning token details + mock_response = ChatCompletion( + id="chatcmpl-test", + model="gpt-3.5-turbo", + object="chat.completion", + created=1234567890, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Response", + ), + ) + ], + usage=CompletionUsage( + completion_tokens=25, + prompt_tokens=50, + total_tokens=75, + ), + ) + + client.chat.completions._post = mock.Mock(return_value=mock_response) + + with start_transaction(name="openai"): + client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello"}], + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 50 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 25 + assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 75 + # Cache and reasoning fields should not be present + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED not in span["data"] + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] + assert SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING not in span["data"] + + +def test_openai_cache_tokens_with_reasoning(sentry_init, capture_events): + """Test that both cached tokens and reasoning tokens can be tracked together.""" + sentry_init( + integrations=[OpenAIIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + import openai + + client = openai.OpenAI(api_key="test-key") + + # Response with both cached tokens and reasoning tokens + mock_response = ChatCompletion( + id="chatcmpl-test", + model="o1-preview", + object="chat.completion", + created=1234567890, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Detailed reasoning response", + ), + ) + ], + usage=CompletionUsage( + completion_tokens=100, + prompt_tokens=200, + total_tokens=300, + ), + ) + + # Add both prompt and completion token details + mock_response.usage.prompt_tokens_details = {"cached_tokens": 150} + mock_response.usage.completion_tokens_details = {"reasoning_tokens": 75} + + client.chat.completions._post = mock.Mock(return_value=mock_response) + + with start_transaction(name="openai"): + client.chat.completions.create( + model="o1-preview", + messages=[{"role": "user", "content": "Complex problem"}], + ) + + assert len(events) == 1 + (event,) = events + (span,) = event["spans"] + + # Check both cached and reasoning tokens + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 200 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 100 + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 150 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 75 From e2eaaac29e8819c9cdcb825e3fa74f160b95e1a4 Mon Sep 17 00:00:00 2001 From: Simon Hellmayr Date: Thu, 15 Jan 2026 14:57:30 +0100 Subject: [PATCH 2/3] simplify --- sentry_sdk/ai/__init__.py | 10 - sentry_sdk/integrations/openai.py | 11 +- .../integrations/anthropic/test_anthropic.py | 268 ++-------------- tests/integrations/litellm/test_litellm.py | 172 +---------- tests/integrations/openai/test_openai.py | 292 +----------------- 5 files changed, 49 insertions(+), 704 deletions(-) diff --git a/sentry_sdk/ai/__init__.py b/sentry_sdk/ai/__init__.py index 6eeeed3d76..fbcb9c061d 100644 --- a/sentry_sdk/ai/__init__.py +++ b/sentry_sdk/ai/__init__.py @@ -1,4 +1,3 @@ -from .monitoring import record_token_usage # noqa: F401 from .utils import ( set_data_normalized, GEN_AI_MESSAGE_ROLE_MAPPING, @@ -6,12 +5,3 @@ normalize_message_role, normalize_message_roles, ) # noqa: F401 - -__all__ = [ - "record_token_usage", - "set_data_normalized", - "GEN_AI_MESSAGE_ROLE_MAPPING", - "GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING", - "normalize_message_role", - "normalize_message_roles", -] diff --git a/sentry_sdk/integrations/openai.py b/sentry_sdk/integrations/openai.py index 13757d59d2..a2c7cc8d1d 100644 --- a/sentry_sdk/integrations/openai.py +++ b/sentry_sdk/integrations/openai.py @@ -131,12 +131,7 @@ def _calculate_token_usage( if hasattr(response, "usage"): input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"]) - if hasattr(response.usage, "prompt_tokens_details"): - input_tokens_cached = _get_usage( - response.usage.prompt_tokens_details, ["cached_tokens"] - ) - # OpenAI also supports input_tokens_details for compatibility - elif hasattr(response.usage, "input_tokens_details"): + if hasattr(response.usage, "input_tokens_details"): input_tokens_cached = _get_usage( response.usage.input_tokens_details, ["cached_tokens"] ) @@ -148,10 +143,6 @@ def _calculate_token_usage( output_tokens_reasoning = _get_usage( response.usage.output_tokens_details, ["reasoning_tokens"] ) - elif hasattr(response.usage, "completion_tokens_details"): - output_tokens_reasoning = _get_usage( - response.usage.completion_tokens_details, ["reasoning_tokens"] - ) total_tokens = _get_usage(response.usage, ["total_tokens"]) diff --git a/tests/integrations/anthropic/test_anthropic.py b/tests/integrations/anthropic/test_anthropic.py index b3bd0d44b3..fe56c4dc7d 100644 --- a/tests/integrations/anthropic/test_anthropic.py +++ b/tests/integrations/anthropic/test_anthropic.py @@ -850,8 +850,10 @@ def test_collect_ai_data_with_input_json_delta(): output_tokens = 20 content_blocks = [] - model, new_input_tokens, new_output_tokens, new_content_blocks = _collect_ai_data( - event, model, input_tokens, output_tokens, content_blocks + model, new_input_tokens, new_output_tokens, _, _, new_content_blocks = ( + _collect_ai_data( + event, model, input_tokens, output_tokens, 0, 0, content_blocks + ) ) assert model is None @@ -881,6 +883,8 @@ def test_set_output_data_with_input_json_delta(sentry_init): model="", input_tokens=10, output_tokens=20, + cache_read_input_tokens=0, + cache_write_input_tokens=0, content_blocks=[{"text": "".join(json_deltas), "type": "text"}], ) @@ -1449,90 +1453,27 @@ def test_system_prompt_with_complex_structure(sentry_init, capture_events): def test_cache_tokens_nonstreaming(sentry_init, capture_events): - """Test that cache read and write tokens are properly tracked for non-streaming responses.""" - sentry_init( - integrations=[AnthropicIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) + """Test cache read/write tokens are tracked for non-streaming responses.""" + sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0) events = capture_events() client = Anthropic(api_key="z") - # Create a message with cache token usage - message_with_cache = Message( - id="id", - model="claude-3-5-sonnet-20241022", - role="assistant", - content=[TextBlock(type="text", text="Response using cache")], - type="message", - usage=Usage( - input_tokens=100, - output_tokens=50, - cache_read_input_tokens=80, # 80 tokens read from cache - cache_write_input_tokens=20, # 20 tokens written to cache - ), - ) - - client.messages._post = mock.Mock(return_value=message_with_cache) - - messages = [{"role": "user", "content": "Hello"}] - - with start_transaction(name="anthropic"): - response = client.messages.create( - max_tokens=1024, messages=messages, model="claude-3-5-sonnet-20241022" + client.messages._post = mock.Mock( + return_value=Message( + id="id", + model="claude-3-5-sonnet-20241022", + role="assistant", + content=[TextBlock(type="text", text="Response")], + type="message", + usage=Usage( + input_tokens=100, + output_tokens=50, + cache_read_input_tokens=80, + cache_creation_input_tokens=20, + ), ) - - assert response == message_with_cache - usage = response.usage - - assert usage.input_tokens == 100 - assert usage.output_tokens == 50 - assert usage.cache_read_input_tokens == 80 - assert usage.cache_write_input_tokens == 20 - - assert len(events) == 1 - (event,) = events - - assert event["type"] == "transaction" - assert len(event["spans"]) == 1 - (span,) = event["spans"] - - assert span["op"] == OP.GEN_AI_CHAT - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 50 - assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 150 - # Check cache-related tokens - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 - - -def test_cache_tokens_only_reads(sentry_init, capture_events): - """Test tracking when only cache reads are present (no writes).""" - sentry_init( - integrations=[AnthropicIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() - client = Anthropic(api_key="z") - - # Message with only cache reads, no writes - message_cache_read_only = Message( - id="id", - model="claude-3-5-sonnet-20241022", - role="assistant", - content=[TextBlock(type="text", text="Response")], - type="message", - usage=Usage( - input_tokens=100, - output_tokens=50, - cache_read_input_tokens=100, # All tokens read from cache - cache_write_input_tokens=0, # No new cache writes - ), ) - client.messages._post = mock.Mock(return_value=message_cache_read_only) - with start_transaction(name="anthropic"): client.messages.create( max_tokens=1024, @@ -1540,27 +1481,16 @@ def test_cache_tokens_only_reads(sentry_init, capture_events): model="claude-3-5-sonnet-20241022", ) - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 100 - # Cache write should not be present when it's 0 - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] + (span,) = events[0]["spans"] + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 def test_cache_tokens_streaming(sentry_init, capture_events): - """Test that cache tokens are tracked correctly for streaming responses.""" - sentry_init( - integrations=[AnthropicIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() + """Test cache tokens are tracked for streaming responses.""" client = Anthropic(api_key="z") - - # Create streaming events with cache usage - stream_events = [ + returned_stream = Stream(cast_to=None, response=None, client=client) + returned_stream._iterator = [ MessageStartEvent( type="message_start", message=Message( @@ -1573,15 +1503,10 @@ def test_cache_tokens_streaming(sentry_init, capture_events): input_tokens=100, output_tokens=0, cache_read_input_tokens=80, - cache_write_input_tokens=20, + cache_creation_input_tokens=20, ), ), ), - ContentBlockDeltaEvent( - type="content_block_delta", - index=0, - delta=TextDelta(type="text_delta", text="Hello"), - ), MessageDeltaEvent( type="message_delta", delta=Delta(stop_reason="end_turn"), @@ -1589,146 +1514,19 @@ def test_cache_tokens_streaming(sentry_init, capture_events): ), ] - mock_stream = mock.MagicMock(spec=Stream) - mock_stream.__iter__ = mock.Mock(return_value=iter(stream_events)) - mock_stream._iterator = iter(stream_events) - - client.messages._post = mock.Mock(return_value=mock_stream) - - with start_transaction(name="anthropic"): - stream = client.messages.create( - max_tokens=1024, - messages=[{"role": "user", "content": "Hello"}], - model="claude-3-5-sonnet-20241022", - stream=True, - ) - # Consume the stream - for _ in stream: - pass - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - assert span["op"] == OP.GEN_AI_CHAT - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10 - assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 110 - # Check streaming cache tokens - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 - assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True - - -@pytest.mark.asyncio -async def test_cache_tokens_streaming_async(sentry_init, capture_events): - """Test that cache tokens are tracked correctly for async streaming responses.""" - sentry_init( - integrations=[AnthropicIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) + sentry_init(integrations=[AnthropicIntegration()], traces_sample_rate=1.0) events = capture_events() - client = AsyncAnthropic(api_key="z") - - async def async_iterator(values): - for value in values: - yield value - - # Create streaming events with cache usage - stream_events = [ - MessageStartEvent( - type="message_start", - message=Message( - id="id", - model="claude-3-5-sonnet-20241022", - role="assistant", - content=[], - type="message", - usage=Usage( - input_tokens=100, - output_tokens=0, - cache_read_input_tokens=80, - cache_write_input_tokens=20, - ), - ), - ), - ContentBlockDeltaEvent( - type="content_block_delta", - index=0, - delta=TextDelta(type="text_delta", text="Hello"), - ), - MessageDeltaEvent( - type="message_delta", - delta=Delta(stop_reason="end_turn"), - usage=MessageDeltaUsage(output_tokens=10), - ), - ] - - mock_stream = mock.MagicMock(spec=AsyncStream) - mock_stream.__aiter__ = mock.Mock(return_value=async_iterator(stream_events)) - mock_stream._iterator = async_iterator(stream_events) - - client.messages._post = mock.Mock(return_value=mock_stream) + client.messages._post = mock.Mock(return_value=returned_stream) with start_transaction(name="anthropic"): - stream = await client.messages.create( + for _ in client.messages.create( max_tokens=1024, messages=[{"role": "user", "content": "Hello"}], model="claude-3-5-sonnet-20241022", stream=True, - ) - # Consume the stream - async for _ in stream: + ): pass - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - assert span["op"] == OP.GEN_AI_CHAT - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10 - # Check async streaming cache tokens + (span,) = events[0]["spans"] assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 - - -def test_no_cache_tokens(sentry_init, capture_events): - """Test that requests without cache usage don't have cache fields.""" - sentry_init( - integrations=[AnthropicIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() - client = Anthropic(api_key="z") - - # Message without any cache usage - message_no_cache = Message( - id="id", - model="claude-3-5-sonnet-20241022", - role="assistant", - content=[TextBlock(type="text", text="Response")], - type="message", - usage=Usage(input_tokens=100, output_tokens=50), - ) - - client.messages._post = mock.Mock(return_value=message_no_cache) - - with start_transaction(name="anthropic"): - client.messages.create( - max_tokens=1024, - messages=[{"role": "user", "content": "Hello"}], - model="claude-3-5-sonnet-20241022", - ) - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 50 - # Cache fields should not be present - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED not in span["data"] - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] diff --git a/tests/integrations/litellm/test_litellm.py b/tests/integrations/litellm/test_litellm.py index 7f4dc0842b..53dae68be1 100644 --- a/tests/integrations/litellm/test_litellm.py +++ b/tests/integrations/litellm/test_litellm.py @@ -755,176 +755,32 @@ def test_litellm_message_truncation(sentry_init, capture_events): assert tx["_meta"]["spans"]["0"]["data"]["gen_ai.request.messages"][""]["len"] == 5 -def test_litellm_cache_tokens_anthropic(sentry_init, capture_events): - """Test that cache tokens are properly tracked when using Anthropic through LiteLLM.""" - sentry_init( - integrations=[LiteLLMIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) +def test_litellm_cache_tokens(sentry_init, capture_events): + """Test cache tokens are tracked through LiteLLM.""" + sentry_init(integrations=[LiteLLMIntegration()], traces_sample_rate=1.0) events = capture_events() - # Mock LiteLLM completion response with Anthropic cache tokens mock_response = mock.Mock() mock_response.model = "claude-3-5-sonnet-20241022" mock_response.choices = [ - mock.Mock( - message=mock.Mock( - role="assistant", - content="Response using cache", - ) - ) + mock.Mock(message=mock.Mock(role="assistant", content="Response")) ] mock_response.usage = mock.Mock( prompt_tokens=100, completion_tokens=50, total_tokens=150, - cache_read_input_tokens=80, # Anthropic cache reads - cache_write_input_tokens=20, # Anthropic cache writes + cache_read_input_tokens=80, + cache_write_input_tokens=20, ) - with mock.patch("litellm.completion", return_value=mock_response): - with start_transaction(name="litellm"): - litellm.completion( - model="anthropic/claude-3-5-sonnet-20241022", - messages=[{"role": "user", "content": "Hello"}], - ) - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] + kwargs = { + "model": "anthropic/claude-3-5-sonnet-20241022", + "messages": [{"role": "user", "content": "Hello"}], + } + with start_transaction(name="litellm"): + _input_callback(kwargs) + _success_callback(kwargs, mock_response, datetime.now(), datetime.now()) - assert span["op"] == OP.GEN_AI_CHAT - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 50 - assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 150 - # Check cache tokens from Anthropic provider + (span,) = events[0]["spans"] assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE] == 20 - - -def test_litellm_cache_tokens_openai(sentry_init, capture_events): - """Test that cached tokens from OpenAI through LiteLLM are tracked.""" - sentry_init( - integrations=[LiteLLMIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() - - # Mock LiteLLM completion response with OpenAI cache tokens - mock_response = mock.Mock() - mock_response.model = "gpt-4o" - mock_response.choices = [ - mock.Mock( - message=mock.Mock( - role="assistant", - content="Response", - ) - ) - ] - mock_response.usage = mock.Mock( - prompt_tokens=200, - completion_tokens=75, - total_tokens=275, - cache_read_input_tokens=150, # OpenAI cache reads - cache_write_input_tokens=0, # OpenAI doesn't charge for cache writes - ) - - with mock.patch("litellm.completion", return_value=mock_response): - with start_transaction(name="litellm"): - litellm.completion( - model="gpt-4o", - messages=[{"role": "user", "content": "Complex prompt"}], - ) - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 200 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 75 - # Check cached tokens - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 150 - # Cache write should not be present when it's 0 - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] - - -def test_litellm_no_cache_tokens(sentry_init, capture_events): - """Test that requests without cache usage don't have cache fields.""" - sentry_init( - integrations=[LiteLLMIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() - - # Mock response without cache tokens - mock_response = mock.Mock() - mock_response.model = "gpt-3.5-turbo" - mock_response.choices = [ - mock.Mock( - message=mock.Mock( - role="assistant", - content="Response", - ) - ) - ] - mock_response.usage = mock.Mock( - prompt_tokens=50, - completion_tokens=25, - total_tokens=75, - ) - - with mock.patch("litellm.completion", return_value=mock_response): - with start_transaction(name="litellm"): - litellm.completion( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "Hello"}], - ) - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 50 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 25 - # Cache fields should not be present - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED not in span["data"] - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] - - -def test_litellm_embeddings_no_cache(sentry_init, capture_events): - """Test that embeddings don't track cache tokens.""" - sentry_init( - integrations=[LiteLLMIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() - - # Mock embeddings response - mock_response = mock.Mock() - mock_response.model = "text-embedding-ada-002" - mock_response.data = [mock.Mock(embedding=[0.1, 0.2, 0.3])] - mock_response.usage = mock.Mock( - prompt_tokens=10, - total_tokens=10, - ) - - with mock.patch("litellm.embedding", return_value=mock_response): - with start_transaction(name="litellm"): - litellm.embedding( - model="text-embedding-ada-002", - input="Test text", - ) - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - assert span["op"] == OP.GEN_AI_EMBEDDINGS - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10 - # Embeddings don't have cache tokens - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED not in span["data"] - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] diff --git a/tests/integrations/openai/test_openai.py b/tests/integrations/openai/test_openai.py index 1e88b21e86..814289c887 100644 --- a/tests/integrations/openai/test_openai.py +++ b/tests/integrations/openai/test_openai.py @@ -39,7 +39,7 @@ SKIP_RESPONSES_TESTS = True from sentry_sdk import start_transaction -from sentry_sdk.consts import OP, SPANDATA +from sentry_sdk.consts import SPANDATA from sentry_sdk.integrations.openai import ( OpenAIIntegration, _calculate_token_usage, @@ -1559,293 +1559,3 @@ def test_openai_message_truncation(sentry_init, capture_events): if SPANDATA.GEN_AI_REQUEST_MESSAGES in span_meta: messages_meta = span_meta[SPANDATA.GEN_AI_REQUEST_MESSAGES] assert "len" in messages_meta.get("", {}) - - -def test_openai_cache_tokens_prompt_tokens_details(sentry_init, capture_events): - """Test that cached tokens are tracked from prompt_tokens_details (OpenAI standard).""" - sentry_init( - integrations=[OpenAIIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() - - # Mock OpenAI client - import openai - - client = openai.OpenAI(api_key="test-key") - - # Create a response with cached tokens in prompt_tokens_details - mock_response = ChatCompletion( - id="chatcmpl-test", - model="gpt-4o", - object="chat.completion", - created=1234567890, - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage( - role="assistant", - content="Hello! How can I help you?", - ), - ) - ], - usage=CompletionUsage( - completion_tokens=10, - prompt_tokens=100, - total_tokens=110, - ), - ) - - # Add prompt_tokens_details with cached_tokens - mock_response.usage.prompt_tokens_details = {"cached_tokens": 80} - - client.chat.completions._post = mock.Mock(return_value=mock_response) - - with start_transaction(name="openai"): - client.chat.completions.create( - model="gpt-4o", - messages=[{"role": "user", "content": "Hello"}], - ) - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - assert span["op"] == OP.GEN_AI_CHAT - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10 - assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 110 - # Check cached tokens from prompt_tokens_details - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 80 - # OpenAI doesn't track cache writes, so this should not be present - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] - - -def test_openai_cache_tokens_input_tokens_details_compat(sentry_init, capture_events): - """Test that cached tokens are tracked from input_tokens_details (compatibility).""" - sentry_init( - integrations=[OpenAIIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() - - import openai - - client = openai.OpenAI(api_key="test-key") - - # Create a response with cached tokens in input_tokens_details (for compatibility) - mock_response = ChatCompletion( - id="chatcmpl-test", - model="gpt-4o", - object="chat.completion", - created=1234567890, - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage( - role="assistant", - content="Response", - ), - ) - ], - usage=CompletionUsage( - completion_tokens=20, - prompt_tokens=150, - total_tokens=170, - ), - ) - - # Add input_tokens_details with cached_tokens (compatibility path) - mock_response.usage.input_tokens_details = {"cached_tokens": 120} - - client.chat.completions._post = mock.Mock(return_value=mock_response) - - with start_transaction(name="openai"): - client.chat.completions.create( - model="gpt-4o", - messages=[{"role": "user", "content": "Hello"}], - ) - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 150 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20 - # Check cached tokens from input_tokens_details - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 120 - - -def test_openai_reasoning_tokens_compatibility(sentry_init, capture_events): - """Test that reasoning tokens are tracked from both output_tokens_details and completion_tokens_details.""" - sentry_init( - integrations=[OpenAIIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() - - import openai - - client = openai.OpenAI(api_key="test-key") - - # Create a response with reasoning tokens in completion_tokens_details - mock_response = ChatCompletion( - id="chatcmpl-test", - model="o1-preview", - object="chat.completion", - created=1234567890, - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage( - role="assistant", - content="Response", - ), - ) - ], - usage=CompletionUsage( - completion_tokens=50, - prompt_tokens=100, - total_tokens=150, - ), - ) - - # Add completion_tokens_details with reasoning_tokens (OpenAI standard) - mock_response.usage.completion_tokens_details = {"reasoning_tokens": 30} - - client.chat.completions._post = mock.Mock(return_value=mock_response) - - with start_transaction(name="openai"): - client.chat.completions.create( - model="o1-preview", - messages=[{"role": "user", "content": "Solve this problem"}], - ) - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - # Check reasoning tokens from completion_tokens_details - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 30 - - -def test_openai_no_cache_tokens(sentry_init, capture_events): - """Test that requests without cache usage don't have cache fields.""" - sentry_init( - integrations=[OpenAIIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() - - import openai - - client = openai.OpenAI(api_key="test-key") - - # Response without any cache or reasoning token details - mock_response = ChatCompletion( - id="chatcmpl-test", - model="gpt-3.5-turbo", - object="chat.completion", - created=1234567890, - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage( - role="assistant", - content="Response", - ), - ) - ], - usage=CompletionUsage( - completion_tokens=25, - prompt_tokens=50, - total_tokens=75, - ), - ) - - client.chat.completions._post = mock.Mock(return_value=mock_response) - - with start_transaction(name="openai"): - client.chat.completions.create( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "Hello"}], - ) - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 50 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 25 - assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 75 - # Cache and reasoning fields should not be present - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED not in span["data"] - assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE not in span["data"] - assert SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING not in span["data"] - - -def test_openai_cache_tokens_with_reasoning(sentry_init, capture_events): - """Test that both cached tokens and reasoning tokens can be tracked together.""" - sentry_init( - integrations=[OpenAIIntegration(include_prompts=True)], - traces_sample_rate=1.0, - send_default_pii=True, - ) - events = capture_events() - - import openai - - client = openai.OpenAI(api_key="test-key") - - # Response with both cached tokens and reasoning tokens - mock_response = ChatCompletion( - id="chatcmpl-test", - model="o1-preview", - object="chat.completion", - created=1234567890, - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage( - role="assistant", - content="Detailed reasoning response", - ), - ) - ], - usage=CompletionUsage( - completion_tokens=100, - prompt_tokens=200, - total_tokens=300, - ), - ) - - # Add both prompt and completion token details - mock_response.usage.prompt_tokens_details = {"cached_tokens": 150} - mock_response.usage.completion_tokens_details = {"reasoning_tokens": 75} - - client.chat.completions._post = mock.Mock(return_value=mock_response) - - with start_transaction(name="openai"): - client.chat.completions.create( - model="o1-preview", - messages=[{"role": "user", "content": "Complex problem"}], - ) - - assert len(events) == 1 - (event,) = events - (span,) = event["spans"] - - # Check both cached and reasoning tokens - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 200 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 100 - assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 150 - assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 75 From 7a2e2b5076165e7c689feb42da2d5dd20666b485 Mon Sep 17 00:00:00 2001 From: Simon Hellmayr Date: Thu, 15 Jan 2026 16:02:32 +0100 Subject: [PATCH 3/3] changes to litellm --- sentry_sdk/integrations/litellm.py | 308 +++++++++++++++-------------- 1 file changed, 155 insertions(+), 153 deletions(-) diff --git a/sentry_sdk/integrations/litellm.py b/sentry_sdk/integrations/litellm.py index 0d34e2e949..19e1f76736 100644 --- a/sentry_sdk/integrations/litellm.py +++ b/sentry_sdk/integrations/litellm.py @@ -1,3 +1,4 @@ +from functools import wraps from typing import TYPE_CHECKING import sentry_sdk @@ -11,7 +12,7 @@ from sentry_sdk.consts import SPANDATA from sentry_sdk.integrations import DidNotEnable, Integration from sentry_sdk.scope import should_send_default_pii -from sentry_sdk.utils import event_from_exception +from sentry_sdk.utils import capture_internal_exceptions, event_from_exception if TYPE_CHECKING: from typing import Any, Dict @@ -19,71 +20,32 @@ try: import litellm # type: ignore[import-not-found] - from litellm import input_callback, success_callback, failure_callback except ImportError: raise DidNotEnable("LiteLLM not installed") -def _get_metadata_dict(kwargs: "Dict[str, Any]") -> "Dict[str, Any]": - """Get the metadata dictionary from the kwargs.""" - litellm_params = kwargs.setdefault("litellm_params", {}) - - # we need this weird little dance, as metadata might be set but may be None initially - metadata = litellm_params.get("metadata") - if metadata is None: - metadata = {} - litellm_params["metadata"] = metadata - return metadata - - -def _input_callback(kwargs: "Dict[str, Any]") -> None: - """Handle the start of a request.""" - integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration) - - if integration is None: - return - - # Get key parameters - full_model = kwargs.get("model", "") +def _get_provider_and_model(full_model: str) -> "tuple[str, str]": + """Extract provider and model name from full model string.""" try: model, provider, _, _ = litellm.get_llm_provider(full_model) + return provider, model except Exception: - model = full_model - provider = "unknown" - - call_type = kwargs.get("call_type", None) - if call_type == "embedding": - operation = "embeddings" - else: - operation = "chat" - - # Start a new span/transaction - span = get_start_span_function()( - op=( - consts.OP.GEN_AI_CHAT - if operation == "chat" - else consts.OP.GEN_AI_EMBEDDINGS - ), - name=f"{operation} {model}", - origin=LiteLLMIntegration.origin, - ) - span.__enter__() + return "unknown", full_model - # Store span for later - _get_metadata_dict(kwargs)["_sentry_span"] = span - - # Set basic data - set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, provider) - set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, operation) +def _set_input_data( + span: "Any", + kwargs: "Dict[str, Any]", + operation: str, + integration: "LiteLLMIntegration", +) -> None: + """Set input data on the span.""" # Record input/messages if allowed if should_send_default_pii() and integration.include_prompts: if operation == "embeddings": - # For embeddings, look for the 'input' parameter embedding_input = kwargs.get("input") if embedding_input: scope = sentry_sdk.get_current_scope() - # Normalize to list format input_list = ( embedding_input if isinstance(embedding_input, list) @@ -98,7 +60,6 @@ def _input_callback(kwargs: "Dict[str, Any]") -> None: unpack=False, ) else: - # For chat, look for the 'messages' parameter messages = kwargs.get("messages", []) if messages: scope = sentry_sdk.get_current_scope() @@ -111,7 +72,7 @@ def _input_callback(kwargs: "Dict[str, Any]") -> None: unpack=False, ) - # Record other parameters + # Record standard parameters params = { "model": SPANDATA.GEN_AI_REQUEST_MODEL, "stream": SPANDATA.GEN_AI_RESPONSE_STREAMING, @@ -126,107 +87,157 @@ def _input_callback(kwargs: "Dict[str, Any]") -> None: if value is not None: set_data_normalized(span, attribute, value) - # Record LiteLLM-specific parameters - litellm_params = { - "api_base": kwargs.get("api_base"), - "api_version": kwargs.get("api_version"), - "custom_llm_provider": kwargs.get("custom_llm_provider"), - } - for key, value in litellm_params.items(): - if value is not None: - set_data_normalized(span, f"gen_ai.litellm.{key}", value) - -def _success_callback( - kwargs: "Dict[str, Any]", - completion_response: "Any", - start_time: "datetime", - end_time: "datetime", +def _set_output_data( + span: "Any", + response: "Any", + integration: "LiteLLMIntegration", ) -> None: - """Handle successful completion.""" + """Set output data on the span.""" + # Record model information + if hasattr(response, "model"): + set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model) + + # Record response content if allowed + if should_send_default_pii() and integration.include_prompts: + if hasattr(response, "choices"): + response_messages = [] + for choice in response.choices: + if hasattr(choice, "message"): + if hasattr(choice.message, "model_dump"): + response_messages.append(choice.message.model_dump()) + elif hasattr(choice.message, "dict"): + response_messages.append(choice.message.dict()) + else: + msg = {} + if hasattr(choice.message, "role"): + msg["role"] = choice.message.role + if hasattr(choice.message, "content"): + msg["content"] = choice.message.content + if hasattr(choice.message, "tool_calls"): + msg["tool_calls"] = choice.message.tool_calls + response_messages.append(msg) + + if response_messages: + set_data_normalized( + span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_messages + ) - span = _get_metadata_dict(kwargs).get("_sentry_span") - if span is None: - return + # Record token usage + if hasattr(response, "usage"): + usage = response.usage + + # Extract cached tokens from prompt_tokens_details (OpenAI format used by LiteLLM) + cached_tokens = None + prompt_tokens_details = getattr(usage, "prompt_tokens_details", None) + if prompt_tokens_details is not None: + cached_tokens = getattr(prompt_tokens_details, "cached_tokens", None) + + # Extract cache write tokens (Anthropic only) + cache_creation_tokens = getattr(usage, "cache_creation_input_tokens", None) + + record_token_usage( + span, + input_tokens=getattr(usage, "prompt_tokens", None), + input_tokens_cached=cached_tokens, + input_tokens_cache_write=cache_creation_tokens, + output_tokens=getattr(usage, "completion_tokens", None), + total_tokens=getattr(usage, "total_tokens", None), + ) - integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration) - if integration is None: - return - try: - # Record model information - if hasattr(completion_response, "model"): - set_data_normalized( - span, SPANDATA.GEN_AI_RESPONSE_MODEL, completion_response.model - ) +def _wrap_completion(original_func: "Any") -> "Any": + """Wrap litellm.completion to add instrumentation.""" - # Record response content if allowed - if should_send_default_pii() and integration.include_prompts: - if hasattr(completion_response, "choices"): - response_messages = [] - for choice in completion_response.choices: - if hasattr(choice, "message"): - if hasattr(choice.message, "model_dump"): - response_messages.append(choice.message.model_dump()) - elif hasattr(choice.message, "dict"): - response_messages.append(choice.message.dict()) - else: - # Fallback for basic message objects - msg = {} - if hasattr(choice.message, "role"): - msg["role"] = choice.message.role - if hasattr(choice.message, "content"): - msg["content"] = choice.message.content - if hasattr(choice.message, "tool_calls"): - msg["tool_calls"] = choice.message.tool_calls - response_messages.append(msg) - - if response_messages: - set_data_normalized( - span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_messages - ) + @wraps(original_func) + def wrapper(*args: "Any", **kwargs: "Any") -> "Any": + integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration) + if integration is None: + return original_func(*args, **kwargs) - # Record token usage - if hasattr(completion_response, "usage"): - usage = completion_response.usage - record_token_usage( - span, - input_tokens=getattr(usage, "prompt_tokens", None), - input_tokens_cached=getattr(usage, "cache_read_input_tokens", None), - input_tokens_cache_write=getattr( - usage, "cache_write_input_tokens", None - ), - output_tokens=getattr(usage, "completion_tokens", None), - total_tokens=getattr(usage, "total_tokens", None), + # Get model and provider + full_model = kwargs.get("model", args[0] if args else "") + provider, model = _get_provider_and_model(full_model) + + # Create span + span = get_start_span_function()( + op=consts.OP.GEN_AI_CHAT, + name=f"chat {model}", + origin=LiteLLMIntegration.origin, + ) + span.__enter__() + + # Set basic data + set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, provider) + set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat") + + with capture_internal_exceptions(): + _set_input_data(span, kwargs, "chat", integration) + + try: + response = original_func(*args, **kwargs) + with capture_internal_exceptions(): + _set_output_data(span, response, integration) + return response + except Exception as exc: + event, hint = event_from_exception( + exc, + client_options=sentry_sdk.get_client().options, + mechanism={"type": "litellm", "handled": False}, ) + sentry_sdk.capture_event(event, hint=hint) + raise + finally: + span.__exit__(None, None, None) - finally: - # Always finish the span and clean up - span.__exit__(None, None, None) + return wrapper -def _failure_callback( - kwargs: "Dict[str, Any]", - exception: Exception, - start_time: "datetime", - end_time: "datetime", -) -> None: - """Handle request failure.""" - span = _get_metadata_dict(kwargs).get("_sentry_span") - if span is None: - return +def _wrap_acompletion(original_func: "Any") -> "Any": + """Wrap litellm.acompletion to add instrumentation.""" - try: - # Capture the exception - event, hint = event_from_exception( - exception, - client_options=sentry_sdk.get_client().options, - mechanism={"type": "litellm", "handled": False}, + @wraps(original_func) + async def wrapper(*args: "Any", **kwargs: "Any") -> "Any": + integration = sentry_sdk.get_client().get_integration(LiteLLMIntegration) + if integration is None: + return await original_func(*args, **kwargs) + + # Get model and provider + full_model = kwargs.get("model", args[0] if args else "") + provider, model = _get_provider_and_model(full_model) + + # Create span + span = get_start_span_function()( + op=consts.OP.GEN_AI_CHAT, + name=f"chat {model}", + origin=LiteLLMIntegration.origin, ) - sentry_sdk.capture_event(event, hint=hint) - finally: - # Always finish the span and clean up - span.__exit__(type(exception), exception, None) + span.__enter__() + + # Set basic data + set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, provider) + set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat") + + with capture_internal_exceptions(): + _set_input_data(span, kwargs, "chat", integration) + + try: + response = await original_func(*args, **kwargs) + with capture_internal_exceptions(): + _set_output_data(span, response, integration) + return response + except Exception as exc: + event, hint = event_from_exception( + exc, + client_options=sentry_sdk.get_client().options, + mechanism={"type": "litellm", "handled": False}, + ) + sentry_sdk.capture_event(event, hint=hint) + raise + finally: + span.__exit__(None, None, None) + + return wrapper class LiteLLMIntegration(Integration): @@ -282,15 +293,6 @@ def __init__(self: "LiteLLMIntegration", include_prompts: bool = True) -> None: @staticmethod def setup_once() -> None: - """Set up LiteLLM callbacks for monitoring.""" - litellm.input_callback = input_callback or [] - if _input_callback not in litellm.input_callback: - litellm.input_callback.append(_input_callback) - - litellm.success_callback = success_callback or [] - if _success_callback not in litellm.success_callback: - litellm.success_callback.append(_success_callback) - - litellm.failure_callback = failure_callback or [] - if _failure_callback not in litellm.failure_callback: - litellm.failure_callback.append(_failure_callback) + """Set up LiteLLM instrumentation by wrapping completion functions.""" + litellm.completion = _wrap_completion(litellm.completion) + litellm.acompletion = _wrap_acompletion(litellm.acompletion)