Skip to content

Commit c80f8ac

Browse files
authored
Support handlers that return io.Reader (#472)
1 parent ad74310 commit c80f8ac

8 files changed

+334
-58
lines changed

lambda/entry.go

+3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ import (
3838
//
3939
// Where "TIn" and "TOut" are types compatible with the "encoding/json" standard library.
4040
// See https://golang.org/pkg/encoding/json/#Unmarshal for how deserialization behaves
41+
//
42+
// "TOut" may also implement the io.Reader interface.
43+
// If "TOut" is both json serializable and implements io.Reader, then the json serialization is used.
4144
func Start(handler interface{}) {
4245
StartWithOptions(handler)
4346
}

lambda/handler.go

+73-21
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ import (
88
"encoding/json"
99
"errors"
1010
"fmt"
11+
"io"
12+
"io/ioutil" // nolint:staticcheck
1113
"reflect"
14+
"strings"
1215

1316
"github.com/aws/aws-lambda-go/lambda/handlertrace"
1417
)
@@ -18,7 +21,7 @@ type Handler interface {
1821
}
1922

2023
type handlerOptions struct {
21-
Handler
24+
handlerFunc
2225
baseContext context.Context
2326
jsonResponseEscapeHTML bool
2427
jsonResponseIndentPrefix string
@@ -184,32 +187,68 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
184187
if h.enableSIGTERM {
185188
enableSIGTERM(h.sigtermCallbacks)
186189
}
187-
h.Handler = reflectHandler(handlerFunc, h)
190+
h.handlerFunc = reflectHandler(handlerFunc, h)
188191
return h
189192
}
190193

191-
type bytesHandlerFunc func(context.Context, []byte) ([]byte, error)
194+
type handlerFunc func(context.Context, []byte) (io.Reader, error)
192195

193-
func (h bytesHandlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
194-
return h(ctx, payload)
196+
// back-compat for the rpc mode
197+
func (h handlerFunc) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
198+
response, err := h(ctx, payload)
199+
if err != nil {
200+
return nil, err
201+
}
202+
// if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak
203+
if response, ok := response.(io.Closer); ok {
204+
defer response.Close()
205+
}
206+
// optimization: if the response is a *bytes.Buffer, a copy can be eliminated
207+
switch response := response.(type) {
208+
case *jsonOutBuffer:
209+
return response.Bytes(), nil
210+
case *bytes.Buffer:
211+
return response.Bytes(), nil
212+
}
213+
b, err := ioutil.ReadAll(response)
214+
if err != nil {
215+
return nil, err
216+
}
217+
return b, nil
195218
}
196-
func errorHandler(err error) Handler {
197-
return bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) {
219+
220+
func errorHandler(err error) handlerFunc {
221+
return func(_ context.Context, _ []byte) (io.Reader, error) {
198222
return nil, err
199-
})
223+
}
224+
}
225+
226+
type jsonOutBuffer struct {
227+
*bytes.Buffer
200228
}
201229

202-
func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
203-
if handlerFunc == nil {
230+
func (j *jsonOutBuffer) ContentType() string {
231+
return contentTypeJSON
232+
}
233+
234+
func reflectHandler(f interface{}, h *handlerOptions) handlerFunc {
235+
if f == nil {
204236
return errorHandler(errors.New("handler is nil"))
205237
}
206238

207-
if handler, ok := handlerFunc.(Handler); ok {
208-
return handler
239+
// back-compat: types with reciever `Invoke(context.Context, []byte) ([]byte, error)` need the return bytes wrapped
240+
if handler, ok := f.(Handler); ok {
241+
return func(ctx context.Context, payload []byte) (io.Reader, error) {
242+
b, err := handler.Invoke(ctx, payload)
243+
if err != nil {
244+
return nil, err
245+
}
246+
return bytes.NewBuffer(b), nil
247+
}
209248
}
210249

211-
handler := reflect.ValueOf(handlerFunc)
212-
handlerType := reflect.TypeOf(handlerFunc)
250+
handler := reflect.ValueOf(f)
251+
handlerType := reflect.TypeOf(f)
213252
if handlerType.Kind() != reflect.Func {
214253
return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func))
215254
}
@@ -223,9 +262,10 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
223262
return errorHandler(err)
224263
}
225264

226-
return bytesHandlerFunc(func(ctx context.Context, payload []byte) ([]byte, error) {
265+
out := &jsonOutBuffer{bytes.NewBuffer(nil)}
266+
return func(ctx context.Context, payload []byte) (io.Reader, error) {
267+
out.Reset()
227268
in := bytes.NewBuffer(payload)
228-
out := bytes.NewBuffer(nil)
229269
decoder := json.NewDecoder(in)
230270
encoder := json.NewEncoder(out)
231271
encoder.SetEscapeHTML(h.jsonResponseEscapeHTML)
@@ -266,16 +306,28 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
266306
trace.ResponseEvent(ctx, val)
267307
}
268308
}
309+
310+
// encode to JSON
269311
if err := encoder.Encode(val); err != nil {
312+
// if response is not JSON serializable, but the response type is a reader, return it as-is
313+
if reader, ok := val.(io.Reader); ok {
314+
return reader, nil
315+
}
270316
return nil, err
271317
}
272318

273-
responseBytes := out.Bytes()
319+
// if response value is an io.Reader, return it as-is
320+
if reader, ok := val.(io.Reader); ok {
321+
// back-compat, don't return the reader if the value serialized to a non-empty json
322+
if strings.HasPrefix(out.String(), "{}") {
323+
return reader, nil
324+
}
325+
}
326+
274327
// back-compat, strip the encoder's trailing newline unless WithSetIndent was used
275328
if h.jsonResponseIndentValue == "" && h.jsonResponseIndentPrefix == "" {
276-
return responseBytes[:len(responseBytes)-1], nil
329+
out.Truncate(out.Len() - 1)
277330
}
278-
279-
return responseBytes, nil
280-
})
331+
return out, nil
332+
}
281333
}

