Skip to content

Commit 1dbb89e

Browse files
committed
[compat] support PKey::RSA#sign_pss and verify_pss
1 parent 06612ca commit 1dbb89e

File tree

3 files changed

+302
-23
lines changed

3 files changed

+302
-23
lines changed

src/main/java/org/jruby/ext/openssl/PKeyRSA.java

Lines changed: 206 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
6262
import org.bouncycastle.asn1.x509.DigestInfo;
6363
import org.bouncycastle.crypto.CryptoException;
64+
import org.bouncycastle.crypto.DataLengthException;
6465
import org.bouncycastle.crypto.digests.SHA1Digest;
6566
import org.bouncycastle.crypto.digests.SHA256Digest;
6667
import org.bouncycastle.crypto.digests.SHA384Digest;
@@ -84,6 +85,7 @@
8485
import org.jruby.RubyModule;
8586
import org.jruby.RubyNumeric;
8687
import org.jruby.RubyString;
88+
import org.jruby.RubySymbol;
8789
import org.jruby.anno.JRubyMethod;
8890
import org.jruby.exceptions.RaiseException;
8991
import org.jruby.ext.openssl.util.ByteArrayOutputStream;
@@ -711,9 +713,9 @@ public IRubyObject sign_raw(ThreadContext context, IRubyObject[] args) {
711713
if (mgf1Alg == null) mgf1Alg = digestAlg;
712714
if (saltLen < 0) saltLen = getDigestLength(digestAlg);
713715
try {
714-
return StringHelper.newString(runtime, signWithPSS(runtime, hashBytes, digestAlg, mgf1Alg, saltLen));
715-
} catch (CryptoException e) {
716-
throw newPKeyError(runtime, e.getMessage());
716+
return StringHelper.newString(runtime, signWithPSS(hashBytes, digestAlg, mgf1Alg, saltLen));
717+
} catch (IllegalArgumentException | CryptoException e) {
718+
throw (RaiseException) newPKeyError(runtime, e.getMessage()).initCause(e);
717719
}
718720
}
719721
}
@@ -750,11 +752,8 @@ public IRubyObject verify_raw(ThreadContext context, IRubyObject[] args) {
750752
String mgf1Alg = Utils.extractStringOpt(context, opts, "rsa_mgf1_md", true);
751753
if (mgf1Alg == null) mgf1Alg = digestAlg;
752754
if (saltLen < 0) saltLen = getDigestLength(digestAlg);
753-
try { // verify_raw: input is already the hash → use PreHashedDigest (pass-through phase 1)
754-
return runtime.newBoolean(verifyWithPSS(publicKey, hashBytes, digestAlg, true, mgf1Alg, saltLen, sigBytes));
755-
} catch (Exception e) {
756-
throw newPKeyError(runtime, e.getMessage());
757-
}
755+
// verify_raw: input is already the hash → use PreHashedDigest (pass-through phase 1)
756+
return verifyPSS(runtime, true, hashBytes, digestAlg, mgf1Alg, saltLen, sigBytes);
758757
}
759758
}
760759

@@ -796,18 +795,137 @@ public IRubyObject verify(ThreadContext context, IRubyObject[] args) {
796795
if (saltLen < 0) saltLen = getDigestLength(digestAlg);
797796
byte[] sigBytes = sign.convertToString().getBytes();
798797
byte[] dataBytes = data.convertToString().getBytes();
799-
try { // verify (non-raw): feed raw data; PSSSigner will hash it internally via SHA-NNN
800-
return runtime.newBoolean(verifyWithPSS(publicKey, dataBytes, digestAlg, false, mgf1Alg, saltLen, sigBytes));
801-
} catch (Exception e) {
802-
throw newPKeyError(runtime, e.getMessage());
803-
}
798+
799+
// verify (non-raw): feed raw data; PSSSigner will hash it internally via SHA-NNN
800+
return verifyPSS(runtime, false, dataBytes, digestAlg, mgf1Alg, saltLen, sigBytes);
804801
}
805802
}
806803

807804
// Fall back to standard PKey#verify (PKCS#1 v1.5)
808805
return super.verify(digest, sign, data);
809806
}
810807

