@@ -31,13 +31,14 @@ def get_args():
3131 parser .add_argument (
3232 "--dtype" ,
3333 default = "uint8" ,
34- choices = ["uint8" ,"float32" , "int32" , "int64" ],
34+ choices = ["uint8" , "float32" , "int32" , "int64" ],
3535 help = "data type of the label" ,
3636 )
3737 args = parser .parse_args ()
3838 return args
3939
40- def load_labels (base_path , dataset_size , use_label_2K = True ):
40+
41+ def load_labels (base_path , dataset_size , use_label_2K = True ):
4142 # load labels
4243 paper_nodes_num = {
4344 "tiny" : 100000 ,
@@ -49,7 +50,12 @@ def load_labels(base_path, dataset_size, use_label_2K = True):
4950 label_file = (
5051 "node_label_19.npy" if not use_label_2K else "node_label_2K.npy"
5152 )
52- paper_lbl_path = os .path .join (base_path , dataset_size , "processed" , "paper" , label_file )
53+ paper_lbl_path = os .path .join (
54+ base_path ,
55+ dataset_size ,
56+ "processed" ,
57+ "paper" ,
58+ label_file )
5359
5460 if dataset_size in ["large" , "full" ]:
5561 paper_node_labels = torch .from_numpy (
@@ -63,11 +69,11 @@ def load_labels(base_path, dataset_size, use_label_2K = True):
6369 torch .long )
6470 labels = paper_node_labels
6571 val_idx = torch .load (
66- os .path .join (
67- base_path ,
68- dataset_size ,
69- "processed" ,
70- "val_idx.pt" ))
72+ os .path .join (
73+ base_path ,
74+ dataset_size ,
75+ "processed" ,
76+ "val_idx.pt" ))
7177 return labels , val_idx
7278
7379
@@ -77,7 +83,11 @@ def get_labels(labels, val_idx, id_list):
7783
7884if __name__ == "__main__" :
7985 args = get_args ()
80- dtype_map = {"uint8" : np .uint8 ,"float32" : np .float32 , "int32" : np .int32 , "int64" : np .int64 }
86+ dtype_map = {
87+ "uint8" : np .uint8 ,
88+ "float32" : np .float32 ,
89+ "int32" : np .int32 ,
90+ "int64" : np .int64 }
8191
8292 with open (args .mlperf_accuracy_file , "r" ) as f :
8393 mlperf_results = json .load (f )
@@ -97,7 +107,8 @@ def get_labels(labels, val_idx, id_list):
97107 # get ground truth
98108 label = get_labels (labels , val_idx , idx )
99109 # get prediction
100- data = int (np .frombuffer (bytes .fromhex (result ["data" ]), dtype_map [args .dtype ])[0 ])
110+ data = int (np .frombuffer (bytes .fromhex (
111+ result ["data" ]), dtype_map [args .dtype ])[0 ])
101112 if label == data :
102113 good += 1
103114 total += 1
@@ -106,7 +117,5 @@ def get_labels(labels, val_idx, id_list):
106117 results ["number_correct_samples" ] = good
107118 results ["performance_sample_count" ] = total
108119
109-
110120 with open (args .output_file , "w" ) as fp :
111121 json .dump (results , fp )
112-
0 commit comments