diff --git a/.coverage b/.coverage
index 7ac1d3f..44c7c72 100644
Binary files a/.coverage and b/.coverage differ
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 22c8d88..f08867b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -14,6 +14,11 @@ repos:
rev: 6.0.0
hooks:
- id: flake8
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.14.6
+ hooks:
+ - id: ruff-check
+ args: [ --fix ]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
diff --git a/README.md b/README.md
index 9288857..8f33e84 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,3 @@
-

[](https://github.com/finitearth/promptolution/actions/workflows/ci.yml)
@@ -7,104 +6,93 @@

[](https://colab.research.google.com/github/finitearth/promptolution/blob/main/tutorials/getting_started.ipynb)
-Promptolution is a library that provides a modular and extensible framework for implementing prompt tuning for single tasks and larger experiments. It offers a user-friendly interface to assemble the core components for various prompt optimization tasks.
+
+
+
+
+
+
+
+
+
+
+## π What is Promptolution?
-This project was developed by [Timo HeiΓ](https://www.linkedin.com/in/timo-heiss/), [Moritz Schlager](https://www.linkedin.com/in/moritz-schlager/) and [Tom Zehle](https://www.linkedin.com/in/tom-zehle/) as part of a study program at LMU Munich.
+**Promptolution** is a unified, modular framework for prompt optimization built for researchers and advanced practitioners who want full control over their experimental setup. Unlike end-to-end application frameworks with high abstraction, promptolution focuses exclusively on the optimization stage, providing a clean, transparent, and extensible API. It allows for simple prompt optimization for one task up to large-scale reproducible benchmark experiments.
+
+
+
+### Key Features
-## Installation
+* Implementation of many current prompt optimizers out of the box.
+* Unified LLM backend supporting API-based models, Local LLMs, and vLLM clusters.
+* Built-in response caching to save costs and parallelized inference for speed.
+* Detailed logging and token usage tracking for granular post-hoc analysis.
-Use pip to install our library:
+Have a look at our [Release Notes](https://finitearth.github.io/promptolution/release-notes/) for the latest updates to promptolution.
+
+## π¦ Installation
```
pip install promptolution[api]
```
-If you want to run your prompt optimization locally, either via transformers or vLLM, consider running:
+Local inference via vLLM or transformers:
```
pip install promptolution[vllm,transformers]
```
-Alternatively, clone the repository, run
+From source:
```
+git clone https://github.com/finitearth/promptolution.git
+cd promptolution
poetry install
```
-to install the necessary dependencies. You might need to install [pipx](https://pipx.pypa.io/stable/installation/) and [poetry](https://python-poetry.org/docs/) first.
-
-## Usage
-
-To get started right away, take a look at our [getting started notebook](https://github.com/finitearth/promptolution/blob/main/tutorials/getting_started.ipynb) and our [other demos and tutorials](https://github.com/finitearth/promptolution/blob/main/tutorials).
-For more details, a comprehensive **documentation** with API reference is availabe at https://finitearth.github.io/promptolution/.
+## π§ Quickstart
-### Featured Optimizers
+Start with the **Getting Started tutorial**:
+[https://github.com/finitearth/promptolution/blob/main/tutorials/getting_started.ipynb](https://github.com/finitearth/promptolution/blob/main/tutorials/getting_started.ipynb)
-| **Name** | **Paper** | **init prompts** | **Exploration** | **Costs** | **Parallelizable** | **Utilizes Fewshot Examples** |
-| :-----------: | :----------------------------------------------: | :--------------: | :-------------: | :-------: | :-------------------: | :---------------------------: |
-| `CAPO` | [Zehle et al.](https://arxiv.org/abs/2504.16005) | _required_ | π | π² | β | β |
-| `EvoPromptDE` | [Guo et al.](https://arxiv.org/abs/2309.08532) | _required_ | π | π²π² | β | β |
-| `EvoPromptGA` | [Guo et al.](https://arxiv.org/abs/2309.08532) | _required_ | π | π²π² | β | β |
-| `OPRO` | [Yang et al.](https://arxiv.org/abs/2309.03409) | _optional_ | π | π²π² | β | β |
+Full docs:
+[https://finitearth.github.io/promptolution/](https://finitearth.github.io/promptolution/)
-### Core Components
-
-- `Task`: Encapsulates initial prompts, dataset features, targets, and evaluation methods.
-- `Predictor`: Implements the prediction logic, interfacing between the `Task` and `LLM` components.
-- `LLM`: Unifies the process of obtaining responses from language models, whether locally hosted or accessed via API.
-- `Optimizer`: Implements prompt optimization algorithms, utilizing the other components during the optimization process.
-
-### Key Features
-- Modular and object-oriented design
-- Extensible architecture
-- Easy-to-use interface for assembling experiments
-- Parallelized LLM requests for improved efficiency
-- Integration with langchain for standardized LLM API calls
-- Detailed logging and callback system for optimization analysis
+## π§ Featured Optimizers
-## Changelog
+| **Name** | **Paper** | **Init prompts** | **Exploration** | **Costs** | **Parallelizable** | **Few-shot** |
+| ---- | ---- | ---- |---- |---- | ----|---- |
+| `CAPO` | [Zehle et al., 2025](https://arxiv.org/abs/2504.16005) | required | π | π² | β | β |
+| `EvoPromptDE` | [Guo et al., 2023](https://arxiv.org/abs/2309.08532) | required | π | π²π² | β | β |
+| `EvoPromptGA` | [Guo et al., 2023](https://arxiv.org/abs/2309.08532) | required | π | π²π² | β | β |
+| `OPRO` | [Yang et al., 2023](https://arxiv.org/abs/2309.03409) | optional | π | π²π² | β | β |
-Release notes for each version of the library can be found [here](https://finitearth.github.io/promptolution/release-notes/)
+## π Components
-## Contributing
+* **`Task`** β Manages the dataset, evaluation metrics, and subsampling.
+* **`Predictor`** β Defines how to extract the answer from the model's response.
+* **`LLM`** β A unified interface handling inference, token counting, and concurrency.
+* **`Optimizer`** β The core component that implements the algorithms that refine prompts.
+* **`ExperimentConfig`** β A configuration abstraction to streamline and parametrize large-scale scientific experiments.
-The first step to contributing is to open an issue describing the bug, feature, or enhancements. Ensure the issue is clearly described, assigned, and properly tagged. All work should be linked to an open issue.
+## π€ Contributing
-### Code Style and Linting
+Open an issue β create a branch β PR β CI β review β merge.
+Branch naming: `feature/...`, `fix/...`, `chore/...`, `refactor/...`.
-We use Black for code formatting, Flake8 for linting, pydocstyle for docstring conventions (Google format), and isort to sort imports. All these checks are enforced via pre-commit hooks, which automatically run on every commit. Install the pre-commit hooks to ensure that all checks run automatically:
+Please ensure to use pre-commit, which assists with keeping the code quality high:
```
pre-commit install
-```
-
-To run all checks manually:
-
-```
pre-commit run --all-files
```
-
-### Branch Protection and Merging Guidelines
-
-- The main branch is protected. No direct commits are allowed for non-administrators.
-- Rebase your branch on main before opening a pull request.
-- All contributions must be made on dedicated branches linked to specific issues.
-- Name the branch according to {prefix}/{description} with one of the prefixes fix, feature, chore, or refactor.
-- A pull request must have at least one approval from a code owner before it can be merged into main.
-- CI checks must pass before a pull request can be merged.
-- New releases will only be created by code owners.
-
-### Testing
-
-We use pytest to run tests, and coverage to track code coverage. Tests automatically run on pull requests and pushes to the main branch, but please ensure they also pass locally before pushing!
-To run the tests with coverage locally, use the following commands or your IDE's test runner:
+We encourage every contributor to also write tests, that automatically check if the implementation works as expected:
```
poetry run python -m coverage run -m pytest
-```
-
-To see the coverage report run:
-```
poetry run python -m coverage report
```
+
+Developed by **Timo HeiΓ**, **Moritz Schlager**, and **Tom Zehle** (LMU Munich, MCML, ELLIS, TUM, Uni Freiburg).
diff --git a/docs/index.md b/docs/index.md
index c562b8c..5496305 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -29,5 +29,6 @@ Or clone our GitHub repository:
- [Optimizers](api/optimizers.md)
- [Predictors](api/predictors.md)
- [Tasks](api/tasks.md)
-- [Callbacks](api/callbacks.md)
-- [Config](api/config.md)
+- [Helpers](api/helpers.md)
+- [Utils](api/utils.md)
+- [Exemplar Selectors](api/examplar_selectors.md)
diff --git a/docs/release-notes/v2.2.0.md b/docs/release-notes/v2.2.0.md
new file mode 100644
index 0000000..8724a41
--- /dev/null
+++ b/docs/release-notes/v2.2.0.md
@@ -0,0 +1,13 @@
+## Release v2.2.0
+### What's changed
+
+#### Added features:
+* Extended interface of APILLM allowing to pass kwargs to the API
+* Improve asynchronous parallelization of LLM calls shortening inference times
+* Introduced a `Prompt` class to encapsulate instructions and few-shot examples
+
+#### Further changes:
+* Improved error handling
+* Improved task-description infusion mechanism for meta-prompts
+
+**Full Changelog**: [here](https://github.com/finitearth/promptolution/compare/2.1.0...v2.2.0)
diff --git a/mkdocs.yml b/mkdocs.yml
index 57cde7a..ac377fb 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -47,6 +47,7 @@ nav:
- Home: index.md
- Release Notes:
- Overview: release-notes.md
+ - v2.2.0: release-notes/v2.2.0.md
- v2.1.0: release-notes/v2.1.0.md
- v2.0.1: release-notes/v2.0.1.md
- v2.0.0: release-notes/v2.0.0.md
diff --git a/promptolution/exemplar_selectors/__init__.py b/promptolution/exemplar_selectors/__init__.py
index 62e6c9a..e948a3a 100644
--- a/promptolution/exemplar_selectors/__init__.py
+++ b/promptolution/exemplar_selectors/__init__.py
@@ -2,3 +2,8 @@
from promptolution.exemplar_selectors.random_search_selector import RandomSearchSelector
from promptolution.exemplar_selectors.random_selector import RandomSelector
+
+__all__ = [
+ "RandomSelector",
+ "RandomSearchSelector",
+]
diff --git a/promptolution/exemplar_selectors/base_exemplar_selector.py b/promptolution/exemplar_selectors/base_exemplar_selector.py
index bb2ee21..5d77647 100644
--- a/promptolution/exemplar_selectors/base_exemplar_selector.py
+++ b/promptolution/exemplar_selectors/base_exemplar_selector.py
@@ -5,6 +5,8 @@
from typing import TYPE_CHECKING, Optional
+from promptolution.utils.prompt import Prompt
+
if TYPE_CHECKING: # pragma: no cover
from promptolution.predictors.base_predictor import BasePredictor
from promptolution.tasks.base_task import BaseTask
@@ -33,11 +35,11 @@ def __init__(self, task: "BaseTask", predictor: "BasePredictor", config: Optiona
config.apply_to(self)
@abstractmethod
- def select_exemplars(self, prompt: str, n_examples: int = 5) -> str:
+ def select_exemplars(self, prompt: Prompt, n_examples: int = 5) -> Prompt:
"""Select exemplars based on the given prompt.
Args:
- prompt (str): The input prompt to base the exemplar selection on.
+ prompt (Prompt): The input prompt to base the exemplar selection on.
n_examples (int, optional): The number of exemplars to select. Defaults to 5.
Returns:
diff --git a/promptolution/exemplar_selectors/random_search_selector.py b/promptolution/exemplar_selectors/random_search_selector.py
index 7a88b08..b8cb6ee 100644
--- a/promptolution/exemplar_selectors/random_search_selector.py
+++ b/promptolution/exemplar_selectors/random_search_selector.py
@@ -1,6 +1,7 @@
"""Random search exemplar selector."""
from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
+from promptolution.utils.prompt import Prompt
class RandomSearchSelector(BaseExemplarSelector):
@@ -10,7 +11,7 @@ class RandomSearchSelector(BaseExemplarSelector):
evaluates their performance, and selects the best performing set.
"""
- def select_exemplars(self, prompt: str, n_trials: int = 5) -> str:
+ def select_exemplars(self, prompt: Prompt, n_trials: int = 5) -> Prompt:
"""Select exemplars using a random search strategy.
This method generates multiple sets of random examples, evaluates their performance
@@ -21,7 +22,7 @@ def select_exemplars(self, prompt: str, n_trials: int = 5) -> str:
n_trials (int, optional): The number of random trials to perform. Defaults to 5.
Returns:
- str: The best performing prompt, which includes the original prompt and the selected exemplars.
+ Prompt: The best performing prompt, which includes the original prompt and the selected exemplars.
"""
best_score = 0.0
best_prompt = prompt
@@ -30,7 +31,7 @@ def select_exemplars(self, prompt: str, n_trials: int = 5) -> str:
_, seq = self.task.evaluate(
prompt, self.predictor, eval_strategy="subsample", return_seq=True, return_agg_scores=False
)
- prompt_with_examples = "\n\n".join([prompt] + [seq[0][0]]) + "\n\n"
+ prompt_with_examples = Prompt(prompt.instruction, [seq[0][0]])
# evaluate prompts as few shot prompt
score = self.task.evaluate(prompt_with_examples, self.predictor, eval_strategy="subsample")[0]
if score > best_score:
diff --git a/promptolution/exemplar_selectors/random_selector.py b/promptolution/exemplar_selectors/random_selector.py
index a6a4b72..7b0ae0f 100644
--- a/promptolution/exemplar_selectors/random_selector.py
+++ b/promptolution/exemplar_selectors/random_selector.py
@@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, List, Optional
from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
+from promptolution.utils.prompt import Prompt
if TYPE_CHECKING: # pragma: no cover
from promptolution.predictors.base_predictor import BasePredictor
@@ -37,18 +38,18 @@ def __init__(
self.desired_score = desired_score
super().__init__(task, predictor, config)
- def select_exemplars(self, prompt: str, n_examples: int = 5) -> str:
+ def select_exemplars(self, prompt: Prompt, n_examples: int = 5) -> Prompt:
"""Select exemplars using a random selection strategy.
This method generates random examples and selects those that are evaluated as correct
(score == self.desired_score) until the desired number of exemplars is reached.
Args:
- prompt (str): The input prompt to base the exemplar selection on.
+ prompt (Prompt): The input prompt to base the exemplar selection on.
n_examples (int, optional): The number of exemplars to select. Defaults to 5.
Returns:
- str: A new prompt that includes the original prompt and the selected exemplars.
+ Prompt: A new prompt that includes the original prompt and the selected exemplars.
"""
examples: List[str] = []
while len(examples) < n_examples:
@@ -59,4 +60,4 @@ def select_exemplars(self, prompt: str, n_examples: int = 5) -> str:
seq = seqs[0][0]
if score == self.desired_score:
examples.append(seq)
- return "\n\n".join([prompt] + examples) + "\n\n"
+ return Prompt(prompt.instruction, examples)
diff --git a/promptolution/helpers.py b/promptolution/helpers.py
index 2594609..a25c008 100644
--- a/promptolution/helpers.py
+++ b/promptolution/helpers.py
@@ -1,10 +1,11 @@
"""Helper functions for the usage of the libary."""
-
-from typing import TYPE_CHECKING, Callable, List, Literal, Optional
+from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Union, cast
from promptolution.tasks.judge_tasks import JudgeTask
from promptolution.tasks.reward_tasks import RewardTask
+from promptolution.utils.prompt import Prompt
+from promptolution.utils.prompt_creation import create_prompts_from_task_description
if TYPE_CHECKING: # pragma: no cover
from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
@@ -28,17 +29,8 @@
from promptolution.optimizers.evoprompt_de import EvoPromptDE
from promptolution.optimizers.evoprompt_ga import EvoPromptGA
from promptolution.optimizers.opro import OPRO
-from promptolution.optimizers.templates import (
- CAPO_CROSSOVER_TEMPLATE,
- CAPO_MUTATION_TEMPLATE,
- EVOPROMPT_DE_TEMPLATE,
- EVOPROMPT_DE_TEMPLATE_TD,
- EVOPROMPT_GA_TEMPLATE,
- EVOPROMPT_GA_TEMPLATE_TD,
- OPRO_TEMPLATE,
- OPRO_TEMPLATE_TD,
-)
-from promptolution.predictors.classifier import FirstOccurrenceClassifier, MarkerBasedClassifier
+from promptolution.predictors.first_occurrence_predictor import FirstOccurrencePredictor
+from promptolution.predictors.maker_based_predictor import MarkerBasedPredictor
from promptolution.tasks.classification_tasks import ClassificationTask
from promptolution.utils.logging import get_logger
@@ -59,12 +51,13 @@ def run_experiment(df: pd.DataFrame, config: "ExperimentConfig") -> pd.DataFrame
train_df = df.sample(frac=0.8, random_state=42)
test_df = df.drop(train_df.index)
prompts = run_optimization(train_df, config)
- df_prompt_scores = run_evaluation(test_df, config, prompts)
+ prompts_str = [p.construct_prompt() for p in prompts]
+ df_prompt_scores = run_evaluation(test_df, config, prompts_str)
return df_prompt_scores
-def run_optimization(df: pd.DataFrame, config: "ExperimentConfig") -> List[str]:
+def run_optimization(df: pd.DataFrame, config: "ExperimentConfig") -> List[Prompt]:
"""Run the optimization phase of the experiment.
Configures all LLMs (downstream, meta, and judge) to use
@@ -74,12 +67,18 @@ def run_optimization(df: pd.DataFrame, config: "ExperimentConfig") -> List[str]:
config (Config): Configuration object for the experiment.
Returns:
- List[str]: The optimized list of prompts.
+ List[Prompt]: The optimized list of prompts.
"""
llm = get_llm(config=config)
predictor = get_predictor(llm, config=config)
- config.task_description = (config.task_description or "") + " " + (predictor.extraction_description or "")
+ if getattr(config, "prompts") is None:
+ initial_prompts = create_prompts_from_task_description(
+ task_description=config.task_description,
+ llm=llm,
+ )
+ config.prompts = [Prompt(p) for p in initial_prompts]
+
if config.optimizer == "capo" and (config.eval_strategy is None or "block" not in config.eval_strategy):
logger.warning("π CAPO requires block evaluation strategy. Setting it to 'sequential_block'.")
config.eval_strategy = "sequential_block"
@@ -94,14 +93,15 @@ def run_optimization(df: pd.DataFrame, config: "ExperimentConfig") -> List[str]:
logger.warning("π₯ Starting optimization...")
prompts = optimizer.optimize(n_steps=config.n_steps)
- if hasattr(config, "prepend_exemplars") and config.prepend_exemplars:
+ if hasattr(config, "posthoc_exemplar_selection") and config.posthoc_exemplar_selection:
selector = get_exemplar_selector(config.exemplar_selector, task, predictor)
prompts = [selector.select_exemplars(p, n_examples=config.n_exemplars) for p in prompts]
-
return prompts
-def run_evaluation(df: pd.DataFrame, config: "ExperimentConfig", prompts: List[str]) -> pd.DataFrame:
+def run_evaluation(
+ df: pd.DataFrame, config: "ExperimentConfig", prompts: Union[List[Prompt], List[str]]
+) -> pd.DataFrame:
"""Run the evaluation phase of the experiment.
Configures all LLMs (downstream, meta, and judge) to use
@@ -119,8 +119,13 @@ def run_evaluation(df: pd.DataFrame, config: "ExperimentConfig", prompts: List[s
task = get_task(df, config, judge_llm=llm)
predictor = get_predictor(llm, config=config)
logger.warning("π Starting evaluation...")
+ if isinstance(prompts[0], str):
+ str_prompts = cast(List[str], prompts)
+ prompts = [Prompt(p) for p in str_prompts]
+ else:
+ str_prompts = [p.construct_prompt() for p in cast(List[Prompt], prompts)]
scores = task.evaluate(prompts, predictor, eval_strategy="full")
- df = pd.DataFrame(dict(prompt=prompts, score=scores))
+ df = pd.DataFrame(dict(prompt=str_prompts, score=scores))
df = df.sort_values("score", ascending=False, ignore_index=True)
return df
@@ -220,50 +225,27 @@ def get_optimizer(
ValueError: If an unknown optimizer type is specified
"""
final_optimizer = optimizer or (config.optimizer if config else None)
- final_task_description = task_description or (config.task_description if config else None)
+ if config is None:
+ config = ExperimentConfig()
+ if task_description is not None:
+ config.task_description = task_description
if final_optimizer == "capo":
- crossover_template = (
- CAPO_CROSSOVER_TEMPLATE.replace("", final_task_description)
- if final_task_description
- else CAPO_CROSSOVER_TEMPLATE
- )
- mutation_template = (
- CAPO_MUTATION_TEMPLATE.replace("", final_task_description)
- if final_task_description
- else CAPO_MUTATION_TEMPLATE
- )
-
return CAPO(
predictor=predictor,
meta_llm=meta_llm,
task=task,
- crossover_template=crossover_template,
- mutation_template=mutation_template,
config=config,
)
if final_optimizer == "evopromptde":
- template = (
- EVOPROMPT_DE_TEMPLATE_TD.replace("", final_task_description)
- if final_task_description
- else EVOPROMPT_DE_TEMPLATE
- )
- return EvoPromptDE(predictor=predictor, meta_llm=meta_llm, task=task, prompt_template=template, config=config)
+ return EvoPromptDE(predictor=predictor, meta_llm=meta_llm, task=task, config=config)
if final_optimizer == "evopromptga":
- template = (
- EVOPROMPT_GA_TEMPLATE_TD.replace("", final_task_description)
- if final_task_description
- else EVOPROMPT_GA_TEMPLATE
- )
- return EvoPromptGA(predictor=predictor, meta_llm=meta_llm, task=task, prompt_template=template, config=config)
+ return EvoPromptGA(predictor=predictor, meta_llm=meta_llm, task=task, config=config)
if final_optimizer == "opro":
- template = (
- OPRO_TEMPLATE_TD.replace("", final_task_description) if final_task_description else OPRO_TEMPLATE
- )
- return OPRO(predictor=predictor, meta_llm=meta_llm, task=task, prompt_template=template, config=config)
+ return OPRO(predictor=predictor, meta_llm=meta_llm, task=task, config=config)
raise ValueError(f"Unknown optimizer: {final_optimizer}")
@@ -296,23 +278,23 @@ def get_predictor(downstream_llm=None, type: "PredictorType" = "marker", *args,
"""Factory function to create and return a predictor instance.
This function supports three types of predictors:
- 1. FirstOccurrenceClassifier: A predictor that classifies based on first occurrence of the label.
- 2. MarkerBasedClassifier: A predictor that classifies based on a marker.
+ 1. FirstOccurrencePredictor: A predictor that classifies based on first occurrence of the label.
+ 2. MarkerBasedPredictor: A predictor that classifies based on a marker.
Args:
downstream_llm: The language model to use for prediction.
type (Literal["first_occurrence", "marker"]): The type of predictor to create:
- - "first_occurrence" for FirstOccurrenceClassifier
- - "marker" (default) for MarkerBasedClassifier
+ - "first_occurrence" for FirstOccurrencePredictor
+ - "marker" (default) for MarkerBasedPredictor
*args: Variable length argument list passed to the predictor constructor.
**kwargs: Arbitrary keyword arguments passed to the predictor constructor.
Returns:
- An instance of FirstOccurrenceClassifier or MarkerBasedClassifier.
+ An instance of FirstOccurrencePredictor or MarkerBasedPredictor.
"""
if type == "first_occurrence":
- return FirstOccurrenceClassifier(downstream_llm, *args, **kwargs)
+ return FirstOccurrencePredictor(downstream_llm, *args, **kwargs)
elif type == "marker":
- return MarkerBasedClassifier(downstream_llm, *args, **kwargs)
+ return MarkerBasedPredictor(downstream_llm, *args, **kwargs)
else:
raise ValueError(f"Invalid predictor type: '{type}'")
diff --git a/promptolution/llms/__init__.py b/promptolution/llms/__init__.py
index 7fd7b97..8110f87 100644
--- a/promptolution/llms/__init__.py
+++ b/promptolution/llms/__init__.py
@@ -1,6 +1,11 @@
"""Module for Large Language Models."""
-
from promptolution.llms.api_llm import APILLM
from promptolution.llms.local_llm import LocalLLM
from promptolution.llms.vllm import VLLM
+
+__all__ = [
+ "APILLM",
+ "LocalLLM",
+ "VLLM",
+]
diff --git a/promptolution/llms/api_llm.py b/promptolution/llms/api_llm.py
index 093478e..c6971a6 100644
--- a/promptolution/llms/api_llm.py
+++ b/promptolution/llms/api_llm.py
@@ -1,141 +1,241 @@
"""Module to interface with various language models through their respective APIs."""
-try:
- import asyncio
+import asyncio
+import threading
+from concurrent.futures import TimeoutError as FuturesTimeout
- from openai import AsyncOpenAI
- from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
+from openai import AsyncOpenAI
+from openai.types.chat import ChatCompletion
- import_successful = True
-except ImportError:
- import_successful = False
-
-
-from typing import TYPE_CHECKING, Dict, List, Optional
+from typing import Any, Dict, List, Optional
from promptolution.llms.base_llm import BaseLLM
-
-if TYPE_CHECKING: # pragma: no cover
- from promptolution.utils.config import ExperimentConfig
-
+from promptolution.utils.config import ExperimentConfig
from promptolution.utils.logging import get_logger
logger = get_logger(__name__)
-async def _invoke_model(
- prompt: str,
- system_prompt: str,
- max_tokens: int,
- model_id: str,
- client: AsyncOpenAI,
- semaphore: asyncio.Semaphore,
- max_retries: int = 20,
- retry_delay: float = 5,
-) -> ChatCompletion:
- async with semaphore:
- messages: List[ChatCompletionMessageParam] = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": prompt},
- ]
-
- for attempt in range(max_retries + 1): # +1 for the initial attempt
- try:
- response = await client.chat.completions.create(
- model=model_id,
- messages=messages,
- max_tokens=max_tokens,
- )
- return response
- except Exception as e:
- if attempt < max_retries:
- # Calculate exponential backoff with jitter
- logger.warning(
- f"β οΈ API call failed (attempt {attempt + 1} / {max_retries + 1}): {str(e)}. "
- f"Retrying in {retry_delay:.2f} seconds..."
- )
- await asyncio.sleep(retry_delay)
- else:
- # Log the final failure and re-raise the exception
- logger.error(f"β API call failed after {max_retries + 1} attempts: {str(e)}")
- raise # Re-raise the exception after all retries fail
- raise RuntimeError("Failed to get response after multiple retries.")
-
-
class APILLM(BaseLLM):
- """A class to interface with language models through their respective APIs.
-
- This class provides a unified interface for making API calls to language models
- using the OpenAI client library. It handles rate limiting through semaphores
- and supports both synchronous and asynchronous operations.
-
- Attributes:
- model_id (str): Identifier for the model to use.
- client (AsyncOpenAI): The initialized API client.
- max_tokens (int): Maximum number of tokens in model responses.
- semaphore (asyncio.Semaphore): Semaphore to limit concurrent API calls.
- """
+ """Persistent asynchronous LLM wrapper using a background event loop."""
def __init__(
self,
api_url: Optional[str] = None,
model_id: Optional[str] = None,
api_key: Optional[str] = None,
- max_concurrent_calls: int = 50,
- max_tokens: int = 512,
+ max_concurrent_calls: int = 32,
+ max_tokens: int = 4096,
+ call_timeout_s: float = 200.0, # per request
+ gather_timeout_s: float = 500.0, # whole batch
+ max_retries: int = 5,
+ retry_base_delay_s: float = 1,
+ client_kwargs: Optional[Dict[str, Any]] = None,
+ call_kwargs: Optional[Dict[str, Any]] = None,
config: Optional["ExperimentConfig"] = None,
) -> None:
- """Initialize the APILLM with a specific model and API configuration.
+ """Initialize the APILLM.
Args:
- api_url (str): The base URL for the API endpoint.
- model_id (str): Identifier for the model to use.
- api_key (str, optional): API key for authentication. Defaults to None.
- max_concurrent_calls (int, optional): Maximum number of concurrent API calls. Defaults to 50.
- max_tokens (int, optional): Maximum number of tokens in model responses. Defaults to 512.
- config (ExperimentConfig, optional): Configuration for the LLM, overriding defaults.
-
- Raises:
- ImportError: If required libraries are not installed.
+ api_url (Optional[str]): Base URL for the API endpoint.
+ model_id (Optional[str]): Identifier of the model to call. Must be set.
+ api_key (Optional[str]): API key/token for authentication.
+ max_concurrent_calls (int): Maximum number of concurrent API calls.
+ max_tokens (int): Default maximum number of tokens in model responses.
+ call_timeout_s (float): Per-call timeout in seconds.
+ gather_timeout_s (float): Timeout in seconds for the entire batch.
+ max_retries (int): Number of retry attempts per prompt in addition to the initial call.
+ retry_base_delay_s (float): Base delay in seconds for exponential backoff between retries.
+ client_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments passed to `AsyncOpenAI(...)`.
+ call_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments passed to `client.chat.completions.create(...)`.
+ config (Optional[ExperimentConfig]): Configuration for the LLM, overriding defaults.
"""
- if not import_successful:
- raise ImportError(
- "Could not import at least one of the required libraries: openai, asyncio. "
- "Please ensure they are installed in your environment."
- )
-
self.api_url = api_url
self.model_id = model_id
self.api_key = api_key
- self.max_concurrent_calls = max_concurrent_calls
self.max_tokens = max_tokens
+ self.call_timeout_s = call_timeout_s
+ self.gather_timeout_s = gather_timeout_s
+ self.max_retries = max_retries
+ self.retry_base_delay_s = retry_base_delay_s
+
+ # extra kwargs
+ self._client_kwargs: Dict[str, Any] = dict(client_kwargs or {})
+ self._call_kwargs: Dict[str, Any] = dict(call_kwargs or {})
+ self.max_concurrent_calls = max_concurrent_calls
super().__init__(config=config)
- self.client = AsyncOpenAI(base_url=self.api_url, api_key=self.api_key)
- self.semaphore = asyncio.Semaphore(self.max_concurrent_calls)
+
+ # --- persistent loop + semaphore ---
+ self._loop = asyncio.new_event_loop()
+ self._sem = asyncio.Semaphore(self.max_concurrent_calls)
+
+ def _run_loop() -> None:
+ """Run the background event loop forever."""
+ asyncio.set_event_loop(self._loop)
+ self._loop.run_forever()
+
+ self._thread = threading.Thread(target=_run_loop, name="APILLMLoop", daemon=True)
+ self._thread.start()
+
+ # Create client once; can still be customised via client_kwargs.
+ self.client = AsyncOpenAI(
+ base_url=self.api_url,
+ api_key=self.api_key,
+ timeout=self.call_timeout_s,
+ **self._client_kwargs,
+ )
+
+ # ---------- async bits that run inside the loop ----------
+ async def _ainvoke_once(self, prompt: str, system_prompt: str) -> ChatCompletion:
+ """Perform a single API call with a per-call timeout.
+
+ Args:
+ prompt (str): User prompt content.
+ system_prompt (str): System-level instructions for the model.
+
+ Returns:
+ ChatCompletion: Raw completion response from the API.
+
+ Raises:
+ asyncio.TimeoutError: If the call exceeds `call_timeout_s`.
+ Exception: Any exception raised by the underlying client call.
+ """
+ messages = [
+ {"role": "system", "content": str(system_prompt)},
+ {"role": "user", "content": str(prompt)},
+ ]
+
+ # base kwargs; user can override via call_kwargs
+ kwargs: Dict[str, Any] = {
+ "model": self.model_id,
+ "messages": messages,
+ "max_tokens": self.max_tokens,
+ }
+ kwargs.update(self._call_kwargs)
+
+ async with self._sem:
+ # per-call timeout enforces failure instead of hang
+ return await asyncio.wait_for(
+ self.client.chat.completions.create(**kwargs),
+ timeout=self.call_timeout_s,
+ )
+
+ async def _ainvoke_with_retries(self, prompt: str, system_prompt: str) -> str:
+ """Invoke the model with retries and exponential backoff.
+
+ Args:
+ prompt (str): User prompt content.
+ system_prompt (str): System-level instructions for the model.
+
+ Returns:
+ str: The message content of the first choice in the completion.
+
+ Raises:
+ Exception: The last exception encountered after all retries are exhausted.
+ """
+ last_err: Optional[Exception] = None
+ for attempt in range(self.max_retries + 1):
+ try:
+ r = await self._ainvoke_once(prompt, system_prompt)
+ content = r.choices[0].message.content
+ if content is None:
+ raise RuntimeError("Empty content from model")
+ return content
+ except Exception as e:
+ last_err = e
+ if attempt < self.max_retries:
+ delay = self.retry_base_delay_s * (2**attempt)
+ logger.error(
+ f"LLM call failed ({attempt + 1}/{self.max_retries + 1}): β retrying in {delay}s", exc_info=e
+ )
+ await asyncio.sleep(delay)
+ assert last_err is not None
+ raise last_err
+
+ async def _aget_batch(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
+ """Execute a batch of prompts concurrently and collect responses.
+
+ Args:
+ prompts (List[str]): List of user prompts.
+ system_prompts (List[str]): List of system prompts; must match `prompts` in length.
+
+ Returns:
+ List[str]: List of model outputs. For failed entries, an empty string is inserted.
+
+ Raises:
+ TimeoutError: If the entire batch exceeds `gather_timeout_s`.
+ RuntimeError: If any of the tasks fails; the first exception is propagated.
+ """
+ tasks = [asyncio.create_task(self._ainvoke_with_retries(p, s)) for p, s in zip(prompts, system_prompts)]
+
+ try:
+ results = await asyncio.wait_for(
+ asyncio.gather(*tasks, return_exceptions=True),
+ timeout=self.gather_timeout_s,
+ )
+ except asyncio.TimeoutError:
+ for t in tasks:
+ t.cancel()
+ raise TimeoutError(f"LLM batch timed out after {self.gather_timeout_s}s")
+
+ outs: List[str] = []
+ first_exc: Optional[BaseException] = None
+ for r in results:
+ if isinstance(r, BaseException):
+ if first_exc is None:
+ first_exc = r
+ outs.append("")
+ else:
+ outs.append(r)
+
+ if first_exc:
+ for t in tasks:
+ if not t.done():
+ t.cancel()
+ raise RuntimeError(f"LLM batch failed: {first_exc}") from first_exc
+
+ return outs
+
+ # ---------- sync API used by the threads ----------
+ def _submit(self, coro):
+ """Submit a coroutine to the background event loop.
+
+ Args:
+ coro: Coroutine object to be scheduled on the loop.
+
+ Returns:
+ concurrent.futures.Future: Future representing the coroutine result.
+ """
+ return asyncio.run_coroutine_threadsafe(coro, self._loop)
def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
- # Setup for async execution in sync context
+ """Synchronously obtain responses for a batch of prompts.
+
+ This is the main entrypoint used by external callers. It handles system
+ prompt broadcasting and delegates the actual work to the async batch
+ execution on the background loop.
+
+ Args:
+ prompts (List[str]): List of user prompts.
+ system_prompts (List[str]): List of system prompts. If a single system
+ prompt is provided and multiple prompts are given, the system
+ prompt is broadcast to all prompts. Otherwise, the list is
+ normalized to match the length of `prompts`.
+
+ Returns:
+ List[str]: List of model responses corresponding to `prompts`.
+
+ Raises:
+ TimeoutError: If waiting on the batch future exceeds `gather_timeout_s + 5.0`.
+ Exception: Any underlying error from the async batch execution.
+ """
+ fut = self._submit(self._aget_batch(prompts, system_prompts))
try:
- loop = asyncio.get_running_loop()
- except RuntimeError: # 'get_running_loop' raises a RuntimeError if there is no running loop
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- responses = loop.run_until_complete(self._get_response_async(prompts, system_prompts))
- return responses
-
- async def _get_response_async(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
- assert self.model_id is not None, "model_id must be set"
- tasks = [
- _invoke_model(prompt, system_prompt, self.max_tokens, self.model_id, self.client, self.semaphore)
- for prompt, system_prompt in zip(prompts, system_prompts)
- ]
- messages = await asyncio.gather(*tasks)
- responses = []
- for message in messages:
- response = message.choices[0].message.content
- if response is None:
- raise ValueError("Received None response from the API.")
- responses.append(response)
- return responses
+ r = fut.result(timeout=self.gather_timeout_s + 5.0)
+ return r
+ except FuturesTimeout:
+ fut.cancel()
+ raise TimeoutError(f"LLM batch (future) timed out after {self.gather_timeout_s + 5.0}s")
+ except Exception:
+ raise
diff --git a/promptolution/llms/base_llm.py b/promptolution/llms/base_llm.py
index 2fe43f9..2007a10 100644
--- a/promptolution/llms/base_llm.py
+++ b/promptolution/llms/base_llm.py
@@ -9,8 +9,8 @@
from promptolution.utils.config import ExperimentConfig
from transformers import PreTrainedTokenizer
-from promptolution.optimizers.templates import DEFAULT_SYS_PROMPT
from promptolution.utils.logging import get_logger
+from promptolution.utils.templates import DEFAULT_SYS_PROMPT
logger = get_logger(__name__)
@@ -42,7 +42,7 @@ def __init__(self, config: Optional["ExperimentConfig"] = None):
# Initialize token counters
self.input_token_count = 0
self.output_token_count = 0
- self.tokenizer: Optional[PreTrainedTokenizer] = None
+ self.tokenizer: Optional["PreTrainedTokenizer"] = None
def get_token_count(self) -> Dict[str, int]:
"""Get the current count of input and output tokens.
diff --git a/promptolution/llms/vllm.py b/promptolution/llms/vllm.py
index 1df5121..f22ff52 100644
--- a/promptolution/llms/vllm.py
+++ b/promptolution/llms/vllm.py
@@ -1,10 +1,10 @@
"""Module for running language models locally using the vLLM library."""
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING: # pragma: no cover
from promptolution.utils.config import ExperimentConfig
+ from transformers import PreTrainedTokenizer
from promptolution.llms.base_llm import BaseLLM
@@ -14,7 +14,6 @@
try:
from transformers import AutoTokenizer # type: ignore
- from transformers import PreTrainedTokenizer
from vllm import LLM, SamplingParams
imports_successful = True
@@ -38,7 +37,7 @@ class VLLM(BaseLLM):
update_token_count: Update the token count based on the given inputs and outputs.
"""
- tokenizer: PreTrainedTokenizer
+ tokenizer: "PreTrainedTokenizer"
def __init__(
self,
diff --git a/promptolution/optimizers/__init__.py b/promptolution/optimizers/__init__.py
index 47f78a3..4b7a7db 100644
--- a/promptolution/optimizers/__init__.py
+++ b/promptolution/optimizers/__init__.py
@@ -1,23 +1,13 @@
"""Module for prompt optimizers."""
-
from promptolution.optimizers.capo import CAPO
from promptolution.optimizers.evoprompt_de import EvoPromptDE
from promptolution.optimizers.evoprompt_ga import EvoPromptGA
from promptolution.optimizers.opro import OPRO
-from promptolution.optimizers.templates import (
- CAPO_CROSSOVER_TEMPLATE,
- CAPO_DOWNSTREAM_TEMPLATE,
- CAPO_FEWSHOT_TEMPLATE,
- CAPO_MUTATION_TEMPLATE,
- DEFAULT_SYS_PROMPT,
- EVOPROMPT_DE_TEMPLATE,
- EVOPROMPT_DE_TEMPLATE_TD,
- EVOPROMPT_GA_TEMPLATE,
- EVOPROMPT_GA_TEMPLATE_TD,
- OPRO_TEMPLATE,
- OPRO_TEMPLATE_TD,
- PROMPT_CREATION_TEMPLATE,
- PROMPT_CREATION_TEMPLATE_TD,
- PROMPT_VARIATION_TEMPLATE,
-)
+
+__all__ = [
+ "CAPO",
+ "EvoPromptDE",
+ "EvoPromptGA",
+ "OPRO",
+]
diff --git a/promptolution/optimizers/base_optimizer.py b/promptolution/optimizers/base_optimizer.py
index ded87e5..7264f6f 100644
--- a/promptolution/optimizers/base_optimizer.py
+++ b/promptolution/optimizers/base_optimizer.py
@@ -1,6 +1,5 @@
"""Base module for optimizers in the promptolution library."""
-
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Literal, Optional
@@ -12,6 +11,7 @@
from promptolution.utils.callbacks import BaseCallback
from promptolution.utils.logging import get_logger
+from promptolution.utils.prompt import Prompt
logger = get_logger(__name__)
@@ -49,7 +49,7 @@ def __init__(
config (ExperimentConfig, optional): Configuration for the optimizer, overriding defaults.
"""
# Set up optimizer state
- self.prompts: List[str] = initial_prompts or []
+ self.prompts: List[Prompt] = [Prompt(p) for p in initial_prompts] if initial_prompts else []
self.task = task
self.callbacks: List["BaseCallback"] = callbacks or []
self.predictor = predictor
@@ -60,7 +60,7 @@ def __init__(
self.config = config
- def optimize(self, n_steps: int) -> List[str]:
+ def optimize(self, n_steps: int) -> List[Prompt]:
"""Perform the optimization process.
This method should be implemented by concrete optimizer classes to define
@@ -82,8 +82,7 @@ def optimize(self, n_steps: int) -> List[str]:
self.prompts = self._step()
except Exception as e:
# exit training loop and gracefully fail
- logger.error(f"β Error during optimization step: {e}")
- logger.error("β οΈ Exiting optimization loop.")
+ logger.error("β Error during optimization step! β οΈ Exiting optimization loop.", exc_info=e)
break
# Callbacks at the end of each step
@@ -105,7 +104,7 @@ def _pre_optimization_loop(self) -> None:
pass
@abstractmethod
- def _step(self) -> List[str]:
+ def _step(self) -> List[Prompt]:
"""Perform a single optimization step.
This method should be implemented by concrete optimizer classes to define
@@ -129,3 +128,15 @@ def _on_train_end(self) -> None:
"""Call all registered callbacks at the end of the entire optimization process."""
for callback in self.callbacks:
callback.on_train_end(self)
+
+ def _initialize_meta_template(self, template: str) -> str:
+ task_description = getattr(self.task, "task_description")
+ extraction_description = getattr(self.predictor, "extraction_description")
+ if self.config is not None and getattr(self.config, "task_description") is not None:
+ task_description = self.config.task_description
+ if task_description is None:
+ logger.warning("Task description is not provided. Please make sure to include relevant task details.")
+ task_description = ""
+ if extraction_description is not None:
+ task_description += "\n" + extraction_description
+ return template.replace("", task_description)
diff --git a/promptolution/optimizers/capo.py b/promptolution/optimizers/capo.py
index bcfa275..3c5955a 100644
--- a/promptolution/optimizers/capo.py
+++ b/promptolution/optimizers/capo.py
@@ -1,15 +1,12 @@
"""Implementation of the CAPO (Cost-Aware Prompt Optimization) algorithm."""
import random
-from itertools import compress
import numpy as np
import pandas as pd
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
-from promptolution.utils.formatting import extract_from_tag
-
if TYPE_CHECKING: # pragma: no cover
from promptolution.utils.callbacks import BaseCallback
from promptolution.llms.base_llm import BaseLLM
@@ -19,52 +16,16 @@
from promptolution.utils.test_statistics import TestStatistics
from promptolution.optimizers.base_optimizer import BaseOptimizer
-from promptolution.optimizers.templates import (
- CAPO_CROSSOVER_TEMPLATE,
- CAPO_DOWNSTREAM_TEMPLATE,
- CAPO_FEWSHOT_TEMPLATE,
- CAPO_MUTATION_TEMPLATE,
-)
+from promptolution.utils.formatting import extract_from_tag
from promptolution.utils.logging import get_logger
+from promptolution.utils.prompt import Prompt, sort_prompts_by_scores
+from promptolution.utils.templates import CAPO_CROSSOVER_TEMPLATE, CAPO_FEWSHOT_TEMPLATE, CAPO_MUTATION_TEMPLATE
from promptolution.utils.test_statistics import get_test_statistic_func
from promptolution.utils.token_counter import get_token_counter
logger = get_logger(__name__)
-class CAPOPrompt:
- """Represents a prompt consisting of an instruction and few-shot examples."""
-
- def __init__(self, instruction_text: str, few_shots: List[str]) -> None:
- """Initializes the Prompt with an instruction and associated examples.
-
- Args:
- instruction_text (str): The instruction or prompt text.
- few_shots (List[str]): List of examples as string.
- """
- self.instruction_text = instruction_text.strip()
- self.few_shots = few_shots
-
- def construct_prompt(self) -> str:
- """Constructs the full prompt string by replacing placeholders in the template with the instruction and formatted examples.
-
- Returns:
- str: The constructed prompt string.
- """
- few_shot_str = "\n\n".join(self.few_shots).strip()
- prompt = (
- CAPO_DOWNSTREAM_TEMPLATE.replace("", self.instruction_text)
- .replace("", few_shot_str)
- .replace("\n\n\n\n", "\n\n") # replace extra newlines if no few shots are provided
- .strip()
- )
- return prompt
-
- def __str__(self) -> str:
- """Returns the string representation of the prompt."""
- return self.construct_prompt()
-
-
class CAPO(BaseOptimizer):
"""CAPO: Cost-Aware Prompt Optimization.
@@ -80,6 +41,8 @@ def __init__(
task: "BaseTask",
meta_llm: "BaseLLM",
initial_prompts: Optional[List[str]] = None,
+ crossover_template: Optional[str] = None,
+ mutation_template: Optional[str] = None,
crossovers_per_iter: int = 4,
upper_shots: int = 5,
max_n_blocks_eval: int = 10,
@@ -89,8 +52,6 @@ def __init__(
check_fs_accuracy: bool = True,
create_fs_reasoning: bool = True,
df_few_shots: Optional[pd.DataFrame] = None,
- crossover_template: Optional[str] = None,
- mutation_template: Optional[str] = None,
callbacks: Optional[List["BaseCallback"]] = None,
config: Optional["ExperimentConfig"] = None,
) -> None:
@@ -101,6 +62,8 @@ def __init__(
task (BaseTask): The task instance containing dataset and description.
meta_llm (BaseLLM): The meta language model for crossover/mutation.
initial_prompts (List[str]): Initial prompt instructions.
+ crossover_template (str, optional): Template for crossover instructions.
+ mutation_template (str, optional): Template for mutation instructions.
crossovers_per_iter (int): Number of crossover operations per iteration.
upper_shots (int): Maximum number of few-shot examples per prompt.
p_few_shot_reasoning (float): Probability of generating llm-reasoning for few-shot examples, instead of simply using input-output pairs.
@@ -113,17 +76,12 @@ def __init__(
create_fs_reasoning (bool): Whether to create reasoning for few-shot examples using the downstream model,
instead of simply using input-output pairs from the few shots DataFrame. Default is True.
df_few_shots (pd.DataFrame): DataFrame containing few-shot examples. If None, will pop 10% of datapoints from task.
- crossover_template (str, optional): Template for crossover instructions.
- mutation_template (str, optional): Template for mutation instructions.
callbacks (List[Callable], optional): Callbacks for optimizer events.
config (ExperimentConfig, optional): Configuration for the optimizer.
"""
self.meta_llm = meta_llm
self.downstream_llm = predictor.llm
- self.crossover_template = crossover_template or CAPO_CROSSOVER_TEMPLATE
- self.mutation_template = mutation_template or CAPO_MUTATION_TEMPLATE
-
self.crossovers_per_iter = crossovers_per_iter
self.upper_shots = upper_shots
self.max_n_blocks_eval = max_n_blocks_eval
@@ -136,8 +94,11 @@ def __init__(
self.check_fs_accuracy = check_fs_accuracy
self.create_fs_reasoning = create_fs_reasoning
- self.scores: List[float] = []
super().__init__(predictor, task, initial_prompts, callbacks, config)
+
+ self.crossover_template = self._initialize_meta_template(crossover_template or CAPO_CROSSOVER_TEMPLATE)
+ self.mutation_template = self._initialize_meta_template(mutation_template or CAPO_MUTATION_TEMPLATE)
+
self.df_few_shots = df_few_shots if df_few_shots is not None else task.pop_datapoints(frac=0.1)
if self.max_n_blocks_eval > self.task.n_blocks:
logger.warning(
@@ -145,6 +106,11 @@ def __init__(
f" Setting max_n_blocks_eval to {self.task.n_blocks}."
)
self.max_n_blocks_eval = self.task.n_blocks
+ if "block" not in self.task.eval_strategy:
+ logger.warning(
+ f"βΉοΈ CAPO requires 'block' in the eval_strategy, but got {self.task.eval_strategy}. Setting eval_strategy to 'sequential_block'."
+ )
+ self.task.eval_strategy = "sequential_block"
self.population_size = len(self.prompts)
if hasattr(self.predictor, "begin_marker") and hasattr(self.predictor, "end_marker"):
@@ -154,7 +120,7 @@ def __init__(
self.target_begin_marker = ""
self.target_end_marker = ""
- def _initialize_population(self, initial_prompts: List[str]) -> List[CAPOPrompt]:
+ def _initialize_population(self, initial_prompts: List[Prompt]) -> List[Prompt]:
"""Initializes the population of Prompt objects from initial instructions.
Args:
@@ -164,10 +130,10 @@ def _initialize_population(self, initial_prompts: List[str]) -> List[CAPOPrompt]
List[Prompt]: Initialized population of prompts with few-shot examples.
"""
population = []
- for instruction_text in initial_prompts:
+ for prompt in initial_prompts:
num_examples = random.randint(0, self.upper_shots)
- few_shots = self._create_few_shot_examples(instruction_text, num_examples)
- population.append(CAPOPrompt(instruction_text, few_shots))
+ few_shots = self._create_few_shot_examples(prompt.instruction, num_examples)
+ population.append(Prompt(prompt.instruction, few_shots))
return population
@@ -202,18 +168,18 @@ def _create_few_shot_examples(self, instruction: str, num_examples: int) -> List
# Check which predictions are correct and get a single one per example
for j in range(num_examples):
# Process and clean up the generated sequences
- seqs[j] = seqs[j].replace(sample_inputs[j], "").strip()
+ seqs[j] = seqs[j].replace(sample_inputs[j], "", 1).strip()
# Check if the prediction is correct and add reasoning if so
if preds[j] == sample_targets[j] or not self.check_fs_accuracy:
few_shots[j] = CAPO_FEWSHOT_TEMPLATE.replace("", sample_inputs[j]).replace("