Skip to content

Commit 3ed4816

Browse files
committed
documentation
1 parent 80b3857 commit 3ed4816

File tree

17 files changed

+642
-4
lines changed

17 files changed

+642
-4
lines changed

promptolution/callbacks.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ def on_train_end(self, logs=None):
1616

1717

1818
class LoggerCallback(Callback):
19+
"""
20+
Callback for logging optimization progress.
21+
22+
This callback logs information about each step, epoch, and the end of training.
23+
24+
Attributes:
25+
logger: The logger object to use for logging.
26+
step (int): The current step number.
27+
"""
1928
def __init__(self, logger):
2029
# TODO check up whats up with logging leves
2130
self.logger = logger
@@ -36,6 +45,15 @@ def on_train_end(self, logs=None):
3645

3746

3847
class CSVCallback(Callback):
48+
"""
49+
Callback for saving optimization progress to a CSV file.
50+
51+
This callback saves prompts and scores at each step to a CSV file.
52+
53+
Attributes:
54+
path (str): The path to the CSV file.
55+
step (int): The current step number.
56+
"""
3957
def __init__(self, path):
4058
# if dir does not exist
4159
if not os.path.exists(os.path.dirname(path)):
@@ -62,6 +80,15 @@ def on_train_end(self, logs=None):
6280

6381

6482
class BestPromptCallback(Callback):
83+
"""
84+
Callback for tracking the best prompt during optimization.
85+
86+
This callback keeps track of the prompt with the highest score.
87+
88+
Attributes:
89+
best_prompt (str): The prompt with the highest score so far.
90+
best_score (float): The highest score achieved so far.
91+
"""
6592
def __init__(self):
6693
self.best_prompt = ""
6794
self.best_score = -99999
@@ -76,6 +103,14 @@ def get_best_prompt(self):
76103

77104

78105
class ProgressBarCallback(Callback):
106+
"""
107+
Callback for displaying a progress bar during optimization.
108+
109+
This callback uses tqdm to display a progress bar that updates at each step.
110+
111+
Attributes:
112+
pbar (tqdm): The tqdm progress bar object.
113+
"""
79114
def __init__(self, total_steps):
80115
self.pbar = tqdm(total=total_steps)
81116

promptolution/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,27 @@
44

55
@dataclass
66
class Config:
7+
"""
8+
Configuration class for the promptolution library.
9+
10+
This class handles loading and parsing of configuration settings,
11+
either from a config file or from keyword arguments.
12+
13+
Attributes:
14+
task_name (str): Name of the task.
15+
ds_path (str): Path to the dataset.
16+
n_steps (int): Number of optimization steps.
17+
optimizer (str): Name of the optimizer to use.
18+
meta_prompt_path (str): Path to the meta prompt file.
19+
meta_llms (str): Name of the meta language model.
20+
downstream_llm (str): Name of the downstream language model.
21+
evaluation_llm (str): Name of the evaluation language model.
22+
init_pop_size (int): Initial population size. Defaults to 10.
23+
logging_dir (str): Directory for logging. Defaults to "logs/run.csv".
24+
experiment_name (str): Name of the experiment. Defaults to "experiment".
25+
include_task_desc (bool): Whether to include task description. Defaults to False.
26+
random_seed (int): Random seed for reproducibility. Defaults to 42.
27+
"""
728
task_name: str
829
ds_path: str
930
n_steps: int

promptolution/llms/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,25 @@
44

55

66
def get_llm(model_id: str, *args, **kwargs):
7+
"""
8+
Factory function to create and return a language model instance based on the provided model_id.
9+
10+
This function supports three types of language models:
11+
1. DummyLLM: A mock LLM for testing purposes.
12+
2. LocalLLM: For running models locally (identified by 'local' in the model_id).
13+
3. APILLM: For API-based models (default if not matching other types).
14+
15+
Args:
16+
model_id (str): Identifier for the model to use. Special cases:
17+
- "dummy" for DummyLLM
18+
- "local-{model_name}" for LocalLLM
19+
- Any other string for APILLM
20+
*args: Variable length argument list passed to the LLM constructor.
21+
**kwargs: Arbitrary keyword arguments passed to the LLM constructor.
22+
23+
Returns:
24+
An instance of DummyLLM, LocalLLM, or APILLM based on the model_id.
25+
"""
726
if model_id == "dummy":
827
return DummyLLM(*args, **kwargs)
928
if "local" in model_id:

promptolution/llms/api_llm.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import openai
55
from logging import INFO, Logger
66

