Action

Action data type with tensor.

*batch_shape depends on the input of the algorithm you are using.

  • simple batch: *batch_shape = (batch_size,)
  • sequence batch: *batch_shape = (seq_batch_size, seq_len)

discrete_action and continuous_action are tensors of the discrete and continuous action spaces respectively. If the action space is only discrete, continuous_action is empty tensor (*batch_shape, 0) and the other case is vice versa.

Module: aine_drl.exp

@dataclass(frozen=True)
class Action

Fields

discrete_action

The discrete action tensor (*batch_shape, num_discrete_branches).

discrete_action: torch.Tensor

continuous_action

The continuous action tensor (*batch_shape, num_continuous_branches).

continuous_action: torch.Tensor

Properties

num_discrete_branches

The number of discrete action branches.

@property
def num_discrete_branches(self) -> int

num_continuous_branches

The number of continuous action branches.

@property
def num_continuous_branches(self) -> int

num_branches

The number of action branches which is equal to num_discrete_branches + num_continuous_branches.

@property
def num_branches(self) -> int

batch_shape

The batch shape *batch_shape of the action tensor.

@property
def batch_shape(self) -> torch.Size

Methods

transform()

Transform the action tensor with the callable function.

def transform(self, func: Callable[[Tensor], Tensor]) -> Action

Parameters:

Name Description
func ((Tensor) -> Tensor) The callable function to transform each action tensor.

Returns:

Name Description
action (Action) The transformed action.

getitem()

Get a batch of Action from the Action instance. Note that it's recommended to use range slicing instead of indexing.

def __getitem__(self, idx) -> Action

from_iter()

Create an Action batch instance from iterable of Action. Each item in the iterable must consist of the same action spaces. For example, if first Action has only discrete actions and the number of discrete action branches is 4, second Action must be same.

@staticmethod
def from_iter(actions: Iterable[Action]) -> Action

Example:

import torch
from aine_drl.exp import Action

batch_size1 = 2
batch_size2 = 3

action1 = Action(
    continuous_action=torch.randn(batch_size1, 2)
)

action2 = Action(
    continuous_action=torch.randn(batch_size2, 2)
)

action = Action.from_iter([action1, action2])
>>> action.discrete_action.shape, action.continuous_action.shape
(torch.Size([5, 0]), torch.Size([5, 2]))