Skip to content

Commit db4f1a6

Browse files
committedAug 15, 2020
Allow using custom Record class
Add the new `record_class` parameter to the `create_pool()` and `connect()` functions, as well as to the `cursor()`, `prepare()`, `fetch()` and `fetchrow()` connection methods. This not only allows adding custom functionality to the returned objects, but also assists with typing (see #577 for discussion). Fixes: #40.
1 parent c8b8a50 commit db4f1a6

17 files changed

+610
-105
lines changed
 

‎.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[flake8]
2-
ignore = E402,E731,W504,E252
2+
ignore = E402,E731,W503,W504,E252
33
exclude = .git,__pycache__,build,dist,.eggs,.github,.local

‎asyncpg/_testbase/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import unittest
2020

2121

22+
import asyncpg
2223
from asyncpg import cluster as pg_cluster
2324
from asyncpg import connection as pg_connection
2425
from asyncpg import pool as pg_pool
@@ -266,13 +267,15 @@ def create_pool(dsn=None, *,
266267
loop=None,
267268
pool_class=pg_pool.Pool,
268269
connection_class=pg_connection.Connection,
270+
record_class=asyncpg.Record,
269271
**connect_kwargs):
270272
return pool_class(
271273
dsn,
272274
min_size=min_size, max_size=max_size,
273275
max_queries=max_queries, loop=loop, setup=setup, init=init,
274276
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
275277
connection_class=connection_class,
278+
record_class=record_class,
276279
**connect_kwargs)
277280

278281

‎asyncpg/connect_utils.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -594,8 +594,16 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
594594
raise
595595

596596

597-
async def _connect_addr(*, addr, loop, timeout, params, config,
598-
connection_class):
597+
async def _connect_addr(
598+
*,
599+
addr,
600+
loop,
601+
timeout,
602+
params,
603+
config,
604+
connection_class,
605+
record_class
606+
):
599607
assert loop is not None
600608

601609
if timeout <= 0:
@@ -613,7 +621,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
613621
params = params._replace(password=password)
614622

615623
proto_factory = lambda: protocol.Protocol(
616-
addr, connected, params, loop)
624+
addr, connected, params, record_class, loop)
617625

618626
if isinstance(addr, str):
619627
# UNIX socket
@@ -649,7 +657,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
649657
return con
650658

651659

652-
async def _connect(*, loop, timeout, connection_class, **kwargs):
660+
async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
653661
if loop is None:
654662
loop = asyncio.get_event_loop()
655663

@@ -661,9 +669,14 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
661669
before = time.monotonic()
662670
try:
663671
con = await _connect_addr(
664-
addr=addr, loop=loop, timeout=timeout,
665-
params=params, config=config,
666-
connection_class=connection_class)
672+
addr=addr,
673+
loop=loop,
674+
timeout=timeout,
675+
params=params,
676+
config=config,
677+
connection_class=connection_class,
678+
record_class=record_class,
679+
)
667680
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
668681
last_error = ex
669682
else:

‎asyncpg/connection.py

+250-53
Large diffs are not rendered by default.

‎asyncpg/cursor.py