7+
from typing import List
8+
79
from langchain_anthropic import ChatAnthropic
810
from langchain_community.chat_models.deepinfra import ChatDeepInfraException
911
from langchain_core.messages import HumanMessage
@@ -17,6 +19,20 @@
1719

1820

1921
async def invoke_model(prompt, model, semaphore):
22+
"""
23+
Asynchronously invoke a language model with retry logic.
24+
25+
Args:
26+
prompt (str): The input prompt for the model.
27+
model: The language model to invoke.
28+
semaphore (asyncio.Semaphore): Semaphore to limit concurrent calls.
29+
30+
Returns:
31+
str: The model's response content.
32+
33+
Raises:
34+
ChatDeepInfraException: If all retry attempts fail.
35+
"""
2036
async with semaphore:
2137
max_retries = 100
2238
delay = 3
@@ -33,7 +49,30 @@ async def invoke_model(prompt, model, semaphore):
3349

3450

3551
class APILLM:
52+
"""
53+
A class to interface with various language models through their respective APIs.
54+
55+
This class supports Claude (Anthropic), GPT (OpenAI), and LLaMA (DeepInfra) models.
56+
It handles API key management, model initialization, and provides methods for
57+
both synchronous and asynchronous inference.
58+
59+
Attributes:
60+
model: The initialized language model instance.
61+
62+
Methods:
63+
get_response: Synchronously get responses for a list of prompts.
64+
_get_response: Asynchronously get responses for a list of prompts.
65+
"""
3666
def __init__(self, model_id: str):
67+
"""
68+
Initialize the APILLM with a specific model.
69+
70+
Args:
71+
model_id (str): Identifier for the model to use.
72+
73+
Raises:
74+
ValueError: If an unknown model identifier is provided.
75+
"""
3776
if "claude" in model_id:
3877
ANTHROPIC_API_KEY = open("anthropictoken.txt", "r").read()
3978
self.model = ChatAnthropic(model=model_id, api_key=ANTHROPIC_API_KEY)
@@ -46,7 +85,21 @@ def __init__(self, model_id: str):
4685
else:
4786
raise ValueError(f"Unknown model: {model_id}")
4887

49-
def get_response(self, prompts: list[str]) -> list[str]:
88+
def get_response(self, prompts: List[str]) -> List[str]:
89+
"""
90+
Synchronously get responses for a list of prompts.
91+
92+
This method includes retry logic for handling connection errors and rate limits.
93+
94+
Args:
95+
prompts (list[str]): List of input prompts.
96+
97+
Returns:
98+
list[str]: List of model responses.
99+
100+
Raises:
101+
requests.exceptions.ConnectionError: If max retries are exceeded.
102+
"""
50103
max_retries = 100
51104
delay = 3
52105
attempts = 0
@@ -74,6 +127,18 @@ def get_response(self, prompts: list[str]) -> list[str]:
74127
async def _get_response(
75128
self, prompts: list[str], max_concurrent_calls=200
76129
) -> list[str]: # TODO change name of method
130+
"""
131+
Asynchronously get responses for a list of prompts.
132+
133+
This method uses a semaphore to limit the number of concurrent API calls.
134+
135+
Args:
136+
prompts (list[str]): List of input prompts.
137+
max_concurrent_calls (int): Maximum number of concurrent API calls allowed.
138+
139+
Returns:
140+
list[str]: List of model responses.
141+
"""
77142
semaphore = asyncio.Semaphore(max_concurrent_calls) # Limit the number of concurrent calls
78143
tasks = []
79144

promptolution/llms/base_llm.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,60 @@
55

66

77
class BaseLLM(ABC):
8+
"""
9+
Abstract base class for Language Models in the promptolution library.
10+
11+
This class defines the interface that all concrete LLM implementations should follow.
12+
13+
Methods:
14+
get_response: An abstract method that should be implemented by subclasses
15+
to generate responses for given prompts.
16+
"""
817
def __init__(self, *args, **kwargs):
918
pass
1019

1120
@abstractmethod
1221
def get_response(self, prompts: List[str]) -> List[str]:
22+
"""
23+
Generate responses for the given prompts.
24+
25+
This method should be implemented by subclasses to define how
26+
the LLM generates responses.
27+
28+
Args:
29+
prompts (List[str]): A list of input prompts.
30+
31+
Returns:
32+
List[str]: A list of generated responses corresponding to the input prompts.
33+
"""
1334
pass
1435

1536

