Skip to content

Commit

Permalink
feat(models): Copy Imputer (#72)
Browse files Browse the repository at this point in the history
* Implemented copy imputer

* Fixed implementation and tested

* gpc

* Removed print leftover

* Refactor according to review

* gpc passed

* Fixed issue

* Minor changes

* init error

* Merged develop

* Fixed dynamic imputer

* Fixed issue when running inference with dynamic imputer and diagnostic variables

* Addressed changes in review
  • Loading branch information
icedoom888 authored Feb 3, 2025
1 parent f1cc2e6 commit 4690ed5
Showing 1 changed file with 172 additions and 14 deletions.
186 changes: 172 additions & 14 deletions models/src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
super().__init__(config, data_indices, statistics)

self.nan_locations = None
# weight imputed values wiht zero in loss calculation
# weight imputed values with zero in loss calculation
self.loss_mask_training = None

def _validate_indices(self):
Expand Down Expand Up @@ -113,6 +113,12 @@ def get_nans(self, x: torch.Tensor) -> torch.Tensor:
idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)]
return torch.isnan(x[idx].squeeze())

def fill_with_value(self, x, index):
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value
return x

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
if not in_place:
Expand Down Expand Up @@ -145,10 +151,7 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
)

# Replace values
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value
return x
return self.fill_with_value(x, index)

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
Expand Down Expand Up @@ -231,13 +234,130 @@ def __init__(
self._validate_indices()


class CopyImputer(BaseImputer):
"""Imputes missing values copying them from another variable.
```
default: "none"
variable_to_copy:
- variable_missing_1
- variable_missing_2
```
"""

def __init__(
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
) -> None:
super().__init__(config, data_indices, statistics)

self._create_imputation_indices()

self._validate_indices()

def _create_imputation_indices(
self,
):
"""Create the indices for imputation."""
name_to_index_training_input = self.data_indices.data.input.name_to_index
name_to_index_inference_input = self.data_indices.model.input.name_to_index
name_to_index_training_output = self.data_indices.data.output.name_to_index
name_to_index_inference_output = self.data_indices.model.output.name_to_index

self.num_training_input_vars = len(name_to_index_training_input)
self.num_inference_input_vars = len(name_to_index_inference_input)
self.num_training_output_vars = len(name_to_index_training_output)
self.num_inference_output_vars = len(name_to_index_inference_output)

(
self.index_training_input,
self.index_inference_input,
self.index_training_output,
self.index_inference_output,
self.replacement,
) = ([], [], [], [], [])

# Create indices for imputation
for name in name_to_index_training_input:
key_to_copy = self.methods.get(name, self.default)

if key_to_copy == "none":
LOGGER.debug(f"Imputer: skipping {name} as no imputation method is specified")
continue

self.index_training_input.append(name_to_index_training_input[name])
self.index_training_output.append(name_to_index_training_output.get(name, None))
self.index_inference_input.append(name_to_index_inference_input.get(name, None))
self.index_inference_output.append(name_to_index_inference_output.get(name, None))

self.replacement.append(key_to_copy)

LOGGER.debug(f"Imputer: replacing NaNs in {name} with value coming from variable :{self.replacement[-1]}")

def fill_with_value(self, x, index):
# Replace values
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
assert not torch.isnan(
x[..., self.data_indices.data.input.name_to_index[value]][self._expand_subset_mask(x, idx_src)]
).any(), f"NaNs found in {value}."
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = x[
..., self.data_indices.data.input.name_to_index[value]
][self._expand_subset_mask(x, idx_src)]
return x

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
if not in_place:
x = x.clone()

# Initialize nan mask once
if self.nan_locations is None:

# Get NaN locations
self.nan_locations = self.get_nans(x)

# Initialize training loss mask to weigh imputed values with zeroes once
self.loss_mask_training = torch.ones(
(x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device
) # shape (grid, n_outputs)
# for all variables that are imputed and part of the model output, set the loss weight to zero
for idx_src, idx_dst in zip(self.index_training_input, self.index_inference_output):
if idx_dst is not None:
self.loss_mask_training[:, idx_dst] = (~self.nan_locations[:, idx_src]).int()

# Choose correct index based on number of variables
if x.shape[-1] == self.num_training_input_vars:
index = self.index_training_input
elif x.shape[-1] == self.num_inference_input_vars:
index = self.index_inference_input
else:
raise ValueError(
f"Input tensor ({x.shape[-1]}) does not match the training "
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
)

return self.fill_with_value(x, index)


class DynamicMixin:
"""Mixin to add dynamic imputation behavior."""
"""
Mixin to add dynamic imputation behavior.
To be used when NaN maps change at different timesteps.
"""

def get_nans(self, x: torch.Tensor) -> torch.Tensor:
"""Override to calculate NaN locations dynamically."""
return torch.isnan(x)

def fill_with_value(self, x, index, nan_locations):
# Replace values
for idx, value in zip(index, self.replacement):
if idx is not None:
x[..., idx][nan_locations[..., idx]] = value
return x

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
if not in_place:
Expand All @@ -261,12 +381,7 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
)

# Replace values
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
x[..., idx_dst][nan_locations[..., idx_src]] = value

return x
return self.fill_with_value(x, index, nan_locations)

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
Expand All @@ -282,7 +397,7 @@ def __init__(
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
) -> None:
super().__init__(config, data_indices, statistics)
InputImputer.__init__(self, config, data_indices, statistics)
warnings.warn(
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
The model will be trained to predict imputed values. This might deteriorate performances."
Expand All @@ -298,8 +413,51 @@ def __init__(
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
) -> None:
super().__init__(config, data_indices, statistics)
ConstantImputer.__init__(self, config, data_indices, statistics)
warnings.warn(
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
The model will be trained to predict imputed values. This might deteriorate performances."
)


class DynamicCopyImputer(DynamicMixin, CopyImputer):
"""Dynamic Copy imputation behavior."""

def __init__(
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
) -> None:
CopyImputer.__init__(self, config, data_indices, statistics)
warnings.warn(
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
The model will be trained to predict imputed values. This might deteriorate performances."
)

def fill_with_value(self, x, index, nan_locations):

if x.shape[-1] == self.num_training_input_vars:
indices = self.data_indices.data.input.name_to_index
elif x.shape[-1] == self.num_inference_input_vars:
indices = self.data_indices.model.input.name_to_index
else:
raise ValueError(
f"Input tensor ({x.shape[-1]}) does not match the training "
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
)

# Replace values
for idx, value in zip(index, self.replacement):
if idx is not None:
assert not torch.isnan(x[..., indices[value]][nan_locations[..., idx]]).any(), f"NaNs found in {value}."
x[..., idx][nan_locations[..., idx]] = x[..., indices[value]][nan_locations[..., idx]]
return x

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
return DynamicMixin.transform(self, x, in_place)

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
return DynamicMixin.inverse_transform(self, x, in_place)

0 comments on commit 4690ed5

Please sign in to comment.