+67-15
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,61 @@ class CursorFactory(connresource.ConnectionResource):
1919
results of a large query.
2020
"""
2121

22-
__slots__ = ('_state', '_args', '_prefetch', '_query', '_timeout')
23-
24-
def __init__(self, connection, query, state, args, prefetch, timeout):
22+
__slots__ = (
23+
'_state',
24+
'_args',
25+
'_prefetch',
26+
'_query',
27+
'_timeout',
28+
'_record_class',
29+
)
30+
31+
def __init__(
32+
self,
33+
connection,
34+
query,
35+
state,
36+
args,
37+
prefetch,
38+
timeout,
39+
record_class
40+
):
2541
super().__init__(connection)
2642
self._args = args
2743
self._prefetch = prefetch
2844
self._query = query
2945
self._timeout = timeout
3046
self._state = state
47+
self._record_class = record_class
3148
if state is not None:
3249
state.attach()
3350

3451
@compat.aiter_compat
3552
@connresource.guarded
3653
def __aiter__(self):
3754
prefetch = 50 if self._prefetch is None else self._prefetch
38-
return CursorIterator(self._connection,
39-
self._query, self._state,
40-
self._args, prefetch,
41-
self._timeout)
55+
return CursorIterator(
56+
self._connection,
57+
self._query,
58+
self._state,
59+
self._args,
60+
self._record_class,
61+
prefetch,
62+
self._timeout,
63+
)
4264

4365
@connresource.guarded
4466
def __await__(self):
4567
if self._prefetch is not None:
4668
raise exceptions.InterfaceError(
4769
'prefetch argument can only be specified for iterable cursor')
48-
cursor = Cursor(self._connection, self._query,
49-
self._state, self._args)
70+
cursor = Cursor(
71+
self._connection,
72+
self._query,
73+
self._state,
74+
self._args,
75+
self._record_class,
76+
)
5077
return cursor._init(self._timeout).__await__()
5178

5279
def __del__(self):
@@ -57,9 +84,16 @@ def __del__(self):
5784

5885
class BaseCursor(connresource.ConnectionResource):
5986

60-
__slots__ = ('_state', '_args', '_portal_name', '_exhausted', '_query')
87+
__slots__ = (
88+
'_state',
89+
'_args',
90+
'_portal_name',
91+
'_exhausted',
92+
'_query',
93+
'_record_class',
94+
)
6195

62-
def __init__(self, connection, query, state, args):
96+
def __init__(self, connection, query, state, args, record_class):
6397
super().__init__(connection)
6498
self._args = args
6599
self._state = state
@@ -68,6 +102,7 @@ def __init__(self, connection, query, state, args):
68102
self._portal_name = None
69103
self._exhausted = False
70104
self._query = query
105+
self._record_class = record_class
71106

72107
def _check_ready(self):
73108
if self._state is None:
@@ -151,8 +186,17 @@ class CursorIterator(BaseCursor):
151186

152187
__slots__ = ('_buffer', '_prefetch', '_timeout')
153188

154-
def __init__(self, connection, query, state, args, prefetch, timeout):
155-
super().__init__(connection, query, state, args)
189+
def __init__(
190+
self,
191+
connection,
192+
query,
193+
state,
194+
args,
195+
record_class,
196+
prefetch,
197+
timeout
198+
):
199+
super().__init__(connection, query, state, args, record_class)
156200

157201
if prefetch <= 0:
158202
raise exceptions.InterfaceError(
@@ -171,7 +215,11 @@ def __aiter__(self):
171215
async def __anext__(self):
172216
if self._state is None:
173217
self._state = await self._connection._get_statement(
174-
self._query, self._timeout, named=True)
218+
self._query,
219+
self._timeout,
220+
named=True,
221+
record_class=self._record_class,
222+
)
175223
self._state.attach()
176224

177225
if not self._portal_name:
@@ -196,7 +244,11 @@ class Cursor(BaseCursor):
196244
async def _init(self, timeout):
197245
if self._state is None:
198246
self._state = await self._connection._get_statement(
199-
self._query, timeout, named=True)
247+
self._query,
248+
timeout,
249+
named=True,
250+
record_class=self._record_class,
251+
)
200252
self._state.attach()
201253
self._check_ready()
202254
await self._bind(timeout)

‎asyncpg/pool.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from . import connection
1616
from . import connect_utils
1717
from . import exceptions
18+
from . import protocol
1819

1920

2021
logger = logging.getLogger(__name__)
@@ -309,7 +310,7 @@ class Pool:
309310
'_init', '_connect_args', '_connect_kwargs',
310311
'_working_addr', '_working_config', '_working_params',
311312
'_holders', '_initialized', '_initializing', '_closing',
312-
'_closed', '_connection_class', '_generation',
313+
'_closed', '_connection_class', '_record_class', '_generation',
313314
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
314315
)
315316

@@ -322,6 +323,7 @@ def __init__(self, *connect_args,
322323
init,
323324
loop,
324325
connection_class,
326+
record_class,
325327
**connect_kwargs):
326328

327329
if len(connect_args) > 1:
@@ -359,6 +361,11 @@ def __init__(self, *connect_args,
359361
'connection_class is expected to be a subclass of '
360362
'asyncpg.Connection, got {!r}'.format(connection_class))
361363

364+
if not issubclass(record_class, protocol.Record):
365+
raise TypeError(
366+
'record_class is expected to be a subclass of '
367+
'asyncpg.Record, got {!r}'.format(record_class))
368+
362369
self._minsize = min_size
363370
self._maxsize = max_size
364371

@@ -372,6 +379,7 @@ def __init__(self, *connect_args,
372379
self._working_params = None
373380

374381
self._connection_class = connection_class
382+
self._record_class = record_class
375383

376384
self._closing = False
377385
self._closed = False
@@ -469,6 +477,7 @@ async def _get_new_connection(self):
469477
*self._connect_args,
470478
loop=self._loop,
471479
connection_class=self._connection_class,
480+
record_class=self._record_class,
472481
**self._connect_kwargs)
473482

474483
self._working_addr = con._addr
@@ -484,7 +493,9 @@ async def _get_new_connection(self):
484493
timeout=self._working_params.connect_timeout,
485494
config=self._working_config,
486495
params=self._working_params,
487-
connection_class=self._connection_class)
496+
connection_class=self._connection_class,
497+
record_class=self._record_class,
498+
)
488499

489500
if self._init is not None:
490501
try:
@@ -793,6 +804,7 @@ def create_pool(dsn=None, *,
793804
init=None,
794805
loop=None,
795806
connection_class=connection.Connection,
807+
record_class=protocol.Record,
796808
**connect_kwargs):
797809
r"""Create a connection pool.
798810
@@ -851,6 +863,11 @@ def create_pool(dsn=None, *,
851863
The class to use for connections. Must be a subclass of
852864
:class:`~asyncpg.connection.Connection`.
853865
866+
:param type record_class:
867+
If specified, the class to use for records returned by queries on
868+
the connections in this pool. Must be a subclass of
869+
:class:`~asyncpg.Record`.
870+
854871
:param int min_size:
855872
Number of connection the pool will be initialized with.
856873
@@ -901,10 +918,14 @@ def create_pool(dsn=None, *,
901918
or :meth:`Connection.add_log_listener()
902919
<connection.Connection.add_log_listener>`) present on the connection
903920
at the moment of its release to the pool.
921+
922+
.. versionchanged:: 0.22.0
923+
Added the *record_class* parameter.
904924
"""
905925
return Pool(
906926
dsn,
907927
connection_class=connection_class,
928+
record_class=record_class,
908929
min_size=min_size, max_size=max_size,
909930
max_queries=max_queries, loop=loop, setup=setup, init=init,
910931
max_inactive_connection_lifetime=max_inactive_connection_lifetime,

‎asyncpg/prepared_stmt.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,15 @@ def cursor(self, *args, prefetch=None,
103103
104104
:return: A :class:`~cursor.CursorFactory` object.
105105
"""
106-
return cursor.CursorFactory(self._connection, self._query,
107-
self._state, args, prefetch,
108-
timeout)
106+
return cursor.CursorFactory(
107+
self._connection,
108+
self._query,
109+
self._state,
110+
args,
111+
prefetch,
112+
timeout,
113+
self._state.record_class,
114+
)
109115

110116
@connresource.guarded
111117
async def explain(self, *args, analyze=False):

‎asyncpg/protocol/codecs/base.pyx

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from collections.abc import Mapping as MappingABC
99

10+
import asyncpg
1011
from asyncpg import exceptions
1112

1213

@@ -232,7 +233,7 @@ cdef class Codec:
232233
schema=self.schema,
233234
data_type=self.name,
234235
)
235-
result = record.ApgRecord_New(self.record_desc, elem_count)
236+
result = record.ApgRecord_New(asyncpg.Record, self.record_desc, elem_count)
236237
for i in range(elem_count):
237238
elem_typ = self.element_type_oids[i]
238239
received_elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4))

