Tabular model

Tabular models run discrete-domain deterministic/stochastic policies.



skrl provides a Python mixin (TabularMixin) to assist in the creation of these types of models, allowing users to have full control over the table definitions. Note that the use of this mixin must comply with the following rules:

  • The definition of multiple inheritance must always include the Model base class at the end.

  • The Model base class constructor must be invoked before the mixins constructor.

class TabularModel(TabularMixin, Model):
    def __init__(self, observation_space, action_space, device=None, num_envs=1):
        Model.__init__(self, observation_space, action_space, device)
        TabularMixin.__init__(self, num_envs)

Usage

import torch

from skrl.models.torch import Model, TabularMixin


# define the model
class EpilonGreedyPolicy(TabularMixin, Model):
    def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
        Model.__init__(self, observation_space, action_space, device)
        TabularMixin.__init__(self, num_envs)

        self.epsilon = epsilon
        self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), dtype=torch.float32)

    def compute(self, inputs, role):
        states = inputs["states"]
        actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), states],
                               dim=-1, keepdim=True).view(-1,1)

        indexes = (torch.rand(states.shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
        if indexes.numel():
            actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
        return actions, {}


# instantiate the model (assumes there is a wrapped environment: env)
policy = EpilonGreedyPolicy(observation_space=env.observation_space,
                            action_space=env.action_space,
                            device=env.device,
                            num_envs=env.num_envs,
                            epsilon=0.15)

API (PyTorch)

class skrl.models.torch.tabular.TabularMixin(num_envs: int = 1, role: str = '')

Bases: object

__init__(num_envs: int = 1, role: str = '') None

Tabular mixin model

Parameters:
  • num_envs (int, optional) – Number of environments (default: 1)

  • role (str, optional) – Role play by the model (default: "")

Example:

# define the model
>>> import torch
>>> from skrl.models.torch import Model, TabularMixin
>>>
>>> class GreedyPolicy(TabularMixin, Model):
...     def __init__(self, observation_space, action_space, device="cuda:0", num_envs=1):
...         Model.__init__(self, observation_space, action_space, device)
...         TabularMixin.__init__(self, num_envs)
...
...         self.table = torch.ones((num_envs, self.num_observations, self.num_actions),
...                                 dtype=torch.float32, device=self.device)
...
...     def compute(self, inputs, role):
...         actions = torch.argmax(self.table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]],
...                                dim=-1, keepdim=True).view(-1,1)
...         return actions, {}
...
>>> # given an observation_space: gym.spaces.Discrete with n=100
>>> # and an action_space: gym.spaces.Discrete with n=5
>>> model = GreedyPolicy(observation_space, action_space, num_envs=1)
>>>
>>> print(model)
GreedyPolicy(
  (table): Tensor(shape=[1, 100, 5])
)
act(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor, torch.Tensor | None, Mapping[str, torch.Tensor | Any]]

Act in response to the state of the environment

Parameters:
  • inputs (dict where the values are typically torch.Tensor) –

    Model inputs. The most common keys are:

    • "states": state of the environment used to make the decision

    • "taken_actions": actions taken by the policy for the given states

  • role (str, optional) – Role play by the model (default: "")

Returns:

Model output. The first component is the action to be taken by the agent. The second component is None. The third component is a dictionary containing extra output values

Return type:

tuple of torch.Tensor, torch.Tensor or None, and dict

Example:

>>> # given a batch of sample states with shape (1, 100)
>>> actions, _, outputs = model.act({"states": states})
>>> print(actions[0], outputs)
tensor([[3]], device='cuda:0') {}
table() torch.Tensor

Return the Q-table

Returns:

Q-table

Return type:

torch.Tensor

Example:

>>> output = model.table()
>>> print(output.shape)
torch.Size([1, 100, 5])