Agent
An agent is an actor that can observe and interact with the environment. Depending on the reinforcement learning algorithms, the way the agent is trained and selects actions is different.
If you want to implement your own agent, you can inherit from the abstract class Agent
and implement the abstract methods.
Module: aine_drl.agent
class Agent(ABC)
Constructor
def __init__(
self,
num_envs: int,
network: Network,
behavior_type: BehaviorType = BehaviorType.TRAIN,
)
Parameters:
Name | Description |
---|---|
num_envs (int ) | The number of environments the agent is interacting with. |
network (Network ) | The network of the agent. |
behavior_type (BehaviorType ) | (default = BehaviorType.TRAIN ) The behavior type of the agent. |
Properties
name
The name of the agent. You need to implement this property when you inherit from Agent
.
@property
@abstractmethod
def name(self) -> str
device
The device where the agent is running.
@property
def device(self) -> torch.device
num_envs
The number of environments the agent is interacting with. This is the number of environments in the vectorized environment.
@property
def num_envs(self) -> int
training_steps
The number of training steps the network has been trained for. This equals to the times the optimizer take steps.
@property
def training_steps(self) -> int
behavior_type
The behavior type of the agent. When you train the agent, the behavior type is BehaviorType.TRAIN
. When you inference the agent, the behavior type is BehaviorType.INFERENCE
.
@property
def behavior_type(self) -> BehaviorType
@behavior_type.setter
def behavior_type(self, value: BehaviorType)
log_keys
The keys of log data that the agent want to log. You need to override this property when you want to log some data.
@property
def log_keys(self) -> tuple[str, ...]
log_data
The log data of the agent. The data is a dictionary with the keys in log_keys
and values with tuple (value, time)
. You need to override this property when you want to log some data.
@property
def log_data(self) -> dict[str, tuple[Any, float]]
state_dict
The state dictionary of the agent. The state dictionary contains the state of the agent, including the network and other data. You can save them to a PyTorch file and load them later. You need to override this property when you want to save the state of the agent.
@property
def state_dict(self) -> dict
Methods
select_action()
Select actions from the Observation
.
def select_action(self, obs: Observation) -> Action
Parameters:
Name | Description | Shape |
---|---|---|
obs (Observation ) | One-step observation batch tuple. | *batch_shape = (num_envs,) details in Observation docs |
Returns:
Name | Description | Shape |
---|---|---|
action (Action ) | One-step action batch. | *batch_shape = (num_envs,) details in Action docs |
_select_action_train()
Select actions from the Observation
when training. You need to implement this method when you inherit from Agent
.
@abstractmethod
def _select_action_train(self, obs: Observation) -> Action
The parameters and return values are the same as select_action()
.
_select_action_inference()
Select actions from the Observation
when inference. You need to implement this method when you inherit from Agent
.
@abstractmethod
def _select_action_inference(self, obs: Observation) -> Action
The parameters and return values are the same as select_action()
.
update()
Update and train the agent.
def update(self, exp: Experience)
Parameters:
Name | Description | Shape |
---|---|---|
exp (Experience ) | One-step experience tuple. | *batch_shape = (num_envs,) details in Experience docs |
_update_train()
Update and train the agent when training. You need to implement this method when you inherit from Agent
.
@abstractmethod
def _update_train(self, exp: Experience)
The parameters are the same as update()
.
_update_inference()
Update and train the agent when inference. You need to implement this method when you inherit from Agent
.
@abstractmethod
def _update_inference(self, exp: Experience)
The parameters are the same as update()
.
_tick_training_steps()
Tick the training steps.
def _tick_training_steps(self)
load_state_dict()
Load the state dictionary of the agent. You need to override this method when you want to load the state of the agent.
def load_state_dict(self, state_dict: dict)