Learning rate schedulers#
Learning rate schedulers are techniques that adjust the learning rate over time to improve the performance of the agent.
Learning rate schedulers |
|
|
---|---|---|
\(\blacksquare\) |
\(\blacksquare\) |
Implementation according to the ML framework:
PyTorch: The implemented schedulers inherit from the PyTorch
_LRScheduler
class. Visit How to adjust learning rate in the PyTorch documentation for more details.JAX: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit Schedules in the Optax documentation for more details.
Usage#
The learning rate scheduler usage is defined in each agent’s configuration dictionary. The scheduler class is set under the "learning_rate_scheduler"
key and its arguments are set under the "learning_rate_scheduler_kwargs"
key as a keyword argument dictionary, without specifying the optimizer (first argument).
The following examples show how to set the scheduler for an agent:
# import the scheduler class
from torch.optim.lr_scheduler import StepLR
cfg = DEFAULT_CONFIG.copy()
cfg["learning_rate_scheduler"] = StepLR
cfg["learning_rate_scheduler_kwargs"] = {"step_size": 1, "gamma": 0.9}
# import the scheduler class
from skrl.resources.schedulers.torch import KLAdaptiveLR
cfg = DEFAULT_CONFIG.copy()
cfg["learning_rate_scheduler"] = KLAdaptiveLR
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
# import the scheduler function
from optax import constant_schedule
cfg = DEFAULT_CONFIG.copy()
cfg["learning_rate_scheduler"] = constant_schedule
cfg["learning_rate_scheduler_kwargs"] = {"value": 1e-4}
# import the scheduler class
from skrl.resources.schedulers.jax import KLAdaptiveLR # or kl_adaptive (Optax style)
cfg = DEFAULT_CONFIG.copy()
cfg["learning_rate_scheduler"] = KLAdaptiveLR
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}