46
46
import os
47
47
import time
48
48
import numpy as np
49
+ import concurrent .futures
50
+ import multiprocessing
49
51
50
52
from romtools .workflows .workflow_utils import create_empty_dir
51
53
from romtools .workflows .models import Model
52
54
from romtools .workflows .parameter_spaces import ParameterSpace
53
55
54
56
57
+ def _get_run_id_from_run_dir (run_dir ):
58
+ return int (run_dir .split ('_' )[- 1 ])
59
+
60
+
55
61
def _create_parameter_dict (parameter_names , parameter_values ):
56
62
return dict (zip (parameter_names , parameter_values ))
57
63
58
64
59
65
def run_sampling (model : Model ,
60
66
parameter_space : ParameterSpace ,
61
67
absolute_sampling_directory : str ,
68
+ evaluation_concurrency = 1 ,
62
69
number_of_samples : int = 10 ,
63
70
random_seed : int = 1 ,
64
71
dry_run : bool = False ,
@@ -67,6 +74,17 @@ def run_sampling(model: Model,
67
74
Core algorithm
68
75
'''
69
76
77
+ # we use here spawn because the default fork causes issues with mpich,
78
+ # see here: https://github.com/Pressio/rom-tools-and-workflows/pull/206
79
+ #
80
+ # to read more about fork/spawn:
81
+ # https://docs.python.org/3/library/multiprocessing.html#multiprocessing-start-methods
82
+ #
83
+ # and
84
+ # https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ProcessPoolExecutor
85
+ #
86
+ mp_cntxt = multiprocessing .get_context ("spawn" )
87
+
70
88
np .random .seed (random_seed )
71
89
72
90
# create parameter samples
@@ -85,40 +103,65 @@ def run_sampling(model: Model,
85
103
model .populate_run_directory (run_directory , parameter_dict )
86
104
run_directories .append (run_directory )
87
105
88
- # Run cases if dry_run is not set
106
+ # Print MPI warnings
107
+ print ("""
108
+ Warning: If you are using your model with MPI via a direct call to `mpirun -n ...`,
109
+ be aware that this may or may not work for issues that are purely related to MPI.
110
+ """ )
89
111
if not dry_run :
90
- run_times = np .zeros (number_of_samples )
91
- for sample_index in range (0 , number_of_samples ):
92
- print ("======= Sample " + str (sample_index ) + " ============" )
93
- run_directory = f'{ run_directory_base } { sample_index } '
94
- if "passed.txt" in os .listdir (run_directory ) and not overwrite :
95
- print ("Skipping (Sample has already run successfully)" )
96
- else :
112
+ # Run cases
113
+ if evaluation_concurrency == 1 :
114
+ run_times = np .zeros (number_of_samples )
115
+ for sample_index in range (0 , number_of_samples ):
116
+ print ("======= Sample " + str (sample_index ) + " ============" )
97
117
print ("Running" )
98
- parameter_dict = _create_parameter_dict (parameter_names , parameter_samples [sample_index ])
99
- run_times [sample_index ] = run_sample (run_directory , model , parameter_dict )
100
- sample_stats_save_directory = f'{ run_directory_base } { sample_index } /../'
101
- np .savez (f'{ sample_stats_save_directory } /sampling_stats' ,
102
- run_times = run_times )
118
+ run_directory = f'{ run_directory_base } { sample_index } '
119
+ if "passed.txt" in os .listdir (run_directory ) and not overwrite :
120
+ print ("Skipping (Sample has already run successfully)" )
121
+ else :
122
+ print ("Running" )
123
+ parameter_dict = _create_parameter_dict (parameter_names , parameter_samples [sample_index ])
124
+ run_times [sample_index ] = run_sample (run_directory , model , parameter_dict )
125
+ sample_stats_save_directory = f'{ run_directory_base } { sample_index } /../'
126
+ np .savez (f'{ sample_stats_save_directory } /sampling_stats' ,
127
+ run_times = run_times )
128
+ else :
129
+ #Identify samples to run
130
+ samples_to_run = []
131
+ for sample_index in range (0 , number_of_samples ):
132
+ run_directory = f'{ run_directory_base } { sample_index } '
133
+ if "passed.txt" in os .listdir (run_directory ) and not overwrite :
134
+ print (f"Skipping sample { sample_index } (Sample has already run successfully)" )
135
+ pass
136
+ else :
137
+ samples_to_run .append (sample_index )
138
+ with concurrent .futures .ProcessPoolExecutor (max_workers = evaluation_concurrency , mp_context = mp_cntxt ) as executor :
139
+ these_futures = [executor .submit (run_sample ,
140
+ f'{ run_directory_base } { sample_id } ' , model ,
141
+ _create_parameter_dict (parameter_names , parameter_samples [sample_id ]))
142
+ for sample_id in samples_to_run ]
103
143
104
- return run_directories
144
+ # Wait for all processes to finish
145
+ concurrent .futures .wait (these_futures )
105
146
147
+ run_times = [future .result () for future in these_futures ]
148
+ sample_stats_save_directory = f'{ run_directory_base } { sample_index } /../'
149
+ np .savez (f'{ sample_stats_save_directory } /sampling_stats' , run_times = run_times )
150
+
151
+ return run_directories
106
152
107
- def run_sample (run_directory : str , model : Model ,
108
- parameter_sample : dict ):
109
- '''
110
- Execute individual sample
111
- '''
112
153
154
+ def run_sample (run_directory : str , model : Model , parameter_sample : dict ):
155
+ run_id = _get_run_id_from_run_dir (run_directory )
113
156
ts = time .time ()
114
157
flag = model .run_model (run_directory , parameter_sample )
115
158
tf = time .time ()
116
159
run_time = tf - ts
117
160
118
161
if flag == 0 :
162
+ print (f"Sample { run_id } is complete, run time = { run_time } " )
119
163
np .savetxt (os .path .join (run_directory , 'passed.txt' ), np .array ([0 ]), '%i' )
120
- print (f"Sample complete, run time = { run_time } " )
121
164
else :
122
- print (f"Sample failed, run time = { run_time } " )
165
+ print (f"Sample { run_id } failed, run time = { run_time } " )
123
166
print (" " )
124
167
return run_time
0 commit comments