Skip to content

Commit

Permalink
Merge pull request #9 from MAIF/feature/get_combined_coverage
Browse files Browse the repository at this point in the history
Feature/get combined coverage
  • Loading branch information
ThomasBouche authored Jan 7, 2025
2 parents d2cf40e + b06c988 commit 34fc31f
Show file tree
Hide file tree
Showing 10 changed files with 1,426 additions and 433 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ repos:
exclude: ^(docs/)
- id: pretty-format-json
args: [--autofix]
exclude_types: [jupyter]
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
exclude: ^(docs/)
Expand Down Expand Up @@ -53,3 +54,7 @@ repos:
- id: conventional-pre-commit
stages: [commit-msg]
args: [feat, fix, ci, chore, test, docs]
- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
hooks:
- id: nbstripout
6 changes: 3 additions & 3 deletions src/meteole/_arpege.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
"WIND_SPEED_GUST__SPECIFIC_HEIGHT_LEVEL_ABOVE_GROUND",
"WIND_SPEED__SPECIFIC_HEIGHT_LEVEL_ABOVE_GROUND",
"WIND_SPEED__ISOBARIC_SURFACE",
"DOWNWARD_SHORT_WAVE_RADIATION_FLUX__GROUND_OR_WATER_SURFACE",
"SHORT_WAVE_RADIATION_FLUX__GROUND_OR_WATER_SURFACE",
"RELATIVE_HUMIDITY__SPECIFIC_HEIGHT_LEVEL_ABOVE_GROUND",
"RELATIVE_HUMIDITY__ISOBARIC_SURFACE",
"PLANETARY_BOUNDARY_LAYER_HEIGHT__GROUND_OR_WATER_SURFACE",
Expand Down Expand Up @@ -57,13 +55,15 @@
"V_COMPONENT_OF_WIND__POTENTIAL_VORTICITY_SURFACE_1500",
"V_COMPONENT_OF_WIND__POTENTIAL_VORTICITY_SURFACE_2000",
"GEOPOTENTIAL__ISOBARIC_SURFACE",
"TOTAL_CLOUD_COVER__GROUND_OR_WATER_SURFACE",
]

ARPEGE_OTHER_INDICATORS: list[str] = [
"TOTAL_WATER_PRECIPITATION__GROUND_OR_WATER_SURFACE",
"TOTAL_CLOUD_COVER__GROUND_OR_WATER_SURFACE",
"TOTAL_SNOW_PRECIPITATION__GROUND_OR_WATER_SURFACE",
"TOTAL_PRECIPITATION__GROUND_OR_WATER_SURFACE",
"DOWNWARD_SHORT_WAVE_RADIATION_FLUX__GROUND_OR_WATER_SURFACE",
"SHORT_WAVE_RADIATION_FLUX__GROUND_OR_WATER_SURFACE",
]


