@@ -2537,58 +2537,185 @@ def perform(self, node, inputs, output_storage):
2537
2537
)
2538
2538
2539
2539
def c_code_cache_version (self ):
2540
- return (5 ,)
2540
+ return None
2541
+ return (6 ,)
2541
2542
2542
2543
def c_code (self , node , name , inputs , outputs , sub ):
2543
- axis , tens = inputs [0 ], inputs [1 :]
2544
- view = - 1
2545
- non_empty_tensor = tens [view ]
2546
- input_1 = tens [0 ]
2547
- l = len (tens )
2548
- (out ,) = outputs
2549
- fail = sub ["fail" ]
2550
- adtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2544
+ axis , * arrays = inputs
2545
+ [out ] = outputs
2551
2546
2552
- copy_to_list = (
2553
- f"""Py_INCREF({ inp } ); PyList_SetItem(list, { i } , (PyObject*){ inp } );"""
2554
- for i , inp in enumerate (tens )
2555
- )
2547
+ n = len (arrays )
2548
+ out_dtype = node .outputs [0 ].type .dtype_specs ()[2 ]
2549
+ out_itemsize = np .dtype (node .outputs [0 ].dtype ).itemsize
2550
+ ndim = node .outputs [0 ].type .ndim
2551
+ fail = sub ["fail" ]
2556
2552
2557
- copy_inputs_to_list = "\n " .join (copy_to_list )
2558
- n = len (tens )
2553
+ # Most times axis is constant, inline it
2554
+ # This is safe to do because the hash of the c_code includes the constant signature
2555
+ if isinstance (node .inputs [0 ], Constant ):
2556
+ static_axis = int (node .inputs [0 ].data )
2557
+ static_axis = normalize_axis_index (static_axis , ndim )
2558
+ axis_def = f"{ static_axis } ;"
2559
+ axis_check = ""
2560
+ else :
2561
+ axis_dtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2562
+ axis_def = f"(({ axis_dtype } *)PyArray_DATA({ axis } ))[0];"
2563
+ axis_check = f"""
2564
+ if (axis < 0){{
2565
+ axis = { ndim } + axis;
2566
+ }}
2567
+ if (axis >= { ndim } || axis < 0) {{
2568
+ PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
2569
+ { fail }
2570
+ }}
2571
+ """
2559
2572
2560
2573
code = f"""
2561
- int axis = (({ adtype } *)PyArray_DATA({ axis } ))[0];
2562
- PyObject* list = PyList_New({ l } );
2563
- { copy_inputs_to_list }
2564
- int tensors_lens_sum;
2565
- if({ view } != -1) {{
2566
- tensors_lens_sum = 0;
2567
-
2568
- for(int i=0; i < { n } ; i++){{
2569
- tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
2574
+ int axis = { axis_def }
2575
+ PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2576
+ npy_intp out_shape[{ ndim } ];
2577
+ npy_intp join_size = 0;
2578
+ int out_is_valid = 0;
2579
+ PyArrayObject_fields *view;
2580
+
2581
+ // Validate input shapes and compute join size
2582
+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2583
+
2584
+ { axis_check }
2585
+
2586
+ for (int i = 0; i < { n } ; i++) {{
2587
+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2588
+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2589
+ { fail }
2590
+ }}
2591
+
2592
+ join_size += PyArray_SHAPE(arrays[i])[axis];
2593
+
2594
+ if(i > 0){{
2595
+ for (int j = 0; j < { ndim } ; j++) {{
2596
+ if((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2597
+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2598
+ { fail }
2599
+ }}
2600
+ }}
2601
+ }}
2570
2602
}}
2571
- tensors_lens_sum -= PyArray_DIM({ non_empty_tensor } , axis);
2572
- }}
2573
- if({ view } != -1 && tensors_lens_sum == 0) {{
2574
- Py_XDECREF({ out } );
2575
- Py_INCREF({ non_empty_tensor } );
2576
- { out } = { non_empty_tensor } ;
2577
- }}else{{
2578
- //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
2579
- int ndim = PyArray_NDIM({ input_1 } );
2580
- if( axis < -ndim ){{
2581
- PyErr_Format(PyExc_IndexError,
2582
- "Join axis %d out of bounds [0, %d)", axis, ndim);
2583
- { fail }
2603
+
2604
+ // Define dimensions of output array
2605
+ memcpy(out_shape, shape, { ndim } * sizeof(npy_intp));
2606
+ out_shape[axis] = join_size;
2607
+
2608
+ // Reuse output or allocate new one
2609
+ if ({ out } != NULL) {{
2610
+ out_is_valid = (PyArray_NDIM({ out } ) == { ndim } );
2611
+ for (int i = 0; i < { ndim } ; i++) {{
2612
+ out_is_valid &= (PyArray_SHAPE({ out } )[i] == out_shape[i]);
2613
+ }}
2584
2614
}}
2585
- Py_XDECREF({ out } );
2586
- { out } = (PyArrayObject *)PyArray_Concatenate(list, axis);
2587
- Py_DECREF(list);
2588
- if(!{ out } ){{
2615
+
2616
+ if (!out_is_valid) {{
2617
+ Py_XDECREF({ out } );
2618
+
2619
+ // Find best memory layout to match the input tensors
2620
+ // Adapted from numpy PyArray_CreateMultiSortedStridePerm
2621
+ // https://github.com/numpy/numpy/blob/214b9f7c6d27f48b163dd7adbf9de368ad59859f/numpy/_core/src/multiarray/shape.c#L801
2622
+ int strideperm[{ ndim } ] = {{{ ',' .join (map (str , range (ndim )))} }};
2623
+ npy_intp strides[{ ndim } ];
2624
+
2625
+ // Sort strides (insertion sort)
2626
+ for (int i0 = 1; i0 < { ndim } ; ++i0) {{
2627
+ int ipos = i0;
2628
+ int ax_j0 = strideperm[i0];
2629
+
2630
+ for (int i1 = i0 - 1; i1 >= 0; --i1) {{
2631
+ int ambig = 1, shouldswap = 0;
2632
+ int ax_j1 = strideperm[i1];
2633
+
2634
+ for (int iarrays = 0; iarrays < { n } ; ++iarrays) {{
2635
+ if (PyArray_SHAPE(arrays[iarrays])[ax_j0] != 1 && PyArray_SHAPE(arrays[iarrays])[ax_j1] != 1) {{
2636
+ npy_intp stride0 = PyArray_STRIDES(arrays[iarrays])[ax_j0];
2637
+ npy_intp stride1 = PyArray_STRIDES(arrays[iarrays])[ax_j1];
2638
+ if (stride0 < 0) stride0 = -stride0;
2639
+ if (stride1 < 0) stride1 = -stride1;
2640
+
2641
+ if (stride0 <= stride1) {{
2642
+ shouldswap = 0;
2643
+ }}
2644
+ else if (ambig) {{
2645
+ shouldswap = 1;
2646
+ }}
2647
+ ambig = 0;
2648
+ }}
2649
+ }}
2650
+
2651
+ if (!ambig) {{
2652
+ if (shouldswap) {{
2653
+ ipos = i1;
2654
+ }}
2655
+ else {{
2656
+ break;
2657
+ }}
2658
+ }}
2659
+ }}
2660
+
2661
+ if (ipos != i0) {{
2662
+ for (int i1 = i0; i1 > ipos; --i1) {{
2663
+ strideperm[i1] = strideperm[i1-1];
2664
+ }}
2665
+ strideperm[ipos] = ax_j0;
2666
+ }}
2667
+ }}
2668
+
2669
+ // Calculate strides based on sorted order
2670
+ npy_intp stride = { out_itemsize } ;
2671
+ for (int i = { ndim } -1; i >= 0; --i) {{
2672
+ int ax = strideperm[i];
2673
+ strides[ax] = stride;
2674
+ stride *= out_shape[ax];
2675
+ }}
2676
+
2677
+ { out } = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
2678
+ PyArray_DescrFromType({ out_dtype } ),
2679
+ { ndim } ,
2680
+ out_shape,
2681
+ strides,
2682
+ NULL, /* data */
2683
+ NPY_ARRAY_DEFAULT,
2684
+ NULL);
2685
+
2686
+ if ({ out } == NULL) {{
2687
+ { fail }
2688
+ }}
2689
+ }}
2690
+
2691
+ // Create view into output buffer
2692
+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2693
+ Py_INCREF(PyArray_DESCR({ out } ));
2694
+ view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2695
+ PyArray_DESCR({ out } ),
2696
+ { ndim } ,
2697
+ PyArray_SHAPE(arrays[0]),
2698
+ PyArray_STRIDES({ out } ),
2699
+ PyArray_DATA({ out } ),
2700
+ NPY_ARRAY_WRITEABLE,
2701
+ NULL);
2702
+ if (view == NULL) {{
2589
2703
{ fail }
2590
2704
}}
2591
- }}
2705
+
2706
+ // Copy data into output buffer
2707
+ for (int i = 0; i < { n } ; i++) {{
2708
+ view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2709
+
2710
+ if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2711
+ Py_DECREF(view);
2712
+ { fail }
2713
+ }}
2714
+
2715
+ view->data += (view->dimensions[axis] * view->strides[axis]);
2716
+ }}
2717
+
2718
+ Py_DECREF(view);
2592
2719
"""
2593
2720
return code
2594
2721
0 commit comments