Skip to content

Commit 66724b1

Browse files
committed
Add more benchmarks
1 parent 3ff5dbc commit 66724b1

File tree

3 files changed

+134
-75
lines changed

3 files changed

+134
-75
lines changed

brainpy/_src/math/event/_csr_matvec_taichi.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,19 +126,34 @@ def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
126126
events: ti.types.ndarray(ndim=1),
127127
out: ti.types.ndarray(ndim=1)):
128128
value = values[0]
129+
# total_rows = indptr.shape[0] - 1
130+
# for i in range(total_rows * 32):
131+
# row_i = ti.cast(ti.floor(i / 32), ti.i32)
132+
# index = i % 32
133+
# if events[row_i]:
134+
# for j in range(indptr[row_i], indptr[row_i + 1]):
135+
# if j % 32 == index:
136+
# out[indices[j]] += value
129137
for row_i in ti.ndrange(indptr.shape[0] - 1):
130138
if events[row_i]:
131139
for j in range(indptr[row_i], indptr[row_i + 1]):
132140
out[indices[j]] += value
133141

134-
135142
@ti.kernel
136143
def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
137144
indices: ti.types.ndarray(ndim=1),
138145
indptr: ti.types.ndarray(ndim=1),
139146
events: ti.types.ndarray(ndim=1),
140147
out: ti.types.ndarray(ndim=1)):
141148
value = values[0]
149+
# total_rows = indptr.shape[0] - 1
150+
# for i in range(total_rows * 32):
151+
# row_i = ti.cast(ti.floor(i / 32), ti.i32)
152+
# index = i % 32
153+
# if events[row_i] > 0.:
154+
# for j in range(indptr[row_i], indptr[row_i + 1]):
155+
# if j % 32 == index:
156+
# out[indices[j]] += value
142157
for row_i in ti.ndrange(indptr.shape[0] - 1):
143158
if events[row_i] > 0.:
144159
for j in range(indptr[row_i], indptr[row_i + 1]):
@@ -152,7 +167,16 @@ def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
152167
events: ti.types.ndarray(ndim=1),
153168
out: ti.types.ndarray(ndim=1)):
154169
value = values[0]
155-
for row_i in ti.ndrange(indptr.shape[0] - 1):
170+
# total_rows = indptr.shape[0] - 1
171+
# for i in ti.ndrange(total_rows * 32):
172+
# row_i = ti.cast(ti.floor(i / 32), ti.i32)
173+
# index = i % 32
174+
# r = 0.
175+
# for j in range(indptr[row_i], indptr[row_i + 1]):
176+
# if j % 32 == index and events[indices[j]]:
177+
# r += value
178+
# out[row_i] += r
179+
for row_i in range(indptr.shape[0] - 1):
156180
r = 0.
157181
for j in range(indptr[row_i], indptr[row_i + 1]):
158182
if events[indices[j]]:
@@ -166,7 +190,7 @@ def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
166190
events: ti.types.ndarray(ndim=1),
167191
out: ti.types.ndarray(ndim=1)):
168192
value = values[0]
169-
for row_i in ti.ndrange(indptr.shape[0] - 1):
193+
for row_i in range(indptr.shape[0] - 1):
170194
r = 0.
171195
for j in range(indptr[row_i], indptr[row_i + 1]):
172196
if events[indices[j]] > 0.:
@@ -181,6 +205,14 @@ def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
181205
indptr: ti.types.ndarray(ndim=1),
182206
events: ti.types.ndarray(ndim=1),
183207
out: ti.types.ndarray(ndim=1)):
208+
# total_rows = indptr.shape[0] - 1
209+
# for i in range(total_rows * 32):
210+
# row_i = ti.cast(ti.floor(i / 32), ti.i32)
211+
# index = i % 32
212+
# if events[row_i]:
213+
# for j in range(indptr[row_i], indptr[row_i + 1]):
214+
# if j % 32 == index:
215+
# out[indices[j]] += values[j]
184216
for row_i in ti.ndrange(indptr.shape[0] - 1):
185217
if events[row_i]:
186218
for j in range(indptr[row_i], indptr[row_i + 1]):
@@ -193,6 +225,14 @@ def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
193225
indptr: ti.types.ndarray(ndim=1),
194226
events: ti.types.ndarray(ndim=1),
195227
out: ti.types.ndarray(ndim=1)):
228+
# total_rows = indptr.shape[0] - 1
229+
# for i in range(total_rows * 32):
230+
# row_i = ti.cast(ti.floor(i / 32), ti.i32)
231+
# index = i % 32
232+
# if events[row_i] > 0.:
233+
# for j in range(indptr[row_i], indptr[row_i + 1]):
234+
# if j % 32 == index:
235+
# out[indices[j]] += values[j]
196236
for row_i in ti.ndrange(indptr.shape[0] - 1):
197237
if events[row_i] > 0.:
198238
for j in range(indptr[row_i], indptr[row_i + 1]):
@@ -205,7 +245,16 @@ def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
205245
indptr: ti.types.ndarray(ndim=1),
206246
events: ti.types.ndarray(ndim=1),
207247
out: ti.types.ndarray(ndim=1)):
208-
for row_i in ti.ndrange(indptr.shape[0] - 1):
248+
# total_rows = indptr.shape[0] - 1
249+
# for i in ti.ndrange(total_rows * 32):
250+
# row_i = ti.cast(ti.floor(i / 32), ti.i32)
251+
# index = i % 32
252+
# r = 0.
253+
# for j in range(indptr[row_i], indptr[row_i + 1]):
254+
# if j % 32 == index and events[indices[j]]:
255+
# r += values[j]
256+
# out[row_i] += r
257+
for row_i in range(indptr.shape[0] - 1):
209258
r = 0.
210259
for j in range(indptr[row_i], indptr[row_i + 1]):
211260
if events[indices[j]]:
@@ -218,7 +267,7 @@ def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
218267
indptr: ti.types.ndarray(ndim=1),
219268
events: ti.types.ndarray(ndim=1),
220269
out: ti.types.ndarray(ndim=1)):
221-
for row_i in ti.ndrange(indptr.shape[0] - 1):
270+
for row_i in range(indptr.shape[0] - 1):
222271
r = 0.
223272
for j in range(indptr[row_i], indptr[row_i + 1]):
224273
if events[indices[j]] > 0.:

brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616

1717
s = [1000, 2500, 5000, 10000, 25000, 50000]
1818
p = [0.1, 0.2, 0.3, 0.4, 0.5]
19-
values_type = ['homo', 'heter']
20-
events_type = ['bool', 'float']
19+
values_type = ['homo',
20+
'heter']
21+
events_type = ['bool',
22+
'float',
23+
]
24+
transpose = [True,
25+
False]
2126

2227
print(bm.get_platform())
2328

2429

25-
def test_event_ell_cpu(s, p, values_type, events_type):
30+
def test_event_ell_cpu(s, p, values_type, events_type, transpose):
2631
print('s: ', s, 'p: ', p)
2732
k = int(s * p)
2833
bm.random.seed(1234)
@@ -39,43 +44,42 @@ def test_event_ell_cpu(s, p, values_type, events_type):
3944
dense[pre_indices, csr_indices] = 1.0
4045

4146
if events_type == 'float':
42-
vector = vector.astype(np.float32)
43-
vector[vector == 1.0] = bm.random.rand(bm.sum(vector == 1.0))
47+
vector = vector.astype(bm.float32)
4448
if values_type == 'heter':
4549
heter_data = bm.as_jax(rng.random(csr_indices.shape))
4650
weight = heter_data
4751

4852
# groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
4953

50-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
54+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
5155
# time.sleep(2)
5256

5357
time0 = time.time()
54-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
58+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
5559
time1 = time.time()
5660
# time.sleep(2)
5761

5862
time2 = time.time()
59-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
63+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
6064
time3 = time.time()
6165
# time.sleep(2)
6266

6367
time4 = time.time()
64-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
68+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
6569
time5 = time.time()
6670
# time.sleep(2)
6771

6872
time6 = time.time()
69-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
73+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
7074
time7 = time.time()
7175

7276
time8 = time.time()
73-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
77+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
7478
time9 = time.time()
7579

76-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
77-
# print(result1[0])
78-
# print(result2)
80+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
81+
# print(result1[0])
82+
# print(result2)
7983
# print(groundtruth - result1[0])
8084
# print(groundtruth - result2)
8185

@@ -85,26 +89,26 @@ def test_event_ell_cpu(s, p, values_type, events_type):
8589
# assert bm.allclose(result1[0], result2)
8690

8791
time12 = time.time()
88-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
92+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
8993
time13 = time.time()
9094
# time.sleep(2)
9195

9296
time14 = time.time()
93-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
97+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
9498
time15 = time.time()
9599
# time.sleep(2)
96100

97101
time16 = time.time()
98-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
102+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
99103
time17 = time.time()
100104
# time.sleep(2)
101105

102106
time18 = time.time()
103-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
107+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
104108
time19 = time.time()
105109

106110
time20 = time.time()
107-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
111+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
108112
time21 = time.time()
109113

110114
taichi_aot_time1 = (time1 - time0) * 1000
@@ -136,7 +140,7 @@ def test_event_ell_cpu(s, p, values_type, events_type):
136140
return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\
137141
brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
138142

139-
def test_event_ell_gpu(s, p, values_type, events_type):
143+
def test_event_ell_gpu(s, p, values_type, events_type, transpose):
140144
print('s: ', s, 'p: ', p)
141145
k = int(s * p)
142146
bm.random.seed(1234)
@@ -152,8 +156,7 @@ def test_event_ell_gpu(s, p, values_type, events_type):
152156
dense[pre_indices, csr_indices] = 1.0
153157

