diff --git a/synapses/SET_layer.py b/synapses/SET_layer.py index 259996c..915c952 100644 --- a/synapses/SET_layer.py +++ b/synapses/SET_layer.py @@ -258,5 +258,5 @@ def reset_parameters(self): def forward(self, x): k = x[:, self.inds] k = k * self.weight - z = scatter_add(k, self.inds_out) + z = scatter_add(k, self.inds_out, dim_size=self.outdim) return z + self.bias \ No newline at end of file