Skip to content

Commit f2417b9

Browse files
authored
Update perturbIC.py
1 parent 4378440 commit f2417b9

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

src/perturbIC.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ def parse_args():
2525
parser = argparse.ArgumentParser(description="Perturb UM initial dump")
2626
parser.add_argument('-a', dest='amplitude', type=float, default=0.01,
2727
help = 'Amplitude of the perturbation.')
28-
parser.add_argument('-s','--seed', dest='seed', type=int
28+
parser.add_argument('-s','--seed', dest='seed', type=int,
2929
help = 'The seed value used to generate the random perturbation (must be a non-negative integer).')
3030
parser.add_argument('ifile', metavar="INPUT_PATH", help='Path to the input file.')
3131
parser.add_argument('--validate', action='store_true',
32-
help='Validate the output fields file using mule validation.)
32+
help='Validate the output fields file using mule validation.')
3333
args_parsed = parser.parse_args()
3434
return args_parsed
3535

@@ -47,9 +47,11 @@ def create_random_generator(value=None):
4747
numpy.random.Generator
4848
The numpy random generator object.
4949
"""
50+
51+
5052
if value < 0:
5153
raise ValueError('Seed value must be non-negative.')
52-
return Generator(PCG64(value))
54+
return Generator(PCG64(value))
5355

5456
def remove_timeseries(ff):
5557
"""
@@ -120,7 +122,7 @@ def create_perturbation(amplitude, random_generator, shape, nullify_poles = True
120122
perturbation = random_generator.uniform(low = -amplitude, high = amplitude, size = shape)
121123
# Set poles to zero (only necessary for ND grids, but doesn't hurt EG)
122124
if nullify_poles:
123-
perturbation[[0,-1],:] = 0
125+
perturbation[[0,-1],:] = 0
124126
return perturbation
125127

126128

@@ -165,17 +167,18 @@ def transform(self, source_field, new_field):
165167
Perform the field data manipulation: check that the array and source field data have the same shape and then add them together.
166168
"""
167169
data = source_field.get_data()
168-
if field_shape:=data.shape != array_shape:=self.array.shape:
170+
if (field_shape:=data.shape) != (array_shape:=self.array.shape):
169171
raise ValueError(f"Array and field could not be broadcast together with shapes {array_shape} and {field_shape}.")
170-
return data + self.array
172+
else:
173+
return data + self.array
171174

172175

173176
def void_validation(*args, **kwargs):
174177
"""
175178
Don't perform the validation, but print a message to inform that validation has been skipped.
176179
"""
177180
print('Skipping mule validation. To enable the validation, run using the "--validate" option.')
178-
181+
return
179182

180183

181184
def main():
@@ -191,35 +194,35 @@ def main():
191194

192195
# Create the output filename
193196
output_file = create_default_outname(args.ifile)
194-
197+
195198
# Create the random generator.
196199
random_generator = create_random_generator(args.seed)
197200

198201
# Skip mule validation if the "--validate" option is provided
199202
if args.validate:
200203
mule.DumpFile.validate = void_validation
201204
ff_raw = mule.DumpFile.from_file(args.ifile)
202-
205+
203206

204207
# Remove the time series from the data to ensure mule will work
205208
ff = remove_timeseries(ff_raw)
206209

207210
# loop through the fields
208211
for ifield, field in enumerate(ff.fields):
209212
if is_field_to_perturb(field, STASH_THETA):
210-
try:
213+
try:
211214
ff.fields[ifield] = perturb_operator(field)
212215
except NameError: # perturb_operator is not defined
213216
# Only create the perturb_operator if it does not exist yet
214-
shape = field.get_data().shape
215-
perturbation = create_perturbation(args.amplitude, random_generator, shape)
216-
perturb_operator = AdditionOperator(perturbation)
217-
ff.fields[ifield] = perturb_operator(field)
217+
218+
shape = field.get_data().shape
219+
perturbation = create_perturbation(args.amplitude, random_generator, shape)
220+
perturb_operator = AdditionOperator(perturbation)
221+
ff.fields[ifield] = perturb_operator(field)
218222

219223
ff.to_file(output_file)
220224

221225
if __name__== "__main__":
222226

223227
main()
224228

225-

0 commit comments

Comments
 (0)