diff --git a/builtin/builtin.go b/builtin/builtin.go index 87e73614..3c3a9988 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -98,6 +98,11 @@ var Builtins = []*Function{ Predicate: true, Types: types(new(func([]any, func(any) any) map[any][]any)), }, + { + Name: "uniqBy", + Predicate: true, + Types: types(new(func([]any, func(any) any) []any)), + }, { Name: "sortBy", Predicate: true, diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index 0d0dec35..2f0b5c7b 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -184,6 +184,18 @@ func TestBuiltin(t *testing.T) { {`groupBy(1..3, # > 1)[true]`, []any{2, 3}}, {`groupBy(1..3, # > 1 ? nil : "")[nil]`, []any{2, 3}}, {`groupBy(ArrayOfFoo, .Value).a`, []any{mock.Foo{Value: "a"}}}, + {`uniqBy(1..9, # % 3)`, []any{1, 2, 3}}, + {`uniqBy([], #)`, []any{}}, + {`uniqBy([nil, 1, nil, 2], #)`, []any{nil, 1, 2}}, + {`uniqBy([{id: "a", name: "one"}, {id: "a", name: "two"}, {id: "b", name: "three"}], .id)`, []any{ + map[string]any{"id": "a", "name": "one"}, + map[string]any{"id": "b", "name": "three"}, + }}, + {`uniqBy([[1, 2], [1, 2], [1, 3]], #)`, []any{[]any{1, 2}, []any{1, 3}}}, + {`uniqBy([{id: 1, name: "a"}, {id: 1, name: "b"}, {id: 2, name: "c"}], .id)`, []any{ + map[string]any{"id": 1, "name": "a"}, + map[string]any{"id": 2, "name": "c"}, + }}, {`reduce(1..9, # + #acc, 0)`, 45}, {`reduce(1..9, # + #acc)`, 45}, {`reduce([.5, 1.5, 2.5], # + #acc, 0)`, 4.5}, @@ -728,6 +740,7 @@ func TestBuiltin_with_deref(t *testing.T) { {`findLast(arr, # > 0)`, 3}, {`findLastIndex(arr, # > 0)`, 2}, {`groupBy(arr, # % 2 == 0)`, map[any][]any{false: {1, 3}, true: {2}}}, + {`uniqBy(arr, # % 2)`, []any{1, 2}}, {`sortBy(arr, -#)`, []any{3, 2, 1}}, {`reduce(arr, # + #acc, x)`, 6 + 42}, {`ceil(x)`, 42.0}, @@ -874,10 +887,10 @@ func TestAbs_UnsignedIntegers(t *testing.T) { // Test that abs() correctly handles unsigned integers // Unsigned integers are always non-negative, so abs() should return them unchanged tests := []struct { - name string - env map[string]any - expr string - want any + name string + env map[string]any + expr string + want any }{ {"uint", map[string]any{"x": uint(42)}, "abs(x)", uint(42)}, {"uint8", map[string]any{"x": uint8(42)}, "abs(x)", uint8(42)}, diff --git a/checker/checker.go b/checker/checker.go index 3620f207..a5fbbd64 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -886,6 +886,29 @@ func (v *Checker) builtinNode(node *ast.BuiltinNode) Nature { } return v.error(node.Arguments[1], "predicate should has one input and one output param") + case "uniqBy": + collection := v.visit(node.Arguments[0]) + collection = collection.Deref(&v.config.NtCache) + if !collection.IsArray() && !collection.IsUnknown(&v.config.NtCache) { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection.String()) + } + + v.begin(collection) + predicate := v.visit(node.Arguments[1]) + v.end() + + if predicate.IsFunc() && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && predicate.IsFirstArgUnknown(&v.config.NtCache) { + + if collection.IsUnknown(&v.config.NtCache) { + return v.config.NtCache.FromType(arrayType) + } + collection = collection.Elem(&v.config.NtCache) + return collection.MakeArrayOf(&v.config.NtCache) + } + return v.error(node.Arguments[1], "predicate should has one input and one output param") + case "sortBy": collection := v.visit(node.Arguments[0]) collection = collection.Deref(&v.config.NtCache) diff --git a/compiler/compiler.go b/compiler/compiler.go index 68517535..6190eea8 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1086,7 +1086,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { c.compile(node.Arguments[0]) c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) - c.emit(OpCreate, 1) + c.emit(OpCreate, CreateGroupBy) c.emit(OpSetAcc) c.emitLoop(func() { c.compile(node.Arguments[1]) @@ -1096,6 +1096,21 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { c.emit(OpEnd) return + case "uniqBy": + c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) + c.emit(OpBegin) + c.emit(OpCreate, CreateUniqBy) + c.emit(OpSetAcc) + c.emitLoop(func() { + c.compile(node.Arguments[1]) + c.emit(OpUniqBy) + }) + c.emit(OpGetAcc) + c.emit(OpUniqByResult) + c.emit(OpEnd) + return + case "sortBy": c.compile(node.Arguments[0]) c.derefInNeeded(node.Arguments[0]) @@ -1105,7 +1120,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { } else { c.emit(OpPush, c.addConstant("asc")) } - c.emit(OpCreate, 2) + c.emit(OpCreate, CreateSortBy) c.emit(OpSetAcc) c.emitLoop(func() { c.compile(node.Arguments[1]) diff --git a/docs/language-definition.md b/docs/language-definition.md index 69efbdfa..09fc5c06 100644 --- a/docs/language-definition.md +++ b/docs/language-definition.md @@ -743,6 +743,16 @@ Removes duplicates from an array. uniq([1, 2, 3, 2, 1]) == [1, 2, 3] ``` +### uniqBy(array, predicate) {#uniqBy} + +Removes duplicates from an array using the result of the [predicate](#predicate) as the uniqueness key. +The first element for each unique key is kept. + +```expr +uniqBy(users, .ID) +uniqBy([1, 2, 3, 4], # % 2) == [1, 2] +``` + ### join(array[, delimiter]) {#join} Joins an array of strings into a single string with the given delimiter. diff --git a/expr_test.go b/expr_test.go index fd1ce3ab..236f605c 100644 --- a/expr_test.go +++ b/expr_test.go @@ -712,6 +712,27 @@ func TestExpr_readme_example(t *testing.T) { require.Equal(t, "Hello, world!", output) } +func TestExpr_uniqBy(t *testing.T) { + env := map[string]any{ + "users": []map[string]any{ + {"id": "a", "name": "first"}, + {"id": "a", "name": "second"}, + {"id": "b", "name": "third"}, + }, + } + + program, err := expr.Compile(`uniqBy(users, .id)`, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + + assert.Equal(t, []any{ + map[string]any{"id": "a", "name": "first"}, + map[string]any{"id": "b", "name": "third"}, + }, output) +} + func TestExpr(t *testing.T) { date := time.Date(2017, time.October, 23, 18, 30, 0, 0, time.UTC) oneDay, _ := time.ParseDuration("24h") diff --git a/parser/parser.go b/parser/parser.go index 9e24a71e..01fded3e 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -42,6 +42,7 @@ var predicates = map[string]struct { "findLast": {[]arg{expr, predicate}}, "findLastIndex": {[]arg{expr, predicate}}, "groupBy": {[]arg{expr, predicate}}, + "uniqBy": {[]arg{expr, predicate}}, "sortBy": {[]arg{expr, predicate, expr | optional}}, "reduce": {[]arg{expr, predicate, expr | optional}}, } diff --git a/vm/opcodes.go b/vm/opcodes.go index 5fca0fa2..016fbedb 100644 --- a/vm/opcodes.go +++ b/vm/opcodes.go @@ -79,6 +79,8 @@ const ( OpThrow OpCreate OpGroupBy + OpUniqBy + OpUniqByResult OpSortBy OpSort OpProfileStart @@ -88,3 +90,9 @@ const ( OpOr OpEnd // This opcode must be at the end of this list. ) + +const ( + CreateGroupBy = iota + 1 + CreateSortBy + CreateUniqBy +) diff --git a/vm/program.go b/vm/program.go index 7eb96bd3..19a7a3ba 100644 --- a/vm/program.go +++ b/vm/program.go @@ -360,6 +360,12 @@ func (program *Program) DisassembleWriter(w io.Writer) { case OpGroupBy: code("OpGroupBy") + case OpUniqBy: + code("OpUniqBy") + + case OpUniqByResult: + code("OpUniqByResult") + case OpSortBy: code("OpSortBy") diff --git a/vm/utils.go b/vm/utils.go index 7f1ca1e8..af2a5d6e 100644 --- a/vm/utils.go +++ b/vm/utils.go @@ -3,6 +3,8 @@ package vm import ( "reflect" "time" + + "github.com/expr-lang/expr/vm/runtime" ) type ( @@ -46,6 +48,54 @@ func (s *Scope) Item() any { type groupBy = map[any][]any +type uniqBy struct { + Keys []any + Items []any + Hashable map[any]struct{} +} + +func newUniqBy(size int) *uniqBy { + return &uniqBy{ + Keys: make([]any, 0, size), + Items: make([]any, 0, size), + Hashable: make(map[any]struct{}, size), + } +} + +func (u *uniqBy) Add(key, item any) { + if hash, ok := uniqByHash(key); ok { + if _, exists := u.Hashable[hash]; exists { + return + } + u.Hashable[hash] = struct{}{} + u.Keys = append(u.Keys, key) + u.Items = append(u.Items, item) + return + } + + for _, seen := range u.Keys { + if runtime.Equal(key, seen) { + return + } + } + u.Keys = append(u.Keys, key) + u.Items = append(u.Items, item) +} + +func uniqByHash(key any) (any, bool) { + if runtime.IsNil(key) { + return nil, true + } + switch key := key.(type) { + case string, bool, time.Duration: + return key, true + case time.Time: + return key.UTC(), true + default: + return nil, false + } +} + type Span struct { Name string `json:"name"` Expression string `json:"expression"` diff --git a/vm/vm.go b/vm/vm.go index ba3b5386..0e3912cf 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -549,9 +549,9 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { case OpCreate: switch arg { - case 1: + case CreateGroupBy: vm.push(make(groupBy)) - case 2: + case CreateSortBy: scope := vm.currScope var desc bool order, ok := vm.pop().(string) @@ -571,6 +571,9 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { Array: make([]any, 0, scope.Len), Values: make([]any, 0, scope.Len), }) + case CreateUniqBy: + scope := vm.currScope + vm.push(newUniqBy(scope.Len)) default: panic(fmt.Sprintf("unknown OpCreate argument %v", arg)) } @@ -583,6 +586,14 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { } scope.Acc.(groupBy)[key] = append(scope.Acc.(groupBy)[key], scope.Item()) + case OpUniqBy: + scope := vm.currScope + key := vm.pop() + scope.Acc.(*uniqBy).Add(key, scope.Item()) + + case OpUniqByResult: + vm.push(vm.pop().(*uniqBy).Items) + case OpSortBy: scope := vm.currScope value := vm.pop()