Skip to content

Commit db3e9f4

Browse files
Citojnak
authored andcommitted
Improve creation of column and sort enums (#210)
- create separate enum module - create enums based on object type, not based on model - provide more customization options - split tests in different modules - adapt flask_sqlalchemy example - use conftest.py for better test fixtures
1 parent e362e3f commit db3e9f4

24 files changed

+1416
-488
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ __pycache__/
1111
# Distribution / packaging
1212
.Python
1313
env/
14+
.venv/
1415
build/
1516
develop-eggs/
1617
dist/

examples/flask_sqlalchemy/app.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,46 @@
11
#!/usr/bin/env python
22

3+
from database import db_session, init_db
34
from flask import Flask
5+
from schema import schema
46

57
from flask_graphql import GraphQLView
68

7-
from .database import db_session, init_db
8-
from .schema import schema
9-
109
app = Flask(__name__)
1110
app.debug = True
1211

13-
default_query = '''
12+
example_query = """
1413
{
15-
allEmployees {
14+
allEmployees(sort: [NAME_ASC, ID_ASC]) {
1615
edges {
1716
node {
18-
id,
19-
name,
17+
id
18+
name
2019
department {
21-
id,
20+
id
2221
name
23-
},
22+
}
2423
role {
25-
id,
24+
id
2625
name
2726
}
2827
}
2928
}
3029
}
31-
}'''.strip()
30+
}
31+
"""
3232

3333

34-
app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=True))
34+
app.add_url_rule(
35+
"/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=True)
36+
)
3537

3638

3739
@app.teardown_appcontext
3840
def shutdown_session(exception=None):
3941
db_session.remove()
4042

41-
if __name__ == '__main__':
43+
44+
if __name__ == "__main__":
4245
init_db()
4346
app.run()

examples/flask_sqlalchemy/database.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def init_db():
1414
# import all modules here that might define models so that
1515
# they will be registered properly on the metadata. Otherwise
1616
# you will have to import them first before calling init_db()
17-
from .models import Department, Employee, Role
17+
from models import Department, Employee, Role
1818
Base.metadata.drop_all(bind=engine)
1919
Base.metadata.create_all(bind=engine)
2020

examples/flask_sqlalchemy/models.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1+
from database import Base
12
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func
23
from sqlalchemy.orm import backref, relationship
34

4-
from .database import Base
5-
65

76
class Department(Base):
87
__tablename__ = 'department'

examples/flask_sqlalchemy/schema.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
from models import Department as DepartmentModel
2+
from models import Employee as EmployeeModel
3+
from models import Role as RoleModel
4+
15
import graphene
26
from graphene import relay
3-
from graphene_sqlalchemy import (SQLAlchemyConnectionField,
4-
SQLAlchemyObjectType, utils)
5-
6-
from .models import Department as DepartmentModel
7-
from .models import Employee as EmployeeModel
8-
from .models import Role as RoleModel
7+
from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType
98

109

1110
class Department(SQLAlchemyObjectType):
@@ -26,18 +25,11 @@ class Meta:
2625
interfaces = (relay.Node, )
2726

2827

29-
SortEnumEmployee = utils.sort_enum_for_model(EmployeeModel, 'SortEnumEmployee',
30-
lambda c, d: c.upper() + ('_ASC' if d else '_DESC'))
31-
32-
3328
class Query(graphene.ObjectType):
3429
node = relay.Node.Field()
3530
# Allow only single column sorting
3631
all_employees = SQLAlchemyConnectionField(
37-
Employee,
38-
sort=graphene.Argument(
39-
SortEnumEmployee,
40-
default_value=utils.EnumValue('id_asc', EmployeeModel.id.asc())))
32+
Employee, sort=Employee.sort_argument())
4133
# Allows sorting over multiple columns, by default over the primary key
4234
all_roles = SQLAlchemyConnectionField(Role)
4335
# Disable sorting over this field

graphene_sqlalchemy/converter.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
String)
88
from graphene.types.json import JSONString
99

10+
from .enums import enum_for_sa_enum
11+
from .registry import get_global_registry
12+
1013
try:
1114
from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType
1215
except ImportError:
@@ -145,21 +148,15 @@ def convert_column_to_float(type, column, registry=None):
145148

146149
@convert_sqlalchemy_type.register(types.Enum)
147150
def convert_enum_to_enum(type, column, registry=None):
148-
enum_class = getattr(type, 'enum_class', None)
149-
if enum_class: # Check if an enum.Enum type is used
150-
graphene_type = Enum.from_enum(enum_class)
151-
else: # Nope, just a list of string options
152-
items = zip(type.enums, type.enums)
153-
graphene_type = Enum(type.name, items)
154151
return Field(
155-
graphene_type,
152+
lambda: enum_for_sa_enum(type, registry or get_global_registry()),
156153
description=get_column_doc(column),
157154
required=not (is_column_nullable(column)),
158155
)
159156

160157

161158
@convert_sqlalchemy_type.register(ChoiceType)
162-
def convert_column_to_enum(type, column, registry=None):
159+
def convert_choice_to_enum(type, column, registry=None):
163160
name = "{}_{}".format(column.table.name, column.name).upper()
164161
return Enum(name, type.choices, description=get_column_doc(column))
165162

graphene_sqlalchemy/enums.py

+203
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from sqlalchemy import Column
2+
from sqlalchemy.types import Enum as SQLAlchemyEnumType
3+
4+
from graphene import Argument, Enum, List
5+
6+
from .utils import EnumValue, to_enum_value_name, to_type_name
7+
8+
9+
def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None):
10+
"""Convert the given SQLAlchemy Enum type to a Graphene Enum type.
11+
12+
The name of the Graphene Enum will be determined as follows:
13+
If the SQLAlchemy Enum is based on a Python Enum, use the name
14+
of the Python Enum. Otherwise, if the SQLAlchemy Enum is named,
15+
use the SQL name after conversion to a type name. Otherwise, use
16+
the given fallback_name or raise an error if it is empty.
17+
18+
The Enum value names are converted to upper case if necessary.
19+
"""
20+
if not isinstance(sa_enum, SQLAlchemyEnumType):
21+
raise TypeError(
22+
"Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)
23+
)
24+
enum_class = sa_enum.enum_class
25+
if enum_class:
26+
if all(to_enum_value_name(key) == key for key in enum_class.__members__):
27+
return Enum.from_enum(enum_class)
28+
name = enum_class.__name__
29+
members = [
30+
(to_enum_value_name(key), value.value)
31+
for key, value in enum_class.__members__.items()
32+
]
33+
else:
34+
sql_enum_name = sa_enum.name
35+
if sql_enum_name:
36+
name = to_type_name(sql_enum_name)
37+
elif fallback_name:
38+
name = fallback_name
39+
else:
40+
raise TypeError("No type name specified for {!r}".format(sa_enum))
41+
members = [(to_enum_value_name(key), key) for key in sa_enum.enums]
42+
return Enum(name, members)
43+
44+
45+
def enum_for_sa_enum(sa_enum, registry):
46+
"""Return the Graphene Enum type for the specified SQLAlchemy Enum type."""
47+
if not isinstance(sa_enum, SQLAlchemyEnumType):
48+
raise TypeError(
49+
"Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)
50+
)
51+
enum = registry.get_graphene_enum_for_sa_enum(sa_enum)
52+
if not enum:
53+
enum = _convert_sa_to_graphene_enum(sa_enum)
54+
registry.register_enum(sa_enum, enum)
55+
return enum
56+
57+
58+
def enum_for_field(obj_type, field_name):
59+
"""Return the Graphene Enum type for the specified Graphene field."""
60+
from .types import SQLAlchemyObjectType
61+
62+
if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType):
63+
raise TypeError(
64+
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type))
65+
if not field_name or not isinstance(field_name, str):
66+
raise TypeError(
67+
"Expected a field name, but got: {!r}".format(field_name))
68+
registry = obj_type._meta.registry
69+
orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name)
70+
if orm_field is None:
71+
raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name))
72+
if not isinstance(orm_field, Column):
73+
raise TypeError(
74+
"{}.{} does not map to model column".format(obj_type._meta.name, field_name)
75+
)
76+
sa_enum = orm_field.type
77+
if not isinstance(sa_enum, SQLAlchemyEnumType):
78+
raise TypeError(
79+
"{}.{} does not map to enum column".format(obj_type._meta.name, field_name)
80+
)
81+
enum = registry.get_graphene_enum_for_sa_enum(sa_enum)
82+
if not enum:
83+
fallback_name = obj_type._meta.name + to_type_name(field_name)
84+
enum = _convert_sa_to_graphene_enum(sa_enum, fallback_name)
85+
registry.register_enum(sa_enum, enum)
86+
return enum
87+
88+
89+
def _default_sort_enum_symbol_name(column_name, sort_asc=True):
90+
return to_enum_value_name(column_name) + ("_ASC" if sort_asc else "_DESC")
91+
92+
93+
def sort_enum_for_object_type(
94+
obj_type, name=None, only_fields=None, only_indexed=None, get_symbol_name=None
95+
):
96+
"""Return Graphene Enum for sorting the given SQLAlchemyObjectType.
97+
98+
Parameters
99+
- obj_type : SQLAlchemyObjectType
100+
The object type for which the sort Enum shall be generated.
101+
- name : str, optional, default None
102+
Name to use for the sort Enum.
103+
If not provided, it will be set to the object type name + 'SortEnum'
104+
- only_fields : sequence, optional, default None
105+
If this is set, only fields from this sequence will be considered.
106+
- only_indexed : bool, optional, default False
107+
If this is set, only indexed columns will be considered.
108+
- get_symbol_name : function, optional, default None
109+
Function which takes the column name and a boolean indicating
110+
if the sort direction is ascending, and returns the symbol name
111+
for the current column and sort direction. If no such function
112+
is passed, a default function will be used that creates the symbols
113+
'foo_asc' and 'foo_desc' for a column with the name 'foo'.
114+
115+
Returns
116+
- Enum
117+
The Graphene Enum type
118+
"""
119+
name = name or obj_type._meta.name + "SortEnum"
120+
registry = obj_type._meta.registry
121+
enum = registry.get_sort_enum_for_object_type(obj_type)
122+
custom_options = dict(
123+
only_fields=only_fields,
124+
only_indexed=only_indexed,
125+
get_symbol_name=get_symbol_name,
126+
)
127+
if enum:
128+
if name != enum.__name__ or custom_options != enum.custom_options:
129+
raise ValueError(
130+
"Sort enum for {} has already been customized".format(obj_type)
131+
)
132+
else:
133+
members = []
134+
default = []
135+
fields = obj_type._meta.fields
136+
get_name = get_symbol_name or _default_sort_enum_symbol_name
137+
for field_name in fields:
138+
if only_fields and field_name not in only_fields:
139+
continue
140+
orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name)
141+
if not isinstance(orm_field, Column):
142+
continue
143+
if only_indexed and not (orm_field.primary_key or orm_field.index):
144+
continue
145+
asc_name = get_name(orm_field.name, True)
146+
asc_value = EnumValue(asc_name, orm_field.asc())
147+
desc_name = get_name(orm_field.name, False)
148+
desc_value = EnumValue(desc_name, orm_field.desc())
149+
if orm_field.primary_key:
150+
default.append(asc_value)
151+
members.extend(((asc_name, asc_value), (desc_name, desc_value)))
152+
enum = Enum(name, members)
153+
enum.default = default # store default as attribute
154+
enum.custom_options = custom_options
155+
registry.register_sort_enum(obj_type, enum)
156+
return enum
157+
158+
159+
def sort_argument_for_object_type(
160+
obj_type,
161+
enum_name=None,
162+
only_fields=None,
163+
only_indexed=None,
164+
get_symbol_name=None,
165+
has_default=True,
166+
):
167+
""""Returns Graphene Argument for sorting the given SQLAlchemyObjectType.
168+
169+
Parameters
170+
- obj_type : SQLAlchemyObjectType
171+
The object type for which the sort Argument shall be generated.
172+
- enum_name : str, optional, default None
173+
Name to use for the sort Enum.
174+
If not provided, it will be set to the object type name + 'SortEnum'
175+
- only_fields : sequence, optional, default None
176+
If this is set, only fields from this sequence will be considered.
177+
- only_indexed : bool, optional, default False
178+
If this is set, only indexed columns will be considered.
179+
- get_symbol_name : function, optional, default None
180+
Function which takes the column name and a boolean indicating
181+
if the sort direction is ascending, and returns the symbol name
182+
for the current column and sort direction. If no such function
183+
is passed, a default function will be used that creates the symbols
184+
'foo_asc' and 'foo_desc' for a column with the name 'foo'.
185+
- has_default : bool, optional, default True
186+
If this is set to False, no sorting will happen when this argument is not
187+
passed. Otherwise results will be sortied by the primary key(s) of the model.
188+
189+
Returns
190+
- Enum
191+
A Graphene Argument that accepts a list of sorting directions for the model.
192+
"""
193+
enum = sort_enum_for_object_type(
194+
obj_type,
195+
enum_name,
196+
only_fields=only_fields,
197+
only_indexed=only_indexed,
198+
get_symbol_name=get_symbol_name,
199+
)
200+
if not has_default:
201+
enum.default = None
202+
203+
return Argument(List(enum), default_value=enum.default)

0 commit comments

Comments
 (0)