Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
adni_storage/
abide_storage/
cobre_storage/
corr_storage/
dlbs_storage/
fcon1000_storage/
ixi_storage/
oasis1_storage/
sald_storage/
ants/
model_dumps/
model_dumps/
pretrained/
medvit/
processed_slices/
vit_trained_bak/
camcan_storage/
Empty file added 3dres.py
Empty file.
Binary file added Papers/s41598-024-59578-3.pdf
Binary file not shown.
Binary file added __pycache__/classify_fnn.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/classify_fnn_bag.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_gn.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_att.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_att_bl.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_att_gate2.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_att_res.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_cbam.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_elu.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_lrelu.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_multi_att.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_res.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_spd.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_sw.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/cnn_mx_vl.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/dataset_cls.cpython-312.pyc
Binary file not shown.
7 changes: 4 additions & 3 deletions adni_preprocess.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
count=0

# Define source and target directories
source_folder="adni_storage/ADNI_nii_gz_stripped"
target_folder="adni_storage/ADNI_nii_gz_bias_corrected"
source_folder="camcan_storage/CamCAN_nii_gz_stripped"
target_folder="camcan_storage/CamCAN_nii_gz_bias_corrected"

# Path to N4BiasFieldCorrection binary
n4bias_path="ants/ants-2.5.4/bin/N4BiasFieldCorrection"
# n4bias_path="ants/ants-2.5.4/bin/N4BiasFieldCorrection"
n4bias_path="~/miniconda3/envs/db3/bin/N4BiasFieldCorrection"

