Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,23 +465,14 @@ def _set_embeddings_input_data(
def _common_set_output_data(
span: "Span",
response: "Any",
kwargs: "dict[str, Any]",
input: "Any",
integration: "OpenAIIntegration",
start_time: "Optional[float]" = None,
finish_span: bool = True,
) -> None:
if hasattr(response, "model"):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model)

# Input messages (the prompt or data sent to the model)
# used for the token usage calculation
messages = kwargs.get("messages")
if messages is None:
messages = kwargs.get("input")

if messages is not None and isinstance(messages, str):
messages = [messages]

ttft: "Optional[float]" = None

if hasattr(response, "choices"):
Expand All @@ -494,7 +485,7 @@ def _common_set_output_data(
if len(response_text) > 0:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_text)

_calculate_token_usage(messages, response, span, None, integration.count_tokens)
_calculate_token_usage(input, response, span, None, integration.count_tokens)

if finish_span:
span.__exit__(None, None, None)
Expand Down Expand Up @@ -530,7 +521,7 @@ def _common_set_output_data(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
)

_calculate_token_usage(messages, response, span, None, integration.count_tokens)
_calculate_token_usage(input, response, span, None, integration.count_tokens)

if finish_span:
span.__exit__(None, None, None)
Expand Down Expand Up @@ -571,7 +562,7 @@ def new_iterator() -> "Iterator[ChatCompletionChunk]":
# OpenAI responses API end of streaming response
if RESPONSES_API_ENABLED and isinstance(x, ResponseCompletedEvent):
_calculate_token_usage(
messages,
input,
x.response,
span,
None,
Expand All @@ -594,7 +585,7 @@ def new_iterator() -> "Iterator[ChatCompletionChunk]":
)
if count_tokens_manually:
_calculate_token_usage(
messages,
input,
response,
span,
all_responses,
Expand Down Expand Up @@ -635,7 +626,7 @@ async def new_iterator_async() -> "AsyncIterator[ChatCompletionChunk]":
# OpenAI responses API end of streaming response
if RESPONSES_API_ENABLED and isinstance(x, ResponseCompletedEvent):
_calculate_token_usage(
messages,
input,
x.response,
span,
None,
Expand All @@ -658,7 +649,7 @@ async def new_iterator_async() -> "AsyncIterator[ChatCompletionChunk]":
)
if count_tokens_manually:
_calculate_token_usage(
messages,
input,
response,
span,
all_responses,
Expand All @@ -672,7 +663,7 @@ async def new_iterator_async() -> "AsyncIterator[ChatCompletionChunk]":
else:
response._iterator = new_iterator()
else:
_calculate_token_usage(messages, response, span, None, integration.count_tokens)
_calculate_token_usage(input, response, span, None, integration.count_tokens)
if finish_span:
span.__exit__(None, None, None)

Expand Down Expand Up @@ -727,10 +718,15 @@ def _set_completions_api_output_data(
start_time: "Optional[float]" = None,
finish_span: bool = True,
) -> None:
messages = kwargs.get("messages")

if messages is not None and isinstance(messages, str):
messages = [messages]

_common_set_output_data(
span,
response,
kwargs,
messages,
integration,
start_time,
finish_span,
Expand All @@ -745,10 +741,15 @@ def _set_streaming_completions_api_output_data(
start_time: "Optional[float]" = None,
finish_span: bool = True,
) -> None:
messages = kwargs.get("messages")

if messages is not None and isinstance(messages, str):
messages = [messages]

_common_set_output_data(
span,
response,
kwargs,
messages,
integration,
start_time,
finish_span,
Expand All @@ -763,10 +764,15 @@ def _set_responses_api_output_data(
start_time: "Optional[float]" = None,
finish_span: bool = True,
) -> None:
input = kwargs.get("input")

if input is not None and isinstance(input, str):
input = [input]

_common_set_output_data(
span,
response,
kwargs,
input,
integration,
start_time,
finish_span,
Expand All @@ -781,10 +787,15 @@ def _set_streaming_responses_api_output_data(
start_time: "Optional[float]" = None,
finish_span: bool = True,
) -> None:
input = kwargs.get("input")

if input is not None and isinstance(input, str):
input = [input]

_common_set_output_data(
span,
response,
kwargs,
input,
integration,
start_time,
finish_span,
Expand All @@ -799,10 +810,15 @@ def _set_embeddings_output_data(
start_time: "Optional[float]" = None,
finish_span: bool = True,
) -> None:
input = kwargs.get("input")

if input is not None and isinstance(input, str):
input = [input]

_common_set_output_data(
span,
response,
kwargs,
input,
integration,
start_time,
finish_span,
Expand Down
Loading