lambda/handler_test.go

+104-14
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@
33
package lambda
44

55
import (
6+
"bytes"
67
"context"
78
"errors"
89
"fmt"
10+
"io"
11+
"io/ioutil" //nolint: staticcheck
12+
"strings"
913
"testing"
1014
"time"
1115

1216
"github.com/aws/aws-lambda-go/lambda/handlertrace"
1317
"github.com/aws/aws-lambda-go/lambda/messages"
1418
"github.com/stretchr/testify/assert"
19+
"github.com/stretchr/testify/require"
1520
)
1621

1722
func TestInvalidHandlers(t *testing.T) {
@@ -145,6 +150,23 @@ func TestInvalidHandlers(t *testing.T) {
145150
}
146151
}
147152

153+
type arbitraryJSON struct {
154+
json []byte
155+
err error
156+
}
157+
158+
func (a arbitraryJSON) MarshalJSON() ([]byte, error) {
159+
return a.json, a.err
160+
}
161+
162+
type staticHandler struct {
163+
body []byte
164+
}
165+
166+
func (h *staticHandler) Invoke(_ context.Context, _ []byte) ([]byte, error) {
167+
return h.body, nil
168+
}
169+
148170
type expected struct {
149171
val string
150172
err error
@@ -168,10 +190,8 @@ func TestInvokes(t *testing.T) {
168190
}{
169191
{
170192
input: `"Lambda"`,
171-
expected: expected{`"Hello Lambda!"`, nil},
172-
handler: func(name string) (string, error) {
173-
return hello(name), nil
174-
},
193+
expected: expected{`null`, nil},
194+
handler: func(_ string) {},
175195
},
176196
{
177197
input: `"Lambda"`,
@@ -180,6 +200,12 @@ func TestInvokes(t *testing.T) {
180200
return hello(name), nil
181201
},
182202
},
203+
{
204+
expected: expected{`"Hello No Value!"`, nil},
205+
handler: func(ctx context.Context) (string, error) {
206+
return hello("No Value"), nil
207+
},
208+
},
183209
{
184210
input: `"Lambda"`,
185211
expected: expected{`"Hello Lambda!"`, nil},
@@ -294,22 +320,86 @@ func TestInvokes(t *testing.T) {
294320
{
295321
name: "Handler interface implementations are passthrough",
296322
expected: expected{`<xml>hello</xml>`, nil},
297-
handler: bytesHandlerFunc(func(_ context.Context, _ []byte) ([]byte, error) {
298-
return []byte(`<xml>hello</xml>`), nil
299-
}),
323+
handler: &staticHandler{body: []byte(`<xml>hello</xml>`)},
324+
},
325+
{
326+
name: "io.Reader responses are passthrough",
327+
expected: expected{`<yolo>yolo</yolo>`, nil},
328+
handler: func() (io.Reader, error) {
329+
return strings.NewReader(`<yolo>yolo</yolo>`), nil
330+
},
331+
},
332+
{
333+
name: "io.Reader responses that are byte buffers are passthrough",
334+
expected: expected{`<yolo>yolo</yolo>`, nil},
335+
handler: func() (*bytes.Buffer, error) {
336+
return bytes.NewBuffer([]byte(`<yolo>yolo</yolo>`)), nil
337+
},
338+
},
339+
{
340+
name: "io.Reader responses that are also json serializable, handler returns the json, ignoring the reader",
341+
expected: expected{`{"Yolo":"yolo"}`, nil},
342+
handler: func() (io.Reader, error) {
343+
return struct {
344+
io.Reader `json:"-"`
345+
Yolo string
346+
}{
347+
Reader: strings.NewReader(`<yolo>yolo</yolo>`),
348+
Yolo: "yolo",
349+
}, nil
350+
},
351+
},
352+
{
353+
name: "types that are not json serializable result in an error",
354+
expected: expected{``, errors.New("json: error calling MarshalJSON for type struct { lambda.arbitraryJSON }: barf")},
355+
handler: func() (interface{}, error) {
356+
return struct {
357+
arbitraryJSON
358+
}{
359+
arbitraryJSON{nil, errors.New("barf")},
360+
}, nil
361+
},
362+
},
363+
{
364+
name: "io.Reader responses that not json serializable remain passthrough",
365+
expected: expected{`wat`, nil},
366+
handler: func() (io.Reader, error) {
367+
return struct {
368+
arbitraryJSON
369+
io.Reader
370+
}{
371+
arbitraryJSON{nil, errors.New("barf")},
372+
strings.NewReader("wat"),
373+
}, nil
374+
},
300375
},
301376
}
302377
for i, testCase := range testCases {
303378
testCase := testCase
304379
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) {
305380
lambdaHandler := newHandler(testCase.handler, testCase.options...)
306-
response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input))
307-
if testCase.expected.err != nil {
308-
assert.Equal(t, testCase.expected.err, err)
309-
} else {
310-
assert.NoError(t, err)
311-
assert.Equal(t, testCase.expected.val, string(response))
312-
}
381+
t.Run("via Handler.Invoke", func(t *testing.T) {
382+
response, err := lambdaHandler.Invoke(context.TODO(), []byte(testCase.input))
383+
if testCase.expected.err != nil {
384+
assert.EqualError(t, err, testCase.expected.err.Error())
385+
} else {
386+
assert.NoError(t, err)
387+
assert.Equal(t, testCase.expected.val, string(response))
388+
}
389+
})
390+
t.Run("via handlerOptions.handlerFunc", func(t *testing.T) {
391+
response, err := lambdaHandler.handlerFunc(context.TODO(), []byte(testCase.input))
392+
if testCase.expected.err != nil {
393+
assert.EqualError(t, err, testCase.expected.err.Error())
394+
} else {
395+
assert.NoError(t, err)
396+
require.NotNil(t, response)
397+
responseBytes, err := ioutil.ReadAll(response)
398+
assert.NoError(t, err)
399+
assert.Equal(t, testCase.expected.val, string(responseBytes))
400+
}
401+
})
402+
313403
})
314404
}
315405
}

