Skip to content

Commit

Permalink
Fix recently discovered discrepancies with protoc (bufbuild#229)
Browse files Browse the repository at this point in the history
Removes support for explicit positive sign `+` before numeric literals
since that is not actually supported by `protoc`. Also adds support for
`-nan` (with negative sign) in option values as well as case-insensitive
`inf`, and `infinity`, and `nan` in message literals, all of which are
allowed by `protoc`.
  • Loading branch information
kralicky committed Feb 7, 2024
1 parent bdc2198 commit d37b7f5
Show file tree
Hide file tree
Showing 6 changed files with 780 additions and 697 deletions.
53 changes: 5 additions & 48 deletions ast/values.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ var _ ValueNode = (*CompoundIdentNode)(nil)
var _ ValueNode = (*StringLiteralNode)(nil)
var _ ValueNode = (*CompoundStringLiteralNode)(nil)
var _ ValueNode = (*UintLiteralNode)(nil)
var _ ValueNode = (*PositiveUintLiteralNode)(nil)
var _ ValueNode = (*NegativeIntLiteralNode)(nil)
var _ ValueNode = (*FloatLiteralNode)(nil)
var _ ValueNode = (*SpecialFloatLiteralNode)(nil)
Expand Down Expand Up @@ -154,7 +153,6 @@ func AsInt32(n IntValueNode, min, max int32) (int32, bool) {
}

var _ IntValueNode = (*UintLiteralNode)(nil)
var _ IntValueNode = (*PositiveUintLiteralNode)(nil)
var _ IntValueNode = (*NegativeIntLiteralNode)(nil)

// UintLiteralNode represents a simple integer literal with no sign character.
Expand Down Expand Up @@ -191,49 +189,6 @@ func (n *UintLiteralNode) AsFloat() float64 {
return float64(n.Val)
}

// PositiveUintLiteralNode represents an integer literal with a positive (+) sign.
type PositiveUintLiteralNode struct {
compositeNode
Plus *RuneNode
Uint *UintLiteralNode
Val uint64
}

// NewPositiveUintLiteralNode creates a new *PositiveUintLiteralNode. Both
// arguments must be non-nil.
func NewPositiveUintLiteralNode(sign *RuneNode, i *UintLiteralNode) *PositiveUintLiteralNode {
if sign == nil {
panic("sign is nil")
}
if i == nil {
panic("i is nil")
}
children := []Node{sign, i}
return &PositiveUintLiteralNode{
compositeNode: compositeNode{
children: children,
},
Plus: sign,
Uint: i,
Val: i.Val,
}
}

func (n *PositiveUintLiteralNode) Value() interface{} {
return n.Val
}

func (n *PositiveUintLiteralNode) AsInt64() (int64, bool) {
if n.Val > math.MaxInt64 {
return 0, false
}
return int64(n.Val), true
}

func (n *PositiveUintLiteralNode) AsUint64() (uint64, bool) {
return n.Val, true
}

// NegativeIntLiteralNode represents an integer literal with a negative (-) sign.
type NegativeIntLiteralNode struct {
compositeNode
Expand Down Expand Up @@ -320,12 +275,14 @@ type SpecialFloatLiteralNode struct {
}

// NewSpecialFloatLiteralNode returns a new *SpecialFloatLiteralNode for the
// given keyword, which must be "inf" or "nan".
// given keyword. The given keyword should be "inf", "infinity", or "nan"
// in any case.
func NewSpecialFloatLiteralNode(name *KeywordNode) *SpecialFloatLiteralNode {
var f float64
if name.Val == "inf" {
switch strings.ToLower(name.Val) {
case "inf", "infinity":
f = math.Inf(1)
} else {
default:
f = math.NaN()
}
return &SpecialFloatLiteralNode{
Expand Down
10 changes: 0 additions & 10 deletions ast/visitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,6 @@ func testVisitors(methodCalled *string) (*SimpleVisitor, []*SimpleVisitor) {
*methodCalled = "*UintLiteralNode"
return nil
},
DoVisitPositiveUintLiteralNode: func(*PositiveUintLiteralNode) error {
*methodCalled = "*PositiveUintLiteralNode"
return nil
},
DoVisitNegativeIntLiteralNode: func(*NegativeIntLiteralNode) error {
*methodCalled = "*NegativeIntLiteralNode"
return nil
Expand Down Expand Up @@ -350,9 +346,6 @@ func testVisitors(methodCalled *string) (*SimpleVisitor, []*SimpleVisitor) {
{
DoVisitUintLiteralNode: v.DoVisitUintLiteralNode,
},
{
DoVisitPositiveUintLiteralNode: v.DoVisitPositiveUintLiteralNode,
},
{
DoVisitNegativeIntLiteralNode: v.DoVisitNegativeIntLiteralNode,
},
Expand Down Expand Up @@ -477,9 +470,6 @@ func TestVisitorAll(t *testing.T) {
(*UintLiteralNode)(nil): {
"*UintLiteralNode", "ValueNode", "IntValueNode", "FloatValueNode", "TerminalNode", "Node",
},
(*PositiveUintLiteralNode)(nil): {
"*PositiveUintLiteralNode", "ValueNode", "IntValueNode", "CompositeNode", "Node",
},
(*NegativeIntLiteralNode)(nil): {
"*NegativeIntLiteralNode", "ValueNode", "IntValueNode", "CompositeNode", "Node",
},
Expand Down
16 changes: 0 additions & 16 deletions ast/walk.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ func Visit(n Node, v Visitor) error {
return v.VisitCompoundStringLiteralNode(n)
case *UintLiteralNode:
return v.VisitUintLiteralNode(n)
case *PositiveUintLiteralNode:
return v.VisitPositiveUintLiteralNode(n)
case *NegativeIntLiteralNode:
return v.VisitNegativeIntLiteralNode(n)
case *FloatLiteralNode:
Expand Down Expand Up @@ -316,8 +314,6 @@ type Visitor interface {
VisitCompoundStringLiteralNode(*CompoundStringLiteralNode) error
// VisitUintLiteralNode is invoked when visiting a *UintLiteralNode in the AST.
VisitUintLiteralNode(*UintLiteralNode) error
// VisitPositiveUintLiteralNode is invoked when visiting a *PositiveUintLiteralNode in the AST.
VisitPositiveUintLiteralNode(*PositiveUintLiteralNode) error
// VisitNegativeIntLiteralNode is invoked when visiting a *NegativeIntLiteralNode in the AST.
VisitNegativeIntLiteralNode(*NegativeIntLiteralNode) error
// VisitFloatLiteralNode is invoked when visiting a *FloatLiteralNode in the AST.
Expand Down Expand Up @@ -469,10 +465,6 @@ func (n NoOpVisitor) VisitUintLiteralNode(_ *UintLiteralNode) error {
return nil
}

func (n NoOpVisitor) VisitPositiveUintLiteralNode(_ *PositiveUintLiteralNode) error {
return nil
}

func (n NoOpVisitor) VisitNegativeIntLiteralNode(_ *NegativeIntLiteralNode) error {
return nil
}
Expand Down Expand Up @@ -569,7 +561,6 @@ type SimpleVisitor struct {
DoVisitStringLiteralNode func(*StringLiteralNode) error
DoVisitCompoundStringLiteralNode func(*CompoundStringLiteralNode) error
DoVisitUintLiteralNode func(*UintLiteralNode) error
DoVisitPositiveUintLiteralNode func(*PositiveUintLiteralNode) error
DoVisitNegativeIntLiteralNode func(*NegativeIntLiteralNode) error
DoVisitFloatLiteralNode func(*FloatLiteralNode) error
DoVisitSpecialFloatLiteralNode func(*SpecialFloatLiteralNode) error
Expand Down Expand Up @@ -862,13 +853,6 @@ func (v *SimpleVisitor) VisitUintLiteralNode(node *UintLiteralNode) error {
return v.visitInterface(node)
}

func (v *SimpleVisitor) VisitPositiveUintLiteralNode(node *PositiveUintLiteralNode) error {
if v.DoVisitPositiveUintLiteralNode != nil {
return v.DoVisitPositiveUintLiteralNode(node)
}
return v.visitInterface(node)
}

func (v *SimpleVisitor) VisitNegativeIntLiteralNode(node *NegativeIntLiteralNode) error {
if v.DoVisitNegativeIntLiteralNode != nil {
return v.DoVisitNegativeIntLiteralNode(node)
Expand Down
60 changes: 41 additions & 19 deletions parser/proto.y
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package parser

import (
"math"
"strings"

"github.com/bufbuild/protocompile/ast"
)
Expand Down Expand Up @@ -85,7 +86,7 @@ import (
%type <ref> extensionName messageLiteralFieldName
%type <optNms> optionName
%type <cmpctOpts> compactOptions
%type <v> value optionValue scalarValue messageLiteralWithBraces messageLiteral numLit listLiteral listElement listOfMessagesLiteral messageValue
%type <v> fieldValue optionValue scalarValue fieldScalarValue messageLiteralWithBraces messageLiteral numLit specialFloatLit listLiteral listElement listOfMessagesLiteral messageValue
%type <il> enumValueNumber
%type <id> identifier mapKeyType msgElementName extElementName oneofElementName notGroupElementName mtdElementName enumValueName fieldCardinality
%type <cid> qualifiedIdentifier msgElementIdent extElementIdent oneofElementIdent notGroupElementIdent mtdElementIdent
Expand Down Expand Up @@ -328,6 +329,7 @@ scalarValue : stringLit {
$$ = toStringValueNode($1)
}
| numLit
| specialFloatLit
| identifier {
$$ = $1
}
Expand All @@ -338,23 +340,9 @@ numLit : _FLOAT_LIT {
| '-' _FLOAT_LIT {
$$ = ast.NewSignedFloatLiteralNode($1, $2)
}
| '+' _FLOAT_LIT {
$$ = ast.NewSignedFloatLiteralNode($1, $2)
}
| '+' _INF {
f := ast.NewSpecialFloatLiteralNode($2.ToKeyword())
$$ = ast.NewSignedFloatLiteralNode($1, f)
}
| '-' _INF {
f := ast.NewSpecialFloatLiteralNode($2.ToKeyword())
$$ = ast.NewSignedFloatLiteralNode($1, f)
}
| _INT_LIT {
$$ = $1
}
| '+' _INT_LIT {
$$ = ast.NewPositiveUintLiteralNode($1, $2)
}
| '-' _INT_LIT {
if $2.Val > math.MaxInt64 + 1 {
// can't represent as int so treat as float literal
Expand All @@ -364,6 +352,16 @@ numLit : _FLOAT_LIT {
}
}

specialFloatLit
: '-' _INF {
f := ast.NewSpecialFloatLiteralNode($2.ToKeyword())
$$ = ast.NewSignedFloatLiteralNode($1, f)
}
| '-' _NAN {
f := ast.NewSpecialFloatLiteralNode($2.ToKeyword())
$$ = ast.NewSignedFloatLiteralNode($1, f)
}

stringLit : _STRING_LIT {
$$ = []*ast.StringLiteralNode{$1}
}
Expand Down Expand Up @@ -426,7 +424,7 @@ messageLiteralFieldEntry : messageLiteralField {
$$ = nil
}

messageLiteralField : messageLiteralFieldName ':' value {
messageLiteralField : messageLiteralFieldName ':' fieldValue {
if $1 != nil && $2 != nil {
$$ = ast.NewMessageFieldNode($1, $2, $3)
} else {
Expand All @@ -440,7 +438,7 @@ messageLiteralField : messageLiteralFieldName ':' value {
$$ = nil
}
}
| error ':' value {
| error ':' fieldValue {
$$ = nil
}

Expand All @@ -457,10 +455,34 @@ messageLiteralFieldName : identifier {
$$ = nil
}

value : scalarValue
fieldValue
: fieldScalarValue
| messageLiteral
| listLiteral

fieldScalarValue : stringLit {
$$ = toStringValueNode($1)
}
| numLit
| '-' identifier {
kw := $2.ToKeyword()
switch strings.ToLower(kw.Val) {
case "inf", "infinity", "nan":
// these are acceptable
default:
// anything else is not
protolex.(*protoLex).Error(`only identifiers "inf", "infinity", or "nan" may appear after negative sign`)
}
// we'll validate the identifier later
f := ast.NewSpecialFloatLiteralNode(kw)
$$ = ast.NewSignedFloatLiteralNode($1, f)
}
| identifier {
$$ = $1
}



messageValue : messageLiteral
| listOfMessagesLiteral

Expand Down Expand Up @@ -500,7 +522,7 @@ listElements : listElement {
$$ = $1
}

listElement : scalarValue
listElement : fieldScalarValue
| messageLiteral

listOfMessagesLiteral : '[' messageLiterals ']' {
Expand Down
Loading

0 comments on commit d37b7f5

Please sign in to comment.