From f1b68c877d52fe7e5ca985055e113c9a519bb57c Mon Sep 17 00:00:00 2001 From: bizy Date: Wed, 28 Feb 2024 00:30:33 +0700 Subject: [PATCH 01/15] Compare `any` arrays --- expr_test.go | 26 +++ vm/runtime/helpers/main.go | 46 +++++ vm/runtime/helpers[generated].go | 343 +++++++++++++++++++++++++++++++ 3 files changed, 415 insertions(+) diff --git a/expr_test.go b/expr_test.go index 74975362b..ea9213be2 100644 --- a/expr_test.go +++ b/expr_test.go @@ -2482,3 +2482,29 @@ func TestRaceCondition_variables(t *testing.T) { wg.Wait() } + +func TestArrayComparison(t *testing.T) { + tests := []struct { + env any + code string + }{ + {[]string{"A", "B"}, "foo == ['A', 'B']"}, + {[]int{1, 2}, "foo == [1, 2]"}, + {[]uint8{1, 2}, "foo == [1, 2]"}, + {[]float64{1.1, 2.2}, "foo == [1.1, 2.2]"}, + {[]any{"A", 1, 1.1, true}, "foo == ['A', 1, 1.1, true]"}, + {[]string{"A", "B"}, "foo != [1, 2]"}, + } + + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + env := map[string]any{"foo": tt.env} + program, err := expr.Compile(tt.code, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, true, out) + }) + } +} diff --git a/vm/runtime/helpers/main.go b/vm/runtime/helpers/main.go index b3f598a43..54a4fc235 100644 --- a/vm/runtime/helpers/main.go +++ b/vm/runtime/helpers/main.go @@ -19,6 +19,7 @@ func main() { "cases_with_duration": func(op string) string { return cases(op, uints, ints, floats, []string{"time.Duration"}) }, + "array_equal_cases": func() string { return arrayEqualCases([]string{"string"}, uints, ints, floats) }, }). Parse(helpers), ).Execute(&b, nil) @@ -89,6 +90,45 @@ func cases(op string, xs ...[]string) string { return strings.TrimRight(out, "\n") } +func arrayEqualCases(xs ...[]string) string { + var types []string + for _, x := range xs { + types = append(types, x...) + } + + _, _ = fmt.Fprintf(os.Stderr, "Generating array equal cases for %v\n", types) + + var out string + echo := func(s string, xs ...any) { + out += fmt.Sprintf(s, xs...) + "\n" + } + echo(`case []any:`) + echo(`switch y := b.(type) {`) + for _, a := range append(types, "any") { + echo(`case []%v:`, a) + echo(`if len(x) != len(y) { return false }`) + echo(`for i := range x {`) + echo(`if !Equal(x[i], y[i]) { return false }`) + echo(`}`) + echo("return true") + } + echo(`}`) + for _, a := range types { + echo(`case []%v:`, a) + echo(`switch y := b.(type) {`) + echo(`case []any:`) + echo(`return Equal(y, x)`) + echo(`case []%v:`, a) + echo(`if len(x) != len(y) { return false }`) + echo(`for i := range x {`) + echo(`if x[i] != y[i] { return false }`) + echo(`}`) + echo("return true") + echo(`}`) + } + return strings.TrimRight(out, "\n") +} + func isFloat(t string) bool { return strings.HasPrefix(t, "float") } @@ -110,6 +150,7 @@ import ( func Equal(a, b interface{}) bool { switch x := a.(type) { {{ cases "==" }} + {{ array_equal_cases }} case string: switch y := b.(type) { case string: @@ -125,6 +166,11 @@ func Equal(a, b interface{}) bool { case time.Duration: return x == y } + case bool: + switch y := b.(type) { + case bool: + return x == y + } } if IsNil(a) && IsNil(b) { return true diff --git a/vm/runtime/helpers[generated].go b/vm/runtime/helpers[generated].go index 720feb455..d950f1111 100644 --- a/vm/runtime/helpers[generated].go +++ b/vm/runtime/helpers[generated].go @@ -334,6 +334,344 @@ func Equal(a, b interface{}) bool { case float64: return float64(x) == float64(y) } + case []any: + switch y := b.(type) { + case []string: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint8: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint16: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int8: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int16: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []float32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []float64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []any: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + } + case []string: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []string: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint8: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint8: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint16: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint16: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int8: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int8: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int16: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int16: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []float32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []float32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []float64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []float64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } case string: switch y := b.(type) { case string: @@ -349,6 +687,11 @@ func Equal(a, b interface{}) bool { case time.Duration: return x == y } + case bool: + switch y := b.(type) { + case bool: + return x == y + } } if IsNil(a) && IsNil(b) { return true From ef57900b163f64429fa82a542fe5662e3b41ef1e Mon Sep 17 00:00:00 2001 From: bizy Date: Tue, 5 Mar 2024 01:11:10 +0700 Subject: [PATCH 02/15] Add bench and tests for `runtime.Equal` --- vm/runtime/helpers_test.go | 57 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 vm/runtime/helpers_test.go diff --git a/vm/runtime/helpers_test.go b/vm/runtime/helpers_test.go new file mode 100644 index 000000000..42a0aece0 --- /dev/null +++ b/vm/runtime/helpers_test.go @@ -0,0 +1,57 @@ +package runtime_test + +import ( + "testing" + + "github.com/expr-lang/expr/vm/runtime" + "github.com/stretchr/testify/assert" +) + +var tests = []struct { + name string + a, b any + want bool +}{ + {"int == int", 42, 42, true}, + {"int != int", 42, 33, false}, + {"int == int8", 42, int8(42), true}, + {"int == int16", 42, int16(42), true}, + {"int == int32", 42, int32(42), true}, + {"int == int64", 42, int64(42), true}, + {"float == float", 42.0, 42.0, true}, + {"float != float", 42.0, 33.0, false}, + {"float == int", 42.0, 42, true}, + {"float != int", 42.0, 33, false}, + {"string == string", "foo", "foo", true}, + {"string != string", "foo", "bar", false}, + {"bool == bool", true, true, true}, + {"bool != bool", true, false, false}, + {"[]any == []int", []any{1, 2, 3}, []int{1, 2, 3}, true}, + {"[]any != []int", []any{1, 2, 3}, []int{1, 2, 99}, false}, + {"deep []any == []any", []any{[]int{1}, 2, []any{"3"}}, []any{[]any{1}, 2, []string{"3"}}, true}, + {"deep []any != []any", []any{[]int{1}, 2, []any{"3", "42"}}, []any{[]any{1}, 2, []string{"3"}}, false}, + {"map[string]any == map[string]any", map[string]any{"a": 1}, map[string]any{"a": 1}, true}, + {"map[string]any != map[string]any", map[string]any{"a": 1}, map[string]any{"a": 1, "b": 2}, false}, +} + +func TestEqual(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := runtime.Equal(tt.a, tt.b) + assert.Equal(t, tt.want, got, "Equal(%v, %v) = %v; want %v", tt.a, tt.b, got, tt.want) + got = runtime.Equal(tt.b, tt.a) + assert.Equal(t, tt.want, got, "Equal(%v, %v) = %v; want %v", tt.b, tt.a, got, tt.want) + }) + } + +} + +func BenchmarkEqual(b *testing.B) { + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + runtime.Equal(tt.a, tt.b) + } + }) + } +} From 4023ef3992c7eab21a1bb49b2a2d0428947605d2 Mon Sep 17 00:00:00 2001 From: Sergey Date: Tue, 27 Feb 2024 20:23:49 +0700 Subject: [PATCH 03/15] Support chzained comparisonc`1 < 2 < 3` (#581) --- expr_test.go | 20 +++++++++++++ parser/operator/operator.go | 4 +++ parser/parser.go | 36 ++++++++++++++++++++++ parser/parser_test.go | 60 +++++++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+) diff --git a/expr_test.go b/expr_test.go index ea9213be2..a4321c575 100644 --- a/expr_test.go +++ b/expr_test.go @@ -1253,6 +1253,26 @@ func TestExpr(t *testing.T) { `[nil, 3, 4]?.[0]?.[1]`, nil, }, + { + `1 > 2 < 3`, + false, + }, + { + `1 < 2 < 3`, + true, + }, + { + `1 < 2 < 3 > 4`, + false, + }, + { + `1 < 2 < 3 > 2`, + true, + }, + { + `1 < 2 < 3 == true`, + true, + }, } for _, tt := range tests { diff --git a/parser/operator/operator.go b/parser/operator/operator.go index 8d804c7b3..411a0e2bc 100644 --- a/parser/operator/operator.go +++ b/parser/operator/operator.go @@ -54,3 +54,7 @@ var Binary = map[string]Operator{ "^": {100, Right}, "??": {500, Left}, } + +func IsComparison(op string) bool { + return op == "<" || op == ">" || op == ">=" || op == "<=" +} diff --git a/parser/parser.go b/parser/parser.go index 1eabdebe2..9114bc0c9 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -164,6 +164,11 @@ func (p *parser) parseExpression(precedence int) Node { break } + if operator.IsComparison(opToken.Value) { + nodeLeft = p.parseComparison(nodeLeft, opToken, op.Precedence) + goto next + } + var nodeRight Node if op.Associativity == operator.Left { nodeRight = p.parseExpression(op.Precedence + 1) @@ -685,3 +690,34 @@ func (p *parser) parsePostfixExpression(node Node) Node { } return node } + +func (p *parser) parseComparison(left Node, token Token, precedence int) Node { + var rootNode Node + for { + comparator := p.parseExpression(precedence + 1) + cmpNode := &BinaryNode{ + Operator: token.Value, + Left: left, + Right: comparator, + } + cmpNode.SetLocation(token.Location) + if rootNode == nil { + rootNode = cmpNode + } else { + rootNode = &BinaryNode{ + Operator: "&&", + Left: rootNode, + Right: cmpNode, + } + rootNode.SetLocation(token.Location) + } + + left = comparator + token = p.current + if !(token.Is(Operator) && operator.IsComparison(token.Value) && p.err == nil) { + break + } + p.next() + } + return rootNode +} diff --git a/parser/parser_test.go b/parser/parser_test.go index b633bd52e..9225e1028 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -531,6 +531,66 @@ world`}, To: &IntegerNode{Value: 3}, }, }, + { + `1 < 2 > 3`, + &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 2}, + }, + Right: &BinaryNode{ + Operator: ">", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 3}, + }, + }, + }, + { + `1 < 2 < 3 < 4`, + &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 2}, + }, + Right: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 3}, + }, + }, + Right: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 3}, + Right: &IntegerNode{Value: 4}, + }, + }, + }, + { + `1 < 2 < 3 == true`, + &BinaryNode{ + Operator: "==", + Left: &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 2}, + }, + Right: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 3}, + }, + }, + Right: &BoolNode{Value: true}, + }, + }, } for _, test := range tests { t.Run(test.input, func(t *testing.T) { From 6de8091a097270b16496d4b61f0b5a47bd9c016d Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Thu, 29 Feb 2024 20:51:34 +0100 Subject: [PATCH 04/15] Add spans --- compiler/compiler.go | 29 +++++++++++++++++++++++++++++ conf/config.go | 1 + vm/opcodes.go | 2 ++ vm/program.go | 9 +++++++++ vm/utils.go | 13 +++++++++++++ vm/vm.go | 9 +++++++++ 6 files changed, 63 insertions(+) diff --git a/compiler/compiler.go b/compiler/compiler.go index a4f189e6b..808b53c9b 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -50,6 +50,11 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro } } + var span *Span + if len(c.spans) > 0 { + span = c.spans[0] + } + program = NewProgram( tree.Source, tree.Node, @@ -60,6 +65,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro c.arguments, c.functions, c.debugInfo, + span, ) return } @@ -76,6 +82,7 @@ type compiler struct { functionsIndex map[string]int debugInfo map[string]string nodes []ast.Node + spans []*Span chains [][]int arguments []int } @@ -193,6 +200,28 @@ func (c *compiler) compile(node ast.Node) { c.nodes = c.nodes[:len(c.nodes)-1] }() + if c.config != nil && c.config.Profile { + span := &Span{ + Name: reflect.TypeOf(node).String(), + Expression: node.String(), + } + if len(c.spans) > 0 { + prev := c.spans[len(c.spans)-1] + prev.Children = append(prev.Children, span) + } + c.spans = append(c.spans, span) + defer func() { + if len(c.spans) > 1 { + c.spans = c.spans[:len(c.spans)-1] + } + }() + + c.emit(OpProfileStart, c.addConstant(span)) + defer func() { + c.emit(OpProfileEnd, c.addConstant(span)) + }() + } + switch n := node.(type) { case *ast.NilNode: c.NilNode(n) diff --git a/conf/config.go b/conf/config.go index e543732ce..799898109 100644 --- a/conf/config.go +++ b/conf/config.go @@ -20,6 +20,7 @@ type Config struct { ExpectAny bool Optimize bool Strict bool + Profile bool ConstFns map[string]reflect.Value Visitors []ast.Visitor Functions FunctionsTable diff --git a/vm/opcodes.go b/vm/opcodes.go index 0417dab61..84d751d6b 100644 --- a/vm/opcodes.go +++ b/vm/opcodes.go @@ -81,6 +81,8 @@ const ( OpGroupBy OpSortBy OpSort + OpProfileStart + OpProfileEnd OpBegin OpEnd // This opcode must be at the end of this list. ) diff --git a/vm/program.go b/vm/program.go index 4a878267b..989546744 100644 --- a/vm/program.go +++ b/vm/program.go @@ -27,6 +27,7 @@ type Program struct { variables int functions []Function debugInfo map[string]string + span *Span } // NewProgram returns a new Program. It's used by the compiler. @@ -40,6 +41,7 @@ func NewProgram( arguments []int, functions []Function, debugInfo map[string]string, + span *Span, ) *Program { return &Program{ source: source, @@ -51,6 +53,7 @@ func NewProgram( Arguments: arguments, functions: functions, debugInfo: debugInfo, + span: span, } } @@ -360,6 +363,12 @@ func (program *Program) DisassembleWriter(w io.Writer) { case OpSort: code("OpSort") + case OpProfileStart: + code("OpProfileStart") + + case OpProfileEnd: + code("OpProfileEnd") + case OpBegin: code("OpBegin") diff --git a/vm/utils.go b/vm/utils.go index d7db2a52a..fc2f5e7b8 100644 --- a/vm/utils.go +++ b/vm/utils.go @@ -2,6 +2,7 @@ package vm import ( "reflect" + "time" ) type ( @@ -25,3 +26,15 @@ type Scope struct { } type groupBy = map[any][]any + +type Span struct { + Name string `json:"name"` + Expression string `json:"expression"` + Duration int64 `json:"duration"` + Children []*Span `json:"children"` + start time.Time +} + +func GetSpan(program *Program) *Span { + return program.span +} diff --git a/vm/vm.go b/vm/vm.go index 1e85893b0..7e933ce74 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -8,6 +8,7 @@ import ( "regexp" "sort" "strings" + "time" "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/file" @@ -523,6 +524,14 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { vm.memGrow(uint(scope.Len)) vm.push(sortable.Array) + case OpProfileStart: + span := program.Constants[arg].(*Span) + span.start = time.Now() + + case OpProfileEnd: + span := program.Constants[arg].(*Span) + span.Duration += time.Since(span.start).Nanoseconds() + case OpBegin: a := vm.pop() array := reflect.ValueOf(a) From dc76d4c79c9e3eff6b70c399d3507d7f9006a4ad Mon Sep 17 00:00:00 2001 From: Ganesan Karuppasamy Date: Sun, 3 Mar 2024 21:14:32 +0530 Subject: [PATCH 05/15] Enable Support for Arrays in Sum, Mean, and Median Functions (#580) --- builtin/builtin.go | 201 ++++++---------------------------------- builtin/builtin_test.go | 13 +++ builtin/lib.go | 154 ++++++++++++++++++++++++------ builtin/validation.go | 38 ++++++++ 4 files changed, 206 insertions(+), 200 deletions(-) create mode 100644 builtin/validation.go diff --git a/builtin/builtin.go b/builtin/builtin.go index fc48e111a..7bf377df2 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -135,42 +135,21 @@ var Builtins = []*Function{ Name: "ceil", Fast: Ceil, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for ceil (type %s)", args[0]) + return validateRoundFunc("ceil", args) }, }, { Name: "floor", Fast: Floor, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for floor (type %s)", args[0]) + return validateRoundFunc("floor", args) }, }, { Name: "round", Fast: Round, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for floor (type %s)", args[0]) + return validateRoundFunc("round", args) }, }, { @@ -392,185 +371,63 @@ var Builtins = []*Function{ }, { Name: "max", - Func: Max, + Func: func(args ...any) (any, error) { + return minMax("max", runtime.Less, args...) + }, Validate: func(args []reflect.Type) (reflect.Type, error) { - switch len(args) { - case 0: - return anyType, fmt.Errorf("not enough arguments to call max") - case 1: - if kindName := kind(args[0]); kindName == reflect.Array || kindName == reflect.Slice { - return anyType, nil - } - fallthrough - default: - for _, arg := range args { - switch kind(arg) { - case reflect.Interface, reflect.Array, reflect.Slice: - return anyType, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: - default: - return anyType, fmt.Errorf("invalid argument for max (type %s)", arg) - } - } - return args[0], nil - } + return validateAggregateFunc("max", args) }, }, { Name: "min", - Func: Min, + Func: func(args ...any) (any, error) { + return minMax("min", runtime.More, args...) + }, Validate: func(args []reflect.Type) (reflect.Type, error) { - switch len(args) { - case 0: - return anyType, fmt.Errorf("not enough arguments to call min") - case 1: - if kindName := kind(args[0]); kindName == reflect.Array || kindName == reflect.Slice { - return anyType, nil - } - fallthrough - default: - for _, arg := range args { - switch kind(arg) { - case reflect.Interface, reflect.Array, reflect.Slice: - return anyType, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: - default: - return anyType, fmt.Errorf("invalid argument for min (type %s)", arg) - } - } - return args[0], nil - - } + return validateAggregateFunc("min", args) }, }, { Name: "sum", - Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot sum %s", v.Kind()) - } - sum := int64(0) - i := 0 - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - sum += it.Int() - } else if it.CanFloat() { - goto float - } else { - return nil, fmt.Errorf("cannot sum %s", it.Kind()) - } - } - return int(sum), nil - float: - fSum := float64(sum) - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - fSum += float64(it.Int()) - } else if it.CanFloat() { - fSum += it.Float() - } else { - return nil, fmt.Errorf("cannot sum %s", it.Kind()) - } - } - return fSum, nil - }, + Func: sum, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot sum %s", args[0]) - } - return anyType, nil + return validateAggregateFunc("sum", args) }, }, { Name: "mean", Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot mean %s", v.Kind()) + count, sum, err := mean(args...) + if err != nil { + return nil, err } - if v.Len() == 0 { + if count == 0 { return 0.0, nil } - sum := float64(0) - i := 0 - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - sum += float64(it.Int()) - } else if it.CanFloat() { - sum += it.Float() - } else { - return nil, fmt.Errorf("cannot mean %s", it.Kind()) - } - } - return sum / float64(i), nil + return sum / float64(count), nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot avg %s", args[0]) - } - return floatType, nil + return validateAggregateFunc("mean", args) }, }, { Name: "median", Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot median %s", v.Kind()) - } - if v.Len() == 0 { - return 0.0, nil + values, err := median(args...) + if err != nil { + return nil, err } - s := make([]float64, v.Len()) - for i := 0; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - s[i] = float64(it.Int()) - } else if it.CanFloat() { - s[i] = it.Float() - } else { - return nil, fmt.Errorf("cannot median %s", it.Kind()) + if n := len(values); n > 0 { + sort.Float64s(values) + if n%2 == 1 { + return values[n/2], nil } + return (values[n/2-1] + values[n/2]) / 2, nil } - sort.Float64s(s) - if len(s)%2 == 0 { - return (s[len(s)/2-1] + s[len(s)/2]) / 2, nil - } - return s[len(s)/2], nil + return 0.0, nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot median %s", args[0]) - } - return floatType, nil + return validateAggregateFunc("median", args) }, }, { diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index bc1a2e149..aa324c9be 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -85,19 +85,29 @@ func TestBuiltin(t *testing.T) { {`min(1.5, 2.5, 3.5)`, 1.5}, {`min([1, 2, 3])`, 1}, {`min([1.5, 2.5, 3.5])`, 1.5}, + {`min(-1, [1.5, 2.5, 3.5])`, -1}, {`sum(1..9)`, 45}, {`sum([.5, 1.5, 2.5])`, 4.5}, {`sum([])`, 0}, {`sum([1, 2, 3.0, 4])`, 10.0}, + {`sum(10, [1, 2, 3], 1..9)`, 61}, + {`sum(-10, [1, 2, 3, 4])`, 0}, + {`sum(-10.9, [1, 2, 3, 4, 9])`, 8.1}, {`mean(1..9)`, 5.0}, {`mean([.5, 1.5, 2.5])`, 1.5}, {`mean([])`, 0.0}, {`mean([1, 2, 3.0, 4])`, 2.5}, + {`mean(10, [1, 2, 3], 1..9)`, 4.6923076923076925}, + {`mean(-10, [1, 2, 3, 4])`, 0.0}, + {`mean(10.9, 1..9)`, 5.59}, {`median(1..9)`, 5.0}, {`median([.5, 1.5, 2.5])`, 1.5}, {`median([])`, 0.0}, {`median([1, 2, 3])`, 2.0}, {`median([1, 2, 3, 4])`, 2.5}, + {`median(10, [1, 2, 3], 1..9)`, 4.0}, + {`median(-10, [1, 2, 3, 4])`, 2.0}, + {`median(1..5, 4.9)`, 3.5}, {`toJSON({foo: 1, bar: 2})`, "{\n \"bar\": 2,\n \"foo\": 1\n}"}, {`fromJSON("[1, 2, 3]")`, []any{1.0, 2.0, 3.0}}, {`toBase64("hello")`, "aGVsbG8="}, @@ -207,6 +217,9 @@ func TestBuiltin_errors(t *testing.T) { {`min()`, `not enough arguments to call min`}, {`min(1, "2")`, `invalid argument for min (type string)`}, {`min([1, "2"])`, `invalid argument for min (type string)`}, + {`median(1..9, "t")`, "invalid argument for median (type string)"}, + {`mean("s", 1..9)`, "invalid argument for mean (type string)"}, + {`sum("s", "h")`, "invalid argument for sum (type string)"}, {`duration("error")`, `invalid duration`}, {`date("error")`, `invalid date`}, {`get()`, `invalid number of arguments (expected 2, got 0)`}, diff --git a/builtin/lib.go b/builtin/lib.go index b08c2ed2b..9ff9478aa 100644 --- a/builtin/lib.go +++ b/builtin/lib.go @@ -6,7 +6,7 @@ import ( "reflect" "strconv" - "github.com/expr-lang/expr/vm/runtime" + "github.com/expr-lang/expr/internal/deref" ) func Len(x any) any { @@ -254,45 +254,143 @@ func String(arg any) any { return fmt.Sprintf("%v", arg) } -func Max(args ...any) (any, error) { - return minMaxFunc("max", runtime.Less, args) -} +func sum(args ...any) (any, error) { + var total int + var fTotal float64 + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) -func Min(args ...any) (any, error) { - return minMaxFunc("min", runtime.More, args) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemSum, err := sum(rv.Index(i).Interface()) + if err != nil { + return nil, err + } + switch elemSum := elemSum.(type) { + case int: + total += elemSum + case float64: + fTotal += elemSum + } + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + total += int(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + total += int(rv.Uint()) + case reflect.Float32, reflect.Float64: + fTotal += rv.Float() + default: + return nil, fmt.Errorf("invalid argument for sum (type %T)", arg) + } + } + + if fTotal != 0.0 { + return fTotal + float64(total), nil + } + return total, nil } -func minMaxFunc(name string, fn func(any, any) bool, args []any) (any, error) { +func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { var val any for _, arg := range args { - switch v := arg.(type) { - case []float32, []float64, []uint, []uint8, []uint16, []uint32, []uint64, []int, []int8, []int16, []int32, []int64: - rv := reflect.ValueOf(v) - if rv.Len() == 0 { - return nil, fmt.Errorf("not enough arguments to call %s", name) - } - arg = rv.Index(0).Interface() - for i := 1; i < rv.Len(); i++ { - elem := rv.Index(i).Interface() - if fn(arg, elem) { - arg = elem + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemVal, err := minMax(name, fn, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + switch elemVal.(type) { + case int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64: + if elemVal != nil && (val == nil || fn(val, elemVal)) { + val = elemVal + } + default: + return nil, fmt.Errorf("invalid argument for %s (type %T)", name, elemVal) } + } - case []any: - var err error - if arg, err = minMaxFunc(name, fn, v); err != nil { - return nil, err + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + elemVal := rv.Interface() + if val == nil || fn(val, elemVal) { + val = elemVal } - case float32, float64, uint, uint8, uint16, uint32, uint64, int, int8, int16, int32, int64: default: if len(args) == 1 { - return arg, nil + return args[0], nil } - return nil, fmt.Errorf("invalid argument for %s (type %T)", name, v) - } - if val == nil || fn(val, arg) { - val = arg + return nil, fmt.Errorf("invalid argument for %s (type %T)", name, arg) } } return val, nil } + +func mean(args ...any) (int, float64, error) { + var total float64 + var count int + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemCount, elemSum, err := mean(rv.Index(i).Interface()) + if err != nil { + return 0, 0, err + } + total += elemSum + count += elemCount + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + total += float64(rv.Int()) + count++ + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + total += float64(rv.Uint()) + count++ + case reflect.Float32, reflect.Float64: + total += rv.Float() + count++ + default: + return 0, 0, fmt.Errorf("invalid argument for mean (type %T)", arg) + } + } + return count, total, nil +} + +func median(args ...any) ([]float64, error) { + var values []float64 + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elems, err := median(rv.Index(i).Interface()) + if err != nil { + return nil, err + } + values = append(values, elems...) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + values = append(values, float64(rv.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + values = append(values, float64(rv.Uint())) + case reflect.Float32, reflect.Float64: + values = append(values, rv.Float()) + default: + return nil, fmt.Errorf("invalid argument for median (type %T)", arg) + } + } + return values, nil +} diff --git a/builtin/validation.go b/builtin/validation.go new file mode 100644 index 000000000..057f247e9 --- /dev/null +++ b/builtin/validation.go @@ -0,0 +1,38 @@ +package builtin + +import ( + "fmt" + "reflect" + + "github.com/expr-lang/expr/internal/deref" +) + +func validateAggregateFunc(name string, args []reflect.Type) (reflect.Type, error) { + switch len(args) { + case 0: + return anyType, fmt.Errorf("not enough arguments to call %s", name) + default: + for _, arg := range args { + switch kind(deref.Type(arg)) { + case reflect.Interface, reflect.Array, reflect.Slice: + return anyType, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: + default: + return anyType, fmt.Errorf("invalid argument for %s (type %s)", name, arg) + } + } + return args[0], nil + } +} + +func validateRoundFunc(name string, args []reflect.Type) (reflect.Type, error) { + if len(args) != 1 { + return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) + } + switch kind(args[0]) { + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: + return floatType, nil + default: + return anyType, fmt.Errorf("invalid argument for %s (type %s)", name, args[0]) + } +} From 55be21afd2fd6530e8b7e80b283677bede54916f Mon Sep 17 00:00:00 2001 From: Ganesan Karuppasamy Date: Mon, 4 Mar 2024 19:10:33 +0530 Subject: [PATCH 06/15] Fix `-1 not in []` expressions (#590) --- compiler/compiler_test.go | 33 +++++++++++++ expr_test.go | 8 +++ parser/operator/operator.go | 9 ++++ parser/parser.go | 99 ++++++++++++++++++++----------------- parser/parser_test.go | 61 +++++++++++++++++++++++ 5 files changed, 164 insertions(+), 46 deletions(-) diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 741142a77..fbd83ec86 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -541,6 +541,39 @@ func TestCompile_optimizes_jumps(t *testing.T) { {vm.OpFetch, 0}, }, }, + { + `-1 not in [1, 2, 5]`, + []op{ + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpIn, 0}, + {vm.OpNot, 0}, + }, + }, + { + `1 + 8 not in [1, 2, 5]`, + []op{ + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpIn, 0}, + {vm.OpNot, 0}, + }, + }, + { + `true ? false : 8 not in [1, 2, 5]`, + []op{ + {vm.OpTrue, 0}, + {vm.OpJumpIfFalse, 3}, + {vm.OpPop, 0}, + {vm.OpFalse, 0}, + {vm.OpJump, 5}, + {vm.OpPop, 0}, + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpIn, 0}, + {vm.OpNot, 0}, + }, + }, } for _, test := range tests { diff --git a/expr_test.go b/expr_test.go index a4321c575..46cb8fe89 100644 --- a/expr_test.go +++ b/expr_test.go @@ -785,6 +785,10 @@ func TestExpr(t *testing.T) { `Two not in 0..1`, true, }, + { + `-1 not in [1]`, + true, + }, { `Int32 in [10, 20]`, false, @@ -797,6 +801,10 @@ func TestExpr(t *testing.T) { `String matches ("^" + String + "$")`, true, }, + { + `'foo' + 'bar' not matches 'foobar'`, + false, + }, { `"foobar" contains "bar"`, true, diff --git a/parser/operator/operator.go b/parser/operator/operator.go index 411a0e2bc..4eeaf80ed 100644 --- a/parser/operator/operator.go +++ b/parser/operator/operator.go @@ -20,6 +20,15 @@ func IsBoolean(op string) bool { return op == "and" || op == "or" || op == "&&" || op == "||" } +func AllowedNegateSuffix(op string) bool { + switch op { + case "contains", "matches", "startsWith", "endsWith", "in": + return true + default: + return false + } +} + var Unary = map[string]Operator{ "not": {50, Left}, "!": {50, Left}, diff --git a/parser/parser.go b/parser/parser.go index 9114bc0c9..9cb79cbbb 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -126,10 +126,8 @@ func (p *parser) expect(kind Kind, values ...string) { // parse functions func (p *parser) parseExpression(precedence int) Node { - if precedence == 0 { - if p.current.Is(Operator, "let") { - return p.parseVariableDeclaration() - } + if precedence == 0 && p.current.Is(Operator, "let") { + return p.parseVariableDeclaration() } nodeLeft := p.parsePrimary() @@ -137,62 +135,71 @@ func (p *parser) parseExpression(precedence int) Node { prevOperator := "" opToken := p.current for opToken.Is(Operator) && p.err == nil { - negate := false + negate := opToken.Is(Operator, "not") var notToken Token // Handle "not *" operator, like "not in" or "not contains". - if opToken.Is(Operator, "not") { + if negate { + currentPos := p.pos p.next() - notToken = p.current - negate = true - opToken = p.current + if operator.AllowedNegateSuffix(p.current.Value) { + if op, ok := operator.Binary[p.current.Value]; ok && op.Precedence >= precedence { + notToken = p.current + opToken = p.current + } else { + p.pos = currentPos + p.current = opToken + break + } + } else { + p.error("unexpected token %v", p.current) + break + } } - if op, ok := operator.Binary[opToken.Value]; ok { - if op.Precedence >= precedence { - p.next() + if op, ok := operator.Binary[opToken.Value]; ok && op.Precedence >= precedence { + p.next() - if opToken.Value == "|" { - identToken := p.current - p.expect(Identifier) - nodeLeft = p.parseCall(identToken, []Node{nodeLeft}, true) - goto next - } + if opToken.Value == "|" { + identToken := p.current + p.expect(Identifier) + nodeLeft = p.parseCall(identToken, []Node{nodeLeft}, true) + goto next + } - if prevOperator == "??" && opToken.Value != "??" && !opToken.Is(Bracket, "(") { - p.errorAt(opToken, "Operator (%v) and coalesce expressions (??) cannot be mixed. Wrap either by parentheses.", opToken.Value) - break - } + if prevOperator == "??" && opToken.Value != "??" && !opToken.Is(Bracket, "(") { + p.errorAt(opToken, "Operator (%v) and coalesce expressions (??) cannot be mixed. Wrap either by parentheses.", opToken.Value) + break + } - if operator.IsComparison(opToken.Value) { - nodeLeft = p.parseComparison(nodeLeft, opToken, op.Precedence) - goto next - } + if operator.IsComparison(opToken.Value) { + nodeLeft = p.parseComparison(nodeLeft, opToken, op.Precedence) + goto next + } - var nodeRight Node - if op.Associativity == operator.Left { - nodeRight = p.parseExpression(op.Precedence + 1) - } else { - nodeRight = p.parseExpression(op.Precedence) - } + var nodeRight Node + if op.Associativity == operator.Left { + nodeRight = p.parseExpression(op.Precedence + 1) + } else { + nodeRight = p.parseExpression(op.Precedence) + } - nodeLeft = &BinaryNode{ - Operator: opToken.Value, - Left: nodeLeft, - Right: nodeRight, - } - nodeLeft.SetLocation(opToken.Location) + nodeLeft = &BinaryNode{ + Operator: opToken.Value, + Left: nodeLeft, + Right: nodeRight, + } + nodeLeft.SetLocation(opToken.Location) - if negate { - nodeLeft = &UnaryNode{ - Operator: "not", - Node: nodeLeft, - } - nodeLeft.SetLocation(notToken.Location) + if negate { + nodeLeft = &UnaryNode{ + Operator: "not", + Node: nodeLeft, } - - goto next + nodeLeft.SetLocation(notToken.Location) } + + goto next } break diff --git a/parser/parser_test.go b/parser/parser_test.go index 9225e1028..2a30787a0 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -365,6 +365,62 @@ world`}, &UnaryNode{Operator: "not", Node: &IdentifierNode{Value: "in_var"}}, }, + { + "-1 not in [1, 2, 3, 4]", + &UnaryNode{Operator: "not", + Node: &BinaryNode{Operator: "in", + Left: &UnaryNode{Operator: "-", Node: &IntegerNode{Value: 1}}, + Right: &ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &IntegerNode{Value: 4}, + }}}}, + }, + { + "1*8 not in [1, 2, 3, 4]", + &UnaryNode{Operator: "not", + Node: &BinaryNode{Operator: "in", + Left: &BinaryNode{Operator: "*", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 8}, + }, + Right: &ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &IntegerNode{Value: 4}, + }}}}, + }, + { + "2==2 ? false : 3 not in [1, 2, 5]", + &ConditionalNode{ + Cond: &BinaryNode{ + Operator: "==", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 2}, + }, + Exp1: &BoolNode{Value: false}, + Exp2: &UnaryNode{ + Operator: "not", + Node: &BinaryNode{ + Operator: "in", + Left: &IntegerNode{Value: 3}, + Right: &ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 5}, + }}}}}, + }, + { + "'foo' + 'bar' not matches 'foobar'", + &UnaryNode{Operator: "not", + Node: &BinaryNode{Operator: "matches", + Left: &BinaryNode{Operator: "+", + Left: &StringNode{Value: "foo"}, + Right: &StringNode{Value: "bar"}}, + Right: &StringNode{Value: "foobar"}}}, + }, { "all(Tickets, #)", &BuiltinNode{ @@ -706,6 +762,11 @@ invalid float literal: strconv.ParseFloat: parsing "0o1E+1": invalid syntax (1:6 invalid float literal: strconv.ParseFloat: parsing "1E": invalid syntax (1:2) | 1E | .^ + +1 not == [1, 2, 5] +unexpected token Operator("==") (1:7) + | 1 not == [1, 2, 5] + | ......^ ` func TestParse_error(t *testing.T) { From 56448f81ea4e9ab8edefa4973d0d45012532e20d Mon Sep 17 00:00:00 2001 From: Sergey Date: Sun, 17 Mar 2024 14:51:53 +0700 Subject: [PATCH 07/15] `expr.Operator` passes before `expr.Env` caused error (#606) --- checker/checker_test.go | 2 +- conf/config.go | 6 +++++- expr_test.go | 14 ++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/checker/checker_test.go b/checker/checker_test.go index d6a84abc5..29c50807e 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -632,7 +632,7 @@ func TestCheck_TaggedFieldName(t *testing.T) { tree, err := parser.Parse(`foo.bar`) require.NoError(t, err) - config := &conf.Config{} + config := conf.CreateNew() expr.Env(struct { x struct { y bool `expr:"bar"` diff --git a/conf/config.go b/conf/config.go index 799898109..01a407a10 100644 --- a/conf/config.go +++ b/conf/config.go @@ -32,6 +32,7 @@ type Config struct { func CreateNew() *Config { c := &Config{ Optimize: true, + Types: make(TypesTable), ConstFns: make(map[string]reflect.Value), Functions: make(map[string]*builtin.Function), Builtins: make(map[string]*builtin.Function), @@ -62,7 +63,10 @@ func (c *Config) WithEnv(env any) { } c.Env = env - c.Types = CreateTypesTable(env) + types := CreateTypesTable(env) + for name, t := range types { + c.Types[name] = t + } c.MapEnv = mapEnv c.DefaultType = mapValueType c.Strict = true diff --git a/expr_test.go b/expr_test.go index 46cb8fe89..790fdd5d9 100644 --- a/expr_test.go +++ b/expr_test.go @@ -2511,6 +2511,20 @@ func TestRaceCondition_variables(t *testing.T) { wg.Wait() } +func TestOperatorDependsOnEnv(t *testing.T) { + env := map[string]any{ + "plus": func(a, b int) int { + return 42 + }, + } + program, err := expr.Compile(`1 + 2`, expr.Operator("+", "plus"), expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, 42, out) +} + func TestArrayComparison(t *testing.T) { tests := []struct { env any From 745daf583ff241ff1e0bd6c680b92cb55fb984a5 Mon Sep 17 00:00:00 2001 From: Sergey Date: Thu, 21 Mar 2024 02:06:04 +0700 Subject: [PATCH 08/15] builtin `int` unwraps underlying int value (#611) --- builtin/builtin_test.go | 14 ++++++++++++++ builtin/lib.go | 4 ++++ 2 files changed, 18 insertions(+) diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index aa324c9be..7f5045f41 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -612,3 +612,17 @@ func TestBuiltin_bitOpsFunc(t *testing.T) { }) } } + +type customInt int + +func Test_int_unwraps_underlying_value(t *testing.T) { + env := map[string]any{ + "customInt": customInt(42), + } + program, err := expr.Compile(`int(customInt) == 42`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, true, out) +} diff --git a/builtin/lib.go b/builtin/lib.go index 9ff9478aa..e3a6c0aef 100644 --- a/builtin/lib.go +++ b/builtin/lib.go @@ -209,6 +209,10 @@ func Int(x any) any { } return i default: + val := reflect.ValueOf(x) + if val.CanConvert(integerType) { + return val.Convert(integerType).Interface() + } panic(fmt.Sprintf("invalid operation: int(%T)", x)) } } From 594d0c395cc84c70c3c66570a2bbab1aa4bcdc38 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Wed, 27 Mar 2024 09:26:56 +0100 Subject: [PATCH 09/15] Better map ast printing --- ast/print.go | 8 +++++++- ast/print_test.go | 5 +++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/ast/print.go b/ast/print.go index fa593ae28..063e9eb27 100644 --- a/ast/print.go +++ b/ast/print.go @@ -202,5 +202,11 @@ func (n *MapNode) String() string { } func (n *PairNode) String() string { - return fmt.Sprintf("%s: %s", n.Key.String(), n.Value.String()) + if str, ok := n.Key.(*StringNode); ok { + if utils.IsValidIdentifier(str.Value) { + return fmt.Sprintf("%s: %s", str.Value, n.Value.String()) + } + return fmt.Sprintf("%q: %s", str.String(), n.Value.String()) + } + return fmt.Sprintf("(%s): %s", n.Key.String(), n.Value.String()) } diff --git a/ast/print_test.go b/ast/print_test.go index 16d64357b..d9e55c2ea 100644 --- a/ast/print_test.go +++ b/ast/print_test.go @@ -55,8 +55,8 @@ func TestPrint(t *testing.T) { {`func(a)`, `func(a)`}, {`func(a, b)`, `func(a, b)`}, {`{}`, `{}`}, - {`{a: b}`, `{"a": b}`}, - {`{a: b, c: d}`, `{"a": b, "c": d}`}, + {`{a: b}`, `{a: b}`}, + {`{a: b, c: d}`, `{a: b, c: d}`}, {`[]`, `[]`}, {`[a]`, `[a]`}, {`[a, b]`, `[a, b]`}, @@ -71,6 +71,7 @@ func TestPrint(t *testing.T) { {`a[1:]`, `a[1:]`}, {`a[:]`, `a[:]`}, {`(nil ?? 1) > 0`, `(nil ?? 1) > 0`}, + {`{("a" + "b"): 42}`, `{("a" + "b"): 42}`}, } for _, tt := range tests { From cdb4565b29b2ad6846c0f1ab2bd318f18e334bb9 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Thu, 28 Mar 2024 21:37:51 +0800 Subject: [PATCH 10/15] feat: extract code for compiling equal operator (#614) --- compiler/compiler.go | 48 ++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/compiler/compiler.go b/compiler/compiler.go index 808b53c9b..a38d977d5 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -395,34 +395,12 @@ func (c *compiler) UnaryNode(node *ast.UnaryNode) { } func (c *compiler) BinaryNode(node *ast.BinaryNode) { - l := kind(node.Left) - r := kind(node.Right) - - leftIsSimple := isSimpleType(node.Left) - rightIsSimple := isSimpleType(node.Right) - leftAndRightAreSimple := leftIsSimple && rightIsSimple - switch node.Operator { case "==": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - - if l == r && l == reflect.Int && leftAndRightAreSimple { - c.emit(OpEqualInt) - } else if l == r && l == reflect.String && leftAndRightAreSimple { - c.emit(OpEqualString) - } else { - c.emit(OpEqual) - } + c.equalBinaryNode(node) case "!=": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpEqual) + c.equalBinaryNode(node) c.emit(OpNot) case "or", "||": @@ -580,6 +558,28 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) { } } +func (c *compiler) equalBinaryNode(node *ast.BinaryNode) { + l := kind(node.Left) + r := kind(node.Right) + + leftIsSimple := isSimpleType(node.Left) + rightIsSimple := isSimpleType(node.Right) + leftAndRightAreSimple := leftIsSimple && rightIsSimple + + c.compile(node.Left) + c.derefInNeeded(node.Left) + c.compile(node.Right) + c.derefInNeeded(node.Right) + + if l == r && l == reflect.Int && leftAndRightAreSimple { + c.emit(OpEqualInt) + } else if l == r && l == reflect.String && leftAndRightAreSimple { + c.emit(OpEqualString) + } else { + c.emit(OpEqual) + } +} + func isSimpleType(node ast.Node) bool { if node == nil { return false From 10bf15a3f75d10bababbdf92209973365e4a2553 Mon Sep 17 00:00:00 2001 From: Richard Wooding Date: Mon, 8 Apr 2024 09:00:54 +0200 Subject: [PATCH 11/15] Update README.md (#619) Add SPAN Digital as an user of expr --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index bd34c7d24..1475fe2f5 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,7 @@ func main() { * [Visually.io](https://visually.io) employs Expr as a business rule engine for its personalization targeting algorithm. * [Akvorado](https://github.com/akvorado/akvorado) utilizes Expr to classify exporters and interfaces in network flows. * [keda.sh](https://keda.sh) uses Expr to allow customization of its Kubernetes-based event-driven autoscaling. +* [SPAN Digital](https://spandigital.com/) uses Expr in it's Knowledge Management products [Add your company too](https://github.com/expr-lang/expr/edit/master/README.md) From 3d488a9b8881f637042ce0188b9c719d21a422e5 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Tue, 9 Apr 2024 09:47:41 +0200 Subject: [PATCH 12/15] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1475fe2f5..1a2d7dc83 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ func main() { * [Visually.io](https://visually.io) employs Expr as a business rule engine for its personalization targeting algorithm. * [Akvorado](https://github.com/akvorado/akvorado) utilizes Expr to classify exporters and interfaces in network flows. * [keda.sh](https://keda.sh) uses Expr to allow customization of its Kubernetes-based event-driven autoscaling. -* [SPAN Digital](https://spandigital.com/) uses Expr in it's Knowledge Management products +* [Span Digital](https://spandigital.com/) uses Expr in it's Knowledge Management products. [Add your company too](https://github.com/expr-lang/expr/edit/master/README.md) From 4a115c313858103452295f2a2eccafaac3f1cd36 Mon Sep 17 00:00:00 2001 From: needsure <166317845+needsure@users.noreply.github.com> Date: Tue, 9 Apr 2024 19:13:55 +0800 Subject: [PATCH 13/15] chore: fix some typos in conments (#622) Signed-off-by: needsure --- patcher/value/value.go | 4 ++-- test/operator/operator_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/patcher/value/value.go b/patcher/value/value.go index 59351be6b..28f52be27 100644 --- a/patcher/value/value.go +++ b/patcher/value/value.go @@ -13,9 +13,9 @@ import ( // ValueGetter is a Patcher that allows custom types to be represented as standard go values for use with expr. // It also adds the `$patcher_value_getter` function to the program for efficiently calling matching interfaces. // -// The purpose of this Patcher is to make it seemless to use custom types in expressions without the need to +// The purpose of this Patcher is to make it seamless to use custom types in expressions without the need to // first convert them to standard go values. It may also facilitate using already existing structs or maps as -// environments when they contain compatabile types. +// environments when they contain compatible types. // // An example usage may be modeling a database record with columns that have varying data types and constraints. // In such an example you may have custom types that, beyond storing a simple value, such as an integer, may diff --git a/test/operator/operator_test.go b/test/operator/operator_test.go index a19c191dc..b49d91cc6 100644 --- a/test/operator/operator_test.go +++ b/test/operator/operator_test.go @@ -77,7 +77,7 @@ func TestOperator_Function(t *testing.T) { } for _, tt := range tests { - t.Run(fmt.Sprintf(`opertor function helper test %s`, tt.input), func(t *testing.T) { + t.Run(fmt.Sprintf(`operator function helper test %s`, tt.input), func(t *testing.T) { program, err := expr.Compile( tt.input, expr.Env(env), From df36ecb3bed0929bd8ebb904c871fa2b98bf821d Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Wed, 10 Apr 2024 10:58:17 +0200 Subject: [PATCH 14/15] Revert "Optimize boolean operations between all, any, one, none functions (#555)" (#625) This reverts commit 3c03e5965172519f7bc12100db6607d6a9fae031. --- optimizer/optimizer.go | 1 - optimizer/optimizer_test.go | 122 ----------------------------- optimizer/predicate_combination.go | 51 ------------ 3 files changed, 174 deletions(-) delete mode 100644 optimizer/predicate_combination.go diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 6d1fb0b54..a9c0fa3d3 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -36,6 +36,5 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &filterLen{}) Walk(node, &filterLast{}) Walk(node, &filterFirst{}) - Walk(node, &predicateCombination{}) return nil } diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index 703bd1ceb..e45de763b 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -1,7 +1,6 @@ package optimizer_test import ( - "fmt" "reflect" "strings" "testing" @@ -340,124 +339,3 @@ func TestOptimize_filter_map_first(t *testing.T) { assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) } - -func TestOptimize_predicate_combination(t *testing.T) { - tests := []struct { - op string - fn string - wantOp string - }{ - {"and", "all", "and"}, - {"&&", "all", "&&"}, - {"or", "all", "or"}, - {"||", "all", "||"}, - {"and", "any", "and"}, - {"&&", "any", "&&"}, - {"or", "any", "or"}, - {"||", "any", "||"}, - {"and", "none", "or"}, - {"&&", "none", "||"}, - {"and", "one", "or"}, - {"&&", "one", "||"}, - } - - for _, tt := range tests { - rule := fmt.Sprintf(`%s(users, .Age > 18 and .Name != "Bob") %s %s(users, .Age < 30)`, tt.fn, tt.op, tt.fn) - t.Run(rule, func(t *testing.T) { - tree, err := parser.Parse(rule) - require.NoError(t, err) - - err = optimizer.Optimize(&tree.Node, nil) - require.NoError(t, err) - - expected := &ast.BuiltinNode{ - Name: tt.fn, - Arguments: []ast.Node{ - &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: tt.wantOp, - Left: &ast.BinaryNode{ - Operator: "and", - Left: &ast.BinaryNode{ - Operator: ">", - Left: &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Age"}, - }, - Right: &ast.IntegerNode{Value: 18}, - }, - Right: &ast.BinaryNode{ - Operator: "!=", - Left: &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Name"}, - }, - Right: &ast.StringNode{Value: "Bob"}, - }, - }, - Right: &ast.BinaryNode{ - Operator: "<", - Left: &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Age"}, - }, - Right: &ast.IntegerNode{Value: 30}, - }, - }, - }, - }, - } - assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) - }) - } -} - -func TestOptimize_predicate_combination_nested(t *testing.T) { - tree, err := parser.Parse(`any(users, {all(.Friends, {.Age == 18 })}) && any(users, {all(.Friends, {.Name != "Bob" })})`) - require.NoError(t, err) - - err = optimizer.Optimize(&tree.Node, nil) - require.NoError(t, err) - - expected := &ast.BuiltinNode{ - Name: "any", - Arguments: []ast.Node{ - &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ - Node: &ast.BuiltinNode{ - Name: "all", - Arguments: []ast.Node{ - &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Friends"}, - }, - &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: "&&", - Left: &ast.BinaryNode{ - Operator: "==", - Left: &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Age"}, - }, - Right: &ast.IntegerNode{Value: 18}, - }, - Right: &ast.BinaryNode{ - Operator: "!=", - Left: &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Name"}, - }, - Right: &ast.StringNode{Value: "Bob"}, - }, - }, - }, - }, - }, - }, - }, - } - - assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) -} diff --git a/optimizer/predicate_combination.go b/optimizer/predicate_combination.go deleted file mode 100644 index 2733781df..000000000 --- a/optimizer/predicate_combination.go +++ /dev/null @@ -1,51 +0,0 @@ -package optimizer - -import ( - . "github.com/expr-lang/expr/ast" - "github.com/expr-lang/expr/parser/operator" -) - -type predicateCombination struct{} - -func (v *predicateCombination) Visit(node *Node) { - if op, ok := (*node).(*BinaryNode); ok && operator.IsBoolean(op.Operator) { - if left, ok := op.Left.(*BuiltinNode); ok { - if combinedOp, ok := combinedOperator(left.Name, op.Operator); ok { - if right, ok := op.Right.(*BuiltinNode); ok && right.Name == left.Name { - if left.Arguments[0].Type() == right.Arguments[0].Type() && left.Arguments[0].String() == right.Arguments[0].String() { - closure := &ClosureNode{ - Node: &BinaryNode{ - Operator: combinedOp, - Left: left.Arguments[1].(*ClosureNode).Node, - Right: right.Arguments[1].(*ClosureNode).Node, - }, - } - v.Visit(&closure.Node) - Patch(node, &BuiltinNode{ - Name: left.Name, - Arguments: []Node{ - left.Arguments[0], - closure, - }, - }) - } - } - } - } - } -} - -func combinedOperator(fn, op string) (string, bool) { - switch fn { - case "all", "any": - return op, true - case "one", "none": - switch op { - case "and": - return "or", true - case "&&": - return "||", true - } - } - return "", false -} From 583bb9d97878c63450d81eba61e69403f1df3f9d Mon Sep 17 00:00:00 2001 From: Sergey Date: Fri, 12 Apr 2024 23:49:41 +0700 Subject: [PATCH 15/15] Optimize boolean operations between all, any, none functions (#626) --- expr_test.go | 189 +++++++++++++++++++++++++++++ optimizer/optimizer.go | 1 + optimizer/optimizer_test.go | 116 ++++++++++++++++++ optimizer/predicate_combination.go | 61 ++++++++++ 4 files changed, 367 insertions(+) create mode 100644 optimizer/predicate_combination.go diff --git a/expr_test.go b/expr_test.go index 790fdd5d9..ac8eecf48 100644 --- a/expr_test.go +++ b/expr_test.go @@ -901,18 +901,147 @@ func TestExpr(t *testing.T) { `all(1..3, {# > 0})`, true, }, + { + `all(1..3, {# > 0}) && all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# > 2}) && all(1..3, {# < 4})`, + false, + }, + { + `all(1..3, {# > 0}) && all(1..3, {# < 2})`, + false, + }, + { + `all(1..3, {# > 2}) && all(1..3, {# < 2})`, + false, + }, + { + `all(1..3, {# > 0}) || all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# > 0}) || all(1..3, {# != 2})`, + true, + }, + { + `all(1..3, {# != 3}) || all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# != 3}) || all(1..3, {# != 2})`, + false, + }, { `none(1..3, {# == 0})`, true, }, + { + `none(1..3, {# == 0}) && none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 0}) && none(1..3, {# == 3})`, + false, + }, + { + `none(1..3, {# == 1}) && none(1..3, {# == 4})`, + false, + }, + { + `none(1..3, {# == 1}) && none(1..3, {# == 3})`, + false, + }, + { + `none(1..3, {# == 0}) || none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 0}) || none(1..3, {# == 3})`, + true, + }, + { + `none(1..3, {# == 1}) || none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 1}) || none(1..3, {# == 3})`, + false, + }, { `any([1,1,0,1], {# == 0})`, true, }, + { + `any(1..3, {# == 1}) && any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 0}) && any(1..3, {# == 2})`, + false, + }, + { + `any(1..3, {# == 1}) && any(1..3, {# == 4})`, + false, + }, + { + `any(1..3, {# == 0}) && any(1..3, {# == 4})`, + false, + }, + { + `any(1..3, {# == 1}) || any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 0}) || any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 1}) || any(1..3, {# == 4})`, + true, + }, + { + `any(1..3, {# == 0}) || any(1..3, {# == 4})`, + false, + }, { `one([1,1,0,1], {# == 0}) and not one([1,0,0,1], {# == 0})`, true, }, + { + `one(1..3, {# == 1}) and one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`, + false, + }, + { + `one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`, + false, + }, + { + `one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`, + false, + }, + { + `one(1..3, {# == 1}) or one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`, + false, + }, + { `count(1..30, {# % 3 == 0})`, 10, @@ -2525,6 +2654,66 @@ func TestOperatorDependsOnEnv(t *testing.T) { assert.Equal(t, 42, out) } +func TestIssue624(t *testing.T) { + type tag struct { + Name string + } + + type item struct { + Tags []tag + } + + i := item{ + Tags: []tag{ + {Name: "one"}, + {Name: "two"}, + }, + } + + rule := `[ +true && true, +one(Tags, .Name in ["one"]), +one(Tags, .Name in ["two"]), +one(Tags, .Name in ["one"]) && one(Tags, .Name in ["two"]) +]` + resp, err := expr.Eval(rule, i) + require.NoError(t, err) + require.Equal(t, []interface{}{true, true, true, true}, resp) +} + +func TestPredicateCombination(t *testing.T) { + tests := []struct { + code1 string + code2 string + }{ + {"all(1..3, {# > 0}) && all(1..3, {# < 4})", "all(1..3, {# > 0 && # < 4})"}, + {"all(1..3, {# > 1}) && all(1..3, {# < 4})", "all(1..3, {# > 1 && # < 4})"}, + {"all(1..3, {# > 0}) && all(1..3, {# < 2})", "all(1..3, {# > 0 && # < 2})"}, + {"all(1..3, {# > 1}) && all(1..3, {# < 2})", "all(1..3, {# > 1 && # < 2})"}, + + {"any(1..3, {# > 0}) || any(1..3, {# < 4})", "any(1..3, {# > 0 || # < 4})"}, + {"any(1..3, {# > 1}) || any(1..3, {# < 4})", "any(1..3, {# > 1 || # < 4})"}, + {"any(1..3, {# > 0}) || any(1..3, {# < 2})", "any(1..3, {# > 0 || # < 2})"}, + {"any(1..3, {# > 1}) || any(1..3, {# < 2})", "any(1..3, {# > 1 || # < 2})"}, + + {"none(1..3, {# > 0}) && none(1..3, {# < 4})", "none(1..3, {# > 0 || # < 4})"}, + {"none(1..3, {# > 1}) && none(1..3, {# < 4})", "none(1..3, {# > 1 || # < 4})"}, + {"none(1..3, {# > 0}) && none(1..3, {# < 2})", "none(1..3, {# > 0 || # < 2})"}, + {"none(1..3, {# > 1}) && none(1..3, {# < 2})", "none(1..3, {# > 1 || # < 2})"}, + } + for _, tt := range tests { + t.Run(tt.code1, func(t *testing.T) { + out1, err := expr.Eval(tt.code1, nil) + require.NoError(t, err) + + out2, err := expr.Eval(tt.code2, nil) + require.NoError(t, err) + + require.Equal(t, out1, out2) + }) + } +} + func TestArrayComparison(t *testing.T) { tests := []struct { env any diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index a9c0fa3d3..6d1fb0b54 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -36,5 +36,6 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &filterLen{}) Walk(node, &filterLast{}) Walk(node, &filterFirst{}) + Walk(node, &predicateCombination{}) return nil } diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index e45de763b..316b17182 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -1,6 +1,7 @@ package optimizer_test import ( + "fmt" "reflect" "strings" "testing" @@ -339,3 +340,118 @@ func TestOptimize_filter_map_first(t *testing.T) { assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) } + +func TestOptimize_predicate_combination(t *testing.T) { + tests := []struct { + op string + fn string + wantOp string + }{ + {"and", "all", "and"}, + {"&&", "all", "&&"}, + {"or", "any", "or"}, + {"||", "any", "||"}, + {"and", "none", "or"}, + {"&&", "none", "||"}, + } + + for _, tt := range tests { + rule := fmt.Sprintf(`%s(users, .Age > 18 and .Name != "Bob") %s %s(users, .Age < 30)`, tt.fn, tt.op, tt.fn) + t.Run(rule, func(t *testing.T) { + tree, err := parser.Parse(rule) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BuiltinNode{ + Name: tt.fn, + Arguments: []ast.Node{ + &ast.IdentifierNode{Value: "users"}, + &ast.ClosureNode{ + Node: &ast.BinaryNode{ + Operator: tt.wantOp, + Left: &ast.BinaryNode{ + Operator: "and", + Left: &ast.BinaryNode{ + Operator: ">", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 18}, + }, + Right: &ast.BinaryNode{ + Operator: "!=", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Name"}, + }, + Right: &ast.StringNode{Value: "Bob"}, + }, + }, + Right: &ast.BinaryNode{ + Operator: "<", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 30}, + }, + }, + }, + }, + } + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) + }) + } +} + +func TestOptimize_predicate_combination_nested(t *testing.T) { + tree, err := parser.Parse(`all(users, {all(.Friends, {.Age == 18 })}) && all(users, {all(.Friends, {.Name != "Bob" })})`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BuiltinNode{ + Name: "all", + Arguments: []ast.Node{ + &ast.IdentifierNode{Value: "users"}, + &ast.ClosureNode{ + Node: &ast.BuiltinNode{ + Name: "all", + Arguments: []ast.Node{ + &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Friends"}, + }, + &ast.ClosureNode{ + Node: &ast.BinaryNode{ + Operator: "&&", + Left: &ast.BinaryNode{ + Operator: "==", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 18}, + }, + Right: &ast.BinaryNode{ + Operator: "!=", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Name"}, + }, + Right: &ast.StringNode{Value: "Bob"}, + }, + }, + }, + }, + }, + }, + }, + } + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} diff --git a/optimizer/predicate_combination.go b/optimizer/predicate_combination.go new file mode 100644 index 000000000..6e8a7f7cf --- /dev/null +++ b/optimizer/predicate_combination.go @@ -0,0 +1,61 @@ +package optimizer + +import ( + . "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/parser/operator" +) + +/* +predicateCombination is a visitor that combines multiple predicate calls into a single call. +For example, the following expression: + + all(x, x > 1) && all(x, x < 10) -> all(x, x > 1 && x < 10) + any(x, x > 1) || any(x, x < 10) -> any(x, x > 1 || x < 10) + none(x, x > 1) && none(x, x < 10) -> none(x, x > 1 || x < 10) +*/ +type predicateCombination struct{} + +func (v *predicateCombination) Visit(node *Node) { + if op, ok := (*node).(*BinaryNode); ok && operator.IsBoolean(op.Operator) { + if left, ok := op.Left.(*BuiltinNode); ok { + if combinedOp, ok := combinedOperator(left.Name, op.Operator); ok { + if right, ok := op.Right.(*BuiltinNode); ok && right.Name == left.Name { + if left.Arguments[0].Type() == right.Arguments[0].Type() && left.Arguments[0].String() == right.Arguments[0].String() { + closure := &ClosureNode{ + Node: &BinaryNode{ + Operator: combinedOp, + Left: left.Arguments[1].(*ClosureNode).Node, + Right: right.Arguments[1].(*ClosureNode).Node, + }, + } + v.Visit(&closure.Node) + Patch(node, &BuiltinNode{ + Name: left.Name, + Arguments: []Node{ + left.Arguments[0], + closure, + }, + }) + } + } + } + } + } +} + +func combinedOperator(fn, op string) (string, bool) { + switch { + case fn == "all" && (op == "and" || op == "&&"): + return op, true + case fn == "any" && (op == "or" || op == "||"): + return op, true + case fn == "none" && (op == "and" || op == "&&"): + switch op { + case "and": + return "or", true + case "&&": + return "||", true + } + } + return "", false +}