Skip to content

Commit dd9be96

Browse files
Submission checker v5.0 (mlcommons#1955)
* Submission checker v5.0 * [Automated Commit] Format Codebase --------- Co-authored-by: pgmpablo157321 <[email protected]> Co-authored-by: Arjun Suresh <[email protected]>
1 parent 21d845e commit dd9be96

File tree

2 files changed

+147
-43
lines changed

2 files changed

+147
-43
lines changed

graph/R-GAT/tools/accuracy_igbh.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ def get_args():
3131
parser.add_argument(
3232
"--dtype",
3333
default="uint8",
34-
choices=["uint8","float32", "int32", "int64"],
34+
choices=["uint8", "float32", "int32", "int64"],
3535
help="data type of the label",
3636
)
3737
args = parser.parse_args()
3838
return args
3939

40-
def load_labels(base_path, dataset_size, use_label_2K = True):
40+
41+
def load_labels(base_path, dataset_size, use_label_2K=True):
4142
# load labels
4243
paper_nodes_num = {
4344
"tiny": 100000,
@@ -49,7 +50,12 @@ def load_labels(base_path, dataset_size, use_label_2K = True):
4950
label_file = (
5051
"node_label_19.npy" if not use_label_2K else "node_label_2K.npy"
5152
)
52-
paper_lbl_path = os.path.join(base_path, dataset_size, "processed", "paper", label_file)
53+
paper_lbl_path = os.path.join(
54+
base_path,
55+
dataset_size,
56+
"processed",
57+
"paper",
58+
label_file)
5359

5460
if dataset_size in ["large", "full"]:
5561
paper_node_labels = torch.from_numpy(
@@ -63,11 +69,11 @@ def load_labels(base_path, dataset_size, use_label_2K = True):
6369
torch.long)
6470
labels = paper_node_labels
6571
val_idx = torch.load(
66-
os.path.join(
67-
base_path,
68-
dataset_size,
69-
"processed",
70-
"val_idx.pt"))
72+
os.path.join(
73+
base_path,
74+
dataset_size,
75+
"processed",
76+
"val_idx.pt"))
7177
return labels, val_idx
7278

7379

@@ -77,7 +83,11 @@ def get_labels(labels, val_idx, id_list):
7783

7884
if __name__ == "__main__":
7985
args = get_args()
80-
dtype_map = {"uint8": np.uint8,"float32": np.float32, "int32": np.int32, "int64": np.int64}
86+
dtype_map = {
87+
"uint8": np.uint8,
88+
"float32": np.float32,
89+
"int32": np.int32,
90+
"int64": np.int64}
8191

8292
with open(args.mlperf_accuracy_file, "r") as f:
8393
mlperf_results = json.load(f)
@@ -97,7 +107,8 @@ def get_labels(labels, val_idx, id_list):
97107
# get ground truth
98108
label = get_labels(labels, val_idx, idx)
99109
# get prediction
100-
data = int(np.frombuffer(bytes.fromhex(result["data"]), dtype_map[args.dtype])[0])
110+
data = int(np.frombuffer(bytes.fromhex(
111+
result["data"]), dtype_map[args.dtype])[0])
101112
if label == data:
102113
good += 1
103114
total += 1
@@ -106,7 +117,5 @@ def get_labels(labels, val_idx, id_list):
106117
results["number_correct_samples"] = good
107118
results["performance_sample_count"] = total
108119

109-
110120
with open(args.output_file, "w") as fp:
111121
json.dump(results, fp)
112-

0 commit comments

Comments
 (0)