Source code for mice.state

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Optional

import numpy as np


[docs] @dataclass(slots=True) class WelfordVec: """ Online mean / M2 accumulator for vectors. We store M2 per-coordinate (same shape as mean) so that sum(var_i) == M2.sum()/(n-1) matches the manuscript's use of V_{l,k} = sum_i Var(Delta^{(i)}_{l,k}). """ mean: np.ndarray m2: np.ndarray n: int = 0
[docs] @classmethod def zeros(cls, dim: int, dtype=np.float64) -> "WelfordVec": mean = np.zeros(dim, dtype=dtype) m2 = np.zeros(dim, dtype=dtype) return cls(mean=mean, m2=m2, n=0)
[docs] def update_batch(self, x: np.ndarray) -> None: """ Update with a batch x of shape (m, d). """ if x.size == 0: return if x.ndim != 2: raise ValueError(f"Expected x to have ndim=2, got shape {x.shape}") m = x.shape[0] if m == 0: return # Batch Welford update (Chan et al. style) batch_mean = x.mean(axis=0) batch_m2 = ((x - batch_mean) ** 2).sum(axis=0) if self.n == 0: self.mean[...] = batch_mean self.m2[...] = batch_m2 self.n = m return n_a = self.n n_b = m delta = batch_mean - self.mean n = n_a + n_b self.mean[...] = self.mean + delta * (n_b / n) self.m2[...] = self.m2 + batch_m2 + (delta**2) * (n_a * n_b / n) self.n = n
@property def var_sum(self) -> float: if self.n <= 1: return float("inf") return float(self.m2.sum() / (self.n - 1))
[docs] @dataclass(slots=True) class ResamplingAcc: """ Maintain partitioned sums/counts to compute leave-one-partition-out means efficiently (O(re_part * d) to materialize all LOO means). """ re_part: int sum_total: np.ndarray # shape (d,) cnt_total: int sum_part: np.ndarray # shape (re_part, d) cnt_part: np.ndarray # shape (re_part,)
[docs] @classmethod def zeros(cls, re_part: int, dim: int, dtype=np.float64) -> "ResamplingAcc": return cls( re_part=re_part, sum_total=np.zeros(dim, dtype=dtype), cnt_total=0, sum_part=np.zeros((re_part, dim), dtype=dtype), cnt_part=np.zeros(re_part, dtype=np.int64), )
[docs] def update_batch(self, x: np.ndarray) -> None: """ x: shape (m, d) """ if x.size == 0: return m = x.shape[0] start = self.cnt_total idxs = (np.arange(start, start + m) % self.re_part).astype(np.int64) self.sum_total += x.sum(axis=0) self.cnt_total += m # Update per-partition sums/counts for p in range(self.re_part): mask = idxs == p c = int(mask.sum()) if c: self.cnt_part[p] += c self.sum_part[p] += x[mask].sum(axis=0)
[docs] def loo_means(self) -> np.ndarray: """ Returns array of shape (re_part, d), where row p is the mean excluding samples that fell into partition p. """ d = self.sum_total.shape[0] out = np.empty((self.re_part, d), dtype=self.sum_total.dtype) for p in range(self.re_part): denom = self.cnt_total - int(self.cnt_part[p]) if denom <= 0: # Fallback: if we can't exclude, use full mean out[p] = self.sum_total / max(self.cnt_total, 1) else: out[p] = (self.sum_total - self.sum_part[p]) / denom return out
[docs] @dataclass(slots=True) class LevelState: """ Statistics for a single level (either base gradient or a gradient difference). - base level: Delta = grad(x_l) - diff level: Delta = grad(x_l) - grad(x_prev) """ x: np.ndarray cost: int # 1 for base, 2 for diff x_prev: Optional[np.ndarray] # None for base sample_fn: Callable[[int], Any] delta_stats: WelfordVec base_stats: Optional[WelfordVec] # only for diff levels (for restart/clipping heuristics) m_min: int # Optional resampling accumulators (enabled when resampling is enabled) delta_resamp: Optional[ResamplingAcc] = None base_resamp: Optional[ResamplingAcc] = None # Bookkeeping for bias/stat error decomposition (matches current implementation style) m_prev: int = 0 @property def m(self) -> int: return self.delta_stats.n @property def v_delta(self) -> float: return self.delta_stats.var_sum @property def v_base(self) -> float: if self.base_stats is None: return self.delta_stats.var_sum return self.base_stats.var_sum @property def v_batch(self) -> float: """ Variance proxy used in sample-size optimization: - base level: variance of grad(x_0) - diff level: variance of grad(x_l) (the "top" gradient) """ return self.v_base @property def mean_delta(self) -> np.ndarray: return self.delta_stats.mean
[docs] def delta_loo_means(self) -> Optional[np.ndarray]: if self.delta_resamp is None: return None return self.delta_resamp.loo_means()