diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py index 6c5fdd76c1a5..6ac15ffd6d61 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py @@ -65,6 +65,7 @@ from ._quick_query_helper import BlobQueryReader from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin, TransportWrapper from ._shared.response_handlers import process_storage_error, return_response_headers +from ._shared.validation import ChecksumAlgorithm, parse_validation_option from ._serialize import ( get_access_conditions, get_api_version, @@ -505,15 +506,11 @@ def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, upload_blob only succeeds if the blob's lease is active and matches this ID. Value can be a BlobLeaseClient object @@ -616,6 +613,9 @@ def upload_blob( raise ValueError("Encryption required but no key was provided.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _upload_blob_options( data=data, blob_type=blob_type, @@ -627,6 +627,7 @@ def upload_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -683,15 +684,11 @@ def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a @@ -765,6 +762,9 @@ def download_blob( raise ValueError("Offset value must not be None if length is set.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _download_blob_options( blob_name=self.blob_name, container_name=self.container_name, @@ -778,6 +778,7 @@ def download_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -2009,15 +2010,11 @@ def stage_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -2850,13 +2847,11 @@ def upload_page( Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. :paramtype lease: ~azure.storage.blob.BlobLeaseClient or str - :keyword bool validate_content: - If true, calculates an MD5 hash of the page content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int if_sequence_number_lte: If the blob's sequence number is less than or equal to the specified value, the request proceeds; otherwise it fails. @@ -3157,13 +3152,11 @@ def append_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash of the block content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int maxsize_condition: Optional conditional header. The max length in bytes permitted for the append blob. If the Append Block operation would cause the blob diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi index fee0f4757f79..f5679a595ad8 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi @@ -173,7 +173,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): tags: Optional[Dict[str, str]] = None, overwrite: bool = False, content_settings: Optional[ContentSettings] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[BlobLeaseClient] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -200,7 +200,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -222,7 +222,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -244,7 +244,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -486,7 +486,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, encoding: Optional[str] = None, cpk: Optional[CustomerProvidedEncryptionKey] = None, @@ -671,7 +671,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: int, *, lease: Optional[Union[BlobLeaseClient, str]] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, if_sequence_number_lte: Optional[int] = None, if_sequence_number_lt: Optional[int] = None, if_sequence_number_eq: Optional[int] = None, @@ -741,7 +741,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, maxsize_condition: Optional[int] = None, appendpos_condition: Optional[int] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py index 33d5dfc7f0b2..1bbfb33901d3 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py @@ -8,7 +8,7 @@ from io import BytesIO from typing import ( Any, AnyStr, AsyncGenerator, AsyncIterable, cast, - Dict, IO, Iterable, List, Optional, Tuple, Union, + Dict, IO, Iterable, List, Literal, Optional, Tuple, Union, TYPE_CHECKING ) from urllib.parse import quote, unquote, urlparse @@ -58,6 +58,7 @@ from ._shared.response_handlers import return_headers_and_deserialized, return_response_headers from ._shared.uploads import IterStreamer from ._shared.uploads_async import AsyncIterStreamer +from ._shared.validation import parse_validation_option from ._upload_helpers import _any_conditions if TYPE_CHECKING: @@ -110,6 +111,7 @@ def _upload_blob_options( # pylint:disable=too-many-statements length: Optional[int], metadata: Optional[Dict[str, str]], encryption_options: Dict[str, Any], + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]], config: "StorageConfiguration", sdk_moniker: str, client: "AzureBlobStorage", @@ -135,7 +137,6 @@ def _upload_blob_options( # pylint:disable=too-many-statements else: raise TypeError(f"Unsupported data type: {type(data)}") - validate_content = kwargs.pop('validate_content', False) content_settings = kwargs.pop('content_settings', None) overwrite = kwargs.pop('overwrite', False) max_concurrency = kwargs.pop('max_concurrency', None) @@ -258,6 +259,7 @@ def _download_blob_options( length: Optional[int], encoding: Optional[str], encryption_options: Dict[str, Any], + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]], config: "StorageConfiguration", sdk_moniker: str, client: "AzureBlobStorage", @@ -279,6 +281,8 @@ def _download_blob_options( Encoding to decode the downloaded bytes. Default is None, i.e. no decoding. :param Dict[str, Any] encryption_options: The options for encryption, if enabled. + :param validate_content: + Enables checksum validation for the transfer. Already parsed via parse_validation_option. :param StorageConfiguration config: The Storage configuration options. :param str sdk_moniker: @@ -292,8 +296,6 @@ def _download_blob_options( if offset is None: raise ValueError("Offset must be provided if length is provided.") length = offset + length - 1 # Service actually uses an end-range inclusive index - - validate_content = kwargs.pop('validate_content', False) access_conditions = get_access_conditions(kwargs.pop('lease', None)) mod_conditions = get_modify_conditions(kwargs) @@ -721,7 +723,7 @@ def _stage_block_options( if isinstance(data, bytes): data = data[:length] - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) cpk_scope_info = get_cpk_scope_info(kwargs) cpk = kwargs.pop('cpk', None) cpk_info = None @@ -1004,7 +1006,7 @@ def _upload_page_options( ) mod_conditions = get_modify_conditions(kwargs) cpk_scope_info = get_cpk_scope_info(kwargs) - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) cpk = kwargs.pop('cpk', None) cpk_info = None if cpk: @@ -1149,7 +1151,7 @@ def _append_block_options( appendpos_condition = kwargs.pop('appendpos_condition', None) maxsize_condition = kwargs.pop('maxsize_condition', None) - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) append_conditions = None if maxsize_condition or appendpos_condition is not None: append_conditions = AppendPositionAccessConditions( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index 0415b58cec0d..834252c3cd48 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -1029,15 +1029,11 @@ def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used, because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the container has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -1274,15 +1270,11 @@ def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi index 8825670779d7..5ee116e5ca75 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi @@ -13,6 +13,7 @@ from typing import ( Callable, Dict, List, + Literal, IO, Iterable, Iterator, @@ -253,7 +254,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): *, overwrite: Optional[bool] = None, content_settings: Optional[ContentSettings] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -295,7 +296,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -316,7 +317,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -338,7 +339,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py index 4f39fa68e3c7..401bb76d867c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py @@ -11,7 +11,7 @@ from io import BytesIO, StringIO from typing import ( Any, Callable, cast, Dict, Generator, - Generic, IO, Iterator, List, Optional, + Generic, IO, Iterator, List, Literal, Optional, overload, Tuple, TypeVar, Union, TYPE_CHECKING ) @@ -92,7 +92,7 @@ def __init__( current_progress: int, start_range: int, end_range: int, - validate_content: bool, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]], encryption_options: Dict[str, Any], encryption_data: Optional["_EncryptionData"] = None, stream: Any = None, @@ -330,7 +330,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: bool = None, # type: ignore [assignment] + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, encryption_options: Dict[str, Any] = None, # type: ignore [assignment] max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index fee83754bb6b..71f4cce10541 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -34,7 +34,11 @@ from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode -from .streams import StructuredMessageDecoder, StructuredMessageEncodeStream, StructuredMessageProperties +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) from .validation import ( CV_TYPE_ERROR_MSG, calculate_content_md5, @@ -69,7 +73,12 @@ def encode_base64(data: Union[bytes, str]) -> str: # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) + retry_counts = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -78,7 +87,9 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs + ) # Is this method/status code retryable? (Based on allowlists and control @@ -98,7 +109,9 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get("x-ms-copy-source-error-code") + error_code = response.http_response.headers.get( + "x-ms-copy-source-error-code" + ) if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -122,9 +135,9 @@ def is_checksum_retry(response) -> bool: # Legacy code - evaluate retry only on validate_content=True if validate_content is True and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - calculate_content_md5(response.http_response.body()) - ) + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -153,7 +166,9 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str( + uuid.uuid1() + ) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -193,7 +208,9 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError(f"Attempting to use undefined host location {use_location}") + raise ValueError( + f"Attempting to use undefined host location {use_location}" + ) if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -211,7 +228,9 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) + super(StorageLoggingPolicy, self).__init__( + logging_enable=logging_enable, **kwargs + ) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -240,7 +259,16 @@ def on_request(self, request: "PipelineRequest") -> None: parsed_qs["sig"] = "*****" # the SAS needs to be put back together - value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + value = urlunparse( + ( + scheme, + netloc, + path, + params, + urlencode(parsed_qs), + fragment, + ) + ) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") @@ -253,7 +281,9 @@ def on_request(self, request: "PipelineRequest") -> None: except Exception as err: # pylint: disable=broad-except _LOGGER.debug("Failed to log request: %r", err) - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: if response.context.pop("logging_enable", self.enable_http_logger): if not _LOGGER.isEnabledFor(logging.DEBUG): return @@ -268,7 +298,9 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get("content-type", "") + resp_content_type = response.http_response.headers.get( + "content-type", "" + ) if header and pattern.match(header): filename = header.partition("=")[2] @@ -297,7 +329,9 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop("raw_request_hook", self._request_callback) + request_callback = request.context.options.pop( + "raw_request_hook", self._request_callback + ) if request_callback: request_callback(request) @@ -315,36 +349,50 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -379,13 +427,21 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: elif validate_content == ChecksumAlgorithm.CRC64: if isinstance(data, bytes): - request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) + request.http_request.headers[CRC64_HEADER] = encode_base64( + calculate_crc64_bytes(data) + ) elif hasattr(data, "read"): - content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) + content_length = int( + request.http_request.headers.get(CONTENT_LENGTH_HEADER) + ) # Wrap data in structured message stream and adjust HTTP request - sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) + sm_stream = StructuredMessageEncodeStream( + data, content_length, StructuredMessageProperties.CRC64 + ) request.http_request.data = sm_stream - request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) + request.http_request.headers[CONTENT_LENGTH_HEADER] = str( + len(sm_stream) + ) request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 else: @@ -408,7 +464,9 @@ def _validate_content_response( if not validate_content: return - if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): + if is_md5_validation(validate_content) and response.http_response.headers.get( + "content-md5" + ): computed_md5 = request.context.get("validate_content_md5") or encode_base64( calculate_content_md5(response.http_response.body()) ) @@ -440,9 +498,12 @@ def _validate_content_response( # Patch response to return response iterator wrapped in structured message decoder original_stream_download = response.http_response.stream_download + def wrapped_stream_download(*args, **kwargs): iterator = original_stream_download(*args, **kwargs) - decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) + decoder = decoder_cls( + iterator, content_length, block_size=DATA_BLOCK_SIZE + ) decoder.request = iterator.request # type: ignore decoder.response = iterator.response # type: ignore return decoder @@ -455,13 +516,16 @@ class StorageContentValidation(SansIOHTTPPolicy): This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. """ + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super().__init__() def on_request(self, request: "PipelineRequest") -> None: _prepare_content_validation(request) - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: _validate_content_response(request, response, StructuredMessageDecoder) @@ -489,7 +553,9 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: + def _set_next_host_location( + self, settings: Dict[str, Any], request: "PipelineRequest" + ) -> None: """ A function which sets the next host location on the request, if applicable. @@ -509,7 +575,7 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: """ Configure the retry settings for the request. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A dictionary containing the retry settings. @@ -528,7 +594,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "retry_secondary": options.pop( + "retry_to_secondary", self.retry_to_secondary + ), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -537,7 +605,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument + def get_backoff_time( + self, settings: Dict[str, Any] + ) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -549,7 +619,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disabl def sleep(self, settings, transport): """Sleep for the backoff time. - + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. :param transport: The transport to use for sleeping. :type transport: @@ -600,7 +670,9 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append(RequestHistory(request, http_response=response)) + settings["history"].append( + RequestHistory(request, http_response=response) + ) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -623,7 +695,7 @@ def increment( def send(self, request): """Send the request with retry logic. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A pipeline response object. @@ -635,13 +707,20 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry( + response + ): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) self.sleep(retry_settings, request.context.transport) continue @@ -649,9 +728,16 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - retry_hook(retry_settings, request=request.http_request, response=None, error=err) + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) self.sleep(retry_settings, request.context.transport) continue raise err @@ -704,7 +790,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -717,8 +805,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -756,7 +850,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -771,7 +867,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -779,16 +879,22 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "TokenCredential", audience: str, **kwargs: Any + ) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: """Handle the challenge from the service and authorize the request. - + :param request: The request object. :type request: ~azure.core.pipeline.PipelineRequest :param response: The response object. - :type response: ~azure.core.pipeline.PipelineResponse + :type response: ~azure.core.pipeline.PipelineResponse :return: True if the request was authorized, False otherwise. :rtype: bool """ diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 860f10e93089..14ce070e47ff 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -11,7 +11,10 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncHTTPPolicy, +) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE @@ -42,9 +45,17 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + await settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) else: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) async def is_checksum_retry(response): @@ -59,9 +70,9 @@ async def is_checksum_retry(response): await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - calculate_content_md5(response.http_response.body()) - ) + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -72,6 +83,7 @@ class AsyncContentValidationPolicy(AsyncHTTPPolicy): This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. """ + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super().__init__() @@ -106,36 +118,50 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = await self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -164,13 +190,20 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): + if is_retry( + response, retry_settings["mode"] + ) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: await retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) await self.sleep(retry_settings, request.context.transport) continue @@ -178,9 +211,16 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - await retry_hook(retry_settings, request=request.http_request, response=None, error=err) + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) await self.sleep(retry_settings, request.context.transport) continue raise err @@ -235,7 +275,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -248,8 +290,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -287,7 +335,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -302,7 +352,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -310,10 +364,16 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any + ) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + async def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py index e04d666eab5e..712f4e90af69 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py @@ -35,16 +35,19 @@ class SMRegion(Enum): MESSAGE_FOOTER = 5 -def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: - return (version.to_bytes(1, 'little') + - size.to_bytes(8, 'little') + - flags.to_bytes(2, 'little') + - num_segments.to_bytes(2, 'little')) +def generate_message_header( + version: int, size: int, flags: StructuredMessageProperties, num_segments: int +) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) def generate_segment_header(number: int, size: int) -> bytes: - return (number.to_bytes(2, 'little') + - size.to_bytes(8, 'little')) + return number.to_bytes(2, "little") + size.to_bytes(8, "little") def parse_message_header( @@ -53,24 +56,30 @@ def parse_message_header( version = data[0] if version != 1: raise ValueError(f"The structured message version is not supported: {version}") - message_length = int.from_bytes(data[1:9], 'little') + message_length = int.from_bytes(data[1:9], "little") if message_length != expected_message_length: - raise ValueError(f"Structured message length {message_length} " - f"did not match content length {expected_message_length}") - flags = StructuredMessageProperties(int.from_bytes(data[9:11], 'little')) - num_segments = int.from_bytes(data[11:13], 'little') + raise ValueError( + f"Structured message length {message_length} " + f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") return version, flags, num_segments def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: - segment_number = int.from_bytes(data[0:2], 'little') + segment_number = int.from_bytes(data[0:2], "little") if segment_number != expected_segment_number: - raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") - segment_content_length = int.from_bytes(data[2:10], 'little') + raise ValueError( + f"Structured message segment number invalid or out of order {segment_number}" + ) + segment_content_length = int.from_bytes(data[2:10], "little") return segment_number, segment_content_length -class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes +class StructuredMessageEncodeStream( + IOBase +): # pylint: disable=too-many-instance-attributes message_version: int content_length: int message_length: int @@ -95,11 +104,12 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc _segment_crc64s: dict[int, int] def __init__( - self, inner_stream: IO[bytes], + self, + inner_stream: IO[bytes], content_length: int, flags: StructuredMessageProperties, *, - segment_size: int = DEFAULT_SEGMENT_SIZE + segment_size: int = DEFAULT_SEGMENT_SIZE, ) -> None: if segment_size < 1: raise ValueError("Segment size must be greater than 0.") @@ -141,11 +151,19 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) @property def _message_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) def _update_current_region_length(self) -> None: if self._current_region == SMRegion.MESSAGE_HEADER: @@ -155,8 +173,9 @@ def _update_current_region_length(self) -> None: elif self._current_region == SMRegion.SEGMENT_CONTENT: # Last segment size is remaining content if self._current_segment_number == self._num_segments: - self._current_region_length = self.content_length - \ - ((self._current_segment_number - 1) * self._segment_size) + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) else: self._current_region_length = self._segment_size elif self._current_region == SMRegion.SEGMENT_FOOTER: @@ -179,7 +198,10 @@ def readable(self) -> bool: def seekable(self) -> bool: try: # Only seekable if the inner stream is and we could get its initial position - return self._inner_stream.seekable() and self._initial_content_position is not None + return ( + self._inner_stream.seekable() + and self._initial_content_position is not None + ) except (AttributeError, UnsupportedOperation, OSError): return False @@ -187,22 +209,38 @@ def tell(self) -> int: if self._current_region == SMRegion.MESSAGE_HEADER: return self._current_region_offset if self._current_region == SMRegion.SEGMENT_HEADER: - return (self._message_header_length + self._content_offset + - (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + - self._current_region_offset) + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) if self._current_region == SMRegion.SEGMENT_CONTENT: - return (self._message_header_length + self._content_offset + - (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + - self._segment_header_length) + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) if self._current_region == SMRegion.SEGMENT_FOOTER: - return (self._message_header_length + self._content_offset + - (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + - self._segment_header_length + - self._current_region_offset) + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) if self._current_region == SMRegion.MESSAGE_FOOTER: - return (self._message_header_length + self._content_offset + - self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + - self._current_region_offset) + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) raise ValueError(f"Invalid SMRegion {self._current_region}") @@ -233,21 +271,33 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: # MESSAGE_FOOTER elif position >= self.message_length - self._message_footer_length: self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - (self.message_length - self._message_footer_length) + self._current_region_offset = position - ( + self.message_length - self._message_footer_length + ) self._content_offset = self.content_length self._current_segment_number = self._num_segments else: # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length - new_segment_num = 1 + (position - self._message_header_length) // full_segment_size + full_segment_size = ( + self._segment_header_length + + self._segment_size + + self._segment_footer_length + ) + new_segment_num = ( + 1 + (position - self._message_header_length) // full_segment_size + ) segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size + previous_segments_total_content_size = ( + new_segment_num - 1 + ) * self._segment_size # We need the size of the segment we are seeking to for some of the calculations below new_segment_size = self._segment_size if new_segment_num == self._num_segments: # The last segment size is the remaining content length - new_segment_size = self.content_length - previous_segments_total_content_size + new_segment_size = ( + self.content_length - previous_segments_total_content_size + ) # SEGMENT_HEADER if segment_pos < self._segment_header_length: @@ -258,17 +308,25 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: elif segment_pos < self._segment_header_length + new_segment_size: self._current_region = SMRegion.SEGMENT_CONTENT self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = previous_segments_total_content_size + self._current_region_offset + self._content_offset = ( + previous_segments_total_content_size + self._current_region_offset + ) # SEGMENT_FOOTER else: self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size - self._content_offset = previous_segments_total_content_size + new_segment_size + self._current_region_offset = ( + segment_pos - self._segment_header_length - new_segment_size + ) + self._content_offset = ( + previous_segments_total_content_size + new_segment_size + ) self._current_segment_number = new_segment_num self._update_current_region_length() - self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) + self._inner_stream.seek( + (self._initial_content_position or 0) + self._content_offset + ) return position def read(self, size: int = -1) -> bytes: @@ -276,7 +334,7 @@ def read(self, size: int = -1) -> bytes: raise ValueError("Stream is closed") if size == 0: - return b'' + return b"" if size < 0: size = sys.maxsize @@ -286,11 +344,14 @@ def read(self, size: int = -1) -> bytes: while count < size and self.tell() < self.message_length: remaining = size - count if self._current_region in ( - SMRegion.MESSAGE_HEADER, - SMRegion.SEGMENT_HEADER, - SMRegion.SEGMENT_FOOTER, - SMRegion.MESSAGE_FOOTER): - count += self._read_metadata_region(self._current_region, remaining, output) + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region( + self._current_region, remaining, output + ) elif self._current_region == SMRegion.SEGMENT_CONTENT: count += self._read_content(remaining, output) else: @@ -300,7 +361,9 @@ def read(self, size: int = -1) -> bytes: def _calculate_message_length(self) -> int: length = self._message_header_length - length += (self._segment_header_length + self._segment_footer_length) * self._num_segments + length += ( + self._segment_header_length + self._segment_footer_length + ) * self._num_segments length += self.content_length length += self._message_footer_length return length @@ -311,22 +374,28 @@ def _get_metadata_region(self, region: SMRegion) -> bytes: self.message_version, self.message_length, self.flags, - self._num_segments) + self._num_segments, + ) if region == SMRegion.SEGMENT_HEADER: - segment_size = min(self._segment_size, self.content_length - self._content_offset) + segment_size = min( + self._segment_size, self.content_length - self._content_offset + ) return generate_segment_header(self._current_segment_number, segment_size) if region == SMRegion.SEGMENT_FOOTER: if StructuredMessageProperties.CRC64 in self.flags: return self._segment_crc64s[self._current_segment_number].to_bytes( - StructuredMessageConstants.CRC64_LENGTH, 'little') - return b'' + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" if region == SMRegion.MESSAGE_FOOTER: if StructuredMessageProperties.CRC64 in self.flags: - return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little') - return b'' + return self._message_crc64.to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" raise ValueError(f"Invalid metadata SMRegion {self._current_region}") @@ -352,16 +421,22 @@ def _advance_region(self, current: SMRegion): self._update_current_region_length() - def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: + def _read_metadata_region( + self, region: SMRegion, size: int, output: BytesIO + ) -> int: metadata = self._get_metadata_region(region) read_size = min(size, self._current_region_length - self._current_region_offset) - content = metadata[self._current_region_offset: self._current_region_offset + read_size] + content = metadata[ + self._current_region_offset : self._current_region_offset + read_size + ] output.write(content) self._current_region_offset += read_size - if (self._current_region_offset == self._current_region_length and - self._current_region != SMRegion.MESSAGE_FOOTER): + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): self._advance_region(region) return read_size @@ -383,8 +458,9 @@ def _read_content(self, size: int, output: BytesIO) -> int: if StructuredMessageProperties.CRC64 in self.flags: if checksum_offset == 0: - self._segment_crc64s[self._current_segment_number] = \ - calculate_crc64(content, self._segment_crc64s[self._current_segment_number]) + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) self._message_crc64 = calculate_crc64(content, self._message_crc64) self._content_offset += read_size @@ -425,14 +501,22 @@ class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-att _segment_content_offset: int _block_size: int - def __init__(self, inner_iterator: Iterator[bytes], content_length: int, *, block_size: int = 4096) -> None: + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError("Content not long enough to contain a valid message header.") + raise ValueError( + "Content not long enough to contain a valid message header." + ) self._inner_iterator = inner_iterator - self._buffer = b'' + self._buffer = b"" self._message_offset = 0 self._message_crc64 = 0 @@ -453,11 +537,19 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) @property def _message_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) @property def _end_of_segment_content(self) -> bool: @@ -483,7 +575,7 @@ def read(self, size: int = -1) -> bytes: raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: - return b'' + return b"" if size < 0: size = sys.maxsize @@ -496,17 +588,23 @@ def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: self._read_segment_footer() if self.num_segments > 1: - raise ValueError("First message segment was empty but more segments were detected.") + raise ValueError( + "First message segment was empty but more segments were detected." + ) self._read_message_footer() - return b'' + return b"" count = 0 content = BytesIO() - while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): if self._end_of_segment_content: self._read_segment_header() - segment_remaining = self._segment_content_length - self._segment_content_offset + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) read_size = min(segment_remaining, size - count) segment_content = self._read_from_inner(read_size) @@ -514,8 +612,12 @@ def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) - self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) self._segment_content_offset += read_size self._message_offset += read_size @@ -529,7 +631,10 @@ def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if self._message_offset == self.message_length and self._segment_number != self.num_segments: + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -543,7 +648,9 @@ def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError("Invalid structured message data detected. Stream content incomplete.") + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) data = self._buffer[:size] self._buffer = self._buffer[size:] @@ -552,7 +659,8 @@ def _read_from_inner(self, size: int) -> bytes: def _read_message_header(self) -> None: header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length) + header_data, self.message_length + ) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH def _read_message_footer(self) -> None: @@ -564,16 +672,19 @@ def _read_message_footer(self) -> None: if StructuredMessageProperties.CRC64 in self.flags: message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) - if self._message_crc64 != int.from_bytes(message_crc, 'little'): - raise ValueError("CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid.") + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) self._message_offset += self._message_footer_length def _read_segment_header(self) -> None: header_data = self._read_from_inner(self._segment_header_length) self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1) + header_data, self._segment_number + 1 + ) self._message_offset += self._segment_header_length self._segment_content_offset = 0 @@ -583,8 +694,10 @@ def _read_segment_footer(self) -> None: if StructuredMessageProperties.CRC64 in self.flags: segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) - if self._segment_crc64 != int.from_bytes(segment_crc, 'little'): - raise ValueError(f"CRC64 mismatch detected in segment {self._segment_number}. " - f"All data read should be considered invalid.") + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py index 0bd608d02379..ee7d92d14d77 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py @@ -8,11 +8,18 @@ from io import BytesIO, IOBase from typing import AsyncIterator -from .streams import StructuredMessageConstants, StructuredMessageProperties, parse_message_header, parse_segment_header +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) from .validation import calculate_crc64 -class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes +class AsyncStructuredMessageDecoder( + IOBase +): # pylint: disable=too-many-instance-attributes message_version: int """The version of the structured message.""" @@ -33,14 +40,22 @@ class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instanc _segment_content_offset: int _block_size: int - def __init__(self, inner_iterator: AsyncIterator[bytes], content_length: int, *, block_size: int = 4096) -> None: + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError("Content not long enough to contain a valid message header.") + raise ValueError( + "Content not long enough to contain a valid message header." + ) self._inner_iterator = inner_iterator - self._buffer = b'' + self._buffer = b"" self._message_offset = 0 self._message_crc64 = 0 @@ -61,11 +76,19 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) @property def _message_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) @property def _end_of_segment_content(self) -> bool: @@ -91,7 +114,7 @@ async def read(self, size: int = -1) -> bytes: raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: - return b'' + return b"" if size < 0: size = sys.maxsize @@ -104,17 +127,23 @@ async def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: await self._read_segment_footer() if self.num_segments > 1: - raise ValueError("First message segment was empty but more segments were detected.") + raise ValueError( + "First message segment was empty but more segments were detected." + ) await self._read_message_footer() - return b'' + return b"" count = 0 content = BytesIO() - while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): if self._end_of_segment_content: await self._read_segment_header() - segment_remaining = self._segment_content_length - self._segment_content_offset + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) read_size = min(segment_remaining, size - count) segment_content = await self._read_from_inner(read_size) @@ -122,8 +151,12 @@ async def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) - self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) self._segment_content_offset += read_size self._message_offset += read_size @@ -137,7 +170,10 @@ async def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if self._message_offset == self.message_length and self._segment_number != self.num_segments: + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -151,16 +187,21 @@ async def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError("Invalid structured message data detected. Stream content incomplete.") + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) data = self._buffer[:size] self._buffer = self._buffer[size:] return data async def _read_message_header(self) -> None: - header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + header_data = await self._read_from_inner( + StructuredMessageConstants.V1_HEADER_LENGTH + ) self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length) + header_data, self.message_length + ) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH async def _read_message_footer(self) -> None: @@ -170,18 +211,23 @@ async def _read_message_footer(self) -> None: raise ValueError("Invalid structured message data detected.") if StructuredMessageProperties.CRC64 in self.flags: - message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + message_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) - if self._message_crc64 != int.from_bytes(message_crc, 'little'): - raise ValueError("CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid.") + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) self._message_offset += self._message_footer_length async def _read_segment_header(self) -> None: header_data = await self._read_from_inner(self._segment_header_length) self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1) + header_data, self._segment_number + 1 + ) self._message_offset += self._segment_header_length self._segment_content_offset = 0 @@ -189,10 +235,14 @@ async def _read_segment_header(self) -> None: async def _read_segment_footer(self) -> None: if StructuredMessageProperties.CRC64 in self.flags: - segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) - - if self._segment_crc64 != int.from_bytes(segment_crc, 'little'): - raise ValueError(f"CRC64 mismatch detected in segment {self._segment_number}. " - f"All data read should be considered invalid.") + segment_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py index 329ef7517d9b..5370d9dd669c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py @@ -8,7 +8,7 @@ import hashlib from enum import Enum from io import SEEK_SET -from typing import IO, cast, Literal, Union +from typing import IO, Literal, Optional, Union, cast from azure.core import CaseInsensitiveEnumMeta @@ -21,8 +21,49 @@ class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): MD5 = "md5" CRC64 = "crc64" + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) -def is_md5_validation(validate_content: Union[bool, Literal["md5", "crc64"]]) -> bool: + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], +) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + if validate_content not in (ChecksumAlgorithm.list()): + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if validate_content == ChecksumAlgorithm.AUTO: + validate_content = ChecksumAlgorithm.CRC64.value + + if validate_content == ChecksumAlgorithm.CRC64: + _verify_extensions("crc64") + + return validate_content + + +def is_md5_validation( + validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], +) -> bool: + if validate_content is None: + return False if isinstance(validate_content, bool): return validate_content return validate_content == ChecksumAlgorithm.MD5 @@ -61,4 +102,4 @@ def calculate_crc64_bytes(data: bytes) -> bytes: # Locally import to avoid error if not installed. from azure.storage.extensions import crc64 - return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, 'little')) + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py index 64b0432da803..6873b93bb4e6 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py @@ -5,7 +5,7 @@ # -------------------------------------------------------------------------- from io import SEEK_SET, UnsupportedOperation -from typing import Any, cast, Dict, IO, Optional, TypeVar, TYPE_CHECKING +from typing import Any, cast, Dict, IO, Literal, Optional, TypeVar, Union, TYPE_CHECKING from azure.core.exceptions import ResourceExistsError, ResourceModifiedError, HttpResponseError @@ -71,7 +71,7 @@ def upload_block_blob( # pylint: disable=too-many-locals, too-many-statements encryption_options: Dict[str, Any], blob_settings: "StorageConfiguration", headers: Dict[str, Any], - validate_content: bool, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]], max_concurrency: Optional[int], length: Optional[int] = None, **kwargs: Any @@ -125,7 +125,7 @@ def upload_block_blob( # pylint: disable=too-many-locals, too-many-statements use_original_upload_path = ( blob_settings.use_byte_buffer - or validate_content is not None + or validate_content not in (None, False) or encryption_options.get('required') or blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold or hasattr(stream, 'seekable') and not stream.seekable() @@ -213,7 +213,7 @@ def upload_page_blob( headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: @@ -291,7 +291,7 @@ def upload_append_blob( # pylint: disable=unused-argument headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py index f1143006ec69..ccf70c26b2e5 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py @@ -77,6 +77,7 @@ from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper, parse_connection_str from .._shared.policies_async import ExponentialRetry from .._shared.response_handlers import process_storage_error, return_response_headers +from .._shared.validation import ChecksumAlgorithm, parse_validation_option if TYPE_CHECKING: from azure.core import MatchConditions @@ -516,15 +517,11 @@ async def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: If specified, upload_blob only succeeds if the blob's lease is active and matches this ID. @@ -629,6 +626,9 @@ async def upload_blob( raise ValueError("Encryption required but no key was provided.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _upload_blob_options( data=data, blob_type=blob_type, @@ -640,6 +640,7 @@ async def upload_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -696,15 +697,11 @@ async def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a @@ -778,6 +775,9 @@ async def download_blob( raise ValueError("Offset value must not be None if length is set.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _download_blob_options( blob_name=self.blob_name, container_name=self.container_name, @@ -791,6 +791,7 @@ async def download_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -2053,15 +2054,11 @@ async def stage_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -2896,13 +2893,11 @@ async def upload_page( Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. :paramtype lease: ~azure.storage.blob.aio.BlobLeaseClient or str - :keyword bool validate_content: - If true, calculates an MD5 hash of the page content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int if_sequence_number_lte: If the blob's sequence number is less than or equal to the specified value, the request proceeds; otherwise it fails. @@ -3204,13 +3199,11 @@ async def append_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash of the block content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int maxsize_condition: Optional conditional header. The max length in bytes permitted for the append blob. If the Append Block operation would cause the blob diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi index 9c4a37007f65..ba9a7460425c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi @@ -175,7 +175,7 @@ class BlobClient( # type: ignore[misc] tags: Optional[Dict[str, str]] = None, overwrite: bool = False, content_settings: Optional[ContentSettings] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[BlobLeaseClient] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -202,7 +202,7 @@ class BlobClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -224,7 +224,7 @@ class BlobClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -246,7 +246,7 @@ class BlobClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -470,7 +470,7 @@ class BlobClient( # type: ignore[misc] data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, encoding: Optional[str] = None, cpk: Optional[CustomerProvidedEncryptionKey] = None, @@ -655,7 +655,7 @@ class BlobClient( # type: ignore[misc] length: int, *, lease: Optional[Union[BlobLeaseClient, str]] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, if_sequence_number_lte: Optional[int] = None, if_sequence_number_lt: Optional[int] = None, if_sequence_number_eq: Optional[int] = None, @@ -725,7 +725,7 @@ class BlobClient( # type: ignore[misc] data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, maxsize_condition: Optional[int] = None, appendpos_condition: Optional[int] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index e08abc8d3ca6..5ba5832ed0a0 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -1025,15 +1025,11 @@ async def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used, because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the container has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -1271,15 +1267,11 @@ async def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi index f4be54eaea38..1c49b69ecb79 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi @@ -16,6 +16,7 @@ from typing import ( Callable, Dict, List, + Literal, IO, Iterable, Optional, @@ -258,7 +259,7 @@ class ContainerClient( # type: ignore[misc] *, overwrite: Optional[bool] = None, content_settings: Optional[ContentSettings] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -300,7 +301,7 @@ class ContainerClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -322,7 +323,7 @@ class ContainerClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -344,7 +345,7 @@ class ContainerClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py index cf14fdb07b5b..6168ec93e9e5 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py @@ -15,7 +15,7 @@ from typing import ( Any, AsyncIterator, Awaitable, Generator, Callable, cast, Dict, - Generic, IO, Optional, overload, + Generic, IO, Literal, Optional, overload, Tuple, TypeVar, Union, TYPE_CHECKING ) @@ -239,7 +239,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: bool = None, # type: ignore [assignment] + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, encryption_options: Dict[str, Any] = None, # type: ignore [assignment] max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py index dc7b35b04307..5b551fdec2fb 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py @@ -6,7 +6,7 @@ import inspect from io import SEEK_SET, UnsupportedOperation -from typing import Any, cast, Dict, IO, Optional, TypeVar, TYPE_CHECKING +from typing import Any, cast, Dict, IO, Literal, Optional, TypeVar, Union, TYPE_CHECKING from azure.core.exceptions import HttpResponseError, ResourceModifiedError @@ -47,7 +47,7 @@ async def upload_block_blob( # pylint: disable=too-many-locals, too-many-statem encryption_options: Dict[str, Any], blob_settings: "StorageConfiguration", headers: Dict[str, Any], - validate_content: bool, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]], max_concurrency: Optional[int], length: Optional[int] = None, **kwargs: Any @@ -105,7 +105,7 @@ async def upload_block_blob( # pylint: disable=too-many-locals, too-many-statem use_original_upload_path = ( blob_settings.use_byte_buffer - or validate_content is not None + or validate_content not in (None, False) or encryption_options.get('required') or blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold or hasattr(stream, 'seekable') and not stream.seekable() @@ -193,7 +193,7 @@ async def upload_page_blob( headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: @@ -271,7 +271,7 @@ async def upload_append_blob( # pylint: disable=unused-argument headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: diff --git a/sdk/storage/azure-storage-blob/dev_requirements.txt b/sdk/storage/azure-storage-blob/dev_requirements.txt index de5100414c4c..d2188b6513c7 100644 --- a/sdk/storage/azure-storage-blob/dev_requirements.txt +++ b/sdk/storage/azure-storage-blob/dev_requirements.txt @@ -1,5 +1,6 @@ -e ../../../eng/tools/azure-sdk-tools ../../core/azure-core ../../identity/azure-identity +../azure-storage-extensions azure-mgmt-storage==20.1.0 aiohttp>=3.0 \ No newline at end of file diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation.py b/sdk/storage/azure-storage-blob/tests/test_content_validation.py index 4d24bd015f4f..5d670ba330fb 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation.py @@ -9,6 +9,7 @@ import pytest from azure.storage.blob import ( BlobBlock, + BlobClient, BlobServiceClient, BlobType, ContainerClient @@ -20,6 +21,7 @@ StorageRecordedTestCase ) +from encryption_test_helper import KeyWrapper from settings.testcase import BlobPreparer @@ -90,28 +92,29 @@ def teardown_method(self, _): def _get_blob_reference(self): return self.get_resource_name('blob') - # TODO: This test coming later - # @BlobPreparer() - # def test_encryption_blocked_crc64(self, **kwargs): - # storage_account_name = kwargs.pop("storage_account_name") - # storage_account_key = kwargs.pop("storage_account_key") - - # kek = KeyWrapper('key1') - # blob = BlobClient( - # self.account_url(storage_account_name, "blob"), - # "testing", - # "testing", - # credential=storage_account_key, - # require_encryption=True, - # encryption_version='2.0', - # key_encryption_key=kek) - - # with pytest.raises(ValueError): - # blob.upload_blob(b'123', validate_content='crc64') + @BlobPreparer() + def test_encryption_blocked_crc64(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + kek = KeyWrapper('key1') + blob = BlobClient( + self.account_url(storage_account_name, "blob"), + "testing", + "testing", + credential=self.get_credential(BlobServiceClient), + require_encryption=True, + encryption_version='2.0', + key_encryption_key=kek) + + with pytest.raises(ValueError): + blob.upload_blob(b'123', validate_content='crc64') + + # Needed for teardown + self.container = None @BlobPreparer() @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type - @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content + @pytest.mark.parametrize('b', [True, 'auto','md5', 'crc64']) # b: validate_content @GenericTestProxyParametrize2() @recorded_by_proxy def test_upload_blob(self, a, b, **kwargs): @@ -119,7 +122,7 @@ def test_upload_blob(self, a, b, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - assert_method = assert_content_crc64 if b == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if b in ('auto', 'crc64') else assert_content_md5 # Test supported data types byte_data = b'abc' * 512 @@ -200,7 +203,7 @@ def test_upload_blob_substream(self, a, **kwargs): assert content.read() == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_stage_block(self, a, **kwargs): @@ -217,7 +220,7 @@ def generator(): for i in range(0, len(data1), 500): yield data1[i: i + 500] - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 blob.stage_block('1', data1, validate_content=a, raw_request_hook=assert_method) blob.stage_block('2', data2, encoding='utf-8-sig', validate_content=a, raw_request_hook=assert_method) @@ -270,7 +273,7 @@ def test_stage_block_streaming_large(self, a, **kwargs): assert result.read() == data1 + data2 + data3 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_append_block(self, a, **kwargs): @@ -287,7 +290,7 @@ def generator(): for i in range(0, len(data1), 500): yield data1[i: i + 500] - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 blob.create_append_blob() blob.append_block(data1, validate_content=a, raw_request_hook=assert_method) @@ -339,7 +342,7 @@ def test_append_block_streaming_large(self, a, **kwargs): assert result.read() == data1 + data2 + data3 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_page(self, a, **kwargs): @@ -350,7 +353,7 @@ def test_upload_page(self, a, **kwargs): data1 = b'abc' * 512 data2 = "你好世界abcd" * 32 data2_encoded = data2.encode('utf-8') - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 # Act blob.create_page_blob(5 * 1024) @@ -362,7 +365,7 @@ def test_upload_page(self, a, **kwargs): assert content.read() == data1 + data2_encoded @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_blob(self, a, **kwargs): @@ -372,7 +375,7 @@ def test_download_blob(self, a, **kwargs): blob = self.container.get_blob_client(self._get_blob_reference()) data = b'abc' * 512 blob.upload_blob(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get # Act downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py index a0fdcbee01d8..b9de57c7015d 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py @@ -10,6 +10,7 @@ from azure.core.exceptions import ResourceExistsError from azure.storage.blob import BlobBlock, BlobType from azure.storage.blob.aio import ( + BlobClient, BlobServiceClient, ContainerClient ) @@ -19,6 +20,7 @@ GenericTestProxyParametrize1, GenericTestProxyParametrize2 ) +from encryption_test_helper import KeyWrapper from settings.testcase import BlobPreparer from test_content_validation import ( @@ -55,28 +57,26 @@ async def _teardown(self): def _get_blob_reference(self): return self.get_resource_name('blob') - # TODO: This test coming later - # @BlobPreparer() - # async def test_encryption_blocked_crc64(self, **kwargs): - # storage_account_name = kwargs.pop("storage_account_name") - # storage_account_key = kwargs.pop("storage_account_key") - - # kek = KeyWrapper('key1') - # blob = BlobClient( - # self.account_url(storage_account_name, "blob"), - # "testing", - # "testing", - # credential=storage_account_key, - # require_encryption=True, - # encryption_version='2.0', - # key_encryption_key=kek) - - # with pytest.raises(ValueError): - # await blob.upload_blob(b'123', validate_content='crc64') + @BlobPreparer() + async def test_encryption_blocked_crc64(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + kek = KeyWrapper('key1') + blob = BlobClient( + self.account_url(storage_account_name, "blob"), + "testing", + "testing", + credential=self.get_credential(BlobServiceClient, is_async=True), + require_encryption=True, + encryption_version='2.0', + key_encryption_key=kek) + + with pytest.raises(ValueError): + await blob.upload_blob(b'123', validate_content='crc64') @BlobPreparer() @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type - @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content + @pytest.mark.parametrize('b', [True, "auto", 'md5', 'crc64']) # b: validate_content @GenericTestProxyParametrize2() @recorded_by_proxy_async async def test_upload_blob(self, a, b, **kwargs): @@ -84,7 +84,7 @@ async def test_upload_blob(self, a, b, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - assert_method = assert_content_crc64 if b == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if b in ('auto', 'crc64') else assert_content_md5 # Test supported data types byte_data = b'abc' * 512 @@ -172,7 +172,7 @@ async def test_upload_blob_substream(self, a, **kwargs): await self._teardown() @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_stage_block(self, a, **kwargs): @@ -189,7 +189,7 @@ def generator(): for i in range(0, len(data1), 500): yield data1[i: i + 500] - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 # Act await blob.stage_block('1', data1, validate_content=a, raw_request_hook=assert_method) @@ -246,7 +246,7 @@ async def test_stage_block_streaming_large(self, a, **kwargs): assert await result.read() == data1 + data2 + data3 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_append_block(self, a, **kwargs): @@ -263,7 +263,7 @@ def generator(): for i in range(0, len(data1), 500): yield data1[i: i + 500] - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 # Act await blob.create_append_blob() @@ -320,7 +320,7 @@ async def test_append_block_streaming_large(self, a, **kwargs): await self._teardown() @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_page(self, a, **kwargs): @@ -331,7 +331,7 @@ async def test_upload_page(self, a, **kwargs): data1 = b'abc' * 512 data2 = "你好世界abcd" * 32 data2_encoded = data2.encode('utf-8') - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 # Act await blob.create_page_blob(5 * 1024) @@ -344,7 +344,7 @@ async def test_upload_page(self, a, **kwargs): await self._teardown() @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_blob(self, a, **kwargs): @@ -354,7 +354,7 @@ async def test_download_blob(self, a, **kwargs): blob = self.container.get_blob_client(self._get_blob_reference()) data = b'abc' * 512 await blob.upload_blob(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get # Act downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py index 3f65ae8d6498..71f4cce10541 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py @@ -12,7 +12,7 @@ import uuid from io import SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -32,8 +32,20 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_md5_validation, + ChecksumAlgorithm, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +56,15 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -55,7 +73,12 @@ def encode_base64(data): # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) + retry_counts = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -64,7 +87,9 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs + ) # Is this method/status code retryable? (Based on allowlists and control @@ -84,7 +109,9 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get("x-ms-copy-source-error-code") + error_code = response.http_response.headers.get( + "x-ms-copy-source-error-code" + ) if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -101,12 +128,16 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -135,7 +166,9 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str( + uuid.uuid1() + ) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -175,7 +208,9 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError(f"Attempting to use undefined host location {use_location}") + raise ValueError( + f"Attempting to use undefined host location {use_location}" + ) if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -193,7 +228,9 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) + super(StorageLoggingPolicy, self).__init__( + logging_enable=logging_enable, **kwargs + ) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -222,7 +259,16 @@ def on_request(self, request: "PipelineRequest") -> None: parsed_qs["sig"] = "*****" # the SAS needs to be put back together - value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + value = urlunparse( + ( + scheme, + netloc, + path, + params, + urlencode(parsed_qs), + fragment, + ) + ) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") @@ -235,7 +281,9 @@ def on_request(self, request: "PipelineRequest") -> None: except Exception as err: # pylint: disable=broad-except _LOGGER.debug("Failed to log request: %r", err) - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: if response.context.pop("logging_enable", self.enable_http_logger): if not _LOGGER.isEnabledFor(logging.DEBUG): return @@ -250,7 +298,9 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get("content-type", "") + resp_content_type = response.http_response.headers.get( + "content-type", "" + ) if header and pattern.match(header): filename = header.partition("=")[2] @@ -279,7 +329,9 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop("raw_request_hook", self._request_callback) + request_callback = request.context.options.pop( + "raw_request_hook", self._request_callback + ) if request_callback: request_callback(request) @@ -297,36 +349,50 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -334,64 +400,133 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. +def _prepare_content_validation(request: "PipelineRequest") -> None: + """Shared request-side logic for content validation. - This will overwrite any headers already defined in the request. + Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 + validation, and stores the validation mode in the request context. """ + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return - header_name = "Content-MD5" - - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() + # Download + if request.http_request.method == "GET": + if validate_content == ChecksumAlgorithm.CRC64: + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - @staticmethod - def get_content_md5(data): + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. - data = data or b"" - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc - else: - raise ValueError("Data should be bytes or a seekable file-like object.") + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 - return md5.digest() + elif validate_content == ChecksumAlgorithm.CRC64: + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64( + calculate_crc64_bytes(data) + ) + elif hasattr(data, "read"): + content_length = int( + request.http_request.headers.get(CONTENT_LENGTH_HEADER) + ) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream( + data, content_length, StructuredMessageProperties.CRC64 + ) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str( + len(sm_stream) + ) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + request.context["validate_content"] = validate_content - def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - request.context["validate_content"] = validate_content - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + """Shared response-side logic for content validation. + + Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches + ``stream_download`` to wrap the iterator in the given *decoder_cls*. + """ + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if is_md5_validation(validate_content) and response.http_response.headers.get( + "content-md5" + ): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) + + elif validate_content == ChecksumAlgorithm.CRC64: + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}", + ), + response=response.http_response, ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls( + iterator, content_length, block_size=DATA_BLOCK_SIZE ) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + def on_request(self, request: "PipelineRequest") -> None: + _prepare_content_validation(request) + + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): @@ -418,7 +553,9 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: + def _set_next_host_location( + self, settings: Dict[str, Any], request: "PipelineRequest" + ) -> None: """ A function which sets the next host location on the request, if applicable. @@ -438,7 +575,7 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: """ Configure the retry settings for the request. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A dictionary containing the retry settings. @@ -457,7 +594,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "retry_secondary": options.pop( + "retry_to_secondary", self.retry_to_secondary + ), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -466,7 +605,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument + def get_backoff_time( + self, settings: Dict[str, Any] + ) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -478,7 +619,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disabl def sleep(self, settings, transport): """Sleep for the backoff time. - + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. :param transport: The transport to use for sleeping. :type transport: @@ -529,7 +670,9 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append(RequestHistory(request, http_response=response)) + settings["history"].append( + RequestHistory(request, http_response=response) + ) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -552,7 +695,7 @@ def increment( def send(self, request): """Send the request with retry logic. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A pipeline response object. @@ -564,13 +707,20 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry( + response + ): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) self.sleep(retry_settings, request.context.transport) continue @@ -578,9 +728,16 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - retry_hook(retry_settings, request=request.http_request, response=None, error=err) + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) self.sleep(retry_settings, request.context.transport) continue raise err @@ -633,7 +790,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -646,8 +805,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -685,7 +850,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -700,7 +867,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -708,16 +879,22 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "TokenCredential", audience: str, **kwargs: Any + ) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: """Handle the challenge from the service and authorize the request. - + :param request: The request object. :type request: ~azure.core.pipeline.PipelineRequest :param response: The response object. - :type response: ~azure.core.pipeline.PipelineResponse + :type response: ~azure.core.pipeline.PipelineResponse :return: True if the request was authorized, False otherwise. :rtype: bool """ diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py index 4cb32f23248b..14ce070e47ff 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py @@ -11,11 +11,25 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncHTTPPolicy, +) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE -from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy +from .policies import ( + _prepare_content_validation, + _validate_content_response, + encode_base64, + is_retry, + StorageRetryPolicy, +) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -31,27 +45,66 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + await settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) else: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): @@ -65,36 +118,50 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = await self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -123,13 +190,20 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): + if is_retry( + response, retry_settings["mode"] + ) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: await retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) await self.sleep(retry_settings, request.context.transport) continue @@ -137,9 +211,16 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - await retry_hook(retry_settings, request=request.http_request, response=None, error=err) + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) await self.sleep(retry_settings, request.context.transport) continue raise err @@ -194,7 +275,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -207,8 +290,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -246,7 +335,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -261,7 +352,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -269,10 +364,16 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any + ) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + async def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py new file mode 100644 index 000000000000..712f4e90af69 --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py @@ -0,0 +1,703 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from typing import IO, Iterator, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header( + version: int, size: int, flags: StructuredMessageProperties, num_segments: int +) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) + + +def generate_segment_header(number: int, size: int) -> bytes: + return number.to_bytes(2, "little") + size.to_bytes(8, "little") + + +def parse_message_header( + data: bytes, expected_message_length: int +) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], "little") + if message_length != expected_message_length: + raise ValueError( + f"Structured message length {message_length} " + f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], "little") + if segment_number != expected_segment_number: + raise ValueError( + f"Structured message segment number invalid or out of order {segment_number}" + ) + segment_content_length = int.from_bytes(data[2:10], "little") + return segment_number, segment_content_length + + +class StructuredMessageEncodeStream( + IOBase +): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _checksum_offset: int + """Tracks the offset the checksum has been calculated up to for seeking purposes""" + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, + inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE, + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._checksum_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + def close(self) -> None: + self._inner_stream.close() + super().close() + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return ( + self._inner_stream.seekable() + and self._initial_content_position is not None + ) + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence == SEEK_SET: + position = offset + elif whence == SEEK_CUR: + position = self.tell() + offset + elif whence == SEEK_END: + position = self.message_length + offset + else: + raise ValueError(f"Invalid value for whence: {whence}") + + if position < 0: + raise ValueError(f"Cannot seek to negative position: {position}") + if position > self.tell(): + raise UnsupportedOperation("This stream only supports seeking backwards.") + + # MESSAGE_HEADER + if position < self._message_header_length: + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_offset = position + self._content_offset = 0 + self._current_segment_number = 0 + # MESSAGE_FOOTER + elif position >= self.message_length - self._message_footer_length: + self._current_region = SMRegion.MESSAGE_FOOTER + self._current_region_offset = position - ( + self.message_length - self._message_footer_length + ) + self._content_offset = self.content_length + self._current_segment_number = self._num_segments + else: + # The size of a "full" segment. Fine to use for calculating new segment number and pos + full_segment_size = ( + self._segment_header_length + + self._segment_size + + self._segment_footer_length + ) + new_segment_num = ( + 1 + (position - self._message_header_length) // full_segment_size + ) + segment_pos = (position - self._message_header_length) % full_segment_size + previous_segments_total_content_size = ( + new_segment_num - 1 + ) * self._segment_size + + # We need the size of the segment we are seeking to for some of the calculations below + new_segment_size = self._segment_size + if new_segment_num == self._num_segments: + # The last segment size is the remaining content length + new_segment_size = ( + self.content_length - previous_segments_total_content_size + ) + + # SEGMENT_HEADER + if segment_pos < self._segment_header_length: + self._current_region = SMRegion.SEGMENT_HEADER + self._current_region_offset = segment_pos + self._content_offset = previous_segments_total_content_size + # SEGMENT_CONTENT + elif segment_pos < self._segment_header_length + new_segment_size: + self._current_region = SMRegion.SEGMENT_CONTENT + self._current_region_offset = segment_pos - self._segment_header_length + self._content_offset = ( + previous_segments_total_content_size + self._current_region_offset + ) + # SEGMENT_FOOTER + else: + self._current_region = SMRegion.SEGMENT_FOOTER + self._current_region_offset = ( + segment_pos - self._segment_header_length - new_segment_size + ) + self._content_offset = ( + previous_segments_total_content_size + new_segment_size + ) + + self._current_segment_number = new_segment_num + + self._update_current_region_length() + self._inner_stream.seek( + (self._initial_content_position or 0) + self._content_offset + ) + return position + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b"" + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region( + self._current_region, remaining, output + ) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += ( + self._segment_header_length + self._segment_footer_length + ) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments, + ) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min( + self._segment_size, self.content_length - self._content_offset + ) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region( + self, region: SMRegion, size: int, output: BytesIO + ) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[ + self._current_region_offset : self._current_region_offset + read_size + ] + output.write(content) + + self._current_region_offset += read_size + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + # Will be non-zero if there is data to read that does not need to have checksum calculated. + # Will always be positive as stream can only seek backwards. + checksum_offset = self._checksum_offset - self._content_offset + + read_size = min(size, self._current_region_length - self._current_region_offset) + if checksum_offset != 0: + # Only read up to checksum offset this iteration + read_size = min(read_size, checksum_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + if checksum_offset == 0: + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + # Only update the checksum offset if we've read new data + if self._content_offset > self._checksum_offset: + self._checksum_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: Iterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + + def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + def _read_message_header(self) -> None: + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py new file mode 100644 index 000000000000..ee7d92d14d77 --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py @@ -0,0 +1,248 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder( + IOBase +): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + await self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner( + StructuredMessageConstants.V1_HEADER_LENGTH + ) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py new file mode 100644 index 000000000000..5370d9dd669c --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py @@ -0,0 +1,105 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +import hashlib +from enum import Enum +from io import SEEK_SET +from typing import IO, Literal, Optional, Union, cast + +from azure.core import CaseInsensitiveEnumMeta + +CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." + + +class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): + AUTO = "auto" + MD5 = "md5" + CRC64 = "crc64" + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], +) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + if validate_content not in (ChecksumAlgorithm.list()): + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if validate_content == ChecksumAlgorithm.AUTO: + validate_content = ChecksumAlgorithm.CRC64.value + + if validate_content == ChecksumAlgorithm.CRC64: + _verify_extensions("crc64") + + return validate_content + + +def is_md5_validation( + validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return validate_content + return validate_content == ChecksumAlgorithm.MD5 + + +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py index 3f65ae8d6498..71f4cce10541 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py @@ -12,7 +12,7 @@ import uuid from io import SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -32,8 +32,20 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_md5_validation, + ChecksumAlgorithm, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +56,15 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -55,7 +73,12 @@ def encode_base64(data): # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) + retry_counts = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -64,7 +87,9 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs + ) # Is this method/status code retryable? (Based on allowlists and control @@ -84,7 +109,9 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get("x-ms-copy-source-error-code") + error_code = response.http_response.headers.get( + "x-ms-copy-source-error-code" + ) if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -101,12 +128,16 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -135,7 +166,9 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str( + uuid.uuid1() + ) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -175,7 +208,9 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError(f"Attempting to use undefined host location {use_location}") + raise ValueError( + f"Attempting to use undefined host location {use_location}" + ) if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -193,7 +228,9 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) + super(StorageLoggingPolicy, self).__init__( + logging_enable=logging_enable, **kwargs + ) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -222,7 +259,16 @@ def on_request(self, request: "PipelineRequest") -> None: parsed_qs["sig"] = "*****" # the SAS needs to be put back together - value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + value = urlunparse( + ( + scheme, + netloc, + path, + params, + urlencode(parsed_qs), + fragment, + ) + ) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") @@ -235,7 +281,9 @@ def on_request(self, request: "PipelineRequest") -> None: except Exception as err: # pylint: disable=broad-except _LOGGER.debug("Failed to log request: %r", err) - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: if response.context.pop("logging_enable", self.enable_http_logger): if not _LOGGER.isEnabledFor(logging.DEBUG): return @@ -250,7 +298,9 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get("content-type", "") + resp_content_type = response.http_response.headers.get( + "content-type", "" + ) if header and pattern.match(header): filename = header.partition("=")[2] @@ -279,7 +329,9 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop("raw_request_hook", self._request_callback) + request_callback = request.context.options.pop( + "raw_request_hook", self._request_callback + ) if request_callback: request_callback(request) @@ -297,36 +349,50 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -334,64 +400,133 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. +def _prepare_content_validation(request: "PipelineRequest") -> None: + """Shared request-side logic for content validation. - This will overwrite any headers already defined in the request. + Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 + validation, and stores the validation mode in the request context. """ + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return - header_name = "Content-MD5" - - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() + # Download + if request.http_request.method == "GET": + if validate_content == ChecksumAlgorithm.CRC64: + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - @staticmethod - def get_content_md5(data): + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. - data = data or b"" - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc - else: - raise ValueError("Data should be bytes or a seekable file-like object.") + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 - return md5.digest() + elif validate_content == ChecksumAlgorithm.CRC64: + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64( + calculate_crc64_bytes(data) + ) + elif hasattr(data, "read"): + content_length = int( + request.http_request.headers.get(CONTENT_LENGTH_HEADER) + ) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream( + data, content_length, StructuredMessageProperties.CRC64 + ) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str( + len(sm_stream) + ) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + request.context["validate_content"] = validate_content - def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - request.context["validate_content"] = validate_content - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + """Shared response-side logic for content validation. + + Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches + ``stream_download`` to wrap the iterator in the given *decoder_cls*. + """ + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if is_md5_validation(validate_content) and response.http_response.headers.get( + "content-md5" + ): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) + + elif validate_content == ChecksumAlgorithm.CRC64: + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}", + ), + response=response.http_response, ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls( + iterator, content_length, block_size=DATA_BLOCK_SIZE ) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + def on_request(self, request: "PipelineRequest") -> None: + _prepare_content_validation(request) + + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): @@ -418,7 +553,9 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: + def _set_next_host_location( + self, settings: Dict[str, Any], request: "PipelineRequest" + ) -> None: """ A function which sets the next host location on the request, if applicable. @@ -438,7 +575,7 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: """ Configure the retry settings for the request. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A dictionary containing the retry settings. @@ -457,7 +594,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "retry_secondary": options.pop( + "retry_to_secondary", self.retry_to_secondary + ), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -466,7 +605,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument + def get_backoff_time( + self, settings: Dict[str, Any] + ) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -478,7 +619,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disabl def sleep(self, settings, transport): """Sleep for the backoff time. - + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. :param transport: The transport to use for sleeping. :type transport: @@ -529,7 +670,9 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append(RequestHistory(request, http_response=response)) + settings["history"].append( + RequestHistory(request, http_response=response) + ) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -552,7 +695,7 @@ def increment( def send(self, request): """Send the request with retry logic. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A pipeline response object. @@ -564,13 +707,20 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry( + response + ): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) self.sleep(retry_settings, request.context.transport) continue @@ -578,9 +728,16 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - retry_hook(retry_settings, request=request.http_request, response=None, error=err) + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) self.sleep(retry_settings, request.context.transport) continue raise err @@ -633,7 +790,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -646,8 +805,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -685,7 +850,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -700,7 +867,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -708,16 +879,22 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "TokenCredential", audience: str, **kwargs: Any + ) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: """Handle the challenge from the service and authorize the request. - + :param request: The request object. :type request: ~azure.core.pipeline.PipelineRequest :param response: The response object. - :type response: ~azure.core.pipeline.PipelineResponse + :type response: ~azure.core.pipeline.PipelineResponse :return: True if the request was authorized, False otherwise. :rtype: bool """ diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py index 4cb32f23248b..14ce070e47ff 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py @@ -11,11 +11,25 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncHTTPPolicy, +) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE -from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy +from .policies import ( + _prepare_content_validation, + _validate_content_response, + encode_base64, + is_retry, + StorageRetryPolicy, +) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -31,27 +45,66 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + await settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) else: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): @@ -65,36 +118,50 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = await self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -123,13 +190,20 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): + if is_retry( + response, retry_settings["mode"] + ) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: await retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) await self.sleep(retry_settings, request.context.transport) continue @@ -137,9 +211,16 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - await retry_hook(retry_settings, request=request.http_request, response=None, error=err) + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) await self.sleep(retry_settings, request.context.transport) continue raise err @@ -194,7 +275,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -207,8 +290,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -246,7 +335,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -261,7 +352,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -269,10 +364,16 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any + ) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + async def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py new file mode 100644 index 000000000000..712f4e90af69 --- /dev/null +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py @@ -0,0 +1,703 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from typing import IO, Iterator, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header( + version: int, size: int, flags: StructuredMessageProperties, num_segments: int +) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) + + +def generate_segment_header(number: int, size: int) -> bytes: + return number.to_bytes(2, "little") + size.to_bytes(8, "little") + + +def parse_message_header( + data: bytes, expected_message_length: int +) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], "little") + if message_length != expected_message_length: + raise ValueError( + f"Structured message length {message_length} " + f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], "little") + if segment_number != expected_segment_number: + raise ValueError( + f"Structured message segment number invalid or out of order {segment_number}" + ) + segment_content_length = int.from_bytes(data[2:10], "little") + return segment_number, segment_content_length + + +class StructuredMessageEncodeStream( + IOBase +): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _checksum_offset: int + """Tracks the offset the checksum has been calculated up to for seeking purposes""" + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, + inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE, + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._checksum_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + def close(self) -> None: + self._inner_stream.close() + super().close() + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return ( + self._inner_stream.seekable() + and self._initial_content_position is not None + ) + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence == SEEK_SET: + position = offset + elif whence == SEEK_CUR: + position = self.tell() + offset + elif whence == SEEK_END: + position = self.message_length + offset + else: + raise ValueError(f"Invalid value for whence: {whence}") + + if position < 0: + raise ValueError(f"Cannot seek to negative position: {position}") + if position > self.tell(): + raise UnsupportedOperation("This stream only supports seeking backwards.") + + # MESSAGE_HEADER + if position < self._message_header_length: + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_offset = position + self._content_offset = 0 + self._current_segment_number = 0 + # MESSAGE_FOOTER + elif position >= self.message_length - self._message_footer_length: + self._current_region = SMRegion.MESSAGE_FOOTER + self._current_region_offset = position - ( + self.message_length - self._message_footer_length + ) + self._content_offset = self.content_length + self._current_segment_number = self._num_segments + else: + # The size of a "full" segment. Fine to use for calculating new segment number and pos + full_segment_size = ( + self._segment_header_length + + self._segment_size + + self._segment_footer_length + ) + new_segment_num = ( + 1 + (position - self._message_header_length) // full_segment_size + ) + segment_pos = (position - self._message_header_length) % full_segment_size + previous_segments_total_content_size = ( + new_segment_num - 1 + ) * self._segment_size + + # We need the size of the segment we are seeking to for some of the calculations below + new_segment_size = self._segment_size + if new_segment_num == self._num_segments: + # The last segment size is the remaining content length + new_segment_size = ( + self.content_length - previous_segments_total_content_size + ) + + # SEGMENT_HEADER + if segment_pos < self._segment_header_length: + self._current_region = SMRegion.SEGMENT_HEADER + self._current_region_offset = segment_pos + self._content_offset = previous_segments_total_content_size + # SEGMENT_CONTENT + elif segment_pos < self._segment_header_length + new_segment_size: + self._current_region = SMRegion.SEGMENT_CONTENT + self._current_region_offset = segment_pos - self._segment_header_length + self._content_offset = ( + previous_segments_total_content_size + self._current_region_offset + ) + # SEGMENT_FOOTER + else: + self._current_region = SMRegion.SEGMENT_FOOTER + self._current_region_offset = ( + segment_pos - self._segment_header_length - new_segment_size + ) + self._content_offset = ( + previous_segments_total_content_size + new_segment_size + ) + + self._current_segment_number = new_segment_num + + self._update_current_region_length() + self._inner_stream.seek( + (self._initial_content_position or 0) + self._content_offset + ) + return position + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b"" + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region( + self._current_region, remaining, output + ) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += ( + self._segment_header_length + self._segment_footer_length + ) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments, + ) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min( + self._segment_size, self.content_length - self._content_offset + ) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region( + self, region: SMRegion, size: int, output: BytesIO + ) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[ + self._current_region_offset : self._current_region_offset + read_size + ] + output.write(content) + + self._current_region_offset += read_size + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + # Will be non-zero if there is data to read that does not need to have checksum calculated. + # Will always be positive as stream can only seek backwards. + checksum_offset = self._checksum_offset - self._content_offset + + read_size = min(size, self._current_region_length - self._current_region_offset) + if checksum_offset != 0: + # Only read up to checksum offset this iteration + read_size = min(read_size, checksum_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + if checksum_offset == 0: + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + # Only update the checksum offset if we've read new data + if self._content_offset > self._checksum_offset: + self._checksum_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: Iterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + + def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + def _read_message_header(self) -> None: + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py new file mode 100644 index 000000000000..ee7d92d14d77 --- /dev/null +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py @@ -0,0 +1,248 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder( + IOBase +): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + await self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner( + StructuredMessageConstants.V1_HEADER_LENGTH + ) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py new file mode 100644 index 000000000000..5370d9dd669c --- /dev/null +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py @@ -0,0 +1,105 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +import hashlib +from enum import Enum +from io import SEEK_SET +from typing import IO, Literal, Optional, Union, cast + +from azure.core import CaseInsensitiveEnumMeta + +CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." + + +class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): + AUTO = "auto" + MD5 = "md5" + CRC64 = "crc64" + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], +) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + if validate_content not in (ChecksumAlgorithm.list()): + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if validate_content == ChecksumAlgorithm.AUTO: + validate_content = ChecksumAlgorithm.CRC64.value + + if validate_content == ChecksumAlgorithm.CRC64: + _verify_extensions("crc64") + + return validate_content + + +def is_md5_validation( + validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return validate_content + return validate_content == ChecksumAlgorithm.MD5 + + +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index b343373dfce5..8dc7fdf4ce72 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -12,7 +12,7 @@ import uuid from io import SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -32,8 +32,20 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_md5_validation, + ChecksumAlgorithm, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +56,15 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -55,7 +73,12 @@ def encode_base64(data): # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) + retry_counts = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -64,7 +87,9 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs + ) # Is this method/status code retryable? (Based on allowlists and control @@ -84,7 +109,9 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get("x-ms-copy-source-error-code") + error_code = response.http_response.headers.get( + "x-ms-copy-source-error-code" + ) if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -101,12 +128,16 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -141,7 +172,9 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str( + uuid.uuid1() + ) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -181,7 +214,9 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError(f"Attempting to use undefined host location {use_location}") + raise ValueError( + f"Attempting to use undefined host location {use_location}" + ) if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -199,7 +234,9 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) + super(StorageLoggingPolicy, self).__init__( + logging_enable=logging_enable, **kwargs + ) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -228,7 +265,16 @@ def on_request(self, request: "PipelineRequest") -> None: parsed_qs["sig"] = "*****" # the SAS needs to be put back together - value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + value = urlunparse( + ( + scheme, + netloc, + path, + params, + urlencode(parsed_qs), + fragment, + ) + ) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") @@ -241,7 +287,9 @@ def on_request(self, request: "PipelineRequest") -> None: except Exception as err: # pylint: disable=broad-except _LOGGER.debug("Failed to log request: %r", err) - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: if response.context.pop("logging_enable", self.enable_http_logger): if not _LOGGER.isEnabledFor(logging.DEBUG): return @@ -256,7 +304,9 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get("content-type", "") + resp_content_type = response.http_response.headers.get( + "content-type", "" + ) if header and pattern.match(header): filename = header.partition("=")[2] @@ -285,7 +335,9 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop("raw_request_hook", self._request_callback) + request_callback = request.context.options.pop( + "raw_request_hook", self._request_callback + ) if request_callback: request_callback(request) @@ -303,36 +355,50 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -340,64 +406,133 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. +def _prepare_content_validation(request: "PipelineRequest") -> None: + """Shared request-side logic for content validation. - This will overwrite any headers already defined in the request. + Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 + validation, and stores the validation mode in the request context. """ + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return - header_name = "Content-MD5" + # Download + if request.http_request.method == "GET": + if validate_content == ChecksumAlgorithm.CRC64: + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() - - @staticmethod - def get_content_md5(data): + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. - data = data or b"" - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc - else: - raise ValueError("Data should be bytes or a seekable file-like object.") + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 - return md5.digest() + elif validate_content == ChecksumAlgorithm.CRC64: + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64( + calculate_crc64_bytes(data) + ) + elif hasattr(data, "read"): + content_length = int( + request.http_request.headers.get(CONTENT_LENGTH_HEADER) + ) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream( + data, content_length, StructuredMessageProperties.CRC64 + ) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str( + len(sm_stream) + ) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) - def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - request.context["validate_content"] = validate_content + request.context["validate_content"] = validate_content + + +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + """Shared response-side logic for content validation. - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches + ``stream_download`` to wrap the iterator in the given *decoder_cls*. + """ + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if is_md5_validation(validate_content) and response.http_response.headers.get( + "content-md5" + ): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) + + elif validate_content == ChecksumAlgorithm.CRC64: + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}", + ), + response=response.http_response, ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls( + iterator, content_length, block_size=DATA_BLOCK_SIZE ) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + def on_request(self, request: "PipelineRequest") -> None: + _prepare_content_validation(request) + + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): @@ -424,7 +559,9 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: + def _set_next_host_location( + self, settings: Dict[str, Any], request: "PipelineRequest" + ) -> None: """ A function which sets the next host location on the request, if applicable. @@ -463,7 +600,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "retry_secondary": options.pop( + "retry_to_secondary", self.retry_to_secondary + ), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -472,7 +611,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument + def get_backoff_time( + self, settings: Dict[str, Any] + ) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -535,7 +676,9 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append(RequestHistory(request, http_response=response)) + settings["history"].append( + RequestHistory(request, http_response=response) + ) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -570,13 +713,20 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry( + response + ): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) self.sleep(retry_settings, request.context.transport) continue @@ -584,9 +734,16 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - retry_hook(retry_settings, request=request.http_request, response=None, error=err) + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) self.sleep(retry_settings, request.context.transport) continue raise err @@ -639,7 +796,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -652,8 +811,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -691,7 +856,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -706,7 +873,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -714,10 +885,16 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "TokenCredential", audience: str, **kwargs: Any + ) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: """Handle the challenge from the service and authorize the request. :param request: The request object. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 4cb32f23248b..14ce070e47ff 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -11,11 +11,25 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncHTTPPolicy, +) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE -from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy +from .policies import ( + _prepare_content_validation, + _validate_content_response, + encode_base64, + is_retry, + StorageRetryPolicy, +) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -31,27 +45,66 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + await settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) else: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): @@ -65,36 +118,50 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = await self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -123,13 +190,20 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): + if is_retry( + response, retry_settings["mode"] + ) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: await retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) await self.sleep(retry_settings, request.context.transport) continue @@ -137,9 +211,16 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - await retry_hook(retry_settings, request=request.http_request, response=None, error=err) + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) await self.sleep(retry_settings, request.context.transport) continue raise err @@ -194,7 +275,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -207,8 +290,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -246,7 +335,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -261,7 +352,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -269,10 +364,16 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any + ) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + async def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py new file mode 100644 index 000000000000..712f4e90af69 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py @@ -0,0 +1,703 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from typing import IO, Iterator, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header( + version: int, size: int, flags: StructuredMessageProperties, num_segments: int +) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) + + +def generate_segment_header(number: int, size: int) -> bytes: + return number.to_bytes(2, "little") + size.to_bytes(8, "little") + + +def parse_message_header( + data: bytes, expected_message_length: int +) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], "little") + if message_length != expected_message_length: + raise ValueError( + f"Structured message length {message_length} " + f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], "little") + if segment_number != expected_segment_number: + raise ValueError( + f"Structured message segment number invalid or out of order {segment_number}" + ) + segment_content_length = int.from_bytes(data[2:10], "little") + return segment_number, segment_content_length + + +class StructuredMessageEncodeStream( + IOBase +): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _checksum_offset: int + """Tracks the offset the checksum has been calculated up to for seeking purposes""" + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, + inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE, + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._checksum_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + def close(self) -> None: + self._inner_stream.close() + super().close() + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return ( + self._inner_stream.seekable() + and self._initial_content_position is not None + ) + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence == SEEK_SET: + position = offset + elif whence == SEEK_CUR: + position = self.tell() + offset + elif whence == SEEK_END: + position = self.message_length + offset + else: + raise ValueError(f"Invalid value for whence: {whence}") + + if position < 0: + raise ValueError(f"Cannot seek to negative position: {position}") + if position > self.tell(): + raise UnsupportedOperation("This stream only supports seeking backwards.") + + # MESSAGE_HEADER + if position < self._message_header_length: + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_offset = position + self._content_offset = 0 + self._current_segment_number = 0 + # MESSAGE_FOOTER + elif position >= self.message_length - self._message_footer_length: + self._current_region = SMRegion.MESSAGE_FOOTER + self._current_region_offset = position - ( + self.message_length - self._message_footer_length + ) + self._content_offset = self.content_length + self._current_segment_number = self._num_segments + else: + # The size of a "full" segment. Fine to use for calculating new segment number and pos + full_segment_size = ( + self._segment_header_length + + self._segment_size + + self._segment_footer_length + ) + new_segment_num = ( + 1 + (position - self._message_header_length) // full_segment_size + ) + segment_pos = (position - self._message_header_length) % full_segment_size + previous_segments_total_content_size = ( + new_segment_num - 1 + ) * self._segment_size + + # We need the size of the segment we are seeking to for some of the calculations below + new_segment_size = self._segment_size + if new_segment_num == self._num_segments: + # The last segment size is the remaining content length + new_segment_size = ( + self.content_length - previous_segments_total_content_size + ) + + # SEGMENT_HEADER + if segment_pos < self._segment_header_length: + self._current_region = SMRegion.SEGMENT_HEADER + self._current_region_offset = segment_pos + self._content_offset = previous_segments_total_content_size + # SEGMENT_CONTENT + elif segment_pos < self._segment_header_length + new_segment_size: + self._current_region = SMRegion.SEGMENT_CONTENT + self._current_region_offset = segment_pos - self._segment_header_length + self._content_offset = ( + previous_segments_total_content_size + self._current_region_offset + ) + # SEGMENT_FOOTER + else: + self._current_region = SMRegion.SEGMENT_FOOTER + self._current_region_offset = ( + segment_pos - self._segment_header_length - new_segment_size + ) + self._content_offset = ( + previous_segments_total_content_size + new_segment_size + ) + + self._current_segment_number = new_segment_num + + self._update_current_region_length() + self._inner_stream.seek( + (self._initial_content_position or 0) + self._content_offset + ) + return position + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b"" + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region( + self._current_region, remaining, output + ) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += ( + self._segment_header_length + self._segment_footer_length + ) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments, + ) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min( + self._segment_size, self.content_length - self._content_offset + ) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region( + self, region: SMRegion, size: int, output: BytesIO + ) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[ + self._current_region_offset : self._current_region_offset + read_size + ] + output.write(content) + + self._current_region_offset += read_size + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + # Will be non-zero if there is data to read that does not need to have checksum calculated. + # Will always be positive as stream can only seek backwards. + checksum_offset = self._checksum_offset - self._content_offset + + read_size = min(size, self._current_region_length - self._current_region_offset) + if checksum_offset != 0: + # Only read up to checksum offset this iteration + read_size = min(read_size, checksum_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + if checksum_offset == 0: + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + # Only update the checksum offset if we've read new data + if self._content_offset > self._checksum_offset: + self._checksum_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: Iterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + + def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + def _read_message_header(self) -> None: + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py new file mode 100644 index 000000000000..ee7d92d14d77 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py @@ -0,0 +1,248 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder( + IOBase +): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + await self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner( + StructuredMessageConstants.V1_HEADER_LENGTH + ) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py new file mode 100644 index 000000000000..5370d9dd669c --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py @@ -0,0 +1,105 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +import hashlib +from enum import Enum +from io import SEEK_SET +from typing import IO, Literal, Optional, Union, cast + +from azure.core import CaseInsensitiveEnumMeta + +CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." + + +class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): + AUTO = "auto" + MD5 = "md5" + CRC64 = "crc64" + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], +) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + if validate_content not in (ChecksumAlgorithm.list()): + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if validate_content == ChecksumAlgorithm.AUTO: + validate_content = ChecksumAlgorithm.CRC64.value + + if validate_content == ChecksumAlgorithm.CRC64: + _verify_extensions("crc64") + + return validate_content + + +def is_md5_validation( + validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return validate_content + return validate_content == ChecksumAlgorithm.MD5 + + +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little"))