Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
fb83f37
adding files from muawi_resnets
vivianwhite Apr 21, 2024
0f10d2c
removing comments and unused lines
vivianwhite Apr 21, 2024
e3ab5ba
remove comments and unused lines
vivianwhite Apr 21, 2024
8411f74
fixed and tested conv_modules
vivianwhite Apr 21, 2024
fe6cbc7
add whitespace
vivianwhite Apr 21, 2024
a472f7d
add whitespace
vivianwhite Apr 21, 2024
2ec009e
trying to setup define_models
vivianwhite Apr 21, 2024
122c758
fixed models dir
vivianwhite Apr 21, 2024
28841fe
Delete models directory
vivianwhite Apr 21, 2024
7ddc12c
add utils file
vivianwhite Apr 21, 2024
55670ba
fix white spaces
vivianwhite Apr 21, 2024
a779085
fixing V1 cifar10 model
vivianwhite Apr 21, 2024
b873254
add replace_linears function
vivianwhite Apr 22, 2024
204bf3a
adding recursive functions script
vivianwhite Apr 22, 2024
2aeaf98
some updates, more to come tomorrow
vivianwhite Apr 23, 2024
6a236c5
Delete refactor/resnet.py
vivianwhite Apr 22, 2024
a9317a7
update define_model options
vivianwhite Apr 23, 2024
11d9477
normalization back to normal
vivianwhite Apr 23, 2024
6905850
updates anda adding turn_off_grad function
vivianwhite Apr 23, 2024
aa28aab
adding V1-init functionality
vivianwhite Apr 23, 2024
e537fda
updating v1-init
vivianwhite Apr 23, 2024
15b8cd5
cleanup
vivianwhite Apr 24, 2024
29978b4
initial commit of refactored rainbow (have not run anything yet)
MuawizChaudhary Apr 24, 2024
82180c4
minimal changes to ensure this script runs
MuawizChaudhary Apr 24, 2024
09aade2
cleanup, modified print statements, added seed argument, removed widt…
MuawizChaudhary Apr 24, 2024
6c6dcc9
removed affine argument
MuawizChaudhary Apr 24, 2024
15422ab
moved recursive modules around
MuawizChaudhary Apr 24, 2024
1397314
added recurse_preorder. made each replace function a wrapper that ret…
MuawizChaudhary Apr 24, 2024
44da047
removed fact_2_conv
MuawizChaudhary Apr 24, 2024
a4bf746
changes
MuawizChaudhary Apr 24, 2024
fff0899
did TODOs
vivianwhite Apr 24, 2024
f36a459
noting origin of code
vivianwhite May 1, 2024
38de471
added class conditional decerators. added looping over multiple callb…
MuawizChaudhary Apr 25, 2024
0cca6a1
Revert to 09f4833
MuawizChaudhary Apr 26, 2024
42f7542
seperated rainbow functionality from runner script
MuawizChaudhary Apr 26, 2024
fff82a5
changes to rainbow sampling so we can sample multiple times and calcu…
MuawizChaudhary Apr 26, 2024
1dd2048
refactored a decent amount of rainbow sampling
MuawizChaudhary Apr 27, 2024
5f44973
fact-conv changes
MuawizChaudhary May 1, 2024
657db4f
reduce amount of code in rainbow
MuawizChaudhary May 1, 2024
c680aea
initalize with empty_like instead
MuawizChaudhary May 1, 2024
edda0b4
added verbose option
MuawizChaudhary May 1, 2024
68e84e4
removed my beloved og
MuawizChaudhary May 2, 2024
07803e0
deleted learnable_cov.py which had been renamed to V1_covariance.py
MuawizChaudhary May 2, 2024
6594292
width scaling should be first
MuawizChaudhary May 2, 2024
c078cfc
making rainbow sampling more like a properly called class. similar to…
MuawizChaudhary May 2, 2024
642c480
removed specification of device from init, rely instead on layer to p…
MuawizChaudhary May 2, 2024
d3116ec
formatting changes
MuawizChaudhary May 2, 2024
992e98b
formatting
MuawizChaudhary May 2, 2024
0d0d60f
added README.md
MuawizChaudhary May 2, 2024
2b48a03
added citation to README.md
MuawizChaudhary May 2, 2024
9dc5666
Update README.md
MuawizChaudhary May 2, 2024
01b39c8
renaming folder
vivianwhite May 3, 2024
c92f6ac
Delete layers directory
vivianwhite May 3, 2024
e05bbcd
Delete refactor directory
vivianwhite May 3, 2024
0a1c34b
moving scripts
vivianwhite May 3, 2024
428ae58
Delete V1-models/scripts/s_f_sweep directory
vivianwhite May 3, 2024
c6fa1a4
updating LC RSN scripts
vivianwhite May 4, 2024
345a598
removing hardcoded stuff
vivianwhite May 4, 2024
1c5295f
fixed save dir
vivianwhite May 4, 2024
1b080e9
fixed save dir
vivianwhite May 4, 2024
2ae3414
new experiments
MuawizChaudhary May 8, 2024
378cbd9
implementing diag factconv2d
vivianwhite May 8, 2024
2ffc071
minor changes
MuawizChaudhary May 8, 2024
74aa5a3
conv_modules.py
vivianwhite May 8, 2024
2366ef8
add diagchan, call by setting args.net='resnet18-diagchan'
vivianwhite May 8, 2024
c720d95
added low-rank setup
vivianwhite Jun 25, 2024
4a04e1d
low-rank experiments
vivianwhite Jun 26, 2024
0342b2e
adding covariance module
vivianwhite Jun 26, 2024
a8c7b5d
adding learnable low-rank diagonal covariance
vivianwhite Jun 28, 2024
ec927a4
adding low rank experiments
vivianwhite Jun 28, 2024
5cceccb
updating covariance modules
vivianwhite Jun 29, 2024
5513509
logging param count
vivianwhite Jul 1, 2024
f5f70f9
update turn_off_covar_grad function
vivianwhite Jul 1, 2024
0a2e945
adding LearnableCov.py back
vivianwhite Jul 1, 2024
1f9f43c
adding nonlinearities study
vivianwhite Jul 15, 2024
db9ba1b
adding resnext model
vivianwhite Jul 20, 2024
de9820b
added diagonally dominant experiment and updated scripts
vivianwhite Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
350 changes: 321 additions & 29 deletions FactConv/conv_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from torch.nn.parameter import Parameter
from torch.nn.common_types import _size_2_t
from typing import Optional, List, Tuple, Union
from cov import Covariance, LowRankCovariance, LowRankPlusDiagCovariance,\
LowRankK1Covariance, OffDiagCovariance, DiagCovariance,\
LowRankK1DiagCovariance, DiagonallyDominantCovariance

