Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sentry_sdk/ai/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def record_token_usage(
span: "Span",
input_tokens: "Optional[int]" = None,
input_tokens_cached: "Optional[int]" = None,
input_tokens_cache_write: "Optional[int]" = None,
output_tokens: "Optional[int]" = None,
output_tokens_reasoning: "Optional[int]" = None,
total_tokens: "Optional[int]" = None,
Expand All @@ -113,6 +114,12 @@ def record_token_usage(
input_tokens_cached,
)

if input_tokens_cache_write is not None:
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE,
input_tokens_cache_write,
)

if output_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)

Expand Down
6 changes: 6 additions & 0 deletions sentry_sdk/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,12 @@ class SPANDATA:
Example: 50
"""

GEN_AI_USAGE_INPUT_TOKENS_CACHE_WRITE = "gen_ai.usage.input_tokens.cache_write"
"""
The number of tokens written to the cache when processing the AI input (prompt).
Example: 100
"""

GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
"""
The number of tokens in the output.
Expand Down
102 changes: 88 additions & 14 deletions sentry_sdk/integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,29 +72,47 @@ def _capture_exception(exc: "Any") -> None:
sentry_sdk.capture_event(event, hint=hint)


def _get_token_usage(result: "Messages") -> "tuple[int, int]":
def _get_token_usage(result: "Messages") -> "tuple[int, int, int, int]":
"""
Get token usage from the Anthropic response.
Returns: (input_tokens, output_tokens, cache_read_input_tokens, cache_write_input_tokens)
"""
input_tokens = 0
output_tokens = 0
cache_read_input_tokens = 0
cache_write_input_tokens = 0
if hasattr(result, "usage"):
usage = result.usage
if hasattr(usage, "input_tokens") and isinstance(usage.input_tokens, int):
input_tokens = usage.input_tokens
if hasattr(usage, "output_tokens") and isinstance(usage.output_tokens, int):
output_tokens = usage.output_tokens

return input_tokens, output_tokens
if hasattr(usage, "cache_read_input_tokens") and isinstance(
usage.cache_read_input_tokens, int
):
cache_read_input_tokens = usage.cache_read_input_tokens
if hasattr(usage, "cache_creation_input_tokens") and isinstance(
usage.cache_creation_input_tokens, int
):
cache_write_input_tokens = usage.cache_creation_input_tokens

return (
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
)


def _collect_ai_data(
event: "MessageStreamEvent",
model: "str | None",
input_tokens: int,
output_tokens: int,
cache_read_input_tokens: int,
cache_write_input_tokens: int,
content_blocks: "list[str]",
) -> "tuple[str | None, int, int, list[str]]":
) -> "tuple[str | None, int, int, int, int, list[str]]":
"""
Collect model information, token usage, and collect content blocks from the AI streaming response.
"""
Expand All @@ -104,6 +122,14 @@ def _collect_ai_data(
usage = event.message.usage
input_tokens += usage.input_tokens
output_tokens += usage.output_tokens
if hasattr(usage, "cache_read_input_tokens") and isinstance(
usage.cache_read_input_tokens, int
):
cache_read_input_tokens += usage.cache_read_input_tokens
if hasattr(usage, "cache_creation_input_tokens") and isinstance(
usage.cache_creation_input_tokens, int
):
cache_write_input_tokens += usage.cache_creation_input_tokens
model = event.message.model or model
elif event.type == "content_block_start":
pass
Expand All @@ -117,7 +143,14 @@ def _collect_ai_data(
elif event.type == "message_delta":
output_tokens += event.usage.output_tokens

return model, input_tokens, output_tokens, content_blocks
return (
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
content_blocks,
)


def _set_input_data(
Expand Down Expand Up @@ -219,6 +252,8 @@ def _set_output_data(
model: "str | None",
input_tokens: "int | None",
output_tokens: "int | None",
cache_read_input_tokens: "int | None",
cache_write_input_tokens: "int | None",
content_blocks: "list[Any]",
finish_span: bool = False,
) -> None:
Expand Down Expand Up @@ -254,6 +289,8 @@ def _set_output_data(
span,
input_tokens=input_tokens,
output_tokens=output_tokens,
input_tokens_cached=cache_read_input_tokens,
input_tokens_cache_write=cache_write_input_tokens,
)

if finish_span:
Expand Down Expand Up @@ -288,7 +325,12 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A

with capture_internal_exceptions():
if hasattr(result, "content"):
input_tokens, output_tokens = _get_token_usage(result)
(
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
) = _get_token_usage(result)

content_blocks = []
for content_block in result.content:
Expand All @@ -305,6 +347,8 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
model=getattr(result, "model", None),
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_write_input_tokens=cache_write_input_tokens,
content_blocks=content_blocks,
finish_span=True,
)
Expand All @@ -317,13 +361,26 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
model = None
input_tokens = 0
output_tokens = 0
cache_read_input_tokens = 0
cache_write_input_tokens = 0
content_blocks: "list[str]" = []

for event in old_iterator:
model, input_tokens, output_tokens, content_blocks = (
_collect_ai_data(
event, model, input_tokens, output_tokens, content_blocks
)
(
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
content_blocks,
) = _collect_ai_data(
event,
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
content_blocks,
)
yield event

Expand All @@ -333,6 +390,8 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_write_input_tokens=cache_write_input_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
)
Expand All @@ -341,13 +400,26 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
model = None
input_tokens = 0
output_tokens = 0
cache_read_input_tokens = 0
cache_write_input_tokens = 0
content_blocks: "list[str]" = []

async for event in old_iterator:
model, input_tokens, output_tokens, content_blocks = (
_collect_ai_data(
event, model, input_tokens, output_tokens, content_blocks
)
(
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
content_blocks,
) = _collect_ai_data(
event,
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_write_input_tokens,
content_blocks,
)
yield event

Expand All @@ -357,6 +429,8 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_write_input_tokens=cache_write_input_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
)
Expand Down
Loading
Loading