KL Adaptive

Adjust the learning rate according to the value of the Kullback-Leibler (KL) divergence.



Algorithm


Algorithm implementation

The learning rate (\(\eta\)) at each step is modified as follows:

IF \(\; KL >\) kl_factor kl_threshold THEN
\(\eta_{t + 1} = \max(\eta_t \,/\) lr_factor \(,\) min_lr \()\)
IF \(\; KL <\) kl_threshold \(/\) kl_factor THEN
\(\eta_{t + 1} = \min(\) lr_factor \(\eta_t,\) max_lr \()\)

Usage

The learning rate scheduler usage is defined in each agent’s configuration. The scheduler class is set under the learning_rate_scheduler key and its arguments are set under the learning_rate_scheduler_kwargs key, as a Python dictionary, without specifying the optimizer instance.

# import the scheduler class
from skrl.resources.schedulers.torch import KLAdaptiveLR

cfg = AGENT_CFG()
# ...
cfg.learning_rate_scheduler = KLAdaptiveLR
cfg.learning_rate_scheduler_kwargs = {"kl_threshold": 0.01}

API


PyTorch

KLAdaptiveLR

Adaptive KL scheduler.

class skrl.resources.schedulers.torch.kl_adaptive.KLAdaptiveLR(*args: Any, **kwargs: Any)[source]

Bases: _LRScheduler

Adaptive KL scheduler.

Adjusts the learning rate according to the KL divergence.

Note

This scheduler is only available for the A2C, AMP, PPO and RPO single-agent algorithms, and IPPO and MAPPO multi-agent algorithms. Applying it to other agents will not change the learning rate.

Parameters:
  • optimizer – Wrapped optimizer.

  • kl_threshold – Threshold for KL divergence.

  • min_lr – Lower bound for learning rate.

  • max_lr – Upper bound for learning rate.

  • kl_factor – The number used to modify the KL divergence threshold.

  • lr_factor – The number used to modify the learning rate.

  • last_epoch – The index of last epoch.

  • verbose – Verbose mode.

Example:

>>> scheduler = KLAdaptiveLR(optimizer, kl_threshold=0.01)
>>> for epoch in range(100):
>>>     # ...
>>>     kl_divergence = ...
>>>     scheduler.step(kl_divergence)

Methods:

step([kl, epoch])

Step scheduler.

step(kl: torch.Tensor | float | None = None, *, epoch: int | None = None) None[source]

Step scheduler.

Parameters:
  • kl – KL divergence. If None, no adjustment is made. If tensor, the number of elements must be 1.

  • epoch – Epoch.

Example:

>>> kl = torch.distributions.kl_divergence(p, q)
>>> kl
tensor([0.0332, 0.0500, 0.0383,  ..., 0.0076, 0.0240, 0.0164])
>>> scheduler.step(kl.mean())

>>> kl = 0.0046
>>> scheduler.step(kl)

JAX

KLAdaptiveLR

Adaptive KL scheduler.

skrl.resources.schedulers.jax.kl_adaptive.KLAdaptiveLR(*, kl_threshold: float = 0.008, min_lr: float = 1e-06, max_lr: float = 0.01, kl_factor: float = 2, lr_factor: float = 1.5) optax.Schedule[source]

Adaptive KL scheduler.

Adjusts the learning rate according to the KL divergence.

Note

This scheduler is only available for the A2C, AMP, PPO and RPO single-agent algorithms, and IPPO and MAPPO multi-agent algorithms. Applying it to other agents will not change the learning rate.

Parameters:
  • kl_threshold – Threshold for KL divergence.

  • min_lr – Lower bound for learning rate.

  • max_lr – Upper bound for learning rate.

  • kl_factor – The number used to modify the KL divergence threshold.

  • lr_factor – The number used to modify the learning rate.

Returns:

A function that maps step counts, current learning rate and KL divergence to the new learning rate value. If no learning rate is specified, 1.0 will be returned to mimic the Optax’s scheduler behaviors. If the learning rate is specified but the KL divergence is not 0, the specified learning rate is returned.

Example:

>>> scheduler = KLAdaptiveLR(kl_threshold=0.01)
>>> for epoch in range(100):
>>>     # ...
>>>     kl_divergence = ...
>>>     new_lr = scheduler(timestep, lr, kl_divergence)

Warp

KLAdaptiveLR

Adaptive KL scheduler.

skrl.resources.schedulers.warp.kl_adaptive.KLAdaptiveLR(*, kl_threshold: float = 0.008, min_lr: float = 1e-06, max_lr: float = 0.01, kl_factor: float = 2, lr_factor: float = 1.5) Callable[[int, float, float], float][source]

Adaptive KL scheduler.

Adjusts the learning rate according to the KL divergence.

Note

This scheduler is only available for the A2C, AMP, PPO and RPO single-agent algorithms, and IPPO and MAPPO multi-agent algorithms. Applying it to other agents will not change the learning rate.

Parameters:
  • kl_threshold – Threshold for KL divergence.

  • min_lr – Lower bound for learning rate.

  • max_lr – Upper bound for learning rate.

  • kl_factor – The number used to modify the KL divergence threshold.

  • lr_factor – The number used to modify the learning rate.

Returns:

A function that maps step counts, current learning rate and KL divergence to the new learning rate value. If no learning rate is specified, 1.0 will be returned to mimic the JAX’s scheduler behaviors. If the learning rate is specified but the KL divergence is not 0, the specified learning rate is returned.

Example:

>>> scheduler = KLAdaptiveLR(kl_threshold=0.01)
>>> for epoch in range(100):
>>>     # ...
>>>     kl_divergence = ...
>>>     new_lr = scheduler(timestep, lr, kl_divergence)