From 58cc113845fa82d4ceb6ad4cb3d497dd13f5d0b6 Mon Sep 17 00:00:00 2001 From: Pierre-Yves Strub Date: Thu, 26 Mar 2026 09:30:38 +0100 Subject: [PATCH] Add contextual rewrite-pattern selection Allow rewrite patterns to designate a subterm inside a larger context with the [x in pattern] syntax. This lets rewrite target exactly the occurrence named by the surrounding context, and adds regression coverage for that form. The context variable must appear exactly once in the pattern (linearity check). Delta expansion and conversion are disabled during contextual pattern matching to ensure position computation remains sound. --- src/ecHiGoal.ml | 248 ++++++++++++++++++++++----- src/ecHiGoal.mli | 7 +- src/ecParser.mly | 15 +- src/ecParsetree.ml | 11 +- src/ecProofTerm.ml | 16 +- src/ecProofTerm.mli | 1 + tests/rw_explicit_eq_with_pattern.ec | 122 +++++++++++++ 7 files changed, 361 insertions(+), 59 deletions(-) diff --git a/src/ecHiGoal.ml b/src/ecHiGoal.ml index 1b79f912fc..9bdfe3dcb1 100644 --- a/src/ecHiGoal.ml +++ b/src/ecHiGoal.ml @@ -18,6 +18,7 @@ open EcLowGoal module Sid = EcIdent.Sid module Mid = EcIdent.Mid +module Mint = EcMaps.Mint module Sp = EcPath.Sp module ER = EcReduction @@ -249,6 +250,8 @@ module LowRewrite = struct | LRW_IdRewriting | LRW_RPatternNoMatch | LRW_RPatternNoRuleMatch + | LRW_RPatternNotLinear + | LRW_RPatternNoVariable exception RewriteError of error @@ -326,24 +329,76 @@ module LowRewrite = struct let find_rewrite_patterns = find_rewrite_patterns ~inpred:false - type rwinfos = rwside * EcFol.form option * EcMatching.occ option + type rwinfos = rwside * (form * (EcIdent.t * ty) option) option * EcMatching.occ option - let t_rewrite_r ?(mode = `Full) ?target ((s, prw, o) : rwinfos) pt tc = + let first_occurrence (p : EcMatching.ptnpos) = + EcMatching.FPosition.filter (`Inclusive, EcMaps.Sint.singleton 1) p + + let ptnpos_as_map (p : EcMatching.ptnpos) = + (p :> [`Select of int | `Sub of EcMatching.ptnpos] Mint.t) + + let path_of_occurrence (p : EcMatching.ptnpos) = + let rec aux acc (p : EcMatching.ptnpos) = + let p = ptnpos_as_map p in + assert (Mint.cardinal p = 1); + + let i, p = Mint.choose p in + + match p with + | `Select _ -> List.rev (i :: acc) + | `Sub p -> aux (i :: acc) p + in + + let p = first_occurrence p in + let p = ptnpos_as_map p in + + assert (Mint.cardinal p = 1); + + let i, p = Mint.choose p in + assert (i = 0); + + match p with + | `Select _ -> [] + | `Sub p -> aux [] p + + let first_selected_subform (p : EcMatching.ptnpos) (f : form) = + let selected = ref None in + + let _ = + EcMatching.FPosition.map p + (fun fp -> + if Option.is_none !selected then selected := Some fp; + fp) + f + in + + oget !selected + + let t_rewrite_r + ?(mode : [`Full | `Light] = `Full) + ?(target : EcIdent.t option) + ((s, prw, o) : rwinfos) + (pt : pt_ev) + (tc : tcenv1) + = let hyps, tgfp = FApi.tc1_flat ?target tc in let modes = match mode with - | `Full -> [{ k_keyed = true; k_conv = false }; - { k_keyed = true; k_conv = true };] - | `Light -> [{ k_keyed = true; k_conv = false }] in + | `Full -> [{ k_keyed = true; k_conv = false; k_delta = true }; + { k_keyed = true; k_conv = true; k_delta = true };] + | `Light -> [{ k_keyed = true; k_conv = false; k_delta = true }] in let for1 (pt, mode, (f1, f2)) = let fp, tp = match s with `LtoR -> f1, f2 | `RtoL -> f2, f1 in - let subf, occmode = + let subf, occmode, cpos = match prw with | None -> begin try - PT.pf_find_occurence_lazy pt.PT.ptev_env ~modes ~ptn:fp tgfp + let subf, occmode = + PT.pf_find_occurence_lazy pt.PT.ptev_env ~modes ~ptn:fp tgfp + in + (subf, occmode, None) with | PT.FindOccFailure `MatchFailure -> raise (RewriteError LRW_NothingToRewrite) @@ -351,23 +406,94 @@ module LowRewrite = struct raise (RewriteError LRW_CannotInfer) end - | Some prw -> begin - let prw, _ = - try - PT.pf_find_occurence_lazy - pt.PT.ptev_env ~full:false ~modes ~ptn:prw tgfp - with PT.FindOccFailure `MatchFailure -> - raise (RewriteError LRW_RPatternNoMatch) in + | Some (prw, subl) -> begin + let subcpos = + match subl with + | None -> None - try - PT.pf_find_occurence_lazy - pt.PT.ptev_env ~rooted:true ~modes ~ptn:fp prw - with - | PT.FindOccFailure `MatchFailure -> - raise (RewriteError LRW_RPatternNoRuleMatch) - | PT.FindOccFailure `IncompleteMatch -> - raise (RewriteError LRW_CannotInfer) - end in + | Some (x, xty) -> + let fx = f_local x xty in + let subcpos = + FPosition.select_form + ~xconv:`Eq ~keyed:true hyps None fx prw + in + + if FPosition.is_empty subcpos then + raise (RewriteError LRW_RPatternNoVariable); + + if FPosition.occurences subcpos <> 1 then + raise (RewriteError LRW_RPatternNotLinear); + + let subcpos = + match o with + | None -> subcpos + | Some o -> + if not (FPosition.is_occurences_valid (snd o) subcpos) then + raise (RewriteError LRW_InvalidOccurence); + FPosition.filter o subcpos + in + + Some subcpos + in + + let ctxt_modes = + match subl with + | None -> modes + | Some _ -> [{ k_keyed = true; k_conv = false; k_delta = false }] + in + + let prw, prwmode = + try + PT.pf_find_occurence_lazy + pt.PT.ptev_env ~full:false ~modes:ctxt_modes ~ptn:prw tgfp + with PT.FindOccFailure `MatchFailure -> + raise (RewriteError LRW_RPatternNoMatch) in + + match subcpos with + | None -> + begin + try + let subf, occmode = + PT.pf_find_occurence_lazy + pt.PT.ptev_env ~rooted:true ~modes:ctxt_modes ~ptn:fp prw + in + (subf, occmode, None) + with + | PT.FindOccFailure `MatchFailure -> + raise (RewriteError LRW_RPatternNoRuleMatch) + | PT.FindOccFailure `IncompleteMatch -> + raise (RewriteError LRW_CannotInfer) + end + + | Some subcpos -> + let subf = first_selected_subform subcpos prw in + + begin + try + ignore + (PT.pf_find_occurence_lazy + pt.PT.ptev_env ~rooted:true ~modes:ctxt_modes ~ptn:fp subf) + with + | PT.FindOccFailure `MatchFailure -> + raise (RewriteError LRW_RPatternNoRuleMatch) + | PT.FindOccFailure `IncompleteMatch -> + raise (RewriteError LRW_CannotInfer) + end; + + let cpos = + let prwpos = + FPosition.select_form + ~xconv:`AlphaEq ~keyed:prwmode.k_keyed hyps + (Some (`Inclusive, EcMaps.Sint.singleton 1)) + prw tgfp + in + let root = path_of_occurrence prwpos in + FPosition.reroot root subcpos + in + + (subf, { k_keyed = true; k_conv = false; k_delta = false }, Some cpos) + end + in if not occmode.k_keyed then begin let tp = PT.concretize_form pt.PT.ptev_env tp in @@ -377,10 +503,15 @@ module LowRewrite = struct let pt = fst (PT.concretize pt) in let cpos = - try FPosition.select_form - ~xconv:`AlphaEq ~keyed:occmode.k_keyed - hyps o subf tgfp - with InvalidOccurence -> raise (RewriteError (LRW_InvalidOccurence)) + match cpos with + | Some cpos -> cpos + | None -> + try + FPosition.select_form + ~xconv:`AlphaEq ~keyed:occmode.k_keyed + hyps o subf tgfp + with InvalidOccurence -> + raise (RewriteError LRW_InvalidOccurence) in EcLowGoal.t_rewrite @@ -569,7 +700,14 @@ let process_apply_top tc = | _ -> tc_error !!tc "no top assumption" (* -------------------------------------------------------------------- *) -let process_rewrite1_core ?mode ?(close = true) ?target (s, p, o) pt tc = +let process_rewrite1_core + ?(mode : [`Full | `Light] option) + ?(close : bool = true) + ?(target : EcIdent.t option) + ((s, p, o) : rwside * (form * (EcIdent.t * ty) option) option * rwocc) + (pt : pt_ev) + (tc : tcenv1) += let o = norm_rwocc o in try @@ -596,9 +734,13 @@ let process_rewrite1_core ?mode ?(close = true) ?target (s, p, o) pt tc = tc_error !!tc "r-pattern does not match the goal" | LowRewrite.LRW_RPatternNoRuleMatch -> tc_error !!tc "r-pattern does not match the rewriting rule" + | LowRewrite.LRW_RPatternNotLinear -> + tc_error !!tc "context variable must appear exactly once in the r-pattern" + | LowRewrite.LRW_RPatternNoVariable -> + tc_error !!tc "context variable does not appear in the r-pattern" (* -------------------------------------------------------------------- *) -let process_delta ~und_delta ?target (s, o, p) tc = +let process_delta ~und_delta ?target ((s :rwside), o, p) tc = let env, hyps, concl = FApi.tc1_eflat tc in let o = norm_rwocc o in @@ -768,38 +910,50 @@ let process_rewrite1_r ttenv ?target ri tc = let target = target |> omap (fst -| ((LDecl.hyp_by_name^~ hyps) -| unloc)) in t_simplify_lg ?target ~delta:`IfApplied (ttenv, logic) tc - | RWDelta ((s, r, o, px), p) -> begin - if Option.is_some px then + | RWDelta (rwopt, p) -> begin + if Option.is_some rwopt.match_ then tc_error !!tc "cannot use pattern selection in delta-rewrite rules"; - let do1 tc = process_delta ~und_delta ?target (s, o, p) tc in + let do1 tc = + process_delta ~und_delta ?target (rwopt.side, rwopt.occurrence, p) tc in - match r with + match rwopt.repeat with | None -> do1 tc | Some (b, n) -> t_do b n do1 tc end - | RWRw (((s : rwside), r, o, p), pts) -> begin + | RWRw (rwopt, pts) -> begin let do1 (mode : [`Full | `Light]) ((subs : rwside), pt) tc = let hyps = FApi.tc1_hyps tc in let target = target |> omap (fst -| ((LDecl.hyp_by_name^~ hyps) -| unloc)) in let hyps = FApi.tc1_hyps ?target tc in let ptenv, prw = - match p with + match rwopt.match_ with | None -> PT.ptenv_of_penv hyps !!tc, None - | Some p -> + | Some (RWM_Plain p) -> let (ps, ue), p = TTC.tc1_process_pattern tc p in let ev = MEV.of_idents (Mid.keys ps) `Form in - (PT.ptenv !!tc hyps (ue, ev), Some p) in + (PT.ptenv !!tc hyps (ue, ev), Some (p, None)) + + | Some (RWM_Context (x, p)) -> + let ps = ref Mid.empty in + let ue = EcProofTyping.unienv_of_hyps hyps in + let x = EcIdent.create (unloc x) in + let xty = EcUnify.UniEnv.fresh ue in + let hyps = FApi.tc1_hyps tc in + let hyps = LDecl.add_local x (LD_var (xty, None)) hyps in + let p = EcTyping.trans_pattern (LDecl.toenv hyps) ps ue p in + let ev = MEV.of_idents (x :: Mid.keys !ps) `Form in + (PT.ptenv !!tc hyps (ue, ev), Some (p, Some (x, xty))) in let theside = - match s, subs with - | `LtoR, _ -> (subs :> rwside) - | _ , `LtoR -> (s :> rwside) - | `RtoL, `RtoL -> (`LtoR :> rwside) in + match rwopt.side, subs with + | `LtoR, _ -> (subs :> rwside) + | _ , `LtoR -> (rwopt.side :> rwside) + | `RtoL, `RtoL -> (`LtoR :> rwside) in let is_baserw p = EcEnv.BaseRw.is_base p.pl_desc (FApi.tc1_env tc) in @@ -814,7 +968,7 @@ let process_rewrite1_r ttenv ?target ri tc = let do1 lemma tc = let pt = PT.pt_of_uglobal_r (PT.copy ptenv) lemma in - process_rewrite1_core ~mode ?target (theside, prw, o) pt tc + process_rewrite1_core ~mode ?target (theside, prw, rwopt.occurrence) pt tc in t_ors (List.map do1 ls) tc | { fp_head = FPNamed (p, None); fp_args = []; } @@ -832,11 +986,11 @@ let process_rewrite1_r ttenv ?target ri tc = let do1 (lemma, _) tc = let pt = PT.pt_of_uglobal_r (PT.copy ptenv0) lemma in - process_rewrite1_core ~mode ?target (theside, prw, o) pt tc in + process_rewrite1_core ~mode ?target (theside, prw, rwopt.occurrence) pt tc in t_ors (List.map do1 ls) tc | _ -> - process_rewrite1_core ~mode ?target (theside, prw, o) pt tc + process_rewrite1_core ~mode ?target (theside, prw, rwopt.occurrence) pt tc end | { fp_head = FPCut (Some f); fp_args = []; } @@ -856,16 +1010,16 @@ let process_rewrite1_r ttenv ?target ri tc = let pt = PTApply { pt_head = PTCut (f, None); pt_args = []; } in let pt = { ptev_env = ptenv; ptev_pt = pt; ptev_ax = f; } in - process_rewrite1_core ~mode ?target (theside, prw, o) pt tc + process_rewrite1_core ~mode ?target (theside, prw, rwopt.occurrence) pt tc | _ -> let pt = PT.process_full_pterm ~implicits ptenv pt in - process_rewrite1_core ~mode ?target (theside, prw, o) pt tc + process_rewrite1_core ~mode ?target (theside, prw, rwopt.occurrence) pt tc in let doall mode tc = t_ors (List.map (do1 mode) pts) tc in - match r with + match rwopt.repeat with | None -> doall `Full tc | Some (`Maybe, None) -> diff --git a/src/ecHiGoal.mli b/src/ecHiGoal.mli index 317163fd6e..81cf7a6b9c 100644 --- a/src/ecHiGoal.mli +++ b/src/ecHiGoal.mli @@ -41,13 +41,18 @@ module LowRewrite : sig | LRW_IdRewriting | LRW_RPatternNoMatch | LRW_RPatternNoRuleMatch + | LRW_RPatternNotLinear + | LRW_RPatternNoVariable exception RewriteError of error val find_rewrite_patterns: rwside -> pt_ev -> (pt_ev * rwmode * (form * form)) list - type rwinfos = rwside * EcFol.form option * EcMatching.occ option + type rwinfos = + rwside + * (form * (EcIdent.t * EcTypes.ty) option) option + * EcMatching.occ option val t_rewrite_r: ?mode:[ `Full | `Light] -> diff --git a/src/ecParser.mly b/src/ecParser.mly index 7bfd266ee1..f01b02b996 100644 --- a/src/ecParser.mly +++ b/src/ecParser.mly @@ -2416,11 +2416,11 @@ rwarg1: | SLASHTILDEQ { RWSimpl `Variant } -| s=rwside r=rwrepeat? o=rwocc? p=bracket(form_h)? fp=rwpterms - { RWRw ((s, r, o, p), fp) } +| side=rwside repeat=rwrepeat? occurrence=rwocc? match_=bracket(rwmatch)? fp=rwpterms + { RWRw ({ side; repeat; occurrence; match_ }, fp) } -| s=rwside r=rwrepeat? o=rwocc? SLASH x=sform_h %prec prec_tactic - { RWDelta ((s, r, o, None), x); } +| side=rwside repeat=rwrepeat? occurrence=rwocc? SLASH fp=sform_h %prec prec_tactic + { RWDelta ({ side; repeat; occurrence; match_ = None }, fp); } | PR s=bracket(rwpr_arg) { RWPr s } @@ -2446,6 +2446,13 @@ rwarg1: parse_error (loc x) (Some msg) } +rwmatch: +| p=form_h + { RWM_Plain p } + +| x=ident IN p=form_h + { RWM_Context (x, p) } + rwpterms: | f=pterm { [(`LtoR, f)] } diff --git a/src/ecParsetree.ml b/src/ecParsetree.ml index 5db41e85a2..7e87f88561 100644 --- a/src/ecParsetree.ml +++ b/src/ecParsetree.ml @@ -926,12 +926,21 @@ and rwarg1 = | RWApp of ppterm | RWTactic of rwtactic -and rwoptions = rwside * trepeat option * rwocc * pformula option +and rwmatch = + | RWM_Plain of pformula + | RWM_Context of psymbol * pformula + and rwside = [`LtoR | `RtoL] and rwocc = rwocci option and rwocci = [`Inclusive of Sint.t | `Exclusive of Sint.t | `All] and rwtactic = [`Ring | `Field] +and rwoptions = + { side : rwside + ; repeat : trepeat option + ; occurrence : rwocc + ; match_ : rwmatch option } + (* -------------------------------------------------------------------- *) let norm_rwocci (x : rwocci) = match x with diff --git a/src/ecProofTerm.ml b/src/ecProofTerm.ml index 9055f24629..b1416d9ea0 100644 --- a/src/ecProofTerm.ml +++ b/src/ecProofTerm.ml @@ -291,16 +291,17 @@ exception FindOccFailure of [`MatchFailure | `IncompleteMatch] type occmode = { k_keyed : bool; k_conv : bool; + k_delta : bool; } -let om_rigid = { k_keyed = true; k_conv = false; } +let om_rigid = { k_keyed = true; k_conv = false; k_delta = true; } let pf_find_occurence (pt : pt_env) ?(full = true) ?(rooted = false) ?occmode ~ptn subject = let module E = struct exception MatchFound of form end in - let occmode = odfl { k_keyed = false; k_conv = true; } occmode in + let occmode = odfl { k_keyed = false; k_conv = true; k_delta = true; } occmode in let na = List.length (snd (EcFol.destr_app ptn)) in let ho = @@ -339,7 +340,10 @@ let pf_find_occurence then EcMatching.fmrigid else EcMatching.fmdelta in - let mode = { mode with fm_conv = occmode.k_conv } in + let mode = { mode with + fm_conv = occmode.k_conv; + fm_delta = mode.fm_delta && occmode.k_delta; + } in let trymatch mode bds tp = if not (keycheck tp key) then `Continue else @@ -382,9 +386,9 @@ let pf_find_occurence (* -------------------------------------------------------------------- *) let default_modes = [ - { k_keyed = true; k_conv = false; }; - { k_keyed = true; k_conv = true; }; - { k_keyed = false; k_conv = true; }; + { k_keyed = true; k_conv = false; k_delta = true; }; + { k_keyed = true; k_conv = true; k_delta = true; }; + { k_keyed = false; k_conv = true; k_delta = true; }; ] let pf_find_occurence_lazy diff --git a/src/ecProofTerm.mli b/src/ecProofTerm.mli index 8f54208794..2f9a44c224 100644 --- a/src/ecProofTerm.mli +++ b/src/ecProofTerm.mli @@ -115,6 +115,7 @@ exception FindOccFailure of [`MatchFailure | `IncompleteMatch] type occmode = { k_keyed : bool; k_conv : bool; + k_delta : bool; } val om_rigid : occmode diff --git a/tests/rw_explicit_eq_with_pattern.ec b/tests/rw_explicit_eq_with_pattern.ec index 6b6aa3d9c9..1d3a580da1 100644 --- a/tests/rw_explicit_eq_with_pattern.ec +++ b/tests/rw_explicit_eq_with_pattern.ec @@ -24,3 +24,125 @@ rewrite [_ + c](_ : _ + c = d). - suff: x + c = d by exact. admit. - suff: x + d = x by exact. admit. qed. + +(* -------------------------------------------------------------------- *) +(* Contextual pattern: [y in x + y] targets the second argument of the + outer addition, leaving the first x untouched. *) +lemma L4 (c d x : int) : x + (x + c) = x. +proof. +rewrite [y in x + y](_ : y = d). +- suff: x + c = d by exact. admit. +- suff: x + d = x by exact. admit. +qed. + +(* -------------------------------------------------------------------- *) +(* Contextual pattern in a deeper nesting: [y in y + c] picks the + inner (x + c) subterm of ((x + c) + c). *) +lemma L5 (c d x : int) : (x + c) + c = x. +proof. +rewrite [y in y + c](_ : y = d). +- suff: x + c = d by exact. admit. +- suff: d + c = x by exact. admit. +qed. + +(* -------------------------------------------------------------------- *) +(* Contextual pattern with a more complex context expression. *) +lemma L6 (a b c x : int) : (a + b) * (a + c) = x. +proof. +rewrite [y in y * (a + c)](_ : y = c). +- suff: a + b = c by exact. admit. +- suff: c * (a + c) = x by exact. admit. +qed. + +(* -------------------------------------------------------------------- *) +(* Contextual pattern with a definition: ensure no unfolding occurs + during pattern search, which would make position computation wrong. *) +op f (x : int) = x + 1. + +lemma L7 (a b : int) : f a + f b = a. +proof. +rewrite [y in f y + _](_ : y = b). +- suff: a = b by exact. admit. +- suff: f b + f b = a by exact. admit. +qed. + +(* -------------------------------------------------------------------- *) +(* Same but the contextual pattern involves the defined operator. *) +lemma L7b (a b c : int) : f a + f b = c. +proof. +rewrite [y in y + f b](_ : y = a). +- suff: f a = a by exact. admit. +- suff: a + f b = c by exact. admit. +qed. + +(* -------------------------------------------------------------------- *) +(* Contextual pattern where the pattern could match via unfolding of f. + The pattern [y in y + 1] should NOT match f a (which unfolds to a + 1), + because that would make position computation wrong. *) +lemma L7c (a b : int) : f a + f b = a. +proof. +fail rewrite [y in y + 1](_ : y = b). +abort. + +(* -------------------------------------------------------------------- *) +(* Reverse: pattern mentions f but goal has the unfolded form. + [y in f y] should NOT match in (a + 1) + (b + 1) via folding. *) +lemma L7d (a b : int) : (a + 1) + (b + 1) = a. +proof. +fail rewrite [y in f y](_ : y = b). +abort. + +(* -------------------------------------------------------------------- *) +(* Inner unfolding: the pattern [y in y + (x + 1)] could match + y + f x via inner unfolding of f. This risks making the position + computation wrong because prw would be the unfolded form while the + goal still has f x. *) +op g (x y : int) = x + y. + +lemma L7e (a b x : int) : g a (f x) + g b (f x) = a. +proof. +rewrite [y in g y (f x)](_ : _ = b). +- suff: a = b by exact. admit. +- suff: g b (f x) + g b (f x) = a by exact. admit. +qed. + +(* -------------------------------------------------------------------- *) +(* Pattern has unfolded form of f in an inner position. + Conversion is disabled for contextual pattern search, so this + must fail even though y is not part of the unfolded subterm. *) +lemma L7f (a b x : int) : g a (f x) + g b (f x) = a. +proof. +fail rewrite [y in g y (x + 1)](_ : y = b). +abort. + +(* -------------------------------------------------------------------- *) +(* Dangerous case: y IS part of the unfolded subterm. The pattern + [y in (3 + y + 1) + f b] could match foo a via inner unfolding, + but that would make position computation wrong. Must be rejected. *) +op h (x : int) = 3 + x + 1. + +lemma L7g (a b : int) : h a + f b = a. +proof. +fail rewrite [y in (3 + y + 1) + f b](_ : y = b). +abort. + +(* -------------------------------------------------------------------- *) +(* Error: context variable does not appear in the r-pattern — should fail. *) +lemma L8_fail (c d x : int) : x + (x + c) = x. +proof. +fail rewrite [y in x + x](_ : x + c = d). +abort. + +(* -------------------------------------------------------------------- *) +(* Error: context variable appears twice — non-linear, should fail. *) +lemma L9_fail (c d x : int) : x + (x + c) = x. +proof. +fail rewrite [y in y + y](_ : x = d). +abort. + +(* -------------------------------------------------------------------- *) +(* Error: rewrite rule LHS does not match the selected subterm. *) +lemma L10_fail (c d x : int) : x + (x + c) = x. +proof. +fail rewrite [y in x + y](_ : x = d). +abort.