@@ -36,8 +36,12 @@ def colorize(string, color, bold=False, highlight=False):
36
36
37
37
class Logger ():
38
38
def __init__ (self , log_dir = "./logs" , exp_name = None , env_name = None , seed = 0 ):
39
+ self .exp_name = exp_name
40
+ self .env_name = env_name
41
+ self .seed = seed
42
+
39
43
num_exps = 0
40
- self .log_dir = f"./{ log_dir } /{ exp_name } _{ env_name } _seed{ seed } "
44
+ self .log_dir = f"./{ log_dir } /{ exp_name . replace ( '-' , '_' ) } _{ env_name . replace ( '-' , '_' ) } _seed{ seed } "
41
45
while True :
42
46
if os .path .exists (f"{ self .log_dir } -{ str (num_exps )} /" ):
43
47
num_exps += 1
@@ -73,20 +77,15 @@ def update(self, score, total_steps):
73
77
epinfo = {"mean_score" : avg_score , "total_steps" : total_steps , "std_score" : std_score , "max_score" : max_score , "min_score" : max_score }
74
78
self .logger .writerow (epinfo )
75
79
self .csv_file .flush ()
80
+
81
+ def new_custom_logger (self , filename = None , fieldnames = []):
82
+ custom_logger = CustomLogger (self .log_dir , self .exp_name , self .env_name , self .seed , filename , fieldnames )
83
+ return custom_logger
76
84
77
85
class CustomLogger ():
78
86
def __init__ (self , log_dir = "./logs" , exp_name = None , env_name = None , seed = 0 , filename = "logger.csv" , fieldnames = []):
79
87
self .fieldnames = ["total_steps" ] + fieldnames
80
- num_exps = 0
81
- self .log_dir = f"./{ log_dir } /{ exp_name } _{ env_name } _seed{ seed } "
82
- while True :
83
- if os .path .exists (f"{ self .log_dir } -{ str (num_exps )} /" ):
84
- num_exps += 1
85
- else :
86
- self .log_dir += f"-{ str (num_exps )} /"
87
- os .makedirs (self .log_dir )
88
- break
89
- self .csv_file = open (self .log_dir + '/' + filename , 'w' , encoding = 'utf8' )
88
+ self .csv_file = open (log_dir + '/' + filename , 'w' , encoding = 'utf8' )
90
89
header = {"t_start" : time .time (), 'env_id' : env_name , 'exp_name' : exp_name , 'seed' : seed }
91
90
header = '# {} \n ' .format (json .dumps (header ))
92
91
self .csv_file .write (header )
@@ -95,7 +94,6 @@ def __init__(self, log_dir="./logs", exp_name=None, env_name=None, seed=0, filen
95
94
self .csv_file .flush ()
96
95
97
96
def update (self , fieldvalues , total_steps ):
98
-
99
97
print (colorize (f"\n CustomLogger with fileds: { self .fieldnames } " , 'yellow' , bold = True ))
100
98
print (colorize (f"total_steps: { total_steps } , fieldvalues: { fieldvalues } \n " , 'yellow' , bold = True ))
101
99
0 commit comments