Skip to content

Commit 948977d

Browse files
committed
Reuse output buffer in C-impl of Join
1 parent 5abc731 commit 948977d

File tree

2 files changed

+107
-10
lines changed

2 files changed

+107
-10
lines changed

pytensor/tensor/basic.py

+78-8
Original file line numberDiff line numberDiff line change
@@ -2520,7 +2520,7 @@ def perform(self, node, inputs, output_storage):
25202520
)
25212521

25222522
def c_code_cache_version(self):
2523-
return (6,)
2523+
return (7,)
25242524

25252525
def c_code(self, node, name, inputs, outputs, sub):
25262526
axis, *arrays = inputs
@@ -2559,16 +2559,86 @@ def c_code(self, node, name, inputs, outputs, sub):
25592559
code = f"""
25602560
int axis = {axis_def}
25612561
PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
2562-
PyObject* arrays_tuple = PyTuple_New({n});
2562+
int out_is_valid = {out} != NULL;
25632563
25642564
{axis_check}
25652565
2566-
Py_XDECREF({out});
2567-
{copy_arrays_to_tuple}
2568-
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2569-
Py_DECREF(arrays_tuple);
2570-
if(!{out}){{
2571-
{fail}
2566+
if (out_is_valid) {{
2567+
// Check if we can reuse output
2568+
npy_intp join_size = 0;
2569+
npy_intp out_shape[{ndim}];
2570+
npy_intp *shape = PyArray_SHAPE(arrays[0]);
2571+
2572+
for (int i = 0; i < {n}; i++) {{
2573+
if (PyArray_NDIM(arrays[i]) != {ndim}) {{
2574+
PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2575+
{fail}
2576+
}}
2577+
2578+
join_size += PyArray_SHAPE(arrays[i])[axis];
2579+
2580+
if (i > 0){{
2581+
for (int j = 0; j < {ndim}; j++) {{
2582+
if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2583+
PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2584+
{fail}
2585+
}}
2586+
}}
2587+
}}
2588+
}}
2589+
2590+
memcpy(out_shape, shape, {ndim} * sizeof(npy_intp));
2591+
out_shape[axis] = join_size;
2592+
2593+
for (int i = 0; i < {ndim}; i++) {{
2594+
out_is_valid &= (PyArray_SHAPE({out})[i] == out_shape[i]);
2595+
}}
2596+
}}
2597+
2598+
if (!out_is_valid) {{
2599+
// Use PyArray_Concatenate
2600+
Py_XDECREF({out});
2601+
PyObject* arrays_tuple = PyTuple_New({n});
2602+
{copy_arrays_to_tuple}
2603+
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2604+
Py_DECREF(arrays_tuple);
2605+
if(!{out}){{
2606+
{fail}
2607+
}}
2608+
}}
2609+
else {{
2610+
// Copy the data to the pre-allocated output buffer
2611+
2612+
// Create view into output buffer
2613+
PyArrayObject_fields *view;
2614+
2615+
// PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2616+
Py_INCREF(PyArray_DESCR({out}));
2617+
view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2618+
PyArray_DESCR({out}),
2619+
{ndim},
2620+
PyArray_SHAPE(arrays[0]),
2621+
PyArray_STRIDES({out}),
2622+
PyArray_DATA({out}),
2623+
NPY_ARRAY_WRITEABLE,
2624+
NULL);
2625+
if (view == NULL) {{
2626+
{fail}
2627+
}}
2628+
2629+
// Copy data into output buffer
2630+
for (int i = 0; i < {n}; i++) {{
2631+
view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2632+
2633+
if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2634+
Py_DECREF(view);
2635+
{fail}
2636+
}}
2637+
2638+
view->data += (view->dimensions[axis] * view->strides[axis]);
2639+
}}
2640+
2641+
Py_DECREF(view);
25722642
}}
25732643
"""
25742644
return code

tests/tensor/test_basic.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
ivector,
118118
lscalar,
119119
lvector,
120+
matrices,
120121
matrix,
121122
row,
122123
scalar,
@@ -1762,7 +1763,7 @@ def test_join_matrixV_negative_axis(self):
17621763
got = f(-2)
17631764
assert np.allclose(got, want)
17641765

1765-
with pytest.raises(IndexError):
1766+
with pytest.raises(ValueError):
17661767
f(-3)
17671768

17681769
@pytest.mark.parametrize("py_impl", (False, True))
@@ -1805,7 +1806,7 @@ def test_join_matrixC_negative_axis(self, py_impl):
18051806
got = f()
18061807
assert np.allclose(got, want)
18071808

1808-
with pytest.raises(IndexError):
1809+
with pytest.raises(ValueError):
18091810
join(-3, a, b)
18101811

18111812
with impl_ctxt:
@@ -2152,6 +2153,32 @@ def test_split_view(self, linker):
21522153
assert np.allclose(r, expected)
21532154
assert r.base is x_test
21542155

2156+
@pytest.mark.parametrize("gc", (True, False), ids=lambda x: f"gc={x}")
2157+
@pytest.mark.parametrize("memory_layout", ["C-contiguous", "F-contiguous", "Mixed"])
2158+
@pytest.mark.parametrize("axis", (0, 1), ids=lambda x: f"axis={x}")
2159+
@pytest.mark.parametrize("ndim", (1, 2), ids=["vector", "matrix"])
2160+
@config.change_flags(cmodule__warn_no_version=False)
2161+
def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark):
2162+
if ndim == 1 and not (memory_layout == "C-contiguous" and axis == 0):
2163+
pytest.skip("Redundant parametrization")
2164+
n = 64
2165+
inputs = vectors("abcdef") if ndim == 1 else matrices("abcdef")
2166+
out = join(axis, *inputs)
2167+
fn = pytensor.function(inputs, Out(out, borrow=True), trust_input=True)
2168+
fn.vm.allow_gc = gc
2169+
test_values = [np.zeros((n, n)[:ndim], dtype=inputs[0].dtype) for _ in inputs]
2170+
if memory_layout == "C-contiguous":
2171+
pass
2172+
elif memory_layout == "F-contiguous":
2173+
test_values = [t.T for t in test_values]
2174+
elif memory_layout == "Mixed":
2175+
test_values = [t if i % 2 else t.T for i, t in enumerate(test_values)]
2176+
else:
2177+
raise ValueError
2178+
2179+
assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6)
2180+
benchmark(fn, *test_values)
2181+
21552182

21562183
def test_TensorFromScalar():
21572184
s = ps.constant(56)

0 commit comments

Comments
 (0)