Skip to content

Commit ab159f6

Browse files
authored
Merge pull request #206 from Pressio/issue_159
parallelize sampling
2 parents 81dc4dd + a4b770d commit ab159f6

File tree

3 files changed

+85
-31
lines changed

3 files changed

+85
-31
lines changed

romtools/hyper_reduction/deim.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _deim_get_indices_sharedmem(U):
8383
return indices
8484

8585

86-
class _DistDeimData:
86+
class _dist_deim_data:
8787
def __init__(self, i, r):
8888
self.local_indices = np.array([int(i)])
8989
self.owning_ranks = np.array([int(r)])
@@ -96,26 +96,26 @@ def append(self, i, r):
9696

9797
def _deim_get_indices_distributed(U, comm):
9898
m = np.shape(U)[1]
99-
local_index, found_rank = la.argmax(np.abs(U[:, 0]), comm)
100-
result = _DistDeimData(local_index, found_rank)
99+
local_index, foundRank = la.argmax(np.abs(U[:, 0]), comm)
100+
result = _dist_deim_data(local_index, foundRank)
101101
if m == 1:
102102
return result.local_indices, result.owning_ranks
103103

104-
my_rank = comm.Get_rank()
104+
myRank = comm.Get_rank()
105105
LHS, RHS, C = np.array([]), np.array([]), np.array([])
106106
for ell in range(1, m):
107-
indices = result.local_indices[result.owning_ranks==my_rank]
107+
indices = result.local_indices[result.owning_ranks==myRank]
108108
LHS = np.array([]) if indices.size == 0 else U[indices, 0:ell]
109109
RHS = np.array([]) if indices.size == 0 else U[indices, ell]
110110

111111
A, b = la.move_distributed_linear_system_to_rank_zero(LHS, RHS, comm)
112-
if my_rank == 0:
112+
if myRank == 0:
113113
C = np.linalg.solve(A, b)
114114
C = comm.bcast(C, root=0)
115115

116116
residual = U[:, ell] - U[:, 0:ell] @ C
117-
local_index, found_rank = la.argmax(np.abs(residual), comm)
118-
result.append(local_index, found_rank)
117+
local_index, foundRank = la.argmax(np.abs(residual), comm)
118+
result.append(local_index, foundRank)
119119

120120
return result.local_indices, result.owning_ranks
121121

romtools/workflows/sampling/sampling.py

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,26 @@
4646
import os
4747
import time
4848
import numpy as np
49+
import concurrent.futures
50+
import multiprocessing
4951

5052
from romtools.workflows.workflow_utils import create_empty_dir
5153
from romtools.workflows.models import Model
5254
from romtools.workflows.parameter_spaces import ParameterSpace
5355

5456

57+
def _get_run_id_from_run_dir(run_dir):
58+
return int(run_dir.split('_')[-1])
59+
60+
5561
def _create_parameter_dict(parameter_names, parameter_values):
5662
return dict(zip(parameter_names, parameter_values))
5763

5864

