Skip to content

Commit eef0b77

Browse files
🐛 Fix Enum handling in SQLAlchemy (#165)
Co-authored-by: Sebastián Ramírez <[email protected]>
1 parent 2fab481 commit eef0b77

File tree

2 files changed

+76
-13
lines changed

2 files changed

+76
-13
lines changed

sqlmodel/main.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,9 @@
3131
from pydantic.main import ModelMetaclass, validate_model
3232
from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations
3333
from pydantic.utils import ROOT_KEY, Representation
34-
from sqlalchemy import (
35-
Boolean,
36-
Column,
37-
Date,
38-
DateTime,
39-
Float,
40-
ForeignKey,
41-
Integer,
42-
Interval,
43-
Numeric,
44-
inspect,
45-
)
34+
from sqlalchemy import Boolean, Column, Date, DateTime
35+
from sqlalchemy import Enum as sa_Enum
36+
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
4637
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
4738
from sqlalchemy.orm.attributes import set_attribute
4839
from sqlalchemy.orm.decl_api import DeclarativeMeta
@@ -396,7 +387,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
396387
if issubclass(field.type_, time):
397388
return Time
398389
if issubclass(field.type_, Enum):
399-
return Enum
390+
return sa_Enum(field.type_)
400391
if issubclass(field.type_, bytes):
401392
return LargeBinary
402393
if issubclass(field.type_, Decimal):

tests/test_enums.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import enum
2+
import uuid
3+
4+
from sqlalchemy import create_mock_engine
5+
from sqlalchemy.sql.type_api import TypeEngine
6+
from sqlmodel import Field, SQLModel
7+
8+
"""
9+
Tests related to Enums
10+
11+
Associated issues:
12+
* https://github.com/tiangolo/sqlmodel/issues/96
13+
* https://github.com/tiangolo/sqlmodel/issues/164
14+
"""
15+
16+
17+
class MyEnum1(enum.Enum):
18+
A = "A"
19+
B = "B"
20+
21+
22+
class MyEnum2(enum.Enum):
23+
C = "C"
24+
D = "D"
25+
26+
27+
class BaseModel(SQLModel):
28+
id: uuid.UUID = Field(primary_key=True)
29+
enum_field: MyEnum2
30+
31+
32+
class FlatModel(SQLModel, table=True):
33+
id: uuid.UUID = Field(primary_key=True)
34+
enum_field: MyEnum1
35+
36+
37+
class InheritModel(BaseModel, table=True):
38+
pass
39+
40+
41+
def pg_dump(sql: TypeEngine, *args, **kwargs):
42+
dialect = sql.compile(dialect=postgres_engine.dialect)
43+
sql_str = str(dialect).rstrip()
44+
if sql_str:
45+
print(sql_str + ";")
46+
47+
48+
def sqlite_dump(sql: TypeEngine, *args, **kwargs):
49+
dialect = sql.compile(dialect=sqlite_engine.dialect)
50+
sql_str = str(dialect).rstrip()
51+
if sql_str:
52+
print(sql_str + ";")
53+
54+
55+
postgres_engine = create_mock_engine("postgresql://", pg_dump)
56+
sqlite_engine = create_mock_engine("sqlite://", sqlite_dump)
57+
58+
59+
def test_postgres_ddl_sql(capsys):
60+
SQLModel.metadata.create_all(bind=postgres_engine, checkfirst=False)
61+
62+
captured = capsys.readouterr()
63+
assert "CREATE TYPE myenum1 AS ENUM ('A', 'B');" in captured.out
64+
assert "CREATE TYPE myenum2 AS ENUM ('C', 'D');" in captured.out
65+
66+
67+
def test_sqlite_ddl_sql(capsys):
68+
SQLModel.metadata.create_all(bind=sqlite_engine, checkfirst=False)
69+
70+
captured = capsys.readouterr()
71+
assert "enum_field VARCHAR(1) NOT NULL" in captured.out
72+
assert "CREATE TYPE" not in captured.out

0 commit comments

Comments
 (0)