Tabular model¶
Tabular models run discrete-domain deterministic/stochastic policies.
skrl provides a Python mixin (TabularMixin
) to assist in the creation of these types of models, allowing users to have full control over the table definitions. Note that the use of this mixin must comply with the following rules:
The definition of multiple inheritance must always include the Model base class at the end.
The Model base class constructor must be invoked before the mixins constructor.
class TabularModel(TabularMixin, Model):
def __init__(self, observation_space, action_space, device=None, num_envs=1):
Model.__init__(self, observation_space, action_space, device)
TabularMixin.__init__(self, num_envs)
Usage¶
import torch
from skrl.models.torch import Model, TabularMixin
# define the model
class EpilonGreedyPolicy(TabularMixin, Model):
def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
Model.__init__(self, observation_space, action_space, device)
TabularMixin.__init__(self, num_envs)
self.epsilon = epsilon
self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), dtype=torch.float32)
def compute(self, inputs, role):
states = inputs["states"]
actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), states],
dim=-1, keepdim=True).view(-1,1)
indexes = (torch.rand(states.shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
if indexes.numel():
actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
return actions, {}
# instantiate the model (assumes there is a wrapped environment: env)
policy = EpilonGreedyPolicy(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
num_envs=env.num_envs,
epsilon=0.15)
API (PyTorch)¶
- class skrl.models.torch.tabular.TabularMixin(num_envs: int = 1, role: str = '')¶
Bases:
object
- __init__(num_envs: int = 1, role: str = '') None ¶
Tabular mixin model
- Parameters:
Example:
# define the model >>> import torch >>> from skrl.models.torch import Model, TabularMixin >>> >>> class GreedyPolicy(TabularMixin, Model): ... def __init__(self, observation_space, action_space, device="cuda:0", num_envs=1): ... Model.__init__(self, observation_space, action_space, device) ... TabularMixin.__init__(self, num_envs) ... ... self.table = torch.ones((num_envs, self.num_observations, self.num_actions), ... dtype=torch.float32, device=self.device) ... ... def compute(self, inputs, role): ... actions = torch.argmax(self.table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]], ... dim=-1, keepdim=True).view(-1,1) ... return actions, {} ... >>> # given an observation_space: gym.spaces.Discrete with n=100 >>> # and an action_space: gym.spaces.Discrete with n=5 >>> model = GreedyPolicy(observation_space, action_space, num_envs=1) >>> >>> print(model) GreedyPolicy( (table): Tensor(shape=[1, 100, 5]) )
- act(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor, torch.Tensor | None, Mapping[str, torch.Tensor | Any]] ¶
Act in response to the state of the environment
- Parameters:
inputs (dict where the values are typically torch.Tensor) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)
- Returns:
Model output. The first component is the action to be taken by the agent. The second component is
None
. The third component is a dictionary containing extra output values- Return type:
tuple of torch.Tensor, torch.Tensor or None, and dict
Example:
>>> # given a batch of sample states with shape (1, 100) >>> actions, _, outputs = model.act({"states": states}) >>> print(actions[0], outputs) tensor([[3]], device='cuda:0') {}
- table() torch.Tensor ¶
Return the Q-table
- Returns:
Q-table
- Return type:
Example:
>>> output = model.table() >>> print(output.shape) torch.Size([1, 100, 5])