Skip to content

Commit 8ac5e33

Browse files
committed
Refactored token usage
1 parent 789f6b3 commit 8ac5e33

File tree

1 file changed

+108
-184
lines changed

1 file changed

+108
-184
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 108 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from functools import wraps
44

55
import sentry_sdk
6-
from sentry_sdk.ai.monitoring import set_ai_pipeline_name, record_token_usage
6+
from sentry_sdk.ai.monitoring import set_ai_pipeline_name
77
from sentry_sdk.ai.utils import set_data_normalized
88
from sentry_sdk.consts import OP, SPANDATA
99
from sentry_sdk.integrations import DidNotEnable, Integration
1010
from sentry_sdk.scope import should_send_default_pii
1111
from sentry_sdk.tracing import Span
12+
from sentry_sdk.tracing_utils import _get_value
1213
from sentry_sdk.utils import logger, capture_internal_exceptions
1314

1415
from typing import TYPE_CHECKING
@@ -29,7 +30,7 @@
2930

3031
try:
3132
from langchain.agents import AgentExecutor
32-
from langchain_core.agents import AgentAction, AgentFinish
33+
from langchain_core.agents import AgentFinish
3334
from langchain_core.callbacks import (
3435
BaseCallbackHandler,
3536
BaseCallbackManager,
@@ -128,102 +129,6 @@ def _normalize_langchain_message(self, message):
128129
parsed.update(message.additional_kwargs)
129130
return parsed
130131

131-
def _extract_token_usage(self, token_usage):
132-
# type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
133-
"""Extract input, output, and total tokens from various token usage formats."""
134-
if not token_usage:
135-
return None, None, None
136-
137-
input_tokens = None
138-
output_tokens = None
139-
total_tokens = None
140-
141-
if hasattr(token_usage, "get"):
142-
input_tokens = token_usage.get("prompt_tokens") or token_usage.get(
143-
"input_tokens"
144-
)
145-
output_tokens = token_usage.get("completion_tokens") or token_usage.get(
146-
"output_tokens"
147-
)
148-
total_tokens = token_usage.get("total_tokens")
149-
else:
150-
input_tokens = getattr(token_usage, "prompt_tokens", None) or getattr(
151-
token_usage, "input_tokens", None
152-
)
153-
output_tokens = getattr(token_usage, "completion_tokens", None) or getattr(
154-
token_usage, "output_tokens", None
155-
)
156-
total_tokens = getattr(token_usage, "total_tokens", None)
157-
158-
# LangChain's OpenAI callback uses these specific field names
159-
if input_tokens is None and hasattr(token_usage, "get"):
160-
input_tokens = token_usage.get("prompt_tokens") or token_usage.get(
161-
"input_tokens"
162-
)
163-
if output_tokens is None and hasattr(token_usage, "get"):
164-
output_tokens = token_usage.get("completion_tokens") or token_usage.get(
165-
"output_tokens"
166-
)
167-
if total_tokens is None and hasattr(token_usage, "get"):
168-
total_tokens = token_usage.get("total_tokens")
169-
170-
return input_tokens, output_tokens, total_tokens
171-
172-
def _extract_token_usage_from_generations(self, generations):
173-
# type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
174-
"""Extract token usage from response.generations structure."""
175-
if not generations:
176-
return None, None, None
177-
178-
total_input = 0
179-
total_output = 0
180-
total_total = 0
181-
found = False
182-
183-
for gen_list in generations:
184-
for gen in gen_list:
185-
usage_metadata = None
186-
if (
187-
hasattr(gen, "message")
188-
and getattr(gen, "message", None) is not None
189-
and hasattr(gen.message, "usage_metadata")
190-
):
191-
usage_metadata = getattr(gen.message, "usage_metadata", None)
192-
if usage_metadata is None and hasattr(gen, "usage_metadata"):
193-
usage_metadata = getattr(gen, "usage_metadata", None)
194-
if usage_metadata:
195-
input_tokens, output_tokens, total_tokens = (
196-
self._extract_token_usage_from_response(usage_metadata)
197-
)
198-
if any([input_tokens, output_tokens, total_tokens]):
199-
found = True
200-
total_input += int(input_tokens)
201-
total_output += int(output_tokens)
202-
total_total += int(total_tokens)
203-
204-
if not found:
205-
return None, None, None
206-
207-
return (
208-
total_input if total_input > 0 else None,
209-
total_output if total_output > 0 else None,
210-
total_total if total_total > 0 else None,
211-
)
212-
213-
def _extract_token_usage_from_response(self, response):
214-
# type: (Any) -> tuple[int, int, int]
215-
if response:
216-
if hasattr(response, "get"):
217-
input_tokens = response.get("input_tokens", 0)
218-
output_tokens = response.get("output_tokens", 0)
219-
total_tokens = response.get("total_tokens", 0)
220-
else:
221-
input_tokens = getattr(response, "input_tokens", 0)
222-
output_tokens = getattr(response, "output_tokens", 0)
223-
total_tokens = getattr(response, "total_tokens", 0)
224-
225-
return input_tokens, output_tokens, total_tokens
226-
227132
def _create_span(self, run_id, parent_id, **kwargs):
228133
# type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
229134
watched_span = None # type: Optional[WatchedSpan]
@@ -369,53 +274,14 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
369274
span_data = self.span_map[run_id]
370275
span = span_data.span
371276

372-
token_usage = None
373-
374-
# Try multiple paths to extract token usage, prioritizing streaming-aware approaches
375-
if response.llm_output and "token_usage" in response.llm_output:
376-
token_usage = response.llm_output["token_usage"]
377-
elif response.llm_output and hasattr(response.llm_output, "token_usage"):
378-
token_usage = response.llm_output.token_usage
379-
elif hasattr(response, "usage"):
380-
token_usage = response.usage
381-
elif hasattr(response, "token_usage"):
382-
token_usage = response.token_usage
383-
# Check for usage_metadata in llm_output (common in streaming responses)
384-
elif response.llm_output and "usage_metadata" in response.llm_output:
385-
token_usage = response.llm_output["usage_metadata"]
386-
elif response.llm_output and hasattr(response.llm_output, "usage_metadata"):
387-
token_usage = response.llm_output.usage_metadata
388-
elif hasattr(response, "usage_metadata"):
389-
token_usage = response.usage_metadata
390-
391277
if should_send_default_pii() and self.include_prompts:
392278
set_data_normalized(
393279
span,
394280
SPANDATA.GEN_AI_RESPONSE_TEXT,
395281
[[x.text for x in list_] for list_ in response.generations],
396282
)
397283

398-
if token_usage:
399-
input_tokens, output_tokens, total_tokens = self._extract_token_usage(
400-
token_usage
401-
)
402-
else:
403-
input_tokens, output_tokens, total_tokens = (
404-
self._extract_token_usage_from_generations(response.generations)
405-
)
406-
407-
if (
408-
input_tokens is not None
409-
or output_tokens is not None
410-
or total_tokens is not None
411-
):
412-
record_token_usage(
413-
span,
414-
input_tokens=input_tokens,
415-
output_tokens=output_tokens,
416-
total_tokens=total_tokens,
417-
)
418-
284+
_record_token_usage(span, response)
419285
self._exit_span(span_data, run_id)
420286

421287
def on_llm_end(self, response, *, run_id, **kwargs):
@@ -469,40 +335,7 @@ def on_llm_end(self, response, *, run_id, **kwargs):
469335
[[x.text for x in list_] for list_ in response.generations],
470336
)
471337

