diff --git a/retinanet/csv_eval.py b/retinanet/csv_eval.py index 175fc36a7..aecab6905 100644 --- a/retinanet/csv_eval.py +++ b/retinanet/csv_eval.py @@ -177,6 +177,8 @@ def evaluate( all_annotations = _get_annotations(generator) average_precisions = {} + recalls = {} + precisions = {} for label in range(generator.num_classes()): false_positives = np.zeros((0,)) @@ -231,17 +233,19 @@ def evaluate( # compute average precision average_precision = _compute_ap(recall, precision) average_precisions[label] = average_precision, num_annotations + recalls[label] = recall + precisions[label] = precision print('\nmAP:') for label in range(generator.num_classes()): label_name = generator.label_to_name(label) print('{}: {}'.format(label_name, average_precisions[label][0])) - print("Precision: ",precision[-1]) - print("Recall: ",recall[-1]) + print("Precision: ",precisions[label][-1]) + print("Recall: ",recalls[label][-1]) if save_path!=None: - plt.plot(recall,precision) + plt.plot(recalls[label],precisions[label]) # naming the x axis plt.xlabel('Recall') # naming the y axis