Learning rate schedulers

Learning rate schedulers are techniques that adjust the learning rate over time to improve the performance of the agent.



Learning rate schedulers

    pytorch    

    jax    

KL Adaptive

\(\blacksquare\)

\(\blacksquare\)


Implementation according to the ML framework:

  • PyTorch: The implemented schedulers inherit from the PyTorch _LRScheduler class. Visit How to adjust learning rate in the PyTorch documentation for more details.

  • JAX: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit Schedules in the Optax documentation for more details.


Usage

The learning rate scheduler usage is defined in each agent’s configuration dictionary. The scheduler class is set under the "learning_rate_scheduler" key and its arguments are set under the "learning_rate_scheduler_kwargs" key as a keyword argument dictionary, without specifying the optimizer (first argument).

The following examples show how to set the scheduler for an agent:

# import the scheduler class
from torch.optim.lr_scheduler import StepLR

cfg = DEFAULT_CONFIG.copy()
cfg["learning_rate_scheduler"] = StepLR
cfg["learning_rate_scheduler_kwargs"] = {"step_size": 1, "gamma": 0.9}