Skip to content

Commit c787ad2

Browse files
authored
Refactor code to enhance selection and operations (#9)
* WIP: refactoring enhancements * Update delete method * Fix query filter parse * Update tests * Update multiple_conditions to and * fix parse filters method * fix tests * Add arithmetic filters
1 parent b79b48e commit c787ad2

12 files changed

+875
-275
lines changed

sqlalchemy_crud_plus/crud.py

+105-134
Original file line numberDiff line numberDiff line change
@@ -1,184 +1,130 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from typing import Any, Generic, Iterable, Literal, Sequence, Type, TypeVar
3+
from typing import Any, Generic, Iterable, Sequence, Type
44

5-
from pydantic import BaseModel
6-
from sqlalchemy import Row, RowMapping, and_, asc, desc, or_, select
5+
from sqlalchemy import Row, RowMapping, select
76
from sqlalchemy import delete as sa_delete
87
from sqlalchemy import update as sa_update
98
from sqlalchemy.ext.asyncio import AsyncSession
109

11-
from sqlalchemy_crud_plus.errors import ModelColumnError, SelectExpressionError
10+
from sqlalchemy_crud_plus.errors import MultipleResultsError
11+
from sqlalchemy_crud_plus.types import CreateSchema, Model, UpdateSchema
12+
from sqlalchemy_crud_plus.utils import apply_sorting, count, parse_filters
1213

13-
_Model = TypeVar('_Model')
14-
_CreateSchema = TypeVar('_CreateSchema', bound=BaseModel)
15-
_UpdateSchema = TypeVar('_UpdateSchema', bound=BaseModel)
1614

17-
18-
class CRUDPlus(Generic[_Model]):
19-
def __init__(self, model: Type[_Model]):
15+
class CRUDPlus(Generic[Model]):
16+
def __init__(self, model: Type[Model]):
2017
self.model = model
2118

22-
async def create_model(self, session: AsyncSession, obj: _CreateSchema, commit: bool = False, **kwargs) -> _Model:
19+
async def create_model(self, session: AsyncSession, obj: CreateSchema, commit: bool = False, **kwargs) -> Model:
2320
"""
2421
Create a new instance of a model
2522
26-
:param session:
27-
:param obj:
28-
:param commit:
29-
:param kwargs:
23+
:param session: The SQLAlchemy async session.
24+
:param obj: The Pydantic schema containing data to be saved.
25+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
26+
:param kwargs: Additional model data not included in the pydantic schema.
3027
:return:
3128
"""
32-
if kwargs:
33-
ins = self.model(**obj.model_dump(), **kwargs)
34-
else:
29+
if not kwargs:
3530
ins = self.model(**obj.model_dump())
31+
else:
32+
ins = self.model(**obj.model_dump(), **kwargs)
3633
session.add(ins)
3734
if commit:
3835
await session.commit()
3936
return ins
4037

4138
async def create_models(
42-
self, session: AsyncSession, obj: Iterable[_CreateSchema], commit: bool = False
43-
) -> list[_Model]:
39+
self, session: AsyncSession, obj: Iterable[CreateSchema], commit: bool = False
40+
) -> list[Model]:
4441
"""
4542
Create new instances of a model
4643
47-
:param session:
48-
:param obj:
49-
:param commit:
44+
:param session: The SQLAlchemy async session.
45+
:param obj: The Pydantic schema list containing data to be saved.
46+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
5047
:return:
5148
"""
5249
ins_list = []
53-
for i in obj:
54-
ins_list.append(self.model(**i.model_dump()))
50+
for ins in obj:
51+
ins_list.append(self.model(**ins.model_dump()))
5552
session.add_all(ins_list)
5653
if commit:
5754
await session.commit()
5855
return ins_list
5956

60-
async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | None:
57+
async def select_model(self, session: AsyncSession, pk: int) -> Model | None:
6158
"""
6259
Query by ID
6360
64-
:param session:
65-
:param pk:
61+
:param session: The SQLAlchemy async session.
62+
:param pk: The database primary key value.
6663
:return:
6764
"""
6865
stmt = select(self.model).where(self.model.id == pk)
6966
query = await session.execute(stmt)
7067
return query.scalars().first()
7168

