Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"netCDF4",
"cftime",
"dask",
"distributed>=2024.0.0",
"pyyaml",
"tqdm",
"requests",
Expand Down
346 changes: 345 additions & 1 deletion tests/unit/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
"""

from pathlib import Path
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock, patch

import numpy as np
import pytest
import xarray as xr

from access_moppy.base import CMIP6_CMORiser

Expand Down Expand Up @@ -189,3 +191,345 @@ def test_getattr_fallback(self, mock_vocab, mock_mapping, temp_dir):
# When ds is None, getattr should raise AttributeError
with pytest.raises(AttributeError):
_ = cmoriser.nonexistent_attribute


class TestCMIP6CMORiserWrite:
"""Unit tests for CMIP6_CMORiser.write() method with memory validation."""

# ==================== Fixtures ====================

@pytest.fixture
def mock_vocab(self):
"""Mock CMIP6 vocabulary object."""
vocab = Mock()
vocab.get_table = Mock(return_value={"tas": {"units": "K"}})
return vocab

@pytest.fixture
def mock_mapping(self):
"""Mock variable mapping."""
return {
"CF standard Name": "air_temperature",
"units": "K",
"dimensions": {"time": "time", "lat": "lat", "lon": "lon"},
"positive": None,
}

@pytest.fixture
def sample_dataset(self):
"""
Create a sample xarray Dataset for testing.

Dataset structure:
- tas: main variable (12 time steps × 10 lat × 10 lon, float32)
- time_bnds: time bounds
- All required CMIP6 global attributes included
"""
time = np.arange(12)
lat = np.arange(10)
lon = np.arange(10)

data = np.random.rand(12, 10, 10).astype(np.float32)

ds = xr.Dataset(
{
"tas": (["time", "lat", "lon"], data, {"_FillValue": 1e20}),
"time_bnds": (["time", "bnds"], np.zeros((12, 2))),
},
coords={
"time": (
"time",
time,
{"units": "days since 2000-01-01", "calendar": "standard"},
),
"lat": ("lat", lat),
"lon": ("lon", lon),
},
attrs={
"variable_id": "tas",
"table_id": "Amon",
"source_id": "ACCESS-ESM1-5",
"experiment_id": "historical",
"variant_label": "r1i1p1f1",
"grid_label": "gn",
},
)
return ds

@pytest.fixture
def sample_dataset_missing_attrs(self):
"""Create a dataset missing required CMIP6 attributes."""
ds = xr.Dataset(
{"tas": (["time"], np.zeros(10))},
coords={
"time": (
"time",
np.arange(10),
{"units": "days since 2000-01-01", "calendar": "standard"},
),
},
attrs={"variable_id": "tas"}, # Missing other required attrs
)
return ds

@pytest.fixture
def cmoriser_with_dataset(self, mock_vocab, mock_mapping, sample_dataset, temp_dir):
"""Create a CMORiser instance with a valid dataset attached."""
cmoriser = CMIP6_CMORiser(
input_paths=["test.nc"],
output_path=str(temp_dir),
cmip6_vocab=mock_vocab,
variable_mapping=mock_mapping,
compound_name="Amon.tas",
)
cmoriser.ds = sample_dataset
cmoriser.cmor_name = "tas"
return cmoriser

# ==================== Attribute Validation Tests ====================

@pytest.mark.unit
def test_write_raises_error_when_missing_required_attributes(
self, mock_vocab, mock_mapping, sample_dataset_missing_attrs, temp_dir
):
"""
Test that write() raises ValueError when required CMIP6 attributes are missing.

Required attributes: variable_id, table_id, source_id, experiment_id,
variant_label, grid_label
"""
cmoriser = CMIP6_CMORiser(
input_paths=["test.nc"],
output_path=str(temp_dir),
cmip6_vocab=mock_vocab,
variable_mapping=mock_mapping,
compound_name="Amon.tas",
)
cmoriser.ds = sample_dataset_missing_attrs
cmoriser.cmor_name = "tas"

with pytest.raises(
ValueError, match="Missing required CMIP6 global attributes"
):
cmoriser.write()

# ==================== Memory Estimation Tests ====================

@pytest.mark.unit
def test_write_data_size_estimation(self, cmoriser_with_dataset):
"""
Test that data size estimation is reasonable.

Sample dataset: float32 (4 bytes) × 12 × 10 × 10 = 4,800 bytes for main var
With 1.5x overhead factor, total should be well under 1 GB.
"""
ds = cmoriser_with_dataset.ds

# Calculate expected size manually
total_size = 0
for var in ds.variables:
vdat = ds[var]
var_size = vdat.dtype.itemsize
for dim in vdat.dims:
var_size *= ds.sizes[dim]
total_size += var_size

expected_size_with_overhead = int(total_size * 1.5)

# Verify the size is small (test data should be < 1 MB)
assert expected_size_with_overhead < 1 * 1024**2

# ==================== System Memory Check Tests ====================

@pytest.mark.unit
def test_write_proceeds_when_system_memory_sufficient(
self, cmoriser_with_dataset, temp_dir
):
"""
Test that write() proceeds normally when system memory is sufficient.

