Skip to content

Commit 361dc2b

Browse files
authored
Benchmark rambo (#9)
* adding rambo example * Refactor random module and add tests
1 parent f818327 commit 361dc2b

File tree

10 files changed

+370
-36
lines changed

10 files changed

+370
-36
lines changed

Diff for: examples/rambo.py

+209
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""
2+
Rambo benchmark
3+
4+
Examples:
5+
6+
# run 1000 iterations of 10 events and 100 outputs on sharpy backend
7+
python rambo.py -nevts 10 -nout 100 -b sharpy -i 1000
8+
9+
# MPI parallel run
10+
mpiexec -n 3 python rambo.py -nevts 64 -nout 64 -b sharpy -i 1000
11+
12+
"""
13+
14+
import argparse
15+
import time as time_mod
16+
17+
import numpy
18+
19+
import sharpy
20+
21+
try:
22+
import mpi4py
23+
24+
mpi4py.rc.finalize = False
25+
from mpi4py import MPI
26+
27+
comm_rank = MPI.COMM_WORLD.Get_rank()
28+
comm = MPI.COMM_WORLD
29+
except ImportError:
30+
comm_rank = 0
31+
comm = None
32+
33+
34+
def info(s):
35+
if comm_rank == 0:
36+
print(s)
37+
38+
39+
def sp_rambo(sp, sp_C1, sp_F1, sp_Q1, sp_output, nevts, nout):
40+
sp_C = 2.0 * sp_C1 - 1.0
41+
sp_S = sp.sqrt(1 - sp.square(sp_C))
42+
sp_F = 2.0 * sp.pi * sp_F1
43+
sp_Q = -sp.log(sp_Q1)
44+
45+
sp_output[:, :, 0] = sp.reshape(sp_Q, (nevts, nout, 1))
46+
sp_output[:, :, 1] = sp.reshape(
47+
sp_Q * sp_S * sp.sin(sp_F), (nevts, nout, 1)
48+
)
49+
sp_output[:, :, 2] = sp.reshape(
50+
sp_Q * sp_S * sp.cos(sp_F), (nevts, nout, 1)
51+
)
52+
sp_output[:, :, 3] = sp.reshape(sp_Q * sp_C, (nevts, nout, 1))
53+
54+
sharpy.sync()
55+
56+
57+
def np_rambo(np, C1, F1, Q1, output, nevts, nout):
58+
C = 2.0 * C1 - 1.0
59+
S = np.sqrt(1 - np.square(C))
60+
F = 2.0 * np.pi * F1
61+
Q = -np.log(Q1)
62+
63+
output[:, :, 0] = Q
64+
output[:, :, 1] = Q * S * np.sin(F)
65+
output[:, :, 2] = Q * S * np.cos(F)
66+
output[:, :, 3] = Q * C
67+
68+
69+
def initialize(np, nevts, nout, seed, dtype):
70+
np.random.seed(seed)
71+
C1 = np.random.rand(nevts, nout)
72+
F1 = np.random.rand(nevts, nout)
73+
Q1 = np.random.rand(nevts, nout) * np.random.rand(nevts, nout)
74+
return (C1, F1, Q1, np.zeros((nevts, nout, 4), dtype))
75+
76+
77+
def run(nevts, nout, backend, iterations, datatype):
78+
if backend == "sharpy":
79+
import sharpy as np
80+
from sharpy import fini, init, sync
81+
82+
rambo = sp_rambo
83+
84+
init(False)
85+
elif backend == "numpy":
86+
import numpy as np
87+
88+
if comm is not None:
89+
assert (
90+
comm.Get_size() == 1
91+
), "Numpy backend only supports serial execution."
92+
93+
fini = sync = lambda x=None: None
94+
rambo = np_rambo
95+
else:
96+
raise ValueError(f'Unknown backend: "{backend}"')
97+
98+
dtype = {
99+
"f32": np.float32,
100+
"f64": np.float64,
101+
}[datatype]
102+
103+
info(f"Using backend: {backend}")
104+
info(f"Number of events: {nevts}")
105+
info(f"Number of outputs: {nout}")
106+
info(f"Datatype: {datatype}")
107+
108+
seed = 7777
109+
C1, F1, Q1, output = initialize(np, nevts, nout, seed, dtype)
110+
sync()
111+
112+
# verify
113+
if backend == "sharpy":
114+
sp_rambo(sharpy, C1, F1, Q1, output, nevts, nout)
115+
# sync() !! not work here?
116+
np_C1 = sharpy.to_numpy(C1)
117+
np_F1 = sharpy.to_numpy(F1)
118+
np_Q1 = sharpy.to_numpy(Q1)
119+
np_output = numpy.zeros((nevts, nout, 4))
120+
np_rambo(numpy, np_C1, np_F1, np_Q1, np_output, nevts, nout)
121+
assert numpy.allclose(sharpy.to_numpy(output), np_output)
122+
123+
def eval():
124+
tic = time_mod.perf_counter()
125+
rambo(np, C1, F1, Q1, output, nevts, nout)
126+
toc = time_mod.perf_counter()
127+
return toc - tic
128+
129+
# warm-up run
130+
t_warm = eval()
131+
132+
# evaluate
133+
info(f"Running {iterations} iterations")
134+
time_list = []
135+
for i in range(iterations):
136+
time_list.append(eval())
137+
138+
# get max time over mpi ranks
139+
if comm is not None:
140+
t_warm = comm.allreduce(t_warm, MPI.MAX)
141+
time_list = comm.allreduce(time_list, MPI.MAX)
142+
143+
t_min = numpy.min(time_list)
144+
t_max = numpy.max(time_list)
145+
t_med = numpy.median(time_list)
146+
init_overhead = t_warm - t_med
147+
if backend == "sharpy":
148+
info(f"Estimated initialization overhead: {init_overhead:.5f} s")
149+
info(f"Min. duration: {t_min:.5f} s")
150+
info(f"Max. duration: {t_max:.5f} s")
151+
info(f"Median duration: {t_med:.5f} s")
152+
153+
fini()
154+
155+
156+
if __name__ == "__main__":
157+
parser = argparse.ArgumentParser(
158+
description="Run rambo benchmark",
159+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
160+
)
161+
162+
parser.add_argument(
163+
"-nevts",
164+
"--num_events",
165+
type=int,
166+
default=10,
167+
help="Number of events to evaluate.",
168+
)
169+
parser.add_argument(
170+
"-nout",
171+
"--num_outputs",
172+
type=int,
173+
default=10,
174+
help="Number of outputs to evaluate.",
175+
)
176+
177+
parser.add_argument(
178+
"-b",
179+
"--backend",
180+
type=str,
181+
default="sharpy",
182+
choices=["sharpy", "numpy"],
183+
help="Backend to use.",
184+
)
185+
186+
parser.add_argument(
187+
"-i",
188+
"--iterations",
189+
type=int,
190+
default=10,
191+
help="Number of iterations to run.",
192+
)
193+
parser.add_argument(
194+
"-d",
195+
"--datatype",
196+
type=str,
197+
default="f64",
198+
choices=["f32", "f64"],
199+
help="Datatype for model state variables",
200+
)
201+
args = parser.parse_args()
202+
nevts, nout = args.num_events, args.num_outputs
203+
run(
204+
nevts,
205+
nout,
206+
args.backend,
207+
args.iterations,
208+
args.datatype,
209+
)

Diff for: imex_version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
199f9456fd31b96930395ab650fdb6fea42769dd
1+
a02f09350a8eba081c92a7d0117334eb56c9fb5a

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def build_cmake(self, ext):
5454
name="sharpy",
5555
version="0.2",
5656
description="Distributed array and more",
57-
packages=["sharpy", "sharpy.numpy"], # "sharpy.torch"],
57+
packages=["sharpy", "sharpy.numpy", "sharpy.random"], # "sharpy.torch"],
5858
ext_modules=[CMakeExtension("sharpy/_sharpy")],
5959
cmdclass=dict(
6060
# Enable the CMakeExtension entries defined above

Diff for: sharpy/__init__.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,22 @@
3838
from ._sharpy import sync
3939
from .ndarray import ndarray
4040

41+
42+
# Lazy load submodules
43+
def __getattr__(name):
44+
if name == "random":
45+
import sharpy.random as random
46+
47+
return random
48+
elif name == "numpy":
49+
import sharpy.numpy as numpy
50+
51+
return numpy
52+
53+
if "_fallback" in globals():
54+
return _fallback(name)
55+
56+
4157
_sharpy_cw = bool(int(getenv("SHARPY_CW", False)))
4258

4359
pi = 3.1415926535897932384626433
@@ -185,7 +201,3 @@ def __getattr__(self, name):
185201
dt.linalg.norm(...)
186202
"""
187203
return _fallback(name, self._func)
188-
189-
def __getattr__(name):
190-
"Attempt to find a fallback in fallback-lib"
191-
return _fallback(name)

Diff for: sharpy/random.py

-11
This file was deleted.

Diff for: sharpy/random/__init__.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
3+
import sharpy as sp
4+
from sharpy import float64
5+
from sharpy.numpy import fromfunction
6+
7+
8+
def uniform(low, high, size, device="", team=1):
9+
data = np.random.uniform(low, high, size)
10+
if len(data.shape) == 0:
11+
sp_data = sp.full((), data[()], device=device, team=team)
12+
return sp_data
13+
return fromfunction(
14+
lambda *index: data[index],
15+
data.shape,
16+
dtype=float64,
17+
device=device,
18+
team=team,
19+
)
20+
21+
22+
def rand(*shape, device="", team=1):
23+
data = np.random.rand(*shape)
24+
if isinstance(data, float):
25+
return data
26+
return fromfunction(
27+
lambda *index: data[index],
28+
data.shape,
29+
dtype=float64,
30+
device=device,
31+
team=team,
32+
)
33+
34+
35+
def seed(s):
36+
np.random.seed(s)

Diff for: src/Service.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ struct DeferredService : public DeferredT<Service::service_promise_type,
5151
// drop from dep manager
5252
dm.drop(_a);
5353
// and from registry
54-
Registry::del(_a);
54+
dm.addReady(_a, [this](id_type guid) {
55+
assert(this->_a == guid);
56+
Registry::del(guid);
57+
});
5558
break;
5659
}
5760
case RUN:

0 commit comments

Comments
 (0)