diff --git a/_testdata/positive/client_options.json b/_testdata/positive/client_options.json new file mode 100644 index 000000000..9408fdab5 --- /dev/null +++ b/_testdata/positive/client_options.json @@ -0,0 +1,31 @@ +{ + "openapi": "3.0.3", + "paths": { + "/foo": { + "get": { + "operationId": "Foo", + "parameters": [ + { + "name": "body", + "in": "query", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "string" + } + } + } + } + } + } + } + } +} diff --git a/gen/_template/client.tmpl b/gen/_template/client.tmpl index 825b80385..c46b56ff6 100644 --- a/gen/_template/client.tmpl +++ b/gen/_template/client.tmpl @@ -40,6 +40,13 @@ func WithRequestClient(client ht.Client) RequestOption { } } +// WithServerURL sets client for request. +func WithServerURL(u *url.URL) RequestOption { + return func(cfg *requestConfig) { + cfg.ServerURL = u + } +} + // WithEditRequest sets function to edit request. func WithEditRequest(fn func(req *http.Request) error) RequestOption { return func(cfg *requestConfig) { @@ -205,6 +212,7 @@ func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) {{ $op.Name }}(ctx cont {{- if $op.WebhookInfo }},targetURL{{ end -}} {{- if $op.Request }},request{{ end -}} {{- if $op.Params }},params{{ end -}} + {{- if $cfg.RequestOptionsEnabled }},options...{{ end -}} ) return {{ if $op.Responses.DoPass }}res,{{ end }} err } diff --git a/internal/integration/_config/client_options.yml b/internal/integration/_config/client_options.yml new file mode 100644 index 000000000..046fafa85 --- /dev/null +++ b/internal/integration/_config/client_options.yml @@ -0,0 +1,4 @@ +generator: + features: + enable: + - "client/request/options" diff --git a/internal/integration/client_options_test.go b/internal/integration/client_options_test.go new file mode 100644 index 000000000..e3c55bc0d --- /dev/null +++ b/internal/integration/client_options_test.go @@ -0,0 +1,109 @@ +package integration + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + ht "github.com/ogen-go/ogen/http" + api "github.com/ogen-go/ogen/internal/integration/test_client_options" + + "github.com/stretchr/testify/require" +) + +type clientOptionsHandler struct{} + +// Foo implements api.Handler. +func (c *clientOptionsHandler) Foo(ctx context.Context, params api.FooParams) (string, error) { + return params.Body, nil +} + +var _ api.Handler = (*clientOptionsHandler)(nil) + +func TestClientOptions(t *testing.T) { + h, err := api.NewServer(&clientOptionsHandler{}) + require.NoError(t, err) + + s := httptest.NewServer(h) + defer s.Close() + + t.Run("WithRequestClient", func(t *testing.T) { + ctx := context.Background() + + c, err := api.NewClient(s.URL) + require.NoError(t, err) + + op := api.WithRequestClient(new(testFaultyClient)) + _, err = c.Foo(ctx, api.FooParams{Body: "test"}, op) + require.ErrorContains(t, err, `test faulty client`) + }) + t.Run("WithServerURL", func(t *testing.T) { + ctx := context.Background() + + c, err := api.NewClient(`http://completly-wrong-url.foo`) + require.NoError(t, err) + + u, err := url.Parse(s.URL) + require.NoError(t, err) + + op := api.WithServerURL(u) + resp, err := c.Foo(ctx, api.FooParams{Body: "test"}, op) + require.NoError(t, err) + require.Equal(t, "test", resp) + }) + t.Run("WithEditRequest", func(t *testing.T) { + ctx := context.Background() + + c, err := api.NewClient(s.URL) + require.NoError(t, err) + + op := api.WithEditRequest(func(req *http.Request) error { + q := req.URL.Query() + q.Set("body", "request-override") + req.URL.RawQuery = q.Encode() + return nil + }) + resp, err := c.Foo(ctx, api.FooParams{Body: "test"}, op) + require.NoError(t, err) + require.Equal(t, "request-override", resp) + + op = api.WithEditRequest(func(*http.Request) error { + return errors.New("request editor error") + }) + _, err = c.Foo(ctx, api.FooParams{Body: "test"}, op) + require.ErrorContains(t, err, `request editor error`) + }) + t.Run("WithEditResponse", func(t *testing.T) { + ctx := context.Background() + + c, err := api.NewClient(s.URL) + require.NoError(t, err) + + op := api.WithEditResponse(func(resp *http.Response) error { + resp.Body = io.NopCloser(strings.NewReader(`"response-override"`)) + return nil + }) + resp, err := c.Foo(ctx, api.FooParams{Body: "test"}, op) + require.NoError(t, err) + require.Equal(t, "response-override", resp) + + op = api.WithEditResponse(func(*http.Response) error { + return errors.New("response editor error") + }) + _, err = c.Foo(ctx, api.FooParams{Body: "test"}, op) + require.ErrorContains(t, err, `response editor error`) + }) +} + +type testFaultyClient struct{} + +var _ ht.Client = (*testFaultyClient)(nil) + +func (f *testFaultyClient) Do(req *http.Request) (*http.Response, error) { + return nil, errors.New("test faulty client") +} diff --git a/internal/integration/generate.go b/internal/integration/generate.go index e5173e16c..02531620a 100644 --- a/internal/integration/generate.go +++ b/internal/integration/generate.go @@ -30,6 +30,7 @@ package integration //go:generate go run ../../cmd/ogen -v --clean --config _config/allOf.yml --target test_allof ../../_testdata/positive/allOf.yml //go:generate go run ../../cmd/ogen -v --clean --config _config/anyOf.yml --target test_anyof ../../_testdata/positive/anyOf.json //go:generate go run ../../cmd/ogen -v --clean --config _config/additionalPropertiesPatternProperties.yml --target test_additionalpropertiespatternproperties ../../_testdata/positive/additionalPropertiesPatternProperties.yml +//go:generate go run ../../cmd/ogen -v --clean --config _config/client_options.yml --target test_client_options ../../_testdata/positive/client_options.json // //go:generate go run ../../cmd/ogen -v --clean -target test_enum_naming ../../_testdata/positive/enum_naming.yml //go:generate go run ../../cmd/ogen -v --clean -target test_naming_extensions ../../_testdata/positive/naming_extensions.json diff --git a/internal/integration/test_client_options/oas_cfg_gen.go b/internal/integration/test_client_options/oas_cfg_gen.go new file mode 100644 index 000000000..fc3ff3449 --- /dev/null +++ b/internal/integration/test_client_options/oas_cfg_gen.go @@ -0,0 +1,283 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "net/http" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" + + ht "github.com/ogen-go/ogen/http" + "github.com/ogen-go/ogen/middleware" + "github.com/ogen-go/ogen/ogenerrors" + "github.com/ogen-go/ogen/otelogen" +) + +var ( + // Allocate option closure once. + clientSpanKind = trace.WithSpanKind(trace.SpanKindClient) + // Allocate option closure once. + serverSpanKind = trace.WithSpanKind(trace.SpanKindServer) +) + +type ( + optionFunc[C any] func(*C) + otelOptionFunc func(*otelConfig) +) + +type otelConfig struct { + TracerProvider trace.TracerProvider + Tracer trace.Tracer + MeterProvider metric.MeterProvider + Meter metric.Meter +} + +func (cfg *otelConfig) initOTEL() { + if cfg.TracerProvider == nil { + cfg.TracerProvider = otel.GetTracerProvider() + } + if cfg.MeterProvider == nil { + cfg.MeterProvider = otel.GetMeterProvider() + } + cfg.Tracer = cfg.TracerProvider.Tracer(otelogen.Name, + trace.WithInstrumentationVersion(otelogen.SemVersion()), + ) + cfg.Meter = cfg.MeterProvider.Meter(otelogen.Name, + metric.WithInstrumentationVersion(otelogen.SemVersion()), + ) +} + +// ErrorHandler is error handler. +type ErrorHandler = ogenerrors.ErrorHandler + +type serverConfig struct { + otelConfig + NotFound http.HandlerFunc + MethodNotAllowed func(w http.ResponseWriter, r *http.Request, allowed string) + ErrorHandler ErrorHandler + Prefix string + Middleware Middleware + MaxMultipartMemory int64 +} + +// ServerOption is server config option. +type ServerOption interface { + applyServer(*serverConfig) +} + +var _ ServerOption = (optionFunc[serverConfig])(nil) + +func (o optionFunc[C]) applyServer(c *C) { + o(c) +} + +var _ ServerOption = (otelOptionFunc)(nil) + +func (o otelOptionFunc) applyServer(c *serverConfig) { + o(&c.otelConfig) +} + +func newServerConfig(opts ...ServerOption) serverConfig { + cfg := serverConfig{ + NotFound: http.NotFound, + MethodNotAllowed: func(w http.ResponseWriter, r *http.Request, allowed string) { + status := http.StatusMethodNotAllowed + if r.Method == "OPTIONS" { + w.Header().Set("Access-Control-Allow-Methods", allowed) + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + status = http.StatusNoContent + } else { + w.Header().Set("Allow", allowed) + } + w.WriteHeader(status) + }, + ErrorHandler: ogenerrors.DefaultErrorHandler, + Middleware: nil, + MaxMultipartMemory: 32 << 20, // 32 MB + } + for _, opt := range opts { + opt.applyServer(&cfg) + } + cfg.initOTEL() + return cfg +} + +type baseServer struct { + cfg serverConfig + requests metric.Int64Counter + errors metric.Int64Counter + duration metric.Float64Histogram +} + +func (s baseServer) notFound(w http.ResponseWriter, r *http.Request) { + s.cfg.NotFound(w, r) +} + +func (s baseServer) notAllowed(w http.ResponseWriter, r *http.Request, allowed string) { + s.cfg.MethodNotAllowed(w, r, allowed) +} + +func (cfg serverConfig) baseServer() (s baseServer, err error) { + s = baseServer{cfg: cfg} + if s.requests, err = otelogen.ServerRequestCountCounter(s.cfg.Meter); err != nil { + return s, err + } + if s.errors, err = otelogen.ServerErrorsCountCounter(s.cfg.Meter); err != nil { + return s, err + } + if s.duration, err = otelogen.ServerDurationHistogram(s.cfg.Meter); err != nil { + return s, err + } + return s, nil +} + +type clientConfig struct { + otelConfig + Client ht.Client +} + +// ClientOption is client config option. +type ClientOption interface { + applyClient(*clientConfig) +} + +var _ ClientOption = (optionFunc[clientConfig])(nil) + +func (o optionFunc[C]) applyClient(c *C) { + o(c) +} + +var _ ClientOption = (otelOptionFunc)(nil) + +func (o otelOptionFunc) applyClient(c *clientConfig) { + o(&c.otelConfig) +} + +func newClientConfig(opts ...ClientOption) clientConfig { + cfg := clientConfig{ + Client: http.DefaultClient, + } + for _, opt := range opts { + opt.applyClient(&cfg) + } + cfg.initOTEL() + return cfg +} + +type baseClient struct { + cfg clientConfig + requests metric.Int64Counter + errors metric.Int64Counter + duration metric.Float64Histogram +} + +func (cfg clientConfig) baseClient() (c baseClient, err error) { + c = baseClient{cfg: cfg} + if c.requests, err = otelogen.ClientRequestCountCounter(c.cfg.Meter); err != nil { + return c, err + } + if c.errors, err = otelogen.ClientErrorsCountCounter(c.cfg.Meter); err != nil { + return c, err + } + if c.duration, err = otelogen.ClientDurationHistogram(c.cfg.Meter); err != nil { + return c, err + } + return c, nil +} + +// Option is config option. +type Option interface { + ServerOption + ClientOption +} + +// WithTracerProvider specifies a tracer provider to use for creating a tracer. +// +// If none is specified, the global provider is used. +func WithTracerProvider(provider trace.TracerProvider) Option { + return otelOptionFunc(func(cfg *otelConfig) { + if provider != nil { + cfg.TracerProvider = provider + } + }) +} + +// WithMeterProvider specifies a meter provider to use for creating a meter. +// +// If none is specified, the otel.GetMeterProvider() is used. +func WithMeterProvider(provider metric.MeterProvider) Option { + return otelOptionFunc(func(cfg *otelConfig) { + if provider != nil { + cfg.MeterProvider = provider + } + }) +} + +// WithClient specifies http client to use. +func WithClient(client ht.Client) ClientOption { + return optionFunc[clientConfig](func(cfg *clientConfig) { + if client != nil { + cfg.Client = client + } + }) +} + +// WithNotFound specifies Not Found handler to use. +func WithNotFound(notFound http.HandlerFunc) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + if notFound != nil { + cfg.NotFound = notFound + } + }) +} + +// WithMethodNotAllowed specifies Method Not Allowed handler to use. +func WithMethodNotAllowed(methodNotAllowed func(w http.ResponseWriter, r *http.Request, allowed string)) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + if methodNotAllowed != nil { + cfg.MethodNotAllowed = methodNotAllowed + } + }) +} + +// WithErrorHandler specifies error handler to use. +func WithErrorHandler(h ErrorHandler) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + if h != nil { + cfg.ErrorHandler = h + } + }) +} + +// WithPathPrefix specifies server path prefix. +func WithPathPrefix(prefix string) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + cfg.Prefix = prefix + }) +} + +// WithMiddleware specifies middlewares to use. +func WithMiddleware(m ...Middleware) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + switch len(m) { + case 0: + cfg.Middleware = nil + case 1: + cfg.Middleware = m[0] + default: + cfg.Middleware = middleware.ChainMiddlewares(m...) + } + }) +} + +// WithMaxMultipartMemory specifies limit of memory for storing file parts. +// File parts which can't be stored in memory will be stored on disk in temporary files. +func WithMaxMultipartMemory(max int64) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + if max > 0 { + cfg.MaxMultipartMemory = max + } + }) +} diff --git a/internal/integration/test_client_options/oas_client_gen.go b/internal/integration/test_client_options/oas_client_gen.go new file mode 100644 index 000000000..0a011f4f8 --- /dev/null +++ b/internal/integration/test_client_options/oas_client_gen.go @@ -0,0 +1,224 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "context" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-faster/errors" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "go.opentelemetry.io/otel/trace" + + "github.com/ogen-go/ogen/conv" + ht "github.com/ogen-go/ogen/http" + "github.com/ogen-go/ogen/otelogen" + "github.com/ogen-go/ogen/uri" +) + +type requestConfig struct { + Client ht.Client + ServerURL *url.URL + EditRequest func(req *http.Request) error + EditResponse func(resp *http.Response) error +} + +func (cfg *requestConfig) setDefaults(c baseClient) { + if cfg.Client == nil { + cfg.Client = c.cfg.Client + } +} + +func (cfg *requestConfig) onRequest(req *http.Request) error { + if fn := cfg.EditRequest; fn != nil { + return fn(req) + } + return nil +} + +func (cfg *requestConfig) onResponse(resp *http.Response) error { + if fn := cfg.EditResponse; fn != nil { + return fn(resp) + } + return nil +} + +// RequestOption defines options for request. +type RequestOption func(cfg *requestConfig) + +// WithRequestClient sets client for request. +func WithRequestClient(client ht.Client) RequestOption { + return func(cfg *requestConfig) { + cfg.Client = client + } +} + +// WithServerURL sets client for request. +func WithServerURL(u *url.URL) RequestOption { + return func(cfg *requestConfig) { + cfg.ServerURL = u + } +} + +// WithEditRequest sets function to edit request. +func WithEditRequest(fn func(req *http.Request) error) RequestOption { + return func(cfg *requestConfig) { + cfg.EditRequest = fn + } +} + +// WithEditResponse sets function to edit response. +func WithEditResponse(fn func(resp *http.Response) error) RequestOption { + return func(cfg *requestConfig) { + cfg.EditResponse = fn + } +} + +func trimTrailingSlashes(u *url.URL) { + u.Path = strings.TrimRight(u.Path, "/") + u.RawPath = strings.TrimRight(u.RawPath, "/") +} + +// Invoker invokes operations described by OpenAPI v3 specification. +type Invoker interface { + // Foo invokes Foo operation. + // + // GET /foo + Foo(ctx context.Context, params FooParams, options ...RequestOption) (string, error) +} + +// Client implements OAS client. +type Client struct { + serverURL *url.URL + baseClient +} + +// NewClient initializes new Client defined by OAS. +func NewClient(serverURL string, opts ...ClientOption) (*Client, error) { + u, err := url.Parse(serverURL) + if err != nil { + return nil, err + } + trimTrailingSlashes(u) + + c, err := newClientConfig(opts...).baseClient() + if err != nil { + return nil, err + } + return &Client{ + serverURL: u, + baseClient: c, + }, nil +} + +// Foo invokes Foo operation. +// +// GET /foo +func (c *Client) Foo(ctx context.Context, params FooParams, options ...RequestOption) (string, error) { + res, err := c.sendFoo(ctx, params, options...) + return res, err +} + +func (c *Client) sendFoo(ctx context.Context, params FooParams, requestOptions ...RequestOption) (res string, err error) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("Foo"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/foo"), + } + + // Run stopwatch. + startTime := time.Now() + defer func() { + // Use floating point division here for higher precision (instead of Millisecond method). + elapsedDuration := time.Since(startTime) + c.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), metric.WithAttributes(otelAttrs...)) + }() + + // Increment request counter. + c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + + // Start a span for this request. + ctx, span := c.cfg.Tracer.Start(ctx, FooOperation, + trace.WithAttributes(otelAttrs...), + clientSpanKind, + ) + // Track stage for error reporting. + var stage string + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + } + span.End() + }() + + var reqCfg requestConfig + reqCfg.setDefaults(c.baseClient) + for _, o := range requestOptions { + o(&reqCfg) + } + + stage = "BuildURL" + u := c.serverURL + if override := reqCfg.ServerURL; override != nil { + u = override + } + u = uri.Clone(u) + var pathParts [1]string + pathParts[0] = "/foo" + uri.AddPathParts(u, pathParts[:]...) + + stage = "EncodeQueryParams" + q := uri.NewQueryEncoder() + { + // Encode "body" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "body", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + return e.EncodeValue(conv.StringToString(params.Body)) + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } + u.RawQuery = q.Values().Encode() + + stage = "EncodeRequest" + r, err := ht.NewRequest(ctx, "GET", u) + if err != nil { + return res, errors.Wrap(err, "create request") + } + + if err := reqCfg.onRequest(r); err != nil { + return res, errors.Wrap(err, "edit request") + } + + stage = "SendRequest" + resp, err := reqCfg.Client.Do(r) + if err != nil { + return res, errors.Wrap(err, "do request") + } + defer resp.Body.Close() + + if err := reqCfg.onResponse(resp); err != nil { + return res, errors.Wrap(err, "edit response") + } + + stage = "DecodeResponse" + result, err := decodeFooResponse(resp) + if err != nil { + return res, errors.Wrap(err, "decode response") + } + + return result, nil +} diff --git a/internal/integration/test_client_options/oas_handlers_gen.go b/internal/integration/test_client_options/oas_handlers_gen.go new file mode 100644 index 000000000..ede0f6090 --- /dev/null +++ b/internal/integration/test_client_options/oas_handlers_gen.go @@ -0,0 +1,167 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "context" + "net/http" + "time" + + "github.com/go-faster/errors" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "go.opentelemetry.io/otel/trace" + + ht "github.com/ogen-go/ogen/http" + "github.com/ogen-go/ogen/middleware" + "github.com/ogen-go/ogen/ogenerrors" + "github.com/ogen-go/ogen/otelogen" +) + +type codeRecorder struct { + http.ResponseWriter + status int +} + +func (c *codeRecorder) WriteHeader(status int) { + c.status = status + c.ResponseWriter.WriteHeader(status) +} + +// handleFooRequest handles Foo operation. +// +// GET /foo +func (s *Server) handleFooRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) { + statusWriter := &codeRecorder{ResponseWriter: w} + w = statusWriter + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("Foo"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/foo"), + } + + // Start a span for this request. + ctx, span := s.cfg.Tracer.Start(r.Context(), FooOperation, + trace.WithAttributes(otelAttrs...), + serverSpanKind, + ) + defer span.End() + + // Add Labeler to context. + labeler := &Labeler{attrs: otelAttrs} + ctx = contextWithLabeler(ctx, labeler) + + // Run stopwatch. + startTime := time.Now() + defer func() { + elapsedDuration := time.Since(startTime) + + attrSet := labeler.AttributeSet() + attrs := attrSet.ToSlice() + code := statusWriter.status + if code != 0 { + codeAttr := semconv.HTTPResponseStatusCode(code) + attrs = append(attrs, codeAttr) + span.SetAttributes(codeAttr) + } + attrOpt := metric.WithAttributes(attrs...) + + // Increment request counter. + s.requests.Add(ctx, 1, attrOpt) + + // Use floating point division here for higher precision (instead of Millisecond method). + s.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), attrOpt) + }() + + var ( + recordError = func(stage string, err error) { + span.RecordError(err) + + // https://opentelemetry.io/docs/specs/semconv/http/http-spans/#status + // Span Status MUST be left unset if HTTP status code was in the 1xx, 2xx or 3xx ranges, + // unless there was another error (e.g., network error receiving the response body; or 3xx codes with + // max redirects exceeded), in which case status MUST be set to Error. + code := statusWriter.status + if code >= 100 && code < 500 { + span.SetStatus(codes.Error, stage) + } + + attrSet := labeler.AttributeSet() + attrs := attrSet.ToSlice() + if code != 0 { + attrs = append(attrs, semconv.HTTPResponseStatusCode(code)) + } + + s.errors.Add(ctx, 1, metric.WithAttributes(attrs...)) + } + err error + opErrContext = ogenerrors.OperationContext{ + Name: FooOperation, + ID: "Foo", + } + ) + params, err := decodeFooParams(args, argsEscaped, r) + if err != nil { + err = &ogenerrors.DecodeParamsError{ + OperationContext: opErrContext, + Err: err, + } + defer recordError("DecodeParams", err) + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + + var response string + if m := s.cfg.Middleware; m != nil { + mreq := middleware.Request{ + Context: ctx, + OperationName: FooOperation, + OperationSummary: "", + OperationID: "Foo", + Body: nil, + Params: middleware.Parameters{ + { + Name: "body", + In: "query", + }: params.Body, + }, + Raw: r, + } + + type ( + Request = struct{} + Params = FooParams + Response = string + ) + response, err = middleware.HookMiddleware[ + Request, + Params, + Response, + ]( + m, + mreq, + unpackFooParams, + func(ctx context.Context, request Request, params Params) (response Response, err error) { + response, err = s.h.Foo(ctx, params) + return response, err + }, + ) + } else { + response, err = s.h.Foo(ctx, params) + } + if err != nil { + defer recordError("Internal", err) + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + + if err := encodeFooResponse(response, w, span); err != nil { + defer recordError("EncodeResponse", err) + if !errors.Is(err, ht.ErrInternalServerErrorResponse) { + s.cfg.ErrorHandler(ctx, w, r, err) + } + return + } +} diff --git a/internal/integration/test_client_options/oas_labeler_gen.go b/internal/integration/test_client_options/oas_labeler_gen.go new file mode 100644 index 000000000..7e519e84e --- /dev/null +++ b/internal/integration/test_client_options/oas_labeler_gen.go @@ -0,0 +1,42 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" +) + +// Labeler is used to allow adding custom attributes to the server request metrics. +type Labeler struct { + attrs []attribute.KeyValue +} + +// Add attributes to the Labeler. +func (l *Labeler) Add(attrs ...attribute.KeyValue) { + l.attrs = append(l.attrs, attrs...) +} + +// AttributeSet returns the attributes added to the Labeler as an attribute.Set. +func (l *Labeler) AttributeSet() attribute.Set { + return attribute.NewSet(l.attrs...) +} + +type labelerContextKey struct{} + +// LabelerFromContext retrieves the Labeler from the provided context, if present. +// +// If no Labeler was found in the provided context a new, empty Labeler is returned and the second +// return value is false. In this case it is safe to use the Labeler but any attributes added to +// it will not be used. +func LabelerFromContext(ctx context.Context) (*Labeler, bool) { + if l, ok := ctx.Value(labelerContextKey{}).(*Labeler); ok { + return l, true + } + return &Labeler{}, false +} + +func contextWithLabeler(ctx context.Context, l *Labeler) context.Context { + return context.WithValue(ctx, labelerContextKey{}, l) +} diff --git a/internal/integration/test_client_options/oas_middleware_gen.go b/internal/integration/test_client_options/oas_middleware_gen.go new file mode 100644 index 000000000..6f58a1a79 --- /dev/null +++ b/internal/integration/test_client_options/oas_middleware_gen.go @@ -0,0 +1,10 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "github.com/ogen-go/ogen/middleware" +) + +// Middleware is middleware type. +type Middleware = middleware.Middleware diff --git a/internal/integration/test_client_options/oas_operations_gen.go b/internal/integration/test_client_options/oas_operations_gen.go new file mode 100644 index 000000000..51da59a02 --- /dev/null +++ b/internal/integration/test_client_options/oas_operations_gen.go @@ -0,0 +1,10 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +// OperationName is the ogen operation name +type OperationName = string + +const ( + FooOperation OperationName = "Foo" +) diff --git a/internal/integration/test_client_options/oas_parameters_gen.go b/internal/integration/test_client_options/oas_parameters_gen.go new file mode 100644 index 000000000..f4b1bb2fa --- /dev/null +++ b/internal/integration/test_client_options/oas_parameters_gen.go @@ -0,0 +1,70 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "net/http" + + "github.com/ogen-go/ogen/conv" + "github.com/ogen-go/ogen/middleware" + "github.com/ogen-go/ogen/ogenerrors" + "github.com/ogen-go/ogen/uri" + "github.com/ogen-go/ogen/validate" +) + +// FooParams is parameters of Foo operation. +type FooParams struct { + Body string +} + +func unpackFooParams(packed middleware.Parameters) (params FooParams) { + { + key := middleware.ParameterKey{ + Name: "body", + In: "query", + } + params.Body = packed[key].(string) + } + return params +} + +func decodeFooParams(args [0]string, argsEscaped bool, r *http.Request) (params FooParams, _ error) { + q := uri.NewQueryDecoder(r.URL.Query()) + // Decode query: body. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "body", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := conv.ToString(val) + if err != nil { + return err + } + + params.Body = c + return nil + }); err != nil { + return err + } + } else { + return validate.ErrFieldRequired + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "body", + In: "query", + Err: err, + } + } + return params, nil +} diff --git a/internal/integration/test_client_options/oas_request_decoders_gen.go b/internal/integration/test_client_options/oas_request_decoders_gen.go new file mode 100644 index 000000000..ae379a2db --- /dev/null +++ b/internal/integration/test_client_options/oas_request_decoders_gen.go @@ -0,0 +1,3 @@ +// Code generated by ogen, DO NOT EDIT. + +package api diff --git a/internal/integration/test_client_options/oas_request_encoders_gen.go b/internal/integration/test_client_options/oas_request_encoders_gen.go new file mode 100644 index 000000000..ae379a2db --- /dev/null +++ b/internal/integration/test_client_options/oas_request_encoders_gen.go @@ -0,0 +1,3 @@ +// Code generated by ogen, DO NOT EDIT. + +package api diff --git a/internal/integration/test_client_options/oas_response_decoders_gen.go b/internal/integration/test_client_options/oas_response_decoders_gen.go new file mode 100644 index 000000000..d69e1c3e2 --- /dev/null +++ b/internal/integration/test_client_options/oas_response_decoders_gen.go @@ -0,0 +1,58 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "io" + "mime" + "net/http" + + "github.com/go-faster/errors" + "github.com/go-faster/jx" + + "github.com/ogen-go/ogen/ogenerrors" + "github.com/ogen-go/ogen/validate" +) + +func decodeFooResponse(resp *http.Response) (res string, _ error) { + switch resp.StatusCode { + case 200: + // Code 200. + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response string + if err := func() error { + v, err := d.Str() + response = string(v) + if err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + return response, nil + default: + return res, validate.InvalidContentType(ct) + } + } + return res, validate.UnexpectedStatusCode(resp.StatusCode) +} diff --git a/internal/integration/test_client_options/oas_response_encoders_gen.go b/internal/integration/test_client_options/oas_response_encoders_gen.go new file mode 100644 index 000000000..a400deebf --- /dev/null +++ b/internal/integration/test_client_options/oas_response_encoders_gen.go @@ -0,0 +1,26 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "net/http" + + "github.com/go-faster/errors" + "github.com/go-faster/jx" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +func encodeFooResponse(response string, w http.ResponseWriter, span trace.Span) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + span.SetStatus(codes.Ok, http.StatusText(200)) + + e := new(jx.Encoder) + e.Str(response) + if _, err := e.WriteTo(w); err != nil { + return errors.Wrap(err, "write") + } + + return nil +} diff --git a/internal/integration/test_client_options/oas_router_gen.go b/internal/integration/test_client_options/oas_router_gen.go new file mode 100644 index 000000000..aa88c1368 --- /dev/null +++ b/internal/integration/test_client_options/oas_router_gen.go @@ -0,0 +1,180 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "net/http" + "net/url" + "strings" + + "github.com/ogen-go/ogen/uri" +) + +func (s *Server) cutPrefix(path string) (string, bool) { + prefix := s.cfg.Prefix + if prefix == "" { + return path, true + } + if !strings.HasPrefix(path, prefix) { + // Prefix doesn't match. + return "", false + } + // Cut prefix from the path. + return strings.TrimPrefix(path, prefix), true +} + +// ServeHTTP serves http request as defined by OpenAPI v3 specification, +// calling handler that matches the path or returning not found error. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + elem := r.URL.Path + elemIsEscaped := false + if rawPath := r.URL.RawPath; rawPath != "" { + if normalized, ok := uri.NormalizeEscapedPath(rawPath); ok { + elem = normalized + elemIsEscaped = strings.ContainsRune(elem, '%') + } + } + + elem, ok := s.cutPrefix(elem) + if !ok || len(elem) == 0 { + s.notFound(w, r) + return + } + + // Static code generated router with unwrapped path search. + switch { + default: + if len(elem) == 0 { + break + } + switch elem[0] { + case '/': // Prefix: "/foo" + origElem := elem + if l := len("/foo"); len(elem) >= l && elem[0:l] == "/foo" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch r.Method { + case "GET": + s.handleFooRequest([0]string{}, elemIsEscaped, w, r) + default: + s.notAllowed(w, r, "GET") + } + + return + } + + elem = origElem + } + } + s.notFound(w, r) +} + +// Route is route object. +type Route struct { + name string + summary string + operationID string + pathPattern string + count int + args [0]string +} + +// Name returns ogen operation name. +// +// It is guaranteed to be unique and not empty. +func (r Route) Name() string { + return r.name +} + +// Summary returns OpenAPI summary. +func (r Route) Summary() string { + return r.summary +} + +// OperationID returns OpenAPI operationId. +func (r Route) OperationID() string { + return r.operationID +} + +// PathPattern returns OpenAPI path. +func (r Route) PathPattern() string { + return r.pathPattern +} + +// Args returns parsed arguments. +func (r Route) Args() []string { + return r.args[:r.count] +} + +// FindRoute finds Route for given method and path. +// +// Note: this method does not unescape path or handle reserved characters in path properly. Use FindPath instead. +func (s *Server) FindRoute(method, path string) (Route, bool) { + return s.FindPath(method, &url.URL{Path: path}) +} + +// FindPath finds Route for given method and URL. +func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) { + var ( + elem = u.Path + args = r.args + ) + if rawPath := u.RawPath; rawPath != "" { + if normalized, ok := uri.NormalizeEscapedPath(rawPath); ok { + elem = normalized + } + defer func() { + for i, arg := range r.args[:r.count] { + if unescaped, err := url.PathUnescape(arg); err == nil { + r.args[i] = unescaped + } + } + }() + } + + elem, ok := s.cutPrefix(elem) + if !ok { + return r, false + } + + // Static code generated router with unwrapped path search. + switch { + default: + if len(elem) == 0 { + break + } + switch elem[0] { + case '/': // Prefix: "/foo" + origElem := elem + if l := len("/foo"); len(elem) >= l && elem[0:l] == "/foo" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch method { + case "GET": + r.name = FooOperation + r.summary = "" + r.operationID = "Foo" + r.pathPattern = "/foo" + r.args = args + r.count = 0 + return r, true + default: + return + } + } + + elem = origElem + } + } + return r, false +} diff --git a/internal/integration/test_client_options/oas_schemas_gen.go b/internal/integration/test_client_options/oas_schemas_gen.go new file mode 100644 index 000000000..ae379a2db --- /dev/null +++ b/internal/integration/test_client_options/oas_schemas_gen.go @@ -0,0 +1,3 @@ +// Code generated by ogen, DO NOT EDIT. + +package api diff --git a/internal/integration/test_client_options/oas_server_gen.go b/internal/integration/test_client_options/oas_server_gen.go new file mode 100644 index 000000000..da8895adc --- /dev/null +++ b/internal/integration/test_client_options/oas_server_gen.go @@ -0,0 +1,34 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "context" +) + +// Handler handles operations described by OpenAPI v3 specification. +type Handler interface { + // Foo implements Foo operation. + // + // GET /foo + Foo(ctx context.Context, params FooParams) (string, error) +} + +// Server implements http server based on OpenAPI v3 specification and +// calls Handler to handle requests. +type Server struct { + h Handler + baseServer +} + +// NewServer creates new Server. +func NewServer(h Handler, opts ...ServerOption) (*Server, error) { + s, err := newServerConfig(opts...).baseServer() + if err != nil { + return nil, err + } + return &Server{ + h: h, + baseServer: s, + }, nil +} diff --git a/internal/integration/test_client_options/oas_unimplemented_gen.go b/internal/integration/test_client_options/oas_unimplemented_gen.go new file mode 100644 index 000000000..96a919277 --- /dev/null +++ b/internal/integration/test_client_options/oas_unimplemented_gen.go @@ -0,0 +1,21 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "context" + + ht "github.com/ogen-go/ogen/http" +) + +// UnimplementedHandler is no-op Handler which returns http.ErrNotImplemented. +type UnimplementedHandler struct{} + +var _ Handler = UnimplementedHandler{} + +// Foo implements Foo operation. +// +// GET /foo +func (UnimplementedHandler) Foo(ctx context.Context, params FooParams) (r string, _ error) { + return r, ht.ErrNotImplemented +}