Skip to content

Commit 711cf51

Browse files
committed
Tool calls test
1 parent 06f3746 commit 711cf51

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed

tests/integrations/huggingface_hub/test_huggingface_hub.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,71 @@ def mock_hf_chat_completion_api():
216216
yield rsps
217217

218218

219+
@pytest.fixture
220+
def mock_hf_chat_completion_api_tools():
221+
# type: () -> Any
222+
"""Mock HuggingFace chat completion API"""
223+
with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
224+
model_name = "test-model"
225+
226+
# Mock model info endpoint
227+
rsps.add(
228+
responses.GET,
229+
MODEL_ENDPOINT.format(model_name=model_name),
230+
json={
231+
"id": model_name,
232+
"pipeline_tag": "conversational",
233+
"inferenceProviderMapping": {
234+
"hf-inference": {
235+
"status": "live",
236+
"providerId": model_name,
237+
"task": "conversational",
238+
}
239+
},
240+
},
241+
status=200,
242+
)
243+
244+
# Mock chat completion endpoint
245+
rsps.add(
246+
responses.POST,
247+
INFERENCE_ENDPOINT.format(model_name=model_name) + "/v1/chat/completions",
248+
json={
249+
"id": "xyz-123",
250+
"created": 1234567890,
251+
"model": f"{model_name}-123",
252+
"system_fingerprint": "fp_123",
253+
"choices": [
254+
{
255+
"index": 0,
256+
"finish_reason": "tool_calls",
257+
"message": {
258+
"role": "assistant",
259+
"tool_calls": [
260+
{
261+
"id": "call_123",
262+
"type": "function",
263+
"function": {
264+
"name": "get_weather",
265+
"arguments": {"location": "Paris"},
266+
},
267+
}
268+
],
269+
},
270+
}
271+
],
272+
"usage": {
273+
"completion_tokens": 8,
274+
"prompt_tokens": 10,
275+
"total_tokens": 18,
276+
},
277+
},
278+
status=200,
279+
)
280+
281+
yield rsps
282+
283+
219284
@pytest.fixture
220285
def mock_hf_chat_completion_api_streaming():
221286
# type: () -> Any
@@ -500,3 +565,71 @@ def test_chat_completion_api_error(
500565
"thread.name": mock.ANY,
501566
}
502567
assert span["data"] == expected_data
568+
569+
570+
@pytest.mark.parametrize("send_default_pii", [True, False])
571+
def test_chat_completion_with_tools(
572+
sentry_init, capture_events, send_default_pii, mock_hf_chat_completion_api_tools
573+
):
574+
# type: (Any, Any, Any, Any) -> None
575+
sentry_init(traces_sample_rate=1.0, send_default_pii=send_default_pii)
576+
events = capture_events()
577+
578+
client = InferenceClient(model="test-model")
579+
580+
tools = [
581+
{
582+
"type": "function",
583+
"function": {
584+
"name": "get_weather",
585+
"description": "Get current weather",
586+
"parameters": {
587+
"type": "object",
588+
"properties": {"location": {"type": "string"}},
589+
"required": ["location"],
590+
},
591+
},
592+
}
593+
]
594+
595+
with sentry_sdk.start_transaction(name="test"):
596+
client.chat_completion(
597+
messages=[{"role": "user", "content": "What is the weather in Paris?"}],
598+
tools=tools,
599+
tool_choice="auto",
600+
)
601+
602+
(transaction,) = events
603+
(span,) = transaction["spans"]
604+
605+
assert span["op"] == "gen_ai.chat"
606+
assert span["description"] == "chat test-model"
607+
608+
expected_data = {
609+
"gen_ai.operation.name": "chat",
610+
"gen_ai.request.available_tools": '[{"type": "function", "function": {"name": '
611+
'"get_weather", "description": "Get current '
612+
'weather", "parameters": {"type": "object", '
613+
'"properties": {"location": {"type": '
614+
'"string"}}, "required": ["location"]}}}]',
615+
"gen_ai.request.model": "test-model",
616+
"gen_ai.response.finish_reasons": "tool_calls",
617+
"gen_ai.response.model": "test-model-123",
618+
"gen_ai.response.tool_calls": '[{"function": {"arguments": {"location": "Paris"}, "name": "get_weather", "description": "None"}, "id": "call_123", "type": "function"}]',
619+
"gen_ai.usage.input_tokens": 10,
620+
"gen_ai.usage.output_tokens": 8,
621+
"gen_ai.usage.total_tokens": 18,
622+
"thread.id": mock.ANY,
623+
"thread.name": mock.ANY,
624+
}
625+
626+
if send_default_pii:
627+
expected_data["gen_ai.request.messages"] = (
628+
'[{"role": "user", "content": "What is the weather in Paris?"}]'
629+
)
630+
631+
if not send_default_pii:
632+
assert "gen_ai.request.messages" not in expected_data
633+
assert "gen_ai.response.text" not in expected_data
634+
635+
assert span["data"] == expected_data

0 commit comments

Comments
 (0)