@@ -79,53 +79,51 @@ let CHKNORM_LENGTH_SIMPLIFY_CONV =
7979(* Helper lemmas *)
8080(* ------------------------------------------------------------------------- *)
8181
82- (* ival(iword(abs x)) = abs x when abs x < 2^31 *)
83- let IVAL_IWORD_ABS_32 = prove(
84- `!x:int. abs x < &2 pow 31 ==> ival(iword (abs x) : 32 word) = abs x`,
85- GEN_TAC THEN DISCH_TAC THEN
86- MATCH_MP_TAC IVAL_IWORD THEN
87- REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN
88- MP_TAC(SPEC `x:int` INT_ABS_POS) THEN ASM_INT_ARITH_TAC);;
89-
90- (* MAX of {0, 0xFFFFFFFF} conditionals collapses to a single conditional *)
91- let MAX_COND_4_LEMMA = prove(
92- `MAX (if b0 then 4294967295 else 0 )
93- (MAX (if b1 then 4294967295 else 0 )
94- (MAX (if b2 then 4294967295 else 0 )
95- (if b3 then 4294967295 else 0 ))) =
96- if (b0 \/ b1 \/ b2 \/ b3) then 4294967295 else 0 `,
97- MAP_EVERY BOOL_CASES_TAC [`b0:bool`; `b1:bool`; `b2:bool`; `b3:bool`] THEN
98- REWRITE_TAC[] THEN ARITH_TAC);;
99-
100- (* (?i. i < 256 /\ P i) <=> P 0 \/ ... \/ P 255 *)
101- let EXISTS_LT_256 =
102- let p = `P:num->bool` and i_var = `i:num` in
103- let mk_p k = mk_comb(p, mk_small_numeral k) in
104- let rhs = end_itlist (fun a b -> mk_disj(a,b)) (map mk_p (0 --255 )) in
105- let lhs = mk_exists(i_var,
106- mk_conj(mk_comb(mk_comb(`(<)`, i_var), `256 `), mk_comb(p, i_var))) in
107- let arith_rules =
108- ARITH_RULE `i < 1 <=> i = 0 ` ::
109- map (fun k -> ARITH_RULE(subst [mk_small_numeral k, `n:num`;
110- mk_small_numeral(k-1 ), `m:num`]
111- `i < n <=> i = m \/ i < m`)) (2 --256 ) in
112- prove(mk_forall(p, mk_eq(lhs, rhs)),
113- GEN_TAC THEN REWRITE_TAC arith_rules THEN
114- REWRITE_TAC[RIGHT_OR_DISTRIB; EXISTS_OR_THM; UNWIND_THM2] THEN
115- REWRITE_TAC[DISJ_ACI]);;
116-
117- (* word_or of word_neg(word(bitval ...)) combines disjunctively *)
118- let WORD_OR_NEG_BITVAL = prove(
119- `word_or (word_neg (word (bitval b1) : 32 word))
120- (word_neg (word (bitval b2) : 32 word)) : 32 word =
121- word_neg (word (bitval (b1 \/ b2)))`,
122- MAP_EVERY BOOL_CASES_TAC [`b1:bool`; `b2:bool`] THEN
123- REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV);;
124-
125- (* val(word_neg(word(bitval b))) = if b then 0xFFFFFFFF else 0 *)
126- let VAL_WORD_NEG_BITVAL = prove(
127- `val (word_neg (word (bitval b) : 32 word)) = if b then 4294967295 else 0 `,
128- BOOL_CASES_TAC `b:bool` THEN REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV);;
82+ (* Expression emerging from the AVX2 code converting bit to 32-bit mask *)
83+ let bit_to_mask32 = new_definition `bit_to_mask32 (b : bool) : 32 word = word_neg (word (bitval b) : 32 word)`;;
84+
85+ (* Expression used for bounds check itself *)
86+ let bd = new_definition `bd (v : int32) (b: int32) : bool =
87+ (ival (iword (abs (ival v)) : 32 word) >= ival (word_zx (word_zx b : 64 word) : 32 word))`;;
88+
89+ let MAX_VAL_BIT_TO_MASK32 = prove(
90+ `MAX (val (bit_to_mask32 b0)) (val (bit_to_mask32 b1)) = val (bit_to_mask32 (b0 \/ b1))`,
91+ REWRITE_TAC[bit_to_mask32] THEN
92+ BOOL_CASES_TAC `b0:bool` THEN BOOL_CASES_TAC `b1:bool` THEN
93+ REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV THEN ARITH_TAC);;
94+
95+ let BD_SIMP = prove(
96+ `abs(ival(x : int32)) < &2 pow 31 ==> (bd x b <=> abs (ival x) >= ival b)`,
97+ REWRITE_TAC[bd] THEN DISCH_TAC THEN
98+ SUBGOAL_THEN `ival(iword(abs(ival(x:32 word))) : 32 word) = abs(ival x)` SUBST1_TAC THENL
99+ [MATCH_MP_TAC IVAL_IWORD THEN REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN
100+ FIRST_X_ASSUM MP_TAC THEN REWRITE_TAC[INT_ABS_POS] THEN INT_ARITH_TAC; ALL_TAC] THEN
101+ SUBGOAL_THEN `(word_zx:64 word -> 32 word) ((word_zx:32 word -> 64 word) b) = b` SUBST1_TAC THENL
102+ [MATCH_MP_TAC WORD_ZX_ZX THEN REWRITE_TAC[DIMINDEX_32; DIMINDEX_64] THEN ARITH_TAC;
103+ REFL_TAC]);;
104+
105+ let BIT_TO_MASK32_OR = prove(
106+ `word_or (bit_to_mask32 b0) (bit_to_mask32 b1) = bit_to_mask32 (b0 \/ b1)`,
107+ REWRITE_TAC[bit_to_mask32] THEN
108+ BOOL_CASES_TAC `b0:bool` THEN BOOL_CASES_TAC `b1:bool` THEN
109+ REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV);;
110+
111+ let MASK32_TO_BIT = prove(
112+ `(word_zx:32 word -> 64 word) (word_and ((word_zx:64 word -> 32 word)
113+ ((word_zx:32 word -> 64 word) (word_subword (word (val (bit_to_mask32 b)) : 128 word) (0 ,32 ))))
114+ (word 1 )) = word (bitval b) : 64 word`,
115+ REWRITE_TAC[bit_to_mask32] THEN
116+ BOOL_CASES_TAC `b:bool` THEN REWRITE_TAC[BITVAL_CLAUSES] THEN
117+ CONV_TAC WORD_REDUCE_CONV);;
118+
119+ let WORD_JOIN_OR_TYBIT0 = prove(
120+ `word_or (word_join (a:N word) (b:N word) : (N tybit0) word) (word_join (c:N word) (d:N word)) =
121+ word_join (word_or a c) (word_or b d)`,
122+ REWRITE_TAC[WORD_EQ_BITS_ALT; BIT_WORD_OR; BIT_WORD_JOIN; DIMINDEX_TYBIT0] THEN
123+ X_GEN_TAC `i:num` THEN
124+ ASM_CASES_TAC `i < 2 * dimindex(:N)` THEN ASM_REWRITE_TAC[] THEN
125+ ASM_CASES_TAC `i < dimindex(:N)` THEN ASM_REWRITE_TAC[] THEN
126+ MATCH_MP_TAC(TAUT `p ==> (q <=> p /\ q)`) THEN ASM_ARITH_TAC);;
129127
130128(* ------------------------------------------------------------------------- *)
131129(* Core correctness theorem *)
@@ -147,12 +145,12 @@ let MLDSA_POLY_CHKNORM_CORRECT = prove(
147145 CONV_TAC CHKNORM_LENGTH_SIMPLIFY_CONV THEN
148146 MAP_EVERY X_GEN_TAC [`a:int64`; `x:num->int32`; `bound:int32`; `pc:num`] THEN
149147 REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; C_ARGUMENTS;
150- NONOVERLAPPING_CLAUSES; EXISTS_LT_256 ] THEN
148+ NONOVERLAPPING_CLAUSES] THEN
151149 DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN
152150 (* Expand bounded foralls in precondition to 256 explicit cases *)
153- CONV_TAC(RATOR_CONV(LAND_CONV(ONCE_DEPTH_CONV
154- (EXPAND_CASES_CONV THENC ONCE_DEPTH_CONV NUM_MULT_CONV)))) THEN
155151 ENSURES_INIT_TAC " s0" THEN
152+ UNDISCH_TAC `forall i. i < 256 ==> read (memory :> bytes32 (word_add a (word (4 * i)))) s0 = x i` THEN
153+ CONV_TAC(ONCE_DEPTH_CONV (EXPAND_CASES_CONV THENC ONCE_DEPTH_CONV NUM_MULT_CONV)) THEN REPEAT STRIP_TAC THEN
156154 (* Merge bytes32 reads into bytes128 reads (64 merges for 256 coefficients) *)
157155 MP_TAC(end_itlist CONJ (map (fun n -> READ_MEMORY_MERGE_CONV 2
158156 (subst[mk_small_numeral(16 *n),`n:num`]
@@ -165,35 +163,23 @@ let MLDSA_POLY_CHKNORM_CORRECT = prove(
165163 MAP_UNTIL_TARGET_PC (fun n ->
166164 ARM_STEPS_TAC MLDSA_POLY_CHKNORM_EXEC [n] THEN
167165 RULE_ASSUM_TAC(CONV_RULE(TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV)) THEN
168- RULE_ASSUM_TAC(REWRITE_RULE[WORD_SUBWORD_OR])) 1 THEN
169- (* Collapse nested word_or of word_neg pairs, then take val *)
170- RULE_ASSUM_TAC(REWRITE_RULE[WORD_OR_NEG_BITVAL; VAL_WORD_NEG_BITVAL]) THEN
166+ RULE_ASSUM_TAC(REWRITE_RULE[WORD_SUBWORD_OR; GSYM bit_to_mask32; WORD_JOIN_OR_TYBIT0; SYM (SPEC_ALL bd); BIT_TO_MASK32_OR;
167+ MAX_VAL_BIT_TO_MASK32; MASK32_TO_BIT])) 1 THEN
168+
171169 (* Close the state relation *)
172170 ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN
173- (* Prove ival(iword(abs(ival x i))) = abs(ival(x i)) for all 256 coefficients *)
174- SUBGOAL_THEN
175- `!i. i < 256 ==> ival(iword(abs(ival((x:num->int32) i))) : 32 word) = abs(ival(x i))`
176- ASSUME_TAC THENL
177- [REPEAT STRIP_TAC THEN MATCH_MP_TAC IVAL_IWORD_ABS_32 THEN
178- UNDISCH_TAC `(i:num) < 256 ` THEN SPEC_TAC(`i:num`, `i:num`) THEN
179- CONV_TAC EXPAND_CASES_CONV THEN ASM_REWRITE_TAC[];
180- ALL_TAC] THEN
181- (* Apply the ival/iword simplification for all 256 coefficients *)
182- FIRST_X_ASSUM(fun th -> REWRITE_TAC
183- (map (fun k -> MATCH_MP th
184- (ARITH_RULE(subst [mk_small_numeral k, `n:num`] `n < 256 `)))
185- (0 --255 ))) THEN
186- (* Simplify word_zx round-trip for bound *)
187- REWRITE_TAC[prove(
188- `ival(word_zx ((word_zx:32 word->64 word) (bound:32 word)) : 32 word) = ival bound`,
189- BITBLAST_TAC)] THEN
190- (* Rewrite MAX of conditionals to a single conditional *)
191- REWRITE_TAC[MAX_COND_4_LEMMA] THEN
192- (* Normalize the disjunction order *)
193- REWRITE_TAC[DISJ_ACI] THEN
194- (* Case split on the condition and simplify word operations *)
195- COND_CASES_TAC THEN
196- ASM_REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV);;
171+ DISCARD_MATCHING_ASSUMPTIONS [`read t s = x`] THEN
172+
173+ RULE_ASSUM_TAC (CONV_RULE (ONCE_DEPTH_CONV EXPAND_CASES_CONV)) THEN
174+ REPEAT(FIRST_X_ASSUM(CONJUNCTS_THEN ASSUME_TAC)) THEN
175+ IMP_REWRITE_TAC [BD_SIMP] THEN
176+ POP_ASSUM_LIST (K ALL_TAC) THEN
177+
178+ (* Convert to ! instead of ? and split *)
179+ GEN_REWRITE_TAC (BINOP_CONV o ONCE_DEPTH_CONV) [prove (`b = ~ (~ (b : bool))`, REWRITE_TAC[])] THEN
180+ GEN_REWRITE_TAC TOP_SWEEP_CONV [MESON[] `~(?i. i < n /\ P i) <=> (!i. i < n ==> ~P i)`; DE_MORGAN_THM] THEN
181+ CONV_TAC (ONCE_DEPTH_CONV EXPAND_CASES_CONV) THEN
182+ REPEAT AP_TERM_TAC THEN EQ_TAC THEN SIMP_TAC[]);;
197183
198184(* ------------------------------------------------------------------------- *)
199185(* Subroutine correctness theorem (includes return) *)
@@ -220,9 +206,8 @@ let MLDSA_POLY_CHKNORM_SUBROUTINE_CORRECT = prove(
220206 let TWEAK_CONV =
221207 ONCE_DEPTH_CONV EXPAND_CASES_CONV THENC
222208 ONCE_DEPTH_CONV NUM_MULT_CONV THENC
223- PURE_REWRITE_CONV [WORD_ADD_0; EXISTS_LT_256 ] in
209+ PURE_REWRITE_CONV [WORD_ADD_0] in
224210 CONV_TAC TWEAK_CONV THEN
225211 ARM_ADD_RETURN_NOSTACK_TAC MLDSA_POLY_CHKNORM_EXEC
226212 (CONV_RULE TWEAK_CONV
227213 (CONV_RULE CHKNORM_LENGTH_SIMPLIFY_CONV MLDSA_POLY_CHKNORM_CORRECT)));;
228-
0 commit comments