Skip to content

Commit 2049f4e

Browse files
committed
update
1 parent 8d735ed commit 2049f4e

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

Diff for: tf2.0/rl_trader.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
from sklearn.preprocessing import StandardScaler
1616

1717

18+
import tensorflow as tf
19+
# if tf.__version__.startswith('2'):
20+
# tf.compat.v1.disable_eager_execution()
21+
22+
1823
# Let's use AAPL (Apple), MSI (Motorola), SBUX (Starbucks)
1924
def get_data():
2025
# returns a T x 3 list of stock prices
@@ -270,10 +275,10 @@ def update_replay_memory(self, state, action, reward, next_state, done):
270275
def act(self, state):
271276
if np.random.rand() <= self.epsilon:
272277
return np.random.choice(self.action_size)
273-
act_values = self.model.predict(state)
278+
act_values = self.model.predict(state, verbose=0)
274279
return np.argmax(act_values[0]) # returns action
275280

276-
281+
@tf.function
277282
def replay(self, batch_size=32):
278283
# first check if replay buffer contains enough data
279284
if self.memory.size < batch_size:
@@ -288,7 +293,7 @@ def replay(self, batch_size=32):
288293
done = minibatch['d']
289294

290295
# Calculate the tentative target: Q(s',a)
291-
target = rewards + (1 - done) * self.gamma * np.amax(self.model.predict(next_states), axis=1)
296+
target = rewards + (1 - done) * self.gamma * np.amax(self.model.predict(next_states, verbose=0), axis=1)
292297

293298
# With the Keras API, the target (usually) must have the same
294299
# shape as the predictions.
@@ -298,7 +303,7 @@ def replay(self, batch_size=32):
298303
# the prediction for all values.
299304
# Then, only change the targets for the actions taken.
300305
# Q(s,a)
301-
target_full = self.model.predict(states)
306+
target_full = self.model.predict(states, verbose=0)
302307
target_full[np.arange(batch_size), actions] = target
303308

304309
# Run one training step
@@ -316,6 +321,7 @@ def save(self, name):
316321
self.model.save_weights(name)
317322

318323

324+
319325
def play_one_episode(agent, env, is_train):
320326
# note: after transforming states are already 1xD
321327
state = env.reset()
@@ -340,6 +346,7 @@ def play_one_episode(agent, env, is_train):
340346
# config
341347
models_folder = 'rl_trader_models'
342348
rewards_folder = 'rl_trader_rewards'
349+
model_file = 'dqn.weights.h5'
343350
num_episodes = 2000
344351
batch_size = 32
345352
initial_investment = 20000
@@ -383,7 +390,7 @@ def play_one_episode(agent, env, is_train):
383390
agent.epsilon = 0.01
384391

385392
# load trained weights
386-
agent.load(f'{models_folder}/dqn.h5')
393+
agent.load(f'{models_folder}/{model_file}')
387394

388395
# play the game num_episodes times
389396
for e in range(num_episodes):
@@ -396,7 +403,7 @@ def play_one_episode(agent, env, is_train):
396403
# save the weights when we are done
397404
if args.mode == 'train':
398405
# save the DQN
399-
agent.save(f'{models_folder}/dqn.h5')
406+
agent.save(f'{models_folder}/{model_file}')
400407

401408
# save the scaler
402409
with open(f'{models_folder}/scaler.pkl', 'wb') as f:

0 commit comments

Comments
 (0)