Adversarial Motion Priors (AMP)

AMP is a model-free, stochastic on-policy policy gradient algorithm (trained using a combination of GAIL and PPO) for adversarial learning of physics-based character animation. It enables characters to imitate diverse behaviors from large unstructured datasets, without the need for motion planners or other mechanisms for clip selection

Paper: AMP: Adversarial Motion Priors for Stylized Physics-Based Character Control



Algorithm


Algorithm implementation

Main notation/symbols:
- policy (\(\pi_\theta\)), value (\(V_\phi\)) and discriminator (\(D_\psi\)) function approximators
- states (\(s\)), actions (\(a\)), rewards (\(r\)), next states (\(s'\)), dones (\(d\))
- values (\(V\)), next values (\(V'\)), advantages (\(A\)), returns (\(R\))
- log probabilities (\(logp\))
- loss (\(L\))
- reference motion dataset (\(M\)), AMP replay buffer (\(B\))
- AMP states (\(s_{_{AMP}}\)), reference motion states (\(s_{_{AMP}}^{^M}\)), AMP states from replay buffer (\(s_{_{AMP}}^{^B}\))

Learning algorithm


compute_gae(...)
def \(\;f_{GAE} (r, d, V, V') \;\rightarrow\; R, A:\)
\(adv \leftarrow 0\)
\(A \leftarrow \text{zeros}(r)\)
# advantages computation
FOR each reverse iteration \(i\) up to the number of rows in \(r\) DO
\(adv \leftarrow r_i - V_i \, +\) discount_factor \((V' \, +\) lambda \(\neg d_i \; adv)\)
\(A_i \leftarrow adv\)
# returns computation
\(R \leftarrow A + V\)
# normalize advantages
\(A \leftarrow \dfrac{A - \bar{A}}{A_\sigma + 10^{-8}}\)

_update(...)
# update dataset of reference motions
collect reference motions of size amp_batch_size \(\rightarrow\;\) \(\text{append}(M)\)
# compute combined rewards
\(r_D \leftarrow -log(\text{max}( 1 - \hat{y}(D_\psi(s_{_{AMP}})), \, 10^{-4})) \qquad\) with \(\; \hat{y}(x) = \dfrac{1}{1 + e^{-x}}\)
\(r' \leftarrow\) task_reward_weight \(r \, +\) style_reward_weight discriminator_reward_scale \(r_D\)
# compute returns and advantages
\(R, A \leftarrow f_{GAE}(r', d, V, V')\)
# sample mini-batches from memory
[[\(s, a, logp, V, R, A, s_{_{AMP}}\)]] \(\leftarrow\) states, actions, log_prob, values, returns, advantages, AMP states
[[\(s_{_{AMP}}^{^M}\)]] \(\leftarrow\) AMP states from \(M\)
IF \(B\) is not empty THEN
[[\(s_{_{AMP}}^{^B}\)]] \(\leftarrow\) AMP states from \(B\)
ELSE
[[\(s_{_{AMP}}^{^B}\)]] \(\leftarrow\) [[\(s_{_{AMP}}\)]]
# learning epochs
FOR each learning epoch up to learning_epochs DO
# mini-batches loop
FOR each mini-batch [\(s, a, logp, V, R, A, s_{_{AMP}}, s_{_{AMP}}^{^B}, s_{_{AMP}}^{^M}\)] up to mini_batches DO
\(logp' \leftarrow \pi_\theta(s, a)\)
# compute entropy loss
IF entropy computation is enabled THEN
\({L}_{entropy} \leftarrow \, -\) entropy_loss_scale \(\frac{1}{N} \sum_{i=1}^N \pi_{\theta_{entropy}}\)
ELSE
\({L}_{entropy} \leftarrow 0\)
# compute policy loss
\(ratio \leftarrow e^{logp' - logp}\)
\(L_{_{surrogate}} \leftarrow A \; ratio\)
\(L_{_{clipped\,surrogate}} \leftarrow A \; \text{clip}(ratio, 1 - c, 1 + c) \qquad\) with \(c\) as ratio_clip
\(L^{clip}_{\pi_\theta} \leftarrow - \frac{1}{N} \sum_{i=1}^N \min(L_{_{surrogate}}, L_{_{clipped\,surrogate}})\)
# compute value loss
\(V_{_{predicted}} \leftarrow V_\phi(s)\)
IF clip_predicted_values is enabled THEN
\(V_{_{predicted}} \leftarrow V + \text{clip}(V_{_{predicted}} - V, -c, c) \qquad\) with \(c\) as value_clip
\(L_{V_\phi} \leftarrow\) value_loss_scale \(\frac{1}{N} \sum_{i=1}^N (R - V_{_{predicted}})^2\)
# compute discriminator loss
\({logit}_{_{AMP}} \leftarrow D_\psi(s_{_{AMP}}) \qquad\) with \(s_{_{AMP}}\) of size discriminator_batch_size
\({logit}_{_{AMP}}^{^B} \leftarrow D_\psi(s_{_{AMP}}^{^B}) \qquad\) with \(s_{_{AMP}}^{^B}\) of size discriminator_batch_size
\({logit}_{_{AMP}}^{^M} \leftarrow D_\psi(s_{_{AMP}}^{^M}) \qquad\) with \(s_{_{AMP}}^{^M}\) of size discriminator_batch_size
# discriminator prediction loss
\(L_{D_\psi} \leftarrow \dfrac{1}{2}(BCE({logit}_{_{AMP}}\) ++ \({logit}_{_{AMP}}^{^B}, \, 0) + BCE({logit}_{_{AMP}}^{^M}, \, 1))\)
with \(\; BCE(x,y)=-\frac{1}{N} \sum_{i=1}^N [y \; log(\hat{y}) + (1-y) \, log(1-\hat{y})] \;\) and \(\; \hat{y} = \dfrac{1}{1 + e^{-x}}\)
# discriminator logit regularization
\(L_{D_\psi} \leftarrow L_{D_\psi} +\) discriminator_logit_regularization_scale \(\sum_{i=1}^N \text{flatten}(\psi_w[-1])^2\)
# discriminator gradient penalty
\(L_{D_\psi} \leftarrow L_{D_\psi} +\) discriminator_gradient_penalty_scale \(\frac{1}{N} \sum_{i=1}^N \sum (\nabla_\psi {logit}_{_{AMP}}^{^M})^2\)
# discriminator weight decay
\(L_{D_\psi} \leftarrow L_{D_\psi} +\) discriminator_weight_decay_scale \(\sum_{i=1}^N \text{flatten}(\psi_w)^2\)
# optimization step
reset \(\text{optimizer}_{\theta, \phi, \psi}\)
\(\nabla_{\theta, \, \phi, \, \psi} (L^{clip}_{\pi_\theta} + {L}_{entropy} + L_{V_\phi} + L_{D_\psi})\)
\(\text{clip}(\lVert \nabla_{\theta, \, \phi, \, \psi} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_{\theta, \phi, \psi}\)
# update learning rate
IF there is a learning_rate_scheduler THEN
step \(\text{scheduler}_{\theta, \phi, \psi} (\text{optimizer}_{\theta, \phi, \psi})\)
# update AMP repaly buffer
\(s_{_{AMP}} \rightarrow\;\) \(\text{append}(B)\)

Usage

# import the agent and its default configuration
from skrl.agents.torch.amp import AMP, AMP_DEFAULT_CONFIG

# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ...  # only required during training
models["discriminator"] = ...  # only required during training

# adjust some configuration if necessary
cfg_agent = AMP_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...

# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
# (assuming defined memories for motion <motion_dataset> and <reply_buffer>)
# (assuming defined methods to collect motion <collect_reference_motions> and <collect_observation>)
agent = AMP(models=models,
            memory=memory,  # only required during training
            cfg=cfg_agent,
            observation_space=env.observation_space,
            action_space=env.action_space,
            device=env.device,
            amp_observation_space=env.amp_observation_space,
            motion_dataset=motion_dataset,
            reply_buffer=reply_buffer,
            collect_reference_motions=collect_reference_motions,
            collect_observation=collect_observation)

Configuration and hyperparameters

AMP_DEFAULT_CONFIG = {
    "rollouts": 16,                 # number of rollouts before updating
    "learning_epochs": 6,           # number of learning epochs during each update
    "mini_batches": 2,              # number of mini batches during each learning epoch

    "discount_factor": 0.99,        # discount factor (gamma)
    "lambda": 0.95,                 # TD(lambda) coefficient (lam) for computing returns and advantages

    "learning_rate": 5e-5,                  # learning rate
    "learning_rate_scheduler": None,        # learning rate scheduler class (see torch.optim.lr_scheduler)
    "learning_rate_scheduler_kwargs": {},   # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

    "state_preprocessor": None,             # state preprocessor class (see skrl.resources.preprocessors)
    "state_preprocessor_kwargs": {},        # state preprocessor's kwargs (e.g. {"size": env.observation_space})
    "value_preprocessor": None,             # value preprocessor class (see skrl.resources.preprocessors)
    "value_preprocessor_kwargs": {},        # value preprocessor's kwargs (e.g. {"size": 1})
    "amp_state_preprocessor": None,         # AMP state preprocessor class (see skrl.resources.preprocessors)
    "amp_state_preprocessor_kwargs": {},    # AMP state preprocessor's kwargs (e.g. {"size": env.amp_observation_space})

    "random_timesteps": 0,          # random exploration steps
    "learning_starts": 0,           # learning starts after this many steps

    "grad_norm_clip": 0.0,              # clipping coefficient for the norm of the gradients
    "ratio_clip": 0.2,                  # clipping coefficient for computing the clipped surrogate objective
    "value_clip": 0.2,                  # clipping coefficient for computing the value loss (if clip_predicted_values is True)
    "clip_predicted_values": False,     # clip predicted values during value loss computation

    "entropy_loss_scale": 0.0,          # entropy loss scaling factor
    "value_loss_scale": 2.5,            # value loss scaling factor
    "discriminator_loss_scale": 5.0,    # discriminator loss scaling factor

    "amp_batch_size": 512,                  # batch size for updating the reference motion dataset
    "task_reward_weight": 0.0,              # task-reward weight (wG)
    "style_reward_weight": 1.0,             # style-reward weight (wS)
    "discriminator_batch_size": 0,          # batch size for computing the discriminator loss (all samples if 0)
    "discriminator_reward_scale": 2,                    # discriminator reward scaling factor
    "discriminator_logit_regularization_scale": 0.05,   # logit regularization scale factor for the discriminator loss
    "discriminator_gradient_penalty_scale": 5,          # gradient penalty scaling factor for the discriminator loss
    "discriminator_weight_decay_scale": 0.0001,         # weight decay scaling factor for the discriminator loss

    "rewards_shaper": None,         # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
    "time_limit_bootstrap": False,  # bootstrap at timeout termination (episode truncation)

    "experiment": {
        "directory": "",            # experiment's parent directory
        "experiment_name": "",      # experiment name
        "write_interval": 250,      # TensorBoard writing interval (timesteps)

        "checkpoint_interval": 1000,        # interval for checkpoints (timesteps)
        "store_separately": False,          # whether to store checkpoints separately

        "wandb": False,             # whether to use Weights & Biases
        "wandb_kwargs": {}          # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
    }
}

Spaces

The implementation supports the following Gym spaces / Gymnasium spaces

Gym/Gymnasium spaces

AMP observation

Observation

Action

Discrete

\(\square\)

\(\square\)

\(\square\)

MultiDiscrete

\(\square\)

\(\square\)

\(\square\)

Box

\(\blacksquare\)

\(\blacksquare\)

\(\blacksquare\)

Dict

\(\square\)

\(\square\)

\(\square\)


Models

The implementation uses 1 stochastic (continuous) and 2 deterministic function approximators. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models

Notation

Concept

Key

Input shape

Output shape

Type

\(\pi_\theta(s)\)

Policy

"policy"

observation

action

Gaussian /
MultivariateGaussian

\(V_\phi(s)\)

Value

"value"

observation

1

Deterministic

\(D_\psi(s_{_{AMP}})\)

Discriminator

"discriminator"

AMP observation

1

Deterministic


Features

Support for advanced features is described in the next table

Feature

Support and remarks

    pytorch    

    jax    

Shared model

-

\(\square\)

\(\square\)

RNN support

-

\(\square\)

\(\square\)


API (PyTorch)

skrl.agents.torch.amp.AMP_DEFAULT_CONFIG

alias of {‘amp_batch_size’: 512, ‘amp_state_preprocessor’: None, ‘amp_state_preprocessor_kwargs’: {}, ‘clip_predicted_values’: False, ‘discount_factor’: 0.99, ‘discriminator_batch_size’: 0, ‘discriminator_gradient_penalty_scale’: 5, ‘discriminator_logit_regularization_scale’: 0.05, ‘discriminator_loss_scale’: 5.0, ‘discriminator_reward_scale’: 2, ‘discriminator_weight_decay_scale’: 0.0001, ‘entropy_loss_scale’: 0.0, ‘experiment’: {‘checkpoint_interval’: 1000, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0.0, ‘lambda’: 0.95, ‘learning_epochs’: 6, ‘learning_rate’: 5e-05, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘mini_batches’: 2, ‘random_timesteps’: 0, ‘ratio_clip’: 0.2, ‘rewards_shaper’: None, ‘rollouts’: 16, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘style_reward_weight’: 1.0, ‘task_reward_weight’: 0.0, ‘time_limit_bootstrap’: False, ‘value_clip’: 0.2, ‘value_loss_scale’: 2.5, ‘value_preprocessor’: None, ‘value_preprocessor_kwargs’: {}}

class skrl.agents.torch.amp.AMP(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None, amp_observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, motion_dataset: Memory | None = None, reply_buffer: Memory | None = None, collect_reference_motions: Callable[[int], torch.Tensor] | None = None, collect_observation: Callable[[], torch.Tensor] | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None, amp_observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, motion_dataset: Memory | None = None, reply_buffer: Memory | None = None, collect_reference_motions: Callable[[int], torch.Tensor] | None = None, collect_observation: Callable[[], torch.Tensor] | None = None) None

Adversarial Motion Priors (AMP)

https://arxiv.org/abs/2104.02180

The implementation is adapted from the NVIDIA IsaacGymEnvs (https://github.com/NVIDIA-Omniverse/IsaacGymEnvs/blob/main/isaacgymenvs/learning/amp_continuous.py)

Parameters:
  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

  • amp_observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None) – AMP observation/state space or shape (default: None)

  • motion_dataset (skrl.memory.torch.Memory or None) – Reference motion dataset: M (default: None)

  • reply_buffer (skrl.memory.torch.Memory or None) – Reply buffer for preventing discriminator overfitting: B (default: None)

  • collect_reference_motions (Callable[[int], torch.Tensor] or None) – Callable to collect reference motions (default: None)

  • collect_observation (Callable[[], torch.Tensor] or None) – Callable to collect observation (default: None)

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

Parameters:
  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps