Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
171 commits
Select commit Hold shift + click to select a range
8c5497b
Implement TimedGRPOTrainer to log roll-out batch durations
mdvillagra Jul 22, 2025
e91d56a
Add venv to .gitignore to exclude virtual environment files
mdvillagra Jul 22, 2025
b990ba2
Add timing logs for set_model_params in FullModelLLMTrainer and FullM…
mdvillagra Jul 22, 2025
26f9668
Update run_fedml_client_custom.sh and run_fedml_server_custom.sh to i…
mdvillagra Jul 22, 2025
74a395f
Add periodic checkpointing and per-round checkpoint configuration to …
mdvillagra Jul 22, 2025
6c3ea9c
Remove commented-out code blocks in FullModelLLMTrainer for clarity
mdvillagra Jul 22, 2025
499104a
Add logging for global update frequency in FedMLServerManager
mdvillagra Jul 23, 2025
a2e576e
Add Nesterov momentum support in FullModelLLMAggregator and update lo…
mdvillagra Jul 23, 2025
579dd31
Refactor checkpoint saving logic in FullModelLLMTrainer and FullModel…
mdvillagra Jul 23, 2025
5d885ae
Merge remote changes from origin/marcos/benchmarks
mdvillagra Jul 23, 2025
b2c987e
Add RewardFunction class for evaluating model responses with correctn…
mdvillagra Jul 23, 2025
37ad46e
Update logging format in FedMLServerManager to include total rounds a…
mdvillagra Jul 23, 2025
b5fb686
Refactor correctness_reward and combined_reward methods in RewardFunc…
mdvillagra Jul 23, 2025
91642a4
Remove the format_reward method from the RewardFunction class to stre…
mdvillagra Jul 23, 2025
eaa7344
Increase max_new_tokens in FullModelLLMTrainer from 512 to 1024 for i…
mdvillagra Jul 23, 2025
afc697c
Add report_to parameter for Weights & Biases integration in FullModel…
mdvillagra Jul 23, 2025
99434ed
Refactor reward function usage in FullModelLLMTrainer to utilize rewa…
mdvillagra Jul 24, 2025
ee38a16
Refactor reward function parameters in FullModelLLMTrainer to enhance…
mdvillagra Jul 24, 2025
f4b4f1d
Update gradient accumulation steps and batch sizes in grpo_gsm8k_test…
mdvillagra Jul 24, 2025
740eed7
Update grpo_batch_size in grpo_gsm8k_test_config.yaml from 32 to 2 fo…
mdvillagra Jul 24, 2025
cd16d02
Update FullModelLLMTrainer to correctly handle model state dict loadi…
mdvillagra Jul 24, 2025
f8261bb
Refactor state dict loading in FullModelLLMTrainer to correctly acces…
mdvillagra Jul 24, 2025
c7848a3
Enhance checkpoint saving in FullModelLLMAggregator to utilize Huggin…
mdvillagra Jul 24, 2025
8732a29
Update communication rounds in grpo_gsm8k_test_config.yaml from 3 to …
mdvillagra Jul 24, 2025
85220ca
Add debugging breakpoint in reward function of FullModelLLMTrainer fo…
mdvillagra Jul 24, 2025
dd8a896
Update client configuration in grpo_gsm8k_test_config.yaml to use a s…
mdvillagra Jul 24, 2025
304b809
Enhance debugging in reward function of FullModelLLMTrainer by adding…
mdvillagra Jul 24, 2025
97e3a90
Add methods to handle boxed content and convert strings to numbers in…
mdvillagra Jul 24, 2025
09291ee
Fix reward function in FullModelLLMTrainer by updating variable names…
mdvillagra Jul 24, 2025
8c1ed2d
Update grpo_gsm8k_test_config.yaml to adjust client configuration and…
mdvillagra Jul 24, 2025
7d9a745
Refactor docstring in FullModelLLMAggregator's aggregate method for c…
mdvillagra Jul 24, 2025
3042670
Implement KL divergence logging in TimedGRPOTrainer and adjust learni…
mdvillagra Jul 25, 2025
0b28e4d
Update generation parameters and add scale_rewards option in FullMode…
mdvillagra Jul 25, 2025
3c9bff7
Adjust temperature parameter in generation settings of FullModelLLMTr…
mdvillagra Jul 25, 2025
d0d5c5a
Adjust top_p parameter in generation settings of FullModelLLMTrainer …
mdvillagra Jul 25, 2025
0700f72
Refactor generation parameters in FullModelLLMTrainer by adjusting te…
mdvillagra Jul 25, 2025
01aae96
Update learning rate in FullModelLLMTrainer to 5e-6 and uncomment sca…
mdvillagra Jul 25, 2025
4da64e2
Update generation parameters in FullModelLLMTrainer to include temper…
mdvillagra Jul 25, 2025
682fb55
Adjust max_completion_length in FullModelLLMTrainer to 100 and add ep…
mdvillagra Jul 25, 2025
850cdc2
Refactor custom trainer to integrate TrainingMetricsLogger and GRPOMe…
mdvillagra Jul 25, 2025
035704b
Remove wandb_entity parameter from FullModelLLMTrainer configuration …
mdvillagra Jul 25, 2025
f90c4ec
Increase max_completion_length in FullModelLLMTrainer from 100 to 256…
mdvillagra Jul 25, 2025
f9d0833
Comment out report_to parameter in FullModelLLMTrainer to disable Wei…
mdvillagra Jul 25, 2025
91eda1d
Update FullModelLLMTrainer to dynamically set run_name based on clien…
mdvillagra Jul 25, 2025
1a478c4
Update FullModelLLMTrainer to modify run_name format and change wandb…
mdvillagra Jul 25, 2025
5e53db2
Update FullModelLLMTrainer and grpo_gsm8k_test_config.yaml to enhance…
mdvillagra Jul 25, 2025
6791c4a
Update grpo_gsm8k_test_config.yaml to increase comm_round from 3 to 1…
mdvillagra Jul 25, 2025
2bd0c73
Enhance FullModelLLMAggregator with WandB logging for server statisti…
mdvillagra Jul 25, 2025
086c4bb
Update FullModelLLMTrainer to set logging_steps to 1 for more frequen…
mdvillagra Jul 25, 2025
4ac8dd0
Refactor TrainingMetricsLogger to use instance method for moving aver…
mdvillagra Jul 25, 2025
f2ccd9c
Enhance TrainingMetricsLogger and FullModelLLMAggregator with improve…
mdvillagra Jul 25, 2025
8692e47
Update FullModelLLMTrainer to disable log_completions for cleaner log…
mdvillagra Jul 26, 2025
8c57b88
Add checkpoint cleanup functionality to FullModelLLMTrainer and FullM…
mdvillagra Jul 26, 2025
262f45e
Update FullModelLLMTrainer to use a time-based seed for reproducibili…
mdvillagra Jul 26, 2025
af28fb6
Fix seed assignment in FullModelLLMTrainer to ensure proper formattin…
mdvillagra Jul 26, 2025
1adc9ed
Refactor aggregate method documentation in FullModelLLMAggregator to …
mdvillagra Jul 27, 2025
77222e1
Add method to cleanup old round checkpoints in FullModelLLMAggregator
mdvillagra Jul 27, 2025
09fb8b6
Add evaluation script for Qwen3-0.6B on GSM8K test split
mdvillagra Jul 27, 2025
daef416
Refactor aggregate method documentation in FullModelLLMAggregator and…
mdvillagra Jul 27, 2025
1da9db9
Update model name in save_initial_checkpoint.py from Qwen3-0.6B to Qw…
mdvillagra Jul 27, 2025
cf97042
Adjust GRPO batch size in grpo_gsm8k_test_config.yaml from 4 to 2 for…
mdvillagra Jul 27, 2025
8ee82b4
Update gradient_checkpointing settings in FullModelLLMTrainer and grp…
mdvillagra Jul 27, 2025
852d836
Add optimizer configuration in FullModelLLMTrainer
mdvillagra Jul 27, 2025
36bae8c
Add initial checkpoint saving in run_fedml_server_custom.sh
mdvillagra Jul 27, 2025
fed4829
Update training configuration and model parameters in FullModelLLMTra…
mdvillagra Jul 27, 2025
5870af8
Enhance custom trainer and configuration for improved logging and per…
mdvillagra Jul 27, 2025
5f4fd91
Reorganize environment variable settings for HF Transformers in custo…
mdvillagra Jul 27, 2025
c4bfc95
Disable flash attention in grpo_gsm8k_test_config.yaml to revert to s…
mdvillagra Jul 27, 2025
ad6c779
Enhance model loading and tokenizer initialization in custom trainer …
mdvillagra Jul 27, 2025
b672c1a
Add warnings filter in custom trainer to suppress advisory messages
mdvillagra Jul 27, 2025
a6930a9
Add gradient check for NaN/Inf values in training process
mdvillagra Jul 27, 2025
4f97311
Refactor training configuration and model handling in custom trainer
mdvillagra Jul 28, 2025
79b980f
Refactor checkpoint saving logic in run_fedllm.py
mdvillagra Jul 28, 2025
7667d99
Refactor checkpoint saving logic in run_fedllm.py
mdvillagra Jul 28, 2025
4481c0d
Enhance checkpoint saving logic in run_fedllm_custom.py
mdvillagra Jul 28, 2025
6b11954
Update GRPO configuration in grpo_gsm8k_test_config.yaml for testing
mdvillagra Jul 28, 2025
c358d92
Ensure model parameters are on CPU and clear CUDA cache in FullModelL…
mdvillagra Jul 28, 2025
134196d
Enhance timing and logging in TimedGRPOTrainer and FullModelLLMTrainer
mdvillagra Jul 29, 2025
56a701b
Add evaluation script for Qwen3-0.6B on GSM8K and enhance logging in …
mdvillagra Jul 29, 2025
a3a3240
Update GRPO configuration in grpo_gsm8k_test_config.yaml to switch fp…
mdvillagra Jul 29, 2025
ba8ac29
Add checkpoint cleanup functionality in FullModelLLMAggregator
mdvillagra Jul 29, 2025
c2311f6
Refactor average completion time logging in TimedGRPOTrainer
mdvillagra Jul 29, 2025
675dae5
Comment out the 'optim' parameter in FullModelLLMTrainer's GRPO confi…
mdvillagra Jul 29, 2025
078b48b
Enhance average completion time tracking in TrainingMetricsLogger
mdvillagra Jul 29, 2025
bced90b
Update model configuration and enhance logging in TrainingMetricsLogger
mdvillagra Jul 29, 2025
e19acfd
Update wallclock checkpoint retention policy in FullModelLLMAggregator
mdvillagra Jul 29, 2025
e8553c9
Enhance logging of average completion time in TimedGRPOTrainer
mdvillagra Jul 29, 2025
547cb65
Improve formatting of average completion time log in TimedGRPOTrainer
mdvillagra Jul 29, 2025
ba40e91
Update GRPO configuration and model training parameters
mdvillagra Jul 29, 2025
6c8c7ec
Refactor experience generation method in TimedGRPOTrainer
mdvillagra Jul 29, 2025
edf407d
Update model parameters and configurations for training optimization
mdvillagra Jul 29, 2025
2d8c1ab
Update model configurations and training parameters for improved perf…
mdvillagra Jul 30, 2025
b3e564c
Adjust max completion length and new token limit in FullModelLLMTrain…
mdvillagra Jul 30, 2025
eb2d32a
Enhance validation and memory management in training process
mdvillagra Jul 30, 2025
415cbcb
Update model configurations and training parameters for consistency a…
mdvillagra Jul 30, 2025
a2b5b98
Update model name in configuration files for consistency
mdvillagra Jul 30, 2025
2a68cb0
Update training parameters and memory management in FullModelLLMTrainer
mdvillagra Jul 30, 2025
a0ca721
Merge remote-tracking branch 'refs/remotes/origin/marcos/benchmarks' …
mdvillagra Jul 30, 2025
411b5ed
Update model configurations and training parameters for consistency a…
mdvillagra Jul 30, 2025
9145989
Update model configurations and training parameters for consistency
mdvillagra Jul 30, 2025
deef150
Update max completion length and new token limit in FullModelLLMTrain…
mdvillagra Jul 30, 2025
7272908
Merge remote-tracking branch 'refs/remotes/origin/marcos/benchmarks' …
mdvillagra Jul 30, 2025
fe57962
Comment out optim parameter in FullModelLLMTrainer to disable 8-bit A…
mdvillagra Jul 30, 2025
7bd11ba
Update model configurations and training parameters for improved perf…
mdvillagra Jul 30, 2025
df96bf0
Update client configuration parameters in grpo_gsm8k_test_config.yaml…
mdvillagra Jul 30, 2025
6e44080
Enhance TimedGRPOTrainer initialization and update model configurations
mdvillagra Jul 31, 2025
a097481
Update broadcast_object_list call in FullModelLLMTrainer to specify d…
mdvillagra Jul 31, 2025
a14707e
Refactor FullModelLLMTrainer to use float16 and enhance model paramet…
mdvillagra Jul 31, 2025
c7e89e7
Refactor broadcast_object_list call in FullModelLLMTrainer to remove …
mdvillagra Jul 31, 2025
7f02a25
Comment out CPU transfer for reference model in TimedGRPOTrainer and …
mdvillagra Jul 31, 2025
e9108d4
Refactor model initialization in TimedGRPOTrainer to retrieve model_i…
mdvillagra Jul 31, 2025
e80eedb
Update reference model configuration in TimedGRPOTrainer to use GPTQ-…
mdvillagra Jul 31, 2025
f5cc03c
Add docstring to TimedGRPOTrainer class for improved documentation
mdvillagra Jul 31, 2025
1b3500d
Enhance TimedGRPOTrainer with dropout control and reference model syn…
mdvillagra Jul 31, 2025
f457548
Update beta parameter in FullModelLLMTrainer to improve model perform…
mdvillagra Jul 31, 2025
4bf2a06
Update client configuration in grpo_gsm8k_test_config.yaml to support…
mdvillagra Aug 1, 2025
2ab2835
Adjust max completion length and new tokens in FullModelLLMTrainer fo…
mdvillagra Aug 1, 2025
403cac4
Update optimizer in FullModelLLMTrainer to use paged_adamw_8bit for e…
mdvillagra Aug 1, 2025
fc9f3ba
Enhance memory management in FullModelLLMTrainer by adding garbage co…
mdvillagra Aug 1, 2025
d9d6bed
Enhance TimedGRPOTrainer with fallback utilities and dropout management
mdvillagra Aug 1, 2025
315188c
Update optimizer and clean up memory management in FullModelLLMTrainer
mdvillagra Aug 1, 2025
cd0b61d
Comment out garbage collection import in custom_trainer.py to streaml…
mdvillagra Aug 1, 2025
a03174e
Update optimizer in FullModelLLMTrainer to paged_lion_8bit for improv…
mdvillagra Aug 1, 2025
cf6ecb2
Fix syntax error in optimizer assignment in FullModelLLMTrainer
mdvillagra Aug 1, 2025
9e1291f
Update gradient checkpointing settings in FullModelLLMTrainer and con…
mdvillagra Aug 1, 2025
841b6f0
Update gradient checkpointing settings in FullModelLLMTrainer and con…
mdvillagra Aug 1, 2025
e6c0ebb
Update max completion length and batch size in FullModelLLMTrainer an…
mdvillagra Aug 1, 2025
7a9ada3
Update max completion length and batch size in FullModelLLMTrainer an…
mdvillagra Aug 1, 2025
25c894a
Update max completion length and logging settings in FullModelLLMTrainer
mdvillagra Aug 1, 2025
3a56c1b
Update GRPO configuration for testing with increased epochs and batch…
mdvillagra Aug 1, 2025
28db52f
Update GRPO configuration in FullModelLLMTrainer for enhanced perform…
mdvillagra Aug 1, 2025
f0555a8
Update GRPO configuration for multi-client setup in grpo_gsm8k_test_c…
mdvillagra Aug 1, 2025
e84b11f
Update model configuration in GRPO test files to use Qwen3-0.6B
mdvillagra Aug 1, 2025
9a8ce66
Update model configuration and logging settings for GRPO testing
mdvillagra Aug 1, 2025
ba26160
Enhance validation script for custom model weights and output handling
mdvillagra Aug 2, 2025
5f19d21
Add paired permutation test script for model reward evaluation
mdvillagra Aug 2, 2025
655bf9e
Update GRPO configuration for reduced batch size and completion length
mdvillagra Aug 2, 2025
77cd14b
Enable SGD optimization in GRPO configuration for FullModelLLMTrainer
mdvillagra Aug 2, 2025
5622116
Reduce max completion length and new tokens in FullModelLLMTrainer fr…
mdvillagra Aug 2, 2025
aca1d32
Reduce max completion length and new tokens in FullModelLLMTrainer fr…
mdvillagra Aug 2, 2025
140fbda
Update max completion length and new tokens in FullModelLLMTrainer fr…
mdvillagra Aug 2, 2025
52f6122
Increase max completion length and new tokens in FullModelLLMTrainer …
mdvillagra Aug 4, 2025
133bb78
Refactor reference model loading in TimedGRPOTrainer to use AutoModel…
mdvillagra Aug 4, 2025
83f7051
Move reference model to the same device as the policy in TimedGRPOTra…
mdvillagra Aug 4, 2025
8da9d85
Update reference model in TimedGRPOTrainer to use Qwen/Qwen3-0.6B for…
mdvillagra Aug 4, 2025
0417ee8
Update reference model in TimedGRPOTrainer to use Qwen/Qwen3-0.6B-GPT…
mdvillagra Aug 4, 2025
2329b2c
Reduce max completion length and new tokens in FullModelLLMTrainer fr…
mdvillagra Aug 4, 2025
f5344ad
Increase max completion length and new tokens in FullModelLLMTrainer …
mdvillagra Aug 4, 2025
9d3c905
Remove deprecated SyncRefModelCallback implementation from custom_tra…
mdvillagra Aug 5, 2025
efd2e18
Remove fallback stub for prepare_fsdp in custom_trainer.py to streaml…
mdvillagra Aug 5, 2025
48ec150
Update reference model in TimedGRPOTrainer to use Qwen/Qwen3-1.7B for…
mdvillagra Aug 5, 2025
44e075f
Add device compatibility in TimedGRPOTrainer by moving batch tensors …
mdvillagra Aug 5, 2025
c08bc3b
Enhance batch handling in TimedGRPOTrainer by adding support for sing…
mdvillagra Aug 5, 2025
bfcf77c
Refactor batch tensor handling in TimedGRPOTrainer by removing unnece…
mdvillagra Aug 5, 2025
69194e9
Enhance tensor handling in TimedGRPOTrainer by ensuring log probabili…
mdvillagra Aug 5, 2025
800f566
Improve tensor device handling in TimedGRPOTrainer by adding checks f…
mdvillagra Aug 5, 2025
a3d7b7b
Refine tensor device alignment in TimedGRPOTrainer by implementing a …
mdvillagra Aug 5, 2025
e6e2c15
Optimize tensor device management in TimedGRPOTrainer by converting l…
mdvillagra Aug 5, 2025
92b942e
Adjust max completion length and new tokens in FullModelLLMTrainer to…
mdvillagra Aug 5, 2025
33f6ed2
Refactor TimedGRPOTrainer by adding docstrings for improved code docu…
mdvillagra Aug 5, 2025
760493c
Update reference model initialization in TimedGRPOTrainer to include …
mdvillagra Aug 5, 2025
a29d3eb
Update beta parameter in FullModelLLMTrainer to 0.1 for improved trai…
mdvillagra Aug 5, 2025
5cc0537
Update generation count in FullModelLLMTrainer to 4 and increase batc…
mdvillagra Aug 6, 2025
28d7513
Update generation count in FullModelLLMTrainer from 4 to 2 for optimi…
mdvillagra Aug 6, 2025
5c68b92
Reduce grpo_max_steps in GRPO test configuration from 50 to 20 for op…
mdvillagra Aug 6, 2025
2a7da13
Update client configuration in GRPO test setup to support 4 clients f…
mdvillagra Aug 6, 2025
79929f7
Update GRPO test configuration to reduce batch size from 2 to 1 for f…
mdvillagra Aug 6, 2025
8c8739f
Update reference model in TimedGRPOTrainer from "Qwen/Qwen3-1.7B" to …
mdvillagra Aug 6, 2025
19e0aad
Update reference model in TimedGRPOTrainer from "Qwen/Qwen3-0.6" to "…
mdvillagra Aug 6, 2025
e57d3e2
Refactor TimedGRPOTrainer to improve code documentation with added do…
mdvillagra Aug 7, 2025
3b7a416
Increase the number of retained old wallclock checkpoints from 6 to 1…
mdvillagra Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,6 @@ python/examples/launch/hello_world/fedml_job_entry_pack.bat
**mpi_host_file
/python/fedml/workflow/driver_example/customized_job_example/train_job/bootstrap.bat
/python/fedml/workflow/driver_example/customized_job_example/train_job/fedml_job_entry_pack.bat


