Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions alf/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2507,6 +2507,79 @@ def _conv_transpose_2d(in_channels,
bias=bias)


@alf.configurable
class ResidualFCBlock(nn.Module):
r"""The Residual block with FC layers.

This is the Residual Feedforward block used in the following paper, replacing
the MLP layers.

::

Lee et al "SimBA: Simplicity Bias for Scaling up Parameters in Deep Reinforcement Learning", arXiv:2410.09754

The block is defined as,

:math:`x_{out} = x_{in} + 2-layer-MLP(LayerNorm(x_{in}))`

"""

def __init__(self,
input_size: int,
output_size: int,
hidden_size: Optional[int] = None,
use_bias: Optional[bool] = True,
use_output_ln: Optional[bool] = False,
activation: Callable = torch.relu_,
kernel_initializer: Callable[[Tensor],
None] = nn.init.kaiming_normal_,
bias_init_value: float = 0.0):
"""
Args:
input_size (int): input size
output_size (int): output size
hidden_sizes (int): size of the hidden layer.
use_bias (bool): whether to use bias for FC layers.
activation (Callable): activation for the hidden layer.
kernel_initializer (Callable): initializer for the FC layer kernel.
bias_init_value (float): a constant for the initial FC bias value.
"""
super().__init__()
self._use_output_ln = use_output_ln
if hidden_size is None:
hidden_size = input_size
fc1 = FC(input_size,
hidden_size,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_init_value=bias_init_value)
fc2 = FC(hidden_size,
output_size,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_init_value=bias_init_value)
self._core_layers = nn.Sequential(fc1, fc2)
self._ln = nn.LayerNorm(input_size)
if use_output_ln:
self._out_ln = nn.LayerNorm(output_size)

def reset_parameters(self):
self._ln.reset_parameters()
for layer in self._core_layers:
layer.reset_parameters()
if self._use_output_ln:
self._out_ln.reset_parameters()

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
core_inputs = self._ln(inputs)
core = self._core_layers(core_inputs)
outputs = core + inputs
if self._use_output_ln:
outputs = self._out_ln(outputs)
return outputs


