Skip to content

Commit 257623d

Browse files
Merge pull request #68 from strongdm/idx-351/fix-partial-fold-comparable
Use comparable evaluators in partial evaluation and folding
2 parents b25ac8d + 2b80710 commit 257623d

File tree

7 files changed

+90
-328
lines changed

7 files changed

+90
-328
lines changed

internal/eval/evalers.go

+4-104
Original file line numberDiff line numberDiff line change
@@ -429,106 +429,6 @@ func (n *negateEval) Eval(env Env) (types.Value, error) {
429429
return res, nil
430430
}
431431

432-
// longLessThanEval
433-
type longLessThanEval struct {
434-
lhs Evaler
435-
rhs Evaler
436-
}
437-
438-
func newLongLessThanEval(lhs Evaler, rhs Evaler) Evaler {
439-
return &longLessThanEval{
440-
lhs: lhs,
441-
rhs: rhs,
442-
}
443-
}
444-
445-
func (n *longLessThanEval) Eval(env Env) (types.Value, error) {
446-
lhs, err := evalLong(n.lhs, env)
447-
if err != nil {
448-
return zeroValue(), err
449-
}
450-
rhs, err := evalLong(n.rhs, env)
451-
if err != nil {
452-
return zeroValue(), err
453-
}
454-
return types.Boolean(lhs < rhs), nil
455-
}
456-
457-
// longLessThanOrEqualEval
458-
type longLessThanOrEqualEval struct {
459-
lhs Evaler
460-
rhs Evaler
461-
}
462-
463-
func newLongLessThanOrEqualEval(lhs Evaler, rhs Evaler) Evaler {
464-
return &longLessThanOrEqualEval{
465-
lhs: lhs,
466-
rhs: rhs,
467-
}
468-
}
469-
470-
func (n *longLessThanOrEqualEval) Eval(env Env) (types.Value, error) {
471-
lhs, err := evalLong(n.lhs, env)
472-
if err != nil {
473-
return zeroValue(), err
474-
}
475-
rhs, err := evalLong(n.rhs, env)
476-
if err != nil {
477-
return zeroValue(), err
478-
}
479-
return types.Boolean(lhs <= rhs), nil
480-
}
481-
482-
// longGreaterThanEval
483-
type longGreaterThanEval struct {
484-
lhs Evaler
485-
rhs Evaler
486-
}
487-
488-
func newLongGreaterThanEval(lhs Evaler, rhs Evaler) Evaler {
489-
return &longGreaterThanEval{
490-
lhs: lhs,
491-
rhs: rhs,
492-
}
493-
}
494-
495-
func (n *longGreaterThanEval) Eval(env Env) (types.Value, error) {
496-
lhs, err := evalLong(n.lhs, env)
497-
if err != nil {
498-
return zeroValue(), err
499-
}
500-
rhs, err := evalLong(n.rhs, env)
501-
if err != nil {
502-
return zeroValue(), err
503-
}
504-
return types.Boolean(lhs > rhs), nil
505-
}
506-
507-
// longGreaterThanOrEqualEval
508-
type longGreaterThanOrEqualEval struct {
509-
lhs Evaler
510-
rhs Evaler
511-
}
512-
513-
func newLongGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) Evaler {
514-
return &longGreaterThanOrEqualEval{
515-
lhs: lhs,
516-
rhs: rhs,
517-
}
518-
}
519-
520-
func (n *longGreaterThanOrEqualEval) Eval(env Env) (types.Value, error) {
521-
lhs, err := evalLong(n.lhs, env)
522-
if err != nil {
523-
return zeroValue(), err
524-
}
525-
rhs, err := evalLong(n.rhs, env)
526-
if err != nil {
527-
return zeroValue(), err
528-
}
529-
return types.Boolean(lhs >= rhs), nil
530-
}
531-
532432
// decimalLessThanEval
533433
type decimalLessThanEval struct {
534434
lhs Evaler
@@ -1301,7 +1201,7 @@ type comparableValueLessThanEval struct {
13011201
rhs Evaler
13021202
}
13031203

1304-
func newComparableValueLessThanEval(lhs Evaler, rhs Evaler) *comparableValueLessThanEval {
1204+
func newComparableValueLessThanEval(lhs Evaler, rhs Evaler) Evaler {
13051205
return &comparableValueLessThanEval{
13061206
lhs: lhs,
13071207
rhs: rhs,
@@ -1332,7 +1232,7 @@ type comparableValueGreaterThanEval struct {
13321232
rhs Evaler
13331233
}
13341234

1335-
func newComparableValueGreaterThanEval(lhs Evaler, rhs Evaler) *comparableValueGreaterThanEval {
1235+
func newComparableValueGreaterThanEval(lhs Evaler, rhs Evaler) Evaler {
13361236
return &comparableValueGreaterThanEval{
13371237
lhs: lhs,
13381238
rhs: rhs,
@@ -1361,7 +1261,7 @@ type comparableValueLessThanOrEqualEval struct {
13611261
rhs Evaler
13621262
}
13631263

1364-
func newComparableValueLessThanOrEqualEval(lhs Evaler, rhs Evaler) *comparableValueLessThanOrEqualEval {
1264+
func newComparableValueLessThanOrEqualEval(lhs Evaler, rhs Evaler) Evaler {
13651265
return &comparableValueLessThanOrEqualEval{
13661266
lhs: lhs,
13671267
rhs: rhs,
@@ -1390,7 +1290,7 @@ type comparableValueGreaterThanOrEqualEval struct {
13901290
rhs Evaler
13911291
}
13921292

1393-
func newComparableValueGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) *comparableValueGreaterThanOrEqualEval {
1293+
func newComparableValueGreaterThanOrEqualEval(lhs Evaler, rhs Evaler) Evaler {
13941294
return &comparableValueGreaterThanOrEqualEval{
13951295
lhs: lhs,
13961296
rhs: rhs,

internal/eval/evalers_test.go

-208
Original file line numberDiff line numberDiff line change
@@ -513,214 +513,6 @@ func TestNegateNode(t *testing.T) {
513513
}
514514
}
515515

516-
func TestLongLessThanNode(t *testing.T) {
517-
t.Parallel()
518-
{
519-
tests := []struct {
520-
lhs, rhs int64
521-
result bool
522-
}{
523-
{-1, -1, false},
524-
{-1, 0, true},
525-
{-1, 1, true},
526-
{0, -1, false},
527-
{0, 0, false},
528-
{0, 1, true},
529-
{1, -1, false},
530-
{1, 0, false},
531-
{1, 1, false},
532-
}
533-
for _, tt := range tests {
534-
tt := tt
535-
t.Run(fmt.Sprintf("%v<%v", tt.lhs, tt.rhs), func(t *testing.T) {
536-
t.Parallel()
537-
n := newLongLessThanEval(
538-
newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs)))
539-
v, err := n.Eval(Env{})
540-
testutil.OK(t, err)
541-
AssertBoolValue(t, v, tt.result)
542-
})
543-
}
544-
}
545-
{
546-
tests := []struct {
547-
name string
548-
lhs, rhs Evaler
549-
err error
550-
}{
551-
{"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest},
552-
{"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType},
553-
{"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest},
554-
{"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType},
555-
}
556-
for _, tt := range tests {
557-
tt := tt
558-
t.Run(tt.name, func(t *testing.T) {
559-
t.Parallel()
560-
n := newLongLessThanEval(tt.lhs, tt.rhs)
561-
_, err := n.Eval(Env{})
562-
testutil.ErrorIs(t, err, tt.err)
563-
})
564-
}
565-
}
566-
}
567-
568-
func TestLongLessThanOrEqualNode(t *testing.T) {
569-
t.Parallel()
570-
{
571-
tests := []struct {
572-
lhs, rhs int64
573-
result bool
574-
}{
575-
{-1, -1, true},
576-
{-1, 0, true},
577-
{-1, 1, true},
578-
{0, -1, false},
579-
{0, 0, true},
580-
{0, 1, true},
581-
{1, -1, false},
582-
{1, 0, false},
583-
{1, 1, true},
584-
}
585-
for _, tt := range tests {
586-
tt := tt
587-
t.Run(fmt.Sprintf("%v<=%v", tt.lhs, tt.rhs), func(t *testing.T) {
588-
t.Parallel()
589-
n := newLongLessThanOrEqualEval(
590-
newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs)))
591-
v, err := n.Eval(Env{})
592-
testutil.OK(t, err)
593-
AssertBoolValue(t, v, tt.result)
594-
})
595-
}
596-
}
597-
{
598-
tests := []struct {
599-
name string
600-
lhs, rhs Evaler
601-
err error
602-
}{
603-
{"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest},
604-
{"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType},
605-
{"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest},
606-
{"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType},
607-
}
608-
for _, tt := range tests {
609-
tt := tt
610-
t.Run(tt.name, func(t *testing.T) {
611-
t.Parallel()
612-
n := newLongLessThanOrEqualEval(tt.lhs, tt.rhs)
613-
_, err := n.Eval(Env{})
614-
testutil.ErrorIs(t, err, tt.err)
615-
})
616-
}
617-
}
618-
}
619-
620-
func TestLongGreaterThanNode(t *testing.T) {
621-
t.Parallel()
622-
{
623-
tests := []struct {
624-
lhs, rhs int64
625-
result bool
626-
}{
627-
{-1, -1, false},
628-
{-1, 0, false},
629-
{-1, 1, false},
630-
{0, -1, true},
631-
{0, 0, false},
632-
{0, 1, false},
633-
{1, -1, true},
634-
{1, 0, true},
635-
{1, 1, false},
636-
}
637-
for _, tt := range tests {
638-
tt := tt
639-
t.Run(fmt.Sprintf("%v>%v", tt.lhs, tt.rhs), func(t *testing.T) {
640-
t.Parallel()
641-
n := newLongGreaterThanEval(
642-
newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs)))
643-
v, err := n.Eval(Env{})
644-
testutil.OK(t, err)
645-
AssertBoolValue(t, v, tt.result)
646-
})
647-
}
648-
}
649-
{
650-
tests := []struct {
651-
name string
652-
lhs, rhs Evaler
653-
err error
654-
}{
655-
{"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest},
656-
{"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType},
657-
{"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest},
658-
{"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType},
659-
}
660-
for _, tt := range tests {
661-
tt := tt
662-
t.Run(tt.name, func(t *testing.T) {
663-
t.Parallel()
664-
n := newLongGreaterThanEval(tt.lhs, tt.rhs)
665-
_, err := n.Eval(Env{})
666-
testutil.ErrorIs(t, err, tt.err)
667-
})
668-
}
669-
}
670-
}
671-
672-
func TestLongGreaterThanOrEqualNode(t *testing.T) {
673-
t.Parallel()
674-
{
675-
tests := []struct {
676-
lhs, rhs int64
677-
result bool
678-
}{
679-
{-1, -1, true},
680-
{-1, 0, false},
681-
{-1, 1, false},
682-
{0, -1, true},
683-
{0, 0, true},
684-
{0, 1, false},
685-
{1, -1, true},
686-
{1, 0, true},
687-
{1, 1, true},
688-
}
689-
for _, tt := range tests {
690-
tt := tt
691-
t.Run(fmt.Sprintf("%v>=%v", tt.lhs, tt.rhs), func(t *testing.T) {
692-
t.Parallel()
693-
n := newLongGreaterThanOrEqualEval(
694-
newLiteralEval(types.Long(tt.lhs)), newLiteralEval(types.Long(tt.rhs)))
695-
v, err := n.Eval(Env{})
696-
testutil.OK(t, err)
697-
AssertBoolValue(t, v, tt.result)
698-
})
699-
}
700-
}
701-
{
702-
tests := []struct {
703-
name string
704-
lhs, rhs Evaler
705-
err error
706-
}{
707-
{"LhsError", newErrorEval(errTest), newLiteralEval(types.Long(0)), errTest},
708-
{"LhsTypeError", newLiteralEval(types.True), newLiteralEval(types.Long(0)), ErrType},
709-
{"RhsError", newLiteralEval(types.Long(0)), newErrorEval(errTest), errTest},
710-
{"RhsTypeError", newLiteralEval(types.Long(0)), newLiteralEval(types.True), ErrType},
711-
}
712-
for _, tt := range tests {
713-
tt := tt
714-
t.Run(tt.name, func(t *testing.T) {
715-
t.Parallel()
716-
n := newLongGreaterThanOrEqualEval(tt.lhs, tt.rhs)
717-
_, err := n.Eval(Env{})
718-
testutil.ErrorIs(t, err, tt.err)
719-
})
720-
}
721-
}
722-
}
723-
724516
func TestDecimalLessThanNode(t *testing.T) {
725517
t.Parallel()
726518
{

internal/eval/fold.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,13 @@ func fold(n ast.IsNode) ast.IsNode {
228228
case ast.NodeTypeNotEquals:
229229
return tryFoldBinary(v.BinaryNode, newNotEqualEval, func(b ast.BinaryNode) ast.IsNode { return ast.NodeTypeNotEquals{BinaryNode: b} })
230230
case ast.NodeTypeGreaterThan:
231-
return tryFoldBinary(v.BinaryNode, newLongGreaterThanEval, func(b ast.BinaryNode) ast.IsNode { return ast.NodeTypeGreaterThan{BinaryNode: b} })
231+
return tryFoldBinary(v.BinaryNode, newComparableValueGreaterThanEval, func(b ast.BinaryNode) ast.IsNode { return ast.NodeTypeGreaterThan{BinaryNode: b} })
232232
case ast.NodeTypeGreaterThanOrEqual:
233-
return tryFoldBinary(v.BinaryNode, newLongGreaterThanOrEqualEval, func(b ast.BinaryNode) ast.IsNode { return ast.NodeTypeGreaterThanOrEqual{BinaryNode: b} })
233+
return tryFoldBinary(v.BinaryNode, newComparableValueGreaterThanOrEqualEval, func(b ast.BinaryNode) ast.IsNode { return ast.NodeTypeGreaterThanOrEqual{BinaryNode: b} })
234234
case ast.NodeTypeLessThan:
235-
return tryFoldBinary(v.BinaryNode, newLongLessThanEval, func(b ast.BinaryNode) ast.IsNode { return ast.NodeTypeLessThan{BinaryNode: b} })
235+
return tryFoldBinary(v.BinaryNode, newComparableValueLessThanEval, func(b ast.BinaryNode) ast.IsNode { return ast.NodeTypeLessThan{BinaryNode: b} })
236236
case ast.NodeTypeLessThanOrEqual:
237-
return tryFoldBinary(v.BinaryNode, newLongLessThanOrEqualEval, func(b ast.BinaryNode) ast.IsNode { return ast.NodeTypeLessThanOrEqual{BinaryNode: b} })
237+
return tryFoldBinary(v.BinaryNode, newComparableValueLessThanOrEqualEval, func(b ast.BinaryNode) ast.IsNode { return ast.NodeTypeLessThanOrEqual{BinaryNode: b} })
238238
case ast.NodeTypeSub:
239239
return tryFoldBinary(v.BinaryNode, newSubtractEval, func(b ast.BinaryNode) ast.IsNode { return ast.NodeTypeSub{BinaryNode: b} })
240240
case ast.NodeTypeAdd:

0 commit comments

Comments
 (0)