Skip to content

Commit

Permalink
Support for Parameterized type
Browse files Browse the repository at this point in the history
* Separate AnyType. This will be helpful in match method
* Added support for ParameterizedFixedChar/VarChar/FixedBinary/Decimal
* Added parser support for Parameterized/PrecisionTimestamp/PrecisionTimestampTz
  • Loading branch information
anshuldata committed Aug 30, 2024
1 parent 58e4ba0 commit ff46f15
Show file tree
Hide file tree
Showing 10 changed files with 517 additions and 46 deletions.
15 changes: 15 additions & 0 deletions extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
package extensions

import (
"errors"
"fmt"
"reflect"
"strings"

substraitgo "github.com/substrait-io/substrait-go"
"github.com/substrait-io/substrait-go/types"
"github.com/substrait-io/substrait-go/types/parser"
)

Expand Down Expand Up @@ -57,6 +59,7 @@ type TypeVariation struct {

type Argument interface {
toTypeString() string
ArgType() (types.Type, error)
}

type EnumArg struct {
Expand All @@ -69,6 +72,10 @@ func (EnumArg) toTypeString() string {
return "req"
}

func (EnumArg) ArgType() (types.Type, error) {
return nil, errors.New("unimplemented")
}

type ValueArg struct {
Name string `yaml:",omitempty"`
Description string `yaml:",omitempty"`
Expand All @@ -80,6 +87,10 @@ func (v ValueArg) toTypeString() string {
return v.Value.Expr.(*parser.Type).ShortType()
}

func (v ValueArg) ArgType() (types.Type, error) {
return v.Value.Expr.(*parser.Type).Type()
}

type TypeArg struct {
Name string `yaml:",omitempty"`
Description string `yaml:",omitempty"`
Expand All @@ -88,6 +99,10 @@ type TypeArg struct {

func (TypeArg) toTypeString() string { return "type" }

func (TypeArg) ArgType() (types.Type, error) {
return nil, errors.New("unimplemented")
}

type ArgumentList []Argument

func (a *ArgumentList) UnmarshalYAML(fn func(interface{}) error) error {
Expand Down
58 changes: 58 additions & 0 deletions types/any_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package types

import (
"fmt"

"github.com/substrait-io/substrait-go/proto"
)

// AnyType to represent AnyType, this type is to indicate "any" type of argument
// This type is not used in function invocation. It is only used in function definition
type AnyType struct {
Name string
Nullability Nullability
}

func (*AnyType) isRootRef() {}
func (m *AnyType) WithNullability(nullability Nullability) Type {
m.Nullability = nullability
return m
}
func (m *AnyType) GetType() Type { return m }
func (m *AnyType) GetNullability() Nullability {
return m.Nullability
}
func (*AnyType) GetTypeVariationReference() uint32 {
panic("not allowed")
}
func (*AnyType) Equals(rhs Type) bool {
// equal to every other type
return true
}

func (*AnyType) ToProtoFuncArg() *proto.FunctionArgument {
panic("not allowed")
}

func (*AnyType) ToProto() *proto.Type {
panic("not allowed")
}

func (t *AnyType) ShortString() string { return t.Name }
func (t *AnyType) String() string {
return fmt.Sprintf("%s%s", t.Name, strNullable(t))
}

// Below methods are for parser Def interface

func (*AnyType) Optional() bool {
panic("not allowed")
}

func (m *AnyType) ShortType() string {
return "any"
}

func (m *AnyType) Type() (Type, error) {
return m, nil
}
33 changes: 33 additions & 0 deletions types/any_type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package types_test

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/types"
)

func TestAnyType(t *testing.T) {
for _, td := range []struct {
testName string
argName string
nullability types.Nullability
expectedString string
}{
{"any", "any", types.NullabilityNullable, "any?"},
{"anyrequired", "any", types.NullabilityRequired, "any"},
{"anyOtherName", "any1", types.NullabilityNullable, "any1?"},
{"T name", "T", types.NullabilityNullable, "T?"},
} {
t.Run(td.testName, func(t *testing.T) {
arg := &types.AnyType{
Name: td.argName,
Nullability: td.nullability,
}
require.Equal(t, td.expectedString, arg.String())
require.Equal(t, td.nullability, arg.GetNullability())
require.Equal(t, td.argName, arg.ShortString())
require.Equal(t, "any", arg.ShortType())
})
}
}
55 changes: 55 additions & 0 deletions types/parameterized_decimal_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package types

import (
"fmt"

"github.com/substrait-io/substrait-go/proto"
)

type ParameterizedDecimal struct {
Nullability Nullability
TypeVariationRef uint32
Precision IntegerParam
Scale IntegerParam
}

func (*ParameterizedDecimal) isRootRef() {}
func (m *ParameterizedDecimal) WithNullability(n Nullability) Type {
m.Nullability = n
return m
}

func (m *ParameterizedDecimal) GetType() Type { return m }
func (m *ParameterizedDecimal) GetNullability() Nullability { return m.Nullability }
func (m *ParameterizedDecimal) GetTypeVariationReference() uint32 {
return m.TypeVariationRef
}
func (m *ParameterizedDecimal) Equals(rhs Type) bool {
if o, ok := rhs.(*ParameterizedDecimal); ok {
return *o == *m
}
return false
}

func (*ParameterizedDecimal) ToProtoFuncArg() *proto.FunctionArgument {
// parameterized type are never on wire so to proto is not supported
panic("not supported")
}

func (m *ParameterizedDecimal) ShortString() string {
t := &DecimalType{}
return t.ShortString()
}

func (m *ParameterizedDecimal) String() string {
return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString())
}

func (m *ParameterizedDecimal) ParameterString() string {
return fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String())
}

func (m *ParameterizedDecimal) BaseString() string {
t := &DecimalType{}
return t.BaseString()
}
114 changes: 114 additions & 0 deletions types/parameterized_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package types

import (
"fmt"

"github.com/substrait-io/substrait-go/proto"
)

type IntegerParam struct {
Name string
}

func (m IntegerParam) Equals(o IntegerParam) bool {
return m == o
}

func (p IntegerParam) ToProto() *proto.ParameterizedType_IntegerParameter {
panic("not implemented")
}

func (m *IntegerParam) String() string {
return m.Name
}

type ParameterizedSingleIntegerType interface {
Type
WithIntegerOption(param IntegerParam) ParameterizedSingleIntegerType
}

type ParameterizedTypeSingleIntegerParam[T VarCharType | FixedCharType | FixedBinaryType | PrecisionTimestampType | PrecisionTimestampTzType] struct {
Nullability Nullability
TypeVariationRef uint32
IntegerOption IntegerParam
}

func (m *ParameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption IntegerParam) ParameterizedSingleIntegerType {
m.IntegerOption = integerOption
return m
}

func (*ParameterizedTypeSingleIntegerParam[T]) isRootRef() {}
func (m *ParameterizedTypeSingleIntegerParam[T]) WithNullability(n Nullability) Type {
m.Nullability = n
return m
}

func (m *ParameterizedTypeSingleIntegerParam[T]) GetType() Type { return m }
func (m *ParameterizedTypeSingleIntegerParam[T]) GetNullability() Nullability { return m.Nullability }
func (m *ParameterizedTypeSingleIntegerParam[T]) GetTypeVariationReference() uint32 {
return m.TypeVariationRef
}
func (m *ParameterizedTypeSingleIntegerParam[T]) Equals(rhs Type) bool {
if o, ok := rhs.(*ParameterizedTypeSingleIntegerParam[T]); ok {
return *o == *m
}
return false
}

func (*ParameterizedTypeSingleIntegerParam[T]) ToProtoFuncArg() *proto.FunctionArgument {
// parameterized type are never on wire so to proto is not supported
panic("not supported")
}

func (m *ParameterizedTypeSingleIntegerParam[T]) ShortString() string {
switch any(m).(type) {
case *ParameterizedVarCharType:
t := &VarCharType{}
return t.ShortString()
case *ParameterizedFixedCharType:
t := &FixedCharType{}
return t.ShortString()
case *ParameterizedFixedBinaryType:
t := &FixedBinaryType{}
return t.ShortString()
case *ParameterizedPrecisionTimestampType:
t := &PrecisionTimestampType{}
return t.ShortString()
case *ParameterizedPrecisionTimestampTzType:
t := &PrecisionTimestampTzType{}
return t.ShortString()
default:
panic("unknown type")
}
}

func (m *ParameterizedTypeSingleIntegerParam[T]) String() string {
return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString())
}

func (m *ParameterizedTypeSingleIntegerParam[T]) ParameterString() string {
return fmt.Sprintf("<%s>", m.IntegerOption.String())
}

func (m *ParameterizedTypeSingleIntegerParam[T]) BaseString() string {
switch any(m).(type) {
case *ParameterizedVarCharType:
t := &VarCharType{}
return t.BaseString()
case *ParameterizedFixedCharType:
t := &FixedCharType{}
return t.BaseString()
case *ParameterizedFixedBinaryType:
t := &FixedBinaryType{}
return t.BaseString()
case *ParameterizedPrecisionTimestampType:
t := &PrecisionTimestampType{}
return t.BaseString()
case *ParameterizedPrecisionTimestampTzType:
t := &PrecisionTimestampTzType{}
return t.BaseString()
default:
panic("unknown type")
}
}
66 changes: 66 additions & 0 deletions types/parameterized_types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package types_test

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/types"
)

func TestParameterizedVarCharType(t *testing.T) {
for _, td := range []struct {
name string
typ types.ParameterizedSingleIntegerType
nullability types.Nullability
integerOption types.IntegerParam
expectedString string
expectedBaseString string
expectedShortString string
}{
{"nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "varchar?<L1>", "varchar", "vchar"},
{"non nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "varchar<L1>", "varchar", "vchar"},
{"nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "char?<L1>", "char", "fchar"},
{"non nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "char<L1>", "char", "fchar"},
{"nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "fixedbinary?<L1>", "fixedbinary", "fbin"},
{"non nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "fixedbinary<L1>", "fixedbinary", "fbin"},
{"nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "precision_timestamp?<L1>", "precision_timestamp", "prets"},
{"non nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "precision_timestamp<L1>", "precision_timestamp", "prets"},
{"nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "precision_timestamp_tz?<L1>", "precision_timestamp_tz", "pretstz"},
{"non nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "precision_timestamp_tz<L1>", "precision_timestamp_tz", "pretstz"},
} {
t.Run(td.name, func(t *testing.T) {
pt := td.typ.WithIntegerOption(td.integerOption).WithNullability(td.nullability)
require.Equal(t, td.expectedString, pt.String())
parameterizeType, ok := pt.(types.ParameterizedType)
require.True(t, ok)
require.Equal(t, td.expectedBaseString, parameterizeType.BaseString())
require.Equal(t, td.expectedShortString, pt.ShortString())
require.True(t, pt.Equals(pt))
})
}
}

func TestParameterizedDecimalType(t *testing.T) {
for _, td := range []struct {
name string
precision string
scale string
nullability types.Nullability
expectedString string
expectedBaseString string
expectedShortString string
}{
{"nullable decimal", "P", "S", types.NullabilityNullable, "decimal?<P,S>", "decimal", "dec"},
{"non nullable decimal", "P", "S", types.NullabilityRequired, "decimal<P,S>", "decimal", "dec"},
} {
t.Run(td.name, func(t *testing.T) {
precision := types.IntegerParam{Name: td.precision}
scale := types.IntegerParam{Name: td.scale}
pt := &types.ParameterizedDecimalType{Precision: precision, Scale: scale, Nullability: td.nullability}
require.Equal(t, td.expectedString, pt.String())
require.Equal(t, td.expectedBaseString, pt.BaseString())
require.Equal(t, td.expectedShortString, pt.ShortString())
require.True(t, pt.Equals(pt))
})
}
}
Loading

0 comments on commit ff46f15

Please sign in to comment.