Skip to content

Commit 6d9ca53

Browse files
committed
Optimize csr matvec with taichi customized op and Add taichi csr matvec benchmark
1 parent c9923ca commit 6d9ca53

File tree

2 files changed

+361
-33
lines changed

2 files changed

+361
-33
lines changed

brainpy/_src/math/sparse/_csr_mv_taichi.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'csrmv_taichi',
2020
]
2121

22+
### CPU
2223

2324
@ti.kernel
2425
def _sparse_csr_matvec_transpose_cpu(values: ti.types.ndarray(ndim=1),
@@ -63,45 +64,58 @@ def _sparse_csr_matvec_cpu(values: ti.types.ndarray(ndim=1),
6364
r += values[j] * vector[col_indices[j]]
6465
out[row_i] = r
6566

67+
### GPU
68+
# homo
6669

6770
@ti.kernel
68-
def _sparse_csr_matvec_transpose_gpu(values: ti.types.ndarray(ndim=1),
71+
def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
6972
col_indices: ti.types.ndarray(ndim=1),
7073
row_ptr: ti.types.ndarray(ndim=1),
7174
vector: ti.types.ndarray(ndim=1),
7275
out: ti.types.ndarray(ndim=1)):
73-
if values.shape[0] == 1:
74-
value = values[0]
75-
for row_i in range(row_ptr.shape[0] - 1):
76-
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
77-
out[col_indices[j]] += value * vector[row_i]
78-
79-
else:
80-
for row_i in range(row_ptr.shape[0] - 1):
81-
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
82-
out[col_indices[j]] += values[j] * vector[row_i]
76+
value = values[0]
77+
for row_i in range(row_ptr.shape[0] - 1):
78+
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
79+
out[col_indices[j]] += value * vector[row_i]
8380

8481

8582
@ti.kernel
86-
def _sparse_csr_matvec_gpu(values: ti.types.ndarray(ndim=1),
83+
def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
8784
col_indices: ti.types.ndarray(ndim=1),
8885
row_ptr: ti.types.ndarray(ndim=1),
8986
vector: ti.types.ndarray(ndim=1),
9087
out: ti.types.ndarray(ndim=1)):
91-
if values.shape[0] == 1:
92-
value = values[0]
93-
for row_i in range(row_ptr.shape[0] - 1):
94-
r = 0.
95-
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
96-
r += value * vector[col_indices[j]]
97-
out[row_i] = r
88+
value = values[0]
89+
for row_i in range(row_ptr.shape[0] - 1):
90+
r = 0.
91+
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
92+
r += value * vector[col_indices[j]]
93+
out[row_i] = r
9894

99-
else:
100-
for row_i in range(row_ptr.shape[0] - 1):
101-
r = 0.
102-
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
103-
r += values[j] * vector[col_indices[j]]
104-
out[row_i] = r
95+
# heter
96+
97+
@ti.kernel
98+
def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
99+
col_indices: ti.types.ndarray(ndim=1),
100+
row_ptr: ti.types.ndarray(ndim=1),
101+
vector: ti.types.ndarray(ndim=1),
102+
out: ti.types.ndarray(ndim=1)):
103+
for row_i in range(row_ptr.shape[0] - 1):
104+
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
105+
out[col_indices[j]] += values[j] * vector[row_i]
106+
107+
108+
@ti.kernel
109+
def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
110+
col_indices: ti.types.ndarray(ndim=1),
111+
row_ptr: ti.types.ndarray(ndim=1),
112+
vector: ti.types.ndarray(ndim=1),
113+
out: ti.types.ndarray(ndim=1)):
114+
for row_i in range(row_ptr.shape[0] - 1):
115+
r = 0.
116+
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
117+
r += values[j] * vector[col_indices[j]]
118+
out[row_i] = r
105119

106120

107121
def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape):
@@ -195,9 +209,15 @@ def csrmv_taichi(
195209
out_shape = shape[1] if transpose else shape[0]
196210

197211
if transpose:
198-
prim = _csr_matvec_transpose_p
212+
if data.shape[0] == 1:
213+
prim = _csr_matvec_transpose_homo_p
214+
else:
215+
prim = _csr_matvec_transpose_heter_p
199216
else:
200-
prim = _csr_matvec_p
217+
if data.shape[0] == 1:
218+
prim = _csr_matvec_homo_p
219+
else:
220+
prim = _csr_matvec_heter_p
201221

202222
return prim(data,
203223
indices,
@@ -215,10 +235,19 @@ def _define_op(cpu_kernel, gpu_kernel):
215235
return prim
216236

217237

218-
# transpose
219-
_csr_matvec_transpose_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_cpu,
220-
gpu_kernel=_sparse_csr_matvec_transpose_gpu)
238+
# transpose homo
239+
_csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_cpu,
240+
gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu)
241+
242+
# no transpose homo
243+
_csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_cpu,
244+
gpu_kernel=_sparse_csr_matvec_homo_gpu)
245+
246+
# transpose heter
247+
_csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_cpu,
248+
gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu)
249+
250+
# no transpose heter
251+
_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_cpu,
252+
gpu_kernel=_sparse_csr_matvec_heter_gpu)
221253

222-
# no transpose
223-
_csr_matvec_p = _define_op(cpu_kernel=_sparse_csr_matvec_cpu,
224-
gpu_kernel=_sparse_csr_matvec_gpu)

0 commit comments

Comments
 (0)