-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(python): Implement extension type/mechanism in python package (#688
) This PR implements the first canonical extension type for nanoarrow/Python: `nanoarrow.bool8()`. In doing so it also implements some machinery for "extensions", which is intended to be internal and to evolve with the requirements of some follow-up canonical extension types. The extension points are: - Parsing the metadata when `Schema.extension` is accessed to get parameter access (and validate the metadata) - Converting an extension array to Python objects (i.e., `array.to_pylist()`) - Converting an extension array to a Python sequence (i.e., `array.to_pysequence()`) - Constructing an extension array from Python objects (i.e., `na.Array([True, False, None], na.bool8())` - Constructing an extension array from a Python buffer (i.e., `na.Array(np.array([True, False, False]), na.bool8())` I am not sure the extension point implementations via the `Extension` methods are the final way this should be done, but this PR at least connects the wires. The tensor extensions are more complex and will require some modification of this, but I'd like to do that in another PR. ```python import nanoarrow as na import numpy as np na.Array([True, False, None], na.bool8()) #> nanoarrow.Array<arrow.bool8{int8}>[3] #> True #> False #> None na.Array(np.array([True, False, True]), na.bool8()) #> nanoarrow.Array<arrow.bool8{int8}>[3] #> True #> False #> True na.Array([True, False, None], na.bool8()).to_pylist() #> [True, False, None] np.array(na.Array([True, False, True], na.bool8()).to_pysequence()) #> array([ True, False, True]) ```
- Loading branch information
1 parent
e54b7df
commit 7b8a7c8
Showing
13 changed files
with
470 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Type | ||
|
||
from nanoarrow.c_schema import CSchema, CSchemaView, c_schema_view | ||
|
||
|
||
class Extension: | ||
"""Define a nanoarrow Extension | ||
A nanoarrow extension customizes behaviour of built-in operations | ||
applicable to a specific type. This is currently implemented only | ||
for Arrow extension types but could in theory apply if one wanted | ||
to customize the conversion behaviour of a specific non-extension | ||
type. | ||
This is currently internal and involves knowledge of other internal | ||
nanoarrow/Python structures. It is currently used only to implement | ||
canonical extensions with the anticipation of evolving to support | ||
user-defined extensions as the internal APIs on which it relies | ||
stabilize. | ||
With the current design, an Extension subclass must be constructible | ||
with no parameters (e.g., ``Extension()``). | ||
""" | ||
|
||
def get_schema(self) -> CSchema: | ||
"""Get the schema for which this extension applies. | ||
This is used by :func:`register_extension` to ensure that it can be resolved | ||
when needed. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def get_params(self, c_schema: CSchema) -> Mapping[str, Any]: | ||
"""Compute a dictionary of type parameters. | ||
These parameters are accessible via the :class:`Schema` | ||
``extension`` attribute (e.g., ``schema.extension.param_name``). | ||
Internal parameters can also be returned but should be prefixed with | ||
an underscore. | ||
This method should also error if the storage type or any other property | ||
of the schema is not valid. | ||
""" | ||
return {} | ||
|
||
def get_pyiter( | ||
self, | ||
py_iterator, | ||
offset: int, | ||
length: int, | ||
) -> Optional[Iterator[Optional[bool]]]: | ||
"""Compute an iterable of Python objects. | ||
Used by ``to_pylist()`` to generate scalars for a particular type. | ||
If ``None`` is returned, the behaviour of the storage type will be | ||
used without warning. | ||
This method is currently passed the underlying :class:`PyIterator` | ||
and returns an iterator; however, it could in the future be passed | ||
a :class:`CSchema` and return a PyIterator class once that class | ||
structure is stabilized. | ||
""" | ||
name = py_iterator._schema_view.extension_name | ||
raise NotImplementedError(f"Extension get_pyiter() for {name}") | ||
|
||
def get_sequence_converter(self, c_schema: CSchema): | ||
"""Return an ArrayViewVisitor subclass used to compute a sequence from | ||
a stream of arrays. | ||
This is currently implemented outside the null handler and may need a flag | ||
at some point to indicate that it did or did not handle its own nulls. | ||
""" | ||
schema_view = c_schema_view(c_schema) | ||
name = schema_view.extension_name | ||
raise NotImplementedError(f"Extension get_sequence_converter() for {name}") | ||
|
||
def get_buffer_appender( | ||
self, c_schema: CSchema, array_builder | ||
) -> Optional[Callable[[Any], None]]: | ||
"""Compute a function that prepares a :class:`CArrayBuilder` from a | ||
buffer. | ||
This is used to customize the behavior of creating a CArray from an | ||
object implementing the Python buffer protocol. If ``None`` is | ||
returned, the storage will be converted without a warning. | ||
This method is currently passed a :class:`CArrayBuilder` but in | ||
the future should perhaps be passed a :class:`CSchema` and return a | ||
CArrayBuilder class. | ||
""" | ||
schema_view = c_schema_view(c_schema) | ||
name = schema_view.extension_name | ||
raise NotImplementedError(f"Extension get_buffer_appender() for {name}") | ||
|
||
def get_iterable_appender( | ||
self, c_schema: CSchema, array_builder | ||
) -> Optional[Callable[[Iterable], None]]: | ||
"""Compute a function that prepares a :class:`CArrayBuilder` from a | ||
buffer. | ||
This is used to customize the behavior of creating a CArray from an | ||
iterable of Python objects. | ||
This method is currently passed a :class:`CArrayBuilder` but in | ||
the future should perhaps be passed a :class:`CSchema` and return a | ||
CArrayBuilder class. | ||
""" | ||
schema_view = c_schema_view(c_schema) | ||
name = schema_view.extension_name | ||
raise NotImplementedError(f"Extension get_iterable_appender() for {name}") | ||
|
||
|
||
_global_extension_registry = {} | ||
|
||
|
||
def resolve_extension(c_schema_view: CSchemaView) -> Optional[Extension]: | ||
"""Resolve an extension instance from a :class:`CSchemaView` | ||
Returns the registered extension instance if one applies to the passed | ||
type or ``None`` otherwise. | ||
""" | ||
extension_name = c_schema_view.extension_name | ||
if extension_name in _global_extension_registry: | ||
return _global_extension_registry[extension_name] | ||
|
||
return None | ||
|
||
|
||
def register_extension(extension: Extension) -> Optional[Extension]: | ||
"""Register an :class:`Extension` instance in the global registry. | ||
Inserts an extension into the global registry, returning the | ||
previously registered extension for that type if one exists | ||
(or ``None`` otherwise). | ||
""" | ||
global _global_extension_registry | ||
|
||
schema_view = c_schema_view(extension.get_schema()) | ||
key = schema_view.extension_name | ||
prev = resolve_extension(schema_view) | ||
_global_extension_registry[key] = extension | ||
return prev | ||
|
||
|
||
def unregister_extension(extension_name: str): | ||
"""Remove an extension from the global registry by extension name. | ||
Returns the removed extension. Raises ``KeyError`` if there was no | ||
extension registered for this extension name. | ||
""" | ||
prev = _global_extension_registry[extension_name] | ||
del _global_extension_registry[extension_name] | ||
return prev | ||
|
||
|
||
def register(extension_cls: Type[Extension]): | ||
"""Decorator that registers an extension class by instantiating it | ||
and adding it to the global registry.""" | ||
register_extension(extension_cls()) | ||
return extension_cls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from typing import Any, Iterator, Mapping, Optional | ||
|
||
from nanoarrow.c_buffer import CBufferBuilder | ||
from nanoarrow.c_schema import CSchema, c_schema_view | ||
from nanoarrow.schema import extension_type, int8 | ||
from nanoarrow.visitor import ToPyBufferConverter | ||
|
||
from nanoarrow import extension | ||
|
||
|
||
def bool8(nullable: bool = True): | ||
"""Create a type representing a boolean encoded as one byte per value | ||
Parameters | ||
---------- | ||
nullable : bool, optional | ||
Use ``False`` to mark this field as non-nullable. | ||
""" | ||
|
||
return extension_type(int8(), "arrow.bool8", nullable=nullable) | ||
|
||
|
||
class Bool8SequenceConverter(ToPyBufferConverter): | ||
def _make_builder(self): | ||
return CBufferBuilder().set_format("?") | ||
|
||
|
||
@extension.register | ||
class Bool8Extension(extension.Extension): | ||
def get_schema(self) -> CSchema: | ||
return bool8() | ||
|
||
def get_params(self, c_schema: CSchema) -> Mapping[str, Any]: | ||
schema_view = c_schema_view(c_schema) | ||
if schema_view.type != "int8": | ||
raise ValueError("arrow.bool8 must have storage type int8") | ||
|
||
return {} | ||
|
||
def get_pyiter( | ||
self, | ||
py_iterator, | ||
offset: int, | ||
length: int, | ||
) -> Optional[Iterator[Optional[bool]]]: | ||
view = py_iterator._array_view | ||
items = map(bool, view.buffer(1).elements(offset, length)) | ||
|
||
if py_iterator._contains_nulls(): | ||
validity = view.buffer(0).elements(offset, length) | ||
return py_iterator._wrap_iter_nullable(validity, items) | ||
else: | ||
return items | ||
|
||
def get_sequence_converter(self, c_schema: CSchema): | ||
self.get_params(c_schema) | ||
return Bool8SequenceConverter | ||
|
||
def get_sequence_appender(self, c_schema: CSchema, array_builder): | ||
self.get_params(c_schema) | ||
return None | ||
|
||
def get_buffer_appender(self, c_schema: CSchema, array_builder): | ||
self.get_params(c_schema) | ||
return None | ||
|
||
def get_iterable_appender(self, c_schema: CSchema, array_builder): | ||
self.get_params(c_schema) | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.