Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit testing #38

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
test_data/
.coverage
*.sh
__pycache__/
*.py[cod]
121 changes: 94 additions & 27 deletions src/hres_ic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,57 +14,124 @@

from pathlib import Path
import argparse
from datetime import datetime,timedelta
import pandas
from datetime import datetime
import shutil

import replace_landsurface_with_ERA5land_IC
import replace_landsurface_with_BARRA2R_IC

boolopt = {
"True": True,
"False": False,
}
INPUT_TIME_FORMAT = "%Y%m%d%H%M"
OUTPUT_TIME_FORMAT = "%Y%m%dT%H%MZ"

def get_start_time(time):
"""
Convert the time from the input string format to the desired string format

def main():
Parameters
----------
time: str
The time in the input string format

Returns
-------
str
The time in the desired string format
"""
The main function that creates a worker pool and generates single GRIB files
for requested date/times in parallel.
return datetime.strptime(time,INPUT_TIME_FORMAT).strftime(OUTPUT_TIME_FORMAT)

def replace_input_file_with_tmp_input_file(tmp_path):
"""
Swaps the newly-created temporary input file with the original input file, by
removing the '.tmp' extension from the temporary file path.

Parameters
----------
None. The arguments are given via the command-line
tmp_path: PosixPath
The temporary path with the '.tmp' extension.

Returns
-------
None. The astart file is updated and overwritten
"""
None
"""
if tmp_path.suffix == '.tmp':
shutil.move(tmp_path, tmp_path.with_suffix(''))
else:
raise ValueError(f"Expected a path ending in '.tmp', got '{tmp_path}'.")

def parse_arguments():
"""
Parses the command line arguments.

Parameters
----------
None

Returns
-------
argparse.Namespace
The parsed command line arguments
"""
# Parse the command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--mask', required=True, type=Path)
parser.add_argument('--file', required=True, type=Path)
parser.add_argument('--start', required=True, type=pandas.to_datetime)
parser.add_argument('--start', required=True, type=str)
parser.add_argument('--type', default="era5land")
args = parser.parse_args()
print(args)
return parser.parse_args()

def set_replace_function(data_type):
"""
Set what replace function to use, based on the command line argument '--type'.

Parameters
----------
data_type: str
Type of data source for the replacement. The value of the command line argument '--type'.

Returns
-------
callable or None
The function to be called for the replacement, or None if no replacement needs to take place.
"""
if data_type == "era5land":
replace_function = replace_landsurface_with_ERA5land_IC.swap_land_era5land
elif data_type == "barra":
replace_function = replace_landsurface_with_BARRA2R_IC.swap_land_barra
else:
replace_function = None
return replace_function


def main():
"""
The main function that creates a worker pool and generates single GRIB files
for requested date/times in parallel.

Parameters
----------
None. The arguments are given via the command-line

