Running standard scaler

Standardize input features by removing the mean and scaling to unit variance.



Algorithm


Algorithm implementation

Main notation/symbols:
- mean (\(\bar{x}\)), standard deviation (\(\sigma\)), variance (\(\sigma^2\))
- running mean (\(\bar{x}_t\)), running variance (\(\sigma^2_t\))

Standardization by centering and scaling

\(\text{clip}((x - \bar{x}_t) / (\sqrt{\sigma^2} \;+\) epsilon \(), -c, c) \qquad\) with \(c\) as clip_threshold

Scale back the data to the original representation (inverse transform)

\(\sqrt{\sigma^2_t} \; \text{clip}(x, -c, c) + \bar{x}_t \qquad\) with \(c\) as clip_threshold

Update the running mean and variance (See parallel algorithm)

\(\delta \leftarrow x - \bar{x}_t\)
\(n_T \leftarrow n_t + n\)
\(M2 \leftarrow (\sigma^2_t n_t) + (\sigma^2 n) + \delta^2 \dfrac{n_t n}{n_T}\)
# update internal variables
\(\bar{x}_t \leftarrow \bar{x}_t + \delta \dfrac{n}{n_T}\)
\(\sigma^2_t \leftarrow \dfrac{M2}{n_T}\)
\(n_t \leftarrow n_T\)

Usage

The preprocessor usage is defined in each agent’s configuration. The preprocessor class is set under the "<type>_preprocessor" key and its arguments are set under the "<type>_preprocessor_kwargs" key, as a Python dictionary.

The following examples show how to set the preprocessors for an agent:

# import the preprocessor class
from skrl.resources.preprocessors.torch import RunningStandardScaler

cfg = AGENT_CFG()
# ...
cfg.observation_preprocessor = RunningStandardScaler
cfg.observation_preprocessor_kwargs = {"size": env.observation_space, "device": device}
cfg.state_preprocessor = RunningStandardScaler
cfg.state_preprocessor_kwargs = {"size": env.state_space, "device": device}
cfg.value_preprocessor = RunningStandardScaler
cfg.value_preprocessor_kwargs = {"size": 1, "device": device}

API


PyTorch

RunningStandardScaler

Standardize the input data by removing the mean and scaling by the standard deviation.

class skrl.resources.preprocessors.torch.running_standard_scaler.RunningStandardScaler(*args: Any, **kwargs: Any)[source]

Bases: Module

Standardize the input data by removing the mean and scaling by the standard deviation.

Parameters:
  • size – Size of the input space.

  • epsilon – Small number to avoid division by zero.

  • clip_threshold – Threshold to clip the data.

  • device – Data allocation and computation device. If not specified, the default device will be used.

Example:

>>> running_standard_scaler = RunningStandardScaler(size=2)
>>> data = torch.rand(3, 2)  # tensor of shape (N, 2)
>>> running_standard_scaler(data)
tensor([[0.1954, 0.3356],
        [0.9719, 0.4163],
        [0.8540, 0.1982]])
__call__(*args: Any, **kwargs: Any) Any[source]

Call self as a function.

Methods:

forward(x, *[, train, inverse, no_grad])

Forward pass of the standardizer.

forward(x: torch.Tensor | None, *, train: bool = False, inverse: bool = False, no_grad: bool = True) torch.Tensor | None[source]

Forward pass of the standardizer.

Parameters:
  • x – Input tensor.

  • train – Whether to train the standardizer.

  • inverse – Whether to inverse the standardizer to scale back the data.

  • no_grad – Whether to disable the gradient computation.

Returns:

Standardized tensor.

Example:

>>> x = torch.rand(3, 2, device="cuda:0")
>>> running_standard_scaler(x)
tensor([[0.6933, 0.1905],
        [0.3806, 0.3162],
        [0.1140, 0.0272]], device='cuda:0')

>>> running_standard_scaler(x, train=True)
tensor([[ 0.8681, -0.6731],
        [ 0.0560, -0.3684],
        [-0.6360, -1.0690]], device='cuda:0')

>>> running_standard_scaler(x, inverse=True)
tensor([[0.6260, 0.5468],
        [0.5056, 0.5987],
        [0.4029, 0.4795]], device='cuda:0')

JAX

RunningStandardScaler

Standardize the input data by removing the mean and scaling by the standard deviation.

