Skip to content

Commit cd0e54c

Browse files
committed
HOL-Light/AArch64: Refactoring of poly_chknorm proof
- Rewrite expressions during symbolic execution to keep system states readable - Keep quantified propositions folded to the point where case-by-case analysis is needed - Hoist all helper lemmas out of the main proof for better readability Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
1 parent 627b734 commit cd0e54c

1 file changed

Lines changed: 64 additions & 79 deletions

File tree

proofs/hol_light/aarch64/proofs/mldsa_poly_chknorm.ml

Lines changed: 64 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -79,53 +79,51 @@ let CHKNORM_LENGTH_SIMPLIFY_CONV =
7979
(* Helper lemmas *)
8080
(* ------------------------------------------------------------------------- *)
8181

82-
(* ival(iword(abs x)) = abs x when abs x < 2^31 *)
83-
let IVAL_IWORD_ABS_32 = prove(
84-
`!x:int. abs x < &2 pow 31 ==> ival(iword (abs x) : 32 word) = abs x`,
85-
GEN_TAC THEN DISCH_TAC THEN
86-
MATCH_MP_TAC IVAL_IWORD THEN
87-
REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN
88-
MP_TAC(SPEC `x:int` INT_ABS_POS) THEN ASM_INT_ARITH_TAC);;
89-
90-
(* MAX of {0, 0xFFFFFFFF} conditionals collapses to a single conditional *)
91-
let MAX_COND_4_LEMMA = prove(
92-
`MAX (if b0 then 4294967295 else 0)
93-
(MAX (if b1 then 4294967295 else 0)
94-
(MAX (if b2 then 4294967295 else 0)
95-
(if b3 then 4294967295 else 0))) =
96-
if (b0 \/ b1 \/ b2 \/ b3) then 4294967295 else 0`,
97-
MAP_EVERY BOOL_CASES_TAC [`b0:bool`; `b1:bool`; `b2:bool`; `b3:bool`] THEN
98-
REWRITE_TAC[] THEN ARITH_TAC);;
99-
100-
(* (?i. i < 256 /\ P i) <=> P 0 \/ ... \/ P 255 *)
101-
let EXISTS_LT_256 =
102-
let p = `P:num->bool` and i_var = `i:num` in
103-
let mk_p k = mk_comb(p, mk_small_numeral k) in
104-
let rhs = end_itlist (fun a b -> mk_disj(a,b)) (map mk_p (0--255)) in
105-
let lhs = mk_exists(i_var,
106-
mk_conj(mk_comb(mk_comb(`(<)`, i_var), `256`), mk_comb(p, i_var))) in
107-
let arith_rules =
108-
ARITH_RULE `i < 1 <=> i = 0` ::
109-
map (fun k -> ARITH_RULE(subst [mk_small_numeral k, `n:num`;
110-
mk_small_numeral(k-1), `m:num`]
111-
`i < n <=> i = m \/ i < m`)) (2--256) in
112-
prove(mk_forall(p, mk_eq(lhs, rhs)),
113-
GEN_TAC THEN REWRITE_TAC arith_rules THEN
114-
REWRITE_TAC[RIGHT_OR_DISTRIB; EXISTS_OR_THM; UNWIND_THM2] THEN
115-
REWRITE_TAC[DISJ_ACI]);;
116-
117-
(* word_or of word_neg(word(bitval ...)) combines disjunctively *)
118-
let WORD_OR_NEG_BITVAL = prove(
119-
`word_or (word_neg (word (bitval b1) : 32 word))
120-
(word_neg (word (bitval b2) : 32 word)) : 32 word =
121-
word_neg (word (bitval (b1 \/ b2)))`,
122-
MAP_EVERY BOOL_CASES_TAC [`b1:bool`; `b2:bool`] THEN
123-
REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV);;
124-
125-
(* val(word_neg(word(bitval b))) = if b then 0xFFFFFFFF else 0 *)
126-
let VAL_WORD_NEG_BITVAL = prove(
127-
`val (word_neg (word (bitval b) : 32 word)) = if b then 4294967295 else 0`,
128-
BOOL_CASES_TAC `b:bool` THEN REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV);;
82+
(* Expression emerging from the AVX2 code converting bit to 32-bit mask *)
83+
let bit_to_mask32 = new_definition `bit_to_mask32 (b : bool) : 32 word = word_neg (word (bitval b) : 32 word)`;;
84+
85+
(* Expression used for bounds check itself *)
86+
let bd = new_definition `bd (v : int32) (b: int32) : bool =
87+
(ival (iword (abs (ival v)) : 32 word) >= ival (word_zx (word_zx b : 64 word) : 32 word))`;;
88+
89+
let MAX_VAL_BIT_TO_MASK32 = prove(
90+
`MAX (val (bit_to_mask32 b0)) (val (bit_to_mask32 b1)) = val (bit_to_mask32 (b0 \/ b1))`,
91+
REWRITE_TAC[bit_to_mask32] THEN
92+
BOOL_CASES_TAC `b0:bool` THEN BOOL_CASES_TAC `b1:bool` THEN
93+
REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV THEN ARITH_TAC);;
94+
95+
let BD_SIMP = prove(
96+
`abs(ival(x : int32)) < &2 pow 31 ==> (bd x b <=> abs (ival x) >= ival b)`,
97+
REWRITE_TAC[bd] THEN DISCH_TAC THEN
98+
SUBGOAL_THEN `ival(iword(abs(ival(x:32 word))) : 32 word) = abs(ival x)` SUBST1_TAC THENL
99+
[MATCH_MP_TAC IVAL_IWORD THEN REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN
100+
FIRST_X_ASSUM MP_TAC THEN REWRITE_TAC[INT_ABS_POS] THEN INT_ARITH_TAC; ALL_TAC] THEN
101+
SUBGOAL_THEN `(word_zx:64 word -> 32 word) ((word_zx:32 word -> 64 word) b) = b` SUBST1_TAC THENL
102+
[MATCH_MP_TAC WORD_ZX_ZX THEN REWRITE_TAC[DIMINDEX_32; DIMINDEX_64] THEN ARITH_TAC;
103+
REFL_TAC]);;
104+
105+
let BIT_TO_MASK32_OR = prove(
106+
`word_or (bit_to_mask32 b0) (bit_to_mask32 b1) = bit_to_mask32 (b0 \/ b1)`,
107+
REWRITE_TAC[bit_to_mask32] THEN
108+
BOOL_CASES_TAC `b0:bool` THEN BOOL_CASES_TAC `b1:bool` THEN
109+
REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV);;
110+
111+
let MASK32_TO_BIT = prove(
112+
`(word_zx:32 word -> 64 word) (word_and ((word_zx:64 word -> 32 word)
113+
((word_zx:32 word -> 64 word) (word_subword (word (val (bit_to_mask32 b)) : 128 word) (0,32))))
114+
(word 1)) = word (bitval b) : 64 word`,
115+
REWRITE_TAC[bit_to_mask32] THEN
116+
BOOL_CASES_TAC `b:bool` THEN REWRITE_TAC[BITVAL_CLAUSES] THEN
117+
CONV_TAC WORD_REDUCE_CONV);;
118+
119+
let WORD_JOIN_OR_TYBIT0 = prove(
120+
`word_or (word_join (a:N word) (b:N word) : (N tybit0) word) (word_join (c:N word) (d:N word)) =
121+
word_join (word_or a c) (word_or b d)`,
122+
REWRITE_TAC[WORD_EQ_BITS_ALT; BIT_WORD_OR; BIT_WORD_JOIN; DIMINDEX_TYBIT0] THEN
123+
X_GEN_TAC `i:num` THEN
124+
ASM_CASES_TAC `i < 2 * dimindex(:N)` THEN ASM_REWRITE_TAC[] THEN
125+
ASM_CASES_TAC `i < dimindex(:N)` THEN ASM_REWRITE_TAC[] THEN
126+
MATCH_MP_TAC(TAUT `p ==> (q <=> p /\ q)`) THEN ASM_ARITH_TAC);;
129127

