|
2 | 2 | # |
3 | 3 | # This source code is licensed under the MIT license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | | -from __future__ import annotations |
| 5 | + |
| 6 | +# This makes omegaconf unhappy with typing.Any |
| 7 | +# Therefore we need Optional and Union |
| 8 | +# from __future__ import annotations |
6 | 9 |
|
7 | 10 | from copy import copy |
8 | 11 | from dataclasses import dataclass, field as dataclass_field |
9 | | -from typing import Any, Callable, Sequence |
| 12 | +from typing import Any, Callable, Optional, Sequence, Union |
10 | 13 |
|
11 | 14 | import torch |
| 15 | +from omegaconf import DictConfig |
12 | 16 |
|
13 | 17 | from torchrl._utils import logger as torchrl_logger, VERBOSE |
14 | 18 | from torchrl.envs import ParallelEnv |
@@ -212,18 +216,18 @@ def get_norm_state_dict(env): |
212 | 216 | def transformed_env_constructor( |
213 | 217 | cfg: DictConfig, # noqa: F821 |
214 | 218 | video_tag: str = "", |
215 | | - logger: Logger | None = None, |
216 | | - stats: dict | None = None, |
| 219 | + logger: Optional[Logger] = None, # noqa |
| 220 | + stats: Optional[dict] = None, |
217 | 221 | norm_obs_only: bool = False, |
218 | 222 | use_env_creator: bool = False, |
219 | | - custom_env_maker: Callable | None = None, |
220 | | - custom_env: EnvBase | None = None, |
| 223 | + custom_env_maker: Optional[Callable] = None, |
| 224 | + custom_env: Optional[EnvBase] = None, |
221 | 225 | return_transformed_envs: bool = True, |
222 | | - action_dim_gsde: int | None = None, |
223 | | - state_dim_gsde: int | None = None, |
224 | | - batch_dims: int | None = 0, |
225 | | - obs_norm_state_dict: dict | None = None, |
226 | | -) -> Callable | EnvCreator: |
| 226 | + action_dim_gsde: Optional[int] = None, |
| 227 | + state_dim_gsde: Optional[int] = None, |
| 228 | + batch_dims: Optional[int] = 0, |
| 229 | + obs_norm_state_dict: Optional[dict] = None, |
| 230 | +) -> Union[Callable, EnvCreator]: |
227 | 231 | """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. |
228 | 232 |
|
229 | 233 | Args: |
@@ -329,7 +333,7 @@ def make_transformed_env(**kwargs) -> TransformedEnv: |
329 | 333 |
|
330 | 334 | def parallel_env_constructor( |
331 | 335 | cfg: DictConfig, **kwargs # noqa: F821 |
332 | | -) -> ParallelEnv | EnvCreator: |
| 336 | +) -> Union[ParallelEnv, EnvCreator]: |
333 | 337 | """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. |
334 | 338 |
|
335 | 339 | Args: |
@@ -374,7 +378,7 @@ def parallel_env_constructor( |
374 | 378 | def get_stats_random_rollout( |
375 | 379 | cfg: DictConfig, # noqa: F821 |
376 | 380 | proof_environment: EnvBase = None, |
377 | | - key: str | None = None, |
| 381 | + key: Optional[str] = None, |
378 | 382 | ): |
379 | 383 | """Gathers stas (loc and scale) from an environment using random rollouts. |
380 | 384 |
|
@@ -452,7 +456,7 @@ def get_stats_random_rollout( |
452 | 456 | def initialize_observation_norm_transforms( |
453 | 457 | proof_environment: EnvBase, |
454 | 458 | num_iter: int = 1000, |
455 | | - key: str | tuple[str, ...] = None, |
| 459 | + key: Union[str, tuple[str, ...]] = None, |
456 | 460 | ): |
457 | 461 | """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`. |
458 | 462 |
|
@@ -532,7 +536,7 @@ class EnvConfig: |
532 | 536 | # maximum steps per trajectory, frames per batch or any other factor in the algorithm, |
533 | 537 | # e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4 |
534 | 538 | # the actual number of frames retrieved will be 200e6. Default=1. |
535 | | - reward_scaling: float | None = None |
| 539 | + reward_scaling: Any = None # noqa |
536 | 540 | # scale of the reward. |
537 | 541 | reward_loc: float = 0.0 |
538 | 542 | # location of the reward. |
|
0 commit comments