diff --git a/README.md b/README.md index bd34c7d2..1a2d7dc8 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,7 @@ func main() { * [Visually.io](https://visually.io) employs Expr as a business rule engine for its personalization targeting algorithm. * [Akvorado](https://github.com/akvorado/akvorado) utilizes Expr to classify exporters and interfaces in network flows. * [keda.sh](https://keda.sh) uses Expr to allow customization of its Kubernetes-based event-driven autoscaling. +* [Span Digital](https://spandigital.com/) uses Expr in it's Knowledge Management products. [Add your company too](https://github.com/expr-lang/expr/edit/master/README.md) diff --git a/ast/print.go b/ast/print.go index fa593ae2..063e9eb2 100644 --- a/ast/print.go +++ b/ast/print.go @@ -202,5 +202,11 @@ func (n *MapNode) String() string { } func (n *PairNode) String() string { - return fmt.Sprintf("%s: %s", n.Key.String(), n.Value.String()) + if str, ok := n.Key.(*StringNode); ok { + if utils.IsValidIdentifier(str.Value) { + return fmt.Sprintf("%s: %s", str.Value, n.Value.String()) + } + return fmt.Sprintf("%q: %s", str.String(), n.Value.String()) + } + return fmt.Sprintf("(%s): %s", n.Key.String(), n.Value.String()) } diff --git a/ast/print_test.go b/ast/print_test.go index 16d64357..d9e55c2e 100644 --- a/ast/print_test.go +++ b/ast/print_test.go @@ -55,8 +55,8 @@ func TestPrint(t *testing.T) { {`func(a)`, `func(a)`}, {`func(a, b)`, `func(a, b)`}, {`{}`, `{}`}, - {`{a: b}`, `{"a": b}`}, - {`{a: b, c: d}`, `{"a": b, "c": d}`}, + {`{a: b}`, `{a: b}`}, + {`{a: b, c: d}`, `{a: b, c: d}`}, {`[]`, `[]`}, {`[a]`, `[a]`}, {`[a, b]`, `[a, b]`}, @@ -71,6 +71,7 @@ func TestPrint(t *testing.T) { {`a[1:]`, `a[1:]`}, {`a[:]`, `a[:]`}, {`(nil ?? 1) > 0`, `(nil ?? 1) > 0`}, + {`{("a" + "b"): 42}`, `{("a" + "b"): 42}`}, } for _, tt := range tests { diff --git a/builtin/builtin.go b/builtin/builtin.go index fc48e111..7bf377df 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -135,42 +135,21 @@ var Builtins = []*Function{ Name: "ceil", Fast: Ceil, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for ceil (type %s)", args[0]) + return validateRoundFunc("ceil", args) }, }, { Name: "floor", Fast: Floor, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for floor (type %s)", args[0]) + return validateRoundFunc("floor", args) }, }, { Name: "round", Fast: Round, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for floor (type %s)", args[0]) + return validateRoundFunc("round", args) }, }, { @@ -392,185 +371,63 @@ var Builtins = []*Function{ }, { Name: "max", - Func: Max, + Func: func(args ...any) (any, error) { + return minMax("max", runtime.Less, args...) + }, Validate: func(args []reflect.Type) (reflect.Type, error) { - switch len(args) { - case 0: - return anyType, fmt.Errorf("not enough arguments to call max") - case 1: - if kindName := kind(args[0]); kindName == reflect.Array || kindName == reflect.Slice { - return anyType, nil - } - fallthrough - default: - for _, arg := range args { - switch kind(arg) { - case reflect.Interface, reflect.Array, reflect.Slice: - return anyType, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: - default: - return anyType, fmt.Errorf("invalid argument for max (type %s)", arg) - } - } - return args[0], nil - } + return validateAggregateFunc("max", args) }, }, { Name: "min", - Func: Min, + Func: func(args ...any) (any, error) { + return minMax("min", runtime.More, args...) + }, Validate: func(args []reflect.Type) (reflect.Type, error) { - switch len(args) { - case 0: - return anyType, fmt.Errorf("not enough arguments to call min") - case 1: - if kindName := kind(args[0]); kindName == reflect.Array || kindName == reflect.Slice { - return anyType, nil - } - fallthrough - default: - for _, arg := range args { - switch kind(arg) { - case reflect.Interface, reflect.Array, reflect.Slice: - return anyType, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: - default: - return anyType, fmt.Errorf("invalid argument for min (type %s)", arg) - } - } - return args[0], nil - - } + return validateAggregateFunc("min", args) }, }, { Name: "sum", - Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot sum %s", v.Kind()) - } - sum := int64(0) - i := 0 - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - sum += it.Int() - } else if it.CanFloat() { - goto float - } else { - return nil, fmt.Errorf("cannot sum %s", it.Kind()) - } - } - return int(sum), nil - float: - fSum := float64(sum) - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - fSum += float64(it.Int()) - } else if it.CanFloat() { - fSum += it.Float() - } else { - return nil, fmt.Errorf("cannot sum %s", it.Kind()) - } - } - return fSum, nil - }, + Func: sum, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot sum %s", args[0]) - } - return anyType, nil + return validateAggregateFunc("sum", args) }, }, { Name: "mean", Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot mean %s", v.Kind()) + count, sum, err := mean(args...) + if err != nil { + return nil, err } - if v.Len() == 0 { + if count == 0 { return 0.0, nil } - sum := float64(0) - i := 0 - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - sum += float64(it.Int()) - } else if it.CanFloat() { - sum += it.Float() - } else { - return nil, fmt.Errorf("cannot mean %s", it.Kind()) - } - } - return sum / float64(i), nil + return sum / float64(count), nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot avg %s", args[0]) - } - return floatType, nil + return validateAggregateFunc("mean", args) }, }, { Name: "median", Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot median %s", v.Kind()) - } - if v.Len() == 0 { - return 0.0, nil + values, err := median(args...) + if err != nil { + return nil, err } - s := make([]float64, v.Len()) - for i := 0; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - s[i] = float64(it.Int()) - } else if it.CanFloat() { - s[i] = it.Float() - } else { - return nil, fmt.Errorf("cannot median %s", it.Kind()) + if n := len(values); n > 0 { + sort.Float64s(values) + if n%2 == 1 { + return values[n/2], nil } + return (values[n/2-1] + values[n/2]) / 2, nil } - sort.Float64s(s) - if len(s)%2 == 0 { - return (s[len(s)/2-1] + s[len(s)/2]) / 2, nil - } - return s[len(s)/2], nil + return 0.0, nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot median %s", args[0]) - } - return floatType, nil + return validateAggregateFunc("median", args) }, }, { diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index bc1a2e14..7f5045f4 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -85,19 +85,29 @@ func TestBuiltin(t *testing.T) { {`min(1.5, 2.5, 3.5)`, 1.5}, {`min([1, 2, 3])`, 1}, {`min([1.5, 2.5, 3.5])`, 1.5}, + {`min(-1, [1.5, 2.5, 3.5])`, -1}, {`sum(1..9)`, 45}, {`sum([.5, 1.5, 2.5])`, 4.5}, {`sum([])`, 0}, {`sum([1, 2, 3.0, 4])`, 10.0}, + {`sum(10, [1, 2, 3], 1..9)`, 61}, + {`sum(-10, [1, 2, 3, 4])`, 0}, + {`sum(-10.9, [1, 2, 3, 4, 9])`, 8.1}, {`mean(1..9)`, 5.0}, {`mean([.5, 1.5, 2.5])`, 1.5}, {`mean([])`, 0.0}, {`mean([1, 2, 3.0, 4])`, 2.5}, + {`mean(10, [1, 2, 3], 1..9)`, 4.6923076923076925}, + {`mean(-10, [1, 2, 3, 4])`, 0.0}, + {`mean(10.9, 1..9)`, 5.59}, {`median(1..9)`, 5.0}, {`median([.5, 1.5, 2.5])`, 1.5}, {`median([])`, 0.0}, {`median([1, 2, 3])`, 2.0}, {`median([1, 2, 3, 4])`, 2.5}, + {`median(10, [1, 2, 3], 1..9)`, 4.0}, + {`median(-10, [1, 2, 3, 4])`, 2.0}, + {`median(1..5, 4.9)`, 3.5}, {`toJSON({foo: 1, bar: 2})`, "{\n \"bar\": 2,\n \"foo\": 1\n}"}, {`fromJSON("[1, 2, 3]")`, []any{1.0, 2.0, 3.0}}, {`toBase64("hello")`, "aGVsbG8="}, @@ -207,6 +217,9 @@ func TestBuiltin_errors(t *testing.T) { {`min()`, `not enough arguments to call min`}, {`min(1, "2")`, `invalid argument for min (type string)`}, {`min([1, "2"])`, `invalid argument for min (type string)`}, + {`median(1..9, "t")`, "invalid argument for median (type string)"}, + {`mean("s", 1..9)`, "invalid argument for mean (type string)"}, + {`sum("s", "h")`, "invalid argument for sum (type string)"}, {`duration("error")`, `invalid duration`}, {`date("error")`, `invalid date`}, {`get()`, `invalid number of arguments (expected 2, got 0)`}, @@ -599,3 +612,17 @@ func TestBuiltin_bitOpsFunc(t *testing.T) { }) } } + +type customInt int + +func Test_int_unwraps_underlying_value(t *testing.T) { + env := map[string]any{ + "customInt": customInt(42), + } + program, err := expr.Compile(`int(customInt) == 42`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, true, out) +} diff --git a/builtin/lib.go b/builtin/lib.go index b08c2ed2..e3a6c0ae 100644 --- a/builtin/lib.go +++ b/builtin/lib.go @@ -6,7 +6,7 @@ import ( "reflect" "strconv" - "github.com/expr-lang/expr/vm/runtime" + "github.com/expr-lang/expr/internal/deref" ) func Len(x any) any { @@ -209,6 +209,10 @@ func Int(x any) any { } return i default: + val := reflect.ValueOf(x) + if val.CanConvert(integerType) { + return val.Convert(integerType).Interface() + } panic(fmt.Sprintf("invalid operation: int(%T)", x)) } } @@ -254,45 +258,143 @@ func String(arg any) any { return fmt.Sprintf("%v", arg) } -func Max(args ...any) (any, error) { - return minMaxFunc("max", runtime.Less, args) -} +func sum(args ...any) (any, error) { + var total int + var fTotal float64 -func Min(args ...any) (any, error) { - return minMaxFunc("min", runtime.More, args) + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) + + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemSum, err := sum(rv.Index(i).Interface()) + if err != nil { + return nil, err + } + switch elemSum := elemSum.(type) { + case int: + total += elemSum + case float64: + fTotal += elemSum + } + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + total += int(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + total += int(rv.Uint()) + case reflect.Float32, reflect.Float64: + fTotal += rv.Float() + default: + return nil, fmt.Errorf("invalid argument for sum (type %T)", arg) + } + } + + if fTotal != 0.0 { + return fTotal + float64(total), nil + } + return total, nil } -func minMaxFunc(name string, fn func(any, any) bool, args []any) (any, error) { +func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { var val any for _, arg := range args { - switch v := arg.(type) { - case []float32, []float64, []uint, []uint8, []uint16, []uint32, []uint64, []int, []int8, []int16, []int32, []int64: - rv := reflect.ValueOf(v) - if rv.Len() == 0 { - return nil, fmt.Errorf("not enough arguments to call %s", name) - } - arg = rv.Index(0).Interface() - for i := 1; i < rv.Len(); i++ { - elem := rv.Index(i).Interface() - if fn(arg, elem) { - arg = elem + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemVal, err := minMax(name, fn, rv.Index(i).Interface()) + if err != nil { + return nil, err } + switch elemVal.(type) { + case int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64: + if elemVal != nil && (val == nil || fn(val, elemVal)) { + val = elemVal + } + default: + return nil, fmt.Errorf("invalid argument for %s (type %T)", name, elemVal) + } + } - case []any: - var err error - if arg, err = minMaxFunc(name, fn, v); err != nil { - return nil, err + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + elemVal := rv.Interface() + if val == nil || fn(val, elemVal) { + val = elemVal } - case float32, float64, uint, uint8, uint16, uint32, uint64, int, int8, int16, int32, int64: default: if len(args) == 1 { - return arg, nil + return args[0], nil } - return nil, fmt.Errorf("invalid argument for %s (type %T)", name, v) - } - if val == nil || fn(val, arg) { - val = arg + return nil, fmt.Errorf("invalid argument for %s (type %T)", name, arg) } } return val, nil } + +func mean(args ...any) (int, float64, error) { + var total float64 + var count int + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemCount, elemSum, err := mean(rv.Index(i).Interface()) + if err != nil { + return 0, 0, err + } + total += elemSum + count += elemCount + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + total += float64(rv.Int()) + count++ + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + total += float64(rv.Uint()) + count++ + case reflect.Float32, reflect.Float64: + total += rv.Float() + count++ + default: + return 0, 0, fmt.Errorf("invalid argument for mean (type %T)", arg) + } + } + return count, total, nil +} + +func median(args ...any) ([]float64, error) { + var values []float64 + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elems, err := median(rv.Index(i).Interface()) + if err != nil { + return nil, err + } + values = append(values, elems...) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + values = append(values, float64(rv.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + values = append(values, float64(rv.Uint())) + case reflect.Float32, reflect.Float64: + values = append(values, rv.Float()) + default: + return nil, fmt.Errorf("invalid argument for median (type %T)", arg) + } + } + return values, nil +} diff --git a/builtin/validation.go b/builtin/validation.go new file mode 100644 index 00000000..057f247e --- /dev/null +++ b/builtin/validation.go @@ -0,0 +1,38 @@ +package builtin + +import ( + "fmt" + "reflect" + + "github.com/expr-lang/expr/internal/deref" +) + +func validateAggregateFunc(name string, args []reflect.Type) (reflect.Type, error) { + switch len(args) { + case 0: + return anyType, fmt.Errorf("not enough arguments to call %s", name) + default: + for _, arg := range args { + switch kind(deref.Type(arg)) { + case reflect.Interface, reflect.Array, reflect.Slice: + return anyType, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: + default: + return anyType, fmt.Errorf("invalid argument for %s (type %s)", name, arg) + } + } + return args[0], nil + } +} + +func validateRoundFunc(name string, args []reflect.Type) (reflect.Type, error) { + if len(args) != 1 { + return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) + } + switch kind(args[0]) { + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: + return floatType, nil + default: + return anyType, fmt.Errorf("invalid argument for %s (type %s)", name, args[0]) + } +} diff --git a/checker/checker_test.go b/checker/checker_test.go index d6a84abc..29c50807 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -632,7 +632,7 @@ func TestCheck_TaggedFieldName(t *testing.T) { tree, err := parser.Parse(`foo.bar`) require.NoError(t, err) - config := &conf.Config{} + config := conf.CreateNew() expr.Env(struct { x struct { y bool `expr:"bar"` diff --git a/compiler/compiler.go b/compiler/compiler.go index a4f189e6..a38d977d 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -50,6 +50,11 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro } } + var span *Span + if len(c.spans) > 0 { + span = c.spans[0] + } + program = NewProgram( tree.Source, tree.Node, @@ -60,6 +65,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro c.arguments, c.functions, c.debugInfo, + span, ) return } @@ -76,6 +82,7 @@ type compiler struct { functionsIndex map[string]int debugInfo map[string]string nodes []ast.Node + spans []*Span chains [][]int arguments []int } @@ -193,6 +200,28 @@ func (c *compiler) compile(node ast.Node) { c.nodes = c.nodes[:len(c.nodes)-1] }() + if c.config != nil && c.config.Profile { + span := &Span{ + Name: reflect.TypeOf(node).String(), + Expression: node.String(), + } + if len(c.spans) > 0 { + prev := c.spans[len(c.spans)-1] + prev.Children = append(prev.Children, span) + } + c.spans = append(c.spans, span) + defer func() { + if len(c.spans) > 1 { + c.spans = c.spans[:len(c.spans)-1] + } + }() + + c.emit(OpProfileStart, c.addConstant(span)) + defer func() { + c.emit(OpProfileEnd, c.addConstant(span)) + }() + } + switch n := node.(type) { case *ast.NilNode: c.NilNode(n) @@ -366,34 +395,12 @@ func (c *compiler) UnaryNode(node *ast.UnaryNode) { } func (c *compiler) BinaryNode(node *ast.BinaryNode) { - l := kind(node.Left) - r := kind(node.Right) - - leftIsSimple := isSimpleType(node.Left) - rightIsSimple := isSimpleType(node.Right) - leftAndRightAreSimple := leftIsSimple && rightIsSimple - switch node.Operator { case "==": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - - if l == r && l == reflect.Int && leftAndRightAreSimple { - c.emit(OpEqualInt) - } else if l == r && l == reflect.String && leftAndRightAreSimple { - c.emit(OpEqualString) - } else { - c.emit(OpEqual) - } + c.equalBinaryNode(node) case "!=": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpEqual) + c.equalBinaryNode(node) c.emit(OpNot) case "or", "||": @@ -551,6 +558,28 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) { } } +func (c *compiler) equalBinaryNode(node *ast.BinaryNode) { + l := kind(node.Left) + r := kind(node.Right) + + leftIsSimple := isSimpleType(node.Left) + rightIsSimple := isSimpleType(node.Right) + leftAndRightAreSimple := leftIsSimple && rightIsSimple + + c.compile(node.Left) + c.derefInNeeded(node.Left) + c.compile(node.Right) + c.derefInNeeded(node.Right) + + if l == r && l == reflect.Int && leftAndRightAreSimple { + c.emit(OpEqualInt) + } else if l == r && l == reflect.String && leftAndRightAreSimple { + c.emit(OpEqualString) + } else { + c.emit(OpEqual) + } +} + func isSimpleType(node ast.Node) bool { if node == nil { return false diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 741142a7..fbd83ec8 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -541,6 +541,39 @@ func TestCompile_optimizes_jumps(t *testing.T) { {vm.OpFetch, 0}, }, }, + { + `-1 not in [1, 2, 5]`, + []op{ + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpIn, 0}, + {vm.OpNot, 0}, + }, + }, + { + `1 + 8 not in [1, 2, 5]`, + []op{ + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpIn, 0}, + {vm.OpNot, 0}, + }, + }, + { + `true ? false : 8 not in [1, 2, 5]`, + []op{ + {vm.OpTrue, 0}, + {vm.OpJumpIfFalse, 3}, + {vm.OpPop, 0}, + {vm.OpFalse, 0}, + {vm.OpJump, 5}, + {vm.OpPop, 0}, + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpIn, 0}, + {vm.OpNot, 0}, + }, + }, } for _, test := range tests { diff --git a/conf/config.go b/conf/config.go index e543732c..01a407a1 100644 --- a/conf/config.go +++ b/conf/config.go @@ -20,6 +20,7 @@ type Config struct { ExpectAny bool Optimize bool Strict bool + Profile bool ConstFns map[string]reflect.Value Visitors []ast.Visitor Functions FunctionsTable @@ -31,6 +32,7 @@ type Config struct { func CreateNew() *Config { c := &Config{ Optimize: true, + Types: make(TypesTable), ConstFns: make(map[string]reflect.Value), Functions: make(map[string]*builtin.Function), Builtins: make(map[string]*builtin.Function), @@ -61,7 +63,10 @@ func (c *Config) WithEnv(env any) { } c.Env = env - c.Types = CreateTypesTable(env) + types := CreateTypesTable(env) + for name, t := range types { + c.Types[name] = t + } c.MapEnv = mapEnv c.DefaultType = mapValueType c.Strict = true diff --git a/expr_test.go b/expr_test.go index 74975362..ac8eecf4 100644 --- a/expr_test.go +++ b/expr_test.go @@ -785,6 +785,10 @@ func TestExpr(t *testing.T) { `Two not in 0..1`, true, }, + { + `-1 not in [1]`, + true, + }, { `Int32 in [10, 20]`, false, @@ -797,6 +801,10 @@ func TestExpr(t *testing.T) { `String matches ("^" + String + "$")`, true, }, + { + `'foo' + 'bar' not matches 'foobar'`, + false, + }, { `"foobar" contains "bar"`, true, @@ -893,18 +901,147 @@ func TestExpr(t *testing.T) { `all(1..3, {# > 0})`, true, }, + { + `all(1..3, {# > 0}) && all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# > 2}) && all(1..3, {# < 4})`, + false, + }, + { + `all(1..3, {# > 0}) && all(1..3, {# < 2})`, + false, + }, + { + `all(1..3, {# > 2}) && all(1..3, {# < 2})`, + false, + }, + { + `all(1..3, {# > 0}) || all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# > 0}) || all(1..3, {# != 2})`, + true, + }, + { + `all(1..3, {# != 3}) || all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# != 3}) || all(1..3, {# != 2})`, + false, + }, { `none(1..3, {# == 0})`, true, }, + { + `none(1..3, {# == 0}) && none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 0}) && none(1..3, {# == 3})`, + false, + }, + { + `none(1..3, {# == 1}) && none(1..3, {# == 4})`, + false, + }, + { + `none(1..3, {# == 1}) && none(1..3, {# == 3})`, + false, + }, + { + `none(1..3, {# == 0}) || none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 0}) || none(1..3, {# == 3})`, + true, + }, + { + `none(1..3, {# == 1}) || none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 1}) || none(1..3, {# == 3})`, + false, + }, { `any([1,1,0,1], {# == 0})`, true, }, + { + `any(1..3, {# == 1}) && any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 0}) && any(1..3, {# == 2})`, + false, + }, + { + `any(1..3, {# == 1}) && any(1..3, {# == 4})`, + false, + }, + { + `any(1..3, {# == 0}) && any(1..3, {# == 4})`, + false, + }, + { + `any(1..3, {# == 1}) || any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 0}) || any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 1}) || any(1..3, {# == 4})`, + true, + }, + { + `any(1..3, {# == 0}) || any(1..3, {# == 4})`, + false, + }, { `one([1,1,0,1], {# == 0}) and not one([1,0,0,1], {# == 0})`, true, }, + { + `one(1..3, {# == 1}) and one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`, + false, + }, + { + `one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`, + false, + }, + { + `one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`, + false, + }, + { + `one(1..3, {# == 1}) or one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`, + false, + }, + { `count(1..30, {# % 3 == 0})`, 10, @@ -1253,6 +1390,26 @@ func TestExpr(t *testing.T) { `[nil, 3, 4]?.[0]?.[1]`, nil, }, + { + `1 > 2 < 3`, + false, + }, + { + `1 < 2 < 3`, + true, + }, + { + `1 < 2 < 3 > 4`, + false, + }, + { + `1 < 2 < 3 > 2`, + true, + }, + { + `1 < 2 < 3 == true`, + true, + }, } for _, tt := range tests { @@ -2482,3 +2639,103 @@ func TestRaceCondition_variables(t *testing.T) { wg.Wait() } + +func TestOperatorDependsOnEnv(t *testing.T) { + env := map[string]any{ + "plus": func(a, b int) int { + return 42 + }, + } + program, err := expr.Compile(`1 + 2`, expr.Operator("+", "plus"), expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, 42, out) +} + +func TestIssue624(t *testing.T) { + type tag struct { + Name string + } + + type item struct { + Tags []tag + } + + i := item{ + Tags: []tag{ + {Name: "one"}, + {Name: "two"}, + }, + } + + rule := `[ +true && true, +one(Tags, .Name in ["one"]), +one(Tags, .Name in ["two"]), +one(Tags, .Name in ["one"]) && one(Tags, .Name in ["two"]) +]` + resp, err := expr.Eval(rule, i) + require.NoError(t, err) + require.Equal(t, []interface{}{true, true, true, true}, resp) +} + +func TestPredicateCombination(t *testing.T) { + tests := []struct { + code1 string + code2 string + }{ + {"all(1..3, {# > 0}) && all(1..3, {# < 4})", "all(1..3, {# > 0 && # < 4})"}, + {"all(1..3, {# > 1}) && all(1..3, {# < 4})", "all(1..3, {# > 1 && # < 4})"}, + {"all(1..3, {# > 0}) && all(1..3, {# < 2})", "all(1..3, {# > 0 && # < 2})"}, + {"all(1..3, {# > 1}) && all(1..3, {# < 2})", "all(1..3, {# > 1 && # < 2})"}, + + {"any(1..3, {# > 0}) || any(1..3, {# < 4})", "any(1..3, {# > 0 || # < 4})"}, + {"any(1..3, {# > 1}) || any(1..3, {# < 4})", "any(1..3, {# > 1 || # < 4})"}, + {"any(1..3, {# > 0}) || any(1..3, {# < 2})", "any(1..3, {# > 0 || # < 2})"}, + {"any(1..3, {# > 1}) || any(1..3, {# < 2})", "any(1..3, {# > 1 || # < 2})"}, + + {"none(1..3, {# > 0}) && none(1..3, {# < 4})", "none(1..3, {# > 0 || # < 4})"}, + {"none(1..3, {# > 1}) && none(1..3, {# < 4})", "none(1..3, {# > 1 || # < 4})"}, + {"none(1..3, {# > 0}) && none(1..3, {# < 2})", "none(1..3, {# > 0 || # < 2})"}, + {"none(1..3, {# > 1}) && none(1..3, {# < 2})", "none(1..3, {# > 1 || # < 2})"}, + } + for _, tt := range tests { + t.Run(tt.code1, func(t *testing.T) { + out1, err := expr.Eval(tt.code1, nil) + require.NoError(t, err) + + out2, err := expr.Eval(tt.code2, nil) + require.NoError(t, err) + + require.Equal(t, out1, out2) + }) + } +} + +func TestArrayComparison(t *testing.T) { + tests := []struct { + env any + code string + }{ + {[]string{"A", "B"}, "foo == ['A', 'B']"}, + {[]int{1, 2}, "foo == [1, 2]"}, + {[]uint8{1, 2}, "foo == [1, 2]"}, + {[]float64{1.1, 2.2}, "foo == [1.1, 2.2]"}, + {[]any{"A", 1, 1.1, true}, "foo == ['A', 1, 1.1, true]"}, + {[]string{"A", "B"}, "foo != [1, 2]"}, + } + + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + env := map[string]any{"foo": tt.env} + program, err := expr.Compile(tt.code, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, true, out) + }) + } +} diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index 703bd1ce..316b1718 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -349,16 +349,10 @@ func TestOptimize_predicate_combination(t *testing.T) { }{ {"and", "all", "and"}, {"&&", "all", "&&"}, - {"or", "all", "or"}, - {"||", "all", "||"}, - {"and", "any", "and"}, - {"&&", "any", "&&"}, {"or", "any", "or"}, {"||", "any", "||"}, {"and", "none", "or"}, {"&&", "none", "||"}, - {"and", "one", "or"}, - {"&&", "one", "||"}, } for _, tt := range tests { @@ -414,14 +408,14 @@ func TestOptimize_predicate_combination(t *testing.T) { } func TestOptimize_predicate_combination_nested(t *testing.T) { - tree, err := parser.Parse(`any(users, {all(.Friends, {.Age == 18 })}) && any(users, {all(.Friends, {.Name != "Bob" })})`) + tree, err := parser.Parse(`all(users, {all(.Friends, {.Age == 18 })}) && all(users, {all(.Friends, {.Name != "Bob" })})`) require.NoError(t, err) err = optimizer.Optimize(&tree.Node, nil) require.NoError(t, err) expected := &ast.BuiltinNode{ - Name: "any", + Name: "all", Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, &ast.ClosureNode{ diff --git a/optimizer/predicate_combination.go b/optimizer/predicate_combination.go index 2733781d..6e8a7f7c 100644 --- a/optimizer/predicate_combination.go +++ b/optimizer/predicate_combination.go @@ -5,6 +5,14 @@ import ( "github.com/expr-lang/expr/parser/operator" ) +/* +predicateCombination is a visitor that combines multiple predicate calls into a single call. +For example, the following expression: + + all(x, x > 1) && all(x, x < 10) -> all(x, x > 1 && x < 10) + any(x, x > 1) || any(x, x < 10) -> any(x, x > 1 || x < 10) + none(x, x > 1) && none(x, x < 10) -> none(x, x > 1 || x < 10) +*/ type predicateCombination struct{} func (v *predicateCombination) Visit(node *Node) { @@ -36,10 +44,12 @@ func (v *predicateCombination) Visit(node *Node) { } func combinedOperator(fn, op string) (string, bool) { - switch fn { - case "all", "any": + switch { + case fn == "all" && (op == "and" || op == "&&"): + return op, true + case fn == "any" && (op == "or" || op == "||"): return op, true - case "one", "none": + case fn == "none" && (op == "and" || op == "&&"): switch op { case "and": return "or", true diff --git a/parser/operator/operator.go b/parser/operator/operator.go index 8d804c7b..4eeaf80e 100644 --- a/parser/operator/operator.go +++ b/parser/operator/operator.go @@ -20,6 +20,15 @@ func IsBoolean(op string) bool { return op == "and" || op == "or" || op == "&&" || op == "||" } +func AllowedNegateSuffix(op string) bool { + switch op { + case "contains", "matches", "startsWith", "endsWith", "in": + return true + default: + return false + } +} + var Unary = map[string]Operator{ "not": {50, Left}, "!": {50, Left}, @@ -54,3 +63,7 @@ var Binary = map[string]Operator{ "^": {100, Right}, "??": {500, Left}, } + +func IsComparison(op string) bool { + return op == "<" || op == ">" || op == ">=" || op == "<=" +} diff --git a/parser/parser.go b/parser/parser.go index 1eabdebe..9cb79cbb 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -126,10 +126,8 @@ func (p *parser) expect(kind Kind, values ...string) { // parse functions func (p *parser) parseExpression(precedence int) Node { - if precedence == 0 { - if p.current.Is(Operator, "let") { - return p.parseVariableDeclaration() - } + if precedence == 0 && p.current.Is(Operator, "let") { + return p.parseVariableDeclaration() } nodeLeft := p.parsePrimary() @@ -137,57 +135,71 @@ func (p *parser) parseExpression(precedence int) Node { prevOperator := "" opToken := p.current for opToken.Is(Operator) && p.err == nil { - negate := false + negate := opToken.Is(Operator, "not") var notToken Token // Handle "not *" operator, like "not in" or "not contains". - if opToken.Is(Operator, "not") { + if negate { + currentPos := p.pos p.next() - notToken = p.current - negate = true - opToken = p.current + if operator.AllowedNegateSuffix(p.current.Value) { + if op, ok := operator.Binary[p.current.Value]; ok && op.Precedence >= precedence { + notToken = p.current + opToken = p.current + } else { + p.pos = currentPos + p.current = opToken + break + } + } else { + p.error("unexpected token %v", p.current) + break + } } - if op, ok := operator.Binary[opToken.Value]; ok { - if op.Precedence >= precedence { - p.next() + if op, ok := operator.Binary[opToken.Value]; ok && op.Precedence >= precedence { + p.next() - if opToken.Value == "|" { - identToken := p.current - p.expect(Identifier) - nodeLeft = p.parseCall(identToken, []Node{nodeLeft}, true) - goto next - } + if opToken.Value == "|" { + identToken := p.current + p.expect(Identifier) + nodeLeft = p.parseCall(identToken, []Node{nodeLeft}, true) + goto next + } - if prevOperator == "??" && opToken.Value != "??" && !opToken.Is(Bracket, "(") { - p.errorAt(opToken, "Operator (%v) and coalesce expressions (??) cannot be mixed. Wrap either by parentheses.", opToken.Value) - break - } + if prevOperator == "??" && opToken.Value != "??" && !opToken.Is(Bracket, "(") { + p.errorAt(opToken, "Operator (%v) and coalesce expressions (??) cannot be mixed. Wrap either by parentheses.", opToken.Value) + break + } - var nodeRight Node - if op.Associativity == operator.Left { - nodeRight = p.parseExpression(op.Precedence + 1) - } else { - nodeRight = p.parseExpression(op.Precedence) - } + if operator.IsComparison(opToken.Value) { + nodeLeft = p.parseComparison(nodeLeft, opToken, op.Precedence) + goto next + } - nodeLeft = &BinaryNode{ - Operator: opToken.Value, - Left: nodeLeft, - Right: nodeRight, - } - nodeLeft.SetLocation(opToken.Location) + var nodeRight Node + if op.Associativity == operator.Left { + nodeRight = p.parseExpression(op.Precedence + 1) + } else { + nodeRight = p.parseExpression(op.Precedence) + } - if negate { - nodeLeft = &UnaryNode{ - Operator: "not", - Node: nodeLeft, - } - nodeLeft.SetLocation(notToken.Location) - } + nodeLeft = &BinaryNode{ + Operator: opToken.Value, + Left: nodeLeft, + Right: nodeRight, + } + nodeLeft.SetLocation(opToken.Location) - goto next + if negate { + nodeLeft = &UnaryNode{ + Operator: "not", + Node: nodeLeft, + } + nodeLeft.SetLocation(notToken.Location) } + + goto next } break @@ -685,3 +697,34 @@ func (p *parser) parsePostfixExpression(node Node) Node { } return node } + +func (p *parser) parseComparison(left Node, token Token, precedence int) Node { + var rootNode Node + for { + comparator := p.parseExpression(precedence + 1) + cmpNode := &BinaryNode{ + Operator: token.Value, + Left: left, + Right: comparator, + } + cmpNode.SetLocation(token.Location) + if rootNode == nil { + rootNode = cmpNode + } else { + rootNode = &BinaryNode{ + Operator: "&&", + Left: rootNode, + Right: cmpNode, + } + rootNode.SetLocation(token.Location) + } + + left = comparator + token = p.current + if !(token.Is(Operator) && operator.IsComparison(token.Value) && p.err == nil) { + break + } + p.next() + } + return rootNode +} diff --git a/parser/parser_test.go b/parser/parser_test.go index b633bd52..2a30787a 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -365,6 +365,62 @@ world`}, &UnaryNode{Operator: "not", Node: &IdentifierNode{Value: "in_var"}}, }, + { + "-1 not in [1, 2, 3, 4]", + &UnaryNode{Operator: "not", + Node: &BinaryNode{Operator: "in", + Left: &UnaryNode{Operator: "-", Node: &IntegerNode{Value: 1}}, + Right: &ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &IntegerNode{Value: 4}, + }}}}, + }, + { + "1*8 not in [1, 2, 3, 4]", + &UnaryNode{Operator: "not", + Node: &BinaryNode{Operator: "in", + Left: &BinaryNode{Operator: "*", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 8}, + }, + Right: &ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &IntegerNode{Value: 4}, + }}}}, + }, + { + "2==2 ? false : 3 not in [1, 2, 5]", + &ConditionalNode{ + Cond: &BinaryNode{ + Operator: "==", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 2}, + }, + Exp1: &BoolNode{Value: false}, + Exp2: &UnaryNode{ + Operator: "not", + Node: &BinaryNode{ + Operator: "in", + Left: &IntegerNode{Value: 3}, + Right: &ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 5}, + }}}}}, + }, + { + "'foo' + 'bar' not matches 'foobar'", + &UnaryNode{Operator: "not", + Node: &BinaryNode{Operator: "matches", + Left: &BinaryNode{Operator: "+", + Left: &StringNode{Value: "foo"}, + Right: &StringNode{Value: "bar"}}, + Right: &StringNode{Value: "foobar"}}}, + }, { "all(Tickets, #)", &BuiltinNode{ @@ -531,6 +587,66 @@ world`}, To: &IntegerNode{Value: 3}, }, }, + { + `1 < 2 > 3`, + &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 2}, + }, + Right: &BinaryNode{ + Operator: ">", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 3}, + }, + }, + }, + { + `1 < 2 < 3 < 4`, + &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 2}, + }, + Right: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 3}, + }, + }, + Right: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 3}, + Right: &IntegerNode{Value: 4}, + }, + }, + }, + { + `1 < 2 < 3 == true`, + &BinaryNode{ + Operator: "==", + Left: &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 2}, + }, + Right: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 3}, + }, + }, + Right: &BoolNode{Value: true}, + }, + }, } for _, test := range tests { t.Run(test.input, func(t *testing.T) { @@ -646,6 +762,11 @@ invalid float literal: strconv.ParseFloat: parsing "0o1E+1": invalid syntax (1:6 invalid float literal: strconv.ParseFloat: parsing "1E": invalid syntax (1:2) | 1E | .^ + +1 not == [1, 2, 5] +unexpected token Operator("==") (1:7) + | 1 not == [1, 2, 5] + | ......^ ` func TestParse_error(t *testing.T) { diff --git a/patcher/value/value.go b/patcher/value/value.go index 59351be6..28f52be2 100644 --- a/patcher/value/value.go +++ b/patcher/value/value.go @@ -13,9 +13,9 @@ import ( // ValueGetter is a Patcher that allows custom types to be represented as standard go values for use with expr. // It also adds the `$patcher_value_getter` function to the program for efficiently calling matching interfaces. // -// The purpose of this Patcher is to make it seemless to use custom types in expressions without the need to +// The purpose of this Patcher is to make it seamless to use custom types in expressions without the need to // first convert them to standard go values. It may also facilitate using already existing structs or maps as -// environments when they contain compatabile types. +// environments when they contain compatible types. // // An example usage may be modeling a database record with columns that have varying data types and constraints. // In such an example you may have custom types that, beyond storing a simple value, such as an integer, may diff --git a/test/operator/operator_test.go b/test/operator/operator_test.go index a19c191d..b49d91cc 100644 --- a/test/operator/operator_test.go +++ b/test/operator/operator_test.go @@ -77,7 +77,7 @@ func TestOperator_Function(t *testing.T) { } for _, tt := range tests { - t.Run(fmt.Sprintf(`opertor function helper test %s`, tt.input), func(t *testing.T) { + t.Run(fmt.Sprintf(`operator function helper test %s`, tt.input), func(t *testing.T) { program, err := expr.Compile( tt.input, expr.Env(env), diff --git a/vm/opcodes.go b/vm/opcodes.go index 0417dab6..84d751d6 100644 --- a/vm/opcodes.go +++ b/vm/opcodes.go @@ -81,6 +81,8 @@ const ( OpGroupBy OpSortBy OpSort + OpProfileStart + OpProfileEnd OpBegin OpEnd // This opcode must be at the end of this list. ) diff --git a/vm/program.go b/vm/program.go index 4a878267..98954674 100644 --- a/vm/program.go +++ b/vm/program.go @@ -27,6 +27,7 @@ type Program struct { variables int functions []Function debugInfo map[string]string + span *Span } // NewProgram returns a new Program. It's used by the compiler. @@ -40,6 +41,7 @@ func NewProgram( arguments []int, functions []Function, debugInfo map[string]string, + span *Span, ) *Program { return &Program{ source: source, @@ -51,6 +53,7 @@ func NewProgram( Arguments: arguments, functions: functions, debugInfo: debugInfo, + span: span, } } @@ -360,6 +363,12 @@ func (program *Program) DisassembleWriter(w io.Writer) { case OpSort: code("OpSort") + case OpProfileStart: + code("OpProfileStart") + + case OpProfileEnd: + code("OpProfileEnd") + case OpBegin: code("OpBegin") diff --git a/vm/runtime/helpers/main.go b/vm/runtime/helpers/main.go index b3f598a4..54a4fc23 100644 --- a/vm/runtime/helpers/main.go +++ b/vm/runtime/helpers/main.go @@ -19,6 +19,7 @@ func main() { "cases_with_duration": func(op string) string { return cases(op, uints, ints, floats, []string{"time.Duration"}) }, + "array_equal_cases": func() string { return arrayEqualCases([]string{"string"}, uints, ints, floats) }, }). Parse(helpers), ).Execute(&b, nil) @@ -89,6 +90,45 @@ func cases(op string, xs ...[]string) string { return strings.TrimRight(out, "\n") } +func arrayEqualCases(xs ...[]string) string { + var types []string + for _, x := range xs { + types = append(types, x...) + } + + _, _ = fmt.Fprintf(os.Stderr, "Generating array equal cases for %v\n", types) + + var out string + echo := func(s string, xs ...any) { + out += fmt.Sprintf(s, xs...) + "\n" + } + echo(`case []any:`) + echo(`switch y := b.(type) {`) + for _, a := range append(types, "any") { + echo(`case []%v:`, a) + echo(`if len(x) != len(y) { return false }`) + echo(`for i := range x {`) + echo(`if !Equal(x[i], y[i]) { return false }`) + echo(`}`) + echo("return true") + } + echo(`}`) + for _, a := range types { + echo(`case []%v:`, a) + echo(`switch y := b.(type) {`) + echo(`case []any:`) + echo(`return Equal(y, x)`) + echo(`case []%v:`, a) + echo(`if len(x) != len(y) { return false }`) + echo(`for i := range x {`) + echo(`if x[i] != y[i] { return false }`) + echo(`}`) + echo("return true") + echo(`}`) + } + return strings.TrimRight(out, "\n") +} + func isFloat(t string) bool { return strings.HasPrefix(t, "float") } @@ -110,6 +150,7 @@ import ( func Equal(a, b interface{}) bool { switch x := a.(type) { {{ cases "==" }} + {{ array_equal_cases }} case string: switch y := b.(type) { case string: @@ -125,6 +166,11 @@ func Equal(a, b interface{}) bool { case time.Duration: return x == y } + case bool: + switch y := b.(type) { + case bool: + return x == y + } } if IsNil(a) && IsNil(b) { return true diff --git a/vm/runtime/helpers[generated].go b/vm/runtime/helpers[generated].go index 720feb45..d950f111 100644 --- a/vm/runtime/helpers[generated].go +++ b/vm/runtime/helpers[generated].go @@ -334,6 +334,344 @@ func Equal(a, b interface{}) bool { case float64: return float64(x) == float64(y) } + case []any: + switch y := b.(type) { + case []string: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint8: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint16: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int8: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int16: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []float32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []float64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []any: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + } + case []string: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []string: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint8: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint8: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint16: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint16: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int8: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int8: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int16: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int16: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []float32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []float32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []float64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []float64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } case string: switch y := b.(type) { case string: @@ -349,6 +687,11 @@ func Equal(a, b interface{}) bool { case time.Duration: return x == y } + case bool: + switch y := b.(type) { + case bool: + return x == y + } } if IsNil(a) && IsNil(b) { return true diff --git a/vm/runtime/helpers_test.go b/vm/runtime/helpers_test.go new file mode 100644 index 00000000..42a0aece --- /dev/null +++ b/vm/runtime/helpers_test.go @@ -0,0 +1,57 @@ +package runtime_test + +import ( + "testing" + + "github.com/expr-lang/expr/vm/runtime" + "github.com/stretchr/testify/assert" +) + +var tests = []struct { + name string + a, b any + want bool +}{ + {"int == int", 42, 42, true}, + {"int != int", 42, 33, false}, + {"int == int8", 42, int8(42), true}, + {"int == int16", 42, int16(42), true}, + {"int == int32", 42, int32(42), true}, + {"int == int64", 42, int64(42), true}, + {"float == float", 42.0, 42.0, true}, + {"float != float", 42.0, 33.0, false}, + {"float == int", 42.0, 42, true}, + {"float != int", 42.0, 33, false}, + {"string == string", "foo", "foo", true}, + {"string != string", "foo", "bar", false}, + {"bool == bool", true, true, true}, + {"bool != bool", true, false, false}, + {"[]any == []int", []any{1, 2, 3}, []int{1, 2, 3}, true}, + {"[]any != []int", []any{1, 2, 3}, []int{1, 2, 99}, false}, + {"deep []any == []any", []any{[]int{1}, 2, []any{"3"}}, []any{[]any{1}, 2, []string{"3"}}, true}, + {"deep []any != []any", []any{[]int{1}, 2, []any{"3", "42"}}, []any{[]any{1}, 2, []string{"3"}}, false}, + {"map[string]any == map[string]any", map[string]any{"a": 1}, map[string]any{"a": 1}, true}, + {"map[string]any != map[string]any", map[string]any{"a": 1}, map[string]any{"a": 1, "b": 2}, false}, +} + +func TestEqual(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := runtime.Equal(tt.a, tt.b) + assert.Equal(t, tt.want, got, "Equal(%v, %v) = %v; want %v", tt.a, tt.b, got, tt.want) + got = runtime.Equal(tt.b, tt.a) + assert.Equal(t, tt.want, got, "Equal(%v, %v) = %v; want %v", tt.b, tt.a, got, tt.want) + }) + } + +} + +func BenchmarkEqual(b *testing.B) { + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + runtime.Equal(tt.a, tt.b) + } + }) + } +} diff --git a/vm/utils.go b/vm/utils.go index d7db2a52..fc2f5e7b 100644 --- a/vm/utils.go +++ b/vm/utils.go @@ -2,6 +2,7 @@ package vm import ( "reflect" + "time" ) type ( @@ -25,3 +26,15 @@ type Scope struct { } type groupBy = map[any][]any + +type Span struct { + Name string `json:"name"` + Expression string `json:"expression"` + Duration int64 `json:"duration"` + Children []*Span `json:"children"` + start time.Time +} + +func GetSpan(program *Program) *Span { + return program.span +} diff --git a/vm/vm.go b/vm/vm.go index 1e85893b..7e933ce7 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -8,6 +8,7 @@ import ( "regexp" "sort" "strings" + "time" "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/file" @@ -523,6 +524,14 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { vm.memGrow(uint(scope.Len)) vm.push(sortable.Array) + case OpProfileStart: + span := program.Constants[arg].(*Span) + span.start = time.Now() + + case OpProfileEnd: + span := program.Constants[arg].(*Span) + span.Duration += time.Since(span.start).Nanoseconds() + case OpBegin: a := vm.pop() array := reflect.ValueOf(a)