venv
17 changes: 16 additions & 1 deletion python/fedml/cross_silo/server/fedml_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,22 @@ def handle_message_receive_model_from_client(self, msg_params):
if self.is_main_process():
mlops.log_aggregated_model_info(self.args.round_idx, model_url=global_model_url)

logging.info("\n\n==========end {}-th round training===========\n".format(self.args.round_idx))
# --------------------------------------------------
# Log global-update frequency in wall-clock terms
# --------------------------------------------------
current_ts = time.time()
# Compute and print only if this is not the very first round
if hasattr(self, "_last_round_end_ts") and self._last_round_end_ts is not None:
delta = current_ts - self._last_round_end_ts
if delta > 0:
freq = 1.0 / delta
logging.info(
f"Global update frequency: {freq:.4f} updates/sec ({delta:.2f} s per round)"
)
# Record timestamp for the next round
self._last_round_end_ts = current_ts

logging.info("\n\n==========end {}/{}-th round training===========\n".format(self.args.round_idx, self.round_num))
if self.args.round_idx < self.round_num:
mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx))

Expand Down
1,029 changes: 982 additions & 47 deletions python/spotlight_prj/fedllm/custom_trainer.py

Large diffs are not rendered by default.

158 changes: 158 additions & 0 deletions python/spotlight_prj/fedllm/data_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@

