MICE Package API

Top-level imports

MICE: Multi-Iteration stochastiC Estimator

A gradient estimator for stochastic optimization that uses successive control variates along the optimization path to reduce variance.

Core implementation

class mice.core_impl.MICE(grad: ~typing.Callable[[~numpy.ndarray, ~typing.Any], ~numpy.ndarray], sampler: ~typing.Callable[[int], ~typing.Any] | ~numpy.ndarray | ~typing.Sequence[~typing.Any], eps: float = 0.577, min_batch: int = 10, restart_factor: int = 10, max_cost: float = inf, stop_crit_norm: float = 0.0, stop_crit_prob: float = 0.05, convex: bool = False, policy: ~mice.policy.DropRestartClipPolicy = <factory>, recorder: ~mice.logging.Recorder | None = None, use_resampling: bool = True, re_part: int = 5, re_quantile: float = 0.05, re_tot_cost: float = 0.2, re_min_n: int = 5, re_max_samp: int = 1000, max_grad_batch_elems: int = 5000000)[source]

Bases: object

Multi-Iteration stochastiC Estimator for stochastic gradients.

This class maintains a hierarchy of control variates and adaptively allocates samples across levels to satisfy an error tolerance while reducing gradient evaluation cost.

grad: Callable[[ndarray, Any], ndarray]
sampler: Callable[[int], Any] | ndarray | Sequence[Any]
eps: float = 0.577
min_batch: int = 10
restart_factor: int = 10
max_cost: float = inf
stop_crit_norm: float = 0.0
stop_crit_prob: float = 0.05
convex: bool = False
policy: DropRestartClipPolicy
recorder: Recorder | None = None
use_resampling: bool = True
re_part: int = 5
re_quantile: float = 0.05
re_tot_cost: float = 0.2
re_min_n: int = 5
re_max_samp: int = 1000
max_grad_batch_elems: int = 5000000
evaluate(x: ndarray) ndarray[source]

Evaluate a MICE gradient estimate at x and update internal state.

Returns the aggregated gradient estimator for the current iterate.

get_log()[source]

Core compatibility layer

Compatibility shim for core imports.

class mice.core.MICE(grad: ~typing.Callable[[~numpy.ndarray, ~typing.Any], ~numpy.ndarray], sampler: ~typing.Callable[[int], ~typing.Any] | ~numpy.ndarray | ~typing.Sequence[~typing.Any], eps: float = 0.577, min_batch: int = 10, restart_factor: int = 10, max_cost: float = inf, stop_crit_norm: float = 0.0, stop_crit_prob: float = 0.05, convex: bool = False, policy: ~mice.policy.DropRestartClipPolicy = <factory>, recorder: ~mice.logging.Recorder | None = None, use_resampling: bool = True, re_part: int = 5, re_quantile: float = 0.05, re_tot_cost: float = 0.2, re_min_n: int = 5, re_max_samp: int = 1000, max_grad_batch_elems: int = 5000000)[source]

Bases: object

Multi-Iteration stochastiC Estimator for stochastic gradients.

This class maintains a hierarchy of control variates and adaptively allocates samples across levels to satisfy an error tolerance while reducing gradient evaluation cost.

convex: bool = False
eps: float = 0.577
evaluate(x: ndarray) ndarray[source]

Evaluate a MICE gradient estimate at x and update internal state.

Returns the aggregated gradient estimator for the current iterate.

get_log()[source]
max_cost: float = inf
max_grad_batch_elems: int = 5000000
min_batch: int = 10
re_max_samp: int = 1000
re_min_n: int = 5
re_part: int = 5
re_quantile: float = 0.05
re_tot_cost: float = 0.2
recorder: Recorder | None = None
restart_factor: int = 10
stop_crit_norm: float = 0.0
stop_crit_prob: float = 0.05
use_resampling: bool = True
grad: Callable[[ndarray, Any], ndarray]
sampler: Callable[[int], Any] | ndarray | Sequence[Any]
policy: DropRestartClipPolicy

Policy

class mice.policy.DropRestartClipPolicy(drop_param: float = 0.5, restart_param: float = 0.0, max_hierarchy_size: int = 1000, aggr_cost: float = 0.1, clip_type: str | None = None, clip_every: int = 0)[source]

Bases: object

Policy parameters.

Drop is checked every iteration. Clip can be checked every clip_every iterations or fully deactivated (clip_type=None).

drop_param: float
restart_param: float
max_hierarchy_size: int
aggr_cost: float
clip_type: str | None
clip_every: int

Norm estimators

class mice.norms.PlainNormEstimator(convex: bool = False, _best: float = inf)[source]

Bases: object

Plain norm estimator: uses ||g_hat||.

convex: bool
update(g_hat: ndarray) float[source]
class mice.norms.ResamplingNormEstimator(re_part: int = 5, re_quantile: float = 0.05, stop_quantile: float = 0.95, convex: bool = False, _best_tol: float = inf, _best_stop: float = inf)[source]

