Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions alf/algorithms/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,18 @@ def __init__(self, data_transformer_ctors, observation_spec):

@staticmethod
def _validate_order(data_transformers):
# Hindsight should probably not be used together with FrameStacker,
# unless done really carefully. Hindsight after FrameStacker is
# simply wrong, because Hindsight would read ``achieved_goal`` field
# of a future step directly from the replay buffer without stacking.
def _tier_of(data_transformer):
if isinstance(data_transformer, UntransformedTimeStep):
return 1
if isinstance(data_transformer,
(HindsightExperienceTransformer, FrameStacker)):
if isinstance(data_transformer, HindsightExperienceTransformer):
return 2
return 3
if isinstance(data_transformer, FrameStacker):
return 3
return 4

prev_tier = 0
for i in range(len(data_transformers)):
Expand Down
6 changes: 5 additions & 1 deletion alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,23 @@ def __init__(self,
replay_buffer_length = adjust_replay_buffer_length(
config, self._num_earliest_frames_ignored)

total_replay_size = replay_buffer_length * self._env.batch_size
if config.whole_replay_buffer_training and config.clear_replay_buffer:
# For whole replay buffer training, we would like to be sure
# that the replay buffer have enough samples in it to perform
# the training, which will most likely happen in the 2nd
# iteration. The minimum_initial_collect_steps guarantees that.
minimum_initial_collect_steps = replay_buffer_length * self._env.batch_size
minimum_initial_collect_steps = total_replay_size
if config.initial_collect_steps < minimum_initial_collect_steps:
common.info(
'Set the initial_collect_steps to minimum required '
f'value {minimum_initial_collect_steps} because '
'whole_replay_buffer_training is on.')
config.initial_collect_steps = minimum_initial_collect_steps

assert config.initial_collect_steps <= total_replay_size, \
"Training will not happen - insufficient replay buffer size."

self.set_replay_buffer(self._env.batch_size, replay_buffer_length,
config.priority_replay)

Expand Down
2 changes: 2 additions & 0 deletions alf/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def config1(config_name, value, mutable=True, raise_if_used=True):
config_node = _get_config_node(config_name)

if raise_if_used and config_node.is_used():
# Log error because pre_config catches and silences the ValueError.
logging.error("Config '%s' used before configured." % config_name)
raise ValueError(
"Config '%s' has already been used. You should config "
"its value before using it." % config_name)
Expand Down
3 changes: 2 additions & 1 deletion alf/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ def _generate_time_step(batched,
if env_id is None:
env_id = md.arange(batch_size, dtype=md.int32)
if reward is not None:
assert reward.shape[:1] == outer_dims
assert reward.shape[:1] == outer_dims, "%s, %s" % (reward.shape,
outer_dims)
if prev_action is not None:
flat_action = nest.flatten(prev_action)
assert flat_action[0].shape[:1] == outer_dims
Expand Down
4 changes: 3 additions & 1 deletion alf/networks/critic_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self,
joint_fc_layer_params=None,
activation=torch.relu_,
kernel_initializer=None,
last_bias_init_value=0.0,
use_fc_bn=False,
use_naive_parallel_network=False,
name="CriticNetwork"):
Expand Down Expand Up @@ -174,7 +175,8 @@ def __init__(self,
last_activation=math_ops.identity,
use_fc_bn=use_fc_bn,
last_kernel_initializer=last_kernel_initializer,
name=name)
last_bias_init_value=last_bias_init_value,
name=name + ".joint_encoder")
self._use_naive_parallel_network = use_naive_parallel_network

