Adversarial Motion Priors (AMP)

AMP is a model-free, stochastic on-policy policy gradient algorithm (trained using a combination of GAIL and PPO) for adversarial learning of physics-based character animation. It enables characters to imitate diverse behaviors from large unstructured datasets, without the need for motion planners or other mechanisms for clip selection.

Paper: AMP: Adversarial Motion Priors for Stylized Physics-Based Character Control.



Algorithm


Algorithm implementation

Main notation/symbols:
- policy (\(\pi_\theta\)), value (\(V_\phi\)) and discriminator (\(D_\psi\)) function approximators
- states (\(s\)), actions (\(a\)), rewards (\(r\)), next states (\(s'\)), terminated (\(d_{_{end}}\)), truncated (\(d_{_{timeout}}\))
- values (\(V\)), next values (\(V'\)), advantages (\(A\)), returns (\(R\))
- log probabilities (\(logp\))
- loss (\(L\))
- reference motion dataset (\(M\)), AMP replay buffer (\(B\))
- AMP states (\(s_{_{AMP}}\)), reference motion states (\(s_{_{AMP}}^{^M}\)), AMP states from replay buffer (\(s_{_{AMP}}^{^B}\))

Learning algorithm


compute_gae(...)
def \(\;f_{GAE} (r, d, V, V') \;\rightarrow\; R, A:\)
\(adv \leftarrow 0\)
\(A \leftarrow \text{zeros}(r)\)
# advantages computation
FOR each reverse iteration \(i\) up to the number of rows in \(r\) DO
\(adv \leftarrow r_i - V_i \, +\) discount_factor \((V' \, +\) lambda \(\neg d_i \; adv)\)
\(A_i \leftarrow adv\)
# returns computation
\(R \leftarrow A + V\)
# normalize advantages
\(A \leftarrow \dfrac{A - \bar{A}}{A_\sigma + 10^{-8}}\)

_update(...)
# update dataset of reference motions
collect reference motions of size amp_batch_size \(\rightarrow\;\) \(\text{append}(M)\)
# compute combined rewards
\(r_D \leftarrow -log(\text{max}( 1 - \hat{y}(D_\psi(s_{_{AMP}})), \, 10^{-4})) \qquad\) with \(\; \hat{y}(x) = \dfrac{1}{1 + e^{-x}}\)
\(r' \leftarrow\) task_reward_weight \(r \, +\) style_reward_weight discriminator_reward_scale \(r_D\)
# compute returns and advantages
\(R, A \leftarrow f_{GAE}(r', d_{_{end}} \lor d_{_{timeout}}, V, V')\)
# sample mini-batches from memory
[[\(s, a, logp, V, R, A, s_{_{AMP}}\)]] \(\leftarrow\) states, actions, log_prob, values, returns, advantages, AMP states
[[\(s_{_{AMP}}^{^M}\)]] \(\leftarrow\) AMP states from \(M\)
IF \(B\) is not empty THEN
[[\(s_{_{AMP}}^{^B}\)]] \(\leftarrow\) AMP states from \(B\)
ELSE
[[\(s_{_{AMP}}^{^B}\)]] \(\leftarrow\) [[\(s_{_{AMP}}\)]]
# learning epochs
FOR each learning epoch up to learning_epochs DO
# mini-batches loop
FOR each mini-batch [\(s, a, logp, V, R, A, s_{_{AMP}}, s_{_{AMP}}^{^B}, s_{_{AMP}}^{^M}\)] up to mini_batches DO
\(logp' \leftarrow \pi_\theta(s, a)\)
# compute entropy loss
IF entropy computation is enabled THEN
\({L}_{entropy} \leftarrow \, -\) entropy_loss_scale \(\frac{1}{N} \sum_{i=1}^N \pi_{\theta_{entropy}}\)
ELSE
\({L}_{entropy} \leftarrow 0\)
# compute policy loss
\(ratio \leftarrow e^{logp' - logp}\)
\(L_{_{surrogate}} \leftarrow A \; ratio\)
\(L_{_{clipped\,surrogate}} \leftarrow A \; \text{clip}(ratio, 1 - c, 1 + c) \qquad\) with \(c\) as ratio_clip
\(L^{clip}_{\pi_\theta} \leftarrow - \frac{1}{N} \sum_{i=1}^N \min(L_{_{surrogate}}, L_{_{clipped\,surrogate}})\)
# compute value loss
\(V_{_{predicted}} \leftarrow V_\phi(s)\)
IF clip_predicted_values is enabled THEN
\(V_{_{predicted}} \leftarrow V + \text{clip}(V_{_{predicted}} - V, -c, c) \qquad\) with \(c\) as value_clip
\(L_{V_\phi} \leftarrow\) value_loss_scale \(\frac{1}{N} \sum_{i=1}^N (R - V_{_{predicted}})^2\)
# compute discriminator loss
\({logit}_{_{AMP}} \leftarrow D_\psi(s_{_{AMP}}) \qquad\) with \(s_{_{AMP}}\) of size discriminator_batch_size
\({logit}_{_{AMP}}^{^B} \leftarrow D_\psi(s_{_{AMP}}^{^B}) \qquad\) with \(s_{_{AMP}}^{^B}\) of size discriminator_batch_size
\({logit}_{_{AMP}}^{^M} \leftarrow D_\psi(s_{_{AMP}}^{^M}) \qquad\) with \(s_{_{AMP}}^{^M}\) of size discriminator_batch_size
# discriminator prediction loss
\(L_{D_\psi} \leftarrow \dfrac{1}{2}(BCE({logit}_{_{AMP}}\) ++ \({logit}_{_{AMP}}^{^B}, \, 0) + BCE({logit}_{_{AMP}}^{^M}, \, 1))\)
with \(\; BCE(x,y)=-\frac{1}{N} \sum_{i=1}^N [y \; log(\hat{y}) + (1-y) \, log(1-\hat{y})] \;\) and \(\; \hat{y} = \dfrac{1}{1 + e^{-x}}\)
# discriminator logit regularization
\(L_{D_\psi} \leftarrow L_{D_\psi} +\) discriminator_logit_regularization_scale \(\sum_{i=1}^N \text{flatten}(\psi_w[-1])^2\)
# discriminator gradient penalty
\(L_{D_\psi} \leftarrow L_{D_\psi} +\) discriminator_gradient_penalty_scale \(\frac{1}{N} \sum_{i=1}^N \sum (\nabla_\psi {logit}_{_{AMP}}^{^M})^2\)
# discriminator weight decay
\(L_{D_\psi} \leftarrow L_{D_\psi} +\) discriminator_weight_decay_scale \(\sum_{i=1}^N \text{flatten}(\psi_w)^2\)
# optimization step
reset \(\text{optimizer}_{\theta, \phi, \psi}\)
\(\nabla_{\theta, \, \phi, \, \psi} (L^{clip}_{\pi_\theta} + {L}_{entropy} + L_{V_\phi} + L_{D_\psi})\)
\(\text{clip}(\lVert \nabla_{\theta, \, \phi, \, \psi} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_{\theta, \phi, \psi}\)
# update learning rate
IF there is a learning_rate_scheduler THEN
step \(\text{scheduler}_{\theta, \phi, \psi} (\text{optimizer}_{\theta, \phi, \psi})\)
# update AMP repaly buffer
\(s_{_{AMP}} \rightarrow\;\) \(\text{append}(B)\)

Usage

# import the agent and its default configuration
from skrl.agents.torch.amp import AMP, AMP_CFG

# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ...  # only required during training
models["discriminator"] = ...  # only required during training

# adjust some configuration if necessary
cfg_agent = AMP_CFG()
cfg_agent.KEY = ...

# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
# (assuming defined memories for motion <motion_dataset> and <reply_buffer>)
# (assuming defined method to collect reference motions <collect_reference_motions>)
agent = AMP(
    models=models,
    memory=memory,  # only required during training
    cfg=cfg_agent,
    observation_space=env.observation_space,
    state_space=env.state_space,
    action_space=env.action_space,
    device=env.device,
    amp_observation_space=env.amp_observation_space,
    motion_dataset=motion_dataset,
    reply_buffer=reply_buffer,
    collect_reference_motions=collect_reference_motions,
)

Configuration and hyperparameters

Dataclass

    pytorch    

    jax    

    warp    

AMP_CFG

AMP_CFG


Spaces

The implementation supports the following Gymnasium spaces:

Gymnasium spaces

AMP observation

Observation

Action

Discrete

\(\square\)

\(\square\)

\(\square\)

MultiDiscrete

\(\square\)

\(\square\)

\(\square\)

Box

\(\blacksquare\)

\(\blacksquare\)

\(\blacksquare\)

Dict

\(\square\)

\(\square\)

\(\square\)


Models

The implementation uses 1 stochastic (continuous) and 2 deterministic function approximators. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models.

Notation

Concept

Key

Input shape

Output shape

Type

\(\pi_\theta(s)\)

Policy

"policy"

observation

action

Gaussian /
MultivariateGaussian

\(V_\phi(s)\)

Value

"value"

observation

1

Deterministic

\(D_\psi(s_{_{AMP}})\)

Discriminator

"discriminator"

AMP observation

1

Deterministic


Features

Support for advanced features is described in the following table:

Feature

Support and remarks

    pytorch    

    jax    

    warp    

Shared model

-

\(\square\)

\(\square\)

\(\square\)

RNN support

-

\(\square\)

\(\square\)

\(\square\)

Mixed precision

Automatic mixed precision

\(\blacksquare\)

\(\square\)

\(\square\)

Distributed

Single Program Multi Data (SPMD) multi-GPU

\(\blacksquare\)

\(\square\)

\(\square\)


API


PyTorch

AMP_CFG

Configuration for the AMP agent.

AMP

Adversarial Motion Priors (AMP).

class skrl.agents.torch.amp.AMP_CFG(*, experiment: ExperimentCfg = <factory>, rollouts: int = 16, learning_epochs: int = 6, mini_batches: int = 2, discount_factor: float = 0.99, gae_lambda: float = 0.95, learning_rate: float | tuple[float, float, float] = 5e-05, learning_rate_scheduler: type | tuple[type | None, type | None, type | None] | None = None, learning_rate_scheduler_kwargs: dict | tuple[dict, dict, dict] = <factory>, observation_preprocessor: type | None = None, observation_preprocessor_kwargs: dict = <factory>, state_preprocessor: type | None = None, state_preprocessor_kwargs: dict = <factory>, value_preprocessor: type | None = None, value_preprocessor_kwargs: dict = <factory>, amp_observation_preprocessor: type | None = None, amp_observation_preprocessor_kwargs: dict = <factory>, random_timesteps: int = 0, learning_starts: int = 0, grad_norm_clip: float = 0.5, ratio_clip: float = 0.2, value_clip: float = 0.2, entropy_loss_scale: float = 0.0, value_loss_scale: float = 2.5, discriminator_loss_scale: float = 5.0, amp_batch_size: int = 512, task_reward_scale: float = 0.0, style_reward_scale: float = 2.0, discriminator_batch_size: int = -1, discriminator_logit_regularization_scale: float = 0.05, discriminator_gradient_penalty_scale: float = 5.0, discriminator_weight_decay_scale: float = 0.0001, time_limit_bootstrap: bool = False, rewards_shaper: Callable | None = None, mixed_precision: bool = False)[source]

Bases: AgentCfg

Configuration for the AMP agent.

Methods:

expand()

Expand the configuration.

validate()

Validate the configuration.

Attributes:

amp_batch_size

Batch size for updating the reference motion dataset.

amp_observation_preprocessor

Preprocessor class to process the environment's AMP observations.

amp_observation_preprocessor_kwargs

Keyword arguments for the AMP observation preprocessor's constructor.

discount_factor

Parameter that balances the importance of future rewards (close to 1.0) versus immediate rewards (close to 0.0).

discriminator_batch_size

Batch size (for subsampling AMP observations) for computing the discriminator loss.

discriminator_gradient_penalty_scale

Gradient penalty scaling factor for the discriminator.

discriminator_logit_regularization_scale

Logit regularization scaling factor for the discriminator.

discriminator_loss_scale

Discriminator loss scaling factor.

discriminator_weight_decay_scale

Weight decay scaling factor for the discriminator.

entropy_loss_scale

Entropy loss scaling factor.

experiment

Experiment settings.

gae_lambda

TD(lambda) coefficient for computing Generalized Advantage Estimation (GAE).

grad_norm_clip

Clipping coefficient for the gradients by their global norm.

learning_epochs

Number of learning epochs to perform during updates.

learning_rate

Learning rate for the policy, value and discriminator networks.

learning_rate_scheduler

Learning rate scheduler class for the policy, value and discriminator networks.

learning_rate_scheduler_kwargs

Keyword arguments for the learning rate scheduler's constructor.

learning_starts

Number of steps to perform before calling the algorithm update function.

mini_batches

Number of mini batches to sample when updating.

mixed_precision

Whether to enable automatic mixed precision for higher performance.

observation_preprocessor

Preprocessor class to process the environment's observations.

observation_preprocessor_kwargs

Keyword arguments for the observation preprocessor's constructor.

random_timesteps

Number of random exploration (sampling random actions) steps to perform before sampling actions from the policy.

ratio_clip

Clipping coefficient for computing the clipped surrogate objective.

rewards_shaper

Rewards shaping function.

rollouts

Number of collection steps to perform between updates.

state_preprocessor

Preprocessor class to process the environment's states.

state_preprocessor_kwargs

Keyword arguments for the state preprocessor's constructor.

style_reward_scale

Reward scaling factor for the style (motion to be copied).

task_reward_scale

Reward scaling factor for the task.

time_limit_bootstrap

Whether to bootstrap at timeout termination (episode truncation).

value_clip

Clipping coefficient for the predicted value during value loss computation.

value_loss_scale

Value loss scaling factor.

value_preprocessor

Preprocessor class to process the value network's output.

value_preprocessor_kwargs

Keyword arguments for the value preprocessor's constructor.

expand() None[source]

Expand the configuration.

validate() bool[source]

Validate the configuration.

amp_batch_size: int = 512

Batch size for updating the reference motion dataset.

amp_observation_preprocessor: type | None = None

Preprocessor class to process the environment’s AMP observations.

See Preprocessors for more details.

amp_observation_preprocessor_kwargs: dict

Keyword arguments for the AMP observation preprocessor’s constructor.

See Preprocessors for more details.

discount_factor: float = 0.99

Parameter that balances the importance of future rewards (close to 1.0) versus immediate rewards (close to 0.0).

Range: [0.0, 1.0].

discriminator_batch_size: int = -1

Batch size (for subsampling AMP observations) for computing the discriminator loss.

If less than or equal to 0, all sampled AMP observations will be used.

discriminator_gradient_penalty_scale: float = 5.0

Gradient penalty scaling factor for the discriminator.

discriminator_logit_regularization_scale: float = 0.05

Logit regularization scaling factor for the discriminator.

discriminator_loss_scale: float = 5.0

Discriminator loss scaling factor.

discriminator_weight_decay_scale: float = 0.0001

Weight decay scaling factor for the discriminator.

entropy_loss_scale: float = 0.0

Entropy loss scaling factor.

experiment: ExperimentCfg

Experiment settings.

gae_lambda: float = 0.95

TD(lambda) coefficient for computing Generalized Advantage Estimation (GAE).

grad_norm_clip: float = 0.5

Clipping coefficient for the gradients by their global norm.

If less than or equal to 0, the gradients will not be clipped.

learning_epochs: int = 6

Number of learning epochs to perform during updates.

learning_rate: float | tuple[float, float, float] = 5e-05

Learning rate for the policy, value and discriminator networks.

  • If a float is provided, the same learning rate will be used for the networks.

  • If a tuple is provided, its elements will be used for each network in order.

learning_rate_scheduler: type | tuple[type | None, type | None, type | None] | None = None

Learning rate scheduler class for the policy, value and discriminator networks.

See Learning rate schedulers for more details.

  • If a class is provided, the same learning rate scheduler will be used for the networks.

  • If a tuple is provided, its elements will be used for each network in order.

learning_rate_scheduler_kwargs: dict | tuple[dict, dict, dict]

Keyword arguments for the learning rate scheduler’s constructor.

See Learning rate schedulers for more details.

Warning

The optimizer argument is automatically passed to the learning rate scheduler’s constructor. Therefore, it must not be provided in the keyword arguments.

  • If a dictionary is provided, the same keyword arguments will be used for the networks.

  • If a tuple is provided, its elements will be used for each network in order.

learning_starts: int = 0

Number of steps to perform before calling the algorithm update function.

mini_batches: int = 2

Number of mini batches to sample when updating.

mixed_precision: bool = False

Whether to enable automatic mixed precision for higher performance.

observation_preprocessor: type | None = None

Preprocessor class to process the environment’s observations.

See Preprocessors for more details.

observation_preprocessor_kwargs: dict

Keyword arguments for the observation preprocessor’s constructor.

See Preprocessors for more details.

random_timesteps: int = 0

Number of random exploration (sampling random actions) steps to perform before sampling actions from the policy.

ratio_clip: float = 0.2

Clipping coefficient for computing the clipped surrogate objective.

rewards_shaper: Callable | None = None

Rewards shaping function.

rollouts: int = 16

Number of collection steps to perform between updates.

state_preprocessor: type | None = None

Preprocessor class to process the environment’s states.

See Preprocessors for more details.

state_preprocessor_kwargs: dict

Keyword arguments for the state preprocessor’s constructor.

See Preprocessors for more details.

style_reward_scale: float = 2.0

Reward scaling factor for the style (motion to be copied).

task_reward_scale: float = 0.0

Reward scaling factor for the task.

time_limit_bootstrap: bool = False

Whether to bootstrap at timeout termination (episode truncation).

value_clip: float = 0.2

Clipping coefficient for the predicted value during value loss computation.

If less than or equal to 0, the predicted value will not be clipped.

value_loss_scale: float = 2.5

Value loss scaling factor.

value_preprocessor: type | None = None

Preprocessor class to process the value network’s output.

See Preprocessors for more details.

value_preprocessor_kwargs: dict

Keyword arguments for the value preprocessor’s constructor.

See Preprocessors for more details.

class skrl.agents.torch.amp.AMP(*, models: dict[str, Model], memory: Memory | None = None, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: AMP_CFG | dict = {}, amp_observation_space: gymnasium.Space | None = None, motion_dataset: Memory | None = None, reply_buffer: Memory | None = None, collect_reference_motions: Callable[[int], torch.Tensor] | None = None)[source]

Bases: Agent

Adversarial Motion Priors (AMP).

https://arxiv.org/abs/2104.02180

Note

The implementation is adapted from the NVIDIA IsaacGymEnvs repository.

Parameters:
  • models – Agent’s models.

  • memory – Memory to storage agent’s data and environment transitions.

  • observation_space – Observation space.

  • state_space – State space.

  • action_space – Action space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • cfg – Agent’s configuration.

  • amp_observation_space – AMP observation space.

  • motion_dataset – Reference motion dataset (M).

  • reply_buffer – Reply buffer for preventing discriminator overfitting (B).

  • collect_reference_motions – Callable to collect reference motions.

Raises:

KeyError – If a configuration key is missing.

Methods:

act(observations, states, *, timestep, timesteps)

Process the environment's observations/states to make a decision (actions) using the main policy.

enable_models_training_mode([enabled])

Set the training mode of all the agent's models: enabled (training) or disabled (evaluation).

enable_training_mode([enabled, apply_to_models])

Set the training mode of the agent: enabled (training) or disabled (evaluation).

init(*[, trainer_cfg])

Initialize the agent.

load(path)

Load the agent from the specified path.

post_interaction(*, timestep, timesteps)

Method called after the interaction with the environment.

pre_interaction(*, timestep, timesteps)

Method called before the interaction with the environment.

record_transition(*, observations, states, ...)

Record an environment transition in memory.

save(path)

Save the agent to the specified path.

track_data(tag, value)

Track data to TensorBoard.

update(*, timestep, timesteps)

Algorithm's main update step.

write_checkpoint(*, timestep, timesteps)

Write checkpoint (modules) to persistent storage.

write_tracking_data(*, timestep, timesteps)

Write tracking data to TensorBoard.

act(observations: torch.Tensor, states: torch.Tensor | None, *, timestep: int, timesteps: int) tuple[torch.Tensor, dict[str, Any]][source]

Process the environment’s observations/states to make a decision (actions) using the main policy.

Parameters:
  • observations – Environment observations.

  • states – Environment states.

  • timestep – Current timestep.

  • timesteps – Number of timesteps.

Returns:

Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model.

enable_models_training_mode(enabled: bool = True) None[source]

Set the training mode of all the agent’s models: enabled (training) or disabled (evaluation).

Parameters:

enabled – True to enable the training mode, False to enable the evaluation mode.

enable_training_mode(enabled: bool = True, *, apply_to_models: bool = False) None[source]

Set the training mode of the agent: enabled (training) or disabled (evaluation).

The training mode can be queried by the training property.

Parameters:
  • enabled – True to enable the training mode, False to enable the evaluation mode.

  • apply_to_models – Whether to apply the training mode to all the agent’s models.

init(*, trainer_cfg: dict[str, Any] | None = None) None[source]

Initialize the agent.

Parameters:

trainer_cfg – Trainer configuration.

load(path: str) None[source]

Load the agent from the specified path.

Note

The final storage device is determined by the constructor of the agent.

Parameters:

path – Path to load the agent from.

post_interaction(*, timestep: int, timesteps: int) None[source]

Method called after the interaction with the environment.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

pre_interaction(*, timestep: int, timesteps: int) None[source]

Method called before the interaction with the environment.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

record_transition(*, observations: torch.Tensor, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_observations: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None[source]

Record an environment transition in memory.

Parameters:
  • observations – Environment observations.

  • states – Environment states.

  • actions – Actions taken by the agent.

  • rewards – Instant rewards achieved by the current actions.

  • next_observations – Next environment observations.

  • next_states – Next environment states.

  • terminated – Signals that indicate episodes have terminated.

  • truncated – Signals that indicate episodes have been truncated.

  • infos – Additional information about the environment.

  • timestep – Current timestep.

  • timesteps – Number of timesteps.

save(path: str) None[source]

Save the agent to the specified path.

Parameters:

path – Path to save the agent to.

track_data(tag: str, value: float) None[source]

Track data to TensorBoard.

Note

Currently only scalar data is supported.

Parameters:
  • tag – Data identifier (e.g. ‘Loss/Policy loss’).

  • value – Value to track.

update(*, timestep: int, timesteps: int) None[source]

Algorithm’s main update step.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

write_checkpoint(*, timestep: int, timesteps: int) None[source]

Write checkpoint (modules) to persistent storage.

Note

The checkpoints are stored in the subdirectory checkpoints within the experiment directory. The checkpoint name is the timestep argument value (if it is not None), or the current system date-time otherwise.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

write_tracking_data(*, timestep: int, timesteps: int) None[source]

Write tracking data to TensorBoard.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.