From 7d8e809dc73ebbb76ff849a68c8f51172b9b06a2 Mon Sep 17 00:00:00 2001 From: Chiheb Date: Sun, 4 Dec 2022 17:36:46 +0100 Subject: [PATCH 1/2] Auto gen API model --- flask_restx/tools/__init__.py | 1 + flask_restx/tools/gen_api_model.py | 214 ++++++++++++++ requirements/test.pip | 1 + tests/test_auto_gen_api_model.py | 455 +++++++++++++++++++++++++++++ 4 files changed, 671 insertions(+) create mode 100644 flask_restx/tools/__init__.py create mode 100644 flask_restx/tools/gen_api_model.py create mode 100644 tests/test_auto_gen_api_model.py diff --git a/flask_restx/tools/__init__.py b/flask_restx/tools/__init__.py new file mode 100644 index 00000000..2c20b5ec --- /dev/null +++ b/flask_restx/tools/__init__.py @@ -0,0 +1 @@ +from .gen_api_model import * # noqa diff --git a/flask_restx/tools/gen_api_model.py b/flask_restx/tools/gen_api_model.py new file mode 100644 index 00000000..78499008 --- /dev/null +++ b/flask_restx/tools/gen_api_model.py @@ -0,0 +1,214 @@ +"""Auto generate SQLAlchemy API model schema from database table""" +from flask_restx.fields import ( + List as Listx, + Nested as Nestedx, + Raw, + String, + DateTime, + Date, + Boolean, + Integer, + Float, +) + +__all__ = ["gen_api_model_from_db"] + + +SQLALCHEMY_TYPES = { + "ARRAY": Listx, + "INT": Integer, + "CHAR": String, + "VARCHAR": String, + "NCHAR": String, + "NVARCHAR": String, + "TEXT": String, + "Text": String, + "FLOAT": String, + "NUMERIC": String, + "REAL": Float, + "DECIMAL": Float, + "TIMESTAMP": Float, + "DATETIME": DateTime, + "BOOLEAN": Boolean, + "BIGINT": Integer, + "SMALLINT": Integer, + "INTEGER": Integer, + "DATE": Date, + "TIME": DateTime, + "String": String, + "Integer": Integer, + "SmallInteger": Integer, + "BigInteger": Integer, + "Numeric": Float, + "Float": Float, + "DateTime": DateTime, + "Date": Date, + "Time": DateTime, + "Boolean": Boolean, + "Unicode": String, + "UnicodeText": String, + "JSON": Raw, +} + + +class Utilities: + """Utilities""" + + def __init__(self, force_camel_case: bool = True): + self.force_camel_case = force_camel_case + + def to_camel_case(self, attribute_name: str, sep="_"): + """Convert attribute name separated by sep to camelCase""" + if not self.force_camel_case: + return attribute_name + head, *tail = attribute_name.split(sep) + tail_capitalized = [k.capitalize() for k in tail] + return "".join([head] + tail_capitalized) + + +class ModelSchema(Utilities): + """Generate API model schema from SQLAlchemy database model""" + + __slots__ = ( + "api", + "model", + "fields", + "ignore_attributes", + "parents", + ) + + def __init__( + self, + api, # type: any + model, # type: any + fields=[], # type: list[str] + force_camel_case=True, # type: bool + ignore_attributes=[], # type: list[str] + parents=[], # type: list[any] + ): + super().__init__(force_camel_case) + self.api = api + self.model = model + self.fields = fields + self.ignore_attributes = ignore_attributes + self.parents = parents + + def get_api_data_type(self, db_field, attribute_name): + # type: (any, str) -> any + """Get data type from database field""" + db_field_cls = SQLALCHEMY_TYPES.get(db_field.type.__class__.__name__, None) + if db_field_cls is None: + raise ValueError( + f"Database field type <{db_field}:{db_field.type}> is not recognized/supported" + ) + try: + return db_field_cls(attribute=attribute_name) + except TypeError: + return db_field_cls( + SQLALCHEMY_TYPES.get( + db_field.type.__dict__.get("item_type", String).__class__.__name__ + ) + ) + + def _foreign_keys_conditon(self, model, elm, with_mapper=False): + # type: (any, str, bool) -> bool + has_mapper = hasattr(getattr(model, elm), "mapper") + base_condition = ( + not elm.startswith("_") + and not elm.endswith("_") + and elm not in self.ignore_attributes + and elm != "Meta" # Ignore Meta class + # Should not be a function + and not callable(getattr(model, elm, None)) + ) + if not with_mapper: + return base_condition and not has_mapper + if has_mapper and getattr(model, elm).mapper.class_ in self.parents: + return False + return base_condition and has_mapper + + def attrs_without_foreign_keys_condition(self, model, elm): + # type: (any, str) -> bool + """Return database model attributes without foreign keys""" + return self._foreign_keys_conditon(model, elm) + + def attrs_with_foreign_keys_condition(self, model, elm): + # type: (any, str) -> bool + """Return database model attributes with only foreign keys""" + return self._foreign_keys_conditon(model, elm, with_mapper=True) + + def get_model_fields(self, model, fields=[], use_columns=False): + # type: (any, list[str], bool) -> tuple | list + """Return model Meta fields or columns fields""" + if fields: + return fields + if hasattr(model, "Meta"): + if model.Meta.fields == "__all__": + return model.__dict__ + return model.Meta.fields + if use_columns: + return model.__table__.columns.keys() + return model.__dict__ + + def gen_api_model_from_db(self): + # type: () -> dict + """Gen API model from DB""" + self.parents.append(self.model) + attributes = [ + k + for k in self.get_model_fields(self.model, self.fields, use_columns=True) + if self.attrs_without_foreign_keys_condition(self.model, k) + ] # type: list[str] + + # For Nested mappings it's recommended to use a proper Meta class for each database model object + # Like this you can keep track and handle each model fields easly; better than using a default fields + if not self.fields: + mappers = [ + k + for k in self.get_model_fields(self.model) + if self.attrs_with_foreign_keys_condition(self.model, k) + ] # type: list[str | None] + else: + mappers = [] # type: list[str | None] + simple_mappings = { + self.to_camel_case(attribute): self.get_api_data_type( + self.model.__dict__.get(attribute), attribute + ) + for attribute in attributes + } + if not self.fields: + nested = { + self.to_camel_case(attribute): Listx( + Nestedx( + self.api.model( + f"Nested{attribute.capitalize()}", + ModelSchema( + api=self.api, + model=self.model.__dict__.get(attribute).mapper.class_, + force_camel_case=self.force_camel_case, + ignore_attributes=self.ignore_attributes, + parents=self.parents, + ).gen_api_model_from_db(), + ) + ) + ) + for attribute in mappers + } # type: dict + else: + nested = {} # type: dict + return {**simple_mappings, **nested} + + +def gen_api_model_from_db( + api, model, fields=[], force_camel_case=True, ignore_attributes=[] +): + # type: (any, any, list[str], bool, list[str]) -> dict + """Helper function""" + return ModelSchema( + api=api, + model=model, + fields=fields, + force_camel_case=force_camel_case, + ignore_attributes=ignore_attributes, + parents=[], # Need to force the value here otherwise it'll keep track of previous func calls + ).gen_api_model_from_db() diff --git a/requirements/test.pip b/requirements/test.pip index 6c449a79..4646505b 100644 --- a/requirements/test.pip +++ b/requirements/test.pip @@ -10,3 +10,4 @@ pytest-profiling==1.7.0 tzlocal invoke==1.7.3 twine==3.8.0 +Flask-SQLAlchemy==3.0.2 diff --git a/tests/test_auto_gen_api_model.py b/tests/test_auto_gen_api_model.py new file mode 100644 index 00000000..8c099ca8 --- /dev/null +++ b/tests/test_auto_gen_api_model.py @@ -0,0 +1,455 @@ +"""Test Auto generate SQLAlchemy API model""" +import pytest +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import declarative_base, relationship + +from flask_restx import Resource, fields, marshal +from flask_restx.tools import gen_api_model_from_db + + +class FixtureTestCase(object): + @pytest.fixture + def db(self, app): + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" + db = SQLAlchemy(app) + yield db + + @pytest.fixture + def user_model(self, db): + class User(db.Model): + id = db.Column(db.Integer, primary_key=True) + username = db.Column(db.String(80), unique=True, nullable=False) + email = db.Column(db.String(120), unique=True, nullable=False) + + def __repr__(self): + return "" % self.username + + class Meta: + fields = "__all__" + + return User + + @pytest.fixture + def user_model_with_relations(self, db): + class Address(db.Model): + id = db.Column(db.Integer, primary_key=True) + road = db.Column(db.String) + person_id = db.Column( + db.Integer, db.ForeignKey("person.id"), nullable=False + ) + + def __repr__(self): + return f"{self.road}" + + class Meta: + fields = ("road",) + + class Person(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String) + age = db.Column(db.Integer) + birth_date = db.Column(db.DateTime) + addresses = db.relationship("Address", backref="person", lazy=True) + + def __repr__(self): + return f"{self.name}" + + class Meta: + fields = ("id", "name", "birth_date", "addresses") + + yield Person + + @pytest.fixture + def models_with_deep_nested_relations(self, db): + class Country(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String) + address = db.relationship("Address", backref="country", lazy=True) + + def __repr__(self): + return f"{self.name}" + + class Meta: + fields = "__all__" + + class Address(db.Model): + id = db.Column(db.Integer, primary_key=True) + road = db.Column(db.String) + person_id = db.Column( + db.Integer, db.ForeignKey("person.id"), nullable=False + ) + + country_id = db.Column( + db.Integer, db.ForeignKey("country.id"), nullable=False + ) + + def __repr__(self): + return f"{self.road}" + + class Meta: + fields = ("id", "road", "country") + + class Person(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String) + age = db.Column(db.Integer) + birth_date = db.Column(db.DateTime) + addresses = db.relationship("Address", backref="person", lazy=True) + + def __repr__(self): + return f"{self.name}" + + class Meta: + fields = ("id", "name", "birth_date", "addresses") + + yield {"person": Person, "address": Address, "country": Country} + + @pytest.fixture + def declarative_models_with_deep_nested_relations(self): + Base = declarative_base() + + class Country(Base): + __tablename__ = "country" + + id = Column(Integer, primary_key=True) + name = Column(String) + address = relationship("Address", backref="country", lazy=True) + + def __repr__(self): + return f"{self.name}" + + class Meta: + fields = "__all__" + + class Address(Base): + __tablename__ = "address" + + id = Column(Integer, primary_key=True) + road = Column(String) + person_id = Column(Integer, ForeignKey("person.id"), nullable=False) + + country_id = Column(Integer, ForeignKey("country.id"), nullable=False) + + def __repr__(self): + return f"{self.road}" + + class Meta: + fields = ("id", "road", "country") + + class Person(Base): + __tablename__ = "person" + + id = Column(Integer, primary_key=True) + name = Column(String) + age = Column(Integer) + birth_date = Column(DateTime) + addresses = relationship("Address", backref="person", lazy=True) + + def __repr__(self): + return f"{self.name}" + + class Meta: + fields = ("id", "name", "birth_date", "addresses") + + yield {"person": Person, "address": Address, "country": Country} + + +class AutoGenAPIModelTest(FixtureTestCase): + def test_user_model(self, user_model, api): + payload = {"id": 1, "username": "toto", "email": "toto@tata.tt"} + schema = gen_api_model_from_db(api, user_model) + marshalled = marshal(payload, schema) + assert marshalled == payload + + def test_model_as_flat_dict_with_marchal_decorator_list( + self, api, client, user_model + ): + fields = api.model("Person", gen_api_model_from_db(api, user_model)) + + @api.route("/model-as-dict/") + class ModelAsDict(Resource): + @api.marshal_list_with(fields) + def get(self): + return {} + + data = client.get_specs() + + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "id": {"type": "integer"}, + "username": {"type": "string"}, + "email": {"type": "string"}, + }, + "type": "object", + } + + path = data["paths"]["/model-as-dict/"] + assert path["get"]["responses"]["200"]["schema"] == { + "type": "array", + "items": {"$ref": "#/definitions/Person"}, + } + + def test_model_as_flat_dict_with_marchal_decorator_list_kwargs( + self, api, client, user_model + ): + fields = api.model("Person", gen_api_model_from_db(api, user_model)) + + @api.route("/model-as-dict/") + class ModelAsDict(Resource): + @api.marshal_list_with(fields, code=201, description="Some details") + def get(self): + return {} + + data = client.get_specs() + + assert "definitions" in data + assert "Person" in data["definitions"] + + path = data["paths"]["/model-as-dict/"] + assert path["get"]["responses"] == { + "201": { + "description": "Some details", + "schema": { + "type": "array", + "items": {"$ref": "#/definitions/Person"}, + }, + } + } + + def test_model_as_dict_with_list(self, api, client, db): + class User(db.Model): + id = db.Column(db.Integer, primary_key=True) + username = db.Column(db.String(80), unique=True, nullable=False) + tags = db.Column(db.ARRAY(db.String)) + + def __repr__(self): + return "" % self.username + + class Meta: + fields = "__all__" + + fields = api.model("Person", gen_api_model_from_db(api, User)) + + @api.route("/model-with-list/") + class ModelAsDict(Resource): + @api.doc(model=fields) + def get(self): + return {} + + data = client.get_specs() + + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "id": {"type": "integer"}, + "username": {"type": "string"}, + "tags": {"type": "array", "items": {"type": "string"}}, + }, + "type": "object", + } + + path = data["paths"]["/model-with-list/"] + assert path["get"]["responses"]["200"]["schema"] == { + "$ref": "#/definitions/Person" + } + + def test_model_as_nested_dict_with_list( + self, api, client, user_model_with_relations + ): + + person = api.model( + "Person", gen_api_model_from_db(api, user_model_with_relations) + ) + + @api.route("/model-with-list/") + class ModelAsDict(Resource): + @api.doc(model=person) + def get(self): + return {} + + data = client.get_specs() + + assert "definitions" in data + assert "Person" in data["definitions"] + assert "NestedAddresses" in data["definitions"] + + def test_model_as_nested_dict_with_list_limited_fields( + self, api, client, user_model_with_relations + ): + + person = api.model( + "Person", gen_api_model_from_db(api, user_model_with_relations) + ) + + @api.route("/model-with-list/") + class ModelAsDict(Resource): + @api.doc(model=person) + def get(self): + return {} + + data = client.get_specs() + + assert "definitions" in data + assert "Person" in data["definitions"] + assert "NestedAddresses" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "id": {"type": "integer"}, + "name": {"type": "string"}, + "birthDate": {"format": "date-time", "type": "string"}, + "addresses": { + "type": "array", + "items": {"$ref": "#/definitions/NestedAddresses"}, + }, + }, + "type": "object", + } + assert data["definitions"]["NestedAddresses"] == { + "properties": {"road": {"type": "string"}}, + "type": "object", + } + path = data["paths"]["/model-with-list/"] + assert path["get"]["responses"]["200"]["schema"] == { + "$ref": "#/definitions/Person" + } + + def test_model_as_deep_nested_dict_with_list_limited_fields( + self, api, client, models_with_deep_nested_relations + ): + + person = api.model( + "Person", + gen_api_model_from_db(api, models_with_deep_nested_relations["person"]), + ) + + @api.route("/model-with-list/") + class ModelAsDict(Resource): + @api.doc(model=person) + def get(self): + return {} + + data = client.get_specs() + + assert "definitions" in data + assert "Person" in data["definitions"] + assert "NestedAddresses" in data["definitions"] + assert "NestedCountry" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "id": {"type": "integer"}, + "name": {"type": "string"}, + "birthDate": {"format": "date-time", "type": "string"}, + "addresses": { + "type": "array", + "items": {"$ref": "#/definitions/NestedAddresses"}, + }, + }, + "type": "object", + } + assert data["definitions"]["NestedAddresses"] == { + "properties": { + "id": {"type": "integer"}, + "road": {"type": "string"}, + "country": { + "type": "array", + "items": {"$ref": "#/definitions/NestedCountry"}, + }, + }, + "type": "object", + } + assert data["definitions"]["NestedCountry"] == { + "properties": {"id": {"type": "integer"}, "name": {"type": "string"}}, + "type": "object", + } + path = data["paths"]["/model-with-list/"] + assert path["get"]["responses"]["200"]["schema"] == { + "$ref": "#/definitions/Person" + } + + def test_model_as_deep_nested_dict_with_list_static_fields( + self, api, client, models_with_deep_nested_relations + ): + + addresses = api.model( + "Addresses", + gen_api_model_from_db( + api, models_with_deep_nested_relations["address"], fields=("id", "road") + ), + ) + countries = api.model( + "Countries", + gen_api_model_from_db( + api, models_with_deep_nested_relations["country"], fields=("name",) + ), + ) + person = api.model( + "Person", + { + **gen_api_model_from_db( + api, + models_with_deep_nested_relations["person"], + fields=("id", "name", "birth_date"), + ), + "customAddressesFieldName": fields.List(fields.Nested(addresses)), + "customCountriesFieldName": fields.List(fields.Nested(countries)), + }, + ) + + @api.route("/model-with-list/") + class ModelAsDict(Resource): + @api.doc(model=person) + def get(self): + return {} + + data = client.get_specs() + + assert "definitions" in data + assert "Person" in data["definitions"] + assert "Addresses" in data["definitions"] + assert "Countries" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "id": {"type": "integer"}, + "name": {"type": "string"}, + "birthDate": {"format": "date-time", "type": "string"}, + "customAddressesFieldName": { + "type": "array", + "items": {"$ref": "#/definitions/Addresses"}, + }, + "customCountriesFieldName": { + "type": "array", + "items": {"$ref": "#/definitions/Countries"}, + }, + }, + "type": "object", + } + assert data["definitions"]["Addresses"] == { + "properties": {"id": {"type": "integer"}, "road": {"type": "string"}}, + "type": "object", + } + assert data["definitions"]["Countries"] == { + "properties": {"name": {"type": "string"}}, + "type": "object", + } + path = data["paths"]["/model-with-list/"] + assert path["get"]["responses"]["200"]["schema"] == { + "$ref": "#/definitions/Person" + } + + def test_declarative_model_as_deep_nested_dict_with_list_limited_fields( + self, api, client, declarative_models_with_deep_nested_relations + ): + return self.test_model_as_deep_nested_dict_with_list_limited_fields( + api, client, declarative_models_with_deep_nested_relations + ) + + def test_declarative_model_as_deep_nested_dict_with_list_static_fields( + self, api, client, declarative_models_with_deep_nested_relations + ): + return self.test_model_as_deep_nested_dict_with_list_static_fields( + api, client, declarative_models_with_deep_nested_relations + ) From 8822358dbbe07b18f1599e8b54ace30699df5557 Mon Sep 17 00:00:00 2001 From: chiheb-nexus Date: Sun, 4 Dec 2022 21:25:03 +0100 Subject: [PATCH 2/2] Fix types --- flask_restx/tools/gen_api_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flask_restx/tools/gen_api_model.py b/flask_restx/tools/gen_api_model.py index 78499008..d89f78ea 100644 --- a/flask_restx/tools/gen_api_model.py +++ b/flask_restx/tools/gen_api_model.py @@ -23,8 +23,8 @@ "NVARCHAR": String, "TEXT": String, "Text": String, - "FLOAT": String, - "NUMERIC": String, + "FLOAT": Float, + "NUMERIC": Float, "REAL": Float, "DECIMAL": Float, "TIMESTAMP": Float,