class skrl.resources.preprocessors.jax.running_standard_scaler.RunningStandardScaler(size: int | list[int] | gymnasium.Space, *, epsilon: float = 1e-08, clip_threshold: float = 5.0, device: str | jax.Device | None = None)[source]

Bases: object

Standardize the input data by removing the mean and scaling by the standard deviation.

Parameters:
  • size – Size of the input space.

  • epsilon – Small number to avoid division by zero.

  • clip_threshold – Threshold to clip the data.

  • device – Data allocation and computation device. If not specified, the default device will be used.

Example:

>>> running_standard_scaler = RunningStandardScaler(size=2)
>>> data = jax.random.uniform(jax.random.PRNGKey(0), (3,2))  # tensor of shape (N, 2)
>>> running_standard_scaler(data)
Array([[0.57450044, 0.09968603],
       [0.7419659 , 0.8941783 ],
       [0.59656656, 0.45325184]], dtype=float32)
__call__(x: jax.Array | None, *, train: bool = False, inverse: bool = False) jax.Array | None[source]

Forward pass of the standardizer.

Parameters:
  • x – Input tensor.

  • train – Whether to train the standardizer.

  • inverse – Whether to inverse the standardizer to scale back the data.

Returns:

Standardized tensor.

Example:

>>> x = jax.random.uniform(jax.random.PRNGKey(0), (3,2))
>>> running_standard_scaler(x)
Array([[0.57450044, 0.09968603],
       [0.7419659 , 0.8941783 ],
       [0.59656656, 0.45325184]], dtype=float32)

>>> running_standard_scaler(x, train=True)
Array([[ 0.167439  , -0.4292293 ],
       [ 0.45878986,  0.8719094 ],
       [ 0.20582889,  0.14980486]], dtype=float32)

>>> running_standard_scaler(x, inverse=True)
Array([[0.80847514, 0.4226486 ],
       [0.9047325 , 0.90777594],
       [0.8211585 , 0.6385405 ]], dtype=float32)

Attributes:

state_dict

Dictionary containing references to the whole state of the module.

property state_dict: dict[str, jax.Array][source]

Dictionary containing references to the whole state of the module.


Warp

RunningStandardScaler

Standardize the input data by removing the mean and scaling by the standard deviation.

class skrl.resources.preprocessors.warp.running_standard_scaler.RunningStandardScaler(size: int | list[int] | gymnasium.Space, *, epsilon: float = 1e-08, clip_threshold: float = 5.0, device: str | wp.Device | None = None)[source]

Bases: object

Standardize the input data by removing the mean and scaling by the standard deviation.

Parameters:
  • size – Size of the input space.

  • epsilon – Small number to avoid division by zero.

  • clip_threshold – Threshold to clip the data.

  • device – Data allocation and computation device. If not specified, the default device will be used.

Example:

>>> running_standard_scaler = RunningStandardScaler(size=2)
>>> data = rand(3, 2)  # tensor of shape (N, 2)
>>> running_standard_scaler(data)
tensor([[0.1954, 0.3356],
        [0.9719, 0.4163],
        [0.8540, 0.1982]])
__call__(x: wp.array | None, *, train: bool = False, inverse: bool = False, inplace: bool = False) wp.array | None[source]

Forward pass of the standardizer.

Parameters:
  • x – Input tensor.

  • train – Whether to train the standardizer.

  • inverse – Whether to inverse the standardizer to scale back the data.

  • no_grad – Whether to disable the gradient computation.

  • inplace – Whether to perform the operation in-place.

Returns:

Standardized tensor.

Example:

>>> x = rand(3, 2, device="cuda:0")
>>> running_standard_scaler(x)
tensor([[0.6933, 0.1905],
        [0.3806, 0.3162],
        [0.1140, 0.0272]], device='cuda:0')

>>> running_standard_scaler(x, train=True)
tensor([[ 0.8681, -0.6731],
        [ 0.0560, -0.3684],
        [-0.6360, -1.0690]], device='cuda:0')

>>> running_standard_scaler(x, inverse=True)
tensor([[0.6260, 0.5468],
        [0.5056, 0.5987],
        [0.4029, 0.4795]], device='cuda:0')

Methods:

load_state_dict(state_dict)

state_dict()

Dictionary containing references to the whole state of the module.

load_state_dict(state_dict: dict[str, warp.array]) None[source]
state_dict() dict[str, warp.array][source]

Dictionary containing references to the whole state of the module.