|
61 | 61 | import org.bouncycastle.asn1.x509.AlgorithmIdentifier; |
62 | 62 | import org.bouncycastle.asn1.x509.DigestInfo; |
63 | 63 | import org.bouncycastle.crypto.CryptoException; |
| 64 | +import org.bouncycastle.crypto.DataLengthException; |
64 | 65 | import org.bouncycastle.crypto.digests.SHA1Digest; |
65 | 66 | import org.bouncycastle.crypto.digests.SHA256Digest; |
66 | 67 | import org.bouncycastle.crypto.digests.SHA384Digest; |
|
84 | 85 | import org.jruby.RubyModule; |
85 | 86 | import org.jruby.RubyNumeric; |
86 | 87 | import org.jruby.RubyString; |
| 88 | +import org.jruby.RubySymbol; |
87 | 89 | import org.jruby.anno.JRubyMethod; |
88 | 90 | import org.jruby.exceptions.RaiseException; |
89 | 91 | import org.jruby.ext.openssl.util.ByteArrayOutputStream; |
@@ -711,9 +713,9 @@ public IRubyObject sign_raw(ThreadContext context, IRubyObject[] args) { |
711 | 713 | if (mgf1Alg == null) mgf1Alg = digestAlg; |
712 | 714 | if (saltLen < 0) saltLen = getDigestLength(digestAlg); |
713 | 715 | try { |
714 | | - return StringHelper.newString(runtime, signWithPSS(runtime, hashBytes, digestAlg, mgf1Alg, saltLen)); |
715 | | - } catch (CryptoException e) { |
716 | | - throw newPKeyError(runtime, e.getMessage()); |
| 716 | + return StringHelper.newString(runtime, signWithPSS(hashBytes, digestAlg, mgf1Alg, saltLen)); |
| 717 | + } catch (IllegalArgumentException | CryptoException e) { |
| 718 | + throw (RaiseException) newPKeyError(runtime, e.getMessage()).initCause(e); |
717 | 719 | } |
718 | 720 | } |
719 | 721 | } |
@@ -750,11 +752,8 @@ public IRubyObject verify_raw(ThreadContext context, IRubyObject[] args) { |
750 | 752 | String mgf1Alg = Utils.extractStringOpt(context, opts, "rsa_mgf1_md", true); |
751 | 753 | if (mgf1Alg == null) mgf1Alg = digestAlg; |
752 | 754 | if (saltLen < 0) saltLen = getDigestLength(digestAlg); |
753 | | - try { // verify_raw: input is already the hash → use PreHashedDigest (pass-through phase 1) |
754 | | - return runtime.newBoolean(verifyWithPSS(publicKey, hashBytes, digestAlg, true, mgf1Alg, saltLen, sigBytes)); |
755 | | - } catch (Exception e) { |
756 | | - throw newPKeyError(runtime, e.getMessage()); |
757 | | - } |
| 755 | + // verify_raw: input is already the hash → use PreHashedDigest (pass-through phase 1) |
| 756 | + return verifyPSS(runtime, true, hashBytes, digestAlg, mgf1Alg, saltLen, sigBytes); |
758 | 757 | } |
759 | 758 | } |
760 | 759 |
|
@@ -796,18 +795,137 @@ public IRubyObject verify(ThreadContext context, IRubyObject[] args) { |
796 | 795 | if (saltLen < 0) saltLen = getDigestLength(digestAlg); |
797 | 796 | byte[] sigBytes = sign.convertToString().getBytes(); |
798 | 797 | byte[] dataBytes = data.convertToString().getBytes(); |
799 | | - try { // verify (non-raw): feed raw data; PSSSigner will hash it internally via SHA-NNN |
800 | | - return runtime.newBoolean(verifyWithPSS(publicKey, dataBytes, digestAlg, false, mgf1Alg, saltLen, sigBytes)); |
801 | | - } catch (Exception e) { |
802 | | - throw newPKeyError(runtime, e.getMessage()); |
803 | | - } |
| 798 | + |
| 799 | + // verify (non-raw): feed raw data; PSSSigner will hash it internally via SHA-NNN |
| 800 | + return verifyPSS(runtime, false, dataBytes, digestAlg, mgf1Alg, saltLen, sigBytes); |
804 | 801 | } |
805 | 802 | } |
806 | 803 |
|
807 | 804 | // Fall back to standard PKey#verify (PKCS#1 v1.5) |
808 | 805 | return super.verify(digest, sign, data); |
809 | 806 | } |
810 | 807 |
|
| 808 | + // Override sign to support an optional 3rd opts argument. |
| 809 | + // When opts contains rsa_padding_mode: "pss", signs the raw data with RSA-PSS. |
| 810 | + // Otherwise delegates to PKey#sign (PKCS#1 v1.5). Non-Hash opts raise TypeError. |
| 811 | + @JRubyMethod(name = "sign", required = 2, optional = 1) |
| 812 | + public IRubyObject sign(ThreadContext context, IRubyObject[] args) { |
| 813 | + final Ruby runtime = context.runtime; |
| 814 | + final IRubyObject digest = args[0]; |
| 815 | + final IRubyObject data = args[1]; |
| 816 | + final IRubyObject opts = args.length > 2 ? args[2] : context.nil; |
| 817 | + |
| 818 | + if (!opts.isNil()) { |
| 819 | + if (!(opts instanceof RubyHash)) throw runtime.newTypeError("expected Hash"); |
| 820 | + String paddingMode = Utils.extractStringOpt(context, opts, "rsa_padding_mode", true); |
| 821 | + if ("pss".equalsIgnoreCase(paddingMode)) { |
| 822 | + if (privateKey == null) throw newPKeyError(runtime, "Private RSA key needed!"); |
| 823 | + final String digestAlg = getDigestAlgName(digest); |
| 824 | + int saltLen = Utils.extractIntOpt(context, opts, "rsa_pss_saltlen", -1, true); |
| 825 | + String mgf1Alg = Utils.extractStringOpt(context, opts, "rsa_mgf1_md", true); |
| 826 | + if (mgf1Alg == null) mgf1Alg = digestAlg; |
| 827 | + if (saltLen < 0) saltLen = maxPSSSaltLength(digestAlg, privateKey.getModulus().bitLength()); |
| 828 | + |
| 829 | + final byte[] signedData; |
| 830 | + try { |
| 831 | + signedData = signDataWithPSS(runtime, data.convertToString(), digestAlg, mgf1Alg, saltLen); |
| 832 | + } catch (IllegalArgumentException | DataLengthException | CryptoException e) { |
| 833 | + throw (RaiseException) newPKeyError(runtime, e.getMessage()).initCause(e); |
| 834 | + } |
| 835 | + return StringHelper.newString(runtime, signedData); |
| 836 | + } |
| 837 | + } |
| 838 | + return super.sign(digest, data); // PKCS#1 v1.5 fallback |
| 839 | + } |
| 840 | + |
| 841 | + // sign_pss(digest, data, salt_length:, mgf1_hash:) |
| 842 | + // Signs data with RSA-PSS. salt_length accepts :digest, :max, :auto, or an integer. |
| 843 | + @JRubyMethod(name = "sign_pss", required = 2, optional = 1) |
| 844 | + public IRubyObject sign_pss(ThreadContext context, IRubyObject[] args) { |
| 845 | + final Ruby runtime = context.runtime; |
| 846 | + if (privateKey == null) throw newPKeyError(runtime, "Private RSA key needed!"); |
| 847 | + final String digestAlg = getDigestAlgName(args[0]); |
| 848 | + final IRubyObject opts = args.length > 2 ? args[2] : context.nil; |
| 849 | + final int maxSalt = maxPSSSaltLength(digestAlg, privateKey.getModulus().bitLength()); |
| 850 | + |
| 851 | + String mgf1Alg = Utils.extractStringOpt(context, opts, "mgf1_hash"); |
| 852 | + if (mgf1Alg == null) mgf1Alg = digestAlg; |
| 853 | + |
| 854 | + final IRubyObject saltLenArg = opts instanceof RubyHash ? |
| 855 | + ((RubyHash) opts).fastARef(runtime.newSymbol("salt_length")) : null; |
| 856 | + final int saltLen; |
| 857 | + if (saltLenArg instanceof RubySymbol) { |
| 858 | + String sym = saltLenArg.asJavaString(); |
| 859 | + if ("digest".equals(sym)) saltLen = getDigestLength(digestAlg); |
| 860 | + else if ("max".equals(sym) || "auto".equals(sym)) saltLen = maxSalt; |
| 861 | + else throw runtime.newArgumentError("unknown salt_length: " + sym); |
| 862 | + } else if (saltLenArg != null && !saltLenArg.isNil()) { |
| 863 | + saltLen = RubyNumeric.fix2int(saltLenArg); |
| 864 | + } else { |
| 865 | + saltLen = maxSalt; |
| 866 | + } |
| 867 | + |
| 868 | + final byte[] signedData; |
| 869 | + try { |
| 870 | + signedData = signDataWithPSS(runtime, args[1].convertToString(), digestAlg, mgf1Alg, saltLen); |
| 871 | + } catch (IllegalArgumentException | DataLengthException | CryptoException e) { |
| 872 | + throw (RaiseException) newPKeyError(runtime, e.getMessage()).initCause(e); |
| 873 | + } |
| 874 | + return StringHelper.newString(runtime, signedData); |
| 875 | + } |
| 876 | + |
| 877 | + // verify_pss(digest, signature, data, salt_length:, mgf1_hash:) |
| 878 | + // Verifies a PSS signature. salt_length accepts :auto, :max, :digest, or an integer. |
| 879 | + @JRubyMethod(name = "verify_pss", required = 3, optional = 1) |
| 880 | + public IRubyObject verify_pss(ThreadContext context, IRubyObject[] args) { |
| 881 | + final Ruby runtime = context.runtime; |
| 882 | + final String digestAlg = getDigestAlgName(args[0]); |
| 883 | + final byte[] sigBytes = args[1].convertToString().getBytes(); |
| 884 | + final byte[] dataBytes = args[2].convertToString().getBytes(); |
| 885 | + final IRubyObject opts = args.length > 3 ? args[3] : context.nil; |
| 886 | + |
| 887 | + String mgf1Alg = Utils.extractStringOpt(context, opts, "mgf1_hash"); |
| 888 | + if (mgf1Alg == null) mgf1Alg = digestAlg; |
| 889 | + |
| 890 | + IRubyObject saltLenArg = opts instanceof RubyHash |
| 891 | + ? ((RubyHash) opts).fastARef(runtime.newSymbol("salt_length")) : null; |
| 892 | + int saltLen; |
| 893 | + if (saltLenArg instanceof RubySymbol) { |
| 894 | + String sym = saltLenArg.asJavaString(); |
| 895 | + if ("auto".equals(sym)) { |
| 896 | + saltLen = pssAutoSaltLength(publicKey, sigBytes, digestAlg, mgf1Alg); |
| 897 | + if (saltLen < 0) return runtime.getFalse(); |
| 898 | + } else if ("max".equals(sym)) { |
| 899 | + saltLen = maxPSSSaltLength(digestAlg, publicKey.getModulus().bitLength()); |
| 900 | + } else if ("digest".equals(sym)) { |
| 901 | + saltLen = getDigestLength(digestAlg); |
| 902 | + } else { |
| 903 | + throw runtime.newArgumentError("unknown salt_length: " + sym); |
| 904 | + } |
| 905 | + } else if (saltLenArg != null && !saltLenArg.isNil()) { |
| 906 | + saltLen = RubyNumeric.fix2int(saltLenArg); |
| 907 | + } else { |
| 908 | + saltLen = getDigestLength(digestAlg); |
| 909 | + } |
| 910 | + |
| 911 | + return verifyPSS(runtime, false, dataBytes, digestAlg, mgf1Alg, saltLen, sigBytes); |
| 912 | + } |
| 913 | + |
| 914 | + private IRubyObject verifyPSS(final Ruby runtime, final boolean rawVerify, |
| 915 | + final byte[] dataBytes, final String digestAlg, |
| 916 | + final String mgf1Alg, final int saltLen, final byte[] sigBytes) { |
| 917 | + boolean verified; |
| 918 | + try { |
| 919 | + verified = verifyWithPSS(rawVerify, publicKey, dataBytes, digestAlg, mgf1Alg, saltLen, sigBytes); |
| 920 | + } catch (IllegalArgumentException|IllegalStateException e) { |
| 921 | + verified = false; |
| 922 | + } catch (Exception e) { |
| 923 | + debugStackTrace(runtime, e); |
| 924 | + return runtime.getNil(); |
| 925 | + } |
| 926 | + return runtime.newBoolean(verified); |
| 927 | + } |
| 928 | + |
811 | 929 | private static byte[] buildDigestInfo(String digestAlg, byte[] hashBytes) throws IOException { |
812 | 930 | AlgorithmIdentifier algId = getDigestAlgId(digestAlg); |
813 | 931 | return new DigestInfo(algId, hashBytes).getEncoded("DER"); |
@@ -855,30 +973,28 @@ private static int getDigestLength(String digestAlg) { |
855 | 973 | // Signs pre-hashed bytes using RSA-PSS. PSSSigner internally reuses the content digest for |
856 | 974 | // BOTH hashing the message (phase 1) and hashing mDash (phase 2), so we use PreHashedDigest |
857 | 975 | // which passes through pre-hashed bytes verbatim in phase 1 and runs a real SHA hash in phase 2. |
858 | | - private byte[] signWithPSS(Ruby runtime, byte[] hashBytes, String digestAlg, String mgf1Alg, int saltLen) |
| 976 | + private byte[] signWithPSS(byte[] hashBytes, String digestAlg, String mgf1Alg, int saltLen) |
859 | 977 | throws CryptoException { |
860 | 978 | org.bouncycastle.crypto.Digest contentDigest = new PreHashedDigest(getDigestLength(digestAlg), digestAlg); |
861 | 979 | org.bouncycastle.crypto.Digest mgf1Digest = createBCDigest(mgf1Alg); |
862 | 980 | PSSSigner signer = new PSSSigner(new RSABlindedEngine(), contentDigest, mgf1Digest, saltLen); |
863 | 981 | RSAKeyParameters bcKey = toBCPrivateKeyParams(privateKey); |
864 | | - signer.init(true, new ParametersWithRandom(bcKey, getSecureRandom(runtime))); |
| 982 | + signer.init(true, new ParametersWithRandom(bcKey, getSecureRandom(getRuntime()))); |
865 | 983 | signer.update(hashBytes, 0, hashBytes.length); |
866 | 984 | return signer.generateSignature(); |
867 | 985 | } |
868 | 986 |
|
869 | | - // Verifies an RSA-PSS signature. When isRaw=true the input is a pre-computed hash (verify_raw); |
| 987 | + // Verifies an RSA-PSS signature. When rawVerify=true the input is a pre-computed hash (verify_raw); |
870 | 988 | // PreHashedDigest passes it through in phase 1 then uses a real SHA for hashing mDash in phase 2. |
871 | | - // When isRaw=false the input is raw data (verify with opts); a real SHA digest is used throughout. |
872 | | - private static boolean verifyWithPSS(RSAPublicKey pubKey, byte[] inputBytes, |
873 | | - String digestAlg, boolean isRaw, |
874 | | - String mgf1Alg, int saltLen, byte[] sigBytes) { |
875 | | - org.bouncycastle.crypto.Digest contentDigest = isRaw |
| 989 | + // When rawVerify=false the input is raw data (verify with opts); a real SHA digest is used throughout. |
| 990 | + private static boolean verifyWithPSS(final boolean rawVerify, RSAPublicKey pubKey, byte[] inputBytes, |
| 991 | + String digestAlg, String mgf1Alg, int saltLen, byte[] sigBytes) { |
| 992 | + org.bouncycastle.crypto.Digest contentDigest = rawVerify |
876 | 993 | ? new PreHashedDigest(getDigestLength(digestAlg), digestAlg) |
877 | 994 | : createBCDigest(digestAlg); |
878 | 995 | org.bouncycastle.crypto.Digest mgf1Digest = createBCDigest(mgf1Alg); |
879 | 996 | PSSSigner verifier = new PSSSigner(new RSABlindedEngine(), contentDigest, mgf1Digest, saltLen); |
880 | | - RSAKeyParameters bcPubKey = new RSAKeyParameters(false, pubKey.getModulus(), pubKey.getPublicExponent()); |
881 | | - verifier.init(false, bcPubKey); |
| 997 | + verifier.init(false, new RSAKeyParameters(false, pubKey.getModulus(), pubKey.getPublicExponent())); |
882 | 998 | verifier.update(inputBytes, 0, inputBytes.length); |
883 | 999 | return verifier.verifySignature(sigBytes); |
884 | 1000 | } |
@@ -951,6 +1067,73 @@ private static RSAKeyParameters toBCPrivateKeyParams(RSAPrivateKey privKey) { |
951 | 1067 | return new RSAKeyParameters(true, privKey.getModulus(), privKey.getPrivateExponent()); |
952 | 1068 | } |
953 | 1069 |
|
| 1070 | + // Signs raw (unhashed) data with RSA-PSS; PSSSigner applies the hash internally. |
| 1071 | + private byte[] signDataWithPSS(Ruby runtime, RubyString data, String digestAlg, String mgf1Alg, int saltLen) |
| 1072 | + throws CryptoException { |
| 1073 | + org.bouncycastle.crypto.Digest contentDigest = createBCDigest(digestAlg); |
| 1074 | + org.bouncycastle.crypto.Digest mgf1Digest = createBCDigest(mgf1Alg); |
| 1075 | + PSSSigner signer = new PSSSigner(new RSABlindedEngine(), contentDigest, mgf1Digest, saltLen); |
| 1076 | + signer.init(true, new ParametersWithRandom(toBCPrivateKeyParams(privateKey), getSecureRandom(runtime))); |
| 1077 | + final ByteList dataBytes = data.getByteList(); |
| 1078 | + signer.update(dataBytes.unsafeBytes(), dataBytes.getBegin(), dataBytes.getRealSize()); |
| 1079 | + return signer.generateSignature(); |
| 1080 | + } |
| 1081 | + |
| 1082 | + // Maximum PSS salt length per RFC 8017 §9.1.1: |
| 1083 | + // emLen = ceil((keyBits - 1) / 8), maxSalt = emLen - 2 - hLen |
| 1084 | + private static int maxPSSSaltLength(String digestAlg, int keyBits) { |
| 1085 | + int emLen = (keyBits - 1 + 7) / 8; |
| 1086 | + return emLen - 2 - getDigestLength(digestAlg); |
| 1087 | + } |
| 1088 | + |
| 1089 | + // Extracts the actual PSS salt length from a signature by parsing the PSS-encoded message. |
| 1090 | + // Returns -1 if the encoding is invalid (not a well-formed PSS block). |
| 1091 | + // This is used to implement salt_length: :auto in verify_pss. |
| 1092 | + private static int pssAutoSaltLength(RSAPublicKey pubKey, byte[] sigBytes, String digestAlg, String mgf1Alg) { |
| 1093 | + // Step 1: RSA public-key operation → encoded message (EM) |
| 1094 | + RSAKeyParameters bcPubKey = new RSAKeyParameters(false, pubKey.getModulus(), pubKey.getPublicExponent()); |
| 1095 | + RSABlindedEngine rsa = new RSABlindedEngine(); |
| 1096 | + rsa.init(false, bcPubKey); |
| 1097 | + byte[] em = rsa.processBlock(sigBytes, 0, sigBytes.length); |
| 1098 | + |
| 1099 | + int hLen = getDigestLength(digestAlg); |
| 1100 | + int emLen = em.length; |
| 1101 | + if (emLen < hLen + 2 || em[emLen - 1] != (byte) 0xBC) return -1; |
| 1102 | + |
| 1103 | + int dbLen = emLen - hLen - 1; |
| 1104 | + byte[] H = new byte[hLen]; |
| 1105 | + System.arraycopy(em, dbLen, H, 0, hLen); |
| 1106 | + |
| 1107 | + // Step 2: Recover DB = MGF1(H, dbLen) XOR maskedDB |
| 1108 | + byte[] DB = new byte[dbLen]; |
| 1109 | + System.arraycopy(em, 0, DB, 0, dbLen); |
| 1110 | + org.bouncycastle.crypto.Digest mgfDigest = createBCDigest(mgf1Alg); |
| 1111 | + int mgfHLen = mgfDigest.getDigestSize(); |
| 1112 | + byte[] hBuf = new byte[mgfHLen]; |
| 1113 | + byte[] ctr = new byte[4]; |
| 1114 | + for (int pos = 0, c = 0; pos < dbLen; c++) { |
| 1115 | + ctr[0] = (byte)(c >> 24); ctr[1] = (byte)(c >> 16); |
| 1116 | + ctr[2] = (byte)(c >> 8); ctr[3] = (byte) c; |
| 1117 | + mgfDigest.update(H, 0, hLen); |
| 1118 | + mgfDigest.update(ctr, 0, 4); |
| 1119 | + mgfDigest.doFinal(hBuf, 0); |
| 1120 | + int n = Math.min(mgfHLen, dbLen - pos); |
| 1121 | + for (int i = 0; i < n; i++) DB[pos + i] ^= hBuf[i]; |
| 1122 | + pos += n; |
| 1123 | + } |
| 1124 | + |
| 1125 | + // Step 3: Clear top bits per RFC 8017 §9.1.2 |
| 1126 | + int topBits = 8 * emLen - (pubKey.getModulus().bitLength() - 1); |
| 1127 | + if (topBits > 0) DB[0] &= (byte)(0xFF >>> topBits); |
| 1128 | + |
| 1129 | + // Step 4: Find the 0x01 separator; salt follows it |
| 1130 | + for (int i = 0; i < dbLen; i++) { |
| 1131 | + if (DB[i] == 0x01) return dbLen - i - 1; |
| 1132 | + if (DB[i] != 0x00) return -1; |
| 1133 | + } |
| 1134 | + return -1; |
| 1135 | + } |
| 1136 | + |
954 | 1137 | @JRubyMethod(name="d=") |
955 | 1138 | public synchronized IRubyObject set_d(final ThreadContext context, IRubyObject value) { |
956 | 1139 | if ( privateKey != null ) { |
|
0 commit comments