-
-
Notifications
You must be signed in to change notification settings - Fork 21
Add GNN-Based Predictor with DAG Preprocessing #430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
antotu
wants to merge
54
commits into
munich-quantum-toolkit:main
Choose a base branch
from
antotu:gnn-branch
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
12fad57
added function related to training and for GNN, needed to define GNN …
be734dd
Added the gnn part, must be fine-tuned hyper-params, no test
61c6824
Removed the barriers in the creation of the DAG
75875ff
🎨 pre-commit fixes
pre-commit-ci[bot] 081651c
coded tested and fixed, need to add a cross validation module
b82dc01
Merge branch 'gnn-branch' of https://github.com/antotu/predictor-gnn …
antotu 5ebd202
fixed the problem of the predict_device_for_figure_of_merits
857cd6f
🎨 pre-commit fixes
pre-commit-ci[bot] 6081f6b
Hellinger test done: success
7c54da6
Merge branch 'gnn-branch' of https://github.com/antotu/predictor-gnn …
bb4da24
GNN predictor fixed with optuna and tested
10bb52c
🎨 pre-commit fixes
pre-commit-ci[bot] 06be0d6
GNN predictor fixed with optuna and tested
ce990e3
Modified the tolm for running on the MacOS
96ca75b
Problems modified TPESampler and not TYPESampler
a64a082
Problems modified TPESampler and not TYPESampler
f8c99b5
🎨 pre-commit fixes
pre-commit-ci[bot] e4e2742
Problems modified TPESampler and not TYPESampler
5784ff7
Problems modified TPESampler and not TYPESampler
7e17379
Test modified with number of epochs as parameter
082de05
Eliminated trained model
5ed00a9
Changed the test estimated hellinger for windows
e59a941
🎨 pre-commit fixes
pre-commit-ci[bot] 3a9f16c
Changed the test estimated hellinger for windows
c43ee01
Merge branch 'gnn-branch' of https://github.com/antotu/predictor-gnn …
92eda99
Changed the test estimated hellinger for windows
dc1aa55
Problem with windows solved eliminating warning
6809ccb
Files modified according suggestion
8c77598
Fixed the comments related to test hellinger distance and utils
antotu dc0a824
🎨 pre-commit fixes
pre-commit-ci[bot] 2419952
Fixed modification also with pre-commit
antotu 5335241
Fixed modification also with pre-commit
antotu 96096a0
Refactor the test ml predictor considering to join function related M…
antotu 4613012
Modified part of helper in order to solve problems code
antotu 1c728e2
Pre-commit has substituted Wille in Will
antotu c31cb46
Update tests/device_selection/test_predictor_ml.py
antotu 13cf0f4
🎨 pre-commit fixes
pre-commit-ci[bot] 713343f
first round fixes
antotu 17c6575
🎨 pre-commit fixes
pre-commit-ci[bot] 169b00e
pre-commit fixes
antotu 2248081
pre-commit fixes
antotu 8f90b12
Update src/mqt/predictor/ml/predictor.py
antotu 74ec34b
Update src/mqt/predictor/ml/predictor.py
antotu f99e17b
🎨 pre-commit fixes
pre-commit-ci[bot] 57b1a29
Partial modification
antotu 96232f0
Merge branch 'gnn-branch' of github.com:antotu/predictor-gnn into gnn…
antotu 61965d8
🎨 pre-commit fixes
pre-commit-ci[bot] 93f5414
fixed comments repo
antotu 95f5359
Merge branch 'gnn-branch' of github.com:antotu/predictor-gnn into gnn…
antotu 4fb7112
Modified the gates accepted
antotu 5ea1720
Modified list
antotu 312e6ea
Fixed bug Swap and Cswap gates
antotu 156b7e6
Edit for saving memory GPU
antotu 77c9f5c
Added patience as variable
antotu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) 2023 - 2025 Chair for Design Automation, TUM | ||
# Copyright (c) 2025 Munich Quantum Software Company GmbH | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: MIT | ||
# | ||
# Licensed under the MIT License | ||
|
||
# file generated by setuptools-scm | ||
# don't change, don't track in version control | ||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"__commit_id__", | ||
"__version__", | ||
"__version_tuple__", | ||
"commit_id", | ||
"version", | ||
"version_tuple", | ||
] | ||
|
||
TYPE_CHECKING = False | ||
if TYPE_CHECKING: | ||
VERSION_TUPLE = tuple[int | str, ...] | ||
Check warningCode scanning / CodeQL Unreachable code Warning
This statement is unreachable.
|
||
COMMIT_ID = str | None | ||
else: | ||
VERSION_TUPLE = object | ||
COMMIT_ID = object | ||
|
||
version: str | ||
__version__: str | ||
__version_tuple__: VERSION_TUPLE | ||
version_tuple: VERSION_TUPLE | ||
commit_id: COMMIT_ID | ||
__commit_id__: COMMIT_ID | ||
|
||
__version__ = version = "0.1.dev719+g5ea17201a.d20250908" | ||
__version_tuple__ = version_tuple = (0, 1, "dev719", "g5ea17201a.d20250908") | ||
|
||
__commit_id__ = commit_id = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) 2023 - 2025 Chair for Design Automation, TUM | ||
# Copyright (c) 2025 Munich Quantum Software Company GmbH | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: MIT | ||
# | ||
# Licensed under the MIT License | ||
|
||
"""This module contains the GNN module for graph neural networks.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as functional | ||
from torch_geometric.nn import SAGEConv, global_mean_pool | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import ( | ||
Callable, # on 3.10+ prefer collections.abc | ||
) | ||
|
||
from torch_geometric.data import Data | ||
|
||
|
||
class GraphConvolutionSage(nn.Module): | ||
"""Graph convolutional layer using SAGEConv.""" | ||
|
||
def __init__( | ||
self, | ||
in_feats: int, | ||
hidden_dim: int, | ||
num_resnet_layers: int, | ||
*, | ||
conv_activation: Callable[..., torch.Tensor] = functional.leaky_relu, | ||
conv_act_kwargs: dict[str, Any] | None = None, | ||
) -> None: | ||
"""A flexible SageConv graph classification model. | ||
|
||
Args: | ||
in_feats: dimensionality of node features | ||
hidden_dim: output size of SageConv | ||
num_resnet_layers: how many SageConv layers (with residuals) to stack after the SageConvs | ||
mlp_units: list of units for each layer of the final MLP | ||
conv_activation: activation fn after each graph layer | ||
conv_act_kwargs: extra kwargs for conv_activation | ||
final_activation: activation applied to the final scalar output | ||
""" | ||
super().__init__() | ||
self.conv_activation = conv_activation | ||
self.conv_act_kwargs = conv_act_kwargs or {} | ||
|
||
# --- GRAPH ENCODER --- | ||
self.convs = nn.ModuleList() | ||
# 1) Convolution not in residual configuration | ||
# Possible to generalize the code | ||
self.convs.append(SAGEConv(in_feats, hidden_dim)) | ||
self.convs.append(SAGEConv(hidden_dim, hidden_dim)) | ||
|
||
for _ in range(num_resnet_layers): | ||
self.convs.append(SAGEConv(hidden_dim, hidden_dim)) | ||
|
||
def forward(self, data: Data) -> torch.Tensor: | ||
"""Forward function that allows to elaborate the input graph.""" | ||
x, edge_index, batch = data.x, data.edge_index, data.batch | ||
# 1) Graph stack with residuals | ||
for i, conv in enumerate(self.convs): | ||
x_new = conv(x, edge_index) | ||
x_new = self.conv_activation(x_new, **self.conv_act_kwargs) | ||
# the number 2 is set because two convolution without residual configuration are applied | ||
# and then all the others are in residual configuration | ||
x = x_new if i < 2 else x + x_new | ||
|
||
# 2) Global pooling | ||
return global_mean_pool(x, batch) | ||
|
||
|
||
class GNN(nn.Module): | ||
"""Architecture composed by a Graph Convolutional part with Sage Convolution module and followed by a MLP.""" | ||
|
||
def __init__( | ||
self, | ||
in_feats: int, | ||
hidden_dim: int, | ||
num_resnet_layers: int, | ||
mlp_units: list[int], | ||
*, | ||
conv_activation: Callable[..., torch.Tensor] = functional.leaky_relu, | ||
conv_act_kwargs: dict[str, Any] | None = None, | ||
mlp_activation: Callable[..., torch.Tensor] = functional.leaky_relu, | ||
mlp_act_kwargs: dict[str, Any] | None = None, | ||
classes: list[str] | None = None, | ||
output_dim: int = 1, | ||
) -> None: | ||
"""Init class for the GNN. | ||
|
||
Arguments: | ||
in_feats: dimension of input features of the node | ||
hidden_dim: dimension of hidden output channels of the Convolutional part | ||
num_resnet_layers: number of residual layers | ||
mlp_units: list of units for each layer of the final MLP | ||
conv_activation: activation fn after each graph layer | ||
conv_act_kwargs: extra kwargs for conv_activation. | ||
mlp_activation: activation fn after each MLP layer | ||
mlp_act_kwargs: extra kwargs for mlp_activation. | ||
output_dim: dimension of the output, default is 1 for regression tasks | ||
classes: list of class names for classification tasks | ||
""" | ||
super().__init__() | ||
# Convolutional part | ||
self.graph_conv = GraphConvolutionSage( | ||
in_feats, hidden_dim, num_resnet_layers, conv_activation=conv_activation, conv_act_kwargs=conv_act_kwargs | ||
) | ||
|
||
# MLP architecture | ||
self.mlp_activation = mlp_activation | ||
self.mlp_act_kwargs = mlp_act_kwargs or {} | ||
self.classes = classes | ||
self.fcs = nn.ModuleList() | ||
last_dim = hidden_dim | ||
for out_dim in mlp_units: | ||
self.fcs.append(nn.Linear(last_dim, out_dim)) | ||
last_dim = out_dim | ||
self.out = nn.Linear(last_dim, output_dim) | ||
|
||
def forward(self, data: Data) -> torch.Tensor: | ||
"""Forward function that allows to elaborate the input graph. | ||
|
||
Arguments: | ||
data: The input graph data. | ||
""" | ||
# apply the convolution | ||
x = self.graph_conv(data) | ||
# Apply the MLP | ||
for fc in self.fcs: | ||
x = self.mlp_activation(fc(x), **self.mlp_act_kwargs) | ||
return self.out(x) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should not be tracked and can be removed.