Skip to content

Commit e16b094

Browse files
committed
feat: Added pydantic v2 to models
1 parent 3770471 commit e16b094

11 files changed

+142
-50
lines changed

src/openapi_python_generator/__main__.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing import Optional
2+
from enum import Enum
23

34
import click
45

56
from openapi_python_generator import __version__
6-
from openapi_python_generator.common import HTTPLibrary
7+
from openapi_python_generator.common import HTTPLibrary, PydanticVersion
78
from openapi_python_generator.generate_data import generate_data
89

9-
1010
@click.command()
1111
@click.argument("source")
1212
@click.argument("output")
@@ -38,6 +38,13 @@
3838
default=None,
3939
help="Custom template path to use. Allows overriding of the built in templates",
4040
)
41+
@click.option(
42+
"--pydantic-version",
43+
type=click.Choice(["v1", "v2"]),
44+
default="v2",
45+
show_default=True,
46+
help="Pydantic version to use for generated models.",
47+
)
4148
@click.version_option(version=__version__)
4249
def main(
4350
source: str,
@@ -46,6 +53,7 @@ def main(
4653
env_token_name: Optional[str] = None,
4754
use_orjson: bool = False,
4855
custom_template_path: Optional[str] = None,
56+
pydantic_version: PydanticVersion = PydanticVersion.V2,
4957
) -> None:
5058
"""
5159
Generate Python code from an OpenAPI 3.0 specification.
@@ -54,9 +62,9 @@ def main(
5462
an OUTPUT path, where the resulting client is created.
5563
"""
5664
generate_data(
57-
source, output, library, env_token_name, use_orjson, custom_template_path
65+
source, output, library, env_token_name, use_orjson, custom_template_path,pydantic_version
5866
)
5967

6068

6169
if __name__ == "__main__": # pragma: no cover
62-
main()
70+
main()

src/openapi_python_generator/common.py

+4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ class HTTPLibrary(str, Enum):
1414
requests = "requests"
1515
aiohttp = "aiohttp"
1616

17+
class PydanticVersion(str, Enum):
18+
V1 = "v1"
19+
V2 = "v2"
20+
1721

1822
library_config_dict: Dict[Optional[HTTPLibrary], LibraryConfig] = {
1923
HTTPLibrary.httpx: LibraryConfig(

src/openapi_python_generator/generate_data.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from openapi_pydantic.v3.v3_0 import OpenAPI
1414
from pydantic import ValidationError
1515

16-
from .common import HTTPLibrary
16+
from .common import HTTPLibrary, PydanticVersion
1717
from .common import library_config_dict
1818
from .language_converters.python.generator import generator
1919
from .language_converters.python.jinja_config import SERVICE_TEMPLATE
@@ -140,6 +140,7 @@ def generate_data(
140140
env_token_name: Optional[str] = None,
141141
use_orjson: bool = False,
142142
custom_template_path: Optional[str] = None,
143+
pydantic_version: PydanticVersion = PydanticVersion.V2,
143144
) -> None:
144145
"""
145146
Generate Python code from an OpenAPI 3.0 specification.
@@ -153,6 +154,7 @@ def generate_data(
153154
env_token_name,
154155
use_orjson,
155156
custom_template_path,
157+
pydantic_version,
156158
)
157159

158160
write_data(result, output)

src/openapi_python_generator/language_converters/python/api_config_generator.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from openapi_pydantic.v3.v3_0 import OpenAPI
44

5+
from openapi_python_generator.common import PydanticVersion
56
from openapi_python_generator.language_converters.python.jinja_config import (
6-
API_CONFIG_TEMPLATE,
7+
API_CONFIG_TEMPLATE, API_CONFIG_TEMPLATE_PYDANTIC_V2,
78
)
89
from openapi_python_generator.language_converters.python.jinja_config import (
910
create_jinja_env,
@@ -12,15 +13,18 @@
1213

1314

1415
def generate_api_config(
15-
data: OpenAPI, env_token_name: Optional[str] = None
16+
data: OpenAPI, env_token_name: Optional[str] = None,
17+
pydantic_version: PydanticVersion = PydanticVersion.V2,
1618
) -> APIConfig:
1719
"""
1820
Generate the API model.
1921
"""
22+
23+
template_name = API_CONFIG_TEMPLATE_PYDANTIC_V2 if pydantic_version == PydanticVersion.V2 else API_CONFIG_TEMPLATE
2024
jinja_env = create_jinja_env()
2125
return APIConfig(
2226
file_name="api_config",
23-
content=jinja_env.get_template(API_CONFIG_TEMPLATE).render(
27+
content=jinja_env.get_template(template_name).render(
2428
env_token_name=env_token_name, **data.dict()
2529
),
2630
base_url=data.servers[0].url if len(data.servers) > 0 else "NO SERVER",

src/openapi_python_generator/language_converters/python/generator.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from openapi_pydantic.v3.v3_0 import OpenAPI
44

5+
from openapi_python_generator.common import PydanticVersion
56
from openapi_python_generator.language_converters.python import common
67
from openapi_python_generator.language_converters.python.api_config_generator import (
78
generate_api_config,
@@ -22,6 +23,7 @@ def generator(
2223
env_token_name: Optional[str] = None,
2324
use_orjson: bool = False,
2425
custom_template_path: Optional[str] = None,
26+
pydantic_version: PydanticVersion = PydanticVersion.V2,
2527
) -> ConversionResult:
2628
"""
2729
Generate Python code from an OpenAPI 3.0 specification.
@@ -31,7 +33,7 @@ def generator(
3133
common.set_custom_template_path(custom_template_path)
3234

3335
if data.components is not None:
34-
models = generate_models(data.components)
36+
models = generate_models(data.components, pydantic_version)
3537
else:
3638
models = []
3739

@@ -40,7 +42,7 @@ def generator(
4042
else:
4143
services = []
4244

43-
api_config = generate_api_config(data, env_token_name)
45+
api_config = generate_api_config(data, env_token_name, pydantic_version)
4446

4547
return ConversionResult(
4648
models=models,

src/openapi_python_generator/language_converters/python/jinja_config.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
ENUM_TEMPLATE = "enum.jinja2"
1111
MODELS_TEMPLATE = "models.jinja2"
12+
MODELS_TEMPLATE_PYDANTIC_V2 = "models_pydantic_2.jinja2"
1213
SERVICE_TEMPLATE = "service.jinja2"
1314
HTTPX_TEMPLATE = "httpx.jinja2"
1415
API_CONFIG_TEMPLATE = "apiconfig.jinja2"
16+
API_CONFIG_TEMPLATE_PYDANTIC_V2 = "apiconfig_pydantic_2.jinja2"
1517
TEMPLATE_PATH = Path(__file__).parent / "templates"
1618

1719

src/openapi_python_generator/language_converters/python/model_generator.py

+33-28
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import click
77
from openapi_pydantic.v3.v3_0 import Schema, Reference, Components
88

9+
from openapi_python_generator.common import PydanticVersion
910
from openapi_python_generator.language_converters.python import common
1011
from openapi_python_generator.language_converters.python.jinja_config import (
11-
ENUM_TEMPLATE,
12+
ENUM_TEMPLATE, MODELS_TEMPLATE_PYDANTIC_V2,
1213
)
1314
from openapi_python_generator.language_converters.python.jinja_config import (
1415
MODELS_TEMPLATE,
@@ -22,9 +23,9 @@
2223

2324

2425
def type_converter( # noqa: C901
25-
schema: Schema,
26-
required: bool = False,
27-
model_name: Optional[str] = None,
26+
schema: Schema,
27+
required: bool = False,
28+
model_name: Optional[str] = None,
2829
) -> TypeConversion:
2930
"""
3031
Converts an OpenAPI type to a Python type.
@@ -69,13 +70,13 @@ def type_converter( # noqa: C901
6970
)
7071

7172
original_type = (
72-
"tuple<" + ",".join([i.original_type for i in conversions]) + ">"
73+
"tuple<" + ",".join([i.original_type for i in conversions]) + ">"
7374
)
7475
if len(conversions) == 1:
7576
converted_type = conversions[0].converted_type
7677
else:
7778
converted_type = (
78-
"Tuple[" + ",".join([i.converted_type for i in conversions]) + "]"
79+
"Tuple[" + ",".join([i.converted_type for i in conversions]) + "]"
7980
)
8081

8182
converted_type = pre_type + converted_type + post_type
@@ -101,14 +102,14 @@ def type_converter( # noqa: C901
101102
)
102103
)
103104
original_type = (
104-
"union<" + ",".join([i.original_type for i in conversions]) + ">"
105+
"union<" + ",".join([i.original_type for i in conversions]) + ">"
105106
)
106107

107108
if len(conversions) == 1:
108109
converted_type = conversions[0].converted_type
109110
else:
110111
converted_type = (
111-
"Union[" + ",".join([i.converted_type for i in conversions]) + "]"
112+
"Union[" + ",".join([i.converted_type for i in conversions]) + "]"
112113
)
113114

114115
converted_type = pre_type + converted_type + post_type
@@ -120,13 +121,13 @@ def type_converter( # noqa: C901
120121
# We only want to auto convert to datetime if orjson is used throghout the code, otherwise we can not
121122
# serialize it to JSON.
122123
elif schema.type == "string" and (
123-
schema.schema_format is None or not common.get_use_orjson()
124+
schema.schema_format is None or not common.get_use_orjson()
124125
):
125126
converted_type = pre_type + "str" + post_type
126127
elif (
127-
schema.type == "string"
128-
and schema.schema_format.startswith("uuid")
129-
and common.get_use_orjson()
128+
schema.type == "string"
129+
and schema.schema_format.startswith("uuid")
130+
and common.get_use_orjson()
130131
):
131132
if len(schema.schema_format) > 4 and schema.schema_format[4].isnumeric():
132133
uuid_type = schema.schema_format.upper()
@@ -154,7 +155,8 @@ def type_converter( # noqa: C901
154155
original_type = "array<" + converted_reference.type.original_type + ">"
155156
retVal += converted_reference.type.converted_type
156157
elif isinstance(schema.items, Schema):
157-
original_type = "array<" + (str(schema.items.type.value) if schema.items.type is not None else "unknown")+ ">"
158+
original_type = "array<" + (
159+
str(schema.items.type.value) if schema.items.type is not None else "unknown") + ">"
158160
retVal += type_converter(schema.items, True).converted_type
159161
else:
160162
original_type = "array<unknown>"
@@ -178,7 +180,7 @@ def type_converter( # noqa: C901
178180

179181

180182
def _generate_property_from_schema(
181-
model_name: str, name: str, schema: Schema, parent_schema: Optional[Schema] = None
183+
model_name: str, name: str, schema: Schema, parent_schema: Optional[Schema] = None
182184
) -> Property:
183185
"""
184186
Generates a property from a schema. It takes the type of the schema and converts it to a python type, and then
@@ -190,9 +192,9 @@ def _generate_property_from_schema(
190192
:return: Property
191193
"""
192194
required = (
193-
parent_schema is not None
194-
and parent_schema.required is not None
195-
and name in parent_schema.required
195+
parent_schema is not None
196+
and parent_schema.required is not None
197+
and name in parent_schema.required
196198
)
197199

198200
import_type = None
@@ -209,11 +211,11 @@ def _generate_property_from_schema(
209211

210212

211213
def _generate_property_from_reference(
212-
model_name: str,
213-
name: str,
214-
reference: Reference,
215-
parent_schema: Optional[Schema] = None,
216-
force_required: bool = False,
214+
model_name: str,
215+
name: str,
216+
reference: Reference,
217+
parent_schema: Optional[Schema] = None,
218+
force_required: bool = False,
217219
) -> Property:
218220
"""
219221
Generates a property from a reference. It takes the name of the reference as the type, and then
@@ -225,10 +227,10 @@ def _generate_property_from_reference(
225227
:return: Property and model to be imported by the file
226228
"""
227229
required = (
228-
parent_schema is not None
229-
and parent_schema.required is not None
230-
and name in parent_schema.required
231-
) or force_required
230+
parent_schema is not None
231+
and parent_schema.required is not None
232+
and name in parent_schema.required
233+
) or force_required
232234
import_model = common.normalize_symbol(reference.ref.split("/")[-1])
233235

234236
if import_model == model_name:
@@ -256,13 +258,14 @@ def _generate_property_from_reference(
256258
)
257259

258260

259-
def generate_models(components: Components) -> List[Model]:
261+
def generate_models(components: Components, pydantic_version: PydanticVersion = PydanticVersion.V2) -> List[Model]:
260262
"""
261263
Receives components from an OpenAPI 3.0 specification and generates the models from it.
262264
It does so, by iterating over the components.schemas dictionary. For each schema, it checks if
263265
it is a normal schema (i.e. simple type like string, integer, etc.), a reference to another schema, or
264266
an array of types/references. It then computes pydantic models from it using jinja2
265267
:param components: The components from an OpenAPI 3.0 specification.
268+
:param pydantic_version: The version of pydantic to use.
266269
:return: A list of models.
267270
"""
268271
models: List[Model] = []
@@ -313,7 +316,9 @@ def generate_models(components: Components) -> List[Model]:
313316
)
314317
properties.append(conv_property)
315318

316-
generated_content = jinja_env.get_template(MODELS_TEMPLATE).render(
319+
template_name = MODELS_TEMPLATE_PYDANTIC_V2 if pydantic_version == PydanticVersion.V2 else MODELS_TEMPLATE
320+
321+
generated_content = jinja_env.get_template(template_name).render(
317322
schema_name=name, schema=schema_or_reference, properties=properties
318323
)
319324

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
{% if env_token_name is not none %}import os{% endif %}
2+
3+
from pydantic import BaseModel, Field
4+
from typing import Optional, Union
5+
6+
class APIConfig(BaseModel):
7+
model_config = {
8+
"validate_assignment": True
9+
}
10+
11+
base_path: str = {% if servers|length > 0 %} '{{ servers[0].url }}' {% else %} 'NO SERVER' {% endif %}
12+
13+
verify: Union[bool, str] = True
14+
{% if env_token_name is none %}
15+
access_token : Optional[str] = None
16+
{% endif %}
17+
18+
def get_access_token(self) -> Optional[str]:
19+
{% if env_token_name is not none %}
20+
try:
21+
return os.environ['{{ env_token_name }}']
22+
except KeyError:
23+
return None
24+
{% else %}
25+
return self.access_token
26+
{% endif %}
27+
28+
def set_access_token(self, value : str):
29+
{% if env_token_name is not none %}
30+
raise Exception("This client was generated with an environment variable for the access token. Please set the environment variable '{{ env_token_name }}' to the access token.")
31+
{% else %}
32+
self.access_token = value
33+
{% endif %}
34+
35+
class HTTPException(Exception):
36+
def __init__(self, status_code: int, message: str):
37+
self.status_code = status_code
38+
self.message = message
39+
super().__init__(f"{status_code} {message}")
40+
41+
def __str__(self):
42+
return f"{self.status_code} {self.message}"

0 commit comments

Comments
 (0)