31
31
import matplotlib .pyplot as plt
32
32
33
33
import tensorlayer as tl
34
- from tutorial_wrappers import build_env
34
+ import gym
35
35
36
36
parser = argparse .ArgumentParser ()
37
- parser .add_argument ('--mode' , help = 'train or test' , default = 'test' )
37
+ parser .add_argument ('--train' , dest = 'train' , action = 'store_true' , default = False )
38
+ parser .add_argument ('--test' , dest = 'test' , action = 'store_true' , default = True )
38
39
parser .add_argument (
39
40
'--save_path' , default = None , help = 'folder to save if mode == train else model path,'
40
41
'qnet will be saved once target net update'
47
48
np .random .seed (args .seed )
48
49
tf .random .set_seed (args .seed ) # reproducible
49
50
env_id = args .env_id
50
- env = build_env (env_id , seed = args .seed )
51
+ env = gym .make (env_id )
52
+ env .seed (args .seed )
51
53
alg_name = 'C51'
52
54
53
55
# #################### hyper parameters ####################
@@ -195,7 +197,7 @@ class DQN(object):
195
197
def __init__ (self ):
196
198
model = MLP if qnet_type == 'MLP' else CNN
197
199
self .qnet = model ('q' )
198
- if args .mode == ' train' :
200
+ if args .train :
199
201
self .qnet .train ()
200
202
self .targetqnet = model ('targetq' )
201
203
self .targetqnet .infer ()
@@ -211,7 +213,7 @@ def __init__(self):
211
213
212
214
def get_action (self , obv ):
213
215
eps = epsilon (self .niter )
214
- if args .mode == ' train' and random .random () < eps :
216
+ if args .train and random .random () < eps :
215
217
return int (random .random () * out_dim )
216
218
else :
217
219
obv = np .expand_dims (obv , 0 ).astype ('float32' ) * ob_scale
@@ -275,76 +277,65 @@ def _train_func(self, b_o, b_index, b_m):
275
277
# ############################# Trainer ###################################
276
278
if __name__ == '__main__' :
277
279
dqn = DQN ()
278
- if args .mode == 'train' :
280
+ t0 = time .time ()
281
+ if args .train :
279
282
buffer = ReplayBuffer (buffer_size )
280
-
281
- o = env .reset ()
282
283
nepisode = 0
283
- t = time .time ()
284
284
all_episode_reward = []
285
285
for i in range (1 , number_timesteps + 1 ):
286
-
287
- a = dqn . get_action ( o )
288
-
289
- # execute action and feed to replay buffer
290
- # note that `_` tail in var name means next
291
- o_ , r , done , info = env . step ( a )
292
- buffer . add ( o , a , r , o_ , done )
293
-
294
- if i >= warm_start :
295
- transitions = buffer . sample ( batch_size )
296
- dqn . train ( * transitions )
297
-
298
- if done :
299
- episode_reward = info [ 'episode' ][ 'r' ]
300
- if nepisode == 0 :
301
- all_episode_reward . append ( episode_reward )
286
+ o = env . reset ()
287
+ episode_reward = 0
288
+ while True :
289
+ a = dqn . get_action ( o )
290
+ # execute action and feed to replay buffer
291
+ # note that `_` tail in var name means next
292
+ o_ , r , done , info = env . step ( a )
293
+ buffer . add ( o , a , r , o_ , done )
294
+ episode_reward += r
295
+
296
+ if i >= warm_start :
297
+ transitions = buffer . sample ( batch_size )
298
+ dqn . train ( * transitions )
299
+
300
+ if done :
301
+ break
302
302
else :
303
- all_episode_reward .append (all_episode_reward [- 1 ] * 0.9 + episode_reward * 0.1 )
304
- o = env .reset ()
303
+ o = o_
304
+
305
+ if nepisode == 0 :
306
+ all_episode_reward .append (episode_reward )
305
307
else :
306
- o = o_
307
-
308
- # episode in info is real (unwrapped) message
309
- if info .get ('episode' ):
310
- nepisode += 1
311
- reward , length = info ['episode' ]['r' ], info ['episode' ]['l' ]
312
- try :
313
- fps = int (length / (time .time () - t ))
314
- except :
315
- fps = 0
316
- print (
317
- 'Time steps so far: {}, episode so far: {}, '
318
- 'episode reward: {:.4f}, episode length: {}, FPS: {}' .format (i , nepisode , reward , length , fps )
308
+ all_episode_reward .append (all_episode_reward [- 1 ] * 0.9 + episode_reward * 0.1 )
309
+ nepisode += 1
310
+ print (
311
+ 'Training | Episode: {} | Episode Reward: {:.4f} | Running Time: {:.4f}' .format (
312
+ nepisode , episode_reward , time .time () - t0
319
313
)
320
- t = time . time ()
314
+ ) # episode num starts from 1 in print
321
315
322
316
dqn .save (args .save_path )
323
317
plt .plot (all_episode_reward )
324
318
if not os .path .exists ('image' ):
325
319
os .makedirs ('image' )
326
320
plt .savefig (os .path .join ('image' , '_' .join ([alg_name , env_id ])))
327
- else :
321
+
322
+ if args .test :
328
323
nepisode = 0
329
- o = env .reset ()
330
324
for i in range (1 , number_timesteps + 1 ):
331
- a = dqn .get_action (o )
332
-
333
- # execute action
334
- # note that `_` tail in var name means next
335
- o_ , r , done , info = env .step (a )
336
- env .render ()
337
-
338
- if done :
339
- o = env .reset ()
340
- else :
341
- o = o_
342
-
343
- # episode in info is real (unwrapped) message
344
- if info .get ('episode' ):
345
- nepisode += 1
346
- reward , length = info ['episode' ]['r' ], info ['episode' ]['l' ]
347
- print (
348
- 'Time steps so far: {}, episode so far: {}, '
349
- 'episode reward: {:.4f}, episode length: {}' .format (i , nepisode , reward , length )
350
- )
325
+ o = env .reset ()
326
+ episode_reward = 0
327
+ while True :
328
+ env .render ()
329
+ a = dqn .get_action (o )
330
+ o_ , r , done , info = env .step (a )
331
+ episode_reward += r
332
+ if done :
333
+ break
334
+ else :
335
+ o = o_
336
+ nepisode += 1
337
+ print (
338
+ 'Testing | Episode: {} | Episode Reward: {:.4f} | Running Time: {:.4f}' .format (
339
+ nepisode , episode_reward , time .time () - t0
340
+ )
341
+ )
0 commit comments