|
14 | 14 | from torch.autograd import Variable |
15 | 15 | from torch.utils.serialization import load_lua |
16 | 16 |
|
17 | | - |
18 | | -class LambdaBase(nn.Sequential): |
19 | | - def __init__(self, fn, *args): |
20 | | - super(LambdaBase, self).__init__(*args) |
21 | | - self.lambda_func = fn |
22 | | - |
23 | | - def forward_prepare(self, input): |
24 | | - output = [] |
25 | | - for module in self._modules.values(): |
26 | | - output.append(module(input)) |
27 | | - return output if output else input |
28 | | - |
29 | | - |
30 | | -class Lambda(LambdaBase): |
31 | | - def forward(self, input): |
32 | | - return self.lambda_func(self.forward_prepare(input)) |
33 | | - |
34 | | - |
35 | | -class LambdaMap(LambdaBase): |
36 | | - def forward(self, input): |
37 | | - # result is Variables list [Variable1, Variable2, ...] |
38 | | - return list(map(self.lambda_func, self.forward_prepare(input))) |
39 | | - |
40 | | - |
41 | | -class LambdaReduce(LambdaBase): |
42 | | - def forward(self, input): |
43 | | - # result is a Variable |
44 | | - return reduce(self.lambda_func, self.forward_prepare(input)) |
| 17 | +from header import LambdaBase, Lambda, LambdaMap, LambdaReduce |
45 | 18 |
|
46 | 19 |
|
47 | 20 | def copy_param(m, n): |
48 | | - if m.weight is not None: n.weight.data.copy_(m.weight) |
49 | | - if m.bias is not None: n.bias.data.copy_(m.bias) |
50 | | - if hasattr(n, 'running_mean'): n.running_mean.copy_(m.running_mean) |
51 | | - if hasattr(n, 'running_var'): n.running_var.copy_(m.running_var) |
| 21 | + if m.weight is not None: |
| 22 | + n.weight.data.copy_(m.weight, broadcast=False) |
| 23 | + if hasattr(m, 'bias') and m.bias is not None: |
| 24 | + n.bias.data.copy_(m.bias, broadcast=False) |
| 25 | + if hasattr(n, 'running_mean'): |
| 26 | + n.running_mean.copy_(m.running_mean, broadcast=False) |
| 27 | + if hasattr(n, 'running_var'): |
| 28 | + n.running_var.copy_(m.running_var, broadcast=False) |
52 | 29 |
|
53 | 30 |
|
54 | 31 | def add_submodule(seq, *args): |
@@ -165,7 +142,6 @@ def lua_recursive_source(module): |
165 | 142 | s = [] |
166 | 143 | for m in module.modules: |
167 | 144 | name = type(m).__name__ |
168 | | - real = m |
169 | 145 | if name == 'TorchObject': |
170 | 146 | name = m._typename.replace('cudnn.', '') |
171 | 147 | m = m._obj |
|
0 commit comments