diff --git a/FactConv/conv_modules.py b/FactConv/conv_modules.py index 5274173..93d325e 100644 --- a/FactConv/conv_modules.py +++ b/FactConv/conv_modules.py @@ -4,6 +4,9 @@ from torch.nn.parameter import Parameter from torch.nn.common_types import _size_2_t from typing import Optional, List, Tuple, Union +from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance,\ +LowRankK1Covariance, OffDiagCovariance, DiagCovariance,\ +LowRankK1DiagCovariance, DiagonallyDominantCovariance """ The function below is copied directly from @@ -19,6 +22,7 @@ def _contract(tensor, matrix, axis): class FactConv2d(nn.Conv2d): def __init__( self, + nonlinearity, in_channels: int, out_channels: int, kernel_size: _size_2_t, @@ -41,38 +45,68 @@ def __init__( self.register_buffer("weight", new_weight) nn.init.kaiming_normal_(self.weight) - self.in_features = self.in_channels // self.groups * \ - self.kernel_size[0] * self.kernel_size[1] - triu1 = torch.triu_indices(self.in_channels // self.groups, - self.in_channels // self.groups, - device=self.weight.device, - dtype=torch.long) - scat_idx1 = triu1[0]*self.in_channels//self.groups + triu1[1] - self.register_buffer("scat_idx1", scat_idx1, persistent=False) + # self.in_features = self.in_channels // self.groups * \ + # self.kernel_size[0] * self.kernel_size[1] - triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1], - self.kernel_size[0] - * self.kernel_size[1], - device=self.weight.device, - dtype=torch.long) - scat_idx2 = triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1] - self.register_buffer("scat_idx2", scat_idx2, persistent=False) + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] - triu1_len = triu1.shape[1] - triu2_len = triu2.shape[1] + self.channel = Covariance(channel_triu_size, nonlinearity) + self.spatial = Covariance(spatial_triu_size, nonlinearity) - tri1_vec = self.weight.new_zeros((triu1_len,)) - self.tri1_vec = Parameter(tri1_vec) + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) - tri2_vec = self.weight.new_zeros((triu2_len,)) - self.tri2_vec = Parameter(tri2_vec) + +class LowRankFactConv2d(nn.Conv2d): + def __init__( + self, + spatial_k, + channel_k, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None, + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + self.spatial_k = spatial_k + self.channel_k = channel_k + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = LowRankCovariance(channel_triu_size, self.channel_k) + self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k) + #self.spatial = Covariance(spatial_triu_size) def forward(self, input: Tensor) -> Tensor: - U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels // - self.groups, self.scat_idx1) - U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1], - self.scat_idx2) + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + # flatten over filter dims and contract composite_weight = _contract(self.weight, U1.T, 1) composite_weight = _contract( @@ -80,8 +114,266 @@ def forward(self, input: Tensor) -> Tensor: ).reshape(self.weight.shape) return self._conv_forward(input, composite_weight, self.bias) - def _tri_vec_to_mat(self, vec, n, scat_idx): - U = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n) - U = torch.diagonal_scatter(U, U.diagonal().exp_()) - return U +class LowRankPlusDiagFactConv2d(nn.Conv2d): + def __init__( + self, + channel_k, + nonlinearity, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None, + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + self.channel_k = channel_k + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = LowRankK1DiagCovariance(channel_triu_size, self.channel_k, nonlinearity) + self.spatial = Covariance(spatial_triu_size, nonlinearity) + + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) + +class LowRankK1FactConv2d(nn.Conv2d): + def __init__( + self, + channel_k, + nonlinearity, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None, + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + self.channel_k = channel_k + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = LowRankK1Covariance(channel_triu_size, self.channel_k,\ + nonlinearity) + self.spatial = Covariance(spatial_triu_size, nonlinearity) + + + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) + +class OffDiagFactConv2d(nn.Conv2d): + def __init__( + self, + nonlinearity, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None, + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = OffDiagCovariance(channel_triu_size, nonlinearity) + self.spatial = Covariance(spatial_triu_size, nonlinearity) + + + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) + +class DiagFactConv2d(nn.Conv2d): + def __init__( + self, + nonlinearity, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = DiagCovariance(channel_triu_size, nonlinearity) + self.spatial = DiagCovariance(spatial_triu_size, nonlinearity) + + + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) + +class DiagChanFactConv2d(nn.Conv2d): + def __init__( + self, + nonlinearity, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = DiagCovariance(channel_triu_size, nonlinearity) + self.spatial = Covariance(spatial_triu_size, nonlinearity) + + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) + +class DiagDomFactConv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None + ) -> None: + # init as Conv2d + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, device, dtype) + # weight shape: (out_channels, in_channels // groups, *kernel_size) + new_weight = torch.empty_like(self.weight) + del self.weight # remove Parameter, create buffer + self.register_buffer("weight", new_weight) + nn.init.kaiming_normal_(self.weight) + + channel_triu_size = self.in_channels // self.groups + spatial_triu_size = self.kernel_size[0] * self.kernel_size[1] + + self.channel = DiagonallyDominantCovariance(channel_triu_size) + self.spatial = Covariance(spatial_triu_size, "abs") + + def forward(self, input: Tensor) -> Tensor: + U1 = self.channel.sqrt() + U2 = self.spatial.sqrt() + # flatten over filter dims and contract + composite_weight = _contract(self.weight, U1.T, 1) + composite_weight = _contract( + torch.flatten(composite_weight, -2, -1), U2.T, -1 + ).reshape(self.weight.shape) + return self._conv_forward(input, composite_weight, self.bias) diff --git a/FactConv/cov.py b/FactConv/cov.py new file mode 100644 index 0000000..32bb69a --- /dev/null +++ b/FactConv/cov.py @@ -0,0 +1,281 @@ +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn.parameter import Parameter +from torch.nn.common_types import _size_2_t +from typing import Optional, List, Tuple, Union +import math + +class Covariance(nn.Module): + """ Module for learnable weight covariance matrices + Input: the size of the covariance matrix and the nonlinearity + For channel covariances, triu_size is in_channels // groups + For spatial covariances, triu_size is kernel_size[0] * kernel_size[1] + """ + def __init__(self, triu_size, nonlinearity): + super().__init__() + self.triu_size = triu_size + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + scat_idx = triu[0] * self.triu_size + triu[1] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = triu.shape[1] + tri_vec = torch.zeros((triu_len,)) + + self.nl = nonlinearity + self.dict1 = {'abs': torch.abs, 'exp': torch.exp, 'x^2': torch.square, + '1/x^2': lambda x: torch.div(1, torch.square(x)), 'relu': + torch.relu, 'softplus': nn.Softplus(), 'identity': nn.Identity()} + self.dict2 = {'abs': 1, 'exp': 0, 'x^2': 1, '1/x^2': 1, 'relu': 1, + 'softplus': 0.5414, 'identity': 1} + self.func = self.dict1[self.nl] + diag = triu[0] == triu[1] + tri_vec[diag] = self.dict2[self.nl] + self.tri_vec = Parameter(tri_vec) + + + def cov(self): + R = self._tri_vec_to_mat(self.tri_vec, self.triu_size, self.scat_idx) + return R.T @ R + + def sqrt(self): + R = self._tri_vec_to_mat(self.tri_vec, self.triu_size, self.scat_idx) + return R + + def _tri_vec_to_mat(self, vec, n, scat_idx): + r = torch.zeros((n*n), device=self.tri_vec.device, + dtype=self.tri_vec.dtype) + r = r.view(n, n).fill_diagonal_(self.dict2[self.nl]).view(n*n) + r = r.scatter_(0, scat_idx, vec).view(n, n) + r = torch.diagonal_scatter(r, self.func(r.diagonal())) + + return r + +class LowRankCovariance(Covariance): + """ Module for learnable low-rank covariance matrices + Input: the size of the covariance matrix and the rank as a percentage + (Relative rank) + """ + def __init__(self, triu_size, rank): + super().__init__(triu_size) + self.k = math.ceil(rank * triu_size) + print(self.k) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = triu[0] < self.k + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) + +class LowRank1stFullCovariance(Covariance): + """ Module for learnable low-rank covariance matrices + Input: the size of the covariance matrix and the rank as a percentage + (Relative rank) + """ + def __init__(self, triu_size, rank): + super().__init__(triu_size) + if triu_size == 3: + self.k = 3 + print(self.k) + else: + self.k = math.ceil(rank * triu_size) + print(self.k) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = triu[0] < self.k + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) + + +class LowRankPlusDiagCovariance(Covariance): + """ Module for learning low-rank covariances plus the main diagonal + input: size of covariance matrix and the percentage rank (rows kept in upper triangular covaraince) + (Relative rank) + """ + + def __init__(self, triu_size, rank): + super().__init__(triu_size) + if triu_size == 3: + self.k = 3 + print(self.k) + else: + self.k = math.ceil(rank * triu_size) + print(self.k) + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = (triu[0] < self.k)|(triu[0]==triu[1]) + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.tri_vec = Parameter(tri_vec) + +class LowRankK1Covariance(Covariance): + """ Module for learnable low-rank covariance matrices where + each layer has a rank integer + Input: the size of the covariance matrix and the rank (int) + (Absolute rank) + """ + def __init__(self, triu_size, rank, nonlinearity): + super().__init__(triu_size, nonlinearity) + + if triu_size == 3: + self.k = 3 + print(self.k) + elif triu_size < rank: + self.k = triu_size + print(self.k) + else: + self.k = rank + print(self.k) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = triu[0] < self.k + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + + self.nl = nonlinearity + self.dict1 = {'abs': torch.abs, 'exp': torch.exp, 'x^2': torch.square, + '1/x^2': lambda x: torch.div(1, torch.square(x)), 'relu': + torch.relu, 'softplus': nn.Softplus(), 'identity': nn.Identity()} + self.dict2 = {'abs': 1, 'exp': 0, 'x^2': 1, '1/x^2': 1, 'relu': 1, + 'softplus': 0.5414, 'identity': 1} + self.func = self.dict1[self.nl] + diag = triu[0][mask] == triu[1][mask] + tri_vec[diag] = self.dict2[self.nl] + + self.tri_vec = Parameter(tri_vec) + +class LowRankK1DiagCovariance(Covariance): + """ Module for learnable low-rank covariance matrices where + each layer has a rank integer plus full diagonal + Input: the size of the covariance matrix and the rank (int) + (Absolute rank) + """ + def __init__(self, triu_size, rank, nonlinearity): + super().__init__(triu_size, nonlinearity) + + if triu_size == 3: + self.k = 3 + print(self.k) + elif triu_size < rank: + self.k = triu_size + print(self.k) + else: + self.k = rank + print(self.k) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = (triu[0] < self.k)|(triu[0]==triu[1]) + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + + self.nl = nonlinearity + self.dict1 = {'abs': torch.abs, 'exp': torch.exp, 'x^2': torch.square, + '1/x^2': lambda x: torch.div(1, torch.square(x)), 'relu': + torch.relu, 'softplus': nn.Softplus(), 'identity': nn.Identity()} + self.dict2 = {'abs': 1, 'exp': 0, 'x^2': 1, '1/x^2': 1, 'relu': 1, + 'softplus': 0.5414, 'identity': 1} + self.func = self.dict1[self.nl] + + diag = triu[0][mask] == triu[1][mask] + tri_vec[diag] = self.dict2[self.nl] + self.tri_vec = Parameter(tri_vec) + +class OffDiagCovariance(Covariance): + """ Module for learning off-diagonal of weight covariance matrix + input: size of covariance matrix""" + def __init__(self, triu_size, nonlinearity): + super().__init__(triu_size, nonlinearity) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = (triu[0]+1) == triu[1] + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.nl = nonlinearity + self.dict1 = {'abs': torch.abs, 'exp': torch.exp, 'x^2': torch.square, + '1/x^2': lambda x: torch.div(1, torch.square(x)), 'relu': + torch.relu, 'softplus': nn.Softplus(), 'identity': nn.Identity()} + self.dict2 = {'abs': 1, 'exp': 0, 'x^2': 1, '1/x^2': 1, 'relu': 1, + 'softplus': 0.5414, 'identity': 1} + self.func = self.dict1[self.nl] + tri_vec.fill_(self.dict2[self.nl]) + self.tri_vec = Parameter(tri_vec) + +class DiagCovariance(Covariance): + """ Module for learning diagonal of weight covariance matrix + input: size of covariance matrix""" + def __init__(self, triu_size, nonlinearity): + super().__init__(triu_size, nonlinearity) + + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + mask = triu[0] == triu[1] + scat_idx = triu[0][mask] * self.triu_size + triu[1][mask] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = scat_idx.shape[0] + tri_vec = torch.zeros((triu_len,)) + self.nl = nonlinearity + self.dict1 = {'abs': torch.abs, 'exp': torch.exp, 'x^2': torch.square, + '1/x^2': lambda x: torch.div(1, torch.square(x)), 'relu': + torch.relu, 'softplus': nn.Softplus(), 'identity': nn.Identity()} + self.dict2 = {'abs': 1, 'exp': 0, 'x^2': 1, '1/x^2': 1, 'relu': 1, + 'softplus': 0.5414, 'identity': 1} + self.func = self.dict1[self.nl] + tri_vec.fill_(self.dict2[self.nl]) + self.tri_vec = Parameter(tri_vec) + + +class DiagonallyDominantCovariance(Covariance): + """ Module for learnable diagonally dominant weight covariance matrices + Input: the size of the covariance matrix + """ + def __init__(self, triu_size): + super().__init__(triu_size, "abs") + self.triu_size = triu_size + triu = torch.triu_indices(self.triu_size, self.triu_size, + dtype=torch.long) + scat_idx = triu[0] * self.triu_size + triu[1] + self.register_buffer("scat_idx", scat_idx, persistent=False) + + triu_len = triu.shape[1] + tri_vec = torch.zeros((triu_len,)) + + diag = triu[0] == triu[1] + tri_vec[diag] = 1 + self.tri_vec = Parameter(tri_vec) + + def _tri_vec_to_mat(self, vec, n, scat_idx): + r = torch.zeros((n*n), device=self.tri_vec.device, + dtype=self.tri_vec.dtype) + r = r.view(n, n).fill_diagonal_(1).view(n*n) + r = r.scatter_(0, scat_idx, vec).view(n, n) + r = torch.diagonal_scatter(r, torch.abs(torch.sum(r, dim=1))) +# print(r) + return r + diff --git a/FactConv/launched.sh b/FactConv/launched.sh new file mode 100755 index 0000000..484fa99 --- /dev/null +++ b/FactConv/launched.sh @@ -0,0 +1,38 @@ +#!/bin/bash +#width=(0.125 0.25 0.5 1.0 2.0 4.0) +#nonlins=("exp" "abs" "x^2" "1/x^2" "softplus" "identity" "relu") +width=(1) +rank=(1 2 4 8 16 32 64 128 256 512 1024) +seed=(0 1 2) +#seed=(3) + +for i in ${width[@]} +do + for j in ${rank[@]} + do + for k in ${seed[@]} + do + #sbatch setoff.sh --width $i --seed $j --net resnet18 + #sbatch setoff.sh --width $i --seed $j --net fact_resnet18 + #sbatch setoff.sh --width $i --seed $j --net fact_us_resnet18 + #sbatch setoff.sh --width $i --seed $j --net fact_uc_resnet18 + #sbatch setoff.sh --width $i --seed $j --net fact_us_uc_resnet18 + # WHERE IS THIS VIVIAN 🧐🧐🤨🤨 + # NVM good job Vivian +# sbatch setoff.sh $i $j $k resnet18-fact-lr-K1-no-affines +# sbatch setoff.sh $i $j $k resnet18-fact-lr-K1 +# sbatch setoff.sh $i $j $k resnet18-diagdom +# sbatch setoff.sh $i $j $k resnet18-diagchan +# sbatch setoff.sh $i $j $k resnet18-lr-diag + sbatch setoff.sh $i $j $k wrn-lr-K1 + sbatch setoff.sh $i $j $k wrn-lr-diag +# sbatch setoff.sh $i $j $k resnet18-fact-us +# sbatch setoff.sh $i $j $k resnet18-fact-uc +# sbatch setoff.sh $i $j $k resnet18-fact-usuc + #sbatch setoff.sh $i $j fact_diagchan_us_resnet18 + #sbatch setoff.sh $i $j fact_diagchan_uc_resnet18 + #sbatch setoff.sh $i $j fact_diagchan_us_uc_resnet18 + done + done +done + diff --git a/FactConv/models/__init__.py b/FactConv/models/__init__.py index 434ce4f..ac51710 100644 --- a/FactConv/models/__init__.py +++ b/FactConv/models/__init__.py @@ -1,19 +1,59 @@ from .resnet import ResNet18 -from .function_utils import replace_layers_factconv2d, turn_off_covar_grad, replace_layers_scale, init_V1_layers +from .resnext import Network +from .wideresnet import WideResNet +from .convnext import convnext_base +from .conv2next import conv2next_base +from .LC_models import LC_CIFAR10 +from .function_utils import replace_layers_factconv2d,\ +replace_layers_diagfactconv2d, replace_layers_diagchanfactconv2d,\ +turn_off_covar_grad, replace_layers_scale, init_V1_layers,\ +replace_layers_lowrank, replace_layers_lowrankplusdiag,\ +replace_layers_lowrankK1, replace_layers_offdiag,\ +replace_affines, replace_layers_diagdom def define_models(args): if 'resnet18' in args.net: model = ResNet18() + if 'resnext' in args.net: + model = Network() + if 'wrn' in args.net: + model = WideResNet(depth=28, num_classes=10, widen_factor=16, dropRate=0) + if 'convnext' in args.net: + model = convnext_base() + if 'conv2next' in args.net: + model = conv2next_base() + if 'rsn' in args.net: + model = LC_CIFAR10(hidden_dim=100, size=2, spatial_freq=0.1, scale=1, + bias=True, freeze_spatial=False, freeze_channel=False, + spatial_init='V1') if args.width != 1: replace_layers_scale(model, args.width) if 'fact' in args.net: - replace_layers_factconv2d(model) - if "v1" in args.net: + replace_layers_factconv2d(model, args.nonlinearity) + if 'diag' in args.net: + replace_layers_diagfactconv2d(model, args.nonlinearity) + if 'diagchan' in args.net: + replace_layers_diagchanfactconv2d(model, args.nonlinearity) + if 'lowrank' in args.net: + replace_layers_lowrank(model, args.spatial_k, args.channel_k) + if 'lr-diag' in args.net: + replace_layers_lowrankplusdiag(model, args.channel_k, args.nonlinearity) + if 'lr-K1' in args.net: + replace_layers_lowrankK1(model, args.channel_k, args.nonlinearity) + if 'diagdom' in args.net: + replace_layers_diagdom(model) + if 'no-affines' in args.net: + replace_affines(model) + if 'offdiag' in args.net: + replace_layers_offdiag(model, args.nonlinearity) + if 'v1' in args.net: init_V1_layers(model, bias=False) - if "us" in args.net: - turn_off_covar_grad(model, "spatial") - if "uc" in args.net: - turn_off_covar_grad(model, "channel") + if 'us' in args.net: + turn_off_covar_grad(model, 'spatial') + print('US') + if 'uc' in args.net: + turn_off_covar_grad(model, 'channel') + print('UC') return model diff --git a/FactConv/models/function_utils.py b/FactConv/models/function_utils.py index c642039..8595888 100644 --- a/FactConv/models/function_utils.py +++ b/FactConv/models/function_utils.py @@ -1,6 +1,11 @@ import torch import torch.nn as nn -from conv_modules import FactConv2d +from conv_modules import FactConv2d, DiagFactConv2d, DiagChanFactConv2d,\ +LowRankFactConv2d, LowRankPlusDiagFactConv2d, LowRankK1FactConv2d, \ +OffDiagFactConv2d, DiagDomFactConv2d +from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance,\ +LowRankK1Covariance, OffDiagCovariance, DiagCovariance,\ +DiagonallyDominantCovariance from V1_covariance import V1_init @@ -15,7 +20,7 @@ def recurse_preorder(model, callback): return model -def replace_layers_factconv2d(model): +def replace_layers_factconv2d(model, nonlinearity): ''' Replace nn.Conv2d layers with FactConv2d ''' @@ -24,8 +29,10 @@ def _replace_layers_factconv2d(module): ## simple module new_module = FactConv2d( in_channels=module.in_channels, + nonlinearity=nonlinearity, out_channels=module.out_channels, kernel_size=module.kernel_size, + groups=module.groups, stride=module.stride, padding=module.padding, bias=True if module.bias is not None else False) old_sd = module.state_dict() @@ -38,6 +45,55 @@ def _replace_layers_factconv2d(module): return recurse_preorder(model, _replace_layers_factconv2d) +def replace_layers_diagfactconv2d(model, nonlinearity): + ''' + Replace nn.Conv2d layers with DiagFactConv2d + ''' + def _replace_layers_diagfactconv2d(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = DiagFactConv2d( + in_channels=module.in_channels, + groups=module.groups, + nonlinearity=nonlinearity, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_diagfactconv2d) + +def replace_layers_diagchanfactconv2d(model,nonlinearity): + ''' + Replace nn.Conv2d layers with DiagChanFactConv2d + ''' + def _replace_layers_diagchanfactconv2d(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = DiagChanFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + groups=module.groups, + kernel_size=module.kernel_size, + nonlinearity=nonlinearity, + stride=module.stride, padding=module.padding, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_diagchanfactconv2d) + + def replace_affines(model): ''' Set BatchNorm2d layers to have 'affine=False' @@ -70,7 +126,7 @@ def _replace_layers_scale(module): out_channels=int(module.out_channels*scale), kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, - groups = module.groups, + groups=module.groups, bias=True if module.bias is not None else False) return new_module if isinstance(module, nn.BatchNorm2d): @@ -107,6 +163,7 @@ def _replace_layers_fact_with_conv(module): in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, + groups=module.groups, stride=module.stride, padding=module.padding, bias=True if module.bias is not None else False) old_sd = module.state_dict() @@ -144,19 +201,20 @@ def _replace_layers_fact_with_conv(module): def turn_off_covar_grad(model, covariance): + print("Unlearnable ", covariance) ''' - Turn off gradients in tri1_vec or tri2_vec to turn off - channel or spatial covariance learning + Turn off gradients in the tri_vec param in + the Covariance module to disable channel + or spatial covariance learning ''' def _turn_off_covar_grad(module): - if isinstance(module, FactConv2d): - for name, param in module.named_parameters(): - if covariance == "channel": - if "tri1_vec" in name: - param.requires_grad = False - if covariance == "spatial": - if "tri2_vec" in name: - param.requires_grad = False + for name, mod in module.named_modules(): + if isinstance(mod, Covariance): + if covariance == name: + for name, param in mod.named_parameters(): + if "tri_vec" in name: + param.requires_grad = False + return recurse_preorder(model, _turn_off_covar_grad) @@ -191,3 +249,122 @@ def _init_V1_layers(module): param.requires_grad = False return recurse_preorder(model, _init_V1_layers) +def replace_layers_lowrank(model, spatial_k, channel_k): + ''' + Replace nn.Conv2d layers with LowRankFactConv2d + ''' + def _replace_layers_lowrank(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = LowRankFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + groups=module.groups, + stride=module.stride, padding=module.padding, + spatial_k=spatial_k, channel_k=channel_k, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_lowrank) + + +def replace_layers_lowrankplusdiag(model, channel_k, nonlinearity): + ''' + Replace nn.Conv2d layers with LowRankFactConv2d + ''' + def _replace_layers_lowrankplusdiag(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = LowRankPlusDiagFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + groups=module.groups, + channel_k=channel_k, nonlinearity=nonlinearity, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_lowrankplusdiag) + +def replace_layers_lowrankK1(model, channel_k, nonlinearity): + ''' + Replace nn.Conv2d layers with LowRankK1FactConv2d + ''' + def _replace_layers_lowrankK1(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = LowRankK1FactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + groups=module.groups, + stride=module.stride, padding=module.padding, + channel_k=channel_k, nonlinearity=nonlinearity, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_lowrankK1) + +def replace_layers_offdiag(model, nonlinearity): + ''' + Replace nn.Conv2d layers with OffDiagFactConv2d + ''' + def _replace_layers_offdiag(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = OffDiagFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + groups=module.groups, + stride=module.stride, padding=module.padding, + nonlinearity=nonlinearity, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_offdiag) + +def replace_layers_diagdom(model): + ''' + Replace nn.Conv2d layers with DiagDomFactConv2d + ''' + def _replace_layers_diagdom(module): + if isinstance(module, nn.Conv2d): + ## simple module + new_module = DiagDomFactConv2d( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, padding=module.padding, + groups=module.groups, + bias=True if module.bias is not None else False) + old_sd = module.state_dict() + new_sd = new_module.state_dict() + new_sd['weight'] = old_sd['weight'] + if module.bias is not None: + new_sd['bias'] = old_sd['bias'] + new_module.load_state_dict(new_sd) + return new_module + return recurse_preorder(model, _replace_layers_diagdom) diff --git a/FactConv/models/initializer.py b/FactConv/models/initializer.py new file mode 100644 index 0000000..f2c5148 --- /dev/null +++ b/FactConv/models/initializer.py @@ -0,0 +1,26 @@ +from typing import Callable + +import torch.nn as nn + + +def create_initializer(mode: str) -> Callable: + if mode in ['kaiming_fan_out', 'kaiming_fan_in']: + mode = mode[8:] + + def initializer(module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight.data, + mode=mode, + nonlinearity='relu') + elif isinstance(module, nn.BatchNorm2d): + nn.init.ones_(module.weight.data) + nn.init.zeros_(module.bias.data) + elif isinstance(module, nn.Linear): + nn.init.kaiming_normal_(module.weight.data, + mode=mode, + nonlinearity='relu') + nn.init.zeros_(module.bias.data) + else: + raise ValueError() + + return initializer diff --git a/FactConv/models/resnext.py b/FactConv/models/resnext.py new file mode 100644 index 0000000..9349dbb --- /dev/null +++ b/FactConv/models/resnext.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .initializer import create_initializer + + +class BottleneckBlock(nn.Module): + expansion = 4 + + def __init__(self, in_channels, out_channels, stride, stage_index, + base_channels, cardinality): + super().__init__() + + bottleneck_channels = cardinality * base_channels * 2**stage_index + + self.conv1 = nn.Conv2d(in_channels, + bottleneck_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn1 = nn.BatchNorm2d(bottleneck_channels) + + self.conv2 = nn.Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride, # downsample with 3x3 conv + padding=1, + groups=cardinality, + bias=False) + self.bn2 = nn.BatchNorm2d(bottleneck_channels) + + self.conv3 = nn.Conv2d(bottleneck_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn3 = nn.BatchNorm2d(out_channels) + + self.shortcut = nn.Sequential() # identity + if in_channels != out_channels: + self.shortcut.add_module( + 'conv', + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, # downsample + padding=0, + bias=False)) + self.shortcut.add_module('bn', nn.BatchNorm2d(out_channels)) # BN + + def forward(self, x): + y = F.relu(self.bn1(self.conv1(x)), inplace=True) + y = F.relu(self.bn2(self.conv2(y)), inplace=True) + y = self.bn3(self.conv3(y)) # not apply ReLU + y += self.shortcut(x) + y = F.relu(y, inplace=True) # apply ReLU after addition + return y + + +class Network(nn.Module): + def __init__(self): + super().__init__() + + #model_config = config.model.resnext + #depth = model_config.depth + #initial_channels = model_config.initial_channels + #self.base_channels = model_config.base_channels + #self.cardinality = model_config.cardinality + depth = 29 + initial_channels = 64 + self.cardinality = 8 + self.base_channels = 64 + + n_blocks_per_stage = (depth - 2) // 9 + assert n_blocks_per_stage * 9 + 2 == depth + block = BottleneckBlock + + n_channels = [ + initial_channels, + initial_channels * block.expansion, + initial_channels * 2 * block.expansion, + initial_channels * 4 * block.expansion, + ] + + #self.conv = nn.Conv2d(config.dataset.n_channels, + self.conv = nn.Conv2d(3, + n_channels[0], + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn = nn.BatchNorm2d(n_channels[0]) + + self.stage1 = self._make_stage(n_channels[0], + n_channels[1], + n_blocks_per_stage, + 0, + stride=1) + self.stage2 = self._make_stage(n_channels[1], + n_channels[2], + n_blocks_per_stage, + 1, + stride=2) + self.stage3 = self._make_stage(n_channels[2], + n_channels[3], + n_blocks_per_stage, + 2, + stride=2) + + # compute conv feature size + with torch.no_grad(): + dummy_data = torch.zeros( + #(1, config.dataset.n_channels, config.dataset.image_size, + (1, 3, 32, 32), + # config.dataset.image_size), + dtype=torch.float32) + self.feature_size = self._forward_conv(dummy_data).view( + -1).shape[0] + + #self.fc = nn.Linear(self.feature_size, config.dataset.n_classes) + self.fc = nn.Linear(self.feature_size, 10) + + # initialize weights + #initializer = create_initializer(config.model.init_mode) + initializer = create_initializer("kaiming_fan_out") + self.apply(initializer) + + def _make_stage(self, in_channels, out_channels, n_blocks, stage_index, + stride): + stage = nn.Sequential() + for index in range(n_blocks): + block_name = f'block{index + 1}' + if index == 0: + stage.add_module( + block_name, + BottleneckBlock( + in_channels, + out_channels, + stride, # downsample + stage_index, + self.base_channels, + self.cardinality)) + else: + stage.add_module( + block_name, + BottleneckBlock( + out_channels, + out_channels, + 1, # no downsampling + stage_index, + self.base_channels, + self.cardinality)) + return stage + + def _forward_conv(self, x): + x = F.relu(self.bn(self.conv(x)), inplace=True) + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = F.adaptive_avg_pool2d(x, output_size=1) + return x + + def forward(self, x): + x = self._forward_conv(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x diff --git a/FactConv/pytorch_cifar.py b/FactConv/pytorch_cifar.py index 893f38a..d90aadd 100644 --- a/FactConv/pytorch_cifar.py +++ b/FactConv/pytorch_cifar.py @@ -14,8 +14,9 @@ from models import define_models def save_model(args, model): - src= "../saved-models/ResNets/" - model_dir = src + args.name + src="/home/mila/v/vivian.white/scratch/v1-models/saved-models/low-rank/" + model_dir = src + args.net + "_rank" + str(args.channel_k) + "_width" + str(args.width) + "_" + args.nonlinearity+ "-seed" + str(args.seed) + print("Model dir: ", model_dir) os.makedirs(model_dir, exist_ok=True) torch.save(model.state_dict(), model_dir+ "/model.pt") @@ -32,6 +33,9 @@ def save_model(args, model): help='filename for saved model') parser.add_argument('--seed', default=0, type=int, help='seed to use') parser.add_argument('--width', type=float, default=1, help='resnet width scale factor') +parser.add_argument('--spatial_k', type=float, default=1, help='%spatial low-rank') +parser.add_argument('--channel_k', type=float, default=1, help='%channel low-rank') +parser.add_argument('--nonlinearity', type=str, default='abs') args = parser.parse_args() @@ -70,25 +74,35 @@ def save_model(args, model): print('==> Building model..') net = define_models(args) -run_name = args.net +run_name = "{}_rank{}_seed{}_width{}_{}".format(args.net, args.channel_k, args.seed,\ + args.width, args.nonlinearity) print("Args.net: ", args.net) print("Net: ", net) +print("Run name: ", run_name) + +param_count = sum(p.numel() for p in net.parameters() if p.requires_grad) +print("Num Learnable Params: ", param_count) + set_seeds(args.seed) net = net.to(device) -wandb_dir = "../../wandb" +wandb_dir = "/home/mila/v/vivian.white/scratch/v1-models/wandb" os.makedirs(wandb_dir, exist_ok=True) os.chdir(wandb_dir) -run = wandb.init(project="FactConv", config=args, - group="pytorch_cifar", name=run_name, dir=wandb_dir) + + + +run = wandb.init(project="factconv", config=args, group="lowrankabs", name=run_name, dir=wandb_dir) #wandb.watch(net, log='all', log_freq=1) +run.log({'param_count':param_count}) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) -scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,\ + T_max=args.num_epochs) # Training def train(epoch): @@ -111,6 +125,9 @@ def train(epoch): progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) + + acc = 100.*correct/total + run.log({"train_accuracy":acc}) def test(epoch): global best_acc @@ -134,7 +151,7 @@ def test(epoch): # Save checkpoint. acc = 100.*correct/total - run.log({"accuracy":acc}) + run.log({"test_accuracy":acc}) if acc > best_acc: print('Saving..') state = { diff --git a/FactConv/setoff.sh b/FactConv/setoff.sh new file mode 100644 index 0000000..2aa95b1 --- /dev/null +++ b/FactConv/setoff.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --gres=gpu:a100l:1 +#SBATCH --constraint="ampere" +#SBATCH -c 8 +#SBATCH --mem=20G +#SBATCH -t 20:00:00 +#SBATCH --output slurm/%j.out +#SBATCH --partition long + +module load anaconda/3 +conda activate random_features + +python pytorch_cifar.py --nonlinearity "abs" --width $1 --channel_k $2 --seed $3 --net $4