Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions aif_gen/_cli/commands/clean.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the use case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reiterate, is this necessary to merge? Or just useful as a pre-processing step for our data? By the way, I would rather not have used this in excess (or at all), since it's unclear how removing specific words could alter the latent preference that we are aiming to model.

Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import logging
import pathlib
from typing import Optional

import click

from aif_gen.dataset.continual_alignment_dataset import (
ContinualAlignmentDataset,
)
from aif_gen.util.hf import download_from_hf, upload_to_hf
from aif_gen.util.seed import seed_everything


@click.command(context_settings={'show_default': True})
@click.argument(
'input_data_file',
type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path),
)
@click.argument(
'output_data_file',
type=click.Path(dir_okay=False, path_type=pathlib.Path),
)
@click.argument(
'words',
type=click.STRING,
)
@click.option(
'--random_seed',
type=int,
help='Random seed for data generation.',
default=0,
)
@click.option(
'--hf-repo-id',
type=click.STRING,
default=None,
help='If not None, push the generated input_dataset to a HuggingFace remote repository with the associated repo-id.',
)
def clean_dataset(
input_data_file: pathlib.Path,
output_data_file: pathlib.Path,
words: str,
random_seed: int,
hf_repo_id: Optional[str],
) -> None:
r"""Clean a ContinualAlignmentDataset given a space-separated string of words.

INPUT_DATA_FILE: Path to the input dataset.
OUTPUT_DATA_FILE: Path to the output dataset.
WORDS: Space-separated string of words to clean the dataset.
"""
if hf_repo_id is not None:
input_data_file = download_from_hf(hf_repo_id, input_data_file)

logging.info(f'Reading input_dataset from: {input_data_file}')
input_dataset = ContinualAlignmentDataset.from_json(input_data_file)
logging.info(f'Read {len(input_dataset)} samples from: {input_data_file}')

if not len(input_dataset):
logging.warning('No samples found in dataset, skipping clean up.')
return

logging.info(f'Using words: {words}')
logging.info(f'Random seed: {random_seed}')
seed_everything(random_seed)

output_data_file.parent.mkdir(parents=True, exist_ok=True)

words_list = words.split(' ')
if len(words_list) == 0:
logging.warning('No words found in words string, skipping clean up.')
return

# clean up each data point in the dataset
for dataset in input_dataset.datasets:
for sample in dataset.samples:
for word in words_list:
sample.prompt = sample.prompt.replace(word, '')
sample.chosen = sample.chosen.replace(word, '')
sample.rejected = sample.rejected.replace(word, '')

logging.info(f'Finished cleaning dataset.')

logging.info(f'Writing {len(dataset)} samples to {output_data_file}')
input_dataset.to_json(output_data_file)
logging.info(f'Wrote {len(dataset)} samples to {output_data_file}')
Comment on lines +84 to +86
Copy link

Copilot AI Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable 'dataset' refers to the last iterated dataset rather than the aggregate number of cleaned samples; consider using the total sample count or an appropriately aggregated value to accurately reflect the number of samples written.

Suggested change
logging.info(f'Writing {len(dataset)} samples to {output_data_file}')
input_dataset.to_json(output_data_file)
logging.info(f'Wrote {len(dataset)} samples to {output_data_file}')
total_samples = sum(len(dataset.samples) for dataset in input_dataset.datasets)
logging.info(f'Writing {total_samples} samples to {output_data_file}')
input_dataset.to_json(output_data_file)
logging.info(f'Wrote {total_samples} samples to {output_data_file}')

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot is right here


if hf_repo_id is not None:
upload_to_hf(repo_id=hf_repo_id, local_path=output_data_file)
2 changes: 2 additions & 0 deletions aif_gen/_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import click

from aif_gen._cli.commands.clean import clean_dataset
from aif_gen._cli.commands.filter import filter_dataset
from aif_gen._cli.commands.generate import generate
from aif_gen._cli.commands.merge import merge
Expand Down Expand Up @@ -47,6 +48,7 @@ def cli(log_file: pathlib.Path) -> None:
cli.add_command(sample)
cli.add_command(transmute)
cli.add_command(filter_dataset)
cli.add_command(clean_dataset)

