diff --git a/inference/act_model.py b/inference/act_model.py index e0e2442..496adbf 100755 --- a/inference/act_model.py +++ b/inference/act_model.py @@ -50,8 +50,8 @@ def __init__(self, rate=0.0, alpha=0.4, device="cuda:0"): super(ActivityModel, self).__init__() self.alpha = alpha - self.kcat_model = KcatModel().to(device) - self.Km_model = KmModel().to(device) + self.kcat_model = KcatModel(device = device).to(device) + self.Km_model = KmModel(device = device).to(device) self.prot_norm = nn.BatchNorm1d(1024).to(device) self.molt5_norm = nn.BatchNorm1d(768).to(device)