-
Notifications
You must be signed in to change notification settings - Fork 77
/
Copy pathbenchmark.py
70 lines (61 loc) · 2.63 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from dataclasses import dataclass
from typing import List
import yaml
@dataclass
class Scenario:
room: str
receptacles: List[str]
seen_objects: List[str]
seen_placements: List[List[str]]
unseen_objects: List[str]
unseen_placements: List[List[str]]
annotator_notes: str
tags: List[str]
def load_scenarios(path='scenarios.yml'):
with open(path, 'r', encoding='utf8') as f:
scenarios = list(map(lambda x: Scenario(**x), yaml.safe_load(f)))
return scenarios
def parse_summary(summarization_completion):
lines = [l for l in map(str.strip, summarization_completion.split('\n')) if len(l) > 0]
if len(lines) > 1:
print('Warning: Using first line of multi-line summary')
return lines[0]
def parse_placements(placement_completion, objects):
placements = []
first_line = True
for line in placement_completion.strip().split('\n'):
if first_line:
obj = objects[0]
recep = line
first_line = False
else:
if len(line) == 0:
print('Warning: Stopping since newline was encountered')
break
placement_args = line.split(',')
if len(placement_args) != 2:
print('Warning: Skipping invalid placement')
continue
obj, recep = placement_args
if '(' in obj:
obj = obj.split('(')[1].strip().replace('"', '')
else:
print('Warning: Found possibly invalid placement')
obj = obj.strip().replace('"', '')
recep = recep.strip().replace(')', '').replace('"', '')
placements.append([obj, recep])
return placements
def check_placements(predicted_placements, correct_placements):
correct_placements_dict = {}
for obj, recep in correct_placements:
correct_placements_dict[obj] = recep
corrects = []
for obj, recep in predicted_placements: # Note that for repeats, this will only score the first instance
corrects.append(obj in correct_placements_dict and recep == correct_placements_dict.pop(obj))
accuracy = sum(corrects) / len(correct_placements)
return corrects, accuracy
if __name__ == '__main__':
assert check_placements([['o1', 'r1'], ['o2', 'r2']], [['o1', 'r1'], ['o2', 'r2']]) == ([True, True], 1.0)
assert check_placements([['o1', 'r2'], ['o2', 'r2']], [['o1', 'r1'], ['o2', 'r2']]) == ([False, True], 0.5)
assert check_placements([['o3', 'r1'], ['o2', 'r2']], [['o1', 'r1'], ['o2', 'r2']]) == ([False, True], 0.5)
assert check_placements([['o1', 'r1'], ['o1', 'r1']], [['o1', 'r1'], ['o2', 'r2']]) == ([True, False], 0.5)