Skip to content

Commit ff27c71

Browse files
coeff-aijcoord-e
andauthored
Add annotation support for trait methods and verify that implementations satisfy them (#25)
* add: tests for annotations in traits * add: reference trait-side definitions for require/ensure annotations of functions in impl blocks Update src/analyze/local_def.rs Co-authored-by: Hiromi Ogawa <me@coord-e.com> * change: move extract_*_annot()s to analyze::Analyzer from analyze::local_def::Analyzer * change: insert type names as prefix of name for predicates in impl blocks * add: tests for identifying struct‑bound predicates using `Self::` * add: Identify struct-bounded predicates using `Self::` prefix --------- Co-authored-by: Hiromi Ogawa <me@coord-e.com>
1 parent 4e2836c commit ff27c71

7 files changed

Lines changed: 411 additions & 55 deletions

File tree

src/analyze.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use rustc_middle::mir::{self, BasicBlock, Local};
1616
use rustc_middle::ty::{self as mir_ty, TyCtxt};
1717
use rustc_span::def_id::{DefId, LocalDefId};
1818

19+
use crate::analyze;
20+
use crate::annot::{AnnotFormula, AnnotParser, Resolver};
1921
use crate::chc;
2022
use crate::pretty::PrettyDisplayExt as _;
2123
use crate::refine::{self, BasicBlockType, TypeBuilder};
@@ -435,4 +437,54 @@ impl<'tcx> Analyzer<'tcx> {
435437
let body = self.tcx.optimized_mir(local_def_id);
436438
self.local_fn_sig_with_body(local_def_id, body)
437439
}
440+
441+
fn extract_require_annot<T>(
442+
&self,
443+
def_id: DefId,
444+
resolver: T,
445+
self_type_name: Option<String>,
446+
) -> Option<AnnotFormula<T::Output>>
447+
where
448+
T: Resolver,
449+
{
450+
let mut require_annot = None;
451+
let parser = AnnotParser::new(&resolver, self_type_name);
452+
for attrs in self
453+
.tcx
454+
.get_attrs_by_path(def_id, &analyze::annot::requires_path())
455+
{
456+
if require_annot.is_some() {
457+
unimplemented!();
458+
}
459+
let ts = analyze::annot::extract_annot_tokens(attrs.clone());
460+
let require = parser.parse_formula(ts).unwrap();
461+
require_annot = Some(require);
462+
}
463+
require_annot
464+
}
465+
466+
fn extract_ensure_annot<T>(
467+
&self,
468+
def_id: DefId,
469+
resolver: T,
470+
self_type_name: Option<String>,
471+
) -> Option<AnnotFormula<T::Output>>
472+
where
473+
T: Resolver,
474+
{
475+
let mut ensure_annot = None;
476+
let parser = AnnotParser::new(&resolver, self_type_name);
477+
for attrs in self
478+
.tcx
479+
.get_attrs_by_path(def_id, &analyze::annot::ensures_path())
480+
{
481+
if ensure_annot.is_some() {
482+
unimplemented!();
483+
}
484+
let ts = analyze::annot::extract_annot_tokens(attrs.clone());
485+
let ensure = parser.parse_formula(ts).unwrap();
486+
ensure_annot = Some(ensure);
487+
}
488+
ensure_annot
489+
}
438490
}

src/analyze/local_def.rs

Lines changed: 87 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -53,66 +53,38 @@ pub struct Analyzer<'tcx, 'ctx> {
5353
}
5454