Returns
-------
None. The astart file is updated and overwritten
"""

args = parse_arguments()
print(f"{args=}")

# Convert the date/time to a formatted string
t = args.start.strftime("%Y%m%dT%H%MZ")
print(args.mask, args.file, t)

# If necessary replace ERA5 land/surface fields with higher-resolution options
if "era5land" in args.type:
replace_landsurface_with_ERA5land_IC.swap_land_era5land(args.mask, args.file, t)
shutil.move(args.file.as_posix(), args.file.as_posix().replace('.tmp', ''))
elif "barra" in args.type:
replace_landsurface_with_BARRA2R_IC.swap_land_barra(args.mask, args.file, t)
shutil.move(args.file.as_posix(), args.file.as_posix().replace('.tmp', ''))
t = get_start_time(args.start)
print(f"mask = {args.mask}")
print(f"replacement_file = {args.file}")
print(f"start_time = {t}")

# If necessary replace land/surface fields with higher-resolution options
replace_function = set_replace_function(args.type)
if replace_function is not None:
replace_function(args.mask, args.file, t)
replace_input_file_with_tmp_input_file(args.file)
else:
print("No need to swap out IC")

if __name__ == '__main__':
main()
main() # pragma: no cover

11 changes: 6 additions & 5 deletions src/replace_landsurface_with_ERA5land_IC.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_ERA_nc_data(ncfname, FIELDN, wanted_dt, bounds):
latmin_index, latmax_index = bounds.latmin, bounds.latmax

# Open the file containing the data
print(ncfname, FIELDN)
print(f"Requested variable {FIELDN} in file {ncfname}.")
if Path(ncfname).exists():
d = xr.open_dataset(ncfname)
else:
Expand All @@ -150,7 +150,7 @@ def get_ERA_nc_data(ncfname, FIELDN, wanted_dt, bounds):
# Find the array index for the date/time of interest
times=d['time'].dt.strftime("%Y%m%d%H%M").data
TM=times.tolist().index(wanted_dt)
print(TM)
print(f"Index of requested time in data: {TM}")

# Read the data
if lonmin_index < lonmax_index:
Expand Down Expand Up @@ -220,7 +220,7 @@ def swap_land_era5land(mask_fullpath, ic_file_fullpath, ic_date):
ic_file = ic_file_fullpath.parts[-1].replace('.tmp', '')

# create date/time useful information
print(ic_date)
print(f"Requested date: {ic_date}")
yyyy = ic_date[0:4]
mm = ic_date[4:6]
ic_z_date = ic_date.replace('T', '').replace('Z', '')
Expand All @@ -238,7 +238,8 @@ def swap_land_era5land(mask_fullpath, ic_file_fullpath, ic_date):

# Path to output file
ff_out = ic_file_fullpath.as_posix()
print(ff_in, ff_out)
print(f"Input file: '{ff_in}'")
print(f"Output file: '{ff_out}'")

# Read input file
mf_in = mule.load_umfile(ff_in)
Expand All @@ -255,7 +256,7 @@ def swap_land_era5land(mask_fullpath, ic_file_fullpath, ic_date):
# For each field in the input write to the output file (but modify as required)
for f in mf_in.fields:

print(f.lbuser4, f.lblev, f.lblrec, f.lbhr, f.lbcode)
print(f"{f.lbuser4=}", f"{f.lblev=}", f"{f.lblrec=}", f"{f.lbhr=}", f"{f.lbcode=}")

if f.lbuser4 == 9:
# replace coarse soil moisture with high-res information
Expand Down
135 changes: 135 additions & 0 deletions tests/test_hres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import pytest
from unittest.mock import patch, Mock
from pathlib import Path

# TODO: place ROSE DATA into a function (or an input argument) so it doesn't need to get called when importing the module
import os
os.environ['ROSE_DATA'] = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),'test_data/rose_data')

# TODO: Turn src into a package so that we can import the function directly
# For now required to import from src
import sys #To delete when src is a package
srcpath = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),'src') #To delete when src is a package
sys.path.insert(0,srcpath) #To delete when src is a package

from hres_ic import get_start_time, replace_input_file_with_tmp_input_file, parse_arguments, set_replace_function, main

del sys.path[0] #To delete when src is a package

# Test get_start_time
def test_get_start_time():
time = "199205041155"
assert get_start_time(time) == "19920504T1155Z"

# Test replace_input_file_with_tmp_input_file
def test_replace_input_file_with_tmp_input_file():
tmppath = Path('example/of/.tmp/file.tmp')
newpath = Path('example/of/.tmp/file')
# Mock the shutil.move function
with patch('shutil.move') as mock_move:
replace_input_file_with_tmp_input_file(tmppath)
mock_move.assert_called_once_with(tmppath,newpath)

def test_replace_input_file_with_tmp_input_file_fail():
tmppath = Path('example/of/.tmp/invalid/filetmp')
# Mock the shutil.move function
with patch('shutil.move'):
with pytest.raises(ValueError):
replace_input_file_with_tmp_input_file(tmppath)

# Test parse_arguments
@patch('sys.argv', ['program_name', '--mask', 'mask_path', '--file', 'file_path', '--start', '202408121230'])
def test_parse_arguments_success():
args = parse_arguments()
assert args.mask == Path('mask_path')
assert args.file == Path('file_path')
assert args.start == '202408121230'
assert args.type == 'era5land'

@patch('sys.argv', ['program_name', '--mask', 'mask_path', '--file', 'file_path', '--start', '202408121230', '--type', 'newtype'])
def test_parse_arguments_with_type():
args = parse_arguments()
assert args.type == 'newtype'

@patch('sys.argv', ['program_name', '--file', 'file_path', '--start', '202408121230'])
def test_parse_arguments_missing_mask():
with pytest.raises(SystemExit):
parse_arguments()

@patch('sys.argv', ['program_name', '--mask', 'mask_path', '--start', '2024-08-12'])
def test_parse_arguments_missing_file():
with pytest.raises(SystemExit):
parse_arguments()

@patch('sys.argv', ['program_name', '--mask', 'mask_path', '--file', 'file_path'])
def test_parse_arguments_missing_start():
with pytest.raises(SystemExit):
parse_arguments()

# Test set_replace_function
@patch('replace_landsurface_with_ERA5land_IC.swap_land_era5land')
def test_set_replace_function_era5land(mock_era5land):
result = set_replace_function("era5land")
assert result == mock_era5land

@patch('replace_landsurface_with_BARRA2R_IC.swap_land_barra')
def test_set_replace_function_barra(mock_barra):
result = set_replace_function("barra")
assert result == mock_barra

def test_set_replace_function_unknown():
result = set_replace_function("unknown")
assert result is None

# Test main function
@patch('hres_ic.parse_arguments')
@patch('hres_ic.get_start_time')
@patch('hres_ic.set_replace_function')
@patch('hres_ic.replace_input_file_with_tmp_input_file')
def test_main_with_replacement(mock_replace_input, mock_set_replace, mock_get_start, mock_parse_args):
# Mock the arguments returned by parse_arguments
mock_args = Mock()
mock_args.mask = "mock_mask"
mock_args.file = "mock_file"
mock_args.start = "mock_start"
mock_args.type = "era5land"
mock_parse_args.return_value = mock_args

# Mock the return value of get_start_time
mock_get_start.return_value = "mock_time"

# Mock the replacement function
mock_replace_func = Mock()
mock_set_replace.return_value = mock_replace_func

main()
mock_parse_args.assert_called_once()
mock_get_start.assert_called_once_with("mock_start")
mock_set_replace.assert_called_once_with("era5land")
mock_replace_func.assert_called_once_with("mock_mask", "mock_file", "mock_time")
mock_replace_input.assert_called_once_with("mock_file")

@patch('hres_ic.parse_arguments')
@patch('hres_ic.get_start_time')
@patch('hres_ic.set_replace_function')
@patch('hres_ic.replace_input_file_with_tmp_input_file')
def test_main_without_replacement(mock_replace_input, mock_set_replace, mock_get_start, mock_parse_args):
# Mock the arguments returned by parse_arguments
mock_args = Mock()
mock_args.mask = "mock_mask"
mock_args.file = "mock_file"
mock_args.start = "mock_start"
mock_args.type = "unknown_type"
mock_parse_args.return_value = mock_args

# Mock the return value of get_start_time
mock_get_start.return_value = "mock_time"

# Mock the replacement function to return None
mock_set_replace.return_value = None

main()
mock_parse_args.assert_called_once()
mock_get_start.assert_called_once_with("mock_start")
mock_set_replace.assert_called_once_with("unknown_type")
mock_replace_input.assert_not_called() # Should not be called since replace_function is None
16 changes: 16 additions & 0 deletions tests/test_replace_landsurface_with_era5land.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
from unittest.mock import patch

# TODO: place ROSE DATA into a function (or an input argument) so it doesn't need to get called when importing the module
import os
os.environ['ROSE_DATA'] = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),'test_data/rose_data')

# TODO: Turn src into a package so that we can import the function directly
# For now required to import from src
import sys #To delete when src is a package
srcpath = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),'src') #To delete when src is a package
sys.path.insert(0,srcpath) #To delete when src is a package

from replace_landsurface_with_ERA5land_IC import bounding_box

del sys.path[0] #To delete when src is a package
Loading