@@ -120,6 +120,14 @@ def execute(self, sql, *args, **kwargs):
120
120
if len (args ) > 0 and len (kwargs ) > 0 :
121
121
raise RuntimeError ("cannot pass both named and positional parameters" )
122
122
123
+ # Infer command from (unflattened) statement
124
+ for token in statements [0 ]:
125
+ if token .ttype in [sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
126
+ command = token .value .upper ()
127
+ break
128
+ else :
129
+ command = None
130
+
123
131
# Flatten statement
124
132
tokens = list (statements [0 ].flatten ())
125
133
@@ -313,45 +321,41 @@ def shutdown_session(exception=None):
313
321
314
322
# Return value
315
323
ret = True
316
- if tokens [0 ].ttype == sqlparse .tokens .Keyword .DML :
317
-
318
- # Uppercase token's value
319
- value = tokens [0 ].value .upper ()
320
-
321
- # If SELECT, return result set as list of dict objects
322
- if value == "SELECT" :
323
-
324
- # Coerce types
325
- rows = [dict (row ) for row in result .fetchall ()]
326
- for row in rows :
327
- for column in row :
328
-
329
- # Coerce decimal.Decimal objects to float objects
330
- # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
331
- if type (row [column ]) is decimal .Decimal :
332
- row [column ] = float (row [column ])
333
-
334
- # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
335
- elif type (row [column ]) is memoryview :
336
- row [column ] = bytes (row [column ])
337
-
338
- # Rows to be returned
339
- ret = rows
340
-
341
- # If INSERT, return primary key value for a newly inserted row (or None if none)
342
- elif value == "INSERT" :
343
- if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
344
- try :
345
- result = connection .execute ("SELECT LASTVAL()" )
346
- ret = result .first ()[0 ]
347
- except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
348
- ret = None
349
- else :
350
- ret = result .lastrowid if result .rowcount == 1 else None
351
-
352
- # If DELETE or UPDATE, return number of rows matched
353
- elif value in ["DELETE" , "UPDATE" ]:
354
- ret = result .rowcount
324
+
325
+ # If SELECT, return result set as list of dict objects
326
+ if command == "SELECT" :
327
+
328
+ # Coerce types
329
+ rows = [dict (row ) for row in result .fetchall ()]
330
+ for row in rows :
331
+ for column in row :
332
+
333
+ # Coerce decimal.Decimal objects to float objects
334
+ # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
335
+ if type (row [column ]) is decimal .Decimal :
336
+ row [column ] = float (row [column ])
337
+
338
+ # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
339
+ elif type (row [column ]) is memoryview :
340
+ row [column ] = bytes (row [column ])
341
+
342
+ # Rows to be returned
343
+ ret = rows
344
+
345
+ # If INSERT, return primary key value for a newly inserted row (or None if none)
346
+ elif command == "INSERT" :
347
+ if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
348
+ try :
349
+ result = connection .execute ("SELECT LASTVAL()" )
350
+ ret = result .first ()[0 ]
351
+ except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
352
+ ret = None
353
+ else :
354
+ ret = result .lastrowid if result .rowcount == 1 else None
355
+
356
+ # If DELETE or UPDATE, return number of rows matched
357
+ elif command in ["DELETE" , "UPDATE" ]:
358
+ ret = result .rowcount
355
359
356
360
# If constraint violated, return None
357
361
except sqlalchemy .exc .IntegrityError as e :
0 commit comments