diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d0d00907..8c12ab64 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - id: python-check-blanket-noqa # Check for # noqa: all - id: python-no-log-warn # Check for log.warn - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black args: [--line-length=120] @@ -40,7 +40,7 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.2 hooks: - id: ruff args: @@ -59,13 +59,8 @@ repos: hooks: - id: rstfmt exclude: 'cli/.*' # Because we use argparse -- repo: https://github.com/b8raoult/pre-commit-docconvert - rev: "0.1.5" - hooks: - - id: docconvert - args: ["numpy"] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.2.4" + rev: "v2.5.0" hooks: - id: pyproject-fmt diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c3c63fa..3d808e7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,13 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-datasets/compare/0.5.8...HEAD) -### Changed + +### Added + +- Call filters from anemoi-transform - make test optional when adls is not installed Pull request #110 + ## [0.5.8](https://github.com/ecmwf/anemoi-datasets/compare/0.5.7...0.5.8) - 2024-10-26 ### Changed @@ -34,6 +38,7 @@ Keep it human-readable, your future self will thank you! ### Changed +- Upload with ssh (experimental) - Remove upstream dependencies from downstream-ci workflow (temporary) (#83) - ci: pin python versions to 3.9 ... 3.12 for checks (#93) - Fix `__version__` import in init diff --git a/docs/using/missing.rst b/docs/using/missing.rst index e7c6b9d4..a6728494 100644 --- a/docs/using/missing.rst +++ b/docs/using/missing.rst @@ -4,6 +4,16 @@ Managing missing dates ######################## +********************************************* + Managing missing dates with anemoi-training +********************************************* + +Anemoi-training has internal handling of missing dates, and will +calculate the valid date indices used during training using the +``missing`` property. Consequenctly, when training a model with +anemoi-training, you should `not` specify a method to deal with missing +dates in the dataloader configuration file. + ************************************************** Filling the missing dates with artificial values ************************************************** diff --git a/pyproject.toml b/pyproject.toml index 60c8a56a..be3611f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,12 @@ -#!/usr/bin/env python -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ - [build-system] requires = [ "setuptools>=60", @@ -42,6 +40,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 1ca9aef8..9d66b2fc 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -10,14 +10,13 @@ import logging import os -import shutil import sys from concurrent.futures import ThreadPoolExecutor from concurrent.futures import as_completed import tqdm -from anemoi.utils.s3 import download -from anemoi.utils.s3 import upload +from anemoi.utils.remote import Transfer +from anemoi.utils.remote import TransferMethodNotImplementedError from . import Command @@ -29,54 +28,7 @@ isatty = False -class S3Downloader: - def __init__(self, source, target, transfers, overwrite, resume, verbosity, **kwargs): - self.source = source - self.target = target - self.transfers = transfers - self.overwrite = overwrite - self.resume = resume - self.verbosity = verbosity - - def run(self): - if self.target == ".": - self.target = os.path.basename(self.source) - - if self.overwrite and os.path.exists(self.target): - LOG.info(f"Deleting {self.target}") - shutil.rmtree(self.target) - - download( - self.source + "/" if not self.source.endswith("/") else self.source, - self.target, - overwrite=self.overwrite, - resume=self.resume, - verbosity=self.verbosity, - threads=self.transfers, - ) - - -class S3Uploader: - def __init__(self, source, target, transfers, overwrite, resume, verbosity, **kwargs): - self.source = source - self.target = target - self.transfers = transfers - self.overwrite = overwrite - self.resume = resume - self.verbosity = verbosity - - def run(self): - upload( - self.source, - self.target, - overwrite=self.overwrite, - resume=self.resume, - verbosity=self.verbosity, - threads=self.transfers, - ) - - -class DefaultCopier: +class ZarrCopier: def __init__(self, source, target, transfers, block_size, overwrite, resume, verbosity, nested, rechunk, **kwargs): self.source = source self.target = target @@ -90,6 +42,14 @@ def __init__(self, source, target, transfers, block_size, overwrite, resume, ver self.rechunking = rechunk.split(",") if rechunk else [] + source_is_ssh = self.source.startswith("ssh://") + target_is_ssh = self.target.startswith("ssh://") + + if source_is_ssh or target_is_ssh: + if self.rechunk: + raise NotImplementedError("Rechunking with SSH not implemented.") + assert NotImplementedError("SSH not implemented.") + def _store(self, path, nested=False): if nested: import zarr @@ -337,26 +297,33 @@ def run(self, args): if args.source == args.target: raise ValueError("Source and target are the same.") - kwargs = vars(args) - if args.overwrite and args.resume: raise ValueError("Cannot use --overwrite and --resume together.") - source_in_s3 = args.source.startswith("s3://") - target_in_s3 = args.target.startswith("s3://") - - copier = None - - if args.rechunk or (source_in_s3 and target_in_s3): - copier = DefaultCopier(**kwargs) - else: - if source_in_s3: - copier = S3Downloader(**kwargs) - - if target_in_s3: - copier = S3Uploader(**kwargs) - + if not args.rechunk: + # rechunking is only supported for ZARR datasets, it is implemented in this package + try: + if args.source.startswith("s3://") and not args.source.endswith("/"): + args.source = args.source + "/" + copier = Transfer( + args.source, + args.target, + overwrite=args.overwrite, + resume=args.resume, + verbosity=args.verbosity, + threads=args.transfers, + ) + copier.run() + return + except TransferMethodNotImplementedError: + # DataTransfer relies on anemoi-utils which is agnostic to the source and target format + # it transfers file and folders, ignoring that it is zarr data + # if it is not implemented, we fallback to the ZarrCopier + pass + + copier = ZarrCopier(**vars(args)) copier.run() + return class Copy(CopyMixin, Command): diff --git a/src/anemoi/datasets/create/check.py b/src/anemoi/datasets/create/check.py index e6f5eb14..2672034f 100644 --- a/src/anemoi/datasets/create/check.py +++ b/src/anemoi/datasets/create/check.py @@ -89,7 +89,7 @@ def check_parsed(self): self.messages.append( f"the dataset name {self} does not follow naming convention. " "See here for details: " - "https://confluence.ecmwf.int/display/DWF/Datasets+available+as+zarr" + "https://anemoi-registry.readthedocs.io/en/latest/naming-conventions.html" ) def check_resolution(self, resolution): diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index b761d242..79686f32 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -215,15 +215,14 @@ def _prepare_serialisation(o): def set_to_test_mode(cfg): NUMBER_OF_DATES = 4 - dates = cfg["dates"] LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.") groups = Groups(**LoadersConfig(cfg).dates) - dates = groups.dates + dates = groups.provider.values cfg["dates"] = dict( start=dates[0], end=dates[NUMBER_OF_DATES - 1], - frequency=dates.frequency, + frequency=groups.provider.frequency, group_by=NUMBER_OF_DATES, ) diff --git a/src/anemoi/datasets/create/functions/__init__.py b/src/anemoi/datasets/create/functions/__init__.py index af894893..4230a97b 100644 --- a/src/anemoi/datasets/create/functions/__init__.py +++ b/src/anemoi/datasets/create/functions/__init__.py @@ -21,6 +21,8 @@ def assert_is_fieldlist(obj): def import_function(name, kind): + from anemoi.transform.filters import filter_registry + name = name.replace("-", "_") plugins = {} @@ -30,8 +32,21 @@ def import_function(name, kind): if name in plugins: return plugins[name].load() - module = importlib.import_module( - f".{kind}.{name}", - package=__name__, - ) - return module.execute + try: + module = importlib.import_module( + f".{kind}.{name}", + package=__name__, + ) + return module.execute + except ModuleNotFoundError: + pass + + if kind == "filters": + if filter_registry.lookup(name, return_none=True): + + def proc(context, data, *args, **kwargs): + return filter_registry.create(name, *args, **kwargs)(data) + + return proc + + raise ValueError(f"Unknown {kind} '{name}'") diff --git a/src/anemoi/datasets/create/functions/filters/__init__.py b/src/anemoi/datasets/create/functions/filters/__init__.py index 08297687..684f0bee 100644 --- a/src/anemoi/datasets/create/functions/filters/__init__.py +++ b/src/anemoi/datasets/create/functions/filters/__init__.py @@ -7,9 +7,3 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# diff --git a/src/anemoi/datasets/create/functions/filters/rename.py b/src/anemoi/datasets/create/functions/filters/rename.py index d815286f..331a8c1f 100644 --- a/src/anemoi/datasets/create/functions/filters/rename.py +++ b/src/anemoi/datasets/create/functions/filters/rename.py @@ -25,7 +25,9 @@ class RenamedFieldMapping: def __init__(self, field, what, renaming): self.field = field self.what = what - self.renaming = renaming + self.renaming = {} + for k, v in renaming.items(): + self.renaming[k] = {str(a): str(b) for a, b in v.items()} def metadata(self, key=None, **kwargs): if key is None: diff --git a/src/anemoi/datasets/create/functions/sources/mars.py b/src/anemoi/datasets/create/functions/sources/mars.py index 703c3708..ef72888c 100644 --- a/src/anemoi/datasets/create/functions/sources/mars.py +++ b/src/anemoi/datasets/create/functions/sources/mars.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. import datetime +import re from anemoi.utils.humanize import did_you_mean from earthkit.data import from_source @@ -32,6 +33,25 @@ def _date_to_datetime(d): return datetime.datetime.fromisoformat(d) +def expand_to_by(x): + + if isinstance(x, (str, int)): + return expand_to_by(str(x).split("/")) + + if len(x) == 3 and x[1] == "to": + start = int(x[0]) + end = int(x[2]) + return list(range(start, end + 1)) + + if len(x) == 5 and x[1] == "to" and x[3] == "by": + start = int(x[0]) + end = int(x[2]) + by = int(x[4]) + return list(range(start, end + 1, by)) + + return x + + def normalise_time_delta(t): if isinstance(t, datetime.timedelta): assert t == datetime.timedelta(hours=t.hours), t @@ -43,25 +63,48 @@ def normalise_time_delta(t): return t +def _normalise_time(t): + t = int(t) + if t < 100: + t * 100 + return "{:04d}".format(t) + + def _expand_mars_request(request, date, request_already_using_valid_datetime=False, date_key="date"): requests = [] - step = to_list(request.get("step", [0])) - for s in step: + + user_step = to_list(expand_to_by(request.get("step", [0]))) + user_time = None + user_date = None + + if not request_already_using_valid_datetime: + user_time = request.get("time") + if user_time is not None: + user_time = to_list(user_time) + user_time = [_normalise_time(t) for t in user_time] + + user_date = request.get(date_key) + if user_date is not None: + assert isinstance(user_date, str), user_date + user_date = re.compile("^{}$".format(user_date.replace("-", "").replace("?", "."))) + + for step in user_step: r = request.copy() if not request_already_using_valid_datetime: - if isinstance(s, str) and "-" in s: - assert s.count("-") == 1, s + if isinstance(step, str) and "-" in step: + assert step.count("-") == 1, step + # this takes care of the cases where the step is a period such as 0-24 or 12-24 - hours = int(str(s).split("-")[-1]) + hours = int(str(step).split("-")[-1]) base = date - datetime.timedelta(hours=hours) r.update( { date_key: base.strftime("%Y%m%d"), "time": base.strftime("%H%M"), - "step": s, + "step": step, } ) @@ -70,12 +113,28 @@ def _expand_mars_request(request, date, request_already_using_valid_datetime=Fal if isinstance(r[pproc], (list, tuple)): r[pproc] = "/".join(str(x) for x in r[pproc]) + if user_date is not None: + if not user_date.match(r[date_key]): + continue + + if user_time is not None: + # It time is provided by the user, we only keep the requests that match the time + if r["time"] not in user_time: + continue + requests.append(r) + # assert requests, requests + return requests -def factorise_requests(dates, *requests, request_already_using_valid_datetime=False, date_key="date"): +def factorise_requests( + dates, + *requests, + request_already_using_valid_datetime=False, + date_key="date", +): updates = [] for req in requests: # req = normalise_request(req) @@ -88,6 +147,9 @@ def factorise_requests(dates, *requests, request_already_using_valid_datetime=Fa date_key=date_key, ) + if not updates: + return + compressed = Availability(updates) for r in compressed.iterate(): for k, v in r.items(): @@ -178,7 +240,15 @@ def use_grib_paramid(r): ] -def mars(context, dates, *requests, request_already_using_valid_datetime=False, date_key="date", **kwargs): +def mars( + context, + dates, + *requests, + request_already_using_valid_datetime=False, + date_key="date", + **kwargs, +): + if not requests: requests = [kwargs] diff --git a/src/anemoi/datasets/create/statistics/__init__.py b/src/anemoi/datasets/create/statistics/__init__.py index 88016dad..c3d3939e 100644 --- a/src/anemoi/datasets/create/statistics/__init__.py +++ b/src/anemoi/datasets/create/statistics/__init__.py @@ -6,6 +6,7 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import datetime import glob import hashlib diff --git a/src/anemoi/datasets/data/__init__.py b/src/anemoi/datasets/data/__init__.py index 8c0503ce..1b2a26c6 100644 --- a/src/anemoi/datasets/data/__init__.py +++ b/src/anemoi/datasets/data/__init__.py @@ -49,6 +49,7 @@ def _convert(x): def open_dataset(*args, **kwargs): # That will get rid of OmegaConf objects + args, kwargs = _convert(args), _convert(kwargs) ds = _open_dataset(*args, **kwargs) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index f9dda392..8ab06757 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -11,7 +11,6 @@ import datetime import json import logging -import os import pprint import warnings from functools import cached_property @@ -28,8 +27,6 @@ def _tidy(v): return [_tidy(i) for i in v] if isinstance(v, dict): return {k: _tidy(v) for k, v in v.items()} - if isinstance(v, str) and v.startswith("/"): - return os.path.basename(v) if isinstance(v, datetime.datetime): return v.isoformat() if isinstance(v, datetime.date): @@ -391,7 +388,7 @@ def _supporting_arrays_and_sources(self): # Arrays from the input sources for i, source in enumerate(self._input_sources()): - name = source.name if source.name is not None else i + name = source.name if source.name is not None else f"source{i}" src_arrays = source._supporting_arrays(name) source_to_arrays[id(source)] = sorted(src_arrays.keys()) diff --git a/src/anemoi/datasets/data/grids.py b/src/anemoi/datasets/data/grids.py index e8859cdb..15915f13 100644 --- a/src/anemoi/datasets/data/grids.py +++ b/src/anemoi/datasets/data/grids.py @@ -12,6 +12,7 @@ from functools import cached_property import numpy as np +from scipy.spatial import cKDTree from .debug import Node from .debug import debug_indexing @@ -142,95 +143,250 @@ def tree(self): class Cutout(GridsBase): - def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False): - from anemoi.datasets.grids import cutout_mask - + def __init__(self, datasets, axis=3, cropping_distance=2.0, neighbours=5, min_distance_km=None, plot=None): + """Initializes a Cutout object for hierarchical management of Limited Area + Models (LAMs) and a global dataset, handling overlapping regions. + + Args: + datasets (list): List of LAM and global datasets. + axis (int): Concatenation axis, must be set to 3. + cropping_distance (float): Distance threshold in degrees for + cropping cutouts. + neighbours (int): Number of neighboring points to consider when + constructing masks. + min_distance_km (float, optional): Minimum distance threshold in km + between grid points. + plot (bool, optional): Flag to enable or disable visualization + plots. + """ super().__init__(datasets, axis) - assert len(datasets) == 2, "CutoutGrids requires two datasets" + assert len(datasets) >= 2, "CutoutGrids requires at least two datasets" assert axis == 3, "CutoutGrids requires axis=3" + assert cropping_distance >= 0, "cropping_distance must be a non-negative number" + if min_distance_km is not None: + assert min_distance_km >= 0, "min_distance_km must be a non-negative number" + + self.lams = datasets[:-1] # Assume the last dataset is the global one + self.globe = datasets[-1] + self.axis = axis + self.cropping_distance = cropping_distance + self.neighbours = neighbours + self.min_distance_km = min_distance_km + self.plot = plot + self.masks = [] # To store the masks for each LAM dataset + self.global_mask = np.ones(self.globe.shape[-1], dtype=bool) + + # Initialize cumulative masks + self._initialize_masks() + + def _initialize_masks(self): + """Generates hierarchical masks for each LAM dataset by excluding + overlapping regions with previous LAMs and creating a global mask for + the global dataset. + + Raises: + ValueError: If the global mask dimension does not match the global + dataset grid points. + """ + from anemoi.datasets.grids import cutout_mask - # We assume that the LAM is the first dataset, and the global is the second - # Note: the second fields does not really need to be global - - self.lam, self.globe = datasets - self.mask = cutout_mask( - self.lam.latitudes, - self.lam.longitudes, - self.globe.latitudes, - self.globe.longitudes, - plot=plot, - min_distance_km=min_distance_km, - cropping_distance=cropping_distance, - neighbours=neighbours, - ) - assert len(self.mask) == self.globe.shape[3], ( - len(self.mask), - self.globe.shape[3], - ) + for i, lam in enumerate(self.lams): + assert len(lam.shape) == len( + self.globe.shape + ), "LAMs and global dataset must have the same number of dimensions" + lam_lats = lam.latitudes + lam_lons = lam.longitudes + # Create a mask for the global dataset excluding all LAM points + global_overlap_mask = cutout_mask( + lam.latitudes, + lam.longitudes, + self.globe.latitudes, + self.globe.longitudes, + plot=False, + min_distance_km=self.min_distance_km, + cropping_distance=self.cropping_distance, + neighbours=self.neighbours, + ) + + # Ensure the mask dimensions match the global grid points + if global_overlap_mask.shape[0] != self.globe.shape[-1]: + raise ValueError("Global mask dimension does not match global dataset grid " "points.") + self.global_mask[~global_overlap_mask] = False + + # Create a mask for the LAM datasets hierarchically, excluding + # points from previous LAMs + lam_current_mask = np.ones(lam.shape[-1], dtype=bool) + if i > 0: + for j in range(i): + prev_lam = self.lams[j] + prev_lam_lats = prev_lam.latitudes + prev_lam_lons = prev_lam.longitudes + # Check for overlap by computing distances + if self.has_overlap(prev_lam_lats, prev_lam_lons, lam_lats, lam_lons): + lam_overlap_mask = cutout_mask( + prev_lam_lats, + prev_lam_lons, + lam_lats, + lam_lons, + plot=False, + min_distance_km=self.min_distance_km, + cropping_distance=self.cropping_distance, + neighbours=self.neighbours, + ) + lam_current_mask[~lam_overlap_mask] = False + self.masks.append(lam_current_mask) + + def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0): + """Checks for overlapping points between two sets of latitudes and + longitudes within a specified distance threshold. + + Args: + lats1, lons1 (np.ndarray): Latitude and longitude arrays for the + first dataset. + lats2, lons2 (np.ndarray): Latitude and longitude arrays for the + second dataset. + distance_threshold (float): Distance in degrees to consider as + overlapping. + + Returns: + bool: True if any points overlap within the distance threshold, + otherwise False. + """ + # Create KDTree for the first set of points + tree = cKDTree(np.vstack((lats1, lons1)).T) + + # Query the second set of points against the first tree + distances, _ = tree.query(np.vstack((lats2, lons2)).T, k=1) + + # Check if any distance is less than the specified threshold + return np.any(distances < distance_threshold) + + def __getitem__(self, index): + """Retrieves data from the masked LAMs and global dataset based on the + given index. + + Args: + index (int or slice or tuple): Index specifying the data to + retrieve. + + Returns: + np.ndarray: Data array from the masked datasets based on the index. + """ + if isinstance(index, (int, slice)): + index = (index, slice(None), slice(None), slice(None)) + return self._get_tuple(index) + + def _get_tuple(self, index): + """Helper method that applies masks and retrieves data from each dataset + according to the specified index. + + Args: + index (tuple): Index specifying slices to retrieve data. + + Returns: + np.ndarray: Concatenated data array from all datasets based on the + index. + """ + index, changes = index_to_slices(index, self.shape) + # Select data from each LAM + lam_data = [lam[index] for lam in self.lams] + + # First apply spatial indexing on `self.globe` and then apply the mask + globe_data_sliced = self.globe[index[:3]] + globe_data = globe_data_sliced[..., self.global_mask] + + # Concatenate LAM data with global data + result = np.concatenate(lam_data + [globe_data], axis=self.axis) + return apply_index_to_slices_changes(result, changes) def collect_supporting_arrays(self, collected, *path): - collected.append((path, "cutout_mask", self.mask)) + """Collects supporting arrays, including masks for each LAM and the global + dataset. + + Args: + collected (list): List to which the supporting arrays are appended. + *path: Variable length argument list specifying the paths for the masks. + """ + # Append masks for each LAM + for i, (lam, mask) in enumerate(zip(self.lams, self.masks)): + collected.append((path + (f"lam_{i}",), "cutout_mask", mask)) + + # Append the global mask + collected.append((path + ("global",), "cutout_mask", self.global_mask)) @cached_property def shape(self): - shape = self.lam.shape - # Number of non-zero masked values in the globe dataset - nb_globe = np.count_nonzero(self.mask) - return shape[:-1] + (shape[-1] + nb_globe,) + """Returns the shape of the Cutout, accounting for retained grid points + across all LAMs and the global dataset. + + Returns: + tuple: Shape of the concatenated masked datasets. + """ + shapes = [np.sum(mask) for mask in self.masks] + global_shape = np.sum(self.global_mask) + return tuple(self.lams[0].shape[:-1] + (sum(shapes) + global_shape,)) def check_same_resolution(self, d1, d2): # Turned off because we are combining different resolutions pass @property - def latitudes(self): - return np.concatenate([self.lam.latitudes, self.globe.latitudes[self.mask]]) + def grids(self): + """Returns the number of grid points for each LAM and the global dataset + after applying masks. - @property - def longitudes(self): - return np.concatenate([self.lam.longitudes, self.globe.longitudes[self.mask]]) + Returns: + tuple: Count of retained grid points for each dataset. + """ + grids = [np.sum(mask) for mask in self.masks] + grids.append(np.sum(self.global_mask)) + return tuple(grids) - def __getitem__(self, index): - if isinstance(index, (int, slice)): - index = (index, slice(None), slice(None), slice(None)) - return self._get_tuple(index) + @property + def latitudes(self): + """Returns the concatenated latitudes of each LAM and the global dataset + after applying masks. - @debug_indexing - @expand_list_indexing - def _get_tuple(self, index): - assert self.axis >= len(index) or index[self.axis] == slice( - None - ), f"No support for selecting a subset of the 1D values {index} ({self.tree()})" - index, changes = index_to_slices(index, self.shape) + Returns: + np.ndarray: Concatenated latitude array for the masked datasets. + """ + lam_latitudes = np.concatenate([lam.latitudes[mask] for lam, mask in zip(self.lams, self.masks)]) - # In case index_to_slices has changed the last slice - index, _ = update_tuple(index, self.axis, slice(None)) + assert ( + len(lam_latitudes) + len(self.globe.latitudes[self.global_mask]) == self.shape[-1] + ), "Mismatch in number of latitudes" - lam_data = self.lam[index] - globe_data = self.globe[index] + latitudes = np.concatenate([lam_latitudes, self.globe.latitudes[self.global_mask]]) + return latitudes - globe_data = globe_data[:, :, :, self.mask] + @property + def longitudes(self): + """Returns the concatenated longitudes of each LAM and the global dataset + after applying masks. - result = np.concatenate([lam_data, globe_data], axis=self.axis) + Returns: + np.ndarray: Concatenated longitude array for the masked datasets. + """ + lam_longitudes = np.concatenate([lam.longitudes[mask] for lam, mask in zip(self.lams, self.masks)]) - return apply_index_to_slices_changes(result, changes) + assert ( + len(lam_longitudes) + len(self.globe.longitudes[self.global_mask]) == self.shape[-1] + ), "Mismatch in number of longitudes" - @property - def grids(self): - for d in self.datasets: - if len(d.grids) > 1: - raise NotImplementedError("CutoutGrids does not support multi-grids datasets as inputs") - shape = self.lam.shape - return (shape[-1], self.shape[-1] - shape[-1]) + longitudes = np.concatenate([lam_longitudes, self.globe.longitudes[self.global_mask]]) + return longitudes def tree(self): + """Generates a hierarchical tree structure for the `Cutout` instance and + its associated datasets. + + Returns: + Node: A `Node` object representing the `Cutout` instance as the root + node, with each dataset in `self.datasets` represented as a child + node. + """ return Node(self, [d.tree() for d in self.datasets]) - # def metadata_specific(self): - # return super().metadata_specific( - # mask=serialise_mask(self.mask), - # ) - def grids_factory(args, kwargs): if "ensemble" in kwargs: diff --git a/src/anemoi/datasets/data/merge.py b/src/anemoi/datasets/data/merge.py index 31c08b27..6921c2be 100644 --- a/src/anemoi/datasets/data/merge.py +++ b/src/anemoi/datasets/data/merge.py @@ -28,21 +28,41 @@ class Merge(Combined): + + # d0 d2 d4 d6 ... + # d1 d3 d5 d7 ... + + # gives + # d0 d1 d2 d3 ... + def __init__(self, datasets, allow_gaps_in_dates=False): super().__init__(datasets) self.allow_gaps_in_dates = allow_gaps_in_dates - dates = dict() + dates = dict() # date -> (dataset_index, date_index) for i, d in enumerate(datasets): for j, date in enumerate(d.dates): date = date.astype(object) if date in dates: - d1 = datasets[dates[date][0]] - d2 = datasets[i] + + d1 = datasets[dates[date][0]] # Selected + d2 = datasets[i] # The new one + + if j in d2.missing: + # LOG.warning(f"Duplicate date {date} found in datasets {d1} and {d2}, but {date} is missing in {d}, ignoring") + continue + + k = dates[date][1] + if k in d1.missing: + # LOG.warning(f"Duplicate date {date} found in datasets {d1} and {d2}, but {date} is missing in {d}, ignoring") + dates[date] = (i, j) # Replace the missing date with the new one + continue + raise ValueError(f"Duplicate date {date} found in datasets {d1} and {d2}") - dates[date] = (i, j) + else: + dates[date] = (i, j) all_dates = sorted(dates) start = all_dates[0] @@ -71,7 +91,10 @@ def __init__(self, datasets, allow_gaps_in_dates=False): self._dates = np.array(_dates, dtype="datetime64[s]") self._indices = np.array(indices) - self._frequency = frequency.astype(object) + self._frequency = frequency # .astype(object) + + def __len__(self): + return len(self._dates) @property def dates(self): diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 19246bc0..c8340e6c 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -71,7 +71,7 @@ class S3Store(ReadOnlyStore): """ def __init__(self, url, region=None): - from anemoi.utils.s3 import s3_client + from anemoi.utils.remote.s3 import s3_client _, _, self.bucket, self.key = url.split("/", 3) self.s3 = s3_client(self.bucket, region=region) @@ -364,6 +364,7 @@ def metadata_specific(self): attrs=dict(self.z.attrs), chunks=self.chunks, dtype=str(self.dtype), + path=self.path, ) def source(self, index): @@ -396,7 +397,7 @@ def __init__(self, path): super().__init__(path) missing_dates = self.z.attrs.get("missing_dates", []) - missing_dates = set([np.datetime64(x) for x in missing_dates]) + missing_dates = set([np.datetime64(x, "s") for x in missing_dates]) self.missing_to_dates = {i: d for i, d in enumerate(self.dates) if d in missing_dates} self.missing = set(self.missing_to_dates) diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index e2ef1d6b..676c5802 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -10,6 +10,8 @@ import datetime import warnings +from functools import reduce +from math import gcd # from anemoi.utils.dates import as_datetime from anemoi.utils.dates import DateTimes @@ -195,18 +197,16 @@ def __init__(self, start, end, steps=[0], years=20, **kwargs): dates = sorted(dates) - mindelta = None + deltas = set() for a, b in zip(dates, dates[1:]): delta = b - a assert isinstance(delta, datetime.timedelta), delta - if mindelta is None: - mindelta = delta - else: - mindelta = min(mindelta, delta) + deltas.add(delta) + mindelta_seconds = reduce(gcd, [int(delta.total_seconds()) for delta in deltas]) + mindelta = datetime.timedelta(seconds=mindelta_seconds) self.frequency = mindelta assert mindelta.total_seconds() > 0, mindelta - print("🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥", dates[0], dates[-1], mindelta) # Use all values between start and end by frequency, and set the ones that are missing diff --git a/tests/create/test_create.py b/tests/create/test_create.py index 6e3f381b..ab612dbf 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -257,8 +257,9 @@ def compare(self): def test_run(name): config = os.path.join(HERE, name + ".yaml") output = os.path.join(HERE, name + ".zarr") + is_test = False - creator_factory("init", config=config, path=output, overwrite=True).run() + creator_factory("init", config=config, path=output, overwrite=True, test=is_test).run() creator_factory("load", path=output).run() creator_factory("finalise", path=output).run() creator_factory("patch", path=output).run() @@ -272,8 +273,8 @@ def test_run(name): # reference_path = os.path.join(HERE, name + "-reference.zarr") s3_uri = TEST_DATA_ROOT + "/" + name + ".zarr" # if not os.path.exists(reference_path): - # from anemoi.utils.s3 import download as s3_download - # s3_download(s3_uri + '/', reference_path, overwrite=True) + # from anemoi.utils.remote import transfer + # transfer(s3_uri + '/', reference_path, overwrite=True) Comparer(name, output_path=output, reference_path=s3_uri).compare() # Comparer(name, output_path=output, reference_path=reference_path).compare() diff --git a/tools/grids/grids3.yaml b/tools/grids/grids3.yaml new file mode 100644 index 00000000..75f91961 --- /dev/null +++ b/tools/grids/grids3.yaml @@ -0,0 +1,42 @@ +common: + mars_request: &mars_request + expver: "0001" + grid: 0.25/0.25 + area: [40, 25, 20, 60] + rotation: [-20, -40] + +dates: + start: 2024-01-01 00:00:00 + end: 2024-01-01 18:00:00 + frequency: 6h + +input: + join: + - mars: + <<: *mars_request + param: [2t, 10u, 10v, lsm] + levtype: sfc + stream: oper + type: an + - mars: + <<: *mars_request + param: [q, t, z] + levtype: pl + level: [50, 100] + stream: oper + type: an + - accumulations: + <<: *mars_request + levtype: sfc + param: [cp, tp] + - forcings: + template: ${input.join.0.mars} + param: + - cos_latitude + - sin_latitude + +output: + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level diff --git a/tools/grids/grids4.yaml b/tools/grids/grids4.yaml new file mode 100644 index 00000000..39b72706 --- /dev/null +++ b/tools/grids/grids4.yaml @@ -0,0 +1,41 @@ +common: + mars_request: &mars_request + expver: "0001" + grid: 0.5/0.5 + area: [30, 90, 10, 120] + +dates: + start: 2024-01-01 00:00:00 + end: 2024-01-01 18:00:00 + frequency: 6h + +input: + join: + - mars: + <<: *mars_request + param: [2t, 10u, 10v, lsm] + levtype: sfc + stream: oper + type: an + - mars: + <<: *mars_request + param: [q, t, z] + levtype: pl + level: [50, 100] + stream: oper + type: an + - accumulations: + <<: *mars_request + levtype: sfc + param: [cp, tp] + - forcings: + template: ${input.join.0.mars} + param: + - cos_latitude + - sin_latitude + +output: + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level diff --git a/tools/grids/grids5.yaml b/tools/grids/grids5.yaml new file mode 100644 index 00000000..42aab132 --- /dev/null +++ b/tools/grids/grids5.yaml @@ -0,0 +1,41 @@ +common: + mars_request: &mars_request + expver: "0001" + grid: 0.2/0.2 + area: [25, 100, 20, 105] + +dates: + start: 2024-01-01 00:00:00 + end: 2024-01-01 18:00:00 + frequency: 6h + +input: + join: + - mars: + <<: *mars_request + param: [2t, 10u, 10v, lsm] + levtype: sfc + stream: oper + type: an + - mars: + <<: *mars_request + param: [q, t, z] + levtype: pl + level: [50, 100] + stream: oper + type: an + - accumulations: + <<: *mars_request + levtype: sfc + param: [cp, tp] + - forcings: + template: ${input.join.0.mars} + param: + - cos_latitude + - sin_latitude + +output: + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level diff --git a/tools/grids/grids6.yaml b/tools/grids/grids6.yaml new file mode 100644 index 00000000..641618fd --- /dev/null +++ b/tools/grids/grids6.yaml @@ -0,0 +1,41 @@ +common: + mars_request: &mars_request + expver: "0001" + grid: 10/10 + area: [90, -40, -40, 180] + +dates: + start: 2024-01-01 00:00:00 + end: 2024-01-01 18:00:00 + frequency: 6h + +input: + join: + - mars: + <<: *mars_request + param: [2t, 10u, 10v, lsm] + levtype: sfc + stream: oper + type: an + - mars: + <<: *mars_request + param: [q, t, z] + levtype: pl + level: [50, 100] + stream: oper + type: an + - accumulations: + <<: *mars_request + levtype: sfc + param: [cp, tp] + - forcings: + template: ${input.join.0.mars} + param: + - cos_latitude + - sin_latitude + +output: + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level diff --git a/tools/grids/grids7.yaml b/tools/grids/grids7.yaml new file mode 100644 index 00000000..b400a1a8 --- /dev/null +++ b/tools/grids/grids7.yaml @@ -0,0 +1,41 @@ +common: + mars_request: &mars_request + expver: "0001" + grid: 2/2 + area: [90, -40, -40, 180] + +dates: + start: 2024-01-01 00:00:00 + end: 2024-01-01 18:00:00 + frequency: 6h + +input: + join: + - mars: + <<: *mars_request + param: [2t, 10u, 10v, lsm] + levtype: sfc + stream: oper + type: an + - mars: + <<: *mars_request + param: [q, t, z] + levtype: pl + level: [50, 100] + stream: oper + type: an + - accumulations: + <<: *mars_request + levtype: sfc + param: [cp, tp] + - forcings: + template: ${input.join.0.mars} + param: + - cos_latitude + - sin_latitude + +output: + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level diff --git a/tools/grids/grids_multilam.ipynb b/tools/grids/grids_multilam.ipynb new file mode 100644 index 00000000..bb212bc4 --- /dev/null +++ b/tools/grids/grids_multilam.ipynb @@ -0,0 +1,1032 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from anemoi.datasets import open_dataset\n", + "from anemoi.datasets.data.grids import Cutout" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load the data\n", + "Datasets generated from the grids*.yaml files in tools/grids/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = \"dir_with_your_zarr_data\"\n", + "f_global = data_dir + \"/grids1.zarr\"\n", + "f_lam1 = data_dir + \"/grids2.zarr\"\n", + "f_lam2 = data_dir + \"/grids3.zarr\"\n", + "f_lam3 = data_dir + \"/grids4.zarr\"\n", + "f_lam4 = data_dir + \"/grids5.zarr\"\n", + "f_lam5 = data_dir + \"/grids6.zarr\"\n", + "f_lam6 = data_dir + \"/grids7.zarr\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "global_dataset = open_dataset(f_global)\n", + "lam_dataset_1 = open_dataset(f_lam1)\n", + "lam_dataset_2 = open_dataset(f_lam2)\n", + "lam_dataset_3 = open_dataset(f_lam3)\n", + "lam_dataset_4 = open_dataset(f_lam4)\n", + "lam_dataset_5 = open_dataset(f_lam5)\n", + "lam_dataset_6 = open_dataset(f_lam6)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# Define and run some tests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def test_cutout_initialization(lam_dataset_1, lam_dataset_2, global_dataset):\n", + " \"\"\"Ensure that the Cutout class correctly initializes with multiple Limited\n", + " Area Models (LAMs) and a global dataset.\"\"\"\n", + " cutout = Cutout(\n", + " [lam_dataset_1, lam_dataset_2, global_dataset],\n", + " axis=3,\n", + " )\n", + "\n", + " assert len(cutout.lams) == 2\n", + " assert cutout.globe is not None\n", + " assert len(cutout.masks) == 2\n", + "\n", + "\n", + "def test_cutout_mask_generation(lam_dataset, global_dataset):\n", + " \"\"\" \"Ensure that the cutout_mask function correctly generates masks for LAMs\n", + " and excludes overlapping regions.\"\"\"\n", + " cutout = Cutout([lam_dataset, global_dataset], axis=3)\n", + " mask = cutout.masks[0]\n", + " lam = cutout.lams[0]\n", + "\n", + " assert mask is not None\n", + " assert isinstance(mask, np.ndarray)\n", + " assert isinstance(cutout.global_mask, np.ndarray)\n", + " assert mask.shape[-1] == lam.shape[-1]\n", + " assert cutout.global_mask.shape[-1] == global_dataset.shape[-1]\n", + "\n", + "\n", + "def test_cutout_getitem(lam_dataset, global_dataset):\n", + " \"\"\"Verify that the __getitem__ method correctly returns the appropriate\n", + " data when indexing the Cutout object.\"\"\"\n", + " cutout = Cutout([lam_dataset, global_dataset], axis=3)\n", + "\n", + " data = cutout[0, :, :, :]\n", + " expected_shape = cutout.shape[1:]\n", + " assert data is not None\n", + " assert data.shape == expected_shape\n", + "\n", + "\n", + "def test_latitudes_longitudes_concatenation(\n", + " lam_dataset_1, lam_dataset_2, global_dataset\n", + "):\n", + " \"\"\"Ensure that latitudes and longitudes are correctly\n", + " concatenated from all LAMs and the masked global dataset.\"\"\"\n", + " cutout = Cutout([lam_dataset_1, lam_dataset_2, global_dataset], axis=3)\n", + "\n", + " latitudes = cutout.latitudes\n", + " longitudes = cutout.longitudes\n", + "\n", + " assert latitudes is not None\n", + " assert longitudes is not None\n", + " assert len(latitudes) == cutout.shape[-1]\n", + " assert len(longitudes) == cutout.shape[-1]\n", + "\n", + "\n", + "def test_overlapping_lams(lam_dataset_1, lam_dataset_2, global_dataset):\n", + " \"\"\"Confirm that overlapping regions between LAMs and the global dataset are\n", + " correctly handled by the masks.\"\"\"\n", + " # lam_dataset_2 has to overlap with lam_dataset_1\n", + " cutout = Cutout([lam_dataset_1, lam_dataset_2, global_dataset], axis=3)\n", + "\n", + " # Verify that the overlapping region in lam_dataset_2 is excluded\n", + " assert np.count_nonzero(cutout.masks[1] == False) > 0 # noqa: E712\n", + "\n", + "\n", + "def test_open_dataset_cutout(lam_dataset_1, global_dataset):\n", + " \"\"\"Ensure that open_dataset(cutout=[...]) works correctly with the new\n", + " Cutout implementation.\"\"\"\n", + " ds = open_dataset(cutout=[lam_dataset_1, global_dataset])\n", + "\n", + " assert isinstance(ds, Cutout)\n", + " assert len(ds.lams) == 1\n", + " assert ds.globe is not None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_cutout_initialization(lam_dataset_1, lam_dataset_2, global_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_cutout_mask_generation(lam_dataset_1, global_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_cutout_getitem(lam_dataset_1, global_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_latitudes_longitudes_concatenation(lam_dataset_1, lam_dataset_2, global_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_overlapping_lams(lam_dataset_1, lam_dataset_2, global_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_open_dataset_cutout(lam_dataset_1, global_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# Plot function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_grid(\n", + " ds,\n", + " path,\n", + " s=0.1,\n", + " c=\"r\",\n", + " grids=None,\n", + " point=None,\n", + " central_latitude=-20.0,\n", + " central_longitude=165.0,\n", + "):\n", + " import matplotlib.pyplot as plt\n", + " import cartopy.crs as ccrs\n", + " import numpy as np\n", + "\n", + " lats, lons = ds.latitudes, ds.longitudes\n", + "\n", + " fig = plt.figure(figsize=(9, 9))\n", + " proj = ccrs.NearsidePerspective(\n", + " central_latitude=central_latitude,\n", + " central_longitude=central_longitude,\n", + " satellite_height=4e6,\n", + " )\n", + "\n", + " ax = plt.axes(projection=proj)\n", + "\n", + " def fill():\n", + " # Make sure we have a full globe\n", + " lons, lats = np.meshgrid(np.arange(-180, 180, 1), np.arange(-90, 90, 1))\n", + " x, y, _ = proj.transform_points(\n", + " ccrs.PlateCarree(), lons.flatten(), lats.flatten()\n", + " ).T\n", + "\n", + " mask = np.invert(np.logical_or(np.isinf(x), np.isinf(y)))\n", + " x = np.compress(mask, x)\n", + " y = np.compress(mask, y)\n", + "\n", + " # ax.tricontourf(x, y, values)\n", + " ax.scatter(x, y, s=0, c=\"w\")\n", + "\n", + " fill()\n", + "\n", + " def plot(what, s, c):\n", + " x, y, _ = proj.transform_points(ccrs.PlateCarree(), lons[what], lats[what]).T\n", + "\n", + " mask = np.invert(np.logical_or(np.isinf(x), np.isinf(y)))\n", + " x = np.compress(mask, x)\n", + " y = np.compress(mask, y)\n", + "\n", + " # ax.tricontourf(x, y, values)\n", + " ax.scatter(x, y, s=s, c=c)\n", + "\n", + " if grids:\n", + " # print('s: ', s)\n", + " a = 0\n", + " for i, b in enumerate(grids):\n", + " if s[i] is not None:\n", + " plot(slice(a, a + b), s[i], c[i])\n", + " a += b\n", + " else:\n", + " plot(..., s, c)\n", + "\n", + " if point:\n", + " point = np.array(point, dtype=np.float64)\n", + " x, y, _ = proj.transform_points(ccrs.PlateCarree(), point[1], point[0]).T\n", + " ax.scatter(x, y, s=100, c=\"k\")\n", + "\n", + " ax.coastlines()\n", + "\n", + " if isinstance(path, str):\n", + " fig.savefig(path, bbox_inches=\"tight\")\n", + " else:\n", + " for p in path:\n", + " fig.savefig(p, bbox_inches=\"tight\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 1) Plot the datasets separately" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " global_dataset, \"global_grids1.png\", central_latitude=20.0, central_longitude=75.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_1, \"lam1_grids2.png\", central_latitude=60.0, central_longitude=15.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_2, \"lam1_grids3.png\", central_latitude=50.0, central_longitude=75.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_3, \"lam1_grids4.png\", central_latitude=20.0, central_longitude=105.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_4, \"lam1_grids5.png\", central_latitude=20.0, central_longitude=105.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_5, \"lam5_grids6.png\", central_latitude=-20.0, central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_6, \"lam6_grids7.png\", central_latitude=-20.0, central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 2) Test cutout with one LAM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_1, global_dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam1.png\",\n", + " s=[0.5, 0.5],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=50.0,\n", + " central_longitude=15.0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3) a) Test two overlapping LAMs\n", + "The LAMs have different resolution and are rotated" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_2, lam_dataset_1, global_dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds.grids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam1_lam2.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=50.0,\n", + " central_longitude=65.0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## 3) b) The same LAMs but in a different order" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_1, lam_dataset_2, global_dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds.grids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam2_lam1.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=50.0,\n", + " central_longitude=65.0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 4) Test two LAMS that are not overlapping" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_3, lam_dataset_2, global_dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam3_lam2.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=40.0,\n", + " central_longitude=95.0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5) Test multiple LAMS \n", + "\n", + "- LAMs with different resolutions\n", + "- Rotated LAMs\n", + "- LAMs with no overlap.\n", + "- LAM contained within other LAM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[lam_dataset_4, lam_dataset_3, lam_dataset_2, lam_dataset_1, global_dataset]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam4_lam3_lam2_lam1.png\",\n", + " s=[0.1, 0.1, 0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " # c=[\"g\", \"r\", \"b\", \"c\", \"y\", \"k\"],\n", + " c=[\"g\", \"r\", \"b\", \"c\", \"y\"],\n", + " central_latitude=50.0,\n", + " central_longitude=95.0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# Test small LAM behind bigger LAM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_3, lam_dataset_4, global_dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam3_lam4.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=50.0,\n", + " central_longitude=95.0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 6 a) Test cutout with a coarser resolution LAM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using LAMs with very low resolution can be a challenge, depending on how it compares to the resolution of the global dataset and the other LAMs.\n", + "\n", + "TODO: A future implementation could consider a list of `min_distance_km` and `neighbours`, so that there is value one for each LAM." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using default values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_5, global_dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=-30.0,\n", + " central_longitude=165.0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Has some issues when using default parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=50.0,\n", + " central_longitude=165.0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6) a) i) Test the parameter `min_distance_km`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_5, global_dataset], min_distance_km=600)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=-30.0,\n", + " central_longitude=165.0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=50.0,\n", + " central_longitude=165.0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6) a) ii) Test the parameter `neighbours`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_5, global_dataset], neighbours=200)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=-30.0,\n", + " central_longitude=165.0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=50.0,\n", + " central_longitude=165.0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6) b) A similar example, where the LAM resolution is not so low" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6) b) i) Test the parameter `min_distance_km`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[\n", + " lam_dataset_4,\n", + " lam_dataset_3,\n", + " lam_dataset_2,\n", + " lam_dataset_1,\n", + " lam_dataset_6,\n", + " global_dataset,\n", + " ],\n", + " min_distance_km=200,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam4_lam3_lam2_lam1_lam6.png\",\n", + " s=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\", \"c\", \"y\", \"k\"],\n", + " central_latitude=-30.0,\n", + " central_longitude=165.0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam4_lam3_lam2_lam1_lam6.png\",\n", + " s=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\", \"c\", \"y\", \"k\"],\n", + " central_latitude=50.0,\n", + " central_longitude=95.0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6) b) ii) Test the parameter `neighbours`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[\n", + " lam_dataset_4,\n", + " lam_dataset_3,\n", + " lam_dataset_2,\n", + " lam_dataset_1,\n", + " lam_dataset_6,\n", + " global_dataset,\n", + " ],\n", + " neighbours=10,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam4_lam3_lam2_lam1_lam6.png\",\n", + " s=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\", \"c\", \"y\", \"k\"],\n", + " central_latitude=-30.0,\n", + " central_longitude=165.0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam4_lam3_lam2_lam1_lam6.png\",\n", + " s=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\", \"c\", \"y\", \"k\"],\n", + " central_latitude=50.0,\n", + " central_longitude=95.0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 7) Test thinning with cutout" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[\n", + " {\"dataset\": lam_dataset_2, \"thinning\": 2},\n", + " {\"dataset\": lam_dataset_1, \"thinning\": 8},\n", + " {\"dataset\": global_dataset, \"thinning\": 2},\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_thinning2_global_lam2_lam1.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=50.0,\n", + " central_longitude=65.0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 8) Test cropping with cutout" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[\n", + " {\"dataset\": lam_dataset_1, \"area\": (60, 0, 20, 80)},\n", + " {\"dataset\": lam_dataset_2},\n", + " {\"dataset\": global_dataset},\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam1cropped_lam2.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=50.0,\n", + " central_longitude=65.0,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tools/upload-sample-dataset.py b/tools/upload-sample-dataset.py index 586f7fdf..67d8ebd0 100755 --- a/tools/upload-sample-dataset.py +++ b/tools/upload-sample-dataset.py @@ -20,7 +20,7 @@ import logging import os -from anemoi.utils.s3 import upload +from anemoi.utils.remote import transfer LOG = logging.getLogger(__name__) @@ -38,6 +38,8 @@ target = args.target bucket = args.bucket +assert os.path.exists(source), f"Source {source} does not exist" + if not target.startswith("s3://"): if target.startswith("/"): target = target[1:] @@ -46,5 +48,5 @@ target = os.path.join(bucket, target) LOG.info(f"Uploading {source} to {target}") -upload(source, target, overwrite=args.overwrite) +transfer(source, target, overwrite=args.overwrite) LOG.info("Upload complete")