Skip to content

Commit e029d8a

Browse files
Merge branch 'new_expts' of git+ssh://github.com/automl/mdp-playground into new_expts
2 parents 2241e59 + 4240c87 commit e029d8a

File tree

4 files changed

+42
-7
lines changed

4 files changed

+42
-7
lines changed

codecov.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
coverage:
2+
range: 68..100
3+
round: down
4+
precision: 2
5+
status:
6+
project:
7+
default:
8+
# basic
9+
target: 68%
10+
threshold: 5%
11+
base: auto

example.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def atari_wrapper_example():
257257
"state_space_type": "discrete",
258258
}
259259

260-
from mdp_playground.envs.gym_env_wrapper import GymEnvWrapper
260+
from mdp_playground.envs import GymEnvWrapper
261261
import gym
262262

263263
ae = gym.make("QbertNoFrameskip-v4")
@@ -299,8 +299,12 @@ def mujoco_wrapper_example():
299299
# This actually makes a subclass and not a wrapper. Because, some
300300
# frameworks might need an instance of this class to also be an instance
301301
# of the Mujoco base_class.
302-
from mdp_playground.envs.mujoco_env_wrapper import get_mujoco_wrapper
303-
from gym.envs.mujoco.half_cheetah_v3 import HalfCheetahEnv
302+
try:
303+
from mdp_playground.envs import get_mujoco_wrapper
304+
from gym.envs.mujoco.half_cheetah_v3 import HalfCheetahEnv
305+
except Exception as e:
306+
print("Exception:", e, "caught. You may need to install mujoco-py. NOT running mujoco_wrapper_example.")
307+
return
304308

305309
wrapped_mujoco_env = get_mujoco_wrapper(HalfCheetahEnv)
306310

@@ -413,7 +417,6 @@ def minigrid_wrapper_example():
413417
"action_space_size": 8,
414418
"state_space_type": "discrete",
415419
"action_space_type": "discrete",
416-
"terminal_state_density": 0.25,
417420
"maximally_connected": True,
418421
}
419422
)

mdp_playground/envs/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
11
from mdp_playground.envs.rl_toy_env import RLToyEnv
2+
3+
try:
4+
from mdp_playground.envs.gym_env_wrapper import GymEnvWrapper
5+
from mdp_playground.envs.mujoco_env_wrapper import get_mujoco_wrapper
6+
except Exception as e:
7+
print("Exception:", e, "caught. You may need to install Ray or mujoco-py.")

mdp_playground/envs/rl_toy_env.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class RLToyEnv(gym.Env):
5353
diameter : int > 0
5454
For discrete environments, if diameter = d, the set of states is set to be a d-partite graph (and NOT a complete d-partite graph), where, if we order the d sets as 1, 2, .., d, states from set 1 will have actions leading to states in set 2 and so on, with the final set d having actions leading to states in set 1. Number of actions for each state will, thus, be = (number of states) / (d). Default value: 1 for discrete environments. For continuous environments, this dimension is set automatically based on the state_space_max value.
5555
terminal_state_density : float in range [0, 1]
56-
For discrete environments, the fraction of states that are terminal; the terminal states are fixed to the "last" states when we consider them to be ordered by their numerical value. This is w.l.o.g. because discrete states are categorical. For continuous environments, please see terminal_states and term_state_edge for how to control terminal states.
56+
For discrete environments, the fraction of states that are terminal; the terminal states are fixed to the "last" states when we consider them to be ordered by their numerical value. This is w.l.o.g. because discrete states are categorical. For continuous environments, please see terminal_states and term_state_edge for how to control terminal states. Default value: 0.25.
5757
term_state_reward : float
5858
Adds this to the reward if a terminal state was reached at the current time step. Default value: 0.
5959
image_representations : boolean
@@ -217,6 +217,16 @@ def __init__(self, **config):
217217

218218
print("Passed config:", config, "\n")
219219

220+
if config == {}:
221+
config = {
222+
"state_space_size": 8,
223+
"action_space_size": 8,
224+
"state_space_type": "discrete",
225+
"action_space_type": "discrete",
226+
"terminal_state_density": 0.25,
227+
"maximally_connected": True,
228+
}
229+
220230
# Print initial "banner"
221231
screen_output_width = 132 # #hardcoded #TODO get from system
222232
repeat_equal_sign = (screen_output_width - 20) // 2
@@ -329,6 +339,11 @@ def __init__(self, **config):
329339
# if config["state_space_type"] == "discrete":
330340
# assert "init_state_dist" in config
331341

342+
if "terminal_state_density" not in config:
343+
self.terminal_state_density = 0.25
344+
else:
345+
self.terminal_state_density = config["terminal_state_density"]
346+
332347
if not self.use_custom_mdp:
333348
if "generate_random_mdp" not in config:
334349
self.generate_random_mdp = True
@@ -786,7 +801,7 @@ def init_terminal_states(self):
786801
"""Initialises terminal state set to be the 'last' states for discrete environments. For continuous environments, terminal states will be in a hypercube centred around config['terminal_states'] with the edge of the hypercube of length config['term_state_edge']."""
787802
if self.config["state_space_type"] == "discrete":
788803
if (
789-
self.use_custom_mdp and "terminal_state_density" not in self.config
804+
self.use_custom_mdp and "terminal_states" in self.config
790805
): # custom/user-defined terminal states
791806
self.is_terminal_state = (
792807
self.config["terminal_states"]
@@ -796,7 +811,7 @@ def init_terminal_states(self):
796811
else:
797812
# Define the no. of terminal states per independent set of the state space
798813
self.num_terminal_states = int(
799-
self.config["terminal_state_density"] * self.action_space_size[0]
814+
self.terminal_state_density * self.action_space_size[0]
800815
) # #hardcoded ####IMP Using action_space_size
801816
# since it contains state_space_size // diameter
802817
# if self.num_terminal_states == 0: # Have at least 1 terminal state?

0 commit comments

Comments
 (0)