from datasets import load_dataset


class DataFormatting:

def __init__(self):

self.system_prompt = """

Respond in the following format:

<reasoning>
...
</reasoning>

<answer>

...
</answer>

"""



def extract_answer_from_model_output(self, text):

"""
Extracts the value from the last <answer> tag in the text.

Args:
text (str): The model generated containing XML-style <answer> tags.

Returs:
str or None: The content inside the <answer> tags, or None if no valid answer is found

Explanation:
1. Splits the text on the <answer> tag to isolate content after the tag.
2. Checks if at least one <answer> tag exists in the text.
3. For the last <answer> segment:
- Verifies it contains a closing </answer>
- Extracts only the content between the tags.
4. Returns None if the answer is empty (just "...") or if tags are missing
"""


#split on <answer> and take everything after the last occurane.
parts = text.split("<answer>")

if len(parts)<2: # No <answer> tag found

return None

last_part = parts[-1]

#Extract the content up to </answer>

if "</answer>" not in last_part:
return None

answer = last_part.split("</answer>")[0].strip()

return None if answer =="..." else answer


def extract_answer_from_dataset(self, text):

"""
Extracts the answer from the GSM8K dataset examples.

Args:
text(str): The dataset example text containing a question and answer

Returns:
str or None: The extracted answer part after the '####' delimiter, or None


Explanation:

1. Checks if the text contains the '####' delimiter that separates questions from answers
2. If found, splits the text at this delimiter and returns the second part
3. The answer is stripped of leading or trailing white spaces.
4. Returns None if no delimiter is present.

"""

