Skip to content

Commit 38f9496

Browse files
authored
Add predicate to sum() builtin (#592)
* Add predicate to sum() builtin * go mod tidy
1 parent d66ffcd commit 38f9496

File tree

8 files changed

+48
-65
lines changed

8 files changed

+48
-65
lines changed

Diff for: builtin/builtin.go

+5-7
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ var Builtins = []*Function{
8383
Predicate: true,
8484
Types: types(new(func([]any, func(any) bool) int)),
8585
},
86+
{
87+
Name: "sum",
88+
Predicate: true,
89+
Types: types(new(func([]any, func(any) bool) int)),
90+
},
8691
{
8792
Name: "groupBy",
8893
Predicate: true,
@@ -387,13 +392,6 @@ var Builtins = []*Function{
387392
return validateAggregateFunc("min", args)
388393
},
389394
},
390-
{
391-
Name: "sum",
392-
Func: sum,
393-
Validate: func(args []reflect.Type) (reflect.Type, error) {
394-
return validateAggregateFunc("sum", args)
395-
},
396-
},
397395
{
398396
Name: "mean",
399397
Func: func(args ...any) (any, error) {

Diff for: builtin/builtin_test.go

-4
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,6 @@ func TestBuiltin(t *testing.T) {
9090
{`sum([.5, 1.5, 2.5])`, 4.5},
9191
{`sum([])`, 0},
9292
{`sum([1, 2, 3.0, 4])`, 10.0},
93-
{`sum(10, [1, 2, 3], 1..9)`, 61},
94-
{`sum(-10, [1, 2, 3, 4])`, 0},
95-
{`sum(-10.9, [1, 2, 3, 4, 9])`, 8.1},
9693
{`mean(1..9)`, 5.0},
9794
{`mean([.5, 1.5, 2.5])`, 1.5},
9895
{`mean([])`, 0.0},
@@ -219,7 +216,6 @@ func TestBuiltin_errors(t *testing.T) {
219216
{`min([1, "2"])`, `invalid argument for min (type string)`},
220217
{`median(1..9, "t")`, "invalid argument for median (type string)"},
221218
{`mean("s", 1..9)`, "invalid argument for mean (type string)"},
222-
{`sum("s", "h")`, "invalid argument for sum (type string)"},
223219
{`duration("error")`, `invalid duration`},
224220
{`date("error")`, `invalid date`},
225221
{`get()`, `invalid number of arguments (expected 2, got 0)`},

Diff for: builtin/lib.go

-39
Original file line numberDiff line numberDiff line change
@@ -258,45 +258,6 @@ func String(arg any) any {
258258
return fmt.Sprintf("%v", arg)
259259
}
260260

261-
func sum(args ...any) (any, error) {
262-
var total int
263-
var fTotal float64
264-
265-
for _, arg := range args {
266-
rv := reflect.ValueOf(deref.Deref(arg))
267-
268-
switch rv.Kind() {
269-
case reflect.Array, reflect.Slice:
270-
size := rv.Len()
271-
for i := 0; i < size; i++ {
272-
elemSum, err := sum(rv.Index(i).Interface())
273-
if err != nil {
274-
return nil, err
275-
}
276-
switch elemSum := elemSum.(type) {
277-
case int:
278-
total += elemSum
279-
case float64:
280-
fTotal += elemSum
281-
}
282-
}
283-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
284-
total += int(rv.Int())
285-
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
286-
total += int(rv.Uint())
287-
case reflect.Float32, reflect.Float64:
288-
fTotal += rv.Float()
289-
default:
290-
return nil, fmt.Errorf("invalid argument for sum (type %T)", arg)
291-
}
292-
}
293-
294-
if fTotal != 0.0 {
295-
return fTotal + float64(total), nil
296-
}
297-
return total, nil
298-
}
299-
300261
func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
301262
var val any
302263
for _, arg := range args {

Diff for: checker/checker.go

+23
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,29 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
668668
}
669669
return v.error(node.Arguments[1], "predicate should has one input and one output param")
670670

671+
case "sum":
672+
collection, _ := v.visit(node.Arguments[0])
673+
if !isArray(collection) && !isAny(collection) {
674+
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
675+
}
676+
677+
if len(node.Arguments) == 2 {
678+
v.begin(collection)
679+
closure, _ := v.visit(node.Arguments[1])
680+
v.end()
681+
682+
if isFunc(closure) &&
683+
closure.NumOut() == 1 &&
684+
closure.NumIn() == 1 && isAny(closure.In(0)) {
685+
return closure.Out(0), info{}
686+
}
687+
} else {
688+
if isAny(collection) {
689+
return anyType, info{}
690+
}
691+
return collection.Elem(), info{}
692+
}
693+
671694
case "find", "findLast":
672695
collection, _ := v.visit(node.Arguments[0])
673696
if !isArray(collection) && !isAny(collection) {

Diff for: compiler/compiler.go

+19
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,25 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
809809
c.emit(OpEnd)
810810
return
811811

812+
case "sum":
813+
c.compile(node.Arguments[0])
814+
c.emit(OpBegin)
815+
c.emit(OpInt, 0)
816+
c.emit(OpSetAcc)
817+
c.emitLoop(func() {
818+
if len(node.Arguments) == 2 {
819+
c.compile(node.Arguments[1])
820+
} else {
821+
c.emit(OpPointer)
822+
}
823+
c.emit(OpGetAcc)
824+
c.emit(OpAdd)
825+
c.emit(OpSetAcc)
826+
})
827+
c.emit(OpGetAcc)
828+
c.emit(OpEnd)
829+
return
830+
812831
case "find":
813832
c.compile(node.Arguments[0])
814833
c.emit(OpBegin)

Diff for: parser/parser.go

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ var predicates = map[string]struct {
3434
"filter": {[]arg{expr, closure}},
3535
"map": {[]arg{expr, closure}},
3636
"count": {[]arg{expr, closure}},
37+
"sum": {[]arg{expr, closure | optional}},
3738
"find": {[]arg{expr, closure}},
3839
"findIndex": {[]arg{expr, closure}},
3940
"findLast": {[]arg{expr, closure}},

Diff for: test/fuzz/fuzz_corpus.txt

-1
Original file line numberDiff line numberDiff line change
@@ -10455,7 +10455,6 @@ max(f64, i64)
1045510455
max(false ? 1 : 0.5)
1045610456
max(false ? 1 : nil)
1045710457
max(false ? add : ok)
10458-
max(false ? half : list)
1045910458
max(false ? i : nil)
1046010459
max(false ? i32 : score)
1046110460
max(false ? true : 1)

Diff for: testdata/examples.txt

-14
Original file line numberDiff line numberDiff line change
@@ -7419,12 +7419,6 @@ get(ok ? score : foo, String?.foo())
74197419
get(ok ? score : i64, foo)
74207420
get(reduce(list, array), i32)
74217421
get(sort(array), i32)
7422-
get(sum(array), Qux)
7423-
get(sum(array), String)
7424-
get(sum(array), f32)
7425-
get(sum(array), f64 == list)
7426-
get(sum(array), greet)
7427-
get(sum(array), i)
74287422
get(take(list, i), i64)
74297423
get(true ? "bar" : ok, score(i))
74307424
get(true ? "foo" : half, list)
@@ -7460,7 +7454,6 @@ greet != nil ? list : false
74607454
greet != score
74617455
greet != score != false
74627456
greet != score or ok
7463-
greet != sum(array)
74647457
greet == add
74657458
greet == add ? i : list
74667459
greet == add or ok
@@ -12200,7 +12193,6 @@ last(ok ? ok : 0.5)
1220012193
last(reduce(array, list))
1220112194
last(reduce(list, array))
1220212195
last(sort(array))
12203-
last(sum(array))
1220412196
last(true ? "bar" : half)
1220512197
last(true ? add : list)
1220612198
last(true ? foo : 1)
@@ -14818,7 +14810,6 @@ ok != nil ? nil : array
1481814810
ok != not ok
1481914811
ok != ok
1482014812
ok != ok ? false : "bar"
14821-
ok != sum(array)
1482214813
ok && !false
1482314814
ok && !ok
1482414815
ok && "foo" matches "bar"
@@ -16970,7 +16961,6 @@ string(groupBy(list, i))
1697016961
string(half != nil)
1697116962
string(half != score)
1697216963
string(half == nil)
16973-
string(half == sum(array))
1697416964
string(half(0.5))
1697516965
string(half(1))
1697616966
string(half(f64))
@@ -17297,18 +17287,14 @@ sum([0.5])
1729717287
sum([f32])
1729817288
sum(array)
1729917289
sum(array) != f32
17300-
sum(array) != half
17301-
sum(array) != ok
1730217290
sum(array) % i
1730317291
sum(array) % i64
1730417292
sum(array) - f32
1730517293
sum(array) / -f64
1730617294
sum(array) < i
17307-
sum(array) == div
1730817295
sum(array) == i64 - i
1730917296
sum(array) ^ f64
1731017297
sum(array) not in array
17311-
sum(array) not in list
1731217298
sum(filter(array, ok))
1731317299
sum(groupBy(array, i32).String)
1731417300
sum(groupBy(list, #)?.greet)

0 commit comments

Comments
 (0)