Adam#

An extension of the stochastic gradient descent algorithm that adaptively changes the learning rate for each neural network parameter.



Usage#

Note

This class is the result of isolating the Optax optimizer that is mixed with the model parameters, as defined in the Flax’s TrainState class. It is not intended to be used directly by the user, but by agent implementations.

# import the optimizer class
from skrl.resources.optimizers.jax import Adam

# instantiate the optimizer
optimizer = Adam(model=model, lr=1e-3)

# step the optimizer
optimizer = optimizer.step(grad, model)

API (JAX)#

class skrl.resources.optimizers.jax.adam.Adam(model: Model, lr: float = 0.001, grad_norm_clip: float = 0, scale: bool = True)#

Bases: object

static __new__(cls, model: Model, lr: float = 0.001, grad_norm_clip: float = 0, scale: bool = True) Optimizer#

Adam optimizer

Adapted from Optax’s Adam to support custom scale (learning rate)

Parameters:
  • model (skrl.models.jax.Model) – Model

  • lr (float, optional) – Learning rate (default: 1e-3)

  • grad_norm_clip (float, optional) – Clipping coefficient for the norm of the gradients (default: 0). Disabled if less than or equal to zero

  • scale (bool, optional) – Whether to instantiate the optimizer as-is or remove the scaling step (default: True). Remove the scaling step if a custom learning rate is to be applied during optimization steps

Returns:

Adam optimizer

Return type:

flax.struct.PyTreeNode

Example:

>>> optimizer = Adam(model=policy, lr=5e-4)
>>> # step the optimizer given a computed gradiend (grad)
>>> optimizer = optimizer.step(grad, policy)

# apply custom learning rate during optimization steps
>>> optimizer = Adam(model=policy, lr=5e-4, scale=False)
>>> # step the optimizer given a computed gradiend and an updated learning rate (lr)
>>> optimizer = optimizer.step(grad, policy, lr)