diff --git a/common/auth/src/main/scala/org/apache/texera/auth/UploadTokenParser.scala b/common/auth/src/main/scala/org/apache/texera/auth/UploadTokenParser.scala new file mode 100644 index 00000000000..4897296da6e --- /dev/null +++ b/common/auth/src/main/scala/org/apache/texera/auth/UploadTokenParser.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.texera.auth + +import org.apache.texera.auth.util.CryptoService +import org.apache.texera.config.AuthConfig + +import com.fasterxml.jackson.annotation.{JsonCreator, JsonIgnoreProperties, JsonProperty} +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} + +object UploadTokenParser { + + val Version: String = "v1" + + @JsonIgnoreProperties(ignoreUnknown = true) + final case class UploadTokenPayload @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) ( + @JsonProperty(value = "version", required = true) + version: String, + @JsonProperty(value = "uploadId", required = true) + uploadId: String, + @JsonProperty(value = "did", required = true) + did: Int, + @JsonProperty(value = "uid", required = true) + uid: Int, + @JsonProperty(value = "filePath", required = true) + filePath: String, + @JsonProperty(value = "physicalAddress", required = true) + physicalAddress: String + ) + + private lazy val cryptoService: CryptoService = + CryptoService(AuthConfig.uploadTokenSecretKey) + + private lazy val objectMapper: ObjectMapper = + new ObjectMapper() + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + + def encode(payload: UploadTokenPayload): String = { + val node = objectMapper.createObjectNode() + node.put("version", Version) + node.put("uploadId", payload.uploadId) + node.put("did", payload.did) + node.put("uid", payload.uid) + node.put("filePath", payload.filePath) + node.put("physicalAddress", payload.physicalAddress) + + val rawJson = objectMapper.writeValueAsString(node) + cryptoService.encrypt(rawJson) + } + + def decode(token: String): UploadTokenPayload = { + val decryptedJson = cryptoService.decrypt(token) + val decodedPayload = objectMapper.readValue(decryptedJson, classOf[UploadTokenPayload]) + + decodedPayload + } +} diff --git a/common/auth/src/main/scala/org/apache/texera/auth/util/CryptoService.scala b/common/auth/src/main/scala/org/apache/texera/auth/util/CryptoService.scala new file mode 100644 index 00000000000..7f3f2ce3377 --- /dev/null +++ b/common/auth/src/main/scala/org/apache/texera/auth/util/CryptoService.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.auth.util + +import java.nio.charset.StandardCharsets +import java.security.{MessageDigest, SecureRandom} +import java.util.Base64 +import javax.crypto.{Cipher, SecretKey} +import javax.crypto.spec.{GCMParameterSpec, SecretKeySpec} + +/** + * Generic AES-GCM crypto utilities. + * + * Usage: + * val crypto = CryptoService("secret") + * val token = crypto.encrypt("hello") + * val plain = crypto.decrypt(token) + */ +final class CryptoService private (private val key: SecretKey) { + + def encrypt(plain: String): String = + CryptoService.encrypt(plain, key) + + def decrypt(token: String): String = + CryptoService.decrypt(token, key) +} + +object CryptoService { + private val Algorithm = "AES/GCM/NoPadding" + private val IvLength = 12 + private val TagLength = 128 + + private val random = new SecureRandom() + + /** Build an instance from a String secret. */ + def apply(secret: String): CryptoService = + new CryptoService(deriveKeyFromSecret(secret)) + + /** Derive a 256-bit AES key from a String. */ + def deriveKeyFromSecret(secret: String): SecretKey = { + val digest = MessageDigest.getInstance("SHA-256") + val keyBytes = digest.digest(secret.getBytes(StandardCharsets.UTF_8)) + new SecretKeySpec(keyBytes, "AES") + } + + /** Low-level encrypt with explicit key. + * + * Algorithm: AES-GCM (AEAD). + * - Provides confidentiality (encryption) and integrity/authenticity (GCM tag). + * - Output format (before Base64): [ IV || (ciphertext || tag) ] + * In JCE, `doFinal()` in GCM returns ciphertext with the authentication tag appended. + */ + def encrypt(plain: String, key: SecretKey): String = { + + // Allocate a fresh IV/nonce for this encryption. + // In GCM the IV must be unique per message under the same key; uniqueness prevents nonce-reuse attacks (keystream reuse and possible tag forgery). + val iv = new Array[Byte](IvLength) + + // Fill IV with cryptographically secure random bytes. + // Random IVs make identical plaintexts encrypt to different outputs and make collisions extremely unlikely. + random.nextBytes(iv) + + // Create a Cipher for the requested transformation (e.g., "AES/GCM/NoPadding"). + // GCM is the mode; "NoPadding" is standard for GCM in JCE. + val cipher = Cipher.getInstance(Algorithm) + + // Initialize cipher for ENCRYPT using: + // - `key`: the AES key material + // - `GCMParameterSpec(TagLength, iv)`: supplies the IV and the desired authentication tag length. + // TagLength is in bits (commonly 128 bits = 16 bytes). + // The tag is what later allows decryption to detect tampering. + cipher.init(Cipher.ENCRYPT_MODE, key, new GCMParameterSpec(TagLength, iv)) + + // Convert the plaintext string to bytes in a deterministic encoding (UTF-8). + // Crypto APIs operate on bytes; UTF-8 avoids platform-dependent encodings. + val plainBytes = plain.getBytes(StandardCharsets.UTF_8) + + // Encrypt and compute the authentication tag. + // For AES-GCM in JCE, `doFinal()` returns: ciphertext || tag (tag appended at the end). + // Any modification of the ciphertext/tag will be detected during decrypt via tag verification. + val cipherText = cipher.doFinal(plainBytes) + + // Build the final payload to return. + // We must include the IV with the output because decryption needs the same IV to recompute the keystream + // and verify the authentication tag. The IV is not secret, only required to be unique. + val combined = new Array[Byte](iv.length + cipherText.length) + + // Prefix the payload with the IV so the decrypt() routine can read it back. + System.arraycopy(iv, 0, combined, 0, iv.length) + + // Append ciphertext+tag after the IV. + System.arraycopy(cipherText, 0, combined, iv.length, cipherText.length) + + // Encode binary payload as URL-safe Base64 text (no '+', '/', or '=' padding). + // This makes it safe to store/transport in URLs, cookies, headers, etc. + Base64.getUrlEncoder.withoutPadding().encodeToString(combined) + } + + /** Low-level decrypt with explicit key. */ + def decrypt(token: String, key: SecretKey): String = { + val combined = Base64.getUrlDecoder.decode(token) + + if (combined.length <= IvLength) { + throw new IllegalArgumentException("Invalid encrypted token") + } + + val iv = java.util.Arrays.copyOfRange(combined, 0, IvLength) + val cipherText = java.util.Arrays.copyOfRange(combined, IvLength, combined.length) + + val cipher = Cipher.getInstance(Algorithm) + cipher.init(Cipher.DECRYPT_MODE, key, new GCMParameterSpec(TagLength, iv)) + + val plainBytes = cipher.doFinal(cipherText) + new String(plainBytes, StandardCharsets.UTF_8) + } +} diff --git a/common/config/src/main/resources/auth.conf b/common/config/src/main/resources/auth.conf index c99db10c85e..505d48ab93b 100644 --- a/common/config/src/main/resources/auth.conf +++ b/common/config/src/main/resources/auth.conf @@ -25,4 +25,8 @@ auth { 256-bit-secret = "8a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d" 256-bit-secret = ${?AUTH_JWT_SECRET} } + upload-token { + 256-bit-secret = "8fnegtjgby3ith7gjw3htr8rj3rhbeub" + 256-bit-secret = ${?AUTH_UPLOAD_TOKEN_SECRET} + } } \ No newline at end of file diff --git a/common/config/src/main/scala/org/apache/texera/config/AuthConfig.scala b/common/config/src/main/scala/org/apache/texera/config/AuthConfig.scala index e62863470cd..3675e996677 100644 --- a/common/config/src/main/scala/org/apache/texera/config/AuthConfig.scala +++ b/common/config/src/main/scala/org/apache/texera/config/AuthConfig.scala @@ -31,6 +31,7 @@ object AuthConfig { // For storing the generated/configured secret @volatile private var secretKey: String = _ + @volatile private var uploadTokenSecret: String = _ // Read JWT secret key with support for random generation def jwtSecretKey: String = { @@ -45,6 +46,18 @@ object AuthConfig { secretKey } + /** + * Secret used for encrypting upload tokens + * Config path: auth.upload-token.256-bit-secret + */ + def uploadTokenSecretKey: String = + synchronized { + if (uploadTokenSecret == null) { + uploadTokenSecret = conf.getString("auth.upload-token.256-bit-secret") + } + uploadTokenSecret + } + private def getRandomHexString: String = { val bytes = 32 val r = new Random() diff --git a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala index 63c09f4c30b..790fa9181c1 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala @@ -358,4 +358,13 @@ object LakeFSStorageClient { branchesApi.resetBranch(repoName, branchName, resetCreation).execute() } + + def parsePhysicalAddress(address: String): (String, String) = { + // expected: "://bucket/key..." + val uri = new java.net.URI(address) + val bucket = uri.getHost + val key = uri.getPath.stripPrefix("/") + (bucket, key) + } + } diff --git a/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala b/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala index 94007e988e5..572cb3696ac 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala @@ -39,6 +39,12 @@ object S3StorageClient { val MINIMUM_NUM_OF_MULTIPART_S3_PART: Long = 5L * 1024 * 1024 // 5 MiB val MAXIMUM_NUM_OF_MULTIPART_S3_PARTS = 10_000 + /** Minimal info about an active multipart upload. */ + final case class MultipartUploadInfo(key: String, uploadId: String) + + /** Minimal info about a completed part in an upload. */ + final case class PartInfo(partNumber: Int, eTag: String) + // Initialize MinIO-compatible S3 Client private lazy val s3Client: S3Client = { val credentials = AwsBasicCredentials.create(StorageConfig.s3Username, StorageConfig.s3Password) @@ -259,4 +265,60 @@ object S3StorageClient { DeleteObjectRequest.builder().bucket(bucketName).key(objectKey).build() ) } + + def uploadPart( + bucket: String, + key: String, + uploadId: String, + partNumber: Int, + inputStream: InputStream, + contentLength: Option[Long] + ): Unit = { + val body: RequestBody = contentLength match { + case Some(len) => RequestBody.fromInputStream(inputStream, len) + case None => + val bytes = inputStream.readAllBytes() + RequestBody.fromBytes(bytes) + } + + val req = UploadPartRequest + .builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .partNumber(partNumber) + .build() + + s3Client.uploadPart(req, body) + } + + /** + * List *all* parts for a given multipart upload (bucket + key + uploadId). + * Handles pagination (up to 1000 parts per page). + */ + def listAllParts( + bucket: String, + key: String, + uploadId: String + ): Seq[PartInfo] = { + val acc = scala.collection.mutable.ArrayBuffer.empty[PartInfo] + var partNumberMarker: Integer = null + var truncated = true + + while (truncated) { + val builder = + ListPartsRequest.builder().bucket(bucket).key(key).uploadId(uploadId) + if (partNumberMarker != null) builder.partNumberMarker(partNumberMarker) + + val resp = s3Client.listParts(builder.build()) + resp.parts().asScala.foreach { p => + acc += PartInfo(p.partNumber(), Option(p.eTag()).map(_.replace("\"", "")).orNull) + } + + truncated = resp.isTruncated + partNumberMarker = resp.nextPartNumberMarker() + } + + acc.toSeq + } } diff --git a/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala b/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala index 2a67440cf0e..8287a8b8cbd 100644 --- a/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala +++ b/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala @@ -27,7 +27,8 @@ import org.apache.texera.amber.config.StorageConfig import org.apache.texera.amber.core.storage.model.OnDataset import org.apache.texera.amber.core.storage.util.LakeFSStorageClient import org.apache.texera.amber.core.storage.{DocumentFactory, FileResolver} -import org.apache.texera.auth.SessionUser +import org.apache.texera.auth.UploadTokenParser.UploadTokenPayload +import org.apache.texera.auth.{SessionUser, UploadTokenParser} import org.apache.texera.dao.SqlServer import org.apache.texera.dao.SqlServer.withTransaction import org.apache.texera.dao.jooq.generated.enums.PrivilegeEnum @@ -640,138 +641,55 @@ class DatasetResource { @QueryParam("ownerEmail") ownerEmail: String, @QueryParam("datasetName") datasetName: String, @QueryParam("filePath") encodedUrl: String, - @QueryParam("uploadId") uploadId: Optional[String], @QueryParam("numParts") numParts: Optional[Integer], - payload: Map[ - String, - Any - ], // Expecting {"parts": [...], "physicalAddress": "s3://bucket/path"} + payload: Map[String, Any], @Auth user: SessionUser ): Response = { val uid = user.getUid + operationType.toLowerCase match { + case "init" => initMultipartUpload(ownerEmail, datasetName, encodedUrl, numParts, uid) + case "finish" => finishMultipartUpload(payload, uid) + case "abort" => abortMultipartUpload(payload, uid) + case _ => + throw new BadRequestException("Invalid type parameter. Use 'init', 'finish', or 'abort'.") + } + } - withTransaction(context) { ctx => - val dataset = context - .select(DATASET.fields: _*) - .from(DATASET) - .leftJoin(USER) - .on(USER.UID.eq(DATASET.OWNER_UID)) - .where(USER.EMAIL.eq(ownerEmail)) - .and(DATASET.NAME.eq(datasetName)) - .fetchOneInto(classOf[Dataset]) - if (dataset == null || !userHasWriteAccess(ctx, dataset.getDid, uid)) { - throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) - } - - // Decode the file path - val repositoryName = dataset.getRepositoryName - val filePath = URLDecoder.decode(encodedUrl, StandardCharsets.UTF_8.name()) - - operationType.toLowerCase match { - case "init" => - val numPartsValue = numParts.toScala.getOrElse( - throw new BadRequestException("numParts is required for initialization") - ) - - val presignedResponse = LakeFSStorageClient.initiatePresignedMultipartUploads( - repositoryName, - filePath, - numPartsValue - ) - Response - .ok( - Map( - "uploadId" -> presignedResponse.getUploadId, - "presignedUrls" -> presignedResponse.getPresignedUrls, - "physicalAddress" -> presignedResponse.getPhysicalAddress - ) - ) - .build() - - case "finish" => - val uploadIdValue = uploadId.toScala.getOrElse( - throw new BadRequestException("uploadId is required for completion") - ) - - // Extract parts from the payload - val partsList = payload.get("parts") match { - case Some(rawList: List[_]) => - try { - rawList.map { - case part: Map[_, _] => - val partMap = part.asInstanceOf[Map[String, Any]] - val partNumber = partMap.get("PartNumber") match { - case Some(i: Int) => i - case Some(s: String) => s.toInt - case _ => throw new BadRequestException("Invalid or missing PartNumber") - } - val eTag = partMap.get("ETag") match { - case Some(s: String) => s - case _ => throw new BadRequestException("Invalid or missing ETag") - } - (partNumber, eTag) - - case _ => - throw new BadRequestException("Each part must be a Map[String, Any]") - } - } catch { - case e: NumberFormatException => - throw new BadRequestException("PartNumber must be an integer", e) - } - - case _ => - throw new BadRequestException("Missing or invalid 'parts' list in payload") - } - - // Extract physical address from payload - val physicalAddress = payload.get("physicalAddress") match { - case Some(address: String) => address - case _ => throw new BadRequestException("Missing physicalAddress in payload") - } - - // Complete the multipart upload with parts and physical address - val objectStats = LakeFSStorageClient.completePresignedMultipartUploads( - repositoryName, - filePath, - uploadIdValue, - partsList, - physicalAddress - ) + @POST + @RolesAllowed(Array("REGULAR", "ADMIN")) + @Path("/multipart-upload/part") + @Consumes(Array(MediaType.APPLICATION_OCTET_STREAM)) + def uploadPart( + @QueryParam("uploadToken") uploadToken: String, + @QueryParam("partNumber") partNumber: Int, + partStream: InputStream, + @Context headers: HttpHeaders, + @Auth user: SessionUser + ): Response = { - Response - .ok( - Map( - "message" -> "Multipart upload completed successfully", - "filePath" -> objectStats.getPath - ) - ) - .build() + if (uploadToken == null || uploadToken.isEmpty) + throw new BadRequestException("token is required") - case "abort" => - val uploadIdValue = uploadId.toScala.getOrElse( - throw new BadRequestException("uploadId is required for abortion") - ) + if (partNumber < 1) + throw new BadRequestException("partNumber must be >= 1") - // Extract physical address from payload - val physicalAddress = payload.get("physicalAddress") match { - case Some(address: String) => address - case _ => throw new BadRequestException("Missing physicalAddress in payload") - } + val decoded = UploadTokenParser.decode(uploadToken) + val (dataset, bucket, key, uploadId, physicalAddress) = + resolveMultipartUploadContextFromToken(decoded, user.getUid) - // Abort the multipart upload - LakeFSStorageClient.abortPresignedMultipartUploads( - repositoryName, - filePath, - uploadIdValue, - physicalAddress - ) + val contentLenHeader = headers.getHeaderString(HttpHeaders.CONTENT_LENGTH) + val contentLength = Option(contentLenHeader).map(_.toLong) - Response.ok(Map("message" -> "Multipart upload aborted successfully")).build() + S3StorageClient.uploadPart( + bucket = bucket, + key = key, + uploadId = uploadId, + partNumber = partNumber, + inputStream = partStream, + contentLength = contentLength + ) - case _ => - throw new BadRequestException("Invalid type parameter. Use 'init', 'finish', or 'abort'.") - } - } + Response.ok().build() } @POST @@ -1372,4 +1290,169 @@ class DatasetResource { Right(response) } } + + // === Multipart helpers (stateless, token-based) === + /** + * Given a decoded token and current authenticated user, rediscover: + * (dataset, bucket, key, uploadId, physicalAddress) + * using only the data encrypted into the token. + */ + private def resolveMultipartUploadContextFromToken( + token: UploadTokenPayload, + currentUid: Int + ): (Dataset, String, String, String, String) = { + if (token.uid != currentUid) { + throw new ForbiddenException("User has no access to this upload") + } + + // 1) Check dataset and permissions + val dataset = withTransaction(context) { ctx => + val ds = getDatasetByID(ctx, token.did) + if (!userHasWriteAccess(ctx, token.did, currentUid)) { + throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) + } + ds + } + + // 2) parse physical address into bucket + key + val (bucket, key) = LakeFSStorageClient.parsePhysicalAddress(token.physicalAddress) + + // dataset, bucket, key, uploadId, physicalAddress + (dataset, bucket, key, token.uploadId, token.physicalAddress) + } + + /** + * Initialize a multipart upload for a given dataset + logical file path. + * + * Keeps the HTTP API the same but: + * - ignores numParts + * - does not use any presigned URLs from lakeFS + * - returns a stateless, encrypted uploadToken + */ + private def initMultipartUpload( + ownerEmail: String, + datasetName: String, + encodedUrl: String, + numParts: Optional[Integer], + uid: Int + ): Response = { + withTransaction(context) { ctx => + val dataset = ctx + .select(DATASET.fields: _*) + .from(DATASET) + .leftJoin(USER) + .on(USER.UID.eq(DATASET.OWNER_UID)) + .where(USER.EMAIL.eq(ownerEmail)) + .and(DATASET.NAME.eq(datasetName)) + .fetchOneInto(classOf[Dataset]) + + if (dataset == null || !userHasWriteAccess(ctx, dataset.getDid, uid)) { + throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) + } + + val repositoryName = dataset.getRepositoryName + val filePath = URLDecoder.decode(encodedUrl, StandardCharsets.UTF_8.name()) + + // We do NOT care about numParts or initial presigned URLs. + // We only need uploadId + physicalAddress. + val presign = + LakeFSStorageClient.initiatePresignedMultipartUploads(repositoryName, filePath, 1) + + val uploadIdStr = presign.getUploadId + val physicalAddress = presign.getPhysicalAddress + + val payload = UploadTokenParser.UploadTokenPayload( + version = UploadTokenParser.Version, + uploadId = uploadIdStr, + did = dataset.getDid.intValue(), + uid = uid, + filePath = filePath, + physicalAddress = physicalAddress + ) + + val token = UploadTokenParser.encode(payload) + + Response.ok(Map("uploadToken" -> token)).build() + } + } + + /** + * Complete a multipart upload: + * - token -> dataset + bucket/key/uploadId/physicalAddress + * - list parts from S3 (ListParts) + * - call lakeFS completePresignMultipartUpload with physicalAddress + */ + private def finishMultipartUpload( + payload: Map[String, Any], + uid: Int + ): Response = { + val tokenValueStr = payload + .get("uploadToken") + .map(_.asInstanceOf[String]) + .getOrElse { + throw new BadRequestException("uploadToken is required for completion") + } + + val decoded = UploadTokenParser.decode(tokenValueStr) + val (dataset, bucket, key, uploadId, physicalAddress) = + resolveMultipartUploadContextFromToken(decoded, uid) + + val partInfos = + S3StorageClient.listAllParts(bucket, key, uploadId) + + if (partInfos.isEmpty) { + throw new BadRequestException("No uploaded parts found for this upload") + } + + val partsList: List[(Int, String)] = + partInfos.map(pi => (pi.partNumber, pi.eTag)).toList + + val objectStats = LakeFSStorageClient.completePresignedMultipartUploads( + dataset.getRepositoryName, + decoded.filePath, + uploadId, + partsList, + physicalAddress + ) + + Response + .ok( + Map( + "message" -> "Multipart upload completed successfully", + "filePath" -> objectStats.getPath + ) + ) + .build() + } + + /** + * Abort a multipart upload: + * - token -> dataset + bucket/key/uploadId/physicalAddress + * - abort multipart in S3 + * - abort in lakeFS + */ + private def abortMultipartUpload( + payload: Map[String, Any], + uid: Int + ): Response = { + val tokenValueStr = payload + .get("uploadToken") + .map(_.asInstanceOf[String]) + .getOrElse { + throw new BadRequestException("uploadToken is required for abortion") + } + + val decoded = UploadTokenParser.decode(tokenValueStr) + val (dataset, _, _, uploadId, physicalAddress) = + resolveMultipartUploadContextFromToken(decoded, uid) + + LakeFSStorageClient.abortPresignedMultipartUploads( + dataset.getRepositoryName, + decoded.filePath, + uploadId, + physicalAddress + ) + + Response.ok(Map("message" -> "Multipart upload aborted successfully")).build() + } } diff --git a/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts b/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts index b4d12f5a28e..fff40cdf414 100644 --- a/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts +++ b/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts @@ -104,8 +104,8 @@ export class DatasetDetailComponent implements OnInit { // List of upload tasks – each task tracked by its filePath public uploadTasks: Array< MultipartUploadProgress & { - filePath: string; - } + filePath: string; + } > = []; @Output() userMakeChanges = new EventEmitter(); @@ -416,8 +416,7 @@ export class DatasetDetailComponent implements OnInit { filePath: file.name, percentage: 0, status: "initializing", - uploadId: "", - physicalAddress: "", + uploadToken: "", }); // Start multipart upload const subscription = this.datasetService @@ -558,21 +557,24 @@ export class DatasetDetailComponent implements OnInit { this.onUploadComplete(); } + if (!task.uploadToken) { + this.uploadTasks = this.uploadTasks.filter(t => t.filePath !== task.filePath); + return; + } + this.datasetService .finalizeMultipartUpload( this.ownerEmail, this.datasetName, task.filePath, - task.uploadId, - [], - task.physicalAddress, + task.uploadToken, true // abort flag ) .pipe(untilDestroyed(this)) .subscribe(() => { this.notificationService.info(`${task.filePath} uploading has been terminated`); }); - // Remove the aborted task immediately + this.uploadTasks = this.uploadTasks.filter(t => t.filePath !== task.filePath); } diff --git a/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts b/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts index c09125d73b1..41742dc14ca 100644 --- a/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts +++ b/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts @@ -27,6 +27,7 @@ import { DashboardDataset } from "../../../type/dashboard-dataset.interface"; import { DatasetFileNode } from "../../../../common/type/datasetVersionFileTree"; import { DatasetStagedObject } from "../../../../common/type/dataset-staged-object"; import { GuiConfigService } from "../../../../common/service/gui-config.service"; +import { AuthService } from "src/app/common/service/user/auth.service"; export const DATASET_BASE_URL = "dataset"; export const DATASET_CREATE_URL = DATASET_BASE_URL + "/create"; @@ -51,11 +52,10 @@ export interface MultipartUploadProgress { filePath: string; percentage: number; status: "initializing" | "uploading" | "finished" | "aborted"; - uploadId: string; - physicalAddress: string; - uploadSpeed?: number; // bytes per second - estimatedTimeRemaining?: number; // seconds - totalTime?: number; // total seconds taken + uploadToken: string; + uploadSpeed?: number; // bytes per second + estimatedTimeRemaining?: number; // seconds + totalTime?: number; // total seconds taken } @Injectable({ @@ -122,6 +122,7 @@ export class DatasetService { public retrieveAccessibleDatasets(): Observable { return this.http.get(`${AppSettings.getApiEndpoint()}/${DATASET_LIST_URL}`); } + public createDatasetVersion(did: number, newVersion: string): Observable { return this.http .post<{ @@ -141,6 +142,13 @@ export class DatasetService { /** * Handles multipart upload for large files using RxJS, * with a concurrency limit on how many parts we process in parallel. + * + * Backend flow: + * POST /dataset/multipart-upload?type=init&ownerEmail=...&datasetName=...&filePath=...&numParts=N + * -> { uploadToken } + * POST /dataset/multipart-upload/part?uploadToken=&partNumber= (body: raw chunk) + * POST /dataset/multipart-upload?type=finish (body: { uploadToken }) + * POST /dataset/multipart-upload?type=abort (body: { uploadToken }) */ public multipartUpload( ownerEmail: string, @@ -152,8 +160,8 @@ export class DatasetService { ): Observable { const partCount = Math.ceil(file.size / partSize); - return new Observable(observer => { - // Track upload progress for each part independently + return new Observable(observer => { + // Track upload progress (bytes) for each part independently const partProgress = new Map(); // Progress tracking state @@ -162,8 +170,15 @@ export class DatasetService { let lastETA = 0; let lastUpdateTime = 0; - // Calculate stats with smoothing + const lastStats = { + uploadSpeed: 0, + estimatedTimeRemaining: 0, + totalTime: 0, + }; + const getTotalTime = () => (startTime ? (Date.now() - startTime) / 1000 : 0); + + // Calculate stats with smoothing and simple throttling (~1s) const calculateStats = (totalUploaded: number) => { if (startTime === null) { startTime = Date.now(); @@ -172,25 +187,28 @@ export class DatasetService { const now = Date.now(); const elapsed = getTotalTime(); - // Throttle updates to every 1s const shouldUpdate = now - lastUpdateTime >= 1000; if (!shouldUpdate) { - return null; + // keep totalTime fresh even when throttled + lastStats.totalTime = elapsed; + return lastStats; } lastUpdateTime = now; - // Calculate speed with moving average const currentSpeed = elapsed > 0 ? totalUploaded / elapsed : 0; speedSamples.push(currentSpeed); - if (speedSamples.length > 5) speedSamples.shift(); - const avgSpeed = speedSamples.reduce((a, b) => a + b, 0) / speedSamples.length; + if (speedSamples.length > 5) { + speedSamples.shift(); + } + const avgSpeed = + speedSamples.length > 0 + ? speedSamples.reduce((a, b) => a + b, 0) / speedSamples.length + : 0; - // Calculate smooth ETA const remaining = file.size - totalUploaded; let eta = avgSpeed > 0 ? remaining / avgSpeed : 0; - eta = Math.min(eta, 24 * 60 * 60); // cap ETA at 24h, 86400 sec + eta = Math.min(eta, 24 * 60 * 60); // cap ETA at 24h - // Smooth ETA changes (limit to 30% change) if (lastETA > 0 && eta > 0) { const maxChange = lastETA * 0.3; const diff = Math.abs(eta - lastETA); @@ -200,229 +218,239 @@ export class DatasetService { } lastETA = eta; - // Near completion optimization const percentComplete = (totalUploaded / file.size) * 100; if (percentComplete > 95) { eta = Math.min(eta, 10); } - return { - uploadSpeed: avgSpeed, - estimatedTimeRemaining: Math.max(0, Math.round(eta)), - totalTime: elapsed, - }; + lastStats.uploadSpeed = avgSpeed; + lastStats.estimatedTimeRemaining = Math.max(0, Math.round(eta)); + lastStats.totalTime = elapsed; + + return lastStats; }; - const subscription = this.initiateMultipartUpload(ownerEmail, datasetName, filePath, partCount) + // 1. INIT: ask backend to create a LakeFS multipart upload session and get uploadToken + const initParams = new HttpParams() + .set("type", "init") + .set("ownerEmail", ownerEmail) + .set("datasetName", datasetName) + .set("filePath", encodeURIComponent(filePath)) + .set("numParts", partCount.toString()); + + const init$ = this.http.post<{ uploadToken: string }>( + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, + {}, + { params: initParams } + ); + + const subscription = init$ .pipe( - switchMap(initiateResponse => { - const { uploadId, presignedUrls, physicalAddress } = initiateResponse; - if (!uploadId) { + switchMap(initResp => { + const uploadToken = initResp.uploadToken; + if (!uploadToken) { observer.error(new Error("Failed to initiate multipart upload")); return EMPTY; } + + // Notify UI that upload is starting observer.next({ - filePath: filePath, + filePath, percentage: 0, status: "initializing", - uploadId: uploadId, - physicalAddress: physicalAddress, + uploadToken, uploadSpeed: 0, estimatedTimeRemaining: 0, totalTime: 0, }); - // Keep track of all uploaded parts - const uploadedParts: { PartNumber: number; ETag: string }[] = []; - - // 1) Convert presignedUrls into a stream of URLs - return from(presignedUrls).pipe( - // 2) Use mergeMap with concurrency limit to upload chunk by chunk - mergeMap((url, index) => { - const partNumber = index + 1; - const start = index * partSize; - const end = Math.min(start + partSize, file.size); - const chunk = file.slice(start, end); - - // Upload the chunk - return new Observable(partObserver => { - const xhr = new XMLHttpRequest(); - - xhr.upload.addEventListener("progress", event => { - if (event.lengthComputable) { - // Update this specific part's progress - partProgress.set(partNumber, event.loaded); - - // Calculate total progress across all parts - let totalUploaded = 0; - partProgress.forEach(bytes => (totalUploaded += bytes)); - const percentage = Math.round((totalUploaded / file.size) * 100); - const stats = calculateStats(totalUploaded); - - observer.next({ - filePath, - percentage: Math.min(percentage, 99), // Cap at 99% until finalized - status: "uploading", - uploadId, - physicalAddress, - ...stats, - }); - } - }); - - xhr.addEventListener("load", () => { - if (xhr.status === 200 || xhr.status === 201) { - const etag = xhr.getResponseHeader("ETag")?.replace(/"/g, ""); - if (!etag) { - partObserver.error(new Error(`Missing ETag for part ${partNumber}`)); - return; + // 2. Upload each part to /multipart-upload/part using XMLHttpRequest + return from(Array.from({ length: partCount }, (_, i) => i)).pipe( + mergeMap( + index => { + const partNumber = index + 1; + const start = index * partSize; + const end = Math.min(start + partSize, file.size); + const chunk = file.slice(start, end); + + return new Observable(partObserver => { + const xhr = new XMLHttpRequest(); + + xhr.upload.addEventListener("progress", event => { + if (event.lengthComputable) { + partProgress.set(partNumber, event.loaded); + + let totalUploaded = 0; + partProgress.forEach(bytes => { + totalUploaded += bytes; + }); + + const percentage = Math.round((totalUploaded / file.size) * 100); + const stats = calculateStats(totalUploaded); + + observer.next({ + filePath, + percentage: Math.min(percentage, 99), + status: "uploading", + uploadToken, + ...stats, + }); + } + }); + + xhr.addEventListener("load", () => { + if (xhr.status === 200 || xhr.status === 204) { + // Mark part as fully uploaded + partProgress.set(partNumber, chunk.size); + + let totalUploaded = 0; + partProgress.forEach(bytes => { + totalUploaded += bytes; + }); + + // Force stats recompute on completion + lastUpdateTime = 0; + const percentage = Math.round((totalUploaded / file.size) * 100); + const stats = calculateStats(totalUploaded); + + observer.next({ + filePath, + percentage: Math.min(percentage, 99), + status: "uploading", + uploadToken, + ...stats, + }); + + partObserver.complete(); + } else { + partObserver.error( + new Error(`Failed to upload part ${partNumber} (HTTP ${xhr.status})`) + ); } + }); - // Mark this part as fully uploaded - partProgress.set(partNumber, chunk.size); - uploadedParts.push({ PartNumber: partNumber, ETag: etag }); - - // Recalculate progress - let totalUploaded = 0; - partProgress.forEach(bytes => (totalUploaded += bytes)); - const percentage = Math.round((totalUploaded / file.size) * 100); - lastUpdateTime = 0; - const stats = calculateStats(totalUploaded); - - observer.next({ - filePath, - percentage: Math.min(percentage, 99), - status: "uploading", - uploadId, - physicalAddress, - ...stats, - }); - partObserver.complete(); - } else { + xhr.addEventListener("error", () => { + // Remove failed part from progress + partProgress.delete(partNumber); partObserver.error(new Error(`Failed to upload part ${partNumber}`)); - } - }); + }); - xhr.addEventListener("error", () => { - // Remove failed part from progress - partProgress.delete(partNumber); - partObserver.error(new Error(`Failed to upload part ${partNumber}`)); - }); + const partUrl = + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload/part` + + `?uploadToken=${encodeURIComponent(uploadToken)}&partNumber=${partNumber}`; - xhr.open("PUT", url); - xhr.send(chunk); - }); - }, concurrencyLimit), - - // 3) Collect results from all uploads (like forkJoin, but respects concurrency) - toArray(), - // 4) Finalize if all parts succeeded - switchMap(() => - this.finalizeMultipartUpload( - ownerEmail, - datasetName, - filePath, - uploadId, - uploadedParts, - physicalAddress, - false - ) + xhr.open("POST", partUrl); + xhr.setRequestHeader("Content-Type", "application/octet-stream"); + const token = AuthService.getAccessToken(); + if (token) { + xhr.setRequestHeader("Authorization", `Bearer ${token}`); + } + xhr.send(chunk); + return () => { + try { + xhr.abort(); + } catch {} + }; + }); + }, + concurrencyLimit ), + toArray(), // wait for all parts + // 3. FINISH: notify backend that all parts are done + switchMap(() => { + const finishParams = new HttpParams() + .set("type", "finish") + .set("ownerEmail", ownerEmail) + .set("datasetName", datasetName) + .set("filePath", encodeURIComponent(filePath)); + + const body = { uploadToken }; + + return this.http.post( + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, + body, + { params: finishParams } + ); + }), tap(() => { + const totalTime = getTotalTime(); observer.next({ filePath, percentage: 100, status: "finished", - uploadId: uploadId, - physicalAddress: physicalAddress, + uploadToken, uploadSpeed: 0, estimatedTimeRemaining: 0, - totalTime: getTotalTime(), + totalTime, }); observer.complete(); }), - catchError((error: unknown) => { - // If an error occurred, abort the upload + catchError(error => { + // On error, compute best-effort percentage from bytes we've seen + let totalUploaded = 0; + partProgress.forEach(bytes => { + totalUploaded += bytes; + }); + const percentage = + file.size > 0 ? Math.round((totalUploaded / file.size) * 100) : 0; + observer.next({ filePath, - percentage: Math.round((uploadedParts.length / partCount) * 100), + percentage, status: "aborted", - uploadId: uploadId, - physicalAddress: physicalAddress, + uploadToken, uploadSpeed: 0, estimatedTimeRemaining: 0, totalTime: getTotalTime(), }); - return this.finalizeMultipartUpload( - ownerEmail, - datasetName, - filePath, - uploadId, - uploadedParts, - physicalAddress, - true - ).pipe(switchMap(() => throwError(() => error))); + // Abort on backend + const abortParams = new HttpParams() + .set("type", "abort") + .set("ownerEmail", ownerEmail) + .set("datasetName", datasetName) + .set("filePath", encodeURIComponent(filePath)); + + const body = { uploadToken }; + + return this.http + .post( + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, + body, + { params: abortParams } + ) + .pipe( + switchMap(() => throwError(() => error)), + catchError(() => throwError(() => error)) + ); }) ); }) ) .subscribe({ - error: (err: unknown) => observer.error(err), + error: err => observer.error(err), }); + return () => subscription.unsubscribe(); }); } - /** - * Initiates a multipart upload and retrieves presigned URLs for each part. - * @param ownerEmail Owner's email - * @param datasetName Dataset Name - * @param filePath File path within the dataset - * @param numParts Number of parts for the multipart upload - */ - private initiateMultipartUpload( - ownerEmail: string, - datasetName: string, - filePath: string, - numParts: number - ): Observable<{ uploadId: string; presignedUrls: string[]; physicalAddress: string }> { - const params = new HttpParams() - .set("type", "init") - .set("ownerEmail", ownerEmail) - .set("datasetName", datasetName) - .set("filePath", encodeURIComponent(filePath)) - .set("numParts", numParts.toString()); - - return this.http.post<{ uploadId: string; presignedUrls: string[]; physicalAddress: string }>( - `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, - {}, - { params } - ); - } - - /** - * Completes or aborts a multipart upload, sending part numbers and ETags to the backend. - */ public finalizeMultipartUpload( ownerEmail: string, datasetName: string, filePath: string, - uploadId: string, - parts: { PartNumber: number; ETag: string }[], - physicalAddress: string, + uploadToken: string, isAbort: boolean ): Observable { const params = new HttpParams() .set("type", isAbort ? "abort" : "finish") .set("ownerEmail", ownerEmail) .set("datasetName", datasetName) - .set("filePath", encodeURIComponent(filePath)) - .set("uploadId", uploadId); + .set("filePath", encodeURIComponent(filePath)); return this.http.post( `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, - { parts, physicalAddress }, + { uploadToken }, { params } ); }