Skip to content

Commit ad74310

Browse files
authored
Fix issue #377 (#475) - Adds validation to reject handlers that would result in a panic when constructing the context
* add test case for #377 * fix panicking * use interface{} instead of any * add a comment for argumentType.NumMethod() == 0
1 parent 65f8ccd commit ad74310

File tree

3 files changed

+90
-12
lines changed

3 files changed

+90
-12
lines changed

lambda/entry_generic_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func TestStartHandlerFunc(t *testing.T) {
2727

2828
handlerType := reflect.TypeOf(f)
2929

30-
handlerTakesContext, err := validateArguments(handlerType)
30+
handlerTakesContext, err := handlerTakesContext(handlerType)
3131
assert.NoError(t, err)
3232
assert.True(t, handlerTakesContext)
3333

lambda/handler.go

+26-10
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,36 @@ func WithEnableSIGTERM(callbacks ...func()) Option {
9999
})
100100
}
101101

102-
func validateArguments(handler reflect.Type) (bool, error) {
103-
handlerTakesContext := false
104-
if handler.NumIn() > 2 {
105-
return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn())
106-
} else if handler.NumIn() > 0 {
102+
// handlerTakesContext returns whether the handler takes a context.Context as its first argument.
103+
func handlerTakesContext(handler reflect.Type) (bool, error) {
104+
switch handler.NumIn() {
105+
case 0:
106+
return false, nil
107+
case 1:
107108
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
108109
argumentType := handler.In(0)
109-
handlerTakesContext = argumentType.Implements(contextType)
110-
if handler.NumIn() > 1 && !handlerTakesContext {
110+
if argumentType.Kind() != reflect.Interface {
111+
return false, nil
112+
}
113+
114+
// handlers like func(event any) are valid.
115+
if argumentType.NumMethod() == 0 {
116+
return false, nil
117+
}
118+
119+
if !contextType.Implements(argumentType) || !argumentType.Implements(contextType) {
120+
return false, fmt.Errorf("handler takes an interface, but it is not context.Context: %q", argumentType.Name())
121+
}
122+
return true, nil
123+
case 2:
124+
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
125+
argumentType := handler.In(0)
126+
if argumentType.Kind() != reflect.Interface || !contextType.Implements(argumentType) || !argumentType.Implements(contextType) {
111127
return false, fmt.Errorf("handler takes two arguments, but the first is not Context. got %s", argumentType.Kind())
112128
}
129+
return true, nil
113130
}
114-
115-
return handlerTakesContext, nil
131+
return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn())
116132
}
117133

118134
func validateReturns(handler reflect.Type) error {
@@ -198,7 +214,7 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
198214
return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func))
199215
}
200216

201-
takesContext, err := validateArguments(handlerType)
217+
takesContext, err := handlerTakesContext(handlerType)
202218
if err != nil {
203219
return errorHandler(err)
204220
}

lambda/handler_test.go

+63-1
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,29 @@ import (
77
"errors"
88
"fmt"
99
"testing"
10+
"time"
1011

1112
"github.com/aws/aws-lambda-go/lambda/handlertrace"
1213
"github.com/aws/aws-lambda-go/lambda/messages"
1314
"github.com/stretchr/testify/assert"
1415
)
1516

1617
func TestInvalidHandlers(t *testing.T) {
18+
type valuer interface {
19+
Value(key interface{}) interface{}
20+
}
21+
22+
type customContext interface {
23+
context.Context
24+
MyCustomMethod()
25+
}
26+
27+
type myContext interface {
28+
Deadline() (deadline time.Time, ok bool)
29+
Done() <-chan struct{}
30+
Err() error
31+
Value(key interface{}) interface{}
32+
}
1733

1834
testCases := []struct {
1935
name string
@@ -72,12 +88,58 @@ func TestInvalidHandlers(t *testing.T) {
7288
handler: func() {
7389
},
7490
},
91+
{
92+
name: "the handler takes the empty interface",
93+
expected: nil,
94+
handler: func(v interface{}) error {
95+
if _, ok := v.(context.Context); ok {
96+
return errors.New("v should not be a Context")
97+
}
98+
return nil
99+
},
100+
},
101+
{
102+
name: "the handler takes a subset of context.Context",
103+
expected: errors.New("handler takes an interface, but it is not context.Context: \"valuer\""),
104+
handler: func(ctx valuer) {
105+
},
106+
},
107+
{
108+
name: "the handler takes a same interface with context.Context",
109+
expected: nil,
110+
handler: func(ctx myContext) {
111+
},
112+
},
113+
{
114+
name: "the handler takes a superset of context.Context",
115+
expected: errors.New("handler takes an interface, but it is not context.Context: \"customContext\""),
116+
handler: func(ctx customContext) {
117+
},
118+
},
119+
{
120+
name: "the handler takes two arguments and first argument is a subset of context.Context",
121+
expected: errors.New("handler takes two arguments, but the first is not Context. got interface"),
122+
handler: func(ctx valuer, v interface{}) {
123+
},
124+
},
125+
{
126+
name: "the handler takes two arguments and first argument is a same interface with context.Context",
127+
expected: nil,
128+
handler: func(ctx myContext, v interface{}) {
129+
},
130+
},
131+
{
132+
name: "the handler takes two arguments and first argument is a superset of context.Context",
133+
expected: errors.New("handler takes two arguments, but the first is not Context. got interface"),
134+
handler: func(ctx customContext, v interface{}) {
135+
},
136+
},
75137
}
76138
for i, testCase := range testCases {
77139
testCase := testCase
78140
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) {
79141
lambdaHandler := NewHandler(testCase.handler)
80-
_, err := lambdaHandler.Invoke(context.TODO(), make([]byte, 0))
142+
_, err := lambdaHandler.Invoke(context.TODO(), []byte("{}"))
81143
assert.Equal(t, testCase.expected, err)
82144
})
83145
}

0 commit comments

Comments
 (0)