Adversarial Motion Priors (AMP)

AMP is a model-free, stochastic on-policy policy gradient algorithm (trained using a combination of GAIL and PPO) for adversarial learning of physics-based character animation. It enables characters to imitate diverse behaviors from large unstructured datasets, without the need for motion planners or other mechanisms for clip selection

Paper: AMP: Adversarial Motion Priors for Stylized Physics-Based Character Control

Algorithm implementation

Main notation/symbols:
- policy (\(\pi_\theta\)), value (\(V_\phi\)) and discriminator (\(D_\psi\)) function approximators
- states (\(s\)), actions (\(a\)), rewards (\(r\)), next states (\(s'\)), dones (\(d\))
- values (\(V\)), next values (\(V'\)), advantages (\(A\)), returns (\(R\))
- log probabilities (\(logp\))
- loss (\(L\))
- reference motion dataset (\(M\)), AMP replay buffer (\(B\))
- AMP states (\(s_{_{AMP}}\)), reference motion states (\(s_{_{AMP}}^{^M}\)), AMP states from replay buffer (\(s_{_{AMP}}^{^B}\))

Learning algorithm (_update(...))

compute_gae(...)
def \(\;f_{GAE} (r, d, V, V') \;\rightarrow\; R, A:\)
\(adv \leftarrow 0\)
\(A \leftarrow \text{zeros}(r)\)
# advantages computation
FOR each reverse iteration \(i\) up to the number of rows in \(r\) DO
\(adv \leftarrow r_i - V_i \, +\) discount_factor \((V' \, +\) lambda \(\neg d_i \; adv)\)
\(A_i \leftarrow adv\)
# returns computation
\(R \leftarrow A + V\)
# normalize advantages
\(A \leftarrow \dfrac{A - \bar{A}}{A_\sigma + 10^{-8}}\)
# update dataset of reference motions
collect reference motions of size amp_batch_size \(\rightarrow\;\) \(\text{append}(M)\)
# compute combined rewards
\(r_D \leftarrow -log(\text{max}( 1 - \hat{y}(D_\psi(s_{_{AMP}})), \, 10^{-4})) \qquad\) with \(\; \hat{y}(x) = \dfrac{1}{1 + e^{-x}}\)
\(r' \leftarrow\) task_reward_weight \(r \, +\) style_reward_weight discriminator_reward_scale \(r_D\)
# compute returns and advantages
\(R, A \leftarrow f_{GAE}(r', d, V, V')\)
# sample mini-batches from memory
[[\(s, a, logp, V, R, A, s_{_{AMP}}\)]] \(\leftarrow\) states, actions, log_prob, values, returns, advantages, AMP states
[[\(s_{_{AMP}}^{^M}\)]] \(\leftarrow\) AMP states from \(M\)
IF \(B\) is not empty THEN
[[\(s_{_{AMP}}^{^B}\)]] \(\leftarrow\) AMP states from \(B\)
ELSE
[[\(s_{_{AMP}}^{^B}\)]] \(\leftarrow\) [[\(s_{_{AMP}}\)]]
# learning epochs
FOR each learning epoch up to learning_epochs DO
# mini-batches loop
FOR each mini-batch [\(s, a, logp, V, R, A, s_{_{AMP}}, s_{_{AMP}}^{^B}, s_{_{AMP}}^{^M}\)] up to mini_batches DO
\(logp' \leftarrow \pi_\theta(s, a)\)
# compute entropy loss
IF entropy computation is enabled THEN
\({L}_{entropy} \leftarrow \, -\) entropy_loss_scale \(\frac{1}{N} \sum_{i=1}^N \pi_{\theta_{entropy}}\)
ELSE
\({L}_{entropy} \leftarrow 0\)
# compute policy loss
\(ratio \leftarrow e^{logp' - logp}\)
\(L_{_{surrogate}} \leftarrow A \; ratio\)
\(L_{_{clipped\,surrogate}} \leftarrow A \; \text{clip}(ratio, 1 - c, 1 + c) \qquad\) with \(c\) as ratio_clip
\(L^{clip}_{\pi_\theta} \leftarrow - \frac{1}{N} \sum_{i=1}^N \min(L_{_{surrogate}}, L_{_{clipped\,surrogate}})\)
# compute value loss
\(V_{_{predicted}} \leftarrow V_\phi(s)\)
IF clip_predicted_values is enabled THEN
\(V_{_{predicted}} \leftarrow V + \text{clip}(V_{_{predicted}} - V, -c, c) \qquad\) with \(c\) as value_clip
\(L_{V_\phi} \leftarrow\) value_loss_scale \(\frac{1}{N} \sum_{i=1}^N (R - V_{_{predicted}})^2\)
# compute discriminator loss
\({logit}_{_{AMP}} \leftarrow D_\psi(s_{_{AMP}}) \qquad\) with \(s_{_{AMP}}\) of size discriminator_batch_size
\({logit}_{_{AMP}}^{^B} \leftarrow D_\psi(s_{_{AMP}}^{^B}) \qquad\) with \(s_{_{AMP}}^{^B}\) of size discriminator_batch_size
\({logit}_{_{AMP}}^{^M} \leftarrow D_\psi(s_{_{AMP}}^{^M}) \qquad\) with \(s_{_{AMP}}^{^M}\) of size discriminator_batch_size
# discriminator prediction loss
\(L_{D_\psi} \leftarrow \dfrac{1}{2}(BCE({logit}_{_{AMP}}\) ++ \({logit}_{_{AMP}}^{^B}, \, 0) + BCE({logit}_{_{AMP}}^{^M}, \, 1))\)
with \(\; BCE(x,y)=-\frac{1}{N} \sum_{i=1}^N [y \; log(\hat{y}) + (1-y) \, log(1-\hat{y})] \;\) and \(\; \hat{y} = \dfrac{1}{1 + e^{-x}}\)
# discriminator logit regularization
\(L_{D_\psi} \leftarrow L_{D_\psi} +\) discriminator_logit_regularization_scale \(\sum_{i=1}^N \text{flatten}(\psi_w[-1])^2\)
# discriminator gradient penalty
\(L_{D_\psi} \leftarrow L_{D_\psi} +\) discriminator_gradient_penalty_scale \(\frac{1}{N} \sum_{i=1}^N \sum (\nabla_\psi {logit}_{_{AMP}}^{^M})^2\)
# discriminator weight decay
\(L_{D_\psi} \leftarrow L_{D_\psi} +\) discriminator_weight_decay_scale \(\sum_{i=1}^N \text{flatten}(\psi_w)^2\)
# optimization step
reset \(\text{optimizer}_{\theta, \phi, \psi}\)
\(\nabla_{\theta, \, \phi, \, \psi} (L^{clip}_{\pi_\theta} + {L}_{entropy} + L_{V_\phi} + L_{D_\psi})\)
\(\text{clip}(\lVert \nabla_{\theta, \, \phi, \, \psi} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_{\theta, \phi, \psi}\)
# update learning rate
IF there is a learning_rate_scheduler THEN
step \(\text{scheduler}_{\theta, \phi, \psi} (\text{optimizer}_{\theta, \phi, \psi})\)
# update AMP repaly buffer
\(s_{_{AMP}} \rightarrow\;\) \(\text{append}(B)\)

Configuration and hyperparameters

skrl.agents.torch.amp.amp.AMP_DEFAULT_CONFIG
 1AMP_DEFAULT_CONFIG = {
 2    "rollouts": 16,                 # number of rollouts before updating
 3    "learning_epochs": 6,           # number of learning epochs during each update
 4    "mini_batches": 2,              # number of mini batches during each learning epoch
 5
 6    "discount_factor": 0.99,        # discount factor (gamma)
 7    "lambda": 0.95,                 # TD(lambda) coefficient (lam) for computing returns and advantages
 8
 9    "learning_rate": 5e-5,                  # learning rate
10    "learning_rate_scheduler": None,        # learning rate scheduler class (see torch.optim.lr_scheduler)
11    "learning_rate_scheduler_kwargs": {},   # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})
12
13    "state_preprocessor": None,             # state preprocessor class (see skrl.resources.preprocessors)
14    "state_preprocessor_kwargs": {},        # state preprocessor's kwargs (e.g. {"size": env.observation_space})
15    "value_preprocessor": None,             # value preprocessor class (see skrl.resources.preprocessors)
16    "value_preprocessor_kwargs": {},        # value preprocessor's kwargs (e.g. {"size": 1})
17    "amp_state_preprocessor": None,         # AMP state preprocessor class (see skrl.resources.preprocessors)
18    "amp_state_preprocessor_kwargs": {},    # AMP state preprocessor's kwargs (e.g. {"size": env.amp_observation_space})
19
20    "random_timesteps": 0,          # random exploration steps
21    "learning_starts": 0,           # learning starts after this many steps
22
23    "grad_norm_clip": 0.0,              # clipping coefficient for the norm of the gradients
24    "ratio_clip": 0.2,                  # clipping coefficient for computing the clipped surrogate objective
25    "value_clip": 0.2,                  # clipping coefficient for computing the value loss (if clip_predicted_values is True)
26    "clip_predicted_values": False,     # clip predicted values during value loss computation
27
28    "entropy_loss_scale": 0.0,          # entropy loss scaling factor
29    "value_loss_scale": 2.5,            # value loss scaling factor
30    "discriminator_loss_scale": 5.0,    # discriminator loss scaling factor
31
32    "amp_batch_size": 512,                  # batch size for updating the reference motion dataset
33    "task_reward_weight": 0.0,              # task-reward weight (wG)
34    "style_reward_weight": 1.0,             # style-reward weight (wS)
35    "discriminator_batch_size": 0,          # batch size for computing the discriminator loss (all samples if 0)
36    "discriminator_reward_scale": 2,                    # discriminator reward scaling factor
37    "discriminator_logit_regularization_scale": 0.05,   # logit regularization scale factor for the discriminator loss
38    "discriminator_gradient_penalty_scale": 5,          # gradient penalty scaling factor for the discriminator loss
39    "discriminator_weight_decay_scale": 0.0001,         # weight decay scaling factor for the discriminator loss
40
41    "rewards_shaper": None,         # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
42
43    "experiment": {
44        "directory": "",            # experiment's parent directory
45        "experiment_name": "",      # experiment name
46        "write_interval": 250,      # TensorBoard writing interval (timesteps)
47
48        "checkpoint_interval": 1000,        # interval for checkpoints (timesteps)
49        "store_separately": False,          # whether to store checkpoints separately
50
51        "wandb": False,             # whether to use Weights & Biases
52        "wandb_kwargs": {}          # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
53    }
54}

Spaces and models

The implementation supports the following Gym spaces / Gymnasium spaces

Gym/Gymnasium spaces

AMP observation

Observation

Action

Discrete

\(\square\)

\(\square\)

\(\square\)

Box

\(\blacksquare\)

\(\blacksquare\)

\(\blacksquare\)

Dict

\(\square\)

\(\square\)

\(\square\)

The implementation uses 1 stochastic (continuous) and 2 deterministic function approximators. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models

Notation

Concept

Key

Input shape

Output shape

Type

\(\pi_\theta(s)\)

Policy

"policy"

observation

action

Gaussian / MultivariateGaussian

\(V_\phi(s)\)

Value

"value"

observation

1

Deterministic

\(D_\psi(s_{_{AMP}})\)

Discriminator

"discriminator"

AMP observation

1

Deterministic

Support for advanced features is described in the next table

Feature

Support and remarks

Shared model

-

RNN support

-

API