Network

A network is a neural network that is used by an agent to predict the action given the observation. You can implement your own network by inheriting from the abstract class Network and implementing the abstract methods.

Module: aine_drl.net

class Network(ABC)

Properties

device

The device where the network is running.

@property
def device(self) -> torch.device

Methods

model()

The model of the network. The model() must be return torch.nn.Module object which includes all the layers of the network.

@abstractmethod
def model(self) -> torch.nn.Module

For example, assume that your network consists of encoding layer, actor layer, critic layer. Then, the model() method should return the following torch.nn.Module object:

import torch.nn as nn

class FooNet(nn.Module, Network):
    def __init__(self):
        super().__init__()

        self._encoder = nn.Linear(4, 64)
        self._actor = nn.Linear(64, 2)
        self._critic = nn.Linear(64, 1)
    
    def model(self):
        return self