1
1
#!/usr/bin/env python3
2
2
# -*- coding: utf-8 -*-
3
- from typing import Any , Generic , Iterable , Literal , Sequence , Type , TypeVar
3
+ from typing import Any , Generic , Iterable , Sequence , Type
4
4
5
- from pydantic import BaseModel
6
- from sqlalchemy import Row , RowMapping , and_ , asc , desc , or_ , select
5
+ from sqlalchemy import Row , RowMapping , select
7
6
from sqlalchemy import delete as sa_delete
8
7
from sqlalchemy import update as sa_update
9
8
from sqlalchemy .ext .asyncio import AsyncSession
10
9
11
- from sqlalchemy_crud_plus .errors import ModelColumnError , SelectExpressionError
10
+ from sqlalchemy_crud_plus .errors import MultipleResultsError
11
+ from sqlalchemy_crud_plus .types import CreateSchema , Model , UpdateSchema
12
+ from sqlalchemy_crud_plus .utils import apply_sorting , count , parse_filters
12
13
13
- _Model = TypeVar ('_Model' )
14
- _CreateSchema = TypeVar ('_CreateSchema' , bound = BaseModel )
15
- _UpdateSchema = TypeVar ('_UpdateSchema' , bound = BaseModel )
16
14
17
-
18
- class CRUDPlus (Generic [_Model ]):
19
- def __init__ (self , model : Type [_Model ]):
15
+ class CRUDPlus (Generic [Model ]):
16
+ def __init__ (self , model : Type [Model ]):
20
17
self .model = model
21
18
22
- async def create_model (self , session : AsyncSession , obj : _CreateSchema , commit : bool = False , ** kwargs ) -> _Model :
19
+ async def create_model (self , session : AsyncSession , obj : CreateSchema , commit : bool = False , ** kwargs ) -> Model :
23
20
"""
24
21
Create a new instance of a model
25
22
26
- :param session:
27
- :param obj:
28
- :param commit:
29
- :param kwargs:
23
+ :param session: The SQLAlchemy async session.
24
+ :param obj: The Pydantic schema containing data to be saved.
25
+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
26
+ :param kwargs: Additional model data not included in the pydantic schema.
30
27
:return:
31
28
"""
32
- if kwargs :
33
- ins = self .model (** obj .model_dump (), ** kwargs )
34
- else :
29
+ if not kwargs :
35
30
ins = self .model (** obj .model_dump ())
31
+ else :
32
+ ins = self .model (** obj .model_dump (), ** kwargs )
36
33
session .add (ins )
37
34
if commit :
38
35
await session .commit ()
39
36
return ins
40
37
41
38
async def create_models (
42
- self , session : AsyncSession , obj : Iterable [_CreateSchema ], commit : bool = False
43
- ) -> list [_Model ]:
39
+ self , session : AsyncSession , obj : Iterable [CreateSchema ], commit : bool = False
40
+ ) -> list [Model ]:
44
41
"""
45
42
Create new instances of a model
46
43
47
- :param session:
48
- :param obj:
49
- :param commit:
44
+ :param session: The SQLAlchemy async session.
45
+ :param obj: The Pydantic schema list containing data to be saved.
46
+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
50
47
:return:
51
48
"""
52
49
ins_list = []
53
- for i in obj :
54
- ins_list .append (self .model (** i .model_dump ()))
50
+ for ins in obj :
51
+ ins_list .append (self .model (** ins .model_dump ()))
55
52
session .add_all (ins_list )
56
53
if commit :
57
54
await session .commit ()
58
55
return ins_list
59
56
60
- async def select_model_by_id (self , session : AsyncSession , pk : int ) -> _Model | None :
57
+ async def select_model (self , session : AsyncSession , pk : int ) -> Model | None :
61
58
"""
62
59
Query by ID
63
60
64
- :param session:
65
- :param pk:
61
+ :param session: The SQLAlchemy async session.
62
+ :param pk: The database primary key value.
66
63
:return:
67
64
"""
68
65
stmt = select (self .model ).where (self .model .id == pk )
69
66
query = await session .execute (stmt )
70
67
return query .scalars ().first ()
71
68
72
- async def select_model_by_column (self , session : AsyncSession , column : str , column_value : Any ) -> _Model | None :
69
+ async def select_model_by_column (self , session : AsyncSession , ** kwargs ) -> Model | None :
73
70
"""
74
71
Query by column
75
72
76
- :param session:
77
- :param column:
78
- :param column_value:
79
- :return:
80
- """
81
- if hasattr (self .model , column ):
82
- model_column = getattr (self .model , column )
83
- stmt = select (self .model ).where (model_column == column_value ) # type: ignore
84
- query = await session .execute (stmt )
85
- return query .scalars ().first ()
86
- else :
87
- raise ModelColumnError (f'Column { column } is not found in { self .model } ' )
88
-
89
- async def select_model_by_columns (
90
- self , session : AsyncSession , expression : Literal ['and' , 'or' ] = 'and' , ** conditions
91
- ) -> _Model | None :
92
- """
93
- Query by columns
94
-
95
- :param session:
96
- :param expression:
97
- :param conditions: Query conditions, format:column1=value1, column2=value2
73
+ :param session: The SQLAlchemy async session.
74
+ :param kwargs: Query expressions.
98
75
:return:
99
76
"""
100
- where_list = []
101
- for column , value in conditions .items ():
102
- if hasattr (self .model , column ):
103
- model_column = getattr (self .model , column )
104
- where_list .append (model_column == value )
105
- else :
106
- raise ModelColumnError (f'Column { column } is not found in { self .model } ' )
107
- match expression :
108
- case 'and' :
109
- stmt = select (self .model ).where (and_ (* where_list ))
110
- query = await session .execute (stmt )
111
- case 'or' :
112
- stmt = select (self .model ).where (or_ (* where_list ))
113
- query = await session .execute (stmt )
114
- case _:
115
- raise SelectExpressionError (
116
- f'Select expression { expression } is not supported, only supports `and`, `or`'
117
- )
77
+ filters = await parse_filters (self .model , ** kwargs )
78
+ stmt = select (self .model ).where (* filters )
79
+ query = await session .execute (stmt )
118
80
return query .scalars ().first ()
119
81
120
- async def select_models (self , session : AsyncSession ) -> Sequence [Row [Any ] | RowMapping | Any ]:
82
+ async def select_models (self , session : AsyncSession , ** kwargs ) -> Sequence [Row [Any ] | RowMapping | Any ]:
121
83
"""
122
84
Query all rows
123
85
124
- :param session:
86
+ :param session: The SQLAlchemy async session.
87
+ :param kwargs: Query expressions.
125
88
:return:
126
89
"""
127
- stmt = select (self .model )
90
+ filters = await parse_filters (self .model , ** kwargs )
91
+ stmt = select (self .model ).where (* filters )
128
92
query = await session .execute (stmt )
129
93
return query .scalars ().all ()
130
94
131
95
async def select_models_order (
132
- self ,
133
- session : AsyncSession ,
134
- * columns ,
135
- model_sort : Literal ['asc' , 'desc' ] = 'desc' ,
96
+ self , session : AsyncSession , sort_columns : str | list [str ], sort_orders : str | list [str ] | None = None , ** kwargs
136
97
) -> Sequence [Row | RowMapping | Any ] | None :
137
98
"""
138
- Query all rows asc or desc
99
+ Query all rows and sort by columns
139
100
140
- :param session:
141
- :param columns:
142
- :param model_sort:
101
+ :param session: The SQLAlchemy async session.
102
+ :param sort_columns: more details see apply_sorting
103
+ :param sort_orders: more details see apply_sorting
143
104
:return:
144
105
"""
145
- sort_list = []
146
- for column in columns :
147
- if hasattr (self .model , column ):
148
- model_column = getattr (self .model , column )
149
- sort_list .append (model_column )
150
- else :
151
- raise ModelColumnError (f'Column { column } is not found in { self .model } ' )
152
- match model_sort :
153
- case 'asc' :
154
- query = await session .execute (select (self .model ).order_by (asc (* sort_list )))
155
- case 'desc' :
156
- query = await session .execute (select (self .model ).order_by (desc (* sort_list )))
157
- case _:
158
- raise SelectExpressionError (
159
- f'Select sort expression { model_sort } is not supported, only supports `asc`, `desc`'
160
- )
106
+ filters = await parse_filters (self .model , ** kwargs )
107
+ stmt = select (self .model ).where (* filters )
108
+ stmt_sort = await apply_sorting (self .model , stmt , sort_columns , sort_orders )
109
+ query = await session .execute (stmt_sort )
161
110
return query .scalars ().all ()
162
111
163
112
async def update_model (
164
- self , session : AsyncSession , pk : int , obj : _UpdateSchema | dict [str , Any ], commit : bool = False , ** kwargs
113
+ self , session : AsyncSession , pk : int , obj : UpdateSchema | dict [str , Any ], commit : bool = False
165
114
) -> int :
166
115
"""
167
- Update an instance of model's primary key
116
+ Update an instance by model's primary key
168
117
169
- :param session:
170
- :param pk:
171
- :param obj:
172
- :param commit:
173
- :param kwargs:
118
+ :param session: The SQLAlchemy async session.
119
+ :param pk: The database primary key value.
120
+ :param obj: A pydantic schema or dictionary containing the update data
121
+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
174
122
:return:
175
123
"""
176
124
if isinstance (obj , dict ):
177
125
instance_data = obj
178
126
else :
179
127
instance_data = obj .model_dump (exclude_unset = True )
180
- if kwargs :
181
- instance_data .update (kwargs )
182
128
stmt = sa_update (self .model ).where (self .model .id == pk ).values (** instance_data )
183
129
result = await session .execute (stmt )
184
130
if commit :
@@ -188,55 +134,80 @@ async def update_model(
188
134
async def update_model_by_column (
189
135
self ,
190
136
session : AsyncSession ,
191
- column : str ,
192
- column_value : Any ,
193
- obj : _UpdateSchema | dict [str , Any ],
137
+ obj : UpdateSchema | dict [str , Any ],
138
+ allow_multiple : bool = False ,
194
139
commit : bool = False ,
195
140
** kwargs ,
196
141
) -> int :
197
142
"""
198
- Update an instance of model column
143
+ Update an instance by model column
199
144
200
- :param session:
201
- :param column:
202
- :param column_value:
203
- :param obj:
204
- :param commit:
205
- :param kwargs:
145
+ :param session: The SQLAlchemy async session.
146
+ :param obj: A pydantic schema or dictionary containing the update data
147
+ :param allow_multiple: If `True`, allows updating multiple records that match the filters.
148
+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
149
+ :param kwargs: Query expressions.
206
150
:return:
207
151
"""
152
+ filters = await parse_filters (self .model , ** kwargs )
153
+ total_count = await count (session , self .model , filters )
154
+ if not allow_multiple and total_count > 1 :
155
+ raise MultipleResultsError (f'Only one record is expected to be update, found { total_count } records.' )
208
156
if isinstance (obj , dict ):
209
157
instance_data = obj
210
158
else :
211
159
instance_data = obj .model_dump (exclude_unset = True )
212
- if kwargs :
213
- instance_data .update (kwargs )
214
- if hasattr (self .model , column ):
215
- model_column = getattr (self .model , column )
216
- else :
217
- raise ModelColumnError (f'Column { column } is not found in { self .model } ' )
218
- stmt = sa_update (self .model ).where (model_column == column_value ).values (** instance_data ) # type: ignore
160
+ stmt = sa_update (self .model ).where (* filters ).values (** instance_data ) # type: ignore
219
161
result = await session .execute (stmt )
220
162
if commit :
221
163
await session .commit ()
222
164
return result .rowcount # type: ignore
223
165
224
- async def delete_model (self , session : AsyncSession , pk : int , commit : bool = False , ** kwargs ) -> int :
166
+ async def delete_model (self , session : AsyncSession , pk : int , commit : bool = False ) -> int :
225
167
"""
226
- Delete an instance of a model
168
+ Delete an instance by model's primary key
227
169
228
- :param session:
229
- :param pk:
230
- :param commit:
231
- :param kwargs: for soft deletion only
170
+ :param session: The SQLAlchemy async session.
171
+ :param pk: The database primary key value.
172
+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
232
173
:return:
233
174
"""
234
- if not kwargs :
235
- stmt = sa_delete (self .model ).where (self .model .id == pk )
236
- result = await session .execute (stmt )
237
- else :
238
- stmt = sa_update (self .model ).where (self .model .id == pk ).values (** kwargs )
239
- result = await session .execute (stmt )
175
+ stmt = sa_delete (self .model ).where (self .model .id == pk )
176
+ result = await session .execute (stmt )
240
177
if commit :
241
178
await session .commit ()
242
179
return result .rowcount # type: ignore
180
+
181
+ async def delete_model_by_column (
182
+ self ,
183
+ session : AsyncSession ,
184
+ allow_multiple : bool = False ,
185
+ logical_deletion : bool = False ,
186
+ deleted_flag_column : str = 'del_flag' ,
187
+ commit : bool = False ,
188
+ ** kwargs ,
189
+ ) -> int :
190
+ """
191
+ Delete
192
+
193
+ :param session: The SQLAlchemy async session.
194
+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
195
+ :param kwargs: Query expressions.
196
+ :param allow_multiple: If `True`, allows deleting multiple records that match the filters.
197
+ :param logical_deletion: If `True`, enable logical deletion instead of physical deletion
198
+ :param deleted_flag_column: Specify the flag column for logical deletion
199
+ :return:
200
+ """
201
+ filters = await parse_filters (self .model , ** kwargs )
202
+ total_count = await count (session , self .model , filters )
203
+ if not allow_multiple and total_count > 1 :
204
+ raise MultipleResultsError (f'Only one record is expected to be delete, found { total_count } records.' )
205
+ if logical_deletion :
206
+ deleted_flag = {deleted_flag_column : True }
207
+ stmt = sa_update (self .model ).where (* filters ).values (** deleted_flag )
208
+ else :
209
+ stmt = sa_delete (self .model ).where (* filters )
210
+ await session .execute (stmt )
211
+ if commit :
212
+ await session .commit ()
213
+ return total_count
0 commit comments