Skip to content
This repository was archived by the owner on Jan 22, 2025. It is now read-only.

Commit 1d214fb

Browse files
Support openapi3 refs (#39)
1 parent 034c0f0 commit 1d214fb

File tree

2 files changed

+70
-18
lines changed

2 files changed

+70
-18
lines changed

src/flask_openapi/base.py

+47-17
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,27 @@
77
88
"""
99
import codecs
10+
import json
1011
import logging
1112
import os
1213
import re
13-
from typing import Optional
14-
15-
import yaml
16-
import json
17-
1814
from collections import defaultdict
1915
from functools import partial, wraps
16+
from typing import Optional
2017

21-
from flask import (abort, Blueprint, current_app, jsonify, Markup, redirect,
22-
render_template, request, Response, url_for)
18+
import yaml
19+
from flask import (
20+
abort,
21+
Blueprint,
22+
current_app,
23+
jsonify,
24+
Markup,
25+
redirect,
26+
render_template,
27+
request,
28+
Response,
29+
url_for,
30+
)
2331
from flask.json import JSONEncoder
2432
from flask.views import MethodView
2533
from werkzeug.datastructures import Authorization
@@ -32,12 +40,21 @@
3240
from mistune import markdown
3341

3442
from . import __version__
35-
from .constants import (OAS3_SUB_COMPONENTS, OPTIONAL_FIELDS,
36-
OPTIONAL_OAS3_FIELDS)
37-
from .utils import (extract_definitions, extract_schema, get_schema_specs,
38-
get_specs, get_vendor_extension_fields, is_openapi3,
39-
LazyString, parse_definition_docstring, parse_imports,
40-
swag_annotation, validate)
43+
from .constants import OAS3_SUB_COMPONENTS, OPTIONAL_FIELDS, OPTIONAL_OAS3_FIELDS
44+
from .utils import (
45+
convert_responses_to_openapi3,
46+
extract_definitions,
47+
extract_schema,
48+
get_schema_specs,
49+
get_specs,
50+
get_vendor_extension_fields,
51+
is_openapi3,
52+
LazyString,
53+
parse_definition_docstring,
54+
parse_imports,
55+
swag_annotation,
56+
validate,
57+
)
4158

4259

4360
def NO_SANITIZER(text):
@@ -546,20 +563,29 @@ def get_operations(swag, path_verb=None):
546563
operation['requestBody'] = request_body
547564
if callbacks:
548565
operation['callbacks'] = callbacks
549-
if responses:
550-
operation['responses'] = responses
551566
# parameters - swagger ui dislikes empty parameter lists
552567
if len(params) > 0:
553568
operation['parameters'] = params
569+
570+
media_types = ['application/json']
554571
# other optionals
555572
for key in optional_fields:
556573
if key in swag:
557574
value = swag.get(key)
558575
if key in ('produces', 'consumes'):
559576
if not isinstance(value, (list, tuple)):
560577
value = [value]
578+
if key == 'produces':
579+
media_types = value
561580

562581
operation[key] = value
582+
583+
if responses:
584+
if is_openapi3(openapi_version):
585+
convert_responses_to_openapi3(responses, media_types)
586+
587+
operation['responses'] = responses
588+
563589
if path_verb:
564590
operations[path_verb] = operation
565591
else:
@@ -616,8 +642,12 @@ def get_operations(swag, path_verb=None):
616642
paths[srule][key] = val
617643
self.apispecs[endpoint] = data
618644

645+
# if is_openapi3(openapi_version):
646+
# del data['definitions']
619647
if is_openapi3(openapi_version):
620-
del data['definitions']
648+
# Copy definitions to components/schemas
649+
if definitions:
650+
data.setdefault('components', {}).setdefault('schemas', {}).update(definitions)
621651

622652
return data
623653

@@ -786,7 +816,7 @@ def update_schemas_parsers(self, doc, schemas, parsers, definitions):
786816
'''
787817
Schemas and parsers would be updated here from doc
788818
'''
789-
if self.is_openapi3():
819+
if is_openapi3(self.config.get('openapi')):
790820
# 'json' to comply with self.SCHEMA_LOCATIONS's {'body':'json'}
791821
location = 'json'
792822
json_schema = None

src/flask_openapi/utils.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from collections import defaultdict, OrderedDict
1212
from copy import deepcopy
1313
from functools import wraps
14-
from importlib import import_module
1514

1615
import jsonschema
1716
import yaml
@@ -1070,3 +1069,26 @@ def get_swag_path_from_doc_dir(method: any, view_class: any, doc_dir: str, endpo
10701069
logging.exception(f"{file_path} is not a file")
10711070

10721071
return file_path
1072+
1073+
def convert_references_to_openapi3(obj):
1074+
for key, val in obj.items():
1075+
if key == '$ref':
1076+
obj[key] = val.replace('definitions', 'components/schemas')
1077+
1078+
if isinstance(val, dict):
1079+
convert_references_to_openapi3(val)
1080+
1081+
1082+
def convert_response_definitions_to_openapi3(response, media_types):
1083+
if 'schema' in response:
1084+
convert_references_to_openapi3(response['schema'])
1085+
if 'content' not in response:
1086+
response['content'] = {}
1087+
for media_type in media_types:
1088+
response['content'][media_type] = {'schema': dict(response['schema'])}
1089+
del response['schema']
1090+
1091+
1092+
def convert_responses_to_openapi3(responses, media_types):
1093+
for val in responses.values():
1094+
convert_response_definitions_to_openapi3(val, media_types)

0 commit comments

Comments
 (0)