Learning rate schedulers¶
Learning rate schedulers are techniques that adjust the learning rate over time to improve the performance of the agent.
Learning rate schedulers |
|
|
---|---|---|
\(\blacksquare\) |
\(\blacksquare\) |
Implementation according to the ML framework:
PyTorch: The implemented schedulers inherit from the PyTorch
_LRScheduler
class. Visit How to adjust learning rate in the PyTorch documentation for more details.JAX: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit Optimizer Schedules in the Optax documentation for more details.
Usage¶
The learning rate scheduler usage is defined in each agent’s configuration dictionary. The scheduler class is set under the "learning_rate_scheduler"
key and its arguments are set under the "learning_rate_scheduler_kwargs"
key as a keyword argument dictionary, without specifying the optimizer (first argument).
The following examples show how to set the scheduler for an agent:
# import the scheduler class
from torch.optim.lr_scheduler import StepLR
cfg = DEFAULT_CONFIG.copy()
cfg["learning_rate_scheduler"] = StepLR
cfg["learning_rate_scheduler_kwargs"] = {"step_size": 1, "gamma": 0.9}
# import the scheduler class
from skrl.resources.schedulers.torch import KLAdaptiveLR
cfg = DEFAULT_CONFIG.copy()
cfg["learning_rate_scheduler"] = KLAdaptiveLR
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
# import the scheduler function
from optax import constant_schedule
cfg = DEFAULT_CONFIG.copy()
cfg["learning_rate_scheduler"] = constant_schedule
cfg["learning_rate_scheduler_kwargs"] = {"value": 1e-4}
# import the scheduler class
from skrl.resources.schedulers.jax import KLAdaptiveLR # or kl_adaptive (Optax style)
cfg = DEFAULT_CONFIG.copy()
cfg["learning_rate_scheduler"] = KLAdaptiveLR
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}