Categorical model
Categorical models run discrete-domain stochastic policies.
skrl provides a Python mixin (CategoricalMixin
) to assist in the creation of these types of models, allowing users to have full control over the function approximator definitions and architectures. Note that the use of this mixin must comply with the following rules:
The definition of multiple inheritance must always include the Model base class at the end.
class CategoricalModel(CategoricalMixin, Model): def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True): Model.__init__(self, observation_space, action_space, device) CategoricalMixin.__init__(self, unnormalized_log_prob)
The Model base class constructor must be invoked before the mixins constructor.
class CategoricalModel(CategoricalMixin, Model): def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True): Model.__init__(self, observation_space, action_space, device) CategoricalMixin.__init__(self, unnormalized_log_prob)
Concept
Basic usage
Multi-Layer Perceptron (MLP)
Convolutional Neural Network (CNN)
Recurrent Neural Network (RNN)
Gated Recurrent Unit RNN (GRU)
Long Short-Term Memory RNN (LSTM)
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, CategoricalMixin
5
6
7# define the model
8class MLP(CategoricalMixin, Model):
9 def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True):
10 Model.__init__(self, observation_space, action_space, device)
11 CategoricalMixin.__init__(self, unnormalized_log_prob)
12
13 self.net = nn.Sequential(nn.Linear(self.num_observations, 64),
14 nn.ReLU(),
15 nn.Linear(64, 32),
16 nn.ReLU(),
17 nn.Linear(32, self.num_actions))
18
19 def compute(self, inputs, role):
20 return self.net(inputs["states"]), {}
21
22
23# instantiate the model (assumes there is a wrapped environment: env)
24policy = MLP(observation_space=env.observation_space,
25 action_space=env.action_space,
26 device=env.device,
27 unnormalized_log_prob=True)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, CategoricalMixin
6
7
8# define the model
9class MLP(CategoricalMixin, Model):
10 def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True):
11 Model.__init__(self, observation_space, action_space, device)
12 CategoricalMixin.__init__(self, unnormalized_log_prob)
13
14 self.fc1 = nn.Linear(self.num_observations, 64)
15 self.fc2 = nn.Linear(64, 32)
16 self.logits = nn.Linear(32, self.num_actions)
17
18 def compute(self, inputs, role):
19 x = self.fc1(inputs["states"])
20 x = F.relu(x)
21 x = self.fc2(x)
22 x = F.relu(x)
23 return self.logits(x), {}
24
25
26# instantiate the model (assumes there is a wrapped environment: env)
27policy = MLP(observation_space=env.observation_space,
28 action_space=env.action_space,
29 device=env.device,
30 unnormalized_log_prob=True)
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, CategoricalMixin
5
6
7# define the model
8class CNN(CategoricalMixin, Model):
9 def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True):
10 Model.__init__(self, observation_space, action_space, device)
11 CategoricalMixin.__init__(self, unnormalized_log_prob)
12
13 self.net = nn.Sequential(nn.Conv2d(3, 32, kernel_size=8, stride=4),
14 nn.ReLU(),
15 nn.Conv2d(32, 64, kernel_size=4, stride=2),
16 nn.ReLU(),
17 nn.Conv2d(64, 64, kernel_size=3, stride=1),
18 nn.ReLU(),
19 nn.Flatten(),
20 nn.Linear(1024, 512),
21 nn.ReLU(),
22 nn.Linear(512, 16),
23 nn.Tanh(),
24 nn.Linear(16, 64),
25 nn.Tanh(),
26 nn.Linear(64, 32),
27 nn.Tanh(),
28 nn.Linear(32, self.num_actions))
29
30 def compute(self, inputs, role):
31 # permute (samples, width * height * channels) -> (samples, channels, width, height)
32 return self.net(inputs["states"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)), {}
33
34
35# instantiate the model (assumes there is a wrapped environment: env)
36policy = CNN(observation_space=env.observation_space,
37 action_space=env.action_space,
38 device=env.device,
39 unnormalized_log_prob=True)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, CategoricalMixin
6
7
8# define the model
9class CNN(CategoricalMixin, Model):
10 def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True):
11 Model.__init__(self, observation_space, action_space, device)
12 CategoricalMixin.__init__(self, unnormalized_log_prob)
13
14 self.conv1 = nn.Conv2d(3, 32, kernel_size=8, stride=4)
15 self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
16 self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
17 self.fc1 = nn.Linear(1024, 512)
18 self.fc2 = nn.Linear(512, 16)
19 self.fc3 = nn.Linear(16, 64)
20 self.fc4 = nn.Linear(64, 32)
21 self.fc5 = nn.Linear(32, self.num_actions)
22
23 def compute(self, inputs, role):
24 # permute (samples, width * height * channels) -> (samples, channels, width, height)
25 x = inputs["states"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)
26 x = self.conv1(x)
27 x = F.relu(x)
28 x = self.conv2(x)
29 x = F.relu(x)
30 x = self.conv3(x)
31 x = F.relu(x)
32 x = torch.flatten(x, start_dim=1)
33 x = self.fc1(x)
34 x = F.relu(x)
35 x = self.fc2(x)
36 x = torch.tanh(x)
37 x = self.fc3(x)
38 x = torch.tanh(x)
39 x = self.fc4(x)
40 x = torch.tanh(x)
41 x = self.fc5(x)
42 return x, {}
43
44
45# instantiate the model (assumes there is a wrapped environment: env)
46policy = CNN(observation_space=env.observation_space,
47 action_space=env.action_space,
48 device=env.device,
49 unnormalized_log_prob=True)
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden stateThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden states ordered according to the model specification
The
.compute()
method must inlcude, under the"rnn"
key of the returned dictionary, a list of each final hidden state
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, CategoricalMixin
5
6
7# define the model
8class RNN(CategoricalMixin, Model):
9 def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True,
10 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
11 Model.__init__(self, observation_space, action_space, device)
12 CategoricalMixin.__init__(self, unnormalized_log_prob)
13
14 self.num_envs = num_envs
15 self.num_layers = num_layers
16 self.hidden_size = hidden_size # Hout
17 self.sequence_length = sequence_length
18
19 self.rnn = nn.RNN(input_size=self.num_observations,
20 hidden_size=self.hidden_size,
21 num_layers=self.num_layers,
22 batch_first=True) # batch_first -> (batch, sequence, features)
23
24 self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
25 nn.ReLU(),
26 nn.Linear(64, 32),
27 nn.ReLU(),
28 nn.Linear(32, self.num_actions))
29
30 def get_specification(self):
31 # batch size (N) is the number of envs during rollout
32 return {"rnn": {"sequence_length": self.sequence_length,
33 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
34
35 def compute(self, inputs, role):
36 states = inputs["states"]
37 terminated = inputs.get("terminated", None)
38 hidden_states = inputs["rnn"][0]
39
40 # training
41 if self.training:
42 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
43 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
44 # get the hidden states corresponding to the initial sequence
45 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
46
47 # reset the RNN state in the middle of a sequence
48 if terminated is not None and torch.any(terminated):
49 rnn_outputs = []
50 terminated = terminated.view(-1, self.sequence_length)
51 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
52
53 for i in range(len(indexes) - 1):
54 i0, i1 = indexes[i], indexes[i + 1]
55 rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
56 hidden_states[:, (terminated[:,i1-1]), :] = 0
57 rnn_outputs.append(rnn_output)
58
59 rnn_output = torch.cat(rnn_outputs, dim=1)
60 # no need to reset the RNN state in the sequence
61 else:
62 rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
63 # rollout
64 else:
65 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
66 rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
67
68 # flatten the RNN output
69 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
70
71 return self.net(rnn_output), {"rnn": [hidden_states]}
72
73
74# instantiate the model (assumes there is a wrapped environment: env)
75policy = RNN(observation_space=env.observation_space,
76 action_space=env.action_space,
77 device=env.device,
78 unnormalized_log_prob=True,
79 num_envs=env.num_envs,
80 num_layers=1,
81 hidden_size=64,
82 sequence_length=10)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, CategoricalMixin
6
7
8# define the model
9class RNN(CategoricalMixin, Model):
10 def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True,
11 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
12 Model.__init__(self, observation_space, action_space, device)
13 CategoricalMixin.__init__(self, unnormalized_log_prob)
14
15 self.num_envs = num_envs
16 self.num_layers = num_layers
17 self.hidden_size = hidden_size # Hout
18 self.sequence_length = sequence_length
19
20 self.rnn = nn.RNN(input_size=self.num_observations,
21 hidden_size=self.hidden_size,
22 num_layers=self.num_layers,
23 batch_first=True) # batch_first -> (batch, sequence, features)
24
25 self.fc1 = nn.Linear(self.hidden_size, 64)
26 self.fc2 = nn.Linear(64, 32)
27 self.logits = nn.Linear(32, self.num_actions)
28
29 def get_specification(self):
30 # batch size (N) is the number of envs during rollout
31 return {"rnn": {"sequence_length": self.sequence_length,
32 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
33
34 def compute(self, inputs, role):
35 states = inputs["states"]
36 terminated = inputs.get("terminated", None)
37 hidden_states = inputs["rnn"][0]
38
39 # training
40 if self.training:
41 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
42 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
43 # get the hidden states corresponding to the initial sequence
44 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
45
46 # reset the RNN state in the middle of a sequence
47 if terminated is not None and torch.any(terminated):
48 rnn_outputs = []
49 terminated = terminated.view(-1, self.sequence_length)
50 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
51
52 for i in range(len(indexes) - 1):
53 i0, i1 = indexes[i], indexes[i + 1]
54 rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
55 hidden_states[:, (terminated[:,i1-1]), :] = 0
56 rnn_outputs.append(rnn_output)
57
58 rnn_output = torch.cat(rnn_outputs, dim=1)
59 # no need to reset the RNN state in the sequence
60 else:
61 rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
62 # rollout
63 else:
64 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
65 rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
66
67 # flatten the RNN output
68 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
69
70 x = self.fc1(rnn_output)
71 x = F.relu(x)
72 x = self.fc2(x)
73 x = F.relu(x)
74
75 return self.logits(x), {"rnn": [hidden_states]}
76
77
78# instantiate the model (assumes there is a wrapped environment: env)
79policy = RNN(observation_space=env.observation_space,
80 action_space=env.action_space,
81 device=env.device,
82 unnormalized_log_prob=True,
83 num_envs=env.num_envs,
84 num_layers=1,
85 hidden_size=64,
86 sequence_length=10)
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden stateThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden states ordered according to the model specification
The
.compute()
method must inlcude, under the"rnn"
key of the returned dictionary, a list of each final hidden state
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, CategoricalMixin
5
6
7# define the model
8class GRU(CategoricalMixin, Model):
9 def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True,
10 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
11 Model.__init__(self, observation_space, action_space, device)
12 CategoricalMixin.__init__(self, unnormalized_log_prob)
13
14 self.num_envs = num_envs
15 self.num_layers = num_layers
16 self.hidden_size = hidden_size # Hout
17 self.sequence_length = sequence_length
18
19 self.gru = nn.GRU(input_size=self.num_observations,
20 hidden_size=self.hidden_size,
21 num_layers=self.num_layers,
22 batch_first=True) # batch_first -> (batch, sequence, features)
23
24 self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
25 nn.ReLU(),
26 nn.Linear(64, 32),
27 nn.ReLU(),
28 nn.Linear(32, self.num_actions))
29
30 def get_specification(self):
31 # batch size (N) is the number of envs during rollout
32 return {"rnn": {"sequence_length": self.sequence_length,
33 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
34
35 def compute(self, inputs, role):
36 states = inputs["states"]
37 terminated = inputs.get("terminated", None)
38 hidden_states = inputs["rnn"][0]
39
40 # training
41 if self.training:
42 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
43 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
44 # get the hidden states corresponding to the initial sequence
45 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
46
47 # reset the RNN state in the middle of a sequence
48 if terminated is not None and torch.any(terminated):
49 rnn_outputs = []
50 terminated = terminated.view(-1, self.sequence_length)
51 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
52
53 for i in range(len(indexes) - 1):
54 i0, i1 = indexes[i], indexes[i + 1]
55 rnn_output, hidden_states = self.gru(rnn_input[:,i0:i1,:], hidden_states)
56 hidden_states[:, (terminated[:,i1-1]), :] = 0
57 rnn_outputs.append(rnn_output)
58
59 rnn_output = torch.cat(rnn_outputs, dim=1)
60 # no need to reset the RNN state in the sequence
61 else:
62 rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
63 # rollout
64 else:
65 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
66 rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
67
68 # flatten the RNN output
69 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
70
71 return self.net(rnn_output), {"rnn": [hidden_states]}
72
73
74# instantiate the model (assumes there is a wrapped environment: env)
75policy = GRU(observation_space=env.observation_space,
76 action_space=env.action_space,
77 device=env.device,
78 unnormalized_log_prob=True,
79 num_envs=env.num_envs,
80 num_layers=1,
81 hidden_size=64,
82 sequence_length=10)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, CategoricalMixin
6
7
8# define the model
9class GRU(CategoricalMixin, Model):
10 def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True,
11 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
12 Model.__init__(self, observation_space, action_space, device)
13 CategoricalMixin.__init__(self, unnormalized_log_prob)
14
15 self.num_envs = num_envs
16 self.num_layers = num_layers
17 self.hidden_size = hidden_size # Hout
18 self.sequence_length = sequence_length
19
20 self.gru = nn.GRU(input_size=self.num_observations,
21 hidden_size=self.hidden_size,
22 num_layers=self.num_layers,
23 batch_first=True) # batch_first -> (batch, sequence, features)
24
25 self.fc1 = nn.Linear(self.hidden_size, 64)
26 self.fc2 = nn.Linear(64, 32)
27 self.logits = nn.Linear(32, self.num_actions)
28
29 def get_specification(self):
30 # batch size (N) is the number of envs during rollout
31 return {"rnn": {"sequence_length": self.sequence_length,
32 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
33
34 def compute(self, inputs, role):
35 states = inputs["states"]
36 terminated = inputs.get("terminated", None)
37 hidden_states = inputs["rnn"][0]
38
39 # training
40 if self.training:
41 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
42 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
43 # get the hidden states corresponding to the initial sequence
44 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
45
46 # reset the RNN state in the middle of a sequence
47 if terminated is not None and torch.any(terminated):
48 rnn_outputs = []
49 terminated = terminated.view(-1, self.sequence_length)
50 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
51
52 for i in range(len(indexes) - 1):
53 i0, i1 = indexes[i], indexes[i + 1]
54 rnn_output, hidden_states = self.gru(rnn_input[:,i0:i1,:], hidden_states)
55 hidden_states[:, (terminated[:,i1-1]), :] = 0
56 rnn_outputs.append(rnn_output)
57
58 rnn_output = torch.cat(rnn_outputs, dim=1)
59 # no need to reset the RNN state in the sequence
60 else:
61 rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
62 # rollout
63 else:
64 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
65 rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
66
67 # flatten the RNN output
68 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
69
70 x = self.fc1(rnn_output)
71 x = F.relu(x)
72 x = self.fc2(x)
73 x = F.relu(x)
74
75 return self.logits(x), {"rnn": [hidden_states]}
76
77
78# instantiate the model (assumes there is a wrapped environment: env)
79policy = GRU(observation_space=env.observation_space,
80 action_space=env.action_space,
81 device=env.device,
82 unnormalized_log_prob=True,
83 num_envs=env.num_envs,
84 num_layers=1,
85 hidden_size=64,
86 sequence_length=10)
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden/cell statesThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden/cell states ordered according to the model specification
The
.compute()
method must inlcude, under the"rnn"
key of the returned dictionary, a list of each final hidden/cell states
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, CategoricalMixin
5
6
7# define the model
8class LSTM(CategoricalMixin, Model):
9 def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True,
10 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
11 Model.__init__(self, observation_space, action_space, device)
12 CategoricalMixin.__init__(self, unnormalized_log_prob)
13
14 self.num_envs = num_envs
15 self.num_layers = num_layers
16 self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
17 self.sequence_length = sequence_length
18
19 self.lstm = nn.LSTM(input_size=self.num_observations,
20 hidden_size=self.hidden_size,
21 num_layers=self.num_layers,
22 batch_first=True) # batch_first -> (batch, sequence, features)
23
24 self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
25 nn.ReLU(),
26 nn.Linear(64, 32),
27 nn.ReLU(),
28 nn.Linear(32, self.num_actions))
29
30 def get_specification(self):
31 # batch size (N) is the number of envs during rollout
32 return {"rnn": {"sequence_length": self.sequence_length,
33 "sizes": [(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
34 (self.num_layers, self.num_envs, self.hidden_size)]}} # cell states (D ∗ num_layers, N, Hcell)
35
36 def compute(self, inputs, role):
37 states = inputs["states"]
38 terminated = inputs.get("terminated", None)
39 hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
40
41 # training
42 if self.training:
43 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
44 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
45 cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1]) # (D * num_layers, N, L, Hcell)
46 # get the hidden/cell states corresponding to the initial sequence
47 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
48 cell_states = cell_states[:,:,0,:].contiguous() # (D * num_layers, N, Hcell)
49
50 # reset the RNN state in the middle of a sequence
51 if terminated is not None and torch.any(terminated):
52 rnn_outputs = []
53 terminated = terminated.view(-1, self.sequence_length)
54 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
55
56 for i in range(len(indexes) - 1):
57 i0, i1 = indexes[i], indexes[i + 1]
58 rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
59 hidden_states[:, (terminated[:,i1-1]), :] = 0
60 cell_states[:, (terminated[:,i1-1]), :] = 0
61 rnn_outputs.append(rnn_output)
62
63 rnn_states = (hidden_states, cell_states)
64 rnn_output = torch.cat(rnn_outputs, dim=1)
65 # no need to reset the RNN state in the sequence
66 else:
67 rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
68 # rollout
69 else:
70 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
71 rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
72
73 # flatten the RNN output
74 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
75
76 return self.net(rnn_output), {"rnn": [rnn_states[0], rnn_states[1]]}
77
78
79# instantiate the model (assumes there is a wrapped environment: env)
80policy = LSTM(observation_space=env.observation_space,
81 action_space=env.action_space,
82 device=env.device,
83 unnormalized_log_prob=True,
84 num_envs=env.num_envs,
85 num_layers=1,
86 hidden_size=64,
87 sequence_length=10)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, CategoricalMixin
6
7
8# define the model
9class LSTM(CategoricalMixin, Model):
10 def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True,
11 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
12 Model.__init__(self, observation_space, action_space, device)
13 CategoricalMixin.__init__(self, unnormalized_log_prob)
14
15 self.num_envs = num_envs
16 self.num_layers = num_layers
17 self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
18 self.sequence_length = sequence_length
19
20 self.lstm = nn.LSTM(input_size=self.num_observations,
21 hidden_size=self.hidden_size,
22 num_layers=self.num_layers,
23 batch_first=True) # batch_first -> (batch, sequence, features)
24
25 self.fc1 = nn.Linear(self.hidden_size, 64)
26 self.fc2 = nn.Linear(64, 32)
27 self.logits = nn.Linear(32, self.num_actions)
28
29 def get_specification(self):
30 # batch size (N) is the number of envs during rollout
31 return {"rnn": {"sequence_length": self.sequence_length,
32 "sizes": [(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
33 (self.num_layers, self.num_envs, self.hidden_size)]}} # cell states (D ∗ num_layers, N, Hcell)
34
35 def compute(self, inputs, role):
36 states = inputs["states"]
37 terminated = inputs.get("terminated", None)
38 hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
39
40 # training
41 if self.training:
42 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
43 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
44 cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1]) # (D * num_layers, N, L, Hcell)
45 # get the hidden/cell states corresponding to the initial sequence
46 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
47 cell_states = cell_states[:,:,0,:].contiguous() # (D * num_layers, N, Hcell)
48
49 # reset the RNN state in the middle of a sequence
50 if terminated is not None and torch.any(terminated):
51 rnn_outputs = []
52 terminated = terminated.view(-1, self.sequence_length)
53 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
54
55 for i in range(len(indexes) - 1):
56 i0, i1 = indexes[i], indexes[i + 1]
57 rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
58 hidden_states[:, (terminated[:,i1-1]), :] = 0
59 cell_states[:, (terminated[:,i1-1]), :] = 0
60 rnn_outputs.append(rnn_output)
61
62 rnn_states = (hidden_states, cell_states)
63 rnn_output = torch.cat(rnn_outputs, dim=1)
64 # no need to reset the RNN state in the sequence
65 else:
66 rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
67 # rollout
68 else:
69 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
70 rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
71
72 # flatten the RNN output
73 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
74
75 x = self.fc1(rnn_output)
76 x = F.relu(x)
77 x = self.fc2(x)
78 x = F.relu(x)
79
80 return self.logits(x), {"rnn": [rnn_states[0], rnn_states[1]]}
81
82
83# instantiate the model (assumes there is a wrapped environment: env)
84policy = LSTM(observation_space=env.observation_space,
85 action_space=env.action_space,
86 device=env.device,
87 unnormalized_log_prob=True,
88 num_envs=env.num_envs,
89 num_layers=1,
90 hidden_size=64,
91 sequence_length=10)
API
- class skrl.models.torch.categorical.CategoricalMixin(unnormalized_log_prob: bool = True, role: str = '')
Bases:
object
- __init__(unnormalized_log_prob: bool = True, role: str = '') None
Categorical mixin model (stochastic model)
- Parameters
unnormalized_log_prob (bool, optional) – Flag to indicate how to be interpreted the model’s output (default:
True
). If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum)role (str, optional) – Role play by the model (default:
""
)
Example:
# define the model >>> import torch >>> import torch.nn as nn >>> from skrl.models.torch import Model, CategoricalMixin >>> >>> class Policy(CategoricalMixin, Model): ... def __init__(self, observation_space, action_space, device="cuda:0", unnormalized_log_prob=True): ... Model.__init__(self, observation_space, action_space, device) ... CategoricalMixin.__init__(self, unnormalized_log_prob) ... ... self.net = nn.Sequential(nn.Linear(self.num_observations, 32), ... nn.ELU(), ... nn.Linear(32, 32), ... nn.ELU(), ... nn.Linear(32, self.num_actions)) ... ... def compute(self, inputs, role): ... return self.net(inputs["states"]), {} ... >>> # given an observation_space: gym.spaces.Box with shape (4,) >>> # and an action_space: gym.spaces.Discrete with n = 2 >>> model = Policy(observation_space, action_space) >>> >>> print(model) Policy( (net): Sequential( (0): Linear(in_features=4, out_features=32, bias=True) (1): ELU(alpha=1.0) (2): Linear(in_features=32, out_features=32, bias=True) (3): ELU(alpha=1.0) (4): Linear(in_features=32, out_features=2, bias=True) ) )
- act(inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = '') Tuple[torch.Tensor, Optional[torch.Tensor], Mapping[str, Union[torch.Tensor, Any]]]
Act stochastically in response to the state of the environment
- Parameters
inputs (dict where the values are typically torch.Tensor) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)
- Returns
Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function. The third component is a dictionary containing the network output
"net_output"
and extra output values- Return type
tuple of torch.Tensor, torch.Tensor or None, and dictionary
Example:
>>> # given a batch of sample states with shape (4096, 4) >>> actions, log_prob, outputs = model.act({"states": states}) >>> print(actions.shape, log_prob.shape, outputs["net_output"].shape) torch.Size([4096, 1]) torch.Size([4096, 1]) torch.Size([4096, 2])
- distribution(role: str = '') torch.distributions.categorical.Categorical
Get the current distribution of the model
- Returns
Distribution of the model
- Return type
torch.distributions.Categorical
- Parameters
role (str, optional) – Role play by the model (default:
""
)
Example:
>>> distribution = model.distribution() >>> print(distribution) Categorical(probs: torch.Size([4096, 2]), logits: torch.Size([4096, 2]))