@@ -44,7 +44,7 @@ def __init__(
44
44
super ().__init__ (config , data_indices , statistics )
45
45
46
46
self .nan_locations = None
47
- # weight imputed values wiht zero in loss calculation
47
+ # weight imputed values with zero in loss calculation
48
48
self .loss_mask_training = None
49
49
50
50
def _validate_indices (self ):
@@ -113,6 +113,12 @@ def get_nans(self, x: torch.Tensor) -> torch.Tensor:
113
113
idx = [slice (0 , 1 )] * (x .ndim - 2 ) + [slice (None ), slice (None )]
114
114
return torch .isnan (x [idx ].squeeze ())
115
115
116
+ def fill_with_value (self , x , index ):
117
+ for idx_src , (idx_dst , value ) in zip (self .index_training_input , zip (index , self .replacement )):
118
+ if idx_dst is not None :
119
+ x [..., idx_dst ][self ._expand_subset_mask (x , idx_src )] = value
120
+ return x
121
+
116
122
def transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
117
123
"""Impute missing values in the input tensor."""
118
124
if not in_place :
@@ -145,10 +151,7 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
145
151
)
146
152
147
153
# Replace values
148
- for idx_src , (idx_dst , value ) in zip (self .index_training_input , zip (index , self .replacement )):
149
- if idx_dst is not None :
150
- x [..., idx_dst ][self ._expand_subset_mask (x , idx_src )] = value
151
- return x
154
+ return self .fill_with_value (x , index )
152
155
153
156
def inverse_transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
154
157
"""Impute missing values in the input tensor."""
@@ -231,13 +234,130 @@ def __init__(
231
234
self ._validate_indices ()
232
235
233
236
237
+ class CopyImputer (BaseImputer ):
238
+ """Imputes missing values copying them from another variable.
239
+ ```
240
+ default: "none"
241
+ variable_to_copy:
242
+ - variable_missing_1
243
+ - variable_missing_2
244
+ ```
245
+ """
246
+
247
+ def __init__ (
248
+ self ,
249
+ config = None ,
250
+ data_indices : Optional [IndexCollection ] = None ,
251
+ statistics : Optional [dict ] = None ,
252
+ ) -> None :
253
+ super ().__init__ (config , data_indices , statistics )
254
+
255
+ self ._create_imputation_indices ()
256
+
257
+ self ._validate_indices ()
258
+
259
+ def _create_imputation_indices (
260
+ self ,
261
+ ):
262
+ """Create the indices for imputation."""
263
+ name_to_index_training_input = self .data_indices .data .input .name_to_index
264
+ name_to_index_inference_input = self .data_indices .model .input .name_to_index
265
+ name_to_index_training_output = self .data_indices .data .output .name_to_index
266
+ name_to_index_inference_output = self .data_indices .model .output .name_to_index
267
+
268
+ self .num_training_input_vars = len (name_to_index_training_input )
269
+ self .num_inference_input_vars = len (name_to_index_inference_input )
270
+ self .num_training_output_vars = len (name_to_index_training_output )
271
+ self .num_inference_output_vars = len (name_to_index_inference_output )
272
+
273
+ (
274
+ self .index_training_input ,
275
+ self .index_inference_input ,
276
+ self .index_training_output ,
277
+ self .index_inference_output ,
278
+ self .replacement ,
279
+ ) = ([], [], [], [], [])
280
+
281
+ # Create indices for imputation
282
+ for name in name_to_index_training_input :
283
+ key_to_copy = self .methods .get (name , self .default )
284
+
285
+ if key_to_copy == "none" :
286
+ LOGGER .debug (f"Imputer: skipping { name } as no imputation method is specified" )
287
+ continue
288
+
289
+ self .index_training_input .append (name_to_index_training_input [name ])
290
+ self .index_training_output .append (name_to_index_training_output .get (name , None ))
291
+ self .index_inference_input .append (name_to_index_inference_input .get (name , None ))
292
+ self .index_inference_output .append (name_to_index_inference_output .get (name , None ))
293
+
294
+ self .replacement .append (key_to_copy )
295
+
296
+ LOGGER .debug (f"Imputer: replacing NaNs in { name } with value coming from variable :{ self .replacement [- 1 ]} " )
297
+
298
+ def fill_with_value (self , x , index ):
299
+ # Replace values
300
+ for idx_src , (idx_dst , value ) in zip (self .index_training_input , zip (index , self .replacement )):
301
+ if idx_dst is not None :
302
+ assert not torch .isnan (
303
+ x [..., self .data_indices .data .input .name_to_index [value ]][self ._expand_subset_mask (x , idx_src )]
304
+ ).any (), f"NaNs found in { value } ."
305
+ x [..., idx_dst ][self ._expand_subset_mask (x , idx_src )] = x [
306
+ ..., self .data_indices .data .input .name_to_index [value ]
307
+ ][self ._expand_subset_mask (x , idx_src )]
308
+ return x
309
+
310
+ def transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
311
+ """Impute missing values in the input tensor."""
312
+ if not in_place :
313
+ x = x .clone ()
314
+
315
+ # Initialize nan mask once
316
+ if self .nan_locations is None :
317
+
318
+ # Get NaN locations
319
+ self .nan_locations = self .get_nans (x )
320
+
321
+ # Initialize training loss mask to weigh imputed values with zeroes once
322
+ self .loss_mask_training = torch .ones (
323
+ (x .shape [- 2 ], len (self .data_indices .model .output .name_to_index )), device = x .device
324
+ ) # shape (grid, n_outputs)
325
+ # for all variables that are imputed and part of the model output, set the loss weight to zero
326
+ for idx_src , idx_dst in zip (self .index_training_input , self .index_inference_output ):
327
+ if idx_dst is not None :
328
+ self .loss_mask_training [:, idx_dst ] = (~ self .nan_locations [:, idx_src ]).int ()
329
+
330
+ # Choose correct index based on number of variables
331
+ if x .shape [- 1 ] == self .num_training_input_vars :
332
+ index = self .index_training_input
333
+ elif x .shape [- 1 ] == self .num_inference_input_vars :
334
+ index = self .index_inference_input
335
+ else :
336
+ raise ValueError (
337
+ f"Input tensor ({ x .shape [- 1 ]} ) does not match the training "
338
+ f"({ self .num_training_input_vars } ) or inference shape ({ self .num_inference_input_vars } )" ,
339
+ )
340
+
341
+ return self .fill_with_value (x , index )
342
+
343
+
234
344
class DynamicMixin :
235
- """Mixin to add dynamic imputation behavior."""
345
+ """
346
+ Mixin to add dynamic imputation behavior.
347
+ To be used when NaN maps change at different timesteps.
348
+ """
236
349
237
350
def get_nans (self , x : torch .Tensor ) -> torch .Tensor :
238
351
"""Override to calculate NaN locations dynamically."""
239
352
return torch .isnan (x )
240
353
354
+ def fill_with_value (self , x , index , nan_locations ):
355
+ # Replace values
356
+ for idx , value in zip (index , self .replacement ):
357
+ if idx is not None :
358
+ x [..., idx ][nan_locations [..., idx ]] = value
359
+ return x
360
+
241
361
def transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
242
362
"""Impute missing values in the input tensor."""
243
363
if not in_place :
@@ -261,12 +381,7 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
261
381
f"({ self .num_training_input_vars } ) or inference shape ({ self .num_inference_input_vars } )" ,
262
382
)
263
383
264
- # Replace values
265
- for idx_src , (idx_dst , value ) in zip (self .index_training_input , zip (index , self .replacement )):
266
- if idx_dst is not None :
267
- x [..., idx_dst ][nan_locations [..., idx_src ]] = value
268
-
269
- return x
384
+ return self .fill_with_value (x , index , nan_locations )
270
385
271
386
def inverse_transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
272
387
"""Impute missing values in the input tensor."""
@@ -282,7 +397,7 @@ def __init__(
282
397
data_indices : Optional [IndexCollection ] = None ,
283
398
statistics : Optional [dict ] = None ,
284
399
) -> None :
285
- super () .__init__ (config , data_indices , statistics )
400
+ InputImputer .__init__ (self , config , data_indices , statistics )
286
401
warnings .warn (
287
402
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
288
403
The model will be trained to predict imputed values. This might deteriorate performances."
@@ -298,8 +413,51 @@ def __init__(
298
413
data_indices : Optional [IndexCollection ] = None ,
299
414
statistics : Optional [dict ] = None ,
300
415
) -> None :
301
- super ().__init__ (config , data_indices , statistics )
416
+ ConstantImputer .__init__ (self , config , data_indices , statistics )
417
+ warnings .warn (
418
+ "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
419
+ The model will be trained to predict imputed values. This might deteriorate performances."
420
+ )
421
+
422
+
423
+ class DynamicCopyImputer (DynamicMixin , CopyImputer ):
424
+ """Dynamic Copy imputation behavior."""
425
+
426
+ def __init__ (
427
+ self ,
428
+ config = None ,
429
+ data_indices : Optional [IndexCollection ] = None ,
430
+ statistics : Optional [dict ] = None ,
431
+ ) -> None :
432
+ CopyImputer .__init__ (self , config , data_indices , statistics )
302
433
warnings .warn (
303
434
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
304
435
The model will be trained to predict imputed values. This might deteriorate performances."
305
436
)
437
+
438
+ def fill_with_value (self , x , index , nan_locations ):
439
+
440
+ if x .shape [- 1 ] == self .num_training_input_vars :
441
+ indices = self .data_indices .data .input .name_to_index
442
+ elif x .shape [- 1 ] == self .num_inference_input_vars :
443
+ indices = self .data_indices .model .input .name_to_index
444
+ else :
445
+ raise ValueError (
446
+ f"Input tensor ({ x .shape [- 1 ]} ) does not match the training "
447
+ f"({ self .num_training_input_vars } ) or inference shape ({ self .num_inference_input_vars } )" ,
448
+ )
449
+
450
+ # Replace values
451
+ for idx , value in zip (index , self .replacement ):
452
+ if idx is not None :
453
+ assert not torch .isnan (x [..., indices [value ]][nan_locations [..., idx ]]).any (), f"NaNs found in { value } ."
454
+ x [..., idx ][nan_locations [..., idx ]] = x [..., indices [value ]][nan_locations [..., idx ]]
455
+ return x
456
+
457
+ def transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
458
+ """Impute missing values in the input tensor."""
459
+ return DynamicMixin .transform (self , x , in_place )
460
+
461
+ def inverse_transform (self , x : torch .Tensor , in_place : bool = True ) -> torch .Tensor :
462
+ """Impute missing values in the input tensor."""
463
+ return DynamicMixin .inverse_transform (self , x , in_place )
0 commit comments