Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/python/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ flame.egg-info
data
wandb

examples/fwdllm/expts/run_tc_expts/mpi_host_file
examples/fwdllm/expts/run_tc_expts/mpi_host_file
lib/python/examples/fwdllm/expts/run_tc_expts/mpi_host_file
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
)
self.trainer_event_dict = None
if (
self.track_trainer_avail["enabled"]
self.track_trainer_avail["enabled"] != None
and self.track_trainer_avail["type"] == "ORACULAR"
):
self.trainer_event_dict = self.read_trainer_unavailability(
Expand Down
8 changes: 4 additions & 4 deletions lib/python/examples/fwdllm/aggregator/fl_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ def post_complete_message(tc_args):
parser.add_argument("--log_level", type=str, default="INFO", required=False)
args = parser.parse_args()

logger.setLevel(args.log_level)
logger.setLevel(args.log_level.upper())
config = Config(args.config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# customize the log format
logging.basicConfig(
level=logging._nameToLevel[args.log_level],
format="%(process)s %(asctime)s.%(msecs)03d - {%(module)s.py (%(lineno)d)} - %(funcName)s(): %(message)s",
datefmt="%Y-%m-%d,%H:%M:%S",
level=logging._nameToLevel[args.log_level.upper()],
format="%(asctime)s | %(filename)s:%(lineno)d | %(levelname)s | %(threadName)s | %(funcName)s | %(message)s",
force=True,
)
logger.info(config)
set_seed(config.hyperparameters.manual_seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def post_complete_message(tc_args):

# customize the log format
logging.basicConfig(
level=logging.INFO,
format="%(process)s %(asctime)s.%(msecs)03d - {%(module)s.py (%(lineno)d)} - %(funcName)s(): %(message)s",
datefmt="%Y-%m-%d,%H:%M:%S",
level=getattr(logging, getattr(args, "log_level", "INFO").upper(), logging.INFO),
format="%(asctime)s | %(filename)s:%(lineno)d | %(levelname)s | %(threadName)s | %(funcName)s | %(message)s",
force=True,
)
logging.info(args)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@
"roundNudgeType": "last_train",
"minInitialTrainers": 50,
"k": 10,
"is_async": true
"is_async": true,
"stat_utility": "partial"
}
},
"optimizer": {
Expand All @@ -123,11 +124,14 @@
"use_oort_lr": "True",
"dataset_name": "google-speech",
"agg_rate_conf": {
"type": "old",
"scale": 1.0,
"type": "new",
"scale": 0.0,
"a_exp": 0.2,
"b_exp": 0.25
}
"b_exp": 0.25,
"alpha_type": "polynomial",
"beta_type": "exponential"
},
"stat_utility": "partial"
}
},
"maxRunTime": 600,
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
wash.cc.gatech.edu
jayne.cc.gatech.edu
69 changes: 69 additions & 0 deletions lib/python/examples/fwdllm/extract_stat_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import sys
import re
import csv
import os

def extract_utilities(log_file, output_file):
if not os.path.exists(log_file):
print(f"Error: Log file not found at {log_file}")
sys.exit(1)

print(f"Extracting stat_utility metrics from {log_file}...")

# Regex patterns for matching top_aggregator and fwdllm_aggregator logs
# top_agg_pattern = re.compile(r"Received weights from (\S+)\. It was trained on model version (\d+), with (\d+) samples\. Returned partial stat utility ([\d\.]+) and full stat utility ([\d\.]+)")
fwdllm_agg_pattern = re.compile(r"Aggregated utilities for (\S+)\. Partial stat utility used for FedBuff: ([\d\.]+)\. Full stat utility stored for Oort: ([\d\.]+)")

# Regex for async_oort selector debug logs (if enabled)
oort_selector_pattern = re.compile(r"Trainer (\S+) full_dataset_stat_utility: unnormalized = ([\d\.]+), normalized = ([\d\.]+), partial_dataset_stat_utility: ([\d\.]+)")

results = []

with open(log_file, 'r') as f:
for line in f:
# Match top aggregator logs
m_top = top_agg_pattern.search(line)
if m_top:
tid = m_top.group(1)
version = m_top.group(2)
samples = m_top.group(3)
partial_util = m_top.group(4)
full_util = m_top.group(5)
results.append([tid, version, samples, partial_util, full_util, "N/A", "top_aggregator"])
continue

# Match fwdllm aggregator logs
m_fwd = fwdllm_agg_pattern.search(line)
if m_fwd:
tid = m_fwd.group(1)
partial_util = m_fwd.group(2)
full_util = m_fwd.group(3)
results.append([tid, "N/A", "N/A", partial_util, full_util, "N/A", "fwdllm_aggregator"])
continue

