Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.texera.auth

import org.apache.texera.auth.util.CryptoService
import org.apache.texera.config.AuthConfig

import com.fasterxml.jackson.annotation.{JsonCreator, JsonIgnoreProperties, JsonProperty}
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}

object UploadTokenParser {

val Version: String = "v1"

@JsonIgnoreProperties(ignoreUnknown = true)
final case class UploadTokenPayload @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) (
@JsonProperty(value = "version", required = true)
version: String,
@JsonProperty(value = "uploadId", required = true)
uploadId: String,
@JsonProperty(value = "did", required = true)
did: Int,
@JsonProperty(value = "uid", required = true)
uid: Int,
@JsonProperty(value = "filePath", required = true)
filePath: String,
@JsonProperty(value = "physicalAddress", required = true)
physicalAddress: String
)

private lazy val cryptoService: CryptoService =
CryptoService(AuthConfig.uploadTokenSecretKey)

private lazy val objectMapper: ObjectMapper =
new ObjectMapper()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)

def encode(payload: UploadTokenPayload): String = {
val node = objectMapper.createObjectNode()
node.put("version", Version)
node.put("uploadId", payload.uploadId)
node.put("did", payload.did)
node.put("uid", payload.uid)
node.put("filePath", payload.filePath)
node.put("physicalAddress", payload.physicalAddress)

val rawJson = objectMapper.writeValueAsString(node)
cryptoService.encrypt(rawJson)
}

def decode(token: String): UploadTokenPayload = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, decode function has too much manual work which I believe library should already provide high level functions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still with JSON there will be some manual work to be done, how will you modularize this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New version uses JSON

val decryptedJson = cryptoService.decrypt(token)
val decodedPayload = objectMapper.readValue(decryptedJson, classOf[UploadTokenPayload])

decodedPayload
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.texera.auth.util

import java.nio.charset.StandardCharsets
import java.security.{MessageDigest, SecureRandom}
import java.util.Base64
import javax.crypto.{Cipher, SecretKey}
import javax.crypto.spec.{GCMParameterSpec, SecretKeySpec}

/**
* Generic AES-GCM crypto utilities.
*
* Usage:
* val crypto = CryptoService("secret")
* val token = crypto.encrypt("hello")
* val plain = crypto.decrypt(token)
*/
final class CryptoService private (private val key: SecretKey) {

def encrypt(plain: String): String =
CryptoService.encrypt(plain, key)

def decrypt(token: String): String =
CryptoService.decrypt(token, key)
}

