Skip to content

Commit 2e0c79d

Browse files
committed
Reuse output buffer in C-impl of Join
1 parent 1acacb8 commit 2e0c79d

File tree

2 files changed

+107
-10
lines changed

2 files changed

+107
-10
lines changed

pytensor/tensor/basic.py

+78-8
Original file line numberDiff line numberDiff line change
@@ -2518,7 +2518,7 @@ def perform(self, node, inputs, output_storage):
25182518
)
25192519

25202520
def c_code_cache_version(self):
2521-
return (6,)
2521+
return (7,)
25222522

25232523
def c_code(self, node, name, inputs, outputs, sub):
25242524
axis, *arrays = inputs
@@ -2557,16 +2557,86 @@ def c_code(self, node, name, inputs, outputs, sub):
25572557
code = f"""
25582558
int axis = {axis_def}
25592559
PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
2560-
PyObject* arrays_tuple = PyTuple_New({n});
2560+
int out_is_valid = {out} != NULL;
25612561
25622562
{axis_check}
25632563
2564-
Py_XDECREF({out});
2565-
{copy_arrays_to_tuple}
2566-
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2567-
Py_DECREF(arrays_tuple);
2568-
if(!{out}){{
2569-
{fail}
2564+
if (out_is_valid) {{
2565+
// Check if we can reuse output
2566+
npy_intp join_size = 0;
2567+
npy_intp out_shape[{ndim}];
2568+
npy_intp *shape = PyArray_SHAPE(arrays[0]);
2569+
2570+
for (int i = 0; i < {n}; i++) {{
2571+
if (PyArray_NDIM(arrays[i]) != {ndim}) {{
2572+
PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2573+
{fail}
2574+
}}
2575+
2576+
join_size += PyArray_SHAPE(arrays[i])[axis];
2577+
2578+
if (i > 0){{
2579+
for (int j = 0; j < {ndim}; j++) {{
2580+
if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2581+
PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2582+
{fail}
2583+
}}
2584+
}}
2585+
}}
2586+
}}
2587+
2588+
memcpy(out_shape, shape, {ndim} * sizeof(npy_intp));
2589+
out_shape[axis] = join_size;
2590+
2591+
for (int i = 0; i < {ndim}; i++) {{
2592+
out_is_valid &= (PyArray_SHAPE({out})[i] == out_shape[i]);
2593+
}}
2594+
}}
2595+
2596+
if (!out_is_valid) {{
2597+
// Use PyArray_Concatenate
2598+
Py_XDECREF({out});
2599+
PyObject* arrays_tuple = PyTuple_New({n});
2600+
{copy_arrays_to_tuple}
2601+
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2602+
Py_DECREF(arrays_tuple);
2603+
if(!{out}){{
2604+
{fail}
2605+
}}
2606+
}}
2607+
else {{
2608+
// Copy the data to the pre-allocated output buffer
2609+
2610+
// Create view into output buffer
2611+
PyArrayObject_fields *view;
2612+
2613+
// PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2614+
Py_INCREF(PyArray_DESCR({out}));
2615+
view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2616+
PyArray_DESCR({out}),
2617+
{ndim},
2618+
PyArray_SHAPE(arrays[0]),
2619+
PyArray_STRIDES({out}),
2620+
PyArray_DATA({out}),
2621+
NPY_ARRAY_WRITEABLE,
2622+
NULL);
2623+
if (view == NULL) {{
2624+
{fail}
2625+
}}
2626+
2627+
// Copy data into output buffer
2628+
for (int i = 0; i < {n}; i++) {{
2629+
view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2630+
2631+
if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2632+
Py_DECREF(view);
2633+
{fail}
2634+
}}
2635+
2636+
view->data += (view->dimensions[axis] * view->strides[axis]);
2637+
}}
2638+
2639+
Py_DECREF(view);
25702640
}}
25712641
"""
25722642
return code

tests/tensor/test_basic.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
ivector,
118118
lscalar,
119119
lvector,
120+
matrices,
120121
matrix,
121122
row,
122123
scalar,
@@ -1762,7 +1763,7 @@ def test_join_matrixV_negative_axis(self):
17621763
got = f(-2)
17631764
assert np.allclose(got, want)
17641765

1765-
with pytest.raises(IndexError):
1766+
with pytest.raises(ValueError):
17661767
f(-3)
17671768

17681769
@pytest.mark.parametrize("py_impl", (False, True))
@@ -1805,7 +1806,7 @@ def test_join_matrixC_negative_axis(self, py_impl):
18051806
got = f()
18061807
assert np.allclose(got, want)
18071808

1808-
with pytest.raises(IndexError):
1809+
with pytest.raises(ValueError):
18091810
join(-3, a, b)
18101811

18111812
with impl_ctxt:
@@ -2152,6 +2153,32 @@ def test_split_view(self, linker):
21522153
assert np.allclose(r, expected)
21532154
assert r.base is x_test
21542155

2156+
@pytest.mark.parametrize("gc", (True, False), ids=lambda x: f"gc={x}")
2157+
@pytest.mark.parametrize("memory_layout", ["C-contiguous", "F-contiguous", "Mixed"])
2158+
@pytest.mark.parametrize("axis", (0, 1), ids=lambda x: f"axis={x}")
2159+
@pytest.mark.parametrize("ndim", (1, 2), ids=["vector", "matrix"])
2160+
@config.change_flags(cmodule__warn_no_version=False)
2161+
def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark):
2162+
if ndim == 1 and not (memory_layout == "C-contiguous" and axis == 0):
2163+
pytest.skip("Redundant parametrization")
2164+
n = 64
2165+
inputs = vectors("abcdef") if ndim == 1 else matrices("abcdef")
2166+
out = join(axis, *inputs)
2167+
fn = pytensor.function(inputs, Out(out, borrow=True), trust_input=True)
2168+
fn.vm.allow_gc = gc
2169+
test_values = [np.zeros((n, n)[:ndim], dtype=inputs[0].dtype) for _ in inputs]
2170+
if memory_layout == "C-contiguous":
2171+
pass
2172+
elif memory_layout == "F-contiguous":
2173+
test_values = [t.T for t in test_values]
2174+
elif memory_layout == "Mixed":
2175+
test_values = [t if i % 2 else t.T for i, t in enumerate(test_values)]
2176+
else:
2177+
raise ValueError
2178+
2179+
assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6)
2180+
benchmark(fn, *test_values)
2181+
21552182

21562183
def test_TensorFromScalar():
21572184
s = ps.constant(56)

0 commit comments

Comments
 (0)