Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Batch Sampling Pipeline with Integration Tests for GFS Data #57

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ node_modules/

# Data
data/


#saved_batches
saved_batches/
9 changes: 5 additions & 4 deletions src/open_data_pvnet/configs/gfs_data_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ general:
input_data:
nwp:
gfs:
time_resolution_minutes: 180 # Match the dataset's resolution (3 hours)
time_resolution_minutes: 360 # Ensure matches dataset's 6-hour steps
interval_start_minutes: 0
interval_end_minutes: 1080 # 6 forecast steps (6 * 3 hours)
dropout_timedeltas_minutes: null
interval_end_minutes: 1440 # Forecast for 1 day
# dropout_fraction: 0.1
dropout_timedeltas_minutes: null # Remove restrictions on dropout
accum_channels: []
max_staleness_minutes: 1080 # Match interval_end_minutes for consistency
max_staleness_minutes: 180 # Extend staleness window (3 days)
zarr_path: "s3://ocf-open-data-pvnet/data/gfs.zarr"
provider: "gfs"
image_size_pixels_height: 1
Expand Down
146 changes: 81 additions & 65 deletions src/open_data_pvnet/nwp/gfs_dataset.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,66 @@
"""
# How to run this script independently:
1. Ensure `ocf-data-sampler` is installed and properly configured.
2. Set the appropriate dataset path and config file.
3. Uncomment the main block below to run as a standalone script.
GFS Data Sampler

This script is designed to load, process, and sample Global Forecast System (GFS) data
stored in Zarr format. It provides functionalities for handling NaN values, retrieving
valid forecast initialization times, and normalizing the dataset for machine learning tasks.

The script is structured as follows:
1. **open_gfs**: Loads the dataset from an S3 or local Zarr file.
2. **handle_nan_values**: Handles missing data in the dataset.
3. **GFSDataSampler** (PyTorch Dataset): Samples and normalizes data for training.

To test the functionality, run this script directly. The `main` block loads the dataset
and prints basic statistics for debugging.

"""

import logging
import pandas as pd
import xarray as xr
import numpy as np
import fsspec
from torch.utils.data import Dataset
from ocf_data_sampler.config import load_yaml_configuration
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
import fsspec
import numpy as np


# Configure logging
logging.basicConfig(level=logging.WARNING)
# Configure logging format for readability
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# Ensure xarray retains attributes during operations
xr.set_options(keep_attrs=True)


def open_gfs(dataset_path: str) -> xr.DataArray:
"""
Opens the GFS dataset stored in Zarr format and prepares it for processing.
Open the GFS dataset from a Zarr file stored remotely or locally.

Args:
dataset_path (str): Path to the GFS dataset.
dataset_path (str): Path to the GFS dataset (S3 or local).

Returns:
xr.DataArray: The processed GFS data.
xr.DataArray: The processed GFS data array with required dimensions.
"""
logging.info("Opening GFS dataset synchronously...")
logging.info(f"Opening GFS dataset from {dataset_path}...")
store = fsspec.get_mapper(dataset_path, anon=True)
gfs_dataset: xr.Dataset = xr.open_dataset(
store, engine="zarr", consolidated=True, chunks="auto"
)

# Convert dataset to DataArray for easier handling
gfs_data: xr.DataArray = gfs_dataset.to_array(dim="channel")

# Rename "init_time" to "init_time_utc" if necessary
if "init_time" in gfs_data.dims:
logging.debug("Renaming 'init_time' to 'init_time_utc'...")
gfs_data = gfs_data.rename({"init_time": "init_time_utc"})

# Ensure correct dimension order
required_dims = ["init_time_utc", "step", "channel", "latitude", "longitude"]
gfs_data = gfs_data.transpose(*required_dims)

logging.debug(f"GFS dataset dimensions: {gfs_data.dims}")
logging.info("GFS dataset loaded successfully.")
return gfs_data


