This project tackles a biological knowledge graph challenge: predicting drug-disease relationships. The dataset consists of nodes representing drugs, diseases, and other biomedical entities (gene), and edges representing their relationships. The primary goal is to train and evaluate a classifier to predict drug-disease links.
- Works with Python 3.10.7 and should work with Python > 3.10 (but not compatible with Python 3.13)
- Ensure all dependencies are installed by creating a Python environment called
ec_challengein/path/to/env/:python -m venv /path/to/env/ec_challenge source /path/to/env/ec_challenge/bin/activate pip install networkx[default] pip3 install -U scikit-learn pip install pandas pip install ipykernel pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install torch_geometric pip3 install jupyter python -m ipykernel install --user --name ec_challenge - Download data from here
- Place the data files in the
./data/directory:Edges.csvNodes.csvGround Truth.csvEmbeddings.csv
- Open jupyter session with
jupyer notebookand runlink_prediction_RF.ipynb
The provided dataset consists of:
- Edges.csv: describes relationships between nodes.
- Nodes.csv: contains metadata about nodes.
- Ground Truth.csv: specifies drug-disease pairs with binary labels (
1for a known link,0if no link). - Embeddings.csv: precomputed topological embeddings for each node.
-
Data Exploration and Preparation:
- Based on the initial exploration, the data include multiple types of entities (e.g., drugs, diseases, genes) and relationships (e.g., "biolink:treats", "biolink:same_as"). A heterogeneous graph was used to model the data, as it explicitly considers these contextual differences.
-
Graph representation:
- The knowledge graph was represented using PyTorch Geometric's
HeteroDataobject:- Nodes: each node was mapped to unique indices.
- Edges: relationships between nodes were encoded in a sparse adjacency matrix using
edge_index.
- The graph was modeled as undirected.
- The knowledge graph was represented using PyTorch Geometric's
-
Combining Drug-Disease Node Embeddings:
- A dictionary was created to map each node ID to its embedding.
- Drug-disease pairs in the
Ground Truth.csvwere represented using their node embeddings. - A Hadamard product (element-wise multiplication) was chosen to combine node embeddings for source-target pairs to capture relational patterns.
-
Model Training and Hyperparameter Tuning:
- The classes are moderately unbalanced - I tracked class-specific metrics such as precision, recall, and F1-score.
- Stratified splits were used to evenly distribute the values of
yin each split across training, validation, and test sets (70%, 15%, 15%). - A Random Forest Classifier was selected. Hyperparameters were tuned using
RandomizedSearchCVwith 5-fold cross-validation to optimize for ROC AUC.
-
Final Model:
- The best hyperparameters were used to retrain the model on combined training and validation data.
- The model was evaluated on an unseen test set.
-
Heterogeneous Graph Representation:
- The data include multiple types of entities (e.g., drugs, diseases, genes) and relationships (e.g., "drug-treats-disease," "gene-translates_to-protein")
- Undirected graph choice. While some relationships are inherently directional (e.g., "biolink:treats", "biolink:prevents", "biolink:causes") are directed, others (e.g., "biolink:same_as", "biolink:gene_associated_with_condition") are undirected. Using an undirected graph allows the model to propagate information bidirectionally across all relationships, simplifying the architecture - but I may lose key features of the graph.
-
Embedding Combination Method:
- The Hadamard product was used as it preserves relational properties and is computationally efficient.
- Alternative methods like concatenation were considered but this increases the number of features.
-
Classifier Choice:
- Random Forest was chosen. Here, we care about predicting drug-disease pair links. For simplicity and speed, I selected a Random Forest Classifier.
- Class weighting was applied to handle imbalance (
class_weight='balanced').
-
Metrics:
- ROC AUC: metric for evaluating probabilistic predictions.
- Precision-Recall AUC: assesses performance of classifier for class prediction.
- Normalized Confusion Matrix: displays class-specific prediction proportions normalized by true class size.
-
Graph Assumed Undirected:
- As mentioned, while some relationships (e.g., "biolink:treats", "biolink:prevents", "biolink:causes") are directed, others (e.g., "biolink:same_as", "biolink:gene_associated_with_condition") are undirected.
- Using an undirected graph simplifies the model by ensuring bidirectional message passing, but it may lead to information loss for directed relationships (e.g., "biolink:treats").
-
Precomputed Node Embeddings:
- This current implementation uses precomputed node embeddings. This limits the ability of the model to learn drug-disease-specific embeddings from the graph structure and relationships. Also are the pre-computed embeddings generated on the whole graph? This may lead to potential data leakage... see Future Work section for suggestions to generate node embeddings.
I could extend this work as following:
- Leveraging the heterogeneous graph to generate node embeddings using a graph neural network (GNN) framework with
torch_geometric.nn.HeteroConvlayers or use an encoder to generate node embeddings on the graph. - I would split the heterogeneous graph into training and testing edges based on the types of edges (drug-disease links) for classifier link prediction. But which ones to pick? It is unclear how to chose the
["Drug", RELATIONSHIP, "Disease"].edge_labelfor drug-disease links prediction, as considering all links with all drugs and diseases might create a sparse graph?