# Loop through all .nii.gz files in the source folder
for input_file in "${source_folder}"/*.stripped.nii.gz; do
Expand Down
1 change: 0 additions & 1 deletion adni_storage

This file was deleted.

Binary file added age_distribution.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions calculate_agegap.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def plot_with_metrics(data, x_col, y_col, hue_col, title, x_lim):
plt.show()

# For training set
dfres_train = pd.read_csv("predicted_ages_train.csv", sep=",", index_col=0).reset_index()
dfres_train = pd.read_csv("model_dumps/cnn_mx_elu_predicted_ages_train.csv", sep=",", index_col=0).reset_index()
dfres_train = calculate_lowess_yhat_and_agegap(dfres_train)

# Keep only the row with the smallest Age for each SubjectID
Expand All @@ -50,7 +50,7 @@ def plot_with_metrics(data, x_col, y_col, hue_col, title, x_lim):
title="Age gap predictions (Train Set)", x_lim=(40, 100))

# For validation set
dfres_val = pd.read_csv("predicted_ages_val.csv", sep=",", index_col=0).reset_index()
dfres_val = pd.read_csv("model_dumps/cnn_mx_elu_predicted_ages_val.csv", sep=",", index_col=0).reset_index()
dfres_val = calculate_lowess_yhat_and_agegap(dfres_val)

# Keep only the row with the smallest Age for each SubjectID
Expand Down
186 changes: 186 additions & 0 deletions calculate_agegap_da.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import pandas as pd
import numpy as np
import statsmodels.api as sm
import seaborn as sns
from scipy.stats import norm
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy.interpolate import make_interp_spline, interp1d
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC # Import SVM
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.mixture import GaussianMixture
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def calculate_lowess_yhat_and_agegap(dfres):
dfres_agegap = dfres.copy()
# calculate agegap using lowess of predicted vs chronological age from training cohort
lowess = sm.nonparametric.lowess
lowess_fit = lowess(dfres_agegap.Predicted_Age.to_numpy(), dfres_agegap.Age.to_numpy(), frac=0.8, it=3)
lowess_fit_int = interp1d(lowess_fit[:,0], lowess_fit[:,1], bounds_error=False, kind='linear', fill_value=(0, 150))
y_lowess = lowess_fit_int(dfres_agegap.Age)
dfres_agegap["yhat_lowess"] = y_lowess
# dfres_agegap["yhat_lowess"] = age_prediction_lowess(np.array(dfres_agegap.Age))
if len(dfres_agegap.loc[dfres_agegap.yhat_lowess.isna()]) > 0:
print("Could not predict lowess yhat in " + str(len(dfres_agegap.loc[dfres_agegap.yhat_lowess.isna()])) + " samples")
dfres_agegap = dfres_agegap.dropna(subset="yhat_lowess")
dfres_agegap["AgeGap"] = dfres_agegap["Predicted_Age"] - dfres_agegap["yhat_lowess"]
dfres_agegap["AgeGap"] = dfres_agegap["AgeGap"].abs()
return dfres_agegap

# Function to calculate MAE and R², and annotate the plot
def plot_with_metrics(data, x_col, y_col, hue_col, title, x_lim):
# Calculate MAE and R²
mae = mean_absolute_error(data[x_col], data[y_col])
r2 = r2_score(data[x_col], data[y_col])

# Create scatterplot
sns.scatterplot(data=data, x=x_col, y=y_col, hue=hue_col, palette='coolwarm', hue_norm=(-12, 12))
plt.xlim(*x_lim)
plt.title(f"{title}\nMAE: {mae:.2f}, R²: {r2:.2f}")
plt.xlabel(x_col)
plt.ylabel(y_col)
plt.show()

# For training set
dfres_train = pd.read_csv("model_dumps/cnn_mx_elu_predicted_ages_train.csv", sep=",", index_col=0).reset_index()
dfres_train = calculate_lowess_yhat_and_agegap(dfres_train)

# Keep only the row with the smallest Age for each SubjectID
dfres_train = dfres_train.loc[dfres_train.groupby('SubjectID')['Age'].idxmin()]
dfres_train = dfres_train.reset_index(drop=True)

# For validation set
dfres_val = pd.read_csv("model_dumps/cnn_mx_elu_predicted_ages_val.csv", sep=",", index_col=0).reset_index()
dfres_val = calculate_lowess_yhat_and_agegap(dfres_val)

# Keep only the row with the smallest Age for each SubjectID
dfres_val = dfres_val.loc[dfres_val.groupby('SubjectID')['Age'].idxmin()]
dfres_val = dfres_val.reset_index(drop=True)


# Step 1: Encode categorical variables (for 'Sex' column)
dfres_train['Sex'] = dfres_train['Sex'].map({'M': 0, 'F': 1})
dfres_val['Sex'] = dfres_val['Sex'].map({'M': 0, 'F': 1})

# Step 2: Convert the 'Group' column to binary (AD vs not AD)
dfres_train['Group_binary'] = dfres_train['Group'].apply(lambda x: 1 if x == 'AD' else 0)
dfres_val['Group_binary'] = dfres_val['Group'].apply(lambda x: 1 if x == 'AD' else 0)

# Step 3: Initialize the LabelEncoder for the binary target 'Group_binary' column
y_train = dfres_train['Group_binary']
y_val = dfres_val['Group_binary']

print(f"Binary labels for training set: {y_train.unique()}") # To verify the binary classification

# Step 4: Drop the original 'Group' column and prepare features for training
X_train = dfres_train[['AgeGap']]
X_val = dfres_val[['AgeGap']]

print(f"Features for training set:\n{X_train.head()}")

# Step 1: Prepare the data
X_train = torch.tensor(dfres_train[['AgeGap']].values, dtype=torch.float32) # Feature
y_train = torch.tensor(dfres_train['Group_binary'].values, dtype=torch.float32) # Target (binary)

X_val = torch.tensor(dfres_val[['AgeGap']].values, dtype=torch.float32) # Feature
y_val = torch.tensor(dfres_val['Group_binary'].values, dtype=torch.float32) # Target (binary)

# Step 2: Create Dataloader
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Step 3: Define the Model with Attention Mechanism
class AttentionDNN(nn.Module):
def __init__(self):
super(AttentionDNN, self).__init__()
self.dense1 = nn.Linear(1, 64)
self.dropout1 = nn.Dropout(0.2)

# Attention mechanism (simple scaled dot-product attention)
self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=1, batch_first=True)

self.dense2 = nn.Linear(128, 32) # 64 + 64 (Concatenated)
self.dropout2 = nn.Dropout(0.2)
self.output = nn.Linear(32, 1) # Binary output

def forward(self, x):
# First dense layer
x = torch.relu(self.dense1(x))
x = self.dropout1(x).unsqueeze(1) # Add extra dimension for attention (batch_size, seq_len, feature)

# Attention mechanism (requires 3 inputs: queries, keys, values)
attn_output, _ = self.attention(x, x, x)

# Concatenate attention output with original features
x = torch.cat((x, attn_output), dim=-1)

# Second dense layer
x = torch.relu(self.dense2(x))
x = self.dropout2(x)

# Output layer (sigmoid for binary classification)
x = torch.sigmoid(self.output(x.squeeze(1))) # Remove extra dimension

return x

# Step 4: Initialize the model, loss function, and optimizer
model = AttentionDNN()
criterion = nn.BCELoss() # Binary Cross Entropy loss for binary classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Step 5: Train the model
num_epochs = 100
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()

outputs = model(inputs)
loss = criterion(outputs, labels.view(-1, 1)) # Reshaping labels for BCELoss
loss.backward()
optimizer.step()

running_loss += loss.item()

print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}")

# Step 6: Evaluate the model
model.eval()
y_pred_prob = []
y_true = []

with torch.no_grad():
for inputs, labels in val_loader:
outputs = model(inputs)
y_pred_prob.append(outputs)
y_true.append(labels)

y_pred_prob = torch.cat(y_pred_prob)
y_true = torch.cat(y_true)

# Convert to binary labels (0 or 1)
y_pred = (y_pred_prob > 0.5).float()

# Classification Report
class_names = ['Not AD', 'AD']
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))

# Confusion Matrix
conf_matrix = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()
114 changes: 114 additions & 0 deletions calculate_agegap_gmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import pandas as pd
import numpy as np
import statsmodels.api as sm
import seaborn as sns
from scipy.stats import norm
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy.interpolate import make_interp_spline, interp1d
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC # Import SVM
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.mixture import GaussianMixture
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score



def calculate_lowess_yhat_and_agegap(dfres):
dfres_agegap = dfres.copy()
# calculate agegap using lowess of predicted vs chronological age from training cohort
lowess = sm.nonparametric.lowess
lowess_fit = lowess(dfres_agegap.Predicted_Age.to_numpy(), dfres_agegap.Age.to_numpy(), frac=0.8, it=3)
lowess_fit_int = interp1d(lowess_fit[:,0], lowess_fit[:,1], bounds_error=False, kind='linear', fill_value=(0, 150))
y_lowess = lowess_fit_int(dfres_agegap.Age)
dfres_agegap["yhat_lowess"] = y_lowess
# dfres_agegap["yhat_lowess"] = age_prediction_lowess(np.array(dfres_agegap.Age))
if len(dfres_agegap.loc[dfres_agegap.yhat_lowess.isna()]) > 0:
print("Could not predict lowess yhat in " + str(len(dfres_agegap.loc[dfres_agegap.yhat_lowess.isna()])) + " samples")
dfres_agegap = dfres_agegap.dropna(subset="yhat_lowess")
dfres_agegap["AgeGap"] = dfres_agegap["Predicted_Age"] - dfres_agegap["yhat_lowess"]
dfres_agegap["AgeGap"] = dfres_agegap["AgeGap"].abs()
return dfres_agegap

# Function to calculate MAE and R², and annotate the plot
def plot_with_metrics(data, x_col, y_col, hue_col, title, x_lim):
# Calculate MAE and R²
mae = mean_absolute_error(data[x_col], data[y_col])
r2 = r2_score(data[x_col], data[y_col])

# Create scatterplot
sns.scatterplot(data=data, x=x_col, y=y_col, hue=hue_col, palette='coolwarm', hue_norm=(-12, 12))
plt.xlim(*x_lim)
plt.title(f"{title}\nMAE: {mae:.2f}, R²: {r2:.2f}")
plt.xlabel(x_col)
plt.ylabel(y_col)
plt.show()

# For training set
dfres_train = pd.read_csv("model_dumps/cnn_mx_elu_predicted_ages_train.csv", sep=",", index_col=0).reset_index()
dfres_train = calculate_lowess_yhat_and_agegap(dfres_train)

# Keep only the row with the smallest Age for each SubjectID
dfres_train = dfres_train.loc[dfres_train.groupby('SubjectID')['Age'].idxmin()]
dfres_train = dfres_train.reset_index(drop=True)

# For validation set
dfres_val = pd.read_csv("model_dumps/cnn_mx_elu_predicted_ages_val.csv", sep=",", index_col=0).reset_index()
dfres_val = calculate_lowess_yhat_and_agegap(dfres_val)

# Keep only the row with the smallest Age for each SubjectID
dfres_val = dfres_val.loc[dfres_val.groupby('SubjectID')['Age'].idxmin()]
dfres_val = dfres_val.reset_index(drop=True)


# Step 1: Encode categorical variables (for 'Sex' column)
dfres_train['Sex'] = dfres_train['Sex'].map({'M': 0, 'F': 1})
dfres_val['Sex'] = dfres_val['Sex'].map({'M': 0, 'F': 1})

# Step 2: Convert the 'Group' column to binary (AD vs not AD)
dfres_train['Group_binary'] = dfres_train['Group'].apply(lambda x: 1 if x == 'AD' else 0)
dfres_val['Group_binary'] = dfres_val['Group'].apply(lambda x: 1 if x == 'AD' else 0)

# Step 3: Initialize the LabelEncoder for the binary target 'Group_binary' column
y_train = dfres_train['Group_binary']
y_val = dfres_val['Group_binary']

print(f"Binary labels for training set: {y_train.unique()}") # To verify the binary classification

# Step 4: Drop the original 'Group' column and prepare features for training
X_train = dfres_train[['AgeGap']]
X_val = dfres_val[['AgeGap']]

print(f"Features for training set:\n{X_train.head()}")

# Step 5: Train the Gaussian Mixture Model (GMM)
gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=42)
gmm.fit(X_train)

# Step 6: Predict probabilities and classify
# GMM provides probabilities for each component (class). Choose the class with the highest probability.
y_pred_prob = gmm.predict_proba(X_val)
y_pred = np.argmax(y_pred_prob, axis=1) # Get the class with the highest probability

# Get the class labels (binary: 0 = not AD, 1 = AD)
class_names = ['Not AD', 'AD']

# Step 7: Evaluate the GMM model
print("Classification Report:")
print(classification_report(y_val, y_pred, target_names=class_names))

print("Confusion Matrix:")
conf_matrix = confusion_matrix(y_val, y_pred)

# Create a heatmap with seaborn
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()

# Step 8: (Optional) Evaluate overall accuracy
accuracy = accuracy_score(y_val, y_pred)
print(f"Accuracy: {accuracy:.2f}")
Loading