Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic template customisation #60

Merged
merged 2 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
language: system
types: [ python ]
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.11.5
hooks:
- id: isort
name: isort (python)
1 change: 1 addition & 0 deletions docs/references/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ Options:
--library The library to use. Defaults to `httpx`.
--env-token-name The name of the environment variable to use for the API key. Defaults to `access_token`.
--use-orjson Use the `orjson` library for serialization. Defaults to `false`.
--custom-template-path Use a custom template path to override the built in templates.
-h, --help Show this help message and exit.
```
11 changes: 10 additions & 1 deletion src/openapi_python_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,30 @@
help="Use the orjson library to serialize the data. This is faster than the default json library and provides "
"serialization of datetimes and other types that are not supported by the default json library.",
)
@click.option(
"--custom-template-path",
type=str,
default=None,
help="Custom template path to use. Allows overriding of the built in templates",
)
@click.version_option(version=__version__)
def main(
source: str,
output: str,
library: Optional[HTTPLibrary] = HTTPLibrary.httpx,
env_token_name: Optional[str] = None,
use_orjson: bool = False,
custom_template_path: Optional[str] = None,
) -> None:
"""
Generate Python code from an OpenAPI 3.0 specification.

Provide a SOURCE (file or URL) containing the OpenAPI 3 specification and
an OUTPUT path, where the resulting client is created.
"""
generate_data(source, output, library, env_token_name, use_orjson)
generate_data(
source, output, library, env_token_name, use_orjson, custom_template_path
)


if __name__ == "__main__": # pragma: no cover
Expand Down
14 changes: 11 additions & 3 deletions src/openapi_python_generator/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from .common import HTTPLibrary
from .common import library_config_dict
from .language_converters.python.generator import generator
from .language_converters.python.jinja_config import JINJA_ENV
from .language_converters.python.jinja_config import SERVICE_TEMPLATE
from .language_converters.python.jinja_config import create_jinja_env
from .models import ConversionResult


Expand Down Expand Up @@ -109,13 +109,14 @@ def write_data(data: ConversionResult, output: Union[str, Path]) -> None:
files = []

# Write the services.
jinja_env = create_jinja_env()
for service in data.services:
if len(service.operations) == 0:
continue
files.append(service.file_name)
write_code(
services_path / f"{service.file_name}.py",
JINJA_ENV.get_template(SERVICE_TEMPLATE).render(**service.dict()),
jinja_env.get_template(SERVICE_TEMPLATE).render(**service.dict()),
)

# Create services.__init__.py file containing imports to all services.
Expand All @@ -137,13 +138,20 @@ def generate_data(
library: Optional[HTTPLibrary] = HTTPLibrary.httpx,
env_token_name: Optional[str] = None,
use_orjson: bool = False,
custom_template_path: Optional[str] = None,
) -> None:
"""
Generate Python code from an OpenAPI 3.0 specification.
"""
data = get_open_api(source)
click.echo(f"Generating data from {source}")

result = generator(data, library_config_dict[library], env_token_name, use_orjson)
result = generator(
data,
library_config_dict[library],
env_token_name,
use_orjson,
custom_template_path,
)

write_data(result, output)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from openapi_python_generator.language_converters.python.jinja_config import (
API_CONFIG_TEMPLATE,
)
from openapi_python_generator.language_converters.python.jinja_config import JINJA_ENV
from openapi_python_generator.language_converters.python.jinja_config import (
create_jinja_env,
)
from openapi_python_generator.models import APIConfig


Expand All @@ -15,9 +17,10 @@ def generate_api_config(
"""
Generate the API model.
"""
jinja_env = create_jinja_env()
return APIConfig(
file_name="api_config",
content=JINJA_ENV.get_template(API_CONFIG_TEMPLATE).render(
content=jinja_env.get_template(API_CONFIG_TEMPLATE).render(
env_token_name=env_token_name, **data.dict()
),
base_url=data.servers[0].url if len(data.servers) > 0 else "NO SERVER",
Expand Down
20 changes: 20 additions & 0 deletions src/openapi_python_generator/language_converters/python/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import keyword
import re
from typing import Optional


_use_orjson: bool = False
_custom_template_path: str = None
_symbol_ascii_strip_re = re.compile(r"[^A-Za-z0-9_]")


Expand All @@ -24,6 +26,24 @@ def get_use_orjson() -> bool:
return _use_orjson


def set_custom_template_path(value: Optional[str]) -> None:
"""
Set the value of the global variable _custom_template_path.
:param value: value of the variable
"""
global _custom_template_path
_custom_template_path = value


def get_custom_template_path() -> Optional[str]:
"""
Get the value of the global variable _custom_template_path.
:return: value of the variable
"""
global _custom_template_path
return _custom_template_path


def normalize_symbol(symbol: str) -> str:
"""
Remove invalid characters & keywords in Python symbol names
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ def generator(
library_config: LibraryConfig,
env_token_name: Optional[str] = None,
use_orjson: bool = False,
custom_template_path: Optional[str] = None,
) -> ConversionResult:
"""
Generate Python code from an OpenAPI 3.0 specification.
"""

common.set_use_orjson(use_orjson)
common.set_custom_template_path(custom_template_path)

if data.components is not None:
models = generate_models(data.components)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from pathlib import Path

from jinja2 import ChoiceLoader
from jinja2 import Environment
from jinja2 import FileSystemLoader

from . import common


ENUM_TEMPLATE = "enum.jinja2"
MODELS_TEMPLATE = "models.jinja2"
Expand All @@ -11,6 +14,20 @@
API_CONFIG_TEMPLATE = "apiconfig.jinja2"
TEMPLATE_PATH = Path(__file__).parent / "templates"

JINJA_ENV = Environment(
loader=FileSystemLoader(TEMPLATE_PATH), autoescape=True, trim_blocks=True
)

def create_jinja_env():
custom_template_path = common.get_custom_template_path()
return Environment(
loader=(
ChoiceLoader(
[
FileSystemLoader(custom_template_path),
FileSystemLoader(TEMPLATE_PATH),
]
)
if custom_template_path is not None
else FileSystemLoader(TEMPLATE_PATH)
),
autoescape=True,
trim_blocks=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from openapi_python_generator.language_converters.python.jinja_config import (
ENUM_TEMPLATE,
)
from openapi_python_generator.language_converters.python.jinja_config import JINJA_ENV
from openapi_python_generator.language_converters.python.jinja_config import (
MODELS_TEMPLATE,
)
from openapi_python_generator.language_converters.python.jinja_config import (
create_jinja_env,
)
from openapi_python_generator.models import Model
from openapi_python_generator.models import Property
from openapi_python_generator.models import TypeConversion
Expand Down Expand Up @@ -264,6 +266,7 @@ def generate_models(components: Components) -> List[Model]:
if components.schemas is None:
return models

jinja_env = create_jinja_env()
for schema_name, schema_or_reference in components.schemas.items():
name = common.normalize_symbol(schema_name)
if schema_or_reference.enum is not None:
Expand All @@ -275,7 +278,7 @@ def generate_models(components: Components) -> List[Model]:
]
m = Model(
file_name=name,
content=JINJA_ENV.get_template(ENUM_TEMPLATE).render(
content=jinja_env.get_template(ENUM_TEMPLATE).render(
name=name, **value_dict
),
openapi_object=schema_or_reference,
Expand Down Expand Up @@ -306,7 +309,7 @@ def generate_models(components: Components) -> List[Model]:
)
properties.append(conv_property)

generated_content = JINJA_ENV.get_template(MODELS_TEMPLATE).render(
generated_content = jinja_env.get_template(MODELS_TEMPLATE).render(
schema_name=name, schema=schema_or_reference, properties=properties
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

from openapi_python_generator.language_converters.python import common
from openapi_python_generator.language_converters.python.common import normalize_symbol
from openapi_python_generator.language_converters.python.jinja_config import JINJA_ENV
from openapi_python_generator.language_converters.python.jinja_config import (
create_jinja_env,
)
from openapi_python_generator.language_converters.python.model_generator import (
type_converter,
)
Expand Down Expand Up @@ -269,6 +271,7 @@ def generate_services(
:param paths: paths object to be converted
:return: List of services
"""
jinja_env = create_jinja_env()

def generate_service_operation(
op: Operation, path_name: str, async_type: bool
Expand Down Expand Up @@ -296,7 +299,7 @@ def generate_service_operation(
use_orjson=common.get_use_orjson(),
)

so.content = JINJA_ENV.get_template(library_config.template_name).render(
so.content = jinja_env.get_template(library_config.template_name).render(
**so.dict()
)

Expand Down