@alf.configurable(whitelist=[
'with_batch_normalization', 'bn_ctor', 'weight_opt_args', 'activation'
])
Expand Down
16 changes: 16 additions & 0 deletions alf/networks/actor_distribution_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def __init__(self,
preprocessing_combiner=None,
conv_layer_params=None,
fc_layer_params=None,
use_residual_fc_block=False,
num_residual_fc_blocks=1,
residual_fc_block_hidden_size=None,
residual_fc_block_use_output_ln=True,
activation=torch.relu_,
kernel_initializer=None,
use_fc_bn=False,
Expand Down Expand Up @@ -173,6 +177,14 @@ def __init__(self,
where ``padding`` is optional.
fc_layer_params (tuple[int]): a tuple of integers representing hidden
FC layer sizes.
use_residual_fc_block (bool): whether to use residual block instead of
FC layers.
num_residual_fc_blocks (int): number of residual FC blocks, only valid
if use_residual_fc_block is True.
residual_fc_block_hidden_size (int): hidden size of residual FC blocks,
only valid if use_residual_fc_block is True.
residual_fc_block_use_output_ln (bool): whether to use layer norm for
the output of residual FC block, only valid if use_residual_fc_block.
activation (nn.functional): activation used for hidden layers.
kernel_initializer (Callable): initializer for all the layers
excluding the projection net. If none is provided a default
Expand Down Expand Up @@ -201,6 +213,10 @@ def __init__(self,
preprocessing_combiner=preprocessing_combiner,
conv_layer_params=conv_layer_params,
fc_layer_params=fc_layer_params,
use_residual_fc_block=use_residual_fc_block,
num_residual_fc_blocks=num_residual_fc_blocks,
residual_fc_block_hidden_size=residual_fc_block_hidden_size,
residual_fc_block_use_output_ln=residual_fc_block_use_output_ln,
activation=activation,
kernel_initializer=kernel_initializer,
use_fc_bn=use_fc_bn,
Expand Down
48 changes: 33 additions & 15 deletions alf/networks/critic_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def __init__(self,
action_fc_layer_params=None,
observation_action_combiner=None,
joint_fc_layer_params=None,
joint_use_residual_fc_block=False,
joint_num_residual_fc_blocks=1,
joint_residual_fc_block_hidden_size=None,
joint_residual_fc_block_use_output_ln=True,
activation=torch.relu_,
kernel_initializer=None,
use_fc_bn=False,
Expand Down Expand Up @@ -124,6 +128,14 @@ def __init__(self,
joint_fc_layer_params (tuple[int]): a tuple of integers representing
hidden FC layer sizes FC layers after merging observations and
actions.
joint_use_residual_fc_block (bool): whether to use residual block instead
of FC layers after merging observations and actions.
joint_num_residual_fc_blocks (int): number of joint residual FC blocks,
only valid if joint_use_residual_fc_block is True.
joint_residual_fc_block_hidden_size (int): hidden size of residual FC
blocks, only valid if joint_use_residual_fc_block is True.
joint_residual_fc_block_use_output_ln (bool): whether to use layer norm
for the output of joint residual FC block.
activation (nn.functional): activation used for hidden layers. The
last layer will not be activated.
kernel_initializer (Callable): initializer for all the layers but
Expand Down Expand Up @@ -184,21 +196,27 @@ def __init__(self,
if observation_action_combiner is None:
observation_action_combiner = alf.layers.NestConcat(dim=-1)

super().__init__(input_tensor_spec=input_tensor_spec,
output_tensor_spec=output_tensor_spec,
input_preprocessors=(obs_encoder, action_encoder),
preprocessing_combiner=observation_action_combiner,
fc_layer_params=joint_fc_layer_params,
activation=activation,
kernel_initializer=kernel_initializer,
use_fc_bn=use_fc_bn,
use_fc_ln=use_fc_ln,
last_layer_size=output_tensor_spec.numel,
last_activation=last_layer_activation,
last_kernel_initializer=last_kernel_initializer,
last_use_fc_bn=last_use_fc_bn,
last_use_fc_ln=last_use_fc_ln,
name=name)
super().__init__(
input_tensor_spec=input_tensor_spec,
output_tensor_spec=output_tensor_spec,
input_preprocessors=(obs_encoder, action_encoder),
preprocessing_combiner=observation_action_combiner,
fc_layer_params=joint_fc_layer_params,
use_residual_fc_block=joint_use_residual_fc_block,
num_residual_fc_blocks=joint_num_residual_fc_blocks,
residual_fc_block_hidden_size=joint_residual_fc_block_hidden_size,
residual_fc_block_use_output_ln=
joint_residual_fc_block_use_output_ln,
activation=activation,
kernel_initializer=kernel_initializer,
use_fc_bn=use_fc_bn,
use_fc_ln=use_fc_ln,
last_layer_size=output_tensor_spec.numel,
last_activation=last_layer_activation,
last_kernel_initializer=last_kernel_initializer,
last_use_fc_bn=last_use_fc_bn,
last_use_fc_ln=last_use_fc_ln,
name=name)
self._use_naive_parallel_network = use_naive_parallel_network

def make_parallel(self, n):
Expand Down
33 changes: 32 additions & 1 deletion alf/networks/encoding_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,10 @@ def __init__(self,
preprocessing_combiner=None,
conv_layer_params=None,
fc_layer_params=None,
use_residual_fc_block=False,
num_residual_fc_blocks=1,
residual_fc_block_hidden_size=None,
residual_fc_block_use_output_ln=True,
activation=torch.relu_,
kernel_initializer=None,
use_fc_bn=False,
Expand Down Expand Up @@ -668,6 +672,14 @@ def __init__(self,
where ``padding`` is optional.
fc_layer_params (tuple[int]): a tuple of integers
representing FC layer sizes.
use_residual_fc_block (bool): whether to use residual block instead of
FC layers.
num_residual_fc_blocks (int): number of residual FC blocks, only valid
if use_residual_fc_block is True.
residual_fc_block_hidden_size (int): hidden size of residual FC blocks,
only valid if use_residual_fc_block is True.
residual_fc_block_use_output_ln (bool): whether to use layer norm for
the output of residual FC block, only valid if use_residual_fc_block.
activation (nn.functional): activation used for all the layers but
the last layer.
kernel_initializer (Callable): initializer for all the layers but
Expand Down Expand Up @@ -766,7 +778,7 @@ def __init__(self,
f"The input shape {spec.shape} should be like (N, )"
"or (N, D, ).")

if fc_layer_params is None:
if fc_layer_params is None or use_residual_fc_block:
fc_layer_params = []
else:
assert isinstance(fc_layer_params, tuple)
Expand All @@ -790,6 +802,25 @@ def __init__(self,
kernel_initializer=kernel_initializer))
input_size = size

if use_residual_fc_block:
if residual_fc_block_hidden_size is None:
residual_fc_block_hidden_size = input_size
nets.append(
fc_layer_ctor(input_size,
residual_fc_block_hidden_size,
activation=activation,
use_bn=use_fc_bn,
use_ln=use_fc_ln,
kernel_initializer=kernel_initializer))
input_size = residual_fc_block_hidden_size
for _ in range(num_residual_fc_blocks):
nets.append(
layers.ResidualFCBlock(
input_size,
residual_fc_block_hidden_size,
use_output_ln=residual_fc_block_use_output_ln))
input_size = residual_fc_block_hidden_size

if last_layer_size is not None or last_activation is not None:
assert last_layer_size is not None and last_activation is not None, \
"Both last_layer_size and last_activation need to be specified!"
Expand Down