diff --git a/.coverage b/.coverage index 0b5c8725..eee57be0 100644 Binary files a/.coverage and b/.coverage differ diff --git a/mplc/contributivity.py b/mplc/contributivity.py index 848d1553..cd3f77a4 100644 --- a/mplc/contributivity.py +++ b/mplc/contributivity.py @@ -12,12 +12,13 @@ from timeit import default_timer as timer import numpy as np +import tensorflow as tf from loguru import logger from scipy.stats import norm from sklearn.linear_model import LinearRegression from . import constants -from .multi_partner_learning import basic_mpl +from .multi_partner_learning import basic_mpl, fast_mpl class KrigingModel: @@ -84,9 +85,16 @@ def __str__(self): + str(self.first_charac_fct_calls_count) + "\n" ) - output += f"Contributivity scores: {np.round(self.contributivity_scores, 3)}\n" - output += f"Std of the contributivity scores: {np.round(self.scores_std, 3)}\n" - output += f"Normalized contributivity scores: {np.round(self.normalized_scores, 3)}\n" + if isinstance(self.contributivity_scores, dict): + for key, value in self.contributivity_scores.items(): + output += f'Metric: {key}\n' + output += f"Contributivity scores : {np.round(value, 3)}\n" + output += f"Std of the contributivity scores: {np.round(self.scores_std[key], 3)}\n" + output += f"Normalized contributivity scores: {np.round(self.normalized_scores[key], 3)}\n" + else: + output += f"Contributivity scores : {np.round(self.contributivity_scores, 3)}\n" + output += f"Std of the contributivity scores: {np.round(self.scores_std, 3)}\n" + output += f"Normalized contributivity scores: {np.round(self.normalized_scores, 3)}\n" return output @@ -1113,23 +1121,41 @@ def compute_relative_perf_matrix(self): return relative_perf_matrix - def s_model(self): # TOD refacto + def statistcal_distances_via_smodel(self): + start = timer() - mpl = basic_mpl.FedAvgSmodel(self.scenario) + mpl = fast_mpl.FastFedAvgSmodel(self.scenario, **self.scenario.mpl_kwargs) mpl.fit() - theta_estimated = np.zeros((mpl.partners_count, - mpl.dataset.num_classes, - mpl.dataset.num_classes)) + cross_entropy = tf.keras.metrics.CategoricalCrossentropy() + self.contributivity_scores = {'Kullback Leiber divergence': [0 for _ in mpl.partners_list], + 'Bhattacharyya distance': [0 for _ in mpl.partners_list], + 'Hellinger metric': [0 for _ in mpl.partners_list]} + self.scores_std = {'Kullback Leiber divergence': [0 for _ in mpl.partners_list], + 'Bhattacharyya distance': [0 for _ in mpl.partners_list], + 'Hellinger metric': [0 for _ in mpl.partners_list]} + # TODO; The variance of our estimation is likely to be estimated. + for i, partnerMpl in enumerate(mpl.partners_list): - theta_estimated[i] = (np.exp(partnerMpl.noise_layer_weights) / np.sum( - np.exp(partnerMpl.noise_layer_weights), axis=2)) - self.contributivity_scores = np.exp(- np.array([np.linalg.norm( - theta_estimated[i] - np.identity(mpl.dataset.num_classes) - ) for i in range(len(self.scenario.partners_list))])) - - self.name = "S-Model" - self.scores_std = np.zeros(mpl.partners_count) - self.normalized_scores = self.contributivity_scores / np.sum(self.contributivity_scores) + y_global = mpl.model.predict(partnerMpl.x_train) + y_local = mpl.smodel_list[i].predict(y_global) + cross_entropy.update_state(y_global, y_local) + cs = cross_entropy.result().numpy() + cross_entropy.reset_states() + cross_entropy.update_state(y_global, y_global) + e = cross_entropy.result().numpy() + cross_entropy.reset_states() + BC = 0 + for y_g, y_l in zip(y_global, y_local): + BC += np.sum(np.sqrt(y_g * y_l)) + BC /= len(y_global) + self.contributivity_scores['Kullback Leiber divergence'][i] = cs - e + self.contributivity_scores['Bhattacharyya distance'][i] = - np.log(BC) + self.contributivity_scores['Hellinger metric'][i] = np.sqrt(1 - BC) + + self.name = "Statistic metric via S-model" + self.normalized_scores = {} + for key, value in self.contributivity_scores.items(): + self.normalized_scores[key] = value / np.sum(value) end = timer() self.computation_time_sec = end - start @@ -1195,7 +1221,7 @@ def compute_contributivity( # Contributivity 10: Partner valuation by reinforcement learning self.PVRL(learning_rate=0.2) elif method_to_compute == "S-Model": - self.s_model() + self.statistcal_distances_via_smodel() else: logger.warning("Unrecognized name of method, statement ignored!") diff --git a/mplc/doc/documentation.md b/mplc/doc/documentation.md index 535fdedb..ccad33be 100644 --- a/mplc/doc/documentation.md +++ b/mplc/doc/documentation.md @@ -314,6 +314,20 @@ There are several parameters influencing how the collaborative and distributed l - `'seqavg'`: stands for sequential averaging ![Schema seqavg](../../img/collaborative_rounds_seqavg.png) + + The previous methods are implemented to be agnostic to the model used. However, some of these methods are also implemented within the tensorflow interface, at a lower level. These implementations are usually faster, especially if you are using a GPU. Unfortunately, those methods are only compatible with tensorflow.keras-based models. The mplc-native dataset `Titanic` cannot be used. + + Available methods: + - `'fast-fedavg'`: equivalent to FedAvg. + - `'fast-fedgrads'`: equivalent to FedGrad. + - `'fast-fedavg-smodel'`: equivalent to FedAvg, with smodel + - `'fast-fedgrad-smodel'`: equivalent to FedGrad with smodel + - `'fast-fedgdo'`: Stand for Federated averaging with double optimizers. This method is inspired from Federated gradient, but with modification on the local computation of the gradient. + A local optimizer (partner-specific) is used to do several minimization steps (local minibatches) of the local-loss + during a global-minibatch. We use the sum of these weighs-updates as the gradient which is sent to the global optimizer. + The global optimizer aggregates these gradients, which have been sent by the partners, + and performs a optimization step with this aggregated gradient. + Example: `multi_partner_learning_approach='seqavg'` diff --git a/mplc/multi_partner_learning/__init__.py b/mplc/multi_partner_learning/__init__.py index 74a02708..d46f15b4 100644 --- a/mplc/multi_partner_learning/__init__.py +++ b/mplc/multi_partner_learning/__init__.py @@ -14,7 +14,8 @@ 'fast-fedavg': fast_mpl.FastFedAvg, 'fast-fedgrads': fast_mpl.FastFedGrad, 'fast-fedavg-smodel': fast_mpl.FastFedAvgSmodel, - 'fast-fedgrad-smodel': fast_mpl.FastGradSmodel + 'fast-fedgrad-smodel': fast_mpl.FastGradSmodel, + 'fast-fedgdo': fast_mpl.FastFedGDO } MULTI_PARTNER_LEARNING_APPROACHES = BASIC_MPL_APPROACHES.copy() diff --git a/mplc/multi_partner_learning/basic_mpl.py b/mplc/multi_partner_learning/basic_mpl.py index bed2e384..8c618e9d 100644 --- a/mplc/multi_partner_learning/basic_mpl.py +++ b/mplc/multi_partner_learning/basic_mpl.py @@ -544,7 +544,8 @@ def fit(self): for p in self.partners_list: confusion = confusion_matrix(np.argmax(p.y_train, axis=1), np.argmax(pretrain_model.predict(p.x_train), axis=1), - normalize='pred') + normalize='pred', + labels=list(range(10))) p.noise_layer_weights = [np.log(confusion.T + 1e-8)] self.model_weights[:-1] = self.pretrain_mpl.model_weights[:-1] else: diff --git a/mplc/multi_partner_learning/fast_mpl.py b/mplc/multi_partner_learning/fast_mpl.py index edeaf8cf..5471470b 100644 --- a/mplc/multi_partner_learning/fast_mpl.py +++ b/mplc/multi_partner_learning/fast_mpl.py @@ -376,7 +376,7 @@ def fit_minibatch(model, partners_minibatches, partners_optimizers, partners_wei for p in self.partners_list: confusion = confusion_matrix(np.argmax(p.y_train, axis=1), np.argmax(self.model.predict(p.x_train), axis=1), - normalize='pred') + normalize='pred', labels=list(range(10))) p.noise_layer_weights = [np.log(confusion.T + 1e-8)] else: for p in self.partners_list: @@ -549,7 +549,7 @@ def fit_epoch(model, train_dataset, partners_grads, smodel_list, global_grad, ag for p in self.partners_list: confusion = confusion_matrix(np.argmax(p.y_train, axis=1), np.argmax(self.model.predict(p.x_train), axis=1), - normalize='pred') + normalize='pred', labels=list(range(10))) p.noise_layer_weights = [np.log(confusion.T + 1e-8)] else: for p in self.partners_list: @@ -575,3 +575,95 @@ def fit_epoch(model, train_dataset, partners_grads, smodel_list, global_grad, ag break self.log_end_training() + + +class FastFedGDO(FastFedAvg): + """ + This method is inspired from Federated gradient, but with modification on the local computation of the gradient. + In this version we use a local optimizer (partner-specific) to do several minimization steps of the local-loss + during a minibatch. We use the sum of these weighs-updates as the gradient which is sent to the global optimizer. + The global optimizer aggregates these gradients-like which have been sent by the partners, + and performs a optimization step with this aggregated gradient. + """ + name = 'FastFedGDO' + + def __init__(self, scenario, reset_local_optims=False, global_optimiser=None, **kwargs): + self.global_optimiser = global_optimiser + self.reset_local_optims = reset_local_optims + super(FastFedGDO, self).__init__(scenario, **kwargs) + + def init_specific_tf_variable(self): + # generate tf Variables in which we will store the model weights + self.model_stateholder = [tf.Variable(initial_value=w.read_value()) for w in self.model.trainable_weights] + self.partners_grads = [[tf.Variable(initial_value=w.read_value()) for w in self.model.trainable_weights] + for _ in self.partners_list] + self.global_grad = [tf.Variable(initial_value=w.read_value()) for w in self.model.trainable_weights] + self.partners_optimizers = [self.model.optimizer.from_config(self.model.optimizer.get_config()) for _ in + self.partners_list] + if self.global_optimiser: + self.model.compile(optimizer=self.global_optimiser) + + def fit(self): + # TF function definition + @tf.function + def fit_minibatch(model, model_stateholder, partners_minibatches, partners_optimizers, partners_grads, + global_grad, aggregation_weights): + for model_w, old_w in zip(model.trainable_weights, model_stateholder): # store model weights + old_w.assign(model_w.read_value()) + + for p_id, minibatch in enumerate(partners_minibatches): # minibatch == (x,y) + # minibatch[0] in a tensor of shape=(number of batch, batch size, img). + # We cannot iterate on tensors, so we convert this tensor to a list of + # *number of batch* tensors with shape=(batch size, img) + x_minibatch = tf.unstack(minibatch[0], axis=0) + y_minibatch = tf.unstack(minibatch[1], axis=0) # same here, with labels + + for x, y in zip(x_minibatch, y_minibatch): # iterate over batches + with tf.GradientTape() as tape: + y_pred = model(x) + loss = model.compiled_loss(y, y_pred) + model.compiled_metrics.update_state(y, y_pred) # log the loss and accuracy + partners_optimizers[p_id].minimize(loss, model.trainable_weights, + tape=tape) # perform local optimizations + # get the gradient as theta_before_minibatch - theta_after_minibatch + for grad_per_layer, w_old, w_new in zip(partners_grads[p_id], model_stateholder, + model.trainable_weights): + grad_per_layer.assign((w_old - w_new)) + + for model_w, old_w in zip(model.trainable_weights, + model_stateholder): # reset the model's weights for the next partner + model_w.assign(old_w.read_value()) + + # at the end of the minibatch, aggregate all the local grads + for i, grads_per_layer in enumerate(zip(*partners_grads)): + global_grad[i].assign(tf.tensordot(grads_per_layer, aggregation_weights, [0, 0])) + + # perform one optimization update using the aggregated gradient + model.optimizer.apply_gradients( + zip(global_grad, model.trainable_weights)) + + # Execution + + self.timer = time.time() + for e in range(self.epoch_count): + self.epoch_timer = time.time() + for partners_minibatches in zip(*self.train_dataset): # <- partners_minibatches == [(x, y)] * nb_partners + if self.reset_local_optims: + self.partners_optimizers = [self.model.optimizer.from_config(self.model.optimizer.get_config()) for + _ in + self.partners_list] # reset the local optimizers + fit_minibatch(self.model, + self.model_stateholder, + partners_minibatches, + self.partners_optimizers, + self.partners_grads, + self.global_grad, + self.aggregation_weights) + epoch_history = self.get_epoch_history() # compute val and train acc and loss. + # add the epoch _history to self _history, and log epoch number, and metrics values. + self.log_epoch(e, epoch_history) + self.epochs_index += 1 + if self.early_stop(): + break + + self.log_end_training() diff --git a/mplc/scenario.py b/mplc/scenario.py index ace252ed..9661832c 100644 --- a/mplc/scenario.py +++ b/mplc/scenario.py @@ -537,22 +537,39 @@ def to_dataframe(self): df = df.append(dict_results, ignore_index=True) for contrib in self.contributivity_list: - # Contributivity data - dict_results["contributivity_method"] = contrib.name - dict_results["contributivity_scores"] = contrib.contributivity_scores - dict_results["contributivity_stds"] = contrib.scores_std - dict_results["computation_time_sec"] = contrib.computation_time_sec - dict_results["first_characteristic_calls_count"] = contrib.first_charac_fct_calls_count - - for i in range(self.partners_count): - # Partner-specific data - dict_results["partner_id"] = i - dict_results["dataset_fraction_of_partner"] = self.amounts_per_partner[i] - dict_results["contributivity_score"] = contrib.contributivity_scores[i] - dict_results["contributivity_std"] = contrib.scores_std[i] - - df = df.append(dict_results, ignore_index=True) + if isinstance(contrib.contributivity_scores, dict): + for key, value in contrib.contributivity_scores.items(): + dict_results["contributivity_method"] = f'{contrib.name} - {key}' + dict_results["contributivity_scores"] = value + dict_results["contributivity_stds"] = contrib.scores_std[key] + dict_results["computation_time_sec"] = contrib.computation_time_sec + dict_results["first_characteristic_calls_count"] = contrib.first_charac_fct_calls_count + + for i in range(self.partners_count): + # Partner-specific data + dict_results["partner_id"] = i + dict_results["dataset_fraction_of_partner"] = self.amounts_per_partner[i] + dict_results["contributivity_score"] = value[i] + dict_results["contributivity_std"] = contrib.scores_std[key][i] + + df = df.append(dict_results, ignore_index=True) + + else: + dict_results["contributivity_method"] = contrib.name + dict_results["contributivity_scores"] = contrib.contributivity_scores + dict_results["contributivity_stds"] = contrib.scores_std + dict_results["computation_time_sec"] = contrib.computation_time_sec + dict_results["first_characteristic_calls_count"] = contrib.first_charac_fct_calls_count + + for i in range(self.partners_count): + # Partner-specific data + dict_results["partner_id"] = i + dict_results["dataset_fraction_of_partner"] = self.amounts_per_partner[i] + dict_results["contributivity_score"] = contrib.contributivity_scores[i] + dict_results["contributivity_std"] = contrib.scores_std[i] + + df = df.append(dict_results, ignore_index=True) return df diff --git a/tests/contrib_end_to_end_test.py b/tests/contrib_end_to_end_test.py index 83b46f29..7781331e 100644 --- a/tests/contrib_end_to_end_test.py +++ b/tests/contrib_end_to_end_test.py @@ -71,7 +71,7 @@ def test_all_contrib_methods(self): exp.run() df = exp.result - assert len(df) == 2 * len(all_methods) + assert len(df) == 2 * (len(all_methods) + 2) # the S-Model contributivity generates 3 lines per partner def test_IS_reg_S_contrib(self): """