| 
 | 1 | +// Copyright 2021 DeepMap, Inc.  | 
 | 2 | +//  | 
 | 3 | +// Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 4 | +// you may not use this file except in compliance with the License.  | 
 | 5 | +// You may obtain a copy of the License at  | 
 | 6 | +//  | 
 | 7 | +// http://www.apache.org/licenses/LICENSE-2.0  | 
 | 8 | +//  | 
 | 9 | +// Unless required by applicable law or agreed to in writing, software  | 
 | 10 | +// distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 12 | +// See the License for the specific language governing permissions and  | 
 | 13 | +// limitations under the License.  | 
 | 14 | + | 
 | 15 | +package middleware  | 
 | 16 | + | 
 | 17 | +import (  | 
 | 18 | +	"context"  | 
 | 19 | +	"errors"  | 
 | 20 | +	"fmt"  | 
 | 21 | +	"log"  | 
 | 22 | +	"net/http"  | 
 | 23 | +	"os"  | 
 | 24 | +	"strings"  | 
 | 25 | + | 
 | 26 | +	"github.com/getkin/kin-openapi/openapi3"  | 
 | 27 | +	"github.com/getkin/kin-openapi/openapi3filter"  | 
 | 28 | +	"github.com/getkin/kin-openapi/routers"  | 
 | 29 | +	"github.com/getkin/kin-openapi/routers/gorillamux"  | 
 | 30 | +	"github.com/gin-gonic/gin"  | 
 | 31 | +)  | 
 | 32 | + | 
 | 33 | +const (  | 
 | 34 | +	GinContextKey = "oapi-codegen/gin-context"  | 
 | 35 | +	UserDataKey   = "oapi-codegen/user-data"  | 
 | 36 | +)  | 
 | 37 | + | 
 | 38 | +// OapiValidatorFromYamlFile creates a validator middleware from a YAML file path  | 
 | 39 | +func OapiValidatorFromYamlFile(path string) (gin.HandlerFunc, error) {  | 
 | 40 | +	data, err := os.ReadFile(path)  | 
 | 41 | +	if err != nil {  | 
 | 42 | +		return nil, fmt.Errorf("error reading %s: %s", path, err)  | 
 | 43 | +	}  | 
 | 44 | + | 
 | 45 | +	swagger, err := openapi3.NewLoader().LoadFromData(data)  | 
 | 46 | +	if err != nil {  | 
 | 47 | +		return nil, fmt.Errorf("error parsing %s as Swagger YAML: %s",  | 
 | 48 | +			path, err)  | 
 | 49 | +	}  | 
 | 50 | +	return OapiRequestValidator(swagger), nil  | 
 | 51 | +}  | 
 | 52 | + | 
 | 53 | +// OapiRequestValidator is an gin middleware function which validates incoming HTTP requests  | 
 | 54 | +// to make sure that they conform to the given OAPI 3.0 specification. When  | 
 | 55 | +// OAPI validation fails on the request, we return an HTTP/400 with error message  | 
 | 56 | +func OapiRequestValidator(swagger *openapi3.T) gin.HandlerFunc {  | 
 | 57 | +	return OapiRequestValidatorWithOptions(swagger, nil)  | 
 | 58 | +}  | 
 | 59 | + | 
 | 60 | +// ErrorHandler is called when there is an error in validation  | 
 | 61 | +type ErrorHandler func(c *gin.Context, message string, statusCode int)  | 
 | 62 | + | 
 | 63 | +// MultiErrorHandler is called when oapi returns a MultiError type  | 
 | 64 | +type MultiErrorHandler func(openapi3.MultiError) error  | 
 | 65 | + | 
 | 66 | +// Options to customize request validation. These are passed through to  | 
 | 67 | +// openapi3filter.  | 
 | 68 | +type Options struct {  | 
 | 69 | +	ErrorHandler      ErrorHandler  | 
 | 70 | +	Options           openapi3filter.Options  | 
 | 71 | +	ParamDecoder      openapi3filter.ContentParameterDecoder  | 
 | 72 | +	UserData          interface{}  | 
 | 73 | +	MultiErrorHandler MultiErrorHandler  | 
 | 74 | +	// SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil`  | 
 | 75 | +	SilenceServersWarning bool  | 
 | 76 | +}  | 
 | 77 | + | 
 | 78 | +// OapiRequestValidatorWithOptions creates a validator from a swagger object, with validation options  | 
 | 79 | +func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin.HandlerFunc {  | 
 | 80 | +	if swagger.Servers != nil && (options == nil || !options.SilenceServersWarning) {  | 
 | 81 | +		log.Println("WARN: OapiRequestValidatorWithOptions called with an OpenAPI spec that has `Servers` set. This may lead to an HTTP 400 with `no matching operation was found` when sending a valid request, as the validator performs `Host` header validation. If you're expecting `Host` header validation, you can silence this warning by setting `Options.SilenceServersWarning = true`. See https://github.com/deepmap/oapi-codegen/issues/882 for more information.")  | 
 | 82 | +	}  | 
 | 83 | + | 
 | 84 | +	router, err := gorillamux.NewRouter(swagger)  | 
 | 85 | +	if err != nil {  | 
 | 86 | +		panic(err)  | 
 | 87 | +	}  | 
 | 88 | +	return func(c *gin.Context) {  | 
 | 89 | +		err := ValidateRequestFromContext(c, router, options)  | 
 | 90 | +		if err != nil {  | 
 | 91 | +			// using errors.Is did not work  | 
 | 92 | +			if options != nil && options.ErrorHandler != nil && err.Error() == routers.ErrPathNotFound.Error() {  | 
 | 93 | +				options.ErrorHandler(c, err.Error(), http.StatusNotFound)  | 
 | 94 | +				// in case the handler didn't internally call Abort, stop the chain  | 
 | 95 | +				c.Abort()  | 
 | 96 | +			} else if options != nil && options.ErrorHandler != nil {  | 
 | 97 | +					options.ErrorHandler(c, err.Error(), http.StatusBadRequest)  | 
 | 98 | +					// in case the handler didn't internally call Abort, stop the chain  | 
 | 99 | +					c.Abort()  | 
 | 100 | +			} else if err.Error() == routers.ErrPathNotFound.Error() {  | 
 | 101 | +				// note: i am not sure if this is the best way to handle this  | 
 | 102 | +				c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": err.Error()})  | 
 | 103 | +			} else {  | 
 | 104 | +				// note: i am not sure if this is the best way to handle this  | 
 | 105 | +				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})  | 
 | 106 | +			}  | 
 | 107 | +		}  | 
 | 108 | +		c.Next()  | 
 | 109 | +	}  | 
 | 110 | +}  | 
 | 111 | + | 
 | 112 | +// ValidateRequestFromContext is called from the middleware above and actually does the work  | 
 | 113 | +// of validating a request.  | 
 | 114 | +func ValidateRequestFromContext(c *gin.Context, router routers.Router, options *Options) error {  | 
 | 115 | +	req := c.Request  | 
 | 116 | +	route, pathParams, err := router.FindRoute(req)  | 
 | 117 | + | 
 | 118 | +	// We failed to find a matching route for the request.  | 
 | 119 | +	if err != nil {  | 
 | 120 | +		switch e := err.(type) {  | 
 | 121 | +		case *routers.RouteError:  | 
 | 122 | +			// We've got a bad request, the path requested doesn't match  | 
 | 123 | +			// either server, or path, or something.  | 
 | 124 | +			return errors.New(e.Reason)  | 
 | 125 | +		default:  | 
 | 126 | +			// This should never happen today, but if our upstream code changes,  | 
 | 127 | +			// we don't want to crash the server, so handle the unexpected error.  | 
 | 128 | +			return fmt.Errorf("error validating route: %s", err.Error())  | 
 | 129 | +		}  | 
 | 130 | +	}  | 
 | 131 | + | 
 | 132 | +	validationInput := &openapi3filter.RequestValidationInput{  | 
 | 133 | +		Request:    req,  | 
 | 134 | +		PathParams: pathParams,  | 
 | 135 | +		Route:      route,  | 
 | 136 | +	}  | 
 | 137 | + | 
 | 138 | +	// Pass the gin context into the request validator, so that any callbacks  | 
 | 139 | +	// which it invokes make it available.  | 
 | 140 | +	requestContext := context.WithValue(context.Background(), GinContextKey, c) //nolint:staticcheck  | 
 | 141 | + | 
 | 142 | +	if options != nil {  | 
 | 143 | +		validationInput.Options = &options.Options  | 
 | 144 | +		validationInput.ParamDecoder = options.ParamDecoder  | 
 | 145 | +		requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) //nolint:staticcheck  | 
 | 146 | +	}  | 
 | 147 | + | 
 | 148 | +	err = openapi3filter.ValidateRequest(requestContext, validationInput)  | 
 | 149 | +	if err != nil {  | 
 | 150 | +		me := openapi3.MultiError{}  | 
 | 151 | +		if errors.As(err, &me) {  | 
 | 152 | +			errFunc := getMultiErrorHandlerFromOptions(options)  | 
 | 153 | +			return errFunc(me)  | 
 | 154 | +		}  | 
 | 155 | + | 
 | 156 | +		switch e := err.(type) {  | 
 | 157 | +		case *openapi3filter.RequestError:  | 
 | 158 | +			// We've got a bad request  | 
 | 159 | +			// Split up the verbose error by lines and return the first one  | 
 | 160 | +			// openapi errors seem to be multi-line with a decent message on the first  | 
 | 161 | +			errorLines := strings.Split(e.Error(), "\n")  | 
 | 162 | +			return fmt.Errorf("error in openapi3filter.RequestError: %s", errorLines[0])  | 
 | 163 | +		case *openapi3filter.SecurityRequirementsError:  | 
 | 164 | +			return fmt.Errorf("error in openapi3filter.SecurityRequirementsError: %s", e.Error())  | 
 | 165 | +		default:  | 
 | 166 | +			// This should never happen today, but if our upstream code changes,  | 
 | 167 | +			// we don't want to crash the server, so handle the unexpected error.  | 
 | 168 | +			return fmt.Errorf("error validating request: %w", err)  | 
 | 169 | +		}  | 
 | 170 | +	}  | 
 | 171 | +	return nil  | 
 | 172 | +}  | 
 | 173 | + | 
 | 174 | +// GetGinContext gets the echo context from within requests. It returns  | 
 | 175 | +// nil if not found or wrong type.  | 
 | 176 | +func GetGinContext(c context.Context) *gin.Context {  | 
 | 177 | +	iface := c.Value(GinContextKey)  | 
 | 178 | +	if iface == nil {  | 
 | 179 | +		return nil  | 
 | 180 | +	}  | 
 | 181 | +	ginCtx, ok := iface.(*gin.Context)  | 
 | 182 | +	if !ok {  | 
 | 183 | +		return nil  | 
 | 184 | +	}  | 
 | 185 | +	return ginCtx  | 
 | 186 | +}  | 
 | 187 | + | 
 | 188 | +func GetUserData(c context.Context) interface{} {  | 
 | 189 | +	return c.Value(UserDataKey)  | 
 | 190 | +}  | 
 | 191 | + | 
 | 192 | +// attempt to get the MultiErrorHandler from the options. If it is not set,  | 
 | 193 | +// return a default handler  | 
 | 194 | +func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler {  | 
 | 195 | +	if options == nil {  | 
 | 196 | +		return defaultMultiErrorHandler  | 
 | 197 | +	}  | 
 | 198 | + | 
 | 199 | +	if options.MultiErrorHandler == nil {  | 
 | 200 | +		return defaultMultiErrorHandler  | 
 | 201 | +	}  | 
 | 202 | + | 
 | 203 | +	return options.MultiErrorHandler  | 
 | 204 | +}  | 
 | 205 | + | 
 | 206 | +// defaultMultiErrorHandler returns a StatusBadRequest (400) and a list  | 
 | 207 | +// of all of the errors. This method is called if there are no other  | 
 | 208 | +// methods defined on the options.  | 
 | 209 | +func defaultMultiErrorHandler(me openapi3.MultiError) error {  | 
 | 210 | +	return fmt.Errorf("multiple errors encountered: %s", me)  | 
 | 211 | +}  | 
0 commit comments