@@ -930,9 +930,13 @@ async def test_bad_chat_completion_async(sentry_init, capture_events):
930930
931931@pytest .mark .parametrize (
932932 "send_default_pii, include_prompts" ,
933- [(True , True ), (True , False ), (False , True ), (False , False )],
933+ [
934+ (True , False ),
935+ (False , True ),
936+ (False , False ),
937+ ],
934938)
935- def test_embeddings_create (
939+ def test_embeddings_create_no_pii (
936940 sentry_init , capture_events , send_default_pii , include_prompts
937941):
938942 sentry_init (
@@ -966,10 +970,109 @@ def test_embeddings_create(
966970 assert tx ["type" ] == "transaction"
967971 span = tx ["spans" ][0 ]
968972 assert span ["op" ] == "gen_ai.embeddings"
969- if send_default_pii and include_prompts :
970- assert "hello" in span ["data" ][SPANDATA .GEN_AI_EMBEDDINGS_INPUT ]
973+
974+ assert SPANDATA .GEN_AI_EMBEDDINGS_INPUT not in span ["data" ]
975+
976+ assert span ["data" ]["gen_ai.usage.input_tokens" ] == 20
977+ assert span ["data" ]["gen_ai.usage.total_tokens" ] == 30
978+
979+
980+ @pytest .mark .parametrize (
981+ "input" ,
982+ [
983+ pytest .param (
984+ "hello" ,
985+ id = "string" ,
986+ ),
987+ pytest .param (
988+ ["First text" , "Second text" , "Third text" ],
989+ id = "string_sequence" ,
990+ ),
991+ pytest .param (
992+ iter (["First text" , "Second text" , "Third text" ]),
993+ id = "string_iterable" ,
994+ ),
995+ pytest .param (
996+ [5 , 8 , 13 , 21 , 34 ],
997+ id = "tokens" ,
998+ ),
999+ pytest .param (
1000+ iter (
1001+ [5 , 8 , 13 , 21 , 34 ],
1002+ ),
1003+ id = "token_iterable" ,
1004+ ),
1005+ pytest .param (
1006+ [
1007+ [5 , 8 , 13 , 21 , 34 ],
1008+ [8 , 13 , 21 , 34 , 55 ],
1009+ ],
1010+ id = "tokens_sequence" ,
1011+ ),
1012+ pytest .param (
1013+ iter (
1014+ [
1015+ [5 , 8 , 13 , 21 , 34 ],
1016+ [8 , 13 , 21 , 34 , 55 ],
1017+ ]
1018+ ),
1019+ id = "tokens_sequence_iterable" ,
1020+ ),
1021+ ],
1022+ )
1023+ def test_embeddings_create (sentry_init , capture_events , input , request ):
1024+ sentry_init (
1025+ integrations = [OpenAIIntegration (include_prompts = True )],
1026+ traces_sample_rate = 1.0 ,
1027+ send_default_pii = True ,
1028+ )
1029+ events = capture_events ()
1030+
1031+ client = OpenAI (api_key = "z" )
1032+
1033+ returned_embedding = CreateEmbeddingResponse (
1034+ data = [Embedding (object = "embedding" , index = 0 , embedding = [1.0 , 2.0 , 3.0 ])],
1035+ model = "some-model" ,
1036+ object = "list" ,
1037+ usage = EmbeddingTokenUsage (
1038+ prompt_tokens = 20 ,
1039+ total_tokens = 30 ,
1040+ ),
1041+ )
1042+
1043+ client .embeddings ._post = mock .Mock (return_value = returned_embedding )
1044+ with start_transaction (name = "openai tx" ):
1045+ response = client .embeddings .create (input = input , model = "text-embedding-3-large" )
1046+
1047+ assert len (response .data [0 ].embedding ) == 3
1048+
1049+ tx = events [0 ]
1050+ assert tx ["type" ] == "transaction"
1051+ span = tx ["spans" ][0 ]
1052+ assert span ["op" ] == "gen_ai.embeddings"
1053+
1054+ param_id = request .node .callspec .id
1055+ if param_id == "string" :
1056+ assert json .loads (span ["data" ][SPANDATA .GEN_AI_EMBEDDINGS_INPUT ]) == ["hello" ]
1057+ elif param_id == "string_sequence" or param_id == "string_iterable" :
1058+ assert json .loads (span ["data" ][SPANDATA .GEN_AI_EMBEDDINGS_INPUT ]) == [
1059+ "First text" ,
1060+ "Second text" ,
1061+ "Third text" ,
1062+ ]
1063+ elif param_id == "tokens" or param_id == "token_iterable" :
1064+ assert json .loads (span ["data" ][SPANDATA .GEN_AI_EMBEDDINGS_INPUT ]) == [
1065+ 5 ,
1066+ 8 ,
1067+ 13 ,
1068+ 21 ,
1069+ 34 ,
1070+ ]
9711071 else :
972- assert SPANDATA .GEN_AI_EMBEDDINGS_INPUT not in span ["data" ]
1072+ assert json .loads (span ["data" ][SPANDATA .GEN_AI_EMBEDDINGS_INPUT ]) == [
1073+ [5 , 8 , 13 , 21 , 34 ],
1074+ [8 , 13 , 21 , 34 , 55 ],
1075+ ]
9731076
9741077 assert span ["data" ]["gen_ai.usage.input_tokens" ] == 20
9751078 assert span ["data" ]["gen_ai.usage.total_tokens" ] == 30
@@ -978,9 +1081,13 @@ def test_embeddings_create(
9781081@pytest .mark .asyncio
9791082@pytest .mark .parametrize (
9801083 "send_default_pii, include_prompts" ,
981- [(True , True ), (True , False ), (False , True ), (False , False )],
1084+ [
1085+ (True , False ),
1086+ (False , True ),
1087+ (False , False ),
1088+ ],
9821089)
983- async def test_embeddings_create_async (
1090+ async def test_embeddings_create_async_no_pii (
9841091 sentry_init , capture_events , send_default_pii , include_prompts
9851092):
9861093 sentry_init (
@@ -1014,10 +1121,112 @@ async def test_embeddings_create_async(
10141121 assert tx ["type" ] == "transaction"
10151122 span = tx ["spans" ][0 ]
10161123 assert span ["op" ] == "gen_ai.embeddings"
1017- if send_default_pii and include_prompts :
1018- assert "hello" in span ["data" ][SPANDATA .GEN_AI_EMBEDDINGS_INPUT ]
1124+
1125+ assert SPANDATA .GEN_AI_EMBEDDINGS_INPUT not in span ["data" ]
1126+
1127+ assert span ["data" ]["gen_ai.usage.input_tokens" ] == 20
1128+ assert span ["data" ]["gen_ai.usage.total_tokens" ] == 30
1129+
1130+
1131+ @pytest .mark .asyncio
1132+ @pytest .mark .parametrize (
1133+ "input" ,
1134+ [
1135+ pytest .param (
1136+ "hello" ,
1137+ id = "string" ,
1138+ ),
1139+ pytest .param (
1140+ ["First text" , "Second text" , "Third text" ],
1141+ id = "string_sequence" ,
1142+ ),
1143+ pytest .param (
1144+ iter (["First text" , "Second text" , "Third text" ]),
1145+ id = "string_iterable" ,
1146+ ),
1147+ pytest .param (
1148+ [5 , 8 , 13 , 21 , 34 ],
1149+ id = "tokens" ,
1150+ ),
1151+ pytest .param (
1152+ iter (
1153+ [5 , 8 , 13 , 21 , 34 ],
1154+ ),
1155+ id = "token_iterable" ,
1156+ ),
1157+ pytest .param (
1158+ [
1159+ [5 , 8 , 13 , 21 , 34 ],
1160+ [8 , 13 , 21 , 34 , 55 ],
1161+ ],
1162+ id = "tokens_sequence" ,
1163+ ),
1164+ pytest .param (
1165+ iter (
1166+ [
1167+ [5 , 8 , 13 , 21 , 34 ],
1168+ [8 , 13 , 21 , 34 , 55 ],
1169+ ]
1170+ ),
1171+ id = "tokens_sequence_iterable" ,
1172+ ),
1173+ ],
1174+ )
1175+ async def test_embeddings_create_async (sentry_init , capture_events , input , request ):
1176+ sentry_init (
1177+ integrations = [OpenAIIntegration (include_prompts = True )],
1178+ traces_sample_rate = 1.0 ,
1179+ send_default_pii = True ,
1180+ )
1181+ events = capture_events ()
1182+
1183+ client = AsyncOpenAI (api_key = "z" )
1184+
1185+ returned_embedding = CreateEmbeddingResponse (
1186+ data = [Embedding (object = "embedding" , index = 0 , embedding = [1.0 , 2.0 , 3.0 ])],
1187+ model = "some-model" ,
1188+ object = "list" ,
1189+ usage = EmbeddingTokenUsage (
1190+ prompt_tokens = 20 ,
1191+ total_tokens = 30 ,
1192+ ),
1193+ )
1194+
1195+ client .embeddings ._post = AsyncMock (return_value = returned_embedding )
1196+ with start_transaction (name = "openai tx" ):
1197+ response = await client .embeddings .create (
1198+ input = input , model = "text-embedding-3-large"
1199+ )
1200+
1201+ assert len (response .data [0 ].embedding ) == 3
1202+
1203+ tx = events [0 ]
1204+ assert tx ["type" ] == "transaction"
1205+ span = tx ["spans" ][0 ]
1206+ assert span ["op" ] == "gen_ai.embeddings"
1207+
1208+ param_id = request .node .callspec .id
1209+ if param_id == "string" :
1210+ assert json .loads (span ["data" ][SPANDATA .GEN_AI_EMBEDDINGS_INPUT ]) == ["hello" ]
1211+ elif param_id == "string_sequence" or param_id == "string_iterable" :
1212+ assert json .loads (span ["data" ][SPANDATA .GEN_AI_EMBEDDINGS_INPUT ]) == [
1213+ "First text" ,
1214+ "Second text" ,
1215+ "Third text" ,
1216+ ]
1217+ elif param_id == "tokens" or param_id == "token_iterable" :
1218+ assert json .loads (span ["data" ][SPANDATA .GEN_AI_EMBEDDINGS_INPUT ]) == [
1219+ 5 ,
1220+ 8 ,
1221+ 13 ,
1222+ 21 ,
1223+ 34 ,
1224+ ]
10191225 else :
1020- assert SPANDATA .GEN_AI_EMBEDDINGS_INPUT not in span ["data" ]
1226+ assert json .loads (span ["data" ][SPANDATA .GEN_AI_EMBEDDINGS_INPUT ]) == [
1227+ [5 , 8 , 13 , 21 , 34 ],
1228+ [8 , 13 , 21 , 34 , 55 ],
1229+ ]
10211230
10221231 assert span ["data" ]["gen_ai.usage.input_tokens" ] == 20
10231232 assert span ["data" ]["gen_ai.usage.total_tokens" ] == 30
0 commit comments