Tabular model
Tabular models run discrete-domain deterministic/stochastic policies.
skrl provides a Python mixin (TabularMixin
) to assist in the creation of these types of models, allowing users to have full control over the table definitions. Note that the use of this mixin must comply with the following rules:
The definition of multiple inheritance must always include the Model base class at the end.
class TabularModel(TabularMixin, Model): def __init__(self, observation_space, action_space, device="cuda:0", num_envs=1): Model.__init__(self, observation_space, action_space, device) TabularMixin.__init__(self, num_envs)
The Model base class constructor must be invoked before the mixins constructor.
class TabularModel(TabularMixin, Model): def __init__(self, observation_space, action_space, device="cuda:0", num_envs=1): Model.__init__(self, observation_space, action_space, device) TabularMixin.__init__(self, num_envs)
Basic usage
1import torch
2
3from skrl.models.torch import Model, TabularMixin
4
5
6# define the model
7class EpilonGreedyPolicy(TabularMixin, Model):
8 def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
9 Model.__init__(self, observation_space, action_space, device)
10 TabularMixin.__init__(self, num_envs)
11
12 self.epsilon = epsilon
13 self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), dtype=torch.float32)
14
15 def compute(self, inputs, role):
16 states = inputs["states"]
17 actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), states],
18 dim=-1, keepdim=True).view(-1,1)
19
20 indexes = (torch.rand(states.shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
21 if indexes.numel():
22 actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
23 return actions, {}
24
25
26# instantiate the model (assumes there is a wrapped environment: env)
27policy = EpilonGreedyPolicy(observation_space=env.observation_space,
28 action_space=env.action_space,
29 device=env.device,
30 num_envs=env.num_envs,
31 epsilon=0.15)
API
- class skrl.models.torch.tabular.TabularMixin(num_envs: int = 1, role: str = '')
Bases:
object
- __init__(num_envs: int = 1, role: str = '') None
Tabular mixin model
- Parameters
Example:
# define the model >>> import torch >>> from skrl.models.torch import Model, TabularMixin >>> >>> class GreedyPolicy(TabularMixin, Model): ... def __init__(self, observation_space, action_space, device="cuda:0", num_envs=1): ... Model.__init__(self, observation_space, action_space, device) ... TabularMixin.__init__(self, num_envs) ... ... self.table = torch.ones((num_envs, self.num_observations, self.num_actions), ... dtype=torch.float32, device=self.device) ... ... def compute(self, inputs, role): ... actions = torch.argmax(self.table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]], ... dim=-1, keepdim=True).view(-1,1) ... return actions, {} ... >>> # given an observation_space: gym.spaces.Discrete with n=100 >>> # and an action_space: gym.spaces.Discrete with n=5 >>> model = GreedyPolicy(observation_space, action_space, num_envs=1) >>> >>> print(model) GreedyPolicy( (table): Tensor(shape=[1, 100, 5]) )
- act(inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = '') Tuple[torch.Tensor, Optional[torch.Tensor], Mapping[str, Union[torch.Tensor, Any]]]
Act in response to the state of the environment
- Parameters
inputs (dict where the values are typically torch.Tensor) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)
- Returns
Model output. The first component is the action to be taken by the agent. The second component is
None
. The third component is a dictionary containing extra output values- Return type
tuple of torch.Tensor, torch.Tensor or None, and dictionary
Example:
>>> # given a batch of sample states with shape (1, 100) >>> actions, _, outputs = model.act({"states": states}) >>> print(actions[0], outputs) tensor([[3]], device='cuda:0') {}
- table() torch.Tensor
Return the Q-table
- Returns
Q-table
- Return type
Example:
>>> output = model.table() >>> print(output.shape) torch.Size([1, 100, 5])