Bases: object

Resampling norm estimator: consumes a vector of resampled ||g_hat|| values and returns two quantiles: - a conservative low-quantile for error tolerance selection - a quantile for the stochastic stopping rule

re_part: int
re_quantile: float
stop_quantile: float
convex: bool
update_from_norms(norms: ndarray) tuple[float, float][source]

norms: 1D array of ||g_hat^{(s)}|| values from resampling. Returns (tol_norm, stop_norm).

Sampling

class mice.sampling.FiniteSampler(data: ndarray | Sequence[Any], start: int, counter: int = 0)[source]

Bases: object

Deterministic without-replacement sampler over a finite population by cycling through a random starting offset.

data: ndarray | Sequence[Any]
start: int
counter: int
data_size: int
next(n: int)[source]
mice.sampling.make_sampler(sampler: Callable[[int], Any] | ndarray | Sequence[Any], rng: Generator) tuple[Callable[[int], Any], int | None][source]

Returns (sample_fn, data_size_if_finite).

State

class mice.state.WelfordVec(mean: ndarray, m2: ndarray, n: int = 0)[source]

Bases: object

Online mean / M2 accumulator for vectors.

We store M2 per-coordinate (same shape as mean) so that sum(var_i) == M2.sum()/(n-1) matches the manuscript’s use of V_{l,k} = sum_i Var(Delta^{(i)}_{l,k}).

mean: ndarray
m2: ndarray
n: int
classmethod zeros(dim: int, dtype=<class 'numpy.float64'>) WelfordVec[source]
update_batch(x: ndarray) None[source]

Update with a batch x of shape (m, d).

property var_sum: float
class mice.state.ResamplingAcc(re_part: int, sum_total: ndarray, cnt_total: int, sum_part: ndarray, cnt_part: ndarray)[source]

Bases: object

Maintain partitioned sums/counts to compute leave-one-partition-out means efficiently (O(re_part * d) to materialize all LOO means).

re_part: int
sum_total: ndarray
cnt_total: int
sum_part: ndarray
cnt_part: ndarray
classmethod zeros(re_part: int, dim: int, dtype=<class 'numpy.float64'>) ResamplingAcc[source]
update_batch(x: ndarray) None[source]

x: shape (m, d)

loo_means() ndarray[source]

Returns array of shape (re_part, d), where row p is the mean excluding samples that fell into partition p.

class mice.state.LevelState(x: ndarray, cost: int, x_prev: ndarray | None, sample_fn: Callable[[int], Any], delta_stats: WelfordVec, base_stats: WelfordVec | None, m_min: int, delta_resamp: ResamplingAcc | None = None, base_resamp: ResamplingAcc | None = None, m_prev: int = 0)[source]

Bases: object

Statistics for a single level (either base gradient or a gradient difference).

  • base level: Delta = grad(x_l)

  • diff level: Delta = grad(x_l) - grad(x_prev)

x: ndarray
cost: int
x_prev: ndarray | None
sample_fn: Callable[[int], Any]
delta_stats: WelfordVec
base_stats: WelfordVec | None
m_min: int
delta_resamp: ResamplingAcc | None
base_resamp: ResamplingAcc | None
m_prev: int
property m: int
property v_delta: float
property v_base: float
property v_batch: float

Variance proxy used in sample-size optimization: - base level: variance of grad(x_0) - diff level: variance of grad(x_l) (the “top” gradient)

property mean_delta: ndarray
delta_loo_means() ndarray | None[source]

Logging

class mice.logging.Recorder(events: ~typing.List[~typing.Dict[str, ~typing.Any]] = <factory>)[source]

Bases: object

Minimal event recorder for MICE.

Keep this extremely lightweight; convert to pandas on demand (future).

events: List[Dict[str, Any]]
add(*, event: str, num_grads: int, hier_length: int, last_v: float | None, grad_norm: float | None, iteration: int, terminate_reason: str | None = None) None[source]
as_list()[source]

Plotting

mice.plot_mice.plot_mice(data, ax, x, y, style='loglog', markers=True, legend=True, color='C0')[source]

Plot MICE logs on the given Matplotlib axes.

Parameters:
  • data (pandas.DataFrame) – DataFrame containing optimization log data, including event markers.

  • ax (matplotlib.axes.Axes) – Axes object where the data will be plotted.

  • x (str) – Column name used for the x-axis.

  • y (str) – Column name used for the y-axis.

  • style ({'loglog', 'semilogy', 'semilogx', 'plot'}) – Plot style controlling linear/log scaling of axes.

  • markers (bool) – If True, adds event markers (start, add, dropped, restart, end).

  • legend (bool) – If True, adds a legend.

  • color (str) – Line color passed to Matplotlib.

Returns:

The updated axes.

Return type:

matplotlib.axes.Axes