Source code for common.utils.torch
""":mod:`torch` utilities."""
import torch
from jaxtyping import Float32
from torch import Tensor
[docs]
class RunningStandardization:
"""Standardizes the running data.
Args:
x_size: Size of the input tensor.
"""
def __init__(self: "RunningStandardization", x_size: int) -> None:
self.mean: Float32[Tensor, " x_size"] = torch.zeros(size=(x_size,))
self.var: Float32[Tensor, " x_size"] = torch.zeros(size=(x_size,))
self.std: Float32[Tensor, " x_size"] = torch.zeros(size=(x_size,))
self.n: Float32[Tensor, " 1"] = torch.zeros(size=(1,))
def __call__(
self: "RunningStandardization",
x: Float32[Tensor, " x_size"],
) -> Float32[Tensor, " x_size"]:
"""Inputs ``x``, updates attrs and returns standardized ``x``.
Args:
x: Input tensor.
Returns:
Standardized tensor.
"""
self.n += torch.ones(size=(1,))
new_mean: Float32[Tensor, " x_size"] = (
self.mean + (x - self.mean) / self.n
)
new_var: Float32[Tensor, " x_size"] = self.var + (x - self.mean) * (
x - new_mean
)
new_std: Float32[Tensor, " x_size"] = torch.sqrt(new_var / self.n)
self.mean, self.var, self.std = new_mean, new_var, new_std
standardized_x: Float32[Tensor, " x_size"] = (x - self.mean) / (
self.std + self.std.eq(0)
)
return standardized_x