if "####" not in text:
return None

return text.split("####")[1].strip()



def prepare_dataset(self, split="train"):

"""
Load and prepare GSM8K dataset for training with string prompts.

Args:
split(str): The dataset split to load("train" or "test"), Defaults to "train"

Returns:
list: A list of formatted examples, each containing a prompt string and the role

Explanation:
1. Loads GSM8K dataset from Hugging Face dataset hub.
2. For each example in the dataset:
- Creates a list of messages with system prompt and the question.
- Converts this list into a single string prompt using build_prompt()
- Extracts the answer from the dataset example.
- Creates a list of formatted examples with prompt and answer.
3. Returns the list of formatted examples ready for model training or evaluation.
"""

data = load_dataset('openai/gsm8k', 'main')[split]

formatted_data = []

for example in data:

# convert the list of messages to a single string prompt

prompt_str = self.build_prompt([
{"role": "system", "content": self.system_prompt},
{"role":"user", "content": example["question"]}
])


formatted_example = {
"prompt":prompt_str, # string rather than a list
"answer": self.extract_answer_from_dataset(example["answer"])
}
formatted_data.append(formatted_example)

return formatted_data



def build_prompt(self,messages):

"""
Build a single prompt string from a list of messages.

Args:
messages(list): A list of message dictionaries, each with 'role' and 'content'

Returns:
str: A concatenated string of all message content.

Explanation:
1. Takes a list of message dictionaries in typical chat format.
2. Extracts the 'content' field from each message and strips whitespace.
3. Joins all content strings with newlines to create a single prompt.
4. This preserves the training format while converting from structures messages.
"""

