Conversation
ajratner
left a comment
There was a problem hiding this comment.
Hey @khaledsaab sorry for the delay reviewing this! I think it looks great overall, see comments for requested changes. Also, to pass tests, need to run make dev, then make fix and make check (see Developer guidelines in README)
| # The following identity module is to essentially replace the last FC layer | ||
| # in the resnet model by the FC in MeTal | ||
|
|
||
| class IdentityModule(nn.Module): |
There was a problem hiding this comment.
This class already exists in metal.modules.identity_module
|
|
||
| # Here we create a dataloader that transforms CIFAR labels from 0-9, to 1-10, | ||
| # We do this because MeTal treats a 0 label as abstain | ||
| class MetalDataset(Dataset): |
There was a problem hiding this comment.
Since we already have a MetalDataset, I would rename this OneIndexedDataset? Also clean up commented out stuff
There was a problem hiding this comment.
And this can go in metal.utils
| class BasicBlock(nn.Module): | ||
| expansion = 1 | ||
|
|
||
| def __init__(self, in_planes, planes, stride=1): |
There was a problem hiding this comment.
Nit, but in_channels etc. is more standard; in_planes might be confusing?
| class ResNet(nn.Module): | ||
| def __init__(self, block, num_blocks, num_classes=10): | ||
| super(ResNet, self).__init__() | ||
| self.in_planes = 64 |
There was a problem hiding this comment.
This is defined but not used? Would be ideal to make this a kwarg also?
| @@ -0,0 +1,117 @@ | |||
| '''ResNet in PyTorch. | |||
There was a problem hiding this comment.
I think we should put this in metal.contrib.modules
No description provided.