lambda/invoke_loop.go

+18-4
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
package lambda
44

55
import (
6+
"bytes"
67
"context"
78
"encoding/json"
89
"fmt"
10+
"io"
911
"log"
1012
"os"
1113
"strconv"
@@ -70,7 +72,7 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
7072
ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID)
7173

7274
// call the handler, marshal any returned error
73-
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke)
75+
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.handlerFunc)
7476
if invokeErr != nil {
7577
if err := reportFailure(invoke, invokeErr); err != nil {
7678
return err
@@ -80,7 +82,19 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
8082
}
8183
return nil
8284
}
83-
if err := invoke.success(response, contentTypeJSON); err != nil {
85+
// if the response needs to be closed (ex: net.Conn, os.File), ensure it's closed before the next invoke to prevent a resource leak
86+
if response, ok := response.(io.Closer); ok {
87+
defer response.Close()
88+
}
89+
90+
// if the response defines a content-type, plumb it through
91+
contentType := contentTypeBytes
92+
type ContentType interface{ ContentType() string }
93+
if response, ok := response.(ContentType); ok {
94+
contentType = response.ContentType()
95+
}
96+
97+
if err := invoke.success(response, contentType); err != nil {
8498
return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err)
8599
}
86100

@@ -90,13 +104,13 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
90104
func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error {
91105
errorPayload := safeMarshal(invokeErr)
92106
log.Printf("%s", errorPayload)
93-
if err := invoke.failure(errorPayload, contentTypeJSON); err != nil {
107+
if err := invoke.failure(bytes.NewReader(errorPayload), contentTypeJSON); err != nil {
94108
return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err)
95109
}
96110
return nil
97111
}
98112

99-
func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) {
113+
func callBytesHandlerFunc(ctx context.Context, payload []byte, handler handlerFunc) (response io.Reader, invokeErr *messages.InvokeResponse_Error) {
100114
defer func() {
101115
if err := recover(); err != nil {
102116
invokeErr = lambdaPanicResponse(err)

0 commit comments

Comments
 (0)