from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Tuple
import numpy as np
from .logging import Recorder
from .norms import PlainNormEstimator, ResamplingNormEstimator
from .policy import DropRestartClipPolicy
from .sampling import FiniteSampler, SamplerLike
from .state import LevelState, ResamplingAcc, WelfordVec
GradFn = Callable[[np.ndarray, Any], np.ndarray]
[docs]
@dataclass
class MICE:
"""
Multi-Iteration stochastiC Estimator for stochastic gradients.
This class maintains a hierarchy of control variates and adaptively allocates
samples across levels to satisfy an error tolerance while reducing gradient
evaluation cost.
"""
grad: GradFn
sampler: SamplerLike
eps: float = 0.577
min_batch: int = 10
restart_factor: int = 10
max_cost: float = float("inf")
stop_crit_norm: float = 0.0
stop_crit_prob: float = 0.05
convex: bool = False
policy: DropRestartClipPolicy = field(default_factory=DropRestartClipPolicy)
recorder: Optional[Recorder] = None
# - resampling ON by default
# - clipping OFF by default
use_resampling: bool = True
# Resampling controls
re_part: int = 5
re_quantile: float = 0.05
# Default resampling controls used by the current implementation.
re_tot_cost: float = 0.2
re_min_n: int = 5
re_max_samp: int = 1000
# Safety: cap per-call gradient batch materialization size to avoid OOM when
# theoretically-optimal sample sizes become extremely large.
#
# This caps the number of float entries in the (m, dim) arrays returned by
# `grad(...)` during a single sampling-growth update.
max_grad_batch_elems: int = 5_000_000
def __post_init__(self) -> None:
self.rng = np.random.default_rng()
self.finite = not callable(self.sampler)
self.data_size = len(self.sampler) if self.finite else None
self.levels: List[LevelState] = []
self.dim: Optional[int] = None
self.counter = 0 # gradient evaluations
self.k = 0 # estimator calls
self.terminate = False
self.terminate_reason: Optional[str] = None
self.force_restart = False
# For finite samplers, we keep a single FiniteSampler instance to maintain state
# across levels, but we initialize it lazily
self._finite_sampler: Optional[FiniteSampler] = None
self.m_restart_min = self.restart_factor * self.min_batch
if self.finite:
self.m_restart_min = int(min(self.m_restart_min, self.data_size))
# Select the norm estimator according to the resampling setting.
self.norm_estimator = (
ResamplingNormEstimator(
re_part=self.re_part,
re_quantile=self.re_quantile,
stop_quantile=self.stop_crit_prob,
convex=self.convex,
)
if self.use_resampling
else PlainNormEstimator(convex=self.convex)
)
self._norm_stop: Optional[float] = None
# - err_tol is used inside the resampling routine to estimate the cost
# of a MICE iteration (via opt_ml - m_prev) before updating err_tol.
self.err_tol: float = 1e-6 if self.use_resampling else 0.0
self.re_cost: float = 1.0
# Cached aggregate estimator: g_hat = sum(mean_delta_l)
self._g_hat: Optional[np.ndarray] = None
self._last_pilot_thetas: Any = None
self._last_pilot_g_cur: Optional[np.ndarray] = None # shape (m, d)
if self.recorder is None:
self.recorder = Recorder()
def __call__(self, x: np.ndarray) -> np.ndarray:
return self.evaluate(x)
# --- Public API ---
[docs]
def evaluate(self, x: np.ndarray) -> np.ndarray:
"""
Evaluate a MICE gradient estimate at ``x`` and update internal state.
Returns the aggregated gradient estimator for the current iterate.
"""
if self.terminate:
return np.full_like(np.asarray(x, dtype=float), np.nan)
x = np.asarray(x, dtype=float).reshape(-1)
if self.dim is None:
self.dim = int(x.size)
elif x.size != self.dim:
raise ValueError(f"Inconsistent dimension: expected {self.dim}, got {x.size}")
# Mark previous sample sizes for bias/stat-error decomposition
for lvl in self.levels:
lvl.m_prev = lvl.m
current_event = "start" if len(self.levels) == 0 else "add"
# Add a new level and take pilot samples (m_min)
self._add_level(x)
self._pilot_update_last_level()
if self.terminate:
g_hat = self._g_hat if self._g_hat is not None else self._aggregate_recompute()
self._record(event="end", g_hat=g_hat)
self.k += 1
return g_hat
# Define tolerance on error using current norm estimate
err_tol = self._define_tol()
# Compute optimal sample sizes
opt_ml = self._get_opt_ml(err_tol)
# Policy decisions (drop/restart/clip) happen before the sampling-growth loop.
if self.policy and len(self.levels) > 2:
did_drop, opt_ml = self._check_dropping(opt_ml, err_tol)
if self.terminate:
current_event = "end"
opt_ml = np.asarray([lvl.m for lvl in self.levels], dtype=int)
if did_drop:
current_event = "dropped"
if self.policy and len(self.levels) > 1 and not self.terminate:
did_restart, opt_ml = self._check_restart(opt_ml, err_tol)
if did_restart:
current_event = "restart"
if not self.terminate:
did_clip, opt_ml = self._check_clipping(opt_ml, err_tol)
if did_clip:
current_event = "clip"
# Sample more until sizes meet opt_ml
while not self._check_samp_sizes(opt_ml):
for lvl, m_opt in zip(self.levels, opt_ml):
self._grow_level_to(lvl, int(m_opt))
if self.terminate:
break
if self.terminate:
break
err_tol = self._define_tol()
opt_ml = self._get_opt_ml(err_tol)
g_hat = self._g_hat if self._g_hat is not None else self._aggregate_recompute()
# Stop criterion check.
self._check_stop_crit(err_tol)
if self.terminate:
current_event = "end"
# Log (exactly once per evaluate call)
self._record(event=current_event, g_hat=g_hat)
self.k += 1
return g_hat
[docs]
def get_log(self):
return self.recorder.as_list() if self.recorder else []
# --- Internal helpers ---
def _record(self, *, event: str, g_hat: np.ndarray) -> None:
if not self.recorder:
return
grad_norm = float(np.linalg.norm(g_hat)) if g_hat is not None else None
last_v = self.levels[-1].v_delta if self.levels else None
self.recorder.add(
event=event,
num_grads=int(self.counter),
hier_length=len(self.levels),
last_v=last_v,
grad_norm=grad_norm,
iteration=int(self.k),
terminate_reason=self.terminate_reason,
)
def _check_eval_budget(self, extra_eval: int) -> bool:
if self.counter + int(extra_eval) <= self.max_cost:
return True
self.terminate = True
self.terminate_reason = "max_cost"
return False
def _new_level_sampler(self):
if not self.finite:
return self.sampler # callable
# Reuse the single FiniteSampler instance to maintain state across levels.
# Lazily initialize it using the current RNG so seeding works even if `rng`
# is set after __post_init__ (common in experiment scripts).
if self._finite_sampler is None:
assert self.data_size is not None
start = int(self.rng.integers(0, self.data_size))
self._finite_sampler = FiniteSampler(data=self.sampler, start=start)
return self._finite_sampler.next
def _add_level(self, x: np.ndarray) -> None:
if not self.levels:
# base
delta_stats = WelfordVec.zeros(self.dim)
lvl = LevelState(
x=x.copy(),
x_prev=None,
cost=1,
sample_fn=self._new_level_sampler(),
delta_stats=delta_stats,
base_stats=None,
m_min=self.m_restart_min,
delta_resamp=ResamplingAcc.zeros(self.re_part, self.dim) if self.use_resampling else None,
base_resamp=None,
)
self.levels.append(lvl)
self._g_hat = np.zeros(self.dim)
return
prev = self.levels[-1].x
delta_stats = WelfordVec.zeros(self.dim)
base_stats = WelfordVec.zeros(self.dim)
lvl = LevelState(
x=x.copy(),
x_prev=prev.copy(),
cost=2,
sample_fn=self._new_level_sampler(),
delta_stats=delta_stats,
base_stats=base_stats,
m_min=self.min_batch,
delta_resamp=ResamplingAcc.zeros(self.re_part, self.dim) if self.use_resampling else None,
base_resamp=ResamplingAcc.zeros(self.re_part, self.dim) if self.use_resampling else None,
)
self.levels.append(lvl)
def _pilot_update_last_level(self) -> None:
"""
Take m_min samples for the last level and update its statistics.
"""
lvl = self.levels[-1]
m = int(lvl.m_min)
if m <= 0:
return
if not self._check_eval_budget(m * lvl.cost):
return
thetas = lvl.sample_fn(m)
if lvl.cost == 1:
g = self.grad(lvl.x, thetas)
if g.shape != (m, self.dim):
raise ValueError(f"grad returned {g.shape}, expected {(m, self.dim)}")
old_mean = lvl.mean_delta.copy()
lvl.delta_stats.update_batch(g)
if lvl.delta_resamp is not None:
lvl.delta_resamp.update_batch(g)
self._update_g_hat(old_mean, lvl.mean_delta)
self.counter += m
self._last_pilot_thetas = thetas
self._last_pilot_g_cur = g
else:
g_cur = self.grad(lvl.x, thetas)
g_prev = self.grad(lvl.x_prev, thetas)
if g_cur.shape != (m, self.dim) or g_prev.shape != (m, self.dim):
raise ValueError("grad returned wrong shape for diff level")
old_mean = lvl.mean_delta.copy()
lvl.base_stats.update_batch(g_cur)
lvl.delta_stats.update_batch(g_cur - g_prev)
if lvl.base_resamp is not None:
lvl.base_resamp.update_batch(g_cur)
if lvl.delta_resamp is not None:
lvl.delta_resamp.update_batch(g_cur - g_prev)
self._update_g_hat(old_mean, lvl.mean_delta)
self.counter += 2 * m
self._last_pilot_thetas = thetas
self._last_pilot_g_cur = g_cur
def _update_g_hat(self, old_mean: np.ndarray, new_mean: np.ndarray) -> None:
if self._g_hat is None:
self._g_hat = np.zeros(self.dim)
self._g_hat += (new_mean - old_mean)
def _aggregate_recompute(self) -> np.ndarray:
self._g_hat = np.zeros(self.dim)
for lvl in self.levels:
self._g_hat += lvl.mean_delta
return self._g_hat
# --- Error control, sample sizes, and stopping ---
def _define_tol(self) -> float:
if self.use_resampling:
self.err_tol = self._define_tol_resampling()
return float(self.err_tol)
g_hat = self._g_hat if self._g_hat is not None else self._aggregate_recompute()
n = float(self.norm_estimator.update(g_hat))
# Plain norm mode: err_tol = eps * ||g_hat||
self.err_tol = float(self.eps * n)
return float(self.err_tol)
def _define_tol_resampling(self) -> float:
"""
Define tolerance from resampled gradient norms.
err_tol = eps * q_{re_quantile}( ||g_hat^{(res)}|| ).
The stop quantile is stored for the stochastic stopping rule.
"""
if not self.levels:
self._norm_stop = 0.0
return float(self.eps * 0.0)
g_hat = self._g_hat if self._g_hat is not None else self._aggregate_recompute()
L = len(self.levels)
# Materialize leave-one-partition-out means per level
loo_means = []
for lvl in self.levels:
arr = lvl.delta_loo_means()
if arr is None:
# Shouldn't happen if use_resampling=True, but keep safe fallback
arr = np.tile(lvl.mean_delta[None, :], (self.re_part, 1))
loo_means.append(arr)
# We use `self.err_tol` (previous tolerance) to estimate opt_ml,
# matching the current estimator sequencing.
ml_prev = np.asarray([lvl.m_prev for lvl in self.levels], dtype=float)
opt_ml_prev = self._get_opt_ml(float(self.err_tol))
cost = float(np.maximum(opt_ml_prev - ml_prev, 0.0).sum())
re_samp = int(self.re_tot_cost * cost / (self.re_cost * max(L, 1)))
re_samp = int(
min(
re_samp,
int(self.re_max_samp),
int((2 * self.re_part) ** L),
)
)
n_samp = int(max(re_samp, int(self.re_min_n)))
choices = self.rng.integers(0, self.re_part, size=(n_samp, L), dtype=np.int64)
g_samp = np.zeros((n_samp, self.dim), dtype=float)
for j in range(L):
g_samp += loo_means[j][choices[:, j]]
norms = np.linalg.norm(g_samp, axis=1)
# Include the full estimator norm as an additional sample.
norms = np.concatenate([norms, np.asarray([float(np.linalg.norm(g_hat))])], axis=0)
tol_norm, stop_norm = self.norm_estimator.update_from_norms(norms) # type: ignore[attr-defined]
self._norm_stop = float(stop_norm)
return float(self.eps * tol_norm)
def _get_opt_ml(self, err_tol: float) -> np.ndarray:
if self.finite:
return self._get_opt_ml_finite(err_tol)
return self._get_opt_ml_continuous(err_tol)
def _get_opt_ml_continuous(self, err_tol: float) -> np.ndarray:
vl = []
ml = []
cl = []
m_min = []
for i, lvl in enumerate(self.levels):
vl.append(lvl.v_batch if i == 0 else lvl.v_delta)
ml.append(lvl.m)
cl.append(lvl.cost)
m_min.append(self.m_restart_min if i == 0 else self.min_batch)
vl = np.asarray(vl, dtype=float)
ml = np.asarray(ml, dtype=float)
cl = np.asarray(cl, dtype=float)
m_min = np.asarray(m_min, dtype=float)
constant = float(np.sum(np.sqrt(vl * cl)))
opt_ml = np.ceil((err_tol ** (-2)) * np.sqrt(vl / cl) * constant).astype(int)
opt_ml = np.maximum(opt_ml, m_min.astype(int))
opt_ml = np.maximum(opt_ml, ml.astype(int))
return opt_ml
def _get_opt_ml_finite(self, err_tol: float) -> np.ndarray:
assert self.data_size is not None
ds = self.data_size
vl = []
ml = []
cl = []
m_min = []
for i, lvl in enumerate(self.levels):
vl.append(lvl.v_batch if i == 0 else lvl.v_delta)
ml.append(lvl.m)
cl.append(lvl.cost)
m_min.append(self.m_restart_min if i == 0 else self.min_batch)
vl = np.asarray(vl, dtype=float)
ml = np.asarray(ml, dtype=float)
cl = np.asarray(cl, dtype=float)
m_min = np.asarray(m_min, dtype=int)
opt_ml = m_min.copy().astype(int)
ells = ml < ds
while float(np.sum(vl / opt_ml * (1.0 - opt_ml / ds))) > float(err_tol**2):
aux1 = float(err_tol**2) + (1.0 / ds) * float(np.sum(vl[ells]))
aux2 = float(np.sum(np.sqrt(vl[ells] * cl[ells])))
opt_ml[ells] = np.ceil(np.sqrt(vl[ells] / cl[ells]) * aux2 / aux1).astype(int)
opt_ml = np.minimum(opt_ml, ds)
opt_ml = np.maximum(opt_ml, ml.astype(int))
opt_ml = np.maximum(opt_ml, m_min)
ells = opt_ml < ds
return opt_ml
def _check_samp_sizes(self, opt_ml: np.ndarray) -> bool:
return all(opt_ml <= np.asarray([lvl.m for lvl in self.levels], dtype=int))
def _check_stop_crit(self, err_tol: float) -> None:
if self.stop_crit_norm <= 0.0:
return
if self.use_resampling and self._norm_stop is not None:
norm_est = float(self._norm_stop)
else:
g_hat = self._g_hat if self._g_hat is not None else self._aggregate_recompute()
norm_est = float(np.linalg.norm(g_hat))
l2_err = self._compute_error()
if norm_est < np.sqrt(self.stop_crit_norm) - np.sqrt(max(l2_err, 0.0)):
self.terminate = True
self.terminate_reason = "stop_crit"
def _compute_bias(self) -> float:
if len(self.levels) <= 1:
return 0.0
bias = 0.0
if self.finite:
assert self.data_size is not None
for lvl in self.levels[:-1]:
factor = (self.data_size - lvl.m_prev) / self.data_size
bias += factor * (lvl.m_prev / (lvl.m**2)) * lvl.v_delta
else:
for lvl in self.levels[:-1]:
bias += (lvl.m_prev / (lvl.m**2)) * lvl.v_delta
return float(bias)
def _compute_statistical_error(self) -> float:
if not self.levels:
return 0.0
stat_err = 0.0
if self.finite:
assert self.data_size is not None
for lvl in self.levels[:-1]:
factor = (self.data_size - lvl.m) / self.data_size
stat_err += factor * ((lvl.m - lvl.m_prev) / (lvl.m**2)) * lvl.v_delta
last = self.levels[-1]
factor = (self.data_size - last.m) / self.data_size
stat_err += factor * (last.v_delta / last.m)
else:
for lvl in self.levels[:-1]:
stat_err += ((lvl.m - lvl.m_prev) / (lvl.m**2)) * lvl.v_delta
last = self.levels[-1]
stat_err += last.v_delta / last.m
return float(stat_err)
def _compute_error(self) -> float:
return float(self._compute_bias() + self._compute_statistical_error())
# --- Sampling growth ---
def _grow_level_to(self, lvl: LevelState, m_opt: int) -> None:
if lvl.m >= m_opt:
return
m_to_sample = min(m_opt - lvl.m, lvl.m if lvl.m > 0 else m_opt - lvl.m)
m_min = lvl.m_min
if m_to_sample <= 0:
return
if self.finite:
assert self.data_size is not None
m_to_sample = min(m_to_sample, self.data_size - lvl.m)
m_min = min(m_min, self.data_size - lvl.m)
if m_to_sample > 0:
m_to_sample = max(m_to_sample, m_min)
extra_eval = m_to_sample * lvl.cost
if not self._check_eval_budget(extra_eval):
return
remaining = int(m_to_sample)
dim = int(self.dim or 1)
max_rows = max(1, int(self.max_grad_batch_elems // max(dim, 1)))
while remaining > 0:
m_chunk = int(min(remaining, max_rows))
thetas = lvl.sample_fn(int(m_chunk))
if lvl.cost == 1:
g = self.grad(lvl.x, thetas)
old_mean = lvl.mean_delta.copy()
lvl.delta_stats.update_batch(g)
if lvl.delta_resamp is not None:
lvl.delta_resamp.update_batch(g)
self._update_g_hat(old_mean, lvl.mean_delta)
self.counter += int(m_chunk)
else:
g_cur = self.grad(lvl.x, thetas)
g_prev = self.grad(lvl.x_prev, thetas)
old_mean = lvl.mean_delta.copy()
lvl.base_stats.update_batch(g_cur)
lvl.delta_stats.update_batch(g_cur - g_prev)
if lvl.base_resamp is not None:
lvl.base_resamp.update_batch(g_cur)
if lvl.delta_resamp is not None:
lvl.delta_resamp.update_batch(g_cur - g_prev)
self._update_g_hat(old_mean, lvl.mean_delta)
self.counter += int(2 * m_chunk)
remaining -= m_chunk
# --- Policy: dropping / restart / clipping ---
def _check_restart(self, opt_ml: np.ndarray, err_tol: float) -> Tuple[bool, np.ndarray]:
ml = np.asarray([lvl.m for lvl in self.levels], dtype=float)
mice_cost = float(np.maximum(0.0, np.ceil(opt_ml - ml)).sum() + self.policy.aggr_cost * len(opt_ml))
new_delta = self._restart_delta(self.levels[-1])
opt_ml_restart = self._get_opt_ml_for_levels([new_delta], err_tol)
opt_ml_restart = np.maximum(opt_ml_restart, self.m_restart_min)
restart_cost = float(np.maximum(0.0, opt_ml_restart[0] - ml[-1]) + self.policy.aggr_cost)
if (
restart_cost < mice_cost * (1.0 + self.policy.restart_param)
or len(self.levels) > self.policy.max_hierarchy_size
or self.force_restart
):
self.force_restart = False
self.levels = [new_delta]
self._g_hat = new_delta.mean_delta.copy()
return True, np.asarray(opt_ml_restart, dtype=int)
return False, opt_ml
def _restart_delta(self, last: LevelState) -> LevelState:
if last.cost == 1:
base_stats = last.delta_stats
else:
assert last.base_stats is not None
base_stats = last.base_stats
new_delta_stats = WelfordVec.zeros(self.dim)
new_delta_stats.mean[...] = base_stats.mean
new_delta_stats.m2[...] = base_stats.m2
new_delta_stats.n = base_stats.n
lvl = LevelState(
x=last.x.copy(),
x_prev=None,
cost=1,
sample_fn=last.sample_fn,
delta_stats=new_delta_stats,
base_stats=None,
m_min=self.m_restart_min,
m_prev=last.m_prev,
delta_resamp=(last.delta_resamp if last.cost == 1 else last.base_resamp),
base_resamp=None,
)
return lvl
def _get_opt_ml_for_levels(self, levels: List[LevelState], err_tol: float) -> np.ndarray:
if self.finite:
assert self.data_size is not None
ds = self.data_size
vl = []
ml = []
cl = []
m_min = []
for i, lvl in enumerate(levels):
vl.append(lvl.v_batch if i == 0 else lvl.v_delta)
ml.append(lvl.m)
cl.append(lvl.cost)
m_min.append(self.m_restart_min if i == 0 else self.min_batch)
vl = np.asarray(vl, dtype=float)
ml = np.asarray(ml, dtype=float)
cl = np.asarray(cl, dtype=float)
m_min = np.asarray(m_min, dtype=int)
opt_ml = m_min.copy().astype(int)
ells = ml < ds
while float(np.sum(vl / opt_ml * (1.0 - opt_ml / ds))) > float(err_tol**2):
aux1 = float(err_tol**2) + (1.0 / ds) * float(np.sum(vl[ells]))
aux2 = float(np.sum(np.sqrt(vl[ells] * cl[ells])))
opt_ml[ells] = np.ceil(np.sqrt(vl[ells] / cl[ells]) * aux2 / aux1).astype(int)
opt_ml = np.minimum(opt_ml, ds)
opt_ml = np.maximum(opt_ml, ml.astype(int))
opt_ml = np.maximum(opt_ml, m_min)
ells = opt_ml < ds
return opt_ml
vl = []
ml = []
cl = []
m_min = []
for i, lvl in enumerate(levels):
vl.append(lvl.v_batch if i == 0 else lvl.v_delta)
ml.append(lvl.m)
cl.append(lvl.cost)
m_min.append(self.m_restart_min if i == 0 else self.min_batch)
vl = np.asarray(vl, dtype=float)
ml = np.asarray(ml, dtype=float)
cl = np.asarray(cl, dtype=float)
m_min = np.asarray(m_min, dtype=float)
constant = float(np.sum(np.sqrt(vl * cl)))
opt_ml = np.ceil((err_tol ** (-2)) * np.sqrt(vl / cl) * constant).astype(int)
opt_ml = np.maximum(opt_ml, m_min.astype(int))
opt_ml = np.maximum(opt_ml, ml.astype(int))
return opt_ml
def _check_dropping(self, opt_ml: np.ndarray, err_tol: float) -> Tuple[bool, np.ndarray]:
ml = np.asarray([lvl.m for lvl in self.levels], dtype=float)
mice_cost = float(np.maximum(0.0, np.ceil(opt_ml - ml)).sum() + self.policy.aggr_cost * len(opt_ml))
delta_drop = self._build_drop_delta()
if delta_drop is None:
return False, opt_ml
levels_drop = self.levels[:-2] + [delta_drop]
opt_ml_drop = self._get_opt_ml_for_levels(levels_drop, err_tol)
ml_drop = np.asarray([lvl.m for lvl in self.levels[:-2]] + [self.levels[-1].m], dtype=float)
drop_cost = float(np.maximum(0.0, np.ceil(opt_ml_drop - ml_drop)).sum() + self.policy.aggr_cost * len(opt_ml_drop))
if drop_cost <= mice_cost * (1.0 + self.policy.drop_param):
self.levels = levels_drop
self._g_hat = self._aggregate_recompute()
return True, np.asarray(opt_ml_drop, dtype=int)
return False, opt_ml
def _build_drop_delta(self) -> Optional[LevelState]:
if len(self.levels) < 3:
return None
x_k = self.levels[-1].x
x_km2 = self.levels[-3].x
m = int(self.min_batch)
if self._last_pilot_thetas is not None and self._last_pilot_g_cur is not None:
thetas = self._last_pilot_thetas
g_cur = self._last_pilot_g_cur
m = g_cur.shape[0]
if not self._check_eval_budget(m):
return None
else:
if not self._check_eval_budget(2 * m):
return None
thetas = self.levels[-1].sample_fn(m)
g_cur = self.grad(x_k, thetas)
self.counter += m
g_km2 = self.grad(x_km2, thetas)
self.counter += m
delta_stats = WelfordVec.zeros(self.dim)
base_stats = WelfordVec.zeros(self.dim)
base_stats.update_batch(g_cur)
delta_stats.update_batch(g_cur - g_km2)
drop_lvl = LevelState(
x=x_k.copy(),
x_prev=x_km2.copy(),
cost=2,
sample_fn=self.levels[-1].sample_fn,
delta_stats=delta_stats,
base_stats=base_stats,
m_min=self.min_batch,
delta_resamp=ResamplingAcc.zeros(self.re_part, self.dim) if self.use_resampling else None,
base_resamp=ResamplingAcc.zeros(self.re_part, self.dim) if self.use_resampling else None,
)
if drop_lvl.base_resamp is not None:
drop_lvl.base_resamp.update_batch(g_cur)
if drop_lvl.delta_resamp is not None:
drop_lvl.delta_resamp.update_batch(g_cur - g_km2)
return drop_lvl
def _check_clipping(self, opt_ml: np.ndarray, err_tol: float) -> Tuple[bool, np.ndarray]:
if not self.policy.clip_type:
return False, opt_ml
if self.policy.clip_every and self.k % self.policy.clip_every != 0:
return False, opt_ml
if self.policy.clip_type == "full":
if not self.finite:
return False, opt_ml
assert self.data_size is not None
m_is_datasize = np.where(opt_ml == self.data_size)[0]
if len(m_is_datasize) and int(m_is_datasize.max()) > 0:
lvl_clip = int(m_is_datasize.max())
ml = np.asarray([lvl.m for lvl in self.levels], dtype=float)
cost = float(np.maximum(opt_ml - ml, 0).sum() + self.policy.aggr_cost * len(ml))
deltas_clip = self.levels[lvl_clip:]
opt_ml_clip = self._get_opt_ml_for_levels(deltas_clip, err_tol)
ml_clip = np.asarray([lvl.m for lvl in deltas_clip], dtype=float)
cost_clip = float(np.maximum(opt_ml_clip - ml_clip, 0).sum() + self.policy.aggr_cost * len(opt_ml_clip))
if cost_clip <= cost:
self.levels = deltas_clip
self.levels[0] = self._restart_delta(self.levels[0])
self._g_hat = self._aggregate_recompute()
return True, np.asarray(opt_ml_clip, dtype=int)
return False, opt_ml
if self.policy.clip_type == "all":
ml = np.asarray([lvl.m for lvl in self.levels], dtype=float)
cost = float(np.maximum(opt_ml - ml, 0).sum() + self.policy.aggr_cost * len(ml))
best_cost = cost
best_i = None
best_opt = None
for i in range(len(self.levels)):
deltas_clip = self.levels[i:]
opt_ml_clip = self._get_opt_ml_for_levels(deltas_clip, err_tol)
ml_clip = np.asarray([lvl.m for lvl in deltas_clip], dtype=float)
cost_clip = float(np.maximum(opt_ml_clip - ml_clip, 0).sum() + self.policy.aggr_cost * len(opt_ml_clip))
if cost_clip < best_cost:
best_cost = cost_clip
best_i = i
best_opt = opt_ml_clip
if best_i is not None and best_opt is not None:
self.levels = self.levels[best_i:]
self.levels[0] = self._restart_delta(self.levels[0])
self._g_hat = self._aggregate_recompute()
return True, np.asarray(best_opt, dtype=int)
return False, opt_ml
return False, opt_ml