5555
impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
56-
fn extract_require_annot<T>(&self, resolver: T) -> Option<AnnotFormula<T::Output>>
57-
where
58-
T: annot::Resolver,
59-
{
60-
let mut require_annot = None;
61-
for attrs in self.tcx.get_attrs_by_path(
62-
self.local_def_id.to_def_id(),
63-
&analyze::annot::requires_path(),
64-
) {
65-
if require_annot.is_some() {
66-
unimplemented!();
67-
}
68-
let ts = analyze::annot::extract_annot_tokens(attrs.clone());
69-
let require = AnnotParser::new(&resolver).parse_formula(ts).unwrap();
70-
require_annot = Some(require);
71-
}
72-
require_annot
73-
}
74-
75-
fn extract_ensure_annot<T>(&self, resolver: T) -> Option<AnnotFormula<T::Output>>
76-
where
77-
T: annot::Resolver,
78-
{
79-
let mut ensure_annot = None;
80-
for attrs in self.tcx.get_attrs_by_path(
81-
self.local_def_id.to_def_id(),
82-
&analyze::annot::ensures_path(),
83-
) {
84-
if ensure_annot.is_some() {
85-
unimplemented!();
86-
}
87-
let ts = analyze::annot::extract_annot_tokens(attrs.clone());
88-
let ensure = AnnotParser::new(&resolver).parse_formula(ts).unwrap();
89-
ensure_annot = Some(ensure);
90-
}
91-
ensure_annot
92-
}
93-
94-
fn extract_param_annots<T>(&self, resolver: T) -> Vec<(Ident, rty::RefinedType<T::Output>)>
56+
fn extract_param_annots<T>(
57+
&self,
58+
resolver: T,
59+
self_type_name: Option<String>,
60+
) -> Vec<(Ident, rty::RefinedType<T::Output>)>
9561
where
9662
T: annot::Resolver,
9763
{
9864
let mut param_annots = Vec::new();
65+
let parser = AnnotParser::new(&resolver, self_type_name);
9966
for attrs in self
10067
.tcx
10168
.get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::param_path())
10269
{
10370
let ts = analyze::annot::extract_annot_tokens(attrs.clone());
10471
let (ident, ts) = analyze::annot::split_param(&ts);
105-
let param = AnnotParser::new(&resolver).parse_rty(ts).unwrap();
72+
let param = parser.parse_rty(ts).unwrap();
10673
param_annots.push((ident, param));
10774
}
10875
param_annots
10976
}
11077

111-
fn extract_ret_annot<T>(&self, resolver: T) -> Option<rty::RefinedType<T::Output>>
78+
fn extract_ret_annot<T>(
79+
&self,
80+
resolver: T,
81+
self_type_name: Option<String>,
82+
) -> Option<rty::RefinedType<T::Output>>
11283
where
11384
T: annot::Resolver,
11485
{
11586
let mut ret_annot = None;
87+
let parser = AnnotParser::new(&resolver, self_type_name);
11688
for attrs in self
11789
.tcx
11890
.get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::ret_path())
@@ -121,14 +93,34 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
12193
unimplemented!();
12294
}
12395
let ts = analyze::annot::extract_annot_tokens(attrs.clone());
124-
let ret = AnnotParser::new(&resolver).parse_rty(ts).unwrap();
96+
let ret = parser.parse_rty(ts).unwrap();
12597
ret_annot = Some(ret);
12698
}
12799
ret_annot
128100
}
129101

102+
fn impl_type(&self) -> Option<rustc_middle::ty::Ty<'tcx>> {
103+
use rustc_hir::def::DefKind;
104+
105+
let parent_def_id = self.tcx.parent(self.local_def_id.to_def_id());
106+
107+
if !matches!(self.tcx.def_kind(parent_def_id), DefKind::Impl { .. }) {
108+
return None;
109+
}
110+
111+
let self_ty = self.tcx.type_of(parent_def_id).instantiate_identity();
112+
113+
Some(self_ty)
114+
}
115+
130116
pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) {
131-
let pred_name = self.tcx.item_name(local_def_id.to_def_id()).to_string();
117+
// predicate's name
118+
let impl_type = self.impl_type();
119+
let pred_item_name = self.tcx.item_name(local_def_id.to_def_id()).to_string();
120+
let pred_name = match impl_type {
121+
Some(t) => t.to_string() + "_" + &pred_item_name,
122+
None => pred_item_name,
123+
};
132124

133125
// function's body
134126
use rustc_hir::{Block, Expr, ExprKind};
@@ -252,6 +244,17 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
252244
|| (all_params_annotated && has_ret)
253245
}
254246

247+
pub fn trait_item_id(&self) -> Option<LocalDefId> {
248+
let impl_item_assoc = self
249+
.tcx
250+
.opt_associated_item(self.local_def_id.to_def_id())?;
251+
let trait_item_id = impl_item_assoc
252+
.trait_item_def_id
253+
.and_then(|id| id.as_local())?;
254+
255+
Some(trait_item_id)
256+
}
257+
255258
pub fn expected_ty(&mut self) -> rty::RefinedType {
256259
let sig = self
257260
.ctx
@@ -268,16 +271,47 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
268271
param_resolver.push_param(input_ident.name, input_ty.to_sort());
269272
}
270273

