@@ -220,7 +220,7 @@ def mock_hf_chat_completion_api():
220220@pytest .fixture
221221def mock_hf_chat_completion_api_tools ():
222222 # type: () -> Any
223- """Mock HuggingFace chat completion API"""
223+ """Mock HuggingFace chat completion API with tool calls. """
224224 with responses .RequestsMock (assert_all_requests_are_fired = False ) as rsps :
225225 model_name = "test-model"
226226
@@ -328,6 +328,52 @@ def mock_hf_chat_completion_api_streaming():
328328 yield rsps
329329
330330
331+ @pytest .fixture
332+ def mock_hf_chat_completion_api_streaming_tools ():
333+ # type: () -> Any
334+ """Mock streaming HuggingFace chat completion API with tool calls."""
335+ with responses .RequestsMock (assert_all_requests_are_fired = False ) as rsps :
336+ model_name = "test-model"
337+
338+ # Mock model info endpoint
339+ rsps .add (
340+ responses .GET ,
341+ MODEL_ENDPOINT .format (model_name = model_name ),
342+ json = {
343+ "id" : model_name ,
344+ "pipeline_tag" : "conversational" ,
345+ "inferenceProviderMapping" : {
346+ "hf-inference" : {
347+ "status" : "live" ,
348+ "providerId" : model_name ,
349+ "task" : "conversational" ,
350+ }
351+ },
352+ },
353+ status = 200 ,
354+ )
355+
356+ # Mock chat completion streaming endpoint
357+ streaming_chat_response = (
358+ b'data:{"id":"xyz-123","created":1234567890,"model":"test-model-123","system_fingerprint":"fp_123","choices":[{"delta":{"role":"assistant","content":"response with tool calls follows"},"index":0,"finish_reason":null}],"usage":null}\n \n '
359+ b'data:{"id":"xyz-124","created":1234567890,"model":"test-model-123","system_fingerprint":"fp_123","choices":[{"delta":{"role":"assistant","tool_calls": [{"id": "call_123","type": "function","function": {"name": "get_weather", "arguments": {"location": "Paris"}}}]},"index":0,"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":183,"completion_tokens":14,"total_tokens":197}}\n \n '
360+ )
361+
362+ rsps .add (
363+ responses .POST ,
364+ INFERENCE_ENDPOINT .format (model_name = model_name ) + "/v1/chat/completions" ,
365+ body = streaming_chat_response ,
366+ status = 200 ,
367+ headers = {
368+ "Content-Type" : "text/event-stream" ,
369+ "Cache-Control" : "no-cache" ,
370+ "Connection" : "keep-alive" ,
371+ },
372+ )
373+
374+ yield rsps
375+
376+
331377@pytest .mark .parametrize ("send_default_pii" , [True , False ])
332378@pytest .mark .parametrize ("include_prompts" , [True , False ])
333379def test_text_generation (
@@ -681,5 +727,89 @@ def test_chat_completion_with_tools(
681727 if not send_default_pii or not include_prompts :
682728 assert "gen_ai.request.messages" not in expected_data
683729 assert "gen_ai.response.text" not in expected_data
730+ assert "gen_ai.response.tool_calls" not in expected_data
731+
732+ assert span ["data" ] == expected_data
733+
734+
735+ @pytest .mark .parametrize ("send_default_pii" , [True , False ])
736+ @pytest .mark .parametrize ("include_prompts" , [True , False ])
737+ def test_chat_completion_streaming_with_tools (
738+ sentry_init ,
739+ capture_events ,
740+ send_default_pii ,
741+ include_prompts ,
742+ mock_hf_chat_completion_api_streaming_tools ,
743+ ):
744+ # type: (Any, Any, Any, Any, Any) -> None
745+ sentry_init (
746+ traces_sample_rate = 1.0 ,
747+ send_default_pii = send_default_pii ,
748+ integrations = [HuggingfaceHubIntegration (include_prompts = include_prompts )],
749+ )
750+ events = capture_events ()
751+
752+ client = InferenceClient (model = "test-model" )
753+
754+ tools = [
755+ {
756+ "type" : "function" ,
757+ "function" : {
758+ "name" : "get_weather" ,
759+ "description" : "Get current weather" ,
760+ "parameters" : {
761+ "type" : "object" ,
762+ "properties" : {"location" : {"type" : "string" }},
763+ "required" : ["location" ],
764+ },
765+ },
766+ }
767+ ]
768+
769+ with sentry_sdk .start_transaction (name = "test" ):
770+ response = client .chat_completion (
771+ messages = [{"role" : "user" , "content" : "What is the weather in Paris?" }],
772+ stream = True ,
773+ tools = tools ,
774+ tool_choice = "auto" ,
775+ )
776+
777+ for x in response :
778+ print (x )
779+
780+ (transaction ,) = events
781+ (span ,) = transaction ["spans" ]
782+
783+ assert span ["op" ] == "gen_ai.chat"
784+ assert span ["description" ] == "chat test-model"
785+ assert span ["origin" ] == "auto.ai.huggingface_hub"
786+
787+ expected_data = {
788+ "gen_ai.operation.name" : "chat" ,
789+ "gen_ai.request.available_tools" : '[{"type": "function", "function": {"name": "get_weather", "description": "Get current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}]' ,
790+ "gen_ai.request.model" : "test-model" ,
791+ "gen_ai.response.finish_reasons" : "tool_calls" ,
792+ "gen_ai.response.model" : "test-model-123" ,
793+ "gen_ai.response.streaming" : True ,
794+ "gen_ai.usage.input_tokens" : 183 ,
795+ "gen_ai.usage.output_tokens" : 14 ,
796+ "gen_ai.usage.total_tokens" : 197 ,
797+ "thread.id" : mock .ANY ,
798+ "thread.name" : mock .ANY ,
799+ }
800+
801+ if send_default_pii and include_prompts :
802+ expected_data ["gen_ai.request.messages" ] = (
803+ '[{"role": "user", "content": "What is the weather in Paris?"}]'
804+ )
805+ expected_data ["gen_ai.response.text" ] = "response with tool calls follows"
806+ expected_data ["gen_ai.response.tool_calls" ] = (
807+ '[{"function": {"arguments": {"location": "Paris"}, "name": "get_weather"}, "id": "call_123", "type": "function", "index": "None"}]'
808+ )
809+
810+ if not send_default_pii or not include_prompts :
811+ assert "gen_ai.request.messages" not in expected_data
812+ assert "gen_ai.response.text" not in expected_data
813+ assert "gen_ai.response.tool_calls" not in expected_data
684814
685815 assert span ["data" ] == expected_data
0 commit comments