diff --git a/ml4h/models/train.py b/ml4h/models/train.py index eb935f9aa..c8a0affd6 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -302,7 +302,7 @@ def train_diffusion_control_model(args, supervised=False): model = DiffusionController( args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, - args.inspect_model, args.sigmoid_beta, args.diffusion_condition_strategy, + args.sigmoid_beta, args.diffusion_condition_strategy, args.inspect_model, ) loss = keras.losses.mean_absolute_error if args.diffusion_loss == 'mean_absolute_error' else keras.losses.mean_squared_error