@@ -31,13 +31,14 @@ def get_args():
31
31
parser .add_argument (
32
32
"--dtype" ,
33
33
default = "uint8" ,
34
- choices = ["uint8" ,"float32" , "int32" , "int64" ],
34
+ choices = ["uint8" , "float32" , "int32" , "int64" ],
35
35
help = "data type of the label" ,
36
36
)
37
37
args = parser .parse_args ()
38
38
return args
39
39
40
- def load_labels (base_path , dataset_size , use_label_2K = True ):
40
+
41
+ def load_labels (base_path , dataset_size , use_label_2K = True ):
41
42
# load labels
42
43
paper_nodes_num = {
43
44
"tiny" : 100000 ,
@@ -49,7 +50,12 @@ def load_labels(base_path, dataset_size, use_label_2K = True):
49
50
label_file = (
50
51
"node_label_19.npy" if not use_label_2K else "node_label_2K.npy"
51
52
)
52
- paper_lbl_path = os .path .join (base_path , dataset_size , "processed" , "paper" , label_file )
53
+ paper_lbl_path = os .path .join (
54
+ base_path ,
55
+ dataset_size ,
56
+ "processed" ,
57
+ "paper" ,
58
+ label_file )
53
59
54
60
if dataset_size in ["large" , "full" ]:
55
61
paper_node_labels = torch .from_numpy (
@@ -63,11 +69,11 @@ def load_labels(base_path, dataset_size, use_label_2K = True):
63
69
torch .long )
64
70
labels = paper_node_labels
65
71
val_idx = torch .load (
66
- os .path .join (
67
- base_path ,
68
- dataset_size ,
69
- "processed" ,
70
- "val_idx.pt" ))
72
+ os .path .join (
73
+ base_path ,
74
+ dataset_size ,
75
+ "processed" ,
76
+ "val_idx.pt" ))
71
77
return labels , val_idx
72
78
73
79
@@ -77,7 +83,11 @@ def get_labels(labels, val_idx, id_list):
77
83
78
84
if __name__ == "__main__" :
79
85
args = get_args ()
80
- dtype_map = {"uint8" : np .uint8 ,"float32" : np .float32 , "int32" : np .int32 , "int64" : np .int64 }
86
+ dtype_map = {
87
+ "uint8" : np .uint8 ,
88
+ "float32" : np .float32 ,
89
+ "int32" : np .int32 ,
90
+ "int64" : np .int64 }
81
91
82
92
with open (args .mlperf_accuracy_file , "r" ) as f :
83
93
mlperf_results = json .load (f )
@@ -97,7 +107,8 @@ def get_labels(labels, val_idx, id_list):
97
107
# get ground truth
98
108
label = get_labels (labels , val_idx , idx )
99
109
# get prediction
100
- data = int (np .frombuffer (bytes .fromhex (result ["data" ]), dtype_map [args .dtype ])[0 ])
110
+ data = int (np .frombuffer (bytes .fromhex (
111
+ result ["data" ]), dtype_map [args .dtype ])[0 ])
101
112
if label == data :
102
113
good += 1
103
114
total += 1
@@ -106,7 +117,5 @@ def get_labels(labels, val_idx, id_list):
106
117
results ["number_correct_samples" ] = good
107
118
results ["performance_sample_count" ] = total
108
119
109
-
110
120
with open (args .output_file , "w" ) as fp :
111
121
json .dump (results , fp )
112
-
0 commit comments