Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom committed Nov 13, 2024
1 parent 8e5ab91 commit b42ae90
Showing 1 changed file with 77 additions and 16 deletions.
93 changes: 77 additions & 16 deletions lite/valor_lite/object_detection/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

def benchmark_add_bounding_boxes(
n_labels: int,
n_annotations_per_datum: tuple[int, int],
n_annotation_pairs: int,
n_annotation_unmatched: int,
time_limit: float | None,
repeat: int = 1,
):
Expand Down Expand Up @@ -60,11 +61,31 @@ def benchmark_finalize(

elapsed = 0
for _ in range(repeat):
loader = generate_cache(
n_datums=n_datums,
n_labels=n_labels,
n_annotations_per_datum=n_annotations_per_datum,
)

n_matched, n_unmatched = n_annotations_per_datum
pairs = [
generate_bounding_box_pair(n_labels=n_labels)
for _ in range(n_matched)
]
unmatched_gts = [
generate_bounding_box(n_labels=n_labels, is_prediction=False)
for _ in range(n_unmatched)
]
unmatched_pds = [
generate_bounding_box(n_labels=n_labels, is_prediction=True)
for _ in range(n_unmatched)
]
gts = [gt for gt, _ in pairs] + unmatched_gts
pds = [pd for _, pd in pairs] + unmatched_pds

loader = DataLoader()
for i in range(n_datums):
detection = Detection(
uid=str(i),
groundtruths=gts,
predictions=pds,
)
loader.add_bounding_boxes([detection])
elapsed += profile(loader.finalize)()
return elapsed / repeat

Expand All @@ -84,11 +105,31 @@ def benchmark_compute_precision_recall(

elapsed = 0
for _ in range(repeat):
loader = generate_cache(
n_datums=n_datums,
n_labels=n_labels,
n_annotations_per_datum=n_annotations_per_datum,
)

n_matched, n_unmatched = n_annotations_per_datum
pairs = [
generate_bounding_box_pair(n_labels=n_labels)
for _ in range(n_matched)
]
unmatched_gts = [
generate_bounding_box(n_labels=n_labels, is_prediction=False)
for _ in range(n_unmatched)
]
unmatched_pds = [
generate_bounding_box(n_labels=n_labels, is_prediction=True)
for _ in range(n_unmatched)
]
gts = [gt for gt, _ in pairs] + unmatched_gts
pds = [pd for _, pd in pairs] + unmatched_pds

loader = DataLoader()
for i in range(n_datums):
detection = Detection(
uid=str(i),
groundtruths=gts,
predictions=pds,
)
loader.add_bounding_boxes([detection])
evaluator = loader.finalize()
elapsed += profile(evaluator.compute_precision_recall)()
return elapsed / repeat
Expand All @@ -110,11 +151,31 @@ def benchmark_compute_confusion_matrix(

elapsed = 0
for _ in range(repeat):
loader = generate_cache(
n_datums=n_datums,
n_labels=n_labels,
n_annotations_per_datum=n_annotations_per_datum,
)

n_matched, n_unmatched = n_annotations_per_datum
pairs = [
generate_bounding_box_pair(n_labels=n_labels)
for _ in range(n_matched)
]
unmatched_gts = [
generate_bounding_box(n_labels=n_labels, is_prediction=False)
for _ in range(n_unmatched)
]
unmatched_pds = [
generate_bounding_box(n_labels=n_labels, is_prediction=True)
for _ in range(n_unmatched)
]
gts = [gt for gt, _ in pairs] + unmatched_gts
pds = [pd for _, pd in pairs] + unmatched_pds

loader = DataLoader()
for i in range(n_datums):
detection = Detection(
uid=str(i),
groundtruths=gts,
predictions=pds,
)
loader.add_bounding_boxes([detection])
evaluator = loader.finalize()
elapsed += profile(evaluator.compute_confusion_matrix)(
number_of_examples=n_examples
Expand Down

0 comments on commit b42ae90

Please sign in to comment.