‎asyncpg/protocol/prepared_stmt.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ cdef class PreparedStatementState:
1111
readonly str query
1212
readonly bint closed
1313
readonly int refs
14+
readonly type record_class
15+
1416

1517
list row_desc
1618
list parameters_desc

‎asyncpg/protocol/prepared_stmt.pyx

+9-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@ from asyncpg import exceptions
1111
@cython.final
1212
cdef class PreparedStatementState:
1313

14-
def __cinit__(self, str name, str query, BaseProtocol protocol):
14+
def __cinit__(
15+
self,
16+
str name,
17+
str query,
18+
BaseProtocol protocol,
19+
type record_class
20+
):
1521
self.name = name
1622
self.query = query
1723
self.settings = protocol.settings
@@ -21,6 +27,7 @@ cdef class PreparedStatementState:
2127
self.cols_desc = None
2228
self.closed = False
2329
self.refs = 0
30+
self.record_class = record_class
2431

2532
def _get_parameters(self):
2633
cdef Codec codec
@@ -264,7 +271,7 @@ cdef class PreparedStatementState:
264271
'different from what was described ({})'.format(
265272
fnum, self.cols_num))
266273

267-
dec_row = record.ApgRecord_New(self.cols_desc, fnum)
274+
dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum)
268275
for i in range(fnum):
269276
flen = hton.unpack_int32(frb_read(&rbuf, 4))
270277

