Skip to content

Commit 0f21f5c

Browse files
committed
test: Fix synchronization in AMO tests
1 parent fbf8941 commit 0f21f5c

File tree

1 file changed

+17
-26
lines changed

1 file changed

+17
-26
lines changed

test/test_amo.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def testFetch(self):
2323
for t in types_ext:
2424
with self.subTest(type=t):
2525
src = shmem.array(mype, dtype=t)
26-
shmem.barrier_all()
2726
val = shmem.atomic_fetch(src, nxpe)
2827
self.assertEqual(val, nxpe)
2928
shmem.free(src)
@@ -47,7 +46,6 @@ def testSwap(self):
4746
for t in types_ext:
4847
with self.subTest(type=t):
4948
tgt = shmem.array(-1, dtype=t)
50-
shmem.barrier_all()
5149
val = shmem.atomic_swap(tgt, nxpe, nxpe)
5250
shmem.barrier_all()
5351
self.assertEqual(tgt, mype)
@@ -62,23 +60,26 @@ def testCompareSwap(self):
6260
for t in types_std:
6361
with self.subTest(type=t):
6462
tgt = shmem.array(0, dtype=t)
65-
shmem.barrier_all()
6663
#
64+
shmem.sync_all()
6765
val = shmem.atomic_compare_swap(tgt, 1, nxpe, nxpe)
6866
shmem.barrier_all()
6967
self.assertEqual(tgt, 0)
7068
self.assertEqual(val, 0)
7169
#
70+
shmem.sync_all()
7271
val = shmem.atomic_compare_swap(tgt, 0, nxpe, nxpe)
7372
shmem.barrier_all()
7473
self.assertEqual(tgt, mype)
7574
self.assertEqual(val, 0)
7675
#
76+
shmem.sync_all()
7777
val = shmem.atomic_compare_swap(tgt, nxpe, 0, nxpe)
7878
shmem.barrier_all()
7979
self.assertEqual(tgt, 0)
8080
self.assertEqual(val, nxpe)
8181
#
82+
shmem.sync_all()
8283
val = shmem.atomic_compare_swap(tgt, npes, 0, nxpe)
8384
shmem.barrier_all()
8485
self.assertEqual(tgt, 0)
@@ -93,16 +94,15 @@ def testFetchOp(self):
9394
for t in types_std:
9495
with self.subTest(type=t):
9596
tgt = shmem.array(0, dtype=t)
96-
shmem.barrier_all()
9797
for i in range(3):
98+
shmem.sync_all()
9899
op = shmem.AMO_INC
99100
val = shmem.atomic_fetch_op(tgt, None, op, nxpe)
100-
shmem.barrier_all()
101101
self.assertEqual(val, i)
102102
for i in range(3):
103+
shmem.sync_all()
103104
op = shmem.AMO_ADD
104105
val = shmem.atomic_fetch_op(tgt, 1, op, nxpe)
105-
shmem.barrier_all()
106106
self.assertEqual(val, 3 + i)
107107
shmem.free(tgt)
108108
for t in types_ext:
@@ -122,13 +122,12 @@ def testOp(self):
122122
with self.subTest(type=t):
123123
tgt = shmem.array(0, dtype=t)
124124
for i in range(3):
125-
shmem.barrier_all()
125+
shmem.sync_all()
126126
op = shmem.AMO_INC
127127
val = shmem.atomic_fetch(tgt, nxpe)
128128
shmem.atomic_op(tgt, None, op, nxpe)
129129
self.assertEqual(val, i)
130130
for i in range(3):
131-
shmem.barrier_all()
132131
op = shmem.AMO_ADD
133132
val = shmem.atomic_fetch(tgt, nxpe)
134133
shmem.atomic_op(tgt, 1, op, nxpe)
@@ -151,14 +150,11 @@ def testFetchIncAdd(self):
151150
for t in types_std:
152151
with self.subTest(type=t):
153152
tgt = shmem.array(0, dtype=t)
154-
shmem.barrier_all()
155153
for i in range(3):
156154
val = shmem.atomic_fetch_inc(tgt, nxpe)
157-
shmem.barrier_all()
158155
self.assertEqual(val, i)
159156
for i in range(3):
160157
val = shmem.atomic_fetch_add(tgt, 1, nxpe)
161-
shmem.barrier_all()
162158
self.assertEqual(val, 3 + i)
163159
shmem.free(tgt)
164160

@@ -170,12 +166,12 @@ def testIncAdd(self):
170166
with self.subTest(type=t):
171167
tgt = shmem.array(0, dtype=t)
172168
for i in range(3):
173-
shmem.barrier_all()
169+
shmem.sync_all()
174170
val = shmem.atomic_fetch(tgt, nxpe)
175171
shmem.atomic_inc(tgt, nxpe)
176172
self.assertEqual(val, i)
177173
for i in range(3):
178-
shmem.barrier_all()
174+
shmem.sync_all()
179175
val = shmem.atomic_fetch(tgt, nxpe)
180176
shmem.atomic_add(tgt, 1, nxpe)
181177
self.assertEqual(val, 3 + i)
@@ -188,20 +184,16 @@ def testFetchBitwise(self):
188184
for t in types_bit:
189185
with self.subTest(type=t):
190186
tgt = shmem.array(0, dtype=t)
191-
shmem.barrier_all()
192187
for i in range(5):
193188
val = shmem.atomic_fetch_or(tgt, 1<<i, nxpe)
194-
shmem.barrier_all()
195189
self.assertEqual(val, 2**i-1)
196190
for i in reversed(range(5)):
197191
val = shmem.atomic_fetch_xor(tgt, 1<<i, nxpe)
198-
shmem.barrier_all()
199192
self.assertEqual(val, 2**(i+1)-1)
200193
shmem.atomic_set(tgt, 2**5-1, nxpe)
201194
shmem.barrier_all()
202195
for i in reversed(range(5)):
203196
val = shmem.atomic_fetch_and(tgt, 2**i-1, nxpe)
204-
shmem.barrier_all()
205197
self.assertEqual(val, 2**(i+1)-1)
206198
shmem.free(tgt)
207199

