Skip to content

Commit b4726ac

Browse files
committed
unified ray init that is version agnostic
1 parent 47c8e68 commit b4726ac

File tree

4 files changed

+61
-26
lines changed

4 files changed

+61
-26
lines changed

mdp_playground/config_processor/config_processor.py

+46-19
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,29 @@
1717
ModelCatalog.register_custom_preprocessor("ohe", OneHotPreprocessor)
1818

1919

20+
# def init_ray(log_level=None, tmp_dir=None, include_webui=None,
21+
# object_store_memory=int(2e9),
22+
# redis_max_memory=int(1e9), local_mode=False):
23+
def init_ray(**kwargs):
24+
import ray
25+
if ray.__version__[0] == '1': # new version 1.0 API
26+
if "redis_max_memory" in kwargs:
27+
value = kwargs["redis_max_memory"]
28+
del kwargs["redis_max_memory"]
29+
kwargs["_redis_max_memory"] = value
30+
if "tmp_dir" in kwargs:
31+
value = kwargs["tmp_dir"]
32+
del kwargs["tmp_dir"]
33+
kwargs["_temp_dir"] = value
34+
35+
if "log_level" in kwargs:
36+
value = kwargs["log_level"]
37+
del kwargs["log_level"]
38+
kwargs["logging_level"] = value
39+
40+
ray.init(**kwargs)
41+
42+
2043
def process_configs(config_file, stats_file_prefix, config_num, log_level,
2144
framework='ray', framework_dir='/tmp/ray'):
2245
config_file_path = os.path.abspath('/'.join(config_file.split('/')[:-1]))
@@ -129,30 +152,34 @@ def process_configs(config_file, stats_file_prefix, config_num, log_level,
129152

130153

131154
def setup_ray(config, config_num, log_level, framework_dir):
132-
import ray
133155
tmp_dir = framework_dir + '/tmp_' + str(config_num)
156+
# import ray
134157
if config.algorithm == 'DQN': #hack
135-
ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9),
136-
temp_dir=tmp_dir,
137-
logging_level=log_level,
138-
# local_mode=True,
139-
# webui_host='0.0.0.0'); logging_level=logging.INFO,
140-
)
158+
init_ray(log_level=log_level, tmp_dir=tmp_dir)
159+
# ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9),
160+
# temp_dir=tmp_dir,
161+
# logging_level=log_level,
162+
# # local_mode=True,
163+
# # webui_host='0.0.0.0'); logging_level=logging.INFO,
164+
# )
141165
# ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), local_mode=True, plasma_directory='/tmp') #, memory=int(8e9), local_mode=True # local_mode (bool): If true, the code will be executed serially. This is useful for debugging. # when true on_train_result and on_episode_end operate in the same current directory as the script. A3C is crashing in local mode, so didn't use it and had to work around by giving full path + filename in stats_file_name.; also has argument driver_object_store_memory=, plasma_directory='/tmp'
142166
elif config.algorithm == 'A3C': #hack
143-
ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9),
144-
temp_dir=tmp_dir,
145-
logging_level=log_level,
146-
# local_mode=True,
147-
# webui_host='0.0.0.0'); logging_level=logging.INFO,
148-
) # ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), local_mode=True, plasma_directory='/tmp')
167+
init_ray(log_level=log_level, tmp_dir=tmp_dir)
168+
# ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9),
169+
# temp_dir=tmp_dir,
170+
# logging_level=log_level,
171+
# # local_mode=True,
172+
# # webui_host='0.0.0.0'); logging_level=logging.INFO,
173+
# ) # ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), local_mode=True, plasma_directory='/tmp')
149174
else:
150-
ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9),
151-
temp_dir=tmp_dir,
152-
logging_level=log_level,
153-
local_mode=True,
154-
# webui_host='0.0.0.0'); logging_level=logging.INFO,
155-
)
175+
init_ray(log_level=log_level, tmp_dir=tmp_dir, local_mode=True)
176+
177+
# ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9),
178+
# temp_dir=tmp_dir,
179+
# logging_level=log_level,
180+
# local_mode=True,
181+
# # webui_host='0.0.0.0'); logging_level=logging.INFO,
182+
# )
156183

