diff --git a/mellea/stdlib/sampling/budget_forcing.py b/mellea/stdlib/sampling/budget_forcing.py
new file mode 100644
index 00000000..a1be15a2
--- /dev/null
+++ b/mellea/stdlib/sampling/budget_forcing.py
@@ -0,0 +1,244 @@
+"""Sampling Strategies for budget forcing generation."""
+
+from copy import deepcopy
+
+import tqdm
+
+from mellea.backends import Backend, BaseModelSubclass
+from mellea.backends.ollama import OllamaModelBackend
+from mellea.helpers.fancy_logger import FancyLogger
+from mellea.stdlib import funcs as mfuncs
+from mellea.stdlib.base import ModelOutputThunk
+from mellea.stdlib.requirement import Requirement, ValidationResult
+from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult
+from mellea.stdlib.sampling.base import Component, Context
+from mellea.stdlib.sampling_algos.budget_forcing_alg import think_budget_forcing
+
+
+class BudgetForcingSamplingStrategy(RejectionSamplingStrategy):
+ """Budget forcing sampling class."""
+
+ think_max_tokens: int | None
+ answer_max_tokens: int | None
+ start_think_token: str | None
+ end_think_token: str | None
+ begin_response_token: str | None
+ end_response_token: str
+ think_more_suffix: str | None
+ answer_suffix: str | None
+
+ def __init__(
+ self,
+ *,
+ think_max_tokens: int | None = 4096,
+ answer_max_tokens: int | None = None,
+ start_think_token: str | None = "",
+ end_think_token: str | None = "",
+ begin_response_token: str | None = "",
+ end_response_token: str = "",
+ think_more_suffix: str | None = "",
+ answer_suffix: str | None = "",
+ loop_budget: int = 1,
+ requirements: list[Requirement] | None,
+ ):
+ r"""Initialize class.
+
+ Inherits from RejectionSamplingStrategy.
+
+ Args:
+ think_max_tokens: Number of tokens for think block
+ answer_max_tokens: Number of tokens allocated for answer portion, if set to None answer tokens will be unlimited
+ start_think_token: Special start of think block token defaults to ''
+ end_think_token: Special end of think block token defaults to ''
+ begin_response_token: Special begin of response block token e.g. '' defaults to ""
+ end_response_token: Special end of response block token e.g. '' defaults to ""
+ think_more_suffix: Suffix for continue thinking e.g. "\nWait let's think more carefully" to force the model to think more, defaults to "". If set to "", no force thinking will be applied, the token budget will be become an upper bound.
+ answer_suffix: Suffix to obtain final answer, default to "\nThe final answer is:"
+ loop_budget: Number of times to iterate through the process. Must be greater than 0.
+ requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.
+
+ Raises:
+ AssertionError: If loop_budget is not greater than 0.
+ """
+ super().__init__(loop_budget=loop_budget, requirements=requirements)
+ self.think_max_tokens = think_max_tokens
+ self.answer_max_tokens = answer_max_tokens
+ self.start_think_token = start_think_token
+ self.end_think_token = end_think_token
+ self.begin_response_token = begin_response_token
+ self.end_response_token = end_response_token
+ self.think_more_suffix = think_more_suffix
+ self.answer_suffix = answer_suffix
+
+ async def sample(
+ self,
+ action: Component,
+ context: Context,
+ backend: Backend,
+ requirements: list[Requirement] | None,
+ *,
+ validation_ctx: Context | None = None,
+ format: type[BaseModelSubclass] | None = None,
+ model_options: dict | None = None,
+ tool_calls: bool = False,
+ show_progress: bool = True,
+ ) -> SamplingResult:
+ """This method performs a sampling operation based on the given instruction.
+
+ Args:
+ action : The action object to be sampled.
+ context: The context to be passed to the sampling strategy.
+ backend: The backend used for generating samples.
+ requirements: List of requirements to test against (merged with global requirements).
+ validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx.
+ format: output format for structured outputs.
+ model_options: model options to pass to the backend during generation / validation.
+ tool_calls: True if tool calls should be used during this sampling strategy.
+ show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog.
+
+ Returns:
+ SamplingResult: A result object indicating the success or failure of the sampling process.
+
+ Raises:
+ AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling.
+ """
+ validation_ctx = validation_ctx if validation_ctx is not None else context
+
+ flog = FancyLogger.get_logger()
+
+ sampled_results: list[ModelOutputThunk] = []
+ sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = []
+ sampled_actions: list[Component] = []
+ sample_contexts: list[Context] = []
+
+ # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress
+ # flag to determine whether we should show the pbar.
+ show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO
+
+ reqs = []
+ # global requirements supersede local requirements (global requirements can be defined by user)
+ # Todo: re-evaluate if this makes sense
+ if self.requirements is not None:
+ reqs += self.requirements
+ elif requirements is not None:
+ reqs += requirements
+ reqs = list(set(reqs))
+
+ loop_count = 0
+ loop_budget_range_iterator = (
+ tqdm.tqdm(range(self.loop_budget)) # type: ignore
+ if show_progress
+ else range(self.loop_budget) # type: ignore
+ )
+
+ next_action = deepcopy(action)
+ next_context = context
+ for _ in loop_budget_range_iterator: # type: ignore
+ loop_count += 1
+ if not show_progress:
+ flog.info(f"Running loop {loop_count} of {self.loop_budget}")
+
+ # TODO
+ # tool_calls is not supported for budget forcing
+ assert tool_calls is False, (
+ "tool_calls is not supported with budget forcing"
+ )
+ # TODO
+ assert isinstance(backend, OllamaModelBackend), (
+ "Only ollama backend supported with budget forcing"
+ )
+ # run a generation pass with budget forcing
+ result = think_budget_forcing(
+ backend,
+ next_action,
+ ctx=context,
+ format=format,
+ tool_calls=tool_calls,
+ think_max_tokens=self.think_max_tokens,
+ answer_max_tokens=self.answer_max_tokens,
+ start_think_token=self.start_think_token,
+ end_think_token=self.end_think_token,
+ think_more_suffix=self.think_more_suffix,
+ answer_suffix=self.answer_suffix,
+ model_options=model_options,
+ )
+ result_ctx = next_context
+
+ # validation pass
+ val_scores_co = mfuncs.avalidate(
+ reqs=reqs,
+ context=result_ctx,
+ backend=backend,
+ output=result,
+ format=format,
+ model_options=model_options,
+ # tool_calls=tool_calls # Don't support using tool calls in validation strategies.
+ )
+ val_scores = await val_scores_co
+
+ # match up reqs with scores
+ constraint_scores = list(zip(reqs, val_scores))
+
+ # collect all data
+ sampled_results.append(result)
+ sampled_scores.append(constraint_scores)
+ sampled_actions.append(next_action)
+ sample_contexts.append(result_ctx)
+
+ # if all vals are true -- break and return success
+ if all(bool(s[1]) for s in constraint_scores):
+ flog.info("SUCCESS")
+ assert (
+ result._generate_log is not None
+ ) # Cannot be None after generation.
+ result._generate_log.is_final_result = True
+
+ # SUCCESS !!!!
+ return SamplingResult(
+ result_index=len(sampled_results) - 1,
+ success=True,
+ sample_generations=sampled_results,
+ sample_validations=sampled_scores,
+ sample_contexts=sample_contexts,
+ sample_actions=sampled_actions,
+ )
+
+ else:
+ # log partial success and continue
+ count_valid = len([s for s in constraint_scores if bool(s[1])])
+ flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}")
+
+ # If we did not pass all constraints, update the instruction and try again.
+ next_action, next_context = self.repair(
+ next_context,
+ result_ctx,
+ sampled_actions,
+ sampled_results,
+ sampled_scores,
+ )
+
+ flog.info(
+ f"Invoking select_from_failure after {len(sampled_results)} failed attempts."
+ )
+
+ # if no valid result could be determined, find a last resort.
+ best_failed_index = self.select_from_failure(
+ sampled_actions, sampled_results, sampled_scores
+ )
+ assert best_failed_index < len(sampled_results), (
+ "The select_from_failure method did not return a valid result. It has to selected from failed_results."
+ )
+
+ assert (
+ sampled_results[best_failed_index]._generate_log is not None
+ ) # Cannot be None after generation.
+ sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore
+
+ return SamplingResult(
+ result_index=best_failed_index,
+ success=False,
+ sample_generations=sampled_results,
+ sample_validations=sampled_scores,
+ sample_actions=sampled_actions,
+ sample_contexts=sample_contexts,
+ )
diff --git a/mellea/stdlib/sampling_algos/budget_forcing_alg.py b/mellea/stdlib/sampling_algos/budget_forcing_alg.py
new file mode 100644
index 00000000..eae78e6d
--- /dev/null
+++ b/mellea/stdlib/sampling_algos/budget_forcing_alg.py
@@ -0,0 +1,175 @@
+"""Budget forcing implementation."""
+
+import re
+from typing import Any
+
+from mellea.backends import BaseModelSubclass, ModelOption
+from mellea.backends.ollama import OllamaModelBackend
+from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk
+
+
+def think_budget_forcing( # noqa: D417
+ backend: OllamaModelBackend,
+ action: CBlock | Component,
+ *,
+ ctx: Context,
+ format: type[BaseModelSubclass] | None = None,
+ tool_calls: bool = False,
+ think_max_tokens: int | None = 4096,
+ answer_max_tokens: int | None = None,
+ start_think_token: str | None = "",
+ end_think_token: str | None = "",
+ begin_response_token: str | None = "",
+ think_more_suffix: str | None = "",
+ answer_suffix: str | None = "",
+ model_options: dict | None = None,
+) -> ModelOutputThunk:
+ r"""Generate with budget forcing using the completions APIs.
+
+ This relies on raw autocompletion and assumes the model's output is structured in the following form: ' ... summary answer'
+ The budget forcing method is proposed in the paper: https://arxiv.org/abs/2501.19393
+ This implementation tries to follow the key outlines in the paper while ensuring stable and fail-safe operation.
+ This is performed via multi-step generation. The model will be called multiple times until requirements are met, in other words, the response will be assembled conditionally.
+
+ Args:
+ backend: OllamaModelBackend
+ action: The last item of the context should be passed in as an `action` instead of as part of the `ctx`. See `docs/dev/generate_signature_decisions.md`.
+ think_max_tokens: Budget in number of tokens allocated for the think block
+ answer_max_tokens: Budget in number of tokens allocated for the summary and answer block, None indicates unbounded answer, generating till EoS
+ start_think_token: String indicating start of think block, default
+ end_think_token: String indicating end of think block, default
+ begin_response_token: Used by certain models, string indicating start of response block, e.g. "", default None
+ think_more_suffix: String to append to force continued thinking, e.g. "\nWait" if set to None we will not force additional thinking. Use None for upper-bound budget case
+ answer_suffix: String to append to force a final answer
+ model_options: Any model options to upsert into the defaults for this call.
+
+ Assumptions:
+ - The chat template is applied on prompt, with think mode enabled
+ - Model is think mode activated
+ - enabling prefix-caching improves performance
+
+ Limitations:
+ - Does not support batching
+ """
+ responses = []
+ prompt = backend.formatter.print(action)
+ if start_think_token:
+ prompt += start_think_token
+ responses.append(start_think_token)
+
+ # Generate thinking portion
+ if model_options is None:
+ model_options = dict()
+ model_options["n"] = 1
+ if think_max_tokens is None:
+ think_max_tokens = 0
+ rem_toks = think_max_tokens
+ model_options[ModelOption.MAX_NEW_TOKENS] = rem_toks
+ gen_tok_count = 0
+ curr_prompt = prompt
+ _generate_logs: list[GenerateLog | None] = []
+ _meta_logs: list[dict[str, Any]] = []
+ min_char_len = 10
+
+ # think block indefinite multi-step operation to satisfy user's budget
+ while True:
+ if rem_toks <= 0: # zero-think case
+ break
+
+ model_options[ModelOption.MAX_NEW_TOKENS] = rem_toks
+ result = backend.generate_from_raw(
+ [CBlock(value=curr_prompt)],
+ model_options=model_options,
+ ctx=ctx,
+ tool_calls=tool_calls,
+ format=format,
+ )
+ _generate_logs.append(result[0]._generate_log)
+ if result[0]._meta is None:
+ raise Exception("Requires meta information in generation results.")
+
+ _meta_logs.append(result[0]._meta)
+ gen_tok_count += result[0]._meta["usage"]["completion_tokens"]
+ rem_toks = think_max_tokens - gen_tok_count
+ response = result[0].value if result[0].value else ""
+
+ if think_more_suffix is None or think_more_suffix == "":
+ # non-strict budget form
+ responses.append(response)
+ break
+
+ if rem_toks <= 0:
+ responses.append(response)
+ break
+
+ else:
+ if end_think_token:
+ step = response.split(end_think_token)[0]
+ # model fails to produce thoughts, let's exit
+ if len(step.strip()) <= min_char_len:
+ responses.append(response)
+ break
+
+ # request more steps
+ step = f"{step} {think_more_suffix}"
+ responses.append(step)
+ curr_prompt += step
+
+ response = "".join(responses)
+
+ if answer_suffix is None:
+ # create response ModelOutputThunk object
+ _meta = _meta_logs[-1] # Initialize using the last meta object
+ _meta["usage"]["completion_tokens"] = gen_tok_count
+ # the first prompt is the true prompt
+ _meta["usage"]["prompt_tokens"] = _meta_logs[0]["usage"]["prompt_tokens"]
+ _meta["usage"]["total_tokens"] = (
+ _meta["usage"]["prompt_tokens"] + _meta["usage"]["completion_tokens"]
+ )
+ _res = ModelOutputThunk(value=response, meta=_meta)
+ # we will simply take the last log output as a representative log, alternatively we can merge the logs but that function is not available yet
+ _res._generate_log = _generate_logs[-1]
+ return _res
+
+ # One more round of generate to get an answer
+ if end_think_token and end_think_token not in response:
+ response += f" {end_think_token}"
+
+ if begin_response_token and begin_response_token not in response:
+ response += f" {begin_response_token}"
+
+ if answer_suffix:
+ response += f" {answer_suffix}"
+
+ # update original curr_prompt with assembled response
+ curr_prompt += response
+ if answer_max_tokens is not None:
+ model_options[ModelOption.MAX_NEW_TOKENS] = answer_max_tokens
+
+ else:
+ model_options.pop(ModelOption.MAX_NEW_TOKENS, None) # generate unconditionally
+
+ # model_options["logprobs"] = 1 # To get number of generated tokens
+ result = backend.generate_from_raw(
+ [CBlock(curr_prompt)],
+ model_options=model_options,
+ ctx=ctx,
+ tool_calls=tool_calls,
+ format=format,
+ )
+ _generate_logs.append(result[0]._generate_log)
+ response += result[0].value if result[0].value else ""
+ _meta_logs.append(result[0]._meta)
+ gen_tok_count += result[0]._meta["usage"]["completion_tokens"]
+ # create response ModelOutputThunk object
+ _meta = _meta_logs[-1] # Initialize using the last meta object
+ _meta["usage"]["completion_tokens"] = gen_tok_count
+ # the first prompt is the true prompt
+ _meta["usage"]["prompt_tokens"] = _meta_logs[0]["usage"]["prompt_tokens"]
+ _meta["usage"]["total_tokens"] = (
+ _meta["usage"]["prompt_tokens"] + _meta["usage"]["completion_tokens"]
+ )
+ _res = ModelOutputThunk(value=response, meta=_meta)
+ # we will simply take the last log output as a representative log, alternatively we can merge the logs but that function is not available yet
+ _res._generate_log = _generate_logs[-1]
+ return _res
diff --git a/test/stdlib_basics/test_think_budget_forcing.py b/test/stdlib_basics/test_think_budget_forcing.py
new file mode 100644
index 00000000..f6645f66
--- /dev/null
+++ b/test/stdlib_basics/test_think_budget_forcing.py
@@ -0,0 +1,85 @@
+"""Testing functions for budget forcing generation."""
+
+import pytest
+
+from mellea import MelleaSession, start_session
+from mellea.backends import ModelOption
+from mellea.backends.model_ids import OPENAI_GPT_OSS_20B
+from mellea.stdlib.base import CBlock
+from mellea.stdlib.sampling.budget_forcing import BudgetForcingSamplingStrategy
+
+MODEL_ID = OPENAI_GPT_OSS_20B
+
+
+@pytest.fixture(scope="module")
+def m_session(gh_run):
+ """Start default Mellea's session."""
+ if gh_run == 1: # on github
+ m = start_session(
+ "ollama",
+ model_id=MODEL_ID,
+ model_options={ModelOption.MAX_NEW_TOKENS: 5},
+ )
+ else:
+ m = start_session(
+ "ollama",
+ model_id=MODEL_ID,
+ )
+ yield m
+ del m
+
+
+def test_think_big(m_session: MelleaSession, gh_run: int):
+ """Tests big thinking budget."""
+ # if on github we can run big thinking mode
+ if gh_run == 1:
+ pytest.skip("Skipping big_thinking runs in gh workflows.")
+
+ prompt = "What is the smallest positive integer $n$ such that all the roots of $z^4 + z^2 + 1 = 0$ are $n^{\\text{th}}$ roots of unity?"
+ prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}."
+ action = CBlock(value=prompt + prompt_suffix)
+ THINK_MAX_TOKENS = 2048
+ ANSWER_MAX_TOKENS = 512
+
+ strategy = BudgetForcingSamplingStrategy(
+ think_max_tokens=THINK_MAX_TOKENS,
+ answer_max_tokens=ANSWER_MAX_TOKENS,
+ start_think_token="",
+ end_think_token="",
+ think_more_suffix="\nWait, let's think more carefully",
+ answer_suffix="The final answer is:",
+ requirements=None
+ )
+ result = m_session.instruct(action, strategy=strategy)
+
+ print("\n******\nThink big:")
+ print(str(result))
+
+
+def test_think_little(m_session: MelleaSession, gh_run: int):
+ """Tests little thinking budget."""
+ prompt = "Compute 1+1?"
+ prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}."
+ action = CBlock(value=prompt + prompt_suffix)
+ THINK_MAX_TOKENS = 16
+ ANSWER_MAX_TOKENS = 8
+ if gh_run == 1: # on github
+ THINK_MAX_TOKENS = 0
+ ANSWER_MAX_TOKENS = 5
+
+ strategy = BudgetForcingSamplingStrategy(
+ think_max_tokens=THINK_MAX_TOKENS,
+ answer_max_tokens=ANSWER_MAX_TOKENS,
+ start_think_token="",
+ end_think_token="",
+ answer_suffix="The final answer is: \\boxed{",
+ requirements=None
+ )
+ result = m_session.instruct(action, strategy=strategy)
+
+ print("\n******\nThink little:")
+ print(str(result))
+
+
+if __name__ == "__main__":
+ pytest.main(["-s", __file__])