Skip to content
Merged
109 changes: 84 additions & 25 deletions packages/http/httpx/kiota_http/httpx_request_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ async def send_primitive_async(
await self.throw_failed_responses(response, error_map, parent_span, parent_span)
if self._should_return_none(response):
return None

if response_type == "bytes":
return response.content # type: ignore
_deserialized_span = self._start_local_tracing_span("get_root_parse_node", parent_span)
Expand Down Expand Up @@ -425,7 +426,79 @@ async def get_root_parse_node(
span.end()

def _should_return_none(self, response: httpx.Response) -> bool:
return response.status_code == 204 or not bool(response.content)
"""Helper function to check if the response should return None.

Conditions:
- The response status code is 204 or 304
- the response content is empty.
- The response status code is 301 or 302 and the location header is not present.

Returns:
bool: True if the response should return None, False otherwise.
"""
return response.status_code == 204 or response.status_code == 304 or not bool(
response.content
) or (not response.headers.get("location") and response.status_code in [301, 302])

def _is_redirect_missing_location(
self, response: httpx.Response, parent_span: trace.Span, attribute_span: trace.Span
) -> bool:
if response.is_redirect:
if response.has_redirect_location:
return False
# Raise a more specific error if the server returned a redirect status code
# without a location header
attribute_span.set_status(trace.StatusCode.ERROR)
_throw_failed_resp_span = self._start_local_tracing_span(
"throw_failed_responses", parent_span
)
_throw_failed_resp_span.set_attribute("status", response.status_code)
exc = APIError(
f"The server returned a redirect status code {response.status_code}"
" without a location header",
response.status_code,
response.headers, # type: ignore
)
_throw_failed_resp_span.set_status(trace.StatusCode.ERROR, str(exc))
attribute_span.record_exception(exc)
_throw_failed_resp_span.end()
raise exc
return True

async def _get_error_from_response(
self,
response: httpx.Response,
error_map: dict[str, type[ParsableFactory]],
response_status_code_str: str,
response_status_code: int,
attribute_span: trace.Span,
_throw_failed_resp_span: trace.Span,
) -> object:
error_class = None
if response_status_code_str in error_map: # Error Code 400 - <= 599
error_class = error_map[response_status_code_str]
elif 400 <= response_status_code < 500 and "4XX" in error_map: # Error code 4XX
error_class = error_map["4XX"]
elif 500 <= response_status_code < 600 and "5XX" in error_map: # Error code 5XX
error_class = error_map["5XX"]
elif "XXX" in error_map: # Blanket case
error_class = error_map["XXX"]

root_node = await self.get_root_parse_node(
response, _throw_failed_resp_span, _throw_failed_resp_span
)
attribute_span.set_attribute(ERROR_BODY_FOUND_KEY, bool(root_node))

_get_obj_ctx = trace.set_span_in_context(_throw_failed_resp_span)
_get_obj_span = tracer.start_span("get_object_value", context=_get_obj_ctx)

if not root_node:
return None
error = None
if error_class:
error = root_node.get_object_value(error_class)
_get_obj_span.end()
return error

async def throw_failed_responses(
self,
Expand All @@ -434,7 +507,9 @@ async def throw_failed_responses(
parent_span: trace.Span,
attribute_span: trace.Span,
) -> None:
if response.is_success:
if response.is_success or response.status_code == 304:
return
if self._is_redirect_missing_location(response, parent_span, attribute_span) is False:
return
try:
attribute_span.set_status(trace.StatusCode.ERROR)
Expand Down Expand Up @@ -476,29 +551,14 @@ async def throw_failed_responses(
raise exc
_throw_failed_resp_span.set_attribute("status_message", "received_error_response")

error_class = None
if response_status_code_str in error_map: # Error Code 400 - <= 599
error_class = error_map[response_status_code_str]
elif 400 <= response_status_code < 500 and "4XX" in error_map: # Error code 4XX
error_class = error_map["4XX"]
elif 500 <= response_status_code < 600 and "5XX" in error_map: # Error code 5XX
error_class = error_map["5XX"]
elif "XXX" in error_map: # Blanket case
error_class = error_map["XXX"]

root_node = await self.get_root_parse_node(
response, _throw_failed_resp_span, _throw_failed_resp_span
error = await self._get_error_from_response(
response,
error_map,
response_status_code_str,
response_status_code,
attribute_span,
_throw_failed_resp_span,
)
attribute_span.set_attribute(ERROR_BODY_FOUND_KEY, bool(root_node))

_get_obj_ctx = trace.set_span_in_context(_throw_failed_resp_span)
_get_obj_span = tracer.start_span("get_object_value", context=_get_obj_ctx)

if not root_node:
return None
error = None
if error_class:
error = root_node.get_object_value(error_class)
if isinstance(error, APIError):
error.response_headers = response_headers # type: ignore
error.response_status_code = response_status_code
Expand All @@ -512,7 +572,6 @@ async def throw_failed_responses(
response_status_code,
response_headers, # type: ignore
)
_get_obj_span.end()
raise exc
finally:
_throw_failed_resp_span.end()
Expand Down
48 changes: 48 additions & 0 deletions packages/http/httpx/tests/test_httpx_request_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,54 @@ async def test_retries_on_cae_failure(
request_adapter._authentication_provider.authenticate_request.assert_has_awaits(calls)


@pytest.mark.asyncio
async def test_send_primitive_async_304_no_location_header_returns_null(
request_adapter, request_info
):
mock_304_response = httpx.Response(
status_code=304, headers={"Content-Type": "application/json"}
)
request_adapter.get_http_response_message = AsyncMock(return_value=mock_304_response)
resp = await request_adapter.get_http_response_message(request_info)
assert resp.status_code == 304
assert "location" not in resp.headers
final_result = await request_adapter.send_primitive_async(request_info, "float", {})
assert final_result is None


@pytest.mark.asyncio
async def test_send_primitive_async_301_no_location_header_throws(request_adapter, request_info):
mock_301_response = httpx.Response(
status_code=301, headers={"Content-Type": "application/json"}
)
request_adapter.get_http_response_message = AsyncMock(return_value=mock_301_response)
resp = await request_adapter.get_http_response_message(request_info)
assert resp.status_code == 301
assert "location" not in resp.headers
with pytest.raises(APIError) as e:
await request_adapter.send_primitive_async(request_info, "float", {})
assert e is not None
assert e.value.response_status_code == 301


@pytest.mark.asyncio
async def test_send_primitive_async_302_with_location_header_does_not_throw(
request_adapter, request_info
):
mock_302_response = httpx.Response(
status_code=302,
headers={
"Content-Type": "application/json",
"location": "https://example.com"
}
)
request_adapter.get_http_response_message = AsyncMock(return_value=mock_302_response)
resp = await request_adapter.get_http_response_message(request_info)
assert resp.status_code == 302
assert "location" in resp.headers
await request_adapter.send_primitive_async(request_info, "float", {})


def test_httpx_request_adapter_uses_http_client_base_url(auth_provider):
http_client = httpx.AsyncClient(base_url=BASE_URL)
request_adapter = HttpxRequestAdapter(auth_provider, http_client=http_client)
Expand Down