157184
def init_stats_file(stats_file_name, columns_to_write):
158185
fout = open(stats_file_name, 'a') #hardcoded

mdp_playground/scripts/run_experiments_ray.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#from ray.rllib.utils.seed import seed as rllib_seed
1414
import mdp_playground
1515
from mdp_playground.envs import RLToyEnv
16+
from mdp_playground.config_processor import init_ray
1617
from ray.tune.registry import register_env
1718
register_env("RLToy-v0", lambda config: RLToyEnv(**config))
1819
import sys, os
@@ -97,13 +98,16 @@ def create_gym_env_wrapper_frame_stack_atari(config): #hack ###TODO remove?
9798
ModelCatalog.register_custom_preprocessor("ohe", OneHotPreprocessor)
9899

99100
if config.algorithm == 'DQN':
100-
ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), include_webui=False) #webui_host='0.0.0.0')
101+
init_ray(include_webui=False)
102+
# ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), include_webui=False) #webui_host='0.0.0.0')
101103
# ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), local_mode=True, plasma_directory='/tmp') #, memory=int(8e9), local_mode=True # when true on_train_result and on_episode_end operate in the same current directory as the script. A3C is crashing in local mode, so didn't use it and had to work around by giving full path + filename in stats_file_name.; also has argument driver_object_store_memory=, plasma_directory='/tmp'
102104
elif config.algorithm == 'A3C': #hack
103-
ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), include_webui=False)
105+
init_ray(include_webui=False)
106+
# ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), include_webui=False)
104107
# ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), local_mode=True, plasma_directory='/tmp')
105108
else:
106-
ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), local_mode=True, temp_dir='/tmp/ray' + str(args.config_num), include_webui=False)
109+
init_ray(local_mode=True, temp_dir='/tmp/ray' + str(args.config_num), include_webui=False)
110+
# ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), local_mode=True, temp_dir='/tmp/ray' + str(args.config_num), include_webui=False)
107111

108112

109113
var_configs_deepcopy = copy.deepcopy(config.var_configs) #hack because this needs to be read in on_train_result and trying to read config there raises an error because it's been imported from a Python module and I think they try to reload it there.
@@ -468,4 +472,4 @@ def create_gym_env_wrapper_mujoco_wrapper(config, wrapped_mujoco_env):
468472

469473
if __name__ == '__main__':
470474
args = parse_args()
471-
main(args)
475+
main(args)

pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,14 @@ scipy = { version = "1.3.0", optional = true }
5757
pillow = { version = "6.1.0", optional = true }
5858
tensorflow-probability = { version = "0.9.0", optional = true }
5959
mujoco-py = { version = "2.0.2.13", optional = true }
60+
tensorflow = { version = "2.2.0", optional = true }
61+
ray = { version = "1.3.0", optional = true }
6062
dill = "^0.3.3"
6163
colorama = "^0.4.4"
6264

6365
[tool.poetry.extras]
64-
extras_disc = ["pandas", "requests", "configspace", "scipy", "pillow", "pillow"]
66+
extras = ["pandas", "requests", "configspace", "scipy", "pillow", "tensorflow", "ray"]
67+
extras_disc = ["pandas", "requests", "configspace", "scipy", "pillow"]
6568
extras_cont = ["pandas", "requests", "configspace", "scipy", "tensorflow-probability", "mujoco-py"]
6669

6770
[tool.poetry.dev-dependencies]

setup.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
{'': ['*']}
1313

1414
extras_require = [
15-
'ray[rllib,debug]==0.7.3',
15+
'ray[rllib,debug]==1.3.0',
1616
'tensorflow==2.2.0',
1717
'pillow==6.1.0',
1818
'pandas==0.25.0',
@@ -88,7 +88,8 @@
8888
python_requires=">=3.6",
8989
install_requires=['gym<=0.14', 'dill'],
9090
extras_require={
91-
'extras_disc': extras_require,
91+
'extras': extras_require,
92+
'extras_disc': extras_require_disc,
9293
'extras_cont': extras_require_cont,
9394
},
9495

0 commit comments

Comments
 (0)