130128
(* ------------------------------------------------------------------------- *)
131129
(* Core correctness theorem *)
@@ -147,12 +145,12 @@ let MLDSA_POLY_CHKNORM_CORRECT = prove(
147145
CONV_TAC CHKNORM_LENGTH_SIMPLIFY_CONV THEN
148146
MAP_EVERY X_GEN_TAC [`a:int64`; `x:num->int32`; `bound:int32`; `pc:num`] THEN
149147
REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; C_ARGUMENTS;
150-
NONOVERLAPPING_CLAUSES; EXISTS_LT_256] THEN
148+
NONOVERLAPPING_CLAUSES] THEN
151149
DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN
152150
(* Expand bounded foralls in precondition to 256 explicit cases *)
153-
CONV_TAC(RATOR_CONV(LAND_CONV(ONCE_DEPTH_CONV
154-
(EXPAND_CASES_CONV THENC ONCE_DEPTH_CONV NUM_MULT_CONV)))) THEN
155151
ENSURES_INIT_TAC "s0" THEN
152+
UNDISCH_TAC `forall i. i < 256 ==> read (memory :> bytes32 (word_add a (word (4 * i)))) s0 = x i` THEN
153+
CONV_TAC(ONCE_DEPTH_CONV (EXPAND_CASES_CONV THENC ONCE_DEPTH_CONV NUM_MULT_CONV)) THEN REPEAT STRIP_TAC THEN
156154
(* Merge bytes32 reads into bytes128 reads (64 merges for 256 coefficients) *)
157155
MP_TAC(end_itlist CONJ (map (fun n -> READ_MEMORY_MERGE_CONV 2
158156
(subst[mk_small_numeral(16*n),`n:num`]
@@ -165,35 +163,23 @@ let MLDSA_POLY_CHKNORM_CORRECT = prove(
165163
MAP_UNTIL_TARGET_PC (fun n ->
166164
ARM_STEPS_TAC MLDSA_POLY_CHKNORM_EXEC [n] THEN
167165
RULE_ASSUM_TAC(CONV_RULE(TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV)) THEN
168-
RULE_ASSUM_TAC(REWRITE_RULE[WORD_SUBWORD_OR])) 1 THEN
169-
(* Collapse nested word_or of word_neg pairs, then take val *)
170-
RULE_ASSUM_TAC(REWRITE_RULE[WORD_OR_NEG_BITVAL; VAL_WORD_NEG_BITVAL]) THEN
166+
RULE_ASSUM_TAC(REWRITE_RULE[WORD_SUBWORD_OR; GSYM bit_to_mask32; WORD_JOIN_OR_TYBIT0; SYM (SPEC_ALL bd); BIT_TO_MASK32_OR;
167+
MAX_VAL_BIT_TO_MASK32; MASK32_TO_BIT])) 1 THEN
168+
171169
(* Close the state relation *)
172170
ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN
173-
(* Prove ival(iword(abs(ival x i))) = abs(ival(x i)) for all 256 coefficients *)
174-
SUBGOAL_THEN
175-
`!i. i < 256 ==> ival(iword(abs(ival((x:num->int32) i))) : 32 word) = abs(ival(x i))`
176-
ASSUME_TAC THENL
177-
[REPEAT STRIP_TAC THEN MATCH_MP_TAC IVAL_IWORD_ABS_32 THEN
178-
UNDISCH_TAC `(i:num) < 256` THEN SPEC_TAC(`i:num`, `i:num`) THEN
179-
CONV_TAC EXPAND_CASES_CONV THEN ASM_REWRITE_TAC[];
180-
ALL_TAC] THEN
181-
(* Apply the ival/iword simplification for all 256 coefficients *)
182-
FIRST_X_ASSUM(fun th -> REWRITE_TAC
183-
(map (fun k -> MATCH_MP th
184-
(ARITH_RULE(subst [mk_small_numeral k, `n:num`] `n < 256`)))
185-
(0--255))) THEN
186-
(* Simplify word_zx round-trip for bound *)
187-
REWRITE_TAC[prove(
188-
`ival(word_zx ((word_zx:32 word->64 word) (bound:32 word)) : 32 word) = ival bound`,
189-
BITBLAST_TAC)] THEN
190-
(* Rewrite MAX of conditionals to a single conditional *)
191-
REWRITE_TAC[MAX_COND_4_LEMMA] THEN
192-
(* Normalize the disjunction order *)
193-
REWRITE_TAC[DISJ_ACI] THEN
194-
(* Case split on the condition and simplify word operations *)
195-
COND_CASES_TAC THEN
196-
ASM_REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV);;
171+
DISCARD_MATCHING_ASSUMPTIONS [`read t s = x`] THEN
172+
173+
RULE_ASSUM_TAC (CONV_RULE (ONCE_DEPTH_CONV EXPAND_CASES_CONV)) THEN
174+
REPEAT(FIRST_X_ASSUM(CONJUNCTS_THEN ASSUME_TAC)) THEN
175+
IMP_REWRITE_TAC [BD_SIMP] THEN
176+
POP_ASSUM_LIST (K ALL_TAC) THEN
177+
178+
(* Convert to ! instead of ? and split *)
179+
GEN_REWRITE_TAC (BINOP_CONV o ONCE_DEPTH_CONV) [prove (`b = ~ (~ (b : bool))`, REWRITE_TAC[])] THEN
180+
GEN_REWRITE_TAC TOP_SWEEP_CONV [MESON[] `~(?i. i < n /\ P i) <=> (!i. i < n ==> ~P i)`; DE_MORGAN_THM] THEN
181+
CONV_TAC (ONCE_DEPTH_CONV EXPAND_CASES_CONV) THEN
182+
REPEAT AP_TERM_TAC THEN EQ_TAC THEN SIMP_TAC[]);;
197183

198184
(* ------------------------------------------------------------------------- *)
199185
(* Subroutine correctness theorem (includes return) *)
@@ -220,9 +206,8 @@ let MLDSA_POLY_CHKNORM_SUBROUTINE_CORRECT = prove(
220206
let TWEAK_CONV =
221207
ONCE_DEPTH_CONV EXPAND_CASES_CONV THENC
222208
ONCE_DEPTH_CONV NUM_MULT_CONV THENC
223-
PURE_REWRITE_CONV [WORD_ADD_0; EXISTS_LT_256] in
209+
PURE_REWRITE_CONV [WORD_ADD_0] in
224210
CONV_TAC TWEAK_CONV THEN
225211
ARM_ADD_RETURN_NOSTACK_TAC MLDSA_POLY_CHKNORM_EXEC
226212
(CONV_RULE TWEAK_CONV
227213
(CONV_RULE CHKNORM_LENGTH_SIMPLIFY_CONV MLDSA_POLY_CHKNORM_CORRECT)));;
228-

0 commit comments

Comments
 (0)