Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion lib/typecheck.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,28 @@ let register_builtins (ctx : context) : unit =

(** Check a top-level function declaration. *)
let check_fn_decl (ctx : context) (fd : fn_decl) : unit result =
(* #135 slice 7: register the explicit `<T>` type parameters as fresh,
generalizable unification variables before lowering param/return
types. Without this, an uppercase param like `x: T` lowered to a
*rigid* `TCon "T"` (lower_type_expr TyCon fallthrough), which
`generalize` ignores — so every generic top-level function was
effectively monomorphic and the second instantiation blew up with
`Unify.TypeMismatch (T, Int)` (and `use prelude::{…}` import-checks
failed transitively). Mirrors the let-generalization discipline:
enter a deeper level, create the vars there, generalize at the outer
level so they become the scheme's quantified tyvars. *)
let tp_names = List.map (fun (tp : type_param) -> tp.tp_name.name)
fd.fd_type_params in
let saved_tp = List.map (fun n -> (n, Hashtbl.find_opt ctx.type_env n))
tp_names in
enter_level ctx;
List.iter (fun n ->
Hashtbl.replace ctx.type_env n (fresh_tyvar ctx.level)) tp_names;
let restore_tp () =
List.iter (fun (n, old) -> match old with
| Some t -> Hashtbl.replace ctx.type_env n t
| None -> Hashtbl.remove ctx.type_env n) saved_tp
in
(* Extern functions have no body — register the signature so callers can
typecheck against it, then bail out before the body-check pass. *)
if fd.fd_body = FnExtern then begin
Expand Down Expand Up @@ -1340,8 +1362,10 @@ let check_fn_decl (ctx : context) (fd : fn_decl) : unit result =
in
TArrow (param_ty, q, acc, fn_eff)
) param_tys fd.fd_params ret_ty in
exit_level ctx;
let sc = generalize ctx fn_ty in
bind_scheme ctx fd.fd_name.name sc;
restore_tp ();
Ok ()
end
else
Expand Down Expand Up @@ -1396,9 +1420,13 @@ let check_fn_decl (ctx : context) (fd : fn_decl) : unit result =
| Some sc -> Hashtbl.replace ctx.name_types n sc
| None -> Hashtbl.remove ctx.name_types n
) old;
(* Generalize and rebind the function with its polymorphic type *)
(* Generalize and rebind the function with its polymorphic type.
exit_level first so the `<T>` type-param vars (created at the deeper
level above) are quantified by `generalize` (#135 slice 7). *)
exit_level ctx;
let sc = generalize ctx fn_ty in
bind_scheme ctx fd.fd_name.name sc;
restore_tp ();
Ok ()

(** Register a type declaration in the context. *)
Expand Down
14 changes: 7 additions & 7 deletions stdlib/prelude.affine
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ type Result<T, E> = Ok(T) | Err(E)
// ============================================================================

fn map<T, U>(arr: [T], f: T -> U) -> [U] {
let result = [];
let mut result = [];
for x in arr {
result = result ++ [f(x)];
}
result
}

fn filter<T>(arr: [T], predicate: T -> Bool) -> [T] {
let result = [];
let mut result = [];
for x in arr {
if predicate(x) {
result = result ++ [x];
Expand All @@ -43,7 +43,7 @@ fn filter<T>(arr: [T], predicate: T -> Bool) -> [T] {
}

fn fold<T, U>(arr: [T], init: U, f: (U, T) -> U) -> U {
let acc = init;
let mut acc = init;
for x in arr {
acc = f(acc, x);
}
Expand Down Expand Up @@ -121,8 +121,8 @@ fn any(arr: [Bool]) -> Bool {
// ============================================================================

fn range(start: Int, end: Int) -> [Int] {
let result = [];
let i = start;
let mut result = [];
let mut i = start;
while i < end {
result = result ++ [i];
i = i + 1;
Expand All @@ -131,8 +131,8 @@ fn range(start: Int, end: Int) -> [Int] {
}

fn repeat<T>(value: T, n: Int) -> [T] {
let result = [];
let i = 0;
let mut result = [];
let mut i = 0;
while i < n {
result = result ++ [value];
i = i + 1;
Expand Down
19 changes: 19 additions & 0 deletions test/test_e2e.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3363,6 +3363,23 @@ let test_trait_sig_and_assoc_not_regressed () =
pub fn next(mut self) -> Option<Int>;
}|})

(* Issue #135 slice 7: top-level generic functions must instantiate their
`<T>` scheme with fresh vars per call. Before, `<T>` lowered to a rigid
`TCon "T"` that `generalize` ignored, so any 2nd instantiation failed
with `Unify.TypeMismatch (T, Int)` (and `use prelude` import-checks
failed transitively). *)
let test_generic_fn_multi_instantiation () =
Alcotest.(check bool) "id<T> called at Int and Bool in one program" true
(parse_check_passes
{|fn id<T>(x: T) -> T { return x; }
fn use_it() -> Bool { let a = id(1); let b = id(true); return b; }|})

let test_generic_hof_monomorphic_caller () =
Alcotest.(check bool) "generic fold<T,U> called by monomorphic Int sum" true
(parse_check_passes
{|fn fold<T, U>(arr: [T], init: U, f: (U, T) -> U) -> U { return init; }
fn sum(a: [Int]) -> Int { return fold(a, 0, fn(acc, x) => acc + x); }|})

let test_multi_arg_arrow () =
Alcotest.(check bool) "(A, B) -> C parses + typechecks" true
(parse_check_passes
Expand Down Expand Up @@ -3418,6 +3435,8 @@ let type_syntax_sugar_tests = [
Alcotest.test_case "effect E; + -> T / E (#135 slice 3)" `Quick test_bare_effect_and_effect_row;
Alcotest.test_case "trait default body + ref self (#135 sl5)" `Quick test_trait_default_body;
Alcotest.test_case "trait sig + assoc non-regressed (#135 sl5)" `Quick test_trait_sig_and_assoc_not_regressed;
Alcotest.test_case "generic fn multi-instantiation (#135 sl7)" `Quick test_generic_fn_multi_instantiation;
Alcotest.test_case "generic HOF + mono caller (#135 sl7)" `Quick test_generic_hof_monomorphic_caller;
Alcotest.test_case "(A, B) -> C (multi-arg arrow)" `Quick test_multi_arg_arrow;
Alcotest.test_case "(A, B) without arrow remains tuple" `Quick test_tuple_type_still_works;
]
Expand Down
Loading