472-
token_usage = None
473-
if response.llm_output and "token_usage" in response.llm_output:
474-
token_usage = response.llm_output["token_usage"]
475-
476-
elif response.llm_output and hasattr(response.llm_output, "token_usage"):
477-
token_usage = response.llm_output.token_usage
478-
479-
elif hasattr(response, "usage"):
480-
token_usage = response.usage
481-
482-
elif hasattr(response, "token_usage"):
483-
token_usage = response.token_usage
484-
485-
if token_usage:
486-
input_tokens, output_tokens, total_tokens = self._extract_token_usage(
487-
token_usage
488-
)
489-
else:
490-
input_tokens, output_tokens, total_tokens = (
491-
self._extract_token_usage_from_generations(response.generations)
492-
)
493-
494-
if (
495-
input_tokens is not None
496-
or output_tokens is not None
497-
or total_tokens is not None
498-
):
499-
record_token_usage(
500-
span,
501-
input_tokens=input_tokens,
502-
output_tokens=output_tokens,
503-
total_tokens=total_tokens,
504-
)
505-
338+
_record_token_usage(span, response)
506339
self._exit_span(span_data, run_id)
507340

508341
def on_llm_error(self, error, *, run_id, **kwargs):
@@ -515,15 +348,6 @@ def on_chat_model_error(self, error, *, run_id, **kwargs):
515348
"""Run when Chat Model errors."""
516349
self._handle_error(run_id, error)
517350

