From c3f8340c155a5f2d953e14a8f1ec7fc529b7a8c1 Mon Sep 17 00:00:00 2001 From: Chris White Date: Tue, 23 Nov 2021 22:22:34 -0800 Subject: [PATCH 1/3] Enum fixes (typing and inheritance) Addresses issues #96 & #164 --- sqlmodel/main.py | 17 +++-------- tests/test_enums.py | 72 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 13 deletions(-) create mode 100644 tests/test_enums.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 661276b31d..315424d65e 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -31,18 +31,9 @@ from pydantic.main import BaseConfig, ModelMetaclass, validate_model from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations from pydantic.utils import ROOT_KEY, Representation -from sqlalchemy import ( - Boolean, - Column, - Date, - DateTime, - Float, - ForeignKey, - Integer, - Interval, - Numeric, - inspect, -) +from sqlalchemy import Boolean, Column, Date, DateTime +from sqlalchemy import Enum as sa_Enum +from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta @@ -389,7 +380,7 @@ def get_sqlachemy_type(field: ModelField) -> Any: if issubclass(field.type_, time): return Time if issubclass(field.type_, Enum): - return Enum + return sa_Enum(field.type_) if issubclass(field.type_, bytes): return LargeBinary if issubclass(field.type_, Decimal): diff --git a/tests/test_enums.py b/tests/test_enums.py new file mode 100644 index 0000000000..aeec6456da --- /dev/null +++ b/tests/test_enums.py @@ -0,0 +1,72 @@ +import enum +import uuid + +from sqlalchemy import create_mock_engine +from sqlalchemy.sql.type_api import TypeEngine +from sqlmodel import Field, SQLModel + +""" +Tests related to Enums + +Associated issues: +* https://github.com/tiangolo/sqlmodel/issues/96 +* https://github.com/tiangolo/sqlmodel/issues/164 +""" + + +class MyEnum1(enum.Enum): + A = "A" + B = "B" + + +class MyEnum2(enum.Enum): + C = "C" + D = "D" + + +class BaseModel(SQLModel): + id: uuid.UUID = Field(primary_key=True) + enum_field: MyEnum2 + + +class FlatModel(SQLModel, table=True): + id: uuid.UUID = Field(primary_key=True) + enum_field: MyEnum1 + + +class InheritModel(BaseModel, table=True): + pass + + +def pg_dump(sql: TypeEngine, *args, **kwargs): + dialect = sql.compile(dialect=postgres_engine.dialect) + sql_str = str(dialect).rstrip() + if sql_str: + print(sql_str + ";") + + +def sqlite_dump(sql: TypeEngine, *args, **kwargs): + dialect = sql.compile(dialect=sqlite_engine.dialect) + sql_str = str(dialect).rstrip() + if sql_str: + print(sql_str + ";") + + +postgres_engine = create_mock_engine("postgresql://", pg_dump) +sqlite_engine = create_mock_engine("sqlite://", sqlite_dump) + + +def test_postgres_ddl_sql(capsys): + SQLModel.metadata.create_all(bind=postgres_engine, checkfirst=False) + + captured = capsys.readouterr() + assert "CREATE TYPE myenum1 AS ENUM ('A', 'B');" in captured.out + assert "CREATE TYPE myenum2 AS ENUM ('C', 'D');" in captured.out + + +def test_sqlite_ddl_sql(capsys): + SQLModel.metadata.create_all(bind=sqlite_engine, checkfirst=False) + + captured = capsys.readouterr() + assert "enum_field VARCHAR(1) NOT NULL" in captured.out + assert "CREATE TYPE" not in captured.out From c01152379ef0f0b255cf87f2e1071ef8e1f8f83a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 28 Aug 2022 00:42:57 +0200 Subject: [PATCH 2/3] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20format=20of=20imports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 86e28b3333..6099aa305a 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -31,9 +31,19 @@ from pydantic.main import ModelMetaclass, validate_model from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations from pydantic.utils import ROOT_KEY, Representation -from sqlalchemy import Boolean, Column, Date, DateTime from sqlalchemy import Enum as sa_Enum -from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect +from sqlalchemy import ( + Boolean, + Column, + Date, + DateTime, + Float, + ForeignKey, + Integer, + Interval, + Numeric, + inspect, +) from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta From 37c1eb738003c23450e2fce86cb8e1b6826238a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 28 Aug 2022 00:46:07 +0200 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20import=20sorts,=20prob?= =?UTF-8?q?ably=20revert=20the=20last=20order=20"fix"=20=F0=9F=98=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 6099aa305a..86e28b3333 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -31,19 +31,9 @@ from pydantic.main import ModelMetaclass, validate_model from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations from pydantic.utils import ROOT_KEY, Representation +from sqlalchemy import Boolean, Column, Date, DateTime from sqlalchemy import Enum as sa_Enum -from sqlalchemy import ( - Boolean, - Column, - Date, - DateTime, - Float, - ForeignKey, - Integer, - Interval, - Numeric, - inspect, -) +from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta