Skip to content

Commit 166f25e

Browse files
committed
Clean up flushing, add communication of Python objects
1 parent c1b878d commit 166f25e

File tree

4 files changed

+59
-21
lines changed

4 files changed

+59
-21
lines changed

demos/mpi-alltoall.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from mpi4py import MPI
22
import numpy as np
3-
import sys
3+
from sys import stdout
44

55
comm = MPI.COMM_WORLD
66
rank = comm.Get_rank()
@@ -10,6 +10,10 @@
1010

1111
data = np.arange(8) / 10. + rank
1212
recv_buf = np.zeros(8)
13+
# Python sequence, lenght has to be equal number to MPI tasks
14+
py_data = []
15+
for r in range(4):
16+
py_data.append({'key{0:02d}'.format(10*rank + r) : 10*rank + r})
1317

1418
if rank == 0:
1519
print("Original data")
@@ -18,22 +22,28 @@
1822
for r in range(size):
1923
if rank == r:
2024
print("rank ", rank, data)
21-
sys.stdout.flush()
25+
print("rank ", rank, py_data)
2226
comm.Barrier()
2327

28+
stdout.flush()
29+
2430
comm.Alltoall(data, recv_buf)
31+
new_data = comm.alltoall(py_data)
2532

2633
comm.Barrier()
2734
if rank == 0:
2835
print()
2936
print("Final data")
30-
sys.stdout.flush()
37+
38+
stdout.flush()
3139
comm.Barrier()
3240

3341
for r in range(size):
3442
if rank == r:
3543
print("rank ", rank, recv_buf)
36-
sys.stdout.flush()
44+
print("rank ", rank, new_data)
3745
comm.Barrier()
3846

47+
stdout.flush()
48+
3949

demos/mpi-bcast.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from mpi4py import MPI
22
import numpy as np
3-
import sys
3+
from sys import stdout
44

55
comm = MPI.COMM_WORLD
66
rank = comm.Get_rank()
@@ -9,27 +9,32 @@
99
assert size == 4
1010

1111
if rank == 0:
12-
data = np.arange(8) / 10.
12+
data = np.arange(8) / 10. # NumPy array
13+
py_data = {'key1' : 0.0, 'key2' : 11} # Python object
1314
else:
1415
data = np.zeros(8)
16+
py_data = None
1517

1618
if rank == 0:
1719
print("Original data")
18-
sys.stdout.flush()
20+
stdout.flush()
1921
comm.Barrier()
2022

2123
print("rank ", rank, data)
22-
sys.stdout.flush()
24+
print("rank ", rank, py_data)
25+
stdout.flush()
2326

2427
comm.Bcast(data, root=0)
28+
new_data = comm.bcast(py_data, root=0)
2529

2630
comm.Barrier()
2731
if rank == 0:
2832
print()
2933
print("Final data")
30-
sys.stdout.flush()
34+
stdout.flush()
3135
comm.Barrier()
3236

3337
print("rank ", rank, data)
38+
print("rank ", rank, new_data)
3439

3540

demos/mpi-gather.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,52 @@
11
from mpi4py import MPI
22
import numpy as np
3-
import sys
3+
from sys import stdout
44

55
comm = MPI.COMM_WORLD
66
rank = comm.Get_rank()
77
size = comm.Get_size()
88

99
assert size == 4
1010

11-
data = np.arange(2) / 10. + rank
11+
data = np.arange(2) / 10. + rank # NumPy array
1212

13+
# Let's create different Python objects for different MPI tasks
1314
if rank == 0:
15+
py_data = 'foo.bar'
16+
elif rank == 1:
17+
py_data = 12.34
18+
elif rank == 2:
19+
py_data = {'key1' : 99.0, 'key2' : [-1, 2.3]}
20+
else:
21+
py_data = [6.5, 4.3]
22+
23+
24+
if rank == 1:
1425
recv_buf = np.zeros(8)
1526
else:
1627
recv_buf = None
1728

1829
if rank == 0:
1930
print("Original data")
20-
sys.stdout.flush()
31+
32+
stdout.flush()
2133
comm.Barrier()
2234

2335
print("rank ", rank, data)
24-
sys.stdout.flush()
36+
print("rank ", rank, py_data)
37+
stdout.flush()
2538

26-
comm.Gather(data, recv_buf, root=0)
39+
comm.Gather(data, recv_buf, root=1)
40+
new_data = comm.gather(py_data, root=1)
2741

2842
comm.Barrier()
2943
if rank == 0:
3044
print()
3145
print("Final data")
32-
sys.stdout.flush()
46+
47+
stdout.flush()
3348
comm.Barrier()
3449

3550
print("rank ", rank, recv_buf)
36-
51+
print("rank ", rank, new_data)
3752

demos/mpi-scatter.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from mpi4py import MPI
22
import numpy as np
3-
import sys
3+
from sys import stdout
44

55
comm = MPI.COMM_WORLD
66
rank = comm.Get_rank()
@@ -9,29 +9,37 @@
99
assert size == 4
1010

1111
if rank == 0:
12-
data = np.arange(8) / 10.
12+
data = np.arange(8) / 10. # NumPy array
13+
# Python sequence, lenght has to be equal number to MPI tasks
14+
py_data = ['foo', 'bar', 11.2, {'key' : 22}]
1315
else:
1416
data = None
17+
py_data = None
1518

1619
recv_buf = np.zeros(2)
1720

1821
if rank == 0:
1922
print("Original data")
20-
sys.stdout.flush()
23+
24+
stdout.flush()
2125
comm.Barrier()
2226

2327
print("rank ", rank, data)
24-
sys.stdout.flush()
28+
print("rank ", rank, py_data)
29+
stdout.flush()
2530

2631
comm.Scatter(data, recv_buf, root=0)
32+
new_data = comm.scatter(py_data, root=0)
2733

2834
comm.Barrier()
2935
if rank == 0:
3036
print()
3137
print("Final data")
32-
sys.stdout.flush()
38+
39+
stdout.flush()
3340
comm.Barrier()
3441

3542
print("rank ", rank, recv_buf)
43+
print("rank ", rank, new_data)
3644

3745

0 commit comments

Comments
 (0)