Expand All @@ -59,7 +72,7 @@ def handle_nan_values(

Args:
dataset (xr.DataArray): The dataset to process.
method (str): The method for handling NaNs ("fill" or "drop").
method (str): Method for handling NaNs ("fill" or "drop").
fill_value (float): Value to replace NaNs if method is "fill".

Returns:
Expand All @@ -78,6 +91,10 @@ def handle_nan_values(
class GFSDataSampler(Dataset):
"""
A PyTorch Dataset for sampling and normalizing GFS data.

Attributes:
dataset (xr.DataArray): GFS dataset containing weather variables.
valid_t0_times (pd.DataFrame): Dataframe of valid initialization times for sampling.
"""

def __init__(
Expand All @@ -92,19 +109,27 @@ def __init__(

Args:
dataset (xr.DataArray): The dataset to sample from.
config_filename (str): Path to the configuration file.
config_filename (str): Path to the YAML configuration file.
start_time (str, optional): Start time for filtering data.
end_time (str, optional): End time for filtering data.
"""
logging.info("Initializing GFSDataSampler...")
self.dataset = dataset
self.config = load_yaml_configuration(config_filename)

# Retrieve valid initialization times
self.valid_t0_times = find_valid_time_periods({"nwp": {"gfs": self.dataset}}, self.config)
logging.debug(f"Valid initialization times:\n{self.valid_t0_times}")
logging.info(f"Raw valid_t0_times:\n{self.valid_t0_times}")

# Ensure multiple valid timestamps exist
if len(self.valid_t0_times) <= 1:
logging.warning("Only one valid t0 timestamp found. Consider adjusting max_staleness.")

# Rename "start_dt" to "t0" for clarity
if "start_dt" in self.valid_t0_times.columns:
self.valid_t0_times = self.valid_t0_times.rename(columns={"start_dt": "t0"})

# Apply time range filtering if specified
if start_time:
self.valid_t0_times = self.valid_t0_times[
self.valid_t0_times["t0"] >= pd.Timestamp(start_time)
Expand All @@ -114,13 +139,19 @@ def __init__(
self.valid_t0_times["t0"] <= pd.Timestamp(end_time)
]

logging.debug(f"Filtered valid_t0_times:\n{self.valid_t0_times}")
logging.info(
f"Total valid initialization times after filtering: {len(self.valid_t0_times)}"
)

def __len__(self):
return len(self.valid_t0_times)

def __getitem__(self, idx):
"""
Fetch a sample based on the index.
"""
t0 = self.valid_t0_times.iloc[idx]["t0"]
logging.info(f"Fetching sample for t0={t0}.")
return self._get_sample(t0)

def _get_sample(self, t0: pd.Timestamp) -> xr.Dataset:
Expand All @@ -131,9 +162,10 @@ def _get_sample(self, t0: pd.Timestamp) -> xr.Dataset:
t0 (pd.Timestamp): The initialization time.

Returns:
xr.Dataset: The sampled data.
xr.Dataset: The sampled and normalized data.
"""
logging.info(f"Generating sample for t0={t0}...")

interval_start = pd.Timedelta(minutes=self.config.input_data.nwp.gfs.interval_start_minutes)
interval_end = pd.Timedelta(minutes=self.config.input_data.nwp.gfs.interval_end_minutes)
time_resolution = pd.Timedelta(
Expand All @@ -143,67 +175,51 @@ def _get_sample(self, t0: pd.Timestamp) -> xr.Dataset:
start_dt = t0 + interval_start
end_dt = t0 + interval_end
target_times = pd.date_range(start=start_dt, end=end_dt, freq=time_resolution)
logging.debug(f"Target times: {target_times}")

valid_steps = [np.timedelta64((time - t0).value, "ns") for time in target_times]
available_steps = self.dataset.step.values
valid_steps = [step for step in valid_steps if step in available_steps]

if not valid_steps:
raise ValueError(f"No valid steps found for t0={t0}")
logging.info(f"Expected target times: {target_times}")

sliced_data = self.dataset.sel(init_time_utc=t0, step=valid_steps)
sliced_data = self.dataset.sel(
init_time_utc=t0, step=[np.timedelta64((t - t0).value, "ns") for t in target_times]
)
return self._normalize_sample(sliced_data)

def _normalize_sample(self, dataset: xr.Dataset) -> xr.Dataset:
"""
Normalize the dataset using precomputed means and standard deviations.
Normalize the dataset using mean and standard deviation values.

Args:
dataset (xr.Dataset): The dataset to normalize.

Returns:
xr.Dataset: The normalized dataset.
"""
logging.info("Starting normalization...")
provider = self.config.input_data.nwp.gfs.provider
dataset_channels = dataset.channel.values
mean_channels = NWP_MEANS[provider].channel.values
std_channels = NWP_STDS[provider].channel.values

valid_channels = set(dataset_channels) & set(mean_channels) & set(std_channels)
missing_in_dataset = set(mean_channels) - set(dataset_channels)
missing_in_means = set(dataset_channels) - set(mean_channels)

if missing_in_dataset:
logging.warning(f"Channels missing in dataset: {missing_in_dataset}")
if missing_in_means:
logging.warning(f"Channels missing in normalization stats: {missing_in_means}")

valid_channels = list(valid_channels)
dataset = dataset.sel(channel=valid_channels)
means = NWP_MEANS[provider].sel(channel=valid_channels)
stds = NWP_STDS[provider].sel(channel=valid_channels)

logging.debug(f"Selected Channels: {valid_channels}")
logging.debug(f"Mean Values: {means.values}")
logging.debug(f"Std Values: {stds.values}")

try:
normalized_dataset = (dataset - means) / stds
logging.info("Normalization completed.")
return normalized_dataset
except Exception as e:
logging.error(f"Error during normalization: {e}")
raise e


# # Uncomment the block below to test
logging.info("Normalizing dataset...")
return (dataset - NWP_MEANS["gfs"]) / NWP_STDS["gfs"]


# if __name__ == "__main__":
# """
# Main block for testing the GFS data sampling process.
# This section ensures the dataset loads correctly, handles NaNs, and samples data.

# Steps:
# 1. Load the dataset from an S3 location.
# 2. Handle NaN values (filling with zero).
# 3. Initialize the GFSDataSampler.
# 4. Print dataset statistics for validation.
# """
# dataset_path = "s3://ocf-open-data-pvnet/data/gfs.zarr"
# config_path = "src/open_data_pvnet/configs/gfs_data_config.yaml"

# # Load dataset
# dataset = open_gfs(dataset_path)

# # Handle NaN values
# dataset = handle_nan_values(dataset, method="fill", fill_value=0.0)
# sampler = GFSDataSampler(dataset, config_filename=config_path, start_time="2023-01-01T00:00:00", end_time="2023-01-30T00:00:00")
# sample = sampler[0]
# print(sample)

# # Initialize sampler
# sampler = GFSDataSampler(dataset, config_filename=config_path)

# # Print statistics
# logging.info(f"Total samples available: {len(sampler)}")
# logging.info(f"First 5 samples: {[sampler[i] for i in range(min(5, len(sampler)))]}")
103 changes: 103 additions & 0 deletions src/open_data_pvnet/scripts/batch_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# batch_samples.py
import logging
import sys
import math
import torch.multiprocessing as mp
from torch.utils.data import DataLoader

# Import functions and classes from their canonical modules
from src.open_data_pvnet.nwp.gfs_dataset import open_gfs, handle_nan_values, GFSDataSampler
from src.open_data_pvnet.utils.batch_utils import process_and_save_batches

# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)


# Define an identity collate function at the module level.
# This function simply returns the batch as is.
def identity_collate(batch):
return batch


def run_batch_samples(
dataset_path: str,
config_path: str,
start_time: str,
end_time: str,
output_directory: str,
dataloader_kwargs: dict = None,
num_batches: int = None,
):
"""
Runs the batch sampling process: loads the dataset, wraps it in a DataLoader,
and processes/saves batches.
"""
# Set the multiprocessing start method if not already set.
if mp.get_start_method(allow_none=True) != "forkserver":
try:
mp.set_start_method("forkserver", force=True)
except RuntimeError:
logger.warning("Multiprocessing start method already set. Proceeding.")
mp.set_sharing_strategy("file_system")

# Set default DataLoader kwargs if not provided.
if dataloader_kwargs is None:
dataloader_kwargs = {
"batch_size": 4, # Example batch size; adjust as needed.
"shuffle": True,
"num_workers": 2,
"prefetch_factor": 2,
"pin_memory": False,
# Use the module-level identity_collate function.
"collate_fn": identity_collate,
}

# Load and preprocess the raw dataset.
logger.info("Loading GFS dataset from %s", dataset_path)
gfs_data = open_gfs(dataset_path)
gfs_data = handle_nan_values(gfs_data, method="fill", fill_value=0.0)

# Initialize the custom dataset sampler.
logger.info("Initializing GFSDataSampler with config: %s", config_path)
gfs_sampler = GFSDataSampler(
gfs_data, config_filename=config_path, start_time=start_time, end_time=end_time
)

# Enhanced logging: log number of samples and expected number of batches.
num_samples = len(gfs_sampler)
batch_size = dataloader_kwargs.get("batch_size", 1)
expected_batches = math.ceil(num_samples / batch_size)
logger.info("Total samples in dataset: %d", num_samples)
logger.info("Expected number of batches (with batch_size=%d): %d", batch_size, expected_batches)

# Create a DataLoader to batch the dataset.
logger.info("Creating DataLoader with parameters: %s", dataloader_kwargs)
data_loader = DataLoader(gfs_sampler, **dataloader_kwargs)

# Optional: verify by iterating over one batch.
logger.info("Verifying DataLoader by iterating over the first batch...")
for batch in data_loader:
logger.info("Received a batch of samples (verification). Batch type: %s", type(batch))
break

# Process and save batches.
logger.info("Processing and saving batches to directory: %s", output_directory)
process_and_save_batches(data_loader, output_directory, num_batches=num_batches)
logger.info("All batches processed and saved.")


# The __main__ block allows the script to be run standalone.
# However, the main functionality is encapsulated in run_batch_samples(),
# so it can also be imported and called from other parts of the repo.
if __name__ == "__main__":
# Example configuration values.
dataset_path = "s3://ocf-open-data-pvnet/data/gfs.zarr"
config_path = "src/open_data_pvnet/configs/gfs_data_config.yaml"
start_time = "2023-01-01T00:00:00"
end_time = "2023-02-28T00:00:00"
output_directory = "./saved_batches"

run_batch_samples(
dataset_path, config_path, start_time, end_time, output_directory, num_batches=10
)
Loading