Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"netCDF4",
"cftime",
"dask",
"distributed>=2024.0.0",
"pyyaml",
"tqdm",
"requests",
Expand Down
85 changes: 85 additions & 0 deletions src/access_moppy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from typing import Any, Dict, List, Optional, Union

import netCDF4 as nc
import psutil
import xarray as xr
from cftime import num2date
from dask.distributed import get_client

from access_moppy.utilities import (
FrequencyMismatchError,
Expand Down Expand Up @@ -275,6 +277,89 @@ def write(self):
f"Missing required CMIP6 global attributes for filename: {missing}"
)

# ========== Memory Check ==========
# This section estimates the data size and compares it against available memory
# to prevent out-of-memory errors during the write operation.

def estimate_data_size(ds, cmor_name):
total_size = 0
for var in ds.variables:
vdat = ds[var]
# Start with the size of a single element (e.g., 4 bytes for float32)
var_size = vdat.dtype.itemsize
# Multiply by the size of each dimension to get total elements
for dim in vdat.dims:
var_size *= ds.sizes[dim]
total_size += var_size
# Apply 1.5x overhead factor for safe memory estimation
return int(total_size * 1.5)

# Calculate the estimated data size for this dataset
data_size = estimate_data_size(self.ds, self.cmor_name)

# Get system memory information using psutil
available_memory = psutil.virtual_memory().available

# ========== Dask Client Detection ==========
# Check if a Dask distributed client exists, as this affects how we handle
# memory management. Dask clusters have their own memory limits separate
# from system memory.

client = None
worker_memory = None # Memory limit of a single worker
total_cluster_memory = None # Sum of all workers' memory limits

try:
# Attempt to get an existing Dask client
client = get_client()

# Retrieve information about all workers in the cluster
worker_info = client.scheduler_info()["workers"]

if worker_info:
# Get the minimum memory_limit across all workers
worker_memory = min(w["memory_limit"] for w in worker_info.values())

# Sum up all workers' memory for total cluster capacity
total_cluster_memory = sum(
w["memory_limit"] for w in worker_info.values()
)

except ValueError:
# No Dask client exists - we'll use local/system memory for writing
pass

# ========== Memory Validation Logic ==========
# This section implements a decision tree based on data size vs available memory:

if client is not None:
# Dask client exists - check against cluster memory limits
if data_size > worker_memory:
# WARNING: Data fits in total cluster memory but exceeds single worker capacity
print(
f"Warning: Data size ({data_size / 1024**3:.2f} GB) exceeds single worker memory "
f"({worker_memory / 1024**3:.2f} GB) but fits in total cluster memory "
f"({total_cluster_memory / 1024**3:.2f} GB)."
)
print("Closing Dask client to use local memory for writing...")
client.close()
client = None

# If data < worker_memory: No action needed, proceed with write

if data_size > available_memory:
# Data exceeds available system memory
raise MemoryError(
f"Data size ({data_size / 1024**3:.2f} GB) exceeds available system memory "
f"({available_memory / 1024**3:.2f} GB). "
f"Consider using write_parallel() for chunked writing."
)

# Log the memory status for user awareness
print(
f"Data size: {data_size / 1024**3:.2f} GB, Available memory: {available_memory / 1024**3:.2f} GB"
)

time_var = self.ds[self.cmor_name].coords["time"]
units = time_var.attrs["units"]
calendar = time_var.attrs.get("calendar", "standard").lower()
Expand Down