From 7b8a7c85ad7456c5e6181bdae1920ddace8f9044 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 6 Dec 2024 21:51:35 +0000 Subject: [PATCH] feat(python): Implement extension type/mechanism in python package (#688) This PR implements the first canonical extension type for nanoarrow/Python: `nanoarrow.bool8()`. In doing so it also implements some machinery for "extensions", which is intended to be internal and to evolve with the requirements of some follow-up canonical extension types. The extension points are: - Parsing the metadata when `Schema.extension` is accessed to get parameter access (and validate the metadata) - Converting an extension array to Python objects (i.e., `array.to_pylist()`) - Converting an extension array to a Python sequence (i.e., `array.to_pysequence()`) - Constructing an extension array from Python objects (i.e., `na.Array([True, False, None], na.bool8())` - Constructing an extension array from a Python buffer (i.e., `na.Array(np.array([True, False, False]), na.bool8())` I am not sure the extension point implementations via the `Extension` methods are the final way this should be done, but this PR at least connects the wires. The tensor extensions are more complex and will require some modification of this, but I'd like to do that in another PR. ```python import nanoarrow as na import numpy as np na.Array([True, False, None], na.bool8()) #> nanoarrow.Array[3] #> True #> False #> None na.Array(np.array([True, False, True]), na.bool8()) #> nanoarrow.Array[3] #> True #> False #> True na.Array([True, False, None], na.bool8()).to_pylist() #> [True, False, None] np.array(na.Array([True, False, True], na.bool8()).to_pysequence()) #> array([ True, False, True]) ``` --- python/src/nanoarrow/__init__.py | 2 + python/src/nanoarrow/_schema.pyx | 10 +- python/src/nanoarrow/c_array.py | 29 ++++ python/src/nanoarrow/extension.py | 177 ++++++++++++++++++++ python/src/nanoarrow/extension_canonical.py | 86 ++++++++++ python/src/nanoarrow/iterator.py | 13 +- python/src/nanoarrow/meson.build | 2 + python/src/nanoarrow/schema.py | 15 +- python/src/nanoarrow/visitor.py | 21 ++- python/tests/test_extension.py | 64 +++++++ python/tests/test_extension_canonical.py | 46 +++++ python/tests/test_schema.py | 2 +- python/tests/test_visitor.py | 15 ++ 13 files changed, 470 insertions(+), 12 deletions(-) create mode 100644 python/src/nanoarrow/extension.py create mode 100644 python/src/nanoarrow/extension_canonical.py create mode 100644 python/tests/test_extension.py create mode 100644 python/tests/test_extension_canonical.py diff --git a/python/src/nanoarrow/__init__.py b/python/src/nanoarrow/__init__.py index 7f67dd304..62221eeb3 100644 --- a/python/src/nanoarrow/__init__.py +++ b/python/src/nanoarrow/__init__.py @@ -29,6 +29,7 @@ from nanoarrow.c_array_stream import c_array_stream from nanoarrow.c_schema import c_schema from nanoarrow.c_buffer import c_buffer +from nanoarrow.extension_canonical import bool8 from nanoarrow.schema import ( Schema, Type, @@ -87,6 +88,7 @@ "binary", "binary_view", "bool_", + "bool8", "c_array", "c_array_from_buffers", "c_array_stream", diff --git a/python/src/nanoarrow/_schema.pyx b/python/src/nanoarrow/_schema.pyx index 3e82c0659..c1deaec72 100644 --- a/python/src/nanoarrow/_schema.pyx +++ b/python/src/nanoarrow/_schema.pyx @@ -555,6 +555,8 @@ cdef class CSchemaView: (_types.TIMESTAMP, _types.DATE64, _types.DURATION) ): return 'q' + elif self.extension_name: + return self._get_buffer_format() else: return None @@ -564,7 +566,13 @@ cdef class CSchemaView: or None if there is no Python format string that can represent this type without loosing information. """ - if self.extension_name or self._schema_view.type != self._schema_view.storage_type: + if self.extension_name: + return None + else: + return self._get_buffer_format() + + def _get_buffer_format(self): + if self._schema_view.type != self._schema_view.storage_type: return None # String/binary types do not have format strings as far as the Python diff --git a/python/src/nanoarrow/c_array.py b/python/src/nanoarrow/c_array.py index 0c71bda45..57390b7c1 100644 --- a/python/src/nanoarrow/c_array.py +++ b/python/src/nanoarrow/c_array.py @@ -24,6 +24,7 @@ from nanoarrow._utils import obj_is_buffer, obj_is_capsule from nanoarrow.c_buffer import c_buffer from nanoarrow.c_schema import c_schema, c_schema_view +from nanoarrow.extension import resolve_extension from nanoarrow import _types @@ -416,6 +417,16 @@ def infer_schema(cls, obj) -> Tuple[CBuffer, CSchema]: def __init__(self, schema): super().__init__(schema) + ext = resolve_extension(self._schema_view) + self._append_ext = None + if ext is not None: + self._append_ext = ext.get_buffer_appender(self._schema, self) + elif self._schema_view.extension_name: + raise NotImplementedError( + "Can't create array for unregistered extension " + f"'{self._schema_view.extension_name}'" + ) + if self._schema_view.storage_buffer_format is None: raise ValueError( f"Can't build array of type {self._schema_view.type} from PyBuffer" @@ -428,6 +439,12 @@ def append(self, obj: Any) -> None: if not isinstance(obj, CBuffer): obj = CBuffer.from_pybuffer(obj) + if self._append_ext is not None: + return self._append_ext(obj) + + return self._append_impl(obj) + + def _append_impl(self, obj): if ( self._schema_view.buffer_format in ("b", "c") and obj.format not in ("b", "c") @@ -462,6 +479,18 @@ def __init__(self, schema): # Resolve the method name we are going to use to do the building from # the provided schema. + ext = resolve_extension(self._schema_view) + if ext is not None: + maybe_appender = ext.get_iterable_appender(self._schema, self) + if maybe_appender: + self._append_impl = maybe_appender + return + elif self._schema_view.extension_name: + raise NotImplementedError( + f"Can't create array for unregistered extension " + f"'{self._schema_view.extension_name}'" + ) + type_id = self._schema_view.type_id if type_id not in _ARRAY_BUILDER_FROM_ITERABLE_METHOD: raise ValueError( diff --git a/python/src/nanoarrow/extension.py b/python/src/nanoarrow/extension.py new file mode 100644 index 000000000..3ecb2fbfa --- /dev/null +++ b/python/src/nanoarrow/extension.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Type + +from nanoarrow.c_schema import CSchema, CSchemaView, c_schema_view + + +class Extension: + """Define a nanoarrow Extension + + A nanoarrow extension customizes behaviour of built-in operations + applicable to a specific type. This is currently implemented only + for Arrow extension types but could in theory apply if one wanted + to customize the conversion behaviour of a specific non-extension + type. + + This is currently internal and involves knowledge of other internal + nanoarrow/Python structures. It is currently used only to implement + canonical extensions with the anticipation of evolving to support + user-defined extensions as the internal APIs on which it relies + stabilize. + + With the current design, an Extension subclass must be constructible + with no parameters (e.g., ``Extension()``). + """ + + def get_schema(self) -> CSchema: + """Get the schema for which this extension applies. + + This is used by :func:`register_extension` to ensure that it can be resolved + when needed. + """ + raise NotImplementedError() + + def get_params(self, c_schema: CSchema) -> Mapping[str, Any]: + """Compute a dictionary of type parameters. + + These parameters are accessible via the :class:`Schema` + ``extension`` attribute (e.g., ``schema.extension.param_name``). + Internal parameters can also be returned but should be prefixed with + an underscore. + + This method should also error if the storage type or any other property + of the schema is not valid. + """ + return {} + + def get_pyiter( + self, + py_iterator, + offset: int, + length: int, + ) -> Optional[Iterator[Optional[bool]]]: + """Compute an iterable of Python objects. + + Used by ``to_pylist()`` to generate scalars for a particular type. + If ``None`` is returned, the behaviour of the storage type will be + used without warning. + + This method is currently passed the underlying :class:`PyIterator` + and returns an iterator; however, it could in the future be passed + a :class:`CSchema` and return a PyIterator class once that class + structure is stabilized. + """ + name = py_iterator._schema_view.extension_name + raise NotImplementedError(f"Extension get_pyiter() for {name}") + + def get_sequence_converter(self, c_schema: CSchema): + """Return an ArrayViewVisitor subclass used to compute a sequence from + a stream of arrays. + + This is currently implemented outside the null handler and may need a flag + at some point to indicate that it did or did not handle its own nulls. + """ + schema_view = c_schema_view(c_schema) + name = schema_view.extension_name + raise NotImplementedError(f"Extension get_sequence_converter() for {name}") + + def get_buffer_appender( + self, c_schema: CSchema, array_builder + ) -> Optional[Callable[[Any], None]]: + """Compute a function that prepares a :class:`CArrayBuilder` from a + buffer. + + This is used to customize the behavior of creating a CArray from an + object implementing the Python buffer protocol. If ``None`` is + returned, the storage will be converted without a warning. + + This method is currently passed a :class:`CArrayBuilder` but in + the future should perhaps be passed a :class:`CSchema` and return a + CArrayBuilder class. + """ + schema_view = c_schema_view(c_schema) + name = schema_view.extension_name + raise NotImplementedError(f"Extension get_buffer_appender() for {name}") + + def get_iterable_appender( + self, c_schema: CSchema, array_builder + ) -> Optional[Callable[[Iterable], None]]: + """Compute a function that prepares a :class:`CArrayBuilder` from a + buffer. + + This is used to customize the behavior of creating a CArray from an + iterable of Python objects. + + This method is currently passed a :class:`CArrayBuilder` but in + the future should perhaps be passed a :class:`CSchema` and return a + CArrayBuilder class. + """ + schema_view = c_schema_view(c_schema) + name = schema_view.extension_name + raise NotImplementedError(f"Extension get_iterable_appender() for {name}") + + +_global_extension_registry = {} + + +def resolve_extension(c_schema_view: CSchemaView) -> Optional[Extension]: + """Resolve an extension instance from a :class:`CSchemaView` + + Returns the registered extension instance if one applies to the passed + type or ``None`` otherwise. + """ + extension_name = c_schema_view.extension_name + if extension_name in _global_extension_registry: + return _global_extension_registry[extension_name] + + return None + + +def register_extension(extension: Extension) -> Optional[Extension]: + """Register an :class:`Extension` instance in the global registry. + + Inserts an extension into the global registry, returning the + previously registered extension for that type if one exists + (or ``None`` otherwise). + """ + global _global_extension_registry + + schema_view = c_schema_view(extension.get_schema()) + key = schema_view.extension_name + prev = resolve_extension(schema_view) + _global_extension_registry[key] = extension + return prev + + +def unregister_extension(extension_name: str): + """Remove an extension from the global registry by extension name. + + Returns the removed extension. Raises ``KeyError`` if there was no + extension registered for this extension name. + """ + prev = _global_extension_registry[extension_name] + del _global_extension_registry[extension_name] + return prev + + +def register(extension_cls: Type[Extension]): + """Decorator that registers an extension class by instantiating it + and adding it to the global registry.""" + register_extension(extension_cls()) + return extension_cls diff --git a/python/src/nanoarrow/extension_canonical.py b/python/src/nanoarrow/extension_canonical.py new file mode 100644 index 000000000..2e3dc6d43 --- /dev/null +++ b/python/src/nanoarrow/extension_canonical.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Iterator, Mapping, Optional + +from nanoarrow.c_buffer import CBufferBuilder +from nanoarrow.c_schema import CSchema, c_schema_view +from nanoarrow.schema import extension_type, int8 +from nanoarrow.visitor import ToPyBufferConverter + +from nanoarrow import extension + + +def bool8(nullable: bool = True): + """Create a type representing a boolean encoded as one byte per value + + Parameters + ---------- + nullable : bool, optional + Use ``False`` to mark this field as non-nullable. + """ + + return extension_type(int8(), "arrow.bool8", nullable=nullable) + + +class Bool8SequenceConverter(ToPyBufferConverter): + def _make_builder(self): + return CBufferBuilder().set_format("?") + + +@extension.register +class Bool8Extension(extension.Extension): + def get_schema(self) -> CSchema: + return bool8() + + def get_params(self, c_schema: CSchema) -> Mapping[str, Any]: + schema_view = c_schema_view(c_schema) + if schema_view.type != "int8": + raise ValueError("arrow.bool8 must have storage type int8") + + return {} + + def get_pyiter( + self, + py_iterator, + offset: int, + length: int, + ) -> Optional[Iterator[Optional[bool]]]: + view = py_iterator._array_view + items = map(bool, view.buffer(1).elements(offset, length)) + + if py_iterator._contains_nulls(): + validity = view.buffer(0).elements(offset, length) + return py_iterator._wrap_iter_nullable(validity, items) + else: + return items + + def get_sequence_converter(self, c_schema: CSchema): + self.get_params(c_schema) + return Bool8SequenceConverter + + def get_sequence_appender(self, c_schema: CSchema, array_builder): + self.get_params(c_schema) + return None + + def get_buffer_appender(self, c_schema: CSchema, array_builder): + self.get_params(c_schema) + return None + + def get_iterable_appender(self, c_schema: CSchema, array_builder): + self.get_params(c_schema) + return None diff --git a/python/src/nanoarrow/iterator.py b/python/src/nanoarrow/iterator.py index fc6e1428f..f2e616155 100644 --- a/python/src/nanoarrow/iterator.py +++ b/python/src/nanoarrow/iterator.py @@ -23,6 +23,7 @@ from nanoarrow._array import CArrayView from nanoarrow.c_array_stream import c_array_stream from nanoarrow.c_schema import c_schema, c_schema_view +from nanoarrow.extension import resolve_extension from nanoarrow.schema import Schema from nanoarrow import _types @@ -183,6 +184,8 @@ def get_iterator(cls, obj, schema=None): def __init__(self, schema, *, array_view=None): super().__init__(schema, array_view=array_view) + self._ext = resolve_extension(self._schema_view) + self._ext_params = self._ext.get_params(schema) if self._ext else None self._children = list( map(self._make_child, self._schema.children, self._array_view.children) @@ -208,11 +211,15 @@ def __iter__(self): def _iter_chunk(self, offset, length): """Iterate over all elements in a slice of the current chunk""" + # Check for an extension type first since this isn't reflected by - # self._schema_view.type_id. Currently we just return the storage - # iterator with a warning for extension types. + # self._schema_view.type_id. maybe_extension_name = self._schema_view.extension_name - if maybe_extension_name: + if self._ext: + maybe_iter = self._ext.get_pyiter(self, offset, length) + if maybe_iter: + return maybe_iter + elif maybe_extension_name: self._warn( f"Converting unregistered extension '{maybe_extension_name}' " "as storage type", diff --git a/python/src/nanoarrow/meson.build b/python/src/nanoarrow/meson.build index ee52e9c74..9b8c223b5 100644 --- a/python/src/nanoarrow/meson.build +++ b/python/src/nanoarrow/meson.build @@ -101,6 +101,8 @@ py_sources = [ 'c_buffer.py', 'c_schema.py', 'device.py', + 'extension.py', + 'extension_canonical.py', 'ipc.py', 'iterator.py', '_repr_utils.py', diff --git a/python/src/nanoarrow/schema.py b/python/src/nanoarrow/schema.py index 67412e994..99616eeeb 100644 --- a/python/src/nanoarrow/schema.py +++ b/python/src/nanoarrow/schema.py @@ -27,6 +27,7 @@ SchemaMetadata, ) from nanoarrow.c_schema import c_schema +from nanoarrow.extension import resolve_extension from nanoarrow import _repr_utils, _types @@ -124,6 +125,11 @@ class ExtensionAccessor: def __init__(self, schema) -> None: self._schema = schema + self._ext = resolve_extension(self._schema._c_schema_view) + self._params = self._ext.get_params(self._schema) if self._ext else {} + + def __dir__(self) -> List[str]: + return ["name", "metadata", "storage"] + list(self._params.keys()) @property def name(self) -> str: @@ -131,10 +137,10 @@ def name(self) -> str: return self._schema._c_schema_view.extension_name @property - def metadata(self) -> Union[bytes, None]: + def metadata(self) -> bytes: """Extension metadata for this extension type if present""" extension_metadata = self._schema._c_schema_view.extension_metadata - return extension_metadata if extension_metadata else None + return extension_metadata if extension_metadata else b"" @property def storage(self): @@ -148,6 +154,9 @@ def storage(self): return Schema(self._schema, metadata=metadata) + def __getattr__(self, key: str): + return self._params[key] + class Schema: """Create a nanoarrow Schema @@ -1305,6 +1314,8 @@ def extension_type( metadata["ARROW:extension:name"] = extension_name if extension_metadata: metadata["ARROW:extension:metadata"] = extension_metadata + else: + metadata["ARROW:extension:metadata"] = "" return Schema(storage_schema, nullable=nullable, metadata=metadata) diff --git a/python/src/nanoarrow/visitor.py b/python/src/nanoarrow/visitor.py index d1f95e473..2e11e22c0 100644 --- a/python/src/nanoarrow/visitor.py +++ b/python/src/nanoarrow/visitor.py @@ -21,6 +21,7 @@ from nanoarrow._buffer import CBuffer, CBufferBuilder from nanoarrow.c_array_stream import c_array_stream from nanoarrow.c_schema import c_schema_view +from nanoarrow.extension import resolve_extension from nanoarrow.iterator import ArrayViewBaseIterator, PyIterator from nanoarrow.schema import Type @@ -333,8 +334,7 @@ def finish(self) -> List: class ToPyBufferConverter(ArrayViewVisitor): def begin(self, total_elements: Union[int, None]): - self._builder = CBufferBuilder() - self._builder.set_format(self._schema_view.buffer_format) + self._builder = self._make_builder() if total_elements is not None: element_size_bits = self._schema_view.layout.element_size_bits[1] @@ -353,6 +353,9 @@ def visit_chunk_view(self, array_view: CArrayView) -> None: def finish(self) -> Any: return self._builder.finish() + def _make_builder(self): + return CBufferBuilder().set_format(self._schema_view.buffer_format) + class ToBooleanBufferConverter(ArrayViewVisitor): def begin(self, total_elements: Union[int, None]): @@ -428,9 +431,16 @@ def finish(self) -> Any: def _resolve_converter_cls(schema, handle_nulls=None): schema_view = c_schema_view(schema) + ext = resolve_extension(schema_view) + ext_converter_cls = ext.get_sequence_converter(schema) if ext else None if schema_view.nullable: - if schema_view.type_id == _types.BOOL: + if ext_converter_cls: + return ToNullableSequenceConverter, { + "converter_cls": ext_converter_cls, + "handle_nulls": handle_nulls, + } + elif schema_view.type_id == _types.BOOL: return ToNullableSequenceConverter, { "converter_cls": ToBooleanBufferConverter, "handle_nulls": handle_nulls, @@ -443,8 +453,9 @@ def _resolve_converter_cls(schema, handle_nulls=None): else: return ToPyListConverter, {} else: - - if schema_view.type_id == _types.BOOL: + if ext_converter_cls: + return ext_converter_cls, {} + elif schema_view.type_id == _types.BOOL: return ToBooleanBufferConverter, {} elif schema_view.buffer_format is not None: return ToPyBufferConverter, {} diff --git a/python/tests/test_extension.py b/python/tests/test_extension.py new file mode 100644 index 000000000..2bb85fc7f --- /dev/null +++ b/python/tests/test_extension.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from nanoarrow.c_schema import c_schema_view + +import nanoarrow as na +from nanoarrow import extension + + +def test_basic_extension(): + class TestExtension(extension.Extension): + def get_schema(self): + return na.extension_type(na.int32(), "arrow.test") + + def get_params(self, c_schema): + return {"parsed_key": "some parsed value"} + + instance = TestExtension() + assert extension.register_extension(instance) is None + + # Check internal resolution + assert extension.resolve_extension(c_schema_view(instance.get_schema())) is instance + + # Check Schema integration + schema = na.extension_type(na.int32(), "arrow.test") + assert schema.extension.parsed_key == "some parsed value" + + # Ensure other integrations fail if methods aren't implemented + with pytest.raises(TypeError, match="get_iterable_appender"): + assert na.Array([0], schema) + + with pytest.raises(TypeError, match="get_buffer_appender"): + assert na.Array(bytearray([0]), schema) + + schema = na.extension_type(na.int32(), "arrow.test") + storage_array = na.c_array([1, 2, 3], na.int32()) + _, storage_array_capsule = na.c_array(storage_array).__arrow_c_array__() + array = na.Array(storage_array_capsule, schema) + with pytest.raises(NotImplementedError, match="get_pyiter"): + array.to_pylist() + + with pytest.raises(NotImplementedError, match="get_sequence_converter"): + array.to_pysequence() + + other_instance = TestExtension() + assert extension.register_extension(other_instance) is instance + assert extension.unregister_extension("arrow.test") is other_instance + with pytest.raises(KeyError): + extension.unregister_extension("arrow.test") diff --git a/python/tests/test_extension_canonical.py b/python/tests/test_extension_canonical.py new file mode 100644 index 000000000..efed4c7ad --- /dev/null +++ b/python/tests/test_extension_canonical.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import nanoarrow as na + + +def test_extension_bool8(): + schema = na.bool8() + assert schema.type == na.Type.EXTENSION + assert schema.extension.storage.type == na.Type.INT8 + assert schema.extension.name == "arrow.bool8" + assert schema.extension.metadata == b"" + + assert na.bool8(nullable=False).nullable is False + + bool8_array = na.Array([True, False, True, True], na.bool8()) + assert bool8_array.schema.type == na.Type.EXTENSION + assert bool8_array.schema.extension.name == "arrow.bool8" + assert bool8_array.to_pylist() == [True, False, True, True] + + sequence = bool8_array.to_pysequence() + assert list(sequence) == [True, False, True, True] + + bool8_array = na.Array([True, False, None, True], na.bool8()) + assert bool8_array.to_pylist() == [True, False, None, True] + + sequence = bool8_array.to_pysequence(handle_nulls=na.nulls_separate()) + assert list(sequence[1]) == [True, False, False, True] + assert list(sequence[0]) == [True, True, False, True] + + bool8_array = na.Array(sequence[1], na.bool8()) + assert bool8_array.to_pylist() == [True, False, False, True] diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index 77a52d3a8..7d1f97f3e 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -248,7 +248,7 @@ def test_schema_extension(): schema_obj = na.extension_type(na.int32(), "arrow.test", nullable=False) assert schema_obj.extension.name == "arrow.test" - assert schema_obj.extension.metadata is None + assert schema_obj.extension.metadata == b"" assert schema_obj.extension.storage.type == na.Type.INT32 assert schema_obj.nullable is False diff --git a/python/tests/test_visitor.py b/python/tests/test_visitor.py index eb1c1f651..dbc38ce12 100644 --- a/python/tests/test_visitor.py +++ b/python/tests/test_visitor.py @@ -183,3 +183,18 @@ def test_numpy_null_handling(): nulls_as_sentinel(is_valid_non_empty, data), np.array([1, np.nan, 3], dtype=np.float64), ) + + +def test_iterator_unregistered_extension(): + from nanoarrow.iterator import UnregisteredExtensionWarning + + schema = na.extension_type(na.int32(), "arrow.test") + storage_array = na.c_array([1, 2, 3], na.int32()) + _, storage_array_capsule = na.c_array(storage_array).__arrow_c_array__() + extension_array = na.c_array(storage_array_capsule, schema) + + with pytest.warns(UnregisteredExtensionWarning): + visitor.ToPyListConverter.visit(extension_array) + + with pytest.warns(UnregisteredExtensionWarning): + visitor.ToPySequenceConverter.visit(extension_array)