From fb83f37d2cb30d9ff7e264bbb3e9511f38388704 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 14:36:09 -0400 Subject: [PATCH 01/77] adding files from muawi_resnets --- refactor/conv_modules.py | 384 +++++++++++++++++++++++++++++++++++ refactor/og_pytorch_cifar.py | 263 ++++++++++++++++++++++++ refactor/resnet.py | 136 +++++++++++++ 3 files changed, 783 insertions(+) create mode 100644 refactor/conv_modules.py create mode 100644 refactor/og_pytorch_cifar.py create mode 100644 refactor/resnet.py diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py new file mode 100644 index 0000000..55ac0e9 --- /dev/null +++ b/refactor/conv_modules.py @@ -0,0 +1,384 @@ +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn.parameter import Parameter, UninitializedParameter +from torch.nn.common_types import _size_2_t +from typing import Optional, List, Tuple, Union +import time +import copy + + +def _contract(tensor, matrix, axis): + """tensor is (..., D, ...), matrix is (P, D), returns (..., P, ...).""" + t = torch.moveaxis(tensor, source=axis, destination=-1) # (..., D) + r = t @ matrix.T # (..., P) + return torch.moveaxis(r, source=-1, destination=axis) # (..., P, ...) + +class FactConv2dPostExp(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + + factory_kwargs = {'device': device, 'dtype': dtype} + self.factory_kwargs = factory_kwargs + + # weight shape: (out_channels, in_channels // groups, *kernel_size) + weight_shape = self.weight.shape + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", torch.empty(weight_shape, **factory_kwargs)) + nn.init.kaiming_normal_(self.weight) + + self.in_features = self.in_channels // self.groups * \ + self.kernel_size[0] * self.kernel_size[1] + triu1 = torch.triu_indices(self.in_channels // self.groups, + self.in_channels // self.groups) + triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], + self.kernel_size[0] + * self.kernel_size[1]) + triu1_len = triu1.shape[1] + triu2_len = triu2.shape[1] + tri1_vec = torch.zeros((triu1_len,), + **factory_kwargs) + self.tri1_vec = Parameter(tri1_vec) + + tri2_vec = torch.zeros((triu2_len,), **factory_kwargs) + self.tri2_vec = Parameter(tri2_vec) + + + def forward(self, input: Tensor) -> Tensor: + U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // self.groups) + U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1]) + U = torch.kron(U1, U2) + U = self._exp_diag(U) + + matrix_shape = (self.out_channels, self.in_features) + composite_weight = torch.reshape( + torch.reshape(self.weight, matrix_shape) @ U, + self.weight.shape + ) + + return self._conv_forward(input, composite_weight, self.bias) + + def _tri_vec_to_mat(self, vec, n): + U = torch.zeros((n, n), **self.factory_kwargs) + U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec + return U + + def _exp_diag(self, mat): + exp_diag = torch.exp(torch.diagonal(mat)) + n = mat.shape[0] + mat[range(n), range(n)] = exp_diag + return mat + +class FactConv2dPreExp(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + + factory_kwargs = {'device': device, 'dtype': dtype} + self.factory_kwargs = factory_kwargs + + # weight shape: (out_channels, in_channels // groups, *kernel_size) + weight_shape = self.weight.shape + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", torch.empty(weight_shape, **factory_kwargs)) + nn.init.kaiming_normal_(self.weight) + + self.in_features = self.in_channels // self.groups * \ + self.kernel_size[0] * self.kernel_size[1] + triu1 = torch.triu_indices(self.in_channels // self.groups, + self.in_channels // self.groups, + **factory_kwargs) + self.scat_idx1=triu1[0]*self.in_channels//self.groups + triu1[1] + triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], + self.kernel_size[0] + * self.kernel_size[1], + **factory_kwargs) + + self.scat_idx2=triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1] + triu1_len = triu1.shape[1] + triu2_len = triu2.shape[1] + tri1_vec = torch.zeros((triu1_len,), + **factory_kwargs) + + self.tri1_vec = Parameter(tri1_vec) + + tri2_vec = torch.zeros((triu2_len,), **factory_kwargs) + self.tri2_vec = Parameter(tri2_vec) + + def construct_Us(self): + self.tri1_vec = Parameter(self._tri_vec_to_mat(self.tri1_vec, self.in_channels // + self.groups,self.scat_idx1)) + self.tri2_vec = Parameter(self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], + self.scat_idx2)) + + + def forward(self, input: Tensor) -> Tensor: + U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // + self.groups, self.scat_idx1) + U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], + self.scat_idx2 + ) + #.reshape([self.kernel_size[0], self.kernel_size[1], + # self.kernel_size[0] , self.kernel_size[1]]) + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + # composite_weight = torch.einsum("ijkl, jm -> imkl", self.weight + return self._conv_forward(input, composite_weight, self.bias) + +# def forward(self, input: Tensor) -> Tensor: +# #U1 = self.tri1_vec +# #U2 = self.tri2_vec +# U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // +# self.groups,self.scat_idx1) +# # print(self.in_channels//self.groups) +# U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], +# self.scat_idx2) +# # +# U = torch.kron(U1, U2) +# #s1 = time.time() +# #U = torch.kron(U1, U2) +# #torch.cuda.synchronize() +# #print("torch.kron",time.time()-s1) +# #U = self._kron(U1, U2) +# #s1 = time.time() +# #U = self._kron(U1, U2) +# #torch.cuda.synchronize() +# #print("self.kron", time.time()-s1) + +# matrix_shape = (self.out_channels, self.in_features) +# composite_weight = torch.reshape( +# torch.reshape(self.weight, matrix_shape) @ U, +# self.weight.shape +# ) +# output = self._conv_forward(input, composite_weight, self.bias) +# return output + + + + def _kron(self, a, b): + a_shape = a.shape + b_shape = b.shape + c_shape = (a.shape[0]*b.shape[0], a.shape[1]*b.shape[1]) + + a = a.reshape(-1, 1) + b = b.reshape(1, -1) + + product = a@b + product = product.reshape(a_shape[0], a_shape[1], b.shape[0], b.shape[1]) + product = product.permute(0, 2, 1, 3) + product = product.reshape(c_shape[0], c_shape[1]) + return product + + + + def _tri_vec_to_mat(self, vec, n, scat_idx): + U = torch.zeros((n* n), + **self.factory_kwargs).scatter_(0,scat_idx,vec).view(n,n) + #U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec + U = torch.diagonal_scatter(U,U.diagonal().exp_()) + #self._exp_diag(U) + return U + + #def _tri_vec_to_mat(self, vec, n): + # U = torch.zeros((n, n), **self.factory_kwargs) + # U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec + # U = self._exp_diag(U) + # return U + + + def _exp_diag(self, mat): + exp_diag = torch.exp(torch.diagonal(mat)) + n = mat.shape[0] + mat[range(n), range(n)] = exp_diag + return mat + +# +# def forward(self, input: Tensor) -> Tensor: +# s1 = time.time() +# U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // +# self.groups,self.scat_idx1) +# print(time.time()-s1) +# print(self.in_channels//self.groups) +# U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], +# self.scat_idx2) +# U = torch.kron(U1, U2) +# matrix_shape = (self.out_channels, self.in_features) +# composite_weight = torch.reshape( +# torch.reshape(self.weight, matrix_shape) @ U, +# self.weight.shape +# ) +# output = self._conv_forward(input, composite_weight, self.bias) +# return output +# +# def _tri_vec_to_mat(self, vec, n, scat_idx): +# U = torch.zeros((n* n), +# **self.factory_kwargs).scatter_(0,scat_idx,vec).view(n,n) +# #U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec +# U = torch.diagonal_scatter(U,U.diagonal().exp_()) +# #self._exp_diag(U) +# return U +# +# +class FactConv2dK(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + k: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + + factory_kwargs = {'device': device, 'dtype': dtype} + self.factory_kwargs = factory_kwargs + self.k = k + # weight shape: (out_channels, in_channels // groups, *kernel_size) + weight_shape = self.weight.shape + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", torch.empty(weight_shape, **factory_kwargs)) + nn.init.kaiming_normal_(self.weight) + + self.in_features = self.in_channels // self.groups * \ + self.kernel_size[0] * self.kernel_size[1] + + triu1 = torch.triu_indices(self.in_channels // self.groups, + self.in_channels // self.groups, + **factory_kwargs) + self.scat_idx1=triu1[0]*self.in_channels//self.groups + triu1[1] + triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], + self.kernel_size[0] + * self.kernel_size[1], + **factory_kwargs) + + self.scat_idx2=triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1] + triu1_len = triu1.shape[1] + triu2_len = triu2.shape[1] + tri1_vec = torch.zeros((triu1_len,), **factory_kwargs) + tri2_vec = torch.zeros((triu2_len,), **factory_kwargs) + + self.tri1_vecs\ + = nn.ParameterList([nn.Parameter(copy.deepcopy(tri1_vec)) for i in + range(self.k)]) + self.tri2_vecs\ + = nn.ParameterList([nn.Parameter(copy.deepcopy(tri2_vec)) for i in + range(self.k)]) + + for param in self.tri1_vecs: + nn.init.trunc_normal_(param, std=0.02) + for param in self.tri2_vecs: + nn.init.trunc_normal_(param, std=0.02) + + + def construct_Us(self): + self.tri1_vec = Parameter(self._tri_vec_to_mat(self.tri1_vec, self.in_channels // + self.groups,self.scat_idx1)) + self.tri2_vec = Parameter(self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], + self.scat_idx2)) + + + def forward(self, input: Tensor) -> Tensor: + krons = [] + comp_weights = [] + for i in range(self.k): + U1 = self._tri_vec_to_mat(self.tri1_vecs[i], self.in_channels // + self.groups,self.scat_idx1) + U2 = self._tri_vec_to_mat(self.tri2_vecs[i], self.kernel_size[0] * self.kernel_size[1], + self.scat_idx2) + U = torch.kron(U1, U2) + krons.append(U) + + matrix_shape = (self.out_channels, self.in_features) + comp_weight = torch.reshape( + torch.reshape(self.weight, matrix_shape) @ U, + self.weight.shape) + comp_weights.append(comp_weight) + + arr = torch.stack(krons, dim=0) + U = torch.mean(arr, dim=0) + arr2 = torch.stack(comp_weights, dim=0) + composite_weight = torch.mean(arr2, dim=0) + + # matrix_shape = (self.out_channels, self.in_features) + # composite_weight = torch.reshape( + # torch.reshape(self.weight, matrix_shape) @ U, + # self.weight.shape + # ) + + output = self._conv_forward(input, composite_weight, self.bias) + return output + + def _kron(self, a, b): + a_shape = a.shape + b_shape = b.shape + c_shape = (a.shape[0]*b.shape[0], a.shape[1]*b.shape[1]) + + a = a.reshape(-1, 1) + b = b.reshape(1, -1) + + product = a@b + product = product.reshape(a_shape[0], a_shape[1], b.shape[0], b.shape[1]) + product = product.permute(0, 2, 1, 3) + product = product.reshape(c_shape[0], c_shape[1]) + return product + + + + def _tri_vec_to_mat(self, vec, n, scat_idx): + U = torch.zeros((n* n), + **self.factory_kwargs).scatter_(0,scat_idx,vec).view(n,n) + #U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec + U = torch.diagonal_scatter(U,U.diagonal().exp_()) + #self._exp_diag(U) + return U + + + def _exp_diag(self, mat): + exp_diag = torch.exp(torch.diagonal(mat)) + n = mat.shape[0] + mat[range(n), range(n)] = exp_diag + return mat diff --git a/refactor/og_pytorch_cifar.py b/refactor/og_pytorch_cifar.py new file mode 100644 index 0000000..bf311cc --- /dev/null +++ b/refactor/og_pytorch_cifar.py @@ -0,0 +1,263 @@ +'''Train CIFAR10 with PyTorch.''' +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from torch.profiler import profile, record_function, ProfilerActivity + +import torchvision +import torchvision.transforms as transforms + +import os +import argparse + +from pytorch_cifar_utils import progress_bar, set_seeds + +from test_models_safety import PostExp, PreExp +from hooks import wandb_forwards_hook, wandb_backwards_hook + +import wandb + +from distutils.util import strtobool + +from resnet import ResNet18 +from vgg import VGG + +from test_models_safety import PostExp, PreExp + +def save_model(args, model): + src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/CIFAR10_pytorch_tta/" + model_dir = src + args.name + os.makedirs(model_dir, exist_ok=True) + os.chdir(model_dir) + + #saves loss & accuracy in the trial directory -- all trials + trial_dir = model_dir + "/trial_" + str(1) + os.makedirs(trial_dir, exist_ok=True) + os.chdir(trial_dir) + + torch.save(model.state_dict(), trial_dir+ "/model.pt") + torch.save(args, trial_dir+ "/args.pt") + + +parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') +parser.add_argument('--lr', default=0.1, type=float, help='learning rate') +parser.add_argument('--resume', '-r', action='store_true', + help='resume from checkpoint') +parser.add_argument('--net', type=str, default='vgg', choices=['vgg', 'vggbn', + 'resnet', 'factnetv1', 'factnetdefault', 'vggfact', 'vggbnfact'], help="which convmodule to use") +parser.add_argument('--freeze_spatial', dest='freeze_spatial', + type=lambda x: bool(strtobool(x)), default=True, + help="freeze spatial filters for LearnableCov models") +parser.add_argument('--freeze_channel', dest='freeze_channel', + type=lambda x: bool(strtobool(x)), default=False, + help="freeze channels for LearnableCov models") +parser.add_argument('--spatial_init', type=str, default='V1', choices=['default', 'V1'], + help="initialization for spatial filters for LearnableCov models") +parser.add_argument('--s', type=int, default=2, help='V1 size') +parser.add_argument('--f', type=float, default=0.1, help='V1 spatial frequency') +parser.add_argument('--scale', type=int, default=1, help='V1 scale') +parser.add_argument('--name', type=str, default='TESTING_VGG', + help='filename for saved model') +parser.add_argument('--bias', dest='bias', type=lambda x: bool(strtobool(x)), + default=False, help='bias=True or False') + +args = parser.parse_args() + +device = 'cuda' if torch.cuda.is_available() else 'cpu' +best_acc = 0 # best test accuracy +start_epoch = 0 # start from epoch 0 or last checkpoint epoch + +# Data +print('==> Preparing data..') +transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +trainset = torchvision.datasets.CIFAR10( + root='./data', train=True, download=True, transform=transform_train) +trainloader = torch.utils.data.DataLoader( + trainset, batch_size=128, shuffle=True, num_workers=8) + +testset = torchvision.datasets.CIFAR10( + root='./data', train=False, download=True, transform=transform_test) +testloader = torch.utils.data.DataLoader( + testset, batch_size=1000, shuffle=False, num_workers=8) + +classes = ('plane', 'car', 'bird', 'cat', 'deer', + 'dog', 'frog', 'horse', 'ship', 'truck') + +# Model +print('==> Building model..') +# net = VGG('VGG19') +# net = ResNet18() +# net = PreActResNet18() +# net = GoogLeNet() +# net = DenseNet121() +# net = ResNeXt29_2x64d() +# net = MobileNet() +# net = MobileNetV2() +# net = DPN92() +# net = ShuffleNetG2() +# net = SENet18() +# net = ShuffleNetV2(1) +# net = EfficientNetB0() +# net = RegNetX_200MF() +# +# +# +from ConvModules import FactConv2dPreExp +def replace_layers_keep_weight(model): + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + replace_layers_keep_weight(module) + if isinstance(module, nn.Conv2d): + ## simple module + new_module = FactConv2dPreExp( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + #new_module.tri1_vec = nn.Parameter(int(new_module.tri1_vec * scale)) + setattr(model, n, new_module) + +set_seeds(0) +if args.net == "vgg": + net=VGG("VGG11", False) + run_name = "OGVGG" +elif args.net == "vggbn": + net=VGG("VGG11", True) + run_name = "OGVGGBN" +elif args.net == "resnet": + net = ResNet18() + run_name = "Resnet" +elif args.net == "factnetv1": + net = PreExp(1, args.s, args.f, args.scale, args.bias, args.freeze_spatial, args.freeze_channel, "V1").to(device) + run_name ="FactnetV1" +elif args.net == "factnetdefault": + net = PreExp(1, args.s, args.f, args.scale, args.bias, args.freeze_spatial, args.freeze_channel, "default").to(device) + run_name ="Factnetdefault" +elif args.net == "vggfact": + net=VGG("VGG11", False) + replace_layers_keep_weight(net) + run_name ="vggfactdefault" +elif args.net == "vggbnfact": + net=VGG("VGG11", True) + replace_layers_keep_weight(net) + run_name ="vggbnnfactdefault" + + +set_seeds(0) +set_seeds(0) + +net = net.to(device) +#if device == 'cuda': + #net = torch.nn.DataParallel(net) + #cudnn.benchmark = True +wandb_dir = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/wandb" +os.makedirs(wandb_dir, exist_ok=True) +os.chdir(wandb_dir) +run_name = "OGVGG" + +run = wandb.init(project="random_project", config=args, + group="pytorch_cifar_better_tracked_og", name=run_name, dir=wandb_dir) +#wandb.watch(net, log='all', log_freq=1) + +if args.resume: + # Load checkpoint. + print('==> Resuming from checkpoint..') + assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' + checkpoint = torch.load('./checkpoint/ckpt.pth') + net.load_state_dict(checkpoint['net']) + best_acc = checkpoint['acc'] + start_epoch = checkpoint['epoch'] + +criterion = nn.CrossEntropyLoss() +optimizer = optim.SGD(net.parameters(), lr=args.lr, + momentum=0.9, weight_decay=5e-4) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) + + +# Training +def train(epoch): + print('\nEpoch: %d' % epoch) + net.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(trainloader): + inputs, targets = inputs.to(device), targets.to(device) + optimizer.zero_grad() + outputs = net(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) + + +def test(epoch): + global best_acc + net.eval() + test_loss = 0 + correct = 0 + total = 0 + with torch.no_grad(): + for batch_idx, (inputs, targets) in enumerate(testloader): + inputs, targets = inputs.to(device), targets.to(device) + outputs = net(inputs) + loss = criterion(outputs, targets) + + test_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) + + # Save checkpoint. + acc = 100.*correct/total + run.log({"accuracy":acc}) + if acc > best_acc: + print('Saving..') + state = { + 'net': net.state_dict(), + 'acc': acc, + 'epoch': epoch, + } + #if not os.path.isdir('checkpoint'): + # os.mkdir('checkpoint') + #torch.save(state, './checkpoint/ckpt.pth') + save_model(args, net) + best_acc = acc + + +for epoch in range(start_epoch, start_epoch+200):#00 + train(epoch) + test(epoch) + scheduler.step() +args.name += "final" +save_model(args, net) diff --git a/refactor/resnet.py b/refactor/resnet.py new file mode 100644 index 0000000..4ed155e --- /dev/null +++ b/refactor/resnet.py @@ -0,0 +1,136 @@ +'''ResNet in PyTorch. + +For Pre-activation ResNet, see 'preact_resnet.py'. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * + planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + + +def ResNet18_Class100(): + return ResNet(BasicBlock, [2, 2, 2, 2], 100) + + +def ResNet34(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + + +def ResNet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + + +def ResNet101(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + + +def ResNet152(): + return ResNet(Bottleneck, [3, 8, 36, 3]) + + +def test(): + net = ResNet18() + y = net(torch.randn(1, 3, 32, 32)) + print(y.size()) + +#test() From 0f10d2c8702e1cb1f4a848852af611fdc40f5d42 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 15:18:50 -0400 Subject: [PATCH 02/77] removing comments and unused lines --- refactor/conv_modules.py | 223 --------------------------------------- 1 file changed, 223 deletions(-) diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py index 55ac0e9..d5dc688 100644 --- a/refactor/conv_modules.py +++ b/refactor/conv_modules.py @@ -149,236 +149,13 @@ def forward(self, input: Tensor) -> Tensor: U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], self.scat_idx2 ) - #.reshape([self.kernel_size[0], self.kernel_size[1], - # self.kernel_size[0] , self.kernel_size[1]]) - # flatten over filter dims and contract - composite_weight = _contract(self.weight, U1.T, 1) composite_weight = _contract( torch.flatten(composite_weight, -2, -1), U2.T, -1 ).reshape(self.weight.shape) - # composite_weight = torch.einsum("ijkl, jm -> imkl", self.weight return self._conv_forward(input, composite_weight, self.bias) -# def forward(self, input: Tensor) -> Tensor: -# #U1 = self.tri1_vec -# #U2 = self.tri2_vec -# U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // -# self.groups,self.scat_idx1) -# # print(self.in_channels//self.groups) -# U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], -# self.scat_idx2) -# # -# U = torch.kron(U1, U2) -# #s1 = time.time() -# #U = torch.kron(U1, U2) -# #torch.cuda.synchronize() -# #print("torch.kron",time.time()-s1) -# #U = self._kron(U1, U2) -# #s1 = time.time() -# #U = self._kron(U1, U2) -# #torch.cuda.synchronize() -# #print("self.kron", time.time()-s1) - -# matrix_shape = (self.out_channels, self.in_features) -# composite_weight = torch.reshape( -# torch.reshape(self.weight, matrix_shape) @ U, -# self.weight.shape -# ) -# output = self._conv_forward(input, composite_weight, self.bias) -# return output - - - - def _kron(self, a, b): - a_shape = a.shape - b_shape = b.shape - c_shape = (a.shape[0]*b.shape[0], a.shape[1]*b.shape[1]) - - a = a.reshape(-1, 1) - b = b.reshape(1, -1) - - product = a@b - product = product.reshape(a_shape[0], a_shape[1], b.shape[0], b.shape[1]) - product = product.permute(0, 2, 1, 3) - product = product.reshape(c_shape[0], c_shape[1]) - return product - - def _tri_vec_to_mat(self, vec, n, scat_idx): - U = torch.zeros((n* n), - **self.factory_kwargs).scatter_(0,scat_idx,vec).view(n,n) - #U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec U = torch.diagonal_scatter(U,U.diagonal().exp_()) - #self._exp_diag(U) return U - #def _tri_vec_to_mat(self, vec, n): - # U = torch.zeros((n, n), **self.factory_kwargs) - # U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec - # U = self._exp_diag(U) - # return U - - - def _exp_diag(self, mat): - exp_diag = torch.exp(torch.diagonal(mat)) - n = mat.shape[0] - mat[range(n), range(n)] = exp_diag - return mat - -# -# def forward(self, input: Tensor) -> Tensor: -# s1 = time.time() -# U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // -# self.groups,self.scat_idx1) -# print(time.time()-s1) -# print(self.in_channels//self.groups) -# U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], -# self.scat_idx2) -# U = torch.kron(U1, U2) -# matrix_shape = (self.out_channels, self.in_features) -# composite_weight = torch.reshape( -# torch.reshape(self.weight, matrix_shape) @ U, -# self.weight.shape -# ) -# output = self._conv_forward(input, composite_weight, self.bias) -# return output -# -# def _tri_vec_to_mat(self, vec, n, scat_idx): -# U = torch.zeros((n* n), -# **self.factory_kwargs).scatter_(0,scat_idx,vec).view(n,n) -# #U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec -# U = torch.diagonal_scatter(U,U.diagonal().exp_()) -# #self._exp_diag(U) -# return U -# -# -class FactConv2dK(nn.Conv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t, - stride: _size_2_t = 1, - padding: Union[str, _size_2_t] = 0, - dilation: _size_2_t = 1, - groups: int = 1, - k: int = 1, - bias: bool = True, - padding_mode: str = 'zeros', # TODO: refine this type - device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), - dtype=None - ) -> None: - # init as Conv2d - super().__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, bias, padding_mode, device, dtype) - - factory_kwargs = {'device': device, 'dtype': dtype} - self.factory_kwargs = factory_kwargs - self.k = k - # weight shape: (out_channels, in_channels // groups, *kernel_size) - weight_shape = self.weight.shape - del self.weight # remove Parameter, create buffer - self.register_buffer("weight", torch.empty(weight_shape, **factory_kwargs)) - nn.init.kaiming_normal_(self.weight) - - self.in_features = self.in_channels // self.groups * \ - self.kernel_size[0] * self.kernel_size[1] - - triu1 = torch.triu_indices(self.in_channels // self.groups, - self.in_channels // self.groups, - **factory_kwargs) - self.scat_idx1=triu1[0]*self.in_channels//self.groups + triu1[1] - triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], - self.kernel_size[0] - * self.kernel_size[1], - **factory_kwargs) - - self.scat_idx2=triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1] - triu1_len = triu1.shape[1] - triu2_len = triu2.shape[1] - tri1_vec = torch.zeros((triu1_len,), **factory_kwargs) - tri2_vec = torch.zeros((triu2_len,), **factory_kwargs) - - self.tri1_vecs\ - = nn.ParameterList([nn.Parameter(copy.deepcopy(tri1_vec)) for i in - range(self.k)]) - self.tri2_vecs\ - = nn.ParameterList([nn.Parameter(copy.deepcopy(tri2_vec)) for i in - range(self.k)]) - - for param in self.tri1_vecs: - nn.init.trunc_normal_(param, std=0.02) - for param in self.tri2_vecs: - nn.init.trunc_normal_(param, std=0.02) - - - def construct_Us(self): - self.tri1_vec = Parameter(self._tri_vec_to_mat(self.tri1_vec, self.in_channels // - self.groups,self.scat_idx1)) - self.tri2_vec = Parameter(self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], - self.scat_idx2)) - - - def forward(self, input: Tensor) -> Tensor: - krons = [] - comp_weights = [] - for i in range(self.k): - U1 = self._tri_vec_to_mat(self.tri1_vecs[i], self.in_channels // - self.groups,self.scat_idx1) - U2 = self._tri_vec_to_mat(self.tri2_vecs[i], self.kernel_size[0] * self.kernel_size[1], - self.scat_idx2) - U = torch.kron(U1, U2) - krons.append(U) - - matrix_shape = (self.out_channels, self.in_features) - comp_weight = torch.reshape( - torch.reshape(self.weight, matrix_shape) @ U, - self.weight.shape) - comp_weights.append(comp_weight) - - arr = torch.stack(krons, dim=0) - U = torch.mean(arr, dim=0) - arr2 = torch.stack(comp_weights, dim=0) - composite_weight = torch.mean(arr2, dim=0) - - # matrix_shape = (self.out_channels, self.in_features) - # composite_weight = torch.reshape( - # torch.reshape(self.weight, matrix_shape) @ U, - # self.weight.shape - # ) - - output = self._conv_forward(input, composite_weight, self.bias) - return output - - def _kron(self, a, b): - a_shape = a.shape - b_shape = b.shape - c_shape = (a.shape[0]*b.shape[0], a.shape[1]*b.shape[1]) - - a = a.reshape(-1, 1) - b = b.reshape(1, -1) - - product = a@b - product = product.reshape(a_shape[0], a_shape[1], b.shape[0], b.shape[1]) - product = product.permute(0, 2, 1, 3) - product = product.reshape(c_shape[0], c_shape[1]) - return product - - - - def _tri_vec_to_mat(self, vec, n, scat_idx): - U = torch.zeros((n* n), - **self.factory_kwargs).scatter_(0,scat_idx,vec).view(n,n) - #U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec - U = torch.diagonal_scatter(U,U.diagonal().exp_()) - #self._exp_diag(U) - return U - - - def _exp_diag(self, mat): - exp_diag = torch.exp(torch.diagonal(mat)) - n = mat.shape[0] - mat[range(n), range(n)] = exp_diag - return mat From e3ab5ba4cbfcc0613468716810f1ddf981fd97b3 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 15:29:00 -0400 Subject: [PATCH 03/77] remove comments and unused lines --- refactor/og_pytorch_cifar.py | 72 +++++------------------------------- 1 file changed, 9 insertions(+), 63 deletions(-) diff --git a/refactor/og_pytorch_cifar.py b/refactor/og_pytorch_cifar.py index bf311cc..70f6662 100644 --- a/refactor/og_pytorch_cifar.py +++ b/refactor/og_pytorch_cifar.py @@ -14,7 +14,6 @@ from pytorch_cifar_utils import progress_bar, set_seeds -from test_models_safety import PostExp, PreExp from hooks import wandb_forwards_hook, wandb_backwards_hook import wandb @@ -22,10 +21,9 @@ from distutils.util import strtobool from resnet import ResNet18 -from vgg import VGG -from test_models_safety import PostExp, PreExp - +# TODO: import define_models function + def save_model(args, model): src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/CIFAR10_pytorch_tta/" model_dir = src + args.name @@ -98,23 +96,8 @@ def save_model(args, model): # Model print('==> Building model..') -# net = VGG('VGG19') -# net = ResNet18() -# net = PreActResNet18() -# net = GoogLeNet() -# net = DenseNet121() -# net = ResNeXt29_2x64d() -# net = MobileNet() -# net = MobileNetV2() -# net = DPN92() -# net = ShuffleNetG2() -# net = SENet18() -# net = ShuffleNetV2(1) -# net = EfficientNetB0() -# net = RegNetX_200MF() -# -# -# + + from ConvModules import FactConv2dPreExp def replace_layers_keep_weight(model): for n, module in model.named_children(): @@ -138,39 +121,13 @@ def replace_layers_keep_weight(model): #new_module.tri1_vec = nn.Parameter(int(new_module.tri1_vec * scale)) setattr(model, n, new_module) -set_seeds(0) -if args.net == "vgg": - net=VGG("VGG11", False) - run_name = "OGVGG" -elif args.net == "vggbn": - net=VGG("VGG11", True) - run_name = "OGVGGBN" -elif args.net == "resnet": - net = ResNet18() - run_name = "Resnet" -elif args.net == "factnetv1": - net = PreExp(1, args.s, args.f, args.scale, args.bias, args.freeze_spatial, args.freeze_channel, "V1").to(device) - run_name ="FactnetV1" -elif args.net == "factnetdefault": - net = PreExp(1, args.s, args.f, args.scale, args.bias, args.freeze_spatial, args.freeze_channel, "default").to(device) - run_name ="Factnetdefault" -elif args.net == "vggfact": - net=VGG("VGG11", False) - replace_layers_keep_weight(net) - run_name ="vggfactdefault" -elif args.net == "vggbnfact": - net=VGG("VGG11", True) - replace_layers_keep_weight(net) - run_name ="vggbnnfactdefault" - - -set_seeds(0) -set_seeds(0) +set_seeds(args.seed) +net = define_models(args.net) +run_name = args.net + +set_seeds(args.seed) net = net.to(device) -#if device == 'cuda': - #net = torch.nn.DataParallel(net) - #cudnn.benchmark = True wandb_dir = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/wandb" os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) @@ -180,14 +137,6 @@ def replace_layers_keep_weight(model): group="pytorch_cifar_better_tracked_og", name=run_name, dir=wandb_dir) #wandb.watch(net, log='all', log_freq=1) -if args.resume: - # Load checkpoint. - print('==> Resuming from checkpoint..') - assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' - checkpoint = torch.load('./checkpoint/ckpt.pth') - net.load_state_dict(checkpoint['net']) - best_acc = checkpoint['acc'] - start_epoch = checkpoint['epoch'] criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=args.lr, @@ -248,9 +197,6 @@ def test(epoch): 'acc': acc, 'epoch': epoch, } - #if not os.path.isdir('checkpoint'): - # os.mkdir('checkpoint') - #torch.save(state, './checkpoint/ckpt.pth') save_model(args, net) best_acc = acc From 8411f744c3611c7b8eaf9058f8e4fe86c4ebcf76 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 15:53:08 -0400 Subject: [PATCH 04/77] fixed and tested conv_modules --- refactor/conv_modules.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py index d5dc688..a16ebc3 100644 --- a/refactor/conv_modules.py +++ b/refactor/conv_modules.py @@ -149,6 +149,8 @@ def forward(self, input: Tensor) -> Tensor: U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], self.scat_idx2 ) + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) composite_weight = _contract( torch.flatten(composite_weight, -2, -1), U2.T, -1 ).reshape(self.weight.shape) @@ -156,6 +158,8 @@ def forward(self, input: Tensor) -> Tensor: def _tri_vec_to_mat(self, vec, n, scat_idx): + U = torch.zeros((n* n), + **self.factory_kwargs).scatter_(0,scat_idx,vec).view(n,n) U = torch.diagonal_scatter(U,U.diagonal().exp_()) return U From fe6cbc77c1bc042b23bcd74a2f326a60edf33f20 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 17:04:55 -0400 Subject: [PATCH 05/77] add whitespace --- refactor/conv_modules.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py index a16ebc3..dc51453 100644 --- a/refactor/conv_modules.py +++ b/refactor/conv_modules.py @@ -14,6 +14,7 @@ def _contract(tensor, matrix, axis): r = t @ matrix.T # (..., P) return torch.moveaxis(r, source=-1, destination=axis) # (..., P, ...) + class FactConv2dPostExp(nn.Conv2d): def __init__( self, @@ -74,17 +75,20 @@ def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, composite_weight, self.bias) + def _tri_vec_to_mat(self, vec, n): U = torch.zeros((n, n), **self.factory_kwargs) U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec return U + def _exp_diag(self, mat): exp_diag = torch.exp(torch.diagonal(mat)) n = mat.shape[0] mat[range(n), range(n)] = exp_diag return mat + class FactConv2dPreExp(nn.Conv2d): def __init__( self, @@ -136,6 +140,7 @@ def __init__( tri2_vec = torch.zeros((triu2_len,), **factory_kwargs) self.tri2_vec = Parameter(tri2_vec) + def construct_Us(self): self.tri1_vec = Parameter(self._tri_vec_to_mat(self.tri1_vec, self.in_channels // self.groups,self.scat_idx1)) @@ -163,3 +168,4 @@ def _tri_vec_to_mat(self, vec, n, scat_idx): U = torch.diagonal_scatter(U,U.diagonal().exp_()) return U + From a472f7d53312b379ec8be1d67c6b8401f0a80279 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 16:08:57 -0400 Subject: [PATCH 06/77] add whitespace --- refactor/conv_modules.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py index dc51453..d6b008c 100644 --- a/refactor/conv_modules.py +++ b/refactor/conv_modules.py @@ -11,7 +11,9 @@ def _contract(tensor, matrix, axis): """tensor is (..., D, ...), matrix is (P, D), returns (..., P, ...).""" t = torch.moveaxis(tensor, source=axis, destination=-1) # (..., D) + r = t @ matrix.T # (..., P) + return torch.moveaxis(r, source=-1, destination=axis) # (..., P, ...) @@ -84,7 +86,9 @@ def _tri_vec_to_mat(self, vec, n): def _exp_diag(self, mat): exp_diag = torch.exp(torch.diagonal(mat)) + n = mat.shape[0] + mat[range(n), range(n)] = exp_diag return mat From 2ec009e031e3ab82b6b026e7eb93a204231b2edf Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 17:10:59 -0400 Subject: [PATCH 07/77] trying to setup define_models --- models/V1_models.py | 385 +++++++++++++++++++++++++++++++++++ models/__init__.py | 1 + models/define_models.py | 6 + models/resnet.py | 137 +++++++++++++ refactor/og_pytorch_cifar.py | 14 +- 5 files changed, 537 insertions(+), 6 deletions(-) create mode 100644 models/V1_models.py create mode 100644 models/__init__.py create mode 100644 models/define_models.py create mode 100644 models/resnet.py diff --git a/models/V1_models.py b/models/V1_models.py new file mode 100644 index 0000000..c16d661 --- /dev/null +++ b/models/V1_models.py @@ -0,0 +1,385 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F +import sys +sys.path.insert(0, '/research/harris/vivian/structured_random_features/') +from src.models.init_weights import V1_init, classical_init, V1_weights +import gc +import LearnableCov + +def train(model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + optimizer.step() + + +def test(model, device, test_loader, epoch): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss + pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + accuracy = 100. * correct / len(test_loader.dataset) + + print('Test Epoch: {}\t Avg Loss: {:.4f}\t Accuracy: {:.2f}%'.format( + epoch, test_loss, accuracy)) + + return test_loss, accuracy + +class V1_CIFAR10(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super().__init__() + + # fixed feature layers + self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, + kernel_size=7, stride=1, padding=3, bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim + 3, out_channels=hidden_dim, + kernel_size=7, stride=1, padding=3, bias=bias) + self.relu = nn.ReLU() + + # unsupervised layers + self.bn_x = nn.BatchNorm2d(3) + self.bn_h1 = nn.BatchNorm2d(hidden_dim + 3) + self.bn_h2 = nn.BatchNorm2d(hidden_dim * 2 + 3) + + # supervised layers + self.clf = nn.Linear((8 ** 2) * (hidden_dim * 2 + 3), 10) + + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (hidden_dim * 7 * 7) + center = (3., 3.) + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): + # methods + # smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) + smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + flatten = nn.Flatten() + + x = self.bn_x(x) + h = torch.cat((self.relu(self.v1_layer(x)), smooth(x)), 1) + h = self.bn_h1(h) + h = torch.cat((self.relu(self.v1_layer2(h)), smooth(h)), 1) + h = self.bn_h2(h) + h = flatten(pool(h)) + return self.clf(h) + + +class Rand_Scat_Block(nn.Module): + def __init__(self, in_chan, num_filt, size, spatial_freq, + kernel_size=7, stride=1, padding=3, scale=None, bias=True, seed=None): + super().__init__() + + out_chan = in_chan + num_filt + + self.v1 = nn.Conv2d(in_channels=in_chan, out_channels=num_filt, + kernel_size=kernel_size, stride=stride, padding=padding, + scale=scale, bias=bias) + self.bn = nn.BatchNorm2d(num_filt) + self.relu = nn.ReLU() + + # V1 params + if scale is None: + scale = 1 / (in_chan * np.prod(kernel_size)) + center = ((kernel_size - 1) / 2, (kernel_size - 1) / 2) + + # init weights + V1_init(self.v1, size, spatial_freq, center, scale, bias, seed) + self.v1.weight.requires_grad = False + if bias: + self.v1.bias.requires_grad = False + + def forward(self, x): + smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) + h = self.relu(self.v1(x)) + concat = torch.cat((h, smooth(x)), 1) # concatenate with smoothed input + return self.bn(concat) + +class Rand_Scat_CIFAR10(nn.Module): + pass + +class Learned_Rand_Scat_CIFAR10(nn.Module): + def __init__(self, num_filt, size, spatial_freq, scale, bias, seed=None): + super().__init__() + + # channel dimensions + dims = [3, 64, 128] + + # fixed feature layers + self.v1_layer = nn.Conv2d(in_channels=dims[0], out_channels=num_filt, + kernel_size=7, stride=1, padding=3, bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=dims[1], out_channels=num_filt, + kernel_size=7, stride=1, padding=3, bias=bias) + self.relu = nn.ReLU() + + # unsupervised layers + self.bn_x = nn.BatchNorm2d(3) + self.bn_h1 = nn.LazyBatchNorm2d() + self.bn_h2 = nn.LazyBatchNorm2d() + + # supervised layers + self.L1 = nn.LazyConv2d(out_channels=dims[1], kernel_size=1, bias=False) + self.L2 = nn.LazyConv2d(out_channels=dims[2], kernel_size=1, bias=False) + self.clf = nn.LazyLinear(10) + + # init fixed weights + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (num_filt * 7 * 7) + center = (3., 3.) + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias==True: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + +# OLD +# def forward(self, x): +# h1 = self.relu(self.v1_layer(self.bn_x(x))) +# h2 = self.relu(self.v1_layer2(self.bn_h1(h1))) + +# pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) +# x_pool = self.bn0(pool(x)) +# h1_pool = self.bn1(pool(h1)) +# h2_pool = self.bn2(pool(h2)) + +# x_flat = x_pool.view(x_pool.size(0), -1) #view +# h1_flat = h1_pool.view(h1_pool.size(0), -1) #view +# h2_flat = h2_pool.view(h2_pool.size(0), -1) #view + + +# concat = torch.cat((x_flat, h1_flat, h2_flat), 1) + + def forward(self, x): + # methods + smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + flatten = nn.Flatten() + + # do pass + x = self.bn_x(x) + h1 = self.relu(self.v1_layer(x)) + h1 = torch.cat((h1, smooth(x)), 1) + h1 = self.bn_h1(h1) # chan: 3 + num_filt + h2 = self.relu(self.v1_layer2(self.L1(h1))) + h2 = torch.cat((h2, smooth(h1)), 1) # chan: 3 + num_filt + dims[1] + h2 = self.bn_h2(h2) + + concat = flatten(pool(h2)) + + beta = self.clf(concat) + + return beta + + +class V1_CIFAR100(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super(V1_CIFAR100, self).__init__() + self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.clf = nn.Linear((3 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 100) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(3) + self.bn0 = nn.BatchNorm2d(3) + self.bn1 = nn.BatchNorm2d(hidden_dim) + self.bn2 = nn.BatchNorm2d(hidden_dim) + self.bn_h1 = nn.BatchNorm2d(hidden_dim) + + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (hidden_dim * 7 * 7) + center = None + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias==True: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): #[128, 3, 32, 32] + h1 = self.relu(self.v1_layer(self.bn(x))) + h2 = self.relu(self.v1_layer2(self.bn_h1(h1))) + + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + x_pool = self.bn0(pool(x)) + h1_pool = self.bn1(pool(h1)) + h2_pool = self.bn2(pool(h2)) + + x_flat = x_pool.view(x_pool.size(0), -1) + h1_flat = h1_pool.view(h1_pool.size(0), -1) + h2_flat = h2_pool.view(h2_pool.size(0), -1) + + concat = torch.cat((x_flat, h1_flat, h2_flat), 1) + + beta = self.clf(concat) + return beta + +class V1_MNIST(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super(V1_MNIST, self).__init__() + self.v1_layer = nn.Conv2d(in_channels=1, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.clf = nn.Linear((1 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 100) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(1) + self.bn0 = nn.BatchNorm2d(1) + self.bn1 = nn.BatchNorm2d(hidden_dim) + self.bn2 = nn.BatchNorm2d(hidden_dim) + + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (hidden_dim * 7 * 7) + center = None + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias==True: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): #[128, 1, 28, 28] + h1 = self.relu(self.v1_layer(self.bn(x))) + h2 = self.relu(self.v1_layer2(h1)) + + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=2) + x_pool = self.bn0(pool(x)) + h1_pool = self.bn1(pool(h1)) + h2_pool = self.bn2(pool(h2)) + + x_flat = x_pool.view(x_pool.size(0), -1) + h1_flat = h1_pool.view(h1_pool.size(0), -1) + h2_flat = h2_pool.view(h2_pool.size(0), -1) + + concat = torch.cat((x_flat, h1_flat, h2_flat), 1) + + beta = self.clf(concat) + return beta + +class Scattering_V1_MNIST(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super(Scattering_V1_MNIST, self).__init__() + self.v1_layer = nn.Conv2d(in_channels=1, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(1) + self.bn0 = nn.BatchNorm2d(1) + self.bn1 = nn.BatchNorm2d(hidden_dim) + self.bn2 = nn.BatchNorm2d(hidden_dim) + + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (hidden_dim * 7 * 7) + center = None + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias==True: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): #[128, 1, 28, 28] + h1 = self.relu(self.v1_layer(self.bn(x))) + h2 = self.relu(self.v1_layer2(h1)) + + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=2) + x_pool = self.bn0(pool(x)) + h1_pool = self.bn1(pool(h1)) + h2_pool = self.bn2(pool(h2)) + + x_flat = x_pool.view(x_pool.size(0), -1) + h1_flat = h1_pool.view(h1_pool.size(0), -1) + h2_flat = h2_pool.view(h2_pool.size(0), -1) + + concat = torch.cat((x_flat, h1_flat, h2_flat), 1) + return concat + +class Scattering_V1_celeba(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super().__init__() + + # fixed feature layers + self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, + kernel_size=7, stride=1, padding=3, bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim + 3, out_channels=hidden_dim, + kernel_size=7, stride=1, padding=3, bias=bias) + self.relu = nn.ReLU() + + # unsupervised layers + + self.bn_x = nn.BatchNorm2d(3) + self.bn_h1 = nn.BatchNorm2d(hidden_dim + 3) + self.bn_h2 = nn.BatchNorm2d(hidden_dim * 2 + 3) + + + # supervised layers + + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (hidden_dim * 7 * 7) + center = (3., 3.) + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): + # methods + # smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) + smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + flatten = nn.Flatten() + + x = self.bn_x(x) + h = torch.cat((self.relu(self.v1_layer(x)), smooth(x)), 1) + h = self.bn_h1(h) + h = torch.cat((self.relu(self.v1_layer2(h)), smooth(h)), 1) + h = self.bn_h2(h) + h = flatten(pool(h)) + return h + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..1e59b87 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from resnet import ResNet18 diff --git a/models/define_models.py b/models/define_models.py new file mode 100644 index 0000000..34de128 --- /dev/null +++ b/models/define_models.py @@ -0,0 +1,6 @@ +from resnet import ResNet18 + +def define_models(args): + if args.net == 'resnet18': + model = ResNet18() + return model diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000..9177d5f --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,137 @@ +'''ResNet in PyTorch. + +For Pre-activation ResNet, see 'preact_resnet.py'. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * + planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + + +def ResNet18_Class100(): + return ResNet(BasicBlock, [2, 2, 2, 2], 100) + +# TODO: use a recursive function to replace all linears with final dim 100 + +def ResNet34(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + + +def ResNet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + + +def ResNet101(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + + +def ResNet152(): + return ResNet(Bottleneck, [3, 8, 36, 3]) + + +def test(): + net = ResNet18() + y = net(torch.randn(1, 3, 32, 32)) + print(y.size()) + +#test() diff --git a/refactor/og_pytorch_cifar.py b/refactor/og_pytorch_cifar.py index 70f6662..7f118d5 100644 --- a/refactor/og_pytorch_cifar.py +++ b/refactor/og_pytorch_cifar.py @@ -20,12 +20,13 @@ from distutils.util import strtobool -from resnet import ResNet18 +from conv_modules import FactConv2dPreExp # TODO: import define_models function +from models.define_models import define_models def save_model(args, model): - src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/CIFAR10_pytorch_tta/" + src= "/home/mila/v/vivian.white/scratch/v1-models/saved-models/test_refactor/" model_dir = src + args.name os.makedirs(model_dir, exist_ok=True) os.chdir(model_dir) @@ -43,7 +44,7 @@ def save_model(args, model): parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') -parser.add_argument('--net', type=str, default='vgg', choices=['vgg', 'vggbn', +parser.add_argument('--net', type=str, default='resnet18', choices=['vgg', 'vggbn', 'resnet', 'factnetv1', 'factnetdefault', 'vggfact', 'vggbnfact'], help="which convmodule to use") parser.add_argument('--freeze_spatial', dest='freeze_spatial', type=lambda x: bool(strtobool(x)), default=True, @@ -60,6 +61,7 @@ def save_model(args, model): help='filename for saved model') parser.add_argument('--bias', dest='bias', type=lambda x: bool(strtobool(x)), default=False, help='bias=True or False') +parser.add_argument('--seed', default=0, type=int, help='seed to use') args = parser.parse_args() @@ -98,7 +100,6 @@ def save_model(args, model): print('==> Building model..') -from ConvModules import FactConv2dPreExp def replace_layers_keep_weight(model): for n, module in model.named_children(): if len(list(module.children())) > 0: @@ -123,12 +124,13 @@ def replace_layers_keep_weight(model): set_seeds(args.seed) net = define_models(args.net) +replace_layers_keep_weight(net) run_name = args.net - +print("Model Built!") set_seeds(args.seed) net = net.to(device) -wandb_dir = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/wandb" +wandb_dir = "/home/mila/v/vivian.white/scratch/v1-models/wandb" os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) run_name = "OGVGG" From 122c7582f87bc1e0e48aaeb764f714ee7b88ecd3 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 17:23:56 -0400 Subject: [PATCH 08/77] fixed models dir --- refactor/models/V1_models.py | 385 +++++++++++++++++++++++++++++++ refactor/models/__init__.py | 1 + refactor/models/define_models.py | 6 + refactor/models/resnet.py | 137 +++++++++++ refactor/og_pytorch_cifar.py | 2 +- 5 files changed, 530 insertions(+), 1 deletion(-) create mode 100644 refactor/models/V1_models.py create mode 100644 refactor/models/__init__.py create mode 100644 refactor/models/define_models.py create mode 100644 refactor/models/resnet.py diff --git a/refactor/models/V1_models.py b/refactor/models/V1_models.py new file mode 100644 index 0000000..c16d661 --- /dev/null +++ b/refactor/models/V1_models.py @@ -0,0 +1,385 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F +import sys +sys.path.insert(0, '/research/harris/vivian/structured_random_features/') +from src.models.init_weights import V1_init, classical_init, V1_weights +import gc +import LearnableCov + +def train(model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + optimizer.step() + + +def test(model, device, test_loader, epoch): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss + pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + accuracy = 100. * correct / len(test_loader.dataset) + + print('Test Epoch: {}\t Avg Loss: {:.4f}\t Accuracy: {:.2f}%'.format( + epoch, test_loss, accuracy)) + + return test_loss, accuracy + +class V1_CIFAR10(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super().__init__() + + # fixed feature layers + self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, + kernel_size=7, stride=1, padding=3, bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim + 3, out_channels=hidden_dim, + kernel_size=7, stride=1, padding=3, bias=bias) + self.relu = nn.ReLU() + + # unsupervised layers + self.bn_x = nn.BatchNorm2d(3) + self.bn_h1 = nn.BatchNorm2d(hidden_dim + 3) + self.bn_h2 = nn.BatchNorm2d(hidden_dim * 2 + 3) + + # supervised layers + self.clf = nn.Linear((8 ** 2) * (hidden_dim * 2 + 3), 10) + + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (hidden_dim * 7 * 7) + center = (3., 3.) + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): + # methods + # smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) + smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + flatten = nn.Flatten() + + x = self.bn_x(x) + h = torch.cat((self.relu(self.v1_layer(x)), smooth(x)), 1) + h = self.bn_h1(h) + h = torch.cat((self.relu(self.v1_layer2(h)), smooth(h)), 1) + h = self.bn_h2(h) + h = flatten(pool(h)) + return self.clf(h) + + +class Rand_Scat_Block(nn.Module): + def __init__(self, in_chan, num_filt, size, spatial_freq, + kernel_size=7, stride=1, padding=3, scale=None, bias=True, seed=None): + super().__init__() + + out_chan = in_chan + num_filt + + self.v1 = nn.Conv2d(in_channels=in_chan, out_channels=num_filt, + kernel_size=kernel_size, stride=stride, padding=padding, + scale=scale, bias=bias) + self.bn = nn.BatchNorm2d(num_filt) + self.relu = nn.ReLU() + + # V1 params + if scale is None: + scale = 1 / (in_chan * np.prod(kernel_size)) + center = ((kernel_size - 1) / 2, (kernel_size - 1) / 2) + + # init weights + V1_init(self.v1, size, spatial_freq, center, scale, bias, seed) + self.v1.weight.requires_grad = False + if bias: + self.v1.bias.requires_grad = False + + def forward(self, x): + smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) + h = self.relu(self.v1(x)) + concat = torch.cat((h, smooth(x)), 1) # concatenate with smoothed input + return self.bn(concat) + +class Rand_Scat_CIFAR10(nn.Module): + pass + +class Learned_Rand_Scat_CIFAR10(nn.Module): + def __init__(self, num_filt, size, spatial_freq, scale, bias, seed=None): + super().__init__() + + # channel dimensions + dims = [3, 64, 128] + + # fixed feature layers + self.v1_layer = nn.Conv2d(in_channels=dims[0], out_channels=num_filt, + kernel_size=7, stride=1, padding=3, bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=dims[1], out_channels=num_filt, + kernel_size=7, stride=1, padding=3, bias=bias) + self.relu = nn.ReLU() + + # unsupervised layers + self.bn_x = nn.BatchNorm2d(3) + self.bn_h1 = nn.LazyBatchNorm2d() + self.bn_h2 = nn.LazyBatchNorm2d() + + # supervised layers + self.L1 = nn.LazyConv2d(out_channels=dims[1], kernel_size=1, bias=False) + self.L2 = nn.LazyConv2d(out_channels=dims[2], kernel_size=1, bias=False) + self.clf = nn.LazyLinear(10) + + # init fixed weights + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (num_filt * 7 * 7) + center = (3., 3.) + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias==True: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + +# OLD +# def forward(self, x): +# h1 = self.relu(self.v1_layer(self.bn_x(x))) +# h2 = self.relu(self.v1_layer2(self.bn_h1(h1))) + +# pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) +# x_pool = self.bn0(pool(x)) +# h1_pool = self.bn1(pool(h1)) +# h2_pool = self.bn2(pool(h2)) + +# x_flat = x_pool.view(x_pool.size(0), -1) #view +# h1_flat = h1_pool.view(h1_pool.size(0), -1) #view +# h2_flat = h2_pool.view(h2_pool.size(0), -1) #view + + +# concat = torch.cat((x_flat, h1_flat, h2_flat), 1) + + def forward(self, x): + # methods + smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + flatten = nn.Flatten() + + # do pass + x = self.bn_x(x) + h1 = self.relu(self.v1_layer(x)) + h1 = torch.cat((h1, smooth(x)), 1) + h1 = self.bn_h1(h1) # chan: 3 + num_filt + h2 = self.relu(self.v1_layer2(self.L1(h1))) + h2 = torch.cat((h2, smooth(h1)), 1) # chan: 3 + num_filt + dims[1] + h2 = self.bn_h2(h2) + + concat = flatten(pool(h2)) + + beta = self.clf(concat) + + return beta + + +class V1_CIFAR100(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super(V1_CIFAR100, self).__init__() + self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.clf = nn.Linear((3 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 100) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(3) + self.bn0 = nn.BatchNorm2d(3) + self.bn1 = nn.BatchNorm2d(hidden_dim) + self.bn2 = nn.BatchNorm2d(hidden_dim) + self.bn_h1 = nn.BatchNorm2d(hidden_dim) + + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (hidden_dim * 7 * 7) + center = None + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias==True: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): #[128, 3, 32, 32] + h1 = self.relu(self.v1_layer(self.bn(x))) + h2 = self.relu(self.v1_layer2(self.bn_h1(h1))) + + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + x_pool = self.bn0(pool(x)) + h1_pool = self.bn1(pool(h1)) + h2_pool = self.bn2(pool(h2)) + + x_flat = x_pool.view(x_pool.size(0), -1) + h1_flat = h1_pool.view(h1_pool.size(0), -1) + h2_flat = h2_pool.view(h2_pool.size(0), -1) + + concat = torch.cat((x_flat, h1_flat, h2_flat), 1) + + beta = self.clf(concat) + return beta + +class V1_MNIST(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super(V1_MNIST, self).__init__() + self.v1_layer = nn.Conv2d(in_channels=1, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.clf = nn.Linear((1 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 100) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(1) + self.bn0 = nn.BatchNorm2d(1) + self.bn1 = nn.BatchNorm2d(hidden_dim) + self.bn2 = nn.BatchNorm2d(hidden_dim) + + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (hidden_dim * 7 * 7) + center = None + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias==True: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): #[128, 1, 28, 28] + h1 = self.relu(self.v1_layer(self.bn(x))) + h2 = self.relu(self.v1_layer2(h1)) + + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=2) + x_pool = self.bn0(pool(x)) + h1_pool = self.bn1(pool(h1)) + h2_pool = self.bn2(pool(h2)) + + x_flat = x_pool.view(x_pool.size(0), -1) + h1_flat = h1_pool.view(h1_pool.size(0), -1) + h2_flat = h2_pool.view(h2_pool.size(0), -1) + + concat = torch.cat((x_flat, h1_flat, h2_flat), 1) + + beta = self.clf(concat) + return beta + +class Scattering_V1_MNIST(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super(Scattering_V1_MNIST, self).__init__() + self.v1_layer = nn.Conv2d(in_channels=1, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(1) + self.bn0 = nn.BatchNorm2d(1) + self.bn1 = nn.BatchNorm2d(hidden_dim) + self.bn2 = nn.BatchNorm2d(hidden_dim) + + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (hidden_dim * 7 * 7) + center = None + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias==True: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): #[128, 1, 28, 28] + h1 = self.relu(self.v1_layer(self.bn(x))) + h2 = self.relu(self.v1_layer2(h1)) + + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=2) + x_pool = self.bn0(pool(x)) + h1_pool = self.bn1(pool(h1)) + h2_pool = self.bn2(pool(h2)) + + x_flat = x_pool.view(x_pool.size(0), -1) + h1_flat = h1_pool.view(h1_pool.size(0), -1) + h2_flat = h2_pool.view(h2_pool.size(0), -1) + + concat = torch.cat((x_flat, h1_flat, h2_flat), 1) + return concat + +class Scattering_V1_celeba(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super().__init__() + + # fixed feature layers + self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, + kernel_size=7, stride=1, padding=3, bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim + 3, out_channels=hidden_dim, + kernel_size=7, stride=1, padding=3, bias=bias) + self.relu = nn.ReLU() + + # unsupervised layers + + self.bn_x = nn.BatchNorm2d(3) + self.bn_h1 = nn.BatchNorm2d(hidden_dim + 3) + self.bn_h2 = nn.BatchNorm2d(hidden_dim * 2 + 3) + + + # supervised layers + + scale1 = 1 / (3 * 7 * 7) + scale2 = 1 / (hidden_dim * 7 * 7) + center = (3., 3.) + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): + # methods + # smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) + smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + flatten = nn.Flatten() + + x = self.bn_x(x) + h = torch.cat((self.relu(self.v1_layer(x)), smooth(x)), 1) + h = self.bn_h1(h) + h = torch.cat((self.relu(self.v1_layer2(h)), smooth(h)), 1) + h = self.bn_h2(h) + h = flatten(pool(h)) + return h + diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py new file mode 100644 index 0000000..9c7178b --- /dev/null +++ b/refactor/models/__init__.py @@ -0,0 +1 @@ +from .resnet import ResNet18 diff --git a/refactor/models/define_models.py b/refactor/models/define_models.py new file mode 100644 index 0000000..904af30 --- /dev/null +++ b/refactor/models/define_models.py @@ -0,0 +1,6 @@ +from .resnet import ResNet18 + +def define_models(args): + if args.net == 'resnet18': + model = ResNet18() + return model diff --git a/refactor/models/resnet.py b/refactor/models/resnet.py new file mode 100644 index 0000000..9177d5f --- /dev/null +++ b/refactor/models/resnet.py @@ -0,0 +1,137 @@ +'''ResNet in PyTorch. + +For Pre-activation ResNet, see 'preact_resnet.py'. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * + planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + + +def ResNet18_Class100(): + return ResNet(BasicBlock, [2, 2, 2, 2], 100) + +# TODO: use a recursive function to replace all linears with final dim 100 + +def ResNet34(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + + +def ResNet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + + +def ResNet101(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + + +def ResNet152(): + return ResNet(Bottleneck, [3, 8, 36, 3]) + + +def test(): + net = ResNet18() + y = net(torch.randn(1, 3, 32, 32)) + print(y.size()) + +#test() diff --git a/refactor/og_pytorch_cifar.py b/refactor/og_pytorch_cifar.py index 7f118d5..c75e23f 100644 --- a/refactor/og_pytorch_cifar.py +++ b/refactor/og_pytorch_cifar.py @@ -123,7 +123,7 @@ def replace_layers_keep_weight(model): setattr(model, n, new_module) set_seeds(args.seed) -net = define_models(args.net) +net = define_models(args) replace_layers_keep_weight(net) run_name = args.net print("Model Built!") From 28841fea97286f8e7b3ef4abf9a13293d42e31b0 Mon Sep 17 00:00:00 2001 From: vivianwhite <66977221+vivianwhite@users.noreply.github.com> Date: Sun, 21 Apr 2024 14:32:08 -0700 Subject: [PATCH 09/77] Delete models directory --- models/V1_models.py | 385 ---------------------------------------- models/__init__.py | 1 - models/define_models.py | 6 - models/resnet.py | 137 -------------- 4 files changed, 529 deletions(-) delete mode 100644 models/V1_models.py delete mode 100644 models/__init__.py delete mode 100644 models/define_models.py delete mode 100644 models/resnet.py diff --git a/models/V1_models.py b/models/V1_models.py deleted file mode 100644 index c16d661..0000000 --- a/models/V1_models.py +++ /dev/null @@ -1,385 +0,0 @@ -import torch.nn as nn -import torch -import torch.nn.functional as F -import sys -sys.path.insert(0, '/research/harris/vivian/structured_random_features/') -from src.models.init_weights import V1_init, classical_init, V1_weights -import gc -import LearnableCov - -def train(model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = F.cross_entropy(output, target) - loss.backward() - optimizer.step() - - -def test(model, device, test_loader, epoch): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - output = model(data) - test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss - pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - accuracy = 100. * correct / len(test_loader.dataset) - - print('Test Epoch: {}\t Avg Loss: {:.4f}\t Accuracy: {:.2f}%'.format( - epoch, test_loss, accuracy)) - - return test_loss, accuracy - -class V1_CIFAR10(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): - super().__init__() - - # fixed feature layers - self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, - kernel_size=7, stride=1, padding=3, bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim + 3, out_channels=hidden_dim, - kernel_size=7, stride=1, padding=3, bias=bias) - self.relu = nn.ReLU() - - # unsupervised layers - self.bn_x = nn.BatchNorm2d(3) - self.bn_h1 = nn.BatchNorm2d(hidden_dim + 3) - self.bn_h2 = nn.BatchNorm2d(hidden_dim * 2 + 3) - - # supervised layers - self.clf = nn.Linear((8 ** 2) * (hidden_dim * 2 + 3), 10) - - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (hidden_dim * 7 * 7) - center = (3., 3.) - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): - # methods - # smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) - smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - flatten = nn.Flatten() - - x = self.bn_x(x) - h = torch.cat((self.relu(self.v1_layer(x)), smooth(x)), 1) - h = self.bn_h1(h) - h = torch.cat((self.relu(self.v1_layer2(h)), smooth(h)), 1) - h = self.bn_h2(h) - h = flatten(pool(h)) - return self.clf(h) - - -class Rand_Scat_Block(nn.Module): - def __init__(self, in_chan, num_filt, size, spatial_freq, - kernel_size=7, stride=1, padding=3, scale=None, bias=True, seed=None): - super().__init__() - - out_chan = in_chan + num_filt - - self.v1 = nn.Conv2d(in_channels=in_chan, out_channels=num_filt, - kernel_size=kernel_size, stride=stride, padding=padding, - scale=scale, bias=bias) - self.bn = nn.BatchNorm2d(num_filt) - self.relu = nn.ReLU() - - # V1 params - if scale is None: - scale = 1 / (in_chan * np.prod(kernel_size)) - center = ((kernel_size - 1) / 2, (kernel_size - 1) / 2) - - # init weights - V1_init(self.v1, size, spatial_freq, center, scale, bias, seed) - self.v1.weight.requires_grad = False - if bias: - self.v1.bias.requires_grad = False - - def forward(self, x): - smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) - h = self.relu(self.v1(x)) - concat = torch.cat((h, smooth(x)), 1) # concatenate with smoothed input - return self.bn(concat) - -class Rand_Scat_CIFAR10(nn.Module): - pass - -class Learned_Rand_Scat_CIFAR10(nn.Module): - def __init__(self, num_filt, size, spatial_freq, scale, bias, seed=None): - super().__init__() - - # channel dimensions - dims = [3, 64, 128] - - # fixed feature layers - self.v1_layer = nn.Conv2d(in_channels=dims[0], out_channels=num_filt, - kernel_size=7, stride=1, padding=3, bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=dims[1], out_channels=num_filt, - kernel_size=7, stride=1, padding=3, bias=bias) - self.relu = nn.ReLU() - - # unsupervised layers - self.bn_x = nn.BatchNorm2d(3) - self.bn_h1 = nn.LazyBatchNorm2d() - self.bn_h2 = nn.LazyBatchNorm2d() - - # supervised layers - self.L1 = nn.LazyConv2d(out_channels=dims[1], kernel_size=1, bias=False) - self.L2 = nn.LazyConv2d(out_channels=dims[2], kernel_size=1, bias=False) - self.clf = nn.LazyLinear(10) - - # init fixed weights - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (num_filt * 7 * 7) - center = (3., 3.) - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias==True: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - -# OLD -# def forward(self, x): -# h1 = self.relu(self.v1_layer(self.bn_x(x))) -# h2 = self.relu(self.v1_layer2(self.bn_h1(h1))) - -# pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) -# x_pool = self.bn0(pool(x)) -# h1_pool = self.bn1(pool(h1)) -# h2_pool = self.bn2(pool(h2)) - -# x_flat = x_pool.view(x_pool.size(0), -1) #view -# h1_flat = h1_pool.view(h1_pool.size(0), -1) #view -# h2_flat = h2_pool.view(h2_pool.size(0), -1) #view - - -# concat = torch.cat((x_flat, h1_flat, h2_flat), 1) - - def forward(self, x): - # methods - smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - flatten = nn.Flatten() - - # do pass - x = self.bn_x(x) - h1 = self.relu(self.v1_layer(x)) - h1 = torch.cat((h1, smooth(x)), 1) - h1 = self.bn_h1(h1) # chan: 3 + num_filt - h2 = self.relu(self.v1_layer2(self.L1(h1))) - h2 = torch.cat((h2, smooth(h1)), 1) # chan: 3 + num_filt + dims[1] - h2 = self.bn_h2(h2) - - concat = flatten(pool(h2)) - - beta = self.clf(concat) - - return beta - - -class V1_CIFAR100(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): - super(V1_CIFAR100, self).__init__() - self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.clf = nn.Linear((3 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 100) - self.relu = nn.ReLU() - self.bn = nn.BatchNorm2d(3) - self.bn0 = nn.BatchNorm2d(3) - self.bn1 = nn.BatchNorm2d(hidden_dim) - self.bn2 = nn.BatchNorm2d(hidden_dim) - self.bn_h1 = nn.BatchNorm2d(hidden_dim) - - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (hidden_dim * 7 * 7) - center = None - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias==True: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): #[128, 3, 32, 32] - h1 = self.relu(self.v1_layer(self.bn(x))) - h2 = self.relu(self.v1_layer2(self.bn_h1(h1))) - - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - x_pool = self.bn0(pool(x)) - h1_pool = self.bn1(pool(h1)) - h2_pool = self.bn2(pool(h2)) - - x_flat = x_pool.view(x_pool.size(0), -1) - h1_flat = h1_pool.view(h1_pool.size(0), -1) - h2_flat = h2_pool.view(h2_pool.size(0), -1) - - concat = torch.cat((x_flat, h1_flat, h2_flat), 1) - - beta = self.clf(concat) - return beta - -class V1_MNIST(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): - super(V1_MNIST, self).__init__() - self.v1_layer = nn.Conv2d(in_channels=1, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.clf = nn.Linear((1 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 100) - self.relu = nn.ReLU() - self.bn = nn.BatchNorm2d(1) - self.bn0 = nn.BatchNorm2d(1) - self.bn1 = nn.BatchNorm2d(hidden_dim) - self.bn2 = nn.BatchNorm2d(hidden_dim) - - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (hidden_dim * 7 * 7) - center = None - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias==True: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): #[128, 1, 28, 28] - h1 = self.relu(self.v1_layer(self.bn(x))) - h2 = self.relu(self.v1_layer2(h1)) - - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=2) - x_pool = self.bn0(pool(x)) - h1_pool = self.bn1(pool(h1)) - h2_pool = self.bn2(pool(h2)) - - x_flat = x_pool.view(x_pool.size(0), -1) - h1_flat = h1_pool.view(h1_pool.size(0), -1) - h2_flat = h2_pool.view(h2_pool.size(0), -1) - - concat = torch.cat((x_flat, h1_flat, h2_flat), 1) - - beta = self.clf(concat) - return beta - -class Scattering_V1_MNIST(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): - super(Scattering_V1_MNIST, self).__init__() - self.v1_layer = nn.Conv2d(in_channels=1, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.relu = nn.ReLU() - self.bn = nn.BatchNorm2d(1) - self.bn0 = nn.BatchNorm2d(1) - self.bn1 = nn.BatchNorm2d(hidden_dim) - self.bn2 = nn.BatchNorm2d(hidden_dim) - - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (hidden_dim * 7 * 7) - center = None - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias==True: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): #[128, 1, 28, 28] - h1 = self.relu(self.v1_layer(self.bn(x))) - h2 = self.relu(self.v1_layer2(h1)) - - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=2) - x_pool = self.bn0(pool(x)) - h1_pool = self.bn1(pool(h1)) - h2_pool = self.bn2(pool(h2)) - - x_flat = x_pool.view(x_pool.size(0), -1) - h1_flat = h1_pool.view(h1_pool.size(0), -1) - h2_flat = h2_pool.view(h2_pool.size(0), -1) - - concat = torch.cat((x_flat, h1_flat, h2_flat), 1) - return concat - -class Scattering_V1_celeba(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): - super().__init__() - - # fixed feature layers - self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, - kernel_size=7, stride=1, padding=3, bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim + 3, out_channels=hidden_dim, - kernel_size=7, stride=1, padding=3, bias=bias) - self.relu = nn.ReLU() - - # unsupervised layers - - self.bn_x = nn.BatchNorm2d(3) - self.bn_h1 = nn.BatchNorm2d(hidden_dim + 3) - self.bn_h2 = nn.BatchNorm2d(hidden_dim * 2 + 3) - - - # supervised layers - - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (hidden_dim * 7 * 7) - center = (3., 3.) - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): - # methods - # smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) - smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - flatten = nn.Flatten() - - x = self.bn_x(x) - h = torch.cat((self.relu(self.v1_layer(x)), smooth(x)), 1) - h = self.bn_h1(h) - h = torch.cat((self.relu(self.v1_layer2(h)), smooth(h)), 1) - h = self.bn_h2(h) - h = flatten(pool(h)) - return h - diff --git a/models/__init__.py b/models/__init__.py deleted file mode 100644 index 1e59b87..0000000 --- a/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from resnet import ResNet18 diff --git a/models/define_models.py b/models/define_models.py deleted file mode 100644 index 34de128..0000000 --- a/models/define_models.py +++ /dev/null @@ -1,6 +0,0 @@ -from resnet import ResNet18 - -def define_models(args): - if args.net == 'resnet18': - model = ResNet18() - return model diff --git a/models/resnet.py b/models/resnet.py deleted file mode 100644 index 9177d5f..0000000 --- a/models/resnet.py +++ /dev/null @@ -1,137 +0,0 @@ -'''ResNet in PyTorch. - -For Pre-activation ResNet, see 'preact_resnet.py'. - -Reference: -[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun - Deep Residual Learning for Image Recognition. arXiv:1512.03385 -''' -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super(BasicBlock, self).__init__() - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=1, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion*planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion*planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, in_planes, planes, stride=1): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=stride, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, self.expansion * - planes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(self.expansion*planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion*planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion*planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = F.relu(self.bn2(self.conv2(out))) - out = self.bn3(self.conv3(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class ResNet(nn.Module): - def __init__(self, block, num_blocks, num_classes=10): - super(ResNet, self).__init__() - self.in_planes = 64 - - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, - stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) - self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) - self.linear = nn.Linear(512*block.expansion, num_classes) - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1]*(num_blocks-1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.layer1(out) - out = self.layer2(out) - out = self.layer3(out) - out = self.layer4(out) - out = F.avg_pool2d(out, 4) - out = out.view(out.size(0), -1) - out = self.linear(out) - return out - - -def ResNet18(): - return ResNet(BasicBlock, [2, 2, 2, 2]) - - -def ResNet18_Class100(): - return ResNet(BasicBlock, [2, 2, 2, 2], 100) - -# TODO: use a recursive function to replace all linears with final dim 100 - -def ResNet34(): - return ResNet(BasicBlock, [3, 4, 6, 3]) - - -def ResNet50(): - return ResNet(Bottleneck, [3, 4, 6, 3]) - - -def ResNet101(): - return ResNet(Bottleneck, [3, 4, 23, 3]) - - -def ResNet152(): - return ResNet(Bottleneck, [3, 8, 36, 3]) - - -def test(): - net = ResNet18() - y = net(torch.randn(1, 3, 32, 32)) - print(y.size()) - -#test() From 7ddc12c2f9b4febe2f784742f6e760655b2bd014 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 17:36:13 -0400 Subject: [PATCH 10/77] add utils file --- refactor/pytorch_cifar_utils.py | 135 ++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 refactor/pytorch_cifar_utils.py diff --git a/refactor/pytorch_cifar_utils.py b/refactor/pytorch_cifar_utils.py new file mode 100644 index 0000000..08b9e15 --- /dev/null +++ b/refactor/pytorch_cifar_utils.py @@ -0,0 +1,135 @@ +'''Some helper functions for PyTorch, including: + - get_mean_and_std: calculate the mean and std value of dataset. + - msr_init: net parameter initialization. + - progress_bar: progress bar mimic xlua.progress. +''' +import os +import sys +import time +import math + +import torch +import torch.nn as nn +import torch.nn.init as init +import random +import numpy as np + +def set_seeds(seed): + torch.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + random.seed(seed) + np.random.seed(seed) + + + +def get_mean_and_std(dataset): + '''Compute the mean and std value of dataset.''' + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) + mean = torch.zeros(3) + std = torch.zeros(3) + print('==> Computing mean and std..') + for inputs, targets in dataloader: + for i in range(3): + mean[i] += inputs[:,i,:,:].mean() + std[i] += inputs[:,i,:,:].std() + mean.div_(len(dataset)) + std.div_(len(dataset)) + return mean, std + +def init_params(net): + '''Init layer parameters.''' + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal(m.weight, mode='fan_out') + if m.bias: + init.constant(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + init.constant(m.weight, 1) + init.constant(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal(m.weight, std=1e-3) + if m.bias: + init.constant(m.bias, 0) + + +#_, term_width = os.popen('stty size', 'r').read().split() +term_width = 80 #int(term_width) + +TOTAL_BAR_LENGTH = 65. +last_time = time.time() +begin_time = last_time +def progress_bar(current, total, msg=None): + global last_time, begin_time + if current == 0: + begin_time = time.time() # Reset for new bar. + + cur_len = int(TOTAL_BAR_LENGTH*current/total) + rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 + + sys.stdout.write(' [') + for i in range(cur_len): + sys.stdout.write('=') + sys.stdout.write('>') + for i in range(rest_len): + sys.stdout.write('.') + sys.stdout.write(']') + + cur_time = time.time() + step_time = cur_time - last_time + last_time = cur_time + tot_time = cur_time - begin_time + + L = [] + L.append(' Step: %s' % format_time(step_time)) + L.append(' | Tot: %s' % format_time(tot_time)) + if msg: + L.append(' | ' + msg) + + msg = ''.join(L) + sys.stdout.write(msg) + for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): + sys.stdout.write(' ') + + # Go back to the center of the bar. + for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): + sys.stdout.write('\b') + sys.stdout.write(' %d/%d ' % (current+1, total)) + + if current < total-1: + sys.stdout.write('\r') + else: + sys.stdout.write('\n') + sys.stdout.flush() + +def format_time(seconds): + days = int(seconds / 3600/24) + seconds = seconds - days*3600*24 + hours = int(seconds / 3600) + seconds = seconds - hours*3600 + minutes = int(seconds / 60) + seconds = seconds - minutes*60 + secondsf = int(seconds) + seconds = seconds - secondsf + millis = int(seconds*1000) + + f = '' + i = 1 + if days > 0: + f += str(days) + 'D' + i += 1 + if hours > 0 and i <= 2: + f += str(hours) + 'h' + i += 1 + if minutes > 0 and i <= 2: + f += str(minutes) + 'm' + i += 1 + if secondsf > 0 and i <= 2: + f += str(secondsf) + 's' + i += 1 + if millis > 0 and i <= 2: + f += str(millis) + 'ms' + i += 1 + if f == '': + f = '0ms' + return f From 55670bae4edd13ad14f24de3425dad3359a63ba6 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 18:35:03 -0400 Subject: [PATCH 11/77] fix white spaces --- refactor/conv_modules.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py index d6b008c..1eb6d5f 100644 --- a/refactor/conv_modules.py +++ b/refactor/conv_modules.py @@ -11,9 +11,7 @@ def _contract(tensor, matrix, axis): """tensor is (..., D, ...), matrix is (P, D), returns (..., P, ...).""" t = torch.moveaxis(tensor, source=axis, destination=-1) # (..., D) - r = t @ matrix.T # (..., P) - return torch.moveaxis(r, source=-1, destination=axis) # (..., P, ...) @@ -62,7 +60,6 @@ def __init__( tri2_vec = torch.zeros((triu2_len,), **factory_kwargs) self.tri2_vec = Parameter(tri2_vec) - def forward(self, input: Tensor) -> Tensor: U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // self.groups) U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1]) @@ -77,18 +74,14 @@ def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, composite_weight, self.bias) - def _tri_vec_to_mat(self, vec, n): U = torch.zeros((n, n), **self.factory_kwargs) U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec return U - def _exp_diag(self, mat): exp_diag = torch.exp(torch.diagonal(mat)) - n = mat.shape[0] - mat[range(n), range(n)] = exp_diag return mat @@ -144,14 +137,12 @@ def __init__( tri2_vec = torch.zeros((triu2_len,), **factory_kwargs) self.tri2_vec = Parameter(tri2_vec) - def construct_Us(self): self.tri1_vec = Parameter(self._tri_vec_to_mat(self.tri1_vec, self.in_channels // self.groups,self.scat_idx1)) self.tri2_vec = Parameter(self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], self.scat_idx2)) - def forward(self, input: Tensor) -> Tensor: U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // self.groups, self.scat_idx1) @@ -165,11 +156,9 @@ def forward(self, input: Tensor) -> Tensor: ).reshape(self.weight.shape) return self._conv_forward(input, composite_weight, self.bias) - def _tri_vec_to_mat(self, vec, n, scat_idx): U = torch.zeros((n* n), **self.factory_kwargs).scatter_(0,scat_idx,vec).view(n,n) U = torch.diagonal_scatter(U,U.diagonal().exp_()) return U - From a779085803f75fde9635b962415d3f1519e6c721 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sun, 21 Apr 2024 18:39:39 -0400 Subject: [PATCH 12/77] fixing V1 cifar10 model --- refactor/models/V1_models.py | 326 ++----------------------------- refactor/models/__init__.py | 11 ++ refactor/models/define_models.py | 6 - 3 files changed, 31 insertions(+), 312 deletions(-) delete mode 100644 refactor/models/define_models.py diff --git a/refactor/models/V1_models.py b/refactor/models/V1_models.py index c16d661..bc4bfaa 100644 --- a/refactor/models/V1_models.py +++ b/refactor/models/V1_models.py @@ -2,59 +2,26 @@ import torch import torch.nn.functional as F import sys -sys.path.insert(0, '/research/harris/vivian/structured_random_features/') +sys.path.insert(0, '/home/mila/v/vivian.white/structured-random-features') from src.models.init_weights import V1_init, classical_init, V1_weights import gc -import LearnableCov - -def train(model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = F.cross_entropy(output, target) - loss.backward() - optimizer.step() - - -def test(model, device, test_loader, epoch): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - output = model(data) - test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss - pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - accuracy = 100. * correct / len(test_loader.dataset) - - print('Test Epoch: {}\t Avg Loss: {:.4f}\t Accuracy: {:.2f}%'.format( - epoch, test_loss, accuracy)) - - return test_loss, accuracy class V1_CIFAR10(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias=False, seed=None): super().__init__() - # fixed feature layers + self.bn_x = nn.BatchNorm2d(3) self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, bias=bias) + self.relu = nn.ReLU() + self.smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) + self.bn_h1 = nn.BatchNorm2d(hidden_dim + 3) self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim + 3, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, bias=bias) - self.relu = nn.ReLU() - # unsupervised layers - self.bn_x = nn.BatchNorm2d(3) - self.bn_h1 = nn.BatchNorm2d(hidden_dim + 3) self.bn_h2 = nn.BatchNorm2d(hidden_dim * 2 + 3) - - # supervised layers + self.pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + self.flatten = nn.Flatten() self.clf = nn.Linear((8 ** 2) * (hidden_dim * 2 + 3), 10) scale1 = 1 / (3 * 7 * 7) @@ -72,147 +39,33 @@ def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): self.v1_layer2.bias.requires_grad = False def forward(self, x): - # methods - # smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) - smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - flatten = nn.Flatten() - x = self.bn_x(x) - h = torch.cat((self.relu(self.v1_layer(x)), smooth(x)), 1) + h = torch.cat((self.relu(self.v1_layer(x)), self.smooth(x)), 1) h = self.bn_h1(h) - h = torch.cat((self.relu(self.v1_layer2(h)), smooth(h)), 1) + h = torch.cat((self.relu(self.v1_layer2(h)), self.smooth(h)), 1) h = self.bn_h2(h) - h = flatten(pool(h)) + h = self.flatten(self.pool(h)) return self.clf(h) -class Rand_Scat_Block(nn.Module): - def __init__(self, in_chan, num_filt, size, spatial_freq, - kernel_size=7, stride=1, padding=3, scale=None, bias=True, seed=None): - super().__init__() - - out_chan = in_chan + num_filt - - self.v1 = nn.Conv2d(in_channels=in_chan, out_channels=num_filt, - kernel_size=kernel_size, stride=stride, padding=padding, - scale=scale, bias=bias) - self.bn = nn.BatchNorm2d(num_filt) - self.relu = nn.ReLU() - - # V1 params - if scale is None: - scale = 1 / (in_chan * np.prod(kernel_size)) - center = ((kernel_size - 1) / 2, (kernel_size - 1) / 2) - - # init weights - V1_init(self.v1, size, spatial_freq, center, scale, bias, seed) - self.v1.weight.requires_grad = False - if bias: - self.v1.bias.requires_grad = False - - def forward(self, x): - smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) - h = self.relu(self.v1(x)) - concat = torch.cat((h, smooth(x)), 1) # concatenate with smoothed input - return self.bn(concat) - -class Rand_Scat_CIFAR10(nn.Module): - pass - -class Learned_Rand_Scat_CIFAR10(nn.Module): - def __init__(self, num_filt, size, spatial_freq, scale, bias, seed=None): - super().__init__() - - # channel dimensions - dims = [3, 64, 128] - - # fixed feature layers - self.v1_layer = nn.Conv2d(in_channels=dims[0], out_channels=num_filt, - kernel_size=7, stride=1, padding=3, bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=dims[1], out_channels=num_filt, - kernel_size=7, stride=1, padding=3, bias=bias) - self.relu = nn.ReLU() - - # unsupervised layers - self.bn_x = nn.BatchNorm2d(3) - self.bn_h1 = nn.LazyBatchNorm2d() - self.bn_h2 = nn.LazyBatchNorm2d() - - # supervised layers - self.L1 = nn.LazyConv2d(out_channels=dims[1], kernel_size=1, bias=False) - self.L2 = nn.LazyConv2d(out_channels=dims[2], kernel_size=1, bias=False) - self.clf = nn.LazyLinear(10) - - # init fixed weights - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (num_filt * 7 * 7) - center = (3., 3.) - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias==True: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - -# OLD -# def forward(self, x): -# h1 = self.relu(self.v1_layer(self.bn_x(x))) -# h2 = self.relu(self.v1_layer2(self.bn_h1(h1))) - -# pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) -# x_pool = self.bn0(pool(x)) -# h1_pool = self.bn1(pool(h1)) -# h2_pool = self.bn2(pool(h2)) - -# x_flat = x_pool.view(x_pool.size(0), -1) #view -# h1_flat = h1_pool.view(h1_pool.size(0), -1) #view -# h2_flat = h2_pool.view(h2_pool.size(0), -1) #view - - -# concat = torch.cat((x_flat, h1_flat, h2_flat), 1) - - def forward(self, x): - # methods - smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - flatten = nn.Flatten() - - # do pass - x = self.bn_x(x) - h1 = self.relu(self.v1_layer(x)) - h1 = torch.cat((h1, smooth(x)), 1) - h1 = self.bn_h1(h1) # chan: 3 + num_filt - h2 = self.relu(self.v1_layer2(self.L1(h1))) - h2 = torch.cat((h2, smooth(h1)), 1) # chan: 3 + num_filt + dims[1] - h2 = self.bn_h2(h2) - - concat = flatten(pool(h2)) - - beta = self.clf(concat) - - return beta - class V1_CIFAR100(nn.Module): def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): super(V1_CIFAR100, self).__init__() + + self.bn_x = nn.BatchNorm2d(3) self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, bias=bias) + self.relu = nn.ReLU() self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, bias=bias) - self.clf = nn.Linear((3 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 100) - self.relu = nn.ReLU() - self.bn = nn.BatchNorm2d(3) self.bn0 = nn.BatchNorm2d(3) self.bn1 = nn.BatchNorm2d(hidden_dim) self.bn2 = nn.BatchNorm2d(hidden_dim) self.bn_h1 = nn.BatchNorm2d(hidden_dim) + self.pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + self.clf = nn.Linear((3 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 100) scale1 = 1 / (3 * 7 * 7) scale2 = 1 / (hidden_dim * 7 * 7) center = None @@ -228,13 +81,12 @@ def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): self.v1_layer2.bias.requires_grad = False def forward(self, x): #[128, 3, 32, 32] - h1 = self.relu(self.v1_layer(self.bn(x))) + h1 = self.relu(self.v1_layer(self.bn_x(x))) h2 = self.relu(self.v1_layer2(self.bn_h1(h1))) - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - x_pool = self.bn0(pool(x)) - h1_pool = self.bn1(pool(h1)) - h2_pool = self.bn2(pool(h2)) + x_pool = self.bn0(self.pool(x)) + h1_pool = self.bn1(self.pool(h1)) + h2_pool = self.bn2(self.pool(h2)) x_flat = x_pool.view(x_pool.size(0), -1) h1_flat = h1_pool.view(h1_pool.size(0), -1) @@ -245,141 +97,3 @@ def forward(self, x): #[128, 3, 32, 32] beta = self.clf(concat) return beta -class V1_MNIST(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): - super(V1_MNIST, self).__init__() - self.v1_layer = nn.Conv2d(in_channels=1, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.clf = nn.Linear((1 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 100) - self.relu = nn.ReLU() - self.bn = nn.BatchNorm2d(1) - self.bn0 = nn.BatchNorm2d(1) - self.bn1 = nn.BatchNorm2d(hidden_dim) - self.bn2 = nn.BatchNorm2d(hidden_dim) - - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (hidden_dim * 7 * 7) - center = None - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias==True: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): #[128, 1, 28, 28] - h1 = self.relu(self.v1_layer(self.bn(x))) - h2 = self.relu(self.v1_layer2(h1)) - - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=2) - x_pool = self.bn0(pool(x)) - h1_pool = self.bn1(pool(h1)) - h2_pool = self.bn2(pool(h2)) - - x_flat = x_pool.view(x_pool.size(0), -1) - h1_flat = h1_pool.view(h1_pool.size(0), -1) - h2_flat = h2_pool.view(h2_pool.size(0), -1) - - concat = torch.cat((x_flat, h1_flat, h2_flat), 1) - - beta = self.clf(concat) - return beta - -class Scattering_V1_MNIST(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): - super(Scattering_V1_MNIST, self).__init__() - self.v1_layer = nn.Conv2d(in_channels=1, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.relu = nn.ReLU() - self.bn = nn.BatchNorm2d(1) - self.bn0 = nn.BatchNorm2d(1) - self.bn1 = nn.BatchNorm2d(hidden_dim) - self.bn2 = nn.BatchNorm2d(hidden_dim) - - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (hidden_dim * 7 * 7) - center = None - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias==True: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): #[128, 1, 28, 28] - h1 = self.relu(self.v1_layer(self.bn(x))) - h2 = self.relu(self.v1_layer2(h1)) - - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=2) - x_pool = self.bn0(pool(x)) - h1_pool = self.bn1(pool(h1)) - h2_pool = self.bn2(pool(h2)) - - x_flat = x_pool.view(x_pool.size(0), -1) - h1_flat = h1_pool.view(h1_pool.size(0), -1) - h2_flat = h2_pool.view(h2_pool.size(0), -1) - - concat = torch.cat((x_flat, h1_flat, h2_flat), 1) - return concat - -class Scattering_V1_celeba(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): - super().__init__() - - # fixed feature layers - self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, - kernel_size=7, stride=1, padding=3, bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim + 3, out_channels=hidden_dim, - kernel_size=7, stride=1, padding=3, bias=bias) - self.relu = nn.ReLU() - - # unsupervised layers - - self.bn_x = nn.BatchNorm2d(3) - self.bn_h1 = nn.BatchNorm2d(hidden_dim + 3) - self.bn_h2 = nn.BatchNorm2d(hidden_dim * 2 + 3) - - - # supervised layers - - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (hidden_dim * 7 * 7) - center = (3., 3.) - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): - # methods - # smooth = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) - smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - flatten = nn.Flatten() - - x = self.bn_x(x) - h = torch.cat((self.relu(self.v1_layer(x)), smooth(x)), 1) - h = self.bn_h1(h) - h = torch.cat((self.relu(self.v1_layer2(h)), smooth(h)), 1) - h = self.bn_h2(h) - h = flatten(pool(h)) - return h - diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py index 9c7178b..ab9aef3 100644 --- a/refactor/models/__init__.py +++ b/refactor/models/__init__.py @@ -1 +1,12 @@ from .resnet import ResNet18 +from .V1_models import V1_CIFAR10, V1_CIFAR100 + +def define_models(args): + if args.net == 'resnet18': + model = ResNet18() + elif args.net == 'rsn_cifar10': + model = V1_CIFAR10(hidden_dim=100, size=args.s, + spatial_freq=args.f, scale=args.scale) + elif args.net == 'rsn_cifar100': + model = V1_CIFAR100() + return model diff --git a/refactor/models/define_models.py b/refactor/models/define_models.py deleted file mode 100644 index 904af30..0000000 --- a/refactor/models/define_models.py +++ /dev/null @@ -1,6 +0,0 @@ -from .resnet import ResNet18 - -def define_models(args): - if args.net == 'resnet18': - model = ResNet18() - return model From b8732541f64bcce6db896a97a8fdcc7ed99e572d Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Mon, 22 Apr 2024 01:03:16 -0400 Subject: [PATCH 13/77] add replace_linears function --- refactor/models/V1_models.py | 22 +++++++++++++++++++++- refactor/models/__init__.py | 6 ++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/refactor/models/V1_models.py b/refactor/models/V1_models.py index bc4bfaa..dbae415 100644 --- a/refactor/models/V1_models.py +++ b/refactor/models/V1_models.py @@ -6,6 +6,26 @@ from src.models.init_weights import V1_init, classical_init, V1_weights import gc +def replace_linear_layer(model): + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + replace_layers_keep_weight(module) + if isinstance(module, nn.Linear): + ## simple module + new_module = nn.Linear( + in_features=module.in_features, + out_features=100, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + #new_sd['weight'] = old_sd['weight'] + #if module.bias is not None: + # new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + setattr(model, n, new_module) + + class V1_CIFAR10(nn.Module): def __init__(self, hidden_dim, size, spatial_freq, scale, bias=False, seed=None): super().__init__() @@ -50,7 +70,7 @@ def forward(self, x): class V1_CIFAR100(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias=False, seed=None): super(V1_CIFAR100, self).__init__() self.bn_x = nn.BatchNorm2d(3) diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py index ab9aef3..7e50696 100644 --- a/refactor/models/__init__.py +++ b/refactor/models/__init__.py @@ -1,5 +1,5 @@ from .resnet import ResNet18 -from .V1_models import V1_CIFAR10, V1_CIFAR100 +from .V1_models import V1_CIFAR10, V1_CIFAR100, replace_linear_layer def define_models(args): if args.net == 'resnet18': @@ -8,5 +8,7 @@ def define_models(args): model = V1_CIFAR10(hidden_dim=100, size=args.s, spatial_freq=args.f, scale=args.scale) elif args.net == 'rsn_cifar100': - model = V1_CIFAR100() + model = V1_CIFAR10(hidden_dim=100, size=args.s, + spatial_freq=args.f, scale=args.scale) + replace_linear_layer(model) return model From 204bf3af8de9b2870a4df32e4023ebe1ccc100d5 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Mon, 22 Apr 2024 01:19:45 -0400 Subject: [PATCH 14/77] adding recursive functions script --- refactor/function_utils.py | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 refactor/function_utils.py diff --git a/refactor/function_utils.py b/refactor/function_utils.py new file mode 100644 index 0000000..f35ae5e --- /dev/null +++ b/refactor/function_utils.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +from conv_modules import FactConv2dPreExp + +def replace_layers_keep_weight(model): + ''' + Replace nn.Conv2d layers with FactConv2d + ''' + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + replace_layers_keep_weight(module) + if isinstance(module, nn.Conv2d): + ## simple module + new_module = FactConv2dPreExp( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + #new_module.tri1_vec = nn.Parameter(int(new_module.tri1_vec * scale)) + setattr(model, n, new_module) + +def replace_affines(model): + ''' + Set BatchNorm2d layers to have 'affine=False' + ''' + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + replace_affines(module) + if isinstance(module, nn.BatchNorm2d): + ## simple module + new_module = nn.BatchNorm2d( + num_features=module.num_features, + eps=module.eps, momentum=module.momentum, + affine=False, + track_running_stats=module.track_running_stats) + setattr(model, n, new_module) + +def replace_layers_agnostic(model, scale=1): + ''' + Replace nn.Conv2d layers with a different scale + ''' + prev_out_ch = 0 + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + replace_layers_agnostic(module,scale) + if isinstance(module, nn.Conv2d): + if module.in_channels == 3: + in_scale = 1 + else: + in_scale = scale + ## simple module + new_module = nn.Conv2d( + in_channels=int(module.in_channels*in_scale), + out_channels=int(module.out_channels*scale), + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + groups = module.groups, + bias=True if module.bias is not None else False) + setattr(model, n, new_module) + prev_out_ch = new_module.out_channels + if isinstance(module, nn.BatchNorm2d): + new_module = nn.BatchNorm2d(prev_out_ch) + setattr(model, n, new_module) + if isinstance(module, nn.Linear): + new_module = nn.Linear(int(512 * scale), 10) + setattr(model, n, new_module) From 2aeaf983caef676bcd217dc84f6604b7c947a44b Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Tue, 23 Apr 2024 01:28:53 -0400 Subject: [PATCH 15/77] some updates, more to come tomorrow --- refactor/conv_modules.py | 73 +-------------- refactor/models/V1_models.py | 119 ------------------------ refactor/models/__init__.py | 26 +++--- refactor/{ => models}/function_utils.py | 14 +-- refactor/models/resnet.py | 1 - refactor/og_pytorch_cifar.py | 67 ++++--------- refactor/pytorch_cifar_utils.py | 33 ------- 7 files changed, 43 insertions(+), 290 deletions(-) delete mode 100644 refactor/models/V1_models.py rename refactor/{ => models}/function_utils.py (89%) diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py index 1eb6d5f..0dd0f49 100644 --- a/refactor/conv_modules.py +++ b/refactor/conv_modules.py @@ -15,78 +15,7 @@ def _contract(tensor, matrix, axis): return torch.moveaxis(r, source=-1, destination=axis) # (..., P, ...) -class FactConv2dPostExp(nn.Conv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t, - stride: _size_2_t = 1, - padding: Union[str, _size_2_t] = 0, - dilation: _size_2_t = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = 'zeros', # TODO: refine this type - device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), - dtype=None - ) -> None: - # init as Conv2d - super().__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, bias, padding_mode, device, dtype) - - factory_kwargs = {'device': device, 'dtype': dtype} - self.factory_kwargs = factory_kwargs - - # weight shape: (out_channels, in_channels // groups, *kernel_size) - weight_shape = self.weight.shape - del self.weight # remove Parameter, create buffer - self.register_buffer("weight", torch.empty(weight_shape, **factory_kwargs)) - nn.init.kaiming_normal_(self.weight) - - self.in_features = self.in_channels // self.groups * \ - self.kernel_size[0] * self.kernel_size[1] - triu1 = torch.triu_indices(self.in_channels // self.groups, - self.in_channels // self.groups) - triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], - self.kernel_size[0] - * self.kernel_size[1]) - triu1_len = triu1.shape[1] - triu2_len = triu2.shape[1] - tri1_vec = torch.zeros((triu1_len,), - **factory_kwargs) - self.tri1_vec = Parameter(tri1_vec) - - tri2_vec = torch.zeros((triu2_len,), **factory_kwargs) - self.tri2_vec = Parameter(tri2_vec) - - def forward(self, input: Tensor) -> Tensor: - U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // self.groups) - U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1]) - U = torch.kron(U1, U2) - U = self._exp_diag(U) - - matrix_shape = (self.out_channels, self.in_features) - composite_weight = torch.reshape( - torch.reshape(self.weight, matrix_shape) @ U, - self.weight.shape - ) - - return self._conv_forward(input, composite_weight, self.bias) - - def _tri_vec_to_mat(self, vec, n): - U = torch.zeros((n, n), **self.factory_kwargs) - U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec - return U - - def _exp_diag(self, mat): - exp_diag = torch.exp(torch.diagonal(mat)) - n = mat.shape[0] - mat[range(n), range(n)] = exp_diag - return mat - - -class FactConv2dPreExp(nn.Conv2d): +class FactConv2d(nn.Conv2d): def __init__( self, in_channels: int, diff --git a/refactor/models/V1_models.py b/refactor/models/V1_models.py deleted file mode 100644 index dbae415..0000000 --- a/refactor/models/V1_models.py +++ /dev/null @@ -1,119 +0,0 @@ -import torch.nn as nn -import torch -import torch.nn.functional as F -import sys -sys.path.insert(0, '/home/mila/v/vivian.white/structured-random-features') -from src.models.init_weights import V1_init, classical_init, V1_weights -import gc - -def replace_linear_layer(model): - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - replace_layers_keep_weight(module) - if isinstance(module, nn.Linear): - ## simple module - new_module = nn.Linear( - in_features=module.in_features, - out_features=100, - bias=True if module.bias is not None else False) - old_sd = module.state_dict() - new_sd = new_module.state_dict() - #new_sd['weight'] = old_sd['weight'] - #if module.bias is not None: - # new_sd['bias'] = old_sd['bias'] - new_module.load_state_dict(new_sd) - setattr(model, n, new_module) - - -class V1_CIFAR10(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias=False, seed=None): - super().__init__() - - self.bn_x = nn.BatchNorm2d(3) - self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, - kernel_size=7, stride=1, padding=3, bias=bias) - self.relu = nn.ReLU() - self.smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) - self.bn_h1 = nn.BatchNorm2d(hidden_dim + 3) - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim + 3, out_channels=hidden_dim, - kernel_size=7, stride=1, padding=3, bias=bias) - - self.bn_h2 = nn.BatchNorm2d(hidden_dim * 2 + 3) - self.pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - self.flatten = nn.Flatten() - self.clf = nn.Linear((8 ** 2) * (hidden_dim * 2 + 3), 10) - - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (hidden_dim * 7 * 7) - center = (3., 3.) - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): - x = self.bn_x(x) - h = torch.cat((self.relu(self.v1_layer(x)), self.smooth(x)), 1) - h = self.bn_h1(h) - h = torch.cat((self.relu(self.v1_layer2(h)), self.smooth(h)), 1) - h = self.bn_h2(h) - h = self.flatten(self.pool(h)) - return self.clf(h) - - - -class V1_CIFAR100(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias=False, seed=None): - super(V1_CIFAR100, self).__init__() - - self.bn_x = nn.BatchNorm2d(3) - self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.relu = nn.ReLU() - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.bn0 = nn.BatchNorm2d(3) - self.bn1 = nn.BatchNorm2d(hidden_dim) - self.bn2 = nn.BatchNorm2d(hidden_dim) - self.bn_h1 = nn.BatchNorm2d(hidden_dim) - - self.pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - self.clf = nn.Linear((3 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 100) - scale1 = 1 / (3 * 7 * 7) - scale2 = 1 / (hidden_dim * 7 * 7) - center = None - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias==True: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): #[128, 3, 32, 32] - h1 = self.relu(self.v1_layer(self.bn_x(x))) - h2 = self.relu(self.v1_layer2(self.bn_h1(h1))) - - x_pool = self.bn0(self.pool(x)) - h1_pool = self.bn1(self.pool(h1)) - h2_pool = self.bn2(self.pool(h2)) - - x_flat = x_pool.view(x_pool.size(0), -1) - h1_flat = h1_pool.view(h1_pool.size(0), -1) - h2_flat = h2_pool.view(h2_pool.size(0), -1) - - concat = torch.cat((x_flat, h1_flat, h2_flat), 1) - - beta = self.clf(concat) - return beta - diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py index 7e50696..d2cc284 100644 --- a/refactor/models/__init__.py +++ b/refactor/models/__init__.py @@ -1,14 +1,18 @@ from .resnet import ResNet18 -from .V1_models import V1_CIFAR10, V1_CIFAR100, replace_linear_layer - +from .function_utils import replace_layers_keep_weight, turn_off_grada, replace_layers_scale def define_models(args): - if args.net == 'resnet18': - model = ResNet18() - elif args.net == 'rsn_cifar10': - model = V1_CIFAR10(hidden_dim=100, size=args.s, - spatial_freq=args.f, scale=args.scale) - elif args.net == 'rsn_cifar100': - model = V1_CIFAR10(hidden_dim=100, size=args.s, - spatial_freq=args.f, scale=args.scale) - replace_linear_layer(model) + if 'resnet18' in args.net: + model = ResNet18() + if 'fact' in args.net: + replace_layers_keep_weight(model) + if "v1" in args.net: + # TODO: import V1_init function from structured-random-features + V1_init(model) + if "us" in args.net: + # TODO: make turn_off_grad function + turn_off_grad(model, "spatial") + if "uc" in args.net: + turn_off_grad(model, "channel") + if args.width != 1: + replace_layers_scale(model, args.width) return model diff --git a/refactor/function_utils.py b/refactor/models/function_utils.py similarity index 89% rename from refactor/function_utils.py rename to refactor/models/function_utils.py index f35ae5e..ac49a26 100644 --- a/refactor/function_utils.py +++ b/refactor/models/function_utils.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from conv_modules import FactConv2dPreExp +from conv_modules import FactConv2d def replace_layers_keep_weight(model): ''' @@ -12,7 +12,7 @@ def replace_layers_keep_weight(model): replace_layers_keep_weight(module) if isinstance(module, nn.Conv2d): ## simple module - new_module = FactConv2dPreExp( + new_module = FactConv2d( in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, @@ -24,7 +24,6 @@ def replace_layers_keep_weight(model): if module.bias is not None: new_sd['bias'] = old_sd['bias'] new_module.load_state_dict(new_sd) - #new_module.tri1_vec = nn.Parameter(int(new_module.tri1_vec * scale)) setattr(model, n, new_module) def replace_affines(model): @@ -44,7 +43,7 @@ def replace_affines(model): track_running_stats=module.track_running_stats) setattr(model, n, new_module) -def replace_layers_agnostic(model, scale=1): +def replace_layers_scale(model, scale=1): ''' Replace nn.Conv2d layers with a different scale ''' @@ -52,7 +51,7 @@ def replace_layers_agnostic(model, scale=1): for n, module in model.named_children(): if len(list(module.children())) > 0: ## compound module, go inside it - replace_layers_agnostic(module,scale) + replace_layers_scale(module,scale) if isinstance(module, nn.Conv2d): if module.in_channels == 3: in_scale = 1 @@ -72,5 +71,8 @@ def replace_layers_agnostic(model, scale=1): new_module = nn.BatchNorm2d(prev_out_ch) setattr(model, n, new_module) if isinstance(module, nn.Linear): - new_module = nn.Linear(int(512 * scale), 10) + new_module = nn.Linear(int(module.in_features * scale), 10) setattr(model, n, new_module) + +def turn_off_grad(model, covariance): + #TODO: turn off gradients diff --git a/refactor/models/resnet.py b/refactor/models/resnet.py index 9177d5f..4ed155e 100644 --- a/refactor/models/resnet.py +++ b/refactor/models/resnet.py @@ -111,7 +111,6 @@ def ResNet18(): def ResNet18_Class100(): return ResNet(BasicBlock, [2, 2, 2, 2], 100) -# TODO: use a recursive function to replace all linears with final dim 100 def ResNet34(): return ResNet(BasicBlock, [3, 4, 6, 3]) diff --git a/refactor/og_pytorch_cifar.py b/refactor/og_pytorch_cifar.py index c75e23f..a383d0f 100644 --- a/refactor/og_pytorch_cifar.py +++ b/refactor/og_pytorch_cifar.py @@ -4,26 +4,14 @@ import torch.optim as optim import torch.nn.functional as F import torch.backends.cudnn as cudnn -from torch.profiler import profile, record_function, ProfilerActivity - import torchvision import torchvision.transforms as transforms - import os import argparse - from pytorch_cifar_utils import progress_bar, set_seeds - -from hooks import wandb_forwards_hook, wandb_backwards_hook - import wandb - from distutils.util import strtobool - -from conv_modules import FactConv2dPreExp - -# TODO: import define_models function -from models.define_models import define_models +from models import define_models def save_model(args, model): src= "/home/mila/v/vivian.white/scratch/v1-models/saved-models/test_refactor/" @@ -44,8 +32,12 @@ def save_model(args, model): parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') -parser.add_argument('--net', type=str, default='resnet18', choices=['vgg', 'vggbn', - 'resnet', 'factnetv1', 'factnetdefault', 'vggfact', 'vggbnfact'], help="which convmodule to use") +parser.add_argument('--net', type=str, default='resnet18', choices=['resnet18', + 'rsn_cifar10', 'rsn_cifar100'], + help="which model to use") +parser.add_argument('--num_epochs', type=int, default=90, help='number of trainepochs') +parser.add_argument('--hidden_dim', type=int, default=100, + help='number of hidden dimensions in model') parser.add_argument('--freeze_spatial', dest='freeze_spatial', type=lambda x: bool(strtobool(x)), default=True, help="freeze spatial filters for LearnableCov models") @@ -62,6 +54,7 @@ def save_model(args, model): parser.add_argument('--bias', dest='bias', type=lambda x: bool(strtobool(x)), default=False, help='bias=True or False') parser.add_argument('--seed', default=0, type=int, help='seed to use') +parser.add_argument('--width', type=float, default=1, help='resnet width scale factor') args = parser.parse_args() @@ -75,12 +68,17 @@ def save_model(args, model): transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + # normalization slightly different from old training setup + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) ]) transform_test = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) ]) trainset = torchvision.datasets.CIFAR10( @@ -99,44 +97,18 @@ def save_model(args, model): # Model print('==> Building model..') - -def replace_layers_keep_weight(model): - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - replace_layers_keep_weight(module) - if isinstance(module, nn.Conv2d): - ## simple module - new_module = FactConv2dPreExp( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - bias=True if module.bias is not None else False) - old_sd = module.state_dict() - new_sd = new_module.state_dict() - new_sd['weight'] = old_sd['weight'] - if module.bias is not None: - new_sd['bias'] = old_sd['bias'] - new_module.load_state_dict(new_sd) - #new_module.tri1_vec = nn.Parameter(int(new_module.tri1_vec * scale)) - setattr(model, n, new_module) - -set_seeds(args.seed) net = define_models(args) -replace_layers_keep_weight(net) run_name = args.net -print("Model Built!") +print("Model Built! ", net) set_seeds(args.seed) net = net.to(device) wandb_dir = "/home/mila/v/vivian.white/scratch/v1-models/wandb" os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) -run_name = "OGVGG" -run = wandb.init(project="random_project", config=args, - group="pytorch_cifar_better_tracked_og", name=run_name, dir=wandb_dir) +run = wandb.init(project="refactoring", config=args, + group="pytorch_cifar", name=run_name, dir=wandb_dir) #wandb.watch(net, log='all', log_freq=1) @@ -145,7 +117,6 @@ def replace_layers_keep_weight(model): momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) - # Training def train(epoch): print('\nEpoch: %d' % epoch) @@ -203,7 +174,7 @@ def test(epoch): best_acc = acc -for epoch in range(start_epoch, start_epoch+200):#00 +for epoch in range(start_epoch, start_epoch+args.num_epochs): train(epoch) test(epoch) scheduler.step() diff --git a/refactor/pytorch_cifar_utils.py b/refactor/pytorch_cifar_utils.py index 08b9e15..dc3ccff 100644 --- a/refactor/pytorch_cifar_utils.py +++ b/refactor/pytorch_cifar_utils.py @@ -1,6 +1,4 @@ '''Some helper functions for PyTorch, including: - - get_mean_and_std: calculate the mean and std value of dataset. - - msr_init: net parameter initialization. - progress_bar: progress bar mimic xlua.progress. ''' import os @@ -22,37 +20,6 @@ def set_seeds(seed): np.random.seed(seed) - -def get_mean_and_std(dataset): - '''Compute the mean and std value of dataset.''' - dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) - mean = torch.zeros(3) - std = torch.zeros(3) - print('==> Computing mean and std..') - for inputs, targets in dataloader: - for i in range(3): - mean[i] += inputs[:,i,:,:].mean() - std[i] += inputs[:,i,:,:].std() - mean.div_(len(dataset)) - std.div_(len(dataset)) - return mean, std - -def init_params(net): - '''Init layer parameters.''' - for m in net.modules(): - if isinstance(m, nn.Conv2d): - init.kaiming_normal(m.weight, mode='fan_out') - if m.bias: - init.constant(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): - init.constant(m.weight, 1) - init.constant(m.bias, 0) - elif isinstance(m, nn.Linear): - init.normal(m.weight, std=1e-3) - if m.bias: - init.constant(m.bias, 0) - - #_, term_width = os.popen('stty size', 'r').read().split() term_width = 80 #int(term_width) From 6a236c58885065b563856724cf63530db1a46fc2 Mon Sep 17 00:00:00 2001 From: vivianwhite <66977221+vivianwhite@users.noreply.github.com> Date: Sun, 21 Apr 2024 22:22:33 -0700 Subject: [PATCH 16/77] Delete refactor/resnet.py --- refactor/resnet.py | 136 --------------------------------------------- 1 file changed, 136 deletions(-) delete mode 100644 refactor/resnet.py diff --git a/refactor/resnet.py b/refactor/resnet.py deleted file mode 100644 index 4ed155e..0000000 --- a/refactor/resnet.py +++ /dev/null @@ -1,136 +0,0 @@ -'''ResNet in PyTorch. - -For Pre-activation ResNet, see 'preact_resnet.py'. - -Reference: -[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun - Deep Residual Learning for Image Recognition. arXiv:1512.03385 -''' -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super(BasicBlock, self).__init__() - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=1, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion*planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion*planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, in_planes, planes, stride=1): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=stride, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, self.expansion * - planes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(self.expansion*planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion*planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion*planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = F.relu(self.bn2(self.conv2(out))) - out = self.bn3(self.conv3(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class ResNet(nn.Module): - def __init__(self, block, num_blocks, num_classes=10): - super(ResNet, self).__init__() - self.in_planes = 64 - - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, - stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) - self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) - self.linear = nn.Linear(512*block.expansion, num_classes) - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1]*(num_blocks-1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.layer1(out) - out = self.layer2(out) - out = self.layer3(out) - out = self.layer4(out) - out = F.avg_pool2d(out, 4) - out = out.view(out.size(0), -1) - out = self.linear(out) - return out - - -def ResNet18(): - return ResNet(BasicBlock, [2, 2, 2, 2]) - - -def ResNet18_Class100(): - return ResNet(BasicBlock, [2, 2, 2, 2], 100) - - -def ResNet34(): - return ResNet(BasicBlock, [3, 4, 6, 3]) - - -def ResNet50(): - return ResNet(Bottleneck, [3, 4, 6, 3]) - - -def ResNet101(): - return ResNet(Bottleneck, [3, 4, 23, 3]) - - -def ResNet152(): - return ResNet(Bottleneck, [3, 8, 36, 3]) - - -def test(): - net = ResNet18() - y = net(torch.randn(1, 3, 32, 32)) - print(y.size()) - -#test() From a9317a7c7805d87e51bf4e7d41d09babc9212c94 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Tue, 23 Apr 2024 01:34:59 -0400 Subject: [PATCH 17/77] update define_model options --- refactor/models/__init__.py | 6 +++--- refactor/og_pytorch_cifar.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py index d2cc284..95e38a6 100644 --- a/refactor/models/__init__.py +++ b/refactor/models/__init__.py @@ -5,13 +5,13 @@ def define_models(args): model = ResNet18() if 'fact' in args.net: replace_layers_keep_weight(model) - if "v1" in args.net: + if args.spatial_init == 'V1': # TODO: import V1_init function from structured-random-features V1_init(model) - if "us" in args.net: + if args.freeze_spatial == True: # TODO: make turn_off_grad function turn_off_grad(model, "spatial") - if "uc" in args.net: + if args.freeze_channel == True: turn_off_grad(model, "channel") if args.width != 1: replace_layers_scale(model, args.width) diff --git a/refactor/og_pytorch_cifar.py b/refactor/og_pytorch_cifar.py index a383d0f..2bc4efe 100644 --- a/refactor/og_pytorch_cifar.py +++ b/refactor/og_pytorch_cifar.py @@ -33,7 +33,7 @@ def save_model(args, model): parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') parser.add_argument('--net', type=str, default='resnet18', choices=['resnet18', - 'rsn_cifar10', 'rsn_cifar100'], + 'resnet18-fact'], help="which model to use") parser.add_argument('--num_epochs', type=int, default=90, help='number of trainepochs') parser.add_argument('--hidden_dim', type=int, default=100, From 11d947786b6b3b4efb215ba8294e1d8501aa8541 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Tue, 23 Apr 2024 01:37:19 -0400 Subject: [PATCH 18/77] normalization back to normal --- refactor/og_pytorch_cifar.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/refactor/og_pytorch_cifar.py b/refactor/og_pytorch_cifar.py index 2bc4efe..d295a1b 100644 --- a/refactor/og_pytorch_cifar.py +++ b/refactor/og_pytorch_cifar.py @@ -68,17 +68,12 @@ def save_model(args, model): transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - # normalization slightly different from old training setup - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), - #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10( @@ -139,7 +134,6 @@ def train(epoch): progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) - def test(epoch): global best_acc net.eval() From 690585013ff2607c235a1ea1e040a97ffbf005f8 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Tue, 23 Apr 2024 15:49:29 -0400 Subject: [PATCH 19/77] updates anda adding turn_off_grad function --- refactor/models/__init__.py | 12 ++++++------ refactor/models/function_utils.py | 22 +++++++++++++++++++--- refactor/models/resnet.py | 4 ---- refactor/og_pytorch_cifar.py | 23 ++++------------------- 4 files changed, 29 insertions(+), 32 deletions(-) diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py index 95e38a6..bbe0b81 100644 --- a/refactor/models/__init__.py +++ b/refactor/models/__init__.py @@ -1,17 +1,17 @@ from .resnet import ResNet18 -from .function_utils import replace_layers_keep_weight, turn_off_grada, replace_layers_scale +from .function_utils import replace_layers_factconv2d, turn_off_grad, replace_layers_scale def define_models(args): if 'resnet18' in args.net: model = ResNet18() if 'fact' in args.net: - replace_layers_keep_weight(model) - if args.spatial_init == 'V1': + replace_layers_factconv2d(model) + if "v1" in args.net: # TODO: import V1_init function from structured-random-features + # TODO: make srf code pip installable V1_init(model) - if args.freeze_spatial == True: - # TODO: make turn_off_grad function + if "us" in args.net: turn_off_grad(model, "spatial") - if args.freeze_channel == True: + if "uc" in args.net: turn_off_grad(model, "channel") if args.width != 1: replace_layers_scale(model, args.width) diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index ac49a26..9df0af2 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -2,14 +2,14 @@ import torch.nn as nn from conv_modules import FactConv2d -def replace_layers_keep_weight(model): +def replace_layers_factconv2d(model): ''' Replace nn.Conv2d layers with FactConv2d ''' for n, module in model.named_children(): if len(list(module.children())) > 0: ## compound module, go inside it - replace_layers_keep_weight(module) + replace_layers_factconv2d(module) if isinstance(module, nn.Conv2d): ## simple module new_module = FactConv2d( @@ -75,4 +75,20 @@ def replace_layers_scale(model, scale=1): setattr(model, n, new_module) def turn_off_grad(model, covariance): - #TODO: turn off gradients + ''' + Turn off gradients in tri1_vec or tri2_vec to turn off + channel or spatial covariance learning + ''' + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + turn_off_grad(module, covariance) + if isinstance(module, FactConv2d): + for name, param in module.named_parameters(): + if covariance == "channel": + if "tri1_vec" in name: + param.requires_grad = False + if covariance == "spatial": + if "tri2_vec" in name: + param.requires_grad = False + diff --git a/refactor/models/resnet.py b/refactor/models/resnet.py index 4ed155e..beb18f9 100644 --- a/refactor/models/resnet.py +++ b/refactor/models/resnet.py @@ -108,10 +108,6 @@ def ResNet18(): return ResNet(BasicBlock, [2, 2, 2, 2]) -def ResNet18_Class100(): - return ResNet(BasicBlock, [2, 2, 2, 2], 100) - - def ResNet34(): return ResNet(BasicBlock, [3, 4, 6, 3]) diff --git a/refactor/og_pytorch_cifar.py b/refactor/og_pytorch_cifar.py index d295a1b..6581924 100644 --- a/refactor/og_pytorch_cifar.py +++ b/refactor/og_pytorch_cifar.py @@ -32,27 +32,11 @@ def save_model(args, model): parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') -parser.add_argument('--net', type=str, default='resnet18', choices=['resnet18', - 'resnet18-fact'], +parser.add_argument('--net', type=str, default='resnet18', #choices=['resnet18','resnet18-fact'], help="which model to use") -parser.add_argument('--num_epochs', type=int, default=90, help='number of trainepochs') -parser.add_argument('--hidden_dim', type=int, default=100, - help='number of hidden dimensions in model') -parser.add_argument('--freeze_spatial', dest='freeze_spatial', - type=lambda x: bool(strtobool(x)), default=True, - help="freeze spatial filters for LearnableCov models") -parser.add_argument('--freeze_channel', dest='freeze_channel', - type=lambda x: bool(strtobool(x)), default=False, - help="freeze channels for LearnableCov models") -parser.add_argument('--spatial_init', type=str, default='V1', choices=['default', 'V1'], - help="initialization for spatial filters for LearnableCov models") -parser.add_argument('--s', type=int, default=2, help='V1 size') -parser.add_argument('--f', type=float, default=0.1, help='V1 spatial frequency') -parser.add_argument('--scale', type=int, default=1, help='V1 scale') +parser.add_argument('--num_epochs', type=int, default=200, help='number of trainepochs') parser.add_argument('--name', type=str, default='TESTING_VGG', help='filename for saved model') -parser.add_argument('--bias', dest='bias', type=lambda x: bool(strtobool(x)), - default=False, help='bias=True or False') parser.add_argument('--seed', default=0, type=int, help='seed to use') parser.add_argument('--width', type=float, default=1, help='resnet width scale factor') @@ -94,7 +78,8 @@ def save_model(args, model): net = define_models(args) run_name = args.net -print("Model Built! ", net) +print("Args.net: ", args.net) +print("Net: ", net) set_seeds(args.seed) net = net.to(device) From aa28aab5b0f06dbef9f1689d8e3f3d4430527110 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Tue, 23 Apr 2024 18:55:28 -0400 Subject: [PATCH 20/77] adding V1-init functionality --- refactor/conv_modules.py | 4 +- refactor/learnable_cov.py | 83 +++++++++++++++++++++++++++++++ refactor/models/__init__.py | 7 ++- refactor/models/function_utils.py | 22 ++++++++ 4 files changed, 109 insertions(+), 7 deletions(-) create mode 100644 refactor/learnable_cov.py diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py index 0dd0f49..847bab9 100644 --- a/refactor/conv_modules.py +++ b/refactor/conv_modules.py @@ -1,11 +1,9 @@ import torch from torch import Tensor import torch.nn as nn -from torch.nn.parameter import Parameter, UninitializedParameter +from torch.nn.parameter import Parameter from torch.nn.common_types import _size_2_t from typing import Optional, List, Tuple, Union -import time -import copy def _contract(tensor, matrix, axis): diff --git a/refactor/learnable_cov.py b/refactor/learnable_cov.py new file mode 100644 index 0000000..089f951 --- /dev/null +++ b/refactor/learnable_cov.py @@ -0,0 +1,83 @@ +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn.parameter import Parameter, UninitializedParameter +from torch.nn.common_types import _size_2_t +from typing import Optional, List, Tuple, Union +import numpy as np +import numpy.linalg as la +from scipy.spatial.distance import pdist, squareform + + +def V1_covariance_matrix(dim, size, spatial_freq, center, scale=1): + """ + Generates the covariance matrix for Gaussian Process with non-stationary + covariance. This matrix will be used to generate random + features inspired from the receptive-fields of V1 neurons. + + C(x, y) = exp(-|x - y|/(2 * spatial_freq))^2 * exp(-|x - m| / (2 * size))^2 * exp(-|y - m| / (2 * size))^2 + + Parameters + ---------- + + dim : tuple of shape (2, 1) + Dimension of random features. + + size : float + Determines the size of the random weights + + spatial_freq : float + Determines the spatial frequency of the random weights + + center : tuple of shape (2, 1) + Location of the center of the random weights. + + scale: float, default=1 + Normalization factor for Tr norm of cov matrix + + Returns + ------- + + C : array-like of shape (dim[0] * dim[1], dim[0] * dim[1]) + covariance matrix w/ Tr norm = scale * dim[0] * dim[1] + """ + + x = np.arange(dim[0]) + y = np.arange(dim[1]) + yy, xx = np.meshgrid(y, x) + grid = np.column_stack((xx.flatten(), yy.flatten())) + + a = squareform(pdist(grid, 'sqeuclidean')) + b = la.norm(grid - center, axis=1) ** 2 + c = b.reshape(-1, 1) + C = np.exp(-a / (2 * spatial_freq ** 2)) * np.exp(-b / (2 * size ** 2)) * np.exp(-c / (2 * size ** 2)) \ + + 1e-5 * np.eye(dim[0] * dim[1]) + C *= scale * dim[0] * dim[1] / np.trace(C) + return C + +def V1_init(layer, size, spatial_freq, center, scale=1., bias=False, seed=None, + device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): + ''' + Initialization for FactConv2d + ''' + + classname = layer.__class__.__name__ + assert classname.find('FactConv2d') != -1, 'This init only works for FactConv2d layers' + assert center is not None, "center needed" + + out_channels, in_channels, xdim, ydim = layer.weight.shape + dim = (xdim, ydim) + + C_patch = Tensor(V1_covariance_matrix(dim, size, spatial_freq, center, scale)).to(device) + U_patch = torch.linalg.cholesky(C_patch, upper=True) + n = U_patch.shape[0] + # replace diagonal with logarithm for parameterization + log_diag = torch.log(torch.diagonal(U_patch)) + U_patch[range(n), range(n)] = log_diag + # form vector of upper triangular entries + tri_vec = U_patch[torch.triu_indices(n, n, device=device).tolist()].ravel() + with torch.no_grad(): + layer.tri2_vec.copy_(tri_vec) + + if bias == False: + layer.bias = None diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py index bbe0b81..4b252df 100644 --- a/refactor/models/__init__.py +++ b/refactor/models/__init__.py @@ -1,14 +1,13 @@ from .resnet import ResNet18 -from .function_utils import replace_layers_factconv2d, turn_off_grad, replace_layers_scale +from .function_utils import replace_layers_factconv2d, turn_off_grad, replace_layers_scale, init_V1_layers + def define_models(args): if 'resnet18' in args.net: model = ResNet18() if 'fact' in args.net: replace_layers_factconv2d(model) if "v1" in args.net: - # TODO: import V1_init function from structured-random-features - # TODO: make srf code pip installable - V1_init(model) + init_V1_layers(model, bias=False) if "us" in args.net: turn_off_grad(model, "spatial") if "uc" in args.net: diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index 9df0af2..18c5184 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn from conv_modules import FactConv2d +from learnable_cov import V1_init def replace_layers_factconv2d(model): ''' @@ -92,3 +93,24 @@ def turn_off_grad(model, covariance): if "tri2_vec" in name: param.requires_grad = False +def init_V1_layers(model, bias): + ''' + Initialize every FactConv2d layer with V1-inspired + spatial weight init + ''' + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + init_V1_layers(module, bias) + if isinstance(module, FactConv2d): + kernel_size = 3 + center = ((kernel_size - 1) / 2, (kernel_size - 1) / 2) + V1_init(module, size=2, spatial_freq=0.1, scale=1, center=center) + for name, param in module.named_parameters(): + if "weight" in name: + param.requires_grad = False + + if bias: + if "bias" in name: + param.requires_grad = False + From e537fdab7b38d7e4988420620417a95bcfc62813 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Tue, 23 Apr 2024 18:59:17 -0400 Subject: [PATCH 21/77] updating v1-init --- refactor/models/function_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index 18c5184..d2a1927 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -103,8 +103,7 @@ def init_V1_layers(model, bias): ## compound module, go inside it init_V1_layers(module, bias) if isinstance(module, FactConv2d): - kernel_size = 3 - center = ((kernel_size - 1) / 2, (kernel_size - 1) / 2) + center = ((module.kernel_size[0] - 1) / 2, (module.kernel_size[1] - 1) / 2) V1_init(module, size=2, spatial_freq=0.1, scale=1, center=center) for name, param in module.named_parameters(): if "weight" in name: From 15b8cd51c6b383b81d338a4830fc4a40888dafdc Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Tue, 23 Apr 2024 20:11:04 -0400 Subject: [PATCH 22/77] cleanup --- refactor/learnable_cov.py | 5 +---- refactor/models/__init__.py | 1 + refactor/models/function_utils.py | 4 ++++ refactor/og_pytorch_cifar.py | 3 +-- refactor/pytorch_cifar_utils.py | 5 +++-- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/refactor/learnable_cov.py b/refactor/learnable_cov.py index 089f951..69fcf4d 100644 --- a/refactor/learnable_cov.py +++ b/refactor/learnable_cov.py @@ -1,9 +1,5 @@ import torch from torch import Tensor -import torch.nn as nn -from torch.nn.parameter import Parameter, UninitializedParameter -from torch.nn.common_types import _size_2_t -from typing import Optional, List, Tuple, Union import numpy as np import numpy.linalg as la from scipy.spatial.distance import pdist, squareform @@ -55,6 +51,7 @@ def V1_covariance_matrix(dim, size, spatial_freq, center, scale=1): C *= scale * dim[0] * dim[1] / np.trace(C) return C + def V1_init(layer, size, spatial_freq, center, scale=1., bias=False, seed=None, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): ''' diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py index 4b252df..c314e22 100644 --- a/refactor/models/__init__.py +++ b/refactor/models/__init__.py @@ -1,6 +1,7 @@ from .resnet import ResNet18 from .function_utils import replace_layers_factconv2d, turn_off_grad, replace_layers_scale, init_V1_layers + def define_models(args): if 'resnet18' in args.net: model = ResNet18() diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index d2a1927..06f415a 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -27,6 +27,7 @@ def replace_layers_factconv2d(model): new_module.load_state_dict(new_sd) setattr(model, n, new_module) + def replace_affines(model): ''' Set BatchNorm2d layers to have 'affine=False' @@ -44,6 +45,7 @@ def replace_affines(model): track_running_stats=module.track_running_stats) setattr(model, n, new_module) + def replace_layers_scale(model, scale=1): ''' Replace nn.Conv2d layers with a different scale @@ -75,6 +77,7 @@ def replace_layers_scale(model, scale=1): new_module = nn.Linear(int(module.in_features * scale), 10) setattr(model, n, new_module) + def turn_off_grad(model, covariance): ''' Turn off gradients in tri1_vec or tri2_vec to turn off @@ -93,6 +96,7 @@ def turn_off_grad(model, covariance): if "tri2_vec" in name: param.requires_grad = False + def init_V1_layers(model, bias): ''' Initialize every FactConv2d layer with V1-inspired diff --git a/refactor/og_pytorch_cifar.py b/refactor/og_pytorch_cifar.py index 6581924..22c5489 100644 --- a/refactor/og_pytorch_cifar.py +++ b/refactor/og_pytorch_cifar.py @@ -32,8 +32,7 @@ def save_model(args, model): parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') -parser.add_argument('--net', type=str, default='resnet18', #choices=['resnet18','resnet18-fact'], - help="which model to use") +parser.add_argument('--net', type=str, default='resnet18', help="which model to use") parser.add_argument('--num_epochs', type=int, default=200, help='number of trainepochs') parser.add_argument('--name', type=str, default='TESTING_VGG', help='filename for saved model') diff --git a/refactor/pytorch_cifar_utils.py b/refactor/pytorch_cifar_utils.py index dc3ccff..ef51ed9 100644 --- a/refactor/pytorch_cifar_utils.py +++ b/refactor/pytorch_cifar_utils.py @@ -20,12 +20,12 @@ def set_seeds(seed): np.random.seed(seed) -#_, term_width = os.popen('stty size', 'r').read().split() -term_width = 80 #int(term_width) +term_width = 80 TOTAL_BAR_LENGTH = 65. last_time = time.time() begin_time = last_time + def progress_bar(current, total, msg=None): global last_time, begin_time if current == 0: @@ -69,6 +69,7 @@ def progress_bar(current, total, msg=None): sys.stdout.write('\n') sys.stdout.flush() + def format_time(seconds): days = int(seconds / 3600/24) seconds = seconds - days*3600*24 From 29978b4987532d87cc0c3c1d85a55a9533adaf82 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 24 Apr 2024 07:50:26 -0400 Subject: [PATCH 23/77] initial commit of refactored rainbow (have not run anything yet) --- refactor/refactor_rainbow.py | 803 +++++++++++++++++++++++++++++++++++ 1 file changed, 803 insertions(+) create mode 100644 refactor/refactor_rainbow.py diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py new file mode 100644 index 0000000..e326d87 --- /dev/null +++ b/refactor/refactor_rainbow.py @@ -0,0 +1,803 @@ +'''Train CIFAR10 with PyTorch.''' +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from torch.profiler import profile, record_function, ProfilerActivity +import torchvision +import torchvision.transforms as transforms +import time +import os +import argparse +import copy + +from pytorch_cifar_utils import progress_bar, set_seeds +from test_models_safety import PostExp, PreExp +from layers_model import ThreeLayer_CIFAR10, Sequential_ThreeLayer_CIFAR10 +import wandb +from distutils.util import strtobool +from resnet import ResNet18 +from vgg import VGG +import numpy as np +import gc +#torch.backends.cudnn.allow_tf32 = True +#torch.backends.cuda.matmul.allow_tf32 = True +#torch.backends.cuda.preferred_linalg_library('magma') + +def save_model(args, model): + #assert False + src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/eigh_final_refactor_covar_new_testing_rainbow_models/" + model_dir = src + args.name + os.makedirs(model_dir, exist_ok=True) + os.chdir(model_dir) + + #saves loss & accuracy in the trial directory -- all trials + + torch.save(model.state_dict(), model_dir+ "/model.pt") + torch.save(args, model_dir+ "/args.pt") + + +parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') +parser.add_argument('--lr', default=0.1, type=float, help='learning rate') +parser.add_argument('--epochs', default=10, type=int, help='number of epochs') +parser.add_argument('--name', type=str, default='TESTING_VGG', + help='filename for saved model') +parser.add_argument('--affine', type=lambda x: bool(strtobool(x)), + default=True, help='Batch Norm affine True or False') +parser.add_argument('--aca', type=lambda x: bool(strtobool(x)), + default=True, help='Activation Cross-Covariance Alignment') +parser.add_argument('--wa', type=lambda x: bool(strtobool(x)), + default=True, help='Weight alignment True=Yes False=No') +parser.add_argument('--in_wa', type=lambda x: bool(strtobool(x)), + default=True, help='input=True output=False') +parser.add_argument('--fact', type=lambda x: bool(strtobool(x)), + default=True, help='FactNet True or False') +parser.add_argument('--width', default=0.125, type=float, help='width') +parser.add_argument('--sampling', type=str, default='ours', + choices=['ours', 'theirs'], help="which sampling to use") +args = parser.parse_args() +if args.width == 1.0: + args.width = 1 +if args.width == 2.0: + args.width = 2 +if args.width == 4.0: + args.width = 4 +if args.width == 8.0: + args.width = 8 +print("Sampling: {} Width: {} Fact: {} ACA: {} WA: {} In_WA: {}".format(args.sampling, + args.width, args.fact, args.aca, args.wa, args.in_wa)) + +device = 'cuda' if torch.cuda.is_available() else 'cpu' +best_acc = 0 # best test accuracy +start_epoch = 0 # start from epoch 0 or last checkpoint epoch + +# Data +print('==> Preparing data..') +transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +trainset = torchvision.datasets.CIFAR10( + root='./data', train=True, download=True, transform=transform_train)#transform_train) +trainloader = torch.utils.data.DataLoader( + trainset, batch_size=128, shuffle=True, num_workers=4, drop_last=True) + +testset = torchvision.datasets.CIFAR10( + root='./data', train=False, download=True, transform=transform_test) +testloader = torch.utils.data.DataLoader( + testset, batch_size=1000, shuffle=False, num_workers=8) + +classes = ('plane', 'car', 'bird', 'cat', 'deer', + 'dog', 'frog', 'horse', 'ship', 'truck') + +# Model +print('==> Building model..') +from ConvModules import FactConv2dPreExp + +def replace_layers_agnostic(model, scale=1): + prev_out_ch = 0 + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + replace_layers_agnostic(module,scale) + if isinstance(module, nn.Conv2d): + if module.in_channels == 3: + in_scale = 1 + else: + in_scale = scale + ## simple module + new_module = nn.Conv2d( + in_channels=int(module.in_channels*in_scale), + out_channels=int(module.out_channels*scale), + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + groups = module.groups, + bias=True if module.bias is not None else False) + setattr(model, n, new_module) + prev_out_ch = new_module.out_channels + if isinstance(module, nn.BatchNorm2d): + new_module = nn.BatchNorm2d(prev_out_ch) + setattr(model, n, new_module) + if isinstance(module, nn.Linear): + new_module = nn.Linear(int(512 * scale), 10) + setattr(model, n, new_module) + +def replace_affines(model): + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + replace_affines(module) + if isinstance(module, nn.BatchNorm2d): + ## simple module + new_module = nn.BatchNorm2d( + num_features=module.num_features, + eps=module.eps, momentum=module.momentum, + affine=False, + track_running_stats=module.track_running_stats) + setattr(model, n, new_module) +def replace_layers(model): + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + replace_layers(module) + if isinstance(module, nn.Conv2d): + ## simple module + new_module = FactConv2dPreExp( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + setattr(model, n, new_module) + +def replace_layers_fact(model): + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + replace_layers_fact(module) + if isinstance(module, FactConv2dPreExp): + ## simple module + new_module = nn.Conv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + #new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + U1 = module._tri_vec_to_mat(module.tri1_vec, module.in_channels // + module.groups,module.scat_idx1) + U2 = module._tri_vec_to_mat(module.tri2_vec, module.kernel_size[0] * module.kernel_size[1], + module.scat_idx2) + U = torch.kron(U1, U2) + matrix_shape = (module.out_channels, module.in_features) + composite_weight = torch.reshape( + torch.reshape(module.weight, matrix_shape) @ U, + module.weight.shape + ) + #output = self._conv_forward(input, composite_weight, self.bias) + new_sd['weight'] = composite_weight + new_module.load_state_dict(new_sd) + setattr(model, n, new_module) + + +net=ResNet18() +#net=Sequential_ThreeLayer_CIFAR10(100,False) +net.to(device) +replace_layers_agnostic(net, args.width) +if args.fact: + replace_layers(net) +#if not args.affine: +# replace_affines(net) +#sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/three-layer-models-sequential/TESTING_3Layer_final/{}_model.pt".format("fact" if args.fact else "conv")) +if args.fact and args.affine: + if args.width == 8: + sd=torch.load("/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/width_8/8scale_final/fact_model.pt") + else: + sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/fact_model.pt".format(args.width)) +elif not args.fact and args.affine: + if args.width == 8: + sd=torch.load("/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/width_8/8scale_final/conv_model.pt") + else: + sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/conv_model.pt".format(args.width)) +elif args.fact and not args.affine: + sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_2/{}scale_final/fact_model.pt".format(args.width)) +elif not args.fact and not args.affine: + sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_2/{}scale_final/conv_model.pt".format(args.width)) +net.load_state_dict(sd) +net.to(device) +net_new = copy.deepcopy(net) +net_new.to(device) +print(net_new) +replace_layers_fact(net) +net.to(device) + +net.train() +net_new.train() + +#traditional way of calculating svd. can be a bit unstable sometimes tho +def calc_svd(A, name=''): + u, s, vh = torch.linalg.svd( + A, full_matrices=False, + # driver="gesvd" + ) # (C_in_reference, R), (R,), (R, C_in_generated) + alignment = u @ vh # (C_in_reference, C_in_generated) + return alignment + +#i've been finding this way of calculating svd to be more stable. +def calc_svd_eigh(A, name=''): + A_T_A = A.T@A + V_val, Vn = torch.linalg.eigh(A_T_A) + V_val = V_val.flip(0) + Vn = Vn.fliplr().T + Sn = (1e-6 + V_val.abs()).sqrt() + Sn_inv = (1/Sn).diag() + Un = A @ Vn.T @ Sn_inv + alignment = Un @ Vn + return alignment + +#used in activation cross-covariance calculation +#input align hook +def return_hook(): + def hook(mod, inputs): + shape = inputs[0].shape + inputs_permute = inputs[0].permute(1,0,2,3).reshape(inputs[0].shape[1], -1) + reshape = (mod.input_align@inputs_permute).reshape(shape[1], + shape[0], shape[2], + shape[3]).permute(1, 0, 2, 3) + return reshape + return hook + + +#settings +# our, wa true, aca true (fact, conv) +# our wa false, aca true (fact, conv) +# our wa false aca false (fact, conv) [Just Random] +# our was true aca false (fact, conv) +# +# theirs wa true, aca true (conv) +# theirs wa false, aca true (conv) +# theirs wa false aca false (conv) [Just Random] +# theirs was true aca false (conv) +# +# this function does not do an explicit specification of the colored covariance +@torch.no_grad() +def our_rainbow_sampling(model, new_model): + for (n1, m1), (n2, m2) in zip(model.named_children(), new_model.named_children()): + if len(list(m1.children())) > 0: + our_rainbow_sampling(m1, m2) + if isinstance(m1, nn.Conv2d): + print("conv") + if isinstance(m2, FactConv2dPreExp): + new_module = FactConv2dPreExp( + in_channels=m2.in_channels, + out_channels=m2.out_channels, + kernel_size=m2.kernel_size, + stride=m2.stride, padding=m2.padding, + bias=True if m2.bias is not None else False) + else: + new_module = nn.Conv2d( + in_channels=int(m2.in_channels), + out_channels=int(m2.out_channels), + kernel_size=m2.kernel_size, + stride=m2.stride, padding=m2.padding, + groups = m2.groups, + bias=True if m2.bias is not None else False).to(device) + + if args.sampling == 'ours' and args.wa: + # right now this function does not do an explicit specification of the colored covariance + new_module = weight_Alignment(m1, m2, new_module, in_dim=args.in_wa) + if args.sampling == 'theirs': + # for conv only + if args.wa: + new_module = weight_Alignment_With_CC(m1, m2, new_module) + else: + new_module = colored_Covariance_Specification(m1, m2, new_module) + # this step calculates the activation cross-covariance alignment (ACA) + if m1.in_channels != 3 and args.aca: + new_module = conv_ACA(m1, m2, new_module) + # converts fact conv to conv. this is for sake of speed. + #if isinstance(new_module, FactConv2dPreExp): + # new_module = fact_2_conv(new_module) + #changes the network module + setattr(new_model, n1, new_module) + + #only computes the ACA + if isinstance(m1, nn.Linear) and args.aca: + new_module = linear_ACA(m1, m2, new_model) + setattr(new_model, n1, new_module) + #just run stats through + if isinstance(m1, nn.BatchNorm2d): + batchNorm_stats_recalc(m1, m2) + + +def weight_Alignment(m1, m2, new_module, in_dim=True): + # reference model state dict + print("we go here") + ref_sd = m1.state_dict() + # generated model state dict - uses reference model weights. for now + gen_sd = m2.state_dict() + + # module with random init - to be loaded to model + loading_sd = new_module.state_dict() + new_gaussian = loading_sd['weight'] + print(new_gaussian.shape) + + # carry over old bias. only matters when we work with no batchnorm networks + if m1.bias is not None: + loading_sd['bias'] = ref_sd['bias'] + # carry over old colored covariance. only matters with fact-convs + if "tri1_vec" in gen_sd.keys(): + loading_sd['tri1_vec']=gen_sd['tri1_vec'] + loading_sd['tri2_vec']=gen_sd['tri2_vec'] + #this is the spot where + # we can do weight alignment + # for fact net, this means aligning with the random noise + # for conv net, this could mean aligning with a. W OR b. U + # we can do colored-covariance specification + # for fact net, this means just using it's R matrix + # for conv net, this could mean doing nothing (if aligning with W), or use S and V if we did b. + # in this function, we just align with W and don't specify the mulit-color covariance + + # IF FACT: we align the generated factnet with the reference fact net's noise + # IF CONV: we align the generated convnet with the reference conv net's weight matrix + reference_weight = gen_sd['weight'] + generated_weight = new_gaussian + + #reshape to outdim x indim*spatial + reference_weight = reference_weight.reshape(reference_weight.shape[0], -1) + generated_weight = generated_weight.reshape(generated_weight.shape[0], -1) + #compute transpose, giving indim*spatial x outdim + #generated_weight = generated_weight.T + + #compute weight cross-covariance indim*spatial x indim*spatial + #TODO REFACTOR TO HAVE REF FIRST. OUTDIM x OUTDIM + if in_dim: + print("Input Alignment") + weight_cov = (generated_weight.T@reference_weight) + #weight_cov = (reference_weight@generated_weight.T) + alignment = calc_svd(weight_cov, name="Weight alignment") + + # outdim x indim x spatial + final_gen_weight = new_gaussian + # outdim x indim*spatial + final_gen_weight = final_gen_weight.reshape(final_gen_weight.shape[0], -1) + # outdim x indim*spatial + final_gen_weight = final_gen_weight@alignment + #final_gen_weight = alignment@final_gen_weight + # outdim x indim x spatial + else: + print("Output Alignment") + weight_cov = (reference_weight@generated_weight.T) + #weight_cov = (reference_weight@generated_weight.T) + alignment = calc_svd(weight_cov, name="Weight alignment") + + # outdim x indim x spatial + final_gen_weight = new_gaussian + # outdim x indim*spatial + final_gen_weight = final_gen_weight.reshape(final_gen_weight.shape[0], -1) + # outdim x indim*spatial + final_gen_weight = alignment@final_gen_weight + #final_gen_weight = alignment@final_gen_weight + # outdim x indim x spatial + loading_sd['weight'] = final_gen_weight.reshape(ref_sd['weight'].shape) + loading_sd['weight_align'] = alignment + new_module.register_buffer("weight_align", alignment) + new_module.load_state_dict(loading_sd) + return new_module + + + + +def conv_ACA(m1, m2, new_module): + activation = [] + other_activation = [] + print("in convACA") + + # this hook grabs the input activations of the conv layer + # rearanges the vector so that the width by height dim is + # additional samples to the covariance + # bwh x c + def define_hook(m): + def store_hook(mod, inputs, outputs): + #inputs[0] = b x c x w x h + #inputs[0].permute(0,2,3,1).reshape(-1, inputs[0].shape[1])) + #from bonner lab tutorial + x = inputs[0] + x = x.permute(0, 2, 3, 1) + x = x.reshape((-1, x.shape[-1])) + activation.append(x) + raise Exception("Done") + return store_hook + + print(m1) + print(m2) + + hook_handle_1 = m1.register_forward_hook(define_hook(m1)) + hook_handle_2 = m2.register_forward_hook(define_hook(m2)) + + covar = None + total = 0 + for batch_idx, (inputs, targets) in enumerate(trainloader): + inputs, targets = inputs.to(device), targets.to(device) + try: + outputs1 = net(inputs) + except Exception: + pass + try: + outputs2 = net_new(inputs) + except Exception: + pass + total+= inputs.shape[0] + if covar is None: + #activation is bwh x c + covar = activation[0].T @ activation[1] + assert (covar.isfinite().all()) + else: + #activation is bwh x c + covar += activation[0].T @ activation[1] + assert (covar.isfinite().all()) + activation = [] + other_activation = [] + #c x c + covar /= total + hook_handle_1.remove() + hook_handle_2.remove() + print("done with covariance_calc") + align = calc_svd(covar, name="Cross-Covariance") + new_module.register_buffer("input_align", align) + # this hook takes the input to the conv, aligns, then returns + # to the conv the aligned inputs + hook_handle_pre_forward = new_module.register_forward_pre_hook(return_hook()) + return new_module + + + +def linear_ACA(m1, m2, new_model): + print("linear") + new_module = nn.Linear(m1.in_features, m1.out_features, bias=True + if m1.bias is not None else False).to(device) + ref_sd = m1.state_dict() + loading_sd = new_module.state_dict() + loading_sd['weight'] = ref_sd['weight'] + if m1.bias is not None: + loading_sd['bias'] = ref_sd['bias'] + activation = [] + other_activation = [] + + hook_handle_1 = m1.register_forward_hook(lambda mod, inputs, outputs: + activation.append(inputs[0])) + + hook_handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: + other_activation.append(inputs[0])) + covar = None + total = 0 + print("starting covariance_calc") + for batch_idx, (inputs, targets) in enumerate(trainloader): + inputs, targets = inputs.to(device), targets.to(device) + outputs1 = net(inputs) + outputs2 = net_new(inputs) + total+= inputs.shape[0] + if covar is None: + covar = activation[0].T @ other_activation[0] + else: + covar += activation[0].T @ other_activation[0] + activation = [] + other_activation = [] + covar /= total + #print("done with covariance_calc") + + hook_handle_1.remove() + hook_handle_2.remove() + + align = calc_svd(covar, name="Cross-Covariance") + new_weight = loading_sd['weight'] + new_weight = torch.moveaxis(new_weight, source=1, + destination=-1) + new_weight = new_weight@align + loading_sd['weight'] = torch.moveaxis(new_weight, source=-1, + destination=1) + new_module.load_state_dict(loading_sd) + return new_module + + +def batchNorm_stats_recalc(m1, m2): + print("BatchieNormie") + m1.train() + m2.train() + m1.reset_running_stats() + m2.reset_running_stats() + handle_1 = m1.register_forward_hook(lambda mod, inputs, outputs: Exception("Done")) + handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: Exception("Done")) + m1.to(device) + m2.to(device) + for batch_idx, (inputs, targets) in enumerate(trainloader): + inputs, targets = inputs.to(device), targets.to(device) + try: + outputs1 = net(inputs) + except Exception: + pass + try: + outputs2 = net_new(inputs) + except Exception: + pass + handle_1.remove() + handle_2.remove() + m1.eval() + m2.eval() + +def weight_Alignment_With_CC(m1, m2, new_module, Un=None, Sn=None, Vn=None): + print("NOT SUPPOSED TO BE HERE") + # reference model state dict + ref_sd = m1.state_dict() + # generated model state dict - uses reference model weights. for now + gen_sd = m2.state_dict() + + # module with random init - to be loaded to model + loading_sd = new_module.state_dict() + new_gaussian = loading_sd['weight'] + + # carry over old bias. only matters when we work with no batchnorm networks + if m1.bias is not None: + loading_sd['bias'] = ref_sd['bias'] + # carry over old colored covariance. only matters with fact-convs + if "tri1_vec" in gen_sd.keys(): + loading_sd['tri1_vec']=gen_sd['tri1_vec'] + loading_sd['tri2_vec']=gen_sd['tri2_vec'] + old_weight = ref_sd['weight'] + A = old_weight.reshape(old_weight.shape[0], -1) + A_T_A = A.T@A + V_val, Vn = torch.linalg.eigh(A_T_A) + del A_T_A + V_val = V_val.flip(0) + Vn = Vn.fliplr().T + Sn = (1e-6 + V_val.abs()).sqrt() + Sn_inv = (1/Sn).diag() + Un = A @ Vn.T @ Sn_inv + white_gaussian = torch.randn_like(Un) + copy_weight = Un + copy_weight_gen = white_gaussian + copy_weight = copy_weight.reshape(copy_weight.shape[0], -1) + copy_weight_gen = copy_weight_gen.reshape(copy_weight_gen.shape[0], -1).T + weight_cov = (copy_weight_gen@copy_weight) + alignment = calc_svd(weight_cov, name="Weight") + new_weight = white_gaussian + new_weight = new_weight.reshape(new_weight.shape[0], -1) + new_weight = new_weight@alignment # C_in_reference to C_in_generated + + new_module.register_buffer("weight_align", alignment) + loading_sd['weight_align'] = alignment + colored_gaussian = white_gaussian @ (Sn[:,None]* Vn)#(Sn[:,None]* Vn) + loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) + new_module.load_state_dict(loading_sd) + return new_module + +# this function does not do an explicit specification of the colored covariance +@torch.no_grad() +def colored_Covariance_Specification(m1, m2, new_module, Un=None, Sn=None, Vn=None): + print("NOT HERE") + # reference model state dict + ref_sd = m1.state_dict() + # generated model state dict - uses reference model weights. for now + gen_sd = m2.state_dict() + + # module with random init - to be loaded to model + loading_sd = new_module.state_dict() + new_gaussian = loading_sd['weight'] + + # carry over old bias. only matters when we work with no batchnorm networks + if m1.bias is not None: + loading_sd['bias'] = ref_sd['bias'] + # carry over old colored covariance. only matters with fact-convs + if "tri1_vec" in gen_sd.keys(): + loading_sd['tri1_vec']=gen_sd['tri1_vec'] + loading_sd['tri2_vec']=gen_sd['tri2_vec'] + old_weight = ref_sd['weight'] + A = old_weight.reshape(old_weight.shape[0], -1) + A_T_A = A.T@A + V_val, Vn = torch.linalg.eigh(A_T_A) + del A_T_A + V_val = V_val.flip(0) + Vn = Vn.fliplr().T + Sn = (1e-6 + V_val.abs()).sqrt() + Sn_inv = (1/Sn).diag() + Un = A @ Vn.T @ Sn_inv + white_gaussian = torch.randn_like(Un) + colored_gaussian = white_gaussian @ (Sn[:,None]* Vn)#(Sn[:,None]* Vn) + loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) + new_module.load_state_dict(loading_sd) + return new_module + +def fact_2_conv(new_module): + ## simple module + print("TESTING FACT REPLACEMENT") + fact_module = nn.Conv2d( + in_channels=new_module.in_channels, + out_channels=new_module.out_channels, + kernel_size=new_module.kernel_size, + stride=new_module.stride, padding=new_module.padding, + bias=True if new_module.bias is not None else False) + + old_sd = new_module.state_dict() + new_sd = fact_module.state_dict() + if new_module.bias is not None: + new_sd['bias'] = old_sd['bias'] + U1 = new_module._tri_vec_to_mat(new_module.tri1_vec, new_module.in_channels // + new_module.groups, new_module.scat_idx1) + U2 = new_module._tri_vec_to_mat(new_module.tri2_vec, + new_module.kernel_size[0] * new_module.kernel_size[1], + new_module.scat_idx2) + U = torch.kron(U1, U2) + matrix_shape = (new_module.out_channels, new_module.in_features) + composite_weight = torch.reshape( + torch.reshape(new_module.weight, matrix_shape) @ U, + new_module.weight.shape + ) + new_sd['weight'] = composite_weight + if 'weight_align' in old_sd.keys(): + new_sd['weight_align'] = old_sd['weight_align'] + shape = fact_module.in_channels*fact_module.kernel_size[0]*fact_module.kernel_size[1] + fact_module.register_buffer("weight_align",torch.zeros((shape, shape))) + if 'input_align' in old_sd.keys(): + new_sd['input_align'] = old_sd['input_align'] + out_shape = fact_module.in_channels + fact_module.register_buffer("input_align",torch.zeros((out_shape, out_shape))) + if new_module.in_channels != 3: + #fact check this + for key in list(new_module._forward_pre_hooks.keys()): + del new_module._forward_pre_hooks[key] + hook_handle_pre_forward = fact_module.register_forward_pre_hook(return_hook()) + fact_module.load_state_dict(new_sd) + fact_module.to(device) + new_module = fact_module + print("FACT REPLACEMENT:", new_module) + return new_module + + +def turn_off_grads(model): + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + turn_off_grads(module) + else: + if isinstance(module, nn.Linear) and module.out_features == 10: + grad=True + else: + grad=False + for param in module.parameters(): + param.requires_grad = grad + +set_seeds(1) +#net=VGG("VGG11", args.bn_on) +criterion = nn.CrossEntropyLoss() +def train(epoch, net): + print('\nEpoch: %d' % epoch) + net.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(trainloader): + inputs, targets = inputs.to(device), targets.to(device) + optimizer.zero_grad() + outputs = net(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) + + +def test(epoch, net): + global best_acc + net.eval() + test_loss = 0 + correct = 0 + total = 0 + with torch.no_grad(): + for batch_idx, (inputs, targets) in enumerate(testloader): + inputs, targets = inputs.to(device), targets.to(device) + outputs = net(inputs) + loss = criterion(outputs, targets) + + test_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) + + + # Save checkpoint. + acc = 100.*correct/total + print("accuracy:", acc) + return acc, test_loss + #run.log({"accuracy":acc}) + +print("testing Res{}Net18 with width of {}".format("Fact" if args.fact else "Conv", args.width)) +pretrained_acc, og_loss = test(0, net) + +set_seeds(1) +s=time.time() +our_rainbow_sampling(net, net_new) +net_new.train() +for batch_idx, (inputs, targets) in enumerate(trainloader): + inputs, targets = inputs.to(device), targets.to(device) + outputs = net_new(inputs) +run_name= "eigh_final_{}_Res{}Net18_Width_{}_{}_affine_ACA_{}_WA_{}_InputWA_{}".format(args.sampling.capitalize(), "Fact" + if args.fact else "Conv", str(args.width), "No" if not args.affine else + "Yes", "On" if args.aca else "Off", "On" if args.wa else "Off", + args.in_wa) + + +print("TOTAL TIME:", time.time()-s) +turn_off_grads(net_new) +## +## +## +# +optimizer = optim.SGD(filter(lambda param: param.requires_grad, net_new.parameters()), lr=args.lr, + momentum=0.9, weight_decay=5e-4) +#print("testing rainbow sampling") +print("testing {} sampling at width {}".format(args.sampling, args.width)) +net_new.eval() + +#run_name= "correctly_saved_network_{}_Res{}Net18_Width_{}_{}_affine".format(args.sampling.capitalize(), "Fact" +# if args.fact else "Conv", str(args.width), "No" if not args.affine else +# "Yes") +# +args.name = run_name +print(net_new) +sampled_acc, sampled_loss = test(0, net_new) +#assert False +#print("training rainbow sampling classifier head for 10 epochs") +save_model(args, net_new) +accs = [] +test_losses= [] +print("training classifier head of {} sampled model for {} epochs".format(args.sampling, args.epochs)) +for i in range(0, args.epochs): + net_new.train() + train(i, net_new) + net_new.eval() + acc, loss_test =test(i, net_new) + test_losses.append(loss_test) + accs.append(acc) +logger ={"pretrained_acc": pretrained_acc, "sampled_acc": sampled_acc, + "first_epoch_acc":accs[0], "third_epoch_acc": accs[2], + "tenth_epoch_acc":accs[args.epochs-1], 'width':args.width, + "og_loss":og_loss, "sampled_loss":sampled_loss, + "first_epoch_loss":test_losses[0], "third_epoch_loss": test_losses[2], + "tenth_epoch_loss":test_losses[args.epochs-1], 'width':args.width} + + +wandb_dir = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/wandb" +os.makedirs(wandb_dir, exist_ok=True) +os.chdir(wandb_dir) +group_string = "eigh_final_iclr_ACA_{}_WA_{}_{}_{}_rainbow_sampling".format("On" if + args.aca else "Off", "Input" if args.in_wa else "Output", "On" if + args.wa else "Off", "Fact" if args.fact else "Conv") +run = wandb.init(project="random_project", config=args, + group=group_string, name=run_name, dir=wandb_dir) +run.log(logger) + + +args.name += "_trained_classifier_head" +save_model(args, net_new) +#wandb.watch(net, log='all', log_freq=1) From 82180c4e41c146556f1dec4588e3cd3748e14e00 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 24 Apr 2024 08:17:58 -0400 Subject: [PATCH 24/77] minimal changes to ensure this script runs --- refactor/refactor_rainbow.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index e326d87..66b1cb5 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -13,12 +13,9 @@ import copy from pytorch_cifar_utils import progress_bar, set_seeds -from test_models_safety import PostExp, PreExp -from layers_model import ThreeLayer_CIFAR10, Sequential_ThreeLayer_CIFAR10 import wandb from distutils.util import strtobool -from resnet import ResNet18 -from vgg import VGG +from models.resnet import ResNet18 import numpy as np import gc #torch.backends.cudnn.allow_tf32 = True @@ -27,7 +24,7 @@ def save_model(args, model): #assert False - src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/eigh_final_refactor_covar_new_testing_rainbow_models/" + src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/refactoring/" model_dir = src + args.name os.makedirs(model_dir, exist_ok=True) os.chdir(model_dir) @@ -101,7 +98,7 @@ def save_model(args, model): # Model print('==> Building model..') -from ConvModules import FactConv2dPreExp +from conv_modules import FactConv2d def replace_layers_agnostic(model, scale=1): prev_out_ch = 0 @@ -151,7 +148,7 @@ def replace_layers(model): replace_layers(module) if isinstance(module, nn.Conv2d): ## simple module - new_module = FactConv2dPreExp( + new_module = FactConv2d( in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, @@ -164,7 +161,7 @@ def replace_layers_fact(model): if len(list(module.children())) > 0: ## compound module, go inside it replace_layers_fact(module) - if isinstance(module, FactConv2dPreExp): + if isinstance(module, FactConv2d): ## simple module new_module = nn.Conv2d( in_channels=module.in_channels, @@ -194,7 +191,6 @@ def replace_layers_fact(model): net=ResNet18() -#net=Sequential_ThreeLayer_CIFAR10(100,False) net.to(device) replace_layers_agnostic(net, args.width) if args.fact: @@ -280,8 +276,8 @@ def our_rainbow_sampling(model, new_model): our_rainbow_sampling(m1, m2) if isinstance(m1, nn.Conv2d): print("conv") - if isinstance(m2, FactConv2dPreExp): - new_module = FactConv2dPreExp( + if isinstance(m2, FactConv2d): + new_module = FactConv2d( in_channels=m2.in_channels, out_channels=m2.out_channels, kernel_size=m2.kernel_size, @@ -309,7 +305,7 @@ def our_rainbow_sampling(model, new_model): if m1.in_channels != 3 and args.aca: new_module = conv_ACA(m1, m2, new_module) # converts fact conv to conv. this is for sake of speed. - #if isinstance(new_module, FactConv2dPreExp): + #if isinstance(new_module, FactConv2d): # new_module = fact_2_conv(new_module) #changes the network module setattr(new_model, n1, new_module) @@ -763,6 +759,7 @@ def test(epoch, net): # if args.fact else "Conv", str(args.width), "No" if not args.affine else # "Yes") # +run_name = "refactor" args.name = run_name print(net_new) sampled_acc, sampled_loss = test(0, net_new) @@ -793,6 +790,7 @@ def test(epoch, net): group_string = "eigh_final_iclr_ACA_{}_WA_{}_{}_{}_rainbow_sampling".format("On" if args.aca else "Off", "Input" if args.in_wa else "Output", "On" if args.wa else "Off", "Fact" if args.fact else "Conv") +group_name = "refactor" run = wandb.init(project="random_project", config=args, group=group_string, name=run_name, dir=wandb_dir) run.log(logger) From 09aade208cf013ed1ad663d5984e44915755379c Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 24 Apr 2024 09:15:01 -0400 Subject: [PATCH 25/77] cleanup, modified print statements, added seed argument, removed width 8, removed evaluation on affine models, etc etc --- refactor/refactor_rainbow.py | 152 ++++++++++++++--------------------- 1 file changed, 62 insertions(+), 90 deletions(-) diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index 66b1cb5..f7f9f0c 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -4,26 +4,24 @@ import torch.optim as optim import torch.nn.functional as F import torch.backends.cudnn as cudnn -from torch.profiler import profile, record_function, ProfilerActivity import torchvision import torchvision.transforms as transforms + +import wandb +import numpy as np + import time import os import argparse import copy +import gc +from distutils.util import strtobool from pytorch_cifar_utils import progress_bar, set_seeds -import wandb -from distutils.util import strtobool from models.resnet import ResNet18 -import numpy as np -import gc -#torch.backends.cudnn.allow_tf32 = True -#torch.backends.cuda.matmul.allow_tf32 = True -#torch.backends.cuda.preferred_linalg_library('magma') - +from conv_modules import FactConv2d + def save_model(args, model): - #assert False src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/refactoring/" model_dir = src + args.name os.makedirs(model_dir, exist_ok=True) @@ -38,6 +36,7 @@ def save_model(args, model): parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--epochs', default=10, type=int, help='number of epochs') +parser.add_argument('--seed', default=1, type=int, help='seed') parser.add_argument('--name', type=str, default='TESTING_VGG', help='filename for saved model') parser.add_argument('--affine', type=lambda x: bool(strtobool(x)), @@ -53,7 +52,9 @@ def save_model(args, model): parser.add_argument('--width', default=0.125, type=float, help='width') parser.add_argument('--sampling', type=str, default='ours', choices=['ours', 'theirs'], help="which sampling to use") + args = parser.parse_args() + if args.width == 1.0: args.width = 1 if args.width == 2.0: @@ -62,6 +63,7 @@ def save_model(args, model): args.width = 4 if args.width == 8.0: args.width = 8 + print("Sampling: {} Width: {} Fact: {} ACA: {} WA: {} In_WA: {}".format(args.sampling, args.width, args.fact, args.aca, args.wa, args.in_wa)) @@ -84,7 +86,7 @@ def save_model(args, model): ]) trainset = torchvision.datasets.CIFAR10( - root='./data', train=True, download=True, transform=transform_train)#transform_train) + root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader( trainset, batch_size=128, shuffle=True, num_workers=4, drop_last=True) @@ -98,7 +100,7 @@ def save_model(args, model): # Model print('==> Building model..') -from conv_modules import FactConv2d + def replace_layers_agnostic(model, scale=1): prev_out_ch = 0 @@ -128,6 +130,7 @@ def replace_layers_agnostic(model, scale=1): new_module = nn.Linear(int(512 * scale), 10) setattr(model, n, new_module) + def replace_affines(model): for n, module in model.named_children(): if len(list(module.children())) > 0: @@ -141,6 +144,8 @@ def replace_affines(model): affine=False, track_running_stats=module.track_running_stats) setattr(model, n, new_module) + + def replace_layers(model): for n, module in model.named_children(): if len(list(module.children())) > 0: @@ -156,6 +161,7 @@ def replace_layers(model): bias=True if module.bias is not None else False) setattr(model, n, new_module) + def replace_layers_fact(model): for n, module in model.named_children(): if len(list(module.children())) > 0: @@ -171,7 +177,6 @@ def replace_layers_fact(model): bias=True if module.bias is not None else False) old_sd = module.state_dict() new_sd = new_module.state_dict() - #new_sd['weight'] = old_sd['weight'] if module.bias is not None: new_sd['bias'] = old_sd['bias'] U1 = module._tri_vec_to_mat(module.tri1_vec, module.in_channels // @@ -184,7 +189,6 @@ def replace_layers_fact(model): torch.reshape(module.weight, matrix_shape) @ U, module.weight.shape ) - #output = self._conv_forward(input, composite_weight, self.bias) new_sd['weight'] = composite_weight new_module.load_state_dict(new_sd) setattr(model, n, new_module) @@ -197,21 +201,10 @@ def replace_layers_fact(model): replace_layers(net) #if not args.affine: # replace_affines(net) -#sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/three-layer-models-sequential/TESTING_3Layer_final/{}_model.pt".format("fact" if args.fact else "conv")) if args.fact and args.affine: - if args.width == 8: - sd=torch.load("/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/width_8/8scale_final/fact_model.pt") - else: - sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/fact_model.pt".format(args.width)) + sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/fact_model.pt".format(args.width)) elif not args.fact and args.affine: - if args.width == 8: - sd=torch.load("/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/width_8/8scale_final/conv_model.pt") - else: - sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/conv_model.pt".format(args.width)) -elif args.fact and not args.affine: - sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_2/{}scale_final/fact_model.pt".format(args.width)) -elif not args.fact and not args.affine: - sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_2/{}scale_final/conv_model.pt".format(args.width)) + sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/conv_model.pt".format(args.width)) net.load_state_dict(sd) net.to(device) net_new = copy.deepcopy(net) @@ -223,15 +216,19 @@ def replace_layers_fact(model): net.train() net_new.train() +set_seeds(args.seed) +criterion = nn.CrossEntropyLoss() + + #traditional way of calculating svd. can be a bit unstable sometimes tho def calc_svd(A, name=''): u, s, vh = torch.linalg.svd( A, full_matrices=False, - # driver="gesvd" ) # (C_in_reference, R), (R,), (R, C_in_generated) alignment = u @ vh # (C_in_reference, C_in_generated) return alignment + #i've been finding this way of calculating svd to be more stable. def calc_svd_eigh(A, name=''): A_T_A = A.T@A @@ -244,6 +241,7 @@ def calc_svd_eigh(A, name=''): alignment = Un @ Vn return alignment + #used in activation cross-covariance calculation #input align hook def return_hook(): @@ -275,7 +273,6 @@ def our_rainbow_sampling(model, new_model): if len(list(m1.children())) > 0: our_rainbow_sampling(m1, m2) if isinstance(m1, nn.Conv2d): - print("conv") if isinstance(m2, FactConv2d): new_module = FactConv2d( in_channels=m2.in_channels, @@ -321,7 +318,6 @@ def our_rainbow_sampling(model, new_model): def weight_Alignment(m1, m2, new_module, in_dim=True): # reference model state dict - print("we go here") ref_sd = m1.state_dict() # generated model state dict - uses reference model weights. for now gen_sd = m2.state_dict() @@ -329,7 +325,6 @@ def weight_Alignment(m1, m2, new_module, in_dim=True): # module with random init - to be loaded to model loading_sd = new_module.state_dict() new_gaussian = loading_sd['weight'] - print(new_gaussian.shape) # carry over old bias. only matters when we work with no batchnorm networks if m1.bias is not None: @@ -338,6 +333,7 @@ def weight_Alignment(m1, m2, new_module, in_dim=True): if "tri1_vec" in gen_sd.keys(): loading_sd['tri1_vec']=gen_sd['tri1_vec'] loading_sd['tri2_vec']=gen_sd['tri2_vec'] + #this is the spot where # we can do weight alignment # for fact net, this means aligning with the random noise @@ -356,14 +352,12 @@ def weight_Alignment(m1, m2, new_module, in_dim=True): reference_weight = reference_weight.reshape(reference_weight.shape[0], -1) generated_weight = generated_weight.reshape(generated_weight.shape[0], -1) #compute transpose, giving indim*spatial x outdim - #generated_weight = generated_weight.T #compute weight cross-covariance indim*spatial x indim*spatial #TODO REFACTOR TO HAVE REF FIRST. OUTDIM x OUTDIM if in_dim: - print("Input Alignment") + print("Input Weight Alignment") weight_cov = (generated_weight.T@reference_weight) - #weight_cov = (reference_weight@generated_weight.T) alignment = calc_svd(weight_cov, name="Weight alignment") # outdim x indim x spatial @@ -372,12 +366,9 @@ def weight_Alignment(m1, m2, new_module, in_dim=True): final_gen_weight = final_gen_weight.reshape(final_gen_weight.shape[0], -1) # outdim x indim*spatial final_gen_weight = final_gen_weight@alignment - #final_gen_weight = alignment@final_gen_weight - # outdim x indim x spatial else: - print("Output Alignment") + print("Output Weight Alignment") weight_cov = (reference_weight@generated_weight.T) - #weight_cov = (reference_weight@generated_weight.T) alignment = calc_svd(weight_cov, name="Weight alignment") # outdim x indim x spatial @@ -386,8 +377,7 @@ def weight_Alignment(m1, m2, new_module, in_dim=True): final_gen_weight = final_gen_weight.reshape(final_gen_weight.shape[0], -1) # outdim x indim*spatial final_gen_weight = alignment@final_gen_weight - #final_gen_weight = alignment@final_gen_weight - # outdim x indim x spatial + loading_sd['weight'] = final_gen_weight.reshape(ref_sd['weight'].shape) loading_sd['weight_align'] = alignment new_module.register_buffer("weight_align", alignment) @@ -395,21 +385,16 @@ def weight_Alignment(m1, m2, new_module, in_dim=True): return new_module - - def conv_ACA(m1, m2, new_module): + print("Convolutional Input Activations Alignment") activation = [] other_activation = [] - print("in convACA") - # this hook grabs the input activations of the conv layer # rearanges the vector so that the width by height dim is # additional samples to the covariance # bwh x c def define_hook(m): def store_hook(mod, inputs, outputs): - #inputs[0] = b x c x w x h - #inputs[0].permute(0,2,3,1).reshape(-1, inputs[0].shape[1])) #from bonner lab tutorial x = inputs[0] x = x.permute(0, 2, 3, 1) @@ -418,12 +403,10 @@ def store_hook(mod, inputs, outputs): raise Exception("Done") return store_hook - print(m1) - print(m2) - hook_handle_1 = m1.register_forward_hook(define_hook(m1)) hook_handle_2 = m2.register_forward_hook(define_hook(m2)) + print("Starting Sample Cross-Covariance Calculation") covar = None total = 0 for batch_idx, (inputs, targets) in enumerate(trainloader): @@ -447,22 +430,23 @@ def store_hook(mod, inputs, outputs): assert (covar.isfinite().all()) activation = [] other_activation = [] + #c x c covar /= total hook_handle_1.remove() hook_handle_2.remove() - print("done with covariance_calc") + print("Sample Cross-Covariance Calculation finished") align = calc_svd(covar, name="Cross-Covariance") new_module.register_buffer("input_align", align) + # this hook takes the input to the conv, aligns, then returns # to the conv the aligned inputs hook_handle_pre_forward = new_module.register_forward_pre_hook(return_hook()) return new_module - def linear_ACA(m1, m2, new_model): - print("linear") + print("Linear Input Activations Alignment") new_module = nn.Linear(m1.in_features, m1.out_features, bias=True if m1.bias is not None else False).to(device) ref_sd = m1.state_dict() @@ -480,7 +464,7 @@ def linear_ACA(m1, m2, new_model): other_activation.append(inputs[0])) covar = None total = 0 - print("starting covariance_calc") + print("Starting Sample Cross-Covariance Calculation") for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) outputs1 = net(inputs) @@ -493,11 +477,11 @@ def linear_ACA(m1, m2, new_model): activation = [] other_activation = [] covar /= total - #print("done with covariance_calc") hook_handle_1.remove() hook_handle_2.remove() + print("Sample Cross-Covariance Calculation finished") align = calc_svd(covar, name="Cross-Covariance") new_weight = loading_sd['weight'] new_weight = torch.moveaxis(new_weight, source=1, @@ -510,7 +494,7 @@ def linear_ACA(m1, m2, new_model): def batchNorm_stats_recalc(m1, m2): - print("BatchieNormie") + print("Calculating Batch Statistics") m1.train() m2.train() m1.reset_running_stats() @@ -533,6 +517,8 @@ def batchNorm_stats_recalc(m1, m2): handle_2.remove() m1.eval() m2.eval() + print("Batch Statistics Calculation Finished") + def weight_Alignment_With_CC(m1, m2, new_module, Un=None, Sn=None, Vn=None): print("NOT SUPPOSED TO BE HERE") @@ -552,6 +538,7 @@ def weight_Alignment_With_CC(m1, m2, new_module, Un=None, Sn=None, Vn=None): if "tri1_vec" in gen_sd.keys(): loading_sd['tri1_vec']=gen_sd['tri1_vec'] loading_sd['tri2_vec']=gen_sd['tri2_vec'] + old_weight = ref_sd['weight'] A = old_weight.reshape(old_weight.shape[0], -1) A_T_A = A.T@A @@ -563,11 +550,13 @@ def weight_Alignment_With_CC(m1, m2, new_module, Un=None, Sn=None, Vn=None): Sn_inv = (1/Sn).diag() Un = A @ Vn.T @ Sn_inv white_gaussian = torch.randn_like(Un) + copy_weight = Un copy_weight_gen = white_gaussian copy_weight = copy_weight.reshape(copy_weight.shape[0], -1) copy_weight_gen = copy_weight_gen.reshape(copy_weight_gen.shape[0], -1).T weight_cov = (copy_weight_gen@copy_weight) + alignment = calc_svd(weight_cov, name="Weight") new_weight = white_gaussian new_weight = new_weight.reshape(new_weight.shape[0], -1) @@ -575,11 +564,12 @@ def weight_Alignment_With_CC(m1, m2, new_module, Un=None, Sn=None, Vn=None): new_module.register_buffer("weight_align", alignment) loading_sd['weight_align'] = alignment - colored_gaussian = white_gaussian @ (Sn[:,None]* Vn)#(Sn[:,None]* Vn) + colored_gaussian = white_gaussian @ (Sn[:,None]* Vn) loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) new_module.load_state_dict(loading_sd) return new_module - + + # this function does not do an explicit specification of the colored covariance @torch.no_grad() def colored_Covariance_Specification(m1, m2, new_module, Un=None, Sn=None, Vn=None): @@ -600,6 +590,7 @@ def colored_Covariance_Specification(m1, m2, new_module, Un=None, Sn=None, Vn=No if "tri1_vec" in gen_sd.keys(): loading_sd['tri1_vec']=gen_sd['tri1_vec'] loading_sd['tri2_vec']=gen_sd['tri2_vec'] + old_weight = ref_sd['weight'] A = old_weight.reshape(old_weight.shape[0], -1) A_T_A = A.T@A @@ -611,14 +602,16 @@ def colored_Covariance_Specification(m1, m2, new_module, Un=None, Sn=None, Vn=No Sn_inv = (1/Sn).diag() Un = A @ Vn.T @ Sn_inv white_gaussian = torch.randn_like(Un) - colored_gaussian = white_gaussian @ (Sn[:,None]* Vn)#(Sn[:,None]* Vn) + + colored_gaussian = white_gaussian @ (Sn[:,None]* Vn) loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) new_module.load_state_dict(loading_sd) return new_module - + + def fact_2_conv(new_module): ## simple module - print("TESTING FACT REPLACEMENT") + print("Replacing FactConv") fact_module = nn.Conv2d( in_channels=new_module.in_channels, out_channels=new_module.out_channels, @@ -628,19 +621,23 @@ def fact_2_conv(new_module): old_sd = new_module.state_dict() new_sd = fact_module.state_dict() + if new_module.bias is not None: new_sd['bias'] = old_sd['bias'] + U1 = new_module._tri_vec_to_mat(new_module.tri1_vec, new_module.in_channels // new_module.groups, new_module.scat_idx1) U2 = new_module._tri_vec_to_mat(new_module.tri2_vec, new_module.kernel_size[0] * new_module.kernel_size[1], new_module.scat_idx2) U = torch.kron(U1, U2) + matrix_shape = (new_module.out_channels, new_module.in_features) composite_weight = torch.reshape( torch.reshape(new_module.weight, matrix_shape) @ U, new_module.weight.shape ) + new_sd['weight'] = composite_weight if 'weight_align' in old_sd.keys(): new_sd['weight_align'] = old_sd['weight_align'] @@ -675,9 +672,7 @@ def turn_off_grads(model): for param in module.parameters(): param.requires_grad = grad -set_seeds(1) -#net=VGG("VGG11", args.bn_on) -criterion = nn.CrossEntropyLoss() + def train(epoch, net): print('\nEpoch: %d' % epoch) net.train() @@ -719,52 +714,34 @@ def test(epoch, net): progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) - - # Save checkpoint. acc = 100.*correct/total print("accuracy:", acc) return acc, test_loss - #run.log({"accuracy":acc}) + print("testing Res{}Net18 with width of {}".format("Fact" if args.fact else "Conv", args.width)) pretrained_acc, og_loss = test(0, net) -set_seeds(1) +set_seeds(args.seed) s=time.time() our_rainbow_sampling(net, net_new) net_new.train() for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) outputs = net_new(inputs) -run_name= "eigh_final_{}_Res{}Net18_Width_{}_{}_affine_ACA_{}_WA_{}_InputWA_{}".format(args.sampling.capitalize(), "Fact" - if args.fact else "Conv", str(args.width), "No" if not args.affine else - "Yes", "On" if args.aca else "Off", "On" if args.wa else "Off", - args.in_wa) - - print("TOTAL TIME:", time.time()-s) turn_off_grads(net_new) -## -## -## -# optimizer = optim.SGD(filter(lambda param: param.requires_grad, net_new.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4) -#print("testing rainbow sampling") print("testing {} sampling at width {}".format(args.sampling, args.width)) net_new.eval() -#run_name= "correctly_saved_network_{}_Res{}Net18_Width_{}_{}_affine".format(args.sampling.capitalize(), "Fact" -# if args.fact else "Conv", str(args.width), "No" if not args.affine else -# "Yes") -# run_name = "refactor" args.name = run_name print(net_new) + sampled_acc, sampled_loss = test(0, net_new) -#assert False -#print("training rainbow sampling classifier head for 10 epochs") save_model(args, net_new) accs = [] test_losses= [] @@ -776,6 +753,7 @@ def test(epoch, net): acc, loss_test =test(i, net_new) test_losses.append(loss_test) accs.append(acc) + logger ={"pretrained_acc": pretrained_acc, "sampled_acc": sampled_acc, "first_epoch_acc":accs[0], "third_epoch_acc": accs[2], "tenth_epoch_acc":accs[args.epochs-1], 'width':args.width, @@ -783,19 +761,13 @@ def test(epoch, net): "first_epoch_loss":test_losses[0], "third_epoch_loss": test_losses[2], "tenth_epoch_loss":test_losses[args.epochs-1], 'width':args.width} - wandb_dir = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/wandb" os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) -group_string = "eigh_final_iclr_ACA_{}_WA_{}_{}_{}_rainbow_sampling".format("On" if - args.aca else "Off", "Input" if args.in_wa else "Output", "On" if - args.wa else "Off", "Fact" if args.fact else "Conv") group_name = "refactor" run = wandb.init(project="random_project", config=args, group=group_string, name=run_name, dir=wandb_dir) run.log(logger) - args.name += "_trained_classifier_head" save_model(args, net_new) -#wandb.watch(net, log='all', log_freq=1) From 6c6dcc9542e5ee42bd74746982cc9e3fde44a8ec Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 24 Apr 2024 09:19:28 -0400 Subject: [PATCH 26/77] removed affine argument --- refactor/refactor_rainbow.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index f7f9f0c..9986893 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -39,8 +39,6 @@ def save_model(args, model): parser.add_argument('--seed', default=1, type=int, help='seed') parser.add_argument('--name', type=str, default='TESTING_VGG', help='filename for saved model') -parser.add_argument('--affine', type=lambda x: bool(strtobool(x)), - default=True, help='Batch Norm affine True or False') parser.add_argument('--aca', type=lambda x: bool(strtobool(x)), default=True, help='Activation Cross-Covariance Alignment') parser.add_argument('--wa', type=lambda x: bool(strtobool(x)), @@ -131,21 +129,6 @@ def replace_layers_agnostic(model, scale=1): setattr(model, n, new_module) -def replace_affines(model): - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - replace_affines(module) - if isinstance(module, nn.BatchNorm2d): - ## simple module - new_module = nn.BatchNorm2d( - num_features=module.num_features, - eps=module.eps, momentum=module.momentum, - affine=False, - track_running_stats=module.track_running_stats) - setattr(model, n, new_module) - - def replace_layers(model): for n, module in model.named_children(): if len(list(module.children())) > 0: @@ -199,11 +182,9 @@ def replace_layers_fact(model): replace_layers_agnostic(net, args.width) if args.fact: replace_layers(net) -#if not args.affine: -# replace_affines(net) -if args.fact and args.affine: +if args.fact: sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/fact_model.pt".format(args.width)) -elif not args.fact and args.affine: +elif not args.fact: sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/conv_model.pt".format(args.width)) net.load_state_dict(sd) net.to(device) From 15422ab63da4f5b6d4cce75b261da0f0cff775dc Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 24 Apr 2024 10:57:44 -0400 Subject: [PATCH 27/77] moved recursive modules around --- refactor/models/__init__.py | 6 +- refactor/models/function_utils.py | 56 +++++++++++++++- refactor/refactor_rainbow.py | 107 +++--------------------------- 3 files changed, 68 insertions(+), 101 deletions(-) diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py index c314e22..0588e64 100644 --- a/refactor/models/__init__.py +++ b/refactor/models/__init__.py @@ -1,5 +1,5 @@ from .resnet import ResNet18 -from .function_utils import replace_layers_factconv2d, turn_off_grad, replace_layers_scale, init_V1_layers +from .function_utils import replace_layers_factconv2d, turn_off_covar_grad, replace_layers_scale, init_V1_layers def define_models(args): @@ -10,9 +10,9 @@ def define_models(args): if "v1" in args.net: init_V1_layers(model, bias=False) if "us" in args.net: - turn_off_grad(model, "spatial") + turn_off_covar_grad(model, "spatial") if "uc" in args.net: - turn_off_grad(model, "channel") + turn_off_covar_grad(model, "channel") if args.width != 1: replace_layers_scale(model, args.width) return model diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index 06f415a..34bf580 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -78,7 +78,42 @@ def replace_layers_scale(model, scale=1): setattr(model, n, new_module) -def turn_off_grad(model, covariance): +def replace_layers_fact_with_conv(model): + ''' + Replace FactConv2d layers with nn.Conv2d + ''' + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + replace_layers_fact_with_conv(module) + if isinstance(module, FactConv2d): + ## simple module + new_module = nn.Conv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + U1 = module._tri_vec_to_mat(module.tri1_vec, module.in_channels // + module.groups,module.scat_idx1) + U2 = module._tri_vec_to_mat(module.tri2_vec, module.kernel_size[0] * module.kernel_size[1], + module.scat_idx2) + U = torch.kron(U1, U2) + matrix_shape = (module.out_channels, module.in_features) + composite_weight = torch.reshape( + torch.reshape(module.weight, matrix_shape) @ U, + module.weight.shape + ) + new_sd['weight'] = composite_weight + new_module.load_state_dict(new_sd) + setattr(model, n, new_module) + + +def turn_off_covar_grad(model, covariance): ''' Turn off gradients in tri1_vec or tri2_vec to turn off channel or spatial covariance learning @@ -86,7 +121,7 @@ def turn_off_grad(model, covariance): for n, module in model.named_children(): if len(list(module.children())) > 0: ## compound module, go inside it - turn_off_grad(module, covariance) + turn_off_covar_grad(module, covariance) if isinstance(module, FactConv2d): for name, param in module.named_parameters(): if covariance == "channel": @@ -96,6 +131,23 @@ def turn_off_grad(model, covariance): if "tri2_vec" in name: param.requires_grad = False + +def turn_off_backbone_grad(model): + ''' + Turn off gradients in backbone. For tuning just classifier layer + ''' + for n, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + turn_off_backbone_grad(module) + #else: + if isinstance(module, nn.Linear) and module.out_features == 10: + grad=True + else: + grad=False + for param in module.parameters(): + param.requires_grad = grad + def init_V1_layers(model, bias): ''' diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index 9986893..4ec99ab 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -20,6 +20,8 @@ from pytorch_cifar_utils import progress_bar, set_seeds from models.resnet import ResNet18 from conv_modules import FactConv2d +from models.function_utils import replace_layers_factconv2d,\ +replace_layers_scale, replace_layers_fact_with_conv, turn_off_backbone_grad def save_model(args, model): src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/refactoring/" @@ -100,98 +102,25 @@ def save_model(args, model): print('==> Building model..') -def replace_layers_agnostic(model, scale=1): - prev_out_ch = 0 - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - replace_layers_agnostic(module,scale) - if isinstance(module, nn.Conv2d): - if module.in_channels == 3: - in_scale = 1 - else: - in_scale = scale - ## simple module - new_module = nn.Conv2d( - in_channels=int(module.in_channels*in_scale), - out_channels=int(module.out_channels*scale), - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - groups = module.groups, - bias=True if module.bias is not None else False) - setattr(model, n, new_module) - prev_out_ch = new_module.out_channels - if isinstance(module, nn.BatchNorm2d): - new_module = nn.BatchNorm2d(prev_out_ch) - setattr(model, n, new_module) - if isinstance(module, nn.Linear): - new_module = nn.Linear(int(512 * scale), 10) - setattr(model, n, new_module) - - -def replace_layers(model): - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - replace_layers(module) - if isinstance(module, nn.Conv2d): - ## simple module - new_module = FactConv2d( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - bias=True if module.bias is not None else False) - setattr(model, n, new_module) - - -def replace_layers_fact(model): - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - replace_layers_fact(module) - if isinstance(module, FactConv2d): - ## simple module - new_module = nn.Conv2d( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - bias=True if module.bias is not None else False) - old_sd = module.state_dict() - new_sd = new_module.state_dict() - if module.bias is not None: - new_sd['bias'] = old_sd['bias'] - U1 = module._tri_vec_to_mat(module.tri1_vec, module.in_channels // - module.groups,module.scat_idx1) - U2 = module._tri_vec_to_mat(module.tri2_vec, module.kernel_size[0] * module.kernel_size[1], - module.scat_idx2) - U = torch.kron(U1, U2) - matrix_shape = (module.out_channels, module.in_features) - composite_weight = torch.reshape( - torch.reshape(module.weight, matrix_shape) @ U, - module.weight.shape - ) - new_sd['weight'] = composite_weight - new_module.load_state_dict(new_sd) - setattr(model, n, new_module) - - net=ResNet18() net.to(device) -replace_layers_agnostic(net, args.width) +replace_layers_scale(net, args.width) if args.fact: - replace_layers(net) + replace_layers_factconv2d(net) + + if args.fact: sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/fact_model.pt".format(args.width)) elif not args.fact: sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/conv_model.pt".format(args.width)) net.load_state_dict(sd) net.to(device) + net_new = copy.deepcopy(net) net_new.to(device) print(net_new) -replace_layers_fact(net) + +replace_layers_fact_with_conv(net) net.to(device) net.train() @@ -639,20 +568,6 @@ def fact_2_conv(new_module): print("FACT REPLACEMENT:", new_module) return new_module - -def turn_off_grads(model): - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - turn_off_grads(module) - else: - if isinstance(module, nn.Linear) and module.out_features == 10: - grad=True - else: - grad=False - for param in module.parameters(): - param.requires_grad = grad - def train(epoch, net): print('\nEpoch: %d' % epoch) @@ -712,7 +627,7 @@ def test(epoch, net): inputs, targets = inputs.to(device), targets.to(device) outputs = net_new(inputs) print("TOTAL TIME:", time.time()-s) -turn_off_grads(net_new) +turn_off_backbone_grad(net_new) optimizer = optim.SGD(filter(lambda param: param.requires_grad, net_new.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4) print("testing {} sampling at width {}".format(args.sampling, args.width)) @@ -745,7 +660,7 @@ def test(epoch, net): wandb_dir = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/wandb" os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) -group_name = "refactor" +group_string = "refactor" run = wandb.init(project="random_project", config=args, group=group_string, name=run_name, dir=wandb_dir) run.log(logger) From 1397314cd86ca9bf011b3509a8d918d0d81b065c Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 24 Apr 2024 12:36:46 -0400 Subject: [PATCH 28/77] added recurse_preorder. made each replace function a wrapper that returns the modifed network (tho the network is modified, so this is reduandant) --- refactor/models/function_utils.py | 109 ++++++++++++++++++++---------- refactor/refactor_rainbow.py | 10 +-- 2 files changed, 79 insertions(+), 40 deletions(-) diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index 34bf580..cd509c4 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -3,14 +3,23 @@ from conv_modules import FactConv2d from learnable_cov import V1_init + +def recurse_preorder(model, callback): + r = callback(model) + if r is not model and r is not None: + return r + for n, module in model.named_children(): + r = recurse_preorder(module, callback) + if r is not module and r is not None: + setattr(model, n, r) + return model + + def replace_layers_factconv2d(model): ''' Replace nn.Conv2d layers with FactConv2d ''' - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - replace_layers_factconv2d(module) + def _replace_layers_factconv2d(module): if isinstance(module, nn.Conv2d): ## simple module new_module = FactConv2d( @@ -25,17 +34,17 @@ def replace_layers_factconv2d(model): if module.bias is not None: new_sd['bias'] = old_sd['bias'] new_module.load_state_dict(new_sd) - setattr(model, n, new_module) + return new_module + return module + return recurse_preorder(model, _replace_layers_factconv2d) +#TODO: VIVIAN TEST THIS def replace_affines(model): ''' Set BatchNorm2d layers to have 'affine=False' ''' - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - replace_affines(module) + def _replace_affines(module): if isinstance(module, nn.BatchNorm2d): ## simple module new_module = nn.BatchNorm2d( @@ -43,18 +52,16 @@ def replace_affines(model): eps=module.eps, momentum=module.momentum, affine=False, track_running_stats=module.track_running_stats) - setattr(model, n, new_module) + return new_module + return module + return recurse_preorder(model, _replace_affines) def replace_layers_scale(model, scale=1): ''' Replace nn.Conv2d layers with a different scale ''' - prev_out_ch = 0 - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - replace_layers_scale(module,scale) + def _replace_layers_scale(module): if isinstance(module, nn.Conv2d): if module.in_channels == 3: in_scale = 1 @@ -68,24 +75,39 @@ def replace_layers_scale(model, scale=1): stride=module.stride, padding=module.padding, groups = module.groups, bias=True if module.bias is not None else False) - setattr(model, n, new_module) - prev_out_ch = new_module.out_channels + return new_module if isinstance(module, nn.BatchNorm2d): - new_module = nn.BatchNorm2d(prev_out_ch) - setattr(model, n, new_module) + new_module = nn.BatchNorm2d(int(module.num_features*scale), + affine=module.affine) + return new_module if isinstance(module, nn.Linear): new_module = nn.Linear(int(module.in_features * scale), 10) - setattr(model, n, new_module) + return new_module + return module + return recurse_preorder(model, _replace_layers_scale) + +#used in activation cross-covariance calculation +#input align hook +def return_hook(): + def hook(mod, inputs): + shape = inputs[0].shape + inputs_permute = inputs[0].permute(1,0,2,3).reshape(inputs[0].shape[1], -1) + reshape = (mod.input_align@inputs_permute).reshape(shape[1], + shape[0], shape[2], + shape[3]).permute(1, 0, 2, 3) + return reshape + return hook + + + +#TODO: MUAWIZ TEST THIS def replace_layers_fact_with_conv(model): ''' Replace FactConv2d layers with nn.Conv2d ''' - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - replace_layers_fact_with_conv(module) + def _replace_layers_fact_with_conv(module): if isinstance(module, FactConv2d): ## simple module new_module = nn.Conv2d( @@ -109,19 +131,33 @@ def replace_layers_fact_with_conv(model): module.weight.shape ) new_sd['weight'] = composite_weight + if 'weight_align' in old_sd.keys(): + new_sd['weight_align'] = old_sd['weight_align'] + shape = new_module.in_channels*new_module.kernel_size[0]*new_module.kernel_size[1] + new_module.register_buffer("weight_align",torch.zeros((shape, shape))) + if 'input_align' in old_sd.keys(): + new_sd['input_align'] = old_sd['input_align'] + out_shape = new_module.in_channels + new_module.register_buffer("input_align",torch.zeros((out_shape, out_shape))) + if module.in_channels != 3: + #fact check this + for key in list(module._forward_pre_hooks.keys()): + del module._forward_pre_hooks[key] + hook_handle_pre_forward = new_module.register_forward_pre_hook(return_hook()) new_module.load_state_dict(new_sd) - setattr(model, n, new_module) - + new_module.to(old_sd['weight'].device) + #new_module.load_state_dict(new_sd) + return new_module + return module + return recurse_preorder(model, _replace_layers_fact_with_conv) +#TODO: VIVIAN TEST THIS def turn_off_covar_grad(model, covariance): ''' Turn off gradients in tri1_vec or tri2_vec to turn off channel or spatial covariance learning ''' - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - turn_off_covar_grad(module, covariance) + def _turn_off_covar_grad(module): if isinstance(module, FactConv2d): for name, param in module.named_parameters(): if covariance == "channel": @@ -130,25 +166,26 @@ def turn_off_covar_grad(model, covariance): if covariance == "spatial": if "tri2_vec" in name: param.requires_grad = False - + return module + return recurse_preorder(model, _turn_off_covar_grad) + def turn_off_backbone_grad(model): ''' Turn off gradients in backbone. For tuning just classifier layer ''' - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - turn_off_backbone_grad(module) - #else: + def _turn_off_backbone_grad(module): if isinstance(module, nn.Linear) and module.out_features == 10: grad=True else: grad=False for param in module.parameters(): param.requires_grad = grad + return None + return recurse_preorder(model, _turn_off_backbone_grad) +#TODO: VIVIAN MODIFY THIS THEN TEST IT def init_V1_layers(model, bias): ''' Initialize every FactConv2d layer with V1-inspired diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index 4ec99ab..2f73871 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -21,7 +21,8 @@ from models.resnet import ResNet18 from conv_modules import FactConv2d from models.function_utils import replace_layers_factconv2d,\ -replace_layers_scale, replace_layers_fact_with_conv, turn_off_backbone_grad +replace_layers_scale, replace_layers_fact_with_conv, turn_off_backbone_grad, \ +recurse_preorder def save_model(args, model): src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/refactoring/" @@ -103,7 +104,6 @@ def save_model(args, model): net=ResNet18() -net.to(device) replace_layers_scale(net, args.width) if args.fact: replace_layers_factconv2d(net) @@ -122,6 +122,7 @@ def save_model(args, model): replace_layers_fact_with_conv(net) net.to(device) +print(net) net.train() net_new.train() @@ -213,8 +214,9 @@ def our_rainbow_sampling(model, new_model): new_module = conv_ACA(m1, m2, new_module) # converts fact conv to conv. this is for sake of speed. #if isinstance(new_module, FactConv2d): - # new_module = fact_2_conv(new_module) - #changes the network module + # new_module = replace_layers_fact_with_conv(new_module) + # #new_module = fact_2_conv(new_module) + # changes the network module setattr(new_model, n1, new_module) #only computes the ACA From 44da047a0f4fc11a313b53f2855833e10f698c88 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 24 Apr 2024 12:45:36 -0400 Subject: [PATCH 29/77] removed fact_2_conv --- refactor/refactor_rainbow.py | 50 ------------------------------------ 1 file changed, 50 deletions(-) diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index 2f73871..23f459d 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -521,56 +521,6 @@ def colored_Covariance_Specification(m1, m2, new_module, Un=None, Sn=None, Vn=No return new_module -def fact_2_conv(new_module): - ## simple module - print("Replacing FactConv") - fact_module = nn.Conv2d( - in_channels=new_module.in_channels, - out_channels=new_module.out_channels, - kernel_size=new_module.kernel_size, - stride=new_module.stride, padding=new_module.padding, - bias=True if new_module.bias is not None else False) - - old_sd = new_module.state_dict() - new_sd = fact_module.state_dict() - - if new_module.bias is not None: - new_sd['bias'] = old_sd['bias'] - - U1 = new_module._tri_vec_to_mat(new_module.tri1_vec, new_module.in_channels // - new_module.groups, new_module.scat_idx1) - U2 = new_module._tri_vec_to_mat(new_module.tri2_vec, - new_module.kernel_size[0] * new_module.kernel_size[1], - new_module.scat_idx2) - U = torch.kron(U1, U2) - - matrix_shape = (new_module.out_channels, new_module.in_features) - composite_weight = torch.reshape( - torch.reshape(new_module.weight, matrix_shape) @ U, - new_module.weight.shape - ) - - new_sd['weight'] = composite_weight - if 'weight_align' in old_sd.keys(): - new_sd['weight_align'] = old_sd['weight_align'] - shape = fact_module.in_channels*fact_module.kernel_size[0]*fact_module.kernel_size[1] - fact_module.register_buffer("weight_align",torch.zeros((shape, shape))) - if 'input_align' in old_sd.keys(): - new_sd['input_align'] = old_sd['input_align'] - out_shape = fact_module.in_channels - fact_module.register_buffer("input_align",torch.zeros((out_shape, out_shape))) - if new_module.in_channels != 3: - #fact check this - for key in list(new_module._forward_pre_hooks.keys()): - del new_module._forward_pre_hooks[key] - hook_handle_pre_forward = fact_module.register_forward_pre_hook(return_hook()) - fact_module.load_state_dict(new_sd) - fact_module.to(device) - new_module = fact_module - print("FACT REPLACEMENT:", new_module) - return new_module - - def train(epoch, net): print('\nEpoch: %d' % epoch) net.train() From a4bf746eed6dd5bbd1ab04dc7f5c7bd67876e5b2 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 24 Apr 2024 12:57:21 -0400 Subject: [PATCH 30/77] changes --- refactor/models/function_utils.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index cd509c4..789ad4b 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -35,7 +35,6 @@ def _replace_layers_factconv2d(module): new_sd['bias'] = old_sd['bias'] new_module.load_state_dict(new_sd) return new_module - return module return recurse_preorder(model, _replace_layers_factconv2d) @@ -53,7 +52,6 @@ def _replace_affines(module): affine=False, track_running_stats=module.track_running_stats) return new_module - return module return recurse_preorder(model, _replace_affines) @@ -83,7 +81,6 @@ def _replace_layers_scale(module): if isinstance(module, nn.Linear): new_module = nn.Linear(int(module.in_features * scale), 10) return new_module - return module return recurse_preorder(model, _replace_layers_scale) @@ -101,8 +98,6 @@ def hook(mod, inputs): return hook - -#TODO: MUAWIZ TEST THIS def replace_layers_fact_with_conv(model): ''' Replace FactConv2d layers with nn.Conv2d @@ -146,11 +141,10 @@ def _replace_layers_fact_with_conv(module): hook_handle_pre_forward = new_module.register_forward_pre_hook(return_hook()) new_module.load_state_dict(new_sd) new_module.to(old_sd['weight'].device) - #new_module.load_state_dict(new_sd) return new_module - return module return recurse_preorder(model, _replace_layers_fact_with_conv) + #TODO: VIVIAN TEST THIS def turn_off_covar_grad(model, covariance): ''' @@ -166,7 +160,6 @@ def _turn_off_covar_grad(module): if covariance == "spatial": if "tri2_vec" in name: param.requires_grad = False - return module return recurse_preorder(model, _turn_off_covar_grad) @@ -181,7 +174,6 @@ def _turn_off_backbone_grad(module): grad=False for param in module.parameters(): param.requires_grad = grad - return None return recurse_preorder(model, _turn_off_backbone_grad) From fff08993a2a2f84e76da76b34b65275cc23d365e Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Wed, 24 Apr 2024 16:42:22 -0400 Subject: [PATCH 31/77] did TODOs --- refactor/models/function_utils.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index 789ad4b..366d01c 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -38,7 +38,6 @@ def _replace_layers_factconv2d(module): return recurse_preorder(model, _replace_layers_factconv2d) -#TODO: VIVIAN TEST THIS def replace_affines(model): ''' Set BatchNorm2d layers to have 'affine=False' @@ -145,7 +144,6 @@ def _replace_layers_fact_with_conv(module): return recurse_preorder(model, _replace_layers_fact_with_conv) -#TODO: VIVIAN TEST THIS def turn_off_covar_grad(model, covariance): ''' Turn off gradients in tri1_vec or tri2_vec to turn off @@ -177,24 +175,20 @@ def _turn_off_backbone_grad(module): return recurse_preorder(model, _turn_off_backbone_grad) -#TODO: VIVIAN MODIFY THIS THEN TEST IT def init_V1_layers(model, bias): ''' Initialize every FactConv2d layer with V1-inspired spatial weight init ''' - for n, module in model.named_children(): - if len(list(module.children())) > 0: - ## compound module, go inside it - init_V1_layers(module, bias) + def _init_V1_layers(module): if isinstance(module, FactConv2d): center = ((module.kernel_size[0] - 1) / 2, (module.kernel_size[1] - 1) / 2) V1_init(module, size=2, spatial_freq=0.1, scale=1, center=center) for name, param in module.named_parameters(): if "weight" in name: param.requires_grad = False - if bias: if "bias" in name: param.requires_grad = False + return recurse_preorder(model, _init_V1_layers) From f36a459c68ad619082e035cfbf59eadd22fed849 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Tue, 30 Apr 2024 22:54:39 -0400 Subject: [PATCH 32/77] noting origin of code --- refactor/V1_covariance.py | 100 ++++++++++++++++++++++++++++++ refactor/conv_modules.py | 5 +- refactor/models/function_utils.py | 2 +- 3 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 refactor/V1_covariance.py diff --git a/refactor/V1_covariance.py b/refactor/V1_covariance.py new file mode 100644 index 0000000..448647d --- /dev/null +++ b/refactor/V1_covariance.py @@ -0,0 +1,100 @@ +""" +The following code is copied from the Structured Random Features library, +https://github.com/glomerulus-lab/structured-random-features, +used under the following license: + +The MIT License (MIT) +Copyright (c) 2021, Biraj Pandey + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + +""" + +import torch +from torch import Tensor +import numpy as np +import numpy.linalg as la +from scipy.spatial.distance import pdist, squareform + + +def V1_covariance_matrix(dim, size, spatial_freq, center, scale=1): + """ + Generates the covariance matrix for Gaussian Process with non-stationary + covariance. This matrix will be used to generate random + features inspired from the receptive-fields of V1 neurons. + + C(x, y) = exp(-|x - y|/(2 * spatial_freq))^2 * exp(-|x - m| / (2 * size))^2 * exp(-|y - m| / (2 * size))^2 + + Parameters + ---------- + + dim : tuple of shape (2, 1) + Dimension of random features. + + size : float + Determines the size of the random weights + + spatial_freq : float + Determines the spatial frequency of the random weights + + center : tuple of shape (2, 1) + Location of the center of the random weights. + + scale: float, default=1 + Normalization factor for Tr norm of cov matrix + + Returns + ------- + + C : array-like of shape (dim[0] * dim[1], dim[0] * dim[1]) + covariance matrix w/ Tr norm = scale * dim[0] * dim[1] + """ + + x = np.arange(dim[0]) + y = np.arange(dim[1]) + yy, xx = np.meshgrid(y, x) + grid = np.column_stack((xx.flatten(), yy.flatten())) + + a = squareform(pdist(grid, 'sqeuclidean')) + b = la.norm(grid - center, axis=1) ** 2 + c = b.reshape(-1, 1) + C = np.exp(-a / (2 * spatial_freq ** 2)) * np.exp(-b / (2 * size ** 2)) * np.exp(-c / (2 * size ** 2)) \ + + 1e-5 * np.eye(dim[0] * dim[1]) + C *= scale * dim[0] * dim[1] / np.trace(C) + return C + + +def V1_init(layer, size, spatial_freq, center, scale=1., bias=False, seed=None, + device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): + ''' + Initialization for FactConv2d + ''' + + classname = layer.__class__.__name__ + assert classname.find('FactConv2d') != -1, 'This init only works for FactConv2d layers' + assert center is not None, "center needed" + + out_channels, in_channels, xdim, ydim = layer.weight.shape + dim = (xdim, ydim) + + C_patch = Tensor(V1_covariance_matrix(dim, size, spatial_freq, center, scale)).to(device) + U_patch = torch.linalg.cholesky(C_patch, upper=True) + n = U_patch.shape[0] + # replace diagonal with logarithm for parameterization + log_diag = torch.log(torch.diagonal(U_patch)) + U_patch[range(n), range(n)] = log_diag + # form vector of upper triangular entries + tri_vec = U_patch[torch.triu_indices(n, n, device=device).tolist()].ravel() + with torch.no_grad(): + layer.tri2_vec.copy_(tri_vec) + + if bias == False: + layer.bias = None diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py index 847bab9..98b2a49 100644 --- a/refactor/conv_modules.py +++ b/refactor/conv_modules.py @@ -5,7 +5,10 @@ from torch.nn.common_types import _size_2_t from typing import Optional, List, Tuple, Union - +""" +The function below is copied directly from +https://bonnerlab.github.io/ccn-tutorial/pages/analyzing_neural_networks.html +""" def _contract(tensor, matrix, axis): """tensor is (..., D, ...), matrix is (P, D), returns (..., P, ...).""" t = torch.moveaxis(tensor, source=axis, destination=-1) # (..., D) diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index 366d01c..6614ec0 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from conv_modules import FactConv2d -from learnable_cov import V1_init +from V1_covariance import V1_init def recurse_preorder(model, callback): From 38de47115ce80ee5ac799af3a4f1a5b3a9973fe8 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Thu, 25 Apr 2024 17:46:33 -0400 Subject: [PATCH 33/77] added class conditional decerators. added looping over multiple callbacks, could be better coded --- refactor/models/function_utils.py | 130 ++++++++++++++++++++---------- 1 file changed, 86 insertions(+), 44 deletions(-) diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index 6614ec0..cf54465 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -15,27 +15,47 @@ def recurse_preorder(model, callback): return model -def replace_layers_factconv2d(model): +def recurse_preorder_v2(callback): + def _preorder_recursive_invoker(model): + r = callback(model) + if r is not model and r is not None: + return r + for n, module in model.named_children(): + r = _preorder_recursive_invoker(module) + if r is not module and r is not None: + setattr(model, n, r) + return model + return _preorder_recursive_invoker + + +def ifisinstance(klass): + def _make_conditional(callback): + def _conditional_invoker(model): + if isinstance(model, klass): + return callback(model) + return _conditional_invoker + return _make_conditional + + +@recurse_preorder_v2 +@ifisinstance(nn.Conv2d) +def replace_layers_factconv2d(module): ''' Replace nn.Conv2d layers with FactConv2d ''' - def _replace_layers_factconv2d(module): - if isinstance(module, nn.Conv2d): - ## simple module - new_module = FactConv2d( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - bias=True if module.bias is not None else False) - old_sd = module.state_dict() - new_sd = new_module.state_dict() - new_sd['weight'] = old_sd['weight'] - if module.bias is not None: - new_sd['bias'] = old_sd['bias'] - new_module.load_state_dict(new_sd) - return new_module - return recurse_preorder(model, _replace_layers_factconv2d) + new_module = FactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module def replace_affines(model): @@ -44,7 +64,6 @@ def replace_affines(model): ''' def _replace_affines(module): if isinstance(module, nn.BatchNorm2d): - ## simple module new_module = nn.BatchNorm2d( num_features=module.num_features, eps=module.eps, momentum=module.momentum, @@ -54,35 +73,58 @@ def _replace_affines(module): return recurse_preorder(model, _replace_affines) +def replace_layers_conv_scale(scale): + @ifisinstance(nn.Conv2d) + def _replace_layers_conv_scale(module): + if module.in_channels == 3: + in_scale = 1 + else: + in_scale = scale + new_module = nn.Conv2d( + in_channels=int(module.in_channels*in_scale), + out_channels=int(module.out_channels*scale), + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + groups = module.groups, + bias=True if module.bias is not None else False) + return new_module + return _replace_layers_conv_scale + + +def replace_layers_bn_scale(scale): + @ifisinstance(nn.BatchNorm2d) + def _replace_layers_bn_scale(module): + new_module = nn.BatchNorm2d(int(module.num_features*scale), affine=module.affine) + return new_module + return _replace_layers_bn_scale + + +def replace_layers_linear_scale(scale): + @ifisinstance(nn.Linear) + def _replace_layers_linear_scale(module): + new_module = nn.Linear(int(module.in_features * scale), 10) + return new_module + return _replace_layers_linear_scale + + + def replace_layers_scale(model, scale=1): ''' Replace nn.Conv2d layers with a different scale ''' - def _replace_layers_scale(module): - if isinstance(module, nn.Conv2d): - if module.in_channels == 3: - in_scale = 1 - else: - in_scale = scale - ## simple module - new_module = nn.Conv2d( - in_channels=int(module.in_channels*in_scale), - out_channels=int(module.out_channels*scale), - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - groups = module.groups, - bias=True if module.bias is not None else False) - return new_module - if isinstance(module, nn.BatchNorm2d): - new_module = nn.BatchNorm2d(int(module.num_features*scale), - affine=module.affine) - return new_module - if isinstance(module, nn.Linear): - new_module = nn.Linear(int(module.in_features * scale), 10) - return new_module - return recurse_preorder(model, _replace_layers_scale) - - + callback_list = [replace_layers_conv_scale(scale), + replace_layers_bn_scale(scale), replace_layers_linear_scale(scale)] + + @recurse_preorder_v2 + def _replace_layers_scale(model): + for callback in callback_list: + r = callback(model) + if r is not None and r is not model: + return r + return model + + return _replace_layers_scale(model) + #used in activation cross-covariance calculation #input align hook From 0cca6a1efe744ef5c2ead22aaac5b4c1dbb8de9d Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Fri, 26 Apr 2024 06:51:12 -0400 Subject: [PATCH 34/77] Revert to 09f4833 --- refactor/models/function_utils.py | 130 ++++++++++-------------------- 1 file changed, 44 insertions(+), 86 deletions(-) diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index cf54465..6614ec0 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -15,47 +15,27 @@ def recurse_preorder(model, callback): return model -def recurse_preorder_v2(callback): - def _preorder_recursive_invoker(model): - r = callback(model) - if r is not model and r is not None: - return r - for n, module in model.named_children(): - r = _preorder_recursive_invoker(module) - if r is not module and r is not None: - setattr(model, n, r) - return model - return _preorder_recursive_invoker - - -def ifisinstance(klass): - def _make_conditional(callback): - def _conditional_invoker(model): - if isinstance(model, klass): - return callback(model) - return _conditional_invoker - return _make_conditional - - -@recurse_preorder_v2 -@ifisinstance(nn.Conv2d) -def replace_layers_factconv2d(module): +def replace_layers_factconv2d(model): ''' Replace nn.Conv2d layers with FactConv2d ''' - new_module = FactConv2d( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - bias=True if module.bias is not None else False) - old_sd = module.state_dict() - new_sd = new_module.state_dict() - new_sd['weight'] = old_sd['weight'] - if module.bias is not None: - new_sd['bias'] = old_sd['bias'] - new_module.load_state_dict(new_sd) - return new_module + def _replace_layers_factconv2d(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = FactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_factconv2d) def replace_affines(model): @@ -64,6 +44,7 @@ def replace_affines(model): ''' def _replace_affines(module): if isinstance(module, nn.BatchNorm2d): + ## simple module new_module = nn.BatchNorm2d( num_features=module.num_features, eps=module.eps, momentum=module.momentum, @@ -73,58 +54,35 @@ def _replace_affines(module): return recurse_preorder(model, _replace_affines) -def replace_layers_conv_scale(scale): - @ifisinstance(nn.Conv2d) - def _replace_layers_conv_scale(module): - if module.in_channels == 3: - in_scale = 1 - else: - in_scale = scale - new_module = nn.Conv2d( - in_channels=int(module.in_channels*in_scale), - out_channels=int(module.out_channels*scale), - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - groups = module.groups, - bias=True if module.bias is not None else False) - return new_module - return _replace_layers_conv_scale - - -def replace_layers_bn_scale(scale): - @ifisinstance(nn.BatchNorm2d) - def _replace_layers_bn_scale(module): - new_module = nn.BatchNorm2d(int(module.num_features*scale), affine=module.affine) - return new_module - return _replace_layers_bn_scale - - -def replace_layers_linear_scale(scale): - @ifisinstance(nn.Linear) - def _replace_layers_linear_scale(module): - new_module = nn.Linear(int(module.in_features * scale), 10) - return new_module - return _replace_layers_linear_scale - - - def replace_layers_scale(model, scale=1): ''' Replace nn.Conv2d layers with a different scale ''' - callback_list = [replace_layers_conv_scale(scale), - replace_layers_bn_scale(scale), replace_layers_linear_scale(scale)] - - @recurse_preorder_v2 - def _replace_layers_scale(model): - for callback in callback_list: - r = callback(model) - if r is not None and r is not model: - return r - return model - - return _replace_layers_scale(model) - + def _replace_layers_scale(module): + if isinstance(module, nn.Conv2d): + if module.in_channels == 3: + in_scale = 1 + else: + in_scale = scale + ## simple module + new_module = nn.Conv2d( + in_channels=int(module.in_channels*in_scale), + out_channels=int(module.out_channels*scale), + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + groups = module.groups, + bias=True if module.bias is not None else False) + return new_module + if isinstance(module, nn.BatchNorm2d): + new_module = nn.BatchNorm2d(int(module.num_features*scale), + affine=module.affine) + return new_module + if isinstance(module, nn.Linear): + new_module = nn.Linear(int(module.in_features * scale), 10) + return new_module + return recurse_preorder(model, _replace_layers_scale) + + #used in activation cross-covariance calculation #input align hook From 42f7542356e49e58b177e42cf2875b0a0c6a6b06 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Fri, 26 Apr 2024 11:57:41 -0400 Subject: [PATCH 35/77] seperated rainbow functionality from runner script --- refactor/rainbow.py | 404 +++++++++++++++++++++++++++++++++++ refactor/refactor_rainbow.py | 400 +--------------------------------- 2 files changed, 412 insertions(+), 392 deletions(-) create mode 100644 refactor/rainbow.py diff --git a/refactor/rainbow.py b/refactor/rainbow.py new file mode 100644 index 0000000..b26ce1c --- /dev/null +++ b/refactor/rainbow.py @@ -0,0 +1,404 @@ +import torch +import torch.nn as nn + +import copy + + + +from conv_modules import FactConv2d +#traditional way of calculating svd. can be a bit unstable sometimes tho +def calc_svd(A, name=''): + u, s, vh = torch.linalg.svd( + A, full_matrices=False, + ) # (C_in_reference, R), (R,), (R, C_in_generated) + alignment = u @ vh # (C_in_reference, C_in_generated) + return alignment, (u, s, vh) + + +#used in activation cross-covariance calculation +#input align hook +def return_hook(): + def hook(mod, inputs): + shape = inputs[0].shape + inputs_permute = inputs[0].permute(1,0,2,3).reshape(inputs[0].shape[1], -1) + reshape = (mod.input_align@inputs_permute).reshape(shape[1], + shape[0], shape[2], + shape[3]).permute(1, 0, 2, 3) + return reshape + return hook + + +#settings +# our, wa true, aca true (fact, conv) +# our wa false, aca true (fact, conv) +# our wa false aca false (fact, conv) [Just Random] +# our was true aca false (fact, conv) +# +# theirs wa true, aca true (conv) +# theirs wa false, aca true (conv) +# theirs wa false aca false (conv) [Just Random] +# theirs was true aca false (conv) +# +# this function does not do an explicit specification of the colored covariance +class rainbow_sampler: + def __init__(self, net, net_new, args, device, trainloader): + self.net = copy.deepcopy(net) + self.net_new = copy.deepcopy(net_new) + self.sampling = args.sampling + self.wa = args.wa + self.in_wa = args.in_wa + self.aca = args.aca + self.device = device + self.trainloader = trainloader + + def do_rainbow_sampling(self): + self.our_rainbow_sampling(self.net, self.net_new) + + @torch.no_grad() + def our_rainbow_sampling(self, model, new_model): + for (n1, m1), (n2, m2) in zip(model.named_children(), new_model.named_children()): + if len(list(m1.children())) > 0: + self.our_rainbow_sampling(m1, m2) + if isinstance(m1, nn.Conv2d): + if isinstance(m2, FactConv2d): + new_module = FactConv2d( + in_channels=m2.in_channels, + out_channels=m2.out_channels, + kernel_size=m2.kernel_size, + stride=m2.stride, padding=m2.padding, + bias=True if m2.bias is not None else + False).to(self.device) + else: + new_module = nn.Conv2d( + in_channels=int(m2.in_channels), + out_channels=int(m2.out_channels), + kernel_size=m2.kernel_size, + stride=m2.stride, padding=m2.padding, + groups = m2.groups, + bias=True if m2.bias is not None else False).to(self.device) + + if self.sampling == 'ours' and self.wa: + # right now this function does not do an explicit specification of the colored covariance + new_module = self.weight_Alignment(m1, m2, new_module, in_dim=self.in_wa) + if self.sampling == 'theirs': + # for conv only + if self.wa: + new_module = weight_Alignment_With_CC(m1, m2, new_module) + else: + new_module = colored_Covariance_Specification(m1, m2, new_module) + # this step calculates the activation cross-covariance alignment (ACA) + if m1.in_channels != 3 and self.aca: + new_module = self.conv_ACA(m1, m2, new_module) + # converts fact conv to conv. this is for sake of speed. + #if isinstance(new_module, FactConv2d): + # new_module = replace_layers_fact_with_conv(new_module) + # #new_module = fact_2_conv(new_module) + # changes the network module + setattr(new_model, n1, new_module) + + #only computes the ACA + if isinstance(m1, nn.Linear) and self.aca: + new_module = self.linear_ACA(m1, m2, new_model) + setattr(new_model, n1, new_module) + ##just run stats through + if isinstance(m1, nn.BatchNorm2d): + self.batchNorm_stats_recalc(m1, m2) + + + def conv_ACA(self, m1, m2, new_module): + print("Convolutional Input Activations Alignment") + activation = [] + other_activation = [] + # this hook grabs the input activations of the conv layer + # rearanges the vector so that the width by height dim is + # additional samples to the covariance + # bwh x c + def define_hook(m): + def store_hook(mod, inputs, outputs): + #from bonner lab tutorial + x = inputs[0] + x = x.permute(0, 2, 3, 1) + x = x.reshape((-1, x.shape[-1])) + activation.append(x) + raise Exception("Done") + return store_hook + + hook_handle_1 = m1.register_forward_hook(define_hook(m1)) + hook_handle_2 = m2.register_forward_hook(define_hook(m2)) + + print("Starting Sample Cross-Covariance Calculation") + covar = None + total = 0 + for batch_idx, (inputs, targets) in enumerate(self.trainloader): + inputs, targets = inputs.to(self.device), targets.to(self.device) + try: + outputs1 = self.net(inputs) + except Exception: + pass + try: + outputs2 = self.net_new(inputs) + except Exception: + pass + total+= inputs.shape[0] + if covar is None: + #activation is bwh x c + covar = activation[0].T @ activation[1] + assert (covar.isfinite().all()) + else: + #activation is bwh x c + covar += activation[0].T @ activation[1] + assert (covar.isfinite().all()) + activation = [] + other_activation = [] + + #c x c + covar /= total + hook_handle_1.remove() + hook_handle_2.remove() + print("Sample Cross-Covariance Calculation finished") + align, _ = calc_svd(covar, name="Cross-Covariance") + new_module.register_buffer("input_align", align) + + # this hook takes the input to the conv, aligns, then returns + # to the conv the aligned inputs + hook_handle_pre_forward = new_module.register_forward_pre_hook(return_hook()) + return new_module + + + def batchNorm_stats_recalc(self, m1, m2): + print("Calculating Batch Statistics") + m1.train() + m2.train() + m1.reset_running_stats() + m2.reset_running_stats() + handle_1 = m1.register_forward_hook(lambda mod, inputs, outputs: Exception("Done")) + handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: Exception("Done")) + m1.to(self.device) + m2.to(self.device) + for batch_idx, (inputs, targets) in enumerate(self.trainloader): + inputs, targets = inputs.to(self.device), targets.to(self.device) + try: + outputs1 = self.net(inputs) + except Exception: + pass + try: + outputs2 = self.net_new(inputs) + except Exception: + pass + handle_1.remove() + handle_2.remove() + m1.eval() + m2.eval() + print("Batch Statistics Calculation Finished") + + + def linear_ACA(self, m1, m2, new_model): + print("Linear Input Activations Alignment") + new_module = nn.Linear(m1.in_features, m1.out_features, bias=True + if m1.bias is not None else False).to(self.device) + ref_sd = m1.state_dict() + loading_sd = new_module.state_dict() + loading_sd['weight'] = ref_sd['weight'] + if m1.bias is not None: + loading_sd['bias'] = ref_sd['bias'] + activation = [] + other_activation = [] + + hook_handle_1 = m1.register_forward_hook(lambda mod, inputs, outputs: + activation.append(inputs[0])) + + hook_handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: + other_activation.append(inputs[0])) + covar = None + total = 0 + print("Starting Sample Cross-Covariance Calculation") + for batch_idx, (inputs, targets) in enumerate(self.trainloader): + inputs, targets = inputs.to(self.device), targets.to(self.device) + outputs1 = self.net(inputs) + outputs2 = self.net_new(inputs) + total+= inputs.shape[0] + if covar is None: + covar = activation[0].T @ other_activation[0] + else: + covar += activation[0].T @ other_activation[0] + activation = [] + other_activation = [] + covar /= total + + hook_handle_1.remove() + hook_handle_2.remove() + + print("Sample Cross-Covariance Calculation finished") + align, _ = calc_svd(covar, name="Cross-Covariance") + new_weight = loading_sd['weight'] + new_weight = torch.moveaxis(new_weight, source=1, + destination=-1) + new_weight = new_weight@align + loading_sd['weight'] = torch.moveaxis(new_weight, source=-1, + destination=1) + new_module.load_state_dict(loading_sd) + return new_module + + + + def weight_Alignment(self,m1, m2, new_module, in_dim=True): + # reference model state dict + ref_sd = m1.state_dict() + # generated model state dict - uses reference model weights. for now + gen_sd = m2.state_dict() + + # module with random init - to be loaded to model + loading_sd = new_module.state_dict() + new_gaussian = loading_sd['weight'] + + # carry over old bias. only matters when we work with no batchnorm networks + if m1.bias is not None: + loading_sd['bias'] = ref_sd['bias'] + # carry over old colored covariance. only matters with fact-convs + if "tri1_vec" in gen_sd.keys(): + loading_sd['tri1_vec']=gen_sd['tri1_vec'] + loading_sd['tri2_vec']=gen_sd['tri2_vec'] + + #this is the spot where + # we can do weight alignment + # for fact net, this means aligning with the random noise + # for conv net, this could mean aligning with a. W OR b. U + # we can do colored-covariance specification + # for fact net, this means just using it's R matrix + # for conv net, this could mean doing nothing (if aligning with W), or use S and V if we did b. + # in this function, we just align with W and don't specify the mulit-color covariance + + # IF FACT: we align the generated factnet with the reference fact net's noise + # IF CONV: we align the generated convnet with the reference conv net's weight matrix + reference_weight = gen_sd['weight'] + generated_weight = new_gaussian + + #reshape to outdim x indim*spatial + reference_weight = reference_weight.reshape(reference_weight.shape[0], -1) + generated_weight = generated_weight.reshape(generated_weight.shape[0], -1) + #compute transpose, giving indim*spatial x outdim + + #compute weight cross-covariance indim*spatial x indim*spatial + #TODO REFACTOR TO HAVE REF FIRST. OUTDIM x OUTDIM + if in_dim: + print("Input Weight Alignment") + weight_cov = (generated_weight.T@reference_weight) + alignment, _ = calc_svd(weight_cov, name="Weight alignment") + + # outdim x indim x spatial + final_gen_weight = new_gaussian + # outdim x indim*spatial + final_gen_weight = final_gen_weight.reshape(final_gen_weight.shape[0], -1) + # outdim x indim*spatial + final_gen_weight = final_gen_weight@alignment + else: + print("Output Weight Alignment") + weight_cov = (reference_weight@generated_weight.T) + alignment, _ = calc_svd(weight_cov, name="Weight alignment") + + # outdim x indim x spatial + final_gen_weight = new_gaussian + # outdim x indim*spatial + final_gen_weight = final_gen_weight.reshape(final_gen_weight.shape[0], -1) + # outdim x indim*spatial + final_gen_weight = alignment@final_gen_weight + + loading_sd['weight'] = final_gen_weight.reshape(ref_sd['weight'].shape) + loading_sd['weight_align'] = alignment + new_module.register_buffer("weight_align", alignment) + new_module.load_state_dict(loading_sd) + return new_module + + + + + def weight_Alignment_With_CC(self, m1, m2, new_module, Un=None, Sn=None, Vn=None): + print("NOT SUPPOSED TO BE HERE") + # reference model state dict + ref_sd = m1.state_dict() + # generated model state dict - uses reference model weights. for now + gen_sd = m2.state_dict() + + # module with random init - to be loaded to model + loading_sd = new_module.state_dict() + new_gaussian = loading_sd['weight'] + + # carry over old bias. only matters when we work with no batchnorm networks + if m1.bias is not None: + loading_sd['bias'] = ref_sd['bias'] + # carry over old colored covariance. only matters with fact-convs + if "tri1_vec" in gen_sd.keys(): + loading_sd['tri1_vec']=gen_sd['tri1_vec'] + loading_sd['tri2_vec']=gen_sd['tri2_vec'] + + old_weight = ref_sd['weight'] + A = old_weight.reshape(old_weight.shape[0], -1) + A_T_A = A.T@A + V_val, Vn = torch.linalg.eigh(A_T_A) + del A_T_A + V_val = V_val.flip(0) + Vn = Vn.fliplr().T + Sn = (1e-6 + V_val.abs()).sqrt() + Sn_inv = (1/Sn).diag() + Un = A @ Vn.T @ Sn_inv + white_gaussian = torch.randn_like(Un) + + copy_weight = Un + copy_weight_gen = white_gaussian + copy_weight = copy_weight.reshape(copy_weight.shape[0], -1) + copy_weight_gen = copy_weight_gen.reshape(copy_weight_gen.shape[0], -1).T + weight_cov = (copy_weight_gen@copy_weight) + + alignment = calc_svd(weight_cov, name="Weight") + new_weight = white_gaussian + new_weight = new_weight.reshape(new_weight.shape[0], -1) + new_weight = new_weight@alignment # C_in_reference to C_in_generated + + new_module.register_buffer("weight_align", alignment) + loading_sd['weight_align'] = alignment + colored_gaussian = white_gaussian @ (Sn[:,None]* Vn) + loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) + new_module.load_state_dict(loading_sd) + return new_module + + + # this function does not do an explicit specification of the colored covariance + @torch.no_grad() + def colored_Covariance_Specification(self, m1, m2, new_module, Un=None, Sn=None, Vn=None): + print("NOT HERE") + # reference model state dict + ref_sd = m1.state_dict() + # generated model state dict - uses reference model weights. for now + gen_sd = m2.state_dict() + + # module with random init - to be loaded to model + loading_sd = new_module.state_dict() + new_gaussian = loading_sd['weight'] + + # carry over old bias. only matters when we work with no batchnorm networks + if m1.bias is not None: + loading_sd['bias'] = ref_sd['bias'] + # carry over old colored covariance. only matters with fact-convs + if "tri1_vec" in gen_sd.keys(): + loading_sd['tri1_vec']=gen_sd['tri1_vec'] + loading_sd['tri2_vec']=gen_sd['tri2_vec'] + + old_weight = ref_sd['weight'] + A = old_weight.reshape(old_weight.shape[0], -1) + A_T_A = A.T@A + V_val, Vn = torch.linalg.eigh(A_T_A) + del A_T_A + V_val = V_val.flip(0) + Vn = Vn.fliplr().T + Sn = (1e-6 + V_val.abs()).sqrt() + Sn_inv = (1/Sn).diag() + Un = A @ Vn.T @ Sn_inv + white_gaussian = torch.randn_like(Un) + + colored_gaussian = white_gaussian @ (Sn[:,None]* Vn) + loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) + new_module.load_state_dict(loading_sd) + return new_module + + + diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index 23f459d..121d5d7 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -23,6 +23,7 @@ from models.function_utils import replace_layers_factconv2d,\ replace_layers_scale, replace_layers_fact_with_conv, turn_off_backbone_grad, \ recurse_preorder +from rainbow import calc_svd, return_hook, rainbow_sampler def save_model(args, model): src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/refactoring/" @@ -130,397 +131,6 @@ def save_model(args, model): set_seeds(args.seed) criterion = nn.CrossEntropyLoss() - -#traditional way of calculating svd. can be a bit unstable sometimes tho -def calc_svd(A, name=''): - u, s, vh = torch.linalg.svd( - A, full_matrices=False, - ) # (C_in_reference, R), (R,), (R, C_in_generated) - alignment = u @ vh # (C_in_reference, C_in_generated) - return alignment - - -#i've been finding this way of calculating svd to be more stable. -def calc_svd_eigh(A, name=''): - A_T_A = A.T@A - V_val, Vn = torch.linalg.eigh(A_T_A) - V_val = V_val.flip(0) - Vn = Vn.fliplr().T - Sn = (1e-6 + V_val.abs()).sqrt() - Sn_inv = (1/Sn).diag() - Un = A @ Vn.T @ Sn_inv - alignment = Un @ Vn - return alignment - - -#used in activation cross-covariance calculation -#input align hook -def return_hook(): - def hook(mod, inputs): - shape = inputs[0].shape - inputs_permute = inputs[0].permute(1,0,2,3).reshape(inputs[0].shape[1], -1) - reshape = (mod.input_align@inputs_permute).reshape(shape[1], - shape[0], shape[2], - shape[3]).permute(1, 0, 2, 3) - return reshape - return hook - - -#settings -# our, wa true, aca true (fact, conv) -# our wa false, aca true (fact, conv) -# our wa false aca false (fact, conv) [Just Random] -# our was true aca false (fact, conv) -# -# theirs wa true, aca true (conv) -# theirs wa false, aca true (conv) -# theirs wa false aca false (conv) [Just Random] -# theirs was true aca false (conv) -# -# this function does not do an explicit specification of the colored covariance -@torch.no_grad() -def our_rainbow_sampling(model, new_model): - for (n1, m1), (n2, m2) in zip(model.named_children(), new_model.named_children()): - if len(list(m1.children())) > 0: - our_rainbow_sampling(m1, m2) - if isinstance(m1, nn.Conv2d): - if isinstance(m2, FactConv2d): - new_module = FactConv2d( - in_channels=m2.in_channels, - out_channels=m2.out_channels, - kernel_size=m2.kernel_size, - stride=m2.stride, padding=m2.padding, - bias=True if m2.bias is not None else False) - else: - new_module = nn.Conv2d( - in_channels=int(m2.in_channels), - out_channels=int(m2.out_channels), - kernel_size=m2.kernel_size, - stride=m2.stride, padding=m2.padding, - groups = m2.groups, - bias=True if m2.bias is not None else False).to(device) - - if args.sampling == 'ours' and args.wa: - # right now this function does not do an explicit specification of the colored covariance - new_module = weight_Alignment(m1, m2, new_module, in_dim=args.in_wa) - if args.sampling == 'theirs': - # for conv only - if args.wa: - new_module = weight_Alignment_With_CC(m1, m2, new_module) - else: - new_module = colored_Covariance_Specification(m1, m2, new_module) - # this step calculates the activation cross-covariance alignment (ACA) - if m1.in_channels != 3 and args.aca: - new_module = conv_ACA(m1, m2, new_module) - # converts fact conv to conv. this is for sake of speed. - #if isinstance(new_module, FactConv2d): - # new_module = replace_layers_fact_with_conv(new_module) - # #new_module = fact_2_conv(new_module) - # changes the network module - setattr(new_model, n1, new_module) - - #only computes the ACA - if isinstance(m1, nn.Linear) and args.aca: - new_module = linear_ACA(m1, m2, new_model) - setattr(new_model, n1, new_module) - #just run stats through - if isinstance(m1, nn.BatchNorm2d): - batchNorm_stats_recalc(m1, m2) - - -def weight_Alignment(m1, m2, new_module, in_dim=True): - # reference model state dict - ref_sd = m1.state_dict() - # generated model state dict - uses reference model weights. for now - gen_sd = m2.state_dict() - - # module with random init - to be loaded to model - loading_sd = new_module.state_dict() - new_gaussian = loading_sd['weight'] - - # carry over old bias. only matters when we work with no batchnorm networks - if m1.bias is not None: - loading_sd['bias'] = ref_sd['bias'] - # carry over old colored covariance. only matters with fact-convs - if "tri1_vec" in gen_sd.keys(): - loading_sd['tri1_vec']=gen_sd['tri1_vec'] - loading_sd['tri2_vec']=gen_sd['tri2_vec'] - - #this is the spot where - # we can do weight alignment - # for fact net, this means aligning with the random noise - # for conv net, this could mean aligning with a. W OR b. U - # we can do colored-covariance specification - # for fact net, this means just using it's R matrix - # for conv net, this could mean doing nothing (if aligning with W), or use S and V if we did b. - # in this function, we just align with W and don't specify the mulit-color covariance - - # IF FACT: we align the generated factnet with the reference fact net's noise - # IF CONV: we align the generated convnet with the reference conv net's weight matrix - reference_weight = gen_sd['weight'] - generated_weight = new_gaussian - - #reshape to outdim x indim*spatial - reference_weight = reference_weight.reshape(reference_weight.shape[0], -1) - generated_weight = generated_weight.reshape(generated_weight.shape[0], -1) - #compute transpose, giving indim*spatial x outdim - - #compute weight cross-covariance indim*spatial x indim*spatial - #TODO REFACTOR TO HAVE REF FIRST. OUTDIM x OUTDIM - if in_dim: - print("Input Weight Alignment") - weight_cov = (generated_weight.T@reference_weight) - alignment = calc_svd(weight_cov, name="Weight alignment") - - # outdim x indim x spatial - final_gen_weight = new_gaussian - # outdim x indim*spatial - final_gen_weight = final_gen_weight.reshape(final_gen_weight.shape[0], -1) - # outdim x indim*spatial - final_gen_weight = final_gen_weight@alignment - else: - print("Output Weight Alignment") - weight_cov = (reference_weight@generated_weight.T) - alignment = calc_svd(weight_cov, name="Weight alignment") - - # outdim x indim x spatial - final_gen_weight = new_gaussian - # outdim x indim*spatial - final_gen_weight = final_gen_weight.reshape(final_gen_weight.shape[0], -1) - # outdim x indim*spatial - final_gen_weight = alignment@final_gen_weight - - loading_sd['weight'] = final_gen_weight.reshape(ref_sd['weight'].shape) - loading_sd['weight_align'] = alignment - new_module.register_buffer("weight_align", alignment) - new_module.load_state_dict(loading_sd) - return new_module - - -def conv_ACA(m1, m2, new_module): - print("Convolutional Input Activations Alignment") - activation = [] - other_activation = [] - # this hook grabs the input activations of the conv layer - # rearanges the vector so that the width by height dim is - # additional samples to the covariance - # bwh x c - def define_hook(m): - def store_hook(mod, inputs, outputs): - #from bonner lab tutorial - x = inputs[0] - x = x.permute(0, 2, 3, 1) - x = x.reshape((-1, x.shape[-1])) - activation.append(x) - raise Exception("Done") - return store_hook - - hook_handle_1 = m1.register_forward_hook(define_hook(m1)) - hook_handle_2 = m2.register_forward_hook(define_hook(m2)) - - print("Starting Sample Cross-Covariance Calculation") - covar = None - total = 0 - for batch_idx, (inputs, targets) in enumerate(trainloader): - inputs, targets = inputs.to(device), targets.to(device) - try: - outputs1 = net(inputs) - except Exception: - pass - try: - outputs2 = net_new(inputs) - except Exception: - pass - total+= inputs.shape[0] - if covar is None: - #activation is bwh x c - covar = activation[0].T @ activation[1] - assert (covar.isfinite().all()) - else: - #activation is bwh x c - covar += activation[0].T @ activation[1] - assert (covar.isfinite().all()) - activation = [] - other_activation = [] - - #c x c - covar /= total - hook_handle_1.remove() - hook_handle_2.remove() - print("Sample Cross-Covariance Calculation finished") - align = calc_svd(covar, name="Cross-Covariance") - new_module.register_buffer("input_align", align) - - # this hook takes the input to the conv, aligns, then returns - # to the conv the aligned inputs - hook_handle_pre_forward = new_module.register_forward_pre_hook(return_hook()) - return new_module - - -def linear_ACA(m1, m2, new_model): - print("Linear Input Activations Alignment") - new_module = nn.Linear(m1.in_features, m1.out_features, bias=True - if m1.bias is not None else False).to(device) - ref_sd = m1.state_dict() - loading_sd = new_module.state_dict() - loading_sd['weight'] = ref_sd['weight'] - if m1.bias is not None: - loading_sd['bias'] = ref_sd['bias'] - activation = [] - other_activation = [] - - hook_handle_1 = m1.register_forward_hook(lambda mod, inputs, outputs: - activation.append(inputs[0])) - - hook_handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: - other_activation.append(inputs[0])) - covar = None - total = 0 - print("Starting Sample Cross-Covariance Calculation") - for batch_idx, (inputs, targets) in enumerate(trainloader): - inputs, targets = inputs.to(device), targets.to(device) - outputs1 = net(inputs) - outputs2 = net_new(inputs) - total+= inputs.shape[0] - if covar is None: - covar = activation[0].T @ other_activation[0] - else: - covar += activation[0].T @ other_activation[0] - activation = [] - other_activation = [] - covar /= total - - hook_handle_1.remove() - hook_handle_2.remove() - - print("Sample Cross-Covariance Calculation finished") - align = calc_svd(covar, name="Cross-Covariance") - new_weight = loading_sd['weight'] - new_weight = torch.moveaxis(new_weight, source=1, - destination=-1) - new_weight = new_weight@align - loading_sd['weight'] = torch.moveaxis(new_weight, source=-1, - destination=1) - new_module.load_state_dict(loading_sd) - return new_module - - -def batchNorm_stats_recalc(m1, m2): - print("Calculating Batch Statistics") - m1.train() - m2.train() - m1.reset_running_stats() - m2.reset_running_stats() - handle_1 = m1.register_forward_hook(lambda mod, inputs, outputs: Exception("Done")) - handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: Exception("Done")) - m1.to(device) - m2.to(device) - for batch_idx, (inputs, targets) in enumerate(trainloader): - inputs, targets = inputs.to(device), targets.to(device) - try: - outputs1 = net(inputs) - except Exception: - pass - try: - outputs2 = net_new(inputs) - except Exception: - pass - handle_1.remove() - handle_2.remove() - m1.eval() - m2.eval() - print("Batch Statistics Calculation Finished") - - -def weight_Alignment_With_CC(m1, m2, new_module, Un=None, Sn=None, Vn=None): - print("NOT SUPPOSED TO BE HERE") - # reference model state dict - ref_sd = m1.state_dict() - # generated model state dict - uses reference model weights. for now - gen_sd = m2.state_dict() - - # module with random init - to be loaded to model - loading_sd = new_module.state_dict() - new_gaussian = loading_sd['weight'] - - # carry over old bias. only matters when we work with no batchnorm networks - if m1.bias is not None: - loading_sd['bias'] = ref_sd['bias'] - # carry over old colored covariance. only matters with fact-convs - if "tri1_vec" in gen_sd.keys(): - loading_sd['tri1_vec']=gen_sd['tri1_vec'] - loading_sd['tri2_vec']=gen_sd['tri2_vec'] - - old_weight = ref_sd['weight'] - A = old_weight.reshape(old_weight.shape[0], -1) - A_T_A = A.T@A - V_val, Vn = torch.linalg.eigh(A_T_A) - del A_T_A - V_val = V_val.flip(0) - Vn = Vn.fliplr().T - Sn = (1e-6 + V_val.abs()).sqrt() - Sn_inv = (1/Sn).diag() - Un = A @ Vn.T @ Sn_inv - white_gaussian = torch.randn_like(Un) - - copy_weight = Un - copy_weight_gen = white_gaussian - copy_weight = copy_weight.reshape(copy_weight.shape[0], -1) - copy_weight_gen = copy_weight_gen.reshape(copy_weight_gen.shape[0], -1).T - weight_cov = (copy_weight_gen@copy_weight) - - alignment = calc_svd(weight_cov, name="Weight") - new_weight = white_gaussian - new_weight = new_weight.reshape(new_weight.shape[0], -1) - new_weight = new_weight@alignment # C_in_reference to C_in_generated - - new_module.register_buffer("weight_align", alignment) - loading_sd['weight_align'] = alignment - colored_gaussian = white_gaussian @ (Sn[:,None]* Vn) - loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) - new_module.load_state_dict(loading_sd) - return new_module - - -# this function does not do an explicit specification of the colored covariance -@torch.no_grad() -def colored_Covariance_Specification(m1, m2, new_module, Un=None, Sn=None, Vn=None): - print("NOT HERE") - # reference model state dict - ref_sd = m1.state_dict() - # generated model state dict - uses reference model weights. for now - gen_sd = m2.state_dict() - - # module with random init - to be loaded to model - loading_sd = new_module.state_dict() - new_gaussian = loading_sd['weight'] - - # carry over old bias. only matters when we work with no batchnorm networks - if m1.bias is not None: - loading_sd['bias'] = ref_sd['bias'] - # carry over old colored covariance. only matters with fact-convs - if "tri1_vec" in gen_sd.keys(): - loading_sd['tri1_vec']=gen_sd['tri1_vec'] - loading_sd['tri2_vec']=gen_sd['tri2_vec'] - - old_weight = ref_sd['weight'] - A = old_weight.reshape(old_weight.shape[0], -1) - A_T_A = A.T@A - V_val, Vn = torch.linalg.eigh(A_T_A) - del A_T_A - V_val = V_val.flip(0) - Vn = Vn.fliplr().T - Sn = (1e-6 + V_val.abs()).sqrt() - Sn_inv = (1/Sn).diag() - Un = A @ Vn.T @ Sn_inv - white_gaussian = torch.randn_like(Un) - - colored_gaussian = white_gaussian @ (Sn[:,None]* Vn) - loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) - new_module.load_state_dict(loading_sd) - return new_module - - def train(epoch, net): print('\nEpoch: %d' % epoch) net.train() @@ -573,8 +183,14 @@ def test(epoch, net): set_seeds(args.seed) s=time.time() -our_rainbow_sampling(net, net_new) + + +rainbow = rainbow_sampler(net, net_new, args, device, trainloader) +rainbow.do_rainbow_sampling()#rainbow.net, rainbow.net_new) + +net_new = rainbow.net_new net_new.train() + for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) outputs = net_new(inputs) From fff82a5f9b983a8e4de976a330eb0a68d227a152 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Fri, 26 Apr 2024 15:03:20 -0400 Subject: [PATCH 36/77] changes to rainbow sampling so we can sample multiple times and calculate variance --- refactor/rainbow.py | 7 ++++ refactor/refactor_rainbow.py | 67 ++++++++++++++++++------------------ 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/refactor/rainbow.py b/refactor/rainbow.py index b26ce1c..370f422 100644 --- a/refactor/rainbow.py +++ b/refactor/rainbow.py @@ -2,6 +2,7 @@ import torch.nn as nn import copy +from pytorch_cifar_utils import set_seeds @@ -44,6 +45,7 @@ class rainbow_sampler: def __init__(self, net, net_new, args, device, trainloader): self.net = copy.deepcopy(net) self.net_new = copy.deepcopy(net_new) + self.seed = args.seed self.sampling = args.sampling self.wa = args.wa self.in_wa = args.in_wa @@ -52,6 +54,11 @@ def __init__(self, net, net_new, args, device, trainloader): self.trainloader = trainloader def do_rainbow_sampling(self): + set_seeds(self.seed) + self.net.train() + self.net_new = copy.deepcopy(self.net) + self.net_new.train() + print("With seed {}".format(self.seed)) self.our_rainbow_sampling(self.net, self.net_new) @torch.no_grad() diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index 121d5d7..624a31e 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -121,7 +121,7 @@ def save_model(args, model): net_new.to(device) print(net_new) -replace_layers_fact_with_conv(net) +#replace_layers_fact_with_conv(net) net.to(device) print(net) @@ -182,41 +182,40 @@ def test(epoch, net): pretrained_acc, og_loss = test(0, net) set_seeds(args.seed) -s=time.time() - - -rainbow = rainbow_sampler(net, net_new, args, device, trainloader) -rainbow.do_rainbow_sampling()#rainbow.net, rainbow.net_new) - -net_new = rainbow.net_new -net_new.train() - -for batch_idx, (inputs, targets) in enumerate(trainloader): - inputs, targets = inputs.to(device), targets.to(device) - outputs = net_new(inputs) -print("TOTAL TIME:", time.time()-s) -turn_off_backbone_grad(net_new) -optimizer = optim.SGD(filter(lambda param: param.requires_grad, net_new.parameters()), lr=args.lr, - momentum=0.9, weight_decay=5e-4) -print("testing {} sampling at width {}".format(args.sampling, args.width)) -net_new.eval() - -run_name = "refactor" -args.name = run_name -print(net_new) - -sampled_acc, sampled_loss = test(0, net_new) -save_model(args, net_new) -accs = [] -test_losses= [] -print("training classifier head of {} sampled model for {} epochs".format(args.sampling, args.epochs)) -for i in range(0, args.epochs): +for i in range(0, 5): + s=time.time() + args.seed = i + rainbow = rainbow_sampler(net, net_new, args, device, trainloader) + rainbow.do_rainbow_sampling()#rainbow.net, rainbow.net_new) + net_new = rainbow.net_new net_new.train() - train(i, net_new) + + for batch_idx, (inputs, targets) in enumerate(trainloader): + inputs, targets = inputs.to(device), targets.to(device) + outputs = net_new(inputs) + print("TOTAL TIME:", time.time()-s) + turn_off_backbone_grad(net_new) + optimizer = optim.SGD(filter(lambda param: param.requires_grad, net_new.parameters()), lr=args.lr, + momentum=0.9, weight_decay=5e-4) + print("testing {} sampling at width {}".format(args.sampling, args.width)) net_new.eval() - acc, loss_test =test(i, net_new) - test_losses.append(loss_test) - accs.append(acc) + + run_name = "refactor" + args.name = run_name + print(net_new) + + sampled_acc, sampled_loss = test(0, net_new) + save_model(args, net_new) + accs = [] + test_losses= [] + print("training classifier head of {} sampled model for {} epochs".format(args.sampling, args.epochs)) + for i in range(0, args.epochs): + net_new.train() + train(i, net_new) + net_new.eval() + acc, loss_test =test(i, net_new) + test_losses.append(loss_test) + accs.append(acc) logger ={"pretrained_acc": pretrained_acc, "sampled_acc": sampled_acc, "first_epoch_acc":accs[0], "third_epoch_acc": accs[2], From 1dd2048ed089f5cfa6c1f1e1f940856f5024b453 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Fri, 26 Apr 2024 21:04:30 -0400 Subject: [PATCH 37/77] refactored a decent amount of rainbow sampling --- refactor/rainbow.py | 173 ++++++++++------------------------- refactor/refactor_rainbow.py | 94 +++++++++---------- 2 files changed, 97 insertions(+), 170 deletions(-) diff --git a/refactor/rainbow.py b/refactor/rainbow.py index 370f422..9c5d6e2 100644 --- a/refactor/rainbow.py +++ b/refactor/rainbow.py @@ -4,14 +4,10 @@ import copy from pytorch_cifar_utils import set_seeds - - from conv_modules import FactConv2d #traditional way of calculating svd. can be a bit unstable sometimes tho def calc_svd(A, name=''): - u, s, vh = torch.linalg.svd( - A, full_matrices=False, - ) # (C_in_reference, R), (R,), (R, C_in_generated) + u, s, vh = torch.linalg.svd(A, full_matrices=False) # (C_in_reference, R), (R,), (R, C_in_generated) alignment = u @ vh # (C_in_reference, C_in_generated) return alignment, (u, s, vh) @@ -39,8 +35,6 @@ def hook(mod, inputs): # theirs wa false, aca true (conv) # theirs wa false aca false (conv) [Just Random] # theirs was true aca false (conv) -# -# this function does not do an explicit specification of the colored covariance class rainbow_sampler: def __init__(self, net, net_new, args, device, trainloader): self.net = copy.deepcopy(net) @@ -61,6 +55,25 @@ def do_rainbow_sampling(self): print("With seed {}".format(self.seed)) self.our_rainbow_sampling(self.net, self.net_new) + def load_state_dicts(self, m1, m2, new_module): + # reference model state dict + ref_sd = m1.state_dict() + # generated model state dict - uses reference model weights. for now + gen_sd = m2.state_dict() + + # module with random init - to be loaded to model + loading_sd = new_module.state_dict() + + # carry over old bias. only matters when we work with no batchnorm networks + if m1.bias is not None: + loading_sd['bias'] = ref_sd['bias'] + # carry over old colored covariance. only matters with fact-convs + if "tri1_vec" in ref_sd.keys(): + loading_sd['tri1_vec'] = ref_sd['tri1_vec'] + loading_sd['tri2_vec'] = ref_sd['tri2_vec'] + + return ref_sd, gen_sd, loading_sd + @torch.no_grad() def our_rainbow_sampling(self, model, new_model): for (n1, m1), (n2, m2) in zip(model.named_children(), new_model.named_children()): @@ -90,16 +103,12 @@ def our_rainbow_sampling(self, model, new_model): if self.sampling == 'theirs': # for conv only if self.wa: - new_module = weight_Alignment_With_CC(m1, m2, new_module) + new_module = self.weight_Alignment_With_CC(m1, m2, new_module) else: - new_module = colored_Covariance_Specification(m1, m2, new_module) + new_module = self.colored_Covariance_Specification(m1, m2, new_module) # this step calculates the activation cross-covariance alignment (ACA) if m1.in_channels != 3 and self.aca: new_module = self.conv_ACA(m1, m2, new_module) - # converts fact conv to conv. this is for sake of speed. - #if isinstance(new_module, FactConv2d): - # new_module = replace_layers_fact_with_conv(new_module) - # #new_module = fact_2_conv(new_module) # changes the network module setattr(new_model, n1, new_module) @@ -111,7 +120,6 @@ def our_rainbow_sampling(self, model, new_model): if isinstance(m1, nn.BatchNorm2d): self.batchNorm_stats_recalc(m1, m2) - def conv_ACA(self, m1, m2, new_module): print("Convolutional Input Activations Alignment") activation = [] @@ -171,7 +179,6 @@ def store_hook(mod, inputs, outputs): hook_handle_pre_forward = new_module.register_forward_pre_hook(return_hook()) return new_module - def batchNorm_stats_recalc(self, m1, m2): print("Calculating Batch Statistics") m1.train() @@ -198,7 +205,6 @@ def batchNorm_stats_recalc(self, m1, m2): m2.eval() print("Batch Statistics Calculation Finished") - def linear_ACA(self, m1, m2, new_model): print("Linear Input Activations Alignment") new_module = nn.Linear(m1.in_features, m1.out_features, bias=True @@ -246,44 +252,16 @@ def linear_ACA(self, m1, m2, new_model): new_module.load_state_dict(loading_sd) return new_module - - - def weight_Alignment(self,m1, m2, new_module, in_dim=True): - # reference model state dict - ref_sd = m1.state_dict() - # generated model state dict - uses reference model weights. for now - gen_sd = m2.state_dict() - - # module with random init - to be loaded to model - loading_sd = new_module.state_dict() - new_gaussian = loading_sd['weight'] - - # carry over old bias. only matters when we work with no batchnorm networks - if m1.bias is not None: - loading_sd['bias'] = ref_sd['bias'] - # carry over old colored covariance. only matters with fact-convs - if "tri1_vec" in gen_sd.keys(): - loading_sd['tri1_vec']=gen_sd['tri1_vec'] - loading_sd['tri2_vec']=gen_sd['tri2_vec'] - - #this is the spot where - # we can do weight alignment - # for fact net, this means aligning with the random noise - # for conv net, this could mean aligning with a. W OR b. U - # we can do colored-covariance specification - # for fact net, this means just using it's R matrix - # for conv net, this could mean doing nothing (if aligning with W), or use S and V if we did b. - # in this function, we just align with W and don't specify the mulit-color covariance - - # IF FACT: we align the generated factnet with the reference fact net's noise - # IF CONV: we align the generated convnet with the reference conv net's weight matrix - reference_weight = gen_sd['weight'] - generated_weight = new_gaussian + # IF FACT: we align the generated factnet with the reference fact net's noise + # IF CONV: we align the generated convnet with the reference conv net's weight matrix + def weight_Alignment(self, m1, m2, new_module, in_dim=True): + ref_sd, gen_sd, loading_sd = self.load_state_dicts(m1, m2, new_module) + reference_weight = ref_sd['weight'] + generated_weight = loading_sd['weight'] #reshape to outdim x indim*spatial reference_weight = reference_weight.reshape(reference_weight.shape[0], -1) generated_weight = generated_weight.reshape(generated_weight.shape[0], -1) - #compute transpose, giving indim*spatial x outdim #compute weight cross-covariance indim*spatial x indim*spatial #TODO REFACTOR TO HAVE REF FIRST. OUTDIM x OUTDIM @@ -292,62 +270,31 @@ def weight_Alignment(self,m1, m2, new_module, in_dim=True): weight_cov = (generated_weight.T@reference_weight) alignment, _ = calc_svd(weight_cov, name="Weight alignment") - # outdim x indim x spatial - final_gen_weight = new_gaussian # outdim x indim*spatial - final_gen_weight = final_gen_weight.reshape(final_gen_weight.shape[0], -1) - # outdim x indim*spatial - final_gen_weight = final_gen_weight@alignment + final_gen_weight = generated_weight@alignment else: print("Output Weight Alignment") weight_cov = (reference_weight@generated_weight.T) alignment, _ = calc_svd(weight_cov, name="Weight alignment") - # outdim x indim x spatial - final_gen_weight = new_gaussian - # outdim x indim*spatial - final_gen_weight = final_gen_weight.reshape(final_gen_weight.shape[0], -1) # outdim x indim*spatial - final_gen_weight = alignment@final_gen_weight + final_gen_weight = alignment@generated_weight - loading_sd['weight'] = final_gen_weight.reshape(ref_sd['weight'].shape) loading_sd['weight_align'] = alignment new_module.register_buffer("weight_align", alignment) + + loading_sd['weight'] = final_gen_weight.reshape(ref_sd['weight'].shape) new_module.load_state_dict(loading_sd) return new_module - - - - def weight_Alignment_With_CC(self, m1, m2, new_module, Un=None, Sn=None, Vn=None): - print("NOT SUPPOSED TO BE HERE") - # reference model state dict - ref_sd = m1.state_dict() - # generated model state dict - uses reference model weights. for now - gen_sd = m2.state_dict() - - # module with random init - to be loaded to model - loading_sd = new_module.state_dict() - new_gaussian = loading_sd['weight'] - - # carry over old bias. only matters when we work with no batchnorm networks - if m1.bias is not None: - loading_sd['bias'] = ref_sd['bias'] - # carry over old colored covariance. only matters with fact-convs - if "tri1_vec" in gen_sd.keys(): - loading_sd['tri1_vec']=gen_sd['tri1_vec'] - loading_sd['tri2_vec']=gen_sd['tri2_vec'] - + def weight_Alignment_With_CC(self, m1, m2, new_module): + print("Weight alignment with Colored Covariance") + ref_sd, gen_sd, loading_sd = self.load_state_dicts(m1, m2, new_module) + old_weight = ref_sd['weight'] A = old_weight.reshape(old_weight.shape[0], -1) - A_T_A = A.T@A - V_val, Vn = torch.linalg.eigh(A_T_A) - del A_T_A - V_val = V_val.flip(0) - Vn = Vn.fliplr().T - Sn = (1e-6 + V_val.abs()).sqrt() - Sn_inv = (1/Sn).diag() - Un = A @ Vn.T @ Sn_inv + + _, (Un, Sn, Vn) = calc_svd(A) white_gaussian = torch.randn_like(Un) copy_weight = Un @@ -356,7 +303,7 @@ def weight_Alignment_With_CC(self, m1, m2, new_module, Un=None, Sn=None, Vn=None copy_weight_gen = copy_weight_gen.reshape(copy_weight_gen.shape[0], -1).T weight_cov = (copy_weight_gen@copy_weight) - alignment = calc_svd(weight_cov, name="Weight") + alignment, _ = calc_svd(weight_cov, name="Weight") new_weight = white_gaussian new_weight = new_weight.reshape(new_weight.shape[0], -1) new_weight = new_weight@alignment # C_in_reference to C_in_generated @@ -364,48 +311,26 @@ def weight_Alignment_With_CC(self, m1, m2, new_module, Un=None, Sn=None, Vn=None new_module.register_buffer("weight_align", alignment) loading_sd['weight_align'] = alignment colored_gaussian = white_gaussian @ (Sn[:,None]* Vn) + loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) new_module.load_state_dict(loading_sd) return new_module - # this function does not do an explicit specification of the colored covariance @torch.no_grad() - def colored_Covariance_Specification(self, m1, m2, new_module, Un=None, Sn=None, Vn=None): - print("NOT HERE") - # reference model state dict - ref_sd = m1.state_dict() - # generated model state dict - uses reference model weights. for now - gen_sd = m2.state_dict() - - # module with random init - to be loaded to model - loading_sd = new_module.state_dict() - new_gaussian = loading_sd['weight'] - - # carry over old bias. only matters when we work with no batchnorm networks - if m1.bias is not None: - loading_sd['bias'] = ref_sd['bias'] - # carry over old colored covariance. only matters with fact-convs - if "tri1_vec" in gen_sd.keys(): - loading_sd['tri1_vec']=gen_sd['tri1_vec'] - loading_sd['tri2_vec']=gen_sd['tri2_vec'] - + def colored_Covariance_Specification(self, m1, m2, new_module): + print("Colored Covariance") + ref_sd, gen_sd, loading_sd = self.load_state_dicts(m1, m2, new_module) + old_weight = ref_sd['weight'] A = old_weight.reshape(old_weight.shape[0], -1) - A_T_A = A.T@A - V_val, Vn = torch.linalg.eigh(A_T_A) - del A_T_A - V_val = V_val.flip(0) - Vn = Vn.fliplr().T - Sn = (1e-6 + V_val.abs()).sqrt() - Sn_inv = (1/Sn).diag() - Un = A @ Vn.T @ Sn_inv + + _, (Un, Sn, Vn) = calc_svd(A) white_gaussian = torch.randn_like(Un) - colored_gaussian = white_gaussian @ (Sn[:,None]* Vn) + loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) new_module.load_state_dict(loading_sd) return new_module - - - + + diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index 624a31e..3a9cbee 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -40,7 +40,7 @@ def save_model(args, model): parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--epochs', default=10, type=int, help='number of epochs') -parser.add_argument('--seed', default=1, type=int, help='seed') +parser.add_argument('--seed', default=0, type=int, help='seed') parser.add_argument('--name', type=str, default='TESTING_VGG', help='filename for saved model') parser.add_argument('--aca', type=lambda x: bool(strtobool(x)), @@ -104,31 +104,6 @@ def save_model(args, model): print('==> Building model..') -net=ResNet18() -replace_layers_scale(net, args.width) -if args.fact: - replace_layers_factconv2d(net) - - -if args.fact: - sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/fact_model.pt".format(args.width)) -elif not args.fact: - sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/conv_model.pt".format(args.width)) -net.load_state_dict(sd) -net.to(device) - -net_new = copy.deepcopy(net) -net_new.to(device) -print(net_new) - -#replace_layers_fact_with_conv(net) -net.to(device) -print(net) - -net.train() -net_new.train() - -set_seeds(args.seed) criterion = nn.CrossEntropyLoss() def train(epoch, net): @@ -177,16 +152,41 @@ def test(epoch, net): print("accuracy:", acc) return acc, test_loss - -print("testing Res{}Net18 with width of {}".format("Fact" if args.fact else "Conv", args.width)) -pretrained_acc, og_loss = test(0, net) - +logger ={'width':args.width}#, } set_seeds(args.seed) for i in range(0, 5): + net=ResNet18() + replace_layers_scale(net, args.width) + if args.fact: + replace_layers_factconv2d(net) + + + if args.fact: + sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/fact_model.pt".format(args.width)) + elif not args.fact: + sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/conv_model.pt".format(args.width)) + net.load_state_dict(sd) + net.to(device) + + net_new = copy.deepcopy(net) + net_new.to(device) + print(net_new) + + net.to(device) + print(net) + + net.train() + net_new.train() + + set_seeds(i) + print("testing Res{}Net18 with width of {}".format("Fact" if args.fact else "Conv", args.width)) + pretrained_acc, og_loss = test(0, net) + + s=time.time() args.seed = i rainbow = rainbow_sampler(net, net_new, args, device, trainloader) - rainbow.do_rainbow_sampling()#rainbow.net, rainbow.net_new) + rainbow.do_rainbow_sampling() net_new = rainbow.net_new net_new.train() @@ -200,8 +200,6 @@ def test(epoch, net): print("testing {} sampling at width {}".format(args.sampling, args.width)) net_new.eval() - run_name = "refactor" - args.name = run_name print(net_new) sampled_acc, sampled_loss = test(0, net_new) @@ -209,28 +207,32 @@ def test(epoch, net): accs = [] test_losses= [] print("training classifier head of {} sampled model for {} epochs".format(args.sampling, args.epochs)) - for i in range(0, args.epochs): + for j in range(0, args.epochs): net_new.train() - train(i, net_new) + train(j, net_new) net_new.eval() - acc, loss_test =test(i, net_new) + acc, loss_test =test(j, net_new) test_losses.append(loss_test) accs.append(acc) -logger ={"pretrained_acc": pretrained_acc, "sampled_acc": sampled_acc, - "first_epoch_acc":accs[0], "third_epoch_acc": accs[2], - "tenth_epoch_acc":accs[args.epochs-1], 'width':args.width, - "og_loss":og_loss, "sampled_loss":sampled_loss, - "first_epoch_loss":test_losses[0], "third_epoch_loss": test_losses[2], - "tenth_epoch_loss":test_losses[args.epochs-1], 'width':args.width} + new_logger ={"sampled_acc_{}".format(i): sampled_acc,"pretrained_acc_{}".format(i): + pretrained_acc, "og_loss_{}".format(i): og_loss, + "first_epoch_acc_{}".format(i):accs[0], "third_epoch_acc_{}".format(i): accs[2], + "tenth_epoch_acc_{}".format(i):accs[args.epochs-1], + "sampled_loss_{}".format(i):sampled_loss, + "first_epoch_loss_{}".format(i):test_losses[0], "third_epoch_loss_{}".format(i): test_losses[2], + "tenth_epoch_loss_{}".format(i):test_losses[args.epochs-1]} + logger = {**logger, **new_logger} wandb_dir = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/wandb" os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) -group_string = "refactor" +#group_string = "refactor" +group_string = "variance_runs" + +#run_name = "refactor" +run_name= "width_{}_sampling_{}_fact_{}_ACA_{}_WA_{}_inWA_{}".format(args.width, args.sampling, args.fact, args.aca, args.wa, args.in_wa) +args.name = run_name run = wandb.init(project="random_project", config=args, group=group_string, name=run_name, dir=wandb_dir) run.log(logger) - -args.name += "_trained_classifier_head" -save_model(args, net_new) From 5f449732958d6ce2a570b69a1f2c9eb8bcaad024 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 1 May 2024 14:05:35 -0400 Subject: [PATCH 38/77] fact-conv changes --- refactor/conv_modules.py | 42 +++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py index 98b2a49..76506a5 100644 --- a/refactor/conv_modules.py +++ b/refactor/conv_modules.py @@ -28,57 +28,52 @@ def __init__( groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', # TODO: refine this type - device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), + device=None, dtype=None ) -> None: # init as Conv2d super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype) - - factory_kwargs = {'device': device, 'dtype': dtype} - self.factory_kwargs = factory_kwargs - # weight shape: (out_channels, in_channels // groups, *kernel_size) weight_shape = self.weight.shape + new_weight = self.weight.new_empty(weight_shape) del self.weight # remove Parameter, create buffer - self.register_buffer("weight", torch.empty(weight_shape, **factory_kwargs)) + self.register_buffer("weight", new_weight) nn.init.kaiming_normal_(self.weight) self.in_features = self.in_channels // self.groups * \ self.kernel_size[0] * self.kernel_size[1] triu1 = torch.triu_indices(self.in_channels // self.groups, self.in_channels // self.groups, - **factory_kwargs) - self.scat_idx1=triu1[0]*self.in_channels//self.groups + triu1[1] + device=self.weight.device, + dtype=torch.long) + scat_idx1 = triu1[0]*self.in_channels//self.groups + triu1[1] + self.register_buffer("scat_idx1", scat_idx1, persistent=False) + triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], self.kernel_size[0] * self.kernel_size[1], - **factory_kwargs) + device=self.weight.device, + dtype=torch.long) + scat_idx2 = triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1] + self.register_buffer("scat_idx2", scat_idx2, persistent=False) - self.scat_idx2=triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1] triu1_len = triu1.shape[1] triu2_len = triu2.shape[1] - tri1_vec = torch.zeros((triu1_len,), - **factory_kwargs) + tri1_vec = self.weight.new_zeros((triu1_len,)) self.tri1_vec = Parameter(tri1_vec) - tri2_vec = torch.zeros((triu2_len,), **factory_kwargs) + tri2_vec = self.weight.new_zeros((triu2_len,)) self.tri2_vec = Parameter(tri2_vec) - def construct_Us(self): - self.tri1_vec = Parameter(self._tri_vec_to_mat(self.tri1_vec, self.in_channels // - self.groups,self.scat_idx1)) - self.tri2_vec = Parameter(self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], - self.scat_idx2)) def forward(self, input: Tensor) -> Tensor: U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // - self.groups, self.scat_idx1) + self.groups, self.scat_idx1) U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], - self.scat_idx2 - ) + self.scat_idx2) # flatten over filter dims and contract composite_weight = _contract(self.weight, U1.T, 1) composite_weight = _contract( @@ -87,8 +82,7 @@ def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, composite_weight, self.bias) def _tri_vec_to_mat(self, vec, n, scat_idx): - U = torch.zeros((n* n), - **self.factory_kwargs).scatter_(0,scat_idx,vec).view(n,n) - U = torch.diagonal_scatter(U,U.diagonal().exp_()) + U = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n) + U = torch.diagonal_scatter(U, U.diagonal().exp_()) return U From 657db4f27507bcec8037d706da394ad6a3f2b00b Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 1 May 2024 15:09:39 -0400 Subject: [PATCH 39/77] reduce amount of code in rainbow --- refactor/rainbow.py | 118 ++++++++++++++--------------------- refactor/refactor_rainbow.py | 18 ++---- 2 files changed, 53 insertions(+), 83 deletions(-) diff --git a/refactor/rainbow.py b/refactor/rainbow.py index 9c5d6e2..6c150dc 100644 --- a/refactor/rainbow.py +++ b/refactor/rainbow.py @@ -36,7 +36,7 @@ def hook(mod, inputs): # theirs wa false aca false (conv) [Just Random] # theirs was true aca false (conv) class rainbow_sampler: - def __init__(self, net, net_new, args, device, trainloader): + def __init__(self, net, net_new, args, device, trainloader, num_classes=10): self.net = copy.deepcopy(net) self.net_new = copy.deepcopy(net_new) self.seed = args.seed @@ -46,6 +46,7 @@ def __init__(self, net, net_new, args, device, trainloader): self.aca = args.aca self.device = device self.trainloader = trainloader + self.num_classes = num_classes def do_rainbow_sampling(self): set_seeds(self.seed) @@ -97,10 +98,10 @@ def our_rainbow_sampling(self, model, new_model): groups = m2.groups, bias=True if m2.bias is not None else False).to(self.device) - if self.sampling == 'ours' and self.wa: + if self.sampling == 'structured_alignment' and self.wa: # right now this function does not do an explicit specification of the colored covariance new_module = self.weight_Alignment(m1, m2, new_module, in_dim=self.in_wa) - if self.sampling == 'theirs': + if self.sampling == 'cc_specification': # for conv only if self.wa: new_module = self.weight_Alignment_With_CC(m1, m2, new_module) @@ -119,11 +120,40 @@ def our_rainbow_sampling(self, model, new_model): ##just run stats through if isinstance(m1, nn.BatchNorm2d): self.batchNorm_stats_recalc(m1, m2) - + + def run_forward(self, calc_covar=False): + if calc_covar: + covar = None + total = 0 + for batch_idx, (inputs, targets) in enumerate(self.trainloader): + inputs, targets = inputs.to(self.device), targets.to(self.device) + try: + outputs1 = self.net(inputs) + except Exception: + pass + try: + outputs2 = self.net_new(inputs) + except Exception: + pass + if calc_covar: + total+= inputs.shape[0] + if covar is None: + #activation is bwh x c + covar = self.activation[0].T @ self.activation[1] + assert (covar.isfinite().all()) + else: + #activation is bwh x c + covar += self.activation[0].T @ self.activation[1] + assert (covar.isfinite().all()) + self.activation = [] + if calc_covar: + #c x c + covar /= total + return covar + def conv_ACA(self, m1, m2, new_module): print("Convolutional Input Activations Alignment") - activation = [] - other_activation = [] + self.activation = [] # this hook grabs the input activations of the conv layer # rearanges the vector so that the width by height dim is # additional samples to the covariance @@ -134,46 +164,18 @@ def store_hook(mod, inputs, outputs): x = inputs[0] x = x.permute(0, 2, 3, 1) x = x.reshape((-1, x.shape[-1])) - activation.append(x) + self.activation.append(x) raise Exception("Done") return store_hook - hook_handle_1 = m1.register_forward_hook(define_hook(m1)) hook_handle_2 = m2.register_forward_hook(define_hook(m2)) - print("Starting Sample Cross-Covariance Calculation") - covar = None - total = 0 - for batch_idx, (inputs, targets) in enumerate(self.trainloader): - inputs, targets = inputs.to(self.device), targets.to(self.device) - try: - outputs1 = self.net(inputs) - except Exception: - pass - try: - outputs2 = self.net_new(inputs) - except Exception: - pass - total+= inputs.shape[0] - if covar is None: - #activation is bwh x c - covar = activation[0].T @ activation[1] - assert (covar.isfinite().all()) - else: - #activation is bwh x c - covar += activation[0].T @ activation[1] - assert (covar.isfinite().all()) - activation = [] - other_activation = [] - - #c x c - covar /= total + covar = self.run_forward(calc_covar=True) + print("Sample Cross-Covariance Calculation finished") hook_handle_1.remove() hook_handle_2.remove() - print("Sample Cross-Covariance Calculation finished") align, _ = calc_svd(covar, name="Cross-Covariance") new_module.register_buffer("input_align", align) - # this hook takes the input to the conv, aligns, then returns # to the conv the aligned inputs hook_handle_pre_forward = new_module.register_forward_pre_hook(return_hook()) @@ -189,16 +191,7 @@ def batchNorm_stats_recalc(self, m1, m2): handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: Exception("Done")) m1.to(self.device) m2.to(self.device) - for batch_idx, (inputs, targets) in enumerate(self.trainloader): - inputs, targets = inputs.to(self.device), targets.to(self.device) - try: - outputs1 = self.net(inputs) - except Exception: - pass - try: - outputs2 = self.net_new(inputs) - except Exception: - pass + self.run_forward(calc_covar=False) handle_1.remove() handle_2.remove() m1.eval() @@ -211,36 +204,19 @@ def linear_ACA(self, m1, m2, new_model): if m1.bias is not None else False).to(self.device) ref_sd = m1.state_dict() loading_sd = new_module.state_dict() - loading_sd['weight'] = ref_sd['weight'] + if m1.out_features == self.num_classes: + loading_sd['weight'] = ref_sd['weight'] if m1.bias is not None: loading_sd['bias'] = ref_sd['bias'] - activation = [] - other_activation = [] - + self.activation = [] hook_handle_1 = m1.register_forward_hook(lambda mod, inputs, outputs: - activation.append(inputs[0])) - + self.activation.append(inputs[0])) hook_handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: - other_activation.append(inputs[0])) - covar = None - total = 0 + self.activation.append(inputs[0])) print("Starting Sample Cross-Covariance Calculation") - for batch_idx, (inputs, targets) in enumerate(self.trainloader): - inputs, targets = inputs.to(self.device), targets.to(self.device) - outputs1 = self.net(inputs) - outputs2 = self.net_new(inputs) - total+= inputs.shape[0] - if covar is None: - covar = activation[0].T @ other_activation[0] - else: - covar += activation[0].T @ other_activation[0] - activation = [] - other_activation = [] - covar /= total - + covar = self.run_forward(calc_covar=True) hook_handle_1.remove() hook_handle_2.remove() - print("Sample Cross-Covariance Calculation finished") align, _ = calc_svd(covar, name="Cross-Covariance") new_weight = loading_sd['weight'] @@ -248,7 +224,7 @@ def linear_ACA(self, m1, m2, new_model): destination=-1) new_weight = new_weight@align loading_sd['weight'] = torch.moveaxis(new_weight, source=-1, - destination=1) + destination=1) new_module.load_state_dict(loading_sd) return new_module diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index 3a9cbee..0d4c63e 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -52,19 +52,13 @@ def save_model(args, model): parser.add_argument('--fact', type=lambda x: bool(strtobool(x)), default=True, help='FactNet True or False') parser.add_argument('--width', default=0.125, type=float, help='width') -parser.add_argument('--sampling', type=str, default='ours', - choices=['ours', 'theirs'], help="which sampling to use") +parser.add_argument('--sampling', type=str, default='structured_alignment', + choices=['structured_alignment', 'cc_specification'], help="which sampling to use") args = parser.parse_args() -if args.width == 1.0: - args.width = 1 -if args.width == 2.0: - args.width = 2 -if args.width == 4.0: - args.width = 4 -if args.width == 8.0: - args.width = 8 +if int(args.width) == args.width: + args.width = int(args.width) print("Sampling: {} Width: {} Fact: {} ACA: {} WA: {} In_WA: {}".format(args.sampling, args.width, args.fact, args.aca, args.wa, args.in_wa)) @@ -154,7 +148,7 @@ def test(epoch, net): logger ={'width':args.width}#, } set_seeds(args.seed) -for i in range(0, 5): +for i in range(0, 1): net=ResNet18() replace_layers_scale(net, args.width) if args.fact: @@ -228,7 +222,7 @@ def test(epoch, net): os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) #group_string = "refactor" -group_string = "variance_runs" +group_string = "IGNOREvariance_runs" #run_name = "refactor" run_name= "width_{}_sampling_{}_fact_{}_ACA_{}_WA_{}_inWA_{}".format(args.width, args.sampling, args.fact, args.aca, args.wa, args.in_wa) From c680aea78277e02ab65c2c797e0f71b85e421461 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 1 May 2024 15:26:14 -0400 Subject: [PATCH 40/77] initalize with empty_like instead --- refactor/conv_modules.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py index 76506a5..5274173 100644 --- a/refactor/conv_modules.py +++ b/refactor/conv_modules.py @@ -36,8 +36,7 @@ def __init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype) # weight shape: (out_channels, in_channels // groups, *kernel_size) - weight_shape = self.weight.shape - new_weight = self.weight.new_empty(weight_shape) + new_weight = torch.empty_like(self.weight) del self.weight # remove Parameter, create buffer self.register_buffer("weight", new_weight) nn.init.kaiming_normal_(self.weight) From edda0b447fd2a4354037fb198884df59b1faede0 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 1 May 2024 15:26:52 -0400 Subject: [PATCH 41/77] added verbose option --- refactor/rainbow.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/refactor/rainbow.py b/refactor/rainbow.py index 6c150dc..b9307fc 100644 --- a/refactor/rainbow.py +++ b/refactor/rainbow.py @@ -2,8 +2,9 @@ import torch.nn as nn import copy -from pytorch_cifar_utils import set_seeds +import logging +from pytorch_cifar_utils import set_seeds from conv_modules import FactConv2d #traditional way of calculating svd. can be a bit unstable sometimes tho def calc_svd(A, name=''): @@ -36,7 +37,7 @@ def hook(mod, inputs): # theirs wa false aca false (conv) [Just Random] # theirs was true aca false (conv) class rainbow_sampler: - def __init__(self, net, net_new, args, device, trainloader, num_classes=10): + def __init__(self, net, net_new, args, device, trainloader, num_classes=10, verbose=True): self.net = copy.deepcopy(net) self.net_new = copy.deepcopy(net_new) self.seed = args.seed @@ -47,13 +48,15 @@ def __init__(self, net, net_new, args, device, trainloader, num_classes=10): self.device = device self.trainloader = trainloader self.num_classes = num_classes + logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, + format='%(message)s') def do_rainbow_sampling(self): set_seeds(self.seed) self.net.train() self.net_new = copy.deepcopy(self.net) self.net_new.train() - print("With seed {}".format(self.seed)) + logging.info("With seed {}".format(self.seed)) self.our_rainbow_sampling(self.net, self.net_new) def load_state_dicts(self, m1, m2, new_module): @@ -152,7 +155,7 @@ def run_forward(self, calc_covar=False): return covar def conv_ACA(self, m1, m2, new_module): - print("Convolutional Input Activations Alignment") + logging.info("Convolutional Input Activations Alignment") self.activation = [] # this hook grabs the input activations of the conv layer # rearanges the vector so that the width by height dim is @@ -169,9 +172,9 @@ def store_hook(mod, inputs, outputs): return store_hook hook_handle_1 = m1.register_forward_hook(define_hook(m1)) hook_handle_2 = m2.register_forward_hook(define_hook(m2)) - print("Starting Sample Cross-Covariance Calculation") + logging.info("Starting Sample Cross-Covariance Calculation") covar = self.run_forward(calc_covar=True) - print("Sample Cross-Covariance Calculation finished") + logging.info("Sample Cross-Covariance Calculation finished") hook_handle_1.remove() hook_handle_2.remove() align, _ = calc_svd(covar, name="Cross-Covariance") @@ -182,7 +185,7 @@ def store_hook(mod, inputs, outputs): return new_module def batchNorm_stats_recalc(self, m1, m2): - print("Calculating Batch Statistics") + logging.info("Calculating Batch Statistics") m1.train() m2.train() m1.reset_running_stats() @@ -196,10 +199,10 @@ def batchNorm_stats_recalc(self, m1, m2): handle_2.remove() m1.eval() m2.eval() - print("Batch Statistics Calculation Finished") + logging.info("Batch Statistics Calculation Finished") def linear_ACA(self, m1, m2, new_model): - print("Linear Input Activations Alignment") + logging.info("Linear Input Activations Alignment") new_module = nn.Linear(m1.in_features, m1.out_features, bias=True if m1.bias is not None else False).to(self.device) ref_sd = m1.state_dict() @@ -213,11 +216,11 @@ def linear_ACA(self, m1, m2, new_model): self.activation.append(inputs[0])) hook_handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: self.activation.append(inputs[0])) - print("Starting Sample Cross-Covariance Calculation") + logging.info("Starting Sample Cross-Covariance Calculation") covar = self.run_forward(calc_covar=True) hook_handle_1.remove() hook_handle_2.remove() - print("Sample Cross-Covariance Calculation finished") + logging.info("Sample Cross-Covariance Calculation finished") align, _ = calc_svd(covar, name="Cross-Covariance") new_weight = loading_sd['weight'] new_weight = torch.moveaxis(new_weight, source=1, @@ -242,14 +245,14 @@ def weight_Alignment(self, m1, m2, new_module, in_dim=True): #compute weight cross-covariance indim*spatial x indim*spatial #TODO REFACTOR TO HAVE REF FIRST. OUTDIM x OUTDIM if in_dim: - print("Input Weight Alignment") + logging.info("Input Weight Alignment") weight_cov = (generated_weight.T@reference_weight) alignment, _ = calc_svd(weight_cov, name="Weight alignment") # outdim x indim*spatial final_gen_weight = generated_weight@alignment else: - print("Output Weight Alignment") + logging.info("Output Weight Alignment") weight_cov = (reference_weight@generated_weight.T) alignment, _ = calc_svd(weight_cov, name="Weight alignment") @@ -264,7 +267,7 @@ def weight_Alignment(self, m1, m2, new_module, in_dim=True): return new_module def weight_Alignment_With_CC(self, m1, m2, new_module): - print("Weight alignment with Colored Covariance") + logging.info("Weight alignment with Colored Covariance") ref_sd, gen_sd, loading_sd = self.load_state_dicts(m1, m2, new_module) old_weight = ref_sd['weight'] @@ -295,7 +298,7 @@ def weight_Alignment_With_CC(self, m1, m2, new_module): # this function does not do an explicit specification of the colored covariance @torch.no_grad() def colored_Covariance_Specification(self, m1, m2, new_module): - print("Colored Covariance") + logging.info("Colored Covariance") ref_sd, gen_sd, loading_sd = self.load_state_dicts(m1, m2, new_module) old_weight = ref_sd['weight'] From 68e84e45c0197e4fdd73801f08c915f984c4db67 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Thu, 2 May 2024 12:45:09 -0400 Subject: [PATCH 42/77] removed my beloved og --- refactor/{og_pytorch_cifar.py => pytorch_cifar.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename refactor/{og_pytorch_cifar.py => pytorch_cifar.py} (100%) diff --git a/refactor/og_pytorch_cifar.py b/refactor/pytorch_cifar.py similarity index 100% rename from refactor/og_pytorch_cifar.py rename to refactor/pytorch_cifar.py From 07803e0de977eec6f2379b31490adf80991a3883 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Thu, 2 May 2024 12:53:22 -0400 Subject: [PATCH 43/77] deleted learnable_cov.py which had been renamed to V1_covariance.py --- refactor/learnable_cov.py | 80 --------------------------------------- 1 file changed, 80 deletions(-) delete mode 100644 refactor/learnable_cov.py diff --git a/refactor/learnable_cov.py b/refactor/learnable_cov.py deleted file mode 100644 index 69fcf4d..0000000 --- a/refactor/learnable_cov.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -from torch import Tensor -import numpy as np -import numpy.linalg as la -from scipy.spatial.distance import pdist, squareform - - -def V1_covariance_matrix(dim, size, spatial_freq, center, scale=1): - """ - Generates the covariance matrix for Gaussian Process with non-stationary - covariance. This matrix will be used to generate random - features inspired from the receptive-fields of V1 neurons. - - C(x, y) = exp(-|x - y|/(2 * spatial_freq))^2 * exp(-|x - m| / (2 * size))^2 * exp(-|y - m| / (2 * size))^2 - - Parameters - ---------- - - dim : tuple of shape (2, 1) - Dimension of random features. - - size : float - Determines the size of the random weights - - spatial_freq : float - Determines the spatial frequency of the random weights - - center : tuple of shape (2, 1) - Location of the center of the random weights. - - scale: float, default=1 - Normalization factor for Tr norm of cov matrix - - Returns - ------- - - C : array-like of shape (dim[0] * dim[1], dim[0] * dim[1]) - covariance matrix w/ Tr norm = scale * dim[0] * dim[1] - """ - - x = np.arange(dim[0]) - y = np.arange(dim[1]) - yy, xx = np.meshgrid(y, x) - grid = np.column_stack((xx.flatten(), yy.flatten())) - - a = squareform(pdist(grid, 'sqeuclidean')) - b = la.norm(grid - center, axis=1) ** 2 - c = b.reshape(-1, 1) - C = np.exp(-a / (2 * spatial_freq ** 2)) * np.exp(-b / (2 * size ** 2)) * np.exp(-c / (2 * size ** 2)) \ - + 1e-5 * np.eye(dim[0] * dim[1]) - C *= scale * dim[0] * dim[1] / np.trace(C) - return C - - -def V1_init(layer, size, spatial_freq, center, scale=1., bias=False, seed=None, - device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): - ''' - Initialization for FactConv2d - ''' - - classname = layer.__class__.__name__ - assert classname.find('FactConv2d') != -1, 'This init only works for FactConv2d layers' - assert center is not None, "center needed" - - out_channels, in_channels, xdim, ydim = layer.weight.shape - dim = (xdim, ydim) - - C_patch = Tensor(V1_covariance_matrix(dim, size, spatial_freq, center, scale)).to(device) - U_patch = torch.linalg.cholesky(C_patch, upper=True) - n = U_patch.shape[0] - # replace diagonal with logarithm for parameterization - log_diag = torch.log(torch.diagonal(U_patch)) - U_patch[range(n), range(n)] = log_diag - # form vector of upper triangular entries - tri_vec = U_patch[torch.triu_indices(n, n, device=device).tolist()].ravel() - with torch.no_grad(): - layer.tri2_vec.copy_(tri_vec) - - if bias == False: - layer.bias = None From 659429242f9b2850e178c535962675e6a78cacb5 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Thu, 2 May 2024 12:58:35 -0400 Subject: [PATCH 44/77] width scaling should be first --- refactor/models/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py index 0588e64..434ce4f 100644 --- a/refactor/models/__init__.py +++ b/refactor/models/__init__.py @@ -5,6 +5,8 @@ def define_models(args): if 'resnet18' in args.net: model = ResNet18() + if args.width != 1: + replace_layers_scale(model, args.width) if 'fact' in args.net: replace_layers_factconv2d(model) if "v1" in args.net: @@ -13,6 +15,5 @@ def define_models(args): turn_off_covar_grad(model, "spatial") if "uc" in args.net: turn_off_covar_grad(model, "channel") - if args.width != 1: - replace_layers_scale(model, args.width) + return model From c078cfcb8d907f0efc11b681ad5ebc33d1646bf1 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Thu, 2 May 2024 13:48:34 -0400 Subject: [PATCH 45/77] making rainbow sampling more like a properly called class. similar to a kymatio front end. --- refactor/rainbow.py | 35 +++++++++++++++--------------- refactor/refactor_rainbow.py | 42 ++++++++++++++---------------------- 2 files changed, 34 insertions(+), 43 deletions(-) diff --git a/refactor/rainbow.py b/refactor/rainbow.py index b9307fc..6fb5744 100644 --- a/refactor/rainbow.py +++ b/refactor/rainbow.py @@ -36,28 +36,29 @@ def hook(mod, inputs): # theirs wa false, aca true (conv) # theirs wa false aca false (conv) [Just Random] # theirs was true aca false (conv) -class rainbow_sampler: - def __init__(self, net, net_new, args, device, trainloader, num_classes=10, verbose=True): - self.net = copy.deepcopy(net) - self.net_new = copy.deepcopy(net_new) - self.seed = args.seed - self.sampling = args.sampling - self.wa = args.wa - self.in_wa = args.in_wa - self.aca = args.aca - self.device = device +class RainbowSampler: + def __init__(self, ref_net, trainloader, seed=0, sampling='structured_alignment', wa=True, in_wa=True, aca=True, device=None, num_classes=10, verbose=True): + self.ref_net = copy.deepcopy(ref_net) + self.gen_net = copy.deepcopy(ref_net) self.trainloader = trainloader + self.seed = seed + self.sampling = sampling + self.wa = wa + self.in_wa = in_wa + self.aca = aca + self.device = torch.get_default_device() if device is None else torch.device(device) self.num_classes = num_classes logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, format='%(message)s') - def do_rainbow_sampling(self): + def sample(self): set_seeds(self.seed) - self.net.train() - self.net_new = copy.deepcopy(self.net) - self.net_new.train() + self.ref_net.train() + self.gen_net = copy.deepcopy(self.ref_net) + self.gen_net.train() logging.info("With seed {}".format(self.seed)) - self.our_rainbow_sampling(self.net, self.net_new) + self.our_rainbow_sampling(self.ref_net, self.gen_net) + return self.gen_net def load_state_dicts(self, m1, m2, new_module): # reference model state dict @@ -131,11 +132,11 @@ def run_forward(self, calc_covar=False): for batch_idx, (inputs, targets) in enumerate(self.trainloader): inputs, targets = inputs.to(self.device), targets.to(self.device) try: - outputs1 = self.net(inputs) + outputs1 = self.ref_net(inputs) except Exception: pass try: - outputs2 = self.net_new(inputs) + outputs2 = self.gen_net(inputs) except Exception: pass if calc_covar: diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py index 0d4c63e..96bcfcc 100644 --- a/refactor/refactor_rainbow.py +++ b/refactor/refactor_rainbow.py @@ -23,7 +23,7 @@ from models.function_utils import replace_layers_factconv2d,\ replace_layers_scale, replace_layers_fact_with_conv, turn_off_backbone_grad, \ recurse_preorder -from rainbow import calc_svd, return_hook, rainbow_sampler +from rainbow import RainbowSampler def save_model(args, model): src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/refactoring/" @@ -160,18 +160,9 @@ def test(epoch, net): elif not args.fact: sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/conv_model.pt".format(args.width)) net.load_state_dict(sd) - net.to(device) - - net_new = copy.deepcopy(net) - net_new.to(device) - print(net_new) - net.to(device) print(net) - - net.train() - net_new.train() - + set_seeds(i) print("testing Res{}Net18 with width of {}".format("Fact" if args.fact else "Conv", args.width)) pretrained_acc, og_loss = test(0, net) @@ -179,33 +170,32 @@ def test(epoch, net): s=time.time() args.seed = i - rainbow = rainbow_sampler(net, net_new, args, device, trainloader) - rainbow.do_rainbow_sampling() - net_new = rainbow.net_new - net_new.train() + rainbow = RainbowSampler(net, trainloader, args.seed, args.sampling, args.wa, args.in_wa, args.aca, device) + rainbow_net = rainbow.sample() + rainbow_net.train() for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) - outputs = net_new(inputs) + outputs = rainbow_net(inputs) print("TOTAL TIME:", time.time()-s) - turn_off_backbone_grad(net_new) - optimizer = optim.SGD(filter(lambda param: param.requires_grad, net_new.parameters()), lr=args.lr, + turn_off_backbone_grad(rainbow_net) + optimizer = optim.SGD(filter(lambda param: param.requires_grad, rainbow_net.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4) print("testing {} sampling at width {}".format(args.sampling, args.width)) - net_new.eval() + rainbow_net.eval() - print(net_new) + print(rainbow_net) - sampled_acc, sampled_loss = test(0, net_new) - save_model(args, net_new) + sampled_acc, sampled_loss = test(0, rainbow_net) + save_model(args, rainbow_net) accs = [] test_losses= [] print("training classifier head of {} sampled model for {} epochs".format(args.sampling, args.epochs)) for j in range(0, args.epochs): - net_new.train() - train(j, net_new) - net_new.eval() - acc, loss_test =test(j, net_new) + rainbow_net.train() + train(j, rainbow_net) + rainbow_net.eval() + acc, loss_test =test(j, rainbow_net) test_losses.append(loss_test) accs.append(acc) From 642c480e676608a1b3a445559d857a2092db2cbf Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Thu, 2 May 2024 13:59:29 -0400 Subject: [PATCH 46/77] removed specification of device from init, rely instead on layer to provide that. --- refactor/V1_covariance.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/refactor/V1_covariance.py b/refactor/V1_covariance.py index 448647d..936142c 100644 --- a/refactor/V1_covariance.py +++ b/refactor/V1_covariance.py @@ -72,8 +72,7 @@ def V1_covariance_matrix(dim, size, spatial_freq, center, scale=1): return C -def V1_init(layer, size, spatial_freq, center, scale=1., bias=False, seed=None, - device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): +def V1_init(layer, size, spatial_freq, center, scale=1., bias=False, seed=None): ''' Initialization for FactConv2d ''' @@ -85,14 +84,14 @@ def V1_init(layer, size, spatial_freq, center, scale=1., bias=False, seed=None, out_channels, in_channels, xdim, ydim = layer.weight.shape dim = (xdim, ydim) - C_patch = Tensor(V1_covariance_matrix(dim, size, spatial_freq, center, scale)).to(device) + C_patch = Tensor(V1_covariance_matrix(dim, size, spatial_freq, center, scale)).to(layer.weight.device) U_patch = torch.linalg.cholesky(C_patch, upper=True) n = U_patch.shape[0] # replace diagonal with logarithm for parameterization log_diag = torch.log(torch.diagonal(U_patch)) U_patch[range(n), range(n)] = log_diag # form vector of upper triangular entries - tri_vec = U_patch[torch.triu_indices(n, n, device=device).tolist()].ravel() + tri_vec = U_patch[torch.triu_indices(n, n, device=layer.weight.device).tolist()].ravel() with torch.no_grad(): layer.tri2_vec.copy_(tri_vec) From d3116ecda4b1df4dd9d5c25f623b88302aea1bd9 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Thu, 2 May 2024 14:04:37 -0400 Subject: [PATCH 47/77] formatting changes --- refactor/rainbow.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/refactor/rainbow.py b/refactor/rainbow.py index 6fb5744..911642b 100644 --- a/refactor/rainbow.py +++ b/refactor/rainbow.py @@ -6,6 +6,8 @@ from pytorch_cifar_utils import set_seeds from conv_modules import FactConv2d + + #traditional way of calculating svd. can be a bit unstable sometimes tho def calc_svd(A, name=''): u, s, vh = torch.linalg.svd(A, full_matrices=False) # (C_in_reference, R), (R,), (R, C_in_generated) @@ -26,16 +28,6 @@ def hook(mod, inputs): return hook -#settings -# our, wa true, aca true (fact, conv) -# our wa false, aca true (fact, conv) -# our wa false aca false (fact, conv) [Just Random] -# our was true aca false (fact, conv) -# -# theirs wa true, aca true (conv) -# theirs wa false, aca true (conv) -# theirs wa false aca false (conv) [Just Random] -# theirs was true aca false (conv) class RainbowSampler: def __init__(self, ref_net, trainloader, seed=0, sampling='structured_alignment', wa=True, in_wa=True, aca=True, device=None, num_classes=10, verbose=True): self.ref_net = copy.deepcopy(ref_net) @@ -48,8 +40,7 @@ def __init__(self, ref_net, trainloader, seed=0, sampling='structured_alignment' self.aca = aca self.device = torch.get_default_device() if device is None else torch.device(device) self.num_classes = num_classes - logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, - format='%(message)s') + logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, format='%(message)s') def sample(self): set_seeds(self.seed) @@ -158,6 +149,7 @@ def run_forward(self, calc_covar=False): def conv_ACA(self, m1, m2, new_module): logging.info("Convolutional Input Activations Alignment") self.activation = [] + # this hook grabs the input activations of the conv layer # rearanges the vector so that the width by height dim is # additional samples to the covariance @@ -171,6 +163,7 @@ def store_hook(mod, inputs, outputs): self.activation.append(x) raise Exception("Done") return store_hook + hook_handle_1 = m1.register_forward_hook(define_hook(m1)) hook_handle_2 = m2.register_forward_hook(define_hook(m2)) logging.info("Starting Sample Cross-Covariance Calculation") @@ -296,7 +289,6 @@ def weight_Alignment_With_CC(self, m1, m2, new_module): new_module.load_state_dict(loading_sd) return new_module - # this function does not do an explicit specification of the colored covariance @torch.no_grad() def colored_Covariance_Specification(self, m1, m2, new_module): logging.info("Colored Covariance") From 992e98b4ac6d7026ced2122078001dd7937a0dd8 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Thu, 2 May 2024 14:09:02 -0400 Subject: [PATCH 48/77] formatting --- refactor/models/function_utils.py | 1 - refactor/pytorch_cifar_utils.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py index 6614ec0..c642039 100644 --- a/refactor/models/function_utils.py +++ b/refactor/models/function_utils.py @@ -83,7 +83,6 @@ def _replace_layers_scale(module): return recurse_preorder(model, _replace_layers_scale) - #used in activation cross-covariance calculation #input align hook def return_hook(): diff --git a/refactor/pytorch_cifar_utils.py b/refactor/pytorch_cifar_utils.py index ef51ed9..9bb3a6c 100644 --- a/refactor/pytorch_cifar_utils.py +++ b/refactor/pytorch_cifar_utils.py @@ -21,11 +21,9 @@ def set_seeds(seed): term_width = 80 - TOTAL_BAR_LENGTH = 65. last_time = time.time() begin_time = last_time - def progress_bar(current, total, msg=None): global last_time, begin_time if current == 0: From 0d0d60f89a18eb26bf41ea15c6f59f89dae766a1 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Thu, 2 May 2024 14:18:49 -0400 Subject: [PATCH 49/77] added README.md --- refactor/README.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 refactor/README.md diff --git a/refactor/README.md b/refactor/README.md new file mode 100644 index 0000000..5acf0e8 --- /dev/null +++ b/refactor/README.md @@ -0,0 +1,22 @@ +# Learning and Aligning Structured Random Feature Networks + +Our Factorized Covariance module is located in `conv_modules.py` and can be called similarly to a nn.Conv2d module, like so `m = FactConv2d(in_channels=3, out_channels=32, kernel_size=(3,3))`. + +V1 initalization based on the receptive V1 field of mice is located in `V1_covariance.py` and can be called like so; `V1_init(m, size=2, spatial_freq=0.1, scale=1, center=center)` where `center=((m.kernel_size[0]-1)/2, (m.kernel_size[1]-1)/2)`. + +Rainbow sampling can be done with our `RainbowSampler` class object as so: + +``` +from rainbow import RainbowSampler +R = RainbowSampler(net, trainloader) +rainbow_net = R.sample() +``` + +To use our factorized ResNet in our rainbow sampling procedure as outlined in "Learning and Aligning Structured Random Feature Networks" by White et al., specify `RainbowSampler(..., sampling='structured_alignment', wa=True, in_wa=True, aca=True)`. This can be specified for both FactConv2d and nn.Conv2d modules. + +To do the rainbow sampling procedure of "A Rainbow in Deep Network Black Boxes" by Guth et al., specify `RainbowSampler(..., sampling='cc_specification', wa=False, aca=True)`. This is specified specifically for networks using nn.Conv2d modules. + +We provide our trained ResNets and Fact-Conv variants in this google drive link. + +Run `python3 setup.py install` to install the Factored Covariance module + From 2b48a03f44660b6a998dfa34ab4c933c278c4af4 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Thu, 2 May 2024 14:25:28 -0400 Subject: [PATCH 50/77] added citation to README.md --- refactor/README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/refactor/README.md b/refactor/README.md index 5acf0e8..3c32ebd 100644 --- a/refactor/README.md +++ b/refactor/README.md @@ -20,3 +20,19 @@ We provide our trained ResNets and Fact-Conv variants in this google drive link. Run `python3 setup.py install` to install the Factored Covariance module + +# Cite Us + +If you found this repository or our paper helpful, please cite us as shown below: + +```bibtex +@inproceedings{ +white2024learning, +title={Learning and Aligning Structured Random Feature Networks}, +author={Vivian White and Muawiz Sajjad Chaudhary and Guy Wolf and Guillaume Lajoie and Kameron Decker Harris}, +booktitle={ICLR 2024 Workshop on Representational Alignment}, +year={2024}, +url={https://openreview.net/forum?id=vWhUQXQoFF} +} +``` + From 9dc5666e65ae5adc7ec38cd466509d07bd03957d Mon Sep 17 00:00:00 2001 From: Muawiz Sajjad Chaudhary <39755015+MuawizChaudhary@users.noreply.github.com> Date: Thu, 2 May 2024 14:40:34 -0400 Subject: [PATCH 51/77] Update README.md --- refactor/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/refactor/README.md b/refactor/README.md index 3c32ebd..ef6d557 100644 --- a/refactor/README.md +++ b/refactor/README.md @@ -1,9 +1,14 @@ # Learning and Aligning Structured Random Feature Networks +This is the repo for the Learning and Aligning Structured Random Feature Networks paper, accepted at the ICLR Re-Align 2024 workshop. +## Factorized Random Feature Convolutions Our Factorized Covariance module is located in `conv_modules.py` and can be called similarly to a nn.Conv2d module, like so `m = FactConv2d(in_channels=3, out_channels=32, kernel_size=(3,3))`. V1 initalization based on the receptive V1 field of mice is located in `V1_covariance.py` and can be called like so; `V1_init(m, size=2, spatial_freq=0.1, scale=1, center=center)` where `center=((m.kernel_size[0]-1)/2, (m.kernel_size[1]-1)/2)`. +Run `python3 setup.py install` to install the Factored Covariance module + +## Rainbow Sampling Rainbow sampling can be done with our `RainbowSampler` class object as so: ``` @@ -18,9 +23,6 @@ To do the rainbow sampling procedure of "A Rainbow in Deep Network Black Boxes" We provide our trained ResNets and Fact-Conv variants in this google drive link. -Run `python3 setup.py install` to install the Factored Covariance module - - # Cite Us If you found this repository or our paper helpful, please cite us as shown below: From 01b39c8cc6f8d7440d30b253ca1e7b6d09c763c1 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Fri, 3 May 2024 17:30:00 -0400 Subject: [PATCH 52/77] renaming folder --- FactConv/pytorch_cifar.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 893f38a..60263c3 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -14,12 +14,18 @@ from models import define_models def save_model(args, model): - src= "../saved-models/ResNets/" + src= "/home/mila/v/vivian.white/scratch/v1-models/saved-models/test_refactor/" model_dir = src + args.name os.makedirs(model_dir, exist_ok=True) + os.chdir(model_dir) - torch.save(model.state_dict(), model_dir+ "/model.pt") - torch.save(args, model_dir+ "/args.pt") + #saves loss & accuracy in the trial directory -- all trials + trial_dir = model_dir + "/trial_" + str(1) + os.makedirs(trial_dir, exist_ok=True) + os.chdir(trial_dir) + + torch.save(model.state_dict(), trial_dir+ "/model.pt") + torch.save(args, trial_dir+ "/args.pt") parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') @@ -76,7 +82,7 @@ def save_model(args, model): set_seeds(args.seed) net = net.to(device) -wandb_dir = "../../wandb" +wandb_dir = "/home/mila/v/vivian.white/scratch/v1-models/wandb" os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) From c92f6ac4727da2c3068d86c124c2f92070a2a3fd Mon Sep 17 00:00:00 2001 From: vivianwhite <66977221+vivianwhite@users.noreply.github.com> Date: Fri, 3 May 2024 14:39:32 -0700 Subject: [PATCH 53/77] Delete layers directory --- RSN_experiments/LearnableCov.py | 201 -------------------------------- 1 file changed, 201 deletions(-) delete mode 100644 RSN_experiments/LearnableCov.py diff --git a/RSN_experiments/LearnableCov.py b/RSN_experiments/LearnableCov.py deleted file mode 100644 index dc1fd07..0000000 --- a/RSN_experiments/LearnableCov.py +++ /dev/null @@ -1,201 +0,0 @@ -import torch -from torch import Tensor -import torch.nn as nn -from torch.nn.parameter import Parameter, UninitializedParameter -from torch.nn.common_types import _size_2_t -from typing import Optional, List, Tuple, Union - -class Linear(nn.Module): - in_features: int - out_features: int - weight: torch.Tensor - - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.factory_kwargs = factory_kwargs - - self.register_buffer("weight", - torch.empty((out_features, in_features), - **factory_kwargs)) - - triu_len = torch.triu_indices(in_features, in_features).shape[1] - self.tri_vec = Parameter(torch.empty((triu_len,), **factory_kwargs)) - if bias: - self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) - else: - self.register_parameter('bias', None) - self.reset_parameters() - - def reset_parameters(self) -> None: - nn.init.constant_(self.tri_vec, 0.) - # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with - # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see - # https://github.com/pytorch/pytorch/issues/57109 - nn.init.kaiming_normal_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(self.bias, -bound, bound) - - def forward(self, input: Tensor) -> Tensor: - U = torch.zeros((self.in_features, self.in_features), - **self.factory_kwargs) - U[torch.triu_indices(self.in_features, self.in_features).tolist()] \ - = self.tri_vec - exp_diag = torch.exp(torch.diagonal(U)) - U[range(self.in_features), range(self.in_features)] = exp_diag - composite_weight = self.weight @ U - - return F.linear(input, composite_weight, self.bias) - -class Conv2d(nn.Conv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t, - stride: _size_2_t = 1, - padding: Union[str, _size_2_t] = 0, - dilation: _size_2_t = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = 'zeros', # TODO: refine this type - device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), - dtype=None - ) -> None: - # init as Conv2d - super().__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, bias, padding_mode, device, dtype) - - factory_kwargs = {'device': device, 'dtype': dtype} - print("Device: ", device) - self.factory_kwargs = factory_kwargs - - # weight shape: (out_channels, in_channels // groups, *kernel_size) - weight_shape = self.weight.shape - del self.weight # remove Parameter, create buffer - self.register_buffer("weight", torch.empty(weight_shape, **factory_kwargs)) - nn.init.kaiming_normal_(self.weight) - - self.in_features = self.in_channels // self.groups * \ - self.kernel_size[0] * self.kernel_size[1] - triu_len = torch.triu_indices(self.in_features, - self.in_features).shape[1] - self.tri_vec = Parameter(torch.zeros((triu_len,), **factory_kwargs)) - - def forward(self, input: Tensor) -> Tensor: - U = torch.zeros((self.in_features, self.in_features), - **self.factory_kwargs) - U[torch.triu_indices(self.in_features, self.in_features).tolist()] \ - = self.tri_vec - exp_diag = torch.exp(torch.diagonal(U)) - U[range(self.in_features), range(self.in_features)] = exp_diag - - matrix_shape = (self.out_channels, self.in_features) - composite_weight = torch.reshape( - torch.reshape(self.weight, matrix_shape) @ U, - self.weight.shape - ) - - return self._conv_forward(input, composite_weight, self.bias) - -class FactConv2d(nn.Conv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t, - stride: _size_2_t = 1, - padding: Union[str, _size_2_t] = 0, - dilation: _size_2_t = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = 'zeros', # TODO: refine this type - device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), - dtype=None - ) -> None: - # init as Conv2d - super().__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, bias, padding_mode, device, dtype) - - factory_kwargs = {'device': device, 'dtype': dtype} - self.factory_kwargs = factory_kwargs - - # weight shape: (out_channels, in_channels // groups, *kernel_size) - weight_shape = self.weight.shape - del self.weight # remove Parameter, create buffer - self.register_buffer("weight", torch.empty(weight_shape, **factory_kwargs)) - nn.init.kaiming_normal_(self.weight) - - self.in_features = self.in_channels // self.groups * \ - self.kernel_size[0] * self.kernel_size[1] - triu1_len = torch.triu_indices(self.in_channels // self.groups, - self.in_channels // self.groups).shape[1] - triu2_len = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], - self.kernel_size[0] * self.kernel_size[1]).shape[1] - self.tri1_vec = Parameter(torch.zeros((triu1_len,), **factory_kwargs)) - self.tri2_vec = Parameter(torch.zeros((triu2_len,), **factory_kwargs)) - - def forward(self, input: Tensor) -> Tensor: - U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // self.groups) - U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1]) - U = torch.kron(U1, U2) - U = self._exp_diag(U) - - matrix_shape = (self.out_channels, self.in_features) - composite_weight = torch.reshape( - torch.reshape(self.weight, matrix_shape) @ U, - self.weight.shape - ) - - return self._conv_forward(input, composite_weight, self.bias) - - def _tri_vec_to_mat(self, vec, n): - U = torch.zeros((n, n), **self.factory_kwargs) - U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec - # TODO(kamdh): experiment with this placement versus after kron - # U = self._exp_diag(U) - return U - - def _exp_diag(self, mat): - exp_diag = torch.exp(torch.diagonal(mat)) - n = mat.shape[0] - mat[range(n), range(n)] = exp_diag - return mat - -def V1_init(layer, size, spatial_freq, center, scale=1., bias=False, seed=None, - device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): - ''' - Initialization for FactConv2d - ''' - import sys - sys.path.insert(0, '/research/harris/vivian/structured_random_features/') - from src.models.weights import V1_covariance_matrix - - classname = layer.__class__.__name__ - assert classname.find('FactConv2d') != -1, 'This init only works for FactConv2d layers' - assert center is not None, "center needed" - - out_channels, in_channels, xdim, ydim = layer.weight.shape - dim = (xdim, ydim) - - C_patch = Tensor(V1_covariance_matrix(dim, size, spatial_freq, center, scale)).to(device) - U_patch = torch.linalg.cholesky(C_patch, upper=True) - n = U_patch.shape[0] - # replace diagonal with logarithm for parameterization - log_diag = torch.log(torch.diagonal(U_patch)) - U_patch[range(n), range(n)] = log_diag - # form vector of upper triangular entries - tri_vec = U_patch[torch.triu_indices(n, n, device=device).tolist()].ravel() - with torch.no_grad(): - layer.tri2_vec.copy_(tri_vec) - - if bias == False: - layer.bias = None From e05bbcd9ad44173c91e510c2f841dc2d383de358 Mon Sep 17 00:00:00 2001 From: vivianwhite <66977221+vivianwhite@users.noreply.github.com> Date: Fri, 3 May 2024 15:28:19 -0700 Subject: [PATCH 54/77] Delete refactor directory has been moved to FactConv directory --- refactor/README.md | 40 ---- refactor/V1_covariance.py | 99 ---------- refactor/conv_modules.py | 87 --------- refactor/models/__init__.py | 19 -- refactor/models/function_utils.py | 193 ------------------- refactor/models/resnet.py | 132 ------------- refactor/pytorch_cifar.py | 160 ---------------- refactor/pytorch_cifar_utils.py | 101 ---------- refactor/rainbow.py | 308 ------------------------------ refactor/refactor_rainbow.py | 222 --------------------- 10 files changed, 1361 deletions(-) delete mode 100644 refactor/README.md delete mode 100644 refactor/V1_covariance.py delete mode 100644 refactor/conv_modules.py delete mode 100644 refactor/models/__init__.py delete mode 100644 refactor/models/function_utils.py delete mode 100644 refactor/models/resnet.py delete mode 100644 refactor/pytorch_cifar.py delete mode 100644 refactor/pytorch_cifar_utils.py delete mode 100644 refactor/rainbow.py delete mode 100644 refactor/refactor_rainbow.py diff --git a/refactor/README.md b/refactor/README.md deleted file mode 100644 index ef6d557..0000000 --- a/refactor/README.md +++ /dev/null @@ -1,40 +0,0 @@ -# Learning and Aligning Structured Random Feature Networks -This is the repo for the Learning and Aligning Structured Random Feature Networks paper, accepted at the ICLR Re-Align 2024 workshop. - -## Factorized Random Feature Convolutions -Our Factorized Covariance module is located in `conv_modules.py` and can be called similarly to a nn.Conv2d module, like so `m = FactConv2d(in_channels=3, out_channels=32, kernel_size=(3,3))`. - -V1 initalization based on the receptive V1 field of mice is located in `V1_covariance.py` and can be called like so; `V1_init(m, size=2, spatial_freq=0.1, scale=1, center=center)` where `center=((m.kernel_size[0]-1)/2, (m.kernel_size[1]-1)/2)`. - -Run `python3 setup.py install` to install the Factored Covariance module - -## Rainbow Sampling -Rainbow sampling can be done with our `RainbowSampler` class object as so: - -``` -from rainbow import RainbowSampler -R = RainbowSampler(net, trainloader) -rainbow_net = R.sample() -``` - -To use our factorized ResNet in our rainbow sampling procedure as outlined in "Learning and Aligning Structured Random Feature Networks" by White et al., specify `RainbowSampler(..., sampling='structured_alignment', wa=True, in_wa=True, aca=True)`. This can be specified for both FactConv2d and nn.Conv2d modules. - -To do the rainbow sampling procedure of "A Rainbow in Deep Network Black Boxes" by Guth et al., specify `RainbowSampler(..., sampling='cc_specification', wa=False, aca=True)`. This is specified specifically for networks using nn.Conv2d modules. - -We provide our trained ResNets and Fact-Conv variants in this google drive link. - -# Cite Us - -If you found this repository or our paper helpful, please cite us as shown below: - -```bibtex -@inproceedings{ -white2024learning, -title={Learning and Aligning Structured Random Feature Networks}, -author={Vivian White and Muawiz Sajjad Chaudhary and Guy Wolf and Guillaume Lajoie and Kameron Decker Harris}, -booktitle={ICLR 2024 Workshop on Representational Alignment}, -year={2024}, -url={https://openreview.net/forum?id=vWhUQXQoFF} -} -``` - diff --git a/refactor/V1_covariance.py b/refactor/V1_covariance.py deleted file mode 100644 index 936142c..0000000 --- a/refactor/V1_covariance.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -The following code is copied from the Structured Random Features library, -https://github.com/glomerulus-lab/structured-random-features, -used under the following license: - -The MIT License (MIT) -Copyright (c) 2021, Biraj Pandey - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of the Software, and to permit persons to whom the Software is furnished to do -so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in - all copies or substantial portions of the Software. - -""" - -import torch -from torch import Tensor -import numpy as np -import numpy.linalg as la -from scipy.spatial.distance import pdist, squareform - - -def V1_covariance_matrix(dim, size, spatial_freq, center, scale=1): - """ - Generates the covariance matrix for Gaussian Process with non-stationary - covariance. This matrix will be used to generate random - features inspired from the receptive-fields of V1 neurons. - - C(x, y) = exp(-|x - y|/(2 * spatial_freq))^2 * exp(-|x - m| / (2 * size))^2 * exp(-|y - m| / (2 * size))^2 - - Parameters - ---------- - - dim : tuple of shape (2, 1) - Dimension of random features. - - size : float - Determines the size of the random weights - - spatial_freq : float - Determines the spatial frequency of the random weights - - center : tuple of shape (2, 1) - Location of the center of the random weights. - - scale: float, default=1 - Normalization factor for Tr norm of cov matrix - - Returns - ------- - - C : array-like of shape (dim[0] * dim[1], dim[0] * dim[1]) - covariance matrix w/ Tr norm = scale * dim[0] * dim[1] - """ - - x = np.arange(dim[0]) - y = np.arange(dim[1]) - yy, xx = np.meshgrid(y, x) - grid = np.column_stack((xx.flatten(), yy.flatten())) - - a = squareform(pdist(grid, 'sqeuclidean')) - b = la.norm(grid - center, axis=1) ** 2 - c = b.reshape(-1, 1) - C = np.exp(-a / (2 * spatial_freq ** 2)) * np.exp(-b / (2 * size ** 2)) * np.exp(-c / (2 * size ** 2)) \ - + 1e-5 * np.eye(dim[0] * dim[1]) - C *= scale * dim[0] * dim[1] / np.trace(C) - return C - - -def V1_init(layer, size, spatial_freq, center, scale=1., bias=False, seed=None): - ''' - Initialization for FactConv2d - ''' - - classname = layer.__class__.__name__ - assert classname.find('FactConv2d') != -1, 'This init only works for FactConv2d layers' - assert center is not None, "center needed" - - out_channels, in_channels, xdim, ydim = layer.weight.shape - dim = (xdim, ydim) - - C_patch = Tensor(V1_covariance_matrix(dim, size, spatial_freq, center, scale)).to(layer.weight.device) - U_patch = torch.linalg.cholesky(C_patch, upper=True) - n = U_patch.shape[0] - # replace diagonal with logarithm for parameterization - log_diag = torch.log(torch.diagonal(U_patch)) - U_patch[range(n), range(n)] = log_diag - # form vector of upper triangular entries - tri_vec = U_patch[torch.triu_indices(n, n, device=layer.weight.device).tolist()].ravel() - with torch.no_grad(): - layer.tri2_vec.copy_(tri_vec) - - if bias == False: - layer.bias = None diff --git a/refactor/conv_modules.py b/refactor/conv_modules.py deleted file mode 100644 index 5274173..0000000 --- a/refactor/conv_modules.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -from torch import Tensor -import torch.nn as nn -from torch.nn.parameter import Parameter -from torch.nn.common_types import _size_2_t -from typing import Optional, List, Tuple, Union - -""" -The function below is copied directly from -https://bonnerlab.github.io/ccn-tutorial/pages/analyzing_neural_networks.html -""" -def _contract(tensor, matrix, axis): - """tensor is (..., D, ...), matrix is (P, D), returns (..., P, ...).""" - t = torch.moveaxis(tensor, source=axis, destination=-1) # (..., D) - r = t @ matrix.T # (..., P) - return torch.moveaxis(r, source=-1, destination=axis) # (..., P, ...) - - -class FactConv2d(nn.Conv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t, - stride: _size_2_t = 1, - padding: Union[str, _size_2_t] = 0, - dilation: _size_2_t = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = 'zeros', # TODO: refine this type - device=None, - dtype=None - ) -> None: - # init as Conv2d - super().__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, bias, padding_mode, device, dtype) - # weight shape: (out_channels, in_channels // groups, *kernel_size) - new_weight = torch.empty_like(self.weight) - del self.weight # remove Parameter, create buffer - self.register_buffer("weight", new_weight) - nn.init.kaiming_normal_(self.weight) - - self.in_features = self.in_channels // self.groups * \ - self.kernel_size[0] * self.kernel_size[1] - triu1 = torch.triu_indices(self.in_channels // self.groups, - self.in_channels // self.groups, - device=self.weight.device, - dtype=torch.long) - scat_idx1 = triu1[0]*self.in_channels//self.groups + triu1[1] - self.register_buffer("scat_idx1", scat_idx1, persistent=False) - - triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], - self.kernel_size[0] - * self.kernel_size[1], - device=self.weight.device, - dtype=torch.long) - scat_idx2 = triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1] - self.register_buffer("scat_idx2", scat_idx2, persistent=False) - - triu1_len = triu1.shape[1] - triu2_len = triu2.shape[1] - - tri1_vec = self.weight.new_zeros((triu1_len,)) - self.tri1_vec = Parameter(tri1_vec) - - tri2_vec = self.weight.new_zeros((triu2_len,)) - self.tri2_vec = Parameter(tri2_vec) - - - def forward(self, input: Tensor) -> Tensor: - U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // - self.groups, self.scat_idx1) - U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], - self.scat_idx2) - # flatten over filter dims and contract - composite_weight = _contract(self.weight, U1.T, 1) - composite_weight = _contract( - torch.flatten(composite_weight, -2, -1), U2.T, -1 - ).reshape(self.weight.shape) - return self._conv_forward(input, composite_weight, self.bias) - - def _tri_vec_to_mat(self, vec, n, scat_idx): - U = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n) - U = torch.diagonal_scatter(U, U.diagonal().exp_()) - return U - diff --git a/refactor/models/__init__.py b/refactor/models/__init__.py deleted file mode 100644 index 434ce4f..0000000 --- a/refactor/models/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from .resnet import ResNet18 -from .function_utils import replace_layers_factconv2d, turn_off_covar_grad, replace_layers_scale, init_V1_layers - - -def define_models(args): - if 'resnet18' in args.net: - model = ResNet18() - if args.width != 1: - replace_layers_scale(model, args.width) - if 'fact' in args.net: - replace_layers_factconv2d(model) - if "v1" in args.net: - init_V1_layers(model, bias=False) - if "us" in args.net: - turn_off_covar_grad(model, "spatial") - if "uc" in args.net: - turn_off_covar_grad(model, "channel") - - return model diff --git a/refactor/models/function_utils.py b/refactor/models/function_utils.py deleted file mode 100644 index c642039..0000000 --- a/refactor/models/function_utils.py +++ /dev/null @@ -1,193 +0,0 @@ -import torch -import torch.nn as nn -from conv_modules import FactConv2d -from V1_covariance import V1_init - - -def recurse_preorder(model, callback): - r = callback(model) - if r is not model and r is not None: - return r - for n, module in model.named_children(): - r = recurse_preorder(module, callback) - if r is not module and r is not None: - setattr(model, n, r) - return model - - -def replace_layers_factconv2d(model): - ''' - Replace nn.Conv2d layers with FactConv2d - ''' - def _replace_layers_factconv2d(module): - if isinstance(module, nn.Conv2d): - ## simple module - new_module = FactConv2d( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - bias=True if module.bias is not None else False) - old_sd = module.state_dict() - new_sd = new_module.state_dict() - new_sd['weight'] = old_sd['weight'] - if module.bias is not None: - new_sd['bias'] = old_sd['bias'] - new_module.load_state_dict(new_sd) - return new_module - return recurse_preorder(model, _replace_layers_factconv2d) - - -def replace_affines(model): - ''' - Set BatchNorm2d layers to have 'affine=False' - ''' - def _replace_affines(module): - if isinstance(module, nn.BatchNorm2d): - ## simple module - new_module = nn.BatchNorm2d( - num_features=module.num_features, - eps=module.eps, momentum=module.momentum, - affine=False, - track_running_stats=module.track_running_stats) - return new_module - return recurse_preorder(model, _replace_affines) - - -def replace_layers_scale(model, scale=1): - ''' - Replace nn.Conv2d layers with a different scale - ''' - def _replace_layers_scale(module): - if isinstance(module, nn.Conv2d): - if module.in_channels == 3: - in_scale = 1 - else: - in_scale = scale - ## simple module - new_module = nn.Conv2d( - in_channels=int(module.in_channels*in_scale), - out_channels=int(module.out_channels*scale), - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - groups = module.groups, - bias=True if module.bias is not None else False) - return new_module - if isinstance(module, nn.BatchNorm2d): - new_module = nn.BatchNorm2d(int(module.num_features*scale), - affine=module.affine) - return new_module - if isinstance(module, nn.Linear): - new_module = nn.Linear(int(module.in_features * scale), 10) - return new_module - return recurse_preorder(model, _replace_layers_scale) - - -#used in activation cross-covariance calculation -#input align hook -def return_hook(): - def hook(mod, inputs): - shape = inputs[0].shape - inputs_permute = inputs[0].permute(1,0,2,3).reshape(inputs[0].shape[1], -1) - reshape = (mod.input_align@inputs_permute).reshape(shape[1], - shape[0], shape[2], - shape[3]).permute(1, 0, 2, 3) - return reshape - return hook - - -def replace_layers_fact_with_conv(model): - ''' - Replace FactConv2d layers with nn.Conv2d - ''' - def _replace_layers_fact_with_conv(module): - if isinstance(module, FactConv2d): - ## simple module - new_module = nn.Conv2d( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, padding=module.padding, - bias=True if module.bias is not None else False) - old_sd = module.state_dict() - new_sd = new_module.state_dict() - if module.bias is not None: - new_sd['bias'] = old_sd['bias'] - U1 = module._tri_vec_to_mat(module.tri1_vec, module.in_channels // - module.groups,module.scat_idx1) - U2 = module._tri_vec_to_mat(module.tri2_vec, module.kernel_size[0] * module.kernel_size[1], - module.scat_idx2) - U = torch.kron(U1, U2) - matrix_shape = (module.out_channels, module.in_features) - composite_weight = torch.reshape( - torch.reshape(module.weight, matrix_shape) @ U, - module.weight.shape - ) - new_sd['weight'] = composite_weight - if 'weight_align' in old_sd.keys(): - new_sd['weight_align'] = old_sd['weight_align'] - shape = new_module.in_channels*new_module.kernel_size[0]*new_module.kernel_size[1] - new_module.register_buffer("weight_align",torch.zeros((shape, shape))) - if 'input_align' in old_sd.keys(): - new_sd['input_align'] = old_sd['input_align'] - out_shape = new_module.in_channels - new_module.register_buffer("input_align",torch.zeros((out_shape, out_shape))) - if module.in_channels != 3: - #fact check this - for key in list(module._forward_pre_hooks.keys()): - del module._forward_pre_hooks[key] - hook_handle_pre_forward = new_module.register_forward_pre_hook(return_hook()) - new_module.load_state_dict(new_sd) - new_module.to(old_sd['weight'].device) - return new_module - return recurse_preorder(model, _replace_layers_fact_with_conv) - - -def turn_off_covar_grad(model, covariance): - ''' - Turn off gradients in tri1_vec or tri2_vec to turn off - channel or spatial covariance learning - ''' - def _turn_off_covar_grad(module): - if isinstance(module, FactConv2d): - for name, param in module.named_parameters(): - if covariance == "channel": - if "tri1_vec" in name: - param.requires_grad = False - if covariance == "spatial": - if "tri2_vec" in name: - param.requires_grad = False - return recurse_preorder(model, _turn_off_covar_grad) - - -def turn_off_backbone_grad(model): - ''' - Turn off gradients in backbone. For tuning just classifier layer - ''' - def _turn_off_backbone_grad(module): - if isinstance(module, nn.Linear) and module.out_features == 10: - grad=True - else: - grad=False - for param in module.parameters(): - param.requires_grad = grad - return recurse_preorder(model, _turn_off_backbone_grad) - - -def init_V1_layers(model, bias): - ''' - Initialize every FactConv2d layer with V1-inspired - spatial weight init - ''' - def _init_V1_layers(module): - if isinstance(module, FactConv2d): - center = ((module.kernel_size[0] - 1) / 2, (module.kernel_size[1] - 1) / 2) - V1_init(module, size=2, spatial_freq=0.1, scale=1, center=center) - for name, param in module.named_parameters(): - if "weight" in name: - param.requires_grad = False - if bias: - if "bias" in name: - param.requires_grad = False - return recurse_preorder(model, _init_V1_layers) - diff --git a/refactor/models/resnet.py b/refactor/models/resnet.py deleted file mode 100644 index beb18f9..0000000 --- a/refactor/models/resnet.py +++ /dev/null @@ -1,132 +0,0 @@ -'''ResNet in PyTorch. - -For Pre-activation ResNet, see 'preact_resnet.py'. - -Reference: -[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun - Deep Residual Learning for Image Recognition. arXiv:1512.03385 -''' -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super(BasicBlock, self).__init__() - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=1, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion*planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion*planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, in_planes, planes, stride=1): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=stride, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, self.expansion * - planes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(self.expansion*planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion*planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion*planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = F.relu(self.bn2(self.conv2(out))) - out = self.bn3(self.conv3(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class ResNet(nn.Module): - def __init__(self, block, num_blocks, num_classes=10): - super(ResNet, self).__init__() - self.in_planes = 64 - - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, - stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) - self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) - self.linear = nn.Linear(512*block.expansion, num_classes) - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1]*(num_blocks-1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.layer1(out) - out = self.layer2(out) - out = self.layer3(out) - out = self.layer4(out) - out = F.avg_pool2d(out, 4) - out = out.view(out.size(0), -1) - out = self.linear(out) - return out - - -def ResNet18(): - return ResNet(BasicBlock, [2, 2, 2, 2]) - - -def ResNet34(): - return ResNet(BasicBlock, [3, 4, 6, 3]) - - -def ResNet50(): - return ResNet(Bottleneck, [3, 4, 6, 3]) - - -def ResNet101(): - return ResNet(Bottleneck, [3, 4, 23, 3]) - - -def ResNet152(): - return ResNet(Bottleneck, [3, 8, 36, 3]) - - -def test(): - net = ResNet18() - y = net(torch.randn(1, 3, 32, 32)) - print(y.size()) - -#test() diff --git a/refactor/pytorch_cifar.py b/refactor/pytorch_cifar.py deleted file mode 100644 index 22c5489..0000000 --- a/refactor/pytorch_cifar.py +++ /dev/null @@ -1,160 +0,0 @@ -'''Train CIFAR10 with PyTorch.''' -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -import torch.backends.cudnn as cudnn -import torchvision -import torchvision.transforms as transforms -import os -import argparse -from pytorch_cifar_utils import progress_bar, set_seeds -import wandb -from distutils.util import strtobool -from models import define_models - -def save_model(args, model): - src= "/home/mila/v/vivian.white/scratch/v1-models/saved-models/test_refactor/" - model_dir = src + args.name - os.makedirs(model_dir, exist_ok=True) - os.chdir(model_dir) - - #saves loss & accuracy in the trial directory -- all trials - trial_dir = model_dir + "/trial_" + str(1) - os.makedirs(trial_dir, exist_ok=True) - os.chdir(trial_dir) - - torch.save(model.state_dict(), trial_dir+ "/model.pt") - torch.save(args, trial_dir+ "/args.pt") - - -parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') -parser.add_argument('--lr', default=0.1, type=float, help='learning rate') -parser.add_argument('--resume', '-r', action='store_true', - help='resume from checkpoint') -parser.add_argument('--net', type=str, default='resnet18', help="which model to use") -parser.add_argument('--num_epochs', type=int, default=200, help='number of trainepochs') -parser.add_argument('--name', type=str, default='TESTING_VGG', - help='filename for saved model') -parser.add_argument('--seed', default=0, type=int, help='seed to use') -parser.add_argument('--width', type=float, default=1, help='resnet width scale factor') - -args = parser.parse_args() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' -best_acc = 0 # best test accuracy -start_epoch = 0 # start from epoch 0 or last checkpoint epoch - -# Data -print('==> Preparing data..') -transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), -]) - -transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), -]) - -trainset = torchvision.datasets.CIFAR10( - root='./data', train=True, download=True, transform=transform_train) -trainloader = torch.utils.data.DataLoader( - trainset, batch_size=128, shuffle=True, num_workers=8) - -testset = torchvision.datasets.CIFAR10( - root='./data', train=False, download=True, transform=transform_test) -testloader = torch.utils.data.DataLoader( - testset, batch_size=1000, shuffle=False, num_workers=8) - -classes = ('plane', 'car', 'bird', 'cat', 'deer', - 'dog', 'frog', 'horse', 'ship', 'truck') - -# Model -print('==> Building model..') - -net = define_models(args) -run_name = args.net -print("Args.net: ", args.net) -print("Net: ", net) -set_seeds(args.seed) - -net = net.to(device) -wandb_dir = "/home/mila/v/vivian.white/scratch/v1-models/wandb" -os.makedirs(wandb_dir, exist_ok=True) -os.chdir(wandb_dir) - -run = wandb.init(project="refactoring", config=args, - group="pytorch_cifar", name=run_name, dir=wandb_dir) -#wandb.watch(net, log='all', log_freq=1) - - -criterion = nn.CrossEntropyLoss() -optimizer = optim.SGD(net.parameters(), lr=args.lr, - momentum=0.9, weight_decay=5e-4) -scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) - -# Training -def train(epoch): - print('\nEpoch: %d' % epoch) - net.train() - train_loss = 0 - correct = 0 - total = 0 - for batch_idx, (inputs, targets) in enumerate(trainloader): - inputs, targets = inputs.to(device), targets.to(device) - optimizer.zero_grad() - outputs = net(inputs) - loss = criterion(outputs, targets) - loss.backward() - optimizer.step() - train_loss += loss.item() - _, predicted = outputs.max(1) - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' - % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) - -def test(epoch): - global best_acc - net.eval() - test_loss = 0 - correct = 0 - total = 0 - with torch.no_grad(): - for batch_idx, (inputs, targets) in enumerate(testloader): - inputs, targets = inputs.to(device), targets.to(device) - outputs = net(inputs) - loss = criterion(outputs, targets) - - test_loss += loss.item() - _, predicted = outputs.max(1) - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' - % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) - - # Save checkpoint. - acc = 100.*correct/total - run.log({"accuracy":acc}) - if acc > best_acc: - print('Saving..') - state = { - 'net': net.state_dict(), - 'acc': acc, - 'epoch': epoch, - } - save_model(args, net) - best_acc = acc - - -for epoch in range(start_epoch, start_epoch+args.num_epochs): - train(epoch) - test(epoch) - scheduler.step() -args.name += "final" -save_model(args, net) diff --git a/refactor/pytorch_cifar_utils.py b/refactor/pytorch_cifar_utils.py deleted file mode 100644 index 9bb3a6c..0000000 --- a/refactor/pytorch_cifar_utils.py +++ /dev/null @@ -1,101 +0,0 @@ -'''Some helper functions for PyTorch, including: - - progress_bar: progress bar mimic xlua.progress. -''' -import os -import sys -import time -import math - -import torch -import torch.nn as nn -import torch.nn.init as init -import random -import numpy as np - -def set_seeds(seed): - torch.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - random.seed(seed) - np.random.seed(seed) - - -term_width = 80 -TOTAL_BAR_LENGTH = 65. -last_time = time.time() -begin_time = last_time -def progress_bar(current, total, msg=None): - global last_time, begin_time - if current == 0: - begin_time = time.time() # Reset for new bar. - - cur_len = int(TOTAL_BAR_LENGTH*current/total) - rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 - - sys.stdout.write(' [') - for i in range(cur_len): - sys.stdout.write('=') - sys.stdout.write('>') - for i in range(rest_len): - sys.stdout.write('.') - sys.stdout.write(']') - - cur_time = time.time() - step_time = cur_time - last_time - last_time = cur_time - tot_time = cur_time - begin_time - - L = [] - L.append(' Step: %s' % format_time(step_time)) - L.append(' | Tot: %s' % format_time(tot_time)) - if msg: - L.append(' | ' + msg) - - msg = ''.join(L) - sys.stdout.write(msg) - for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): - sys.stdout.write(' ') - - # Go back to the center of the bar. - for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): - sys.stdout.write('\b') - sys.stdout.write(' %d/%d ' % (current+1, total)) - - if current < total-1: - sys.stdout.write('\r') - else: - sys.stdout.write('\n') - sys.stdout.flush() - - -def format_time(seconds): - days = int(seconds / 3600/24) - seconds = seconds - days*3600*24 - hours = int(seconds / 3600) - seconds = seconds - hours*3600 - minutes = int(seconds / 60) - seconds = seconds - minutes*60 - secondsf = int(seconds) - seconds = seconds - secondsf - millis = int(seconds*1000) - - f = '' - i = 1 - if days > 0: - f += str(days) + 'D' - i += 1 - if hours > 0 and i <= 2: - f += str(hours) + 'h' - i += 1 - if minutes > 0 and i <= 2: - f += str(minutes) + 'm' - i += 1 - if secondsf > 0 and i <= 2: - f += str(secondsf) + 's' - i += 1 - if millis > 0 and i <= 2: - f += str(millis) + 'ms' - i += 1 - if f == '': - f = '0ms' - return f diff --git a/refactor/rainbow.py b/refactor/rainbow.py deleted file mode 100644 index 911642b..0000000 --- a/refactor/rainbow.py +++ /dev/null @@ -1,308 +0,0 @@ -import torch -import torch.nn as nn - -import copy -import logging - -from pytorch_cifar_utils import set_seeds -from conv_modules import FactConv2d - - -#traditional way of calculating svd. can be a bit unstable sometimes tho -def calc_svd(A, name=''): - u, s, vh = torch.linalg.svd(A, full_matrices=False) # (C_in_reference, R), (R,), (R, C_in_generated) - alignment = u @ vh # (C_in_reference, C_in_generated) - return alignment, (u, s, vh) - - -#used in activation cross-covariance calculation -#input align hook -def return_hook(): - def hook(mod, inputs): - shape = inputs[0].shape - inputs_permute = inputs[0].permute(1,0,2,3).reshape(inputs[0].shape[1], -1) - reshape = (mod.input_align@inputs_permute).reshape(shape[1], - shape[0], shape[2], - shape[3]).permute(1, 0, 2, 3) - return reshape - return hook - - -class RainbowSampler: - def __init__(self, ref_net, trainloader, seed=0, sampling='structured_alignment', wa=True, in_wa=True, aca=True, device=None, num_classes=10, verbose=True): - self.ref_net = copy.deepcopy(ref_net) - self.gen_net = copy.deepcopy(ref_net) - self.trainloader = trainloader - self.seed = seed - self.sampling = sampling - self.wa = wa - self.in_wa = in_wa - self.aca = aca - self.device = torch.get_default_device() if device is None else torch.device(device) - self.num_classes = num_classes - logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, format='%(message)s') - - def sample(self): - set_seeds(self.seed) - self.ref_net.train() - self.gen_net = copy.deepcopy(self.ref_net) - self.gen_net.train() - logging.info("With seed {}".format(self.seed)) - self.our_rainbow_sampling(self.ref_net, self.gen_net) - return self.gen_net - - def load_state_dicts(self, m1, m2, new_module): - # reference model state dict - ref_sd = m1.state_dict() - # generated model state dict - uses reference model weights. for now - gen_sd = m2.state_dict() - - # module with random init - to be loaded to model - loading_sd = new_module.state_dict() - - # carry over old bias. only matters when we work with no batchnorm networks - if m1.bias is not None: - loading_sd['bias'] = ref_sd['bias'] - # carry over old colored covariance. only matters with fact-convs - if "tri1_vec" in ref_sd.keys(): - loading_sd['tri1_vec'] = ref_sd['tri1_vec'] - loading_sd['tri2_vec'] = ref_sd['tri2_vec'] - - return ref_sd, gen_sd, loading_sd - - @torch.no_grad() - def our_rainbow_sampling(self, model, new_model): - for (n1, m1), (n2, m2) in zip(model.named_children(), new_model.named_children()): - if len(list(m1.children())) > 0: - self.our_rainbow_sampling(m1, m2) - if isinstance(m1, nn.Conv2d): - if isinstance(m2, FactConv2d): - new_module = FactConv2d( - in_channels=m2.in_channels, - out_channels=m2.out_channels, - kernel_size=m2.kernel_size, - stride=m2.stride, padding=m2.padding, - bias=True if m2.bias is not None else - False).to(self.device) - else: - new_module = nn.Conv2d( - in_channels=int(m2.in_channels), - out_channels=int(m2.out_channels), - kernel_size=m2.kernel_size, - stride=m2.stride, padding=m2.padding, - groups = m2.groups, - bias=True if m2.bias is not None else False).to(self.device) - - if self.sampling == 'structured_alignment' and self.wa: - # right now this function does not do an explicit specification of the colored covariance - new_module = self.weight_Alignment(m1, m2, new_module, in_dim=self.in_wa) - if self.sampling == 'cc_specification': - # for conv only - if self.wa: - new_module = self.weight_Alignment_With_CC(m1, m2, new_module) - else: - new_module = self.colored_Covariance_Specification(m1, m2, new_module) - # this step calculates the activation cross-covariance alignment (ACA) - if m1.in_channels != 3 and self.aca: - new_module = self.conv_ACA(m1, m2, new_module) - # changes the network module - setattr(new_model, n1, new_module) - - #only computes the ACA - if isinstance(m1, nn.Linear) and self.aca: - new_module = self.linear_ACA(m1, m2, new_model) - setattr(new_model, n1, new_module) - ##just run stats through - if isinstance(m1, nn.BatchNorm2d): - self.batchNorm_stats_recalc(m1, m2) - - def run_forward(self, calc_covar=False): - if calc_covar: - covar = None - total = 0 - for batch_idx, (inputs, targets) in enumerate(self.trainloader): - inputs, targets = inputs.to(self.device), targets.to(self.device) - try: - outputs1 = self.ref_net(inputs) - except Exception: - pass - try: - outputs2 = self.gen_net(inputs) - except Exception: - pass - if calc_covar: - total+= inputs.shape[0] - if covar is None: - #activation is bwh x c - covar = self.activation[0].T @ self.activation[1] - assert (covar.isfinite().all()) - else: - #activation is bwh x c - covar += self.activation[0].T @ self.activation[1] - assert (covar.isfinite().all()) - self.activation = [] - if calc_covar: - #c x c - covar /= total - return covar - - def conv_ACA(self, m1, m2, new_module): - logging.info("Convolutional Input Activations Alignment") - self.activation = [] - - # this hook grabs the input activations of the conv layer - # rearanges the vector so that the width by height dim is - # additional samples to the covariance - # bwh x c - def define_hook(m): - def store_hook(mod, inputs, outputs): - #from bonner lab tutorial - x = inputs[0] - x = x.permute(0, 2, 3, 1) - x = x.reshape((-1, x.shape[-1])) - self.activation.append(x) - raise Exception("Done") - return store_hook - - hook_handle_1 = m1.register_forward_hook(define_hook(m1)) - hook_handle_2 = m2.register_forward_hook(define_hook(m2)) - logging.info("Starting Sample Cross-Covariance Calculation") - covar = self.run_forward(calc_covar=True) - logging.info("Sample Cross-Covariance Calculation finished") - hook_handle_1.remove() - hook_handle_2.remove() - align, _ = calc_svd(covar, name="Cross-Covariance") - new_module.register_buffer("input_align", align) - # this hook takes the input to the conv, aligns, then returns - # to the conv the aligned inputs - hook_handle_pre_forward = new_module.register_forward_pre_hook(return_hook()) - return new_module - - def batchNorm_stats_recalc(self, m1, m2): - logging.info("Calculating Batch Statistics") - m1.train() - m2.train() - m1.reset_running_stats() - m2.reset_running_stats() - handle_1 = m1.register_forward_hook(lambda mod, inputs, outputs: Exception("Done")) - handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: Exception("Done")) - m1.to(self.device) - m2.to(self.device) - self.run_forward(calc_covar=False) - handle_1.remove() - handle_2.remove() - m1.eval() - m2.eval() - logging.info("Batch Statistics Calculation Finished") - - def linear_ACA(self, m1, m2, new_model): - logging.info("Linear Input Activations Alignment") - new_module = nn.Linear(m1.in_features, m1.out_features, bias=True - if m1.bias is not None else False).to(self.device) - ref_sd = m1.state_dict() - loading_sd = new_module.state_dict() - if m1.out_features == self.num_classes: - loading_sd['weight'] = ref_sd['weight'] - if m1.bias is not None: - loading_sd['bias'] = ref_sd['bias'] - self.activation = [] - hook_handle_1 = m1.register_forward_hook(lambda mod, inputs, outputs: - self.activation.append(inputs[0])) - hook_handle_2 = m2.register_forward_hook(lambda mod, inputs, outputs: - self.activation.append(inputs[0])) - logging.info("Starting Sample Cross-Covariance Calculation") - covar = self.run_forward(calc_covar=True) - hook_handle_1.remove() - hook_handle_2.remove() - logging.info("Sample Cross-Covariance Calculation finished") - align, _ = calc_svd(covar, name="Cross-Covariance") - new_weight = loading_sd['weight'] - new_weight = torch.moveaxis(new_weight, source=1, - destination=-1) - new_weight = new_weight@align - loading_sd['weight'] = torch.moveaxis(new_weight, source=-1, - destination=1) - new_module.load_state_dict(loading_sd) - return new_module - - # IF FACT: we align the generated factnet with the reference fact net's noise - # IF CONV: we align the generated convnet with the reference conv net's weight matrix - def weight_Alignment(self, m1, m2, new_module, in_dim=True): - ref_sd, gen_sd, loading_sd = self.load_state_dicts(m1, m2, new_module) - reference_weight = ref_sd['weight'] - generated_weight = loading_sd['weight'] - - #reshape to outdim x indim*spatial - reference_weight = reference_weight.reshape(reference_weight.shape[0], -1) - generated_weight = generated_weight.reshape(generated_weight.shape[0], -1) - - #compute weight cross-covariance indim*spatial x indim*spatial - #TODO REFACTOR TO HAVE REF FIRST. OUTDIM x OUTDIM - if in_dim: - logging.info("Input Weight Alignment") - weight_cov = (generated_weight.T@reference_weight) - alignment, _ = calc_svd(weight_cov, name="Weight alignment") - - # outdim x indim*spatial - final_gen_weight = generated_weight@alignment - else: - logging.info("Output Weight Alignment") - weight_cov = (reference_weight@generated_weight.T) - alignment, _ = calc_svd(weight_cov, name="Weight alignment") - - # outdim x indim*spatial - final_gen_weight = alignment@generated_weight - - loading_sd['weight_align'] = alignment - new_module.register_buffer("weight_align", alignment) - - loading_sd['weight'] = final_gen_weight.reshape(ref_sd['weight'].shape) - new_module.load_state_dict(loading_sd) - return new_module - - def weight_Alignment_With_CC(self, m1, m2, new_module): - logging.info("Weight alignment with Colored Covariance") - ref_sd, gen_sd, loading_sd = self.load_state_dicts(m1, m2, new_module) - - old_weight = ref_sd['weight'] - A = old_weight.reshape(old_weight.shape[0], -1) - - _, (Un, Sn, Vn) = calc_svd(A) - white_gaussian = torch.randn_like(Un) - - copy_weight = Un - copy_weight_gen = white_gaussian - copy_weight = copy_weight.reshape(copy_weight.shape[0], -1) - copy_weight_gen = copy_weight_gen.reshape(copy_weight_gen.shape[0], -1).T - weight_cov = (copy_weight_gen@copy_weight) - - alignment, _ = calc_svd(weight_cov, name="Weight") - new_weight = white_gaussian - new_weight = new_weight.reshape(new_weight.shape[0], -1) - new_weight = new_weight@alignment # C_in_reference to C_in_generated - - new_module.register_buffer("weight_align", alignment) - loading_sd['weight_align'] = alignment - colored_gaussian = white_gaussian @ (Sn[:,None]* Vn) - - loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) - new_module.load_state_dict(loading_sd) - return new_module - - @torch.no_grad() - def colored_Covariance_Specification(self, m1, m2, new_module): - logging.info("Colored Covariance") - ref_sd, gen_sd, loading_sd = self.load_state_dicts(m1, m2, new_module) - - old_weight = ref_sd['weight'] - A = old_weight.reshape(old_weight.shape[0], -1) - - _, (Un, Sn, Vn) = calc_svd(A) - white_gaussian = torch.randn_like(Un) - colored_gaussian = white_gaussian @ (Sn[:,None]* Vn) - - loading_sd['weight'] = colored_gaussian.reshape(old_weight.shape) - new_module.load_state_dict(loading_sd) - return new_module - - diff --git a/refactor/refactor_rainbow.py b/refactor/refactor_rainbow.py deleted file mode 100644 index 96bcfcc..0000000 --- a/refactor/refactor_rainbow.py +++ /dev/null @@ -1,222 +0,0 @@ -'''Train CIFAR10 with PyTorch.''' -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -import torch.backends.cudnn as cudnn -import torchvision -import torchvision.transforms as transforms - -import wandb -import numpy as np - -import time -import os -import argparse -import copy -import gc -from distutils.util import strtobool - -from pytorch_cifar_utils import progress_bar, set_seeds -from models.resnet import ResNet18 -from conv_modules import FactConv2d -from models.function_utils import replace_layers_factconv2d,\ -replace_layers_scale, replace_layers_fact_with_conv, turn_off_backbone_grad, \ -recurse_preorder -from rainbow import RainbowSampler - -def save_model(args, model): - src = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/saved-models/refactoring/" - model_dir = src + args.name - os.makedirs(model_dir, exist_ok=True) - os.chdir(model_dir) - - #saves loss & accuracy in the trial directory -- all trials - - torch.save(model.state_dict(), model_dir+ "/model.pt") - torch.save(args, model_dir+ "/args.pt") - - -parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') -parser.add_argument('--lr', default=0.1, type=float, help='learning rate') -parser.add_argument('--epochs', default=10, type=int, help='number of epochs') -parser.add_argument('--seed', default=0, type=int, help='seed') -parser.add_argument('--name', type=str, default='TESTING_VGG', - help='filename for saved model') -parser.add_argument('--aca', type=lambda x: bool(strtobool(x)), - default=True, help='Activation Cross-Covariance Alignment') -parser.add_argument('--wa', type=lambda x: bool(strtobool(x)), - default=True, help='Weight alignment True=Yes False=No') -parser.add_argument('--in_wa', type=lambda x: bool(strtobool(x)), - default=True, help='input=True output=False') -parser.add_argument('--fact', type=lambda x: bool(strtobool(x)), - default=True, help='FactNet True or False') -parser.add_argument('--width', default=0.125, type=float, help='width') -parser.add_argument('--sampling', type=str, default='structured_alignment', - choices=['structured_alignment', 'cc_specification'], help="which sampling to use") - -args = parser.parse_args() - -if int(args.width) == args.width: - args.width = int(args.width) - -print("Sampling: {} Width: {} Fact: {} ACA: {} WA: {} In_WA: {}".format(args.sampling, - args.width, args.fact, args.aca, args.wa, args.in_wa)) - -device = 'cuda' if torch.cuda.is_available() else 'cpu' -best_acc = 0 # best test accuracy -start_epoch = 0 # start from epoch 0 or last checkpoint epoch - -# Data -print('==> Preparing data..') -transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), -]) - -transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), -]) - -trainset = torchvision.datasets.CIFAR10( - root='./data', train=True, download=True, transform=transform_train) -trainloader = torch.utils.data.DataLoader( - trainset, batch_size=128, shuffle=True, num_workers=4, drop_last=True) - -testset = torchvision.datasets.CIFAR10( - root='./data', train=False, download=True, transform=transform_test) -testloader = torch.utils.data.DataLoader( - testset, batch_size=1000, shuffle=False, num_workers=8) - -classes = ('plane', 'car', 'bird', 'cat', 'deer', - 'dog', 'frog', 'horse', 'ship', 'truck') - -# Model -print('==> Building model..') - - -criterion = nn.CrossEntropyLoss() - -def train(epoch, net): - print('\nEpoch: %d' % epoch) - net.train() - train_loss = 0 - correct = 0 - total = 0 - for batch_idx, (inputs, targets) in enumerate(trainloader): - inputs, targets = inputs.to(device), targets.to(device) - optimizer.zero_grad() - outputs = net(inputs) - loss = criterion(outputs, targets) - loss.backward() - optimizer.step() - train_loss += loss.item() - _, predicted = outputs.max(1) - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' - % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) - - -def test(epoch, net): - global best_acc - net.eval() - test_loss = 0 - correct = 0 - total = 0 - with torch.no_grad(): - for batch_idx, (inputs, targets) in enumerate(testloader): - inputs, targets = inputs.to(device), targets.to(device) - outputs = net(inputs) - loss = criterion(outputs, targets) - - test_loss += loss.item() - _, predicted = outputs.max(1) - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' - % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) - # Save checkpoint. - acc = 100.*correct/total - print("accuracy:", acc) - return acc, test_loss - -logger ={'width':args.width}#, } -set_seeds(args.seed) -for i in range(0, 1): - net=ResNet18() - replace_layers_scale(net, args.width) - if args.fact: - replace_layers_factconv2d(net) - - - if args.fact: - sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/fact_model.pt".format(args.width)) - elif not args.fact: - sd=torch.load("/network/scratch/v/vivian.white/v1-models/saved-models/affine_1/{}scale_final/conv_model.pt".format(args.width)) - net.load_state_dict(sd) - net.to(device) - print(net) - - set_seeds(i) - print("testing Res{}Net18 with width of {}".format("Fact" if args.fact else "Conv", args.width)) - pretrained_acc, og_loss = test(0, net) - - - s=time.time() - args.seed = i - rainbow = RainbowSampler(net, trainloader, args.seed, args.sampling, args.wa, args.in_wa, args.aca, device) - rainbow_net = rainbow.sample() - rainbow_net.train() - - for batch_idx, (inputs, targets) in enumerate(trainloader): - inputs, targets = inputs.to(device), targets.to(device) - outputs = rainbow_net(inputs) - print("TOTAL TIME:", time.time()-s) - turn_off_backbone_grad(rainbow_net) - optimizer = optim.SGD(filter(lambda param: param.requires_grad, rainbow_net.parameters()), lr=args.lr, - momentum=0.9, weight_decay=5e-4) - print("testing {} sampling at width {}".format(args.sampling, args.width)) - rainbow_net.eval() - - print(rainbow_net) - - sampled_acc, sampled_loss = test(0, rainbow_net) - save_model(args, rainbow_net) - accs = [] - test_losses= [] - print("training classifier head of {} sampled model for {} epochs".format(args.sampling, args.epochs)) - for j in range(0, args.epochs): - rainbow_net.train() - train(j, rainbow_net) - rainbow_net.eval() - acc, loss_test =test(j, rainbow_net) - test_losses.append(loss_test) - accs.append(acc) - - new_logger ={"sampled_acc_{}".format(i): sampled_acc,"pretrained_acc_{}".format(i): - pretrained_acc, "og_loss_{}".format(i): og_loss, - "first_epoch_acc_{}".format(i):accs[0], "third_epoch_acc_{}".format(i): accs[2], - "tenth_epoch_acc_{}".format(i):accs[args.epochs-1], - "sampled_loss_{}".format(i):sampled_loss, - "first_epoch_loss_{}".format(i):test_losses[0], "third_epoch_loss_{}".format(i): test_losses[2], - "tenth_epoch_loss_{}".format(i):test_losses[args.epochs-1]} - logger = {**logger, **new_logger} - -wandb_dir = "/home/mila/m/muawiz.chaudhary/scratch/v1-models/wandb" -os.makedirs(wandb_dir, exist_ok=True) -os.chdir(wandb_dir) -#group_string = "refactor" -group_string = "IGNOREvariance_runs" - -#run_name = "refactor" -run_name= "width_{}_sampling_{}_fact_{}_ACA_{}_WA_{}_inWA_{}".format(args.width, args.sampling, args.fact, args.aca, args.wa, args.in_wa) -args.name = run_name -run = wandb.init(project="random_project", config=args, - group=group_string, name=run_name, dir=wandb_dir) -run.log(logger) From 0a1c34b15819960f24c8ac5de6ab218ac1c356f1 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Fri, 3 May 2024 18:37:02 -0400 Subject: [PATCH 55/77] moving scripts --- FactConv/RSN-scripts/LC_CIFAR10.py | 10 +- FactConv/RSN-scripts/LC_CIFAR100.py | 2 +- .../RSN-scripts/LC_CIFAR10_Small_Sample.py | 2 +- .../scripts/s_f_sweep/generate_params.py | 18 ++ V1-models/scripts/s_f_sweep/run_model.py | 195 ++++++++++++++++++ V1-models/scripts/s_f_sweep/script1 | 73 +++++++ V1-models/scripts/s_f_sweep/script2 | 76 +++++++ V1-models/scripts/s_f_sweep/sweep_data.py | 50 +++++ V1-models/scripts/s_f_sweep/sweep_data.txt | 50 +++++ .../scripts/s_f_sweep/sweep_data_script.py | 6 + 10 files changed, 476 insertions(+), 6 deletions(-) create mode 100644 V1-models/scripts/s_f_sweep/generate_params.py create mode 100644 V1-models/scripts/s_f_sweep/run_model.py create mode 100755 V1-models/scripts/s_f_sweep/script1 create mode 100755 V1-models/scripts/s_f_sweep/script2 create mode 100644 V1-models/scripts/s_f_sweep/sweep_data.py create mode 100644 V1-models/scripts/s_f_sweep/sweep_data.txt create mode 100644 V1-models/scripts/s_f_sweep/sweep_data_script.py diff --git a/FactConv/RSN-scripts/LC_CIFAR10.py b/FactConv/RSN-scripts/LC_CIFAR10.py index 9d5cc33..c246682 100644 --- a/FactConv/RSN-scripts/LC_CIFAR10.py +++ b/FactConv/RSN-scripts/LC_CIFAR10.py @@ -43,7 +43,7 @@ def test(model, device, test_loader, epoch): return test_loss, accuracy def save_model(args, model, loss, accuracy): - src = "../../saved-models/CIFAR10/" + src = "/research/harris/vivian/v1-models/saved-models/CIFAR10/" model_dir = src + args.name if not os.path.exists(model_dir): os.makedirs(model_dir) @@ -110,19 +110,21 @@ def save_model(args, model, loss, accuracy): train_loader = torch.utils.data.DataLoader( datasets.CIFAR10( - root=scattering_datasets.get_dataset_dir('CIFAR'), + #root=scattering_datasets.get_dataset_dir('CIFAR'), + root="/research/harris/vivian/v1-models/datasets/new_CIFAR10", train=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, - ]), download=True), + ]), download=False), batch_size=512, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) test_loader = torch.utils.data.DataLoader( datasets.CIFAR10( - root=scattering_datasets.get_dataset_dir('CIFAR'), + root="/research/harris/vivian/v1-models/datasets/new_CIFAR10", + #root=scattering_datasets.get_dataset_dir('CIFAR'), train=False, transform=transforms.Compose([ transforms.ToTensor(), diff --git a/FactConv/RSN-scripts/LC_CIFAR100.py b/FactConv/RSN-scripts/LC_CIFAR100.py index 8beef5b..273f9dd 100644 --- a/FactConv/RSN-scripts/LC_CIFAR100.py +++ b/FactConv/RSN-scripts/LC_CIFAR100.py @@ -44,7 +44,7 @@ def test(model, device, test_loader, epoch): return test_loss, accuracy def save_model(args, model, loss, accuracy): - src = "../../saved-models/CIFAR100/" + src = "/research/harris/vivian/v1-models/saved-models/CIFAR100/" model_dir = src + args.name if not os.path.exists(model_dir): os.makedirs(model_dir) diff --git a/FactConv/RSN-scripts/LC_CIFAR10_Small_Sample.py b/FactConv/RSN-scripts/LC_CIFAR10_Small_Sample.py index 5f8bf14..cad992b 100644 --- a/FactConv/RSN-scripts/LC_CIFAR10_Small_Sample.py +++ b/FactConv/RSN-scripts/LC_CIFAR10_Small_Sample.py @@ -46,7 +46,7 @@ def test(model, device, test_loader, epoch): return test_loss, accuracy def save_model(args, model, loss, accuracy): - src = "../../saved-models/CIFAR10_50_Samples/" + src = "/research/harris/vivian/v1-models/saved-models/CIFAR10_50_Samples/" model_dir = src + args.name if not os.path.exists(model_dir): os.makedirs(model_dir) diff --git a/V1-models/scripts/s_f_sweep/generate_params.py b/V1-models/scripts/s_f_sweep/generate_params.py new file mode 100644 index 0000000..0a50aac --- /dev/null +++ b/V1-models/scripts/s_f_sweep/generate_params.py @@ -0,0 +1,18 @@ +import numpy as np + +f_arr = np.array([0.1, 0.5, 1, 2, 3, 5, 7]) +s_arr = np.array([1, 2, 3, 4, 5, 6, 10]) + +num_combinations = len(f_arr) * len(s_arr) +i = 0 +for f in f_arr: + for s in s_arr: + i += 1 + if i <= num_combinations / 2: + for trial in range(3): + print("Trial: {} s: {} f: {}".format(trial+1, s, f)) + print("python3 run_model.py --device=0 --s={} --f={} --trial={} --name='s_{}_f_{}'".format(s, f, trial+1, s, f)) + else: + for trial in range(3): + print("Trial: {} s: {} f: {}".format(trial+1, s, f)) + print("python3 run_model.py --device=1 --s={} --f={} --trial={} --name='s_{}_f_{}'".format(s, f, trial+1, s, f)) \ No newline at end of file diff --git a/V1-models/scripts/s_f_sweep/run_model.py b/V1-models/scripts/s_f_sweep/run_model.py new file mode 100644 index 0000000..f81f76f --- /dev/null +++ b/V1-models/scripts/s_f_sweep/run_model.py @@ -0,0 +1,195 @@ +#runs from script1 and script2 from generate_params.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim +from torchvision import datasets, transforms +from kymatio.torch import Scattering2D +import kymatio.datasets as scattering_datasets +import matplotlib.pyplot as plt +from datetime import datetime +import os +from distutils.util import strtobool +import argparse +import sys +sys.path.insert(0, '/research/harris/vivian/structured_random_features/') +from src.models.init_weights import V1_init, classical_init, V1_weights + + +class BN_V1_V1_LinearLayer(nn.Module): + def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): + super(BN_V1_V1_LinearLayer, self).__init__() + self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, + bias=bias) + self.clf = nn.Linear((3 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 10) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(3) + self.bn0 = nn.BatchNorm2d(3) + self.bn1 = nn.BatchNorm2d(hidden_dim) + self.bn2 = nn.BatchNorm2d(hidden_dim) + + scale1 = hidden_dim / ((3 * (32 * 32) ** 2) ) + scale2 = hidden_dim / ((hidden_dim * (32 * 32) ** 2)) + center = None + + V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) + self.v1_layer.weight.requires_grad = False + + V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) + self.v1_layer2.weight.requires_grad = False + + if bias==True: + self.v1_layer.bias.requires_grad = False + self.v1_layer2.bias.requires_grad = False + + def forward(self, x): #[128, 3, 32, 32] + x = self.bn(x) + + layer1 = self.v1_layer(x) #[128, hidden_dim, 32, 32] w/ k=7, s=1, p=3 + layer1bn = self.bn1(layer1) + + layer2 = self.v1_layer2(layer1bn) #COULD ADD BATCH NORM HERE #[128, hidden_dim, 32, 32] w/ k=7, s=1, p=3 + layer2bn = self.bn2(layer2) + + x1 = self.relu(x) + h1 = self.relu(layer1bn) + h2 = self.relu(layer2bn) + + pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) + x_pool = pool(x1) #[128, 3, 8, 8] + h1_pool = pool(h1) #[128, hidden_dim, 8, 8] + h2_pool = pool(h2) #[128, hidden_dim, 8, 8] + + x_flat = x_pool.view(x_pool.size(0), -1) #[128, 192] = [128, 3 * 8 * 8] std ~1, mean ~0 + h1_flat = h1_pool.view(h1_pool.size(0), -1) #[128, hidden_dim * 8 * 8] std ~1, mean ~0 + h2_flat = h2_pool.view(h2_pool.size(0), -1) #[128, hidden_dim * 8 * 8] std ~1, mean ~0 + + concat = torch.cat((x_flat, h1_flat, h2_flat), 1) #[128, (3 * 8 * 8) + (hidden_dim * 8 * 8) + (hidden_dim * 8 * 8) + + beta = self.clf(concat) #[128, 10] + return beta + +def train(model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + optimizer.step() + + +def test(model, device, test_loader, epoch): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss + pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + accuracy = 100. * correct / len(test_loader.dataset) + + if epoch == 90: + print('Loss: {:.4f}\t Accuracy: {:.2f}%'.format(test_loss, accuracy)) + + return test_loss, accuracy + +def save_model(args, model): + src = "/research/harris/vivian/v1-models/saved-models/BN_V1_V1_Linear/" + model_dir = src + "s_f_sweep/" + args.name + if not os.path.exists(model_dir): + os.makedirs(model_dir) + os.chdir(model_dir) + + trial_dir = model_dir + "/trial_" + str(args.trial) + if not os.path.exists(trial_dir): + os.makedirs(trial_dir) + os.chdir(trial_dir) + + torch.save(test_loss, "loss.pt") + torch.save(test_accuracy, "accuracy.pt") + torch.save(model, "model.pt") + torch.save(args, "args.pt") + + +if __name__ == '__main__': + print("STARTING MODEL") + parser = argparse.ArgumentParser(description='CIFAR scattering + hybrid examples') + parser.add_argument('--hidden_dim', type=int, default=100, help='number of hidden dimensions in model') + parser.add_argument('--num_epoch', type=int, default=90, help='number of epochs') + parser.add_argument('--lr', type=float, default=0.01, help='learning rate') + parser.add_argument('--s', type=int, default=3, help='V1 size') + parser.add_argument('--f', type=float, default=1.0, help='V1 spatial frequency') + parser.add_argument('--scale', type=int, default=1, help='V1 scale') + parser.add_argument('--name', type=str, default='new model', help='filename for saved model') + parser.add_argument('--trial', type=int, default=1, help='trial number') + parser.add_argument('--bias', dest='bias', type=lambda x: bool(strtobool(x)), default=False, help='bias=True or False') + parser.add_argument('--device', type=int, default=0, help="which device to use (0 or 1)") + args = parser.parse_args() + initial_lr = args.lr + + use_cuda = torch.cuda.is_available() + device = torch.device("cuda:" + str(args.device) if use_cuda else "cpu") + + start = datetime.now() + + model = BN_V1_V1_LinearLayer(args.hidden_dim, args.s, args.f, args.scale, args.bias).to(device) + + # DataLoaders + if use_cuda: + num_workers = 4 + pin_memory = True + else: + num_workers = 0 + pin_memory = False + + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_loader = torch.utils.data.DataLoader( + datasets.CIFAR10(root=scattering_datasets.get_dataset_dir('CIFAR'), train=True, transform=transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32, 4), + transforms.ToTensor(), + normalize, + ]), download=False), + batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) + + test_loader = torch.utils.data.DataLoader( + datasets.CIFAR10(root=scattering_datasets.get_dataset_dir('CIFAR'), train=False, transform=transforms.Compose([ + transforms.ToTensor(), + normalize, + ])), + batch_size=128, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) + + + test_loss = [] + test_accuracy = [] + epoch_list = [] + + for epoch in range(0, args.num_epoch): + if epoch%20==0: + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9,weight_decay=0.0005, nesterov=True) + args.lr*=0.2 + + train(model, device, train_loader, optimizer, epoch+1) + loss, accuracy = test(model, device, test_loader, epoch+1) + test_loss.append(loss) + test_accuracy.append(accuracy) + epoch_list.append(epoch) + + end = datetime.now() + print("Trial {} time (HH:MM:SS): {}".format(args.trial, end-start)) + #print("Hidden dim: {}\t Learning rate: {}".format(args.hidden_dim, initial_lr)) + + save_model(args, model) + diff --git a/V1-models/scripts/s_f_sweep/script1 b/V1-models/scripts/s_f_sweep/script1 new file mode 100755 index 0000000..eedafe3 --- /dev/null +++ b/V1-models/scripts/s_f_sweep/script1 @@ -0,0 +1,73 @@ +#!/bin/bash +#python3 run_model.py --device=0 --s=1 --f=0.1 --trial=1 --name='s_1_f_0.1' +#python3 run_model.py --device=0 --s=1 --f=0.1 --trial=2 --name='s_1_f_0.1' +#python3 run_model.py --device=0 --s=1 --f=0.1 --trial=3 --name='s_1_f_0.1' +#python3 run_model.py --device=0 --s=2 --f=0.1 --trial=1 --name='s_2_f_0.1' +#python3 run_model.py --device=0 --s=2 --f=0.1 --trial=2 --name='s_2_f_0.1' +#python3 run_model.py --device=0 --s=2 --f=0.1 --trial=3 --name='s_2_f_0.1' +#python3 run_model.py --device=0 --s=3 --f=0.1 --trial=1 --name='s_3_f_0.1' +#python3 run_model.py --device=0 --s=3 --f=0.1 --trial=2 --name='s_3_f_0.1' +#python3 run_model.py --device=0 --s=3 --f=0.1 --trial=3 --name='s_3_f_0.1' +#python3 run_model.py --device=0 --s=4 --f=0.1 --trial=1 --name='s_4_f_0.1' +#python3 run_model.py --device=0 --s=4 --f=0.1 --trial=2 --name='s_4_f_0.1' +#python3 run_model.py --device=0 --s=4 --f=0.1 --trial=3 --name='s_4_f_0.1' +python3 run_model.py --device=0 --s=5 --f=0.1 --trial=1 --name='s_5_f_0.1' +python3 run_model.py --device=0 --s=5 --f=0.1 --trial=2 --name='s_5_f_0.1' +python3 run_model.py --device=0 --s=5 --f=0.1 --trial=3 --name='s_5_f_0.1' +python3 run_model.py --device=0 --s=6 --f=0.1 --trial=1 --name='s_6_f_0.1' +python3 run_model.py --device=0 --s=6 --f=0.1 --trial=2 --name='s_6_f_0.1' +python3 run_model.py --device=0 --s=6 --f=0.1 --trial=3 --name='s_6_f_0.1' +python3 run_model.py --device=0 --s=10 --f=0.1 --trial=1 --name='s_10_f_0.1' +python3 run_model.py --device=0 --s=10 --f=0.1 --trial=2 --name='s_10_f_0.1' +python3 run_model.py --device=0 --s=10 --f=0.1 --trial=3 --name='s_10_f_0.1' +python3 run_model.py --device=0 --s=1 --f=0.5 --trial=1 --name='s_1_f_0.5' +python3 run_model.py --device=0 --s=1 --f=0.5 --trial=2 --name='s_1_f_0.5' +python3 run_model.py --device=0 --s=1 --f=0.5 --trial=3 --name='s_1_f_0.5' +python3 run_model.py --device=0 --s=2 --f=0.5 --trial=1 --name='s_2_f_0.5' +python3 run_model.py --device=0 --s=2 --f=0.5 --trial=2 --name='s_2_f_0.5' +python3 run_model.py --device=0 --s=2 --f=0.5 --trial=3 --name='s_2_f_0.5' +python3 run_model.py --device=0 --s=3 --f=0.5 --trial=1 --name='s_3_f_0.5' +python3 run_model.py --device=0 --s=3 --f=0.5 --trial=2 --name='s_3_f_0.5' +python3 run_model.py --device=0 --s=3 --f=0.5 --trial=3 --name='s_3_f_0.5' +python3 run_model.py --device=0 --s=4 --f=0.5 --trial=1 --name='s_4_f_0.5' +python3 run_model.py --device=0 --s=4 --f=0.5 --trial=2 --name='s_4_f_0.5' +python3 run_model.py --device=0 --s=4 --f=0.5 --trial=3 --name='s_4_f_0.5' +python3 run_model.py --device=0 --s=5 --f=0.5 --trial=1 --name='s_5_f_0.5' +python3 run_model.py --device=0 --s=5 --f=0.5 --trial=2 --name='s_5_f_0.5' +python3 run_model.py --device=0 --s=5 --f=0.5 --trial=3 --name='s_5_f_0.5' +python3 run_model.py --device=0 --s=6 --f=0.5 --trial=1 --name='s_6_f_0.5' +python3 run_model.py --device=0 --s=6 --f=0.5 --trial=2 --name='s_6_f_0.5' +python3 run_model.py --device=0 --s=6 --f=0.5 --trial=3 --name='s_6_f_0.5' +python3 run_model.py --device=0 --s=10 --f=0.5 --trial=1 --name='s_10_f_0.5' +python3 run_model.py --device=0 --s=10 --f=0.5 --trial=2 --name='s_10_f_0.5' +python3 run_model.py --device=0 --s=10 --f=0.5 --trial=3 --name='s_10_f_0.5' +python3 run_model.py --device=0 --s=1 --f=1.0 --trial=1 --name='s_1_f_1.0' +python3 run_model.py --device=0 --s=1 --f=1.0 --trial=2 --name='s_1_f_1.0' +python3 run_model.py --device=0 --s=1 --f=1.0 --trial=3 --name='s_1_f_1.0' +python3 run_model.py --device=0 --s=2 --f=1.0 --trial=1 --name='s_2_f_1.0' +python3 run_model.py --device=0 --s=2 --f=1.0 --trial=2 --name='s_2_f_1.0' +python3 run_model.py --device=0 --s=2 --f=1.0 --trial=3 --name='s_2_f_1.0' +python3 run_model.py --device=0 --s=3 --f=1.0 --trial=1 --name='s_3_f_1.0' +python3 run_model.py --device=0 --s=3 --f=1.0 --trial=2 --name='s_3_f_1.0' +python3 run_model.py --device=0 --s=3 --f=1.0 --trial=3 --name='s_3_f_1.0' +python3 run_model.py --device=0 --s=4 --f=1.0 --trial=1 --name='s_4_f_1.0' +python3 run_model.py --device=0 --s=4 --f=1.0 --trial=2 --name='s_4_f_1.0' +python3 run_model.py --device=0 --s=4 --f=1.0 --trial=3 --name='s_4_f_1.0' +python3 run_model.py --device=0 --s=5 --f=1.0 --trial=1 --name='s_5_f_1.0' +python3 run_model.py --device=0 --s=5 --f=1.0 --trial=2 --name='s_5_f_1.0' +python3 run_model.py --device=0 --s=5 --f=1.0 --trial=3 --name='s_5_f_1.0' +python3 run_model.py --device=0 --s=6 --f=1.0 --trial=1 --name='s_6_f_1.0' +python3 run_model.py --device=0 --s=6 --f=1.0 --trial=2 --name='s_6_f_1.0' +python3 run_model.py --device=0 --s=6 --f=1.0 --trial=3 --name='s_6_f_1.0' +python3 run_model.py --device=0 --s=10 --f=1.0 --trial=1 --name='s_10_f_1.0' +python3 run_model.py --device=0 --s=10 --f=1.0 --trial=2 --name='s_10_f_1.0' +python3 run_model.py --device=0 --s=10 --f=1.0 --trial=3 --name='s_10_f_1.0' +python3 run_model.py --device=0 --s=1 --f=2.0 --trial=1 --name='s_1_f_2.0' +python3 run_model.py --device=0 --s=1 --f=2.0 --trial=2 --name='s_1_f_2.0' +python3 run_model.py --device=0 --s=1 --f=2.0 --trial=3 --name='s_1_f_2.0' +python3 run_model.py --device=0 --s=2 --f=2.0 --trial=1 --name='s_2_f_2.0' +python3 run_model.py --device=0 --s=2 --f=2.0 --trial=2 --name='s_2_f_2.0' +python3 run_model.py --device=0 --s=2 --f=2.0 --trial=3 --name='s_2_f_2.0' +python3 run_model.py --device=0 --s=3 --f=2.0 --trial=1 --name='s_3_f_2.0' +python3 run_model.py --device=0 --s=3 --f=2.0 --trial=2 --name='s_3_f_2.0' +python3 run_model.py --device=0 --s=3 --f=2.0 --trial=3 --name='s_3_f_2.0' diff --git a/V1-models/scripts/s_f_sweep/script2 b/V1-models/scripts/s_f_sweep/script2 new file mode 100755 index 0000000..fcf1b66 --- /dev/null +++ b/V1-models/scripts/s_f_sweep/script2 @@ -0,0 +1,76 @@ +#!/bin/bash +#python3 run_model.py --device=1 --s=4 --f=2.0 --trial=1 --name='s_4_f_2.0' +#python3 run_model.py --device=1 --s=4 --f=2.0 --trial=2 --name='s_4_f_2.0' +#python3 run_model.py --device=1 --s=4 --f=2.0 --trial=3 --name='s_4_f_2.0' +#python3 run_model.py --device=1 --s=5 --f=2.0 --trial=1 --name='s_5_f_2.0' +#python3 run_model.py --device=1 --s=5 --f=2.0 --trial=2 --name='s_5_f_2.0' +#python3 run_model.py --device=1 --s=5 --f=2.0 --trial=3 --name='s_5_f_2.0' +#python3 run_model.py --device=1 --s=6 --f=2.0 --trial=1 --name='s_6_f_2.0' +#python3 run_model.py --device=1 --s=6 --f=2.0 --trial=2 --name='s_6_f_2.0' +#python3 run_model.py --device=1 --s=6 --f=2.0 --trial=3 --name='s_6_f_2.0' +#python3 run_model.py --device=1 --s=10 --f=2.0 --trial=1 --name='s_10_f_2.0' +#python3 run_model.py --device=1 --s=10 --f=2.0 --trial=2 --name='s_10_f_2.0' +#python3 run_model.py --device=1 --s=10 --f=2.0 --trial=3 --name='s_10_f_2.0' +python3 run_model.py --device=1 --s=1 --f=3.0 --trial=1 --name='s_1_f_3.0' +python3 run_model.py --device=1 --s=1 --f=3.0 --trial=2 --name='s_1_f_3.0' +python3 run_model.py --device=1 --s=1 --f=3.0 --trial=3 --name='s_1_f_3.0' +python3 run_model.py --device=1 --s=2 --f=3.0 --trial=1 --name='s_2_f_3.0' +python3 run_model.py --device=1 --s=2 --f=3.0 --trial=2 --name='s_2_f_3.0' +python3 run_model.py --device=1 --s=2 --f=3.0 --trial=3 --name='s_2_f_3.0' +python3 run_model.py --device=1 --s=3 --f=3.0 --trial=1 --name='s_3_f_3.0' +python3 run_model.py --device=1 --s=3 --f=3.0 --trial=2 --name='s_3_f_3.0' +python3 run_model.py --device=1 --s=3 --f=3.0 --trial=3 --name='s_3_f_3.0' +python3 run_model.py --device=1 --s=4 --f=3.0 --trial=1 --name='s_4_f_3.0' +python3 run_model.py --device=1 --s=4 --f=3.0 --trial=2 --name='s_4_f_3.0' +python3 run_model.py --device=1 --s=4 --f=3.0 --trial=3 --name='s_4_f_3.0' +python3 run_model.py --device=1 --s=5 --f=3.0 --trial=1 --name='s_5_f_3.0' +python3 run_model.py --device=1 --s=5 --f=3.0 --trial=2 --name='s_5_f_3.0' +python3 run_model.py --device=1 --s=5 --f=3.0 --trial=3 --name='s_5_f_3.0' +python3 run_model.py --device=1 --s=6 --f=3.0 --trial=1 --name='s_6_f_3.0' +python3 run_model.py --device=1 --s=6 --f=3.0 --trial=2 --name='s_6_f_3.0' +python3 run_model.py --device=1 --s=6 --f=3.0 --trial=3 --name='s_6_f_3.0' +python3 run_model.py --device=1 --s=10 --f=3.0 --trial=1 --name='s_10_f_3.0' +python3 run_model.py --device=1 --s=10 --f=3.0 --trial=2 --name='s_10_f_3.0' +python3 run_model.py --device=1 --s=10 --f=3.0 --trial=3 --name='s_10_f_3.0' +python3 run_model.py --device=1 --s=1 --f=5.0 --trial=1 --name='s_1_f_5.0' +python3 run_model.py --device=1 --s=1 --f=5.0 --trial=2 --name='s_1_f_5.0' +python3 run_model.py --device=1 --s=1 --f=5.0 --trial=3 --name='s_1_f_5.0' +python3 run_model.py --device=1 --s=2 --f=5.0 --trial=1 --name='s_2_f_5.0' +python3 run_model.py --device=1 --s=2 --f=5.0 --trial=2 --name='s_2_f_5.0' +python3 run_model.py --device=1 --s=2 --f=5.0 --trial=3 --name='s_2_f_5.0' +python3 run_model.py --device=1 --s=3 --f=5.0 --trial=1 --name='s_3_f_5.0' +python3 run_model.py --device=1 --s=3 --f=5.0 --trial=2 --name='s_3_f_5.0' +python3 run_model.py --device=1 --s=3 --f=5.0 --trial=3 --name='s_3_f_5.0' +python3 run_model.py --device=1 --s=4 --f=5.0 --trial=1 --name='s_4_f_5.0' +python3 run_model.py --device=1 --s=4 --f=5.0 --trial=2 --name='s_4_f_5.0' +python3 run_model.py --device=1 --s=4 --f=5.0 --trial=3 --name='s_4_f_5.0' +python3 run_model.py --device=1 --s=5 --f=5.0 --trial=1 --name='s_5_f_5.0' +python3 run_model.py --device=1 --s=5 --f=5.0 --trial=2 --name='s_5_f_5.0' +python3 run_model.py --device=1 --s=5 --f=5.0 --trial=3 --name='s_5_f_5.0' +python3 run_model.py --device=1 --s=6 --f=5.0 --trial=1 --name='s_6_f_5.0' +python3 run_model.py --device=1 --s=6 --f=5.0 --trial=2 --name='s_6_f_5.0' +python3 run_model.py --device=1 --s=6 --f=5.0 --trial=3 --name='s_6_f_5.0' +python3 run_model.py --device=1 --s=10 --f=5.0 --trial=1 --name='s_10_f_5.0' +python3 run_model.py --device=1 --s=10 --f=5.0 --trial=2 --name='s_10_f_5.0' +python3 run_model.py --device=1 --s=10 --f=5.0 --trial=3 --name='s_10_f_5.0' +python3 run_model.py --device=1 --s=1 --f=7.0 --trial=1 --name='s_1_f_7.0' +python3 run_model.py --device=1 --s=1 --f=7.0 --trial=2 --name='s_1_f_7.0' +python3 run_model.py --device=1 --s=1 --f=7.0 --trial=3 --name='s_1_f_7.0' +python3 run_model.py --device=1 --s=2 --f=7.0 --trial=1 --name='s_2_f_7.0' +python3 run_model.py --device=1 --s=2 --f=7.0 --trial=2 --name='s_2_f_7.0' +python3 run_model.py --device=1 --s=2 --f=7.0 --trial=3 --name='s_2_f_7.0' +python3 run_model.py --device=1 --s=3 --f=7.0 --trial=1 --name='s_3_f_7.0' +python3 run_model.py --device=1 --s=3 --f=7.0 --trial=2 --name='s_3_f_7.0' +python3 run_model.py --device=1 --s=3 --f=7.0 --trial=3 --name='s_3_f_7.0' +python3 run_model.py --device=1 --s=4 --f=7.0 --trial=1 --name='s_4_f_7.0' +python3 run_model.py --device=1 --s=4 --f=7.0 --trial=2 --name='s_4_f_7.0' +python3 run_model.py --device=1 --s=4 --f=7.0 --trial=3 --name='s_4_f_7.0' +python3 run_model.py --device=1 --s=5 --f=7.0 --trial=1 --name='s_5_f_7.0' +python3 run_model.py --device=1 --s=5 --f=7.0 --trial=2 --name='s_5_f_7.0' +python3 run_model.py --device=1 --s=5 --f=7.0 --trial=3 --name='s_5_f_7.0' +python3 run_model.py --device=1 --s=6 --f=7.0 --trial=1 --name='s_6_f_7.0' +python3 run_model.py --device=1 --s=6 --f=7.0 --trial=2 --name='s_6_f_7.0' +python3 run_model.py --device=1 --s=6 --f=7.0 --trial=3 --name='s_6_f_7.0' +python3 run_model.py --device=1 --s=10 --f=7.0 --trial=1 --name='s_10_f_7.0' +python3 run_model.py --device=1 --s=10 --f=7.0 --trial=2 --name='s_10_f_7.0' +python3 run_model.py --device=1 --s=10 --f=7.0 --trial=3 --name='s_10_f_7.0' diff --git a/V1-models/scripts/s_f_sweep/sweep_data.py b/V1-models/scripts/s_f_sweep/sweep_data.py new file mode 100644 index 0000000..77eef1d --- /dev/null +++ b/V1-models/scripts/s_f_sweep/sweep_data.py @@ -0,0 +1,50 @@ +#Python script to take the average of 3 trials of BN_V1_V1_Linear run +import numpy as np +import argparse +import os +import torch +from scipy.stats import sem + +if __name__ == '__main__': + f_arr = np.array([0.1, 0.5, 1, 2, 3, 5, 7]) + s_arr = np.array([1, 2, 3, 4, 5, 6, 10]) + + file = open("sweep_data.txt", 'w') + header = "s f loss loss_err accuracy accuracy_err" + file.write(header) + file.write("\n") + + for f in f_arr: + for s in s_arr: + + src = "/research/harris/vivian/v1-models/saved-models/BN_V1_V1_Linear/s_f_sweep/" + "s_" + str(s) + "_f_" + str(f) + loss = [] + accuracy = [] + for i in range(3): + n=i+1 + os.chdir(src + "/trial_" + str(n)) + + loss_trial = torch.load("loss.pt")[-1] + accuracy_trial = torch.load("accuracy.pt")[-1] + + loss.append(loss_trial) + accuracy.append(accuracy_trial) + + + avg_loss = np.round_(np.mean(loss), decimals=4) + avg_acc = np.round_(np.mean(accuracy), decimals=2) + loss_err = np.round_(sem(loss), decimals=4) + acc_err = np.round_(sem(accuracy), decimals=2) + + + string = str(s) + " " + str(f) + " " + str(avg_loss) + " " + "+/-" + str(loss_err) + " " + str(avg_acc) + " " + "+/-" + str(acc_err) + file.write(string) + file.write("\n") + + + + + + + + \ No newline at end of file diff --git a/V1-models/scripts/s_f_sweep/sweep_data.txt b/V1-models/scripts/s_f_sweep/sweep_data.txt new file mode 100644 index 0000000..2531fc4 --- /dev/null +++ b/V1-models/scripts/s_f_sweep/sweep_data.txt @@ -0,0 +1,50 @@ +s f loss loss_err accuracy accuracy_err +1 0.1 1.0122 +/-0.0062 65.75 +/-0.27 +2 0.1 1.0045 +/-0.0043 65.86 +/-0.06 +3 0.1 1.0125 +/-0.0035 65.51 +/-0.12 +4 0.1 1.0315 +/-0.022 64.66 +/-0.75 +5 0.1 1.0101 +/-0.0078 65.77 +/-0.23 +6 0.1 1.0105 +/-0.006 65.7 +/-0.17 +10 0.1 1.0081 +/-0.0025 65.77 +/-0.03 +1 0.5 1.0295 +/-0.0053 65.09 +/-0.21 +2 0.5 1.0257 +/-0.0023 65.26 +/-0.09 +3 0.5 1.017 +/-0.0065 65.31 +/-0.39 +4 0.5 1.0134 +/-0.0063 65.37 +/-0.21 +5 0.5 1.0158 +/-0.0079 65.31 +/-0.34 +6 0.5 1.0201 +/-0.0047 65.29 +/-0.23 +10 0.5 1.0134 +/-0.004 65.5 +/-0.09 +1 1.0 1.0411 +/-0.0084 64.25 +/-0.32 +2 1.0 1.0378 +/-0.0067 64.49 +/-0.14 +3 1.0 1.0399 +/-0.0019 64.37 +/-0.26 +4 1.0 1.0436 +/-0.0068 64.13 +/-0.29 +5 1.0 1.0438 +/-0.0094 64.3 +/-0.53 +6 1.0 1.0523 +/-0.0044 63.95 +/-0.12 +10 1.0 1.0473 +/-0.0069 64.21 +/-0.37 +1 2.0 1.0539 +/-0.0016 63.78 +/-0.08 +2 2.0 1.0626 +/-0.0025 63.67 +/-0.23 +3 2.0 1.0705 +/-0.0078 63.24 +/-0.51 +4 2.0 1.0723 +/-0.0024 62.95 +/-0.12 +5 2.0 1.0695 +/-0.0054 63.17 +/-0.19 +6 2.0 1.0797 +/-0.0065 62.96 +/-0.23 +10 2.0 1.1128 +/-0.0192 61.44 +/-0.75 +1 3.0 1.0621 +/-0.0022 63.42 +/-0.11 +2 3.0 1.0766 +/-0.0047 63.07 +/-0.23 +3 3.0 1.0916 +/-0.0048 62.2 +/-0.16 +4 3.0 1.101 +/-0.0093 62.12 +/-0.42 +5 3.0 1.0913 +/-0.0055 62.67 +/-0.03 +6 3.0 1.1036 +/-0.0021 61.88 +/-0.06 +10 3.0 1.0994 +/-0.0175 61.78 +/-0.59 +1 5.0 1.0879 +/-0.0011 62.63 +/-0.15 +2 5.0 1.1076 +/-0.0012 61.54 +/-0.05 +3 5.0 1.1116 +/-0.0053 61.52 +/-0.24 +4 5.0 1.1178 +/-0.0058 61.38 +/-0.17 +5 5.0 1.1131 +/-0.0034 61.53 +/-0.18 +6 5.0 1.1122 +/-0.0054 61.4 +/-0.15 +10 5.0 1.1212 +/-0.0062 61.38 +/-0.11 +1 7.0 1.0922 +/-0.0062 62.48 +/-0.26 +2 7.0 1.1276 +/-0.0033 60.94 +/-0.09 +3 7.0 1.1162 +/-0.0108 61.41 +/-0.43 +4 7.0 1.1275 +/-0.0112 61.01 +/-0.38 +5 7.0 1.1298 +/-0.0059 60.79 +/-0.18 +6 7.0 1.1261 +/-0.0055 61.19 +/-0.11 +10 7.0 1.1242 +/-0.0026 61.27 +/-0.09 diff --git a/V1-models/scripts/s_f_sweep/sweep_data_script.py b/V1-models/scripts/s_f_sweep/sweep_data_script.py new file mode 100644 index 0000000..8dbbe52 --- /dev/null +++ b/V1-models/scripts/s_f_sweep/sweep_data_script.py @@ -0,0 +1,6 @@ +f_list = [0.1, 2.0] +s_list = [1, 2, 3] +for f in f_list: + for s in s_list: + new_name = "s_"+str(s)+"_f_"+str(f) + print("python3 sweep_data.py --s={} --f={} ".format(s, f)) \ No newline at end of file From 428ae583ceaff47228ddf19ada3013ab7f841d64 Mon Sep 17 00:00:00 2001 From: vivianwhite <66977221+vivianwhite@users.noreply.github.com> Date: Fri, 3 May 2024 15:49:08 -0700 Subject: [PATCH 56/77] Delete V1-models/scripts/s_f_sweep directory --- .../scripts/s_f_sweep/generate_params.py | 18 -- V1-models/scripts/s_f_sweep/run_model.py | 195 ------------------ V1-models/scripts/s_f_sweep/script1 | 73 ------- V1-models/scripts/s_f_sweep/script2 | 76 ------- V1-models/scripts/s_f_sweep/sweep_data.py | 50 ----- V1-models/scripts/s_f_sweep/sweep_data.txt | 50 ----- .../scripts/s_f_sweep/sweep_data_script.py | 6 - 7 files changed, 468 deletions(-) delete mode 100644 V1-models/scripts/s_f_sweep/generate_params.py delete mode 100644 V1-models/scripts/s_f_sweep/run_model.py delete mode 100755 V1-models/scripts/s_f_sweep/script1 delete mode 100755 V1-models/scripts/s_f_sweep/script2 delete mode 100644 V1-models/scripts/s_f_sweep/sweep_data.py delete mode 100644 V1-models/scripts/s_f_sweep/sweep_data.txt delete mode 100644 V1-models/scripts/s_f_sweep/sweep_data_script.py diff --git a/V1-models/scripts/s_f_sweep/generate_params.py b/V1-models/scripts/s_f_sweep/generate_params.py deleted file mode 100644 index 0a50aac..0000000 --- a/V1-models/scripts/s_f_sweep/generate_params.py +++ /dev/null @@ -1,18 +0,0 @@ -import numpy as np - -f_arr = np.array([0.1, 0.5, 1, 2, 3, 5, 7]) -s_arr = np.array([1, 2, 3, 4, 5, 6, 10]) - -num_combinations = len(f_arr) * len(s_arr) -i = 0 -for f in f_arr: - for s in s_arr: - i += 1 - if i <= num_combinations / 2: - for trial in range(3): - print("Trial: {} s: {} f: {}".format(trial+1, s, f)) - print("python3 run_model.py --device=0 --s={} --f={} --trial={} --name='s_{}_f_{}'".format(s, f, trial+1, s, f)) - else: - for trial in range(3): - print("Trial: {} s: {} f: {}".format(trial+1, s, f)) - print("python3 run_model.py --device=1 --s={} --f={} --trial={} --name='s_{}_f_{}'".format(s, f, trial+1, s, f)) \ No newline at end of file diff --git a/V1-models/scripts/s_f_sweep/run_model.py b/V1-models/scripts/s_f_sweep/run_model.py deleted file mode 100644 index f81f76f..0000000 --- a/V1-models/scripts/s_f_sweep/run_model.py +++ /dev/null @@ -1,195 +0,0 @@ -#runs from script1 and script2 from generate_params.py - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim -from torchvision import datasets, transforms -from kymatio.torch import Scattering2D -import kymatio.datasets as scattering_datasets -import matplotlib.pyplot as plt -from datetime import datetime -import os -from distutils.util import strtobool -import argparse -import sys -sys.path.insert(0, '/research/harris/vivian/structured_random_features/') -from src.models.init_weights import V1_init, classical_init, V1_weights - - -class BN_V1_V1_LinearLayer(nn.Module): - def __init__(self, hidden_dim, size, spatial_freq, scale, bias, seed=None): - super(BN_V1_V1_LinearLayer, self).__init__() - self.v1_layer = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.v1_layer2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=7, stride=1, padding=3, - bias=bias) - self.clf = nn.Linear((3 * (8 ** 2)) + (hidden_dim * (8 ** 2)) + (hidden_dim * (8 ** 2)), 10) - self.relu = nn.ReLU() - self.bn = nn.BatchNorm2d(3) - self.bn0 = nn.BatchNorm2d(3) - self.bn1 = nn.BatchNorm2d(hidden_dim) - self.bn2 = nn.BatchNorm2d(hidden_dim) - - scale1 = hidden_dim / ((3 * (32 * 32) ** 2) ) - scale2 = hidden_dim / ((hidden_dim * (32 * 32) ** 2)) - center = None - - V1_init(self.v1_layer, size, spatial_freq, center, scale1, bias, seed) - self.v1_layer.weight.requires_grad = False - - V1_init(self.v1_layer2, size, spatial_freq, center, scale2, bias, seed) - self.v1_layer2.weight.requires_grad = False - - if bias==True: - self.v1_layer.bias.requires_grad = False - self.v1_layer2.bias.requires_grad = False - - def forward(self, x): #[128, 3, 32, 32] - x = self.bn(x) - - layer1 = self.v1_layer(x) #[128, hidden_dim, 32, 32] w/ k=7, s=1, p=3 - layer1bn = self.bn1(layer1) - - layer2 = self.v1_layer2(layer1bn) #COULD ADD BATCH NORM HERE #[128, hidden_dim, 32, 32] w/ k=7, s=1, p=3 - layer2bn = self.bn2(layer2) - - x1 = self.relu(x) - h1 = self.relu(layer1bn) - h2 = self.relu(layer2bn) - - pool = nn.AvgPool2d(kernel_size=4, stride=4, padding=1) - x_pool = pool(x1) #[128, 3, 8, 8] - h1_pool = pool(h1) #[128, hidden_dim, 8, 8] - h2_pool = pool(h2) #[128, hidden_dim, 8, 8] - - x_flat = x_pool.view(x_pool.size(0), -1) #[128, 192] = [128, 3 * 8 * 8] std ~1, mean ~0 - h1_flat = h1_pool.view(h1_pool.size(0), -1) #[128, hidden_dim * 8 * 8] std ~1, mean ~0 - h2_flat = h2_pool.view(h2_pool.size(0), -1) #[128, hidden_dim * 8 * 8] std ~1, mean ~0 - - concat = torch.cat((x_flat, h1_flat, h2_flat), 1) #[128, (3 * 8 * 8) + (hidden_dim * 8 * 8) + (hidden_dim * 8 * 8) - - beta = self.clf(concat) #[128, 10] - return beta - -def train(model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = F.cross_entropy(output, target) - loss.backward() - optimizer.step() - - -def test(model, device, test_loader, epoch): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - output = model(data) - test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss - pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - accuracy = 100. * correct / len(test_loader.dataset) - - if epoch == 90: - print('Loss: {:.4f}\t Accuracy: {:.2f}%'.format(test_loss, accuracy)) - - return test_loss, accuracy - -def save_model(args, model): - src = "/research/harris/vivian/v1-models/saved-models/BN_V1_V1_Linear/" - model_dir = src + "s_f_sweep/" + args.name - if not os.path.exists(model_dir): - os.makedirs(model_dir) - os.chdir(model_dir) - - trial_dir = model_dir + "/trial_" + str(args.trial) - if not os.path.exists(trial_dir): - os.makedirs(trial_dir) - os.chdir(trial_dir) - - torch.save(test_loss, "loss.pt") - torch.save(test_accuracy, "accuracy.pt") - torch.save(model, "model.pt") - torch.save(args, "args.pt") - - -if __name__ == '__main__': - print("STARTING MODEL") - parser = argparse.ArgumentParser(description='CIFAR scattering + hybrid examples') - parser.add_argument('--hidden_dim', type=int, default=100, help='number of hidden dimensions in model') - parser.add_argument('--num_epoch', type=int, default=90, help='number of epochs') - parser.add_argument('--lr', type=float, default=0.01, help='learning rate') - parser.add_argument('--s', type=int, default=3, help='V1 size') - parser.add_argument('--f', type=float, default=1.0, help='V1 spatial frequency') - parser.add_argument('--scale', type=int, default=1, help='V1 scale') - parser.add_argument('--name', type=str, default='new model', help='filename for saved model') - parser.add_argument('--trial', type=int, default=1, help='trial number') - parser.add_argument('--bias', dest='bias', type=lambda x: bool(strtobool(x)), default=False, help='bias=True or False') - parser.add_argument('--device', type=int, default=0, help="which device to use (0 or 1)") - args = parser.parse_args() - initial_lr = args.lr - - use_cuda = torch.cuda.is_available() - device = torch.device("cuda:" + str(args.device) if use_cuda else "cpu") - - start = datetime.now() - - model = BN_V1_V1_LinearLayer(args.hidden_dim, args.s, args.f, args.scale, args.bias).to(device) - - # DataLoaders - if use_cuda: - num_workers = 4 - pin_memory = True - else: - num_workers = 0 - pin_memory = False - - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - train_loader = torch.utils.data.DataLoader( - datasets.CIFAR10(root=scattering_datasets.get_dataset_dir('CIFAR'), train=True, transform=transforms.Compose([ - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32, 4), - transforms.ToTensor(), - normalize, - ]), download=False), - batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) - - test_loader = torch.utils.data.DataLoader( - datasets.CIFAR10(root=scattering_datasets.get_dataset_dir('CIFAR'), train=False, transform=transforms.Compose([ - transforms.ToTensor(), - normalize, - ])), - batch_size=128, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) - - - test_loss = [] - test_accuracy = [] - epoch_list = [] - - for epoch in range(0, args.num_epoch): - if epoch%20==0: - optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9,weight_decay=0.0005, nesterov=True) - args.lr*=0.2 - - train(model, device, train_loader, optimizer, epoch+1) - loss, accuracy = test(model, device, test_loader, epoch+1) - test_loss.append(loss) - test_accuracy.append(accuracy) - epoch_list.append(epoch) - - end = datetime.now() - print("Trial {} time (HH:MM:SS): {}".format(args.trial, end-start)) - #print("Hidden dim: {}\t Learning rate: {}".format(args.hidden_dim, initial_lr)) - - save_model(args, model) - diff --git a/V1-models/scripts/s_f_sweep/script1 b/V1-models/scripts/s_f_sweep/script1 deleted file mode 100755 index eedafe3..0000000 --- a/V1-models/scripts/s_f_sweep/script1 +++ /dev/null @@ -1,73 +0,0 @@ -#!/bin/bash -#python3 run_model.py --device=0 --s=1 --f=0.1 --trial=1 --name='s_1_f_0.1' -#python3 run_model.py --device=0 --s=1 --f=0.1 --trial=2 --name='s_1_f_0.1' -#python3 run_model.py --device=0 --s=1 --f=0.1 --trial=3 --name='s_1_f_0.1' -#python3 run_model.py --device=0 --s=2 --f=0.1 --trial=1 --name='s_2_f_0.1' -#python3 run_model.py --device=0 --s=2 --f=0.1 --trial=2 --name='s_2_f_0.1' -#python3 run_model.py --device=0 --s=2 --f=0.1 --trial=3 --name='s_2_f_0.1' -#python3 run_model.py --device=0 --s=3 --f=0.1 --trial=1 --name='s_3_f_0.1' -#python3 run_model.py --device=0 --s=3 --f=0.1 --trial=2 --name='s_3_f_0.1' -#python3 run_model.py --device=0 --s=3 --f=0.1 --trial=3 --name='s_3_f_0.1' -#python3 run_model.py --device=0 --s=4 --f=0.1 --trial=1 --name='s_4_f_0.1' -#python3 run_model.py --device=0 --s=4 --f=0.1 --trial=2 --name='s_4_f_0.1' -#python3 run_model.py --device=0 --s=4 --f=0.1 --trial=3 --name='s_4_f_0.1' -python3 run_model.py --device=0 --s=5 --f=0.1 --trial=1 --name='s_5_f_0.1' -python3 run_model.py --device=0 --s=5 --f=0.1 --trial=2 --name='s_5_f_0.1' -python3 run_model.py --device=0 --s=5 --f=0.1 --trial=3 --name='s_5_f_0.1' -python3 run_model.py --device=0 --s=6 --f=0.1 --trial=1 --name='s_6_f_0.1' -python3 run_model.py --device=0 --s=6 --f=0.1 --trial=2 --name='s_6_f_0.1' -python3 run_model.py --device=0 --s=6 --f=0.1 --trial=3 --name='s_6_f_0.1' -python3 run_model.py --device=0 --s=10 --f=0.1 --trial=1 --name='s_10_f_0.1' -python3 run_model.py --device=0 --s=10 --f=0.1 --trial=2 --name='s_10_f_0.1' -python3 run_model.py --device=0 --s=10 --f=0.1 --trial=3 --name='s_10_f_0.1' -python3 run_model.py --device=0 --s=1 --f=0.5 --trial=1 --name='s_1_f_0.5' -python3 run_model.py --device=0 --s=1 --f=0.5 --trial=2 --name='s_1_f_0.5' -python3 run_model.py --device=0 --s=1 --f=0.5 --trial=3 --name='s_1_f_0.5' -python3 run_model.py --device=0 --s=2 --f=0.5 --trial=1 --name='s_2_f_0.5' -python3 run_model.py --device=0 --s=2 --f=0.5 --trial=2 --name='s_2_f_0.5' -python3 run_model.py --device=0 --s=2 --f=0.5 --trial=3 --name='s_2_f_0.5' -python3 run_model.py --device=0 --s=3 --f=0.5 --trial=1 --name='s_3_f_0.5' -python3 run_model.py --device=0 --s=3 --f=0.5 --trial=2 --name='s_3_f_0.5' -python3 run_model.py --device=0 --s=3 --f=0.5 --trial=3 --name='s_3_f_0.5' -python3 run_model.py --device=0 --s=4 --f=0.5 --trial=1 --name='s_4_f_0.5' -python3 run_model.py --device=0 --s=4 --f=0.5 --trial=2 --name='s_4_f_0.5' -python3 run_model.py --device=0 --s=4 --f=0.5 --trial=3 --name='s_4_f_0.5' -python3 run_model.py --device=0 --s=5 --f=0.5 --trial=1 --name='s_5_f_0.5' -python3 run_model.py --device=0 --s=5 --f=0.5 --trial=2 --name='s_5_f_0.5' -python3 run_model.py --device=0 --s=5 --f=0.5 --trial=3 --name='s_5_f_0.5' -python3 run_model.py --device=0 --s=6 --f=0.5 --trial=1 --name='s_6_f_0.5' -python3 run_model.py --device=0 --s=6 --f=0.5 --trial=2 --name='s_6_f_0.5' -python3 run_model.py --device=0 --s=6 --f=0.5 --trial=3 --name='s_6_f_0.5' -python3 run_model.py --device=0 --s=10 --f=0.5 --trial=1 --name='s_10_f_0.5' -python3 run_model.py --device=0 --s=10 --f=0.5 --trial=2 --name='s_10_f_0.5' -python3 run_model.py --device=0 --s=10 --f=0.5 --trial=3 --name='s_10_f_0.5' -python3 run_model.py --device=0 --s=1 --f=1.0 --trial=1 --name='s_1_f_1.0' -python3 run_model.py --device=0 --s=1 --f=1.0 --trial=2 --name='s_1_f_1.0' -python3 run_model.py --device=0 --s=1 --f=1.0 --trial=3 --name='s_1_f_1.0' -python3 run_model.py --device=0 --s=2 --f=1.0 --trial=1 --name='s_2_f_1.0' -python3 run_model.py --device=0 --s=2 --f=1.0 --trial=2 --name='s_2_f_1.0' -python3 run_model.py --device=0 --s=2 --f=1.0 --trial=3 --name='s_2_f_1.0' -python3 run_model.py --device=0 --s=3 --f=1.0 --trial=1 --name='s_3_f_1.0' -python3 run_model.py --device=0 --s=3 --f=1.0 --trial=2 --name='s_3_f_1.0' -python3 run_model.py --device=0 --s=3 --f=1.0 --trial=3 --name='s_3_f_1.0' -python3 run_model.py --device=0 --s=4 --f=1.0 --trial=1 --name='s_4_f_1.0' -python3 run_model.py --device=0 --s=4 --f=1.0 --trial=2 --name='s_4_f_1.0' -python3 run_model.py --device=0 --s=4 --f=1.0 --trial=3 --name='s_4_f_1.0' -python3 run_model.py --device=0 --s=5 --f=1.0 --trial=1 --name='s_5_f_1.0' -python3 run_model.py --device=0 --s=5 --f=1.0 --trial=2 --name='s_5_f_1.0' -python3 run_model.py --device=0 --s=5 --f=1.0 --trial=3 --name='s_5_f_1.0' -python3 run_model.py --device=0 --s=6 --f=1.0 --trial=1 --name='s_6_f_1.0' -python3 run_model.py --device=0 --s=6 --f=1.0 --trial=2 --name='s_6_f_1.0' -python3 run_model.py --device=0 --s=6 --f=1.0 --trial=3 --name='s_6_f_1.0' -python3 run_model.py --device=0 --s=10 --f=1.0 --trial=1 --name='s_10_f_1.0' -python3 run_model.py --device=0 --s=10 --f=1.0 --trial=2 --name='s_10_f_1.0' -python3 run_model.py --device=0 --s=10 --f=1.0 --trial=3 --name='s_10_f_1.0' -python3 run_model.py --device=0 --s=1 --f=2.0 --trial=1 --name='s_1_f_2.0' -python3 run_model.py --device=0 --s=1 --f=2.0 --trial=2 --name='s_1_f_2.0' -python3 run_model.py --device=0 --s=1 --f=2.0 --trial=3 --name='s_1_f_2.0' -python3 run_model.py --device=0 --s=2 --f=2.0 --trial=1 --name='s_2_f_2.0' -python3 run_model.py --device=0 --s=2 --f=2.0 --trial=2 --name='s_2_f_2.0' -python3 run_model.py --device=0 --s=2 --f=2.0 --trial=3 --name='s_2_f_2.0' -python3 run_model.py --device=0 --s=3 --f=2.0 --trial=1 --name='s_3_f_2.0' -python3 run_model.py --device=0 --s=3 --f=2.0 --trial=2 --name='s_3_f_2.0' -python3 run_model.py --device=0 --s=3 --f=2.0 --trial=3 --name='s_3_f_2.0' diff --git a/V1-models/scripts/s_f_sweep/script2 b/V1-models/scripts/s_f_sweep/script2 deleted file mode 100755 index fcf1b66..0000000 --- a/V1-models/scripts/s_f_sweep/script2 +++ /dev/null @@ -1,76 +0,0 @@ -#!/bin/bash -#python3 run_model.py --device=1 --s=4 --f=2.0 --trial=1 --name='s_4_f_2.0' -#python3 run_model.py --device=1 --s=4 --f=2.0 --trial=2 --name='s_4_f_2.0' -#python3 run_model.py --device=1 --s=4 --f=2.0 --trial=3 --name='s_4_f_2.0' -#python3 run_model.py --device=1 --s=5 --f=2.0 --trial=1 --name='s_5_f_2.0' -#python3 run_model.py --device=1 --s=5 --f=2.0 --trial=2 --name='s_5_f_2.0' -#python3 run_model.py --device=1 --s=5 --f=2.0 --trial=3 --name='s_5_f_2.0' -#python3 run_model.py --device=1 --s=6 --f=2.0 --trial=1 --name='s_6_f_2.0' -#python3 run_model.py --device=1 --s=6 --f=2.0 --trial=2 --name='s_6_f_2.0' -#python3 run_model.py --device=1 --s=6 --f=2.0 --trial=3 --name='s_6_f_2.0' -#python3 run_model.py --device=1 --s=10 --f=2.0 --trial=1 --name='s_10_f_2.0' -#python3 run_model.py --device=1 --s=10 --f=2.0 --trial=2 --name='s_10_f_2.0' -#python3 run_model.py --device=1 --s=10 --f=2.0 --trial=3 --name='s_10_f_2.0' -python3 run_model.py --device=1 --s=1 --f=3.0 --trial=1 --name='s_1_f_3.0' -python3 run_model.py --device=1 --s=1 --f=3.0 --trial=2 --name='s_1_f_3.0' -python3 run_model.py --device=1 --s=1 --f=3.0 --trial=3 --name='s_1_f_3.0' -python3 run_model.py --device=1 --s=2 --f=3.0 --trial=1 --name='s_2_f_3.0' -python3 run_model.py --device=1 --s=2 --f=3.0 --trial=2 --name='s_2_f_3.0' -python3 run_model.py --device=1 --s=2 --f=3.0 --trial=3 --name='s_2_f_3.0' -python3 run_model.py --device=1 --s=3 --f=3.0 --trial=1 --name='s_3_f_3.0' -python3 run_model.py --device=1 --s=3 --f=3.0 --trial=2 --name='s_3_f_3.0' -python3 run_model.py --device=1 --s=3 --f=3.0 --trial=3 --name='s_3_f_3.0' -python3 run_model.py --device=1 --s=4 --f=3.0 --trial=1 --name='s_4_f_3.0' -python3 run_model.py --device=1 --s=4 --f=3.0 --trial=2 --name='s_4_f_3.0' -python3 run_model.py --device=1 --s=4 --f=3.0 --trial=3 --name='s_4_f_3.0' -python3 run_model.py --device=1 --s=5 --f=3.0 --trial=1 --name='s_5_f_3.0' -python3 run_model.py --device=1 --s=5 --f=3.0 --trial=2 --name='s_5_f_3.0' -python3 run_model.py --device=1 --s=5 --f=3.0 --trial=3 --name='s_5_f_3.0' -python3 run_model.py --device=1 --s=6 --f=3.0 --trial=1 --name='s_6_f_3.0' -python3 run_model.py --device=1 --s=6 --f=3.0 --trial=2 --name='s_6_f_3.0' -python3 run_model.py --device=1 --s=6 --f=3.0 --trial=3 --name='s_6_f_3.0' -python3 run_model.py --device=1 --s=10 --f=3.0 --trial=1 --name='s_10_f_3.0' -python3 run_model.py --device=1 --s=10 --f=3.0 --trial=2 --name='s_10_f_3.0' -python3 run_model.py --device=1 --s=10 --f=3.0 --trial=3 --name='s_10_f_3.0' -python3 run_model.py --device=1 --s=1 --f=5.0 --trial=1 --name='s_1_f_5.0' -python3 run_model.py --device=1 --s=1 --f=5.0 --trial=2 --name='s_1_f_5.0' -python3 run_model.py --device=1 --s=1 --f=5.0 --trial=3 --name='s_1_f_5.0' -python3 run_model.py --device=1 --s=2 --f=5.0 --trial=1 --name='s_2_f_5.0' -python3 run_model.py --device=1 --s=2 --f=5.0 --trial=2 --name='s_2_f_5.0' -python3 run_model.py --device=1 --s=2 --f=5.0 --trial=3 --name='s_2_f_5.0' -python3 run_model.py --device=1 --s=3 --f=5.0 --trial=1 --name='s_3_f_5.0' -python3 run_model.py --device=1 --s=3 --f=5.0 --trial=2 --name='s_3_f_5.0' -python3 run_model.py --device=1 --s=3 --f=5.0 --trial=3 --name='s_3_f_5.0' -python3 run_model.py --device=1 --s=4 --f=5.0 --trial=1 --name='s_4_f_5.0' -python3 run_model.py --device=1 --s=4 --f=5.0 --trial=2 --name='s_4_f_5.0' -python3 run_model.py --device=1 --s=4 --f=5.0 --trial=3 --name='s_4_f_5.0' -python3 run_model.py --device=1 --s=5 --f=5.0 --trial=1 --name='s_5_f_5.0' -python3 run_model.py --device=1 --s=5 --f=5.0 --trial=2 --name='s_5_f_5.0' -python3 run_model.py --device=1 --s=5 --f=5.0 --trial=3 --name='s_5_f_5.0' -python3 run_model.py --device=1 --s=6 --f=5.0 --trial=1 --name='s_6_f_5.0' -python3 run_model.py --device=1 --s=6 --f=5.0 --trial=2 --name='s_6_f_5.0' -python3 run_model.py --device=1 --s=6 --f=5.0 --trial=3 --name='s_6_f_5.0' -python3 run_model.py --device=1 --s=10 --f=5.0 --trial=1 --name='s_10_f_5.0' -python3 run_model.py --device=1 --s=10 --f=5.0 --trial=2 --name='s_10_f_5.0' -python3 run_model.py --device=1 --s=10 --f=5.0 --trial=3 --name='s_10_f_5.0' -python3 run_model.py --device=1 --s=1 --f=7.0 --trial=1 --name='s_1_f_7.0' -python3 run_model.py --device=1 --s=1 --f=7.0 --trial=2 --name='s_1_f_7.0' -python3 run_model.py --device=1 --s=1 --f=7.0 --trial=3 --name='s_1_f_7.0' -python3 run_model.py --device=1 --s=2 --f=7.0 --trial=1 --name='s_2_f_7.0' -python3 run_model.py --device=1 --s=2 --f=7.0 --trial=2 --name='s_2_f_7.0' -python3 run_model.py --device=1 --s=2 --f=7.0 --trial=3 --name='s_2_f_7.0' -python3 run_model.py --device=1 --s=3 --f=7.0 --trial=1 --name='s_3_f_7.0' -python3 run_model.py --device=1 --s=3 --f=7.0 --trial=2 --name='s_3_f_7.0' -python3 run_model.py --device=1 --s=3 --f=7.0 --trial=3 --name='s_3_f_7.0' -python3 run_model.py --device=1 --s=4 --f=7.0 --trial=1 --name='s_4_f_7.0' -python3 run_model.py --device=1 --s=4 --f=7.0 --trial=2 --name='s_4_f_7.0' -python3 run_model.py --device=1 --s=4 --f=7.0 --trial=3 --name='s_4_f_7.0' -python3 run_model.py --device=1 --s=5 --f=7.0 --trial=1 --name='s_5_f_7.0' -python3 run_model.py --device=1 --s=5 --f=7.0 --trial=2 --name='s_5_f_7.0' -python3 run_model.py --device=1 --s=5 --f=7.0 --trial=3 --name='s_5_f_7.0' -python3 run_model.py --device=1 --s=6 --f=7.0 --trial=1 --name='s_6_f_7.0' -python3 run_model.py --device=1 --s=6 --f=7.0 --trial=2 --name='s_6_f_7.0' -python3 run_model.py --device=1 --s=6 --f=7.0 --trial=3 --name='s_6_f_7.0' -python3 run_model.py --device=1 --s=10 --f=7.0 --trial=1 --name='s_10_f_7.0' -python3 run_model.py --device=1 --s=10 --f=7.0 --trial=2 --name='s_10_f_7.0' -python3 run_model.py --device=1 --s=10 --f=7.0 --trial=3 --name='s_10_f_7.0' diff --git a/V1-models/scripts/s_f_sweep/sweep_data.py b/V1-models/scripts/s_f_sweep/sweep_data.py deleted file mode 100644 index 77eef1d..0000000 --- a/V1-models/scripts/s_f_sweep/sweep_data.py +++ /dev/null @@ -1,50 +0,0 @@ -#Python script to take the average of 3 trials of BN_V1_V1_Linear run -import numpy as np -import argparse -import os -import torch -from scipy.stats import sem - -if __name__ == '__main__': - f_arr = np.array([0.1, 0.5, 1, 2, 3, 5, 7]) - s_arr = np.array([1, 2, 3, 4, 5, 6, 10]) - - file = open("sweep_data.txt", 'w') - header = "s f loss loss_err accuracy accuracy_err" - file.write(header) - file.write("\n") - - for f in f_arr: - for s in s_arr: - - src = "/research/harris/vivian/v1-models/saved-models/BN_V1_V1_Linear/s_f_sweep/" + "s_" + str(s) + "_f_" + str(f) - loss = [] - accuracy = [] - for i in range(3): - n=i+1 - os.chdir(src + "/trial_" + str(n)) - - loss_trial = torch.load("loss.pt")[-1] - accuracy_trial = torch.load("accuracy.pt")[-1] - - loss.append(loss_trial) - accuracy.append(accuracy_trial) - - - avg_loss = np.round_(np.mean(loss), decimals=4) - avg_acc = np.round_(np.mean(accuracy), decimals=2) - loss_err = np.round_(sem(loss), decimals=4) - acc_err = np.round_(sem(accuracy), decimals=2) - - - string = str(s) + " " + str(f) + " " + str(avg_loss) + " " + "+/-" + str(loss_err) + " " + str(avg_acc) + " " + "+/-" + str(acc_err) - file.write(string) - file.write("\n") - - - - - - - - \ No newline at end of file diff --git a/V1-models/scripts/s_f_sweep/sweep_data.txt b/V1-models/scripts/s_f_sweep/sweep_data.txt deleted file mode 100644 index 2531fc4..0000000 --- a/V1-models/scripts/s_f_sweep/sweep_data.txt +++ /dev/null @@ -1,50 +0,0 @@ -s f loss loss_err accuracy accuracy_err -1 0.1 1.0122 +/-0.0062 65.75 +/-0.27 -2 0.1 1.0045 +/-0.0043 65.86 +/-0.06 -3 0.1 1.0125 +/-0.0035 65.51 +/-0.12 -4 0.1 1.0315 +/-0.022 64.66 +/-0.75 -5 0.1 1.0101 +/-0.0078 65.77 +/-0.23 -6 0.1 1.0105 +/-0.006 65.7 +/-0.17 -10 0.1 1.0081 +/-0.0025 65.77 +/-0.03 -1 0.5 1.0295 +/-0.0053 65.09 +/-0.21 -2 0.5 1.0257 +/-0.0023 65.26 +/-0.09 -3 0.5 1.017 +/-0.0065 65.31 +/-0.39 -4 0.5 1.0134 +/-0.0063 65.37 +/-0.21 -5 0.5 1.0158 +/-0.0079 65.31 +/-0.34 -6 0.5 1.0201 +/-0.0047 65.29 +/-0.23 -10 0.5 1.0134 +/-0.004 65.5 +/-0.09 -1 1.0 1.0411 +/-0.0084 64.25 +/-0.32 -2 1.0 1.0378 +/-0.0067 64.49 +/-0.14 -3 1.0 1.0399 +/-0.0019 64.37 +/-0.26 -4 1.0 1.0436 +/-0.0068 64.13 +/-0.29 -5 1.0 1.0438 +/-0.0094 64.3 +/-0.53 -6 1.0 1.0523 +/-0.0044 63.95 +/-0.12 -10 1.0 1.0473 +/-0.0069 64.21 +/-0.37 -1 2.0 1.0539 +/-0.0016 63.78 +/-0.08 -2 2.0 1.0626 +/-0.0025 63.67 +/-0.23 -3 2.0 1.0705 +/-0.0078 63.24 +/-0.51 -4 2.0 1.0723 +/-0.0024 62.95 +/-0.12 -5 2.0 1.0695 +/-0.0054 63.17 +/-0.19 -6 2.0 1.0797 +/-0.0065 62.96 +/-0.23 -10 2.0 1.1128 +/-0.0192 61.44 +/-0.75 -1 3.0 1.0621 +/-0.0022 63.42 +/-0.11 -2 3.0 1.0766 +/-0.0047 63.07 +/-0.23 -3 3.0 1.0916 +/-0.0048 62.2 +/-0.16 -4 3.0 1.101 +/-0.0093 62.12 +/-0.42 -5 3.0 1.0913 +/-0.0055 62.67 +/-0.03 -6 3.0 1.1036 +/-0.0021 61.88 +/-0.06 -10 3.0 1.0994 +/-0.0175 61.78 +/-0.59 -1 5.0 1.0879 +/-0.0011 62.63 +/-0.15 -2 5.0 1.1076 +/-0.0012 61.54 +/-0.05 -3 5.0 1.1116 +/-0.0053 61.52 +/-0.24 -4 5.0 1.1178 +/-0.0058 61.38 +/-0.17 -5 5.0 1.1131 +/-0.0034 61.53 +/-0.18 -6 5.0 1.1122 +/-0.0054 61.4 +/-0.15 -10 5.0 1.1212 +/-0.0062 61.38 +/-0.11 -1 7.0 1.0922 +/-0.0062 62.48 +/-0.26 -2 7.0 1.1276 +/-0.0033 60.94 +/-0.09 -3 7.0 1.1162 +/-0.0108 61.41 +/-0.43 -4 7.0 1.1275 +/-0.0112 61.01 +/-0.38 -5 7.0 1.1298 +/-0.0059 60.79 +/-0.18 -6 7.0 1.1261 +/-0.0055 61.19 +/-0.11 -10 7.0 1.1242 +/-0.0026 61.27 +/-0.09 diff --git a/V1-models/scripts/s_f_sweep/sweep_data_script.py b/V1-models/scripts/s_f_sweep/sweep_data_script.py deleted file mode 100644 index 8dbbe52..0000000 --- a/V1-models/scripts/s_f_sweep/sweep_data_script.py +++ /dev/null @@ -1,6 +0,0 @@ -f_list = [0.1, 2.0] -s_list = [1, 2, 3] -for f in f_list: - for s in s_list: - new_name = "s_"+str(s)+"_f_"+str(f) - print("python3 sweep_data.py --s={} --f={} ".format(s, f)) \ No newline at end of file From c6fa1a4e181d23fd45b14194e6bfc0093dcfb858 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Fri, 3 May 2024 20:24:48 -0400 Subject: [PATCH 57/77] updating LC RSN scripts --- FactConv/RSN-scripts/LC_CIFAR10.py | 10 ++++------ FactConv/RSN-scripts/LC_CIFAR100.py | 2 +- FactConv/RSN-scripts/LC_CIFAR10_Small_Sample.py | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/FactConv/RSN-scripts/LC_CIFAR10.py b/FactConv/RSN-scripts/LC_CIFAR10.py index c246682..9d5cc33 100644 --- a/FactConv/RSN-scripts/LC_CIFAR10.py +++ b/FactConv/RSN-scripts/LC_CIFAR10.py @@ -43,7 +43,7 @@ def test(model, device, test_loader, epoch): return test_loss, accuracy def save_model(args, model, loss, accuracy): - src = "/research/harris/vivian/v1-models/saved-models/CIFAR10/" + src = "../../saved-models/CIFAR10/" model_dir = src + args.name if not os.path.exists(model_dir): os.makedirs(model_dir) @@ -110,21 +110,19 @@ def save_model(args, model, loss, accuracy): train_loader = torch.utils.data.DataLoader( datasets.CIFAR10( - #root=scattering_datasets.get_dataset_dir('CIFAR'), - root="/research/harris/vivian/v1-models/datasets/new_CIFAR10", + root=scattering_datasets.get_dataset_dir('CIFAR'), train=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, - ]), download=False), + ]), download=True), batch_size=512, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) test_loader = torch.utils.data.DataLoader( datasets.CIFAR10( - root="/research/harris/vivian/v1-models/datasets/new_CIFAR10", - #root=scattering_datasets.get_dataset_dir('CIFAR'), + root=scattering_datasets.get_dataset_dir('CIFAR'), train=False, transform=transforms.Compose([ transforms.ToTensor(), diff --git a/FactConv/RSN-scripts/LC_CIFAR100.py b/FactConv/RSN-scripts/LC_CIFAR100.py index 273f9dd..8beef5b 100644 --- a/FactConv/RSN-scripts/LC_CIFAR100.py +++ b/FactConv/RSN-scripts/LC_CIFAR100.py @@ -44,7 +44,7 @@ def test(model, device, test_loader, epoch): return test_loss, accuracy def save_model(args, model, loss, accuracy): - src = "/research/harris/vivian/v1-models/saved-models/CIFAR100/" + src = "../../saved-models/CIFAR100/" model_dir = src + args.name if not os.path.exists(model_dir): os.makedirs(model_dir) diff --git a/FactConv/RSN-scripts/LC_CIFAR10_Small_Sample.py b/FactConv/RSN-scripts/LC_CIFAR10_Small_Sample.py index cad992b..5f8bf14 100644 --- a/FactConv/RSN-scripts/LC_CIFAR10_Small_Sample.py +++ b/FactConv/RSN-scripts/LC_CIFAR10_Small_Sample.py @@ -46,7 +46,7 @@ def test(model, device, test_loader, epoch): return test_loss, accuracy def save_model(args, model, loss, accuracy): - src = "/research/harris/vivian/v1-models/saved-models/CIFAR10_50_Samples/" + src = "../../saved-models/CIFAR10_50_Samples/" model_dir = src + args.name if not os.path.exists(model_dir): os.makedirs(model_dir) From 345a598fd02b3be1c510b52fe7571b954421adfc Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Fri, 3 May 2024 20:27:46 -0400 Subject: [PATCH 58/77] removing hardcoded stuff --- FactConv/pytorch_cifar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 60263c3..4e8ad77 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -14,7 +14,7 @@ from models import define_models def save_model(args, model): - src= "/home/mila/v/vivian.white/scratch/v1-models/saved-models/test_refactor/" + src= "../../saved-models/ResNets/" model_dir = src + args.name os.makedirs(model_dir, exist_ok=True) os.chdir(model_dir) @@ -82,7 +82,7 @@ def save_model(args, model): set_seeds(args.seed) net = net.to(device) -wandb_dir = "/home/mila/v/vivian.white/scratch/v1-models/wandb" +wandb_dir = "../../wandb" os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) From 1c5295f52ddb4078da8425419e103f17618dd7fe Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Fri, 3 May 2024 20:43:27 -0400 Subject: [PATCH 59/77] fixed save dir --- FactConv/pytorch_cifar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 4e8ad77..7de8e15 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -14,7 +14,7 @@ from models import define_models def save_model(args, model): - src= "../../saved-models/ResNets/" + src= "../saved-models/ResNets/" model_dir = src + args.name os.makedirs(model_dir, exist_ok=True) os.chdir(model_dir) From 1b080e97e9164097d8bcb424845a9ba85b3c9c74 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Fri, 3 May 2024 21:46:52 -0400 Subject: [PATCH 60/77] fixed save dir --- FactConv/pytorch_cifar.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 7de8e15..893f38a 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -17,15 +17,9 @@ def save_model(args, model): src= "../saved-models/ResNets/" model_dir = src + args.name os.makedirs(model_dir, exist_ok=True) - os.chdir(model_dir) - #saves loss & accuracy in the trial directory -- all trials - trial_dir = model_dir + "/trial_" + str(1) - os.makedirs(trial_dir, exist_ok=True) - os.chdir(trial_dir) - - torch.save(model.state_dict(), trial_dir+ "/model.pt") - torch.save(args, trial_dir+ "/args.pt") + torch.save(model.state_dict(), model_dir+ "/model.pt") + torch.save(args, model_dir+ "/args.pt") parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') From 2ae341419df85115e7e56960a35813acf747bf17 Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 8 May 2024 03:07:33 -0400 Subject: [PATCH 61/77] new experiments --- FactConv/launched.sh | 21 +++++++++++++++++++++ FactConv/pytorch_cifar.py | 4 ++-- FactConv/setoff.sh | 13 +++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 FactConv/launched.sh create mode 100644 FactConv/setoff.sh diff --git a/FactConv/launched.sh b/FactConv/launched.sh new file mode 100644 index 0000000..e27470c --- /dev/null +++ b/FactConv/launched.sh @@ -0,0 +1,21 @@ +#!/bin/bash +width=(0.125 0.25 0.5 1.0 2.0 4.0) +seed=(0 1 2) + + +for i in ${width[@]} +do + for j in ${seed[@]} + do + sbatch setoff.sh --width $i --seed $j --net resnet18 + sbatch setoff.sh --width $i --seed $j --net fact_resnet18 + sbatch setoff.sh --width $i --seed $j --net fact_us_resnet18 + sbatch setoff.sh --width $i --seed $j --net fact_uc_resnet18 + sbatch setoff.sh --width $i --seed $j --net fact_us_uc_resnet18 + # WHERE IS THIS VIVIAN 🧐🧐🤨🤨 + #sbatch setoff.sh --width $i --seed $j --net fact_diag_us_resnet18 + #sbatch setoff.sh --width $i --seed $j --net fact_diag_uc_resnet18 + #sbatch setoff.sh --width $i --seed $j --net fact_diag_us_uc_resnet18 + done +done + diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 893f38a..264223e 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -70,7 +70,7 @@ def save_model(args, model): print('==> Building model..') net = define_models(args) -run_name = args.net +run_name = "{}_width_{}_seed_{}".format(args.net, args.width, args.seed) print("Args.net: ", args.net) print("Net: ", net) set_seeds(args.seed) @@ -81,7 +81,7 @@ def save_model(args, model): os.chdir(wandb_dir) run = wandb.init(project="FactConv", config=args, - group="pytorch_cifar", name=run_name, dir=wandb_dir) + group="testing", name=run_name, dir=wandb_dir) #wandb.watch(net, log='all', log_freq=1) diff --git a/FactConv/setoff.sh b/FactConv/setoff.sh new file mode 100644 index 0000000..0a3e069 --- /dev/null +++ b/FactConv/setoff.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --gres=gpu:a100l:1 +#SBATCH --constraint="ampere" +#SBATCH -c 8 +#SBATCH --mem=20G +#SBATCH -t 20:00:00 +#SBATCH --output slurm/%j.out +#SBATCH --partition long + +module load python/3.8 +source ../refactor/env/bin/activate +echo "$@" +CUDA_VISIBLE_DEVICES=0 python pytorch_cifar.py "$@" From 378cbd9c5f43dab314e0399679be120eb56d7d60 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Wed, 8 May 2024 04:42:29 -0400 Subject: [PATCH 62/77] implementing diag factconv2d --- FactConv/conv_modules.py | 72 +++++++++++++++++++++++++++++++ FactConv/models/__init__.py | 5 ++- FactConv/models/function_utils.py | 25 ++++++++++- 3 files changed, 100 insertions(+), 2 deletions(-) diff --git a/FactConv/conv_modules.py b/FactConv/conv_modules.py index 5274173..eb8c944 100644 --- a/FactConv/conv_modules.py +++ b/FactConv/conv_modules.py @@ -85,3 +85,75 @@ def _tri_vec_to_mat(self, vec, n, scat_idx): U = torch.diagonal_scatter(U, U.diagonal().exp_()) return U +class DiagFactConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + self.in_features = self.in_channels // self.groups * \ + self.kernel_size[0] * self.kernel_size[1] + triu1 = torch.triu_indices(self.in_channels // self.groups, + self.in_channels // self.groups, + device=self.weight.device, + dtype=torch.long) + mask = triu1[0] == triu1[1] + scat_idx1 = triu1[0][mask]*self.in_channels//self.groups + triu1[1][mask] + self.register_buffer("scat_idx1", scat_idx1, persistent=False) + + triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], + self.kernel_size[0] + * self.kernel_size[1], + device=self.weight.device, + dtype=torch.long) + mask = triu2[0] == triu2[1] + scat_idx2 = triu2[0][mask]*self.kernel_size[0]*self.kernel_size[1] + triu2[1][mask] + + self.register_buffer("scat_idx2", scat_idx2, persistent=False) + + triu1_len = scat_idx1.shape[0] + triu2_len = scat_idx2.shape[0] + + tri1_vec = self.weight.new_zeros((triu1_len,)) + self.tri1_vec = Parameter(tri1_vec) + + tri2_vec = self.weight.new_zeros((triu2_len,)) + self.tri2_vec = Parameter(tri2_vec) + + + def forward(self, input: Tensor) -> Tensor: + U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // + self.groups, self.scat_idx1) + U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], + self.scat_idx2) + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) + + def _tri_vec_to_mat(self, vec, n, scat_idx): + U = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n) + U = torch.diagonal_scatter(U, U.diagonal().exp_()) + return U + diff --git a/FactConv/models/__init__.py b/FactConv/models/__init__.py index 434ce4f..5498049 100644 --- a/FactConv/models/__init__.py +++ b/FactConv/models/__init__.py @@ -1,5 +1,6 @@ from .resnet import ResNet18 -from .function_utils import replace_layers_factconv2d, turn_off_covar_grad, replace_layers_scale, init_V1_layers +from .function_utils import replace_layers_factconv2d,\ +replace_layers_diagfactconv2d, turn_off_covar_grad, replace_layers_scale, init_V1_layers def define_models(args): @@ -9,6 +10,8 @@ def define_models(args): replace_layers_scale(model, args.width) if 'fact' in args.net: replace_layers_factconv2d(model) + if 'diag' in args.net: + replace_layers_diagfactconv2d(model) if "v1" in args.net: init_V1_layers(model, bias=False) if "us" in args.net: diff --git a/FactConv/models/function_utils.py b/FactConv/models/function_utils.py index c642039..7207210 100644 --- a/FactConv/models/function_utils.py +++ b/FactConv/models/function_utils.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from conv_modules import FactConv2d +from conv_modules import FactConv2d, DiagFactConv2d from V1_covariance import V1_init @@ -38,6 +38,29 @@ def _replace_layers_factconv2d(module): return recurse_preorder(model, _replace_layers_factconv2d) +def replace_layers_diagfactconv2d(model): + ''' + Replace nn.Conv2d layers with DiagFactConv2d + ''' + def _replace_layers_diagfactconv2d(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = DiagFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_diagfactconv2d) + + def replace_affines(model): ''' Set BatchNorm2d layers to have 'affine=False' From 2ffc07190d566f7e8753ecba29669fb48a267ebd Mon Sep 17 00:00:00 2001 From: Muawiz Chaudhary Date: Wed, 8 May 2024 04:56:29 -0400 Subject: [PATCH 63/77] minor changes --- FactConv/launched.sh | 18 ++++++++++-------- FactConv/models/function_utils.py | 2 +- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/FactConv/launched.sh b/FactConv/launched.sh index e27470c..3904685 100644 --- a/FactConv/launched.sh +++ b/FactConv/launched.sh @@ -7,15 +7,17 @@ for i in ${width[@]} do for j in ${seed[@]} do - sbatch setoff.sh --width $i --seed $j --net resnet18 - sbatch setoff.sh --width $i --seed $j --net fact_resnet18 - sbatch setoff.sh --width $i --seed $j --net fact_us_resnet18 - sbatch setoff.sh --width $i --seed $j --net fact_uc_resnet18 - sbatch setoff.sh --width $i --seed $j --net fact_us_uc_resnet18 + #sbatch setoff.sh --width $i --seed $j --net resnet18 + #sbatch setoff.sh --width $i --seed $j --net fact_resnet18 + #sbatch setoff.sh --width $i --seed $j --net fact_us_resnet18 + #sbatch setoff.sh --width $i --seed $j --net fact_uc_resnet18 + #sbatch setoff.sh --width $i --seed $j --net fact_us_uc_resnet18 # WHERE IS THIS VIVIAN 🧐🧐🤨🤨 - #sbatch setoff.sh --width $i --seed $j --net fact_diag_us_resnet18 - #sbatch setoff.sh --width $i --seed $j --net fact_diag_uc_resnet18 - #sbatch setoff.sh --width $i --seed $j --net fact_diag_us_uc_resnet18 + # NVM good job Vivian + sbatch setoff.sh --width $i --seed $j --net fact_diag_resnet18 + sbatch setoff.sh --width $i --seed $j --net fact_diag_us_resnet18 + sbatch setoff.sh --width $i --seed $j --net fact_diag_uc_resnet18 + sbatch setoff.sh --width $i --seed $j --net fact_diag_us_uc_resnet18 done done diff --git a/FactConv/models/function_utils.py b/FactConv/models/function_utils.py index 7207210..0c732aa 100644 --- a/FactConv/models/function_utils.py +++ b/FactConv/models/function_utils.py @@ -172,7 +172,7 @@ def turn_off_covar_grad(model, covariance): channel or spatial covariance learning ''' def _turn_off_covar_grad(module): - if isinstance(module, FactConv2d): + if isinstance(module, FactConv2d) or isinstance(module, DiagFactConv2d): for name, param in module.named_parameters(): if covariance == "channel": if "tri1_vec" in name: From 74aa5a3ec40dc14d7cb3b63f64ab2e52765c0823 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Wed, 8 May 2024 08:42:20 -0400 Subject: [PATCH 64/77] conv_modules.py --- FactConv/models/__init__.py | 4 +++- FactConv/models/function_utils.py | 24 +++++++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/FactConv/models/__init__.py b/FactConv/models/__init__.py index 5498049..d5550f2 100644 --- a/FactConv/models/__init__.py +++ b/FactConv/models/__init__.py @@ -1,6 +1,6 @@ from .resnet import ResNet18 from .function_utils import replace_layers_factconv2d,\ -replace_layers_diagfactconv2d, turn_off_covar_grad, replace_layers_scale, init_V1_layers +replace_layers_diagfactconv2d, replace_layers_diagchanfactconv2d, turn_off_covar_grad, replace_layers_scale, init_V1_layers def define_models(args): @@ -12,6 +12,8 @@ def define_models(args): replace_layers_factconv2d(model) if 'diag' in args.net: replace_layers_diagfactconv2d(model) + if 'diagchan' in args.net: + replace_layers_diagchanfactconv2d(model) if "v1" in args.net: init_V1_layers(model, bias=False) if "us" in args.net: diff --git a/FactConv/models/function_utils.py b/FactConv/models/function_utils.py index 0c732aa..4d207e0 100644 --- a/FactConv/models/function_utils.py +++ b/FactConv/models/function_utils.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from conv_modules import FactConv2d, DiagFactConv2d +from conv_modules import FactConv2d, DiagFactConv2d, DiagChanFactConv2d from V1_covariance import V1_init @@ -60,6 +60,28 @@ def _replace_layers_diagfactconv2d(module): return new_module return recurse_preorder(model, _replace_layers_diagfactconv2d) +def replace_layers_diagchanfactconv2d(model): + ''' + Replace nn.Conv2d layers with DiagChanFactConv2d + ''' + def _replace_layers_diagchanfactconv2d(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = DiagChanFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_diagchanfactconv2d) + def replace_affines(model): ''' From 2366ef82629908fc65bf3a77d695dee91b2879f7 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Wed, 8 May 2024 08:44:30 -0400 Subject: [PATCH 65/77] add diagchan, call by setting args.net='resnet18-diagchan' --- FactConv/conv_modules.py | 70 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/FactConv/conv_modules.py b/FactConv/conv_modules.py index eb8c944..60db42c 100644 --- a/FactConv/conv_modules.py +++ b/FactConv/conv_modules.py @@ -157,3 +157,73 @@ def _tri_vec_to_mat(self, vec, n, scat_idx): U = torch.diagonal_scatter(U, U.diagonal().exp_()) return U +class DiagChanFactConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + self.in_features = self.in_channels // self.groups * \ + self.kernel_size[0] * self.kernel_size[1] + triu1 = torch.triu_indices(self.in_channels // self.groups, + self.in_channels // self.groups, + device=self.weight.device, + dtype=torch.long) + mask = triu1[0] == triu1[1] + scat_idx1 = triu1[0][mask]*self.in_channels//self.groups + triu1[1][mask] + self.register_buffer("scat_idx1", scat_idx1, persistent=False) + + triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], + self.kernel_size[0] + * self.kernel_size[1], + device=self.weight.device, + dtype=torch.long) + scat_idx2 = triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1] + self.register_buffer("scat_idx2", scat_idx2, persistent=False) + + triu1_len = scat_idx1.shape[0] + triu2_len = triu2.shape[1] + + tri1_vec = self.weight.new_zeros((triu1_len,)) + self.tri1_vec = Parameter(tri1_vec) + + tri2_vec = self.weight.new_zeros((triu2_len,)) + self.tri2_vec = Parameter(tri2_vec) + + + def forward(self, input: Tensor) -> Tensor: + U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // + self.groups, self.scat_idx1) + U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], + self.scat_idx2) + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) + + def _tri_vec_to_mat(self, vec, n, scat_idx): + U = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n) + U = torch.diagonal_scatter(U, U.diagonal().exp_()) + return U + From c720d95c2bf757a794a1be049de9415536896c64 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Tue, 25 Jun 2024 17:08:25 -0400 Subject: [PATCH 66/77] added low-rank setup --- FactConv/conv_modules.py | 105 ++++++++++++++++++++---------- FactConv/models/__init__.py | 16 ++++- FactConv/models/function_utils.py | 53 ++++++++++++--- FactConv/pytorch_cifar.py | 20 ++++-- 4 files changed, 142 insertions(+), 52 deletions(-) diff --git a/FactConv/conv_modules.py b/FactConv/conv_modules.py index 60db42c..0d80a25 100644 --- a/FactConv/conv_modules.py +++ b/FactConv/conv_modules.py @@ -4,6 +4,7 @@ from torch.nn.parameter import Parameter from torch.nn.common_types import _size_2_t from typing import Optional, List, Tuple, Union +from cov import Covariance, LowRankCovariance """ The function below is copied directly from @@ -41,38 +42,71 @@ def __init__( self.register_buffer("weight", new_weight) nn.init.kaiming_normal_(self.weight) - self.in_features = self.in_channels // self.groups * \ - self.kernel_size[0] * self.kernel_size[1] - triu1 = torch.triu_indices(self.in_channels // self.groups, - self.in_channels // self.groups, - device=self.weight.device, - dtype=torch.long) - scat_idx1 = triu1[0]*self.in_channels//self.groups + triu1[1] - self.register_buffer("scat_idx1", scat_idx1, persistent=False) + # self.in_features = self.in_channels // self.groups * \ + # self.kernel_size[0] * self.kernel_size[1] - triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], - self.kernel_size[0] - * self.kernel_size[1], - device=self.weight.device, - dtype=torch.long) - scat_idx2 = triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1] - self.register_buffer("scat_idx2", scat_idx2, persistent=False) + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] - triu1_len = triu1.shape[1] - triu2_len = triu2.shape[1] + self.channel = Covariance(channel_triu_size) + self.spatial = Covariance(spatial_triu_size) - tri1_vec = self.weight.new_zeros((triu1_len,)) - self.tri1_vec = Parameter(tri1_vec) + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel._tri_vec_to_mat(self.channel.tri_vec, + self.channel.triu_size, self.channel.scat_idx) + U2 = self.spatial._tri_vec_to_mat(self.spatial.tri_vec, + self.spatial.triu_size, self.spatial.scat_idx) - tri2_vec = self.weight.new_zeros((triu2_len,)) - self.tri2_vec = Parameter(tri2_vec) + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) + + +class LowRankFactConv2d(nn.Conv2d): + def __init__( + self, + spatial_k, + channel_k, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None, + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + self.spatial_k = spatial_k + self.channel_k = channel_k + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = LowRankCovariance(channel_triu_size, self.channel_k) + self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k) def forward(self, input: Tensor) -> Tensor: - U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // - self.groups, self.scat_idx1) - U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], - self.scat_idx2) + U1 = self.channel._tri_vec_to_mat(self.channel.tri_vec, + self.channel.triu_size, self.channel.scat_idx) + U2 = self.spatial._tri_vec_to_mat(self.spatial.tri_vec, + self.spatial.triu_size, self.spatial.scat_idx) + # flatten over filter dims and contract composite_weight = _contract(self.weight, U1.T, 1) composite_weight = _contract( @@ -80,11 +114,6 @@ def forward(self, input: Tensor) -> Tensor: ).reshape(self.weight.shape) return self._conv_forward(input, composite_weight, self.bias) - def _tri_vec_to_mat(self, vec, n, scat_idx): - U = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n) - U = torch.diagonal_scatter(U, U.diagonal().exp_()) - return U - class DiagFactConv2d(nn.Conv2d): def __init__( self, @@ -184,12 +213,16 @@ def __init__( self.in_features = self.in_channels // self.groups * \ self.kernel_size[0] * self.kernel_size[1] - triu1 = torch.triu_indices(self.in_channels // self.groups, - self.in_channels // self.groups, - device=self.weight.device, - dtype=torch.long) - mask = triu1[0] == triu1[1] - scat_idx1 = triu1[0][mask]*self.in_channels//self.groups + triu1[1][mask] + #triu1 = torch.triu_indices(self.in_channels // self.groups, + # self.in_channels // self.groups, + # device=self.weight.device, + # dtype=torch.long) + #mask = triu1[0] == triu1[1] + #scat_idx1 = triu1[0][mask]*self.in_channels//self.groups + triu1[1][mask] + matrix = torch.arange(self.in_channels // self.groups \ + * self.in_channels // self.groups).view(self.in_channels // \ + self.groups, self.in_channels // self.groups) + scat_idx1 = torch.diagonal(matrix) self.register_buffer("scat_idx1", scat_idx1, persistent=False) triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], diff --git a/FactConv/models/__init__.py b/FactConv/models/__init__.py index d5550f2..c245851 100644 --- a/FactConv/models/__init__.py +++ b/FactConv/models/__init__.py @@ -1,11 +1,19 @@ from .resnet import ResNet18 +from .LC_models import LC_CIFAR10 from .function_utils import replace_layers_factconv2d,\ -replace_layers_diagfactconv2d, replace_layers_diagchanfactconv2d, turn_off_covar_grad, replace_layers_scale, init_V1_layers +replace_layers_diagfactconv2d, replace_layers_diagchanfactconv2d,\ +turn_off_covar_grad, replace_layers_scale, init_V1_layers,\ +replace_layers_lowrank def define_models(args): if 'resnet18' in args.net: model = ResNet18() + if 'rsn' in args.net: + model = LC_CIFAR10(hidden_dim=100, size=2, spatial_freq=0.1, scale=1, + bias=True, freeze_spatial=False, freeze_channel=False, + spatial_init='V1') + print("RSN model") if args.width != 1: replace_layers_scale(model, args.width) if 'fact' in args.net: @@ -14,11 +22,17 @@ def define_models(args): replace_layers_diagfactconv2d(model) if 'diagchan' in args.net: replace_layers_diagchanfactconv2d(model) + print("Diag Chan") + if 'lowrank' in args.net: + replace_layers_lowrank(model, args.spatial_k, args.channel_k) + print("Low rank") if "v1" in args.net: init_V1_layers(model, bias=False) if "us" in args.net: turn_off_covar_grad(model, "spatial") + print("US") if "uc" in args.net: turn_off_covar_grad(model, "channel") + print("UC") return model diff --git a/FactConv/models/function_utils.py b/FactConv/models/function_utils.py index 4d207e0..2dfc62d 100644 --- a/FactConv/models/function_utils.py +++ b/FactConv/models/function_utils.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn -from conv_modules import FactConv2d, DiagFactConv2d, DiagChanFactConv2d +from conv_modules import FactConv2d, DiagFactConv2d, DiagChanFactConv2d,\ +LowRankFactConv2d +from cov import Covariance, LowRankCovariance from V1_covariance import V1_init @@ -189,19 +191,30 @@ def _replace_layers_fact_with_conv(module): def turn_off_covar_grad(model, covariance): + print("Unlearnable ", covariance) ''' Turn off gradients in tri1_vec or tri2_vec to turn off channel or spatial covariance learning ''' def _turn_off_covar_grad(module): - if isinstance(module, FactConv2d) or isinstance(module, DiagFactConv2d): - for name, param in module.named_parameters(): - if covariance == "channel": - if "tri1_vec" in name: - param.requires_grad = False - if covariance == "spatial": - if "tri2_vec" in name: - param.requires_grad = False + if isinstance(module, FactConv2d) or isinstance(module, DiagFactConv2d)\ + or isinstance(module, DiagChanFactConv2d) or isinstance(module, LowRankFactConv2d): + for name, mod in module.named_modules(): + if isinstance(mod, Covariance): + if covariance == name: + for name, param in mod.named_parameters(): + if "tri_vec" in name: + param.requires_grad = False + + # for name, param in module.named_parameters(): + # if covariance == "channel": + # if "tri1_vec" in name: + # param.requires_grad = False + # print("Unlearnable Channel") + # if covariance == "spatial": + # if "tri2_vec" in name: + # param.requires_grad = False + # print("Unlearnable Spatial") return recurse_preorder(model, _turn_off_covar_grad) @@ -236,3 +249,25 @@ def _init_V1_layers(module): param.requires_grad = False return recurse_preorder(model, _init_V1_layers) +def replace_layers_lowrank(model, spatial_k, channel_k): + ''' + Replace nn.Conv2d layers with LowRankFactConv2d + ''' + def _replace_layers_lowrank(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = LowRankFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + spatial_k=spatial_k, channel_k=channel_k, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_lowrank) diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 264223e..9f600ad 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -14,8 +14,10 @@ from models import define_models def save_model(args, model): - src= "../saved-models/ResNets/" - model_dir = src + args.name + #src= "../saved-models/ResNets/" + src="/home/mila/v/vivian.white/scratch/v1-models/saved-models/low-rank/" + model_dir = src + args.net + "-seed" + str(args.seed) + print("Model dir: ", model_dir) os.makedirs(model_dir, exist_ok=True) torch.save(model.state_dict(), model_dir+ "/model.pt") @@ -32,6 +34,8 @@ def save_model(args, model): help='filename for saved model') parser.add_argument('--seed', default=0, type=int, help='seed to use') parser.add_argument('--width', type=float, default=1, help='resnet width scale factor') +parser.add_argument('--spatial_k', type=float, default=1, help='%spatial low-rank') +parser.add_argument('--channel_k', type=float, default=1, help='%channel low-rank') args = parser.parse_args() @@ -73,22 +77,26 @@ def save_model(args, model): run_name = "{}_width_{}_seed_{}".format(args.net, args.width, args.seed) print("Args.net: ", args.net) print("Net: ", net) + set_seeds(args.seed) net = net.to(device) -wandb_dir = "../../wandb" +wandb_dir = "/home/mila/v/vivian.white/scratch/v1-models/wandb" os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) -run = wandb.init(project="FactConv", config=args, - group="testing", name=run_name, dir=wandb_dir) +print("Num Learnable Params: ", sum(p.numel() for p in net.parameters() if + p.requires_grad)) + +run = wandb.init(project="factconv", config=args, group="covtest", name=run_name, dir=wandb_dir) #wandb.watch(net, log='all', log_freq=1) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) -scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,\ + T_max=args.num_epochs) # Training def train(epoch): From 4a04e1db4b5585e0bc1a1f7277d94ca0da7f69fb Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Wed, 26 Jun 2024 14:06:18 -0400 Subject: [PATCH 67/77] low-rank experiments --- FactConv/pytorch_cifar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 9f600ad..5dbad88 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -74,7 +74,7 @@ def save_model(args, model): print('==> Building model..') net = define_models(args) -run_name = "{}_width_{}_seed_{}".format(args.net, args.width, args.seed) +run_name = "{}_{}_seed_{}".format(args.net, args.channel_k, args.seed) print("Args.net: ", args.net) print("Net: ", net) @@ -88,7 +88,7 @@ def save_model(args, model): print("Num Learnable Params: ", sum(p.numel() for p in net.parameters() if p.requires_grad)) -run = wandb.init(project="factconv", config=args, group="covtest", name=run_name, dir=wandb_dir) +run = wandb.init(project="factconv", config=args, group="lowrank", name=run_name, dir=wandb_dir) #wandb.watch(net, log='all', log_freq=1) From 0342b2e24ff24ad7219704c2a9f04fa20e06e77e Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Wed, 26 Jun 2024 14:07:47 -0400 Subject: [PATCH 68/77] adding covariance module --- FactConv/cov.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 FactConv/cov.py diff --git a/FactConv/cov.py b/FactConv/cov.py new file mode 100644 index 0000000..bddc8f9 --- /dev/null +++ b/FactConv/cov.py @@ -0,0 +1,55 @@ +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn.parameter import Parameter +from torch.nn.common_types import _size_2_t +from typing import Optional, List, Tuple, Union +import math + +class Covariance(nn.Module): + def __init__(self, triu_size): + super().__init__() + self.triu_size = triu_size + + # for channel cov, triu_size is in_channels // groups + # for spatial cov, triu_size is kernel_size[0] * kernel_size[1] + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + scat_idx = triu[0] * self.triu_size + triu[1] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = triu.shape[1] + #tri_vec = self.weight.new_zeros((triu_len,)) + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) + + + def cov(self, R): + return R.T @ R + + def sqrt(self, R): + return R + + def _tri_vec_to_mat(self, vec, n, scat_idx): + #r = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n) + r = torch.zeros((n*n), device=self.tri_vec.device, + dtype=self.tri_vec.dtype).scatter_(0, scat_idx, vec).view(n, n) + r = torch.diagonal_scatter(r, r.diagonal().exp_()) + return r + +class LowRankCovariance(Covariance): + def __init__(self, triu_size, rank): + super().__init__(triu_size) + self.k = math.ceil(rank * triu_size) +# print("Rank: ", self.k) + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = triu[0] < self.k + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) + From a8c7b5d2f8c933f25dffe782303140316b12553e Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Fri, 28 Jun 2024 15:47:03 -0400 Subject: [PATCH 69/77] adding learnable low-rank diagonal covariance --- FactConv/conv_modules.py | 65 +++++++++++++++++++++++++++++++++------- FactConv/cov.py | 49 ++++++++++++++++++++++++------ 2 files changed, 95 insertions(+), 19 deletions(-) diff --git a/FactConv/conv_modules.py b/FactConv/conv_modules.py index 0d80a25..fb4a8b0 100644 --- a/FactConv/conv_modules.py +++ b/FactConv/conv_modules.py @@ -4,7 +4,7 @@ from torch.nn.parameter import Parameter from torch.nn.common_types import _size_2_t from typing import Optional, List, Tuple, Union -from cov import Covariance, LowRankCovariance +from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance """ The function below is copied directly from @@ -52,10 +52,8 @@ def __init__( self.spatial = Covariance(spatial_triu_size) def forward(self, input: Tensor) -> Tensor: - U1 = self.channel._tri_vec_to_mat(self.channel.tri_vec, - self.channel.triu_size, self.channel.scat_idx) - U2 = self.spatial._tri_vec_to_mat(self.spatial.tri_vec, - self.spatial.triu_size, self.spatial.scat_idx) + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() # flatten over filter dims and contract composite_weight = _contract(self.weight, U1.T, 1) @@ -98,14 +96,61 @@ def __init__( spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] self.channel = LowRankCovariance(channel_triu_size, self.channel_k) - self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k) + #self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k) + self.spatial = Covariance(spatial_triu_size) + + + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) + +class LowRankPlusDiagFactConv2d(nn.Conv2d): + def __init__( + self, + spatial_k, + channel_k, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None, + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + self.spatial_k = spatial_k + self.channel_k = channel_k + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = LowRankPlusDiagCovariance(channel_triu_size, self.channel_k) + # self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k) + self.spatial = Covariance(spatial_triu_size) def forward(self, input: Tensor) -> Tensor: - U1 = self.channel._tri_vec_to_mat(self.channel.tri_vec, - self.channel.triu_size, self.channel.scat_idx) - U2 = self.spatial._tri_vec_to_mat(self.spatial.tri_vec, - self.spatial.triu_size, self.spatial.scat_idx) + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() # flatten over filter dims and contract composite_weight = _contract(self.weight, U1.T, 1) diff --git a/FactConv/cov.py b/FactConv/cov.py index bddc8f9..9fabd30 100644 --- a/FactConv/cov.py +++ b/FactConv/cov.py @@ -7,42 +7,52 @@ import math class Covariance(nn.Module): + """ Module for learnable weight covariance matrices + Input: the size of the covariance matrix + For channel covariances, triu_size is in_channels // groups + For spatial covariances, triu_size is kernel_size[0] * kernel_size[1] + """ def __init__(self, triu_size): super().__init__() self.triu_size = triu_size - # for channel cov, triu_size is in_channels // groups - # for spatial cov, triu_size is kernel_size[0] * kernel_size[1] - triu = torch.triu_indices(self.triu_size, self.triu_size, dtype=torch.long) scat_idx = triu[0] * self.triu_size + triu[1] self.register_buffer("scat_idx", scat_idx, persistent=False) triu_len = triu.shape[1] - #tri_vec = self.weight.new_zeros((triu_len,)) tri_vec = torch.zeros((triu_len,)) self.tri_vec = Parameter(tri_vec) - def cov(self, R): + def cov(self): + R = self._tri_vec_to_mat(self.tri_vec, self.triu_size, self.scat_idx) return R.T @ R - def sqrt(self, R): + def sqrt(self): + R = self._tri_vec_to_mat(self.tri_vec, self.triu_size, self.scat_idx) return R def _tri_vec_to_mat(self, vec, n, scat_idx): - #r = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n) r = torch.zeros((n*n), device=self.tri_vec.device, dtype=self.tri_vec.dtype).scatter_(0, scat_idx, vec).view(n, n) r = torch.diagonal_scatter(r, r.diagonal().exp_()) return r class LowRankCovariance(Covariance): + """ Module for learnable low-rank covariance matrices + Input: the size of the covariance matrix and the rank as a percentage + """ def __init__(self, triu_size, rank): super().__init__(triu_size) - self.k = math.ceil(rank * triu_size) -# print("Rank: ", self.k) + if triu_size == 3: + self.k = 3 + print(self.k) + else: + self.k = math.ceil(rank * triu_size) + print(self.k) + triu = torch.triu_indices(self.triu_size, self.triu_size, dtype=torch.long) mask = triu[0] < self.k @@ -53,3 +63,24 @@ def __init__(self, triu_size, rank): tri_vec = torch.zeros((triu_len,)) self.tri_vec = Parameter(tri_vec) +class LowRankPlusDiagCovariance(Covariance): + """ Module for learning low-rank covariances plus the main diagonal + input: size of covariance matrix and the percentage rank (rows kept in upper triangular covaraince)""" + def __init__(self, triu_size, rank): + super().__init__(triu_size) + if triu_size == 3: + self.k = 3 + print(self.k) + else: + self.k = math.ceil(rank * triu_size) + print(self.k) + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = (triu[0] < self.k)|(triu[0]==triu[1]) + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) + From ec927a4b4a34893338031f9f1fcc3cebaf8a6c0c Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Fri, 28 Jun 2024 15:49:39 -0400 Subject: [PATCH 70/77] adding low rank experiments --- FactConv/launched.sh | 21 +++++++++++++-------- FactConv/models/__init__.py | 5 ++++- FactConv/models/function_utils.py | 31 ++++++++++++++++++++++++++++--- FactConv/pytorch_cifar.py | 2 +- FactConv/setoff.sh | 8 ++++---- 5 files changed, 50 insertions(+), 17 deletions(-) mode change 100644 => 100755 FactConv/launched.sh diff --git a/FactConv/launched.sh b/FactConv/launched.sh old mode 100644 new mode 100755 index 3904685..06277c0 --- a/FactConv/launched.sh +++ b/FactConv/launched.sh @@ -1,9 +1,13 @@ #!/bin/bash -width=(0.125 0.25 0.5 1.0 2.0 4.0) +#width=(0.125 0.25 0.5 1.0 2.0 4.0) +#width=(1) +#rank=(1.0 0.9 0.65 0.75 0.5 0.35 0.25 0.15 0.1 0.05) +rank=(1.0 0.75 0.5 0.25 0.1) +#seed=(1 2) +#width=(0.125) seed=(0 1 2) - -for i in ${width[@]} +for i in ${rank[@]} do for j in ${seed[@]} do @@ -14,10 +18,11 @@ do #sbatch setoff.sh --width $i --seed $j --net fact_us_uc_resnet18 # WHERE IS THIS VIVIAN 🧐🧐🤨🤨 # NVM good job Vivian - sbatch setoff.sh --width $i --seed $j --net fact_diag_resnet18 - sbatch setoff.sh --width $i --seed $j --net fact_diag_us_resnet18 - sbatch setoff.sh --width $i --seed $j --net fact_diag_uc_resnet18 - sbatch setoff.sh --width $i --seed $j --net fact_diag_us_uc_resnet18 - done + sbatch setoff.sh $i $j resnet18_fact_lr-diag + sbatch setoff.sh $i $j resnet18_fact_lowrank + #sbatch setoff.sh $i $j fact_diagchan_us_resnet18 + #sbatch setoff.sh $i $j fact_diagchan_uc_resnet18 + #sbatch setoff.sh $i $j fact_diagchan_us_uc_resnet18 + done done diff --git a/FactConv/models/__init__.py b/FactConv/models/__init__.py index c245851..55a178f 100644 --- a/FactConv/models/__init__.py +++ b/FactConv/models/__init__.py @@ -3,7 +3,7 @@ from .function_utils import replace_layers_factconv2d,\ replace_layers_diagfactconv2d, replace_layers_diagchanfactconv2d,\ turn_off_covar_grad, replace_layers_scale, init_V1_layers,\ -replace_layers_lowrank +replace_layers_lowrank, replace_layers_lowrankplusdiag def define_models(args): @@ -26,6 +26,9 @@ def define_models(args): if 'lowrank' in args.net: replace_layers_lowrank(model, args.spatial_k, args.channel_k) print("Low rank") + if 'lr-diag' in args.net: + replace_layers_lowrankplusdiag(model, args.spatial_k, args.channel_k) + print("Low rank plus diag") if "v1" in args.net: init_V1_layers(model, bias=False) if "us" in args.net: diff --git a/FactConv/models/function_utils.py b/FactConv/models/function_utils.py index 2dfc62d..695526e 100644 --- a/FactConv/models/function_utils.py +++ b/FactConv/models/function_utils.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn from conv_modules import FactConv2d, DiagFactConv2d, DiagChanFactConv2d,\ -LowRankFactConv2d -from cov import Covariance, LowRankCovariance +LowRankFactConv2d, LowRankPlusDiagFactConv2d +from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance from V1_covariance import V1_init @@ -198,7 +198,8 @@ def turn_off_covar_grad(model, covariance): ''' def _turn_off_covar_grad(module): if isinstance(module, FactConv2d) or isinstance(module, DiagFactConv2d)\ - or isinstance(module, DiagChanFactConv2d) or isinstance(module, LowRankFactConv2d): + or isinstance(module, DiagChanFactConv2d) or isinstance(module,LowRankFactConv2d)\ + or isinstance(module, LowRankPlusDiagFactConv2d): for name, mod in module.named_modules(): if isinstance(mod, Covariance): if covariance == name: @@ -271,3 +272,27 @@ def _replace_layers_lowrank(module): new_module.load_state_dict(new_sd) return new_module return recurse_preorder(model, _replace_layers_lowrank) + + +def replace_layers_lowrankplusdiag(model, spatial_k, channel_k): + ''' + Replace nn.Conv2d layers with LowRankFactConv2d + ''' + def _replace_layers_lowrankplusdiag(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = LowRankPlusDiagFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + spatial_k=spatial_k, channel_k=channel_k, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_lowrankplusdiag) diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 5dbad88..08a864d 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -88,7 +88,7 @@ def save_model(args, model): print("Num Learnable Params: ", sum(p.numel() for p in net.parameters() if p.requires_grad)) -run = wandb.init(project="factconv", config=args, group="lowrank", name=run_name, dir=wandb_dir) +run = wandb.init(project="factconv", config=args, group="lowrankplusdiag", name=run_name, dir=wandb_dir) #wandb.watch(net, log='all', log_freq=1) diff --git a/FactConv/setoff.sh b/FactConv/setoff.sh index 0a3e069..10972e4 100644 --- a/FactConv/setoff.sh +++ b/FactConv/setoff.sh @@ -7,7 +7,7 @@ #SBATCH --output slurm/%j.out #SBATCH --partition long -module load python/3.8 -source ../refactor/env/bin/activate -echo "$@" -CUDA_VISIBLE_DEVICES=0 python pytorch_cifar.py "$@" +module load anaconda/3 +conda activate random_features + +python pytorch_cifar.py --channel_k $1 --seed $2 --net $3 From 5cceccb81ce9a8ac1f736beed002aa753f8b5b4c Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sat, 29 Jun 2024 16:53:00 -0400 Subject: [PATCH 71/77] updating covariance modules --- FactConv/conv_modules.py | 185 +++++++++++++++++------------- FactConv/cov.py | 73 ++++++++++++ FactConv/models/__init__.py | 27 ++--- FactConv/models/function_utils.py | 65 ++++++++--- 4 files changed, 243 insertions(+), 107 deletions(-) diff --git a/FactConv/conv_modules.py b/FactConv/conv_modules.py index fb4a8b0..2fe6bac 100644 --- a/FactConv/conv_modules.py +++ b/FactConv/conv_modules.py @@ -4,7 +4,8 @@ from torch.nn.parameter import Parameter from torch.nn.common_types import _size_2_t from typing import Optional, List, Tuple, Union -from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance +from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance,\ +LowRankK1Covariance, OffDiagCovariance, DiagCovariance """ The function below is copied directly from @@ -96,8 +97,8 @@ def __init__( spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] self.channel = LowRankCovariance(channel_triu_size, self.channel_k) - #self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k) - self.spatial = Covariance(spatial_triu_size) + self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k) + #self.spatial = Covariance(spatial_triu_size) def forward(self, input: Tensor) -> Tensor: @@ -114,7 +115,6 @@ def forward(self, input: Tensor) -> Tensor: class LowRankPlusDiagFactConv2d(nn.Conv2d): def __init__( self, - spatial_k, channel_k, in_channels: int, out_channels: int, @@ -132,7 +132,6 @@ def __init__( super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype) - self.spatial_k = spatial_k self.channel_k = channel_k # weight shape: (out_channels, in_channels // groups, *kernel_size) new_weight = torch.empty_like(self.weight) @@ -144,7 +143,6 @@ def __init__( spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] self.channel = LowRankPlusDiagCovariance(channel_triu_size, self.channel_k) - # self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k) self.spatial = Covariance(spatial_triu_size) @@ -159,9 +157,10 @@ def forward(self, input: Tensor) -> Tensor: ).reshape(self.weight.shape) return self._conv_forward(input, composite_weight, self.bias) -class DiagFactConv2d(nn.Conv2d): +class LowRankK1FactConv2d(nn.Conv2d): def __init__( self, + channel_k, in_channels: int, out_channels: int, kernel_size: _size_2_t, @@ -172,7 +171,7 @@ def __init__( bias: bool = True, padding_mode: str = 'zeros', # TODO: refine this type device=None, - dtype=None + dtype=None, ) -> None: # init as Conv2d super().__init__( @@ -184,41 +183,61 @@ def __init__( self.register_buffer("weight", new_weight) nn.init.kaiming_normal_(self.weight) - self.in_features = self.in_channels // self.groups * \ - self.kernel_size[0] * self.kernel_size[1] - triu1 = torch.triu_indices(self.in_channels // self.groups, - self.in_channels // self.groups, - device=self.weight.device, - dtype=torch.long) - mask = triu1[0] == triu1[1] - scat_idx1 = triu1[0][mask]*self.in_channels//self.groups + triu1[1][mask] - self.register_buffer("scat_idx1", scat_idx1, persistent=False) - - triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], - self.kernel_size[0] - * self.kernel_size[1], - device=self.weight.device, - dtype=torch.long) - mask = triu2[0] == triu2[1] - scat_idx2 = triu2[0][mask]*self.kernel_size[0]*self.kernel_size[1] + triu2[1][mask] - - self.register_buffer("scat_idx2", scat_idx2, persistent=False) + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = LowRankK1Covariance(channel_triu_size, self.channel_k) + # self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k) + self.spatial = Covariance(spatial_triu_size) - triu1_len = scat_idx1.shape[0] - triu2_len = scat_idx2.shape[0] - tri1_vec = self.weight.new_zeros((triu1_len,)) - self.tri1_vec = Parameter(tri1_vec) + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) + +class OffDiagFactConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None, + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] - tri2_vec = self.weight.new_zeros((triu2_len,)) - self.tri2_vec = Parameter(tri2_vec) + self.channel = OffDiagCovariance(channel_triu_size) + self.spatial = Covariance(spatial_triu_size) def forward(self, input: Tensor) -> Tensor: - U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // - self.groups, self.scat_idx1) - U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], - self.scat_idx2) + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + # flatten over filter dims and contract composite_weight = _contract(self.weight, U1.T, 1) composite_weight = _contract( @@ -226,12 +245,7 @@ def forward(self, input: Tensor) -> Tensor: ).reshape(self.weight.shape) return self._conv_forward(input, composite_weight, self.bias) - def _tri_vec_to_mat(self, vec, n, scat_idx): - U = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n) - U = torch.diagonal_scatter(U, U.diagonal().exp_()) - return U - -class DiagChanFactConv2d(nn.Conv2d): +class DiagFactConv2d(nn.Conv2d): def __init__( self, in_channels: int, @@ -256,43 +270,17 @@ def __init__( self.register_buffer("weight", new_weight) nn.init.kaiming_normal_(self.weight) - self.in_features = self.in_channels // self.groups * \ - self.kernel_size[0] * self.kernel_size[1] - #triu1 = torch.triu_indices(self.in_channels // self.groups, - # self.in_channels // self.groups, - # device=self.weight.device, - # dtype=torch.long) - #mask = triu1[0] == triu1[1] - #scat_idx1 = triu1[0][mask]*self.in_channels//self.groups + triu1[1][mask] - matrix = torch.arange(self.in_channels // self.groups \ - * self.in_channels // self.groups).view(self.in_channels // \ - self.groups, self.in_channels // self.groups) - scat_idx1 = torch.diagonal(matrix) - self.register_buffer("scat_idx1", scat_idx1, persistent=False) - - triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], - self.kernel_size[0] - * self.kernel_size[1], - device=self.weight.device, - dtype=torch.long) - scat_idx2 = triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1] - self.register_buffer("scat_idx2", scat_idx2, persistent=False) - - triu1_len = scat_idx1.shape[0] - triu2_len = triu2.shape[1] - - tri1_vec = self.weight.new_zeros((triu1_len,)) - self.tri1_vec = Parameter(tri1_vec) - - tri2_vec = self.weight.new_zeros((triu2_len,)) - self.tri2_vec = Parameter(tri2_vec) + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = DiagCovariance(channel_triu_size) + self.spatial = DiagCovariance(spatial_triu_size) def forward(self, input: Tensor) -> Tensor: - U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // - self.groups, self.scat_idx1) - U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], - self.scat_idx2) + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + # flatten over filter dims and contract composite_weight = _contract(self.weight, U1.T, 1) composite_weight = _contract( @@ -300,8 +288,43 @@ def forward(self, input: Tensor) -> Tensor: ).reshape(self.weight.shape) return self._conv_forward(input, composite_weight, self.bias) - def _tri_vec_to_mat(self, vec, n, scat_idx): - U = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n) - U = torch.diagonal_scatter(U, U.diagonal().exp_()) - return U +class DiagChanFactConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = DiagCovariance(channel_triu_size) + self.spatial = Covariance(spatial_triu_size) + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) diff --git a/FactConv/cov.py b/FactConv/cov.py index 9fabd30..b0da90a 100644 --- a/FactConv/cov.py +++ b/FactConv/cov.py @@ -41,6 +41,25 @@ def _tri_vec_to_mat(self, vec, n, scat_idx): return r class LowRankCovariance(Covariance): + """ Module for learnable low-rank covariance matrices + Input: the size of the covariance matrix and the rank as a percentage + """ + def __init__(self, triu_size, rank): + super().__init__(triu_size) + self.k = math.ceil(rank * triu_size) + print(self.k) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = triu[0] < self.k + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) + +class LowRank1stFullCovariance(Covariance): """ Module for learnable low-rank covariance matrices Input: the size of the covariance matrix and the rank as a percentage """ @@ -63,6 +82,7 @@ def __init__(self, triu_size, rank): tri_vec = torch.zeros((triu_len,)) self.tri_vec = Parameter(tri_vec) + class LowRankPlusDiagCovariance(Covariance): """ Module for learning low-rank covariances plus the main diagonal input: size of covariance matrix and the percentage rank (rows kept in upper triangular covaraince)""" @@ -84,3 +104,56 @@ def __init__(self, triu_size, rank): tri_vec = torch.zeros((triu_len,)) self.tri_vec = Parameter(tri_vec) +class LowRankK1Covariance(Covariance): + """ Module for learnable low-rank covariance matrices where + each layer has a rank of 1 + Input: the size of the covariance matrix + """ + def __init__(self, triu_size, rank): + super().__init__(triu_size) + + self.k = 1 + print(self.k) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = triu[0] < self.k + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) + + +class OffDiagCovariance(Covariance): + """ Module for learning off-diagonal of weight covariance matrix + input: size of covariance matrix""" + def __init__(self, triu_size): + super().__init__(triu_size) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = (triu[0]+1) == triu[1] + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) + +class DiagCovariance(Covariance): + """ Module for learning diagonal of weight covariance matrix + input: size of covariance matrix""" + def __init__(self, triu_size): + super().__init__(triu_size) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = triu[0] == triu[1] + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) diff --git a/FactConv/models/__init__.py b/FactConv/models/__init__.py index 55a178f..2553089 100644 --- a/FactConv/models/__init__.py +++ b/FactConv/models/__init__.py @@ -3,7 +3,8 @@ from .function_utils import replace_layers_factconv2d,\ replace_layers_diagfactconv2d, replace_layers_diagchanfactconv2d,\ turn_off_covar_grad, replace_layers_scale, init_V1_layers,\ -replace_layers_lowrank, replace_layers_lowrankplusdiag +replace_layers_lowrank, replace_layers_lowrankplusdiag,\ +replace_layers_lowrankK1, replace_layers_offdiag def define_models(args): @@ -13,7 +14,6 @@ def define_models(args): model = LC_CIFAR10(hidden_dim=100, size=2, spatial_freq=0.1, scale=1, bias=True, freeze_spatial=False, freeze_channel=False, spatial_init='V1') - print("RSN model") if args.width != 1: replace_layers_scale(model, args.width) if 'fact' in args.net: @@ -22,20 +22,21 @@ def define_models(args): replace_layers_diagfactconv2d(model) if 'diagchan' in args.net: replace_layers_diagchanfactconv2d(model) - print("Diag Chan") if 'lowrank' in args.net: replace_layers_lowrank(model, args.spatial_k, args.channel_k) - print("Low rank") if 'lr-diag' in args.net: - replace_layers_lowrankplusdiag(model, args.spatial_k, args.channel_k) - print("Low rank plus diag") - if "v1" in args.net: + replace_layers_lowrankplusdiag(model) + if 'lr-K1' in args.net: + replace_layers_lowrankK1(model, args.channel_k) + if 'offdiag' in args.net: + replace_layers_offdiag(model) + if 'v1' in args.net: init_V1_layers(model, bias=False) - if "us" in args.net: - turn_off_covar_grad(model, "spatial") - print("US") - if "uc" in args.net: - turn_off_covar_grad(model, "channel") - print("UC") + if 'us' in args.net: + turn_off_covar_grad(model, 'spatial') + print('US') + if 'uc' in args.net: + turn_off_covar_grad(model, 'channel') + print('UC') return model diff --git a/FactConv/models/function_utils.py b/FactConv/models/function_utils.py index 695526e..ad3f860 100644 --- a/FactConv/models/function_utils.py +++ b/FactConv/models/function_utils.py @@ -1,8 +1,10 @@ import torch import torch.nn as nn from conv_modules import FactConv2d, DiagFactConv2d, DiagChanFactConv2d,\ -LowRankFactConv2d, LowRankPlusDiagFactConv2d -from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance +LowRankFactConv2d, LowRankPlusDiagFactConv2d, LowRankK1FactConv2d, \ +OffDiagFactConv2d +from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance,\ +LowRankK1Covariance, OffDiagCovariance, DiagCovariance from V1_covariance import V1_init @@ -196,6 +198,7 @@ def turn_off_covar_grad(model, covariance): Turn off gradients in tri1_vec or tri2_vec to turn off channel or spatial covariance learning ''' + # TODO: update to check if Covariance module, not Fact def _turn_off_covar_grad(module): if isinstance(module, FactConv2d) or isinstance(module, DiagFactConv2d)\ or isinstance(module, DiagChanFactConv2d) or isinstance(module,LowRankFactConv2d)\ @@ -207,15 +210,6 @@ def _turn_off_covar_grad(module): if "tri_vec" in name: param.requires_grad = False - # for name, param in module.named_parameters(): - # if covariance == "channel": - # if "tri1_vec" in name: - # param.requires_grad = False - # print("Unlearnable Channel") - # if covariance == "spatial": - # if "tri2_vec" in name: - # param.requires_grad = False - # print("Unlearnable Spatial") return recurse_preorder(model, _turn_off_covar_grad) @@ -274,7 +268,7 @@ def _replace_layers_lowrank(module): return recurse_preorder(model, _replace_layers_lowrank) -def replace_layers_lowrankplusdiag(model, spatial_k, channel_k): +def replace_layers_lowrankplusdiag(model, channel_k): ''' Replace nn.Conv2d layers with LowRankFactConv2d ''' @@ -286,7 +280,7 @@ def _replace_layers_lowrankplusdiag(module): out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, - spatial_k=spatial_k, channel_k=channel_k, + channel_k=channel_k, bias=True if module.bias is not None else False) old_sd = module.state_dict() new_sd = new_module.state_dict() @@ -296,3 +290,48 @@ def _replace_layers_lowrankplusdiag(module): new_module.load_state_dict(new_sd) return new_module return recurse_preorder(model, _replace_layers_lowrankplusdiag) + +def replace_layers_lowrankK1(model, channel_k): + ''' + Replace nn.Conv2d layers with LowRankK1FactConv2d + ''' + def _replace_layers_lowrankK1(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = LowRankK1FactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + channel_k=channel_k, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_lowrankK1) + +def replace_layers_offdiag(model): + ''' + Replace nn.Conv2d layers with OffDiagFactConv2d + ''' + def _replace_layers_offdiag(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = OffDiagFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_offdiag) From 55135097a6819fe0f875141b58a587d078c3a83e Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Mon, 1 Jul 2024 13:37:45 -0400 Subject: [PATCH 72/77] logging param count --- FactConv/pytorch_cifar.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 08a864d..9fd1eed 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -14,7 +14,6 @@ from models import define_models def save_model(args, model): - #src= "../saved-models/ResNets/" src="/home/mila/v/vivian.white/scratch/v1-models/saved-models/low-rank/" model_dir = src + args.net + "-seed" + str(args.seed) print("Model dir: ", model_dir) @@ -85,12 +84,15 @@ def save_model(args, model): os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) -print("Num Learnable Params: ", sum(p.numel() for p in net.parameters() if - p.requires_grad)) +param_count = sum(p.numel() for p in net.parameters() if p.requires_grad) + +print("Num Learnable Params: ", param_count) + run = wandb.init(project="factconv", config=args, group="lowrankplusdiag", name=run_name, dir=wandb_dir) #wandb.watch(net, log='all', log_freq=1) +run.log({'param_count':param_count}) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=args.lr, @@ -119,6 +121,9 @@ def train(epoch): progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) + + acc = 100.*correct/total + run.log({"train_accuracy":acc}) def test(epoch): global best_acc @@ -142,7 +147,7 @@ def test(epoch): # Save checkpoint. acc = 100.*correct/total - run.log({"accuracy":acc}) + run.log({"test_accuracy":acc}) if acc > best_acc: print('Saving..') state = { From f5f70f928e08d82fbc51e2ffb31f721216e1364f Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Mon, 1 Jul 2024 13:56:38 -0400 Subject: [PATCH 73/77] update turn_off_covar_grad function --- FactConv/models/function_utils.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/FactConv/models/function_utils.py b/FactConv/models/function_utils.py index ad3f860..a3575eb 100644 --- a/FactConv/models/function_utils.py +++ b/FactConv/models/function_utils.py @@ -195,20 +195,17 @@ def _replace_layers_fact_with_conv(module): def turn_off_covar_grad(model, covariance): print("Unlearnable ", covariance) ''' - Turn off gradients in tri1_vec or tri2_vec to turn off - channel or spatial covariance learning + Turn off gradients in the tri_vec param in + the Covariance module to disable channel + or spatial covariance learning ''' - # TODO: update to check if Covariance module, not Fact def _turn_off_covar_grad(module): - if isinstance(module, FactConv2d) or isinstance(module, DiagFactConv2d)\ - or isinstance(module, DiagChanFactConv2d) or isinstance(module,LowRankFactConv2d)\ - or isinstance(module, LowRankPlusDiagFactConv2d): - for name, mod in module.named_modules(): - if isinstance(mod, Covariance): - if covariance == name: - for name, param in mod.named_parameters(): - if "tri_vec" in name: - param.requires_grad = False + for name, mod in module.named_modules(): + if isinstance(mod, Covariance): + if covariance == name: + for name, param in mod.named_parameters(): + if "tri_vec" in name: + param.requires_grad = False return recurse_preorder(model, _turn_off_covar_grad) From 0a2e945a6843f360bea9577d13cbe8caa60174f6 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Mon, 1 Jul 2024 15:38:56 -0400 Subject: [PATCH 74/77] adding LearnableCov.py back --- RSN_experiments/LearnableCov.py | 201 ++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 RSN_experiments/LearnableCov.py diff --git a/RSN_experiments/LearnableCov.py b/RSN_experiments/LearnableCov.py new file mode 100644 index 0000000..dc1fd07 --- /dev/null +++ b/RSN_experiments/LearnableCov.py @@ -0,0 +1,201 @@ +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn.parameter import Parameter, UninitializedParameter +from torch.nn.common_types import _size_2_t +from typing import Optional, List, Tuple, Union + +class Linear(nn.Module): + in_features: int + out_features: int + weight: torch.Tensor + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.factory_kwargs = factory_kwargs + + self.register_buffer("weight", + torch.empty((out_features, in_features), + **factory_kwargs)) + + triu_len = torch.triu_indices(in_features, in_features).shape[1] + self.tri_vec = Parameter(torch.empty((triu_len,), **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.constant_(self.tri_vec, 0.) + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + nn.init.kaiming_normal_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input: Tensor) -> Tensor: + U = torch.zeros((self.in_features, self.in_features), + **self.factory_kwargs) + U[torch.triu_indices(self.in_features, self.in_features).tolist()] \ + = self.tri_vec + exp_diag = torch.exp(torch.diagonal(U)) + U[range(self.in_features), range(self.in_features)] = exp_diag + composite_weight = self.weight @ U + + return F.linear(input, composite_weight, self.bias) + +class Conv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + + factory_kwargs = {'device': device, 'dtype': dtype} + print("Device: ", device) + self.factory_kwargs = factory_kwargs + + # weight shape: (out_channels, in_channels // groups, *kernel_size) + weight_shape = self.weight.shape + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", torch.empty(weight_shape, **factory_kwargs)) + nn.init.kaiming_normal_(self.weight) + + self.in_features = self.in_channels // self.groups * \ + self.kernel_size[0] * self.kernel_size[1] + triu_len = torch.triu_indices(self.in_features, + self.in_features).shape[1] + self.tri_vec = Parameter(torch.zeros((triu_len,), **factory_kwargs)) + + def forward(self, input: Tensor) -> Tensor: + U = torch.zeros((self.in_features, self.in_features), + **self.factory_kwargs) + U[torch.triu_indices(self.in_features, self.in_features).tolist()] \ + = self.tri_vec + exp_diag = torch.exp(torch.diagonal(U)) + U[range(self.in_features), range(self.in_features)] = exp_diag + + matrix_shape = (self.out_channels, self.in_features) + composite_weight = torch.reshape( + torch.reshape(self.weight, matrix_shape) @ U, + self.weight.shape + ) + + return self._conv_forward(input, composite_weight, self.bias) + +class FactConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + + factory_kwargs = {'device': device, 'dtype': dtype} + self.factory_kwargs = factory_kwargs + + # weight shape: (out_channels, in_channels // groups, *kernel_size) + weight_shape = self.weight.shape + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", torch.empty(weight_shape, **factory_kwargs)) + nn.init.kaiming_normal_(self.weight) + + self.in_features = self.in_channels // self.groups * \ + self.kernel_size[0] * self.kernel_size[1] + triu1_len = torch.triu_indices(self.in_channels // self.groups, + self.in_channels // self.groups).shape[1] + triu2_len = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], + self.kernel_size[0] * self.kernel_size[1]).shape[1] + self.tri1_vec = Parameter(torch.zeros((triu1_len,), **factory_kwargs)) + self.tri2_vec = Parameter(torch.zeros((triu2_len,), **factory_kwargs)) + + def forward(self, input: Tensor) -> Tensor: + U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // self.groups) + U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1]) + U = torch.kron(U1, U2) + U = self._exp_diag(U) + + matrix_shape = (self.out_channels, self.in_features) + composite_weight = torch.reshape( + torch.reshape(self.weight, matrix_shape) @ U, + self.weight.shape + ) + + return self._conv_forward(input, composite_weight, self.bias) + + def _tri_vec_to_mat(self, vec, n): + U = torch.zeros((n, n), **self.factory_kwargs) + U[torch.triu_indices(n, n, **self.factory_kwargs).tolist()] = vec + # TODO(kamdh): experiment with this placement versus after kron + # U = self._exp_diag(U) + return U + + def _exp_diag(self, mat): + exp_diag = torch.exp(torch.diagonal(mat)) + n = mat.shape[0] + mat[range(n), range(n)] = exp_diag + return mat + +def V1_init(layer, size, spatial_freq, center, scale=1., bias=False, seed=None, + device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): + ''' + Initialization for FactConv2d + ''' + import sys + sys.path.insert(0, '/research/harris/vivian/structured_random_features/') + from src.models.weights import V1_covariance_matrix + + classname = layer.__class__.__name__ + assert classname.find('FactConv2d') != -1, 'This init only works for FactConv2d layers' + assert center is not None, "center needed" + + out_channels, in_channels, xdim, ydim = layer.weight.shape + dim = (xdim, ydim) + + C_patch = Tensor(V1_covariance_matrix(dim, size, spatial_freq, center, scale)).to(device) + U_patch = torch.linalg.cholesky(C_patch, upper=True) + n = U_patch.shape[0] + # replace diagonal with logarithm for parameterization + log_diag = torch.log(torch.diagonal(U_patch)) + U_patch[range(n), range(n)] = log_diag + # form vector of upper triangular entries + tri_vec = U_patch[torch.triu_indices(n, n, device=device).tolist()].ravel() + with torch.no_grad(): + layer.tri2_vec.copy_(tri_vec) + + if bias == False: + layer.bias = None From 1f9f43cd1ade6c54820c9e8718ad09dd7cc1d3d7 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Mon, 15 Jul 2024 14:10:33 -0400 Subject: [PATCH 75/77] adding nonlinearities study --- FactConv/conv_modules.py | 31 +++++---- FactConv/cov.py | 111 ++++++++++++++++++++++++++---- FactConv/models/__init__.py | 15 ++-- FactConv/models/function_utils.py | 16 +++-- FactConv/pytorch_cifar.py | 8 ++- 5 files changed, 139 insertions(+), 42 deletions(-) diff --git a/FactConv/conv_modules.py b/FactConv/conv_modules.py index 2fe6bac..f7ce681 100644 --- a/FactConv/conv_modules.py +++ b/FactConv/conv_modules.py @@ -5,7 +5,8 @@ from torch.nn.common_types import _size_2_t from typing import Optional, List, Tuple, Union from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance,\ -LowRankK1Covariance, OffDiagCovariance, DiagCovariance +LowRankK1Covariance, OffDiagCovariance, DiagCovariance,\ +LowRankK1DiagCovariance """ The function below is copied directly from @@ -21,6 +22,7 @@ def _contract(tensor, matrix, axis): class FactConv2d(nn.Conv2d): def __init__( self, + nonlinearity, in_channels: int, out_channels: int, kernel_size: _size_2_t, @@ -49,8 +51,8 @@ def __init__( channel_triu_size = self.in_channels // self.groups spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] - self.channel = Covariance(channel_triu_size) - self.spatial = Covariance(spatial_triu_size) + self.channel = Covariance(channel_triu_size, nonlinearity) + self.spatial = Covariance(spatial_triu_size, nonlinearity) def forward(self, input: Tensor) -> Tensor: U1 = self.channel.sqrt() @@ -161,6 +163,7 @@ class LowRankK1FactConv2d(nn.Conv2d): def __init__( self, channel_k, + nonlinearity, in_channels: int, out_channels: int, kernel_size: _size_2_t, @@ -183,12 +186,13 @@ def __init__( self.register_buffer("weight", new_weight) nn.init.kaiming_normal_(self.weight) + self.channel_k = channel_k channel_triu_size = self.in_channels // self.groups spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] - self.channel = LowRankK1Covariance(channel_triu_size, self.channel_k) - # self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k) - self.spatial = Covariance(spatial_triu_size) + self.channel = LowRankK1Covariance(channel_triu_size, self.channel_k,\ + nonlinearity) + self.spatial = Covariance(spatial_triu_size, nonlinearity) def forward(self, input: Tensor) -> Tensor: @@ -205,6 +209,7 @@ def forward(self, input: Tensor) -> Tensor: class OffDiagFactConv2d(nn.Conv2d): def __init__( self, + nonlinearity, in_channels: int, out_channels: int, kernel_size: _size_2_t, @@ -230,8 +235,8 @@ def __init__( channel_triu_size = self.in_channels // self.groups spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] - self.channel = OffDiagCovariance(channel_triu_size) - self.spatial = Covariance(spatial_triu_size) + self.channel = OffDiagCovariance(channel_triu_size, nonlinearity) + self.spatial = Covariance(spatial_triu_size, nonlinearity) def forward(self, input: Tensor) -> Tensor: @@ -248,6 +253,7 @@ def forward(self, input: Tensor) -> Tensor: class DiagFactConv2d(nn.Conv2d): def __init__( self, + nonlinearity, in_channels: int, out_channels: int, kernel_size: _size_2_t, @@ -273,8 +279,8 @@ def __init__( channel_triu_size = self.in_channels // self.groups spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] - self.channel = DiagCovariance(channel_triu_size) - self.spatial = DiagCovariance(spatial_triu_size) + self.channel = DiagCovariance(channel_triu_size, nonlinearity) + self.spatial = DiagCovariance(spatial_triu_size, nonlinearity) def forward(self, input: Tensor) -> Tensor: @@ -291,6 +297,7 @@ def forward(self, input: Tensor) -> Tensor: class DiagChanFactConv2d(nn.Conv2d): def __init__( self, + nonlinearity, in_channels: int, out_channels: int, kernel_size: _size_2_t, @@ -316,8 +323,8 @@ def __init__( channel_triu_size = self.in_channels // self.groups spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] - self.channel = DiagCovariance(channel_triu_size) - self.spatial = Covariance(spatial_triu_size) + self.channel = DiagCovariance(channel_triu_size, nonlinearity) + self.spatial = Covariance(spatial_triu_size, nonlinearity) def forward(self, input: Tensor) -> Tensor: U1 = self.channel.sqrt() diff --git a/FactConv/cov.py b/FactConv/cov.py index b0da90a..4cf8006 100644 --- a/FactConv/cov.py +++ b/FactConv/cov.py @@ -8,11 +8,11 @@ class Covariance(nn.Module): """ Module for learnable weight covariance matrices - Input: the size of the covariance matrix + Input: the size of the covariance matrix and the nonlinearity For channel covariances, triu_size is in_channels // groups For spatial covariances, triu_size is kernel_size[0] * kernel_size[1] """ - def __init__(self, triu_size): + def __init__(self, triu_size, nonlinearity): super().__init__() self.triu_size = triu_size @@ -23,6 +23,16 @@ def __init__(self, triu_size): triu_len = triu.shape[1] tri_vec = torch.zeros((triu_len,)) + + self.nl = nonlinearity + self.dict1 = {'abs': torch.abs, 'exp': torch.exp, 'x^2': torch.square, + '1/x^2': lambda x: torch.div(1, torch.square(x)), 'relu': + torch.relu, 'softplus': nn.Softplus(), 'identity': nn.Identity()} + self.dict2 = {'abs': 1, 'exp': 0, 'x^2': 1, '1/x^2': 1, 'relu': 1, + 'softplus': 0.5414, 'identity': 1} + self.func = self.dict1[self.nl] + diag = triu[0] == triu[1] + tri_vec[diag] = self.dict2[self.nl] self.tri_vec = Parameter(tri_vec) @@ -36,13 +46,17 @@ def sqrt(self): def _tri_vec_to_mat(self, vec, n, scat_idx): r = torch.zeros((n*n), device=self.tri_vec.device, - dtype=self.tri_vec.dtype).scatter_(0, scat_idx, vec).view(n, n) - r = torch.diagonal_scatter(r, r.diagonal().exp_()) + dtype=self.tri_vec.dtype) + r = r.view(n, n).fill_diagonal_(self.dict2[self.nl]).view(n*n) + r = r.scatter_(0, scat_idx, vec).view(n, n) + r = torch.diagonal_scatter(r, self.func(r.diagonal())) + return r class LowRankCovariance(Covariance): """ Module for learnable low-rank covariance matrices Input: the size of the covariance matrix and the rank as a percentage + (Relative rank) """ def __init__(self, triu_size, rank): super().__init__(triu_size) @@ -62,6 +76,7 @@ def __init__(self, triu_size, rank): class LowRank1stFullCovariance(Covariance): """ Module for learnable low-rank covariance matrices Input: the size of the covariance matrix and the rank as a percentage + (Relative rank) """ def __init__(self, triu_size, rank): super().__init__(triu_size) @@ -85,7 +100,10 @@ def __init__(self, triu_size, rank): class LowRankPlusDiagCovariance(Covariance): """ Module for learning low-rank covariances plus the main diagonal - input: size of covariance matrix and the percentage rank (rows kept in upper triangular covaraince)""" + input: size of covariance matrix and the percentage rank (rows kept in upper triangular covaraince) + (Relative rank) + """ + def __init__(self, triu_size, rank): super().__init__(triu_size) if triu_size == 3: @@ -106,14 +124,22 @@ def __init__(self, triu_size, rank): class LowRankK1Covariance(Covariance): """ Module for learnable low-rank covariance matrices where - each layer has a rank of 1 - Input: the size of the covariance matrix + each layer has a rank integer + Input: the size of the covariance matrix and the rank (int) + (Absolute rank) """ - def __init__(self, triu_size, rank): - super().__init__(triu_size) + def __init__(self, triu_size, rank, nonlinearity): + super().__init__(triu_size, nonlinearity) - self.k = 1 - print(self.k) + if triu_size == 3: + self.k = 3 + print(self.k) + elif triu_size < rank: + self.k = triu_size + print(self.k) + else: + self.k = rank + print(self.k) triu = torch.triu_indices(self.triu_size, self.triu_size, dtype=torch.long) @@ -123,14 +149,53 @@ def __init__(self, triu_size, rank): triu_len = scat_idx.shape[0] tri_vec = torch.zeros((triu_len,)) + + self.nl = nonlinearity + self.dict1 = {'abs': torch.abs, 'exp': torch.exp, 'x^2': torch.square, + '1/x^2': lambda x: torch.div(1, torch.square(x)), 'relu': + torch.relu, 'softplus': nn.Softplus(), 'identity': nn.Identity()} + self.dict2 = {'abs': 1, 'exp': 0, 'x^2': 1, '1/x^2': 1, 'relu': 1, + 'softplus': 0.5414, 'identity': 1} + self.func = self.dict1[self.nl] + diag = triu[0][mask] == triu[1][mask] + tri_vec[diag] = self.dict2[self.nl] + self.tri_vec = Parameter(tri_vec) +class LowRankK1DiagCovariance(Covariance): + """ Module for learnable low-rank covariance matrices where + each layer has a rank integer plus full diagonal + Input: the size of the covariance matrix and the rank (int) + (Absolute rank) + """ + def __init__(self, triu_size, rank): + super().__init__(triu_size) + + if triu_size == 3: + self.k = 3 + print(self.k) + elif triu_size < rank: + self.k = triu_size + print(self.k) + else: + self.k = rank + print(self.k) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = (triu[0] < self.k)|(triu[0]==triu[1]) + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) class OffDiagCovariance(Covariance): """ Module for learning off-diagonal of weight covariance matrix input: size of covariance matrix""" - def __init__(self, triu_size): - super().__init__(triu_size) + def __init__(self, triu_size, nonlinearity): + super().__init__(triu_size, nonlinearity) triu = torch.triu_indices(self.triu_size, self.triu_size, dtype=torch.long) @@ -140,13 +205,21 @@ def __init__(self, triu_size): triu_len = scat_idx.shape[0] tri_vec = torch.zeros((triu_len,)) + self.nl = nonlinearity + self.dict1 = {'abs': torch.abs, 'exp': torch.exp, 'x^2': torch.square, + '1/x^2': lambda x: torch.div(1, torch.square(x)), 'relu': + torch.relu, 'softplus': nn.Softplus(), 'identity': nn.Identity()} + self.dict2 = {'abs': 1, 'exp': 0, 'x^2': 1, '1/x^2': 1, 'relu': 1, + 'softplus': 0.5414, 'identity': 1} + self.func = self.dict1[self.nl] + tri_vec.fill_(self.dict2[self.nl]) self.tri_vec = Parameter(tri_vec) class DiagCovariance(Covariance): """ Module for learning diagonal of weight covariance matrix input: size of covariance matrix""" - def __init__(self, triu_size): - super().__init__(triu_size) + def __init__(self, triu_size, nonlinearity): + super().__init__(triu_size, nonlinearity) triu = torch.triu_indices(self.triu_size, self.triu_size, dtype=torch.long) @@ -156,4 +229,12 @@ def __init__(self, triu_size): triu_len = scat_idx.shape[0] tri_vec = torch.zeros((triu_len,)) + self.nl = nonlinearity + self.dict1 = {'abs': torch.abs, 'exp': torch.exp, 'x^2': torch.square, + '1/x^2': lambda x: torch.div(1, torch.square(x)), 'relu': + torch.relu, 'softplus': nn.Softplus(), 'identity': nn.Identity()} + self.dict2 = {'abs': 1, 'exp': 0, 'x^2': 1, '1/x^2': 1, 'relu': 1, + 'softplus': 0.5414, 'identity': 1} + self.func = self.dict1[self.nl] + tri_vec.fill_(self.dict2[self.nl]) self.tri_vec = Parameter(tri_vec) diff --git a/FactConv/models/__init__.py b/FactConv/models/__init__.py index 2553089..e3a175e 100644 --- a/FactConv/models/__init__.py +++ b/FactConv/models/__init__.py @@ -4,7 +4,8 @@ replace_layers_diagfactconv2d, replace_layers_diagchanfactconv2d,\ turn_off_covar_grad, replace_layers_scale, init_V1_layers,\ replace_layers_lowrank, replace_layers_lowrankplusdiag,\ -replace_layers_lowrankK1, replace_layers_offdiag +replace_layers_lowrankK1, replace_layers_offdiag,\ +replace_affines def define_models(args): @@ -17,19 +18,21 @@ def define_models(args): if args.width != 1: replace_layers_scale(model, args.width) if 'fact' in args.net: - replace_layers_factconv2d(model) + replace_layers_factconv2d(model, args.nonlinearity) if 'diag' in args.net: - replace_layers_diagfactconv2d(model) + replace_layers_diagfactconv2d(model, args.nonlinearity) if 'diagchan' in args.net: - replace_layers_diagchanfactconv2d(model) + replace_layers_diagchanfactconv2d(model, args.nonlinearity) if 'lowrank' in args.net: replace_layers_lowrank(model, args.spatial_k, args.channel_k) if 'lr-diag' in args.net: replace_layers_lowrankplusdiag(model) if 'lr-K1' in args.net: - replace_layers_lowrankK1(model, args.channel_k) + replace_layers_lowrankK1(model, args.channel_k, args.nonlinearity) + if 'no-affines' in args.net: + replace_affines(model) if 'offdiag' in args.net: - replace_layers_offdiag(model) + replace_layers_offdiag(model, args.nonlinearity) if 'v1' in args.net: init_V1_layers(model, bias=False) if 'us' in args.net: diff --git a/FactConv/models/function_utils.py b/FactConv/models/function_utils.py index a3575eb..a34c769 100644 --- a/FactConv/models/function_utils.py +++ b/FactConv/models/function_utils.py @@ -19,7 +19,7 @@ def recurse_preorder(model, callback): return model -def replace_layers_factconv2d(model): +def replace_layers_factconv2d(model, nonlinearity): ''' Replace nn.Conv2d layers with FactConv2d ''' @@ -28,6 +28,7 @@ def _replace_layers_factconv2d(module): ## simple module new_module = FactConv2d( in_channels=module.in_channels, + nonlinearity=nonlinearity, out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, @@ -42,7 +43,7 @@ def _replace_layers_factconv2d(module): return recurse_preorder(model, _replace_layers_factconv2d) -def replace_layers_diagfactconv2d(model): +def replace_layers_diagfactconv2d(model, nonlinearity): ''' Replace nn.Conv2d layers with DiagFactConv2d ''' @@ -51,6 +52,7 @@ def _replace_layers_diagfactconv2d(module): ## simple module new_module = DiagFactConv2d( in_channels=module.in_channels, + nonlinearity=nonlinearity, out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, @@ -64,7 +66,7 @@ def _replace_layers_diagfactconv2d(module): return new_module return recurse_preorder(model, _replace_layers_diagfactconv2d) -def replace_layers_diagchanfactconv2d(model): +def replace_layers_diagchanfactconv2d(model,nonlinearity): ''' Replace nn.Conv2d layers with DiagChanFactConv2d ''' @@ -75,6 +77,7 @@ def _replace_layers_diagchanfactconv2d(module): in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, + nonlinearity=nonlinearity, stride=module.stride, padding=module.padding, bias=True if module.bias is not None else False) old_sd = module.state_dict() @@ -288,7 +291,7 @@ def _replace_layers_lowrankplusdiag(module): return new_module return recurse_preorder(model, _replace_layers_lowrankplusdiag) -def replace_layers_lowrankK1(model, channel_k): +def replace_layers_lowrankK1(model, channel_k, nonlinearity): ''' Replace nn.Conv2d layers with LowRankK1FactConv2d ''' @@ -300,7 +303,7 @@ def _replace_layers_lowrankK1(module): out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, - channel_k=channel_k, + channel_k=channel_k, nonlinearity=nonlinearity, bias=True if module.bias is not None else False) old_sd = module.state_dict() new_sd = new_module.state_dict() @@ -311,7 +314,7 @@ def _replace_layers_lowrankK1(module): return new_module return recurse_preorder(model, _replace_layers_lowrankK1) -def replace_layers_offdiag(model): +def replace_layers_offdiag(model, nonlinearity): ''' Replace nn.Conv2d layers with OffDiagFactConv2d ''' @@ -323,6 +326,7 @@ def _replace_layers_offdiag(module): out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, + nonlinearity=nonlinearity, bias=True if module.bias is not None else False) old_sd = module.state_dict() new_sd = new_module.state_dict() diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 9fd1eed..d182ebd 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -15,7 +15,7 @@ def save_model(args, model): src="/home/mila/v/vivian.white/scratch/v1-models/saved-models/low-rank/" - model_dir = src + args.net + "-seed" + str(args.seed) + model_dir = src + args.net + "_" + args.nonlinearity+ "-seed" + str(args.seed) print("Model dir: ", model_dir) os.makedirs(model_dir, exist_ok=True) @@ -35,6 +35,7 @@ def save_model(args, model): parser.add_argument('--width', type=float, default=1, help='resnet width scale factor') parser.add_argument('--spatial_k', type=float, default=1, help='%spatial low-rank') parser.add_argument('--channel_k', type=float, default=1, help='%channel low-rank') +parser.add_argument('--nonlinearity', type=str, default='exp') args = parser.parse_args() @@ -73,7 +74,8 @@ def save_model(args, model): print('==> Building model..') net = define_models(args) -run_name = "{}_{}_seed_{}".format(args.net, args.channel_k, args.seed) +run_name = "{}_rank{}_seed{}_width{}_{}".format(args.net, args.channel_k, args.seed,\ + args.width, args.nonlinearity) print("Args.net: ", args.net) print("Net: ", net) @@ -89,7 +91,7 @@ def save_model(args, model): print("Num Learnable Params: ", param_count) -run = wandb.init(project="factconv", config=args, group="lowrankplusdiag", name=run_name, dir=wandb_dir) +run = wandb.init(project="factconv", config=args, group="lowrankabs", name=run_name, dir=wandb_dir) #wandb.watch(net, log='all', log_freq=1) run.log({'param_count':param_count}) From db9ba1ba439767d88c12b3c1d0a3451cb3ea5f08 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Sat, 20 Jul 2024 18:22:56 -0400 Subject: [PATCH 76/77] adding resnext model --- FactConv/models/__init__.py | 9 +- FactConv/models/function_utils.py | 42 +++++++- FactConv/models/initializer.py | 26 +++++ FactConv/models/resnext.py | 172 ++++++++++++++++++++++++++++++ 4 files changed, 242 insertions(+), 7 deletions(-) create mode 100644 FactConv/models/initializer.py create mode 100644 FactConv/models/resnext.py diff --git a/FactConv/models/__init__.py b/FactConv/models/__init__.py index e3a175e..ecb2a68 100644 --- a/FactConv/models/__init__.py +++ b/FactConv/models/__init__.py @@ -1,16 +1,19 @@ from .resnet import ResNet18 +from .resnext import Network from .LC_models import LC_CIFAR10 from .function_utils import replace_layers_factconv2d,\ replace_layers_diagfactconv2d, replace_layers_diagchanfactconv2d,\ turn_off_covar_grad, replace_layers_scale, init_V1_layers,\ replace_layers_lowrank, replace_layers_lowrankplusdiag,\ replace_layers_lowrankK1, replace_layers_offdiag,\ -replace_affines +replace_affines, replace_layers_diagdom def define_models(args): if 'resnet18' in args.net: model = ResNet18() + if 'resnext' in args.net: + model = Network() if 'rsn' in args.net: model = LC_CIFAR10(hidden_dim=100, size=2, spatial_freq=0.1, scale=1, bias=True, freeze_spatial=False, freeze_channel=False, @@ -26,9 +29,11 @@ def define_models(args): if 'lowrank' in args.net: replace_layers_lowrank(model, args.spatial_k, args.channel_k) if 'lr-diag' in args.net: - replace_layers_lowrankplusdiag(model) + replace_layers_lowrankplusdiag(model, args.channel_k, args.nonlinearity) if 'lr-K1' in args.net: replace_layers_lowrankK1(model, args.channel_k, args.nonlinearity) + if 'diagdom' in args.net: + replace_layers_diagdom(model) if 'no-affines' in args.net: replace_affines(model) if 'offdiag' in args.net: diff --git a/FactConv/models/function_utils.py b/FactConv/models/function_utils.py index a34c769..8595888 100644 --- a/FactConv/models/function_utils.py +++ b/FactConv/models/function_utils.py @@ -2,9 +2,10 @@ import torch.nn as nn from conv_modules import FactConv2d, DiagFactConv2d, DiagChanFactConv2d,\ LowRankFactConv2d, LowRankPlusDiagFactConv2d, LowRankK1FactConv2d, \ -OffDiagFactConv2d +OffDiagFactConv2d, DiagDomFactConv2d from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance,\ -LowRankK1Covariance, OffDiagCovariance, DiagCovariance +LowRankK1Covariance, OffDiagCovariance, DiagCovariance,\ +DiagonallyDominantCovariance from V1_covariance import V1_init @@ -31,6 +32,7 @@ def _replace_layers_factconv2d(module): nonlinearity=nonlinearity, out_channels=module.out_channels, kernel_size=module.kernel_size, + groups=module.groups, stride=module.stride, padding=module.padding, bias=True if module.bias is not None else False) old_sd = module.state_dict() @@ -52,6 +54,7 @@ def _replace_layers_diagfactconv2d(module): ## simple module new_module = DiagFactConv2d( in_channels=module.in_channels, + groups=module.groups, nonlinearity=nonlinearity, out_channels=module.out_channels, kernel_size=module.kernel_size, @@ -76,6 +79,7 @@ def _replace_layers_diagchanfactconv2d(module): new_module = DiagChanFactConv2d( in_channels=module.in_channels, out_channels=module.out_channels, + groups=module.groups, kernel_size=module.kernel_size, nonlinearity=nonlinearity, stride=module.stride, padding=module.padding, @@ -122,7 +126,7 @@ def _replace_layers_scale(module): out_channels=int(module.out_channels*scale), kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, - groups = module.groups, + groups=module.groups, bias=True if module.bias is not None else False) return new_module if isinstance(module, nn.BatchNorm2d): @@ -159,6 +163,7 @@ def _replace_layers_fact_with_conv(module): in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, + groups=module.groups, stride=module.stride, padding=module.padding, bias=True if module.bias is not None else False) old_sd = module.state_dict() @@ -255,6 +260,7 @@ def _replace_layers_lowrank(module): in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, + groups=module.groups, stride=module.stride, padding=module.padding, spatial_k=spatial_k, channel_k=channel_k, bias=True if module.bias is not None else False) @@ -268,7 +274,7 @@ def _replace_layers_lowrank(module): return recurse_preorder(model, _replace_layers_lowrank) -def replace_layers_lowrankplusdiag(model, channel_k): +def replace_layers_lowrankplusdiag(model, channel_k, nonlinearity): ''' Replace nn.Conv2d layers with LowRankFactConv2d ''' @@ -280,7 +286,8 @@ def _replace_layers_lowrankplusdiag(module): out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, - channel_k=channel_k, + groups=module.groups, + channel_k=channel_k, nonlinearity=nonlinearity, bias=True if module.bias is not None else False) old_sd = module.state_dict() new_sd = new_module.state_dict() @@ -302,6 +309,7 @@ def _replace_layers_lowrankK1(module): in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, + groups=module.groups, stride=module.stride, padding=module.padding, channel_k=channel_k, nonlinearity=nonlinearity, bias=True if module.bias is not None else False) @@ -325,6 +333,7 @@ def _replace_layers_offdiag(module): in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, + groups=module.groups, stride=module.stride, padding=module.padding, nonlinearity=nonlinearity, bias=True if module.bias is not None else False) @@ -336,3 +345,26 @@ def _replace_layers_offdiag(module): new_module.load_state_dict(new_sd) return new_module return recurse_preorder(model, _replace_layers_offdiag) + +def replace_layers_diagdom(model): + ''' + Replace nn.Conv2d layers with DiagDomFactConv2d + ''' + def _replace_layers_diagdom(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = DiagDomFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + groups=module.groups, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_diagdom) diff --git a/FactConv/models/initializer.py b/FactConv/models/initializer.py new file mode 100644 index 0000000..f2c5148 --- /dev/null +++ b/FactConv/models/initializer.py @@ -0,0 +1,26 @@ +from typing import Callable + +import torch.nn as nn + + +def create_initializer(mode: str) -> Callable: + if mode in ['kaiming_fan_out', 'kaiming_fan_in']: + mode = mode[8:] + + def initializer(module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight.data, + mode=mode, + nonlinearity='relu') + elif isinstance(module, nn.BatchNorm2d): + nn.init.ones_(module.weight.data) + nn.init.zeros_(module.bias.data) + elif isinstance(module, nn.Linear): + nn.init.kaiming_normal_(module.weight.data, + mode=mode, + nonlinearity='relu') + nn.init.zeros_(module.bias.data) + else: + raise ValueError() + + return initializer diff --git a/FactConv/models/resnext.py b/FactConv/models/resnext.py new file mode 100644 index 0000000..9349dbb --- /dev/null +++ b/FactConv/models/resnext.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .initializer import create_initializer + + +class BottleneckBlock(nn.Module): + expansion = 4 + + def __init__(self, in_channels, out_channels, stride, stage_index, + base_channels, cardinality): + super().__init__() + + bottleneck_channels = cardinality * base_channels * 2**stage_index + + self.conv1 = nn.Conv2d(in_channels, + bottleneck_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn1 = nn.BatchNorm2d(bottleneck_channels) + + self.conv2 = nn.Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride, # downsample with 3x3 conv + padding=1, + groups=cardinality, + bias=False) + self.bn2 = nn.BatchNorm2d(bottleneck_channels) + + self.conv3 = nn.Conv2d(bottleneck_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn3 = nn.BatchNorm2d(out_channels) + + self.shortcut = nn.Sequential() # identity + if in_channels != out_channels: + self.shortcut.add_module( + 'conv', + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, # downsample + padding=0, + bias=False)) + self.shortcut.add_module('bn', nn.BatchNorm2d(out_channels)) # BN + + def forward(self, x): + y = F.relu(self.bn1(self.conv1(x)), inplace=True) + y = F.relu(self.bn2(self.conv2(y)), inplace=True) + y = self.bn3(self.conv3(y)) # not apply ReLU + y += self.shortcut(x) + y = F.relu(y, inplace=True) # apply ReLU after addition + return y + + +class Network(nn.Module): + def __init__(self): + super().__init__() + + #model_config = config.model.resnext + #depth = model_config.depth + #initial_channels = model_config.initial_channels + #self.base_channels = model_config.base_channels + #self.cardinality = model_config.cardinality + depth = 29 + initial_channels = 64 + self.cardinality = 8 + self.base_channels = 64 + + n_blocks_per_stage = (depth - 2) // 9 + assert n_blocks_per_stage * 9 + 2 == depth + block = BottleneckBlock + + n_channels = [ + initial_channels, + initial_channels * block.expansion, + initial_channels * 2 * block.expansion, + initial_channels * 4 * block.expansion, + ] + + #self.conv = nn.Conv2d(config.dataset.n_channels, + self.conv = nn.Conv2d(3, + n_channels[0], + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn = nn.BatchNorm2d(n_channels[0]) + + self.stage1 = self._make_stage(n_channels[0], + n_channels[1], + n_blocks_per_stage, + 0, + stride=1) + self.stage2 = self._make_stage(n_channels[1], + n_channels[2], + n_blocks_per_stage, + 1, + stride=2) + self.stage3 = self._make_stage(n_channels[2], + n_channels[3], + n_blocks_per_stage, + 2, + stride=2) + + # compute conv feature size + with torch.no_grad(): + dummy_data = torch.zeros( + #(1, config.dataset.n_channels, config.dataset.image_size, + (1, 3, 32, 32), + # config.dataset.image_size), + dtype=torch.float32) + self.feature_size = self._forward_conv(dummy_data).view( + -1).shape[0] + + #self.fc = nn.Linear(self.feature_size, config.dataset.n_classes) + self.fc = nn.Linear(self.feature_size, 10) + + # initialize weights + #initializer = create_initializer(config.model.init_mode) + initializer = create_initializer("kaiming_fan_out") + self.apply(initializer) + + def _make_stage(self, in_channels, out_channels, n_blocks, stage_index, + stride): + stage = nn.Sequential() + for index in range(n_blocks): + block_name = f'block{index + 1}' + if index == 0: + stage.add_module( + block_name, + BottleneckBlock( + in_channels, + out_channels, + stride, # downsample + stage_index, + self.base_channels, + self.cardinality)) + else: + stage.add_module( + block_name, + BottleneckBlock( + out_channels, + out_channels, + 1, # no downsampling + stage_index, + self.base_channels, + self.cardinality)) + return stage + + def _forward_conv(self, x): + x = F.relu(self.bn(self.conv(x)), inplace=True) + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = F.adaptive_avg_pool2d(x, output_size=1) + return x + + def forward(self, x): + x = self._forward_conv(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x From de9820b317ce314767b66a71ef18bebc88989a08 Mon Sep 17 00:00:00 2001 From: vivianwhite Date: Thu, 22 Aug 2024 16:31:02 -0400 Subject: [PATCH 77/77] added diagonally dominant experiment and updated scripts --- FactConv/conv_modules.py | 48 ++++++++++++++++++++++++++++++++++--- FactConv/cov.py | 45 ++++++++++++++++++++++++++++++++-- FactConv/launched.sh | 28 +++++++++++++++------- FactConv/models/__init__.py | 9 +++++++ FactConv/pytorch_cifar.py | 10 ++++---- FactConv/setoff.sh | 2 +- 6 files changed, 123 insertions(+), 19 deletions(-) diff --git a/FactConv/conv_modules.py b/FactConv/conv_modules.py index f7ce681..93d325e 100644 --- a/FactConv/conv_modules.py +++ b/FactConv/conv_modules.py @@ -6,7 +6,7 @@ from typing import Optional, List, Tuple, Union from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance,\ LowRankK1Covariance, OffDiagCovariance, DiagCovariance,\ -LowRankK1DiagCovariance +LowRankK1DiagCovariance, DiagonallyDominantCovariance """ The function below is copied directly from @@ -118,6 +118,7 @@ class LowRankPlusDiagFactConv2d(nn.Conv2d): def __init__( self, channel_k, + nonlinearity, in_channels: int, out_channels: int, kernel_size: _size_2_t, @@ -144,8 +145,8 @@ def __init__( channel_triu_size = self.in_channels // self.groups spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] - self.channel = LowRankPlusDiagCovariance(channel_triu_size, self.channel_k) - self.spatial = Covariance(spatial_triu_size) + self.channel = LowRankK1DiagCovariance(channel_triu_size, self.channel_k, nonlinearity) + self.spatial = Covariance(spatial_triu_size, nonlinearity) def forward(self, input: Tensor) -> Tensor: @@ -335,3 +336,44 @@ def forward(self, input: Tensor) -> Tensor: torch.flatten(composite_weight, -2, -1), U2.T, -1 ).reshape(self.weight.shape) return self._conv_forward(input, composite_weight, self.bias) + +class DiagDomFactConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = DiagonallyDominantCovariance(channel_triu_size) + self.spatial = Covariance(spatial_triu_size, "abs") + + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) diff --git a/FactConv/cov.py b/FactConv/cov.py index 4cf8006..32bb69a 100644 --- a/FactConv/cov.py +++ b/FactConv/cov.py @@ -168,8 +168,8 @@ class LowRankK1DiagCovariance(Covariance): Input: the size of the covariance matrix and the rank (int) (Absolute rank) """ - def __init__(self, triu_size, rank): - super().__init__(triu_size) + def __init__(self, triu_size, rank, nonlinearity): + super().__init__(triu_size, nonlinearity) if triu_size == 3: self.k = 3 @@ -189,6 +189,17 @@ def __init__(self, triu_size, rank): triu_len = scat_idx.shape[0] tri_vec = torch.zeros((triu_len,)) + + self.nl = nonlinearity + self.dict1 = {'abs': torch.abs, 'exp': torch.exp, 'x^2': torch.square, + '1/x^2': lambda x: torch.div(1, torch.square(x)), 'relu': + torch.relu, 'softplus': nn.Softplus(), 'identity': nn.Identity()} + self.dict2 = {'abs': 1, 'exp': 0, 'x^2': 1, '1/x^2': 1, 'relu': 1, + 'softplus': 0.5414, 'identity': 1} + self.func = self.dict1[self.nl] + + diag = triu[0][mask] == triu[1][mask] + tri_vec[diag] = self.dict2[self.nl] self.tri_vec = Parameter(tri_vec) class OffDiagCovariance(Covariance): @@ -238,3 +249,33 @@ def __init__(self, triu_size, nonlinearity): self.func = self.dict1[self.nl] tri_vec.fill_(self.dict2[self.nl]) self.tri_vec = Parameter(tri_vec) + + +class DiagonallyDominantCovariance(Covariance): + """ Module for learnable diagonally dominant weight covariance matrices + Input: the size of the covariance matrix + """ + def __init__(self, triu_size): + super().__init__(triu_size, "abs") + self.triu_size = triu_size + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + scat_idx = triu[0] * self.triu_size + triu[1] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = triu.shape[1] + tri_vec = torch.zeros((triu_len,)) + + diag = triu[0] == triu[1] + tri_vec[diag] = 1 + self.tri_vec = Parameter(tri_vec) + + def _tri_vec_to_mat(self, vec, n, scat_idx): + r = torch.zeros((n*n), device=self.tri_vec.device, + dtype=self.tri_vec.dtype) + r = r.view(n, n).fill_diagonal_(1).view(n*n) + r = r.scatter_(0, scat_idx, vec).view(n, n) + r = torch.diagonal_scatter(r, torch.abs(torch.sum(r, dim=1))) +# print(r) + return r + diff --git a/FactConv/launched.sh b/FactConv/launched.sh index 06277c0..484fa99 100755 --- a/FactConv/launched.sh +++ b/FactConv/launched.sh @@ -1,16 +1,17 @@ #!/bin/bash #width=(0.125 0.25 0.5 1.0 2.0 4.0) -#width=(1) -#rank=(1.0 0.9 0.65 0.75 0.5 0.35 0.25 0.15 0.1 0.05) -rank=(1.0 0.75 0.5 0.25 0.1) -#seed=(1 2) -#width=(0.125) +#nonlins=("exp" "abs" "x^2" "1/x^2" "softplus" "identity" "relu") +width=(1) +rank=(1 2 4 8 16 32 64 128 256 512 1024) seed=(0 1 2) +#seed=(3) -for i in ${rank[@]} +for i in ${width[@]} do - for j in ${seed[@]} + for j in ${rank[@]} do + for k in ${seed[@]} + do #sbatch setoff.sh --width $i --seed $j --net resnet18 #sbatch setoff.sh --width $i --seed $j --net fact_resnet18 #sbatch setoff.sh --width $i --seed $j --net fact_us_resnet18 @@ -18,11 +19,20 @@ do #sbatch setoff.sh --width $i --seed $j --net fact_us_uc_resnet18 # WHERE IS THIS VIVIAN 🧐🧐🤨🤨 # NVM good job Vivian - sbatch setoff.sh $i $j resnet18_fact_lr-diag - sbatch setoff.sh $i $j resnet18_fact_lowrank +# sbatch setoff.sh $i $j $k resnet18-fact-lr-K1-no-affines +# sbatch setoff.sh $i $j $k resnet18-fact-lr-K1 +# sbatch setoff.sh $i $j $k resnet18-diagdom +# sbatch setoff.sh $i $j $k resnet18-diagchan +# sbatch setoff.sh $i $j $k resnet18-lr-diag + sbatch setoff.sh $i $j $k wrn-lr-K1 + sbatch setoff.sh $i $j $k wrn-lr-diag +# sbatch setoff.sh $i $j $k resnet18-fact-us +# sbatch setoff.sh $i $j $k resnet18-fact-uc +# sbatch setoff.sh $i $j $k resnet18-fact-usuc #sbatch setoff.sh $i $j fact_diagchan_us_resnet18 #sbatch setoff.sh $i $j fact_diagchan_uc_resnet18 #sbatch setoff.sh $i $j fact_diagchan_us_uc_resnet18 + done done done diff --git a/FactConv/models/__init__.py b/FactConv/models/__init__.py index ecb2a68..ac51710 100644 --- a/FactConv/models/__init__.py +++ b/FactConv/models/__init__.py @@ -1,5 +1,8 @@ from .resnet import ResNet18 from .resnext import Network +from .wideresnet import WideResNet +from .convnext import convnext_base +from .conv2next import conv2next_base from .LC_models import LC_CIFAR10 from .function_utils import replace_layers_factconv2d,\ replace_layers_diagfactconv2d, replace_layers_diagchanfactconv2d,\ @@ -14,6 +17,12 @@ def define_models(args): model = ResNet18() if 'resnext' in args.net: model = Network() + if 'wrn' in args.net: + model = WideResNet(depth=28, num_classes=10, widen_factor=16, dropRate=0) + if 'convnext' in args.net: + model = convnext_base() + if 'conv2next' in args.net: + model = conv2next_base() if 'rsn' in args.net: model = LC_CIFAR10(hidden_dim=100, size=2, spatial_freq=0.1, scale=1, bias=True, freeze_spatial=False, freeze_channel=False, diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index d182ebd..d90aadd 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -15,7 +15,7 @@ def save_model(args, model): src="/home/mila/v/vivian.white/scratch/v1-models/saved-models/low-rank/" - model_dir = src + args.net + "_" + args.nonlinearity+ "-seed" + str(args.seed) + model_dir = src + args.net + "_rank" + str(args.channel_k) + "_width" + str(args.width) + "_" + args.nonlinearity+ "-seed" + str(args.seed) print("Model dir: ", model_dir) os.makedirs(model_dir, exist_ok=True) @@ -35,7 +35,7 @@ def save_model(args, model): parser.add_argument('--width', type=float, default=1, help='resnet width scale factor') parser.add_argument('--spatial_k', type=float, default=1, help='%spatial low-rank') parser.add_argument('--channel_k', type=float, default=1, help='%channel low-rank') -parser.add_argument('--nonlinearity', type=str, default='exp') +parser.add_argument('--nonlinearity', type=str, default='abs') args = parser.parse_args() @@ -78,6 +78,10 @@ def save_model(args, model): args.width, args.nonlinearity) print("Args.net: ", args.net) print("Net: ", net) +print("Run name: ", run_name) + +param_count = sum(p.numel() for p in net.parameters() if p.requires_grad) +print("Num Learnable Params: ", param_count) set_seeds(args.seed) @@ -86,9 +90,7 @@ def save_model(args, model): os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) -param_count = sum(p.numel() for p in net.parameters() if p.requires_grad) -print("Num Learnable Params: ", param_count) run = wandb.init(project="factconv", config=args, group="lowrankabs", name=run_name, dir=wandb_dir) diff --git a/FactConv/setoff.sh b/FactConv/setoff.sh index 10972e4..2aa95b1 100644 --- a/FactConv/setoff.sh +++ b/FactConv/setoff.sh @@ -10,4 +10,4 @@ module load anaconda/3 conda activate random_features -python pytorch_cifar.py --channel_k $1 --seed $2 --net $3 +python pytorch_cifar.py --nonlinearity "abs" --width $1 --channel_k $2 --seed $3 --net $4