Skip to content

Commit d83f0b6

Browse files
committed
Add 'pkg/' from commit '5b7ff06fd44ad7e4169c1ae05a437f9e05654f3d'
git-subtree-dir: pkg git-subtree-mainline: ce53a89 git-subtree-split: 5b7ff06
2 parents ce53a89 + 5b7ff06 commit d83f0b6

File tree

3 files changed

+728
-0
lines changed

3 files changed

+728
-0
lines changed

pkg/oapi_validate.go

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
// Copyright 2021 DeepMap, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package middleware
16+
17+
import (
18+
"context"
19+
"errors"
20+
"fmt"
21+
"log"
22+
"net/http"
23+
"os"
24+
"strings"
25+
26+
"github.com/getkin/kin-openapi/openapi3"
27+
"github.com/getkin/kin-openapi/openapi3filter"
28+
"github.com/getkin/kin-openapi/routers"
29+
"github.com/getkin/kin-openapi/routers/gorillamux"
30+
"github.com/gin-gonic/gin"
31+
)
32+
33+
const (
34+
GinContextKey = "oapi-codegen/gin-context"
35+
UserDataKey = "oapi-codegen/user-data"
36+
)
37+
38+
// OapiValidatorFromYamlFile creates a validator middleware from a YAML file path
39+
func OapiValidatorFromYamlFile(path string) (gin.HandlerFunc, error) {
40+
data, err := os.ReadFile(path)
41+
if err != nil {
42+
return nil, fmt.Errorf("error reading %s: %s", path, err)
43+
}
44+
45+
swagger, err := openapi3.NewLoader().LoadFromData(data)
46+
if err != nil {
47+
return nil, fmt.Errorf("error parsing %s as Swagger YAML: %s",
48+
path, err)
49+
}
50+
return OapiRequestValidator(swagger), nil
51+
}
52+
53+
// OapiRequestValidator is an gin middleware function which validates incoming HTTP requests
54+
// to make sure that they conform to the given OAPI 3.0 specification. When
55+
// OAPI validation fails on the request, we return an HTTP/400 with error message
56+
func OapiRequestValidator(swagger *openapi3.T) gin.HandlerFunc {
57+
return OapiRequestValidatorWithOptions(swagger, nil)
58+
}
59+
60+
// ErrorHandler is called when there is an error in validation
61+
type ErrorHandler func(c *gin.Context, message string, statusCode int)
62+
63+
// MultiErrorHandler is called when oapi returns a MultiError type
64+
type MultiErrorHandler func(openapi3.MultiError) error
65+
66+
// Options to customize request validation. These are passed through to
67+
// openapi3filter.
68+
type Options struct {
69+
ErrorHandler ErrorHandler
70+
Options openapi3filter.Options
71+
ParamDecoder openapi3filter.ContentParameterDecoder
72+
UserData interface{}
73+
MultiErrorHandler MultiErrorHandler
74+
// SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil`
75+
SilenceServersWarning bool
76+
}
77+
78+
// OapiRequestValidatorWithOptions creates a validator from a swagger object, with validation options
79+
func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin.HandlerFunc {
80+
if swagger.Servers != nil && (options == nil || !options.SilenceServersWarning) {
81+
log.Println("WARN: OapiRequestValidatorWithOptions called with an OpenAPI spec that has `Servers` set. This may lead to an HTTP 400 with `no matching operation was found` when sending a valid request, as the validator performs `Host` header validation. If you're expecting `Host` header validation, you can silence this warning by setting `Options.SilenceServersWarning = true`. See https://github.com/deepmap/oapi-codegen/issues/882 for more information.")
82+
}
83+
84+
router, err := gorillamux.NewRouter(swagger)
85+
if err != nil {
86+
panic(err)
87+
}
88+
return func(c *gin.Context) {
89+
err := ValidateRequestFromContext(c, router, options)
90+
if err != nil {
91+
// using errors.Is did not work
92+
if options != nil && options.ErrorHandler != nil && err.Error() == routers.ErrPathNotFound.Error() {
93+
options.ErrorHandler(c, err.Error(), http.StatusNotFound)
94+
// in case the handler didn't internally call Abort, stop the chain
95+
c.Abort()
96+
} else if options != nil && options.ErrorHandler != nil {
97+
options.ErrorHandler(c, err.Error(), http.StatusBadRequest)
98+
// in case the handler didn't internally call Abort, stop the chain
99+
c.Abort()
100+
} else if err.Error() == routers.ErrPathNotFound.Error() {
101+
// note: i am not sure if this is the best way to handle this
102+
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": err.Error()})
103+
} else {
104+
// note: i am not sure if this is the best way to handle this
105+
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
106+
}
107+
}
108+
c.Next()
109+
}
110+
}
111+
112+
// ValidateRequestFromContext is called from the middleware above and actually does the work
113+
// of validating a request.
114+
func ValidateRequestFromContext(c *gin.Context, router routers.Router, options *Options) error {
115+
req := c.Request
116+
route, pathParams, err := router.FindRoute(req)
117+
118+
// We failed to find a matching route for the request.
119+
if err != nil {
120+
switch e := err.(type) {
121+
case *routers.RouteError:
122+
// We've got a bad request, the path requested doesn't match
123+
// either server, or path, or something.
124+
return errors.New(e.Reason)
125+
default:
126+
// This should never happen today, but if our upstream code changes,
127+
// we don't want to crash the server, so handle the unexpected error.
128+
return fmt.Errorf("error validating route: %s", err.Error())
129+
}
130+
}
131+
132+
validationInput := &openapi3filter.RequestValidationInput{
133+
Request: req,
134+
PathParams: pathParams,
135+
Route: route,
136+
}
137+
138+
// Pass the gin context into the request validator, so that any callbacks
139+
// which it invokes make it available.
140+
requestContext := context.WithValue(context.Background(), GinContextKey, c) //nolint:staticcheck
141+
142+
if options != nil {
143+
validationInput.Options = &options.Options
144+
validationInput.ParamDecoder = options.ParamDecoder
145+
requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) //nolint:staticcheck
146+
}
147+
148+
err = openapi3filter.ValidateRequest(requestContext, validationInput)
149+
if err != nil {
150+
me := openapi3.MultiError{}
151+
if errors.As(err, &me) {
152+
errFunc := getMultiErrorHandlerFromOptions(options)
153+
return errFunc(me)
154+
}
155+
156+
switch e := err.(type) {
157+
case *openapi3filter.RequestError:
158+
// We've got a bad request
159+
// Split up the verbose error by lines and return the first one
160+
// openapi errors seem to be multi-line with a decent message on the first
161+
errorLines := strings.Split(e.Error(), "\n")
162+
return fmt.Errorf("error in openapi3filter.RequestError: %s", errorLines[0])
163+
case *openapi3filter.SecurityRequirementsError:
164+
return fmt.Errorf("error in openapi3filter.SecurityRequirementsError: %s", e.Error())
165+
default:
166+
// This should never happen today, but if our upstream code changes,
167+
// we don't want to crash the server, so handle the unexpected error.
168+
return fmt.Errorf("error validating request: %w", err)
169+
}
170+
}
171+
return nil
172+
}
173+
174+
// GetGinContext gets the echo context from within requests. It returns
175+
// nil if not found or wrong type.
176+
func GetGinContext(c context.Context) *gin.Context {
177+
iface := c.Value(GinContextKey)
178+
if iface == nil {
179+
return nil
180+
}
181+
ginCtx, ok := iface.(*gin.Context)
182+
if !ok {
183+
return nil
184+
}
185+
return ginCtx
186+
}
187+
188+
func GetUserData(c context.Context) interface{} {
189+
return c.Value(UserDataKey)
190+
}
191+
192+
// attempt to get the MultiErrorHandler from the options. If it is not set,
193+
// return a default handler
194+
func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler {
195+
if options == nil {
196+
return defaultMultiErrorHandler
197+
}
198+
199+
if options.MultiErrorHandler == nil {
200+
return defaultMultiErrorHandler
201+
}
202+
203+
return options.MultiErrorHandler
204+
}
205+
206+
// defaultMultiErrorHandler returns a StatusBadRequest (400) and a list
207+
// of all of the errors. This method is called if there are no other
208+
// methods defined on the options.
209+
func defaultMultiErrorHandler(me openapi3.MultiError) error {
210+
return fmt.Errorf("multiple errors encountered: %s", me)
211+
}

0 commit comments

Comments
 (0)