Source code for running_stats

"""Computation of mean and std from online streams of batches

"""


[docs]class RunningStats: def __init__(self, loader_dict): """Computation of mean and std from online streams of batches Parameters ---------- loader_dict : dict dict of callable functions that can get the desired data from each batch """ self.loader_dict = loader_dict stats = dict() for k, _ in self.loader_dict.items(): stats[f'{k}_mean'] = 0.0 stats[f'{k}_var'] = 0.0 self.stats = stats
[docs] def update(self, batch, i): """Update `stats` for a new batch Parameters ---------- batch : array or dict new batch of data whose data can be accessed by the functions in `loader_dict` i : int index indicating that the batch is the i-th batch """ for k, func in self.loader_dict.items(): new = func(batch) new_mean = new.mean(dim=0, keepdim=True) new_var = new.var(dim=0, unbiased=False, keepdim=True) self.stats[f'{k}_var'] += (new_var - self.stats[f'{k}_var'])/(i+1) self.stats[f'{k}_var'] += (i/(i+1)**2.0)*(self.stats[f'{k}_mean'] - new_mean)**2.0 self.stats[f'{k}_mean'] += (new_mean - self.stats[f'{k}_mean'])/(i+1)