Skip to content

Commit 7b8a7c8

Browse files
authored
feat(python): Implement extension type/mechanism in python package (#688)
This PR implements the first canonical extension type for nanoarrow/Python: `nanoarrow.bool8()`. In doing so it also implements some machinery for "extensions", which is intended to be internal and to evolve with the requirements of some follow-up canonical extension types. The extension points are: - Parsing the metadata when `Schema.extension` is accessed to get parameter access (and validate the metadata) - Converting an extension array to Python objects (i.e., `array.to_pylist()`) - Converting an extension array to a Python sequence (i.e., `array.to_pysequence()`) - Constructing an extension array from Python objects (i.e., `na.Array([True, False, None], na.bool8())` - Constructing an extension array from a Python buffer (i.e., `na.Array(np.array([True, False, False]), na.bool8())` I am not sure the extension point implementations via the `Extension` methods are the final way this should be done, but this PR at least connects the wires. The tensor extensions are more complex and will require some modification of this, but I'd like to do that in another PR. ```python import nanoarrow as na import numpy as np na.Array([True, False, None], na.bool8()) #> nanoarrow.Array<arrow.bool8{int8}>[3] #> True #> False #> None na.Array(np.array([True, False, True]), na.bool8()) #> nanoarrow.Array<arrow.bool8{int8}>[3] #> True #> False #> True na.Array([True, False, None], na.bool8()).to_pylist() #> [True, False, None] np.array(na.Array([True, False, True], na.bool8()).to_pysequence()) #> array([ True, False, True]) ```
1 parent e54b7df commit 7b8a7c8

13 files changed

+470
-12
lines changed

python/src/nanoarrow/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from nanoarrow.c_array_stream import c_array_stream
3030
from nanoarrow.c_schema import c_schema
3131
from nanoarrow.c_buffer import c_buffer
32+
from nanoarrow.extension_canonical import bool8
3233
from nanoarrow.schema import (
3334
Schema,
3435
Type,
@@ -87,6 +88,7 @@
8788
"binary",
8889
"binary_view",
8990
"bool_",
91+
"bool8",
9092
"c_array",
9193
"c_array_from_buffers",
9294
"c_array_stream",

python/src/nanoarrow/_schema.pyx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,8 @@ cdef class CSchemaView:
555555
(_types.TIMESTAMP, _types.DATE64, _types.DURATION)
556556
):
557557
return 'q'
558+
elif self.extension_name:
559+
return self._get_buffer_format()
558560
else:
559561
return None
560562

@@ -564,7 +566,13 @@ cdef class CSchemaView:
564566
or None if there is no Python format string that can represent this
565567
type without loosing information.
566568
"""
567-
if self.extension_name or self._schema_view.type != self._schema_view.storage_type:
569+
if self.extension_name:
570+
return None
571+
else:
572+
return self._get_buffer_format()
573+
574+
def _get_buffer_format(self):
575+
if self._schema_view.type != self._schema_view.storage_type:
568576
return None
569577

570578
# String/binary types do not have format strings as far as the Python

python/src/nanoarrow/c_array.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from nanoarrow._utils import obj_is_buffer, obj_is_capsule
2525
from nanoarrow.c_buffer import c_buffer
2626
from nanoarrow.c_schema import c_schema, c_schema_view
27+
from nanoarrow.extension import resolve_extension
2728

2829
from nanoarrow import _types
2930

@@ -416,6 +417,16 @@ def infer_schema(cls, obj) -> Tuple[CBuffer, CSchema]:
416417
def __init__(self, schema):
417418
super().__init__(schema)
418419

420+
ext = resolve_extension(self._schema_view)
421+
self._append_ext = None
422+
if ext is not None:
423+
self._append_ext = ext.get_buffer_appender(self._schema, self)
424+
elif self._schema_view.extension_name:
425+
raise NotImplementedError(
426+
"Can't create array for unregistered extension "
427+
f"'{self._schema_view.extension_name}'"
428+
)
429+
419430
if self._schema_view.storage_buffer_format is None:
420431
raise ValueError(
421432
f"Can't build array of type {self._schema_view.type} from PyBuffer"
@@ -428,6 +439,12 @@ def append(self, obj: Any) -> None:
428439
if not isinstance(obj, CBuffer):
429440
obj = CBuffer.from_pybuffer(obj)
430441

442+
if self._append_ext is not None:
443+
return self._append_ext(obj)
444+
445+
return self._append_impl(obj)
446+
447+
def _append_impl(self, obj):
431448
if (
432449
self._schema_view.buffer_format in ("b", "c")
433450
and obj.format not in ("b", "c")
@@ -462,6 +479,18 @@ def __init__(self, schema):
462479

463480
# Resolve the method name we are going to use to do the building from
464481
# the provided schema.
482+
ext = resolve_extension(self._schema_view)
483+
if ext is not None:
484+
maybe_appender = ext.get_iterable_appender(self._schema, self)
485+
if maybe_appender:
486+
self._append_impl = maybe_appender
487+
return
488+
elif self._schema_view.extension_name:
489+
raise NotImplementedError(
490+
f"Can't create array for unregistered extension "
491+
f"'{self._schema_view.extension_name}'"
492+
)
493+
465494
type_id = self._schema_view.type_id
466495
if type_id not in _ARRAY_BUILDER_FROM_ITERABLE_METHOD:
467496
raise ValueError(

python/src/nanoarrow/extension.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Type
19+
20+
from nanoarrow.c_schema import CSchema, CSchemaView, c_schema_view
21+
22+
23+
class Extension:
24+
"""Define a nanoarrow Extension
25+
26+
A nanoarrow extension customizes behaviour of built-in operations
27+
applicable to a specific type. This is currently implemented only
28+
for Arrow extension types but could in theory apply if one wanted
29+
to customize the conversion behaviour of a specific non-extension
30+
type.
31+
32+
This is currently internal and involves knowledge of other internal
33+
nanoarrow/Python structures. It is currently used only to implement
34+
canonical extensions with the anticipation of evolving to support
35+
user-defined extensions as the internal APIs on which it relies
36+
stabilize.
37+
38+
With the current design, an Extension subclass must be constructible
39+
with no parameters (e.g., ``Extension()``).
40+
"""
41+
42+
def get_schema(self) -> CSchema:
43+
"""Get the schema for which this extension applies.
44+
45+
This is used by :func:`register_extension` to ensure that it can be resolved
46+
when needed.
47+
"""
48+
raise NotImplementedError()
49+
50+
def get_params(self, c_schema: CSchema) -> Mapping[str, Any]:
51+
"""Compute a dictionary of type parameters.
52+
53+
These parameters are accessible via the :class:`Schema`
54+
``extension`` attribute (e.g., ``schema.extension.param_name``).
55+
Internal parameters can also be returned but should be prefixed with
56+
an underscore.
57+
58+
This method should also error if the storage type or any other property
59+
of the schema is not valid.
60+
"""
61+
return {}
62+
63+
def get_pyiter(
64+
self,
65+
py_iterator,
66+
offset: int,
67+
length: int,
68+
) -> Optional[Iterator[Optional[bool]]]:
69+
"""Compute an iterable of Python objects.
70+
71+
Used by ``to_pylist()`` to generate scalars for a particular type.
72+
If ``None`` is returned, the behaviour of the storage type will be
73+
used without warning.
74+
75+
This method is currently passed the underlying :class:`PyIterator`
76+
and returns an iterator; however, it could in the future be passed
77+
a :class:`CSchema` and return a PyIterator class once that class
78+
structure is stabilized.
79+
"""
80+
name = py_iterator._schema_view.extension_name
81+
raise NotImplementedError(f"Extension get_pyiter() for {name}")
82+
83+
def get_sequence_converter(self, c_schema: CSchema):
84+
"""Return an ArrayViewVisitor subclass used to compute a sequence from
85+
a stream of arrays.
86+
87+
This is currently implemented outside the null handler and may need a flag
88+
at some point to indicate that it did or did not handle its own nulls.
89+
"""
90+
schema_view = c_schema_view(c_schema)
91+
name = schema_view.extension_name
92+
raise NotImplementedError(f"Extension get_sequence_converter() for {name}")
93+
94+
def get_buffer_appender(
95+
self, c_schema: CSchema, array_builder
96+
) -> Optional[Callable[[Any], None]]:
97+
"""Compute a function that prepares a :class:`CArrayBuilder` from a
98+
buffer.
99+
100+
This is used to customize the behavior of creating a CArray from an
101+
object implementing the Python buffer protocol. If ``None`` is
102+
returned, the storage will be converted without a warning.
103+
104+
This method is currently passed a :class:`CArrayBuilder` but in
105+
the future should perhaps be passed a :class:`CSchema` and return a
106+
CArrayBuilder class.
107+
"""
108+
schema_view = c_schema_view(c_schema)
109+
name = schema_view.extension_name
110+
raise NotImplementedError(f"Extension get_buffer_appender() for {name}")
111+
112+
def get_iterable_appender(
113+
self, c_schema: CSchema, array_builder
114+
) -> Optional[Callable[[Iterable], None]]:
115+
"""Compute a function that prepares a :class:`CArrayBuilder` from a
116+
buffer.
117+
118+
This is used to customize the behavior of creating a CArray from an
119+
iterable of Python objects.
120+
121+
This method is currently passed a :class:`CArrayBuilder` but in
122+
the future should perhaps be passed a :class:`CSchema` and return a
123+
CArrayBuilder class.
124+
"""
125+
schema_view = c_schema_view(c_schema)
126+
name = schema_view.extension_name
127+
raise NotImplementedError(f"Extension get_iterable_appender() for {name}")
128+
129+
130+
_global_extension_registry = {}
131+
132+
133+
def resolve_extension(c_schema_view: CSchemaView) -> Optional[Extension]:
134+
"""Resolve an extension instance from a :class:`CSchemaView`
135+
136+
Returns the registered extension instance if one applies to the passed
137+
type or ``None`` otherwise.
138+
"""
139+
extension_name = c_schema_view.extension_name
140+
if extension_name in _global_extension_registry:
141+
return _global_extension_registry[extension_name]
142+
143+
return None
144+
145+
146+
def register_extension(extension: Extension) -> Optional[Extension]:
147+
"""Register an :class:`Extension` instance in the global registry.
148+
149+
Inserts an extension into the global registry, returning the
150+
previously registered extension for that type if one exists
151+
(or ``None`` otherwise).
152+
"""
153+
global _global_extension_registry
154+
155+
schema_view = c_schema_view(extension.get_schema())
156+
key = schema_view.extension_name
157+
prev = resolve_extension(schema_view)
158+
_global_extension_registry[key] = extension
159+
return prev
160+
161+
162+
def unregister_extension(extension_name: str):
163+
"""Remove an extension from the global registry by extension name.
164+
165+
Returns the removed extension. Raises ``KeyError`` if there was no
166+
extension registered for this extension name.
167+
"""
168+
prev = _global_extension_registry[extension_name]
169+
del _global_extension_registry[extension_name]
170+
return prev
171+
172+
173+
def register(extension_cls: Type[Extension]):
174+
"""Decorator that registers an extension class by instantiating it
175+
and adding it to the global registry."""
176+
register_extension(extension_cls())
177+
return extension_cls
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from typing import Any, Iterator, Mapping, Optional
19+
20+
from nanoarrow.c_buffer import CBufferBuilder
21+
from nanoarrow.c_schema import CSchema, c_schema_view
22+
from nanoarrow.schema import extension_type, int8
23+
from nanoarrow.visitor import ToPyBufferConverter
24+
25+
from nanoarrow import extension
26+
27+
28+
def bool8(nullable: bool = True):
29+
"""Create a type representing a boolean encoded as one byte per value
30+
31+
Parameters
32+
----------
33+
nullable : bool, optional
34+
Use ``False`` to mark this field as non-nullable.
35+
"""
36+
37+
return extension_type(int8(), "arrow.bool8", nullable=nullable)
38+
39+
40+
class Bool8SequenceConverter(ToPyBufferConverter):
41+
def _make_builder(self):
42+
return CBufferBuilder().set_format("?")
43+
44+
45+
@extension.register
46+
class Bool8Extension(extension.Extension):
47+
def get_schema(self) -> CSchema:
48+
return bool8()
49+
50+
def get_params(self, c_schema: CSchema) -> Mapping[str, Any]:
51+
schema_view = c_schema_view(c_schema)
52+
if schema_view.type != "int8":
53+
raise ValueError("arrow.bool8 must have storage type int8")
54+
55+
return {}
56+
57+
def get_pyiter(
58+
self,
59+
py_iterator,
60+
offset: int,
61+
length: int,
62+
) -> Optional[Iterator[Optional[bool]]]:
63+
view = py_iterator._array_view
64+
items = map(bool, view.buffer(1).elements(offset, length))
65+
66+
if py_iterator._contains_nulls():
67+
validity = view.buffer(0).elements(offset, length)
68+
return py_iterator._wrap_iter_nullable(validity, items)
69+
else:
70+
return items
71+
72+
def get_sequence_converter(self, c_schema: CSchema):
73+
self.get_params(c_schema)
74+
return Bool8SequenceConverter
75+
76+
def get_sequence_appender(self, c_schema: CSchema, array_builder):
77+
self.get_params(c_schema)
78+
return None
79+
80+
def get_buffer_appender(self, c_schema: CSchema, array_builder):
81+
self.get_params(c_schema)
82+
return None
83+
84+
def get_iterable_appender(self, c_schema: CSchema, array_builder):
85+
self.get_params(c_schema)
86+
return None

python/src/nanoarrow/iterator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from nanoarrow._array import CArrayView
2424
from nanoarrow.c_array_stream import c_array_stream
2525
from nanoarrow.c_schema import c_schema, c_schema_view
26+
from nanoarrow.extension import resolve_extension
2627
from nanoarrow.schema import Schema
2728

2829
from nanoarrow import _types
@@ -183,6 +184,8 @@ def get_iterator(cls, obj, schema=None):
183184

184185
def __init__(self, schema, *, array_view=None):
185186
super().__init__(schema, array_view=array_view)
187+
self._ext = resolve_extension(self._schema_view)
188+
self._ext_params = self._ext.get_params(schema) if self._ext else None
186189

187190
self._children = list(
188191
map(self._make_child, self._schema.children, self._array_view.children)
@@ -208,11 +211,15 @@ def __iter__(self):
208211

209212
def _iter_chunk(self, offset, length):
210213
"""Iterate over all elements in a slice of the current chunk"""
214+
211215
# Check for an extension type first since this isn't reflected by
212-
# self._schema_view.type_id. Currently we just return the storage
213-
# iterator with a warning for extension types.
216+
# self._schema_view.type_id.
214217
maybe_extension_name = self._schema_view.extension_name
215-
if maybe_extension_name:
218+
if self._ext:
219+
maybe_iter = self._ext.get_pyiter(self, offset, length)
220+
if maybe_iter:
221+
return maybe_iter
222+
elif maybe_extension_name:
216223
self._warn(
217224
f"Converting unregistered extension '{maybe_extension_name}' "
218225
"as storage type",

0 commit comments

Comments
 (0)