Skip to content

Commit 50e5131

Browse files
committed
Add missing argument to cmd line parser to allow reshuffling of land points
1 parent 574a227 commit 50e5131

File tree

2 files changed

+110
-31
lines changed

2 files changed

+110
-31
lines changed

src/ecmwf_models/era5/reshuffle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def main(args):
262262
variables=args.variables,
263263
bbox=args.bbox,
264264
h_steps=args.h_steps,
265+
land_points=args.land_points,
265266
imgbuffer=args.imgbuffer,
266267
)
267268

tests/tests_era5/test_era5_reshuffle.py

Lines changed: 109 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,68 +21,146 @@
2121
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2222
# SOFTWARE.
2323

24-
'''
24+
"""
2525
Test module for image to time series conversion.
26-
'''
26+
"""
2727

2828
import os
2929
import glob
3030
import tempfile
3131
import numpy as np
3232
import numpy.testing as nptest
33+
from datetime import datetime
34+
3335
from ecmwf_models.era5.reshuffle import main
3436
from ecmwf_models import ERATs
37+
from ecmwf_models.era5.reshuffle import parse_args
38+
39+
40+
def test_parse_args():
41+
42+
args = parse_args(
43+
[
44+
"/in",
45+
"/out",
46+
"2000-01-01",
47+
"2010-12-31",
48+
"swvl1",
49+
"swvl2",
50+
"--land_points",
51+
"True",
52+
"--imgbuffer",
53+
"1000",
54+
"--bbox",
55+
"12",
56+
"46",
57+
"17",
58+
"50",
59+
]
60+
)
61+
62+
assert isinstance(args.dataset_root, str) and args.dataset_root == "/in"
63+
assert (
64+
isinstance(args.timeseries_root, str)
65+
and args.timeseries_root == "/out"
66+
)
67+
assert isinstance(args.start, datetime) and args.start == datetime(
68+
2000, 1, 1
69+
)
70+
assert isinstance(args.end, datetime) and args.end == datetime(
71+
2010, 12, 31
72+
)
73+
assert isinstance(args.variables, list) and len(args.variables) == 2
74+
assert isinstance(args.land_points, bool) and args.land_points is True
75+
assert isinstance(args.imgbuffer, int) and args.imgbuffer == 1000
76+
assert (
77+
isinstance(args.bbox, list)
78+
and len(args.bbox) == 4
79+
and all([isinstance(a, float) for a in args.bbox])
80+
)
81+
3582

3683
def test_ERA5_reshuffle_nc():
3784
# test reshuffling era5 netcdf images to time series
3885

39-
inpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..',
40-
"ecmwf_models-test-data", "ERA5", "netcdf")
86+
inpath = os.path.join(
87+
os.path.dirname(os.path.abspath(__file__)),
88+
"..",
89+
"ecmwf_models-test-data",
90+
"ERA5",
91+
"netcdf",
92+
)
4193

42-
startdate = '2010-01-01'
43-
enddate = '2010-01-01'
94+
startdate = "2010-01-01"
95+
enddate = "2010-01-01"
4496
parameters = ["swvl1", "swvl2"]
45-
h_steps = ['--h_steps', '0', '12']
46-
landpoints = ['--land_points', 'True']
47-
bbox = ['--bbox', "12", '46', '17', '50']
97+
h_steps = ["--h_steps", "0", "12"]
98+
landpoints = ["--land_points", "True"]
99+
bbox = ["--bbox", "12", "46", "17", "50"]
48100

49101
with tempfile.TemporaryDirectory() as ts_path:
50-
args = [inpath, ts_path, startdate, enddate] \
51-
+ parameters + h_steps + landpoints + bbox
102+
args = (
103+
[inpath, ts_path, startdate, enddate]
104+
+ parameters
105+
+ h_steps
106+
+ landpoints
107+
+ bbox
108+
)
52109
main(args)
53-
assert len(glob.glob(os.path.join(ts_path, "*.nc"))) == 5 # less files because only land points and bbox
54-
ds = ERATs(ts_path, ioclass_kws={'read_bulk': True})
110+
assert (
111+
len(glob.glob(os.path.join(ts_path, "*.nc"))) == 5
112+
) # less files because only land points and bbox
113+
ds = ERATs(ts_path, ioclass_kws={"read_bulk": True})
55114
ts = ds.read(15, 48)
56115
ds.close()
57-
swvl1_values_should = np.array([0.402825, 0.390983], dtype=np.float32)
58-
nptest.assert_allclose(ts['swvl1'].values, swvl1_values_should, rtol=1e-5)
59-
swvl2_values_should = np.array([0.390512, 0.390981], dtype=np.float32)
60-
nptest.assert_allclose(ts['swvl2'].values, swvl2_values_should, rtol=1e-5)
116+
swvl1_values_should = np.array([0.402825, 0.390983], dtype=np.float32)
117+
nptest.assert_allclose(
118+
ts["swvl1"].values, swvl1_values_should, rtol=1e-5
119+
)
120+
swvl2_values_should = np.array([0.390512, 0.390981], dtype=np.float32)
121+
nptest.assert_allclose(
122+
ts["swvl2"].values, swvl2_values_should, rtol=1e-5
123+
)
124+
61125

62126
def test_ERA5_reshuffle_grb():
63127
# test reshuffling era5 netcdf images to time series
64128

65-
inpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..',
66-
"ecmwf_models-test-data", "ERA5", "netcdf")
67-
startdate = '2010-01-01'
68-
enddate = '2010-01-01'
129+
inpath = os.path.join(
130+
os.path.dirname(os.path.abspath(__file__)),
131+
"..",
132+
"ecmwf_models-test-data",
133+
"ERA5",
134+
"netcdf",
135+
)
136+
startdate = "2010-01-01"
137+
enddate = "2010-01-01"
69138
parameters = ["swvl1", "swvl2"]
70-
h_steps = ['--h_steps', '0', '12']
71-
landpoints = ['--land_points', 'False']
72-
bbox = ['--bbox', "12", '46', '17', '50']
139+
h_steps = ["--h_steps", "0", "12"]
140+
landpoints = ["--land_points", "False"]
141+
bbox = ["--bbox", "12", "46", "17", "50"]
73142

74143
with tempfile.TemporaryDirectory() as ts_path:
75144

76-
args = [inpath, ts_path, startdate, enddate] + parameters + \
77-
h_steps + landpoints + bbox
145+
args = (
146+
[inpath, ts_path, startdate, enddate]
147+
+ parameters
148+
+ h_steps
149+
+ landpoints
150+
+ bbox
151+
)
78152

79153
main(args)
80154

81155
assert len(glob.glob(os.path.join(ts_path, "*.nc"))) == 5
82-
ds = ERATs(ts_path, ioclass_kws={'read_bulk': True})
156+
ds = ERATs(ts_path, ioclass_kws={"read_bulk": True})
83157
ts = ds.read(15, 48)
84158
ds.close()
85-
swvl1_values_should = np.array([0.402824, 0.390979], dtype=np.float32)
86-
nptest.assert_allclose(ts['swvl1'].values, swvl1_values_should, rtol=1e-5)
87-
swvl2_values_should = np.array([0.390514, 0.390980], dtype=np.float32)
88-
nptest.assert_allclose(ts['swvl2'].values, swvl2_values_should, rtol=1e-5)
159+
swvl1_values_should = np.array([0.402824, 0.390979], dtype=np.float32)
160+
nptest.assert_allclose(
161+
ts["swvl1"].values, swvl1_values_should, rtol=1e-5
162+
)
163+
swvl2_values_should = np.array([0.390514, 0.390980], dtype=np.float32)
164+
nptest.assert_allclose(
165+
ts["swvl2"].values, swvl2_values_should, rtol=1e-5
166+
)

0 commit comments

Comments
 (0)