The initialization order of modules is crucial and should be explicitly stated.
Due to the random seed set in the test file, the initialization order of different modules is important when constructing the ResNet9 network. It is recommended to clarify this in the documentation of the homework.
For ResNet9:
Correct:
class ResNet9(ndl.nn.Module):
def __init__(self, device=None, dtype="float32"):
super().__init__()
self.module = nn.Sequential(
ConvBN(3, 16, 7, 4, device=device, dtype=dtype),
ConvBN(16, 32, 3, 2, device=device, dtype=dtype),
ndl.nn.Residual(nn.Sequential(
ConvBN(32, 32, 3, 1, device=device, dtype=dtype),
ConvBN(32, 32, 3, 1, device=device, dtype=dtype),
)),
ConvBN(32, 64, 3, 2, device=device, dtype=dtype),
ConvBN(64, 128, 3, 2, device=device, dtype=dtype),
ndl.nn.Residual(nn.Sequential(
ConvBN(128, 128, 3, 1, device=device, dtype=dtype),
ConvBN(128, 128, 3, 1, device=device, dtype=dtype),
)),
nn.Flatten(),
nn.Linear(128, 128, device=device, dtype=dtype),
nn.ReLU(),
nn.Linear(128, 10, device=device, dtype=dtype),
)
Incorrect (ResNet 9, train_cifar19 will fail):
class ResNet9(ndl.nn.Module):
def __init__(self, device=None, dtype="float32"):
super().__init__()
residual1 = ndl.nn.Residual(nn.Sequential(
ConvBN(32, 32, 3, 1, device=device, dtype=dtype),
ConvBN(32, 32, 3, 1, device=device, dtype=dtype),
))
residual2 = ndl.nn.Residual(nn.Sequential(
ConvBN(128, 128, 3, 1, device=device, dtype=dtype),
ConvBN(128, 128, 3, 1, device=device, dtype=dtype),
))
self.module = nn.Sequential(
ConvBN(3, 16, 7, 4, device=device, dtype=dtype),
ConvBN(16, 32, 3, 2, device=device, dtype=dtype),
residual1,
ConvBN(32, 64, 3, 2, device=device, dtype=dtype),
ConvBN(64, 128, 3, 2, device=device, dtype=dtype),
residual2,
nn.Flatten(),
nn.Linear(128, 128, device=device, dtype=dtype),
nn.ReLU(),
nn.Linear(128, 10, device=device, dtype=dtype),
)
The initialization order of modules is crucial and should be explicitly stated.
Due to the random seed set in the test file, the initialization order of different modules is important when constructing the ResNet9 network. It is recommended to clarify this in the documentation of the homework.
For ResNet9:
Correct:
Incorrect (ResNet 9, train_cifar19 will fail):