diff --git a/pyproject.toml b/pyproject.toml index efe54fd..5a1e833 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "netCDF4", "cftime", "dask", + "distributed>=2024.0.0", "pyyaml", "tqdm", "requests", diff --git a/src/access_moppy/base.py b/src/access_moppy/base.py index e8f9c92..b2c2930 100644 --- a/src/access_moppy/base.py +++ b/src/access_moppy/base.py @@ -4,8 +4,10 @@ from typing import Any, Dict, List, Optional, Union import netCDF4 as nc +import psutil import xarray as xr from cftime import num2date +from dask.distributed import get_client from access_moppy.utilities import ( FrequencyMismatchError, @@ -275,6 +277,89 @@ def write(self): f"Missing required CMIP6 global attributes for filename: {missing}" ) + # ========== Memory Check ========== + # This section estimates the data size and compares it against available memory + # to prevent out-of-memory errors during the write operation. + + def estimate_data_size(ds, cmor_name): + total_size = 0 + for var in ds.variables: + vdat = ds[var] + # Start with the size of a single element (e.g., 4 bytes for float32) + var_size = vdat.dtype.itemsize + # Multiply by the size of each dimension to get total elements + for dim in vdat.dims: + var_size *= ds.sizes[dim] + total_size += var_size + # Apply 1.5x overhead factor for safe memory estimation + return int(total_size * 1.5) + + # Calculate the estimated data size for this dataset + data_size = estimate_data_size(self.ds, self.cmor_name) + + # Get system memory information using psutil + available_memory = psutil.virtual_memory().available + + # ========== Dask Client Detection ========== + # Check if a Dask distributed client exists, as this affects how we handle + # memory management. Dask clusters have their own memory limits separate + # from system memory. + + client = None + worker_memory = None # Memory limit of a single worker + total_cluster_memory = None # Sum of all workers' memory limits + + try: + # Attempt to get an existing Dask client + client = get_client() + + # Retrieve information about all workers in the cluster + worker_info = client.scheduler_info()["workers"] + + if worker_info: + # Get the minimum memory_limit across all workers + worker_memory = min(w["memory_limit"] for w in worker_info.values()) + + # Sum up all workers' memory for total cluster capacity + total_cluster_memory = sum( + w["memory_limit"] for w in worker_info.values() + ) + + except ValueError: + # No Dask client exists - we'll use local/system memory for writing + pass + + # ========== Memory Validation Logic ========== + # This section implements a decision tree based on data size vs available memory: + + if client is not None: + # Dask client exists - check against cluster memory limits + if data_size > worker_memory: + # WARNING: Data fits in total cluster memory but exceeds single worker capacity + print( + f"Warning: Data size ({data_size / 1024**3:.2f} GB) exceeds single worker memory " + f"({worker_memory / 1024**3:.2f} GB) but fits in total cluster memory " + f"({total_cluster_memory / 1024**3:.2f} GB)." + ) + print("Closing Dask client to use local memory for writing...") + client.close() + client = None + + # If data < worker_memory: No action needed, proceed with write + + if data_size > available_memory: + # Data exceeds available system memory + raise MemoryError( + f"Data size ({data_size / 1024**3:.2f} GB) exceeds available system memory " + f"({available_memory / 1024**3:.2f} GB). " + f"Consider using write_parallel() for chunked writing." + ) + + # Log the memory status for user awareness + print( + f"Data size: {data_size / 1024**3:.2f} GB, Available memory: {available_memory / 1024**3:.2f} GB" + ) + time_var = self.ds[self.cmor_name].coords["time"] units = time_var.attrs["units"] calendar = time_var.attrs.get("calendar", "standard").lower()