From 1ce7fa3d6e9c33380118acb2f566e1d24b6f5bc3 Mon Sep 17 00:00:00 2001 From: sadhamanus Date: Fri, 24 Jan 2025 08:55:05 -0500 Subject: [PATCH 1/3] Updated CoFrNet files Better default parm settings, debugged, best model returned --- aix360/algorithms/cofrnet/CoFrNet.py | 310 +++++++++++++----- .../cofrnet/Customized_Linear_Classes.py | 143 ++++++++ aix360/algorithms/cofrnet/utils.py | 195 ++--------- 3 files changed, 411 insertions(+), 237 deletions(-) create mode 100644 aix360/algorithms/cofrnet/Customized_Linear_Classes.py diff --git a/aix360/algorithms/cofrnet/CoFrNet.py b/aix360/algorithms/cofrnet/CoFrNet.py index 52eec44..79fac35 100644 --- a/aix360/algorithms/cofrnet/CoFrNet.py +++ b/aix360/algorithms/cofrnet/CoFrNet.py @@ -1,10 +1,10 @@ ''' -author: @ishapuri101 +author: @ishapuri101, @sadhamanus ''' -import sys import pandas as pd #loading data in table form import numpy as np # linear algebra +import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) import torch # import main library import torch.nn as nn # import modules from torch.autograd import Function # import Function to create custom activations @@ -12,18 +12,41 @@ import torch.nn.functional as F # import torch functions from sklearn.preprocessing import MinMaxScaler from sklearn.model_selection import train_test_split +import torchsample +import os +from torchsample.modules import ModuleTrainer +from torchsample.callbacks import EarlyStopping, ReduceLROnPlateau +from torchsample.regularizers import L1Regularizer, L2Regularizer +from Customized_Linear_Classes import CustomizedLinearFunction +from Customized_Linear_Classes import CustomizedLinear +from utils import generate_connections +from utils import process_data +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning import Trainer -from aix360.algorithms.cofrnet.utils import generate_connections -from aix360.algorithms.cofrnet.utils import process_data -from aix360.algorithms.cofrnet.CustomizedLinearClasses import CustomizedLinearFunction -from aix360.algorithms.cofrnet.CustomizedLinearClasses import CustomizedLinear +from torch.utils.data import DataLoader -from aix360.algorithms.die import DIExplainer +from tqdm import tqdm +from torch.utils.data import Dataset +class OnlyTabularDataset(Dataset): + def __init__(self, values, label): + self.values = values + self.label = label + + def __len__(self): + return len(self.label) + + def __getitem__(self, index): + return { + 'tabular': torch.tensor(self.values[index], dtype=torch.float), + 'target' : torch.tensor(self.label[index], dtype=torch.long) + } + class CoFrNet_Model(nn.Module): """ CoFrNet_Model is the base class for Continued Fractions Nets (CoFrNets). @@ -53,58 +76,201 @@ def __init__(self, connections): else: self.num_total_parameters = self.num_total_parameters + np.count_nonzero(np.asarray(self.connections[i])) self.layers.append(CustomizedLinear(torch.tensor(self.connections[i]))) - + + + self.BatchNorm = nn.BatchNorm2d(self.input_features) - def modified_reciprocal_activation(self, Wx, epsilon = .1): + def modified_reciprocal_activation(self, Wx, epsilon = .01): ''' Activation function that uses capped 1/x described in paper. Takes in Wx, returns modified activation function of Wx ''' epsilonMatrix = torch.mul(torch.full_like(Wx, epsilon), torch.sign(Wx)) denom = torch.where(torch.abs(Wx) < epsilon, epsilonMatrix, Wx) + denom = torch.nan_to_num(denom, nan=epsilon) + #print(epsilonMatrix, Wx) + if torch.any(torch.isnan(torch.reciprocal(denom))): + print(f'nans present when doing 1/x') return torch.reciprocal(denom) + + + def forward(self, x): ''' Customized forward function. ''' + + #l_output --> layer output, a_output --> activation output for i in range(len(self.layers)): if (i == 0): - l_input = x + #print(f'input size: {x.size()}') + l_input = x #self.BatchNorm(x) #x l_output = self.layers[i](l_input) a_output = self.modified_reciprocal_activation(l_output) + #print("self.layers[i].output_features", self.layers[i].output_features) + batchNorm = nn.BatchNorm1d(self.layers[i].output_features) + a_output = batchNorm(a_output) elif ((i > 0) and (i != len(self.layers) - 1)): - l_input = x + l_input = x #self.BatchNorm(x) #x l_output = self.layers[i](l_input) + prev_output #l_output = self.dropout(l_output) a_output = self.modified_reciprocal_activation(l_output) + #batchNorm = nn.BatchNorm1d(self.layers[i].output_features) + #a_output = batchNorm(a_output) else: l_input = prev_output #l_input = self.dropout(l_input) l_output = self.layers[i](l_input) a_output = l_output prev_output = a_output + #print(f'output size: {a_output.size()}') return a_output +ckpt_path = 'ckpt/cofrnet.pt' + + +class CoFrNet_Explainer(): + def __init__(self, num_layers, data_input_size, data_output_size, which_variant, tensor_x_train, tensor_y_train, tensor_x_val, tensor_y_val, tensor_x_test, y_test,num_nodes): + self.num_layers = num_layers + self.data_input_size = data_input_size + self.data_output_size = data_output_size + self.which_variant = which_variant + self.model = CoFrNet_Model(generate_connections(self.num_layers, + self.data_input_size, + self.data_output_size, + self.which_variant,num_nodes)) + self.tensor_x_train = tensor_x_train + self.tensor_y_train = tensor_y_train + self.tensor_x_val = tensor_x_val + self.tensor_y_val = tensor_y_val + self.tensor_x_test = tensor_x_test + self.y_test = y_test + + self.train_dataset = OnlyTabularDataset(self.tensor_x_train, + self.tensor_y_train) + def collate_fn(batch): + batch = torch.cat([sample[0].unsqueeze(0) for sample in batch], dim=0) + return batch -class CoFrNet_Explainer(DIExplainer): - def __init__(self, cofrnet_model): - self.model = cofrnet_model + self.dataloader = DataLoader(self.train_dataset, data_input_size) + + self.x_train_dl = DataLoader(tensor_x_train, data_input_size) + self.y_train_dl = DataLoader(tensor_y_train, data_input_size) + self.x_val_dl = DataLoader(tensor_x_val, data_input_size) + self.y_val_dl = DataLoader(tensor_y_val, data_input_size) + self.x_test_dl = DataLoader(tensor_x_val, data_input_size) + self.y_test_dl = DataLoader(y_test, data_input_size) + + + def evaluate(model, dataloader): + model.eval() + + val_accuracy = [] + val_loss = [] + + for batch in dataloader: + + with torch.no_grad(): + logits = model(batch) + loss = loss_fn(logits, target) + val_loss.append(loss.item()) + preds = torch.argmax(logits, dim=1).flatten() + accuracy = (preds == target).cpu().numpy().mean() * 100 + val_accuracy.append(accuracy) - def set_params(self, *argv, **kwargs): - """ - Set parameters for the explainer. - """ - pass + val_loss = np.mean(val_loss) + val_accuracy = np.mean(val_accuracy) + + return val_loss, val_accuracy + + + + def fit(self, weight_decay = 0, patience = float('Inf'), min_delta = .0001, learning_rate = 1e-2, num_epoch = 100): + + self.model.train() + criterion = nn.CrossEntropyLoss() + EPOCHS = num_epoch + optm = torch.optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + last_loss = float('Inf') + min_loss = float('Inf') + + trigger_times = 0 - def print_accuracy(self, xtest, ytest): - results = self.model(xtest).detach().numpy() + + for epoch in range(EPOCHS): + epoch_loss = 0 + correct = 0 + for bidx, batch in tqdm(enumerate(self.dataloader)): + x_train = self.x_train_dl + y_train = self.y_train_dl + #print(f'batch size: {len(batch)}') + loss, predictions = self.train(self.model, + batch['tabular'], + batch['target'], + optm, + criterion) + if loss > last_loss: + trigger_times += 1 + + + if trigger_times >= patience: + print('Early stopping!\nStart to test process.') + if not os.path.exists(ckpt_path): + raise RuntimeError(f'\'{ckpt_path}\' does not exist') + self.model.load_state_dict(torch.load(ckpt_path)) + return self.model + else: + trigger_times = 0 + + last_loss = loss + + #self.model.eval() + for idx, i in enumerate(predictions): + predictions_max = torch.max(i) + index_of_max = list(i).index(max(list(i))) + + if index_of_max == self.tensor_y_train[idx]: + correct += 1 + + acc = (correct/len(self.y_train_dl)) + epoch_loss += loss + + if epoch_loss < min_loss: + min_loss = epoch_loss + torch.save(self.model.state_dict(), ckpt_path) + + print('Epoch {} Accuracy : {}'.format(epoch+1, acc*100)) + print('Current best loss: {}'.format(min_loss)) + print('Epoch {} Loss : {}'.format((epoch+1),epoch_loss)) + + + def train(self, model, x, y, optimizer, criterion): + optimizer.zero_grad() + output = model(x) + loss = criterion(output,y) + loss.backward() + optimizer.step() + + + return loss, output + + + def predict(self, input_x_tensor): + if not os.path.exists(ckpt_path): + raise RuntimeError(f'\'{ckpt_path}\' does not exist') + self.model.load_state_dict(torch.load(ckpt_path)) + return(self.model(input_x_tensor)) + + def print_accuracy(self): + results = self.predict(self.tensor_x_test).detach().numpy() idx = np.argmax(results, axis = -1) results = np.zeros(results.shape) results[ np.arange(results.shape[0]), idx] = 1 @@ -113,67 +279,55 @@ def print_accuracy(self, xtest, ytest): numTotal = 0 numCorrect = 0 for i in range(0, len(results)): - if results[i] == ytest[i]: + if results[i] == self.y_test[i]: numCorrect = numCorrect + 1 numTotal = numTotal + 1 print("Accuracy: ", numCorrect/numTotal) accuracy = float(numCorrect/numTotal) - - def explain(self, explain_mode, max_layer_num = 10, var_num = 6): - ''' - Provides Explanations of CoFrNet Model - - Args: - explain_mode: either "importances" or "print_co_fr", will raise exception if not one of these two options - max_layer_num: For "print_co_fr": Choose Depth of Ladder to Show, Default 10 - var_num: For "print_co_fr": Variable (index of input feature) for Which to Display Ladder, Default 6 - ''' - - def importances(self): - final_layer_weights = vars(self.model.layers[-1])['_parameters']['weight'].data.numpy() - weights_by_node = final_layer_weights.T - averaged = np.average(weights_by_node, axis = 1) - copy_averaged = averaged.copy() - print(copy_averaged) - num_important_to_print = 3 - for x in range(0, num_important_to_print): - min_idx = np.argmax(copy_averaged) - print("The number " + str(x+1) + " most important input feature was the " + str(min_idx+1) + "th one.") - copy_averaged[np.argmax(copy_averaged)] = copy_averaged[np.argmin(copy_averaged)] - #print(vars(self.model.layers[-1])['_parameters']['weight'].data.numpy().T) + def importances(self): + if not os.path.exists(ckpt_path): + raise RuntimeError(f'\'{ckpt_path}\' does not exist') + self.model.load_state_dict(torch.load(ckpt_path)) + final_layer_weights = vars(self.model.layers[-1])['_parameters']['weight'].data.numpy() + weights_by_node = final_layer_weights.T + averaged = np.average(weights_by_node, axis = 1) + copy_averaged = averaged.copy() + print(copy_averaged) + num_important_to_print = 3 + for x in range(0, num_important_to_print): + min_idx = np.argmax(copy_averaged) + print("The number " + str(x+1) + " most important input feature was the " + str(min_idx+1) + "th one.") + copy_averaged[np.argmax(copy_averaged)] = copy_averaged[np.argmin(copy_averaged)] - def print_co_fr(self, max_layer_num = 10, var_num = 6): - #max_layer_num = chosen depth of ladder to show (10 layers, index would be 9) - #var_num = variable for which to display ladder - thingToPrint = "" - for layerNum in range(0, max_layer_num-1): - temp = vars(self.model.layers[layerNum]) - print() - print("LayerNum: ", layerNum) - val = (temp['_parameters']['weight'].data[var_num][var_num]).numpy() - print("Val: ", val) - bias = temp['_parameters']['bias'].data[var_num].numpy() - print("Bias: ", bias) - if (bias > (.01 * val)): - print(str(bias)) - combined = "("+str(val) + "*x + " + str(bias)+")" - print("Combined: ", combined) - else: - print("hi") - combined = "(" + str(val)+"*x" + "+0)" - print("Combined: ", combined) - print() - thingToPrint = "1/(" + combined + " + (" + thingToPrint + "))" - - print(thingToPrint) - return thingToPrint - - if explain_mode == "importances": - importances() - elif explain_mode == "print_co_fr": - print_co_fr(max_layer_num, var_num) - else: - raise Exception("explain_mode must be either 'importances' or 'print_co_fr'") - + def explain(self, max_layer_num = 10, var_num = 6): + #max_layer_num = chosen depth of ladder to show (10 layers, index would be 9) + #var_num = variable for which to display ladder + thingToPrint = "" + if not os.path.exists(ckpt_path): + raise RuntimeError(f'\'{ckpt_path}\' does not exist') + self.model.load_state_dict(torch.load(ckpt_path)) + for layerNum in range(0, max_layer_num-1): + temp = vars(self.model.layers[layerNum]) + print() + print("LayerNum: ", layerNum) + val = (temp['_parameters']['weight'].data[var_num][var_num]).numpy() + print("Val: ", val) + bias = temp['_parameters']['bias'].data[var_num].numpy() + print("Bias: ", bias) + if (bias > (.01 * val)): + print(str(bias)) + combined = "("+str(val) + "*x + " + str(bias)+")" + print("Combined: ", combined) + #thingToPrint = "\n 1/("+str(val) + "x + " + str(bias)+")" + thingToPrint + else: + print("hi") + combined = "(" + str(val)+"*x" + "+0)" + print("Combined: ", combined) + #thingToPrint = "\n 1/(" + str(val)+"x" + "+0)" + thingToPrint + print() + thingToPrint = "1/(" + combined + " + (" + thingToPrint + "))" + + print(thingToPrint) + return thingToPrint diff --git a/aix360/algorithms/cofrnet/Customized_Linear_Classes.py b/aix360/algorithms/cofrnet/Customized_Linear_Classes.py new file mode 100644 index 0000000..d2ce17b --- /dev/null +++ b/aix360/algorithms/cofrnet/Customized_Linear_Classes.py @@ -0,0 +1,143 @@ +''' +author: @ishapuri101, @sadhamanus +''' + +# coding: utf-8 +import pandas as pd #loading data in table form +import numpy as np # linear algebra +import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) +import torch # import main library +import torch.nn as nn # import modules +from torch.autograd import Function # import Function to create custom activations +from torch.nn.parameter import Parameter # import Parameter to create custom activations with learnable parameters +import torch.nn.functional as F # import torch functions +import math + + +#check whether cuda is available +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +''' +CUSTOM LINEAR CLASS +''' +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +extended torch.nn module which cusmize connection. +This code base on https://pytorch.org/docs/stable/notes/extending.html +""" + +################################# +# Define custome autograd function for masked connection. + +class CustomizedLinearFunction(torch.autograd.Function): + """ + autograd function which masks it's weights by 'mask'. + """ + + # Note that both forward and backward are @staticmethods + @staticmethod + # bias, mask is an optional argument + def forward(ctx, input, weight, bias=None, mask=None): + if mask is not None: + # change weight to 0 where mask == 0 + weight = weight * mask + #print(f'input size: {input.size()}, weight size: {weight.t().size()}') + output = input.mm(weight.t()) #output = input.mm(weight.t()) + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + ctx.save_for_backward(input, weight, bias, mask) + return output + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, grad_output): + #pdb.set_trace() + #print(float(grad_output)) + # This is a pattern that is very convenient - at the top of backward + # unpack saved_tensors and initialize all gradients w.r.t. inputs to + # None. Thanks to the fact that additional trailing Nones are + # ignored, the return statement is simple even when the function has + # optional inputs. + input, weight, bias, mask = ctx.saved_tensors + grad_input = grad_weight = grad_bias = grad_mask = None + + # These needs_input_grad checks are optional and there only to + # improve efficiency. If you want to make your code simpler, you can + # skip them. Returning gradients for inputs that don't require it is + # not an error. + if ctx.needs_input_grad[0]: + grad_input = grad_output.mm(weight) + if ctx.needs_input_grad[1]: + grad_weight = grad_output.t().mm(input) + if mask is not None: + # change grad_weight to 0 where mask == 0 + grad_weight = grad_weight * mask + #if bias is not None and ctx.needs_input_grad[2]: + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0).squeeze(0) + + return grad_input, grad_weight, grad_bias, grad_mask + + +class CustomizedLinear(nn.Module): + def __init__(self, mask, bias=True): + """ + extended torch.nn module which mask connection. + Argumens + ------------------ + mask [torch.tensor]: + the shape is (n_input_feature, n_output_feature). + the elements are 0 or 1 which declare un-connected or + connected. + bias [bool]: + flg of bias. + """ + super(CustomizedLinear, self).__init__() + self.input_features = mask.shape[0] + self.output_features = mask.shape[1] + #print("self.output_features hi hi hi", self.output_features) + if isinstance(mask, torch.Tensor): + self.mask = mask.type(torch.float).t() + else: + self.mask = torch.tensor(mask, dtype=torch.float).t() + + self.mask = nn.Parameter(self.mask, requires_grad=False) + + # nn.Parameter is a special kind of Tensor, that will get + # automatically registered as Module's parameter once it's assigned + # as an attribute. Parameters and buffers need to be registered, or + # they won't appear in .parameters() (doesn't apply to buffers), and + # won't be converted when e.g. .cuda() is called. You can use + # .register_buffer() to register buffers. + # nn.Parameters require gradients by default. + self.weight = nn.Parameter(torch.Tensor(self.output_features, self.input_features)) + + if bias: + self.bias = nn.Parameter(torch.Tensor(self.output_features)) + else: + # You should always register all possible parameters, but the + # optional ones can be None if you want. + self.register_parameter('bias', None) + self.reset_parameters() + + # mask weight + # commented out may 5 self.weight.data = self.weight.data * self.mask + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + + + def forward(self, input): + # See the autograd section for explanation of what happens here. + return CustomizedLinearFunction.apply(input, self.weight, self.bias, self.mask) + + def extra_repr(self): + # (Optional)Set the extra information about this module. You can test + # it by printing an object of this class. + return 'input_features={}, output_features={}, bias={}'.format( + self.input_features, self.output_features, self.bias is not None + ) diff --git a/aix360/algorithms/cofrnet/utils.py b/aix360/algorithms/cofrnet/utils.py index 48ac255..039638a 100644 --- a/aix360/algorithms/cofrnet/utils.py +++ b/aix360/algorithms/cofrnet/utils.py @@ -1,39 +1,32 @@ ''' -author: @ishapuri101 +author: @ishapuri101, @sadhamanus ''' import numpy as np # linear algebra from sklearn.preprocessing import MinMaxScaler -from tqdm import tqdm -import pandas as pd #loading data in table form -from torch.utils.data import Dataset import torch # import main library -import torch.nn as nn # import modules -from torch.autograd import Function # import Function to create custom activations -from torch.nn.parameter import Parameter # import Parameter to create custom activations with learnable parameters -import torch.nn.functional as F # import torch functions -from sklearn.preprocessing import MinMaxScaler -from sklearn.model_selection import train_test_split -import torch.optim as optim -#import random -from torch.utils.data import DataLoader +import pandas as pd #loading data in table form + +def generate_connections(num_total_layers: int, input_size: int, output_size: int, which_variant: str, num_nodes = float('Inf')): + ''' + Args: + num_total_layers: depth of the ladders + input_size: number of input features + output_size: number of output features + which_variant: choose one of CoFrNet Variants to work with: fully_connected, diagonalized, ladder_of_ladders, or diag_ladder_of_ladder_combined + Returns: + a 3D Matrix of 1s and 0s, where a 1 signifies that a connection exists between two nodes, and a 0 signifies that it doesn't. + ''' -def generate_connections(num_total_layers: int, - input_size: int, - output_size: int, - which_variant: str, - num_nodes = 0, - feature_index = 0, - features_to_use = []): def fully_connected_constant(): - num_total_layers = num_total_layers - 1 #not including output layer + numLayers_notInclOutput = num_total_layers - 1 #not including output layer genConns = [] - for i in range(0, num_total_layers): + for i in range(0, numLayers_notInclOutput): genConns.append(np.ones([input_size, num_nodes]).tolist()) genConns.append(np.ones([num_nodes, output_size]).tolist()) @@ -64,60 +57,26 @@ def ladder_of_ladders(): return getConns - def diagonalized_ladder_of_ladders_combined(): + def diagonalized_ladder_of_ladders_combinedd(): #numLayers DOES include output layer #numLayers = numLadders in this case getConns = [] - numLayers_notIncOutput = num_total_layers - 1 #numLayers_notIncOutput = input_size - + numLayers_notIncOutput = min(num_nodes, num_total_layers - 1) + #print(f'Max number of full ladders: {numLayers_notIncOutput}') + for i in range(0, numLayers_notIncOutput): ladderOfLadders = np.ones([input_size, numLayers_notIncOutput]) ladderOfLadders[:, 0:i] = 0 toAppend = np.append(np.eye(input_size), ladderOfLadders, axis = 1) getConns.append(toAppend.tolist()) + #print(len(getConns[-1][0])) getConns.append(np.ones([len(getConns[-1][0]), output_size]).tolist()) + #getConns.append(np.ones([numLayers_notIncOutput, output_size]).tolist()) return getConns - def one_feature_diagonalized(): - #numLayers DOES include output layer - - numLayers_notInclOutput = num_total_layers - 1 - genConns = [] - - for i in range(0, numLayers_notInclOutput): - toAppend = np.zeros([input_size, 1]) - toAppend[feature_index][0] = 1 - genConns.append(toAppend.tolist()) - - genConns.append(np.ones([1, output_size]).tolist()) - - return genConns - - def n_feature_fully_connected(): - #numLayers DOES include output layer - - numLayers_notInclOutput = num_total_layers - 1 - features_not_to_use = [] - for i in range(0, input_size): - if i not in features_to_use: - features_not_to_use.append(i) - genConns = [] - - for i in range(0, numLayers_notInclOutput): - toAppend = np.ones([input_size, input_size]) - for num in features_not_to_use: - for j in range(0, input_size): - toAppend[j][num] = 0 - toAppend[num][j] = 0 - genConns.append(toAppend.tolist()) - - genConns.append(np.ones([input_size, output_size]).tolist()) - - return genConns - if which_variant == "fully_connected": return fully_connected_constant() @@ -126,13 +85,9 @@ def n_feature_fully_connected(): elif which_variant == "ladder_of_ladders": return ladder_of_ladders() elif which_variant == "diag_ladder_of_ladder_combined": - return diagonalized_ladder_of_ladders_combined() - elif which_variant == "one_feature_diag": - return one_feature_diagonalized() - elif which_variant == 'n_feature_fully_connected': - return n_feature_fully_connected() + return diagonalized_ladder_of_ladders_combinedd() else: - raise Exception("You must choose one of the following four choices for which_variant: fully_connected, diagonalized, ladder_of_ladders, diag_ladder_of_ladder_combined, or one_feature_diag") + raise Exception("You must choose one of the following four choices for which_variant: fully_connected, diagonalized, ladder_of_ladders, or diag_ladder_of_ladder_combined") @@ -141,28 +96,24 @@ def n_feature_fully_connected(): -def process_data(first_column_csv, last_column_csv, web_link = None, data_filename = None): - #data_filename: filename of data source - #first_column_csv: index (starting from 0) of first column to include in dataset - #last_column_csv: index (starting from 0) of last column to include in dataset. Use -1 if you want to include all of the columns. +def process_data(data_filename, first_column_csv, last_column_csv): + ''' + Args: + data_filename: filename of data source + first_column_csv: index (starting from 0) of first column to include in dataset + last_column_csv: index (starting from 0) of last column to include in dataset. Use -1 if you want to include all of the columns. + Returns: + tensor_x_train, tensor_y_train, tensor_x_val, tensor_y_val, tensor_x_test, y_test + ''' + import pandas as pd - - if web_link is not None: - pathname = web_link - else: - pathname = data_filename #'datasets/' + data_filename - df=pd.read_csv(pathname, sep=',',header=0, lineterminator='\r') + df=pd.read_csv('datasets/' + data_filename, sep=',',header=0) if last_column_csv != -1: last_column_csv = last_column_csv + 1 - df = df.sample(frac = 1) X = df.iloc[:, first_column_csv : last_column_csv].values - y = df.iloc[:,-1].values.T - - - from sklearn.preprocessing import LabelEncoder le = LabelEncoder() y = le.fit_transform(y) @@ -176,16 +127,8 @@ def process_data(first_column_csv, last_column_csv, web_link = None, data_filena seed = seeds[2] from sklearn.model_selection import train_test_split - X_train, X_test, y_train, y_test = train_test_split(X, - y, - test_size = 0.3, - random_state = seed, - shuffle = True) - X_train, X_val, y_train, y_val = train_test_split(X_train, - y_train, - test_size=0.05, - random_state=seed, - shuffle = True) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = seed) + X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.05, random_state=seed) #CONVERTING TO TENSOR tensor_x_train = torch.Tensor(X_train) @@ -199,69 +142,3 @@ def process_data(first_column_csv, last_column_csv, web_link = None, data_filena return tensor_x_train, tensor_y_train, tensor_x_val, tensor_y_val, tensor_x_test, y_test -def onehot_encoding(label, n_classes): - """Conduct one-hot encoding on a label vector.""" - label = label.view(-1) - onehot = torch.zeros(label.size(0), n_classes).float().to(label.device) - onehot.scatter_(1, label.view(-1, 1), 1) - - return onehot - - - - - - - - -def train(model, dataloader, num_classes, lr = 0.001, momentum = 0.9, epochs = 20): - #criterion = nn.CrossEntropyLoss() - criterion = nn.MSELoss(reduction="sum") - #optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) - optimizer = optim.Adam(model.parameters(), lr=lr) - - EPOCHS = epochs - for epoch in range(EPOCHS): # loop over the dataset multiple times - print("Epoch: ", epoch) - running_loss = 0.0 - #for i, data in enumerate(trainloader, 0): - for i, batch in tqdm(enumerate(dataloader)): - # get the inputs; data is a list of [inputs, labels] - # forward + backward + optimize - - batch['tabular'].requires_grad=True - - - outputs = model(batch['tabular']) - - one_hot_encoded_target = onehot_encoding(batch['target'], num_classes) - - #loss = criterion(outputs, batch['target']) - loss = criterion(outputs, one_hot_encoded_target) - - # zero the parameter gradients - optimizer.zero_grad() - - loss.backward() - optimizer.step() - - # print statistics - running_loss += loss.item() - #print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}') - print("Loss: ", running_loss) - - print('Finished Training') - -class OnlyTabularDataset(Dataset): - def __init__(self, values, label): - self.values = values - self.label = label - - def __len__(self): - return len(self.label) - - def __getitem__(self, index): - return { - 'tabular': torch.tensor(self.values[index], dtype=torch.float), - 'target' : torch.tensor(self.label[index], dtype=torch.long) - } \ No newline at end of file From b83073f33632d7a5fe948eb86b916f0165dcdc06 Mon Sep 17 00:00:00 2001 From: sadhamanus Date: Fri, 24 Jan 2025 09:08:12 -0500 Subject: [PATCH 2/3] Update cofrnet_example.ipynb Updated custom linear call --- examples/cofrnet/cofrnet_example.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/cofrnet/cofrnet_example.ipynb b/examples/cofrnet/cofrnet_example.ipynb index 182ea39..9be6b7b 100644 --- a/examples/cofrnet/cofrnet_example.ipynb +++ b/examples/cofrnet/cofrnet_example.ipynb @@ -21,8 +21,8 @@ "outputs": [], "source": [ "# Imports and Seeds\n", - "from aix360.algorithms.cofrnet.CustomizedLinearClasses import CustomizedLinearFunction\n", - "from aix360.algorithms.cofrnet.CustomizedLinearClasses import CustomizedLinear\n", + "from aix360.algorithms.cofrnet.Customized_Linear_Classes import CustomizedLinearFunction\n", + "from aix360.algorithms.cofrnet.Customized_Linear_Classes import CustomizedLinear\n", "from aix360.algorithms.cofrnet.utils import generate_connections\n", "from aix360.algorithms.cofrnet.utils import process_data\n", "from aix360.algorithms.cofrnet.utils import train\n", From d4562fdfaed5c63ca4e700f9feba21a1a4baa8bf Mon Sep 17 00:00:00 2001 From: sadhamanus Date: Tue, 4 Feb 2025 11:46:56 -0500 Subject: [PATCH 3/3] Update CoFrNet.py --- aix360/algorithms/cofrnet/CoFrNet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aix360/algorithms/cofrnet/CoFrNet.py b/aix360/algorithms/cofrnet/CoFrNet.py index 79fac35..258922c 100644 --- a/aix360/algorithms/cofrnet/CoFrNet.py +++ b/aix360/algorithms/cofrnet/CoFrNet.py @@ -86,7 +86,7 @@ def modified_reciprocal_activation(self, Wx, epsilon = .01): ''' epsilonMatrix = torch.mul(torch.full_like(Wx, epsilon), torch.sign(Wx)) denom = torch.where(torch.abs(Wx) < epsilon, epsilonMatrix, Wx) - denom = torch.nan_to_num(denom, nan=epsilon) + denom = torch.where(denom == 0, epsilon, denom) #print(epsilonMatrix, Wx) if torch.any(torch.isnan(torch.reciprocal(denom))): print(f'nans present when doing 1/x')