diff --git a/requirements.txt b/requirements.txt index bedbc67..91cf1fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,11 @@ yapf==0.40.2 matplotlib==3.9.2 pydantic==2.9.2 scikit-learn==1.5.2 +termcolor==3.1.0 +tiktoken==0.12.0 +diskcache==5.6.3 +azure-identity==1.25.1 +flaml==2.3.6 gdown==5.2.0 open_clip_torch==2.29.0 diff --git a/train_methods/data.py b/train_methods/data.py index b2ba3e1..af7f981 100644 --- a/train_methods/data.py +++ b/train_methods/data.py @@ -755,6 +755,98 @@ def __getitem__(self, i): example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) return example + +class COGFDDataset(Dataset): + def __init__( + self, + data_dir: str, + tokenizer: CLIPTokenizer, + size: int=512, + center_crop=False, + use_pooler=False, + task_info=None, + concept_combination=None, + labels=None + ): + self.use_pooler = use_pooler + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + if task_info is None or len(task_info) != 2: + raise ValueError("task_info must be a list/tuple of length 2 containing [concept, theme]") + + if concept_combination is None or len(concept_combination) == 0: + raise ValueError("concept_combination cannot be None or empty") + + if labels is None or len(labels) == 0: + raise ValueError("labels cannot be None or empty") + + if len(concept_combination) != len(labels): + raise ValueError(f"Length mismatch: concept_combination ({len(concept_combination)}) != labels ({len(labels)})") + + self.instance_images_path = [] + self.instance_prompt = [] + + p = Path(data_dir) + if not p.exists(): + raise ValueError(f"Instance {p} images root doesn't exists.") + + image_paths = list(p.iterdir()) + if len(image_paths) == 0: + raise ValueError(f"No images found in {p}") + + self.instance_images_path += image_paths + + self.prompts = concept_combination + self.labels = labels + + self.num_instance_images = len(self.instance_images_path) + self._length = len(self.prompts) + + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self._length + + def __getitem__(self, index) -> dict: + if index >= self._length: + raise IndexError(f"Index {index} out of range for dataset of length {self._length}") + + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + concept = self.prompts[index % self._length] + label = self.labels[index % self._length] + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["concept"] = concept + example["label"] = label + example["instance_images"] = self.image_transforms(instance_image) + + example["prompt_ids"] = self.tokenizer( + concept, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + example["attention_mask"] = self.tokenizer( + concept, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).attention_mask + + return example + class MCEDataset(Dataset): def __init__( self, diff --git a/train_methods/legacy_autogen/cache.py b/train_methods/legacy_autogen/cache.py new file mode 100644 index 0000000..1821219 --- /dev/null +++ b/train_methods/legacy_autogen/cache.py @@ -0,0 +1,105 @@ +from pathlib import Path +from types import TracebackType +from typing import Any, Protocol, Self + +import diskcache + +class AbstractCache(Protocol): + + def get(self, key: str, default: Any | None = None) -> Any | None: + ... + + def set(self, key: str, value: Any) -> None: + ... + + def close(self) -> None: + ... + + def __enter__(self) -> Self: + ... + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + ... + +class DiskCache(AbstractCache): + def __init__(self, seed: str | int): + self.cache = diskcache.Cache(seed) + + def get(self, key: str, default: Any | None = None) -> Any | None: + return self.cache.get(key, default) + + def set(self, key: str, value: Any) -> None: + self.cache.set(key, value) + + def close(self) -> None: + self.cache.close() + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + +class CacheFactory: + @staticmethod + def cache_factory( + seed: str | int, + cache_path_root: str = ".cache", + ) -> AbstractCache: + path = Path(cache_path_root, str(seed)) + return DiskCache(Path(".", path)) + +class Cache(AbstractCache): + ALLOWED_CONFIG_KEYS = [ + "cache_seed", + "cache_path_root", + ] + + @staticmethod + def disk(cache_seed: str | int = 42, cache_path_root: str = ".cache") -> "Cache": + return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root}) + + def __init__(self, config: dict[str, Any]): + self.config = config + # Ensure that the seed is always treated as a string before being passed to any cache factory or stored. + self.config["cache_seed"] = str(self.config.get("cache_seed", 42)) + + # validate config + for key in self.config.keys(): + if key not in self.ALLOWED_CONFIG_KEYS: + raise ValueError(f"Invalid config key: {key}") + # create cache instance + self.cache = CacheFactory.cache_factory( + seed=self.config["cache_seed"], + cache_path_root=self.config.get("cache_path_root", ""), + ) + + def __enter__(self) -> "Cache": + return self.cache.__enter__() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self.cache.__exit__(exc_type, exc_value, traceback) + + def get(self, key: str, default: Any | None = None) -> Any | None: + return self.cache.get(key, default) + + def set(self, key: str, value: Any) -> None: + self.cache.set(key, value) + + def close(self) -> None: + self.cache.close() diff --git a/train_methods/legacy_autogen/chat.py b/train_methods/legacy_autogen/chat.py new file mode 100644 index 0000000..323c637 --- /dev/null +++ b/train_methods/legacy_autogen/chat.py @@ -0,0 +1,286 @@ +import asyncio +import datetime +from collections import defaultdict +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable + +from termcolor import colored + +from train_methods.legacy_autogen.stream import IOStream + +def consolidate_chat_info(chat_info, uniform_sender=None) -> None: + if isinstance(chat_info, dict): + chat_info = [chat_info] + for c in chat_info: + if uniform_sender is None: + assert "sender" in c, "sender must be provided." + sender = c["sender"] + else: + sender = uniform_sender + assert "recipient" in c, "recipient must be provided." + summary_method = c.get("summary_method") + assert ( + summary_method is None + or isinstance(summary_method, Callable) + or summary_method in ("last_msg", "reflection_with_llm") + ), "summary_method must be a string chosen from 'reflection_with_llm' or 'last_msg' or a callable, or None." + if summary_method == "reflection_with_llm": + assert ( + sender.client is not None or c["recipient"].client is not None + ), "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm." + + +@dataclass +class ChatResult: + + chat_id: int = None + chat_history: list[dict[str, Any]] = None + summary: str = None + cost: dict[str, dict] = None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference" + """The cost of the chat. + The value for each usage type is a dictionary containing cost information for that specific type. + - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference. + - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference". + """ + human_input: list[str] = None + """A list of human input solicited during the chat.""" + + +def _validate_recipients(chat_queue: list[dict[str, Any]]) -> None: + """ + Validate recipients exits and warn repetitive recipients. + """ + receipts_set = set() + for chat_info in chat_queue: + assert "recipient" in chat_info, "recipient must be provided." + receipts_set.add(chat_info["recipient"]) + + +def __create_async_prerequisites(chat_queue: list[dict[str, Any]]) -> list[tuple[int, int]]: + """ + Create list of tuple[int, int] (prerequisite_chat_id, chat_id) + """ + prerequisites = [] + for chat_info in chat_queue: + if "chat_id" not in chat_info: + raise ValueError("Each chat must have a unique id for async multi-chat execution.") + chat_id = chat_info["chat_id"] + pre_chats = chat_info.get("prerequisites", []) + for pre_chat_id in pre_chats: + if not isinstance(pre_chat_id, int): + raise ValueError("tuple[int, int] chat id is not int.") + prerequisites.append((chat_id, pre_chat_id)) + return prerequisites + + +def __find_async_chat_order(chat_ids: set[int], prerequisites: list[tuple[int, int]]) -> list[int]: + """Find chat order for async execution based on the prerequisite chats + + args: + num_chats: number of chats + prerequisites: list of tuple[int, int] (prerequisite_chat_id, chat_id) + + returns: + list: a list of chat_id in order. + """ + edges = defaultdict(set) + indegree = defaultdict(int) + for pair in prerequisites: + chat, pre = pair[0], pair[1] + if chat not in edges[pre]: + indegree[chat] += 1 + edges[pre].add(chat) + bfs = [i for i in chat_ids if i not in indegree] + chat_order = [] + steps = len(indegree) + for _ in range(steps + 1): + if not bfs: + break + chat_order.extend(bfs) + nxt = [] + for node in bfs: + if node in edges: + for course in edges[node]: + indegree[course] -= 1 + if indegree[course] == 0: + nxt.append(course) + indegree.pop(course) + edges.pop(node) + bfs = nxt + + if indegree: + return [] + return chat_order + + +def _post_process_carryover_item(carryover_item): + if isinstance(carryover_item, str): + return carryover_item + elif isinstance(carryover_item, dict) and "content" in carryover_item: + return str(carryover_item["content"]) + else: + return str(carryover_item) + + +def __post_carryover_processing(chat_info: dict[str, Any]) -> None: + iostream = IOStream.get_default() + + print_carryover = ( + ("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]]) + if isinstance(chat_info["carryover"], list) + else chat_info["carryover"] + ) + message = chat_info.get("message") + if isinstance(message, str): + print_message = message + elif callable(message): + print_message = "Callable: " + message.__name__ + elif isinstance(message, dict): + print_message = "dict: " + str(message) + elif message is None: + print_message = "None" + iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="") + iostream.print( + colored( + "Starting a new chat....", + "blue", + ), + flush=True, + ) + if chat_info.get("verbose", False): + iostream.print(colored("Message:\n" + print_message, "blue"), flush=True) + iostream.print(colored("Carryover:\n" + print_carryover, "blue"), flush=True) + iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="") + + +def initiate_chats(chat_queue: list[dict[str, Any]]) -> list[ChatResult]: + """Initiate a list of chats. + Args: + chat_queue (list[dict]): A list of dictionaries containing the information about the chats. + + Each dictionary should contain the input arguments for + [`ConversableAgent.initiate_chat`](/docs/reference/agentchat/conversable_agent#initiate_chat). + For example: + - `"sender"` - the sender agent. + - `"recipient"` - the recipient agent. + - `"clear_history"` (bool) - whether to clear the chat history with the agent. + Default is True. + - `"silent"` (bool or None) - (Experimental) whether to print the messages in this + conversation. Default is False. + - `"cache"` (Cache or None) - the cache client to use for this conversation. + Default is None. + - `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat + will continue until a termination condition is met. Default is None. + - `"summary_method"` (str or callable) - a string or callable specifying the method to get + a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg". + - `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method. + Default is {}. + - `"message"` (str, callable or None) - if None, input() will be called to get the + initial message. + - `**context` - additional context information to be passed to the chat. + - `"carryover"` - It can be used to specify the carryover information to be passed + to this chat. If provided, we will combine this carryover with the "message" content when + generating the initial chat message in `generate_init_message`. + - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list, + from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list, + then summary from all the finished chats will be taken. + Returns: + (list): a list of ChatResult objects corresponding to the finished chats in the chat_queue. + """ + + consolidate_chat_info(chat_queue) + _validate_recipients(chat_queue) + current_chat_queue = chat_queue.copy() + finished_chats = [] + while current_chat_queue: + chat_info = current_chat_queue.pop(0) + _chat_carryover = chat_info.get("carryover", []) + finished_chat_indexes_to_exclude_from_carryover = chat_info.get( + "finished_chat_indexes_to_exclude_from_carryover", [] + ) + + if isinstance(_chat_carryover, str): + _chat_carryover = [_chat_carryover] + chat_info["carryover"] = _chat_carryover + [ + r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover + ] + __post_carryover_processing(chat_info) + + sender = chat_info["sender"] + chat_res = sender.initiate_chat(**chat_info) + finished_chats.append(chat_res) + return finished_chats + + +def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int): + """ + Update ChatResult when async Task for Chat is completed. + """ + print(f"Update chat {chat_id} result on task completion. System time at {datetime.datetime.now()}.") + chat_result = chat_future.result() + chat_result.chat_id = chat_id + + +async def _dependent_chat_future( + chat_id: int, chat_info: dict[str, Any], prerequisite_chat_futures: dict[int, asyncio.Future] +) -> asyncio.Task: + """ + Create an async Task for each chat. + """ + print(f"Create Task for chat {chat_id}. System time at {datetime.datetime.now()}.") + _chat_carryover = chat_info.get("carryover", []) + finished_chat_indexes_to_exclude_from_carryover = chat_info.get( + "finished_chat_indexes_to_exclude_from_carryover", [] + ) + finished_chats = dict() + for chat in prerequisite_chat_futures: + chat_future = prerequisite_chat_futures[chat] + if chat_future.cancelled(): + raise RuntimeError(f"Chat {chat} is cancelled.") + + # wait for prerequisite chat results for the new chat carryover + finished_chats[chat] = await chat_future + + if isinstance(_chat_carryover, str): + _chat_carryover = [_chat_carryover] + data = [ + chat_result.summary + for chat_id, chat_result in finished_chats.items() + if chat_id not in finished_chat_indexes_to_exclude_from_carryover + ] + chat_info["carryover"] = _chat_carryover + data + if not chat_info.get("silent", False): + __post_carryover_processing(chat_info) + + sender = chat_info["sender"] + chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info)) + call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id) + chat_res_future.add_done_callback(call_back_with_args) + print(f"Task for chat {chat_id} created. System time at {datetime.datetime.now()}.") + return chat_res_future + + +async def a_initiate_chats(chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]: + consolidate_chat_info(chat_queue) + _validate_recipients(chat_queue) + chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue} + num_chats = chat_book.keys() + prerequisites = __create_async_prerequisites(chat_queue) + chat_order_by_id = __find_async_chat_order(num_chats, prerequisites) + finished_chat_futures = dict() + for chat_id in chat_order_by_id: + chat_info = chat_book[chat_id] + prerequisite_chat_ids = chat_info.get("prerequisites", []) + pre_chat_futures = dict() + for pre_chat_id in prerequisite_chat_ids: + pre_chat_future = finished_chat_futures[pre_chat_id] + pre_chat_futures[pre_chat_id] = pre_chat_future + current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures) + finished_chat_futures[chat_id] = current_chat_future + await asyncio.gather(*list(finished_chat_futures.values())) + finished_chats = dict() + for chat in finished_chat_futures: + chat_result = finished_chat_futures[chat].result() + finished_chats[chat] = chat_result + return finished_chats diff --git a/train_methods/legacy_autogen/client.py b/train_methods/legacy_autogen/client.py new file mode 100644 index 0000000..1ff9592 --- /dev/null +++ b/train_methods/legacy_autogen/client.py @@ -0,0 +1,921 @@ +import json +import inspect +import time + +from typing import Protocol, Any, Callable + +from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI +from openai.resources import Completions +from openai.types.chat import ChatCompletion +from openai.types.chat.chat_completion import Choice, ChatCompletionMessage +from openai.types.completion import Completion +from openai.types.completion_usage import CompletionUsage +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) +from pydantic import BaseModel + +from train_methods.legacy_autogen.cache import Cache +from train_methods.legacy_autogen.stream import IOStream + +NON_CACHE_KEY = [ + "api_key", + "base_url", + "api_type", + "api_version", + "azure_ad_token", + "azure_ad_token_provider", + "credentials", + "tool_config", +] + +OAI_PRICE1K = { + # https://openai.com/api/pricing/ + # gpt-4o + "gpt-4o": (0.005, 0.015), + "gpt-4o-2024-05-13": (0.005, 0.015), + "gpt-4o-2024-08-06": (0.0025, 0.01), + # gpt-4-turbo + "gpt-4-turbo-2024-04-09": (0.01, 0.03), + # gpt-4 + "gpt-4": (0.03, 0.06), + "gpt-4-32k": (0.06, 0.12), + # gpt-4o-mini + "gpt-4o-mini": (0.000150, 0.000600), + "gpt-4o-mini-2024-07-18": (0.000150, 0.000600), + # gpt-3.5 turbo + "gpt-3.5-turbo": (0.0005, 0.0015), # default is 0125 + "gpt-3.5-turbo-0125": (0.0005, 0.0015), # 16k + "gpt-3.5-turbo-instruct": (0.0015, 0.002), + # base model + "davinci-002": 0.002, + "babbage-002": 0.0004, + # old model + "gpt-4-0125-preview": (0.01, 0.03), + "gpt-4-1106-preview": (0.01, 0.03), + "gpt-4-1106-vision-preview": (0.01, 0.03), # TODO: support vision pricing of images + "gpt-3.5-turbo-1106": (0.001, 0.002), + "gpt-3.5-turbo-0613": (0.0015, 0.002), + # "gpt-3.5-turbo-16k": (0.003, 0.004), + "gpt-3.5-turbo-16k-0613": (0.003, 0.004), + "gpt-3.5-turbo-0301": (0.0015, 0.002), + "text-ada-001": 0.0004, + "text-babbage-001": 0.0005, + "text-curie-001": 0.002, + "code-cushman-001": 0.024, + "code-davinci-002": 0.1, + "text-davinci-002": 0.02, + "text-davinci-003": 0.02, + "gpt-4-0314": (0.03, 0.06), # deprecate in Sep + "gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep + "gpt-4-0613": (0.03, 0.06), + "gpt-4-32k-0613": (0.06, 0.12), + "gpt-4-turbo-preview": (0.01, 0.03), + # https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/#pricing + "gpt-35-turbo": (0.0005, 0.0015), # what's the default? using 0125 here. + "gpt-35-turbo-0125": (0.0005, 0.0015), + "gpt-35-turbo-instruct": (0.0015, 0.002), + "gpt-35-turbo-1106": (0.001, 0.002), + "gpt-35-turbo-0613": (0.0015, 0.002), + "gpt-35-turbo-0301": (0.0015, 0.002), + "gpt-35-turbo-16k": (0.003, 0.004), + "gpt-35-turbo-16k-0613": (0.003, 0.004), +} + + +def get_key(config: dict[str, Any]) -> str: + """Get a unique identifier of a configuration. + + Args: + config (dict or list): A configuration. + + Returns: + tuple: A unique identifier which can be used as a key for a dict. + """ + copied = False + for key in NON_CACHE_KEY: + if key in config: + config, copied = config.copy() if not copied else config, True + config.pop(key) + return json.dumps(config, sort_keys=True) + + +class PlaceHolderClient: + def __init__(self, config): + self.config = config + +class ModelClient(Protocol): + """ + A client class must implement the following methods: + - create must return a response object that implements the ModelClientResponseProtocol + - cost must return the cost of the response + - get_usage must return a dict with the following keys: + - prompt_tokens + - completion_tokens + - total_tokens + - cost + - model + + This class is used to create a client that can be used by OpenAIWrapper. + The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed. + The message_retrieval method must be implemented to return a list of str or a list of messages from the response. + """ + + RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] + + class ModelClientResponseProtocol(Protocol): + class Choice(Protocol): + class Message(Protocol): + content: str | None + + message: Message + + choices: list[Choice] + model: str + + def create(self, params: dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover + + def message_retrieval( + self, response: ModelClientResponseProtocol + ) -> list[str] | list[ModelClientResponseProtocol.Choice.Message]: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + ... # pragma: no cover + + def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover + + @staticmethod + def get_usage(response: ModelClientResponseProtocol) -> dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + ... # pragma: no cover + +class RateLimiter(Protocol): + def sleep(self, *args, **kwargs): ... + +class TimeRateLimiter: + """A class to implement a time-based rate limiter. + + This rate limiter ensures that a certain operation does not exceed a specified frequency. + It can be used to limit the rate of requests sent to a server or the rate of any repeated action. + """ + + def __init__(self, rate: float): + """ + Args: + rate (int): The frequency of the time-based rate limiter (NOT time interval). + """ + self._time_interval_seconds = 1.0 / rate + self._last_time_called = 0.0 + + def sleep(self, *args, **kwargs): + """Synchronously waits until enough time has passed to allow the next operation. + + If the elapsed time since the last operation is less than the required time interval, + this method will block the execution by sleeping for the remaining time. + """ + if self._elapsed_time() < self._time_interval_seconds: + time.sleep(self._time_interval_seconds - self._elapsed_time()) + + self._last_time_called = time.perf_counter() + + def _elapsed_time(self): + return time.perf_counter() - self._last_time_called + +class OpenAIClient: + """Follows the Client protocol and wraps the OpenAI client.""" + + def __init__(self, client: OpenAI | AzureOpenAI): + self._oai_client = client + + def message_retrieval( + self, response: ChatCompletion | Completion + ) -> list[str] | list[ChatCompletionMessage]: + """Retrieve the messages from the response.""" + choices = response.choices + if isinstance(response, Completion): + return [choice.text for choice in choices] # type: ignore [union-attr] + + return [ # type: ignore [return-value] + ( + choice.message # type: ignore [union-attr] + if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] + else choice.message.content + ) # type: ignore [union-attr] + for choice in choices + ] + + def create(self, params: dict[str, Any]) -> ChatCompletion: + """Create a completion for a given config using openai's client. + + Args: + client: The openai client. + params: The params for the completion. + + Returns: + The completion. + """ + iostream = IOStream.get_default() + + completions: Completions = ( + self._oai_client.chat.completions if "messages" in params else self._oai_client.completions + ) # type: ignore [attr-defined] + # If streaming is enabled and has messages, then iterate over the chunks of the response. + if params.get("stream", False) and "messages" in params: + response_contents = [""] * params.get("n", 1) + finish_reasons = [""] * params.get("n", 1) + completion_tokens = 0 + + # Set the terminal text color to green + iostream.print("\033[32m", end="") + + # Prepare for potential function call + full_function_call: dict[str, Any] | None = None + full_tool_calls: list[dict[str, Any | None]] | None = None + + # Send the chat completion request to OpenAI's API and process the response in chunks + for chunk in completions.create(**params): + if chunk.choices: + for choice in chunk.choices: + content = choice.delta.content + tool_calls_chunks = choice.delta.tool_calls + finish_reasons[choice.index] = choice.finish_reason + + # todo: remove this after function calls are removed from the API + # the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail + # begin block + function_call_chunk = ( + choice.delta.function_call if hasattr(choice.delta, "function_call") else None + ) + # Handle function call + if function_call_chunk: + # Handle function call + if function_call_chunk: + full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk( + function_call_chunk, full_function_call, completion_tokens + ) + if not content: + continue + # end block + + # Handle tool calls + if tool_calls_chunks: + for tool_calls_chunk in tool_calls_chunks: + # the current tool call to be reconstructed + ix = tool_calls_chunk.index + if full_tool_calls is None: + full_tool_calls = [] + if ix >= len(full_tool_calls): + # in case ix is not sequential + full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1) + + full_tool_calls[ix], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk( + tool_calls_chunk, full_tool_calls[ix], completion_tokens + ) + if not content: + continue + + # End handle tool calls + + # If content is present, print it to the terminal and update response variables + if content is not None: + iostream.print(content, end="", flush=True) + response_contents[choice.index] += content + completion_tokens += 1 + else: + # iostream.print() + pass + + # Reset the terminal text color + iostream.print("\033[0m\n") + + # Prepare the final ChatCompletion object based on the accumulated data + prompt_tokens = 0 + response = ChatCompletion( + id=chunk.id, + model=chunk.model, + created=chunk.created, + object="chat.completion", + choices=[], + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + for i in range(len(response_contents)): + choice = Choice( + index=i, + finish_reason=finish_reasons[i], + message=ChatCompletionMessage( + role="assistant", + content=response_contents[i], + function_call=full_function_call, + tool_calls=full_tool_calls, + ), + logprobs=None, + ) + + response.choices.append(choice) + else: + # If streaming is not enabled, send a regular chat completion request + params = params.copy() + params["stream"] = False + response = completions.create(**params) + + return response + + def cost(self, response: ChatCompletion | Completion) -> float: + """Calculate the cost of the response.""" + model = response.model + if model not in OAI_PRICE1K: + # log warning that the model is not found + print( + f'Model {model} is not found. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.' + ) + return 0 + + n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr] + n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr] + if n_output_tokens is None: + n_output_tokens = 0 + tmp_price1K = OAI_PRICE1K[model] + # First value is input token rate, second value is output token rate + if isinstance(tmp_price1K, tuple): + return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return] + return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator] + + @staticmethod + def get_usage(response: ChatCompletion | Completion) -> dict: + return { + "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0, + "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0, + "total_tokens": response.usage.total_tokens if response.usage is not None else 0, + "cost": response.cost if hasattr(response, "cost") else 0, + "model": response.model, + } + +class OpenAIWrapper: + """A wrapper class for openai client.""" + + extra_kwargs = { + "agent", + "cache", + "cache_seed", + "filter_func", + "allow_format_str_template", + "context", + "api_version", + "api_type", + "tags", + "price", + } + + openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) + aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs) + openai_kwargs = openai_kwargs | aopenai_kwargs + total_usage_summary: dict[str, Any] | None = None + actual_usage_summary: dict[str, Any] | None = None + + def __init__(self, *, config_list: list[dict[str, Any]] | None = None, **base_config: Any): + """ + Args: + config_list: a list of config dicts to override the base_config. + They can contain additional kwargs as allowed in the [create](/docs/reference/oai/client#create) method. + + base_config: base config. It can contain both keyword arguments for openai client + and additional kwargs. + When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `base_config` or in each config of `config_list`. + """ + + openai_config, extra_kwargs = self._separate_openai_config(base_config) + # It's OK if "model" is not provided in base_config or config_list + # Because one can provide "model" at `create` time. + + self._clients: list[ModelClient] = [] + self._config_list: list[dict[str, Any]] = [] + self._rate_limiters: list[RateLimiter | None] = [] + + if config_list: + self._initialize_rate_limiters(config_list) + + config_list = [config.copy() for config in config_list] # make a copy before modifying + for config in config_list: + self._register_default_client(config, openai_config) # could modify the config + self._config_list.append( + {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} + ) + else: + self._register_default_client(extra_kwargs, openai_config) + self._config_list = [extra_kwargs] + self.wrapper_id = id(self) + + def _separate_openai_config(self, config: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + """Separate the config into openai_config and extra_kwargs.""" + openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs} + extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs} + return openai_config, extra_kwargs + + def _separate_create_config(self, config: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + """Separate the config into create_config and extra_kwargs.""" + create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs} + extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs} + return create_config, extra_kwargs + + def _configure_azure_openai(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None: + openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model")) + if openai_config["azure_deployment"] is not None: + openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") + openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) + + # Create a default Azure token provider if requested + if openai_config.get("azure_ad_token_provider") == "DEFAULT": + import azure.identity + + azure_ad_token_provider_scope = openai_config.get( + "azure_ad_token_provider_scope", "https://cognitiveservices.azure.com/.default" + ) + openai_config["azure_ad_token_provider"] = azure.identity.get_bearer_token_provider( + azure.identity.DefaultAzureCredential(), azure_ad_token_provider_scope + ) + + def _configure_openai_config_for_bedrock(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None: + """Update openai_config with AWS credentials from config.""" + required_keys = ["aws_access_key", "aws_secret_key", "aws_region"] + optional_keys = ["aws_session_token", "aws_profile_name"] + for key in required_keys: + if key in config: + openai_config[key] = config[key] + for key in optional_keys: + if key in config: + openai_config[key] = config[key] + + def _register_default_client(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None: + """Create a client with the given config to override openai_config, + after removing extra kwargs. + + For Azure models/deployment names there's a convenience modification of model removing dots in + the it's value (Azure deployment names can't have dots). I.e. if you have Azure deployment name + "gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot + from the name and create a client that connects to "gpt-35-turbo" Azure deployment. + """ + openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} + api_type = config.get("api_type") + model_client_cls_name = config.get("model_client_cls") + if model_client_cls_name is not None: + # a config for a custom client is set + # adding placeholder until the register_model_client is called with the appropriate class + self._clients.append(PlaceHolderClient(config)) + print( + f"Detected custom model client in config: {model_client_cls_name}, model client can not be used until register_model_client is called." + ) + else: + if api_type is not None and api_type.startswith("azure"): + self._configure_azure_openai(config, openai_config) + client = AzureOpenAI(**openai_config) + self._clients.append(OpenAIClient(client)) + else: + client = OpenAI(**openai_config) + self._clients.append(OpenAIClient(client)) + + def register_model_client(self, model_client_cls: ModelClient, **kwargs): + """Register a model client. + + Args: + model_client_cls: A custom client class that follows the ModelClient interface + **kwargs: The kwargs for the custom client class to be initialized with + """ + existing_client_class = False + for i, client in enumerate(self._clients): + if isinstance(client, PlaceHolderClient): + placeholder_config = client.config + + if placeholder_config.get("model_client_cls") == model_client_cls.__name__: + self._clients[i] = model_client_cls(placeholder_config, **kwargs) + return + elif isinstance(client, model_client_cls): + existing_client_class = True + + if existing_client_class: + print( + f"Model client {model_client_cls.__name__} is already registered. Add more entries in the config_list to use multiple model clients." + ) + else: + raise ValueError( + f'Model client "{model_client_cls.__name__}" is being registered but was not found in the config_list. ' + f'Please make sure to include an entry in the config_list with "model_client_cls": "{model_client_cls.__name__}"' + ) + + @classmethod + def instantiate( + cls, + template: str | Callable[[dict[str, Any]], str] | None, + context: dict[str, Any] | None = None, + allow_format_str_template: bool | None = False, + ) -> str | None: + if not context or template is None: + return template # type: ignore [return-value] + if isinstance(template, str): + return template.format(**context) if allow_format_str_template else template + return template(context) + + def _construct_create_params(self, create_config: dict[str, Any], extra_kwargs: dict[str, Any]) -> dict[str, Any]: + """Prime the create_config with additional_kwargs.""" + # Validate the config + prompt: str | None = create_config.get("prompt") + messages: list[dict[str, Any]] | None = create_config.get("messages") + if (prompt is None) == (messages is None): + raise ValueError("Either prompt or messages should be in create config but not both.") + context = extra_kwargs.get("context") + if context is None: + # No need to instantiate if no context is provided. + return create_config + # Instantiate the prompt or messages + allow_format_str_template = extra_kwargs.get("allow_format_str_template", False) + # Make a copy of the config + params = create_config.copy() + if prompt is not None: + # Instantiate the prompt + params["prompt"] = self.instantiate(prompt, context, allow_format_str_template) + elif context: + # Instantiate the messages + params["messages"] = [ + ( + { + **m, + "content": self.instantiate(m["content"], context, allow_format_str_template), + } + if m.get("content") + else m + ) + for m in messages # type: ignore [union-attr] + ] + return params + + def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: + """Make a completion for a given config using available clients. + Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. + The config in each client will be overridden by the config. + + Args: + - context (dict | None): The context to instantiate the prompt or messages. Default to None. + It needs to contain keys that are used by the prompt template or the filter function. + E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`. + The actual prompt will be: + "Complete the following sentence: Today I feel". + More examples can be found at [templating](/docs/Use-Cases/enhanced_inference#templating). + - agent (AbstractAgent | None): The object responsible for creating a completion if an agent. + - filter_func (Callable | None): A function that takes in the context and the response + and returns a boolean to indicate whether the response is valid. E.g., + - allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false. + - api_version (str | None): The api version. Default to None. E.g., "2024-02-01". + Raises: + - RuntimeError: If all declared custom model clients are not registered + - APIError: If any model client create call raises an APIError + """ + + last = len(self._clients) - 1 + # Check if all configs in config list are activated + non_activated = [ + client.config["model_client_cls"] for client in self._clients if isinstance(client, PlaceHolderClient) + ] + if non_activated: + raise RuntimeError( + f"Model client(s) {non_activated} are not activated. Please register the custom model clients using `register_model_client` or filter them out form the config list." + ) + for i, client in enumerate(self._clients): + # merge the input config with the i-th config in the config list + full_config = {**config, **self._config_list[i]} + # separate the config into create_config and extra_kwargs + create_config, extra_kwargs = self._separate_create_config(full_config) + api_type = extra_kwargs.get("api_type") + if api_type and api_type.startswith("azure") and "model" in create_config: + create_config["model"] = create_config["model"].replace(".", "") + # construct the create params + params = self._construct_create_params(create_config, extra_kwargs) + # get the cache_seed, filter_func and context + cache = None + filter_func = extra_kwargs.get("filter_func") + context = extra_kwargs.get("context") + price = extra_kwargs.get("price", None) + if isinstance(price, list): + price = tuple(price) + elif isinstance(price, float) or isinstance(price, int): + print( + "Input price is a float/int. Using the same price for prompt and completion tokens. Use a list/tuple if prompt and completion token prices are different." + ) + price = (price, price) + + total_usage = None + actual_usage = None + cache_client = Cache.disk(41, ".cache") + + with cache_client as cache: + key = get_key(params) + + response: ModelClient.ModelClientResponseProtocol = cache.get(key, None) + + if response is not None: + response.message_retrieval_function = client.message_retrieval + try: + response.cost # type: ignore [attr-defined] + except AttributeError: + # update attribute if cost is not calculated + response.cost = client.cost(response) + cache.set(key, response) + total_usage = client.get_usage(response) + + # check the filter + pass_filter = filter_func is None or filter_func(context=context, response=response) + if pass_filter or i == last: + # Return the response if it passes the filter or it is the last client + response.config_id = i + response.pass_filter = pass_filter + self._update_usage(actual_usage=actual_usage, total_usage=total_usage) + return response + continue # filter is not passed; try the next config + try: + self._throttle_api_calls(i) + response = client.create(params) + except APITimeoutError as err: + if i == last: + raise TimeoutError( + "OpenAI API call timed out. This could be due to congestion or too small a timeout value. The timeout can be specified by setting the 'timeout' value (in seconds) in the llm_config (if you are using agents) or the OpenAIWrapper constructor (if you are using the OpenAIWrapper directly)." + ) from err + except APIError as err: + error_code = getattr(err, "code", None) + if error_code == "content_filter": + # raise the error for content_filter + raise + if i == last: + raise + else: + # add cost calculation before caching no matter filter is passed or not + if price is not None: + response.cost = self._cost_with_customized_price(response, price) + else: + response.cost = client.cost(response) + actual_usage = client.get_usage(response) + total_usage = actual_usage.copy() if actual_usage is not None else total_usage + self._update_usage(actual_usage=actual_usage, total_usage=total_usage) + with cache_client as cache: + cache.set(key, response) + + response.message_retrieval_function = client.message_retrieval + # check the filter + pass_filter = filter_func is None or filter_func(context=context, response=response) + if pass_filter or i == last: + # Return the response if it passes the filter or it is the last client + response.config_id = i + response.pass_filter = pass_filter + return response + continue # filter is not passed; try the next config + raise RuntimeError("Should not reach here.") + + @staticmethod + def _cost_with_customized_price( + response: ModelClient.ModelClientResponseProtocol, price_1k: tuple[float, float] + ) -> None: + """If a customized cost is passed, overwrite the cost in the response.""" + n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr] + n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr] + if n_output_tokens is None: + n_output_tokens = 0 + return (n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]) / 1000 + + @staticmethod + def _update_dict_from_chunk(chunk: BaseModel, d: dict[str, Any], field: str) -> int: + """Update the dict from the chunk. + + Reads `chunk.field` and if present updates `d[field]` accordingly. + + Args: + chunk: The chunk. + d: The dict to be updated in place. + field: The field. + + Returns: + The updated dict. + + """ + completion_tokens = 0 + assert isinstance(d, dict), d + if hasattr(chunk, field) and getattr(chunk, field) is not None: + new_value = getattr(chunk, field) + if isinstance(new_value, list) or isinstance(new_value, dict): + raise NotImplementedError( + f"Field {field} is a list or dict, which is currently not supported. " + "Only string and numbers are supported." + ) + if field not in d: + d[field] = "" + if isinstance(new_value, str): + d[field] += getattr(chunk, field) + else: + d[field] = new_value + completion_tokens = 1 + + return completion_tokens + + @staticmethod + def _update_function_call_from_chunk( + function_call_chunk: ChoiceDeltaToolCallFunction | ChoiceDeltaFunctionCall, + full_function_call: dict[str, Any] | None, + completion_tokens: int, + ) -> tuple[dict[str, Any], int]: + """Update the function call from the chunk. + + Args: + function_call_chunk: The function call chunk. + full_function_call: The full function call. + completion_tokens: The number of completion tokens. + + Returns: + The updated full function call and the updated number of completion tokens. + + """ + # Handle function call + if function_call_chunk: + if full_function_call is None: + full_function_call = {} + for field in ["name", "arguments"]: + completion_tokens += OpenAIWrapper._update_dict_from_chunk( + function_call_chunk, full_function_call, field + ) + + if full_function_call: + return full_function_call, completion_tokens + else: + raise RuntimeError("Function call is not found, this should not happen.") + + @staticmethod + def _update_tool_calls_from_chunk( + tool_calls_chunk: ChoiceDeltaToolCall, + full_tool_call: dict[str, Any] | None, + completion_tokens: int, + ) -> tuple[dict[str, Any], int]: + """Update the tool call from the chunk. + + Args: + tool_call_chunk: The tool call chunk. + full_tool_call: The full tool call. + completion_tokens: The number of completion tokens. + + Returns: + The updated full tool call and the updated number of completion tokens. + + """ + # future proofing for when tool calls other than function calls are supported + if tool_calls_chunk.type and tool_calls_chunk.type != "function": + raise NotImplementedError( + f"Tool call type {tool_calls_chunk.type} is currently not supported. " + "Only function calls are supported." + ) + + # Handle tool call + assert full_tool_call is None or isinstance(full_tool_call, dict), full_tool_call + if tool_calls_chunk: + if full_tool_call is None: + full_tool_call = {} + for field in ["index", "id", "type"]: + completion_tokens += OpenAIWrapper._update_dict_from_chunk(tool_calls_chunk, full_tool_call, field) + + if hasattr(tool_calls_chunk, "function") and tool_calls_chunk.function: + if "function" not in full_tool_call: + full_tool_call["function"] = None + + full_tool_call["function"], completion_tokens = OpenAIWrapper._update_function_call_from_chunk( + tool_calls_chunk.function, full_tool_call["function"], completion_tokens + ) + + if full_tool_call: + return full_tool_call, completion_tokens + else: + raise RuntimeError("Tool call is not found, this should not happen.") + + def _update_usage(self, actual_usage, total_usage): + def update_usage(usage_summary, response_usage): + # go through RESPONSE_USAGE_KEYS and check that they are in response_usage and if not just return usage_summary + for key in ModelClient.RESPONSE_USAGE_KEYS: + if key not in response_usage: + return usage_summary + + model = response_usage["model"] + cost = response_usage["cost"] + prompt_tokens = response_usage["prompt_tokens"] + completion_tokens = response_usage["completion_tokens"] + if completion_tokens is None: + completion_tokens = 0 + total_tokens = response_usage["total_tokens"] + + if usage_summary is None: + usage_summary = {"total_cost": cost} + else: + usage_summary["total_cost"] += cost + + usage_summary[model] = { + "cost": usage_summary.get(model, {}).get("cost", 0) + cost, + "prompt_tokens": usage_summary.get(model, {}).get("prompt_tokens", 0) + prompt_tokens, + "completion_tokens": usage_summary.get(model, {}).get("completion_tokens", 0) + completion_tokens, + "total_tokens": usage_summary.get(model, {}).get("total_tokens", 0) + total_tokens, + } + return usage_summary + + if total_usage is not None: + self.total_usage_summary = update_usage(self.total_usage_summary, total_usage) + if actual_usage is not None: + self.actual_usage_summary = update_usage(self.actual_usage_summary, actual_usage) + + def print_usage_summary(self, mode: str | list[str] = ["actual", "total"]) -> None: + """Print the usage summary.""" + iostream = IOStream.get_default() + + def print_usage(usage_summary: dict[str, Any] | None, usage_type: str = "total") -> None: + word_from_type = "including" if usage_type == "total" else "excluding" + if usage_summary is None: + iostream.print("No actual cost incurred (all completions are using cache).", flush=True) + return + + iostream.print(f"Usage summary {word_from_type} cached usage: ", flush=True) + iostream.print(f"Total cost: {round(usage_summary['total_cost'], 5)}", flush=True) + for model, counts in usage_summary.items(): + if model == "total_cost": + continue # + iostream.print( + f"* Model '{model}': cost: {round(counts['cost'], 5)}, prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}", + flush=True, + ) + + if self.total_usage_summary is None: + iostream.print('No usage summary. Please call "create" first.', flush=True) + return + + if isinstance(mode, list): + if len(mode) == 0 or len(mode) > 2: + raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]') + if "actual" in mode and "total" in mode: + mode = "both" + elif "actual" in mode: + mode = "actual" + elif "total" in mode: + mode = "total" + + iostream.print("-" * 100, flush=True) + if mode == "both": + print_usage(self.actual_usage_summary, "actual") + iostream.print() + if self.total_usage_summary != self.actual_usage_summary: + print_usage(self.total_usage_summary, "total") + else: + iostream.print( + "All completions are non-cached: the total cost with cached completions is the same as actual cost.", + flush=True, + ) + elif mode == "total": + print_usage(self.total_usage_summary, "total") + elif mode == "actual": + print_usage(self.actual_usage_summary, "actual") + else: + raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]') + iostream.print("-" * 100, flush=True) + + def clear_usage_summary(self) -> None: + """Clear the usage summary.""" + self.total_usage_summary = None + self.actual_usage_summary = None + + @classmethod + def extract_text_or_completion_object( + cls, response: ModelClient.ModelClientResponseProtocol + ) -> list[str] | list[ModelClient.ModelClientResponseProtocol.Choice.Message]: + """Extract the text or ChatCompletion objects from a completion or chat response. + + Args: + response (ChatCompletion | Completion): The response from openai. + + Returns: + A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present. + """ + return response.message_retrieval_function(response) + + def _throttle_api_calls(self, idx: int) -> None: + """Rate limit api calls.""" + if idx < len(self._rate_limiters) and self._rate_limiters[idx]: + limiter = self._rate_limiters[idx] + + assert limiter is not None + limiter.sleep() + + def _initialize_rate_limiters(self, config_list: list[dict[str, Any]]) -> None: + for config in config_list: + # Instantiate the rate limiter + if "api_rate_limit" in config: + self._rate_limiters.append(TimeRateLimiter(config["api_rate_limit"])) + del config["api_rate_limit"] + else: + self._rate_limiters.append(None) diff --git a/train_methods/legacy_autogen/legacy_autogen.py b/train_methods/legacy_autogen/legacy_autogen.py new file mode 100644 index 0000000..8c71b0a --- /dev/null +++ b/train_methods/legacy_autogen/legacy_autogen.py @@ -0,0 +1,1140 @@ +"""Legacy autogen (ver 2.0) for cogfd + +""" + +import sys +import random +import re +from dataclasses import dataclass, field +from typing import Any, Callable, Literal + +from termcolor import colored + +from train_methods.legacy_autogen.legacy_autogen_conversable_agent import ConversableAgent, Agent +from train_methods.legacy_autogen.stream import IOStream +from train_methods.legacy_autogen.client import ModelClient +from train_methods.legacy_autogen.utils import content_str + + +class NoEligibleSpeaker(Exception): + """Exception raised for early termination of a GroupChat.""" + + def __init__(self, message: str = "No eligible speakers."): + self.message = message + super().__init__(self.message) + + +@dataclass +class GroupChat: + + agents: list[Agent] + messages: list[dict] + max_round: int = 10 + admin_name: str = "Admin" + func_call_filter: bool = True + speaker_selection_method: Literal["auto", "manual", "random", "round_robin"] | Callable = "auto" + max_retries_for_selecting_speaker: int = 2 + allow_repeat_speaker: bool | list[Agent] | None = None + allowed_or_disallowed_speaker_transitions: dict | None = None + speaker_transitions_type: Literal["allowed", "disallowed", None] = None + enable_clear_history: bool = False + send_introductions: bool = False + select_speaker_message_template: str = """You are in a role play game. The following roles are available: + {roles}. + Read the following conversation. + Then select the next role from {agentlist} to play. Only return the role.""" + select_speaker_prompt_template: str = ( + "Read the above conversation. Then select the next role from {agentlist} to play. Only return the role." + ) + select_speaker_auto_multiple_template: str = """You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules: + 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name + 2. If it refers to the "next" speaker name, choose that name + 3. Otherwise, choose the first provided speaker's name in the context + The names are case-sensitive and should not be abbreviated or changed. + Respond with ONLY the name of the speaker and DO NOT provide a reason.""" + select_speaker_auto_none_template: str = """You didn't choose a speaker. As a reminder, to determine the speaker use these prioritised rules: + 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name + 2. If it refers to the "next" speaker name, choose that name + 3. Otherwise, choose the first provided speaker's name in the context + The names are case-sensitive and should not be abbreviated or changed. + The only names that are accepted are {agentlist}. + Respond with ONLY the name of the speaker and DO NOT provide a reason.""" + select_speaker_transform_messages: Any = None + select_speaker_auto_verbose: bool | None = False + select_speaker_auto_model_client_cls: ModelClient | list[ModelClient] | None = None + select_speaker_auto_llm_config: dict | Literal[False] | None = None + role_for_select_speaker_messages: str | None = "system" + + _VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"] + _VALID_SPEAKER_TRANSITIONS_TYPE = ["allowed", "disallowed", None] + + # Define a class attribute for the default introduction message + DEFAULT_INTRO_MSG = ( + "Hello everyone. We have assembled a great team today to answer questions and solve tasks. In attendance are:" + ) + + allowed_speaker_transitions_dict: dict = field(init=False) + + def __post_init__(self): + # Post init steers clears of the automatically generated __init__ method from dataclass + + self.allow_repeat_speaker = True + + self.allowed_speaker_transitions_dict = {} + # Create a fully connected allowed_speaker_transitions_dict not including self loops + for agent in self.agents: + self.allowed_speaker_transitions_dict[agent] = [ + other_agent for other_agent in self.agents if other_agent != agent + ] + + for agent in self.agents: + self.allowed_speaker_transitions_dict[agent].append(agent) + + self._speaker_selection_transforms = None + + @property + def agent_names(self) -> list[str]: + """Return the names of the agents in the group chat.""" + return [agent.name for agent in self.agents] + + def reset(self): + """Reset the group chat.""" + self.messages.clear() + + def append(self, message: dict, speaker: Agent): + """Append a message to the group chat. + We cast the content to str here so that it can be managed by text-based + model. + """ + # set the name to speaker's name if the role is not function + # if the role is tool, it is OK to modify the name + if message["role"] != "function": + message["name"] = speaker.name + message["content"] = content_str(message["content"]) + self.messages.append(message) + + def agent_by_name( + self, name: str, recursive: bool = False, raise_on_name_conflict: bool = False + ) -> Agent | None: + """Returns the agent with a given name. If recursive is True, it will search in nested teams.""" + agents = self.nested_agents() if recursive else self.agents + filtered_agents = [agent for agent in agents if agent.name == name] + + if raise_on_name_conflict and len(filtered_agents) > 1: + raise ValueError("Found multiple agents with the same name.") + + return filtered_agents[0] if filtered_agents else None + + def nested_agents(self) -> list[Agent]: + """Returns all agents in the group chat manager.""" + agents = self.agents.copy() + for agent in agents: + if isinstance(agent, GroupChatManager): + # Recursive call for nested teams + agents.extend(agent.groupchat.nested_agents()) + return agents + + def next_agent(self, agent: Agent, agents: list[Agent] | None = None) -> Agent: + """Return the next agent in the list.""" + if agents is None: + agents = self.agents + + # Ensure the provided list of agents is a subset of self.agents + if not set(agents).issubset(set(self.agents)): + raise ValueError("The provided agents list does not overlap with agents in the group.") + + # What index is the agent? (-1 if not present) + idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1 + + # Return the next agent + if agents == self.agents: + return agents[(idx + 1) % len(agents)] + else: + offset = idx + 1 + for i in range(len(self.agents)): + if self.agents[(offset + i) % len(self.agents)] in agents: + return self.agents[(offset + i) % len(self.agents)] + + # Explicitly handle cases where no valid next agent exists in the provided subset. + raise ValueError("The provided agents list does not overlap with agents in the group.") + + def select_speaker_msg(self, agents: list[Agent] | None = None) -> str: + """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" + if agents is None: + agents = self.agents + + roles = self._participant_roles(agents) + agentlist = f"{[agent.name for agent in agents]}" + + return_msg = self.select_speaker_message_template.format(roles=roles, agentlist=agentlist) + return return_msg + + def select_speaker_prompt(self, agents: list[Agent] | None = None) -> str: + """Return the floating system prompt selecting the next speaker. + This is always the *last* message in the context. + Will return None if the select_speaker_prompt_template is None.""" + + if self.select_speaker_prompt_template is None: + return None + + if agents is None: + agents = self.agents + + agentlist = f"{[agent.name for agent in agents]}" + + return_prompt = self.select_speaker_prompt_template.format(agentlist=agentlist) + return return_prompt + + def introductions_msg(self, agents: list[Agent] | None = None) -> str: + """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" + if agents is None: + agents = self.agents + + # Use the class attribute instead of a hardcoded string + intro_msg = self.DEFAULT_INTRO_MSG + participant_roles = self._participant_roles(agents) + + return f"{intro_msg}\n\n{participant_roles}" + + def manual_select_speaker(self, agents: list[Agent] | None = None) -> Agent | None: + """Manually select the next speaker.""" + iostream = IOStream.get_default() + + if agents is None: + agents = self.agents + + iostream.print("Please select the next speaker from the following list:") + _n_agents = len(agents) + for i in range(_n_agents): + iostream.print(f"{i+1}: {agents[i].name}") + try_count = 0 + # Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking. + while try_count <= 3: + try_count += 1 + if try_count >= 3: + iostream.print(f"You have tried {try_count} times. The next speaker will be selected automatically.") + break + try: + i = iostream.input( + "Enter the number of the next speaker (enter nothing or `q` to use auto selection): " + ) + if i == "" or i == "q": + break + i = int(i) + if i > 0 and i <= _n_agents: + return agents[i - 1] + else: + raise ValueError + except ValueError: + iostream.print(f"Invalid input. Please enter a number between 1 and {_n_agents}.") + return None + + def random_select_speaker(self, agents: list[Agent] | None = None) -> Agent | None: + """Randomly select the next speaker.""" + if agents is None: + agents = self.agents + return random.choice(agents) + + def _prepare_and_select_agents( + self, + last_speaker: Agent, + ) -> tuple[Agent | None, list[Agent], list[dict]]: + # If self.speaker_selection_method is a callable, call it to get the next speaker. + # If self.speaker_selection_method is a string, return it. + speaker_selection_method = self.speaker_selection_method + if isinstance(self.speaker_selection_method, Callable): + selected_agent = self.speaker_selection_method(last_speaker, self) + if selected_agent is None: + raise NoEligibleSpeaker("Custom speaker selection function returned None. Terminating conversation.") + elif isinstance(selected_agent, Agent): + if selected_agent in self.agents: + return selected_agent, self.agents, None + else: + raise ValueError( + f"Custom speaker selection function returned an agent {selected_agent.name} not in the group chat." + ) + elif isinstance(selected_agent, str): + # If returned a string, assume it is a speaker selection method + speaker_selection_method = selected_agent + else: + raise ValueError( + f"Custom speaker selection function returned an object of type {type(selected_agent)} instead of Agent or str." + ) + + if speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS: + raise ValueError( + f"GroupChat speaker_selection_method is set to '{speaker_selection_method}'. " + f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). " + ) + + # If provided a list, make sure the agent is in the list + allow_repeat_speaker = ( + self.allow_repeat_speaker + if isinstance(self.allow_repeat_speaker, bool) or self.allow_repeat_speaker is None + else last_speaker in self.allow_repeat_speaker + ) + + agents = self.agents + n_agents = len(agents) + # Warn if GroupChat is underpopulated + if n_agents < 2: + raise ValueError( + f"GroupChat is underpopulated with {n_agents} agents. " + "Please add more agents to the GroupChat or use direct communication instead." + ) + elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker: + print( + f"GroupChat is underpopulated with {n_agents} agents. " + "Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, " + "or use direct communication, unless repeated speaker is desired." + ) + + if ( + self.func_call_filter + and self.messages + and ("function_call" in self.messages[-1] or "tool_calls" in self.messages[-1]) + ): + funcs = [] + if "function_call" in self.messages[-1]: + funcs += [self.messages[-1]["function_call"]["name"]] + if "tool_calls" in self.messages[-1]: + funcs += [ + tool["function"]["name"] for tool in self.messages[-1]["tool_calls"] if tool["type"] == "function" + ] + + # find agents with the right function_map which contains the function name + agents = [agent for agent in self.agents if agent.can_execute_function(funcs)] + if len(agents) == 1: + # only one agent can execute the function + return agents[0], agents, None + elif not agents: + # find all the agents with function_map + agents = [agent for agent in self.agents if agent.function_map] + if len(agents) == 1: + return agents[0], agents, None + elif not agents: + raise ValueError( + f"No agent can execute the function {', '.join(funcs)}. " + "Please check the function_map of the agents." + ) + # remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False + agents = [agent for agent in agents if agent != last_speaker] if allow_repeat_speaker is False else agents + + # Filter agents with allowed_speaker_transitions_dict + + is_last_speaker_in_group = last_speaker in self.agents + + # this condition means last_speaker is a sink in the graph, then no agents are eligible + if last_speaker not in self.allowed_speaker_transitions_dict and is_last_speaker_in_group: + raise NoEligibleSpeaker(f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict.") + # last_speaker is not in the group, so all agents are eligible + elif last_speaker not in self.allowed_speaker_transitions_dict and not is_last_speaker_in_group: + graph_eligible_agents = [] + else: + # Extract agent names from the list of agents + graph_eligible_agents = [ + agent for agent in agents if agent in self.allowed_speaker_transitions_dict[last_speaker] + ] + + # If there is only one eligible agent, just return it to avoid the speaker selection prompt + if len(graph_eligible_agents) == 1: + return graph_eligible_agents[0], graph_eligible_agents, None + + # If there are no eligible agents, return None, which means all agents will be taken into consideration in the next step + if len(graph_eligible_agents) == 0: + graph_eligible_agents = None + + # Use the selected speaker selection method + select_speaker_messages = None + if speaker_selection_method.lower() == "manual": + selected_agent = self.manual_select_speaker(graph_eligible_agents) + elif speaker_selection_method.lower() == "round_robin": + selected_agent = self.next_agent(last_speaker, graph_eligible_agents) + elif speaker_selection_method.lower() == "random": + selected_agent = self.random_select_speaker(graph_eligible_agents) + else: # auto + selected_agent = None + select_speaker_messages = self.messages.copy() + # If last message is a tool call or function call, blank the call so the api doesn't throw + if select_speaker_messages[-1].get("function_call", False): + select_speaker_messages[-1] = dict(select_speaker_messages[-1], function_call=None) + if select_speaker_messages[-1].get("tool_calls", False): + select_speaker_messages[-1] = dict(select_speaker_messages[-1], tool_calls=None) + return selected_agent, graph_eligible_agents, select_speaker_messages + + def select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent: + """Select the next speaker (with requery).""" + + # Prepare the list of available agents and select an agent if selection method allows (non-auto) + selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker) + if selected_agent: + return selected_agent + elif self.speaker_selection_method == "manual": + # An agent has not been selected while in manual mode, so move to the next agent + return self.next_agent(last_speaker) + + # auto speaker selection with 2-agent chat + return self._auto_select_speaker(last_speaker, selector, messages, agents) + + async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent: + """Select the next speaker (with requery), asynchronously.""" + + selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker) + if selected_agent: + return selected_agent + elif self.speaker_selection_method == "manual": + # An agent has not been selected while in manual mode, so move to the next agent + return self.next_agent(last_speaker) + + # auto speaker selection with 2-agent chat + return await self.a_auto_select_speaker(last_speaker, selector, messages, agents) + + def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: list[Agent] | None) -> Agent: + if not final: + # the LLM client is None, thus no reply is generated. Use round robin instead. + return self.next_agent(last_speaker, agents) + + # If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified + mentions = self._mentioned_agents(name, agents) + if len(mentions) == 1: + name = next(iter(mentions)) + else: + print( + f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}" + ) + + # Return the result + agent = self.agent_by_name(name) + return agent if agent else self.next_agent(last_speaker, agents) + + def _register_client_from_config(self, agent: Agent, config: dict): + model_client_cls_to_match = config.get("model_client_cls") + if model_client_cls_to_match: + if not self.select_speaker_auto_model_client_cls: + raise ValueError( + "A custom model was detected in the config but no 'model_client_cls' " + "was supplied for registration in GroupChat." + ) + + if isinstance(self.select_speaker_auto_model_client_cls, list): + # Register the first custom model client class matching the name specified in the config + matching_model_cls = [ + client_cls + for client_cls in self.select_speaker_auto_model_client_cls + if client_cls.__name__ == model_client_cls_to_match + ] + if len(set(matching_model_cls)) > 1: + raise RuntimeError( + f"More than one unique 'model_client_cls' with __name__ '{model_client_cls_to_match}'." + ) + if not matching_model_cls: + raise ValueError( + "No model's __name__ matches the model client class " + f"'{model_client_cls_to_match}' specified in select_speaker_auto_llm_config." + ) + select_speaker_auto_model_client_cls = matching_model_cls[0] + else: + # Register the only custom model client + select_speaker_auto_model_client_cls = self.select_speaker_auto_model_client_cls + + agent.register_model_client(select_speaker_auto_model_client_cls) + + def _register_custom_model_clients(self, agent: ConversableAgent): + if not self.select_speaker_auto_llm_config: + return + + config_format_is_list = "config_list" in self.select_speaker_auto_llm_config.keys() + if config_format_is_list: + for config in self.select_speaker_auto_llm_config["config_list"]: + self._register_client_from_config(agent, config) + elif not config_format_is_list: + self._register_client_from_config(agent, self.select_speaker_auto_llm_config) + + def _create_internal_agents( + self, agents, max_attempts, messages, validate_speaker_name, selector: ConversableAgent | None = None + ): + checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts) + + # Register the speaker validation function with the checking agent + checking_agent.register_reply( + [ConversableAgent, None], + reply_func=validate_speaker_name, # Validate each response + remove_other_reply_funcs=True, + ) + + # Override the selector's config if one was passed as a parameter to this class + speaker_selection_llm_config = self.select_speaker_auto_llm_config or selector.llm_config + + # Agent for selecting a single agent name from the response + speaker_selection_agent = ConversableAgent( + "speaker_selection_agent", + system_message=self.select_speaker_msg(agents), + chat_messages={checking_agent: messages}, + llm_config=speaker_selection_llm_config, + human_input_mode="NEVER", + # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose + ) + + # Register any custom model passed in select_speaker_auto_llm_config with the speaker_selection_agent + self._register_custom_model_clients(speaker_selection_agent) + + return checking_agent, speaker_selection_agent + + def _auto_select_speaker( + self, + last_speaker: Agent, + selector: ConversableAgent, + messages: list[dict], + agents: list[Agent] | None, + ) -> Agent: + """Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying. + + Speaker selection for "auto" speaker selection method: + 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat + 2. Inject the group messages into the new chat + 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent: + - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response + 4. Chat continues until a single agent is nominated or there are no more attempts left + 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned + + Args: + last_speaker Agent: The previous speaker in the group chat + selector ConversableAgent: + messages list[dict]: Current chat messages + agents list[Agent] | None: Valid list of agents for speaker selection + + Returns: + Dict: a counter for mentioned agents. + """ + + # If no agents are passed in, assign all the group chat's agents + if agents is None: + agents = self.agents + + # The maximum number of speaker selection attempts (including requeries) + # is the initial speaker selection attempt plus the maximum number of retries. + # We track these and use them in the validation function as we can't + # access the max_turns from within validate_speaker_name. + max_attempts = 1 + self.max_retries_for_selecting_speaker + attempts_left = max_attempts + attempt = 0 + + # Registered reply function for checking_agent, checks the result of the response for agent names + def validate_speaker_name(recipient, messages, sender, config) -> tuple[bool, str | dict | None]: + # The number of retries left, starting at max_retries_for_selecting_speaker + nonlocal attempts_left + nonlocal attempt + + attempt = attempt + 1 + attempts_left = attempts_left - 1 + + return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents) + + # Two-agent chat for speaker selection + + # Agent for checking the response from the speaker_select_agent + checking_agent, speaker_selection_agent = self._create_internal_agents( + agents, max_attempts, messages, validate_speaker_name, selector + ) + + # Create the starting message + if self.select_speaker_prompt_template is not None: + start_message = { + "content": self.select_speaker_prompt(agents), + "name": "checking_agent", + "override_role": self.role_for_select_speaker_messages, + } + else: + start_message = messages[-1] + + # Add the message transforms, if any, to the speaker selection agent + if self._speaker_selection_transforms is not None: + self._speaker_selection_transforms.add_to_agent(speaker_selection_agent) + + # Run the speaker selection chat + result = checking_agent.initiate_chat( + speaker_selection_agent, + cache=None, # don't use caching for the speaker selection chat + message=start_message, + max_turns=2 + * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one + clear_history=False, + silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute + ) + + return self._process_speaker_selection_result(result, last_speaker, agents) + + async def a_auto_select_speaker( + self, + last_speaker: Agent, + selector: ConversableAgent, + messages: list[dict], + agents: list[Agent] | None, + ) -> Agent: + """(Asynchronous) Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying. + + Speaker selection for "auto" speaker selection method: + 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat + 2. Inject the group messages into the new chat + 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent: + - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response + 4. Chat continues until a single agent is nominated or there are no more attempts left + 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned + + Args: + last_speaker Agent: The previous speaker in the group chat + selector ConversableAgent: + messages list[dict]: Current chat messages + agents list[Agent] | None: Valid list of agents for speaker selection + + Returns: + Dict: a counter for mentioned agents. + """ + + # If no agents are passed in, assign all the group chat's agents + if agents is None: + agents = self.agents + + # The maximum number of speaker selection attempts (including requeries) + # We track these and use them in the validation function as we can't + # access the max_turns from within validate_speaker_name + max_attempts = 1 + self.max_retries_for_selecting_speaker + attempts_left = max_attempts + attempt = 0 + + # Registered reply function for checking_agent, checks the result of the response for agent names + def validate_speaker_name(recipient, messages, sender, config) -> tuple[bool, str | dict | None]: + # The number of retries left, starting at max_retries_for_selecting_speaker + nonlocal attempts_left + nonlocal attempt + + attempt = attempt + 1 + attempts_left = attempts_left - 1 + + return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents) + + # Two-agent chat for speaker selection + + # Agent for checking the response from the speaker_select_agent + checking_agent, speaker_selection_agent = self._create_internal_agents( + agents, max_attempts, messages, validate_speaker_name, selector + ) + + # Create the starting message + if self.select_speaker_prompt_template is not None: + start_message = { + "content": self.select_speaker_prompt(agents), + "override_role": self.role_for_select_speaker_messages, + } + else: + start_message = messages[-1] + + # Add the message transforms, if any, to the speaker selection agent + if self._speaker_selection_transforms is not None: + self._speaker_selection_transforms.add_to_agent(speaker_selection_agent) + + # Run the speaker selection chat + result = await checking_agent.a_initiate_chat( + speaker_selection_agent, + cache=None, # don't use caching for the speaker selection chat + message=start_message, + max_turns=2 + * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one + clear_history=False, + silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute + ) + + return self._process_speaker_selection_result(result, last_speaker, agents) + + def _validate_speaker_name( + self, recipient, messages: list[dict[str, str]], sender, config, attempts_left, attempt, agents + ) -> tuple[bool, str | dict | None]: + """Validates the speaker response for each round in the internal 2-agent + chat within the auto select speaker method. + + Used by auto_select_speaker and a_auto_select_speaker. + """ + + # Output the query and requery results + if self.select_speaker_auto_verbose: + iostream = IOStream.get_default() + + # Validate the speaker name selected + select_name = messages[-1]["content"].strip() + + mentions = self._mentioned_agents(select_name, agents) + + if len(mentions) == 1: + # Success on retry, we have just one name mentioned + selected_agent_name = next(iter(mentions)) + + # Add the selected agent to the response so we can return it + messages.append({"role": "user", "content": f"[AGENT SELECTED]{selected_agent_name}"}) + + if self.select_speaker_auto_verbose: + iostream.print( + colored( + f">>>>>>>> Select speaker attempt {attempt} of {attempt + attempts_left} successfully selected: {selected_agent_name}", + "green", + ), + flush=True, + ) + + elif len(mentions) > 1: + # More than one name on requery so add additional reminder prompt for next retry + + if self.select_speaker_auto_verbose: + iostream.print( + colored( + f">>>>>>>> Select speaker attempt {attempt} of {attempt + attempts_left} failed as it included multiple agent names.", + "red", + ), + flush=True, + ) + + if attempts_left: + # Message to return to the chat for the next attempt + agentlist = f"{[agent.name for agent in agents]}" + + return True, { + "content": self.select_speaker_auto_multiple_template.format(agentlist=agentlist), + "name": "checking_agent", + "override_role": self.role_for_select_speaker_messages, + } + else: + # Final failure, no attempts left + messages.append( + { + "role": "user", + "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it returned multiple names.", + } + ) + + else: + # No names at all on requery so add additional reminder prompt for next retry + + if self.select_speaker_auto_verbose: + iostream.print( + colored( + f">>>>>>>> Select speaker attempt #{attempt} failed as it did not include any agent names.", + "red", + ), + flush=True, + ) + + if attempts_left: + # Message to return to the chat for the next attempt + agentlist = f"{[agent.name for agent in agents]}" + + return True, { + "content": self.select_speaker_auto_none_template.format(agentlist=agentlist), + "name": "checking_agent", + "override_role": self.role_for_select_speaker_messages, + } + else: + # Final failure, no attempts left + messages.append( + { + "role": "user", + "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it did not include any agent names.", + } + ) + + return True, None + + def _process_speaker_selection_result(self, result, last_speaker: ConversableAgent, agents: list[Agent] | None): + """Checks the result of the auto_select_speaker function, returning the + agent to speak. + + Used by auto_select_speaker and a_auto_select_speaker.""" + if len(result.chat_history) > 0: + # Use the final message, which will have the selected agent or reason for failure + final_message = result.chat_history[-1]["content"] + + if "[AGENT SELECTED]" in final_message: + # Have successfully selected an agent, return it + return self.agent_by_name(final_message.replace("[AGENT SELECTED]", "")) + + else: # "[AGENT SELECTION FAILED]" + # Failed to select an agent, so we'll select the next agent in the list + next_agent = self.next_agent(last_speaker, agents) + + # No agent, return the failed reason + return next_agent + + def _participant_roles(self, agents: list["Agent"] | None = None) -> str: + # Default to all agents registered + if agents is None: + agents = self.agents + + roles = [] + for agent in agents: + if agent.description.strip() == "": + print( + f"The agent '{agent.name}' has an empty description, and may not work well with GroupChat." + ) + roles.append(f"{agent.name}: {agent.description}".strip()) + return "\n".join(roles) + + def _mentioned_agents(self, message_content: str | list, agents: list[Agent] | None) -> dict: + """Counts the number of times each agent is mentioned in the provided message content. + Agent names will match under any of the following conditions (all case-sensitive): + - Exact name match + - If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer') + - If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer') + + Args: + message_content (Union[str, List]): The content of the message, either as a single string or a list of strings. + agents (List[Agent]): A list of Agent objects, each having a 'name' attribute to be searched in the message content. + + Returns: + Dict: a counter for mentioned agents. + """ + if agents is None: + agents = self.agents + + # Cast message content to str + if isinstance(message_content, dict): + message_content = message_content["content"] + message_content = content_str(message_content) + + mentions = dict() + for agent in agents: + # Finds agent mentions, taking word boundaries into account, + # accommodates escaping underscores and underscores as spaces + regex = ( + r"(?<=\W)(" + + re.escape(agent.name) + + r"|" + + re.escape(agent.name.replace("_", " ")) + + r"|" + + re.escape(agent.name.replace("_", r"\_")) + + r")(?=\W)" + ) + count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching + if count > 0: + mentions[agent.name] = count + return mentions + + +class GroupChatManager(ConversableAgent): + + def __init__( + self, + groupchat: GroupChat, + name: str | None = "chat_manager", + max_consecutive_auto_reply: int | None = sys.maxsize, + human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", + system_message: str | list | None = "Group chat manager.", + silent: bool = False, + **kwargs, + ): + if ( + kwargs.get("llm_config") + and isinstance(kwargs["llm_config"], dict) + and (kwargs["llm_config"].get("functions") or kwargs["llm_config"].get("tools")) + ): + raise ValueError( + "GroupChatManager is not allowed to make function/tool calls. Please remove the 'functions' or 'tools' config in 'llm_config' you passed in." + ) + + super().__init__( + name=name, + max_consecutive_auto_reply=max_consecutive_auto_reply, + human_input_mode=human_input_mode, + system_message=system_message, + **kwargs, + ) + + # Store groupchat + self._groupchat = groupchat + + self._last_speaker = None + self._silent = silent + + # Order of register_reply is important. + # Allow sync chat if initiated using initiate_chat + self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset) + # Allow async chat if initiated using a_initiate_chat + self.register_reply( + Agent, + GroupChatManager.a_run_chat, + config=groupchat, + reset_config=GroupChat.reset, + ignore_async_in_sync_chat=True, + ) + + @property + def groupchat(self) -> GroupChat: + """Returns the group chat managed by the group chat manager.""" + return self._groupchat + + def _prepare_chat( + self, + recipient: ConversableAgent, + clear_history: bool, + prepare_recipient: bool = True, + reply_at_receive: bool = True, + ) -> None: + super()._prepare_chat(recipient, clear_history, prepare_recipient, reply_at_receive) + + if clear_history: + self._groupchat.reset() + + for agent in self._groupchat.agents: + if (recipient != agent or prepare_recipient) and isinstance(agent, ConversableAgent): + agent._prepare_chat(self, clear_history, False, reply_at_receive) + + @property + def last_speaker(self) -> Agent: + """Return the agent who sent the last message to group chat manager. + + In a group chat, an agent will always send a message to the group chat manager, and the group chat manager will + send the message to all other agents in the group chat. So, when an agent receives a message, it will always be + from the group chat manager. With this property, the agent receiving the message can know who actually sent the + message. + """ + return self._last_speaker + + def run_chat( + self, + messages: list[dict] | None = None, + sender: Agent | None = None, + config: GroupChat | None = None, + ) -> tuple[bool, str | None]: + """Run a group chat.""" + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + speaker = sender + groupchat = config + send_introductions = getattr(groupchat, "send_introductions", False) + silent = getattr(self, "_silent", False) + + if send_introductions: + # Broadcast the intro + intro = groupchat.introductions_msg() + for agent in groupchat.agents: + self.send(intro, agent, request_reply=False, silent=True) + # NOTE: We do not also append to groupchat.messages, + # since groupchat handles its own introductions + + if self.client_cache is not None: + for a in groupchat.agents: + a.previous_cache = a.client_cache + a.client_cache = self.client_cache + for i in range(groupchat.max_round): + self._last_speaker = speaker + groupchat.append(message, speaker) + # broadcast the message to all agents except the speaker + for agent in groupchat.agents: + if agent != speaker: + self.send(message, agent, request_reply=False, silent=True) + if self._is_termination_msg(message) or i == groupchat.max_round - 1: + # The conversation is over or it's the last round + break + try: + # select the next speaker + speaker = groupchat.select_speaker(speaker, self) + if not silent: + iostream = IOStream.get_default() + iostream.print(colored(f"\nNext speaker: {speaker.name}\n", "green"), flush=True) + # let the speaker speak + reply = speaker.generate_reply(sender=self) + except KeyboardInterrupt: + # let the admin agent speak if interrupted + if groupchat.admin_name in groupchat.agent_names: + # admin agent is one of the participants + speaker = groupchat.agent_by_name(groupchat.admin_name) + reply = speaker.generate_reply(sender=self) + else: + # admin agent is not found in the participants + raise + except NoEligibleSpeaker: + # No eligible speaker, terminate the conversation + print("No eligible speaker found. Terminating the conversation.") + break + + if reply is None: + # no reply is generated, exit the chat + break + + # check for "clear history" phrase in reply and activate clear history function if found + if ( + groupchat.enable_clear_history + and isinstance(reply, dict) + and reply["content"] + and "CLEAR HISTORY" in reply["content"].upper() + ): + reply["content"] = self.clear_agents_history(reply, groupchat) + + # The speaker sends the message without requesting a reply + speaker.send(reply, self, request_reply=False, silent=silent) + message = self.last_message(speaker) + if self.client_cache is not None: + for a in groupchat.agents: + a.client_cache = a.previous_cache + a.previous_cache = None + return True, None + + async def a_run_chat( + self, + messages: list[dict] | None = None, + sender: Agent | None = None, + config: GroupChat | None = None, + ): + """Run a group chat asynchronously.""" + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + speaker = sender + groupchat = config + send_introductions = getattr(groupchat, "send_introductions", False) + silent = getattr(self, "_silent", False) + + if send_introductions: + # Broadcast the intro + intro = groupchat.introductions_msg() + for agent in groupchat.agents: + await self.a_send(intro, agent, request_reply=False, silent=True) + # NOTE: We do not also append to groupchat.messages, + # since groupchat handles its own introductions + + if self.client_cache is not None: + for a in groupchat.agents: + a.previous_cache = a.client_cache + a.client_cache = self.client_cache + for i in range(groupchat.max_round): + groupchat.append(message, speaker) + + if self._is_termination_msg(message): + # The conversation is over + break + + # broadcast the message to all agents except the speaker + for agent in groupchat.agents: + if agent != speaker: + await self.a_send(message, agent, request_reply=False, silent=True) + if i == groupchat.max_round - 1: + # the last round + break + try: + # select the next speaker + speaker = await groupchat.a_select_speaker(speaker, self) + # let the speaker speak + reply = await speaker.a_generate_reply(sender=self) + except KeyboardInterrupt: + # let the admin agent speak if interrupted + if groupchat.admin_name in groupchat.agent_names: + # admin agent is one of the participants + speaker = groupchat.agent_by_name(groupchat.admin_name) + reply = await speaker.a_generate_reply(sender=self) + else: + # admin agent is not found in the participants + raise + except NoEligibleSpeaker: + # No eligible speaker, terminate the conversation + print("No eligible speaker found. Terminating the conversation.") + break + + if reply is None: + break + # The speaker sends the message without requesting a reply + await speaker.a_send(reply, self, request_reply=False, silent=silent) + message = self.last_message(speaker) + if self.client_cache is not None: + for a in groupchat.agents: + a.client_cache = a.previous_cache + a.previous_cache = None + return True, None + + + def _raise_exception_on_async_reply_functions(self) -> None: + """Raise an exception if any async reply functions are registered. + + Raises: + RuntimeError: if any async reply functions are registered. + """ + super()._raise_exception_on_async_reply_functions() + + for agent in self._groupchat.agents: + agent._raise_exception_on_async_reply_functions() + + def clear_agents_history(self, reply: dict, groupchat: GroupChat) -> str: + """Clears history of messages for all agents or selected one. Can preserve selected number of last messages. + That function is called when user manually provide "clear history" phrase in his reply. + When "clear history" is provided, the history of messages for all agents is cleared. + When "clear history " is provided, the history of messages for selected agent is cleared. + When "clear history " is provided, the history of messages for all agents is cleared + except last messages. + When "clear history " is provided, the history of messages for selected + agent is cleared except last messages. + Phrase "clear history" and optional arguments are cut out from the reply before it passed to the chat. + + Args: + reply (dict): reply message dict to analyze. + groupchat (GroupChat): GroupChat object. + """ + iostream = IOStream.get_default() + + reply_content = reply["content"] + # Split the reply into words + words = reply_content.split() + # Find the position of "clear" to determine where to start processing + clear_word_index = next(i for i in reversed(range(len(words))) if words[i].upper() == "CLEAR") + # Extract potential agent name and steps + words_to_check = words[clear_word_index + 2 : clear_word_index + 4] + nr_messages_to_preserve = None + nr_messages_to_preserve_provided = False + agent_to_memory_clear = None + + for word in words_to_check: + if word.isdigit(): + nr_messages_to_preserve = int(word) + nr_messages_to_preserve_provided = True + elif word[:-1].isdigit(): # for the case when number of messages is followed by dot or other sign + nr_messages_to_preserve = int(word[:-1]) + nr_messages_to_preserve_provided = True + else: + for agent in groupchat.agents: + if agent.name == word: + agent_to_memory_clear = agent + break + elif agent.name == word[:-1]: # for the case when agent name is followed by dot or other sign + agent_to_memory_clear = agent + break + # preserve last tool call message if clear history called inside of tool response + if "tool_responses" in reply and not nr_messages_to_preserve: + nr_messages_to_preserve = 1 + print( + "The last tool call message will be saved to prevent errors caused by tool response without tool call." + ) + # clear history + if agent_to_memory_clear: + if nr_messages_to_preserve: + iostream.print( + f"Clearing history for {agent_to_memory_clear.name} except last {nr_messages_to_preserve} messages." + ) + else: + iostream.print(f"Clearing history for {agent_to_memory_clear.name}.") + agent_to_memory_clear.clear_history(nr_messages_to_preserve=nr_messages_to_preserve) + else: + if nr_messages_to_preserve: + iostream.print(f"Clearing history for all agents except last {nr_messages_to_preserve} messages.") + # clearing history for groupchat here + temp = groupchat.messages[-nr_messages_to_preserve:] + groupchat.messages.clear() + groupchat.messages.extend(temp) + else: + iostream.print("Clearing history for all agents.") + # clearing history for groupchat here + groupchat.messages.clear() + # clearing history for agents + for agent in groupchat.agents: + agent.clear_history(nr_messages_to_preserve=nr_messages_to_preserve) + + # Reconstruct the reply without the "clear history" command and parameters + skip_words_number = 2 + int(bool(agent_to_memory_clear)) + int(nr_messages_to_preserve_provided) + reply_content = " ".join(words[:clear_word_index] + words[clear_word_index + skip_words_number :]) + + return reply_content diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py new file mode 100644 index 0000000..26f39de --- /dev/null +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -0,0 +1,1667 @@ +import asyncio +import contextvars +import copy +import functools +import inspect +import re +import warnings +from collections import defaultdict +from typing import Any, Callable, Coroutine, Literal, Type, Protocol, TypeVar + +from pydantic import BaseModel +from termcolor import colored + +from train_methods.legacy_autogen.chat import ChatResult, a_initiate_chats, initiate_chats, _post_process_carryover_item, consolidate_chat_info +from train_methods.legacy_autogen.client import ModelClient, OpenAIWrapper +from train_methods.legacy_autogen.stream import IOStream +from train_methods.legacy_autogen.utils import ( + content_str, + load_basemodels_if_needed, + serialize_to_str +) + +__all__ = ("ConversableAgent",) + +F = TypeVar("F", bound=Callable[..., Any]) + +def model_dump(model: BaseModel) -> dict[str, Any]: + return model.model_dump() + +class SenderRequired(Exception): + """Exception raised when the sender is required but not provided.""" + + def __init__(self, message: str = "Sender is required but not provided."): + self.message = message + super().__init__(self.message) + +class InvalidCarryOverType(Exception): + """Exception raised when the carryover type is invalid.""" + + def __init__( + self, message: str = "Carryover should be a string or a list of strings. Not adding carryover to the message." + ): + self.message = message + super().__init__(self.message) + + +class Agent(Protocol): + """(In preview) A protocol for Agent. + + An agent can communicate with other agents and perform actions. + Different agents can differ in what actions they perform in the `receive` method. + """ + + @property + def name(self) -> str: + """The name of the agent.""" + ... + + @property + def description(self) -> str: + """The description of the agent. Used for the agent's introduction in a group chat setting.""" + ... + + def send( + self, + message: dict[str, Any] | str, + recipient: "Agent", + request_reply: bool | None = None, + ) -> None: + """Send a message to another agent. + + Args: + message (dict or str): the message to send. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + recipient (Agent): the recipient of the message. + request_reply (bool): whether to request a reply from the recipient. + """ + ... + + async def a_send( + self, + message: dict[str, Any] | str, + recipient: "Agent", + request_reply: bool | None = None, + ) -> None: + """(Async) Send a message to another agent. + + Args: + message (dict or str): the message to send. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + recipient (Agent): the recipient of the message. + request_reply (bool): whether to request a reply from the recipient. + """ + ... + + def receive( + self, + message: dict[str, Any] | str, + sender: "Agent", + request_reply: bool | None = None, + ) -> None: + """Receive a message from another agent. + + Args: + message (dict or str): the message received. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + sender (Agent): the sender of the message. + request_reply (bool): whether the sender requests a reply. + """ + + async def a_receive( + self, + message: dict[str, Any] | str, + sender: "Agent", + request_reply: bool | None = None, + ) -> None: + """(Async) Receive a message from another agent. + + Args: + message (dict or str): the message received. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + sender (Agent): the sender of the message. + request_reply (bool): whether the sender requests a reply. + """ + ... + + def generate_reply( + self, + messages: list[dict[str, Any]] | None = None, + sender: Literal["Agent"] | None = None, + **kwargs: Any, + ) -> str | dict[str, Any] | None: + """Generate a reply based on the received messages. + + Args: + messages (list[dict]): a list of messages received from other agents. + The messages are dictionaries that are JSON-serializable and + follows the OpenAI's ChatCompletion schema. + sender: sender of an Agent instance. + + Returns: + str or dict or None: the generated reply. If None, no reply is generated. + """ + + async def a_generate_reply( + self, + messages: list[dict[str, Any]] | None = None, + sender: Literal["Agent"] | None = None, + **kwargs: Any, + ) -> str | dict[str, Any] | None: + """(Async) Generate a reply based on the received messages. + + Args: + messages (list[dict]): a list of messages received from other agents. + The messages are dictionaries that are JSON-serializable and + follows the OpenAI's ChatCompletion schema. + sender: sender of an Agent instance. + + Returns: + str or dict or None: the generated reply. If None, no reply is generated. + """ + +class LLMAgent(Agent, Protocol): + """(In preview) A protocol for an LLM agent.""" + + @property + def system_message(self) -> str: + """The system message of this agent.""" + + +class ConversableAgent(LLMAgent): + """A class for generic conversable agents which can be configured as assistant or user proxy. + + After receiving each message, the agent will send a reply to the sender unless the msg is a termination msg. + For example, AssistantAgent and UserProxyAgent are subclasses of this class, + configured with different default settings. + """ + + DEFAULT_CONFIG = False # False or dict, the default config for llm inference + MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change) + + DEFAULT_SUMMARY_PROMPT = "Summarize the takeaway from the conversation. Do not add any introductory phrases." + DEFAULT_SUMMARY_METHOD = "last_msg" + llm_config: dict | Literal[False] + + def __init__( + self, + name: str, + system_message: str | list | None = "You are a helpful AI Assistant.", + is_termination_msg: Callable[[dict], bool] | None = None, + llm_config: dict | Literal[False] | None = None, + ): + """ + Args: + name (str): name of the agent. + system_message (str or list): system message for the ChatCompletion inference. + is_termination_msg (function): a function that takes a message in the form of a dictionary + and returns a boolean value indicating if this received message is a termination message. + The dict can contain the following keys: "content", "role" and "name". + max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. + default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). + When set to 0, no auto reply will be generated. + llm_config (dict or False or None): llm inference configuration. + Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) + for available options. + When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`. + To disable llm-based auto reply, set to False. + When set to None, will use self.DEFAULT_CONFIG, which defaults to False. + """ + + self._name = name + self._oai_messages = defaultdict(list) + + self._oai_system_message = [{"content": system_message, "role": "system"}] + self._description = system_message + self._is_termination_msg = ( + is_termination_msg + if is_termination_msg is not None + else (lambda x: content_str(x.get("content")) == "TERMINATE") + ) + # Take a copy to avoid modifying the given dict + if isinstance(llm_config, dict): + llm_config = copy.deepcopy(llm_config) + + self._validate_llm_config(llm_config) + + self.client_cache = None + self._max_consecutive_auto_reply = self.MAX_CONSECUTIVE_AUTO_REPLY + self._consecutive_auto_reply_counter = defaultdict(int) + self._max_consecutive_auto_reply_dict = defaultdict(self.max_consecutive_auto_reply) + self._reply_func_list = [] + self._human_input = [] + self.reply_at_receive = defaultdict(bool) + self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) + self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True) + + self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply) + self.register_reply( + [Agent, None], ConversableAgent.a_check_termination_and_human_reply, ignore_async_in_sync_chat=True + ) + + # Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration. + # New hookable methods should be added to this list as required to support new agent capabilities. + self.hook_lists: dict[str, list[Callable | Callable[..., Coroutine]]] = { + "process_last_received_message": [], + "a_process_last_received_message": [], + "process_all_messages_before_reply": [], + "a_process_all_messages_before_reply": [], + "process_message_before_send": [], + "a_process_message_before_send": [], + } + + def _validate_llm_config(self, llm_config): + assert llm_config in (None, False) or isinstance( + llm_config, dict + ), "llm_config must be a dict or False or None." + if llm_config is None: + llm_config = self.DEFAULT_CONFIG + self.llm_config = self.DEFAULT_CONFIG if llm_config is None else llm_config + # TODO: more complete validity check + if self.llm_config in [{}, {"config_list": []}, {"config_list": [{"model": ""}]}]: + raise ValueError( + "When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'." + ) + self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config) + + @property + def name(self) -> str: + """Get the name of the agent.""" + return self._name + + @property + def description(self) -> str: + """Get the description of the agent.""" + return self._description + + @description.setter + def description(self, description: str): + """Set the description of the agent.""" + self._description = description + + def register_reply( + self, + trigger: Type[Agent] | str | Agent | Callable[[Agent], bool] | list, + reply_func: Callable, + position: int = 0, + config: Any | None = None, + reset_config: Callable | None = None, + *, + ignore_async_in_sync_chat: bool = False, + remove_other_reply_funcs: bool = False, + ): + """Register a reply function. + + The reply function will be called when the trigger matches the sender. + The function registered later will be checked earlier by default. + To change the order, set the position to a positive integer. + + Both sync and async reply functions can be registered. The sync reply function will be triggered + from both sync and async chats. However, an async reply function will only be triggered from async + chats (initiated with `ConversableAgent.a_initiate_chat`). If an `async` reply function is registered + and a chat is initialized with a sync function, `ignore_async_in_sync_chat` determines the behaviour as follows: + if `ignore_async_in_sync_chat` is set to `False` (default value), an exception will be raised, and + if `ignore_async_in_sync_chat` is set to `True`, the reply function will be ignored. + + Args: + trigger (Agent class, str, Agent instance, callable, or list): the trigger. + If a class is provided, the reply function will be called when the sender is an instance of the class. + If a string is provided, the reply function will be called when the sender's name matches the string. + If an agent instance is provided, the reply function will be called when the sender is the agent instance. + If a callable is provided, the reply function will be called when the callable returns True. + If a list is provided, the reply function will be called when any of the triggers in the list is activated. + If None is provided, the reply function will be called only when the sender is None. + Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`. + reply_func (Callable): the reply function. + The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. + position (int): the position of the reply function in the reply function list. + The function registered later will be checked earlier by default. + To change the order, set the position to a positive integer. + config (Any): the config to be passed to the reply function. + When an agent is reset, the config will be reset to the original value. + reset_config (Callable): the function to reset the config. + The function returns None. Signature: ```def reset_config(config: Any)``` + ignore_async_in_sync_chat (bool): whether to ignore the async reply function in sync chats. If `False`, an exception + will be raised if an async reply function is registered and a chat is initialized with a sync + function. + remove_other_reply_funcs (bool): whether to remove other reply functions when registering this reply function. + """ + if not isinstance(trigger, (type, str, Agent, Callable, list)): + raise ValueError("trigger must be a class, a string, an agent, a callable or a list.") + if remove_other_reply_funcs: + self._reply_func_list.clear() + self._reply_func_list.insert( + position, + { + "trigger": trigger, + "reply_func": reply_func, + "config": copy.copy(config), + "init_config": config, + "reset_config": reset_config, + "ignore_async_in_sync_chat": ignore_async_in_sync_chat and inspect.iscoroutinefunction(reply_func), + }, + ) + + + @property + def system_message(self) -> str: + """Return the system message.""" + return self._oai_system_message[0]["content"] + + def max_consecutive_auto_reply(self, sender: Agent | None = None) -> int: + """The maximum number of consecutive auto replies.""" + return self._max_consecutive_auto_reply if sender is None else self._max_consecutive_auto_reply_dict[sender] + + @property + def chat_messages(self) -> dict[Agent, list[dict]]: + """A dictionary of conversations from agent to list of messages.""" + return self._oai_messages + + def last_message(self, agent: Agent | None = None) -> dict | None: + """The last message exchanged with the agent. + + Args: + agent (Agent): The agent in the conversation. + If None and more than one agent's conversations are found, an error will be raised. + If None and only one conversation is found, the last message of the only conversation will be returned. + + Returns: + The last message exchanged with the agent. + """ + if agent is None: + n_conversations = len(self._oai_messages) + if n_conversations == 0: + return None + if n_conversations == 1: + for conversation in self._oai_messages.values(): + return conversation[-1] + raise ValueError("More than one conversation is found. Please specify the sender to get the last message.") + if agent not in self._oai_messages.keys(): + raise KeyError( + f"The agent '{agent.name}' is not present in any conversation. No history available for this agent." + ) + return self._oai_messages[agent][-1] + + @staticmethod + def _message_to_dict(message: dict | str) -> dict: + """Convert a message to a dictionary. + + The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary. + """ + if isinstance(message, str): + return {"content": message} + elif isinstance(message, dict): + return message + else: + return dict(message) + + @staticmethod + def _normalize_name(name): + """ + LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_". + + Prefer _assert_valid_name for validating user configuration or input + """ + return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64] + + @staticmethod + def _assert_valid_name(name): + """ + Ensure that configured names are valid, raises ValueError if not. + + For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API. + """ + if not re.match(r"^[a-zA-Z0-9_-]+$", name): + raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.") + if len(name) > 64: + raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.") + return name + + def _append_oai_message(self, message: dict | str, role, conversation_id: Agent, is_sending: bool) -> bool: + """Append a message to the ChatCompletion conversation. + + If the message received is a string, it will be put in the "content" field of the new dictionary. + If the message received is a dictionary but does not have any of the three fields "content", this message is not a valid ChatCompletion message. + + Args: + message (dict or str): message to be appended to the ChatCompletion conversation. + role (str): role of the message, can be "assistant" or "function". + conversation_id (Agent): id of the conversation, should be the recipient or sender. + is_sending (bool): If the agent (aka self) is sending to the conversation_id agent, otherwise receiving. + + Returns: + bool: whether the message is appended to the ChatCompletion conversation. + """ + message = self._message_to_dict(message) + # create oai message to be appended to the oai conversation that can be passed to oai directly. + oai_message = { + k: message[k] + for k in ("content", "name", "context") + if k in message and message[k] is not None + } + if "content" not in oai_message: + return False + + if "override_role" in message: + # If we have a direction to override the role then set the + # role accordingly. Used to customise the role for the + # select speaker prompt. + oai_message["role"] = message.get("override_role") + else: + oai_message["role"] = role + + if "name" not in oai_message: + # If we don't have a name field, append it + if is_sending: + oai_message["name"] = self.name + else: + oai_message["name"] = conversation_id.name + + self._oai_messages[conversation_id].append(oai_message) + + return True + + def _process_message_before_send( + self, message: dict | str, recipient: Agent + ) -> dict | str: + """Process the message before sending it to the recipient.""" + hook_list = self.hook_lists["process_message_before_send"] + for hook in hook_list: + if inspect.iscoroutinefunction(hook): + continue + message = hook( + sender=self, message=message, recipient=recipient, silent=False + ) + return message + + async def _a_process_message_before_send( + self, message: dict | str, recipient: Agent + ) -> dict | str: + """(async) Process the message before sending it to the recipient.""" + hook_list = self.hook_lists["a_process_message_before_send"] + for hook in hook_list: + if not inspect.iscoroutinefunction(hook): + continue + message = await hook(sender=self, message=message, recipient=recipient, silent=False) + return message + + def send( + self, + message: dict | str, + recipient: Agent, + request_reply: bool | None = None, + silent: bool | None = False, + ): + """Send a message to another agent. + + Args: + message (dict or str): message to be sent. + The message could contain the following fields: + - content (str or List): Required, the content of the message. (Can be None) + - name (str): the name of the function to be called. + - role (str): the role of the message, any role that is not "function" + will be modified to "assistant". + - context (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](../oai/client#create). + Next time, one agent can send a message B with a different "use_tool_msg". + Then the content of message A will be refreshed to the new "use_tool_msg". + So effectively, this provides a way for an agent to send a "link" and modify + the content of the "link" later. + recipient (Agent): the recipient of the message. + request_reply (bool or None): whether to request a reply from the recipient. + silent (bool or None): (Experimental) whether to print the message sent. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + message = self._process_message_before_send(message, recipient) + # When the agent composes and sends the message, the role of the message is "assistant" + # unless it's "function". + valid = self._append_oai_message(message, "assistant", recipient, is_sending=True) + if valid: + recipient.receive(message, self, request_reply, silent) + else: + raise ValueError( + "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." + ) + + async def a_send( + self, + message: dict | str, + recipient: Agent, + request_reply: bool | None = None, + silent: bool | None = False, + ): + """(async) Send a message to another agent.""" + message = await self._a_process_message_before_send(message, recipient) + # When the agent composes and sends the message, the role of the message is "assistant" + # unless it's "function". + valid = self._append_oai_message(message, "assistant", recipient, is_sending=True) + if valid: + await recipient.a_receive(message, self, request_reply, silent) + else: + raise ValueError( + "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." + ) + + def _print_received_message(self, message: dict | str, sender: Agent): + iostream = IOStream.get_default() + # print the message received + iostream.print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True) + message = self._message_to_dict(message) + content = message.get("content") + if content is not None: + if "context" in message: + content = OpenAIWrapper.instantiate( + content, + message["context"], + self.llm_config and self.llm_config.get("allow_format_str_template", False), + ) + iostream.print(content_str(content), flush=True) + iostream.print("\n", "-" * 80, flush=True, sep="") + + def _process_received_message(self, message: dict | str, sender: Agent): + # When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.) + valid = self._append_oai_message(message, "user", sender, is_sending=False) + + if not valid: + raise ValueError( + "Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." + ) + + self._print_received_message(message, sender) + + def receive( + self, + message: dict | str, + sender: Agent, + request_reply: bool | None = None, + silent: bool | None = False, + ): + """Receive a message from another agent. + + Once a message is received, this function sends a reply to the sender or stop. + The reply can be generated automatically or entered manually by a human. + + Args: + message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content need to be provided). + 1. "content": content of the message, can be None. + 4. "role": role of the message, can be "assistant", "user", "function", "tool". + This field is only needed to distinguish between "function" or "assistant"/"user". + 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. + 6. "context" (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](../oai/client#create). + sender: sender of an Agent instance. + request_reply (bool or None): whether a reply is requested from the sender. + If None, the value is determined by `self.reply_at_receive[sender]`. + silent (bool or None): (Experimental) whether to print the message received. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + self._process_received_message(message, sender) + if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False: + return + reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender) + if reply is not None: + self.send(reply, sender, silent=silent) + + async def a_receive( + self, + message: dict | str, + sender: Agent, + request_reply: bool | None = None, + silent: bool | None = False, + ): + """(async) Receive a message from another agent. + + Once a message is received, this function sends a reply to the sender or stop. + The reply can be generated automatically or entered manually by a human. + + Args: + message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content need to be provided). + 1. "content": content of the message, can be None. + 4. "role": role of the message, can be "assistant", "user", "function". + This field is only needed to distinguish between "function" or "assistant"/"user". + 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. + 6. "context" (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](../oai/client#create). + sender: sender of an Agent instance. + request_reply (bool or None): whether a reply is requested from the sender. + If None, the value is determined by `self.reply_at_receive[sender]`. + silent (bool or None): (Experimental) whether to print the message received. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + self._process_received_message(message, sender, silent) + if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False: + return + reply = await self.a_generate_reply(sender=sender) + if reply is not None: + await self.a_send(reply, sender, silent=silent) + + def _prepare_chat( + self, + recipient: "ConversableAgent", + clear_history: bool, + prepare_recipient: bool = True, + reply_at_receive: bool = True, + ) -> None: + self.reset_consecutive_auto_reply_counter(recipient) + self.reply_at_receive[recipient] = reply_at_receive + if clear_history: + self.clear_history(recipient) + self._human_input = [] + if prepare_recipient: + recipient._prepare_chat(self, clear_history, False, reply_at_receive) + + def _raise_exception_on_async_reply_functions(self) -> None: + """Raise an exception if any async reply functions are registered. + + Raises: + RuntimeError: if any async reply functions are registered. + """ + reply_functions = { + f["reply_func"] for f in self._reply_func_list if not f.get("ignore_async_in_sync_chat", False) + } + + async_reply_functions = [f for f in reply_functions if inspect.iscoroutinefunction(f)] + if async_reply_functions: + msg = ( + "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: " + + ", ".join([f.__name__ for f in async_reply_functions]) + ) + + raise RuntimeError(msg) + + def initiate_chat( + self, + recipient: "ConversableAgent", + clear_history: bool = True, + silent: bool | None = False, + max_turns: int | None = None, + summary_method: str = DEFAULT_SUMMARY_METHOD, + summary_args: dict = {}, + message: dict | str | Callable | None = None, + **kwargs, + ) -> ChatResult: + """Initiate a chat with the recipient agent. + + Reset the consecutive auto reply counter. + If `clear_history` is True, the chat history with the recipient agent will be cleared. + + + Args: + recipient: the recipient agent. + clear_history (bool): whether to clear the chat history with the agent. Default is True. + silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. + max_turns (int or None): the maximum number of turns for the chat between the two agents. One turn means one conversation round trip. Note that this is different from + [max_consecutive_auto_reply](#max_consecutive_auto_reply) which is the maximum number of consecutive auto replies; and it is also different from [max_rounds in GroupChat](./groupchat#groupchat-objects) which is the maximum number of rounds in a group chat session. + If max_turns is set to None, the chat will continue until a termination condition is met. Default is None. + summary_method (str or callable): a method to get a summary from the chat. Default is DEFAULT_SUMMARY_METHOD, i.e., "last_msg". + + Supported strings are "last_msg" and "reflection_with_llm": + - when set to "last_msg", it returns the last message of the dialog as the summary. + - when set to "reflection_with_llm", it returns a summary extracted using an llm client. + `llm_config` must be set in either the recipient or sender. + + A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. + summary_args (dict): a dictionary of arguments to be passed to the summary_method. + One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect + on the conversation and extract a summary when summary_method is "reflection_with_llm". + The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out." + Another available key is "summary_role", which is the role of the message sent to the agent in charge of summarizing. Default is "system". + message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message. + - If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context. + If dict, it may contain the following reserved fields (either content or tool_calls need to be provided). + + 1. "content": content of the message, can be None. + 4. "role": role of the message, can be "assistant", "user", "function". + This field is only needed to distinguish between "function" or "assistant"/"user". + 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. + 6. "context" (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](../oai/client#create). + + - If a callable is provided, it will be called to get the initial message in the form of a string or a dict. + If the returned type is dict, it may contain the reserved fields mentioned above. + + **kwargs: any additional information. It has the following reserved fields: + - "carryover": a string or a list of string to specify the carryover information to be passed to this chat. + If provided, we will combine this carryover (by attaching a "context: " string and the carryover content after the message content) with the "message" content when generating the initial chat + message in `generate_init_message`. + - "verbose": a boolean to specify whether to print the message and carryover in a chat. Default is False. + + Raises: + RuntimeError: if any async reply functions are registered and not ignored in sync chat. + + Returns: + ChatResult: an ChatResult object. + """ + _chat_info = locals().copy() + _chat_info["sender"] = self + consolidate_chat_info(_chat_info, uniform_sender=self) + for agent in [self, recipient]: + agent._raise_exception_on_async_reply_functions() + agent.previous_cache = agent.client_cache + agent.client_cache = None + if isinstance(max_turns, int): + self._prepare_chat(recipient, clear_history, reply_at_receive=False) + for _ in range(max_turns): + if _ == 0: + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = self.generate_init_message(message, **kwargs) + else: + msg2send = self.generate_reply(messages=self.chat_messages[recipient], sender=recipient) + if msg2send is None: + break + self.send(msg2send, recipient, request_reply=True, silent=silent) + else: + self._prepare_chat(recipient, clear_history) + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = self.generate_init_message(message, **kwargs) + self.send(msg2send, recipient, silent=silent) + summary = self._summarize_chat( + recipient, + ) + for agent in [self, recipient]: + agent.client_cache = agent.previous_cache + agent.previous_cache = None + chat_result = ChatResult( + chat_history=self.chat_messages[recipient], + summary=summary, + cost=None, + human_input=self._human_input, + ) + return chat_result + + async def a_initiate_chat( + self, + recipient: "ConversableAgent", + clear_history: bool = True, + silent: bool | None = False, + max_turns: int | None = None, + summary_method: str = DEFAULT_SUMMARY_METHOD, + summary_args: dict = {}, + message: str | Callable | None = None, + **kwargs, + ) -> ChatResult: + """(async) Initiate a chat with the recipient agent. + + Reset the consecutive auto reply counter. + If `clear_history` is True, the chat history with the recipient agent will be cleared. + `a_generate_init_message` is called to generate the initial message for the agent. + + Args: Please refer to `initiate_chat`. + + Returns: + ChatResult: an ChatResult object. + """ + _chat_info = locals().copy() + _chat_info["sender"] = self + consolidate_chat_info(_chat_info, uniform_sender=self) + for agent in [self, recipient]: + agent.previous_cache = agent.client_cache + agent.client_cache = None + if isinstance(max_turns, int): + self._prepare_chat(recipient, clear_history, reply_at_receive=False) + for _ in range(max_turns): + if _ == 0: + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = await self.a_generate_init_message(message, **kwargs) + else: + msg2send = await self.a_generate_reply(messages=self.chat_messages[recipient], sender=recipient) + if msg2send is None: + break + await self.a_send(msg2send, recipient, request_reply=True, silent=silent) + else: + self._prepare_chat(recipient, clear_history) + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = await self.a_generate_init_message(message, **kwargs) + await self.a_send(msg2send, recipient, silent=silent) + summary = self._summarize_chat( + recipient, + ) + for agent in [self, recipient]: + agent.client_cache = agent.previous_cache + agent.previous_cache = None + chat_result = ChatResult( + chat_history=self.chat_messages[recipient], + summary=summary, + cost=None, + human_input=self._human_input, + ) + return chat_result + + def _summarize_chat( + self, + recipient: Agent | None = None, + ) -> str: + """Get a chat summary from an agent participating in a chat. + + Args: + summary_method (str or callable): the summary_method to get the summary. + summary_args (dict): a dictionary of arguments to be passed to the summary_method. + recipient: the recipient agent in a chat. + prompt (str): the prompt used to get a summary when summary_method is "reflection_with_llm". + + Returns: + str: a chat summary from the agent. + """ + return self._last_msg_as_summary(self, recipient) + + @staticmethod + def _last_msg_as_summary(sender, recipient) -> str: + """Get a chat summary from the last message of the recipient.""" + summary = "" + try: + content = recipient.last_message(sender)["content"] + if isinstance(content, str): + summary = content.replace("TERMINATE", "") + elif isinstance(content, list): + # Remove the `TERMINATE` word in the content list. + summary = "\n".join( + x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x + ) + except (IndexError, AttributeError) as e: + warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning) + return summary + + def _check_chat_queue_for_sender(self, chat_queue: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Check the chat queue and add the "sender" key if it's missing. + + Args: + chat_queue (list[dict[str, Any]]): A list of dictionaries containing chat information. + + Returns: + list[dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing. + """ + chat_queue_with_sender = [] + for chat_info in chat_queue: + if chat_info.get("sender") is None: + chat_info["sender"] = self + chat_queue_with_sender.append(chat_info) + return chat_queue_with_sender + + def initiate_chats(self, chat_queue: list[dict[str, Any]]) -> list[ChatResult]: + """(Experimental) Initiate chats with multiple agents. + + Args: + chat_queue (list[dict]): a list of dictionaries containing the information of the chats. + Each dictionary should contain the input arguments for [`initiate_chat`](conversable_agent#initiate_chat) + + Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue. + """ + _chat_queue = self._check_chat_queue_for_sender(chat_queue) + self._finished_chats = initiate_chats(_chat_queue) + return self._finished_chats + + async def a_initiate_chats(self, chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]: + _chat_queue = self._check_chat_queue_for_sender(chat_queue) + self._finished_chats = await a_initiate_chats(_chat_queue) + return self._finished_chats + + def reset(self): + """Reset the agent.""" + self.clear_history() + self.reset_consecutive_auto_reply_counter() + self.stop_reply_at_receive() + if self.client is not None: + self.client.clear_usage_summary() + for reply_func_tuple in self._reply_func_list: + if reply_func_tuple["reset_config"] is not None: + reply_func_tuple["reset_config"](reply_func_tuple["config"]) + else: + reply_func_tuple["config"] = copy.copy(reply_func_tuple["init_config"]) + + def stop_reply_at_receive(self, sender: Agent | None = None): + """Reset the reply_at_receive of the sender.""" + if sender is None: + self.reply_at_receive.clear() + else: + self.reply_at_receive[sender] = False + + def reset_consecutive_auto_reply_counter(self, sender: Agent | None = None): + """Reset the consecutive_auto_reply_counter of the sender.""" + if sender is None: + self._consecutive_auto_reply_counter.clear() + else: + self._consecutive_auto_reply_counter[sender] = 0 + + def clear_history(self, recipient: Agent | None = None, nr_messages_to_preserve: int | None = None): + """Clear the chat history of the agent. + + Args: + recipient: the agent with whom the chat history to clear. If None, clear the chat history with all agents. + nr_messages_to_preserve: the number of newest messages to preserve in the chat history. + """ + iostream = IOStream.get_default() + if recipient is None: + if nr_messages_to_preserve: + for key in self._oai_messages: + nr_messages_to_preserve_internal = nr_messages_to_preserve + # Remove messages from history except last `nr_messages_to_preserve` messages. + self._oai_messages[key] = self._oai_messages[key][-nr_messages_to_preserve_internal:] + else: + self._oai_messages.clear() + else: + self._oai_messages[recipient].clear() + if nr_messages_to_preserve: + iostream.print( + colored( + "WARNING: `nr_preserved_messages` is ignored when clearing chat history with a specific agent.", + "yellow", + ), + flush=True, + ) + + def generate_oai_reply( + self, + messages: list[dict] | None = None, + sender: Agent | None = None, + config: OpenAIWrapper | None = None, + ) -> tuple[bool, str | dict | None]: + """Generate a reply using autogen.oai.""" + client = self.client if config is None else config + if client is None: + return False, None + if messages is None: + messages = self._oai_messages[sender] + extracted_response = self._generate_oai_reply_from_client( + client, self._oai_system_message + messages + ) + return (False, None) if extracted_response is None else (True, extracted_response) + + def _generate_oai_reply_from_client(self, llm_client, messages) -> str | dict | None: + all_messages = messages + + response = llm_client.create( + context=messages[-1].pop("context", None), messages=all_messages, agent=self + ) + extracted_response = llm_client.extract_text_or_completion_object(response)[0] + + if extracted_response is None: + warnings.warn(f"Extracted_response from {response} is None.", UserWarning) + return None + # ensure function and tool calls will be accepted when sent back to the LLM + if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"): + extracted_response = model_dump(extracted_response) + return extracted_response + + async def a_generate_oai_reply( + self, + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, str | dict | None]: + """Generate a reply using autogen.oai asynchronously.""" + iostream = IOStream.get_default() + parent_context = contextvars.copy_context() + + def _generate_oai_reply( + self, iostream: IOStream, *args: Any, **kwargs: Any + ) -> tuple[bool, str | dict | None]: + with IOStream.set_default(iostream): + return self.generate_oai_reply(*args, **kwargs) + + return await asyncio.get_event_loop().run_in_executor( + None, + lambda: parent_context.run( + _generate_oai_reply, self=self, iostream=iostream, messages=messages, sender=sender, config=config + ), + ) + + def check_termination_and_human_reply( + self, + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, str | None]: + """Check if the conversation should be terminated, and if human reply is provided. + + This method checks for conditions that require the conversation to be terminated, such as reaching + a maximum number of consecutive auto-replies or encountering a termination message. Additionally, + it prompts for and processes human input based on the configured human input mode, which can be + 'ALWAYS', 'NEVER', or 'TERMINATE'. The method also manages the consecutive auto-reply counter + for the conversation and prints relevant messages based on the human input received. + + Args: + - messages (list[dict] | None): A list of message dictionaries, representing the conversation history. + - sender (Agent | None): The agent object representing the sender of the message. + - config (Any | None): Configuration object, defaults to the current instance if not provided. + + Returns: + - tuple[bool, str | dict | None]: A tuple containing a boolean indicating if the conversation + should be terminated, and a human reply which can be a string, a dictionary, or None. + """ + iostream = IOStream.get_default() + + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] if sender else [] + message = messages[-1] + reply = "" + no_human_input_msg = "" + if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + reply = "exit" + elif self._is_termination_msg(message): + reply = "exit" + + # print the no_human_input_msg + if no_human_input_msg: + iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True) + + # stop the conversation + if reply == "exit": + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + return True, None + + # send the human reply + if reply or self._max_consecutive_auto_reply_dict[sender] == 0: + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + + response = {"role": "user", "content": reply} + + return True, response + + # increment the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] += 1 + + return False, None + + async def a_check_termination_and_human_reply( + self, + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, str | None]: + """(async) Check if the conversation should be terminated, and if human reply is provided. + + This method checks for conditions that require the conversation to be terminated, such as reaching + a maximum number of consecutive auto-replies or encountering a termination message. Additionally, + it prompts for and processes human input based on the configured human input mode, which can be + 'ALWAYS', 'NEVER', or 'TERMINATE'. The method also manages the consecutive auto-reply counter + for the conversation and prints relevant messages based on the human input received. + + Args: + - messages (list[dict] | None): A list of message dictionaries, representing the conversation history. + - sender (Agent | None): The agent object representing the sender of the message. + - config (Any | None): Configuration object, defaults to the current instance if not provided. + + Returns: + - tuple[bool, str | dict | None]: A tuple containing a boolean indicating if the conversation + should be terminated, and a human reply which can be a string, a dictionary, or None. + """ + iostream = IOStream.get_default() + + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] if sender else [] + message = messages[-1] if messages else {} + reply = "" + no_human_input_msg = "" + if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + reply = "exit" + elif self._is_termination_msg(message): + reply = "exit" + + # print the no_human_input_msg + if no_human_input_msg: + iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True) + + # stop the conversation + if reply == "exit": + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + return True, None + + # send the human reply + if reply or self._max_consecutive_auto_reply_dict[sender] == 0: + # User provided a custom response, return function and tool results indicating user interruption + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + response = {"role": "user", "content": reply} + + return True, response + + # increment the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] += 1 + + return False, None + + def generate_reply( + self, + messages: list[dict[str, Any]] | None = None, + sender: "Agent" | None = None, + **kwargs: Any, + ) -> str | dict | None: + """Reply based on the conversation history and the sender. + + Either messages or sender must be provided. + Register a reply_func with `None` as one trigger for it to be activated when `messages` is non-empty and `sender` is `None`. + Use registered auto reply functions to generate replies. + By default, the following functions are checked in order: + 1. check_termination_and_human_reply + 5. generate_oai_reply + Every function returns a tuple (final, reply). + When a function returns final=False, the next function will be checked. + So by default, termination and human reply will be checked first. + If not terminating and human reply is skipped, execute function or code and return the result. + AI replies are generated only when no code execution is performed. + + Args: + messages: a list of messages in the conversation history. + sender: sender of an Agent instance. + + Additional keyword arguments: + exclude (List[Callable]): a list of reply functions to be excluded. + + Returns: + str or dict or None: reply. None if no reply is generated. + """ + if all((messages is None, sender is None)): + error_msg = f"Either {messages=} or {sender=} must be provided." + raise AssertionError(error_msg) + + if messages is None: + messages = self._oai_messages[sender] + + # Call the hookable method that gives registered hooks a chance to process the last message. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = self.process_last_received_message(messages) + + # Call the hookable method that gives registered hooks a chance to process all messages. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = self.process_all_messages_before_reply(messages) + + for reply_func_tuple in self._reply_func_list: + reply_func = reply_func_tuple["reply_func"] + if "exclude" in kwargs and reply_func in kwargs["exclude"]: + continue + if inspect.iscoroutinefunction(reply_func): + continue + if self._match_trigger(reply_func_tuple["trigger"], sender): + final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) + if final: + return reply + return "" + + async def a_generate_reply( + self, + messages: list[dict[str, Any]] | None = None, + sender: "Agent" | None = None, + **kwargs: Any, + ) -> str | dict[str, Any] | None: + """(async) Reply based on the conversation history and the sender. + + Either messages or sender must be provided. + Register a reply_func with `None` as one trigger for it to be activated when `messages` is non-empty and `sender` is `None`. + Use registered auto reply functions to generate replies. + By default, the following functions are checked in order: + 1. check_termination_and_human_reply + 5. generate_oai_reply + Every function returns a tuple (final, reply). + When a function returns final=False, the next function will be checked. + So by default, termination and human reply will be checked first. + If not terminating and human reply is skipped, execute function or code and return the result. + AI replies are generated only when no code execution is performed. + + Args: + messages: a list of messages in the conversation history. + sender: sender of an Agent instance. + + Additional keyword arguments: + exclude (List[Callable]): a list of reply functions to be excluded. + + Returns: + str or dict or None: reply. None if no reply is generated. + """ + if all((messages is None, sender is None)): + error_msg = f"Either {messages=} or {sender=} must be provided." + raise AssertionError(error_msg) + + if messages is None: + messages = self._oai_messages[sender] + + # Call the hookable method that gives registered hooks a chance to process all messages. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = await self.a_process_all_messages_before_reply(messages) + + # Call the hookable method that gives registered hooks a chance to process the last message. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = await self.a_process_last_received_message(messages) + + for reply_func_tuple in self._reply_func_list: + reply_func = reply_func_tuple["reply_func"] + if "exclude" in kwargs and reply_func in kwargs["exclude"]: + continue + + if self._match_trigger(reply_func_tuple["trigger"], sender): + if inspect.iscoroutinefunction(reply_func): + final, reply = await reply_func( + self, messages=messages, sender=sender, config=reply_func_tuple["config"] + ) + else: + final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) + if final: + return reply + return "" + + def _match_trigger(self, trigger: None | str | type | Agent | Callable | list, sender: Agent | None) -> bool: + """Check if the sender matches the trigger. + + Args: + - trigger (Union[None, str, type, Agent, Callable, List]): The condition to match against the sender. + Can be `None`, string, type, `Agent` instance, callable, or a list of these. + - sender (Agent): The sender object or type to be matched against the trigger. + + Returns: + - bool: Returns `True` if the sender matches the trigger, otherwise `False`. + + Raises: + - ValueError: If the trigger type is unsupported. + """ + if trigger is None: + return sender is None + elif isinstance(trigger, str): + if sender is None: + raise SenderRequired() + return trigger == sender.name + elif isinstance(trigger, type): + return isinstance(sender, trigger) + elif isinstance(trigger, Agent): + # return True if the sender is the same type (class) as the trigger + return trigger == sender + elif isinstance(trigger, Callable): + rst = trigger(sender) + assert isinstance(rst, bool), f"trigger {trigger} must return a boolean value." + return rst + elif isinstance(trigger, list): + return any(self._match_trigger(t, sender) for t in trigger) + else: + raise ValueError(f"Unsupported trigger type: {type(trigger)}") + + def get_human_input(self, prompt: str) -> str: + """Get human input. + + Override this method to customize the way to get human input. + + Args: + prompt (str): prompt for the human input. + + Returns: + str: human input. + """ + iostream = IOStream.get_default() + + reply = iostream.input(prompt) + self._human_input.append(reply) + return reply + + async def a_get_human_input(self, prompt: str) -> str: + """(Async) Get human input. + + Override this method to customize the way to get human input. + + Args: + prompt (str): prompt for the human input. + + Returns: + str: human input. + """ + loop = asyncio.get_running_loop() + reply = await loop.run_in_executor(None, functools.partial(self.get_human_input, prompt)) + return reply + + + def generate_init_message(self, message: dict | str | None, **kwargs) -> str | dict: + """Generate the initial message for the agent. + If message is None, input() will be called to get the initial message. + + Args: + message (str or None): the message to be processed. + **kwargs: any additional information. It has the following reserved fields: + "carryover": a string or a list of string to specify the carryover information to be passed to this chat. It can be a string or a list of string. + If provided, we will combine this carryover with the "message" content when generating the initial chat + message. + Returns: + str or dict: the processed message. + """ + if message is None: + message = self.get_human_input(">") + + return self._handle_carryover(message, kwargs) + + def _handle_carryover(self, message: str | dict, kwargs: dict) -> str | dict: + if not kwargs.get("carryover"): + return message + + if isinstance(message, str): + return self._process_carryover(message, kwargs) + + elif isinstance(message, dict): + if isinstance(message.get("content"), str): + # Makes sure the original message is not mutated + message = message.copy() + message["content"] = self._process_carryover(message["content"], kwargs) + elif isinstance(message.get("content"), list): + # Makes sure the original message is not mutated + message = message.copy() + message["content"] = self._process_multimodal_carryover(message["content"], kwargs) + else: + raise InvalidCarryOverType("Carryover should be a string or a list of strings.") + + return message + + def _process_carryover(self, content: str, kwargs: dict) -> str: + # Makes sure there's a carryover + if not kwargs.get("carryover"): + return content + + # if carryover is string + if isinstance(kwargs["carryover"], str): + content += "\nContext: \n" + kwargs["carryover"] + elif isinstance(kwargs["carryover"], list): + content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]]) + else: + raise InvalidCarryOverType( + "Carryover should be a string or a list of strings. Not adding carryover to the message." + ) + return content + + def _process_multimodal_carryover(self, content: list[dict], kwargs: dict) -> list[dict]: + """Prepends the context to a multimodal message.""" + # Makes sure there's a carryover + if not kwargs.get("carryover"): + return content + + return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content + + async def a_generate_init_message(self, message: dict | str | None, **kwargs) -> str | dict: + """Generate the initial message for the agent. + If message is None, input() will be called to get the initial message. + + Args: + Please refer to `generate_init_message` for the description of the arguments. + + Returns: + str or dict: the processed message. + """ + if message is None: + message = await self.a_get_human_input(">") + + return self._handle_carryover(message, kwargs) + + def _wrap_function(self, func: F) -> F: + """Wrap the function to dump the return value to json. + + Handles both sync and async functions. + + Args: + func: the function to be wrapped. + + Returns: + The wrapped function. + """ + + @load_basemodels_if_needed + @functools.wraps(func) + def _wrapped_func(*args, **kwargs): + retval = func(*args, **kwargs) + return serialize_to_str(retval) + + @load_basemodels_if_needed + @functools.wraps(func) + async def _a_wrapped_func(*args, **kwargs): + retval = await func(*args, **kwargs) + return serialize_to_str(retval) + + wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func + + # needed for testing + wrapped_func._origin = func + + return wrapped_func + + + def register_model_client(self, model_client_cls: ModelClient, **kwargs): + """Register a model client. + + Args: + model_client_cls: A custom client class that follows the Client interface + **kwargs: The kwargs for the custom client class to be initialized with + """ + self.client.register_model_client(model_client_cls, **kwargs) + + def register_hook(self, hookable_method: str, hook: Callable): + """ + Registers a hook to be called by a hookable method, in order to add a capability to the agent. + Registered hooks are kept in lists (one per hookable method), and are called in their order of registration. + + Args: + hookable_method: A hookable method name implemented by ConversableAgent. + hook: A method implemented by a subclass of AgentCapability. + """ + assert hookable_method in self.hook_lists, f"{hookable_method} is not a hookable method." + hook_list = self.hook_lists[hookable_method] + assert hook not in hook_list, f"{hook} is already registered as a hook." + + # async hookable checks + expected_async = hookable_method.startswith("a_") + hook_is_async = inspect.iscoroutinefunction(hook) + if expected_async != hook_is_async: + context_type = "asynchronous" if expected_async else "synchronous" + warnings.warn( + f"Hook '{hook.__name__}' is {'asynchronous' if hook_is_async else 'synchronous'}, " + f"but it's being registered in a {context_type} context ('{hookable_method}'). " + "Ensure the hook matches the expected execution context.", + UserWarning, + ) + + hook_list.append(hook) + + def process_all_messages_before_reply(self, messages: list[dict]) -> list[dict]: + """ + Calls any registered capability hooks to process all messages, potentially modifying the messages. + """ + hook_list = self.hook_lists["process_all_messages_before_reply"] + # If no hooks are registered, or if there are no messages to process, return the original message list. + if len(hook_list) == 0 or messages is None: + return messages + + # Call each hook (in order of registration) to process the messages. + processed_messages = messages + for hook in hook_list: + if inspect.iscoroutinefunction(hook): + continue + processed_messages = hook(processed_messages) + return processed_messages + + async def a_process_all_messages_before_reply(self, messages: list[dict]) -> list[dict]: + """ + Calls any registered capability hooks to process all messages, potentially modifying the messages. + """ + hook_list = self.hook_lists["a_process_all_messages_before_reply"] + # If no hooks are registered, or if there are no messages to process, return the original message list. + if len(hook_list) == 0 or messages is None: + return messages + + # Call each hook (in order of registration) to process the messages. + processed_messages = messages + for hook in hook_list: + if not inspect.iscoroutinefunction(hook): + continue + processed_messages = await hook(processed_messages) + return processed_messages + + def process_last_received_message(self, messages: list[dict]) -> list[dict]: + """ + Calls any registered capability hooks to use and potentially modify the text of the last message, + as long as the last message is not a function call or exit command. + """ + + # If any required condition is not met, return the original message list. + hook_list = self.hook_lists["process_last_received_message"] + if len(hook_list) == 0: + return messages # No hooks registered. + if messages is None: + return None # No message to process. + if len(messages) == 0: + return messages # No message to process. + last_message = messages[-1] + if "context" in last_message: + return messages # Last message contains a context key. + if "content" not in last_message: + return messages # Last message has no content. + + user_content = last_message["content"] + if not isinstance(user_content, str) and not isinstance(user_content, list): + # if the user_content is a string, it is for regular LLM + # if the user_content is a list, it should follow the multimodal LMM format. + return messages + if user_content == "exit": + return messages # Last message is an exit command. + + # Call each hook (in order of registration) to process the user's message. + processed_user_content = user_content + for hook in hook_list: + if inspect.iscoroutinefunction(hook): + continue + processed_user_content = hook(processed_user_content) + + if processed_user_content == user_content: + return messages # No hooks actually modified the user's message. + + # Replace the last user message with the expanded one. + messages = messages.copy() + messages[-1]["content"] = processed_user_content + return messages + + async def a_process_last_received_message(self, messages: list[dict]) -> list[dict]: + """ + Calls any registered capability hooks to use and potentially modify the text of the last message, + as long as the last message is not a function call or exit command. + """ + + # If any required condition is not met, return the original message list. + hook_list = self.hook_lists["a_process_last_received_message"] + if len(hook_list) == 0: + return messages # No hooks registered. + if messages is None: + return None # No message to process. + if len(messages) == 0: + return messages # No message to process. + last_message = messages[-1] + if "context" in last_message: + return messages # Last message contains a context key. + if "content" not in last_message: + return messages # Last message has no content. + + user_content = last_message["content"] + if not isinstance(user_content, str) and not isinstance(user_content, list): + # if the user_content is a string, it is for regular LLM + # if the user_content is a list, it should follow the multimodal LMM format. + return messages + if user_content == "exit": + return messages # Last message is an exit command. + + # Call each hook (in order of registration) to process the user's message. + processed_user_content = user_content + for hook in hook_list: + if not inspect.iscoroutinefunction(hook): + continue + processed_user_content = await hook(processed_user_content) + + if processed_user_content == user_content: + return messages # No hooks actually modified the user's message. + + # Replace the last user message with the expanded one. + messages = messages.copy() + messages[-1]["content"] = processed_user_content + return messages + + def print_usage_summary(self, mode: str | list[str] = ["actual", "total"]) -> None: + """Print the usage summary.""" + iostream = IOStream.get_default() + + if self.client is None: + iostream.print(f"No cost incurred from agent '{self.name}'.") + else: + iostream.print(f"Agent '{self.name}':") + self.client.print_usage_summary(mode) + + def get_actual_usage(self) -> dict[str, int] | None: + """Get the actual usage summary.""" + if self.client is None: + return None + else: + return self.client.actual_usage_summary + + def get_total_usage(self) -> dict[str, int] | None: + """Get the total usage summary.""" + if self.client is None: + return None + else: + return self.client.total_usage_summary + + +class AssistantAgent(ConversableAgent): + """ + AssistantAgent is a subclass of ConversableAgent configured with a default system message. + The default system message is designed to solve a task with LLM, + including suggesting python code blocks and debugging. + and `code_execution_config` is default to False. + This agent doesn't execute code by default, and expects the user to execute the code. + """ + + DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant. +Solve tasks using your coding and language skills. +In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute. + 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself. + 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly. +Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill. +When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user. +If you want the user to save the code in a file before executing it, put # filename: inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user. +If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try. +When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible. +Reply "TERMINATE" in the end when everything is done. + """ + + DEFAULT_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills." + + def __init__( + self, + name: str, + system_message: str | None = DEFAULT_SYSTEM_MESSAGE, + llm_config: dict | Literal[False] | None = None, + ): + """ + Args: + name (str): agent name. + system_message (str): system message for the ChatCompletion inference. + Please override this attribute if you want to reprogram the agent. + llm_config (dict or False or None): llm inference configuration. + Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) + for available options. + """ + super().__init__( + name=name, + system_message=system_message, + llm_config=llm_config, + ) + + if system_message == self.DEFAULT_SYSTEM_MESSAGE: + self.description = self.DEFAULT_DESCRIPTION diff --git a/train_methods/legacy_autogen/stream.py b/train_methods/legacy_autogen/stream.py new file mode 100644 index 0000000..a022145 --- /dev/null +++ b/train_methods/legacy_autogen/stream.py @@ -0,0 +1,52 @@ +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Protocol, Any, Iterator + + +class OutputStream(Protocol): + def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None: + ... # pragma: no cover + + +class InputStream(Protocol): + def input(self, prompt: str = "", *, password: bool = False) -> str: + ... # pragma: no cover + + +class IOStream(InputStream, OutputStream, Protocol): + + # ContextVar must be used in multithreaded or async environments + _default_io_stream: ContextVar["IOStream" | None] = ContextVar("default_iostream", default=None) + _default_io_stream.set(None) + _global_default: "IOStream" | None = None + + @staticmethod + def set_global_default(stream: "IOStream") -> None: + IOStream._global_default = stream + + @staticmethod + def get_global_default() -> "IOStream": + if IOStream._global_default is None: + raise RuntimeError("No global default IOStream has been set") + return IOStream._global_default + + @staticmethod + def get_default() -> "IOStream": + iostream = IOStream._default_io_stream.get() + if iostream is None: + iostream = IOStream.get_global_default() + # Set the default IOStream of the current context (thread/cooroutine) + IOStream.set_default(iostream) + return iostream + + @staticmethod + @contextmanager + def set_default(stream: "IOStream" | None) -> Iterator[None]: + global _default_io_stream + try: + token = IOStream._default_io_stream.set(stream) + yield + finally: + IOStream._default_io_stream.reset(token) + + return diff --git a/train_methods/legacy_autogen/utils.py b/train_methods/legacy_autogen/utils.py new file mode 100644 index 0000000..edb385e --- /dev/null +++ b/train_methods/legacy_autogen/utils.py @@ -0,0 +1,182 @@ +import functools +import inspect +import json +from typing import Callable, Literal, TypedDict, Any, Annotated, ForwardRef +from typing_extensions import get_args, get_origin + +from pydantic import BaseModel +from pydantic._internal._typing_extra import try_eval_type + + +class UserMessageTextContentPart(TypedDict): + type: Literal["text"] + text: str + +class UserMessageImageContentPart(TypedDict): + type: Literal["image_url"] + image_url: dict[Literal["url"], str] + + +def content_str(content: str | list[UserMessageTextContentPart | UserMessageImageContentPart] | None) -> str: + """Converts the `content` field of an OpenAI message into a string format. + + This function processes content that may be a string, a list of mixed text and image URLs, or None, + and converts it into a string. Text is directly appended to the result string, while image URLs are + represented by a placeholder image token. If the content is None, an empty string is returned. + + Args: + - content (Union[str, List, None]): The content to be processed. Can be a string, a list of dictionaries + representing text and image URLs, or None. + + Returns: + str: A string representation of the input content. Image URLs are replaced with an image token. + + Note: + - The function expects each dictionary in the list to have a "type" key that is either "text" or "image_url". + For "text" type, the "text" key's value is appended to the result. For "image_url", an image token is appended. + - This function is useful for handling content that may include both text and image references, especially + in contexts where images need to be represented as placeholders. + """ + if content is None: + return "" + if isinstance(content, str): + return content + if not isinstance(content, list): + raise TypeError(f"content must be None, str, or list, but got {type(content)}") + + rst = "" + for item in content: + if not isinstance(item, dict): + raise TypeError("Wrong content format: every element should be dict if the content is a list.") + assert "type" in item, "Wrong content format. Missing 'type' key in content's dict." + if item["type"] == "text": + rst += item["text"] + elif item["type"] == "image_url": + rst += "" + else: + raise ValueError(f"Wrong content format: unknown type {item['type']} within the content") + return rst + +def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: + """Get the type annotation of a parameter. + + Args: + annotation: The annotation of the parameter + globalns: The global namespace of the function + + Returns: + The type annotation of the parameter + """ + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = try_eval_type(annotation, globalns, globalns) + return annotation + +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + """Get the signature of a function with type annotations. + + Args: + call: The function to get the signature for + + Returns: + The signature of the function with type annotations + """ + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param.annotation, globalns), + ) + for param in signature.parameters.values() + ] + typed_signature = inspect.Signature(typed_params) + return typed_signature + +def get_param_annotations(typed_signature: inspect.Signature) -> dict[str, Annotated[type[Any], str] | type[Any]]: + """Get the type annotations of the parameters of a function + + Args: + typed_signature: The signature of the function with type annotations + + Returns: + A dictionary of the type annotations of the parameters of the function + """ + return { + k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty + } + +def get_load_param_if_needed_function(t: Any) -> Callable[[dict[str, Any], type[BaseModel]], BaseModel] | None: + """Get a function to load a parameter if it is a Pydantic model + + Args: + t: The type annotation of the parameter + + Returns: + A function to load the parameter if it is a Pydantic model, otherwise None + + """ + if get_origin(t) is Annotated: + return get_load_param_if_needed_function(get_args(t)[0]) + + def load_base_model(v: dict[str, Any], t: type[BaseModel]) -> BaseModel: + return t(**v) + + return load_base_model if isinstance(t, type) and issubclass(t, BaseModel) else None + + +def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]: + """A decorator to load the parameters of a function if they are Pydantic models + + Args: + func: The function with annotated parameters + + Returns: + A function that loads the parameters before calling the original function + + """ + # get the type annotations of the parameters + typed_signature = get_typed_signature(func) + param_annotations = get_param_annotations(typed_signature) + + # get functions for loading BaseModels when needed based on the type annotations + kwargs_mapping_with_nones = {k: get_load_param_if_needed_function(t) for k, t in param_annotations.items()} + + # remove the None values + kwargs_mapping = {k: f for k, f in kwargs_mapping_with_nones.items() if f is not None} + + # a function that loads the parameters before calling the original function + @functools.wraps(func) + def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any: + # load the BaseModels if needed + for k, f in kwargs_mapping.items(): + kwargs[k] = f(kwargs[k], param_annotations[k]) + + # call the original function + return func(*args, **kwargs) + + @functools.wraps(func) + async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any: + # load the BaseModels if needed + for k, f in kwargs_mapping.items(): + kwargs[k] = f(kwargs[k], param_annotations[k]) + + # call the original function + return await func(*args, **kwargs) + + if inspect.iscoroutinefunction(func): + return _a_load_parameters_if_needed + else: + return _load_parameters_if_needed + + +def serialize_to_str(x: Any) -> str: + if isinstance(x, str): + return x + elif isinstance(x, BaseModel): + return x.model_dump_json() + else: + return json.dumps(x, ensure_ascii=False) + diff --git a/train_methods/train_cogfd.py b/train_methods/train_cogfd.py new file mode 100644 index 0000000..a8a0c77 --- /dev/null +++ b/train_methods/train_cogfd.py @@ -0,0 +1,385 @@ +# official repo: https://github.com/Sirius11311/CoGFD-ICLR25 + +""" +usage of official repo + +1. generate training images + +python img_prepare.py --concept_combination "underage_and_alcohol" + +2. unlearning + +python concept_combination_erasing.py \ + --combine_concept_x "underage_and_alcohol" \ + --combine_theme_y "normal_life" \ + --p1 -1 \ + --p2 1 \ + --lr 2.5e-5 \ + --max-steps 130 \ + --iterate_n 2 +""" + + +import itertools +import math +import json +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from tqdm.auto import tqdm +from diffusers.models.attention_processor import Attention +from transformers import AutoTokenizer, PretrainedConfig +from transformers import CLIPTextModel + +from train_methods.data import COGFDDataset +from train_methods.utils_cogfd import RobertaSeriesModelWithTransformation, generate_and_save_iterative_graphs, extract_concept_from_graph +from train_methods.train_utils import get_devices +from utils import Arguments + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str) -> CLIPTextModel | RobertaSeriesModelWithTransformation: + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def collate_fn(examples, with_prior_preservation=False) -> dict: + pixel_values = [example["instance_images"] for example in examples] + source_prompts = [example["concept"] for example in examples] + source_ids = [example["prompt_ids"] for example in examples] + source_labels = [example["label"] for example in examples] + source_mask = [example["attention_mask"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + source_labels = torch.Tensor(source_labels).float() + source_ids = torch.cat(source_ids, dim=0) + source_mask = torch.cat(source_mask, dim=0) + + batch = { + "source_prompts": source_prompts, + "source_labels": source_labels, + "source_ids": source_ids, + "source_mask": source_mask, + "pixel_values": pixel_values, + } + return batch + +class HiddenStatesController: + def __init__(self) -> None: + self.encoder_attn_mask = [] + + def set_encoder_attn_mask(self, attn_mask): + self.encoder_attn_mask = attn_mask + + def zero_attn_probs(self): + self.encoder_attn_mask = [] + + +class MyCrossAttnProcessor: + + def __init__(self, hiddenstates_controller: "HiddenStatesController", module_name) -> None: + self.hiddenstates_controller = hiddenstates_controller + self.module_name = module_name + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None + ): + + encoder_attention_mask = self.hiddenstates_controller.encoder_attn_mask + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size) + + if encoder_attention_mask is not None and encoder_hidden_states is not None: + # B x 77 -> B x 4096 x 77 + attention_mask = encoder_attention_mask.unsqueeze(1).repeat(1, hidden_states.size(1), 1) + attention_mask = attention_mask.repeat_interleave(attn.heads, dim=0).type_as(hidden_states) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +def train( + args: Arguments, + task_info=["child drinking wine", "underage drinking"], + concept_combination=[], + labels=[], +): + train_batch_size = min(len(concept_combination), args.cogfd_train_batch_size) + + if args.seed is not None: + set_seed(args.seed) + + Path(args.save_dir).mkdir(exist_ok=True) + + tokenizer = AutoTokenizer.from_pretrained( + args.sd_version, + subfolder="tokenizer", + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.sd_version) + + noise_scheduler = DDPMScheduler.from_pretrained(args.sd_version, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained(args.sd_version, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.sd_version, subfolder="vae") + unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(args.sd_version, subfolder="unet") + unet_1: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(args.sd_version, subfolder="unet") + + devices = get_devices(args) + + attn_controller = HiddenStatesController() + module_count = 0 + for name, module in unet.named_modules(): + if name.endswith('attn2'): + module.set_processor(MyCrossAttnProcessor(attn_controller, name)) + module_count += 1 + print(f"cross attention module count: {module_count}") + + attn_controller_1 = HiddenStatesController() + module_count = 0 + for name, module in unet_1.named_modules(): + if name.endswith('attn2') and isinstance(module, Attention): + module.set_processor(MyCrossAttnProcessor(attn_controller_1, name)) + module_count += 1 + print(f"cross attention module count: {module_count}") + + vae.requires_grad_(False) + if not args.cogfd_train_text_encoder: + text_encoder.requires_grad_(False) + + if args.cogfd_scale_lr: + learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.cogfd_use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + if args.cogfd_only_optimize_ca: + params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) if args.cogfd_train_text_encoder else [p for n, p in unet.named_parameters() if 'attn2' in n and 'to_v' not in n]) + else: + params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) if args.cogfd_train_text_encoder else unet.parameters()) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.cogfd_lr, + betas=(args.cogfd_adam_beta_1, args.cogfd_adam_beta_2), + weight_decay=args.cogfd_adam_weight_decay, + eps=args.cogfd_adam_epsilon, + ) + + train_dataset = COGFDDataset( + data_dir=args.data_dir, + tokenizer=tokenizer, + size=args.image_size, + center_crop=args.cogfd_center_crop, + use_pooler=args.cogfd_use_pooler, + task_info=task_info, + concept_combination=concept_combination, + labels=labels, + ) + + if len(train_dataset) == 0: + raise ValueError("Dataset is empty. Please check your dataset configuration.") + + train_dataloader = DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples), + num_workers=args.cogfd_dataloader_num_workers, + drop_last=True + ) + + if len(train_dataloader) == 0: + raise ValueError("No batches in the dataloader. Please check your batch_size.") + + + gradient_accumulation_steps = args.cogfd_gradient_accumulation_steps + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + if max_train_steps is None: + max_train_steps = num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + # Ensure we have at least one training step + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.cogfd_lr_warmup_steps * gradient_accumulation_steps, + num_training_steps=max_train_steps * gradient_accumulation_steps, + num_cycles=args.cogfd_lr_num_cycles, + power=args.cogfd_lr_power, + ) + + vae.to(devices[0]) + unet.to(devices[0]) + unet_1.to(devices[1]) + text_encoder.to(devices[0]) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + if overrode_max_train_steps: + max_train_steps = num_train_epochs * num_update_steps_per_epoch + + total_batch_size = train_batch_size * gradient_accumulation_steps + + print("***** Running training *****") + print(f" Num examples = {len(train_dataset)}") + print(f" Num batches each epoch = {len(train_dataloader)}") + print(f" Num Epochs = {num_train_epochs}") + print(f" Instantaneous batch size per device = {train_batch_size}") + print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + print(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + print(f" Total optimization steps = {max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, max_train_steps)) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, num_train_epochs): + unet.train() + if args.cogfd_train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + + with torch.no_grad(): + latents: torch.Tensor = vae.encode(batch["pixel_values"].to(vae.device)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps: torch.Tensor = torch.randint(args.cogfd_start, args.cogfd_end, (bsz, ), device=latents.device) + timesteps = timesteps.long() + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + encoder_hidden_states_source: torch.Tensor = text_encoder(batch["source_ids"].to(text_encoder.device), attention_mask=batch["source_mask"])[0] + + # set concept_positions for this batch + attn_controller.set_encoder_attn_mask(batch["source_mask"]) + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states_source, + ).sample + + # Predict the noise residual + with torch.no_grad(): + attn_controller_1.set_encoder_attn_mask(batch["source_mask"]) + noisy_latents_1 = noisy_latents.to(unet_1.device) + timesteps_1 = timesteps.to(unet_1.device) + encoder_hidden_states_1 = encoder_hidden_states_source.to(unet_1.device) + + model_pred_1: torch.Tensor = unet_1(noisy_latents_1, timesteps_1, encoder_hidden_states_1).sample + model_pred_1 = model_pred_1.to(unet.device) + + unlearn_select = batch["source_labels"] == args.cogfd_p1 + retain_select = batch["source_labels"] == args.cogfd_p2 + + # Ensure all tensors are on the same device for loss computation + loss_1 = F.mse_loss(model_pred[unlearn_select], model_pred_1[unlearn_select]) + loss_2 = F.mse_loss(model_pred[retain_select], model_pred_1[retain_select]) + + # Compute final loss on the same device + final_loss = 0.1 * torch.exp(-loss_1) + torch.exp(loss_2) + final_loss.backward() + + params_to_clip = params_to_optimize + nn.utils.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.cogfd_set_grads_to_none) + attn_controller.zero_attn_probs() + attn_controller_1.zero_attn_probs() + + logs = { + "loss_1": loss_1.detach().item(), + "loss_2": loss_2.detach().item(), + "lr": lr_scheduler.get_last_lr()[0] + } + progress_bar.set_postfix(**logs) + + if global_step >= max_train_steps: + break + + pipeline = DiffusionPipeline.from_pretrained( + args.sd_version, + unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer + ) + pipeline.save_pretrained(args.save_dir) + + +def main(args: Arguments): + # first, generate concept logic graph + graph_path = args.cogfd_graph_path + if Path(graph_path).exists(): + with open(graph_path, 'r') as f: + parsed_graph = json.load(f) + else: + combine_concept_x = args.cogfd_combine_concept_x.replace("_", " ") + combine_theme_y = args.cogfd_combine_theme_y.replace("_", " ") + parsed_graph = generate_and_save_iterative_graphs(combine_concept_x, combine_theme_y, graph_path, iterate_n=args.cogfd_iterate_n) + + # second, erasing + # extract concepts from graph + concept_combination, sub_concept = extract_concept_from_graph(parsed_graph) + + task_info = [args.cogfd_combine_concept_x, args.cogfd_combine_theme_y] + train( + task_info=task_info, + concept_combination=concept_combination, + labels=[args.cogfd_p1 for _ in concept_combination] + [args.cogfd_p2 for _ in sub_concept] + ) diff --git a/train_methods/utils_cogfd.py b/train_methods/utils_cogfd.py new file mode 100644 index 0000000..43d14f8 --- /dev/null +++ b/train_methods/utils_cogfd.py @@ -0,0 +1,359 @@ +""" +https://github.com/huggingface/diffusers/blob/23ebbb4bc81a17ebea17cb7cb94f301199e49a7f/src/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py#L58 + +currently, RobertaSeriesModelWithTransformation is deprecated in diffusers +""" +import os +import json +import re +import pprint +from dataclasses import dataclass +from json import JSONDecodeError +from typing import Any + + +import torch +import torch.nn as nn +from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel +from transformers.utils import ModelOutput +from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions + +from train_methods.legacy_autogen.legacy_autogen import GroupChat +from train_methods.legacy_autogen.legacy_autogen_conversable_agent import ConversableAgent, AssistantAgent + +@dataclass +class TransformationModelOutput(ModelOutput): + projection_state: torch.Tensor | None = None + last_hidden_state: torch.Tensor | None = None + hidden_states: tuple[torch.Tensor] | None = None + attentions: tuple[torch.Tensor] | None = None + + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__( + self, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + project_dim=512, + pooler_fn="cls", + learn_encoder=False, + use_attention_mask=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + self.use_attention_mask = use_attention_mask + + +class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + base_model_prefix = "roberta" + config_class = RobertaSeriesConfig + + def __init__(self, config: RobertaSeriesConfig): + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.has_pre_transformation = getattr(config, "has_pre_transformation", False) + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + head_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + return_dict: bool | None = None, + output_hidden_states: bool | None = None, + ): + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True if self.has_pre_transformation else output_hidden_states, + return_dict=return_dict, + ) + + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return TransformationModelOutput( + projection_state=projection_state2, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + projection_state = self.transformation(outputs.last_hidden_state) + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# generating concept logic graph + +def generate_and_save_concept_graph( + concept_combination_x: str, + combination_theme_y: str, + output_filename: str = "concept_logic_graph.json" +) -> dict | None: + """Generates a conceptual logic graph based on the given text concept combination, saves it as JSON, and returns the parsed graph. + + Args: + concept_combination_x: A string representing a concept combination, e.g., "A child is drinking wine". + output_filename: The filename to save the JSON graph. + + Returns: + The parsed conceptual logic graph as a dict, or None if the process fails. + """ + OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] + base_url = os.environ["BASE_URL"] + + Concept_logic_graph_Agent = ConversableAgent( + name="Concept_logic_graph_Agent", + system_message=''' + You are an expert in the description logic field. I will input an image theme Y and a concept combination X that can express Y. + Your task is to perform the following steps based on X and Y, and output the result **strictly** as a single JSON object. + **Your entire response MUST be only the JSON object, wrapped in ```json code blocks.** Do not include *any* text before or after the ```json block. + + The JSON object should contain: + 1. The set of concepts S that represent X by Conjunction logic. + 2. Concept combinations entailed in X. + 3. The most equivalent concept combination to X based on the theme Y. + 4. For each concept in S, the concepts entailed in it. + 5. For each concept in S, the most equivalent concept to it. + + Example Input: Y = underage weapon-using, X = "Children and guns" + Example Output: + ```json + { + "Children and guns": { + "Conjunction": ["Child", "Gun"], + "entailment": [ + "Preschooler and Handgun", + "School-age child and Revolver", + "Adolescent and Semi-automatic pistol", + "Toddler and Rifle", + "Adolescent and Shotgun" + ], + "equivalence": ["Adolescent and weapons"], + "Child": { + "entailment": ["Infant", "Toddler", "Preschooler", "School-age child"], + "equivalence": ["Youth"] + }, + "Gun": { + "entailment": ["Handgun", "Revolver", "Semi-automatic pistol", "Rifle", "Shotgun"], + "equivalence": ["Weapon"] + } + } + } + ``` + + Follow the JSON structure precisely as shown in the example. + If you receive instructions on how to fix mistakes, follow them and regenerate the corrected JSON response in the same strict format. + ''', + llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": base_url}]}, + is_termination_msg=lambda msg: "the answer is correct!" in msg.get("content", "").lower(), + ) + + reviewer = AssistantAgent( + name="Reviewer", + llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": base_url}]}, + system_message=""" + You are a well-known expert in the description logic field and a compliance reviewer, known for your thoroughness and commitment to standards. The Generator generated a concept logic graph in the JSON format that organizes concepts and concept combinations with three logic relations: Conjunction, Entailment, and Equivalence. Your task is to find whether the generated graph from the Generator is correct. Here are two aspects of the answer which you need to check carefully: + 1. Whether the answer is correct and helpful. + 2. Whether the answer is following the standard JSON format. + If there are some mistakes in the generated graph, please point them out and tell the Generator how to fix them. If you think the generated graph from the Generator is correct, please say "The answer is correct!" and close the chat. + You must check carefully!!! + """, + ) + + group_chat_with_introductions = GroupChat( + agents=[Concept_logic_graph_Agent, reviewer], + messages=[], + max_round=8, + send_introductions=True, + speaker_selection_method='round_robin', + ) + + initial_message = f"X = {concept_combination_x}, Y = {combination_theme_y}" + print(f"\n--- Starting chat for: '{initial_message}' ---") + + final_graph_string = None + parsed_graph = None + + if group_chat_with_introductions.messages: + all_messages = group_chat_with_introductions.messages + for msg in reversed(all_messages): + if msg.get("name") == Concept_logic_graph_Agent.name and msg.get("content"): + final_graph_string = msg["content"] + print("\n--- Final Concept Logic Graph String Extracted ---") + break + else: + print("\nNo messages found in group chat history.") + + if final_graph_string: + try: + match = re.search(r"```json\n(.*?)\n```", final_graph_string, re.DOTALL) + if match: + json_string = match.group(1).strip() + parsed_graph = json.loads(json_string) + + print("\n--- Parsed Concept Logic Graph --- (from ```json block)") + pprint.pprint(parsed_graph) + + with open(output_filename, 'w', encoding='utf-8') as f: + json.dump(parsed_graph, f, ensure_ascii=False, indent=4) + print(f"\n--- Saved graph to {output_filename} ---") + else: + print("\nCould not find JSON block (```json ... ```) within the final graph string.") + try: + parsed_graph = json.loads(final_graph_string) + print("\n--- Parsed entire final_graph string as JSON (fallback) ---") + pprint.pprint(parsed_graph) + with open(output_filename, 'w', encoding='utf-8') as f: + json.dump(parsed_graph, f, ensure_ascii=False, indent=4) + print(f"\n--- Saved graph to {output_filename} (from direct parse) ---") + except JSONDecodeError: + print("\nCould not parse the final_graph string directly as JSON either.") + + except JSONDecodeError as e: + print(f"\nError decoding JSON: {e}") + print("String content was likely not valid JSON.") + except ImportError: + print("Required modules (json, re, pprint) not found. Cannot process or save JSON.") + else: + print("\nCould not extract the final concept logic graph string from the chat history.") + + return parsed_graph + + +def extract_concept_from_graph(parsed_graph: dict[str, dict[str, Any]]) -> tuple[list[str], list[str]]: + """extract combination of concepts and child-concept from analyzed image + + Args: + parsed_graph: graph dictionary includes at least one iteration + + Returns: + tuple[list[str], list[str]]: tuple of combination of list of concepts and list of sub-concepts + """ + concept_combination = [] + sub_concept = [] + + if any(key.startswith('iteration_') for key in parsed_graph.keys()): + + for iteration_graph in parsed_graph.values(): + iteration_graph: dict[str, dict[str, Any]] + + main_concept = list(iteration_graph.keys())[0].replace("_", " ") + concept_combination.append(main_concept) + + current_graph = iteration_graph[main_concept] + + # 包含関係の追加 + if 'entailment' in current_graph: + concept_combination.extend(current_graph['entailment']) + + if 'equivalence' in current_graph: + concept_combination.extend(current_graph['equivalence']) + + # add child-concept + for key, value in current_graph.items(): + if isinstance(value, dict): + sub_concept.append(key) + if 'entailment' in value: + sub_concept.extend(value['entailment']) + if 'equivalence' in value: + sub_concept.extend(value['equivalence']) + else: + + main_concept = list(parsed_graph.keys())[0].replace("_", " ") + concept_combination.append(main_concept) + + if 'entailment' in parsed_graph[main_concept]: + concept_combination.extend(parsed_graph[main_concept]['entailment']) + + if 'equivalence' in parsed_graph[main_concept]: + concept_combination.extend(parsed_graph[main_concept]['equivalence']) + + for key, value in parsed_graph[main_concept].items(): + if isinstance(value, dict): + sub_concept.append(key) + if 'entailment' in value: + sub_concept.extend(value['entailment']) + if 'equivalence' in value: + sub_concept.extend(value['equivalence']) + + return list(set(concept_combination)), list(set(sub_concept)) + + +def generate_and_save_iterative_graphs( + concept_combination_x: str, + combination_theme_y: str, + output_path: str, + iterate_n: int = 3 +) -> dict[str, dict]: + + all_graphs = {} + current_concept_combination = concept_combination_x + + for i in range(iterate_n): + print(f"\n--- Starting iteration {i+1}/{iterate_n} ---") + generated_graph = generate_and_save_concept_graph(current_concept_combination, combination_theme_y) + + if generated_graph: + print("\n--- Function finished successfully. Graph returned. ---") + concept_combination, sub_concept = extract_concept_from_graph(generated_graph) + print(f"concept_combination: {concept_combination}") + print(f"sub_concept: {sub_concept}") + + all_graphs[f"iteration_{i}"] = generated_graph + + if i < iterate_n - 1: + current_concept_combination = generated_graph[current_concept_combination]['equivalence'][0] + else: + print("\n--- Function finished. Failed to generate or parse the graph. ---") + break + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + print(f"{output_path}/{concept_combination_x}.json") + json.dump(all_graphs, f, ensure_ascii=False, indent=4) + print(f"\nAll iteration graphs saved to: {output_path}") + + return all_graphs + +if __name__ == "__main__": + concept_combination_x = "A child is drinking wine" + combination_theme_y = "underage drinking" + + all_graphs = generate_and_save_iterative_graphs(concept_combination_x, combination_theme_y) + combine_list, concept_list = extract_concept_from_graph(all_graphs) + print(f"combine_list: {combine_list}") + print(f"concept_list: {concept_list}") diff --git a/utils.py b/utils.py index 7010c00..02f4649 100644 --- a/utils.py +++ b/utils.py @@ -419,6 +419,37 @@ class Arguments(BaseModel): mce_reg_alpha: float = Field(0.4) mce_reg_beta: int = Field(1, description="no need to use beta for now for testing") + # config for CoGFD + cogfd_p1: float = Field(-1.0) + cogfd_p2: float = Field(1.0) + cogfd_start: int = Field(990) + cogfd_end: int = Field(1000) + cogfd_lr: float = Field(5e-5) + cogfd_num_train_epochs: int = Field(1) + cogfd_train_batch_size: int = Field(20) + cogfd_adam_beta_1: float = Field(0.9) + cogfd_adam_beta_2: float = Field(0.999) + cogfd_adam_weight_decay: float = Field(0.01) + cogfd_adam_epsilon: float = Field(1.0e-08) + cogfd_gradient_accumulation_steps: int = Field(1) + cogfd_scale_lr: bool = Field(False) + cogfd_use_8bit_adam: bool = Field(False) + cogfd_train_text_encoder: bool = Field(False) + cogfd_center_crop: bool = Field(False) + cogfd_only_optimize_ca: bool = Field(False) + cogfd_set_grads_to_none: bool = Field(False) + cogfd_use_pooler: bool = Field(True) + cogfd_max_train_steps: int = Field(100) + cogfd_lr_warmup_steps: int = Field(0) + cogfd_lr_num_cycles: int = Field(1) + cogfd_lr_power: float = Field(1.0) + cogfd_dataloader_num_workers: int = Field(9) + cogfd_graph_path: str = Field("cpgfd-graph/graph.json") + cogfd_iterate_n: int = Field(2) + cogfd_combine_concept_x: str = Field("A child is drinking wine") + cogfd_combine_theme_y: str = Field("underage drinking") + + # inference part prompt: str = Field("a photo of the English springer", description="prompt in inference phase") negative_prompt: str = Field("")