None
, cpu
, cuda
, cuda:0
and other devices of torch.device()
argument class: A2CSharedNetwork
Note that since it uses the Actor-Critic architecure and the parameter sharing, the encoding layer must be shared between Actor and Critic.
You need to implement below methods.
@abstractmethod
def forward(
self,
obs: Observation
) -> tuple[PolicyDist, Tensor]
Parameters:
Name | Description | Shape |
---|---|---|
obs (Observation ) | Observation batch tuple. | *batch_shape = (batch_size,) details in Observation docs |
Returns:
Name | Description | Shape |
---|---|---|
policy_dist (PolicyDist ) | Policy distribution \(\pi(a \vert s)\). | *batch_shape = (batch_size,) details in PolicyDist docs |
state_value (Tensor ) | State value \(V(s)\) | (batch_size, 1) |