|
2 | 2 | import pytest |
3 | 3 | import responses |
4 | 4 |
|
| 5 | +import huggingface_hub |
5 | 6 | from huggingface_hub import InferenceClient |
6 | 7 |
|
7 | 8 | import sentry_sdk |
@@ -69,6 +70,48 @@ def mock_hf_text_generation_api(): |
69 | 70 | yield rsps |
70 | 71 |
|
71 | 72 |
|
| 73 | +@pytest.fixture |
| 74 | +def mock_hf_api_with_errors(): |
| 75 | + # type: () -> Any |
| 76 | + """Mock HuggingFace API that always raises errors for any request""" |
| 77 | + with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps: |
| 78 | + model_name = "test-model" |
| 79 | + |
| 80 | + # Mock model info endpoint with error |
| 81 | + rsps.add( |
| 82 | + responses.GET, |
| 83 | + MODEL_ENDPOINT.format(model_name=model_name), |
| 84 | + json={"error": "Model not found"}, |
| 85 | + status=404, |
| 86 | + ) |
| 87 | + |
| 88 | + # Mock text generation endpoint with error |
| 89 | + rsps.add( |
| 90 | + responses.POST, |
| 91 | + INFERENCE_ENDPOINT.format(model_name=model_name), |
| 92 | + json={"error": "Internal server error", "message": "Something went wrong"}, |
| 93 | + status=500, |
| 94 | + ) |
| 95 | + |
| 96 | + # Mock chat completion endpoint with error |
| 97 | + rsps.add( |
| 98 | + responses.POST, |
| 99 | + INFERENCE_ENDPOINT.format(model_name=model_name) + "/v1/chat/completions", |
| 100 | + json={"error": "Service unavailable", "message": "Chat completion failed"}, |
| 101 | + status=503, |
| 102 | + ) |
| 103 | + |
| 104 | + # Catch-all pattern for any other model requests |
| 105 | + rsps.add( |
| 106 | + responses.GET, |
| 107 | + "https://huggingface.co/api/models/test-model-error", |
| 108 | + json={"error": "Generic model error"}, |
| 109 | + status=500, |
| 110 | + ) |
| 111 | + |
| 112 | + yield rsps |
| 113 | + |
| 114 | + |
72 | 115 | @pytest.fixture |
73 | 116 | def mock_hf_text_generation_api_streaming(): |
74 | 117 | # type: () -> Any |
@@ -222,9 +265,7 @@ def test_text_generation( |
222 | 265 | sentry_init(traces_sample_rate=1.0, send_default_pii=send_default_pii) |
223 | 266 | events = capture_events() |
224 | 267 |
|
225 | | - client = InferenceClient( |
226 | | - model="test-model", |
227 | | - ) |
| 268 | + client = InferenceClient(model="test-model") |
228 | 269 |
|
229 | 270 | with sentry_sdk.start_transaction(name="test"): |
230 | 271 | client.text_generation( |
@@ -271,9 +312,7 @@ def test_text_generation_streaming( |
271 | 312 | sentry_init(traces_sample_rate=1.0, send_default_pii=send_default_pii) |
272 | 313 | events = capture_events() |
273 | 314 |
|
274 | | - client = InferenceClient( |
275 | | - model="test-model", |
276 | | - ) |
| 315 | + client = InferenceClient(model="test-model") |
277 | 316 |
|
278 | 317 | with sentry_sdk.start_transaction(name="test"): |
279 | 318 | for _ in client.text_generation( |
@@ -321,9 +360,7 @@ def test_chat_completion( |
321 | 360 | sentry_init(traces_sample_rate=1.0, send_default_pii=send_default_pii) |
322 | 361 | events = capture_events() |
323 | 362 |
|
324 | | - client = InferenceClient( |
325 | | - model="test-model", |
326 | | - ) |
| 363 | + client = InferenceClient(model="test-model") |
327 | 364 |
|
328 | 365 | with sentry_sdk.start_transaction(name="test"): |
329 | 366 | client.chat_completion( |
@@ -373,9 +410,7 @@ def test_chat_completion_streaming( |
373 | 410 | sentry_init(traces_sample_rate=1.0, send_default_pii=send_default_pii) |
374 | 411 | events = capture_events() |
375 | 412 |
|
376 | | - client = InferenceClient( |
377 | | - model="test-model", |
378 | | - ) |
| 413 | + client = InferenceClient(model="test-model") |
379 | 414 |
|
380 | 415 | with sentry_sdk.start_transaction(name="test"): |
381 | 416 | response = client.chat_completion( |
@@ -418,3 +453,45 @@ def test_chat_completion_streaming( |
418 | 453 | assert "gen_ai.response.text" not in expected_data |
419 | 454 |
|
420 | 455 | assert span["data"] == expected_data |
| 456 | + |
| 457 | + |
| 458 | +def test_chat_completion_api_error( |
| 459 | + sentry_init, capture_events, mock_hf_api_with_errors |
| 460 | +): |
| 461 | + # type: (Any, Any, Any) -> None |
| 462 | + sentry_init(traces_sample_rate=1.0) |
| 463 | + events = capture_events() |
| 464 | + |
| 465 | + client = InferenceClient(model="test-model") |
| 466 | + |
| 467 | + with sentry_sdk.start_transaction(name="test"): |
| 468 | + with pytest.raises(huggingface_hub.errors.HfHubHTTPError): |
| 469 | + client.chat_completion( |
| 470 | + messages=[{"role": "user", "content": "Hello!"}], |
| 471 | + ) |
| 472 | + |
| 473 | + ( |
| 474 | + error, |
| 475 | + transaction, |
| 476 | + ) = events |
| 477 | + |
| 478 | + assert error["exception"]["values"][0]["mechanism"]["type"] == "huggingface_hub" |
| 479 | + assert not error["exception"]["values"][0]["mechanism"]["handled"] |
| 480 | + |
| 481 | + (span,) = transaction["spans"] |
| 482 | + |
| 483 | + assert span["op"] == "gen_ai.chat" |
| 484 | + assert span["description"] == "chat test-model" |
| 485 | + assert span.get("tags", {}).get("status") == "error" |
| 486 | + |
| 487 | + assert ( |
| 488 | + error["contexts"]["trace"]["trace_id"] |
| 489 | + == transaction["contexts"]["trace"]["trace_id"] |
| 490 | + ) |
| 491 | + expected_data = { |
| 492 | + "gen_ai.operation.name": "chat", |
| 493 | + "gen_ai.request.model": "test-model", |
| 494 | + "thread.id": mock.ANY, |
| 495 | + "thread.name": mock.ANY, |
| 496 | + } |
| 497 | + assert span["data"] == expected_data |
0 commit comments