Skip to content

fix!: Change reference resolution to use reference path instead of class name (fixes #342) #366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .free_form_model import FreeFormModel
from .http_validation_error import HTTPValidationError
from .model_from_all_of import ModelFromAllOf
from .model_name import ModelName
from .model_with_additional_properties_inlined import ModelWithAdditionalPropertiesInlined
from .model_with_additional_properties_inlined_additional_property import (
ModelWithAdditionalPropertiesInlinedAdditionalProperty,
Expand All @@ -19,6 +20,7 @@
from .model_with_any_json_properties_additional_property_type0 import ModelWithAnyJsonPropertiesAdditionalPropertyType0
from .model_with_primitive_additional_properties import ModelWithPrimitiveAdditionalProperties
from .model_with_primitive_additional_properties_a_date_holder import ModelWithPrimitiveAdditionalPropertiesADateHolder
from .model_with_property_ref import ModelWithPropertyRef
from .model_with_union_property import ModelWithUnionProperty
from .model_with_union_property_inlined import ModelWithUnionPropertyInlined
from .model_with_union_property_inlined_fruit_type0 import ModelWithUnionPropertyInlinedFruitType0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any, Dict, List, Type, TypeVar

import attr

T = TypeVar("T", bound="ModelName")


@attr.s(auto_attribs=True)
class ModelName:
""" """

additional_properties: Dict[str, Any] = attr.ib(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
model_name = cls()

model_name.additional_properties = d
return model_name

@property
def additional_keys(self) -> List[str]:
return list(self.additional_properties.keys())

def __getitem__(self, key: str) -> Any:
return self.additional_properties[key]

def __setitem__(self, key: str, value: Any) -> None:
self.additional_properties[key] = value

def __delitem__(self, key: str) -> None:
del self.additional_properties[key]

def __contains__(self, key: str) -> bool:
return key in self.additional_properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any, Dict, List, Type, TypeVar, Union

import attr

from ..models.model_name import ModelName
from ..types import UNSET, Unset

T = TypeVar("T", bound="ModelWithPropertyRef")


@attr.s(auto_attribs=True)
class ModelWithPropertyRef:
""" """

inner: Union[Unset, ModelName] = UNSET
additional_properties: Dict[str, Any] = attr.ib(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:
inner: Union[Unset, Dict[str, Any]] = UNSET
if not isinstance(self.inner, Unset):
inner = self.inner.to_dict()

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})
if inner is not UNSET:
field_dict["inner"] = inner

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
inner: Union[Unset, ModelName] = UNSET
_inner = d.pop("inner", UNSET)
if not isinstance(_inner, Unset):
inner = ModelName.from_dict(_inner)

model_with_property_ref = cls(
inner=inner,
)

model_with_property_ref.additional_properties = d
return model_with_property_ref

@property
def additional_keys(self) -> List[str]:
return list(self.additional_properties.keys())

def __getitem__(self, key: str) -> Any:
return self.additional_properties[key]

def __setitem__(self, key: str, value: Any) -> None:
self.additional_properties[key] = value

def __delitem__(self, key: str) -> None:
del self.additional_properties[key]

def __contains__(self, key: str) -> bool:
return key in self.additional_properties
26 changes: 18 additions & 8 deletions end_to_end_tests/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -819,43 +819,43 @@
"one_of_models": {
"oneOf": [
{
"ref": "#components/schemas/FreeFormModel"
"ref": "#/components/schemas/FreeFormModel"
},
{
"ref": "#components/schemas/ModelWithUnionProperty"
"ref": "#/components/schemas/ModelWithUnionProperty"
}
],
"nullable": false
},
"nullable_one_of_models": {
"oneOf": [
{
"ref": "#components/schemas/FreeFormModel"
"ref": "#/components/schemas/FreeFormModel"
},
{
"ref": "#components/schemas/ModelWithUnionProperty"
"ref": "#/components/schemas/ModelWithUnionProperty"
}
],
"nullable": true
},
"not_required_one_of_models": {
"oneOf": [
{
"ref": "#components/schemas/FreeFormModel"
"ref": "#/components/schemas/FreeFormModel"
},
{
"ref": "#components/schemas/ModelWithUnionProperty"
"ref": "#/components/schemas/ModelWithUnionProperty"
}
],
"nullable": false
},
"not_required_nullable_one_of_models": {
"oneOf": [
{
"ref": "#components/schemas/FreeFormModel"
"ref": "#/components/schemas/FreeFormModel"
},
{
"ref": "#components/schemas/ModelWithUnionProperty"
"ref": "#/components/schemas/ModelWithUnionProperty"
},
{
"type": "string"
Expand Down Expand Up @@ -1110,6 +1110,16 @@
"type": "string"
}
}
},
"model_reference_doesnt_match": {
"title": "ModelName",
"type": "object"
},
"ModelWithPropertyRef": {
"type": "object",
"properties": {
"inner": {"$ref": "#/components/schemas/model_reference_doesnt_match"}
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion end_to_end_tests/regen_golden_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
shutil.rmtree(gr_path, ignore_errors=True)
shutil.rmtree(output_path, ignore_errors=True)

result = runner.invoke(app, [f"--config={config_path}", "generate", f"--path={openapi_path}"])
result = runner.invoke(app, ["generate", f"--config={config_path}", f"--path={openapi_path}"])

if result.stdout:
print(result.stdout)
Expand Down
2 changes: 1 addition & 1 deletion end_to_end_tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run_e2e_test(extra_args=None, expected_differences=None):
output_path = Path.cwd() / "my-test-api-client"
shutil.rmtree(output_path, ignore_errors=True)

args = [f"--config={config_path}", "generate", f"--path={openapi_path}"]
args = ["generate", f"--config={config_path}", f"--path={openapi_path}"]
if extra_args:
args.extend(extra_args)
result = runner.invoke(app, args)
Expand Down
53 changes: 35 additions & 18 deletions openapi_python_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from openapi_python_client import utils

from .parser import GeneratorData, import_string_from_reference
from .config import Config
from .parser import GeneratorData, import_string_from_class
from .parser.errors import GeneratorError
from .utils import snake_case

Expand All @@ -41,15 +42,12 @@ class MetaType(str, Enum):


class Project:
project_name_override: Optional[str] = None
package_name_override: Optional[str] = None
package_version_override: Optional[str] = None

def __init__(
self,
*,
openapi: GeneratorData,
meta: MetaType,
config: Config,
custom_template_path: Optional[Path] = None,
file_encoding: str = "utf-8",
) -> None:
Expand All @@ -70,17 +68,17 @@ def __init__(
loader = package_loader
self.env: Environment = Environment(loader=loader, trim_blocks=True, lstrip_blocks=True)

self.project_name: str = self.project_name_override or f"{utils.kebab_case(openapi.title).lower()}-client"
self.project_name: str = config.project_name_override or f"{utils.kebab_case(openapi.title).lower()}-client"
self.project_dir: Path = Path.cwd()
if meta != MetaType.NONE:
self.project_dir /= self.project_name

self.package_name: str = self.package_name_override or self.project_name.replace("-", "_")
self.package_name: str = config.package_name_override or self.project_name.replace("-", "_")
self.package_dir: Path = self.project_dir / self.package_name
self.package_description: str = utils.remove_string_escapes(
f"A client library for accessing {self.openapi.title}"
)
self.version: str = self.package_version_override or openapi.version
self.version: str = config.package_version_override or openapi.version

self.env.filters.update(TEMPLATE_FILTERS)

Expand Down Expand Up @@ -215,21 +213,21 @@ def _build_models(self) -> None:
imports = []

model_template = self.env.get_template("model.py.jinja")
for model in self.openapi.models.values():
module_path = models_dir / f"{model.reference.module_name}.py"
for model in self.openapi.models:
module_path = models_dir / f"{model.class_info.module_name}.py"
module_path.write_text(model_template.render(model=model), encoding=self.file_encoding)
imports.append(import_string_from_reference(model.reference))
imports.append(import_string_from_class(model.class_info))

# Generate enums
str_enum_template = self.env.get_template("str_enum.py.jinja")
int_enum_template = self.env.get_template("int_enum.py.jinja")
for enum in self.openapi.enums.values():
module_path = models_dir / f"{enum.reference.module_name}.py"
for enum in self.openapi.enums:
module_path = models_dir / f"{enum.class_info.module_name}.py"
if enum.value_type is int:
module_path.write_text(int_enum_template.render(enum=enum), encoding=self.file_encoding)
else:
module_path.write_text(str_enum_template.render(enum=enum), encoding=self.file_encoding)
imports.append(import_string_from_reference(enum.reference))
imports.append(import_string_from_class(enum.class_info))

models_init_template = self.env.get_template("models_init.py.jinja")
models_init.write_text(models_init_template.render(imports=imports), encoding=self.file_encoding)
Expand Down Expand Up @@ -261,23 +259,31 @@ def _get_project_for_url_or_path(
url: Optional[str],
path: Optional[Path],
meta: MetaType,
config: Config,
custom_template_path: Optional[Path] = None,
file_encoding: str = "utf-8",
) -> Union[Project, GeneratorError]:
data_dict = _get_document(url=url, path=path)
if isinstance(data_dict, GeneratorError):
return data_dict
openapi = GeneratorData.from_dict(data_dict)
openapi = GeneratorData.from_dict(data_dict, config=config)
if isinstance(openapi, GeneratorError):
return openapi
return Project(openapi=openapi, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding)
return Project(
openapi=openapi,
custom_template_path=custom_template_path,
meta=meta,
file_encoding=file_encoding,
config=config,
)


def create_new_client(
*,
url: Optional[str],
path: Optional[Path],
meta: MetaType,
config: Config,
custom_template_path: Optional[Path] = None,
file_encoding: str = "utf-8",
) -> Sequence[GeneratorError]:
Expand All @@ -288,7 +294,12 @@ def create_new_client(
A list containing any errors encountered when generating.
"""
project = _get_project_for_url_or_path(
url=url, path=path, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding
url=url,
path=path,
custom_template_path=custom_template_path,
meta=meta,
file_encoding=file_encoding,
config=config,
)
if isinstance(project, GeneratorError):
return [project]
Expand All @@ -300,6 +311,7 @@ def update_existing_client(
url: Optional[str],
path: Optional[Path],
meta: MetaType,
config: Config,
custom_template_path: Optional[Path] = None,
file_encoding: str = "utf-8",
) -> Sequence[GeneratorError]:
Expand All @@ -310,7 +322,12 @@ def update_existing_client(
A list containing any errors encountered when generating.
"""
project = _get_project_for_url_or_path(
url=url, path=path, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding
url=url,
path=path,
custom_template_path=custom_template_path,
meta=meta,
file_encoding=file_encoding,
config=config,
)
if isinstance(project, GeneratorError):
return [project]
Expand Down
Loading