if __name__ == '__main__':
cli()
4 changes: 2 additions & 2 deletions aif_gen/api/response_mapper/response_mapper.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming we are finalized on this approach for our first release and paper, we should move this to config (in a subsequent PR)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also could be worth fixing #140 while we are here unless you rather do it separately for clean git history

Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class ResponseMapper(ResponseMapperBase):
"""

NUMBER_OF_PREFERENCE_AXES_SAMPLED: int = 3
TASK_PREFERENCE_INCLUSION_PROBABILITY_POSIIVE: float = 0.5
TASK_PREFERENCE_INCLUSION_PROBABILITY_NEGATIVE: float = 0.5
TASK_PREFERENCE_INCLUSION_PROBABILITY_POSIIVE: float = 0.4
TASK_PREFERENCE_INCLUSION_PROBABILITY_NEGATIVE: float = 0.4

def __init__(self, suffix_context: Optional[str] = None) -> None:
self._suffix_context = suffix_context
Expand Down
11 changes: 4 additions & 7 deletions aif_gen/generate/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,13 +533,10 @@ class ResponsePair(pydantic.BaseModel, extra='forbid'):
async with async_semaphore:
if cache is not None:
output = await cache.get(task_prompt + task_prompt_second)
if output is None:
raise ValueError(
f'No cached response for task prompt: {task_prompt + task_prompt_second}'
)
structured_output = ResponsePair.model_validate_json(output)
output1_str: str = structured_output.chosen
output2_str: str = structured_output.rejected
if output is not None:
structured_output = ResponsePair.model_validate_json(output)
output1_str: str = structured_output.chosen
output2_str: str = structured_output.rejected
else:
output = None

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ uv sync --group benchmarks
```sh
uv run benchmarks/reward_modeling.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name preference_axes.json \
--dataset_name benchmarks/continual_data_debug.json \
--dataset_index 0 \
--output_dir Qwen2-0.5B-Reward \
--per_device_train_batch_size 8 \
Expand Down
169 changes: 169 additions & 0 deletions jobs/generate_all_downsampled.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#!/bin/bash
#SBATCH --job-name=generate_static_all_70B_final
#SBATCH --partition=main
#SBATCH --mem=48G
#SBATCH --cpus-per-task=6
#SBATCH --time=24:00:00
#SBATCH --output=slurm-%j.out
#SBATCH --error=slurm-%j.err
#SBATCH --mail-type=ALL
#SBATCH --mail-user=

# set -euo pipefail
source .env

# 1) start the vllm server in the background
uvx vllm serve meta-llama/Meta-Llama-3-70B-Instruct \
--dtype auto \
--api-key openai \
--tensor-parallel-size 2 &
SERVER_PID=$!
echo "⏳ Waiting for VLLM server (PID=$SERVER_PID) to come up…"

# replace fixed sleep with a health‐check loop
export UV_VLLM_SERVER_URL="http://127.0.0.1:8000" # tell `uv run` where to send requests
for i in $(seq 1 600); do
if curl -fs "${UV_VLLM_SERVER_URL}/health"; then
echo "✅ VLLM up after $((i*5))s"
break
fi
echo "…still waiting ($i/600)…"
sleep 5
done

# helper to run one job
() { echo "➡️ $*"; eval "$*"; }


# 2) run all generation jobs sequentially
uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/education_qna_direct/data.json" \
config/static_copy/education_qna_direct.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/education_qna_eli5/data.json" \
config/static_copy/education_qna_eli5.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/education_qna_expert/data.json" \
config/static_copy/education_qna_expert.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/education_qna_hinted/data.json" \
config/static_copy/education_qna_hinted.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/education_summary_eli5/data.json" \
config/static_copy/education_summary_eli5.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/education_summary_expert/data.json" \
config/static_copy/education_summary_expert.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/politics_generate_formal/data.json" \
config/static_copy/politics_generate_formal.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/politics_generate_rapper/data.json" \
config/static_copy/politics_generate_rapper.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/politics_generate_shakespeare/data.json" \
config/static_copy/politics_generate_shakespeare.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/politics_qna_eli5/data.json" \
config/static_copy/politics_qna_eli5.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/politics_qna_expert/data.json" \
config/static_copy/politics_qna_expert.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/politics_summary_eli5/data.json" \
config/static_copy/politics_summary_eli5.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/politics_summary_expert/data.json" \
config/static_copy/politics_summary_expert.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/tech_healthcare_qna_eli5/data.json" \
config/static_copy/tech_healthcare_qna_eli5.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/tech_healthcare_qna_expert/data.json" \
config/static_copy/tech_healthcare_qna_expert.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/tech_physics_summary_eli5/data.json" \
config/static_copy/tech_physics_summary_eli5.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/tech_physics_summary_expert/data.json" \
config/static_copy/tech_physics_summary_expert.yaml \
Meta-Llama-3.1-70B-Instruct

uv run aif generate \
--include-preference-axes \
--max_concurrency 256 \
--output_file "data/70B_generation/tech_physics_summary_highschool/data.json" \
config/static_copy/tech_physics_summary_highschool.yaml \
Meta-Llama-3.1-70B-Instruct

# 3) shutdown the server when done
echo "✅ All jobs finished. Shutting down VLLM server (PID=$SERVER_PID)…"
kill $SERVER_PID
wait $SERVER_PID 2>/dev/null || true
echo "🛑 Server stopped."
Loading