None
, cpu
, cuda
, cuda:0
and other devices of torch.device()
argument class: REINFORCENetwork
You need to implement below methods.
@abstractmethod
def forward(
self,
obs: Observation
) -> PolicyDist
Parameters:
Name | Description | Shape |
---|---|---|
obs (Observation ) | Observation batch tuple. | *batch_shape = (batch_size,) details in Observation docs |
Returns:
Name | Description | Shape |
---|---|---|
policy_dist (PolicyDist ) | Policy distribution \(\pi(a \vert s)\). | *batch_shape = (batch_size,) details in PolicyDist docs |