diff --git a/models/SeSGCN_student.py b/models/SeSGCN_student.py index 7a9701a..880cce5 100644 --- a/models/SeSGCN_student.py +++ b/models/SeSGCN_student.py @@ -333,9 +333,10 @@ def forward(self, x, maskA, maskT): x= x.permute(0,2,1,3) # prepare the input for the Time-Extrapolator-CNN (NCTV->NTCV) - x=self.prelus[0](self.txcnns[0](x)) - - for i in range(1,self.n_txcnn_layers): - x = self.prelus[i](self.txcnns[i](x)) +x # residual connection + for i, (prelu, txcnn) in enumerate(zip(self.prelus, self.txcnns)): + if i == 0: + x = prelu(txcnn(x)) + else: + x = prelu(txcnn(x)) + x return x \ No newline at end of file diff --git a/models/SeSGCN_teacher.py b/models/SeSGCN_teacher.py index af7dc5d..2be50c9 100644 --- a/models/SeSGCN_teacher.py +++ b/models/SeSGCN_teacher.py @@ -262,9 +262,10 @@ def forward(self, x): x= x.permute(0,2,1,3) # prepare the input for the Time-Extrapolator-CNN (NCTV->NTCV) - x=self.prelus[0](self.txcnns[0](x)) - - for i in range(1,self.n_txcnn_layers): - x = self.prelus[i](self.txcnns[i](x)) +x # residual connection + for i, (prelu, txcnn) in enumerate(zip(self.prelus, self.txcnns)): + if i == 0: + x = prelu(txcnn(x)) + else: + x = prelu(txcnn(x)) + x return x \ No newline at end of file