@@ -161,7 +161,7 @@ def test_event_ell_gpu(s, p, values_type, events_type, transpose):
161
161
heter_data = bm .as_jax (rng .random (csr_indices .shape ))
162
162
weight = heter_data
163
163
164
- groundtruth = bm .as_jax (vector , dtype = float ) @ bm .as_jax (dense )
164
+ # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
165
165
166
166
167
167
@@ -261,6 +261,8 @@ def test_event_ell_gpu(s, p, values_type, events_type, transpose):
261
261
return taichi_aot_time1 , taichi_aot_time2 , taichi_aot_time3 , taichi_aot_time4 , taichi_aot_time5 ,\
262
262
brainpy_time1 , brainpy_time2 , brainpy_time3 , brainpy_time4 , brainpy_time5 , speedup
263
263
264
+ PATH = os .path .dirname (os .path .abspath (__file__ ))
265
+
264
266
# init dataframe
265
267
df = pd .DataFrame (columns = ['s' , 'p' , 'backend' , 'values type' , 'events type' , 'transpose' ,
266
268
'taichi aot time1(ms)' , 'taichi aot time2(ms)' , 'taichi aot time3(ms)' , 'taichi aot time4(ms)' , 'taichi aot time5(ms)' ,
@@ -279,7 +281,7 @@ def test_event_ell_gpu(s, p, values_type, events_type, transpose):
279
281
df .loc [df .shape [0 ]] = [_s , _p , 'cpu' , _values_type , _events_type , _transpose ,
280
282
taichi_aot_time_1 , taichi_aot_time_2 , taichi_aot_time_3 , taichi_aot_time_4 , taichi_aot_time_5 ,
281
283
brainpy_time_1 , brainpy_time_2 , brainpy_time_3 , brainpy_time_4 , brainpy_time_5 , speedup ]
282
- df .to_csv (' event_csrmv_cpu.csv' , index = False )
284
+ df .to_csv (f' { PATH } / event_csrmv_cpu.csv' , index = False )
283
285
284
286
if (bm .get_platform () == 'gpu' ):
285
287
for _s in s :
@@ -293,7 +295,7 @@ def test_event_ell_gpu(s, p, values_type, events_type, transpose):
293
295
df .loc [df .shape [0 ]] = [_s , _p , 'gpu' , _values_type , _events_type , transpose ,
294
296
taichi_aot_time_1 , taichi_aot_time_2 , taichi_aot_time_3 , taichi_aot_time_4 , taichi_aot_time_5 ,
295
297
brainpy_time_1 , brainpy_time_2 , brainpy_time_3 , brainpy_time_4 , brainpy_time_5 , speedup ]
296
- df .to_csv (' event_csrmv_gpu.csv' , index = False )
298
+ df .to_csv (f' { PATH } / event_csrmv_gpu.csv' , index = False )
297
299
298
300
# if (bm.get_platform() == 'gpu'):
299
301
# for _s in s:
0 commit comments