diff --git a/docs/en/operations/external-authenticators/tokens.md b/docs/en/operations/external-authenticators/tokens.md index 48c86c92b9cf..74c02a56900b 100644 --- a/docs/en/operations/external-authenticators/tokens.md +++ b/docs/en/operations/external-authenticators/tokens.md @@ -102,7 +102,7 @@ Only one of `static_jwks` or `static_jwks_file` keys must be present in one veri ::: :::note -Only RS* family algorithms are supported! +For JWKS-based validators (`jwt_static_jwks` and `jwt_dynamic_jwks`), RS* and ES* family algorithms are supported. ::: ### JWT with remote JWKS diff --git a/src/Access/TokenProcessorsJWT.cpp b/src/Access/TokenProcessorsJWT.cpp index c666aca15c3a..182556f24b6a 100644 --- a/src/Access/TokenProcessorsJWT.cpp +++ b/src/Access/TokenProcessorsJWT.cpp @@ -4,6 +4,12 @@ #include #include #include +#include +#include +#include +#include +#include +#include namespace DB { @@ -156,6 +162,77 @@ bool check_claims(const String & claims, const picojson::value::object & payload return check_claims(json.get(), payload, ""); } +std::string create_public_key_from_ec_components(const std::string & x, const std::string & y, int curve_nid) +{ + auto decode_base64url = [](const std::string & value) + { + return jwt::base::decode(jwt::base::pad(value)); + }; + + auto decoded_x = decode_base64url(x); + auto decoded_y = decode_base64url(y); + + size_t coordinate_size = 0; + if (curve_nid == NID_X9_62_prime256v1) + coordinate_size = 32; + else if (curve_nid == NID_secp384r1) + coordinate_size = 48; + else if (curve_nid == NID_secp521r1) + coordinate_size = 66; + else + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unsupported EC curve"); + + if (decoded_x.size() > coordinate_size || decoded_y.size() > coordinate_size) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: invalid EC key coordinates length"); + + std::vector public_key_octets(1 + 2 * coordinate_size, 0); + public_key_octets[0] = 0x04; // Uncompressed point format. + std::memcpy(public_key_octets.data() + 1 + (coordinate_size - decoded_x.size()), decoded_x.data(), decoded_x.size()); + std::memcpy(public_key_octets.data() + 1 + coordinate_size + (coordinate_size - decoded_y.size()), decoded_y.data(), decoded_y.size()); + + const char * group_name = OBJ_nid2sn(curve_nid); + if (!group_name) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unsupported EC curve"); + + std::unique_ptr params_bld(OSSL_PARAM_BLD_new(), OSSL_PARAM_BLD_free); + if (!params_bld) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to allocate OpenSSL parameter builder"); + + if (OSSL_PARAM_BLD_push_utf8_string(params_bld.get(), OSSL_PKEY_PARAM_GROUP_NAME, group_name, 0) != 1) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to set EC group parameter"); + + if (OSSL_PARAM_BLD_push_octet_string(params_bld.get(), OSSL_PKEY_PARAM_PUB_KEY, public_key_octets.data(), public_key_octets.size()) != 1) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to set EC public key parameter"); + + std::unique_ptr params(OSSL_PARAM_BLD_to_param(params_bld.get()), OSSL_PARAM_free); + if (!params) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to build OpenSSL parameters"); + + std::unique_ptr key_ctx(EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr), EVP_PKEY_CTX_free); + if (!key_ctx) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to create EVP key context"); + + if (EVP_PKEY_fromdata_init(key_ctx.get()) <= 0) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to initialize EVP key import"); + + EVP_PKEY * raw_evp_key = nullptr; + if (EVP_PKEY_fromdata(key_ctx.get(), &raw_evp_key, EVP_PKEY_PUBLIC_KEY, params.get()) <= 0) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to import EC public key"); + + std::unique_ptr evp_key(raw_evp_key, EVP_PKEY_free); + + std::unique_ptr bio(BIO_new(BIO_s_mem()), BIO_free); + if (!bio) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to allocate BIO"); + + if (PEM_write_bio_PUBKEY(bio.get(), evp_key.get()) != 1) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to encode EC public key"); + + char * data = nullptr; + auto len = BIO_get_mem_data(bio.get(), &data); + return std::string(data, len); +} + } namespace @@ -392,12 +469,56 @@ bool JwksJwtProcessor::resolveAndValidate(TokenCredentials & credentials) const if (public_key.empty()) { - if (!(jwk.has_jwk_claim("n") && jwk.has_jwk_claim("e"))) - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: 'n' or 'e' not found", processor_name); - LOG_TRACE(getLogger("TokenAuthentication"), "{}: `issuer` or `x5c` not present, verifying {} with RSA components", processor_name, username); - const auto modulus = jwk.get_jwk_claim("n").as_string(); - const auto exponent = jwk.get_jwk_claim("e").as_string(); - public_key = jwt::helper::create_public_key_from_rsa_components(modulus, exponent); + const auto key_type = jwk.get_key_type(); + if (key_type == "EC") + { + if (!(jwk.has_jwk_claim("x") && jwk.has_jwk_claim("y"))) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: missing 'x'/'y' claims for EC key type", processor_name); + + int curve_nid = NID_undef; + std::optional expected_crv; + if (algo == "es256") + { + curve_nid = NID_X9_62_prime256v1; + expected_crv = "P-256"; + } + else if (algo == "es384") + { + curve_nid = NID_secp384r1; + expected_crv = "P-384"; + } + else if (algo == "es512") + { + curve_nid = NID_secp521r1; + expected_crv = "P-521"; + } + + if (curve_nid == NID_undef) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unknown algorithm {}", algo); + + if (jwk.has_jwk_claim("crv")) + { + const auto crv = jwk.get_jwk_claim("crv").as_string(); + if (expected_crv.has_value() && crv != expected_crv.value()) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: `crv` in JWK does not match JWT algorithm"); + } + + LOG_TRACE(getLogger("TokenAuthentication"), "{}: `x5c` not present, verifying {} with EC components", processor_name, username); + const auto x = jwk.get_jwk_claim("x").as_string(); + const auto y = jwk.get_jwk_claim("y").as_string(); + public_key = create_public_key_from_ec_components(x, y, curve_nid); + } + else if (key_type == "RSA") + { + if (!(jwk.has_jwk_claim("n") && jwk.has_jwk_claim("e"))) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: missing 'n'/'e' claims for RSA key type", processor_name); + LOG_TRACE(getLogger("TokenAuthentication"), "{}: `issuer` or `x5c` not present, verifying {} with RSA components", processor_name, username); + const auto modulus = jwk.get_jwk_claim("n").as_string(); + const auto exponent = jwk.get_jwk_claim("e").as_string(); + public_key = jwt::helper::create_public_key_from_rsa_components(modulus, exponent); + } + else + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK key type '{}'", processor_name, key_type); } if (jwk.has_algorithm() && Poco::toLower(jwk.get_algorithm()) != algo) @@ -409,6 +530,12 @@ bool JwksJwtProcessor::resolveAndValidate(TokenCredentials & credentials) const verifier = verifier.allow_algorithm(jwt::algorithm::rs384(public_key, "", "", "")); else if (algo == "rs512") verifier = verifier.allow_algorithm(jwt::algorithm::rs512(public_key, "", "", "")); + else if (algo == "es256") + verifier = verifier.allow_algorithm(jwt::algorithm::es256(public_key, "", "", "")); + else if (algo == "es384") + verifier = verifier.allow_algorithm(jwt::algorithm::es384(public_key, "", "", "")); + else if (algo == "es512") + verifier = verifier.allow_algorithm(jwt::algorithm::es512(public_key, "", "", "")); else throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unknown algorithm {}", algo); diff --git a/tests/integration/test_jwt_auth/jwks_server/server.py b/tests/integration/test_jwt_auth/jwks_server/server.py index 96e07f02335e..67c4d6a4cf8d 100644 --- a/tests/integration/test_jwt_auth/jwks_server/server.py +++ b/tests/integration/test_jwt_auth/jwks_server/server.py @@ -16,6 +16,14 @@ def server(): "kaRv8XJbra0IeIINmKv0F4--ww8ZxXTR6cvI-MsArUiAPwzf7s5dMR4DNRG6YNTrPA0pTOqQE9sRPd62XsfU08plYm27naOUZ" "O5avIPl1YO5I6Gi4kPdTvv3WFIy-QvoKoPhPCaD6EbdBpe8BbTQ", "e": "AQAB"}, + { + "kty": "EC", + "alg": "ES384", + "kid": "ecmykid", + "crv": "P-384", + "x": "ewdB5ypKwp641N5cYmKJvTiwWLIc_IJduJwur2mit1SgQpPZdUwpDV3aNIAmry4Y", + "y": "Jajx21k25o2K-ik86kaaawu6O84awaSmvSirJn8WCeEuotu3O-4Gn-ryOMuDsH76", + }, ] } response.status = 200 diff --git a/tests/integration/test_jwt_auth/test.py b/tests/integration/test_jwt_auth/test.py index 14d42ae08bde..481c8117a73e 100644 --- a/tests/integration/test_jwt_auth/test.py +++ b/tests/integration/test_jwt_auth/test.py @@ -80,3 +80,20 @@ def test_jwks_server(started_cluster): ] ) assert res == "jwt_user\n" + + +def test_jwks_server_ec_es384(started_cluster): + res = client.exec_in_container( + [ + "bash", + "-c", + curl_with_jwt( + token="eyJhbGciOiJFUzM4NCIsImtpZCI6ImVjbXlraWQiLCJ0eXAiOiJKV1QifQ." + "eyJzdWIiOiJqd3RfdXNlciIsImlzcyI6InRlc3RfaXNzIn0." + "3iGUcKfc07oLN4XmBA6BJSGSfu7cBsdQ6KAFh1sV64rWYkVL5VzYlAskHaWZ4R9hR3QK0Bv0EPjia8Vo-xdN9jS7-fVB7RF0" + "rGvbTOIuxE-yDumCyji3MYoLpcbOVasU", + ip=cluster.get_instance_ip(instance.name), + ), + ] + ) + assert res == "jwt_user\n"