Skip to content

Commit d4da482

Browse files
author
Oliver Kahrmann
committed
Generate a parameter for the Accept header when a response code has multiple content types
1 parent e198444 commit d4da482

File tree

6 files changed

+148
-0
lines changed

6 files changed

+148
-0
lines changed

gen/_template/mediatypes.tmpl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{{ define "mediatypes" }}
2+
{{- /*gotype: github.com/ogen-go/ogen/gen.TemplateConfig*/ -}}
3+
{{ template "header" $ }}
4+
5+
const (
6+
{{- range $op := $.MediaTypes }}
7+
MediaType{{ $op.Name }} string = {{ quote $op.Value }}
8+
{{- end }}
9+
)
10+
11+
{{ end }}

gen/gen_operation.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"github.com/ogen-go/ogen/gen/ir"
1010
"github.com/ogen-go/ogen/internal/xslices"
11+
"github.com/ogen-go/ogen/jsonschema"
1112
"github.com/ogen-go/ogen/openapi"
1213
)
1314

@@ -47,6 +48,54 @@ func (g *Generator) generateOperation(ctx *genctx, webhookName string, spec *ope
4748
return nil, errors.Wrap(err, "parameters")
4849
}
4950

51+
// If there is no manual specification of the Accept parameter
52+
if _, ok := xslices.FindFunc(op.Params, func(param *ir.Parameter) bool { return param.Spec.In.Header() && param.Spec.Name == "Accept" }); !ok {
53+
supportsMultipleMediaTypes := false
54+
// And at least one operation defines multiple media types
55+
for _, statusCode := range spec.Responses.StatusCode {
56+
if len(statusCode.Content) > 1 {
57+
supportsMultipleMediaTypes = true
58+
break
59+
}
60+
}
61+
if supportsMultipleMediaTypes {
62+
mediaTypes := map[string]any{}
63+
for _, statusCode := range spec.Responses.StatusCode {
64+
for mediaType := range statusCode.Content {
65+
mediaTypes[mediaType] = nil
66+
}
67+
}
68+
69+
mediaTypeType, ok := ctx.global.types["AcceptHeader"]
70+
if !ok {
71+
mediaTypeType = &ir.Type{
72+
Doc: "Auto-generated parameter for the Accept header",
73+
Kind: ir.KindStruct,
74+
Name: "ht.AcceptHeader",
75+
}
76+
}
77+
78+
acceptParam := &ir.Parameter{
79+
Name: "Accept",
80+
Type: mediaTypeType,
81+
Spec: &openapi.Parameter{
82+
Name: "Accept",
83+
Description: "Auto-generated parameter for the Accept header",
84+
Schema: &jsonschema.Schema{
85+
Type: jsonschema.String,
86+
},
87+
In: openapi.LocationHeader,
88+
},
89+
}
90+
91+
// acceptParam, err := g.generateParameter(ctx, op.Name)
92+
// if err != nil {
93+
// return nil, errors.Wrap(err, "parameters")
94+
// }
95+
op.Params = append(op.Params, acceptParam)
96+
}
97+
}
98+
5099
// Convert []openapi.PathPart to []*ir.PathPart
51100
op.PathParts = convertPathParts(op.Spec.Path, op.Params)
52101

gen/generator.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type Generator struct {
3030
defaultOperations []*ir.Operation // Operations without an operation group.
3131
operationGroups []*ir.OperationGroup
3232
webhooks []*ir.Operation
33+
mediaTypes map[ir.ContentType]*ir.MediaType
3334
securities map[string]*ir.Security
3435
tstorage *tstorage
3536
errType *ir.Response
@@ -98,6 +99,7 @@ func NewGenerator(spec *ogen.Spec, opts Options) (*Generator, error) {
9899
servers: nil,
99100
operations: nil,
100101
webhooks: nil,
102+
mediaTypes: map[ir.ContentType]*ir.MediaType{},
101103
securities: map[string]*ir.Security{},
102104
tstorage: newTStorage(),
103105
errType: nil,
@@ -187,6 +189,21 @@ func (g *Generator) makeOps(ops []*openapi.Operation) error {
187189
return err
188190
}
189191

192+
// Collect all media types used in responses to generate constants
193+
for _, response := range op.Responses.StatusCode {
194+
for contentType := range response.Contents {
195+
if _, ok := g.mediaTypes[contentType]; !ok {
196+
constantName, err := pascalNonEmpty(string(contentType))
197+
if err != nil {
198+
return errors.Wrap(err, "gather media types")
199+
}
200+
g.mediaTypes[contentType] = &ir.MediaType{
201+
Name: constantName,
202+
Value: contentType,
203+
}
204+
}
205+
}
206+
}
190207
g.operations = append(g.operations, op)
191208
}
192209

@@ -282,6 +299,12 @@ func sortOperations(ops []*ir.Operation) {
282299
})
283300
}
284301

