Skip to content

Commit

Permalink
Optimize csr matvec with taichi customized op and Add taichi csr matv…
Browse files Browse the repository at this point in the history
…ec benchmark
  • Loading branch information
Routhleck committed Dec 11, 2023
1 parent c9923ca commit 6d9ca53
Show file tree
Hide file tree
Showing 2 changed files with 361 additions and 33 deletions.
95 changes: 62 additions & 33 deletions brainpy/_src/math/sparse/_csr_mv_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'csrmv_taichi',
]

### CPU

@ti.kernel
def _sparse_csr_matvec_transpose_cpu(values: ti.types.ndarray(ndim=1),
Expand Down Expand Up @@ -63,45 +64,58 @@ def _sparse_csr_matvec_cpu(values: ti.types.ndarray(ndim=1),
r += values[j] * vector[col_indices[j]]
out[row_i] = r

### GPU
# homo

@ti.kernel
def _sparse_csr_matvec_transpose_gpu(values: ti.types.ndarray(ndim=1),
def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
vector: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
if values.shape[0] == 1:
value = values[0]
for row_i in range(row_ptr.shape[0] - 1):
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[j]] += value * vector[row_i]

else:
for row_i in range(row_ptr.shape[0] - 1):
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[j]] += values[j] * vector[row_i]
value = values[0]
for row_i in range(row_ptr.shape[0] - 1):
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[j]] += value * vector[row_i]


@ti.kernel
def _sparse_csr_matvec_gpu(values: ti.types.ndarray(ndim=1),
def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
vector: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
if values.shape[0] == 1:
value = values[0]
for row_i in range(row_ptr.shape[0] - 1):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += value * vector[col_indices[j]]
out[row_i] = r
value = values[0]
for row_i in range(row_ptr.shape[0] - 1):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += value * vector[col_indices[j]]
out[row_i] = r

else:
for row_i in range(row_ptr.shape[0] - 1):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += values[j] * vector[col_indices[j]]
out[row_i] = r
# heter

@ti.kernel
def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
vector: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
for row_i in range(row_ptr.shape[0] - 1):
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[j]] += values[j] * vector[row_i]


@ti.kernel
def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
vector: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
for row_i in range(row_ptr.shape[0] - 1):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += values[j] * vector[col_indices[j]]
out[row_i] = r


def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape):
Expand Down Expand Up @@ -195,9 +209,15 @@ def csrmv_taichi(
out_shape = shape[1] if transpose else shape[0]

if transpose:
prim = _csr_matvec_transpose_p
if data.shape[0] == 1:
prim = _csr_matvec_transpose_homo_p
else:
prim = _csr_matvec_transpose_heter_p
else:
prim = _csr_matvec_p
if data.shape[0] == 1:
prim = _csr_matvec_homo_p
else:
prim = _csr_matvec_heter_p

return prim(data,
indices,
Expand All @@ -215,10 +235,19 @@ def _define_op(cpu_kernel, gpu_kernel):
return prim


# transpose
_csr_matvec_transpose_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_cpu,
gpu_kernel=_sparse_csr_matvec_transpose_gpu)
# transpose homo
_csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_cpu,
gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu)

# no transpose homo
_csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_cpu,
gpu_kernel=_sparse_csr_matvec_homo_gpu)

# transpose heter
_csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_cpu,
gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu)

# no transpose heter
_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_cpu,
gpu_kernel=_sparse_csr_matvec_heter_gpu)

# no transpose
_csr_matvec_p = _define_op(cpu_kernel=_sparse_csr_matvec_cpu,
gpu_kernel=_sparse_csr_matvec_gpu)
Loading

0 comments on commit 6d9ca53

Please sign in to comment.