‎asyncpg/protocol/protocol.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ cdef class BaseProtocol(CoreProtocol):
4242
object timeout_callback
4343
object completed_callback
4444
object conref
45+
type record_class
4546
bint is_reading
4647

4748
str last_query

‎asyncpg/protocol/protocol.pyx

+11-4
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ NO_TIMEOUT = object()
7373

7474

7575
cdef class BaseProtocol(CoreProtocol):
76-
def __init__(self, addr, connected_fut, con_params, loop):
76+
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
7777
# type of `con_params` is `_ConnectionParameters`
7878
CoreProtocol.__init__(self, con_params)
7979

@@ -85,6 +85,7 @@ cdef class BaseProtocol(CoreProtocol):
8585

8686
self.address = addr
8787
self.settings = ConnectionSettings((self.address, con_params.database))
88+
self.record_class = record_class
8889

8990
self.statement = None
9091
self.return_extra = False
@@ -122,6 +123,9 @@ cdef class BaseProtocol(CoreProtocol):
122123
def get_settings(self):
123124
return self.settings
124125

126+
def get_record_class(self):
127+
return self.record_class
128+
125129
def is_in_transaction(self):
126130
# PQTRANS_INTRANS = idle, within transaction block
127131
# PQTRANS_INERROR = idle, within failed transaction
@@ -139,7 +143,9 @@ cdef class BaseProtocol(CoreProtocol):
139143

140144
@cython.iterable_coroutine
141145
async def prepare(self, stmt_name, query, timeout,
142-
PreparedStatementState state=None):
146+
*,
147+
PreparedStatementState state=None,
148+
record_class):
143149
if self.cancel_waiter is not None:
144150
await self.cancel_waiter
145151
if self.cancel_sent_waiter is not None:
@@ -154,7 +160,8 @@ cdef class BaseProtocol(CoreProtocol):
154160
self._prepare(stmt_name, query) # network op
155161
self.last_query = query
156162
if state is None:
157-
state = PreparedStatementState(stmt_name, query, self)
163+
state = PreparedStatementState(
164+
stmt_name, query, self, record_class)
158165
self.statement = state
159166
except Exception as ex:
160167
waiter.set_exception(ex)
@@ -955,7 +962,7 @@ def _create_record(object mapping, tuple elems):
955962
desc = record.ApgRecordDesc_New(
956963
mapping, tuple(mapping) if mapping else ())
957964

