|
| 1 | +import argparse |
| 2 | +from collections import defaultdict, Counter |
| 3 | +import pathlib |
| 4 | +import pandas as pd |
| 5 | +import json |
| 6 | +from clams_utils.aapb import goldretriever |
| 7 | + |
| 8 | +# constant: |
| 9 | +GOLD_URL = "https://github.com/clamsproject/aapb-annotations/tree/bebd93af0882b8cf942ba827917938b49570d6d9/scene-recognition/golds" |
| 10 | +# note that you must first have output mmif files to compare against |
| 11 | + |
| 12 | +# parse SWT output into dictionary to extract label-timepoint pairs |
| 13 | + |
| 14 | +# convert ISO timestamp strings (hours:minutes:seconds.ms) back to milliseconds |
| 15 | + |
| 16 | + |
| 17 | +def convert_iso_milliseconds(timestamp): |
| 18 | + ms = 0 |
| 19 | + # add hours |
| 20 | + ms += int(timestamp.split(":")[0]) * 3600000 |
| 21 | + # add minutes |
| 22 | + ms += int(timestamp.split(":")[1]) * 60000 |
| 23 | + # add seconds and milliseconds |
| 24 | + ms += float(timestamp.split(":")[2]) * 1000 |
| 25 | + ms = int(ms) |
| 26 | + return ms |
| 27 | + |
| 28 | +# extract gold pairs from each csv. note goldpath is fed in as a path object |
| 29 | +def extract_gold_labels(goldpath, count_subtypes=False): |
| 30 | + df = pd.read_csv(goldpath) |
| 31 | + # convert timestamps (iso) back to ms |
| 32 | + df['timestamp'] = df['timestamp'].apply(convert_iso_milliseconds) |
| 33 | + if count_subtypes: |
| 34 | + # fill empty subtype rows with '' then concatenate with type label |
| 35 | + df['subtype label'] = df['subtype label'].fillna("") |
| 36 | + df['combined'] = df['type label'] + ":" + df['subtype label'] |
| 37 | + # trim extra ":" |
| 38 | + df['combined'] = df['combined'].apply(lambda row: row[:-1] if row[-1] == ':' else row) |
| 39 | + # create dictionary of 'timestamp':'combined' from dataframe |
| 40 | + gold_dict = df.set_index('timestamp')['combined'].to_dict() |
| 41 | + else: |
| 42 | + # ignore subtype label column |
| 43 | + gold_dict = df.set_index('timestamp')['type label'].to_dict() |
| 44 | + # return dictionary that maps timestamps to label |
| 45 | + return gold_dict |
| 46 | + |
| 47 | +# method to match a given predicted timestamp (key) with the closest gold timestamp: |
| 48 | +# acceptable range is default +/- 5 ms. if nothing matches, return None |
| 49 | + |
| 50 | +def closest_gold_timestamp(pred_stamp, gold_dict, good_range = 5): |
| 51 | + # first check if pred in gold_dict. if yes, return pred |
| 52 | + if pred_stamp in gold_dict: |
| 53 | + return pred_stamp |
| 54 | + # for i = 5 to 1 check if pred - i in gold_dict, if yes return pred - i |
| 55 | + for i in range(good_range, 0, -1): |
| 56 | + if pred_stamp - i in gold_dict: |
| 57 | + return pred_stamp - i |
| 58 | + # for i = 1 to i = 5 check if pred + i in gold dict, if yes return pred + i |
| 59 | + for i in range(1, good_range + 1): |
| 60 | + if pred_stamp + i in gold_dict: |
| 61 | + return pred_stamp + i |
| 62 | + return None |
| 63 | + |
| 64 | +# extract predicted label pairs from output mmif and match with gold pairs |
| 65 | +# note that pred_path is already a filepath, not a string |
| 66 | +# returns a dictionary with timestamps as keys and tuples of labels as values. |
| 67 | + |
| 68 | + |
| 69 | +def extract_predicted_consolidate(pred_path, gold_dict, count_subtypes = False): |
| 70 | + # create a dictionary to fill in with timestamps -> label tuples (predicted, gold) |
| 71 | + combined_dict = {} |
| 72 | + with open(pred_path, "r") as file: |
| 73 | + pred_json = json.load(file) |
| 74 | + for view in pred_json["views"]: |
| 75 | + if "annotations" in view: |
| 76 | + for annotation in view["annotations"]: |
| 77 | + if "timePoint" in annotation['properties']: |
| 78 | + # match pred timestamp to closest gold timestamp |
| 79 | + # using default range (+/- 5ms) |
| 80 | + curr_timestamp = closest_gold_timestamp(annotation['properties']['timePoint'], gold_dict) |
| 81 | + # check if closest_gold_timestamp returned None (not within acceptable range) |
| 82 | + if not curr_timestamp: |
| 83 | + continue |
| 84 | + # truncate label if count_subtypes is false |
| 85 | + pred_label = annotation['properties']['label'] if count_subtypes else annotation['properties']['label'][0] |
| 86 | + # if NEG set to '-' |
| 87 | + if annotation['properties']['label'] == 'NEG': |
| 88 | + pred_label = '-' |
| 89 | + # put gold and pred labels into combined dictionary |
| 90 | + combined_dict[curr_timestamp] = (pred_label, gold_dict[curr_timestamp]) |
| 91 | + return combined_dict |
| 92 | + |
| 93 | +# calculate document-level p, r, f1 for each label and macro avg. also returns total counts |
| 94 | +# of tp, fp, fn for each label to calculate micro avg later. |
| 95 | +def document_evaluation(combined_dict): |
| 96 | + # count up tp, fp, fn for each label |
| 97 | + total_counts = defaultdict(Counter) |
| 98 | + for timestamp in combined_dict: |
| 99 | + pred, gold = combined_dict[timestamp][0], combined_dict[timestamp][1] |
| 100 | + if pred == gold: |
| 101 | + total_counts[pred]["tp"] += 1 |
| 102 | + else: |
| 103 | + total_counts[pred]["fp"] += 1 |
| 104 | + total_counts[gold]["fn"] += 1 |
| 105 | + # calculate P, R, F1 for each label, store in nested dictionary |
| 106 | + scores_by_label = defaultdict(lambda: defaultdict(float)) |
| 107 | + # running total for (macro) averaged scores per document |
| 108 | + average_p = 0 |
| 109 | + average_r = 0 |
| 110 | + average_f1 = 0 |
| 111 | + # counter to account for unseen labels |
| 112 | + unseen = 0 |
| 113 | + for label in total_counts: |
| 114 | + tp, fp, fn = total_counts[label]["tp"], total_counts[label]["fp"], total_counts[label]["fn"] |
| 115 | + # if no instances are present/predicted, account for this when taking average of scores |
| 116 | + if tp + fp + fn == 0: |
| 117 | + unseen += 1 |
| 118 | + precision = float(tp/(tp + fp)) if (tp + fp) > 0 else 0 |
| 119 | + recall = float(tp/(tp + fn)) if (tp + fn) > 0 else 0 |
| 120 | + f1 = float(2*(precision*recall)/(precision + recall)) if (precision + recall) > 0 else 0 |
| 121 | + # add individual scores to dict and then add to running sum |
| 122 | + scores_by_label[label]["precision"] = precision |
| 123 | + scores_by_label[label]["recall"] = recall |
| 124 | + scores_by_label[label]["f1"] = f1 |
| 125 | + average_p += precision |
| 126 | + average_r += recall |
| 127 | + average_f1 += f1 |
| 128 | + # calculate macro averages for document and add to scores_by_label |
| 129 | + # make sure to account for unseen unpredicted labels |
| 130 | + denominator = len(scores_by_label) - unseen |
| 131 | + scores_by_label["average"]["precision"] = float(average_p / denominator) |
| 132 | + scores_by_label["average"]["recall"] = float(average_r / denominator) |
| 133 | + scores_by_label["average"]["f1"] = float(average_f1 / denominator) |
| 134 | + # return both scores_by_label and total_counts (to calculate micro avg later) |
| 135 | + return scores_by_label, total_counts |
| 136 | + |
| 137 | +# once you have processed every document, this method runs to calculate the micro-averaged |
| 138 | +# scores. the input is a list of total_counts dictionaries, each obtained from running |
| 139 | +# document_evaluation. |
| 140 | +def total_evaluation(total_counts_list): |
| 141 | + # create dict to hold total tp, fp, fn for all labels |
| 142 | + total_instances_by_label = defaultdict(Counter) |
| 143 | + # iterate through total_counts_list to get complete count of tp, fp, fn by label |
| 144 | + for doc_dict in total_counts_list: |
| 145 | + for label in doc_dict: |
| 146 | + total_instances_by_label[label]["tp"] += doc_dict[label]["tp"] |
| 147 | + total_instances_by_label[label]["fp"] += doc_dict[label]["fp"] |
| 148 | + total_instances_by_label[label]["fn"] += doc_dict[label]["fn"] |
| 149 | + # include a section for total tp/fp/fn for all labels |
| 150 | + total_instances_by_label["all"]["tp"] += doc_dict[label]["tp"] |
| 151 | + total_instances_by_label["all"]["fp"] += doc_dict[label]["fp"] |
| 152 | + total_instances_by_label["all"]["fn"] += doc_dict[label]["fn"] |
| 153 | + # create complete_micro_scores to store micro avg scores for entire dataset |
| 154 | + complete_micro_scores = defaultdict(lambda: defaultdict(float)) |
| 155 | + # fill in micro scores |
| 156 | + for label in total_instances_by_label: |
| 157 | + tp, fp, fn = (total_instances_by_label[label]["tp"], total_instances_by_label[label]["fp"], |
| 158 | + total_instances_by_label[label]["fn"]) |
| 159 | + precision = float(tp/(tp + fp)) if (tp + fp) > 0 else 0 |
| 160 | + recall = float(tp/ (tp + fn)) if (tp + fn) > 0 else 0 |
| 161 | + f1 = float(2*precision*recall/(precision + recall)) if (precision + recall) > 0 else 0 |
| 162 | + complete_micro_scores[label]["precision"] = precision |
| 163 | + complete_micro_scores[label]["recall"] = recall |
| 164 | + complete_micro_scores[label]["f1"] = f1 |
| 165 | + return complete_micro_scores |
| 166 | + |
| 167 | +# run the evaluation on each predicted-gold pair of files, and then the entire dataset for |
| 168 | +# micro average |
| 169 | +def run_dataset_eval(mmif_dir, gold_dir, count_subtypes): |
| 170 | + # create dict of guid -> scores to store each dict of document-level scores |
| 171 | + doc_scores = {} |
| 172 | + # create list to store each dict of document-level counts |
| 173 | + document_counts = [] |
| 174 | + mmif_files = pathlib.Path(mmif_dir).glob("*.mmif") |
| 175 | + # get each mmif file |
| 176 | + for mmif_file in mmif_files: |
| 177 | + guid = "" |
| 178 | + with open(mmif_file, "r") as f: |
| 179 | + curr_mmif = json.load(f) |
| 180 | + # get guid |
| 181 | + location = curr_mmif["documents"][0]["properties"]["location"] |
| 182 | + guid = location.split("/")[-1].split(".")[0] |
| 183 | + # match guid with gold file |
| 184 | + gold_file = next(pathlib.Path(gold_dir).glob(f"*{guid}*")) |
| 185 | + # process gold |
| 186 | + gold_dict = extract_gold_labels(gold_file, count_subtypes) |
| 187 | + # process predicted and consolidate |
| 188 | + combined_dict = extract_predicted_consolidate(mmif_file, gold_dict, count_subtypes) |
| 189 | + # evaluate on document level, storing scores in document_scores and counts in document_counts |
| 190 | + eval_result = document_evaluation(combined_dict) |
| 191 | + doc_scores[guid] = eval_result[0] |
| 192 | + document_counts.append(eval_result[1]) |
| 193 | + # now after processing each document and storing the relevant scores, we can evaluate the |
| 194 | + # dataset performance as a whole |
| 195 | + data_scores = total_evaluation(document_counts) |
| 196 | + return doc_scores, data_scores |
| 197 | + |
| 198 | +def separate_score_outputs(doc_scores, dataset_scores, mmif_dir): |
| 199 | + # get name for new directory |
| 200 | + # with our standard, this results in "scores@" appended to the batch name |
| 201 | + batch_score_name = "scores@" + mmif_dir.split('@')[-1].strip('/') |
| 202 | + # create new dir for scores based on batch name |
| 203 | + new_dir = pathlib.Path.cwd() / batch_score_name |
| 204 | + new_dir.mkdir(parents = True, exist_ok = True) |
| 205 | + # iterate through nested dict, output separate scores for each guid |
| 206 | + for guid in doc_scores: |
| 207 | + doc_df = pd.DataFrame(doc_scores[guid]) |
| 208 | + doc_df = doc_df.transpose() |
| 209 | + out_path = new_dir / f"{guid}.csv" |
| 210 | + doc_df.to_csv(out_path) |
| 211 | + # output total dataset scores |
| 212 | + dataset_df = pd.DataFrame(dataset_scores) |
| 213 | + dataset_df = dataset_df.transpose() |
| 214 | + dataset_df.to_csv(new_dir/"dataset_scores.csv") |
| 215 | + |
| 216 | + |
| 217 | +if __name__ == "__main__": |
| 218 | + parser = argparse.ArgumentParser() |
| 219 | + parser.add_argument('-m', '--mmif_dir', type=str, required=True, |
| 220 | + help='directory containing machine-annotated files in MMIF format') |
| 221 | + parser.add_argument('-g', '--gold_dir', type=str, default=None, |
| 222 | + help='directory containing gold labels in csv format') |
| 223 | + parser.add_argument('-s', '--count_subtypes', type=bool, default=False, |
| 224 | + help='bool flag whether to consider subtypes for evaluation') |
| 225 | + args = parser.parse_args() |
| 226 | + mmif_dir = args.mmif_dir |
| 227 | + gold_dir = goldretriever.download_golds(GOLD_URL) if args.gold_dir is None else args.gold_dir |
| 228 | + count_subtypes = args.count_subtypes |
| 229 | + document_scores, dataset_scores = run_dataset_eval(mmif_dir, gold_dir, count_subtypes) |
| 230 | + # document scores are for each doc, dataset scores are for overall (micro avg) |
| 231 | + # call method to output scores for each doc and then for total scores |
| 232 | + separate_score_outputs(document_scores, dataset_scores, mmif_dir) |
| 233 | + |
0 commit comments