Skip to content

Commit

Permalink
condition and supervise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 16, 2025
1 parent b4018a4 commit 12a0d89
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def parse_args():
'2 means every other residual block, 3 would mean every third.',
)
parser.add_argument(
'--diffusion_condition_strategy', default='concat', choices=['cross_attention', 'concat', 'film'],
'--diffusion_condition_strategy', default='cross_attention',
choices=['cross_attention', 'concat', 'film'],
help='For diffusion models, this controls conditional embeddings are integrated into the U-NET',
)
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def train_diffusion_control_model(args, supervised=False):
metrics = model.evaluate(generate_test, batch_size=args.batch_size, steps=args.test_steps, return_dict=True)
logging.info(f'Test metrics: {metrics}')

data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1)
data, labels, paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/')
image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]}
predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/reconstructions/')
Expand Down

0 comments on commit 12a0d89

Please sign in to comment.