diff --git a/smp-web/.gitignore b/smp-web/.gitignore index 320c107b3..0a57d7b8c 100644 --- a/smp-web/.gitignore +++ b/smp-web/.gitignore @@ -1,3 +1,4 @@ node_modules/ dist/ +dist-test/ package-lock.json diff --git a/smp-web/cbits/js_random.js b/smp-web/cbits/js_random.js new file mode 100644 index 000000000..7c2eececf --- /dev/null +++ b/smp-web/cbits/js_random.js @@ -0,0 +1,14 @@ +addToLibrary({ + js_random_bytes: function(buf, len) { + var bytes = new Uint8Array(len); + if (typeof crypto !== 'undefined' && crypto.getRandomValues) { + crypto.getRandomValues(bytes); + } else { + // Node.js fallback + var nodeCrypto = require('crypto'); + var nodeBytes = nodeCrypto.randomBytes(len); + bytes.set(nodeBytes); + } + HEAPU8.set(bytes, buf); + } +}); diff --git a/smp-web/cbits/sha512.c b/smp-web/cbits/sha512.c new file mode 100644 index 000000000..8045fa8ee --- /dev/null +++ b/smp-web/cbits/sha512.c @@ -0,0 +1,315 @@ +/* +20080913 +D. J. Bernstein +Public domain. + +SHA-512 implementation from SUPERCOP/NaCl. +Source: https://bench.cr.yp.to/supercop.html + crypto_hashblocks/sha512/ref/blocks.c + crypto_hash/sha512/ref/hash.c + +Combined into a single file for WASM compilation alongside sntrup761. +*/ + +#include "sha512.h" + +typedef unsigned long long uint64; + +/* -- crypto_hashblocks_sha512 (blocks.c) -- */ + +static uint64 load_bigendian(const unsigned char *x) +{ + return + (uint64) (x[7]) \ + | (((uint64) (x[6])) << 8) \ + | (((uint64) (x[5])) << 16) \ + | (((uint64) (x[4])) << 24) \ + | (((uint64) (x[3])) << 32) \ + | (((uint64) (x[2])) << 40) \ + | (((uint64) (x[1])) << 48) \ + | (((uint64) (x[0])) << 56) + ; +} + +static void store_bigendian(unsigned char *x,uint64 u) +{ + x[7] = u; u >>= 8; + x[6] = u; u >>= 8; + x[5] = u; u >>= 8; + x[4] = u; u >>= 8; + x[3] = u; u >>= 8; + x[2] = u; u >>= 8; + x[1] = u; u >>= 8; + x[0] = u; +} + +#define SHR(x,c) ((x) >> (c)) +#define ROTR(x,c) (((x) >> (c)) | ((x) << (64 - (c)))) + +#define Ch(x,y,z) ((x & y) ^ (~x & z)) +#define Maj(x,y,z) ((x & y) ^ (x & z) ^ (y & z)) +#define Sigma0(x) (ROTR(x,28) ^ ROTR(x,34) ^ ROTR(x,39)) +#define Sigma1(x) (ROTR(x,14) ^ ROTR(x,18) ^ ROTR(x,41)) +#define sigma0(x) (ROTR(x, 1) ^ ROTR(x, 8) ^ SHR(x,7)) +#define sigma1(x) (ROTR(x,19) ^ ROTR(x,61) ^ SHR(x,6)) + +#define M(w0,w14,w9,w1) w0 = sigma1(w14) + w9 + sigma0(w1) + w0; + +#define EXPAND \ + M(w0 ,w14,w9 ,w1 ) \ + M(w1 ,w15,w10,w2 ) \ + M(w2 ,w0 ,w11,w3 ) \ + M(w3 ,w1 ,w12,w4 ) \ + M(w4 ,w2 ,w13,w5 ) \ + M(w5 ,w3 ,w14,w6 ) \ + M(w6 ,w4 ,w15,w7 ) \ + M(w7 ,w5 ,w0 ,w8 ) \ + M(w8 ,w6 ,w1 ,w9 ) \ + M(w9 ,w7 ,w2 ,w10) \ + M(w10,w8 ,w3 ,w11) \ + M(w11,w9 ,w4 ,w12) \ + M(w12,w10,w5 ,w13) \ + M(w13,w11,w6 ,w14) \ + M(w14,w12,w7 ,w15) \ + M(w15,w13,w8 ,w0 ) + +#define F(w,k) \ + T1 = h + Sigma1(e) + Ch(e,f,g) + k + w; \ + T2 = Sigma0(a) + Maj(a,b,c); \ + h = g; \ + g = f; \ + f = e; \ + e = d + T1; \ + d = c; \ + c = b; \ + b = a; \ + a = T1 + T2; + +static int crypto_hashblocks_sha512(unsigned char *statebytes,const unsigned char *in,unsigned long long inlen) +{ + uint64 state[8]; + uint64 a; + uint64 b; + uint64 c; + uint64 d; + uint64 e; + uint64 f; + uint64 g; + uint64 h; + uint64 T1; + uint64 T2; + + a = load_bigendian(statebytes + 0); state[0] = a; + b = load_bigendian(statebytes + 8); state[1] = b; + c = load_bigendian(statebytes + 16); state[2] = c; + d = load_bigendian(statebytes + 24); state[3] = d; + e = load_bigendian(statebytes + 32); state[4] = e; + f = load_bigendian(statebytes + 40); state[5] = f; + g = load_bigendian(statebytes + 48); state[6] = g; + h = load_bigendian(statebytes + 56); state[7] = h; + + while (inlen >= 128) { + uint64 w0 = load_bigendian(in + 0); + uint64 w1 = load_bigendian(in + 8); + uint64 w2 = load_bigendian(in + 16); + uint64 w3 = load_bigendian(in + 24); + uint64 w4 = load_bigendian(in + 32); + uint64 w5 = load_bigendian(in + 40); + uint64 w6 = load_bigendian(in + 48); + uint64 w7 = load_bigendian(in + 56); + uint64 w8 = load_bigendian(in + 64); + uint64 w9 = load_bigendian(in + 72); + uint64 w10 = load_bigendian(in + 80); + uint64 w11 = load_bigendian(in + 88); + uint64 w12 = load_bigendian(in + 96); + uint64 w13 = load_bigendian(in + 104); + uint64 w14 = load_bigendian(in + 112); + uint64 w15 = load_bigendian(in + 120); + + F(w0 ,0x428a2f98d728ae22ULL) + F(w1 ,0x7137449123ef65cdULL) + F(w2 ,0xb5c0fbcfec4d3b2fULL) + F(w3 ,0xe9b5dba58189dbbcULL) + F(w4 ,0x3956c25bf348b538ULL) + F(w5 ,0x59f111f1b605d019ULL) + F(w6 ,0x923f82a4af194f9bULL) + F(w7 ,0xab1c5ed5da6d8118ULL) + F(w8 ,0xd807aa98a3030242ULL) + F(w9 ,0x12835b0145706fbeULL) + F(w10,0x243185be4ee4b28cULL) + F(w11,0x550c7dc3d5ffb4e2ULL) + F(w12,0x72be5d74f27b896fULL) + F(w13,0x80deb1fe3b1696b1ULL) + F(w14,0x9bdc06a725c71235ULL) + F(w15,0xc19bf174cf692694ULL) + + EXPAND + + F(w0 ,0xe49b69c19ef14ad2ULL) + F(w1 ,0xefbe4786384f25e3ULL) + F(w2 ,0x0fc19dc68b8cd5b5ULL) + F(w3 ,0x240ca1cc77ac9c65ULL) + F(w4 ,0x2de92c6f592b0275ULL) + F(w5 ,0x4a7484aa6ea6e483ULL) + F(w6 ,0x5cb0a9dcbd41fbd4ULL) + F(w7 ,0x76f988da831153b5ULL) + F(w8 ,0x983e5152ee66dfabULL) + F(w9 ,0xa831c66d2db43210ULL) + F(w10,0xb00327c898fb213fULL) + F(w11,0xbf597fc7beef0ee4ULL) + F(w12,0xc6e00bf33da88fc2ULL) + F(w13,0xd5a79147930aa725ULL) + F(w14,0x06ca6351e003826fULL) + F(w15,0x142929670a0e6e70ULL) + + EXPAND + + F(w0 ,0x27b70a8546d22ffcULL) + F(w1 ,0x2e1b21385c26c926ULL) + F(w2 ,0x4d2c6dfc5ac42aedULL) + F(w3 ,0x53380d139d95b3dfULL) + F(w4 ,0x650a73548baf63deULL) + F(w5 ,0x766a0abb3c77b2a8ULL) + F(w6 ,0x81c2c92e47edaee6ULL) + F(w7 ,0x92722c851482353bULL) + F(w8 ,0xa2bfe8a14cf10364ULL) + F(w9 ,0xa81a664bbc423001ULL) + F(w10,0xc24b8b70d0f89791ULL) + F(w11,0xc76c51a30654be30ULL) + F(w12,0xd192e819d6ef5218ULL) + F(w13,0xd69906245565a910ULL) + F(w14,0xf40e35855771202aULL) + F(w15,0x106aa07032bbd1b8ULL) + + EXPAND + + F(w0 ,0x19a4c116b8d2d0c8ULL) + F(w1 ,0x1e376c085141ab53ULL) + F(w2 ,0x2748774cdf8eeb99ULL) + F(w3 ,0x34b0bcb5e19b48a8ULL) + F(w4 ,0x391c0cb3c5c95a63ULL) + F(w5 ,0x4ed8aa4ae3418acbULL) + F(w6 ,0x5b9cca4f7763e373ULL) + F(w7 ,0x682e6ff3d6b2b8a3ULL) + F(w8 ,0x748f82ee5defb2fcULL) + F(w9 ,0x78a5636f43172f60ULL) + F(w10,0x84c87814a1f0ab72ULL) + F(w11,0x8cc702081a6439ecULL) + F(w12,0x90befffa23631e28ULL) + F(w13,0xa4506cebde82bde9ULL) + F(w14,0xbef9a3f7b2c67915ULL) + F(w15,0xc67178f2e372532bULL) + + EXPAND + + F(w0 ,0xca273eceea26619cULL) + F(w1 ,0xd186b8c721c0c207ULL) + F(w2 ,0xeada7dd6cde0eb1eULL) + F(w3 ,0xf57d4f7fee6ed178ULL) + F(w4 ,0x06f067aa72176fbaULL) + F(w5 ,0x0a637dc5a2c898a6ULL) + F(w6 ,0x113f9804bef90daeULL) + F(w7 ,0x1b710b35131c471bULL) + F(w8 ,0x28db77f523047d84ULL) + F(w9 ,0x32caab7b40c72493ULL) + F(w10,0x3c9ebe0a15c9bebcULL) + F(w11,0x431d67c49c100d4cULL) + F(w12,0x4cc5d4becb3e42b6ULL) + F(w13,0x597f299cfc657e2aULL) + F(w14,0x5fcb6fab3ad6faecULL) + F(w15,0x6c44198c4a475817ULL) + + a += state[0]; + b += state[1]; + c += state[2]; + d += state[3]; + e += state[4]; + f += state[5]; + g += state[6]; + h += state[7]; + + state[0] = a; + state[1] = b; + state[2] = c; + state[3] = d; + state[4] = e; + state[5] = f; + state[6] = g; + state[7] = h; + + in += 128; + inlen -= 128; + } + + store_bigendian(statebytes + 0,state[0]); + store_bigendian(statebytes + 8,state[1]); + store_bigendian(statebytes + 16,state[2]); + store_bigendian(statebytes + 24,state[3]); + store_bigendian(statebytes + 32,state[4]); + store_bigendian(statebytes + 40,state[5]); + store_bigendian(statebytes + 48,state[6]); + store_bigendian(statebytes + 56,state[7]); + + return inlen; +} + +/* -- crypto_hash_sha512 (hash.c) -- */ + +static const unsigned char iv[64] = { + 0x6a,0x09,0xe6,0x67,0xf3,0xbc,0xc9,0x08, + 0xbb,0x67,0xae,0x85,0x84,0xca,0xa7,0x3b, + 0x3c,0x6e,0xf3,0x72,0xfe,0x94,0xf8,0x2b, + 0xa5,0x4f,0xf5,0x3a,0x5f,0x1d,0x36,0xf1, + 0x51,0x0e,0x52,0x7f,0xad,0xe6,0x82,0xd1, + 0x9b,0x05,0x68,0x8c,0x2b,0x3e,0x6c,0x1f, + 0x1f,0x83,0xd9,0xab,0xfb,0x41,0xbd,0x6b, + 0x5b,0xe0,0xcd,0x19,0x13,0x7e,0x21,0x79 +}; + +void crypto_hash_sha512(unsigned char *out, + const unsigned char *in, + unsigned long long inlen) +{ + unsigned char h[64]; + unsigned char padded[256]; + int i; + unsigned long long bytes = inlen; + + for (i = 0;i < 64;++i) h[i] = iv[i]; + + crypto_hashblocks_sha512(h,in,inlen); + in += inlen; + inlen &= 127; + in -= inlen; + + for (i = 0;i < (int)inlen;++i) padded[i] = in[i]; + padded[inlen] = 0x80; + + if (inlen < 112) { + for (i = inlen + 1;i < 119;++i) padded[i] = 0; + padded[119] = bytes >> 61; + padded[120] = bytes >> 53; + padded[121] = bytes >> 45; + padded[122] = bytes >> 37; + padded[123] = bytes >> 29; + padded[124] = bytes >> 21; + padded[125] = bytes >> 13; + padded[126] = bytes >> 5; + padded[127] = bytes << 3; + crypto_hashblocks_sha512(h,padded,128); + } else { + for (i = inlen + 1;i < 247;++i) padded[i] = 0; + padded[247] = bytes >> 61; + padded[248] = bytes >> 53; + padded[249] = bytes >> 45; + padded[250] = bytes >> 37; + padded[251] = bytes >> 29; + padded[252] = bytes >> 21; + padded[253] = bytes >> 13; + padded[254] = bytes >> 5; + padded[255] = bytes << 3; + crypto_hashblocks_sha512(h,padded,256); + } + + for (i = 0;i < 64;++i) out[i] = h[i]; +} diff --git a/smp-web/cbits/sntrup761.d.mts b/smp-web/cbits/sntrup761.d.mts new file mode 100644 index 000000000..2f603d0da --- /dev/null +++ b/smp-web/cbits/sntrup761.d.mts @@ -0,0 +1,11 @@ +interface Sntrup761Module { + _sntrup761_wasm_keypair(pk: number, sk: number): void + _sntrup761_wasm_enc(c: number, k: number, pk: number): void + _sntrup761_wasm_dec(k: number, c: number, sk: number): void + _malloc(size: number): number + _free(ptr: number): void + HEAPU8: Uint8Array +} + +declare function createSntrup761(): Promise +export default createSntrup761 diff --git a/smp-web/cbits/sntrup761_wasm.c b/smp-web/cbits/sntrup761_wasm.c new file mode 100644 index 000000000..857db3506 --- /dev/null +++ b/smp-web/cbits/sntrup761_wasm.c @@ -0,0 +1,34 @@ +/* + * WASM wrapper for sntrup761. + * Provides JS-callable functions with RNG from JS imports. + * + * Build: emcc sntrup761_wasm.c sntrup761.c sha512.c -O2 -o sntrup761.js \ + * -s EXPORTED_FUNCTIONS='["_sntrup761_wasm_keypair","_sntrup761_wasm_enc","_sntrup761_wasm_dec","_malloc","_free"]' \ + * -s EXPORTED_RUNTIME_METHODS='["ccall","cwrap"]' + */ + +#include "sntrup761.h" +#include + +/* Import RNG from JS environment */ +extern void js_random_bytes(unsigned char *buf, int len); + +/* RNG callback adapter for sntrup761 */ +static void wasm_random(void *ctx, size_t length, uint8_t *dst) { + (void)ctx; + js_random_bytes(dst, (int)length); +} + +/* JS-callable wrappers */ + +void sntrup761_wasm_keypair(unsigned char *pk, unsigned char *sk) { + sntrup761_keypair(pk, sk, NULL, wasm_random); +} + +void sntrup761_wasm_enc(unsigned char *c, unsigned char *k, const unsigned char *pk) { + sntrup761_enc(c, k, pk, NULL, wasm_random); +} + +void sntrup761_wasm_dec(unsigned char *k, const unsigned char *c, const unsigned char *sk) { + sntrup761_dec(k, c, sk); +} diff --git a/smp-web/package.json b/smp-web/package.json index 996d2896b..ab9ff7c1e 100644 --- a/smp-web/package.json +++ b/smp-web/package.json @@ -14,11 +14,17 @@ "dist" ], "scripts": { - "build": "tsc" + "build:wasm": "mkdir -p dist/wasm && npx emcc cbits/sntrup761_wasm.c ../cbits/sntrup761.c cbits/sha512.c -I../cbits -O2 -o dist/wasm/sntrup761.mjs -s EXPORTED_FUNCTIONS='[\"_sntrup761_wasm_keypair\",\"_sntrup761_wasm_enc\",\"_sntrup761_wasm_dec\",\"_malloc\",\"_free\"]' -s EXPORTED_RUNTIME_METHODS='[\"ccall\",\"cwrap\",\"HEAPU8\"]' -s MODULARIZE=1 -s EXPORT_NAME='createSntrup761' -s ALLOW_MEMORY_GROWTH=1 -s ENVIRONMENT='web,node' --js-library cbits/js_random.js && cp cbits/sntrup761.d.mts dist/wasm/", + "build:ts": "tsc", + "build:test": "tsc -p tsconfig.test.json", + "build": "npm run build:wasm && npm run build:ts && npm run build:test" }, "dependencies": { + "@noble/ciphers": "^2.2.0", + "@noble/curves": "^2.2.0", "@noble/hashes": "^1.5.0", - "@simplex-chat/xftp-web": "file:../xftp-web" + "@simplex-chat/xftp-web": "file:../xftp-web", + "emsdk": "^0.4.0" }, "devDependencies": { "@types/node": "^25.5.0", diff --git a/smp-web/src/agent/message.ts b/smp-web/src/agent/message.ts new file mode 100644 index 000000000..edaf9cd2b --- /dev/null +++ b/smp-web/src/agent/message.ts @@ -0,0 +1,207 @@ +// Agent message encoding/decoding. +// Mirrors: Simplex.Messaging.Agent.Protocol (AgentMsgEnvelope, AgentMessage, APrivHeader, AMessage) + +import { + Decoder, concatBytes, + encodeBytes, decodeBytes, + encodeLarge, decodeLarge, + encodeInt64, decodeInt64, + encodeWord16, decodeWord16, + encodeMaybe, decodeMaybe, + encodeNonEmpty, decodeNonEmpty, +} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" + +// -- Constants (Agent/Protocol.hs:318-319) + +export const currentSMPAgentVersion = 7 + +// -- AMessage (Agent/Protocol.hs:1001-1020) + +export type AMessage = + | {type: "HELLO"} + | {type: "A_MSG", body: Uint8Array} + | {type: "A_RCVD", receipts: AMessageReceipt[]} // NonEmpty + | {type: "EREADY", lastDecryptedMsgId: bigint} + +// Agent/Protocol.hs:1040-1045 +export interface AMessageReceipt { + agentMsgId: bigint // Int64 + msgHash: Uint8Array // ByteString (32-byte SHA-256) + rcptInfo: Uint8Array // MsgReceiptInfo (ByteString, Large-encoded) +} + +// Agent/Protocol.hs:1078-1100 +export function encodeAMessage(msg: AMessage): Uint8Array { + switch (msg.type) { + case "HELLO": return new Uint8Array([0x48]) // "H" + case "A_MSG": return concatBytes(new Uint8Array([0x4D]), msg.body) // "M" + Tail + case "A_RCVD": return concatBytes(new Uint8Array([0x56]), encodeNonEmpty(encodeAMessageReceipt, msg.receipts)) // "V" + NonEmpty + case "EREADY": return concatBytes(new Uint8Array([0x45]), encodeInt64(msg.lastDecryptedMsgId)) // "E" + Int64 + } +} + +export function decodeAMessage(d: Decoder): AMessage { + const tag = d.anyByte() + switch (tag) { + case 0x48: return {type: "HELLO"} // 'H' + case 0x4D: return {type: "A_MSG", body: d.takeAll()} // 'M' + Tail + case 0x56: return {type: "A_RCVD", receipts: decodeNonEmpty(decodeAMessageReceipt, d)} // 'V' + case 0x45: return {type: "EREADY", lastDecryptedMsgId: decodeInt64(d)} // 'E' + // Queue management tags (not needed for chat messages, but recognized for decoding) + case 0x51: { // 'Q' + const sub = d.anyByte() + switch (sub) { + case 0x43: // 'C' = A_QCONT + case 0x41: // 'A' = QADD + case 0x4B: // 'K' = QKEY + case 0x55: // 'U' = QUSE + case 0x54: // 'T' = QTEST + throw new Error("decodeAMessage: queue management message (Q" + String.fromCharCode(sub) + ") not implemented") + default: + throw new Error("decodeAMessage: unknown Q-subtag " + sub) + } + } + default: + throw new Error("decodeAMessage: unknown tag " + tag) + } +} + +// Agent/Protocol.hs:1106-1111 +function encodeAMessageReceipt(r: AMessageReceipt): Uint8Array { + return concatBytes(encodeInt64(r.agentMsgId), encodeBytes(r.msgHash), encodeLarge(r.rcptInfo)) +} + +function decodeAMessageReceipt(d: Decoder): AMessageReceipt { + return {agentMsgId: decodeInt64(d), msgHash: decodeBytes(d), rcptInfo: decodeLarge(d)} +} + +// -- APrivHeader (Agent/Protocol.hs:946-957) + +export interface APrivHeader { + sndMsgId: bigint // AgentMsgId = Int64 + prevMsgHash: Uint8Array // MsgHash = ByteString +} + +export function encodeAPrivHeader(h: APrivHeader): Uint8Array { + return concatBytes(encodeInt64(h.sndMsgId), encodeBytes(h.prevMsgHash)) +} + +export function decodeAPrivHeader(d: Decoder): APrivHeader { + return {sndMsgId: decodeInt64(d), prevMsgHash: decodeBytes(d)} +} + +// -- AgentMessage (Agent/Protocol.hs:866-888) + +export type AgentMessage = + | {type: "connInfo", cInfo: Uint8Array} + | {type: "connInfoReply", smpQueues: Uint8Array[], cInfo: Uint8Array} // NonEmpty raw-encoded SMPQueueInfo + | {type: "ratchetInfo", info: Uint8Array} + | {type: "message", header: APrivHeader, msg: AMessage} + +export function encodeAgentMessage(msg: AgentMessage): Uint8Array { + switch (msg.type) { + case "connInfo": + return concatBytes(new Uint8Array([0x49]), msg.cInfo) // 'I' + Tail + case "connInfoReply": + // 'D' + NonEmpty SMPQueueInfo + Tail cInfo + // SMPQueueInfo encoding is complex; for now encode the raw bytes + return concatBytes( + new Uint8Array([0x44]), + encodeNonEmpty(b => b, msg.smpQueues), + msg.cInfo, + ) + case "ratchetInfo": + return concatBytes(new Uint8Array([0x52]), msg.info) // 'R' + Tail + case "message": + return concatBytes(new Uint8Array([0x4D]), encodeAPrivHeader(msg.header), encodeAMessage(msg.msg)) // 'M' + header + msg + } +} + +export function decodeAgentMessage(d: Decoder): AgentMessage { + const tag = d.anyByte() + switch (tag) { + case 0x49: return {type: "connInfo", cInfo: d.takeAll()} // 'I' + Tail + case 0x44: { // 'D' + // NonEmpty SMPQueueInfo is complex to decode; skip for now, just capture raw + throw new Error("decodeAgentMessage: connInfoReply ('D') not implemented") + } + case 0x52: return {type: "ratchetInfo", info: d.takeAll()} // 'R' + Tail + case 0x4D: return {type: "message", header: decodeAPrivHeader(d), msg: decodeAMessage(d)} // 'M' + default: + throw new Error("decodeAgentMessage: unknown tag " + tag) + } +} + +// -- AgentMsgEnvelope (Agent/Protocol.hs:812-861) + +export type AgentMsgEnvelope = + | {type: "confirmation", agentVersion: number, e2eEncryption: Uint8Array | null, encConnInfo: Uint8Array} + | {type: "envelope", agentVersion: number, encAgentMessage: Uint8Array} + | {type: "invitation", agentVersion: number, connReqBytes: Uint8Array, connInfo: Uint8Array} + | {type: "ratchetKey", agentVersion: number, e2eEncryption: Uint8Array, info: Uint8Array} + +// Agent/Protocol.hs:835-843 +export function encodeAgentMsgEnvelope(env: AgentMsgEnvelope): Uint8Array { + switch (env.type) { + case "confirmation": + // (agentVersion, 'C', Maybe SndE2ERatchetParams, Tail encConnInfo) + return concatBytes( + encodeWord16(env.agentVersion), + new Uint8Array([0x43]), // 'C' + encodeMaybe(b => b, env.e2eEncryption), // e2eEncryption is already smpEncoded bytes or null + env.encConnInfo, // Tail + ) + case "envelope": + // (agentVersion, 'M', Tail encAgentMessage) + return concatBytes( + encodeWord16(env.agentVersion), + new Uint8Array([0x4D]), // 'M' + env.encAgentMessage, // Tail + ) + case "invitation": + // (agentVersion, 'I', Large connReqBytes, Tail connInfo) + return concatBytes( + encodeWord16(env.agentVersion), + new Uint8Array([0x49]), // 'I' + encodeLarge(env.connReqBytes), + env.connInfo, // Tail + ) + case "ratchetKey": + // (agentVersion, 'R', e2eEncryption, Tail info) + return concatBytes( + encodeWord16(env.agentVersion), + new Uint8Array([0x52]), // 'R' + env.e2eEncryption, // already smpEncoded + env.info, // Tail + ) + } +} + +// Agent/Protocol.hs:844-861 +export function decodeAgentMsgEnvelope(d: Decoder): AgentMsgEnvelope { + const agentVersion = decodeWord16(d) + const tag = d.anyByte() + switch (tag) { + case 0x43: // 'C' Confirmation + // e2eEncryption_ is Maybe (SndE2ERatchetParams 'X448), encConnInfo is Tail + // Full parsing of E2ERatchetParams needed to split the boundary — not implemented in spike + throw new Error("decodeAgentMsgEnvelope: confirmation ('C') not fully implemented") + case 0x4D: // 'M' Message envelope + return {type: "envelope", agentVersion, encAgentMessage: d.takeAll()} // Tail + case 0x49: { // 'I' Invitation + const connReqBytes = decodeLarge(d) + const connInfo = d.takeAll() // Tail + return {type: "invitation", agentVersion, connReqBytes, connInfo} + } + case 0x52: { // 'R' RatchetKey + // e2eEncryption is an E2ERatchetParams — variable-length, not Tail + // For now, capture remaining minus nothing (since info is Tail and comes last) + // This is tricky: e2eEncryption is smpEncoded E2ERatchetParams, info is Tail + // We can't easily split without knowing the E2ERatchetParams length + // For the spike, just capture all remaining as raw + throw new Error("decodeAgentMsgEnvelope: ratchetKey ('R') not fully implemented") + } + default: + throw new Error("decodeAgentMsgEnvelope: unknown tag " + tag) + } +} diff --git a/smp-web/src/crypto.ts b/smp-web/src/crypto.ts index 5e84fa0c2..3e074f7ed 100644 --- a/smp-web/src/crypto.ts +++ b/smp-web/src/crypto.ts @@ -3,7 +3,10 @@ import {hkdf as nobleHkdf} from "@noble/hashes/hkdf" import {sha512} from "@noble/hashes/sha512" +import {gcm} from "@noble/ciphers/aes.js" import {cbEncrypt, cbDecrypt} from "@simplex-chat/xftp-web/dist/crypto/secretbox.js" +import {concatBytes} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" +import {pad, unPad} from "@simplex-chat/xftp-web/dist/crypto/padding.js" // C.hkdf (Crypto.hs:1461-1464) // HKDF-SHA512 extract + expand @@ -48,3 +51,39 @@ export function sbDecryptBlock(chainKey: Uint8Array, block: Uint8Array): {decryp const {keyNonce: {sbKey, nonce}, nextChainKey} = sbcHkdf(chainKey) return {decrypted: cbDecrypt(sbKey, nonce, block), nextChainKey} } + +// -- AES-256-GCM authenticated encryption (Crypto.hs:1035-1061) +// Uses 16-byte IVs (GCM with GHASH path per NIST SP 800-38D for IVs != 96 bits) + +export const AUTH_TAG_SIZE = 16 + +// encryptAEAD (Crypto.hs:1035-1039) +export function encryptAEAD( + key: Uint8Array, // 32 bytes + iv: Uint8Array, // 16 bytes + paddedLen: number, + ad: Uint8Array, + plaintext: Uint8Array, +): {authTag: Uint8Array; ciphertext: Uint8Array} { + const padded = pad(plaintext, paddedLen) + const cipher = gcm(key, iv, ad) + const encrypted = cipher.encrypt(padded) + return { + ciphertext: encrypted.subarray(0, encrypted.length - AUTH_TAG_SIZE), + authTag: encrypted.subarray(encrypted.length - AUTH_TAG_SIZE), + } +} + +// decryptAEAD (Crypto.hs:1058-1061) +export function decryptAEAD( + key: Uint8Array, + iv: Uint8Array, + ad: Uint8Array, + ciphertext: Uint8Array, + authTag: Uint8Array, +): Uint8Array { + const cipher = gcm(key, iv, ad) + const encrypted = concatBytes(ciphertext, authTag) + const padded = cipher.decrypt(encrypted) + return unPad(padded) +} diff --git a/smp-web/src/crypto/ratchet.ts b/smp-web/src/crypto/ratchet.ts new file mode 100644 index 000000000..80de7456c --- /dev/null +++ b/smp-web/src/crypto/ratchet.ts @@ -0,0 +1,749 @@ +// Double ratchet with X3DH key agreement and PQ KEM. +// Faithful transpilation of Simplex.Messaging.Crypto.Ratchet +// +// Every type, field, and function mirrors the Haskell source. +// Line references are to src/Simplex/Messaging/Crypto/Ratchet.hs + +import {x448} from "@noble/curves/ed448.js" +import {hkdf, encryptAEAD, decryptAEAD, AUTH_TAG_SIZE} from "../crypto.js" +import {sntrup761Keypair, sntrup761Enc, sntrup761Dec} from "./sntrup761.js" +import type {KEMKeyPair} from "./sntrup761.js" +import { + Decoder, concatBytes, + encodeBytes, decodeBytes, decodeWord16, decodeWord32, + encodeLarge, decodeLarge, + encodeMaybe, +} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" + +// -- Version constants (lines 134-155) + +export const pqRatchetE2EEncryptVersion = 3 +export const currentE2EEncryptVersion = 3 + +// -- X448 key operations + +export interface X448KeyPair { + publicKey: Uint8Array // 56 bytes + privateKey: Uint8Array // 56 bytes +} + +export function generateX448KeyPair(): X448KeyPair { + const privateKey = x448.utils.randomSecretKey() + const publicKey = x448.getPublicKey(privateKey) + return {publicKey, privateKey} +} + +export function x448DH(publicKey: Uint8Array, privateKey: Uint8Array): Uint8Array { + return x448.getSharedSecret(privateKey, publicKey) +} + +// DER encoding for X448 public keys (RFC 8410, SubjectPublicKeyInfo) +// SEQUENCE { SEQUENCE { OID 1.3.101.111 } BIT STRING { 0x00 <56 bytes> } } +const X448_PUBKEY_DER_PREFIX = new Uint8Array([ + 0x30, 0x42, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x6f, 0x03, 0x39, 0x00, +]) + +export function encodePubKeyX448(rawPubKey: Uint8Array): Uint8Array { + return concatBytes(X448_PUBKEY_DER_PREFIX, rawPubKey) +} + +export function decodePubKeyX448(der: Uint8Array): Uint8Array { + if (der.length !== 68) throw new Error("decodePubKeyX448: invalid length " + der.length) + for (let i = 0; i < X448_PUBKEY_DER_PREFIX.length; i++) { + if (der[i] !== X448_PUBKEY_DER_PREFIX[i]) throw new Error("decodePubKeyX448: invalid DER prefix") + } + return der.subarray(12) +} + +// -- KEM types (lines 567-577) +// KEMKeyPair imported from ./sntrup761.js + +export interface RatchetKEMAccepted { + rcPQRr: Uint8Array // KEMPublicKey - received key (1158 bytes) + rcPQRss: Uint8Array // KEMSharedKey - computed shared secret (32 bytes) + rcPQRct: Uint8Array // KEMCiphertext - sent encaps (1039 bytes) +} + +export interface RatchetKEM { + rcPQRs: KEMKeyPair + rcKEMs: RatchetKEMAccepted | null +} + +// -- RatchetInitParams (lines 457-464) + +export interface RatchetInitParams { + assocData: Uint8Array // Str (raw bytes) + ratchetKey: Uint8Array // RatchetKey (32 bytes) + sndHK: Uint8Array // HeaderKey (32 bytes) + rcvNextHK: Uint8Array // HeaderKey (32 bytes) + kemAccepted: RatchetKEMAccepted | null // Maybe RatchetKEMAccepted +} + +// -- hkdf3 (lines 1174-1179) + +function hkdf3(salt: Uint8Array, ikm: Uint8Array, info: string): [Uint8Array, Uint8Array, Uint8Array] { + const out = hkdf(salt, ikm, info, 96) + return [out.slice(0, 32), out.slice(32, 64), out.slice(64, 96)] +} + +// -- pqX3dh (lines 499-508) + +const X3DH_SALT = new Uint8Array(64) + +function pqX3dh( + sk1: Uint8Array, rk1: Uint8Array, + dh1: Uint8Array, dh2: Uint8Array, dh3: Uint8Array, + kemAccepted: RatchetKEMAccepted | null, +): RatchetInitParams { + const assocData = concatBytes(sk1, rk1) + const pq = kemAccepted ? kemAccepted.rcPQRss : new Uint8Array(0) + const dhs = concatBytes(dh1, dh2, dh3, pq) + const [hk, nhk, sk] = hkdf3(X3DH_SALT, dhs, "SimpleXX3DH") + return {assocData, ratchetKey: sk, sndHK: hk, rcvNextHK: nhk, kemAccepted} +} + +// -- pqX3dhSnd (lines 467-480) +// Used by joiner (Alice in PQDR spec, Bob in DR spec) to init SENDING ratchet. + +export function pqX3dhSnd( + spk1: Uint8Array, spk2: Uint8Array, // our private keys + rk1: Uint8Array, rk2: Uint8Array, // their public keys (raw) + kemAccepted: RatchetKEMAccepted | null = null, +): RatchetInitParams { + const sk1Pub = x448.getPublicKey(spk1) + const dh1 = x448DH(rk1, spk2) + const dh2 = x448DH(rk2, spk1) + const dh3 = x448DH(rk2, spk2) + return pqX3dh(sk1Pub, rk1, dh1, dh2, dh3, kemAccepted) +} + +// -- pqX3dhRcv (lines 483-497) +// Used by initiator (Bob in PQDR spec, Alice in DR spec) to init RECEIVING ratchet. + +export function pqX3dhRcv( + rpk1: Uint8Array, rpk2: Uint8Array, // our private keys + sk1: Uint8Array, sk2: Uint8Array, // their public keys (raw) + kemAccepted: RatchetKEMAccepted | null = null, +): RatchetInitParams { + const rk1Pub = x448.getPublicKey(rpk1) + const dh1 = x448DH(sk2, rpk1) + const dh2 = x448DH(sk1, rpk2) + const dh3 = x448DH(sk2, rpk2) + return pqX3dh(sk1, rk1Pub, dh1, dh2, dh3, kemAccepted) +} + +// -- rootKdf (lines 1159-1166) + +export function rootKdf( + rk: Uint8Array, // RatchetKey (32 bytes) + peerPubKey: Uint8Array, // PublicKey a (raw, 56 bytes for X448) + ownPrivKey: Uint8Array, // PrivateKey a (raw, 56 bytes for X448) + kemSecret: Uint8Array | null, // Maybe KEMSharedKey +): {rk: Uint8Array; ck: Uint8Array; nhk: Uint8Array} { + const dhOut = x448DH(peerPubKey, ownPrivKey) + const ss = kemSecret ? concatBytes(dhOut, kemSecret) : dhOut + const [rk_, ck, nhk] = hkdf3(rk, ss, "SimpleXRootRatchet") + return {rk: rk_, ck, nhk} +} + +// -- chainKdf (lines 1168-1172) + +export function chainKdf(ck: Uint8Array): {ck: Uint8Array; mk: Uint8Array; iv: Uint8Array; ehIV: Uint8Array} { + const EMPTY = new Uint8Array(0) + const [ck_, mk, ivs] = hkdf3(EMPTY, ck, "SimpleXChainRatchet") + return {ck: ck_, mk, iv: ivs.slice(0, 16), ehIV: ivs.slice(16, 32)} +} + +// -- Header padding (lines 716-719) + +export function paddedHeaderLen(v: number, pqSupport: boolean): number { + if (pqSupport && v >= pqRatchetE2EEncryptVersion) return 2310 + return 88 +} + +// -- SndRatchet (lines 554-559) + +export interface SndRatchet { + rcDHRr: Uint8Array // peer's public key (raw, 56 bytes) + rcCKs: Uint8Array // sending chain key (32 bytes) + rcHKs: Uint8Array // sending header key (32 bytes) +} + +// -- RcvRatchet (lines 561-565) + +export interface RcvRatchet { + rcCKr: Uint8Array // receiving chain key (32 bytes) + rcHKr: Uint8Array // receiving header key (32 bytes) +} + +// -- MessageKey (lines 608-609) + +export interface MessageKey { + mk: Uint8Array // Key (32 bytes) + iv: Uint8Array // IV (16 bytes) +} + +// -- RatchetVersions (lines 534-538) + +export interface RatchetVersions { + current: number + maxSupported: number +} + +// -- Ratchet (lines 512-532) + +export interface Ratchet { + rcVersion: RatchetVersions + rcAD: Uint8Array // Str (associated data, raw bytes) + rcDHRs: Uint8Array // PrivateKey a (raw, 56 bytes) + rcKEM: RatchetKEM | null + rcSupportKEM: boolean // PQSupport + rcEnableKEM: boolean // PQEncryption + rcSndKEM: boolean // PQEncryption + rcRcvKEM: boolean // PQEncryption + rcRK: Uint8Array // RatchetKey (32 bytes) + rcSnd: SndRatchet | null + rcRcv: RcvRatchet | null + rcNs: number // Word32 + rcNr: number // Word32 + rcPN: number // Word32 + rcNHKs: Uint8Array // HeaderKey (32 bytes) + rcNHKr: Uint8Array // HeaderKey (32 bytes) +} + +// -- SkippedMsgKeys (lines 580-582) + +export type SkippedMsgKeys = Map> + +const MAX_SKIP = 512 + +function hexKey(k: Uint8Array): string { + return Array.from(k, b => b.toString(16).padStart(2, "0")).join("") +} + +function hexToBytes(hex: string): Uint8Array { + const bytes = new Uint8Array(hex.length / 2) + for (let i = 0; i < hex.length; i += 2) bytes[i / 2] = parseInt(hex.substring(i, i + 2), 16) + return bytes +} + +// -- initSndRatchet (lines 643-666) + +export function initSndRatchet( + rcVersion: RatchetVersions, + rcDHRr: Uint8Array, // peer's public key (raw) + rcDHRs: Uint8Array, // our private key (raw) + initParams: RatchetInitParams, + rcPQRs_: KEMKeyPair | null, +): Ratchet { + const {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted} = initParams + // state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr) || state.PQRss) + const kemSecret = kemAccepted ? kemAccepted.rcPQRss : null + const {rk: rcRK, ck: rcCKs, nhk: rcNHKs} = rootKdf(ratchetKey, rcDHRr, rcDHRs, kemSecret) + const pqOn = rcPQRs_ !== null + return { + rcVersion, + rcAD: assocData, + rcDHRs, + rcKEM: rcPQRs_ ? {rcPQRs: rcPQRs_, rcKEMs: kemAccepted} : null, + rcSupportKEM: pqOn, + rcEnableKEM: pqOn, + rcSndKEM: kemAccepted !== null, + rcRcvKEM: false, + rcRK, + rcSnd: {rcDHRr, rcCKs, rcHKs: sndHK}, + rcRcv: null, + rcPN: 0, + rcNs: 0, + rcNr: 0, + rcNHKs, + rcNHKr: rcvNextHK, + } +} + +// -- initRcvRatchet (lines 674-699) + +export function initRcvRatchet( + rcVersion: RatchetVersions, + rcDHRs: Uint8Array, // our private key (raw) + initParams: RatchetInitParams, + rcPQRs_: KEMKeyPair | null, + pqSupport: boolean, +): Ratchet { + const {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted} = initParams + return { + rcVersion, + rcAD: assocData, + rcDHRs, + rcKEM: rcPQRs_ ? {rcPQRs: rcPQRs_, rcKEMs: kemAccepted} : null, + rcSupportKEM: pqSupport, + rcEnableKEM: pqSupport, + rcSndKEM: false, + rcRcvKEM: false, + rcRK: ratchetKey, + rcSnd: null, + rcRcv: null, + rcPN: 0, + rcNs: 0, + rcNr: 0, + rcNHKs: rcvNextHK, + rcNHKr: sndHK, + } +} + +// -- RKEMParams (lines 188-190) - parsed KEM params from message header + +export type RKEMParams = + | {type: "proposed", kemPk: Uint8Array} // RKParamsProposed KEMPublicKey + | {type: "accepted", kemCt: Uint8Array, kemPk: Uint8Array} // RKParamsAccepted KEMCiphertext KEMPublicKey + +// -- MsgHeader (lines 703-711) + +interface MsgHeader { + msgMaxVersion: number + msgDHRs: Uint8Array // PublicKey a (raw, 56 bytes) + msgKEM: RKEMParams | null + msgPN: number // Word32 + msgNs: number // Word32 +} + +// -- encodeMsgHeader (lines 727-730) + +function encodeMsgHeader(v: number, hdr: MsgHeader): Uint8Array { + const vBytes = new Uint8Array(2) + vBytes[0] = (hdr.msgMaxVersion >> 8) & 0xff + vBytes[1] = hdr.msgMaxVersion & 0xff + const dhDer = encodePubKeyX448(hdr.msgDHRs) + const pn = encodeWord32(hdr.msgPN) + const ns = encodeWord32(hdr.msgNs) + if (v >= pqRatchetE2EEncryptVersion) { + // smpEncode (msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs) + // msgKEM :: Maybe ARKEMParams + const kemBytes = hdr.msgKEM ? encodeRKEMParams(hdr.msgKEM) : new Uint8Array([0x30]) // Nothing + return concatBytes(vBytes, encodeBytes(dhDer), kemBytes, pn, ns) + } + // smpEncode (msgMaxVersion, msgDHRs, msgPN, msgNs) + return concatBytes(vBytes, encodeBytes(dhDer), pn, ns) +} + +// Encode Maybe ARKEMParams: '1' + encoded params, or nothing (handled at call site with '0') +function encodeRKEMParams(params: RKEMParams): Uint8Array { + if (params.type === "proposed") { + // Just ('P', kemPk) - smpEncode ('P', k) where k is KEMPublicKey (Large) + return concatBytes(new Uint8Array([0x31, 0x50]), encodeLarge(params.kemPk)) + } + // Just ('A', ct, kemPk) - smpEncode ('A', ct, k) + return concatBytes(new Uint8Array([0x31, 0x41]), encodeLarge(params.kemCt), encodeLarge(params.kemPk)) +} + +function encodeWord32(n: number): Uint8Array { + const buf = new Uint8Array(4) + buf[0] = (n >> 24) & 0xff; buf[1] = (n >> 16) & 0xff + buf[2] = (n >> 8) & 0xff; buf[3] = n & 0xff + return buf +} + +// -- msgHeaderP (lines 733-740) + +function decodeMsgHeader(v: number, data: Uint8Array): MsgHeader { + const d = new Decoder(data) + const msgMaxVersion = decodeWord16(d) + const dhDer = decodeBytes(d) + const msgDHRs = decodePubKeyX448(dhDer) + let msgKEM: RKEMParams | null = null + if (v >= pqRatchetE2EEncryptVersion) { + // Maybe ARKEMParams + const maybeByte = d.anyByte() + if (maybeByte === 0x31) { + // Just - parse ARKEMParams + const tag = d.anyByte() + if (tag === 0x50) { // 'P' - Proposed: KEMPublicKey (Large) + msgKEM = {type: "proposed", kemPk: decodeLarge(d)} + } else if (tag === 0x41) { // 'A' - Accepted: KEMCiphertext (Large) + KEMPublicKey (Large) + const kemCt = decodeLarge(d) + const kemPk = decodeLarge(d) + msgKEM = {type: "accepted", kemCt, kemPk} + } else { + throw new Error("decodeMsgHeader: unknown KEM tag " + tag) + } + } + // else '0' = Nothing, msgKEM stays null + } + const msgPN = decodeWord32(d) + const msgNs = decodeWord32(d) + return {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} +} + +// -- EncMessageHeader (lines 742-756) + +interface EncMessageHeader { + ehVersion: number // current ratchet version + ehIV: Uint8Array // IV (raw 16 bytes) + ehAuthTag: Uint8Array // AuthTag (raw 16 bytes) + ehBody: Uint8Array // encrypted header body +} + +// smpEncode (lines 751-752) +function encodeEncMessageHeader(emh: EncMessageHeader): Uint8Array { + const vBytes = new Uint8Array(2) + vBytes[0] = (emh.ehVersion >> 8) & 0xff + vBytes[1] = emh.ehVersion & 0xff + // smpEncode (ehVersion, ehIV, ehAuthTag) <> encodeLarge ehVersion ehBody + const bodyEnc = emh.ehVersion >= pqRatchetE2EEncryptVersion + ? encodeLarge(emh.ehBody) + : encodeBytes(emh.ehBody) + return concatBytes(vBytes, emh.ehIV, emh.ehAuthTag, bodyEnc) +} + +// smpP (lines 753-756) +function decodeEncMessageHeader(data: Uint8Array): EncMessageHeader { + const d = new Decoder(data) + const ehVersion = decodeWord16(d) + const ehIV = d.take(16) // IV is raw 16 bytes + const ehAuthTag = d.take(16) // AuthTag is raw 16 bytes + // largeP: peek first byte, if < 32 then Large (2-byte len), else ByteString (1-byte len) + const firstByte = data[d.offset()] + const ehBody = firstByte < 32 ? decodeLarge(d) : decodeBytes(d) + return {ehVersion, ehIV, ehAuthTag, ehBody} +} + +// -- EncRatchetMessage (lines 772-787) + +interface EncRatchetMessage { + emHeader: Uint8Array // smpEncoded EncMessageHeader + emAuthTag: Uint8Array // AuthTag (raw 16 bytes) + emBody: Uint8Array // encrypted message body +} + +// encodeEncRatchetMessage (lines 779-781) +function encodeEncRatchetMessage(v: number, msg: EncRatchetMessage): Uint8Array { + // encodeLarge v emHeader <> smpEncode (emAuthTag, Tail emBody) + const headerEnc = v >= pqRatchetE2EEncryptVersion + ? encodeLarge(msg.emHeader) + : encodeBytes(msg.emHeader) + return concatBytes(headerEnc, msg.emAuthTag, msg.emBody) +} + +// encRatchetMessageP (lines 783-787) +function decodeEncRatchetMessage(data: Uint8Array): EncRatchetMessage { + const d = new Decoder(data) + // largeP + const firstByte = data[d.offset()] + const emHeader = firstByte < 32 ? decodeLarge(d) : decodeBytes(d) + // smpEncode (emAuthTag, Tail emBody) → raw 16 bytes + rest + const emAuthTag = d.take(16) + const emBody = d.takeAll() + return {emHeader, emAuthTag, emBody} +} + +// -- MsgEncryptKey (lines 962-968) + +interface MsgEncryptKey { + msgRcVersion: number + msgKey: MessageKey + msgRcAD: Uint8Array + msgEncHeader: Uint8Array +} + +// -- msgKEMParams (lines 956-958) - build KEM params from ratchet state for message header + +function msgKEMParams(kem: RatchetKEM): RKEMParams { + const {rcPQRs, rcKEMs} = kem + if (!rcKEMs) { + return {type: "proposed", kemPk: rcPQRs.publicKey} + } + return {type: "accepted", kemCt: rcKEMs.rcPQRct, kemPk: rcPQRs.publicKey} +} + +// -- pqEnableSupport (line 836-837) + +function pqEnableSupport(v: number, sup: boolean, enc: boolean): boolean { + return sup || (v >= pqRatchetE2EEncryptVersion && enc) +} + +// -- rcEncryptHeader + rcEncryptMsg (lines 902-975) + +export interface EncryptResult { + ciphertext: Uint8Array + state: Ratchet +} + +export function rcEncrypt( + rc: Ratchet, + plaintext: Uint8Array, + paddedMsgLen: number, +): EncryptResult { + if (!rc.rcSnd) throw new Error("rcEncrypt: no sending ratchet (CERatchetState)") + const snd = rc.rcSnd + const v = rc.rcVersion.current + + // state.CKs, mk = KDF_CK(state.CKs) + const chain = chainKdf(snd.rcCKs) + + // header + const headerPlain = encodeMsgHeader(v, { + msgMaxVersion: rc.rcVersion.maxSupported, + msgDHRs: x448.getPublicKey(rc.rcDHRs), + msgKEM: rc.rcKEM ? msgKEMParams(rc.rcKEM) : null, + msgPN: rc.rcPN, + msgNs: rc.rcNs, + }) + + // enc_header = HENCRYPT(state.HKs, header) + const phl = paddedHeaderLen(v, rc.rcSupportKEM) + const {authTag: ehAuthTag, ciphertext: ehBody} = encryptAEAD(snd.rcHKs, chain.ehIV, phl, rc.rcAD, headerPlain) + + // smpEncode EncMessageHeader + const emHeader = encodeEncMessageHeader({ehVersion: v, ehBody, ehAuthTag, ehIV: chain.ehIV}) + + // ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) + const bodyAD = concatBytes(rc.rcAD, emHeader) + const {authTag: emAuthTag, ciphertext: emBody} = encryptAEAD(chain.mk, chain.iv, paddedMsgLen, bodyAD, plaintext) + + // encodeEncRatchetMessage + const ciphertext = encodeEncRatchetMessage(v, {emHeader, emBody, emAuthTag}) + + // Update state + const newState: Ratchet = { + ...rc, + rcSnd: {...snd, rcCKs: chain.ck}, + rcNs: rc.rcNs + 1, + } + + return {ciphertext, state: newState} +} + +// -- rcDecrypt (lines 990-1157) + +export interface DecryptResult { + plaintext: Uint8Array + state: Ratchet + skippedKeys: SkippedMsgKeys +} + +export function rcDecrypt( + rc: Ratchet, + skippedKeys: SkippedMsgKeys, + ciphertext: Uint8Array, +): DecryptResult { + const encMsg = decodeEncRatchetMessage(ciphertext) + const encHdr = decodeEncMessageHeader(encMsg.emHeader) + + // TrySkippedMessageKeysHE + const skipped = tryDecryptSkipped(rc, skippedKeys, encHdr, encMsg) + if (skipped) return skipped + + // DecryptHeader + let ratchetStep: "same" | "advance" = "advance" + let hdr: MsgHeader | null = null + + if (rc.rcRcv) { + hdr = tryDecryptHeader(rc.rcRcv.rcHKr, rc.rcAD, encHdr) + if (hdr) ratchetStep = "same" + } + if (!hdr) { + hdr = tryDecryptHeader(rc.rcNHKr, rc.rcAD, encHdr) + if (!hdr) throw new Error("rcDecrypt: header decryption failed (CERatchetHeader)") + ratchetStep = "advance" + } + + // Version upgrade + let state = rc + const {current, maxSupported} = rc.rcVersion + if (hdr.msgMaxVersion > current) { + state = {...state, rcVersion: {...state.rcVersion, current: Math.max(current, Math.min(hdr.msgMaxVersion, maxSupported))}} + } + + let newSkipped = new Map(skippedKeys) + + if (ratchetStep === "advance") { + // SkipMessageKeysHE(state, header.pn) + const skip1 = skipMessageKeys(state, newSkipped, hdr.msgPN) + state = skip1.state; newSkipped = skip1.skippedKeys + + // DHRatchetPQ2HE(state, header) - ratchet step (lines 1043-1071) + const {kemSS, kemSS2, rcKEM: rcKEM_} = pqRatchetStep(state, hdr.msgKEM) + const newDHRs = generateX448KeyPair() + // state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr) || ss) + const kdf1 = rootKdf(state.rcRK, hdr.msgDHRs, state.rcDHRs, kemSS) + // state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs', state.DHRr) || state.PQRss) + const kdf2 = rootKdf(kdf1.rk, hdr.msgDHRs, newDHRs.privateKey, kemSS2) + const sndKEM = kemSS2 !== null + const rcvKEM = kemSS !== null + const rcEnableKEM_ = sndKEM || rcvKEM || rcKEM_ !== null + + state = { + ...state, + rcDHRs: newDHRs.privateKey, + rcKEM: rcKEM_, + rcSupportKEM: pqEnableSupport(state.rcVersion.current, state.rcSupportKEM, rcEnableKEM_), + rcEnableKEM: rcEnableKEM_, + rcSndKEM: sndKEM, + rcRcvKEM: rcvKEM, + rcRK: kdf2.rk, + rcSnd: {rcDHRr: hdr.msgDHRs, rcCKs: kdf2.ck, rcHKs: state.rcNHKs}, + rcRcv: {rcCKr: kdf1.ck, rcHKr: state.rcNHKr}, + rcPN: rc.rcNs, + rcNs: 0, + rcNr: 0, + rcNHKs: kdf2.nhk, + rcNHKr: kdf1.nhk, + } + } + + // SkipMessageKeysHE(state, header.n) + const skip2 = skipMessageKeys(state, newSkipped, hdr.msgNs) + state = skip2.state; newSkipped = skip2.skippedKeys + + if (!state.rcRcv) throw new Error("rcDecrypt: no receiving ratchet after skip") + + // state.CKr, mk = KDF_CK(state.CKr) + const chain = chainKdf(state.rcRcv.rcCKr) + + // DECRYPT(mk, cipher-text, CONCAT(AD, enc_header)) + const bodyAD = concatBytes(state.rcAD, encMsg.emHeader) + const plaintext = decryptAEAD(chain.mk, chain.iv, bodyAD, encMsg.emBody, encMsg.emAuthTag) + + // state.Nr += 1 + state = { + ...state, + rcRcv: {...state.rcRcv, rcCKr: chain.ck}, + rcNr: state.rcNr + 1, + } + + return {plaintext, state, skippedKeys: newSkipped} +} + +// -- skipMessageKeys (lines 1105-1121) + +function skipMessageKeys( + rc: Ratchet, + skippedKeys: SkippedMsgKeys, + untilN: number, +): {state: Ratchet; skippedKeys: SkippedMsgKeys} { + if (!rc.rcRcv) return {state: rc, skippedKeys} + const rcv = rc.rcRcv + const rcNr = rc.rcNr + + if (rcNr > untilN + 1) throw new Error("rcDecrypt: earlier message (CERatchetEarlierMessage)") + if (rcNr === untilN + 1) throw new Error("rcDecrypt: duplicate message (CERatchetDuplicateMessage)") + if (rcNr + MAX_SKIP < untilN) throw new Error("rcDecrypt: too many skipped (CERatchetTooManySkipped)") + if (rcNr === untilN) return {state: rc, skippedKeys} + + // advanceRcvRatchet + let ck = rcv.rcCKr + let nr = rcNr + const hkHex = hexKey(rcv.rcHKr) + const msgKeys = new Map(skippedKeys.get(hkHex) || new Map()) + + while (nr < untilN) { + const chain = chainKdf(ck) + msgKeys.set(nr, {mk: chain.mk, iv: chain.iv}) + ck = chain.ck + nr++ + } + + const newSkipped = new Map(skippedKeys) + newSkipped.set(hkHex, msgKeys) + + return { + state: {...rc, rcRcv: {...rcv, rcCKr: ck}, rcNr: nr}, + skippedKeys: newSkipped, + } +} + +// -- tryDecryptSkipped (lines 1122-1141) + +function tryDecryptSkipped( + rc: Ratchet, + skippedKeys: SkippedMsgKeys, + encHdr: EncMessageHeader, + encMsg: EncRatchetMessage, +): DecryptResult | null { + for (const [hkHex, msgKeys] of skippedKeys) { + const hk = hexToBytes(hkHex) + const hdr = tryDecryptHeader(hk, rc.rcAD, encHdr) + if (hdr) { + const mk = msgKeys.get(hdr.msgNs) + if (mk) { + const bodyAD = concatBytes(rc.rcAD, encMsg.emHeader) + const plaintext = decryptAEAD(mk.mk, mk.iv, bodyAD, encMsg.emBody, encMsg.emAuthTag) + const newMsgKeys = new Map(msgKeys) + newMsgKeys.delete(hdr.msgNs) + const newSkipped = new Map(skippedKeys) + if (newMsgKeys.size === 0) newSkipped.delete(hkHex) + else newSkipped.set(hkHex, newMsgKeys) + return {plaintext, state: rc, skippedKeys: newSkipped} + } + // Header decrypted but msgNs not in skipped keys - check if same/advance ratchet + // For now, fall through to normal decrypt + } + } + return null +} + +// -- pqRatchetStep (lines 1072-1104) +// Returns (kemSS for receive rootKdf, kemSS' for send rootKdf, new RatchetKEM state) + +function pqRatchetStep( + rc: Ratchet, + msgKEM: RKEMParams | null, +): {kemSS: Uint8Array | null; kemSS2: Uint8Array | null; rcKEM: RatchetKEM | null} { + const pqEnc = rc.rcEnableKEM + const v = rc.rcVersion.current + + if (!msgKEM) { + // Received message does not have KEM in header + if (!rc.rcKEM && pqEnc && v >= pqRatchetE2EEncryptVersion) { + // User enabled KEM but no KEM state yet - generate new keypair + const rcPQRs = sntrup761Keypair() + return {kemSS: null, kemSS2: null, rcKEM: {rcPQRs, rcKEMs: null}} + } + return {kemSS: null, kemSS2: null, rcKEM: null} + } + + // Received message has KEM in header + if (pqEnc && v >= pqRatchetE2EEncryptVersion) { + // Get shared secret from received KEM params + const {ss, rcPQRr} = kemSharedSecret(rc.rcKEM, msgKEM) + // state.PQRct = PQKEM-ENC(state.PQRr, state.PQRss) + const kemEncResult = sntrup761Enc(rcPQRr) + // state.PQRs = GENERATE_PQKEM() + const rcPQRs = sntrup761Keypair() + const kem: RatchetKEM = { + rcPQRs, + rcKEMs: {rcPQRr, rcPQRss: kemEncResult.sharedSecret, rcPQRct: kemEncResult.ciphertext}, + } + return {kemSS: ss, kemSS2: kemEncResult.sharedSecret, rcKEM: kem} + } + + // PQ not enabled but message has KEM - extract shared secret only (no new KEM state) + const {ss} = kemSharedSecret(rc.rcKEM, msgKEM) + return {kemSS: ss, kemSS2: null, rcKEM: null} +} + +// Extract shared secret from received KEM params (lines 1097-1104) +function kemSharedSecret( + rcKEM: RatchetKEM | null, + params: RKEMParams, +): {ss: Uint8Array | null; rcPQRr: Uint8Array} { + if (params.type === "proposed") { + // RKParamsProposed k -> no shared secret yet, just received the public key + return {ss: null, rcPQRr: params.kemPk} + } + // RKParamsAccepted ct k -> decapsulate ct with our private KEM key + if (!rcKEM) throw new Error("pqRatchetStep: CERatchetKEMState - no KEM state for accepted params") + const ss = sntrup761Dec(params.kemCt, rcKEM.rcPQRs.secretKey) + return {ss, rcPQRr: params.kemPk} +} + +// -- decryptHeader helper (lines 1151-1153) + +function tryDecryptHeader(headerKey: Uint8Array, ad: Uint8Array, encHdr: EncMessageHeader): MsgHeader | null { + try { + const plainHeader = decryptAEAD(headerKey, encHdr.ehIV, ad, encHdr.ehBody, encHdr.ehAuthTag) + return decodeMsgHeader(encHdr.ehVersion, plainHeader) + } catch { + return null + } +} diff --git a/smp-web/src/crypto/sntrup761.ts b/smp-web/src/crypto/sntrup761.ts new file mode 100644 index 000000000..8914607ef --- /dev/null +++ b/smp-web/src/crypto/sntrup761.ts @@ -0,0 +1,89 @@ +// SNTRUP761 post-quantum KEM. +// Mirrors: Simplex.Messaging.Crypto.SNTRUP761 +// +// Uses WASM compiled from the same C source as the Haskell build +// (cbits/sntrup761.c by djb et al., public domain). +// SHA-512 from SUPERCOP/NaCl (djb, public domain). + +// Key sizes (from sntrup761.h) +export const SNTRUP761_PUBLICKEY_SIZE = 1158 +export const SNTRUP761_SECRETKEY_SIZE = 1763 +export const SNTRUP761_CIPHERTEXT_SIZE = 1039 +export const SNTRUP761_SIZE = 32 // shared secret + +export interface KEMKeyPair { + publicKey: Uint8Array // 1158 bytes + secretKey: Uint8Array // 1763 bytes +} + +export interface KEMEncResult { + ciphertext: Uint8Array // 1039 bytes + sharedSecret: Uint8Array // 32 bytes +} + +// WASM module instance +let wasmModule: any = null + +export async function initSntrup761(): Promise { + if (wasmModule) return + const createSntrup761 = (await import("../../dist/wasm/sntrup761.mjs")).default + wasmModule = await createSntrup761() +} + +function getModule(): any { + if (!wasmModule) throw new Error("sntrup761 WASM not initialized - call initSntrup761() first") + return wasmModule +} + +export function sntrup761Keypair(): KEMKeyPair { + const m = getModule() + const pkPtr = m._malloc(SNTRUP761_PUBLICKEY_SIZE) + const skPtr = m._malloc(SNTRUP761_SECRETKEY_SIZE) + try { + m._sntrup761_wasm_keypair(pkPtr, skPtr) + const publicKey = new Uint8Array(m.HEAPU8.buffer, pkPtr, SNTRUP761_PUBLICKEY_SIZE).slice() + const secretKey = new Uint8Array(m.HEAPU8.buffer, skPtr, SNTRUP761_SECRETKEY_SIZE).slice() + return {publicKey, secretKey} + } finally { + m._free(pkPtr) + m._free(skPtr) + } +} + +export function sntrup761Enc(publicKey: Uint8Array): KEMEncResult { + if (publicKey.length !== SNTRUP761_PUBLICKEY_SIZE) throw new Error("bad public key length") + const m = getModule() + const pkPtr = m._malloc(SNTRUP761_PUBLICKEY_SIZE) + const ctPtr = m._malloc(SNTRUP761_CIPHERTEXT_SIZE) + const ssPtr = m._malloc(SNTRUP761_SIZE) + try { + m.HEAPU8.set(publicKey, pkPtr) + m._sntrup761_wasm_enc(ctPtr, ssPtr, pkPtr) + const ciphertext = new Uint8Array(m.HEAPU8.buffer, ctPtr, SNTRUP761_CIPHERTEXT_SIZE).slice() + const sharedSecret = new Uint8Array(m.HEAPU8.buffer, ssPtr, SNTRUP761_SIZE).slice() + return {ciphertext, sharedSecret} + } finally { + m._free(pkPtr) + m._free(ctPtr) + m._free(ssPtr) + } +} + +export function sntrup761Dec(ciphertext: Uint8Array, secretKey: Uint8Array): Uint8Array { + if (ciphertext.length !== SNTRUP761_CIPHERTEXT_SIZE) throw new Error("bad ciphertext length") + if (secretKey.length !== SNTRUP761_SECRETKEY_SIZE) throw new Error("bad secret key length") + const m = getModule() + const ctPtr = m._malloc(SNTRUP761_CIPHERTEXT_SIZE) + const skPtr = m._malloc(SNTRUP761_SECRETKEY_SIZE) + const ssPtr = m._malloc(SNTRUP761_SIZE) + try { + m.HEAPU8.set(ciphertext, ctPtr) + m.HEAPU8.set(secretKey, skPtr) + m._sntrup761_wasm_dec(ssPtr, ctPtr, skPtr) + return new Uint8Array(m.HEAPU8.buffer, ssPtr, SNTRUP761_SIZE).slice() + } finally { + m._free(ctPtr) + m._free(skPtr) + m._free(ssPtr) + } +} diff --git a/smp-web/src/protocol.ts b/smp-web/src/protocol.ts index 91ce86abc..e61c49ba8 100644 --- a/smp-web/src/protocol.ts +++ b/smp-web/src/protocol.ts @@ -4,8 +4,12 @@ import { Decoder, concatBytes, encodeBytes, decodeBytes, - encodeLarge, decodeLarge + encodeLarge, decodeLarge, + encodeWord16, decodeWord16, + encodeBool, decodeBool, + encodeMaybe, decodeMaybe, } from "@simplex-chat/xftp-web/dist/protocol/encoding.js" +import {cbEncrypt, cbDecrypt} from "@simplex-chat/xftp-web/dist/crypto/secretbox.js" import {readTag, readSpace} from "@simplex-chat/xftp-web/dist/protocol/commands.js" // -- Transmission encoding (Protocol.hs:2201-2203) @@ -83,7 +87,12 @@ export function decodeLNK(d: Decoder): LNKResponse { export type SMPResponse = | {type: "LNK", response: LNKResponse} + | {type: "IDS", response: IDSResponse} + | {type: "MSG", response: MSGResponse} | {type: "OK"} + | {type: "PONG"} + | {type: "END"} + | {type: "DELD"} | {type: "ERR", message: string} export function decodeResponse(d: Decoder): SMPResponse { @@ -93,7 +102,18 @@ export function decodeResponse(d: Decoder): SMPResponse { readSpace(d) return {type: "LNK", response: decodeLNK(d)} } + case "IDS": { + readSpace(d) + return {type: "IDS", response: decodeIDS(d)} + } + case "MSG": { + readSpace(d) + return {type: "MSG", response: decodeMSG(d)} + } case "OK": return {type: "OK"} + case "PONG": return {type: "PONG"} + case "END": return {type: "END"} + case "DELD": return {type: "DELD"} case "ERR": { readSpace(d) return {type: "ERR", message: readTag(d)} @@ -101,3 +121,234 @@ export function decodeResponse(d: Decoder): SMPResponse { default: throw new Error("unknown SMP response: " + tag) } } + +// -- SMP command encoders (Protocol.hs:1679-1715) + +// MsgFlags (Protocol.hs:884-892) +// Single byte: Bool encoding of notification flag +export function encodeMsgFlags(notification: boolean): Uint8Array { + return encodeBool(notification) +} + +// SubscriptionMode (Protocol.hs:651-659) +// 'S' = SMSubscribe, 'C' = SMOnlyCreate +export function encodeSubMode(subscribe: boolean): Uint8Array { + return ascii(subscribe ? "S" : "C") +} + +// NEW (Protocol.hs:1682-1689) +// For v19: e(NEW_, ' ', rKey, dhKey) <> e(auth_, subMode, queueReqData, ntfCreds) +// auth_ = Maybe SndPublicAuthKey (DER-encoded) +// queueReqData = Maybe QueueReqData +// ntfCreds = Maybe NewNtfCreds (not needed for widget) +export function encodeNEW( + rcvAuthKey: Uint8Array, // DER-encoded Ed25519 or X25519 public key + rcvDhKey: Uint8Array, // DER-encoded X25519 public key + sndAuthKey: Uint8Array | null, // DER-encoded, for TOFU sender auth + subscribe: boolean, +): Uint8Array { + return concatBytes( + ascii("NEW "), + encodeBytes(rcvAuthKey), + encodeBytes(rcvDhKey), + encodeMaybe(encodeBytes, sndAuthKey), + encodeSubMode(subscribe), + encodeMaybe(() => new Uint8Array(0), null), // queueReqData = Nothing (widget doesn't create links) + encodeMaybe(() => new Uint8Array(0), null), // ntfCreds = Nothing + ) +} + +// KEY (Protocol.hs:1692) +// KEY k -> e(KEY_, ' ', k) +export function encodeKEY(senderKey: Uint8Array): Uint8Array { + return concatBytes(ascii("KEY "), encodeBytes(senderKey)) +} + +// SKEY (Protocol.hs:1703) +// SKEY k -> e(SKEY_, ' ', k) +export function encodeSKEY(senderKey: Uint8Array): Uint8Array { + return concatBytes(ascii("SKEY "), encodeBytes(senderKey)) +} + +// SUB (Protocol.hs:1690) +export function encodeSUB(): Uint8Array { + return ascii("SUB") +} + +// ACK (Protocol.hs:1699) +// ACK msgId -> e(ACK_, ' ', msgId) +export function encodeACK(msgId: Uint8Array): Uint8Array { + return concatBytes(ascii("ACK "), encodeBytes(msgId)) +} + +// SEND (Protocol.hs:1704) +// SEND flags msg -> e(SEND_, ' ', flags, ' ', Tail msg) +export function encodeSEND(notification: boolean, msgBody: Uint8Array): Uint8Array { + return concatBytes( + ascii("SEND "), + encodeMsgFlags(notification), + ascii(" "), + msgBody, // Tail - no length prefix + ) +} + +// OFF (Protocol.hs:1700) +export function encodeOFF(): Uint8Array { + return ascii("OFF") +} + +// DEL (Protocol.hs:1701) +export function encodeDEL(): Uint8Array { + return ascii("DEL") +} + +// -- SMP response decoders + +// IDS (Protocol.hs:1914-1921) +// For v19: e(IDS_, ' ', rcvId, sndId, srvDh) <> e(queueMode, linkId, serviceId, ntfCreds) +export interface IDSResponse { + rcvId: Uint8Array + sndId: Uint8Array + srvDhKey: Uint8Array + queueMode: string | null // 'M' = Messaging, 'C' = Contact + linkId: Uint8Array | null +} + +export function decodeIDS(d: Decoder): IDSResponse { + const rcvId = decodeBytes(d) + const sndId = decodeBytes(d) + const srvDhKey = decodeBytes(d) + // v19: queueMode (Maybe QueueMode), linkId (Maybe ByteString), serviceId, ntfCreds + // QueueMode is encoded as Maybe Char ('M'/'C'), not Maybe ByteString + let queueMode: string | null = null + if (d.remaining() > 0) { + const qmByte = d.anyByte() + if (qmByte === 0x31) { // '1' = Just + queueMode = String.fromCharCode(d.anyByte()) + } + // '0' = Nothing, queueMode stays null + } + let linkId: Uint8Array | null = null + if (d.remaining() > 0) { + linkId = decodeMaybe(decodeBytes, d) + } + // serviceId and ntfCreds - skip remaining + return {rcvId, sndId, srvDhKey, queueMode, linkId} +} + +// MSG (Protocol.hs:1927-1928) +// MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -> e(MSG_, ' ', msgId, Tail body) +export interface MSGResponse { + msgId: Uint8Array + msgBody: Uint8Array +} + +export function decodeMSG(d: Decoder): MSGResponse { + const msgId = decodeBytes(d) + const msgBody = d.takeAll() + return {msgId, msgBody} +} + +// -- Per-queue E2E encryption (Protocol.hs:1071-1114) + +// Protocol.hs:316-320 +export const e2eEncMessageLength = 16000 +export const e2eEncConfirmationLength = 15904 + +// Protocol.hs:1078-1086 +export interface PubHeader { + phVersion: number // VersionSMPC (Word16) + phE2ePubDhKey: Uint8Array | null // Maybe PublicKeyX25519 (DER-encoded ByteString) +} + +export function encodePubHeader(h: PubHeader): Uint8Array { + return concatBytes(encodeWord16(h.phVersion), encodeMaybe(encodeBytes, h.phE2ePubDhKey)) +} + +export function decodePubHeader(d: Decoder): PubHeader { + return {phVersion: decodeWord16(d), phE2ePubDhKey: decodeMaybe(decodeBytes, d)} +} + +// Protocol.hs:1097-1110 +export type PrivHeader = + | {type: "PHConfirmation", key: Uint8Array} // 'K' + DER-encoded APublicAuthKey + | {type: "PHEmpty"} // '_' + +export function encodePrivHeader(h: PrivHeader): Uint8Array { + switch (h.type) { + case "PHConfirmation": return concatBytes(new Uint8Array([0x4B]), encodeBytes(h.key)) // 'K' + encodeBytes + case "PHEmpty": return new Uint8Array([0x5F]) // '_' + } +} + +export function decodePrivHeader(d: Decoder): PrivHeader { + const tag = d.anyByte() + switch (tag) { + case 0x4B: return {type: "PHConfirmation", key: decodeBytes(d)} // 'K' + case 0x5F: return {type: "PHEmpty"} // '_' + default: throw new Error("decodePrivHeader: unknown tag " + tag) + } +} + +// Protocol.hs:1095, 1112-1114 +export interface ClientMessage { + privHeader: PrivHeader + body: Uint8Array +} + +// smpEncode (ClientMessage h msg) = smpEncode h <> msg +export function encodeClientMessage(msg: ClientMessage): Uint8Array { + return concatBytes(encodePrivHeader(msg.privHeader), msg.body) +} + +export function decodeClientMessage(d: Decoder): ClientMessage { + const privHeader = decodePrivHeader(d) + const body = d.takeAll() + return {privHeader, body} +} + +// Protocol.hs:1071-1093 +export interface ClientMsgEnvelope { + cmHeader: PubHeader + cmNonce: Uint8Array // CbNonce: raw 24 bytes + cmEncBody: Uint8Array // encrypted body (Tail) +} + +// smpEncode (cmHeader, cmNonce, Tail cmEncBody) +export function encodeClientMsgEnvelope(env: ClientMsgEnvelope): Uint8Array { + return concatBytes(encodePubHeader(env.cmHeader), env.cmNonce, env.cmEncBody) +} + +export function decodeClientMsgEnvelope(d: Decoder): ClientMsgEnvelope { + const cmHeader = decodePubHeader(d) + const cmNonce = d.take(24) // CbNonce is raw 24 bytes + const cmEncBody = d.takeAll() + return {cmHeader, cmNonce, cmEncBody} +} + +// -- Per-queue E2E encrypt/decrypt (Agent/Client.hs:2074-2102) + +// agentCbEncrypt: encrypt a ClientMessage and wrap in ClientMsgEnvelope +export function agentCbEncrypt( + e2eDhSecret: Uint8Array, // X25519 DH shared secret (32 bytes) + smpClientVersion: number, // Word16 + e2ePubKey: Uint8Array | null, // DER-encoded X25519 public key, null for normal messages + msg: Uint8Array, // smpEncode(ClientMessage) +): Uint8Array { + const cmNonce = crypto.getRandomValues(new Uint8Array(24)) + const paddedLen = e2ePubKey !== null ? e2eEncConfirmationLength : e2eEncMessageLength + const cmEncBody = cbEncrypt(e2eDhSecret, cmNonce, msg, paddedLen) + const cmHeader: PubHeader = {phVersion: smpClientVersion, phE2ePubDhKey: e2ePubKey} + return encodeClientMsgEnvelope({cmHeader, cmNonce, cmEncBody}) +} + +// agentCbDecrypt: decrypt a ClientMsgEnvelope +export function agentCbDecrypt( + dhSecret: Uint8Array, // X25519 DH shared secret (32 bytes) + data: Uint8Array, // raw ClientMsgEnvelope bytes +): {pubHeader: PubHeader, clientMessage: ClientMessage} { + const env = decodeClientMsgEnvelope(new Decoder(data)) + const plaintext = cbDecrypt(dhSecret, env.cmNonce, env.cmEncBody) + const clientMessage = decodeClientMessage(new Decoder(plaintext)) + return {pubHeader: env.cmHeader, clientMessage} +} diff --git a/smp-web/tests/ratchet-repl.ts b/smp-web/tests/ratchet-repl.ts new file mode 100644 index 000000000..4b8bc4033 --- /dev/null +++ b/smp-web/tests/ratchet-repl.ts @@ -0,0 +1,326 @@ +// Double ratchet REPL for cross-language testing. +// Holds one ratchet state, reads commands from stdin, writes results to stdout. +// +// Init protocol: +// INIT_RCV +// → ok: +// COMPLETE +// → ok +// INIT_SND +// → ok: +// kemMode: none | propose | accept +// +// Encrypt/decrypt operators (same syntax as Haskell DoubleRatchetTests): +// \#> encrypt, assert noSndKEM +// !#> <plaintext> encrypt, assert hasSndKEM +// \#>! <plaintext> encrypt PQEncOn, assert noSndKEM +// !#>! <plaintext> encrypt PQEncOn, assert hasSndKEM +// !#>\ <plaintext> encrypt PQEncOff, assert hasSndKEM +// \#>\ <plaintext> encrypt PQEncOff, assert noSndKEM +// <#\ <hex ct> <expected> decrypt, assert noRcvKEM +// <#! <hex ct> <expected> decrypt, assert hasRcvKEM +// +// Plain encrypt/decrypt (no assertions): +// E <plaintext> → ok: <hex ciphertext> +// D <hex ciphertext> → ok: <plaintext> +// +// Response format: ok: <data> or error: <message> + +import {createInterface} from "readline" +import { + generateX448KeyPair, pqX3dhSnd, pqX3dhRcv, + encodePubKeyX448, decodePubKeyX448, + initSndRatchet, initRcvRatchet, + rcEncrypt, rcDecrypt, + rootKdf, + type Ratchet, type SkippedMsgKeys, type RatchetVersions, + type RatchetInitParams, type RatchetKEMAccepted, +} from "../dist/crypto/ratchet.js" +import {initSntrup761, sntrup761Keypair, sntrup761Enc, sntrup761Dec} from "../dist/crypto/sntrup761.js" +import type {KEMKeyPair} from "../dist/crypto/sntrup761.js" +import { + Decoder, decodeBytes, decodeLarge, encodeBytes, encodeWord16, concatBytes, +} from "@simplex-chat/xftp-web/dist/protocol/encoding.js" + +// -- State + +let ratchet: Ratchet | null = null +let skippedKeys: SkippedMsgKeys = new Map() +const PADDED_MSG_LEN = 16000 + +// Intermediate state for RCV init (between INIT_RCV and COMPLETE) +let rcvInitState: { + privKey1: Uint8Array + privKey2: Uint8Array + kemKeyPair: KEMKeyPair | null + pqSupport: boolean +} | null = null + +// -- Hex helpers + +function toHex(bytes: Uint8Array): string { + return Array.from(bytes, b => b.toString(16).padStart(2, "0")).join("") +} + +function fromHex(hex: string): Uint8Array { + const bytes = new Uint8Array(hex.length / 2) + for (let i = 0; i < hex.length; i += 2) + bytes[i / 2] = parseInt(hex.substring(i, i + 2), 16) + return bytes +} + +// -- E2E params helpers + +// Parse E2ERatchetParams: version(Word16) + pk1(ByteString) + pk2(ByteString) + Maybe KEMParams +interface ParsedE2EParams { + version: number + pk1Raw: Uint8Array // raw X448 public key + pk2Raw: Uint8Array // raw X448 public key + kemPk: Uint8Array | null // KEM public key if proposed + kemCt: Uint8Array | null // KEM ciphertext if accepted + kemAcceptPk: Uint8Array | null // KEM public key in accepted +} + +function parseE2EParams(data: Uint8Array): ParsedE2EParams { + const d = new Decoder(data) + const version = d.anyByte() * 256 + d.anyByte() + const pk1Raw = decodePubKeyX448(decodeBytes(d)) + const pk2Raw = decodePubKeyX448(decodeBytes(d)) + let kemPk: Uint8Array | null = null + let kemCt: Uint8Array | null = null + let kemAcceptPk: Uint8Array | null = null + if (version >= 3 && d.remaining() > 0) { + const maybeByte = d.anyByte() + if (maybeByte === 0x31) { // Just + const tag = d.anyByte() + if (tag === 0x50) { // 'P' Proposed + kemPk = decodeLarge(d) + } else if (tag === 0x41) { // 'A' Accepted + kemCt = decodeLarge(d) + kemAcceptPk = decodeLarge(d) + } + } + } + return {version, pk1Raw, pk2Raw, kemPk, kemCt, kemAcceptPk} +} + +// Encode E2ERatchetParams for sending to peer +function encodeE2EParams( + version: number, + pk1Raw: Uint8Array, pk2Raw: Uint8Array, + kemPk: Uint8Array | null, // for proposed + kemCt: Uint8Array | null, // for accepted + kemAcceptPk: Uint8Array | null, // public key in accepted +): Uint8Array { + const vBytes = new Uint8Array(2) + vBytes[0] = (version >> 8) & 0xff + vBytes[1] = version & 0xff + const parts = [vBytes, encodeBytes(encodePubKeyX448(pk1Raw)), encodeBytes(encodePubKeyX448(pk2Raw))] + if (version >= 3) { + if (kemCt && kemAcceptPk) { + // Just Accepted + parts.push(new Uint8Array([0x31, 0x41])) // Just + 'A' + parts.push(new Uint8Array([(kemCt.length >> 8) & 0xff, kemCt.length & 0xff])) + parts.push(kemCt) + parts.push(new Uint8Array([(kemAcceptPk.length >> 8) & 0xff, kemAcceptPk.length & 0xff])) + parts.push(kemAcceptPk) + } else if (kemPk) { + // Just Proposed + parts.push(new Uint8Array([0x31, 0x50])) // Just + 'P' + parts.push(new Uint8Array([(kemPk.length >> 8) & 0xff, kemPk.length & 0xff])) + parts.push(kemPk) + } else { + // Nothing + parts.push(new Uint8Array([0x30])) + } + } + return concatBytes(...parts) +} + +// -- Init handlers + +function handleInitRcv(version: number, pqSupport: boolean): string { + const kp1 = generateX448KeyPair() + const kp2 = generateX448KeyPair() + let kemKeyPair: KEMKeyPair | null = null + let kemPk: Uint8Array | null = null + if (pqSupport) { + kemKeyPair = sntrup761Keypair() + kemPk = kemKeyPair.publicKey + } + rcvInitState = {privKey1: kp1.privateKey, privKey2: kp2.privateKey, kemKeyPair, pqSupport} + const params = encodeE2EParams(version, kp1.publicKey, kp2.publicKey, kemPk, null, null) + return "ok: " + toHex(params) +} + +function handleComplete(peerParamsHex: string): string { + if (!rcvInitState) return "error: not in RCV init state" + const {privKey1, privKey2, kemKeyPair, pqSupport} = rcvInitState + const peerParams = parseE2EParams(fromHex(peerParamsHex)) + + // Build kemAccepted for X3DH if peer accepted our KEM proposal + let kemAccepted: RatchetKEMAccepted | null = null + if (peerParams.kemCt && peerParams.kemAcceptPk && kemKeyPair) { + const ss = sntrup761Dec(peerParams.kemCt, kemKeyPair.secretKey) + kemAccepted = {rcPQRr: peerParams.kemAcceptPk, rcPQRss: ss, rcPQRct: peerParams.kemCt} + } + + // X3DH (receiver side) + const initParams = pqX3dhRcv(privKey1, privKey2, peerParams.pk1Raw, peerParams.pk2Raw, kemAccepted) + + // Init receiving ratchet + const vs: RatchetVersions = {current: peerParams.version, maxSupported: peerParams.version} + ratchet = initRcvRatchet(vs, privKey2, initParams, kemKeyPair, pqSupport) + skippedKeys = new Map() + rcvInitState = null + return "ok" +} + +function handleInitSnd(version: number, kemMode: string, peerParamsHex: string): string { + const peerParams = parseE2EParams(fromHex(peerParamsHex)) + + const kp1 = generateX448KeyPair() + const kp2 = generateX448KeyPair() + const kp3 = generateX448KeyPair() // fresh DH key for ratchet + + // KEM handling + let kemAccepted: RatchetKEMAccepted | null = null + let ownKemKp: KEMKeyPair | null = null + let outKemPk: Uint8Array | null = null + let outKemCt: Uint8Array | null = null + let outKemAcceptPk: Uint8Array | null = null + + if (kemMode === "accept" && peerParams.kemPk) { + // Accept peer's KEM proposal + const encResult = sntrup761Enc(peerParams.kemPk) + ownKemKp = sntrup761Keypair() + kemAccepted = {rcPQRr: peerParams.kemPk, rcPQRss: encResult.sharedSecret, rcPQRct: encResult.ciphertext} + outKemCt = encResult.ciphertext + outKemAcceptPk = ownKemKp.publicKey + } else if (kemMode === "propose") { + ownKemKp = sntrup761Keypair() + outKemPk = ownKemKp.publicKey + } + + // X3DH (sender side) + const initParams = pqX3dhSnd(kp1.privateKey, kp2.privateKey, peerParams.pk1Raw, peerParams.pk2Raw, kemAccepted) + + // Init sending ratchet + const vs: RatchetVersions = {current: version, maxSupported: version} + ratchet = initSndRatchet(vs, peerParams.pk2Raw, kp3.privateKey, initParams, ownKemKp) + skippedKeys = new Map() + + const params = encodeE2EParams(version, kp1.publicKey, kp2.publicKey, outKemPk, outKemCt, outKemAcceptPk) + return "ok: " + toHex(params) +} + +// -- Encrypt/decrypt handlers + +function handleEncrypt(kemAssert: boolean | null, _pqPref: boolean | null, plaintext: string): string { + if (!ratchet) return "error: not initialized" + try { + const result = rcEncrypt(ratchet, new TextEncoder().encode(plaintext), PADDED_MSG_LEN) + ratchet = result.state + if (kemAssert === true && !ratchet.rcSndKEM) return "error: expected hasSndKEM" + if (kemAssert === false && ratchet.rcSndKEM) return "error: expected noSndKEM" + return "ok: " + toHex(result.ciphertext) + } catch (e: any) { + return "error: " + e.message + } +} + +function handleDecrypt(kemAssert: boolean | null, hexCt: string, expectedPlaintext: string | null): string { + if (!ratchet) return "error: not initialized" + try { + const ct = fromHex(hexCt) + const result = rcDecrypt(ratchet, skippedKeys, ct) + ratchet = result.state + skippedKeys = result.skippedKeys + const plaintext = new TextDecoder().decode(result.plaintext) + if (kemAssert === true && !ratchet.rcRcvKEM) return "error: expected hasRcvKEM" + if (kemAssert === false && ratchet.rcRcvKEM) return "error: expected noRcvKEM" + if (expectedPlaintext !== null && plaintext !== expectedPlaintext) + return "error: expected '" + expectedPlaintext + "', got '" + plaintext + "'" + return "ok: " + plaintext + } catch (e: any) { + return "error: " + e.message + } +} + +// -- Command parser + +function parseLine(line: string): string { + // Init commands + if (line.startsWith("INIT_RCV ")) { + const parts = line.split(" ") + return handleInitRcv(parseInt(parts[1]), parts[2] === "1") + } + if (line.startsWith("COMPLETE ")) { + return handleComplete(line.substring(9).trim()) + } + if (line.startsWith("INIT_SND ")) { + const parts = line.split(" ") + return handleInitSnd(parseInt(parts[1]), parts[2], parts[3]) + } + + // Query commands + if (line === "SNDKEM") { + if (!ratchet) return "error: not initialized" + return "ok: " + (ratchet.rcSndKEM ? "1" : "0") + } + if (line === "RCVKEM") { + if (!ratchet) return "error: not initialized" + return "ok: " + (ratchet.rcRcvKEM ? "1" : "0") + } + + // Encrypt operators: \#> !#> \#>! !#>! !#>\ \#>\ + const encMatch = line.match(/^([!\\])#(>[!\\]?)\s+(.+)$/) + if (encMatch) { + const [, kemChar, arrow, msg] = encMatch + const kemAssert = kemChar === "!" ? true : false + let pqPref: boolean | null = null + if (arrow === ">!") pqPref = true + else if (arrow === ">\\") pqPref = false + return handleEncrypt(kemAssert, pqPref, msg) + } + + // Decrypt operators: <#\ <#! + const decMatch = line.match(/^<#([!\\])\s+(\S+)\s+(.+)$/) + if (decMatch) { + const [, kemChar, hexCt, expected] = decMatch + const kemAssert = kemChar === "!" ? true : false + return handleDecrypt(kemAssert, hexCt, expected) + } + + // Plain encrypt (no assertion) + if (line.startsWith("E ")) { + return handleEncrypt(null, null, line.substring(2)) + } + + // Plain decrypt (no assertion, no expected) + if (line.startsWith("D ")) { + return handleDecrypt(null, line.substring(2), null) + } + + return "error: unknown command: " + line +} + +// -- Main + +async function main() { + await initSntrup761() + + const rl = createInterface({input: process.stdin, terminal: false}) + + for await (const line of rl) { + const trimmed = line.trim() + if (!trimmed) continue + const response = parseLine(trimmed) + process.stdout.write(response + "\n") + } +} + +main().catch(e => { + process.stderr.write("FATAL: " + e.message + "\n") + process.exit(1) +}) diff --git a/smp-web/tsconfig.test.json b/smp-web/tsconfig.test.json new file mode 100644 index 000000000..960f58829 --- /dev/null +++ b/smp-web/tsconfig.test.json @@ -0,0 +1,15 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "ES2022", + "moduleResolution": "node", + "lib": ["ES2022"], + "outDir": "dist-test", + "rootDir": "tests", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "sourceMap": true + }, + "include": ["tests/**/*.ts"] +} diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index eef5be27f..fd160dbd3 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -73,15 +73,19 @@ runMessageTests :: Bool -> Spec runMessageTests initRatchets_ agreeRatchetKEMs = do - it "should encrypt and decrypt messages" $ run $ testEncryptDecrypt agreeRatchetKEMs - it "should encrypt and decrypt skipped messages" $ run $ testSkippedMessages agreeRatchetKEMs - it "should encrypt and decrypt many messages" $ run $ testManyMessages agreeRatchetKEMs - it "should allow skipped after ratchet advance" $ run $ testSkippedAfterRatchetAdvance agreeRatchetKEMs + it "should encrypt and decrypt messages" $ run testEncryptDecrypt + it "should encrypt and decrypt skipped messages" $ run testSkippedMessages + it "should encrypt and decrypt many messages" $ run testManyMessages + it "should allow skipped after ratchet advance" $ run testSkippedAfterRatchetAdvance where run :: (forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a) -> IO () run test = do - withRatchets_ @X25519 initRatchets_ test - withRatchets_ @X448 initRatchets_ test + withRatchets_ @X25519 initRatchets_ (withKEM test) + withRatchets_ @X448 initRatchets_ (withKEM test) + withKEM :: (AlgorithmI a, DhAlgorithm a) => TestRatchets a -> TestRatchets a + withKEM test alice bob encrypt decrypt (#>) = do + when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob + test alice bob encrypt decrypt (#>) testAlgs :: (forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()) -> IO () testAlgs test = test C.SX25519 >> test C.SX448 @@ -146,6 +150,12 @@ type TestRatchets a = EncryptDecryptSpec a -> IO () +-- Peer-polymorphic types for cross-language testing +type EncryptP p = p -> ByteString -> IO (Either CryptoError ByteString) +type DecryptP p = p -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString)) +type EncryptDecryptSpecP p = (p, ByteString) -> p -> Expectation +type TestRatchetsP p = p -> p -> EncryptP p -> DecryptP p -> EncryptDecryptSpecP p -> IO () + deriving instance Eq (Ratchet a) deriving instance Eq (SndRatchet a) @@ -170,9 +180,8 @@ deriving instance Eq (MsgHeader a) initRatchetKEM :: (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO () initRatchetKEM s r = encryptDecrypt (Just $ PQEncOn) (const ()) (const ()) (s, "initialising ratchet") r -testEncryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a -testEncryptDecrypt agreeRatchetKEMs alice bob encrypt decrypt (#>) = do - when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob +testEncryptDecrypt :: TestRatchetsP p +testEncryptDecrypt alice bob encrypt decrypt (#>) = do (bob, "hello alice") #> alice (alice, "hello bob") #> bob Right b1 <- encrypt bob "how are you, alice?" @@ -191,9 +200,8 @@ testEncryptDecrypt agreeRatchetKEMs alice bob encrypt decrypt (#>) = do (alice, "I'm here too, same") #> bob pure () -testSkippedMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a -testSkippedMessages agreeRatchetKEMs alice bob encrypt decrypt _ = do - when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob +testSkippedMessages :: TestRatchetsP p +testSkippedMessages alice bob encrypt decrypt _ = do Right msg1 <- encrypt bob "hello alice" Right msg2 <- encrypt bob "hello there again" Right msg3 <- encrypt bob "are you there?" @@ -203,9 +211,8 @@ testSkippedMessages agreeRatchetKEMs alice bob encrypt decrypt _ = do Decrypted "hello alice" <- decrypt alice msg1 pure () -testManyMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a -testManyMessages agreeRatchetKEMs alice bob _ _ (#>) = do - when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob +testManyMessages :: TestRatchetsP p +testManyMessages alice bob _ _ (#>) = do (bob, "b1") #> alice (bob, "b2") #> alice (bob, "b3") #> alice @@ -222,9 +229,8 @@ testManyMessages agreeRatchetKEMs alice bob _ _ (#>) = do (bob, "b15") #> alice (bob, "b16") #> alice -testSkippedAfterRatchetAdvance :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a -testSkippedAfterRatchetAdvance agreeRatchetKEMs alice bob encrypt decrypt (#>) = do - when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob +testSkippedAfterRatchetAdvance :: TestRatchetsP p +testSkippedAfterRatchetAdvance alice bob encrypt decrypt (#>) = do (bob, "b1") #> alice Right b2 <- encrypt bob "b2" Right b3 <- encrypt bob "b3" diff --git a/tests/SMPWebTests.hs b/tests/SMPWebTests.hs index 44db83693..11c0d5922 100644 --- a/tests/SMPWebTests.hs +++ b/tests/SMPWebTests.hs @@ -2,6 +2,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} -- | Per-function tests for the smp-web TypeScript SMP client library. @@ -12,9 +13,18 @@ -- Run: cabal test --test-option=--match="/SMP Web Client/" module SMPWebTests (smpWebTests) where -import Control.Monad.Except (ExceptT, runExceptT) +import Control.Concurrent.STM +import Control.Monad (when) +import Data.Bifunctor (first) +import Control.Exception (bracket) +import Control.Monad.Except (ExceptT, liftEither, runExceptT, throwError, withExceptT) +import Crypto.Random (ChaChaDRG) +import Data.IORef +import System.IO (Handle, hFlush, hGetLine, hPutStr, hSetBuffering, BufferMode (..)) +import System.Process (CreateProcess (..), StdStream (..), ProcessHandle, createProcess, proc, terminateProcess) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC +import Data.List (isInfixOf) import Data.List.NonEmpty (NonEmpty (..)) import System.Directory (doesDirectoryExist) import Data.Word (Word16) @@ -25,18 +35,24 @@ import Simplex.Messaging.Client (pattern NRMInteractive) import Simplex.Messaging.Version (mkVersionRange) import Simplex.Messaging.Version.Internal (Version (..)) import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto (Algorithm (..)) import qualified Simplex.Messaging.Crypto.Ratchet as CR +import Simplex.Messaging.Crypto.SNTRUP761.Bindings (KEMPublicKey (..), KEMSecretKey, KEMCiphertext (..), KEMSharedKey (..), sntrup761Keypair, sntrup761Enc, sntrup761Dec) +import qualified Crypto.Cipher.Types as AES +import qualified Data.Map.Strict as M +import qualified Data.ByteArray as BA import Simplex.Messaging.Crypto.ShortLink (contactShortLinkKdf, invShortLinkKdf) import Simplex.Messaging.Encoding -import Simplex.Messaging.Encoding.String (strEncode) -import Simplex.Messaging.Protocol (EntityId (..), SMPServer, SubscriptionMode (..), pattern SMPServer) +import Simplex.Messaging.Encoding.String (Str (..), strEncode) +import Simplex.Messaging.Protocol (EntityId (..), SMPServer, SubscriptionMode (..), MsgFlags (..), pattern SMPServer, encodeProtocol, Command (..), NewQueueReq (..), BrokerMsg (..), RcvMessage (..), EncRcvMsgBody (..), QueueIdsKeys (..), PubHeader (..), PrivHeader (..), ClientMessage (..), ClientMsgEnvelope (..), pattern VersionSMPC) import Simplex.Messaging.Server.Env.STM (AStoreType (..)) import Simplex.Messaging.Server.MsgStore.Types (SMSType (..), SQSType (..)) import Simplex.Messaging.Server.Web (attachStaticAndWS) -import Simplex.Messaging.Transport (TLS, smpBlockSize) +import Simplex.Messaging.Transport (TLS, smpBlockSize, currentServerSMPRelayVersion) import Simplex.Messaging.Transport.Client (TransportHost (..)) import SMPAgentClient (agentCfg, initAgentServers, testDB) import SMPClient (cfgWebOn, testKeyHash, testPort, withSmpServerConfig) +import AgentTests.DoubleRatchetTests (testEncryptDecrypt, testSkippedMessages, testManyMessages, testSkippedAfterRatchetAdvance) import AgentTests.FunctionalAPITests (withAgent) import Test.Hspec hiding (it) import Util @@ -49,10 +65,10 @@ callNode :: String -> IO B.ByteString callNode = callNode_ smpWebDir impEnc :: String -impEnc = "import { Decoder, decodeLarge } from '@simplex-chat/xftp-web/dist/protocol/encoding.js';" +impEnc = "import { Decoder, decodeBytes, decodeLarge, encodeBytes, encodeWord16 } from '@simplex-chat/xftp-web/dist/protocol/encoding.js';" impProto_ :: String -impProto_ = "import { encodeTransmission, encodeBatch, decodeTransmission, encodeLGET, decodeLNK, decodeResponse } from './dist/protocol.js';" +impProto_ = "import { encodeTransmission, encodeBatch, decodeTransmission, encodeLGET, decodeLNK, decodeResponse, encodeNEW, encodeKEY, encodeSKEY, encodeSUB, encodeACK, encodeSEND, encodeOFF, encodeDEL } from './dist/protocol.js';" impProto :: String impProto = impEnc <> impProto_ @@ -74,6 +90,19 @@ impAgentProto = impEnc <> impAgentProto_ impCryptoShortLink :: String impCryptoShortLink = "import { contactShortLinkKdf, invShortLinkKdf, decryptLinkData } from './dist/crypto/shortLink.js';" +impRatchet :: String +impRatchet = "import { generateX448KeyPair, pqX3dhSnd, pqX3dhRcv, x448DH, encodePubKeyX448, decodePubKeyX448, chainKdf, rootKdf, initSndRatchet, initRcvRatchet, rcEncrypt, rcDecrypt } from './dist/crypto/ratchet.js';" + <> "import { encryptAEAD, decryptAEAD } from './dist/crypto.js';" + +impSntrup :: String +impSntrup = "import { initSntrup761, sntrup761Keypair, sntrup761Enc, sntrup761Dec } from './dist/crypto/sntrup761.js'; await initSntrup761();" + +impAgentMsg :: String +impAgentMsg = "import { encodeAMessage, decodeAMessage, encodeAPrivHeader, decodeAPrivHeader, encodeAgentMessage, decodeAgentMessage, encodeAgentMsgEnvelope, decodeAgentMsgEnvelope } from './dist/agent/message.js';" + +impProtoE2E :: String +impProtoE2E = "import { encodePubHeader, decodePubHeader, encodePrivHeader, decodePrivHeader, encodeClientMessage, decodeClientMessage, encodeClientMsgEnvelope, decodeClientMsgEnvelope, agentCbEncrypt, agentCbDecrypt, e2eEncMessageLength } from './dist/protocol.js';" + impCrypto :: String impCrypto = "import { sbcInit, sbcHkdf, sbEncryptBlock, sbDecryptBlock } from './dist/crypto.js';" @@ -84,9 +113,222 @@ impSodium = "import sodium from '@simplex-chat/xftp-web/node_modules/libsodium-w jsStr :: B.ByteString -> String jsStr bs = "'" <> BC.unpack bs <> "'" +paddedMsgLen :: Int +paddedMsgLen = 100 + +-- -- TestPeer: sum type for cross-language ratchet tests + +type HsPeer a = TVar (TVar ChaChaDRG, CR.Ratchet a, CR.SkippedMsgKeys) + +data TestPeer + = forall a. (C.AlgorithmI a, C.DhAlgorithm a) => TestPeerHS (HsPeer a) + | TestPeerJS Handle Handle ProcessHandle -- stdin, stdout, process + +-- dispatch functions + +tpEncrypt :: TestPeer -> B.ByteString -> IO (Either C.CryptoError B.ByteString) +tpEncrypt (TestPeerHS tvar) msg = do + (_, rc, smks) <- readTVarIO tvar + result <- runExceptT $ do + (mek, rc') <- CR.rcEncryptHeader rc Nothing CR.currentE2EEncryptVersion + ct <- CR.rcEncryptMsg mek paddedMsgLen msg + pure (ct, rc') + case result of + Right (ct, rc') -> do + (g, _, smks') <- readTVarIO tvar + atomically $ writeTVar tvar (g, rc', smks') + pure $ Right ct + Left e -> pure $ Left e +tpEncrypt (TestPeerJS hIn hOut _) msg = do + hPutStrLn' hIn $ "E " <> BC.unpack msg + resp <- hGetLine hOut + case parseResponse resp of + Right hex -> pure $ Right $ hexToBS hex + Left err -> error $ "tpEncrypt JS error: " <> err + +tpDecrypt :: TestPeer -> B.ByteString -> IO (Either C.CryptoError (Either C.CryptoError B.ByteString)) +tpDecrypt (TestPeerHS tvar) ct = do + (g, rc, smks) <- readTVarIO tvar + result <- runExceptT $ CR.rcDecrypt g rc smks ct + case result of + Right (msg, rc', smDiff) -> do + atomically $ writeTVar tvar (g, rc', CR.applySMDiff smks smDiff) + pure $ Right msg + Left e -> pure $ Left e +tpDecrypt (TestPeerJS hIn hOut _) ct = do + hPutStrLn' hIn $ "D " <> bsToHex ct + resp <- hGetLine hOut + case parseResponse resp of + Right txt -> pure $ Right $ Right $ BC.pack txt + Left err -> parseJsError err + +-- Map JS REPL error strings to CryptoError at the correct Either level. +-- Outer Left: errors that abort rcDecrypt (header failure, first skipMessageKeys). +-- Inner Right (Left _): errors from second skipMessageKeys (duplicate/earlier in current ratchet state). +parseJsError :: String -> IO (Either C.CryptoError (Either C.CryptoError B.ByteString)) +parseJsError err + -- Outer errors (ExceptT failures in Haskell rcDecrypt) + | has "CERatchetHeader" = pure $ Left C.CERatchetHeader + | has "CERatchetKEMState" = pure $ Left C.CERatchetKEMState + -- Inner errors (pure Left in second skipMessageKeys) + | has "CERatchetDuplicateMessage" = pure $ Right $ Left C.CERatchetDuplicateMessage + | has "CERatchetEarlierMessage" = pure $ Right $ Left $ C.CERatchetEarlierMessage 0 + | has "CERatchetTooManySkipped" = pure $ Right $ Left $ C.CERatchetTooManySkipped 0 + | has "CERatchetState" = pure $ Right $ Left C.CERatchetState + | otherwise = pure $ Left $ C.CryptoHeaderError err + where + has s = s `isInfixOf` err + +tpSndKEM :: TestPeer -> IO Bool +tpSndKEM (TestPeerHS tvar) = do + (_, rc, _) <- readTVarIO tvar + pure $ CR.enablePQ $ CR.rcSndKEM rc +tpSndKEM (TestPeerJS hIn hOut _) = do + hPutStrLn' hIn "SNDKEM" + resp <- hGetLine hOut + pure $ resp == "ok: 1" + +tpRcvKEM :: TestPeer -> IO Bool +tpRcvKEM (TestPeerHS tvar) = do + (_, rc, _) <- readTVarIO tvar + pure $ CR.enablePQ $ CR.rcRcvKEM rc +tpRcvKEM (TestPeerJS hIn hOut _) = do + hPutStrLn' hIn "RCVKEM" + resp <- hGetLine hOut + pure $ resp == "ok: 1" + +tpEncryptDecrypt :: Maybe CR.PQEncryption -> Bool -> Bool -> (TestPeer, B.ByteString) -> TestPeer -> Expectation +tpEncryptDecrypt _pqEnc expectSndKEM expectRcvKEM (sender, msg) receiver = do + Right ct <- tpEncrypt sender msg + sndK <- tpSndKEM sender + when (sndK /= expectSndKEM) $ expectationFailure $ "sndKEM: expected " <> show expectSndKEM <> ", got " <> show sndK + Right (Right msg') <- tpDecrypt receiver ct + rcvK <- tpRcvKEM receiver + when (rcvK /= expectRcvKEM) $ expectationFailure $ "rcvKEM: expected " <> show expectRcvKEM <> ", got " <> show rcvK + msg' `shouldBe` msg + +-- TestPeer operators (matching Haskell DoubleRatchetTests) +tp_noKEM, tp_hasKEM :: (TestPeer, B.ByteString) -> TestPeer -> Expectation +tp_noKEM = tpEncryptDecrypt Nothing False False +tp_hasKEM = tpEncryptDecrypt Nothing True True + +-- JS process helpers + +spawnJsRatchet :: IO (Handle, Handle, ProcessHandle) +spawnJsRatchet = do + let cp = (proc "node" ["dist-test/ratchet-repl.js"]) {cwd = Just "smp-web", std_in = CreatePipe, std_out = CreatePipe, std_err = CreatePipe} + (Just hIn, Just hOut, _, ph) <- createProcess cp + hSetBuffering hIn LineBuffering + hSetBuffering hOut LineBuffering + pure (hIn, hOut, ph) + +destroyJsRatchet :: TestPeer -> IO () +destroyJsRatchet (TestPeerJS _ _ ph) = terminateProcess ph +destroyJsRatchet _ = pure () + +jsCmd :: Handle -> Handle -> String -> IO String +jsCmd hIn hOut cmd = do + hPutStrLn' hIn cmd + hGetLine hOut + +hPutStrLn' :: Handle -> String -> IO () +hPutStrLn' h s = do + hPutStr h (s <> "\n") + hFlush h + +parseResponse :: String -> Either String String +parseResponse resp + | take 4 resp == "ok: " = Right $ drop 4 resp + | take 7 resp == "error: " = Left $ drop 7 resp + | otherwise = Left $ "unexpected response: " <> resp + +bsToHex :: B.ByteString -> String +bsToHex = concatMap (\w -> let h = showHex' w in h) . B.unpack + where + showHex' w = [hexDigit (w `div` 16), hexDigit (w `mod` 16)] + hexDigit n | n < 10 = toEnum (fromEnum '0' + fromIntegral n) + | otherwise = toEnum (fromEnum 'a' + fromIntegral n - 10) + +hexToBS :: String -> B.ByteString +hexToBS = B.pack . go + where + go [] = [] + go (a:b:rest) = fromIntegral (hexVal a * 16 + hexVal b) : go rest + go _ = [] + hexVal c + | c >= '0' && c <= '9' = fromEnum c - fromEnum '0' + | c >= 'a' && c <= 'f' = fromEnum c - fromEnum 'a' + 10 + | c >= 'A' && c <= 'F' = fromEnum c - fromEnum 'A' + 10 + | otherwise = 0 + runRight :: (Show e, HasCallStack) => ExceptT e IO a -> IO a runRight action = runExceptT action >>= either (error . ("Unexpected error: " <>) . show) pure +-- -- Cross-language ratchet init functions and test patterns + +withCrossPeers :: IO (TestPeer, TestPeer) -> ((TestPeer, TestPeer) -> IO ()) -> IO () +withCrossPeers initPeers test = bracket initPeers cleanup test + where + cleanup (a, b) = destroyJsRatchet a >> destroyJsRatchet b + +-- HS (receiver) <-> JS (sender), no PQ +initHsJs_noPQ :: IO (TestPeer, TestPeer) +initHsJs_noPQ = do + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + Version vNum = v + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOff + let aliceE2EHex = bsToHex $ smpEncode e2eAlice + (hIn, hOut, ph) <- spawnJsRatchet + bobE2EHex <- either error pure . parseResponse =<< jsCmd hIn hOut ("INIT_SND " ++ show vNum ++ " none " ++ aliceE2EHex) + Right (CR.AE2ERatchetParams _ bobE2E :: CR.AE2ERatchetParams 'X448) <- pure $ smpDecode $ hexToBS bobE2EHex + Right (aliceInitParams, _) <- runExceptT $ CR.pqX3dhRcv pkAlice1 pkAlice2 Nothing bobE2E + let aliceRatchet = CR.initRcvRatchet (CR.RatchetVersions v v) pkAlice2 (aliceInitParams, Nothing) CR.PQSupportOff + ga <- C.newRandom + aliceTVar <- newTVarIO (ga, aliceRatchet, M.empty :: CR.SkippedMsgKeys) + pure (TestPeerHS aliceTVar, TestPeerJS hIn hOut ph) + +-- HS (receiver) <-> JS (sender), PQ KEM accepted +initHsJs_PQ :: IO (TestPeer, TestPeer) +initHsJs_PQ = do + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + Version vNum = v + (pkAlice1, pkAlice2, alicePKem_@(Just _), e2eAlice) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOn + let aliceE2EHex = bsToHex $ smpEncode e2eAlice + (hIn, hOut, ph) <- spawnJsRatchet + bobE2EHex <- either error pure . parseResponse =<< jsCmd hIn hOut ("INIT_SND " ++ show vNum ++ " accept " ++ aliceE2EHex) + Right (CR.AE2ERatchetParams _ bobE2E :: CR.AE2ERatchetParams 'X448) <- pure $ smpDecode $ hexToBS bobE2EHex + Right (aliceInitParams, aliceKemKp_) <- runExceptT $ CR.pqX3dhRcv pkAlice1 pkAlice2 alicePKem_ bobE2E + let aliceRatchet = CR.initRcvRatchet (CR.RatchetVersions v v) pkAlice2 (aliceInitParams, aliceKemKp_) CR.PQSupportOn + ga <- C.newRandom + aliceTVar <- newTVarIO (ga, aliceRatchet, M.empty :: CR.SkippedMsgKeys) + pure (TestPeerHS aliceTVar, TestPeerJS hIn hOut ph) + +-- JS (receiver) <-> JS (sender), no PQ +initTsTs_noPQ :: IO (TestPeer, TestPeer) +initTsTs_noPQ = do + let Version vNum = CR.currentE2EEncryptVersion + (hInA, hOutA, phA) <- spawnJsRatchet + aliceE2EHex <- either error pure . parseResponse =<< jsCmd hInA hOutA ("INIT_RCV " ++ show vNum ++ " 0") + (hInB, hOutB, phB) <- spawnJsRatchet + bobE2EHex <- either error pure . parseResponse =<< jsCmd hInB hOutB ("INIT_SND " ++ show vNum ++ " none " ++ aliceE2EHex) + completeResp <- jsCmd hInA hOutA ("COMPLETE " ++ bobE2EHex) + when (completeResp /= "ok") $ error $ "COMPLETE failed: " ++ completeResp + pure (TestPeerJS hInA hOutA phA, TestPeerJS hInB hOutB phB) + +-- JS (receiver) <-> JS (sender), PQ KEM accepted +initTsTs_PQ :: IO (TestPeer, TestPeer) +initTsTs_PQ = do + let Version vNum = CR.currentE2EEncryptVersion + (hInA, hOutA, phA) <- spawnJsRatchet + aliceE2EHex <- either error pure . parseResponse =<< jsCmd hInA hOutA ("INIT_RCV " ++ show vNum ++ " 1") + (hInB, hOutB, phB) <- spawnJsRatchet + bobE2EHex <- either error pure . parseResponse =<< jsCmd hInB hOutB ("INIT_SND " ++ show vNum ++ " accept " ++ aliceE2EHex) + completeResp <- jsCmd hInA hOutA ("COMPLETE " ++ bobE2EHex) + when (completeResp /= "ok") $ error $ "COMPLETE failed: " ++ completeResp + pure (TestPeerJS hInA hOutA phA, TestPeerJS hInB hOutB phB) + smpWebTests :: SpecWith () smpWebTests = describe "SMP Web Client" $ do distExists <- runIO $ doesDirectoryExist (smpWebDir <> "/dist") @@ -160,6 +402,67 @@ smpWebTests_ = do <> jsOut ("new Uint8Array([r.type === 'OK' ? 1 : 0])") tsResult `shouldBe` B.singleton 1 + describe "commands" $ do + let v = currentServerSMPRelayVersion + + it "encodeSUB matches Haskell" $ do + let hsEncoded = encodeProtocol v SUB + tsEncoded <- callNode $ impProto <> jsOut "encodeSUB()" + tsEncoded `shouldBe` hsEncoded + + it "encodeKEY matches Haskell" $ do + let keyDer = B.pack [0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x6e, 0x03, 0x21, 0x00] <> B.pack [1..32] + hsEncoded = "KEY " <> smpEncode keyDer + tsEncoded <- callNode $ impProto + <> jsOut ("encodeKEY(" <> jsUint8 keyDer <> ")") + tsEncoded `shouldBe` hsEncoded + + it "encodeSKEY matches Haskell" $ do + let keyDer = B.pack [0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x6e, 0x03, 0x21, 0x00] <> B.pack [1..32] + hsEncoded = "SKEY " <> smpEncode keyDer + tsEncoded <- callNode $ impProto + <> jsOut ("encodeSKEY(" <> jsUint8 keyDer <> ")") + tsEncoded `shouldBe` hsEncoded + + it "encodeACK matches Haskell" $ do + let msgId = B.pack [1..24] + hsEncoded = encodeProtocol v (ACK msgId) + tsEncoded <- callNode $ impProto + <> jsOut ("encodeACK(" <> jsUint8 msgId <> ")") + tsEncoded `shouldBe` hsEncoded + + it "encodeSEND matches Haskell" $ do + let flags = MsgFlags {notification = True} + body = "hello world" + hsEncoded = encodeProtocol v (SEND flags body) + tsEncoded <- callNode $ impProto + <> jsOut ("encodeSEND(true, new TextEncoder().encode('hello world'))") + tsEncoded `shouldBe` hsEncoded + + it "decodes IDS response" $ do + let rcvId = B.pack [1..24] + sndId = B.pack [25..48] + srvDhKey = B.pack [0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x6e, 0x03, 0x21, 0x00] <> B.pack [50..81] + -- Manually encode IDS response: "IDS " <> rcvId <> sndId <> srvDhKey <> Maybe queueMode <> Maybe linkId ... + encoded = "IDS " <> smpEncode (EntityId rcvId) <> smpEncode (EntityId sndId) <> smpEncode srvDhKey + <> smpEncode (Nothing :: Maybe B.ByteString) <> smpEncode (Nothing :: Maybe B.ByteString) + <> smpEncode (Nothing :: Maybe B.ByteString) <> smpEncode (Nothing :: Maybe B.ByteString) + tsResult <- callNode $ impProto + <> "const r = decodeResponse(new Decoder(" <> jsUint8 encoded <> "));" + <> "if (r.type !== 'IDS') throw new Error('expected IDS, got ' + r.type);" + <> jsOut ("new Uint8Array([...r.response.rcvId, ...r.response.sndId])") + tsResult `shouldBe` (rcvId <> sndId) + + it "decodes Haskell-encoded MSG response" $ do + let msgId = B.pack [1..24] + body = "encrypted message body" + hsEncoded = "MSG " <> smpEncode msgId <> body + tsResult <- callNode $ impProto + <> "const r = decodeResponse(new Decoder(" <> jsUint8 hsEncoded <> "));" + <> "if (r.type !== 'MSG') throw new Error('expected MSG, got ' + r.type);" + <> jsOut ("new Uint8Array([...r.response.msgId, ...r.response.msgBody])") + tsResult `shouldBe` (msgId <> body) + describe "transport" $ do describe "SMPServerHandshake" $ do it "TypeScript parses Haskell-encoded server handshake (no authPubKey)" $ do @@ -237,6 +540,390 @@ smpWebTests_ = do <> jsOut ("new Uint8Array([...r.fixedData, 0, ...r.userData])") tsResult `shouldBe` (fixedPlain <> B.singleton 0 <> userPlain) + describe "crypto/sntrup761" $ do + it "TypeScript encapsulates, Haskell decapsulates - shared secret matches" $ do + g <- C.newRandom + (KEMPublicKey pkBytes, sk) <- sntrup761Keypair g + tsResult <- callNode $ impSntrup + <> "const enc = sntrup761Enc(" <> jsUint8 pkBytes <> ");" + <> jsOut ("new Uint8Array([...enc.ciphertext, ...enc.sharedSecret])") + let (ctBytes, tsSharedSecret) = B.splitAt 1039 tsResult + KEMSharedKey hsSharedSecret <- sntrup761Dec (KEMCiphertext ctBytes) sk + (BA.convert hsSharedSecret :: B.ByteString) `shouldBe` tsSharedSecret + + it "Haskell encapsulates, TypeScript decapsulates - shared secret matches" $ do + -- TypeScript generates keypair, passes public key to Haskell via stdout, + -- but callNode is one-shot. So: TypeScript generates keypair, outputs (pk, sk). + -- Then Haskell encapsulates against pk, passes (ct) to TypeScript. + -- TypeScript decapsulates with sk, outputs shared secret. + -- We compare with Haskell's shared secret. + -- + -- Two callNode calls: first to get keypair, second to decapsulate. + kpResult <- callNode $ impSntrup + <> "const kp = sntrup761Keypair();" + <> jsOut ("new Uint8Array([...kp.publicKey, ...kp.secretKey])") + let (tsPk, tsSk) = B.splitAt 1158 kpResult + g <- C.newRandom + (KEMCiphertext ctBytes, KEMSharedKey hsSharedSecret) <- sntrup761Enc g (KEMPublicKey tsPk) + tsResult <- callNode $ impSntrup + <> "const ss = sntrup761Dec(" <> jsUint8 ctBytes <> "," <> jsUint8 tsSk <> ");" + <> jsOut ("ss") + tsResult `shouldBe` (BA.convert hsSharedSecret :: B.ByteString) + + describe "crypto/aesGcm" $ do + it "Haskell encryptAEAD (16-byte IV), TypeScript decrypts" $ do + let key = C.Key $ B.pack [1..32] + iv = C.IV $ B.pack [1..16] + ad = "associated data" + msg = "hello from haskell aes-gcm" + Right (C.AuthTag authTag, ct) <- runExceptT $ C.encryptAEAD key iv 64 ad msg + let tagBytes = BA.convert authTag :: B.ByteString + tsResult <- callNode $ impEnc + <> "import { gcm } from '@noble/ciphers/aes.js';" + <> "const key = " <> jsUint8 (B.pack [1..32]) <> ";" + <> "const iv = " <> jsUint8 (B.pack [1..16]) <> ";" + <> "const ad = new TextEncoder().encode('associated data');" + <> "const ct = " <> jsUint8 ct <> ";" + <> "const tag = " <> jsUint8 tagBytes <> ";" + <> "const cipher = gcm(key, iv, ad);" + <> "const encrypted = new Uint8Array([...ct, ...tag]);" + <> "const decrypted = cipher.decrypt(encrypted);" + -- unpad: 2-byte BE length prefix + message + '#' padding + <> "const len = (decrypted[0] << 8) | decrypted[1];" + <> jsOut ("decrypted.subarray(2, 2 + len)") + tsResult `shouldBe` msg + + describe "crypto/ratchet" $ do + describe "X3DH" $ do + it "pqX3dhSnd and pqX3dhRcv produce same ratchetKey" $ do + -- TypeScript generates two key pairs, computes X3DH from both sides, verifies match + tsResult <- callNode $ impSodium <> impRatchet + <> "const alice1 = generateX448KeyPair();" + <> "const alice2 = generateX448KeyPair();" + <> "const bob1 = generateX448KeyPair();" + <> "const bob2 = generateX448KeyPair();" + -- Bob (joiner) inits sending ratchet with Alice's public keys + <> "const snd = pqX3dhSnd(bob1.privateKey, bob2.privateKey, alice1.publicKey, alice2.publicKey);" + -- Alice (initiator) inits receiving ratchet with Bob's public keys + <> "const rcv = pqX3dhRcv(alice1.privateKey, alice2.privateKey, bob1.publicKey, bob2.publicKey);" + -- ratchetKey, sndHK, rcvNextHK should match + <> "const match = snd.ratchetKey.every((b, i) => b === rcv.ratchetKey[i]) && snd.sndHK.every((b, i) => b === rcv.sndHK[i]) && snd.rcvNextHK.every((b, i) => b === rcv.rcvNextHK[i]);" + <> jsOut ("new Uint8Array([match ? 1 : 0, snd.ratchetKey.length, snd.sndHK.length, snd.rcvNextHK.length])") + tsResult `shouldBe` B.pack [1, 32, 32, 32] + + describe "chainKdf" $ do + it "TypeScript chainKdf produces correct output via HKDF" $ do + -- chainKdf is hkdf3("", ck, "SimpleXChainRatchet") split into 32+32+16+16 + -- Since hkdf is already tested against Haskell, test the split logic + tsResult <- callNode $ impRatchet + <> "const r = chainKdf(" <> jsUint8 (B.pack [1..32]) <> ");" + <> jsOut ("new Uint8Array([r.ck.length, r.mk.length, r.iv.length, r.ehIV.length])") + tsResult `shouldBe` B.pack [32, 32, 16, 16] + + describe "encryptAEAD" $ do + it "TypeScript encrypt matches Haskell encrypt (same ciphertext)" $ do + let key = C.Key $ B.pack [1..32] + iv = C.IV $ B.pack [1..16] + ad = "test associated data" + msg = "ratchet plaintext" + Right (C.AuthTag hsTag, hsCt) <- runExceptT $ C.encryptAEAD key iv 64 ad msg + let hsTagBytes = BA.convert hsTag :: B.ByteString + tsResult <- callNode $ impRatchet + <> "const r = encryptAEAD(" <> jsUint8 (B.pack [1..32]) <> "," <> jsUint8 (B.pack [1..16]) <> ",64," + <> "new TextEncoder().encode('test associated data')," + <> "new TextEncoder().encode('ratchet plaintext'));" + <> jsOut ("new Uint8Array([...r.authTag, ...r.ciphertext])") + tsResult `shouldBe` (hsTagBytes <> hsCt) + + it "TypeScript decrypts Haskell-encrypted" $ do + let key = C.Key $ B.pack [10..41] + iv = C.IV $ B.pack [10..25] + ad = "ad for decrypt test" + msg = "hello from haskell ratchet" + Right (C.AuthTag hsTag, hsCt) <- runExceptT $ C.encryptAEAD key iv 64 ad msg + let hsTagBytes = BA.convert hsTag :: B.ByteString + tsResult <- callNode $ impRatchet + <> "const plain = decryptAEAD(" <> jsUint8 (B.pack [10..41]) <> "," <> jsUint8 (B.pack [10..25]) <> "," + <> "new TextEncoder().encode('ad for decrypt test')," + <> jsUint8 hsCt <> "," <> jsUint8 hsTagBytes <> ");" + <> jsOut ("plain") + tsResult `shouldBe` msg + + it "Haskell decrypts TypeScript-encrypted" $ do + let key = C.Key $ B.pack [20..51] + iv = C.IV $ B.pack [20..35] + ad = "ad for ts encrypt" + msg = "hello from typescript ratchet" + tsResult <- callNode $ impRatchet + <> "const r = encryptAEAD(" <> jsUint8 (B.pack [20..51]) <> "," <> jsUint8 (B.pack [20..35]) <> ",64," + <> "new TextEncoder().encode('ad for ts encrypt')," + <> "new TextEncoder().encode('hello from typescript ratchet'));" + <> jsOut ("new Uint8Array([...r.authTag, ...r.ciphertext])") + let (tsTag, tsCt) = B.splitAt 16 tsResult + Right hsPlain <- runExceptT $ C.decryptAEAD key iv ad tsCt (C.AuthTag $ AES.AuthTag $ BA.convert tsTag) + hsPlain `shouldBe` msg + + describe "ratchet encrypt/decrypt" $ do + it "TypeScript ratchet self-consistency: encrypt, decrypt, ratchet advance, skipped" $ do + tsResult <- callNode $ impRatchet + <> "const a1 = generateX448KeyPair(), a2 = generateX448KeyPair();" + <> "const b1 = generateX448KeyPair(), b2 = generateX448KeyPair();" + <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, a1.publicKey, a2.publicKey);" + <> "const ap = pqX3dhRcv(a1.privateKey, a2.privateKey, b1.publicKey, b2.publicKey);" + <> "const b3 = generateX448KeyPair();" + <> "let bob = initSndRatchet({current:3,maxSupported:3}, a2.publicKey, b3.privateKey, bp, null);" + <> "let alice = initRcvRatchet({current:3,maxSupported:3}, a2.privateKey, ap, null, false);" + <> "let sk = new Map();" + -- Bob sends 3 + <> "const e1 = rcEncrypt(bob, new TextEncoder().encode('msg1'), 100); bob = e1.state;" + <> "const e2 = rcEncrypt(bob, new TextEncoder().encode('msg2'), 100); bob = e2.state;" + <> "const e3 = rcEncrypt(bob, new TextEncoder().encode('msg3'), 100); bob = e3.state;" + -- Alice decrypts msg3 first (skip 1,2) + <> "let d3 = rcDecrypt(alice, sk, e3.ciphertext); alice = d3.state; sk = d3.skippedKeys;" + -- Alice decrypts msg1 from skipped + <> "let d1 = rcDecrypt(alice, sk, e1.ciphertext); alice = d1.state; sk = d1.skippedKeys;" + -- Alice responds + <> "const ea = rcEncrypt(alice, new TextEncoder().encode('reply'), 100); alice = ea.state;" + <> "const da = rcDecrypt(bob, new Map(), ea.ciphertext); bob = da.state;" + -- Verify + <> "const ok = new TextDecoder().decode(d3.plaintext) === 'msg3'" + <> " && new TextDecoder().decode(d1.plaintext) === 'msg1'" + <> " && new TextDecoder().decode(da.plaintext) === 'reply';" + <> jsOut ("new Uint8Array([ok ? 1 : 0])") + tsResult `shouldBe` B.singleton 1 + + it "cross-language: Haskell encrypts, TypeScript decrypts" $ do + -- Round 1: TypeScript generates alice's keys, outputs private keys + smpEncoded E2E params + tsAliceOutput <- callNode $ impEnc <> impRatchet + <> "const a1 = generateX448KeyPair(), a2 = generateX448KeyPair();" + -- smpEncode E2ERatchetParams v3: (version, pk1, pk2, Maybe KEMParams) + -- Nothing = 0x30 ('0') + <> "const e2e = new Uint8Array([...encodeWord16(3), ...encodeBytes(encodePubKeyX448(a1.publicKey)), ...encodeBytes(encodePubKeyX448(a2.publicKey)), 0x30]);" + -- Output: a1.privateKey(56) + a2.privateKey(56) + e2e_len(2) + e2e_bytes + <> "const lenBuf = new Uint8Array(2); lenBuf[0] = (e2e.length >> 8) & 0xff; lenBuf[1] = e2e.length & 0xff;" + <> jsOut ("new Uint8Array([...a1.privateKey, ...a2.privateKey, ...lenBuf, ...e2e])") + let (alicePriv1, rest1) = B.splitAt 56 tsAliceOutput + (alicePriv2, rest2) = B.splitAt 56 rest1 + e2eLen = fromIntegral (B.index rest2 0) * 256 + fromIntegral (B.index rest2 1) + aliceE2EBytes = B.take e2eLen $ B.drop 2 rest2 + + -- Round 2: Haskell decodes alice's E2E params, generates bob, encrypts + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + Right (aliceE2E@(CR.E2ERatchetParams _ _ alicePk2 _) :: CR.E2ERatchetParams 'CR.RKSProposed 'X448) <- pure $ smpDecode aliceE2EBytes + (bobPk1, bobPk2, _pKem, CR.AE2ERatchetParams _ bobE2E) <- CR.generateSndE2EParams @'X448 g v Nothing + Right (bobInitParams, _) <- pure $ CR.pqX3dhSnd bobPk1 bobPk2 Nothing aliceE2E + (_, bobDHRs) <- atomically $ C.generateKeyPair @'X448 g + let bobRatchet = CR.initSndRatchet (CR.RatchetVersions v v) alicePk2 bobDHRs (bobInitParams, Nothing) + Right (mek, _) <- runExceptT $ CR.rcEncryptHeader bobRatchet Nothing v + Right ciphertext <- runExceptT $ CR.rcEncryptMsg mek paddedMsgLen "hello from haskell ratchet" + let bobE2EBytes = smpEncode bobE2E + + -- Round 3: TypeScript decodes bob's params, inits ratchet, decrypts + tsResult <- callNode $ impEnc <> impRatchet + -- Parse bob's E2E params + <> "const d = new Decoder(" <> jsUint8 bobE2EBytes <> ");" + <> "const bobV = d.anyByte() * 256 + d.anyByte();" + <> "const bobPk1Raw = decodePubKeyX448(decodeBytes(d));" + <> "const bobPk2Raw = decodePubKeyX448(decodeBytes(d));" + <> "const a1Priv = " <> jsUint8 alicePriv1 <> ";" + <> "const a2Priv = " <> jsUint8 alicePriv2 <> ";" + <> "const ap = pqX3dhRcv(a1Priv, a2Priv, bobPk1Raw, bobPk2Raw);" + <> "const alice = initRcvRatchet({current:3,maxSupported:3}, a2Priv, ap, null, false);" + <> "const dec = rcDecrypt(alice, new Map(), " <> jsUint8 ciphertext <> ");" + <> jsOut ("dec.plaintext") + tsResult `shouldBe` "hello from haskell ratchet" + + it "cross-language: TypeScript encrypts, Haskell decrypts" $ do + -- Round 1: Haskell generates alice's keys, outputs encoded E2E params + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + (alicePk1, alicePk2, _pKem, aliceE2E) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOff + let aliceE2EBytes = smpEncode aliceE2E + + -- Round 2: TypeScript generates bob's keys, does X3DH, inits snd ratchet, encrypts + tsOutput <- callNode $ impEnc <> impRatchet + -- Parse alice's E2E params + <> "const d = new Decoder(" <> jsUint8 aliceE2EBytes <> ");" + <> "const aliceV = d.anyByte() * 256 + d.anyByte();" + <> "const alicePk1Raw = decodePubKeyX448(decodeBytes(d));" + <> "const alicePk2Raw = decodePubKeyX448(decodeBytes(d));" + -- Bob generates keys + <> "const b1 = generateX448KeyPair(), b2 = generateX448KeyPair();" + <> "const b3 = generateX448KeyPair();" + -- X3DH (bob is sender) + <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, alicePk1Raw, alicePk2Raw);" + -- Init sending ratchet + <> "let bob = initSndRatchet({current:3,maxSupported:3}, alicePk2Raw, b3.privateKey, bp, null);" + -- Encrypt + <> "const enc = rcEncrypt(bob, new TextEncoder().encode('hello from typescript ratchet'), 100);" + -- Output: bob's E2E params (version + 2 DER keys + Nothing KEM) + ciphertext + <> "const bobE2E = new Uint8Array([...encodeWord16(3), ...encodeBytes(encodePubKeyX448(b1.publicKey)), ...encodeBytes(encodePubKeyX448(b2.publicKey)), 0x30]);" + <> "const lenBuf = new Uint8Array(2); lenBuf[0] = (bobE2E.length >> 8) & 0xff; lenBuf[1] = bobE2E.length & 0xff;" + <> "const ctLenBuf = new Uint8Array(2); ctLenBuf[0] = (enc.ciphertext.length >> 8) & 0xff; ctLenBuf[1] = enc.ciphertext.length & 0xff;" + <> jsOut ("new Uint8Array([...lenBuf, ...bobE2E, ...ctLenBuf, ...enc.ciphertext])") + -- Parse output: [2 bytes e2e len][e2e bytes][2 bytes ct len][ct bytes] + let (e2eLenBs, rest1) = B.splitAt 2 tsOutput + bobE2ELen = fromIntegral (B.index e2eLenBs 0) * 256 + fromIntegral (B.index e2eLenBs 1) + (bobE2EBytes, rest2) = B.splitAt bobE2ELen rest1 + (ctLenBs, rest3) = B.splitAt 2 rest2 + ctLen = fromIntegral (B.index ctLenBs 0) * 256 + fromIntegral (B.index ctLenBs 1) + ciphertext = B.take ctLen rest3 + + -- Round 3: Haskell decodes bob's params, does X3DH, inits rcv ratchet, decrypts + Right (CR.AE2ERatchetParams _ bobE2EParams :: CR.AE2ERatchetParams 'X448) <- pure $ smpDecode bobE2EBytes + Right (aliceInitParams, _) <- runExceptT $ CR.pqX3dhRcv alicePk1 alicePk2 Nothing bobE2EParams + let aliceRatchet = CR.initRcvRatchet (CR.RatchetVersions v v) alicePk2 (aliceInitParams, Nothing) CR.PQSupportOff + gAlice <- C.newRandom + Right (msg, _, _) <- runExceptT $ CR.rcDecrypt gAlice aliceRatchet M.empty ciphertext + msg `shouldBe` Right "hello from typescript ratchet" + + it "cross-language: PQ X3DH - Haskell proposes KEM, TypeScript accepts, encrypts" $ do + -- Round 1: Haskell (alice) generates keys with PQ KEM proposal + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + (alicePk1, alicePk2, alicePKem_@(Just _), aliceE2E) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOn + let aliceE2EBytes = smpEncode aliceE2E + + -- Round 2: TypeScript (bob) accepts KEM, does X3DH, inits snd ratchet, encrypts + tsOutput <- callNode $ impEnc <> impSodium <> impRatchet <> impSntrup + -- Parse alice's E2E params (v3: version + pk1 + pk2 + Maybe ARKEMParams) + <> "const d = new Decoder(" <> jsUint8 aliceE2EBytes <> ");" + <> "const aliceV = d.anyByte() * 256 + d.anyByte();" + <> "const alicePk1Raw = decodePubKeyX448(decodeBytes(d));" + <> "const alicePk2Raw = decodePubKeyX448(decodeBytes(d));" + -- Parse Maybe ARKEMParams: '1' + 'P' + KEMPublicKey(Large) + <> "const maybeByte = d.anyByte();" + <> "if (maybeByte !== 0x31) throw new Error('expected Just KEM');" + <> "const kemTag = d.anyByte();" + <> "if (kemTag !== 0x50) throw new Error('expected P (proposed), got ' + kemTag);" + <> "const aliceKemPk = decodeLarge(d);" + -- Bob generates DH keys + <> "const b1 = generateX448KeyPair(), b2 = generateX448KeyPair();" + <> "const b3 = generateX448KeyPair();" + -- Bob encapsulates against alice's KEM public key + <> "const kemEnc = sntrup761Enc(aliceKemPk);" + -- Bob generates his own KEM keypair for future ratchet steps + <> "const bobKem = sntrup761Keypair();" + -- Construct kemAccepted matching Haskell RatchetKEMAccepted: + -- rcPQRr = alice's KEM public key (received) + -- rcPQRss = shared secret (from encapsulation) + -- rcPQRct = ciphertext (sent to alice) + <> "const kemAccepted = {rcPQRr: aliceKemPk, rcPQRss: kemEnc.sharedSecret, rcPQRct: kemEnc.ciphertext};" + -- X3DH with kemAccepted (folds shared secret into HKDF AND stores in RatchetInitParams) + <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, alicePk1Raw, alicePk2Raw, kemAccepted);" + -- Init sending ratchet with bob's KEM keypair + <> "let bob = initSndRatchet({current:3,maxSupported:3}, alicePk2Raw, b3.privateKey, bp, bobKem);" + -- Encrypt + <> "const enc = rcEncrypt(bob, new TextEncoder().encode('hello with PQ'), 100);" + -- Build bob's E2E params: version + pk1 + pk2 + Just(Accepted(ct, bobKemPk)) + -- smpEncode ('A', ct, bobKemPk) where ct and pk are Large-encoded + <> "const bobE2E = new Uint8Array([" + <> " ...encodeWord16(3)," + <> " ...encodeBytes(encodePubKeyX448(b1.publicKey))," + <> " ...encodeBytes(encodePubKeyX448(b2.publicKey))," + <> " 0x31," -- Just + <> " 0x41," -- 'A' = Accepted + <> " ...new Uint8Array([(kemEnc.ciphertext.length >> 8) & 0xff, kemEnc.ciphertext.length & 0xff]), ...kemEnc.ciphertext," + <> " ...new Uint8Array([(bobKem.publicKey.length >> 8) & 0xff, bobKem.publicKey.length & 0xff]), ...bobKem.publicKey," + <> "]);" + <> "const lenBuf = new Uint8Array(2); lenBuf[0] = (bobE2E.length >> 8) & 0xff; lenBuf[1] = bobE2E.length & 0xff;" + <> "const ctLenBuf = new Uint8Array(2); ctLenBuf[0] = (enc.ciphertext.length >> 8) & 0xff; ctLenBuf[1] = enc.ciphertext.length & 0xff;" + <> jsOut ("new Uint8Array([...lenBuf, ...bobE2E, ...ctLenBuf, ...enc.ciphertext])") + let (e2eLenBs, rest1) = B.splitAt 2 tsOutput + bobE2ELen = fromIntegral (B.index e2eLenBs 0) * 256 + fromIntegral (B.index e2eLenBs 1) + (bobE2EBytes, rest2) = B.splitAt bobE2ELen rest1 + (ctLenBs, rest3) = B.splitAt 2 rest2 + ctLen = fromIntegral (B.index ctLenBs 0) * 256 + fromIntegral (B.index ctLenBs 1) + ciphertext = B.take ctLen rest3 + + -- Round 3: Haskell decodes bob's params (with KEM accepted), does X3DH with KEM, decrypts + Right (CR.AE2ERatchetParams _ bobE2EParams :: CR.AE2ERatchetParams 'X448) <- pure $ smpDecode bobE2EBytes + Right (aliceInitParams, aliceKemKp_) <- runExceptT $ CR.pqX3dhRcv alicePk1 alicePk2 alicePKem_ bobE2EParams + let aliceRatchet = CR.initRcvRatchet (CR.RatchetVersions v v) alicePk2 (aliceInitParams, aliceKemKp_) CR.PQSupportOn + gAlice <- C.newRandom + result <- runExceptT $ CR.rcDecrypt gAlice aliceRatchet M.empty ciphertext + case result of + Right (msg, _, _) -> msg `shouldBe` Right "hello with PQ" + Left e -> expectationFailure $ "rcDecrypt failed: " <> show e + + it "TypeScript PQ ratchet self-consistency: multi-message with KEM ratchet steps" $ do + tsResult <- callNode $ impSodium <> impSntrup <> impRatchet + <> "const a1 = generateX448KeyPair(), a2 = generateX448KeyPair();" + <> "const b1 = generateX448KeyPair(), b2 = generateX448KeyPair();" + <> "const b3 = generateX448KeyPair();" + -- Alice proposes KEM + <> "const aliceKem = sntrup761Keypair();" + -- Bob accepts: encapsulate against alice's KEM public key + <> "const kemEnc = sntrup761Enc(aliceKem.publicKey);" + <> "const bobKem = sntrup761Keypair();" + <> "const kemAccepted = {rcPQRr: aliceKem.publicKey, rcPQRss: kemEnc.sharedSecret, rcPQRct: kemEnc.ciphertext};" + -- Alice receives bob's acceptance: decapsulate to get shared secret + <> "const aliceSS = sntrup761Dec(kemEnc.ciphertext, aliceKem.secretKey);" + <> "const aliceKemAccepted = {rcPQRr: bobKem.publicKey, rcPQRss: aliceSS, rcPQRct: kemEnc.ciphertext};" + -- X3DH for both sides + <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, a1.publicKey, a2.publicKey, kemAccepted);" + <> "const ap = pqX3dhRcv(a1.privateKey, a2.privateKey, b1.publicKey, b2.publicKey, aliceKemAccepted);" + -- Init ratchets with KEM keypairs + <> "let bob = initSndRatchet({current:3,maxSupported:3}, a2.publicKey, b3.privateKey, bp, bobKem);" + <> "let alice = initRcvRatchet({current:3,maxSupported:3}, a2.privateKey, ap, aliceKem, true);" + <> "let sk = new Map();" + -- Bob sends msg1 (has KEM params in header from initSndRatchet) + <> "const e1 = rcEncrypt(bob, new TextEncoder().encode('pq msg1'), 100); bob = e1.state;" + -- Alice decrypts msg1 (triggers ratchet advance with KEM) + <> "let d1 = rcDecrypt(alice, sk, e1.ciphertext); alice = d1.state; sk = d1.skippedKeys;" + -- Alice sends msg2 (ratchet advanced, has KEM params from pqRatchetStep) + <> "const e2 = rcEncrypt(alice, new TextEncoder().encode('pq msg2'), 100); alice = e2.state;" + -- Bob decrypts msg2 (triggers ratchet advance with KEM on bob's side) + <> "let d2 = rcDecrypt(bob, new Map(), e2.ciphertext); bob = d2.state;" + -- Bob sends msg3 (another ratchet advance with KEM) + <> "const e3 = rcEncrypt(bob, new TextEncoder().encode('pq msg3'), 100); bob = e3.state;" + -- Alice decrypts msg3 + <> "let d3 = rcDecrypt(alice, sk, e3.ciphertext); alice = d3.state; sk = d3.skippedKeys;" + -- Verify all messages + <> "const ok = new TextDecoder().decode(d1.plaintext) === 'pq msg1'" + <> " && new TextDecoder().decode(d2.plaintext) === 'pq msg2'" + <> " && new TextDecoder().decode(d3.plaintext) === 'pq msg3'" + -- Verify KEM state is maintained + <> " && alice.rcKEM !== null && bob.rcKEM !== null" + <> " && alice.rcSndKEM === true && bob.rcSndKEM === true;" + <> jsOut ("new Uint8Array([ok ? 1 : 0])") + tsResult `shouldBe` B.singleton 1 + + describe "DER encoding" $ do + it "X448 DER round-trips" $ do + tsResult <- callNode $ impRatchet + <> "const kp = generateX448KeyPair();" + <> "const der = encodePubKeyX448(kp.publicKey);" + <> "const raw = decodePubKeyX448(der);" + <> "const match = kp.publicKey.every((b, i) => b === raw[i]);" + <> jsOut ("new Uint8Array([match ? 1 : 0, der.length, raw.length])") + tsResult `shouldBe` B.pack [1, 68, 56] + + describe "cross-language ratchet advance" $ do + let run initPeers op test = withCrossPeers initPeers $ \(alice, bob) -> + test alice bob tpEncrypt tpDecrypt op + describe "HS rcv, JS snd, no PQ" $ do + it "encrypt and decrypt" $ run initHsJs_noPQ tp_noKEM testEncryptDecrypt + it "skipped messages" $ run initHsJs_noPQ tp_noKEM testSkippedMessages + it "many messages" $ run initHsJs_noPQ tp_noKEM testManyMessages + it "skipped after ratchet advance" $ run initHsJs_noPQ tp_noKEM testSkippedAfterRatchetAdvance + describe "HS rcv, JS snd, PQ" $ do + it "encrypt and decrypt" $ run initHsJs_PQ tp_hasKEM testEncryptDecrypt + it "skipped messages" $ run initHsJs_PQ tp_hasKEM testSkippedMessages + it "many messages" $ run initHsJs_PQ tp_hasKEM testManyMessages + it "skipped after ratchet advance" $ run initHsJs_PQ tp_hasKEM testSkippedAfterRatchetAdvance + describe "JS rcv, JS snd, no PQ" $ do + it "encrypt and decrypt" $ run initTsTs_noPQ tp_noKEM testEncryptDecrypt + it "skipped messages" $ run initTsTs_noPQ tp_noKEM testSkippedMessages + it "many messages" $ run initTsTs_noPQ tp_noKEM testManyMessages + it "skipped after ratchet advance" $ run initTsTs_noPQ tp_noKEM testSkippedAfterRatchetAdvance + describe "JS rcv, JS snd, PQ" $ do + it "encrypt and decrypt" $ run initTsTs_PQ tp_hasKEM testEncryptDecrypt + it "skipped messages" $ run initTsTs_PQ tp_hasKEM testSkippedMessages + it "many messages" $ run initTsTs_PQ tp_hasKEM testManyMessages + it "skipped after ratchet advance" $ run initTsTs_PQ tp_hasKEM testSkippedAfterRatchetAdvance + describe "crypto/blockEncryption" $ do describe "sbcInit + sbcHkdf" $ do it "TypeScript produces same sbKey/nonce via sbcInit+sbcHkdf as Haskell" $ do @@ -454,3 +1141,303 @@ smpWebTests_ = do -- First byte: rootKey DER length (44 for Ed25519), rest: userData B.head tsResult `shouldBe` 44 B.tail tsResult `shouldBe` testData + + describe "agent/message" $ do + describe "AMessage" $ do + it "HELLO encoding matches Haskell" $ do + let hsBytes = smpEncode AP.HELLO + tsBytes <- callNode $ impAgentMsg <> jsOut "encodeAMessage({type: 'HELLO'})" + tsBytes `shouldBe` hsBytes + + it "A_MSG encoding matches Haskell" $ do + let body = "hello world from agent" + hsBytes = smpEncode (AP.A_MSG body) + tsBytes <- callNode $ impAgentMsg <> jsOut ("encodeAMessage({type: 'A_MSG', body: new TextEncoder().encode('hello world from agent')})") + tsBytes `shouldBe` hsBytes + + it "EREADY encoding matches Haskell" $ do + let hsBytes = smpEncode (AP.EREADY 42) + tsBytes <- callNode $ impEnc <> impAgentMsg <> jsOut ("encodeAMessage({type: 'EREADY', lastDecryptedMsgId: 42n})") + tsBytes `shouldBe` hsBytes + + it "TypeScript decodes Haskell A_MSG" $ do + let body = "decode test" + hsBytes = smpEncode (AP.A_MSG body) + tsResult <- callNode $ impEnc <> impAgentMsg + <> "const msg = decodeAMessage(new Decoder(" <> jsUint8 hsBytes <> "));" + <> jsOut ("msg.body") + tsResult `shouldBe` body + + describe "APrivHeader" $ do + it "encoding matches Haskell" $ do + let hdr = AP.APrivHeader 1 (B.replicate 32 0xAB) + hsBytes = smpEncode hdr + tsBytes <- callNode $ impEnc <> impAgentMsg + <> jsOut ("encodeAPrivHeader({sndMsgId: 1n, prevMsgHash: " <> jsUint8 (B.replicate 32 0xAB) <> "})") + tsBytes `shouldBe` hsBytes + + describe "AgentMessage" $ do + it "M variant encoding matches Haskell" $ do + let hdr = AP.APrivHeader 5 (B.replicate 32 0) + msg = AP.AgentMessage hdr (AP.A_MSG "test body") + hsBytes = smpEncode msg + tsBytes <- callNode $ impEnc <> impAgentMsg + <> jsOut ("encodeAgentMessage({type: 'message', header: {sndMsgId: 5n, prevMsgHash: new Uint8Array(32)}, msg: {type: 'A_MSG', body: new TextEncoder().encode('test body')}})") + tsBytes `shouldBe` hsBytes + + it "TypeScript decodes Haskell AgentMessage M" $ do + let hdr = AP.APrivHeader 99 (B.pack [1..32]) + msg = AP.AgentMessage hdr (AP.A_MSG "cross-language message") + hsBytes = smpEncode msg + tsResult <- callNode $ impEnc <> impAgentMsg + <> "const m = decodeAgentMessage(new Decoder(" <> jsUint8 hsBytes <> "));" + <> "if (m.type !== 'message') throw new Error('expected message');" + <> "if (m.msg.type !== 'A_MSG') throw new Error('expected A_MSG');" + <> jsOut ("new Uint8Array([...new TextEncoder().encode(m.header.sndMsgId.toString()), 0, ...m.msg.body])") + let (idStr, rest) = B.break (== 0) tsResult + idStr `shouldBe` "99" + B.tail rest `shouldBe` "cross-language message" + + describe "AgentMsgEnvelope" $ do + it "M variant encoding matches Haskell" $ do + let env = AP.AgentMsgEnvelope {AP.agentVersion = AP.currentSMPAgentVersion, AP.encAgentMessage = "encrypted payload"} + hsBytes = smpEncode env + tsBytes <- callNode $ impEnc <> impAgentMsg + <> jsOut ("encodeAgentMsgEnvelope({type: 'envelope', agentVersion: 7, encAgentMessage: new TextEncoder().encode('encrypted payload')})") + tsBytes `shouldBe` hsBytes + + it "TypeScript decodes Haskell AgentMsgEnvelope M" $ do + let env = AP.AgentMsgEnvelope {AP.agentVersion = AP.currentSMPAgentVersion, AP.encAgentMessage = "decrypt me"} + hsBytes = smpEncode env + tsResult <- callNode $ impEnc <> impAgentMsg + <> "const e = decodeAgentMsgEnvelope(new Decoder(" <> jsUint8 hsBytes <> "));" + <> "if (e.type !== 'envelope') throw new Error('expected envelope');" + <> jsOut ("e.encAgentMessage") + tsResult `shouldBe` "decrypt me" + + describe "protocol/e2e" $ do + describe "PubHeader" $ do + it "encoding without key matches Haskell" $ do + let h = PubHeader (VersionSMPC 19) Nothing + hsBytes = smpEncode h + tsBytes <- callNode $ impEnc <> impProtoE2E <> jsOut "encodePubHeader({phVersion: 19, phE2ePubDhKey: null})" + tsBytes `shouldBe` hsBytes + + it "encoding with key matches Haskell" $ do + g <- C.newRandom + (k, _) <- atomically $ C.generateKeyPair @'C.X25519 g + let derKey = C.encodePubKey k -- raw DER bytes without smpEncode length prefix + h = PubHeader (VersionSMPC 19) (Just k) + hsBytes = smpEncode h + tsBytes <- callNode $ impEnc <> impProtoE2E + <> jsOut ("encodePubHeader({phVersion: 19, phE2ePubDhKey: " <> jsUint8 derKey <> "})") + tsBytes `shouldBe` hsBytes + + describe "PrivHeader" $ do + it "PHEmpty encoding matches Haskell" $ do + let hsBytes = smpEncode PHEmpty + tsBytes <- callNode $ impProtoE2E <> jsOut "encodePrivHeader({type: 'PHEmpty'})" + tsBytes `shouldBe` hsBytes + + describe "ClientMessage" $ do + it "encoding matches Haskell" $ do + let body = "agent envelope bytes here" + msg = ClientMessage PHEmpty body + hsBytes = smpEncode msg + tsBytes <- callNode $ impEnc <> impProtoE2E + <> jsOut ("encodeClientMessage({privHeader: {type: 'PHEmpty'}, body: new TextEncoder().encode('agent envelope bytes here')})") + tsBytes `shouldBe` hsBytes + + describe "ClientMsgEnvelope" $ do + it "encoding matches Haskell" $ do + let nonce = C.cbNonce $ B.pack [1..24] + h = PubHeader (VersionSMPC 19) Nothing + env = ClientMsgEnvelope {cmHeader = h, cmNonce = nonce, cmEncBody = "encrypted body data"} + hsBytes = smpEncode env + tsBytes <- callNode $ impEnc <> impProtoE2E + <> jsOut ("encodeClientMsgEnvelope({cmHeader: {phVersion: 19, phE2ePubDhKey: null}, cmNonce: " <> jsUint8 (B.pack [1..24]) <> ", cmEncBody: new TextEncoder().encode('encrypted body data')})") + tsBytes `shouldBe` hsBytes + + it "TypeScript decodes Haskell-encoded" $ do + let nonce = C.cbNonce $ B.pack [10..33] + h = PubHeader (VersionSMPC 19) Nothing + env = ClientMsgEnvelope {cmHeader = h, cmNonce = nonce, cmEncBody = "test ciphertext"} + hsBytes = smpEncode env + tsResult <- callNode $ impEnc <> impProtoE2E + <> "const env = decodeClientMsgEnvelope(new Decoder(" <> jsUint8 hsBytes <> "));" + <> jsOut ("new Uint8Array([env.cmHeader.phVersion >> 8, env.cmHeader.phVersion & 0xff, env.cmHeader.phE2ePubDhKey === null ? 1 : 0, ...env.cmNonce, ...env.cmEncBody])") + let (version, rest1) = B.splitAt 2 tsResult + (nullByte, rest2) = B.splitAt 1 rest1 + (nonceBytes, bodyBytes) = B.splitAt 24 rest2 + version `shouldBe` B.pack [0, 19] + nullByte `shouldBe` B.singleton 1 + nonceBytes `shouldBe` B.pack [10..33] + bodyBytes `shouldBe` "test ciphertext" + + describe "per-queue E2E encrypt/decrypt" $ do + it "TypeScript encrypts, Haskell decrypts" $ do + -- Haskell generates receiver keypair + g <- C.newRandom + (rcvPub, rcvPriv) <- atomically $ C.generateKeyPair @'C.X25519 g + let rcvPubRaw = C.pubKeyBytes rcvPub + -- TypeScript generates sender keypair, computes DH, encrypts + tsOutput <- callNode $ impSodium <> impEnc <> impProtoE2E + <> "import { generateX25519KeyPair, dh, encodePubKeyX25519 } from '@simplex-chat/xftp-web/dist/crypto/keys.js';" + <> "const sndKp = generateX25519KeyPair();" + <> "const rcvPub = " <> jsUint8 rcvPubRaw <> ";" + <> "const dhSecret = dh(rcvPub, sndKp.privateKey);" + <> "const clientMsg = encodeClientMessage({privHeader: {type: 'PHEmpty'}, body: new TextEncoder().encode('hello from typescript e2e')});" + <> "const encrypted = agentCbEncrypt(dhSecret, 19, null, clientMsg);" + -- Output: DER-encoded sndPubKey (ByteString-encoded: 1-byte len + DER) + encrypted + <> "const sndDer = encodePubKeyX25519(sndKp.publicKey);" + <> jsOut ("new Uint8Array([sndDer.length, ...sndDer, ...encrypted])") + -- Parse output: [1 byte len][DER sndPubKey][encrypted] + let sndDerLen = fromIntegral $ B.head tsOutput + (sndDerBytes, encrypted) = B.splitAt sndDerLen $ B.drop 1 tsOutput + -- Haskell decodes sender's DER public key and decrypts + let decoded = do + apk <- C.decodePubKey sndDerBytes + dhSecret <- case apk of + C.APublicKey C.SX25519 pk -> Right $ C.dh' pk rcvPriv + _ -> Left "not X25519" + cme <- smpDecode encrypted + plaintext <- first show $ C.cbDecrypt dhSecret (cmNonce cme) (cmEncBody cme) + cm <- smpDecode plaintext + case cm of + ClientMessage PHEmpty body -> Right body + _ -> Left "unexpected PrivHeader" + decoded `shouldBe` Right "hello from typescript e2e" + + describe "full-stack" $ do + it "Haskell encodes all layers, TypeScript decodes" $ do + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + -- Alice (receiver) ratchet keys - extract raw private bytes for TypeScript + (alicePk1, alicePk2, Nothing, e2eAlice) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOff + let C.PrivateKeyX448 sk1 = alicePk1; alicePriv1 = BA.convert sk1 :: B.ByteString + C.PrivateKeyX448 sk2 = alicePk2; alicePriv2 = BA.convert sk2 :: B.ByteString + -- Bob (sender) ratchet: X3DH + init + (bobPk1, bobPk2, Nothing, CR.AE2ERatchetParams _ e2eBob) <- CR.generateSndE2EParams @'X448 g v Nothing + Right bobInitParams <- pure $ CR.pqX3dhSnd bobPk1 bobPk2 Nothing e2eAlice + (_, bobDHRs) <- atomically $ C.generateKeyPair @'X448 g + let bobRatchet = CR.initSndRatchet (CR.RatchetVersions v v) (C.publicKey alicePk2) bobDHRs bobInitParams + bobE2EBytes = smpEncode e2eBob + -- Per-queue E2E: shared DH secret (pass raw bytes to both sides) + (_, e2eSndPriv) <- atomically $ C.generateKeyPair @'C.X25519 g + (e2eRcvPub, _) <- atomically $ C.generateKeyPair @'C.X25519 g + let dhSecret = C.dh' e2eRcvPub e2eSndPriv + dhSecretBytes = C.dhBytes' dhSecret + -- Haskell: encode A_MSG through all layers + let aMsg = AP.AgentMessage (AP.APrivHeader 1 (B.replicate 32 0)) (AP.A_MSG "hello full stack") + agentMsgBytes = smpEncode aMsg + Right (mek, _) <- runExceptT $ CR.rcEncryptHeader bobRatchet Nothing CR.currentE2EEncryptVersion + Right encAgentMsg <- runExceptT $ CR.rcEncryptMsg mek (AP.e2eEncAgentMsgLength AP.currentSMPAgentVersion CR.PQSupportOff) agentMsgBytes + let envBytes = smpEncode $ AP.AgentMsgEnvelope {AP.agentVersion = AP.currentSMPAgentVersion, AP.encAgentMessage = encAgentMsg} + clientMsgBytes = smpEncode $ ClientMessage PHEmpty envBytes + cmNonce <- atomically $ C.randomCbNonce g + Right cmEncBody <- pure $ C.cbEncrypt dhSecret cmNonce clientMsgBytes 16000 + let cmeBytes = smpEncode $ ClientMsgEnvelope (PubHeader (VersionSMPC 19) Nothing) cmNonce cmEncBody + -- TypeScript: init alice ratchet + decode all layers + tsResult <- callNode $ impSodium <> impEnc <> impRatchet <> impAgentMsg <> impProtoE2E + <> "import { cbDecrypt } from '@simplex-chat/xftp-web/dist/crypto/secretbox.js';" + -- Init alice's receiver ratchet + <> "const a1Priv = " <> jsUint8 alicePriv1 <> ";" + <> "const a2Priv = " <> jsUint8 alicePriv2 <> ";" + <> "const rd = new Decoder(" <> jsUint8 bobE2EBytes <> ");" + <> "rd.anyByte(); rd.anyByte();" -- skip version + <> "const bpk1 = decodePubKeyX448(decodeBytes(rd));" + <> "const bpk2 = decodePubKeyX448(decodeBytes(rd));" + <> "const ap = pqX3dhRcv(a1Priv, a2Priv, bpk1, bpk2);" + <> "let alice = initRcvRatchet({current:3,maxSupported:3}, a2Priv, ap, null, false);" + <> "let sk = new Map();" + -- Layer 1: per-queue E2E decrypt + <> "const dhSecret = " <> jsUint8 dhSecretBytes <> ";" + <> "const {clientMessage} = agentCbDecrypt(dhSecret, " <> jsUint8 cmeBytes <> ");" + -- Layer 2: decode AgentMsgEnvelope + <> "const env = decodeAgentMsgEnvelope(new Decoder(clientMessage.body));" + <> "if (env.type !== 'envelope') throw new Error('expected envelope, got ' + env.type);" + -- Layer 3: ratchet decrypt + <> "const dec = rcDecrypt(alice, sk, env.encAgentMessage);" + -- Layer 4: decode AgentMessage + AMessage + <> "const agentMsg = decodeAgentMessage(new Decoder(dec.plaintext));" + <> "if (agentMsg.type !== 'message') throw new Error('expected message, got ' + agentMsg.type);" + <> "if (agentMsg.msg.type !== 'A_MSG') throw new Error('expected A_MSG, got ' + agentMsg.msg.type);" + <> jsOut ("agentMsg.msg.body") + tsResult `shouldBe` "hello full stack" + + it "TypeScript encodes all layers, Haskell decodes" $ do + g <- C.newRandom + let v = CR.currentE2EEncryptVersion + -- Alice (receiver): Haskell generates ratchet rcv params + X25519 for per-queue E2E + (alicePk1, alicePk2, Nothing, e2eAlice) <- CR.generateRcvE2EParams @'X448 g v CR.PQSupportOff + let aliceE2EBytes = smpEncode e2eAlice + (e2eRcvPub, e2eRcvPriv) <- atomically $ C.generateKeyPair @'C.X25519 g + -- TypeScript: init bob's sender ratchet, encode full stack, output keys + ciphertext + tsOutput <- callNode $ impSodium <> impEnc <> impRatchet <> impAgentMsg <> impProtoE2E + <> "import { generateX25519KeyPair, dh, encodePubKeyX25519 } from '@simplex-chat/xftp-web/dist/crypto/keys.js';" + -- Init bob's sender ratchet + <> "const d = new Decoder(" <> jsUint8 aliceE2EBytes <> ");" + <> "const aliceV = d.anyByte() * 256 + d.anyByte();" + <> "const alicePk1Raw = decodePubKeyX448(decodeBytes(d));" + <> "const alicePk2Raw = decodePubKeyX448(decodeBytes(d));" + <> "const b1 = generateX448KeyPair(), b2 = generateX448KeyPair(), b3 = generateX448KeyPair();" + <> "const bp = pqX3dhSnd(b1.privateKey, b2.privateKey, alicePk1Raw, alicePk2Raw);" + <> "let bob = initSndRatchet({current:3,maxSupported:3}, alicePk2Raw, b3.privateKey, bp, null);" + -- Layer 4: encode AgentMessage + <> "const agentMsg = encodeAgentMessage({type: 'message', header: {sndMsgId: 1n, prevMsgHash: new Uint8Array(32)}, msg: {type: 'A_MSG', body: new TextEncoder().encode('hello from ts full stack')}});" + -- Layer 3: ratchet encrypt + <> "const enc = rcEncrypt(bob, agentMsg, 15840);" + -- Layer 2: wrap in AgentMsgEnvelope + <> "const envBytes = encodeAgentMsgEnvelope({type: 'envelope', agentVersion: 7, encAgentMessage: enc.ciphertext});" + -- Layer 1: per-queue E2E encrypt + <> "const sndKp = generateX25519KeyPair();" + <> "const rcvPub = " <> jsUint8 (C.pubKeyBytes e2eRcvPub) <> ";" + <> "const dhSecret = dh(rcvPub, sndKp.privateKey);" + <> "const clientMsg = encodeClientMessage({privHeader: {type: 'PHEmpty'}, body: envBytes});" + <> "const cmeBytes = agentCbEncrypt(dhSecret, 19, null, clientMsg);" + -- Output: bob E2E params + snd DER pubkey + cmeBytes + <> "const bobE2E = new Uint8Array([...encodeWord16(3), ...encodeBytes(encodePubKeyX448(b1.publicKey)), ...encodeBytes(encodePubKeyX448(b2.publicKey)), 0x30]);" + <> "const sndDer = encodePubKeyX25519(sndKp.publicKey);" + <> "const out = new Uint8Array([" + <> " (bobE2E.length >> 8) & 0xff, bobE2E.length & 0xff, ...bobE2E," + <> " sndDer.length, ...sndDer," + <> " ...cmeBytes" + <> "]);" + <> jsOut ("out") + -- Parse TypeScript output + let (e2eLenBs, r1) = B.splitAt 2 tsOutput + bobE2ELen = fromIntegral (B.index e2eLenBs 0) * 256 + fromIntegral (B.index e2eLenBs 1) + (bobE2EBytes, r2) = B.splitAt bobE2ELen r1 + sndDerLen = fromIntegral $ B.head r2 + (sndDerBytes, cmeBytes) = B.splitAt sndDerLen $ B.drop 1 r2 + -- Haskell: init alice's receiver ratchet, decode all layers + Right (CR.AE2ERatchetParams _ bobE2E :: CR.AE2ERatchetParams 'X448) <- pure $ smpDecode bobE2EBytes + Right (aliceInitParams, _) <- runExceptT $ CR.pqX3dhRcv alicePk1 alicePk2 Nothing bobE2E + let aliceRatchet = CR.initRcvRatchet (CR.RatchetVersions v v) alicePk2 (aliceInitParams, Nothing) CR.PQSupportOff + -- Decode all layers using ExceptT to chain pure Either + IO + gAlice <- C.newRandom + result <- runExceptT $ do + -- Per-queue E2E decrypt (pure) + apk <- liftEither $ C.decodePubKey sndDerBytes + dhSecret <- case apk of + C.APublicKey C.SX25519 pk -> pure $ C.dh' pk e2eRcvPriv + _ -> throwError "not X25519" + cme <- liftEither $ smpDecode cmeBytes + plaintext <- liftEither $ first show $ C.cbDecrypt dhSecret (cmNonce cme) (cmEncBody cme) + cm <- liftEither $ smpDecode plaintext + envBody <- case cm of + ClientMessage PHEmpty b -> pure b + _ -> throwError "unexpected PrivHeader" + -- Decode envelope + env <- liftEither $ smpDecode envBody + encMsg <- case env of + AP.AgentMsgEnvelope {AP.encAgentMessage = m} -> pure m + _ -> throwError "unexpected AgentMsgEnvelope variant" + -- Ratchet decrypt (IO) + (msgBody_, _, _) <- withExceptT show $ CR.rcDecrypt gAlice aliceRatchet M.empty encMsg + liftEither $ first show msgBody_ + -- Decode agent message from result + agentMsgBytes <- either (error . ("decode failed: " <>)) pure result + Right (AP.AgentMessage _ (AP.A_MSG body)) <- pure $ smpDecode agentMsgBytes + body `shouldBe` "hello from ts full stack" +