From ef6e2e619525165a67d9d5578c56cec2192e2339 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Wed, 31 Mar 2021 07:29:59 -0700 Subject: [PATCH 01/47] First draft Stochastic Hybrid Prox LMO --- chop/stochastic.py | 125 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 114 insertions(+), 11 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index e2fe3e7..4fdcb85 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -133,7 +133,7 @@ def __init__(self, params, prox=None, lr=.1, momentum=.9, normalization='none'): self.prox = [] for prox_el in prox: if prox_el is not None: - self.prox.append(lambda x, s=None: prox_el(x.unsqueeze(0)).squeeze()) + self.prox.append(lambda x, s=None: prox_el(x.unsqueeze(0), s).squeeze(0)) else: self.prox.append(lambda x, s=None: x) @@ -141,9 +141,8 @@ def __init__(self, params, prox=None, lr=.1, momentum=.9, normalization='none'): raise ValueError("lr must be float or 'sublinear'.") self.lr = lr - if type(momentum) == float: - if not(0. <= momentum <= 1.): - raise ValueError("Momentum must be in [0., 1.].") + if not(0. <= momentum <= 1.): + raise ValueError("Momentum must be in [0., 1.].") self.momentum = momentum if normalization in self.POSSIBLE_NORMALIZATIONS: @@ -198,7 +197,7 @@ def step(self, closure=None): if self.lr == 'sublinear': step_size = 1. / (state['step'] + 1.) else: - step_size = self.lr + step_size = state['lr'] new_p = self.prox[idx](p - step_size * grad_est, 1.) state['certificate'] = torch.norm((p - new_p) / step_size) @@ -251,8 +250,8 @@ def _lmo(u, x): if not (type(lr) == float or lr == 'sublinear'): raise ValueError("lr must be float or 'sublinear'.") - self.lr = lr - defaults = dict(prox=self.prox, lmo=self.lmo, name=self.name) + defaults = dict(prox=self.prox, lmo=self.lmo, lr=lr, + name=self.name) super(PGDMadry, self).__init__(params, defaults) @property @@ -271,8 +270,8 @@ def step(self, step_size=None, closure=None): if closure is not None: with torch.enable_grad(): loss = closure() - idx = 0 for groups in self.param_groups: + idx = 0 for p in groups['params']: if p.grad is None: continue @@ -286,10 +285,10 @@ def step(self, step_size=None, closure=None): state['step'] = 0. state['step'] += 1. - if self.lr == 'sublinear': + if state['lr'] == 'sublinear': step_size = 1. / (state['step'] + 1.) else: - step_size = self.lr + step_size = state['lr'] lmo_res, _ = self.lmo[idx](-p.grad, p) normalized_grad = lmo_res + p new_p = self.prox[idx](p + step_size * normalized_grad) @@ -446,7 +445,7 @@ class FrankWolfe(Optimizer): name = 'Frank-Wolfe' POSSIBLE_NORMALIZATIONS = {'gradient', 'none'} - def __init__(self, params, lmo, lr=.1, momentum=.9, + def __init__(self, params, lmo, lr=.1, momentum=0., weight_decay=0., normalization='none'): @@ -555,3 +554,107 @@ def step(self, closure=None): p.add_(step_size * update_direction) idx += 1 return loss + + +class SplittingProxFW(Optimizer): + # TODO: write docstring! + + name = 'Hybrid Prox FW Splitting' + + POSSIBLE_NORMALIZATIONS = {'none', 'gradient'} + + def __init__(self, params, lmo, prox=None, + lr_lmo=.1, + lr_prox=.1, + momentum=0., weight_decay=0., + normalization='none'): + + self.lmo = [] + for oracle in lmo: + if oracle is None: + # Then FW will not be used on this parameter + raise ValueError("LMOs cannot be None for this optimizer.") + else: + def _lmo(u, x): + update_direction, max_step_size = oracle(u.unsqueeze(0), x.unsqueeze(0)) + return update_direction.squeeze(dim=0), max_step_size + self.lmo.append(_lmo) + + if prox is None: + prox = [None] * len(list(params)) + + self.prox = [] + for prox_el in prox: + if prox_el is not None: + self.prox.append(lambda x, s=None: prox_el(x.unsqueeze(0), s).squeeze(0)) + else: + self.prox.append(lambda x, s=None: x) + + for name, lr in (('lr_lmo', lr_lmo), + ('lr_prox', lr_prox)): + if not type(lr) == float: + msg = f"{name} should be a float, got {lr}." + raise ValueError(msg) + + if not(0. <= momentum <= 1.): + raise ValueError("omentum must be in [0., 1.].") + + if not (weight_decay >= 0): + raise ValueError("weight_decay must be nonnegative.") + self.weight_decay = weight_decay + + if normalization not in self.POSSIBLE_NORMALIZATIONS: + raise ValueError(f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}") + defaults = dict(lmo=self.lmo, prox=self.prox, + name=self.name, + momentum=momentum, + lr_lmo=lr_lmo, + lr_prox=lr_prox, + weight_decay=weight_decay, + normalization=normalization) + + super(SplittingProxFW, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + for group in self.param_groups: + idx = 0 + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + self.weight_decay * p + + if grad.is_sparse: + raise RuntimeError("We do not yet support sparse gradients.") + # Keep track of the step + state = self.state[p] + + # Initialization + if len(state) == 0: + state['step'] = 0. + # split variable: p = x + y + state['x'] = .5 * p.detach().clone() + state['y'] = .5 * p.detach().clone() + state['step'] += 1. + + y_update, max_step_size = self.lmo[idx](-grad, state['y']) + state['lr_lmo'] = torch.minimum(state['lr_lmo'], max_step_size) + w = y_update + state['y'] + v = self.prox[idx](state['x'] + state['y'] - w - grad / state['lr_prox'], state['lr_prox']) + + state['y'].add_(y_update, alpha=state['lr_lmo']) + x_update = v - state['x'] + state['x'].add_(x_update, alpha=state['lr_lmo']) + + p.copy_(state['x'] + state['y']) + idx += 1 + return loss From 6afc5104a15217ff6fb818d96f6f8f929f65d26a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Wed, 31 Mar 2021 10:14:31 -0700 Subject: [PATCH 02/47] Added stochastic Robust PCA example first draft --- chop/stochastic.py | 64 +++++++++++----- examples/plot_robust_PCA.py | 9 ++- examples/plot_stochastic_robust_PCA.py | 102 +++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 23 deletions(-) create mode 100644 examples/plot_stochastic_robust_PCA.py diff --git a/chop/stochastic.py b/chop/stochastic.py index 4fdcb85..571301d 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -569,35 +569,50 @@ def __init__(self, params, lmo, prox=None, momentum=0., weight_decay=0., normalization='none'): - self.lmo = [] + # initialize proxes + if prox is None: + prox = [None] * len(list(params)) + + prox_candidates = [] + for prox_el in prox: + if prox_el is not None: + prox_candidates.append(lambda x, s=None: prox_el(x.unsqueeze(0), s).squeeze(0)) + else: + prox_candidates.append(lambda x, s=None: x) + # initialize lmos + lmo_candidates = [] for oracle in lmo: if oracle is None: # Then FW will not be used on this parameter - raise ValueError("LMOs cannot be None for this optimizer.") + _lmo = None else: def _lmo(u, x): update_direction, max_step_size = oracle(u.unsqueeze(0), x.unsqueeze(0)) return update_direction.squeeze(dim=0), max_step_size - self.lmo.append(_lmo) - - if prox is None: - prox = [None] * len(list(params)) + lmo_candidates.append(_lmo) + self.lmo = [] self.prox = [] - for prox_el in prox: - if prox_el is not None: - self.prox.append(lambda x, s=None: prox_el(x.unsqueeze(0), s).squeeze(0)) + useable_params = [] + for param, lmo_oracle, prox_oracle in zip(params, lmo_candidates, prox_candidates): + if lmo_oracle: + useable_params.append(param) + self.lmo.append(lmo_oracle) + self.prox.append(prox_oracle) else: - self.prox.append(lambda x, s=None: x) + msg = (f"No LMO was provided for parameter {param}. " + f"This optimizer will not optimize this parameter. " + f"Please pass this parameter to another optimizer.") + warnings.warn(msg) for name, lr in (('lr_lmo', lr_lmo), ('lr_prox', lr_prox)): - if not type(lr) == float: - msg = f"{name} should be a float, got {lr}." + if not ((type(lr) == float) or lr == 'sublinear'): + msg = f"{name} should be a float or 'sublinear', got {lr}." raise ValueError(msg) if not(0. <= momentum <= 1.): - raise ValueError("omentum must be in [0., 1.].") + raise ValueError("momentum must be in [0., 1.].") if not (weight_decay >= 0): raise ValueError("weight_decay must be nonnegative.") @@ -631,25 +646,34 @@ def step(self, closure=None): for p in group['params']: if p.grad is None: continue - grad = p.grad + self.weight_decay * p - + grad = p.grad + state = self.state[p] if grad.is_sparse: raise RuntimeError("We do not yet support sparse gradients.") # Keep track of the step - state = self.state[p] - + grad += group['weight_decay'] * p # Initialization if len(state) == 0: state['step'] = 0. # split variable: p = x + y state['x'] = .5 * p.detach().clone() state['y'] = .5 * p.detach().clone() + # initialize grad estimate + state['grad_est'] = grad + # initialize learning rates + state['lr_prox'] = group['lr_prox'] if type(group['lr_prox'] == float) else 0. + state['lr_lmo'] = group['lr_lmo'] if type(group['lr_lmo'] == float) else 0. state['step'] += 1. + state['grad_est'].add_(grad, alpha=1. - group['momentum']) + + for lr in ('lr_prox', 'lr_lmo'): + if group[lr] == 'sublinear': + state[lr] = 2. / (state['step'] + 2) - y_update, max_step_size = self.lmo[idx](-grad, state['y']) - state['lr_lmo'] = torch.minimum(state['lr_lmo'], max_step_size) + y_update, max_step_size = group['lmo'][idx](-state['grad_est'], state['y']) + state['lr_lmo'] = min(state['lr_lmo'], max_step_size) w = y_update + state['y'] - v = self.prox[idx](state['x'] + state['y'] - w - grad / state['lr_prox'], state['lr_prox']) + v = group['prox'][idx](state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], group['lr_prox']) state['y'].add_(y_update, alpha=state['lr_lmo']) x_update = v - state['x'] diff --git a/examples/plot_robust_PCA.py b/examples/plot_robust_PCA.py index 2b283b9..d91846f 100644 --- a/examples/plot_robust_PCA.py +++ b/examples/plot_robust_PCA.py @@ -25,8 +25,10 @@ m = 1000 n = 1000 -r_p = [(5, 1e-3), (5, 3e-3), (25, 1e-3), (25, 3e-3), - (25, 3e-2), (130, 1e-2)] +r_p = [(5, 1e-3), + # (5, 3e-3), (25, 1e-3), (25, 3e-3), + # (25, 3e-2), (130, 1e-2) + ] for r, p in r_p: print(f'r={r} and p={p}') @@ -49,7 +51,7 @@ @utils.closure def sqloss(Z): - return .5 * torch.linalg.norm((Z - M).squeeze(), ord='fro') ** 2 + return .5 / M.numel() * torch.linalg.norm((Z - M).squeeze(), ord='fro') ** 2 rnuc = torch.linalg.norm(L.squeeze(), ord='nuc') sL1 = abs(S).sum() @@ -102,6 +104,7 @@ def line_search(kwargs): fig.suptitle(f'r={r} and p={p}') axes[0].plot(f_vals) + axes[0].set_ylim(0, 250) axes[0].set_title("Function values") axes[1].plot(sparse_comp) diff --git a/examples/plot_stochastic_robust_PCA.py b/examples/plot_stochastic_robust_PCA.py new file mode 100644 index 0000000..e8befa3 --- /dev/null +++ b/examples/plot_stochastic_robust_PCA.py @@ -0,0 +1,102 @@ + +""" +Stochastic Robust PCA +=========== + +This example fits a Robust PCA model to data. +It uses a stochastic hybrid Frank-Wolfe and proximal method. +See description in :func:`chop.stochastic.SplittingProxFW`. + + +We reproduce the synthetic experimental setting from `[Garber et al. 2018] `_. +We aim to recover :math:`M = L + S + N`, where :math:`L` is rank :math:`p`, +:math:`S` is :math:`p` sparse, and :math:`N` is standard Gaussian elementwise. +""" + + +import matplotlib.pyplot as plt +import torch +import chop +from chop import utils +from chop.utils.logging import Trace + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +m = 1000 +n = 1000 + +r_p = [(5, 1e-3), + # (5, 3e-3), (25, 1e-3), (25, 3e-3), + # (25, 3e-2), (130, 1e-2) + ] + +n_epochs = 100 + +for r, p in r_p: + print(f'r={r} and p={p}') + U = torch.normal(torch.zeros(m, r)) + V = torch.normal(torch.zeros(r, n)) + + # Low rank component + L = 10 * utils.bmm(U, V) + + # Sparse component + S = 100 * torch.normal(torch.zeros(m, n)) + + S *= (torch.rand_like(S) <= p) + + # Add noise + N = torch.normal(torch.zeros(m, n)) + + M = L + S + N + M = M.to(device) + + def sqloss(Z, M): + return .5 / M.numel() * torch.linalg.norm((Z - M).squeeze(), ord='fro') ** 2 + + rnuc = torch.linalg.norm(L.squeeze(), ord='nuc') + sL1 = abs(S).sum() + + print(f"Initial L1 norm: {sL1}") + print(f"Initial Nuclear norm: {rnuc}") + + rank_constraint = chop.constraints.NuclearNormBall(rnuc) + sparsity_constraint = chop.constraints.L1Ball(sL1) + + lmo = rank_constraint.lmo + prox = sparsity_constraint.prox + + Z = torch.zeros_like(M, device=device) + Z.requires_grad_(True) + + sampler = torch.utils.data.BatchSampler(torch.utils.data.RandomSampler(range(M.size(0))), + batch_size=100, + drop_last=False) + + optimizer = chop.stochastic.SplittingProxFW([Z], lmo=[lmo], prox=[prox], + lr_lmo='sublinear', + lr_prox='sublinear', + normalization='none') + + train_losses = [] + losses = [] + + for it in range(n_epochs): + for idx in sampler: + optimizer.zero_grad() + loss = sqloss(Z[idx], M[idx]) + # for logging + with torch.no_grad(): + full_loss = sqloss(Z, M) + losses.append(full_loss.item()) + train_losses.append(loss.item()) + loss.backward() + optimizer.step() + + + plt.plot(train_losses, label='training_losses') + plt.plot(losses, label='loss') + plt.ylim(0, 250) + plt.legend() + print("Done.") From 1d7b945c6315715b84bbc863ccc57082cc1208b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Thu, 1 Apr 2021 08:53:30 -0700 Subject: [PATCH 03/47] Stochastic Robust PCA example fixed --- chop/optim.py | 25 ++++----- chop/stochastic.py | 18 ++++--- examples/plot_stochastic_robust_PCA.py | 72 +++++++++++++++----------- 3 files changed, 67 insertions(+), 48 deletions(-) diff --git a/chop/optim.py b/chop/optim.py index dbb5399..a05fa14 100644 --- a/chop/optim.py +++ b/chop/optim.py @@ -506,27 +506,25 @@ def minimize_alternating_fw_prox(closure, x0, y0, prox=None, lmo=None, lipschitz # TODO: add error catching for L0 Lt = lipschitz + x.requires_grad_(True) + y.requires_grad_(True) + + fval, grad = closure(x + y) + for it in range(max_iter): if step == 'sublinear': step_size = 2. / (it + 2) * torch.ones(batch_size, device=x.device) - x.requires_grad_(True) - y.requires_grad_(True) - z = x + y - - f_val, grad = closure(z) - # estimate Lipschitz constant with backtracking line search - Lt = utils.init_lipschitz(closure, z, L0=Lt) + Lt = utils.init_lipschitz(closure, x + y, L0=Lt) - y_update, max_step_size = lmo(-grad, y) with torch.no_grad(): + y_update, max_step_size = lmo(-grad, y) w = y_update + y - prox_step_size = utils.bmul(step_size, Lt) - v = prox(z - w - utils.bdiv(grad, prox_step_size), prox_step_size) + prox_step_size = utils.bmul(step_size, Lt) + v = prox(x + y - w - utils.bdiv(grad, prox_step_size), prox_step_size) - with torch.no_grad(): if line_search is None: step_size = torch.min(step_size, max_step_size) else: @@ -540,7 +538,10 @@ def minimize_alternating_fw_prox(closure, x0, y0, prox=None, lmo=None, lipschitz if callback(locals()) is False: break - fval, grad = closure(x + y) + x.requires_grad_(True) + y.requires_grad_(True) + + fval, grad = closure(x + y) # TODO: add a certificate of optimality result = optimize.OptimizeResult(x=x, y=y, nit=it, fval=fval, grad=grad, certificate=None) return result diff --git a/chop/stochastic.py b/chop/stochastic.py index 571301d..0b588e2 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -611,8 +611,8 @@ def _lmo(u, x): msg = f"{name} should be a float or 'sublinear', got {lr}." raise ValueError(msg) - if not(0. <= momentum <= 1.): - raise ValueError("momentum must be in [0., 1.].") + if (momentum != 'sublinear') and (not (0. <= momentum <= 1.)): + raise ValueError("momentum must be in [0., 1.] or 'sublinear'.") if not (weight_decay >= 0): raise ValueError("weight_decay must be nonnegative.") @@ -651,7 +651,8 @@ def step(self, closure=None): if grad.is_sparse: raise RuntimeError("We do not yet support sparse gradients.") # Keep track of the step - grad += group['weight_decay'] * p + grad += .5 * group['weight_decay'] * p + # Initialization if len(state) == 0: state['step'] = 0. @@ -659,16 +660,21 @@ def step(self, closure=None): state['x'] = .5 * p.detach().clone() state['y'] = .5 * p.detach().clone() # initialize grad estimate - state['grad_est'] = grad + state['grad_est'] = torch.zeros_like(p) # initialize learning rates state['lr_prox'] = group['lr_prox'] if type(group['lr_prox'] == float) else 0. state['lr_lmo'] = group['lr_lmo'] if type(group['lr_lmo'] == float) else 0. - state['step'] += 1. - state['grad_est'].add_(grad, alpha=1. - group['momentum']) + state['momentum'] = group['momentum'] if type(group['momentum'] == float) else 0. for lr in ('lr_prox', 'lr_lmo'): if group[lr] == 'sublinear': state[lr] = 2. / (state['step'] + 2) + + if group['momentum'] == 'sublinear': + state['momentum'] = 4. / (state['step'] + 8.) ** (2/3) + + state['step'] += 1. + state['grad_est'].add_(grad - state['grad_est'], alpha=1. - state['momentum']) y_update, max_step_size = group['lmo'][idx](-state['grad_est'], state['y']) state['lr_lmo'] = min(state['lr_lmo'], max_step_size) diff --git a/examples/plot_stochastic_robust_PCA.py b/examples/plot_stochastic_robust_PCA.py index e8befa3..f2f9213 100644 --- a/examples/plot_stochastic_robust_PCA.py +++ b/examples/plot_stochastic_robust_PCA.py @@ -20,6 +20,7 @@ from chop import utils from chop.utils.logging import Trace +torch.manual_seed(0) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -53,7 +54,7 @@ M = M.to(device) def sqloss(Z, M): - return .5 / M.numel() * torch.linalg.norm((Z - M).squeeze(), ord='fro') ** 2 + return .5 / Z.numel() * torch.linalg.norm((Z - M).squeeze(), ord='fro') ** 2 rnuc = torch.linalg.norm(L.squeeze(), ord='nuc') sL1 = abs(S).sum() @@ -70,33 +71,44 @@ def sqloss(Z, M): Z = torch.zeros_like(M, device=device) Z.requires_grad_(True) - sampler = torch.utils.data.BatchSampler(torch.utils.data.RandomSampler(range(M.size(0))), - batch_size=100, - drop_last=False) - - optimizer = chop.stochastic.SplittingProxFW([Z], lmo=[lmo], prox=[prox], - lr_lmo='sublinear', - lr_prox='sublinear', - normalization='none') - - train_losses = [] - losses = [] - - for it in range(n_epochs): - for idx in sampler: - optimizer.zero_grad() - loss = sqloss(Z[idx], M[idx]) - # for logging - with torch.no_grad(): - full_loss = sqloss(Z, M) - losses.append(full_loss.item()) - train_losses.append(loss.item()) - loss.backward() - optimizer.step() - - - plt.plot(train_losses, label='training_losses') - plt.plot(losses, label='loss') - plt.ylim(0, 250) - plt.legend() + batch_sizes = [100, 250, 500, 1000] + fig, axes = plt.subplots(ncols=len(batch_sizes), figsize=(18, 10)) + + for batch_size, ax in zip(batch_sizes, axes): + print(f"Batch size: {batch_size}") + sampler = torch.utils.data.BatchSampler(torch.utils.data.RandomSampler(range(M.size(0))), + batch_size=batch_size, + drop_last=True) + + optimizer = chop.stochastic.SplittingProxFW([Z], lmo=[lmo], prox=[prox], + lr_lmo='sublinear', + lr_prox='sublinear', + normalization='none', + # weight_decay=1e-8, + momentum=.9) + + train_losses = [] + losses = [] + sgrad_avg = 0 + n_it = 0 + for it in range(n_epochs): + for idx in sampler: + n_it += 1 + optimizer.zero_grad() + loss = sqloss(Z[idx], M[idx]) + train_losses.append(loss.item()) + loss.backward() + sgrad = Z.grad.detach().clone() + sgrad_avg += sgrad + # for logging + with torch.no_grad(): + full_loss = sqloss(Z, M) + losses.append(full_loss.item()) + optimizer.step() + ax.set_title(f"b={batch_size}") + ax.plot(train_losses, label='training_losses') + ax.plot(losses, label='loss') + ax.set_ylim(0, 250) + ax.legend() + plt.savefig("robustPCA_stoch.png") print("Done.") From fa60642321e93881d3577afb8c7a5492844f212e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Fri, 2 Apr 2021 01:57:34 -0700 Subject: [PATCH 04/47] Added layer nuclear norm constraints + tests --- chop/constraints.py | 24 +++++++++++++++++++----- tests/test_constraints.py | 13 +++++++++---- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index 87a7a28..ad0778a 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -10,6 +10,7 @@ from copy import deepcopy from collections import defaultdict +import warnings import torch @@ -30,13 +31,19 @@ def get_avg_init_norm(layer, param_type=None, p=2, repetitions=100): @torch.no_grad() -def make_Lp_model_constraints(model, p=2, value=300, mode='initialization', constrain_bias=False): - """Create LpBall constraints for each layer of model, and value depends on mode (either radius or +def make_model_constraints(model, ord=2, value=300, mode='initialization', constrain_bias=False): + """Create Ball constraints for each layer of model. Ball radius depends on mode (either radius or factor to multiply average initialization norm with)""" constraints = [] # Compute average init norms if necessary init_norms = dict() + + if ord == 'nuc' and constrain_bias: + msg = "'nuc' constraints cannot constrain bias." + warnings.warn(msg) + constrain_bias = False + if mode == 'initialization': for layer in model.modules(): if hasattr(layer, 'reset_parameters'): @@ -58,12 +65,19 @@ def make_Lp_model_constraints(model, p=2, value=300, mode='initialization', cons else: print(name) if mode == 'radius': - constraint = make_LpBall(value, p=p) + alpha = value elif mode == 'initialization': alpha = value * init_norms[param.shape] - constraint = make_LpBall(alpha, p=p) else: - raise ValueError(f"Unknown mode {mode}") + msg = f"Unknown mode {mode}." + raise ValueError(msg) + if (type(ord) == int) or (ord == np.inf): + constraint = make_LpBall(alpha, p=ord) + elif ord == 'nuc': + constraint = NuclearNormBall(alpha) + else: + msg = f"ord {ord} is not supported." + raise ValueError(msg) constraints.append(constraint) return constraints diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 20f1a14..92f7f8f 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -4,7 +4,9 @@ import torch from torch import nn +from torchvision import models import torch.nn.functional as F + import chop from chop import utils from chop.utils.image import group_patches @@ -140,8 +142,9 @@ def test_feasible(Constraint, alpha): pass -@pytest.mark.parametrize('p', [1, 2, np.inf]) -def test_model_constraint_maker(p): +@pytest.mark.parametrize('ord', [1, 2, np.inf, 'nuc']) +@pytest.mark.parametrize('constrain_bias', [True, False]) +def test_model_constraint_maker(ord, constrain_bias): class Net(nn.Module): def __init__(self): @@ -169,11 +172,13 @@ def forward(self, x): return output model = Net() - constraints = chop.constraints.make_Lp_model_constraints(model, p) + constraints = chop.constraints.make_model_constraints(model, ord, constrain_bias=constrain_bias) assert len(constraints) == len(list(model.parameters())) chop.constraints.make_feasible(model, [constraint.prox for constraint in constraints]) for param, constraint in zip(model.parameters(), constraints): - assert torch.allclose(param, constraint.prox(param.unsqueeze(0)).squeeze(0), atol=1e-5) + if constraint: + assert torch.allclose(param, constraint.prox(param.unsqueeze(0)).squeeze(0), atol=1e-5) + From 7d0005e2ad018a55a4b3c951f682bffe1c9aa5f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Tue, 6 Apr 2021 06:24:51 -0700 Subject: [PATCH 05/47] Bug fixes; slowly migrating to param_groups API for Optimizer --- chop/constraints.py | 19 +- chop/stochastic.py | 173 +++++++++++------- examples/plot_stochastic_robust_PCA.py | 1 - ...wRank+Sparse_constrained_net_on_CIFAR10.py | 130 +++++++++++++ tests/test_constraints.py | 51 +++--- 5 files changed, 275 insertions(+), 99 deletions(-) create mode 100644 examples/training_LowRank+Sparse_constrained_net_on_CIFAR10.py diff --git a/chop/constraints.py b/chop/constraints.py index ad0778a..6942f21 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -39,7 +39,7 @@ def make_model_constraints(model, ord=2, value=300, mode='initialization', const # Compute average init norms if necessary init_norms = dict() - if ord == 'nuc' and constrain_bias: + if (ord == 'nuc') and constrain_bias: msg = "'nuc' constraints cannot constrain bias." warnings.warn(msg) constrain_bias = False @@ -60,7 +60,8 @@ def make_model_constraints(model, ord=2, value=300, mode='initialization', const init_norms[shape] = avg_norm for name, param in model.named_parameters(): - if (not constrain_bias) and ('bias' in name): + is_bias = ('bias' in name) or (param.ndim < 2) + if is_bias: constraint = None else: print(name) @@ -71,13 +72,13 @@ def make_model_constraints(model, ord=2, value=300, mode='initialization', const else: msg = f"Unknown mode {mode}." raise ValueError(msg) - if (type(ord) == int) or (ord == np.inf): - constraint = make_LpBall(alpha, p=ord) - elif ord == 'nuc': - constraint = NuclearNormBall(alpha) - else: - msg = f"ord {ord} is not supported." - raise ValueError(msg) + if (type(ord) == int) or (ord == np.inf): + constraint = make_LpBall(alpha, p=ord) + elif ord == 'nuc': + constraint = NuclearNormBall(alpha) + else: + msg = f"ord {ord} is not supported." + raise ValueError(msg) constraints.append(constraint) return constraints diff --git a/chop/stochastic.py b/chop/stochastic.py index 0b588e2..5863e6f 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -66,7 +66,8 @@ def backtracking_step_size( ratio_increase = 2.0 max_ls_iter = 100 if old_f_t is not None: - tmp = (certificate ** 2) / (2 * (old_f_t - f_t) * norm_update_direction) + tmp = (certificate ** 2) / \ + (2 * (old_f_t - f_t) * norm_update_direction) lipschitz_t = max(min(tmp, lipschitz_t), lipschitz_t * ratio_decrease) for _ in range(max_ls_iter): step_size_t = certificate / (norm_update_direction * lipschitz_t) @@ -76,7 +77,8 @@ def backtracking_step_size( step_size_t = max_step_size rhs = ( -step_size_t * certificate - + 0.5 * (step_size_t ** 2) * lipschitz_t * norm_update_direction + + 0.5 * (step_size_t ** 2) * + lipschitz_t * norm_update_direction ) f_next, grad_next = f_grad(x + step_size_t * update_direction) if f_next - f_t <= rhs + EPS: @@ -104,7 +106,7 @@ def normalize_gradient(grad, normalization): grad = grad / torch.norm(grad) return grad - + class PGD(Optimizer): """Proximal Gradient Descent @@ -127,13 +129,15 @@ class PGD(Optimizer): POSSIBLE_NORMALIZATIONS = {'none', 'L2', 'Linf', 'sign'} def __init__(self, params, prox=None, lr=.1, momentum=.9, normalization='none'): + params = list(params) if prox is None: - prox = [None] * len(list(params)) + prox = [None] * len(params) self.prox = [] for prox_el in prox: if prox_el is not None: - self.prox.append(lambda x, s=None: prox_el(x.unsqueeze(0), s).squeeze(0)) + self.prox.append(lambda x, s=None: prox_el( + x.unsqueeze(0), s).squeeze(0)) else: self.prox.append(lambda x, s=None: x) @@ -148,7 +152,8 @@ def __init__(self, params, prox=None, lr=.1, momentum=.9, normalization='none'): if normalization in self.POSSIBLE_NORMALIZATIONS: self.normalization = normalization else: - raise ValueError(f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}") + raise ValueError( + f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}") defaults = dict(prox=self.prox, name=self.name, momentum=self.momentum, lr=self.lr, normalization=self.normalization) @@ -171,8 +176,8 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() idx = 0 - for groups in self.param_groups: - for p in groups['params']: + for group in self.param_groups: + for p in group['params']: if p.grad is None: continue @@ -190,17 +195,19 @@ def step(self, closure=None): p, memory_format=torch.preserve_format) state['step'] += 1. - state['grad_estimate'].add_(grad - state['grad_estimate'], alpha=1. - self.momentum) + state['grad_estimate'].add_( + grad - state['grad_estimate'], alpha=1. - self.momentum) - grad_est = normalize_gradient(state['grad_estimate'], self.normalization) + grad_est = normalize_gradient( + state['grad_estimate'], group['normalization']) - if self.lr == 'sublinear': - step_size = 1. / (state['step'] + 1.) + if group['lr'] == 'sublinear': + state['lr'] = 1. / (state['step'] + 1.) else: - step_size = state['lr'] + state['lr'] = group['lr'] - new_p = self.prox[idx](p - step_size * grad_est, 1.) - state['certificate'] = torch.norm((p - new_p) / step_size) + new_p = self.prox[idx](p - state['lr'] * grad_est, 1.) + state['certificate'] = torch.norm((p - new_p) / state['lr']) p.copy_(new_p) idx += 1 return loss @@ -243,7 +250,8 @@ def _prox(x, s=None): self.lmo = [] for lmo_el in lmo: def _lmo(u, x): - update_direction, max_step_size = lmo_el(u.unsqueeze(0), x.unsqueeze(0)) + update_direction, max_step_size = lmo_el( + u.unsqueeze(0), x.unsqueeze(0)) return update_direction.squeeze(dim=0), max_step_size self.lmo.append(_lmo) @@ -311,10 +319,10 @@ class S3CM(Optimizer): prox2: [callable or None] or None Proximal operator for second constraint set. - + lr: float > 0 Learning rate - + normalization: str in {'none', 'L2', 'Linf', 'sign'} Normalizes the gradient. 'L2', 'Linf' divide the gradient by the corresponding norm. 'sign' uses the sign of the gradient. @@ -334,7 +342,8 @@ def __init__(self, params, prox1=None, prox2=None, lr=.1, normalization='none'): if normalization in self.POSSIBLE_NORMALIZATIONS: self.normalization = normalization else: - raise ValueError(f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}") + raise ValueError( + f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}") if prox1 is None: prox1 = [None] * len(params) @@ -351,14 +360,15 @@ def prox1_(x, s=None): return x if prox2_ is None: def prox2_(x, s=None): return x - self.prox1.append(lambda x, s=None: prox1_(x.unsqueeze(0), s).squeeze(dim=0)) - self.prox2.append(lambda x, s=None: prox2_(x.unsqueeze(0), s).squeeze(dim=0)) + self.prox1.append(lambda x, s=None: prox1_( + x.unsqueeze(0), s).squeeze(dim=0)) + self.prox2.append(lambda x, s=None: prox2_( + x.unsqueeze(0), s).squeeze(dim=0)) defaults = dict(lr=self.lr, prox1=self.prox1, prox2=self.prox2, normalization=self.normalization) super(S3CM, self).__init__(params, defaults) - @torch.no_grad() def step(self, closure=None): loss = None @@ -383,11 +393,14 @@ def step(self, closure=None): state['step'] = 0 state['iterate_1'] = p.clone().detach() state['iterate_2'] = self.prox2[idx](p, self.lr) - state['dual'] = (state['iterate_1'] - state['iterate_2']) / self.lr - - state['iterate_2'] = self.prox2[idx](state['iterate_1'] + self.lr * state['dual'], self.lr) - state['dual'].add_((state['iterate_1'] - state['iterate_2']) / self.lr) - state['iterate_1'] = self.prox1[idx](state['iterate_2'] + state['dual'] = (state['iterate_1'] - + state['iterate_2']) / self.lr + + state['iterate_2'] = self.prox2[idx]( + state['iterate_1'] + self.lr * state['dual'], self.lr) + state['dual'].add_( + (state['iterate_1'] - state['iterate_2']) / self.lr) + state['iterate_1'] = self.prox1[idx](state['iterate_2'] - self.lr * (grad + state['dual']), self.lr) p.copy_(state['iterate_2']) @@ -403,12 +416,14 @@ def __init__(self, params, lmo_pairwise, lr=.1, momentum=.9): raise ValueError("lr must be float or 'sublinear'.") def _lmo(u, x): - update_direction, max_step_size = lmo_pairwise(u.unsqueeze(0), x.unsqueeze(0)) + update_direction, max_step_size = lmo_pairwise( + u.unsqueeze(0), x.unsqueeze(0)) return update_direction.squeeze(dim=0), max_step_size self.lmo = _lmo self.lr = lr self.momentum = momentum - defaults = dict(lmo=self.lmo, name=self.name, lr=self.lr, momentum=self.momentum) + defaults = dict(lmo=self.lmo, name=self.name, + lr=self.lr, momentum=self.momentum) super(PairwiseFrankWolfe, self).__init__(params, defaults) raise NotImplementedError @@ -456,7 +471,8 @@ def __init__(self, params, lmo, lr=.1, momentum=0., _lmo = None else: def _lmo(u, x): - update_direction, max_step_size = oracle(u.unsqueeze(0), x.unsqueeze(0)) + update_direction, max_step_size = oracle( + u.unsqueeze(0), x.unsqueeze(0)) return update_direction.squeeze(dim=0), max_step_size lmo_candidates.append(_lmo) @@ -484,9 +500,10 @@ def _lmo(u, x): raise ValueError("weight_decay should be nonnegative.") self.weight_decay = weight_decay if normalization not in self.POSSIBLE_NORMALIZATIONS: - raise ValueError(f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}.") + raise ValueError( + f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}.") self.normalization = normalization - defaults = dict(lmo=self.lmo, name=self.name, lr=self.lr, + defaults = dict(lmo=self.lmo, name=self.name, lr=self.lr, momentum=self.momentum, weight_decay=weight_decay, normalization=self.normalization) @@ -543,13 +560,16 @@ def step(self, closure=None): state['step'] += 1. - state['grad_estimate'].add_(grad - state['grad_estimate'], alpha=1. - momentum) + state['grad_estimate'].add_( + grad - state['grad_estimate'], alpha=1. - momentum) update_direction, _ = self.lmo[idx](-state['grad_estimate'], p) - state['certificate'] = (-state['grad_estimate'] * update_direction).sum() - if self.normalization == 'gradient': + state['certificate'] = (-state['grad_estimate'] + * update_direction).sum() + if group['normalization'] == 'gradient': grad_norm = torch.norm(state['grad_estimate']) - step_size = min(1., step_size * grad_norm / torch.linalg.norm(update_direction)) - elif self.normalization == 'none': + step_size = min(1., step_size * grad_norm / + torch.linalg.norm(update_direction)) + elif group['normalization'] == 'none': pass p.add_(step_size * update_direction) idx += 1 @@ -568,34 +588,40 @@ def __init__(self, params, lmo, prox=None, lr_prox=.1, momentum=0., weight_decay=0., normalization='none'): - + params = list(params) # initialize proxes if prox is None: - prox = [None] * len(list(params)) + prox = [None] * len(params) prox_candidates = [] - for prox_el in prox: - if prox_el is not None: - prox_candidates.append(lambda x, s=None: prox_el(x.unsqueeze(0), s).squeeze(0)) + + def prox_maker(oracle): + if oracle: + def _prox(x, s=None): + return oracle(x.unsqueeze(0), s).squeeze(0) else: - prox_candidates.append(lambda x, s=None: x) + def _prox(x, s=None): + return x, s + return _prox + + prox_candidates = [prox_maker(oracle) for oracle in prox] # initialize lmos - lmo_candidates = [] - for oracle in lmo: - if oracle is None: - # Then FW will not be used on this parameter - _lmo = None - else: - def _lmo(u, x): - update_direction, max_step_size = oracle(u.unsqueeze(0), x.unsqueeze(0)) - return update_direction.squeeze(dim=0), max_step_size - lmo_candidates.append(_lmo) + + def lmo_maker(oracle): + def _lmo(u, x): + update_direction, max_step_size = oracle( + u.unsqueeze(0), x.unsqueeze(0)) + return update_direction.squeeze(dim=0), max_step_size.squeeze(dim=0) + + return _lmo + + lmo_candidates = [lmo_maker(oracle) if oracle else None for oracle in lmo] self.lmo = [] self.prox = [] useable_params = [] for param, lmo_oracle, prox_oracle in zip(params, lmo_candidates, prox_candidates): - if lmo_oracle: + if lmo_oracle is not None: useable_params.append(param) self.lmo.append(lmo_oracle) self.prox.append(prox_oracle) @@ -619,7 +645,8 @@ def _lmo(u, x): self.weight_decay = weight_decay if normalization not in self.POSSIBLE_NORMALIZATIONS: - raise ValueError(f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}") + raise ValueError( + f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}") defaults = dict(lmo=self.lmo, prox=self.prox, name=self.name, momentum=momentum, @@ -628,7 +655,7 @@ def _lmo(u, x): weight_decay=weight_decay, normalization=normalization) - super(SplittingProxFW, self).__init__(params, defaults) + super(SplittingProxFW, self).__init__(useable_params, defaults) @torch.no_grad() def step(self, closure=None): @@ -649,7 +676,8 @@ def step(self, closure=None): grad = p.grad state = self.state[p] if grad.is_sparse: - raise RuntimeError("We do not yet support sparse gradients.") + msg = "We do not yet support sparse gradients." + raise RuntimeError(msg) # Keep track of the step grad += .5 * group['weight_decay'] * p @@ -662,24 +690,41 @@ def step(self, closure=None): # initialize grad estimate state['grad_est'] = torch.zeros_like(p) # initialize learning rates - state['lr_prox'] = group['lr_prox'] if type(group['lr_prox'] == float) else 0. - state['lr_lmo'] = group['lr_lmo'] if type(group['lr_lmo'] == float) else 0. - state['momentum'] = group['momentum'] if type(group['momentum'] == float) else 0. + state['lr_prox'] = group['lr_prox'] if type( + group['lr_prox'] == float) else 0. + state['lr_lmo'] = group['lr_lmo'] if type( + group['lr_lmo'] == float) else 0. + state['momentum'] = group['momentum'] if type( + group['momentum'] == float) else 0. for lr in ('lr_prox', 'lr_lmo'): if group[lr] == 'sublinear': state[lr] = 2. / (state['step'] + 2) - + if group['momentum'] == 'sublinear': state['momentum'] = 4. / (state['step'] + 8.) ** (2/3) state['step'] += 1. - state['grad_est'].add_(grad - state['grad_est'], alpha=1. - state['momentum']) + state['grad_est'].add_( + grad - state['grad_est'], alpha=1. - state['momentum']) + + y_update, max_step_size = group['lmo'][idx]( + -state['grad_est'], state['y']) + + if group['normalization'] == 'gradient': + grad_norm = torch.norm(state['grad_est']) + for lr, direction in (('lr_prox', state['grad_est']), + ('lr_lmo', y_update)): + state[lr] = min(max_step_size, group[lr] * + grad_norm / torch.linalg.norm(direction)) + elif group['normalization'] == 'none': + for lr in ('lr_prox', 'lr_lmo'): + state[lr] = min(max_step_size, group[lr]) - y_update, max_step_size = group['lmo'][idx](-state['grad_est'], state['y']) state['lr_lmo'] = min(state['lr_lmo'], max_step_size) w = y_update + state['y'] - v = group['prox'][idx](state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], group['lr_prox']) + v = group['prox'][idx]( + state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], group['lr_prox']) state['y'].add_(y_update, alpha=state['lr_lmo']) x_update = v - state['x'] diff --git a/examples/plot_stochastic_robust_PCA.py b/examples/plot_stochastic_robust_PCA.py index f2f9213..32a52d8 100644 --- a/examples/plot_stochastic_robust_PCA.py +++ b/examples/plot_stochastic_robust_PCA.py @@ -110,5 +110,4 @@ def sqloss(Z, M): ax.plot(losses, label='loss') ax.set_ylim(0, 250) ax.legend() - plt.savefig("robustPCA_stoch.png") print("Done.") diff --git a/examples/training_LowRank+Sparse_constrained_net_on_CIFAR10.py b/examples/training_LowRank+Sparse_constrained_net_on_CIFAR10.py new file mode 100644 index 0000000..dee51b0 --- /dev/null +++ b/examples/training_LowRank+Sparse_constrained_net_on_CIFAR10.py @@ -0,0 +1,130 @@ +""" +Constrained Neural Network Training. +====================================== +Trains a ResNet model on CIFAR10 using constraints on the weights. +This example is inspired by the official PyTorch MNIST example, which +can be found [here](https://github.com/pytorch/examples/blob/master/mnist/main.py). +""" +from tqdm import tqdm + +import numpy as np + +import torch +from torch import nn +from torchvision import models +from torch.nn import functional as F +from easydict import EasyDict + +import chop + +# Setup +torch.manual_seed(0) + +use_cuda = torch.cuda.is_available() +device = torch.device("cuda" if use_cuda else "cpu") + +# Data Loaders +print("Loading data...") +dataset = chop.utils.data.CIFAR10("~/datasets/") +loaders = dataset.loaders() +# Model setup + + +print("Initializing model.") +model = models.resnet18() +model.to(device) + +criterion = nn.CrossEntropyLoss() + +# Outer optimization parameters +nb_epochs = 200 +momentum = .9 +lr_lmo = 0.1 +lr_prox = 0.1 + +# Make constraints +print("Preparing constraints.") +constraints_sparsity = chop.constraints.make_model_constraints(model, + ord=1, + value=10000, + constrain_bias=False) +constraints_low_rank = chop.constraints.make_model_constraints(model, + ord='nuc', + value=1000, + constrain_bias=False) +proxes = [constraint.prox if constraint else None + for constraint in constraints_sparsity] +lmos = [constraint.lmo if constraint else None + for constraint in constraints_low_rank] + +proxes_lr = [constraint.prox if constraint else None + for constraint in constraints_low_rank] + +print("Projecting model parameters in their associated constraint sets.") +chop.constraints.make_feasible(model, proxes) +chop.constraints.make_feasible(model, proxes_lr) + +optimizer = chop.stochastic.SplittingProxFW(model.parameters(), lmos, + proxes, + lr_lmo=lr_lmo, lr_prox=lr_prox, + momentum=momentum, + weight_decay=5e-4, + normalization='gradient') + +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + +bias_params = (param for name, param in model.named_parameters() if 'bias' in name) +bias_opt = chop.stochastic.PGD(bias_params, lr=1e-1) +bias_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(bias_opt) + + +def train(): + model.train() + train_loss = 0. + for data, target in tqdm(loaders.train, desc=f'Training epoch {epoch}/{nb_epochs - 1}'): + data, target = data.to(device), target.to(device) + + optimizer.zero_grad() + bias_opt.zero_grad() + loss = criterion(model(data), target) + loss.backward() + optimizer.step() + bias_opt.step() + + train_loss += loss.item() + train_loss /= len(loaders.train) + print(f'Training loss: {train_loss:.3f}') + return train_loss + + +def eval(): + model.eval() + report = EasyDict(nb_test=0, correct=0, correct_adv_pgd=0, + correct_adv_pgd_madry=0, + correct_adv_fw=0, correct_adv_mfw=0) + val_loss = 0 + with torch.no_grad(): + for data, target in tqdm(loaders.test, desc=f'Val epoch {epoch}/{nb_epochs - 1}'): + data, target = data.to(device), target.to(device) + + # Compute corresponding predictions + logits = model(data) + _, pred = logits.max(1) + val_loss += criterion(logits, target) + # Get clean accuracies + report.nb_test += data.size(0) + report.correct += pred.eq(target).sum().item() + + val_loss /= report.nb_test + print(f'Val acc on clean examples (%): {report.correct / report.nb_test * 100.:.3f}') + return val_loss + + +print("Training...") +# Training loop +for epoch in range(nb_epochs): + # Evaluate on clean and adversarial test data + train() + val_loss = eval() + scheduler.step(val_loss) + bias_scheduler.step(val_loss) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 92f7f8f..a7551ea 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -142,35 +142,36 @@ def test_feasible(Constraint, alpha): pass +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + @pytest.mark.parametrize('ord', [1, 2, np.inf, 'nuc']) @pytest.mark.parametrize('constrain_bias', [True, False]) def test_model_constraint_maker(ord, constrain_bias): - class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 10) - - def forward(self, x): - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output - model = Net() constraints = chop.constraints.make_model_constraints(model, ord, constrain_bias=constrain_bias) From aa447e02a5ca49a6118b31340845557e6d4365f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Tue, 6 Apr 2021 07:35:57 -0700 Subject: [PATCH 06/47] MNIST example LR+Sparse --- chop/constraints.py | 7 ++- chop/stochastic.py | 4 +- .../training_L1_constrained_net_on_CIFAR10.py | 20 +++++--- ...owRank+Sparse_constrained_net_on_MNIST.py} | 48 +++++++++++++++---- 4 files changed, 60 insertions(+), 19 deletions(-) rename examples/{training_LowRank+Sparse_constrained_net_on_CIFAR10.py => training_LowRank+Sparse_constrained_net_on_MNIST.py} (77%) diff --git a/chop/constraints.py b/chop/constraints.py index 6942f21..f0891a9 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -30,6 +30,10 @@ def get_avg_init_norm(layer, param_type=None, p=2, repetitions=100): return float(output) / repetitions +def is_bias(name, param): + return ('bias' in name) or (param.ndim < 2) + + @torch.no_grad() def make_model_constraints(model, ord=2, value=300, mode='initialization', constrain_bias=False): """Create Ball constraints for each layer of model. Ball radius depends on mode (either radius or @@ -60,8 +64,7 @@ def make_model_constraints(model, ord=2, value=300, mode='initialization', const init_norms[shape] = avg_norm for name, param in model.named_parameters(): - is_bias = ('bias' in name) or (param.ndim < 2) - if is_bias: + if is_bias(name, param): constraint = None else: print(name) diff --git a/chop/stochastic.py b/chop/stochastic.py index 5863e6f..f71ed59 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -715,11 +715,11 @@ def step(self, closure=None): grad_norm = torch.norm(state['grad_est']) for lr, direction in (('lr_prox', state['grad_est']), ('lr_lmo', y_update)): - state[lr] = min(max_step_size, group[lr] * + state[lr] = min(max_step_size, state[lr] * grad_norm / torch.linalg.norm(direction)) elif group['normalization'] == 'none': for lr in ('lr_prox', 'lr_lmo'): - state[lr] = min(max_step_size, group[lr]) + state[lr] = min(max_step_size, state[lr]) state['lr_lmo'] = min(state['lr_lmo'], max_step_size) w = y_update + state['y'] diff --git a/examples/training_L1_constrained_net_on_CIFAR10.py b/examples/training_L1_constrained_net_on_CIFAR10.py index 51b9e25..5e244e0 100644 --- a/examples/training_L1_constrained_net_on_CIFAR10.py +++ b/examples/training_L1_constrained_net_on_CIFAR10.py @@ -60,9 +60,8 @@ bias_opt = chop.stochastic.PGD(bias_params, lr=1e-1) bias_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(bias_opt) -print("Training...") -# Training loop -for epoch in range(nb_epochs): + +def train(): model.train() train_loss = 0. for data, target in tqdm(loaders.train, desc=f'Training epoch {epoch}/{nb_epochs - 1}'): @@ -78,9 +77,10 @@ train_loss += loss.item() train_loss /= len(loaders.train) print(f'Training loss: {train_loss:.3f}') + return train_loss - # Evaluate on clean and adversarial test data +def eval(): model.eval() report = EasyDict(nb_test=0, correct=0, correct_adv_pgd=0, correct_adv_pgd_madry=0, @@ -90,7 +90,7 @@ for data, target in tqdm(loaders.test, desc=f'Val epoch {epoch}/{nb_epochs - 1}'): data, target = data.to(device), target.to(device) - # Compute corresponding predictions + # Compute corresponding predictions logits = model(data) _, pred = logits.max(1) val_loss += criterion(logits, target) @@ -100,6 +100,14 @@ val_loss /= report.nb_test print(f'Val acc on clean examples (%): {report.correct / report.nb_test * 100.:.3f}') + return val_loss + +print("Training...") +# Training loop +for epoch in range(nb_epochs): + # Evaluate on clean and adversarial test data + train() + val_loss = eval() scheduler.step(val_loss) - bias_scheduler.step(val_loss) \ No newline at end of file + bias_scheduler.step(val_loss) diff --git a/examples/training_LowRank+Sparse_constrained_net_on_CIFAR10.py b/examples/training_LowRank+Sparse_constrained_net_on_MNIST.py similarity index 77% rename from examples/training_LowRank+Sparse_constrained_net_on_CIFAR10.py rename to examples/training_LowRank+Sparse_constrained_net_on_MNIST.py index dee51b0..1bbb71a 100644 --- a/examples/training_LowRank+Sparse_constrained_net_on_CIFAR10.py +++ b/examples/training_LowRank+Sparse_constrained_net_on_MNIST.py @@ -25,13 +25,43 @@ # Data Loaders print("Loading data...") -dataset = chop.utils.data.CIFAR10("~/datasets/") +dataset = chop.utils.data.MNIST("~/datasets/") loaders = dataset.loaders() +# print("Loading data...") +# dataset = chop.utils.data.CIFAR10("~/datasets/") +# loaders = dataset.loaders() # Model setup print("Initializing model.") -model = models.resnet18() + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +model = Net() model.to(device) criterion = nn.CrossEntropyLoss() @@ -39,8 +69,8 @@ # Outer optimization parameters nb_epochs = 200 momentum = .9 -lr_lmo = 0.1 -lr_prox = 0.1 +lr_lmo = 'sublinear' +lr_prox = 'sublinear' # Make constraints print("Preparing constraints.") @@ -68,12 +98,12 @@ proxes, lr_lmo=lr_lmo, lr_prox=lr_prox, momentum=momentum, - weight_decay=5e-4, - normalization='gradient') + weight_decay=0, + normalization='none') -scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) +# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) -bias_params = (param for name, param in model.named_parameters() if 'bias' in name) +bias_params = (param for name, param in model.named_parameters() if chop.constraints.is_bias(name, param)) bias_opt = chop.stochastic.PGD(bias_params, lr=1e-1) bias_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(bias_opt) @@ -126,5 +156,5 @@ def eval(): # Evaluate on clean and adversarial test data train() val_loss = eval() - scheduler.step(val_loss) + # scheduler.step(val_loss) bias_scheduler.step(val_loss) From 578e8b79f5f119af1e8cde52aab1db932e2a07f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Tue, 13 Apr 2021 11:26:56 -0700 Subject: [PATCH 07/47] Bug fix NuclearNorm prox --- chop/constraints.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index f0891a9..4de00db 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -21,12 +21,13 @@ @torch.no_grad() -def get_avg_init_norm(layer, param_type=None, p=2, repetitions=100): +def get_avg_init_norm(layer, param_type=None, ord=2, repetitions=100): """Computes the average norm of default layer initialization""" output = 0 for _ in range(repetitions): layer.reset_parameters() - output += torch.norm(getattr(layer, param_type), p=p).item() + warnings.warn("torch.norm is deprecated. Think about updating this.") + output += torch.norm(getattr(layer, param_type), p=ord).item() return float(output) / repetitions @@ -56,8 +57,8 @@ def make_model_constraints(model, ord=2, value=300, mode='initialization', const None))]: param = getattr(layer, param_type) shape = param.shape - - avg_norm = get_avg_init_norm(layer, param_type=param_type, p=2) + # TODO: figure out how to set the constraint size for NuclearNormBall constraint + avg_norm = get_avg_init_norm(layer, param_type=param_type, ord=2) if avg_norm == 0.0: # Catch unlikely case that weight/bias is 0-initialized (e.g. BatchNorm does this) avg_norm = 1.0 @@ -449,7 +450,9 @@ def __init__(self, alpha): @torch.no_grad() def prox(self, x, step_size=None): shape = x.shape - x = euclidean_proj_simplex(x.view(-1), self.alpha) + flattened_x = x.view(shape[0], -1) + projected = [euclidean_proj_simplex(row, s=self.alpha) for row in flattened_x] + x = torch.stack(projected) return x.view(*shape) @torch.no_grad() @@ -508,16 +511,19 @@ def prox(self, x, step_size=None): """ Projection operator on the Nuclear Norm constraint set. """ - U, S, V = torch.svd(x) - # Project S on the alpha-simplex - simplex = Simplex(self.alpha) + # Project S on the alpha-L1 ball + ball = L1Ball(self.alpha) + + S_proj = ball.prox(S.view(-1, S.size(-1))).view_as(S) - S_proj = simplex.prox(S.view(-1, S.size(-1))).view_as(S) - VT = V.transpose(-2, -1) return torch.matmul(U, torch.matmul(torch.diag_embed(S_proj), VT)) + def is_feasible(self, x, atol=1e-5, rtol=1e-5): + norms = torch.linalg.norm(x, dim=(-2, -1), ord='nuc') + return (norms <= self.alpha * (1. + rtol) + atol) + class GroupL1Ball: From c921483286e9453942efc6295323ec2194e32a88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Tue, 13 Apr 2021 11:30:21 -0700 Subject: [PATCH 08/47] slowly migrating optimizers to correct **params API --- chop/stochastic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index f71ed59..47d6c2d 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -142,7 +142,7 @@ def __init__(self, params, prox=None, lr=.1, momentum=.9, normalization='none'): self.prox.append(lambda x, s=None: x) if not (type(lr) == float or lr == 'sublinear'): - raise ValueError("lr must be float or 'sublinear'.") + raise ValueError(f"lr must be float or 'sublinear', got {lr}.") self.lr = lr if not(0. <= momentum <= 1.): @@ -405,6 +405,7 @@ def step(self, closure=None): p.copy_(state['iterate_2']) idx += 1 + return loss class PairwiseFrankWolfe(Optimizer): From a672e64313829de1e68c7c9a94d945e072dd2cce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Tue, 13 Apr 2021 11:43:11 -0700 Subject: [PATCH 09/47] Stochastic Robust PCA example --- examples/plot_stochastic_robust_PCA.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/plot_stochastic_robust_PCA.py b/examples/plot_stochastic_robust_PCA.py index 32a52d8..ada42ac 100644 --- a/examples/plot_stochastic_robust_PCA.py +++ b/examples/plot_stochastic_robust_PCA.py @@ -9,7 +9,7 @@ We reproduce the synthetic experimental setting from `[Garber et al. 2018] `_. -We aim to recover :math:`M = L + S + N`, where :math:`L` is rank :math:`p`, +We aim to recover :math:`M = L + S + N`, where :math:`L` is rank :math:`r`, :math:`S` is :math:`p` sparse, and :math:`N` is standard Gaussian elementwise. """ @@ -28,11 +28,11 @@ n = 1000 r_p = [(5, 1e-3), - # (5, 3e-3), (25, 1e-3), (25, 3e-3), - # (25, 3e-2), (130, 1e-2) + (5, 3e-3), (25, 1e-3), (25, 3e-3), + (25, 3e-2), (130, 1e-2) ] -n_epochs = 100 +n_epochs = 400 for r, p in r_p: print(f'r={r} and p={p}') @@ -72,7 +72,8 @@ def sqloss(Z, M): Z.requires_grad_(True) batch_sizes = [100, 250, 500, 1000] - fig, axes = plt.subplots(ncols=len(batch_sizes), figsize=(18, 10)) + fig, axes = plt.subplots(ncols=len(batch_sizes), figsize=(18, 10), sharey=True) + fig.suptitle(f'r={r} and p={p}') for batch_size, ax in zip(batch_sizes, axes): print(f"Batch size: {batch_size}") @@ -80,12 +81,14 @@ def sqloss(Z, M): batch_size=batch_size, drop_last=True) + momentum = .9 if batch_size != 1000 else 0. + optimizer = chop.stochastic.SplittingProxFW([Z], lmo=[lmo], prox=[prox], lr_lmo='sublinear', lr_prox='sublinear', normalization='none', # weight_decay=1e-8, - momentum=.9) + momentum=momentum) train_losses = [] losses = [] @@ -108,6 +111,7 @@ def sqloss(Z, M): ax.set_title(f"b={batch_size}") ax.plot(train_losses, label='training_losses') ax.plot(losses, label='loss') - ax.set_ylim(0, 250) + ax.set_yscale('log') ax.legend() - print("Done.") + fig.show() +print("Done.") From 30611089651bfd9fc9ade9727ad5848ca6f96b43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Tue, 13 Apr 2021 11:46:27 -0700 Subject: [PATCH 10/47] Deleted obsolete LR + sparse example --- ...LowRank+Sparse_constrained_net_on_MNIST.py | 160 ------------------ 1 file changed, 160 deletions(-) delete mode 100644 examples/training_LowRank+Sparse_constrained_net_on_MNIST.py diff --git a/examples/training_LowRank+Sparse_constrained_net_on_MNIST.py b/examples/training_LowRank+Sparse_constrained_net_on_MNIST.py deleted file mode 100644 index 1bbb71a..0000000 --- a/examples/training_LowRank+Sparse_constrained_net_on_MNIST.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -Constrained Neural Network Training. -====================================== -Trains a ResNet model on CIFAR10 using constraints on the weights. -This example is inspired by the official PyTorch MNIST example, which -can be found [here](https://github.com/pytorch/examples/blob/master/mnist/main.py). -""" -from tqdm import tqdm - -import numpy as np - -import torch -from torch import nn -from torchvision import models -from torch.nn import functional as F -from easydict import EasyDict - -import chop - -# Setup -torch.manual_seed(0) - -use_cuda = torch.cuda.is_available() -device = torch.device("cuda" if use_cuda else "cpu") - -# Data Loaders -print("Loading data...") -dataset = chop.utils.data.MNIST("~/datasets/") -loaders = dataset.loaders() -# print("Loading data...") -# dataset = chop.utils.data.CIFAR10("~/datasets/") -# loaders = dataset.loaders() -# Model setup - - -print("Initializing model.") - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 10) - - def forward(self, x): - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output - - -model = Net() -model.to(device) - -criterion = nn.CrossEntropyLoss() - -# Outer optimization parameters -nb_epochs = 200 -momentum = .9 -lr_lmo = 'sublinear' -lr_prox = 'sublinear' - -# Make constraints -print("Preparing constraints.") -constraints_sparsity = chop.constraints.make_model_constraints(model, - ord=1, - value=10000, - constrain_bias=False) -constraints_low_rank = chop.constraints.make_model_constraints(model, - ord='nuc', - value=1000, - constrain_bias=False) -proxes = [constraint.prox if constraint else None - for constraint in constraints_sparsity] -lmos = [constraint.lmo if constraint else None - for constraint in constraints_low_rank] - -proxes_lr = [constraint.prox if constraint else None - for constraint in constraints_low_rank] - -print("Projecting model parameters in their associated constraint sets.") -chop.constraints.make_feasible(model, proxes) -chop.constraints.make_feasible(model, proxes_lr) - -optimizer = chop.stochastic.SplittingProxFW(model.parameters(), lmos, - proxes, - lr_lmo=lr_lmo, lr_prox=lr_prox, - momentum=momentum, - weight_decay=0, - normalization='none') - -# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) - -bias_params = (param for name, param in model.named_parameters() if chop.constraints.is_bias(name, param)) -bias_opt = chop.stochastic.PGD(bias_params, lr=1e-1) -bias_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(bias_opt) - - -def train(): - model.train() - train_loss = 0. - for data, target in tqdm(loaders.train, desc=f'Training epoch {epoch}/{nb_epochs - 1}'): - data, target = data.to(device), target.to(device) - - optimizer.zero_grad() - bias_opt.zero_grad() - loss = criterion(model(data), target) - loss.backward() - optimizer.step() - bias_opt.step() - - train_loss += loss.item() - train_loss /= len(loaders.train) - print(f'Training loss: {train_loss:.3f}') - return train_loss - - -def eval(): - model.eval() - report = EasyDict(nb_test=0, correct=0, correct_adv_pgd=0, - correct_adv_pgd_madry=0, - correct_adv_fw=0, correct_adv_mfw=0) - val_loss = 0 - with torch.no_grad(): - for data, target in tqdm(loaders.test, desc=f'Val epoch {epoch}/{nb_epochs - 1}'): - data, target = data.to(device), target.to(device) - - # Compute corresponding predictions - logits = model(data) - _, pred = logits.max(1) - val_loss += criterion(logits, target) - # Get clean accuracies - report.nb_test += data.size(0) - report.correct += pred.eq(target).sum().item() - - val_loss /= report.nb_test - print(f'Val acc on clean examples (%): {report.correct / report.nb_test * 100.:.3f}') - return val_loss - - -print("Training...") -# Training loop -for epoch in range(nb_epochs): - # Evaluate on clean and adversarial test data - train() - val_loss = eval() - # scheduler.step(val_loss) - bias_scheduler.step(val_loss) From bbe94703f6764e7718c105d71051bb6c8dd66083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Sat, 24 Apr 2021 00:25:31 -0700 Subject: [PATCH 11/47] fixes to splitting method -- added gradient normalization --- chop/stochastic.py | 25 +- ...nstrained_net_on_MNIST_hyperparamsweeps.py | 218 ++++++++++++++++++ tests/test_constraints.py | 7 +- 3 files changed, 233 insertions(+), 17 deletions(-) create mode 100644 examples/training_LowRank+Sparse_constrained_net_on_MNIST_hyperparamsweeps.py diff --git a/chop/stochastic.py b/chop/stochastic.py index 47d6c2d..29f96b1 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -553,7 +553,7 @@ def step(self, closure=None): else: raise ValueError("lr must be float or 'sublinear'.") - if self.momentum is None: + if self.momentum is None or self.momentum == 'sublinear': rho = (1. / (state['step'] + 1)) ** (1/3) momentum = 1. - rho else: @@ -680,7 +680,7 @@ def step(self, closure=None): msg = "We do not yet support sparse gradients." raise RuntimeError(msg) # Keep track of the step - grad += .5 * group['weight_decay'] * p + grad += group['weight_decay'] * p # Initialization if len(state) == 0: @@ -703,7 +703,8 @@ def step(self, closure=None): state[lr] = 2. / (state['step'] + 2) if group['momentum'] == 'sublinear': - state['momentum'] = 4. / (state['step'] + 8.) ** (2/3) + rho = 4. / (state['step'] + 8.) ** (2/3) + state['momentum'] = 1. - rho state['step'] += 1. state['grad_est'].add_( @@ -712,23 +713,19 @@ def step(self, closure=None): y_update, max_step_size = group['lmo'][idx]( -state['grad_est'], state['y']) - if group['normalization'] == 'gradient': - grad_norm = torch.norm(state['grad_est']) - for lr, direction in (('lr_prox', state['grad_est']), - ('lr_lmo', y_update)): - state[lr] = min(max_step_size, state[lr] * - grad_norm / torch.linalg.norm(direction)) - elif group['normalization'] == 'none': - for lr in ('lr_prox', 'lr_lmo'): - state[lr] = min(max_step_size, state[lr]) + state['lr_lmo'] = min(max_step_size, state['lr_lmo']) + if group['normalization'] == 'gradient': + # Normalize LMO update direction + grad_norm = torch.linalg.norm(state['grad_est']) + y_update *= min(1, grad_norm / torch.linalg.norm(y_update)) state['lr_lmo'] = min(state['lr_lmo'], max_step_size) w = y_update + state['y'] v = group['prox'][idx]( - state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], group['lr_prox']) + state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) + x_update = v - state['x'] state['y'].add_(y_update, alpha=state['lr_lmo']) - x_update = v - state['x'] state['x'].add_(x_update, alpha=state['lr_lmo']) p.copy_(state['x'] + state['y']) diff --git a/examples/training_LowRank+Sparse_constrained_net_on_MNIST_hyperparamsweeps.py b/examples/training_LowRank+Sparse_constrained_net_on_MNIST_hyperparamsweeps.py new file mode 100644 index 0000000..bf5bfea --- /dev/null +++ b/examples/training_LowRank+Sparse_constrained_net_on_MNIST_hyperparamsweeps.py @@ -0,0 +1,218 @@ +""" +Constrained Neural Network Training. +====================================== +Trains a ResNet model on CIFAR10 using constraints on the weights. +This example is inspired by the official PyTorch MNIST example, which +can be found [here](https://github.com/pytorch/examples/blob/master/mnist/main.py). +""" +from __future__ import print_function +import torch.nn.functional as F +import torch.nn as nn +import argparse + +import numpy as np + + +import torch +from torch import nn +from torch.nn import functional as F + +import chop + +import wandb + + +# Hyperparam setup +default_config = { + 'lr': 1e-4, + 'batch_size': 64, + 'momentum': .5, + 'weight_decay': 1e-5, + 'lr_bias': 0.005, + 'grad_norm': 'none', + 'l1_constraint_size': 30, + 'nuc_constraint_size': 1e2, + 'epochs': 2, + 'seed': 1 +} + +wandb.init(project='low-rank_sparse_mnist', config=default_config) +config = wandb.config + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def get_sparsity_and_rank(opt): + nnzero = 0 + n_params = 0 + total_rank = 0 + max_rank = 0 + + for group in opt.param_groups: + for p in group['params']: + state = opt.state[p] + nnzero += (state['x'] !=0 ).sum() + n_params += p.numel() + ranks = torch.linalg.matrix_rank(state['y']) + total_rank += ranks.sum() + max_rank += min(p.shape) * ranks.numel() + + return nnzero / n_params, total_rank / max_rank + + +def train(args, model, device, train_loader, opt, opt_bias, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + opt.zero_grad() + opt_bias.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + if loss.isnan(): + break + opt.step() + opt_bias.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + wandb.log({"Train Loss": loss.item()}) + sparsity, rank = get_sparsity_and_rank(opt) + wandb.log({"Sparsity": sparsity, + "Rank": rank}) + + +def test(args, model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + + example_images = [] + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + # sum up batch loss + test_loss += F.nll_loss(output, target, reduction='sum').item() + # get the index of the max log-probability + pred = output.max(1, keepdim=True)[1] + correct += pred.eq(target.view_as(pred)).sum().item() + example_images.append(wandb.Image( + data[0], caption="Pred: {} Truth: {}".format(pred[0].item(), target[0]))) + + test_loss /= len(test_loader.dataset) + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + wandb.log({ + "Examples": example_images, + "Test Accuracy": 100. * correct / len(test_loader.dataset), + "Test Loss": test_loss}) + + +def main(): + + wandb.init() + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch_size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=5, metavar='N', + help='number of epochs to train (default: 2)') + parser.add_argument('--lr', default=1e-2, metavar='LR', + help='learning rate (default: "sublinear")') + parser.add_argument('--lr_bias', default=0.005, type=float, metavar='LR_BIAS', + help='learning rate (default: 0.01)') + parser.add_argument('--momentum', type=float, default=0.5, metavar='M', + help='Optimizer momentum (default: 0.5)') + parser.add_argument('--weight_decay', type=float, default=1e-5, metavar='W', + help='Optimizer weight decay (default: 0.)') + parser.add_argument('--grad_norm', type=str, default='gradient', + help='Gradient normalization options') + parser.add_argument('--nuc_constraint_size', type=float, default=1e2, + help='Size of the Nuclear norm Ball constraint') + parser.add_argument('--l1_constraint_size', type=float, default=30, + help='Size of the ell-1 norm Ball constraint') + parser.add_argument('--no_cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log_interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + if args.lr != 'sublinear': + args.lr = float(args.lr) + + wandb.config.update(args, allow_val_change=True) + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + dataset = chop.utils.data.MNIST("~/datasets/") + loaders = dataset.loaders(args.batch_size, args.test_batch_size) + + model = Net().to(device) + constraints_sparsity = chop.constraints.make_model_constraints(model, + ord=1, + value=args.l1_constraint_size, + constrain_bias=False) + constraints_low_rank = chop.constraints.make_model_constraints(model, + ord='nuc', + value=args.nuc_constraint_size, + constrain_bias=False) + proxes = [constraint.prox if constraint else None + for constraint in constraints_sparsity] + lmos = [constraint.lmo if constraint else None + for constraint in constraints_low_rank] + + proxes_lr = [constraint.prox if constraint else None + for constraint in constraints_low_rank] + + chop.constraints.make_feasible(model, proxes) + chop.constraints.make_feasible(model, proxes_lr) + + optimizer = chop.stochastic.SplittingProxFW(model.parameters(), lmos, + proxes, + lr_lmo=args.lr, + lr_prox=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + normalization=args.grad_norm) + + bias_params = (param for name, param in model.named_parameters() + if chop.constraints.is_bias(name, param)) + bias_opt = chop.stochastic.PGD(bias_params, lr=args.lr_bias) + + + wandb.watch(model, log_freq=1, log='all') + + for epoch in range(1, args.epochs + 1): + train(args, model, device, loaders.train, optimizer, bias_opt, epoch) + test(args, model, device, loaders.test) + + +if __name__ == '__main__': + main() diff --git a/tests/test_constraints.py b/tests/test_constraints.py index aebf657..656b086 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -181,8 +181,9 @@ def test_model_constraint_maker(ord, constrain_bias): chop.constraints.make_feasible(model, proxes) - for (name, param), constraint in zip(model.named_parameters(), constraints): + for (name, param), prox in zip(model.named_parameters(), proxes): if chop.constraints.is_bias(name, param) and ord == 'nuc': continue - if constraint: - assert torch.allclose(param, constraint.prox(param.unsqueeze(0)).squeeze(0), atol=1e-5) + if prox: + allclose = torch.allclose(param, prox(param.unsqueeze(0)).squeeze(0), rtol=5e-3, atol=5e-3) + assert allclose \ No newline at end of file From 7671cf768e54b8bd517ab5c3c9bd97264be92c22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Wed, 28 Apr 2021 15:10:44 -0700 Subject: [PATCH 12/47] removed redundant code --- chop/stochastic.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index 29f96b1..57b7246 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -585,7 +585,7 @@ class SplittingProxFW(Optimizer): POSSIBLE_NORMALIZATIONS = {'none', 'gradient'} def __init__(self, params, lmo, prox=None, - lr_lmo=.1, + lr=.1, lr_prox=.1, momentum=0., weight_decay=0., normalization='none'): @@ -632,7 +632,7 @@ def _lmo(u, x): f"Please pass this parameter to another optimizer.") warnings.warn(msg) - for name, lr in (('lr_lmo', lr_lmo), + for name, lr in (('lr', lr), ('lr_prox', lr_prox)): if not ((type(lr) == float) or lr == 'sublinear'): msg = f"{name} should be a float or 'sublinear', got {lr}." @@ -651,7 +651,7 @@ def _lmo(u, x): defaults = dict(lmo=self.lmo, prox=self.prox, name=self.name, momentum=momentum, - lr_lmo=lr_lmo, + lr=lr, lr_prox=lr_prox, weight_decay=weight_decay, normalization=normalization) @@ -685,20 +685,22 @@ def step(self, closure=None): # Initialization if len(state) == 0: state['step'] = 0. - # split variable: p = x + y - state['x'] = .5 * p.detach().clone() - state['y'] = .5 * p.detach().clone() + state['prox'] = group['prox'][idx] + state['lmo'] = group['lmo'][idx] + # split variable: p = x + y and make feasible + state['x'] = state['prox'](.5 * p.detach().clone()) + state['y'] = state['prox'](.5 * p.detach().clone()) # initialize grad estimate state['grad_est'] = torch.zeros_like(p) # initialize learning rates state['lr_prox'] = group['lr_prox'] if type( group['lr_prox'] == float) else 0. - state['lr_lmo'] = group['lr_lmo'] if type( - group['lr_lmo'] == float) else 0. + state['lr'] = group['lr'] if type( + group['lr'] == float) else 0. state['momentum'] = group['momentum'] if type( group['momentum'] == float) else 0. - for lr in ('lr_prox', 'lr_lmo'): + for lr in ('lr_prox', 'lr'): if group[lr] == 'sublinear': state[lr] = 2. / (state['step'] + 2) @@ -710,23 +712,24 @@ def step(self, closure=None): state['grad_est'].add_( grad - state['grad_est'], alpha=1. - state['momentum']) - y_update, max_step_size = group['lmo'][idx]( + y_update, max_step_size = state['lmo']( -state['grad_est'], state['y']) - state['lr_lmo'] = min(max_step_size, state['lr_lmo']) + state['lr'] = min(max_step_size, state['lr']) if group['normalization'] == 'gradient': # Normalize LMO update direction grad_norm = torch.linalg.norm(state['grad_est']) - y_update *= min(1, grad_norm / torch.linalg.norm(y_update)) - state['lr_lmo'] = min(state['lr_lmo'], max_step_size) + y_update_norm = torch.linalg.norm(y_update) + y_update *= min(1, grad_norm / y_update_norm) + w = y_update + state['y'] - v = group['prox'][idx]( + v = state['prox']( state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) x_update = v - state['x'] - state['y'].add_(y_update, alpha=state['lr_lmo']) - state['x'].add_(x_update, alpha=state['lr_lmo']) + state['y'].add_(y_update, alpha=state['lr']) + state['x'].add_(x_update, alpha=state['lr']) p.copy_(state['x'] + state['y']) idx += 1 From b41a17b0cd40080f2abf389e33803d8a80256808 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Fri, 30 Apr 2021 17:27:31 -0700 Subject: [PATCH 13/47] Enforcing relationship between lr, lr_prox and lipschitz --- chop/stochastic.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index 57b7246..7f7a569 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -586,7 +586,7 @@ class SplittingProxFW(Optimizer): def __init__(self, params, lmo, prox=None, lr=.1, - lr_prox=.1, + lipschitz =1., momentum=0., weight_decay=0., normalization='none'): params = list(params) @@ -598,7 +598,7 @@ def __init__(self, params, lmo, prox=None, def prox_maker(oracle): if oracle: - def _prox(x, s=None): + def _prox(x, s=None): return oracle(x.unsqueeze(0), s).squeeze(0) else: def _prox(x, s=None): @@ -652,7 +652,7 @@ def _lmo(u, x): name=self.name, momentum=momentum, lr=lr, - lr_prox=lr_prox, + lipschitz=lipschitz, weight_decay=weight_decay, normalization=normalization) @@ -693,16 +693,15 @@ def step(self, closure=None): # initialize grad estimate state['grad_est'] = torch.zeros_like(p) # initialize learning rates - state['lr_prox'] = group['lr_prox'] if type( - group['lr_prox'] == float) else 0. state['lr'] = group['lr'] if type( group['lr'] == float) else 0. + state['lipschitz'] = group['lipschitz'] + state['lr_prox'] = state['lr'] * state['lipschitz'] state['momentum'] = group['momentum'] if type( group['momentum'] == float) else 0. - for lr in ('lr_prox', 'lr'): - if group[lr] == 'sublinear': - state[lr] = 2. / (state['step'] + 2) + if group['lr'] == 'sublinear': + state['lr'] = 2. / (state['step'] + 2) if group['momentum'] == 'sublinear': rho = 4. / (state['step'] + 8.) ** (2/3) @@ -715,6 +714,7 @@ def step(self, closure=None): y_update, max_step_size = state['lmo']( -state['grad_est'], state['y']) + state['lr_prox'] = state['lr'] * state['lipschitz'] state['lr'] = min(max_step_size, state['lr']) if group['normalization'] == 'gradient': From 227b627c51c70cea08b7c87d557f53f108fa700d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Mon, 3 May 2021 13:49:56 -0700 Subject: [PATCH 14/47] Removed lr_prox parameter; using lipschitz estimate, consistently with theory --- chop/stochastic.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index 7f7a569..bfa36ba 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -586,7 +586,7 @@ class SplittingProxFW(Optimizer): def __init__(self, params, lmo, prox=None, lr=.1, - lipschitz =1., + lipschitz=1., momentum=0., weight_decay=0., normalization='none'): params = list(params) @@ -632,11 +632,9 @@ def _lmo(u, x): f"Please pass this parameter to another optimizer.") warnings.warn(msg) - for name, lr in (('lr', lr), - ('lr_prox', lr_prox)): - if not ((type(lr) == float) or lr == 'sublinear'): - msg = f"{name} should be a float or 'sublinear', got {lr}." - raise ValueError(msg) + if not ((type(lr) == float) or lr == 'sublinear'): + msg = f"lr should be a float or 'sublinear', got {lr}." + raise ValueError(msg) if (momentum != 'sublinear') and (not (0. <= momentum <= 1.)): raise ValueError("momentum must be in [0., 1.] or 'sublinear'.") From ccc405964d3367a326517c429512ecc2b2b93c1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Mon, 3 May 2021 23:15:44 -0700 Subject: [PATCH 15/47] state is initialized at every step in hybrid optimizer --- chop/stochastic.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index bfa36ba..3b6ba38 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -690,13 +690,14 @@ def step(self, closure=None): state['y'] = state['prox'](.5 * p.detach().clone()) # initialize grad estimate state['grad_est'] = torch.zeros_like(p) - # initialize learning rates - state['lr'] = group['lr'] if type( - group['lr'] == float) else 0. - state['lipschitz'] = group['lipschitz'] - state['lr_prox'] = state['lr'] * state['lipschitz'] - state['momentum'] = group['momentum'] if type( - group['momentum'] == float) else 0. + + # set learning rates + state['lr'] = group['lr'] if type( + group['lr'] == float) else None + state['lipschitz'] = group['lipschitz'] + state['lr_prox'] = state['lr'] * state['lipschitz'] + state['momentum'] = group['momentum'] if type( + group['momentum'] == float) else 0. if group['lr'] == 'sublinear': state['lr'] = 2. / (state['step'] + 2) From 6bf2b138e86b20da6e45ff5f72bac6460117be15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Fri, 7 May 2021 01:19:01 -0700 Subject: [PATCH 16/47] Prox/LMO are now Modules, to ensure pickle-ability --- chop/stochastic.py | 91 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 63 insertions(+), 28 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index 3b6ba38..8876693 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -12,6 +12,7 @@ import warnings import torch +from torch import nn from torch.optim import Optimizer import numpy as np @@ -19,6 +20,30 @@ EPS = np.finfo(np.float32).eps +class Prox(nn.Module): + def __init__(self, prox_fun=None): + super().__init__() + self.prox_fun = prox_fun + + def forward(self, x, s=None): + if self.prox_fun is not None: + return self.prox_fun(x.unsqueeze(0)).squeeze(0) + else: + return x + + +class LMO(nn.Module): + def __init__(self, lmo_fun): + super().__init__() + self.lmo_fun = lmo_fun + + def forward(self, u, x): + update_direction, max_step_size = self.lmo_fun( + u.unsqueeze(0), x.unsqueeze(0)) + return update_direction.squeeze(dim=0), max_step_size.squeeze(dim=0) + + + def backtracking_step_size( x, f_t, @@ -578,7 +603,36 @@ def step(self, closure=None): class SplittingProxFW(Optimizer): - # TODO: write docstring! + """ + Stochastic splitting optimization algorithm, using a prox and a LMO primitive. + + Args: + params: + parameters to optimize + + lmo: [callable or None] + LMO oracles corresponding to each parameter in params + + prox: [callable or None] or None + prox oracles corresponding to each parameter in params + + lr: float + learning rate + + lipschitz: float + estimate of the Lipschitz constant of the objective + + momentum: float in [0., 1.] + momentum to apply in the stochastic gradient estimator + + weigth_decay: float > 0 + scale of L2 penalty + + normalization: str + One of {'gradient', 'none'}. Default: 'none'. + If using 'gradient', normalizes the update_direction to have the same magnitude as the gradient, + for the LMO part. + """ name = 'Hybrid Prox FW Splitting' @@ -590,42 +644,23 @@ def __init__(self, params, lmo, prox=None, momentum=0., weight_decay=0., normalization='none'): params = list(params) + # initialize proxes if prox is None: prox = [None] * len(params) + prox_candidates = [Prox(oracle) for oracle in prox] - prox_candidates = [] - - def prox_maker(oracle): - if oracle: - def _prox(x, s=None): - return oracle(x.unsqueeze(0), s).squeeze(0) - else: - def _prox(x, s=None): - return x, s - return _prox - - prox_candidates = [prox_maker(oracle) for oracle in prox] # initialize lmos + lmo_candidates = [LMO(oracle) if oracle else None for oracle in lmo] - def lmo_maker(oracle): - def _lmo(u, x): - update_direction, max_step_size = oracle( - u.unsqueeze(0), x.unsqueeze(0)) - return update_direction.squeeze(dim=0), max_step_size.squeeze(dim=0) - - return _lmo - - lmo_candidates = [lmo_maker(oracle) if oracle else None for oracle in lmo] - - self.lmo = [] - self.prox = [] + lmos = [] + proxes = [] useable_params = [] for param, lmo_oracle, prox_oracle in zip(params, lmo_candidates, prox_candidates): if lmo_oracle is not None: useable_params.append(param) - self.lmo.append(lmo_oracle) - self.prox.append(prox_oracle) + lmos.append(lmo_oracle) + proxes.append(prox_oracle) else: msg = (f"No LMO was provided for parameter {param}. " f"This optimizer will not optimize this parameter. " @@ -646,7 +681,7 @@ def _lmo(u, x): if normalization not in self.POSSIBLE_NORMALIZATIONS: raise ValueError( f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}") - defaults = dict(lmo=self.lmo, prox=self.prox, + defaults = dict(lmo=lmos, prox=proxes, name=self.name, momentum=momentum, lr=lr, From 346dd321dc5a0775bfa77832d4585505551ba712 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Fri, 7 May 2021 16:24:46 -0700 Subject: [PATCH 17/47] Fixed initialization bug: making y feasible w/ correct projection --- chop/stochastic.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index 8876693..847c7a5 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -611,10 +611,14 @@ class SplittingProxFW(Optimizer): parameters to optimize lmo: [callable or None] - LMO oracles corresponding to each parameter in params + LMO oracles corresponding to each parameter in params. Applies to the y variable. - prox: [callable or None] or None - prox oracles corresponding to each parameter in params + prox1: [callable or None] or None + prox oracles corresponding to each parameter in params. This one is for the x variable. + + prox2: [callable or None] or None + prox oracles corresponding to each parameter in params. + Only used for initializing y to be feasible. lr: float learning rate @@ -638,7 +642,7 @@ class SplittingProxFW(Optimizer): POSSIBLE_NORMALIZATIONS = {'none', 'gradient'} - def __init__(self, params, lmo, prox=None, + def __init__(self, params, lmo, prox1=None, prox2=None, lr=.1, lipschitz=1., momentum=0., weight_decay=0., @@ -646,23 +650,26 @@ def __init__(self, params, lmo, prox=None, params = list(params) # initialize proxes - if prox is None: - prox = [None] * len(params) - prox_candidates = [Prox(oracle) for oracle in prox] + if prox1 is None: + prox1 = [None] * len(params) + prox_candidates = [Prox(oracle) for oracle in prox1] # initialize lmos lmo_candidates = [LMO(oracle) if oracle else None for oracle in lmo] + prox_y = [Prox(oracle) for oracle in prox2] lmos = [] proxes = [] useable_params = [] - for param, lmo_oracle, prox_oracle in zip(params, lmo_candidates, prox_candidates): + proxes_y = [] + for k, (param, lmo_oracle, prox_oracle, prox_y_oracle) in enumerate(zip(params, lmo_candidates, prox_candidates, prox_y)): if lmo_oracle is not None: useable_params.append(param) lmos.append(lmo_oracle) proxes.append(prox_oracle) + proxes_y.append(prox_y_oracle) else: - msg = (f"No LMO was provided for parameter {param}. " + msg = (f"No LMO was provided for parameter {k}. " f"This optimizer will not optimize this parameter. " f"Please pass this parameter to another optimizer.") warnings.warn(msg) @@ -682,6 +689,7 @@ def __init__(self, params, lmo, prox=None, raise ValueError( f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}") defaults = dict(lmo=lmos, prox=proxes, + prox_y=proxes_y, name=self.name, momentum=momentum, lr=lr, @@ -719,16 +727,17 @@ def step(self, closure=None): if len(state) == 0: state['step'] = 0. state['prox'] = group['prox'][idx] + state['prox_y'] = group['prox_y'][idx] state['lmo'] = group['lmo'][idx] # split variable: p = x + y and make feasible state['x'] = state['prox'](.5 * p.detach().clone()) - state['y'] = state['prox'](.5 * p.detach().clone()) + state['y'] = state['prox_y'](.5 * p.detach().clone()) # initialize grad estimate state['grad_est'] = torch.zeros_like(p) # set learning rates state['lr'] = group['lr'] if type( - group['lr'] == float) else None + group['lr'] == float) else None state['lipschitz'] = group['lipschitz'] state['lr_prox'] = state['lr'] * state['lipschitz'] state['momentum'] = group['momentum'] if type( From 42007102a1e744314b497fc1795ec209b1085040 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Sat, 8 May 2021 01:46:36 -0700 Subject: [PATCH 18/47] Fixed stochastic robust PCA example + code snippet for extracting components --- examples/plot_stochastic_robust_PCA.py | 53 ++++++++++++++++++-------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/examples/plot_stochastic_robust_PCA.py b/examples/plot_stochastic_robust_PCA.py index ada42ac..9d1b3a8 100644 --- a/examples/plot_stochastic_robust_PCA.py +++ b/examples/plot_stochastic_robust_PCA.py @@ -19,6 +19,8 @@ import chop from chop import utils from chop.utils.logging import Trace +from time import time + torch.manual_seed(0) @@ -67,15 +69,16 @@ def sqloss(Z, M): lmo = rank_constraint.lmo prox = sparsity_constraint.prox + prox_lr = rank_constraint.prox Z = torch.zeros_like(M, device=device) Z.requires_grad_(True) batch_sizes = [100, 250, 500, 1000] - fig, axes = plt.subplots(ncols=len(batch_sizes), figsize=(18, 10), sharey=True) + fig, axes = plt.subplots(nrows=2, ncols=len(batch_sizes), figsize=(18, 10), sharey=True) fig.suptitle(f'r={r} and p={p}') - for batch_size, ax in zip(batch_sizes, axes): + for batch_size, ax_it, ax_time in zip(batch_sizes, axes[0], axes[1]): print(f"Batch size: {batch_size}") sampler = torch.utils.data.BatchSampler(torch.utils.data.RandomSampler(range(M.size(0))), batch_size=batch_size, @@ -83,35 +86,55 @@ def sqloss(Z, M): momentum = .9 if batch_size != 1000 else 0. - optimizer = chop.stochastic.SplittingProxFW([Z], lmo=[lmo], prox=[prox], - lr_lmo='sublinear', - lr_prox='sublinear', + optimizer = chop.stochastic.SplittingProxFW([Z], lmo=[lmo], + prox1=[prox], + prox2=[prox_lr], + lr='sublinear', + lipschitz=1., normalization='none', - # weight_decay=1e-8, momentum=momentum) train_losses = [] + start = time() + times = [] losses = [] sgrad_avg = 0 n_it = 0 + freq = 10 for it in range(n_epochs): for idx in sampler: n_it += 1 optimizer.zero_grad() loss = sqloss(Z[idx], M[idx]) - train_losses.append(loss.item()) loss.backward() sgrad = Z.grad.detach().clone() sgrad_avg += sgrad # for logging - with torch.no_grad(): - full_loss = sqloss(Z, M) - losses.append(full_loss.item()) + if n_it % freq == 0: + with torch.no_grad(): + times.append(time() - start) + full_loss = sqloss(Z, M) + train_losses.append(loss.item()) + losses.append(full_loss.item()) optimizer.step() - ax.set_title(f"b={batch_size}") - ax.plot(train_losses, label='training_losses') - ax.plot(losses, label='loss') - ax.set_yscale('log') - ax.legend() + + # Get sparse and LR component + state = optimizer.state[Z] + sparse_comp = state['x'] + lr_comp = state['y'] + + # Plots + ax_it.set_title(f"b={batch_size}") + ax_it.plot(train_losses, label='mini-batch loss') + ax_it.plot(losses, label='full loss') + ax_it.set_xlabel('iterations') + ax_time.plot(times, train_losses, label='minibatch loss') + ax_time.plot(times, losses, label='full loss') + ax_time.set_xlabel('time (s)') + ax_it.set_yscale('log') + ax_it.legend() + fig.show() + fig.savefig("robustPCA.png") + break print("Done.") From 22a6c192f145bb247df2b728141b0aeef970569a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Sat, 8 May 2021 01:47:24 -0700 Subject: [PATCH 19/47] Minor fixes --- chop/stochastic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index 847c7a5..f27df47 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -735,17 +735,17 @@ def step(self, closure=None): # initialize grad estimate state['grad_est'] = torch.zeros_like(p) - # set learning rates - state['lr'] = group['lr'] if type( - group['lr'] == float) else None + # set state parameters state['lipschitz'] = group['lipschitz'] - state['lr_prox'] = state['lr'] * state['lipschitz'] state['momentum'] = group['momentum'] if type( group['momentum'] == float) else 0. if group['lr'] == 'sublinear': state['lr'] = 2. / (state['step'] + 2) + elif type(group['lr']) == float: + state['lr'] = group['lr'] + if group['momentum'] == 'sublinear': rho = 4. / (state['step'] + 8.) ** (2/3) state['momentum'] = 1. - rho From 6cb9a32dca543e6626707cf64a8baf1633b1e27b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Sat, 8 May 2021 17:14:19 -0700 Subject: [PATCH 20/47] Fix when prox2 is None --- chop/stochastic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chop/stochastic.py b/chop/stochastic.py index f27df47..aa74621 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -652,6 +652,8 @@ def __init__(self, params, lmo, prox1=None, prox2=None, # initialize proxes if prox1 is None: prox1 = [None] * len(params) + if prox2 is None: + prox2 = [None] * len(params) prox_candidates = [Prox(oracle) for oracle in prox1] # initialize lmos From fb6ec7d2f78f441c7e007db268cc829669effc30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Sun, 9 May 2021 21:03:39 -0700 Subject: [PATCH 21/47] Fixed view/reshape issue --- chop/constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index 4de00db..28b6d13 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -368,11 +368,11 @@ def prox(self, x, step_size=None): projection of x onto the L1 ball. """ shape = x.shape - flattened_x = x.view(shape[0], -1) + flattened_x = x.reshape(shape[0], -1) # TODO vectorize this projected = [euclidean_proj_l1ball(row, s=self.alpha) for row in flattened_x] x = torch.stack(projected) - return x.view(*shape) + return x.reshape(*shape) class L2Ball(LpBall): From 28eff8104b12b2b8f094224e7118f06fb26e52b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Sun, 9 May 2021 23:42:34 -0700 Subject: [PATCH 22/47] Allowing 0 constraints --- chop/constraints.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index 28b6d13..f438139 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -158,13 +158,13 @@ def euclidean_proj_l1ball(v, s=1.): Args: - v: (n,) numpy array, + v: (n,) torch tensor, n-dimensional vector to project s: float, optional, default: 1, radius of the L1-ball Returns: - w: (n,) numpy array, + w: (n,) torch tensor, Euclidean projection of v on the L1-ball of radius s Notes ----- @@ -173,9 +173,11 @@ def euclidean_proj_l1ball(v, s=1.): -------- euclidean_proj_simplex """ - assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s + assert s >= 0, "Radius s must be strictly positive (%d <= 0)" % s if len(v.shape) > 1: raise ValueError + if s == 0: + return torch.zeros_like(v) # compute the vector of absolute values u = abs(v) # check if v is already a solution From 34b99a7e6acc01f8e52df134f8effd4de4d44cbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Thu, 13 May 2021 19:09:59 -0700 Subject: [PATCH 23/47] Changed dataloader + loss for efficiency --- examples/plot_stochastic_robust_PCA.py | 41 +++++++++++++------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/examples/plot_stochastic_robust_PCA.py b/examples/plot_stochastic_robust_PCA.py index 9d1b3a8..335a72e 100644 --- a/examples/plot_stochastic_robust_PCA.py +++ b/examples/plot_stochastic_robust_PCA.py @@ -34,7 +34,9 @@ (25, 3e-2), (130, 1e-2) ] -n_epochs = 400 +n_epochs = 200 + +sqloss = torch.nn.MSELoss() for r, p in r_p: print(f'r={r} and p={p}') @@ -42,22 +44,19 @@ V = torch.normal(torch.zeros(r, n)) # Low rank component - L = 10 * utils.bmm(U, V) + L = 10 * utils.bmm(U, V).to(device) # Sparse component - S = 100 * torch.normal(torch.zeros(m, n)) + S = 100 * torch.normal(torch.zeros(m, n)).to(device) S *= (torch.rand_like(S) <= p) # Add noise - N = torch.normal(torch.zeros(m, n)) + N = torch.normal(torch.zeros(m, n)).to(device) M = L + S + N M = M.to(device) - def sqloss(Z, M): - return .5 / Z.numel() * torch.linalg.norm((Z - M).squeeze(), ord='fro') ** 2 - rnuc = torch.linalg.norm(L.squeeze(), ord='nuc') sL1 = abs(S).sum() @@ -71,20 +70,17 @@ def sqloss(Z, M): prox = sparsity_constraint.prox prox_lr = rank_constraint.prox - Z = torch.zeros_like(M, device=device) - Z.requires_grad_(True) - batch_sizes = [100, 250, 500, 1000] fig, axes = plt.subplots(nrows=2, ncols=len(batch_sizes), figsize=(18, 10), sharey=True) fig.suptitle(f'r={r} and p={p}') for batch_size, ax_it, ax_time in zip(batch_sizes, axes[0], axes[1]): + Z = torch.zeros_like(M, device=device) + Z.requires_grad_(True) print(f"Batch size: {batch_size}") - sampler = torch.utils.data.BatchSampler(torch.utils.data.RandomSampler(range(M.size(0))), - batch_size=batch_size, - drop_last=True) - - momentum = .9 if batch_size != 1000 else 0. + dataset = torch.utils.data.TensorDataset(Z, M) + loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) + momentum = .9 if batch_size != m else 0. optimizer = chop.stochastic.SplittingProxFW([Z], lmo=[lmo], prox1=[prox], @@ -94,18 +90,19 @@ def sqloss(Z, M): normalization='none', momentum=momentum) + train_losses = [] - start = time() times = [] losses = [] sgrad_avg = 0 n_it = 0 freq = 10 + start = time() for it in range(n_epochs): - for idx in sampler: + for zi, mi in loader: n_it += 1 optimizer.zero_grad() - loss = sqloss(Z[idx], M[idx]) + loss = sqloss(zi, mi) loss.backward() sgrad = Z.grad.detach().clone() sgrad_avg += sgrad @@ -114,9 +111,11 @@ def sqloss(Z, M): with torch.no_grad(): times.append(time() - start) full_loss = sqloss(Z, M) + print(full_loss) train_losses.append(loss.item()) losses.append(full_loss.item()) optimizer.step() + donetime = time() # Get sparse and LR component state = optimizer.state[Z] @@ -133,8 +132,10 @@ def sqloss(Z, M): ax_time.set_xlabel('time (s)') ax_it.set_yscale('log') ax_it.legend() - + print(f"Low rank loss: {torch.linalg.norm(L - lr_comp) / torch.linalg.norm(L)}") + print(f"Sparse loss: {torch.linalg.norm(S - sparse_comp) / torch.linalg.norm(S)}") + print(f"Reconstruction loss: {torch.linalg.norm(M - sparse_comp - lr_comp) / torch.linalg.norm(M)}") + print(f"Time: {times[-1]}s") fig.show() fig.savefig("robustPCA.png") - break print("Done.") From 9afbd5d84e65684e7d7b30588a75b5c921df2db8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Thu, 13 May 2021 19:11:17 -0700 Subject: [PATCH 24/47] Removed debug statements --- examples/plot_stochastic_robust_PCA.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/plot_stochastic_robust_PCA.py b/examples/plot_stochastic_robust_PCA.py index 335a72e..ee23cf7 100644 --- a/examples/plot_stochastic_robust_PCA.py +++ b/examples/plot_stochastic_robust_PCA.py @@ -111,7 +111,6 @@ with torch.no_grad(): times.append(time() - start) full_loss = sqloss(Z, M) - print(full_loss) train_losses.append(loss.item()) losses.append(full_loss.item()) optimizer.step() @@ -137,5 +136,4 @@ print(f"Reconstruction loss: {torch.linalg.norm(M - sparse_comp - lr_comp) / torch.linalg.norm(M)}") print(f"Time: {times[-1]}s") fig.show() - fig.savefig("robustPCA.png") print("Done.") From e95dbeddeb8b8ec6b2a1ba9663470c046787b2f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Thu, 13 May 2021 22:24:46 -0700 Subject: [PATCH 25/47] L1 penalty prox w/ stochastic hybrid splitting --- chop/penalties.py | 46 ++++++++++++++++++++++++++++++++++++++++++++-- chop/stochastic.py | 6 +++--- 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/chop/penalties.py b/chop/penalties.py index 5320cf9..e17ada2 100644 --- a/chop/penalties.py +++ b/chop/penalties.py @@ -19,7 +19,7 @@ class L1: - """L1 penalty. Batch-wise function. For each element in the batch, + """L1 Norm penalty. Batch-wise function. For each element in the batch, the L1 penalty is given by ..math:: \Omega(x) = \alpha \|x\|_1 @@ -46,7 +46,49 @@ def __call__(self, x): batch_size = x.size(0) return self.alpha * abs(x.view(batch_size, -1)).sum(dim=-1) - def prox(self, x, step_size=None): + def prox(self, x, step_size): + """Proximal operator for the L1 norm penalty. This is given by soft-thresholding. + + Args: + x: torch.Tensor + x has shape (batch_size, *) + step_size: float or torch.Tensor of shape (batch_size,) + + """ + if isinstance(step_size, Number): + step_size = step_size * torch.ones(x.size(0), device=x.device, dtype=x.dtype) + return utils.bmul(torch.sign(x), F.relu(abs(x) - self.alpha * step_size.view((-1,) + (1,) * (x.dim() - 1)))) + + +class NuclearNorm: + """Nuclear Norm penalty. Batch-wise function. For each element in the batch, + the L1 penalty is given by + ..math:: + \Omega(X) = \alpha \|X\|_* + """ + + def __init__(self, alpha: float): + """ + Args: + alpha: float + Size of the penalty. Must be non-negative. + """ + if alpha < 0: + raise ValueError("alpha must be non negative.") + self.alpha = alpha + + def __call__(self, x): + """ + Returns the value of the penalty on x, batch_size. + + Args: + x: torch.Tensor + x has shape (batch_size, *) + """ + batch_size = x.size(0) + return self.alpha * abs(x.view(batch_size, -1)).sum(dim=-1) + + def prox(self, x, step_size): """Proximal operator for the L1 norm penalty. This is given by soft-thresholding. Args: diff --git a/chop/stochastic.py b/chop/stochastic.py index aa74621..ef5f776 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -27,7 +27,7 @@ def __init__(self, prox_fun=None): def forward(self, x, s=None): if self.prox_fun is not None: - return self.prox_fun(x.unsqueeze(0)).squeeze(0) + return self.prox_fun(x.unsqueeze(0), s).squeeze(0) else: return x @@ -732,8 +732,8 @@ def step(self, closure=None): state['prox_y'] = group['prox_y'][idx] state['lmo'] = group['lmo'][idx] # split variable: p = x + y and make feasible - state['x'] = state['prox'](.5 * p.detach().clone()) - state['y'] = state['prox_y'](.5 * p.detach().clone()) + state['x'] = state['prox'](.5 * p.detach().clone(), 1.) + state['y'] = state['prox_y'](.5 * p.detach().clone(), 1.) # initialize grad estimate state['grad_est'] = torch.zeros_like(p) From 601b7386348b775a29507ddc802e6954defca633 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Fri, 14 May 2021 00:39:05 -0700 Subject: [PATCH 26/47] Penalized version of RobustPCA w/ generalized LMO --- chop/constraints.py | 16 ++++----- chop/penalties.py | 45 +++++++++++++++++++++----- chop/stochastic.py | 16 ++++++--- examples/plot_stochastic_robust_PCA.py | 20 +++++++++--- 4 files changed, 72 insertions(+), 25 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index f438139..2c71ca7 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -274,7 +274,7 @@ def prox(self, x, step_size=None): return torch.clamp(x, min=-self.alpha, max=self.alpha) @torch.no_grad() - def lmo(self, grad, iterate): + def lmo(self, grad, iterate, step_size=None): """Linear Maximization Oracle. Return s - iterate with s solving the linear problem @@ -323,7 +323,7 @@ class L1Ball(LpBall): p = 1 @torch.no_grad() - def lmo(self, grad, iterate): + def lmo(self, grad, iterate, step_size=None): """Linear Maximization Oracle. Return s - iterate with s solving the linear problem @@ -402,7 +402,7 @@ def prox(self, x, step_size=None): @torch.no_grad() - def lmo(self, grad, iterate): + def lmo(self, grad, iterate, step_size=None): """Linear Maximization Oracle. Return s - iterate with s solving the linear problem @@ -458,7 +458,7 @@ def prox(self, x, step_size=None): return x.view(*shape) @torch.no_grad() - def lmo(self, grad, iterate): + def lmo(self, grad, iterate, step_size=None): batch_size = grad.size(0) shape = iterate.shape max_vals, max_idx = grad.reshape(batch_size, -1).max(-1) @@ -488,7 +488,7 @@ def __init__(self, alpha): self.alpha = alpha @torch.no_grad() - def lmo(self, grad, iterate): + def lmo(self, grad, iterate, step_size=None): """ Computes the LMO for the Nuclear Norm Ball on the last two dimensions. Returns :math: `s - $iterate$` where @@ -504,8 +504,8 @@ def lmo(self, grad, iterate): """ update_direction = -iterate.clone().detach() u, _, v = utils.power_iteration(grad) - outer = u.unsqueeze(-1) * v.unsqueeze(-2) - update_direction += self.alpha * outer + atom = u.unsqueeze(-1) * v.unsqueeze(-2) + update_direction += self.alpha * atom return update_direction, torch.ones(iterate.size(0), device=iterate.device, dtype=iterate.dtype) @torch.no_grad() @@ -558,7 +558,7 @@ def get_group_norms(self, x): return group_norms @torch.no_grad() - def lmo(self, grad, iterate): + def lmo(self, grad, iterate, step_size=None): update_direction = -iterate.detach().clone() # find group with largest L2 norm group_norms = self.get_group_norms(grad) diff --git a/chop/penalties.py b/chop/penalties.py index e17ada2..4d60044 100644 --- a/chop/penalties.py +++ b/chop/penalties.py @@ -16,8 +16,18 @@ import torch.nn.functional as F from chop import utils +from chop import constraints +def penalty_lmo_step(atom, iterate, grad, kwargs): + sqfronorm_atom = torch.linalg.norm(atom, ord='fro', dim=(-2, -1)) ** 2 + step = (atom * iterate).sum(dim=(-2, -1)) \ + - 2 * (self.alpha - kwargs['step_size'] * (atom * grad).sum(dim=(-2, -1))) \ + / (kwargs['lipschitz'] * kwargs['step_size'] ** 2) + step = utils.bdiv(step, sqfronorm_atom) + return step + + class L1: """L1 Norm penalty. Batch-wise function. For each element in the batch, the L1 penalty is given by @@ -79,27 +89,46 @@ def __init__(self, alpha: float): def __call__(self, x): """ - Returns the value of the penalty on x, batch_size. + Returns the value of the penalty on x, batch wize. Args: x: torch.Tensor x has shape (batch_size, *) """ - batch_size = x.size(0) - return self.alpha * abs(x.view(batch_size, -1)).sum(dim=-1) + return self.alpha * torch.linalg.norm(x, ord='nuc', dim=(-2, -1)) + @torch.no_grad() def prox(self, x, step_size): - """Proximal operator for the L1 norm penalty. This is given by soft-thresholding. + """Proximal operator for the Nuclear norm penalty. This is given by soft-thresholding of the singular values. Args: x: torch.Tensor - x has shape (batch_size, *) - step_size: float or torch.Tensor of shape (batch_size,) + x has shape (*batch_sizes, m, n) + step_size: float or torch.Tensor of shape (*batch_sizes,) """ + *batch_sizes, m, n = x.shape + if not batch_sizes: + batch_sizes = [1] if isinstance(step_size, Number): - step_size = step_size * torch.ones(x.size(0), device=x.device, dtype=x.dtype) - return utils.bmul(torch.sign(x), F.relu(abs(x) - self.alpha * step_size.view((-1,) + (1,) * (x.dim() - 1)))) + step_size = step_size * torch.ones(*batch_sizes, device=x.device, dtype=x.dtype) + U, S, V = torch.linalg.svd(x) + L1penalty = L1(self.alpha) + S_thresh = L1penalty.prox(S, step_size) + VT = V.transpose(-2, -1) + return U @ torch.diag_embed(S_thresh) @ VT + + @torch.no_grad() + def lmo(self, grad, iterate, **kwargs): + """Generalized LMO for the Nuclear norm penalty""" + batch_sizes = grad.shape[:-2] + if not batch_sizes: + batch_sizes = [1] + ball = constraints.NuclearNormBall(1.) + update_direction, _ = ball.lmo(grad, iterate) + atom = update_direction + iterate + step = penalty_lmo_step(atom, iterate, grad, **kwargs) + return step * atom - iterate, torch.ones(*batch_sizes, dtype=iterate.dtype) class GroupL1: diff --git a/chop/stochastic.py b/chop/stochastic.py index ef5f776..93992ce 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -646,7 +646,7 @@ def __init__(self, params, lmo, prox1=None, prox2=None, lr=.1, lipschitz=1., momentum=0., weight_decay=0., - normalization='none'): + normalization='none', generalized_lmo=False): params = list(params) # initialize proxes @@ -697,7 +697,8 @@ def __init__(self, params, lmo, prox1=None, prox2=None, lr=lr, lipschitz=lipschitz, weight_decay=weight_decay, - normalization=normalization) + normalization=normalization, + generalized_lmo=generalized_lmo) super(SplittingProxFW, self).__init__(useable_params, defaults) @@ -756,8 +757,15 @@ def step(self, closure=None): state['grad_est'].add_( grad - state['grad_est'], alpha=1. - state['momentum']) - y_update, max_step_size = state['lmo']( - -state['grad_est'], state['y']) + if group['generalized_lmo']: + y_update, max_step_size = state['lmo']( + -state['grad_est'], state['y'], + {'lipschitz': state['lipschitz'], + 'step_size': state['lr']} + ) + else: + y_update, max_step_size = state['lmo']( + -state['grad_est'], state['y']) state['lr_prox'] = state['lr'] * state['lipschitz'] state['lr'] = min(max_step_size, state['lr']) diff --git a/examples/plot_stochastic_robust_PCA.py b/examples/plot_stochastic_robust_PCA.py index ee23cf7..091d089 100644 --- a/examples/plot_stochastic_robust_PCA.py +++ b/examples/plot_stochastic_robust_PCA.py @@ -36,7 +36,9 @@ n_epochs = 200 -sqloss = torch.nn.MSELoss() +sqloss = torch.nn.MSELoss(reduction='sum') +lam = 1e6 +freq = 50 for r, p in r_p: print(f'r={r} and p={p}') @@ -56,6 +58,9 @@ M = L + S + N M = M.to(device) + + # From Candès paper + mu = (m * n) / (8 * torch.linalg.norm(M, ord='fro') ** 2) rnuc = torch.linalg.norm(L.squeeze(), ord='nuc') sL1 = abs(S).sum() @@ -64,10 +69,13 @@ print(f"Initial Nuclear norm: {rnuc}") rank_constraint = chop.constraints.NuclearNormBall(rnuc) - sparsity_constraint = chop.constraints.L1Ball(sL1) + rank_penalty = chop.penalties.NuclearNorm(1.) + # sparsity_constraint = chop.constraints.L1Ball(sL1) + sparsity_penalty = chop.penalties.L1(lam) lmo = rank_constraint.lmo - prox = sparsity_constraint.prox + # prox = sparsity_constraint.prox + prox = sparsity_penalty.prox prox_lr = rank_constraint.prox batch_sizes = [100, 250, 500, 1000] @@ -96,13 +104,12 @@ losses = [] sgrad_avg = 0 n_it = 0 - freq = 10 start = time() for it in range(n_epochs): for zi, mi in loader: n_it += 1 optimizer.zero_grad() - loss = sqloss(zi, mi) + loss = mu * sqloss(zi, mi) loss.backward() sgrad = Z.grad.detach().clone() sgrad_avg += sgrad @@ -111,6 +118,7 @@ with torch.no_grad(): times.append(time() - start) full_loss = sqloss(Z, M) + print(full_loss) train_losses.append(loss.item()) losses.append(full_loss.item()) optimizer.step() @@ -135,5 +143,7 @@ print(f"Sparse loss: {torch.linalg.norm(S - sparse_comp) / torch.linalg.norm(S)}") print(f"Reconstruction loss: {torch.linalg.norm(M - sparse_comp - lr_comp) / torch.linalg.norm(M)}") print(f"Time: {times[-1]}s") + break fig.show() + fig.savefig("robustPCA.png") print("Done.") From 633eee601f89548522a3624fd5eb7e1464de351d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Fri, 14 May 2021 19:36:32 -0700 Subject: [PATCH 27/47] Hybrid Prox method now works for penalty LMO --- chop/penalties.py | 33 ++++++++++++++++---------- chop/stochastic.py | 31 +++++++++++++++--------- examples/plot_stochastic_robust_PCA.py | 27 +++++++++++---------- 3 files changed, 55 insertions(+), 36 deletions(-) diff --git a/chop/penalties.py b/chop/penalties.py index 4d60044..ef0383b 100644 --- a/chop/penalties.py +++ b/chop/penalties.py @@ -19,14 +19,6 @@ from chop import constraints -def penalty_lmo_step(atom, iterate, grad, kwargs): - sqfronorm_atom = torch.linalg.norm(atom, ord='fro', dim=(-2, -1)) ** 2 - step = (atom * iterate).sum(dim=(-2, -1)) \ - - 2 * (self.alpha - kwargs['step_size'] * (atom * grad).sum(dim=(-2, -1))) \ - / (kwargs['lipschitz'] * kwargs['step_size'] ** 2) - step = utils.bdiv(step, sqfronorm_atom) - return step - class L1: """L1 Norm penalty. Batch-wise function. For each element in the batch, @@ -56,6 +48,7 @@ def __call__(self, x): batch_size = x.size(0) return self.alpha * abs(x.view(batch_size, -1)).sum(dim=-1) + @torch.no_grad() def prox(self, x, step_size): """Proximal operator for the L1 norm penalty. This is given by soft-thresholding. @@ -69,6 +62,22 @@ def prox(self, x, step_size): step_size = step_size * torch.ones(x.size(0), device=x.device, dtype=x.dtype) return utils.bmul(torch.sign(x), F.relu(abs(x) - self.alpha * step_size.view((-1,) + (1,) * (x.dim() - 1)))) + @torch.no_grad() + def lmo(self, grad, iterate, splitting=False, **kwargs): + """Generalized LMO for the Nuclear norm penalty""" + batch_sizes = grad.shape[:-2] + if not batch_sizes: + batch_sizes = [1] + at = self.atom(grad) + step = penalty_lmo_step(self.alpha, at, splitting, **kwargs) + return step * at - iterate, torch.ones(*batch_sizes, dtype=iterate.dtype) + + @torch.no_grad() + def atom(self, x): + ball = constraints.L1Ball(1.) + update_direction, _ = ball.lmo(grad, iterate) + return update_direction + iterate + class NuclearNorm: """Nuclear Norm penalty. Batch-wise function. For each element in the batch, @@ -119,16 +128,16 @@ def prox(self, x, step_size): return U @ torch.diag_embed(S_thresh) @ VT @torch.no_grad() - def lmo(self, grad, iterate, **kwargs): - """Generalized LMO for the Nuclear norm penalty""" + def lmo(self, grad, iterate): + """Generalized LMO for the Nuclear norm penalty. + This function returns an atom in the constraint set, most aligned with grad.""" batch_sizes = grad.shape[:-2] if not batch_sizes: batch_sizes = [1] ball = constraints.NuclearNormBall(1.) update_direction, _ = ball.lmo(grad, iterate) atom = update_direction + iterate - step = penalty_lmo_step(atom, iterate, grad, **kwargs) - return step * atom - iterate, torch.ones(*batch_sizes, dtype=iterate.dtype) + return atom, self.alpha * torch.ones(*batch_sizes, dtype=grad.dtype, device=grad.device) class GroupL1: diff --git a/chop/stochastic.py b/chop/stochastic.py index 93992ce..5f42917 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -14,6 +14,7 @@ import torch from torch import nn from torch.optim import Optimizer +import torch.nn.functional as F import numpy as np @@ -43,7 +44,6 @@ def forward(self, u, x): return update_direction.squeeze(dim=0), max_step_size.squeeze(dim=0) - def backtracking_step_size( x, f_t, @@ -646,7 +646,8 @@ def __init__(self, params, lmo, prox1=None, prox2=None, lr=.1, lipschitz=1., momentum=0., weight_decay=0., - normalization='none', generalized_lmo=False): + normalization='none', generalized_lmo=False, + ): params = list(params) # initialize proxes @@ -698,8 +699,8 @@ def __init__(self, params, lmo, prox1=None, prox2=None, lipschitz=lipschitz, weight_decay=weight_decay, normalization=normalization, - generalized_lmo=generalized_lmo) - + generalized_lmo=generalized_lmo + ) super(SplittingProxFW, self).__init__(useable_params, defaults) @torch.no_grad() @@ -756,19 +757,27 @@ def step(self, closure=None): state['step'] += 1. state['grad_est'].add_( grad - state['grad_est'], alpha=1. - state['momentum']) + state['lr_prox'] = state['lr'] * state['lipschitz'] if group['generalized_lmo']: - y_update, max_step_size = state['lmo']( - -state['grad_est'], state['y'], - {'lipschitz': state['lipschitz'], - 'step_size': state['lr']} + state['lr_prox'] *= 2 + atom, scale = state['lmo']( + -state['grad_est'], state['y'] ) + magnitude = (atom * (.5 * state['grad_est'] / (state['lipschitz'] * state['lr']) - state['y'])).sum() - scale + from icecream import ic + # ic(magnitude) + magnitude /= torch.linalg.norm(atom) ** 2 + magnitude = F.relu(magnitude) + w = magnitude * atom + v = state['prox']( + state['x'] - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) + y_update = w - state['y'] + else: y_update, max_step_size = state['lmo']( -state['grad_est'], state['y']) - - state['lr_prox'] = state['lr'] * state['lipschitz'] - state['lr'] = min(max_step_size, state['lr']) + state['lr'] = min(max_step_size, state['lr']) if group['normalization'] == 'gradient': # Normalize LMO update direction diff --git a/examples/plot_stochastic_robust_PCA.py b/examples/plot_stochastic_robust_PCA.py index 091d089..4c87aa5 100644 --- a/examples/plot_stochastic_robust_PCA.py +++ b/examples/plot_stochastic_robust_PCA.py @@ -18,9 +18,8 @@ import torch import chop from chop import utils -from chop.utils.logging import Trace from time import time - +import numpy as np torch.manual_seed(0) @@ -34,11 +33,11 @@ (25, 3e-2), (130, 1e-2) ] -n_epochs = 200 +n_epochs = 1000 sqloss = torch.nn.MSELoss(reduction='sum') -lam = 1e6 -freq = 50 +lam = 1. / np.sqrt(m) +freq = 100 for r, p in r_p: print(f'r={r} and p={p}') @@ -60,7 +59,7 @@ M = M.to(device) # From Candès paper - mu = (m * n) / (8 * torch.linalg.norm(M, ord='fro') ** 2) + mu = (m * n) / (8 * torch.linalg.norm(M.view(-1), ord=1)) rnuc = torch.linalg.norm(L.squeeze(), ord='nuc') sL1 = abs(S).sum() @@ -70,13 +69,14 @@ rank_constraint = chop.constraints.NuclearNormBall(rnuc) rank_penalty = chop.penalties.NuclearNorm(1.) - # sparsity_constraint = chop.constraints.L1Ball(sL1) + sparsity_constraint = chop.constraints.L1Ball(sL1) sparsity_penalty = chop.penalties.L1(lam) - lmo = rank_constraint.lmo + lmo = rank_penalty.lmo + # lmo = rank_constraint.lmo # prox = sparsity_constraint.prox prox = sparsity_penalty.prox - prox_lr = rank_constraint.prox + prox_lr = rank_penalty.prox batch_sizes = [100, 250, 500, 1000] fig, axes = plt.subplots(nrows=2, ncols=len(batch_sizes), figsize=(18, 10), sharey=True) @@ -96,7 +96,8 @@ lr='sublinear', lipschitz=1., normalization='none', - momentum=momentum) + momentum=momentum, + generalized_lmo=True) train_losses = [] @@ -118,7 +119,7 @@ with torch.no_grad(): times.append(time() - start) full_loss = sqloss(Z, M) - print(full_loss) + print(full_loss / torch.linalg.norm(M)) train_losses.append(loss.item()) losses.append(full_loss.item()) optimizer.step() @@ -143,7 +144,7 @@ print(f"Sparse loss: {torch.linalg.norm(S - sparse_comp) / torch.linalg.norm(S)}") print(f"Reconstruction loss: {torch.linalg.norm(M - sparse_comp - lr_comp) / torch.linalg.norm(M)}") print(f"Time: {times[-1]}s") - break + fig.show() - fig.savefig("robustPCA.png") + fig.savefig(f"robustPCA_{r_p}.png") print("Done.") From 10b7c4ee03e0c22735a345d31d3d772a54cbb692 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Fri, 14 May 2021 19:51:18 -0700 Subject: [PATCH 28/47] Slight change --- chop/stochastic.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index 5f42917..9c079ae 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -764,14 +764,12 @@ def step(self, closure=None): atom, scale = state['lmo']( -state['grad_est'], state['y'] ) - magnitude = (atom * (.5 * state['grad_est'] / (state['lipschitz'] * state['lr']) - state['y'])).sum() - scale - from icecream import ic - # ic(magnitude) + magnitude = (atom * (.5 * state['grad_est'] / (state['lipschitz'] * state['lr']))).sum() - scale magnitude /= torch.linalg.norm(atom) ** 2 magnitude = F.relu(magnitude) w = magnitude * atom v = state['prox']( - state['x'] - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) + state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) y_update = w - state['y'] else: From 3c996316a07150965d8e0c5c34f06f1601be3a06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Mon, 17 May 2021 17:37:36 -0700 Subject: [PATCH 29/47] Slight algo modification --- chop/stochastic.py | 5 ++--- examples/plot_stochastic_robust_PCA.py | 13 +++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index 9c079ae..5239dd1 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -764,12 +764,11 @@ def step(self, closure=None): atom, scale = state['lmo']( -state['grad_est'], state['y'] ) - magnitude = (atom * (.5 * state['grad_est'] / (state['lipschitz'] * state['lr']))).sum() - scale + eff_step = .5 / (state['lipschitz'] * state['lr']) + magnitude = (atom * (state['x'] + state['y'] - eff_step * state['grad_est'])).sum() - scale * eff_step magnitude /= torch.linalg.norm(atom) ** 2 magnitude = F.relu(magnitude) w = magnitude * atom - v = state['prox']( - state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) y_update = w - state['y'] else: diff --git a/examples/plot_stochastic_robust_PCA.py b/examples/plot_stochastic_robust_PCA.py index 4c87aa5..a1b9972 100644 --- a/examples/plot_stochastic_robust_PCA.py +++ b/examples/plot_stochastic_robust_PCA.py @@ -78,8 +78,9 @@ prox = sparsity_penalty.prox prox_lr = rank_penalty.prox - batch_sizes = [100, 250, 500, 1000] - fig, axes = plt.subplots(nrows=2, ncols=len(batch_sizes), figsize=(18, 10), sharey=True) + # batch_sizes = [100, 250, 500, 1000] + batch_sizes = [1000] + fig, axes = plt.subplots(nrows=2, ncols=max(len(batch_sizes), 2), figsize=(18, 10), sharey=True) fig.suptitle(f'r={r} and p={p}') for batch_size, ax_it, ax_time in zip(batch_sizes, axes[0], axes[1]): @@ -110,7 +111,7 @@ for zi, mi in loader: n_it += 1 optimizer.zero_grad() - loss = mu * sqloss(zi, mi) + loss = mu * sqloss(zi, mi) * Z.size(0) / zi.size(0) loss.backward() sgrad = Z.grad.detach().clone() sgrad_avg += sgrad @@ -118,7 +119,7 @@ if n_it % freq == 0: with torch.no_grad(): times.append(time() - start) - full_loss = sqloss(Z, M) + full_loss = mu * sqloss(Z, M) + torch.linalg.norm(optimizer.state[Z]['y'], ord='nuc') + lam * optimizer.state[Z]['x'].abs().sum() print(full_loss / torch.linalg.norm(M)) train_losses.append(loss.item()) losses.append(full_loss.item()) @@ -144,7 +145,7 @@ print(f"Sparse loss: {torch.linalg.norm(S - sparse_comp) / torch.linalg.norm(S)}") print(f"Reconstruction loss: {torch.linalg.norm(M - sparse_comp - lr_comp) / torch.linalg.norm(M)}") print(f"Time: {times[-1]}s") - + break fig.show() - fig.savefig(f"robustPCA_{r_p}.png") + fig.savefig(f"robustPCA_{r, p}.png") print("Done.") From 5106cf3dbb940ca1149913536036dab874fc1a7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Mon, 17 May 2021 23:07:50 -0700 Subject: [PATCH 30/47] Fixed penalty to use torch.svd instead of torch.linalg.svd + step size fix in stochastic --- chop/penalties.py | 2 +- chop/stochastic.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/chop/penalties.py b/chop/penalties.py index ef0383b..86c00ea 100644 --- a/chop/penalties.py +++ b/chop/penalties.py @@ -121,7 +121,7 @@ def prox(self, x, step_size): batch_sizes = [1] if isinstance(step_size, Number): step_size = step_size * torch.ones(*batch_sizes, device=x.device, dtype=x.dtype) - U, S, V = torch.linalg.svd(x) + U, S, V = torch.svd(x, False) L1penalty = L1(self.alpha) S_thresh = L1penalty.prox(S, step_size) VT = V.transpose(-2, -1) diff --git a/chop/stochastic.py b/chop/stochastic.py index 5239dd1..0751a64 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -760,7 +760,6 @@ def step(self, closure=None): state['lr_prox'] = state['lr'] * state['lipschitz'] if group['generalized_lmo']: - state['lr_prox'] *= 2 atom, scale = state['lmo']( -state['grad_est'], state['y'] ) From 2b2f6697bc62df939efdd9e453951d50f4b05df0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Mon, 17 May 2021 23:12:49 -0700 Subject: [PATCH 31/47] Fixed svd calls --- chop/penalties.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chop/penalties.py b/chop/penalties.py index 86c00ea..f702c71 100644 --- a/chop/penalties.py +++ b/chop/penalties.py @@ -121,10 +121,9 @@ def prox(self, x, step_size): batch_sizes = [1] if isinstance(step_size, Number): step_size = step_size * torch.ones(*batch_sizes, device=x.device, dtype=x.dtype) - U, S, V = torch.svd(x, False) + U, S, VT = torch.linalg.svd(x) L1penalty = L1(self.alpha) S_thresh = L1penalty.prox(S, step_size) - VT = V.transpose(-2, -1) return U @ torch.diag_embed(S_thresh) @ VT @torch.no_grad() From 4291e1ffe1caefbf28fc610b94989bdf34baeaea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Mon, 17 May 2021 23:30:43 -0700 Subject: [PATCH 32/47] SVD call updated --- chop/constraints.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index 2c71ca7..f70614f 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -513,13 +513,12 @@ def prox(self, x, step_size=None): """ Projection operator on the Nuclear Norm constraint set. """ - U, S, V = torch.svd(x) + U, S, VT = torch.linalg.svd(x) # Project S on the alpha-L1 ball ball = L1Ball(self.alpha) S_proj = ball.prox(S.view(-1, S.size(-1))).view_as(S) - VT = V.transpose(-2, -1) return torch.matmul(U, torch.matmul(torch.diag_embed(S_proj), VT)) def is_feasible(self, x, atol=1e-5, rtol=1e-5): From e5671d07dc46e115e6ff03be3bb0f918c8a7e06d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Tue, 18 May 2021 00:58:32 -0700 Subject: [PATCH 33/47] Fixed penalized training + changes to svd calls --- chop/constraints.py | 5 ++--- chop/penalties.py | 20 ++++++++------------ 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index f70614f..b64ee44 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -513,13 +513,12 @@ def prox(self, x, step_size=None): """ Projection operator on the Nuclear Norm constraint set. """ - U, S, VT = torch.linalg.svd(x) + U, S, VT = torch.linalg.svd(x, full_matrices=False) # Project S on the alpha-L1 ball ball = L1Ball(self.alpha) S_proj = ball.prox(S.view(-1, S.size(-1))).view_as(S) - - return torch.matmul(U, torch.matmul(torch.diag_embed(S_proj), VT)) + return U @ torch.diag_embed(S_proj) @ VT def is_feasible(self, x, atol=1e-5, rtol=1e-5): norms = torch.linalg.norm(x, dim=(-2, -1), ord='nuc') diff --git a/chop/penalties.py b/chop/penalties.py index f702c71..5f2193e 100644 --- a/chop/penalties.py +++ b/chop/penalties.py @@ -11,6 +11,7 @@ """ from numbers import Number +import numpy as np from numpy.core.fromnumeric import nonzero import torch import torch.nn.functional as F @@ -63,20 +64,14 @@ def prox(self, x, step_size): return utils.bmul(torch.sign(x), F.relu(abs(x) - self.alpha * step_size.view((-1,) + (1,) * (x.dim() - 1)))) @torch.no_grad() - def lmo(self, grad, iterate, splitting=False, **kwargs): - """Generalized LMO for the Nuclear norm penalty""" - batch_sizes = grad.shape[:-2] + def lmo(self, grad, iterate): + *batch_sizes, m, n = iterate.shape if not batch_sizes: batch_sizes = [1] - at = self.atom(grad) - step = penalty_lmo_step(self.alpha, at, splitting, **kwargs) - return step * at - iterate, torch.ones(*batch_sizes, dtype=iterate.dtype) - - @torch.no_grad() - def atom(self, x): ball = constraints.L1Ball(1.) update_direction, _ = ball.lmo(grad, iterate) - return update_direction + iterate + atom = update_direction + iterate + return atom, self.alpha * torch.ones(*batch_sizes, dtype=grad.dtype, device=grad.device) class NuclearNorm: @@ -121,9 +116,10 @@ def prox(self, x, step_size): batch_sizes = [1] if isinstance(step_size, Number): step_size = step_size * torch.ones(*batch_sizes, device=x.device, dtype=x.dtype) - U, S, VT = torch.linalg.svd(x) + U, S, VT = torch.linalg.svd(x, full_matrices=False) L1penalty = L1(self.alpha) - S_thresh = L1penalty.prox(S, step_size) + S_thresh = L1penalty.prox(S.reshape(np.prod(batch_sizes), S.size(-1)), step_size) + S_thresh = S_thresh.reshape(*batch_sizes, S.size(-1)) return U @ torch.diag_embed(S_thresh) @ VT @torch.no_grad() From b75436b5a2570351bcdd418888464cde70f01692 Mon Sep 17 00:00:00 2001 From: Francisco Utrera Date: Tue, 18 May 2021 15:55:57 -0700 Subject: [PATCH 34/47] took out print statement --- chop/constraints.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chop/constraints.py b/chop/constraints.py index b64ee44..4e463e9 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -68,7 +68,6 @@ def make_model_constraints(model, ord=2, value=300, mode='initialization', const if is_bias(name, param): constraint = None else: - print(name) if mode == 'radius': alpha = value elif mode == 'initialization': From 69bec20650b4c9323fe31305b65fc5ff5b05006a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Wed, 19 May 2021 17:42:01 -0700 Subject: [PATCH 35/47] Added penalty initialization --- chop/constraints.py | 22 +++++++++++++++++----- chop/penalties.py | 8 +++++--- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index 4e463e9..ca5c144 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -18,6 +18,7 @@ from scipy.stats import expon from torch.distributions import Laplace, Normal from chop import utils +from chop import penalties @torch.no_grad() @@ -36,8 +37,8 @@ def is_bias(name, param): @torch.no_grad() -def make_model_constraints(model, ord=2, value=300, mode='initialization', constrain_bias=False): - """Create Ball constraints for each layer of model. Ball radius depends on mode (either radius or +def make_model_constraints(model, ord=2, value=300, mode='initialization', constrain_bias=False, penalty=False): + """Create Ball constraints / penalties for each layer of model. Ball radius / penalty depends on mode (either radius or factor to multiply average initialization norm with)""" constraints = [] @@ -71,14 +72,25 @@ def make_model_constraints(model, ord=2, value=300, mode='initialization', const if mode == 'radius': alpha = value elif mode == 'initialization': - alpha = value * init_norms[param.shape] + if penalty: + alpha = value / init_norms[param.shape] + else: + alpha = value * init_norms[param.shape] else: msg = f"Unknown mode {mode}." raise ValueError(msg) if (type(ord) == int) or (ord == np.inf): - constraint = make_LpBall(alpha, p=ord) + if penalty: + if ord != 1: + raise NotImplementedError("Please use ord=1 or ord='nuc'.") + constraint = penalties.L1(alpha) + else: + constraint = make_LpBall(alpha, p=ord) elif ord == 'nuc': - constraint = NuclearNormBall(alpha) + if penalty: + constraint = penalties.NuclearNorm(alpha) + else: + constraint = NuclearNormBall(alpha) else: msg = f"ord {ord} is not supported." raise ValueError(msg) diff --git a/chop/penalties.py b/chop/penalties.py index 5f2193e..09f741a 100644 --- a/chop/penalties.py +++ b/chop/penalties.py @@ -76,7 +76,7 @@ def lmo(self, grad, iterate): class NuclearNorm: """Nuclear Norm penalty. Batch-wise function. For each element in the batch, - the L1 penalty is given by + the penalty is given by ..math:: \Omega(X) = \alpha \|X\|_* """ @@ -111,16 +111,18 @@ def prox(self, x, step_size): step_size: float or torch.Tensor of shape (*batch_sizes,) """ - *batch_sizes, m, n = x.shape + orig_shape = x.shape + *batch_sizes, m, n = orig_shape if not batch_sizes: batch_sizes = [1] + if isinstance(step_size, Number): step_size = step_size * torch.ones(*batch_sizes, device=x.device, dtype=x.dtype) U, S, VT = torch.linalg.svd(x, full_matrices=False) L1penalty = L1(self.alpha) S_thresh = L1penalty.prox(S.reshape(np.prod(batch_sizes), S.size(-1)), step_size) S_thresh = S_thresh.reshape(*batch_sizes, S.size(-1)) - return U @ torch.diag_embed(S_thresh) @ VT + return (U @ torch.diag_embed(S_thresh) @ VT).reshape(*orig_shape) @torch.no_grad() def lmo(self, grad, iterate): From c0be0ff1f1ea04fbafb4dc8b05c047f53388fe6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Thu, 20 May 2021 22:28:03 -0700 Subject: [PATCH 36/47] removed todo --- chop/constraints.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chop/constraints.py b/chop/constraints.py index ca5c144..17d2008 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -58,7 +58,6 @@ def make_model_constraints(model, ord=2, value=300, mode='initialization', const None))]: param = getattr(layer, param_type) shape = param.shape - # TODO: figure out how to set the constraint size for NuclearNormBall constraint avg_norm = get_avg_init_norm(layer, param_type=param_type, ord=2) if avg_norm == 0.0: # Catch unlikely case that weight/bias is 0-initialized (e.g. BatchNorm does this) From 7028526fede77bf2d58080fd75e247113a9659e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Thu, 20 May 2021 22:29:08 -0700 Subject: [PATCH 37/47] removed print, redundant computation --- chop/stochastic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index 0751a64..d52e806 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -774,18 +774,20 @@ def step(self, closure=None): y_update, max_step_size = state['lmo']( -state['grad_est'], state['y']) state['lr'] = min(max_step_size, state['lr']) + w = y_update + state['y'] if group['normalization'] == 'gradient': # Normalize LMO update direction grad_norm = torch.linalg.norm(state['grad_est']) y_update_norm = torch.linalg.norm(y_update) y_update *= min(1, grad_norm / y_update_norm) + w = y_update + state['y'] - w = y_update + state['y'] v = state['prox']( state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) x_update = v - state['x'] + # print(state['lr'], state['lr_prox']) state['y'].add_(y_update, alpha=state['lr']) state['x'].add_(x_update, alpha=state['lr']) From 34409b17692bf5f3bb8e29705edd8a941fa37f96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Fri, 21 May 2021 12:30:59 -0700 Subject: [PATCH 38/47] removed comment --- chop/stochastic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index d52e806..80e4bfe 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -787,7 +787,6 @@ def step(self, closure=None): state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) x_update = v - state['x'] - # print(state['lr'], state['lr_prox']) state['y'].add_(y_update, alpha=state['lr']) state['x'].add_(x_update, alpha=state['lr']) From 88d72b19bbbeb1c780c41bf9bc2d02b8aa911c87 Mon Sep 17 00:00:00 2001 From: Francisco Utrera Date: Sat, 22 May 2021 18:20:33 -0700 Subject: [PATCH 39/47] Updated CIFAR Added a few more transformations on the training loader and also changed the way we normalize a bit. --- chop/utils/data.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/chop/utils/data.py b/chop/utils/data.py index a1a7e7d..a7e70a7 100644 --- a/chop/utils/data.py +++ b/chop/utils/data.py @@ -83,18 +83,19 @@ class CIFAR10(Dataset): def __init__(self, data_dir, normalize=True): """Initializes Dataset""" self.mean = torch.Tensor((0.4914, 0.4822, 0.4465)) - self.mean = self.mean[..., None, None] self.std = torch.Tensor((0.2023, 0.1994, 0.2010)) - self.std = self.std[..., None, None] self.normalize = t.Normalize(self.mean, self.std) self.unnormalize = t.Normalize(-self.mean / self.std, 1. / self.std) - + transforms_train = [ - t.RandomCrop(32, padding=4), - t.RandomHorizontalFlip(), - t.ToTensor(), - ] + t.RandomCrop(32, padding=4), + t.RandomHorizontalFlip(), + t.ColorJitter(.25,.25,.25), + t.RandomRotation(2), + t.ToTensor(), + ] + transforms_test = [t.ToTensor()] if normalize: From 9393c27f5967ba3185eecdb77169891ccec90fb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Wed, 26 May 2021 11:42:39 -0700 Subject: [PATCH 40/47] Minor rewrites --- chop/constraints.py | 4 ++-- chop/penalties.py | 9 ++++----- chop/stochastic.py | 6 +++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index 17d2008..52efc0b 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -356,9 +356,9 @@ def lmo(self, grad, iterate, step_size=None): update_direction = -iterate.clone().detach() abs_grad = abs(grad) batch_size = iterate.size(0) - flatten_abs_grad = abs_grad.view(batch_size, -1) + flatten_abs_grad = abs_grad.reshape(batch_size, -1) flatten_largest_mask = (flatten_abs_grad == flatten_abs_grad.max(-1, True)[0]) - largest = torch.where(flatten_largest_mask.view_as(abs_grad)) + largest = torch.where(flatten_largest_mask.reshape_as(abs_grad)) update_direction[largest] += self.alpha * torch.sign( grad[largest]) diff --git a/chop/penalties.py b/chop/penalties.py index 09f741a..6589d27 100644 --- a/chop/penalties.py +++ b/chop/penalties.py @@ -20,7 +20,6 @@ from chop import constraints - class L1: """L1 Norm penalty. Batch-wise function. For each element in the batch, the L1 penalty is given by @@ -65,13 +64,13 @@ def prox(self, x, step_size): @torch.no_grad() def lmo(self, grad, iterate): - *batch_sizes, m, n = iterate.shape - if not batch_sizes: - batch_sizes = [1] + """Generalized LMO for the L1 norm penalty. + This function returns an atom in the constraint set, most aligned with grad.""" + batch_size = grad.size(0) ball = constraints.L1Ball(1.) update_direction, _ = ball.lmo(grad, iterate) atom = update_direction + iterate - return atom, self.alpha * torch.ones(*batch_sizes, dtype=grad.dtype, device=grad.device) + return atom, self.alpha * torch.ones(batch_size, dtype=grad.dtype, device=grad.device) class NuclearNorm: diff --git a/chop/stochastic.py b/chop/stochastic.py index 80e4bfe..e67d75f 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -763,9 +763,9 @@ def step(self, closure=None): atom, scale = state['lmo']( -state['grad_est'], state['y'] ) + atom /= torch.linalg.norm(atom) eff_step = .5 / (state['lipschitz'] * state['lr']) - magnitude = (atom * (state['x'] + state['y'] - eff_step * state['grad_est'])).sum() - scale * eff_step - magnitude /= torch.linalg.norm(atom) ** 2 + magnitude = (atom * (p - eff_step * state['grad_est'])).sum() - scale * eff_step magnitude = F.relu(magnitude) w = magnitude * atom y_update = w - state['y'] @@ -784,7 +784,7 @@ def step(self, closure=None): w = y_update + state['y'] v = state['prox']( - state['x'] + state['y'] - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) + p - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) x_update = v - state['x'] state['y'].add_(y_update, alpha=state['lr']) From 14378dce9eb6d25a739add29361ade760b16beee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Wed, 26 May 2021 13:29:47 -0700 Subject: [PATCH 41/47] vectorized L1/Simplex projections --- chop/constraints.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index 52efc0b..a934aec 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -143,17 +143,18 @@ def euclidean_proj_simplex(v, s=1.): http://www.cs.berkeley.edu/~jduchi/projects/DuchiSiShCh08.pdf """ assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s - (n,) = v.shape + b, n = v.shape # check if we are already on the simplex - if v.sum() == s and (v >= 0).all(): + if (v.sum(-1) == s).all() and (v >= 0).all(): return v # get the array of cumulative sums of a sorted (decreasing) copy of v - u, _ = torch.sort(v, descending=True) + u, _ = torch.sort(v, dim=-1, descending=True) cssv = torch.cumsum(u, dim=-1) # get the number of > 0 components of the optimal solution - rho = (u * torch.arange(1, n + 1, device=v.device) > (cssv - s)).sum() - 1 + rho = (u * torch.arange(1, n + 1, device=v.device) > (cssv - s)).sum(-1) - 1 # compute the Lagrange multiplier associated to the simplex constraint - theta = (cssv[rho] - s) / (rho + 1.0) + theta = (cssv[torch.arange(b, device=v.device), rho] - s) / (rho + 1.0) + theta = theta.unsqueeze(-1).expand_as(v) # compute the projection by thresholding v using theta w = torch.clamp(v - theta, min=0) return w @@ -184,14 +185,12 @@ def euclidean_proj_l1ball(v, s=1.): euclidean_proj_simplex """ assert s >= 0, "Radius s must be strictly positive (%d <= 0)" % s - if len(v.shape) > 1: - raise ValueError if s == 0: return torch.zeros_like(v) # compute the vector of absolute values u = abs(v) # check if v is already a solution - if u.sum() <= s: + if (u.sum(-1) <= s).all(): # L1-norm is <= s return v # v is not already a solution: optimum lies on the boundary (norm == s) @@ -382,9 +381,8 @@ def prox(self, x, step_size=None): shape = x.shape flattened_x = x.reshape(shape[0], -1) # TODO vectorize this - projected = [euclidean_proj_l1ball(row, s=self.alpha) for row in flattened_x] - x = torch.stack(projected) - return x.reshape(*shape) + projected = euclidean_proj_l1ball(flattened_x, s=self.alpha) + return projected.reshape(*shape) class L2Ball(LpBall): @@ -463,9 +461,8 @@ def __init__(self, alpha): def prox(self, x, step_size=None): shape = x.shape flattened_x = x.view(shape[0], -1) - projected = [euclidean_proj_simplex(row, s=self.alpha) for row in flattened_x] - x = torch.stack(projected) - return x.view(*shape) + projected = euclidean_proj_simplex(flattened_x, s=self.alpha) + return projected.view(*shape) @torch.no_grad() def lmo(self, grad, iterate, step_size=None): From cd283ca09b9da8dfc06439c2e78905b266fb4126 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Wed, 26 May 2021 18:46:45 -0700 Subject: [PATCH 42/47] special case when nuclear norm is of diameter 0 --- chop/constraints.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/chop/constraints.py b/chop/constraints.py index a934aec..f624e22 100644 --- a/chop/constraints.py +++ b/chop/constraints.py @@ -510,9 +510,10 @@ def lmo(self, grad, iterate, step_size=None): update_direction: torch.Tensor of shape (*, m, n) """ update_direction = -iterate.clone().detach() - u, _, v = utils.power_iteration(grad) - atom = u.unsqueeze(-1) * v.unsqueeze(-2) - update_direction += self.alpha * atom + if self.alpha > 0.: + u, _, v = utils.power_iteration(grad) + atom = u.unsqueeze(-1) * v.unsqueeze(-2) + update_direction += self.alpha * atom return update_direction, torch.ones(iterate.size(0), device=iterate.device, dtype=iterate.dtype) @torch.no_grad() @@ -520,6 +521,8 @@ def prox(self, x, step_size=None): """ Projection operator on the Nuclear Norm constraint set. """ + if self.alpha == 0: + return torch.zeros_like(x) U, S, VT = torch.linalg.svd(x, full_matrices=False) # Project S on the alpha-L1 ball ball = L1Ball(self.alpha) From 429129bdce0df9110925a71a819ac146544219e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Wed, 26 May 2021 19:15:21 -0700 Subject: [PATCH 43/47] Made Frank-Wolfe savable --- chop/stochastic.py | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index e67d75f..2fe0d39 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -486,30 +486,27 @@ class FrankWolfe(Optimizer): name = 'Frank-Wolfe' POSSIBLE_NORMALIZATIONS = {'gradient', 'none'} - def __init__(self, params, lmo, lr=.1, momentum=0., + def __init__(self, params, lmo, prox=None, lr=.1, momentum=0., weight_decay=0., normalization='none'): - lmo_candidates = [] - for oracle in lmo: - if oracle is None: - # Then FW will not be used on this parameter - _lmo = None - else: - def _lmo(u, x): - update_direction, max_step_size = oracle( - u.unsqueeze(0), x.unsqueeze(0)) - return update_direction.squeeze(dim=0), max_step_size - lmo_candidates.append(_lmo) + if prox is None: + prox = [None] * len(params) + + lmo_candidates = [LMO(oracle) if oracle else None for oracle in lmo] + prox = [Prox(oracle) for oracle in prox] - self.lmo = [] useable_params = [] - for param, oracle in zip(params, lmo): - if oracle: + lmos = [] + proxes = [] + + for k, (param, lmo_oracle, prox_oracle) in enumerate(zip(params, lmo_candidates, prox)): + if lmo_oracle is not None: useable_params.append(param) - self.lmo.append(oracle) + lmos.append(lmo_oracle) + proxes.append(prox_oracle) else: - msg = (f"No LMO was provided for parameter {param}. " + msg = (f"No LMO was provided for parameter {k}. " f"Frank-Wolfe will not optimize this parameter. " f"Please use another optimizer.") warnings.warn(msg) @@ -529,7 +526,7 @@ def _lmo(u, x): raise ValueError( f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}.") self.normalization = normalization - defaults = dict(lmo=self.lmo, name=self.name, lr=self.lr, + defaults = dict(lmo=lmos, prox=proxes, name=self.name, lr=self.lr, momentum=self.momentum, weight_decay=weight_decay, normalization=self.normalization) @@ -556,8 +553,8 @@ def step(self, closure=None): if closure is not None: with torch.enable_grad(): loss = closure() - idx = 0 for group in self.param_groups: + idx = 0 for p in group['params']: if p.grad is None: continue @@ -570,7 +567,10 @@ def step(self, closure=None): state['step'] = 0 state['grad_estimate'] = torch.zeros_like( p, memory_format=torch.preserve_format) - + state['prox'] = group['prox'][idx] + state['lmo'] = group['lmo'][idx] + # make sure p is in the constraint set + p.copy_(state['prox'](p, 1.)) if self.lr == 'sublinear': step_size = 1. / (state['step'] + 1.) elif type(self.lr) == float: @@ -588,7 +588,7 @@ def step(self, closure=None): state['grad_estimate'].add_( grad - state['grad_estimate'], alpha=1. - momentum) - update_direction, _ = self.lmo[idx](-state['grad_estimate'], p) + update_direction, _ = state['lmo'](-state['grad_estimate'], p) state['certificate'] = (-state['grad_estimate'] * update_direction).sum() if group['normalization'] == 'gradient': @@ -778,13 +778,13 @@ def step(self, closure=None): if group['normalization'] == 'gradient': # Normalize LMO update direction - grad_norm = torch.linalg.norm(state['grad_est']) + grad_norm = torch.linalg.norm(grad) y_update_norm = torch.linalg.norm(y_update) y_update *= min(1, grad_norm / y_update_norm) w = y_update + state['y'] v = state['prox']( - p - w - state['grad_est'] / state['lr_prox'], state['lr_prox']) + p - w - state['grad_est'] / state['lr_prox'], 1. / state['lr_prox']) x_update = v - state['x'] state['y'].add_(y_update, alpha=state['lr']) From d575a09ed90031cbb0adb174d09694ede91971bc Mon Sep 17 00:00:00 2001 From: Francisco Utrera Date: Tue, 1 Jun 2021 01:40:58 -0700 Subject: [PATCH 44/47] Updated ImageNet --- chop/utils/data.py | 59 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/chop/utils/data.py b/chop/utils/data.py index a7e70a7..0c94e20 100644 --- a/chop/utils/data.py +++ b/chop/utils/data.py @@ -12,6 +12,7 @@ from torch import nn import torchvision from torchvision import transforms as t +from torchvision import transforms class Dataset: @@ -31,10 +32,10 @@ def loaders(self, train_batch_size=128, test_batch_size=128, num_workers=2, pin_memory=True, shuffle=True): """Load training and test data.""" - train_loader = torch.utils.data.DataLoader(self.dataset.train, batch_size=train_batch_size, shuffle=shuffle, num_workers=num_workers, - pin_memory=pin_memory) - test_loader = torch.utils.data.DataLoader(self.dataset.test, batch_size=test_batch_size, shuffle=False, num_workers=num_workers, - pin_memory=pin_memory) + train_loader = torch.utils.data.DataLoader(self.dataset.train, batch_size=train_batch_size, + shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) + test_loader = torch.utils.data.DataLoader(self.dataset.test, batch_size=test_batch_size, + shuffle=False, num_workers=num_workers, pin_memory=pin_memory) return EasyDict(train=train_loader, test=test_loader) def load_k(self, k, train, device, shuffle=True, num_workers=2, pin_memory=False): @@ -124,6 +125,37 @@ def __init__(self, data_dir, normalize=True): } + +class Lighting(object): + """ + Lighting noise (see https://git.io/fhBOc) + """ + def __init__(self, alphastd, eigval, eigvec): + self.alphastd = alphastd + self.eigval = eigval + self.eigvec = eigvec + + def __call__(self, img): + if self.alphastd == 0: + return img + + alpha = img.new().resize_(3).normal_(0, self.alphastd) + rgb = self.eigvec.type_as(img).clone()\ + .mul(alpha.view(1, 3).expand(3, 3))\ + .mul(self.eigval.view(1, 3).expand(3, 3))\ + .sum(1).squeeze() + + return img.add(rgb.view(3, 1, 1).expand_as(img)) + +IMAGENET_PCA = { + 'eigval':ch.Tensor([0.2175, 0.0188, 0.0045]), + 'eigvec':ch.Tensor([ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203], + ]) +} + class ImageNet(Dataset): def __init__(self, data_dir, normalize=True): @@ -136,12 +168,21 @@ def __init__(self, data_dir, normalize=True): self.unnormalize = t.Normalize(-self.mean / self.std, 1./self.std) transforms_train = [ - t.RandomResizedCrop(224), - t.RandomHorizontalFlip(), - t.ToTensor() - ] + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter( + brightness=0.1, + contrast=0.1, + saturation=0.1 + ), + transforms.ToTensor(), + Lighting(0.05, IMAGENET_PCA['eigval'], + IMAGENET_PCA['eigvec']) + ] - transforms_test = [t.ToTensor()] + transforms_test = [transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor()] if normalize: transforms_train.append(self.normalize) From 68129af61bdc71d8dee4fe10fc170ef7f9abffa2 Mon Sep 17 00:00:00 2001 From: Francisco Utrera Date: Thu, 5 Aug 2021 17:36:07 -0700 Subject: [PATCH 45/47] took out the ch --- chop/utils/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chop/utils/data.py b/chop/utils/data.py index 0c94e20..7928080 100644 --- a/chop/utils/data.py +++ b/chop/utils/data.py @@ -148,8 +148,8 @@ def __call__(self, img): return img.add(rgb.view(3, 1, 1).expand_as(img)) IMAGENET_PCA = { - 'eigval':ch.Tensor([0.2175, 0.0188, 0.0045]), - 'eigvec':ch.Tensor([ + 'eigval':torch.Tensor([0.2175, 0.0188, 0.0045]), + 'eigvec':torch.Tensor([ [-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203], From 86ba4bdb307255788beabcf809dc6b22e5a95c33 Mon Sep 17 00:00:00 2001 From: Geoffrey Negiar Date: Fri, 6 Aug 2021 14:42:04 -0700 Subject: [PATCH 46/47] Added tqdm for iters in optim --- chop/optim.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/chop/optim.py b/chop/optim.py index a05fa14..9484d1f 100644 --- a/chop/optim.py +++ b/chop/optim.py @@ -12,8 +12,11 @@ from numbers import Number import warnings +from tqdm import tqdm + import torch import numpy as np + from scipy import optimize from numbers import Number from chop import utils @@ -153,7 +156,7 @@ def prox2(x, s=None, *args): x = prox1(z - utils.bmul(step_size, grad), step_size, *args_prox) u = torch.zeros_like(x) - for it in range(max_iter): + for it in tqdm(range(max_iter)): z.requires_grad_(True) fval, grad = closure(z) with torch.no_grad(): @@ -316,7 +319,7 @@ def prox(x, s=None): else: raise ValueError("step must be float or backtracking or None") - for it in range(max_iter): + for it in tqdm(range(max_iter)): fval, grad = closure(x) x_next = prox(x - utils.bmul(step_size, grad), step_size, *prox_args) @@ -410,7 +413,7 @@ def minimize_frank_wolfe(closure, x0, lmo, step='sublinear', cert = np.inf * torch.ones(batch_size, device=x.device) - for it in range(max_iter): + for it in tqdm(range(max_iter)): x.requires_grad = True fval, grad = closure(x) @@ -511,7 +514,7 @@ def minimize_alternating_fw_prox(closure, x0, y0, prox=None, lmo=None, lipschitz fval, grad = closure(x + y) - for it in range(max_iter): + for it in tqdm(range(max_iter)): if step == 'sublinear': step_size = 2. / (it + 2) * torch.ones(batch_size, device=x.device) From 47aa7450c70286c402bd7773eb424254054b71a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20N=C3=A9giar?= Date: Thu, 16 Sep 2021 19:58:56 -0700 Subject: [PATCH 47/47] Bug fix for penalized stochastic FW --- chop/stochastic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chop/stochastic.py b/chop/stochastic.py index 2fe0d39..08f58df 100644 --- a/chop/stochastic.py +++ b/chop/stochastic.py @@ -763,12 +763,12 @@ def step(self, closure=None): atom, scale = state['lmo']( -state['grad_est'], state['y'] ) - atom /= torch.linalg.norm(atom) - eff_step = .5 / (state['lipschitz'] * state['lr']) - magnitude = (atom * (p - eff_step * state['grad_est'])).sum() - scale * eff_step + atom_norm2 = torch.linalg.norm(atom) ** 2 + magnitude = (atom * (state['lipschitz'] * state['lr'] * p - state['grad_est'])).sum() - scale magnitude = F.relu(magnitude) + magnitude /= state['lipschitz'] * atom_norm2 w = magnitude * atom - y_update = w - state['y'] + y_update = w - state['lr'] * state['y'] else: y_update, max_step_size = state['lmo'](