Skip to content

Commit

Permalink
Added sub kernels to support different input data types and alpha values
Browse files Browse the repository at this point in the history
  • Loading branch information
Lakshmidurga Kaparapu committed Feb 22, 2025
1 parent 69654d4 commit 38fff1b
Show file tree
Hide file tree
Showing 8 changed files with 2,897 additions and 13 deletions.
1,375 changes: 1,375 additions & 0 deletions xa_nnlib/algo/kernels/basic/xa_nn_elm_sub_32xf32.c

Large diffs are not rendered by default.

1,349 changes: 1,349 additions & 0 deletions xa_nnlib/algo/kernels/basic/xa_nn_elm_sub_f32x32.c

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions xa_nnlib/algo/kernels/basic/xa_nn_mean_f32.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
#include <string.h>

// Function to check if two consecutive axes are contiguous
WORD32 are_two_axes_contiguous(WORD32 a, WORD32 b) {
static inline WORD32 are_two_axes_contiguous(WORD32 a, WORD32 b) {
return (b == a + 1);
}

// Check if an axis is in the given axes list
WORD32 is_axis_in_list(WORD32 axis, const WORD32 *axes, WORD32 num_axes) {
static inline WORD32 is_axis_in_list(WORD32 axis, const WORD32 *axes, WORD32 num_axes) {
WORD32 i;
for (i = 0; i < num_axes; i++) {
if (axes[i] == axis)
Expand All @@ -44,7 +44,7 @@ WORD32 is_axis_in_list(WORD32 axis, const WORD32 *axes, WORD32 num_axes) {
}

// Sort axes in ascending order (Bubble Sort)
void sort_axes(WORD32 *axes, WORD32 num_axes) {
static inline void sort_axes(WORD32 *axes, WORD32 num_axes) {
WORD32 temp;
WORD32 i, j;
for (i = 0; i < (num_axes - 1); i++) {
Expand All @@ -60,7 +60,7 @@ void sort_axes(WORD32 *axes, WORD32 num_axes) {

// Merge contiguous axes
// Merge contiguous dimensions other than axes
void merge_axes_dims(const WORD32 *const input_shape, WORD32 num_dims, WORD32 *axes, WORD32 num_axes,
static inline void merge_axes_dims(const WORD32 *const input_shape, WORD32 num_dims, WORD32 *axes, WORD32 num_axes,
WORD32 *new_input_shape, WORD32 *new_num_dims, WORD32 *new_axes, WORD32 *new_num_axes) {
*new_num_dims = 0;
*new_num_axes = 0;
Expand Down Expand Up @@ -223,7 +223,7 @@ WORD32 xa_nn_mean_f32_f32(FLOAT32 *__restrict__ p_out,
merge_axes_dims(p_inp_shape, num_inp_dims, new_axes, num_axis_dims, new_input_shape, &new_num_inp_dims, new_axes_data, &new_num_axis_dims);

WORD32 last_dim = 0;

if(new_axes_data[new_num_axis_dims - CONST_ONE] == (new_num_inp_dims - CONST_ONE))
{
last_dim = CONST_ONE;
Expand Down Expand Up @@ -348,7 +348,7 @@ WORD32 xa_nn_mean_f32_f32(FLOAT32 *__restrict__ p_out,
/* Load input elements with stride "inner_stride" */
PDX_LAV_MXF32_XP(x1, align_src1, p_in_mxf32, rem_elm);

rem_sum = rem_sum + x1;//PDX_ADD_MXF32(rem_sum, x1);
rem_sum = PDX_ADD_MXF32(rem_sum, x1);
}

/* Store output */
Expand Down Expand Up @@ -684,7 +684,6 @@ WORD32 xa_nn_mean_f32_f32(FLOAT32 *__restrict__ p_out,
out = PDX_RADD_MXF32(sum);
xtfloat_storeip(out, p_z, 4);
}

}
}
else
Expand Down Expand Up @@ -772,7 +771,7 @@ WORD32 xa_nn_mean_f32_f32(FLOAT32 *__restrict__ p_out,
/* Store output */
PDX_SAV_MXF32_XP(rem_sum, align_dst, p_dst, rem_elm);
}
PDX_SAPOS_MXF32_FP(align_dst, p_dst);
PDX_SAPOS_MXF32_FP(align_dst, p_dst);
}
}
}
Expand Down
5 changes: 1 addition & 4 deletions xa_nnlib/build/common.mk
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,14 @@ $(LIBOBJ): $(OBJ_LIBOBJS) $(OBJ_LIBOSOBJS)
$(QUIET) -$(RM) $(TEMPOBJ)
endif


$(OBJ_LIBOBJS): $(OBJDIR)/%.o: %.c
@echo "Compiling $<"
$(QUIET) $(CC) -o $@ $(OPT_O3) $(CFLAGS) $(INCLUDES) -c $<

$(OBJ_LIBOSOBJS): $(OBJDIR)/%.o: %.c
@echo "Compiling $<"
$(QUIET) $(CC) -o $@ $(OPT_OS) $(CFLAGS) $(INCLUDES) -c $<




$(LIB): %.a: $(OBJDIR)/%.o
@echo "Creating Library $@"
$(QUIET) $(AR) rc $@ $^
Expand Down
13 changes: 12 additions & 1 deletion xa_nnlib/build/ldscript_nnlib.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,18 @@ EXTERN(xa_nn_elm_less_scalar_f32xf32_bool)
EXTERN(xa_nn_elm_less_broadcast_5D_f32xf32_bool)
EXTERN(xa_nn_elm_where_f32xf32_f32)
EXTERN(xa_nn_elm_where_broadcast_5D_f32xf32_f32)

EXTERN(xa_nn_elm_sub_f32x32xf32_f32)
EXTERN(xa_nn_elm_sub_scalar_f32x32xf32_f32)
EXTERN(xa_nn_elm_sub_broadcast_5D_f32x32xf32_f32)
EXTERN(xa_nn_elm_sub_f32x32x32_f32)
EXTERN(xa_nn_elm_sub_scalar_f32x32x32_f32)
EXTERN(xa_nn_elm_sub_broadcast_5D_f32x32x32_f32)
EXTERN(xa_nn_elm_sub_32xf32xf32_f32)
EXTERN(xa_nn_elm_sub_scalar_32xf32xf32_f32)
EXTERN(xa_nn_elm_sub_broadcast_5D_32xf32xf32_f32)
EXTERN(xa_nn_elm_sub_32xf32x32_f32)
EXTERN(xa_nn_elm_sub_scalar_32xf32x32_f32)
EXTERN(xa_nn_elm_sub_broadcast_5D_32xf32x32_f32)

/* Normalization kernels */
EXTERN(xa_nn_native_layer_norm_f32_f32)
Expand Down
2 changes: 2 additions & 0 deletions xa_nnlib/build/makefile_nn_lib_fusion_g3
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ BASICOBJS = \
xa_nn_elm_rsqrt_f32.o \
xa_nn_elm_where_f32.o \
xa_nn_elm_less_f32.o \
xa_nn_elm_sub_32xf32.o \
xa_nn_elm_sub_f32x32.o \


NORMOBJS = \
Expand Down
13 changes: 13 additions & 0 deletions xa_nnlib/build/symbols_nnlib.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,16 @@ xa_nn_elm_where_f32xf32_f32
xa_nn_elm_where_broadcast_5D_f32xf32_f32
xa_nn_sigmoid_f32_f32
xa_nn_tanh_f32_f32
xa_nn_elm_sub_f32x32xf32_f32
xa_nn_elm_sub_scalar_f32x32xf32_f32
xa_nn_elm_sub_broadcast_5D_f32x32xf32_f32
xa_nn_elm_sub_f32x32x32_f32
xa_nn_elm_sub_scalar_f32x32x32_f32
xa_nn_elm_sub_broadcast_5D_f32x32x32_f32
xa_nn_elm_sub_32xf32xf32_f32
xa_nn_elm_sub_scalar_32xf32xf32_f32
xa_nn_elm_sub_broadcast_5D_32xf32xf32_f32
xa_nn_elm_sub_32xf32x32_f32
xa_nn_elm_sub_scalar_32xf32x32_f32
xa_nn_elm_sub_broadcast_5D_32xf32x32_f32

138 changes: 138 additions & 0 deletions xa_nnlib/include/nnlib/xa_nnlib_kernels_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,144 @@ WORD32 xa_nn_mean_f32_f32(FLOAT32 *__restrict__ p_out,
const WORD32 *__restrict__ p_axis,
WORD32 num_axis_dims);

WORD32 xa_nn_elm_clamp_f32_f32(FLOAT32 *__restrict__ p_out,
const FLOAT32 *__restrict__ p_inp,
const FLOAT32 *__restrict__ p_min,
const FLOAT32 *__restrict__ p_max,
WORD32 num_elm);

WORD32 xa_nn_elm_clamp_scalar_f32_f32(FLOAT32 *__restrict__ p_out,
const FLOAT32 *__restrict__ p_inp,
const FLOAT32 min,
const FLOAT32 max,
WORD32 num_elm);

WORD32 xa_nn_elm_clamp_broadcast_5D_f32_f32(FLOAT32 *__restrict__ p_out,
const WORD32 *const p_out_shape,
const FLOAT32 *__restrict__ p_inp,
const WORD32 *const p_inp_shape,
const FLOAT32 *__restrict__ p_min,
const WORD32 *const p_min_shape,
const FLOAT32 *__restrict__ p_max,
const WORD32 *const p_max_shape,
WORD32 num_inp_dims);

WORD32 xa_nn_elm_where_f32xf32_f32(FLOAT32 *p_out,
const FLOAT32 *p_inp1,
const FLOAT32 *p_inp2,
const UWORD8 *p_cond,
WORD32 num_elm);

WORD32 xa_nn_tanh_f32_f32(FLOAT32 *p_out,
const FLOAT32 *p_inp,
WORD32 vec_length);

WORD32 xa_nn_elm_where_broadcast_5D_f32xf32_f32(FLOAT32 *__restrict__ p_out,
const WORD32 *const p_out_shape,
const FLOAT32 *__restrict__ p_inp1,
const WORD32 *const p_inp1_shape,
const FLOAT32 *__restrict__ p_inp2,
const WORD32 *const p_inp2_shape,
const UWORD8 *p_cond,
const WORD32 *const p_cond_shape,
WORD32 num_inp_dims);

WORD32 xa_nn_elm_less_f32xf32_bool(WORD8 *p_out,
const FLOAT32 *p_inp1,
const FLOAT32 *p_inp2,
WORD32 num_elm);

WORD32 xa_nn_elm_less_scalar_f32xf32_bool(WORD8 *p_out,
const FLOAT32 *p_inp1,
const FLOAT32 inp2,
WORD32 num_elm);

WORD32 xa_nn_elm_less_broadcast_5D_f32xf32_bool(WORD8 *__restrict__ p_out,
const WORD32 *const p_out_shape,
const FLOAT32 *__restrict__ p_inp1,
const WORD32 *const p_inp1_shape,
const FLOAT32 *__restrict__ p_inp2,
const WORD32 *const p_inp2_shape,
WORD32 num_inp_dims);

WORD32 xa_nn_elm_sqrt_f32_f32(FLOAT32 *p_out,
const FLOAT32 *p_inp,
WORD32 num_elm);

WORD32 xa_nn_elm_rsqrt_f32_f32(FLOAT32 *p_out,
const FLOAT32 *p_inp,
WORD32 num_elm);

WORD32 xa_nn_sigmoid_f32_f32(FLOAT32 *p_out,
const FLOAT32 *p_inp,
WORD32 vec_length);

WORD32 xa_nn_elm_clamp_8_8(WORD8 *p_out,
const WORD8 *p_inp,
const WORD8 *p_min,
const WORD8 *p_max,
WORD32 num_elm);

WORD32 xa_nn_elm_clamp_scalar_8_8(WORD8 *p_out,
const WORD8 *p_inp,
const WORD8 min,
const WORD8 max,
WORD32 num_elm);

WORD32 xa_nn_elm_clamp_broadcast_5D_8_8(WORD8 *__restrict__ p_out,
const WORD32 *const p_out_shape,
const WORD8 *__restrict__ p_inp,
const WORD32 *const p_inp_shape,
const WORD8 *__restrict__ p_min,
const WORD32 *const p_min_shape,
const WORD8 *__restrict__ p_max,
const WORD32 *const p_max_shape,
WORD32 num_inp_dims);

WORD32 xa_nn_elm_clamp_8u_8u(UWORD8 *p_out,
const UWORD8 *p_inp,
const UWORD8 *p_min,
const UWORD8 *p_max,
WORD32 num_elm);

WORD32 xa_nn_elm_clamp_scalar_8u_8u(UWORD8 *p_out,
const UWORD8 *p_inp,
const UWORD8 min,
const UWORD8 max,
WORD32 num_elm);

WORD32 xa_nn_elm_clamp_broadcast_5D_8u_8u(UWORD8 *__restrict__ p_out,
const WORD32 *const p_out_shape,
const UWORD8 *__restrict__ p_inp,
const WORD32 *const p_inp_shape,
const UWORD8 *__restrict__ p_min,
const WORD32 *const p_min_shape,
const UWORD8 *__restrict__ p_max,
const WORD32 *const p_max_shape,
WORD32 num_inp_dims);

WORD32 xa_nn_elm_clamp_16_16(WORD16 *p_out,
const WORD16 *p_inp,
const WORD16 *p_min,
const WORD16 *p_max,
WORD32 num_elm);

WORD32 xa_nn_elm_clamp_scalar_16_16(WORD16 *p_out,
const WORD16 *p_inp,
const WORD16 min,
const WORD16 max,
WORD32 num_elm);

WORD32 xa_nn_elm_clamp_broadcast_5D_16_16(WORD16 *__restrict__ p_out,
const WORD32 *const p_out_shape,
const WORD16 *__restrict__ p_inp,
const WORD32 *const p_inp_shape,
const WORD16 *__restrict__ p_min,
const WORD32 *const p_min_shape,
const WORD16 *__restrict__ p_max,
const WORD32 *const p_max_shape,
WORD32 num_inp_dims);

#if defined(__cplusplus)
}
#endif
Expand Down

0 comments on commit 38fff1b

Please sign in to comment.