diff --git a/docs/references/index.md b/docs/references/index.md index cc481cf..a928966 100644 --- a/docs/references/index.md +++ b/docs/references/index.md @@ -16,7 +16,7 @@ Arguments: Options: ```console --library [httpx|requests|aiohttp] - HTTP library to use in the generation of the client. + HTTP library to use in the generation of the client. Defaults to 'httpx'. --env-token-name TEXT Name of the environment variable that contains the token. @@ -36,6 +36,10 @@ Options: Pydantic version to use for generated models. Defaults to 'v2'. +--formatter [black|none] + Option to choose which auto formatter is applied. + Defaults to 'black'. + --version Show the version and exit. -h, --help Show this help message and exit. ``` diff --git a/src/openapi_python_generator/__main__.py b/src/openapi_python_generator/__main__.py index de3af47..f2473e5 100644 --- a/src/openapi_python_generator/__main__.py +++ b/src/openapi_python_generator/__main__.py @@ -1,10 +1,9 @@ from typing import Optional -from enum import Enum import click from openapi_python_generator import __version__ -from openapi_python_generator.common import HTTPLibrary, PydanticVersion +from openapi_python_generator.common import Formatter, HTTPLibrary, PydanticVersion from openapi_python_generator.generate_data import generate_data @click.command() @@ -45,6 +44,13 @@ show_default=True, help="Pydantic version to use for generated models.", ) +@click.option( + "--formatter", + type=click.Choice(["black", "none"]), + default="black", + show_default=True, + help="Option to choose which auto formatter is applied.", +) @click.version_option(version=__version__) def main( source: str, @@ -54,6 +60,7 @@ def main( use_orjson: bool = False, custom_template_path: Optional[str] = None, pydantic_version: PydanticVersion = PydanticVersion.V2, + formatter: Formatter = Formatter.BLACK, ) -> None: """ Generate Python code from an OpenAPI 3.0 specification. @@ -62,7 +69,7 @@ def main( an OUTPUT path, where the resulting client is created. """ generate_data( - source, output, library, env_token_name, use_orjson, custom_template_path,pydantic_version + source, output, library, env_token_name, use_orjson, custom_template_path, pydantic_version, formatter ) diff --git a/src/openapi_python_generator/common.py b/src/openapi_python_generator/common.py index 9748763..fea7713 100644 --- a/src/openapi_python_generator/common.py +++ b/src/openapi_python_generator/common.py @@ -14,11 +14,25 @@ class HTTPLibrary(str, Enum): requests = "requests" aiohttp = "aiohttp" + class PydanticVersion(str, Enum): V1 = "v1" V2 = "v2" +class Formatter(str, Enum): + """ + Enum for the available code formatters. + """ + + BLACK = "black" + NONE = "none" + +class FormatOptions: + skip_validation: bool = False + line_length: int = 120 + + library_config_dict: Dict[Optional[HTTPLibrary], LibraryConfig] = { HTTPLibrary.httpx: LibraryConfig( name="httpx", diff --git a/src/openapi_python_generator/generate_data.py b/src/openapi_python_generator/generate_data.py index e637012..3bbc8d4 100644 --- a/src/openapi_python_generator/generate_data.py +++ b/src/openapi_python_generator/generate_data.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import List from typing import Optional from typing import Union @@ -14,7 +15,7 @@ from openapi_pydantic.v3.v3_0 import OpenAPI from pydantic import ValidationError -from .common import HTTPLibrary, PydanticVersion +from .common import FormatOptions, Formatter, HTTPLibrary, PydanticVersion from .common import library_config_dict from .language_converters.python.generator import generator from .language_converters.python.jinja_config import SERVICE_TEMPLATE @@ -22,26 +23,31 @@ from .models import ConversionResult -def write_code(path: Path, content) -> None: +def write_code(path: Path, content: str, formatter: Formatter) -> None: """ Write the content to the file at the given path. - :param autoformat: The autoformat applied to the code written. :param path: The path to the file. :param content: The content to write. + :param formatter: The formatter applied to the code written. """ - try: - with open(path, "w") as f: - try: - formatted_contend = black.format_file_contents( - content, fast=False, mode=black.FileMode(line_length=120) - ) + if formatter == Formatter.BLACK: + formatted_contend = format_using_black(content) + elif formatter == Formatter.NONE: + formatted_contend = content + else: + raise NotImplementedError(f"Missing implementation for formatter {formatter!r}.") + with open(path, "w") as f: + f.write(formatted_contend) + - except NothingChanged: - formatted_contend = content - formatted_contend = isort.code(formatted_contend, line_length=120) - f.write(formatted_contend) - except Exception as e: - raise e +def format_using_black(content: str) -> str: + try: + formatted_contend = black.format_file_contents( + content, fast=FormatOptions.skip_validation, mode=black.FileMode(line_length=FormatOptions.line_length) + ) + except NothingChanged: + return content + return isort.code(formatted_contend, line_length=FormatOptions.line_length) def get_open_api(source: Union[str, Path]) -> OpenAPI: @@ -105,14 +111,14 @@ def get_open_api(source: Union[str, Path]) -> OpenAPI: raise -def write_data(data: ConversionResult, output: Union[str, Path]) -> None: +def write_data(data: ConversionResult, output: Union[str, Path], formatter: Formatter) -> None: """ - This function will firstly create the folderstrucutre of output, if it doesn't exist. Then it will create the + This function will firstly create the folder structure of output, if it doesn't exist. Then it will create the models from data.models into the models sub module of the output folder. After this, the services will be created into the services sub module of the output folder. - :param autoformat: The autoformat applied to the code written. :param data: The data to write. :param output: The path to the output folder. + :param formatter: The formatter applied to the code written. """ # Create the folder structure of the output folder. @@ -126,17 +132,18 @@ def write_data(data: ConversionResult, output: Union[str, Path]) -> None: services_path = Path(output) / "services" services_path.mkdir(parents=True, exist_ok=True) - files = [] + files: List[str] = [] # Write the models. for model in data.models: files.append(model.file_name) - write_code(models_path / f"{model.file_name}.py", model.content) + write_code(models_path / f"{model.file_name}.py", model.content, formatter) # Create models.__init__.py file containing imports to all models. write_code( models_path / "__init__.py", "\n".join([f"from .{file} import *" for file in files]), + formatter, ) files = [] @@ -150,18 +157,20 @@ def write_data(data: ConversionResult, output: Union[str, Path]) -> None: write_code( services_path / f"{service.file_name}.py", jinja_env.get_template(SERVICE_TEMPLATE).render(**service.dict()), + formatter, ) # Create services.__init__.py file containing imports to all services. - write_code(services_path / "__init__.py", "") + write_code(services_path / "__init__.py", "", formatter) # Write the api_config.py file. - write_code(Path(output) / "api_config.py", data.api_config.content) + write_code(Path(output) / "api_config.py", data.api_config.content, formatter) # Write the __init__.py file. write_code( Path(output) / "__init__.py", "from .models import *\nfrom .services import *\nfrom .api_config import *", + formatter, ) @@ -173,6 +182,7 @@ def generate_data( use_orjson: bool = False, custom_template_path: Optional[str] = None, pydantic_version: PydanticVersion = PydanticVersion.V2, + formatter: Formatter = Formatter.BLACK, ) -> None: """ Generate Python code from an OpenAPI 3.0 specification. @@ -189,4 +199,4 @@ def generate_data( pydantic_version, ) - write_data(result, output) + write_data(result, output, formatter) diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index 4214941..fd69f3c 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -1,4 +1,6 @@ +from pathlib import Path import shutil +import subprocess import pytest import yaml @@ -6,7 +8,7 @@ from orjson import orjson from pydantic import ValidationError -from openapi_python_generator.common import HTTPLibrary +from openapi_python_generator.common import FormatOptions, Formatter, HTTPLibrary from openapi_python_generator.common import library_config_dict from openapi_python_generator.generate_data import generate_data from openapi_python_generator.generate_data import get_open_api @@ -66,7 +68,7 @@ def test_generate_data(model_data_with_cleanup): def test_write_data(model_data_with_cleanup): result = generator(model_data_with_cleanup, library_config_dict[HTTPLibrary.httpx]) - write_data(result, test_result_path) + write_data(result, test_result_path, Formatter.BLACK) assert test_result_path.exists() assert test_result_path.is_dir() @@ -90,7 +92,7 @@ def test_write_data(model_data_with_cleanup): model_data_copy.paths = None result = generator(model_data_copy, library_config_dict[HTTPLibrary.httpx]) - write_data(result, test_result_path) + write_data(result, test_result_path, Formatter.BLACK) assert test_result_path.exists() assert test_result_path.is_dir() @@ -105,3 +107,78 @@ def test_write_data(model_data_with_cleanup): assert (test_result_path / "models" / "__init__.py").is_file() assert (test_result_path / "__init__.py").exists() assert (test_result_path / "__init__.py").is_file() + +def test_write_formatted_data(model_data_with_cleanup): + result = generator(model_data_with_cleanup, library_config_dict[HTTPLibrary.httpx]) + + # First write code without formatter + write_data(result, test_result_path, Formatter.NONE) + + assert test_result_path.exists() + assert test_result_path.is_dir() + assert (test_result_path / "api_config.py").exists() + assert (test_result_path / "models").exists() + assert (test_result_path / "models").is_dir() + assert (test_result_path / "services").exists() + assert (test_result_path / "services").is_dir() + assert (test_result_path / "models" / "__init__.py").exists() + assert (test_result_path / "services" / "__init__.py").exists() + assert (test_result_path / "services" / "__init__.py").is_file() + assert (test_result_path / "models" / "__init__.py").is_file() + assert (test_result_path / "__init__.py").exists() + assert (test_result_path / "__init__.py").is_file() + + assert not files_are_black_formatted(test_result_path) + + # delete test_result_path folder + shutil.rmtree(test_result_path) + + model_data_copy = model_data_with_cleanup.copy() + model_data_copy.components = None + model_data_copy.paths = None + + result = generator(model_data_copy, library_config_dict[HTTPLibrary.httpx]) + write_data(result, test_result_path, Formatter.BLACK) + + assert test_result_path.exists() + assert test_result_path.is_dir() + assert (test_result_path / "api_config.py").exists() + assert (test_result_path / "models").exists() + assert (test_result_path / "models").is_dir() + assert (test_result_path / "services").exists() + assert (test_result_path / "services").is_dir() + assert (test_result_path / "models" / "__init__.py").exists() + assert (test_result_path / "services" / "__init__.py").exists() + assert (test_result_path / "services" / "__init__.py").is_file() + assert (test_result_path / "models" / "__init__.py").is_file() + assert (test_result_path / "__init__.py").exists() + assert (test_result_path / "__init__.py").is_file() + + assert files_are_black_formatted(test_result_path) + +def files_are_black_formatted(test_result_path: Path) -> bool: + # Run the `black --check` command on all files. This does not write any file. + result = subprocess.run([ + "black", + "--check", + # Overwrite any exclusion due to a .gitignore. + "--exclude", "''", + # Settings also used when formatting the code when writing it + "--fast" if FormatOptions.skip_validation else "--safe", + "--line-length", str(FormatOptions.line_length), + # The source directory + str(test_result_path.absolute()) + ], + capture_output=True, + text=True + ) + + # With `--check` the return status has the following meaning: + # - Return code 0 means nothing would change. + # - Return code 1 means some files would be reformatted. + # - Return code 123 means there was an internal error. + + if result.returncode == 123: + result.check_returncode # raise the error + + return result.returncode == 0 \ No newline at end of file