Adam

An extension of the Stochastic Gradient Descent (SGD) algorithm that adaptively changes the learning rate for each neural network parameter.



Usage

The classes are not intended to be used directly by the user, but by agent implementations.

  • For JAX, the class is the result of isolating the Optax optimizer that is mixed with the model parameters, as defined in the Flax’s TrainState class.


API


JAX

Adam

Adam optimizer.

class skrl.resources.optimizers.jax.adam.Adam(*, model: Model, lr: float = 0.001, grad_norm_clip: float = 0, scale: bool = True)[source]

Bases: object

Adam optimizer.

Adapted from Optax’s Adam to support custom scale (learning rate).

Parameters:
  • model – Model.

  • lr – Learning rate.

  • grad_norm_clip – Clipping coefficient for the norm of the gradients. Disabled if less than or equal to zero.

  • scale – Whether to instantiate the optimizer as-is or remove the scaling step. Remove the scaling step if a custom learning rate is to be applied during optimization steps.

Returns:

Adam optimizer.

Example:

>>> optimizer = Adam(model=policy, lr=5e-4)
>>> # step the optimizer given a computed gradiend (grad)
>>> optimizer = optimizer.step(grad, policy)

# apply custom learning rate during optimization steps
>>> optimizer = Adam(model=policy, lr=5e-4, scale=False)
>>> # step the optimizer given a computed gradiend and an updated learning rate (lr)
>>> optimizer = optimizer.step(grad, policy, lr)
static __new__(cls, *, model: Model, lr: float = 0.001, grad_norm_clip: float = 0, scale: bool = True) Optimizer[source]

Adam optimizer.

Adapted from Optax’s Adam to support custom scale (learning rate).

Parameters:
  • model – Model.

  • lr – Learning rate.

  • grad_norm_clip – Clipping coefficient for the norm of the gradients. Disabled if less than or equal to zero.

  • scale – Whether to instantiate the optimizer as-is or remove the scaling step. Remove the scaling step if a custom learning rate is to be applied during optimization steps.

Returns:

Adam optimizer.

Example:

>>> optimizer = Adam(model=policy, lr=5e-4)
>>> # step the optimizer given a computed gradiend (grad)
>>> optimizer = optimizer.step(grad, policy)

# apply custom learning rate during optimization steps
>>> optimizer = Adam(model=policy, lr=5e-4, scale=False)
>>> # step the optimizer given a computed gradiend and an updated learning rate (lr)
>>> optimizer = optimizer.step(grad, policy, lr)