# Match async oort selector logs
m_oort = oort_selector_pattern.search(line)
if m_oort:
tid = m_oort.group(1)
unnorm_full = m_oort.group(2)
norm_full = m_oort.group(3)
partial_util = m_oort.group(4)
results.append([tid, "N/A", "N/A", partial_util, unnorm_full, norm_full, "async_oort_selector"])

# Write results to CSV
with open(output_file, 'w', newline='') as out_csv:
writer = csv.writer(out_csv)
writer.writerow(["TrainerID", "ModelVersion", "Samples", "PartialStatUtility", "FullStatUtility_Unnormalized", "FullStatUtility_Normalized", "Source"])
writer.writerows(results)

print(f"Done! {len(results)} records extracted. Output saved to {output_file}")

if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: python {sys.argv[0]} <path_to_log_file> [output_csv_file]")
sys.exit(1)

log_path = sys.argv[1]
out_path = sys.argv[2] if len(sys.argv) > 2 else "stat_utility_comparison.csv"

extract_utilities(log_path, out_path)
24 changes: 24 additions & 0 deletions lib/python/examples/fwdllm/extract_stat_utility.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

# Script to parse FLAME training logs and extract stat_utility metrics
# This is a wrapper around the extract_stat_utility.py script.

if [ -z "$1" ]; then
echo "Usage: $0 <path_to_log_file> [output_csv_file]"
exit 1
fi

LOG_FILE=$1
OUTPUT_CSV=${2:-"stat_utility_comparison.csv"}

# Check if log file exists
if [ ! -f "$LOG_FILE" ]; then
echo "Error: Log file not found at $LOG_FILE"
exit 1
fi

# Run the python script located in the same directory
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PYTHON_SCRIPT="$SCRIPT_DIR/extract_stat_utility.py"

python "$PYTHON_SCRIPT" "$LOG_FILE" "$OUTPUT_CSV"
6 changes: 3 additions & 3 deletions lib/python/examples/fwdllm/trainer/fl_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def post_complete_message(tc_args):

# customize the log format
logging.basicConfig(
level=logging._nameToLevel[args.log_level],
format="%(process)s %(asctime)s.%(msecs)03d - {%(module)s.py (%(lineno)d)} - %(funcName)s(): %(message)s",
datefmt="%Y-%m-%d,%H:%M:%S",
level=logging._nameToLevel[args.log_level.upper()],
format="%(asctime)s | %(filename)s:%(lineno)d | %(levelname)s | %(threadName)s | %(funcName)s | %(message)s",
force=True,
)
logging.debug(config)
set_seed(config.hyperparameters.manual_seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
self.params = None
self.buffers = None
self.grad_for_var_check = None
self.required_stat_utilities = []

# def initialize(self) -> None: """Initialize role.""" self.device =
# torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -380,6 +381,10 @@ def _select_optimal_perturbations(device, logging_state):
def _train_one_batch(self, device, batch, epoch, batch_idx, v_buffer):
@timer_decorator
def _compute_batch_stat_utility(device, x, labels):
if "partial" not in self.required_stat_utilities:
if batch_idx == 0 and epoch == 0:
logging.info("Skipping partial stat_utility computation as per config.")
return

with torch.no_grad():
pred = self.model(x)
Expand All @@ -391,7 +396,7 @@ def _compute_batch_stat_utility(device, x, labels):
logits = pred
loss = self.base_trainer.oort_loss(logits, labels.view(-1), epoch=0, batch_idx=0, reduction="mean")
# Optimization: Lazy logging & removed .item() to avoid Host-device sync
logging.debug("stat_utility for trainerId: %s is %s, loss: %s", self.trainer_id, self.base_trainer._stat_utility, loss.mean())
logging.debug("partial_stat_utility for trainerId: %s is %s, loss: %s", self.trainer_id, self.base_trainer.partial_stat_utility, loss.mean())

@timer_decorator
def _prepare_perturbation_tensors(device, v_buffer):
Expand Down Expand Up @@ -488,15 +493,70 @@ def _accumulate_and_extract_grads(device, jvp, v_params):
# Optimization: Remove GC & buffer flushes from the batch loop
# self._force_cuda_memory_cleanup(device, f"epoch{epoch}_batch{batch_idx}_end")

self.base_trainer.normalize_stat_utility(epoch)
logging.debug(
f"stat_utility - normalized for trainerId: {self.trainer_id} = {self.base_trainer._stat_utility}"
)
if "partial" in self.required_stat_utilities:
self.base_trainer.normalize_stat_utility(epoch)
logging.debug(
f"stat_utility - normalized for trainerId: {self.trainer_id} = {self.base_trainer.partial_stat_utility}"
)

del x, labels, jvp, v_params
# self._force_cuda_memory_cleanup(device, f"epoch{epoch}_batch{batch_idx}_end")
return loss

@timer_decorator
def calculate_full_dataset_stat_utility(self, all_data_bins, device=None):
if not device:
device = self.device

self.model.to(device)
self.model.eval()

total_squared_loss = 0.0
total_samples = 0

# We need the original loss function without reduction
if hasattr(self, 'base_trainer') and self.base_trainer is not None:
criterion = self.base_trainer.loss_fn(reduction="none", **{k: v for k, v in self.base_trainer.config.hyperparameters.__dict__.items() if k in self.base_trainer.loss_fn.__init__.__code__.co_varnames and k != 'reduction'})
else:
criterion = CrossEntropyLoss(reduction="none")

from torch.cuda.amp import autocast
autocast_cm = autocast() if self.args.fp16 else contextlib.nullcontext()

with torch.no_grad(), autocast_cm:
for data_bin in all_data_bins:
# We expect data_bin to be a DataLoader.
# If we need to force eval_batch_size, we might need a workaround, but typically
# for these evaluation passes we can just iterate over the existing batches.
# If mem is constrained, we should rely on the user's config batch size.
for batch in data_bin:
x = batch[1].to(device, non_blocking=True)
labels = batch[4].to(device, non_blocking=True)

pred = self.model(x)
if hasattr(pred, "logits"):
logits = pred.logits
elif isinstance(pred, (tuple, list)):
logits = pred[0]
else:
logits = pred

loss_list = criterion(logits.view(-1, self.num_labels), labels.view(-1))

total_squared_loss += torch.square(loss_list).sum().item()
total_samples += len(loss_list)

del x, labels, pred, logits, loss_list

# Calculate full stat utility: sqrt(N * sum(loss^2))
if total_samples > 0:
full_stat_utility = math.sqrt(total_samples * total_squared_loss)
else:
full_stat_utility = 0.0

logging.info(f"full_dataset_stat_utility for trainerId: {self.trainer_id} is {full_stat_utility} over {total_samples} samples")
return full_stat_utility

@timer_decorator
def _training_loop(self, device, v_buffer):
global_step = 0
Expand Down
26 changes: 19 additions & 7 deletions lib/python/flame/mode/horizontal/asyncfl/top_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
PROP_LAST_EVAL_ROUND,
PROP_ROUND_DURATION,
PROP_ROUND_START_TIME,
PROP_PARTIAL_DATASET_STAT_UTILITY,
PROP_FULL_DATASET_STAT_UTILITY,
PROP_STAT_UTILITY,
PROP_UPDATE_COUNT,
)
Expand Down Expand Up @@ -466,21 +468,31 @@ def _aggregate_weights(self, tag: str) -> None:
if MessageType.MODEL_VERSION in msg:
version = msg[MessageType.MODEL_VERSION]

