diff --git a/packages/http/httpx/kiota_http/httpx_request_adapter.py b/packages/http/httpx/kiota_http/httpx_request_adapter.py index 9f379fe5..161fd77b 100644 --- a/packages/http/httpx/kiota_http/httpx_request_adapter.py +++ b/packages/http/httpx/kiota_http/httpx_request_adapter.py @@ -332,6 +332,7 @@ async def send_primitive_async( await self.throw_failed_responses(response, error_map, parent_span, parent_span) if self._should_return_none(response): return None + if response_type == "bytes": return response.content # type: ignore _deserialized_span = self._start_local_tracing_span("get_root_parse_node", parent_span) @@ -425,7 +426,79 @@ async def get_root_parse_node( span.end() def _should_return_none(self, response: httpx.Response) -> bool: - return response.status_code == 204 or not bool(response.content) + """Helper function to check if the response should return None. + + Conditions: + - The response status code is 204 or 304 + - the response content is empty. + - The response status code is 301 or 302 and the location header is not present. + + Returns: + bool: True if the response should return None, False otherwise. + """ + return response.status_code == 204 or response.status_code == 304 or not bool( + response.content + ) or (not response.headers.get("location") and response.status_code in [301, 302]) + + def _is_redirect_missing_location( + self, response: httpx.Response, parent_span: trace.Span, attribute_span: trace.Span + ) -> bool: + if response.is_redirect: + if response.has_redirect_location: + return False + # Raise a more specific error if the server returned a redirect status code + # without a location header + attribute_span.set_status(trace.StatusCode.ERROR) + _throw_failed_resp_span = self._start_local_tracing_span( + "throw_failed_responses", parent_span + ) + _throw_failed_resp_span.set_attribute("status", response.status_code) + exc = APIError( + f"The server returned a redirect status code {response.status_code}" + " without a location header", + response.status_code, + response.headers, # type: ignore + ) + _throw_failed_resp_span.set_status(trace.StatusCode.ERROR, str(exc)) + attribute_span.record_exception(exc) + _throw_failed_resp_span.end() + raise exc + return True + + async def _get_error_from_response( + self, + response: httpx.Response, + error_map: dict[str, type[ParsableFactory]], + response_status_code_str: str, + response_status_code: int, + attribute_span: trace.Span, + _throw_failed_resp_span: trace.Span, + ) -> object: + error_class = None + if response_status_code_str in error_map: # Error Code 400 - <= 599 + error_class = error_map[response_status_code_str] + elif 400 <= response_status_code < 500 and "4XX" in error_map: # Error code 4XX + error_class = error_map["4XX"] + elif 500 <= response_status_code < 600 and "5XX" in error_map: # Error code 5XX + error_class = error_map["5XX"] + elif "XXX" in error_map: # Blanket case + error_class = error_map["XXX"] + + root_node = await self.get_root_parse_node( + response, _throw_failed_resp_span, _throw_failed_resp_span + ) + attribute_span.set_attribute(ERROR_BODY_FOUND_KEY, bool(root_node)) + + _get_obj_ctx = trace.set_span_in_context(_throw_failed_resp_span) + _get_obj_span = tracer.start_span("get_object_value", context=_get_obj_ctx) + + if not root_node: + return None + error = None + if error_class: + error = root_node.get_object_value(error_class) + _get_obj_span.end() + return error async def throw_failed_responses( self, @@ -434,7 +507,9 @@ async def throw_failed_responses( parent_span: trace.Span, attribute_span: trace.Span, ) -> None: - if response.is_success: + if response.is_success or response.status_code == 304: + return + if self._is_redirect_missing_location(response, parent_span, attribute_span) is False: return try: attribute_span.set_status(trace.StatusCode.ERROR) @@ -476,29 +551,14 @@ async def throw_failed_responses( raise exc _throw_failed_resp_span.set_attribute("status_message", "received_error_response") - error_class = None - if response_status_code_str in error_map: # Error Code 400 - <= 599 - error_class = error_map[response_status_code_str] - elif 400 <= response_status_code < 500 and "4XX" in error_map: # Error code 4XX - error_class = error_map["4XX"] - elif 500 <= response_status_code < 600 and "5XX" in error_map: # Error code 5XX - error_class = error_map["5XX"] - elif "XXX" in error_map: # Blanket case - error_class = error_map["XXX"] - - root_node = await self.get_root_parse_node( - response, _throw_failed_resp_span, _throw_failed_resp_span + error = await self._get_error_from_response( + response, + error_map, + response_status_code_str, + response_status_code, + attribute_span, + _throw_failed_resp_span, ) - attribute_span.set_attribute(ERROR_BODY_FOUND_KEY, bool(root_node)) - - _get_obj_ctx = trace.set_span_in_context(_throw_failed_resp_span) - _get_obj_span = tracer.start_span("get_object_value", context=_get_obj_ctx) - - if not root_node: - return None - error = None - if error_class: - error = root_node.get_object_value(error_class) if isinstance(error, APIError): error.response_headers = response_headers # type: ignore error.response_status_code = response_status_code @@ -512,7 +572,6 @@ async def throw_failed_responses( response_status_code, response_headers, # type: ignore ) - _get_obj_span.end() raise exc finally: _throw_failed_resp_span.end() diff --git a/packages/http/httpx/tests/test_httpx_request_adapter.py b/packages/http/httpx/tests/test_httpx_request_adapter.py index 20757cec..67c622e8 100644 --- a/packages/http/httpx/tests/test_httpx_request_adapter.py +++ b/packages/http/httpx/tests/test_httpx_request_adapter.py @@ -402,6 +402,54 @@ async def test_retries_on_cae_failure( request_adapter._authentication_provider.authenticate_request.assert_has_awaits(calls) +@pytest.mark.asyncio +async def test_send_primitive_async_304_no_location_header_returns_null( + request_adapter, request_info +): + mock_304_response = httpx.Response( + status_code=304, headers={"Content-Type": "application/json"} + ) + request_adapter.get_http_response_message = AsyncMock(return_value=mock_304_response) + resp = await request_adapter.get_http_response_message(request_info) + assert resp.status_code == 304 + assert "location" not in resp.headers + final_result = await request_adapter.send_primitive_async(request_info, "float", {}) + assert final_result is None + + +@pytest.mark.asyncio +async def test_send_primitive_async_301_no_location_header_throws(request_adapter, request_info): + mock_301_response = httpx.Response( + status_code=301, headers={"Content-Type": "application/json"} + ) + request_adapter.get_http_response_message = AsyncMock(return_value=mock_301_response) + resp = await request_adapter.get_http_response_message(request_info) + assert resp.status_code == 301 + assert "location" not in resp.headers + with pytest.raises(APIError) as e: + await request_adapter.send_primitive_async(request_info, "float", {}) + assert e is not None + assert e.value.response_status_code == 301 + + +@pytest.mark.asyncio +async def test_send_primitive_async_302_with_location_header_does_not_throw( + request_adapter, request_info +): + mock_302_response = httpx.Response( + status_code=302, + headers={ + "Content-Type": "application/json", + "location": "https://example.com" + } + ) + request_adapter.get_http_response_message = AsyncMock(return_value=mock_302_response) + resp = await request_adapter.get_http_response_message(request_info) + assert resp.status_code == 302 + assert "location" in resp.headers + await request_adapter.send_primitive_async(request_info, "float", {}) + + def test_httpx_request_adapter_uses_http_client_base_url(auth_provider): http_client = httpx.AsyncClient(base_url=BASE_URL) request_adapter = HttpxRequestAdapter(auth_provider, http_client=http_client)