Skip to content

Commit 3b7d32e

Browse files
Fixed failing tests. Added run_experiments.py to root dir
1 parent 061b6ec commit 3b7d32e

File tree

3 files changed

+30
-10
lines changed

3 files changed

+30
-10
lines changed

mdp_playground/scripts/run_experiments.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323

2424
def main(args):
25-
#TODO Different seeds for Ray Trainer (TF, numpy, Python; Torch, Env), Environment (it has multiple sources of randomness too), Ray Evaluator
25+
# #TODO Different seeds for Ray Trainer (TF, numpy, Python; Torch, Env),
26+
# Environment (it has multiple sources of randomness too), Ray Evaluator
2627
# docstring at beginning of the file is stored in __doc__
2728
parser = argparse.ArgumentParser(description=__doc__)
2829
parser.add_argument('-c', '--config-file', dest='config_file',
@@ -138,14 +139,20 @@ def main(args):
138139

139140
if args.config_num is not None:
140141
stats_file_name += '_' + str(args.config_num)
141-
# elif args.agent_config_num is not None: ###IMP Commented out! If we append both these nums then, that can lead to 1M small files for 1000x1000 configs which doesn't play well with our Nemo cluster.
142+
# elif args.agent_config_num is not None: ###IMP Commented out! If we append
143+
# both these nums then, that can lead to 1M small files for 1000x1000 configs
144+
# which doesn't play well with our Nemo cluster.
142145
# stats_file_name += '_' + str(args.agent_config_num)
143146

144147
print("Stats file being written to:", stats_file_name)
145148

146-
config, final_configs = config_processor.process_configs(config_file, stats_file_prefix=stats_file_name, framework=args.framework, config_num=args.config_num, log_level=log_level_, framework_dir=args.framework_dir)
149+
config, final_configs = config_processor.process_configs(config_file,\
150+
stats_file_prefix=stats_file_name, framework=args.framework,\
151+
config_num=args.config_num, log_level=log_level_,\
152+
framework_dir=args.framework_dir)
147153

148-
print("Configuration number(s) that will be run:", "all" if args.config_num is None else args.config_num)
154+
print("Configuration number(s) that will be run:", "all" if args.config_num is\
155+
None else args.config_num)
149156

150157

151158
# import default_config

run_experiments.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import sys
2+
import mdp_playground.scripts.run_experiments as run_experiments
3+
4+
run_experiments.main(sys.argv[1:])

tests/test_mdp_playground.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def test_grid_env(self):
754754
env = RLToyEnv(**config)
755755

756756
state = env.get_augmented_state()['augmented_state'][-1]
757-
actions = [[0, 1], [-1, 1], [-1, 0], [1, -1], [0.5, -0.5], [1, 2], [1, 1], [0, 1]]
757+
actions = [[0, -1], [-1, 0], [1, 0], [1, 0], [0, 1], [0, 1], [0, 1], [0, 1], [-1, 0]]
758758

759759
tot_rew = 0
760760
for i in range(len(actions)):
@@ -765,29 +765,38 @@ def test_grid_env(self):
765765
state = next_state.copy()
766766
tot_rew += reward
767767

768-
assert tot_rew == 7.5, str(tot_rew)
768+
assert tot_rew == 8.25, str(tot_rew)
769769

770770
env.reset()
771771
env.close()
772772

773773

774774
# Test 2: Almost the same as 1, but with irrelevant features
775775
config["irrelevant_features"] = True
776+
config["term_state_reward"] = 0.
776777

777778
env = RLToyEnv(**config)
778779
state = env.get_augmented_state()['augmented_state'][-1]
779-
actions = [[0, 1], [-1, 1], [-1, 0], [1, -1], [0.5, -0.5], [1, 2], [1, 1], [0, 1]]
780+
actions = [[0, -1], [-1, 0], [1, 0], [1, 0], [0, 1], [0, 1], [0, 1], [0, 1], [-1, 0]]
780781

781782
tot_rew = 0
782783
for i in range(len(actions)):
783-
action = actions[i] + [-1, 0]
784+
action = actions[i] + [0, 0]
785+
next_obs, reward, done, info = env.step(action)
786+
next_state = env.get_augmented_state()['augmented_state'][-1]
787+
print("sars', done =", state, action, reward, next_state, done)
788+
state = next_state.copy()
789+
tot_rew += reward
790+
791+
for i in range(len(actions)):
792+
action = [0, 0] + actions[i]
784793
next_obs, reward, done, info = env.step(action)
785794
next_state = env.get_augmented_state()['augmented_state'][-1]
786-
print("sars'o', done =", state, action, reward, next_state, next_obs, done)
795+
print("sars', done =", state, action, reward, next_state, done)
787796
state = next_state.copy()
788797
tot_rew += reward
789798

790-
assert tot_rew == 7.5, str(tot_rew)
799+
assert tot_rew == 9, str(tot_rew)
791800

792801
env.reset()
793802
env.close()

0 commit comments

Comments
 (0)