Running standard scaler¶
Standardize input features by removing the mean and scaling to unit variance.
Algorithm¶
Algorithm implementation¶
Standardization by centering and scaling
Scale back the data to the original representation (inverse transform)
Update the running mean and variance (See parallel algorithm)
Usage¶
The preprocessor usage is defined in each agent’s configuration.
The preprocessor class is set under the "<type>_preprocessor" key and its arguments are set under
the "<type>_preprocessor_kwargs" key, as a Python dictionary.
The following examples show how to set the preprocessors for an agent:
# import the preprocessor class
from skrl.resources.preprocessors.torch import RunningStandardScaler
cfg = AGENT_CFG()
# ...
cfg.observation_preprocessor = RunningStandardScaler
cfg.observation_preprocessor_kwargs = {"size": env.observation_space, "device": device}
cfg.state_preprocessor = RunningStandardScaler
cfg.state_preprocessor_kwargs = {"size": env.state_space, "device": device}
cfg.value_preprocessor = RunningStandardScaler
cfg.value_preprocessor_kwargs = {"size": 1, "device": device}
# import the preprocessor class
from skrl.resources.preprocessors.jax import RunningStandardScaler
cfg = AGENT_CFG()
# ...
cfg.observation_preprocessor = RunningStandardScaler
cfg.observation_preprocessor_kwargs = {"size": env.observation_space, "device": device}
cfg.state_preprocessor = RunningStandardScaler
cfg.state_preprocessor_kwargs = {"size": env.state_space, "device": device}
cfg.value_preprocessor = RunningStandardScaler
cfg.value_preprocessor_kwargs = {"size": 1, "device": device}
# import the preprocessor class
from skrl.resources.preprocessors.warp import RunningStandardScaler
cfg = AGENT_CFG()
# ...
cfg.observation_preprocessor = RunningStandardScaler
cfg.observation_preprocessor_kwargs = {"size": env.observation_space, "device": device}
cfg.state_preprocessor = RunningStandardScaler
cfg.state_preprocessor_kwargs = {"size": env.state_space, "device": device}
cfg.value_preprocessor = RunningStandardScaler
cfg.value_preprocessor_kwargs = {"size": 1, "device": device}
API¶
PyTorch¶
Standardize the input data by removing the mean and scaling by the standard deviation. |
- class skrl.resources.preprocessors.torch.running_standard_scaler.RunningStandardScaler(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleStandardize the input data by removing the mean and scaling by the standard deviation.
- Parameters:
size – Size of the input space.
epsilon – Small number to avoid division by zero.
clip_threshold – Threshold to clip the data.
device – Data allocation and computation device. If not specified, the default device will be used.
Example:
>>> running_standard_scaler = RunningStandardScaler(size=2) >>> data = torch.rand(3, 2) # tensor of shape (N, 2) >>> running_standard_scaler(data) tensor([[0.1954, 0.3356], [0.9719, 0.4163], [0.8540, 0.1982]])
Methods:
forward(x, *[, train, inverse, no_grad])Forward pass of the standardizer.
- forward(x: torch.Tensor | None, *, train: bool = False, inverse: bool = False, no_grad: bool = True) torch.Tensor | None[source]¶
Forward pass of the standardizer.
- Parameters:
x – Input tensor.
train – Whether to train the standardizer.
inverse – Whether to inverse the standardizer to scale back the data.
no_grad – Whether to disable the gradient computation.
- Returns:
Standardized tensor.
Example:
>>> x = torch.rand(3, 2, device="cuda:0") >>> running_standard_scaler(x) tensor([[0.6933, 0.1905], [0.3806, 0.3162], [0.1140, 0.0272]], device='cuda:0') >>> running_standard_scaler(x, train=True) tensor([[ 0.8681, -0.6731], [ 0.0560, -0.3684], [-0.6360, -1.0690]], device='cuda:0') >>> running_standard_scaler(x, inverse=True) tensor([[0.6260, 0.5468], [0.5056, 0.5987], [0.4029, 0.4795]], device='cuda:0')
JAX¶
Standardize the input data by removing the mean and scaling by the standard deviation. |
- class skrl.resources.preprocessors.jax.running_standard_scaler.RunningStandardScaler(size: int | list[int] | gymnasium.Space, *, epsilon: float = 1e-08, clip_threshold: float = 5.0, device: str | jax.Device | None = None)[source]¶
Bases:
objectStandardize the input data by removing the mean and scaling by the standard deviation.
- Parameters:
size – Size of the input space.
epsilon – Small number to avoid division by zero.
clip_threshold – Threshold to clip the data.
device – Data allocation and computation device. If not specified, the default device will be used.
Example:
>>> running_standard_scaler = RunningStandardScaler(size=2) >>> data = jax.random.uniform(jax.random.PRNGKey(0), (3,2)) # tensor of shape (N, 2) >>> running_standard_scaler(data) Array([[0.57450044, 0.09968603], [0.7419659 , 0.8941783 ], [0.59656656, 0.45325184]], dtype=float32)
- __call__(x: jax.Array | None, *, train: bool = False, inverse: bool = False) jax.Array | None[source]¶
Forward pass of the standardizer.
- Parameters:
x – Input tensor.
train – Whether to train the standardizer.
inverse – Whether to inverse the standardizer to scale back the data.
- Returns:
Standardized tensor.
Example:
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (3,2)) >>> running_standard_scaler(x) Array([[0.57450044, 0.09968603], [0.7419659 , 0.8941783 ], [0.59656656, 0.45325184]], dtype=float32) >>> running_standard_scaler(x, train=True) Array([[ 0.167439 , -0.4292293 ], [ 0.45878986, 0.8719094 ], [ 0.20582889, 0.14980486]], dtype=float32) >>> running_standard_scaler(x, inverse=True) Array([[0.80847514, 0.4226486 ], [0.9047325 , 0.90777594], [0.8211585 , 0.6385405 ]], dtype=float32)
Attributes:
Dictionary containing references to the whole state of the module.
Warp¶
Standardize the input data by removing the mean and scaling by the standard deviation. |
- class skrl.resources.preprocessors.warp.running_standard_scaler.RunningStandardScaler(size: int | list[int] | gymnasium.Space, *, epsilon: float = 1e-08, clip_threshold: float = 5.0, device: str | wp.Device | None = None)[source]¶
Bases:
objectStandardize the input data by removing the mean and scaling by the standard deviation.
- Parameters:
size – Size of the input space.
epsilon – Small number to avoid division by zero.
clip_threshold – Threshold to clip the data.
device – Data allocation and computation device. If not specified, the default device will be used.
Example:
>>> running_standard_scaler = RunningStandardScaler(size=2) >>> data = rand(3, 2) # tensor of shape (N, 2) >>> running_standard_scaler(data) tensor([[0.1954, 0.3356], [0.9719, 0.4163], [0.8540, 0.1982]])
- __call__(x: wp.array | None, *, train: bool = False, inverse: bool = False, inplace: bool = False) wp.array | None[source]¶
Forward pass of the standardizer.
- Parameters:
x – Input tensor.
train – Whether to train the standardizer.
inverse – Whether to inverse the standardizer to scale back the data.
no_grad – Whether to disable the gradient computation.
inplace – Whether to perform the operation in-place.
- Returns:
Standardized tensor.
Example:
>>> x = rand(3, 2, device="cuda:0") >>> running_standard_scaler(x) tensor([[0.6933, 0.1905], [0.3806, 0.3162], [0.1140, 0.0272]], device='cuda:0') >>> running_standard_scaler(x, train=True) tensor([[ 0.8681, -0.6731], [ 0.0560, -0.3684], [-0.6360, -1.0690]], device='cuda:0') >>> running_standard_scaler(x, inverse=True) tensor([[0.6260, 0.5468], [0.5056, 0.5987], [0.4029, 0.4795]], device='cuda:0')
Methods:
load_state_dict(state_dict)Dictionary containing references to the whole state of the module.