518-
def on_agent_action(self, action, *, run_id, **kwargs):
519-
# type: (SentryLangchainCallback, AgentAction, UUID, Any) -> Any
520-
with capture_internal_exceptions():
521-
if not run_id or run_id not in self.span_map:
522-
return
523-
524-
span_data = self.span_map[run_id]
525-
self._exit_span(span_data, run_id)
526-
527351
def on_agent_finish(self, finish, *, run_id, **kwargs):
528352
# type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
529353
with capture_internal_exceptions():
@@ -595,6 +419,98 @@ def on_tool_error(self, error, *args, run_id, **kwargs):
595419
self._handle_error(run_id, error)
596420

597421

422+
def _extract_tokens(token_usage):
423+
# type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
424+
if not token_usage:
425+
return None, None, None
426+
427+
input_tokens = _get_value(token_usage, "prompt_tokens") or _get_value(
428+
token_usage, "input_tokens"
429+
)
430+
output_tokens = _get_value(token_usage, "completion_tokens") or _get_value(
431+
token_usage, "output_tokens"
432+
)
433+
total_tokens = _get_value(token_usage, "total_tokens")
434+
435+
return input_tokens, output_tokens, total_tokens
436+
437+
438+
def _extract_tokens_from_generations(generations):
439+
# type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
440+
"""Extract token usage from response.generations structure."""
441+
if not generations:
442+
return None, None, None
443+
444+
total_input = 0
445+
total_output = 0
446+
total_total = 0
447+
448+
for gen_list in generations:
449+
for gen in gen_list:
450+
token_usage = _get_token_usage(gen)
451+
input_tokens, output_tokens, total_tokens = _extract_tokens(token_usage)
452+
total_input += input_tokens if input_tokens is not None else 0
453+
total_output += output_tokens if output_tokens is not None else 0
454+
total_total += total_tokens if total_tokens is not None else 0
455+
456+
return (
457+
total_input if total_input > 0 else None,
458+
total_output if total_output > 0 else None,
459+
total_total if total_total > 0 else None,
460+
)
461+
462+
463+
def _get_token_usage(obj):
464+
# type: (Any) -> Optional[Dict[str, Any]]
465+
"""
466+
Check multiple paths to extract token usage from different objects.
467+
"""
468+
possible_names = ("usage", "token_usage", "usage_metadata")
469+
470+
message = _get_value(obj, "message")
471+
if message is not None:
472+
for name in possible_names:
473+
usage = _get_value(message, name)
474+
if usage is not None:
475+
return usage
476+
477+
llm_output = _get_value(obj, "llm_output")
478+
if llm_output is not None:
479+
for name in possible_names:
480+
usage = _get_value(llm_output, name)
481+
if usage is not None:
482+
return usage
483+
484+
# check for usage in the object itself
485+
for name in possible_names:
486+
usage = _get_value(obj, name)
487+
if usage is not None:
488+
return usage
489+
490+
# no usage found anywhere
491+
return None
492+
493+
494+
def _record_token_usage(span, response):
495+
# type: (Span, Any) -> None
496+
token_usage = _get_token_usage(response)
497+
if token_usage:
498+
input_tokens, output_tokens, total_tokens = _extract_tokens(token_usage)
499+
else:
500+
input_tokens, output_tokens, total_tokens = _extract_tokens_from_generations(
501+
response.generations
502+
)
503+
504+
if input_tokens is not None:
505+
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
506+
507+
if output_tokens is not None:
508+
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
509+
510+
if total_tokens is not None:
511+
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)
512+
513+
598514
def _get_request_data(obj, args, kwargs):
599515
# type: (Any, Any, Any) -> tuple[Optional[str], Optional[List[Any]]]
600516
"""
@@ -796,7 +712,11 @@ def new_stream(self, *args, **kwargs):
796712
)
797713

798714
input = args[0].get("input") if len(args) > 1 else None
799-
if input is not None and should_send_default_pii() and self.include_prompts:
715+
if (
716+
input is not None
717+
and should_send_default_pii()
718+
and integration.include_prompts
719+
):
800720
set_data_normalized(
801721
span,
802722
SPANDATA.GEN_AI_REQUEST_MESSAGES,
@@ -819,7 +739,7 @@ def new_iterator():
819739
if (
820740
output is not None
821741
and should_send_default_pii()
822-
and self.include_prompts
742+
and integration.include_prompts
823743
):
824744
span.set_data(SPANDATA.GEN_AI_RESPONSE_TEXT, output)
825745

@@ -831,7 +751,11 @@ async def new_iterator_async():
831751
yield event
832752

833753
output = event.get("output")
834-
if output is not None:
754+
if (
755+
output is not None
756+
and should_send_default_pii()
757+
and integration.include_prompts
758+
):
835759
span.set_data(SPANDATA.GEN_AI_RESPONSE_TEXT, output)
836760

837761
span.__exit__(None, None, None)

0 commit comments

Comments
 (0)