object CryptoService {
private val Algorithm = "AES/GCM/NoPadding"
private val IvLength = 12
private val TagLength = 128

private val random = new SecureRandom()

/** Build an instance from a String secret. */
def apply(secret: String): CryptoService =
new CryptoService(deriveKeyFromSecret(secret))

/** Derive a 256-bit AES key from a String. */
def deriveKeyFromSecret(secret: String): SecretKey = {
val digest = MessageDigest.getInstance("SHA-256")
val keyBytes = digest.digest(secret.getBytes(StandardCharsets.UTF_8))
new SecretKeySpec(keyBytes, "AES")
}

/** Low-level encrypt with explicit key.
*
* Algorithm: AES-GCM (AEAD).
* - Provides confidentiality (encryption) and integrity/authenticity (GCM tag).
* - Output format (before Base64): [ IV || (ciphertext || tag) ]
* In JCE, `doFinal()` in GCM returns ciphertext with the authentication tag appended.
*/
def encrypt(plain: String, key: SecretKey): String = {

// Allocate a fresh IV/nonce for this encryption.
// In GCM the IV must be unique per message under the same key; uniqueness prevents nonce-reuse attacks (keystream reuse and possible tag forgery).
val iv = new Array[Byte](IvLength)

// Fill IV with cryptographically secure random bytes.
// Random IVs make identical plaintexts encrypt to different outputs and make collisions extremely unlikely.
random.nextBytes(iv)

// Create a Cipher for the requested transformation (e.g., "AES/GCM/NoPadding").
// GCM is the mode; "NoPadding" is standard for GCM in JCE.
val cipher = Cipher.getInstance(Algorithm)

// Initialize cipher for ENCRYPT using:
// - `key`: the AES key material
// - `GCMParameterSpec(TagLength, iv)`: supplies the IV and the desired authentication tag length.
// TagLength is in bits (commonly 128 bits = 16 bytes).
// The tag is what later allows decryption to detect tampering.
cipher.init(Cipher.ENCRYPT_MODE, key, new GCMParameterSpec(TagLength, iv))

// Convert the plaintext string to bytes in a deterministic encoding (UTF-8).
// Crypto APIs operate on bytes; UTF-8 avoids platform-dependent encodings.
val plainBytes = plain.getBytes(StandardCharsets.UTF_8)

// Encrypt and compute the authentication tag.
// For AES-GCM in JCE, `doFinal()` returns: ciphertext || tag (tag appended at the end).
// Any modification of the ciphertext/tag will be detected during decrypt via tag verification.
val cipherText = cipher.doFinal(plainBytes)

// Build the final payload to return.
// We must include the IV with the output because decryption needs the same IV to recompute the keystream
// and verify the authentication tag. The IV is not secret, only required to be unique.
val combined = new Array[Byte](iv.length + cipherText.length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments to explain which part of algorithm is done by each line and why we need it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thank you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// Prefix the payload with the IV so the decrypt() routine can read it back.
System.arraycopy(iv, 0, combined, 0, iv.length)

// Append ciphertext+tag after the IV.
System.arraycopy(cipherText, 0, combined, iv.length, cipherText.length)

// Encode binary payload as URL-safe Base64 text (no '+', '/', or '=' padding).
// This makes it safe to store/transport in URLs, cookies, headers, etc.
Base64.getUrlEncoder.withoutPadding().encodeToString(combined)
}

/** Low-level decrypt with explicit key. */
def decrypt(token: String, key: SecretKey): String = {
val combined = Base64.getUrlDecoder.decode(token)

if (combined.length <= IvLength) {
throw new IllegalArgumentException("Invalid encrypted token")
}

val iv = java.util.Arrays.copyOfRange(combined, 0, IvLength)
val cipherText = java.util.Arrays.copyOfRange(combined, IvLength, combined.length)

val cipher = Cipher.getInstance(Algorithm)
cipher.init(Cipher.DECRYPT_MODE, key, new GCMParameterSpec(TagLength, iv))

val plainBytes = cipher.doFinal(cipherText)
new String(plainBytes, StandardCharsets.UTF_8)
}
}
4 changes: 4 additions & 0 deletions common/config/src/main/resources/auth.conf
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,8 @@ auth {
256-bit-secret = "8a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d"
256-bit-secret = ${?AUTH_JWT_SECRET}
}
upload-token {
256-bit-secret = "8fnegtjgby3ith7gjw3htr8rj3rhbeub"
256-bit-secret = ${?AUTH_UPLOAD_TOKEN_SECRET}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ object AuthConfig {

// For storing the generated/configured secret
@volatile private var secretKey: String = _
@volatile private var uploadTokenSecret: String = _

// Read JWT secret key with support for random generation
def jwtSecretKey: String = {
Expand All @@ -45,6 +46,18 @@ object AuthConfig {
secretKey
}

/**
* Secret used for encrypting upload tokens
* Config path: auth.upload-token.256-bit-secret
*/
def uploadTokenSecretKey: String =
synchronized {
if (uploadTokenSecret == null) {
uploadTokenSecret = conf.getString("auth.upload-token.256-bit-secret")
}
uploadTokenSecret
}

private def getRandomHexString: String = {
val bytes = 32
val r = new Random()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,4 +358,13 @@ object LakeFSStorageClient {

branchesApi.resetBranch(repoName, branchName, resetCreation).execute()
}

def parsePhysicalAddress(address: String): (String, String) = {
// expected: "<scheme>://bucket/key..."
val uri = new java.net.URI(address)
val bucket = uri.getHost
val key = uri.getPath.stripPrefix("/")
(bucket, key)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ object S3StorageClient {
val MINIMUM_NUM_OF_MULTIPART_S3_PART: Long = 5L * 1024 * 1024 // 5 MiB
val MAXIMUM_NUM_OF_MULTIPART_S3_PARTS = 10_000

/** Minimal info about an active multipart upload. */
final case class MultipartUploadInfo(key: String, uploadId: String)

/** Minimal info about a completed part in an upload. */
final case class PartInfo(partNumber: Int, eTag: String)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these case classes used as type definition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I only keep the info needed and discard others


// Initialize MinIO-compatible S3 Client
private lazy val s3Client: S3Client = {
val credentials = AwsBasicCredentials.create(StorageConfig.s3Username, StorageConfig.s3Password)
Expand Down Expand Up @@ -259,4 +265,60 @@ object S3StorageClient {
DeleteObjectRequest.builder().bucket(bucketName).key(objectKey).build()
)
}

def uploadPart(
bucket: String,
key: String,
uploadId: String,
partNumber: Int,
inputStream: InputStream,
contentLength: Option[Long]
): Unit = {
val body: RequestBody = contentLength match {
case Some(len) => RequestBody.fromInputStream(inputStream, len)
case None =>
val bytes = inputStream.readAllBytes()
RequestBody.fromBytes(bytes)
}

val req = UploadPartRequest
.builder()
.bucket(bucket)
.key(key)
.uploadId(uploadId)
.partNumber(partNumber)
.build()

s3Client.uploadPart(req, body)
}

/**
* List *all* parts for a given multipart upload (bucket + key + uploadId).
* Handles pagination (up to 1000 parts per page).
*/
def listAllParts(
bucket: String,
key: String,
uploadId: String
): Seq[PartInfo] = {
val acc = scala.collection.mutable.ArrayBuffer.empty[PartInfo]
var partNumberMarker: Integer = null
var truncated = true

while (truncated) {
val builder =
ListPartsRequest.builder().bucket(bucket).key(key).uploadId(uploadId)
if (partNumberMarker != null) builder.partNumberMarker(partNumberMarker)

val resp = s3Client.listParts(builder.build())
resp.parts().asScala.foreach { p =>
acc += PartInfo(p.partNumber(), Option(p.eTag()).map(_.replace("\"", "")).orNull)
}

truncated = resp.isTruncated
partNumberMarker = resp.nextPartNumberMarker()
}

acc.toSeq
}
}
Loading