72-
async def select_model_by_column(self, session: AsyncSession, column: str, column_value: Any) -> _Model | None:
69+
async def select_model_by_column(self, session: AsyncSession, **kwargs) -> Model | None:
7370
"""
7471
Query by column
7572
76-
:param session:
77-
:param column:
78-
:param column_value:
79-
:return:
80-
"""
81-
if hasattr(self.model, column):
82-
model_column = getattr(self.model, column)
83-
stmt = select(self.model).where(model_column == column_value) # type: ignore
84-
query = await session.execute(stmt)
85-
return query.scalars().first()
86-
else:
87-
raise ModelColumnError(f'Column {column} is not found in {self.model}')
88-
89-
async def select_model_by_columns(
90-
self, session: AsyncSession, expression: Literal['and', 'or'] = 'and', **conditions
91-
) -> _Model | None:
92-
"""
93-
Query by columns
94-
95-
:param session:
96-
:param expression:
97-
:param conditions: Query conditions, format:column1=value1, column2=value2
73+
:param session: The SQLAlchemy async session.
74+
:param kwargs: Query expressions.
9875
:return:
9976
"""
100-
where_list = []
101-
for column, value in conditions.items():
102-
if hasattr(self.model, column):
103-
model_column = getattr(self.model, column)
104-
where_list.append(model_column == value)
105-
else:
106-
raise ModelColumnError(f'Column {column} is not found in {self.model}')
107-
match expression:
108-
case 'and':
109-
stmt = select(self.model).where(and_(*where_list))
110-
query = await session.execute(stmt)
111-
case 'or':
112-
stmt = select(self.model).where(or_(*where_list))
113-
query = await session.execute(stmt)
114-
case _:
115-
raise SelectExpressionError(
116-
f'Select expression {expression} is not supported, only supports `and`, `or`'
117-
)
77+
filters = await parse_filters(self.model, **kwargs)
78+
stmt = select(self.model).where(*filters)
79+
query = await session.execute(stmt)
11880
return query.scalars().first()
11981

120-
async def select_models(self, session: AsyncSession) -> Sequence[Row[Any] | RowMapping | Any]:
82+
async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[Any] | RowMapping | Any]:
12183
"""
12284
Query all rows
12385
124-
:param session:
86+
:param session: The SQLAlchemy async session.
87+
:param kwargs: Query expressions.
12588
:return:
12689
"""
127-
stmt = select(self.model)
90+
filters = await parse_filters(self.model, **kwargs)
91+
stmt = select(self.model).where(*filters)
12892
query = await session.execute(stmt)
12993
return query.scalars().all()
13094

13195
async def select_models_order(
132-
self,
133-
session: AsyncSession,
134-
*columns,
135-
model_sort: Literal['asc', 'desc'] = 'desc',
96+
self, session: AsyncSession, sort_columns: str | list[str], sort_orders: str | list[str] | None = None, **kwargs
13697
) -> Sequence[Row | RowMapping | Any] | None:
13798
"""
138-
Query all rows asc or desc
99+
Query all rows and sort by columns
139100
140-
:param session:
141-
:param columns:
142-
:param model_sort:
101+
:param session: The SQLAlchemy async session.
102+
:param sort_columns: more details see apply_sorting
103+
:param sort_orders: more details see apply_sorting
143104
:return:
144105
"""
145-
sort_list = []
146-
for column in columns:
147-
if hasattr(self.model, column):
148-
model_column = getattr(self.model, column)
149-
sort_list.append(model_column)
150-
else:
151-
raise ModelColumnError(f'Column {column} is not found in {self.model}')
152-
match model_sort:
153-
case 'asc':
154-
query = await session.execute(select(self.model).order_by(asc(*sort_list)))
155-
case 'desc':
156-
query = await session.execute(select(self.model).order_by(desc(*sort_list)))
157-
case _:
158-
raise SelectExpressionError(
159-
f'Select sort expression {model_sort} is not supported, only supports `asc`, `desc`'
160-
)
106+
filters = await parse_filters(self.model, **kwargs)
107+
stmt = select(self.model).where(*filters)
108+
stmt_sort = await apply_sorting(self.model, stmt, sort_columns, sort_orders)
109+
query = await session.execute(stmt_sort)
161110
return query.scalars().all()
162111