302+
func sortMediaTypes(types []*ir.MediaType) {
303+
slices.SortStableFunc(types, func(a, b *ir.MediaType) int {
304+
return strings.Compare(a.Name, b.Name)
305+
})
306+
}
307+
285308
func groupOperations(ops []*ir.Operation) (
286309
defaultOperations []*ir.Operation,
287310
operationGroups []*ir.OperationGroup,
@@ -320,6 +343,16 @@ func (g *Generator) Operations() []*ir.Operation {
320343
return g.operations
321344
}
322345

346+
// MediaTypes returns generated media type constants.
347+
func (g *Generator) MediaTypes() []*ir.MediaType {
348+
mediaTypesSorted := make([]*ir.MediaType, 0, len(g.mediaTypes))
349+
for _, mt := range g.mediaTypes {
350+
mediaTypesSorted = append(mediaTypesSorted, mt)
351+
}
352+
sortMediaTypes(mediaTypesSorted)
353+
return mediaTypesSorted
354+
}
355+
323356
// Webhooks returns generated webhooks.
324357
func (g *Generator) Webhooks() []*ir.Operation {
325358
return g.webhooks

gen/ir/operation.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ type OperationGroup struct {
3434
Operations []*Operation
3535
}
3636

37+
type MediaType struct {
38+
Name string // Generated constant name
39+
Value ContentType // Actual media type, e.g. application/xml
40+
}
41+
3742
// OTELAttribute represents OpenTelemetry attribute defined by otelogen package.
3843
type OTELAttribute struct {
3944
// Key is a name of the attribute constructor in otelogen package.

gen/write.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type TemplateConfig struct {
2727
DefaultOperations []*ir.Operation
2828
OperationGroups []*ir.OperationGroup
2929
Webhooks []*ir.Operation
30+
MediaTypes []*ir.MediaType
3031
Types map[string]*ir.Type
3132
Interfaces map[string]*ir.Type
3233
Error *ir.Response
@@ -257,6 +258,7 @@ func (g *Generator) WriteSource(fs FileSystem, pkgName string) error {
257258
DefaultOperations: g.defaultOperations,
258259
OperationGroups: g.operationGroups,
259260
Webhooks: g.webhooks,
261+
MediaTypes: g.MediaTypes(),
260262
Types: types,
261263
Interfaces: interfaces,
262264
Error: g.errType,
@@ -335,6 +337,7 @@ func (g *Generator) WriteSource(fs FileSystem, pkgName string) error {
335337
{"unimplemented", features.Has(OgenUnimplemented) && genServer},
336338
{"labeler", features.Has(OgenOtel) && genServer},
337339
{"operations", (genClient || genServer)},
340+
{"mediatypes", (genClient || genServer)},
338341
} {
339342
t := t
340343
if !t.enabled {

http/accept_header.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package http
2+
3+
import (
4+
"slices"
5+
"strings"
6+
7+
"github.com/go-faster/errors"
8+
"github.com/ogen-go/ogen/uri"
9+
)
10+
11+
// Represents the content of an HTTP Accept Header.
12+
// Supports multiple content types (comma separated) and wild cards.
13+
// Does NOT support q-factor weighting (these values are stripped and ignored).
14+
type AcceptHeader []string
15+
16+
// MarshalText implements encoding.TextMarshaler.
17+
func (s AcceptHeader) MarshalText() ([]byte, error) {
18+
return []byte(strings.Join(s, ", ")), nil
19+
}
20+
21+
// UnmarshalText implements encoding.TextUnmarshaler.
22+
func (s *AcceptHeader) UnmarshalText(data []byte) error {
23+
*s = strings.Split(string(data), ",")
24+
for i, segment := range *s {
25+
// Remove q-factor weighting
26+
if semicolonIndex := strings.IndexByte(segment, ';'); semicolonIndex >= 0 {
27+
segment = segment[:semicolonIndex]
28+
}
29+
// Trim spaces to clean up leftovers from comma separation above (spaces are optional there)
30+
(*s)[i] = strings.TrimSpace(segment)
31+
}
32+
return nil
33+
}
34+
35+
func (s AcceptHeader) MatchesContentType(contentType string) bool {
36+
return slices.ContainsFunc(s, func(pattern string) bool {
37+
return MatchContentType(pattern, contentType)
38+
})
39+
}
40+
41+
func (s *AcceptHeader) DecodeURI(d uri.Decoder) error {
42+
val, err := d.DecodeValue()
43+
if err != nil {
44+
return errors.Wrap(err, "decode accept header")
45+
}
46+
return s.UnmarshalText([]byte(val))
47+
}

0 commit comments

Comments
 (0)