Learning rate schedulers¶
Learning rate schedulers are techniques that adjust the learning rate over time to improve the performance of the agent.
Implemented schedulers¶
The following table lists the implemented schedulers and their support for different frameworks.
Learning rate schedulers |
|
|
|
|---|---|---|---|
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
Implementation details according to the ML framework:
PyTorch: The implemented schedulers inherit from the PyTorch
_LRSchedulerclass. Visit How to adjust learning rate in the PyTorch documentation for more details.JAX: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit Optimizer Schedules in the Optax documentation for more details.
Warp: The implemented schedulers must parameterize and return a function that maps step counts to values.
Usage¶
The learning rate scheduler usage is defined in each agent’s configuration.
The scheduler class is set under the learning_rate_scheduler key and its arguments are set under
the learning_rate_scheduler_kwargs key, as a Python dictionary, without specifying the optimizer instance.
The following examples show how to set the scheduler for an agent, using either a third-party scheduler (from the ML framework) or a native scheduler (from skrl):
# import the scheduler class
from torch.optim.lr_scheduler import StepLR
cfg = AGENT_CFG()
# ...
cfg.learning_rate_scheduler = StepLR
cfg.learning_rate_scheduler_kwargs = {"step_size": 1, "gamma": 0.9}
# import the scheduler class
from skrl.resources.schedulers.torch import KLAdaptiveLR
cfg = AGENT_CFG()
# ...
cfg.learning_rate_scheduler = KLAdaptiveLR
cfg.learning_rate_scheduler_kwargs = {"kl_threshold": 0.01}
# import the scheduler function
from optax import constant_schedule
cfg = AGENT_CFG()
# ...
cfg.learning_rate_scheduler = constant_schedule
cfg.learning_rate_scheduler_kwargs = {"value": 1e-4}
# import the scheduler function
from skrl.resources.schedulers.jax import KLAdaptiveLR # or kl_adaptive (Optax style)
cfg = AGENT_CFG()
# ...
cfg.learning_rate_scheduler = KLAdaptiveLR
cfg.learning_rate_scheduler_kwargs = {"kl_threshold": 0.01}
# import the scheduler function
from skrl.resources.schedulers.warp import KLAdaptiveLR # or kl_adaptive
cfg = AGENT_CFG()
# ...
cfg.learning_rate_scheduler = KLAdaptiveLR
cfg.learning_rate_scheduler_kwargs = {"kl_threshold": 0.01}