Expand Down
264 changes: 233 additions & 31 deletions src/meteole/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import glob
import logging
import os
import re
from abc import ABC, abstractmethod
from functools import reduce
from pathlib import Path
from typing import Any
from warnings import warn
Expand Down Expand Up @@ -152,7 +154,7 @@ def get_coverage(
if indicator:
coverage_id = self._get_coverage_id(indicator, run, interval)

logger.debug(f"Using `coverage_id={coverage_id}`")
logger.info(f"Using `coverage_id={coverage_id}`")

axis = self.get_coverage_description(coverage_id)

Expand All @@ -178,28 +180,6 @@ def get_coverage(

return pd.concat(df_list, axis=0).reset_index(drop=True)

def get_coverages(
self,
coverage_ids: list[str],
lat: tuple = FRANCE_METRO_LATITUDES,
long: tuple = FRANCE_METRO_LONGITUDES,
) -> pd.DataFrame:
"""
Convenient function to quickly fetch a list of indicators using defaults `heights` and `forecast_horizons`
For finer control over heights and forecast_horizons use :meth:`get_coverage`
"""
coverages = [
self.get_coverage(
coverage_id,
lat,
long,
)
for coverage_id in coverage_ids
]

return pd.concat(coverages, axis=0)

def _build_capabilities(self) -> pd.DataFrame:
"Returns the coverage dataframe containing the details of all available coverage_ids"

Expand Down Expand Up @@ -283,7 +263,7 @@ def _get_coverage_id(
valid_intervals = capabilities["interval"].unique().tolist()

if indicator in self.INSTANT_INDICATORS:
if interval is None:
if not interval:
# no interval is expected for instant indicators
pass
else:
Expand All @@ -292,7 +272,7 @@ def _get_coverage_id(
"indicator `{indicator}`."
)
else:
if interval is None:
if not interval:
interval = "P1D"
logger.info(
f"`interval=None` is invalid for non-instant indicators. Using default `interval={interval}`"
Expand All @@ -305,7 +285,7 @@ def _get_coverage_id(

coverage_id = f"{indicator}___{run}"

if interval is not None:
if interval:
coverage_id += f"_{interval}"

return coverage_id
Expand All @@ -320,11 +300,11 @@ def _raise_if_invalid_or_fetch_default(
Args:
param_name (str): The name of the parameter to validate.
inputs (Optional[List[int]]): The list of inputs to validate.
availables (List[int]): The list of available values.
inputs (list[int] | None): The list of inputs to validate.
availables (list[int]): The list of available values.
Returns:
List[int]: The validated list of inputs or the default value.
list[int]: The validated list of inputs or the default value.
Raises:
ValueError: If any of the inputs are not in `availables`.
Expand Down Expand Up @@ -442,13 +422,35 @@ def _get_data_single_forecast(
df.rename(
columns={
"time": "run",
"heightAboveGround": "height",
"isobaricInhPa": "pressure",
"step": "forecast_horizon",
},
inplace=True,
)

known_columns = {"latitude", "longitude", "run", "forecast_horizon", "heightAboveGround", "isobaricInhPa"}
indicator_column = (set(df.columns) - known_columns).pop()

if indicator_column == "unknown":
base_name = "".join([word[0] for word in coverage_id.split("__")[0].split("_")]).lower()
else:
base_name = re.sub(r"\d.*", "", indicator_column)

if "heightAboveGround" in df.columns:
suffix = f"_{int(df['heightAboveGround'].iloc[0])}m"
elif "isobaricInhPa" in df.columns:
suffix = f"_{int(df['isobaricInhPa'].iloc[0])}hpa"
else:
suffix = ""

new_indicator_column = f"{base_name}{suffix}"
df.rename(columns={indicator_column: new_indicator_column}, inplace=True)

df.drop(
columns=["isobaricInhPa", "heightAboveGround", "meanSea", "potentialVorticity"],
errors="ignore",
inplace=True,
)

return df

def _get_coverage_file(
Expand Down Expand Up @@ -540,3 +542,203 @@ def _get_available_feature(grid_axis, feature_name):
features = feature_grid_axis[0]["gmlrgrid:GeneralGridAxis"]["gmlrgrid:coefficients"].split(" ")
features = [int(feature) for feature in features]
return features

def get_combined_coverage(
self,
indicator_names: list[str],
runs: list[str | None] | None = None,
heights: list[int] | None = None,
pressures: list[int] | None = None,
intervals: list[str | None] | None = None,
lat: tuple = FRANCE_METRO_LATITUDES,
long: tuple = FRANCE_METRO_LONGITUDES,
forecast_horizons: list[int] | None = None,
) -> pd.DataFrame:
"""
Get a combined DataFrame of coverage data for multiple indicators and different runs.
This method retrieves and aggregates coverage data for specified indicators, with options
to filter by height, pressure, and forecast_horizon. It returns a concatenated DataFrame
containing the coverage data for all provided runs.
Args:
indicator_names (list[str]): A list of indicator names to retrieve data for.
runs (list[str]): A list of runs for each indicator. Format should be "YYYY-MM-DDTHH:MM:SSZ".
heights (list[int] | None): A list of heights in meters to filter by (default is None).
pressures (list[int] | None): A list of pressures in hPa to filter by (default is None).
intervals (list[str] | None): A list of aggregation periods (default is None). Must be `None` or "" for instant indicators;
otherwise, raises an exception. Defaults to 'P1D' for time-aggregated indicators.
lat (tuple): The latitude range as (min_latitude, max_latitude). Defaults to FRANCE_METRO_LATITUDES.
long (tuple): The longitude range as (min_longitude, max_longitude). Defaults to FRANCE_METRO_LONGITUDES.
forecast_horizons (list[int] | None): A list of forecast horizon values in hours. Defaults to None.
Returns:
pd.DataFrame: A combined DataFrame containing coverage data for all specified runs and indicators.
Raises:
ValueError: If the length of `heights` does not match the length of `indicator_names`.
"""
if not runs:
runs = [None]
coverages = [
self._get_combined_coverage_for_single_run(
indicator_names=indicator_names,
run=run,
lat=lat,
long=long,
heights=heights,
pressures=pressures,
intervals=intervals,
forecast_horizons=forecast_horizons,
)
for run in runs
]
return pd.concat(coverages, axis=0).reset_index(drop=True)

def _get_combined_coverage_for_single_run(
self,
indicator_names: list[str],
run: str | None = None,
heights: list[int] | None = None,
pressures: list[int] | None = None,
intervals: list[str | None] | None = None,
lat: tuple = FRANCE_METRO_LATITUDES,
long: tuple = FRANCE_METRO_LONGITUDES,
forecast_horizons: list[int] | None = None,
) -> pd.DataFrame:
"""
Get a combined DataFrame of coverage data for a given run considering a list of indicators.
This method retrieves and aggregates coverage data for specified indicators, with options
to filter by height, pressure, and forecast_horizon. It returns a concatenated DataFrame
containing the coverage data.
Args:
indicator_names (list[str]): A list of indicator names to retrieve data for.
run (str): A single runs for each indicator. Format should be "YYYY-MM-DDTHH:MM:SSZ".
heights (list[int] | None): A list of heights in meters to filter by (default is None).
pressures (list[int] | None): A list of pressures in hPa to filter by (default is None).
intervals (Optional[list[str]]): A list of aggregation periods (default is None). Must be `None` or "" for instant indicators;
otherwise, raises an exception. Defaults to 'P1D' for time-aggregated indicators.
lat (tuple): The latitude range as (min_latitude, max_latitude). Defaults to FRANCE_METRO_LATITUDES.
long (tuple): The longitude range as (min_longitude, max_longitude). Defaults to FRANCE_METRO_LONGITUDES.
forecast_horizons (list[int] | None): A list of forecast horizon values in hours. Defaults to None.
Returns:
pd.DataFrame: A combined DataFrame containing coverage data for all specified runs and indicators.
Raises:
ValueError: If the length of `heights` does not match the length of `indicator_names`.
"""

def _check_params_length(params: list | None, arg_name: str) -> list:
"""assert length is ok or raise"""
if params is None:
return [None] * len(indicator_names)
if len(params) != len(indicator_names):
raise ValueError(
f"The length of {arg_name} must match the length of indicator_names. If you want multiple {arg_name} for a single indicator, create multiple entries in `indicator_names`."
)
return params

heights = _check_params_length(heights, "heights")
pressures = _check_params_length(pressures, "pressures")
intervals = _check_params_length(intervals, "intervals")

# Get coverage id from run and indicator_name
coverage_ids = [
self._get_coverage_id(indicator_name, run, interval)
for indicator_name, interval in zip(indicator_names, intervals)
]

if forecast_horizons:
# Check forecast_horizons is valid for all indicators
invalid_coverage_ids = self._validate_forecast_horizons(coverage_ids, forecast_horizons)
if invalid_coverage_ids:
raise ValueError(f"{forecast_horizons} are not valid for these coverage_ids : {invalid_coverage_ids}")
else:
forecast_horizons = [self.find_common_forecast_horizons(coverage_ids)[0]]
logger.info(f"Using common forecast_horizons `forecast_horizons={forecast_horizons}`.")

coverages = [
self.get_coverage(
coverage_id=coverage_id,
run=run,
lat=lat,
long=long,
heights=[height] if height is not None else [],
pressures=[pressure] if pressure is not None else [],
forecast_horizons=forecast_horizons,
)
for coverage_id, height, pressure in zip(coverage_ids, heights, pressures)
]

return reduce(
lambda left, right: pd.merge(
left,
right,
on=["latitude", "longitude", "run", "forecast_horizon"],
how="inner",
validate="one_to_one",
),
coverages,
)

def _get_forecast_horizons(self, coverage_ids: list[str]) -> list[list[int]]:
"""
Retrieve the times for each coverage_id.
Parameters:
coverage_ids (list[str]): List of coverage IDs.
Returns:
list[list[int]]: List of times for each coverage ID.
"""
indicator_times = []
for coverage_id in coverage_ids:
times = self.get_coverage_description(coverage_id)["forecast_horizons"]
indicator_times.append(times)
return indicator_times

def find_common_forecast_horizons(
self,
list_coverage_id: list[str],
) -> list[int]:
"""
Find common forecast_horizons among coverage IDs.
indicator_names (list[str]): List of indicator names.
run (Optional[str]): Identifies the model inference. Defaults to latest if None. Format "YYYY-MM-DDTHH:MM:SSZ".
intervals (Optional[list[str]]): List of aggregation periods. Must be None for instant indicators, otherwise raises. Defaults to P1D for time-aggregated indicators like TOTAL_PRECIPITATION.
Returns:
list[int]: Common forecast_horizons
"""
indicator_forecast_horizons = self._get_forecast_horizons(list_coverage_id)

common_forecast_horizons = indicator_forecast_horizons[0]
for times in indicator_forecast_horizons[1:]:
common_forecast_horizons = [time for time in common_forecast_horizons if time in times]

all_times = []
for times in indicator_forecast_horizons:
all_times.extend(times)

return sorted(common_forecast_horizons)

def _validate_forecast_horizons(self, coverage_ids: list[str], forecast_horizons: list[int]) -> list[str]:
"""
Validate forecast_horizons for a list of coverage IDs.
Parameters:
coverage_ids (list[str]): List of coverage IDs.
forecast_horizons (list[int]): List of time forecasts to validate.
Returns:
list[str]: List of invalid coverage IDs.
"""
indicator_forecast_horizons = self._get_forecast_horizons(coverage_ids)

invalid_coverage_ids = [
coverage_id
for coverage_id, times in zip(coverage_ids, indicator_forecast_horizons)
if not set(forecast_horizons).issubset(times)
]

return invalid_coverage_ids
Loading

0 comments on commit 34fc31f

Please sign in to comment.