From 56ca01d757debbed9bb563519992dd566fd8033f Mon Sep 17 00:00:00 2001 From: dkimpara Date: Thu, 16 Jan 2025 16:28:52 -0700 Subject: [PATCH] ensemble capability for train_universal.py --- config/example-v2025.2.0-ensemble.yml | 478 ++++++++++++++++++ config/test_cesm_ensemble.yml | 3 +- credit/parser.py | 5 + .../trainerERA5_multistep_grad_accum.py | 30 ++ 4 files changed, 515 insertions(+), 1 deletion(-) create mode 100644 config/example-v2025.2.0-ensemble.yml diff --git a/config/example-v2025.2.0-ensemble.yml b/config/example-v2025.2.0-ensemble.yml new file mode 100644 index 00000000..a8a97aa7 --- /dev/null +++ b/config/example-v2025.2.0-ensemble.yml @@ -0,0 +1,478 @@ +# --------------------------------------------------------------------------------------------------------------------- # +# This YAML file configures the training of a 1 or 6-hourly state-in-state-out Crossformer model +# on NSF NCAR high-performance computing systems (casper.ucar.edu and derecho.hpc.ucar.edu). +# +# The model is trained on 6-hourly ERA5 model-level data, incorporating key physical features: +# - Input: Model-level variables [U, V, T, Q], static variables [Z_GDS4_SFC, LSM], and dynamic forcing [tsi]. +# - Output: +# - Model level: [U, V, T, Q] +# - Single level: [SP, t2m] +# - 500 hPa level: [U, V, T, Z, Q] +# +# This configuration is designed for fine-tuning and prediction tasks using advanced components, +# such as spectral normalization, reflective padding, and tracer-specific postprocessing. +# +# Ensure compatibility with ERA5 data preprocessing workflows and metadata paths specified in this file. +# --------------------------------------------------------------------------------------------------------------------- # + + +# the location to save your workspace, it will have +# (1) pbs script, (2) a copy of this config, (3) model weights, (4) training_log.csv +# if save_loc does not exist, it will be created automatically +save_loc: '/glade/work/$USER/CREDIT_runs/tests/example_2025_2_0/' +seed: 1000 + +data: + # upper-air variables, must be YEARLY zarr or nc with (time, level, latitude, longitude) dims + # files must have the listed variable names + # upper-air variables will be normalized by the dataloader, users do not need to normalize them + variables: ['U','V','T','Q'] + save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y_TOTAL_*' + + # surface variables, must be YEARLY zarr or nc with (time, latitude, longitude) dims + # the time dimension MUST be the same as upper-air variables + # files must have the listed variable names + # surface variables will be normalized by the dataloader, users do not need to normalize them + surface_variables: ['SP', 't2m', 'Z500', 'T500', 'U500', 'V500', 'Q500'] + save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y_TOTAL_*' + + # dynamic forcing variables, must be YEARLY zarr or nc with (time, latitude, longitude) dims + # the time dimension MUST be the same as upper-air variables + # files must have the listed variable names + # surface variables will be normalized by the dataloader, users do not need to normalize them + dynamic_forcing_variables: ['tsi'] + save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc' + + # diagnostic variables, must be YEARLY zarr or nc with (time, latitude, longitude) dims + # the time dimension MUST be the same as upper-air variables + # files must have the listed variable names + # THESE ARE NOT LOADING CORRECTLY ON X + # diagnostic variables will be normalized by the dataloader, users do not need to normalize them + # diagnostic_variables: ['Z500', 'T500', 'U500', 'V500', 'Q500'] + # save_loc_diagnostic: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/SixHourly_y_TOTAL_*' + + # periodic forcing variables, must be a single zarr or nc with (time, latitude, longitude) dims + # the time dimension should cover an entire LEAP YEAR + # + # e.g., periodic forcing variables can be provided on the year 2000, + # and its time coords should have 24*366 hours for a hourly model + # + # THESE ARE REDUNDANT? + # periodic forcing variables MUST be normalized BY USER + # forcing_variables: ['TSI'] + # save_loc_forcing: '/glade/campaign/cisl/aiml/ksha/CREDIT/forcing_norm_6h.nc' + + # static variables must be a single zarr or nc with (latitude, longitude) coords + # static variables must be normalized BY USER + static_variables: ['Z_GDS4_SFC', 'LSM'] + save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc' + + # z-score files, they must be zarr or nc with (level,) coords + # they MUST include all the + # 'variables', 'surface_variables', 'dynamic_forcing_variables', 'diagnostic_variables' above + mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc' + std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc' + + # years to form the training / validation set [first_year, last_yeat (not covered)] + train_years: [1979, 2014] # 1979 - 2013 + valid_years: [2014, 2018] # 2014 - 2017 + + # std_new is the new data workflow that works with z-score + scaler_type: 'std_new' + # std_cached does not perform normalization, it works for cached dataset (preprocessed dataset) + # scaler_type: 'std_cached' + + # number of input time frames + history_len: 1 + valid_history_len: 1 # keep it the same as "history_len" + + # number of forecast lead times to predict during TRAINING + # 0 for single-step training + # 1, 2, 3, ... for multi-step training + forecast_len: 0 + # the "forecast_len" used for validation set, can be the same as or smaller than "forecast_len" + valid_forecast_len: 0 + + # TRAINING ONLY + # [Optional] backprop_on_timestep: use a list, for ex., [1, 2, 3, 5, 6, 7] to backprop on those timesteps when doing multi-step training + # backprop_on_timestep: [1, 2, 3, 5, 6, 7] + + # when forecast_len = 0, it does nothing + # when forecast_len > 0, it computes training loss on the last forecast state only + # this option speed-up multi-step training, it ONLY works with trainer type: standard + # use one_shot --> True + # do not use one_shot --> False or null + one_shot: False + + # number of hours for each forecast step + # lead_time_periods = 6 for 6 hourly model and 6 hourly taining data + # lead_time_periods = 1 for hourly model and hourly taining data + lead_time_periods: 6 + + # *NOT A STABLE FEATURE + # this is the keyword that resolves the mismatch between 6 hourly model and hourly data + # skip_periods = 6 will train a 6 hourly model on hourly data, by skipping and picking every 6 hour + # default is null or skip_periods = 1 + skip_periods: null + + # this keyword makes 'std_new' compatible with the old 'std' workflow + # + # True: input tensors will be formed under the order of [static -> dynamic_forcing -> forcing] + # True is the same as the old 'std' + # + # False: input tensors will be formed under the order of [dynamic_forcing -> forcing -> static] + static_first: False + + # SST forcing + # sst_forcing: + # activate: False + # varname_skt: 'SKT' + # varname_ocean_mask: 'land_sea_CI_mask' + + # dataset type. For now you may choose from + # ERA5_and_Forcing_MultiStep MultiprocessingBatcherPrefetch ERA5_MultiStep_Batcher MultiprocessingBatcher + # All of these support single or multi-step training (e.g. forecast lengths >= 0) + dataset_type: MultiprocessingBatcherPrefetch + +trainer: + # Choose the training routine + # This is deprecated as of credit-2.0 + # Use universal with all of the above dataset_types + type: universal + + # the keyword that controls GPU usage + # fsdp: fully sharded data parallel + # ddp: distributed data parallel + # none: single-GPU training + mode: none + + # fsdp-specific GPU optimization + # + # allow FSDP offloads gradients to CPU and free GPU memory + # This can cause CPU out-of-memory if for some large models, do your diligence + cpu_offload: False + # save forward pass activation to checkpoints and free GPU memory + activation_checkpoint: True + + # (optional) use torch.compile: False. May not be compatible with custom models + compile: False + + # load existing weights / optimizer / mixed-precision grad scaler / learning rate scheduler state(s) + # when starting multi-step training, only use reload weights initially, then set all to true + load_weights: False + load_optimizer: False + load_scaler: False + load_scheduler: False + + # CREDIT save checkpoints at the end of every epoch + # save_backup_weights will also save checkpoints in the beginning of every epoch + save_backup_weights: True + # save_best_weights will save the best checkpoint separatly based on validation loss + # does not work if skip_validation: True + save_best_weights: True + + # Optional for user to control which variable metrics get saved to training_log.csv + # Set to True to save metrics for all predicted variables + # Set to ["Z500", "Q500", "Q", "T"] for example to save these variables (with levels in the case of Q and T) + # Set to [] or None to save only bulk metrics averaged over all variables + save_metric_vars: True + + # update learning rate to optimizer.param_groups + # False if a scheduler is used + update_learning_rate: False + + # learning rate + learning_rate: 1.0e-03 + + # L2 regularization on trained weights + # weight_decay: 0 --> turn-off L2 reg + weight_decay: 0 + + # define training and validation batch size + # for ddp and fsdp, actual batch size = batch_size * number of GPUs + train_batch_size: 1 + valid_batch_size: 1 + ensemble_size: 2 + + # number of batches per epoch for training and validation + batches_per_epoch: 2 # Set to 0 to use len(dataloader) + valid_batches_per_epoch: 2 + # early stopping + stopping_patience: 50 + # True: do not validate, always save weights + skip_validation: False + + # total number of epochs + start_epoch: 0 + # the trainer will stop after iterating a given number of epochs + # this works for any start_epoch and reload_epoch: True + num_epoch: 2 + # epcoh is saved in the checkpot.pt, this option reads the last epoch + # of the previous job and continue from there + # useful for epoch-dependent schedulers + reload_epoch: True # can also use --> train_one_epoch: True + epochs: &epochs 70 + # if use_scheduler: False + # reload_epoch: False + # epochs: 20 + + use_scheduler: False + #scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1} + scheduler: + scheduler_type: cosine-annealing-restarts + first_cycle_steps: 250 + cycle_mult: 6.0 # Multiplier for steps in subsequent cycles + max_lr: 1.0e-05 + min_lr: 1.0e-08 + warmup_steps: 249 + gamma: 0.7 # LR reduction factor at each cycle restart + + # Pytorch automatic mixed precision (amp). See also 'mixed_precision' below + amp: False + + # This version of mixed precision is supported only with FSDP and gives you more control. Set amp = False. + # See https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision for details + # Many different precision types are supported. See credit.mixed_precision + # If commented out, the default is float32. + mixed_precision: + param_dtype: "float32" + reduce_dtype: "float32" + buffer_dtype: "float32" + + # rescale loss as loss = loss / grad_accum_every + grad_accum_every: 1 + + # gradient clipping. Set to 0 to ignore + grad_max_norm: 1.0 + + # number of CPU workers used in datasets/dataloaders + thread_workers: 4 + valid_thread_workers: 4 + + # Number of prefetch workers in the DataLoader + # This only works with ERA5_MultiStep_Batcher or MultiprocessingBatcherPrefetch + prefetch_factor: 4 + +model: + # NCAR WXFormer example + type: "crossformer" + frames: 1 # number of input states (default: 1) + image_height: 640 # number of latitude grids (default: 640) + image_width: 1280 # number of longitude grids (default: 1280) + levels: 16 # number of upper-air variable levels (default: 15) + channels: 4 # upper-air variable channels + surface_channels: 7 # surface variable channels + input_only_channels: 3 # dynamic forcing, forcing, static channels + output_only_channels: 0 # diagnostic variable channels + + patch_width: 1 # number of latitude grids in each 3D patch (default: 1) + patch_height: 1 # number of longitude grids in each 3D patch (default: 1) + frame_patch_size: 1 # number of input states in each 3D patch (default: 1) + + dim: [32, 64, 128, 256] # Dimensionality of each layer + depth: [2, 2, 2, 2] # Depth of each layer + global_window_size: [10, 5, 2, 1] # Global window size for each layer + local_window_size: 10 # Local window size + cross_embed_kernel_sizes: # kernel sizes for cross-embedding + - [4, 8, 16, 32] + - [2, 4] + - [2, 4] + - [2, 4] + cross_embed_strides: [2, 2, 2, 2] # Strides for cross-embedding (default: [4, 2, 2, 2]) + attn_dropout: 0. # Dropout probability for attention layers (default: 0.0) + ff_dropout: 0. # Dropout probability for feed-forward layers (default: 0.0) + + # fuxi example + # frames: 2 # number of input states + # image_height: &height 640 # number of latitude grids + # image_width: &width 1280 # number of longitude grids + # levels: &levels 16 # number of upper-air variable levels + # channels: 4 # upper-air variable channels + # surface_channels: 7 # surface variable channels + # input_only_channels: 3 # dynamic forcing, forcing, static channels + # output_only_channels: 0 # diagnostic variable channels + + # # patchify layer + # patch_height: 4 # number of latitude grids in each 3D patch + # patch_width: 4 # number of longitude grids in each 3D patch + # frame_patch_size: 2 # number of input states in each 3D patch + + # # hidden layers + # dim: 1024 # dimension (default: 1536) + # num_groups: 32 # number of groups (default: 32) + # num_heads: 8 # number of heads (default: 8) + # window_size: 7 # window size (default: 7) + # depth: 16 # number of swin transformers (default: 48) + + # for BOTH fuxi and crossformer + # use spectral norm + use_spectral_norm: True + + # use interpolation to match the output size + interp: True + + # map boundary padding + padding_conf: + activate: True + mode: earth + pad_lat: 80 + pad_lon: 80 + + # configuration for postblock processing + post_conf: + activate: True + + skebs: + activate: False + + tracer_fixer: + activate: True + denorm: True + tracer_name: ['Q', 'Q500'] + tracer_thres: [1e-8, 1e-8] + + global_mass_fixer: + activate: False + + global_energy_fixer: + activate: False + +loss: + # the main training loss + # available options are "mse", "msle", "mae", "huber", "logcosh", "xtanh", "xsigmoid" + # available ensemble options are "KCRPS" + training_loss: "KCRPS" + + # use power or spectral loss + # if True, this loss will be added to the training_loss: + # total_loss = training_loss + spectral_lambda_reg * power_loss (or spectral_loss) + # + # it is preferred that power loss and spectral loss are NOT applied at the same time + use_power_loss: False + use_spectral_loss: False # if power_loss is on, turn off spectral_loss, vice versa + # rescale power or spectral loss when added to the total loss + spectral_lambda_reg: 0.1 + # truncate small wavenumber (large wavelength) in power or spectral loss + spectral_wavenum_init: 20 + + # this file is MANDATORY for the "predict" section below (inference stage) + # the file must be nc and must contain 1D variables named "latitude" and "longitude" + latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc" + + # use latitude weighting + # if True, the latitude_weights file MUST have a variable named "coslat" + # coslat is the np.cos(latitude) + use_latitude_weights: True + + # variable weighting + # if True, specify your weights + use_variable_weights: False + # # an example + # variable_weights: + # U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005] + # V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005] + # T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005] + # Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005] + # SP: 0.1 + # t2m: 1.0 + # V500: 0.1 + # U500: 0.1 + # T500: 0.1 + # Z500: 0.1 + # Q500: 0.1 + + +predict: + # You have the choice to run rollout_metrics.py, which will only save various metrics like RMSE, etc. + # A directory called metrics will be added to the save_forecast field below. + # To instead save the states to file (netcdf), run rollout_to_metrics.py. You will need to post-process + # all the data with this option. A directory called netcdf will be added to the save_forecast field + + + # the keyword that controls GPU usage + # fsdp: fully sharded data parallel + # ddp: distributed data parallel + # none: single-GPU training + mode: none + + # Set the batch_size for the prediction inference mode (default is 1) + batch_size: 1 + + forecasts: + type: "custom" # keep it as "custom". See also credit.forecast + start_year: 2019 # year of the first initialization (where rollout will start) + start_month: 1 # month of the first initialization + start_day: 1 # day of the first initialization + start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z + duration: 1152 # number of days to initialize, starting from the (year, mon, day) above + # duration should be divisible by the number of GPUs + # (e.g., duration: 384 for 365-day rollout using 32 GPUs) + days: 10 # forecast lead time as days (1 means 24-hour forecast) + + # saved variables (rollout_to_netcdf.yml) + # save_vars: [] will save ALL variables + # remove save_vars from config will save ALL variables + metadata: '/glade/work/schreck/repos/credit/miles-credit/metadata/era5.yaml' + + # The location to store rollout predictions, the folder structure will be: + # $save_forecast/$initialization_time/file.nc + # each forecast lead time produces a file.nc + save_forecast: '/glade/u/home/schreck/scratch/finetune/wx12/' + + # this keyword will apply a low-pass filter to all fields and each forecast step + # it will reduce high-freq noise, but may impact performance + use_laplace_filter: False + + # climatology file to be used with rollout_metrics.py. + # Supplying this will compute ananomly ACC, otherwise its Pearson coeffcient (acc in metrics). + climatology: '/glade/campaign/cisl/aiml/ksha/CREDIT_arXiv/VERIF/verif_6h/ERA5_clim/ERA5_clim_1990_2019_6h_interp.nc' + + # Optional: Interpolate data field from hybrid sigma-pressure vertical coordinates to pressure levels. + # Used with rollout_to_netcdf.py + + # interp_pressure: + # pressure_levels: [300.0, 500.0, 850.0, 925.0] + + # ua_var_encoding: + # zlib: True + # complevel: 1 + # shuffle: True + # chunksizes: [1, *levels, *height, *width] + + # pressure_var_encoding: + # zlib: True + # complevel: 1 + # shuffle: True + # chunksizes: [ 1, 4, *height, *width] + + # surface_var_encoding: + # zlib: true + # complevel: 1 + # shuffle: True + # chunksizes: [1, *height, *width] + + +# credit.pbs supports NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu) +pbs: + # example for derecho + conda: "credit-derecho" + project: "NAML0001" + job_name: "train_model" + walltime: "12:00:00" + nodes: 8 + ncpus: 64 + ngpus: 4 + mem: '480GB' + queue: 'main' + + # examaple for casper + # conda: "credit" + # job_name: 'train_model' + # nodes: 1 + # ncpus: 8 + # ngpus: 1 + # mem: '128GB' + # walltime: '4:00:00' + # gpu_type: 'v100' + # project: 'NRIS0001' + # queue: 'casper' diff --git a/config/test_cesm_ensemble.yml b/config/test_cesm_ensemble.yml index 799c281a..61efd4fb 100644 --- a/config/test_cesm_ensemble.yml +++ b/config/test_cesm_ensemble.yml @@ -4,7 +4,7 @@ # The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask # Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q] # --------------------------------------------------------------------------------------------------------------------- # -save_loc: '/glade/work/$USER/CREDIT_runs/test_cesm_ensemble/' +save_loc: '/glade/work/$USER/CREDIT_runs/tests/cesm_ensemble/' seed: 1000 #net short top of model: FSNT @@ -29,6 +29,7 @@ seed: 1000 #FLDT is not usually output from the model, but FLNT and FLUT are. SOLIN is the insolation. data: + dataset_type: "ERA5_and_Forcing_MultiStep" # upper-air variables variables: ['U','V','T','Q'] save_loc: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/f.e21.CREDIT_climate_????.zarr' diff --git a/credit/parser.py b/credit/parser.py index 12d78101..04ef3110 100644 --- a/credit/parser.py +++ b/credit/parser.py @@ -753,6 +753,11 @@ def credit_main_parser( if "ensemble_size" not in conf["trainer"]: conf["trainer"]["ensemble_size"] = 1 + + if conf["trainer"]["ensemble_size"] > 1: + assert ( + conf["loss"]["training_loss"] in ["KCRPS"] + ), f"{conf["loss"]["training_loss"]} loss incompatible with ensemble training. ensemble_size is {conf["trainer"]["ensemble_size"]}" if "load_scaler" not in conf["trainer"]: conf["trainer"]["load_scaler"] = False diff --git a/credit/trainers/trainerERA5_multistep_grad_accum.py b/credit/trainers/trainerERA5_multistep_grad_accum.py index 24b8c039..86cb9842 100644 --- a/credit/trainers/trainerERA5_multistep_grad_accum.py +++ b/credit/trainers/trainerERA5_multistep_grad_accum.py @@ -78,6 +78,9 @@ def train_one_epoch( amp = conf["trainer"]["amp"] distributed = True if conf["trainer"]["mode"] in ["fsdp", "ddp"] else False forecast_length = conf["data"]["forecast_len"] + ensemble_size = conf["trainer"].get("ensemble_size", 1) + if ensemble_size > 1: + logger.info(f"ensemble training with ensemble_size {ensemble_size}") # number of diagnostic variables varnum_diag = len(conf["data"]["diagnostic_variables"]) @@ -187,6 +190,14 @@ def train_one_epoch( else: # no x_surf x = reshape_only(batch["x"]).to(self.device) # .float() + + # --------------------------------------------- # + # ensemble x and x_surf on initialization + # copies each sample in the batch ensemble_size number of times. + # if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...) + # WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples + if ensemble_size > 1: + x = torch.repeat_interleave(x, ensemble_size, 0) # add forcing and static variables (regardless of fcst hours) if "x_forcing_static" in batch: @@ -194,6 +205,11 @@ def train_one_epoch( x_forcing_batch = ( batch["x_forcing_static"].to(self.device).permute(0, 2, 1, 3, 4) ) # .float() + # ---------------- ensemble ----------------- # + # ensemble x_forcing_batch for concat. see above for explanation of code + if ensemble_size > 1: + x_forcing_batch = torch.repeat_interleave(x_forcing_batch, ensemble_size, 0) + # --------------------------------------------- # # concat on var dimension x = torch.cat((x, x_forcing_batch), dim=1) @@ -404,6 +420,8 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): if "valid_forecast_len" in conf["data"] else conf["forecast_len"] ) + ensemble_size = conf["trainer"].get("ensemble_size", 1) + distributed = True if conf["trainer"]["mode"] in ["fsdp", "ddp"] else False results_dict = defaultdict(list) @@ -487,6 +505,13 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): else: # no x_surf x = reshape_only(batch["x"]).to(self.device) # .float() + # --------------------------------------------- # + # ensemble x and x_surf on initialization + # copies each sample in the batch ensemble_size number of times. + # if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...) + # WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples + if ensemble_size > 1: + x = torch.repeat_interleave(x, ensemble_size, 0) # add forcing and static variables (regardless of fcst hours) if "x_forcing_static" in batch: @@ -496,6 +521,11 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): .to(self.device) .permute(0, 2, 1, 3, 4) ) # .float() + # ---------------- ensemble ----------------- # + # ensemble x_forcing_batch for concat. see above for explanation of code + if ensemble_size > 1: + x_forcing_batch = torch.repeat_interleave(x_forcing_batch, ensemble_size, 0) + # --------------------------------------------- # # concat on var dimension x = torch.cat((x, x_forcing_batch), dim=1)