Skip to content

Commit 4690ed5

Browse files
authored
feat(models): Copy Imputer (#72)
* Implemented copy imputer * Fixed implementation and tested * gpc * Removed print leftover * Refactor according to review * gpc passed * Fixed issue * Minor changes * init error * Merged develop * Fixed dynamic imputer * Fixed issue when running inference with dynamic imputer and diagnostic variables * Addressed changes in review
1 parent f1cc2e6 commit 4690ed5

File tree

1 file changed

+172
-14
lines changed

1 file changed

+172
-14
lines changed

models/src/anemoi/models/preprocessing/imputer.py

Lines changed: 172 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
super().__init__(config, data_indices, statistics)
4545

4646
self.nan_locations = None
47-
# weight imputed values wiht zero in loss calculation
47+
# weight imputed values with zero in loss calculation
4848
self.loss_mask_training = None
4949

5050
def _validate_indices(self):
@@ -113,6 +113,12 @@ def get_nans(self, x: torch.Tensor) -> torch.Tensor:
113113
idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)]
114114
return torch.isnan(x[idx].squeeze())
115115

116+
def fill_with_value(self, x, index):
117+
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
118+
if idx_dst is not None:
119+
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value
120+
return x
121+
116122
def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
117123
"""Impute missing values in the input tensor."""
118124
if not in_place:
@@ -145,10 +151,7 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
145151
)
146152

147153
# Replace values
148-
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
149-
if idx_dst is not None:
150-
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value
151-
return x
154+
return self.fill_with_value(x, index)
152155

