Conversation
…h 8, removed evaluation on affine models, etc etc
…urns the modifed network (tho the network is modified, so this is reduandant)
has been moved to FactConv directory
MuawizChaudhary
left a comment
There was a problem hiding this comment.
just a high level overview, i'll do additional passes.
| help='filename for saved model') | ||
| parser.add_argument('--seed', default=0, type=int, help='seed to use') | ||
| parser.add_argument('--width', type=float, default=1, help='resnet width scale factor') | ||
| parser.add_argument('--spatial_k', type=float, default=1, help='%spatial low-rank') |
There was a problem hiding this comment.
consider removing
FactConv/pytorch_cifar.py
Outdated
| print("Num Learnable Params: ", sum(p.numel() for p in net.parameters() if | ||
| p.requires_grad)) | ||
|
|
||
| run = wandb.init(project="factconv", config=args, group="lowrankplusdiag", name=run_name, dir=wandb_dir) |
There was a problem hiding this comment.
might be worth logging the number of parameters
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,\ | ||
| T_max=args.num_epochs) | ||
|
|
| @@ -1,201 +0,0 @@ | |||
| import torch | |||
| from torch import Tensor | |||
There was a problem hiding this comment.
why this this script removed here?
There was a problem hiding this comment.
good question. appears to be a mistake i made during rebasing
|
|
||
|
|
||
| class FactConv2d(nn.Conv2d): | ||
| def __init__( |
There was a problem hiding this comment.
This entire script, your re-using alot of code. I wonder if you could subclass FactConv extensions using this factconv module.
There was a problem hiding this comment.
as you mentioned, eventually we should refactor the other fact convs to use the covariance module.
| from typing import Optional, List, Tuple, Union | ||
| import math | ||
|
|
||
| class Covariance(nn.Module): |
There was a problem hiding this comment.
theres some repeated code here too. I wonder if we could condense this down.
FactConv/models/function_utils.py
Outdated
| if isinstance(module, FactConv2d) or isinstance(module, DiagFactConv2d)\ | ||
| or isinstance(module, DiagChanFactConv2d) or isinstance(module,LowRankFactConv2d)\ | ||
| or isinstance(module, LowRankPlusDiagFactConv2d): |
There was a problem hiding this comment.
if we had all these factconv variations subclassing fact cov you could just check if it was a factcov instance.
FactConv/models/function_utils.py
Outdated
| or isinstance(module, DiagChanFactConv2d) or isinstance(module,LowRankFactConv2d)\ | ||
| or isinstance(module, LowRankPlusDiagFactConv2d): | ||
| for name, mod in module.named_modules(): | ||
| if isinstance(mod, Covariance): |
There was a problem hiding this comment.
actually i think is better to remove the fact conv check and just check if its a covariance and disable gradients that way.
|
Are we ready to merge? |
Adding Covariance module and low-rank covariance experiments