Skip to content

Commit 9b81d41

Browse files
committed
Faster C-implementation of Join
Reuse outputs when possible, and avoid numpy overhead on public facing function
1 parent a346949 commit 9b81d41

File tree

2 files changed

+196
-42
lines changed

2 files changed

+196
-42
lines changed

pytensor/tensor/basic.py

+169-42
Original file line numberDiff line numberDiff line change
@@ -2537,58 +2537,185 @@ def perform(self, node, inputs, output_storage):
25372537
)
25382538

25392539
def c_code_cache_version(self):
2540-
return (5,)
2540+
return None
2541+
return (6,)
25412542

25422543
def c_code(self, node, name, inputs, outputs, sub):
2543-
axis, tens = inputs[0], inputs[1:]
2544-
view = -1
2545-
non_empty_tensor = tens[view]
2546-
input_1 = tens[0]
2547-
l = len(tens)
2548-
(out,) = outputs
2549-
fail = sub["fail"]
2550-
adtype = node.inputs[0].type.dtype_specs()[1]
2544+
axis, *arrays = inputs
2545+
[out] = outputs
25512546

2552-
copy_to_list = (
2553-
f"""Py_INCREF({inp}); PyList_SetItem(list, {i}, (PyObject*){inp});"""
2554-
for i, inp in enumerate(tens)
2555-
)
2547+
n = len(arrays)
2548+
out_dtype = node.outputs[0].type.dtype_specs()[2]
2549+
out_itemsize = np.dtype(node.outputs[0].dtype).itemsize
2550+
ndim = node.outputs[0].type.ndim
2551+
fail = sub["fail"]
25562552

2557-
copy_inputs_to_list = "\n".join(copy_to_list)
2558-
n = len(tens)
2553+
# Most times axis is constant, inline it
2554+
# This is safe to do because the hash of the c_code includes the constant signature
2555+
if isinstance(node.inputs[0], Constant):
2556+
static_axis = int(node.inputs[0].data)
2557+
static_axis = normalize_axis_index(static_axis, ndim)
2558+
axis_def = f"{static_axis};"
2559+
axis_check = ""
2560+
else:
2561+
axis_dtype = node.inputs[0].type.dtype_specs()[1]
2562+
axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];"
2563+
axis_check = f"""
2564+
if (axis < 0){{
2565+
axis = {ndim} + axis;
2566+
}}
2567+
if (axis >= {ndim} || axis < 0) {{
2568+
PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
2569+
{fail}
2570+
}}
2571+
"""
25592572

