19
19
'csrmv_taichi' ,
20
20
]
21
21
22
+ ### CPU
22
23
23
24
@ti .kernel
24
25
def _sparse_csr_matvec_transpose_cpu (values : ti .types .ndarray (ndim = 1 ),
@@ -63,45 +64,58 @@ def _sparse_csr_matvec_cpu(values: ti.types.ndarray(ndim=1),
63
64
r += values [j ] * vector [col_indices [j ]]
64
65
out [row_i ] = r
65
66
67
+ ### GPU
68
+ # homo
66
69
67
70
@ti .kernel
68
- def _sparse_csr_matvec_transpose_gpu (values : ti .types .ndarray (ndim = 1 ),
71
+ def _sparse_csr_matvec_transpose_homo_gpu (values : ti .types .ndarray (ndim = 1 ),
69
72
col_indices : ti .types .ndarray (ndim = 1 ),
70
73
row_ptr : ti .types .ndarray (ndim = 1 ),
71
74
vector : ti .types .ndarray (ndim = 1 ),
72
75
out : ti .types .ndarray (ndim = 1 )):
73
- if values .shape [0 ] == 1 :
74
- value = values [0 ]
75
- for row_i in range (row_ptr .shape [0 ] - 1 ):
76
- for j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
77
- out [col_indices [j ]] += value * vector [row_i ]
78
-
79
- else :
80
- for row_i in range (row_ptr .shape [0 ] - 1 ):
81
- for j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
82
- out [col_indices [j ]] += values [j ] * vector [row_i ]
76
+ value = values [0 ]
77
+ for row_i in range (row_ptr .shape [0 ] - 1 ):
78
+ for j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
79
+ out [col_indices [j ]] += value * vector [row_i ]
83
80
84
81
85
82
@ti .kernel
86
- def _sparse_csr_matvec_gpu (values : ti .types .ndarray (ndim = 1 ),
83
+ def _sparse_csr_matvec_homo_gpu (values : ti .types .ndarray (ndim = 1 ),
87
84
col_indices : ti .types .ndarray (ndim = 1 ),
88
85
row_ptr : ti .types .ndarray (ndim = 1 ),
89
86
vector : ti .types .ndarray (ndim = 1 ),
90
87
out : ti .types .ndarray (ndim = 1 )):
91
- if values .shape [0 ] == 1 :
92
- value = values [0 ]
93
- for row_i in range (row_ptr .shape [0 ] - 1 ):
94
- r = 0.
95
- for j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
96
- r += value * vector [col_indices [j ]]
97
- out [row_i ] = r
88
+ value = values [0 ]
89
+ for row_i in range (row_ptr .shape [0 ] - 1 ):
90
+ r = 0.
91
+ for j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
92
+ r += value * vector [col_indices [j ]]
93
+ out [row_i ] = r
98
94
99
- else :
100
- for row_i in range (row_ptr .shape [0 ] - 1 ):
101
- r = 0.
102
- for j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
103
- r += values [j ] * vector [col_indices [j ]]
104
- out [row_i ] = r
95
+ # heter
96
+
97
+ @ti .kernel
98
+ def _sparse_csr_matvec_transpose_heter_gpu (values : ti .types .ndarray (ndim = 1 ),
99
+ col_indices : ti .types .ndarray (ndim = 1 ),
100
+ row_ptr : ti .types .ndarray (ndim = 1 ),
101
+ vector : ti .types .ndarray (ndim = 1 ),
102
+ out : ti .types .ndarray (ndim = 1 )):
103
+ for row_i in range (row_ptr .shape [0 ] - 1 ):
104
+ for j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
105
+ out [col_indices [j ]] += values [j ] * vector [row_i ]
106
+
107
+
108
+ @ti .kernel
109
+ def _sparse_csr_matvec_heter_gpu (values : ti .types .ndarray (ndim = 1 ),
110
+ col_indices : ti .types .ndarray (ndim = 1 ),
111
+ row_ptr : ti .types .ndarray (ndim = 1 ),
112
+ vector : ti .types .ndarray (ndim = 1 ),
113
+ out : ti .types .ndarray (ndim = 1 )):
114
+ for row_i in range (row_ptr .shape [0 ] - 1 ):
115
+ r = 0.
116
+ for j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
117
+ r += values [j ] * vector [col_indices [j ]]
118
+ out [row_i ] = r
105
119
106
120
107
121
def _sparse_csr_matvec_jvp_values (val_dot , values , col_indices , row_ptr , vector , * , outs , transpose , shape ):
@@ -195,9 +209,15 @@ def csrmv_taichi(
195
209
out_shape = shape [1 ] if transpose else shape [0 ]
196
210
197
211
if transpose :
198
- prim = _csr_matvec_transpose_p
212
+ if data .shape [0 ] == 1 :
213
+ prim = _csr_matvec_transpose_homo_p
214
+ else :
215
+ prim = _csr_matvec_transpose_heter_p
199
216
else :
200
- prim = _csr_matvec_p
217
+ if data .shape [0 ] == 1 :
218
+ prim = _csr_matvec_homo_p
219
+ else :
220
+ prim = _csr_matvec_heter_p
201
221
202
222
return prim (data ,
203
223
indices ,
@@ -215,10 +235,19 @@ def _define_op(cpu_kernel, gpu_kernel):
215
235
return prim
216
236
217
237
218
- # transpose
219
- _csr_matvec_transpose_p = _define_op (cpu_kernel = _sparse_csr_matvec_transpose_cpu ,
220
- gpu_kernel = _sparse_csr_matvec_transpose_gpu )
238
+ # transpose homo
239
+ _csr_matvec_transpose_homo_p = _define_op (cpu_kernel = _sparse_csr_matvec_transpose_cpu ,
240
+ gpu_kernel = _sparse_csr_matvec_transpose_homo_gpu )
241
+
242
+ # no transpose homo
243
+ _csr_matvec_homo_p = _define_op (cpu_kernel = _sparse_csr_matvec_cpu ,
244
+ gpu_kernel = _sparse_csr_matvec_homo_gpu )
245
+
246
+ # transpose heter
247
+ _csr_matvec_transpose_heter_p = _define_op (cpu_kernel = _sparse_csr_matvec_transpose_cpu ,
248
+ gpu_kernel = _sparse_csr_matvec_transpose_heter_gpu )
249
+
250
+ # no transpose heter
251
+ _csr_matvec_heter_p = _define_op (cpu_kernel = _sparse_csr_matvec_cpu ,
252
+ gpu_kernel = _sparse_csr_matvec_heter_gpu )
221
253
222
- # no transpose
223
- _csr_matvec_p = _define_op (cpu_kernel = _sparse_csr_matvec_cpu ,
224
- gpu_kernel = _sparse_csr_matvec_gpu )
0 commit comments