stat_utility = 0 # default
if MessageType.STAT_UTILITY in msg:
partial_stat_utility = 0 # default

if MessageType.PARTIAL_DATASET_STAT_UTILITY in msg:
channel.set_end_property(
end, PROP_STAT_UTILITY, msg[MessageType.STAT_UTILITY]
end, PROP_PARTIAL_DATASET_STAT_UTILITY, msg[MessageType.PARTIAL_DATASET_STAT_UTILITY]
)
# Extracted directly from the message since fedbuff uses this
partial_stat_utility = msg[MessageType.PARTIAL_DATASET_STAT_UTILITY]

if MessageType.FULL_DATASET_STAT_UTILITY in msg:
channel.set_end_property(
end,
PROP_FULL_DATASET_STAT_UTILITY,
msg[MessageType.FULL_DATASET_STAT_UTILITY],
)
stat_utility = msg[MessageType.STAT_UTILITY]

logger.info(
f"Received weights from {end}. It was trained on model version {version}, with {count} samples. Returned stat utility {stat_utility}"
f"Received weights from {end}. It was trained on model version {version}, with {count} samples. "
f"Returned partial stat utility {partial_stat_utility} and full stat utility {msg.get(MessageType.FULL_DATASET_STAT_UTILITY, 0)}"
)

if (
weights is not None and count > 0
): # SC_TS: count = 0 means no data (it was trained on!), so ignore!
tres = TrainResult(weights, count, version, stat_utility)
tres = TrainResult(weights, count, version, partial_stat_utility)
# save training result from trainer in a disk cache
self.cache[end] = tres
logger.debug(f"received {len(self.cache)} trainer updates in cache")
Expand All @@ -491,7 +503,7 @@ def _aggregate_weights(self, tag: str) -> None:

# Populate round statistics vars
self._round_update_values["staleness"].append(update_staleness_val)
self._round_update_values["stat_utility"].append(stat_utility)
self._round_update_values["stat_utility"].append(partial_stat_utility)
self._round_update_values["trainer_speed"].append(
channel.get_end_property(
end_id=end, key=PROP_ROUND_DURATION
Expand Down
Loading