def make_parallel(self, n):
Expand Down
4 changes: 3 additions & 1 deletion alf/networks/encoding_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def __init__(self,
last_layer_size=None,
last_activation=None,
last_kernel_initializer=None,
last_bias_init_value=0.0,
last_use_fc_bn=False,
name="EncodingNetwork"):
"""
Expand Down Expand Up @@ -540,7 +541,8 @@ def __init__(self,
last_layer_size,
activation=last_activation,
use_bn=last_use_fc_bn,
kernel_initializer=last_kernel_initializer))
kernel_initializer=last_kernel_initializer,
bias_init_value=last_bias_init_value))
input_size = last_layer_size

if output_tensor_spec is not None:
Expand Down
1 change: 1 addition & 0 deletions alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def __init__(self, config: TrainerConfig, ddp_rank: int = -1):
logging.info(
"observation_spec=%s" % pprint.pformat(env.observation_spec()))
logging.info("action_spec=%s" % pprint.pformat(env.action_spec()))
logging.info("reward_spec=%s" % pprint.pformat(env.reward_spec()))

# for offline buffer construction
untransformed_observation_spec = env.observation_spec()
Expand Down
27 changes: 20 additions & 7 deletions alf/utils/data_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
DataItem = alf.data_structures.namedtuple(
"DataItem", [
"env_id", "x", "o", "reward", "step_type", "batch_info",
"replay_buffer", "rollout_info_field"
"replay_buffer", "rollout_info_field", "discount"
],
default_value=())

Expand All @@ -40,12 +40,20 @@ def get_batch(env_ids, dim, t, x):
batch_size = len(env_ids)
x = torch.as_tensor(x, dtype=torch.float32, device="cpu")
t = torch.as_tensor(t, dtype=torch.int32, device="cpu")
ox = (x * torch.arange(
batch_size, dtype=torch.float32, requires_grad=True,
device="cpu").unsqueeze(1) * torch.arange(
dim, dtype=torch.float32, requires_grad=True,
device="cpu").unsqueeze(0))
a = x * torch.ones(batch_size, dtype=torch.float32, device="cpu")

# We allow x and t inputs to be scalars, which will be expanded to be
# consistent with the batch_size.

def _need_to_expand(x):
return not (batch_size > 1 and x.ndim > 0 and batch_size == x.shape[0])

if _need_to_expand(x):
a = x * torch.ones(batch_size, dtype=torch.float32, device="cpu")
else:
a = x
if _need_to_expand(t):
t = t * torch.ones(batch_size, dtype=torch.int32, device="cpu")
ox = a.unsqueeze(1).clone().requires_grad_(True)
g = torch.zeros(batch_size, dtype=torch.float32, device="cpu")
# reward function adapted from ReplayBuffer: default_reward_fn
r = torch.where(
Expand All @@ -60,6 +68,10 @@ def get_batch(env_ids, dim, t, x):
"a": a,
"g": g
}),
discount=torch.tensor(
t != alf.data_structures.StepType.LAST,
dtype=torch.float32,
device="cpu"),
reward=r)


Expand All @@ -79,6 +91,7 @@ def __init__(self, *args):
"a": alf.TensorSpec(shape=(), dtype=torch.float32),
"g": alf.TensorSpec(shape=(), dtype=torch.float32)
}),
discount=alf.TensorSpec(shape=(), dtype=torch.float32),
reward=alf.TensorSpec(shape=(), dtype=torch.float32))

@parameterized.named_parameters([
Expand Down
2 changes: 2 additions & 0 deletions alf/utils/external_configurables.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@

gin.external_configurable(torch.nn.init.xavier_normal_,
'torch.nn.init.xavier_normal_')
gin.external_configurable(torch.nn.Embedding, 'torch.nn.Embedding')
gin.external_configurable(torch.nn.Sequential, 'torch.nn.Sequential')
6 changes: 5 additions & 1 deletion alf/utils/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def _summary(name, val):
def _summarize_all(path, t, m2, m):
if path:
path += "."
spec = TensorSpec.from_tensor(m if m2 is None else m2)
if m2 is not None:
spec = TensorSpec.from_tensor(m2)
else:
assert m is not None
spec = TensorSpec.from_tensor(m)
_summary(path + "tensor.batch_min",
_reduce_along_batch_dims(t, spec, torch.min))
_summary(path + "tensor.batch_max",
Expand Down