25602573
code = f"""
2561-
int axis = (({adtype} *)PyArray_DATA({axis}))[0];
2562-
PyObject* list = PyList_New({l});
2563-
{copy_inputs_to_list}
2564-
int tensors_lens_sum;
2565-
if({view} != -1) {{
2566-
tensors_lens_sum = 0;
2567-
2568-
for(int i=0; i < {n}; i++){{
2569-
tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
2574+
int axis = {axis_def}
2575+
PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
2576+
npy_intp out_shape[{ndim}];
2577+
npy_intp join_size = 0;
2578+
int out_is_valid = 0;
2579+
PyArrayObject_fields *view;
2580+
2581+
// Validate input shapes and compute join size
2582+
npy_intp *shape = PyArray_SHAPE(arrays[0]);
2583+
2584+
{axis_check}
2585+
2586+
for (int i = 0; i < {n}; i++) {{
2587+
if (PyArray_NDIM(arrays[i]) != {ndim}) {{
2588+
PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2589+
{fail}
2590+
}}
2591+
2592+
join_size += PyArray_SHAPE(arrays[i])[axis];
2593+
2594+
if(i > 0){{
2595+
for (int j = 0; j < {ndim}; j++) {{
2596+
if((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2597+
PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2598+
{fail}
2599+
}}
2600+
}}
2601+
}}
25702602
}}
2571-
tensors_lens_sum -= PyArray_DIM({non_empty_tensor}, axis);
2572-
}}
2573-
if({view} != -1 && tensors_lens_sum == 0) {{
2574-
Py_XDECREF({out});
2575-
Py_INCREF({non_empty_tensor});
2576-
{out} = {non_empty_tensor};
2577-
}}else{{
2578-
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
2579-
int ndim = PyArray_NDIM({input_1});
2580-
if( axis < -ndim ){{
2581-
PyErr_Format(PyExc_IndexError,
2582-
"Join axis %d out of bounds [0, %d)", axis, ndim);
2583-
{fail}
2603+
2604+
// Define dimensions of output array
2605+
memcpy(out_shape, shape, {ndim} * sizeof(npy_intp));
2606+
out_shape[axis] = join_size;
2607+
2608+
// Reuse output or allocate new one
2609+
if ({out} != NULL) {{
2610+
out_is_valid = (PyArray_NDIM({out}) == {ndim});
2611+
for (int i = 0; i < {ndim}; i++) {{
2612+
out_is_valid &= (PyArray_SHAPE({out})[i] == out_shape[i]);
2613+
}}
25842614
}}
2585-
Py_XDECREF({out});
2586-
{out} = (PyArrayObject *)PyArray_Concatenate(list, axis);
2587-
Py_DECREF(list);
2588-
if(!{out}){{
2615+
2616+
if (!out_is_valid) {{
2617+
Py_XDECREF({out});
2618+
2619+
// Find best memory layout to match the input tensors
2620+
// Adapted from numpy PyArray_CreateMultiSortedStridePerm
2621+
// https://github.com/numpy/numpy/blob/214b9f7c6d27f48b163dd7adbf9de368ad59859f/numpy/_core/src/multiarray/shape.c#L801
2622+
int strideperm[{ndim}] = {{{','.join(map(str, range(ndim)))}}};
2623+
npy_intp strides[{ndim}];
2624+
2625+
// Sort strides (insertion sort)
2626+
for (int i0 = 1; i0 < {ndim}; ++i0) {{
2627+
int ipos = i0;
2628+
int ax_j0 = strideperm[i0];
2629+
2630+
for (int i1 = i0 - 1; i1 >= 0; --i1) {{
2631+
int ambig = 1, shouldswap = 0;
2632+
int ax_j1 = strideperm[i1];
2633+
2634+
for (int iarrays = 0; iarrays < {n}; ++iarrays) {{
2635+
if (PyArray_SHAPE(arrays[iarrays])[ax_j0] != 1 && PyArray_SHAPE(arrays[iarrays])[ax_j1] != 1) {{
2636+
npy_intp stride0 = PyArray_STRIDES(arrays[iarrays])[ax_j0];
2637+
npy_intp stride1 = PyArray_STRIDES(arrays[iarrays])[ax_j1];
2638+
if (stride0 < 0) stride0 = -stride0;
2639+
if (stride1 < 0) stride1 = -stride1;
2640+
2641+
if (stride0 <= stride1) {{
2642+
shouldswap = 0;
2643+
}}
2644+
else if (ambig) {{
2645+
shouldswap = 1;
2646+
}}
2647+
ambig = 0;
2648+
}}
2649+
}}
2650+
2651+
if (!ambig) {{
2652+
if (shouldswap) {{
2653+
ipos = i1;
2654+
}}
2655+
else {{
2656+
break;
2657+
}}
2658+
}}
2659+
}}
2660+
2661+
if (ipos != i0) {{
2662+
for (int i1 = i0; i1 > ipos; --i1) {{
2663+
strideperm[i1] = strideperm[i1-1];
2664+
}}
2665+
strideperm[ipos] = ax_j0;
2666+
}}
2667+
}}
2668+
2669+
// Calculate strides based on sorted order
2670+
npy_intp stride = {out_itemsize};
2671+
for (int i = {ndim}-1; i >= 0; --i) {{
2672+
int ax = strideperm[i];
2673+
strides[ax] = stride;
2674+
stride *= out_shape[ax];
2675+
}}
2676+
2677+
{out} = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
2678+
PyArray_DescrFromType({out_dtype}),
2679+
{ndim},
2680+
out_shape,
2681+
strides,
2682+
NULL, /* data */
2683+
NPY_ARRAY_DEFAULT,
2684+
NULL);
2685+
2686+
if ({out} == NULL) {{
2687+
{fail}
2688+
}}
2689+
}}
2690+
2691+
// Create view into output buffer
2692+
// PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2693+
Py_INCREF(PyArray_DESCR({out}));
2694+
view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2695+
PyArray_DESCR({out}),
2696+
{ndim},
2697+
PyArray_SHAPE(arrays[0]),
2698+
PyArray_STRIDES({out}),
2699+
PyArray_DATA({out}),
2700+
NPY_ARRAY_WRITEABLE,
2701+
NULL);
2702+
if (view == NULL) {{
25892703
{fail}
25902704
}}
2591-
}}
2705+
2706+
// Copy data into output buffer
2707+
for (int i = 0; i < {n}; i++) {{
2708+
view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2709+
2710+
if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2711+
Py_DECREF(view);
2712+
{fail}
2713+
}}
2714+
2715+
view->data += (view->dimensions[axis] * view->strides[axis]);
2716+
}}
2717+
2718+
Py_DECREF(view);
25922719
"""
25932720
return code
25942721

tests/tensor/test_basic.py

+27
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
ivector,
118118
lscalar,
119119
lvector,
120+
matrices,
120121
matrix,
121122
row,
122123
scalar,
@@ -2156,6 +2157,32 @@ def test_split_view(self, linker):
21562157
# C impl always makes a copy
21572158
assert r.base is not x_test
21582159

2160+
@pytest.mark.parametrize("gc", (True, False), ids=lambda x: f"gc={x}")
2161+
@pytest.mark.parametrize("memory_layout", ["C-contiguous", "F-contiguous", "Mixed"])
2162+
@pytest.mark.parametrize("axis", (0, 1), ids=lambda x: f"axis={x}")
2163+
@pytest.mark.parametrize("ndim", (1, 2), ids=["vector", "matrix"])
2164+
@config.change_flags(cmodule__warn_no_version=False)
2165+
def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark):
2166+
if ndim == 1 and not (memory_layout == "C-contiguous" and axis == 0):
2167+
pytest.skip("Redundant parametrization")
2168+
n = 64
2169+
inputs = vectors("abcdef") if ndim == 1 else matrices("abcdef")
2170+
out = join(axis, *inputs)
2171+
fn = pytensor.function(inputs, Out(out, borrow=True), trust_input=True)
2172+
fn.vm.allow_gc = gc
2173+
test_values = [np.zeros((n, n)[:ndim], dtype=inputs[0].dtype) for _ in inputs]
2174+
if memory_layout == "C-contiguous":
2175+
pass
2176+
elif memory_layout == "F-contiguous":
2177+
test_values = [t.T for t in test_values]
2178+
elif memory_layout == "Mixed":
2179+
test_values = [t if i % 2 else t.T for i, t in enumerate(test_values)]
2180+
else:
2181+
raise ValueError
2182+
2183+
# assert fn(*test_values).shape == (n * 6, n) if axis == 0 else (n, n * 6)
2184+
benchmark(fn, *test_values)
2185+
21592186

21602187
def test_TensorFromScalar():
21612188
s = ps.constant(56)

0 commit comments

Comments
 (0)