Skip to content

Commit b7b58ad

Browse files
committed
switch to absolute imports
1 parent 1fb779a commit b7b58ad

File tree

16 files changed

+34
-20
lines changed

16 files changed

+34
-20
lines changed
File renamed without changes.
File renamed without changes.

examples/a2c_cartpole.py renamed to rl_baseline/examples/a2c_cartpole.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from registration import env_registry, method_registry, optimizer_registry
21
import torch
32
from torch import nn
43
import torch.nn.functional as f
5-
from methods.a2c import A2CModel
4+
5+
from rl_baseline.registration import env_registry, method_registry, optimizer_registry
6+
from rl_baseline.methods.a2c import A2CModel
67

78
# Define your own model. As long as it inherits from the compatible model class, the desired trainer (in this case, A2C) can use it.
89
class MyA2CModel(A2CModel):

rl_baseline/methods/__init__.py

Whitespace-only changes.

methods/a2c.py renamed to rl_baseline/methods/a2c.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from torch.autograd import Variable
88
import torch.nn.functional as f
99

10-
from core import StochasticPolicy, StateValue
11-
from registry import method_registry, model_registry, optimizer_registry
12-
from util import global_norm, log_format, write_tb_event
10+
from rl_baseline.core import StochasticPolicy, StateValue
11+
from rl_baseline.registry import method_registry, model_registry, optimizer_registry
12+
from rl_baseline.util import global_norm, log_format, write_tb_event
1313

1414

1515
# Set up logger

methods/dqn.py renamed to rl_baseline/methods/dqn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import torch.nn.functional as f
99
from torch.nn.utils import clip_grad_norm
1010

11-
from core import StochasticPolicy, StateValue, ActionValue
12-
from registry import method_registry, model_registry, optimizer_registry
13-
from util import global_norm, log_format, write_tb_event, linear_schedule, copy_params
11+
from rl_baseline.core import StochasticPolicy, StateValue, ActionValue
12+
from rl_baseline.registry import method_registry, model_registry, optimizer_registry
13+
from rl_baseline.util import global_norm, log_format, write_tb_event, linear_schedule, copy_params
1414

1515

1616
# Set up logger

registration.py renamed to rl_baseline/registration.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import inspect
2-
from core import GymEnvSpecWrapper
3-
from registry import env_registry, optimizer_registry, method_registry, model_registry
42
from gym.envs.registration import registry
53
from torch import optim
6-
from methods import a2c, dqn
4+
5+
from rl_baseline.registry import env_registry, optimizer_registry, method_registry, model_registry
6+
from rl_baseline.core import GymEnvSpecWrapper
7+
from rl_baseline.methods import a2c, dqn
78

89
# Register all envs in Gym
910
for spec in registry.all():
File renamed without changes.

tests/test_a2c.py renamed to rl_baseline/tests/test_a2c.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
def test_a2c_on_cartpole():
22
# Train a simple A2C linear model to solve CartPole-v0
3-
from methods.a2c import A2CTrainer, A2CLinearModel
43
import gym
54
import torch
65
import numpy as np
76
from torch import optim
87

8+
from rl_baseline.methods.a2c import A2CTrainer, A2CLinearModel
9+
910
# Fix seed for replication
1011
seed = 777
1112
torch.manual_seed(seed)

tests/test_core.py renamed to rl_baseline/tests/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from core import GymEnvSpecWrapper
1+
from rl_baseline.core import GymEnvSpecWrapper
22

33
def test_gym_env_spec_wrapper():
44
from gym.envs.registration import registry

tests/test_registration.py renamed to rl_baseline/tests/test_registration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from core import GymEnvSpecWrapper
2-
from registration import env_registry
1+
from rl_baseline.core import GymEnvSpecWrapper
2+
from rl_baseline.registration import env_registry
33

44
def test_gym_specs_wrapped():
55
spec = env_registry['gym.CartPole-v1']

tests/test_registry.py renamed to rl_baseline/tests/test_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from registry import Registry
1+
from rl_baseline.registry import Registry
22

33
reg = Registry()
44

tests/test_util.py renamed to rl_baseline/tests/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from util import linear_schedule
1+
from rl_baseline.util import linear_schedule
22

33
def test_linear_scheduler():
44
assert linear_schedule(0.5, 0, 100, 200, 150) == 0.25, 'In-between value of linear scheduler.'

train.py renamed to rl_baseline/train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from __future__ import division
33
from __future__ import print_function
44

5+
from six.moves import xrange
6+
57
import logging, itertools
68
import numpy as np
79

@@ -14,9 +16,8 @@
1416
import gym
1517
gym.undo_logger_setup()
1618

17-
from registration import env_registry, optimizer_registry, model_registry, method_registry
18-
19-
from util import log_format, global_norm, get_cartpole_state, set_cartpole_state, copy_params
19+
from rl_baseline.registration import env_registry, optimizer_registry, model_registry, method_registry
20+
from rl_baseline.util import log_format, global_norm, get_cartpole_state, set_cartpole_state, copy_params
2021

2122
logging.basicConfig(format=log_format)
2223
logger = logging.getLogger(__name__)
File renamed without changes.

setup.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from setuptools import setup, find_packages
2+
setup(
3+
name='rl_baseline',
4+
version='0.0.0',
5+
packages=find_packages(),
6+
author='Falcon Dai',
7+
author_email='[email protected]',
8+
keywords=['reinforcement learning', 'machine learning', 'pytorch', 'markov decision process'],
9+
description='PyTorch implementation of state-of-the-art reinforcement learning algorithms.',
10+
)

0 commit comments

Comments
 (0)