Skip to content

Commit 50f65fb

Browse files
authored
Untangle custom codec confusion (MagicStack#662)
Asyncpg currently erroneously prefers binary I/O for underlying type of arrays effectively ignoring a possible custom text codec that might have been configured on a type. Fix this by removing the explicit preference for binary I/O, so that the codec selection preference is now in the following order: - custom binary codec - custom text codec - builtin binary codec - builtin text codec Fixes: MagicStack#590 Reported-by: @neumond
1 parent 7252dbe commit 50f65fb

File tree

6 files changed

+123
-105
lines changed

6 files changed

+123
-105
lines changed

asyncpg/connection.py

+9
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,15 @@ async def set_type_codec(self, typename, *,
11561156
.. versionchanged:: 0.13.0
11571157
The ``binary`` keyword argument was removed in favor of
11581158
``format``.
1159+
1160+
.. note::
1161+
1162+
It is recommended to use the ``'binary'`` or ``'tuple'`` *format*
1163+
whenever possible and if the underlying type supports it. Asyncpg
1164+
currently does not support text I/O for composite and range types,
1165+
and some other functionality, such as
1166+
:meth:`Connection.copy_to_table`, does not support types with text
1167+
codecs.
11591168
"""
11601169
self._check_open()
11611170
typeinfo = await self._introspect_type(typename, schema)

asyncpg/introspection.py

+10-21
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,9 @@
3737
3838
ELSE NULL
3939
END) AS basetype,
40-
t.typreceive::oid != 0 AND t.typsend::oid != 0
41-
AS has_bin_io,
4240
t.typelem AS elemtype,
4341
elem_t.typdelim AS elemdelim,
4442
range_t.rngsubtype AS range_subtype,
45-
(CASE WHEN t.typtype = 'r' THEN
46-
(SELECT
47-
range_elem_t.typreceive::oid != 0 AND
48-
range_elem_t.typsend::oid != 0
49-
FROM
50-
pg_catalog.pg_type AS range_elem_t
51-
WHERE
52-
range_elem_t.oid = range_t.rngsubtype)
53-
ELSE
54-
elem_t.typreceive::oid != 0 AND
55-
elem_t.typsend::oid != 0
56-
END) AS elem_has_bin_io,
5743
(CASE WHEN t.typtype = 'c' THEN
5844
(SELECT
5945
array_agg(ia.atttypid ORDER BY ia.attnum)
@@ -98,12 +84,12 @@
9884

9985
INTRO_LOOKUP_TYPES = '''\
10086
WITH RECURSIVE typeinfo_tree(
101-
oid, ns, name, kind, basetype, has_bin_io, elemtype, elemdelim,
102-
range_subtype, elem_has_bin_io, attrtypoids, attrnames, depth)
87+
oid, ns, name, kind, basetype, elemtype, elemdelim,
88+
range_subtype, attrtypoids, attrnames, depth)
10389
AS (
10490
SELECT
105-
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io,
106-
ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io,
91+
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
92+
ti.elemtype, ti.elemdelim, ti.range_subtype,
10793
ti.attrtypoids, ti.attrnames, 0
10894
FROM
10995
{typeinfo} AS ti
@@ -113,8 +99,8 @@
11399
UNION ALL
114100
115101
SELECT
116-
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io,
117-
ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io,
102+
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
103+
ti.elemtype, ti.elemdelim, ti.range_subtype,
118104
ti.attrtypoids, ti.attrnames, tt.depth + 1
119105
FROM
120106
{typeinfo} ti,
@@ -126,7 +112,10 @@
126112
)
127113
128114
SELECT DISTINCT
129-
*
115+
*,
116+
basetype::regtype::text AS basetype_name,
117+
elemtype::regtype::text AS elemtype_name,
118+
range_subtype::regtype::text AS range_subtype_name
130119
FROM
131120
typeinfo_tree
132121
ORDER BY

asyncpg/protocol/codecs/base.pxd

+2-1
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,5 @@ cdef class DataCodecConfig:
168168

169169
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
170170
bint ignore_custom_codec=*)
171-
cdef inline Codec get_any_local_codec(self, uint32_t oid)
171+
cdef inline Codec get_custom_codec(self, uint32_t oid,
172+
ServerDataFormat format)

asyncpg/protocol/codecs/base.pyx

+64-73
Original file line numberDiff line numberDiff line change
@@ -440,14 +440,7 @@ cdef class DataCodecConfig:
440440
for ti in types:
441441
oid = ti['oid']
442442

443-
if not ti['has_bin_io']:
444-
format = PG_FORMAT_TEXT
445-
else:
446-
format = PG_FORMAT_BINARY
447-
448-
has_text_elements = False
449-
450-
if self.get_codec(oid, format) is not None:
443+
if self.get_codec(oid, PG_FORMAT_ANY) is not None:
451444
continue
452445

453446
name = ti['name']
@@ -468,92 +461,79 @@ cdef class DataCodecConfig:
468461
name = name[1:]
469462
name = '{}[]'.format(name)
470463

471-
if ti['elem_has_bin_io']:
472-
elem_format = PG_FORMAT_BINARY
473-
else:
474-
elem_format = PG_FORMAT_TEXT
475-
476-
elem_codec = self.get_codec(array_element_oid, elem_format)
464+
elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY)
477465
if elem_codec is None:
478-
elem_format = PG_FORMAT_TEXT
479466
elem_codec = self.declare_fallback_codec(
480-
array_element_oid, name, schema)
467+
array_element_oid, ti['elemtype_name'], schema)
481468

482469
elem_delim = <Py_UCS4>ti['elemdelim'][0]
483470

484-
self._derived_type_codecs[oid, elem_format] = \
471+
self._derived_type_codecs[oid, elem_codec.format] = \
485472
Codec.new_array_codec(
486473
oid, name, schema, elem_codec, elem_delim)
487474

488475
elif ti['kind'] == b'c':
476+
# Composite type
477+
489478
if not comp_type_attrs:
490479
raise exceptions.InternalClientError(
491-
'type record missing field types for '
492-
'composite {}'.format(oid))
493-
494-
# Composite type
480+
f'type record missing field types for composite {oid}')
495481

496482
comp_elem_codecs = []
483+
has_text_elements = False
497484

498485
for typoid in comp_type_attrs:
499-
elem_codec = self.get_codec(typoid, PG_FORMAT_BINARY)
500-
if elem_codec is None:
501-
elem_codec = self.get_codec(typoid, PG_FORMAT_TEXT)
502-
has_text_elements = True
486+
elem_codec = self.get_codec(typoid, PG_FORMAT_ANY)
503487
if elem_codec is None:
504488
raise exceptions.InternalClientError(
505-
'no codec for composite attribute type {}'.format(
506-
typoid))
489+
f'no codec for composite attribute type {typoid}')
490+
if elem_codec.format is PG_FORMAT_TEXT:
491+
has_text_elements = True
507492
comp_elem_codecs.append(elem_codec)
508493

509494
element_names = collections.OrderedDict()
510495
for i, attrname in enumerate(ti['attrnames']):
511496
element_names[attrname] = i
512497

498+
# If at least one element is text-encoded, we must
499+
# encode the whole composite as text.
513500
if has_text_elements:
514-
format = PG_FORMAT_TEXT
501+
elem_format = PG_FORMAT_TEXT
502+
else:
503+
elem_format = PG_FORMAT_BINARY
515504

516-
self._derived_type_codecs[oid, format] = \
505+
self._derived_type_codecs[oid, elem_format] = \
517506
Codec.new_composite_codec(
518-
oid, name, schema, format, comp_elem_codecs,
507+
oid, name, schema, elem_format, comp_elem_codecs,
519508
comp_type_attrs, element_names)
520509

521510
elif ti['kind'] == b'd':
522511
# Domain type
523512

524513
if not base_type:
525514
raise exceptions.InternalClientError(
526-
'type record missing base type for domain {}'.format(
527-
oid))
515+
f'type record missing base type for domain {oid}')
528516

529-
elem_codec = self.get_codec(base_type, format)
517+
elem_codec = self.get_codec(base_type, PG_FORMAT_ANY)
530518
if elem_codec is None:
531-
format = PG_FORMAT_TEXT
532519
elem_codec = self.declare_fallback_codec(
533-
base_type, name, schema)
520+
base_type, ti['basetype_name'], schema)
534521

535-
self._derived_type_codecs[oid, format] = elem_codec
522+
self._derived_type_codecs[oid, elem_codec.format] = elem_codec
536523

537524
elif ti['kind'] == b'r':
538525
# Range type
539526

540527
if not range_subtype_oid:
541528
raise exceptions.InternalClientError(
542-
'type record missing base type for range {}'.format(
543-
oid))
529+
f'type record missing base type for range {oid}')
544530

545-
if ti['elem_has_bin_io']:
546-
elem_format = PG_FORMAT_BINARY
547-
else:
548-
elem_format = PG_FORMAT_TEXT
549-
550-
elem_codec = self.get_codec(range_subtype_oid, elem_format)
531+
elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
551532
if elem_codec is None:
552-
elem_format = PG_FORMAT_TEXT
553533
elem_codec = self.declare_fallback_codec(
554-
range_subtype_oid, name, schema)
534+
range_subtype_oid, ti['range_subtype_name'], schema)
555535

556-
self._derived_type_codecs[oid, elem_format] = \
536+
self._derived_type_codecs[oid, elem_codec.format] = \
557537
Codec.new_range_codec(oid, name, schema, elem_codec)
558538

559539
elif ti['kind'] == b'e':
@@ -665,10 +645,6 @@ cdef class DataCodecConfig:
665645
def declare_fallback_codec(self, uint32_t oid, str name, str schema):
666646
cdef Codec codec
667647

668-
codec = self.get_codec(oid, PG_FORMAT_TEXT)
669-
if codec is not None:
670-
return codec
671-
672648
if oid <= MAXBUILTINOID:
673649
# This is a BKI type, for which asyncpg has no
674650
# defined codec. This should only happen for newly
@@ -695,34 +671,49 @@ cdef class DataCodecConfig:
695671
bint ignore_custom_codec=False):
696672
cdef Codec codec
697673

698-
if not ignore_custom_codec:
699-
codec = self.get_any_local_codec(oid)
700-
if codec is not None:
701-
if codec.format != format:
702-
# The codec for this OID has been overridden by
703-
# set_{builtin}_type_codec with a different format.
704-
# We must respect that and not return a core codec.
705-
return None
706-
else:
707-
return codec
708-
709-
codec = get_core_codec(oid, format)
710-
if codec is not None:
674+
if format == PG_FORMAT_ANY:
675+
codec = self.get_codec(
676+
oid, PG_FORMAT_BINARY, ignore_custom_codec)
677+
if codec is None:
678+
codec = self.get_codec(
679+
oid, PG_FORMAT_TEXT, ignore_custom_codec)
711680
return codec
712681
else:
713-
try:
714-
return self._derived_type_codecs[oid, format]
715-
except KeyError:
716-
return None
682+
if not ignore_custom_codec:
683+
codec = self.get_custom_codec(oid, PG_FORMAT_ANY)
684+
if codec is not None:
685+
if codec.format != format:
686+
# The codec for this OID has been overridden by
687+
# set_{builtin}_type_codec with a different format.
688+
# We must respect that and not return a core codec.
689+
return None
690+
else:
691+
return codec
692+
693+
codec = get_core_codec(oid, format)
694+
if codec is not None:
695+
return codec
696+
else:
697+
try:
698+
return self._derived_type_codecs[oid, format]
699+
except KeyError:
700+
return None
717701

718-
cdef inline Codec get_any_local_codec(self, uint32_t oid):
702+
cdef inline Codec get_custom_codec(
703+
self,
704+
uint32_t oid,
705+
ServerDataFormat format
706+
):
719707
cdef Codec codec
720708

721-
codec = self._custom_type_codecs.get((oid, PG_FORMAT_BINARY))
722-
if codec is None:
723-
return self._custom_type_codecs.get((oid, PG_FORMAT_TEXT))
709+
if format == PG_FORMAT_ANY:
710+
codec = self.get_custom_codec(oid, PG_FORMAT_BINARY)
711+
if codec is None:
712+
codec = self.get_custom_codec(oid, PG_FORMAT_TEXT)
724713
else:
725-
return codec
714+
codec = self._custom_type_codecs.get((oid, format))
715+
716+
return codec
726717

727718

728719
cdef inline Codec get_core_codec(

asyncpg/protocol/settings.pyx

+1-10
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,7 @@ cdef class ConnectionSettings(pgproto.CodecContext):
8989
cpdef inline Codec get_data_codec(self, uint32_t oid,
9090
ServerDataFormat format=PG_FORMAT_ANY,
9191
bint ignore_custom_codec=False):
92-
if format == PG_FORMAT_ANY:
93-
codec = self._data_codecs.get_codec(
94-
oid, PG_FORMAT_BINARY, ignore_custom_codec)
95-
if codec is None:
96-
codec = self._data_codecs.get_codec(
97-
oid, PG_FORMAT_TEXT, ignore_custom_codec)
98-
return codec
99-
else:
100-
return self._data_codecs.get_codec(
101-
oid, format, ignore_custom_codec)
92+
return self._data_codecs.get_codec(oid, format, ignore_custom_codec)
10293

10394
def __getattr__(self, name):
10495
if not name.startswith('_'):

tests/test_codecs.py

+37
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,34 @@ async def test_custom_codec_on_enum(self):
13291329
finally:
13301330
await self.con.execute('DROP TYPE custom_codec_t')
13311331

1332+
async def test_custom_codec_on_enum_array(self):
1333+
"""Test encoding/decoding using a custom codec on an enum array.
1334+
1335+
Bug: https://github.com/MagicStack/asyncpg/issues/590
1336+
"""
1337+
await self.con.execute('''
1338+
CREATE TYPE custom_codec_t AS ENUM ('foo', 'bar', 'baz')
1339+
''')
1340+
1341+
try:
1342+
await self.con.set_type_codec(
1343+
'custom_codec_t',
1344+
encoder=lambda v: str(v).lstrip('enum :'),
1345+
decoder=lambda v: 'enum: ' + str(v))
1346+
1347+
v = await self.con.fetchval(
1348+
"SELECT ARRAY['foo', 'bar']::custom_codec_t[]")
1349+
self.assertEqual(v, ['enum: foo', 'enum: bar'])
1350+
1351+
v = await self.con.fetchval(
1352+
'SELECT ARRAY[$1]::custom_codec_t[]', 'foo')
1353+
self.assertEqual(v, ['enum: foo'])
1354+
1355+
v = await self.con.fetchval("SELECT 'foo'::custom_codec_t")
1356+
self.assertEqual(v, 'enum: foo')
1357+
finally:
1358+
await self.con.execute('DROP TYPE custom_codec_t')
1359+
13321360
async def test_custom_codec_override_binary(self):
13331361
"""Test overriding core codecs."""
13341362
import json
@@ -1374,6 +1402,14 @@ def _decoder(value):
13741402
res = await conn.fetchval('SELECT $1::json', data)
13751403
self.assertEqual(data, res)
13761404

1405+
res = await conn.fetchval('SELECT $1::json[]', [data])
1406+
self.assertEqual([data], res)
1407+
1408+
await conn.execute('CREATE DOMAIN my_json AS json')
1409+
1410+
res = await conn.fetchval('SELECT $1::my_json', data)
1411+
self.assertEqual(data, res)
1412+
13771413
def _encoder(value):
13781414
return value
13791415

@@ -1389,6 +1425,7 @@ def _decoder(value):
13891425
res = await conn.fetchval('SELECT $1::uuid', data)
13901426
self.assertEqual(res, data)
13911427
finally:
1428+
await conn.execute('DROP DOMAIN IF EXISTS my_json')
13921429
await conn.close()
13931430

13941431
async def test_custom_codec_override_tuple(self):

0 commit comments

Comments
 (0)