Action
Action data type with tensor.
*batch_shape
depends on the input of the algorithm you are using.
- simple batch:
*batch_shape
=(batch_size,)
- sequence batch:
*batch_shape
=(seq_batch_size, seq_len)
discrete_action
and continuous_action
are tensors of the discrete and continuous action spaces respectively. If the action space is only discrete, continuous_action
is empty tensor (*batch_shape, 0)
and the other case is vice versa.
Module: aine_drl.exp
@dataclass(frozen=True)
class Action
Fields
discrete_action
The discrete action tensor (*batch_shape, num_discrete_branches)
.
discrete_action: torch.Tensor
continuous_action
The continuous action tensor (*batch_shape, num_continuous_branches)
.
continuous_action: torch.Tensor
Properties
num_discrete_branches
The number of discrete action branches.
@property
def num_discrete_branches(self) -> int
num_continuous_branches
The number of continuous action branches.
@property
def num_continuous_branches(self) -> int
num_branches
The number of action branches which is equal to num_discrete_branches
+ num_continuous_branches
.
@property
def num_branches(self) -> int
batch_shape
The batch shape *batch_shape
of the action tensor.
@property
def batch_shape(self) -> torch.Size
Methods
transform()
Transform the action tensor with the callable function.
def transform(self, func: Callable[[Tensor], Tensor]) -> Action
Parameters:
Name | Description |
---|---|
func ((Tensor) -> Tensor ) | The callable function to transform each action tensor. |
Returns:
Name | Description |
---|---|
action (Action ) | The transformed action. |
getitem()
Get a batch of Action
from the Action
instance. Note that it's recommended to use range slicing instead of indexing.
def __getitem__(self, idx) -> Action
from_iter()
Create an Action
batch instance from iterable of Action
. Each item in the iterable must consist of the same action spaces. For example, if first Action
has only discrete actions and the number of discrete action branches is 4
, second Action
must be same.
@staticmethod
def from_iter(actions: Iterable[Action]) -> Action
Example:
import torch
from aine_drl.exp import Action
batch_size1 = 2
batch_size2 = 3
action1 = Action(
continuous_action=torch.randn(batch_size1, 2)
)
action2 = Action(
continuous_action=torch.randn(batch_size2, 2)
)
action = Action.from_iter([action1, action2])
>>> action.discrete_action.shape, action.continuous_action.shape
(torch.Size([5, 0]), torch.Size([5, 2]))