From ca15a176cd1e102901da873e1c302017cc7a2f17 Mon Sep 17 00:00:00 2001 From: Rashed Almansoori Date: Fri, 21 Nov 2025 18:27:34 -0500 Subject: [PATCH 1/7] converted preprocessing and created a download data script and converted all of CNN_functions except the ones that rely on VAE (as it is not yet converted) --- .gitignore | 4 + CNN_functions_pytorch.py | 976 +++++++++++++++++++++++++++++++++++++++ download_data.py | 6 + preprocessing_new.py | 223 +++++++++ 4 files changed, 1209 insertions(+) create mode 100644 CNN_functions_pytorch.py create mode 100644 download_data.py create mode 100644 preprocessing_new.py diff --git a/.gitignore b/.gitignore index 3b58d13..24fb5f6 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,10 @@ __pycache__/ *.py[cod] *$py.class +#Rashed +data/ +chexpert.pkl + # C extensions *.so diff --git a/CNN_functions_pytorch.py b/CNN_functions_pytorch.py new file mode 100644 index 0000000..fd6de12 --- /dev/null +++ b/CNN_functions_pytorch.py @@ -0,0 +1,976 @@ +import random +import psutil +import GPUtil +import sys +import time +import gc +import os + +import numpy as np +import torch +from torch import nn, optim +from torch.utils.data import Dataset, DataLoader + +from sklearn.metrics import accuracy_score, balanced_accuracy_score +from sklearn import metrics as sk_metrics # kept for compatibility, even if unused + +from VAE import Resnet_VAE, make_pretrain_encoder # assumed PyTorch implementation + + +# ------------------------------------------------------------------------- +# Utility: memory usage (CPU + GPU) +# ------------------------------------------------------------------------- + +def get_memory_usage(): + """Returns the current CPU and GPU memory usage in GB.""" + process = psutil.Process() + cpu_memory = process.memory_info().rss / (1024 ** 3) + + gpus = GPUtil.getGPUs() + gpu_memory = sum(gpu.memoryUsed for gpu in gpus) if gpus else 0 + + return cpu_memory, gpu_memory + + +# ------------------------------------------------------------------------- +# Torch Dataset for CNN / VAE +# ------------------------------------------------------------------------- + +class ChestDataset(Dataset): + """ + Wraps numpy arrays into a PyTorch Dataset. + + X: numpy array, shape (N, H, W, C) or (N, H, W) + meta: numpy array, shape (N, >=1), where meta[:,0] is the label. + Remaining columns are auxiliary info (sex, AP/PA, age, etc.). + """ + + def __init__(self, X, meta, add_info=True): + if X.ndim == 3: # (N, H, W) + X = X[..., None] # add channel dimension + + # (N, H, W, C) -> (N, C, H, W) + self.images = torch.from_numpy(X).float().permute(0, 3, 1, 2) + self.labels = torch.from_numpy(meta[:, 0]).long() + + self.add_info = add_info and meta.shape[1] > 1 + if self.add_info: + self.other = torch.from_numpy(meta[:, 1:]).float() + else: + self.other = None + + def __len__(self): + return self.images.shape[0] + + def __getitem__(self, idx): + x = self.images[idx] + y = self.labels[idx] + if self.add_info and self.other is not None: + return x, y, self.other[idx] + else: + return x, y + + +# ------------------------------------------------------------------------- +# Optional discretizer layer (not used by CNN but kept for compatibility) +# ------------------------------------------------------------------------- + +class Discretizer_layer(nn.Module): + def __init__(self, initial_value=1.0, input_shape=16, dtype=torch.float32): + super().__init__() + init = torch.full((input_shape,), float(initial_value), dtype=dtype) + self.bin = nn.Parameter(init) + + def forward(self, x): + # x: (..., input_shape) + # returns 1 where x > bin, else 0 + # broadcast self.bin to match x's shape on last dimension + return (x > self.bin).float() + + +# ------------------------------------------------------------------------- +# CNN definition (replacement for create_CNN) +# ------------------------------------------------------------------------- + +class SimpleCNN(nn.Module): + def __init__(self, grid_params, input_shape, num_classes, add_info): + super().__init__() + num_layer = grid_params.num_layer + dropout = grid_params.dropout + filter_size = grid_params.filter_size + # Use latent_dim as last_num_filters to mirror TF code + last_num_filters = grid_params.latent_dim if hasattr(grid_params, "latent_dim") else grid_params.end_dim_enc + + # input_shape from end_to_end_train is (H, W, C) or (C, H, W). + # We'll assume CNN sees (C, H, W). + if len(input_shape) == 3: + if input_shape[0] in [1, 3]: + in_channels, h, w = input_shape + else: + # assume H, W, C + h, w, in_channels = input_shape + else: + raise ValueError(f"Unexpected input_shape: {input_shape}") + + layers = [] + in_ch = in_channels + + # Convolutional blocks (all but last) + for lay_num in range(num_layer - 1): + out_ch = 16 * (2 ** lay_num) + layers.append(nn.Dropout(p=dropout / (2 ** lay_num))) + layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=filter_size, padding=filter_size // 2)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) + layers.append(nn.BatchNorm2d(out_ch)) + in_ch = out_ch + + # Final conv block (no pooling yet) + layers.append(nn.Dropout(p=dropout / (2 ** num_layer))) + layers.append(nn.Conv2d(in_ch, last_num_filters, kernel_size=filter_size, padding=filter_size // 2)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.BatchNorm2d(last_num_filters)) + + self.features = nn.Sequential(*layers) + + # determine spatial size after conv stack + with torch.no_grad(): + dummy = torch.zeros(1, in_channels, h, w) + feat = self.features(dummy) + _, c, h_f, w_f = feat.shape + + # global pooling over remaining spatial extent + self.pool = nn.MaxPool2d(kernel_size=(h_f, w_f)) + self.classifier = nn.Linear(last_num_filters, num_classes) + + def forward(self, x, return_embedding=False): + feat = self.features(x) # (N, C, H, W) + pooled = self.pool(feat) # (N, C, 1, 1) + emb = pooled.view(pooled.size(0), -1) # (N, C) + logits = self.classifier(emb) # (N, num_classes) + if return_embedding: + return emb, logits + return logits + + +def create_CNN(grid_params, input_shape, num_classes, add_info): + """ + PyTorch replacement for the original create_CNN. + """ + model = SimpleCNN(grid_params, input_shape, num_classes, add_info) + print('num trainable params', sum(p.numel() for p in model.parameters() if p.requires_grad)) + return model + + +# ------------------------------------------------------------------------- +# CNN loading & evaluation (replacement for load_CNN) +# ------------------------------------------------------------------------- + +def _build_loaders_from_numpy(train_data, test_data, val_data, add_info, batch_size): + X_train, meta_train = train_data + X_test, meta_test = test_data + X_val, meta_val = val_data + + train_ds = ChestDataset(X_train, meta_train, add_info=add_info) + test_ds = ChestDataset(X_test, meta_test, add_info=add_info) + val_ds = ChestDataset(X_val, meta_val, add_info=add_info) + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) + test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) + + return train_loader, test_loader, val_loader + + +def load_CNN(grid_params, input_shape, num_classes, debugging, ckpt, manager, + checkpoint_path, train_data, test_data, val_data, add_info): + """ + PyTorch version of load_CNN. + ckpt and manager arguments are kept for API compatibility, but not used. + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if grid_params.load_pretrain_model: + model = load_pretrain_costume(grid_params, input_shape, num_classes) + else: + model = create_CNN(grid_params, input_shape, num_classes, add_info) + + model.to(device) + + learning_rate = grid_params.fine_tune_rate + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + criterion = nn.CrossEntropyLoss() + + # load checkpoint if exists + os.makedirs(checkpoint_path, exist_ok=True) + ckpt_file = os.path.join(checkpoint_path, 'cnn.pt') + if os.path.exists(ckpt_file): + state = torch.load(ckpt_file, map_location=device) + model.load_state_dict(state['model']) + optimizer.load_state_dict(state['optimizer']) + print(f"Restored CNN from {ckpt_file}") + + if debugging: + first_param = next(model.parameters()) + print('DEBUGGING CNN weight sample', first_param.view(-1)[:5].detach().cpu().numpy()) + + train_loader, test_loader, val_loader = _build_loaders_from_numpy( + train_data, test_data, val_data, add_info, grid_params.batch_size + ) + + def eval_loader(loader): + model.eval() + total_loss = 0.0 + all_preds = [] + all_labels = [] + with torch.no_grad(): + for batch in loader: + if add_info and len(batch) == 3: + images, labels, _ = batch + else: + images, labels = batch + images = images.to(device) + labels = labels.to(device) + + logits = model(images) + loss = criterion(logits, labels) + total_loss += loss.item() * images.size(0) + + preds = torch.argmax(logits, dim=1) + all_preds.append(preds.cpu()) + all_labels.append(labels.cpu()) + + total_loss /= len(loader.dataset) + all_preds = torch.cat(all_preds).numpy() + all_labels = torch.cat(all_labels).numpy() + acc = accuracy_score(all_labels, all_preds) + return total_loss, acc + + train_loss, train_acc = eval_loader(train_loader) + test_loss, test_acc = eval_loader(test_loader) + val_loss, val_acc = eval_loader(val_loader) + + print('CNN train loss', train_loss, 'train acc', train_acc, + 'test loss', test_loss, 'test acc', test_acc, flush=True) + + # keep return signature: model, ckpt, manager, metrics + return model, None, None, [[train_loss, train_acc], + [val_loss, val_acc], + [test_loss, test_acc]] + + +# ------------------------------------------------------------------------- +# VAE helpers (assume Resnet_VAE is a PyTorch module with similar API) +# ------------------------------------------------------------------------- + +def load_VAE(params, add_info, num_classes, input_shape, checkpoint_path, path=''): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + model = Resnet_VAE( + filter_size=params.filter_size, + num_layer=params.num_layer, + input_shape=input_shape, + batchnorm_integration=params.batchnorm_integration, + shortcut=params.shortcut, + activation=params.activation, + num_filter_encoder=params.num_filter_encoder, + strides_encoder=params.strides_encoder, + num_filter_decoder=params.num_filter_decoder, + strides_decoder=params.strides_decoder, + latent_dim=params.latent_dim, + end_dim_enc=params.end_dim_enc, + learning_rate=params.learning_rate, + semi_supervised=params.semi_supervised, + num_classes=num_classes, + dropout=params.dropout, + load_pretrain_model=params.load_pretrain_model, + add_info=add_info, + loss_weights=params.loss_weights, + VAE_fine_tune=params.VAE_fine_tune, + path=path, + use_KLD_anneal=params.use_KLD_anneal + ) + + model.to(device) + + os.makedirs(checkpoint_path, exist_ok=True) + ckpt_file = os.path.join(checkpoint_path, 'vae.pt') + if os.path.exists(ckpt_file): + state = torch.load(ckpt_file, map_location=device) + model.load_state_dict(state['model']) + print(f"Restored VAE from {ckpt_file}") + + return model + + +def load_VAE_and_eval(params, input_shape, num_classes, debugging, ckpt, manager, + checkpoint_path, train_dataset, test_dataset, val_dataset, add_info): + """ + PyTorch version of load_VAE_and_eval. + Assumes train_dataset, test_dataset, val_dataset are DataLoaders. + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = load_VAE(params, add_info, num_classes, input_shape, checkpoint_path) + + if debugging: + first_param = next(model.parameters()) + print('DEBUGGING show example VAE weight sample', + first_param.view(-1)[:5].detach().cpu().numpy()) + + def eval_loader(loader): + model.eval() + total_losses = None + num_rounds = 0 + with torch.no_grad(): + for batch in loader: + if add_info and len(batch) == 3: + x, y, other = batch + else: + x, y = batch + other = None + x = x.to(device) + y = y.to(device) + if other is not None: + other = other.to(device) + + # expected API from Resnet_VAE: model.evaluate_batch returns + # (loss, rec_loss, kld, clf_loss, mae, mse) or similar + losses = model.evaluate_batch(x, y, other) + losses_np = [float(l) for l in losses] + if total_losses is None: + total_losses = np.zeros(len(losses_np), dtype=np.float64) + total_losses += np.array(losses_np) + num_rounds += 1 + + if total_losses is None: + return [0.0] + total_losses /= max(num_rounds, 1) + return total_losses.tolist() + + train_losses = eval_loader(train_dataset) + test_losses = eval_loader(test_dataset) + val_losses = eval_loader(val_dataset) + + return model, None, None, [[train_losses], [val_losses], [test_losses]] + + +# ------------------------------------------------------------------------- +# Pretrained backbone integration (replacement for load_pretrain_costume) +# ------------------------------------------------------------------------- + +def load_pretrain_costume(grid_params, input_shape, num_classes): + """ + Placeholder for integrating a pretrained backbone in PyTorch. + In the TF version this used EfficientNetB7; here you can load any + torchvision model (e.g. resnet18) and attach a small head. + For now, we just use SimpleCNN as a stand-in. + """ + print("WARNING: load_pretrain_costume currently uses SimpleCNN as a placeholder.") + model = SimpleCNN(grid_params, input_shape, num_classes, add_info=False) + return model + + +# ------------------------------------------------------------------------- +# CNN training (replacement for train_embedding_cnn) +# ------------------------------------------------------------------------- + +def train_embedding_cnn(grid_params, val_data, input_shape, + num_classes, checkpoint_path, load_data, split, + train_data, add_info, acc_stop=True): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + X_train, meta_train = train_data + X_val, meta_val = val_data + + train_ds = ChestDataset(X_train, meta_train, add_info=add_info) + val_ds = ChestDataset(X_val, meta_val, add_info=add_info) + + train_loader = DataLoader(train_ds, batch_size=grid_params.batch_size, shuffle=True) + val_loader = DataLoader(val_ds, batch_size=grid_params.batch_size, shuffle=False) + + if grid_params.load_pretrain_model: + model = load_pretrain_costume(grid_params, input_shape, num_classes) + else: + model = create_CNN(grid_params, input_shape, num_classes, add_info) + + model.to(device) + + learning_rate = grid_params.learning_rate + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + if num_classes == 1: + criterion = nn.MSELoss() + else: + criterion = nn.CrossEntropyLoss() + + os.makedirs(checkpoint_path, exist_ok=True) + ckpt_file = os.path.join(checkpoint_path, f'cnn_split{split}.pt') + + if acc_stop: + best_val_metric = 0.0 + else: + best_val_metric = float('inf') + + if load_data and os.path.exists(ckpt_file): + state = torch.load(ckpt_file, map_location=device) + model.load_state_dict(state['model']) + optimizer.load_state_dict(state['optimizer']) + print(f"Restored CNN from {ckpt_file}") + + for epoch in range(grid_params.epochs): + model.train() + running_loss = 0.0 + + for batch in train_loader: + if add_info and len(batch) == 3: + inputs, labels, _ = batch + else: + inputs, labels = batch + + inputs = inputs.to(device) + labels = labels.to(device) + + optimizer.zero_grad() + logits = model(inputs) + if num_classes == 1: + labels_f = labels.float().view(-1, 1) + loss = criterion(logits, labels_f) + else: + loss = criterion(logits, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() * inputs.size(0) + + train_loss = running_loss / len(train_loader.dataset) + + # validation + model.eval() + val_loss = 0.0 + all_logits = [] + all_labels = [] + with torch.no_grad(): + for batch in val_loader: + if add_info and len(batch) == 3: + inputs, labels, _ = batch + else: + inputs, labels = batch + + inputs = inputs.to(device) + labels = labels.to(device) + + logits = model(inputs) + if num_classes == 1: + labels_f = labels.float().view(-1, 1) + loss = criterion(logits, labels_f) + else: + loss = criterion(logits, labels) + val_loss += loss.item() * inputs.size(0) + + all_logits.append(logits.cpu()) + all_labels.append(labels.cpu()) + + val_loss /= len(val_loader.dataset) + all_logits = torch.cat(all_logits) + all_labels = torch.cat(all_labels) + if num_classes == 1: + preds = all_logits.view(-1).detach().numpy() + labels_np = all_labels.detach().numpy() + val_acc = -sk_metrics.mean_squared_error(labels_np, preds) + else: + preds = all_logits.argmax(dim=1).numpy() + labels_np = all_labels.numpy() + val_acc = accuracy_score(labels_np, preds) + + print(f"Epoch {epoch+1}/{grid_params.epochs} " + f"- train_loss: {train_loss:.4f} val_loss: {val_loss:.4f} val_acc: {val_acc:.4f}", + flush=True) + + improved = False + if acc_stop: + if val_acc > best_val_metric: + best_val_metric = val_acc + improved = True + else: + if val_loss < best_val_metric: + best_val_metric = val_loss + improved = True + + if improved: + torch.save({'model': model.state_dict(), + 'optimizer': optimizer.state_dict()}, + ckpt_file) + first_param = next(model.parameters()) + print('Saved CNN weights sample', first_param.view(-1)[:5].detach().cpu().numpy(), flush=True) + + # keep original return signature (no ckpt/manager objects in PyTorch version) + return model, None, None + + +# ------------------------------------------------------------------------- +# VAE training helpers (validate_data, beta_scheduler, load_model, train_embedding_VAE) +# ------------------------------------------------------------------------- + +def validate_data(val_dataset, add_info, model, beta=0): + """ + Evaluate VAE on a validation DataLoader. + Expects model.execute_net_xy(x, y, other, training=False, beta=beta) + to return: (loss, rec_loss, kld, clf, other_losses, pred, z, extra) + """ + device = next(model.parameters()).device + losses = np.zeros((4 + 2), dtype=np.float64) + num_rounds = 0 + predictions = [] + all_z = [] + gt = [] + + model.eval() + with torch.no_grad(): + for batch in val_dataset: + if add_info and len(batch) == 3: + test_x, test_y, test_other = batch + else: + test_x, test_y = batch + test_other = None + + test_x = test_x.to(device) + test_y = test_y.to(device) + if test_other is not None: + test_other = test_other.to(device) + + loss, rec_loss, kld, clf, other_losses, pred, z, _ = model.execute_net_xy( + test_x, test_y, test_other, training=False, beta=beta + ) + + losses[0] += float(loss) + losses[1] += float(rec_loss) + losses[2] += float(kld) + losses[3] += float(clf) + predictions.extend(pred.cpu().numpy().tolist()) + gt.extend(test_y.cpu().numpy().tolist()) + + for i, tmp_loss in enumerate(other_losses): + losses[4 + i] += float(tmp_loss.mean()) + + all_z.append([z.cpu().numpy(), test_y.cpu().numpy().tolist()]) + num_rounds += 1 + + return all_z, predictions, gt, losses, num_rounds + + +def beta_scheduler(epoch, total_epochs, epoch_steps=100, max_beta=1.0): + # Linear annealing, like original + return min(max_beta, (epoch % epoch_steps) / epoch_steps) + + +def load_model(params, num_classes, add_info, checkpoint_path, checkpoint_path_tmp, + le_warmup, input_shape, encoder, init=True): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + model = Resnet_VAE( + filter_size=params.filter_size, + num_layer=params.num_layer, + input_shape=input_shape, + batchnorm_integration=params.batchnorm_integration, + shortcut=params.shortcut, + activation=params.activation, + num_filter_encoder=params.num_filter_encoder, + strides_encoder=params.strides_encoder, + num_filter_decoder=params.num_filter_decoder, + strides_decoder=params.strides_decoder, + latent_dim=params.latent_dim, + end_dim_enc=params.end_dim_enc, + learning_rate=params.learning_rate, + semi_supervised=params.semi_supervised, + num_classes=num_classes, + dropout=params.dropout, + load_pretrain_model=params.load_pretrain_model, + add_info=add_info, + loss_weights=params.loss_weights, + VAE_fine_tune=params.VAE_fine_tune, + use_GAN=params.GAN, + use_KLD_anneal=params.use_KLD_anneal, + le_warmup=le_warmup, + gauss_std=params.gauss_std, + encoder=encoder + ) + + model.to(device) + + os.makedirs(checkpoint_path, exist_ok=True) + os.makedirs(checkpoint_path_tmp, exist_ok=True) + ckpt_file = os.path.join(checkpoint_path, 'vae.pt') + ckpt_tmp_file = os.path.join(checkpoint_path_tmp, 'vae_tmp.pt') + + if init and os.path.exists(ckpt_file): + state = torch.load(ckpt_file, map_location=device) + model.load_state_dict(state['model']) + print(f"Restored VAE from {ckpt_file}") + torch.save({'model': model.state_dict()}, ckpt_tmp_file) + elif not init and os.path.exists(ckpt_tmp_file): + state = torch.load(ckpt_tmp_file, map_location=device) + model.load_state_dict(state['model']) + print(f"Restored temp VAE from {ckpt_tmp_file}") + else: + # fresh model, save initial temp + torch.save({'model': model.state_dict()}, ckpt_tmp_file) + + # mimic original return: model, manager, ckpt, manager_tmp, ckpt_tmp + return model, None, None, None, None + + +def train_embedding_VAE(params, train_data, test_data, val_data, input_shape, + num_classes, checkpoint_path, load_data, split, + add_info, checkpoint_path_tmp, acc_stop=True): + """ + PyTorch version of train_embedding_VAE. + + Expects Resnet_VAE with: + - attributes: semi_supervised + - methods: + * train_semi(train_loader, beta) + * train(train_loader) + * evaluate_(loader, verbose) + * execute_net_xy(...) + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + encoder = None + if params.load_pretrain_model: + encoder = make_pretrain_encoder( + params.filter_size, + num_layer=0, + input_shape=input_shape, + batchnorm_integration=params.batchnorm_integration, + num_filter=0, + shortcut=0, + strides=0, + activation=0, + encoder_name=0, + dtype=0, + dilations=0, + dropout=params.dropout, + end_dim=params.end_dim_enc, + path='' + ) + + model, manager, ckpt, manager_tmp, ckpt_tmp = load_model( + params, num_classes, add_info, checkpoint_path, checkpoint_path_tmp, + params.le_warmup, input_shape, encoder[0] if encoder is not None else None, init=True + ) + + def make_loaders(randomize=False): + X_train, meta_train = train_data + X_test, meta_test = test_data + X_val, meta_val = val_data + + train_ds = ChestDataset(X_train, meta_train, add_info=add_info) + test_ds = ChestDataset(X_test, meta_test, add_info=add_info) + val_ds = ChestDataset(X_val, meta_val, add_info=add_info) + + train_loader = DataLoader(train_ds, batch_size=params.batch_size, + shuffle=randomize or True) + test_loader = DataLoader(test_ds, batch_size=params.batch_size, + shuffle=False) + val_loader = DataLoader(val_ds, batch_size=params.batch_size, + shuffle=False) + return train_loader, test_loader, val_loader + + train_loader, test_loader, val_loader = make_loaders(randomize=False) + + all_val_losses = [] + all_debugging_stuff = [] + if acc_stop: + best_val_metric = 0.0 + else: + print('stop for best loss on L') + best_val_metric = float('inf') + + epochs = params.epochs + + # initial val loss (use model.evaluate_ if implemented) + if hasattr(model, "evaluate_"): + init_val_losses = model.evaluate_(val_loader, verbose=0) + val_loss = init_val_losses[0] if isinstance(init_val_losses, (list, tuple, np.ndarray)) else init_val_losses + else: + val_loss = 0.0 + print('VAE initial val loss:', val_loss) + + for epoch in range(epochs): + # re-create loaders with randomization to mimic TF behavior + del train_loader, val_loader + gc.collect() + train_loader, _, val_loader = make_loaders(randomize=True) + + cpu_mem_before, gpu_mem_before = get_memory_usage() + print(f"Epoch {epoch + 1}/{epochs} " + f"- CPU: {cpu_mem_before:.2f} GB | GPU: {gpu_mem_before:.2f} GB", + flush=True) + + beta = beta_scheduler(epoch, epochs) + start_time_train = time.time() + + if getattr(model, "semi_supervised", False): + # expected: model.train_semi(train_loader, beta) -> loss, rec_loss, kld, clf_loss, mae, mse + loss, rec_loss, kld, clf_loss, mae, mse = model.train_semi(train_loader, beta=beta) + train_loss = [loss, rec_loss, kld, clf_loss, mae, mse] + print("Train Semi | Epoch: {:03d} | Loss: {:.4f} | Rec Loss: {:.4f} | " + "KLD: {:.4f} | Clf Loss: {:.4f} | MAE: {:.4f} | MSE: {:.4f}".format( + epoch, loss, rec_loss, kld, clf_loss, mae, mse), + flush=True) + + if np.isnan(loss): + print('Loss is NaN, restoring previous model snapshot (tmp).') + tmp_file = os.path.join(checkpoint_path_tmp, 'vae_tmp.pt') + if os.path.exists(tmp_file): + state = torch.load(tmp_file, map_location=device) + model.load_state_dict(state['model']) + continue + + if epoch % 10 == 0 and not np.isnan(loss): + print('Start VAE validation', flush=True) + + all_z, predictions, gt, losses_arr, num_rounds = validate_data(val_loader, add_info, model, beta=beta) + all_z_train, predictions_train, gt_train, losses_train, num_rounds_train = validate_data( + train_loader, add_info, model, beta=beta + ) + + if num_rounds > 0: + losses_arr = losses_arr / num_rounds + losses_arr = np.round(losses_arr, 3) + val_loss = losses_arr[0] + mae_val = losses_arr[-2] + + # classification accuracy + if len(predictions) > 0: + preds_np = np.array(predictions) + softmax_pred = torch.softmax(torch.from_numpy(preds_np), dim=-1).numpy() + arg_max = np.argmax(softmax_pred, axis=-1) + acc_loss = accuracy_score(gt, arg_max, normalize=True) + balanced_acc = balanced_accuracy_score(gt, arg_max) + else: + acc_loss = 0.0 + balanced_acc = 0.0 + + rec = 0 + if params.VAE_debug: + # use small batch for reconstruction debug + test_x = test_data[0][:16] + train_x = train_data[0][:16] + if add_info: + test_y = test_data[1][:16, 0] + train_y = train_data[1][:16, 0] + test_add_info = test_data[1][:16, 1:] + train_add_info = train_data[1][:16, 1:] + else: + test_y = test_data[1][:16] + train_y = train_data[1][:16] + test_add_info = None + train_add_info = None + + test_x_t = torch.from_numpy(test_x).float().permute(0, 3, 1, 2).to(device) + train_x_t = torch.from_numpy(train_x).float().permute(0, 3, 1, 2).to(device) + test_y_t = torch.from_numpy(test_y).long().to(device) + train_y_t = torch.from_numpy(train_y).long().to(device) + if test_add_info is not None: + test_add_t = torch.from_numpy(test_add_info).float().to(device) + train_add_t = torch.from_numpy(train_add_info).float().to(device) + else: + test_add_t = None + train_add_t = None + + out_rec = model.execute_net_xy(test_x_t, test_y_t, test_add_t, training=False, beta=beta) + out_rec_train = model.execute_net_xy(train_x_t, train_y_t, train_add_t, training=False, beta=beta) + + z, gt_vec = zip(*all_z) + z = np.concatenate(z, axis=0) + gt_vec = np.concatenate(gt_vec, axis=0) + + z_train, gt_train_vec = zip(*all_z_train) + z_train = np.concatenate(z_train, axis=0) + gt_train_vec = np.concatenate(gt_train_vec, axis=0) + + rec = out_rec[-1].cpu().numpy() + all_debugging_stuff.append([ + [losses_arr, z, gt_vec, rec, predictions], + [train_loss, z_train, gt_train_vec, out_rec_train[-1].cpu().numpy(), predictions_train] + ]) + size_in_bytes = sys.getsizeof(all_debugging_stuff) + size_in_mb = size_in_bytes / (1024 ** 2) + print(f"Memory used by all_debugging_stuff: {size_in_mb:.2f} MB") + + del test_x_t, train_x_t, test_y_t, train_y_t + + print( + "Val Results Semi | Epoch: {:03d} | Acc Loss: {:.4f} | Balanced Acc: {:.4f} | " + "MAE: {:.4f} | min_rec: {:.4f} | max_rec: {:.4f}".format( + epoch, acc_loss, balanced_acc, mae_val, + float(np.min(rec)) if np.size(rec) else 0.0, + float(np.max(rec)) if np.size(rec) else 0.0 + ), + flush=True, + ) + + improve = False + if acc_stop: + if balanced_acc > best_val_metric: + best_val_metric = balanced_acc + improve = True + else: + if val_loss < best_val_metric: + best_val_metric = val_loss + improve = True + + if improve: + ckpt_file = os.path.join(checkpoint_path, 'vae.pt') + torch.save({'model': model.state_dict()}, ckpt_file) + print('Saved VAE weights', flush=True) + + print( + "VAE Val Loss: {:.4f} | Acc Loss: {:.4f} | Improve: {} | Time: {:.2f} min | Beta: {:.4f}".format( + val_loss, acc_loss, improve, (time.time() - start_time_train) / 60, beta + ), + flush=True, + ) + all_val_losses.append(losses_arr) + else: + # non-semi-supervised mode (not commonly used in paper) + loss, rec_loss, kld, mae, mse = model.train(train_loader) + print( + "Train | Epoch: {:03d} | Loss: {:.4f} | Rec Loss: {:.4f} | " + "KLD: {:.4f} | MAE: {:.4f} | MSE: {:.4f}".format( + epoch, loss, rec_loss, kld, mae, mse + ), + flush=True, + ) + + # validation loop for non-semi-supervised mode + model.eval() + losses_arr = np.zeros((3 + 2), dtype=np.float64) + num_rounds = 0 + with torch.no_grad(): + for batch in val_loader: + if add_info and len(batch) == 3: + test_x, y, test_other = batch + else: + test_x, y = batch + test_other = None + + test_x = test_x.to(device) + if test_other is not None: + test_other = test_other.to(device) + + loss_v, rec_loss_v, kld_v, other_losses = model.execute_net( + test_x, training=False + ) + losses_arr[0] += float(loss_v) + losses_arr[1] += float(rec_loss_v) + losses_arr[2] += float(kld_v) + for i, tmp_loss in enumerate(other_losses): + losses_arr[3 + i] += float(tmp_loss.mean()) + + num_rounds += 1 + + if num_rounds > 0: + losses_arr = losses_arr / num_rounds + val_loss = losses_arr[0] + losses_arr = np.round(losses_arr, 3) + + print('VAE val results', epoch, losses_arr, flush=True) + improve = False + if val_loss < best_val_metric: + best_val_metric = val_loss + improve = True + ckpt_file = os.path.join(checkpoint_path, 'vae.pt') + torch.save({'model': model.state_dict()}, ckpt_file) + print('Saved VAE weights', flush=True) + + print('VAE val loss:', val_loss, 'improve', improve, + 'time:', (time.time() - start_time_train) / 60, flush=True) + all_val_losses.append(losses_arr) + + # save tmp snapshot each epoch + ckpt_tmp_file = os.path.join(checkpoint_path_tmp, 'vae_tmp.pt') + torch.save({'model': model.state_dict()}, ckpt_tmp_file) + print('Epoch time', (time.time() - start_time_train) / 60, flush=True) + + return model, None, None, all_val_losses, all_debugging_stuff + + +# ------------------------------------------------------------------------- +# Embedding extraction (replacement for get_layer_embeddings) +# ------------------------------------------------------------------------- + +def get_layer_embeddings(grid_params, model, layer_names, all_datasets, get_max=True, add_info=False): + """ + PyTorch version of get_layer_embeddings. + + In the original code: + - if grid_params.use_VAE: intermediate_layer_model = model.clf_model() + - else: intermediate_layer_model = keras.Model(inputs=model.input, outputs=layer_output) + + Here: + - if use_VAE: assume model.clf_model() returns a module whose forward gives embedding maps + - else: assume CNN model has forward(..., return_embedding=True) that returns (embedding, logits) + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print('START EMBEDDING', layer_names) + + if getattr(grid_params, "use_VAE", 0): + embedding_model = model.clf_model().to(device) + else: + embedding_model = model.to(device) + + output_embeds = [] + all_labels = [[] for _ in range(len(all_datasets))] + all_other_info = [[] for _ in range(len(all_datasets))] + num_embedding_layers = 1 # we only support one embedding for now + + for ds_idx, input_dataset in enumerate(all_datasets): + all_batch_embeddings = [[] for _ in range(num_embedding_layers)] + for ds_record in input_dataset: + if add_info and len(ds_record) == 3: + images, labels, other = ds_record + all_other_info[ds_idx].append(other.numpy()) + else: + images, labels = ds_record + other = None + + images = images.to(device) + + with torch.no_grad(): + if getattr(grid_params, "use_VAE", 0): + # assume clf_model() forward returns a list/tuple with embeddings as first element + prediction_0 = embedding_model(images)[0] + else: + emb, _ = embedding_model(images, return_embedding=True) + prediction_0 = emb # (N, D) + + all_labels[ds_idx].append(labels.numpy()) + + if len(layer_names) > 1: + for lay in range(num_embedding_layers): + if get_max and prediction_0.dim() == 4: + diff = prediction_0[:, lay].amax(dim=(1, 2)).cpu().numpy() + else: + diff = prediction_0.cpu().numpy() + all_batch_embeddings[lay].extend(diff) + else: + if get_max and prediction_0.dim() == 4: + diff = prediction_0.amax(dim=(2, 3)).cpu().numpy() + else: + diff = prediction_0.cpu().numpy() + all_batch_embeddings[0].extend(diff) + + all_labels[ds_idx] = np.concatenate(all_labels[ds_idx], axis=0) + if add_info and all_other_info[ds_idx]: + all_other_info[ds_idx] = np.concatenate(all_other_info[ds_idx], axis=0) + + curr_embeds = [np.stack(entry, axis=0) for entry in all_batch_embeddings] + if len(curr_embeds) > 1: + concat_embeds = np.concatenate(curr_embeds, axis=-1) + output_embeds.append(concat_embeds) + else: + output_embeds.append(curr_embeds[0]) + print(curr_embeds[0].shape) + + return output_embeds, all_labels, all_other_info diff --git a/download_data.py b/download_data.py new file mode 100644 index 0000000..c4f0a16 --- /dev/null +++ b/download_data.py @@ -0,0 +1,6 @@ +import kagglehub + +# Download latest version +path = kagglehub.dataset_download("nih-chest-xrays/data") + +print("Path to dataset files:", path) \ No newline at end of file diff --git a/preprocessing_new.py b/preprocessing_new.py new file mode 100644 index 0000000..c5cdabe --- /dev/null +++ b/preprocessing_new.py @@ -0,0 +1,223 @@ +import os +import random +import pickle +from collections import defaultdict + +import cv2 +import numpy as np +import pandas as pd +from PIL import Image + + +NIH_ROOT = "data" +CSV_PATH = os.path.join(NIH_ROOT, "Data_Entry_2017.csv") +IMAGE_ROOT = NIH_ROOT + +OUTPUT_PKL = "chexpert.pkl" +RANDOM_SEED = 42 + + + +def process_image(image_path, size=(128, 128)): + """ + Load an image from disk, convert to grayscale, + apply CLAHE, resize, and return as numpy array. + """ + # Load with PIL + pil_img = Image.open(image_path).convert("RGB") + img = np.array(pil_img) + + # Convert to grayscale + gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + # Apply CLAHE + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + clahe_img = clahe.apply(gray_img) + + # Resize to target size + resized_img = cv2.resize(clahe_img, size, interpolation=cv2.INTER_AREA) + + # Normalize to [0,1] as float32 + + return resized_img # shape: (H, W) + + +def load_and_filter_metadata(): + df = pd.read_csv(CSV_PATH) + + # Keep only Cardiomegaly or No Finding + mask = (df["Finding Labels"] == "Cardiomegaly") | (df["Finding Labels"] == "No Finding") + df = df[mask].copy() + + # Keep only frontal images (PA or AP) + df = df[df["View Position"].isin(["PA", "AP"])].copy() + + # Map labels: 0 = No Finding, 1 = Cardiomegaly + def label_fn(labels): + if labels == "No Finding": + return 0 + elif labels == "Cardiomegaly": + return 1 + else: + raise ValueError(f"Unexpected label: {labels}") + + df["label"] = df["Finding Labels"].apply(label_fn) + + # Basic cleaning / encoding + # Sex: 'M' / 'F' -> 0 / 1 (or you can flip if you want) + df["Sex_code"] = df["Patient Gender"].map({"M": 0, "F": 1}) + + # View Position: PA=0, AP=1 + df["View_code"] = df["View Position"].map({"PA": 0, "AP": 1}) + + # Age: just use as float + df["Age"] = df["Patient Age"].astype(float) + + # Build full image path + # Most Kaggle NIH dumps have flat images folder; if not, adjust here + image_map = build_image_map(IMAGE_ROOT) + + df["image_path"] = df["Image Index"].map(image_map) + + # Drop any rows where we couldn't find the file (should be few / none) + df = df.dropna(subset=["image_path"]) + + return df + + +def build_image_map(root): + """ + Walk through all subfolders under `root` and build + a dict: { '00000001_000.png': '/full/path/to/images_001/00000001_000.png', ... } + """ + image_map = {} + for dirpath, dirnames, filenames in os.walk(root): + for fname in filenames: + if fname.lower().endswith(".png"): + full_path = os.path.join(dirpath, fname) + image_map[fname] = full_path + return image_map + + +def patient_level_split(df, train_ratio=0.8, val_ratio=0.2, seed=RANDOM_SEED): + """ + Split patients into train/val/test (patient-level). + Train+Val = train_ratio, Test = 1 - train_ratio. + Val is val_ratio * Train. + """ + patient_ids = df["Patient ID"].unique().tolist() + random.seed(seed) + random.shuffle(patient_ids) + + num_patients = len(patient_ids) + train_end = int(train_ratio * num_patients) + + train_patients_tmp = patient_ids[:train_end] + test_patients = patient_ids[train_end:] + + # Split train into train/val + random.seed(seed + 1) + random.shuffle(train_patients_tmp) + val_end = int(len(train_patients_tmp) * val_ratio) + val_patients = train_patients_tmp[:val_end] + train_patients = train_patients_tmp[val_end:] + + def assign_split(pid): + if pid in train_patients: + return "train" + elif pid in val_patients: + return "validation" + else: + return "test" + + df["split"] = df["Patient ID"].apply(assign_split) + return df + + + +def balance_per_split(df): + """ + For each split, downsample majority class so we have ~50/50. + """ + balanced_dfs = {} + for split_name in ["train", "validation", "test"]: + split_df = df[df["split"] == split_name].copy() + + cardio = split_df[split_df["label"] == 1] + no_find = split_df[split_df["label"] == 0] + + n_cardio = len(cardio) + n_no_find = len(no_find) + print(f"{split_name}: {n_cardio} cardiomegaly, {n_no_find} no finding") + + if n_cardio == 0 or n_no_find == 0: + print(f"⚠️ Warning: one class empty in {split_name}, skipping balancing") + balanced_dfs[split_name] = split_df + continue + + target_size = min(n_cardio, n_no_find) + + cardio_bal = cardio.sample(n=target_size, random_state=RANDOM_SEED) + no_find_bal = no_find.sample(n=target_size, random_state=RANDOM_SEED) + + balanced = pd.concat([cardio_bal, no_find_bal]).sample(frac=1.0, random_state=RANDOM_SEED) + balanced_dfs[split_name] = balanced + + print(f"{split_name} balanced: {len(balanced)} total ({target_size} + {target_size})") + + return balanced_dfs + + + +def build_numpy_and_save(balanced_dfs, save_path): + save_data = {} + + for split_name, split_df in balanced_dfs.items(): + X_list = [] + int_list = [] + float_list = [] + + for _, row in split_df.iterrows(): + img_path = row["image_path"] + img = process_image(img_path) # (128,128) float32 + + X_list.append(img) + + # y, sex_code, view_code + y = int(row["label"]) + sex_code = int(row["Sex_code"]) + view_code = int(row["View_code"]) + + int_list.append([y, sex_code, view_code]) + float_list.append([row["Age"]]) + + X = np.stack(X_list) # (N, 128, 128) + int_data = np.stack(int_list) # (N, 3) + float_data = np.stack(float_list) # (N, 1) + + print(f"{split_name}: X shape = {X.shape}, int_data shape = {int_data.shape}, float_data shape = {float_data.shape}") + print(f"{split_name}: cardiomegaly percentage = {np.mean(int_data[:, 0])}") + + save_data[split_name] = [X, int_data, float_data] + + with open(save_path, "wb") as f: + pickle.dump(save_data, f) + + print(f"Saved to {save_path}") + + + +if __name__ == "__main__": + # 1) Load & filter metadata + df = load_and_filter_metadata() + print("After filtering:") + print(df["Finding Labels"].value_counts()) + + # 2) Patient-level split + df = patient_level_split(df) + + # 3) Balance 50/50 per split + balanced = balance_per_split(df) + + # 4) Build numpy arrays & save as pickle + build_numpy_and_save(balanced, OUTPUT_PKL) From 4f701df4686132528cdd78c971ec980cc2e74ef0 Mon Sep 17 00:00:00 2001 From: UYang-121 Date: Mon, 24 Nov 2025 23:39:46 -0500 Subject: [PATCH 2/7] Update VAE.py --- VAE.py | 753 ++++++++++----------------------------------------------- 1 file changed, 125 insertions(+), 628 deletions(-) diff --git a/VAE.py b/VAE.py index ac99074..33fe08d 100644 --- a/VAE.py +++ b/VAE.py @@ -1,638 +1,135 @@ -import tensorflow as tf -import numpy as np -from sklearn.metrics import accuracy_score -from sklearn.metrics import balanced_accuracy_score -from sklearn.metrics import precision_recall_fscore_support -from sklearn import metrics +import torch +import torch.nn as nn +import torch.nn.functional as F +class Encoder(nn.Module): + def __init__(self, in_channels=1, latent_dim=62): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, 32, 4, 2, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), -def enc_layer(net,stride,filter_num,batchnorm_integration,shortcut, - dilations,filter_size,trainable,dtype,old_h, old_num_filter,dropout,activation, - name,gauss_std): + nn.Conv2d(64, 128, 4, 2, 1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), - convolution = tf.keras.layers.Conv2D(filter_num, filter_size, dilation_rate=dilations, strides=stride, padding="same", - data_format='channels_last', use_bias=False, activation=None, name=name+'conv',trainable=trainable,dtype=dtype)(net) - #convolution = tf.squeeze(convolution, axis=-2) + nn.Conv2d(128, 256, 4, 2, 1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), - if shortcut: + nn.Conv2d(256, 512, 4, 2, 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), - shortcut = tf.keras.layers.AveragePooling2D(pool_size=filter_size, strides=stride, padding="same", - data_format="channels_last",trainable=trainable)(net) - # reduce shortcut: - #shortcut = tf.squeeze(shortcut, axis=-1) + nn.Conv2d(512, 512, 4, 2, 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), - shortcut= tf.keras.layers.Dense(1, activation=None, use_bias=True, input_shape=(-1, old_h,old_h, old_num_filter),trainable=trainable,dtype=dtype)(shortcut) - convolution = tf.keras.layers.Concatenate(axis=-1)([convolution, shortcut]) - if batchnorm_integration: - convolution = tf.keras.layers.BatchNormalization(trainable=trainable,dtype=dtype)(convolution) - - if dropout: - convolution = tf.keras.layers.Dropout(rate=dropout)(convolution) - - convolution = tf.keras.layers.GaussianNoise(gauss_std)(convolution) - convolution = activation(convolution) - - return convolution -def dec_layer(net,stride,filter_num,batchnorm_integration,shortcut, - dilations,filter_size,trainable,dtype,old_h, old_num_filter,dropout,activation, - name,gauss_std=0.0): - - net = tf.keras.backend.resize_images(net, height_factor=stride, width_factor=stride,data_format="channels_last",interpolation="nearest") - convolution = tf.keras.layers.Conv2D(filter_num, filter_size, dilation_rate=dilations, strides=1, padding="same", - data_format='channels_last', use_bias=False, activation=None)(net) - if shortcut: - shortcut = tf.keras.layers.Dense(1, activation=None, use_bias=True)(net) - convolution = tf.keras.layers.Concatenate(axis=-1)([convolution, shortcut]) - if batchnorm_integration: - convolution = tf.keras.layers.BatchNormalization()(convolution) - net = tf.keras.layers.GaussianNoise(gauss_std)(net) - convolution = activation(convolution) - return convolution - - -def make_costume_encoder(filter_size, num_layer, input_shape, batchnorm_integration, num_filter, shortcut, strides, - activation, encoder_name, dtype, dilations, dropout, end_dim,gauss_std=0.0): - inputs = tf.keras.Input(shape=input_shape, dtype=dtype) - net = inputs # /255.0 - img_size = input_shape[0] - num_filter = [input_shape[2]] + num_filter - for layer_num, stride in enumerate(strides): - old_num_filter = num_filter[layer_num] - if layer_num: - old_num_filter += 1 - net = enc_layer(net, stride, num_filter[layer_num + 1], batchnorm_integration, shortcut, dilations, filter_size, - trainable=True, dtype=dtype, - old_h=img_size, - old_num_filter=num_filter[layer_num], - dropout=dropout, - activation=activation, - name=encoder_name + '_' + str(layer_num),gauss_std=gauss_std) - img_size /= stride - - # net = tf.keras.layers.MaxPooling2D(pool_size=(img_size, img_size))(net) - net = tf.keras.layers.Flatten()(net) - net = tf.keras.layers.Dense(end_dim, activation=None)(net) - net = activation(net) - - model = tf.keras.Model(inputs=inputs, outputs=net) - return model, int(img_size) - - -def make_pretrain_encoder(filter_size, num_layer, input_shape, batchnorm_integration, num_filter, shortcut, strides, - activation, encoder_name, dtype, dilations, dropout, end_dim,path=''): - - - base_model = tf.keras.models.load_model(path+'efficientnetb7_saved_model')#(path+'efficientnetb7_saved_model_128.keras')# - layer_name = 'block4a_activation' # 'block7a_project_conv' - partial_model = tf.keras.Model(inputs=base_model.input, outputs=base_model.get_layer(layer_name).output) - - # add input layer with gaussian noise: - net_input = tf.keras.Input(shape=input_shape, dtype=tf.dtypes.float32) - # x= tf.keras.layers.GaussianNoise(0.05)(net_input) - x = partial_model(net_input) - - # x = base_model.get_layer(layer_name).output - # Add new dropout layer with a custom rate - x = tf.keras.layers.Dropout(dropout)(x) # Adjust dropout rate as needed #grid_params.dropout - # Add conv layer to reduce filter size: - - x = tf.keras.layers.Conv2D(end_dim, filter_size, padding='same', activation='relu')(x)#tanh - if batchnorm_integration: - x = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001)(x) - - next_dropout=max(0.0,dropout-0.1) - x = tf.keras.layers.Dropout(next_dropout)(x) # Adjust dropout rate as needed #grid_params.dropout / 2 - #x = tf.keras.layers.GaussianNoise(0.1)(x) - - x = tf.keras.layers.MaxPooling2D(pool_size=(8, 8))(x) - x = tf.keras.layers.Flatten()(x) - - - model = tf.keras.models.Model(inputs=net_input, outputs=x) - - - return model, 8 - - - - - -def make_encoder(filter_size,num_layer,input_shape, - batchnorm_integration, - num_filter=[16,32,62,124,248,124,124,32], - shortcut=True, - strides=[2,1,2,1,2,1,2,1], - activation=tf.keras.activations.relu, - encoder_name='encoder', - dtype=tf.float32, - dilations=1, - dropout=0.0, - end_dim=128,load_pretrain_model=0,path='',gauss_std=0.0): - if load_pretrain_model: - return make_pretrain_encoder(filter_size,num_layer,input_shape, - batchnorm_integration, - num_filter=num_filter, - shortcut=shortcut, - strides=strides, - activation=activation, - encoder_name=encoder_name, - dtype=dtype, - dilations=dilations, - dropout=dropout, - end_dim=end_dim,path=path)#,gauss_std=gauss_std - else: - return make_costume_encoder(filter_size,num_layer,input_shape, - batchnorm_integration, - num_filter=num_filter, - shortcut=shortcut, - strides=strides, - activation=activation, - encoder_name=encoder_name, - dtype=dtype, - dilations=dilations, - dropout=dropout, - end_dim=end_dim,gauss_std=gauss_std) - - - - -def make_mlp_clf(input_shape,activation,batchnorm_integration,dropout=0.0,num_classes=2,layer=1,name='clf'): - inputs = tf.keras.Input(shape=input_shape, dtype=dtype) - net=inputs - - for i in range(layer): - if dropout: - net=tf.keras.layers.Dropout(dropout/8)(net) - if batchnorm_integration: - net = tf.keras.layers.BatchNormalization()(net) - - #net=tf.keras.layers.Dense(int(input_shape[0]/(2**(i+1))), name=name+str(i))(net) - print(input_shape[0],(2 ** (i + 1)),input_shape[0] / (2 ** (i + 1)),int(input_shape[0] / (2 ** (i + 1)))) - net = tf.keras.layers.Dense(input_shape[0] / (2 ** (i + 1)), name=name + str(i))(net) - net=activation(net) - if dropout: - net = tf.keras.layers.Dropout(dropout/8)(net) - if batchnorm_integration: - net = tf.keras.layers.BatchNormalization()(net) - - net = tf.keras.layers.Dense(num_classes, name=name + str(layer))(net) - - model = tf.keras.Model(inputs=inputs, outputs=net) - return model - -def make_decoder(filter_size,num_layer,input_shape,output_shape, - batchnorm_integration, - num_filter=[1,32,62,124,248,124,124,124], - shortcut=True, - strides=[2,1,2,1,2,1,2,2], - activation=tf.keras.activations.relu, - im_size=14, - encoder_name='encoder', - dtype=tf.float32, - dilations=1, - dropout=0.0,semi_supervised=False,gauss_std=0.0): - inputs = tf.keras.Input(shape=input_shape, dtype=dtype) - - net = tf.keras.layers.Dense(im_size * im_size * 128, activation='relu')(inputs) #* 128 - net = tf.keras.layers.Reshape((im_size, im_size, 128))(net)#, 128 - - #im_size=input_shape#[0] - #num_filter=[input_shape[3]]+num_filter - old_num_filter=128 - for layer_num,stride in enumerate(strides): - #''' - if layer_num < len(strides) - 1: - curr_shortcut=shortcut - else: - curr_shortcut=False - gauss_std=0.0 - activation=tf.keras.activations.sigmoid - net=dec_layer(net,stride,num_filter[-(layer_num+1)],batchnorm_integration,curr_shortcut,dilations,filter_size, - trainable=True,dtype=dtype, - old_h=im_size, - old_num_filter=old_num_filter+1, - dropout=dropout, - activation=activation, - name=encoder_name+'_'+str(layer_num),gauss_std=gauss_std) - - - - model = tf.keras.Model(inputs=inputs, outputs=net) - return model -dtype=tf.float32 - - - - -class WarmUpLearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): - def __init__(self, target_lr, warmup_steps): - """ target_lr: Final learning rate after warmup - warmup_steps: Number of steps to reach target LR - """ - self.target_lr = target_lr - self.warmup_steps = warmup_steps - - def __call__(self, step): - """ Linear warm-up from 0 to target_lr over warmup_steps """ - if self.warmup_steps > 0: - return self.target_lr * tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32) - return self.target_lr # Default LR if no warm-up needed - -class Resnet_VAE(tf.keras.Model): - def __init__(self, - filter_size, - num_layer, - input_shape, - batchnorm_integration, - shortcut, - activation, - num_filter_encoder, - strides_encoder, - end_dim_enc, - num_filter_decoder, - strides_decoder, - latent_dim, - learning_rate, - semi_supervised, - num_classes, - dropout, - load_pretrain_model, - add_info, - loss_weights, - use_KLD_anneal, - VAE_fine_tune,path='',use_GAN=False,le_warmup=0,gauss_std=0.0,encoder=0): - super(Resnet_VAE, self).__init__() - #self.dtype=dtype - self.latent_dim=latent_dim - self.input_shape_=input_shape - self.load_pretrain_model=load_pretrain_model - self.add_info=add_info - self.loss_weights=loss_weights - self.VAE_fine_tune=VAE_fine_tune - self.use_KLD_anneal=use_KLD_anneal - print('loss_weights',self.loss_weights) - - if activation == 'relu': - self.activation = tf.keras.layers.LeakyReLU() - elif activation == 'tanh': - self.activation = tf.keras.activations.tanh - - if load_pretrain_model: - self.encoder, im_size=encoder,8 - else: - self.encoder,im_size =make_encoder(filter_size=filter_size, - num_layer=num_layer, - input_shape=input_shape, - batchnorm_integration=batchnorm_integration, - num_filter = num_filter_encoder, - shortcut = shortcut, - strides = strides_encoder, - activation = self.activation, - encoder_name = 'encoder', - dtype = tf.float32, - dilations = 1, - dropout = dropout, - end_dim = end_dim_enc, - load_pretrain_model=load_pretrain_model,path=path,gauss_std=gauss_std) - - self.z_mean = tf.keras.layers.Dense(latent_dim, name='z_mean') - self.z_log_var = tf.keras.layers.Dense(latent_dim, name='z_log_var') - - - self.semi_supervised=semi_supervised - dec_inp_shape = (latent_dim,) - if semi_supervised: - cls_input_shape=(latent_dim,) - if self.add_info: - cls_input_shape = (latent_dim+3,) - - self.classifier=make_mlp_clf(input_shape=cls_input_shape, - activation=self.activation, - batchnorm_integration=batchnorm_integration, - dropout=0.0, - num_classes=2,layer=3) - - self.clf_loss= tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - if self.VAE_fine_tune==0: - dec_inp_shape=(latent_dim+num_classes,) - else: - dec_inp_shape = (latent_dim,) - self.decoder = make_decoder(filter_size=filter_size, - num_layer=num_layer, - input_shape=dec_inp_shape, - output_shape=input_shape[0], - batchnorm_integration=batchnorm_integration, - num_filter=num_filter_decoder, - shortcut=shortcut, - strides=strides_decoder, - activation=tf.keras.activations.relu, - im_size=im_size, - encoder_name='decoder', - dtype=tf.float32, - dilations=1, - dropout=dropout, - semi_supervised=semi_supervised,gauss_std=gauss_std) - - if le_warmup: - warmup_steps = 5 * 40 # Warmup for first 5 epochs - target_lr = learning_rate#1e-3 - - lr_schedule = WarmUpLearningRateSchedule(target_lr=target_lr, warmup_steps=warmup_steps) - - - else: - lr_schedule =learning_rate - - - #self.optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=lr_schedule) - self.optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule)## - self.loss_fkts=[tf.keras.losses.MAE,tf.keras.losses.MSE] - - - - - def train(self,train_gen): - losses=np.zeros((3+2)) - num_rounds=0 - for (X,y) in train_gen: - loss, rec_loss, kld,other_loss=self.train_step(X) - losses[0]+=loss - losses[1]+=rec_loss - losses[2]+=kld - for i,tmp_loss in enumerate(other_loss): - losses[3+i]+=tf.reduce_mean(tmp_loss) - num_rounds+=1 - #print(num_rounds,np.round(losses/num_rounds,decimals=4)) - return losses/num_rounds - - - def train_semi(self,train_gen,beta=0): - losses=np.zeros((4+2)) - num_rounds=0 - if self.add_info: - for (X,y,other_data) in train_gen: - loss, rec_loss, kld,clf_loss,other_loss=self.train_step_semi_label(X,y,other_data,beta=beta) - losses[0]+=loss.numpy() - losses[1]+=rec_loss.numpy() - losses[2]+=kld.numpy() - losses[3] += clf_loss.numpy() - for i,tmp_loss in enumerate(other_loss): - losses[4+i]+=tf.reduce_mean(tmp_loss).numpy() - num_rounds+=1 - #print(num_rounds,np.round(losses/num_rounds,decimals=4)) - else: - for (X,y) in train_gen: - loss, rec_loss, kld,clf_loss,other_loss=self.train_step_semi_label(X,y,None,beta=beta) - losses[0]+=loss.numpy() - losses[1]+=rec_loss.numpy() - losses[2]+=kld.numpy() - losses[3] += clf_loss.numpy() - for i,tmp_loss in enumerate(other_loss): - losses[4+i]+=tf.reduce_mean(tmp_loss).numpy() - num_rounds+=1 - #print(num_rounds,np.round(losses/num_rounds,decimals=4)) - return losses/num_rounds - - - - @tf.function - def train_step(self,X,beta=0): - with tf.GradientTape() as tape: - loss,rec_loss,kld,other_loss=self.execute_net(X,training=True) - #loss = tf.reduce_mean(loss) - - train_vars = self.trainable_variables - grads = tape.gradient(loss, train_vars) - self.optimizer.apply_gradients(zip(grads, train_vars)) - return loss,rec_loss,kld,other_loss - - @tf.function - def train_step_semi_label(self,X,y,add_data,beta=0): - with tf.GradientTape() as tape: - loss,rec_loss,kld,clf_loss,other_loss,_,_,_=self.execute_net_xy(X,y,add_data,training=True,beta=beta) - #loss = tf.reduce_mean(loss) - - train_vars = self.trainable_variables - grads = tape.gradient(loss, train_vars) - self.optimizer.apply_gradients(zip(grads, train_vars)) - return loss,rec_loss,kld,clf_loss,other_loss - - - def encode_input(self,X,training=True): - encoding=self.encoder(X) - cls_input_mean = self.z_mean(encoding, training=training) - cls_input_var = self.z_log_var(encoding, training=training) - exp0 = tf.exp(cls_input_var * .5) - eps0 = tf.random.normal(shape=(cls_input_mean.shape[0], cls_input_mean.shape[1]), mean=0, stddev=1,dtype=self.dtype) - cls_input0 = eps0 * exp0+ cls_input_mean - return cls_input_mean,cls_input_var,cls_input0 - def enc_dec(self,X,training=True): - cls_input_mean, cls_input_var, cls_input0 = self.encode_input(X, training=training) - decoding = self.decoder(cls_input0) - return cls_input_mean,cls_input_var,cls_input0,decoding - - def embedding(self,X, training=False): - cls_input_mean, cls_input_var, cls_input0 = self.encode_input(X, training=training) - return cls_input0,cls_input_mean, cls_input_var - - def spn_clf(self,embedding_org, training=False): - y_pred = self.classifier(embedding_org) - return y_pred,None - - - def execute_net(self,X,training=True): - normalized_X=tf.cast(X,tf.float32)#X/255.0 - if self.load_pretrain_model: - normalized_X = tf.expand_dims(normalized_X[:,:,:,0],axis=-1) - normalized_X/=255.0 - - cls_input_mean,cls_input_var,cls_input0,decoding=self.enc_dec(X,training=training) - - cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=decoding, labels=normalized_X) - cros_ent_mean=tf.reduce_mean(cross_ent,axis=(1,2,3)) - kl_loss = -0.5 * tf.reduce_sum(1 + cls_input_var - tf.square(cls_input_mean) - tf.exp(cls_input_var), axis=1) - loss = cros_ent_mean +(kl_loss*0.0001) - l1_loss = tf.reduce_mean(cros_ent_mean) - kl_loss = tf.reduce_mean(kl_loss) - loss = tf.reduce_mean(loss) - sigmoid_vals=tf.keras.activations.sigmoid(decoding) - other_loss=[loss_fkt(sigmoid_vals,normalized_X) for loss_fkt in self.loss_fkts] - - #logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3]) - #logpz = log_normal_pdf(cls_input0, 0., 0.) - #logqz_x = log_normal_pdf(cls_input0, cls_input_mean, cls_input_var) - #return -tf.reduce_mean(logpx_z + logpz - logqz_x),-tf.reduce_mean(logpx_z),-tf.reduce_mean(logpz - logqz_x) - return loss,l1_loss,kl_loss,other_loss - - def model_execution_X(self,X, add_info_data, training=False): - - normalized_X=tf.cast(X,tf.float32)#X/255.0 - if self.load_pretrain_model: - normalized_X=tf.expand_dims(normalized_X[:,:,:,0],axis=-1) - normalized_X/=255.0 - - cls_input_mean, cls_input_var, cls_input0 = self.encode_input(X, training=training) - - if self.add_info: - cls_input=tf.concat([cls_input0,tf.cast(add_info_data,dtype=tf.float32)],axis=-1) - else: - cls_input=cls_input0 - - y_pred=self.classifier(cls_input) - return y_pred - - def reconstruct(self,z): - decoding = self.decoder(z) - sigmoid_vals = decoding#tf.keras.activations.sigmoid(decoding) - return sigmoid_vals - - - def execute_net_xy(self,X,y,add_info,training=True,beta=0): - normalized_X=tf.cast(X,tf.float32)#X/255.0 - if self.load_pretrain_model: - normalized_X=tf.expand_dims(normalized_X[:,:,:,0],axis=-1) - normalized_X/=255.0 - - cls_input_mean, cls_input_var, cls_input0 = self.encode_input(X, training=training) - if self.VAE_fine_tune == 0: - one_hot_y=tf.one_hot(y,2,dtype=dtype) - - decoder_input=tf.concat([cls_input0,one_hot_y],axis=-1) - else: - decoder_input=cls_input0 - - if self.add_info: - cls_input=tf.concat([cls_input0,tf.cast(add_info,dtype=tf.float32)],axis=-1) - else: - cls_input=cls_input0 - y_pred=self.classifier(cls_input) - decoding = self.decoder(decoder_input) - - clf_loss=self.clf_loss(y,y_pred) - sigmoid_vals =decoding #tf.keras.activations.sigmoid(decoding) - - - rec_loss=tf.reduce_mean(tf.keras.losses.MSE(sigmoid_vals,normalized_X),axis=(1,2)) - MSE=tf.reduce_mean(tf.keras.losses.MSE(sigmoid_vals,normalized_X),axis=(1,2)) - l1_loss = tf.reduce_mean(tf.keras.losses.MAE(sigmoid_vals, normalized_X), axis=(1, 2)) - - kl_loss = -0.5 * tf.reduce_sum(1 + cls_input_var - tf.square(cls_input_mean) - tf.exp(cls_input_var), axis=1) - - if self.use_KLD_anneal: - loss = (rec_loss * self.loss_weights[0]) + (kl_loss * beta) + (clf_loss *self.loss_weights[2]) - else: - loss = (rec_loss * self.loss_weights[0]) + (kl_loss * self.loss_weights[1]) + (clf_loss *self.loss_weights[2]) - xentro = tf.reduce_mean(rec_loss) - kl_loss = tf.reduce_mean(kl_loss) - loss = tf.reduce_mean(loss) - clf_loss=tf.reduce_mean(clf_loss) - other_loss=[tf.reduce_mean(l1_loss),tf.reduce_mean(MSE)] - - return loss,xentro,kl_loss,clf_loss,other_loss,y_pred,cls_input0,sigmoid_vals - def evaluate_(self,dataset,verbose=0): - losses = np.zeros((4 + 2)) - num_rounds = 0 - predictions = [] - gt = [] - if self.add_info: - for test_x, y,other_info in dataset: - loss, rec_loss, kld, clf, other_losses, pred,_,_ = self.execute_net_xy(test_x, y,other_info,training=False) - losses[0] += loss - losses[1] += rec_loss - losses[2] += kld - losses[3] += clf - predictions.extend(pred.numpy().tolist()) - gt.extend(y.numpy().tolist()) - for i, tmp_loss in enumerate(other_losses): - losses[4 + i] += tf.reduce_mean(tmp_loss) - - num_rounds += 1 - else: - for test_x, y in dataset: - loss, rec_loss, kld, clf, other_losses, pred,_,_ = self.execute_net_xy(test_x, y,None, training=False) - losses[0] += loss - losses[1] += rec_loss - losses[2] += kld - losses[3] += clf - predictions.extend(pred.numpy().tolist()) - gt.extend(y.numpy().tolist()) - for i, tmp_loss in enumerate(other_losses): - losses[4 + i] += tf.reduce_mean(tmp_loss) - - num_rounds += 1 - # calculate accuracy: - softmax_pred = tf.nn.softmax(predictions, axis=-1) - prediction_exponential = tf.math.exp(predictions) - arg_max = np.argmax(softmax_pred, axis=-1) - acc_loss = accuracy_score(gt, arg_max, normalize=True) - losses = losses / num_rounds - val_loss = losses[0] - balanced_acc=balanced_accuracy_score(gt, arg_max) - [prec,rec,f1,_]=precision_recall_fscore_support(gt, arg_max, average=None) - - pred_exp = np.nan_to_num(prediction_exponential[:, 1], nan=0, posinf=1.0) - fpr, tpr, thresholds = metrics.roc_curve(gt, pred_exp,pos_label=1) - - auc=metrics.auc(fpr, tpr) - - return [acc_loss,val_loss,balanced_acc,prec[1],rec[1],f1[1],auc] - - def clf_model(self): - inputs = tf.keras.Input(shape=self.input_shape_, dtype=dtype) - encoding= self.encoder(inputs) - cls_input_mean = self.z_mean(encoding) - cls_input_var = self.z_log_var(encoding) - clf_out=tf.keras.layers.Lambda(sampling, output_shape=(self.latent_dim,), name='z')([cls_input_mean, cls_input_var]) - - - cnn_embedding = tf.keras.Model(inputs=inputs, outputs=[clf_out,cls_input_mean,cls_input_var]) - return cnn_embedding - - def grad_cam_model(self): - # # Get the target layer - target_layer = self.encoder.get_layer('conv2d') - - input1 = target_layer.input - input2 = tf.keras.Input(shape=(3,), dtype=dtype) - - - conv= target_layer(input1) - - encoding=None - - - - cls_input_mean = self.z_mean(encoding) - cls_input_var = self.z_log_var(encoding) - clf_out = tf.keras.layers.Lambda(sampling, output_shape=(self.latent_dim,), name='z')([cls_input_mean, cls_input_var]) - cls_input=tf.concat([clf_out,tf.cast(input2,dtype=tf.float32)],axis=-1) - output = self.classifier(cls_input) - - # Create a model that maps the input image to the activations of the target layer - grad_model = tf.keras.models.Model( - [input1,input2], [target_layer.output,output] + nn.Conv2d(512, 512, 4, 2, 1), + nn.ReLU(inplace=True) + ) + self.fc_mu = nn.Linear(512 * 1 * 1, latent_dim) + self.fc_logvar = nn.Linear(512 * 1 * 1, latent_dim) + + def forward(self, x): + h = self.conv(x) + h = h.view(h.size(0), -1) + mu = self.fc_mu(h) + logvar = self.fc_logvar(h) + return mu, logvar + + +class Decoder(nn.Module): + def __init__(self, out_channels=1, latent_dim=62): + super().__init__() + self.fc = nn.Linear(latent_dim, 512) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(512, 512, 4, 2, 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + + nn.ConvTranspose2d(512, 256, 4, 2, 1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + + nn.ConvTranspose2d(256, 128, 4, 2, 1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + + nn.ConvTranspose2d(64, 32, 4, 2, 1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + + nn.ConvTranspose2d(32, 16, 4, 2, 1), + nn.BatchNorm2d(16), + nn.ReLU(inplace=True), + + nn.ConvTranspose2d(16, out_channels, 4, 2, 1), + nn.Sigmoid() ) - return grad_model - - - - - -def sampling(args): - z_mean, z_log_var = args - batch = tf.shape(z_mean)[0] - dim = tf.shape(z_mean)[1] - epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) - return z_mean + tf.exp(0.5 * z_log_var) * epsilon - - - - - - - - - + def forward(self, z): + h = self.fc(z) + h = h.view(h.size(0), 512, 1, 1) + x_recon = self.deconv(h) + return x_recon + + +class VAE(nn.Module): + def __init__(self, in_channels=1, latent_dim=62, img_size=128): + super().__init__() + self.encoder = Encoder(in_channels, latent_dim) + self.decoder = Decoder(in_channels, latent_dim) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x): + mu, logvar = self.encoder(x) + z = self.reparameterize(mu, logvar) + x_recon = self.decoder(z) + return x_recon, mu, logvar + + +def vae_loss(recon_x, x, mu, logvar): + recon_loss = F.mse_loss(recon_x, x, reduction='sum') + kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + return recon_loss + kld + + +def train_vae(model, dataloader, optimizer, device): + model.train() + total_loss = 0 + for x, _ in dataloader: + x = x.to(device) + optimizer.zero_grad() + recon_x, mu, logvar = model(x) + loss = vae_loss(recon_x, x, mu, logvar) + loss.backward() + optimizer.step() + total_loss += loss.item() + return total_loss / len(dataloader.dataset) + + +def test_vae(model, dataloader, device): + model.eval() + total_loss = 0 + with torch.no_grad(): + for x, _ in dataloader: + x = x.to(device) + recon_x, mu, logvar = model(x) + loss = vae_loss(recon_x, x, mu, logvar) + total_loss += loss.item() + return total_loss / len(dataloader.dataset) From ea19772b61986866ade977c29a85d785af6aafe8 Mon Sep 17 00:00:00 2001 From: UYang-121 Date: Mon, 24 Nov 2025 23:40:20 -0500 Subject: [PATCH 3/7] Add files via upload --- train_vae.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 train_vae.py diff --git a/train_vae.py b/train_vae.py new file mode 100644 index 0000000..72994c1 --- /dev/null +++ b/train_vae.py @@ -0,0 +1,38 @@ +import torch +from torch.utils.data import Dataset, DataLoader +import numpy as np +import pickle +from VAE import VAE, train_vae, test_vae + +class ChexpertDataset(Dataset): + def __init__(self, split): + with open("chexpert.pkl", "rb") as f: + data = pickle.load(f) + imgs = data[split][0] + self.images = torch.tensor(imgs, dtype=torch.float32).unsqueeze(1) / 255.0 # (N,1,128,128) + self.labels = torch.tensor(data[split][1], dtype=torch.float32) # 可忽略 + def __len__(self): + return len(self.images) + def __getitem__(self, idx): + return self.images[idx], self.labels[idx] + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_set = ChexpertDataset("train") + val_set = ChexpertDataset("validation") + + train_loader = DataLoader(train_set, batch_size=32, shuffle=True) + val_loader = DataLoader(val_set, batch_size=32, shuffle=False) + + model = VAE(in_channels=1, latent_dim=62).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + + for epoch in range(1, 11): + loss = train_vae(model, train_loader, optimizer, device) + val_loss = test_vae(model, val_loader, device) + print(f"Epoch {epoch}: Train Loss={loss:.4f}, Val Loss={val_loss:.4f}") + + torch.save(model.state_dict(), "vae_trained.pt") + +if __name__ == "__main__": + main() From e385ba4f1781f8e112bd2d830e543465a8fa7c20 Mon Sep 17 00:00:00 2001 From: dwang0120 Date: Sun, 30 Nov 2025 16:06:28 -0500 Subject: [PATCH 4/7] updated VAE.py vae class, and updated analyse_model to use pytorch functions --- VAE.py | 30 ++++- counterfactuals/analyse_model.py | 105 ++++++++++----- counterfactuals/analyse_model_tf.py | 194 ++++++++++++++++++++++++++++ 3 files changed, 296 insertions(+), 33 deletions(-) create mode 100644 counterfactuals/analyse_model_tf.py diff --git a/VAE.py b/VAE.py index 33fe08d..24172a5 100644 --- a/VAE.py +++ b/VAE.py @@ -86,11 +86,14 @@ def forward(self, z): class VAE(nn.Module): - def __init__(self, in_channels=1, latent_dim=62, img_size=128): + def __init__(self, in_channels=1, latent_dim=62, img_size=128, num_classes = 2): super().__init__() + self.latent_dim = latent_dim self.encoder = Encoder(in_channels, latent_dim) self.decoder = Decoder(in_channels, latent_dim) + # classifier head to match TF pipeline - dennis + self.classifier = nn.Linear(latent_dim, num_classes) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) @@ -100,8 +103,27 @@ def forward(self, x): mu, logvar = self.encoder(x) z = self.reparameterize(mu, logvar) x_recon = self.decoder(z) - return x_recon, mu, logvar - + logits = self.classifier(mu) + return x_recon, mu, logvar, logits + def clf_model(self): + # provide latent embedding extraction + class LatentExtractor(nn.Module): + def __init__(self, parent): + super().__init__() + self.encoder = parent.encoder + def forward(self, x): + mu, logvar = self.encoder(x) + return mu + return LatentExtractor(self) + def vae_rec(self, x): + # reconstruction-only method. I don't think this'll be used since it's in CNN_SPN but just in case. + self.eval() + with torch.no_grad(): + x = torch.tensor(x).float() + if x.ndim == 3: + x = x.unsqueeze(0) + recon, _, _, = self.forward(x) + return recon.cpu().numpy() def vae_loss(recon_x, x, mu, logvar): recon_loss = F.mse_loss(recon_x, x, reduction='sum') @@ -133,3 +155,5 @@ def test_vae(model, dataloader, device): loss = vae_loss(recon_x, x, mu, logvar) total_loss += loss.item() return total_loss / len(dataloader.dataset) + + diff --git a/counterfactuals/analyse_model.py b/counterfactuals/analyse_model.py index 29ea9ce..69d5fbe 100644 --- a/counterfactuals/analyse_model.py +++ b/counterfactuals/analyse_model.py @@ -1,4 +1,6 @@ -import tensorflow as tf +#import tensorflow as tf +import torch as torch +import torch.nn as nn import pickle as pkl import matplotlib.pyplot as plt import random @@ -17,40 +19,84 @@ def load_VAE_model( add_info, num_classes, input_shape,data_path_fold,params,path=''): + """ + Loads VAE model + Args: + add_info (bool): whether additional info channels are used + num_classes (int): number of classes in dataset + input_shape (tuple): shape of input images (C, H, W) + data_path_fold (str): path to fold directory + params (obj): VAE hyperparameters + path (str): optional root path + + Returns: + vae_model: loaded VAE pytorch model + """ vae_save_path=data_path_fold+'vae_checkpoints_copy/tf_ckpts_last' print('TODO _copy') # 1. load vae: - vae_model = load_VAE(params, add_info, num_classes, input_shape, vae_save_path,path) + # params,add_info,num_classes,input_shape,checkpoint_path,path='') + vae_model = load_VAE( + params = params, + add_info = add_info, + num_classes = num_classes, + input_shape = input_shape, + checkpoint_path = vae_save_path, + path = path) return vae_model - - def load_model( add_info, num_classes, input_shape,data_path_fold,params,path=''): - vae_save_path=data_path_fold+'vae_checkpoints/tf_ckpts_last' - checkpoint_path_cnn_spn = data_path_fold + 'cnn_spn_checkpoints' - + """ + Loads a PyTorch CNN-SPN model, its corresponding VAE embedding model, and SPN structure. + + Args: + add_info (bool): whether additional info channels are used + num_classes (int): number of output classes + input_shape (tuple): shape of input images (C, H, W) + data_path_fold (str): path to fold directory + params (obj): models hyperparameters + path (str): optional root path + + Returns: + cnn_spn (nn.Module): Fully loaded PyTorch CNN+SPN model. + """ + vae_ckpt_path=data_path_fold+'vae_checkpoints/pt_ckpts_last' + cnn_spn_ckpt_path = data_path_fold + 'cnn_spn_checkpoints' # 1. load vae: - vae_model = load_VAE(params, add_info, num_classes, input_shape, vae_save_path,path) + vae_model = load_VAE( + params = params, + add_info = add_info, + num_classes = num_classes, + input_shape = input_shape, + checkpoint_path = vae_ckpt_path, + path = path) # 2. load spn_structure spn_data=pkl.load(open(data_path_fold+ 'spn.pkl', 'rb')) spn_clf=spn_data['spn_x'] spn_input_shape=spn_data['data_shape'] label_ids=spn_data['label_ids'] - spn_x_copy, all_spn_x_y, all_spn_x_y_dicts, all_prior, all_spn_x_y_model = create_tf_spn_parts(spn_x=spn_clf, - data_shape=spn_input_shape, - label_ids=label_ids, - trainable_leaf=params.fine_tune_leafs) - + spn_x_copy, all_spn_x_y, all_spn_x_y_dicts, all_prior, all_spn_x_y_model = \ + create_tf_spn_parts( + spn_x=spn_clf, + data_shape=spn_input_shape, + label_ids=label_ids, + trainable_leaf=params.fine_tune_leafs) + # 3. Prepare embedding network + # Comment out Tensorflow equivalent, replace with pytorch decoder = None if params.use_VAE: cnn_embedding = vae_model.clf_model() if params.VAE_fine_tune: decoder = vae_model.decoder else: - cnn_layer_out = vae_model.get_layer('flatten').output - cnn_embedding = tf.keras.Model(inputs=vae_model.input, outputs=cnn_layer_out) + try: + cnn_layer_out = vae_model.get_layer('flatten').output + #cnn_embedding = tf.keras.Model(inputs=vae_model.input, outputs=cnn_layer_out) + cnn_embedding = nn.Module(inputs=vae_model.input, outputs=cnn_layer_out) + except: + raise NotImplementedError("Non-VAE embedding not yet implemented for PyTorch version. Need to get CNN_functions.load_vae() flatten layer first.") # 3. load cnn_spn @@ -65,23 +111,22 @@ def load_model( add_info, num_classes, input_shape,data_path_fold,params,path='' VAE_fine_tune=params.VAE_fine_tune, decoder=decoder, loss_weights=params.loss_weights, - load_pretrain_model=params.load_pretrain_model,clf_mlp=vae_model.classifier) - - ckpt_cnn_spn = tf.train.Checkpoint(step=tf.Variable(1), optimizer=cnn_spn.optimizer, net=cnn_spn) - manager_cnn_spn = tf.train.CheckpointManager(ckpt_cnn_spn, checkpoint_path_cnn_spn, max_to_keep=1) - - - ckpt_cnn_spn.restore(manager_cnn_spn.latest_checkpoint) - - + load_pretrain_model=params.load_pretrain_model, + clf_mlp=vae_model.classifier) + + # Load checkpoint + #ckpt_cnn_spn = tf.train.Checkpoint(step=tf.Variable(1), optimizer=cnn_spn.optimizer, net=cnn_spn) + #manager_cnn_spn = tf.train.CheckpointManager(ckpt_cnn_spn, checkpoint_path_cnn_spn, max_to_keep=1) + + #ckpt_cnn_spn.restore(manager_cnn_spn.latest_checkpoint) + try: + ckpt_state = torch.load(cnn_spn_ckpt_path, map_location="cpu") + cnn_spn.load_state_dict(ckpt_state) + print(f"Loaded CNN+SPN checkpoint from {cnn_spn_ckpt_path}") + except FileNotFoundError: + print("WARNING: No CNN+SPN checkpoint found. Using randomly initialized weights.") return cnn_spn - - - - - - def plot_image_grid(images, ax, title,cmap='gray', alpha=1.0): """Helper function to plot a 4x4 grid of images.""" for i in range(16): diff --git a/counterfactuals/analyse_model_tf.py b/counterfactuals/analyse_model_tf.py new file mode 100644 index 0000000..29ea9ce --- /dev/null +++ b/counterfactuals/analyse_model_tf.py @@ -0,0 +1,194 @@ +import tensorflow as tf +import pickle as pkl +import matplotlib.pyplot as plt +import random +from sklearn.model_selection import KFold + +from CNN_SPN import CNN_SPN_Parts, test_model_SPN_MLP +from CNN_functions import load_VAE +from end_to_end_train import load_dataset +from tf2_spn import create_tf_spn_parts +import numpy as np + + +# 0. load model +# 1. evaluate on test +# 2. show example of FP/FP/TP/TN (very certain) + + +def load_VAE_model( add_info, num_classes, input_shape,data_path_fold,params,path=''): + vae_save_path=data_path_fold+'vae_checkpoints_copy/tf_ckpts_last' + print('TODO _copy') + # 1. load vae: + vae_model = load_VAE(params, add_info, num_classes, input_shape, vae_save_path,path) + return vae_model + + + +def load_model( add_info, num_classes, input_shape,data_path_fold,params,path=''): + vae_save_path=data_path_fold+'vae_checkpoints/tf_ckpts_last' + checkpoint_path_cnn_spn = data_path_fold + 'cnn_spn_checkpoints' + + + # 1. load vae: + vae_model = load_VAE(params, add_info, num_classes, input_shape, vae_save_path,path) + # 2. load spn_structure + spn_data=pkl.load(open(data_path_fold+ 'spn.pkl', 'rb')) + spn_clf=spn_data['spn_x'] + spn_input_shape=spn_data['data_shape'] + label_ids=spn_data['label_ids'] + + spn_x_copy, all_spn_x_y, all_spn_x_y_dicts, all_prior, all_spn_x_y_model = create_tf_spn_parts(spn_x=spn_clf, + data_shape=spn_input_shape, + label_ids=label_ids, + trainable_leaf=params.fine_tune_leafs) + + decoder = None + if params.use_VAE: + cnn_embedding = vae_model.clf_model() + if params.VAE_fine_tune: + decoder = vae_model.decoder + else: + cnn_layer_out = vae_model.get_layer('flatten').output + cnn_embedding = tf.keras.Model(inputs=vae_model.input, outputs=cnn_layer_out) + + + # 3. load cnn_spn + cnn_spn = CNN_SPN_Parts(num_classes=num_classes, + learning_rate=params.fine_tune_rate, + all_spn_x_y_model=all_spn_x_y_model, + all_prior=all_prior, + cnn=cnn_embedding, + get_max=0, + gauss_embeds=params.gauss_embeds, + use_add_info=params.use_add_info, + VAE_fine_tune=params.VAE_fine_tune, + decoder=decoder, + loss_weights=params.loss_weights, + load_pretrain_model=params.load_pretrain_model,clf_mlp=vae_model.classifier) + + ckpt_cnn_spn = tf.train.Checkpoint(step=tf.Variable(1), optimizer=cnn_spn.optimizer, net=cnn_spn) + manager_cnn_spn = tf.train.CheckpointManager(ckpt_cnn_spn, checkpoint_path_cnn_spn, max_to_keep=1) + + + ckpt_cnn_spn.restore(manager_cnn_spn.latest_checkpoint) + + + return cnn_spn + + + + + + + +def plot_image_grid(images, ax, title,cmap='gray', alpha=1.0): + """Helper function to plot a 4x4 grid of images.""" + for i in range(16): + ax[i // 4, i % 4].imshow(images[i], cmap=cmap,alpha=alpha) + ax[i // 4, i % 4].axis('off') # Turn off axis for each image + # Set the header for the grid + ax[0, 1].set_title(title, fontsize=16, pad=20) + +def plot_examples(sorted_data, input_data_name,plot_path,cnn_spn_model): + for key,value in sorted_data.items(): + vals=value[:16] + if len(vals)==16: + # create reconstructions: + gt, predictions, pred_exp,images=zip(*vals) + images=np.asarray(list(images)) + + test_rec = cnn_spn_model.vae_rec(images) + + # Create a figure with two sets of 4x4 grids (one 8x8 plot with 16 subplots) + fig, axs = plt.subplots(4, 8, figsize=(12, 6)) + + # First grid (left 4x4) for test_rec + plot_image_grid(test_rec, axs[:, :4], "Reconstructed Images") + + # Second grid (right 4x4) for images + plot_image_grid(images, axs[:, 4:], "Original Images") + + # Add an overall title + plt.suptitle(key+' '+input_data_name, fontsize=18) + + # Adjust layout to prevent overlap + plt.tight_layout(rect=[0, 0, 1, 0.95]) # Leaves space for the overall title + + # Display the combined plot + #plt.show() + plt.savefig(plot_path+key+'_'+input_data_name+'.png') + plt.clf() + + + +def analyse_pipeline(dataset_name,params,num_train_eval_runs,data_path_grid,add_info,fold_idxs,model_name,path='../'): + # 0. load model + + + #input_shape=(128,128,3) + # load data: + train, test_data, num_classes = load_dataset(dataset_name, binary=False, + load_net=params.load_pretrain_model, + machine=params.machine,path=path) + statistics_names=['Accuracy','AUC','Rec:MSE'] + result_clf=[] + result_rec=[] + + for f_i in fold_idxs: + data_path_fold = data_path_grid + 'fold_' + str(f_i) + '/' + + kf = KFold(n_splits=num_train_eval_runs, random_state=1, shuffle=True) + + for i,(train_index, val_index) in enumerate(kf.split(train[0])): + if i ==f_i: + + random.seed(0) + random.shuffle(train_index) + random.shuffle(val_index) + train_x, val_x = train[0][train_index], train[0][val_index] + train_y, val_y = train[1][train_index], train[1][val_index] + train_data=(train_x,train_y) + + for input_data , input_data_name in zip([test_data,train_data],['Test','Train']): + print(input_data_name) + + input_shape = train_data[0].shape[1:] + cnn_spn_model=load_model( add_info, num_classes, input_shape,data_path_fold,params,path='../') + + # 1. evaluate on test + + results_MLP,results_SPN,losses = test_model_SPN_MLP(cnn_spn_model, input_data, num_classes=num_classes, + batch_size=params.batch_size, + training=False, add_info=add_info) + + for result, clf_name in zip([results_MLP,results_SPN],['MLP','SPN']): + print(result,flush=True) + result_clf.append({ + 'split':input_data_name, + "dataset": dataset_name, + "clf": clf_name, + "fold_idx": f_i, + "model_name": model_name, + 'Accuracy':result[0], + 'Entropy':result[1], + 'Balanced Accuracy':result[2], + 'Precision':result[3], + 'Recall':result[4], + 'F1-Score':result[5], + 'AUC':result[6] + }) + result_rec.append({ + 'split': input_data_name, + "dataset": dataset_name, + "fold_idx": f_i, + "model_name": model_name, + 'MSE':losses[1], + 'MAE':losses[4], + 'KLD':losses[3], + }) + + return result_clf,result_rec + + + From 2be049fce0997fcfec1321e34d538f1d38a3531b Mon Sep 17 00:00:00 2001 From: dwang0120 Date: Sun, 30 Nov 2025 21:51:52 -0500 Subject: [PATCH 5/7] progress on find_many_counterfactuals.py --- counterfactuals/analyse_model_tf.py | 194 -------------- counterfactuals/find_many_counterfactuals.py | 253 ++++++++++++------- 2 files changed, 156 insertions(+), 291 deletions(-) delete mode 100644 counterfactuals/analyse_model_tf.py diff --git a/counterfactuals/analyse_model_tf.py b/counterfactuals/analyse_model_tf.py deleted file mode 100644 index 29ea9ce..0000000 --- a/counterfactuals/analyse_model_tf.py +++ /dev/null @@ -1,194 +0,0 @@ -import tensorflow as tf -import pickle as pkl -import matplotlib.pyplot as plt -import random -from sklearn.model_selection import KFold - -from CNN_SPN import CNN_SPN_Parts, test_model_SPN_MLP -from CNN_functions import load_VAE -from end_to_end_train import load_dataset -from tf2_spn import create_tf_spn_parts -import numpy as np - - -# 0. load model -# 1. evaluate on test -# 2. show example of FP/FP/TP/TN (very certain) - - -def load_VAE_model( add_info, num_classes, input_shape,data_path_fold,params,path=''): - vae_save_path=data_path_fold+'vae_checkpoints_copy/tf_ckpts_last' - print('TODO _copy') - # 1. load vae: - vae_model = load_VAE(params, add_info, num_classes, input_shape, vae_save_path,path) - return vae_model - - - -def load_model( add_info, num_classes, input_shape,data_path_fold,params,path=''): - vae_save_path=data_path_fold+'vae_checkpoints/tf_ckpts_last' - checkpoint_path_cnn_spn = data_path_fold + 'cnn_spn_checkpoints' - - - # 1. load vae: - vae_model = load_VAE(params, add_info, num_classes, input_shape, vae_save_path,path) - # 2. load spn_structure - spn_data=pkl.load(open(data_path_fold+ 'spn.pkl', 'rb')) - spn_clf=spn_data['spn_x'] - spn_input_shape=spn_data['data_shape'] - label_ids=spn_data['label_ids'] - - spn_x_copy, all_spn_x_y, all_spn_x_y_dicts, all_prior, all_spn_x_y_model = create_tf_spn_parts(spn_x=spn_clf, - data_shape=spn_input_shape, - label_ids=label_ids, - trainable_leaf=params.fine_tune_leafs) - - decoder = None - if params.use_VAE: - cnn_embedding = vae_model.clf_model() - if params.VAE_fine_tune: - decoder = vae_model.decoder - else: - cnn_layer_out = vae_model.get_layer('flatten').output - cnn_embedding = tf.keras.Model(inputs=vae_model.input, outputs=cnn_layer_out) - - - # 3. load cnn_spn - cnn_spn = CNN_SPN_Parts(num_classes=num_classes, - learning_rate=params.fine_tune_rate, - all_spn_x_y_model=all_spn_x_y_model, - all_prior=all_prior, - cnn=cnn_embedding, - get_max=0, - gauss_embeds=params.gauss_embeds, - use_add_info=params.use_add_info, - VAE_fine_tune=params.VAE_fine_tune, - decoder=decoder, - loss_weights=params.loss_weights, - load_pretrain_model=params.load_pretrain_model,clf_mlp=vae_model.classifier) - - ckpt_cnn_spn = tf.train.Checkpoint(step=tf.Variable(1), optimizer=cnn_spn.optimizer, net=cnn_spn) - manager_cnn_spn = tf.train.CheckpointManager(ckpt_cnn_spn, checkpoint_path_cnn_spn, max_to_keep=1) - - - ckpt_cnn_spn.restore(manager_cnn_spn.latest_checkpoint) - - - return cnn_spn - - - - - - - -def plot_image_grid(images, ax, title,cmap='gray', alpha=1.0): - """Helper function to plot a 4x4 grid of images.""" - for i in range(16): - ax[i // 4, i % 4].imshow(images[i], cmap=cmap,alpha=alpha) - ax[i // 4, i % 4].axis('off') # Turn off axis for each image - # Set the header for the grid - ax[0, 1].set_title(title, fontsize=16, pad=20) - -def plot_examples(sorted_data, input_data_name,plot_path,cnn_spn_model): - for key,value in sorted_data.items(): - vals=value[:16] - if len(vals)==16: - # create reconstructions: - gt, predictions, pred_exp,images=zip(*vals) - images=np.asarray(list(images)) - - test_rec = cnn_spn_model.vae_rec(images) - - # Create a figure with two sets of 4x4 grids (one 8x8 plot with 16 subplots) - fig, axs = plt.subplots(4, 8, figsize=(12, 6)) - - # First grid (left 4x4) for test_rec - plot_image_grid(test_rec, axs[:, :4], "Reconstructed Images") - - # Second grid (right 4x4) for images - plot_image_grid(images, axs[:, 4:], "Original Images") - - # Add an overall title - plt.suptitle(key+' '+input_data_name, fontsize=18) - - # Adjust layout to prevent overlap - plt.tight_layout(rect=[0, 0, 1, 0.95]) # Leaves space for the overall title - - # Display the combined plot - #plt.show() - plt.savefig(plot_path+key+'_'+input_data_name+'.png') - plt.clf() - - - -def analyse_pipeline(dataset_name,params,num_train_eval_runs,data_path_grid,add_info,fold_idxs,model_name,path='../'): - # 0. load model - - - #input_shape=(128,128,3) - # load data: - train, test_data, num_classes = load_dataset(dataset_name, binary=False, - load_net=params.load_pretrain_model, - machine=params.machine,path=path) - statistics_names=['Accuracy','AUC','Rec:MSE'] - result_clf=[] - result_rec=[] - - for f_i in fold_idxs: - data_path_fold = data_path_grid + 'fold_' + str(f_i) + '/' - - kf = KFold(n_splits=num_train_eval_runs, random_state=1, shuffle=True) - - for i,(train_index, val_index) in enumerate(kf.split(train[0])): - if i ==f_i: - - random.seed(0) - random.shuffle(train_index) - random.shuffle(val_index) - train_x, val_x = train[0][train_index], train[0][val_index] - train_y, val_y = train[1][train_index], train[1][val_index] - train_data=(train_x,train_y) - - for input_data , input_data_name in zip([test_data,train_data],['Test','Train']): - print(input_data_name) - - input_shape = train_data[0].shape[1:] - cnn_spn_model=load_model( add_info, num_classes, input_shape,data_path_fold,params,path='../') - - # 1. evaluate on test - - results_MLP,results_SPN,losses = test_model_SPN_MLP(cnn_spn_model, input_data, num_classes=num_classes, - batch_size=params.batch_size, - training=False, add_info=add_info) - - for result, clf_name in zip([results_MLP,results_SPN],['MLP','SPN']): - print(result,flush=True) - result_clf.append({ - 'split':input_data_name, - "dataset": dataset_name, - "clf": clf_name, - "fold_idx": f_i, - "model_name": model_name, - 'Accuracy':result[0], - 'Entropy':result[1], - 'Balanced Accuracy':result[2], - 'Precision':result[3], - 'Recall':result[4], - 'F1-Score':result[5], - 'AUC':result[6] - }) - result_rec.append({ - 'split': input_data_name, - "dataset": dataset_name, - "fold_idx": f_i, - "model_name": model_name, - 'MSE':losses[1], - 'MAE':losses[4], - 'KLD':losses[3], - }) - - return result_clf,result_rec - - - diff --git a/counterfactuals/find_many_counterfactuals.py b/counterfactuals/find_many_counterfactuals.py index 381cb8e..5a67333 100644 --- a/counterfactuals/find_many_counterfactuals.py +++ b/counterfactuals/find_many_counterfactuals.py @@ -6,7 +6,10 @@ import pickle as pkl import matplotlib.patches as patches import matplotlib.pyplot as plt -import tensorflow as tf +#import tensorflow as tf +import torch +import torch.nn as nn +import torch.nn.functional as F import multiprocessing as mp mp.set_start_method("spawn", force=True) from concurrent.futures import ProcessPoolExecutor @@ -15,116 +18,172 @@ def nan_to_num(tensor,nan,posinf): - is_pos_inf = tf.math.is_inf(tensor) & (tensor > 0) - tensor= tf.where(is_pos_inf, posinf, tensor) - return tf.where(tf.math.is_nan(tensor), nan, tensor) + """ + Replace NaNs and positive infinities in a torch tensor + + Args: + tensor (torch.Tensor or np.ndarray): input + nan (float): value to replace NaNs with + posinf (float): value to replace positive infinities with + + Returns: + tensor: cleaned Torch.tensor + """ + #is_pos_inf = tf.math.is_inf(tensor) & (tensor > 0) + #tensor= tf.where(is_pos_inf, posinf, tensor) + #return tf.where(tf.math.is_nan(tensor), nan, tensor) + if not torch.is_tensor(tensor): + tensor = torch.tensor(tensor) + # first dealing with positive infinity values + is_pos_inf = torch.isinf(tensor) & (tensor > 0) + tensor = torch.where(is_pos_inf, torch.tensor(posinf, dtype = tensor.dtype, device = tensor.device), tensor) + # and now dealing with nan's + tensor = torch.where(torch.isnan(tensor), torch.tensor(nan, dtype = tensor.dtype, device = tensor.device), tensor) + return tensor + def get_counterfactual_infos(img_idx,cnn_spn_model,X,additional_info,opposite_class,y,model_type,opt_weights=[10,0.005],learning_rate = 0.01,num_steps = 150): - # Get the latent variable z for the input + """ Get the latent variable z for the input. Generate counterfactual latent vectors z' by optimizing z' so that the classifier's prediction + moves to the opposite class while keeping z' close to the original z and latent likelihood similar. + + Args: + img_idx (int): index of image in dataset. Optional for logging + cnn_spn_model: Pytorch wrapper providing: + - embedding(X) --> (z_embed, ... , ...) + - spn_clf(embedding) --> (pred_logits, p_z) + - reconstruct(z_tensor) --> reconstructed images tensor + X (np.ndarray or torch.Tensor): Input batch + additional_info (np.ndarray or torch.Tensor): additional info per sample + opposite_class (int): new target class to push to (0,1) + y(np.ndarray or list): original labels (batch or single) + model_type (str): 'MLP' or 'SPN' - determines loss formulation + opt_weights (list): weights [beta, gamma] for distance and plausibility metrics + learning_rate (float): optimizer learning rate for optimizing z' + num_steps (int): number of gradient steps on z' + Returns: + reconstruction_np, rec_z_np, z_prime, np, title_info, distance_np, arg_map_np, loss_val, log_pred_val, p_z_np, label_switch_step, pred_np + """ + # get device from model if possible + device = None + try: + # prefer model parameters + params = next(cnn_spn_model.parameters()) + device = params.device + except StopIteration: + device = torch.device('cpu') + # make sure we're working with tensors that are optimized for our device + if not torch.is_tensor(X): + X_t = torch.tensor(X, dtype=torch.float32, device=device) + else: + X_t = X.to(device) + if not torch.is_tensor(additional_info): + add_info_t = torch.tensor(additional_info, dtype = torch.float32, device = device) + else: + add_info_t = additional_info.to(device) + + # New: Get latent embedding z for the input (expecting a torch tensor?) + # And the embedding should return (z_embed, ..., ...) #print('image',img_idx) - [z_embed, _, _] = cnn_spn_model.embedding(X, training=False) + #[z_embed, _, _] = cnn_spn_model.embedding(X, training=False) # z = vae_encoder(input_x) # Define z_prime as a trainable variable initialized from z - y_opposite=tf.Variable(np.asarray([opposite_class]*z_embed.shape[0]), trainable=True, dtype=tf.dtypes.int32) - z = tf.Variable(z_embed, trainable=False, dtype=tf.dtypes.float32) - z_prime = tf.Variable(z_embed, trainable=True, dtype=tf.dtypes.float32) - add_info = tf.Variable(additional_info, trainable=False, dtype=tf.dtypes.float32) + #y_opposite=tf.Variable(np.asarray([opposite_class]*z_embed.shape[0]), trainable=True, dtype=tf.dtypes.int32) + #z = tf.Variable(z_embed, trainable=False, dtype=tf.dtypes.float32) + #z_prime = tf.Variable(z_embed, trainable=True, dtype=tf.dtypes.float32) + #add_info = tf.Variable(additional_info, trainable=False, dtype=tf.dtypes.float32) #print('shapes', z_prime.shape, additional_info.shape,'le:',learning_rate) - # Define the optimizer - optimizer = tf.keras.optimizers.SGD(learning_rate) # Adam(learning_rate) - myloss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + with torch.no_grad(): + z_embed_tuple = cnn_spn_model.embedding(X_t) #expecting (z, ... , ...) + # account for if the embedding returns a tuple or list + if isinstance(z_embed_tuple, tuple) or isinstance(z_embed_tuple, list): + z_embed = z_embed_tuple[0] + else: + z_embed = z_embed_tuple + # New: detach original z + z = z_embed.detach().to(device) # new shape expected (batch, latent_dim) + # New: crate z_prime as a trainable parameter initialized to z + z_prime = z.clone().detach().requires_grad_(True) + # New: Create y_opposite tensor for classification loss (if we need it) + batch_size = z.shape[0] + y_opposite = torch.full((batch_size,), opposite_class, dtype = torch.long, device = device) + + # Define the optimizer + #optimizer = tf.keras.optimizers.SGD(learning_rate) # Adam(learning_rate) + #myloss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + optimizer = torch.optim.SGD([z_prime], lr=learning_rate) + ce_loss = nn.CrossEntropyLoss(reduction='none') # Loss function to push z' to the opposite class - @tf.function - def counterfactual_loss(): - # Get the predicted probability p(y|z') from the classifier - embedding = tf.concat([z_prime, add_info], axis=-1) - embedding_org = tf.concat([z, add_info], axis=-1) - - pred_org, p_z_org = cnn_spn_model.spn_clf(embedding_org, training=False) - - pred, p_z = cnn_spn_model.spn_clf(embedding, training=False) # clf(z_prime) - # -pred[:,int(y[0])] - - distance = tf.losses.MSE(z, z_prime) - # loss=-test-(p_z*0.5)+distance - plausability_diff=distance - if model_type=='MLP': - # TODO change this to get a result with 2 axis - for later evaluation of log_pred - test=myloss(y_opposite,pred) - loss = test +( distance * opt_weights[0]) - elif model_type=='SPN': - #shift cls_prob to 0 and 1 - test = pred[:, opposite_class] - #p_z_prob=tf.keras.ops.nan_to_num(tf.math.exp(p_z),nan=0,posinf=1.0) - #p_z_org_prob=tf.keras.ops.nan_to_num(tf.math.exp(p_z_org),nan=0,posinf=1.0) - ''' - test=nan_to_num(tf.math.exp(test),nan=0,posinf=1.0) - p_z_prob=nan_to_num(tf.math.exp(p_z),nan=0,posinf=1.0) - p_z_org_prob=nan_to_num(tf.math.exp(p_z_org),nan=0,posinf=1.0) - plausability_diff = tf.reduce_mean(tf.abs(p_z_prob - p_z_org_prob), axis=-1) - ''' + + # pre-initializing some variables for later + title_info = '' + arg_max = None + loss_val = None + log_pred_val = None + distance_val = None + p_z_val = None + label_switch_step = -1 + switch = False + pred = None + + # Precompute embedding_org related values + embedding_org = torch.cat([z, add_info_t], dim=-1) + with torch.no_grad(): + pred_org, p_z_org = cnn_spn_model.spn_clf(embedding_org) # pred_org logits, p_z_org maybe log probs + + #optimization loop + for step in range(num_steps): + optimizer.zero_grad() + + # create embedding from z_prime + additional info + embedding = torch.cat([z_prime, add_info_t], dim=-1) + + # get predictions and p(z) from SPN or MLP (expecting logits) + pred_logits, p_z = cnn_spn_model.spn_clf(embedding) # pred_logits: (batch_num_classes) + + # compute distance per sample (mean squared on latent dims) + distance = torch.mean((z-z_prime) ** 2, dim=-1) # (batch, ) + + if model_type == 'MLP': + # CrossEntropy expects (batch, classes) logits and (batch,) targets + test = ce_loss(pred_logits, y_opposite) # per-sample loss + # total per sample loss = CE + beta * distance + per_sample_loss = test + (distance * opt_weights[0]) + loss = torch.mean(per_sample_loss) + # for logging purposes -- set log_pred as negative cross_entropy (higher is better) + log_pred = -test.detach() + elif model_type == 'SPN': + # assuming pred_logits are class-scores (bigger --> more probable); use opposite_class column + # test is a per-sample score = pred_logits[:, opposite_class] + test = pred_logits[:, opposite_class] + # plausibility difference between p_z and p_z_org + # ensure they're tensors w/ same shape + # compute per-sample sum of abs difference across appropriate dim + try: + plaus_diff = torch.sum(torch.abs(p_z - p_z_org), dim=-1) + except Exception: + # fallback if p_z are scalars + plaus_diff = torch.abs(p_z - p_z_org).reshape(-1) + per_sample_loss = (-test) + (distance * opt_weights[0]) + (plaus_diff * opt_weights[1]) + loss = torch.mean(per_sample_loss) + log_red = -test.detach() - plausability_diff = tf.reduce_sum(tf.abs(p_z - p_z_org), axis=-1) - loss = -test + (distance * opt_weights[0]) + (plausability_diff * opt_weights[1]) + else: + raise ValueError("model_type must be either 'MLP' or 'SPN'") + # Backprop on z_prime - # print(loss.shape,test.shape,distance.shape,plausability_diff.shape) - return loss, -test, distance, pred, plausability_diff ,distance + # Evaluate switching condition - title_info = '' - arg_max, loss, log_pred, distance, z_prime_, p_z = 0, 0, 0, 0, 0, 0 - # Perform gradient descent on z_prime to achieve the counterfactual - label_switch_step=-1 - switch=0 - for step in range(num_steps): - # Compute the loss and apply gradients - with tf.GradientTape() as tape: - loss, log_pred, distance, pred, p_z,test = counterfactual_loss() - - grads = tape.gradient(loss, [z_prime]) - optimizer.apply_gradients(zip(grads, [z_prime])) - if not switch: - arg_max = np.argmax(pred.numpy(), axis=1) - arg_max_mean = np.mean(arg_max) - if opposite_class and arg_max_mean>=0.5: - switch=True - label_switch_step=step - elif (not opposite_class) and arg_max_mean<=0.5: - switch=True - label_switch_step=step - - # Optionally print the progress - #''' - if step == num_steps-1: - arg_max = np.argmax(pred.numpy(), axis=1) - loss = loss.numpy() - #z_prime_ = np.mean(z_prime.numpy(), axis=0) - - log_pred = log_pred.numpy() - distance = distance.numpy() - #test=test.numpy() - p_z = p_z.numpy() - - #log_pred_mean = np.mean(log_pred) - distance_mean = np.mean(distance) - #test_mean=np.mean(test) - p_z_mean = np.mean(p_z) - arg_max_mean = np.mean(arg_max) - #loss_mean = np.mean(loss) - - title_info = "New y:{:.0f}; org y:{:.0f}; MSE{:.2f}; log(p(z)):{:.2f}".format(arg_max_mean, y[0], - distance_mean, p_z_mean) - #print('IMG',img_idx, - # f"Step {step}, Loss: {loss_mean},pred_log: {log_pred_mean},distance: {distance_mean}" - # f",log(p(z)): {p_z_mean}, z_prime: {z_prime_[:5]}, mean argmax: {arg_max_mean}, org_label: {y}") - #if arg_max_mean==opposite_class: - # break - #''' - reconstructions = cnn_spn_model.reconstruct(z_prime) - rec_z = cnn_spn_model.reconstruct(z) + # if last step, gather diagnostics + + # After optimization, reconstruct images from z_prime and z + + # convert outpuits to numpy for later stuff that expects numpy + + print('Rec min max',np.min(rec_z),np.max(rec_z)) return reconstructions.numpy(),rec_z.numpy(),z_prime.numpy(),z.numpy(),title_info,distance,arg_max,loss,log_pred,p_z,label_switch_step,pred.numpy() From 4dbc05928585100667ca239818c422325d05ba81 Mon Sep 17 00:00:00 2001 From: Bowei Kou Date: Mon, 1 Dec 2025 18:32:24 -0500 Subject: [PATCH 6/7] update torch_CNN_SPN torch_spn torch_spn_layers --- test_cnn_spn_basic.py | 278 +++++++++++ test_spn.py | 58 +++ tf_SPN_layers.py | 2 + torch_CNN_SPN.py | 1065 +++++++++++++++++++++++++++++++++++++++++ torch_spn.py | 177 +++++++ torch_spn_layers.py | 200 ++++++++ 6 files changed, 1780 insertions(+) create mode 100644 test_cnn_spn_basic.py create mode 100644 test_spn.py create mode 100644 torch_CNN_SPN.py create mode 100644 torch_spn.py create mode 100644 torch_spn_layers.py diff --git a/test_cnn_spn_basic.py b/test_cnn_spn_basic.py new file mode 100644 index 0000000..a5bb5da --- /dev/null +++ b/test_cnn_spn_basic.py @@ -0,0 +1,278 @@ +import numpy as np +import torch +import torch.nn as nn +from types import SimpleNamespace + +from torch_CNN_SPN import CNN_SPN, CNN_SPN_Parts, train_model_parts + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def make_toy_batch_linear(batch_size=64, input_dim=32, seed=0): + + rng = np.random.RandomState(seed) + X_np = rng.randn(batch_size, input_dim).astype("float32") + y_np = (X_np.sum(axis=1) > 0).astype("int64") + X = torch.from_numpy(X_np).to(device) + y = torch.from_numpy(y_np).to(device) + return X, y + + +def make_toy_images( + N=128, H=8, W=8, C=1, num_classes=2, seed=0 +): + rng = np.random.RandomState(seed) + X = rng.randn(N, H, W, C).astype("float32") + + y = (X.sum(axis=(1, 2, 3)) > 0).astype("int64") + return X, y + +class DummyEncoder(nn.Module): + def __init__(self, input_flat_dim, z_dim=16): + super().__init__() + self.fc_mu = nn.Linear(input_flat_dim, z_dim) + self.fc_logvar = nn.Linear(input_flat_dim, z_dim) + + def forward(self, x, training=True): + if x.dim() > 2: + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_logvar(x) + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + z = mu + eps * std + return z, mu, logvar + + +class DummyDecoder(nn.Module): + def __init__(self, z_dim, out_shape): + super().__init__() + self.out_shape = out_shape + out_flat = 1 + for d in out_shape: + out_flat *= d + self.net = nn.Sequential( + nn.Linear(z_dim, 64), + nn.ReLU(), + nn.Linear(64, out_flat), + ) + + def forward(self, z): + x_hat = self.net(z) + return x_hat.view(z.size(0), *self.out_shape) + + +class DummySPN(nn.Module): + def __init__(self, in_dim): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 16), + nn.ReLU(), + nn.Linear(16, 1), + ) + + def forward(self, x, training=True): + return self.net(x) + +class DummyCkpt: + def __init__(self): + self.step = SimpleNamespace(assign_add=lambda x: None) + + +class DummyManager: + def __init__(self, name=""): + self.name = name + + def save(self): + pass + +def test_cnn_spn_basic(): + print("========== [Test 1] CNN_SPN basic ==========") + input_dim = 32 + num_classes = 2 + lr = 1e-2 + + model = CNN_SPN( + num_classes=num_classes, + input_dimensions=(input_dim,), + learning_rate=lr, + spn=None, + cnn=None, + get_max=False, + ).to(device) + + if not hasattr(model, "use_add_info"): + model.use_add_info = False + + X, y = make_toy_batch_linear(batch_size=64, input_dim=input_dim, seed=123) + + with torch.no_grad(): + loss = model.model_execution_X_y(X, y, other=None) + print(f"[CNN_SPN] forward loss: {float(loss):.6f}") + + print("[CNN_SPN] start small train loop...") + for step in range(20): + loss = model.train_step(X, y, other=None) + print(f" step {step:02d}, loss = {float(loss):.6f}") + print() + +def test_cnn_spn_parts_forward(): + print("========== [Test 2] CNN_SPN_Parts forward ==========") + H, W, C = 8, 8, 1 + num_classes = 2 + lr = 1e-3 + z_dim = 16 + + X_np, y_np = make_toy_images(N=16, H=H, W=W, C=C, num_classes=num_classes, seed=0) + + X = torch.from_numpy(X_np).to(device) + y = torch.from_numpy(y_np).to(device) + + input_flat_dim = H * W * C + encoder = DummyEncoder(input_flat_dim=input_flat_dim, z_dim=z_dim) + decoder = DummyDecoder(z_dim=z_dim, out_shape=(H, W, C)) + spn_in_dim = 1 + z_dim + all_spn_x_y = [DummySPN(spn_in_dim) for _ in range(num_classes)] + priors = [1.0 / num_classes] * num_classes + + model = CNN_SPN_Parts( + num_classes=num_classes, + learning_rate=lr, + all_spn_x_y_model=all_spn_x_y, + all_prior=priors, + cnn=encoder, + get_max=False, + gauss_embeds=0.0, + use_add_info=False, + VAE_fine_tune=0, + decoder=decoder, + loss_weights=[1.0, 1.0, 1.0], + load_pretrain_model=0, + use_GAN=False, + filter_size=3, + num_layer=4, + input_shape=(H, W, C), + batchnorm_integration=1, + shortcut=0, + activation="relu", + num_filter=[64, 128, 128], + strides=[2, 2, 2, 2], + end_dim=100, + dropout=0.2, + dilations=1, + discriminator_name="gan_discriminator", + clf_mlp=None, + ).to(device) + + if not hasattr(model, "use_add_info"): + model.use_add_info = False + + with torch.no_grad(): + out = model.model_execution_X_y(X, y) + print("[CNN_SPN_Parts] model_execution_X_y output shape:", out.shape) + + with torch.no_grad(): + logits = model.model_execution_X(X, other_data=None, training=False) + print("[CNN_SPN_Parts] model_execution_X output shape:", logits.shape) + + probs = torch.softmax(logits, dim=-1) + preds = probs.argmax(dim=-1) + print("[CNN_SPN_Parts] probs[0]:", probs[0].cpu().numpy()) + print("[CNN_SPN_Parts] preds:", preds.cpu().numpy()) + print() + + +def test_train_model_parts_loop(): + print("========== [Test 3] train_model_parts loop ==========") + + H, W, C = 8, 8, 1 + num_classes = 2 + z_dim = 16 + + X_train, y_train = make_toy_images(N=128, H=H, W=W, C=C, num_classes=num_classes, seed=1) + X_val, y_val = make_toy_images(N=64, H=H, W=W, C=C, num_classes=num_classes, seed=2) + X_test, y_test = make_toy_images(N=64, H=H, W=W, C=C, num_classes=num_classes, seed=3) + + train_data = (X_train, y_train) + val_data = (X_val, y_val) + test_data = (X_test, y_test) + + input_flat_dim = H * W * C + encoder = DummyEncoder(input_flat_dim=input_flat_dim, z_dim=z_dim) + decoder = DummyDecoder(z_dim=z_dim, out_shape=(H, W, C)) + spn_in_dim = 1 + z_dim + all_spn_x_y = [DummySPN(spn_in_dim) for _ in range(num_classes)] + priors = [1.0 / num_classes] * num_classes + + cnn_spn = CNN_SPN_Parts( + num_classes=num_classes, + learning_rate=1e-3, + all_spn_x_y_model=all_spn_x_y, + all_prior=priors, + cnn=encoder, + get_max=False, + gauss_embeds=0.0, + use_add_info=False, + VAE_fine_tune=0, + decoder=decoder, + loss_weights=[1.0, 1.0, 1.0], + load_pretrain_model=0, + use_GAN=False, + filter_size=3, + num_layer=4, + input_shape=(H, W, C), + batchnorm_integration=1, + shortcut=0, + activation="relu", + num_filter=[64, 128, 128], + strides=[2, 2, 2, 2], + end_dim=100, + dropout=0.2, + dilations=1, + discriminator_name="gan_discriminator", + clf_mlp=None, + ).to(device) + + if not hasattr(cnn_spn, "use_add_info"): + cnn_spn.use_add_info = False + + grid_params = SimpleNamespace( + use_add_info=False, + batch_size=16, + VAE_fine_tune=0, + GAN=False, + ) + + ckpt = [DummyCkpt(), DummyCkpt()] + manager = [DummyManager("vae"), DummyManager("spn")] + + val_entropy_init = 0.0 + + eval_after_train, debug_info = train_model_parts( + grid_params=grid_params, + cnn_spn=cnn_spn, + train_data=train_data, + val_data=val_data, + test_data=test_data, + num_iterations=5, + ckpt=ckpt, + manager=manager, + val_entropy=val_entropy_init, + val_acc=0.0, + add_info=False, + ) + + print("[train_model_parts] eval_after_train:", eval_after_train) + print("[train_model_parts] len(debug_info):", len(debug_info)) + print() + + +if __name__ == "__main__": + torch.manual_seed(0) + np.random.seed(0) + + test_cnn_spn_basic() + + test_cnn_spn_parts_forward() + + test_train_model_parts_loop() \ No newline at end of file diff --git a/test_spn.py b/test_spn.py new file mode 100644 index 0000000..10bb09d --- /dev/null +++ b/test_spn.py @@ -0,0 +1,58 @@ +import numpy as np +import torch + +from spn.structure.leaves.parametric.Parametric import Gaussian +from spn.structure.Base import Sum, Product + +from spn.algorithms.Inference import log_likelihood + +from torch_spn import create_torch_spn + + +def build_toy_spn(): + + g00 = Gaussian(mean=0.0, stdev=1.0, scope=0) + g01 = Gaussian(mean=3.0, stdev=1.0, scope=0) + + g10 = Gaussian(mean=0.0, stdev=1.0, scope=1) + g11 = Gaussian(mean=-3.0, stdev=1.0, scope=1) + + p0 = Product(children=[g00, g10]) + p1 = Product(children=[g01, g11]) + + # sum 节点(root) + root = Sum(weights=[0.4, 0.6], children=[p0, p1]) + + return root + + +def main(): + spn_root = build_toy_spn() + + torch_spn_model, var_dict, spn_copy = create_torch_spn(spn_root) + + B = 5 + x_np = np.random.randn(B, 2).astype(np.float32) + x_torch = torch.from_numpy(x_np) + + ll_spflow = log_likelihood(spn_copy, x_np) + + torch_spn_model.eval() + with torch.no_grad(): + ll_torch = torch_spn_model(x_torch) + + ll_torch_np = ll_torch.detach().cpu().numpy() + + print("x:") + print(x_np) + print("SPFlow log-likelihood:") + print(ll_spflow) + print("Torch SPN log-likelihood:") + print(ll_torch_np) + + print("Diff = Torch - SPFlow:") + print(ll_torch_np - ll_spflow) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tf_SPN_layers.py b/tf_SPN_layers.py index 9e9dc96..9dd6445 100644 --- a/tf_SPN_layers.py +++ b/tf_SPN_layers.py @@ -106,3 +106,5 @@ def log_sum_to_tf_graph(node, children, data_placeholder=None, variable_dict=Non return layer(children) + + diff --git a/torch_CNN_SPN.py b/torch_CNN_SPN.py new file mode 100644 index 0000000..0f78008 --- /dev/null +++ b/torch_CNN_SPN.py @@ -0,0 +1,1065 @@ +import time +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sklearn import metrics +from sklearn.metrics import accuracy_score +from sklearn.metrics import balanced_accuracy_score +from sklearn.metrics import precision_recall_fscore_support + +from torch.utils.data import TensorDataset, DataLoader + + +def _to_tensor(x, dtype=torch.float32, device=None): + if x is None: + return None + if isinstance(x, np.ndarray): + return torch.from_numpy(x).to(dtype=dtype, device=device) + if torch.is_tensor(x): + return x.to(dtype=dtype, device=device) + return torch.tensor(x, dtype=dtype, device=device) + + +def create_embedding_model(input_dimensions): + + in_dim = int(np.prod(input_dimensions)) + + class EmbeddingModel(nn.Module): + def __init__(self): + super().__init__() + self.net = nn.Sequential( + nn.Flatten(), + nn.Linear(in_dim, 10), + nn.ReLU(), + nn.Linear(10, 8), + nn.ReLU(), + nn.Linear(8, 2), + nn.ReLU(), + ) + + def forward(self, x): + return self.net(x) + + return EmbeddingModel() + + +class CNN_SPN(nn.Module): + + def __init__(self, num_classes, input_dimensions, learning_rate, spn=None, cnn=None, get_max=True): + super(CNN_SPN, self).__init__() + self.num_classes = num_classes + self.get_max = get_max + self.use_add_info = False + + if cnn is not None: + self.embedding = cnn + else: + self.embedding = create_embedding_model(input_dimensions) + + if spn is not None: + self.spn_training = True + self.clf = spn + self.clf_loss = None + else: + self.spn_training = False + self.clf = nn.Sequential( + nn.Linear(2, 1), + nn.Sigmoid() + ) + self.clf_loss = nn.BCELoss() + + self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate) + + def model_execution_X_y(self, X, y, other): + device = next(self.parameters()).device + X = _to_tensor(X, dtype=torch.float32, device=device) + y = _to_tensor(y, dtype=torch.long, device=device) + if other is not None and not isinstance(other, int): + other = _to_tensor(other, dtype=torch.float32, device=device) + else: + other = None + + embedding = self.embedding(X) + + if self.spn_training: + if self.get_max and embedding.dim() > 2: + embedding = torch.amax(embedding, dim=(1, 2)) + if self.use_add_info and other is not None: + embedding = torch.cat([embedding, other], dim=-1) + y_float = y.to(torch.float32).unsqueeze(-1) + spn_input = torch.cat([y_float, embedding], dim=-1) + spn_output = self.clf(spn_input) + loss = -torch.sum(spn_output) + else: + if embedding.dim() > 2: + embedding = torch.amax(embedding, dim=tuple(range(1, embedding.dim()))) + logits = self.clf(embedding).squeeze(-1) + y_float = y.to(torch.float32) + loss = self.clf_loss(logits, y_float) + + return loss + + def train_step(self, x, y, other): + self.train() + self.optimizer.zero_grad() + loss = self.model_execution_X_y(x, y, other) + loss.backward() + self.optimizer.step() + return loss.detach() + + def train_model(self, train_ds, first_loss): + all_losses = 0.0 + counter = 0.0 + for it_c, train_rec in enumerate(train_ds): + other = 0 + if self.use_add_info: + (X, y, other) = train_rec + else: + (X, y) = train_rec + + loss = self.train_step(X, y, other) + all_losses += loss.item() + counter += 1 + if not it_c and first_loss: + print(it_c, 'loss', all_losses / counter) + return all_losses / counter + + def eval_cnn(self, test_X): + self.eval() + with torch.no_grad(): + x = _to_tensor(test_X, dtype=torch.float32, device=next(self.parameters()).device) + return self.embedding(x) + + def get_spn_variables(self): + return list(self.clf.parameters()) if hasattr(self, "clf") else [] + + +def enc_layer(net, stride, filter_num, batchnorm_integration, shortcut, + dilations, filter_size, trainable, dtype, old_h, old_num_filter, dropout, activation, + name): + conv = nn.Conv2d(old_num_filter, filter_num, kernel_size=filter_size, + stride=stride, padding=filter_size // 2, + dilation=dilations, bias=False) + layers = [conv] + + if batchnorm_integration: + layers.append(nn.BatchNorm2d(filter_num)) + + if dropout: + layers.append(nn.Dropout2d(p=dropout)) + + if activation is not None: + layers.append(nn.ReLU() if activation is F.relu else nn.Tanh()) + + block = nn.Sequential(*layers) + net.append(block) + return net, filter_num + + +def make_costume_encoder(filter_size, num_layer, input_shape, batchnorm_integration, num_filter, shortcut, strides, + activation, encoder_name, dtype, dilations, dropout, end_dim): + class CostumeEncoder(nn.Module): + def __init__(self): + super().__init__() + c_in = input_shape[2] + img_size = input_shape[0] + all_filters = [c_in] + num_filter + blocks = [] + in_ch = c_in + for layer_num, stride in enumerate(strides): + out_ch = all_filters[layer_num + 1] + conv = nn.Conv2d(in_ch, out_ch, kernel_size=filter_size, + stride=stride, padding=filter_size // 2, + dilation=dilations, bias=False) + sub_layers = [conv] + if batchnorm_integration: + sub_layers.append(nn.BatchNorm2d(out_ch)) + if dropout: + sub_layers.append(nn.Dropout2d(p=dropout)) + if activation is not None: + if activation is F.relu: + sub_layers.append(nn.ReLU()) + else: + sub_layers.append(nn.Tanh()) + blocks.append(nn.Sequential(*sub_layers)) + in_ch = out_ch + img_size = img_size // stride + + self.blocks = nn.ModuleList(blocks) + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(in_ch * img_size * img_size, end_dim) + self.fc2 = nn.Linear(end_dim, 1) + + def forward(self, x): + if x.dim() == 4 and x.shape[1] != 1 and x.shape[1] != 3: + x = x.permute(0, 3, 1, 2).contiguous() + for blk in self.blocks: + x = blk(x) + x = self.flatten(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + return x + + return CostumeEncoder() + + +def gradient_penalty(critic, real_images, fake_images): + device = real_images.device + batch_size = real_images.size(0) + alpha = torch.rand(batch_size, 1, 1, 1, device=device) + interpolated = alpha * real_images + (1 - alpha) * fake_images + interpolated.requires_grad_(True) + + critic_interpolates = critic(interpolated) + if critic_interpolates.dim() > 1: + critic_interpolates = critic_interpolates.view(-1) + + grads = torch.autograd.grad( + outputs=critic_interpolates, + inputs=interpolated, + grad_outputs=torch.ones_like(critic_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True + )[0] + + grads = grads.view(batch_size, -1) + grad_norm = torch.sqrt(torch.sum(grads ** 2, dim=1) + 1e-8) + penalty = torch.mean((grad_norm - 1.0) ** 2) + return penalty + + +class CNN_SPN_Parts(nn.Module): + def __init__(self, + num_classes, + learning_rate, + all_spn_x_y_model, + all_prior, + cnn=None, + get_max=True, + dtype=torch.float32, + gauss_embeds=0.01, + use_add_info=False, + VAE_fine_tune=0, + decoder=None, + loss_weights=None, + load_pretrain_model=1, + use_GAN=False, + + filter_size=3, + num_layer=4, + input_shape=(128, 128, 1), + batchnorm_integration=1, + shortcut=0, + activation='relu', + num_filter=None, + strides=None, + end_dim=100, + dropout=0.2, + dilations=1, + discriminator_name='gan_discriminator', + clf_mlp=None + + ): + super(CNN_SPN_Parts, self).__init__() + if num_filter is None: + num_filter = [64, 128, 128] + if strides is None: + strides = [2, 2, 2, 2] + if loss_weights is None: + loss_weights = [1.0, 1.0, 1.0] + + self.clf_mlp = clf_mlp + self.load_pretrain_model = load_pretrain_model + self.all_spn_x_y = all_spn_x_y_model + self.use_add_info = use_add_info + self.decoder = decoder + self.VAE_fine_tune = VAE_fine_tune + self.loss_weights = loss_weights + + prior = torch.tensor(all_prior, dtype=dtype) + self.register_buffer("prior_weights", prior) + + self.num_classes = num_classes + self.get_max = get_max + + self.embedding = cnn + + self.spn_training = True + self.clf = all_spn_x_y_model + + self.clf_loss = nn.NLLLoss() + self.gauss_embed = gauss_embeds + self.gauss_std = gauss_embeds + + self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate) + + self.use_GAN = use_GAN + if self.use_GAN: + if activation == 'relu': + act_fn = F.relu + elif activation == 'tanh': + act_fn = torch.tanh + else: + act_fn = F.relu + + self.discriminator = make_costume_encoder( + filter_size, num_layer, input_shape, + batchnorm_integration, num_filter, shortcut, strides, + act_fn, discriminator_name, dtype, dilations, dropout, end_dim + ) + + self.gan_discr_optimizer = torch.optim.Adam( + self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.9) + ) + self.gan_gener_optimizer = torch.optim.Adam( + self.decoder.parameters(), lr=0.0001, betas=(0.5, 0.9) + ) + + def _maybe_gauss(self, embedding, training): + if self.gauss_embed and training: + noise = torch.randn_like(embedding) * self.gauss_std + embedding = embedding + noise + return embedding + + def model_execution_X_y(self, X, y): + + device = next(self.parameters()).device + X = _to_tensor(X, dtype=torch.float32, device=device) + y = _to_tensor(y, dtype=torch.long, device=device) + + embedding, _, _ = self.embedding(X) + + if self.get_max: + embedding = torch.amax(embedding, dim=(1, 2)) + + embedding = self._maybe_gauss(embedding, training=self.training) + + y_float = y.to(torch.float32).unsqueeze(-1) + spn_input = torch.cat([y_float, embedding], dim=-1) + + weights = self.prior_weights.view(1, -1) + inputs = [] + for sub_spn in self.all_spn_x_y: + out = sub_spn(spn_input) + if out.dim() > 1: + out = out.squeeze(-1) + inputs.append(out) + children_prob = torch.stack(inputs, dim=1) # [B, K] + log_enumerator = children_prob + torch.log(weights) + p_x = torch.logsumexp(log_enumerator, dim=1) + p_y_x = log_enumerator - p_x.unsqueeze(1) + return p_y_x + + def spn_clf(self, embedding, training): + + device = embedding.device + embedding = embedding.to(device) + + weights = self.prior_weights.view(1, -1) # [1, K] + inputs = [] + for label_id, sub_spn in enumerate(self.all_spn_x_y): + y = torch.full((embedding.size(0), 1), float(label_id), + device=device, dtype=embedding.dtype) + spn_input = torch.cat([y, embedding], dim=-1) + out = sub_spn(spn_input) + if out.dim() > 1: + out = out.squeeze(-1) + inputs.append(out) + children_prob = torch.stack(inputs, dim=1) # [B, K] + log_enumerator = children_prob + torch.log(weights) + p_x = torch.logsumexp(log_enumerator, dim=1, keepdim=True) + p_y_x = log_enumerator - p_x + return p_y_x, p_x + + def model_execution_X(self, X, other_data, training=True): + + self.train(mode=training) + device = next(self.parameters()).device + X = _to_tensor(X, dtype=torch.float32, device=device) + if other_data is not None and not isinstance(other_data, int): + other_data = _to_tensor(other_data, dtype=torch.float32, device=device) + else: + other_data = None + + embedding, _, _ = self.embedding(X) + + if self.get_max: + embedding = torch.amax(embedding, dim=(1, 2)) + + embedding = self._maybe_gauss(embedding, training=training) + + if self.use_add_info and other_data is not None: + embedding = torch.cat([embedding, other_data], dim=-1) + + p_y_x, p_x = self.spn_clf(embedding, training) + return p_y_x + + def train_step(self, x, y, other_data): + + self.train() + device = next(self.parameters()).device + x = _to_tensor(x, dtype=torch.float32, device=device) + y = _to_tensor(y, dtype=torch.long, device=device) + if other_data is not None and not isinstance(other_data, int): + other_data = _to_tensor(other_data, dtype=torch.float32, device=device) + else: + other_data = None + + self.optimizer.zero_grad() + spn_out = self.model_execution_X(x, other_data, training=True) + loss = self.clf_loss(spn_out, y) + loss.backward() + self.optimizer.step() + return [loss.detach()] + + def reconstruct(self, embedding): + reconstruction = self.decoder(embedding) + return reconstruction + + def vae_rec(self, X): + self.eval() + with torch.no_grad(): + device = next(self.parameters()).device + X = _to_tensor(X, dtype=torch.float32, device=device) + embedding, _, _ = self.embedding(X, training=False) + return self.reconstruct(embedding) + + def model_execution_vae(self, X, y_real, other_data, training=True): + + self.train(mode=training) + device = next(self.parameters()).device + X = _to_tensor(X, dtype=torch.float32, device=device) + y_real = _to_tensor(y_real, dtype=torch.long, device=device) + if other_data is not None and not isinstance(other_data, int): + other_data = _to_tensor(other_data, dtype=torch.float32, device=device) + else: + other_data = None + + normalized_X = X + if self.load_pretrain_model: + if normalized_X.dim() == 4 and normalized_X.shape[-1] > 1: + normalized_X = normalized_X[..., 0:1] + normalized_X = normalized_X / 255.0 + + embedding_, embed_mean, embed_var = self.embedding(X, training=training) + embedding = embedding_.clone() + reconstruction = self.decoder(embedding) + + kl_loss = -0.5 * torch.sum(1 + embed_var - embed_mean.pow(2) - embed_var.exp(), dim=1) + + if self.get_max: + embedding = torch.amax(embedding_, dim=(1, 2)) + + embedding = self._maybe_gauss(embedding, training=training) + + if self.use_add_info and other_data is not None: + embedding = torch.cat([embedding, other_data], dim=-1) + + weights = self.prior_weights.view(1, -1) + inputs = [] + for label_id, sub_spn in enumerate(self.all_spn_x_y): + y = torch.full((embedding.size(0), 1), float(label_id), + device=device, dtype=embedding.dtype) + spn_input = torch.cat([y, embedding], dim=-1) + out = sub_spn(spn_input) + if out.dim() > 1: + out = out.squeeze(-1) + inputs.append(out) + children_prob = torch.stack(inputs, dim=1) # [B, K] + log_enumerator = children_prob + torch.log(weights) + p_x = torch.logsumexp(log_enumerator, dim=1, keepdim=True) + p_y_x = log_enumerator - p_x + + clf_loss = self.clf_loss(p_y_x, y_real) + + if reconstruction.shape != normalized_X.shape: + if reconstruction.dim() == 4 and reconstruction.shape[1] == normalized_X.shape[-1]: + reconstruction_for_loss = reconstruction.permute(0, 2, 3, 1) + else: + reconstruction_for_loss = reconstruction + else: + reconstruction_for_loss = reconstruction + + rec_loss_per_pixel = F.mse_loss(reconstruction_for_loss, normalized_X, reduction="none") + while rec_loss_per_pixel.dim() > 1: + rec_loss_per_pixel = rec_loss_per_pixel.mean(dim=-1) + rec_loss = rec_loss_per_pixel + + rec_w, kl_w, clf_w = self.loss_weights + loss = rec_loss * 2 * rec_w + kl_loss * kl_w + clf_loss * clf_w + + rec_loss = rec_loss.mean() + kl_loss = kl_loss.mean() + loss = loss.mean() + + return p_y_x, loss, rec_loss, clf_loss, kl_loss, embedding_ + + def model_execution_vae_eval(self, X, y_real, other_data): + training = False + self.eval() + device = next(self.parameters()).device + X = _to_tensor(X, dtype=torch.float32, device=device) + y_real = _to_tensor(y_real, dtype=torch.long, device=device) + if other_data is not None and not isinstance(other_data, int): + other_data = _to_tensor(other_data, dtype=torch.float32, device=device) + else: + other_data = None + + normalized_X = X + if self.load_pretrain_model: + if normalized_X.dim() == 4 and normalized_X.shape[-1] > 1: + normalized_X = normalized_X[..., 0:1] + normalized_X = normalized_X / 255.0 + + with torch.no_grad(): + embedding_, embed_mean, embed_var = self.embedding(X, training=training) + embedding = embedding_.clone() + reconstruction = self.decoder(embedding) + + kl_loss = -0.5 * torch.sum(1 + embed_var - embed_mean.pow(2) - embed_var.exp(), dim=1) + + if self.get_max: + embedding = torch.amax(embedding_, dim=(1, 2)) + + embedding = self._maybe_gauss(embedding, training=training) + + if self.use_add_info and other_data is not None: + embedding = torch.cat([embedding, other_data], dim=-1) + + weights = self.prior_weights.view(1, -1) + inputs = [] + for label_id, sub_spn in enumerate(self.all_spn_x_y): + y = torch.full((embedding.size(0), 1), float(label_id), + device=device, dtype=embedding.dtype) + spn_input = torch.cat([y, embedding], dim=-1) + out = sub_spn(spn_input) + if out.dim() > 1: + out = out.squeeze(-1) + inputs.append(out) + children_prob = torch.stack(inputs, dim=1) # [B, K] + log_enumerator = children_prob + torch.log(weights) + p_x = torch.logsumexp(log_enumerator, dim=1, keepdim=True) + p_y_x = log_enumerator - p_x + + p_y_x_mlp = self.clf_mlp(embedding) if self.clf_mlp is not None else None + + if reconstruction.shape != normalized_X.shape: + if reconstruction.dim() == 4 and reconstruction.shape[1] == normalized_X.shape[-1]: + reconstruction_for_loss = reconstruction.permute(0, 2, 3, 1) + else: + reconstruction_for_loss = reconstruction + else: + reconstruction_for_loss = reconstruction + + clf_loss = self.clf_loss(p_y_x, y_real) + + rec_loss_per_pixel = F.mse_loss(reconstruction_for_loss, normalized_X, reduction="none") + while rec_loss_per_pixel.dim() > 1: + rec_loss_per_pixel = rec_loss_per_pixel.mean(dim=-1) + rec_loss = rec_loss_per_pixel + + mae_per_pixel = F.l1_loss(reconstruction_for_loss, normalized_X, reduction="none") + while mae_per_pixel.dim() > 1: + mae_per_pixel = mae_per_pixel.mean(dim=-1) + mae = mae_per_pixel + + rec_w, kl_w, clf_w = self.loss_weights + loss = rec_loss * 2 * rec_w + kl_loss * kl_w + clf_loss * clf_w + + rec_loss = rec_loss.mean() + kl_loss = kl_loss.mean() + loss = loss.mean() + mae = mae.mean() + + return p_y_x, loss, rec_loss, clf_loss, kl_loss, mae, embedding_, p_y_x_mlp + + def train_step_vae_one_loss(self, x, y, other_data): + + self.train() + device = next(self.parameters()).device + x = _to_tensor(x, dtype=torch.float32, device=device) + y = _to_tensor(y, dtype=torch.long, device=device) + if other_data is not None and not isinstance(other_data, int): + other_data = _to_tensor(other_data, dtype=torch.float32, device=device) + else: + other_data = None + + self.optimizer.zero_grad() + spn_out, loss, rec_loss, clf_loss, kl_loss, z = self.model_execution_vae(x, y, other_data, training=True) + loss.backward() + self.optimizer.step() + return [loss.detach(), rec_loss.detach(), clf_loss.detach(), kl_loss.detach()], z.detach(), spn_out.detach() + + def gan_step(self, X, y, other_info, discriminator_loss_old, generator_loss_old): + + device = next(self.parameters()).device + X = _to_tensor(X, dtype=torch.float32, device=device) + + for _ in range(5): + self.gan_discr_optimizer.zero_grad() + fake_images = self.decoder(self.embedding(X)[0]) + real_images = X[..., 0:1] / 255.0 + fake_out = self.discriminator(torch.sigmoid(fake_images)) + real_out = self.discriminator(real_images) + + gp = gradient_penalty(self.discriminator, real_images, torch.sigmoid(fake_images)) + gan_loss = fake_out.mean() - real_out.mean() + 10 * gp + gan_loss.backward() + self.gan_discr_optimizer.step() + + self.gan_gener_optimizer.zero_grad() + fake_images = self.decoder(self.embedding(X)[0]) + fake_out = self.discriminator(torch.sigmoid(fake_images)) + generator_loss = -fake_out.mean() + generator_loss.backward() + self.gan_gener_optimizer.step() + + return gan_loss.detach(), generator_loss.detach() + + def train_gan(self, train_ds, first_loss): + counter = 0.0 + all_losses = np.zeros(6, dtype=np.float64) + z_train, gt_train, predictions_train = [], [], [] + discriminator_loss_old, generator_loss_old = 0, 0 + for it_c, train_rec in enumerate(train_ds): + other_info = 0 + if self.use_add_info: + (X, y, other_info) = train_rec + else: + (X, y) = train_rec + + loss, z, pred = self.train_step_vae_one_loss(X, y, other_info) + discriminator_loss, generator_loss = self.gan_step( + X, y, other_info, discriminator_loss_old, generator_loss_old + ) + + z_train.append(z.cpu().numpy()) + gt_train.append(y) + predictions_train.append(pred.cpu().numpy()) + + loss_vals = [entry.cpu().item() for entry in loss] + loss_vals.append(discriminator_loss.cpu().item()) + loss_vals.append(generator_loss.cpu().item()) + discriminator_loss_old, generator_loss_old = discriminator_loss.cpu().item(), generator_loss.cpu().item() + + all_losses += np.array(loss_vals, dtype=np.float64) + counter += 1 + if not it_c and first_loss: + print(it_c, 'loss', all_losses / counter) + return [all_losses / counter, z_train, gt_train, predictions_train] + + def train_step_vae_diff_loss(self, x, y, other_data): + self.train() + device = next(self.parameters()).device + x = _to_tensor(x, dtype=torch.float32, device=device) + y = _to_tensor(y, dtype=torch.long, device=device) + if other_data is not None and not isinstance(other_data, int): + other_data = _to_tensor(other_data, dtype=torch.float32, device=device) + else: + other_data = None + + self.optimizer.zero_grad() + spn_out = self.model_execution_X(x, other_data, training=True) + loss = self.clf_loss(spn_out, y) + loss.backward() + self.optimizer.step() + return loss.detach() + + def train_model(self, train_ds, first_loss): + """ + self.train(...) + """ + counter = 0.0 + if self.VAE_fine_tune == 0: + all_losses = np.zeros(1, dtype=np.float64) + for it_c, train_rec in enumerate(train_ds): + other_info = 0 + if self.use_add_info: + (X, y, other_info) = train_rec + else: + (X, y) = train_rec + + loss_list = self.train_step(X, y, other_info) + new_loss = [entry.cpu().item() for entry in loss_list] + all_losses += np.array(new_loss, dtype=np.float64) + counter += 1 + if not it_c and first_loss: + print(it_c, 'loss', all_losses / counter) + return all_losses / counter + elif self.VAE_fine_tune == 1: + all_losses = np.zeros(4, dtype=np.float64) + z_train, gt_train, predictions_train = [], [], [] + for it_c, train_rec in enumerate(train_ds): + other_info = 0 + if self.use_add_info: + (X, y, other_info) = train_rec + else: + (X, y) = train_rec + + loss_list, z, pred = self.train_step_vae_one_loss(X, y, other_info) + + z_train.append(z.cpu().numpy()) + gt_train.append(y) + predictions_train.append(pred.cpu().numpy()) + + new_loss = [entry.cpu().item() for entry in loss_list] + all_losses += np.array(new_loss, dtype=np.float64) + counter += 1 + if not it_c and first_loss: + print(it_c, 'loss', all_losses / counter) + return [all_losses / counter, z_train, gt_train, predictions_train] + + elif self.VAE_fine_tune == 2: + all_losses = np.zeros(4, dtype=np.float64) + for it_c, train_rec in enumerate(train_ds): + other_info = 0 + if self.use_add_info: + (X, y, other_info) = train_rec + else: + (X, y) = train_rec + + loss = self.train_step_vae_diff_loss(X, y, other_info) + new_loss = [loss.cpu().item()] + all_losses += np.array(new_loss, dtype=np.float64) + counter += 1 + if not it_c and first_loss: + print(it_c, 'loss', all_losses / counter) + return all_losses / counter + + def eval_cnn(self, test_X): + self.eval() + with torch.no_grad(): + x = _to_tensor(test_X, dtype=torch.float32, device=next(self.parameters()).device) + return self.embedding(x) + + def get_spn_variables(self): + vars_list = [] + for sub_spn in self.all_spn_x_y: + if hasattr(sub_spn, "parameters"): + vars_list.extend(list(sub_spn.parameters())) + return vars_list + + +def train_model_parts(grid_params, cnn_spn, train_data, val_data, test_data, + num_iterations, ckpt, manager, val_entropy, val_acc=0, + add_info=False): + + + device = next(cnn_spn.parameters()).device + + first_loss = True + + num_params = sum(p.numel() for p in cnn_spn.parameters() if p.requires_grad) + print('number of trainable variables in cnn spn:', num_params) + + train_start_time = time.time() + best_acc_loss = val_acc + best_val_reconstruction = 1e8 + eval_after_train = [] + all_debugging_stuff = [] + + for i in range(num_iterations): + idx_train = list(range(train_data[0].shape[0])) + random.shuffle(idx_train) + + X = train_data[0][idx_train] + + if add_info: + y = train_data[1][idx_train, 0] + other = train_data[1][idx_train, 1:] + + X_t = torch.from_numpy(X).to(device=device, dtype=torch.float32) + y_t = torch.from_numpy(y).to(device=device, dtype=torch.long) + other_t = torch.from_numpy(other).to(device=device, dtype=torch.float32) + + if getattr(grid_params, "use_add_info", False): + dataset = TensorDataset(X_t, y_t, other_t) + else: + dataset = TensorDataset(X_t, y_t) + else: + y = train_data[1][idx_train] + X_t = torch.from_numpy(X).to(device=device, dtype=torch.float32) + y_t = torch.from_numpy(y).to(device=device, dtype=torch.long) + dataset = TensorDataset(X_t, y_t) + + train_loader = DataLoader( + dataset, + batch_size=grid_params.batch_size, + shuffle=False + ) + + if grid_params.VAE_fine_tune and not getattr(grid_params, "GAN", False): + + train_loss, z_train, gt_train, predictions_train = cnn_spn.train_model( + train_loader, first_loss=first_loss + ) + elif grid_params.VAE_fine_tune and getattr(grid_params, "GAN", False): + + train_loss, z_train, gt_train, predictions_train = cnn_spn.train_gan( + train_loader, first_loss=first_loss + ) + else: + + train_loss = cnn_spn.train_model(train_loader, first_loss=first_loss) + + first_loss = False + + if i % 3 == 0: + improve = False + + if grid_params.VAE_fine_tune: + + val_losses, z, gt, arg_max, pred_exp = test_model_all( + cnn_spn, + val_data, + num_classes=2, + batch_size=grid_params.batch_size, + training=False, + add_info=add_info + ) + curr_acc = val_losses[2] + + if curr_acc > best_acc_loss: + best_acc_loss = curr_acc + improve = True + + eval_after_train = test_model_no_mpe( + cnn_spn, + [train_data, val_data, test_data], + cnn_spn.num_classes, + grid_params.batch_size, + add_info=add_info + ) + + else: + [val_losses] = test_model_no_mpe( + cnn_spn, + [val_data], + num_classes=2, + batch_size=grid_params.batch_size, + training=False, + add_info=add_info + ) + curr_acc = val_losses[2] + + if curr_acc > best_acc_loss: + best_acc_loss = curr_acc + improve = True + + eval_after_train = test_model_no_mpe( + cnn_spn, + [train_data, val_data, test_data], + cnn_spn.num_classes, + grid_params.batch_size, + add_info=add_info + ) + + if improve and ckpt is not None and manager is not None: + try: + if len(ckpt) > 0 and hasattr(ckpt[0], "step"): + step0 = ckpt[0].step + if hasattr(step0, "assign_add"): + step0.assign_add(1) + elif isinstance(step0, (int, float)): + ckpt[0].step = step0 + 1 + if len(manager) > 0 and hasattr(manager[0], "save"): + manager[0].save() + + if len(ckpt) > 1 and hasattr(ckpt[1], "step"): + step1 = ckpt[1].step + if hasattr(step1, "assign_add"): + step1.assign_add(1) + elif isinstance(step1, (int, float)): + ckpt[1].step = step1 + 1 + + if len(manager) > 1 and hasattr(manager[1], "save"): + manager[1].save() + except TypeError: + pass + + print( + 'val entropy', i, + val_losses[1], + 'improve', improve, + 'curr acc', val_losses[0], val_losses[2], + 'best acc', best_acc_loss + ) + + if grid_params.VAE_fine_tune: + print('curr rec N/A', 'best rec:', best_val_reconstruction) + + print('loss', i, end=': ') + if isinstance(train_loss, (np.ndarray, list, tuple)): + for loss_val in train_loss: + loss_val = float(loss_val) + print(np.round(loss_val, 5), end=' - ') + else: + print(np.round(float(train_loss), 5), end=' - ') + print() + + print('fine tune train time', (time.time() - train_start_time) // 60) + return eval_after_train, all_debugging_stuff + + +def test_model_no_mpe(tfmodel, data_sets, num_classes, batch_size=25, training=False, add_info=False): + device = next(tfmodel.parameters()).device + all_evals = [] + for dataset in data_sets: + prediction = [] + gt = [] + for i in range(dataset[0].shape[0] // batch_size): + X_data = dataset[0][i * batch_size:(i + 1) * batch_size] + other_data = 0 + if add_info: + y_data = dataset[1][i * batch_size:(i + 1) * batch_size, 0] + other_data = dataset[1][i * batch_size:(i + 1) * batch_size, 1:] + else: + y_data = dataset[1][i * batch_size:(i + 1) * batch_size] + + X_t = torch.from_numpy(X_data).to(device=device, dtype=torch.float32) + if isinstance(y_data, np.ndarray): + y_t = torch.from_numpy(y_data).to(device=device, dtype=torch.long) + else: + y_t = torch.tensor(y_data, device=device, dtype=torch.long) + + if add_info and isinstance(other_data, np.ndarray): + other_t = torch.from_numpy(other_data).to(device=device, dtype=torch.float32) + else: + other_t = None + + pred = tfmodel.model_execution_X(X_t, other_t, training=False) + prediction.extend(pred.detach().cpu().numpy().tolist()) + gt.extend(y_t.detach().cpu().numpy().tolist()) + + prediction = np.asarray(prediction) + gt_np = np.asarray(gt) + + logits_t = torch.from_numpy(prediction).to(device=device, dtype=torch.float32) + gt_t = torch.from_numpy(gt_np).to(device=device, dtype=torch.long) + entropy = tfmodel.clf_loss(logits_t, gt_t).item() + + prediction_exponential = np.exp(prediction) + arg_max = np.argmax(prediction_exponential, axis=-1) + acc = accuracy_score(gt_np, arg_max, normalize=True) + + auc = 0 + if num_classes == 2: + pred_exp = np.nan_to_num( + prediction_exponential[:, 1], + nan=0.0, + posinf=1.0, + neginf=0.0, + ) + fpr, tpr, thresholds = metrics.roc_curve(gt_np, pred_exp, pos_label=1) + auc = metrics.auc(fpr, tpr) + balanced_acc = balanced_accuracy_score(gt_np, arg_max) + [prec, rec, f1, _] = precision_recall_fscore_support(gt_np, arg_max, average=None) + results = [acc, entropy, balanced_acc, prec[1], rec[1], f1[1], auc] + all_evals.append(results) + print(results) + return all_evals + + +def test_model_SPN_MLP(tfmodel, dataset, num_classes, batch_size=25, training=False, add_info=False): + mlp_prediction = [] + prediction = [] + gt = [] + z = [] + losses = np.zeros(5, dtype=np.float64) + n_samples = dataset[0].shape[0] + for i in range(n_samples // batch_size): + X_data = dataset[0][i * batch_size:(i + 1) * batch_size] + other_data = 0 + if add_info: + y_data = dataset[1][i * batch_size:(i + 1) * batch_size, 0] + other_data = dataset[1][i * batch_size:(i + 1) * batch_size, 1:] + else: + y_data = dataset[1][i * batch_size:(i + 1) * batch_size] + + pred, loss, rec_loss, clf_loss_val, kl_loss, mae, embedding_, mlp_pred = tfmodel.model_execution_vae_eval( + X_data, y_data, other_data + ) + new_loss = np.asarray( + [loss.item(), rec_loss.item(), clf_loss_val.item(), kl_loss.item(), mae.item()], + dtype=np.float64 + ) + losses += new_loss + prediction.extend(pred.detach().cpu().numpy().tolist()) + mlp_prediction.extend(mlp_pred.detach().cpu().numpy().tolist()) + gt.extend(y_data.tolist()) + z.extend(embedding_.detach().cpu().numpy().tolist()) + losses /= n_samples // batch_size + prediction = np.asarray(prediction) + mlp_prediction = np.asarray(mlp_prediction) + + prediction_exponential = np.exp(prediction) + + results_MLP, _, _ = eval_cls(mlp_prediction, mlp_prediction, gt, tfmodel.clf_loss, num_classes) + results_SPN, _, _ = eval_cls(prediction, prediction_exponential, gt, tfmodel.clf_loss, num_classes) + + return results_MLP, results_SPN, losses + + +def test_model_all(tfmodel, dataset, num_classes, batch_size=25, training=False, add_info=False): + prediction = [] + gt = [] + z = [] + losses = np.zeros(4, dtype=np.float64) + n_samples = dataset[0].shape[0] + for i in range(n_samples // batch_size): + X_data = dataset[0][i * batch_size:(i + 1) * batch_size] + other_data = 0 + if add_info: + y_data = dataset[1][i * batch_size:(i + 1) * batch_size, 0] + other_data = dataset[1][i * batch_size:(i + 1) * batch_size, 1:] + else: + y_data = dataset[1][i * batch_size:(i + 1) * batch_size] + + pred, loss, rec_loss, clf_loss_val, kl_loss, mae, embedding_, mlp_pred = tfmodel.model_execution_vae_eval( + X_data, y_data, other_data + ) + new_loss = np.asarray( + [loss.item(), rec_loss.item(), clf_loss_val.item(), kl_loss.item()], + dtype=np.float64 + ) + losses += new_loss + prediction.extend(pred.detach().cpu().numpy().tolist()) + gt.extend(y_data.tolist()) + z.extend(embedding_.detach().cpu().numpy().tolist()) + losses /= n_samples // batch_size + prediction = np.asarray(prediction) + + prediction_exponential = np.exp(prediction) + + results, pred_exp, arg_max = eval_cls( + prediction, prediction_exponential, gt, tfmodel.clf_loss, num_classes + ) + return results, z, gt, arg_max, pred_exp + + +def eval_cls(pred_logits, prediction_exponential, gt, clf_loss, num_classes=2): + device = torch.device("cpu") + pred_arg_max = np.argmax(prediction_exponential, axis=-1) + + logits_t = torch.from_numpy(pred_logits).to(dtype=torch.float32, device=device) + gt_t = torch.tensor(gt, dtype=torch.long, device=device) + entropy = clf_loss(logits_t, gt_t).item() + + acc = accuracy_score(gt, pred_arg_max, normalize=True) + + auc = 0 + pred_exp = None + if num_classes == 2: + pred_exp = np.nan_to_num(prediction_exponential[:, 1], nan=0, posinf=1.0) + fpr, tpr, thresholds = metrics.roc_curve(gt, pred_exp, pos_label=1) + auc = metrics.auc(fpr, tpr) + balanced_acc = balanced_accuracy_score(gt, pred_arg_max) + prec, rec, f1, _ = precision_recall_fscore_support(gt, pred_arg_max, average=None) + + return [acc, entropy, balanced_acc, prec[1], rec[1], f1[1], auc], pred_exp, pred_arg_max \ No newline at end of file diff --git a/torch_spn.py b/torch_spn.py new file mode 100644 index 0000000..6cbcf43 --- /dev/null +++ b/torch_spn.py @@ -0,0 +1,177 @@ +import torch +import torch.nn as nn +import numpy as np + +from spn.structure.leaves.parametric.Parametric import Gaussian, Categorical +from spn.structure.Base import Sum, Product, Leaf, get_topological_order +from spn.algorithms.TransformStructure import Copy + +import torch_spn_layers as SPN_lay + + +_node_log_torch_graph = { + Sum: SPN_lay.log_sum_to_torch_graph, + Product: SPN_lay.log_prod_to_torch_graph, + Gaussian: SPN_lay.gaussian_to_torch_graph, + Categorical: SPN_lay.categorical_to_torch_graph, +} + + + +def torch_graph_to_sum(node, torchvar): + + with torch.no_grad(): + weights = torch.softmax(torchvar, dim=0).cpu().numpy().tolist() + node.weights = weights + + +def torch_graph_to_gaussian(node, torchvars): + mean_param, stdev_param = torchvars + with torch.no_grad(): + node.mean = mean_param.cpu().numpy() + node.stdev = np.maximum(stdev_param.cpu().numpy(), 0.01) + + +def torch_graph_to_categorical(node, torchvar): + with torch.no_grad(): + p = torch.softmax(torchvar, dim=0).cpu().numpy() + # 加 1e-3 防止 0 + node.p = (p + 1e-3).tolist() + + +_torch_graph_to_node = { + Sum: torch_graph_to_sum, + Gaussian: torch_graph_to_gaussian, + Categorical: torch_graph_to_categorical, +} + + +def spn_to_torch_graph(node, + torch_input, + eval_functions=_node_log_torch_graph, + **args): + + all_results = {} + nodes = get_topological_order(node) + for node_type, func in eval_functions.items(): + if "_eval_func" not in node_type.__dict__: + node_type._eval_func = [] + node_type._eval_func.append(func) + node_type._is_leaf = issubclass(node_type, Leaf) + + leaf_func = eval_functions.get(Leaf, None) + + tmp_children_list = [] + len_tmp_children_list = 0 + + for n_idx, n in enumerate(nodes): + try: + func = n.__class__._eval_func[-1] + n_is_leaf = n.__class__._is_leaf + except Exception: + if isinstance(n, Leaf) and leaf_func is not None: + func = leaf_func + n_is_leaf = True + else: + raise AssertionError( + "No eval function associated with type: %s" % (n.__class__.__name__) + ) + + if n_is_leaf: + result = func(n, torch_input, **args) + else: + len_children = len(n.children) + if len_tmp_children_list < len_children: + tmp_children_list.extend([None] * len_children) + len_tmp_children_list = len(tmp_children_list) + for i in range(len_children): + ci = n.children[i] + tmp_children_list[i] = all_results[ci] + + result = func(n, tmp_children_list[0:len_children], **args) + + all_results[n] = result + + for node_type, func in eval_functions.items(): + del node_type._eval_func[-1] + if len(node_type._eval_func) == 0: + delattr(node_type, "_eval_func") + + torch_graph = all_results[node] + return torch_graph + + +class TorchSPN(nn.Module): + + def __init__(self, spn_root, eval_functions=_node_log_torch_graph, trainable_leaf=True): + super().__init__() + self.spn_copy = Copy(spn_root) + self.eval_functions = eval_functions + self.trainable_leaf = trainable_leaf + + self.variable_dict = {} + + def forward(self, x): + log_prob = spn_to_torch_graph( + self.spn_copy, + x, + eval_functions=self.eval_functions, + variable_dict=self.variable_dict, + trainable_leaf=self.trainable_leaf, + ) + return log_prob + + +def create_torch_spn(spn): + + model = TorchSPN(spn) + return model, model.variable_dict, model.spn_copy + +def create_torch_spn_parts(spn_x, label_ids, trainable_leaf): + + all_spn_x_y = [] + all_spn_x_y_dicts = [] + all_prior = [] + all_spn_x_y_models = [] + + spn_x_copy = Copy(spn_x) + + sorted_list = list(sorted( + zip(spn_x_copy.children, spn_x_copy.weights, label_ids), + key=lambda x: x[2] + )) + print('SPN ROOT weights:', spn_x_copy.weights) + + for spn_x_y, prior_y, label_id in sorted_list: + model = TorchSPN( + spn_x_y, + eval_functions=_node_log_torch_graph, + trainable_leaf=trainable_leaf, + ) + all_spn_x_y.append(spn_x_y) + all_spn_x_y_dicts.append(model.variable_dict) + all_prior.append(prior_y) + all_spn_x_y_models.append(model) + + return spn_x_copy, all_spn_x_y, all_spn_x_y_dicts, all_prior, all_spn_x_y_models + +def torch_graph_to_spn(variable_dict, torch_graph_to_node=_torch_graph_to_node): + + tensors = [] + + for n, torchvars in variable_dict.items(): + tensors.append(torchvars) + + for i, (n, torchvars) in enumerate(variable_dict.items()): + fn = torch_graph_to_node.get(type(n), None) + if fn is not None: + fn(n, tensors[i]) + +def build_torch_spn(root, num_vars=None, device="cpu", trainable_leaf=True): + model = TorchSPN( + spn_root=root, + eval_functions=_node_log_torch_graph, + trainable_leaf=trainable_leaf, + ) + model.to(device) + return model \ No newline at end of file diff --git a/torch_spn_layers.py b/torch_spn_layers.py new file mode 100644 index 0000000..e153a2b --- /dev/null +++ b/torch_spn_layers.py @@ -0,0 +1,200 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Normal, Categorical +import numpy as np + +class GaussianLayer(nn.Module): + + def __init__(self, mean, stdev, + dtype=torch.float32, + log_space=True, + trainable_nodes=True, + name="Gauss"): + super().__init__() + + mean_t = torch.as_tensor(mean, dtype=dtype) + stdev_t = torch.as_tensor(stdev, dtype=dtype) + + self.mean = nn.Parameter(mean_t, requires_grad=trainable_nodes) + self.stdev = nn.Parameter(stdev_t, requires_grad=trainable_nodes) + + tmp_t = torch.tensor(0.001, dtype=dtype) + self.tmp = nn.Parameter(tmp_t, requires_grad=trainable_nodes) + + self.trainable_nodes = trainable_nodes + + def forward(self, inputs, nd_idxs): + + # tf.gather_nd(inputs, nd_idxs) + x = inputs[nd_idxs[:, 0], nd_idxs[:, 1]] # [B] + x = x.view(inputs.size(0), 1) # [B, 1] + + if self.trainable_nodes: + # tf.keras.layers.maximum([self.stdev, self.tmp]) + stdev = torch.maximum(self.stdev, self.tmp) + else: + stdev = self.stdev + + dist = Normal(loc=self.mean, scale=stdev) + # broadcasting: mean/stdev + log_prob = dist.log_prob(x) + + return log_prob # [B, 1] + + +class CategoricalLayer(nn.Module): + + def __init__(self, prob, log_space=True, name="Categorical"): + super().__init__() + prob_t = torch.as_tensor(prob, dtype=torch.float32) + self.probs = nn.Parameter(prob_t, requires_grad=False) + + def forward(self, inputs, nd_idxs): + x = inputs[nd_idxs[:, 0], nd_idxs[:, 1]] # [B] + x = x.view(inputs.size(0), 1) + + # logits -> probs + softmax_probs = F.softmax(self.probs, dim=0) + dist = Categorical(probs=softmax_probs) + + log_prob = dist.log_prob(x.squeeze(-1).long()) # [B] + + return log_prob.view(-1, 1) # [B, 1] + + +def get_batch_idx(node, data_input): + + device = data_input.device + batch_size = data_input.size(0) + + idx = node.scope[0] + batch_idxs = torch.arange(batch_size, device=device, dtype=torch.long) + feat_idxs = torch.full((batch_size,), idx, + device=device, dtype=torch.long) + + nd_idxs = torch.stack([batch_idxs, feat_idxs], dim=1) # [B, 2] + return nd_idxs + + +class LogSumLayer(nn.Module): + def __init__(self, softmax_inverse, + dtype=torch.float32, + name="log_sum"): + super().__init__() + logits = torch.as_tensor(softmax_inverse, dtype=dtype) + self.logits = nn.Parameter(logits, requires_grad=True) + + def forward(self, inputs): + # [num_children] + weights = F.softmax(self.logits, dim=0) + log_w = torch.log(weights + 1e-8) + + # children_prob: [B, num_children, 1] + children_prob = torch.stack(inputs, dim=1) + + log_w = log_w.view(1, -1, 1) + + out = torch.logsumexp(children_prob + log_w, dim=1) # [B, 1] + return out + + +class LogProdLayer(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, inputs): + + # tf.add_n(inputs) + out = inputs[0] + for x in inputs[1:]: + out = out + x + return out + + +def gaussian_to_torch_graph(node, + data_input=None, + log_space=True, + variable_dict=None, + dtype=torch.float32, + trainable_leaf=True): + + nd_idxs = get_batch_idx(node, data_input) + layer = GaussianLayer(node.mean, node.stdev, + dtype=dtype, + trainable_nodes=trainable_leaf, + name=node.__class__.__name__ + str(node.id)) + if variable_dict is not None: + variable_dict[node] = (layer.mean, layer.stdev) + return layer(data_input, nd_idxs) + + +def categorical_to_torch_graph(node, + data_placeholder=None, + log_space=True, + variable_dict=None, + dtype=np.float32, + trainable_leaf=False): + nd_idxs = get_batch_idx(node, data_placeholder) + p = np.array(node.p, dtype=dtype) + softmax_inverse = np.log(p / np.max(p)).astype(dtype) + + layer = CategoricalLayer(softmax_inverse, + log_space, + name=node.__class__.__name__ + str(node.id)) + if variable_dict is not None: + variable_dict[node] = layer.probs + return layer(data_placeholder, nd_idxs) + + +def log_prod_to_torch_graph(node, + children, + data_placeholder=None, + variable_dict=None, + log_space=True, + dtype=np.float32, + trainable_leaf=False): + assert log_space + layer = LogProdLayer() + return layer(children) + + +def log_sum_to_torch_graph(node, + children, + data_placeholder=None, + variable_dict=None, + log_space=True, + dtype=np.float32, + trainable_leaf=False): + assert log_space + softmax_inverse = np.log(node.weights / np.max(node.weights)).astype(dtype) + layer = LogSumLayer(softmax_inverse, + name=node.__class__.__name__ + str(node.id)) + if variable_dict is not None: + variable_dict[node] = layer.logits + return layer(children) + +if __name__ == "__main__": + class DummyNode: + def __init__(self, idx, mean, stdev, weights=None, p=None, node_id=0): + self.scope = [idx] + self.mean = mean + self.stdev = stdev + self.weights = weights + self.p = p + self.id = node_id + + B, D = 4, 3 + x = torch.randn(B, D) + + node_g = DummyNode(idx=1, mean=0.0, stdev=1.0, node_id=1) + var_dict = {} + logp_g = gaussian_to_torch_graph(node_g, x, variable_dict=var_dict) + print("Gaussian logp:", logp_g.shape, logp_g) + + node_s = DummyNode(idx=0, mean=0.0, stdev=1.0, weights=np.array([0.3, 0.7]), node_id=2) + child1 = torch.randn(B, 1) + child2 = torch.randn(B, 1) + logp_sum = log_sum_to_torch_graph(node_s, [child1, child2]) + print("Log-sum:", logp_sum.shape, logp_sum) \ No newline at end of file From d3fa242e6983b6c242a327899155425aedee5c08 Mon Sep 17 00:00:00 2001 From: Bowei Kou Date: Mon, 1 Dec 2025 18:35:04 -0500 Subject: [PATCH 7/7] update torch_CNN_SPN torch_spn torch_spn_layers --- test_spn.py | 1 - torch_spn.py | 1 - 2 files changed, 2 deletions(-) diff --git a/test_spn.py b/test_spn.py index 10bb09d..2f57d31 100644 --- a/test_spn.py +++ b/test_spn.py @@ -20,7 +20,6 @@ def build_toy_spn(): p0 = Product(children=[g00, g10]) p1 = Product(children=[g01, g11]) - # sum 节点(root) root = Sum(weights=[0.4, 0.6], children=[p0, p1]) return root diff --git a/torch_spn.py b/torch_spn.py index 6cbcf43..d9a5180 100644 --- a/torch_spn.py +++ b/torch_spn.py @@ -35,7 +35,6 @@ def torch_graph_to_gaussian(node, torchvars): def torch_graph_to_categorical(node, torchvar): with torch.no_grad(): p = torch.softmax(torchvar, dim=0).cpu().numpy() - # 加 1e-3 防止 0 node.p = (p + 1e-3).tolist()