Skip to content

Commit d5c0993

Browse files
committed
WIP. Support numba udf for SChunks
1 parent 8814dd8 commit d5c0993

File tree

3 files changed

+93
-1
lines changed

3 files changed

+93
-1
lines changed

blosc2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ class Tuner(Enum):
205205
abs,
206206
)
207207

208-
from .lazyexpr import LazyExpr
208+
from .lazyexpr import LazyExpr, expr_from_udf
209209

210210
from .schunk import SChunk, open
211211
from .version import __version__

blosc2/blosc2_ext.pyx

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,6 +1453,36 @@ cdef class SChunk:
14531453
if self.schunk.cctx == NULL:
14541454
raise RuntimeError("Could not create compression context")
14551455

1456+
1457+
def _set_aux_numba(self, func, inputs_id, dtype_output):
1458+
if self.schunk.storage.cparams.nthreads > 1:
1459+
raise AttributeError("compress `nthreads` must be 1 when assigning a prefilter")
1460+
1461+
func_id = func.__name__
1462+
blosc2.prefilter_funcs[func_id] = func
1463+
func_id = func_id.encode("utf-8") if isinstance(func_id, str) else func_id
1464+
1465+
# Set prefilter
1466+
cdef blosc2_cparams* cparams = self.schunk.storage.cparams
1467+
cparams.prefilter = <blosc2_prefilter_fn> general_numba
1468+
1469+
cdef blosc2_prefilter_params* preparams = <blosc2_prefilter_params *> malloc(sizeof(blosc2_prefilter_params))
1470+
cdef filler_udata* fill_udata = <filler_udata *> malloc(sizeof(filler_udata))
1471+
fill_udata.py_func = <char *> malloc(strlen(func_id) + 1)
1472+
strcpy(fill_udata.py_func, func_id)
1473+
fill_udata.inputs_id = inputs_id
1474+
fill_udata.output_cdtype = np.dtype(dtype_output).num
1475+
fill_udata.chunkshape = self.schunk.chunksize // self.schunk.typesize
1476+
1477+
preparams.user_data = fill_udata
1478+
cparams.preparams = preparams
1479+
_check_cparams(cparams)
1480+
1481+
blosc2_free_ctx(self.schunk.cctx)
1482+
self.schunk.cctx = blosc2_create_cctx(dereference(cparams))
1483+
if self.schunk.cctx == NULL:
1484+
raise RuntimeError("Could not create compression context")
1485+
14561486
def _set_prefilter(self, func, dtype_input, dtype_output=None):
14571487
if self.schunk.storage.cparams.nthreads > 1:
14581488
raise AttributeError("compress `nthreads` must be 1 when assigning a prefilter")
@@ -1544,6 +1574,35 @@ cdef int general_filler(blosc2_prefilter_params *params):
15441574

15451575
return 0
15461576

1577+
1578+
cdef int general_numba(blosc2_prefilter_params *params):
1579+
cdef filler_udata *udata = <filler_udata *> params.user_data
1580+
cdef int nd = 1
1581+
cdef np.npy_intp dims = params.output_size // params.output_typesize
1582+
1583+
inputs_tuple = _ctypes.PyObj_FromPtr(udata.inputs_id)
1584+
1585+
output = np.PyArray_SimpleNewFromData(nd, &dims, udata.output_cdtype, <void*>params.output)
1586+
offset = params.nchunk * udata.chunkshape + params.output_offset // params.output_typesize
1587+
1588+
inputs = []
1589+
for obj, dtype in inputs_tuple:
1590+
if isinstance(obj, blosc2.SChunk):
1591+
out = np.empty(dims, dtype=dtype)
1592+
obj.get_slice(start=offset, stop=offset + dims, out=out)
1593+
inputs.append(out)
1594+
elif isinstance(obj, np.ndarray):
1595+
inputs.append(obj[offset : offset + dims])
1596+
elif isinstance(obj, (int, float, bool, complex)):
1597+
inputs.append(np.full(dims, obj, dtype=dtype))
1598+
else:
1599+
raise ValueError("Unsupported operand")
1600+
1601+
func_id = udata.py_func.decode("utf-8")
1602+
blosc2.prefilter_funcs[func_id](tuple(inputs), output, offset)
1603+
1604+
return 0
1605+
15471606
def nelem_from_inputs(inputs_tuple, nelem=None):
15481607
for obj, dtype in inputs_tuple:
15491608
if isinstance(obj, blosc2.SChunk):

blosc2/lazyexpr.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,3 +507,36 @@ def do_slices_intersect(slice1, slice2):
507507
np.testing.assert_allclose(res, nres)
508508
np.testing.assert_allclose(res2, nres)
509509
print("Everything is working fine")
510+
511+
512+
class NumbaExpr:
513+
def __init__(self, func, inputs_tuple, schunk_dtype):
514+
# Suposem que tots els operands tenen els mateix shape (ara per ara) i que són schunks, ja o
515+
# canviarem més endavant
516+
self.inputs_tuple = inputs_tuple # Keep reference to evict lost reference
517+
op1 = inputs_tuple[0][0]
518+
cparams = {'typesize': np.dtype(schunk_dtype).itemsize, 'nthreads': 1}
519+
self.nbytes = op1.size * cparams['typesize']
520+
self.res = blosc2.SChunk(chunksize=self.nbytes, cparams=cparams)
521+
self.res._set_aux_numba(func, id(inputs_tuple), schunk_dtype)
522+
self.schunk_dtype = schunk_dtype # Quan siga amb ndarray açò ja no caldrà
523+
self.func = func
524+
525+
def eval(self):
526+
527+
chunksize = self.res.chunksize
528+
written_nbytes = 0
529+
while written_nbytes < self.nbytes:
530+
chunk = np.zeros(chunksize // self.res.typesize, dtype=self.schunk_dtype)
531+
self.res.append_data(chunk)
532+
written_nbytes += chunksize
533+
if (self.nbytes - written_nbytes) < self.res.chunksize:
534+
chunksize = self.nbytes - written_nbytes
535+
536+
self.res.remove_prefilter(self.func.__name__)
537+
return self.res
538+
539+
540+
# inputs_tuple = ( (operand, dtype), (operand2, dtype2), ... )
541+
def expr_from_udf(func, inputs_tuple, dtype):
542+
return NumbaExpr(func, inputs_tuple, dtype)

0 commit comments

Comments
 (0)