Skip to content

Commit 089ac81

Browse files
committed
Guard against incorrect use of resources associated with a connection
We already check for attempts to use connection methods after the connection has been returned to the pool. This adds a similar check for objects associated with the connection: prepared statements and cursors. Issue: MagicStack#190.
1 parent 8a32fc4 commit 089ac81

File tree

6 files changed

+116
-21
lines changed

6 files changed

+116
-21
lines changed

asyncpg/connection.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ class Connection(metaclass=ConnectionMeta):
3939

4040
__slots__ = ('_protocol', '_transport', '_loop', '_types_stmt',
4141
'_type_by_name_stmt', '_top_xact', '_uid', '_aborted',
42-
'_stmt_cache', '_stmts_to_close', '_listeners',
43-
'_server_version', '_server_caps', '_intro_query',
44-
'_reset_query', '_proxy', '_stmt_exclusive_section',
45-
'_config', '_params', '_addr', '_log_listeners')
42+
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
43+
'_listeners', '_server_version', '_server_caps',
44+
'_intro_query', '_reset_query', '_proxy',
45+
'_stmt_exclusive_section', '_config', '_params', '_addr',
46+
'_log_listeners')
4647

4748
def __init__(self, protocol, transport, loop,
4849
addr: (str, int) or str,
@@ -56,6 +57,10 @@ def __init__(self, protocol, transport, loop,
5657
self._top_xact = None
5758
self._uid = 0
5859
self._aborted = False
60+
# Incremented very time the connection is released back to a pool.
61+
# Used to catch invalid references to connection-related resources
62+
# post-release (e.g. explicit prepared statements).
63+
self._pool_release_ctr = 0
5964

6065
self._addr = addr
6166
self._config = config

asyncpg/connresource.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
# Copyright (C) 2016-present the asyncpg authors and contributors
3+
# <see AUTHORS file>
4+
#
5+
# This module is part of asyncpg and is released under
6+
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
7+
8+
9+
import functools
10+
11+
from . import exceptions
12+
13+
14+
def guarded(meth):
15+
"""A decorator to add a sanity check to ConnectionResource methods."""
16+
17+
@functools.wraps(meth)
18+
def _check(self, *args, **kwargs):
19+
self._check_conn_validity(meth.__name__)
20+
return meth(self, *args, **kwargs)
21+
22+
return _check
23+
24+
25+
class ConnectionResource:
26+
__slots__ = ('_connection', '_con_release_ctr')
27+
28+
def __init__(self, connection):
29+
self._connection = connection
30+
self._con_release_ctr = getattr(connection, '_pool_release_ctr', None)
31+
32+
def _check_conn_validity(self, meth_name):
33+
con_release_ctr = getattr(self._connection, '_pool_release_ctr', None)
34+
if con_release_ctr != self._con_release_ctr:
35+
raise exceptions.InterfaceError(
36+
'cannot call {}.{}(): '
37+
'the underlying connection has been released back '
38+
'to the pool'.format(self.__class__.__name__, meth_name))

asyncpg/cursor.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@
88
import collections
99

1010
from . import compat
11+
from . import connresource
1112
from . import exceptions
1213

1314

14-
class CursorFactory:
15+
class CursorFactory(connresource.ConnectionResource):
1516
"""A cursor interface for the results of a query.
1617
1718
A cursor interface can be used to initiate efficient traversal of the
1819
results of a large query.
1920
"""
2021

21-
__slots__ = ('_state', '_connection', '_args', '_prefetch',
22-
'_query', '_timeout')
22+
__slots__ = ('_state', '_args', '_prefetch', '_query', '_timeout')
2323

2424
def __init__(self, connection, query, state, args, prefetch, timeout):
25-
self._connection = connection
25+
super().__init__(connection)
2626
self._args = args
2727
self._prefetch = prefetch
2828
self._query = query
@@ -32,13 +32,15 @@ def __init__(self, connection, query, state, args, prefetch, timeout):
3232
state.attach()
3333

3434
@compat.aiter_compat
35+
@connresource.guarded
3536
def __aiter__(self):
3637
prefetch = 50 if self._prefetch is None else self._prefetch
3738
return CursorIterator(self._connection,
3839
self._query, self._state,
3940
self._args, prefetch,
4041
self._timeout)
4142

43+
@connresource.guarded
4244
def __await__(self):
4345
if self._prefetch is not None:
4446
raise exceptions.InterfaceError(
@@ -53,14 +55,13 @@ def __del__(self):
5355
self._connection._maybe_gc_stmt(self._state)
5456

5557

56-
class BaseCursor:
58+
class BaseCursor(connresource.ConnectionResource):
5759

58-
__slots__ = ('_state', '_connection', '_args', '_portal_name',
59-
'_exhausted', '_query')
60+
__slots__ = ('_state', '_args', '_portal_name', '_exhausted', '_query')
6061

6162
def __init__(self, connection, query, state, args):
63+
super().__init__(connection)
6264
self._args = args
63-
self._connection = connection
6465
self._state = state
6566
if state is not None:
6667
state.attach()
@@ -162,9 +163,11 @@ def __init__(self, connection, query, state, args, prefetch, timeout):
162163
self._timeout = timeout
163164

164165
@compat.aiter_compat
166+
@connresource.guarded
165167
def __aiter__(self):
166168
return self
167169

170+
@connresource.guarded
168171
async def __anext__(self):
169172
if self._state is None:
170173
self._state = await self._connection._get_statement(
@@ -199,6 +202,7 @@ async def _init(self, timeout):
199202
await self._bind(timeout)
200203
return self
201204

205+
@connresource.guarded
202206
async def fetch(self, n, *, timeout=None):
203207
r"""Return the next *n* rows as a list of :class:`Record` objects.
204208
@@ -216,6 +220,7 @@ async def fetch(self, n, *, timeout=None):
216220
self._exhausted = True
217221
return recs
218222

223+
@connresource.guarded
219224
async def fetchrow(self, *, timeout=None):
220225
r"""Return the next row.
221226
@@ -232,6 +237,7 @@ async def fetchrow(self, *, timeout=None):
232237
return None
233238
return recs[0]
234239

240+
@connresource.guarded
235241
async def forward(self, n, *, timeout=None) -> int:
236242
r"""Skip over the next *n* rows.
237243

asyncpg/pool.py

+3
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ async def release(self):
175175
assert self._in_use
176176
self._in_use = False
177177

178+
# Invalidate external references to the connection.
179+
self._con._pool_release_ctr += 1
180+
178181
if self._con.is_closed():
179182
self._con = None
180183

asyncpg/prepared_stmt.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,24 @@
77

88
import json
99

10+
from . import connresource
1011
from . import cursor
1112
from . import exceptions
1213

1314

14-
class PreparedStatement:
15+
class PreparedStatement(connresource.ConnectionResource):
1516
"""A representation of a prepared statement."""
1617

17-
__slots__ = ('_connection', '_state', '_query', '_last_status')
18+
__slots__ = ('_state', '_query', '_last_status')
1819

1920
def __init__(self, connection, query, state):
20-
self._connection = connection
21+
super().__init__(connection)
2122
self._state = state
2223
self._query = query
2324
state.attach()
2425
self._last_status = None
2526

27+
@connresource.guarded
2628
def get_query(self) -> str:
2729
"""Return the text of the query for this prepared statement.
2830
@@ -33,6 +35,7 @@ def get_query(self) -> str:
3335
"""
3436
return self._query
3537

38+
@connresource.guarded
3639
def get_statusmsg(self) -> str:
3740
"""Return the status of the executed command.
3841
@@ -46,6 +49,7 @@ def get_statusmsg(self) -> str:
4649
return self._last_status
4750
return self._last_status.decode()
4851

52+
@connresource.guarded
4953
def get_parameters(self):
5054
"""Return a description of statement parameters types.
5155
@@ -60,9 +64,9 @@ def get_parameters(self):
6064
# (Type(oid=23, name='int4', kind='scalar', schema='pg_catalog'),
6165
# Type(oid=25, name='text', kind='scalar', schema='pg_catalog'))
6266
"""
63-
self._check_open()
6467
return self._state._get_parameters()
6568

69+
@connresource.guarded
6670
def get_attributes(self):
6771
"""Return a description of relation attributes (columns).
6872
@@ -85,9 +89,9 @@ def get_attributes(self):
8589
# type=Type(oid=26, name='oid', kind='scalar',
8690
# schema='pg_catalog')))
8791
"""
88-
self._check_open()
8992
return self._state._get_attributes()
9093

94+
@connresource.guarded
9195
def cursor(self, *args, prefetch=None,
9296
timeout=None) -> cursor.CursorFactory:
9397
"""Return a *cursor factory* for the prepared statement.
@@ -99,11 +103,11 @@ def cursor(self, *args, prefetch=None,
99103
100104
:return: A :class:`~cursor.CursorFactory` object.
101105
"""
102-
self._check_open()
103106
return cursor.CursorFactory(self._connection, self._query,
104107
self._state, args, prefetch,
105108
timeout)
106109

110+
@connresource.guarded
107111
async def explain(self, *args, analyze=False):
108112
"""Return the execution plan of the statement.
109113
@@ -145,6 +149,7 @@ async def explain(self, *args, analyze=False):
145149

146150
return json.loads(data)
147151

152+
@connresource.guarded
148153
async def fetch(self, *args, timeout=None):
149154
r"""Execute the statement and return a list of :class:`Record` objects.
150155
@@ -157,6 +162,7 @@ async def fetch(self, *args, timeout=None):
157162
data = await self.__bind_execute(args, 0, timeout)
158163
return data
159164

165+
@connresource.guarded
160166
async def fetchval(self, *args, column=0, timeout=None):
161167
"""Execute the statement and return a value in the first row.
162168
@@ -175,6 +181,7 @@ async def fetchval(self, *args, column=0, timeout=None):
175181
return None
176182
return data[0][column]
177183

184+
@connresource.guarded
178185
async def fetchrow(self, *args, timeout=None):
179186
"""Execute the statement and return the first row.
180187
@@ -190,16 +197,21 @@ async def fetchrow(self, *args, timeout=None):
190197
return data[0]
191198

192199
async def __bind_execute(self, args, limit, timeout):
193-
self._check_open()
194200
protocol = self._connection._protocol
195201
data, status, _ = await protocol.bind_execute(
196202
self._state, args, '', limit, True, timeout)
197203
self._last_status = status
198204
return data
199205

200-
def _check_open(self):
206+
def _check_open(self, meth_name):
201207
if self._state.closed:
202-
raise exceptions.InterfaceError('prepared statement is closed')
208+
raise exceptions.InterfaceError(
209+
'cannot call PreparedStmt.{}(): '
210+
'the prepared statement is closed'.format(meth_name))
211+
212+
def _check_conn_validity(self, meth_name):
213+
self._check_open(meth_name)
214+
super()._check_conn_validity(meth_name)
203215

204216
def __del__(self):
205217
self._state.detach()

tests/test_pool.py

+31
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ async def test_pool_11(self):
193193
async with pool.acquire() as con:
194194
self.assertIn(repr(con._con), repr(con)) # Test __repr__.
195195

196+
ps = await con.prepare('SELECT 1')
197+
async with con.transaction():
198+
cur = await con.cursor('SELECT 1')
199+
ps_cur = await ps.cursor()
200+
196201
self.assertIn('[released]', repr(con))
197202

198203
with self.assertRaisesRegex(
@@ -201,6 +206,32 @@ async def test_pool_11(self):
201206

202207
con.execute('select 1')
203208

209+
for meth in ('fetchval', 'fetchrow', 'fetch', 'explain',
210+
'get_query', 'get_statusmsg', 'get_parameters',
211+
'get_attributes'):
212+
with self.assertRaisesRegex(
213+
asyncpg.InterfaceError,
214+
r'cannot call PreparedStatement\.{meth}.*released '
215+
r'back to the pool'.format(meth=meth)):
216+
217+
getattr(ps, meth)()
218+
219+
for c in (cur, ps_cur):
220+
for meth in ('fetch', 'fetchrow'):
221+
with self.assertRaisesRegex(
222+
asyncpg.InterfaceError,
223+
r'cannot call Cursor\.{meth}.*released '
224+
r'back to the pool'.format(meth=meth)):
225+
226+
getattr(c, meth)()
227+
228+
with self.assertRaisesRegex(
229+
asyncpg.InterfaceError,
230+
r'cannot call Cursor\.forward.*released '
231+
r'back to the pool'.format(meth=meth)):
232+
233+
c.forward(1)
234+
204235
await pool.close()
205236

206237
async def test_pool_12(self):

0 commit comments

Comments
 (0)