Tabular model¶
Tabular models run discrete-domain deterministic/stochastic policies.
skrl provides a Python mixin (TabularMixin) to assist in the creation of these types of models,
allowing users to have full control over the table definitions. Note that the use of this mixin must comply with
the following rules:
The definition of multiple inheritance must always include the Model base class at the end.
The Model base class constructor must be invoked before the mixins constructor.
class TabularModel(TabularMixin, Model):
def __init__(self, observation_space, state_space, action_space, device):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
TabularMixin.__init__(self)
Usage¶
import torch
import torch.nn as nn
from skrl.models.torch import Model, TabularMixin
# define the model
class EpsilonGreedyPolicy(TabularMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, epsilon=0.1):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
TabularMixin.__init__(self)
self.epsilon = epsilon
self.q_table = nn.Parameter(
torch.ones((self.num_observations, self.num_actions), dtype=torch.float32, device=self.device),
requires_grad=False,
)
def compute(self, inputs, role):
observations = inputs["observations"]
actions = torch.argmax(self.q_table[observations], dim=-1, keepdim=False)
# choose random actions for exploration according to epsilon
indexes = (torch.rand(observations.shape[0], device=self.device) < self.epsilon).nonzero().flatten()
if indexes.numel():
actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
return actions, {}
# instantiate the model (given a wrapped environment: `env`)
policy = EpsilonGreedyPolicy(env.observation_space, env.state_space, env.action_space, env.device)
API¶
PyTorch¶
Tabular mixin model. |
- class skrl.models.torch.tabular.TabularMixin(*, role: str = '')[source]¶
Bases:
objectTabular mixin model.
- Parameters:
role – Role played by the model.
Methods:
act(inputs, *[, role])Act in response to the observations/states of the environment.
tables(*[, role])Return the tables defined by the model.
- act(inputs: dict[str, Any], *, role: str = '') tuple[torch.Tensor, dict[str, Any]][source]¶
Act in response to the observations/states of the environment.
- Parameters:
inputs –
Model inputs. The most common keys are:
"observations": observation of the environment used to make the decision."states": state of the environment used to make the decision."taken_actions": actions taken by the policy for the given observations/states.
role – Role played by the model.
- Returns:
Model output. The first component is the expected action/value returned by the model. The second component is a dictionary containing extra output values according to the model.