return "\n".join(msg["content"].strip() for msg in messages)

206 changes: 206 additions & 0 deletions python/spotlight_prj/fedllm/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@

import re

import torch
from data_formatting import DataFormatting


class Evaluation:

def __init__(self):
self.dat_fmt = DataFormatting()




def extract_last_number(self, text):

"""
Extracts the last number appearing in the text

Args:
text (str): The text to extract a number from.

Returns:
float or None: The last number in the text, or None if no number is found


Explanation:
1. Removes dollar signs and percentage symbols from text.
2. Users regex to find a number that appeares at the end of the text.
3. The pattern matches numbers that appear at the end of the string.
4 Return the found number as float, or None if no match is found.
"""

text = text.replace('$', '').replace('%','')

pattern = r'(?:^|\s|=)\s*(-?\d*\.?\d+)\s*$'

match = re.search(pattern, text)

return float(match.group(1)) if match else None




def extract_single_number(self, text):

"""
Extracts a single number from text if exactly one number is present.

Args:
text (str): The text to extract number from.

Returns:
float or None: The single number in the text, or None if zero or multiple numbers.

Explanation:
1. Uses regex to find all numbers in the text including the negative numebers.
2. If exactly one number if found, returns it as float.
3. If zero or multiple numbers are found, returns None.

"""

