PolicyDist
Policy distribution interface.
*batch_shape depends on the input of the algorithm you are using.
- simple batch:
*batch_shape=(batch_size,) - sequence batch:
*batch_shape=(seq_batch_size, seq_len)
Module: aine_drl.policy_dist
class PolicyDist(ABC)
Methods
sample()
Sample actions from the policy distribution.
@abstractmethod
def sample(self, reparam_trick: bool = False) -> Action
Parameters:
| Name | Description |
|---|---|
reparam_trick (bool) | (default = False) Whether to use reparameterization trick. |
Returns:
| Name | Shape |
|---|---|
action (Action) | Action shape depends on the constructor arguments |
log_prob()
Returns the log of the probability mass/density function according to the Action.
@abstractmethod
def log_prob(self, action: Action) -> torch.Tensor
Returns:
| Name | Shape |
|---|---|
log_prob (Tensor) | (*batch_shape, num_branches) |
joint_log_prob()
Returns the joint log of the probability mass/density function according to the action.
def joint_log_prob(self, action: Action) -> torch.Tensor
Returns:
| Name | Shape |
|---|---|
joint_log_prob (Tensor) | (*batch_shape, 1) |
entropy()
Returns the entropy of the policy distribution.
@abstractmethod
def entropy(self) -> torch.Tensor
Returns:
| Name | Shape |
|---|---|
entropy (Tensor) | (*batch_shape, num_branches) |
joint_entropy()
Returns the joint entropy of the policy distribution.
def joint_entropy(self) -> torch.Tensor
Returns:
| Name | Shape |
|---|---|
joint_entropy (Tensor) | (*batch_shape, 1) |