33from functools import wraps
44
55import sentry_sdk
6- from sentry_sdk .ai .monitoring import set_ai_pipeline_name , record_token_usage
6+ from sentry_sdk .ai .monitoring import set_ai_pipeline_name
77from sentry_sdk .ai .utils import set_data_normalized
88from sentry_sdk .consts import OP , SPANDATA
99from sentry_sdk .integrations import DidNotEnable , Integration
1010from sentry_sdk .scope import should_send_default_pii
1111from sentry_sdk .tracing import Span
12+ from sentry_sdk .tracing_utils import _get_value
1213from sentry_sdk .utils import logger , capture_internal_exceptions
1314
1415from typing import TYPE_CHECKING
2930
3031try :
3132 from langchain .agents import AgentExecutor
32- from langchain_core .agents import AgentAction , AgentFinish
33+ from langchain_core .agents import AgentFinish
3334 from langchain_core .callbacks import (
3435 BaseCallbackHandler ,
3536 BaseCallbackManager ,
@@ -128,102 +129,6 @@ def _normalize_langchain_message(self, message):
128129 parsed .update (message .additional_kwargs )
129130 return parsed
130131
131- def _extract_token_usage (self , token_usage ):
132- # type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
133- """Extract input, output, and total tokens from various token usage formats."""
134- if not token_usage :
135- return None , None , None
136-
137- input_tokens = None
138- output_tokens = None
139- total_tokens = None
140-
141- if hasattr (token_usage , "get" ):
142- input_tokens = token_usage .get ("prompt_tokens" ) or token_usage .get (
143- "input_tokens"
144- )
145- output_tokens = token_usage .get ("completion_tokens" ) or token_usage .get (
146- "output_tokens"
147- )
148- total_tokens = token_usage .get ("total_tokens" )
149- else :
150- input_tokens = getattr (token_usage , "prompt_tokens" , None ) or getattr (
151- token_usage , "input_tokens" , None
152- )
153- output_tokens = getattr (token_usage , "completion_tokens" , None ) or getattr (
154- token_usage , "output_tokens" , None
155- )
156- total_tokens = getattr (token_usage , "total_tokens" , None )
157-
158- # LangChain's OpenAI callback uses these specific field names
159- if input_tokens is None and hasattr (token_usage , "get" ):
160- input_tokens = token_usage .get ("prompt_tokens" ) or token_usage .get (
161- "input_tokens"
162- )
163- if output_tokens is None and hasattr (token_usage , "get" ):
164- output_tokens = token_usage .get ("completion_tokens" ) or token_usage .get (
165- "output_tokens"
166- )
167- if total_tokens is None and hasattr (token_usage , "get" ):
168- total_tokens = token_usage .get ("total_tokens" )
169-
170- return input_tokens , output_tokens , total_tokens
171-
172- def _extract_token_usage_from_generations (self , generations ):
173- # type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
174- """Extract token usage from response.generations structure."""
175- if not generations :
176- return None , None , None
177-
178- total_input = 0
179- total_output = 0
180- total_total = 0
181- found = False
182-
183- for gen_list in generations :
184- for gen in gen_list :
185- usage_metadata = None
186- if (
187- hasattr (gen , "message" )
188- and getattr (gen , "message" , None ) is not None
189- and hasattr (gen .message , "usage_metadata" )
190- ):
191- usage_metadata = getattr (gen .message , "usage_metadata" , None )
192- if usage_metadata is None and hasattr (gen , "usage_metadata" ):
193- usage_metadata = getattr (gen , "usage_metadata" , None )
194- if usage_metadata :
195- input_tokens , output_tokens , total_tokens = (
196- self ._extract_token_usage_from_response (usage_metadata )
197- )
198- if any ([input_tokens , output_tokens , total_tokens ]):
199- found = True
200- total_input += int (input_tokens )
201- total_output += int (output_tokens )
202- total_total += int (total_tokens )
203-
204- if not found :
205- return None , None , None
206-
207- return (
208- total_input if total_input > 0 else None ,
209- total_output if total_output > 0 else None ,
210- total_total if total_total > 0 else None ,
211- )
212-
213- def _extract_token_usage_from_response (self , response ):
214- # type: (Any) -> tuple[int, int, int]
215- if response :
216- if hasattr (response , "get" ):
217- input_tokens = response .get ("input_tokens" , 0 )
218- output_tokens = response .get ("output_tokens" , 0 )
219- total_tokens = response .get ("total_tokens" , 0 )
220- else :
221- input_tokens = getattr (response , "input_tokens" , 0 )
222- output_tokens = getattr (response , "output_tokens" , 0 )
223- total_tokens = getattr (response , "total_tokens" , 0 )
224-
225- return input_tokens , output_tokens , total_tokens
226-
227132 def _create_span (self , run_id , parent_id , ** kwargs ):
228133 # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
229134 watched_span = None # type: Optional[WatchedSpan]
@@ -369,53 +274,14 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
369274 span_data = self .span_map [run_id ]
370275 span = span_data .span
371276
372- token_usage = None
373-
374- # Try multiple paths to extract token usage, prioritizing streaming-aware approaches
375- if response .llm_output and "token_usage" in response .llm_output :
376- token_usage = response .llm_output ["token_usage" ]
377- elif response .llm_output and hasattr (response .llm_output , "token_usage" ):
378- token_usage = response .llm_output .token_usage
379- elif hasattr (response , "usage" ):
380- token_usage = response .usage
381- elif hasattr (response , "token_usage" ):
382- token_usage = response .token_usage
383- # Check for usage_metadata in llm_output (common in streaming responses)
384- elif response .llm_output and "usage_metadata" in response .llm_output :
385- token_usage = response .llm_output ["usage_metadata" ]
386- elif response .llm_output and hasattr (response .llm_output , "usage_metadata" ):
387- token_usage = response .llm_output .usage_metadata
388- elif hasattr (response , "usage_metadata" ):
389- token_usage = response .usage_metadata
390-
391277 if should_send_default_pii () and self .include_prompts :
392278 set_data_normalized (
393279 span ,
394280 SPANDATA .GEN_AI_RESPONSE_TEXT ,
395281 [[x .text for x in list_ ] for list_ in response .generations ],
396282 )
397283
398- if token_usage :
399- input_tokens , output_tokens , total_tokens = self ._extract_token_usage (
400- token_usage
401- )
402- else :
403- input_tokens , output_tokens , total_tokens = (
404- self ._extract_token_usage_from_generations (response .generations )
405- )
406-
407- if (
408- input_tokens is not None
409- or output_tokens is not None
410- or total_tokens is not None
411- ):
412- record_token_usage (
413- span ,
414- input_tokens = input_tokens ,
415- output_tokens = output_tokens ,
416- total_tokens = total_tokens ,
417- )
418-
284+ _record_token_usage (span , response )
419285 self ._exit_span (span_data , run_id )
420286
421287 def on_llm_end (self , response , * , run_id , ** kwargs ):
@@ -469,40 +335,7 @@ def on_llm_end(self, response, *, run_id, **kwargs):
469335 [[x .text for x in list_ ] for list_ in response .generations ],
470336 )
471337
472- token_usage = None
473- if response .llm_output and "token_usage" in response .llm_output :
474- token_usage = response .llm_output ["token_usage" ]
475-
476- elif response .llm_output and hasattr (response .llm_output , "token_usage" ):
477- token_usage = response .llm_output .token_usage
478-
479- elif hasattr (response , "usage" ):
480- token_usage = response .usage
481-
482- elif hasattr (response , "token_usage" ):
483- token_usage = response .token_usage
484-
485- if token_usage :
486- input_tokens , output_tokens , total_tokens = self ._extract_token_usage (
487- token_usage
488- )
489- else :
490- input_tokens , output_tokens , total_tokens = (
491- self ._extract_token_usage_from_generations (response .generations )
492- )
493-
494- if (
495- input_tokens is not None
496- or output_tokens is not None
497- or total_tokens is not None
498- ):
499- record_token_usage (
500- span ,
501- input_tokens = input_tokens ,
502- output_tokens = output_tokens ,
503- total_tokens = total_tokens ,
504- )
505-
338+ _record_token_usage (span , response )
506339 self ._exit_span (span_data , run_id )
507340
508341 def on_llm_error (self , error , * , run_id , ** kwargs ):
@@ -515,15 +348,6 @@ def on_chat_model_error(self, error, *, run_id, **kwargs):
515348 """Run when Chat Model errors."""
516349 self ._handle_error (run_id , error )
517350
518- def on_agent_action (self , action , * , run_id , ** kwargs ):
519- # type: (SentryLangchainCallback, AgentAction, UUID, Any) -> Any
520- with capture_internal_exceptions ():
521- if not run_id or run_id not in self .span_map :
522- return
523-
524- span_data = self .span_map [run_id ]
525- self ._exit_span (span_data , run_id )
526-
527351 def on_agent_finish (self , finish , * , run_id , ** kwargs ):
528352 # type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
529353 with capture_internal_exceptions ():
@@ -595,6 +419,98 @@ def on_tool_error(self, error, *args, run_id, **kwargs):
595419 self ._handle_error (run_id , error )
596420
597421
422+ def _extract_tokens (token_usage ):
423+ # type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
424+ if not token_usage :
425+ return None , None , None
426+
427+ input_tokens = _get_value (token_usage , "prompt_tokens" ) or _get_value (
428+ token_usage , "input_tokens"
429+ )
430+ output_tokens = _get_value (token_usage , "completion_tokens" ) or _get_value (
431+ token_usage , "output_tokens"
432+ )
433+ total_tokens = _get_value (token_usage , "total_tokens" )
434+
435+ return input_tokens , output_tokens , total_tokens
436+
437+
438+ def _extract_tokens_from_generations (generations ):
439+ # type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
440+ """Extract token usage from response.generations structure."""
441+ if not generations :
442+ return None , None , None
443+
444+ total_input = 0
445+ total_output = 0
446+ total_total = 0
447+
448+ for gen_list in generations :
449+ for gen in gen_list :
450+ token_usage = _get_token_usage (gen )
451+ input_tokens , output_tokens , total_tokens = _extract_tokens (token_usage )
452+ total_input += input_tokens if input_tokens is not None else 0
453+ total_output += output_tokens if output_tokens is not None else 0
454+ total_total += total_tokens if total_tokens is not None else 0
455+
456+ return (
457+ total_input if total_input > 0 else None ,
458+ total_output if total_output > 0 else None ,
459+ total_total if total_total > 0 else None ,
460+ )
461+
462+
463+ def _get_token_usage (obj ):
464+ # type: (Any) -> Optional[Dict[str, Any]]
465+ """
466+ Check multiple paths to extract token usage from different objects.
467+ """
468+ possible_names = ("usage" , "token_usage" , "usage_metadata" )
469+
470+ message = _get_value (obj , "message" )
471+ if message is not None :
472+ for name in possible_names :
473+ usage = _get_value (message , name )
474+ if usage is not None :
475+ return usage
476+
477+ llm_output = _get_value (obj , "llm_output" )
478+ if llm_output is not None :
479+ for name in possible_names :
480+ usage = _get_value (llm_output , name )
481+ if usage is not None :
482+ return usage
483+
484+ # check for usage in the object itself
485+ for name in possible_names :
486+ usage = _get_value (obj , name )
487+ if usage is not None :
488+ return usage
489+
490+ # no usage found anywhere
491+ return None
492+
493+
494+ def _record_token_usage (span , response ):
495+ # type: (Span, Any) -> None
496+ token_usage = _get_token_usage (response )
497+ if token_usage :
498+ input_tokens , output_tokens , total_tokens = _extract_tokens (token_usage )
499+ else :
500+ input_tokens , output_tokens , total_tokens = _extract_tokens_from_generations (
501+ response .generations
502+ )
503+
504+ if input_tokens is not None :
505+ span .set_data (SPANDATA .GEN_AI_USAGE_INPUT_TOKENS , input_tokens )
506+
507+ if output_tokens is not None :
508+ span .set_data (SPANDATA .GEN_AI_USAGE_OUTPUT_TOKENS , output_tokens )
509+
510+ if total_tokens is not None :
511+ span .set_data (SPANDATA .GEN_AI_USAGE_TOTAL_TOKENS , total_tokens )
512+
513+
598514def _get_request_data (obj , args , kwargs ):
599515 # type: (Any, Any, Any) -> tuple[Optional[str], Optional[List[Any]]]
600516 """
@@ -796,7 +712,11 @@ def new_stream(self, *args, **kwargs):
796712 )
797713
798714 input = args [0 ].get ("input" ) if len (args ) > 1 else None
799- if input is not None and should_send_default_pii () and self .include_prompts :
715+ if (
716+ input is not None
717+ and should_send_default_pii ()
718+ and integration .include_prompts
719+ ):
800720 set_data_normalized (
801721 span ,
802722 SPANDATA .GEN_AI_REQUEST_MESSAGES ,
@@ -819,7 +739,7 @@ def new_iterator():
819739 if (
820740 output is not None
821741 and should_send_default_pii ()
822- and self .include_prompts
742+ and integration .include_prompts
823743 ):
824744 span .set_data (SPANDATA .GEN_AI_RESPONSE_TEXT , output )
825745
@@ -831,7 +751,11 @@ async def new_iterator_async():
831751 yield event
832752
833753 output = event .get ("output" )
834- if output is not None :
754+ if (
755+ output is not None
756+ and should_send_default_pii ()
757+ and integration .include_prompts
758+ ):
835759 span .set_data (SPANDATA .GEN_AI_RESPONSE_TEXT , output )
836760
837761 span .__exit__ (None , None , None )
0 commit comments