numbers =re.findall(r'-?\d*\.?\d+', text)
#print("NUMBERS ARE:::", numbers)

if len(numbers)==0:
return None
elif len(numbers)==1:
return float(numbers[0])

else:
return None



def evaluate_model(self, model, tokenizer, eval_samples, device):

"""
Evaluates the model on a set of examples and prints detailed results.

Args:
model: The language model to evaluate.
tokenizer: The tokenizer for encoding inputs and decoding outputs.
eval_samples (list): List of evaluation examples each containing "prompt" and "answer"
device: The device (CPU or GPU) to run evaluation on

Return:
float: The accuracy percentage (correct predictions / total examples * 100)


Explanation:
1. Sets the model to evaluation mode.
2. For each example in the evaluation set:
- Encodes the prompt and generates a respnse using the model
- Extracts the predicted answer from the generated response
- Compares the predicted answer with the expected answer using multiple methods

a. Extract string matching
b. Single number extraction and comparion.
c. Last number extraction and comparison
-Prints detailed information about each example
3. Calculates and returns the overall accuracy.
4. Returns the model to training mode.

"""


model.eval()

correct = 0

total = len(eval_samples)

print("\n" + "="*50)
print("EVALUATION ON", total, "EXAMPLES")
print("="*50)


for example in eval_samples:

