diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0220afe..ea9b8c0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: language: system types: [ python ] - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.11.5 hooks: - id: isort name: isort (python) diff --git a/docs/references/index.md b/docs/references/index.md index e6c34e7..791043a 100644 --- a/docs/references/index.md +++ b/docs/references/index.md @@ -18,5 +18,6 @@ Options: --library The library to use. Defaults to `httpx`. --env-token-name The name of the environment variable to use for the API key. Defaults to `access_token`. --use-orjson Use the `orjson` library for serialization. Defaults to `false`. +--custom-template-path Use a custom template path to override the built in templates. -h, --help Show this help message and exit. ``` diff --git a/src/openapi_python_generator/__main__.py b/src/openapi_python_generator/__main__.py index 36b7f1b..9252b70 100644 --- a/src/openapi_python_generator/__main__.py +++ b/src/openapi_python_generator/__main__.py @@ -32,6 +32,12 @@ help="Use the orjson library to serialize the data. This is faster than the default json library and provides " "serialization of datetimes and other types that are not supported by the default json library.", ) +@click.option( + "--custom-template-path", + type=str, + default=None, + help="Custom template path to use. Allows overriding of the built in templates", +) @click.version_option(version=__version__) def main( source: str, @@ -39,6 +45,7 @@ def main( library: Optional[HTTPLibrary] = HTTPLibrary.httpx, env_token_name: Optional[str] = None, use_orjson: bool = False, + custom_template_path: Optional[str] = None, ) -> None: """ Generate Python code from an OpenAPI 3.0 specification. @@ -46,7 +53,9 @@ def main( Provide a SOURCE (file or URL) containing the OpenAPI 3 specification and an OUTPUT path, where the resulting client is created. """ - generate_data(source, output, library, env_token_name, use_orjson) + generate_data( + source, output, library, env_token_name, use_orjson, custom_template_path + ) if __name__ == "__main__": # pragma: no cover diff --git a/src/openapi_python_generator/generate_data.py b/src/openapi_python_generator/generate_data.py index 0111144..25be063 100644 --- a/src/openapi_python_generator/generate_data.py +++ b/src/openapi_python_generator/generate_data.py @@ -16,8 +16,8 @@ from .common import HTTPLibrary from .common import library_config_dict from .language_converters.python.generator import generator -from .language_converters.python.jinja_config import JINJA_ENV from .language_converters.python.jinja_config import SERVICE_TEMPLATE +from .language_converters.python.jinja_config import create_jinja_env from .models import ConversionResult @@ -109,13 +109,14 @@ def write_data(data: ConversionResult, output: Union[str, Path]) -> None: files = [] # Write the services. + jinja_env = create_jinja_env() for service in data.services: if len(service.operations) == 0: continue files.append(service.file_name) write_code( services_path / f"{service.file_name}.py", - JINJA_ENV.get_template(SERVICE_TEMPLATE).render(**service.dict()), + jinja_env.get_template(SERVICE_TEMPLATE).render(**service.dict()), ) # Create services.__init__.py file containing imports to all services. @@ -137,6 +138,7 @@ def generate_data( library: Optional[HTTPLibrary] = HTTPLibrary.httpx, env_token_name: Optional[str] = None, use_orjson: bool = False, + custom_template_path: Optional[str] = None, ) -> None: """ Generate Python code from an OpenAPI 3.0 specification. @@ -144,6 +146,12 @@ def generate_data( data = get_open_api(source) click.echo(f"Generating data from {source}") - result = generator(data, library_config_dict[library], env_token_name, use_orjson) + result = generator( + data, + library_config_dict[library], + env_token_name, + use_orjson, + custom_template_path, + ) write_data(result, output) diff --git a/src/openapi_python_generator/language_converters/python/api_config_generator.py b/src/openapi_python_generator/language_converters/python/api_config_generator.py index e0ce6f4..8b110c1 100644 --- a/src/openapi_python_generator/language_converters/python/api_config_generator.py +++ b/src/openapi_python_generator/language_converters/python/api_config_generator.py @@ -5,7 +5,9 @@ from openapi_python_generator.language_converters.python.jinja_config import ( API_CONFIG_TEMPLATE, ) -from openapi_python_generator.language_converters.python.jinja_config import JINJA_ENV +from openapi_python_generator.language_converters.python.jinja_config import ( + create_jinja_env, +) from openapi_python_generator.models import APIConfig @@ -15,9 +17,10 @@ def generate_api_config( """ Generate the API model. """ + jinja_env = create_jinja_env() return APIConfig( file_name="api_config", - content=JINJA_ENV.get_template(API_CONFIG_TEMPLATE).render( + content=jinja_env.get_template(API_CONFIG_TEMPLATE).render( env_token_name=env_token_name, **data.dict() ), base_url=data.servers[0].url if len(data.servers) > 0 else "NO SERVER", diff --git a/src/openapi_python_generator/language_converters/python/common.py b/src/openapi_python_generator/language_converters/python/common.py index 34fb10c..e3c55c4 100644 --- a/src/openapi_python_generator/language_converters/python/common.py +++ b/src/openapi_python_generator/language_converters/python/common.py @@ -1,8 +1,10 @@ import keyword import re +from typing import Optional _use_orjson: bool = False +_custom_template_path: str = None _symbol_ascii_strip_re = re.compile(r"[^A-Za-z0-9_]") @@ -24,6 +26,24 @@ def get_use_orjson() -> bool: return _use_orjson +def set_custom_template_path(value: Optional[str]) -> None: + """ + Set the value of the global variable _custom_template_path. + :param value: value of the variable + """ + global _custom_template_path + _custom_template_path = value + + +def get_custom_template_path() -> Optional[str]: + """ + Get the value of the global variable _custom_template_path. + :return: value of the variable + """ + global _custom_template_path + return _custom_template_path + + def normalize_symbol(symbol: str) -> str: """ Remove invalid characters & keywords in Python symbol names diff --git a/src/openapi_python_generator/language_converters/python/generator.py b/src/openapi_python_generator/language_converters/python/generator.py index d7f9086..204d25c 100644 --- a/src/openapi_python_generator/language_converters/python/generator.py +++ b/src/openapi_python_generator/language_converters/python/generator.py @@ -21,12 +21,14 @@ def generator( library_config: LibraryConfig, env_token_name: Optional[str] = None, use_orjson: bool = False, + custom_template_path: Optional[str] = None, ) -> ConversionResult: """ Generate Python code from an OpenAPI 3.0 specification. """ common.set_use_orjson(use_orjson) + common.set_custom_template_path(custom_template_path) if data.components is not None: models = generate_models(data.components) diff --git a/src/openapi_python_generator/language_converters/python/jinja_config.py b/src/openapi_python_generator/language_converters/python/jinja_config.py index f133281..8855399 100644 --- a/src/openapi_python_generator/language_converters/python/jinja_config.py +++ b/src/openapi_python_generator/language_converters/python/jinja_config.py @@ -1,8 +1,11 @@ from pathlib import Path +from jinja2 import ChoiceLoader from jinja2 import Environment from jinja2 import FileSystemLoader +from . import common + ENUM_TEMPLATE = "enum.jinja2" MODELS_TEMPLATE = "models.jinja2" @@ -11,6 +14,20 @@ API_CONFIG_TEMPLATE = "apiconfig.jinja2" TEMPLATE_PATH = Path(__file__).parent / "templates" -JINJA_ENV = Environment( - loader=FileSystemLoader(TEMPLATE_PATH), autoescape=True, trim_blocks=True -) + +def create_jinja_env(): + custom_template_path = common.get_custom_template_path() + return Environment( + loader=( + ChoiceLoader( + [ + FileSystemLoader(custom_template_path), + FileSystemLoader(TEMPLATE_PATH), + ] + ) + if custom_template_path is not None + else FileSystemLoader(TEMPLATE_PATH) + ), + autoescape=True, + trim_blocks=True, + ) diff --git a/src/openapi_python_generator/language_converters/python/model_generator.py b/src/openapi_python_generator/language_converters/python/model_generator.py index bde8fab..4f51e22 100644 --- a/src/openapi_python_generator/language_converters/python/model_generator.py +++ b/src/openapi_python_generator/language_converters/python/model_generator.py @@ -12,10 +12,12 @@ from openapi_python_generator.language_converters.python.jinja_config import ( ENUM_TEMPLATE, ) -from openapi_python_generator.language_converters.python.jinja_config import JINJA_ENV from openapi_python_generator.language_converters.python.jinja_config import ( MODELS_TEMPLATE, ) +from openapi_python_generator.language_converters.python.jinja_config import ( + create_jinja_env, +) from openapi_python_generator.models import Model from openapi_python_generator.models import Property from openapi_python_generator.models import TypeConversion @@ -264,6 +266,7 @@ def generate_models(components: Components) -> List[Model]: if components.schemas is None: return models + jinja_env = create_jinja_env() for schema_name, schema_or_reference in components.schemas.items(): name = common.normalize_symbol(schema_name) if schema_or_reference.enum is not None: @@ -275,7 +278,7 @@ def generate_models(components: Components) -> List[Model]: ] m = Model( file_name=name, - content=JINJA_ENV.get_template(ENUM_TEMPLATE).render( + content=jinja_env.get_template(ENUM_TEMPLATE).render( name=name, **value_dict ), openapi_object=schema_or_reference, @@ -306,7 +309,7 @@ def generate_models(components: Components) -> List[Model]: ) properties.append(conv_property) - generated_content = JINJA_ENV.get_template(MODELS_TEMPLATE).render( + generated_content = jinja_env.get_template(MODELS_TEMPLATE).render( schema_name=name, schema=schema_or_reference, properties=properties ) diff --git a/src/openapi_python_generator/language_converters/python/service_generator.py b/src/openapi_python_generator/language_converters/python/service_generator.py index 7b6af1e..77937f1 100644 --- a/src/openapi_python_generator/language_converters/python/service_generator.py +++ b/src/openapi_python_generator/language_converters/python/service_generator.py @@ -18,7 +18,9 @@ from openapi_python_generator.language_converters.python import common from openapi_python_generator.language_converters.python.common import normalize_symbol -from openapi_python_generator.language_converters.python.jinja_config import JINJA_ENV +from openapi_python_generator.language_converters.python.jinja_config import ( + create_jinja_env, +) from openapi_python_generator.language_converters.python.model_generator import ( type_converter, ) @@ -269,6 +271,7 @@ def generate_services( :param paths: paths object to be converted :return: List of services """ + jinja_env = create_jinja_env() def generate_service_operation( op: Operation, path_name: str, async_type: bool @@ -296,7 +299,7 @@ def generate_service_operation( use_orjson=common.get_use_orjson(), ) - so.content = JINJA_ENV.get_template(library_config.template_name).render( + so.content = jinja_env.get_template(library_config.template_name).render( **so.dict() )