3
3
import numpy as np
4
4
import random
5
5
import time
6
+ import sys
6
7
7
8
from accum_trainer import AccumTrainer
8
9
from game_state import GameState
14
15
from constants import ENTROPY_BETA
15
16
from constants import USE_LSTM
16
17
17
- import sys
18
- LOG_INTERVAL = 1000
18
+ LOG_INTERVAL = 100
19
+ PERFORMANCE_LOG_INTERVAL = 1000
19
20
20
21
class A3CTrainingThread (object ):
21
22
def __init__ (self ,
@@ -27,8 +28,6 @@ def __init__(self,
27
28
max_global_time_step ,
28
29
device ):
29
30
30
- print ("LOCAL_T_MAX=" , LOCAL_T_MAX )
31
-
32
31
self .thread_index = thread_index
33
32
self .learning_rate_input = learning_rate_input
34
33
self .max_global_time_step = max_global_time_step
@@ -124,17 +123,8 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
124
123
values .append (value_ )
125
124
126
125
if (self .thread_index == 0 ) and (self .local_t % LOG_INTERVAL == 0 ):
127
- print ("pi={} (thread{})" .format (pi_ , self .thread_index ))
128
- print (" V={} (thread{})" .format (value_ , self .thread_index ))
129
-
130
- if np .any (np .isnan (pi_ )):
131
- print ("pi={} (thread{})" .format (pi_ , self .thread_index ))
132
- print (" V={} (thread{})" .format (value_ , self .thread_index ))
133
- print ("##############################################################" )
134
- print ("# 'nan' appeared in pi. PLEASE KILL ME by 'control-c' #" )
135
- print ("# thread{} will exit" .format (self .thread_index ))
136
- print ("##############################################################" )
137
- sys .exit (0 )
126
+ print ("pi={}" .format (pi_ ))
127
+ print (" V={}" .format (value_ ))
138
128
139
129
# process game
140
130
self .game_state .process (action )
@@ -153,20 +143,9 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
153
143
# s_t1 -> s_t
154
144
self .game_state .update ()
155
145
156
- if self .local_t % LOG_INTERVAL == 0 :
157
- indent = " |" * self .thread_index
158
- elapsed_time = time .time () - self .start_time
159
- print ("t={:6.0f},s={:9d},th={}:{}r={:3d} |" .format (
160
- elapsed_time , global_t , self .thread_index ,
161
- indent , self .episode_reward ))
162
-
163
146
if terminal :
164
147
terminal_end = True
165
- indent = " |" * self .thread_index
166
- elapsed_time = time .time () - self .start_time
167
- print ("t={:6.0f},s={:9d},th={}:{}r={:3d}@END|" .format (
168
- elapsed_time , global_t , self .thread_index ,
169
- indent , self .episode_reward ))
148
+ print ("score={}" .format (self .episode_reward ))
170
149
171
150
self ._record_score (sess , summary_writer , summary_op , score_input ,
172
151
self .episode_reward , global_t )
@@ -230,13 +209,12 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
230
209
sess .run ( self .apply_gradients ,
231
210
feed_dict = { self .learning_rate_input : cur_learning_rate } )
232
211
233
- if self .local_t - self .prev_local_t >= LOG_INTERVAL :
234
- self .prev_local_t += LOG_INTERVAL
212
+ if ( self .thread_index == 0 ) and ( self . local_t - self .prev_local_t >= PERFORMANCE_LOG_INTERVAL ) :
213
+ self .prev_local_t += PERFORMANCE_LOG_INTERVAL
235
214
elapsed_time = time .time () - self .start_time
236
215
steps_per_sec = global_t / elapsed_time
237
- if self .thread_index == 0 :
238
- print ("### Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour" .format (
239
- global_t , elapsed_time , steps_per_sec , steps_per_sec * 3600 / 1000000. ))
216
+ print ("### Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour" .format (
217
+ global_t , elapsed_time , steps_per_sec , steps_per_sec * 3600 / 1000000. ))
240
218
241
219
# return advanced local step size
242
220
diff_local_t = self .local_t - start_local_t
0 commit comments