diff --git a/src/lighteval/main_inspect.py b/src/lighteval/main_inspect.py index b1cf3215a..377119774 100644 --- a/src/lighteval/main_inspect.py +++ b/src/lighteval/main_inspect.py @@ -52,6 +52,11 @@ def get_inspect_ai_task( name = lighteval_task_config.name sample_fields = lighteval_task_config.sample_fields + if sample_fields is None: + raise ValueError( + f"Task {name} is not supported by inspect_ai yet. You can either define it or use a different backend, `lighteval --help`" + ) + dataset_repo = lighteval_task_config.hf_repo dataset_subset = lighteval_task_config.hf_subset dataset_split = lighteval_task_config.evaluation_splits[0] diff --git a/src/lighteval/tasks/tasks/aimo.py b/src/lighteval/tasks/tasks/aimo.py index fdfc5ff95..03d86268a 100644 --- a/src/lighteval/tasks/tasks/aimo.py +++ b/src/lighteval/tasks/tasks/aimo.py @@ -17,7 +17,10 @@ paper: """ -from lighteval.metrics.metrics import Metrics +from inspect_ai.dataset import Sample +from inspect_ai.solver import generate + +from lighteval.metrics.metrics import Metrics, math_scorer from lighteval.metrics.normalizations import math_normalizer from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -32,9 +35,16 @@ def aimo_prompt(line, task_name: str = None): ) +def record_to_sample(record): + return Sample(input=record["problem"], target=str(record["answer"])) + + task = LightevalTaskConfig( name="aimo_progress_prize_1", prompt_function=aimo_prompt, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), hf_subset="", hf_repo="lighteval/aimo_progress_prize_1", hf_avail_splits=["train"], diff --git a/src/lighteval/tasks/tasks/anli.py b/src/lighteval/tasks/tasks/anli.py index 86a0a9d65..0179f5ef4 100644 --- a/src/lighteval/tasks/tasks/anli.py +++ b/src/lighteval/tasks/tasks/anli.py @@ -22,6 +22,12 @@ https://arxiv.org/abs/1910.14599 """ +from string import ascii_uppercase + +from inspect_ai.dataset import Sample +from inspect_ai.scorer import choice +from inspect_ai.solver import multiple_choice + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -36,6 +42,12 @@ def anli_prompt(line, task_name: str = None): ) +def record_to_sample(record): + choices = ["True", "Neither", "False"] + query = f"{record['premise']}\nQuestion: {record['hypothesis']}" + return Sample(input=query, target=ascii_uppercase[record["label"]], choices=choices) + + anli_r1 = LightevalTaskConfig( name="anli:r1", prompt_function=anli_prompt, @@ -49,6 +61,9 @@ def anli_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) @@ -65,6 +80,9 @@ def anli_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) @@ -81,6 +99,9 @@ def anli_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/arc.py b/src/lighteval/tasks/tasks/arc.py index 39f8b8827..d00e89d3e 100644 --- a/src/lighteval/tasks/tasks/arc.py +++ b/src/lighteval/tasks/tasks/arc.py @@ -22,6 +22,10 @@ https://arxiv.org/abs/1803.05457 """ +from inspect_ai.dataset import Sample +from inspect_ai.scorer import choice +from inspect_ai.solver import multiple_choice + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -36,6 +40,14 @@ def arc_prompt(line, task_name: str = None): ) +def record_to_sample(record): + query = record["question"].strip() + target = record["answerKey"] + choices = record["choices"]["text"] + + return Sample(input=query, target=target, choices=choices) + + arc_challenge = LightevalTaskConfig( name="arc:challenge", prompt_function=arc_prompt, @@ -51,6 +63,9 @@ def arc_prompt(line, task_name: str = None): ], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) arc_easy = LightevalTaskConfig( @@ -68,6 +83,9 @@ def arc_prompt(line, task_name: str = None): ], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) TASKS_TABLE = [arc_challenge, arc_easy] diff --git a/src/lighteval/tasks/tasks/arithmetic.py b/src/lighteval/tasks/tasks/arithmetic.py index 48a290435..d913e3b3e 100644 --- a/src/lighteval/tasks/tasks/arithmetic.py +++ b/src/lighteval/tasks/tasks/arithmetic.py @@ -19,15 +19,25 @@ https://arxiv.org/abs/2005.14165 """ -from lighteval.metrics.metrics import Metrics +from inspect_ai.dataset import Sample +from inspect_ai.solver import generate + +from lighteval.metrics.metrics import Metrics, math_scorer from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc +# TODO: convert dataset to parquet + + def arithmetic_prompt(line, task_name: str = None): return Doc(task_name=task_name, query=line["context"], choices=[line["completion"]], gold_index=[0]) +def record_to_sample(record): + return Sample(input=record["context"], target=record["completion"]) + + arithmetic_1dc = LightevalTaskConfig( name="arithmetic:1dc", prompt_function=arithmetic_prompt, @@ -41,6 +51,9 @@ def arithmetic_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) arithmetic_2da = LightevalTaskConfig( @@ -56,6 +69,9 @@ def arithmetic_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) arithmetic_2dm = LightevalTaskConfig( @@ -71,6 +87,9 @@ def arithmetic_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) arithmetic_2ds = LightevalTaskConfig( @@ -86,6 +105,9 @@ def arithmetic_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) arithmetic_3da = LightevalTaskConfig( @@ -101,6 +123,9 @@ def arithmetic_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) arithmetic_3ds = LightevalTaskConfig( @@ -116,6 +141,9 @@ def arithmetic_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) arithmetic_4da = LightevalTaskConfig( @@ -131,6 +159,9 @@ def arithmetic_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) arithmetic_4ds = LightevalTaskConfig( @@ -146,6 +177,9 @@ def arithmetic_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) arithmetic_5da = LightevalTaskConfig( @@ -161,6 +195,9 @@ def arithmetic_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) arithmetic_5ds = LightevalTaskConfig( @@ -176,6 +213,9 @@ def arithmetic_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/asdiv.py b/src/lighteval/tasks/tasks/asdiv.py index 4fd34df36..fe0632655 100644 --- a/src/lighteval/tasks/tasks/asdiv.py +++ b/src/lighteval/tasks/tasks/asdiv.py @@ -19,7 +19,10 @@ https://arxiv.org/abs/2410.12853 """ -from lighteval.metrics.metrics import Metrics +from inspect_ai.dataset import Sample +from inspect_ai.solver import generate + +from lighteval.metrics.metrics import Metrics, math_scorer from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -33,6 +36,12 @@ def asdiv_prompt(line, task_name: str = None): ) +def record_to_sample(record): + query = f"{record['body']}\n{record['question']}" + target = record["answer"].split(" (")[0] + return Sample(input=query, target=target) + + asdiv = LightevalTaskConfig( name="asdiv", prompt_function=asdiv_prompt, @@ -46,6 +55,9 @@ def asdiv_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=math_scorer(), ) TASKS_TABLE = [asdiv] diff --git a/src/lighteval/tasks/tasks/babi_qa.py b/src/lighteval/tasks/tasks/babi_qa.py index 3a16c4fb3..eaba3d84b 100644 --- a/src/lighteval/tasks/tasks/babi_qa.py +++ b/src/lighteval/tasks/tasks/babi_qa.py @@ -26,6 +26,9 @@ from lighteval.tasks.requests import Doc +# TODO: clean dataset and convert to inspect-ai + + def babi_qa_prompt(line, task_name: str = None): def process_path(path: str) -> str: steps = path.split(",") diff --git a/src/lighteval/tasks/tasks/bbq.py b/src/lighteval/tasks/tasks/bbq.py index eb5fb1d45..b0bbc48c0 100644 --- a/src/lighteval/tasks/tasks/bbq.py +++ b/src/lighteval/tasks/tasks/bbq.py @@ -21,6 +21,10 @@ from string import ascii_uppercase +from inspect_ai.dataset import Sample +from inspect_ai.scorer import choice +from inspect_ai.solver import multiple_choice + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -38,6 +42,13 @@ def bbq_prompt(line, task_name: str = None): ) +def record_to_sample(record): + query = f"{record['context']}\n{record['question']}" + choices = record["choices"] + target = ascii_uppercase[record["gold_index"]] + return Sample(input=query, target=target, choices=choices) + + bbq = LightevalTaskConfig( name="bbq", prompt_function=bbq_prompt, @@ -51,6 +62,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_Age = LightevalTaskConfig( @@ -66,6 +80,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_Disability_status = LightevalTaskConfig( @@ -81,6 +98,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_Gender_identity = LightevalTaskConfig( @@ -96,6 +116,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_Nationality = LightevalTaskConfig( @@ -111,6 +134,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_Physical_appearance = LightevalTaskConfig( @@ -126,6 +152,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_Race_ethnicity = LightevalTaskConfig( @@ -141,6 +170,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_Race_x_SES = LightevalTaskConfig( @@ -156,6 +188,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_Race_x_gender = LightevalTaskConfig( @@ -171,6 +206,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_Religion = LightevalTaskConfig( @@ -186,6 +224,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_SES = LightevalTaskConfig( @@ -201,6 +242,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bbq_Sexual_orientation = LightevalTaskConfig( @@ -216,6 +260,9 @@ def bbq_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/bigbench.py b/src/lighteval/tasks/tasks/bigbench.py index e18057f33..9c9487dcf 100644 --- a/src/lighteval/tasks/tasks/bigbench.py +++ b/src/lighteval/tasks/tasks/bigbench.py @@ -19,6 +19,12 @@ https://arxiv.org/abs/2206.04615 """ +from string import ascii_uppercase + +from inspect_ai.dataset import Sample +from inspect_ai.scorer import choice +from inspect_ai.solver import multiple_choice + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -83,6 +89,13 @@ def bigbench_prompt(line, task_name: str = None): return Doc(task_name=task_name, query=line["inputs"], choices=choices, gold_index=gold_index) +def record_to_sample(record): + query = record["inputs"] + choices = record["multiple_choice_targets"] + target = ascii_uppercase[record["multiple_choice_scores"].index(1)] + return Sample(input=query, target=target, choices=choices) + + abstract_narrative_understanding = LightevalTaskConfig( name="bigbench:abstract_narrative_understanding", prompt_function=bigbench_prompt, @@ -96,6 +109,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) anachronisms = LightevalTaskConfig( @@ -111,6 +127,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) analogical_similarity = LightevalTaskConfig( @@ -126,6 +145,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) analytic_entailment = LightevalTaskConfig( @@ -141,6 +163,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) arithmetic_bb = LightevalTaskConfig( @@ -171,6 +196,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.exact_match(sample_params={"strip_strings": False})], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) authorship_verification = LightevalTaskConfig( @@ -186,6 +214,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) auto_categorization = LightevalTaskConfig( @@ -231,6 +262,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) bridging_anaphora_resolution_barqa = LightevalTaskConfig( @@ -246,6 +280,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.exact_match(sample_params={"strip_strings": False})], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) causal_judgment = LightevalTaskConfig( @@ -261,6 +298,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) cause_and_effect = LightevalTaskConfig( @@ -276,6 +316,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) checkmate_in_one = LightevalTaskConfig( @@ -336,6 +379,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) code_line_description = LightevalTaskConfig( @@ -351,6 +397,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) codenames = LightevalTaskConfig( @@ -401,6 +450,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) conceptual_combinations = LightevalTaskConfig( @@ -416,6 +468,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) conlang_translation = LightevalTaskConfig( @@ -461,6 +516,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) crass_ai = LightevalTaskConfig( @@ -476,6 +534,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) cryobiology_spanish = LightevalTaskConfig( @@ -491,6 +552,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) cryptonite = LightevalTaskConfig( @@ -521,6 +585,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) dark_humor_detection = LightevalTaskConfig( @@ -536,6 +603,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) date_understanding = LightevalTaskConfig( @@ -551,6 +621,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) disambiguation_qa = LightevalTaskConfig( @@ -566,6 +639,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) discourse_marker_prediction = LightevalTaskConfig( @@ -581,6 +657,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) disfl_qa = LightevalTaskConfig( @@ -611,6 +690,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) elementary_math_qa = LightevalTaskConfig( @@ -626,6 +708,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) emoji_movie = LightevalTaskConfig( @@ -661,6 +746,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) empirical_judgments = LightevalTaskConfig( @@ -676,6 +764,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) english_proverbs = LightevalTaskConfig( @@ -691,6 +782,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) english_russian_proverbs = LightevalTaskConfig( @@ -706,6 +800,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) entailed_polarity = LightevalTaskConfig( @@ -856,6 +953,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.bleu, Metrics.rouge_t5], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) gender_inclusive_sentences_german = LightevalTaskConfig( @@ -871,6 +971,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.exact_match(sample_params={"strip_strings": False})], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) general_knowledge = LightevalTaskConfig( @@ -886,6 +989,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) geometric_shapes = LightevalTaskConfig( @@ -1266,6 +1372,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.exact_match(sample_params={"strip_strings": False})], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) linguistics_puzzles = LightevalTaskConfig( @@ -1296,6 +1405,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) logical_args = LightevalTaskConfig( @@ -1386,6 +1498,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.exact_match(sample_params={"strip_strings": False})], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) metaphor_boolean = LightevalTaskConfig( @@ -1491,6 +1606,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.exact_match(sample_params={"strip_strings": False})], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) moral_permissibility = LightevalTaskConfig( @@ -1551,6 +1669,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.exact_match(sample_params={"strip_strings": False})], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) navigate = LightevalTaskConfig( @@ -1566,6 +1687,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) nonsense_words_grammar = LightevalTaskConfig( @@ -2071,6 +2195,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.bleu, Metrics.rouge_t5, Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) simp_turing_concept = LightevalTaskConfig( @@ -2221,6 +2348,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.f1_score_macro], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) sports_understanding = LightevalTaskConfig( @@ -2416,6 +2546,9 @@ def bigbench_prompt(line, task_name: str = None): metrics=[Metrics.bleu, Metrics.rouge_t5, Metrics.loglikelihood_acc, Metrics.bleurt], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) tracking_shuffled_objects = LightevalTaskConfig( diff --git a/src/lighteval/tasks/tasks/bigbench_hard.py b/src/lighteval/tasks/tasks/bigbench_hard.py index 4b930d1b9..18ebef8c0 100644 --- a/src/lighteval/tasks/tasks/bigbench_hard.py +++ b/src/lighteval/tasks/tasks/bigbench_hard.py @@ -17,6 +17,10 @@ from string import ascii_uppercase +from inspect_ai.dataset import Sample +from inspect_ai.scorer import choice +from inspect_ai.solver import multiple_choice + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -41,6 +45,14 @@ def bbh_prompt(line, task_name: str = None): ) +def record_to_sample(record): + query = f"{record.get('task_prefix', '')}\n{record['input']}" + target = ascii_uppercase[record["target_idx"]] + choices = record["choices"] + + return Sample(input=query, target=target, choices=choices) + + causal_judgment = LightevalTaskConfig( name="bigbench_hard:causal_judgment", prompt_function=bbh_prompt, @@ -54,6 +66,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) date_understanding = LightevalTaskConfig( @@ -69,6 +84,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) disambiguation_qa = LightevalTaskConfig( @@ -84,6 +102,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) geometric_shapes = LightevalTaskConfig( @@ -99,6 +120,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) logical_deduction_five_objects = LightevalTaskConfig( @@ -114,6 +138,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) logical_deduction_seven_objects = LightevalTaskConfig( @@ -129,6 +156,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) logical_deduction_three_objects = LightevalTaskConfig( @@ -144,6 +174,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) movie_recommendation = LightevalTaskConfig( @@ -159,6 +192,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) navigate = LightevalTaskConfig( @@ -174,6 +210,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) reasoning_about_colored_objects = LightevalTaskConfig( @@ -189,6 +228,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) ruin_names = LightevalTaskConfig( @@ -204,6 +246,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) salient_translation_error_detection = LightevalTaskConfig( @@ -219,6 +264,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) snarks = LightevalTaskConfig( @@ -234,6 +282,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) sports_understanding = LightevalTaskConfig( @@ -249,6 +300,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) temporal_sequences = LightevalTaskConfig( @@ -264,6 +318,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) tracking_shuffled_objects_five_objects = LightevalTaskConfig( @@ -279,6 +336,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) tracking_shuffled_objects_seven_objects = LightevalTaskConfig( @@ -294,6 +354,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) tracking_shuffled_objects_three_objects = LightevalTaskConfig( @@ -309,6 +372,9 @@ def bbh_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["", "Q=", "\n\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/blimp.py b/src/lighteval/tasks/tasks/blimp.py index 1b9278a32..08dc09d65 100644 --- a/src/lighteval/tasks/tasks/blimp.py +++ b/src/lighteval/tasks/tasks/blimp.py @@ -27,6 +27,9 @@ from lighteval.tasks.requests import Doc +# TODO: Convert to inspect-ai + + def blimp_prompt(line, task_name: str = None): return Doc(task_name=task_name, query="", choices=[line["sentence_good"], line["sentence_bad"]], gold_index=0) diff --git a/src/lighteval/tasks/tasks/bold.py b/src/lighteval/tasks/tasks/bold.py index 2ecc52c05..ad4dd7b15 100644 --- a/src/lighteval/tasks/tasks/bold.py +++ b/src/lighteval/tasks/tasks/bold.py @@ -19,6 +19,10 @@ https://dl.acm.org/doi/10.1145/3442188.3445924 """ +from inspect_ai.dataset import Sample +from inspect_ai.scorer import exact +from inspect_ai.solver import generate + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -28,6 +32,12 @@ def bold_prompt(line, task_name: str = None): return Doc(task_name=task_name, query=line["text"], choices=None, gold_index=None) +def record_to_sample(record): + query = record["text"] + target = "" + return Sample(input=query, target=target) + + bold = LightevalTaskConfig( name="bold", prompt_function=bold_prompt, @@ -41,6 +51,9 @@ def bold_prompt(line, task_name: str = None): metrics=[Metrics.prediction_perplexity], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=exact(), ) bold_gender = LightevalTaskConfig( @@ -56,6 +69,9 @@ def bold_prompt(line, task_name: str = None): metrics=[Metrics.prediction_perplexity], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=exact(), ) bold_political_ideology = LightevalTaskConfig( @@ -71,6 +87,9 @@ def bold_prompt(line, task_name: str = None): metrics=[Metrics.prediction_perplexity], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=exact(), ) bold_profession = LightevalTaskConfig( @@ -86,6 +105,9 @@ def bold_prompt(line, task_name: str = None): metrics=[Metrics.prediction_perplexity], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=exact(), ) bold_race = LightevalTaskConfig( @@ -101,6 +123,9 @@ def bold_prompt(line, task_name: str = None): metrics=[Metrics.prediction_perplexity], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=exact(), ) bold_religious_ideology = LightevalTaskConfig( @@ -116,6 +141,9 @@ def bold_prompt(line, task_name: str = None): metrics=[Metrics.prediction_perplexity], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=exact(), ) TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/boolq.py b/src/lighteval/tasks/tasks/boolq.py index df854b9e6..f927f9491 100644 --- a/src/lighteval/tasks/tasks/boolq.py +++ b/src/lighteval/tasks/tasks/boolq.py @@ -18,6 +18,12 @@ https://arxiv.org/abs/1905.11946 """ +from string import ascii_uppercase + +from inspect_ai.dataset import Sample +from inspect_ai.scorer import choice +from inspect_ai.solver import multiple_choice + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -48,6 +54,24 @@ def boolq_contrastset_prompt(line, task_name: str = None): ][0] +def record_to_sample(record): + choices = ["Yes", "No"] + query = f"{record['passage']}\n{record['question']}" + target = ascii_uppercase[choices.index(record["answer"])] + return Sample(input=query, target=target, choices=choices) + + +def record_to_sample_contrastset(record): + if record["contrast_inputs"] in [None, ""]: + return record_to_sample(record) + + choices = ["Yes", "No"] + query = f"{record['contrast_inputs']['passage']}\n{record['contrast_inputs']['question']}" + target = ascii_uppercase[choices.index(record["answer"])] + + return Sample(input=query, target=target, choices=choices) + + boolq = LightevalTaskConfig( name="boolq", prompt_function=boolq_prompt, @@ -63,6 +87,9 @@ def boolq_contrastset_prompt(line, task_name: str = None): ], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) @@ -81,6 +108,9 @@ def boolq_contrastset_prompt(line, task_name: str = None): ], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample_contrastset, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/civil_comments.py b/src/lighteval/tasks/tasks/civil_comments.py index 79f96ff8e..6c8fe51ec 100644 --- a/src/lighteval/tasks/tasks/civil_comments.py +++ b/src/lighteval/tasks/tasks/civil_comments.py @@ -18,6 +18,12 @@ https://arxiv.org/abs/1903.04561 """ +from string import ascii_uppercase + +from inspect_ai.dataset import Sample +from inspect_ai.scorer import choice +from inspect_ai.solver import multiple_choice + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -32,20 +38,12 @@ def civil_comments_prompt(line, task_name: str = None): ) -civil_comments = LightevalTaskConfig( - name="civil_comments", - prompt_function=civil_comments_prompt, - hf_repo="lighteval/civil_comments_helm", - hf_subset="all", - hf_avail_splits=["train", "test"], - evaluation_splits=["test"], - few_shots_split=None, - few_shots_select=None, - generation_size=100, - metrics=[Metrics.exact_match], - stop_sequence=["\n"], - version=0, -) +def record_to_sample(record): + query = f"{record['text']}\nIs the passage above toxic?" + choices = ["True", "False"] + target = ascii_uppercase[choices.index(record["gold"])] + return Sample(input=query, target=target, choices=choices) + civil_comments_LGBTQ = LightevalTaskConfig( name="civil_comments:LGBTQ", @@ -60,6 +58,9 @@ def civil_comments_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) civil_comments_black = LightevalTaskConfig( @@ -75,6 +76,9 @@ def civil_comments_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) civil_comments_christian = LightevalTaskConfig( @@ -90,6 +94,9 @@ def civil_comments_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) civil_comments_female = LightevalTaskConfig( @@ -105,6 +112,9 @@ def civil_comments_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) civil_comments_male = LightevalTaskConfig( @@ -120,6 +130,9 @@ def civil_comments_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) civil_comments_muslim = LightevalTaskConfig( @@ -135,6 +148,9 @@ def civil_comments_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) civil_comments_other_religions = LightevalTaskConfig( @@ -150,6 +166,9 @@ def civil_comments_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) civil_comments_white = LightevalTaskConfig( @@ -165,10 +184,12 @@ def civil_comments_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) TASKS_TABLE = [ - civil_comments, civil_comments_LGBTQ, civil_comments_black, civil_comments_christian, diff --git a/src/lighteval/tasks/tasks/commonsenseqa.py b/src/lighteval/tasks/tasks/commonsenseqa.py index 639be22d0..59a1e6cdb 100644 --- a/src/lighteval/tasks/tasks/commonsenseqa.py +++ b/src/lighteval/tasks/tasks/commonsenseqa.py @@ -25,6 +25,10 @@ from string import ascii_uppercase +from inspect_ai.dataset import Sample +from inspect_ai.scorer import choice +from inspect_ai.solver import multiple_choice + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -46,6 +50,13 @@ def commonsenseqa_prompt(line, task_name: str = None): ) +def record_to_sample(record): + query = record["question"] + choices = record["choices"]["text"] + target = record["answerKey"] + return Sample(input=query, target=target, choices=choices) + + commonsenseqa = LightevalTaskConfig( name="commonsenseqa", prompt_function=commonsenseqa_prompt, @@ -59,6 +70,9 @@ def commonsenseqa_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/covid_dialogue.py b/src/lighteval/tasks/tasks/covid_dialogue.py index 3446ac17d..4275dcaf8 100644 --- a/src/lighteval/tasks/tasks/covid_dialogue.py +++ b/src/lighteval/tasks/tasks/covid_dialogue.py @@ -19,11 +19,18 @@ https://arxiv.org/abs/2004.06561 """ +from inspect_ai.dataset import Sample +from inspect_ai.scorer import model_graded_fact +from inspect_ai.solver import generate, system_message + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc +PROMPT = "Generate a response given a patient's questions and concerns." + + def covid_dialogue_prompt(line, task_name: str = None): return Doc( task_name=task_name, @@ -34,6 +41,12 @@ def covid_dialogue_prompt(line, task_name: str = None): ) +def record_to_sample(record): + query = record["query"] + target = record["answer"] + return Sample(input=query, target=target) + + covid_dialogue = LightevalTaskConfig( name="covid_dialogue", prompt_function=covid_dialogue_prompt, @@ -47,6 +60,9 @@ def covid_dialogue_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[system_message(PROMPT), generate(cache=True)], + scorer=model_graded_fact(), ) TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/ifbench/main.py b/src/lighteval/tasks/tasks/ifbench/main.py index db6aa9820..45fb76820 100644 --- a/src/lighteval/tasks/tasks/ifbench/main.py +++ b/src/lighteval/tasks/tasks/ifbench/main.py @@ -19,12 +19,10 @@ """ import numpy as np -from aenum import extend_enum from inspect_ai.dataset import Sample from inspect_ai.scorer import Score, Target, accuracy, scorer, stderr from inspect_ai.solver import TaskState, generate -from lighteval.metrics.metrics import Metrics from lighteval.metrics.metrics_sample import SampleLevelComputation from lighteval.metrics.utils.metric_utils import ( SampleLevelMetricGrouping, @@ -183,5 +181,3 @@ async def score(state: TaskState, target: Target): ) TASKS_TABLE = [ifbench_test, ifbench_multiturn] - -extend_enum(Metrics, "ifbench_metric", ifbench_metrics) diff --git a/src/lighteval/tasks/tasks/math_500.py b/src/lighteval/tasks/tasks/math_500.py index 5c075121d..e84f633c5 100644 --- a/src/lighteval/tasks/tasks/math_500.py +++ b/src/lighteval/tasks/tasks/math_500.py @@ -19,19 +19,25 @@ https://arxiv.org/abs/2305.20050 """ +from inspect_ai.dataset import Sample +from inspect_ai.scorer import model_graded_fact +from inspect_ai.solver import generate, prompt_template + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc -def math_500_prompt(line, task_name: str = None): - MATH_QUERY_TEMPLATE = """ +MATH_QUERY_TEMPLATE = """ Solve the following problem. The final line of your response MUST be of the following format: "ANSWER: $ANSWER" (without quotes) where $ANSWER is the final answer. Think step by step before answering. -{Question} +{prompt} """.strip() - query = MATH_QUERY_TEMPLATE.format(Question=line["problem"]) + + +def math_500_prompt(line, task_name: str = None): + query = MATH_QUERY_TEMPLATE.format(prompt=line["problem"]) return Doc( task_name=task_name, query=query, @@ -40,6 +46,12 @@ def math_500_prompt(line, task_name: str = None): ) +def record_to_sample(record): + query = record["problem"] + target = record["answer"] + return Sample(input=query, target=target) + + math_500 = LightevalTaskConfig( name="math_500", prompt_function=math_500_prompt, @@ -54,6 +66,9 @@ def math_500_prompt(line, task_name: str = None): Metrics.pass_at_k_math(sample_params={"k": 1, "n": 1}), ], version=2, + sample_fields=record_to_sample, + solver=[prompt_template(MATH_QUERY_TEMPLATE), generate(cache=True)], + scorer=model_graded_fact(), ) TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/mix_eval/main.py b/src/lighteval/tasks/tasks/mix_eval/main.py index 73a9c00c5..e842fee9d 100644 --- a/src/lighteval/tasks/tasks/mix_eval/main.py +++ b/src/lighteval/tasks/tasks/mix_eval/main.py @@ -24,8 +24,12 @@ import logging import re +from string import ascii_uppercase import numpy as np +from inspect_ai.dataset import Sample +from inspect_ai.scorer import choice, model_graded_fact +from inspect_ai.solver import generate, multiple_choice from lighteval.metrics.metrics_sample import JudgeLLMMixEval from lighteval.metrics.utils.metric_utils import SampleLevelMetricGrouping @@ -76,6 +80,19 @@ def mixeval_multichoice_prompt(line, task_name: str = ""): ) +def record_to_sample_freeform(record): + query = record["prompt"] + target = record["target"][0] + return Sample(input=query, target=target) + + +def record_to_sample_multichoice(record): + query = record["prompt"] + choices = record["options"] + target = ascii_uppercase[int(record["target"][0])] + return Sample(input=query, target=target, choices=choices) + + def process_judge_response(x): try: search = re.search(r"\s(\d)\s", x) @@ -190,6 +207,9 @@ def mean_dv_5(x): generation_size=100, stop_sequence=[], # no stop sequence, will use eot token version="0.1", + sample_fields=record_to_sample_freeform, + solver=[generate(cache=True)], + scorer=model_graded_fact(), ) @@ -206,6 +226,9 @@ def mean_dv_5(x): generation_size=100, stop_sequence=[], # no stop sequence, will use eot token version="0.1", + sample_fields=record_to_sample_multichoice, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) mixeval_freeform_hard = LightevalTaskConfig( @@ -221,6 +244,9 @@ def mean_dv_5(x): generation_size=100, stop_sequence=[], # no stop sequence, will use eot token version="0.1", + sample_fields=record_to_sample_freeform, + solver=[generate(cache=True)], + scorer=model_graded_fact(), ) @@ -237,6 +263,9 @@ def mean_dv_5(x): generation_size=100, stop_sequence=[], # no stop sequence, will use eot token version="0.1", + sample_fields=record_to_sample_multichoice, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) diff --git a/src/lighteval/tasks/tasks/musr.py b/src/lighteval/tasks/tasks/musr.py index fa2671e2d..d0054838d 100644 --- a/src/lighteval/tasks/tasks/musr.py +++ b/src/lighteval/tasks/tasks/musr.py @@ -21,6 +21,11 @@ """ import ast +from string import ascii_uppercase + +from inspect_ai.dataset import Sample +from inspect_ai.scorer import choice +from inspect_ai.solver import multiple_choice from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig @@ -32,13 +37,20 @@ def musr_prompt(line, task_name: str = None): query = line["narrative"] + "\n\n" query += line["question"] + "\n\n" - for i, choice in enumerate(choices): - query += f"{i + 1} - {choice}\n" + for i, choice_ in enumerate(choices): + query += f"{i + 1} - {choice_}\n" query += "Answer:" return Doc(task_name=task_name, query=query, choices=choices, gold_index=line["answer_index"]) +def record_to_sample(record): + query = record["narrative"] + "\n\n" + record["question"] + choices = ast.literal_eval(record["choices"]) + target = ascii_uppercase[record["answer_index"]] + return Sample(input=query, target=target, choices=choices) + + musr_murder_mysteries = LightevalTaskConfig( name="musr:murder_mysteries", prompt_function=musr_prompt, @@ -52,6 +64,9 @@ def musr_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) @@ -68,6 +83,9 @@ def musr_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) @@ -84,6 +102,9 @@ def musr_prompt(line, task_name: str = None): metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[multiple_choice(cache=True)], + scorer=choice(), ) TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/simpleqa.py b/src/lighteval/tasks/tasks/simpleqa.py index 602ba9727..08fcc29f0 100644 --- a/src/lighteval/tasks/tasks/simpleqa.py +++ b/src/lighteval/tasks/tasks/simpleqa.py @@ -19,6 +19,10 @@ https://openai.com/index/introducing-simpleqa/ """ +from inspect_ai.dataset import Sample +from inspect_ai.scorer import model_graded_fact +from inspect_ai.solver import generate + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -38,6 +42,12 @@ def simpleqa_prompt(line, task_name: str = None): ) +def record_to_sample(record): + query = record["problem"] + target = record["answer"] + return Sample(input=query, target=target) + + simpleqa = LightevalTaskConfig( name="simpleqa", prompt_function=simpleqa_prompt, @@ -51,6 +61,9 @@ def simpleqa_prompt(line, task_name: str = None): metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=model_graded_fact(), ) TASKS_TABLE = [