diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py
index d91c63ed2..11017995f 100644
--- a/src/open_r1/rewards.py
+++ b/src/open_r1/rewards.py
@@ -3,7 +3,9 @@
import json
import math
import re
-from typing import Dict
+from abc import ABC, abstractmethod
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import Dict, List
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
@@ -18,20 +20,112 @@
load_dotenv()
-def accuracy_reward(completions, solution, **kwargs):
+class BaseRewardFunction(ABC):
+ """Reward function base class"""
+
+ def __init__(self, max_workers: int = 1, *args, **kwargs) -> None:
+ self.max_workers = max_workers
+ super().__init__()
+
+ @abstractmethod
+ def reward_on_single_completion(self, completion: str, **kwargs) -> float:
+ """Process a single completion and return its score
+
+ Args:
+ completion: Single completion content to evaluate
+
+ Returns:
+ float:
+ """
+ raise NotImplementedError("reward_on_single_completion should be impl by subclass")
+
+ def _single_thread_call(self, completions: List[Dict[str, str]], **kwargs) -> List[float]:
+ results = []
+ for idx, completion in enumerate(completions):
+ # prepare per-completion kwargs
+ per_completion_kwargs = {}
+ for key, value in kwargs.items():
+ if isinstance(value, list):
+ per_completion_kwargs[key] = value[idx]
+ else:
+ per_completion_kwargs[key] = value
+ results.append(self.reward_on_single_completion(completion, **per_completion_kwargs))
+ return results
+
+ def __call__(self, completions: List[Dict[str, str]], **kwargs) -> List[float]:
+ """Process and score multiple model completions in parallel.
+
+ Args:
+ completions: List of model completions, where each completion is a dictionary with 'content' key storing the completion text
+ **kwargs: Additional keyword arguments that can include:
+ - max_workers: Optional int, maximum number of parallel workers
+ - Any other arguments needed by reward_on_single_completion()
+
+ Returns:
+ List[float]: A list of reward scores, one for each completion,
+ computed by reward_on_single_completion()
+
+ Raises:
+ RuntimeError: If processing any completion fails
+ """
+ completions = [completion[0]["content"] for completion in completions]
+
+ if "max_workers" in kwargs:
+ max_workers = kwargs["max_workers"]
+ else:
+ max_workers = self.max_workers
+
+ if max_workers == 1:
+ return self._single_thread_call(completions, **kwargs)
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ future_to_idx = {}
+ for idx, completion in enumerate(completions):
+ # prepare per-completion kwargs
+ per_completion_kwargs = {}
+ for key, value in kwargs.items():
+ if isinstance(value, list):
+ per_completion_kwargs[key] = value[idx]
+ else:
+ per_completion_kwargs[key] = value
+
+ future = executor.submit(self.reward_on_single_completion, completion, **per_completion_kwargs)
+ future_to_idx[future] = idx
+
+ results = [None] * len(completions)
+ for future in as_completed(future_to_idx):
+ idx = future_to_idx[future]
+ try:
+ results[idx] = future.result()
+ except Exception as e:
+ raise RuntimeError(f"Error processing completion {idx}: {e}") from e
+
+ return results
+
+
+class AccuracyReward(BaseRewardFunction):
"""Reward function that checks if the completion is the same as the ground truth."""
- contents = [completion[0]["content"] for completion in completions]
- rewards = []
- for content, sol in zip(contents, solution):
+
+ def reward_on_single_completion(self, completion: str, solution: str, **kwargs) -> float:
+ """Process a single completion and return its score
+
+ Args:
+ completion: Single completion content to evaluate
+ **kwargs: Must contain 'solution' key with the ground truth solution
+
+ Returns:
+ float: 1.0 if the completion matches the ground truth, 0.0 otherwise
+ """
gold_parsed = parse(
- sol,
+ solution,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
+
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
- content,
+ completion,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
@@ -54,21 +148,30 @@ def accuracy_reward(completions, solution, **kwargs):
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
- print("Failed to parse gold solution: ", sol)
- rewards.append(reward)
+ print("Failed to parse gold solution: ", solution)
- return rewards
+ return reward
-def format_reward(completions, **kwargs):
+class FormatReward(BaseRewardFunction):
"""Reward function that checks if the reasoning process is enclosed within and tags, while the final answer is enclosed within and tags."""
- pattern = r"^.*?\s*.*?$"
- completion_contents = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
- return [1.0 if match else 0.0 for match in matches]
+
+ def reward_on_single_completion(self, completion: str, **kwargs) -> float:
+ """Process a single completion and return its score
+
+ Args:
+ completion: Single completion content to evaluate
+ **kwargs: Additional arguments (unused)
+
+ Returns:
+ float: 1.0 if the completion has correct format, 0.0 otherwise
+ """
+ pattern = r"^.*?\s*.*?$"
+ match = re.match(pattern, completion, re.DOTALL | re.MULTILINE)
+ return 1.0 if match else 0.0
-def reasoning_steps_reward(completions, **kwargs):
+class ReasoningStepsReward(BaseRewardFunction):
r"""Reward function that checks for clear step-by-step reasoning.
Regex pattern:
Step \d+: - matches "Step 1:", "Step 2:", etc.
@@ -77,118 +180,55 @@ def reasoning_steps_reward(completions, **kwargs):
\n\* - matches bullet points with asterisks
First,|Second,|Next,|Finally, - matches transition words
"""
- pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
- completion_contents = [completion[0]["content"] for completion in completions]
- matches = [len(re.findall(pattern, content)) for content in completion_contents]
-
- # Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
- return [min(1.0, count / 3) for count in matches]
-
-
-def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float:
- """Compute length-based rewards to discourage overthinking and promote token efficiency.
-
- Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
- Args:
- completions: List of model completions
- solution: List of ground truth solutions
+ def reward_on_single_completion(self, completion: str, **kwargs) -> float:
+ """Process a single completion and return its score based on number of reasoning steps.
- Returns:
- List of rewards where:
- - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
- - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
- """
- contents = [completion[0]["content"] for completion in completions]
-
- # First check correctness of answers
- correctness = []
- for content, sol in zip(contents, solution):
- gold_parsed = parse(
- sol,
- extraction_mode="first_match",
- extraction_config=[LatexExtractionConfig()],
- )
- if len(gold_parsed) == 0:
- # Skip unparseable examples
- correctness.append(True) # Treat as correct to avoid penalizing
- print("Failed to parse gold solution: ", sol)
- continue
-
- answer_parsed = parse(
- content,
- extraction_config=[
- LatexExtractionConfig(
- normalization_config=NormalizationConfig(
- nits=False,
- malformed_operators=False,
- basic_latex=True,
- equations=True,
- boxed=True,
- units=True,
- ),
- boxed_match_priority=0,
- try_extract_without_anchor=False,
- )
- ],
- extraction_mode="first_match",
- )
- correctness.append(verify(answer_parsed, gold_parsed))
-
- # Calculate lengths
- lengths = [len(content) for content in contents]
- min_len = min(lengths)
- max_len = max(lengths)
-
- # If all responses have the same length, return zero rewards
- if max_len == min_len:
- return [0.0] * len(completions)
-
- rewards = []
- for length, is_correct in zip(lengths, correctness):
- lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
+ Args:
+ completion: Single completion content to evaluate
+ **kwargs: Additional arguments (unused)
- if is_correct:
- reward = lambda_val
- else:
- reward = min(0, lambda_val)
+ Returns:
+ float: Score between 0.0 and 1.0 based on number of reasoning steps found
+ """
+ pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
+ matches = len(re.findall(pattern, completion))
- rewards.append(float(reward))
+ # Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
+ return min(1.0, matches / 3)
- return rewards
+class LengthReward(BaseRewardFunction):
+ def reward_on_single_completion(self, completion: str, **kwargs) -> float:
+ raise NotImplementedError("LengthReward don't need to impl reward_on_single_completion")
-def get_cosine_scaled_reward(
- min_value_wrong: float = -1.0,
- max_value_wrong: float = -0.5,
- min_value_correct: float = 0.5,
- max_value_correct: float = 1.0,
- max_len: int = 1000,
-):
- def cosine_scaled_reward(completions, solution, **kwargs):
- """Reward function that scales based on completion length using a cosine schedule.
+ def __call__(self, completions: List[Dict[str, str]], solution: list[str], **kwargs) -> List[float]:
+ """Compute length-based rewards to discourage overthinking and promote token efficiency.
- Shorter correct solutions are rewarded more than longer ones.
- Longer incorrect solutions are penalized less than shorter ones.
+ Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
Args:
completions: List of model completions
solution: List of ground truth solutions
- This function is parameterized by the following arguments:
- min_value_wrong: Minimum reward for wrong answers
- max_value_wrong: Maximum reward for wrong answers
- min_value_correct: Minimum reward for correct answers
- max_value_correct: Maximum reward for correct answers
- max_len: Maximum length for scaling
+ Returns:
+ List of rewards where:
+ - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
+ - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
"""
contents = [completion[0]["content"] for completion in completions]
- rewards = []
+ # First check correctness of answers
+ correctness = []
for content, sol in zip(contents, solution):
- gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
+ gold_parsed = parse(
+ sol,
+ extraction_mode="first_match",
+ extraction_config=[LatexExtractionConfig()],
+ )
if len(gold_parsed) == 0:
- rewards.append(1.0) # Skip unparseable examples
+ # Skip unparseable examples
+ correctness.append(True) # Treat as correct to avoid penalizing
print("Failed to parse gold solution: ", sol)
continue
@@ -210,167 +250,251 @@ def cosine_scaled_reward(completions, solution, **kwargs):
],
extraction_mode="first_match",
)
+ correctness.append(verify(answer_parsed, gold_parsed))
- is_correct = verify(answer_parsed, gold_parsed)
- gen_len = len(content)
+ # Calculate lengths
+ lengths = [len(content) for content in contents]
+ min_len = min(lengths)
+ max_len = max(lengths)
- # Apply cosine scaling based on length
- progress = gen_len / max_len
- cosine = math.cos(progress * math.pi)
+ # If all responses have the same length, return zero rewards
+ if max_len == min_len:
+ return [0.0] * len(completions)
+
+ rewards = []
+ for length, is_correct in zip(lengths, correctness):
+ lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
if is_correct:
- min_value = min_value_correct
- max_value = max_value_correct
+ reward = lambda_val
else:
- # Swap min/max for incorrect answers
- min_value = max_value_wrong
- max_value = min_value_wrong
+ reward = min(0, lambda_val)
- reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
rewards.append(float(reward))
return rewards
- return cosine_scaled_reward
+
+class CosineScaledReward(BaseRewardFunction):
+ def __init__(
+ self,
+ min_value_wrong: float = -1.0,
+ max_value_wrong: float = -0.5,
+ min_value_correct: float = 0.5,
+ max_value_correct: float = 1.0,
+ max_len: int = 1000,
+ max_workers: int = 1,
+ ):
+ """Initialize CosineScaledReward.
+
+ Args:
+ min_value_wrong: Minimum reward for wrong answers
+ max_value_wrong: Maximum reward for wrong answers
+ min_value_correct: Minimum reward for correct answers
+ max_value_correct: Maximum reward for correct answers
+ max_len: Maximum length for scaling
+ """
+ super().__init__(max_workers=max_workers)
+ self.min_value_wrong = min_value_wrong
+ self.max_value_wrong = max_value_wrong
+ self.min_value_correct = min_value_correct
+ self.max_value_correct = max_value_correct
+ self.max_len = max_len
+
+ def reward_on_single_completion(self, completion: str, solution: str, **kwargs) -> float:
+ gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
+ if len(gold_parsed) == 0:
+ print("Failed to parse gold solution: ", solution)
+ return 1.0 # Skip unparseable examples
+
+ answer_parsed = parse(
+ completion,
+ extraction_config=[
+ LatexExtractionConfig(
+ normalization_config=NormalizationConfig(
+ nits=False,
+ malformed_operators=False,
+ basic_latex=True,
+ equations=True,
+ boxed=True,
+ units=True,
+ ),
+ boxed_match_priority=0,
+ try_extract_without_anchor=False,
+ )
+ ],
+ extraction_mode="first_match",
+ )
+
+ is_correct = verify(answer_parsed, gold_parsed)
+ gen_len = len(completion)
+
+ # Apply cosine scaling based on length
+ progress = gen_len / self.max_len
+ cosine = math.cos(progress * math.pi)
+
+ if is_correct:
+ min_value = self.min_value_correct
+ max_value = self.max_value_correct
+ else:
+ # Swap min/max for incorrect answers
+ min_value = self.max_value_wrong
+ max_value = self.min_value_wrong
+
+ reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
+ return float(reward)
-def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
+class RepetitionPenaltyReward(BaseRewardFunction):
"""
Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
-
- Args:
- ngram_size: size of the n-grams
- max_penalty: Maximum (negative) penalty for wrong answers
"""
- if max_penalty > 0:
- raise ValueError(f"max_penalty {max_penalty} should not be positive")
- def zipngram(text: str, ngram_size: int):
+ def __init__(self, ngram_size: int, max_penalty: float, max_workers: int = 1):
+ """
+ Args:
+ ngram_size: size of the n-grams
+ max_penalty: Maximum (negative) penalty for wrong answers
+ """
+ super().__init__(max_workers=max_workers)
+ if max_penalty > 0:
+ raise ValueError(f"max_penalty {max_penalty} should not be positive")
+ self.ngram_size = ngram_size
+ self.max_penalty = max_penalty
+
+ def _zipngram(self, text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
- def repetition_penalty_reward(completions, **kwargs) -> float:
+ def reward_on_single_completion(self, completion: str, **kwargs) -> float:
"""
- reward function the penalizes repetitions
- ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
+ Reward function that penalizes repetitions
Args:
- completions: List of model completions
+ completion: Model completion text
"""
+ if completion == "":
+ return 0.0
+
+ if len(completion.split()) < self.ngram_size:
+ return 0.0
+
+ ngrams = set()
+ total = 0
+ for ng in self._zipngram(completion, self.ngram_size):
+ ngrams.add(ng)
+ total += 1
+
+ scaling = 1 - len(ngrams) / total
+ reward = scaling * self.max_penalty
+ return float(reward)
+
+
+class CodeReward(BaseRewardFunction):
+ """Reward function that evaluates code snippets using the E2B code interpreter."""
+
+ evaluation_script_template = """
+ import subprocess
+ import json
+
+ def evaluate_code(code, test_cases):
+ passed = 0
+ total = len(test_cases)
+ exec_timeout = 5
+
+ for case in test_cases:
+ process = subprocess.run(
+ ["python3", "-c", code],
+ input=case["input"],
+ text=True,
+ capture_output=True,
+ timeout=exec_timeout
+ )
- contents = [completion[0]["content"] for completion in completions]
- rewards = []
- for completion in contents:
- if completion == "":
- rewards.append(0.0)
+ if process.returncode != 0: # Error in execution
continue
- if len(completion.split()) < ngram_size:
- rewards.append(0.0)
- continue
-
- ngrams = set()
- total = 0
- for ng in zipngram(completion, ngram_size):
- ngrams.add(ng)
- total += 1
-
- scaling = 1 - len(ngrams) / total
- reward = scaling * max_penalty
- rewards.append(reward)
- return rewards
-
- return repetition_penalty_reward
+ output = process.stdout.strip()
+ if output.strip() == case["output"].strip():
+ passed += 1
-def extract_code(completion: str) -> str:
- pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
- matches = pattern.findall(completion)
- extracted_answer = matches[-1] if len(matches) >= 1 else ""
- return extracted_answer
+ success_rate = (passed / total)
+ return success_rate
+ code_snippet = {code}
+ test_cases = json.loads({test_cases})
-def code_reward(completions, **kwargs) -> list[float]:
- """Reward function that evaluates code snippets using the E2B code interpreter.
-
- Assumes the dataset contains a `verification_info` column with test cases.
+ evaluate_code(code_snippet, test_cases)
"""
- if not is_e2b_available():
- raise ImportError(
- "E2B is not available and required for this reward function. Please install E2B with "
- "`pip install e2b-code-interpreter` and add an API key to a `.env` file."
- )
-
- rewards = []
- # TODO: add support for other languages in E2B: https://e2b.dev/docs/code-interpreting/supported-languages
- try:
- """Returns a reward function that evaluates code snippets in a sandbox."""
- evaluation_script_template = """
- import subprocess
- import json
-
- def evaluate_code(code, test_cases):
- passed = 0
- total = len(test_cases)
- exec_timeout = 5
-
- for case in test_cases:
- process = subprocess.run(
- ["python3", "-c", code],
- input=case["input"],
- text=True,
- capture_output=True,
- timeout=exec_timeout
- )
- if process.returncode != 0: # Error in execution
- continue
+ def __init__(self, max_workers: int = 1):
+ super().__init__(max_workers=max_workers)
+ if not is_e2b_available():
+ raise ImportError(
+ "E2B is not available and required for this reward function. Please install E2B with "
+ "`pip install e2b-code-interpreter` and add an API key to a `.env` file."
+ )
- output = process.stdout.strip()
- if output.strip() == case["output"].strip():
- passed += 1
+ def _extract_code(self, completion: str) -> str:
+ pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
+ matches = pattern.findall(completion)
+ extracted_answer = matches[-1] if len(matches) >= 1 else ""
+ return extracted_answer
- success_rate = (passed / total)
- return success_rate
+ def reward_on_single_completion(self, completion: str, **kwargs) -> float:
+ """
+ Evaluate code snippets using test cases.
- code_snippet = {code}
- test_cases = json.loads({test_cases})
+ Args:
+ completions: List of model completions
+ **kwargs: Must contain 'verification_info' with test cases
- evaluate_code(code_snippet, test_cases)
+ Returns:
+ List of reward scores between 0 and 1
"""
- code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
+
+ code = self._extract_code(completion)
verification_info = kwargs["verification_info"]
- scripts = [
- evaluation_script_template.format(
- code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"]))
- )
- for code, info in zip(code_snippets, verification_info)
- ]
- with Sandbox(timeout=30, request_timeout=3) as sbx:
- for script in scripts:
+ script = self.evaluation_script_template.format(
+ code=json.dumps(code), test_cases=json.dumps(json.dumps(verification_info["test_cases"]))
+ )
+
+ try:
+ with Sandbox(timeout=30, request_timeout=3) as sbx:
execution = sbx.run_code(script, language=verification_info["language"])
try:
- output = float(execution.text)
+ score = float(execution.text)
except (TypeError, ValueError):
- output = 0.0
- rewards.append(output)
- except Exception as e:
- print(f"Error from E2B executor: {e}")
- rewards = [0.0] * len(completions)
- return rewards
+ score = 0.0
+ return score
+ except Exception as e:
+ print(f"Error from E2B executor: {e}")
+ return 0, 0
-def get_code_format_reward(language: str = "python"):
- """Format reward function specifically for code responses.
+class CodeFormatReward(BaseRewardFunction):
+ """Format reward function specifically for code responses."""
- Args:
- language: Programming language supported by E2B https://e2b.dev/docs/code-interpreting/supported-languages
- """
- pattern = rf"^.*?\s*.*?```{language}\n.*?```.*?$"
+ def __init__(self, language: str = "python", max_workers: int = 1):
+ """
+ Initialize the code format reward function.
- def code_format_reward(completions, **kwargs):
- completion_contents = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
- return [1.0 if match else 0.0 for match in matches]
+ Args:
+ language: Programming language supported by E2B
+ """
+ super().__init__(max_workers=max_workers)
+ self.pattern = rf"^.*?\s*.*?```{language}\n.*?```.*?$"
+
+ def reward_on_single_completion(self, completion: str, **kwargs) -> float:
+ """
+ Check if completions match the expected code format.
- return code_format_reward
+ Args:
+ completions: List of model completions
+
+ Returns:
+ List of 1.0 for matching format, 0.0 otherwise
+ """
+ match = re.match(self.pattern, completion, re.DOTALL | re.MULTILINE)
+ return 1.0 if match else 0.0
diff --git a/tests/test_rewards.py b/tests/test_rewards.py
index 32d0f1137..4e270774c 100644
--- a/tests/test_rewards.py
+++ b/tests/test_rewards.py
@@ -1,13 +1,13 @@
import unittest
from open_r1.rewards import (
- accuracy_reward,
- format_reward,
- get_code_format_reward,
- get_cosine_scaled_reward,
- get_repetition_penalty_reward,
- len_reward,
- reasoning_steps_reward,
+ AccuracyReward,
+ CodeFormatReward,
+ CosineScaledReward,
+ FormatReward,
+ LengthReward,
+ ReasoningStepsReward,
+ RepetitionPenaltyReward,
)
@@ -17,7 +17,7 @@ def test_accuracy_reward_correct_answer(self):
completion = [[{"content": r"\boxed{\frac{63}{400}}"}]]
solution = [r"\frac{63}{400}"]
- rewards = accuracy_reward(completion, solution)
+ rewards = AccuracyReward()(completion, solution=solution)
self.assertEqual(rewards[0], 1.0)
def test_accuracy_reward_wrong_answer(self):
@@ -25,13 +25,13 @@ def test_accuracy_reward_wrong_answer(self):
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
solution = [r"\frac{63}{400}"]
- rewards = accuracy_reward(completion, solution)
+ rewards = AccuracyReward()(completion, solution=solution)
self.assertEqual(rewards[0], 0.0)
def test_format_reward_correct(self):
"""Test format_reward with correct format."""
completion = [[{"content": "Some reasoningThe answer"}]]
- rewards = format_reward(completion)
+ rewards = FormatReward()(completion)
self.assertEqual(rewards[0], 1.0)
def test_format_reward_incorrect(self):
@@ -46,7 +46,7 @@ def test_format_reward_incorrect(self):
for fmt in incorrect_formats:
completion = [[{"content": fmt}]]
- rewards = format_reward(completion)
+ rewards = FormatReward()(completion)
self.assertEqual(rewards[0], 0.0)
def test_reasoning_steps_reward(self):
@@ -64,7 +64,7 @@ def test_reasoning_steps_reward(self):
for content, expected_reward in test_cases:
completion = [[{"content": content}]]
- rewards = reasoning_steps_reward(completion)
+ rewards = ReasoningStepsReward()(completion)
self.assertAlmostEqual(rewards[0], expected_reward)
def test_multiple_completions(self):
@@ -72,7 +72,7 @@ def test_multiple_completions(self):
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
- rewards = accuracy_reward(completions, solutions)
+ rewards = AccuracyReward()(completions, solution=solutions)
self.assertEqual(len(rewards), 2)
self.assertEqual(rewards[0], 1.0)
self.assertEqual(rewards[1], 0.0)
@@ -102,14 +102,14 @@ def test_cosine_scaled_reward(self):
padded_content = content + " " * (content_len - len(content))
completion = [[{"content": padded_content}]]
- rewards = get_cosine_scaled_reward(**test_params)(completion, [solution])
+ rewards = CosineScaledReward(**test_params)(completion, solution=[solution])
self.assertAlmostEqual(rewards[0], expected_reward, places=2)
def test_format_reward_specific_multiline(self):
"""Test format_reward with a specific multiline input."""
inputs = "\nI will count each distinct object in the image:\n1. Purple scooter\n2. Red bicycle\n3. Green motorcycle\n4. Gray sedan\n5. Yellow school bus\n6. Small green double-decker bus\n7. Small red car\n8. Small purple car\n9. Small gray dirt bike\n\nThere are 9 distinct objects in total.\n\n9"
completion = [[{"content": inputs}]]
- rewards = format_reward(completion)
+ rewards = FormatReward()(completion)
self.assertEqual(rewards[0], 1.0)
def test_same_length_responses(self):
@@ -117,7 +117,7 @@ def test_same_length_responses(self):
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
- rewards = len_reward(completions, solutions)
+ rewards = LengthReward()(completions, solutions)
self.assertEqual(rewards, [0.0, 0.0])
def test_different_lengths_correct_answers(self):
@@ -128,7 +128,7 @@ def test_different_lengths_correct_answers(self):
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
- rewards = len_reward(completions, solutions)
+ rewards = LengthReward()(completions, solutions)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward
self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward
@@ -140,7 +140,7 @@ def test_different_lengths_incorrect_answers(self):
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
- rewards = len_reward(completions, solutions)
+ rewards = LengthReward()(completions, solutions)
self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards
self.assertLessEqual(rewards[1], 0.0)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less
@@ -155,7 +155,7 @@ def test_mixed_correctness(self):
]
solutions = [r"\frac{63}{400}"] * 4
- rewards = len_reward(completions, solutions)
+ rewards = LengthReward()(completions, solutions)
# Shortest correct answer should get positive reward
self.assertGreater(rewards[0], 0.0)
@@ -177,7 +177,7 @@ def test_unparseable_solution(self):
completions = [[{"content": r"\boxed{answer}"}], [{"content": r"\boxed{answer} " + "x" * 10}]]
solutions = ["unparseable_latex", "unparseable_latex"]
- rewards = len_reward(completions, solutions)
+ rewards = LengthReward()(completions, solutions)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward
self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward
@@ -185,18 +185,18 @@ def test_unparseable_solution(self):
class TestRepetitionPenaltyReward(unittest.TestCase):
def test_positive_max_penalty_raises_value_error(self):
with self.assertRaises(ValueError):
- get_repetition_penalty_reward(ngram_size=2, max_penalty=1.0)
+ RepetitionPenaltyReward(ngram_size=2, max_penalty=1.0)
with self.assertRaisesRegex(ValueError, "max_penalty 1.5 should not be positive"):
- get_repetition_penalty_reward(ngram_size=2, max_penalty=1.5)
+ RepetitionPenaltyReward(ngram_size=2, max_penalty=1.5)
def test_no_repetition(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": "this is a test sentence"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_full_repetition(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": "this this this this this"}]]
rewards = reward_fn(completions)
@@ -204,7 +204,7 @@ def test_full_repetition(self):
self.assertEqual(rewards, [-0.75])
def test_partial_repetition(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": "this is a this is a test"}]]
rewards = reward_fn(completions)
@@ -213,7 +213,7 @@ def test_partial_repetition(self):
self.assertAlmostEqual(rewards[0], -1 / 3)
def test_multiple_completions(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-0.5)
completions = [
[{"content": "this is a test"}],
[{"content": "test test test test"}],
@@ -226,20 +226,20 @@ def test_multiple_completions(self):
self.assertAlmostEqual(rewards[1], -0.25)
def test_empty_completion(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": ""}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_different_ngram_size(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-2.0)
completions = [[{"content": "this is a this is a test"}]]
rewards = reward_fn(completions)
self.assertAlmostEqual(rewards[0], -0.4)
def test_mixed_case(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=2, max_penalty=-1.0)
completions = [
[{"content": "This is A Test"}],
[{"content": "this IS a test"}],
@@ -250,35 +250,35 @@ def test_mixed_case(self):
self.assertAlmostEqual(rewards[0], rewards[1])
def test_one_word_completion(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "word"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_two_word_completion(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "two words"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_three_word_completion(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "three different words"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_three_word_repetition_completion(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "word word word word"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.5])
def test_four_word_completion_with_repetition(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "one two one two"}]]
rewards = reward_fn(completions)
@@ -286,7 +286,7 @@ def test_four_word_completion_with_repetition(self):
self.assertEqual(rewards, [0.0])
def test_five_word_completion_with_repetition(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-0.5)
completions = [[{"content": "A B C A B"}]]
rewards = reward_fn(completions)
@@ -294,20 +294,20 @@ def test_five_word_completion_with_repetition(self):
self.assertEqual(rewards, [0.0])
def test_six_word_completion_with_repetition(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "A B C A B C"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.25])
def test_long_completion_with_repetition(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "A B C A B C E F G A B C A B C"}]]
rewards = reward_fn(completions)
self.assertAlmostEqual(rewards[0], -0.3846, places=4)
def test_long_completion_without_repetition(self):
- reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
+ reward_fn = RepetitionPenaltyReward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "A B C D E F G H I J K L"}]]
rewards = reward_fn(completions)
@@ -324,7 +324,7 @@ def test_correct_python_format(self):
}
]
]
- reward_fn = get_code_format_reward(language="python")
+ reward_fn = CodeFormatReward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)
@@ -343,7 +343,7 @@ def test_incorrect_formats(self):
"```python\ndef hello(): pass\n```Analysis",
]
- reward_fn = get_code_format_reward(language="python")
+ reward_fn = CodeFormatReward(language="python")
for fmt in incorrect_formats:
completion = [[{"content": fmt}]]
rewards = reward_fn(completion)
@@ -358,7 +358,7 @@ def test_multiple_code_blocks(self):
}
]
]
- reward_fn = get_code_format_reward(language="python")
+ reward_fn = CodeFormatReward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)
@@ -369,12 +369,12 @@ def test_different_languages(self):
]
# Test with JavaScript
- js_reward_fn = get_code_format_reward(language="javascript")
+ js_reward_fn = CodeFormatReward(language="javascript")
rewards = js_reward_fn(completion)
self.assertEqual(rewards[0], 1.0)
# Same completion should fail for Python
- py_reward_fn = get_code_format_reward(language="python")
+ py_reward_fn = CodeFormatReward(language="python")
rewards = py_reward_fn(completion)
self.assertEqual(rewards[0], 0.0)
@@ -387,7 +387,7 @@ def test_multiline_code(self):
}
]
]
- reward_fn = get_code_format_reward(language="python")
+ reward_fn = CodeFormatReward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)