Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/operations/external-authenticators/tokens.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Only one of `static_jwks` or `static_jwks_file` keys must be present in one veri
:::

:::note
Only RS* family algorithms are supported!
For JWKS-based validators (`jwt_static_jwks` and `jwt_dynamic_jwks`), RS* and ES* family algorithms are supported.
:::

### JWT with remote JWKS
Expand Down
139 changes: 133 additions & 6 deletions src/Access/TokenProcessorsJWT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
#include <Common/Base64.h>
#include <Common/logger_useful.h>
#include <Poco/String.h>
#include <openssl/bio.h>
#include <openssl/core_names.h>
#include <openssl/evp.h>
#include <openssl/param_build.h>
#include <openssl/pem.h>
#include <cstring>

namespace DB {

Expand Down Expand Up @@ -156,6 +162,77 @@ bool check_claims(const String & claims, const picojson::value::object & payload
return check_claims(json.get<picojson::value::object>(), payload, "");
}

std::string create_public_key_from_ec_components(const std::string & x, const std::string & y, int curve_nid)
{
auto decode_base64url = [](const std::string & value)
{
return jwt::base::decode<jwt::alphabet::base64url>(jwt::base::pad<jwt::alphabet::base64url>(value));
};

auto decoded_x = decode_base64url(x);
auto decoded_y = decode_base64url(y);

size_t coordinate_size = 0;
if (curve_nid == NID_X9_62_prime256v1)
coordinate_size = 32;
else if (curve_nid == NID_secp384r1)
coordinate_size = 48;
else if (curve_nid == NID_secp521r1)
coordinate_size = 66;
else
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unsupported EC curve");

if (decoded_x.size() > coordinate_size || decoded_y.size() > coordinate_size)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: invalid EC key coordinates length");

std::vector<unsigned char> public_key_octets(1 + 2 * coordinate_size, 0);
public_key_octets[0] = 0x04; // Uncompressed point format.
std::memcpy(public_key_octets.data() + 1 + (coordinate_size - decoded_x.size()), decoded_x.data(), decoded_x.size());
std::memcpy(public_key_octets.data() + 1 + coordinate_size + (coordinate_size - decoded_y.size()), decoded_y.data(), decoded_y.size());

const char * group_name = OBJ_nid2sn(curve_nid);
if (!group_name)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unsupported EC curve");

std::unique_ptr<OSSL_PARAM_BLD, decltype(&OSSL_PARAM_BLD_free)> params_bld(OSSL_PARAM_BLD_new(), OSSL_PARAM_BLD_free);
if (!params_bld)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to allocate OpenSSL parameter builder");

if (OSSL_PARAM_BLD_push_utf8_string(params_bld.get(), OSSL_PKEY_PARAM_GROUP_NAME, group_name, 0) != 1)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to set EC group parameter");

if (OSSL_PARAM_BLD_push_octet_string(params_bld.get(), OSSL_PKEY_PARAM_PUB_KEY, public_key_octets.data(), public_key_octets.size()) != 1)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to set EC public key parameter");

std::unique_ptr<OSSL_PARAM, decltype(&OSSL_PARAM_free)> params(OSSL_PARAM_BLD_to_param(params_bld.get()), OSSL_PARAM_free);
if (!params)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to build OpenSSL parameters");

std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)> key_ctx(EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr), EVP_PKEY_CTX_free);
if (!key_ctx)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to create EVP key context");

if (EVP_PKEY_fromdata_init(key_ctx.get()) <= 0)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to initialize EVP key import");

EVP_PKEY * raw_evp_key = nullptr;
if (EVP_PKEY_fromdata(key_ctx.get(), &raw_evp_key, EVP_PKEY_PUBLIC_KEY, params.get()) <= 0)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to import EC public key");

std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> evp_key(raw_evp_key, EVP_PKEY_free);

std::unique_ptr<BIO, decltype(&BIO_free)> bio(BIO_new(BIO_s_mem()), BIO_free);
if (!bio)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to allocate BIO");

if (PEM_write_bio_PUBKEY(bio.get(), evp_key.get()) != 1)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to encode EC public key");

char * data = nullptr;
auto len = BIO_get_mem_data(bio.get(), &data);
return std::string(data, len);
}

}

namespace
Expand Down Expand Up @@ -392,12 +469,56 @@ bool JwksJwtProcessor::resolveAndValidate(TokenCredentials & credentials) const

