diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py index d91c63ed2..11017995f 100644 --- a/src/open_r1/rewards.py +++ b/src/open_r1/rewards.py @@ -3,7 +3,9 @@ import json import math import re -from typing import Dict +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Dict, List from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify @@ -18,20 +20,112 @@ load_dotenv() -def accuracy_reward(completions, solution, **kwargs): +class BaseRewardFunction(ABC): + """Reward function base class""" + + def __init__(self, max_workers: int = 1, *args, **kwargs) -> None: + self.max_workers = max_workers + super().__init__() + + @abstractmethod + def reward_on_single_completion(self, completion: str, **kwargs) -> float: + """Process a single completion and return its score + + Args: + completion: Single completion content to evaluate + + Returns: + float: + """ + raise NotImplementedError("reward_on_single_completion should be impl by subclass") + + def _single_thread_call(self, completions: List[Dict[str, str]], **kwargs) -> List[float]: + results = [] + for idx, completion in enumerate(completions): + # prepare per-completion kwargs + per_completion_kwargs = {} + for key, value in kwargs.items(): + if isinstance(value, list): + per_completion_kwargs[key] = value[idx] + else: + per_completion_kwargs[key] = value + results.append(self.reward_on_single_completion(completion, **per_completion_kwargs)) + return results + + def __call__(self, completions: List[Dict[str, str]], **kwargs) -> List[float]: + """Process and score multiple model completions in parallel. + + Args: + completions: List of model completions, where each completion is a dictionary with 'content' key storing the completion text + **kwargs: Additional keyword arguments that can include: + - max_workers: Optional int, maximum number of parallel workers + - Any other arguments needed by reward_on_single_completion() + + Returns: + List[float]: A list of reward scores, one for each completion, + computed by reward_on_single_completion() + + Raises: + RuntimeError: If processing any completion fails + """ + completions = [completion[0]["content"] for completion in completions] + + if "max_workers" in kwargs: + max_workers = kwargs["max_workers"] + else: + max_workers = self.max_workers + + if max_workers == 1: + return self._single_thread_call(completions, **kwargs) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_idx = {} + for idx, completion in enumerate(completions): + # prepare per-completion kwargs + per_completion_kwargs = {} + for key, value in kwargs.items(): + if isinstance(value, list): + per_completion_kwargs[key] = value[idx] + else: + per_completion_kwargs[key] = value + + future = executor.submit(self.reward_on_single_completion, completion, **per_completion_kwargs) + future_to_idx[future] = idx + + results = [None] * len(completions) + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + try: + results[idx] = future.result() + except Exception as e: + raise RuntimeError(f"Error processing completion {idx}: {e}") from e + + return results + + +class AccuracyReward(BaseRewardFunction): """Reward function that checks if the completion is the same as the ground truth.""" - contents = [completion[0]["content"] for completion in completions] - rewards = [] - for content, sol in zip(contents, solution): + + def reward_on_single_completion(self, completion: str, solution: str, **kwargs) -> float: + """Process a single completion and return its score + + Args: + completion: Single completion content to evaluate + **kwargs: Must contain 'solution' key with the ground truth solution + + Returns: + float: 1.0 if the completion matches the ground truth, 0.0 otherwise + """ gold_parsed = parse( - sol, + solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) + if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) answer_parsed = parse( - content, + completion, extraction_config=[ LatexExtractionConfig( normalization_config=NormalizationConfig( @@ -54,21 +148,30 @@ def accuracy_reward(completions, solution, **kwargs): else: # If the gold solution is not parseable, we reward 1 to skip this example reward = 1.0 - print("Failed to parse gold solution: ", sol) - rewards.append(reward) + print("Failed to parse gold solution: ", solution) - return rewards + return reward -def format_reward(completions, **kwargs): +class FormatReward(BaseRewardFunction): """Reward function that checks if the reasoning process is enclosed within and tags, while the final answer is enclosed within and tags.""" - pattern = r"^.*?\s*.*?$" - completion_contents = [completion[0]["content"] for completion in completions] - matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] - return [1.0 if match else 0.0 for match in matches] + + def reward_on_single_completion(self, completion: str, **kwargs) -> float: + """Process a single completion and return its score + + Args: + completion: Single completion content to evaluate + **kwargs: Additional arguments (unused) + + Returns: + float: 1.0 if the completion has correct format, 0.0 otherwise + """ + pattern = r"^.*?\s*.*?$" + match = re.match(pattern, completion, re.DOTALL | re.MULTILINE) + return 1.0 if match else 0.0 -def reasoning_steps_reward(completions, **kwargs): +class ReasoningStepsReward(BaseRewardFunction): r"""Reward function that checks for clear step-by-step reasoning. Regex pattern: Step \d+: - matches "Step 1:", "Step 2:", etc. @@ -77,118 +180,55 @@ def reasoning_steps_reward(completions, **kwargs): \n\* - matches bullet points with asterisks First,|Second,|Next,|Finally, - matches transition words """ - pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)" - completion_contents = [completion[0]["content"] for completion in completions] - matches = [len(re.findall(pattern, content)) for content in completion_contents] - - # Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward - return [min(1.0, count / 3) for count in matches] - - -def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float: - """Compute length-based rewards to discourage overthinking and promote token efficiency. - - Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599 - Args: - completions: List of model completions - solution: List of ground truth solutions + def reward_on_single_completion(self, completion: str, **kwargs) -> float: + """Process a single completion and return its score based on number of reasoning steps. - Returns: - List of rewards where: - - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len) - - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len)) - """ - contents = [completion[0]["content"] for completion in completions] - - # First check correctness of answers - correctness = [] - for content, sol in zip(contents, solution): - gold_parsed = parse( - sol, - extraction_mode="first_match", - extraction_config=[LatexExtractionConfig()], - ) - if len(gold_parsed) == 0: - # Skip unparseable examples - correctness.append(True) # Treat as correct to avoid penalizing - print("Failed to parse gold solution: ", sol) - continue - - answer_parsed = parse( - content, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - equations=True, - boxed=True, - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode="first_match", - ) - correctness.append(verify(answer_parsed, gold_parsed)) - - # Calculate lengths - lengths = [len(content) for content in contents] - min_len = min(lengths) - max_len = max(lengths) - - # If all responses have the same length, return zero rewards - if max_len == min_len: - return [0.0] * len(completions) - - rewards = [] - for length, is_correct in zip(lengths, correctness): - lambda_val = 0.5 - (length - min_len) / (max_len - min_len) + Args: + completion: Single completion content to evaluate + **kwargs: Additional arguments (unused) - if is_correct: - reward = lambda_val - else: - reward = min(0, lambda_val) + Returns: + float: Score between 0.0 and 1.0 based on number of reasoning steps found + """ + pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)" + matches = len(re.findall(pattern, completion)) - rewards.append(float(reward)) + # Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward + return min(1.0, matches / 3) - return rewards +class LengthReward(BaseRewardFunction): + def reward_on_single_completion(self, completion: str, **kwargs) -> float: + raise NotImplementedError("LengthReward don't need to impl reward_on_single_completion") -def get_cosine_scaled_reward( - min_value_wrong: float = -1.0, - max_value_wrong: float = -0.5, - min_value_correct: float = 0.5, - max_value_correct: float = 1.0, - max_len: int = 1000, -): - def cosine_scaled_reward(completions, solution, **kwargs): - """Reward function that scales based on completion length using a cosine schedule. + def __call__(self, completions: List[Dict[str, str]], solution: list[str], **kwargs) -> List[float]: + """Compute length-based rewards to discourage overthinking and promote token efficiency. - Shorter correct solutions are rewarded more than longer ones. - Longer incorrect solutions are penalized less than shorter ones. + Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599 Args: completions: List of model completions solution: List of ground truth solutions - This function is parameterized by the following arguments: - min_value_wrong: Minimum reward for wrong answers - max_value_wrong: Maximum reward for wrong answers - min_value_correct: Minimum reward for correct answers - max_value_correct: Maximum reward for correct answers - max_len: Maximum length for scaling + Returns: + List of rewards where: + - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len) + - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len)) """ contents = [completion[0]["content"] for completion in completions] - rewards = [] + # First check correctness of answers + correctness = [] for content, sol in zip(contents, solution): - gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) + gold_parsed = parse( + sol, + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) if len(gold_parsed) == 0: - rewards.append(1.0) # Skip unparseable examples + # Skip unparseable examples + correctness.append(True) # Treat as correct to avoid penalizing print("Failed to parse gold solution: ", sol) continue @@ -210,167 +250,251 @@ def cosine_scaled_reward(completions, solution, **kwargs): ], extraction_mode="first_match", ) + correctness.append(verify(answer_parsed, gold_parsed)) - is_correct = verify(answer_parsed, gold_parsed) - gen_len = len(content) + # Calculate lengths + lengths = [len(content) for content in contents] + min_len = min(lengths) + max_len = max(lengths) - # Apply cosine scaling based on length - progress = gen_len / max_len - cosine = math.cos(progress * math.pi) + # If all responses have the same length, return zero rewards + if max_len == min_len: + return [0.0] * len(completions) + + rewards = [] + for length, is_correct in zip(lengths, correctness): + lambda_val = 0.5 - (length - min_len) / (max_len - min_len) if is_correct: - min_value = min_value_correct - max_value = max_value_correct + reward = lambda_val else: - # Swap min/max for incorrect answers - min_value = max_value_wrong - max_value = min_value_wrong + reward = min(0, lambda_val) - reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine) rewards.append(float(reward)) return rewards - return cosine_scaled_reward + +class CosineScaledReward(BaseRewardFunction): + def __init__( + self, + min_value_wrong: float = -1.0, + max_value_wrong: float = -0.5, + min_value_correct: float = 0.5, + max_value_correct: float = 1.0, + max_len: int = 1000, + max_workers: int = 1, + ): + """Initialize CosineScaledReward. + + Args: + min_value_wrong: Minimum reward for wrong answers + max_value_wrong: Maximum reward for wrong answers + min_value_correct: Minimum reward for correct answers + max_value_correct: Maximum reward for correct answers + max_len: Maximum length for scaling + """ + super().__init__(max_workers=max_workers) + self.min_value_wrong = min_value_wrong + self.max_value_wrong = max_value_wrong + self.min_value_correct = min_value_correct + self.max_value_correct = max_value_correct + self.max_len = max_len + + def reward_on_single_completion(self, completion: str, solution: str, **kwargs) -> float: + gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) + if len(gold_parsed) == 0: + print("Failed to parse gold solution: ", solution) + return 1.0 # Skip unparseable examples + + answer_parsed = parse( + completion, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed=True, + units=True, + ), + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + + is_correct = verify(answer_parsed, gold_parsed) + gen_len = len(completion) + + # Apply cosine scaling based on length + progress = gen_len / self.max_len + cosine = math.cos(progress * math.pi) + + if is_correct: + min_value = self.min_value_correct + max_value = self.max_value_correct + else: + # Swap min/max for incorrect answers + min_value = self.max_value_wrong + max_value = self.min_value_wrong + + reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine) + return float(reward) -def get_repetition_penalty_reward(ngram_size: int, max_penalty: float): +class RepetitionPenaltyReward(BaseRewardFunction): """ Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373. Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py - - Args: - ngram_size: size of the n-grams - max_penalty: Maximum (negative) penalty for wrong answers """ - if max_penalty > 0: - raise ValueError(f"max_penalty {max_penalty} should not be positive") - def zipngram(text: str, ngram_size: int): + def __init__(self, ngram_size: int, max_penalty: float, max_workers: int = 1): + """ + Args: + ngram_size: size of the n-grams + max_penalty: Maximum (negative) penalty for wrong answers + """ + super().__init__(max_workers=max_workers) + if max_penalty > 0: + raise ValueError(f"max_penalty {max_penalty} should not be positive") + self.ngram_size = ngram_size + self.max_penalty = max_penalty + + def _zipngram(self, text: str, ngram_size: int): words = text.lower().split() return zip(*[words[i:] for i in range(ngram_size)]) - def repetition_penalty_reward(completions, **kwargs) -> float: + def reward_on_single_completion(self, completion: str, **kwargs) -> float: """ - reward function the penalizes repetitions - ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py + Reward function that penalizes repetitions Args: - completions: List of model completions + completion: Model completion text """ + if completion == "": + return 0.0 + + if len(completion.split()) < self.ngram_size: + return 0.0 + + ngrams = set() + total = 0 + for ng in self._zipngram(completion, self.ngram_size): + ngrams.add(ng) + total += 1 + + scaling = 1 - len(ngrams) / total + reward = scaling * self.max_penalty + return float(reward) + + +class CodeReward(BaseRewardFunction): + """Reward function that evaluates code snippets using the E2B code interpreter.""" + + evaluation_script_template = """ + import subprocess + import json + + def evaluate_code(code, test_cases): + passed = 0 + total = len(test_cases) + exec_timeout = 5 + + for case in test_cases: + process = subprocess.run( + ["python3", "-c", code], + input=case["input"], + text=True, + capture_output=True, + timeout=exec_timeout + ) - contents = [completion[0]["content"] for completion in completions] - rewards = [] - for completion in contents: - if completion == "": - rewards.append(0.0) + if process.returncode != 0: # Error in execution continue - if len(completion.split()) < ngram_size: - rewards.append(0.0) - continue - - ngrams = set() - total = 0 - for ng in zipngram(completion, ngram_size): - ngrams.add(ng) - total += 1 - - scaling = 1 - len(ngrams) / total - reward = scaling * max_penalty - rewards.append(reward) - return rewards - - return repetition_penalty_reward + output = process.stdout.strip() + if output.strip() == case["output"].strip(): + passed += 1 -def extract_code(completion: str) -> str: - pattern = re.compile(r"```python\n(.*?)```", re.DOTALL) - matches = pattern.findall(completion) - extracted_answer = matches[-1] if len(matches) >= 1 else "" - return extracted_answer + success_rate = (passed / total) + return success_rate + code_snippet = {code} + test_cases = json.loads({test_cases}) -def code_reward(completions, **kwargs) -> list[float]: - """Reward function that evaluates code snippets using the E2B code interpreter. - - Assumes the dataset contains a `verification_info` column with test cases. + evaluate_code(code_snippet, test_cases) """ - if not is_e2b_available(): - raise ImportError( - "E2B is not available and required for this reward function. Please install E2B with " - "`pip install e2b-code-interpreter` and add an API key to a `.env` file." - ) - - rewards = [] - # TODO: add support for other languages in E2B: https://e2b.dev/docs/code-interpreting/supported-languages - try: - """Returns a reward function that evaluates code snippets in a sandbox.""" - evaluation_script_template = """ - import subprocess - import json - - def evaluate_code(code, test_cases): - passed = 0 - total = len(test_cases) - exec_timeout = 5 - - for case in test_cases: - process = subprocess.run( - ["python3", "-c", code], - input=case["input"], - text=True, - capture_output=True, - timeout=exec_timeout - ) - if process.returncode != 0: # Error in execution - continue + def __init__(self, max_workers: int = 1): + super().__init__(max_workers=max_workers) + if not is_e2b_available(): + raise ImportError( + "E2B is not available and required for this reward function. Please install E2B with " + "`pip install e2b-code-interpreter` and add an API key to a `.env` file." + ) - output = process.stdout.strip() - if output.strip() == case["output"].strip(): - passed += 1 + def _extract_code(self, completion: str) -> str: + pattern = re.compile(r"```python\n(.*?)```", re.DOTALL) + matches = pattern.findall(completion) + extracted_answer = matches[-1] if len(matches) >= 1 else "" + return extracted_answer - success_rate = (passed / total) - return success_rate + def reward_on_single_completion(self, completion: str, **kwargs) -> float: + """ + Evaluate code snippets using test cases. - code_snippet = {code} - test_cases = json.loads({test_cases}) + Args: + completions: List of model completions + **kwargs: Must contain 'verification_info' with test cases - evaluate_code(code_snippet, test_cases) + Returns: + List of reward scores between 0 and 1 """ - code_snippets = [extract_code(completion[-1]["content"]) for completion in completions] + + code = self._extract_code(completion) verification_info = kwargs["verification_info"] - scripts = [ - evaluation_script_template.format( - code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"])) - ) - for code, info in zip(code_snippets, verification_info) - ] - with Sandbox(timeout=30, request_timeout=3) as sbx: - for script in scripts: + script = self.evaluation_script_template.format( + code=json.dumps(code), test_cases=json.dumps(json.dumps(verification_info["test_cases"])) + ) + + try: + with Sandbox(timeout=30, request_timeout=3) as sbx: execution = sbx.run_code(script, language=verification_info["language"]) try: - output = float(execution.text) + score = float(execution.text) except (TypeError, ValueError): - output = 0.0 - rewards.append(output) - except Exception as e: - print(f"Error from E2B executor: {e}") - rewards = [0.0] * len(completions) - return rewards + score = 0.0 + return score + except Exception as e: + print(f"Error from E2B executor: {e}") + return 0, 0 -def get_code_format_reward(language: str = "python"): - """Format reward function specifically for code responses. +class CodeFormatReward(BaseRewardFunction): + """Format reward function specifically for code responses.""" - Args: - language: Programming language supported by E2B https://e2b.dev/docs/code-interpreting/supported-languages - """ - pattern = rf"^.*?\s*.*?```{language}\n.*?```.*?$" + def __init__(self, language: str = "python", max_workers: int = 1): + """ + Initialize the code format reward function. - def code_format_reward(completions, **kwargs): - completion_contents = [completion[0]["content"] for completion in completions] - matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] - return [1.0 if match else 0.0 for match in matches] + Args: + language: Programming language supported by E2B + """ + super().__init__(max_workers=max_workers) + self.pattern = rf"^.*?\s*.*?```{language}\n.*?```.*?$" + + def reward_on_single_completion(self, completion: str, **kwargs) -> float: + """ + Check if completions match the expected code format. - return code_format_reward + Args: + completions: List of model completions + + Returns: + List of 1.0 for matching format, 0.0 otherwise + """ + match = re.match(self.pattern, completion, re.DOTALL | re.MULTILINE) + return 1.0 if match else 0.0 diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 32d0f1137..4e270774c 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -1,13 +1,13 @@ import unittest from open_r1.rewards import ( - accuracy_reward, - format_reward, - get_code_format_reward, - get_cosine_scaled_reward, - get_repetition_penalty_reward, - len_reward, - reasoning_steps_reward, + AccuracyReward, + CodeFormatReward, + CosineScaledReward, + FormatReward, + LengthReward, + ReasoningStepsReward, + RepetitionPenaltyReward, ) @@ -17,7 +17,7 @@ def test_accuracy_reward_correct_answer(self): completion = [[{"content": r"\boxed{\frac{63}{400}}"}]] solution = [r"\frac{63}{400}"] - rewards = accuracy_reward(completion, solution) + rewards = AccuracyReward()(completion, solution=solution) self.assertEqual(rewards[0], 1.0) def test_accuracy_reward_wrong_answer(self): @@ -25,13 +25,13 @@ def test_accuracy_reward_wrong_answer(self): completion = [[{"content": r"\boxed{\frac{64}{400}}"}]] solution = [r"\frac{63}{400}"] - rewards = accuracy_reward(completion, solution) + rewards = AccuracyReward()(completion, solution=solution) self.assertEqual(rewards[0], 0.0) def test_format_reward_correct(self): """Test format_reward with correct format.""" completion = [[{"content": "Some reasoningThe answer"}]] - rewards = format_reward(completion) + rewards = FormatReward()(completion) self.assertEqual(rewards[0], 1.0) def test_format_reward_incorrect(self): @@ -46,7 +46,7 @@ def test_format_reward_incorrect(self): for fmt in incorrect_formats: completion = [[{"content": fmt}]] - rewards = format_reward(completion) + rewards = FormatReward()(completion) self.assertEqual(rewards[0], 0.0) def test_reasoning_steps_reward(self): @@ -64,7 +64,7 @@ def test_reasoning_steps_reward(self): for content, expected_reward in test_cases: completion = [[{"content": content}]] - rewards = reasoning_steps_reward(completion) + rewards = ReasoningStepsReward()(completion) self.assertAlmostEqual(rewards[0], expected_reward) def test_multiple_completions(self): @@ -72,7 +72,7 @@ def test_multiple_completions(self): completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]] solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] - rewards = accuracy_reward(completions, solutions) + rewards = AccuracyReward()(completions, solution=solutions) self.assertEqual(len(rewards), 2) self.assertEqual(rewards[0], 1.0) self.assertEqual(rewards[1], 0.0) @@ -102,14 +102,14 @@ def test_cosine_scaled_reward(self): padded_content = content + " " * (content_len - len(content)) completion = [[{"content": padded_content}]] - rewards = get_cosine_scaled_reward(**test_params)(completion, [solution]) + rewards = CosineScaledReward(**test_params)(completion, solution=[solution]) self.assertAlmostEqual(rewards[0], expected_reward, places=2) def test_format_reward_specific_multiline(self): """Test format_reward with a specific multiline input.""" inputs = "\nI will count each distinct object in the image:\n1. Purple scooter\n2. Red bicycle\n3. Green motorcycle\n4. Gray sedan\n5. Yellow school bus\n6. Small green double-decker bus\n7. Small red car\n8. Small purple car\n9. Small gray dirt bike\n\nThere are 9 distinct objects in total.\n\n9" completion = [[{"content": inputs}]] - rewards = format_reward(completion) + rewards = FormatReward()(completion) self.assertEqual(rewards[0], 1.0) def test_same_length_responses(self): @@ -117,7 +117,7 @@ def test_same_length_responses(self): completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]] solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] - rewards = len_reward(completions, solutions) + rewards = LengthReward()(completions, solutions) self.assertEqual(rewards, [0.0, 0.0]) def test_different_lengths_correct_answers(self): @@ -128,7 +128,7 @@ def test_different_lengths_correct_answers(self): ] solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] - rewards = len_reward(completions, solutions) + rewards = LengthReward()(completions, solutions) self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward @@ -140,7 +140,7 @@ def test_different_lengths_incorrect_answers(self): ] solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] - rewards = len_reward(completions, solutions) + rewards = LengthReward()(completions, solutions) self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards self.assertLessEqual(rewards[1], 0.0) self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less @@ -155,7 +155,7 @@ def test_mixed_correctness(self): ] solutions = [r"\frac{63}{400}"] * 4 - rewards = len_reward(completions, solutions) + rewards = LengthReward()(completions, solutions) # Shortest correct answer should get positive reward self.assertGreater(rewards[0], 0.0) @@ -177,7 +177,7 @@ def test_unparseable_solution(self): completions = [[{"content": r"\boxed{answer}"}], [{"content": r"\boxed{answer} " + "x" * 10}]] solutions = ["unparseable_latex", "unparseable_latex"] - rewards = len_reward(completions, solutions) + rewards = LengthReward()(completions, solutions) self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward @@ -185,18 +185,18 @@ def test_unparseable_solution(self): class TestRepetitionPenaltyReward(unittest.TestCase): def test_positive_max_penalty_raises_value_error(self): with self.assertRaises(ValueError): - get_repetition_penalty_reward(ngram_size=2, max_penalty=1.0) + RepetitionPenaltyReward(ngram_size=2, max_penalty=1.0) with self.assertRaisesRegex(ValueError, "max_penalty 1.5 should not be positive"): - get_repetition_penalty_reward(ngram_size=2, max_penalty=1.5) + RepetitionPenaltyReward(ngram_size=2, max_penalty=1.5) def test_no_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=2, max_penalty=-1.0) completions = [[{"content": "this is a test sentence"}]] rewards = reward_fn(completions) self.assertEqual(rewards, [0.0]) def test_full_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=2, max_penalty=-1.0) completions = [[{"content": "this this this this this"}]] rewards = reward_fn(completions) @@ -204,7 +204,7 @@ def test_full_repetition(self): self.assertEqual(rewards, [-0.75]) def test_partial_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=2, max_penalty=-1.0) completions = [[{"content": "this is a this is a test"}]] rewards = reward_fn(completions) @@ -213,7 +213,7 @@ def test_partial_repetition(self): self.assertAlmostEqual(rewards[0], -1 / 3) def test_multiple_completions(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-0.5) completions = [ [{"content": "this is a test"}], [{"content": "test test test test"}], @@ -226,20 +226,20 @@ def test_multiple_completions(self): self.assertAlmostEqual(rewards[1], -0.25) def test_empty_completion(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=2, max_penalty=-1.0) completions = [[{"content": ""}]] rewards = reward_fn(completions) self.assertEqual(rewards, [0.0]) def test_different_ngram_size(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-2.0) completions = [[{"content": "this is a this is a test"}]] rewards = reward_fn(completions) self.assertAlmostEqual(rewards[0], -0.4) def test_mixed_case(self): - reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=2, max_penalty=-1.0) completions = [ [{"content": "This is A Test"}], [{"content": "this IS a test"}], @@ -250,35 +250,35 @@ def test_mixed_case(self): self.assertAlmostEqual(rewards[0], rewards[1]) def test_one_word_completion(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "word"}]] rewards = reward_fn(completions) self.assertEqual(rewards, [0.0]) def test_two_word_completion(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "two words"}]] rewards = reward_fn(completions) self.assertEqual(rewards, [0.0]) def test_three_word_completion(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "three different words"}]] rewards = reward_fn(completions) self.assertEqual(rewards, [0.0]) def test_three_word_repetition_completion(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "word word word word"}]] rewards = reward_fn(completions) self.assertEqual(rewards, [-0.5]) def test_four_word_completion_with_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "one two one two"}]] rewards = reward_fn(completions) @@ -286,7 +286,7 @@ def test_four_word_completion_with_repetition(self): self.assertEqual(rewards, [0.0]) def test_five_word_completion_with_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-0.5) completions = [[{"content": "A B C A B"}]] rewards = reward_fn(completions) @@ -294,20 +294,20 @@ def test_five_word_completion_with_repetition(self): self.assertEqual(rewards, [0.0]) def test_six_word_completion_with_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "A B C A B C"}]] rewards = reward_fn(completions) self.assertEqual(rewards, [-0.25]) def test_long_completion_with_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "A B C A B C E F G A B C A B C"}]] rewards = reward_fn(completions) self.assertAlmostEqual(rewards[0], -0.3846, places=4) def test_long_completion_without_repetition(self): - reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0) + reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0) completions = [[{"content": "A B C D E F G H I J K L"}]] rewards = reward_fn(completions) @@ -324,7 +324,7 @@ def test_correct_python_format(self): } ] ] - reward_fn = get_code_format_reward(language="python") + reward_fn = CodeFormatReward(language="python") rewards = reward_fn(completion) self.assertEqual(rewards[0], 1.0) @@ -343,7 +343,7 @@ def test_incorrect_formats(self): "```python\ndef hello(): pass\n```Analysis", ] - reward_fn = get_code_format_reward(language="python") + reward_fn = CodeFormatReward(language="python") for fmt in incorrect_formats: completion = [[{"content": fmt}]] rewards = reward_fn(completion) @@ -358,7 +358,7 @@ def test_multiple_code_blocks(self): } ] ] - reward_fn = get_code_format_reward(language="python") + reward_fn = CodeFormatReward(language="python") rewards = reward_fn(completion) self.assertEqual(rewards[0], 1.0) @@ -369,12 +369,12 @@ def test_different_languages(self): ] # Test with JavaScript - js_reward_fn = get_code_format_reward(language="javascript") + js_reward_fn = CodeFormatReward(language="javascript") rewards = js_reward_fn(completion) self.assertEqual(rewards[0], 1.0) # Same completion should fail for Python - py_reward_fn = get_code_format_reward(language="python") + py_reward_fn = CodeFormatReward(language="python") rewards = py_reward_fn(completion) self.assertEqual(rewards[0], 0.0) @@ -387,7 +387,7 @@ def test_multiline_code(self): } ] ] - reward_fn = get_code_format_reward(language="python") + reward_fn = CodeFormatReward(language="python") rewards = reward_fn(completion) self.assertEqual(rewards[0], 1.0)