"""
The function below is copied directly from
Expand All @@ -19,6 +22,7 @@ def _contract(tensor, matrix, axis):
class FactConv2d(nn.Conv2d):
def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire script, your re-using alot of code. I wonder if you could subclass FactConv extensions using this factconv module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as you mentioned, eventually we should refactor the other fact convs to use the covariance module.

self,
nonlinearity,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
Expand All @@ -41,47 +45,335 @@ def __init__(
self.register_buffer("weight", new_weight)
nn.init.kaiming_normal_(self.weight)

self.in_features = self.in_channels // self.groups * \
self.kernel_size[0] * self.kernel_size[1]
triu1 = torch.triu_indices(self.in_channels // self.groups,
self.in_channels // self.groups,
device=self.weight.device,
dtype=torch.long)
scat_idx1 = triu1[0]*self.in_channels//self.groups + triu1[1]
self.register_buffer("scat_idx1", scat_idx1, persistent=False)
# self.in_features = self.in_channels // self.groups * \
# self.kernel_size[0] * self.kernel_size[1]

triu2 = torch.triu_indices(self.kernel_size[0] * self.kernel_size[1],
self.kernel_size[0]
* self.kernel_size[1],
device=self.weight.device,
dtype=torch.long)
scat_idx2 = triu2[0]*self.kernel_size[0]*self.kernel_size[1] + triu2[1]
self.register_buffer("scat_idx2", scat_idx2, persistent=False)
channel_triu_size = self.in_channels // self.groups
spatial_triu_size = self.kernel_size[0] * self.kernel_size[1]

triu1_len = triu1.shape[1]
triu2_len = triu2.shape[1]
self.channel = Covariance(channel_triu_size, nonlinearity)
self.spatial = Covariance(spatial_triu_size, nonlinearity)

tri1_vec = self.weight.new_zeros((triu1_len,))
self.tri1_vec = Parameter(tri1_vec)
def forward(self, input: Tensor) -> Tensor:
U1 = self.channel.sqrt()
U2 = self.spatial.sqrt()

# flatten over filter dims and contract
composite_weight = _contract(self.weight, U1.T, 1)
composite_weight = _contract(
torch.flatten(composite_weight, -2, -1), U2.T, -1
).reshape(self.weight.shape)
return self._conv_forward(input, composite_weight, self.bias)

tri2_vec = self.weight.new_zeros((triu2_len,))
self.tri2_vec = Parameter(tri2_vec)

class LowRankFactConv2d(nn.Conv2d):
def __init__(
self,
spatial_k,
channel_k,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None,
) -> None:
# init as Conv2d
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode, device, dtype)
self.spatial_k = spatial_k
self.channel_k = channel_k
# weight shape: (out_channels, in_channels // groups, *kernel_size)
new_weight = torch.empty_like(self.weight)
del self.weight # remove Parameter, create buffer
self.register_buffer("weight", new_weight)
nn.init.kaiming_normal_(self.weight)

channel_triu_size = self.in_channels // self.groups
spatial_triu_size = self.kernel_size[0] * self.kernel_size[1]

self.channel = LowRankCovariance(channel_triu_size, self.channel_k)
self.spatial = LowRankCovariance(spatial_triu_size, self.spatial_k)
#self.spatial = Covariance(spatial_triu_size)


def forward(self, input: Tensor) -> Tensor:
U1 = self._tri_vec_to_mat(self.tri1_vec, self.in_channels //
self.groups, self.scat_idx1)
U2 = self._tri_vec_to_mat(self.tri2_vec, self.kernel_size[0] * self.kernel_size[1],
self.scat_idx2)
U1 = self.channel.sqrt()
U2 = self.spatial.sqrt()

# flatten over filter dims and contract
composite_weight = _contract(self.weight, U1.T, 1)
composite_weight = _contract(
torch.flatten(composite_weight, -2, -1), U2.T, -1
).reshape(self.weight.shape)
return self._conv_forward(input, composite_weight, self.bias)

def _tri_vec_to_mat(self, vec, n, scat_idx):
U = self.weight.new_zeros((n*n)).scatter_(0, scat_idx, vec).view(n, n)
U = torch.diagonal_scatter(U, U.diagonal().exp_())
return U
class LowRankPlusDiagFactConv2d(nn.Conv2d):
def __init__(
self,
channel_k,
nonlinearity,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None,
) -> None:
# init as Conv2d
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode, device, dtype)
self.channel_k = channel_k
# weight shape: (out_channels, in_channels // groups, *kernel_size)
new_weight = torch.empty_like(self.weight)
del self.weight # remove Parameter, create buffer
self.register_buffer("weight", new_weight)
nn.init.kaiming_normal_(self.weight)

channel_triu_size = self.in_channels // self.groups
spatial_triu_size = self.kernel_size[0] * self.kernel_size[1]

self.channel = LowRankK1DiagCovariance(channel_triu_size, self.channel_k, nonlinearity)
self.spatial = Covariance(spatial_triu_size, nonlinearity)


def forward(self, input: Tensor) -> Tensor:
U1 = self.channel.sqrt()
U2 = self.spatial.sqrt()

# flatten over filter dims and contract
composite_weight = _contract(self.weight, U1.T, 1)
composite_weight = _contract(
torch.flatten(composite_weight, -2, -1), U2.T, -1
).reshape(self.weight.shape)
return self._conv_forward(input, composite_weight, self.bias)

class LowRankK1FactConv2d(nn.Conv2d):
def __init__(
self,
channel_k,
nonlinearity,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None,
) -> None:
# init as Conv2d
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode, device, dtype)
# weight shape: (out_channels, in_channels // groups, *kernel_size)
new_weight = torch.empty_like(self.weight)
del self.weight # remove Parameter, create buffer
self.register_buffer("weight", new_weight)
nn.init.kaiming_normal_(self.weight)

self.channel_k = channel_k
channel_triu_size = self.in_channels // self.groups
spatial_triu_size = self.kernel_size[0] * self.kernel_size[1]

self.channel = LowRankK1Covariance(channel_triu_size, self.channel_k,\
nonlinearity)
self.spatial = Covariance(spatial_triu_size, nonlinearity)


def forward(self, input: Tensor) -> Tensor:
U1 = self.channel.sqrt()
U2 = self.spatial.sqrt()

# flatten over filter dims and contract
composite_weight = _contract(self.weight, U1.T, 1)
composite_weight = _contract(
torch.flatten(composite_weight, -2, -1), U2.T, -1
).reshape(self.weight.shape)
return self._conv_forward(input, composite_weight, self.bias)

class OffDiagFactConv2d(nn.Conv2d):
def __init__(
self,
nonlinearity,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None,
) -> None:
# init as Conv2d
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode, device, dtype)
# weight shape: (out_channels, in_channels // groups, *kernel_size)
new_weight = torch.empty_like(self.weight)
del self.weight # remove Parameter, create buffer
self.register_buffer("weight", new_weight)
nn.init.kaiming_normal_(self.weight)

channel_triu_size = self.in_channels // self.groups
spatial_triu_size = self.kernel_size[0] * self.kernel_size[1]

self.channel = OffDiagCovariance(channel_triu_size, nonlinearity)
self.spatial = Covariance(spatial_triu_size, nonlinearity)


def forward(self, input: Tensor) -> Tensor:
U1 = self.channel.sqrt()
U2 = self.spatial.sqrt()

# flatten over filter dims and contract
composite_weight = _contract(self.weight, U1.T, 1)
composite_weight = _contract(
torch.flatten(composite_weight, -2, -1), U2.T, -1
).reshape(self.weight.shape)
return self._conv_forward(input, composite_weight, self.bias)

class DiagFactConv2d(nn.Conv2d):
def __init__(
self,
nonlinearity,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None
) -> None:
# init as Conv2d
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode, device, dtype)
# weight shape: (out_channels, in_channels // groups, *kernel_size)
new_weight = torch.empty_like(self.weight)
del self.weight # remove Parameter, create buffer
self.register_buffer("weight", new_weight)
nn.init.kaiming_normal_(self.weight)

channel_triu_size = self.in_channels // self.groups
spatial_triu_size = self.kernel_size[0] * self.kernel_size[1]

self.channel = DiagCovariance(channel_triu_size, nonlinearity)
self.spatial = DiagCovariance(spatial_triu_size, nonlinearity)


def forward(self, input: Tensor) -> Tensor:
U1 = self.channel.sqrt()
U2 = self.spatial.sqrt()

# flatten over filter dims and contract
composite_weight = _contract(self.weight, U1.T, 1)
composite_weight = _contract(
torch.flatten(composite_weight, -2, -1), U2.T, -1
).reshape(self.weight.shape)
return self._conv_forward(input, composite_weight, self.bias)

class DiagChanFactConv2d(nn.Conv2d):
def __init__(
self,
nonlinearity,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None
) -> None:
# init as Conv2d
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode, device, dtype)
# weight shape: (out_channels, in_channels // groups, *kernel_size)
new_weight = torch.empty_like(self.weight)
del self.weight # remove Parameter, create buffer
self.register_buffer("weight", new_weight)
nn.init.kaiming_normal_(self.weight)

channel_triu_size = self.in_channels // self.groups
spatial_triu_size = self.kernel_size[0] * self.kernel_size[1]

self.channel = DiagCovariance(channel_triu_size, nonlinearity)
self.spatial = Covariance(spatial_triu_size, nonlinearity)

def forward(self, input: Tensor) -> Tensor:
U1 = self.channel.sqrt()
U2 = self.spatial.sqrt()
# flatten over filter dims and contract
composite_weight = _contract(self.weight, U1.T, 1)
composite_weight = _contract(
torch.flatten(composite_weight, -2, -1), U2.T, -1
).reshape(self.weight.shape)
return self._conv_forward(input, composite_weight, self.bias)

class DiagDomFactConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None
) -> None:
# init as Conv2d
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode, device, dtype)
# weight shape: (out_channels, in_channels // groups, *kernel_size)
new_weight = torch.empty_like(self.weight)
del self.weight # remove Parameter, create buffer
self.register_buffer("weight", new_weight)
nn.init.kaiming_normal_(self.weight)

channel_triu_size = self.in_channels // self.groups
spatial_triu_size = self.kernel_size[0] * self.kernel_size[1]

self.channel = DiagonallyDominantCovariance(channel_triu_size)
self.spatial = Covariance(spatial_triu_size, "abs")

def forward(self, input: Tensor) -> Tensor:
U1 = self.channel.sqrt()
U2 = self.spatial.sqrt()
# flatten over filter dims and contract
composite_weight = _contract(self.weight, U1.T, 1)
composite_weight = _contract(
torch.flatten(composite_weight, -2, -1), U2.T, -1
).reshape(self.weight.shape)
return self._conv_forward(input, composite_weight, self.bias)
Loading