|
17 | 17 | ModelCatalog.register_custom_preprocessor("ohe", OneHotPreprocessor)
|
18 | 18 |
|
19 | 19 |
|
| 20 | +# def init_ray(log_level=None, tmp_dir=None, include_webui=None, |
| 21 | +# object_store_memory=int(2e9), |
| 22 | +# redis_max_memory=int(1e9), local_mode=False): |
| 23 | +def init_ray(**kwargs): |
| 24 | + import ray |
| 25 | + if ray.__version__[0] == '1': # new version 1.0 API |
| 26 | + if "redis_max_memory" in kwargs: |
| 27 | + value = kwargs["redis_max_memory"] |
| 28 | + del kwargs["redis_max_memory"] |
| 29 | + kwargs["_redis_max_memory"] = value |
| 30 | + if "tmp_dir" in kwargs: |
| 31 | + value = kwargs["tmp_dir"] |
| 32 | + del kwargs["tmp_dir"] |
| 33 | + kwargs["_temp_dir"] = value |
| 34 | + |
| 35 | + if "log_level" in kwargs: |
| 36 | + value = kwargs["log_level"] |
| 37 | + del kwargs["log_level"] |
| 38 | + kwargs["logging_level"] = value |
| 39 | + |
| 40 | + ray.init(**kwargs) |
| 41 | + |
| 42 | + |
20 | 43 | def process_configs(config_file, stats_file_prefix, config_num, log_level,
|
21 | 44 | framework='ray', framework_dir='/tmp/ray'):
|
22 | 45 | config_file_path = os.path.abspath('/'.join(config_file.split('/')[:-1]))
|
@@ -129,30 +152,34 @@ def process_configs(config_file, stats_file_prefix, config_num, log_level,
|
129 | 152 |
|
130 | 153 |
|
131 | 154 | def setup_ray(config, config_num, log_level, framework_dir):
|
132 |
| - import ray |
133 | 155 | tmp_dir = framework_dir + '/tmp_' + str(config_num)
|
| 156 | + # import ray |
134 | 157 | if config.algorithm == 'DQN': #hack
|
135 |
| - ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), |
136 |
| - temp_dir=tmp_dir, |
137 |
| - logging_level=log_level, |
138 |
| - # local_mode=True, |
139 |
| - # webui_host='0.0.0.0'); logging_level=logging.INFO, |
140 |
| - ) |
| 158 | + init_ray(log_level=log_level, tmp_dir=tmp_dir) |
| 159 | + # ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), |
| 160 | + # temp_dir=tmp_dir, |
| 161 | + # logging_level=log_level, |
| 162 | + # # local_mode=True, |
| 163 | + # # webui_host='0.0.0.0'); logging_level=logging.INFO, |
| 164 | + # ) |
141 | 165 | # ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), local_mode=True, plasma_directory='/tmp') #, memory=int(8e9), local_mode=True # local_mode (bool): If true, the code will be executed serially. This is useful for debugging. # when true on_train_result and on_episode_end operate in the same current directory as the script. A3C is crashing in local mode, so didn't use it and had to work around by giving full path + filename in stats_file_name.; also has argument driver_object_store_memory=, plasma_directory='/tmp'
|
142 | 166 | elif config.algorithm == 'A3C': #hack
|
143 |
| - ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), |
144 |
| - temp_dir=tmp_dir, |
145 |
| - logging_level=log_level, |
146 |
| - # local_mode=True, |
147 |
| - # webui_host='0.0.0.0'); logging_level=logging.INFO, |
148 |
| - ) # ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), local_mode=True, plasma_directory='/tmp') |
| 167 | + init_ray(log_level=log_level, tmp_dir=tmp_dir) |
| 168 | + # ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), |
| 169 | + # temp_dir=tmp_dir, |
| 170 | + # logging_level=log_level, |
| 171 | + # # local_mode=True, |
| 172 | + # # webui_host='0.0.0.0'); logging_level=logging.INFO, |
| 173 | + # ) # ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), local_mode=True, plasma_directory='/tmp') |
149 | 174 | else:
|
150 |
| - ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), |
151 |
| - temp_dir=tmp_dir, |
152 |
| - logging_level=log_level, |
153 |
| - local_mode=True, |
154 |
| - # webui_host='0.0.0.0'); logging_level=logging.INFO, |
155 |
| - ) |
| 175 | + init_ray(log_level=log_level, tmp_dir=tmp_dir, local_mode=True) |
| 176 | + |
| 177 | + # ray.init(object_store_memory=int(2e9), redis_max_memory=int(1e9), |
| 178 | + # temp_dir=tmp_dir, |
| 179 | + # logging_level=log_level, |
| 180 | + # local_mode=True, |
| 181 | + # # webui_host='0.0.0.0'); logging_level=logging.INFO, |
| 182 | + # ) |
156 | 183 |
|
157 | 184 | def init_stats_file(stats_file_name, columns_to_write):
|
158 | 185 | fout = open(stats_file_name, 'a') #hardcoded
|
|
0 commit comments