271-
let mut require_annot = self.extract_require_annot(&param_resolver);
272-
let mut ensure_annot = {
273-
let output_ty = self.type_builder.build(sig.output());
274-
let resolver = annot::StackedResolver::default()
275-
.resolver(analyze::annot::ResultResolver::new(output_ty.to_sort()))
276-
.resolver((&param_resolver).map(rty::RefinedTypeVar::Free));
277-
self.extract_ensure_annot(resolver)
278-
};
279-
let param_annots = self.extract_param_annots(&param_resolver);
280-
let ret_annot = self.extract_ret_annot(&param_resolver);
274+
let output_ty = self.type_builder.build(sig.output());
275+
let result_param_resolver = annot::StackedResolver::default()
276+
.resolver(analyze::annot::ResultResolver::new(output_ty.to_sort()))
277+
.resolver((&param_resolver).map(rty::RefinedTypeVar::Free));
278+
279+
let self_type_name = self.impl_type().map(|ty| ty.to_string());
280+
281+
let mut require_annot = self.ctx.extract_require_annot(
282+
self.local_def_id.to_def_id(),
283+
&param_resolver,
284+
self_type_name.clone(),
285+
);
286+
287+
let mut ensure_annot = self.ctx.extract_ensure_annot(
288+
self.local_def_id.to_def_id(),
289+
&result_param_resolver,
290+
self_type_name.clone(),
291+
);
292+
293+
if let Some(trait_item_id) = self.trait_item_id() {
294+
tracing::info!("trait item fonud: {:?}", trait_item_id);
295+
let trait_require_annot = self.ctx.extract_require_annot(
296+
trait_item_id.into(),
297+
&param_resolver,
298+
self_type_name.clone(),
299+
);
300+
let trait_ensure_annot = self.ctx.extract_ensure_annot(
301+
trait_item_id.into(),
302+
&result_param_resolver,
303+
self_type_name.clone(),
304+
);
305+
306+
assert!(require_annot.is_none() || trait_require_annot.is_none());
307+
require_annot = require_annot.or(trait_require_annot);
308+
309+
assert!(ensure_annot.is_none() || trait_ensure_annot.is_none());
310+
ensure_annot = ensure_annot.or(trait_ensure_annot);
311+
}
312+
313+
let param_annots = self.extract_param_annots(&param_resolver, self_type_name.clone());
314+
let ret_annot = self.extract_ret_annot(&param_resolver, self_type_name);
281315

