Skip to content

Commit 9d89b90

Browse files
ref(openai): Only handle streamed results when applicable
1 parent 4f84f98 commit 9d89b90

File tree

2 files changed

+228
-21
lines changed

2 files changed

+228
-21
lines changed

sentry_sdk/integrations/openai.py

Lines changed: 199 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -746,14 +746,103 @@ def _set_streaming_completions_api_output_data(
746746
if messages is not None and isinstance(messages, str):
747747
messages = [messages]
748748

749-
_common_set_output_data(
750-
span,
751-
response,
752-
messages,
753-
integration,
754-
start_time,
755-
finish_span,
756-
)
749+
ttft: "Optional[float]" = None
750+
data_buf: "list[list[str]]" = [] # one for each choice
751+
752+
old_iterator = response._iterator
753+
754+
def new_iterator() -> "Iterator[ChatCompletionChunk]":
755+
nonlocal ttft
756+
count_tokens_manually = True
757+
for x in old_iterator:
758+
with capture_internal_exceptions():
759+
if hasattr(x, "choices"):
760+
choice_index = 0
761+
for choice in x.choices:
762+
if hasattr(choice, "delta") and hasattr(
763+
choice.delta, "content"
764+
):
765+
if start_time is not None and ttft is None:
766+
ttft = time.perf_counter() - start_time
767+
content = choice.delta.content
768+
if len(data_buf) <= choice_index:
769+
data_buf.append([])
770+
data_buf[choice_index].append(content or "")
771+
choice_index += 1
772+
773+
yield x
774+
775+
with capture_internal_exceptions():
776+
if ttft is not None:
777+
set_data_normalized(
778+
span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
779+
)
780+
if len(data_buf) > 0:
781+
all_responses = ["".join(chunk) for chunk in data_buf]
782+
if should_send_default_pii() and integration.include_prompts:
783+
set_data_normalized(
784+
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
785+
)
786+
if count_tokens_manually:
787+
_calculate_token_usage(
788+
messages,
789+
response,
790+
span,
791+
all_responses,
792+
integration.count_tokens,
793+
)
794+
795+
if finish_span:
796+
span.__exit__(None, None, None)
797+
798+
async def new_iterator_async() -> "AsyncIterator[ChatCompletionChunk]":
799+
nonlocal ttft
800+
count_tokens_manually = True
801+
async for x in old_iterator:
802+
with capture_internal_exceptions():
803+
# OpenAI chat completion API
804+
if hasattr(x, "choices"):
805+
choice_index = 0
806+
for choice in x.choices:
807+
if hasattr(choice, "delta") and hasattr(
808+
choice.delta, "content"
809+
):
810+
if start_time is not None and ttft is None:
811+
ttft = time.perf_counter() - start_time
812+
content = choice.delta.content
813+
if len(data_buf) <= choice_index:
814+
data_buf.append([])
815+
data_buf[choice_index].append(content or "")
816+
choice_index += 1
817+
818+
yield x
819+
820+
with capture_internal_exceptions():
821+
if ttft is not None:
822+
set_data_normalized(
823+
span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
824+
)
825+
if len(data_buf) > 0:
826+
all_responses = ["".join(chunk) for chunk in data_buf]
827+
if should_send_default_pii() and integration.include_prompts:
828+
set_data_normalized(
829+
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
830+
)
831+
if count_tokens_manually:
832+
_calculate_token_usage(
833+
messages,
834+
response,
835+
span,
836+
all_responses,
837+
integration.count_tokens,
838+
)
839+
if finish_span:
840+
span.__exit__(None, None, None)
841+
842+
if str(type(response._iterator)) == "<class 'async_generator'>":
843+
response._iterator = new_iterator_async()
844+
else:
845+
response._iterator = new_iterator()
757846

758847

759848
def _set_responses_api_output_data(
@@ -792,14 +881,108 @@ def _set_streaming_responses_api_output_data(
792881
if input is not None and isinstance(input, str):
793882
input = [input]
794883

795-
_common_set_output_data(
796-
span,
797-
response,
798-
input,
799-
integration,
800-
start_time,
801-
finish_span,
802-
)
884+
ttft: "Optional[float]" = None
885+
data_buf: "list[list[str]]" = [] # one for each choice
886+
887+
old_iterator = response._iterator
888+
889+
def new_iterator() -> "Iterator[ChatCompletionChunk]":
890+
nonlocal ttft
891+
count_tokens_manually = True
892+
for x in old_iterator:
893+
with capture_internal_exceptions():
894+
if hasattr(x, "delta"):
895+
if start_time is not None and ttft is None:
896+
ttft = time.perf_counter() - start_time
897+
if len(data_buf) == 0:
898+
data_buf.append([])
899+
data_buf[0].append(x.delta or "")
900+
901+
if isinstance(x, ResponseCompletedEvent):
902+
_calculate_token_usage(
903+
input,
904+
x.response,
905+
span,
906+
None,
907+
integration.count_tokens,
908+
)
909+
count_tokens_manually = False
910+
911+
yield x
912+
913+
with capture_internal_exceptions():
914+
if ttft is not None:
915+
set_data_normalized(
916+
span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
917+
)
918+
if len(data_buf) > 0:
919+
all_responses = ["".join(chunk) for chunk in data_buf]
920+
if should_send_default_pii() and integration.include_prompts:
921+
set_data_normalized(
922+
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
923+
)
924+
if count_tokens_manually:
925+
_calculate_token_usage(
926+
input,
927+
response,
928+
span,
929+
all_responses,
930+
integration.count_tokens,
931+
)
932+
933+
if finish_span:
934+
span.__exit__(None, None, None)
935+
936+
async def new_iterator_async() -> "AsyncIterator[ChatCompletionChunk]":
937+
nonlocal ttft
938+
count_tokens_manually = True
939+
async for x in old_iterator:
940+
with capture_internal_exceptions():
941+
if hasattr(x, "delta"):
942+
if start_time is not None and ttft is None:
943+
ttft = time.perf_counter() - start_time
944+
if len(data_buf) == 0:
945+
data_buf.append([])
946+
data_buf[0].append(x.delta or "")
947+
948+
if isinstance(x, ResponseCompletedEvent):
949+
_calculate_token_usage(
950+
input,
951+
x.response,
952+
span,
953+
None,
954+
integration.count_tokens,
955+
)
956+
count_tokens_manually = False
957+
958+
yield x
959+
960+
with capture_internal_exceptions():
961+
if ttft is not None:
962+
set_data_normalized(
963+
span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
964+
)
965+
if len(data_buf) > 0:
966+
all_responses = ["".join(chunk) for chunk in data_buf]
967+
if should_send_default_pii() and integration.include_prompts:
968+
set_data_normalized(
969+
span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses
970+
)
971+
if count_tokens_manually:
972+
_calculate_token_usage(
973+
input,
974+
response,
975+
span,
976+
all_responses,
977+
integration.count_tokens,
978+
)
979+
if finish_span:
980+
span.__exit__(None, None, None)
981+
982+
if str(type(response._iterator)) == "<class 'async_generator'>":
983+
response._iterator = new_iterator_async()
984+
else:
985+
response._iterator = new_iterator()
803986

804987

805988
def _set_embeddings_output_data(

tests/integrations/openai/test_openai.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
import pytest
33

4+
from typing import Union
5+
46
from sentry_sdk.utils import package_version
57

68
try:
@@ -21,6 +23,8 @@
2123
from openai.types.chat.chat_completion_chunk import ChoiceDelta, Choice as DeltaChoice
2224
from openai.types.create_embedding_response import Usage as EmbeddingTokenUsage
2325

26+
from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
27+
2428
SKIP_RESPONSES_TESTS = False
2529

2630
try:
@@ -500,6 +504,7 @@ def test_streaming_chat_completion_no_prompts(
500504
{"role": "system", "content": "You are a helpful assistant."},
501505
{"role": "user", "content": "hello"},
502506
],
507+
stream=True,
503508
)
504509
response_string = "".join(
505510
map(lambda x: x.choices[0].delta.content, response_stream)
@@ -624,6 +629,7 @@ def test_streaming_chat_completion(sentry_init, capture_events, messages, reques
624629
response_stream = client.chat.completions.create(
625630
model="some-model",
626631
messages=messages,
632+
stream=True,
627633
)
628634
response_string = "".join(
629635
map(lambda x: x.choices[0].delta.content, response_stream)
@@ -747,6 +753,7 @@ async def test_streaming_chat_completion_async_no_prompts(
747753
{"role": "system", "content": "You are a helpful assistant."},
748754
{"role": "user", "content": "hello"},
749755
],
756+
stream=True,
750757
)
751758

752759
response_string = ""
@@ -881,6 +888,7 @@ async def test_streaming_chat_completion_async(
881888
response_stream = await client.chat.completions.create(
882889
model="some-model",
883890
messages=messages,
891+
stream=True,
884892
)
885893

886894
response_string = ""
@@ -942,7 +950,9 @@ def test_bad_chat_completion(sentry_init, capture_events):
942950
)
943951
with pytest.raises(OpenAIError):
944952
client.chat.completions.create(
945-
model="some-model", messages=[{"role": "system", "content": "hello"}]
953+
model="some-model",
954+
messages=[{"role": "system", "content": "hello"}],
955+
stream=True,
946956
)
947957

948958
(event,) = events
@@ -2340,6 +2350,16 @@ async def test_ai_client_span_responses_async_api(
23402350
assert spans[0]["data"] == expected_data
23412351

23422352

2353+
async def example_response_stream():
2354+
yield EXAMPLE_RESPONSES_STREAM[0]
2355+
yield EXAMPLE_RESPONSES_STREAM[1]
2356+
yield EXAMPLE_RESPONSES_STREAM[2]
2357+
yield EXAMPLE_RESPONSES_STREAM[3]
2358+
yield EXAMPLE_RESPONSES_STREAM[4]
2359+
2360+
return
2361+
2362+
23432363
@pytest.mark.asyncio
23442364
@pytest.mark.parametrize(
23452365
"instructions",
@@ -2417,15 +2437,19 @@ async def test_ai_client_span_streaming_responses_async_api(
24172437
events = capture_events()
24182438

24192439
client = AsyncOpenAI(api_key="z")
2420-
client.responses._post = AsyncMock(return_value=EXAMPLE_RESPONSE)
2440+
returned_stream = AsyncStream(cast_to=None, response=None, client=client)
2441+
returned_stream._iterator = example_response_stream()
2442+
client.responses._post = mock.AsyncMock(return_value=returned_stream)
24212443

24222444
with start_transaction(name="openai tx"):
2423-
await client.responses.create(
2445+
result = await client.responses.create(
24242446
model="gpt-4o",
24252447
instructions=instructions,
24262448
input=input,
24272449
stream=True,
24282450
)
2451+
async for _ in result:
2452+
pass
24292453

24302454
(transaction,) = events
24312455
spans = transaction["spans"]
@@ -2438,14 +2462,14 @@ async def test_ai_client_span_streaming_responses_async_api(
24382462
"gen_ai.operation.name": "responses",
24392463
"gen_ai.response.streaming": True,
24402464
"gen_ai.system": "openai",
2441-
"gen_ai.response.model": "response-model-id",
2465+
"gen_ai.response.time_to_first_token": mock.ANY,
24422466
"gen_ai.usage.input_tokens": 20,
24432467
"gen_ai.usage.input_tokens.cached": 5,
24442468
"gen_ai.usage.output_tokens": 10,
24452469
"gen_ai.usage.output_tokens.reasoning": 8,
24462470
"gen_ai.usage.total_tokens": 30,
24472471
"gen_ai.request.model": "gpt-4o",
2448-
"gen_ai.response.text": "the model response",
2472+
"gen_ai.response.text": "hello world",
24492473
"thread.id": mock.ANY,
24502474
"thread.name": mock.ANY,
24512475
}

0 commit comments

Comments
 (0)