Skip to content

Commit de91b0c

Browse files
Result analysis for large files
1 parent f19fc80 commit de91b0c

File tree

1 file changed

+213
-0
lines changed

1 file changed

+213
-0
lines changed
+213
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import os
2+
import json
3+
import logging
4+
from concurrent.futures import ProcessPoolExecutor
5+
from analysis_utils import format_type
6+
from tqdm import tqdm
7+
from multiprocessing import cpu_count
8+
from threading import Lock
9+
10+
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
11+
TEST_DIR = os.path.join(
12+
SCRIPT_DIR, "results_analysis_tests/test/micro-benchmark/python_features"
13+
)
14+
15+
16+
def check_match(
17+
expected,
18+
out,
19+
top_n=1,
20+
is_ml=False,
21+
print_mismatch=False,
22+
metadata=None,
23+
):
24+
"""
25+
Check for both exact and partial matches between expected and out entries.
26+
Returns a tuple: (is_exact_match, is_partial_match).
27+
"""
28+
metadata = metadata or {}
29+
30+
# Check keys in `out` are present in `expected`
31+
if not all(
32+
x in expected
33+
for x in out.keys()
34+
if x not in {"type", "all_type_preds", "col_offset"}
35+
):
36+
return False, False
37+
38+
# Early exits for file and line number mismatches
39+
if expected.get("file") != out.get("file"):
40+
return False, False
41+
42+
if expected.get("line_number") != out.get("line_number"):
43+
return False, False
44+
45+
# Optional column offset check
46+
if "col_offset" in expected and expected.get("col_offset") != out.get("col_offset"):
47+
return False, False
48+
49+
# Match specific fields if present
50+
for key in ["function", "parameter", "variable"]:
51+
if key in expected and expected.get(key) != out.get(key):
52+
return False, False
53+
54+
# Type matching logic
55+
if is_ml:
56+
_types = [x[0] for x in out.get("all_type_preds", [])]
57+
else:
58+
_types = out.get("type", [])
59+
60+
type_formatted = format_type([_types])
61+
expected_type_formatted = format_type([expected.get("type", [])])
62+
63+
# Exact match check
64+
is_exact_match = any(
65+
sorted(expected_type_formatted) == [t_list] for t_list in type_formatted[:top_n]
66+
)
67+
68+
# Partial match check
69+
expected_set = {t for sublist in expected_type_formatted for t in sublist}
70+
is_partial_match = any(
71+
expected_set.intersection(t_list) for t_list in type_formatted[:top_n]
72+
)
73+
74+
if not (is_exact_match or is_partial_match) and print_mismatch:
75+
log_mismatch(metadata, expected, out, partial_match=True)
76+
77+
return is_exact_match, is_partial_match
78+
79+
80+
def log_mismatch(metadata, expected, out, partial_match):
81+
"""
82+
Log mismatched cases for debugging or analysis.
83+
"""
84+
print(f"\n\n##### Type mismatch! #####\nPartial matching: {partial_match}")
85+
tool_name = metadata.get("tool_name", "unknown_tool")
86+
mismatch_file = f"{tool_name}_mismatches_reasons.csv"
87+
with open(mismatch_file, "a") as f:
88+
f.write(
89+
";".join(
90+
[
91+
metadata.get("cat", "unknown_cat"),
92+
metadata.get("type_category", "unknown_category"),
93+
json.dumps(expected),
94+
json.dumps(out),
95+
]
96+
)
97+
)
98+
f.write("\n")
99+
100+
print("Ground Truth:")
101+
print(json.dumps(expected, indent=4))
102+
print("Output:")
103+
print(json.dumps(out, indent=4))
104+
print("####################\n\n")
105+
106+
107+
def sort_facts(data):
108+
"""
109+
Sort facts based on line_number and ensure 'type' fields (if list) are sorted.
110+
"""
111+
return sorted(data, key=lambda x: int(x.get("line_number", 0)))
112+
113+
114+
def load_and_sort_json(file_path):
115+
"""
116+
Load JSON from a file and sort the facts for consistent processing.
117+
"""
118+
with open(file_path) as f:
119+
data = json.load(f)
120+
return sort_facts(data)
121+
122+
123+
def measure_exact_matches(out, expected, tool_name=None, print_missed=False):
124+
"""
125+
Measure exact and partial matches between two JSON files.
126+
"""
127+
data_out = load_and_sort_json(out)
128+
data_expected = load_and_sort_json(expected)
129+
130+
results = {
131+
"num_all": len(data_expected),
132+
"num_caught_exact": 0,
133+
"num_caught_partial": 0,
134+
}
135+
136+
lock = Lock()
137+
progress_bar = tqdm(total=len(data_expected), desc="Processing facts", position=0)
138+
139+
# Process comparisons in parallel
140+
with ProcessPoolExecutor(max_workers=max(cpu_count() - 1, 1)) as executor:
141+
futures = []
142+
for fact_expected in data_expected:
143+
futures.append(
144+
executor.submit(process_fact_comparison, fact_expected, data_out)
145+
)
146+
147+
for future in futures:
148+
fact_expected = data_expected[futures.index(future)]
149+
try:
150+
is_exact_match, is_partial_match = future.result()
151+
with lock:
152+
if is_exact_match:
153+
results["num_caught_exact"] += 1
154+
elif is_partial_match:
155+
results["num_caught_partial"] += 1
156+
elif print_missed:
157+
log_missed_fact(tool_name, fact_expected)
158+
progress_bar.update(1)
159+
except Exception as e:
160+
logging.error(f"Error processing fact: {fact_expected} - {e}")
161+
162+
progress_bar.close()
163+
return results
164+
165+
166+
def process_fact_comparison(fact_expected, data_out):
167+
"""
168+
Compare a single fact against all output facts to determine exact and partial matches.
169+
Returns the combined match results.
170+
"""
171+
is_exact_match = False
172+
is_partial_match = False
173+
174+
for fact_out in data_out:
175+
exact_match, partial_match = check_match(fact_expected, fact_out)
176+
is_exact_match = is_exact_match or exact_match
177+
is_partial_match = is_partial_match or partial_match
178+
179+
# Break early if both matches are found
180+
if is_exact_match and is_partial_match:
181+
break
182+
183+
return is_exact_match, is_partial_match
184+
185+
186+
def log_missed_fact(tool_name, fact_expected):
187+
"""
188+
Log missed facts to a CSV file for further analysis.
189+
"""
190+
if not tool_name:
191+
return
192+
193+
missed_log_path = f"{tool_name}_not_found_reasons.csv"
194+
with open(missed_log_path, "a") as f:
195+
f.write(f";Missing Fact;{json.dumps(fact_expected)}\n")
196+
197+
198+
# Output the result
199+
if __name__ == "__main__":
200+
# Test the function
201+
for folder in os.listdir(TEST_DIR):
202+
print(folder)
203+
out = f"{TEST_DIR}/{folder}/test1/main_result.json"
204+
expected = f"{TEST_DIR}/{folder}/test1/main_gt.json"
205+
results = measure_exact_matches(out, expected)
206+
print(results)
207+
208+
out = "/home/ashwin/Downloads/rw-benchmark/rw-benchmark/test/test_result.json"
209+
expected = "/home/ashwin/Downloads/rw-benchmark/rw-benchmark/test/test_gt.json"
210+
tool_name = "my_tool"
211+
212+
results = measure_exact_matches(out, expected, tool_name=tool_name)
213+
print(results)

0 commit comments

Comments
 (0)