diff --git a/lib/python/.gitignore b/lib/python/.gitignore index 29963fa27..922355ed3 100644 --- a/lib/python/.gitignore +++ b/lib/python/.gitignore @@ -8,4 +8,5 @@ flame.egg-info data wandb -examples/fwdllm/expts/run_tc_expts/mpi_host_file \ No newline at end of file +examples/fwdllm/expts/run_tc_expts/mpi_host_file +lib/python/examples/fwdllm/expts/run_tc_expts/mpi_host_file \ No newline at end of file diff --git a/lib/python/examples/async_cifar10/aggregator/pytorch/main.py b/lib/python/examples/async_cifar10/aggregator/pytorch/main.py index 6ef0f5c0c..531c3f8d6 100644 --- a/lib/python/examples/async_cifar10/aggregator/pytorch/main.py +++ b/lib/python/examples/async_cifar10/aggregator/pytorch/main.py @@ -124,7 +124,7 @@ def __init__( ) self.trainer_event_dict = None if ( - self.track_trainer_avail["enabled"] + self.track_trainer_avail["enabled"] != None and self.track_trainer_avail["type"] == "ORACULAR" ): self.trainer_event_dict = self.read_trainer_unavailability( diff --git a/lib/python/examples/fwdllm/aggregator/fl_main.py b/lib/python/examples/fwdllm/aggregator/fl_main.py index 3d7f897a4..4cef9a1e9 100755 --- a/lib/python/examples/fwdllm/aggregator/fl_main.py +++ b/lib/python/examples/fwdllm/aggregator/fl_main.py @@ -83,15 +83,15 @@ def post_complete_message(tc_args): parser.add_argument("--log_level", type=str, default="INFO", required=False) args = parser.parse_args() - logger.setLevel(args.log_level) + logger.setLevel(args.log_level.upper()) config = Config(args.config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # customize the log format logging.basicConfig( - level=logging._nameToLevel[args.log_level], - format="%(process)s %(asctime)s.%(msecs)03d - {%(module)s.py (%(lineno)d)} - %(funcName)s(): %(message)s", - datefmt="%Y-%m-%d,%H:%M:%S", + level=logging._nameToLevel[args.log_level.upper()], + format="%(asctime)s | %(filename)s:%(lineno)d | %(levelname)s | %(threadName)s | %(funcName)s | %(message)s", + force=True, ) logger.info(config) set_seed(config.hyperparameters.manual_seed) diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/fedavg_main_tc.py b/lib/python/examples/fwdllm/expts/run_tc_expts/fedavg_main_tc.py index 147a1cba9..d8f53ce31 100755 --- a/lib/python/examples/fwdllm/expts/run_tc_expts/fedavg_main_tc.py +++ b/lib/python/examples/fwdllm/expts/run_tc_expts/fedavg_main_tc.py @@ -62,9 +62,9 @@ def post_complete_message(tc_args): # customize the log format logging.basicConfig( - level=logging.INFO, - format="%(process)s %(asctime)s.%(msecs)03d - {%(module)s.py (%(lineno)d)} - %(funcName)s(): %(message)s", - datefmt="%Y-%m-%d,%H:%M:%S", + level=getattr(logging, getattr(args, "log_level", "INFO").upper(), logging.INFO), + format="%(asctime)s | %(filename)s:%(lineno)d | %(levelname)s | %(threadName)s | %(funcName)s | %(message)s", + force=True, ) logging.info(args) diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/json_scripts/aggregator.json b/lib/python/examples/fwdllm/expts/run_tc_expts/json_scripts/aggregator.json index 3ea80a758..58df2e339 100644 --- a/lib/python/examples/fwdllm/expts/run_tc_expts/json_scripts/aggregator.json +++ b/lib/python/examples/fwdllm/expts/run_tc_expts/json_scripts/aggregator.json @@ -114,7 +114,8 @@ "roundNudgeType": "last_train", "minInitialTrainers": 50, "k": 10, - "is_async": true + "is_async": true, + "stat_utility": "partial" } }, "optimizer": { @@ -123,11 +124,14 @@ "use_oort_lr": "True", "dataset_name": "google-speech", "agg_rate_conf": { - "type": "old", - "scale": 1.0, + "type": "new", + "scale": 0.0, "a_exp": 0.2, - "b_exp": 0.25 - } + "b_exp": 0.25, + "alpha_type": "polynomial", + "beta_type": "exponential" + }, + "stat_utility": "partial" } }, "maxRunTime": 600, diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/mpi_host_file b/lib/python/examples/fwdllm/expts/run_tc_expts/mpi_host_file index 73991df21..1aa403983 100755 --- a/lib/python/examples/fwdllm/expts/run_tc_expts/mpi_host_file +++ b/lib/python/examples/fwdllm/expts/run_tc_expts/mpi_host_file @@ -1 +1 @@ -wash.cc.gatech.edu +jayne.cc.gatech.edu diff --git a/lib/python/examples/fwdllm/extract_stat_utility.py b/lib/python/examples/fwdllm/extract_stat_utility.py new file mode 100644 index 000000000..c219efb41 --- /dev/null +++ b/lib/python/examples/fwdllm/extract_stat_utility.py @@ -0,0 +1,69 @@ +import sys +import re +import csv +import os + +def extract_utilities(log_file, output_file): + if not os.path.exists(log_file): + print(f"Error: Log file not found at {log_file}") + sys.exit(1) + + print(f"Extracting stat_utility metrics from {log_file}...") + + # Regex patterns for matching top_aggregator and fwdllm_aggregator logs + # top_agg_pattern = re.compile(r"Received weights from (\S+)\. It was trained on model version (\d+), with (\d+) samples\. Returned partial stat utility ([\d\.]+) and full stat utility ([\d\.]+)") + fwdllm_agg_pattern = re.compile(r"Aggregated utilities for (\S+)\. Partial stat utility used for FedBuff: ([\d\.]+)\. Full stat utility stored for Oort: ([\d\.]+)") + + # Regex for async_oort selector debug logs (if enabled) + oort_selector_pattern = re.compile(r"Trainer (\S+) full_dataset_stat_utility: unnormalized = ([\d\.]+), normalized = ([\d\.]+), partial_dataset_stat_utility: ([\d\.]+)") + + results = [] + + with open(log_file, 'r') as f: + for line in f: + # Match top aggregator logs + m_top = top_agg_pattern.search(line) + if m_top: + tid = m_top.group(1) + version = m_top.group(2) + samples = m_top.group(3) + partial_util = m_top.group(4) + full_util = m_top.group(5) + results.append([tid, version, samples, partial_util, full_util, "N/A", "top_aggregator"]) + continue + + # Match fwdllm aggregator logs + m_fwd = fwdllm_agg_pattern.search(line) + if m_fwd: + tid = m_fwd.group(1) + partial_util = m_fwd.group(2) + full_util = m_fwd.group(3) + results.append([tid, "N/A", "N/A", partial_util, full_util, "N/A", "fwdllm_aggregator"]) + continue + + # Match async oort selector logs + m_oort = oort_selector_pattern.search(line) + if m_oort: + tid = m_oort.group(1) + unnorm_full = m_oort.group(2) + norm_full = m_oort.group(3) + partial_util = m_oort.group(4) + results.append([tid, "N/A", "N/A", partial_util, unnorm_full, norm_full, "async_oort_selector"]) + + # Write results to CSV + with open(output_file, 'w', newline='') as out_csv: + writer = csv.writer(out_csv) + writer.writerow(["TrainerID", "ModelVersion", "Samples", "PartialStatUtility", "FullStatUtility_Unnormalized", "FullStatUtility_Normalized", "Source"]) + writer.writerows(results) + + print(f"Done! {len(results)} records extracted. Output saved to {output_file}") + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} [output_csv_file]") + sys.exit(1) + + log_path = sys.argv[1] + out_path = sys.argv[2] if len(sys.argv) > 2 else "stat_utility_comparison.csv" + + extract_utilities(log_path, out_path) diff --git a/lib/python/examples/fwdllm/extract_stat_utility.sh b/lib/python/examples/fwdllm/extract_stat_utility.sh new file mode 100755 index 000000000..d6af37724 --- /dev/null +++ b/lib/python/examples/fwdllm/extract_stat_utility.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Script to parse FLAME training logs and extract stat_utility metrics +# This is a wrapper around the extract_stat_utility.py script. + +if [ -z "$1" ]; then + echo "Usage: $0 [output_csv_file]" + exit 1 +fi + +LOG_FILE=$1 +OUTPUT_CSV=${2:-"stat_utility_comparison.csv"} + +# Check if log file exists +if [ ! -f "$LOG_FILE" ]; then + echo "Error: Log file not found at $LOG_FILE" + exit 1 +fi + +# Run the python script located in the same directory +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PYTHON_SCRIPT="$SCRIPT_DIR/extract_stat_utility.py" + +python "$PYTHON_SCRIPT" "$LOG_FILE" "$OUTPUT_CSV" diff --git a/lib/python/examples/fwdllm/trainer/fl_main.py b/lib/python/examples/fwdllm/trainer/fl_main.py index 99b0c96b6..ac5685097 100644 --- a/lib/python/examples/fwdllm/trainer/fl_main.py +++ b/lib/python/examples/fwdllm/trainer/fl_main.py @@ -77,9 +77,9 @@ def post_complete_message(tc_args): # customize the log format logging.basicConfig( - level=logging._nameToLevel[args.log_level], - format="%(process)s %(asctime)s.%(msecs)03d - {%(module)s.py (%(lineno)d)} - %(funcName)s(): %(message)s", - datefmt="%Y-%m-%d,%H:%M:%S", + level=logging._nameToLevel[args.log_level.upper()], + format="%(asctime)s | %(filename)s:%(lineno)d | %(levelname)s | %(threadName)s | %(funcName)s | %(message)s", + force=True, ) logging.debug(config) set_seed(config.hyperparameters.manual_seed) diff --git a/lib/python/examples/fwdllm/trainer/forward_training/tc_transformer_trainer_distribute.py b/lib/python/examples/fwdllm/trainer/forward_training/tc_transformer_trainer_distribute.py index 0c3eb328b..7f72f8831 100755 --- a/lib/python/examples/fwdllm/trainer/forward_training/tc_transformer_trainer_distribute.py +++ b/lib/python/examples/fwdllm/trainer/forward_training/tc_transformer_trainer_distribute.py @@ -225,6 +225,7 @@ def __init__( self.params = None self.buffers = None self.grad_for_var_check = None + self.required_stat_utilities = [] # def initialize(self) -> None: """Initialize role.""" self.device = # torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -380,6 +381,10 @@ def _select_optimal_perturbations(device, logging_state): def _train_one_batch(self, device, batch, epoch, batch_idx, v_buffer): @timer_decorator def _compute_batch_stat_utility(device, x, labels): + if "partial" not in self.required_stat_utilities: + if batch_idx == 0 and epoch == 0: + logging.info("Skipping partial stat_utility computation as per config.") + return with torch.no_grad(): pred = self.model(x) @@ -391,7 +396,7 @@ def _compute_batch_stat_utility(device, x, labels): logits = pred loss = self.base_trainer.oort_loss(logits, labels.view(-1), epoch=0, batch_idx=0, reduction="mean") # Optimization: Lazy logging & removed .item() to avoid Host-device sync - logging.debug("stat_utility for trainerId: %s is %s, loss: %s", self.trainer_id, self.base_trainer._stat_utility, loss.mean()) + logging.debug("partial_stat_utility for trainerId: %s is %s, loss: %s", self.trainer_id, self.base_trainer.partial_stat_utility, loss.mean()) @timer_decorator def _prepare_perturbation_tensors(device, v_buffer): @@ -488,15 +493,70 @@ def _accumulate_and_extract_grads(device, jvp, v_params): # Optimization: Remove GC & buffer flushes from the batch loop # self._force_cuda_memory_cleanup(device, f"epoch{epoch}_batch{batch_idx}_end") - self.base_trainer.normalize_stat_utility(epoch) - logging.debug( - f"stat_utility - normalized for trainerId: {self.trainer_id} = {self.base_trainer._stat_utility}" - ) + if "partial" in self.required_stat_utilities: + self.base_trainer.normalize_stat_utility(epoch) + logging.debug( + f"stat_utility - normalized for trainerId: {self.trainer_id} = {self.base_trainer.partial_stat_utility}" + ) del x, labels, jvp, v_params # self._force_cuda_memory_cleanup(device, f"epoch{epoch}_batch{batch_idx}_end") return loss + @timer_decorator + def calculate_full_dataset_stat_utility(self, all_data_bins, device=None): + if not device: + device = self.device + + self.model.to(device) + self.model.eval() + + total_squared_loss = 0.0 + total_samples = 0 + + # We need the original loss function without reduction + if hasattr(self, 'base_trainer') and self.base_trainer is not None: + criterion = self.base_trainer.loss_fn(reduction="none", **{k: v for k, v in self.base_trainer.config.hyperparameters.__dict__.items() if k in self.base_trainer.loss_fn.__init__.__code__.co_varnames and k != 'reduction'}) + else: + criterion = CrossEntropyLoss(reduction="none") + + from torch.cuda.amp import autocast + autocast_cm = autocast() if self.args.fp16 else contextlib.nullcontext() + + with torch.no_grad(), autocast_cm: + for data_bin in all_data_bins: + # We expect data_bin to be a DataLoader. + # If we need to force eval_batch_size, we might need a workaround, but typically + # for these evaluation passes we can just iterate over the existing batches. + # If mem is constrained, we should rely on the user's config batch size. + for batch in data_bin: + x = batch[1].to(device, non_blocking=True) + labels = batch[4].to(device, non_blocking=True) + + pred = self.model(x) + if hasattr(pred, "logits"): + logits = pred.logits + elif isinstance(pred, (tuple, list)): + logits = pred[0] + else: + logits = pred + + loss_list = criterion(logits.view(-1, self.num_labels), labels.view(-1)) + + total_squared_loss += torch.square(loss_list).sum().item() + total_samples += len(loss_list) + + del x, labels, pred, logits, loss_list + + # Calculate full stat utility: sqrt(N * sum(loss^2)) + if total_samples > 0: + full_stat_utility = math.sqrt(total_samples * total_squared_loss) + else: + full_stat_utility = 0.0 + + logging.info(f"full_dataset_stat_utility for trainerId: {self.trainer_id} is {full_stat_utility} over {total_samples} samples") + return full_stat_utility + @timer_decorator def _training_loop(self, device, v_buffer): global_step = 0 diff --git a/lib/python/flame/mode/horizontal/asyncfl/top_aggregator.py b/lib/python/flame/mode/horizontal/asyncfl/top_aggregator.py index 7f3495f8f..4a1759d87 100644 --- a/lib/python/flame/mode/horizontal/asyncfl/top_aggregator.py +++ b/lib/python/flame/mode/horizontal/asyncfl/top_aggregator.py @@ -39,6 +39,8 @@ PROP_LAST_EVAL_ROUND, PROP_ROUND_DURATION, PROP_ROUND_START_TIME, + PROP_PARTIAL_DATASET_STAT_UTILITY, + PROP_FULL_DATASET_STAT_UTILITY, PROP_STAT_UTILITY, PROP_UPDATE_COUNT, ) @@ -466,21 +468,31 @@ def _aggregate_weights(self, tag: str) -> None: if MessageType.MODEL_VERSION in msg: version = msg[MessageType.MODEL_VERSION] - stat_utility = 0 # default - if MessageType.STAT_UTILITY in msg: + partial_stat_utility = 0 # default + + if MessageType.PARTIAL_DATASET_STAT_UTILITY in msg: channel.set_end_property( - end, PROP_STAT_UTILITY, msg[MessageType.STAT_UTILITY] + end, PROP_PARTIAL_DATASET_STAT_UTILITY, msg[MessageType.PARTIAL_DATASET_STAT_UTILITY] + ) + # Extracted directly from the message since fedbuff uses this + partial_stat_utility = msg[MessageType.PARTIAL_DATASET_STAT_UTILITY] + + if MessageType.FULL_DATASET_STAT_UTILITY in msg: + channel.set_end_property( + end, + PROP_FULL_DATASET_STAT_UTILITY, + msg[MessageType.FULL_DATASET_STAT_UTILITY], ) - stat_utility = msg[MessageType.STAT_UTILITY] logger.info( - f"Received weights from {end}. It was trained on model version {version}, with {count} samples. Returned stat utility {stat_utility}" + f"Received weights from {end}. It was trained on model version {version}, with {count} samples. " + f"Returned partial stat utility {partial_stat_utility} and full stat utility {msg.get(MessageType.FULL_DATASET_STAT_UTILITY, 0)}" ) if ( weights is not None and count > 0 ): # SC_TS: count = 0 means no data (it was trained on!), so ignore! - tres = TrainResult(weights, count, version, stat_utility) + tres = TrainResult(weights, count, version, partial_stat_utility) # save training result from trainer in a disk cache self.cache[end] = tres logger.debug(f"received {len(self.cache)} trainer updates in cache") @@ -491,7 +503,7 @@ def _aggregate_weights(self, tag: str) -> None: # Populate round statistics vars self._round_update_values["staleness"].append(update_staleness_val) - self._round_update_values["stat_utility"].append(stat_utility) + self._round_update_values["stat_utility"].append(partial_stat_utility) self._round_update_values["trainer_speed"].append( channel.get_end_property( end_id=end, key=PROP_ROUND_DURATION diff --git a/lib/python/flame/mode/horizontal/syncfl/fwdllm_aggregator.py b/lib/python/flame/mode/horizontal/syncfl/fwdllm_aggregator.py index 78565c312..13880cd99 100644 --- a/lib/python/flame/mode/horizontal/syncfl/fwdllm_aggregator.py +++ b/lib/python/flame/mode/horizontal/syncfl/fwdllm_aggregator.py @@ -53,6 +53,8 @@ PROP_LAST_EVAL_ROUND, PROP_ROUND_DURATION, PROP_ROUND_START_TIME, + PROP_PARTIAL_DATASET_STAT_UTILITY, + PROP_FULL_DATASET_STAT_UTILITY, PROP_STAT_UTILITY, PROP_UPDATE_COUNT, ) @@ -85,7 +87,7 @@ def _calculate_hash(tensor): def log_error_distribution(probs, labels): """ - Analyzes error distribution across 4 classes and logs via logging.info. + Analyzes error distribution across 4 classes and logs via logging.debug. probs: ndarray [N, 4] - Softmax probabilities labels: ndarray [N] - Integer ground truth """ @@ -102,12 +104,12 @@ def log_error_distribution(probs, labels): wrong_mask = preds != actual_labels if not np.any(wrong_mask): - logging.info("Accuracy is 100%.") + logging.debug("Accuracy is 100%.") return # 3. Filter for wrong predictions (Now safely NumPy) wrong_probs = probs[wrong_mask] - logging.info(f"wrong probs len = {len(wrong_probs)}") + logging.debug("wrong probs len = %d", len(wrong_probs)) # Now this will work perfectly confidences = np.max(wrong_probs, axis=1) @@ -143,7 +145,7 @@ def log_error_distribution(probs, labels): # 4. Summary Statistics for "Confidently Wrong" samples high_margin_count = np.sum(margins > 0.5) logging.debug( - f"Summary: {high_margin_count} errors have a margin > 0.5 (Confidently Wrong)." + "Summary: %d errors have a margin > 0.5 (Confidently Wrong).", high_margin_count ) @@ -225,6 +227,7 @@ def internal_init(self) -> None: self._model_version_trainer_stats = { "train_duration": [], "partial_stat_utility": [], + "full_stat_utility": [], } self._per_round_staleness_list = [] @@ -459,6 +462,9 @@ def read_trainer_unavailability(self, trace=None) -> None: with open(file_path) as f: trainer_json = json.load(f) curr_trainer_id = trainer_json["taskid"] + logger.info( + f"Processing file {file_path} for trainer {curr_trainer_id}" + ) event_list = ast.literal_eval(trainer_json["hyperparameters"][trace]) # SortedDict for efficient timestamp lookup @@ -520,14 +526,16 @@ def aggregate_grads_from_trainers( scale_val = self.optimizer.agg_rate_conf["scale"] a_exp_val = self.optimizer.agg_rate_conf["a_exp"] b_exp_val = self.optimizer.agg_rate_conf["b_exp"] + alpha_type = self.optimizer.agg_rate_conf["alpha_type"] + beta_type = self.optimizer.agg_rate_conf["beta_type"] rate = self.optimizer.weight_factor( scale=scale_val, staleness=staleness_val, a_exp=a_exp_val, loss=stat_utility, b_exp=b_exp_val, - alpha_type="polynomial", - beta_type="polynomial_upshift", + alpha_type=alpha_type, + beta_type=beta_type, ) except Exception as e: logger.warning( @@ -695,13 +703,47 @@ def _process_single_trainer_message(self, channel, msg, end, timestamp): if MessageType.GRADIENTS_FOR_VAR_CHECK in msg else None ) + + partial_stat_utility = 0 + if MessageType.PARTIAL_DATASET_STAT_UTILITY in msg: + channel.set_end_property( + end, + PROP_PARTIAL_DATASET_STAT_UTILITY, + msg[MessageType.PARTIAL_DATASET_STAT_UTILITY], + ) + partial_stat_utility = msg[MessageType.PARTIAL_DATASET_STAT_UTILITY] + + if MessageType.FULL_DATASET_STAT_UTILITY in msg: + # Still populating PROP_STAT_UTILITY for backward compatibility + channel.set_end_property( + end, + PROP_STAT_UTILITY, + msg[MessageType.FULL_DATASET_STAT_UTILITY], + ) + channel.set_end_property( + end, + PROP_FULL_DATASET_STAT_UTILITY, + msg[MessageType.FULL_DATASET_STAT_UTILITY], + ) + else: + raise Exception( + f"End {end} has not have the required value for msg[MessageType.FULL_DATASET_STAT_UTILITY]" + ) + # TODO(GD): Failing fast to catch config/ logic errors. Double check if we can catch this in async_oort.py without this exception & take care of the case where we select before we get the first gradients back. + + logger.info( + f"Aggregated utilities for {end}. " + f"Partial stat utility used for FedBuff: {partial_stat_utility}. " + f"Full stat utility stored for Oort: {msg.get(MessageType.FULL_DATASET_STAT_UTILITY, 0)}" + ) + logger.debug( f"Calling aggregate_grads_for_trainers with grad_for_var_check: {_calculate_hash(grad_for_var_check)}" ) self.aggregate_grads_from_trainers( trainer_gradients, version_for_rate=version_for_rate, - stat_utility=channel.get_end_property(end, PROP_STAT_UTILITY), + stat_utility=partial_stat_utility, grad_for_var_check=grad_for_var_check, ) @@ -752,17 +794,19 @@ def compute_percentiles(values, reverse=False): # This of this as the equivalent of sorting the array in reverse order where p1, p5, p20, p30, p50, p75, p90, p99 = ( - (99, 95, 80, 70, 50, 25, 10, 1) if reverse else (1, 5, 20, 30, 50, 75, 90, 99) + (99, 95, 80, 70, 50, 25, 10, 1) + if reverse + else (1, 5, 20, 30, 50, 75, 90, 99) ) return ( - float(np.percentile(arr, p1, method='lower')), - float(np.percentile(arr, p5, method='lower')), - float(np.percentile(arr, p20, method='lower')), - float(np.percentile(arr, p30, method='lower')), - float(np.percentile(arr, p50, method='lower')), - float(np.percentile(arr, p75, method='lower')), - float(np.percentile(arr, p90, method='lower')), - float(np.percentile(arr, p99, method='lower')), + float(np.percentile(arr, p1, method="lower")), + float(np.percentile(arr, p5, method="lower")), + float(np.percentile(arr, p20, method="lower")), + float(np.percentile(arr, p30, method="lower")), + float(np.percentile(arr, p50, method="lower")), + float(np.percentile(arr, p75, method="lower")), + float(np.percentile(arr, p90, method="lower")), + float(np.percentile(arr, p99, method="lower")), ) n_unique = len(self._model_version_unique_trainers) @@ -774,11 +818,17 @@ def compute_percentiles(values, reverse=False): self._model_version_trainer_stats["partial_stat_utility"], reverse=True ) ) + fsu_p1, fsu_p5, fsu_p20, fsu_p30, fsu_p50, fsu_p75, fsu_p90, fsu_p99 = ( + compute_percentiles( + self._model_version_trainer_stats["full_stat_utility"], reverse=True + ) + ) logger.info( f"==== Model version incremented to {self._curr_agg_version} with updates from {n_unique} unique trainers. Stats of participating trainers: \n" f"p1, p5, p20, p30, p50, p75, p90, p99 of train duration \n{rd_p1:.3f}, {rd_p5:.3f}, {rd_p20:.3f}, {rd_p30:.3f}, {rd_p50:.3f}, {rd_p75:.3f}, {rd_p90:.3f}, {rd_p99:.3f} \n" - f"p1, p5, p20, p30, p50, p75, p90, p99 of partial stat utilities \n{su_p1:.4f}, {su_p5:.4f}, {su_p20:.4f}, {su_p30:.4f}, {su_p50:.4f}, {su_p75:.4f}, {su_p90:.4f}, {su_p99:.4f}" + f"p1, p5, p20, p30, p50, p75, p90, p99 of partial stat utilities \n{su_p1:.4f}, {su_p5:.4f}, {su_p20:.4f}, {su_p30:.4f}, {su_p50:.4f}, {su_p75:.4f}, {su_p90:.4f}, {su_p99:.4f} \n" + f"p1, p5, p20, p30, p50, p75, p90, p99 of full stat utilities \n{fsu_p1:.4f}, {fsu_p5:.4f}, {fsu_p20:.4f}, {fsu_p30:.4f}, {fsu_p50:.4f}, {fsu_p75:.4f}, {fsu_p90:.4f}, {fsu_p99:.4f}" ) # Reset accumulators for the next model version window @@ -786,6 +836,7 @@ def compute_percentiles(values, reverse=False): self._model_version_trainer_stats = { "train_duration": [], "partial_stat_utility": [], + "full_stat_utility": [], } @timer_decorator @@ -805,12 +856,19 @@ def _process_aggregation_goal_met(self, tag, channel, is_async=False): train_duration.total_seconds() ) partial_stat_utility = channel.get_end_property( - trainer_update, PROP_STAT_UTILITY + trainer_update, PROP_PARTIAL_DATASET_STAT_UTILITY ) if partial_stat_utility is not None: self._model_version_trainer_stats["partial_stat_utility"].append( partial_stat_utility ) + full_stat_utility = channel.get_end_property( + trainer_update, PROP_FULL_DATASET_STAT_UTILITY + ) + if full_stat_utility is not None: + self._model_version_trainer_stats["full_stat_utility"].append( + full_stat_utility + ) self.grad_pool.append(self.grad) format_hash = lambda d: [_calculate_hash(v) for v in d] @@ -1153,6 +1211,17 @@ def check_trainer_availability(self, end: str) -> bool: @timer_decorator def _prepare_distribution_payload(self, task_to_perform: str, force_weights: bool = False): + selector_stat_util = self.config.selector.kwargs.get("stat_utility") + optimizer_stat_util = self.config.optimizer.kwargs.get("stat_utility") + + if selector_stat_util is None or optimizer_stat_util is None: + logger.error("stat_utility configuration is missing in aggregator.json. Please set it to 'full', 'partial', or 'none'. In FwdLLM, 'partial' is a good default value.") + raise ValueError("Missing stat_utility configuration in aggregator.json.") + + required_utils_set = {selector_stat_util.lower(), optimizer_stat_util.lower()} + required_utils_set.discard("none") + required_stat_utilities = list(required_utils_set) + if self.var: logger.info( f"self.var = {self.var}, self.var_threshold = {self.var_threshold}" @@ -1170,6 +1239,7 @@ def _prepare_distribution_payload(self, task_to_perform: str, force_weights: boo MessageType.TASK_TO_PERFORM: task_to_perform, MessageType.DATA_ID: self.data_id, MessageType.ITERATION_PER_DATA_ID: self.iteration_per_data_id, + MessageType.REQUIRED_STAT_UTILITIES: required_stat_utilities, } logger.info( @@ -1205,6 +1275,7 @@ def _prepare_distribution_payload(self, task_to_perform: str, force_weights: boo MessageType.TASK_TO_PERFORM: task_to_perform, MessageType.DATA_ID: self.data_id, MessageType.ITERATION_PER_DATA_ID: self.iteration_per_data_id, + MessageType.REQUIRED_STAT_UTILITIES: required_stat_utilities, } return payload diff --git a/lib/python/flame/mode/horizontal/syncfl/fwdllm_trainer.py b/lib/python/flame/mode/horizontal/syncfl/fwdllm_trainer.py index da5136f7b..18cf6d89e 100644 --- a/lib/python/flame/mode/horizontal/syncfl/fwdllm_trainer.py +++ b/lib/python/flame/mode/horizontal/syncfl/fwdllm_trainer.py @@ -69,8 +69,10 @@ def recv_wrapper(self, channel, end_id): return channel.recv(end_id) + import hashlib + def _calculate_hash(tensor): if tensor is None: return "" @@ -78,6 +80,7 @@ def _calculate_hash(tensor): """Calculate a hash for a tensor for logging.""" return hashlib.sha256(tensor.detach().cpu().numpy().tobytes()).hexdigest() + class Trainer(Role, metaclass=ABCMeta): """Trainer implements an ML training role.""" @@ -157,7 +160,8 @@ def internal_init(self) -> None: self.task_to_perform = "train" self.iteration_per_data_id = None self.abort_training = False - self._stat_utility = 0 + self.partial_stat_utility = 0 + self.required_stat_utilities = [] def get(self, tag: str) -> None: """Get data from remote role(s).""" @@ -304,8 +308,10 @@ def _fetch_weights(self, tag: str) -> None: # Helper lambda for a cleaner log format_hash = lambda d: {k: _calculate_hash(v)[:8] for k, v in d.items()} - logging.debug(f"Trainer Id : {self.trainer_id} received weights (hashed): {format_hash(self.model.state_dict())}") - + logging.debug( + f"Trainer Id : {self.trainer_id} received weights (hashed): {format_hash(self.model.state_dict())}" + ) + if MessageType.DATA_ID in msg: logger.info( f"Trainer id {self.trainer_id} received data id for training : {msg[MessageType.DATA_ID]}" @@ -336,12 +342,16 @@ def _fetch_weights(self, tag: str) -> None: full_grad.append( torch.zeros_like(param, device="cpu") ) - + if partial_grad is not None: format_hash = lambda d: [_calculate_hash(v)[:8] for v in d] - logger.debug(f"Trainer: {self.trainer_id} - old_grad: {format_hash(partial_grad)}") + logger.debug( + f"Trainer: {self.trainer_id} - old_grad: {format_hash(partial_grad)}" + ) else: - logger.debug(f"Trainer: {self.trainer_id} - old_grad: None") + logger.debug( + f"Trainer: {self.trainer_id} - old_grad: None" + ) if self.data_id % 2: logger.debug(f"using old grad for : {self.data_id}") @@ -370,6 +380,10 @@ def _fetch_weights(self, tag: str) -> None: logger.debug(f"Found task_to_perform in msg: {self.task_to_perform}") else: logger.info(f"Didn't find TASK_TO_PERFORM in msg") + + if MessageType.REQUIRED_STAT_UTILITIES in msg: + self.required_stat_utilities = msg[MessageType.REQUIRED_STAT_UTILITIES] + self.trainer.model_trainer.required_stat_utilities = self.required_stat_utilities self.fetch_success = True @@ -462,6 +476,22 @@ def _send_grads(self, tag: str) -> None: # one aggregator is sufficient end = channel.one_end(VAL_CH_STATE_SEND) + # We assume self.trainer.model_trainer is present and has the required method + # Pass the entire list so that the first loop in calculate_full_dataset_stat_utility + # correctly identifies the 'bins'. + full_stat_utility = 0.0 + if "full" in self.required_stat_utilities: + full_stat_utility = ( + self.trainer.model_trainer.calculate_full_dataset_stat_utility( + self.train_local_list, self.device + ) + ) + logger.info( + f"Trainer {self.trainer_id} full_dataset_stat_utility calculation complete: {full_stat_utility}" + ) + else: + logger.info("Skipping full stat_utility computation as per config.") + if self.task_to_perform == "train": # trainer is expected to train and it is also available to train - # best case self._update_weights() @@ -496,8 +526,12 @@ def _send_grads(self, tag: str) -> None: f"({size_mb:.2f} MB)." ) - format_hash = lambda d: {k: _calculate_hash(v)[:8] for k, v in d.items()} - logger.info(f"Sending grads from Trainer: {self.trainer_id} - model version: {self._model_version} - grad: {format_hash(grad_dict)} - grad_for_var_check: {_calculate_hash(self.grad_for_var_check)}") + format_hash = lambda d: { + k: _calculate_hash(v)[:8] for k, v in d.items() + } + logger.info( + f"Sending grads from Trainer: {self.trainer_id} - model version: {self._model_version} - grad: {format_hash(grad_dict)} - grad_for_var_check: {_calculate_hash(self.grad_for_var_check)}" + ) else: logger.info("No gradients exist; sending an empty dictionary.") @@ -507,15 +541,18 @@ def _send_grads(self, tag: str) -> None: MessageType.DATASET_SIZE: self.dataset_size, MessageType.MODEL_VERSION: self._model_version, MessageType.DATASAMPLER_METADATA: self.datasampler.get_metadata(), - MessageType.STAT_UTILITY: self._stat_utility, - # - rn FedSgdTrainer has no utility MessageType.TOTAL_DATA_BINS: self.total_data_bins, } + if self.required_stat_utilities: + msg[MessageType.PARTIAL_DATASET_STAT_UTILITY] = self.partial_stat_utility + msg[MessageType.FULL_DATASET_STAT_UTILITY] = full_stat_utility else: msg = { MessageType.MODEL_VERSION: self._model_version, - MessageType.STAT_UTILITY: self._stat_utility, } + if self.required_stat_utilities: + msg[MessageType.PARTIAL_DATASET_STAT_UTILITY] = self.partial_stat_utility + msg[MessageType.FULL_DATASET_STAT_UTILITY] = full_stat_utility channel.send(end, msg) @@ -546,7 +583,7 @@ def _send_grads(self, tag: str) -> None: channel._selector._cleanup_send_ends() # Optimization: Perform GC and CUDA memory cleanup after sending gradients. - # This moves the "stop the world" synchronous flushes out of the measured + # This moves the "stop the world" synchronous flushes out of the measured # training/evaluation phases and into the idle time between rounds. gc.collect() torch.cuda.empty_cache() @@ -688,7 +725,7 @@ def send_heartbeat_to_agg(self) -> None: # #### ADDED OORT RELATED FUNCTIONALITY def init_oort_variables(self) -> None: """Initialize Oort variables.""" - self._stat_utility = 0 + self.partial_stat_utility = 0 self._batch_size = 0 if "reduction" not in inspect.signature(self.loss_fn).parameters: @@ -722,7 +759,7 @@ def oort_loss( loss_list = criterion(output, target) self._batch_size = len(loss_list) logger.debug(f"batch size: {len(loss_list)}") - self._stat_utility += torch.square(loss_list).sum() + self.partial_stat_utility += torch.square(loss_list).sum() if reduction == "mean": loss = loss_list.mean() @@ -737,13 +774,13 @@ def normalize_stat_utility(self, epoch) -> None: """ # incase of oort - stat utility is calculated only at the beginning (epoch = 0, batch = 0) # but in fwdllm, we want to calculate it with every update - self._stat_utility = self._batch_size * math.sqrt( - self._stat_utility / self._batch_size + self.partial_stat_utility = self._batch_size * math.sqrt( + self.partial_stat_utility / self._batch_size ) def reset_stat_utility(self) -> None: """Reset the trainer's statistical utility to zero.""" - self._stat_utility = 0 + self.partial_stat_utility = 0 @timer_decorator def pause_execution(self): diff --git a/lib/python/flame/mode/message.py b/lib/python/flame/mode/message.py index c24c189a5..5689b22e7 100644 --- a/lib/python/flame/mode/message.py +++ b/lib/python/flame/mode/message.py @@ -40,6 +40,10 @@ class MessageType(Enum): STAT_UTILITY = 11 # measured utility of a trainer based on Oort + FULL_DATASET_STAT_UTILITY = 32 # measured utility of a trainer over its entire local dataset + PARTIAL_DATASET_STAT_UTILITY = 33 # measured utility of a trainer over the data it trained on + REQUIRED_STAT_UTILITIES = 34 # list of stat utilities that the trainer is requested to compute + COORDINATED_ENDS = 12 # ends coordinated by a coordinator DATASAMPLER_METADATA = 13 # datasampler metadata diff --git a/lib/python/flame/optimizer/fedbuff.py b/lib/python/flame/optimizer/fedbuff.py index 862ec3d92..b7a6055f9 100644 --- a/lib/python/flame/optimizer/fedbuff.py +++ b/lib/python/flame/optimizer/fedbuff.py @@ -96,6 +96,9 @@ def beta_polynomial_upshift(self, loss, b_exp): def beta_exponential(self, loss, b_exp): return 1 - np.exp(-b_exp * loss) + def beta_exponential_upshift(self, loss, b_exp): + return 1 - np.exp(-b_exp * loss) + 0.25 + def beta_exponential_custom(self, loss, b_exp): decay_constant = 500 / math.log(2) # Adjusting the decay constant return math.exp(-loss / decay_constant) @@ -125,6 +128,8 @@ def weight_factor( beta = self.beta_polynomial_upshift(loss, b_exp) elif beta_type == "exponential_custom": beta = self.beta_exponential_custom(loss, b_exp) + elif beta_type == "exponential_upshift": + beta = self.beta_exponential_upshift(loss, b_exp) else: raise ValueError("Invalid beta type") @@ -173,6 +178,10 @@ def do( rate = 1 / math.sqrt(1 + version - tres.version) elif self.agg_rate_conf["type"] == "new": + if getattr(tres, "stat_utility", None) is None: + logger.error("FedBuff optimizer configured for 'new' agg_rate expects 'stat_utility' from the trainer, but it is missing.") + raise ValueError("FedBuff optimizer expected 'stat_utility' in message from Trainer but got None. Check aggregator / trainer config.") + # New rate that trades off staleness and statistical # utility diff --git a/lib/python/flame/selector/async_oort.py b/lib/python/flame/selector/async_oort.py index 5e47bb8a8..c2bbde777 100644 --- a/lib/python/flame/selector/async_oort.py +++ b/lib/python/flame/selector/async_oort.py @@ -44,6 +44,8 @@ PROP_SELECTED_COUNT = "selected_count" PROP_ROUND_START_TIME = "round_start_time" PROP_ROUND_DURATION = "round_duration" +PROP_FULL_DATASET_STAT_UTILITY = "full_dataset_stat_utility" +PROP_PARTIAL_DATASET_STAT_UTILITY = "partial_dataset_stat_utility" PROP_STAT_UTILITY = "stat_utility" PROP_DATASET_SIZE = "dataset_size" PROP_UPDATE_COUNT = "update_count" @@ -119,9 +121,15 @@ def __init__(self, **kwargs): self.exploration_factor = 0.9 self.exploration_factor_decay = 0.98 - self.min_exploration_factor = 0.2 - + self.min_exploration_factor = 0.1 self.exploitation_util_history = [] + + stat_utility_config = kwargs.get("stat_utility", "partial") + self._target_stat_utility_prop = ( + PROP_FULL_DATASET_STAT_UTILITY + if stat_utility_config == "full" + else PROP_PARTIAL_DATASET_STAT_UTILITY + ) # Assuming a max round duration of 99999 seconds (~1.2 days) self.round_preferred_duration = timedelta(seconds=99999) @@ -328,7 +336,9 @@ def select( self._select_run_counter += 1 for selected_end_id in results.keys(): - end_stat_util = ends[selected_end_id].get_property(PROP_STAT_UTILITY) + end_stat_util = ends[selected_end_id].get_property( + self._target_stat_utility_prop + ) end_speed = ends[selected_end_id].get_property(PROP_ROUND_DURATION) end_last_round = ends[selected_end_id].get_property( PROP_LAST_EVAL_ROUND @@ -514,7 +524,13 @@ def fetch_statistical_utility( if (end_id not in blocklist_end_ids) and ( end_id not in trainer_unavail_list ): - end_utility = ends[end_id].get_property(PROP_STAT_UTILITY) + end_utility = ends[end_id].get_property(self._target_stat_utility_prop) + has_trained = ends[end_id].get_property(PROP_ROUND_DURATION) is not None + + if has_trained and end_utility is None: + logger.error(f"Selector expected {self._target_stat_utility_prop} for end {end_id} but it was missing.") + raise ValueError(f"Crucial state gap: Expected {self._target_stat_utility_prop} for {end_id} in selector, but got None.") + if end_utility is not None: utility_list.append( {PROP_END_ID: end_id, PROP_UTILITY: end_utility} @@ -664,7 +680,7 @@ def save_exploited_utility_history( exploited_utility = 0 for exploit_end_id in exploit_end_ids: exploited_utility += ends[exploit_end_id].get_property( - PROP_STAT_UTILITY + self._target_stat_utility_prop ) exploited_utility /= len(exploit_end_ids) self.exploitation_util_history.append(exploited_utility) @@ -1331,7 +1347,7 @@ def _handle_send_state( # all the ends for end_id, end in ends.items(): logger.debug( - f"End ID: {end_id}, Last Eval Round: {end.get_property(PROP_LAST_EVAL_ROUND)}, Statistical Utility: {end.get_property(PROP_STAT_UTILITY)}" + f"End ID: {end_id}, Last Eval Round: {end.get_property(PROP_LAST_EVAL_ROUND)}, Statistical Utility: {end.get_property(self._target_stat_utility_prop)}" ) # NOTE: (DG) Assuming that shuffled_end_ids is not needed @@ -1836,7 +1852,10 @@ def _handle_recv_state( curr_end_state = end.get_property(KEY_END_STATE) # candidates[end_id] = end if end_id not in self.all_selected.keys(): - if curr_end_state != VAL_END_STATE_NONE: + if ( + curr_end_state is not None # TODO: (GD) revert this + and curr_end_state != VAL_END_STATE_NONE + ): logging.info( f"end_id {end_id} not in all_selected and in state: {curr_end_state}, adding " f"to candidates: key {end_id}, val: {end}" diff --git a/lib/python/flame/selector/oort.py b/lib/python/flame/selector/oort.py index 952b158c0..fef8e9bee 100644 --- a/lib/python/flame/selector/oort.py +++ b/lib/python/flame/selector/oort.py @@ -35,6 +35,8 @@ PROP_SELECTED_COUNT = "selected_count" PROP_ROUND_START_TIME = "round_start_time" PROP_ROUND_DURATION = "round_duration" +PROP_PARTIAL_DATASET_STAT_UTILITY = "partial_dataset_stat_utility" +PROP_FULL_DATASET_STAT_UTILITY = "full_dataset_stat_utility" PROP_STAT_UTILITY = "stat_utility" PROP_DATASET_SIZE = "dataset_size" PROP_UPDATE_COUNT = "update_count" @@ -332,7 +334,7 @@ def select( self._select_run_counter += 1 for selected_end_id in self.selected_ends: - end_stat_util = ends[selected_end_id].get_property(PROP_STAT_UTILITY) + end_stat_util = ends[selected_end_id].get_property(PROP_FULL_DATASET_STAT_UTILITY) end_speed = ends[selected_end_id].get_property(PROP_ROUND_DURATION) end_last_round = ends[selected_end_id].get_property(PROP_LAST_EVAL_ROUND) # Insert to queues tracking stat_util, speed, round data @@ -498,7 +500,7 @@ def fetch_statistical_utility( if (end_id not in blocklist_end_ids) and ( end_id not in trainer_unavail_list ): - end_utility = ends[end_id].get_property(PROP_STAT_UTILITY) + end_utility = ends[end_id].get_property(PROP_FULL_DATASET_STAT_UTILITY) if end_utility is not None: utility_list.append( {PROP_END_ID: end_id, PROP_UTILITY: end_utility} @@ -597,7 +599,7 @@ def save_exploited_utility_history( exploited_utility = 0 for exploit_end_id in exploit_end_ids: exploited_utility += ends[exploit_end_id].get_property( - PROP_STAT_UTILITY + PROP_FULL_DATASET_STAT_UTILITY ) exploited_utility /= len(exploit_end_ids) self.exploitation_util_history.append(exploited_utility) diff --git a/scripts/benchmark_fwdllm.py b/scripts/benchmark_fwdllm.py index 6f7675054..52e749b14 100644 --- a/scripts/benchmark_fwdllm.py +++ b/scripts/benchmark_fwdllm.py @@ -204,8 +204,7 @@ def main(): # Log setup logging.basicConfig( level=logging.INFO, - format="%(asctime)s.%(msecs)03d - %(funcName)s(): %(message)s", - datefmt="%Y-%m-%d,%H:%M:%S", + format="%(asctime)s | %(filename)s:%(lineno)d | %(levelname)s | %(threadName)s | %(funcName)s | %(message)s", ) config = Config(args.config) diff --git a/scripts/calculate_seq_len_cdf.py b/scripts/calculate_seq_len_cdf.py index 1db15b564..192f391e2 100644 --- a/scripts/calculate_seq_len_cdf.py +++ b/scripts/calculate_seq_len_cdf.py @@ -34,7 +34,7 @@ def run_analysis(data_path, partition_path, model_name, method): partition_data = partition_file[method]["partition_data"] trainer_batch_max_lengths = [] - batch_size = 8 + batch_size = 1 for client_idx in tqdm(partition_data.keys(), desc="Trainer (Partitions)"): indices = partition_data[client_idx]["train"][()] @@ -55,8 +55,8 @@ def run_analysis(data_path, partition_path, model_name, method): if __name__ == "__main__": # Parameters based on your input - DATA = "/Users/gaurav/Projects/fednlp_data/data_files/agnews_data.h5" - PARTITION = "/Users/gaurav/Projects/fednlp_data/partition_files/agnews_partition.h5" + DATA = "/coc/scratch/dgarg/fl_datasets/fwdllm/fednlp_data/data_files/agnews_data.h5" + PARTITION = "/coc/scratch/dgarg/fl_datasets/fwdllm/fednlp_data/partition_files/agnews_partition.h5" MODEL = "distilbert-base-uncased" METHOD = "uniform"