#get the prompt and expected answer

full_prompt = example["prompt"]
expected = example["answer"]

#Tokenize and generate response

inputs = tokenizer(full_prompt, return_tensors='pt', padding=False, truncation=False, return_attention_mask=True).to(device)

with torch.no_grad():

outputs = model.generate(
input_ids = inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=512,
temperature=0.7,
num_return_sequences=1,
pad_token_id = tokenizer.pad_token_id,
eos_token_id = tokenizer.eos_token_id,
forced_eos_token_id = tokenizer.eos_token_id,
early_stopping = False,
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)

try:
#Extract answers and check correctness
predicted = self.dat_fmt.extract_answer_from_model_output(response)

#Try different matching method

if predicted == expected : # Exact match

is_correct = True

else:
# Try single number matchin
pred_num = self.extract_single_number(str(predicted))
exp_num = self.extract_single_number(str(expected))

if pred_num is not None and exp_num is not None and pred_num==exp_num:

is_correct = True
else:
#Try the last number matchin
pre_num = self.extract_last_number(str(predicted))
exp_num = self.extract_last_number(str(expected))

is_correct = (pred_num is not None and exp_num is not None and pred_num == exp_num)

if is_correct:
correct+=1


# Print evaluation results

print("\nPrompt:")
print(full_prompt)
print("\nExpected Answer:")
print(expected)
print("\nExtracted Answer:")
print(predicted)
print("\nFull Generated Response:")
print(response)
print("\nCorrect:", "✓" if is_correct else "✗")
print("--"*50)

except Exception as e:

print("\nFailed to parse the model output from prompt:")
print(full_prompt)
print('Error:',e)
print('-'*50)


accuracy = (correct / total) * 100

print(f"\nAccuracy: {accuracy:.2f}% ({correct}/{total})" )

# return the model to training mode
model.train()

return accuracy


Loading