From 685351d61290afaf807a03c969b7947aa347c0f9 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Tue, 14 Oct 2025 21:58:49 +0100 Subject: [PATCH 1/5] feat: add a 'records' example, and some required instructions Prior to this commit there was ostensibly no support for 'record' WIT types, but there wasn't actually much required for them to work. This commit adds an example for use in the test harness and the required instructions to load/store integer record fields. Relates to #4. --- Cargo.lock | 8 + cmd/gravity/src/codegen/exports.rs | 2 +- cmd/gravity/src/codegen/func.rs | 179 +++++++++++++- cmd/gravity/tests/cmd/instructions.stdout | 2 +- cmd/gravity/tests/cmd/records.stderr | 0 cmd/gravity/tests/cmd/records.stdout | 277 ++++++++++++++++++++++ cmd/gravity/tests/cmd/records.toml | 2 + examples/generate.go | 2 + examples/records/Cargo.toml | 11 + examples/records/records_test.go | 66 ++++++ examples/records/src/lib.rs | 31 +++ examples/records/wit/records.wit | 18 ++ 12 files changed, 587 insertions(+), 11 deletions(-) create mode 100644 cmd/gravity/tests/cmd/records.stderr create mode 100644 cmd/gravity/tests/cmd/records.stdout create mode 100644 cmd/gravity/tests/cmd/records.toml create mode 100644 examples/records/Cargo.toml create mode 100644 examples/records/records_test.go create mode 100644 examples/records/src/lib.rs create mode 100644 examples/records/wit/records.wit diff --git a/Cargo.lock b/Cargo.lock index de83507..c2f8a35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -187,6 +187,14 @@ dependencies = [ "wit-component", ] +[[package]] +name = "example-records" +version = "0.0.2" +dependencies = [ + "wit-bindgen", + "wit-component", +] + [[package]] name = "foldhash" version = "0.1.4" diff --git a/cmd/gravity/src/codegen/exports.rs b/cmd/gravity/src/codegen/exports.rs index 9005aba..f0906eb 100644 --- a/cmd/gravity/src/codegen/exports.rs +++ b/cmd/gravity/src/codegen/exports.rs @@ -166,7 +166,7 @@ mod tests { assert!(generated.contains("if err1 != nil {")); assert!(generated.contains("panic(err1)")); assert!(generated.contains("results1 := raw1[0]")); - assert!(generated.contains("result2 := api.DecodeU32(results1)")); + assert!(generated.contains("result2 := api.DecodeU32(uint64(results1))")); assert!(generated.contains("return result2")); } } diff --git a/cmd/gravity/src/codegen/func.rs b/cmd/gravity/src/codegen/func.rs index f7490ab..d2df053 100644 --- a/cmd/gravity/src/codegen/func.rs +++ b/cmd/gravity/src/codegen/func.rs @@ -323,7 +323,7 @@ impl Bindgen for Func<'_> { let operand = &operands[0]; quote_in! { self.body => $['\r'] - $result := $WAZERO_API_DECODE_U32($operand) + $result := $WAZERO_API_DECODE_U32(uint64($operand)) }; results.push(Operand::SingleValue(result.into())); } @@ -1087,16 +1087,168 @@ impl Bindgen for Func<'_> { Instruction::I32Load8S { .. } => todo!("implement instruction: {inst:?}"), Instruction::I32Load16U { .. } => todo!("implement instruction: {inst:?}"), Instruction::I32Load16S { .. } => todo!("implement instruction: {inst:?}"), - Instruction::I64Load { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F32Load { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F64Load { .. } => todo!("implement instruction: {inst:?}"), + Instruction::I64Load { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let default = &format!("default{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if !$ok { + var $default $(typ.as_ref()) + return $default, $ERRORS_NEW("failed to read i64 from memory") + } + } + GoResult::Anon(GoType::Error) => { + if !$ok { + return $ERRORS_NEW("failed to read i64 from memory") + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read i64 from memory")) + } + } + }) + }; + results.push(Operand::SingleValue(value.into())); + } + Instruction::F32Load { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let default = &format!("default{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if !$ok { + var $default $(typ.as_ref()) + return $default, $ERRORS_NEW("failed to read f64 from memory") + } + } + GoResult::Anon(GoType::Error) => { + if !$ok { + return $ERRORS_NEW("failed to read f64 from memory") + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read f64 from memory")) + } + } + }) + }; + results.push(Operand::SingleValue(value.into())); + } + Instruction::F64Load { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let default = &format!("default{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if !$ok { + var $default $(typ.as_ref()) + return $default, $ERRORS_NEW("failed to read f64 from memory") + } + } + GoResult::Anon(GoType::Error) => { + if !$ok { + return $ERRORS_NEW("failed to read f64 from memory") + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read f64 from memory")) + } + } + }) + }; + results.push(Operand::SingleValue(value.into())); + } Instruction::I32Store16 { .. } => todo!("implement instruction: {inst:?}"), Instruction::I64Store { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F32Store { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F64Store { .. } => todo!("implement instruction: {inst:?}"), + Instruction::F32Store { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tag = &operands[0]; + let ptr = &operands[1]; + match &self.direction { + Direction::Export => { + quote_in! { self.body => + $['\r'] + i.module.Memory().WriteUint64Le($ptr+$offset, $tag) + } + } + Direction::Import { .. } => { + quote_in! { self.body => + $['\r'] + mod.Memory().WriteUint64Le($ptr+$offset, $tag) + } + } + } + } + Instruction::F64Store { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tag = &operands[0]; + let ptr = &operands[1]; + match &self.direction { + Direction::Export => { + quote_in! { self.body => + $['\r'] + i.module.Memory().WriteUint64Le($ptr+$offset, $tag) + } + } + Direction::Import { .. } => { + quote_in! { self.body => + $['\r'] + mod.Memory().WriteUint64Le($ptr+$offset, $tag) + } + } + } + } Instruction::I32FromChar => todo!("implement instruction: {inst:?}"), - Instruction::I64FromU64 => todo!("implement instruction: {inst:?}"), - Instruction::I64FromS64 => todo!("implement instruction: {inst:?}"), + Instruction::I64FromU64 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := int64($operand) + } + results.push(Operand::SingleValue(value.into())); + } + Instruction::I64FromS64 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := $operand + } + results.push(Operand::SingleValue(value.into())); + } Instruction::I32FromS32 => { let tmp = self.tmp(); let value = format!("value{tmp}"); @@ -1196,7 +1348,16 @@ impl Bindgen for Func<'_> { results.push(Operand::SingleValue(result.into())); } Instruction::S64FromI64 => todo!("implement instruction: {inst:?}"), - Instruction::U64FromI64 => todo!("implement instruction: {inst:?}"), + Instruction::U64FromI64 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := uint64($operand) + } + results.push(Operand::SingleValue(value.into())); + } Instruction::CharFromI32 => todo!("implement instruction: {inst:?}"), Instruction::F32FromCoreF32 => { let tmp = self.tmp(); diff --git a/cmd/gravity/tests/cmd/instructions.stdout b/cmd/gravity/tests/cmd/instructions.stdout index 738aa51..a9f4cee 100644 --- a/cmd/gravity/tests/cmd/instructions.stdout +++ b/cmd/gravity/tests/cmd/instructions.stdout @@ -180,7 +180,7 @@ func (i *InstructionsInstance) U32Roundtrip( } results1 := raw1[0] - result2 := api.DecodeU32(results1) + result2 := api.DecodeU32(uint64(results1)) return result2 } diff --git a/cmd/gravity/tests/cmd/records.stderr b/cmd/gravity/tests/cmd/records.stderr new file mode 100644 index 0000000..e69de29 diff --git a/cmd/gravity/tests/cmd/records.stdout b/cmd/gravity/tests/cmd/records.stdout new file mode 100644 index 0000000..96d71ef --- /dev/null +++ b/cmd/gravity/tests/cmd/records.stdout @@ -0,0 +1,277 @@ +// Code generated by arcjet-gravity; DO NOT EDIT. + +package records + +import "context" +import "errors" +import "github.com/tetratelabs/wazero" +import "github.com/tetratelabs/wazero/api" + +import _ "embed" + +//go:embed records.wasm +var wasmFileRecords []byte + +type IRecordsTypes interface {} + +type Foo struct { + Float32 float32 + + Float64 float64 + + Uint32 uint32 + + Uint64 uint64 + + S string + + Vf32 []float32 + + Vf64 []float64 +} + +type RecordsFactory struct { + runtime wazero.Runtime + module wazero.CompiledModule +} + +func NewRecordsFactory( + ctx context.Context, + types IRecordsTypes, +) (*RecordsFactory, error) { + wazeroRuntime := wazero.NewRuntime(ctx) + + _, err0 := wazeroRuntime.NewHostModuleBuilder("arcjet:records/types"). + Instantiate(ctx) + if err0 != nil { + return nil, err0 + } + + // Compiling the module takes a LONG time, so we want to do it once and hold + // onto it with the Runtime + module, err := wazeroRuntime.CompileModule(ctx, wasmFileRecords) + if err != nil { + return nil, err + } + return &RecordsFactory{ + runtime: wazeroRuntime, + module: module, + }, nil +} + +func (f *RecordsFactory) Instantiate(ctx context.Context) (*RecordsInstance, error) { + if module, err := f.runtime.InstantiateModule(ctx, f.module, wazero.NewModuleConfig()); err != nil { + return nil, err + } else { + return &RecordsInstance{module}, nil + } +} + +func (f *RecordsFactory) Close(ctx context.Context) { + f.runtime.Close(ctx) +} + +type RecordsInstance struct { + module api.Module +} + +func (i *RecordsInstance) Close(ctx context.Context) error { + if err := i.module.Close(ctx); err != nil { + return err + } + + return nil +} + +// writeString will put a Go string into the Wasm memory following the Component +// Model calling conventions, such as allocating memory with the realloc function +func writeString( + ctx context.Context, + s string, + memory api.Memory, + realloc api.Function, +) (uint64, uint64, error) { + if len(s) == 0 { + return 1, 0, nil + } + + results, err := realloc.Call(ctx, 0, 0, 1, uint64(len(s))) + if err != nil { + return 1, 0, err + } + ptr := results[0] + ok := memory.Write(uint32(ptr), []byte(s)) + if !ok { + return 1, 0, errors.New("failed to write string to wasm memory") + } + return uint64(ptr), uint64(len(s)), nil +} + +func (i *RecordsInstance) ModifyFoo( + ctx context.Context, + f Foo, +) Foo { + arg0 := f + float320 := arg0.Float32 + float640 := arg0.Float64 + uint320 := arg0.Uint32 + uint640 := arg0.Uint64 + s0 := arg0.S + vf320 := arg0.Vf32 + vf640 := arg0.Vf64 + result1 := api.EncodeF32(float320) + result2 := api.EncodeF64(float640) + result3 := api.EncodeU32(uint320) + value4 := int64(uint640) + memory5 := i.module.Memory() + realloc5 := i.module.ExportedFunction("cabi_realloc") + ptr5, len5, err5 := writeString(ctx, s0, memory5, realloc5) + // The return type doesn't contain an error so we panic if one is encountered + if err5 != nil { + panic(err5) + } + vec7 := vf320 + len7 := uint64(len(vec7)) + result7, err7 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len7 * 4) + // The return type doesn't contain an error so we panic if one is encountered + if err7 != nil { + panic(err7) + } + ptr7 := result7[0] + for idx := uint64(0); idx < len7; idx++ { + e := vec7[idx] + base := uint32(ptr7 + uint64(idx) * uint64(4)) + result6 := api.EncodeF32(e) + i.module.Memory().WriteUint64Le(base+0, result6) + } + vec9 := vf640 + len9 := uint64(len(vec9)) + result9, err9 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 8, len9 * 8) + // The return type doesn't contain an error so we panic if one is encountered + if err9 != nil { + panic(err9) + } + ptr9 := result9[0] + for idx := uint64(0); idx < len9; idx++ { + e := vec9[idx] + base := uint32(ptr9 + uint64(idx) * uint64(8)) + result8 := api.EncodeF64(e) + i.module.Memory().WriteUint64Le(base+0, result8) + } + raw10, err10 := i.module.ExportedFunction("modify-foo").Call(ctx, uint64(result1), uint64(result2), uint64(result3), uint64(value4), uint64(ptr5), uint64(len5), uint64(ptr7), uint64(len7), uint64(ptr9), uint64(len9)) + // The return type doesn't contain an error so we panic if one is encountered + if err10 != nil { + panic(err10) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if _, err := i.module.ExportedFunction("cabi_post_modify-foo").Call(ctx, raw10...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + }() + + results10 := raw10[0] + value11, ok11 := i.module.Memory().ReadUint64Le(uint32(results10 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok11 { + panic(errors.New("failed to read f64 from memory")) + } + result12 := api.DecodeF32(value11) + value13, ok13 := i.module.Memory().ReadUint64Le(uint32(results10 + 8)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok13 { + panic(errors.New("failed to read f64 from memory")) + } + result14 := api.DecodeF64(value13) + value15, ok15 := i.module.Memory().ReadUint32Le(uint32(results10 + 16)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok15 { + panic(errors.New("failed to read i32 from memory")) + } + result16 := api.DecodeU32(uint64(value15)) + value17, ok17 := i.module.Memory().ReadUint64Le(uint32(results10 + 24)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok17 { + panic(errors.New("failed to read i64 from memory")) + } + value18 := uint64(value17) + ptr19, ok19 := i.module.Memory().ReadUint32Le(uint32(results10 + 32)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok19 { + panic(errors.New("failed to read pointer from memory")) + } + len20, ok20 := i.module.Memory().ReadUint32Le(uint32(results10 + 36)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok20 { + panic(errors.New("failed to read length from memory")) + } + buf21, ok21 := i.module.Memory().Read(ptr19, len20) + // The return type doesn't contain an error so we panic if one is encountered + if !ok21 { + panic(errors.New("failed to read bytes from memory")) + } + str21 := string(buf21) + ptr22, ok22 := i.module.Memory().ReadUint32Le(uint32(results10 + 40)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok22 { + panic(errors.New("failed to read pointer from memory")) + } + len23, ok23 := i.module.Memory().ReadUint32Le(uint32(results10 + 44)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok23 { + panic(errors.New("failed to read length from memory")) + } + base26 := ptr22 + len26 := len23 + result26 := make([]float32, len26) + for idx26 := uint32(0); idx26 < len26; idx26++ { + base := base26 + idx26 * 4 + value24, ok24 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok24 { + panic(errors.New("failed to read f64 from memory")) + } + result25 := api.DecodeF32(value24) + result26[idx26] = result25 + } + ptr27, ok27 := i.module.Memory().ReadUint32Le(uint32(results10 + 48)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok27 { + panic(errors.New("failed to read pointer from memory")) + } + len28, ok28 := i.module.Memory().ReadUint32Le(uint32(results10 + 52)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok28 { + panic(errors.New("failed to read length from memory")) + } + base31 := ptr27 + len31 := len28 + result31 := make([]float64, len31) + for idx31 := uint32(0); idx31 < len31; idx31++ { + base := base31 + idx31 * 8 + value29, ok29 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok29 { + panic(errors.New("failed to read f64 from memory")) + } + result30 := api.DecodeF64(value29) + result31[idx31] = result30 + } + value32 := Foo{ + Float32: result12, + Float64: result14, + Uint32: result16, + Uint64: value18, + S: str21, + Vf32: result26, + Vf64: result31, + } + return value32 +} + diff --git a/cmd/gravity/tests/cmd/records.toml b/cmd/gravity/tests/cmd/records.toml new file mode 100644 index 0000000..4cba21d --- /dev/null +++ b/cmd/gravity/tests/cmd/records.toml @@ -0,0 +1,2 @@ +bin.name = "gravity" +args = "--world records ../../target/wasm32-unknown-unknown/release/example_records.wasm" diff --git a/examples/generate.go b/examples/generate.go index 47058c8..2ec47ab 100644 --- a/examples/generate.go +++ b/examples/generate.go @@ -1,9 +1,11 @@ package examples //go:generate cargo build -p example-basic --target wasm32-unknown-unknown --release +//go:generate cargo build -p example-records --target wasm32-unknown-unknown --release //go:generate cargo build -p example-iface-method-returns-string --target wasm32-unknown-unknown --release //go:generate cargo build -p example-instructions --target wasm32-unknown-unknown --release //go:generate cargo run --bin gravity -- --world basic --output ./basic/basic.go ../target/wasm32-unknown-unknown/release/example_basic.wasm +//go:generate cargo run --bin gravity -- --world records --output ./records/records.go ../target/wasm32-unknown-unknown/release/example_records.wasm //go:generate cargo run --bin gravity -- --world example --output ./iface-method-returns-string/example.go ../target/wasm32-unknown-unknown/release/example_iface_method_returns_string.wasm //go:generate cargo run --bin gravity -- --world instructions --output ./instructions/bindings.go ../target/wasm32-unknown-unknown/release/example_instructions.wasm diff --git a/examples/records/Cargo.toml b/examples/records/Cargo.toml new file mode 100644 index 0000000..1ae90a3 --- /dev/null +++ b/examples/records/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example-records" +version = "0.0.2" +edition = "2024" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wit-bindgen = "=0.46.0" +wit-component = "=0.239.0" diff --git a/examples/records/records_test.go b/examples/records/records_test.go new file mode 100644 index 0000000..83e6416 --- /dev/null +++ b/examples/records/records_test.go @@ -0,0 +1,66 @@ +package records + +import ( + "math" + "testing" +) + +type types struct{} + +func TestRecord(t *testing.T) { + tys := types{} + fac, err := NewRecordsFactory(t.Context(), tys) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + foo := Foo{ + Float32: 1.0, + Float64: 1.0, + Uint32: 1, + Uint64: uint64(math.MaxUint32), + S: "hello", + Vf32: []float32{1.0, 2.0, 3.0}, + Vf64: []float64{1.0, 2.0, 3.0}, + } + got := ins.ModifyFoo(t.Context(), foo) + want := Foo{ + Float32: foo.Float32 * 2.0, + Float64: foo.Float64 * 2.0, + Uint32: foo.Uint32 + 1, + Uint64: foo.Uint64 + 1, + S: "received hello", + Vf32: []float32{2.0, 4.0, 6.0}, + Vf64: []float64{2.0, 4.0, 6.0}, + } + if !fooCmp(got, want) { + t.Fatalf("got %+v, want %+v", got, want) + } +} + +func fooCmp(a, b Foo) bool { + if a.Float32 != b.Float32 || a.Float64 != b.Float64 || a.Uint32 != b.Uint32 || a.Uint64 != b.Uint64 || a.S != b.S { + return false + } + if len(a.Vf32) != len(b.Vf32) || len(a.Vf64) != len(b.Vf64) { + return false + } + for i := range a.Vf32 { + if a.Vf32[i] != b.Vf32[i] { + return false + } + } + for i := range a.Vf64 { + if a.Vf64[i] != b.Vf64[i] { + return false + } + } + return true +} diff --git a/examples/records/src/lib.rs b/examples/records/src/lib.rs new file mode 100644 index 0000000..f70d0f5 --- /dev/null +++ b/examples/records/src/lib.rs @@ -0,0 +1,31 @@ +wit_bindgen::generate!({ + world: "records", +}); + +struct RecordsWorld; + +export!(RecordsWorld); + +impl Guest for RecordsWorld { + fn modify_foo( + Foo { + float64, + float32, + uint32, + uint64, + s, + vf32, + vf64, + }: Foo, + ) -> Foo { + Foo { + float64: float64 * 2.0, + float32: float32 * 2.0, + uint32: uint32 + 1, + uint64: uint64 + 1, + s: format!("received {s}"), + vf32: vf32.iter().map(|v| v * 2.0).collect(), + vf64: vf64.iter().map(|v| v * 2.0).collect(), + } + } +} diff --git a/examples/records/wit/records.wit b/examples/records/wit/records.wit new file mode 100644 index 0000000..b467229 --- /dev/null +++ b/examples/records/wit/records.wit @@ -0,0 +1,18 @@ +package arcjet:records; + +interface types { + record foo { + float32: f32, + float64: f64, + uint32: u32, + uint64: u64, + s: string, + vf32: list, + vf64: list, + } +} + +world records { + use types.{foo}; + export modify-foo: func(f: foo) -> foo; +} From 990022a6299216865a3bf5019a66ec7849370f9d Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Thu, 4 Dec 2025 21:24:33 +0000 Subject: [PATCH 2/5] Apply suggestions from code review Co-authored-by: blaine-arcjet <146491715+blaine-arcjet@users.noreply.github.com> --- cmd/gravity/src/codegen/func.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/gravity/src/codegen/func.rs b/cmd/gravity/src/codegen/func.rs index d2df053..cf08c09 100644 --- a/cmd/gravity/src/codegen/func.rs +++ b/cmd/gravity/src/codegen/func.rs @@ -1135,18 +1135,18 @@ impl Bindgen for Func<'_> { GoResult::Anon(GoType::ValueOrError(typ)) => { if !$ok { var $default $(typ.as_ref()) - return $default, $ERRORS_NEW("failed to read f64 from memory") + return $default, $ERRORS_NEW("failed to read f32 from memory") } } GoResult::Anon(GoType::Error) => { if !$ok { - return $ERRORS_NEW("failed to read f64 from memory") + return $ERRORS_NEW("failed to read f32 from memory") } } GoResult::Anon(_) | GoResult::Empty => { $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) if !$ok { - panic($ERRORS_NEW("failed to read f64 from memory")) + panic($ERRORS_NEW("failed to read f32 from memory")) } } }) From 599000d71850edf49014a823fab2a49e9b5ba73c Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Thu, 4 Dec 2025 21:34:59 +0000 Subject: [PATCH 3/5] Add fallible records test --- examples/records/records_test.go | 74 ++++++++++++++++++++++++++++++++ examples/records/src/lib.rs | 26 +++++++++++ examples/records/wit/records.wit | 1 + 3 files changed, 101 insertions(+) diff --git a/examples/records/records_test.go b/examples/records/records_test.go index 83e6416..077e975 100644 --- a/examples/records/records_test.go +++ b/examples/records/records_test.go @@ -45,6 +45,80 @@ func TestRecord(t *testing.T) { } } +func TestRecordFallibleSuccess(t *testing.T) { + tys := types{} + fac, err := NewRecordsFactory(t.Context(), tys) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + foo := Foo{ + Float32: 1.0, + Float64: 5.0, // <= 10.0, should succeed + Uint32: 1, + Uint64: uint64(math.MaxUint32), + S: "hello", + Vf32: []float32{1.0, 2.0, 3.0}, + Vf64: []float64{1.0, 2.0, 3.0}, + } + got, err := ins.ModifyFooFallible(t.Context(), foo) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := Foo{ + Float32: foo.Float32 * 2.0, + Float64: foo.Float64 * 2.0, + Uint32: foo.Uint32 + 1, + Uint64: foo.Uint64 + 1, + S: "received hello", + Vf32: []float32{2.0, 4.0, 6.0}, + Vf64: []float64{2.0, 4.0, 6.0}, + } + if !fooCmp(got, want) { + t.Fatalf("got %+v, want %+v", got, want) + } +} + +func TestRecordFallibleError(t *testing.T) { + tys := types{} + fac, err := NewRecordsFactory(t.Context(), tys) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + foo := Foo{ + Float32: 1.0, + Float64: 15.0, // > 10.0, should error + Uint32: 1, + Uint64: uint64(math.MaxUint32), + S: "hello", + Vf32: []float32{1.0, 2.0, 3.0}, + Vf64: []float64{1.0, 2.0, 3.0}, + } + _, err = ins.ModifyFooFallible(t.Context(), foo) + if err == nil { + t.Fatal("expected error, got nil") + } + wantErr := "float64 too big" + if err.Error() != wantErr { + t.Fatalf("got error %q, want %q", err.Error(), wantErr) + } +} + func fooCmp(a, b Foo) bool { if a.Float32 != b.Float32 || a.Float64 != b.Float64 || a.Uint32 != b.Uint32 || a.Uint64 != b.Uint64 || a.S != b.S { return false diff --git a/examples/records/src/lib.rs b/examples/records/src/lib.rs index f70d0f5..4315259 100644 --- a/examples/records/src/lib.rs +++ b/examples/records/src/lib.rs @@ -28,4 +28,30 @@ impl Guest for RecordsWorld { vf64: vf64.iter().map(|v| v * 2.0).collect(), } } + + fn modify_foo_fallible( + Foo { + float64, + float32, + uint32, + uint64, + s, + vf32, + vf64, + }: Foo, + ) -> Result { + if float64 > 10.0 { + Err("float64 too big".to_string()) + } else { + Ok(Foo { + float64: float64 * 2.0, + float32: float32 * 2.0, + uint32: uint32 + 1, + uint64: uint64 + 1, + s: format!("received {s}"), + vf32: vf32.iter().map(|v| v * 2.0).collect(), + vf64: vf64.iter().map(|v| v * 2.0).collect(), + }) + } + } } diff --git a/examples/records/wit/records.wit b/examples/records/wit/records.wit index b467229..c3cfb14 100644 --- a/examples/records/wit/records.wit +++ b/examples/records/wit/records.wit @@ -15,4 +15,5 @@ interface types { world records { use types.{foo}; export modify-foo: func(f: foo) -> foo; + export modify-foo-fallible: func(f: foo) -> result; } From 85448641eaf9c45d5ddbae205bd56a5c28f09eb4 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Thu, 4 Dec 2025 21:37:45 +0000 Subject: [PATCH 4/5] Define record type inline, not in separate interface --- examples/records/records_test.go | 11 +++-------- examples/records/wit/records.wit | 5 +---- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/examples/records/records_test.go b/examples/records/records_test.go index 077e975..f1e0201 100644 --- a/examples/records/records_test.go +++ b/examples/records/records_test.go @@ -5,11 +5,8 @@ import ( "testing" ) -type types struct{} - func TestRecord(t *testing.T) { - tys := types{} - fac, err := NewRecordsFactory(t.Context(), tys) + fac, err := NewRecordsFactory(t.Context()) if err != nil { t.Fatal(err) } @@ -46,8 +43,7 @@ func TestRecord(t *testing.T) { } func TestRecordFallibleSuccess(t *testing.T) { - tys := types{} - fac, err := NewRecordsFactory(t.Context(), tys) + fac, err := NewRecordsFactory(t.Context()) if err != nil { t.Fatal(err) } @@ -87,8 +83,7 @@ func TestRecordFallibleSuccess(t *testing.T) { } func TestRecordFallibleError(t *testing.T) { - tys := types{} - fac, err := NewRecordsFactory(t.Context(), tys) + fac, err := NewRecordsFactory(t.Context()) if err != nil { t.Fatal(err) } diff --git a/examples/records/wit/records.wit b/examples/records/wit/records.wit index c3cfb14..015bfd6 100644 --- a/examples/records/wit/records.wit +++ b/examples/records/wit/records.wit @@ -1,6 +1,6 @@ package arcjet:records; -interface types { +world records { record foo { float32: f32, float64: f64, @@ -10,10 +10,7 @@ interface types { vf32: list, vf64: list, } -} -world records { - use types.{foo}; export modify-foo: func(f: foo) -> foo; export modify-foo-fallible: func(f: foo) -> result; } From 71140b15f67385983decec8dc944e6c839347a20 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Thu, 4 Dec 2025 21:39:34 +0000 Subject: [PATCH 5/5] Update CLI tests --- cmd/gravity/tests/cmd/records.stdout | 212 +++++++++++++++++++++++++-- 1 file changed, 201 insertions(+), 11 deletions(-) diff --git a/cmd/gravity/tests/cmd/records.stdout b/cmd/gravity/tests/cmd/records.stdout index 96d71ef..4745179 100644 --- a/cmd/gravity/tests/cmd/records.stdout +++ b/cmd/gravity/tests/cmd/records.stdout @@ -12,8 +12,6 @@ import _ "embed" //go:embed records.wasm var wasmFileRecords []byte -type IRecordsTypes interface {} - type Foo struct { Float32 float32 @@ -37,16 +35,9 @@ type RecordsFactory struct { func NewRecordsFactory( ctx context.Context, - types IRecordsTypes, ) (*RecordsFactory, error) { wazeroRuntime := wazero.NewRuntime(ctx) - _, err0 := wazeroRuntime.NewHostModuleBuilder("arcjet:records/types"). - Instantiate(ctx) - if err0 != nil { - return nil, err0 - } - // Compiling the module takes a LONG time, so we want to do it once and hold // onto it with the Runtime module, err := wazeroRuntime.CompileModule(ctx, wasmFileRecords) @@ -180,7 +171,7 @@ func (i *RecordsInstance) ModifyFoo( value11, ok11 := i.module.Memory().ReadUint64Le(uint32(results10 + 0)) // The return type doesn't contain an error so we panic if one is encountered if !ok11 { - panic(errors.New("failed to read f64 from memory")) + panic(errors.New("failed to read f32 from memory")) } result12 := api.DecodeF32(value11) value13, ok13 := i.module.Memory().ReadUint64Le(uint32(results10 + 8)) @@ -235,7 +226,7 @@ func (i *RecordsInstance) ModifyFoo( value24, ok24 := i.module.Memory().ReadUint64Le(uint32(base + 0)) // The return type doesn't contain an error so we panic if one is encountered if !ok24 { - panic(errors.New("failed to read f64 from memory")) + panic(errors.New("failed to read f32 from memory")) } result25 := api.DecodeF32(value24) result26[idx26] = result25 @@ -275,3 +266,202 @@ func (i *RecordsInstance) ModifyFoo( return value32 } +func (i *RecordsInstance) ModifyFooFallible( + ctx context.Context, + f Foo, +) (Foo, error) { + arg0 := f + float320 := arg0.Float32 + float640 := arg0.Float64 + uint320 := arg0.Uint32 + uint640 := arg0.Uint64 + s0 := arg0.S + vf320 := arg0.Vf32 + vf640 := arg0.Vf64 + result1 := api.EncodeF32(float320) + result2 := api.EncodeF64(float640) + result3 := api.EncodeU32(uint320) + value4 := int64(uint640) + memory5 := i.module.Memory() + realloc5 := i.module.ExportedFunction("cabi_realloc") + ptr5, len5, err5 := writeString(ctx, s0, memory5, realloc5) + if err5 != nil { + var default5 Foo + return default5, err5 + } + vec7 := vf320 + len7 := uint64(len(vec7)) + result7, err7 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len7 * 4) + if err7 != nil { + var default7 Foo + return default7, err7 + } + ptr7 := result7[0] + for idx := uint64(0); idx < len7; idx++ { + e := vec7[idx] + base := uint32(ptr7 + uint64(idx) * uint64(4)) + result6 := api.EncodeF32(e) + i.module.Memory().WriteUint64Le(base+0, result6) + } + vec9 := vf640 + len9 := uint64(len(vec9)) + result9, err9 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 8, len9 * 8) + if err9 != nil { + var default9 Foo + return default9, err9 + } + ptr9 := result9[0] + for idx := uint64(0); idx < len9; idx++ { + e := vec9[idx] + base := uint32(ptr9 + uint64(idx) * uint64(8)) + result8 := api.EncodeF64(e) + i.module.Memory().WriteUint64Le(base+0, result8) + } + raw10, err10 := i.module.ExportedFunction("modify-foo-fallible").Call(ctx, uint64(result1), uint64(result2), uint64(result3), uint64(value4), uint64(ptr5), uint64(len5), uint64(ptr7), uint64(len7), uint64(ptr9), uint64(len9)) + if err10 != nil { + var default10 Foo + return default10, err10 + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if _, err := i.module.ExportedFunction("cabi_post_modify-foo-fallible").Call(ctx, raw10...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + }() + + results10 := raw10[0] + value11, ok11 := i.module.Memory().ReadByte(uint32(results10 + 0)) + if !ok11 { + var default11 Foo + return default11, errors.New("failed to read byte from memory") + } + var value37 Foo + var err37 error + switch value11 { + case 0: + value12, ok12 := i.module.Memory().ReadUint64Le(uint32(results10 + 8)) + if !ok12 { + var default12 Foo + return default12, errors.New("failed to read f32 from memory") + } + result13 := api.DecodeF32(value12) + value14, ok14 := i.module.Memory().ReadUint64Le(uint32(results10 + 16)) + if !ok14 { + var default14 Foo + return default14, errors.New("failed to read f64 from memory") + } + result15 := api.DecodeF64(value14) + value16, ok16 := i.module.Memory().ReadUint32Le(uint32(results10 + 24)) + if !ok16 { + var default16 Foo + return default16, errors.New("failed to read i32 from memory") + } + result17 := api.DecodeU32(uint64(value16)) + value18, ok18 := i.module.Memory().ReadUint64Le(uint32(results10 + 32)) + if !ok18 { + var default18 Foo + return default18, errors.New("failed to read i64 from memory") + } + value19 := uint64(value18) + ptr20, ok20 := i.module.Memory().ReadUint32Le(uint32(results10 + 40)) + if !ok20 { + var default20 Foo + return default20, errors.New("failed to read pointer from memory") + } + len21, ok21 := i.module.Memory().ReadUint32Le(uint32(results10 + 44)) + if !ok21 { + var default21 Foo + return default21, errors.New("failed to read length from memory") + } + buf22, ok22 := i.module.Memory().Read(ptr20, len21) + if !ok22 { + var default22 Foo + return default22, errors.New("failed to read bytes from memory") + } + str22 := string(buf22) + ptr23, ok23 := i.module.Memory().ReadUint32Le(uint32(results10 + 48)) + if !ok23 { + var default23 Foo + return default23, errors.New("failed to read pointer from memory") + } + len24, ok24 := i.module.Memory().ReadUint32Le(uint32(results10 + 52)) + if !ok24 { + var default24 Foo + return default24, errors.New("failed to read length from memory") + } + base27 := ptr23 + len27 := len24 + result27 := make([]float32, len27) + for idx27 := uint32(0); idx27 < len27; idx27++ { + base := base27 + idx27 * 4 + value25, ok25 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + if !ok25 { + var default25 Foo + return default25, errors.New("failed to read f32 from memory") + } + result26 := api.DecodeF32(value25) + result27[idx27] = result26 + } + ptr28, ok28 := i.module.Memory().ReadUint32Le(uint32(results10 + 56)) + if !ok28 { + var default28 Foo + return default28, errors.New("failed to read pointer from memory") + } + len29, ok29 := i.module.Memory().ReadUint32Le(uint32(results10 + 60)) + if !ok29 { + var default29 Foo + return default29, errors.New("failed to read length from memory") + } + base32 := ptr28 + len32 := len29 + result32 := make([]float64, len32) + for idx32 := uint32(0); idx32 < len32; idx32++ { + base := base32 + idx32 * 8 + value30, ok30 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + if !ok30 { + var default30 Foo + return default30, errors.New("failed to read f64 from memory") + } + result31 := api.DecodeF64(value30) + result32[idx32] = result31 + } + value33 := Foo{ + Float32: result13, + Float64: result15, + Uint32: result17, + Uint64: value19, + S: str22, + Vf32: result27, + Vf64: result32, + } + value37 = value33 + case 1: + ptr34, ok34 := i.module.Memory().ReadUint32Le(uint32(results10 + 8)) + if !ok34 { + var default34 Foo + return default34, errors.New("failed to read pointer from memory") + } + len35, ok35 := i.module.Memory().ReadUint32Le(uint32(results10 + 12)) + if !ok35 { + var default35 Foo + return default35, errors.New("failed to read length from memory") + } + buf36, ok36 := i.module.Memory().Read(ptr34, len35) + if !ok36 { + var default36 Foo + return default36, errors.New("failed to read bytes from memory") + } + str36 := string(buf36) + err37 = errors.New(str36) + default: + err37 = errors.New("invalid variant discriminant for expected") + } + return value37, err37 +} +