|
| 1 | +from typing import Callable, Optional, List, Tuple, Any |
| 2 | +from io import BytesIO |
| 3 | +from datetime import datetime |
| 4 | +from asyncio import Future |
| 5 | + |
| 6 | +from kiota_abstractions.serialization.parsable import Parsable |
| 7 | +from kiota_abstractions.method import Method |
| 8 | +from kiota_abstractions.request_adapter import RequestAdapter |
| 9 | +from kiota_abstractions.request_information import RequestInformation |
| 10 | +from kiota_abstractions.serialization.additional_data_holder import AdditionalDataHolder |
| 11 | + |
| 12 | +from msgraph_core.models import LargeFileUploadCreateSession, LargeFileUploadSession |
| 13 | + |
| 14 | + |
| 15 | +class LargeFileUploadTask: |
| 16 | + |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + upload_session: Parsable, |
| 20 | + request_adapter: RequestAdapter, |
| 21 | + stream: BytesIO, |
| 22 | + max_chunk_size: int = 4 * 1024 * 1024 |
| 23 | + ): |
| 24 | + self.upload_session = upload_session |
| 25 | + self.request_adapter = request_adapter |
| 26 | + self.stream = stream |
| 27 | + self.file_size = stream.getbuffer().nbytes |
| 28 | + self.max_chunk_size = max_chunk_size |
| 29 | + cleaned_value = self.check_value_exists( |
| 30 | + upload_session, 'get_next_expected_range', ['next_expected_range', 'NextExpectedRange'] |
| 31 | + ) |
| 32 | + self.next_range = cleaned_value[0] |
| 33 | + self.chunks = int((self.file_size / max_chunk_size) + 0.5) |
| 34 | + self.on_chunk_upload_complete: Optional[Callable[[int, int], None]] = None |
| 35 | + |
| 36 | + def get_upload_session(self) -> Parsable: |
| 37 | + return self.upload_session |
| 38 | + |
| 39 | + def get_adapter(self) -> RequestAdapter: |
| 40 | + return self.request_adapter |
| 41 | + |
| 42 | + def create_upload_session(self, model: LargeFileUploadCreateSession, callback: Method): |
| 43 | + request_info = RequestInformation() |
| 44 | + request_info.set_uri(self.options.get_item_path()) |
| 45 | + request_info.set_http_method('POST') |
| 46 | + request_info.set_content_type('application/json') |
| 47 | + request_info.set_payload(model) |
| 48 | + |
| 49 | + self.request_adapter.send_async(request_info, LargeFileUploadSession, callback) |
| 50 | + |
| 51 | + def get_chunks(self) -> int: |
| 52 | + return self.chunks |
| 53 | + |
| 54 | + def upload_session_expired(self, upload_session: Optional[Parsable] = None) -> bool: |
| 55 | + now = datetime.now() |
| 56 | + |
| 57 | + validated_value = self.check_value_exists( |
| 58 | + upload_session or self.upload_session, 'get_expiration_date_time', |
| 59 | + ['ExpirationDateTime', 'expirationDateTime'] |
| 60 | + ) |
| 61 | + if not validated_value[0]: |
| 62 | + raise Exception('The upload session does not contain an expiry datetime.') |
| 63 | + |
| 64 | + expiry = validated_value[1] |
| 65 | + |
| 66 | + if expiry is None: |
| 67 | + raise ValueError('The upload session does not contain a valid expiry date.') |
| 68 | + |
| 69 | + then = datetime.strptime(expiry, "%Y-%m-%dT%H:%M:%S") |
| 70 | + interval = (now - then).total_seconds() |
| 71 | + |
| 72 | + if interval < 0: |
| 73 | + return True |
| 74 | + return False |
| 75 | + |
| 76 | + async def upload(self, after_chunk_upload: Optional[Callable] = None) -> Future: |
| 77 | + # Rewinds to take care of failures. |
| 78 | + self.stream.seek(0) |
| 79 | + if self.upload_session_expired(self.upload_session): |
| 80 | + raise RuntimeError('The upload session is expired.') |
| 81 | + |
| 82 | + self.on_chunk_upload_complete = after_chunk_upload if after_chunk_upload is not None else self.on_chunk_upload_complete |
| 83 | + session = self.next_chunk( |
| 84 | + self.stream, 0, max(0, min(self.max_chunk_size - 1, self.file_size - 1)) |
| 85 | + ) |
| 86 | + process_next = session |
| 87 | + # determine the range uploaded |
| 88 | + # includes when resuming existing upload sessions. |
| 89 | + range_parts = self.next_range[0].split("-") if self.next_range else ['0', '0'] |
| 90 | + end = min(int(range_parts[0]) + self.max_chunk_size - 1, self.file_size) |
| 91 | + uploaded_range = [range_parts[0], end] |
| 92 | + while self.chunks > 0: |
| 93 | + session = process_next |
| 94 | + process_next = session.then( |
| 95 | + lambda upload_session: self.process_chunk(upload_session, uploaded_range), |
| 96 | + lambda error: self.handle_error(error) |
| 97 | + ) |
| 98 | + if process_next is not None: |
| 99 | + await process_next |
| 100 | + self.chunks -= 1 |
| 101 | + return session |
| 102 | + |
| 103 | + def process_chunk(self, upload_session, uploaded_range): |
| 104 | + if upload_session is None: |
| 105 | + return upload_session |
| 106 | + next_range = upload_session.get_next_expected_ranges() |
| 107 | + old_url = self.get_validated_upload_url(self.upload_session) |
| 108 | + upload_session.set_upload_url(old_url) |
| 109 | + if self.on_chunk_upload_complete is not None: |
| 110 | + self.on_chunk_upload_complete(uploaded_range) |
| 111 | + if not next_range: |
| 112 | + return upload_session |
| 113 | + range_parts = next_range[0].split("-") |
| 114 | + end = min(int(range_parts[0]) + self.max_chunk_size, self.file_size) |
| 115 | + uploaded_range = [range_parts[0], end] |
| 116 | + self.set_next_range(next_range[0] + "-") |
| 117 | + process_next = self.next_chunk(self.stream) |
| 118 | + return upload_session |
| 119 | + |
| 120 | + def handle_error(self, error): |
| 121 | + raise error |
| 122 | + |
| 123 | + def set_next_range(self, next_range: Optional[str]) -> None: |
| 124 | + self.next_range = next_range |
| 125 | + |
| 126 | + async def next_chunk(self, file: BytesIO, range_start: int = 0, range_end: int = 0) -> Future: |
| 127 | + upload_url = self.get_validated_upload_url(self.upload_session) |
| 128 | + |
| 129 | + if not upload_url: |
| 130 | + raise ValueError('The upload session URL must not be empty.') |
| 131 | + info = RequestInformation() |
| 132 | + info.set_uri(upload_url) |
| 133 | + info.http_method = HttpMethod.PUT |
| 134 | + if not self.next_range: |
| 135 | + self.set_next_range(f'{range_start}-{range_end}') |
| 136 | + range_parts = self.next_range.split('-') if self.next_range else ['-'] |
| 137 | + start = int(range_parts[0]) |
| 138 | + end = int(range_parts[1]) if len(range_parts) > 1 else 0 |
| 139 | + if start == 0 and end == 0: |
| 140 | + chunk_data = file.read(self.max_chunk_size) |
| 141 | + end = min(self.max_chunk_size - 1, self.file_size - 1) |
| 142 | + elif start == 0: |
| 143 | + chunk_data = file.read(end + 1) |
| 144 | + elif end == 0: |
| 145 | + file.seek(start) |
| 146 | + chunk_data = file.read(self.max_chunk_size) |
| 147 | + end = start + len(chunk_data) - 1 |
| 148 | + else: |
| 149 | + file.seek(start) |
| 150 | + end = min(end, self.max_chunk_size + start) |
| 151 | + chunk_data = file.read(end - start + 1) |
| 152 | + |
| 153 | + info.set_headers( |
| 154 | + { |
| 155 | + **info.get_headers(), 'Content-Range': f'bytes {start}-{end}/{self.file_size}' |
| 156 | + } |
| 157 | + ) |
| 158 | + info.set_headers({**info.get_headers(), 'Content-Length': str(len(chunk_data))}) |
| 159 | + |
| 160 | + info.set_stream_content(BytesIO(chunk_data)) |
| 161 | + return await self.adapter.send_async( |
| 162 | + info, LargeFileUploadSession.create_from_discriminator_value |
| 163 | + ) |
| 164 | + |
| 165 | + def get_file(self) -> BytesIO: |
| 166 | + return self.stream |
| 167 | + |
| 168 | + async def cancel(self) -> Optional[Future]: |
| 169 | + request_information = RequestInformation() |
| 170 | + request_information.http_method = HttpMethod.DELETE |
| 171 | + |
| 172 | + upload_url = self.get_validated_upload_url(self.upload_session) |
| 173 | + |
| 174 | + request_information.set_uri(upload_url) |
| 175 | + result = await self.request_adapter.send_no_content_async(request_information) |
| 176 | + |
| 177 | + if hasattr(self.upload_session, 'set_is_cancelled'): |
| 178 | + self.upload_session.set_is_cancelled(True) |
| 179 | + elif hasattr(self.upload_session, 'set_additional_data' |
| 180 | + ) and hasattr(self.upload_session, 'get_additional_data'): |
| 181 | + current = self.upload_session.get_additional_data() |
| 182 | + new = {**current, 'is_cancelled': True} |
| 183 | + self.upload_session.set_additional_data(new) |
| 184 | + return self.upload_session |
| 185 | + |
| 186 | + def additional_data_contains(self, parsable: Parsable, |
| 187 | + property_candidates: List[str]) -> Tuple[bool, Any]: |
| 188 | + if not issubclass(type(parsable), AdditionalDataHolder): |
| 189 | + raise ValueError( |
| 190 | + f'The object passed does not contain property/properties {",".join(property_candidates)} and does not implement AdditionalDataHolder' |
| 191 | + ) |
| 192 | + additional_data = parsable.get_additional_data() |
| 193 | + for property_candidate in property_candidates: |
| 194 | + if property_candidate in additional_data: |
| 195 | + return True, additional_data[property_candidate] |
| 196 | + return False, None |
| 197 | + |
| 198 | + def check_value_exists( |
| 199 | + self, parsable: Parsable, getter_name: str, property_names_in_additional_data: List[str] |
| 200 | + ) -> Tuple[bool, Any]: |
| 201 | + checked_additional_data = self.additional_data_contains( |
| 202 | + parsable, property_names_in_additional_data |
| 203 | + ) |
| 204 | + if issubclass(type(parsable), AdditionalDataHolder) and checked_additional_data[0]: |
| 205 | + return True, checked_additional_data[1] |
| 206 | + |
| 207 | + if hasattr(parsable, getter_name): |
| 208 | + return True, getattr(parsable, getter_name)() |
| 209 | + |
| 210 | + return False, None |
| 211 | + |
| 212 | + async def resume(self) -> Future: |
| 213 | + if self.upload_session_expired(self.upload_session): |
| 214 | + raise RuntimeError('The upload session is expired.') |
| 215 | + |
| 216 | + validated_value = self.check_value_exists( |
| 217 | + self.upload_session, 'get_next_expected_ranges', |
| 218 | + ['NextExpectedRanges', 'nextExpectedRanges'] |
| 219 | + ) |
| 220 | + if not validated_value[0]: |
| 221 | + raise RuntimeError( |
| 222 | + 'The object passed does not contain a valid "nextExpectedRanges" property.' |
| 223 | + ) |
| 224 | + |
| 225 | + next_ranges: List[str] = validated_value[1] |
| 226 | + if len(next_ranges) == 0: |
| 227 | + raise RuntimeError('No more bytes expected.') |
| 228 | + |
| 229 | + next_range = next_ranges[0] |
| 230 | + self.next_range = next_range |
| 231 | + return await self.upload() |
| 232 | + |
| 233 | + def get_validated_upload_url(self, upload_session: Parsable) -> str: |
| 234 | + if not hasattr(upload_session, 'get_upload_url'): |
| 235 | + raise RuntimeError('The upload session does not contain a valid upload url') |
| 236 | + result = upload_session.get_upload_url() |
| 237 | + |
| 238 | + if result is None or result.strip() == '': |
| 239 | + raise RuntimeError('The upload URL cannot be empty.') |
| 240 | + return result |
| 241 | + |
| 242 | + def get_next_range(self) -> Optional[str]: |
| 243 | + return self.next_range |
| 244 | + |
| 245 | + def get_next_range(self) -> Optional[str]: |
| 246 | + return self.next_range |
0 commit comments