From 1a5d41a9f78dbfb3d91495f5564dee3db2ba1ae8 Mon Sep 17 00:00:00 2001 From: OmkarKabadagi5823 Date: Sun, 25 Sep 2022 21:11:35 +0530 Subject: [PATCH] update ModuleList indexing to enumeration --- models/SeSGCN_student.py | 9 +++++---- models/SeSGCN_teacher.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/models/SeSGCN_student.py b/models/SeSGCN_student.py index 7a9701a..880cce5 100644 --- a/models/SeSGCN_student.py +++ b/models/SeSGCN_student.py @@ -333,9 +333,10 @@ def forward(self, x, maskA, maskT): x= x.permute(0,2,1,3) # prepare the input for the Time-Extrapolator-CNN (NCTV->NTCV) - x=self.prelus[0](self.txcnns[0](x)) - - for i in range(1,self.n_txcnn_layers): - x = self.prelus[i](self.txcnns[i](x)) +x # residual connection + for i, (prelu, txcnn) in enumerate(zip(self.prelus, self.txcnns)): + if i == 0: + x = prelu(txcnn(x)) + else: + x = prelu(txcnn(x)) + x return x \ No newline at end of file diff --git a/models/SeSGCN_teacher.py b/models/SeSGCN_teacher.py index af7dc5d..2be50c9 100644 --- a/models/SeSGCN_teacher.py +++ b/models/SeSGCN_teacher.py @@ -262,9 +262,10 @@ def forward(self, x): x= x.permute(0,2,1,3) # prepare the input for the Time-Extrapolator-CNN (NCTV->NTCV) - x=self.prelus[0](self.txcnns[0](x)) - - for i in range(1,self.n_txcnn_layers): - x = self.prelus[i](self.txcnns[i](x)) +x # residual connection + for i, (prelu, txcnn) in enumerate(zip(self.prelus, self.txcnns)): + if i == 0: + x = prelu(txcnn(x)) + else: + x = prelu(txcnn(x)) + x return x \ No newline at end of file