153156
def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
154157
"""Impute missing values in the input tensor."""
@@ -231,13 +234,130 @@ def __init__(
231234
self._validate_indices()
232235

233236

237+
class CopyImputer(BaseImputer):
238+
"""Imputes missing values copying them from another variable.
239+
```
240+
default: "none"
241+
variable_to_copy:
242+
- variable_missing_1
243+
- variable_missing_2
244+
```
245+
"""
246+
247+
def __init__(
248+
self,
249+
config=None,
250+
data_indices: Optional[IndexCollection] = None,
251+
statistics: Optional[dict] = None,
252+
) -> None:
253+
super().__init__(config, data_indices, statistics)
254+
255+
self._create_imputation_indices()
256+
257+
self._validate_indices()
258+
259+
def _create_imputation_indices(
260+
self,
261+
):
262+
"""Create the indices for imputation."""
263+
name_to_index_training_input = self.data_indices.data.input.name_to_index
264+
name_to_index_inference_input = self.data_indices.model.input.name_to_index
265+
name_to_index_training_output = self.data_indices.data.output.name_to_index
266+
name_to_index_inference_output = self.data_indices.model.output.name_to_index
267+
268+
self.num_training_input_vars = len(name_to_index_training_input)
269+
self.num_inference_input_vars = len(name_to_index_inference_input)
270+
self.num_training_output_vars = len(name_to_index_training_output)
271+
self.num_inference_output_vars = len(name_to_index_inference_output)
272+
273+
(
274+
self.index_training_input,
275+
self.index_inference_input,
276+
self.index_training_output,
277+
self.index_inference_output,
278+
self.replacement,
279+
) = ([], [], [], [], [])
280+
281+
# Create indices for imputation
282+
for name in name_to_index_training_input:
283+
key_to_copy = self.methods.get(name, self.default)
284+
285+
if key_to_copy == "none":
286+
LOGGER.debug(f"Imputer: skipping {name} as no imputation method is specified")
287+
continue
288+
289+
self.index_training_input.append(name_to_index_training_input[name])
290+
self.index_training_output.append(name_to_index_training_output.get(name, None))
291+
self.index_inference_input.append(name_to_index_inference_input.get(name, None))
292+
self.index_inference_output.append(name_to_index_inference_output.get(name, None))
293+
294+
self.replacement.append(key_to_copy)
295+
296+
LOGGER.debug(f"Imputer: replacing NaNs in {name} with value coming from variable :{self.replacement[-1]}")
297+
298+
def fill_with_value(self, x, index):
299+
# Replace values
300+
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
301+
if idx_dst is not None:
302+
assert not torch.isnan(
303+
x[..., self.data_indices.data.input.name_to_index[value]][self._expand_subset_mask(x, idx_src)]
304+
).any(), f"NaNs found in {value}."
305+
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = x[
306+
..., self.data_indices.data.input.name_to_index[value]
307+
][self._expand_subset_mask(x, idx_src)]
308+
return x
309+
310+
def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
311+
"""Impute missing values in the input tensor."""
312+
if not in_place:
313+
x = x.clone()
314+
315+
# Initialize nan mask once
316+
if self.nan_locations is None:
317+
318+
# Get NaN locations
319+
self.nan_locations = self.get_nans(x)
320+
321+
# Initialize training loss mask to weigh imputed values with zeroes once
322+
self.loss_mask_training = torch.ones(
323+
(x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device
324+
) # shape (grid, n_outputs)
325+
# for all variables that are imputed and part of the model output, set the loss weight to zero
326+
for idx_src, idx_dst in zip(self.index_training_input, self.index_inference_output):
327+
if idx_dst is not None:
328+
self.loss_mask_training[:, idx_dst] = (~self.nan_locations[:, idx_src]).int()
329+
330+
# Choose correct index based on number of variables
331+
if x.shape[-1] == self.num_training_input_vars:
332+
index = self.index_training_input
333+
elif x.shape[-1] == self.num_inference_input_vars:
334+
index = self.index_inference_input
335+
else:
336+
raise ValueError(
337+
f"Input tensor ({x.shape[-1]}) does not match the training "
338+
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
339+
)
340+
341+
return self.fill_with_value(x, index)
342+
343+
234344
class DynamicMixin:
235-
"""Mixin to add dynamic imputation behavior."""
345+
"""
346+
Mixin to add dynamic imputation behavior.
347+
To be used when NaN maps change at different timesteps.
348+
"""
236349

237350
def get_nans(self, x: torch.Tensor) -> torch.Tensor:
238351
"""Override to calculate NaN locations dynamically."""
239352
return torch.isnan(x)
240353

354+
def fill_with_value(self, x, index, nan_locations):
355+
# Replace values
356+
for idx, value in zip(index, self.replacement):
357+
if idx is not None:
358+
x[..., idx][nan_locations[..., idx]] = value
359+
return x
360+
241361
def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
242362
"""Impute missing values in the input tensor."""
243363
if not in_place:
@@ -261,12 +381,7 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
261381
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
262382
)
263383

264-
# Replace values
265-
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
266-
if idx_dst is not None:
267-
x[..., idx_dst][nan_locations[..., idx_src]] = value
268-
269-
return x
384+
return self.fill_with_value(x, index, nan_locations)
270385

271386
def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
272387
"""Impute missing values in the input tensor."""
@@ -282,7 +397,7 @@ def __init__(
282397
data_indices: Optional[IndexCollection] = None,
283398
statistics: Optional[dict] = None,
284399
) -> None:
285-
super().__init__(config, data_indices, statistics)
400+
InputImputer.__init__(self, config, data_indices, statistics)
286401
warnings.warn(
287402
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
288403
The model will be trained to predict imputed values. This might deteriorate performances."
@@ -298,8 +413,51 @@ def __init__(
298413
data_indices: Optional[IndexCollection] = None,
299414
statistics: Optional[dict] = None,
300415
) -> None:
301-
super().__init__(config, data_indices, statistics)
416+
ConstantImputer.__init__(self, config, data_indices, statistics)
417+
warnings.warn(
418+
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
419+
The model will be trained to predict imputed values. This might deteriorate performances."
420+
)
421+
422+
423+
class DynamicCopyImputer(DynamicMixin, CopyImputer):
424+
"""Dynamic Copy imputation behavior."""
425+
426+
def __init__(
427+
self,
428+
config=None,
429+
data_indices: Optional[IndexCollection] = None,
430+
statistics: Optional[dict] = None,
431+
) -> None:
432+
CopyImputer.__init__(self, config, data_indices, statistics)
302433
warnings.warn(
303434
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
304435
The model will be trained to predict imputed values. This might deteriorate performances."
305436
)
437+
438+
def fill_with_value(self, x, index, nan_locations):
439+
440+
if x.shape[-1] == self.num_training_input_vars:
441+
indices = self.data_indices.data.input.name_to_index
442+
elif x.shape[-1] == self.num_inference_input_vars:
443+
indices = self.data_indices.model.input.name_to_index
444+
else:
445+
raise ValueError(
446+
f"Input tensor ({x.shape[-1]}) does not match the training "
447+
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
448+
)
449+
450+
# Replace values
451+
for idx, value in zip(index, self.replacement):
452+
if idx is not None:
453+
assert not torch.isnan(x[..., indices[value]][nan_locations[..., idx]]).any(), f"NaNs found in {value}."
454+
x[..., idx][nan_locations[..., idx]] = x[..., indices[value]][nan_locations[..., idx]]
455+
return x
456+
457+
def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
458+
"""Impute missing values in the input tensor."""
459+
return DynamicMixin.transform(self, x, in_place)
460+
461+
def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
462+
"""Impute missing values in the input tensor."""
463+
return DynamicMixin.inverse_transform(self, x, in_place)

0 commit comments

Comments
 (0)