-
Notifications
You must be signed in to change notification settings - Fork 111
refactor(dataset): Redirect multipart upload through File Service #4130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6d9dcdd
6d9f4f1
4d889d9
2408cfa
3c903a0
dd1c35c
24da656
0ca4363
489dcd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
| package org.apache.texera.auth | ||
|
|
||
| import org.apache.texera.auth.util.CryptoService | ||
| import org.apache.texera.config.AuthConfig | ||
|
|
||
| import com.fasterxml.jackson.annotation.{JsonCreator, JsonIgnoreProperties, JsonProperty} | ||
| import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} | ||
|
|
||
| object UploadTokenParser { | ||
|
|
||
| val Version: String = "v1" | ||
|
|
||
| @JsonIgnoreProperties(ignoreUnknown = true) | ||
| final case class UploadTokenPayload @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) ( | ||
| @JsonProperty(value = "version", required = true) | ||
| version: String, | ||
| @JsonProperty(value = "uploadId", required = true) | ||
| uploadId: String, | ||
| @JsonProperty(value = "did", required = true) | ||
| did: Int, | ||
| @JsonProperty(value = "uid", required = true) | ||
| uid: Int, | ||
| @JsonProperty(value = "filePath", required = true) | ||
| filePath: String, | ||
| @JsonProperty(value = "physicalAddress", required = true) | ||
| physicalAddress: String | ||
| ) | ||
|
|
||
| private lazy val cryptoService: CryptoService = | ||
| CryptoService(AuthConfig.uploadTokenSecretKey) | ||
|
|
||
| private lazy val objectMapper: ObjectMapper = | ||
| new ObjectMapper() | ||
| .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) | ||
|
|
||
| def encode(payload: UploadTokenPayload): String = { | ||
| val node = objectMapper.createObjectNode() | ||
| node.put("version", Version) | ||
| node.put("uploadId", payload.uploadId) | ||
| node.put("did", payload.did) | ||
| node.put("uid", payload.uid) | ||
| node.put("filePath", payload.filePath) | ||
| node.put("physicalAddress", payload.physicalAddress) | ||
|
|
||
| val rawJson = objectMapper.writeValueAsString(node) | ||
| cryptoService.encrypt(rawJson) | ||
| } | ||
|
|
||
| def decode(token: String): UploadTokenPayload = { | ||
| val decryptedJson = cryptoService.decrypt(token) | ||
| val decodedPayload = objectMapper.readValue(decryptedJson, classOf[UploadTokenPayload]) | ||
|
|
||
| decodedPayload | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| package org.apache.texera.auth.util | ||
|
|
||
| import java.nio.charset.StandardCharsets | ||
| import java.security.{MessageDigest, SecureRandom} | ||
| import java.util.Base64 | ||
| import javax.crypto.{Cipher, SecretKey} | ||
| import javax.crypto.spec.{GCMParameterSpec, SecretKeySpec} | ||
|
|
||
| /** | ||
| * Generic AES-GCM crypto utilities. | ||
| * | ||
| * Usage: | ||
| * val crypto = CryptoService("secret") | ||
| * val token = crypto.encrypt("hello") | ||
| * val plain = crypto.decrypt(token) | ||
| */ | ||
| final class CryptoService private (private val key: SecretKey) { | ||
|
|
||
| def encrypt(plain: String): String = | ||
| CryptoService.encrypt(plain, key) | ||
|
|
||
| def decrypt(token: String): String = | ||
| CryptoService.decrypt(token, key) | ||
| } | ||
|
|
||
| object CryptoService { | ||
| private val Algorithm = "AES/GCM/NoPadding" | ||
| private val IvLength = 12 | ||
| private val TagLength = 128 | ||
|
|
||
| private val random = new SecureRandom() | ||
|
|
||
| /** Build an instance from a String secret. */ | ||
| def apply(secret: String): CryptoService = | ||
| new CryptoService(deriveKeyFromSecret(secret)) | ||
|
|
||
| /** Derive a 256-bit AES key from a String. */ | ||
| def deriveKeyFromSecret(secret: String): SecretKey = { | ||
| val digest = MessageDigest.getInstance("SHA-256") | ||
| val keyBytes = digest.digest(secret.getBytes(StandardCharsets.UTF_8)) | ||
| new SecretKeySpec(keyBytes, "AES") | ||
| } | ||
|
|
||
| /** Low-level encrypt with explicit key. | ||
| * | ||
| * Algorithm: AES-GCM (AEAD). | ||
| * - Provides confidentiality (encryption) and integrity/authenticity (GCM tag). | ||
| * - Output format (before Base64): [ IV || (ciphertext || tag) ] | ||
| * In JCE, `doFinal()` in GCM returns ciphertext with the authentication tag appended. | ||
| */ | ||
| def encrypt(plain: String, key: SecretKey): String = { | ||
|
|
||
| // Allocate a fresh IV/nonce for this encryption. | ||
| // In GCM the IV must be unique per message under the same key; uniqueness prevents nonce-reuse attacks (keystream reuse and possible tag forgery). | ||
| val iv = new Array[Byte](IvLength) | ||
|
|
||
| // Fill IV with cryptographically secure random bytes. | ||
| // Random IVs make identical plaintexts encrypt to different outputs and make collisions extremely unlikely. | ||
| random.nextBytes(iv) | ||
|
|
||
| // Create a Cipher for the requested transformation (e.g., "AES/GCM/NoPadding"). | ||
| // GCM is the mode; "NoPadding" is standard for GCM in JCE. | ||
| val cipher = Cipher.getInstance(Algorithm) | ||
|
|
||
| // Initialize cipher for ENCRYPT using: | ||
| // - `key`: the AES key material | ||
| // - `GCMParameterSpec(TagLength, iv)`: supplies the IV and the desired authentication tag length. | ||
| // TagLength is in bits (commonly 128 bits = 16 bytes). | ||
| // The tag is what later allows decryption to detect tampering. | ||
| cipher.init(Cipher.ENCRYPT_MODE, key, new GCMParameterSpec(TagLength, iv)) | ||
|
|
||
| // Convert the plaintext string to bytes in a deterministic encoding (UTF-8). | ||
| // Crypto APIs operate on bytes; UTF-8 avoids platform-dependent encodings. | ||
| val plainBytes = plain.getBytes(StandardCharsets.UTF_8) | ||
|
|
||
| // Encrypt and compute the authentication tag. | ||
| // For AES-GCM in JCE, `doFinal()` returns: ciphertext || tag (tag appended at the end). | ||
| // Any modification of the ciphertext/tag will be detected during decrypt via tag verification. | ||
| val cipherText = cipher.doFinal(plainBytes) | ||
|
|
||
| // Build the final payload to return. | ||
| // We must include the IV with the output because decryption needs the same IV to recompute the keystream | ||
| // and verify the authentication tag. The IV is not secret, only required to be unique. | ||
| val combined = new Array[Byte](iv.length + cipherText.length) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comments to explain which part of algorithm is done by each line and why we need it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, thank you.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
|
||
| // Prefix the payload with the IV so the decrypt() routine can read it back. | ||
| System.arraycopy(iv, 0, combined, 0, iv.length) | ||
|
|
||
| // Append ciphertext+tag after the IV. | ||
| System.arraycopy(cipherText, 0, combined, iv.length, cipherText.length) | ||
|
|
||
| // Encode binary payload as URL-safe Base64 text (no '+', '/', or '=' padding). | ||
| // This makes it safe to store/transport in URLs, cookies, headers, etc. | ||
| Base64.getUrlEncoder.withoutPadding().encodeToString(combined) | ||
| } | ||
|
|
||
| /** Low-level decrypt with explicit key. */ | ||
| def decrypt(token: String, key: SecretKey): String = { | ||
| val combined = Base64.getUrlDecoder.decode(token) | ||
|
|
||
| if (combined.length <= IvLength) { | ||
| throw new IllegalArgumentException("Invalid encrypted token") | ||
| } | ||
|
|
||
| val iv = java.util.Arrays.copyOfRange(combined, 0, IvLength) | ||
| val cipherText = java.util.Arrays.copyOfRange(combined, IvLength, combined.length) | ||
|
|
||
| val cipher = Cipher.getInstance(Algorithm) | ||
| cipher.init(Cipher.DECRYPT_MODE, key, new GCMParameterSpec(TagLength, iv)) | ||
|
|
||
| val plainBytes = cipher.doFinal(cipherText) | ||
| new String(plainBytes, StandardCharsets.UTF_8) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,12 @@ object S3StorageClient { | |
| val MINIMUM_NUM_OF_MULTIPART_S3_PART: Long = 5L * 1024 * 1024 // 5 MiB | ||
| val MAXIMUM_NUM_OF_MULTIPART_S3_PARTS = 10_000 | ||
|
|
||
| /** Minimal info about an active multipart upload. */ | ||
| final case class MultipartUploadInfo(key: String, uploadId: String) | ||
|
|
||
| /** Minimal info about a completed part in an upload. */ | ||
| final case class PartInfo(partNumber: Int, eTag: String) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these case classes used as type definition?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I only keep the info needed and discard others |
||
|
|
||
| // Initialize MinIO-compatible S3 Client | ||
| private lazy val s3Client: S3Client = { | ||
| val credentials = AwsBasicCredentials.create(StorageConfig.s3Username, StorageConfig.s3Password) | ||
|
|
@@ -259,4 +265,60 @@ object S3StorageClient { | |
| DeleteObjectRequest.builder().bucket(bucketName).key(objectKey).build() | ||
| ) | ||
| } | ||
|
|
||
| def uploadPart( | ||
| bucket: String, | ||
| key: String, | ||
| uploadId: String, | ||
| partNumber: Int, | ||
| inputStream: InputStream, | ||
| contentLength: Option[Long] | ||
| ): Unit = { | ||
| val body: RequestBody = contentLength match { | ||
| case Some(len) => RequestBody.fromInputStream(inputStream, len) | ||
| case None => | ||
| val bytes = inputStream.readAllBytes() | ||
| RequestBody.fromBytes(bytes) | ||
| } | ||
|
|
||
| val req = UploadPartRequest | ||
| .builder() | ||
| .bucket(bucket) | ||
| .key(key) | ||
| .uploadId(uploadId) | ||
| .partNumber(partNumber) | ||
| .build() | ||
|
|
||
| s3Client.uploadPart(req, body) | ||
| } | ||
|
|
||
| /** | ||
| * List *all* parts for a given multipart upload (bucket + key + uploadId). | ||
| * Handles pagination (up to 1000 parts per page). | ||
| */ | ||
| def listAllParts( | ||
| bucket: String, | ||
| key: String, | ||
| uploadId: String | ||
| ): Seq[PartInfo] = { | ||
| val acc = scala.collection.mutable.ArrayBuffer.empty[PartInfo] | ||
| var partNumberMarker: Integer = null | ||
| var truncated = true | ||
|
|
||
| while (truncated) { | ||
| val builder = | ||
| ListPartsRequest.builder().bucket(bucket).key(key).uploadId(uploadId) | ||
| if (partNumberMarker != null) builder.partNumberMarker(partNumberMarker) | ||
|
|
||
| val resp = s3Client.listParts(builder.build()) | ||
| resp.parts().asScala.foreach { p => | ||
| acc += PartInfo(p.partNumber(), Option(p.eTag()).map(_.replace("\"", "")).orNull) | ||
| } | ||
|
|
||
| truncated = resp.isTruncated | ||
| partNumberMarker = resp.nextPartNumberMarker() | ||
| } | ||
|
|
||
| acc.toSeq | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, decode function has too much manual work which I believe library should already provide high level functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still with JSON there will be some manual work to be done, how will you modularize this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New version uses JSON