Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 60 additions & 35 deletions pytorch_forecasting/models/nn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,36 @@ def handle_no_encoding(
"""
Mask the hidden_state where there is no encoding.

Args:
hidden_state (HiddenState): hidden state where some entries need replacement
no_encoding (torch.BoolTensor): positions that need replacement
initial_hidden_state (HiddenState): hidden state to use for replacement

Returns:
HiddenState: hidden state with propagated initial hidden state where appropriate
""" # noqa: E501
Parameters
----------
hidden_state : HiddenState
Hidden state where some entries need replacement.
no_encoding : torch.BoolTensor
Positions that need replacement.
initial_hidden_state : HiddenState
Hidden state to use for replacement.

Returns
-------
HiddenState
Hidden state with propagated initial hidden state where appropriate.
"""
pass

@abstractmethod
def init_hidden_state(self, x: torch.Tensor) -> HiddenState:
"""
Initialise a hidden_state.

Args:
x (torch.Tensor): network input
Parameters
----------
x : torch.Tensor
Network input.

Returns:
HiddenState: default (zero-like) hidden state
Returns
-------
HiddenState
Default (zero-like) hidden state.
"""
pass

Expand All @@ -59,12 +69,17 @@ def repeat_interleave(
"""
Duplicate the hidden_state n_samples times.

Args:
hidden_state (HiddenState): hidden state to repeat
n_samples (int): number of repetitions

Returns:
HiddenState: repeated hidden state
Parameters
----------
hidden_state : HiddenState
Hidden state to repeat.
n_samples : int
Number of repetitions.

Returns
-------
HiddenState
Repeated hidden state.
"""
pass

Expand All @@ -80,19 +95,25 @@ def forward(

Functions as normal for RNN. Only changes output if lengths are defined.

Args:
x (Union[rnn.PackedSequence, torch.Tensor]): input to RNN. either packed sequence or tensor of
padded sequences
hx (HiddenState, optional): hidden state. Defaults to None.
lengths (torch.LongTensor, optional): lengths of sequences. If not None, used to determine correct returned
hidden state. Can contain zeros. Defaults to None.
enforce_sorted (bool, optional): if lengths are passed, determines if RNN expects them to be sorted.
Defaults to True.

Returns:
Tuple[Union[rnn.PackedSequence, torch.Tensor], HiddenState]: output and hidden state.
Output is packed sequence if input has been a packed sequence.
""" # noqa: E501
Parameters
----------
x : rnn.PackedSequence or torch.Tensor
Input to RNN. Either packed sequence or tensor of padded sequences.
hx : HiddenState, optional
Hidden state. Defaults to None.
lengths : torch.LongTensor, optional
Lengths of sequences. If not None, used to determine correct returned
hidden state. Can contain zeros. Defaults to None.
enforce_sorted : bool, optional
If lengths are passed, determines if RNN expects them to be sorted.
Defaults to True.

Returns
-------
tuple of (rnn.PackedSequence or torch.Tensor, HiddenState)
Output and hidden state. Output is a packed sequence if input
was a packed sequence.
"""
if isinstance(x, rnn.PackedSequence) or lengths is None:
assert (
lengths is None
Expand Down Expand Up @@ -230,11 +251,15 @@ def get_rnn(cell_type: type[RNN] | str) -> type[RNN]:
"""
Get LSTM or GRU.

Args:
cell_type (Union[RNN, str]): "LSTM" or "GRU"
Parameters
----------
cell_type : type[RNN] or str
RNN class or string identifier, either ``"LSTM"`` or ``"GRU"``.

Returns:
Type[RNN]: returns GRU or LSTM RNN module
Returns
-------
type[RNN]
Returns the GRU or LSTM RNN class.
"""
if isinstance(cell_type, RNN):
rnn = cell_type
Expand Down
Loading