diff --git a/GraphNorm_ws/ogbg_ws/src/dgl_model/norm.py b/GraphNorm_ws/ogbg_ws/src/dgl_model/norm.py index 24ac344..926a7f6 100644 --- a/GraphNorm_ws/ogbg_ws/src/dgl_model/norm.py +++ b/GraphNorm_ws/ogbg_ws/src/dgl_model/norm.py @@ -20,7 +20,7 @@ def forward(self, graph, tensor, print_=False): return self.norm(tensor) elif self.norm is None: return tensor - batch_list = graph.batch_num_nodes + batch_list = graph.batch_num_nodes() batch_size = len(batch_list) batch_list = torch.Tensor(batch_list).long().to(tensor.device) batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) @@ -35,4 +35,4 @@ def forward(self, graph, tensor, print_=False): std = ((std.T / batch_list).T + 1e-6).sqrt() std = std.repeat_interleave(batch_list, dim=0) # return sub / std - return self.weight * sub / std + self.bias \ No newline at end of file + return self.weight * sub / std + self.bias