Skip to content

Commit bedc050

Browse files
committed
more test coverage
1 parent 52b21c5 commit bedc050

File tree

1 file changed

+131
-1
lines changed

1 file changed

+131
-1
lines changed

tests/integrations/huggingface_hub/test_huggingface_hub.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def mock_hf_chat_completion_api():
220220
@pytest.fixture
221221
def mock_hf_chat_completion_api_tools():
222222
# type: () -> Any
223-
"""Mock HuggingFace chat completion API"""
223+
"""Mock HuggingFace chat completion API with tool calls."""
224224
with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
225225
model_name = "test-model"
226226

@@ -328,6 +328,52 @@ def mock_hf_chat_completion_api_streaming():
328328
yield rsps
329329

330330

331+
@pytest.fixture
332+
def mock_hf_chat_completion_api_streaming_tools():
333+
# type: () -> Any
334+
"""Mock streaming HuggingFace chat completion API with tool calls."""
335+
with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
336+
model_name = "test-model"
337+
338+
# Mock model info endpoint
339+
rsps.add(
340+
responses.GET,
341+
MODEL_ENDPOINT.format(model_name=model_name),
342+
json={
343+
"id": model_name,
344+
"pipeline_tag": "conversational",
345+
"inferenceProviderMapping": {
346+
"hf-inference": {
347+
"status": "live",
348+
"providerId": model_name,
349+
"task": "conversational",
350+
}
351+
},
352+
},
353+
status=200,
354+
)
355+
356+
# Mock chat completion streaming endpoint
357+
streaming_chat_response = (
358+
b'data:{"id":"xyz-123","created":1234567890,"model":"test-model-123","system_fingerprint":"fp_123","choices":[{"delta":{"role":"assistant","content":"response with tool calls follows"},"index":0,"finish_reason":null}],"usage":null}\n\n'
359+
b'data:{"id":"xyz-124","created":1234567890,"model":"test-model-123","system_fingerprint":"fp_123","choices":[{"delta":{"role":"assistant","tool_calls": [{"id": "call_123","type": "function","function": {"name": "get_weather", "arguments": {"location": "Paris"}}}]},"index":0,"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":183,"completion_tokens":14,"total_tokens":197}}\n\n'
360+
)
361+
362+
rsps.add(
363+
responses.POST,
364+
INFERENCE_ENDPOINT.format(model_name=model_name) + "/v1/chat/completions",
365+
body=streaming_chat_response,
366+
status=200,
367+
headers={
368+
"Content-Type": "text/event-stream",
369+
"Cache-Control": "no-cache",
370+
"Connection": "keep-alive",
371+
},
372+
)
373+
374+
yield rsps
375+
376+
331377
@pytest.mark.parametrize("send_default_pii", [True, False])
332378
@pytest.mark.parametrize("include_prompts", [True, False])
333379
def test_text_generation(
@@ -681,5 +727,89 @@ def test_chat_completion_with_tools(
681727
if not send_default_pii or not include_prompts:
682728
assert "gen_ai.request.messages" not in expected_data
683729
assert "gen_ai.response.text" not in expected_data
730+
assert "gen_ai.response.tool_calls" not in expected_data
731+
732+
assert span["data"] == expected_data
733+
734+
735+
@pytest.mark.parametrize("send_default_pii", [True, False])
736+
@pytest.mark.parametrize("include_prompts", [True, False])
737+
def test_chat_completion_streaming_with_tools(
738+
sentry_init,
739+
capture_events,
740+
send_default_pii,
741+
include_prompts,
742+
mock_hf_chat_completion_api_streaming_tools,
743+
):
744+
# type: (Any, Any, Any, Any, Any) -> None
745+
sentry_init(
746+
traces_sample_rate=1.0,
747+
send_default_pii=send_default_pii,
748+
integrations=[HuggingfaceHubIntegration(include_prompts=include_prompts)],
749+
)
750+
events = capture_events()
751+
752+
client = InferenceClient(model="test-model")
753+
754+
tools = [
755+
{
756+
"type": "function",
757+
"function": {
758+
"name": "get_weather",
759+
"description": "Get current weather",
760+
"parameters": {
761+
"type": "object",
762+
"properties": {"location": {"type": "string"}},
763+
"required": ["location"],
764+
},
765+
},
766+
}
767+
]
768+
769+
with sentry_sdk.start_transaction(name="test"):
770+
response = client.chat_completion(
771+
messages=[{"role": "user", "content": "What is the weather in Paris?"}],
772+
stream=True,
773+
tools=tools,
774+
tool_choice="auto",
775+
)
776+
777+
for x in response:
778+
print(x)
779+
780+
(transaction,) = events
781+
(span,) = transaction["spans"]
782+
783+
assert span["op"] == "gen_ai.chat"
784+
assert span["description"] == "chat test-model"
785+
assert span["origin"] == "auto.ai.huggingface_hub"
786+
787+
expected_data = {
788+
"gen_ai.operation.name": "chat",
789+
"gen_ai.request.available_tools": '[{"type": "function", "function": {"name": "get_weather", "description": "Get current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}]',
790+
"gen_ai.request.model": "test-model",
791+
"gen_ai.response.finish_reasons": "tool_calls",
792+
"gen_ai.response.model": "test-model-123",
793+
"gen_ai.response.streaming": True,
794+
"gen_ai.usage.input_tokens": 183,
795+
"gen_ai.usage.output_tokens": 14,
796+
"gen_ai.usage.total_tokens": 197,
797+
"thread.id": mock.ANY,
798+
"thread.name": mock.ANY,
799+
}
800+
801+
if send_default_pii and include_prompts:
802+
expected_data["gen_ai.request.messages"] = (
803+
'[{"role": "user", "content": "What is the weather in Paris?"}]'
804+
)
805+
expected_data["gen_ai.response.text"] = "response with tool calls follows"
806+
expected_data["gen_ai.response.tool_calls"] = (
807+
'[{"function": {"arguments": {"location": "Paris"}, "name": "get_weather"}, "id": "call_123", "type": "function", "index": "None"}]'
808+
)
809+
810+
if not send_default_pii or not include_prompts:
811+
assert "gen_ai.request.messages" not in expected_data
812+
assert "gen_ai.response.text" not in expected_data
813+
assert "gen_ai.response.tool_calls" not in expected_data
684814

685815
assert span["data"] == expected_data

0 commit comments

Comments
 (0)