diff --git a/core/request.go b/core/request.go index ebb1e64..851a07f 100644 --- a/core/request.go +++ b/core/request.go @@ -35,6 +35,11 @@ const ( // API Gateway stage variables. To access the stage variable values // use the GetAPIGatewayStageVars method of the RequestAccessor object. APIGwStageVarsHeader = "X-GoLambdaProxy-ApiGw-StageVars" + + // APIGwPathParamsHeader is the custom header key used to store the + // API Gateway request's path parameters. To access the path parameters + // values use the GetAPIGatewayPathParams method of the RequestAccessor object. + APIGwPathParamsHeader = "X-GoLambdaProxy-ApiGw-PathParams" ) // RequestAccessor objects give access to custom API Gateway properties @@ -79,6 +84,24 @@ func (r *RequestAccessor) GetAPIGatewayStageVars(req *http.Request) (map[string] return stageVars, nil } +// GetAPIGatewayPathParams extracts the API Gateway request's path parameters +// from a request's custom header. +// Returns a map[string]string of the path parameters and their values from +// the request. +func (r *RequestAccessor) GetAPIGatewayPathParams(req *http.Request) (map[string]string, error) { + pathParams := make(map[string]string) + if req.Header.Get(APIGwPathParamsHeader) == "" { + return pathParams, errors.New("No path params header in request") + } + err := json.Unmarshal([]byte(req.Header.Get(APIGwPathParamsHeader)), &pathParams) + if err != nil { + log.Println("Error while unmarshalling path parameters") + log.Println(err) + return pathParams, err + } + return pathParams, nil +} + // StripBasePath instructs the RequestAccessor object that the given base // path should be removed from the request path before sending it to the // framework for routing. This is used when API Gateway is configured with @@ -222,12 +245,23 @@ func addToHeader(req *http.Request, apiGwRequest events.APIGatewayProxyRequest) return req, err } req.Header.Set(APIGwContextHeader, string(apiGwContext)) + pathParams, err := json.Marshal(apiGwRequest.PathParameters) + if err != nil { + log.Println("Could not marshal path params for custom header") + return nil, err + } + req.Header.Set(APIGwPathParamsHeader, string(pathParams)) return req, nil } func addToContext(ctx context.Context, req *http.Request, apiGwRequest events.APIGatewayProxyRequest) *http.Request { lc, _ := lambdacontext.FromContext(ctx) - rc := requestContext{lambdaContext: lc, gatewayProxyContext: apiGwRequest.RequestContext, stageVars: apiGwRequest.StageVariables} + rc := requestContext{ + lambdaContext: lc, + gatewayProxyContext: apiGwRequest.RequestContext, + stageVars: apiGwRequest.StageVariables, + pathParams: apiGwRequest.PathParameters, + } ctx = context.WithValue(ctx, ctxKey{}, rc) return req.WithContext(ctx) } @@ -250,10 +284,17 @@ func GetStageVarsFromContext(ctx context.Context) (map[string]string, bool) { return v.stageVars, ok } +// GetPathParamsFromContext retrieve path parameters from context +func GetPathParamsFromContext(ctx context.Context) (map[string]string, bool) { + v, ok := ctx.Value(ctxKey{}).(requestContext) + return v.pathParams, ok +} + type ctxKey struct{} type requestContext struct { lambdaContext *lambdacontext.LambdaContext gatewayProxyContext events.APIGatewayProxyRequestContext stageVars map[string]string + pathParams map[string]string } diff --git a/core/request_test.go b/core/request_test.go index b91563b..e18a749 100644 --- a/core/request_test.go +++ b/core/request_test.go @@ -174,7 +174,7 @@ var _ = Describe("RequestAccessor tests", func() { // calling old method to verify reverse compatibility httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest) Expect(err).To(BeNil()) - Expect(2).To(Equal(len(httpReq.Header))) + Expect(3).To(Equal(len(httpReq.Header))) Expect(httpReq.Header.Get(core.APIGwContextHeader)).ToNot(BeNil()) }) }) @@ -304,6 +304,52 @@ var _ = Describe("RequestAccessor tests", func() { Expect("value2").To(Equal(stageVars["var2"])) }) + It("Populates path parameters correctly", func() { + varsRequest := getProxyRequest("orders", "GET") + varsRequest.PathParameters = getPathParameters() + + accessor := core.RequestAccessor{} + httpReq, err := accessor.ProxyEventToHTTPRequest(varsRequest) + Expect(err).To(BeNil()) + + pathParams, err := accessor.GetAPIGatewayPathParams(httpReq) + Expect(err).To(BeNil()) + Expect(2).To(Equal(len(pathParams))) + Expect(pathParams["param1"]).ToNot(BeNil()) + Expect(pathParams["param2"]).ToNot(BeNil()) + Expect("value1").To(Equal(pathParams["param1"])) + Expect("value2").To(Equal(pathParams["param2"])) + + // overwrite existing path params header + varsRequestWithHeaders := getProxyRequest("orders", "GET") + varsRequestWithHeaders.PathParameters = getPathParameters() + varsRequestWithHeaders.Headers = map[string]string{core.APIGwPathParamsHeader: `{"var1":"abc123"}`} + httpReq, err = accessor.ProxyEventToHTTPRequest(varsRequestWithHeaders) + Expect(err).To(BeNil()) + pathParams, err = accessor.GetAPIGatewayPathParams(httpReq) + Expect(err).To(BeNil()) + Expect(pathParams["param1"]).To(Equal("value1")) + + pathParams, ok := core.GetPathParamsFromContext(httpReq.Context()) + // not present in context + Expect(ok).To(BeFalse()) + + httpReq, err = accessor.EventToRequestWithContext(context.Background(), varsRequest) + Expect(err).To(BeNil()) + + pathParams, err = accessor.GetAPIGatewayPathParams(httpReq) + // should not be in headers + Expect(err).ToNot(BeNil()) + + pathParams, ok = core.GetPathParamsFromContext(httpReq.Context()) + Expect(ok).To(BeTrue()) + Expect(2).To(Equal(len(pathParams))) + Expect(pathParams["param1"]).ToNot(BeNil()) + Expect(pathParams["param2"]).ToNot(BeNil()) + Expect("value1").To(Equal(pathParams["param1"])) + Expect("value2").To(Equal(pathParams["param2"])) + }) + It("Populates the default hostname correctly", func() { basicRequest := getProxyRequest("orders", "GET") @@ -367,3 +413,10 @@ func getStageVariables() map[string]string { "var2": "value2", } } + +func getPathParameters() map[string]string { + return map[string]string{ + "param1": "value1", + "param2": "value2", + } +} diff --git a/core/requestv2.go b/core/requestv2.go index d665384..6127c90 100644 --- a/core/requestv2.go +++ b/core/requestv2.go @@ -36,7 +36,7 @@ func (r *RequestAccessorV2) GetAPIGatewayContextV2(req *http.Request) (events.AP context := events.APIGatewayV2HTTPRequestContext{} err := json.Unmarshal([]byte(req.Header.Get(APIGwContextHeader)), &context) if err != nil { - log.Println("Erorr while unmarshalling context") + log.Println("Error while unmarshalling context") log.Println(err) return events.APIGatewayV2HTTPRequestContext{}, err } @@ -54,13 +54,31 @@ func (r *RequestAccessorV2) GetAPIGatewayStageVars(req *http.Request) (map[strin } err := json.Unmarshal([]byte(req.Header.Get(APIGwStageVarsHeader)), &stageVars) if err != nil { - log.Println("Erorr while unmarshalling stage variables") + log.Println("Error while unmarshalling stage variables") log.Println(err) return stageVars, err } return stageVars, nil } +// GetAPIGatewayPathParams extracts the API Gateway path parameters from a +// request's custom header. +// Returns a map[string]string of the path parameters and their values from +// the request. +func (r *RequestAccessorV2) GetAPIGatewayPathParams(req *http.Request) (map[string]string, error) { + pathParams := make(map[string]string) + if req.Header.Get(APIGwStageVarsHeader) == "" { + return pathParams, errors.New("No path params header in request") + } + err := json.Unmarshal([]byte(req.Header.Get(APIGwPathParamsHeader)), &pathParams) + if err != nil { + log.Println("Error while unmarshalling path params") + log.Println(err) + return pathParams, err + } + return pathParams, nil +} + // StripBasePath instructs the RequestAccessor object that the given base // path should be removed from the request path before sending it to the // framework for routing. This is used when API Gateway is configured with @@ -194,12 +212,23 @@ func addToHeaderV2(req *http.Request, apiGwRequest events.APIGatewayV2HTTPReques return req, err } req.Header.Add(APIGwContextHeader, string(apiGwContext)) + pathParams, err := json.Marshal(apiGwRequest.PathParameters) + if err != nil { + log.Println("Could not Marshal path parameters for custom header") + return req, err + } + req.Header.Add(APIGwPathParamsHeader, string(pathParams)) return req, nil } func addToContextV2(ctx context.Context, req *http.Request, apiGwRequest events.APIGatewayV2HTTPRequest) *http.Request { lc, _ := lambdacontext.FromContext(ctx) - rc := requestContextV2{lambdaContext: lc, gatewayProxyContext: apiGwRequest.RequestContext, stageVars: apiGwRequest.StageVariables} + rc := requestContextV2{ + lambdaContext: lc, + gatewayProxyContext: apiGwRequest.RequestContext, + stageVars: apiGwRequest.StageVariables, + pathParams: apiGwRequest.PathParameters, + } ctx = context.WithValue(ctx, ctxKey{}, rc) return req.WithContext(ctx) } @@ -222,8 +251,15 @@ func GetStageVarsFromContextV2(ctx context.Context) (map[string]string, bool) { return v.stageVars, ok } +// GetPathParamsFromContextV2 retrieve path params from context +func GetPathParamsFromContextV2(ctx context.Context) (map[string]string, bool) { + v, ok := ctx.Value(ctxKey{}).(requestContextV2) + return v.pathParams, ok +} + type requestContextV2 struct { lambdaContext *lambdacontext.LambdaContext gatewayProxyContext events.APIGatewayV2HTTPRequestContext stageVars map[string]string + pathParams map[string]string } diff --git a/core/requestv2_test.go b/core/requestv2_test.go index e42370d..701347e 100644 --- a/core/requestv2_test.go +++ b/core/requestv2_test.go @@ -173,7 +173,7 @@ var _ = Describe("RequestAccessorV2 tests", func() { // calling old method to verify reverse compatibility httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest) Expect(err).To(BeNil()) - Expect(2).To(Equal(len(httpReq.Header))) + Expect(3).To(Equal(len(httpReq.Header))) Expect(httpReq.Header.Get(core.APIGwContextHeader)).ToNot(BeNil()) }) }) @@ -283,6 +283,42 @@ var _ = Describe("RequestAccessorV2 tests", func() { Expect("value2").To(Equal(stageVars["var2"])) }) + It("Populates path parameters correctly", func() { + varsRequest := getProxyRequestV2("orders", "GET") + varsRequest.PathParameters = getPathParameters() + + accessor := core.RequestAccessorV2{} + httpReq, err := accessor.ProxyEventToHTTPRequest(varsRequest) + Expect(err).To(BeNil()) + + stageVars, err := accessor.GetAPIGatewayPathParams(httpReq) + Expect(err).To(BeNil()) + Expect(2).To(Equal(len(stageVars))) + Expect(stageVars["param1"]).ToNot(BeNil()) + Expect(stageVars["param2"]).ToNot(BeNil()) + Expect("value1").To(Equal(stageVars["param1"])) + Expect("value2").To(Equal(stageVars["param2"])) + + pathParams, ok := core.GetPathParamsFromContextV2(httpReq.Context()) + // not present in context + Expect(ok).To(BeFalse()) + + httpReq, err = accessor.EventToRequestWithContext(context.Background(), varsRequest) + Expect(err).To(BeNil()) + + pathParams, err = accessor.GetAPIGatewayPathParams(httpReq) + // should not be in headers + Expect(err).ToNot(BeNil()) + + pathParams, ok = core.GetPathParamsFromContextV2(httpReq.Context()) + Expect(ok).To(BeTrue()) + Expect(2).To(Equal(len(stageVars))) + Expect(pathParams["param1"]).ToNot(BeNil()) + Expect(pathParams["param2"]).ToNot(BeNil()) + Expect("value1").To(Equal(pathParams["param1"])) + Expect("value2").To(Equal(pathParams["param2"])) + }) + It("Populates the default hostname correctly", func() { basicRequest := getProxyRequestV2("orders", "GET")