From 6e0c50e8850b1343d0c28fab7226365449e0bb72 Mon Sep 17 00:00:00 2001 From: Alysson Ribeiro <15274059+sonalys@users.noreply.github.com> Date: Thu, 14 May 2026 10:16:00 +0200 Subject: [PATCH] feat: Add nil-safety to runtime.Fetch: - Return nil instead of panic on runtime.Fetch - Ensure OpBegin works with type nil for reflect.Value.Len - Improve perf on early return for nil from in builtin.lib.go.get - Add expr.NilSafe() configuration --- builtin/builtin.go | 8 ++++++-- builtin/lib.go | 6 +++--- compiler/compiler.go | 5 +++++ conf/config.go | 12 +++++++++++- expr.go | 6 ++++++ vm/program.go | 4 ++++ vm/runtime/runtime.go | 21 +++++++++++++-------- vm/vm.go | 17 +++++++++++++++-- vm/vm_test.go | 20 ++++++++++++-------- 9 files changed, 75 insertions(+), 24 deletions(-) diff --git a/builtin/builtin.go b/builtin/builtin.go index 87e73614..7963f306 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -601,7 +601,9 @@ var Builtins = []*Function{ return } }() - return runtime.Fetch(args[0], 0), nil + + value, _ := runtime.Fetch(args[0], 0) + return value, nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { if len(args) != 1 { @@ -624,7 +626,9 @@ var Builtins = []*Function{ return } }() - return runtime.Fetch(args[0], -1), nil + + value, _ := runtime.Fetch(args[0], -1) + return value, nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { if len(args) != 1 { diff --git a/builtin/lib.go b/builtin/lib.go index 61748da0..7d96ef95 100644 --- a/builtin/lib.go +++ b/builtin/lib.go @@ -548,13 +548,13 @@ func get(params ...any) (out any, err error) { return nil, fmt.Errorf("invalid number of arguments (expected 2, got %d)", len(params)) } from := params[0] - i := params[1] - v := reflect.ValueOf(from) - if from == nil { return nil, nil } + i := params[1] + v := reflect.ValueOf(from) + if v.Kind() == reflect.Invalid { panic(fmt.Sprintf("cannot fetch %v from %T", i, from)) } diff --git a/compiler/compiler.go b/compiler/compiler.go index 68517535..5740dd56 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -45,7 +45,11 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro c.compile(tree.Node) + nilSafe := false + if c.config != nil { + nilSafe = c.config.NilSafe + switch c.config.Expect { case reflect.Int: c.emit(OpCast, 0) @@ -77,6 +81,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro c.functions, c.debugInfo, span, + nilSafe, ) return } diff --git a/conf/config.go b/conf/config.go index f7c95d20..0113e986 100644 --- a/conf/config.go +++ b/conf/config.go @@ -41,6 +41,9 @@ type Config struct { // When enabled, the lexer treats `if`/`else` as identifiers and the parser // will not parse `if` statements. DisableIfOperator bool + // NilSafe enables nil-safe navigation for all expressions, + // allowing access to fields and methods on nil values without panicking. + NilSafe bool } // CreateNew creates new config with default values. @@ -77,7 +80,14 @@ func (c *Config) ConstExpr(name string) { if c.EnvObject == nil { panic("no environment is specified for ConstExpr()") } - fn := reflect.ValueOf(runtime.Fetch(c.EnvObject, name)) + + field, ok := runtime.Fetch(c.EnvObject, name) + if !ok { + panic(fmt.Errorf("cannot fetch %q in the environment", name)) + } + + fn := reflect.ValueOf(field) + if fn.Kind() != reflect.Func { panic(fmt.Errorf("const expression %q must be a function", name)) } diff --git a/expr.go b/expr.go index 76fbd426..c56ed71a 100644 --- a/expr.go +++ b/expr.go @@ -225,6 +225,12 @@ func MaxNodes(n uint) Option { } } +func NilSafe() Option { + return func(c *conf.Config) { + c.NilSafe = true + } +} + // Compile parses and compiles given input expression to bytecode program. func Compile(input string, ops ...Option) (*vm.Program, error) { config := conf.CreateNew() diff --git a/vm/program.go b/vm/program.go index 7eb96bd3..4e0fc211 100644 --- a/vm/program.go +++ b/vm/program.go @@ -28,6 +28,8 @@ type Program struct { functions []Function debugInfo map[string]string span *Span + + nilSafe bool } // NewProgram returns a new Program. It's used by the compiler. @@ -42,6 +44,7 @@ func NewProgram( functions []Function, debugInfo map[string]string, span *Span, + nilSafe bool, ) *Program { return &Program{ source: source, @@ -54,6 +57,7 @@ func NewProgram( functions: functions, debugInfo: debugInfo, span: span, + nilSafe: nilSafe, } } diff --git a/vm/runtime/runtime.go b/vm/runtime/runtime.go index bc6f2b4d..a409ccb8 100644 --- a/vm/runtime/runtime.go +++ b/vm/runtime/runtime.go @@ -18,7 +18,11 @@ type fieldCacheKey struct { f string } -func Fetch(from, i any) any { +func Fetch(from, i any) (any, bool) { + if from == nil { + return nil, false + } + v := reflect.ValueOf(from) if v.Kind() == reflect.Invalid { panic(fmt.Sprintf("cannot fetch %v from %T", i, from)) @@ -29,7 +33,7 @@ func Fetch(from, i any) any { if methodName, ok := i.(string); ok { method := v.MethodByName(methodName) if method.IsValid() { - return method.Interface() + return method.Interface(), true } } } @@ -52,7 +56,7 @@ func Fetch(from, i any) any { } value := v.Index(index) if value.IsValid() { - return value.Interface() + return value.Interface(), true } case reflect.Map: @@ -63,10 +67,10 @@ func Fetch(from, i any) any { value = v.MapIndex(reflect.ValueOf(i)) } if value.IsValid() { - return value.Interface() + return value.Interface(), true } else { elem := reflect.TypeOf(from).Elem() - return reflect.Zero(elem).Interface() + return reflect.Zero(elem).Interface(), true } case reflect.Struct: @@ -77,7 +81,7 @@ func Fetch(from, i any) any { f: fieldName, } if cv, ok := fieldCache.Load(key); ok { - return v.FieldByIndex(cv.([]int)).Interface() + return v.FieldByIndex(cv.([]int)).Interface(), true } field, ok := t.FieldByNameFunc(func(name string) bool { field, _ := t.FieldByName(name) @@ -94,11 +98,12 @@ func Fetch(from, i any) any { value := v.FieldByIndex(field.Index) if value.IsValid() { fieldCache.Store(key, field.Index) - return value.Interface() + return value.Interface(), true } } } - panic(fmt.Sprintf("cannot fetch %v from %T", i, from)) + + return nil, false } type Field struct { diff --git a/vm/vm.go b/vm/vm.go index ba3b5386..a827e775 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -122,7 +122,12 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { vm.push(vm.Variables[arg]) case OpLoadConst: - vm.push(runtime.Fetch(env, program.Constants[arg])) + value, ok := runtime.Fetch(env, program.Constants[arg]) + if !ok && !program.nilSafe { + panic(fmt.Sprintf("cannot fetch %v in the environment", program.Constants[arg])) + } + + vm.push(value) case OpLoadField: vm.push(runtime.FetchField(env, program.Constants[arg].(*runtime.Field))) @@ -139,7 +144,12 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { case OpFetch: b := vm.pop() a := vm.pop() - vm.push(runtime.Fetch(a, b)) + + value, ok := runtime.Fetch(a, b) + if !ok && !program.nilSafe { + panic(fmt.Sprintf("cannot fetch %v from %T", b, a)) + } + vm.push(value) case OpFetchField: a := vm.pop() @@ -609,6 +619,9 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { a := vm.pop() s := vm.allocScope() switch v := a.(type) { + case nil: + s.Len = 0 + s.Anys = nil case []int: s.Ints = v s.Len = len(v) diff --git a/vm/vm_test.go b/vm/vm_test.go index c86183ca..dcb2385a 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -694,8 +694,9 @@ func TestVM_DirectCallOpcodes(t *testing.T) { tt.bytecode, tt.args, tt.funcs, - nil, // debugInfo - nil, // span + nil, // debugInfo + nil, // span + false, // nilSafe ) vm := &vm.VM{} got, err := vm.Run(program, nil) @@ -819,9 +820,10 @@ func TestVM_IndexAndCountOperations(t *testing.T) { tt.consts, tt.bytecode, tt.args, - nil, // functions - nil, // debugInfo - nil, // span + nil, // functions + nil, // debugInfo + nil, // span + false, // nilSafe ) vm := &vm.VM{} got, err := vm.Run(program, nil) @@ -1288,9 +1290,10 @@ func TestVM_DirectBasicOpcodes(t *testing.T) { tt.consts, tt.bytecode, tt.args, - nil, // functions - nil, // debugInfo - nil, // span + nil, // functions + nil, // debugInfo + nil, // span + false, // nilSafe ) vm := &vm.VM{} got, err := vm.Run(program, tt.env) @@ -1460,6 +1463,7 @@ func TestVM_OpJump_NegativeOffset(t *testing.T) { nil, nil, nil, + false, // nilSafe ) _, err := vm.Run(program, nil)