282316
if self.is_annotated_as_callable() {
283317
if require_annot.is_some() || ensure_annot.is_some() {

src/annot.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ impl<T> FormulaOrTerm<T> {
250250
/// A parser for refinement type annotations and formula annotations.
251251
struct Parser<'a, T> {
252252
resolver: T,
253+
self_type_name: Option<String>,
253254
cursor: RefTokenTreeCursor<'a>,
254255
formula_existentials: HashMap<String, chc::Sort>,
255256
}
@@ -453,6 +454,7 @@ where
453454
TokenTree::Delimited(_, _, Delimiter::Parenthesis, s) => {
454455
let mut parser = Parser {
455456
resolver: self.boxed_resolver(),
457+
self_type_name: self.self_type_name.clone(),
456458
cursor: s.trees(),
457459
formula_existentials: self.formula_existentials.clone(),
458460
};
@@ -493,6 +495,7 @@ where
493495

494496
let mut parser = Parser {
495497
resolver: self.boxed_resolver(),
498+
self_type_name: self.self_type_name.clone(),
496499
cursor: args.trees(),
497500
formula_existentials: self.formula_existentials.clone(),
498501
};
@@ -518,11 +521,40 @@ where
518521
};
519522
let mut parser = Parser {
520523
resolver: self.boxed_resolver(),
524+
self_type_name: self.self_type_name.clone(),
521525
cursor: s.trees(),
522526
formula_existentials: self.formula_existentials.clone(),
523527
};
524528
let args = parser.parse_arg_terms()?;
525529
parser.end_of_input()?;
530+
531+
// Identify struct-bound predicates call such as `Self::pred()`
532+
match path.segments.first() {
533+
Some(AnnotPathSegment {
534+
ident: Ident { name: symbol, .. },
535+
generic_args,
536+
}) if symbol.as_str() == "Self" && generic_args.is_empty() => {
537+
if path.segments.len() != 2 {
538+
unimplemented!("long path beginning with `Self::`");
539+
}
540+
541+
let func_name = path.segments.get(1).unwrap().ident.name.as_str();
542+
let pred_name = if let Some(self_type_name) = &self.self_type_name {
543+
self_type_name.clone() + "_" + func_name
544+
} else {
545+
func_name.to_string()
546+
};
547+
548+
let pred_symbol = chc::UserDefinedPred::new(pred_name);
549+
let pred = chc::Pred::UserDefined(pred_symbol);
550+
551+
let atom = chc::Atom::new(pred, args);
552+
let formula = chc::Formula::Atom(atom);
553+
return Ok(FormulaOrTerm::Formula(formula));
554+
}
555+
_ => {}
556+
}
557+
526558
let (term, sort) = path.to_datatype_ctor(args);
527559
FormulaOrTerm::Term(term, sort)
528560
}
@@ -908,6 +940,7 @@ where
908940
TokenTree::Delimited(_, _, Delimiter::Parenthesis, ts) => {
909941
let mut parser = Parser {
910942
resolver: self.boxed_resolver(),
943+
self_type_name: self.self_type_name.clone(),
911944
cursor: ts.trees(),
912945
formula_existentials: self.formula_existentials.clone(),
913946
};
@@ -1014,6 +1047,7 @@ where
10141047
TokenTree::Delimited(_, _, Delimiter::Parenthesis, ts) => {
10151048
let mut parser = Parser {
10161049
resolver: self.boxed_resolver(),
1050+
self_type_name: self.self_type_name.clone(),
10171051
cursor: ts.trees(),
10181052
formula_existentials: self.formula_existentials.clone(),
10191053
};
@@ -1050,6 +1084,7 @@ where
10501084

10511085
let mut parser = Parser {
10521086
resolver: self.boxed_resolver(),
1087+
self_type_name: self.self_type_name.clone(),
10531088
cursor: ts.trees(),
10541089
formula_existentials: self.formula_existentials.clone(),
10551090
};
@@ -1074,6 +1109,7 @@ where
10741109

10751110
let mut parser = Parser {
10761111
resolver: RefinementResolver::new(self.boxed_resolver()),
1112+
self_type_name: self.self_type_name.clone(),
10771113
cursor: parser.cursor,
10781114
formula_existentials: self.formula_existentials.clone(),
10791115
};
@@ -1199,11 +1235,15 @@ impl<'a, T> StackedResolver<'a, T> {
11991235
#[derive(Debug, Clone)]
12001236
pub struct AnnotParser<T> {
12011237
resolver: T,
1238+
self_type_name: Option<String>,
12021239
}
12031240

12041241
impl<T> AnnotParser<T> {
1205-
pub fn new(resolver: T) -> Self {
1206-
Self { resolver }
1242+
pub fn new(resolver: T, self_type_name: Option<String>) -> Self {
1243+
Self {
1244+
resolver,
1245+
self_type_name,
1246+
}
12071247
}
12081248
}
12091249

@@ -1214,6 +1254,7 @@ where
12141254
pub fn parse_rty(&self, ts: TokenStream) -> Result<rty::RefinedType<T::Output>> {
12151255
let mut parser = Parser {
12161256
resolver: &self.resolver,
1257+
self_type_name: self.self_type_name.clone(),
12171258
cursor: ts.trees(),
12181259
formula_existentials: Default::default(),
12191260
};
@@ -1225,6 +1266,7 @@ where
12251266
pub fn parse_formula(&self, ts: TokenStream) -> Result<AnnotFormula<T::Output>> {
12261267
let mut parser = Parser {
12271268
resolver: &self.resolver,
1269+
self_type_name: self.self_type_name.clone(),
12281270
cursor: ts.trees(),
12291271
formula_existentials: Default::default(),
12301272
};

0 commit comments

Comments
 (0)