163112
async def update_model(
164-
self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], commit: bool = False, **kwargs
113+
self, session: AsyncSession, pk: int, obj: UpdateSchema | dict[str, Any], commit: bool = False
165114
) -> int:
166115
"""
167-
Update an instance of model's primary key
116+
Update an instance by model's primary key
168117
169-
:param session:
170-
:param pk:
171-
:param obj:
172-
:param commit:
173-
:param kwargs:
118+
:param session: The SQLAlchemy async session.
119+
:param pk: The database primary key value.
120+
:param obj: A pydantic schema or dictionary containing the update data
121+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
174122
:return:
175123
"""
176124
if isinstance(obj, dict):
177125
instance_data = obj
178126
else:
179127
instance_data = obj.model_dump(exclude_unset=True)
180-
if kwargs:
181-
instance_data.update(kwargs)
182128
stmt = sa_update(self.model).where(self.model.id == pk).values(**instance_data)
183129
result = await session.execute(stmt)
184130
if commit:
@@ -188,55 +134,80 @@ async def update_model(
188134
async def update_model_by_column(
189135
self,
190136
session: AsyncSession,
191-
column: str,
192-
column_value: Any,
193-
obj: _UpdateSchema | dict[str, Any],
137+
obj: UpdateSchema | dict[str, Any],
138+
allow_multiple: bool = False,
194139
commit: bool = False,
195140
**kwargs,
196141
) -> int:
197142
"""
198-
Update an instance of model column
143+
Update an instance by model column
199144
200-
:param session:
201-
:param column:
202-
:param column_value:
203-
:param obj:
204-
:param commit:
205-
:param kwargs:
145+
:param session: The SQLAlchemy async session.
146+
:param obj: A pydantic schema or dictionary containing the update data
147+
:param allow_multiple: If `True`, allows updating multiple records that match the filters.
148+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
149+
:param kwargs: Query expressions.
206150
:return:
207151
"""
152+
filters = await parse_filters(self.model, **kwargs)
153+
total_count = await count(session, self.model, filters)
154+
if not allow_multiple and total_count > 1:
155+
raise MultipleResultsError(f'Only one record is expected to be update, found {total_count} records.')
208156
if isinstance(obj, dict):
209157
instance_data = obj
210158
else:
211159
instance_data = obj.model_dump(exclude_unset=True)
212-
if kwargs:
213-
instance_data.update(kwargs)
214-
if hasattr(self.model, column):
215-
model_column = getattr(self.model, column)
216-
else:
217-
raise ModelColumnError(f'Column {column} is not found in {self.model}')
218-
stmt = sa_update(self.model).where(model_column == column_value).values(**instance_data) # type: ignore
160+
stmt = sa_update(self.model).where(*filters).values(**instance_data) # type: ignore
219161
result = await session.execute(stmt)
220162
if commit:
221163
await session.commit()
222164
return result.rowcount # type: ignore
223165

224-
async def delete_model(self, session: AsyncSession, pk: int, commit: bool = False, **kwargs) -> int:
166+
async def delete_model(self, session: AsyncSession, pk: int, commit: bool = False) -> int:
225167
"""
226-
Delete an instance of a model
168+
Delete an instance by model's primary key
227169
228-
:param session:
229-
:param pk:
230-
:param commit:
231-
:param kwargs: for soft deletion only
170+
:param session: The SQLAlchemy async session.
171+
:param pk: The database primary key value.
172+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
232173
:return:
233174
"""
234-
if not kwargs:
235-
stmt = sa_delete(self.model).where(self.model.id == pk)
236-
result = await session.execute(stmt)
237-
else:
238-
stmt = sa_update(self.model).where(self.model.id == pk).values(**kwargs)
239-
result = await session.execute(stmt)
175+
stmt = sa_delete(self.model).where(self.model.id == pk)
176+
result = await session.execute(stmt)
240177
if commit:
241178
await session.commit()
242179
return result.rowcount # type: ignore
180+
181+
async def delete_model_by_column(
182+
self,
183+
session: AsyncSession,
184+
allow_multiple: bool = False,
185+
logical_deletion: bool = False,
186+
deleted_flag_column: str = 'del_flag',
187+
commit: bool = False,
188+
**kwargs,
189+
) -> int:
190+
"""
191+
Delete
192+
193+
:param session: The SQLAlchemy async session.
194+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
195+
:param kwargs: Query expressions.
196+
:param allow_multiple: If `True`, allows deleting multiple records that match the filters.
197+
:param logical_deletion: If `True`, enable logical deletion instead of physical deletion
198+
:param deleted_flag_column: Specify the flag column for logical deletion
199+
:return:
200+
"""
201+
filters = await parse_filters(self.model, **kwargs)
202+
total_count = await count(session, self.model, filters)
203+
if not allow_multiple and total_count > 1:
204+
raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.')
205+
if logical_deletion:
206+
deleted_flag = {deleted_flag_column: True}
207+
stmt = sa_update(self.model).where(*filters).values(**deleted_flag)
208+
else:
209+
stmt = sa_delete(self.model).where(*filters)
210+
await session.execute(stmt)
211+
if commit:
212+
await session.commit()
213+
return total_count

sqlalchemy_crud_plus/errors.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,22 @@ def __init__(self, msg: str) -> None:
1717
super().__init__(msg)
1818

1919

20-
class SelectExpressionError(SQLAlchemyCRUDPlusException):
20+
class SelectOperatorError(SQLAlchemyCRUDPlusException):
2121
"""Error raised when a select expression is invalid."""
2222

2323
def __init__(self, msg: str) -> None:
2424
super().__init__(msg)
25+
26+
27+
class ColumnSortError(SQLAlchemyCRUDPlusException):
28+
"""Error raised when a column sorting is invalid."""
29+
30+
def __init__(self, msg: str) -> None:
31+
super().__init__(msg)
32+
33+
34+
class MultipleResultsError(SQLAlchemyCRUDPlusException):
35+
"""Error raised when multiple results are invalid."""
36+
37+
def __init__(self, msg: str) -> None:
38+
super().__init__(msg)

sqlalchemy_crud_plus/types.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
from typing import TypeVar
4+
5+
from pydantic import BaseModel
6+
7+
Model = TypeVar('Model')
8+
9+
CreateSchema = TypeVar('CreateSchema', bound=BaseModel)
10+
UpdateSchema = TypeVar('UpdateSchema', bound=BaseModel)

0 commit comments

Comments
 (0)