diff --git a/compy/models/graphs/pytorch_geom_model.py b/compy/models/graphs/pytorch_geom_model.py index be95a4b..b04dead 100644 --- a/compy/models/graphs/pytorch_geom_model.py +++ b/compy/models/graphs/pytorch_geom_model.py @@ -18,11 +18,11 @@ def __init__(self, config): annotation_size = config["hidden_size_orig"] hidden_size = config["gnn_h_size"] n_steps = config["num_timesteps"] - num_cls = 2 + num_cls = config["num_cls"] self.reduce = nn.Linear(annotation_size, hidden_size) self.conv = GatedGraphConv(hidden_size, n_steps) - self.agg = GlobalAttention(nn.Linear(hidden_size, 1), nn.Linear(hidden_size, 2)) + self.agg = GlobalAttention(nn.Linear(hidden_size, 1), nn.Linear(hidden_size, num_cls)) self.lin = nn.Linear(hidden_size, num_cls) def forward( @@ -51,6 +51,7 @@ def __init__(self, config=None, num_types=None): "learning_rate": 0.001, "batch_size": 64, "num_epochs": 1000, + "num_cls": 2 } super().__init__(config)