Scenario: No Dask client, plenty of system memory available.
Expected: File is created successfully.
"""
with patch("psutil.virtual_memory") as mock_mem:
# Mock sufficient available memory (16 GB)
mock_mem.return_value = MagicMock(
total=32 * 1024**3,
available=16 * 1024**3,
)

with patch(
"dask.distributed.get_client", side_effect=ValueError("No client")
):
cmoriser_with_dataset.write()

# Verify output file was created
output_files = list(Path(temp_dir).glob("*.nc"))
assert len(output_files) == 1

# ==================== Import Error Handling Tests ====================

@pytest.mark.unit
def test_write_handles_distributed_not_installed(
self, cmoriser_with_dataset, temp_dir
):
"""
Test graceful handling when dask.distributed is not installed.

Scenario: dask.distributed import raises ImportError.
Expected: Falls back to system memory check and proceeds.
"""
with patch("psutil.virtual_memory") as mock_mem:
mock_mem.return_value = MagicMock(
total=32 * 1024**3,
available=16 * 1024**3,
)

# Mock ImportError when trying to import dask.distributed
with patch(
"dask.distributed.get_client",
side_effect=ImportError("No module named 'distributed'"),
):
cmoriser_with_dataset.write()

output_files = list(Path(temp_dir).glob("*.nc"))
assert len(output_files) == 1

# ==================== Output File Tests ====================

@pytest.mark.unit
def test_write_creates_correct_cmip6_filename(
self, cmoriser_with_dataset, temp_dir
):
"""
Test that write() creates file with correct CMIP6 filename format.

Expected format: {var}_{table}_{source}_{exp}_{variant}_{grid}_{timerange}.nc
Example: tas_Amon_ACCESS-ESM1-5_historical_r1i1p1f1_gn_200001-200012.nc
"""
with patch("psutil.virtual_memory") as mock_mem:
mock_mem.return_value = MagicMock(
total=32 * 1024**3,
available=16 * 1024**3,
)

with patch(
"dask.distributed.get_client", side_effect=ValueError("No client")
):
cmoriser_with_dataset.write()

output_files = list(Path(temp_dir).glob("*.nc"))
assert len(output_files) == 1

filename = output_files[0].name

# Check filename components
assert filename.startswith("tas_")
assert "_Amon_" in filename
assert "_ACCESS-ESM1-5_" in filename
assert "_historical_" in filename
assert "_r1i1p1f1_" in filename
assert "_gn_" in filename
assert filename.endswith(".nc")

@pytest.mark.unit
def test_write_creates_valid_netcdf_structure(
self, cmoriser_with_dataset, temp_dir
):
"""
Test that write() creates a valid NetCDF file with correct structure.

Verifies:
- Required dimensions exist
- Main variable exists with correct shape
- Global attributes are preserved
"""
with patch("psutil.virtual_memory") as mock_mem:
mock_mem.return_value = MagicMock(
total=32 * 1024**3,
available=16 * 1024**3,
)

with patch(
"dask.distributed.get_client", side_effect=ValueError("No client")
):
cmoriser_with_dataset.write()

output_files = list(Path(temp_dir).glob("*.nc"))
output_file = output_files[0]

# Read back and verify structure
ds_out = xr.open_dataset(output_file)

try:
# Check dimensions
assert "time" in ds_out.dims
assert "lat" in ds_out.dims
assert "lon" in ds_out.dims

# Check main variable
assert "tas" in ds_out.data_vars
assert ds_out["tas"].dims == ("time", "lat", "lon")

# Check global attributes
assert ds_out.attrs["variable_id"] == "tas"
assert ds_out.attrs["table_id"] == "Amon"
assert ds_out.attrs["source_id"] == "ACCESS-ESM1-5"
assert ds_out.attrs["experiment_id"] == "historical"
assert ds_out.attrs["variant_label"] == "r1i1p1f1"
assert ds_out.attrs["grid_label"] == "gn"
finally:
ds_out.close()

@pytest.mark.unit
def test_write_preserves_data_values(self, cmoriser_with_dataset, temp_dir):
"""
Test that write() preserves data values correctly.

Verifies that data written to file matches original data.
"""
with patch("psutil.virtual_memory") as mock_mem:
mock_mem.return_value = MagicMock(
total=32 * 1024**3,
available=16 * 1024**3,
)

with patch(
"dask.distributed.get_client", side_effect=ValueError("No client")
):
original_data = cmoriser_with_dataset.ds["tas"].values.copy()

cmoriser_with_dataset.write()

output_files = list(Path(temp_dir).glob("*.nc"))
ds_out = xr.open_dataset(output_files[0])

try:
np.testing.assert_array_almost_equal(
ds_out["tas"].values, original_data
)
finally:
ds_out.close()

# ==================== Logging Tests ====================

@pytest.mark.unit
def test_write_prints_output_path(self, cmoriser_with_dataset, temp_dir, capsys):
"""
Test that write() prints the output file path after completion.
"""
with patch("psutil.virtual_memory") as mock_mem:
mock_mem.return_value = MagicMock(
total=32 * 1024**3,
available=16 * 1024**3,
)

with patch(
"dask.distributed.get_client", side_effect=ValueError("No client")
):
cmoriser_with_dataset.write()

captured = capsys.readouterr()

assert "CMORised output written to" in captured.out
assert str(temp_dir) in captured.out