From a01deb59571dc912fbe5dc3907d56ce9578f0904 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Fri, 30 Jan 2026 03:40:34 +0000 Subject: [PATCH 1/3] [Corehttp] Add misc updates This pulls in applicable changes that were made in azure-core but not yet applied to corehttp. Signed-off-by: Paul Van Eck --- sdk/core/corehttp/CHANGELOG.md | 11 + sdk/core/corehttp/corehttp/credentials.py | 19 +- .../instrumentation/tracing/opentelemetry.py | 43 +++- sdk/core/corehttp/corehttp/paging.py | 13 + sdk/core/corehttp/corehttp/rest/_aiohttp.py | 16 +- .../rest/_http_response_impl_async.py | 2 + .../corehttp/corehttp/rest/_requests_basic.py | 12 +- sdk/core/corehttp/corehttp/rest/_rest_py3.py | 4 +- .../corehttp/runtime/pipeline/__init__.py | 4 +- .../corehttp/runtime/pipeline/_tools.py | 2 +- .../runtime/policies/_authentication.py | 26 +- .../runtime/policies/_authentication_async.py | 18 +- .../runtime/policies/_distributed_tracing.py | 13 + .../corehttp/runtime/policies/_retry.py | 17 +- .../corehttp/runtime/policies/_retry_async.py | 11 +- .../corehttp/runtime/policies/_universal.py | 11 + sdk/core/corehttp/corehttp/serialization.py | 238 +++++++++++++++++- sdk/core/corehttp/corehttp/transport/_base.py | 4 +- .../corehttp/transport/_base_async.py | 4 +- .../corehttp/transport/aiohttp/_aiohttp.py | 53 ++-- .../transport/requests/_requests_basic.py | 43 ++-- sdk/core/corehttp/corehttp/utils/_utils.py | 3 +- .../tests/async_tests/test_transport_async.py | 121 +++++++++ .../corehttp/tests/test_stream_generator.py | 5 +- sdk/core/corehttp/tests/test_tracer_otel.py | 72 +++++- sdk/core/corehttp/tests/test_transport.py | 97 +++++++ 26 files changed, 771 insertions(+), 91 deletions(-) create mode 100644 sdk/core/corehttp/tests/async_tests/test_transport_async.py create mode 100644 sdk/core/corehttp/tests/test_transport.py diff --git a/sdk/core/corehttp/CHANGELOG.md b/sdk/core/corehttp/CHANGELOG.md index 74661c7c2a53..8b30d36ab798 100644 --- a/sdk/core/corehttp/CHANGELOG.md +++ b/sdk/core/corehttp/CHANGELOG.md @@ -12,11 +12,22 @@ - `DistributedHttpTracingPolicy` and `distributed_trace`/`distributed_trace_async` decorators were added to support OpenTelemetry tracing for SDK operations. - SDK clients can define an `_instrumentation_config` class variable to configure the OpenTelemetry tracer used in method span creation. Possible configuration options are `library_name`, `library_version`, `schema_url`, and `attributes`. - Added a global settings object, `corehttp.settings`, to the `corehttp` package. This object can be used to set global settings for the `corehttp` package. Currently the only setting is `tracing_enabled` for enabling/disabling tracing. [#39172](https://github.com/Azure/azure-sdk-for-python/pull/39172) +- Added `start_time` and `context` keyword arguments to `OpenTelemetryTracer.start_span` and `start_as_current_span` methods. +- Added `set_span_error_status` static method to `OpenTelemetryTracer` for setting a span's status to ERROR. +- Added `is_generated_model`, `attribute_list`, and `TypeHandlerRegistry` to `corehttp.serialization` module for SDK model handling. ### Breaking Changes ### Bugs Fixed +- Fixed `retry_backoff_max` being ignored in retry policies when configuring retries. +- Raise correct exception if transport is used while already closed. +- A timeout error when using the `aiohttp` transport will now be raised as a `corehttp.exceptions.ServiceResponseTimeoutError`, a subtype of the previously raised `ServiceResponseError`. +- When using with `aiohttp` 3.10 or later, a connection timeout error will now be raised as a `corehttp.exceptions.ServiceRequestTimeoutError`, which can be retried. +- Fixed leaked requests and aiohttp exceptions for streamed responses. +- Improved granularity of `ServiceRequestError` and `ServiceResponseError` exceptions raised in timeout scenarios from the requests and aiohttp transports. +- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` will now properly chain exceptions raised during claims challenge handling. If a credential raises an exception when attempting to acquire a token in response to a claims challenge, that exception will be raised with the original 401 response as the cause. + ### Other Changes - Added `opentelemetry-api` as an optional dependency for tracing. [#39172](https://github.com/Azure/azure-sdk-for-python/pull/39172) diff --git a/sdk/core/corehttp/corehttp/credentials.py b/sdk/core/corehttp/corehttp/credentials.py index 6b31a4557ac9..b913b08e5fd5 100644 --- a/sdk/core/corehttp/corehttp/credentials.py +++ b/sdk/core/corehttp/corehttp/credentials.py @@ -68,7 +68,11 @@ def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = ... def close(self) -> None: - pass + """Close the credential, releasing any resources it holds. + + :return: None + :rtype: None + """ class ServiceNamedKey(NamedTuple): @@ -93,7 +97,7 @@ class ServiceKeyCredential: It provides the ability to update the key without creating a new client. :param str key: The key used to authenticate to a service - :raises: TypeError + :raises TypeError: If the key is not a string. """ def __init__(self, key: str) -> None: @@ -117,7 +121,8 @@ def update(self, key: str) -> None: to update long-lived clients. :param str key: The key used to authenticate to a service - :raises: ValueError or TypeError + :raises ValueError: If the key is None or empty. + :raises TypeError: If the key is not a string. """ if not key: raise ValueError("The key used for updating can not be None or empty") @@ -132,7 +137,7 @@ class ServiceNamedKeyCredential: :param str name: The name of the credential used to authenticate to a service. :param str key: The key used to authenticate to a service. - :raises: TypeError + :raises TypeError: If the name or key is not a string. """ def __init__(self, name: str, key: str) -> None: @@ -180,7 +185,11 @@ async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptio ... async def close(self) -> None: - pass + """Close the credential, releasing any resources. + + :return: None + :rtype: None + """ async def __aexit__( self, diff --git a/sdk/core/corehttp/corehttp/instrumentation/tracing/opentelemetry.py b/sdk/core/corehttp/corehttp/instrumentation/tracing/opentelemetry.py index 2876dfdd0589..01c82fba217e 100644 --- a/sdk/core/corehttp/corehttp/instrumentation/tracing/opentelemetry.py +++ b/sdk/core/corehttp/corehttp/instrumentation/tracing/opentelemetry.py @@ -5,13 +5,14 @@ from __future__ import annotations from contextlib import contextmanager from contextvars import Token -from typing import Optional, Dict, Sequence, cast, Callable, Iterator, TYPE_CHECKING +from typing import Any, Optional, Dict, Sequence, cast, Callable, Iterator, TYPE_CHECKING from opentelemetry import context as otel_context_module, trace from opentelemetry.trace import ( Span, SpanKind as OpenTelemetrySpanKind, Link as OpenTelemetryLink, + StatusCode, ) from opentelemetry.trace.propagation import get_current_span as get_current_span_otel from opentelemetry.propagate import extract, inject @@ -80,6 +81,8 @@ def start_span( kind: SpanKind = _SpanKind.INTERNAL, attributes: Optional[Attributes] = None, links: Optional[Sequence[Link]] = None, + start_time: Optional[int] = None, + context: Optional[Dict[str, Any]] = None, ) -> Span: """Starts a span without setting it as the current span in the context. @@ -91,17 +94,29 @@ def start_span( :paramtype attributes: Mapping[str, AttributeValue] :keyword links: Links to add to the span. :paramtype links: list[~corehttp.instrumentation.tracing.Link] + :keyword start_time: The start time of the span in nanoseconds since the epoch. + :paramtype start_time: Optional[int] + :keyword context: A dictionary of context values corresponding to the parent span. If not provided, + the current global context will be used. + :paramtype context: Optional[Dict[str, any]] :return: The span that was started :rtype: ~opentelemetry.trace.Span """ otel_kind = _KIND_MAPPINGS.get(kind, OpenTelemetrySpanKind.INTERNAL) otel_links = self._parse_links(links) + otel_context = None + if context: + otel_context = extract(context) + otel_span = self._tracer.start_span( name, + context=otel_context, kind=otel_kind, attributes=attributes, links=otel_links, + start_time=start_time, + record_exception=False, ) return otel_span @@ -114,12 +129,12 @@ def start_as_current_span( kind: SpanKind = _SpanKind.INTERNAL, attributes: Optional[Attributes] = None, links: Optional[Sequence[Link]] = None, + start_time: Optional[int] = None, + context: Optional[Dict[str, Any]] = None, end_on_exit: bool = True, ) -> Iterator[Span]: """Context manager that starts a span and sets it as the current span in the context. - Exiting the context manager will call the span's end method. - .. code:: python with tracer.start_as_current_span("span_name") as span: @@ -134,12 +149,19 @@ def start_as_current_span( :paramtype attributes: Optional[Attributes] :keyword links: Links to add to the span. :paramtype links: Optional[Sequence[Link]] + :keyword start_time: The start time of the span in nanoseconds since the epoch. + :paramtype start_time: Optional[int] + :keyword context: A dictionary of context values corresponding to the parent span. If not provided, + the current global context will be used. + :paramtype context: Optional[Dict[str, any]] :keyword end_on_exit: Whether to end the span when exiting the context manager. Defaults to True. :paramtype end_on_exit: bool :return: The span that was started - :rtype: ~opentelemetry.trace.Span + :rtype: Iterator[~opentelemetry.trace.Span] """ - span = self.start_span(name, kind=kind, attributes=attributes, links=links) + span = self.start_span( + name, kind=kind, attributes=attributes, links=links, start_time=start_time, context=context + ) with trace.use_span( # pylint: disable=not-context-manager span, record_exception=False, end_on_exit=end_on_exit ) as span: @@ -162,6 +184,17 @@ def use_span(cls, span: Span, *, end_on_exit: bool = True) -> Iterator[Span]: ) as active_span: yield active_span + @staticmethod + def set_span_error_status(span: Span, description: Optional[str] = None) -> None: + """Set the status of a span to ERROR with the provided description, if any. + + :param span: The span to set the ERROR status on. + :type span: ~opentelemetry.trace.Span + :param description: An optional description of the error. + :type description: str + """ + span.set_status(StatusCode.ERROR, description=description) + def _parse_links(self, links: Optional[Sequence[Link]]) -> Optional[Sequence[OpenTelemetryLink]]: if not links: return None diff --git a/sdk/core/corehttp/corehttp/paging.py b/sdk/core/corehttp/corehttp/paging.py index 438a75d76592..69498d027142 100644 --- a/sdk/core/corehttp/corehttp/paging.py +++ b/sdk/core/corehttp/corehttp/paging.py @@ -80,6 +80,13 @@ def __iter__(self) -> Iterator[Iterator[ReturnType]]: return self def __next__(self) -> Iterator[ReturnType]: + """Get the next page in the iterator. + + :returns: An iterator of objects in the next page. + :rtype: iterator[ReturnType] + :raises StopIteration: If there are no more pages to return. + :raises ~corehttp.exceptions.BaseError: If the request to get the next page fails. + """ if self.continuation_token is None and self._did_a_call_already: raise StopIteration("End of paging") try: @@ -129,6 +136,12 @@ def __iter__(self) -> Iterator[ReturnType]: return self def __next__(self) -> ReturnType: + """Get the next item in the iterator. + + :returns: The next item in the iterator. + :rtype: ReturnType + :raises StopIteration: If there are no more items to return. + """ if self._page_iterator is None: self._page_iterator = itertools.chain.from_iterable(self.by_page()) return next(self._page_iterator) diff --git a/sdk/core/corehttp/corehttp/rest/_aiohttp.py b/sdk/core/corehttp/corehttp/rest/_aiohttp.py index a9efbe0f8244..78e26e572aa5 100644 --- a/sdk/core/corehttp/corehttp/rest/_aiohttp.py +++ b/sdk/core/corehttp/corehttp/rest/_aiohttp.py @@ -38,6 +38,7 @@ ServiceRequestError, ServiceResponseError, IncompleteReadError, + ServiceResponseTimeoutError, ) from ..runtime.pipeline import AsyncPipeline from ..transport._base_async import _ResponseStopIteration @@ -224,7 +225,18 @@ async def read(self) -> bytes: """ if not self._content: self._stream_download_check() - self._content = await self._internal_response.read() + try: + self._content = await self._internal_response.read() + except aiohttp.client_exceptions.ClientPayloadError as err: + # This is the case that server closes connection before we finish the reading. aiohttp library + # raises ClientPayloadError. + raise IncompleteReadError(err, error=err) from err + except aiohttp.client_exceptions.ClientResponseError as err: + raise ServiceResponseError(err, error=err) from err + except asyncio.TimeoutError as err: + raise ServiceResponseTimeoutError(err, error=err) from err + except aiohttp.client_exceptions.ClientError as err: + raise ServiceRequestError(err, error=err) from err await self._set_read_checks() return _aiohttp_content_helper(self) @@ -306,7 +318,7 @@ async def __anext__(self): except aiohttp.client_exceptions.ClientResponseError as err: raise ServiceResponseError(err, error=err) from err except asyncio.TimeoutError as err: - raise ServiceResponseError(err, error=err) from err + raise ServiceResponseTimeoutError(err, error=err) from err except aiohttp.client_exceptions.ClientError as err: raise ServiceRequestError(err, error=err) from err except Exception: diff --git a/sdk/core/corehttp/corehttp/rest/_http_response_impl_async.py b/sdk/core/corehttp/corehttp/rest/_http_response_impl_async.py index e91fa7c6108c..66a5e4868c5b 100644 --- a/sdk/core/corehttp/corehttp/rest/_http_response_impl_async.py +++ b/sdk/core/corehttp/corehttp/rest/_http_response_impl_async.py @@ -69,6 +69,7 @@ async def read(self) -> bytes: async def iter_raw(self, **kwargs: Any) -> AsyncIterator[bytes]: """Asynchronously iterates over the response's bytes. Will not decompress in the process + :return: An async iterator of bytes from the response :rtype: AsyncIterator[bytes] """ @@ -79,6 +80,7 @@ async def iter_raw(self, **kwargs: Any) -> AsyncIterator[bytes]: async def iter_bytes(self, **kwargs: Any) -> AsyncIterator[bytes]: """Asynchronously iterates over the response's bytes. Will decompress in the process + :return: An async iterator of bytes from the response :rtype: AsyncIterator[bytes] """ diff --git a/sdk/core/corehttp/corehttp/rest/_requests_basic.py b/sdk/core/corehttp/corehttp/rest/_requests_basic.py index dd98a9ccf69a..6adf4ceada79 100644 --- a/sdk/core/corehttp/corehttp/rest/_requests_basic.py +++ b/sdk/core/corehttp/corehttp/rest/_requests_basic.py @@ -38,8 +38,8 @@ from ..runtime.pipeline import Pipeline from ._http_response_impl import _HttpResponseBaseImpl, HttpResponseImpl from ..exceptions import ( - ServiceRequestError, ServiceResponseError, + ServiceResponseTimeoutError, IncompleteReadError, HttpResponseError, DecodeError, @@ -162,6 +162,14 @@ def __next__(self): _LOGGER.warning("Unable to stream download.") internal_response.close() raise HttpResponseError(err, error=err) from err + except requests.ConnectionError as err: + internal_response.close() + if err.args and isinstance(err.args[0], ReadTimeoutError): + raise ServiceResponseTimeoutError(err, error=err) from err + raise ServiceResponseError(err, error=err) from err + except requests.RequestException as err: + internal_response.close() + raise ServiceResponseError(err, error=err) from err except Exception as err: _LOGGER.warning("Unable to stream download.") internal_response.close() @@ -178,7 +186,7 @@ def _read_raw_stream(response, chunk_size=1): except CoreDecodeError as e: raise DecodeError(e, error=e) from e except ReadTimeoutError as e: - raise ServiceRequestError(e, error=e) from e + raise ServiceResponseTimeoutError(e, error=e) from e else: # Standard file-like object. while True: diff --git a/sdk/core/corehttp/corehttp/rest/_rest_py3.py b/sdk/core/corehttp/corehttp/rest/_rest_py3.py index 82ecef36ff9e..6a51305ae24a 100644 --- a/sdk/core/corehttp/corehttp/rest/_rest_py3.py +++ b/sdk/core/corehttp/corehttp/rest/_rest_py3.py @@ -300,7 +300,7 @@ def json(self) -> Any: :return: The JSON deserialized response body :rtype: any - :raises json.decoder.JSONDecodeError or ValueError (in python 2.7) if object is not JSON decodable: + :raises json.decoder.JSONDecodeError: if the body is not valid JSON. """ @abc.abstractmethod @@ -309,7 +309,7 @@ def raise_for_status(self) -> None: If response is good, does nothing. - :raises ~corehttp.HttpResponseError if the object has an error status code.: + :raises ~corehttp.HttpResponseError: if the object has an error status code. """ diff --git a/sdk/core/corehttp/corehttp/runtime/pipeline/__init__.py b/sdk/core/corehttp/corehttp/runtime/pipeline/__init__.py index 9663a7d0b4d1..86c89a74595f 100644 --- a/sdk/core/corehttp/corehttp/runtime/pipeline/__init__.py +++ b/sdk/core/corehttp/corehttp/runtime/pipeline/__init__.py @@ -97,7 +97,7 @@ def __delitem__(self, key: str) -> None: def clear(self) -> None: # pylint: disable=docstring-missing-return, docstring-missing-rtype """Context objects cannot be cleared. - :raises: TypeError + :raises TypeError: If context objects cannot be cleared. """ raise TypeError("Context objects cannot be cleared.") @@ -106,7 +106,7 @@ def update( # pylint: disable=docstring-missing-return, docstring-missing-rtype ) -> None: """Context objects cannot be updated. - :raises: TypeError + :raises TypeError: If context objects cannot be updated. """ raise TypeError("Context objects cannot be updated.") diff --git a/sdk/core/corehttp/corehttp/runtime/pipeline/_tools.py b/sdk/core/corehttp/corehttp/runtime/pipeline/_tools.py index 9c3e09074501..e8d38e68837d 100644 --- a/sdk/core/corehttp/corehttp/runtime/pipeline/_tools.py +++ b/sdk/core/corehttp/corehttp/runtime/pipeline/_tools.py @@ -39,7 +39,7 @@ def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: :type args: list :rtype: any :return: The result of the function - :raises: TypeError + :raises TypeError: If the function returns an awaitable object. """ result = func(*args, **kwargs) if hasattr(result, "__await__"): diff --git a/sdk/core/corehttp/corehttp/runtime/policies/_authentication.py b/sdk/core/corehttp/corehttp/runtime/policies/_authentication.py index ec3623890a68..0a1723b79fdc 100644 --- a/sdk/core/corehttp/corehttp/runtime/policies/_authentication.py +++ b/sdk/core/corehttp/corehttp/runtime/policies/_authentication.py @@ -10,7 +10,7 @@ from ...credentials import TokenRequestOptions from ...rest import HttpResponse, HttpRequest from . import HTTPPolicy, SansIOHTTPPolicy -from ...exceptions import ServiceRequestError +from ...exceptions import ServiceRequestError, HttpResponseError if TYPE_CHECKING: @@ -93,7 +93,7 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H :param str scopes: Lets you specify the type of access needed. :keyword auth_flows: A list of authentication flows to use for the credential. :paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]] - :raises: :class:`~corehttp.exceptions.ServiceRequestError` + :raises ~corehttp.exceptions.ServiceRequestError: If the request fails. """ def on_request( @@ -158,7 +158,19 @@ def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HT if response.http_response.status_code == 401: self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: - request_authorized = self.on_challenge(request, response) + try: + request_authorized = self.on_challenge(request, response) + except Exception as ex: + # If the response is streamed, read it so the error message is immediately available to the user. + # Otherwise, a generic error message will be given and the user will have to read the response + # body to see the actual error. + if response.context.options.get("stream"): + try: + response.http_response.read() # type: ignore + except Exception: # pylint:disable=broad-except + pass + # Raise the exception from the token request with the original 401 response. + raise ex from HttpResponseError(response=response.http_response) if request_authorized: try: response = self.next.send(request) @@ -215,7 +227,8 @@ class ServiceKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseT :type credential: ~corehttp.credentials.ServiceKeyCredential :param str name: The name of the key header used for the credential. :keyword str prefix: The name of the prefix for the header value if any. - :raises: ValueError or TypeError + :raises ValueError: if name is None or empty. + :raises TypeError: if name is not a string or if credential is not an instance of ServiceKeyCredential. """ def __init__( # pylint: disable=unused-argument @@ -238,4 +251,9 @@ def __init__( # pylint: disable=unused-argument self._prefix = prefix + " " if prefix else "" def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Called before the policy sends a request. + + :param request: The request to be modified before sending. + :type request: ~corehttp.runtime.pipeline.PipelineRequest + """ request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}" diff --git a/sdk/core/corehttp/corehttp/runtime/policies/_authentication_async.py b/sdk/core/corehttp/corehttp/runtime/policies/_authentication_async.py index f1eee89b03c9..1797036cae7f 100644 --- a/sdk/core/corehttp/corehttp/runtime/policies/_authentication_async.py +++ b/sdk/core/corehttp/corehttp/runtime/policies/_authentication_async.py @@ -13,6 +13,7 @@ from ._base_async import AsyncHTTPPolicy from ._authentication import _BearerTokenCredentialPolicyBase from ...rest import AsyncHttpResponse, HttpRequest +from ...exceptions import HttpResponseError from ...utils._utils import get_running_async_lock if TYPE_CHECKING: @@ -66,7 +67,7 @@ async def on_request( :type request: ~corehttp.runtime.pipeline.PipelineRequest :keyword auth_flows: A list of authentication flows to use for the credential. :paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]] - :raises: :class:`~corehttp.exceptions.ServiceRequestError` + :raises ~corehttp.exceptions.ServiceRequestError: If the request fails. """ # If auth_flows is an empty list, we should not attempt to authorize the request. if auth_flows is not None and len(auth_flows) == 0: @@ -123,7 +124,20 @@ async def send( if response.http_response.status_code == 401: self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: - request_authorized = await self.on_challenge(request, response) + try: + request_authorized = await self.on_challenge(request, response) + except Exception as ex: + # If the response is streamed, read it so the error message is immediately available to the user. + # Otherwise, a generic error message will be given and the user will have to read the response + # body to see the actual error. + if response.context.options.get("stream"): + try: + await response.http_response.read() # type: ignore + except Exception: # pylint:disable=broad-except + pass + + # Raise the exception from the token request with the original 401 response + raise ex from HttpResponseError(response=response.http_response) if request_authorized: try: response = await self.next.send(request) diff --git a/sdk/core/corehttp/corehttp/runtime/policies/_distributed_tracing.py b/sdk/core/corehttp/corehttp/runtime/policies/_distributed_tracing.py index a26b63c0bc3c..50ff9466d3a4 100644 --- a/sdk/core/corehttp/corehttp/runtime/policies/_distributed_tracing.py +++ b/sdk/core/corehttp/corehttp/runtime/policies/_distributed_tracing.py @@ -55,6 +55,11 @@ def __init__( # pylint: disable=unused-argument self._instrumentation_config = instrumentation_config def on_request(self, request: PipelineRequest[HttpRequest]) -> None: + """Starts a span for the network call. + + :param request: The PipelineRequest object. + :type request: ~corehttp.runtime.pipeline.PipelineRequest + """ ctxt = request.context.options try: tracing_options: TracingOptions = ctxt.pop("tracing_options", {}) @@ -103,6 +108,13 @@ def on_response( request: PipelineRequest[HttpRequest], response: PipelineResponse[HttpRequest, SansIOHttpResponse], ) -> None: + """Ends the span for the network call and updates its status. + + :param request: The PipelineRequest object. + :type request: ~corehttp.runtime.pipeline.PipelineRequest + :param response: The PipelineResponse object. + :type response: ~corehttp.runtime.pipeline.PipelineResponse + """ if self.TRACING_CONTEXT not in request.context: return @@ -127,6 +139,7 @@ def _set_http_client_span_attributes( response: Optional[SansIOHttpResponse] = None, ) -> None: """Add attributes to an HTTP client span. + :param span: The span to add attributes to. :type span: ~opentelemetry.trace.Span :param request: The request made diff --git a/sdk/core/corehttp/corehttp/runtime/policies/_retry.py b/sdk/core/corehttp/corehttp/runtime/policies/_retry.py index 54e55a3c26b2..6397e309e583 100644 --- a/sdk/core/corehttp/corehttp/runtime/policies/_retry.py +++ b/sdk/core/corehttp/corehttp/runtime/policies/_retry.py @@ -52,6 +52,8 @@ class RetryMode(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Enum for retry modes.""" + # pylint: disable=enum-must-be-uppercase Exponential = "exponential" Fixed = "fixed" @@ -104,7 +106,7 @@ def configure_retries(self, options: Dict[str, Any]) -> Dict[str, Any]: "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), "backoff": options.pop("retry_backoff_factor", self.backoff_factor), - "max_backoff": options.pop("retry_backoff_max", self.BACKOFF_MAX), + "max_backoff": options.pop("retry_backoff_max", self.backoff_max), "methods": options.pop("retry_on_methods", self._method_whitelist), "timeout": options.pop("timeout", self.timeout), "history": [], @@ -394,28 +396,21 @@ class RetryPolicy(RetryPolicyBase, HTTPPolicy[HttpRequest, HttpResponse]): :keyword int retry_total: Total number of retries to allow. Takes precedence over other counts. Default value is 10. - :keyword int retry_connect: How many connection-related errors to retry on. These are errors raised before the request is sent to the remote server, which we assume has not triggered the server to process the request. Default value is 3. - :keyword int retry_read: How many times to retry on read errors. These errors are raised after the request was sent to the server, so the request may have side-effects. Default value is 3. - :keyword int retry_status: How many times to retry on bad status codes. Default value is 3. - :keyword float retry_backoff_factor: A backoff factor to apply between attempts after the second try (most errors are resolved immediately by a second try without a delay). In fixed mode, retry policy will always sleep for {backoff factor}. In 'exponential' mode, retry policy will sleep for: `{backoff factor} * (2 ** ({number of total retries} - 1))` seconds. If the backoff_factor is 0.1, then the retry will sleep for [0.0s, 0.2s, 0.4s, ...] between retries. The default value is 0.8. - :keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes). - :keyword RetryMode retry_mode: Fixed or exponential delay between attemps, default is exponential. - :keyword int timeout: Timeout setting for the operation in seconds, default is 604800s (7 days). """ @@ -481,10 +476,10 @@ def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRe :param request: The PipelineRequest object :type request: ~corehttp.runtime.pipeline.PipelineRequest - :return: Returns the PipelineResponse or raises error if maximum retries exceeded. + :return: The PipelineResponse. :rtype: ~corehttp.runtime.pipeline.PipelineResponse - :raises: ~corehttp.exceptions.BaseError if maximum retries exceeded. - :raises: ~corehttp.exceptions.ClientAuthenticationError if authentication + :raises ~corehttp.exceptions.BaseError: if maximum retries exceeded. + :raises ~corehttp.exceptions.ClientAuthenticationError: if authentication fails. """ retry_active = True response = None diff --git a/sdk/core/corehttp/corehttp/runtime/policies/_retry_async.py b/sdk/core/corehttp/corehttp/runtime/policies/_retry_async.py index bd495e20c180..d482cce1ab5b 100644 --- a/sdk/core/corehttp/corehttp/runtime/policies/_retry_async.py +++ b/sdk/core/corehttp/corehttp/runtime/policies/_retry_async.py @@ -54,23 +54,18 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpRe :keyword int retry_total: Total number of retries to allow. Takes precedence over other counts. Default value is 10. - :keyword int retry_connect: How many connection-related errors to retry on. These are errors raised before the request is sent to the remote server, which we assume has not triggered the server to process the request. Default value is 3. - :keyword int retry_read: How many times to retry on read errors. These errors are raised after the request was sent to the server, so the request may have side-effects. Default value is 3. - :keyword int retry_status: How many times to retry on bad status codes. Default value is 3. - :keyword float retry_backoff_factor: A backoff factor to apply between attempts after the second try (most errors are resolved immediately by a second try without a delay). Retry policy will sleep for: `{backoff factor} * (2 ** ({number of total retries} - 1))` seconds. If the backoff_factor is 0.1, then the retry will sleep for [0.0s, 0.2s, 0.4s, ...] between retries. The default value is 0.8. - :keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes). """ @@ -138,10 +133,10 @@ async def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[ :param request: The PipelineRequest object :type request: ~corehttp.runtime.pipeline.PipelineRequest - :return: Returns the PipelineResponse or raises error if maximum retries exceeded. + :return: The PipelineResponse. :rtype: ~corehttp.runtime.pipeline.PipelineResponse - :raise: ~corehttp.exceptions.BaseError if maximum retries exceeded. - :raise: ~corehttp.exceptions.ClientAuthenticationError if authentication fails + :raises ~corehttp.exceptions.BaseError: if maximum retries exceeded. + :raises ~corehttp.exceptions.ClientAuthenticationError: if authentication fails. """ retry_active = True response = None diff --git a/sdk/core/corehttp/corehttp/runtime/policies/_universal.py b/sdk/core/corehttp/corehttp/runtime/policies/_universal.py index ebb813170dd5..cce0989a3ee4 100644 --- a/sdk/core/corehttp/corehttp/runtime/policies/_universal.py +++ b/sdk/core/corehttp/corehttp/runtime/policies/_universal.py @@ -143,6 +143,7 @@ def user_agent(self) -> str: def add_user_agent(self, value: str) -> None: """Add value to current user agent with a space. + :param str value: value to add to user agent. """ self._user_agent = "{} {}".format(self._user_agent, value) @@ -401,6 +402,11 @@ def deserialize_from_http_generics( return cls.deserialize_from_text(response.text(encoding), mime_type, response=response) def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Set the response encoding in the request context. + + :param request: The PipelineRequest object. + :type request: ~corehttp.runtime.pipeline.PipelineRequest + """ options = request.context.options response_encoding = options.pop("response_encoding", self._response_encoding) if response_encoding: @@ -452,6 +458,11 @@ def __init__( self.proxies = proxies def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: + """Adds the proxy information to the request context. + + :param request: The PipelineRequest object. + :type request: ~corehttp.runtime.pipeline.PipelineRequest + """ ctxt = request.context.options if self.proxies and "proxies" not in ctxt: ctxt["proxies"] = self.proxies diff --git a/sdk/core/corehttp/corehttp/serialization.py b/sdk/core/corehttp/corehttp/serialization.py index a9470c3bf37e..49c1c6066918 100644 --- a/sdk/core/corehttp/corehttp/serialization.py +++ b/sdk/core/corehttp/corehttp/serialization.py @@ -4,14 +4,23 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +# pylint: disable=protected-access import base64 +from functools import partial from json import JSONEncoder -from typing import Union, cast, Any +from typing import Dict, List, Optional, Union, cast, Any, Type, Callable, Tuple from datetime import datetime, date, time, timedelta from datetime import timezone -__all__ = ["NULL", "CoreJSONEncoder"] +__all__ = [ + "NULL", + "CoreJSONEncoder", + "is_generated_model", + "attribute_list", + "TypeHandlerRegistry", +] +TZ_UTC = timezone.utc class _Null: @@ -111,10 +120,176 @@ def _datetime_as_isostr(dt: Union[datetime, date, time, timedelta]) -> str: return _timedelta_as_isostr(dt) +class TypeHandlerRegistry: + """A registry for custom serializers and deserializers for specific types or conditions.""" + + def __init__(self) -> None: + self._serializer_types: Dict[Type, Callable] = {} + self._deserializer_types: Dict[Type, Callable] = {} + self._serializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = [] + self._deserializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = [] + + self._serializer_cache: Dict[Type, Optional[Callable]] = {} + self._deserializer_cache: Dict[Type, Optional[Callable]] = {} + + def register_serializer( + self, condition: Union[Type, Callable[[Any], bool]] + ) -> Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]: + """Decorator to register a serializer. + + The handler function is expected to take a single argument, the object to serialize, + and return a dictionary representation of that object. + + Examples: + + .. code-block:: python + + @registry.register_serializer(CustomModel) + def serialize_single_type(value: CustomModel) -> dict: + return value.to_dict() + + @registry.register_serializer(lambda x: isinstance(x, BaseModel)) + def serialize_with_condition(value: BaseModel) -> dict: + return value.to_dict() + + # Called manually for a specific type + def custom_serializer(value: CustomModel) -> Dict[str, Any]: + return {"custom": value.custom} + + registry.register_serializer(CustomModel)(custom_serializer) + + :param condition: A type or a callable predicate function that takes an object and returns a bool. + :type condition: Union[Type, Callable[[Any], bool]] + :return: A decorator that registers the handler function. + :rtype: Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]] + :raises TypeError: If the condition is neither a type nor a callable. + """ + + def decorator(handler_func: Callable[[Any], Dict[str, Any]]) -> Callable[[Any], Dict[str, Any]]: + if isinstance(condition, type): + self._serializer_types[condition] = handler_func + elif callable(condition): + self._serializer_predicates.append((condition, handler_func)) + else: + raise TypeError("Condition must be a type or a callable predicate function.") + + self._serializer_cache.clear() + return handler_func + + return decorator + + def register_deserializer( + self, condition: Union[Type, Callable[[Any], bool]] + ) -> Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]: + """Decorator to register a deserializer. + + The handler function is expected to take two arguments: the target type and the data dictionary, + and return an instance of the target type. + + Examples: + + .. code-block:: python + + @registry.register_deserializer(CustomModel) + def deserialize_single_type(cls: Type[CustomModel], data: dict) -> CustomModel: + return cls(**data) + + @registry.register_deserializer(lambda t: issubclass(t, BaseModel)) + def deserialize_with_condition(cls: Type[BaseModel], data: dict) -> BaseModel: + return cls(**data) + + # Called manually for a specific type + def custom_deserializer(cls: Type[CustomModel], data: Dict[str, Any]) -> CustomModel: + return cls(custom=data["custom"]) + + registry.register_deserializer(CustomModel)(custom_deserializer) + + :param condition: A type or a callable predicate function that takes an object and returns a bool. + :type condition: Union[Type, Callable[[Any], bool]] + :return: A decorator that registers the handler function. + :rtype: Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]] + :raises TypeError: If the condition is neither a type nor a callable. + """ + + def decorator(handler_func: Callable[[Type, Dict[str, Any]], Any]) -> Callable[[Type, Dict[str, Any]], Any]: + if isinstance(condition, type): + self._deserializer_types[condition] = handler_func + elif callable(condition): + self._deserializer_predicates.append((condition, handler_func)) + else: + raise TypeError("Condition must be a type or a callable predicate function.") + + self._deserializer_cache.clear() + return handler_func + + return decorator + + def get_serializer(self, obj: Any) -> Optional[Callable[[Any], Dict[str, Any]]]: + """Gets the appropriate serializer for an object. + + It first checks the type dictionary for a direct type match. + If no match is found, it iterates through the predicate list to find a match. + + Results of the lookup are cached for performance based on the object's type. + + :param obj: The object to serialize. + :type obj: any + :return: The serializer function if found, otherwise None. + :rtype: Optional[Callable[[Any], Dict[str, Any]]] + """ + obj_type = type(obj) + if obj_type in self._serializer_cache: + return self._serializer_cache[obj_type] + + handler = self._serializer_types.get(type(obj)) + if not handler: + for predicate, pred_handler in self._serializer_predicates: + if predicate(obj): + handler = pred_handler + break + + self._serializer_cache[obj_type] = handler + return handler + + def get_deserializer(self, cls: Type) -> Optional[Callable[[Dict[str, Any]], Any]]: + """Gets the appropriate deserializer for a class. + + It first checks the type dictionary for a direct type match. + If no match is found, it iterates through the predicate list to find a match. + + Results of the lookup are cached for performance based on the class. + + :param cls: The class to deserialize. + :type cls: type + :return: A deserializer function bound to the specified class that takes a dictionary and returns + an instance of that class, or None if no deserializer is found. + :rtype: Optional[Callable[[Dict[str, Any]], Any]] + """ + if cls in self._deserializer_cache: + return self._deserializer_cache[cls] + + handler = self._deserializer_types.get(cls) + if not handler: + for predicate, pred_handler in self._deserializer_predicates: + if predicate(cls): + handler = pred_handler + break + + self._deserializer_cache[cls] = partial(handler, cls) if handler else None + return self._deserializer_cache[cls] + + class CoreJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" def default(self, o: Any) -> Any: + """Override the default method to handle datetime and bytes serialization. + + :param o: The object to serialize. + :type o: Any + :return: A JSON-serializable representation of the object. + :rtype: Any + """ if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() try: @@ -122,3 +297,62 @@ def default(self, o: Any) -> Any: except AttributeError: pass return super(CoreJSONEncoder, self).default(o) + + +def is_generated_model(obj: Any) -> bool: + """Check if the object is a generated SDK model. + + :param obj: The object to check. + :type obj: any + :return: True if the object is a generated SDK model, False otherwise. + :rtype: bool + """ + return bool(getattr(obj, "_is_model", False) or hasattr(obj, "_attribute_map")) + + +def _get_flattened_attribute(obj: Any) -> Optional[str]: + """Get the name of the flattened attribute in a generated TypeSpec model if one exists. + + :param any obj: The object to check. + :return: The name of the flattened attribute if it exists, otherwise None. + :rtype: Optional[str] + """ + flattened_items = None + try: + flattened_items = getattr(obj, next(a for a in dir(obj) if "__flattened_items" in a), None) + except StopIteration: + return None + + if flattened_items is None: + return None + + for k, v in obj._attr_to_rest_field.items(): + try: + if set(v._class_type._attr_to_rest_field.keys()).intersection(set(flattened_items)): + return k + except AttributeError: + # if the attribute does not have _class_type, it is not a typespec generated model + continue + return None + + +def attribute_list(obj: Any) -> List[str]: + """Get a list of attribute names for a generated SDK model. + + :param obj: The object to get attributes from. + :type obj: any + :return: A list of attribute names. + :rtype: List[str] + """ + if not is_generated_model(obj): + raise TypeError("Object is not a generated SDK model.") + if hasattr(obj, "_attribute_map"): + return list(obj._attribute_map.keys()) + flattened_attribute = _get_flattened_attribute(obj) + retval: List[str] = [] + for attr_name, rest_field in obj._attr_to_rest_field.items(): + if flattened_attribute == attr_name: + retval.extend(attribute_list(rest_field._class_type)) + else: + retval.append(attr_name) + return retval diff --git a/sdk/core/corehttp/corehttp/transport/_base.py b/sdk/core/corehttp/corehttp/transport/_base.py index f210ae95952c..dbcabc352547 100644 --- a/sdk/core/corehttp/corehttp/transport/_base.py +++ b/sdk/core/corehttp/corehttp/transport/_base.py @@ -80,10 +80,8 @@ def _handle_non_stream_rest_response(response: HttpResponse) -> None: """ try: response.read() + finally: response.close() - except Exception as exc: - response.close() - raise exc class HttpTransport(ContextManager["HttpTransport"], abc.ABC, Generic[HTTPRequestType, HTTPResponseType]): diff --git a/sdk/core/corehttp/corehttp/transport/_base_async.py b/sdk/core/corehttp/corehttp/transport/_base_async.py index 23092fdcfd89..b12d95b413c4 100644 --- a/sdk/core/corehttp/corehttp/transport/_base_async.py +++ b/sdk/core/corehttp/corehttp/transport/_base_async.py @@ -46,10 +46,8 @@ async def _handle_non_stream_rest_response(response: AsyncHttpResponse) -> None: """ try: await response.read() + finally: await response.close() - except Exception as exc: - await response.close() - raise exc class _ResponseStopIteration(Exception): diff --git a/sdk/core/corehttp/corehttp/transport/aiohttp/_aiohttp.py b/sdk/core/corehttp/corehttp/transport/aiohttp/_aiohttp.py index 6decaddf4dd8..ec74b60cbfe8 100644 --- a/sdk/core/corehttp/corehttp/transport/aiohttp/_aiohttp.py +++ b/sdk/core/corehttp/corehttp/transport/aiohttp/_aiohttp.py @@ -24,7 +24,7 @@ # # -------------------------------------------------------------------------- from __future__ import annotations -from typing import Optional, TYPE_CHECKING, Type, cast, MutableMapping +from typing import Optional, TYPE_CHECKING, Type, MutableMapping from types import TracebackType import logging @@ -34,7 +34,9 @@ from ...exceptions import ( ServiceRequestError, + ServiceRequestTimeoutError, ServiceResponseError, + ServiceResponseTimeoutError, ) from .._base_async import AsyncHttpTransport, _handle_non_stream_rest_response from .._base import _create_connection_config @@ -52,6 +54,15 @@ CONTENT_CHUNK_SIZE = 10 * 1024 _LOGGER = logging.getLogger(__name__) +try: + # ConnectionTimeoutError was only introduced in aiohttp 3.10 so we want to keep this + # backwards compatible. If client is using aiohttp <3.10, the behaviour will safely + # fall back to treating a TimeoutError as a ServiceResponseError (that wont be retried). + from aiohttp.client_exceptions import ConnectionTimeoutError +except ImportError: + + class ConnectionTimeoutError(Exception): ... # type: ignore[no-redef] + class AioHttpTransport(AsyncHttpTransport): """AioHttp HTTP sender implementation. @@ -71,6 +82,7 @@ def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, session_o raise ValueError("session_owner cannot be False if no session is provided") self.connection_config = _create_connection_config(**kwargs) self._use_env_settings = kwargs.pop("use_env_settings", True) + self._has_been_opened = False async def __aenter__(self): await self.open() @@ -86,23 +98,31 @@ async def __aexit__( async def open(self): """Opens the connection.""" - if not self.session and self._session_owner: - jar = aiohttp.DummyCookieJar() - clientsession_kwargs = { - "trust_env": self._use_env_settings, - "cookie_jar": jar, - "auto_decompress": False, - } - self.session = aiohttp.ClientSession(**clientsession_kwargs) - # pyright has trouble to understand that self.session is not None, since we raised at worst in the init - self.session = cast(aiohttp.ClientSession, self.session) + if self._has_been_opened and not self.session: + raise ValueError( + "HTTP transport has already been closed. " + "You may check if you're calling a function outside of the `async with` of your client creation, " + "or if you called `await close()` on your client already." + ) + if not self.session: + if self._session_owner: + jar = aiohttp.DummyCookieJar() + clientsession_kwargs = { + "trust_env": self._use_env_settings, + "cookie_jar": jar, + "auto_decompress": False, + } + self.session = aiohttp.ClientSession(**clientsession_kwargs) + else: + raise ValueError("session_owner cannot be False and no session is available") + + self._has_been_opened = True await self.session.__aenter__() async def close(self): """Closes the connection.""" if self._session_owner and self.session: await self.session.close() - self._session_owner = False self.session = None def _build_ssl_config(self, cert, verify): @@ -221,11 +241,16 @@ async def send( ) if not stream: await _handle_non_stream_rest_response(response) - + except AttributeError as err: + if self.session is None: + raise ValueError("No session available for request.") from err + raise except aiohttp.client_exceptions.ClientResponseError as err: raise ServiceResponseError(err, error=err) from err + except ConnectionTimeoutError as err: + raise ServiceRequestTimeoutError(err, error=err) from err except asyncio.TimeoutError as err: - raise ServiceResponseError(err, error=err) from err + raise ServiceResponseTimeoutError(err, error=err) from err except aiohttp.client_exceptions.ClientError as err: raise ServiceRequestError(err, error=err) from err return response diff --git a/sdk/core/corehttp/corehttp/transport/requests/_requests_basic.py b/sdk/core/corehttp/corehttp/transport/requests/_requests_basic.py index e3e45b091f92..a708ef4c0574 100644 --- a/sdk/core/corehttp/corehttp/transport/requests/_requests_basic.py +++ b/sdk/core/corehttp/corehttp/transport/requests/_requests_basic.py @@ -24,7 +24,7 @@ # # -------------------------------------------------------------------------- import logging -from typing import Optional, Union, TypeVar, cast, MutableMapping, TYPE_CHECKING +from typing import Optional, Union, TypeVar, MutableMapping, TYPE_CHECKING from urllib3.util.retry import Retry from urllib3.exceptions import ( ProtocolError, @@ -35,7 +35,9 @@ from ...exceptions import ( ServiceRequestError, + ServiceRequestTimeoutError, ServiceResponseError, + ServiceResponseTimeoutError, IncompleteReadError, HttpResponseError, ) @@ -84,6 +86,7 @@ def __init__(self, **kwargs) -> None: raise ValueError("session_owner cannot be False if no session is provided") self.connection_config = _create_connection_config(**kwargs) self._use_env_settings = kwargs.pop("use_env_settings", True) + self._has_been_opened = False def __enter__(self) -> "RequestsTransport": self.open() @@ -106,19 +109,26 @@ def _init_session(self, session: requests.Session) -> None: session.mount(p, adapter) def open(self): - if not self.session and self._session_owner: - self.session = requests.Session() - self._init_session(self.session) - # pyright has trouble to understand that self.session is not None, since we raised at worst in the init - self.session = cast(requests.Session, self.session) + if self._has_been_opened and not self.session: + raise ValueError( + "HTTP transport has already been closed. " + "You may check if you're calling a function outside of the `with` of your client creation, " + "or if you called `close()` on your client already." + ) + if not self.session: + if self._session_owner: + self.session = requests.Session() + self._init_session(self.session) + else: + raise ValueError("session_owner cannot be False and no session is available") + self._has_been_opened = True def close(self): if self._session_owner and self.session: self.session.close() - self._session_owner = False self.session = None - def send( + def send( # pylint: disable=too-many-statements self, request: "RestHttpRequest", *, @@ -165,13 +175,18 @@ def send( ) response.raw.enforce_content_length = True - except ( - NewConnectionError, - ConnectTimeoutError, - ) as err: + except AttributeError as err: + if self.session is None: + raise ValueError("No session available for request.") from err + raise + except NewConnectionError as err: error = ServiceRequestError(err, error=err) + except ConnectTimeoutError as err: + error = ServiceRequestTimeoutError(err, error=err) + except requests.exceptions.ConnectTimeout as err: + error = ServiceRequestTimeoutError(err, error=err) except requests.exceptions.ReadTimeout as err: - error = ServiceResponseError(err, error=err) + error = ServiceResponseTimeoutError(err, error=err) except requests.exceptions.ConnectionError as err: if err.args and isinstance(err.args[0], ProtocolError): error = ServiceResponseError(err, error=err) @@ -186,7 +201,7 @@ def send( _LOGGER.warning("Unable to stream download.") error = HttpResponseError(err, error=err) except requests.RequestException as err: - error = ServiceRequestError(err, error=err) + error = ServiceResponseError(err, error=err) if error: raise error diff --git a/sdk/core/corehttp/corehttp/utils/_utils.py b/sdk/core/corehttp/corehttp/utils/_utils.py index 7918e4d72a3a..21bb33719d21 100644 --- a/sdk/core/corehttp/corehttp/utils/_utils.py +++ b/sdk/core/corehttp/corehttp/utils/_utils.py @@ -169,9 +169,10 @@ def get_file_items(files: "FilesType") -> Sequence[Tuple[str, "FileType"]]: def get_running_async_lock() -> AsyncContextManager: """Get a lock instance from the async library that the current context is running under. + :return: An instance of the running async library's Lock class. :rtype: AsyncContextManager - :raises: RuntimeError if the current context is not running under an async library. + :raises RuntimeError: if the current context is not running under an async library. """ try: diff --git a/sdk/core/corehttp/tests/async_tests/test_transport_async.py b/sdk/core/corehttp/tests/async_tests/test_transport_async.py new file mode 100644 index 000000000000..d91e2baf74cd --- /dev/null +++ b/sdk/core/corehttp/tests/async_tests/test_transport_async.py @@ -0,0 +1,121 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from unittest import mock +import asyncio +from packaging.version import Version + +from corehttp.rest import HttpRequest +from corehttp.transport.aiohttp import AioHttpTransport +from corehttp.runtime.pipeline import AsyncPipeline +from corehttp.exceptions import ( + ServiceResponseError, + ServiceRequestError, + ServiceRequestTimeoutError, + ServiceResponseTimeoutError, +) + +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_already_close_with_with(caplog, port): + transport = AioHttpTransport() + + request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + + async with AsyncPipeline(transport) as pipeline: + await pipeline.run(request) + + # This is now closed, new requests should fail + with pytest.raises(ValueError) as err: + await transport.send(request) + assert "HTTP transport has already been closed." in str(err) + + +@pytest.mark.asyncio +async def test_already_close_manually(caplog, port): + transport = AioHttpTransport() + + request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + + await transport.send(request) + await transport.close() + + # This is now closed, new requests should fail + with pytest.raises(ValueError) as err: + await transport.send(request) + assert "HTTP transport has already been closed." in str(err) + + +@pytest.mark.asyncio +async def test_close_too_soon_works_fine(caplog, port): + transport = AioHttpTransport() + + request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + + await transport.close() + result = await transport.send(request) + + assert result # No exception is good enough here + + +@pytest.mark.asyncio +async def test_aiohttp_timeout_response(port): + async with AioHttpTransport() as transport: + + request = HttpRequest("GET", f"http://localhost:{port}/basic/string") + + with mock.patch.object( + aiohttp.ClientResponse, "start", side_effect=asyncio.TimeoutError("Too slow!") + ) as mock_method: + with pytest.raises(ServiceResponseTimeoutError) as err: + await transport.send(request) + + with pytest.raises(ServiceResponseError) as err: + await transport.send(request) + + stream_resp = HttpRequest("GET", f"http://localhost:{port}/streams/basic") + with pytest.raises(ServiceResponseTimeoutError) as err: + await transport.send(stream_resp, stream=True) + + stream_resp = await transport.send(stream_resp, stream=True) + with mock.patch.object( + aiohttp.streams.StreamReader, "read", side_effect=asyncio.TimeoutError("Too slow!") + ) as mock_method: + with pytest.raises(ServiceResponseTimeoutError) as err: + await stream_resp.read() + + +@pytest.mark.asyncio +async def test_aiohttp_timeout_request(): + async with AioHttpTransport() as transport: + transport.session._connector.connect = mock.Mock(side_effect=asyncio.TimeoutError("Too slow!")) + + request = HttpRequest("GET", f"http://localhost:12345/basic/string") + + # aiohttp 3.10 introduced separate connection timeout + if Version(aiohttp.__version__) >= Version("3.10"): + with pytest.raises(ServiceRequestTimeoutError) as err: + await transport.send(request) + + with pytest.raises(ServiceRequestError) as err: + await transport.send(request) + + stream_request = HttpRequest("GET", f"http://localhost:12345/streams/basic") + with pytest.raises(ServiceRequestTimeoutError) as err: + await transport.send(stream_request, stream=True) + + else: + with pytest.raises(ServiceResponseTimeoutError) as err: + await transport.send(request) + + with pytest.raises(ServiceResponseError) as err: + await transport.send(request) + + stream_request = HttpRequest("GET", f"http://localhost:12345/streams/basic") + with pytest.raises(ServiceResponseTimeoutError) as err: + await transport.send(stream_request, stream=True) diff --git a/sdk/core/corehttp/tests/test_stream_generator.py b/sdk/core/corehttp/tests/test_stream_generator.py index 30bbc70fec5a..b15c9440bd9b 100644 --- a/sdk/core/corehttp/tests/test_stream_generator.py +++ b/sdk/core/corehttp/tests/test_stream_generator.py @@ -9,6 +9,7 @@ from corehttp.transport import HttpTransport from corehttp.runtime.pipeline import Pipeline from corehttp.rest._requests_basic import StreamDownloadGenerator, RestRequestsTransportResponse +from corehttp.exceptions import ServiceResponseError import pytest from utils import HTTP_RESPONSES, create_http_response, create_transport_response @@ -61,7 +62,7 @@ def close(self): http_response = create_http_response(http_response, http_request, MockInternalResponse()) stream = StreamDownloadGenerator(pipeline, http_response, decompress=False) with mock.patch("time.sleep", return_value=None): - with pytest.raises(requests.exceptions.ConnectionError): + with pytest.raises(ServiceResponseError): stream.__next__() @@ -108,5 +109,5 @@ def close(self): ) downloader = response.iter_raw() - with pytest.raises(requests.exceptions.ConnectionError): + with pytest.raises(ServiceResponseError): b"".join(downloader) diff --git a/sdk/core/corehttp/tests/test_tracer_otel.py b/sdk/core/corehttp/tests/test_tracer_otel.py index 074ed50a5c05..cb54e7920c18 100644 --- a/sdk/core/corehttp/tests/test_tracer_otel.py +++ b/sdk/core/corehttp/tests/test_tracer_otel.py @@ -304,10 +304,6 @@ def test_span_exception_without_current_context(tracing_helper): finished_spans = tracing_helper.exporter.get_finished_spans() assert len(finished_spans) == 1 - assert len(finished_spans[0].events) == 1 - assert finished_spans[0].events[0].name == "exception" - assert finished_spans[0].events[0].attributes["exception.type"] == "ValueError" - assert finished_spans[0].events[0].attributes["exception.message"] == "This is an error" assert finished_spans[0].status.status_code == OtelStatusCode.ERROR @@ -327,10 +323,6 @@ def test_span_exception_exit(tracing_helper): finished_spans = tracing_helper.exporter.get_finished_spans() assert len(finished_spans) == 1 - assert len(finished_spans[0].events) == 1 - assert finished_spans[0].events[0].name == "exception" - assert finished_spans[0].events[0].attributes["exception.type"] == "ValueError" - assert finished_spans[0].events[0].attributes["exception.message"] == "This is an error" assert finished_spans[0].status.status_code == OtelStatusCode.ERROR @@ -365,3 +357,67 @@ def test_tracer_caching_different_args(): assert tracer1 is tracer2 assert tracer1 is not tracer3 + + +def test_tracer_set_span_error(tracing_helper): + """Test that the tracer's set_span_status method works correctly.""" + tracer = get_tracer() + assert tracer + + with tracer.start_as_current_span(name="ok-span") as span: + tracer.set_span_error_status(span) + + with tracer.start_as_current_span(name="ok-span") as span: + tracer.set_span_error_status(span, "This is an error") + + # Verify status on finished spans + finished_spans = tracing_helper.exporter.get_finished_spans() + assert len(finished_spans) == 2 + + assert finished_spans[0].status.status_code == OtelStatusCode.ERROR + assert finished_spans[0].status.description is None + + assert finished_spans[1].status.status_code == OtelStatusCode.ERROR + assert finished_spans[1].status.description == "This is an error" + + +def test_start_span_with_start_time(tracing_helper): + """Test that a span can be started with a custom start time.""" + tracer = get_tracer() + assert tracer + start_time = 1234567890 + with tracer.start_as_current_span(name="foo-span", start_time=start_time) as span: + assert span.start_time == start_time + finished_spans = tracing_helper.exporter.get_finished_spans() + assert len(finished_spans) == 1 + assert finished_spans[0].start_time == start_time + + span = tracer.start_span(name="foo-span", start_time=start_time) + assert span.start_time == start_time + span.end() + finished_spans = tracing_helper.exporter.get_finished_spans() + assert len(finished_spans) == 2 + assert finished_spans[1].start_time == start_time + + +def test_tracer_with_custom_context(tracing_helper): + """Test that the tracer can start a span with a custom context.""" + tracer = get_tracer() + assert tracer + + with tracer.start_as_current_span(name="foo-span") as foo_span: + foo_trace_context = tracer.get_trace_context() + + assert "traceparent" in foo_trace_context + + with tracer.start_as_current_span(name="bar-span", context=dict(foo_trace_context)) as bar_span: + pass + + finished_spans = tracing_helper.exporter.get_finished_spans() + assert len(finished_spans) == 2 + assert finished_spans[0].name == "foo-span" + assert finished_spans[1].name == "bar-span" + assert finished_spans[1].context.trace_id == finished_spans[0].context.trace_id + + # foo_span should be the parent of bar_span + assert finished_spans[1].parent.span_id == finished_spans[0].context.span_id diff --git a/sdk/core/corehttp/tests/test_transport.py b/sdk/core/corehttp/tests/test_transport.py new file mode 100644 index 000000000000..6d42a60986c7 --- /dev/null +++ b/sdk/core/corehttp/tests/test_transport.py @@ -0,0 +1,97 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import pytest +from unittest import mock +from socket import timeout as SocketTimeout + +from urllib3.util import connection as urllib_connection +from urllib3.response import HTTPResponse as UrllibResponse +from urllib3.connection import HTTPConnection as UrllibConnection + +from corehttp.rest import HttpRequest +from corehttp.transport.requests import RequestsTransport +from corehttp.runtime.pipeline import Pipeline +from corehttp.exceptions import ServiceResponseError, ServiceResponseTimeoutError, ServiceRequestTimeoutError + +import pytest + + +def test_already_close_with_with(caplog, port): + transport = RequestsTransport() + + request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + + with Pipeline(transport) as pipeline: + pipeline.run(request) + + # This is now closed, new requests should fail + with pytest.raises(ValueError) as err: + transport.send(request) + assert "HTTP transport has already been closed." in str(err) + + +def test_already_close_manually(caplog, port): + transport = RequestsTransport() + + request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + + transport.send(request) + transport.close() + + # This is now closed, new requests should fail + with pytest.raises(ValueError) as err: + transport.send(request) + assert "HTTP transport has already been closed." in str(err) + + +def test_close_too_soon_works_fine(caplog, port): + transport = RequestsTransport() + + request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + + transport.close() # Never opened, should work fine + result = transport.send(request) + + assert result # No exception is good enough here + + +def test_requests_timeout_response(caplog, port): + transport = RequestsTransport() + + request = HttpRequest("GET", f"http://localhost:{port}/basic/string") + + with mock.patch.object(UrllibConnection, "getresponse", side_effect=SocketTimeout) as mock_method: + with pytest.raises(ServiceResponseTimeoutError) as err: + transport.send(request, read_timeout=0.0001) + + with pytest.raises(ServiceResponseError) as err: + transport.send(request, read_timeout=0.0001) + + stream_request = HttpRequest("GET", f"http://localhost:{port}/streams/basic") + with pytest.raises(ServiceResponseTimeoutError) as err: + transport.send(stream_request, stream=True, read_timeout=0.0001) + + stream_resp = transport.send(stream_request, stream=True) + with mock.patch.object(UrllibResponse, "_handle_chunk", side_effect=SocketTimeout) as mock_method: + with pytest.raises(ServiceResponseTimeoutError) as err: + stream_resp.read() + + +def test_requests_timeout_request(caplog, port): + transport = RequestsTransport() + + request = HttpRequest("GET", f"http://localhost:{port}/basic/string") + + with mock.patch.object(urllib_connection, "create_connection", side_effect=SocketTimeout) as mock_method: + with pytest.raises(ServiceRequestTimeoutError) as err: + transport.send(request, connection_timeout=0.0001) + + with pytest.raises(ServiceRequestTimeoutError) as err: + transport.send(request, connection_timeout=0.0001) + + stream_request = HttpRequest("GET", f"http://localhost:{port}/streams/basic") + with pytest.raises(ServiceRequestTimeoutError) as err: + transport.send(stream_request, stream=True, connection_timeout=0.0001) From 51107defb4a386fd085bf39460f5250f1e77cacc Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Fri, 30 Jan 2026 15:37:58 -0800 Subject: [PATCH 2/3] Update sdk/core/corehttp/tests/test_transport.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- sdk/core/corehttp/tests/test_transport.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/core/corehttp/tests/test_transport.py b/sdk/core/corehttp/tests/test_transport.py index 6d42a60986c7..a2fbf944d057 100644 --- a/sdk/core/corehttp/tests/test_transport.py +++ b/sdk/core/corehttp/tests/test_transport.py @@ -16,7 +16,6 @@ from corehttp.runtime.pipeline import Pipeline from corehttp.exceptions import ServiceResponseError, ServiceResponseTimeoutError, ServiceRequestTimeoutError -import pytest def test_already_close_with_with(caplog, port): From be0b848977c32a9d04465fde8945e515155f5452 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Sat, 31 Jan 2026 00:05:00 +0000 Subject: [PATCH 3/3] black formatting Signed-off-by: Paul Van Eck --- sdk/core/corehttp/tests/test_transport.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/core/corehttp/tests/test_transport.py b/sdk/core/corehttp/tests/test_transport.py index a2fbf944d057..a3c51719753d 100644 --- a/sdk/core/corehttp/tests/test_transport.py +++ b/sdk/core/corehttp/tests/test_transport.py @@ -17,7 +17,6 @@ from corehttp.exceptions import ServiceResponseError, ServiceResponseTimeoutError, ServiceRequestTimeoutError - def test_already_close_with_with(caplog, port): transport = RequestsTransport()