Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(models): Copy Imputer #72

Merged
merged 16 commits into from
Feb 3, 2025
163 changes: 152 additions & 11 deletions models/src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def get_nans(self, x: torch.Tensor) -> torch.Tensor:
idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)]
return torch.isnan(x[idx].squeeze())

def fill_with_value(self, x, index):
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value
return x

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
if not in_place:
Expand Down Expand Up @@ -145,10 +151,7 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
)

# Replace values
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value
return x
return self.fill_with_value(x, index)

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
Expand Down Expand Up @@ -231,13 +234,131 @@ def __init__(
self._validate_indices()


class CopyImputer(BaseImputer):
"""Imputes missing values copying them from another variable.
```
default: "none"
variable_to_copy:
- variable_missing_1
- variable_missing_2
```
"""

def __init__(
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
) -> None:
super().__init__(config, data_indices, statistics)

self._create_imputation_indices()

self._validate_indices()

def _create_imputation_indices(
self,
):
"""Create the indices for imputation."""
name_to_index_training_input = self.data_indices.data.input.name_to_index
name_to_index_inference_input = self.data_indices.model.input.name_to_index
name_to_index_training_output = self.data_indices.data.output.name_to_index
name_to_index_inference_output = self.data_indices.model.output.name_to_index

self.num_training_input_vars = len(name_to_index_training_input)
self.num_inference_input_vars = len(name_to_index_inference_input)
self.num_training_output_vars = len(name_to_index_training_output)
self.num_inference_output_vars = len(name_to_index_inference_output)

(
self.index_training_input,
self.index_inference_input,
self.index_training_output,
self.index_inference_output,
self.replacement,
) = ([], [], [], [], [])

# Create indices for imputation
for name in name_to_index_training_input:
key_to_copy = self.methods.get(name, self.default)

if key_to_copy == "none":
LOGGER.debug(f"Imputer: skipping {name} as no imputation method is specified")
continue

self.index_training_input.append(name_to_index_training_input[name])
self.index_training_output.append(name_to_index_training_output.get(name, None))
self.index_inference_input.append(name_to_index_inference_input.get(name, None))
self.index_inference_output.append(name_to_index_inference_output.get(name, None))

self.replacement.append(key_to_copy)

LOGGER.debug(f"Imputer: replacing NaNs in {name} with value coming from variable :{self.replacement[-1]}")

def fill_with_value(self, x, index):
# Replace values
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
assert not torch.isnan(
x[..., self.data_indices.data.input.name_to_index[value]][self._expand_subset_mask(x, idx_src)]
).any(), f"NaNs found in {value}."
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = x[
..., self.data_indices.data.input.name_to_index[value]
][self._expand_subset_mask(x, idx_src)]
return x

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
if not in_place:
x = x.clone()

# Initialize nan mask once
if self.nan_locations is None:

# Get NaN locations
self.nan_locations = self.get_nans(x)

# Initialize training loss mask to weigh imputed values with zeroes once
self.loss_mask_training = torch.ones(
(x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device
) # shape (grid, n_outputs)
# for all variables that are imputed and part of the model output, set the loss weight to zero
for idx_src, idx_dst in zip(self.index_training_input, self.index_inference_output):
if idx_dst is not None:
self.loss_mask_training[:, idx_dst] = (~self.nan_locations[:, idx_src]).int()

# Choose correct index based on number of variables
if x.shape[-1] == self.num_training_input_vars:
index = self.index_training_input
elif x.shape[-1] == self.num_inference_input_vars:
index = self.index_inference_input
else:
raise ValueError(
f"Input tensor ({x.shape[-1]}) does not match the training "
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
)

return self.fill_with_value(x, index)


class DynamicMixin:
"""Mixin to add dynamic imputation behavior."""
"""
Mixin to add dynamic imputation behavior.
To be used when NaN maps change at different timesteps.
"""

def get_nans(self, x: torch.Tensor) -> torch.Tensor:
"""Override to calculate NaN locations dynamically."""
return torch.isnan(x)

def fill_with_value(self, x, index, nan_locations):
# Replace values
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
x[..., idx_dst][nan_locations[..., idx_src]] = value

return x

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
if not in_place:
Expand All @@ -261,12 +382,7 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
)

# Replace values
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
x[..., idx_dst][nan_locations[..., idx_src]] = value

return x
return self.fill_with_value(x, index, nan_locations)

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
Expand Down Expand Up @@ -303,3 +419,28 @@ def __init__(
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
The model will be trained to predict imputed values. This might deteriorate performances."
)


class DynamicCopyImputer(DynamicMixin, CopyImputer):
"""Dynamic Copy imputation behavior."""

def fill_with_value(self, x, index, nan_locations):
# Replace values
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
assert not torch.isnan(
x[..., self.data_indices.data.input.name_to_index[value]][nan_locations[..., idx_src]]
).any(), f"NaNs found in {value}."
x[..., idx_dst][nan_locations[..., idx_src]] = x[
..., self.data_indices.data.input.name_to_index[value]
][nan_locations[..., idx_src]]

return x

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
return DynamicMixin.transform(self, x, in_place)

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
return DynamicMixin.inverse_transform(self, x, in_place)
Loading