16
16
17
17
s = [1000 , 2500 , 5000 , 10000 , 25000 , 50000 ]
18
18
p = [0.1 , 0.2 , 0.3 , 0.4 , 0.5 ]
19
- values_type = ['homo' , 'heter' ]
20
- events_type = ['bool' , 'float' ]
19
+ values_type = ['homo' ,
20
+ 'heter' ]
21
+ events_type = ['bool' ,
22
+ 'float' ,
23
+ ]
24
+ transpose = [True ,
25
+ False ]
21
26
22
27
print (bm .get_platform ())
23
28
24
29
25
- def test_event_ell_cpu (s , p , values_type , events_type ):
30
+ def test_event_ell_cpu (s , p , values_type , events_type , transpose ):
26
31
print ('s: ' , s , 'p: ' , p )
27
32
k = int (s * p )
28
33
bm .random .seed (1234 )
@@ -39,43 +44,42 @@ def test_event_ell_cpu(s, p, values_type, events_type):
39
44
dense [pre_indices , csr_indices ] = 1.0
40
45
41
46
if events_type == 'float' :
42
- vector = vector .astype (np .float32 )
43
- vector [vector == 1.0 ] = bm .random .rand (bm .sum (vector == 1.0 ))
47
+ vector = vector .astype (bm .float32 )
44
48
if values_type == 'heter' :
45
49
heter_data = bm .as_jax (rng .random (csr_indices .shape ))
46
50
weight = heter_data
47
51
48
52
# groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense)
49
53
50
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
54
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
51
55
# time.sleep(2)
52
56
53
57
time0 = time .time ()
54
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
58
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
55
59
time1 = time .time ()
56
60
# time.sleep(2)
57
61
58
62
time2 = time .time ()
59
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
63
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
60
64
time3 = time .time ()
61
65
# time.sleep(2)
62
66
63
67
time4 = time .time ()
64
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
68
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
65
69
time5 = time .time ()
66
70
# time.sleep(2)
67
71
68
72
time6 = time .time ()
69
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
73
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
70
74
time7 = time .time ()
71
75
72
76
time8 = time .time ()
73
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
77
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
74
78
time9 = time .time ()
75
79
76
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
77
- # print(result1[0])
78
- # print(result2)
80
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
81
+ # print(result1[0])
82
+ # print(result2)
79
83
# print(groundtruth - result1[0])
80
84
# print(groundtruth - result2)
81
85
@@ -85,26 +89,26 @@ def test_event_ell_cpu(s, p, values_type, events_type):
85
89
# assert bm.allclose(result1[0], result2)
86
90
87
91
time12 = time .time ()
88
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
92
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
89
93
time13 = time .time ()
90
94
# time.sleep(2)
91
95
92
96
time14 = time .time ()
93
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
97
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
94
98
time15 = time .time ()
95
99
# time.sleep(2)
96
100
97
101
time16 = time .time ()
98
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
102
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
99
103
time17 = time .time ()
100
104
# time.sleep(2)
101
105
102
106
time18 = time .time ()
103
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
107
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
104
108
time19 = time .time ()
105
109
106
110
time20 = time .time ()
107
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
111
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
108
112
time21 = time .time ()
109
113
110
114
taichi_aot_time1 = (time1 - time0 ) * 1000
@@ -136,7 +140,7 @@ def test_event_ell_cpu(s, p, values_type, events_type):
136
140
return taichi_aot_time1 , taichi_aot_time2 , taichi_aot_time3 , taichi_aot_time4 , taichi_aot_time5 ,\
137
141
brainpy_time1 , brainpy_time2 , brainpy_time3 , brainpy_time4 , brainpy_time5 , speedup
138
142
139
- def test_event_ell_gpu (s , p , values_type , events_type ):
143
+ def test_event_ell_gpu (s , p , values_type , events_type , transpose ):
140
144
print ('s: ' , s , 'p: ' , p )
141
145
k = int (s * p )
142
146
bm .random .seed (1234 )
@@ -152,8 +156,7 @@ def test_event_ell_gpu(s, p, values_type, events_type):
152
156
dense [pre_indices , csr_indices ] = 1.0
153
157
154
158
if events_type == 'float' :
155
- vector = vector .astype (np .float32 )
156
- vector [vector == 1.0 ] = bm .random .rand (bm .sum (vector == 1.0 ))
159
+ vector = vector .astype (bm .float32 )
157
160
if values_type == 'heter' :
158
161
heter_data = bm .as_jax (rng .random (csr_indices .shape ))
159
162
weight = heter_data
@@ -162,37 +165,39 @@ def test_event_ell_gpu(s, p, values_type, events_type):
162
165
163
166
164
167
165
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
168
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
166
169
# time.sleep(2)
167
170
168
171
time0 = time .time ()
169
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
172
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
170
173
time1 = time .time ()
171
174
# time.sleep(2)
172
175
173
176
time2 = time .time ()
174
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
177
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
175
178
time3 = time .time ()
176
179
# time.sleep(2)
177
180
178
181
time4 = time .time ()
179
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
182
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
180
183
time5 = time .time ()
181
184
# time.sleep(2)
182
185
183
186
time6 = time .time ()
184
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
187
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
185
188
time7 = time .time ()
186
189
187
190
time8 = time .time ()
188
- result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
191
+ result1 = jax .block_until_ready (bm .event .csrmv_taichi (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
189
192
time9 = time .time ()
190
193
191
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
194
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
192
195
# print('--------------------result1[0]------------------')
193
196
# print(result1[0])
194
197
# print('--------------------result2------------------')
195
198
# print(result2)
199
+ # print('--------------------gt------------------')
200
+ # print(groundtruth)
196
201
# print('--------------------gt - result1[0]------------------')
197
202
# print(groundtruth - result1[0])
198
203
# print('--------------------gt - result2------------------')
@@ -204,26 +209,26 @@ def test_event_ell_gpu(s, p, values_type, events_type):
204
209
# assert bm.allclose(result1[0], result2)
205
210
206
211
time12 = time .time ()
207
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
212
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
208
213
time13 = time .time ()
209
214
# time.sleep(2)
210
215
211
216
time14 = time .time ()
212
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
217
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
213
218
time15 = time .time ()
214
219
# time.sleep(2)
215
220
216
221
time16 = time .time ()
217
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
222
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
218
223
time17 = time .time ()
219
224
# time.sleep(2)
220
225
221
226
time18 = time .time ()
222
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
227
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
223
228
time19 = time .time ()
224
229
225
230
time20 = time .time ()
226
- result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = True ))
231
+ result2 = jax .block_until_ready (bm .event .csrmv (weight , csr_indices , csr_indptr , vector , shape = (s , s ), transpose = transpose ))
227
232
time21 = time .time ()
228
233
229
234
taichi_aot_time1 = (time1 - time0 ) * 1000
@@ -236,7 +241,7 @@ def test_event_ell_gpu(s, p, values_type, events_type):
236
241
brainpy_time3 = (time17 - time16 ) * 1000
237
242
brainpy_time4 = (time19 - time18 ) * 1000
238
243
brainpy_time5 = (time21 - time20 ) * 1000
239
-
244
+ print ( 's: ' , s , 'p: ' , p , 'values_type: ' , values_type , 'events_type: ' , events_type , 'transpose: ' , transpose )
240
245
print ('taichi_aot_1: ' , taichi_aot_time1 , 'ms' )
241
246
print ('taichi_aot_2: ' , taichi_aot_time2 , 'ms' )
242
247
print ('taichi_aot_3: ' , taichi_aot_time3 , 'ms' )
@@ -257,7 +262,7 @@ def test_event_ell_gpu(s, p, values_type, events_type):
257
262
brainpy_time1 , brainpy_time2 , brainpy_time3 , brainpy_time4 , brainpy_time5 , speedup
258
263
259
264
# init dataframe
260
- df = pd .DataFrame (columns = ['s' , 'p' , 'backend' , 'values type' , 'events type' ,
265
+ df = pd .DataFrame (columns = ['s' , 'p' , 'backend' , 'values type' , 'events type' , 'transpose' ,
261
266
'taichi aot time1(ms)' , 'taichi aot time2(ms)' , 'taichi aot time3(ms)' , 'taichi aot time4(ms)' , 'taichi aot time5(ms)' ,
262
267
'brainpy time1(ms)' , 'brainpy time2(ms)' , 'brainpy time3(ms)' , 'brainpy time4(ms)' , 'brainpy time5(ms)' ,
263
268
'speedup' ])
@@ -267,10 +272,11 @@ def test_event_ell_gpu(s, p, values_type, events_type):
267
272
for _p in p :
268
273
for _values_type in values_type :
269
274
for _events_type in events_type :
270
- taichi_aot_time_1 , taichi_aot_time_2 , taichi_aot_time_3 , taichi_aot_time_4 , taichi_aot_time_5 ,\
271
- brainpy_time_1 , brainpy_time_2 , brainpy_time_3 , brainpy_time_4 , brainpy_time_5 , speedup = test_event_ell_cpu (_s , _p , _values_type , _events_type )
275
+ for _transpose in transpose :
276
+ taichi_aot_time_1 , taichi_aot_time_2 , taichi_aot_time_3 , taichi_aot_time_4 , taichi_aot_time_5 ,\
277
+ brainpy_time_1 , brainpy_time_2 , brainpy_time_3 , brainpy_time_4 , brainpy_time_5 , speedup = test_event_ell_cpu (_s , _p , _values_type , _events_type , _transpose )
272
278
# append to dataframe
273
- df .loc [df .shape [0 ]] = [_s , _p , 'cpu' , _values_type , _events_type ,
279
+ df .loc [df .shape [0 ]] = [_s , _p , 'cpu' , _values_type , _events_type , _transpose ,
274
280
taichi_aot_time_1 , taichi_aot_time_2 , taichi_aot_time_3 , taichi_aot_time_4 , taichi_aot_time_5 ,
275
281
brainpy_time_1 , brainpy_time_2 , brainpy_time_3 , brainpy_time_4 , brainpy_time_5 , speedup ]
276
282
df .to_csv ('event_csrmv_cpu.csv' , index = False )
@@ -280,10 +286,11 @@ def test_event_ell_gpu(s, p, values_type, events_type):
280
286
for _p in p :
281
287
for _values_type in values_type :
282
288
for _events_type in events_type :
283
- taichi_aot_time_1 , taichi_aot_time_2 , taichi_aot_time_3 , taichi_aot_time_4 , taichi_aot_time_5 ,\
284
- brainpy_time_1 , brainpy_time_2 , brainpy_time_3 , brainpy_time_4 , brainpy_time_5 , speedup = test_event_ell_gpu (_s , _p , _values_type , _events_type )
289
+ for _transpose in transpose :
290
+ taichi_aot_time_1 , taichi_aot_time_2 , taichi_aot_time_3 , taichi_aot_time_4 , taichi_aot_time_5 ,\
291
+ brainpy_time_1 , brainpy_time_2 , brainpy_time_3 , brainpy_time_4 , brainpy_time_5 , speedup = test_event_ell_gpu (_s , _p , _values_type , _events_type , _transpose )
285
292
# append to dataframe
286
- df .loc [df .shape [0 ]] = [_s , _p , 'gpu' , _values_type , _events_type ,
293
+ df .loc [df .shape [0 ]] = [_s , _p , 'gpu' , _values_type , _events_type , transpose ,
287
294
taichi_aot_time_1 , taichi_aot_time_2 , taichi_aot_time_3 , taichi_aot_time_4 , taichi_aot_time_5 ,
288
295
brainpy_time_1 , brainpy_time_2 , brainpy_time_3 , brainpy_time_4 , brainpy_time_5 , speedup ]
289
296
df .to_csv ('event_csrmv_gpu.csv' , index = False )
0 commit comments