Learning rate schedulers

Learning rate schedulers are techniques that adjust the learning rate over time to improve the performance of the agent.



Implemented schedulers

The following table lists the implemented schedulers and their support for different frameworks.

Learning rate schedulers

    pytorch    

    jax    

    warp    

KL Adaptive

\(\blacksquare\)

\(\blacksquare\)

\(\blacksquare\)


Implementation details according to the ML framework:

  • PyTorch: The implemented schedulers inherit from the PyTorch _LRScheduler class. Visit How to adjust learning rate in the PyTorch documentation for more details.

  • JAX: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit Optimizer Schedules in the Optax documentation for more details.

  • Warp: The implemented schedulers must parameterize and return a function that maps step counts to values.


Usage

The learning rate scheduler usage is defined in each agent’s configuration. The scheduler class is set under the learning_rate_scheduler key and its arguments are set under the learning_rate_scheduler_kwargs key, as a Python dictionary, without specifying the optimizer instance.

The following examples show how to set the scheduler for an agent, using either a third-party scheduler (from the ML framework) or a native scheduler (from skrl):

# import the scheduler class
from torch.optim.lr_scheduler import StepLR

cfg = AGENT_CFG()
# ...
cfg.learning_rate_scheduler = StepLR
cfg.learning_rate_scheduler_kwargs = {"step_size": 1, "gamma": 0.9}