5965
def run_sampling(model: Model,
6066
parameter_space: ParameterSpace,
6167
absolute_sampling_directory: str,
68+
evaluation_concurrency = 1,
6269
number_of_samples: int = 10,
6370
random_seed: int = 1,
6471
dry_run: bool = False,
@@ -67,6 +74,17 @@ def run_sampling(model: Model,
6774
Core algorithm
6875
'''
6976

77+
# we use here spawn because the default fork causes issues with mpich,
78+
# see here: https://github.com/Pressio/rom-tools-and-workflows/pull/206
79+
#
80+
# to read more about fork/spawn:
81+
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing-start-methods
82+
#
83+
# and
84+
# https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ProcessPoolExecutor
85+
#
86+
mp_cntxt=multiprocessing.get_context("spawn")
87+
7088
np.random.seed(random_seed)
7189

7290
# create parameter samples
@@ -85,40 +103,65 @@ def run_sampling(model: Model,
85103
model.populate_run_directory(run_directory, parameter_dict)
86104
run_directories.append(run_directory)
87105

88-
# Run cases if dry_run is not set
106+
# Print MPI warnings
107+
print("""
108+
Warning: If you are using your model with MPI via a direct call to `mpirun -n ...`,
109+
be aware that this may or may not work for issues that are purely related to MPI.
110+
""")
89111
if not dry_run:
90-
run_times = np.zeros(number_of_samples)
91-
for sample_index in range(0, number_of_samples):
92-
print("======= Sample " + str(sample_index) + " ============")
93-
run_directory = f'{run_directory_base}{sample_index}'
94-
if "passed.txt" in os.listdir(run_directory) and not overwrite:
95-
print("Skipping (Sample has already run successfully)")
96-
else:
112+
# Run cases
113+
if evaluation_concurrency == 1:
114+
run_times = np.zeros(number_of_samples)
115+
for sample_index in range(0, number_of_samples):
116+
print("======= Sample " + str(sample_index) + " ============")
97117
print("Running")
98-
parameter_dict = _create_parameter_dict(parameter_names, parameter_samples[sample_index])
99-
run_times[sample_index] = run_sample(run_directory, model, parameter_dict)
100-
sample_stats_save_directory = f'{run_directory_base}{sample_index}/../'
101-
np.savez(f'{sample_stats_save_directory}/sampling_stats',
102-
run_times=run_times)
118+
run_directory = f'{run_directory_base}{sample_index}'
119+
if "passed.txt" in os.listdir(run_directory) and not overwrite:
120+
print("Skipping (Sample has already run successfully)")
121+
else:
122+
print("Running")
123+
parameter_dict = _create_parameter_dict(parameter_names, parameter_samples[sample_index])
124+
run_times[sample_index] = run_sample(run_directory, model, parameter_dict)
125+
sample_stats_save_directory = f'{run_directory_base}{sample_index}/../'
126+
np.savez(f'{sample_stats_save_directory}/sampling_stats',
127+
run_times=run_times)
128+
else:
129+
#Identify samples to run
130+
samples_to_run = []
131+
for sample_index in range(0, number_of_samples):
132+
run_directory = f'{run_directory_base}{sample_index}'
133+
if "passed.txt" in os.listdir(run_directory) and not overwrite:
134+
print(f"Skipping sample {sample_index} (Sample has already run successfully)")
135+
pass
136+
else:
137+
samples_to_run.append(sample_index)
138+
with concurrent.futures.ProcessPoolExecutor(max_workers = evaluation_concurrency, mp_context=mp_cntxt) as executor:
139+
these_futures = [executor.submit(run_sample,
140+
f'{run_directory_base}{sample_id}', model,
141+
_create_parameter_dict(parameter_names, parameter_samples[sample_id]))
142+
for sample_id in samples_to_run]
103143

104-
return run_directories
144+
# Wait for all processes to finish
145+
concurrent.futures.wait(these_futures)
105146

147+
run_times = [future.result() for future in these_futures]
148+
sample_stats_save_directory = f'{run_directory_base}{sample_index}/../'
149+
np.savez(f'{sample_stats_save_directory}/sampling_stats', run_times=run_times)
150+
151+
return run_directories
106152

107-
def run_sample(run_directory: str, model: Model,
108-
parameter_sample: dict):
109-
'''
110-
Execute individual sample
111-
'''
112153

154+
def run_sample(run_directory: str, model: Model, parameter_sample: dict):
155+
run_id = _get_run_id_from_run_dir(run_directory)
113156
ts = time.time()
114157
flag = model.run_model(run_directory, parameter_sample)
115158
tf = time.time()
116159
run_time = tf - ts
117160

118161
if flag == 0:
162+
print(f"Sample {run_id} is complete, run time = {run_time}")
119163
np.savetxt(os.path.join(run_directory, 'passed.txt'), np.array([0]), '%i')
120-
print(f"Sample complete, run time = {run_time}")
121164
else:
122-
print(f"Sample failed, run time = {run_time}")
165+
print(f"Sample {run_id} failed, run time = {run_time}")
123166
print(" ")
124167
return run_time

tests/romtools/workflows/sampling/test_sampling.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import pytest
22
import os
33
import numpy as np
4+
import time
45

56
from romtools.workflows.sampling.sampling import run_sampling
67
from romtools.workflows.parameter_spaces import MonteCarloSampler, UniformParameterSpace
78

9+
def _get_run_id(run_dir):
10+
return int(run_dir.split('_')[-1])
811

912
class MockModel:
1013
def __init__(self):
@@ -17,10 +20,18 @@ def populate_run_directory(self, run_dir, parameter_sample):
1720
np.savez(f'{run_dir}/parameter_values.npz', parameter_values=parameter_values)
1821

1922
def run_model(self, run_dir, parameter_sample):
23+
print("running model in ", run_dir)
2024
params_input = np.load(f'{run_dir}/parameter_values.npz')['parameter_values']
2125
for i in range(0, len(parameter_sample)):
2226
parameter_name = list(parameter_sample.keys())[i]
2327
assert params_input[i] == parameter_sample[parameter_name]
28+
np.savetxt(f'{run_dir}/passed.txt', np.array([0]), '%i')
29+
30+
# add artificial lag centered around run_id=5
31+
# such that the closer the ID is to 5, the less the task waits.
32+
# totally arbitrary choice.
33+
seconds_to_wait = abs(_get_run_id(run_dir) - 5) * 4
34+
time.sleep( seconds_to_wait )
2435
return 0
2536

2637

@@ -32,9 +43,9 @@ def run_sampler(tmp_path, dry_run=False, overwrite=True):
3243
my_model = MockModel()
3344
run_directories = run_sampling(my_model, my_parameter_space,
3445
absolute_sampling_directory=tmp_path,
35-
number_of_samples=10, dry_run=dry_run,
46+
evaluation_concurrency=2,
47+
number_of_samples=10,dry_run=dry_run,
3648
overwrite=overwrite)
37-
3849
assert(len(run_directories)==10)
3950

4051
timestamps = []

0 commit comments

Comments
 (0)