Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot training calibration, PR, and ROC curves; logging of label breakdown and number of epochs #332 #340

Closed
wants to merge 14 commits into from
1 change: 1 addition & 0 deletions ml4cvd/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def parse_args():
parser.add_argument("--embed_visualization", help="Method to visualize embed layer. Options: None, tsne, or umap")
parser.add_argument("--explore_export_errors", default=False, action="store_true", help="Export error_type columns in tensors_all*.csv generated by explore.")
parser.add_argument('--plot_hist', default=True, help='Plot histograms of continuous tensors in explore mode.')
parser.add_argument('--plot_train_curves', default=False, action="store_true", help='Plot PR and ROC curves for training set.')

# Training optimization options
parser.add_argument('--num_workers', default=multiprocessing.cpu_count(), type=int, help="Number of workers to use for every tensor generator.")
Expand Down
6 changes: 4 additions & 2 deletions ml4cvd/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,7 @@ def train_model_from_generators(
inspect_show_labels: bool,
return_history: bool = False,
plot: bool = True,
defer_worker_halt: bool = False
) -> Union[Model, Tuple[Model, History]]:
"""Train a model from tensor generators for validation and training data.

Expand Down Expand Up @@ -1051,8 +1052,9 @@ def train_model_from_generators(
validation_steps=validation_steps, validation_data=generate_valid,
callbacks=_get_callbacks(patience, model_file),
)
generate_train.kill_workers()
generate_valid.kill_workers()
if not defer_worker_halt:
generate_train.kill_workers()
generate_valid.kill_workers()

logging.info('Model weights saved at: %s' % model_file)
if plot:
Expand Down
9 changes: 8 additions & 1 deletion ml4cvd/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,17 @@ def train_multimodal_multitask(args):
model = train_model_from_generators(
model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size,
args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels,
defer_worker_halt=args.plot_train_curves
)

out_path = os.path.join(args.output_folder, args.id + '/')
test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
train_data, train_labels = big_batch_from_minibatch_generator(generate_train, args.training_steps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the size of the big_batch returned here? training_steps is usually alot larger than test_steps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train big_batch shape = (25600, 2500, 12), as opposed to (2048, 2500, 12) for test

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a need to return train_paths?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to return it since it isn't used in plotting?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is optional, if provided it will be used to label outliers

if args.plot_train_curves:
out_path_train = os.path.join(args.output_folder, args.id + '/train_pr_roc_curves/')
_predict_and_evaluate(model, train_data, train_labels, args.tensor_maps_in, args.tensor_maps_out, args.batch_size, args.hidden_layer, out_path_train, test_paths, args.embed_visualization, args.alpha)
if args.plot_train_curves:
generate_train.kill_workers()
generate_valid.kill_workers()
return _predict_and_evaluate(model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.batch_size, args.hidden_layer, out_path, test_paths, args.embed_visualization, args.alpha)


Expand Down