154158
if events_type == 'float':
155-
vector = vector.astype(np.float32)
156-
vector[vector == 1.0] = bm.random.rand(bm.sum(vector == 1.0))
159+
vector = vector.astype(bm.float32)
157160
if values_type == 'heter':
158161
heter_data = bm.as_jax(rng.random(csr_indices.shape))
159162
weight = heter_data
@@ -162,37 +165,39 @@ def test_event_ell_gpu(s, p, values_type, events_type):
162165

163166

164167

165-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
168+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
166169
# time.sleep(2)
167170

168171
time0 = time.time()
169-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
172+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
170173
time1 = time.time()
171174
# time.sleep(2)
172175

173176
time2 = time.time()
174-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
177+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
175178
time3 = time.time()
176179
# time.sleep(2)
177180

178181
time4 = time.time()
179-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
182+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
180183
time5 = time.time()
181184
# time.sleep(2)
182185

183186
time6 = time.time()
184-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
187+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
185188
time7 = time.time()
186189

187190
time8 = time.time()
188-
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
191+
result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
189192
time9 = time.time()
190193

191-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
194+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
192195
# print('--------------------result1[0]------------------')
193196
# print(result1[0])
194197
# print('--------------------result2------------------')
195198
# print(result2)
199+
# print('--------------------gt------------------')
200+
# print(groundtruth)
196201
# print('--------------------gt - result1[0]------------------')
197202
# print(groundtruth - result1[0])
198203
# print('--------------------gt - result2------------------')
@@ -204,26 +209,26 @@ def test_event_ell_gpu(s, p, values_type, events_type):
204209
# assert bm.allclose(result1[0], result2)
205210

206211
time12 = time.time()
207-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
212+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
208213
time13 = time.time()
209214
# time.sleep(2)
210215

211216
time14 = time.time()
212-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
217+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
213218
time15 = time.time()
214219
# time.sleep(2)
215220

216221
time16 = time.time()
217-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
222+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
218223
time17 = time.time()
219224
# time.sleep(2)
220225

221226
time18 = time.time()
222-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
227+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
223228
time19 = time.time()
224229

225230
time20 = time.time()
226-
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
231+
result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
227232
time21 = time.time()
228233

229234
taichi_aot_time1 = (time1 - time0) * 1000
@@ -236,7 +241,7 @@ def test_event_ell_gpu(s, p, values_type, events_type):
236241
brainpy_time3 = (time17 - time16) * 1000
237242
brainpy_time4 = (time19 - time18) * 1000
238243
brainpy_time5 = (time21 - time20) * 1000
239-
244+
print('s: ', s, 'p: ', p, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose)
240245
print('taichi_aot_1: ', taichi_aot_time1, 'ms')
241246
print('taichi_aot_2: ', taichi_aot_time2, 'ms')
242247
print('taichi_aot_3: ', taichi_aot_time3, 'ms')
@@ -257,7 +262,7 @@ def test_event_ell_gpu(s, p, values_type, events_type):
257262
brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup
258263

259264
# init dataframe
260-
df = pd.DataFrame(columns=['s', 'p', 'backend', 'values type', 'events type',
265+
df = pd.DataFrame(columns=['s', 'p', 'backend', 'values type', 'events type', 'transpose',
261266
'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)',
262267
'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)',
263268
'speedup'])
@@ -267,10 +272,11 @@ def test_event_ell_gpu(s, p, values_type, events_type):
267272
for _p in p:
268273
for _values_type in values_type:
269274
for _events_type in events_type:
270-
taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
271-
brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_ell_cpu(_s, _p, _values_type, _events_type)
275+
for _transpose in transpose:
276+
taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
277+
brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_ell_cpu(_s, _p, _values_type, _events_type, _transpose)
272278
# append to dataframe
273-
df.loc[df.shape[0]] = [_s, _p, 'cpu', _values_type, _events_type,
279+
df.loc[df.shape[0]] = [_s, _p, 'cpu', _values_type, _events_type, _transpose,
274280
taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
275281
brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup]
276282
df.to_csv('event_csrmv_cpu.csv', index=False)
@@ -280,10 +286,11 @@ def test_event_ell_gpu(s, p, values_type, events_type):
280286
for _p in p:
281287
for _values_type in values_type:
282288
for _events_type in events_type:
283-
taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
284-
brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_ell_gpu(_s, _p, _values_type, _events_type)
289+
for _transpose in transpose:
290+
taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\
291+
brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_ell_gpu(_s, _p, _values_type, _events_type, _transpose)
285292
# append to dataframe
286-
df.loc[df.shape[0]] = [_s, _p, 'gpu', _values_type, _events_type,
293+
df.loc[df.shape[0]] = [_s, _p, 'gpu', _values_type, _events_type, transpose,
287294
taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
288295
brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup]
289296
df.to_csv('event_csrmv_gpu.csv', index=False)

0 commit comments

Comments
 (0)