958-
rec = record.ApgRecord_New(desc, len(elems))
965+
rec = record.ApgRecord_New(Record, desc, len(elems))
959966
for i in range(len(elems)):
960967
elem = elems[i]
961968
cpython.Py_INCREF(elem)

‎asyncpg/protocol/record/__init__.pxd

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ cdef extern from "record/recordobj.h":
1313
cpython.PyTypeObject *ApgRecord_InitTypes() except NULL
1414

1515
int ApgRecord_CheckExact(object)
16-
object ApgRecord_New(object, int)
16+
object ApgRecord_New(type, object, int)
1717
void ApgRecord_SET_ITEM(object, int, object)
1818

1919
object ApgRecordDesc_New(object, object)

‎asyncpg/protocol/record/recordobj.c

+34-13
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@ static PyObject * record_new_items_iter(PyObject *);
1515
static ApgRecordObject *free_list[ApgRecord_MAXSAVESIZE];
1616
static int numfree[ApgRecord_MAXSAVESIZE];
1717

18+
static size_t MAX_RECORD_SIZE = (
19+
((size_t)PY_SSIZE_T_MAX - sizeof(ApgRecordObject) - sizeof(PyObject *))
20+
/ sizeof(PyObject *)
21+
);
22+
1823

1924
PyObject *
20-
ApgRecord_New(PyObject *desc, Py_ssize_t size)
25+
ApgRecord_New(PyTypeObject *type, PyObject *desc, Py_ssize_t size)
2126
{
2227
ApgRecordObject *o;
2328
Py_ssize_t i;
@@ -27,19 +32,36 @@ ApgRecord_New(PyObject *desc, Py_ssize_t size)
2732
return NULL;
2833
}
2934

