From 88eb21cd666712f202bfc4532f7f88c90ad58d5b Mon Sep 17 00:00:00 2001 From: "Evgeny @ SimpleX Chat" <259188159+evgeny-simplex@users.noreply.github.com> Date: Tue, 12 May 2026 17:27:46 +0000 Subject: [PATCH 1/8] smp web: protocol encodings and x3dh --- smp-web/package.json | 2 + smp-web/src/crypto/ratchet.ts | 99 +++++++++++++++++++++++++ smp-web/src/protocol.ts | 135 +++++++++++++++++++++++++++++++++- tests/SMPWebTests.hs | 99 ++++++++++++++++++++++++- 4 files changed, 331 insertions(+), 4 deletions(-) create mode 100644 smp-web/src/crypto/ratchet.ts diff --git a/smp-web/package.json b/smp-web/package.json index 996d2896b..e4105725f 100644 --- a/smp-web/package.json +++ b/smp-web/package.json @@ -17,6 +17,8 @@ "build": "tsc" }, "dependencies": { + "@noble/ciphers": "^2.2.0", + "@noble/curves": "^2.2.0", "@noble/hashes": "^1.5.0", "@simplex-chat/xftp-web": "file:../xftp-web" }, diff --git a/smp-web/src/crypto/ratchet.ts b/smp-web/src/crypto/ratchet.ts new file mode 100644 index 000000000..c8948ea05 --- /dev/null +++ b/smp-web/src/crypto/ratchet.ts @@ -0,0 +1,99 @@ +// Double ratchet with X3DH key agreement. +// Mirrors: Simplex.Messaging.Crypto.Ratchet + +import {x448} from "@noble/curves/ed448.js" +import {hkdf} from "../crypto.js" +import {concatBytes} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" + +// -- X448 key operations + +export interface X448KeyPair { + publicKey: Uint8Array // 56 bytes + privateKey: Uint8Array // 56 bytes +} + +export function generateX448KeyPair(): X448KeyPair { + const privateKey = x448.utils.randomSecretKey() + const publicKey = x448.getPublicKey(privateKey) + return {publicKey, privateKey} +} + +export function x448DH(publicKey: Uint8Array, privateKey: Uint8Array): Uint8Array { + return x448.getSharedSecret(privateKey, publicKey) +} + +// DER encoding for X448 public keys (RFC 8410, SubjectPublicKeyInfo) +// SEQUENCE { SEQUENCE { OID 1.3.101.110 } BIT STRING { 0x00 <56 bytes> } } +const X448_PUBKEY_DER_PREFIX = new Uint8Array([ + 0x30, 0x42, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x6f, 0x03, 0x39, 0x00, +]) + +export function encodePubKeyX448(rawPubKey: Uint8Array): Uint8Array { + return concatBytes(X448_PUBKEY_DER_PREFIX, rawPubKey) +} + +export function decodePubKeyX448(der: Uint8Array): Uint8Array { + if (der.length !== 68) throw new Error("decodePubKeyX448: invalid length " + der.length) + for (let i = 0; i < X448_PUBKEY_DER_PREFIX.length; i++) { + if (der[i] !== X448_PUBKEY_DER_PREFIX[i]) throw new Error("decodePubKeyX448: invalid DER prefix") + } + return der.subarray(12) +} + +// -- X3DH key agreement (Ratchet.hs:499-508) + +export interface RatchetInitParams { + assocData: Uint8Array // pubKeyBytes(sk1) || pubKeyBytes(rk1) + ratchetKey: Uint8Array // 32 bytes (root key) + sndHK: Uint8Array // 32 bytes (header key) + rcvNextHK: Uint8Array // 32 bytes (next header key) +} + +// hkdf3 (Ratchet.hs:1174-1179) +// HKDF-SHA512, output 96 bytes, split 32+32+32 +function hkdf3(salt: Uint8Array, ikm: Uint8Array, info: string): [Uint8Array, Uint8Array, Uint8Array] { + const out = hkdf(salt, ikm, info, 96) + return [out.slice(0, 32), out.slice(32, 64), out.slice(64, 96)] +} + +const X3DH_SALT = new Uint8Array(64) // 64 zero bytes + +// pqX3dh (Ratchet.hs:499-508) +// Core X3DH: three DH results → HKDF → init params +function pqX3dh( + sk1: Uint8Array, rk1: Uint8Array, // public keys for assocData + dh1: Uint8Array, dh2: Uint8Array, dh3: Uint8Array, +): RatchetInitParams { + const assocData = concatBytes(sk1, rk1) + const dhs = concatBytes(dh1, dh2, dh3) // no PQ for MVP + const [hk, nhk, sk] = hkdf3(X3DH_SALT, dhs, "SimpleXX3DH") + return {assocData, ratchetKey: sk, sndHK: hk, rcvNextHK: nhk} +} + +// pqX3dhSnd (Ratchet.hs:467-480) +// Used by joiner (Bob) to initialize SENDING ratchet. +// Our keys: spk1, spk2 (private). Their keys: rk1, rk2 (public, from invitation). +export function pqX3dhSnd( + spk1: Uint8Array, spk2: Uint8Array, // our private keys + rk1: Uint8Array, rk2: Uint8Array, // their public keys (raw, not DER) +): RatchetInitParams { + const sk1Pub = x448.getPublicKey(spk1) + const dh1 = x448DH(rk1, spk2) + const dh2 = x448DH(rk2, spk1) + const dh3 = x448DH(rk2, spk2) + return pqX3dh(sk1Pub, rk1, dh1, dh2, dh3) +} + +// pqX3dhRcv (Ratchet.hs:483-497) +// Used by initiator (Alice) to initialize RECEIVING ratchet. +// Our keys: rpk1, rpk2 (private). Their keys: sk1, sk2 (public, from confirmation). +export function pqX3dhRcv( + rpk1: Uint8Array, rpk2: Uint8Array, // our private keys + sk1: Uint8Array, sk2: Uint8Array, // their public keys (raw, not DER) +): RatchetInitParams { + const rk1Pub = x448.getPublicKey(rpk1) + const dh1 = x448DH(sk2, rpk1) + const dh2 = x448DH(sk1, rpk2) + const dh3 = x448DH(sk2, rpk2) + return pqX3dh(sk1, rk1Pub, dh1, dh2, dh3) +} diff --git a/smp-web/src/protocol.ts b/smp-web/src/protocol.ts index 91ce86abc..dc707a71f 100644 --- a/smp-web/src/protocol.ts +++ b/smp-web/src/protocol.ts @@ -4,7 +4,9 @@ import { Decoder, concatBytes, encodeBytes, decodeBytes, - encodeLarge, decodeLarge + encodeLarge, decodeLarge, + encodeBool, decodeBool, + encodeMaybe, decodeMaybe, } from "@simplex-chat/xftp-web/dist/protocol/encoding.js" import {readTag, readSpace} from "@simplex-chat/xftp-web/dist/protocol/commands.js" @@ -83,7 +85,12 @@ export function decodeLNK(d: Decoder): LNKResponse { export type SMPResponse = | {type: "LNK", response: LNKResponse} + | {type: "IDS", response: IDSResponse} + | {type: "MSG", response: MSGResponse} | {type: "OK"} + | {type: "PONG"} + | {type: "END"} + | {type: "DELD"} | {type: "ERR", message: string} export function decodeResponse(d: Decoder): SMPResponse { @@ -93,7 +100,18 @@ export function decodeResponse(d: Decoder): SMPResponse { readSpace(d) return {type: "LNK", response: decodeLNK(d)} } + case "IDS": { + readSpace(d) + return {type: "IDS", response: decodeIDS(d)} + } + case "MSG": { + readSpace(d) + return {type: "MSG", response: decodeMSG(d)} + } case "OK": return {type: "OK"} + case "PONG": return {type: "PONG"} + case "END": return {type: "END"} + case "DELD": return {type: "DELD"} case "ERR": { readSpace(d) return {type: "ERR", message: readTag(d)} @@ -101,3 +119,118 @@ export function decodeResponse(d: Decoder): SMPResponse { default: throw new Error("unknown SMP response: " + tag) } } + +// -- SMP command encoders (Protocol.hs:1679-1715) + +// MsgFlags (Protocol.hs:884-892) +// Single byte: Bool encoding of notification flag +export function encodeMsgFlags(notification: boolean): Uint8Array { + return encodeBool(notification) +} + +// SubscriptionMode (Protocol.hs:651-659) +// 'S' = SMSubscribe, 'C' = SMOnlyCreate +export function encodeSubMode(subscribe: boolean): Uint8Array { + return ascii(subscribe ? "S" : "C") +} + +// NEW (Protocol.hs:1682-1689) +// For v19: e(NEW_, ' ', rKey, dhKey) <> e(auth_, subMode, queueReqData, ntfCreds) +// auth_ = Maybe SndPublicAuthKey (DER-encoded) +// queueReqData = Maybe QueueReqData +// ntfCreds = Maybe NewNtfCreds (not needed for widget) +export function encodeNEW( + rcvAuthKey: Uint8Array, // DER-encoded Ed25519 or X25519 public key + rcvDhKey: Uint8Array, // DER-encoded X25519 public key + sndAuthKey: Uint8Array | null, // DER-encoded, for TOFU sender auth + subscribe: boolean, +): Uint8Array { + return concatBytes( + ascii("NEW "), + encodeBytes(rcvAuthKey), + encodeBytes(rcvDhKey), + encodeMaybe(encodeBytes, sndAuthKey), + encodeSubMode(subscribe), + encodeMaybe(() => new Uint8Array(0), null), // queueReqData = Nothing (widget doesn't create links) + encodeMaybe(() => new Uint8Array(0), null), // ntfCreds = Nothing + ) +} + +// KEY (Protocol.hs:1692) +// KEY k -> e(KEY_, ' ', k) +export function encodeKEY(senderKey: Uint8Array): Uint8Array { + return concatBytes(ascii("KEY "), encodeBytes(senderKey)) +} + +// SKEY (Protocol.hs:1703) +// SKEY k -> e(SKEY_, ' ', k) +export function encodeSKEY(senderKey: Uint8Array): Uint8Array { + return concatBytes(ascii("SKEY "), encodeBytes(senderKey)) +} + +// SUB (Protocol.hs:1690) +export function encodeSUB(): Uint8Array { + return ascii("SUB") +} + +// ACK (Protocol.hs:1699) +// ACK msgId -> e(ACK_, ' ', msgId) +export function encodeACK(msgId: Uint8Array): Uint8Array { + return concatBytes(ascii("ACK "), encodeBytes(msgId)) +} + +// SEND (Protocol.hs:1704) +// SEND flags msg -> e(SEND_, ' ', flags, ' ', Tail msg) +export function encodeSEND(notification: boolean, msgBody: Uint8Array): Uint8Array { + return concatBytes( + ascii("SEND "), + encodeMsgFlags(notification), + ascii(" "), + msgBody, // Tail - no length prefix + ) +} + +// OFF (Protocol.hs:1700) +export function encodeOFF(): Uint8Array { + return ascii("OFF") +} + +// DEL (Protocol.hs:1701) +export function encodeDEL(): Uint8Array { + return ascii("DEL") +} + +// -- SMP response decoders + +// IDS (Protocol.hs:1914-1921) +// For v19: e(IDS_, ' ', rcvId, sndId, srvDh) <> e(queueMode, linkId, serviceId, ntfCreds) +export interface IDSResponse { + rcvId: Uint8Array + sndId: Uint8Array + srvDhKey: Uint8Array + queueMode: Uint8Array | null + linkId: Uint8Array | null +} + +export function decodeIDS(d: Decoder): IDSResponse { + const rcvId = decodeBytes(d) + const sndId = decodeBytes(d) + const srvDhKey = decodeBytes(d) + const queueMode = d.remaining() > 0 ? decodeMaybe(decodeBytes, d) : null + const linkId = d.remaining() > 0 ? decodeMaybe(decodeBytes, d) : null + // serviceId and ntfCreds - skip remaining + return {rcvId, sndId, srvDhKey, queueMode, linkId} +} + +// MSG (Protocol.hs:1927-1928) +// MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -> e(MSG_, ' ', msgId, Tail body) +export interface MSGResponse { + msgId: Uint8Array + msgBody: Uint8Array +} + +export function decodeMSG(d: Decoder): MSGResponse { + const msgId = decodeBytes(d) + const msgBody = d.takeAll() + return {msgId, msgBody} +} diff --git a/tests/SMPWebTests.hs b/tests/SMPWebTests.hs index 44db83693..6b3409099 100644 --- a/tests/SMPWebTests.hs +++ b/tests/SMPWebTests.hs @@ -12,6 +12,7 @@ -- Run: cabal test --test-option=--match="/SMP Web Client/" module SMPWebTests (smpWebTests) where +import Control.Concurrent.STM (atomically) import Control.Monad.Except (ExceptT, runExceptT) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC @@ -29,11 +30,11 @@ import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Crypto.ShortLink (contactShortLinkKdf, invShortLinkKdf) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String (strEncode) -import Simplex.Messaging.Protocol (EntityId (..), SMPServer, SubscriptionMode (..), pattern SMPServer) +import Simplex.Messaging.Protocol (EntityId (..), SMPServer, SubscriptionMode (..), MsgFlags (..), pattern SMPServer, encodeProtocol, Command (..), NewQueueReq (..), BrokerMsg (..), RcvMessage (..), EncRcvMsgBody (..), QueueIdsKeys (..)) import Simplex.Messaging.Server.Env.STM (AStoreType (..)) import Simplex.Messaging.Server.MsgStore.Types (SMSType (..), SQSType (..)) import Simplex.Messaging.Server.Web (attachStaticAndWS) -import Simplex.Messaging.Transport (TLS, smpBlockSize) +import Simplex.Messaging.Transport (TLS, smpBlockSize, currentServerSMPRelayVersion) import Simplex.Messaging.Transport.Client (TransportHost (..)) import SMPAgentClient (agentCfg, initAgentServers, testDB) import SMPClient (cfgWebOn, testKeyHash, testPort, withSmpServerConfig) @@ -52,7 +53,7 @@ impEnc :: String impEnc = "import { Decoder, decodeLarge } from '@simplex-chat/xftp-web/dist/protocol/encoding.js';" impProto_ :: String -impProto_ = "import { encodeTransmission, encodeBatch, decodeTransmission, encodeLGET, decodeLNK, decodeResponse } from './dist/protocol.js';" +impProto_ = "import { encodeTransmission, encodeBatch, decodeTransmission, encodeLGET, decodeLNK, decodeResponse, encodeNEW, encodeKEY, encodeSKEY, encodeSUB, encodeACK, encodeSEND, encodeOFF, encodeDEL } from './dist/protocol.js';" impProto :: String impProto = impEnc <> impProto_ @@ -74,6 +75,9 @@ impAgentProto = impEnc <> impAgentProto_ impCryptoShortLink :: String impCryptoShortLink = "import { contactShortLinkKdf, invShortLinkKdf, decryptLinkData } from './dist/crypto/shortLink.js';" +impRatchet :: String +impRatchet = "import { generateX448KeyPair, pqX3dhSnd, pqX3dhRcv, x448DH, encodePubKeyX448, decodePubKeyX448 } from './dist/crypto/ratchet.js';" + impCrypto :: String impCrypto = "import { sbcInit, sbcHkdf, sbEncryptBlock, sbDecryptBlock } from './dist/crypto.js';" @@ -160,6 +164,67 @@ smpWebTests_ = do <> jsOut ("new Uint8Array([r.type === 'OK' ? 1 : 0])") tsResult `shouldBe` B.singleton 1 + describe "commands" $ do + let v = currentServerSMPRelayVersion + + it "encodeSUB matches Haskell" $ do + let hsEncoded = encodeProtocol v SUB + tsEncoded <- callNode $ impProto <> jsOut "encodeSUB()" + tsEncoded `shouldBe` hsEncoded + + it "encodeKEY matches Haskell" $ do + let keyDer = B.pack [0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x6e, 0x03, 0x21, 0x00] <> B.pack [1..32] + hsEncoded = "KEY " <> smpEncode keyDer + tsEncoded <- callNode $ impProto + <> jsOut ("encodeKEY(" <> jsUint8 keyDer <> ")") + tsEncoded `shouldBe` hsEncoded + + it "encodeSKEY matches Haskell" $ do + let keyDer = B.pack [0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x6e, 0x03, 0x21, 0x00] <> B.pack [1..32] + hsEncoded = "SKEY " <> smpEncode keyDer + tsEncoded <- callNode $ impProto + <> jsOut ("encodeSKEY(" <> jsUint8 keyDer <> ")") + tsEncoded `shouldBe` hsEncoded + + it "encodeACK matches Haskell" $ do + let msgId = B.pack [1..24] + hsEncoded = encodeProtocol v (ACK msgId) + tsEncoded <- callNode $ impProto + <> jsOut ("encodeACK(" <> jsUint8 msgId <> ")") + tsEncoded `shouldBe` hsEncoded + + it "encodeSEND matches Haskell" $ do + let flags = MsgFlags {notification = True} + body = "hello world" + hsEncoded = encodeProtocol v (SEND flags body) + tsEncoded <- callNode $ impProto + <> jsOut ("encodeSEND(true, new TextEncoder().encode('hello world'))") + tsEncoded `shouldBe` hsEncoded + + it "decodes IDS response" $ do + let rcvId = B.pack [1..24] + sndId = B.pack [25..48] + srvDhKey = B.pack [0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x6e, 0x03, 0x21, 0x00] <> B.pack [50..81] + -- Manually encode IDS response: "IDS " <> rcvId <> sndId <> srvDhKey <> Maybe queueMode <> Maybe linkId ... + encoded = "IDS " <> smpEncode (EntityId rcvId) <> smpEncode (EntityId sndId) <> smpEncode srvDhKey + <> smpEncode (Nothing :: Maybe B.ByteString) <> smpEncode (Nothing :: Maybe B.ByteString) + <> smpEncode (Nothing :: Maybe B.ByteString) <> smpEncode (Nothing :: Maybe B.ByteString) + tsResult <- callNode $ impProto + <> "const r = decodeResponse(new Decoder(" <> jsUint8 encoded <> "));" + <> "if (r.type !== 'IDS') throw new Error('expected IDS, got ' + r.type);" + <> jsOut ("new Uint8Array([...r.response.rcvId, ...r.response.sndId])") + tsResult `shouldBe` (rcvId <> sndId) + + it "decodes Haskell-encoded MSG response" $ do + let msgId = B.pack [1..24] + body = "encrypted message body" + hsEncoded = "MSG " <> smpEncode msgId <> body + tsResult <- callNode $ impProto + <> "const r = decodeResponse(new Decoder(" <> jsUint8 hsEncoded <> "));" + <> "if (r.type !== 'MSG') throw new Error('expected MSG, got ' + r.type);" + <> jsOut ("new Uint8Array([...r.response.msgId, ...r.response.msgBody])") + tsResult `shouldBe` (msgId <> body) + describe "transport" $ do describe "SMPServerHandshake" $ do it "TypeScript parses Haskell-encoded server handshake (no authPubKey)" $ do @@ -237,6 +302,34 @@ smpWebTests_ = do <> jsOut ("new Uint8Array([...r.fixedData, 0, ...r.userData])") tsResult `shouldBe` (fixedPlain <> B.singleton 0 <> userPlain) + describe "crypto/ratchet" $ do + describe "X3DH" $ do + it "pqX3dhSnd and pqX3dhRcv produce same ratchetKey" $ do + -- TypeScript generates two key pairs, computes X3DH from both sides, verifies match + tsResult <- callNode $ impSodium <> impRatchet + <> "const alice1 = generateX448KeyPair();" + <> "const alice2 = generateX448KeyPair();" + <> "const bob1 = generateX448KeyPair();" + <> "const bob2 = generateX448KeyPair();" + -- Bob (joiner) inits sending ratchet with Alice's public keys + <> "const snd = pqX3dhSnd(bob1.privateKey, bob2.privateKey, alice1.publicKey, alice2.publicKey);" + -- Alice (initiator) inits receiving ratchet with Bob's public keys + <> "const rcv = pqX3dhRcv(alice1.privateKey, alice2.privateKey, bob1.publicKey, bob2.publicKey);" + -- ratchetKey, sndHK, rcvNextHK should match + <> "const match = snd.ratchetKey.every((b, i) => b === rcv.ratchetKey[i]) && snd.sndHK.every((b, i) => b === rcv.sndHK[i]) && snd.rcvNextHK.every((b, i) => b === rcv.rcvNextHK[i]);" + <> jsOut ("new Uint8Array([match ? 1 : 0, snd.ratchetKey.length, snd.sndHK.length, snd.rcvNextHK.length])") + tsResult `shouldBe` B.pack [1, 32, 32, 32] + + describe "DER encoding" $ do + it "X448 DER round-trips" $ do + tsResult <- callNode $ impRatchet + <> "const kp = generateX448KeyPair();" + <> "const der = encodePubKeyX448(kp.publicKey);" + <> "const raw = decodePubKeyX448(der);" + <> "const match = kp.publicKey.every((b, i) => b === raw[i]);" + <> jsOut ("new Uint8Array([match ? 1 : 0, der.length, raw.length])") + tsResult `shouldBe` B.pack [1, 68, 56] + describe "crypto/blockEncryption" $ do describe "sbcInit + sbcHkdf" $ do it "TypeScript produces same sbKey/nonce via sbcInit+sbcHkdf as Haskell" $ do From 0a505427eae074ea43741a90930c0f983f771a7e Mon Sep 17 00:00:00 2001 From: "Evgeny @ SimpleX Chat" <259188159+evgeny-simplex@users.noreply.github.com> Date: Tue, 12 May 2026 17:40:24 +0000 Subject: [PATCH 2/8] fix --- smp-web/src/protocol.ts | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/smp-web/src/protocol.ts b/smp-web/src/protocol.ts index dc707a71f..aef8d7e68 100644 --- a/smp-web/src/protocol.ts +++ b/smp-web/src/protocol.ts @@ -208,7 +208,7 @@ export interface IDSResponse { rcvId: Uint8Array sndId: Uint8Array srvDhKey: Uint8Array - queueMode: Uint8Array | null + queueMode: string | null // 'M' = Messaging, 'C' = Contact linkId: Uint8Array | null } @@ -216,8 +216,20 @@ export function decodeIDS(d: Decoder): IDSResponse { const rcvId = decodeBytes(d) const sndId = decodeBytes(d) const srvDhKey = decodeBytes(d) - const queueMode = d.remaining() > 0 ? decodeMaybe(decodeBytes, d) : null - const linkId = d.remaining() > 0 ? decodeMaybe(decodeBytes, d) : null + // v19: queueMode (Maybe QueueMode), linkId (Maybe ByteString), serviceId, ntfCreds + // QueueMode is encoded as Maybe Char ('M'/'C'), not Maybe ByteString + let queueMode: string | null = null + if (d.remaining() > 0) { + const qmByte = d.anyByte() + if (qmByte === 0x31) { // '1' = Just + queueMode = String.fromCharCode(d.anyByte()) + } + // '0' = Nothing, queueMode stays null + } + let linkId: Uint8Array | null = null + if (d.remaining() > 0) { + linkId = decodeMaybe(decodeBytes, d) + } // serviceId and ntfCreds - skip remaining return {rcvId, sndId, srvDhKey, queueMode, linkId} } From 69ae8d31fedae83d9345663b5fb09deb4009c232 Mon Sep 17 00:00:00 2001 From: "Evgeny @ SimpleX Chat" <259188159+evgeny-simplex@users.noreply.github.com> Date: Tue, 12 May 2026 21:57:59 +0000 Subject: [PATCH 3/8] strnup761 compiled to wasm --- smp-web/cbits/js_random.js | 14 ++ smp-web/cbits/sha512.c | 315 ++++++++++++++++++++++++++++++++ smp-web/cbits/sntrup761.d.mts | 11 ++ smp-web/cbits/sntrup761_wasm.c | 34 ++++ smp-web/package.json | 7 +- smp-web/src/crypto/sntrup761.ts | 89 +++++++++ tests/SMPWebTests.hs | 35 ++++ 7 files changed, 503 insertions(+), 2 deletions(-) create mode 100644 smp-web/cbits/js_random.js create mode 100644 smp-web/cbits/sha512.c create mode 100644 smp-web/cbits/sntrup761.d.mts create mode 100644 smp-web/cbits/sntrup761_wasm.c create mode 100644 smp-web/src/crypto/sntrup761.ts diff --git a/smp-web/cbits/js_random.js b/smp-web/cbits/js_random.js new file mode 100644 index 000000000..7c2eececf --- /dev/null +++ b/smp-web/cbits/js_random.js @@ -0,0 +1,14 @@ +addToLibrary({ + js_random_bytes: function(buf, len) { + var bytes = new Uint8Array(len); + if (typeof crypto !== 'undefined' && crypto.getRandomValues) { + crypto.getRandomValues(bytes); + } else { + // Node.js fallback + var nodeCrypto = require('crypto'); + var nodeBytes = nodeCrypto.randomBytes(len); + bytes.set(nodeBytes); + } + HEAPU8.set(bytes, buf); + } +}); diff --git a/smp-web/cbits/sha512.c b/smp-web/cbits/sha512.c new file mode 100644 index 000000000..8045fa8ee --- /dev/null +++ b/smp-web/cbits/sha512.c @@ -0,0 +1,315 @@ +/* +20080913 +D. J. Bernstein +Public domain. + +SHA-512 implementation from SUPERCOP/NaCl. +Source: https://bench.cr.yp.to/supercop.html + crypto_hashblocks/sha512/ref/blocks.c + crypto_hash/sha512/ref/hash.c + +Combined into a single file for WASM compilation alongside sntrup761. +*/ + +#include "sha512.h" + +typedef unsigned long long uint64; + +/* -- crypto_hashblocks_sha512 (blocks.c) -- */ + +static uint64 load_bigendian(const unsigned char *x) +{ + return + (uint64) (x[7]) \ + | (((uint64) (x[6])) << 8) \ + | (((uint64) (x[5])) << 16) \ + | (((uint64) (x[4])) << 24) \ + | (((uint64) (x[3])) << 32) \ + | (((uint64) (x[2])) << 40) \ + | (((uint64) (x[1])) << 48) \ + | (((uint64) (x[0])) << 56) + ; +} + +static void store_bigendian(unsigned char *x,uint64 u) +{ + x[7] = u; u >>= 8; + x[6] = u; u >>= 8; + x[5] = u; u >>= 8; + x[4] = u; u >>= 8; + x[3] = u; u >>= 8; + x[2] = u; u >>= 8; + x[1] = u; u >>= 8; + x[0] = u; +} + +#define SHR(x,c) ((x) >> (c)) +#define ROTR(x,c) (((x) >> (c)) | ((x) << (64 - (c)))) + +#define Ch(x,y,z) ((x & y) ^ (~x & z)) +#define Maj(x,y,z) ((x & y) ^ (x & z) ^ (y & z)) +#define Sigma0(x) (ROTR(x,28) ^ ROTR(x,34) ^ ROTR(x,39)) +#define Sigma1(x) (ROTR(x,14) ^ ROTR(x,18) ^ ROTR(x,41)) +#define sigma0(x) (ROTR(x, 1) ^ ROTR(x, 8) ^ SHR(x,7)) +#define sigma1(x) (ROTR(x,19) ^ ROTR(x,61) ^ SHR(x,6)) + +#define M(w0,w14,w9,w1) w0 = sigma1(w14) + w9 + sigma0(w1) + w0; + +#define EXPAND \ + M(w0 ,w14,w9 ,w1 ) \ + M(w1 ,w15,w10,w2 ) \ + M(w2 ,w0 ,w11,w3 ) \ + M(w3 ,w1 ,w12,w4 ) \ + M(w4 ,w2 ,w13,w5 ) \ + M(w5 ,w3 ,w14,w6 ) \ + M(w6 ,w4 ,w15,w7 ) \ + M(w7 ,w5 ,w0 ,w8 ) \ + M(w8 ,w6 ,w1 ,w9 ) \ + M(w9 ,w7 ,w2 ,w10) \ + M(w10,w8 ,w3 ,w11) \ + M(w11,w9 ,w4 ,w12) \ + M(w12,w10,w5 ,w13) \ + M(w13,w11,w6 ,w14) \ + M(w14,w12,w7 ,w15) \ + M(w15,w13,w8 ,w0 ) + +#define F(w,k) \ + T1 = h + Sigma1(e) + Ch(e,f,g) + k + w; \ + T2 = Sigma0(a) + Maj(a,b,c); \ + h = g; \ + g = f; \ + f = e; \ + e = d + T1; \ + d = c; \ + c = b; \ + b = a; \ + a = T1 + T2; + +static int crypto_hashblocks_sha512(unsigned char *statebytes,const unsigned char *in,unsigned long long inlen) +{ + uint64 state[8]; + uint64 a; + uint64 b; + uint64 c; + uint64 d; + uint64 e; + uint64 f; + uint64 g; + uint64 h; + uint64 T1; + uint64 T2; + + a = load_bigendian(statebytes + 0); state[0] = a; + b = load_bigendian(statebytes + 8); state[1] = b; + c = load_bigendian(statebytes + 16); state[2] = c; + d = load_bigendian(statebytes + 24); state[3] = d; + e = load_bigendian(statebytes + 32); state[4] = e; + f = load_bigendian(statebytes + 40); state[5] = f; + g = load_bigendian(statebytes + 48); state[6] = g; + h = load_bigendian(statebytes + 56); state[7] = h; + + while (inlen >= 128) { + uint64 w0 = load_bigendian(in + 0); + uint64 w1 = load_bigendian(in + 8); + uint64 w2 = load_bigendian(in + 16); + uint64 w3 = load_bigendian(in + 24); + uint64 w4 = load_bigendian(in + 32); + uint64 w5 = load_bigendian(in + 40); + uint64 w6 = load_bigendian(in + 48); + uint64 w7 = load_bigendian(in + 56); + uint64 w8 = load_bigendian(in + 64); + uint64 w9 = load_bigendian(in + 72); + uint64 w10 = load_bigendian(in + 80); + uint64 w11 = load_bigendian(in + 88); + uint64 w12 = load_bigendian(in + 96); + uint64 w13 = load_bigendian(in + 104); + uint64 w14 = load_bigendian(in + 112); + uint64 w15 = load_bigendian(in + 120); + + F(w0 ,0x428a2f98d728ae22ULL) + F(w1 ,0x7137449123ef65cdULL) + F(w2 ,0xb5c0fbcfec4d3b2fULL) + F(w3 ,0xe9b5dba58189dbbcULL) + F(w4 ,0x3956c25bf348b538ULL) + F(w5 ,0x59f111f1b605d019ULL) + F(w6 ,0x923f82a4af194f9bULL) + F(w7 ,0xab1c5ed5da6d8118ULL) + F(w8 ,0xd807aa98a3030242ULL) + F(w9 ,0x12835b0145706fbeULL) + F(w10,0x243185be4ee4b28cULL) + F(w11,0x550c7dc3d5ffb4e2ULL) + F(w12,0x72be5d74f27b896fULL) + F(w13,0x80deb1fe3b1696b1ULL) + F(w14,0x9bdc06a725c71235ULL) + F(w15,0xc19bf174cf692694ULL) + + EXPAND + + F(w0 ,0xe49b69c19ef14ad2ULL) + F(w1 ,0xefbe4786384f25e3ULL) + F(w2 ,0x0fc19dc68b8cd5b5ULL) + F(w3 ,0x240ca1cc77ac9c65ULL) + F(w4 ,0x2de92c6f592b0275ULL) + F(w5 ,0x4a7484aa6ea6e483ULL) + F(w6 ,0x5cb0a9dcbd41fbd4ULL) + F(w7 ,0x76f988da831153b5ULL) + F(w8 ,0x983e5152ee66dfabULL) + F(w9 ,0xa831c66d2db43210ULL) + F(w10,0xb00327c898fb213fULL) + F(w11,0xbf597fc7beef0ee4ULL) + F(w12,0xc6e00bf33da88fc2ULL) + F(w13,0xd5a79147930aa725ULL) + F(w14,0x06ca6351e003826fULL) + F(w15,0x142929670a0e6e70ULL) + + EXPAND + + F(w0 ,0x27b70a8546d22ffcULL) + F(w1 ,0x2e1b21385c26c926ULL) + F(w2 ,0x4d2c6dfc5ac42aedULL) + F(w3 ,0x53380d139d95b3dfULL) + F(w4 ,0x650a73548baf63deULL) + F(w5 ,0x766a0abb3c77b2a8ULL) + F(w6 ,0x81c2c92e47edaee6ULL) + F(w7 ,0x92722c851482353bULL) + F(w8 ,0xa2bfe8a14cf10364ULL) + F(w9 ,0xa81a664bbc423001ULL) + F(w10,0xc24b8b70d0f89791ULL) + F(w11,0xc76c51a30654be30ULL) + F(w12,0xd192e819d6ef5218ULL) + F(w13,0xd69906245565a910ULL) + F(w14,0xf40e35855771202aULL) + F(w15,0x106aa07032bbd1b8ULL) + + EXPAND + + F(w0 ,0x19a4c116b8d2d0c8ULL) + F(w1 ,0x1e376c085141ab53ULL) + F(w2 ,0x2748774cdf8eeb99ULL) + F(w3 ,0x34b0bcb5e19b48a8ULL) + F(w4 ,0x391c0cb3c5c95a63ULL) + F(w5 ,0x4ed8aa4ae3418acbULL) + F(w6 ,0x5b9cca4f7763e373ULL) + F(w7 ,0x682e6ff3d6b2b8a3ULL) + F(w8 ,0x748f82ee5defb2fcULL) + F(w9 ,0x78a5636f43172f60ULL) + F(w10,0x84c87814a1f0ab72ULL) + F(w11,0x8cc702081a6439ecULL) + F(w12,0x90befffa23631e28ULL) + F(w13,0xa4506cebde82bde9ULL) + F(w14,0xbef9a3f7b2c67915ULL) + F(w15,0xc67178f2e372532bULL) + + EXPAND + + F(w0 ,0xca273eceea26619cULL) + F(w1 ,0xd186b8c721c0c207ULL) + F(w2 ,0xeada7dd6cde0eb1eULL) + F(w3 ,0xf57d4f7fee6ed178ULL) + F(w4 ,0x06f067aa72176fbaULL) + F(w5 ,0x0a637dc5a2c898a6ULL) + F(w6 ,0x113f9804bef90daeULL) + F(w7 ,0x1b710b35131c471bULL) + F(w8 ,0x28db77f523047d84ULL) + F(w9 ,0x32caab7b40c72493ULL) + F(w10,0x3c9ebe0a15c9bebcULL) + F(w11,0x431d67c49c100d4cULL) + F(w12,0x4cc5d4becb3e42b6ULL) + F(w13,0x597f299cfc657e2aULL) + F(w14,0x5fcb6fab3ad6faecULL) + F(w15,0x6c44198c4a475817ULL) + + a += state[0]; + b += state[1]; + c += state[2]; + d += state[3]; + e += state[4]; + f += state[5]; + g += state[6]; + h += state[7]; + + state[0] = a; + state[1] = b; + state[2] = c; + state[3] = d; + state[4] = e; + state[5] = f; + state[6] = g; + state[7] = h; + + in += 128; + inlen -= 128; + } + + store_bigendian(statebytes + 0,state[0]); + store_bigendian(statebytes + 8,state[1]); + store_bigendian(statebytes + 16,state[2]); + store_bigendian(statebytes + 24,state[3]); + store_bigendian(statebytes + 32,state[4]); + store_bigendian(statebytes + 40,state[5]); + store_bigendian(statebytes + 48,state[6]); + store_bigendian(statebytes + 56,state[7]); + + return inlen; +} + +/* -- crypto_hash_sha512 (hash.c) -- */ + +static const unsigned char iv[64] = { + 0x6a,0x09,0xe6,0x67,0xf3,0xbc,0xc9,0x08, + 0xbb,0x67,0xae,0x85,0x84,0xca,0xa7,0x3b, + 0x3c,0x6e,0xf3,0x72,0xfe,0x94,0xf8,0x2b, + 0xa5,0x4f,0xf5,0x3a,0x5f,0x1d,0x36,0xf1, + 0x51,0x0e,0x52,0x7f,0xad,0xe6,0x82,0xd1, + 0x9b,0x05,0x68,0x8c,0x2b,0x3e,0x6c,0x1f, + 0x1f,0x83,0xd9,0xab,0xfb,0x41,0xbd,0x6b, + 0x5b,0xe0,0xcd,0x19,0x13,0x7e,0x21,0x79 +}; + +void crypto_hash_sha512(unsigned char *out, + const unsigned char *in, + unsigned long long inlen) +{ + unsigned char h[64]; + unsigned char padded[256]; + int i; + unsigned long long bytes = inlen; + + for (i = 0;i < 64;++i) h[i] = iv[i]; + + crypto_hashblocks_sha512(h,in,inlen); + in += inlen; + inlen &= 127; + in -= inlen; + + for (i = 0;i < (int)inlen;++i) padded[i] = in[i]; + padded[inlen] = 0x80; + + if (inlen < 112) { + for (i = inlen + 1;i < 119;++i) padded[i] = 0; + padded[119] = bytes >> 61; + padded[120] = bytes >> 53; + padded[121] = bytes >> 45; + padded[122] = bytes >> 37; + padded[123] = bytes >> 29; + padded[124] = bytes >> 21; + padded[125] = bytes >> 13; + padded[126] = bytes >> 5; + padded[127] = bytes << 3; + crypto_hashblocks_sha512(h,padded,128); + } else { + for (i = inlen + 1;i < 247;++i) padded[i] = 0; + padded[247] = bytes >> 61; + padded[248] = bytes >> 53; + padded[249] = bytes >> 45; + padded[250] = bytes >> 37; + padded[251] = bytes >> 29; + padded[252] = bytes >> 21; + padded[253] = bytes >> 13; + padded[254] = bytes >> 5; + padded[255] = bytes << 3; + crypto_hashblocks_sha512(h,padded,256); + } + + for (i = 0;i < 64;++i) out[i] = h[i]; +} diff --git a/smp-web/cbits/sntrup761.d.mts b/smp-web/cbits/sntrup761.d.mts new file mode 100644 index 000000000..2f603d0da --- /dev/null +++ b/smp-web/cbits/sntrup761.d.mts @@ -0,0 +1,11 @@ +interface Sntrup761Module { + _sntrup761_wasm_keypair(pk: number, sk: number): void + _sntrup761_wasm_enc(c: number, k: number, pk: number): void + _sntrup761_wasm_dec(k: number, c: number, sk: number): void + _malloc(size: number): number + _free(ptr: number): void + HEAPU8: Uint8Array +} + +declare function createSntrup761(): Promise +export default createSntrup761 diff --git a/smp-web/cbits/sntrup761_wasm.c b/smp-web/cbits/sntrup761_wasm.c new file mode 100644 index 000000000..857db3506 --- /dev/null +++ b/smp-web/cbits/sntrup761_wasm.c @@ -0,0 +1,34 @@ +/* + * WASM wrapper for sntrup761. + * Provides JS-callable functions with RNG from JS imports. + * + * Build: emcc sntrup761_wasm.c sntrup761.c sha512.c -O2 -o sntrup761.js \ + * -s EXPORTED_FUNCTIONS='["_sntrup761_wasm_keypair","_sntrup761_wasm_enc","_sntrup761_wasm_dec","_malloc","_free"]' \ + * -s EXPORTED_RUNTIME_METHODS='["ccall","cwrap"]' + */ + +#include "sntrup761.h" +#include + +/* Import RNG from JS environment */ +extern void js_random_bytes(unsigned char *buf, int len); + +/* RNG callback adapter for sntrup761 */ +static void wasm_random(void *ctx, size_t length, uint8_t *dst) { + (void)ctx; + js_random_bytes(dst, (int)length); +} + +/* JS-callable wrappers */ + +void sntrup761_wasm_keypair(unsigned char *pk, unsigned char *sk) { + sntrup761_keypair(pk, sk, NULL, wasm_random); +} + +void sntrup761_wasm_enc(unsigned char *c, unsigned char *k, const unsigned char *pk) { + sntrup761_enc(c, k, pk, NULL, wasm_random); +} + +void sntrup761_wasm_dec(unsigned char *k, const unsigned char *c, const unsigned char *sk) { + sntrup761_dec(k, c, sk); +} diff --git a/smp-web/package.json b/smp-web/package.json index e4105725f..33b5d52d4 100644 --- a/smp-web/package.json +++ b/smp-web/package.json @@ -14,13 +14,16 @@ "dist" ], "scripts": { - "build": "tsc" + "build:wasm": "mkdir -p dist/wasm && npx emcc cbits/sntrup761_wasm.c ../cbits/sntrup761.c cbits/sha512.c -I../cbits -O2 -o dist/wasm/sntrup761.mjs -s EXPORTED_FUNCTIONS='[\"_sntrup761_wasm_keypair\",\"_sntrup761_wasm_enc\",\"_sntrup761_wasm_dec\",\"_malloc\",\"_free\"]' -s EXPORTED_RUNTIME_METHODS='[\"ccall\",\"cwrap\",\"HEAPU8\"]' -s MODULARIZE=1 -s EXPORT_NAME='createSntrup761' -s ALLOW_MEMORY_GROWTH=1 -s ENVIRONMENT='web,node' --js-library cbits/js_random.js && cp cbits/sntrup761.d.mts dist/wasm/", + "build:ts": "tsc", + "build": "npm run build:wasm && npm run build:ts" }, "dependencies": { "@noble/ciphers": "^2.2.0", "@noble/curves": "^2.2.0", "@noble/hashes": "^1.5.0", - "@simplex-chat/xftp-web": "file:../xftp-web" + "@simplex-chat/xftp-web": "file:../xftp-web", + "emsdk": "^0.4.0" }, "devDependencies": { "@types/node": "^25.5.0", diff --git a/smp-web/src/crypto/sntrup761.ts b/smp-web/src/crypto/sntrup761.ts new file mode 100644 index 000000000..8914607ef --- /dev/null +++ b/smp-web/src/crypto/sntrup761.ts @@ -0,0 +1,89 @@ +// SNTRUP761 post-quantum KEM. +// Mirrors: Simplex.Messaging.Crypto.SNTRUP761 +// +// Uses WASM compiled from the same C source as the Haskell build +// (cbits/sntrup761.c by djb et al., public domain). +// SHA-512 from SUPERCOP/NaCl (djb, public domain). + +// Key sizes (from sntrup761.h) +export const SNTRUP761_PUBLICKEY_SIZE = 1158 +export const SNTRUP761_SECRETKEY_SIZE = 1763 +export const SNTRUP761_CIPHERTEXT_SIZE = 1039 +export const SNTRUP761_SIZE = 32 // shared secret + +export interface KEMKeyPair { + publicKey: Uint8Array // 1158 bytes + secretKey: Uint8Array // 1763 bytes +} + +export interface KEMEncResult { + ciphertext: Uint8Array // 1039 bytes + sharedSecret: Uint8Array // 32 bytes +} + +// WASM module instance +let wasmModule: any = null + +export async function initSntrup761(): Promise { + if (wasmModule) return + const createSntrup761 = (await import("../../dist/wasm/sntrup761.mjs")).default + wasmModule = await createSntrup761() +} + +function getModule(): any { + if (!wasmModule) throw new Error("sntrup761 WASM not initialized - call initSntrup761() first") + return wasmModule +} + +export function sntrup761Keypair(): KEMKeyPair { + const m = getModule() + const pkPtr = m._malloc(SNTRUP761_PUBLICKEY_SIZE) + const skPtr = m._malloc(SNTRUP761_SECRETKEY_SIZE) + try { + m._sntrup761_wasm_keypair(pkPtr, skPtr) + const publicKey = new Uint8Array(m.HEAPU8.buffer, pkPtr, SNTRUP761_PUBLICKEY_SIZE).slice() + const secretKey = new Uint8Array(m.HEAPU8.buffer, skPtr, SNTRUP761_SECRETKEY_SIZE).slice() + return {publicKey, secretKey} + } finally { + m._free(pkPtr) + m._free(skPtr) + } +} + +export function sntrup761Enc(publicKey: Uint8Array): KEMEncResult { + if (publicKey.length !== SNTRUP761_PUBLICKEY_SIZE) throw new Error("bad public key length") + const m = getModule() + const pkPtr = m._malloc(SNTRUP761_PUBLICKEY_SIZE) + const ctPtr = m._malloc(SNTRUP761_CIPHERTEXT_SIZE) + const ssPtr = m._malloc(SNTRUP761_SIZE) + try { + m.HEAPU8.set(publicKey, pkPtr) + m._sntrup761_wasm_enc(ctPtr, ssPtr, pkPtr) + const ciphertext = new Uint8Array(m.HEAPU8.buffer, ctPtr, SNTRUP761_CIPHERTEXT_SIZE).slice() + const sharedSecret = new Uint8Array(m.HEAPU8.buffer, ssPtr, SNTRUP761_SIZE).slice() + return {ciphertext, sharedSecret} + } finally { + m._free(pkPtr) + m._free(ctPtr) + m._free(ssPtr) + } +} + +export function sntrup761Dec(ciphertext: Uint8Array, secretKey: Uint8Array): Uint8Array { + if (ciphertext.length !== SNTRUP761_CIPHERTEXT_SIZE) throw new Error("bad ciphertext length") + if (secretKey.length !== SNTRUP761_SECRETKEY_SIZE) throw new Error("bad secret key length") + const m = getModule() + const ctPtr = m._malloc(SNTRUP761_CIPHERTEXT_SIZE) + const skPtr = m._malloc(SNTRUP761_SECRETKEY_SIZE) + const ssPtr = m._malloc(SNTRUP761_SIZE) + try { + m.HEAPU8.set(ciphertext, ctPtr) + m.HEAPU8.set(secretKey, skPtr) + m._sntrup761_wasm_dec(ssPtr, ctPtr, skPtr) + return new Uint8Array(m.HEAPU8.buffer, ssPtr, SNTRUP761_SIZE).slice() + } finally { + m._free(ctPtr) + m._free(skPtr) + m._free(ssPtr) + } +} diff --git a/tests/SMPWebTests.hs b/tests/SMPWebTests.hs index 6b3409099..04a8dbbd4 100644 --- a/tests/SMPWebTests.hs +++ b/tests/SMPWebTests.hs @@ -27,6 +27,8 @@ import Simplex.Messaging.Version (mkVersionRange) import Simplex.Messaging.Version.Internal (Version (..)) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Ratchet as CR +import Simplex.Messaging.Crypto.SNTRUP761.Bindings (KEMPublicKey (..), KEMSecretKey, KEMCiphertext (..), KEMSharedKey (..), sntrup761Keypair, sntrup761Enc, sntrup761Dec) +import qualified Data.ByteArray as BA import Simplex.Messaging.Crypto.ShortLink (contactShortLinkKdf, invShortLinkKdf) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String (strEncode) @@ -78,6 +80,9 @@ impCryptoShortLink = "import { contactShortLinkKdf, invShortLinkKdf, decryptLink impRatchet :: String impRatchet = "import { generateX448KeyPair, pqX3dhSnd, pqX3dhRcv, x448DH, encodePubKeyX448, decodePubKeyX448 } from './dist/crypto/ratchet.js';" +impSntrup :: String +impSntrup = "import { initSntrup761, sntrup761Keypair, sntrup761Enc, sntrup761Dec } from './dist/crypto/sntrup761.js'; await initSntrup761();" + impCrypto :: String impCrypto = "import { sbcInit, sbcHkdf, sbEncryptBlock, sbDecryptBlock } from './dist/crypto.js';" @@ -302,6 +307,36 @@ smpWebTests_ = do <> jsOut ("new Uint8Array([...r.fixedData, 0, ...r.userData])") tsResult `shouldBe` (fixedPlain <> B.singleton 0 <> userPlain) + describe "crypto/sntrup761" $ do + it "TypeScript encapsulates, Haskell decapsulates - shared secret matches" $ do + g <- C.newRandom + (KEMPublicKey pkBytes, sk) <- sntrup761Keypair g + tsResult <- callNode $ impSntrup + <> "const enc = sntrup761Enc(" <> jsUint8 pkBytes <> ");" + <> jsOut ("new Uint8Array([...enc.ciphertext, ...enc.sharedSecret])") + let (ctBytes, tsSharedSecret) = B.splitAt 1039 tsResult + KEMSharedKey hsSharedSecret <- sntrup761Dec (KEMCiphertext ctBytes) sk + (BA.convert hsSharedSecret :: B.ByteString) `shouldBe` tsSharedSecret + + it "Haskell encapsulates, TypeScript decapsulates - shared secret matches" $ do + -- TypeScript generates keypair, passes public key to Haskell via stdout, + -- but callNode is one-shot. So: TypeScript generates keypair, outputs (pk, sk). + -- Then Haskell encapsulates against pk, passes (ct) to TypeScript. + -- TypeScript decapsulates with sk, outputs shared secret. + -- We compare with Haskell's shared secret. + -- + -- Two callNode calls: first to get keypair, second to decapsulate. + kpResult <- callNode $ impSntrup + <> "const kp = sntrup761Keypair();" + <> jsOut ("new Uint8Array([...kp.publicKey, ...kp.secretKey])") + let (tsPk, tsSk) = B.splitAt 1158 kpResult + g <- C.newRandom + (KEMCiphertext ctBytes, KEMSharedKey hsSharedSecret) <- sntrup761Enc g (KEMPublicKey tsPk) + tsResult <- callNode $ impSntrup + <> "const ss = sntrup761Dec(" <> jsUint8 ctBytes <> "," <> jsUint8 tsSk <> ");" + <> jsOut ("ss") + tsResult `shouldBe` (BA.convert hsSharedSecret :: B.ByteString) + describe "crypto/ratchet" $ do describe "X3DH" $ do it "pqX3dhSnd and pqX3dhRcv produce same ratchetKey" $ do From 949335a8f3426aaf8d7885a516cb1265149d32c8 Mon Sep 17 00:00:00 2001 From: "Evgeny @ SimpleX Chat" <259188159+evgeny-simplex@users.noreply.github.com> Date: Wed, 13 May 2026 11:42:58 +0000 Subject: [PATCH 4/8] AES-256-GCM, comatibility tests --- smp-web/src/crypto.ts | 39 +++++++++++++++++ smp-web/src/crypto/ratchet.ts | 49 +++++++++++++++++++--- tests/SMPWebTests.hs | 79 ++++++++++++++++++++++++++++++++++- 3 files changed, 161 insertions(+), 6 deletions(-) diff --git a/smp-web/src/crypto.ts b/smp-web/src/crypto.ts index 5e84fa0c2..3e074f7ed 100644 --- a/smp-web/src/crypto.ts +++ b/smp-web/src/crypto.ts @@ -3,7 +3,10 @@ import {hkdf as nobleHkdf} from "@noble/hashes/hkdf" import {sha512} from "@noble/hashes/sha512" +import {gcm} from "@noble/ciphers/aes.js" import {cbEncrypt, cbDecrypt} from "@simplex-chat/xftp-web/dist/crypto/secretbox.js" +import {concatBytes} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" +import {pad, unPad} from "@simplex-chat/xftp-web/dist/crypto/padding.js" // C.hkdf (Crypto.hs:1461-1464) // HKDF-SHA512 extract + expand @@ -48,3 +51,39 @@ export function sbDecryptBlock(chainKey: Uint8Array, block: Uint8Array): {decryp const {keyNonce: {sbKey, nonce}, nextChainKey} = sbcHkdf(chainKey) return {decrypted: cbDecrypt(sbKey, nonce, block), nextChainKey} } + +// -- AES-256-GCM authenticated encryption (Crypto.hs:1035-1061) +// Uses 16-byte IVs (GCM with GHASH path per NIST SP 800-38D for IVs != 96 bits) + +export const AUTH_TAG_SIZE = 16 + +// encryptAEAD (Crypto.hs:1035-1039) +export function encryptAEAD( + key: Uint8Array, // 32 bytes + iv: Uint8Array, // 16 bytes + paddedLen: number, + ad: Uint8Array, + plaintext: Uint8Array, +): {authTag: Uint8Array; ciphertext: Uint8Array} { + const padded = pad(plaintext, paddedLen) + const cipher = gcm(key, iv, ad) + const encrypted = cipher.encrypt(padded) + return { + ciphertext: encrypted.subarray(0, encrypted.length - AUTH_TAG_SIZE), + authTag: encrypted.subarray(encrypted.length - AUTH_TAG_SIZE), + } +} + +// decryptAEAD (Crypto.hs:1058-1061) +export function decryptAEAD( + key: Uint8Array, + iv: Uint8Array, + ad: Uint8Array, + ciphertext: Uint8Array, + authTag: Uint8Array, +): Uint8Array { + const cipher = gcm(key, iv, ad) + const encrypted = concatBytes(ciphertext, authTag) + const padded = cipher.decrypt(encrypted) + return unPad(padded) +} diff --git a/smp-web/src/crypto/ratchet.ts b/smp-web/src/crypto/ratchet.ts index c8948ea05..dbe6ec640 100644 --- a/smp-web/src/crypto/ratchet.ts +++ b/smp-web/src/crypto/ratchet.ts @@ -2,7 +2,7 @@ // Mirrors: Simplex.Messaging.Crypto.Ratchet import {x448} from "@noble/curves/ed448.js" -import {hkdf} from "../crypto.js" +import {hkdf, encryptAEAD, decryptAEAD} from "../crypto.js" import {concatBytes} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" // -- X448 key operations @@ -59,13 +59,16 @@ function hkdf3(salt: Uint8Array, ikm: Uint8Array, info: string): [Uint8Array, Ui const X3DH_SALT = new Uint8Array(64) // 64 zero bytes // pqX3dh (Ratchet.hs:499-508) -// Core X3DH: three DH results → HKDF → init params +// Core X3DH: three DH results + optional KEM shared secret → HKDF → init params function pqX3dh( sk1: Uint8Array, rk1: Uint8Array, // public keys for assocData dh1: Uint8Array, dh2: Uint8Array, dh3: Uint8Array, + kemSharedSecret: Uint8Array | null, // PQ KEM shared secret, 32 bytes ): RatchetInitParams { const assocData = concatBytes(sk1, rk1) - const dhs = concatBytes(dh1, dh2, dh3) // no PQ for MVP + const dhs = kemSharedSecret + ? concatBytes(dh1, dh2, dh3, kemSharedSecret) + : concatBytes(dh1, dh2, dh3) const [hk, nhk, sk] = hkdf3(X3DH_SALT, dhs, "SimpleXX3DH") return {assocData, ratchetKey: sk, sndHK: hk, rcvNextHK: nhk} } @@ -76,12 +79,13 @@ function pqX3dh( export function pqX3dhSnd( spk1: Uint8Array, spk2: Uint8Array, // our private keys rk1: Uint8Array, rk2: Uint8Array, // their public keys (raw, not DER) + kemSharedSecret: Uint8Array | null = null, ): RatchetInitParams { const sk1Pub = x448.getPublicKey(spk1) const dh1 = x448DH(rk1, spk2) const dh2 = x448DH(rk2, spk1) const dh3 = x448DH(rk2, spk2) - return pqX3dh(sk1Pub, rk1, dh1, dh2, dh3) + return pqX3dh(sk1Pub, rk1, dh1, dh2, dh3, kemSharedSecret) } // pqX3dhRcv (Ratchet.hs:483-497) @@ -90,10 +94,45 @@ export function pqX3dhSnd( export function pqX3dhRcv( rpk1: Uint8Array, rpk2: Uint8Array, // our private keys sk1: Uint8Array, sk2: Uint8Array, // their public keys (raw, not DER) + kemSharedSecret: Uint8Array | null = null, ): RatchetInitParams { const rk1Pub = x448.getPublicKey(rpk1) const dh1 = x448DH(sk2, rpk1) const dh2 = x448DH(sk1, rpk2) const dh3 = x448DH(sk2, rpk2) - return pqX3dh(sk1, rk1Pub, dh1, dh2, dh3) + return pqX3dh(sk1, rk1Pub, dh1, dh2, dh3, kemSharedSecret) +} + +// -- KDF functions (Ratchet.hs:1159-1179) + +const EMPTY_SALT = new Uint8Array(0) + +// rootKdf (Ratchet.hs:1159-1166) +// HKDF-SHA512 with DH result + optional KEM shared secret +export function rootKdf( + ratchetKey: Uint8Array, // 32 bytes + peerPubKey: Uint8Array, // raw X448 public key, 56 bytes + ownPrivKey: Uint8Array, // raw X448 private key, 56 bytes + kemSecret: Uint8Array | null, // optional KEM shared secret +): {rk: Uint8Array; ck: Uint8Array; nhk: Uint8Array} { + const dhOut = x448DH(peerPubKey, ownPrivKey) + const ss = kemSecret ? concatBytes(dhOut, kemSecret) : dhOut + const [rk, ck, nhk] = hkdf3(ratchetKey, ss, "SimpleXRootRatchet") + return {rk, ck, nhk} +} + +// chainKdf (Ratchet.hs:1168-1172) +// HKDF-SHA512 with empty salt, produces chain key + message key + two 16-byte IVs +export function chainKdf(chainKey: Uint8Array): {ck: Uint8Array; mk: Uint8Array; iv: Uint8Array; ehIV: Uint8Array} { + const [ck, mk, ivs] = hkdf3(EMPTY_SALT, chainKey, "SimpleXChainRatchet") + return {ck, mk, iv: ivs.slice(0, 16), ehIV: ivs.slice(16, 32)} +} + +// -- Header padding (Ratchet.hs:716-719) + +const PADDED_HEADER_LEN_NO_PQ = 88 +const PADDED_HEADER_LEN_PQ = 2310 + +export function paddedHeaderLen(pqSupport: boolean): number { + return pqSupport ? PADDED_HEADER_LEN_PQ : PADDED_HEADER_LEN_NO_PQ } diff --git a/tests/SMPWebTests.hs b/tests/SMPWebTests.hs index 04a8dbbd4..9e8969f11 100644 --- a/tests/SMPWebTests.hs +++ b/tests/SMPWebTests.hs @@ -28,6 +28,7 @@ import Simplex.Messaging.Version.Internal (Version (..)) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Crypto.SNTRUP761.Bindings (KEMPublicKey (..), KEMSecretKey, KEMCiphertext (..), KEMSharedKey (..), sntrup761Keypair, sntrup761Enc, sntrup761Dec) +import qualified Crypto.Cipher.Types as AES import qualified Data.ByteArray as BA import Simplex.Messaging.Crypto.ShortLink (contactShortLinkKdf, invShortLinkKdf) import Simplex.Messaging.Encoding @@ -78,7 +79,8 @@ impCryptoShortLink :: String impCryptoShortLink = "import { contactShortLinkKdf, invShortLinkKdf, decryptLinkData } from './dist/crypto/shortLink.js';" impRatchet :: String -impRatchet = "import { generateX448KeyPair, pqX3dhSnd, pqX3dhRcv, x448DH, encodePubKeyX448, decodePubKeyX448 } from './dist/crypto/ratchet.js';" +impRatchet = "import { generateX448KeyPair, pqX3dhSnd, pqX3dhRcv, x448DH, encodePubKeyX448, decodePubKeyX448, chainKdf, rootKdf } from './dist/crypto/ratchet.js';" + <> "import { encryptAEAD, decryptAEAD } from './dist/crypto.js';" impSntrup :: String impSntrup = "import { initSntrup761, sntrup761Keypair, sntrup761Enc, sntrup761Dec } from './dist/crypto/sntrup761.js'; await initSntrup761();" @@ -337,6 +339,29 @@ smpWebTests_ = do <> jsOut ("ss") tsResult `shouldBe` (BA.convert hsSharedSecret :: B.ByteString) + describe "crypto/aesGcm" $ do + it "Haskell encryptAEAD (16-byte IV), TypeScript decrypts" $ do + let key = C.Key $ B.pack [1..32] + iv = C.IV $ B.pack [1..16] + ad = "associated data" + msg = "hello from haskell aes-gcm" + Right (C.AuthTag authTag, ct) <- runExceptT $ C.encryptAEAD key iv 64 ad msg + let tagBytes = BA.convert authTag :: B.ByteString + tsResult <- callNode $ impEnc + <> "import { gcm } from '@noble/ciphers/aes.js';" + <> "const key = " <> jsUint8 (B.pack [1..32]) <> ";" + <> "const iv = " <> jsUint8 (B.pack [1..16]) <> ";" + <> "const ad = new TextEncoder().encode('associated data');" + <> "const ct = " <> jsUint8 ct <> ";" + <> "const tag = " <> jsUint8 tagBytes <> ";" + <> "const cipher = gcm(key, iv, ad);" + <> "const encrypted = new Uint8Array([...ct, ...tag]);" + <> "const decrypted = cipher.decrypt(encrypted);" + -- unpad: 2-byte BE length prefix + message + '#' padding + <> "const len = (decrypted[0] << 8) | decrypted[1];" + <> jsOut ("decrypted.subarray(2, 2 + len)") + tsResult `shouldBe` msg + describe "crypto/ratchet" $ do describe "X3DH" $ do it "pqX3dhSnd and pqX3dhRcv produce same ratchetKey" $ do @@ -355,6 +380,58 @@ smpWebTests_ = do <> jsOut ("new Uint8Array([match ? 1 : 0, snd.ratchetKey.length, snd.sndHK.length, snd.rcvNextHK.length])") tsResult `shouldBe` B.pack [1, 32, 32, 32] + describe "chainKdf" $ do + it "TypeScript chainKdf produces correct output via HKDF" $ do + -- chainKdf is hkdf3("", ck, "SimpleXChainRatchet") split into 32+32+16+16 + -- Since hkdf is already tested against Haskell, test the split logic + tsResult <- callNode $ impRatchet + <> "const r = chainKdf(" <> jsUint8 (B.pack [1..32]) <> ");" + <> jsOut ("new Uint8Array([r.ck.length, r.mk.length, r.iv.length, r.ehIV.length])") + tsResult `shouldBe` B.pack [32, 32, 16, 16] + + describe "encryptAEAD" $ do + it "TypeScript encrypt matches Haskell encrypt (same ciphertext)" $ do + let key = C.Key $ B.pack [1..32] + iv = C.IV $ B.pack [1..16] + ad = "test associated data" + msg = "ratchet plaintext" + Right (C.AuthTag hsTag, hsCt) <- runExceptT $ C.encryptAEAD key iv 64 ad msg + let hsTagBytes = BA.convert hsTag :: B.ByteString + tsResult <- callNode $ impRatchet + <> "const r = encryptAEAD(" <> jsUint8 (B.pack [1..32]) <> "," <> jsUint8 (B.pack [1..16]) <> ",64," + <> "new TextEncoder().encode('test associated data')," + <> "new TextEncoder().encode('ratchet plaintext'));" + <> jsOut ("new Uint8Array([...r.authTag, ...r.ciphertext])") + tsResult `shouldBe` (hsTagBytes <> hsCt) + + it "TypeScript decrypts Haskell-encrypted" $ do + let key = C.Key $ B.pack [10..41] + iv = C.IV $ B.pack [10..25] + ad = "ad for decrypt test" + msg = "hello from haskell ratchet" + Right (C.AuthTag hsTag, hsCt) <- runExceptT $ C.encryptAEAD key iv 64 ad msg + let hsTagBytes = BA.convert hsTag :: B.ByteString + tsResult <- callNode $ impRatchet + <> "const plain = decryptAEAD(" <> jsUint8 (B.pack [10..41]) <> "," <> jsUint8 (B.pack [10..25]) <> "," + <> "new TextEncoder().encode('ad for decrypt test')," + <> jsUint8 hsCt <> "," <> jsUint8 hsTagBytes <> ");" + <> jsOut ("plain") + tsResult `shouldBe` msg + + it "Haskell decrypts TypeScript-encrypted" $ do + let key = C.Key $ B.pack [20..51] + iv = C.IV $ B.pack [20..35] + ad = "ad for ts encrypt" + msg = "hello from typescript ratchet" + tsResult <- callNode $ impRatchet + <> "const r = encryptAEAD(" <> jsUint8 (B.pack [20..51]) <> "," <> jsUint8 (B.pack [20..35]) <> ",64," + <> "new TextEncoder().encode('ad for ts encrypt')," + <> "new TextEncoder().encode('hello from typescript ratchet'));" + <> jsOut ("new Uint8Array([...r.authTag, ...r.ciphertext])") + let (tsTag, tsCt) = B.splitAt 16 tsResult + Right hsPlain <- runExceptT $ C.decryptAEAD key iv ad tsCt (C.AuthTag $ AES.AuthTag $ BA.convert tsTag) + hsPlain `shouldBe` msg + describe "DER encoding" $ do it "X448 DER round-trips" $ do tsResult <- callNode $ impRatchet From 048b22f8dfe47ba03e48f86db73ebca285f4dade Mon Sep 17 00:00:00 2001 From: "Evgeny @ SimpleX Chat" <259188159+evgeny-simplex@users.noreply.github.com> Date: Wed, 13 May 2026 14:28:22 +0000 Subject: [PATCH 5/8] core of double ratchet --- smp-web/src/crypto/ratchet.ts | 416 +++++++++++++++++++++++++++++++++- tests/SMPWebTests.hs | 90 +++++++- 2 files changed, 502 insertions(+), 4 deletions(-) diff --git a/smp-web/src/crypto/ratchet.ts b/smp-web/src/crypto/ratchet.ts index dbe6ec640..e7befe7c5 100644 --- a/smp-web/src/crypto/ratchet.ts +++ b/smp-web/src/crypto/ratchet.ts @@ -3,7 +3,11 @@ import {x448} from "@noble/curves/ed448.js" import {hkdf, encryptAEAD, decryptAEAD} from "../crypto.js" -import {concatBytes} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" +import { + Decoder, concatBytes, + encodeBytes, decodeBytes, decodeWord16, decodeWord32, + encodeLarge, decodeLarge, +} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" // -- X448 key operations @@ -136,3 +140,413 @@ const PADDED_HEADER_LEN_PQ = 2310 export function paddedHeaderLen(pqSupport: boolean): number { return pqSupport ? PADDED_HEADER_LEN_PQ : PADDED_HEADER_LEN_NO_PQ } + +// -- Ratchet state (Ratchet.hs:512-565) + +export interface SndRatchet { + rcDHRr: Uint8Array // peer's X448 public key (raw, 56 bytes) + rcCKs: Uint8Array // sending chain key (32 bytes) + rcHKs: Uint8Array // sending header key (32 bytes) +} + +export interface RcvRatchet { + rcCKr: Uint8Array // receiving chain key (32 bytes) + rcHKr: Uint8Array // receiving header key (32 bytes) +} + +export interface MessageKey { + mk: Uint8Array // 32 bytes + iv: Uint8Array // 16 bytes +} + +// Skipped message keys: Map> +// Using string keys for the outer map (hex-encoded header key) +export type SkippedMsgKeys = Map> + +export interface RatchetState { + // version + rcVersion: number // current e2e version + rcMaxVersion: number // max supported e2e version + // associated data + rcAD: Uint8Array + // DH ratchet key pair (our private key) + rcDHRs: Uint8Array // X448 private key (56 bytes) + // PQ support + rcSupportKEM: boolean + // root key + rcRK: Uint8Array // 32 bytes + // sending ratchet (null before first message sent after ratchet advance) + rcSnd: SndRatchet | null + // receiving ratchet (null before first message received) + rcRcv: RcvRatchet | null + // counters + rcNs: number // sending message number + rcNr: number // receiving message number + rcPN: number // previous sending chain length + // next header keys + rcNHKs: Uint8Array // 32 bytes + rcNHKr: Uint8Array // 32 bytes +} + +const MAX_SKIP = 512 + +function hexKey(k: Uint8Array): string { + return Array.from(k, b => b.toString(16).padStart(2, "0")).join("") +} + +// -- Ratchet initialization (Ratchet.hs:643-699) + +// initSndRatchet (Ratchet.hs:643-666) +// Used by joiner (Bob) after X3DH +export function initSndRatchet( + version: number, + maxVersion: number, + rcDHRr: Uint8Array, // peer's X448 public key (raw, from invitation) + rcDHRs: Uint8Array, // our X448 private key (raw) + initParams: RatchetInitParams, + pqSupport: boolean, + kemSharedSecret: Uint8Array | null = null, +): RatchetState { + const {rk, ck, nhk} = rootKdf(initParams.ratchetKey, rcDHRr, rcDHRs, kemSharedSecret) + return { + rcVersion: version, + rcMaxVersion: maxVersion, + rcAD: initParams.assocData, + rcDHRs, + rcSupportKEM: pqSupport, + rcRK: rk, + rcSnd: {rcDHRr, rcCKs: ck, rcHKs: initParams.sndHK}, + rcRcv: null, + rcNs: 0, + rcNr: 0, + rcPN: 0, + rcNHKs: nhk, + rcNHKr: initParams.rcvNextHK, + } +} + +// initRcvRatchet (Ratchet.hs:674-699) +// Used by initiator (Alice) after receiving confirmation +export function initRcvRatchet( + version: number, + maxVersion: number, + rcDHRs: Uint8Array, // our X448 private key (raw) + initParams: RatchetInitParams, + pqSupport: boolean, +): RatchetState { + return { + rcVersion: version, + rcMaxVersion: maxVersion, + rcAD: initParams.assocData, + rcDHRs, + rcSupportKEM: pqSupport, + rcRK: initParams.ratchetKey, + rcSnd: null, + rcRcv: null, + rcNs: 0, + rcNr: 0, + rcPN: 0, + rcNHKs: initParams.rcvNextHK, + rcNHKr: initParams.sndHK, + } +} + +// -- Message header (Ratchet.hs:703-787) + +interface MsgHeader { + msgMaxVersion: number + msgDHRs: Uint8Array // X448 public key (raw, 56 bytes) + msgPN: number + msgNs: number +} + +function encodeMsgHeader(v: number, hdr: MsgHeader): Uint8Array { + const versionBytes = new Uint8Array(2) + versionBytes[0] = (hdr.msgMaxVersion >> 8) & 0xff + versionBytes[1] = hdr.msgMaxVersion & 0xff + const dhDer = encodePubKeyX448(hdr.msgDHRs) + const pn = new Uint8Array(4) + pn[0] = (hdr.msgPN >> 24) & 0xff; pn[1] = (hdr.msgPN >> 16) & 0xff + pn[2] = (hdr.msgPN >> 8) & 0xff; pn[3] = hdr.msgPN & 0xff + const ns = new Uint8Array(4) + ns[0] = (hdr.msgNs >> 24) & 0xff; ns[1] = (hdr.msgNs >> 16) & 0xff + ns[2] = (hdr.msgNs >> 8) & 0xff; ns[3] = hdr.msgNs & 0xff + // v >= pqRatchetE2EEncryptVersion (v3): includes KEM params (Maybe, encoded as Nothing for now) + if (v >= 3) { + return concatBytes(versionBytes, encodeBytes(dhDer), new Uint8Array([0x30]), pn, ns) // '0' = Nothing for KEM + } + return concatBytes(versionBytes, encodeBytes(dhDer), pn, ns) +} + +function decodeMsgHeader(v: number, data: Uint8Array): MsgHeader { + const d = new Decoder(data) + const msgMaxVersion = decodeWord16(d) + const dhDer = decodeBytes(d) + const msgDHRs = decodePubKeyX448(dhDer) + // skip KEM params for v3+ + if (v >= 3) { + const kemByte = d.anyByte() + if (kemByte === 0x31) { + // Just - skip KEM params (we don't process them in this simplified version) + decodeBytes(d) // KEM params + } + // else '0' = Nothing, already consumed + } + const msgPN = decodeWord32(d) + const msgNs = decodeWord32(d) + return {msgMaxVersion, msgDHRs, msgPN, msgNs} +} + +// -- Encrypt (Ratchet.hs:902-975) + +export interface EncryptResult { + ciphertext: Uint8Array + state: RatchetState +} + +export function rcEncrypt( + state: RatchetState, + plaintext: Uint8Array, + paddedMsgLen: number, +): EncryptResult { + if (!state.rcSnd) throw new Error("rcEncrypt: no sending ratchet") + const snd = state.rcSnd + const v = state.rcVersion + + // Advance chain: state.CKs, mk = KDF_CK(state.CKs) + const {ck: rcCKs, mk, iv, ehIV} = chainKdf(snd.rcCKs) + + // Build and encrypt header + const headerPlain = encodeMsgHeader(v, { + msgMaxVersion: state.rcMaxVersion, + msgDHRs: x448.getPublicKey(state.rcDHRs), + msgPN: state.rcPN, + msgNs: state.rcNs, + }) + + const phl = paddedHeaderLen(state.rcSupportKEM) + const {authTag: ehAuthTag, ciphertext: ehBody} = encryptAEAD(snd.rcHKs, ehIV, phl, state.rcAD, headerPlain) + + // Encode EncMessageHeader + const ehVersionBytes = new Uint8Array(2) + ehVersionBytes[0] = (v >> 8) & 0xff; ehVersionBytes[1] = v & 0xff + // IV and AuthTag are raw bytes (no length prefix). Body is Large for v3+, ByteString for older. + const encHeader = concatBytes(ehVersionBytes, ehIV, ehAuthTag, v >= 3 ? encodeLarge(ehBody) : encodeBytes(ehBody)) + + // Encrypt body: ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) + const bodyAD = concatBytes(state.rcAD, encHeader) + const {authTag: emAuthTag, ciphertext: emBody} = encryptAEAD(mk, iv, paddedMsgLen, bodyAD, plaintext) + + // Encode EncRatchetMessage + // AuthTag is raw 16 bytes (no length prefix), body is Tail (raw bytes) + const msgBytes = concatBytes(v >= 3 ? encodeLarge(encHeader) : encodeBytes(encHeader), emAuthTag, emBody) + + // Update state + const newState: RatchetState = { + ...state, + rcSnd: {...snd, rcCKs}, + rcNs: state.rcNs + 1, + } + + return {ciphertext: msgBytes, state: newState} +} + +// -- Decrypt (Ratchet.hs:990-1157) + +export interface DecryptResult { + plaintext: Uint8Array + state: RatchetState + skippedKeys: SkippedMsgKeys +} + +interface EncRatchetMessage { + emHeader: Uint8Array + emAuthTag: Uint8Array + emBody: Uint8Array +} + +interface EncMessageHeader { + ehVersion: number + ehIV: Uint8Array + ehAuthTag: Uint8Array + ehBody: Uint8Array +} + +function parseEncRatchetMessage(data: Uint8Array): EncRatchetMessage { + const d = new Decoder(data) + // header is length-prefixed (Large for v3+, ByteString for older) + const firstByte = data[d.offset()] + const emHeader = firstByte < 32 ? decodeLarge(d) : decodeBytes(d) + // AuthTag is raw 16 bytes (no length prefix) + const emAuthTag = d.take(16) + const emBody = d.takeAll() + return {emHeader, emAuthTag, emBody} +} + +function parseEncMessageHeader(data: Uint8Array): EncMessageHeader { + const d = new Decoder(data) + const ehVersion = decodeWord16(d) + // IV is raw 16 bytes, AuthTag is raw 16 bytes + const ehIV = d.take(16) + const ehAuthTag = d.take(16) + // body: Large for v3+ (first byte < 32), ByteString for older + const firstByte = data[d.offset()] + const ehBody = firstByte < 32 ? decodeLarge(d) : decodeBytes(d) + return {ehVersion, ehIV, ehAuthTag, ehBody} +} + +function tryDecryptHeader(headerKey: Uint8Array, ad: Uint8Array, encHdr: EncMessageHeader): MsgHeader | null { + try { + const plainHeader = decryptAEAD(headerKey, encHdr.ehIV, ad, encHdr.ehBody, encHdr.ehAuthTag) + return decodeMsgHeader(encHdr.ehVersion, plainHeader) + } catch { + return null + } +} + +function decryptMessage(mk: Uint8Array, iv: Uint8Array, ad: Uint8Array, encHeader: Uint8Array, encMsg: EncRatchetMessage): Uint8Array { + const bodyAD = concatBytes(ad, encHeader) + return decryptAEAD(mk, iv, bodyAD, encMsg.emBody, encMsg.emAuthTag) +} + +export function rcDecrypt( + state: RatchetState, + skippedKeys: SkippedMsgKeys, + ciphertext: Uint8Array, +): DecryptResult { + const encMsg = parseEncRatchetMessage(ciphertext) + const encHdr = parseEncMessageHeader(encMsg.emHeader) + + // Try skipped message keys + for (const [hkHex, msgKeys] of skippedKeys) { + const hk = hexToBytes(hkHex) + const hdr = tryDecryptHeader(hk, state.rcAD, encHdr) + if (hdr) { + const mk = msgKeys.get(hdr.msgNs) + if (mk) { + // Found in skipped keys - decrypt and remove + const plaintext = decryptMessage(mk.mk, mk.iv, state.rcAD, encMsg.emHeader, encMsg) + const newMsgKeys = new Map(msgKeys) + newMsgKeys.delete(hdr.msgNs) + const newSkipped = new Map(skippedKeys) + if (newMsgKeys.size === 0) newSkipped.delete(hkHex) + else newSkipped.set(hkHex, newMsgKeys) + return {plaintext, state, skippedKeys: newSkipped} + } + } + } + + // Try current receiving ratchet header key + let ratchetStep: "same" | "advance" = "advance" + let hdr: MsgHeader | null = null + + if (state.rcRcv) { + hdr = tryDecryptHeader(state.rcRcv.rcHKr, state.rcAD, encHdr) + if (hdr) ratchetStep = "same" + } + + // Try next header key (advance ratchet) + if (!hdr) { + hdr = tryDecryptHeader(state.rcNHKr, state.rcAD, encHdr) + if (!hdr) throw new Error("rcDecrypt: header decryption failed") + ratchetStep = "advance" + } + + // Upgrade version + let rc = state + if (hdr.msgMaxVersion > rc.rcVersion) { + rc = {...rc, rcVersion: Math.max(rc.rcVersion, Math.min(hdr.msgMaxVersion, rc.rcMaxVersion))} + } + + let newSkipped = new Map(skippedKeys) + + if (ratchetStep === "advance") { + // Skip message keys for previous ratchet + const skipResult = skipMessageKeys(rc, newSkipped, hdr.msgPN) + rc = skipResult.state + newSkipped = skipResult.skippedKeys + + // DH ratchet step + const rcDHRs_new = generateX448KeyPair() + // state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, header.dh)) + const {rk: rcRK1, ck: rcCKr, nhk: rcNHKr} = rootKdf(rc.rcRK, hdr.msgDHRs, rc.rcDHRs, null) + // state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs', header.dh)) + const {rk: rcRK2, ck: rcCKs, nhk: rcNHKs} = rootKdf(rcRK1, hdr.msgDHRs, rcDHRs_new.privateKey, null) + + rc = { + ...rc, + rcDHRs: rcDHRs_new.privateKey, + rcRK: rcRK2, + rcSnd: {rcDHRr: hdr.msgDHRs, rcCKs, rcHKs: rc.rcNHKs}, + rcRcv: {rcCKr, rcHKr: rc.rcNHKr}, + rcPN: rc.rcNs, + rcNs: 0, + rcNr: 0, + rcNHKs, + rcNHKr, + } + } + + // Skip message keys for current ratchet + const skipResult2 = skipMessageKeys(rc, newSkipped, hdr.msgNs) + rc = skipResult2.state + newSkipped = skipResult2.skippedKeys + + if (!rc.rcRcv) throw new Error("rcDecrypt: no receiving ratchet after skip") + + // Decrypt message + const {ck: rcCKr, mk, iv} = chainKdf(rc.rcRcv.rcCKr) + const plaintext = decryptMessage(mk, iv, rc.rcAD, encMsg.emHeader, encMsg) + + rc = { + ...rc, + rcRcv: {...rc.rcRcv, rcCKr}, + rcNr: rc.rcNr + 1, + } + + return {plaintext, state: rc, skippedKeys: newSkipped} +} + +function skipMessageKeys( + state: RatchetState, + skippedKeys: SkippedMsgKeys, + untilN: number, +): {state: RatchetState; skippedKeys: SkippedMsgKeys} { + if (!state.rcRcv) return {state, skippedKeys} + const rcv = state.rcRcv + const rcNr = state.rcNr + + if (rcNr > untilN + 1) throw new Error("rcDecrypt: earlier message") + if (rcNr === untilN + 1) throw new Error("rcDecrypt: duplicate message") + if (rcNr + MAX_SKIP < untilN) throw new Error("rcDecrypt: too many skipped") + if (rcNr === untilN) return {state, skippedKeys} + + // Advance receiving ratchet, storing skipped keys + let ck = rcv.rcCKr + let nr = rcNr + const hkHex = hexKey(rcv.rcHKr) + const msgKeys = new Map(skippedKeys.get(hkHex) || new Map()) + + while (nr < untilN) { + const chain = chainKdf(ck) + msgKeys.set(nr, {mk: chain.mk, iv: chain.iv}) + ck = chain.ck + nr++ + } + + const newSkipped = new Map(skippedKeys) + newSkipped.set(hkHex, msgKeys) + + return { + state: {...state, rcRcv: {...rcv, rcCKr: ck}, rcNr: nr}, + skippedKeys: newSkipped, + } +} + +function hexToBytes(hex: string): Uint8Array { + const bytes = new Uint8Array(hex.length / 2) + for (let i = 0; i < hex.length; i += 2) { + bytes[i / 2] = parseInt(hex.substring(i, i + 2), 16) + } + return bytes +} diff --git a/tests/SMPWebTests.hs b/tests/SMPWebTests.hs index 9e8969f11..ac574aee4 100644 --- a/tests/SMPWebTests.hs +++ b/tests/SMPWebTests.hs @@ -2,6 +2,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} -- | Per-function tests for the smp-web TypeScript SMP client library. @@ -26,13 +27,14 @@ import Simplex.Messaging.Client (pattern NRMInteractive) import Simplex.Messaging.Version (mkVersionRange) import Simplex.Messaging.Version.Internal (Version (..)) import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto (Algorithm (..)) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Crypto.SNTRUP761.Bindings (KEMPublicKey (..), KEMSecretKey, KEMCiphertext (..), KEMSharedKey (..), sntrup761Keypair, sntrup761Enc, sntrup761Dec) import qualified Crypto.Cipher.Types as AES import qualified Data.ByteArray as BA import Simplex.Messaging.Crypto.ShortLink (contactShortLinkKdf, invShortLinkKdf) import Simplex.Messaging.Encoding -import Simplex.Messaging.Encoding.String (strEncode) +import Simplex.Messaging.Encoding.String (Str (..), strEncode) import Simplex.Messaging.Protocol (EntityId (..), SMPServer, SubscriptionMode (..), MsgFlags (..), pattern SMPServer, encodeProtocol, Command (..), NewQueueReq (..), BrokerMsg (..), RcvMessage (..), EncRcvMsgBody (..), QueueIdsKeys (..)) import Simplex.Messaging.Server.Env.STM (AStoreType (..)) import Simplex.Messaging.Server.MsgStore.Types (SMSType (..), SQSType (..)) @@ -53,7 +55,7 @@ callNode :: String -> IO B.ByteString callNode = callNode_ smpWebDir impEnc :: String -impEnc = "import { Decoder, decodeLarge } from '@simplex-chat/xftp-web/dist/protocol/encoding.js';" +impEnc = "import { Decoder, decodeBytes, decodeLarge, encodeBytes, encodeWord16 } from '@simplex-chat/xftp-web/dist/protocol/encoding.js';" impProto_ :: String impProto_ = "import { encodeTransmission, encodeBatch, decodeTransmission, encodeLGET, decodeLNK, decodeResponse, encodeNEW, encodeKEY, encodeSKEY, encodeSUB, encodeACK, encodeSEND, encodeOFF, encodeDEL } from './dist/protocol.js';" @@ -79,7 +81,7 @@ impCryptoShortLink :: String impCryptoShortLink = "import { contactShortLinkKdf, invShortLinkKdf, decryptLinkData } from './dist/crypto/shortLink.js';" impRatchet :: String -impRatchet = "import { generateX448KeyPair, pqX3dhSnd, pqX3dhRcv, x448DH, encodePubKeyX448, decodePubKeyX448, chainKdf, rootKdf } from './dist/crypto/ratchet.js';" +impRatchet = "import { generateX448KeyPair, pqX3dhSnd, pqX3dhRcv, x448DH, encodePubKeyX448, decodePubKeyX448, chainKdf, rootKdf, initSndRatchet, initRcvRatchet, rcEncrypt, rcDecrypt } from './dist/crypto/ratchet.js';" <> "import { encryptAEAD, decryptAEAD } from './dist/crypto.js';" impSntrup :: String @@ -95,6 +97,9 @@ impSodium = "import sodium from '@simplex-chat/xftp-web/node_modules/libsodium-w jsStr :: B.ByteString -> String jsStr bs = "'" <> BC.unpack bs <> "'" +paddedMsgLen :: Int +paddedMsgLen = 100 + runRight :: (Show e, HasCallStack) => ExceptT e IO a -> IO a runRight action = runExceptT action >>= either (error . ("Unexpected error: " <>) . show) pure @@ -432,6 +437,85 @@ smpWebTests_ = do Right hsPlain <- runExceptT $ C.decryptAEAD key iv ad tsCt (C.AuthTag $ AES.AuthTag $ BA.convert tsTag) hsPlain `shouldBe` msg + describe "ratchet encrypt/decrypt" $ do + it "TypeScript ratchet self-consistency: encrypt, decrypt, ratchet advance, skipped" $ do + tsResult <- callNode $ impRatchet + <> "const a1 = generateX448KeyPair(), a2 = generateX448KeyPair();" + <> "const b1 = generateX448KeyPair(), b2 = generateX448KeyPair();" + <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, a1.publicKey, a2.publicKey);" + <> "const ap = pqX3dhRcv(a1.privateKey, a2.privateKey, b1.publicKey, b2.publicKey);" + <> "const b3 = generateX448KeyPair();" + <> "let bob = initSndRatchet(3, 3, a2.publicKey, b3.privateKey, bp, false);" + <> "let alice = initRcvRatchet(3, 3, a2.privateKey, ap, false);" + <> "let sk = new Map();" + -- Bob sends 3 + <> "const e1 = rcEncrypt(bob, new TextEncoder().encode('msg1'), 100); bob = e1.state;" + <> "const e2 = rcEncrypt(bob, new TextEncoder().encode('msg2'), 100); bob = e2.state;" + <> "const e3 = rcEncrypt(bob, new TextEncoder().encode('msg3'), 100); bob = e3.state;" + -- Alice decrypts msg3 first (skip 1,2) + <> "let d3 = rcDecrypt(alice, sk, e3.ciphertext); alice = d3.state; sk = d3.skippedKeys;" + -- Alice decrypts msg1 from skipped + <> "let d1 = rcDecrypt(alice, sk, e1.ciphertext); alice = d1.state; sk = d1.skippedKeys;" + -- Alice responds + <> "const ea = rcEncrypt(alice, new TextEncoder().encode('reply'), 100); alice = ea.state;" + <> "const da = rcDecrypt(bob, new Map(), ea.ciphertext); bob = da.state;" + -- Verify + <> "const ok = new TextDecoder().decode(d3.plaintext) === 'msg3'" + <> " && new TextDecoder().decode(d1.plaintext) === 'msg1'" + <> " && new TextDecoder().decode(da.plaintext) === 'reply';" + <> jsOut ("new Uint8Array([ok ? 1 : 0])") + tsResult `shouldBe` B.singleton 1 + + it "cross-language: Haskell encrypts, TypeScript decrypts" $ do + -- Round 1: TypeScript generates alice's keys, outputs private keys + smpEncoded E2E params + tsAliceOutput <- callNode $ impEnc <> impRatchet + <> "const a1 = generateX448KeyPair(), a2 = generateX448KeyPair();" + -- smpEncode E2ERatchetParams v3: (version, pk1, pk2, Maybe KEMParams) + -- Nothing = 0x30 ('0') + <> "const e2e = new Uint8Array([...encodeWord16(3), ...encodeBytes(encodePubKeyX448(a1.publicKey)), ...encodeBytes(encodePubKeyX448(a2.publicKey)), 0x30]);" + -- Output: a1.privateKey(56) + a2.privateKey(56) + e2e_len(2) + e2e_bytes + <> "const lenBuf = new Uint8Array(2); lenBuf[0] = (e2e.length >> 8) & 0xff; lenBuf[1] = e2e.length & 0xff;" + <> jsOut ("new Uint8Array([...a1.privateKey, ...a2.privateKey, ...lenBuf, ...e2e])") + let (alicePriv1, rest1) = B.splitAt 56 tsAliceOutput + (alicePriv2, rest2) = B.splitAt 56 rest1 + e2eLen = fromIntegral (B.index rest2 0) * 256 + fromIntegral (B.index rest2 1) + aliceE2EBytes = B.take e2eLen $ B.drop 2 rest2 + + -- Round 2: Haskell decodes alice's E2E params, generates bob, encrypts + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + Right (aliceE2E@(CR.E2ERatchetParams _ _ alicePk2 _) :: CR.E2ERatchetParams 'CR.RKSProposed 'X448) <- pure $ smpDecode aliceE2EBytes + (bobPk1, bobPk2, _pKem, CR.AE2ERatchetParams _ bobE2E) <- CR.generateSndE2EParams @'X448 g v Nothing + Right (bobInitParams, _) <- pure $ CR.pqX3dhSnd bobPk1 bobPk2 Nothing aliceE2E + (_, bobDHRs) <- atomically $ C.generateKeyPair @'X448 g + let bobRatchet = CR.initSndRatchet (CR.RatchetVersions v v) alicePk2 bobDHRs (bobInitParams, Nothing) + Right (mek, _) <- runExceptT $ CR.rcEncryptHeader bobRatchet Nothing v + Right ciphertext <- runExceptT $ CR.rcEncryptMsg mek paddedMsgLen "hello from haskell ratchet" + let bobE2EBytes = smpEncode bobE2E + + -- Round 3: TypeScript decodes bob's params, inits ratchet, decrypts + tsResult <- callNode $ impEnc <> impRatchet + -- Parse bob's E2E params + <> "const d = new Decoder(" <> jsUint8 bobE2EBytes <> ");" + <> "const bobV = d.anyByte() * 256 + d.anyByte();" + <> "const bobPk1Raw = decodePubKeyX448(decodeBytes(d));" + <> "const bobPk2Raw = decodePubKeyX448(decodeBytes(d));" + <> "const a1Priv = " <> jsUint8 alicePriv1 <> ";" + <> "const a2Priv = " <> jsUint8 alicePriv2 <> ";" + <> "const ap = pqX3dhRcv(a1Priv, a2Priv, bobPk1Raw, bobPk2Raw);" + <> "const alice = initRcvRatchet(3, 3, a2Priv, ap, false);" + -- Debug: output rcAD length and first bytes to compare + <> "try {" + <> "const dec = rcDecrypt(alice, new Map(), " <> jsUint8 ciphertext <> ");" + <> jsOut ("dec.plaintext") + <> "} catch(e) {" + <> "console.warn('ERROR:', e.message);" + <> "console.warn('rcAD length:', alice.rcAD.length);" + <> "console.warn('rcAD first 10:', Array.from(alice.rcAD.subarray(0,10)));" + <> "process.exit(1);" + <> "}" + tsResult `shouldBe` "hello from haskell ratchet" + describe "DER encoding" $ do it "X448 DER round-trips" $ do tsResult <- callNode $ impRatchet From bc9b85e9a3f7199c0641f657c0d36c8a043446b6 Mon Sep 17 00:00:00 2001 From: "Evgeny @ SimpleX Chat" <259188159+evgeny-simplex@users.noreply.github.com> Date: Fri, 15 May 2026 11:56:33 +0000 Subject: [PATCH 6/8] PQ double ratchet --- smp-web/src/crypto/ratchet.ts | 809 +++++++++++++++++++++------------- tests/SMPWebTests.hs | 171 ++++++- 2 files changed, 663 insertions(+), 317 deletions(-) diff --git a/smp-web/src/crypto/ratchet.ts b/smp-web/src/crypto/ratchet.ts index e7befe7c5..80de7456c 100644 --- a/smp-web/src/crypto/ratchet.ts +++ b/smp-web/src/crypto/ratchet.ts @@ -1,14 +1,25 @@ -// Double ratchet with X3DH key agreement. -// Mirrors: Simplex.Messaging.Crypto.Ratchet +// Double ratchet with X3DH key agreement and PQ KEM. +// Faithful transpilation of Simplex.Messaging.Crypto.Ratchet +// +// Every type, field, and function mirrors the Haskell source. +// Line references are to src/Simplex/Messaging/Crypto/Ratchet.hs import {x448} from "@noble/curves/ed448.js" -import {hkdf, encryptAEAD, decryptAEAD} from "../crypto.js" +import {hkdf, encryptAEAD, decryptAEAD, AUTH_TAG_SIZE} from "../crypto.js" +import {sntrup761Keypair, sntrup761Enc, sntrup761Dec} from "./sntrup761.js" +import type {KEMKeyPair} from "./sntrup761.js" import { Decoder, concatBytes, encodeBytes, decodeBytes, decodeWord16, decodeWord32, encodeLarge, decodeLarge, + encodeMaybe, } from "@simplex-chat/xftp-web/dist/protocol/encoding.js" +// -- Version constants (lines 134-155) + +export const pqRatchetE2EEncryptVersion = 3 +export const currentE2EEncryptVersion = 3 + // -- X448 key operations export interface X448KeyPair { @@ -27,7 +38,7 @@ export function x448DH(publicKey: Uint8Array, privateKey: Uint8Array): Uint8Arra } // DER encoding for X448 public keys (RFC 8410, SubjectPublicKeyInfo) -// SEQUENCE { SEQUENCE { OID 1.3.101.110 } BIT STRING { 0x00 <56 bytes> } } +// SEQUENCE { SEQUENCE { OID 1.3.101.111 } BIT STRING { 0x00 <56 bytes> } } const X448_PUBKEY_DER_PREFIX = new Uint8Array([ 0x30, 0x42, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x6f, 0x03, 0x39, 0x00, ]) @@ -44,484 +55,582 @@ export function decodePubKeyX448(der: Uint8Array): Uint8Array { return der.subarray(12) } -// -- X3DH key agreement (Ratchet.hs:499-508) +// -- KEM types (lines 567-577) +// KEMKeyPair imported from ./sntrup761.js + +export interface RatchetKEMAccepted { + rcPQRr: Uint8Array // KEMPublicKey - received key (1158 bytes) + rcPQRss: Uint8Array // KEMSharedKey - computed shared secret (32 bytes) + rcPQRct: Uint8Array // KEMCiphertext - sent encaps (1039 bytes) +} + +export interface RatchetKEM { + rcPQRs: KEMKeyPair + rcKEMs: RatchetKEMAccepted | null +} + +// -- RatchetInitParams (lines 457-464) export interface RatchetInitParams { - assocData: Uint8Array // pubKeyBytes(sk1) || pubKeyBytes(rk1) - ratchetKey: Uint8Array // 32 bytes (root key) - sndHK: Uint8Array // 32 bytes (header key) - rcvNextHK: Uint8Array // 32 bytes (next header key) + assocData: Uint8Array // Str (raw bytes) + ratchetKey: Uint8Array // RatchetKey (32 bytes) + sndHK: Uint8Array // HeaderKey (32 bytes) + rcvNextHK: Uint8Array // HeaderKey (32 bytes) + kemAccepted: RatchetKEMAccepted | null // Maybe RatchetKEMAccepted } -// hkdf3 (Ratchet.hs:1174-1179) -// HKDF-SHA512, output 96 bytes, split 32+32+32 +// -- hkdf3 (lines 1174-1179) + function hkdf3(salt: Uint8Array, ikm: Uint8Array, info: string): [Uint8Array, Uint8Array, Uint8Array] { const out = hkdf(salt, ikm, info, 96) return [out.slice(0, 32), out.slice(32, 64), out.slice(64, 96)] } -const X3DH_SALT = new Uint8Array(64) // 64 zero bytes +// -- pqX3dh (lines 499-508) + +const X3DH_SALT = new Uint8Array(64) -// pqX3dh (Ratchet.hs:499-508) -// Core X3DH: three DH results + optional KEM shared secret → HKDF → init params function pqX3dh( - sk1: Uint8Array, rk1: Uint8Array, // public keys for assocData + sk1: Uint8Array, rk1: Uint8Array, dh1: Uint8Array, dh2: Uint8Array, dh3: Uint8Array, - kemSharedSecret: Uint8Array | null, // PQ KEM shared secret, 32 bytes + kemAccepted: RatchetKEMAccepted | null, ): RatchetInitParams { const assocData = concatBytes(sk1, rk1) - const dhs = kemSharedSecret - ? concatBytes(dh1, dh2, dh3, kemSharedSecret) - : concatBytes(dh1, dh2, dh3) + const pq = kemAccepted ? kemAccepted.rcPQRss : new Uint8Array(0) + const dhs = concatBytes(dh1, dh2, dh3, pq) const [hk, nhk, sk] = hkdf3(X3DH_SALT, dhs, "SimpleXX3DH") - return {assocData, ratchetKey: sk, sndHK: hk, rcvNextHK: nhk} + return {assocData, ratchetKey: sk, sndHK: hk, rcvNextHK: nhk, kemAccepted} } -// pqX3dhSnd (Ratchet.hs:467-480) -// Used by joiner (Bob) to initialize SENDING ratchet. -// Our keys: spk1, spk2 (private). Their keys: rk1, rk2 (public, from invitation). +// -- pqX3dhSnd (lines 467-480) +// Used by joiner (Alice in PQDR spec, Bob in DR spec) to init SENDING ratchet. + export function pqX3dhSnd( - spk1: Uint8Array, spk2: Uint8Array, // our private keys - rk1: Uint8Array, rk2: Uint8Array, // their public keys (raw, not DER) - kemSharedSecret: Uint8Array | null = null, + spk1: Uint8Array, spk2: Uint8Array, // our private keys + rk1: Uint8Array, rk2: Uint8Array, // their public keys (raw) + kemAccepted: RatchetKEMAccepted | null = null, ): RatchetInitParams { const sk1Pub = x448.getPublicKey(spk1) const dh1 = x448DH(rk1, spk2) const dh2 = x448DH(rk2, spk1) const dh3 = x448DH(rk2, spk2) - return pqX3dh(sk1Pub, rk1, dh1, dh2, dh3, kemSharedSecret) + return pqX3dh(sk1Pub, rk1, dh1, dh2, dh3, kemAccepted) } -// pqX3dhRcv (Ratchet.hs:483-497) -// Used by initiator (Alice) to initialize RECEIVING ratchet. -// Our keys: rpk1, rpk2 (private). Their keys: sk1, sk2 (public, from confirmation). +// -- pqX3dhRcv (lines 483-497) +// Used by initiator (Bob in PQDR spec, Alice in DR spec) to init RECEIVING ratchet. + export function pqX3dhRcv( - rpk1: Uint8Array, rpk2: Uint8Array, // our private keys - sk1: Uint8Array, sk2: Uint8Array, // their public keys (raw, not DER) - kemSharedSecret: Uint8Array | null = null, + rpk1: Uint8Array, rpk2: Uint8Array, // our private keys + sk1: Uint8Array, sk2: Uint8Array, // their public keys (raw) + kemAccepted: RatchetKEMAccepted | null = null, ): RatchetInitParams { const rk1Pub = x448.getPublicKey(rpk1) const dh1 = x448DH(sk2, rpk1) const dh2 = x448DH(sk1, rpk2) const dh3 = x448DH(sk2, rpk2) - return pqX3dh(sk1, rk1Pub, dh1, dh2, dh3, kemSharedSecret) + return pqX3dh(sk1, rk1Pub, dh1, dh2, dh3, kemAccepted) } -// -- KDF functions (Ratchet.hs:1159-1179) +// -- rootKdf (lines 1159-1166) -const EMPTY_SALT = new Uint8Array(0) - -// rootKdf (Ratchet.hs:1159-1166) -// HKDF-SHA512 with DH result + optional KEM shared secret export function rootKdf( - ratchetKey: Uint8Array, // 32 bytes - peerPubKey: Uint8Array, // raw X448 public key, 56 bytes - ownPrivKey: Uint8Array, // raw X448 private key, 56 bytes - kemSecret: Uint8Array | null, // optional KEM shared secret + rk: Uint8Array, // RatchetKey (32 bytes) + peerPubKey: Uint8Array, // PublicKey a (raw, 56 bytes for X448) + ownPrivKey: Uint8Array, // PrivateKey a (raw, 56 bytes for X448) + kemSecret: Uint8Array | null, // Maybe KEMSharedKey ): {rk: Uint8Array; ck: Uint8Array; nhk: Uint8Array} { const dhOut = x448DH(peerPubKey, ownPrivKey) const ss = kemSecret ? concatBytes(dhOut, kemSecret) : dhOut - const [rk, ck, nhk] = hkdf3(ratchetKey, ss, "SimpleXRootRatchet") - return {rk, ck, nhk} + const [rk_, ck, nhk] = hkdf3(rk, ss, "SimpleXRootRatchet") + return {rk: rk_, ck, nhk} } -// chainKdf (Ratchet.hs:1168-1172) -// HKDF-SHA512 with empty salt, produces chain key + message key + two 16-byte IVs -export function chainKdf(chainKey: Uint8Array): {ck: Uint8Array; mk: Uint8Array; iv: Uint8Array; ehIV: Uint8Array} { - const [ck, mk, ivs] = hkdf3(EMPTY_SALT, chainKey, "SimpleXChainRatchet") - return {ck, mk, iv: ivs.slice(0, 16), ehIV: ivs.slice(16, 32)} -} +// -- chainKdf (lines 1168-1172) -// -- Header padding (Ratchet.hs:716-719) +export function chainKdf(ck: Uint8Array): {ck: Uint8Array; mk: Uint8Array; iv: Uint8Array; ehIV: Uint8Array} { + const EMPTY = new Uint8Array(0) + const [ck_, mk, ivs] = hkdf3(EMPTY, ck, "SimpleXChainRatchet") + return {ck: ck_, mk, iv: ivs.slice(0, 16), ehIV: ivs.slice(16, 32)} +} -const PADDED_HEADER_LEN_NO_PQ = 88 -const PADDED_HEADER_LEN_PQ = 2310 +// -- Header padding (lines 716-719) -export function paddedHeaderLen(pqSupport: boolean): number { - return pqSupport ? PADDED_HEADER_LEN_PQ : PADDED_HEADER_LEN_NO_PQ +export function paddedHeaderLen(v: number, pqSupport: boolean): number { + if (pqSupport && v >= pqRatchetE2EEncryptVersion) return 2310 + return 88 } -// -- Ratchet state (Ratchet.hs:512-565) +// -- SndRatchet (lines 554-559) export interface SndRatchet { - rcDHRr: Uint8Array // peer's X448 public key (raw, 56 bytes) + rcDHRr: Uint8Array // peer's public key (raw, 56 bytes) rcCKs: Uint8Array // sending chain key (32 bytes) rcHKs: Uint8Array // sending header key (32 bytes) } +// -- RcvRatchet (lines 561-565) + export interface RcvRatchet { rcCKr: Uint8Array // receiving chain key (32 bytes) rcHKr: Uint8Array // receiving header key (32 bytes) } +// -- MessageKey (lines 608-609) + export interface MessageKey { - mk: Uint8Array // 32 bytes - iv: Uint8Array // 16 bytes + mk: Uint8Array // Key (32 bytes) + iv: Uint8Array // IV (16 bytes) } -// Skipped message keys: Map> -// Using string keys for the outer map (hex-encoded header key) -export type SkippedMsgKeys = Map> +// -- RatchetVersions (lines 534-538) -export interface RatchetState { - // version - rcVersion: number // current e2e version - rcMaxVersion: number // max supported e2e version - // associated data - rcAD: Uint8Array - // DH ratchet key pair (our private key) - rcDHRs: Uint8Array // X448 private key (56 bytes) - // PQ support - rcSupportKEM: boolean - // root key - rcRK: Uint8Array // 32 bytes - // sending ratchet (null before first message sent after ratchet advance) +export interface RatchetVersions { + current: number + maxSupported: number +} + +// -- Ratchet (lines 512-532) + +export interface Ratchet { + rcVersion: RatchetVersions + rcAD: Uint8Array // Str (associated data, raw bytes) + rcDHRs: Uint8Array // PrivateKey a (raw, 56 bytes) + rcKEM: RatchetKEM | null + rcSupportKEM: boolean // PQSupport + rcEnableKEM: boolean // PQEncryption + rcSndKEM: boolean // PQEncryption + rcRcvKEM: boolean // PQEncryption + rcRK: Uint8Array // RatchetKey (32 bytes) rcSnd: SndRatchet | null - // receiving ratchet (null before first message received) rcRcv: RcvRatchet | null - // counters - rcNs: number // sending message number - rcNr: number // receiving message number - rcPN: number // previous sending chain length - // next header keys - rcNHKs: Uint8Array // 32 bytes - rcNHKr: Uint8Array // 32 bytes + rcNs: number // Word32 + rcNr: number // Word32 + rcPN: number // Word32 + rcNHKs: Uint8Array // HeaderKey (32 bytes) + rcNHKr: Uint8Array // HeaderKey (32 bytes) } +// -- SkippedMsgKeys (lines 580-582) + +export type SkippedMsgKeys = Map> + const MAX_SKIP = 512 function hexKey(k: Uint8Array): string { return Array.from(k, b => b.toString(16).padStart(2, "0")).join("") } -// -- Ratchet initialization (Ratchet.hs:643-699) +function hexToBytes(hex: string): Uint8Array { + const bytes = new Uint8Array(hex.length / 2) + for (let i = 0; i < hex.length; i += 2) bytes[i / 2] = parseInt(hex.substring(i, i + 2), 16) + return bytes +} + +// -- initSndRatchet (lines 643-666) -// initSndRatchet (Ratchet.hs:643-666) -// Used by joiner (Bob) after X3DH export function initSndRatchet( - version: number, - maxVersion: number, - rcDHRr: Uint8Array, // peer's X448 public key (raw, from invitation) - rcDHRs: Uint8Array, // our X448 private key (raw) + rcVersion: RatchetVersions, + rcDHRr: Uint8Array, // peer's public key (raw) + rcDHRs: Uint8Array, // our private key (raw) initParams: RatchetInitParams, - pqSupport: boolean, - kemSharedSecret: Uint8Array | null = null, -): RatchetState { - const {rk, ck, nhk} = rootKdf(initParams.ratchetKey, rcDHRr, rcDHRs, kemSharedSecret) + rcPQRs_: KEMKeyPair | null, +): Ratchet { + const {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted} = initParams + // state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr) || state.PQRss) + const kemSecret = kemAccepted ? kemAccepted.rcPQRss : null + const {rk: rcRK, ck: rcCKs, nhk: rcNHKs} = rootKdf(ratchetKey, rcDHRr, rcDHRs, kemSecret) + const pqOn = rcPQRs_ !== null return { - rcVersion: version, - rcMaxVersion: maxVersion, - rcAD: initParams.assocData, + rcVersion, + rcAD: assocData, rcDHRs, - rcSupportKEM: pqSupport, - rcRK: rk, - rcSnd: {rcDHRr, rcCKs: ck, rcHKs: initParams.sndHK}, + rcKEM: rcPQRs_ ? {rcPQRs: rcPQRs_, rcKEMs: kemAccepted} : null, + rcSupportKEM: pqOn, + rcEnableKEM: pqOn, + rcSndKEM: kemAccepted !== null, + rcRcvKEM: false, + rcRK, + rcSnd: {rcDHRr, rcCKs, rcHKs: sndHK}, rcRcv: null, + rcPN: 0, rcNs: 0, rcNr: 0, - rcPN: 0, - rcNHKs: nhk, - rcNHKr: initParams.rcvNextHK, + rcNHKs, + rcNHKr: rcvNextHK, } } -// initRcvRatchet (Ratchet.hs:674-699) -// Used by initiator (Alice) after receiving confirmation +// -- initRcvRatchet (lines 674-699) + export function initRcvRatchet( - version: number, - maxVersion: number, - rcDHRs: Uint8Array, // our X448 private key (raw) + rcVersion: RatchetVersions, + rcDHRs: Uint8Array, // our private key (raw) initParams: RatchetInitParams, + rcPQRs_: KEMKeyPair | null, pqSupport: boolean, -): RatchetState { +): Ratchet { + const {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted} = initParams return { - rcVersion: version, - rcMaxVersion: maxVersion, - rcAD: initParams.assocData, + rcVersion, + rcAD: assocData, rcDHRs, + rcKEM: rcPQRs_ ? {rcPQRs: rcPQRs_, rcKEMs: kemAccepted} : null, rcSupportKEM: pqSupport, - rcRK: initParams.ratchetKey, + rcEnableKEM: pqSupport, + rcSndKEM: false, + rcRcvKEM: false, + rcRK: ratchetKey, rcSnd: null, rcRcv: null, + rcPN: 0, rcNs: 0, rcNr: 0, - rcPN: 0, - rcNHKs: initParams.rcvNextHK, - rcNHKr: initParams.sndHK, + rcNHKs: rcvNextHK, + rcNHKr: sndHK, } } -// -- Message header (Ratchet.hs:703-787) +// -- RKEMParams (lines 188-190) - parsed KEM params from message header + +export type RKEMParams = + | {type: "proposed", kemPk: Uint8Array} // RKParamsProposed KEMPublicKey + | {type: "accepted", kemCt: Uint8Array, kemPk: Uint8Array} // RKParamsAccepted KEMCiphertext KEMPublicKey + +// -- MsgHeader (lines 703-711) interface MsgHeader { msgMaxVersion: number - msgDHRs: Uint8Array // X448 public key (raw, 56 bytes) - msgPN: number - msgNs: number + msgDHRs: Uint8Array // PublicKey a (raw, 56 bytes) + msgKEM: RKEMParams | null + msgPN: number // Word32 + msgNs: number // Word32 } +// -- encodeMsgHeader (lines 727-730) + function encodeMsgHeader(v: number, hdr: MsgHeader): Uint8Array { - const versionBytes = new Uint8Array(2) - versionBytes[0] = (hdr.msgMaxVersion >> 8) & 0xff - versionBytes[1] = hdr.msgMaxVersion & 0xff + const vBytes = new Uint8Array(2) + vBytes[0] = (hdr.msgMaxVersion >> 8) & 0xff + vBytes[1] = hdr.msgMaxVersion & 0xff const dhDer = encodePubKeyX448(hdr.msgDHRs) - const pn = new Uint8Array(4) - pn[0] = (hdr.msgPN >> 24) & 0xff; pn[1] = (hdr.msgPN >> 16) & 0xff - pn[2] = (hdr.msgPN >> 8) & 0xff; pn[3] = hdr.msgPN & 0xff - const ns = new Uint8Array(4) - ns[0] = (hdr.msgNs >> 24) & 0xff; ns[1] = (hdr.msgNs >> 16) & 0xff - ns[2] = (hdr.msgNs >> 8) & 0xff; ns[3] = hdr.msgNs & 0xff - // v >= pqRatchetE2EEncryptVersion (v3): includes KEM params (Maybe, encoded as Nothing for now) - if (v >= 3) { - return concatBytes(versionBytes, encodeBytes(dhDer), new Uint8Array([0x30]), pn, ns) // '0' = Nothing for KEM + const pn = encodeWord32(hdr.msgPN) + const ns = encodeWord32(hdr.msgNs) + if (v >= pqRatchetE2EEncryptVersion) { + // smpEncode (msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs) + // msgKEM :: Maybe ARKEMParams + const kemBytes = hdr.msgKEM ? encodeRKEMParams(hdr.msgKEM) : new Uint8Array([0x30]) // Nothing + return concatBytes(vBytes, encodeBytes(dhDer), kemBytes, pn, ns) } - return concatBytes(versionBytes, encodeBytes(dhDer), pn, ns) + // smpEncode (msgMaxVersion, msgDHRs, msgPN, msgNs) + return concatBytes(vBytes, encodeBytes(dhDer), pn, ns) +} + +// Encode Maybe ARKEMParams: '1' + encoded params, or nothing (handled at call site with '0') +function encodeRKEMParams(params: RKEMParams): Uint8Array { + if (params.type === "proposed") { + // Just ('P', kemPk) - smpEncode ('P', k) where k is KEMPublicKey (Large) + return concatBytes(new Uint8Array([0x31, 0x50]), encodeLarge(params.kemPk)) + } + // Just ('A', ct, kemPk) - smpEncode ('A', ct, k) + return concatBytes(new Uint8Array([0x31, 0x41]), encodeLarge(params.kemCt), encodeLarge(params.kemPk)) +} + +function encodeWord32(n: number): Uint8Array { + const buf = new Uint8Array(4) + buf[0] = (n >> 24) & 0xff; buf[1] = (n >> 16) & 0xff + buf[2] = (n >> 8) & 0xff; buf[3] = n & 0xff + return buf } +// -- msgHeaderP (lines 733-740) + function decodeMsgHeader(v: number, data: Uint8Array): MsgHeader { const d = new Decoder(data) const msgMaxVersion = decodeWord16(d) const dhDer = decodeBytes(d) const msgDHRs = decodePubKeyX448(dhDer) - // skip KEM params for v3+ - if (v >= 3) { - const kemByte = d.anyByte() - if (kemByte === 0x31) { - // Just - skip KEM params (we don't process them in this simplified version) - decodeBytes(d) // KEM params + let msgKEM: RKEMParams | null = null + if (v >= pqRatchetE2EEncryptVersion) { + // Maybe ARKEMParams + const maybeByte = d.anyByte() + if (maybeByte === 0x31) { + // Just - parse ARKEMParams + const tag = d.anyByte() + if (tag === 0x50) { // 'P' - Proposed: KEMPublicKey (Large) + msgKEM = {type: "proposed", kemPk: decodeLarge(d)} + } else if (tag === 0x41) { // 'A' - Accepted: KEMCiphertext (Large) + KEMPublicKey (Large) + const kemCt = decodeLarge(d) + const kemPk = decodeLarge(d) + msgKEM = {type: "accepted", kemCt, kemPk} + } else { + throw new Error("decodeMsgHeader: unknown KEM tag " + tag) + } } - // else '0' = Nothing, already consumed + // else '0' = Nothing, msgKEM stays null } const msgPN = decodeWord32(d) const msgNs = decodeWord32(d) - return {msgMaxVersion, msgDHRs, msgPN, msgNs} + return {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} +} + +// -- EncMessageHeader (lines 742-756) + +interface EncMessageHeader { + ehVersion: number // current ratchet version + ehIV: Uint8Array // IV (raw 16 bytes) + ehAuthTag: Uint8Array // AuthTag (raw 16 bytes) + ehBody: Uint8Array // encrypted header body +} + +// smpEncode (lines 751-752) +function encodeEncMessageHeader(emh: EncMessageHeader): Uint8Array { + const vBytes = new Uint8Array(2) + vBytes[0] = (emh.ehVersion >> 8) & 0xff + vBytes[1] = emh.ehVersion & 0xff + // smpEncode (ehVersion, ehIV, ehAuthTag) <> encodeLarge ehVersion ehBody + const bodyEnc = emh.ehVersion >= pqRatchetE2EEncryptVersion + ? encodeLarge(emh.ehBody) + : encodeBytes(emh.ehBody) + return concatBytes(vBytes, emh.ehIV, emh.ehAuthTag, bodyEnc) +} + +// smpP (lines 753-756) +function decodeEncMessageHeader(data: Uint8Array): EncMessageHeader { + const d = new Decoder(data) + const ehVersion = decodeWord16(d) + const ehIV = d.take(16) // IV is raw 16 bytes + const ehAuthTag = d.take(16) // AuthTag is raw 16 bytes + // largeP: peek first byte, if < 32 then Large (2-byte len), else ByteString (1-byte len) + const firstByte = data[d.offset()] + const ehBody = firstByte < 32 ? decodeLarge(d) : decodeBytes(d) + return {ehVersion, ehIV, ehAuthTag, ehBody} +} + +// -- EncRatchetMessage (lines 772-787) + +interface EncRatchetMessage { + emHeader: Uint8Array // smpEncoded EncMessageHeader + emAuthTag: Uint8Array // AuthTag (raw 16 bytes) + emBody: Uint8Array // encrypted message body +} + +// encodeEncRatchetMessage (lines 779-781) +function encodeEncRatchetMessage(v: number, msg: EncRatchetMessage): Uint8Array { + // encodeLarge v emHeader <> smpEncode (emAuthTag, Tail emBody) + const headerEnc = v >= pqRatchetE2EEncryptVersion + ? encodeLarge(msg.emHeader) + : encodeBytes(msg.emHeader) + return concatBytes(headerEnc, msg.emAuthTag, msg.emBody) +} + +// encRatchetMessageP (lines 783-787) +function decodeEncRatchetMessage(data: Uint8Array): EncRatchetMessage { + const d = new Decoder(data) + // largeP + const firstByte = data[d.offset()] + const emHeader = firstByte < 32 ? decodeLarge(d) : decodeBytes(d) + // smpEncode (emAuthTag, Tail emBody) → raw 16 bytes + rest + const emAuthTag = d.take(16) + const emBody = d.takeAll() + return {emHeader, emAuthTag, emBody} +} + +// -- MsgEncryptKey (lines 962-968) + +interface MsgEncryptKey { + msgRcVersion: number + msgKey: MessageKey + msgRcAD: Uint8Array + msgEncHeader: Uint8Array +} + +// -- msgKEMParams (lines 956-958) - build KEM params from ratchet state for message header + +function msgKEMParams(kem: RatchetKEM): RKEMParams { + const {rcPQRs, rcKEMs} = kem + if (!rcKEMs) { + return {type: "proposed", kemPk: rcPQRs.publicKey} + } + return {type: "accepted", kemCt: rcKEMs.rcPQRct, kemPk: rcPQRs.publicKey} } -// -- Encrypt (Ratchet.hs:902-975) +// -- pqEnableSupport (line 836-837) + +function pqEnableSupport(v: number, sup: boolean, enc: boolean): boolean { + return sup || (v >= pqRatchetE2EEncryptVersion && enc) +} + +// -- rcEncryptHeader + rcEncryptMsg (lines 902-975) export interface EncryptResult { ciphertext: Uint8Array - state: RatchetState + state: Ratchet } export function rcEncrypt( - state: RatchetState, + rc: Ratchet, plaintext: Uint8Array, paddedMsgLen: number, ): EncryptResult { - if (!state.rcSnd) throw new Error("rcEncrypt: no sending ratchet") - const snd = state.rcSnd - const v = state.rcVersion + if (!rc.rcSnd) throw new Error("rcEncrypt: no sending ratchet (CERatchetState)") + const snd = rc.rcSnd + const v = rc.rcVersion.current - // Advance chain: state.CKs, mk = KDF_CK(state.CKs) - const {ck: rcCKs, mk, iv, ehIV} = chainKdf(snd.rcCKs) + // state.CKs, mk = KDF_CK(state.CKs) + const chain = chainKdf(snd.rcCKs) - // Build and encrypt header + // header const headerPlain = encodeMsgHeader(v, { - msgMaxVersion: state.rcMaxVersion, - msgDHRs: x448.getPublicKey(state.rcDHRs), - msgPN: state.rcPN, - msgNs: state.rcNs, + msgMaxVersion: rc.rcVersion.maxSupported, + msgDHRs: x448.getPublicKey(rc.rcDHRs), + msgKEM: rc.rcKEM ? msgKEMParams(rc.rcKEM) : null, + msgPN: rc.rcPN, + msgNs: rc.rcNs, }) - const phl = paddedHeaderLen(state.rcSupportKEM) - const {authTag: ehAuthTag, ciphertext: ehBody} = encryptAEAD(snd.rcHKs, ehIV, phl, state.rcAD, headerPlain) + // enc_header = HENCRYPT(state.HKs, header) + const phl = paddedHeaderLen(v, rc.rcSupportKEM) + const {authTag: ehAuthTag, ciphertext: ehBody} = encryptAEAD(snd.rcHKs, chain.ehIV, phl, rc.rcAD, headerPlain) - // Encode EncMessageHeader - const ehVersionBytes = new Uint8Array(2) - ehVersionBytes[0] = (v >> 8) & 0xff; ehVersionBytes[1] = v & 0xff - // IV and AuthTag are raw bytes (no length prefix). Body is Large for v3+, ByteString for older. - const encHeader = concatBytes(ehVersionBytes, ehIV, ehAuthTag, v >= 3 ? encodeLarge(ehBody) : encodeBytes(ehBody)) + // smpEncode EncMessageHeader + const emHeader = encodeEncMessageHeader({ehVersion: v, ehBody, ehAuthTag, ehIV: chain.ehIV}) - // Encrypt body: ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) - const bodyAD = concatBytes(state.rcAD, encHeader) - const {authTag: emAuthTag, ciphertext: emBody} = encryptAEAD(mk, iv, paddedMsgLen, bodyAD, plaintext) + // ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) + const bodyAD = concatBytes(rc.rcAD, emHeader) + const {authTag: emAuthTag, ciphertext: emBody} = encryptAEAD(chain.mk, chain.iv, paddedMsgLen, bodyAD, plaintext) - // Encode EncRatchetMessage - // AuthTag is raw 16 bytes (no length prefix), body is Tail (raw bytes) - const msgBytes = concatBytes(v >= 3 ? encodeLarge(encHeader) : encodeBytes(encHeader), emAuthTag, emBody) + // encodeEncRatchetMessage + const ciphertext = encodeEncRatchetMessage(v, {emHeader, emBody, emAuthTag}) // Update state - const newState: RatchetState = { - ...state, - rcSnd: {...snd, rcCKs}, - rcNs: state.rcNs + 1, + const newState: Ratchet = { + ...rc, + rcSnd: {...snd, rcCKs: chain.ck}, + rcNs: rc.rcNs + 1, } - return {ciphertext: msgBytes, state: newState} + return {ciphertext, state: newState} } -// -- Decrypt (Ratchet.hs:990-1157) +// -- rcDecrypt (lines 990-1157) export interface DecryptResult { plaintext: Uint8Array - state: RatchetState + state: Ratchet skippedKeys: SkippedMsgKeys } -interface EncRatchetMessage { - emHeader: Uint8Array - emAuthTag: Uint8Array - emBody: Uint8Array -} - -interface EncMessageHeader { - ehVersion: number - ehIV: Uint8Array - ehAuthTag: Uint8Array - ehBody: Uint8Array -} - -function parseEncRatchetMessage(data: Uint8Array): EncRatchetMessage { - const d = new Decoder(data) - // header is length-prefixed (Large for v3+, ByteString for older) - const firstByte = data[d.offset()] - const emHeader = firstByte < 32 ? decodeLarge(d) : decodeBytes(d) - // AuthTag is raw 16 bytes (no length prefix) - const emAuthTag = d.take(16) - const emBody = d.takeAll() - return {emHeader, emAuthTag, emBody} -} - -function parseEncMessageHeader(data: Uint8Array): EncMessageHeader { - const d = new Decoder(data) - const ehVersion = decodeWord16(d) - // IV is raw 16 bytes, AuthTag is raw 16 bytes - const ehIV = d.take(16) - const ehAuthTag = d.take(16) - // body: Large for v3+ (first byte < 32), ByteString for older - const firstByte = data[d.offset()] - const ehBody = firstByte < 32 ? decodeLarge(d) : decodeBytes(d) - return {ehVersion, ehIV, ehAuthTag, ehBody} -} - -function tryDecryptHeader(headerKey: Uint8Array, ad: Uint8Array, encHdr: EncMessageHeader): MsgHeader | null { - try { - const plainHeader = decryptAEAD(headerKey, encHdr.ehIV, ad, encHdr.ehBody, encHdr.ehAuthTag) - return decodeMsgHeader(encHdr.ehVersion, plainHeader) - } catch { - return null - } -} - -function decryptMessage(mk: Uint8Array, iv: Uint8Array, ad: Uint8Array, encHeader: Uint8Array, encMsg: EncRatchetMessage): Uint8Array { - const bodyAD = concatBytes(ad, encHeader) - return decryptAEAD(mk, iv, bodyAD, encMsg.emBody, encMsg.emAuthTag) -} - export function rcDecrypt( - state: RatchetState, + rc: Ratchet, skippedKeys: SkippedMsgKeys, ciphertext: Uint8Array, ): DecryptResult { - const encMsg = parseEncRatchetMessage(ciphertext) - const encHdr = parseEncMessageHeader(encMsg.emHeader) + const encMsg = decodeEncRatchetMessage(ciphertext) + const encHdr = decodeEncMessageHeader(encMsg.emHeader) - // Try skipped message keys - for (const [hkHex, msgKeys] of skippedKeys) { - const hk = hexToBytes(hkHex) - const hdr = tryDecryptHeader(hk, state.rcAD, encHdr) - if (hdr) { - const mk = msgKeys.get(hdr.msgNs) - if (mk) { - // Found in skipped keys - decrypt and remove - const plaintext = decryptMessage(mk.mk, mk.iv, state.rcAD, encMsg.emHeader, encMsg) - const newMsgKeys = new Map(msgKeys) - newMsgKeys.delete(hdr.msgNs) - const newSkipped = new Map(skippedKeys) - if (newMsgKeys.size === 0) newSkipped.delete(hkHex) - else newSkipped.set(hkHex, newMsgKeys) - return {plaintext, state, skippedKeys: newSkipped} - } - } - } + // TrySkippedMessageKeysHE + const skipped = tryDecryptSkipped(rc, skippedKeys, encHdr, encMsg) + if (skipped) return skipped - // Try current receiving ratchet header key + // DecryptHeader let ratchetStep: "same" | "advance" = "advance" let hdr: MsgHeader | null = null - if (state.rcRcv) { - hdr = tryDecryptHeader(state.rcRcv.rcHKr, state.rcAD, encHdr) + if (rc.rcRcv) { + hdr = tryDecryptHeader(rc.rcRcv.rcHKr, rc.rcAD, encHdr) if (hdr) ratchetStep = "same" } - - // Try next header key (advance ratchet) if (!hdr) { - hdr = tryDecryptHeader(state.rcNHKr, state.rcAD, encHdr) - if (!hdr) throw new Error("rcDecrypt: header decryption failed") + hdr = tryDecryptHeader(rc.rcNHKr, rc.rcAD, encHdr) + if (!hdr) throw new Error("rcDecrypt: header decryption failed (CERatchetHeader)") ratchetStep = "advance" } - // Upgrade version - let rc = state - if (hdr.msgMaxVersion > rc.rcVersion) { - rc = {...rc, rcVersion: Math.max(rc.rcVersion, Math.min(hdr.msgMaxVersion, rc.rcMaxVersion))} + // Version upgrade + let state = rc + const {current, maxSupported} = rc.rcVersion + if (hdr.msgMaxVersion > current) { + state = {...state, rcVersion: {...state.rcVersion, current: Math.max(current, Math.min(hdr.msgMaxVersion, maxSupported))}} } let newSkipped = new Map(skippedKeys) if (ratchetStep === "advance") { - // Skip message keys for previous ratchet - const skipResult = skipMessageKeys(rc, newSkipped, hdr.msgPN) - rc = skipResult.state - newSkipped = skipResult.skippedKeys - - // DH ratchet step - const rcDHRs_new = generateX448KeyPair() - // state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, header.dh)) - const {rk: rcRK1, ck: rcCKr, nhk: rcNHKr} = rootKdf(rc.rcRK, hdr.msgDHRs, rc.rcDHRs, null) - // state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs', header.dh)) - const {rk: rcRK2, ck: rcCKs, nhk: rcNHKs} = rootKdf(rcRK1, hdr.msgDHRs, rcDHRs_new.privateKey, null) - - rc = { - ...rc, - rcDHRs: rcDHRs_new.privateKey, - rcRK: rcRK2, - rcSnd: {rcDHRr: hdr.msgDHRs, rcCKs, rcHKs: rc.rcNHKs}, - rcRcv: {rcCKr, rcHKr: rc.rcNHKr}, + // SkipMessageKeysHE(state, header.pn) + const skip1 = skipMessageKeys(state, newSkipped, hdr.msgPN) + state = skip1.state; newSkipped = skip1.skippedKeys + + // DHRatchetPQ2HE(state, header) - ratchet step (lines 1043-1071) + const {kemSS, kemSS2, rcKEM: rcKEM_} = pqRatchetStep(state, hdr.msgKEM) + const newDHRs = generateX448KeyPair() + // state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr) || ss) + const kdf1 = rootKdf(state.rcRK, hdr.msgDHRs, state.rcDHRs, kemSS) + // state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs', state.DHRr) || state.PQRss) + const kdf2 = rootKdf(kdf1.rk, hdr.msgDHRs, newDHRs.privateKey, kemSS2) + const sndKEM = kemSS2 !== null + const rcvKEM = kemSS !== null + const rcEnableKEM_ = sndKEM || rcvKEM || rcKEM_ !== null + + state = { + ...state, + rcDHRs: newDHRs.privateKey, + rcKEM: rcKEM_, + rcSupportKEM: pqEnableSupport(state.rcVersion.current, state.rcSupportKEM, rcEnableKEM_), + rcEnableKEM: rcEnableKEM_, + rcSndKEM: sndKEM, + rcRcvKEM: rcvKEM, + rcRK: kdf2.rk, + rcSnd: {rcDHRr: hdr.msgDHRs, rcCKs: kdf2.ck, rcHKs: state.rcNHKs}, + rcRcv: {rcCKr: kdf1.ck, rcHKr: state.rcNHKr}, rcPN: rc.rcNs, rcNs: 0, rcNr: 0, - rcNHKs, - rcNHKr, + rcNHKs: kdf2.nhk, + rcNHKr: kdf1.nhk, } } - // Skip message keys for current ratchet - const skipResult2 = skipMessageKeys(rc, newSkipped, hdr.msgNs) - rc = skipResult2.state - newSkipped = skipResult2.skippedKeys + // SkipMessageKeysHE(state, header.n) + const skip2 = skipMessageKeys(state, newSkipped, hdr.msgNs) + state = skip2.state; newSkipped = skip2.skippedKeys - if (!rc.rcRcv) throw new Error("rcDecrypt: no receiving ratchet after skip") + if (!state.rcRcv) throw new Error("rcDecrypt: no receiving ratchet after skip") - // Decrypt message - const {ck: rcCKr, mk, iv} = chainKdf(rc.rcRcv.rcCKr) - const plaintext = decryptMessage(mk, iv, rc.rcAD, encMsg.emHeader, encMsg) + // state.CKr, mk = KDF_CK(state.CKr) + const chain = chainKdf(state.rcRcv.rcCKr) - rc = { - ...rc, - rcRcv: {...rc.rcRcv, rcCKr}, - rcNr: rc.rcNr + 1, + // DECRYPT(mk, cipher-text, CONCAT(AD, enc_header)) + const bodyAD = concatBytes(state.rcAD, encMsg.emHeader) + const plaintext = decryptAEAD(chain.mk, chain.iv, bodyAD, encMsg.emBody, encMsg.emAuthTag) + + // state.Nr += 1 + state = { + ...state, + rcRcv: {...state.rcRcv, rcCKr: chain.ck}, + rcNr: state.rcNr + 1, } - return {plaintext, state: rc, skippedKeys: newSkipped} + return {plaintext, state, skippedKeys: newSkipped} } +// -- skipMessageKeys (lines 1105-1121) + function skipMessageKeys( - state: RatchetState, + rc: Ratchet, skippedKeys: SkippedMsgKeys, untilN: number, -): {state: RatchetState; skippedKeys: SkippedMsgKeys} { - if (!state.rcRcv) return {state, skippedKeys} - const rcv = state.rcRcv - const rcNr = state.rcNr +): {state: Ratchet; skippedKeys: SkippedMsgKeys} { + if (!rc.rcRcv) return {state: rc, skippedKeys} + const rcv = rc.rcRcv + const rcNr = rc.rcNr - if (rcNr > untilN + 1) throw new Error("rcDecrypt: earlier message") - if (rcNr === untilN + 1) throw new Error("rcDecrypt: duplicate message") - if (rcNr + MAX_SKIP < untilN) throw new Error("rcDecrypt: too many skipped") - if (rcNr === untilN) return {state, skippedKeys} + if (rcNr > untilN + 1) throw new Error("rcDecrypt: earlier message (CERatchetEarlierMessage)") + if (rcNr === untilN + 1) throw new Error("rcDecrypt: duplicate message (CERatchetDuplicateMessage)") + if (rcNr + MAX_SKIP < untilN) throw new Error("rcDecrypt: too many skipped (CERatchetTooManySkipped)") + if (rcNr === untilN) return {state: rc, skippedKeys} - // Advance receiving ratchet, storing skipped keys + // advanceRcvRatchet let ck = rcv.rcCKr let nr = rcNr const hkHex = hexKey(rcv.rcHKr) @@ -538,15 +647,103 @@ function skipMessageKeys( newSkipped.set(hkHex, msgKeys) return { - state: {...state, rcRcv: {...rcv, rcCKr: ck}, rcNr: nr}, + state: {...rc, rcRcv: {...rcv, rcCKr: ck}, rcNr: nr}, skippedKeys: newSkipped, } } -function hexToBytes(hex: string): Uint8Array { - const bytes = new Uint8Array(hex.length / 2) - for (let i = 0; i < hex.length; i += 2) { - bytes[i / 2] = parseInt(hex.substring(i, i + 2), 16) +// -- tryDecryptSkipped (lines 1122-1141) + +function tryDecryptSkipped( + rc: Ratchet, + skippedKeys: SkippedMsgKeys, + encHdr: EncMessageHeader, + encMsg: EncRatchetMessage, +): DecryptResult | null { + for (const [hkHex, msgKeys] of skippedKeys) { + const hk = hexToBytes(hkHex) + const hdr = tryDecryptHeader(hk, rc.rcAD, encHdr) + if (hdr) { + const mk = msgKeys.get(hdr.msgNs) + if (mk) { + const bodyAD = concatBytes(rc.rcAD, encMsg.emHeader) + const plaintext = decryptAEAD(mk.mk, mk.iv, bodyAD, encMsg.emBody, encMsg.emAuthTag) + const newMsgKeys = new Map(msgKeys) + newMsgKeys.delete(hdr.msgNs) + const newSkipped = new Map(skippedKeys) + if (newMsgKeys.size === 0) newSkipped.delete(hkHex) + else newSkipped.set(hkHex, newMsgKeys) + return {plaintext, state: rc, skippedKeys: newSkipped} + } + // Header decrypted but msgNs not in skipped keys - check if same/advance ratchet + // For now, fall through to normal decrypt + } + } + return null +} + +// -- pqRatchetStep (lines 1072-1104) +// Returns (kemSS for receive rootKdf, kemSS' for send rootKdf, new RatchetKEM state) + +function pqRatchetStep( + rc: Ratchet, + msgKEM: RKEMParams | null, +): {kemSS: Uint8Array | null; kemSS2: Uint8Array | null; rcKEM: RatchetKEM | null} { + const pqEnc = rc.rcEnableKEM + const v = rc.rcVersion.current + + if (!msgKEM) { + // Received message does not have KEM in header + if (!rc.rcKEM && pqEnc && v >= pqRatchetE2EEncryptVersion) { + // User enabled KEM but no KEM state yet - generate new keypair + const rcPQRs = sntrup761Keypair() + return {kemSS: null, kemSS2: null, rcKEM: {rcPQRs, rcKEMs: null}} + } + return {kemSS: null, kemSS2: null, rcKEM: null} + } + + // Received message has KEM in header + if (pqEnc && v >= pqRatchetE2EEncryptVersion) { + // Get shared secret from received KEM params + const {ss, rcPQRr} = kemSharedSecret(rc.rcKEM, msgKEM) + // state.PQRct = PQKEM-ENC(state.PQRr, state.PQRss) + const kemEncResult = sntrup761Enc(rcPQRr) + // state.PQRs = GENERATE_PQKEM() + const rcPQRs = sntrup761Keypair() + const kem: RatchetKEM = { + rcPQRs, + rcKEMs: {rcPQRr, rcPQRss: kemEncResult.sharedSecret, rcPQRct: kemEncResult.ciphertext}, + } + return {kemSS: ss, kemSS2: kemEncResult.sharedSecret, rcKEM: kem} + } + + // PQ not enabled but message has KEM - extract shared secret only (no new KEM state) + const {ss} = kemSharedSecret(rc.rcKEM, msgKEM) + return {kemSS: ss, kemSS2: null, rcKEM: null} +} + +// Extract shared secret from received KEM params (lines 1097-1104) +function kemSharedSecret( + rcKEM: RatchetKEM | null, + params: RKEMParams, +): {ss: Uint8Array | null; rcPQRr: Uint8Array} { + if (params.type === "proposed") { + // RKParamsProposed k -> no shared secret yet, just received the public key + return {ss: null, rcPQRr: params.kemPk} + } + // RKParamsAccepted ct k -> decapsulate ct with our private KEM key + if (!rcKEM) throw new Error("pqRatchetStep: CERatchetKEMState - no KEM state for accepted params") + const ss = sntrup761Dec(params.kemCt, rcKEM.rcPQRs.secretKey) + return {ss, rcPQRr: params.kemPk} +} + +// -- decryptHeader helper (lines 1151-1153) + +function tryDecryptHeader(headerKey: Uint8Array, ad: Uint8Array, encHdr: EncMessageHeader): MsgHeader | null { + try { + const plainHeader = decryptAEAD(headerKey, encHdr.ehIV, ad, encHdr.ehBody, encHdr.ehAuthTag) + return decodeMsgHeader(encHdr.ehVersion, plainHeader) + } catch { + return null } - return bytes } diff --git a/tests/SMPWebTests.hs b/tests/SMPWebTests.hs index ac574aee4..d2278c353 100644 --- a/tests/SMPWebTests.hs +++ b/tests/SMPWebTests.hs @@ -31,6 +31,7 @@ import Simplex.Messaging.Crypto (Algorithm (..)) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Crypto.SNTRUP761.Bindings (KEMPublicKey (..), KEMSecretKey, KEMCiphertext (..), KEMSharedKey (..), sntrup761Keypair, sntrup761Enc, sntrup761Dec) import qualified Crypto.Cipher.Types as AES +import qualified Data.Map.Strict as M import qualified Data.ByteArray as BA import Simplex.Messaging.Crypto.ShortLink (contactShortLinkKdf, invShortLinkKdf) import Simplex.Messaging.Encoding @@ -445,8 +446,8 @@ smpWebTests_ = do <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, a1.publicKey, a2.publicKey);" <> "const ap = pqX3dhRcv(a1.privateKey, a2.privateKey, b1.publicKey, b2.publicKey);" <> "const b3 = generateX448KeyPair();" - <> "let bob = initSndRatchet(3, 3, a2.publicKey, b3.privateKey, bp, false);" - <> "let alice = initRcvRatchet(3, 3, a2.privateKey, ap, false);" + <> "let bob = initSndRatchet({current:3,maxSupported:3}, a2.publicKey, b3.privateKey, bp, null);" + <> "let alice = initRcvRatchet({current:3,maxSupported:3}, a2.privateKey, ap, null, false);" <> "let sk = new Map();" -- Bob sends 3 <> "const e1 = rcEncrypt(bob, new TextEncoder().encode('msg1'), 100); bob = e1.state;" @@ -503,19 +504,167 @@ smpWebTests_ = do <> "const a1Priv = " <> jsUint8 alicePriv1 <> ";" <> "const a2Priv = " <> jsUint8 alicePriv2 <> ";" <> "const ap = pqX3dhRcv(a1Priv, a2Priv, bobPk1Raw, bobPk2Raw);" - <> "const alice = initRcvRatchet(3, 3, a2Priv, ap, false);" - -- Debug: output rcAD length and first bytes to compare - <> "try {" + <> "const alice = initRcvRatchet({current:3,maxSupported:3}, a2Priv, ap, null, false);" <> "const dec = rcDecrypt(alice, new Map(), " <> jsUint8 ciphertext <> ");" <> jsOut ("dec.plaintext") - <> "} catch(e) {" - <> "console.warn('ERROR:', e.message);" - <> "console.warn('rcAD length:', alice.rcAD.length);" - <> "console.warn('rcAD first 10:', Array.from(alice.rcAD.subarray(0,10)));" - <> "process.exit(1);" - <> "}" tsResult `shouldBe` "hello from haskell ratchet" + it "cross-language: TypeScript encrypts, Haskell decrypts" $ do + -- Round 1: Haskell generates alice's keys, outputs encoded E2E params + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + (alicePk1, alicePk2, _pKem, aliceE2E) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOff + let aliceE2EBytes = smpEncode aliceE2E + + -- Round 2: TypeScript generates bob's keys, does X3DH, inits snd ratchet, encrypts + tsOutput <- callNode $ impEnc <> impRatchet + -- Parse alice's E2E params + <> "const d = new Decoder(" <> jsUint8 aliceE2EBytes <> ");" + <> "const aliceV = d.anyByte() * 256 + d.anyByte();" + <> "const alicePk1Raw = decodePubKeyX448(decodeBytes(d));" + <> "const alicePk2Raw = decodePubKeyX448(decodeBytes(d));" + -- Bob generates keys + <> "const b1 = generateX448KeyPair(), b2 = generateX448KeyPair();" + <> "const b3 = generateX448KeyPair();" + -- X3DH (bob is sender) + <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, alicePk1Raw, alicePk2Raw);" + -- Init sending ratchet + <> "let bob = initSndRatchet({current:3,maxSupported:3}, alicePk2Raw, b3.privateKey, bp, null);" + -- Encrypt + <> "const enc = rcEncrypt(bob, new TextEncoder().encode('hello from typescript ratchet'), 100);" + -- Output: bob's E2E params (version + 2 DER keys + Nothing KEM) + ciphertext + <> "const bobE2E = new Uint8Array([...encodeWord16(3), ...encodeBytes(encodePubKeyX448(b1.publicKey)), ...encodeBytes(encodePubKeyX448(b2.publicKey)), 0x30]);" + <> "const lenBuf = new Uint8Array(2); lenBuf[0] = (bobE2E.length >> 8) & 0xff; lenBuf[1] = bobE2E.length & 0xff;" + <> "const ctLenBuf = new Uint8Array(2); ctLenBuf[0] = (enc.ciphertext.length >> 8) & 0xff; ctLenBuf[1] = enc.ciphertext.length & 0xff;" + <> jsOut ("new Uint8Array([...lenBuf, ...bobE2E, ...ctLenBuf, ...enc.ciphertext])") + -- Parse output: [2 bytes e2e len][e2e bytes][2 bytes ct len][ct bytes] + let (e2eLenBs, rest1) = B.splitAt 2 tsOutput + bobE2ELen = fromIntegral (B.index e2eLenBs 0) * 256 + fromIntegral (B.index e2eLenBs 1) + (bobE2EBytes, rest2) = B.splitAt bobE2ELen rest1 + (ctLenBs, rest3) = B.splitAt 2 rest2 + ctLen = fromIntegral (B.index ctLenBs 0) * 256 + fromIntegral (B.index ctLenBs 1) + ciphertext = B.take ctLen rest3 + + -- Round 3: Haskell decodes bob's params, does X3DH, inits rcv ratchet, decrypts + Right (CR.AE2ERatchetParams _ bobE2EParams :: CR.AE2ERatchetParams 'X448) <- pure $ smpDecode bobE2EBytes + Right (aliceInitParams, _) <- runExceptT $ CR.pqX3dhRcv alicePk1 alicePk2 Nothing bobE2EParams + let aliceRatchet = CR.initRcvRatchet (CR.RatchetVersions v v) alicePk2 (aliceInitParams, Nothing) CR.PQSupportOff + gAlice <- C.newRandom + Right (msg, _, _) <- runExceptT $ CR.rcDecrypt gAlice aliceRatchet M.empty ciphertext + msg `shouldBe` Right "hello from typescript ratchet" + + it "cross-language: PQ X3DH - Haskell proposes KEM, TypeScript accepts, encrypts" $ do + -- Round 1: Haskell (alice) generates keys with PQ KEM proposal + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + (alicePk1, alicePk2, alicePKem_@(Just _), aliceE2E) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOn + let aliceE2EBytes = smpEncode aliceE2E + + -- Round 2: TypeScript (bob) accepts KEM, does X3DH, inits snd ratchet, encrypts + tsOutput <- callNode $ impEnc <> impSodium <> impRatchet <> impSntrup + -- Parse alice's E2E params (v3: version + pk1 + pk2 + Maybe ARKEMParams) + <> "const d = new Decoder(" <> jsUint8 aliceE2EBytes <> ");" + <> "const aliceV = d.anyByte() * 256 + d.anyByte();" + <> "const alicePk1Raw = decodePubKeyX448(decodeBytes(d));" + <> "const alicePk2Raw = decodePubKeyX448(decodeBytes(d));" + -- Parse Maybe ARKEMParams: '1' + 'P' + KEMPublicKey(Large) + <> "const maybeByte = d.anyByte();" + <> "if (maybeByte !== 0x31) throw new Error('expected Just KEM');" + <> "const kemTag = d.anyByte();" + <> "if (kemTag !== 0x50) throw new Error('expected P (proposed), got ' + kemTag);" + <> "const aliceKemPk = decodeLarge(d);" + -- Bob generates DH keys + <> "const b1 = generateX448KeyPair(), b2 = generateX448KeyPair();" + <> "const b3 = generateX448KeyPair();" + -- Bob encapsulates against alice's KEM public key + <> "const kemEnc = sntrup761Enc(aliceKemPk);" + -- Bob generates his own KEM keypair for future ratchet steps + <> "const bobKem = sntrup761Keypair();" + -- Construct kemAccepted matching Haskell RatchetKEMAccepted: + -- rcPQRr = alice's KEM public key (received) + -- rcPQRss = shared secret (from encapsulation) + -- rcPQRct = ciphertext (sent to alice) + <> "const kemAccepted = {rcPQRr: aliceKemPk, rcPQRss: kemEnc.sharedSecret, rcPQRct: kemEnc.ciphertext};" + -- X3DH with kemAccepted (folds shared secret into HKDF AND stores in RatchetInitParams) + <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, alicePk1Raw, alicePk2Raw, kemAccepted);" + -- Init sending ratchet with bob's KEM keypair + <> "let bob = initSndRatchet({current:3,maxSupported:3}, alicePk2Raw, b3.privateKey, bp, bobKem);" + -- Encrypt + <> "const enc = rcEncrypt(bob, new TextEncoder().encode('hello with PQ'), 100);" + -- Build bob's E2E params: version + pk1 + pk2 + Just(Accepted(ct, bobKemPk)) + -- smpEncode ('A', ct, bobKemPk) where ct and pk are Large-encoded + <> "const bobE2E = new Uint8Array([" + <> " ...encodeWord16(3)," + <> " ...encodeBytes(encodePubKeyX448(b1.publicKey))," + <> " ...encodeBytes(encodePubKeyX448(b2.publicKey))," + <> " 0x31," -- Just + <> " 0x41," -- 'A' = Accepted + <> " ...new Uint8Array([(kemEnc.ciphertext.length >> 8) & 0xff, kemEnc.ciphertext.length & 0xff]), ...kemEnc.ciphertext," + <> " ...new Uint8Array([(bobKem.publicKey.length >> 8) & 0xff, bobKem.publicKey.length & 0xff]), ...bobKem.publicKey," + <> "]);" + <> "const lenBuf = new Uint8Array(2); lenBuf[0] = (bobE2E.length >> 8) & 0xff; lenBuf[1] = bobE2E.length & 0xff;" + <> "const ctLenBuf = new Uint8Array(2); ctLenBuf[0] = (enc.ciphertext.length >> 8) & 0xff; ctLenBuf[1] = enc.ciphertext.length & 0xff;" + <> jsOut ("new Uint8Array([...lenBuf, ...bobE2E, ...ctLenBuf, ...enc.ciphertext])") + let (e2eLenBs, rest1) = B.splitAt 2 tsOutput + bobE2ELen = fromIntegral (B.index e2eLenBs 0) * 256 + fromIntegral (B.index e2eLenBs 1) + (bobE2EBytes, rest2) = B.splitAt bobE2ELen rest1 + (ctLenBs, rest3) = B.splitAt 2 rest2 + ctLen = fromIntegral (B.index ctLenBs 0) * 256 + fromIntegral (B.index ctLenBs 1) + ciphertext = B.take ctLen rest3 + + -- Round 3: Haskell decodes bob's params (with KEM accepted), does X3DH with KEM, decrypts + Right (CR.AE2ERatchetParams _ bobE2EParams :: CR.AE2ERatchetParams 'X448) <- pure $ smpDecode bobE2EBytes + Right (aliceInitParams, aliceKemKp_) <- runExceptT $ CR.pqX3dhRcv alicePk1 alicePk2 alicePKem_ bobE2EParams + let aliceRatchet = CR.initRcvRatchet (CR.RatchetVersions v v) alicePk2 (aliceInitParams, aliceKemKp_) CR.PQSupportOn + gAlice <- C.newRandom + result <- runExceptT $ CR.rcDecrypt gAlice aliceRatchet M.empty ciphertext + case result of + Right (msg, _, _) -> msg `shouldBe` Right "hello with PQ" + Left e -> expectationFailure $ "rcDecrypt failed: " <> show e + + it "TypeScript PQ ratchet self-consistency: multi-message with KEM ratchet steps" $ do + tsResult <- callNode $ impSodium <> impSntrup <> impRatchet + <> "const a1 = generateX448KeyPair(), a2 = generateX448KeyPair();" + <> "const b1 = generateX448KeyPair(), b2 = generateX448KeyPair();" + <> "const b3 = generateX448KeyPair();" + -- Alice proposes KEM + <> "const aliceKem = sntrup761Keypair();" + -- Bob accepts: encapsulate against alice's KEM public key + <> "const kemEnc = sntrup761Enc(aliceKem.publicKey);" + <> "const bobKem = sntrup761Keypair();" + <> "const kemAccepted = {rcPQRr: aliceKem.publicKey, rcPQRss: kemEnc.sharedSecret, rcPQRct: kemEnc.ciphertext};" + -- Alice receives bob's acceptance: decapsulate to get shared secret + <> "const aliceSS = sntrup761Dec(kemEnc.ciphertext, aliceKem.secretKey);" + <> "const aliceKemAccepted = {rcPQRr: bobKem.publicKey, rcPQRss: aliceSS, rcPQRct: kemEnc.ciphertext};" + -- X3DH for both sides + <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, a1.publicKey, a2.publicKey, kemAccepted);" + <> "const ap = pqX3dhRcv(a1.privateKey, a2.privateKey, b1.publicKey, b2.publicKey, aliceKemAccepted);" + -- Init ratchets with KEM keypairs + <> "let bob = initSndRatchet({current:3,maxSupported:3}, a2.publicKey, b3.privateKey, bp, bobKem);" + <> "let alice = initRcvRatchet({current:3,maxSupported:3}, a2.privateKey, ap, aliceKem, true);" + <> "let sk = new Map();" + -- Bob sends msg1 (has KEM params in header from initSndRatchet) + <> "const e1 = rcEncrypt(bob, new TextEncoder().encode('pq msg1'), 100); bob = e1.state;" + -- Alice decrypts msg1 (triggers ratchet advance with KEM) + <> "let d1 = rcDecrypt(alice, sk, e1.ciphertext); alice = d1.state; sk = d1.skippedKeys;" + -- Alice sends msg2 (ratchet advanced, has KEM params from pqRatchetStep) + <> "const e2 = rcEncrypt(alice, new TextEncoder().encode('pq msg2'), 100); alice = e2.state;" + -- Bob decrypts msg2 (triggers ratchet advance with KEM on bob's side) + <> "let d2 = rcDecrypt(bob, new Map(), e2.ciphertext); bob = d2.state;" + -- Bob sends msg3 (another ratchet advance with KEM) + <> "const e3 = rcEncrypt(bob, new TextEncoder().encode('pq msg3'), 100); bob = e3.state;" + -- Alice decrypts msg3 + <> "let d3 = rcDecrypt(alice, sk, e3.ciphertext); alice = d3.state; sk = d3.skippedKeys;" + -- Verify all messages + <> "const ok = new TextDecoder().decode(d1.plaintext) === 'pq msg1'" + <> " && new TextDecoder().decode(d2.plaintext) === 'pq msg2'" + <> " && new TextDecoder().decode(d3.plaintext) === 'pq msg3'" + -- Verify KEM state is maintained + <> " && alice.rcKEM !== null && bob.rcKEM !== null" + <> " && alice.rcSndKEM === true && bob.rcSndKEM === true;" + <> jsOut ("new Uint8Array([ok ? 1 : 0])") + tsResult `shouldBe` B.singleton 1 + describe "DER encoding" $ do it "X448 DER round-trips" $ do tsResult <- callNode $ impRatchet From 35b264b25be11952f0e3f85602202a3f1fcfd379 Mon Sep 17 00:00:00 2001 From: "Evgeny @ SimpleX Chat" <259188159+evgeny-simplex@users.noreply.github.com> Date: Sat, 16 May 2026 11:57:39 +0000 Subject: [PATCH 7/8] test typescript ratchets --- smp-web/package.json | 3 +- smp-web/tests/ratchet-repl.ts | 326 +++++++++++++++++++++++++ smp-web/tsconfig.test.json | 15 ++ tests/AgentTests/DoubleRatchetTests.hs | 42 ++-- tests/SMPWebTests.hs | 244 +++++++++++++++++- 5 files changed, 610 insertions(+), 20 deletions(-) create mode 100644 smp-web/tests/ratchet-repl.ts create mode 100644 smp-web/tsconfig.test.json diff --git a/smp-web/package.json b/smp-web/package.json index 33b5d52d4..ab9ff7c1e 100644 --- a/smp-web/package.json +++ b/smp-web/package.json @@ -16,7 +16,8 @@ "scripts": { "build:wasm": "mkdir -p dist/wasm && npx emcc cbits/sntrup761_wasm.c ../cbits/sntrup761.c cbits/sha512.c -I../cbits -O2 -o dist/wasm/sntrup761.mjs -s EXPORTED_FUNCTIONS='[\"_sntrup761_wasm_keypair\",\"_sntrup761_wasm_enc\",\"_sntrup761_wasm_dec\",\"_malloc\",\"_free\"]' -s EXPORTED_RUNTIME_METHODS='[\"ccall\",\"cwrap\",\"HEAPU8\"]' -s MODULARIZE=1 -s EXPORT_NAME='createSntrup761' -s ALLOW_MEMORY_GROWTH=1 -s ENVIRONMENT='web,node' --js-library cbits/js_random.js && cp cbits/sntrup761.d.mts dist/wasm/", "build:ts": "tsc", - "build": "npm run build:wasm && npm run build:ts" + "build:test": "tsc -p tsconfig.test.json", + "build": "npm run build:wasm && npm run build:ts && npm run build:test" }, "dependencies": { "@noble/ciphers": "^2.2.0", diff --git a/smp-web/tests/ratchet-repl.ts b/smp-web/tests/ratchet-repl.ts new file mode 100644 index 000000000..4b8bc4033 --- /dev/null +++ b/smp-web/tests/ratchet-repl.ts @@ -0,0 +1,326 @@ +// Double ratchet REPL for cross-language testing. +// Holds one ratchet state, reads commands from stdin, writes results to stdout. +// +// Init protocol: +// INIT_RCV +// → ok: +// COMPLETE +// → ok +// INIT_SND +// → ok: +// kemMode: none | propose | accept +// +// Encrypt/decrypt operators (same syntax as Haskell DoubleRatchetTests): +// \#> encrypt, assert noSndKEM +// !#> <plaintext> encrypt, assert hasSndKEM +// \#>! <plaintext> encrypt PQEncOn, assert noSndKEM +// !#>! <plaintext> encrypt PQEncOn, assert hasSndKEM +// !#>\ <plaintext> encrypt PQEncOff, assert hasSndKEM +// \#>\ <plaintext> encrypt PQEncOff, assert noSndKEM +// <#\ <hex ct> <expected> decrypt, assert noRcvKEM +// <#! <hex ct> <expected> decrypt, assert hasRcvKEM +// +// Plain encrypt/decrypt (no assertions): +// E <plaintext> → ok: <hex ciphertext> +// D <hex ciphertext> → ok: <plaintext> +// +// Response format: ok: <data> or error: <message> + +import {createInterface} from "readline" +import { + generateX448KeyPair, pqX3dhSnd, pqX3dhRcv, + encodePubKeyX448, decodePubKeyX448, + initSndRatchet, initRcvRatchet, + rcEncrypt, rcDecrypt, + rootKdf, + type Ratchet, type SkippedMsgKeys, type RatchetVersions, + type RatchetInitParams, type RatchetKEMAccepted, +} from "../dist/crypto/ratchet.js" +import {initSntrup761, sntrup761Keypair, sntrup761Enc, sntrup761Dec} from "../dist/crypto/sntrup761.js" +import type {KEMKeyPair} from "../dist/crypto/sntrup761.js" +import { + Decoder, decodeBytes, decodeLarge, encodeBytes, encodeWord16, concatBytes, +} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" + +// -- State + +let ratchet: Ratchet | null = null +let skippedKeys: SkippedMsgKeys = new Map() +const PADDED_MSG_LEN = 16000 + +// Intermediate state for RCV init (between INIT_RCV and COMPLETE) +let rcvInitState: { + privKey1: Uint8Array + privKey2: Uint8Array + kemKeyPair: KEMKeyPair | null + pqSupport: boolean +} | null = null + +// -- Hex helpers + +function toHex(bytes: Uint8Array): string { + return Array.from(bytes, b => b.toString(16).padStart(2, "0")).join("") +} + +function fromHex(hex: string): Uint8Array { + const bytes = new Uint8Array(hex.length / 2) + for (let i = 0; i < hex.length; i += 2) + bytes[i / 2] = parseInt(hex.substring(i, i + 2), 16) + return bytes +} + +// -- E2E params helpers + +// Parse E2ERatchetParams: version(Word16) + pk1(ByteString) + pk2(ByteString) + Maybe KEMParams +interface ParsedE2EParams { + version: number + pk1Raw: Uint8Array // raw X448 public key + pk2Raw: Uint8Array // raw X448 public key + kemPk: Uint8Array | null // KEM public key if proposed + kemCt: Uint8Array | null // KEM ciphertext if accepted + kemAcceptPk: Uint8Array | null // KEM public key in accepted +} + +function parseE2EParams(data: Uint8Array): ParsedE2EParams { + const d = new Decoder(data) + const version = d.anyByte() * 256 + d.anyByte() + const pk1Raw = decodePubKeyX448(decodeBytes(d)) + const pk2Raw = decodePubKeyX448(decodeBytes(d)) + let kemPk: Uint8Array | null = null + let kemCt: Uint8Array | null = null + let kemAcceptPk: Uint8Array | null = null + if (version >= 3 && d.remaining() > 0) { + const maybeByte = d.anyByte() + if (maybeByte === 0x31) { // Just + const tag = d.anyByte() + if (tag === 0x50) { // 'P' Proposed + kemPk = decodeLarge(d) + } else if (tag === 0x41) { // 'A' Accepted + kemCt = decodeLarge(d) + kemAcceptPk = decodeLarge(d) + } + } + } + return {version, pk1Raw, pk2Raw, kemPk, kemCt, kemAcceptPk} +} + +// Encode E2ERatchetParams for sending to peer +function encodeE2EParams( + version: number, + pk1Raw: Uint8Array, pk2Raw: Uint8Array, + kemPk: Uint8Array | null, // for proposed + kemCt: Uint8Array | null, // for accepted + kemAcceptPk: Uint8Array | null, // public key in accepted +): Uint8Array { + const vBytes = new Uint8Array(2) + vBytes[0] = (version >> 8) & 0xff + vBytes[1] = version & 0xff + const parts = [vBytes, encodeBytes(encodePubKeyX448(pk1Raw)), encodeBytes(encodePubKeyX448(pk2Raw))] + if (version >= 3) { + if (kemCt && kemAcceptPk) { + // Just Accepted + parts.push(new Uint8Array([0x31, 0x41])) // Just + 'A' + parts.push(new Uint8Array([(kemCt.length >> 8) & 0xff, kemCt.length & 0xff])) + parts.push(kemCt) + parts.push(new Uint8Array([(kemAcceptPk.length >> 8) & 0xff, kemAcceptPk.length & 0xff])) + parts.push(kemAcceptPk) + } else if (kemPk) { + // Just Proposed + parts.push(new Uint8Array([0x31, 0x50])) // Just + 'P' + parts.push(new Uint8Array([(kemPk.length >> 8) & 0xff, kemPk.length & 0xff])) + parts.push(kemPk) + } else { + // Nothing + parts.push(new Uint8Array([0x30])) + } + } + return concatBytes(...parts) +} + +// -- Init handlers + +function handleInitRcv(version: number, pqSupport: boolean): string { + const kp1 = generateX448KeyPair() + const kp2 = generateX448KeyPair() + let kemKeyPair: KEMKeyPair | null = null + let kemPk: Uint8Array | null = null + if (pqSupport) { + kemKeyPair = sntrup761Keypair() + kemPk = kemKeyPair.publicKey + } + rcvInitState = {privKey1: kp1.privateKey, privKey2: kp2.privateKey, kemKeyPair, pqSupport} + const params = encodeE2EParams(version, kp1.publicKey, kp2.publicKey, kemPk, null, null) + return "ok: " + toHex(params) +} + +function handleComplete(peerParamsHex: string): string { + if (!rcvInitState) return "error: not in RCV init state" + const {privKey1, privKey2, kemKeyPair, pqSupport} = rcvInitState + const peerParams = parseE2EParams(fromHex(peerParamsHex)) + + // Build kemAccepted for X3DH if peer accepted our KEM proposal + let kemAccepted: RatchetKEMAccepted | null = null + if (peerParams.kemCt && peerParams.kemAcceptPk && kemKeyPair) { + const ss = sntrup761Dec(peerParams.kemCt, kemKeyPair.secretKey) + kemAccepted = {rcPQRr: peerParams.kemAcceptPk, rcPQRss: ss, rcPQRct: peerParams.kemCt} + } + + // X3DH (receiver side) + const initParams = pqX3dhRcv(privKey1, privKey2, peerParams.pk1Raw, peerParams.pk2Raw, kemAccepted) + + // Init receiving ratchet + const vs: RatchetVersions = {current: peerParams.version, maxSupported: peerParams.version} + ratchet = initRcvRatchet(vs, privKey2, initParams, kemKeyPair, pqSupport) + skippedKeys = new Map() + rcvInitState = null + return "ok" +} + +function handleInitSnd(version: number, kemMode: string, peerParamsHex: string): string { + const peerParams = parseE2EParams(fromHex(peerParamsHex)) + + const kp1 = generateX448KeyPair() + const kp2 = generateX448KeyPair() + const kp3 = generateX448KeyPair() // fresh DH key for ratchet + + // KEM handling + let kemAccepted: RatchetKEMAccepted | null = null + let ownKemKp: KEMKeyPair | null = null + let outKemPk: Uint8Array | null = null + let outKemCt: Uint8Array | null = null + let outKemAcceptPk: Uint8Array | null = null + + if (kemMode === "accept" && peerParams.kemPk) { + // Accept peer's KEM proposal + const encResult = sntrup761Enc(peerParams.kemPk) + ownKemKp = sntrup761Keypair() + kemAccepted = {rcPQRr: peerParams.kemPk, rcPQRss: encResult.sharedSecret, rcPQRct: encResult.ciphertext} + outKemCt = encResult.ciphertext + outKemAcceptPk = ownKemKp.publicKey + } else if (kemMode === "propose") { + ownKemKp = sntrup761Keypair() + outKemPk = ownKemKp.publicKey + } + + // X3DH (sender side) + const initParams = pqX3dhSnd(kp1.privateKey, kp2.privateKey, peerParams.pk1Raw, peerParams.pk2Raw, kemAccepted) + + // Init sending ratchet + const vs: RatchetVersions = {current: version, maxSupported: version} + ratchet = initSndRatchet(vs, peerParams.pk2Raw, kp3.privateKey, initParams, ownKemKp) + skippedKeys = new Map() + + const params = encodeE2EParams(version, kp1.publicKey, kp2.publicKey, outKemPk, outKemCt, outKemAcceptPk) + return "ok: " + toHex(params) +} + +// -- Encrypt/decrypt handlers + +function handleEncrypt(kemAssert: boolean | null, _pqPref: boolean | null, plaintext: string): string { + if (!ratchet) return "error: not initialized" + try { + const result = rcEncrypt(ratchet, new TextEncoder().encode(plaintext), PADDED_MSG_LEN) + ratchet = result.state + if (kemAssert === true && !ratchet.rcSndKEM) return "error: expected hasSndKEM" + if (kemAssert === false && ratchet.rcSndKEM) return "error: expected noSndKEM" + return "ok: " + toHex(result.ciphertext) + } catch (e: any) { + return "error: " + e.message + } +} + +function handleDecrypt(kemAssert: boolean | null, hexCt: string, expectedPlaintext: string | null): string { + if (!ratchet) return "error: not initialized" + try { + const ct = fromHex(hexCt) + const result = rcDecrypt(ratchet, skippedKeys, ct) + ratchet = result.state + skippedKeys = result.skippedKeys + const plaintext = new TextDecoder().decode(result.plaintext) + if (kemAssert === true && !ratchet.rcRcvKEM) return "error: expected hasRcvKEM" + if (kemAssert === false && ratchet.rcRcvKEM) return "error: expected noRcvKEM" + if (expectedPlaintext !== null && plaintext !== expectedPlaintext) + return "error: expected '" + expectedPlaintext + "', got '" + plaintext + "'" + return "ok: " + plaintext + } catch (e: any) { + return "error: " + e.message + } +} + +// -- Command parser + +function parseLine(line: string): string { + // Init commands + if (line.startsWith("INIT_RCV ")) { + const parts = line.split(" ") + return handleInitRcv(parseInt(parts[1]), parts[2] === "1") + } + if (line.startsWith("COMPLETE ")) { + return handleComplete(line.substring(9).trim()) + } + if (line.startsWith("INIT_SND ")) { + const parts = line.split(" ") + return handleInitSnd(parseInt(parts[1]), parts[2], parts[3]) + } + + // Query commands + if (line === "SNDKEM") { + if (!ratchet) return "error: not initialized" + return "ok: " + (ratchet.rcSndKEM ? "1" : "0") + } + if (line === "RCVKEM") { + if (!ratchet) return "error: not initialized" + return "ok: " + (ratchet.rcRcvKEM ? "1" : "0") + } + + // Encrypt operators: \#> !#> \#>! !#>! !#>\ \#>\ + const encMatch = line.match(/^([!\\])#(>[!\\]?)\s+(.+)$/) + if (encMatch) { + const [, kemChar, arrow, msg] = encMatch + const kemAssert = kemChar === "!" ? true : false + let pqPref: boolean | null = null + if (arrow === ">!") pqPref = true + else if (arrow === ">\\") pqPref = false + return handleEncrypt(kemAssert, pqPref, msg) + } + + // Decrypt operators: <#\ <#! + const decMatch = line.match(/^<#([!\\])\s+(\S+)\s+(.+)$/) + if (decMatch) { + const [, kemChar, hexCt, expected] = decMatch + const kemAssert = kemChar === "!" ? true : false + return handleDecrypt(kemAssert, hexCt, expected) + } + + // Plain encrypt (no assertion) + if (line.startsWith("E ")) { + return handleEncrypt(null, null, line.substring(2)) + } + + // Plain decrypt (no assertion, no expected) + if (line.startsWith("D ")) { + return handleDecrypt(null, line.substring(2), null) + } + + return "error: unknown command: " + line +} + +// -- Main + +async function main() { + await initSntrup761() + + const rl = createInterface({input: process.stdin, terminal: false}) + + for await (const line of rl) { + const trimmed = line.trim() + if (!trimmed) continue + const response = parseLine(trimmed) + process.stdout.write(response + "\n") + } +} + +main().catch(e => { + process.stderr.write("FATAL: " + e.message + "\n") + process.exit(1) +}) diff --git a/smp-web/tsconfig.test.json b/smp-web/tsconfig.test.json new file mode 100644 index 000000000..960f58829 --- /dev/null +++ b/smp-web/tsconfig.test.json @@ -0,0 +1,15 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "ES2022", + "moduleResolution": "node", + "lib": ["ES2022"], + "outDir": "dist-test", + "rootDir": "tests", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "sourceMap": true + }, + "include": ["tests/**/*.ts"] +} diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index eef5be27f..fd160dbd3 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -73,15 +73,19 @@ runMessageTests :: Bool -> Spec runMessageTests initRatchets_ agreeRatchetKEMs = do - it "should encrypt and decrypt messages" $ run $ testEncryptDecrypt agreeRatchetKEMs - it "should encrypt and decrypt skipped messages" $ run $ testSkippedMessages agreeRatchetKEMs - it "should encrypt and decrypt many messages" $ run $ testManyMessages agreeRatchetKEMs - it "should allow skipped after ratchet advance" $ run $ testSkippedAfterRatchetAdvance agreeRatchetKEMs + it "should encrypt and decrypt messages" $ run testEncryptDecrypt + it "should encrypt and decrypt skipped messages" $ run testSkippedMessages + it "should encrypt and decrypt many messages" $ run testManyMessages + it "should allow skipped after ratchet advance" $ run testSkippedAfterRatchetAdvance where run :: (forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a) -> IO () run test = do - withRatchets_ @X25519 initRatchets_ test - withRatchets_ @X448 initRatchets_ test + withRatchets_ @X25519 initRatchets_ (withKEM test) + withRatchets_ @X448 initRatchets_ (withKEM test) + withKEM :: (AlgorithmI a, DhAlgorithm a) => TestRatchets a -> TestRatchets a + withKEM test alice bob encrypt decrypt (#>) = do + when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob + test alice bob encrypt decrypt (#>) testAlgs :: (forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()) -> IO () testAlgs test = test C.SX25519 >> test C.SX448 @@ -146,6 +150,12 @@ type TestRatchets a = EncryptDecryptSpec a -> IO () +-- Peer-polymorphic types for cross-language testing +type EncryptP p = p -> ByteString -> IO (Either CryptoError ByteString) +type DecryptP p = p -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString)) +type EncryptDecryptSpecP p = (p, ByteString) -> p -> Expectation +type TestRatchetsP p = p -> p -> EncryptP p -> DecryptP p -> EncryptDecryptSpecP p -> IO () + deriving instance Eq (Ratchet a) deriving instance Eq (SndRatchet a) @@ -170,9 +180,8 @@ deriving instance Eq (MsgHeader a) initRatchetKEM :: (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO () initRatchetKEM s r = encryptDecrypt (Just $ PQEncOn) (const ()) (const ()) (s, "initialising ratchet") r -testEncryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a -testEncryptDecrypt agreeRatchetKEMs alice bob encrypt decrypt (#>) = do - when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob +testEncryptDecrypt :: TestRatchetsP p +testEncryptDecrypt alice bob encrypt decrypt (#>) = do (bob, "hello alice") #> alice (alice, "hello bob") #> bob Right b1 <- encrypt bob "how are you, alice?" @@ -191,9 +200,8 @@ testEncryptDecrypt agreeRatchetKEMs alice bob encrypt decrypt (#>) = do (alice, "I'm here too, same") #> bob pure () -testSkippedMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a -testSkippedMessages agreeRatchetKEMs alice bob encrypt decrypt _ = do - when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob +testSkippedMessages :: TestRatchetsP p +testSkippedMessages alice bob encrypt decrypt _ = do Right msg1 <- encrypt bob "hello alice" Right msg2 <- encrypt bob "hello there again" Right msg3 <- encrypt bob "are you there?" @@ -203,9 +211,8 @@ testSkippedMessages agreeRatchetKEMs alice bob encrypt decrypt _ = do Decrypted "hello alice" <- decrypt alice msg1 pure () -testManyMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a -testManyMessages agreeRatchetKEMs alice bob _ _ (#>) = do - when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob +testManyMessages :: TestRatchetsP p +testManyMessages alice bob _ _ (#>) = do (bob, "b1") #> alice (bob, "b2") #> alice (bob, "b3") #> alice @@ -222,9 +229,8 @@ testManyMessages agreeRatchetKEMs alice bob _ _ (#>) = do (bob, "b15") #> alice (bob, "b16") #> alice -testSkippedAfterRatchetAdvance :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a -testSkippedAfterRatchetAdvance agreeRatchetKEMs alice bob encrypt decrypt (#>) = do - when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob +testSkippedAfterRatchetAdvance :: TestRatchetsP p +testSkippedAfterRatchetAdvance alice bob encrypt decrypt (#>) = do (bob, "b1") #> alice Right b2 <- encrypt bob "b2" Right b3 <- encrypt bob "b3" diff --git a/tests/SMPWebTests.hs b/tests/SMPWebTests.hs index d2278c353..274cd1036 100644 --- a/tests/SMPWebTests.hs +++ b/tests/SMPWebTests.hs @@ -13,10 +13,17 @@ -- Run: cabal test --test-option=--match="/SMP Web Client/" module SMPWebTests (smpWebTests) where -import Control.Concurrent.STM (atomically) +import Control.Concurrent.STM +import Control.Monad (when) +import Control.Exception (bracket) import Control.Monad.Except (ExceptT, runExceptT) +import Crypto.Random (ChaChaDRG) +import Data.IORef +import System.IO (Handle, hFlush, hGetLine, hPutStr, hSetBuffering, BufferMode (..)) +import System.Process (CreateProcess (..), StdStream (..), ProcessHandle, createProcess, proc, terminateProcess) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC +import Data.List (isInfixOf) import Data.List.NonEmpty (NonEmpty (..)) import System.Directory (doesDirectoryExist) import Data.Word (Word16) @@ -44,6 +51,7 @@ import Simplex.Messaging.Transport (TLS, smpBlockSize, currentServerSMPRelayVers import Simplex.Messaging.Transport.Client (TransportHost (..)) import SMPAgentClient (agentCfg, initAgentServers, testDB) import SMPClient (cfgWebOn, testKeyHash, testPort, withSmpServerConfig) +import AgentTests.DoubleRatchetTests (testEncryptDecrypt, testSkippedMessages, testManyMessages, testSkippedAfterRatchetAdvance) import AgentTests.FunctionalAPITests (withAgent) import Test.Hspec hiding (it) import Util @@ -101,9 +109,219 @@ jsStr bs = "'" <> BC.unpack bs <> "'" paddedMsgLen :: Int paddedMsgLen = 100 +-- -- TestPeer: sum type for cross-language ratchet tests + +type HsPeer a = TVar (TVar ChaChaDRG, CR.Ratchet a, CR.SkippedMsgKeys) + +data TestPeer + = forall a. (C.AlgorithmI a, C.DhAlgorithm a) => TestPeerHS (HsPeer a) + | TestPeerJS Handle Handle ProcessHandle -- stdin, stdout, process + +-- dispatch functions + +tpEncrypt :: TestPeer -> B.ByteString -> IO (Either C.CryptoError B.ByteString) +tpEncrypt (TestPeerHS tvar) msg = do + (_, rc, smks) <- readTVarIO tvar + result <- runExceptT $ do + (mek, rc') <- CR.rcEncryptHeader rc Nothing CR.currentE2EEncryptVersion + ct <- CR.rcEncryptMsg mek paddedMsgLen msg + pure (ct, rc') + case result of + Right (ct, rc') -> do + (g, _, smks') <- readTVarIO tvar + atomically $ writeTVar tvar (g, rc', smks') + pure $ Right ct + Left e -> pure $ Left e +tpEncrypt (TestPeerJS hIn hOut _) msg = do + hPutStrLn' hIn $ "E " <> BC.unpack msg + resp <- hGetLine hOut + case parseResponse resp of + Right hex -> pure $ Right $ hexToBS hex + Left err -> error $ "tpEncrypt JS error: " <> err + +tpDecrypt :: TestPeer -> B.ByteString -> IO (Either C.CryptoError (Either C.CryptoError B.ByteString)) +tpDecrypt (TestPeerHS tvar) ct = do + (g, rc, smks) <- readTVarIO tvar + result <- runExceptT $ CR.rcDecrypt g rc smks ct + case result of + Right (msg, rc', smDiff) -> do + atomically $ writeTVar tvar (g, rc', CR.applySMDiff smks smDiff) + pure $ Right msg + Left e -> pure $ Left e +tpDecrypt (TestPeerJS hIn hOut _) ct = do + hPutStrLn' hIn $ "D " <> bsToHex ct + resp <- hGetLine hOut + case parseResponse resp of + Right txt -> pure $ Right $ Right $ BC.pack txt + Left err -> parseJsError err + +-- Map JS REPL error strings to CryptoError at the correct Either level. +-- Outer Left: errors that abort rcDecrypt (header failure, first skipMessageKeys). +-- Inner Right (Left _): errors from second skipMessageKeys (duplicate/earlier in current ratchet state). +parseJsError :: String -> IO (Either C.CryptoError (Either C.CryptoError B.ByteString)) +parseJsError err + -- Outer errors (ExceptT failures in Haskell rcDecrypt) + | has "CERatchetHeader" = pure $ Left C.CERatchetHeader + | has "CERatchetKEMState" = pure $ Left C.CERatchetKEMState + -- Inner errors (pure Left in second skipMessageKeys) + | has "CERatchetDuplicateMessage" = pure $ Right $ Left C.CERatchetDuplicateMessage + | has "CERatchetEarlierMessage" = pure $ Right $ Left $ C.CERatchetEarlierMessage 0 + | has "CERatchetTooManySkipped" = pure $ Right $ Left $ C.CERatchetTooManySkipped 0 + | has "CERatchetState" = pure $ Right $ Left C.CERatchetState + | otherwise = pure $ Left $ C.CryptoHeaderError err + where + has s = s `isInfixOf` err + +tpSndKEM :: TestPeer -> IO Bool +tpSndKEM (TestPeerHS tvar) = do + (_, rc, _) <- readTVarIO tvar + pure $ CR.enablePQ $ CR.rcSndKEM rc +tpSndKEM (TestPeerJS hIn hOut _) = do + hPutStrLn' hIn "SNDKEM" + resp <- hGetLine hOut + pure $ resp == "ok: 1" + +tpRcvKEM :: TestPeer -> IO Bool +tpRcvKEM (TestPeerHS tvar) = do + (_, rc, _) <- readTVarIO tvar + pure $ CR.enablePQ $ CR.rcRcvKEM rc +tpRcvKEM (TestPeerJS hIn hOut _) = do + hPutStrLn' hIn "RCVKEM" + resp <- hGetLine hOut + pure $ resp == "ok: 1" + +tpEncryptDecrypt :: Maybe CR.PQEncryption -> Bool -> Bool -> (TestPeer, B.ByteString) -> TestPeer -> Expectation +tpEncryptDecrypt _pqEnc expectSndKEM expectRcvKEM (sender, msg) receiver = do + Right ct <- tpEncrypt sender msg + sndK <- tpSndKEM sender + when (sndK /= expectSndKEM) $ expectationFailure $ "sndKEM: expected " <> show expectSndKEM <> ", got " <> show sndK + Right (Right msg') <- tpDecrypt receiver ct + rcvK <- tpRcvKEM receiver + when (rcvK /= expectRcvKEM) $ expectationFailure $ "rcvKEM: expected " <> show expectRcvKEM <> ", got " <> show rcvK + msg' `shouldBe` msg + +-- TestPeer operators (matching Haskell DoubleRatchetTests) +tp_noKEM, tp_hasKEM :: (TestPeer, B.ByteString) -> TestPeer -> Expectation +tp_noKEM = tpEncryptDecrypt Nothing False False +tp_hasKEM = tpEncryptDecrypt Nothing True True + +-- JS process helpers + +spawnJsRatchet :: IO (Handle, Handle, ProcessHandle) +spawnJsRatchet = do + let cp = (proc "node" ["dist-test/ratchet-repl.js"]) {cwd = Just "smp-web", std_in = CreatePipe, std_out = CreatePipe, std_err = CreatePipe} + (Just hIn, Just hOut, _, ph) <- createProcess cp + hSetBuffering hIn LineBuffering + hSetBuffering hOut LineBuffering + pure (hIn, hOut, ph) + +destroyJsRatchet :: TestPeer -> IO () +destroyJsRatchet (TestPeerJS _ _ ph) = terminateProcess ph +destroyJsRatchet _ = pure () + +jsCmd :: Handle -> Handle -> String -> IO String +jsCmd hIn hOut cmd = do + hPutStrLn' hIn cmd + hGetLine hOut + +hPutStrLn' :: Handle -> String -> IO () +hPutStrLn' h s = do + hPutStr h (s <> "\n") + hFlush h + +parseResponse :: String -> Either String String +parseResponse resp + | take 4 resp == "ok: " = Right $ drop 4 resp + | take 7 resp == "error: " = Left $ drop 7 resp + | otherwise = Left $ "unexpected response: " <> resp + +bsToHex :: B.ByteString -> String +bsToHex = concatMap (\w -> let h = showHex' w in h) . B.unpack + where + showHex' w = [hexDigit (w `div` 16), hexDigit (w `mod` 16)] + hexDigit n | n < 10 = toEnum (fromEnum '0' + fromIntegral n) + | otherwise = toEnum (fromEnum 'a' + fromIntegral n - 10) + +hexToBS :: String -> B.ByteString +hexToBS = B.pack . go + where + go [] = [] + go (a:b:rest) = fromIntegral (hexVal a * 16 + hexVal b) : go rest + go _ = [] + hexVal c + | c >= '0' && c <= '9' = fromEnum c - fromEnum '0' + | c >= 'a' && c <= 'f' = fromEnum c - fromEnum 'a' + 10 + | c >= 'A' && c <= 'F' = fromEnum c - fromEnum 'A' + 10 + | otherwise = 0 + runRight :: (Show e, HasCallStack) => ExceptT e IO a -> IO a runRight action = runExceptT action >>= either (error . ("Unexpected error: " <>) . show) pure +-- -- Cross-language ratchet init functions and test patterns + +withCrossPeers :: IO (TestPeer, TestPeer) -> ((TestPeer, TestPeer) -> IO ()) -> IO () +withCrossPeers initPeers test = bracket initPeers cleanup test + where + cleanup (a, b) = destroyJsRatchet a >> destroyJsRatchet b + +-- HS (receiver) <-> JS (sender), no PQ +initHsJs_noPQ :: IO (TestPeer, TestPeer) +initHsJs_noPQ = do + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + Version vNum = v + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOff + let aliceE2EHex = bsToHex $ smpEncode e2eAlice + (hIn, hOut, ph) <- spawnJsRatchet + bobE2EHex <- either error pure . parseResponse =<< jsCmd hIn hOut ("INIT_SND " ++ show vNum ++ " none " ++ aliceE2EHex) + Right (CR.AE2ERatchetParams _ bobE2E :: CR.AE2ERatchetParams 'X448) <- pure $ smpDecode $ hexToBS bobE2EHex + Right (aliceInitParams, _) <- runExceptT $ CR.pqX3dhRcv pkAlice1 pkAlice2 Nothing bobE2E + let aliceRatchet = CR.initRcvRatchet (CR.RatchetVersions v v) pkAlice2 (aliceInitParams, Nothing) CR.PQSupportOff + ga <- C.newRandom + aliceTVar <- newTVarIO (ga, aliceRatchet, M.empty :: CR.SkippedMsgKeys) + pure (TestPeerHS aliceTVar, TestPeerJS hIn hOut ph) + +-- HS (receiver) <-> JS (sender), PQ KEM accepted +initHsJs_PQ :: IO (TestPeer, TestPeer) +initHsJs_PQ = do + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + Version vNum = v + (pkAlice1, pkAlice2, alicePKem_@(Just _), e2eAlice) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOn + let aliceE2EHex = bsToHex $ smpEncode e2eAlice + (hIn, hOut, ph) <- spawnJsRatchet + bobE2EHex <- either error pure . parseResponse =<< jsCmd hIn hOut ("INIT_SND " ++ show vNum ++ " accept " ++ aliceE2EHex) + Right (CR.AE2ERatchetParams _ bobE2E :: CR.AE2ERatchetParams 'X448) <- pure $ smpDecode $ hexToBS bobE2EHex + Right (aliceInitParams, aliceKemKp_) <- runExceptT $ CR.pqX3dhRcv pkAlice1 pkAlice2 alicePKem_ bobE2E + let aliceRatchet = CR.initRcvRatchet (CR.RatchetVersions v v) pkAlice2 (aliceInitParams, aliceKemKp_) CR.PQSupportOn + ga <- C.newRandom + aliceTVar <- newTVarIO (ga, aliceRatchet, M.empty :: CR.SkippedMsgKeys) + pure (TestPeerHS aliceTVar, TestPeerJS hIn hOut ph) + +-- JS (receiver) <-> JS (sender), no PQ +initTsTs_noPQ :: IO (TestPeer, TestPeer) +initTsTs_noPQ = do + let Version vNum = CR.currentE2EEncryptVersion + (hInA, hOutA, phA) <- spawnJsRatchet + aliceE2EHex <- either error pure . parseResponse =<< jsCmd hInA hOutA ("INIT_RCV " ++ show vNum ++ " 0") + (hInB, hOutB, phB) <- spawnJsRatchet + bobE2EHex <- either error pure . parseResponse =<< jsCmd hInB hOutB ("INIT_SND " ++ show vNum ++ " none " ++ aliceE2EHex) + completeResp <- jsCmd hInA hOutA ("COMPLETE " ++ bobE2EHex) + when (completeResp /= "ok") $ error $ "COMPLETE failed: " ++ completeResp + pure (TestPeerJS hInA hOutA phA, TestPeerJS hInB hOutB phB) + +-- JS (receiver) <-> JS (sender), PQ KEM accepted +initTsTs_PQ :: IO (TestPeer, TestPeer) +initTsTs_PQ = do + let Version vNum = CR.currentE2EEncryptVersion + (hInA, hOutA, phA) <- spawnJsRatchet + aliceE2EHex <- either error pure . parseResponse =<< jsCmd hInA hOutA ("INIT_RCV " ++ show vNum ++ " 1") + (hInB, hOutB, phB) <- spawnJsRatchet + bobE2EHex <- either error pure . parseResponse =<< jsCmd hInB hOutB ("INIT_SND " ++ show vNum ++ " accept " ++ aliceE2EHex) + completeResp <- jsCmd hInA hOutA ("COMPLETE " ++ bobE2EHex) + when (completeResp /= "ok") $ error $ "COMPLETE failed: " ++ completeResp + pure (TestPeerJS hInA hOutA phA, TestPeerJS hInB hOutB phB) + smpWebTests :: SpecWith () smpWebTests = describe "SMP Web Client" $ do distExists <- runIO $ doesDirectoryExist (smpWebDir <> "/dist") @@ -675,6 +893,30 @@ smpWebTests_ = do <> jsOut ("new Uint8Array([match ? 1 : 0, der.length, raw.length])") tsResult `shouldBe` B.pack [1, 68, 56] + describe "cross-language ratchet advance" $ do + let run initPeers op test = withCrossPeers initPeers $ \(alice, bob) -> + test alice bob tpEncrypt tpDecrypt op + describe "HS rcv, JS snd, no PQ" $ do + it "encrypt and decrypt" $ run initHsJs_noPQ tp_noKEM testEncryptDecrypt + it "skipped messages" $ run initHsJs_noPQ tp_noKEM testSkippedMessages + it "many messages" $ run initHsJs_noPQ tp_noKEM testManyMessages + it "skipped after ratchet advance" $ run initHsJs_noPQ tp_noKEM testSkippedAfterRatchetAdvance + describe "HS rcv, JS snd, PQ" $ do + it "encrypt and decrypt" $ run initHsJs_PQ tp_hasKEM testEncryptDecrypt + it "skipped messages" $ run initHsJs_PQ tp_hasKEM testSkippedMessages + it "many messages" $ run initHsJs_PQ tp_hasKEM testManyMessages + it "skipped after ratchet advance" $ run initHsJs_PQ tp_hasKEM testSkippedAfterRatchetAdvance + describe "JS rcv, JS snd, no PQ" $ do + it "encrypt and decrypt" $ run initTsTs_noPQ tp_noKEM testEncryptDecrypt + it "skipped messages" $ run initTsTs_noPQ tp_noKEM testSkippedMessages + it "many messages" $ run initTsTs_noPQ tp_noKEM testManyMessages + it "skipped after ratchet advance" $ run initTsTs_noPQ tp_noKEM testSkippedAfterRatchetAdvance + describe "JS rcv, JS snd, PQ" $ do + it "encrypt and decrypt" $ run initTsTs_PQ tp_hasKEM testEncryptDecrypt + it "skipped messages" $ run initTsTs_PQ tp_hasKEM testSkippedMessages + it "many messages" $ run initTsTs_PQ tp_hasKEM testManyMessages + it "skipped after ratchet advance" $ run initTsTs_PQ tp_hasKEM testSkippedAfterRatchetAdvance + describe "crypto/blockEncryption" $ do describe "sbcInit + sbcHkdf" $ do it "TypeScript produces same sbKey/nonce via sbcInit+sbcHkdf as Haskell" $ do From dd07c231f1179e6351368e164136e38d288df0fb Mon Sep 17 00:00:00 2001 From: "Evgeny @ SimpleX Chat" <259188159+evgeny-simplex@users.noreply.github.com> Date: Sat, 16 May 2026 17:06:56 +0000 Subject: [PATCH 8/8] agent encoding/encryption stack with test --- smp-web/.gitignore | 1 + smp-web/src/agent/message.ts | 207 +++++++++++++++++++++++ smp-web/src/protocol.ts | 106 ++++++++++++ tests/SMPWebTests.hs | 311 ++++++++++++++++++++++++++++++++++- 4 files changed, 623 insertions(+), 2 deletions(-) create mode 100644 smp-web/src/agent/message.ts diff --git a/smp-web/.gitignore b/smp-web/.gitignore index 320c107b3..0a57d7b8c 100644 --- a/smp-web/.gitignore +++ b/smp-web/.gitignore @@ -1,3 +1,4 @@ node_modules/ dist/ +dist-test/ package-lock.json diff --git a/smp-web/src/agent/message.ts b/smp-web/src/agent/message.ts new file mode 100644 index 000000000..edaf9cd2b --- /dev/null +++ b/smp-web/src/agent/message.ts @@ -0,0 +1,207 @@ +// Agent message encoding/decoding. +// Mirrors: Simplex.Messaging.Agent.Protocol (AgentMsgEnvelope, AgentMessage, APrivHeader, AMessage) + +import { + Decoder, concatBytes, + encodeBytes, decodeBytes, + encodeLarge, decodeLarge, + encodeInt64, decodeInt64, + encodeWord16, decodeWord16, + encodeMaybe, decodeMaybe, + encodeNonEmpty, decodeNonEmpty, +} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" + +// -- Constants (Agent/Protocol.hs:318-319) + +export const currentSMPAgentVersion = 7 + +// -- AMessage (Agent/Protocol.hs:1001-1020) + +export type AMessage = + | {type: "HELLO"} + | {type: "A_MSG", body: Uint8Array} + | {type: "A_RCVD", receipts: AMessageReceipt[]} // NonEmpty + | {type: "EREADY", lastDecryptedMsgId: bigint} + +// Agent/Protocol.hs:1040-1045 +export interface AMessageReceipt { + agentMsgId: bigint // Int64 + msgHash: Uint8Array // ByteString (32-byte SHA-256) + rcptInfo: Uint8Array // MsgReceiptInfo (ByteString, Large-encoded) +} + +// Agent/Protocol.hs:1078-1100 +export function encodeAMessage(msg: AMessage): Uint8Array { + switch (msg.type) { + case "HELLO": return new Uint8Array([0x48]) // "H" + case "A_MSG": return concatBytes(new Uint8Array([0x4D]), msg.body) // "M" + Tail + case "A_RCVD": return concatBytes(new Uint8Array([0x56]), encodeNonEmpty(encodeAMessageReceipt, msg.receipts)) // "V" + NonEmpty + case "EREADY": return concatBytes(new Uint8Array([0x45]), encodeInt64(msg.lastDecryptedMsgId)) // "E" + Int64 + } +} + +export function decodeAMessage(d: Decoder): AMessage { + const tag = d.anyByte() + switch (tag) { + case 0x48: return {type: "HELLO"} // 'H' + case 0x4D: return {type: "A_MSG", body: d.takeAll()} // 'M' + Tail + case 0x56: return {type: "A_RCVD", receipts: decodeNonEmpty(decodeAMessageReceipt, d)} // 'V' + case 0x45: return {type: "EREADY", lastDecryptedMsgId: decodeInt64(d)} // 'E' + // Queue management tags (not needed for chat messages, but recognized for decoding) + case 0x51: { // 'Q' + const sub = d.anyByte() + switch (sub) { + case 0x43: // 'C' = A_QCONT + case 0x41: // 'A' = QADD + case 0x4B: // 'K' = QKEY + case 0x55: // 'U' = QUSE + case 0x54: // 'T' = QTEST + throw new Error("decodeAMessage: queue management message (Q" + String.fromCharCode(sub) + ") not implemented") + default: + throw new Error("decodeAMessage: unknown Q-subtag " + sub) + } + } + default: + throw new Error("decodeAMessage: unknown tag " + tag) + } +} + +// Agent/Protocol.hs:1106-1111 +function encodeAMessageReceipt(r: AMessageReceipt): Uint8Array { + return concatBytes(encodeInt64(r.agentMsgId), encodeBytes(r.msgHash), encodeLarge(r.rcptInfo)) +} + +function decodeAMessageReceipt(d: Decoder): AMessageReceipt { + return {agentMsgId: decodeInt64(d), msgHash: decodeBytes(d), rcptInfo: decodeLarge(d)} +} + +// -- APrivHeader (Agent/Protocol.hs:946-957) + +export interface APrivHeader { + sndMsgId: bigint // AgentMsgId = Int64 + prevMsgHash: Uint8Array // MsgHash = ByteString +} + +export function encodeAPrivHeader(h: APrivHeader): Uint8Array { + return concatBytes(encodeInt64(h.sndMsgId), encodeBytes(h.prevMsgHash)) +} + +export function decodeAPrivHeader(d: Decoder): APrivHeader { + return {sndMsgId: decodeInt64(d), prevMsgHash: decodeBytes(d)} +} + +// -- AgentMessage (Agent/Protocol.hs:866-888) + +export type AgentMessage = + | {type: "connInfo", cInfo: Uint8Array} + | {type: "connInfoReply", smpQueues: Uint8Array[], cInfo: Uint8Array} // NonEmpty raw-encoded SMPQueueInfo + | {type: "ratchetInfo", info: Uint8Array} + | {type: "message", header: APrivHeader, msg: AMessage} + +export function encodeAgentMessage(msg: AgentMessage): Uint8Array { + switch (msg.type) { + case "connInfo": + return concatBytes(new Uint8Array([0x49]), msg.cInfo) // 'I' + Tail + case "connInfoReply": + // 'D' + NonEmpty SMPQueueInfo + Tail cInfo + // SMPQueueInfo encoding is complex; for now encode the raw bytes + return concatBytes( + new Uint8Array([0x44]), + encodeNonEmpty(b => b, msg.smpQueues), + msg.cInfo, + ) + case "ratchetInfo": + return concatBytes(new Uint8Array([0x52]), msg.info) // 'R' + Tail + case "message": + return concatBytes(new Uint8Array([0x4D]), encodeAPrivHeader(msg.header), encodeAMessage(msg.msg)) // 'M' + header + msg + } +} + +export function decodeAgentMessage(d: Decoder): AgentMessage { + const tag = d.anyByte() + switch (tag) { + case 0x49: return {type: "connInfo", cInfo: d.takeAll()} // 'I' + Tail + case 0x44: { // 'D' + // NonEmpty SMPQueueInfo is complex to decode; skip for now, just capture raw + throw new Error("decodeAgentMessage: connInfoReply ('D') not implemented") + } + case 0x52: return {type: "ratchetInfo", info: d.takeAll()} // 'R' + Tail + case 0x4D: return {type: "message", header: decodeAPrivHeader(d), msg: decodeAMessage(d)} // 'M' + default: + throw new Error("decodeAgentMessage: unknown tag " + tag) + } +} + +// -- AgentMsgEnvelope (Agent/Protocol.hs:812-861) + +export type AgentMsgEnvelope = + | {type: "confirmation", agentVersion: number, e2eEncryption: Uint8Array | null, encConnInfo: Uint8Array} + | {type: "envelope", agentVersion: number, encAgentMessage: Uint8Array} + | {type: "invitation", agentVersion: number, connReqBytes: Uint8Array, connInfo: Uint8Array} + | {type: "ratchetKey", agentVersion: number, e2eEncryption: Uint8Array, info: Uint8Array} + +// Agent/Protocol.hs:835-843 +export function encodeAgentMsgEnvelope(env: AgentMsgEnvelope): Uint8Array { + switch (env.type) { + case "confirmation": + // (agentVersion, 'C', Maybe SndE2ERatchetParams, Tail encConnInfo) + return concatBytes( + encodeWord16(env.agentVersion), + new Uint8Array([0x43]), // 'C' + encodeMaybe(b => b, env.e2eEncryption), // e2eEncryption is already smpEncoded bytes or null + env.encConnInfo, // Tail + ) + case "envelope": + // (agentVersion, 'M', Tail encAgentMessage) + return concatBytes( + encodeWord16(env.agentVersion), + new Uint8Array([0x4D]), // 'M' + env.encAgentMessage, // Tail + ) + case "invitation": + // (agentVersion, 'I', Large connReqBytes, Tail connInfo) + return concatBytes( + encodeWord16(env.agentVersion), + new Uint8Array([0x49]), // 'I' + encodeLarge(env.connReqBytes), + env.connInfo, // Tail + ) + case "ratchetKey": + // (agentVersion, 'R', e2eEncryption, Tail info) + return concatBytes( + encodeWord16(env.agentVersion), + new Uint8Array([0x52]), // 'R' + env.e2eEncryption, // already smpEncoded + env.info, // Tail + ) + } +} + +// Agent/Protocol.hs:844-861 +export function decodeAgentMsgEnvelope(d: Decoder): AgentMsgEnvelope { + const agentVersion = decodeWord16(d) + const tag = d.anyByte() + switch (tag) { + case 0x43: // 'C' Confirmation + // e2eEncryption_ is Maybe (SndE2ERatchetParams 'X448), encConnInfo is Tail + // Full parsing of E2ERatchetParams needed to split the boundary — not implemented in spike + throw new Error("decodeAgentMsgEnvelope: confirmation ('C') not fully implemented") + case 0x4D: // 'M' Message envelope + return {type: "envelope", agentVersion, encAgentMessage: d.takeAll()} // Tail + case 0x49: { // 'I' Invitation + const connReqBytes = decodeLarge(d) + const connInfo = d.takeAll() // Tail + return {type: "invitation", agentVersion, connReqBytes, connInfo} + } + case 0x52: { // 'R' RatchetKey + // e2eEncryption is an E2ERatchetParams — variable-length, not Tail + // For now, capture remaining minus nothing (since info is Tail and comes last) + // This is tricky: e2eEncryption is smpEncoded E2ERatchetParams, info is Tail + // We can't easily split without knowing the E2ERatchetParams length + // For the spike, just capture all remaining as raw + throw new Error("decodeAgentMsgEnvelope: ratchetKey ('R') not fully implemented") + } + default: + throw new Error("decodeAgentMsgEnvelope: unknown tag " + tag) + } +} diff --git a/smp-web/src/protocol.ts b/smp-web/src/protocol.ts index aef8d7e68..e61c49ba8 100644 --- a/smp-web/src/protocol.ts +++ b/smp-web/src/protocol.ts @@ -5,9 +5,11 @@ import { Decoder, concatBytes, encodeBytes, decodeBytes, encodeLarge, decodeLarge, + encodeWord16, decodeWord16, encodeBool, decodeBool, encodeMaybe, decodeMaybe, } from "@simplex-chat/xftp-web/dist/protocol/encoding.js" +import {cbEncrypt, cbDecrypt} from "@simplex-chat/xftp-web/dist/crypto/secretbox.js" import {readTag, readSpace} from "@simplex-chat/xftp-web/dist/protocol/commands.js" // -- Transmission encoding (Protocol.hs:2201-2203) @@ -246,3 +248,107 @@ export function decodeMSG(d: Decoder): MSGResponse { const msgBody = d.takeAll() return {msgId, msgBody} } + +// -- Per-queue E2E encryption (Protocol.hs:1071-1114) + +// Protocol.hs:316-320 +export const e2eEncMessageLength = 16000 +export const e2eEncConfirmationLength = 15904 + +// Protocol.hs:1078-1086 +export interface PubHeader { + phVersion: number // VersionSMPC (Word16) + phE2ePubDhKey: Uint8Array | null // Maybe PublicKeyX25519 (DER-encoded ByteString) +} + +export function encodePubHeader(h: PubHeader): Uint8Array { + return concatBytes(encodeWord16(h.phVersion), encodeMaybe(encodeBytes, h.phE2ePubDhKey)) +} + +export function decodePubHeader(d: Decoder): PubHeader { + return {phVersion: decodeWord16(d), phE2ePubDhKey: decodeMaybe(decodeBytes, d)} +} + +// Protocol.hs:1097-1110 +export type PrivHeader = + | {type: "PHConfirmation", key: Uint8Array} // 'K' + DER-encoded APublicAuthKey + | {type: "PHEmpty"} // '_' + +export function encodePrivHeader(h: PrivHeader): Uint8Array { + switch (h.type) { + case "PHConfirmation": return concatBytes(new Uint8Array([0x4B]), encodeBytes(h.key)) // 'K' + encodeBytes + case "PHEmpty": return new Uint8Array([0x5F]) // '_' + } +} + +export function decodePrivHeader(d: Decoder): PrivHeader { + const tag = d.anyByte() + switch (tag) { + case 0x4B: return {type: "PHConfirmation", key: decodeBytes(d)} // 'K' + case 0x5F: return {type: "PHEmpty"} // '_' + default: throw new Error("decodePrivHeader: unknown tag " + tag) + } +} + +// Protocol.hs:1095, 1112-1114 +export interface ClientMessage { + privHeader: PrivHeader + body: Uint8Array +} + +// smpEncode (ClientMessage h msg) = smpEncode h <> msg +export function encodeClientMessage(msg: ClientMessage): Uint8Array { + return concatBytes(encodePrivHeader(msg.privHeader), msg.body) +} + +export function decodeClientMessage(d: Decoder): ClientMessage { + const privHeader = decodePrivHeader(d) + const body = d.takeAll() + return {privHeader, body} +} + +// Protocol.hs:1071-1093 +export interface ClientMsgEnvelope { + cmHeader: PubHeader + cmNonce: Uint8Array // CbNonce: raw 24 bytes + cmEncBody: Uint8Array // encrypted body (Tail) +} + +// smpEncode (cmHeader, cmNonce, Tail cmEncBody) +export function encodeClientMsgEnvelope(env: ClientMsgEnvelope): Uint8Array { + return concatBytes(encodePubHeader(env.cmHeader), env.cmNonce, env.cmEncBody) +} + +export function decodeClientMsgEnvelope(d: Decoder): ClientMsgEnvelope { + const cmHeader = decodePubHeader(d) + const cmNonce = d.take(24) // CbNonce is raw 24 bytes + const cmEncBody = d.takeAll() + return {cmHeader, cmNonce, cmEncBody} +} + +// -- Per-queue E2E encrypt/decrypt (Agent/Client.hs:2074-2102) + +// agentCbEncrypt: encrypt a ClientMessage and wrap in ClientMsgEnvelope +export function agentCbEncrypt( + e2eDhSecret: Uint8Array, // X25519 DH shared secret (32 bytes) + smpClientVersion: number, // Word16 + e2ePubKey: Uint8Array | null, // DER-encoded X25519 public key, null for normal messages + msg: Uint8Array, // smpEncode(ClientMessage) +): Uint8Array { + const cmNonce = crypto.getRandomValues(new Uint8Array(24)) + const paddedLen = e2ePubKey !== null ? e2eEncConfirmationLength : e2eEncMessageLength + const cmEncBody = cbEncrypt(e2eDhSecret, cmNonce, msg, paddedLen) + const cmHeader: PubHeader = {phVersion: smpClientVersion, phE2ePubDhKey: e2ePubKey} + return encodeClientMsgEnvelope({cmHeader, cmNonce, cmEncBody}) +} + +// agentCbDecrypt: decrypt a ClientMsgEnvelope +export function agentCbDecrypt( + dhSecret: Uint8Array, // X25519 DH shared secret (32 bytes) + data: Uint8Array, // raw ClientMsgEnvelope bytes +): {pubHeader: PubHeader, clientMessage: ClientMessage} { + const env = decodeClientMsgEnvelope(new Decoder(data)) + const plaintext = cbDecrypt(dhSecret, env.cmNonce, env.cmEncBody) + const clientMessage = decodeClientMessage(new Decoder(plaintext)) + return {pubHeader: env.cmHeader, clientMessage} +} diff --git a/tests/SMPWebTests.hs b/tests/SMPWebTests.hs index 274cd1036..11c0d5922 100644 --- a/tests/SMPWebTests.hs +++ b/tests/SMPWebTests.hs @@ -15,8 +15,9 @@ module SMPWebTests (smpWebTests) where import Control.Concurrent.STM import Control.Monad (when) +import Data.Bifunctor (first) import Control.Exception (bracket) -import Control.Monad.Except (ExceptT, runExceptT) +import Control.Monad.Except (ExceptT, liftEither, runExceptT, throwError, withExceptT) import Crypto.Random (ChaChaDRG) import Data.IORef import System.IO (Handle, hFlush, hGetLine, hPutStr, hSetBuffering, BufferMode (..)) @@ -43,7 +44,7 @@ import qualified Data.ByteArray as BA import Simplex.Messaging.Crypto.ShortLink (contactShortLinkKdf, invShortLinkKdf) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String (Str (..), strEncode) -import Simplex.Messaging.Protocol (EntityId (..), SMPServer, SubscriptionMode (..), MsgFlags (..), pattern SMPServer, encodeProtocol, Command (..), NewQueueReq (..), BrokerMsg (..), RcvMessage (..), EncRcvMsgBody (..), QueueIdsKeys (..)) +import Simplex.Messaging.Protocol (EntityId (..), SMPServer, SubscriptionMode (..), MsgFlags (..), pattern SMPServer, encodeProtocol, Command (..), NewQueueReq (..), BrokerMsg (..), RcvMessage (..), EncRcvMsgBody (..), QueueIdsKeys (..), PubHeader (..), PrivHeader (..), ClientMessage (..), ClientMsgEnvelope (..), pattern VersionSMPC) import Simplex.Messaging.Server.Env.STM (AStoreType (..)) import Simplex.Messaging.Server.MsgStore.Types (SMSType (..), SQSType (..)) import Simplex.Messaging.Server.Web (attachStaticAndWS) @@ -96,6 +97,12 @@ impRatchet = "import { generateX448KeyPair, pqX3dhSnd, pqX3dhRcv, x448DH, encode impSntrup :: String impSntrup = "import { initSntrup761, sntrup761Keypair, sntrup761Enc, sntrup761Dec } from './dist/crypto/sntrup761.js'; await initSntrup761();" +impAgentMsg :: String +impAgentMsg = "import { encodeAMessage, decodeAMessage, encodeAPrivHeader, decodeAPrivHeader, encodeAgentMessage, decodeAgentMessage, encodeAgentMsgEnvelope, decodeAgentMsgEnvelope } from './dist/agent/message.js';" + +impProtoE2E :: String +impProtoE2E = "import { encodePubHeader, decodePubHeader, encodePrivHeader, decodePrivHeader, encodeClientMessage, decodeClientMessage, encodeClientMsgEnvelope, decodeClientMsgEnvelope, agentCbEncrypt, agentCbDecrypt, e2eEncMessageLength } from './dist/protocol.js';" + impCrypto :: String impCrypto = "import { sbcInit, sbcHkdf, sbEncryptBlock, sbDecryptBlock } from './dist/crypto.js';" @@ -1134,3 +1141,303 @@ smpWebTests_ = do -- First byte: rootKey DER length (44 for Ed25519), rest: userData B.head tsResult `shouldBe` 44 B.tail tsResult `shouldBe` testData + + describe "agent/message" $ do + describe "AMessage" $ do + it "HELLO encoding matches Haskell" $ do + let hsBytes = smpEncode AP.HELLO + tsBytes <- callNode $ impAgentMsg <> jsOut "encodeAMessage({type: 'HELLO'})" + tsBytes `shouldBe` hsBytes + + it "A_MSG encoding matches Haskell" $ do + let body = "hello world from agent" + hsBytes = smpEncode (AP.A_MSG body) + tsBytes <- callNode $ impAgentMsg <> jsOut ("encodeAMessage({type: 'A_MSG', body: new TextEncoder().encode('hello world from agent')})") + tsBytes `shouldBe` hsBytes + + it "EREADY encoding matches Haskell" $ do + let hsBytes = smpEncode (AP.EREADY 42) + tsBytes <- callNode $ impEnc <> impAgentMsg <> jsOut ("encodeAMessage({type: 'EREADY', lastDecryptedMsgId: 42n})") + tsBytes `shouldBe` hsBytes + + it "TypeScript decodes Haskell A_MSG" $ do + let body = "decode test" + hsBytes = smpEncode (AP.A_MSG body) + tsResult <- callNode $ impEnc <> impAgentMsg + <> "const msg = decodeAMessage(new Decoder(" <> jsUint8 hsBytes <> "));" + <> jsOut ("msg.body") + tsResult `shouldBe` body + + describe "APrivHeader" $ do + it "encoding matches Haskell" $ do + let hdr = AP.APrivHeader 1 (B.replicate 32 0xAB) + hsBytes = smpEncode hdr + tsBytes <- callNode $ impEnc <> impAgentMsg + <> jsOut ("encodeAPrivHeader({sndMsgId: 1n, prevMsgHash: " <> jsUint8 (B.replicate 32 0xAB) <> "})") + tsBytes `shouldBe` hsBytes + + describe "AgentMessage" $ do + it "M variant encoding matches Haskell" $ do + let hdr = AP.APrivHeader 5 (B.replicate 32 0) + msg = AP.AgentMessage hdr (AP.A_MSG "test body") + hsBytes = smpEncode msg + tsBytes <- callNode $ impEnc <> impAgentMsg + <> jsOut ("encodeAgentMessage({type: 'message', header: {sndMsgId: 5n, prevMsgHash: new Uint8Array(32)}, msg: {type: 'A_MSG', body: new TextEncoder().encode('test body')}})") + tsBytes `shouldBe` hsBytes + + it "TypeScript decodes Haskell AgentMessage M" $ do + let hdr = AP.APrivHeader 99 (B.pack [1..32]) + msg = AP.AgentMessage hdr (AP.A_MSG "cross-language message") + hsBytes = smpEncode msg + tsResult <- callNode $ impEnc <> impAgentMsg + <> "const m = decodeAgentMessage(new Decoder(" <> jsUint8 hsBytes <> "));" + <> "if (m.type !== 'message') throw new Error('expected message');" + <> "if (m.msg.type !== 'A_MSG') throw new Error('expected A_MSG');" + <> jsOut ("new Uint8Array([...new TextEncoder().encode(m.header.sndMsgId.toString()), 0, ...m.msg.body])") + let (idStr, rest) = B.break (== 0) tsResult + idStr `shouldBe` "99" + B.tail rest `shouldBe` "cross-language message" + + describe "AgentMsgEnvelope" $ do + it "M variant encoding matches Haskell" $ do + let env = AP.AgentMsgEnvelope {AP.agentVersion = AP.currentSMPAgentVersion, AP.encAgentMessage = "encrypted payload"} + hsBytes = smpEncode env + tsBytes <- callNode $ impEnc <> impAgentMsg + <> jsOut ("encodeAgentMsgEnvelope({type: 'envelope', agentVersion: 7, encAgentMessage: new TextEncoder().encode('encrypted payload')})") + tsBytes `shouldBe` hsBytes + + it "TypeScript decodes Haskell AgentMsgEnvelope M" $ do + let env = AP.AgentMsgEnvelope {AP.agentVersion = AP.currentSMPAgentVersion, AP.encAgentMessage = "decrypt me"} + hsBytes = smpEncode env + tsResult <- callNode $ impEnc <> impAgentMsg + <> "const e = decodeAgentMsgEnvelope(new Decoder(" <> jsUint8 hsBytes <> "));" + <> "if (e.type !== 'envelope') throw new Error('expected envelope');" + <> jsOut ("e.encAgentMessage") + tsResult `shouldBe` "decrypt me" + + describe "protocol/e2e" $ do + describe "PubHeader" $ do + it "encoding without key matches Haskell" $ do + let h = PubHeader (VersionSMPC 19) Nothing + hsBytes = smpEncode h + tsBytes <- callNode $ impEnc <> impProtoE2E <> jsOut "encodePubHeader({phVersion: 19, phE2ePubDhKey: null})" + tsBytes `shouldBe` hsBytes + + it "encoding with key matches Haskell" $ do + g <- C.newRandom + (k, _) <- atomically $ C.generateKeyPair @'C.X25519 g + let derKey = C.encodePubKey k -- raw DER bytes without smpEncode length prefix + h = PubHeader (VersionSMPC 19) (Just k) + hsBytes = smpEncode h + tsBytes <- callNode $ impEnc <> impProtoE2E + <> jsOut ("encodePubHeader({phVersion: 19, phE2ePubDhKey: " <> jsUint8 derKey <> "})") + tsBytes `shouldBe` hsBytes + + describe "PrivHeader" $ do + it "PHEmpty encoding matches Haskell" $ do + let hsBytes = smpEncode PHEmpty + tsBytes <- callNode $ impProtoE2E <> jsOut "encodePrivHeader({type: 'PHEmpty'})" + tsBytes `shouldBe` hsBytes + + describe "ClientMessage" $ do + it "encoding matches Haskell" $ do + let body = "agent envelope bytes here" + msg = ClientMessage PHEmpty body + hsBytes = smpEncode msg + tsBytes <- callNode $ impEnc <> impProtoE2E + <> jsOut ("encodeClientMessage({privHeader: {type: 'PHEmpty'}, body: new TextEncoder().encode('agent envelope bytes here')})") + tsBytes `shouldBe` hsBytes + + describe "ClientMsgEnvelope" $ do + it "encoding matches Haskell" $ do + let nonce = C.cbNonce $ B.pack [1..24] + h = PubHeader (VersionSMPC 19) Nothing + env = ClientMsgEnvelope {cmHeader = h, cmNonce = nonce, cmEncBody = "encrypted body data"} + hsBytes = smpEncode env + tsBytes <- callNode $ impEnc <> impProtoE2E + <> jsOut ("encodeClientMsgEnvelope({cmHeader: {phVersion: 19, phE2ePubDhKey: null}, cmNonce: " <> jsUint8 (B.pack [1..24]) <> ", cmEncBody: new TextEncoder().encode('encrypted body data')})") + tsBytes `shouldBe` hsBytes + + it "TypeScript decodes Haskell-encoded" $ do + let nonce = C.cbNonce $ B.pack [10..33] + h = PubHeader (VersionSMPC 19) Nothing + env = ClientMsgEnvelope {cmHeader = h, cmNonce = nonce, cmEncBody = "test ciphertext"} + hsBytes = smpEncode env + tsResult <- callNode $ impEnc <> impProtoE2E + <> "const env = decodeClientMsgEnvelope(new Decoder(" <> jsUint8 hsBytes <> "));" + <> jsOut ("new Uint8Array([env.cmHeader.phVersion >> 8, env.cmHeader.phVersion & 0xff, env.cmHeader.phE2ePubDhKey === null ? 1 : 0, ...env.cmNonce, ...env.cmEncBody])") + let (version, rest1) = B.splitAt 2 tsResult + (nullByte, rest2) = B.splitAt 1 rest1 + (nonceBytes, bodyBytes) = B.splitAt 24 rest2 + version `shouldBe` B.pack [0, 19] + nullByte `shouldBe` B.singleton 1 + nonceBytes `shouldBe` B.pack [10..33] + bodyBytes `shouldBe` "test ciphertext" + + describe "per-queue E2E encrypt/decrypt" $ do + it "TypeScript encrypts, Haskell decrypts" $ do + -- Haskell generates receiver keypair + g <- C.newRandom + (rcvPub, rcvPriv) <- atomically $ C.generateKeyPair @'C.X25519 g + let rcvPubRaw = C.pubKeyBytes rcvPub + -- TypeScript generates sender keypair, computes DH, encrypts + tsOutput <- callNode $ impSodium <> impEnc <> impProtoE2E + <> "import { generateX25519KeyPair, dh, encodePubKeyX25519 } from '@simplex-chat/xftp-web/dist/crypto/keys.js';" + <> "const sndKp = generateX25519KeyPair();" + <> "const rcvPub = " <> jsUint8 rcvPubRaw <> ";" + <> "const dhSecret = dh(rcvPub, sndKp.privateKey);" + <> "const clientMsg = encodeClientMessage({privHeader: {type: 'PHEmpty'}, body: new TextEncoder().encode('hello from typescript e2e')});" + <> "const encrypted = agentCbEncrypt(dhSecret, 19, null, clientMsg);" + -- Output: DER-encoded sndPubKey (ByteString-encoded: 1-byte len + DER) + encrypted + <> "const sndDer = encodePubKeyX25519(sndKp.publicKey);" + <> jsOut ("new Uint8Array([sndDer.length, ...sndDer, ...encrypted])") + -- Parse output: [1 byte len][DER sndPubKey][encrypted] + let sndDerLen = fromIntegral $ B.head tsOutput + (sndDerBytes, encrypted) = B.splitAt sndDerLen $ B.drop 1 tsOutput + -- Haskell decodes sender's DER public key and decrypts + let decoded = do + apk <- C.decodePubKey sndDerBytes + dhSecret <- case apk of + C.APublicKey C.SX25519 pk -> Right $ C.dh' pk rcvPriv + _ -> Left "not X25519" + cme <- smpDecode encrypted + plaintext <- first show $ C.cbDecrypt dhSecret (cmNonce cme) (cmEncBody cme) + cm <- smpDecode plaintext + case cm of + ClientMessage PHEmpty body -> Right body + _ -> Left "unexpected PrivHeader" + decoded `shouldBe` Right "hello from typescript e2e" + + describe "full-stack" $ do + it "Haskell encodes all layers, TypeScript decodes" $ do + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + -- Alice (receiver) ratchet keys - extract raw private bytes for TypeScript + (alicePk1, alicePk2, Nothing, e2eAlice) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOff + let C.PrivateKeyX448 sk1 = alicePk1; alicePriv1 = BA.convert sk1 :: B.ByteString + C.PrivateKeyX448 sk2 = alicePk2; alicePriv2 = BA.convert sk2 :: B.ByteString + -- Bob (sender) ratchet: X3DH + init + (bobPk1, bobPk2, Nothing, CR.AE2ERatchetParams _ e2eBob) <- CR.generateSndE2EParams @'X448 g v Nothing + Right bobInitParams <- pure $ CR.pqX3dhSnd bobPk1 bobPk2 Nothing e2eAlice + (_, bobDHRs) <- atomically $ C.generateKeyPair @'X448 g + let bobRatchet = CR.initSndRatchet (CR.RatchetVersions v v) (C.publicKey alicePk2) bobDHRs bobInitParams + bobE2EBytes = smpEncode e2eBob + -- Per-queue E2E: shared DH secret (pass raw bytes to both sides) + (_, e2eSndPriv) <- atomically $ C.generateKeyPair @'C.X25519 g + (e2eRcvPub, _) <- atomically $ C.generateKeyPair @'C.X25519 g + let dhSecret = C.dh' e2eRcvPub e2eSndPriv + dhSecretBytes = C.dhBytes' dhSecret + -- Haskell: encode A_MSG through all layers + let aMsg = AP.AgentMessage (AP.APrivHeader 1 (B.replicate 32 0)) (AP.A_MSG "hello full stack") + agentMsgBytes = smpEncode aMsg + Right (mek, _) <- runExceptT $ CR.rcEncryptHeader bobRatchet Nothing CR.currentE2EEncryptVersion + Right encAgentMsg <- runExceptT $ CR.rcEncryptMsg mek (AP.e2eEncAgentMsgLength AP.currentSMPAgentVersion CR.PQSupportOff) agentMsgBytes + let envBytes = smpEncode $ AP.AgentMsgEnvelope {AP.agentVersion = AP.currentSMPAgentVersion, AP.encAgentMessage = encAgentMsg} + clientMsgBytes = smpEncode $ ClientMessage PHEmpty envBytes + cmNonce <- atomically $ C.randomCbNonce g + Right cmEncBody <- pure $ C.cbEncrypt dhSecret cmNonce clientMsgBytes 16000 + let cmeBytes = smpEncode $ ClientMsgEnvelope (PubHeader (VersionSMPC 19) Nothing) cmNonce cmEncBody + -- TypeScript: init alice ratchet + decode all layers + tsResult <- callNode $ impSodium <> impEnc <> impRatchet <> impAgentMsg <> impProtoE2E + <> "import { cbDecrypt } from '@simplex-chat/xftp-web/dist/crypto/secretbox.js';" + -- Init alice's receiver ratchet + <> "const a1Priv = " <> jsUint8 alicePriv1 <> ";" + <> "const a2Priv = " <> jsUint8 alicePriv2 <> ";" + <> "const rd = new Decoder(" <> jsUint8 bobE2EBytes <> ");" + <> "rd.anyByte(); rd.anyByte();" -- skip version + <> "const bpk1 = decodePubKeyX448(decodeBytes(rd));" + <> "const bpk2 = decodePubKeyX448(decodeBytes(rd));" + <> "const ap = pqX3dhRcv(a1Priv, a2Priv, bpk1, bpk2);" + <> "let alice = initRcvRatchet({current:3,maxSupported:3}, a2Priv, ap, null, false);" + <> "let sk = new Map();" + -- Layer 1: per-queue E2E decrypt + <> "const dhSecret = " <> jsUint8 dhSecretBytes <> ";" + <> "const {clientMessage} = agentCbDecrypt(dhSecret, " <> jsUint8 cmeBytes <> ");" + -- Layer 2: decode AgentMsgEnvelope + <> "const env = decodeAgentMsgEnvelope(new Decoder(clientMessage.body));" + <> "if (env.type !== 'envelope') throw new Error('expected envelope, got ' + env.type);" + -- Layer 3: ratchet decrypt + <> "const dec = rcDecrypt(alice, sk, env.encAgentMessage);" + -- Layer 4: decode AgentMessage + AMessage + <> "const agentMsg = decodeAgentMessage(new Decoder(dec.plaintext));" + <> "if (agentMsg.type !== 'message') throw new Error('expected message, got ' + agentMsg.type);" + <> "if (agentMsg.msg.type !== 'A_MSG') throw new Error('expected A_MSG, got ' + agentMsg.msg.type);" + <> jsOut ("agentMsg.msg.body") + tsResult `shouldBe` "hello full stack" + + it "TypeScript encodes all layers, Haskell decodes" $ do + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + -- Alice (receiver): Haskell generates ratchet rcv params + X25519 for per-queue E2E + (alicePk1, alicePk2, Nothing, e2eAlice) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOff + let aliceE2EBytes = smpEncode e2eAlice + (e2eRcvPub, e2eRcvPriv) <- atomically $ C.generateKeyPair @'C.X25519 g + -- TypeScript: init bob's sender ratchet, encode full stack, output keys + ciphertext + tsOutput <- callNode $ impSodium <> impEnc <> impRatchet <> impAgentMsg <> impProtoE2E + <> "import { generateX25519KeyPair, dh, encodePubKeyX25519 } from '@simplex-chat/xftp-web/dist/crypto/keys.js';" + -- Init bob's sender ratchet + <> "const d = new Decoder(" <> jsUint8 aliceE2EBytes <> ");" + <> "const aliceV = d.anyByte() * 256 + d.anyByte();" + <> "const alicePk1Raw = decodePubKeyX448(decodeBytes(d));" + <> "const alicePk2Raw = decodePubKeyX448(decodeBytes(d));" + <> "const b1 = generateX448KeyPair(), b2 = generateX448KeyPair(), b3 = generateX448KeyPair();" + <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, alicePk1Raw, alicePk2Raw);" + <> "let bob = initSndRatchet({current:3,maxSupported:3}, alicePk2Raw, b3.privateKey, bp, null);" + -- Layer 4: encode AgentMessage + <> "const agentMsg = encodeAgentMessage({type: 'message', header: {sndMsgId: 1n, prevMsgHash: new Uint8Array(32)}, msg: {type: 'A_MSG', body: new TextEncoder().encode('hello from ts full stack')}});" + -- Layer 3: ratchet encrypt + <> "const enc = rcEncrypt(bob, agentMsg, 15840);" + -- Layer 2: wrap in AgentMsgEnvelope + <> "const envBytes = encodeAgentMsgEnvelope({type: 'envelope', agentVersion: 7, encAgentMessage: enc.ciphertext});" + -- Layer 1: per-queue E2E encrypt + <> "const sndKp = generateX25519KeyPair();" + <> "const rcvPub = " <> jsUint8 (C.pubKeyBytes e2eRcvPub) <> ";" + <> "const dhSecret = dh(rcvPub, sndKp.privateKey);" + <> "const clientMsg = encodeClientMessage({privHeader: {type: 'PHEmpty'}, body: envBytes});" + <> "const cmeBytes = agentCbEncrypt(dhSecret, 19, null, clientMsg);" + -- Output: bob E2E params + snd DER pubkey + cmeBytes + <> "const bobE2E = new Uint8Array([...encodeWord16(3), ...encodeBytes(encodePubKeyX448(b1.publicKey)), ...encodeBytes(encodePubKeyX448(b2.publicKey)), 0x30]);" + <> "const sndDer = encodePubKeyX25519(sndKp.publicKey);" + <> "const out = new Uint8Array([" + <> " (bobE2E.length >> 8) & 0xff, bobE2E.length & 0xff, ...bobE2E," + <> " sndDer.length, ...sndDer," + <> " ...cmeBytes" + <> "]);" + <> jsOut ("out") + -- Parse TypeScript output + let (e2eLenBs, r1) = B.splitAt 2 tsOutput + bobE2ELen = fromIntegral (B.index e2eLenBs 0) * 256 + fromIntegral (B.index e2eLenBs 1) + (bobE2EBytes, r2) = B.splitAt bobE2ELen r1 + sndDerLen = fromIntegral $ B.head r2 + (sndDerBytes, cmeBytes) = B.splitAt sndDerLen $ B.drop 1 r2 + -- Haskell: init alice's receiver ratchet, decode all layers + Right (CR.AE2ERatchetParams _ bobE2E :: CR.AE2ERatchetParams 'X448) <- pure $ smpDecode bobE2EBytes + Right (aliceInitParams, _) <- runExceptT $ CR.pqX3dhRcv alicePk1 alicePk2 Nothing bobE2E + let aliceRatchet = CR.initRcvRatchet (CR.RatchetVersions v v) alicePk2 (aliceInitParams, Nothing) CR.PQSupportOff + -- Decode all layers using ExceptT to chain pure Either + IO + gAlice <- C.newRandom + result <- runExceptT $ do + -- Per-queue E2E decrypt (pure) + apk <- liftEither $ C.decodePubKey sndDerBytes + dhSecret <- case apk of + C.APublicKey C.SX25519 pk -> pure $ C.dh' pk e2eRcvPriv + _ -> throwError "not X25519" + cme <- liftEither $ smpDecode cmeBytes + plaintext <- liftEither $ first show $ C.cbDecrypt dhSecret (cmNonce cme) (cmEncBody cme) + cm <- liftEither $ smpDecode plaintext + envBody <- case cm of + ClientMessage PHEmpty b -> pure b + _ -> throwError "unexpected PrivHeader" + -- Decode envelope + env <- liftEither $ smpDecode envBody + encMsg <- case env of + AP.AgentMsgEnvelope {AP.encAgentMessage = m} -> pure m + _ -> throwError "unexpected AgentMsgEnvelope variant" + -- Ratchet decrypt (IO) + (msgBody_, _, _) <- withExceptT show $ CR.rcDecrypt gAlice aliceRatchet M.empty encMsg + liftEither $ first show msgBody_ + -- Decode agent message from result + agentMsgBytes <- either (error . ("decode failed: " <>)) pure result + Right (AP.AgentMessage _ (AP.A_MSG body)) <- pure $ smpDecode agentMsgBytes + body `shouldBe` "hello from ts full stack" +