Source code for genrl.environments.vec_env.utils

from typing import Tuple

import torch


[docs]class RunningMeanStd: """ Utility Function to compute a running mean and variance calculator :param epsilon: Small number to prevent division by zero for calculations :param shape: Shape of the RMS object :type epsilon: float :type shape: Tuple """ def __init__(self, epsilon: float = 1e-4, shape: Tuple = ()): self.mean = torch.zeros(shape).double() self.var = torch.ones(shape).double() self.count = epsilon
[docs] def update(self, batch: torch.Tensor): batch_mean = torch.mean(batch, axis=0) batch_var = torch.var(batch, axis=0) batch_count = batch.shape[0] total_count = self.count + batch_count delta = batch_mean - self.mean new_mean = self.mean + delta * batch_count / total_count M2 = ( self.var * self.count + batch_var * batch_count + (delta ** 2) * self.count * batch_count / total_count ) self.mean = new_mean self.var = M2 / (total_count - 1) self.count = total_count