808+
// Override sign to support an optional 3rd opts argument.
809+
// When opts contains rsa_padding_mode: "pss", signs the raw data with RSA-PSS.
810+
// Otherwise delegates to PKey#sign (PKCS#1 v1.5). Non-Hash opts raise TypeError.
811+
@JRubyMethod(name = "sign", required = 2, optional = 1)
812+
public IRubyObject sign(ThreadContext context, IRubyObject[] args) {
813+
final Ruby runtime = context.runtime;
814+
final IRubyObject digest = args[0];
815+
final IRubyObject data = args[1];
816+
final IRubyObject opts = args.length > 2 ? args[2] : context.nil;
817+
818+
if (!opts.isNil()) {
819+
if (!(opts instanceof RubyHash)) throw runtime.newTypeError("expected Hash");
820+
String paddingMode = Utils.extractStringOpt(context, opts, "rsa_padding_mode", true);
821+
if ("pss".equalsIgnoreCase(paddingMode)) {
822+
if (privateKey == null) throw newPKeyError(runtime, "Private RSA key needed!");
823+
final String digestAlg = getDigestAlgName(digest);
824+
int saltLen = Utils.extractIntOpt(context, opts, "rsa_pss_saltlen", -1, true);
825+
String mgf1Alg = Utils.extractStringOpt(context, opts, "rsa_mgf1_md", true);
826+
if (mgf1Alg == null) mgf1Alg = digestAlg;
827+
if (saltLen < 0) saltLen = maxPSSSaltLength(digestAlg, privateKey.getModulus().bitLength());
828+
829+
final byte[] signedData;
830+
try {
831+
signedData = signDataWithPSS(runtime, data.convertToString(), digestAlg, mgf1Alg, saltLen);
832+
} catch (IllegalArgumentException | DataLengthException | CryptoException e) {
833+
throw (RaiseException) newPKeyError(runtime, e.getMessage()).initCause(e);
834+
}
835+
return StringHelper.newString(runtime, signedData);
836+
}
837+
}
838+
return super.sign(digest, data); // PKCS#1 v1.5 fallback
839+
}
840+
841+
// sign_pss(digest, data, salt_length:, mgf1_hash:)
842+
// Signs data with RSA-PSS. salt_length accepts :digest, :max, :auto, or an integer.
843+
@JRubyMethod(name = "sign_pss", required = 2, optional = 1)
844+
public IRubyObject sign_pss(ThreadContext context, IRubyObject[] args) {
845+
final Ruby runtime = context.runtime;
846+
if (privateKey == null) throw newPKeyError(runtime, "Private RSA key needed!");
847+
final String digestAlg = getDigestAlgName(args[0]);
848+
final IRubyObject opts = args.length > 2 ? args[2] : context.nil;
849+
final int maxSalt = maxPSSSaltLength(digestAlg, privateKey.getModulus().bitLength());
850+
851+
String mgf1Alg = Utils.extractStringOpt(context, opts, "mgf1_hash");
852+
if (mgf1Alg == null) mgf1Alg = digestAlg;
853+
854+
final IRubyObject saltLenArg = opts instanceof RubyHash ?
855+
((RubyHash) opts).fastARef(runtime.newSymbol("salt_length")) : null;
856+
final int saltLen;
857+
if (saltLenArg instanceof RubySymbol) {
858+
String sym = saltLenArg.asJavaString();
859+
if ("digest".equals(sym)) saltLen = getDigestLength(digestAlg);
860+
else if ("max".equals(sym) || "auto".equals(sym)) saltLen = maxSalt;
861+
else throw runtime.newArgumentError("unknown salt_length: " + sym);
862+
} else if (saltLenArg != null && !saltLenArg.isNil()) {
863+
saltLen = RubyNumeric.fix2int(saltLenArg);
864+
} else {
865+
saltLen = maxSalt;
866+
}
867+
868+
final byte[] signedData;
869+
try {
870+
signedData = signDataWithPSS(runtime, args[1].convertToString(), digestAlg, mgf1Alg, saltLen);
871+
} catch (IllegalArgumentException | DataLengthException | CryptoException e) {
872+
throw (RaiseException) newPKeyError(runtime, e.getMessage()).initCause(e);
873+
}
874+
return StringHelper.newString(runtime, signedData);
875+
}
876+
877+
// verify_pss(digest, signature, data, salt_length:, mgf1_hash:)
878+
// Verifies a PSS signature. salt_length accepts :auto, :max, :digest, or an integer.
879+
@JRubyMethod(name = "verify_pss", required = 3, optional = 1)
880+
public IRubyObject verify_pss(ThreadContext context, IRubyObject[] args) {
881+
final Ruby runtime = context.runtime;
882+
final String digestAlg = getDigestAlgName(args[0]);
883+
final byte[] sigBytes = args[1].convertToString().getBytes();
884+
final byte[] dataBytes = args[2].convertToString().getBytes();
885+
final IRubyObject opts = args.length > 3 ? args[3] : context.nil;
886+
887+
String mgf1Alg = Utils.extractStringOpt(context, opts, "mgf1_hash");
888+
if (mgf1Alg == null) mgf1Alg = digestAlg;
889+
890+
IRubyObject saltLenArg = opts instanceof RubyHash
891+
? ((RubyHash) opts).fastARef(runtime.newSymbol("salt_length")) : null;
892+
int saltLen;
893+
if (saltLenArg instanceof RubySymbol) {
894+
String sym = saltLenArg.asJavaString();
895+
if ("auto".equals(sym)) {
896+
saltLen = pssAutoSaltLength(publicKey, sigBytes, digestAlg, mgf1Alg);
897+
if (saltLen < 0) return runtime.getFalse();
898+
} else if ("max".equals(sym)) {
899+
saltLen = maxPSSSaltLength(digestAlg, publicKey.getModulus().bitLength());
900+
} else if ("digest".equals(sym)) {
901+
saltLen = getDigestLength(digestAlg);
902+
} else {
903+
throw runtime.newArgumentError("unknown salt_length: " + sym);
904+
}
905+
} else if (saltLenArg != null && !saltLenArg.isNil()) {
906+
saltLen = RubyNumeric.fix2int(saltLenArg);
907+
} else {
908+
saltLen = getDigestLength(digestAlg);
909+
}
910+
911+
return verifyPSS(runtime, false, dataBytes, digestAlg, mgf1Alg, saltLen, sigBytes);
912+
}
913+
914+
private IRubyObject verifyPSS(final Ruby runtime, final boolean rawVerify,
915+
final byte[] dataBytes, final String digestAlg,
916+
final String mgf1Alg, final int saltLen, final byte[] sigBytes) {
917+
boolean verified;
918+
try {
919+
verified = verifyWithPSS(rawVerify, publicKey, dataBytes, digestAlg, mgf1Alg, saltLen, sigBytes);
920+
} catch (IllegalArgumentException|IllegalStateException e) {
921+
verified = false;
922+
} catch (Exception e) {
923+
debugStackTrace(runtime, e);
924+
return runtime.getNil();
925+
}
926+
return runtime.newBoolean(verified);
927+
}
928+
811929
private static byte[] buildDigestInfo(String digestAlg, byte[] hashBytes) throws IOException {
812930
AlgorithmIdentifier algId = getDigestAlgId(digestAlg);
813931
return new DigestInfo(algId, hashBytes).getEncoded("DER");
@@ -855,30 +973,28 @@ private static int getDigestLength(String digestAlg) {
855973
// Signs pre-hashed bytes using RSA-PSS. PSSSigner internally reuses the content digest for
856974
// BOTH hashing the message (phase 1) and hashing mDash (phase 2), so we use PreHashedDigest
857975
// which passes through pre-hashed bytes verbatim in phase 1 and runs a real SHA hash in phase 2.
858-
private byte[] signWithPSS(Ruby runtime, byte[] hashBytes, String digestAlg, String mgf1Alg, int saltLen)
976+
private byte[] signWithPSS(byte[] hashBytes, String digestAlg, String mgf1Alg, int saltLen)
859977
throws CryptoException {
860978
org.bouncycastle.crypto.Digest contentDigest = new PreHashedDigest(getDigestLength(digestAlg), digestAlg);
861979
org.bouncycastle.crypto.Digest mgf1Digest = createBCDigest(mgf1Alg);
862980
PSSSigner signer = new PSSSigner(new RSABlindedEngine(), contentDigest, mgf1Digest, saltLen);
863981
RSAKeyParameters bcKey = toBCPrivateKeyParams(privateKey);
864-
signer.init(true, new ParametersWithRandom(bcKey, getSecureRandom(runtime)));
982+
signer.init(true, new ParametersWithRandom(bcKey, getSecureRandom(getRuntime())));
865983
signer.update(hashBytes, 0, hashBytes.length);
866984
return signer.generateSignature();
867985
}
868986

869-
// Verifies an RSA-PSS signature. When isRaw=true the input is a pre-computed hash (verify_raw);
987+
// Verifies an RSA-PSS signature. When rawVerify=true the input is a pre-computed hash (verify_raw);
870988
// PreHashedDigest passes it through in phase 1 then uses a real SHA for hashing mDash in phase 2.
871-
// When isRaw=false the input is raw data (verify with opts); a real SHA digest is used throughout.
872-
private static boolean verifyWithPSS(RSAPublicKey pubKey, byte[] inputBytes,
873-
String digestAlg, boolean isRaw,
874-
String mgf1Alg, int saltLen, byte[] sigBytes) {
875-
org.bouncycastle.crypto.Digest contentDigest = isRaw
989+
// When rawVerify=false the input is raw data (verify with opts); a real SHA digest is used throughout.
990+
private static boolean verifyWithPSS(final boolean rawVerify, RSAPublicKey pubKey, byte[] inputBytes,
991+
String digestAlg, String mgf1Alg, int saltLen, byte[] sigBytes) {
992+
org.bouncycastle.crypto.Digest contentDigest = rawVerify
876993
? new PreHashedDigest(getDigestLength(digestAlg), digestAlg)
877994
: createBCDigest(digestAlg);
878995
org.bouncycastle.crypto.Digest mgf1Digest = createBCDigest(mgf1Alg);
879996
PSSSigner verifier = new PSSSigner(new RSABlindedEngine(), contentDigest, mgf1Digest, saltLen);
880-
RSAKeyParameters bcPubKey = new RSAKeyParameters(false, pubKey.getModulus(), pubKey.getPublicExponent());
881-
verifier.init(false, bcPubKey);
997+
verifier.init(false, new RSAKeyParameters(false, pubKey.getModulus(), pubKey.getPublicExponent()));
882998
verifier.update(inputBytes, 0, inputBytes.length);
883999
return verifier.verifySignature(sigBytes);
8841000
}
@@ -951,6 +1067,73 @@ private static RSAKeyParameters toBCPrivateKeyParams(RSAPrivateKey privKey) {
9511067
return new RSAKeyParameters(true, privKey.getModulus(), privKey.getPrivateExponent());
9521068
}
9531069

1070+
// Signs raw (unhashed) data with RSA-PSS; PSSSigner applies the hash internally.
1071+
private byte[] signDataWithPSS(Ruby runtime, RubyString data, String digestAlg, String mgf1Alg, int saltLen)
1072+
throws CryptoException {
1073+
org.bouncycastle.crypto.Digest contentDigest = createBCDigest(digestAlg);
1074+
org.bouncycastle.crypto.Digest mgf1Digest = createBCDigest(mgf1Alg);
1075+
PSSSigner signer = new PSSSigner(new RSABlindedEngine(), contentDigest, mgf1Digest, saltLen);
1076+
signer.init(true, new ParametersWithRandom(toBCPrivateKeyParams(privateKey), getSecureRandom(runtime)));
1077+
final ByteList dataBytes = data.getByteList();
1078+
signer.update(dataBytes.unsafeBytes(), dataBytes.getBegin(), dataBytes.getRealSize());
1079+
return signer.generateSignature();
1080+
}
1081+
1082+
// Maximum PSS salt length per RFC 8017 §9.1.1:
1083+
// emLen = ceil((keyBits - 1) / 8), maxSalt = emLen - 2 - hLen
1084+
private static int maxPSSSaltLength(String digestAlg, int keyBits) {
1085+
int emLen = (keyBits - 1 + 7) / 8;
1086+
return emLen - 2 - getDigestLength(digestAlg);
1087+
}
1088+
1089+
// Extracts the actual PSS salt length from a signature by parsing the PSS-encoded message.
1090+
// Returns -1 if the encoding is invalid (not a well-formed PSS block).
1091+
// This is used to implement salt_length: :auto in verify_pss.
1092+
private static int pssAutoSaltLength(RSAPublicKey pubKey, byte[] sigBytes, String digestAlg, String mgf1Alg) {
1093+
// Step 1: RSA public-key operation → encoded message (EM)
1094+
RSAKeyParameters bcPubKey = new RSAKeyParameters(false, pubKey.getModulus(), pubKey.getPublicExponent());
1095+
RSABlindedEngine rsa = new RSABlindedEngine();
1096+
rsa.init(false, bcPubKey);
1097+
byte[] em = rsa.processBlock(sigBytes, 0, sigBytes.length);
1098+
1099+
int hLen = getDigestLength(digestAlg);
1100+
int emLen = em.length;
1101+
if (emLen < hLen + 2 || em[emLen - 1] != (byte) 0xBC) return -1;
1102+
1103+
int dbLen = emLen - hLen - 1;
1104+
byte[] H = new byte[hLen];
1105+
System.arraycopy(em, dbLen, H, 0, hLen);
1106+
1107+
// Step 2: Recover DB = MGF1(H, dbLen) XOR maskedDB
1108+
byte[] DB = new byte[dbLen];
1109+
System.arraycopy(em, 0, DB, 0, dbLen);
1110+
org.bouncycastle.crypto.Digest mgfDigest = createBCDigest(mgf1Alg);
1111+
int mgfHLen = mgfDigest.getDigestSize();
1112+
byte[] hBuf = new byte[mgfHLen];
1113+
byte[] ctr = new byte[4];
1114+
for (int pos = 0, c = 0; pos < dbLen; c++) {
1115+
ctr[0] = (byte)(c >> 24); ctr[1] = (byte)(c >> 16);
1116+
ctr[2] = (byte)(c >> 8); ctr[3] = (byte) c;
1117+
mgfDigest.update(H, 0, hLen);
1118+
mgfDigest.update(ctr, 0, 4);
1119+
mgfDigest.doFinal(hBuf, 0);
1120+
int n = Math.min(mgfHLen, dbLen - pos);
1121+
for (int i = 0; i < n; i++) DB[pos + i] ^= hBuf[i];
1122+
pos += n;
1123+
}
1124+
1125+
// Step 3: Clear top bits per RFC 8017 §9.1.2
1126+
int topBits = 8 * emLen - (pubKey.getModulus().bitLength() - 1);
1127+
if (topBits > 0) DB[0] &= (byte)(0xFF >>> topBits);
1128+
1129+
// Step 4: Find the 0x01 separator; salt follows it
1130+
for (int i = 0; i < dbLen; i++) {
1131+
if (DB[i] == 0x01) return dbLen - i - 1;
1132+
if (DB[i] != 0x00) return -1;
1133+
}
1134+
return -1;
1135+
}
1136+
9541137
@JRubyMethod(name="d=")
9551138
public synchronized IRubyObject set_d(final ThreadContext context, IRubyObject value) {
9561139
if ( privateKey != null ) {

src/main/java/org/jruby/ext/openssl/Utils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ public void visit(IRubyObject key, IRubyObject value) {
191191
return ret;
192192
}
193193

194+
static String extractStringOpt(ThreadContext context, IRubyObject opts, String key) {
195+
return extractStringOpt(context, opts, key, false);
196+
}
197+
194198
static String extractStringOpt(ThreadContext context, IRubyObject opts,
195199
String key, boolean tryStringKey) {
196200
if (!(opts instanceof RubyHash)) return null;

0 commit comments

Comments
 (0)