diff --git a/parser/ast.go b/parser/ast.go index 6474406..aff1ff1 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -2971,7 +2971,7 @@ type QueryParam struct { LBracePos Pos RBracePos Pos Name *Ident - Type Expr + Type ColumnType } func (q *QueryParam) Pos() Pos { @@ -3182,7 +3182,7 @@ type ColumnDef struct { NamePos Pos ColumnEnd Pos Name *NestedIdentifier - Type Expr + Type ColumnType NotNull *NotNullLiteral Nullable *NullLiteral @@ -3298,72 +3298,85 @@ func (c *ColumnDef) Accept(visitor ASTVisitor) error { return visitor.VisitColumnDef(c) } -type ScalarTypeExpr struct { +type ColumnType interface { + Expr + Type() string +} + +type ScalarType struct { Name *Ident } -func (s *ScalarTypeExpr) Pos() Pos { +func (s *ScalarType) Pos() Pos { return s.Name.NamePos } -func (s *ScalarTypeExpr) End() Pos { +func (s *ScalarType) End() Pos { return s.Name.NameEnd } -func (s *ScalarTypeExpr) String() string { +func (s *ScalarType) String() string { return s.Name.String() } -func (s *ScalarTypeExpr) Accept(visitor ASTVisitor) error { +func (s *ScalarType) Accept(visitor ASTVisitor) error { visitor.enter(s) defer visitor.leave(s) if err := s.Name.Accept(visitor); err != nil { return err } - return visitor.VisitScalarTypeExpr(s) + return visitor.VisitScalarType(s) } -type PropertyTypeExpr struct { +func (s *ScalarType) Type() string { + return s.Name.Name +} + +type PropertyType struct { Name *Ident } -func (c *PropertyTypeExpr) Pos() Pos { +func (c *PropertyType) Pos() Pos { return c.Name.NamePos } -func (c *PropertyTypeExpr) End() Pos { +func (c *PropertyType) End() Pos { return c.Name.NameEnd } -func (c *PropertyTypeExpr) String() string { +func (c *PropertyType) String() string { return c.Name.String() } -func (c *PropertyTypeExpr) Accept(visitor ASTVisitor) error { +func (c *PropertyType) Accept(visitor ASTVisitor) error { visitor.enter(c) defer visitor.leave(c) if err := c.Name.Accept(visitor); err != nil { return err } - return visitor.VisitPropertyTypeExpr(c) + return visitor.VisitPropertyType(c) +} + +func (c *PropertyType) Type() string { + return c.Name.Name } -type TypeWithParamsExpr struct { +type TypeWithParams struct { LeftParenPos Pos RightParenPos Pos Name *Ident Params []Literal } -func (s *TypeWithParamsExpr) Pos() Pos { +func (s *TypeWithParams) Pos() Pos { return s.Name.NamePos } -func (s *TypeWithParamsExpr) End() Pos { +func (s *TypeWithParams) End() Pos { return s.RightParenPos } -func (s *TypeWithParamsExpr) String() string { +func (s *TypeWithParams) String() string { var builder strings.Builder builder.WriteString(s.Name.String()) builder.WriteByte('(') @@ -3377,7 +3390,7 @@ func (s *TypeWithParamsExpr) String() string { return builder.String() } -func (s *TypeWithParamsExpr) Accept(visitor ASTVisitor) error { +func (s *TypeWithParams) Accept(visitor ASTVisitor) error { visitor.enter(s) defer visitor.leave(s) if err := s.Name.Accept(visitor); err != nil { @@ -3388,25 +3401,29 @@ func (s *TypeWithParamsExpr) Accept(visitor ASTVisitor) error { return err } } - return visitor.VisitTypeWithParamsExpr(s) + return visitor.VisitTypeWithParams(s) } -type ComplexTypeExpr struct { +func (s *TypeWithParams) Type() string { + return s.Name.Name +} + +type ComplexType struct { LeftParenPos Pos RightParenPos Pos Name *Ident - Params []Expr + Params []ColumnType } -func (c *ComplexTypeExpr) Pos() Pos { +func (c *ComplexType) Pos() Pos { return c.Name.NamePos } -func (c *ComplexTypeExpr) End() Pos { +func (c *ComplexType) End() Pos { return c.RightParenPos } -func (c *ComplexTypeExpr) String() string { +func (c *ComplexType) String() string { var builder strings.Builder builder.WriteString(c.Name.String()) builder.WriteByte('(') @@ -3420,7 +3437,7 @@ func (c *ComplexTypeExpr) String() string { return builder.String() } -func (c *ComplexTypeExpr) Accept(visitor ASTVisitor) error { +func (c *ComplexType) Accept(visitor ASTVisitor) error { visitor.enter(c) defer visitor.leave(c) if err := c.Name.Accept(visitor); err != nil { @@ -3431,25 +3448,29 @@ func (c *ComplexTypeExpr) Accept(visitor ASTVisitor) error { return err } } - return visitor.VisitComplexTypeExpr(c) + return visitor.VisitComplexType(c) +} + +func (c *ComplexType) Type() string { + return c.Name.Name } -type NestedTypeExpr struct { +type NestedType struct { LeftParenPos Pos RightParenPos Pos Name *Ident Columns []Expr } -func (n *NestedTypeExpr) Pos() Pos { +func (n *NestedType) Pos() Pos { return n.Name.NamePos } -func (n *NestedTypeExpr) End() Pos { +func (n *NestedType) End() Pos { return n.RightParenPos } -func (n *NestedTypeExpr) String() string { +func (n *NestedType) String() string { var builder strings.Builder // on the same level as the column type builder.WriteString(n.Name.String()) @@ -3465,7 +3486,7 @@ func (n *NestedTypeExpr) String() string { return builder.String() } -func (n *NestedTypeExpr) Accept(visitor ASTVisitor) error { +func (n *NestedType) Accept(visitor ASTVisitor) error { visitor.enter(n) defer visitor.leave(n) if err := n.Name.Accept(visitor); err != nil { @@ -3476,7 +3497,11 @@ func (n *NestedTypeExpr) Accept(visitor ASTVisitor) error { return err } } - return visitor.VisitNestedTypeExpr(n) + return visitor.VisitNestedType(n) +} + +func (n *NestedType) Type() string { + return n.Name.Name } type CompressionCodec struct { @@ -3689,29 +3714,29 @@ func (e *EnumValue) Accept(visitor ASTVisitor) error { if err := e.Value.Accept(visitor); err != nil { return err } - return visitor.VisitEnumValueExpr(e) + return visitor.VisitEnumValue(e) } -type EnumValueList struct { +type EnumType struct { Name *Ident ListPos Pos ListEnd Pos - Enums []EnumValue + Values []EnumValue } -func (e *EnumValueList) Pos() Pos { +func (e *EnumType) Pos() Pos { return e.ListPos } -func (e *EnumValueList) End() Pos { +func (e *EnumType) End() Pos { return e.ListEnd } -func (e *EnumValueList) String() string { +func (e *EnumType) String() string { var builder strings.Builder builder.WriteString(e.Name.String()) builder.WriteByte('(') - for i, enum := range e.Enums { + for i, enum := range e.Values { if i > 0 { builder.WriteString(", ") } @@ -3721,18 +3746,22 @@ func (e *EnumValueList) String() string { return builder.String() } -func (e *EnumValueList) Accept(visitor ASTVisitor) error { +func (e *EnumType) Accept(visitor ASTVisitor) error { visitor.enter(e) defer visitor.leave(e) if err := e.Name.Accept(visitor); err != nil { return err } - for i := range e.Enums { - if err := e.Enums[i].Accept(visitor); err != nil { + for i := range e.Values { + if err := e.Values[i].Accept(visitor); err != nil { return err } } - return visitor.VisitEnumValueExprList(e) + return visitor.VisitEnumType(e) +} + +func (e *EnumType) Type() string { + return e.Name.Name } type IntervalExpr struct { diff --git a/parser/ast_visitor.go b/parser/ast_visitor.go index 7cd3957..a06bf3c 100644 --- a/parser/ast_visitor.go +++ b/parser/ast_visitor.go @@ -73,17 +73,17 @@ type ASTVisitor interface { VisitWindowFunctionExpr(expr *WindowFunctionExpr) error VisitColumnDef(expr *ColumnDef) error VisitColumnExpr(expr *ColumnExpr) error - VisitScalarTypeExpr(expr *ScalarTypeExpr) error - VisitPropertyTypeExpr(expr *PropertyTypeExpr) error - VisitTypeWithParamsExpr(expr *TypeWithParamsExpr) error - VisitComplexTypeExpr(expr *ComplexTypeExpr) error - VisitNestedTypeExpr(expr *NestedTypeExpr) error + VisitScalarType(expr *ScalarType) error + VisitPropertyType(expr *PropertyType) error + VisitTypeWithParams(expr *TypeWithParams) error + VisitComplexType(expr *ComplexType) error + VisitNestedType(expr *NestedType) error VisitCompressionCodec(expr *CompressionCodec) error VisitNumberLiteral(expr *NumberLiteral) error VisitStringLiteral(expr *StringLiteral) error VisitRatioExpr(expr *RatioExpr) error - VisitEnumValueExpr(expr *EnumValue) error - VisitEnumValueExprList(expr *EnumValueList) error + VisitEnumValue(expr *EnumValue) error + VisitEnumType(expr *EnumType) error VisitIntervalExpr(expr *IntervalExpr) error VisitEngineExpr(expr *EngineExpr) error VisitColumnTypeExpr(expr *ColumnTypeExpr) error @@ -678,35 +678,35 @@ func (v *DefaultASTVisitor) VisitColumnExpr(expr *ColumnExpr) error { return nil } -func (v *DefaultASTVisitor) VisitScalarTypeExpr(expr *ScalarTypeExpr) error { +func (v *DefaultASTVisitor) VisitScalarType(expr *ScalarType) error { if v.Visit != nil { return v.Visit(expr) } return nil } -func (v *DefaultASTVisitor) VisitPropertyTypeExpr(expr *PropertyTypeExpr) error { +func (v *DefaultASTVisitor) VisitPropertyType(expr *PropertyType) error { if v.Visit != nil { return v.Visit(expr) } return nil } -func (v *DefaultASTVisitor) VisitTypeWithParamsExpr(expr *TypeWithParamsExpr) error { +func (v *DefaultASTVisitor) VisitTypeWithParams(expr *TypeWithParams) error { if v.Visit != nil { return v.Visit(expr) } return nil } -func (v *DefaultASTVisitor) VisitComplexTypeExpr(expr *ComplexTypeExpr) error { +func (v *DefaultASTVisitor) VisitComplexType(expr *ComplexType) error { if v.Visit != nil { return v.Visit(expr) } return nil } -func (v *DefaultASTVisitor) VisitNestedTypeExpr(expr *NestedTypeExpr) error { +func (v *DefaultASTVisitor) VisitNestedType(expr *NestedType) error { if v.Visit != nil { return v.Visit(expr) } @@ -741,14 +741,14 @@ func (v *DefaultASTVisitor) VisitRatioExpr(expr *RatioExpr) error { return nil } -func (v *DefaultASTVisitor) VisitEnumValueExpr(expr *EnumValue) error { +func (v *DefaultASTVisitor) VisitEnumValue(expr *EnumValue) error { if v.Visit != nil { return v.Visit(expr) } return nil } -func (v *DefaultASTVisitor) VisitEnumValueExprList(expr *EnumValueList) error { +func (v *DefaultASTVisitor) VisitEnumType(expr *EnumType) error { if v.Visit != nil { return v.Visit(expr) } diff --git a/parser/parser_column.go b/parser/parser_column.go index 97e85b4..be18de0 100644 --- a/parser/parser_column.go +++ b/parser/parser_column.go @@ -368,7 +368,14 @@ func (p *Parser) parseColumnCastExpr(pos Pos) (Expr, error) { default: return nil, fmt.Errorf("expected AS or , but got %s", p.lastTokenKind()) } - asColumnType, err := p.parseColumnType(p.Pos()) + + var asColumnType Expr + // CAST(1 AS 'Float') or CAST(1 AS Float) are equivalent + if p.matchTokenKind(TokenString) { + asColumnType, err = p.parseString(p.Pos()) + } else { + asColumnType, err = p.parseColumnType(p.Pos()) + } if err != nil { return nil, err } @@ -749,10 +756,7 @@ func (p *Parser) parseColumnCaseExpr(pos Pos) (*CaseExpr, error) { return caseExpr, nil } -func (p *Parser) parseColumnType(_ Pos) (Expr, error) { // nolint:funlen - if p.matchTokenKind(TokenString) { - return p.parseString(p.Pos()) - } +func (p *Parser) parseColumnType(_ Pos) (ColumnType, error) { // nolint:funlen ident, err := p.parseIdent() if err != nil { return nil, err @@ -767,7 +771,7 @@ func (p *Parser) parseColumnType(_ Pos) (Expr, error) { // nolint:funlen case p.matchTokenKind(TokenString): if peekToken, err := p.lexer.peekToken(); err == nil && peekToken.Kind == opTypeEQ { // enum values - return p.parseEnumExpr(ident,p.Pos()) + return p.parseEnumType(ident, p.Pos()) } // like Datetime('Asia/Dubai') return p.parseColumnTypeWithParams(ident, p.Pos()) @@ -778,7 +782,7 @@ func (p *Parser) parseColumnType(_ Pos) (Expr, error) { // nolint:funlen return nil, fmt.Errorf("unexpected token kind: %v", p.lastTokenKind()) } } - return &ScalarTypeExpr{Name: ident}, nil + return &ScalarType{Name: ident}, nil } func (p *Parser) parseColumnPropertyType(_ Pos) (Expr, error) { @@ -786,13 +790,13 @@ func (p *Parser) parseColumnPropertyType(_ Pos) (Expr, error) { if err != nil { return nil, err } - return &PropertyTypeExpr{ + return &PropertyType{ Name: ident, }, nil } -func (p *Parser) parseComplexType(name *Ident, pos Pos) (Expr, error) { - subTypes := make([]Expr, 0) +func (p *Parser) parseComplexType(name *Ident, pos Pos) (*ComplexType, error) { + subTypes := make([]ColumnType, 0) for !p.lexer.isEOF() && !p.matchTokenKind(")") { subExpr, err := p.parseColumnType(p.Pos()) if err != nil { @@ -807,7 +811,7 @@ func (p *Parser) parseComplexType(name *Ident, pos Pos) (Expr, error) { if _, err := p.consumeTokenKind(")"); err != nil { return nil, err } - return &ComplexTypeExpr{ + return &ComplexType{ LeftParenPos: pos, RightParenPos: rightParenPos, Name: name, @@ -815,35 +819,35 @@ func (p *Parser) parseComplexType(name *Ident, pos Pos) (Expr, error) { }, nil } -func (p *Parser) parseEnumExpr(name *Ident, pos Pos) (*EnumValueList, error) { - enumValueList := &EnumValueList{ - Name: name, +func (p *Parser) parseEnumType(name *Ident, pos Pos) (*EnumType, error) { + enumType := &EnumType{ + Name: name, ListPos: pos, - Enums: make([]EnumValue, 0), + Values: make([]EnumValue, 0), } for !p.lexer.isEOF() && !p.matchTokenKind(")") { - enumValueExpr, err := p.parseEnumValueExpr(p.Pos()) + enumValue, err := p.parseEnumValueExpr(p.Pos()) if err != nil { return nil, err } - if enumValueExpr == nil { + if enumValue == nil { break } - enumValueList.Enums = append(enumValueList.Enums, *enumValueExpr) + enumType.Values = append(enumType.Values, *enumValue) if p.tryConsumeTokenKind(",") == nil { break } } - if len(enumValueList.Enums) > 0 { - enumValueList.ListEnd = enumValueList.Enums[len(enumValueList.Enums)-1].Value.NumEnd + if len(enumType.Values) > 0 { + enumType.ListEnd = enumType.Values[len(enumType.Values)-1].Value.NumEnd } if _, err := p.consumeTokenKind(")"); err != nil { return nil, err } - return enumValueList, nil + return enumType, nil } -func (p *Parser) parseColumnTypeWithParams(name *Ident, pos Pos) (*TypeWithParamsExpr, error) { +func (p *Parser) parseColumnTypeWithParams(name *Ident, pos Pos) (*TypeWithParams, error) { params := make([]Literal, 0) param, err := p.parseLiteral(p.Pos()) if err != nil { @@ -862,7 +866,7 @@ func (p *Parser) parseColumnTypeWithParams(name *Ident, pos Pos) (*TypeWithParam if _, err := p.consumeTokenKind(")"); err != nil { return nil, err } - return &TypeWithParamsExpr{ + return &TypeWithParams{ Name: name, LeftParenPos: pos, RightParenPos: rightParenPos, @@ -870,7 +874,7 @@ func (p *Parser) parseColumnTypeWithParams(name *Ident, pos Pos) (*TypeWithParam }, nil } -func (p *Parser) parseNestedType(name *Ident, pos Pos) (*NestedTypeExpr, error) { +func (p *Parser) parseNestedType(name *Ident, pos Pos) (*NestedType, error) { columns, err := p.parseTableColumns() if err != nil { return nil, err @@ -879,7 +883,7 @@ func (p *Parser) parseNestedType(name *Ident, pos Pos) (*NestedTypeExpr, error) if _, err := p.consumeTokenKind(")"); err != nil { return nil, err } - return &NestedTypeExpr{ + return &NestedType{ LeftParenPos: pos, RightParenPos: rightParenPos, Name: name, diff --git a/parser/testdata/ddl/output/create_table_with_enum_fields.sql.golden.json b/parser/testdata/ddl/output/create_table_with_enum_fields.sql.golden.json index 092c542..c1cf590 100644 --- a/parser/testdata/ddl/output/create_table_with_enum_fields.sql.golden.json +++ b/parser/testdata/ddl/output/create_table_with_enum_fields.sql.golden.json @@ -47,7 +47,7 @@ }, "ListPos": 65, "ListEnd": 160, - "Enums": [ + "Values": [ { "Name": { "LiteralPos": 65,