Skip to content

Commit

Permalink
feat(python): Implement extension type/mechanism in python package (#688
Browse files Browse the repository at this point in the history
)

This PR implements the first canonical extension type for
nanoarrow/Python: `nanoarrow.bool8()`. In doing so it also implements
some machinery for "extensions", which is intended to be internal and to
evolve with the requirements of some follow-up canonical extension
types. The extension points are:

- Parsing the metadata when `Schema.extension` is accessed to get
parameter access (and validate the metadata)
- Converting an extension array to Python objects (i.e.,
`array.to_pylist()`)
- Converting an extension array to a Python sequence (i.e.,
`array.to_pysequence()`)
- Constructing an extension array from Python objects (i.e.,
`na.Array([True, False, None], na.bool8())`
- Constructing an extension array from a Python buffer (i.e.,
`na.Array(np.array([True, False, False]), na.bool8())`

I am not sure the extension point implementations via the `Extension`
methods are the final way this should be done, but this PR at least
connects the wires. The tensor extensions are more complex and will
require some modification of this, but I'd like to do that in another
PR.

```python
import nanoarrow as na
import numpy as np

na.Array([True, False, None], na.bool8())
#> nanoarrow.Array<arrow.bool8{int8}>[3]
#> True
#> False
#> None

na.Array(np.array([True, False, True]), na.bool8())
#> nanoarrow.Array<arrow.bool8{int8}>[3]
#> True
#> False
#> True

na.Array([True, False, None], na.bool8()).to_pylist()
#> [True, False, None]

np.array(na.Array([True, False, True], na.bool8()).to_pysequence())
#> array([ True, False,  True])
```
  • Loading branch information
paleolimbot authored Dec 6, 2024
1 parent e54b7df commit 7b8a7c8
Show file tree
Hide file tree
Showing 13 changed files with 470 additions and 12 deletions.
2 changes: 2 additions & 0 deletions python/src/nanoarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nanoarrow.c_array_stream import c_array_stream
from nanoarrow.c_schema import c_schema
from nanoarrow.c_buffer import c_buffer
from nanoarrow.extension_canonical import bool8
from nanoarrow.schema import (
Schema,
Type,
Expand Down Expand Up @@ -87,6 +88,7 @@
"binary",
"binary_view",
"bool_",
"bool8",
"c_array",
"c_array_from_buffers",
"c_array_stream",
Expand Down
10 changes: 9 additions & 1 deletion python/src/nanoarrow/_schema.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@ cdef class CSchemaView:
(_types.TIMESTAMP, _types.DATE64, _types.DURATION)
):
return 'q'
elif self.extension_name:
return self._get_buffer_format()
else:
return None

Expand All @@ -564,7 +566,13 @@ cdef class CSchemaView:
or None if there is no Python format string that can represent this
type without loosing information.
"""
if self.extension_name or self._schema_view.type != self._schema_view.storage_type:
if self.extension_name:
return None
else:
return self._get_buffer_format()

def _get_buffer_format(self):
if self._schema_view.type != self._schema_view.storage_type:
return None

# String/binary types do not have format strings as far as the Python
Expand Down
29 changes: 29 additions & 0 deletions python/src/nanoarrow/c_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nanoarrow._utils import obj_is_buffer, obj_is_capsule
from nanoarrow.c_buffer import c_buffer
from nanoarrow.c_schema import c_schema, c_schema_view
from nanoarrow.extension import resolve_extension

from nanoarrow import _types

Expand Down Expand Up @@ -416,6 +417,16 @@ def infer_schema(cls, obj) -> Tuple[CBuffer, CSchema]:
def __init__(self, schema):
super().__init__(schema)

ext = resolve_extension(self._schema_view)
self._append_ext = None
if ext is not None:
self._append_ext = ext.get_buffer_appender(self._schema, self)
elif self._schema_view.extension_name:
raise NotImplementedError(
"Can't create array for unregistered extension "
f"'{self._schema_view.extension_name}'"
)

if self._schema_view.storage_buffer_format is None:
raise ValueError(
f"Can't build array of type {self._schema_view.type} from PyBuffer"
Expand All @@ -428,6 +439,12 @@ def append(self, obj: Any) -> None:
if not isinstance(obj, CBuffer):
obj = CBuffer.from_pybuffer(obj)

if self._append_ext is not None:
return self._append_ext(obj)

return self._append_impl(obj)

def _append_impl(self, obj):
if (
self._schema_view.buffer_format in ("b", "c")
and obj.format not in ("b", "c")
Expand Down Expand Up @@ -462,6 +479,18 @@ def __init__(self, schema):

# Resolve the method name we are going to use to do the building from
# the provided schema.
ext = resolve_extension(self._schema_view)
if ext is not None:
maybe_appender = ext.get_iterable_appender(self._schema, self)
if maybe_appender:
self._append_impl = maybe_appender
return
elif self._schema_view.extension_name:
raise NotImplementedError(
f"Can't create array for unregistered extension "
f"'{self._schema_view.extension_name}'"
)

type_id = self._schema_view.type_id
if type_id not in _ARRAY_BUILDER_FROM_ITERABLE_METHOD:
raise ValueError(
Expand Down
177 changes: 177 additions & 0 deletions python/src/nanoarrow/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Type

from nanoarrow.c_schema import CSchema, CSchemaView, c_schema_view


class Extension:
"""Define a nanoarrow Extension
A nanoarrow extension customizes behaviour of built-in operations
applicable to a specific type. This is currently implemented only
for Arrow extension types but could in theory apply if one wanted
to customize the conversion behaviour of a specific non-extension
type.
This is currently internal and involves knowledge of other internal
nanoarrow/Python structures. It is currently used only to implement
canonical extensions with the anticipation of evolving to support
user-defined extensions as the internal APIs on which it relies
stabilize.
With the current design, an Extension subclass must be constructible
with no parameters (e.g., ``Extension()``).
"""

def get_schema(self) -> CSchema:
"""Get the schema for which this extension applies.
This is used by :func:`register_extension` to ensure that it can be resolved
when needed.
"""
raise NotImplementedError()

def get_params(self, c_schema: CSchema) -> Mapping[str, Any]:
"""Compute a dictionary of type parameters.
These parameters are accessible via the :class:`Schema`
``extension`` attribute (e.g., ``schema.extension.param_name``).
Internal parameters can also be returned but should be prefixed with
an underscore.
This method should also error if the storage type or any other property
of the schema is not valid.
"""
return {}

def get_pyiter(
self,
py_iterator,
offset: int,
length: int,
) -> Optional[Iterator[Optional[bool]]]:
"""Compute an iterable of Python objects.
Used by ``to_pylist()`` to generate scalars for a particular type.
If ``None`` is returned, the behaviour of the storage type will be
used without warning.
This method is currently passed the underlying :class:`PyIterator`
and returns an iterator; however, it could in the future be passed
a :class:`CSchema` and return a PyIterator class once that class
structure is stabilized.
"""
name = py_iterator._schema_view.extension_name
raise NotImplementedError(f"Extension get_pyiter() for {name}")

def get_sequence_converter(self, c_schema: CSchema):
"""Return an ArrayViewVisitor subclass used to compute a sequence from
a stream of arrays.
This is currently implemented outside the null handler and may need a flag
at some point to indicate that it did or did not handle its own nulls.
"""
schema_view = c_schema_view(c_schema)
name = schema_view.extension_name
raise NotImplementedError(f"Extension get_sequence_converter() for {name}")

def get_buffer_appender(
self, c_schema: CSchema, array_builder
) -> Optional[Callable[[Any], None]]:
"""Compute a function that prepares a :class:`CArrayBuilder` from a
buffer.
This is used to customize the behavior of creating a CArray from an
object implementing the Python buffer protocol. If ``None`` is
returned, the storage will be converted without a warning.
This method is currently passed a :class:`CArrayBuilder` but in
the future should perhaps be passed a :class:`CSchema` and return a
CArrayBuilder class.
"""
schema_view = c_schema_view(c_schema)
name = schema_view.extension_name
raise NotImplementedError(f"Extension get_buffer_appender() for {name}")

def get_iterable_appender(
self, c_schema: CSchema, array_builder
) -> Optional[Callable[[Iterable], None]]:
"""Compute a function that prepares a :class:`CArrayBuilder` from a
buffer.
This is used to customize the behavior of creating a CArray from an
iterable of Python objects.
This method is currently passed a :class:`CArrayBuilder` but in
the future should perhaps be passed a :class:`CSchema` and return a
CArrayBuilder class.
"""
schema_view = c_schema_view(c_schema)
name = schema_view.extension_name
raise NotImplementedError(f"Extension get_iterable_appender() for {name}")


_global_extension_registry = {}


def resolve_extension(c_schema_view: CSchemaView) -> Optional[Extension]:
"""Resolve an extension instance from a :class:`CSchemaView`
Returns the registered extension instance if one applies to the passed
type or ``None`` otherwise.
"""
extension_name = c_schema_view.extension_name
if extension_name in _global_extension_registry:
return _global_extension_registry[extension_name]

return None


def register_extension(extension: Extension) -> Optional[Extension]:
"""Register an :class:`Extension` instance in the global registry.
Inserts an extension into the global registry, returning the
previously registered extension for that type if one exists
(or ``None`` otherwise).
"""
global _global_extension_registry

schema_view = c_schema_view(extension.get_schema())
key = schema_view.extension_name
prev = resolve_extension(schema_view)
_global_extension_registry[key] = extension
return prev


def unregister_extension(extension_name: str):
"""Remove an extension from the global registry by extension name.
Returns the removed extension. Raises ``KeyError`` if there was no
extension registered for this extension name.
"""
prev = _global_extension_registry[extension_name]
del _global_extension_registry[extension_name]
return prev


def register(extension_cls: Type[Extension]):
"""Decorator that registers an extension class by instantiating it
and adding it to the global registry."""
register_extension(extension_cls())
return extension_cls
86 changes: 86 additions & 0 deletions python/src/nanoarrow/extension_canonical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Iterator, Mapping, Optional

from nanoarrow.c_buffer import CBufferBuilder
from nanoarrow.c_schema import CSchema, c_schema_view
from nanoarrow.schema import extension_type, int8
from nanoarrow.visitor import ToPyBufferConverter

from nanoarrow import extension


def bool8(nullable: bool = True):
"""Create a type representing a boolean encoded as one byte per value
Parameters
----------
nullable : bool, optional
Use ``False`` to mark this field as non-nullable.
"""

return extension_type(int8(), "arrow.bool8", nullable=nullable)


class Bool8SequenceConverter(ToPyBufferConverter):
def _make_builder(self):
return CBufferBuilder().set_format("?")


@extension.register
class Bool8Extension(extension.Extension):
def get_schema(self) -> CSchema:
return bool8()

def get_params(self, c_schema: CSchema) -> Mapping[str, Any]:
schema_view = c_schema_view(c_schema)
if schema_view.type != "int8":
raise ValueError("arrow.bool8 must have storage type int8")

return {}

def get_pyiter(
self,
py_iterator,
offset: int,
length: int,
) -> Optional[Iterator[Optional[bool]]]:
view = py_iterator._array_view
items = map(bool, view.buffer(1).elements(offset, length))

if py_iterator._contains_nulls():
validity = view.buffer(0).elements(offset, length)
return py_iterator._wrap_iter_nullable(validity, items)
else:
return items

def get_sequence_converter(self, c_schema: CSchema):
self.get_params(c_schema)
return Bool8SequenceConverter

def get_sequence_appender(self, c_schema: CSchema, array_builder):
self.get_params(c_schema)
return None

def get_buffer_appender(self, c_schema: CSchema, array_builder):
self.get_params(c_schema)
return None

def get_iterable_appender(self, c_schema: CSchema, array_builder):
self.get_params(c_schema)
return None
13 changes: 10 additions & 3 deletions python/src/nanoarrow/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nanoarrow._array import CArrayView
from nanoarrow.c_array_stream import c_array_stream
from nanoarrow.c_schema import c_schema, c_schema_view
from nanoarrow.extension import resolve_extension
from nanoarrow.schema import Schema

from nanoarrow import _types
Expand Down Expand Up @@ -183,6 +184,8 @@ def get_iterator(cls, obj, schema=None):

def __init__(self, schema, *, array_view=None):
super().__init__(schema, array_view=array_view)
self._ext = resolve_extension(self._schema_view)
self._ext_params = self._ext.get_params(schema) if self._ext else None

self._children = list(
map(self._make_child, self._schema.children, self._array_view.children)
Expand All @@ -208,11 +211,15 @@ def __iter__(self):

def _iter_chunk(self, offset, length):
"""Iterate over all elements in a slice of the current chunk"""

# Check for an extension type first since this isn't reflected by
# self._schema_view.type_id. Currently we just return the storage
# iterator with a warning for extension types.
# self._schema_view.type_id.
maybe_extension_name = self._schema_view.extension_name
if maybe_extension_name:
if self._ext:
maybe_iter = self._ext.get_pyiter(self, offset, length)
if maybe_iter:
return maybe_iter
elif maybe_extension_name:
self._warn(
f"Converting unregistered extension '{maybe_extension_name}' "
"as storage type",
Expand Down
Loading

0 comments on commit 7b8a7c8

Please sign in to comment.