diff --git a/b2sdk/stream/hashing.py b/b2sdk/stream/hashing.py index 99968dbe7..d80412bba 100644 --- a/b2sdk/stream/hashing.py +++ b/b2sdk/stream/hashing.py @@ -17,13 +17,16 @@ class StreamWithHash(ReadOnlyStreamMixin, StreamWithLengthWrapper): """ - Wrap a file-like object, calculates SHA1 while reading - and appends hash at the end. + Wrap a file-like object, calculates SHA1 while reading and appends hash at the end. + + :ivar ~.hash: sha1 checksum of the stream, can be ``None`` if unknown (yet) + :vartype ~.hash: str or None """ - def __init__(self, stream, stream_length=None): + def __init__(self, stream, stream_length=None, upload_source=None): """ :param stream: the stream to read from + :param upload_source: used to set content_sha1 in upload_source (in case of retry etc) """ self.digest = self.get_digest() total_length = None @@ -44,6 +47,7 @@ def seek(self, pos, whence=0): self.digest = self.get_digest() self.hash = None self.hash_read = 0 + self.upload_source = upload_source return super(StreamWithHash, self).seek(0) def read(self, size=None): @@ -63,6 +67,8 @@ def read(self, size=None): # Check for end of stream if size is None or len(data) < size: self.hash = self.digest.hexdigest() + if self.upload_source is not None: + self.upload_source.content_sha1 = self.hash if size is not None: size -= len(data) diff --git a/b2sdk/transfer/emerge/executor.py b/b2sdk/transfer/emerge/executor.py index 776156a31..f1b03d1d6 100644 --- a/b2sdk/transfer/emerge/executor.py +++ b/b2sdk/transfer/emerge/executor.py @@ -10,7 +10,7 @@ import threading -from abc import ABCMeta, abstractmethod +from abc import abstractmethod import six @@ -18,6 +18,7 @@ from b2sdk.file_version import FileVersionInfoFactory from b2sdk.transfer.outbound.large_file_upload_state import LargeFileUploadState from b2sdk.transfer.outbound.upload_source import UploadSourceStream +from b2sdk.utils import B2TraceMetaAbstract from b2sdk.utils import interruptible_get_result AUTO_CONTENT_TYPE = 'b2/x-auto' @@ -63,7 +64,7 @@ def execute_emerge_plan( return execution.execute_plan(emerge_plan) -@six.add_metaclass(ABCMeta) +@six.add_metaclass(B2TraceMetaAbstract) class BaseEmergeExecution(object): DEFAULT_CONTENT_TYPE = AUTO_CONTENT_TYPE @@ -300,7 +301,7 @@ def _match_unfinished_file_if_possible( return None, {} -@six.add_metaclass(ABCMeta) +@six.add_metaclass(B2TraceMetaAbstract) class BaseExecutionStepFactory(object): def __init__(self, emerge_execution, emerge_part): self.emerge_execution = emerge_execution @@ -369,7 +370,7 @@ def create_upload_execution_step(self, stream_opener, stream_length=None, stream ) -@six.add_metaclass(ABCMeta) +@six.add_metaclass(B2TraceMetaAbstract) class BaseExecutionStep(object): @abstractmethod def execute(self): diff --git a/b2sdk/transfer/outbound/upload_manager.py b/b2sdk/transfer/outbound/upload_manager.py index 555fc09b4..be9aae858 100644 --- a/b2sdk/transfer/outbound/upload_manager.py +++ b/b2sdk/transfer/outbound/upload_manager.py @@ -158,15 +158,26 @@ def _upload_part( try: with part_upload_source.open() as part_stream: input_stream = ReadingStreamWithProgress(part_stream, part_progress_listener) - hashing_stream = StreamWithHash( - input_stream, stream_length=part_upload_source.get_content_length() - ) - # it is important that `len()` works on `hashing_stream` + + if part_upload_source.is_sha1_known(): + sha1_checksum = part_upload_source.content_sha1 + logger.debug('hash for part %s is known: %s, use that', part_upload_source, sha1_checksum) + else: + sha1_checksum = HEX_DIGITS_AT_END + # wrap it with a hasher + input_stream = StreamWithHash( + input_stream, + stream_length=part_upload_source.get_content_length(), + ) + logger.debug('hash for part %s is unknown, calculate it and provide it at the end of the stream', part_upload_source) response = self.services.session.upload_part( - file_id, part_number, hashing_stream.length, HEX_DIGITS_AT_END, - hashing_stream + file_id, + part_number, + part_upload_source.get_content_length(), + HEX_DIGITS_AT_END, + input_stream, ) - assert hashing_stream.hash == response['contentSha1'] + assert part_upload_source.get_content_sha1() == response['contentSha1'], 'part checksum mismatch! %s vs %s' % (part_upload_source.get_content_sha1(), response['contentSha1']) return response except B2Error as e: @@ -189,13 +200,24 @@ def _upload_small_file( try: with upload_source.open() as file: input_stream = ReadingStreamWithProgress(file, progress_listener) - hashing_stream = StreamWithHash(input_stream, stream_length=content_length) - # it is important that `len()` works on `hashing_stream` + if upload_source.is_sha1_known(): + sha1_checksum = upload_source.content_sha1 + logger.debug('hash for %s is known: %s, use that', upload_source, sha1_checksum) + else: + sha1_checksum = HEX_DIGITS_AT_END + # wrap it with a hasher + input_stream = StreamWithHash(input_stream, stream_length=content_length) + logger.debug('hash for %s is unknown, calculate it and provide it at the end of the stream', upload_source) response = self.services.session.upload_file( - bucket_id, file_name, hashing_stream.length, content_type, - HEX_DIGITS_AT_END, file_info, hashing_stream + bucket_id, + file_name, + content_length, + content_type, + sha1_checksum, # can be HEX_DIGITS_AT_END + file_info, + input_stream, # can be a hashing stream or a raw stream ) - assert hashing_stream.hash == response['contentSha1'] + assert upload_source.get_content_sha1() == response['contentSha1'], 'small file checksum mismatch!' return FileVersionInfoFactory.from_api_response(response) except B2Error as e: @@ -204,4 +226,5 @@ def _upload_small_file( exception_info_list.append(e) self.account_info.clear_bucket_upload_data(bucket_id) + raise MaxRetriesExceeded(self.MAX_UPLOAD_ATTEMPTS, exception_info_list) diff --git a/b2sdk/transfer/outbound/upload_source.py b/b2sdk/transfer/outbound/upload_source.py index cb02079d5..a97187240 100644 --- a/b2sdk/transfer/outbound/upload_source.py +++ b/b2sdk/transfer/outbound/upload_source.py @@ -23,19 +23,40 @@ class AbstractUploadSource(OutboundTransferSource): """ The source of data for uploading to b2. + + `is_sha1_known()` is useful for medium-sized files where in the first upload attempt we'd like to + stream-read-and-hash, but later on when retrying, the hash is already calculated, so + there is no point in calculating it again. The caller may use :py:class:`b2sdk.v1.StreamWithHash` + in the first attempt and then switch to passing the checksum explicitly to :meth:`b2sdk.v1.Session.upload_file` + in order to avoid (cpu-intensive) re-streaming. + + :ivar ~.content_sha1: sha1 checksum of the entire file, can be ``None`` if unknown (yet) + :vartype ~.content_sha1: str or None """ + def __init__(self, content_sha1=None): + self.content_sha1 = content_sha1 # NOTE: b2sdk.transfer.upload_manager *writes* to this field @abstractmethod def get_content_sha1(self): """ - Return a 40-character string containing the hex SHA1 checksum of the data in the file. + Return a 40-character string containing the hex sha1 checksum of the data in the file. + The implementation of this method may cache the checksum value to avoid recalculating it. + This method may not be thread-safe: if two threads are trying to get the checksum + at the exact same moment, it may be calculated twice. + """ + + def is_sha1_known(self): + """ + Tells whether the checksum would be calculated if `get_content_sha1()` would be called. + + :rtype: bool """ + return self.content_sha1 is not None @abstractmethod def open(self): """ - Return a binary file-like object from which the - data can be read. + Return a binary file-like object from which the data can be read. :return: """ @@ -47,8 +68,9 @@ def is_copy(self): class UploadSourceBytes(AbstractUploadSource): - def __init__(self, data_bytes): + def __init__(self, data_bytes, content_sha1=None): self.data_bytes = data_bytes + super(UploadSourceBytes, self).__init__(content_sha1) def __repr__(self): return '<{classname} data={data} id={id}>'.format( @@ -62,7 +84,9 @@ def get_content_length(self): return len(self.data_bytes) def get_content_sha1(self): - return hashlib.sha1(self.data_bytes).hexdigest() + if self.content_sha1 is None: + self.content_sha1 = hashlib.sha1(self.data_bytes).hexdigest() + return self.content_sha1 def open(self): return io.BytesIO(self.data_bytes) @@ -74,7 +98,7 @@ def __init__(self, local_path, content_sha1=None): if not os.path.isfile(local_path): raise InvalidUploadSource(local_path) self.content_length = os.path.getsize(local_path) - self.content_sha1 = content_sha1 + super(UploadSourceLocalFile, self).__init__(content_sha1) def __repr__(self): return ( @@ -115,6 +139,7 @@ def __init__(self, local_path, content_sha1=None, offset=0, length=None): if length + self.offset > self.file_size: raise ValueError('Range length overflow file size') self.content_length = length + super(UploadSourceLocalFileRange, self).__init__(content_sha1) def __repr__(self): return ( @@ -138,7 +163,7 @@ class UploadSourceStream(AbstractUploadSource): def __init__(self, stream_opener, stream_length=None, stream_sha1=None): self.stream_opener = stream_opener self._content_length = stream_length - self._content_sha1 = stream_sha1 + super(UploadSourceStream, self).__init__(content_sha1=stream_sha1) def __repr__(self): return ( @@ -148,7 +173,7 @@ def __repr__(self): classname=self.__class__.__name__, stream_opener=repr(self.stream_opener), content_length=self._content_length, - content_sha1=self._content_sha1, + content_sha1=self.content_sha1, id=id(self), ) @@ -158,9 +183,9 @@ def get_content_length(self): return self._content_length def get_content_sha1(self): - if self._content_sha1 is None: + if self.content_sha1 is None: self._set_content_length_and_sha1() - return self._content_sha1 + return self.content_sha1 def open(self): return self.stream_opener() @@ -168,7 +193,7 @@ def open(self): def _set_content_length_and_sha1(self): sha1, content_length = hex_sha1_of_unlimited_stream(self.open()) self._content_length = content_length - self._content_sha1 = sha1 + self.content_sha1 = sha1 class UploadSourceStreamRange(UploadSourceStream): diff --git a/test/v1/test_bucket.py b/test/v1/test_bucket.py index 8f71157a4..e87264213 100644 --- a/test/v1/test_bucket.py +++ b/test/v1/test_bucket.py @@ -523,6 +523,18 @@ def test_upload_dead_symlink(self): with self.assertRaises(InvalidUploadSource): self.bucket.upload_local_file(path, 'file1') + def test_upload_local_wrong_sha(self): + with TempDir() as d: + path = os.path.join(d, 'file123') + data = six.b('hello world') + write_file(path, data) + with self.assertRaises(AssertionError): + self.bucket.upload_local_file( + path, + 'file123', + sha1_sum='abcabcabc', + ) + def test_upload_one_retryable_error(self): self.simulator.set_upload_errors([CanRetry(True)]) data = six.b('hello world')