30-
if (size < ApgRecord_MAXSAVESIZE && (o = free_list[size]) != NULL) {
31-
free_list[size] = (ApgRecordObject *) o->ob_item[0];
32-
numfree[size]--;
33-
_Py_NewReference((PyObject *)o);
34-
}
35-
else {
36-
/* Check for overflow */
37-
if ((size_t)size > ((size_t)PY_SSIZE_T_MAX - sizeof(ApgRecordObject) -
38-
sizeof(PyObject *)) / sizeof(PyObject *)) {
35+
if (type == &ApgRecord_Type) {
36+
if (size < ApgRecord_MAXSAVESIZE && (o = free_list[size]) != NULL) {
37+
free_list[size] = (ApgRecordObject *) o->ob_item[0];
38+
numfree[size]--;
39+
_Py_NewReference((PyObject *)o);
40+
}
41+
else {
42+
/* Check for overflow */
43+
if ((size_t)size > MAX_RECORD_SIZE) {
44+
return PyErr_NoMemory();
45+
}
46+
o = PyObject_GC_NewVar(ApgRecordObject, &ApgRecord_Type, size);
47+
if (o == NULL) {
48+
return NULL;
49+
}
50+
}
51+
52+
PyObject_GC_Track(o);
53+
} else {
54+
assert(PyType_IsSubtype(type, &ApgRecord_Type));
55+
56+
if ((size_t)size > MAX_RECORD_SIZE) {
3957
return PyErr_NoMemory();
4058
}
41-
o = PyObject_GC_NewVar(ApgRecordObject, &ApgRecord_Type, size);
42-
if (o == NULL) {
59+
o = (ApgRecordObject *)type->tp_alloc(type, size);
60+
if (!_PyObject_GC_IS_TRACKED(o)) {
61+
PyErr_SetString(
62+
PyExc_TypeError,
63+
"record subclass is not tracked by GC"
64+
);
4365
return NULL;
4466
}
4567
}
@@ -51,7 +73,6 @@ ApgRecord_New(PyObject *desc, Py_ssize_t size)
5173
Py_INCREF(desc);
5274
o->desc = (ApgRecordDescObject*)desc;
5375
o->self_hash = -1;
54-
PyObject_GC_Track(o);
5576
return (PyObject *) o;
5677
}
5778

‎asyncpg/protocol/record/recordobj.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ extern PyTypeObject ApgRecordDesc_Type;
4646
(((ApgRecordObject *)(op))->ob_item[i])
4747

4848
PyTypeObject *ApgRecord_InitTypes(void);
49-
PyObject *ApgRecord_New(PyObject *, Py_ssize_t);
49+
PyObject *ApgRecord_New(PyTypeObject *, PyObject *, Py_ssize_t);
5050
PyObject *ApgRecordDesc_New(PyObject *, PyObject *);
5151

5252
#endif

‎tests/test_record.py

+174
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222
R_ABC = collections.OrderedDict([('a', 0), ('b', 1), ('c', 2)])
2323

2424

25+
class CustomRecord(asyncpg.Record):
26+
pass
27+
28+
29+
class AnotherCustomRecord(asyncpg.Record):
30+
pass
31+
32+
2533
class TestRecord(tb.ConnectedTestCase):
2634

2735
@contextlib.contextmanager
@@ -339,3 +347,169 @@ async def test_record_no_new(self):
339347
with self.assertRaisesRegex(
340348
TypeError, "cannot create 'asyncpg.Record' instances"):
341349
asyncpg.Record()
350+
351+
@tb.with_connection_options(record_class=CustomRecord)
352+
async def test_record_subclass_01(self):
353+
r = await self.con.fetchrow("SELECT 1 as a, '2' as b")
354+
self.assertIsInstance(r, CustomRecord)
355+
356+
r = await self.con.fetch("SELECT 1 as a, '2' as b")
357+
self.assertIsInstance(r[0], CustomRecord)
358+
359+
async with self.con.transaction():
360+
cur = await self.con.cursor("SELECT 1 as a, '2' as b")
361+
r = await cur.fetchrow()
362+
self.assertIsInstance(r, CustomRecord)
363+
364+
cur = await self.con.cursor("SELECT 1 as a, '2' as b")
365+
r = await cur.fetch(1)
366+
self.assertIsInstance(r[0], CustomRecord)
367+
368+
async with self.con.transaction():
369+
cur = self.con.cursor("SELECT 1 as a, '2' as b")
370+
async for r in cur:
371+
self.assertIsInstance(r, CustomRecord)
372+
373+
ps = await self.con.prepare("SELECT 1 as a, '2' as b")
374+
r = await ps.fetchrow()
375+
self.assertIsInstance(r, CustomRecord)
376+
377+
async def test_record_subclass_02(self):
378+
r = await self.con.fetchrow(
379+
"SELECT 1 as a, '2' as b",
380+
record_class=CustomRecord,
381+
)
382+
self.assertIsInstance(r, CustomRecord)
383+
384+
r = await self.con.fetch(
385+
"SELECT 1 as a, '2' as b",
386+
record_class=CustomRecord,
387+
)
388+
self.assertIsInstance(r[0], CustomRecord)
389+
390+
async with self.con.transaction():
391+
cur = await self.con.cursor(
392+
"SELECT 1 as a, '2' as b",
393+
record_class=CustomRecord,
394+
)
395+
r = await cur.fetchrow()
396+
self.assertIsInstance(r, CustomRecord)
397+
398+
cur = await self.con.cursor(
399+
"SELECT 1 as a, '2' as b",
400+
record_class=CustomRecord,
401+
)
402+
r = await cur.fetch(1)
403+
self.assertIsInstance(r[0], CustomRecord)
404+
405+
async with self.con.transaction():
406+
cur = self.con.cursor(
407+
"SELECT 1 as a, '2' as b",
408+
record_class=CustomRecord,
409+
)
410+
async for r in cur:
411+
self.assertIsInstance(r, CustomRecord)
412+
413+
ps = await self.con.prepare(
414+
"SELECT 1 as a, '2' as b",
415+
record_class=CustomRecord,
416+
)
417+
r = await ps.fetchrow()
418+
self.assertIsInstance(r, CustomRecord)
419+
420+
r = await ps.fetch()
421+
self.assertIsInstance(r[0], CustomRecord)
422+
423+
@tb.with_connection_options(record_class=AnotherCustomRecord)
424+
async def test_record_subclass_03(self):
425+
r = await self.con.fetchrow(
426+
"SELECT 1 as a, '2' as b",
427+
record_class=CustomRecord,
428+
)
429+
self.assertIsInstance(r, CustomRecord)
430+
431+
r = await self.con.fetch(
432+
"SELECT 1 as a, '2' as b",
433+
record_class=CustomRecord,
434+
)
435+
self.assertIsInstance(r[0], CustomRecord)
436+
437+
async with self.con.transaction():
438+
cur = await self.con.cursor(
439+
"SELECT 1 as a, '2' as b",
440+
record_class=CustomRecord,
441+
)
442+
r = await cur.fetchrow()
443+
self.assertIsInstance(r, CustomRecord)
444+
445+
cur = await self.con.cursor(
446+
"SELECT 1 as a, '2' as b",
447+
record_class=CustomRecord,
448+
)
449+
r = await cur.fetch(1)
450+
self.assertIsInstance(r[0], CustomRecord)
451+
452+
async with self.con.transaction():
453+
cur = self.con.cursor(
454+
"SELECT 1 as a, '2' as b",
455+
record_class=CustomRecord,
456+
)
457+
async for r in cur:
458+
self.assertIsInstance(r, CustomRecord)
459+
460+
ps = await self.con.prepare(
461+
"SELECT 1 as a, '2' as b",
462+
record_class=CustomRecord,
463+
)
464+
r = await ps.fetchrow()
465+
self.assertIsInstance(r, CustomRecord)
466+
467+
r = await ps.fetch()
468+
self.assertIsInstance(r[0], CustomRecord)
469+
470+
@tb.with_connection_options(record_class=CustomRecord)
471+
async def test_record_subclass_04(self):
472+
r = await self.con.fetchrow(
473+
"SELECT 1 as a, '2' as b",
474+
record_class=asyncpg.Record,
475+
)
476+
self.assertIs(type(r), asyncpg.Record)
477+
478+
r = await self.con.fetch(
479+
"SELECT 1 as a, '2' as b",
480+
record_class=asyncpg.Record,
481+
)
482+
self.assertIs(type(r[0]), asyncpg.Record)
483+
484+
async with self.con.transaction():
485+
cur = await self.con.cursor(
486+
"SELECT 1 as a, '2' as b",
487+
record_class=asyncpg.Record,
488+
)
489+
r = await cur.fetchrow()
490+
self.assertIs(type(r), asyncpg.Record)
491+
492+
cur = await self.con.cursor(
493+
"SELECT 1 as a, '2' as b",
494+
record_class=asyncpg.Record,
495+
)
496+
r = await cur.fetch(1)
497+
self.assertIs(type(r[0]), asyncpg.Record)
498+
499+
async with self.con.transaction():
500+
cur = self.con.cursor(
501+
"SELECT 1 as a, '2' as b",
502+
record_class=asyncpg.Record,
503+
)
504+
async for r in cur:
505+
self.assertIs(type(r), asyncpg.Record)
506+
507+
ps = await self.con.prepare(
508+
"SELECT 1 as a, '2' as b",
509+
record_class=asyncpg.Record,
510+
)
511+
r = await ps.fetchrow()
512+
self.assertIs(type(r), asyncpg.Record)
513+
514+
r = await ps.fetch()
515+
self.assertIs(type(r[0]), asyncpg.Record)

‎tests/test_timeout.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ async def test_command_timeout_01(self):
138138

139139
class SlowPrepareConnection(pg_connection.Connection):
140140
"""Connection class to test timeouts."""
141-
async def _get_statement(self, query, timeout):
141+
async def _get_statement(self, query, timeout, **kwargs):
142142
await asyncio.sleep(0.3)
143-
return await super()._get_statement(query, timeout)
143+
return await super()._get_statement(query, timeout, **kwargs)
144144

145145

146146
class TestTimeoutCoversPrepare(tb.ConnectedTestCase):

0 commit comments

Comments
 (0)
Please sign in to comment.