|
| 1 | +import enum |
| 2 | +import uuid |
| 3 | + |
| 4 | +from sqlalchemy import create_mock_engine |
| 5 | +from sqlalchemy.sql.type_api import TypeEngine |
| 6 | +from sqlmodel import Field, SQLModel |
| 7 | + |
| 8 | +""" |
| 9 | +Tests related to Enums |
| 10 | +
|
| 11 | +Associated issues: |
| 12 | +* https://github.com/tiangolo/sqlmodel/issues/96 |
| 13 | +* https://github.com/tiangolo/sqlmodel/issues/164 |
| 14 | +""" |
| 15 | + |
| 16 | + |
| 17 | +class MyEnum1(enum.Enum): |
| 18 | + A = "A" |
| 19 | + B = "B" |
| 20 | + |
| 21 | + |
| 22 | +class MyEnum2(enum.Enum): |
| 23 | + C = "C" |
| 24 | + D = "D" |
| 25 | + |
| 26 | + |
| 27 | +class BaseModel(SQLModel): |
| 28 | + id: uuid.UUID = Field(primary_key=True) |
| 29 | + enum_field: MyEnum2 |
| 30 | + |
| 31 | + |
| 32 | +class FlatModel(SQLModel, table=True): |
| 33 | + id: uuid.UUID = Field(primary_key=True) |
| 34 | + enum_field: MyEnum1 |
| 35 | + |
| 36 | + |
| 37 | +class InheritModel(BaseModel, table=True): |
| 38 | + pass |
| 39 | + |
| 40 | + |
| 41 | +def pg_dump(sql: TypeEngine, *args, **kwargs): |
| 42 | + dialect = sql.compile(dialect=postgres_engine.dialect) |
| 43 | + sql_str = str(dialect).rstrip() |
| 44 | + if sql_str: |
| 45 | + print(sql_str + ";") |
| 46 | + |
| 47 | + |
| 48 | +def sqlite_dump(sql: TypeEngine, *args, **kwargs): |
| 49 | + dialect = sql.compile(dialect=sqlite_engine.dialect) |
| 50 | + sql_str = str(dialect).rstrip() |
| 51 | + if sql_str: |
| 52 | + print(sql_str + ";") |
| 53 | + |
| 54 | + |
| 55 | +postgres_engine = create_mock_engine("postgresql://", pg_dump) |
| 56 | +sqlite_engine = create_mock_engine("sqlite://", sqlite_dump) |
| 57 | + |
| 58 | + |
| 59 | +def test_postgres_ddl_sql(capsys): |
| 60 | + SQLModel.metadata.create_all(bind=postgres_engine, checkfirst=False) |
| 61 | + |
| 62 | + captured = capsys.readouterr() |
| 63 | + assert "CREATE TYPE myenum1 AS ENUM ('A', 'B');" in captured.out |
| 64 | + assert "CREATE TYPE myenum2 AS ENUM ('C', 'D');" in captured.out |
| 65 | + |
| 66 | + |
| 67 | +def test_sqlite_ddl_sql(capsys): |
| 68 | + SQLModel.metadata.create_all(bind=sqlite_engine, checkfirst=False) |
| 69 | + |
| 70 | + captured = capsys.readouterr() |
| 71 | + assert "enum_field VARCHAR(1) NOT NULL" in captured.out |
| 72 | + assert "CREATE TYPE" not in captured.out |
0 commit comments