if (public_key.empty())
{
if (!(jwk.has_jwk_claim("n") && jwk.has_jwk_claim("e")))
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: 'n' or 'e' not found", processor_name);
LOG_TRACE(getLogger("TokenAuthentication"), "{}: `issuer` or `x5c` not present, verifying {} with RSA components", processor_name, username);
const auto modulus = jwk.get_jwk_claim("n").as_string();
const auto exponent = jwk.get_jwk_claim("e").as_string();
public_key = jwt::helper::create_public_key_from_rsa_components(modulus, exponent);
const auto key_type = jwk.get_key_type();
if (key_type == "EC")
{
if (!(jwk.has_jwk_claim("x") && jwk.has_jwk_claim("y")))
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: missing 'x'/'y' claims for EC key type", processor_name);

int curve_nid = NID_undef;
std::optional<String> expected_crv;
if (algo == "es256")
{
curve_nid = NID_X9_62_prime256v1;
expected_crv = "P-256";
}
else if (algo == "es384")
{
curve_nid = NID_secp384r1;
expected_crv = "P-384";
}
else if (algo == "es512")
{
curve_nid = NID_secp521r1;
expected_crv = "P-521";
}

if (curve_nid == NID_undef)
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unknown algorithm {}", algo);

if (jwk.has_jwk_claim("crv"))
{
const auto crv = jwk.get_jwk_claim("crv").as_string();
if (expected_crv.has_value() && crv != expected_crv.value())
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: `crv` in JWK does not match JWT algorithm");
}

LOG_TRACE(getLogger("TokenAuthentication"), "{}: `x5c` not present, verifying {} with EC components", processor_name, username);
const auto x = jwk.get_jwk_claim("x").as_string();
const auto y = jwk.get_jwk_claim("y").as_string();
public_key = create_public_key_from_ec_components(x, y, curve_nid);
}
else if (key_type == "RSA")
{
if (!(jwk.has_jwk_claim("n") && jwk.has_jwk_claim("e")))
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: missing 'n'/'e' claims for RSA key type", processor_name);
LOG_TRACE(getLogger("TokenAuthentication"), "{}: `issuer` or `x5c` not present, verifying {} with RSA components", processor_name, username);
const auto modulus = jwk.get_jwk_claim("n").as_string();
const auto exponent = jwk.get_jwk_claim("e").as_string();
public_key = jwt::helper::create_public_key_from_rsa_components(modulus, exponent);
}
else
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK key type '{}'", processor_name, key_type);
}

if (jwk.has_algorithm() && Poco::toLower(jwk.get_algorithm()) != algo)
Expand All @@ -409,6 +530,12 @@ bool JwksJwtProcessor::resolveAndValidate(TokenCredentials & credentials) const
verifier = verifier.allow_algorithm(jwt::algorithm::rs384(public_key, "", "", ""));
else if (algo == "rs512")
verifier = verifier.allow_algorithm(jwt::algorithm::rs512(public_key, "", "", ""));
else if (algo == "es256")
verifier = verifier.allow_algorithm(jwt::algorithm::es256(public_key, "", "", ""));
else if (algo == "es384")
verifier = verifier.allow_algorithm(jwt::algorithm::es384(public_key, "", "", ""));
else if (algo == "es512")
verifier = verifier.allow_algorithm(jwt::algorithm::es512(public_key, "", "", ""));
else
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unknown algorithm {}", algo);

Expand Down
8 changes: 8 additions & 0 deletions tests/integration/test_jwt_auth/jwks_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def server():
"kaRv8XJbra0IeIINmKv0F4--ww8ZxXTR6cvI-MsArUiAPwzf7s5dMR4DNRG6YNTrPA0pTOqQE9sRPd62XsfU08plYm27naOUZ"
"O5avIPl1YO5I6Gi4kPdTvv3WFIy-QvoKoPhPCaD6EbdBpe8BbTQ",
"e": "AQAB"},
{
"kty": "EC",
"alg": "ES384",
"kid": "ecmykid",
"crv": "P-384",
"x": "ewdB5ypKwp641N5cYmKJvTiwWLIc_IJduJwur2mit1SgQpPZdUwpDV3aNIAmry4Y",
"y": "Jajx21k25o2K-ik86kaaawu6O84awaSmvSirJn8WCeEuotu3O-4Gn-ryOMuDsH76",
},
]
}
response.status = 200
Expand Down
17 changes: 17 additions & 0 deletions tests/integration/test_jwt_auth/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,20 @@ def test_jwks_server(started_cluster):
]
)
assert res == "jwt_user\n"


def test_jwks_server_ec_es384(started_cluster):
res = client.exec_in_container(
[
"bash",
"-c",
curl_with_jwt(
token="eyJhbGciOiJFUzM4NCIsImtpZCI6ImVjbXlraWQiLCJ0eXAiOiJKV1QifQ."
"eyJzdWIiOiJqd3RfdXNlciIsImlzcyI6InRlc3RfaXNzIn0."
"3iGUcKfc07oLN4XmBA6BJSGSfu7cBsdQ6KAFh1sV64rWYkVL5VzYlAskHaWZ4R9hR3QK0Bv0EPjia8Vo-xdN9jS7-fVB7RF0"
"rGvbTOIuxE-yDumCyji3MYoLpcbOVasU",
ip=cluster.get_instance_ip(instance.name),
),
]
)
assert res == "jwt_user\n"
Loading