Skip to content

Commit 7367edd

Browse files
committed
adjust log and performance log interval
1 parent c71dd09 commit 7367edd

File tree

1 file changed

+10
-32
lines changed

1 file changed

+10
-32
lines changed

a3c_training_thread.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import random
55
import time
6+
import sys
67

78
from accum_trainer import AccumTrainer
89
from game_state import GameState
@@ -14,8 +15,8 @@
1415
from constants import ENTROPY_BETA
1516
from constants import USE_LSTM
1617

17-
import sys
18-
LOG_INTERVAL = 1000
18+
LOG_INTERVAL = 100
19+
PERFORMANCE_LOG_INTERVAL = 1000
1920

2021
class A3CTrainingThread(object):
2122
def __init__(self,
@@ -27,8 +28,6 @@ def __init__(self,
2728
max_global_time_step,
2829
device):
2930

30-
print("LOCAL_T_MAX=", LOCAL_T_MAX)
31-
3231
self.thread_index = thread_index
3332
self.learning_rate_input = learning_rate_input
3433
self.max_global_time_step = max_global_time_step
@@ -124,17 +123,8 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
124123
values.append(value_)
125124

126125
if (self.thread_index == 0) and (self.local_t % LOG_INTERVAL == 0):
127-
print("pi={} (thread{})".format(pi_, self.thread_index))
128-
print(" V={} (thread{})".format(value_, self.thread_index))
129-
130-
if np.any(np.isnan(pi_)):
131-
print("pi={} (thread{})".format(pi_, self.thread_index))
132-
print(" V={} (thread{})".format(value_, self.thread_index))
133-
print("##############################################################")
134-
print("# 'nan' appeared in pi. PLEASE KILL ME by 'control-c' #")
135-
print("# thread{} will exit".format(self.thread_index))
136-
print("##############################################################")
137-
sys.exit(0)
126+
print("pi={}".format(pi_))
127+
print(" V={}".format(value_))
138128

139129
# process game
140130
self.game_state.process(action)
@@ -153,20 +143,9 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
153143
# s_t1 -> s_t
154144
self.game_state.update()
155145

156-
if self.local_t % LOG_INTERVAL == 0:
157-
indent = " |" * self.thread_index
158-
elapsed_time = time.time() - self.start_time
159-
print("t={:6.0f},s={:9d},th={}:{}r={:3d} |".format(
160-
elapsed_time, global_t, self.thread_index,
161-
indent, self.episode_reward))
162-
163146
if terminal:
164147
terminal_end = True
165-
indent = " |" * self.thread_index
166-
elapsed_time = time.time() - self.start_time
167-
print("t={:6.0f},s={:9d},th={}:{}r={:3d}@END|".format(
168-
elapsed_time, global_t, self.thread_index,
169-
indent, self.episode_reward))
148+
print("score={}".format(self.episode_reward))
170149

171150
self._record_score(sess, summary_writer, summary_op, score_input,
172151
self.episode_reward, global_t)
@@ -230,13 +209,12 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
230209
sess.run( self.apply_gradients,
231210
feed_dict = { self.learning_rate_input: cur_learning_rate } )
232211

233-
if self.local_t - self.prev_local_t >= LOG_INTERVAL:
234-
self.prev_local_t += LOG_INTERVAL
212+
if (self.thread_index == 0) and (self.local_t - self.prev_local_t >= PERFORMANCE_LOG_INTERVAL):
213+
self.prev_local_t += PERFORMANCE_LOG_INTERVAL
235214
elapsed_time = time.time() - self.start_time
236215
steps_per_sec = global_t / elapsed_time
237-
if self.thread_index == 0:
238-
print("### Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour".format(
239-
global_t, elapsed_time, steps_per_sec, steps_per_sec * 3600 / 1000000.))
216+
print("### Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour".format(
217+
global_t, elapsed_time, steps_per_sec, steps_per_sec * 3600 / 1000000.))
240218

241219
# return advanced local step size
242220
diff_local_t = self.local_t - start_local_t

0 commit comments

Comments
 (0)