Skip to content

Commit ad5cbc5

Browse files
author
Kareem Zidane
authored
Merge pull request #113 from cs50/develop
cte
2 parents 1689322 + af0dea3 commit ad5cbc5

File tree

3 files changed

+50
-40
lines changed

3 files changed

+50
-40
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="5.0.2"
19+
version="5.0.3"
2020
)

src/cs50/sql.py

+43-39
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ def execute(self, sql, *args, **kwargs):
120120
if len(args) > 0 and len(kwargs) > 0:
121121
raise RuntimeError("cannot pass both named and positional parameters")
122122

123+
# Infer command from (unflattened) statement
124+
for token in statements[0]:
125+
if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
126+
command = token.value.upper()
127+
break
128+
else:
129+
command = None
130+
123131
# Flatten statement
124132
tokens = list(statements[0].flatten())
125133

@@ -313,45 +321,41 @@ def shutdown_session(exception=None):
313321

314322
# Return value
315323
ret = True
316-
if tokens[0].ttype == sqlparse.tokens.Keyword.DML:
317-
318-
# Uppercase token's value
319-
value = tokens[0].value.upper()
320-
321-
# If SELECT, return result set as list of dict objects
322-
if value == "SELECT":
323-
324-
# Coerce types
325-
rows = [dict(row) for row in result.fetchall()]
326-
for row in rows:
327-
for column in row:
328-
329-
# Coerce decimal.Decimal objects to float objects
330-
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
331-
if type(row[column]) is decimal.Decimal:
332-
row[column] = float(row[column])
333-
334-
# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
335-
elif type(row[column]) is memoryview:
336-
row[column] = bytes(row[column])
337-
338-
# Rows to be returned
339-
ret = rows
340-
341-
# If INSERT, return primary key value for a newly inserted row (or None if none)
342-
elif value == "INSERT":
343-
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
344-
try:
345-
result = connection.execute("SELECT LASTVAL()")
346-
ret = result.first()[0]
347-
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
348-
ret = None
349-
else:
350-
ret = result.lastrowid if result.rowcount == 1 else None
351-
352-
# If DELETE or UPDATE, return number of rows matched
353-
elif value in ["DELETE", "UPDATE"]:
354-
ret = result.rowcount
324+
325+
# If SELECT, return result set as list of dict objects
326+
if command == "SELECT":
327+
328+
# Coerce types
329+
rows = [dict(row) for row in result.fetchall()]
330+
for row in rows:
331+
for column in row:
332+
333+
# Coerce decimal.Decimal objects to float objects
334+
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
335+
if type(row[column]) is decimal.Decimal:
336+
row[column] = float(row[column])
337+
338+
# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
339+
elif type(row[column]) is memoryview:
340+
row[column] = bytes(row[column])
341+
342+
# Rows to be returned
343+
ret = rows
344+
345+
# If INSERT, return primary key value for a newly inserted row (or None if none)
346+
elif command == "INSERT":
347+
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
348+
try:
349+
result = connection.execute("SELECT LASTVAL()")
350+
ret = result.first()[0]
351+
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
352+
ret = None
353+
else:
354+
ret = result.lastrowid if result.rowcount == 1 else None
355+
356+
# If DELETE or UPDATE, return number of rows matched
357+
elif command in ["DELETE", "UPDATE"]:
358+
ret = result.rowcount
355359

356360
# If constraint violated, return None
357361
except sqlalchemy.exc.IntegrityError as e:

tests/sql.py

+6
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def setUpClass(self):
158158
def setUp(self):
159159
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
160160

161+
def test_cte(self):
162+
self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}])
163+
161164
class SQLiteTests(SQLTests):
162165
@classmethod
163166
def setUpClass(self):
@@ -306,6 +309,9 @@ def test_numeric(self):
306309
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:1, :2)", 'bar', 'baz', 'qux')
307310
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:1, :2)", 'bar', baz='baz')
308311

312+
def test_cte(self):
313+
self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}])
314+
309315

310316
if __name__ == "__main__":
311317
suite = unittest.TestSuite([

0 commit comments

Comments
 (0)