@@ -218,14 +210,12 @@ def testBitwise(self):
218210
shmem.atomic_or(tgt, 1<<i, nxpe)
219211
self.assertEqual(val, 2**i-1)
220212
for i in reversed(range(5)):
221-
shmem.barrier_all()
222213
val = shmem.atomic_fetch(tgt, nxpe)
223214
shmem.atomic_xor(tgt, 1<<i, nxpe)
224215
self.assertEqual(val, 2**(i+1)-1)
225216
shmem.barrier_all()
226217
shmem.atomic_set(tgt, 2**5-1, nxpe)
227218
for i in reversed(range(5)):
228-
shmem.barrier_all()
229219
val = shmem.atomic_fetch(tgt, nxpe)
230220
shmem.atomic_and(tgt, 2**i-1, nxpe)
231221
self.assertEqual(val, 2**(i+1)-1)
@@ -254,8 +244,8 @@ def testFetch(self):
254244
with self.subTest(type=t):
255245
val = np.array(0, dtype=t)
256246
src = shmem.array(mype, dtype=t)
257-
shmem.barrier_all()
258247
#
248+
shmem.sync_all()
259249
shmem.atomic_fetch_nbi(val, src, nxpe)
260250
shmem.quiet()
261251
self.assertEqual(val, nxpe)
@@ -278,14 +268,15 @@ def testSwap(self):
278268
with self.subTest(type=t):
279269
val = np.array(0, dtype=t)
280270
tgt = shmem.array(-1, dtype=t)
281-
shmem.barrier_all()
282271
#
272+
shmem.sync_all()
283273
shmem.atomic_swap_nbi(val, tgt, nxpe, nxpe)
284274
shmem.quiet()
285275
self.assertEqual(val, np.array(-1, dtype=t))
286276
shmem.sync_all()
287277
self.assertEqual(tgt, mype)
288278
#
279+
shmem.sync_all()
289280
shmem.atomic_swap_nbi(val, tgt, mype, nxpe)
290281
shmem.quiet()
291282
self.assertEqual(val, np.array(nxpe, dtype=t))
@@ -303,26 +294,29 @@ def testCompareSwap(self):
303294
with self.subTest(type=t):
304295
val = np.array(0, dtype=t)
305296
tgt = shmem.array(0, dtype=t)
306-
shmem.barrier_all()
307297
#
298+
shmem.sync_all()
308299
shmem.atomic_compare_swap_nbi(val, tgt, 1, nxpe, nxpe)
309300
shmem.quiet()
310301
self.assertEqual(val, 0)
311302
shmem.sync_all()
312303
self.assertEqual(tgt, 0)
313304
#
305+
shmem.sync_all()
314306
shmem.atomic_compare_swap_nbi(val, tgt, 0, nxpe, nxpe)
315307
shmem.quiet()
316308
self.assertEqual(val, 0)
317309
shmem.sync_all()
318310
self.assertEqual(tgt, mype)
319311
#
312+
shmem.sync_all()
320313
shmem.atomic_compare_swap_nbi(val, tgt, nxpe, 0, nxpe)
321314
shmem.quiet()
322315
self.assertEqual(val, nxpe)
323316
shmem.sync_all()
324317
self.assertEqual(tgt, 0)
325318
#
319+
shmem.sync_all()
326320
shmem.atomic_compare_swap_nbi(val, tgt, npes, 0, nxpe)
327321
shmem.quiet()
328322
self.assertEqual(val, 0)
@@ -339,7 +333,6 @@ def testFetchOp(self):
339333
with self.subTest(type=t):
340334
val = np.array(0, dtype=t)
341335
tgt = shmem.array(0, dtype=t)
342-
shmem.barrier_all()
343336
for i in range(3):
344337
op = shmem.AMO_INC
345338
shmem.atomic_fetch_op_nbi(val, tgt, None, op, nxpe)
@@ -370,7 +363,6 @@ def testFetchIncAdd(self):
370363
with self.subTest(type=t):
371364
val = np.array(0, dtype=t)
372365
tgt = shmem.array(0, dtype=t)
373-
shmem.barrier_all()
374366
for i in range(3):
375367
shmem.atomic_fetch_inc_nbi(val, tgt, nxpe)
376368
shmem.quiet()
@@ -389,7 +381,6 @@ def testFetchBitwise(self):
389381
with self.subTest(type=t):
390382
val = np.array(0, dtype=t)
391383
tgt = shmem.array(0, dtype=t)
392-
shmem.barrier_all()
393384
for i in range(5):
394385
shmem.atomic_fetch_or_nbi(val, tgt, 1<<i, nxpe)
395386
shmem.quiet()
@@ -399,7 +390,7 @@ def testFetchBitwise(self):
399390
shmem.quiet()
400391
self.assertEqual(val, 2**(i+1)-1)
401392
shmem.atomic_set(tgt, 2**5-1, nxpe)
402-
shmem.barrier_all()
393+
shmem.quiet()
403394
for i in reversed(range(5)):
404395
shmem.atomic_fetch_and_nbi(val, tgt, 2**i-1, nxpe)
405396
shmem.quiet()

0 commit comments

Comments
 (0)