1637
class DummyLLM(BaseLLM):
38+
"""
39+
A dummy implementation of the BaseLLM for testing purposes.
40+
41+
This class generates random responses for given prompts, simulating
42+
the behavior of a language model without actually performing any
43+
complex natural language processing.
44+
"""
1745
def __init__(self, *args, **kwargs):
1846
pass
1947

2048
def get_response(self, prompts: str) -> str:
49+
"""
50+
Generate random responses for the given prompts.
51+
52+
This method creates silly, random responses enclosed in <prompt> tags.
53+
It's designed for testing and demonstration purposes.
54+
55+
Args:
56+
prompts (str or List[str]): Input prompt(s). If a single string is provided,
57+
it's converted to a list containing that string.
58+
59+
Returns:
60+
List[str]: A list of randomly generated responses, one for each input prompt.
61+
"""
2162
if isinstance(prompts, str):
2263
prompts = [prompts]
2364
results = []

promptolution/llms/deepinfra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
2+
# TODO delete?
33
from typing import (
44
Any,
55
AsyncIterator,

promptolution/llms/local_llm.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,30 @@
99

1010

1111
class LocalLLM:
12+
"""
13+
A class for running language models locally using the Hugging Face Transformers library.
14+
15+
This class sets up a text generation pipeline with specified model parameters
16+
and provides a method to generate responses for given prompts.
17+
18+
Attributes:
19+
pipeline (transformers.Pipeline): The text generation pipeline.
20+
21+
Methods:
22+
get_response: Generate responses for a list of prompts.
23+
"""
1224
def __init__(self, model_id: str, batch_size=8):
25+
"""
26+
Initialize the LocalLLM with a specific model.
27+
28+
Args:
29+
model_id (str): The identifier of the model to use (e.g., "gpt2", "facebook/opt-1.3b").
30+
batch_size (int, optional): The batch size for text generation. Defaults to 8.
31+
32+
Note:
33+
This method sets up a text generation pipeline with bfloat16 precision,
34+
automatic device mapping, and specific generation parameters.
35+
"""
1336
self.pipeline = transformers.pipeline(
1437
"text-generation",
1538
model=model_id,
@@ -24,6 +47,19 @@ def __init__(self, model_id: str, batch_size=8):
2447
self.pipeline.tokenizer.padding_side = "left"
2548

2649
def get_response(self, prompts: list[str]):
50+
"""
51+
Generate responses for a list of prompts using the local language model.
52+
53+
Args:
54+
prompts (list[str]): A list of input prompts.
55+
56+
Returns:
57+
list[str]: A list of generated responses corresponding to the input prompts.
58+
59+
Note:
60+
This method uses torch.no_grad() for inference to reduce memory usage.
61+
It handles both single and batch inputs, ensuring consistent output format.
62+
"""
2763
with torch.no_grad():
2864
response = self.pipeline(prompts, pad_token_id=self.pipeline.tokenizer.eos_token_id)
2965

@@ -32,3 +68,10 @@ def get_response(self, prompts: list[str]):
3268

3369
response = [r["generated_text"] for r in response]
3470
return response
71+
72+
def __del__(self):
73+
try:
74+
del self.pipeline
75+
torch.cuda.empty_cache()
76+
except Exception as e:
77+
logger.warning(f"Error during LocalLLM cleanup: {e}")

promptolution/optimizers/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,26 @@
44

55

66
def get_optimizer(config, *args, **kwargs):
7+
"""
8+
Factory function to create and return an optimizer instance based on the provided configuration.
9+
10+
This function selects and instantiates the appropriate optimizer class based on the
11+
'optimizer' field in the config object. It supports three types of optimizers:
12+
'dummy', 'evopromptde', and 'evopromptga'.
13+
14+
Args:
15+
config: A configuration object that must have an 'optimizer' attribute.
16+
For 'evopromptde', it should also have a 'donor_random' attribute.
17+
For 'evopromptga', it should also have a 'selection_mode' attribute.
18+
*args: Variable length argument list passed to the optimizer constructor.
19+
**kwargs: Arbitrary keyword arguments passed to the optimizer constructor.
20+
21+
Returns:
22+
An instance of the specified optimizer class.
23+
24+
Raises:
25+
ValueError: If an unknown optimizer type is specified in the config.
26+
"""
727
if config.optimizer == "dummy":
828
return DummyOptimizer(*args, **kwargs)
929
if config.optimizer == "evopromptde":

0 commit comments

Comments
 (0)