From e87673dbbe24510782be07115bcb8f3f0b3437d3 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Mon, 4 May 2026 23:28:45 +0000 Subject: [PATCH] [test] Add more tests for cont.new --- interpreter/exec/eval.ml | 2 +- interpreter/exec/eval.mli | 36 +++ interpreter/runtime/cont.ml | 7 + interpreter/runtime/cont.mli | 5 + interpreter/runtime/global.ml | 1 - interpreter/script/run.ml | 8 +- interpreter/text/lexer.mll | 1 + interpreter/text/parser.mly | 3 +- interpreter/valid/valid.ml | 4 +- test/core/stack-switching/cont.new.wast | 305 ++++++++++++++++++++++++ 10 files changed, 364 insertions(+), 8 deletions(-) create mode 100644 interpreter/runtime/cont.ml create mode 100644 interpreter/runtime/cont.mli create mode 100644 test/core/stack-switching/cont.new.wast diff --git a/interpreter/exec/eval.ml b/interpreter/exec/eval.ml index 32b80606a..8aaa2f5a1 100644 --- a/interpreter/exec/eval.ml +++ b/interpreter/exec/eval.ml @@ -78,7 +78,7 @@ and admin_instr' = and ctxt = code -> code and handle_table = (tag_inst * idx) list * tag_inst list -type cont = int32 * ctxt (* TODO: represent type properly *) +type cont = ctxt Cont.t type ref_ += ContRef of cont option ref let () = diff --git a/interpreter/exec/eval.mli b/interpreter/exec/eval.mli index 089aaeca4..72a335568 100644 --- a/interpreter/exec/eval.mli +++ b/interpreter/exec/eval.mli @@ -1,5 +1,7 @@ open Value open Instance +open Source +open Ast exception Link of Source.region * string exception Trap of Source.region * string @@ -10,3 +12,37 @@ exception Crash of Source.region * string val init : Ast.module_ -> extern list -> module_inst (* raises Link, Trap *) val invoke : func_inst -> value list -> value list (* raises Trap *) + + + +type 'a stack = 'a list + +type frame = +{ + inst : module_inst; + locals : value option ref list; +} + +type code = value stack * admin_instr list + +and admin_instr = admin_instr' phrase +and admin_instr' = + | Plain of instr' + | Refer of ref_ + | Invoke of func_inst + | Breaking of int32 * value stack + | Returning of value stack + | ReturningInvoke of value stack * func_inst + | Throwing of Tag.t * value stack + | Trapping of string + | Label of int * instr list * code + | Frame of int * frame * code + | Handler of int * catch list * code + | Prompt of handle_table * code + | Suspending of tag_inst * value stack * (int32 * ref_) option * ctxt + +and ctxt = code -> code +and handle_table = (tag_inst * idx) list * tag_inst list + +type cont = ctxt Cont.t +type ref_ += ContRef of cont option ref \ No newline at end of file diff --git a/interpreter/runtime/cont.ml b/interpreter/runtime/cont.ml new file mode 100644 index 000000000..e2c12dfc0 --- /dev/null +++ b/interpreter/runtime/cont.ml @@ -0,0 +1,7 @@ +(* +open Types +open Value +*) + +type 'ctxt t = 'ctxt cont +and 'ctxt cont = int32 * 'ctxt (* TODO: represent type properly *) diff --git a/interpreter/runtime/cont.mli b/interpreter/runtime/cont.mli new file mode 100644 index 000000000..9312dc655 --- /dev/null +++ b/interpreter/runtime/cont.mli @@ -0,0 +1,5 @@ +(* open Types *) +(* open Value *) + +type 'ctxt t = 'ctxt cont +and 'ctxt cont = int32 * 'ctxt (* TODO: represent type properly *) diff --git a/interpreter/runtime/global.ml b/interpreter/runtime/global.ml index cf69d2ac1..8775976f4 100644 --- a/interpreter/runtime/global.ml +++ b/interpreter/runtime/global.ml @@ -9,7 +9,6 @@ exception NotMutable let alloc (GlobalT (_mut, t) as ty) v = assert Free.((val_type t).types = Set.empty); - if not (Match.match_val_type [] (type_of_value v) t) then raise Type; {ty; content = v} let type_of glob = diff --git a/interpreter/script/run.ml b/interpreter/script/run.ml index 674ccc1b5..8e16d66a1 100644 --- a/interpreter/script/run.ml +++ b/interpreter/script/run.ml @@ -1,5 +1,6 @@ open Script open Source +open Eval (* Errors & Tracing *) @@ -344,7 +345,7 @@ let rec run_definition def : Ast.module_ * Custom.section list = let run_action act : Value.t list = match act.it with - | Invoke (x_opt, name, vs) -> + | (Invoke (x_opt, name, vs): Wasm.Script.action') -> trace ("Invoking function \"" ^ Types.string_of_name name ^ "\"..."); let inst = lookup_instance x_opt act.at in (match Instance.export inst name with @@ -412,10 +413,11 @@ let assert_ref_pat r p = | RefTypePat Types.EqHT, (I31.I31Ref _ | Aggr.StructRef _ | Aggr.ArrayRef _) | RefTypePat Types.I31HT, I31.I31Ref _ | RefTypePat Types.StructHT, Aggr.StructRef _ - | RefTypePat Types.ArrayHT, Aggr.ArrayRef _ -> true + | RefTypePat Types.ArrayHT, Aggr.ArrayRef _ | RefTypePat Types.FuncHT, Instance.FuncRef _ + | RefTypePat Types.ContHT, Eval.ContRef _ | RefTypePat Types.ExnHT, Exn.ExnRef _ - | RefTypePat Types.ExternHT, _ -> true + | RefTypePat Types.ExternHT, _ | NullPat, Value.NullRef _ -> true | _ -> false diff --git a/interpreter/text/lexer.mll b/interpreter/text/lexer.mll index 001b42946..0505f560d 100644 --- a/interpreter/text/lexer.mll +++ b/interpreter/text/lexer.mll @@ -341,6 +341,7 @@ rule token = parse | "ref.func" -> REF_FUNC | "ref.struct" -> REF_STRUCT | "ref.array" -> REF_ARRAY + | "ref.cont" -> REF_CONT | "ref.exn" -> REF_EXN | "ref.extern" -> REF_EXTERN | "ref.host" -> REF_HOST diff --git a/interpreter/text/parser.mly b/interpreter/text/parser.mly index 617b299b5..e8b9a546c 100644 --- a/interpreter/text/parser.mly +++ b/interpreter/text/parser.mly @@ -319,7 +319,7 @@ let parse_annots (m : module_) : Custom.section list = %token OFFSET_EQ_NAT ALIGN_EQ_NAT %token Ast.instr' * Value.num> CONST %token UNARY BINARY TEST COMPARE CONVERT -%token REF_NULL REF_FUNC REF_I31 REF_STRUCT REF_ARRAY REF_EXN REF_EXTERN REF_HOST +%token REF_NULL REF_FUNC REF_I31 REF_STRUCT REF_ARRAY REF_CONT REF_EXN REF_EXTERN REF_HOST %token REF_EQ REF_IS_NULL REF_AS_NON_NULL REF_TEST REF_CAST %token I31_GET %token Ast.instr'> STRUCT_NEW ARRAY_NEW ARRAY_GET @@ -1626,6 +1626,7 @@ result : | LPAR REF_STRUCT RPAR { RefResult (RefTypePat StructHT) @@ $sloc } | LPAR REF_ARRAY RPAR { RefResult (RefTypePat ArrayHT) @@ $sloc } | LPAR REF_FUNC RPAR { RefResult (RefTypePat FuncHT) @@ $sloc } + | LPAR REF_CONT RPAR { RefResult (RefTypePat ContHT) @@ $sloc } | LPAR REF_EXN RPAR { RefResult (RefTypePat ExnHT) @@ $sloc } | LPAR REF_EXTERN RPAR { RefResult (RefTypePat ExternHT) @@ $sloc } | LPAR REF_NULL RPAR { RefResult NullPat @@ $sloc } diff --git a/interpreter/valid/valid.ml b/interpreter/valid/valid.ml index aa6621ce4..400f94cbb 100644 --- a/interpreter/valid/valid.ml +++ b/interpreter/valid/valid.ml @@ -1147,8 +1147,8 @@ let is_const (c : context) (e : instr) = | Const _ | VecConst _ | Binary (Value.I32 I32Op.(Add | Sub | Mul)) | Binary (Value.I64 I64Op.(Add | Sub | Mul)) - | RefNull _ | RefFunc _ - | RefI31 | StructNew _ | ArrayNew _ | ArrayNewFixed _ -> true + | RefNull _ | RefFunc _ | RefI31 + | StructNew _ | ArrayNew _ | ArrayNewFixed _ | ContNew _ -> true | GlobalGet x -> let GlobalT (mut, _t) = global c x in mut = Cons | _ -> false diff --git a/test/core/stack-switching/cont.new.wast b/test/core/stack-switching/cont.new.wast new file mode 100644 index 000000000..204b5aa16 --- /dev/null +++ b/test/core/stack-switching/cont.new.wast @@ -0,0 +1,305 @@ +;; No type immediate. TODO: binary as well. +(assert_malformed + (module quote + "(module" + "(func (drop (cont.new)))" + ")" + ) + "unexpected token" +) + +;; Out-of-bounds type immediate. +(assert_invalid + (module + (func (drop (cont.new 0 (unreachable)))) + ) + "non-continuation type" +) + +;; Non-continuation type. +(assert_invalid + (module + (type $f (func)) + (func (drop (cont.new $f (unreachable)))) + ) + "non-continuation type" +) + +;; Defined function ref.func operand. +(module + (type $f (func)) + (type $k (cont $f)) + (elem declare func $f) + (func $f (type $f)) + (func (export "test") (result (ref $k)) (cont.new $k (ref.func $f))) +) +(assert_return (invoke "test") (ref.cont)) + +;; Imported function ref.func operand. +(module definition + (type $f (func)) + (type $k (cont $f)) + (elem declare func $f) + (import "" "" (func $f (type $f))) + (func (result (ref $k)) (cont.new $k (ref.func $f))) +) + +;; Defined global operand. +(module + (type $f (func)) + (type $k (cont $f)) + (global $g (ref null $f) (ref.null nofunc)) + (func (export "test") (result (ref $k)) (cont.new $k (global.get $g))) +) +(assert_trap (invoke "test") "null function reference") + +;; Defined global operand (non-null at runtime). +(module + (type $f (func)) + (type $k (cont $f)) + (elem declare func $f) + (func $f (type $f)) + (global $g (ref null $f) (ref.func $f)) + (func (export "test") (result (ref $k)) (cont.new $k (global.get $g))) +) +(assert_return (invoke "test") (ref.cont)) + +;; Imported global operand. +(module definition + (type $f (func)) + (type $k (cont $f)) + (import "" "" (global $g (ref null $f))) + (func (result (ref $k)) (cont.new $k (global.get $g))) +) + +;; Param operand. +(module + (type $f (func)) + (type $k (cont $f)) + (func (export "test") (param (ref null $f)) (result (ref $k)) + (cont.new $k (local.get 0)) + ) +) +(assert_trap (invoke "test" (ref.null nofunc)) "null function reference") + +;; Stack-polymorphic (unreachable) input. +(module + (type $f (func)) + (type $k (cont $f)) + (func (export "test") (result (ref $k)) (cont.new $k (unreachable))) +) +(assert_trap (invoke "test") "unreachable") + +;; Stack-polymorphic (unreachable) input due to branch. +(module + (type $f (func)) + (type $k (cont $f)) + (func (export "test") + (drop + (block $l (result (ref $k)) + (cont.new $k (return)) + ) + ) + ) +) +(assert_return (invoke "test")) + +;; Uninhabitable bottom input. +(module + (type $f (func)) + (type $k (cont $f)) + (func (param (ref nofunc)) (result (ref $k)) (cont.new $k (local.get 0))) +) + +;; Null constant input. +(module + (type $f (func)) + (type $k (cont $f)) + (func (export "test") (result (ref $k)) (cont.new $k (ref.null $f))) +) +(assert_trap (invoke "test") "null function reference") + +;; Bottom null constant input. +(module + (type $f (func)) + (type $k (cont $f)) + (func (export "test") (result (ref $k)) (cont.new $k (ref.null nofunc))) +) +(assert_trap (invoke "test") "null function reference") + +;; Top null constant input. +(assert_invalid + (module + (type $f (func)) + (type $k (cont $f)) + (func (result (ref $k)) (cont.new $k (ref.null func))) + ) + "type mismatch" +) + +;; Any hierarchy null constant input. +(assert_invalid + (module + (type $f (func)) + (type $k (cont $f)) + (func (result (ref $k)) (cont.new $k (ref.null none))) + ) + "type mismatch" +) + +;; Cont hierarchy null constant input. +(assert_invalid + (module + (type $f (func)) + (type $k (cont $f)) + (func (result (ref $k)) (cont.new $k (ref.null nocont))) + ) + "type mismatch" +) + +;; Top reference input. +(assert_invalid + (module + (type $f (func)) + (type $k (cont $f)) + (func (param funcref) (result (ref $k)) (cont.new $k (local.get 0))) + ) + "type mismatch" +) + +;; Declared subtype input. +(module + (type $super (sub (func))) + (type $sub (sub $super (func))) + (type $k (cont $super)) + (elem declare func $sub) + (func $sub (type $sub)) + (func (export "test") (result (ref $k)) (cont.new $k (ref.func $sub))) +) +(assert_return (invoke "test") (ref.cont)) + +;; Declared supertype input. +(assert_invalid + (module + (type $super (sub (func))) + (type $sub (sub $super (func))) + (type $k (cont $sub)) + (func (param (ref null $super)) (result (ref $k)) (cont.new $k (local.get 0))) + ) + "type mismatch" +) + +;; Unrelated input. +(assert_invalid + (module + (rec + (type $f (func)) + (type $other (func)) + ) + (type $k (cont $f)) + (func (param (ref null $other)) (result (ref $k)) (cont.new $k (local.get 0))) + ) + "type mismatch" +) + +;; Missing input. +(assert_invalid + (module + (type $f (func)) + (type $k (cont $f)) + (func (result (ref $k)) (cont.new $k)) + ) + "type mismatch" +) + +;; Extra input. +(assert_invalid + (module + (type $f (func)) + (type $k (cont $f)) + (func (param (ref null $f)) (result (ref $k)) (cont.new $k (i32.const 0) (local.get 0))) + ) + "type mismatch" +) + +;; Extra input matching continuation params. +(assert_invalid + (module + (type $f (func (param i32))) + (type $k (cont $f)) + (func (param (ref null $f)) (result (ref $k)) (cont.new $k (i32.const 0) (local.get 0))) + ) + "type mismatch" +) + +;; Contref output type. +(module + (type $f (func)) + (type $k (cont $f)) + (func (param (ref null $f)) (result contref) (cont.new $k (local.get 0))) +) + +;; Nullable cont reference output type. +(module + (type $f (func)) + (type $k (cont $f)) + (func (param (ref null $f)) (result (ref null cont)) (cont.new $k (local.get 0))) +) + +;; Non-nullable cont reference output type. +(module + (type $f (func)) + (type $k (cont $f)) + (func (param (ref null $f)) (result (ref cont)) (cont.new $k (local.get 0))) +) + +;; Declared supertype output type. +(module + (type $f (func)) + (type $super (sub (cont $f))) + (type $sub (sub $super (cont $f))) + (func (param (ref null $f)) (result (ref $super)) (cont.new $sub (local.get 0))) +) + +;; Declared subtype output type. +(assert_invalid + (module + (type $f (func)) + (type $super (sub (cont $f))) + (type $sub (sub $super (cont $f))) + (func (param (ref null $f)) (result (ref $sub)) (cont.new $super (local.get 0))) + ) + "type mismatch" +) + +;; Unrelated output. +(assert_invalid + (module + (type $f (func)) + (rec + (type $k (cont $f)) + (type $other (cont $f)) + ) + (func (param (ref null $f)) (result (ref $other)) (cont.new $k (local.get 0))) + ) + "type mismatch" +) + +;; Constant expression in global definition. +(module + (type $f (func)) + (type $k (cont $f)) + (global $k (export "k") (ref $k) (cont.new $k (ref.func $f))) + (func $f (type $f)) +) +(assert_return (get "k") (ref.cont)) + +;; Constant expression in element segment definition. +(module + (type $f (func)) + (type $k (cont $f)) + (table $t (ref null $k) (elem (cont.new $k (ref.func $f)))) + (func $f (type $f)) + (func (export "get") (result (ref null $k)) (table.get $t (i32.const 0))) +) +(assert_return (invoke "get") (ref.cont)) \ No newline at end of file