RecurrentNetwork
A recurrent network is a neural network that is used by an agent to predict the action given the observation. You can implement your own recurrent network by inheriting from the abstract class RecurrentNetwork
and implementing the abstract methods.
Module: aine_drl.net
class RecurrentNetwork(Network)
It's inherited from Network
. See Network docs.
Methods
hidden_state_shape()
Returns the shape of the rucurrent hidden state (D x num_layers, H)
.
num_layers
: the number of recurrent layersD
: 2 if bidirectional otherwise 1H
: the value depends on the type of the recurrent network
When you use LSTM, H
= H_cell
+ H_out
. See details in https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html.
When you use GRU, H
= H_out
. See details in https://pytorch.org/docs/stable/generated/torch.nn.GRU.html.
@abstractmethod
def hidden_state_shape(self) -> tuple[int, int]