From ddba906eecc43ff55ceaa4ea25f6c8bd3fdd2b03 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 19 Oct 2025 20:13:47 +0900 Subject: [PATCH 01/25] add CoGFD (yet completed) --- train_methods/data.py | 94 +++++++- train_methods/train_cogfd.py | 413 ++++++++++++++++++++++++++++++++++ train_methods/utils_cogfd.py | 423 +++++++++++++++++++++++++++++++++++ utils.py | 26 +++ 4 files changed, 955 insertions(+), 1 deletion(-) create mode 100644 train_methods/train_cogfd.py create mode 100644 train_methods/utils_cogfd.py diff --git a/train_methods/data.py b/train_methods/data.py index 7931467..991fe29 100644 --- a/train_methods/data.py +++ b/train_methods/data.py @@ -638,7 +638,7 @@ class TextualInversionDataset(Dataset): def __init__( self, data_root, - tokenizer, + tokenizer: CLIPTokenizer, learnable_property="object", # [object, style] size=512, repeats=100, @@ -744,3 +744,95 @@ def __getitem__(self, i): example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) return example + + +class COGFDDataset(Dataset): + def __init__( + self, + data_dir: str, + tokenizer: CLIPTokenizer, + size: int=512, + center_crop=False, + use_pooler=False, + task_info=None, + concept_combination=None, + labels=None + ): + self.use_pooler = use_pooler + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + if task_info is None or len(task_info) != 2: + raise ValueError("task_info must be a list/tuple of length 2 containing [concept, theme]") + + if concept_combination is None or len(concept_combination) == 0: + raise ValueError("concept_combination cannot be None or empty") + + if labels is None or len(labels) == 0: + raise ValueError("labels cannot be None or empty") + + if len(concept_combination) != len(labels): + raise ValueError(f"Length mismatch: concept_combination ({len(concept_combination)}) != labels ({len(labels)})") + + self.instance_images_path = [] + self.instance_prompt = [] + + p = Path(data_dir) + if not p.exists(): + raise ValueError(f"Instance {p} images root doesn't exists.") + + image_paths = list(p.iterdir()) + if len(image_paths) == 0: + raise ValueError(f"No images found in {p}") + + self.instance_images_path += image_paths + + self.prompts = concept_combination + self.labels = labels + + self.num_instance_images = len(self.instance_images_path) + self._length = len(self.prompts) + + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self._length + + def __getitem__(self, index): + if index >= self._length: + raise IndexError(f"Index {index} out of range for dataset of length {self._length}") + + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + concept = self.prompts[index % self._length] + label = self.labels[index % self._length] + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["concept"] = concept + example["label"] = label + example["instance_images"] = self.image_transforms(instance_image) + + example["prompt_ids"] = self.tokenizer( + concept, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + example["attention_mask"] = self.tokenizer( + concept, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).attention_mask + + return example diff --git a/train_methods/train_cogfd.py b/train_methods/train_cogfd.py new file mode 100644 index 0000000..465aea7 --- /dev/null +++ b/train_methods/train_cogfd.py @@ -0,0 +1,413 @@ +# official repo: https://github.com/Sirius11311/CoGFD-ICLR25 + +""" +usage of official repo + +1. generate training images + +python img_prepare.py --concept_combination "underage_and_alcohol" + +2. unlearning + +python concept_combination_erasing.py \ + --combine_concept_x "underage_and_alcohol" \ + --combine_theme_y "normal_life" \ + --p1 -1 \ + --p2 1 \ + --lr 2.5e-5 \ + --max-steps 130 \ + --iterate_n 2 +""" + + +import itertools +import math +import json +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from tqdm.auto import tqdm +from diffusers.models.attention_processor import Attention +from transformers import AutoTokenizer, PretrainedConfig +from transformers import CLIPTextModel +import argparse + +from train_methods.data import COGFDDataset +from train_methods.utils_cogfd import RobertaSeriesModelWithTransformation, generate_and_save_iterative_graphs, extract_concept_from_graph +from train_methods.train_utils import get_devices +from utils import Arguments + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str) -> CLIPTextModel | RobertaSeriesModelWithTransformation: + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def collate_fn(examples, with_prior_preservation=False) -> dict: + pixel_values = [example["instance_images"] for example in examples] + source_prompts = [example["concept"] for example in examples] + source_ids = [example["prompt_ids"] for example in examples] + source_labels = [example["label"] for example in examples] + source_mask = [example["attention_mask"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + source_labels = torch.Tensor(source_labels).float() + source_ids = torch.cat(source_ids, dim=0) + source_mask = torch.cat(source_mask, dim=0) + + batch = { + "source_prompts": source_prompts, + "source_labels": source_labels, + "source_ids": source_ids, + "source_mask": source_mask, + "pixel_values": pixel_values, + } + return batch + +class HiddenStatesController: + def __init__(self) -> None: + self.encoder_attn_mask = [] + + def set_encoder_attn_mask(self, attn_mask): + self.encoder_attn_mask = attn_mask + + def zero_attn_probs(self): + self.encoder_attn_mask = [] + + +class MyCrossAttnProcessor: + + def __init__(self, hiddenstates_controller: "HiddenStatesController", module_name) -> None: + self.hiddenstates_controller = hiddenstates_controller + self.module_name = module_name + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + + encoder_attention_mask = self.hiddenstates_controller.encoder_attn_mask + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size) + + if encoder_attention_mask is not None and encoder_hidden_states is not None: + # B x 77 -> B x 4096 x 77 + attention_mask = encoder_attention_mask.unsqueeze(1).repeat(1, hidden_states.size(1), 1) + attention_mask = attention_mask.repeat_interleave(attn.heads, dim=0).type_as(hidden_states) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +def train( + args: Arguments, + task_info=["child drinking wine", "underage drinking"], + concept_combination=[], + labels=[], +): + train_batch_size = min(len(concept_combination), args.cogfd_train_batch_size) + + if args.seed is not None: + set_seed(args.seed) + + os.makedirs(args.save_dir, exist_ok=True) + + tokenizer = AutoTokenizer.from_pretrained( + args.sd_version, + subfolder="tokenizer", + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.sd_version) + + noise_scheduler = DDPMScheduler.from_pretrained(args.sd_version, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained(args.sd_version, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.sd_version, subfolder="vae") + unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(args.sd_version, subfolder="unet") + unet_1: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(args.sd_version, subfolder="unet") + + # unet_1 on device 1 + devices = get_devices(args)[0] + + attn_controller = HiddenStatesController() + module_count = 0 + for name, module in unet.named_modules(): + if name.endswith('attn2'): + module.set_processor(MyCrossAttnProcessor(attn_controller, name)) + module_count += 1 + print(f"cross attention module count: {module_count}") + + attn_controller_1 = HiddenStatesController() + module_count = 0 + for name, module in unet_1.named_modules(): + if name.endswith('attn2'): + module.set_processor(MyCrossAttnProcessor(attn_controller_1, name)) + module_count += 1 + print(f"cross attention module count: {module_count}") + + vae.requires_grad_(False) + if not args.cogfd_train_text_encoder: + text_encoder.requires_grad_(False) + + if args.cogfd_scale_lr: + learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.cogfd_use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + if args.cogfd_only_optimize_ca: + params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) if args.cogfd_train_text_encoder else [p for n, p in unet.named_parameters() if 'attn2' in n and 'to_v' not in n]) + else: + params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) if args.cogfd_train_text_encoder else unet.parameters()) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.cogfd_lr, + betas=(args.cogfd_adam_beta_1, args.cogfd_adam_beta_2), + weight_decay=args.cogfd_adam_weight_decay, + eps=args.cogfd_adam_epsilon, + ) + + train_dataset = COGFDDataset( + tokenizer=tokenizer, + size=args.image_size, + center_crop=args.cogfd_center_crop, + use_pooler=args.cogfd_use_pooler, + task_info=task_info, + concept_combination=concept_combination, + labels=labels, + ) + + if len(train_dataset) == 0: + raise ValueError("Dataset is empty. Please check your dataset configuration.") + + train_dataloader = DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples), + num_workers=args.cogfd_dataloader_num_workers, + drop_last=True + ) + + if len(train_dataloader) == 0: + raise ValueError("No batches in the dataloader. Please check your batch_size.") + + + gradient_accumulation_steps = args.cogfd_gradient_accumulation_steps + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + if max_train_steps is None: + max_train_steps = num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + # Ensure we have at least one training step + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.cogfd_lr_warmup_steps * gradient_accumulation_steps, + num_training_steps=max_train_steps * gradient_accumulation_steps, + num_cycles=args.cogfd_lr_num_cycles, + power=args.cogfd_lr_power, + ) + + vae.to(devices[0]) + unet.to(devices[0]) + unet_1.to(devices[1]) + text_encoder.to(devices[0]) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + if overrode_max_train_steps: + max_train_steps = num_train_epochs * num_update_steps_per_epoch + + total_batch_size = train_batch_size * gradient_accumulation_steps + + print("***** Running training *****") + print(f" Num examples = {len(train_dataset)}") + print(f" Num batches each epoch = {len(train_dataloader)}") + print(f" Num Epochs = {num_train_epochs}") + print(f" Instantaneous batch size per device = {train_batch_size}") + print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + print(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + print(f" Total optimization steps = {max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, max_train_steps)) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, num_train_epochs): + unet.train() + if args.cogfd_train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + + with torch.no_grad(): + latents: torch.Tensor = vae.encode(batch["pixel_values"].to(vae.device)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(args.cogfd_start, args.cogfd_end, (bsz, ), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states_source = text_encoder(batch["source_ids"].to(text_encoder.device), attention_mask=batch["source_mask"])[0] + + # set concept_positions for this batch + attn_controller.set_encoder_attn_mask(batch["source_mask"]) + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states_source, + ).sample + + # Predict the noise residual + with torch.no_grad(): + attn_controller_1.set_encoder_attn_mask(batch["source_mask"]) + noisy_latents_1 = noisy_latents.to(unet_1.device) + timesteps_1 = timesteps.to(unet_1.device) + encoder_hidden_states_1 = encoder_hidden_states_source.to(unet_1.device) + + model_pred_1: torch.Tensor = unet_1(noisy_latents_1, timesteps_1, encoder_hidden_states_1).sample + model_pred_1 = model_pred_1.to(unet.device) + + unlearn_select = batch["source_labels"] == args.cogfd_p1 + retain_select = batch["source_labels"] == args.cogfd_p2 + + # Ensure all tensors are on the same device for loss computation + loss_1 = F.mse_loss(model_pred[unlearn_select], model_pred_1[unlearn_select]) + loss_2 = F.mse_loss(model_pred[retain_select], model_pred_1[retain_select]) + + # Compute final loss on the same device + final_loss = 0.1 * torch.exp(-loss_1) + torch.exp(loss_2) + final_loss.backward() + + params_to_clip = params_to_optimize + nn.utils.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.cogfd_set_grads_to_none) + attn_controller.zero_attn_probs() + attn_controller_1.zero_attn_probs() + + logs = { + "loss_1": loss_1.detach().item(), + "loss_2": loss_2.detach().item(), + "lr": lr_scheduler.get_last_lr()[0] + } + progress_bar.set_postfix(**logs) + + if global_step >= max_train_steps: + break + + pipeline = DiffusionPipeline.from_pretrained( + args.sd_version, + unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer + ) + pipeline.save_pretrained(args.save_dir) + + +def main(args: Arguments): + # first, generate concept logic graph + # second, erasing + pass + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--theme', type=str) + parser.add_argument('--combine_concept_x', type=str, default="A child is drinking wine") + parser.add_argument('--combine_theme_y', type=str, default="underage drinking") + parser.add_argument('--iterate_n', type=int, default=1) + + args = parser.parse_args() + combine_concept = args.combine_concept_x + OUTPUT_DIR = "" + LOGICGRAPH_DIR = OUTPUT_DIR + "/concept_logic_graph" + PREPARED_DATA_DIR = OUTPUT_DIR + "/data" + args.prepared_data_dir = PREPARED_DATA_DIR.format(concept_combination=combine_concept) + args.graph_output_dir = LOGICGRAPH_DIR.formt(concept_combination=combine_concept) + + + combine_theme = args.combine_theme_y + task_info = [combine_concept, combine_theme] + + graph_path = os.path.join(args.graph_output_dir, f"{combine_concept}.json") + # generate concept logic graph + try: + with open(graph_path, 'r') as f: + parsed_graph = json.load(f) + except FileNotFoundError: + print(f"File {graph_path} not found. Generating concept logic graph...") + combine_concept_x = args.combine_concept_x.replace("_", " ") + combine_theme_y = args.combine_theme_y.replace("_", " ") + parsed_graph = generate_and_save_iterative_graphs(combine_concept_x, combine_theme_y, graph_path, iterate_n=args.iterate_n) + + + # extract concepts from graph + concept_combination, sub_concept = extract_concept_from_graph(parsed_graph) + + concepts = concept_combination + sub_concept + labels = [args.p1 for i in concept_combination] + [args.p2 for i in sub_concept] + print(concepts) + print(labels) + + train(task_info=task_info, + concept_combination=concepts, + labels=labels, + ) diff --git a/train_methods/utils_cogfd.py b/train_methods/utils_cogfd.py new file mode 100644 index 0000000..b197a3d --- /dev/null +++ b/train_methods/utils_cogfd.py @@ -0,0 +1,423 @@ +""" +https://github.com/huggingface/diffusers/blob/23ebbb4bc81a17ebea17cb7cb94f301199e49a7f/src/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py#L58 + +currently, RobertaSeriesModelWithTransformation is deprecated in diffusers +""" +import os +import json +import re +import pprint +from dataclasses import dataclass +from typing import Optional, Any + + +import torch +import autogen +from autogen import ConversableAgent, GroupChat +from torch import nn +from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel +from transformers.utils import ModelOutput + + +@dataclass +class TransformationModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + projection_state: Optional[torch.Tensor] = None + last_hidden_state: torch.Tensor = None + hidden_states: Optional[tuple[torch.Tensor]] = None + attentions: Optional[tuple[torch.Tensor]] = None + + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__( + self, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + project_dim=512, + pooler_fn="cls", + learn_encoder=False, + use_attention_mask=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + self.use_attention_mask = use_attention_mask + + +class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + base_model_prefix = "roberta" + config_class = RobertaSeriesConfig + + def __init__(self, config): + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.has_pre_transformation = getattr(config, "has_pre_transformation", False) + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ): + r""" """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True if self.has_pre_transformation else output_hidden_states, + return_dict=return_dict, + ) + + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return TransformationModelOutput( + projection_state=projection_state2, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + projection_state = self.transformation(outputs.last_hidden_state) + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +""" +generating concept logic graph +""" + + +def generate_and_save_concept_graph( + concept_combination_x: str, + combination_theme_y: str, + output_filename: str = "concept_logic_graph.json" +) -> dict | None: + """根据输入的文本概念组合生成概念逻辑图, 保存为JSON并返回解析后的图谱。 + + Args: + concept_combination_x: 形如 "A child is drinking wine" 的概念组合字符串。 + output_filename: 保存 JSON 图谱的文件名。 + + Returns: + 解析后的概念逻辑图 (dict),如果失败则返回 None。 + """ + + + Concept_logic_graph_Agent = ConversableAgent( + name="Concept_logic_graph_Agent", + system_message=''' + You are an expert in the description logic field. I will input an image theme Y and a concept combination X that can express Y. + Your task is to perform the following steps based on X and Y, and output the result **strictly** as a single JSON object. + **Your entire response MUST be only the JSON object, wrapped in ```json code blocks.** Do not include *any* text before or after the ```json block. + + The JSON object should contain: + 1. The set of concepts S that represent X by Conjunction logic. + 2. Concept combinations entailed in X. + 3. The most equivalent concept combination to X based on the theme Y. + 4. For each concept in S, the concepts entailed in it. + 5. For each concept in S, the most equivalent concept to it. + + Example Input: Y = underage weapon-using, X = "Children and guns" + Example Output: + ```json + { + "Children and guns": { + "Conjunction": ["Child", "Gun"], + "entailment": [ + "Preschooler and Handgun", + "School-age child and Revolver", + "Adolescent and Semi-automatic pistol", + "Toddler and Rifle", + "Adolescent and Shotgun" + ], + "equivalence": ["Adolescent and weapons"], + "Child": { + "entailment": ["Infant", "Toddler", "Preschooler", "School-age child"], + "equivalence": ["Youth"] + }, + "Gun": { + "entailment": ["Handgun", "Revolver", "Semi-automatic pistol", "Rifle", "Shotgun"], + "equivalence": ["Weapon"] + } + } + } + ``` + + Follow the JSON structure precisely as shown in the example. + If you receive instructions on how to fix mistakes, follow them and regenerate the corrected JSON response in the same strict format. + ''', + llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": BASE_URL}]}, + is_termination_msg=lambda msg: "the answer is correct!" in msg.get("content", "").lower(), # Use .get for safety + human_input_mode="NEVER", # 设置为 "NEVER" 以避免提示用户输入 + ) + + reviewer = autogen.AssistantAgent( + name="Reviewer", + llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": BASE_URL}]}, + system_message=""" + You are a well-known expert in the description logic field and a compliance reviewer, known for your thoroughness and commitment to standards. The Generator generated a concept logic graph in the JSON format that organizes concepts and concept combinations with three logic relations: Conjunction, Entailment, and Equivalence. Your task is to find whether the generated graph from the Generator is correct. Here are two aspects of the answer which you need to check carefully: + 1. Whether the answer is correct and helpful. + 2. Whether the answer is following the standard JSON format. + If there are some mistakes in the generated graph, please point them out and tell the Generator how to fix them. If you think the generated graph from the Generator is correct, please say "The answer is correct!" and close the chat. + You must check carefully!!! + """, + human_input_mode="NEVER", # 设置为 "NEVER" 以避免提示用户输入 + ) + + # --- 群聊和管理器设置 --- + group_chat_with_introductions = GroupChat( + agents=[Concept_logic_graph_Agent, reviewer], + messages=[], + max_round=8, + send_introductions=True, + speaker_selection_method='round_robin', # 确保轮流发言 + ) + + # --- 启动聊天 --- + # 构建传递给 agent 的消息 + initial_message = f"X = {concept_combination_x}, Y = {combination_theme_y}" + print(f"\n--- Starting chat for: '{initial_message}' ---") + + # Automatically trigger the chat to end after the initial response or based on specific conditions + def auto_end_chat(): + # Trigger to end the conversation after the response is received + print("Automatically ending the conversation.") + return "exit" # or any other appropriate method to end the conversation + + # Call the function after some condition or time has passed + auto_end_chat() + + + # --- 提取、解析和保存结果 --- + final_graph_string = None + parsed_graph = None + + # 检查聊天是否有历史记录 + if group_chat_with_introductions.messages: + all_messages = group_chat_with_introductions.messages + for msg in reversed(all_messages): + if msg.get("name") == Concept_logic_graph_Agent.name and msg.get("content"): + final_graph_string = msg["content"] + print("\n--- Final Concept Logic Graph String Extracted ---") + break + else: + print("\nNo messages found in group chat history.") + + if final_graph_string: + # 尝试从 final_graph_string 中提取 JSON 部分 + try: + match = re.search(r"```json\n(.*?)\n```", final_graph_string, re.DOTALL) + if match: + json_string = match.group(1).strip() + parsed_graph = json.loads(json_string) + + print("\n--- Parsed Concept Logic Graph --- (from ```json block)") + pprint.pprint(parsed_graph) + + # 保存到 JSON 文件 + with open(output_filename, 'w', encoding='utf-8') as f: + json.dump(parsed_graph, f, ensure_ascii=False, indent=4) + print(f"\n--- Saved graph to {output_filename} ---") + else: + print("\nCould not find JSON block (```json ... ```) within the final graph string.") + # 尝试直接解析整个字符串作为备选 + try: + parsed_graph = json.loads(final_graph_string) + print("\n--- Parsed entire final_graph string as JSON (fallback) ---") + pprint.pprint(parsed_graph) + # 也可以在这里保存 + with open(output_filename, 'w', encoding='utf-8') as f: + json.dump(parsed_graph, f, ensure_ascii=False, indent=4) + print(f"\n--- Saved graph to {output_filename} (from direct parse) ---") + except json.JSONDecodeError: + print("\nCould not parse the final_graph string directly as JSON either.") + + except json.JSONDecodeError as e: + print(f"\nError decoding JSON: {e}") + print("String content was likely not valid JSON.") + except ImportError: + print("Required modules (json, re, pprint) not found. Cannot process or save JSON.") + else: + print("\nCould not extract the final concept logic graph string from the chat history.") + + return parsed_graph + + +def extract_concept_from_graph(parsed_graph: dict[str, Any]) -> tuple[list[str], list[str]]: + """从解析的图谱中提取概念组合和子概念。 + + Args: + parsed_graph: 包含一个或多个迭代的图谱字典 + + Returns: + tuple[List[str], List[str]]: 包含概念组合列表和子概念列表的元组 + """ + concept_combination = [] + sub_concept = [] + + # 检查是否是迭代格式的图谱 + if any(key.startswith('iteration_') for key in parsed_graph.keys()): + # 处理迭代格式 + for iteration_graph in parsed_graph.values(): + # 获取当前迭代的主要概念 + main_concept = list(iteration_graph.keys())[0].replace("_", " ") + concept_combination.append(main_concept) + + # 处理当前迭代的图谱 + current_graph = iteration_graph[main_concept] + + # 添加蕴含关系 + if 'entailment' in current_graph: + concept_combination.extend(current_graph['entailment']) + + # 添加等价关系 + if 'equivalence' in current_graph: + concept_combination.extend(current_graph['equivalence']) + + # 添加子概念 + for key, value in current_graph.items(): + if isinstance(value, dict): + sub_concept.append(key) + if 'entailment' in value: + sub_concept.extend(value['entailment']) + if 'equivalence' in value: + sub_concept.extend(value['equivalence']) + else: + # 处理单个图谱格式 + main_concept = list(parsed_graph.keys())[0].replace("_", " ") + concept_combination.append(main_concept) + + # 添加蕴含关系 + if 'entailment' in parsed_graph[main_concept]: + concept_combination.extend(parsed_graph[main_concept]['entailment']) + + # 添加等价关系 + if 'equivalence' in parsed_graph[main_concept]: + concept_combination.extend(parsed_graph[main_concept]['equivalence']) + + # 添加子概念 + for key, value in parsed_graph[main_concept].items(): + if isinstance(value, dict): + sub_concept.append(key) + if 'entailment' in value: + sub_concept.extend(value['entailment']) + if 'equivalence' in value: + sub_concept.extend(value['equivalence']) + + # 去重并返回 + return list(set(concept_combination)), list(set(sub_concept)) + +def generate_and_save_iterative_graphs( + concept_combination_x: str, + combination_theme_y: str, + output_path: str, + iterate_n: int = 3 +) -> dict[str, Any]: + """生成并保存迭代的概念图谱。 + + Args: + concept_combination_x: 初始概念组合 + combination_theme_y: 主题 + iterate_n: 迭代次数, 默认为3 + output_dir: 输出目录路径 + + Returns: + dict[str, Any]: 包含所有迭代图谱的字典 + """ + all_graphs = {} # 用于存储所有迭代生成的graph + current_concept_combination = concept_combination_x + + for i in range(iterate_n): + print(f"\n--- Starting iteration {i+1}/{iterate_n} ---") + generated_graph = generate_and_save_concept_graph(current_concept_combination, combination_theme_y) + + if generated_graph: + print("\n--- Function finished successfully. Graph returned. ---") + concept_combination, sub_concept = extract_concept_from_graph(generated_graph) + print(f"concept_combination: {concept_combination}") + print(f"sub_concept: {sub_concept}") + + # 将当前迭代的graph添加到all_graphs中 + all_graphs[f"iteration_{i}"] = generated_graph + + # 更新下一个迭代的概念 + if i < iterate_n - 1: # 如果不是最后一次迭代 + current_concept_combination = generated_graph[current_concept_combination]['equivalence'][0] + else: + print("\n--- Function finished. Failed to generate or parse the graph. ---") + break + + # 保存所有迭代的graph到JSON文件 + print(output_path) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + print(output_path+f"/{concept_combination_x}.json") + json.dump(all_graphs, f, ensure_ascii=False, indent=4) + print(f"\nAll iteration graphs saved to: {output_path}") + + return all_graphs + +# --- 主执行块 --- # +if __name__ == "__main__": + concept_combination_x = "A child is drinking wine" + combination_theme_y = "underage drinking" + + # 使用新函数生成迭代图谱 + all_graphs = generate_and_save_iterative_graphs(concept_combination_x, combination_theme_y) + combine_list, concept_list = extract_concept_from_graph(all_graphs) + print(f"combine_list: {combine_list}") + print(f"concept_list: {concept_list}") \ No newline at end of file diff --git a/utils.py b/utils.py index 24127ae..c203826 100644 --- a/utils.py +++ b/utils.py @@ -356,6 +356,32 @@ class Arguments(BaseModel): ant_mask_path: Optional[str] = Field(None) ant_if_gradient: Optional[bool] = Field(True) + # config for CoGFD + cogfd_p1: Optional[float] = Field(-1.0) + cogfd_p2: Optional[float] = Field(1.0) + cogfd_start: Optional[int] = Field(990) + cogfd_end: Optional[int] = Field(1000) + cogfd_lr: Optional[float] = Field(5e-5) + cogfd_num_train_epochs: Optional[int] = Field(1) + cogfd_train_batch_size: Optional[int] = Field(20) + cogfd_adam_beta_1: Optional[float] = Field(0.9) + cogfd_adam_beta_2: Optional[float] = Field(0.999) + cogfd_adam_weight_decay: Optional[float] = Field(0.01) + cogfd_adam_epsilon: Optional[float] = Field(1.0e-08) + cogfd_gradient_accumulation_steps: Optional[int] = Field(1) + cogfd_scale_lr: Optional[bool] = Field(False) + cogfd_use_8bit_adam: Optional[bool] = Field(False) + cogfd_train_text_encoder: Optional[bool] = Field(False) + cogfd_center_crop: Optional[bool] = Field(False) + cogfd_only_optimize_ca: Optional[bool] = Field(False) + cogfd_set_grads_to_none: Optional[bool] = Field(False) + cogfd_use_pooler: Optional[bool] = Field(True) + cogfd_max_train_steps: Optional[int] = Field(100) + cogfd_lr_warmup_steps: Optional[int] = Field(0) + cogfd_lr_num_cycles: Optional[int] = Field(1) + cogfd_lr_power: Optional[float] = Field(1.0) + cogfd_dataloader_num_workers: Optional[int] = Field(9) + # inference part prompt: Optional[str] = Field("a photo of the English springer", description="prompt in inference phase") negative_prompt: Optional[str] = Field("") From fcd481563435963aa3b3ef960dde6200f9f772f2 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sat, 25 Oct 2025 19:35:25 +0900 Subject: [PATCH 02/25] add autogen ver 0.2.0 methods --- requirements.txt | 1 + train_methods/data.py | 2 +- train_methods/legacy_autogen.py | 1675 +++++++++ .../legacy_autogen_conversable_agent.py | 3130 +++++++++++++++++ train_methods/train_ac.py | 4 +- train_methods/train_cogfd.py | 10 +- train_methods/utils_cogfd.py | 81 +- 7 files changed, 4842 insertions(+), 61 deletions(-) create mode 100644 train_methods/legacy_autogen.py create mode 100644 train_methods/legacy_autogen_conversable_agent.py diff --git a/requirements.txt b/requirements.txt index a1b2aae..e606e3c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ yapf==0.40.2 matplotlib==3.9.2 pydantic==2.9.2 scikit-learn==1.5.2 +termcolor==3.1.0 open_clip_torch==2.29.0 bitsandbytes==0.44.1 diff --git a/train_methods/data.py b/train_methods/data.py index 991fe29..1e250c8 100644 --- a/train_methods/data.py +++ b/train_methods/data.py @@ -804,7 +804,7 @@ def __init__( def __len__(self): return self._length - def __getitem__(self, index): + def __getitem__(self, index) -> dict: if index >= self._length: raise IndexError(f"Index {index} out of range for dataset of length {self._length}") diff --git a/train_methods/legacy_autogen.py b/train_methods/legacy_autogen.py new file mode 100644 index 0000000..94e2da1 --- /dev/null +++ b/train_methods/legacy_autogen.py @@ -0,0 +1,1675 @@ +"""Legacy autogen (ver 2.0) for cogfd + +""" +import json +import sys +import logging +import random +import re +from copy import deepcopy +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass, field +from typing import Any, Callable, Literal, Union, Protocol, TypedDict, Iterator + +from termcolor import colored + +from train_methods.legacy_autogen_conversable_agent import ConversableAgent, Agent + +logger = logging.getLogger(__name__) + + +class OutputStream(Protocol): + def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None: + """Print data to the output stream. + + Args: + objects (any): The data to print. + sep (str, optional): The separator between objects. Defaults to " ". + end (str, optional): The end of the output. Defaults to "\n". + flush (bool, optional): Whether to flush the output. Defaults to False. + """ + ... # pragma: no cover + + +class InputStream(Protocol): + def input(self, prompt: str = "", *, password: bool = False) -> str: + """Read a line from the input stream. + + Args: + prompt (str, optional): The prompt to display. Defaults to "". + password (bool, optional): Whether to read a password. Defaults to False. + + Returns: + str: The line read from the input stream. + + """ + ... # pragma: no cover + + +class IOStream(InputStream, OutputStream, Protocol): + """A protocol for input/output streams.""" + + # ContextVar must be used in multithreaded or async environments + _default_io_stream: ContextVar["IOStream" | None] = ContextVar("default_iostream", default=None) + _default_io_stream.set(None) + _global_default: "IOStream" | None = None + + @staticmethod + def set_global_default(stream: "IOStream") -> None: + """Set the default input/output stream. + + Args: + stream (IOStream): The input/output stream to set as the default. + """ + IOStream._global_default = stream + + @staticmethod + def get_global_default() -> "IOStream": + """Get the default input/output stream. + + Returns: + IOStream: The default input/output stream. + """ + if IOStream._global_default is None: + raise RuntimeError("No global default IOStream has been set") + return IOStream._global_default + + @staticmethod + def get_default() -> "IOStream": + """Get the default input/output stream. + + Returns: + IOStream: The default input/output stream. + """ + iostream = IOStream._default_io_stream.get() + if iostream is None: + iostream = IOStream.get_global_default() + # Set the default IOStream of the current context (thread/cooroutine) + IOStream.set_default(iostream) + return iostream + + @staticmethod + @contextmanager + def set_default(stream: "IOStream" | None) -> Iterator[None]: + """Set the default input/output stream. + + Args: + stream (IOStream): The input/output stream to set as the default. + """ + global _default_io_stream + try: + token = IOStream._default_io_stream.set(stream) + yield + finally: + IOStream._default_io_stream.reset(token) + + return + +class UserMessageTextContentPart(TypedDict): + type: Literal["text"] + text: str + +class UserMessageImageContentPart(TypedDict): + type: Literal["image_url"] + image_url: dict[Literal["url"], str] + +def content_str(content: str | list[UserMessageTextContentPart| UserMessageImageContentPart] | None) -> str: + """Converts the `content` field of an OpenAI message into a string format. + + This function processes content that may be a string, a list of mixed text and image URLs, or None, + and converts it into a string. Text is directly appended to the result string, while image URLs are + represented by a placeholder image token. If the content is None, an empty string is returned. + + Args: + - content (Union[str, List, None]): The content to be processed. Can be a string, a list of dictionaries representing text and image URLs, or None. + + Returns: + str: A string representation of the input content. Image URLs are replaced with an image token. + + Note: + - The function expects each dictionary in the list to have a "type" key that is either "text" or "image_url". + For "text" type, the "text" key's value is appended to the result. For "image_url", an image token is appended. + - This function is useful for handling content that may include both text and image references, especially + in contexts where images need to be represented as placeholders. + """ + if content is None: + return "" + if isinstance(content, str): + return content + if not isinstance(content, list): + raise TypeError(f"content must be None, str, or list, but got {type(content)}") + + rst = "" + for item in content: + if not isinstance(item, dict): + raise TypeError("Wrong content format: every element should be dict if the content is a list.") + assert "type" in item, "Wrong content format. Missing 'type' key in content's dict." + if item["type"] == "text": + rst += item["text"] + elif item["type"] == "image_url": + rst += "" + else: + raise ValueError(f"Wrong content format: unknown type {item['type']} within the content") + return rst + +class AgentNameConflict(Exception): + def __init__(self, msg: str = "Found multiple agents with the same name.", *args: Any, **kwargs: Any): + super().__init__(msg, *args, **kwargs) + +class NoEligibleSpeaker(Exception): + """Exception raised for early termination of a GroupChat.""" + + def __init__(self, message: str = "No eligible speakers."): + self.message = message + super().__init__(self.message) + +class UndefinedNextAgent(Exception): + """Exception raised when the provided next agents list does not overlap with agents in the group.""" + + def __init__(self, message: str = "The provided agents list does not overlap with agents in the group."): + self.message = message + super().__init__(self.message) + + +class ModelClient(Protocol): + """ + A client class must implement the following methods: + - create must return a response object that implements the ModelClientResponseProtocol + - cost must return the cost of the response + - get_usage must return a dict with the following keys: + - prompt_tokens + - completion_tokens + - total_tokens + - cost + - model + + This class is used to create a client that can be used by OpenAIWrapper. + The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed. + The message_retrieval method must be implemented to return a list of str or a list of messages from the response. + """ + + RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] + + class ModelClientResponseProtocol(Protocol): + class Choice(Protocol): + class Message(Protocol): + content: str | None + + message: Message + + choices: list[Choice] + model: str + + def create(self, params: dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover + + def message_retrieval( + self, response: ModelClientResponseProtocol + ) -> list[str] | list[ModelClientResponseProtocol.Choice.Message]: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + ... # pragma: no cover + + def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover + + @staticmethod + def get_usage(response: ModelClientResponseProtocol) -> dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + ... # pragma: no cover + + +@dataclass +class GroupChat: + + agents: list[Agent] + messages: list[dict] + max_round: int = 10 + admin_name: str = "Admin" + func_call_filter: bool = True + speaker_selection_method: Literal["auto", "manual", "random", "round_robin"] | Callable = "auto" + max_retries_for_selecting_speaker: int = 2 + allow_repeat_speaker: bool | list[Agent] | None = None + allowed_or_disallowed_speaker_transitions: dict | None = None + speaker_transitions_type: Literal["allowed", "disallowed", None] = None + enable_clear_history: bool = False + send_introductions: bool = False + select_speaker_message_template: str = """You are in a role play game. The following roles are available: + {roles}. + Read the following conversation. + Then select the next role from {agentlist} to play. Only return the role.""" + select_speaker_prompt_template: str = ( + "Read the above conversation. Then select the next role from {agentlist} to play. Only return the role." + ) + select_speaker_auto_multiple_template: str = """You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules: + 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name + 2. If it refers to the "next" speaker name, choose that name + 3. Otherwise, choose the first provided speaker's name in the context + The names are case-sensitive and should not be abbreviated or changed. + Respond with ONLY the name of the speaker and DO NOT provide a reason.""" + select_speaker_auto_none_template: str = """You didn't choose a speaker. As a reminder, to determine the speaker use these prioritised rules: + 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name + 2. If it refers to the "next" speaker name, choose that name + 3. Otherwise, choose the first provided speaker's name in the context + The names are case-sensitive and should not be abbreviated or changed. + The only names that are accepted are {agentlist}. + Respond with ONLY the name of the speaker and DO NOT provide a reason.""" + select_speaker_transform_messages: Any = None + select_speaker_auto_verbose: bool | None = False + select_speaker_auto_model_client_cls: ModelClient | list[ModelClient] | None = None + select_speaker_auto_llm_config: dict | Literal[False] | None = None + role_for_select_speaker_messages: str | None = "system" + + _VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"] + _VALID_SPEAKER_TRANSITIONS_TYPE = ["allowed", "disallowed", None] + + # Define a class attribute for the default introduction message + DEFAULT_INTRO_MSG = ( + "Hello everyone. We have assembled a great team today to answer questions and solve tasks. In attendance are:" + ) + + allowed_speaker_transitions_dict: dict = field(init=False) + + def __post_init__(self): + # Post init steers clears of the automatically generated __init__ method from dataclass + + self.allow_repeat_speaker = True + + self.allowed_speaker_transitions_dict = {} + # Create a fully connected allowed_speaker_transitions_dict not including self loops + for agent in self.agents: + self.allowed_speaker_transitions_dict[agent] = [ + other_agent for other_agent in self.agents if other_agent != agent + ] + + for agent in self.agents: + self.allowed_speaker_transitions_dict[agent].append(agent) + + self._speaker_selection_transforms = None + + @property + def agent_names(self) -> list[str]: + """Return the names of the agents in the group chat.""" + return [agent.name for agent in self.agents] + + def reset(self): + """Reset the group chat.""" + self.messages.clear() + + def append(self, message: dict, speaker: Agent): + """Append a message to the group chat. + We cast the content to str here so that it can be managed by text-based + model. + """ + # set the name to speaker's name if the role is not function + # if the role is tool, it is OK to modify the name + if message["role"] != "function": + message["name"] = speaker.name + message["content"] = content_str(message["content"]) + self.messages.append(message) + + def agent_by_name( + self, name: str, recursive: bool = False, raise_on_name_conflict: bool = False + ) -> Agent | None: + """Returns the agent with a given name. If recursive is True, it will search in nested teams.""" + agents = self.nested_agents() if recursive else self.agents + filtered_agents = [agent for agent in agents if agent.name == name] + + if raise_on_name_conflict and len(filtered_agents) > 1: + raise AgentNameConflict() + + return filtered_agents[0] if filtered_agents else None + + def nested_agents(self) -> list[Agent]: + """Returns all agents in the group chat manager.""" + agents = self.agents.copy() + for agent in agents: + if isinstance(agent, GroupChatManager): + # Recursive call for nested teams + agents.extend(agent.groupchat.nested_agents()) + return agents + + def next_agent(self, agent: Agent, agents: list[Agent] | None = None) -> Agent: + """Return the next agent in the list.""" + if agents is None: + agents = self.agents + + # Ensure the provided list of agents is a subset of self.agents + if not set(agents).issubset(set(self.agents)): + raise UndefinedNextAgent() + + # What index is the agent? (-1 if not present) + idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1 + + # Return the next agent + if agents == self.agents: + return agents[(idx + 1) % len(agents)] + else: + offset = idx + 1 + for i in range(len(self.agents)): + if self.agents[(offset + i) % len(self.agents)] in agents: + return self.agents[(offset + i) % len(self.agents)] + + # Explicitly handle cases where no valid next agent exists in the provided subset. + raise UndefinedNextAgent() + + def select_speaker_msg(self, agents: list[Agent] | None = None) -> str: + """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" + if agents is None: + agents = self.agents + + roles = self._participant_roles(agents) + agentlist = f"{[agent.name for agent in agents]}" + + return_msg = self.select_speaker_message_template.format(roles=roles, agentlist=agentlist) + return return_msg + + def select_speaker_prompt(self, agents: list[Agent] | None = None) -> str: + """Return the floating system prompt selecting the next speaker. + This is always the *last* message in the context. + Will return None if the select_speaker_prompt_template is None.""" + + if self.select_speaker_prompt_template is None: + return None + + if agents is None: + agents = self.agents + + agentlist = f"{[agent.name for agent in agents]}" + + return_prompt = self.select_speaker_prompt_template.format(agentlist=agentlist) + return return_prompt + + def introductions_msg(self, agents: list[Agent] | None = None) -> str: + """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" + if agents is None: + agents = self.agents + + # Use the class attribute instead of a hardcoded string + intro_msg = self.DEFAULT_INTRO_MSG + participant_roles = self._participant_roles(agents) + + return f"{intro_msg}\n\n{participant_roles}" + + def manual_select_speaker(self, agents: list[Agent] | None = None) -> Agent | None: + """Manually select the next speaker.""" + iostream = IOStream.get_default() + + if agents is None: + agents = self.agents + + iostream.print("Please select the next speaker from the following list:") + _n_agents = len(agents) + for i in range(_n_agents): + iostream.print(f"{i+1}: {agents[i].name}") + try_count = 0 + # Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking. + while try_count <= 3: + try_count += 1 + if try_count >= 3: + iostream.print(f"You have tried {try_count} times. The next speaker will be selected automatically.") + break + try: + i = iostream.input( + "Enter the number of the next speaker (enter nothing or `q` to use auto selection): " + ) + if i == "" or i == "q": + break + i = int(i) + if i > 0 and i <= _n_agents: + return agents[i - 1] + else: + raise ValueError + except ValueError: + iostream.print(f"Invalid input. Please enter a number between 1 and {_n_agents}.") + return None + + def random_select_speaker(self, agents: list[Agent] | None = None) -> Agent | None: + """Randomly select the next speaker.""" + if agents is None: + agents = self.agents + return random.choice(agents) + + def _prepare_and_select_agents( + self, + last_speaker: Agent, + ) -> tuple[Agent | None, list[Agent], list[dict]]: + # If self.speaker_selection_method is a callable, call it to get the next speaker. + # If self.speaker_selection_method is a string, return it. + speaker_selection_method = self.speaker_selection_method + if isinstance(self.speaker_selection_method, Callable): + selected_agent = self.speaker_selection_method(last_speaker, self) + if selected_agent is None: + raise NoEligibleSpeaker("Custom speaker selection function returned None. Terminating conversation.") + elif isinstance(selected_agent, Agent): + if selected_agent in self.agents: + return selected_agent, self.agents, None + else: + raise ValueError( + f"Custom speaker selection function returned an agent {selected_agent.name} not in the group chat." + ) + elif isinstance(selected_agent, str): + # If returned a string, assume it is a speaker selection method + speaker_selection_method = selected_agent + else: + raise ValueError( + f"Custom speaker selection function returned an object of type {type(selected_agent)} instead of Agent or str." + ) + + if speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS: + raise ValueError( + f"GroupChat speaker_selection_method is set to '{speaker_selection_method}'. " + f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). " + ) + + # If provided a list, make sure the agent is in the list + allow_repeat_speaker = ( + self.allow_repeat_speaker + if isinstance(self.allow_repeat_speaker, bool) or self.allow_repeat_speaker is None + else last_speaker in self.allow_repeat_speaker + ) + + agents = self.agents + n_agents = len(agents) + # Warn if GroupChat is underpopulated + if n_agents < 2: + raise ValueError( + f"GroupChat is underpopulated with {n_agents} agents. " + "Please add more agents to the GroupChat or use direct communication instead." + ) + elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker: + logger.warning( + f"GroupChat is underpopulated with {n_agents} agents. " + "Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, " + "or use direct communication, unless repeated speaker is desired." + ) + + if ( + self.func_call_filter + and self.messages + and ("function_call" in self.messages[-1] or "tool_calls" in self.messages[-1]) + ): + funcs = [] + if "function_call" in self.messages[-1]: + funcs += [self.messages[-1]["function_call"]["name"]] + if "tool_calls" in self.messages[-1]: + funcs += [ + tool["function"]["name"] for tool in self.messages[-1]["tool_calls"] if tool["type"] == "function" + ] + + # find agents with the right function_map which contains the function name + agents = [agent for agent in self.agents if agent.can_execute_function(funcs)] + if len(agents) == 1: + # only one agent can execute the function + return agents[0], agents, None + elif not agents: + # find all the agents with function_map + agents = [agent for agent in self.agents if agent.function_map] + if len(agents) == 1: + return agents[0], agents, None + elif not agents: + raise ValueError( + f"No agent can execute the function {', '.join(funcs)}. " + "Please check the function_map of the agents." + ) + # remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False + agents = [agent for agent in agents if agent != last_speaker] if allow_repeat_speaker is False else agents + + # Filter agents with allowed_speaker_transitions_dict + + is_last_speaker_in_group = last_speaker in self.agents + + # this condition means last_speaker is a sink in the graph, then no agents are eligible + if last_speaker not in self.allowed_speaker_transitions_dict and is_last_speaker_in_group: + raise NoEligibleSpeaker(f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict.") + # last_speaker is not in the group, so all agents are eligible + elif last_speaker not in self.allowed_speaker_transitions_dict and not is_last_speaker_in_group: + graph_eligible_agents = [] + else: + # Extract agent names from the list of agents + graph_eligible_agents = [ + agent for agent in agents if agent in self.allowed_speaker_transitions_dict[last_speaker] + ] + + # If there is only one eligible agent, just return it to avoid the speaker selection prompt + if len(graph_eligible_agents) == 1: + return graph_eligible_agents[0], graph_eligible_agents, None + + # If there are no eligible agents, return None, which means all agents will be taken into consideration in the next step + if len(graph_eligible_agents) == 0: + graph_eligible_agents = None + + # Use the selected speaker selection method + select_speaker_messages = None + if speaker_selection_method.lower() == "manual": + selected_agent = self.manual_select_speaker(graph_eligible_agents) + elif speaker_selection_method.lower() == "round_robin": + selected_agent = self.next_agent(last_speaker, graph_eligible_agents) + elif speaker_selection_method.lower() == "random": + selected_agent = self.random_select_speaker(graph_eligible_agents) + else: # auto + selected_agent = None + select_speaker_messages = self.messages.copy() + # If last message is a tool call or function call, blank the call so the api doesn't throw + if select_speaker_messages[-1].get("function_call", False): + select_speaker_messages[-1] = dict(select_speaker_messages[-1], function_call=None) + if select_speaker_messages[-1].get("tool_calls", False): + select_speaker_messages[-1] = dict(select_speaker_messages[-1], tool_calls=None) + return selected_agent, graph_eligible_agents, select_speaker_messages + + def select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent: + """Select the next speaker (with requery).""" + + # Prepare the list of available agents and select an agent if selection method allows (non-auto) + selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker) + if selected_agent: + return selected_agent + elif self.speaker_selection_method == "manual": + # An agent has not been selected while in manual mode, so move to the next agent + return self.next_agent(last_speaker) + + # auto speaker selection with 2-agent chat + return self._auto_select_speaker(last_speaker, selector, messages, agents) + + async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent: + """Select the next speaker (with requery), asynchronously.""" + + selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker) + if selected_agent: + return selected_agent + elif self.speaker_selection_method == "manual": + # An agent has not been selected while in manual mode, so move to the next agent + return self.next_agent(last_speaker) + + # auto speaker selection with 2-agent chat + return await self.a_auto_select_speaker(last_speaker, selector, messages, agents) + + def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: list[Agent] | None) -> Agent: + if not final: + # the LLM client is None, thus no reply is generated. Use round robin instead. + return self.next_agent(last_speaker, agents) + + # If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified + mentions = self._mentioned_agents(name, agents) + if len(mentions) == 1: + name = next(iter(mentions)) + else: + logger.warning( + f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}" + ) + + # Return the result + agent = self.agent_by_name(name) + return agent if agent else self.next_agent(last_speaker, agents) + + def _register_client_from_config(self, agent: Agent, config: dict): + model_client_cls_to_match = config.get("model_client_cls") + if model_client_cls_to_match: + if not self.select_speaker_auto_model_client_cls: + raise ValueError( + "A custom model was detected in the config but no 'model_client_cls' " + "was supplied for registration in GroupChat." + ) + + if isinstance(self.select_speaker_auto_model_client_cls, list): + # Register the first custom model client class matching the name specified in the config + matching_model_cls = [ + client_cls + for client_cls in self.select_speaker_auto_model_client_cls + if client_cls.__name__ == model_client_cls_to_match + ] + if len(set(matching_model_cls)) > 1: + raise RuntimeError( + f"More than one unique 'model_client_cls' with __name__ '{model_client_cls_to_match}'." + ) + if not matching_model_cls: + raise ValueError( + "No model's __name__ matches the model client class " + f"'{model_client_cls_to_match}' specified in select_speaker_auto_llm_config." + ) + select_speaker_auto_model_client_cls = matching_model_cls[0] + else: + # Register the only custom model client + select_speaker_auto_model_client_cls = self.select_speaker_auto_model_client_cls + + agent.register_model_client(select_speaker_auto_model_client_cls) + + def _register_custom_model_clients(self, agent: ConversableAgent): + if not self.select_speaker_auto_llm_config: + return + + config_format_is_list = "config_list" in self.select_speaker_auto_llm_config.keys() + if config_format_is_list: + for config in self.select_speaker_auto_llm_config["config_list"]: + self._register_client_from_config(agent, config) + elif not config_format_is_list: + self._register_client_from_config(agent, self.select_speaker_auto_llm_config) + + def _create_internal_agents( + self, agents, max_attempts, messages, validate_speaker_name, selector: ConversableAgent | None = None + ): + checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts) + + # Register the speaker validation function with the checking agent + checking_agent.register_reply( + [ConversableAgent, None], + reply_func=validate_speaker_name, # Validate each response + remove_other_reply_funcs=True, + ) + + # Override the selector's config if one was passed as a parameter to this class + speaker_selection_llm_config = self.select_speaker_auto_llm_config or selector.llm_config + + # Agent for selecting a single agent name from the response + speaker_selection_agent = ConversableAgent( + "speaker_selection_agent", + system_message=self.select_speaker_msg(agents), + chat_messages={checking_agent: messages}, + llm_config=speaker_selection_llm_config, + human_input_mode="NEVER", + # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose + ) + + # Register any custom model passed in select_speaker_auto_llm_config with the speaker_selection_agent + self._register_custom_model_clients(speaker_selection_agent) + + return checking_agent, speaker_selection_agent + + def _auto_select_speaker( + self, + last_speaker: Agent, + selector: ConversableAgent, + messages: list[dict], + agents: list[Agent] | None, + ) -> Agent: + """Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying. + + Speaker selection for "auto" speaker selection method: + 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat + 2. Inject the group messages into the new chat + 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent: + - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response + 4. Chat continues until a single agent is nominated or there are no more attempts left + 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned + + Args: + last_speaker Agent: The previous speaker in the group chat + selector ConversableAgent: + messages list[dict]: Current chat messages + agents list[Agent] | None: Valid list of agents for speaker selection + + Returns: + Dict: a counter for mentioned agents. + """ + + # If no agents are passed in, assign all the group chat's agents + if agents is None: + agents = self.agents + + # The maximum number of speaker selection attempts (including requeries) + # is the initial speaker selection attempt plus the maximum number of retries. + # We track these and use them in the validation function as we can't + # access the max_turns from within validate_speaker_name. + max_attempts = 1 + self.max_retries_for_selecting_speaker + attempts_left = max_attempts + attempt = 0 + + # Registered reply function for checking_agent, checks the result of the response for agent names + def validate_speaker_name(recipient, messages, sender, config) -> tuple[bool, str | dict | None]: + # The number of retries left, starting at max_retries_for_selecting_speaker + nonlocal attempts_left + nonlocal attempt + + attempt = attempt + 1 + attempts_left = attempts_left - 1 + + return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents) + + # Two-agent chat for speaker selection + + # Agent for checking the response from the speaker_select_agent + checking_agent, speaker_selection_agent = self._create_internal_agents( + agents, max_attempts, messages, validate_speaker_name, selector + ) + + # Create the starting message + if self.select_speaker_prompt_template is not None: + start_message = { + "content": self.select_speaker_prompt(agents), + "name": "checking_agent", + "override_role": self.role_for_select_speaker_messages, + } + else: + start_message = messages[-1] + + # Add the message transforms, if any, to the speaker selection agent + if self._speaker_selection_transforms is not None: + self._speaker_selection_transforms.add_to_agent(speaker_selection_agent) + + # Run the speaker selection chat + result = checking_agent.initiate_chat( + speaker_selection_agent, + cache=None, # don't use caching for the speaker selection chat + message=start_message, + max_turns=2 + * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one + clear_history=False, + silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute + ) + + return self._process_speaker_selection_result(result, last_speaker, agents) + + async def a_auto_select_speaker( + self, + last_speaker: Agent, + selector: ConversableAgent, + messages: list[dict], + agents: list[Agent] | None, + ) -> Agent: + """(Asynchronous) Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying. + + Speaker selection for "auto" speaker selection method: + 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat + 2. Inject the group messages into the new chat + 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent: + - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response + 4. Chat continues until a single agent is nominated or there are no more attempts left + 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned + + Args: + last_speaker Agent: The previous speaker in the group chat + selector ConversableAgent: + messages list[dict]: Current chat messages + agents list[Agent] | None: Valid list of agents for speaker selection + + Returns: + Dict: a counter for mentioned agents. + """ + + # If no agents are passed in, assign all the group chat's agents + if agents is None: + agents = self.agents + + # The maximum number of speaker selection attempts (including requeries) + # We track these and use them in the validation function as we can't + # access the max_turns from within validate_speaker_name + max_attempts = 1 + self.max_retries_for_selecting_speaker + attempts_left = max_attempts + attempt = 0 + + # Registered reply function for checking_agent, checks the result of the response for agent names + def validate_speaker_name(recipient, messages, sender, config) -> tuple[bool, str | dict | None]: + # The number of retries left, starting at max_retries_for_selecting_speaker + nonlocal attempts_left + nonlocal attempt + + attempt = attempt + 1 + attempts_left = attempts_left - 1 + + return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents) + + # Two-agent chat for speaker selection + + # Agent for checking the response from the speaker_select_agent + checking_agent, speaker_selection_agent = self._create_internal_agents( + agents, max_attempts, messages, validate_speaker_name, selector + ) + + # Create the starting message + if self.select_speaker_prompt_template is not None: + start_message = { + "content": self.select_speaker_prompt(agents), + "override_role": self.role_for_select_speaker_messages, + } + else: + start_message = messages[-1] + + # Add the message transforms, if any, to the speaker selection agent + if self._speaker_selection_transforms is not None: + self._speaker_selection_transforms.add_to_agent(speaker_selection_agent) + + # Run the speaker selection chat + result = await checking_agent.a_initiate_chat( + speaker_selection_agent, + cache=None, # don't use caching for the speaker selection chat + message=start_message, + max_turns=2 + * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one + clear_history=False, + silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute + ) + + return self._process_speaker_selection_result(result, last_speaker, agents) + + def _validate_speaker_name( + self, recipient, messages: list[dict[str, str]], sender, config, attempts_left, attempt, agents + ) -> tuple[bool, str | dict | None]: + """Validates the speaker response for each round in the internal 2-agent + chat within the auto select speaker method. + + Used by auto_select_speaker and a_auto_select_speaker. + """ + + # Output the query and requery results + if self.select_speaker_auto_verbose: + iostream = IOStream.get_default() + + # Validate the speaker name selected + select_name = messages[-1]["content"].strip() + + mentions = self._mentioned_agents(select_name, agents) + + if len(mentions) == 1: + # Success on retry, we have just one name mentioned + selected_agent_name = next(iter(mentions)) + + # Add the selected agent to the response so we can return it + messages.append({"role": "user", "content": f"[AGENT SELECTED]{selected_agent_name}"}) + + if self.select_speaker_auto_verbose: + iostream.print( + colored( + f">>>>>>>> Select speaker attempt {attempt} of {attempt + attempts_left} successfully selected: {selected_agent_name}", + "green", + ), + flush=True, + ) + + elif len(mentions) > 1: + # More than one name on requery so add additional reminder prompt for next retry + + if self.select_speaker_auto_verbose: + iostream.print( + colored( + f">>>>>>>> Select speaker attempt {attempt} of {attempt + attempts_left} failed as it included multiple agent names.", + "red", + ), + flush=True, + ) + + if attempts_left: + # Message to return to the chat for the next attempt + agentlist = f"{[agent.name for agent in agents]}" + + return True, { + "content": self.select_speaker_auto_multiple_template.format(agentlist=agentlist), + "name": "checking_agent", + "override_role": self.role_for_select_speaker_messages, + } + else: + # Final failure, no attempts left + messages.append( + { + "role": "user", + "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it returned multiple names.", + } + ) + + else: + # No names at all on requery so add additional reminder prompt for next retry + + if self.select_speaker_auto_verbose: + iostream.print( + colored( + f">>>>>>>> Select speaker attempt #{attempt} failed as it did not include any agent names.", + "red", + ), + flush=True, + ) + + if attempts_left: + # Message to return to the chat for the next attempt + agentlist = f"{[agent.name for agent in agents]}" + + return True, { + "content": self.select_speaker_auto_none_template.format(agentlist=agentlist), + "name": "checking_agent", + "override_role": self.role_for_select_speaker_messages, + } + else: + # Final failure, no attempts left + messages.append( + { + "role": "user", + "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it did not include any agent names.", + } + ) + + return True, None + + def _process_speaker_selection_result(self, result, last_speaker: ConversableAgent, agents: list[Agent] | None): + """Checks the result of the auto_select_speaker function, returning the + agent to speak. + + Used by auto_select_speaker and a_auto_select_speaker.""" + if len(result.chat_history) > 0: + # Use the final message, which will have the selected agent or reason for failure + final_message = result.chat_history[-1]["content"] + + if "[AGENT SELECTED]" in final_message: + # Have successfully selected an agent, return it + return self.agent_by_name(final_message.replace("[AGENT SELECTED]", "")) + + else: # "[AGENT SELECTION FAILED]" + # Failed to select an agent, so we'll select the next agent in the list + next_agent = self.next_agent(last_speaker, agents) + + # No agent, return the failed reason + return next_agent + + def _participant_roles(self, agents: list["Agent"] | None = None) -> str: + # Default to all agents registered + if agents is None: + agents = self.agents + + roles = [] + for agent in agents: + if agent.description.strip() == "": + logger.warning( + f"The agent '{agent.name}' has an empty description, and may not work well with GroupChat." + ) + roles.append(f"{agent.name}: {agent.description}".strip()) + return "\n".join(roles) + + def _mentioned_agents(self, message_content: str | list, agents: list[Agent] | None) -> dict: + """Counts the number of times each agent is mentioned in the provided message content. + Agent names will match under any of the following conditions (all case-sensitive): + - Exact name match + - If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer') + - If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer') + + Args: + message_content (Union[str, List]): The content of the message, either as a single string or a list of strings. + agents (List[Agent]): A list of Agent objects, each having a 'name' attribute to be searched in the message content. + + Returns: + Dict: a counter for mentioned agents. + """ + if agents is None: + agents = self.agents + + # Cast message content to str + if isinstance(message_content, dict): + message_content = message_content["content"] + message_content = content_str(message_content) + + mentions = dict() + for agent in agents: + # Finds agent mentions, taking word boundaries into account, + # accommodates escaping underscores and underscores as spaces + regex = ( + r"(?<=\W)(" + + re.escape(agent.name) + + r"|" + + re.escape(agent.name.replace("_", " ")) + + r"|" + + re.escape(agent.name.replace("_", r"\_")) + + r")(?=\W)" + ) + count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching + if count > 0: + mentions[agent.name] = count + return mentions + + +class GroupChatManager(ConversableAgent): + + def __init__( + self, + groupchat: GroupChat, + name: str | None = "chat_manager", + max_consecutive_auto_reply: int | None = sys.maxsize, + human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", + system_message: str | list | None = "Group chat manager.", + silent: bool = False, + **kwargs, + ): + if ( + kwargs.get("llm_config") + and isinstance(kwargs["llm_config"], dict) + and (kwargs["llm_config"].get("functions") or kwargs["llm_config"].get("tools")) + ): + raise ValueError( + "GroupChatManager is not allowed to make function/tool calls. Please remove the 'functions' or 'tools' config in 'llm_config' you passed in." + ) + + super().__init__( + name=name, + max_consecutive_auto_reply=max_consecutive_auto_reply, + human_input_mode=human_input_mode, + system_message=system_message, + **kwargs, + ) + + # Store groupchat + self._groupchat = groupchat + + self._last_speaker = None + self._silent = silent + + # Order of register_reply is important. + # Allow sync chat if initiated using initiate_chat + self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset) + # Allow async chat if initiated using a_initiate_chat + self.register_reply( + Agent, + GroupChatManager.a_run_chat, + config=groupchat, + reset_config=GroupChat.reset, + ignore_async_in_sync_chat=True, + ) + + @property + def groupchat(self) -> GroupChat: + """Returns the group chat managed by the group chat manager.""" + return self._groupchat + + def chat_messages_for_summary(self, agent: Agent) -> list[dict]: + """The list of messages in the group chat as a conversation to summarize. + The agent is ignored. + """ + return self._groupchat.messages + + def _prepare_chat( + self, + recipient: ConversableAgent, + clear_history: bool, + prepare_recipient: bool = True, + reply_at_receive: bool = True, + ) -> None: + super()._prepare_chat(recipient, clear_history, prepare_recipient, reply_at_receive) + + if clear_history: + self._groupchat.reset() + + for agent in self._groupchat.agents: + if (recipient != agent or prepare_recipient) and isinstance(agent, ConversableAgent): + agent._prepare_chat(self, clear_history, False, reply_at_receive) + + @property + def last_speaker(self) -> Agent: + """Return the agent who sent the last message to group chat manager. + + In a group chat, an agent will always send a message to the group chat manager, and the group chat manager will + send the message to all other agents in the group chat. So, when an agent receives a message, it will always be + from the group chat manager. With this property, the agent receiving the message can know who actually sent the + message. + + Example: + ```python + from autogen import ConversableAgent + from autogen import GroupChat, GroupChatManager + + + def print_messages(recipient, messages, sender, config): + # Print the message immediately + print( + f"Sender: {sender.name} | Recipient: {recipient.name} | Message: {messages[-1].get('content')}" + ) + print(f"Real Sender: {sender.last_speaker.name}") + assert sender.last_speaker.name in messages[-1].get("content") + return False, None # Required to ensure the agent communication flow continues + + + agent_a = ConversableAgent("agent A", default_auto_reply="I'm agent A.") + agent_b = ConversableAgent("agent B", default_auto_reply="I'm agent B.") + agent_c = ConversableAgent("agent C", default_auto_reply="I'm agent C.") + for agent in [agent_a, agent_b, agent_c]: + agent.register_reply( + [ConversableAgent, None], reply_func=print_messages, config=None + ) + group_chat = GroupChat( + [agent_a, agent_b, agent_c], + messages=[], + max_round=6, + speaker_selection_method="random", + allow_repeat_speaker=True, + ) + chat_manager = GroupChatManager(group_chat) + groupchat_result = agent_a.initiate_chat( + chat_manager, message="Hi, there, I'm agent A." + ) + ``` + """ + return self._last_speaker + + def run_chat( + self, + messages: list[dict] | None = None, + sender: Agent | None = None, + config: GroupChat | None = None, + ) -> tuple[bool, str | None]: + """Run a group chat.""" + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + speaker = sender + groupchat = config + send_introductions = getattr(groupchat, "send_introductions", False) + silent = getattr(self, "_silent", False) + + if send_introductions: + # Broadcast the intro + intro = groupchat.introductions_msg() + for agent in groupchat.agents: + self.send(intro, agent, request_reply=False, silent=True) + # NOTE: We do not also append to groupchat.messages, + # since groupchat handles its own introductions + + if self.client_cache is not None: + for a in groupchat.agents: + a.previous_cache = a.client_cache + a.client_cache = self.client_cache + for i in range(groupchat.max_round): + self._last_speaker = speaker + groupchat.append(message, speaker) + # broadcast the message to all agents except the speaker + for agent in groupchat.agents: + if agent != speaker: + self.send(message, agent, request_reply=False, silent=True) + if self._is_termination_msg(message) or i == groupchat.max_round - 1: + # The conversation is over or it's the last round + break + try: + # select the next speaker + speaker = groupchat.select_speaker(speaker, self) + if not silent: + iostream = IOStream.get_default() + iostream.print(colored(f"\nNext speaker: {speaker.name}\n", "green"), flush=True) + # let the speaker speak + reply = speaker.generate_reply(sender=self) + except KeyboardInterrupt: + # let the admin agent speak if interrupted + if groupchat.admin_name in groupchat.agent_names: + # admin agent is one of the participants + speaker = groupchat.agent_by_name(groupchat.admin_name) + reply = speaker.generate_reply(sender=self) + else: + # admin agent is not found in the participants + raise + except NoEligibleSpeaker: + # No eligible speaker, terminate the conversation + logger.warning("No eligible speaker found. Terminating the conversation.") + break + + if reply is None: + # no reply is generated, exit the chat + break + + # check for "clear history" phrase in reply and activate clear history function if found + if ( + groupchat.enable_clear_history + and isinstance(reply, dict) + and reply["content"] + and "CLEAR HISTORY" in reply["content"].upper() + ): + reply["content"] = self.clear_agents_history(reply, groupchat) + + # The speaker sends the message without requesting a reply + speaker.send(reply, self, request_reply=False, silent=silent) + message = self.last_message(speaker) + if self.client_cache is not None: + for a in groupchat.agents: + a.client_cache = a.previous_cache + a.previous_cache = None + return True, None + + async def a_run_chat( + self, + messages: list[dict] | None = None, + sender: Agent | None = None, + config: GroupChat | None = None, + ): + """Run a group chat asynchronously.""" + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + speaker = sender + groupchat = config + send_introductions = getattr(groupchat, "send_introductions", False) + silent = getattr(self, "_silent", False) + + if send_introductions: + # Broadcast the intro + intro = groupchat.introductions_msg() + for agent in groupchat.agents: + await self.a_send(intro, agent, request_reply=False, silent=True) + # NOTE: We do not also append to groupchat.messages, + # since groupchat handles its own introductions + + if self.client_cache is not None: + for a in groupchat.agents: + a.previous_cache = a.client_cache + a.client_cache = self.client_cache + for i in range(groupchat.max_round): + groupchat.append(message, speaker) + + if self._is_termination_msg(message): + # The conversation is over + break + + # broadcast the message to all agents except the speaker + for agent in groupchat.agents: + if agent != speaker: + await self.a_send(message, agent, request_reply=False, silent=True) + if i == groupchat.max_round - 1: + # the last round + break + try: + # select the next speaker + speaker = await groupchat.a_select_speaker(speaker, self) + # let the speaker speak + reply = await speaker.a_generate_reply(sender=self) + except KeyboardInterrupt: + # let the admin agent speak if interrupted + if groupchat.admin_name in groupchat.agent_names: + # admin agent is one of the participants + speaker = groupchat.agent_by_name(groupchat.admin_name) + reply = await speaker.a_generate_reply(sender=self) + else: + # admin agent is not found in the participants + raise + except NoEligibleSpeaker: + # No eligible speaker, terminate the conversation + logger.warning("No eligible speaker found. Terminating the conversation.") + break + + if reply is None: + break + # The speaker sends the message without requesting a reply + await speaker.a_send(reply, self, request_reply=False, silent=silent) + message = self.last_message(speaker) + if self.client_cache is not None: + for a in groupchat.agents: + a.client_cache = a.previous_cache + a.previous_cache = None + return True, None + + def resume( + self, + messages: Union[list[dict], str], + remove_termination_string: Union[str, Callable[[str], str]] = None, + silent: bool | None = False, + ) -> tuple[ConversableAgent, dict]: + """Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established + as per the original group chat. + + Args: + - messages Union[list[dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries. + - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination + If a string is provided, this string will be removed from last message. + If a function is provided, the last message will be passed to this function. + - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. + + Returns: + - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message + """ + + # Convert messages from string to messages list, if needed + if isinstance(messages, str): + messages = self.messages_from_string(messages) + elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages): + messages = deepcopy(messages) + else: + raise Exception("Messages is not of type str or list[dict]") + + # Clean up the objects, ensuring there are no messages in the agents and group chat + + # Clear agent message history + for agent in self._groupchat.agents: + if isinstance(agent, ConversableAgent): + agent.clear_history() + + # Clear Manager message history + self.clear_history() + + # Clear GroupChat messages + self._groupchat.reset() + + # Validation of message and agents + + try: + self._valid_resume_messages(messages) + except: + raise + + # Load the messages into the group chat + for i, message in enumerate(messages): + if "name" in message: + message_speaker_agent = self._groupchat.agent_by_name(message["name"]) + else: + # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state) + message_speaker_agent = self + message["name"] = self.name + + # If it wasn't an agent speaking, it may be the manager + if not message_speaker_agent and message["name"] == self.name: + message_speaker_agent = self + + # Add previous messages to each agent (except the last message, as we'll kick off the conversation with it) + if i != len(messages) - 1: + for agent in self._groupchat.agents: + self.send(message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True) + + # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly + if message_speaker_agent: + self._groupchat.append(message, message_speaker_agent) + else: + self._groupchat.messages.append(message) + + # Last speaker agent + last_speaker_name = message["name"] + + # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future) + last_message = message + + # Get last speaker as an agent + previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name) + + # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so + if not previous_last_agent and ( + last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name + ): + previous_last_agent = self + + # Termination removal and check + self._process_resume_termination(remove_termination_string, messages) + + if not silent: + iostream = IOStream.get_default() + iostream.print( + f"Prepared group chat with {len(messages)} messages, the last speaker is", + colored(last_speaker_name, "yellow"), + flush=True, + ) + + # Update group chat settings for resuming + self._groupchat.send_introductions = False + + return previous_last_agent, last_message + + async def a_resume( + self, + messages: Union[list[dict], str], + remove_termination_string: Union[str, Callable[[str], str]] = None, + silent: bool | None = False, + ) -> tuple[ConversableAgent, dict]: + """Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established + as per the original group chat. + + Args: + - messages Union[list[dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries. + - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination + If a string is provided, this string will be removed from last message. + If a function is provided, the last message will be passed to this function, and the function returns the string after processing. + - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. + + Returns: + - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message + """ + + # Convert messages from string to messages list, if needed + if isinstance(messages, str): + messages = self.messages_from_string(messages) + elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages): + messages = deepcopy(messages) + else: + raise Exception("Messages is not of type str or list[dict]") + + # Clean up the objects, ensuring there are no messages in the agents and group chat + + # Clear agent message history + for agent in self._groupchat.agents: + if isinstance(agent, ConversableAgent): + agent.clear_history() + + # Clear Manager message history + self.clear_history() + + # Clear GroupChat messages + self._groupchat.reset() + + # Validation of message and agents + + try: + self._valid_resume_messages(messages) + except: + raise + + # Load the messages into the group chat + for i, message in enumerate(messages): + if "name" in message: + message_speaker_agent = self._groupchat.agent_by_name(message["name"]) + else: + # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state) + message_speaker_agent = self + message["name"] = self.name + + # If it wasn't an agent speaking, it may be the manager + if not message_speaker_agent and message["name"] == self.name: + message_speaker_agent = self + + # Add previous messages to each agent (except the last message, as we'll kick off the conversation with it) + if i != len(messages) - 1: + for agent in self._groupchat.agents: + await self.a_send( + message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True + ) + + # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly + if message_speaker_agent: + self._groupchat.append(message, message_speaker_agent) + else: + self._groupchat.messages.append(message) + + # Last speaker agent + last_speaker_name = message["name"] + + # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future) + last_message = message + + # Get last speaker as an agent + previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name) + + # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so + if not previous_last_agent and ( + last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name + ): + previous_last_agent = self + + # Termination removal and check + self._process_resume_termination(remove_termination_string, messages) + + if not silent: + iostream = IOStream.get_default() + iostream.print( + f"Prepared group chat with {len(messages)} messages, the last speaker is", + colored(last_speaker_name, "yellow"), + flush=True, + ) + + # Update group chat settings for resuming + self._groupchat.send_introductions = False + + return previous_last_agent, last_message + + def _valid_resume_messages(self, messages: list[dict]): + """Validates the messages used for resuming + + args: + messages (list[dict]): list of messages to resume with + + returns: + - bool: Whether they are valid for resuming + """ + # Must have messages to start with, otherwise they should run run_chat + if not messages: + raise Exception( + "Cannot resume group chat as no messages were provided. Use GroupChatManager.run_chat or ConversableAgent.initiate_chat to start a new chat." + ) + + # Check that all agents in the chat messages exist in the group chat + for message in messages: + if message.get("name"): + if ( + not self._groupchat.agent_by_name(message["name"]) + and not message["name"] == self._groupchat.admin_name # ignore group chat's name + and not message["name"] == self.name # ignore group chat manager's name + ): + raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}") + + def _process_resume_termination( + self, remove_termination_string: str | Callable[[str], str], messages: list[dict] + ): + """Removes termination string, if required, and checks if termination may occur. + + args: + remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination + If a string is provided, this string will be removed from last message. + If a function is provided, the last message will be passed to this function, and the function returns the string after processing. + + returns: + None + """ + + last_message = messages[-1] + + # Replace any given termination string in the last message + if isinstance(remove_termination_string, str): + + def _remove_termination_string(content: str) -> str: + return content.replace(remove_termination_string, "") + + else: + _remove_termination_string = remove_termination_string + + if _remove_termination_string: + if messages[-1].get("content"): + messages[-1]["content"] = _remove_termination_string(messages[-1]["content"]) + + # Check if the last message meets termination (if it has one) + if self._is_termination_msg: + if self._is_termination_msg(last_message): + logger.warning("WARNING: Last message meets termination criteria and this may terminate the chat.") + + def messages_from_string(self, message_string: str) -> list[dict]: + """Reads the saved state of messages in Json format for resume and returns as a messages list + + args: + - message_string: Json string, the saved state + + returns: + - list[dict]: List of messages + """ + try: + state = json.loads(message_string) + except json.JSONDecodeError: + raise Exception("Messages string is not a valid JSON string") + + return state + + def messages_to_string(self, messages: list[dict]) -> str: + """Converts the provided messages into a Json string that can be used for resuming the chat. + The state is made up of a list of messages + + args: + - messages (list[dict]): set of messages to convert to a string + + returns: + - str: Json representation of the messages which can be persisted for resuming later + """ + + return json.dumps(messages) + + def _raise_exception_on_async_reply_functions(self) -> None: + """Raise an exception if any async reply functions are registered. + + Raises: + RuntimeError: if any async reply functions are registered. + """ + super()._raise_exception_on_async_reply_functions() + + for agent in self._groupchat.agents: + agent._raise_exception_on_async_reply_functions() + + def clear_agents_history(self, reply: dict, groupchat: GroupChat) -> str: + """Clears history of messages for all agents or selected one. Can preserve selected number of last messages. + That function is called when user manually provide "clear history" phrase in his reply. + When "clear history" is provided, the history of messages for all agents is cleared. + When "clear history " is provided, the history of messages for selected agent is cleared. + When "clear history " is provided, the history of messages for all agents is cleared + except last messages. + When "clear history " is provided, the history of messages for selected + agent is cleared except last messages. + Phrase "clear history" and optional arguments are cut out from the reply before it passed to the chat. + + Args: + reply (dict): reply message dict to analyze. + groupchat (GroupChat): GroupChat object. + """ + iostream = IOStream.get_default() + + reply_content = reply["content"] + # Split the reply into words + words = reply_content.split() + # Find the position of "clear" to determine where to start processing + clear_word_index = next(i for i in reversed(range(len(words))) if words[i].upper() == "CLEAR") + # Extract potential agent name and steps + words_to_check = words[clear_word_index + 2 : clear_word_index + 4] + nr_messages_to_preserve = None + nr_messages_to_preserve_provided = False + agent_to_memory_clear = None + + for word in words_to_check: + if word.isdigit(): + nr_messages_to_preserve = int(word) + nr_messages_to_preserve_provided = True + elif word[:-1].isdigit(): # for the case when number of messages is followed by dot or other sign + nr_messages_to_preserve = int(word[:-1]) + nr_messages_to_preserve_provided = True + else: + for agent in groupchat.agents: + if agent.name == word: + agent_to_memory_clear = agent + break + elif agent.name == word[:-1]: # for the case when agent name is followed by dot or other sign + agent_to_memory_clear = agent + break + # preserve last tool call message if clear history called inside of tool response + if "tool_responses" in reply and not nr_messages_to_preserve: + nr_messages_to_preserve = 1 + logger.warning( + "The last tool call message will be saved to prevent errors caused by tool response without tool call." + ) + # clear history + if agent_to_memory_clear: + if nr_messages_to_preserve: + iostream.print( + f"Clearing history for {agent_to_memory_clear.name} except last {nr_messages_to_preserve} messages." + ) + else: + iostream.print(f"Clearing history for {agent_to_memory_clear.name}.") + agent_to_memory_clear.clear_history(nr_messages_to_preserve=nr_messages_to_preserve) + else: + if nr_messages_to_preserve: + iostream.print(f"Clearing history for all agents except last {nr_messages_to_preserve} messages.") + # clearing history for groupchat here + temp = groupchat.messages[-nr_messages_to_preserve:] + groupchat.messages.clear() + groupchat.messages.extend(temp) + else: + iostream.print("Clearing history for all agents.") + # clearing history for groupchat here + groupchat.messages.clear() + # clearing history for agents + for agent in groupchat.agents: + agent.clear_history(nr_messages_to_preserve=nr_messages_to_preserve) + + # Reconstruct the reply without the "clear history" command and parameters + skip_words_number = 2 + int(bool(agent_to_memory_clear)) + int(nr_messages_to_preserve_provided) + reply_content = " ".join(words[:clear_word_index] + words[clear_word_index + skip_words_number :]) + + return reply_content \ No newline at end of file diff --git a/train_methods/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen_conversable_agent.py new file mode 100644 index 0000000..9c47293 --- /dev/null +++ b/train_methods/legacy_autogen_conversable_agent.py @@ -0,0 +1,3130 @@ +import asyncio +import contextvars +import copy +import functools +import inspect +import json +import logging +import re +import warnings +from collections import defaultdict +from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union, Protocol + +from openai import BadRequestError + +from autogen.agentchat.chat import _post_process_carryover_item +from autogen.exception_utils import InvalidCarryOverType, SenderRequired + +from .._pydantic import model_dump +from ..cache.cache import AbstractCache +from ..code_utils import ( + PYTHON_VARIANTS, + UNKNOWN, + check_can_use_docker_or_throw, + content_str, + decide_use_docker, + execute_code, + extract_code, + infer_lang, +) +from ..coding.base import CodeExecutor +from ..coding.factory import CodeExecutorFactory +from ..formatting_utils import colored +from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str +from ..io.base import IOStream +from ..oai.client import ModelClient, OpenAIWrapper +from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled +from .agent import Agent, LLMAgent +from .chat import ChatResult, a_initiate_chats, initiate_chats +from .utils import consolidate_chat_info, gather_usage_summary + +__all__ = ("ConversableAgent",) + +logger = logging.getLogger(__name__) + +class Agent(Protocol): + """(In preview) A protocol for Agent. + + An agent can communicate with other agents and perform actions. + Different agents can differ in what actions they perform in the `receive` method. + """ + + @property + def name(self) -> str: + """The name of the agent.""" + ... + + @property + def description(self) -> str: + """The description of the agent. Used for the agent's introduction in a group chat setting.""" + ... + + def send( + self, + message: dict[str, Any] | str, + recipient: "Agent", + request_reply: bool | None = None, + ) -> None: + """Send a message to another agent. + + Args: + message (dict or str): the message to send. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + recipient (Agent): the recipient of the message. + request_reply (bool): whether to request a reply from the recipient. + """ + ... + + async def a_send( + self, + message: dict[str, Any] | str, + recipient: "Agent", + request_reply: bool | None = None, + ) -> None: + """(Async) Send a message to another agent. + + Args: + message (dict or str): the message to send. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + recipient (Agent): the recipient of the message. + request_reply (bool): whether to request a reply from the recipient. + """ + ... + + def receive( + self, + message: dict[str, Any] | str, + sender: "Agent", + request_reply: bool | None = None, + ) -> None: + """Receive a message from another agent. + + Args: + message (dict or str): the message received. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + sender (Agent): the sender of the message. + request_reply (bool): whether the sender requests a reply. + """ + + async def a_receive( + self, + message: dict[str, Any] | str, + sender: "Agent", + request_reply: bool | None = None, + ) -> None: + """(Async) Receive a message from another agent. + + Args: + message (dict or str): the message received. If a dict, it should be + a JSON-serializable and follows the OpenAI's ChatCompletion schema. + sender (Agent): the sender of the message. + request_reply (bool): whether the sender requests a reply. + """ + ... + + def generate_reply( + self, + messages: list[dict[str, Any]] | None = None, + sender: Literal["Agent"] | None = None, + **kwargs: Any, + ) -> str | dict[str, Any] | None: + """Generate a reply based on the received messages. + + Args: + messages (list[dict]): a list of messages received from other agents. + The messages are dictionaries that are JSON-serializable and + follows the OpenAI's ChatCompletion schema. + sender: sender of an Agent instance. + + Returns: + str or dict or None: the generated reply. If None, no reply is generated. + """ + + async def a_generate_reply( + self, + messages: list[dict[str, Any]] | None = None, + sender: Literal["Agent"] | None = None, + **kwargs: Any, + ) -> str | dict[str, Any] | None: + """(Async) Generate a reply based on the received messages. + + Args: + messages (list[dict]): a list of messages received from other agents. + The messages are dictionaries that are JSON-serializable and + follows the OpenAI's ChatCompletion schema. + sender: sender of an Agent instance. + + Returns: + str or dict or None: the generated reply. If None, no reply is generated. + """ + + +class LLMAgent(Agent, Protocol): + """(In preview) A protocol for an LLM agent.""" + + @property + def system_message(self) -> str: + """The system message of this agent.""" + + def update_system_message(self, system_message: str) -> None: + """Update this agent's system message. + + Args: + system_message (str): system message for inference. + """ + + + +F = TypeVar("F", bound=Callable[..., Any]) + + +class ConversableAgent(LLMAgent): + """(In preview) A class for generic conversable agents which can be configured as assistant or user proxy. + + After receiving each message, the agent will send a reply to the sender unless the msg is a termination msg. + For example, AssistantAgent and UserProxyAgent are subclasses of this class, + configured with different default settings. + + To modify auto reply, override `generate_reply` method. + To disable/enable human response in every turn, set `human_input_mode` to "NEVER" or "ALWAYS". + To modify the way to get human input, override `get_human_input` method. + To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`, + `run_code`, and `execute_function` methods respectively. + """ + + DEFAULT_CONFIG = False # False or dict, the default config for llm inference + MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change) + + DEFAULT_SUMMARY_PROMPT = "Summarize the takeaway from the conversation. Do not add any introductory phrases." + DEFAULT_SUMMARY_METHOD = "last_msg" + llm_config: Union[Dict, Literal[False]] + + def __init__( + self, + name: str, + system_message: Optional[Union[str, List]] = "You are a helpful AI Assistant.", + is_termination_msg: Optional[Callable[[Dict], bool]] = None, + max_consecutive_auto_reply: Optional[int] = None, + human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE", + function_map: Optional[Dict[str, Callable]] = None, + code_execution_config: Union[Dict, Literal[False]] = False, + llm_config: Optional[Union[Dict, Literal[False]]] = None, + default_auto_reply: Union[str, Dict] = "", + description: Optional[str] = None, + chat_messages: Optional[Dict[Agent, List[Dict]]] = None, + silent: Optional[bool] = None, + ): + """ + Args: + name (str): name of the agent. + system_message (str or list): system message for the ChatCompletion inference. + is_termination_msg (function): a function that takes a message in the form of a dictionary + and returns a boolean value indicating if this received message is a termination message. + The dict can contain the following keys: "content", "role", "name", "function_call". + max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. + default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). + When set to 0, no auto reply will be generated. + human_input_mode (str): whether to ask for human inputs every time a message is received. + Possible values are "ALWAYS", "TERMINATE", "NEVER". + (1) When "ALWAYS", the agent prompts for human input every time a message is received. + Under this mode, the conversation stops when the human input is "exit", + or when is_termination_msg is True and there is no human input. + (2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or + the number of auto reply reaches the max_consecutive_auto_reply. + (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops + when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. + function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions, also used for tool calls. + code_execution_config (dict or False): config for the code execution. + To disable code execution, set to False. Otherwise, set to a dictionary with the following keys: + - work_dir (Optional, str): The working directory for the code execution. + If None, a default working directory will be used. + The default working directory is the "extensions" directory under + "path_to_autogen". + - use_docker (Optional, list, str or bool): The docker image to use for code execution. + Default is True, which means the code will be executed in a docker container. A default list of images will be used. + If a list or a str of image name(s) is provided, the code will be executed in a docker container + with the first image successfully pulled. + If False, the code will be executed in the current environment. + We strongly recommend using docker for code execution. + - timeout (Optional, int): The maximum execution time in seconds. + - last_n_messages (Experimental, int or str): The number of messages to look back for code execution. + If set to 'auto', it will scan backwards through all messages arriving since the agent last spoke, which is typically the last time execution was attempted. (Default: auto) + llm_config (dict or False or None): llm inference configuration. + Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) + for available options. + When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`. + To disable llm-based auto reply, set to False. + When set to None, will use self.DEFAULT_CONFIG, which defaults to False. + default_auto_reply (str or dict): default auto reply when no code execution or llm-based reply is generated. + description (str): a short description of the agent. This description is used by other agents + (e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message) + chat_messages (dict or None): the previous chat messages that this agent had in the past with other agents. + Can be used to give the agent a memory by providing the chat history. This will allow the agent to + resume previous had conversations. Defaults to an empty chat history. + silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of + silent in each function. + """ + # we change code_execution_config below and we have to make sure we don't change the input + # in case of UserProxyAgent, without this we could even change the default value {} + code_execution_config = ( + code_execution_config.copy() if hasattr(code_execution_config, "copy") else code_execution_config + ) + + self._name = name + # a dictionary of conversations, default value is list + if chat_messages is None: + self._oai_messages = defaultdict(list) + else: + self._oai_messages = chat_messages + + self._oai_system_message = [{"content": system_message, "role": "system"}] + self._description = description if description is not None else system_message + self._is_termination_msg = ( + is_termination_msg + if is_termination_msg is not None + else (lambda x: content_str(x.get("content")) == "TERMINATE") + ) + self.silent = silent + # Take a copy to avoid modifying the given dict + if isinstance(llm_config, dict): + try: + llm_config = copy.deepcopy(llm_config) + except TypeError as e: + raise TypeError( + "Please implement __deepcopy__ method for each value class in llm_config to support deepcopy." + " Refer to the docs for more details: https://microsoft.github.io/autogen/docs/topics/llm_configuration#adding-http-client-in-llm_config-for-proxy" + ) from e + + self._validate_llm_config(llm_config) + + if logging_enabled(): + log_new_agent(self, locals()) + + # Initialize standalone client cache object. + self.client_cache = None + + self.human_input_mode = human_input_mode + self._max_consecutive_auto_reply = ( + max_consecutive_auto_reply if max_consecutive_auto_reply is not None else self.MAX_CONSECUTIVE_AUTO_REPLY + ) + self._consecutive_auto_reply_counter = defaultdict(int) + self._max_consecutive_auto_reply_dict = defaultdict(self.max_consecutive_auto_reply) + self._function_map = ( + {} + if function_map is None + else {name: callable for name, callable in function_map.items() if self._assert_valid_name(name)} + ) + self._default_auto_reply = default_auto_reply + self._reply_func_list = [] + self._human_input = [] + self.reply_at_receive = defaultdict(bool) + self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) + self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True) + + # Setting up code execution. + # Do not register code execution reply if code execution is disabled. + if code_execution_config is not False: + # If code_execution_config is None, set it to an empty dict. + if code_execution_config is None: + warnings.warn( + "Using None to signal a default code_execution_config is deprecated. " + "Use {} to use default or False to disable code execution.", + stacklevel=2, + ) + code_execution_config = {} + if not isinstance(code_execution_config, dict): + raise ValueError("code_execution_config must be a dict or False.") + + # We have got a valid code_execution_config. + self._code_execution_config = code_execution_config + + if self._code_execution_config.get("executor") is not None: + if "use_docker" in self._code_execution_config: + raise ValueError( + "'use_docker' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." + ) + + if "work_dir" in self._code_execution_config: + raise ValueError( + "'work_dir' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." + ) + + if "timeout" in self._code_execution_config: + raise ValueError( + "'timeout' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." + ) + + # Use the new code executor. + self._code_executor = CodeExecutorFactory.create(self._code_execution_config) + self.register_reply([Agent, None], ConversableAgent._generate_code_execution_reply_using_executor) + else: + # Legacy code execution using code_utils. + use_docker = self._code_execution_config.get("use_docker", None) + use_docker = decide_use_docker(use_docker) + check_can_use_docker_or_throw(use_docker) + self._code_execution_config["use_docker"] = use_docker + self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply) + else: + # Code execution is disabled. + self._code_execution_config = False + + self.register_reply([Agent, None], ConversableAgent.generate_tool_calls_reply) + self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply, ignore_async_in_sync_chat=True) + self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply) + self.register_reply( + [Agent, None], ConversableAgent.a_generate_function_call_reply, ignore_async_in_sync_chat=True + ) + self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply) + self.register_reply( + [Agent, None], ConversableAgent.a_check_termination_and_human_reply, ignore_async_in_sync_chat=True + ) + + # Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration. + # New hookable methods should be added to this list as required to support new agent capabilities. + self.hook_lists: Dict[str, List[Union[Callable, Callable[..., Coroutine]]]] = { + "process_last_received_message": [], + "a_process_last_received_message": [], + "process_all_messages_before_reply": [], + "a_process_all_messages_before_reply": [], + "process_message_before_send": [], + "a_process_message_before_send": [], + } + + def _validate_llm_config(self, llm_config): + assert llm_config in (None, False) or isinstance( + llm_config, dict + ), "llm_config must be a dict or False or None." + if llm_config is None: + llm_config = self.DEFAULT_CONFIG + self.llm_config = self.DEFAULT_CONFIG if llm_config is None else llm_config + # TODO: more complete validity check + if self.llm_config in [{}, {"config_list": []}, {"config_list": [{"model": ""}]}]: + raise ValueError( + "When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'." + ) + self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config) + + @staticmethod + def _is_silent(agent: Agent, silent: Optional[bool] = False) -> bool: + return agent.silent if agent.silent is not None else silent + + @property + def name(self) -> str: + """Get the name of the agent.""" + return self._name + + @property + def description(self) -> str: + """Get the description of the agent.""" + return self._description + + @description.setter + def description(self, description: str): + """Set the description of the agent.""" + self._description = description + + @property + def code_executor(self) -> Optional[CodeExecutor]: + """The code executor used by this agent. Returns None if code execution is disabled.""" + if not hasattr(self, "_code_executor"): + return None + return self._code_executor + + def register_reply( + self, + trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List], + reply_func: Callable, + position: int = 0, + config: Optional[Any] = None, + reset_config: Optional[Callable] = None, + *, + ignore_async_in_sync_chat: bool = False, + remove_other_reply_funcs: bool = False, + ): + """Register a reply function. + + The reply function will be called when the trigger matches the sender. + The function registered later will be checked earlier by default. + To change the order, set the position to a positive integer. + + Both sync and async reply functions can be registered. The sync reply function will be triggered + from both sync and async chats. However, an async reply function will only be triggered from async + chats (initiated with `ConversableAgent.a_initiate_chat`). If an `async` reply function is registered + and a chat is initialized with a sync function, `ignore_async_in_sync_chat` determines the behaviour as follows: + if `ignore_async_in_sync_chat` is set to `False` (default value), an exception will be raised, and + if `ignore_async_in_sync_chat` is set to `True`, the reply function will be ignored. + + Args: + trigger (Agent class, str, Agent instance, callable, or list): the trigger. + If a class is provided, the reply function will be called when the sender is an instance of the class. + If a string is provided, the reply function will be called when the sender's name matches the string. + If an agent instance is provided, the reply function will be called when the sender is the agent instance. + If a callable is provided, the reply function will be called when the callable returns True. + If a list is provided, the reply function will be called when any of the triggers in the list is activated. + If None is provided, the reply function will be called only when the sender is None. + Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`. + reply_func (Callable): the reply function. + The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. + + ```python + def reply_func( + recipient: ConversableAgent, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + ``` + position (int): the position of the reply function in the reply function list. + The function registered later will be checked earlier by default. + To change the order, set the position to a positive integer. + config (Any): the config to be passed to the reply function. + When an agent is reset, the config will be reset to the original value. + reset_config (Callable): the function to reset the config. + The function returns None. Signature: ```def reset_config(config: Any)``` + ignore_async_in_sync_chat (bool): whether to ignore the async reply function in sync chats. If `False`, an exception + will be raised if an async reply function is registered and a chat is initialized with a sync + function. + remove_other_reply_funcs (bool): whether to remove other reply functions when registering this reply function. + """ + if not isinstance(trigger, (type, str, Agent, Callable, list)): + raise ValueError("trigger must be a class, a string, an agent, a callable or a list.") + if remove_other_reply_funcs: + self._reply_func_list.clear() + self._reply_func_list.insert( + position, + { + "trigger": trigger, + "reply_func": reply_func, + "config": copy.copy(config), + "init_config": config, + "reset_config": reset_config, + "ignore_async_in_sync_chat": ignore_async_in_sync_chat and inspect.iscoroutinefunction(reply_func), + }, + ) + + def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable): + """Replace a registered reply function with a new one. + + Args: + old_reply_func (Callable): the old reply function to be replaced. + new_reply_func (Callable): the new reply function to replace the old one. + """ + for f in self._reply_func_list: + if f["reply_func"] == old_reply_func: + f["reply_func"] = new_reply_func + + @staticmethod + def _get_chats_to_run( + chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> List[Dict[str, Any]]: + """A simple chat reply function. + This function initiate one or a sequence of chats between the "recipient" and the agents in the + chat_queue. + + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + """ + last_msg = messages[-1].get("content") + chat_to_run = [] + for i, c in enumerate(chat_queue): + current_c = c.copy() + if current_c.get("sender") is None: + current_c["sender"] = recipient + message = current_c.get("message") + # If message is not provided in chat_queue, we by default use the last message from the original chat history as the first message in this nested chat (for the first chat in the chat queue). + # NOTE: This setting is prone to change. + if message is None and i == 0: + message = last_msg + if callable(message): + message = message(recipient, messages, sender, config) + # We only run chat that has a valid message. NOTE: This is prone to change dependin on applications. + if message: + current_c["message"] = message + chat_to_run.append(current_c) + return chat_to_run + + @staticmethod + def _summary_from_nested_chats( + chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> Tuple[bool, Union[str, None]]: + """A simple chat reply function. + This function initiate one or a sequence of chats between the "recipient" and the agents in the + chat_queue. + + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + """ + chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) + if not chat_to_run: + return True, None + res = initiate_chats(chat_to_run) + return True, res[-1].summary + + @staticmethod + async def _a_summary_from_nested_chats( + chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any + ) -> Tuple[bool, Union[str, None]]: + """A simple chat reply function. + This function initiate one or a sequence of chats between the "recipient" and the agents in the + chat_queue. + + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + + Returns: + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + """ + chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) + if not chat_to_run: + return True, None + res = await a_initiate_chats(chat_to_run) + index_of_last_chat = chat_to_run[-1]["chat_id"] + return True, res[index_of_last_chat].summary + + def register_nested_chats( + self, + chat_queue: List[Dict[str, Any]], + trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List], + reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats", + position: int = 2, + use_async: Union[bool, None] = None, + **kwargs, + ) -> None: + """Register a nested chat reply function. + Args: + chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them. + trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details. + reply_func_from_nested_chats (Callable, str): the reply function for the nested chat. + The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. + Default to "summary_from_nested_chats", which corresponds to a built-in reply function that get summary from the nested chat_queue. + ```python + def reply_func_from_nested_chats( + chat_queue: List[Dict], + recipient: ConversableAgent, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + ``` + position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply. + use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync. + kwargs: Ref to `register_reply` for details. + """ + if use_async: + for chat in chat_queue: + if chat.get("chat_id") is None: + raise ValueError("chat_id is required for async nested chats") + + if use_async: + if reply_func_from_nested_chats == "summary_from_nested_chats": + reply_func_from_nested_chats = self._a_summary_from_nested_chats + if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction( + reply_func_from_nested_chats + ): + raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine") + + async def wrapped_reply_func(recipient, messages=None, sender=None, config=None): + return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) + + else: + if reply_func_from_nested_chats == "summary_from_nested_chats": + reply_func_from_nested_chats = self._summary_from_nested_chats + if not callable(reply_func_from_nested_chats): + raise ValueError("reply_func_from_nested_chats must be a callable") + + def wrapped_reply_func(recipient, messages=None, sender=None, config=None): + return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) + + functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats) + + self.register_reply( + trigger, + wrapped_reply_func, + position, + kwargs.get("config"), + kwargs.get("reset_config"), + ignore_async_in_sync_chat=( + not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat") + ), + ) + + @property + def system_message(self) -> str: + """Return the system message.""" + return self._oai_system_message[0]["content"] + + def update_system_message(self, system_message: str) -> None: + """Update the system message. + + Args: + system_message (str): system message for the ChatCompletion inference. + """ + self._oai_system_message[0]["content"] = system_message + + def update_max_consecutive_auto_reply(self, value: int, sender: Optional[Agent] = None): + """Update the maximum number of consecutive auto replies. + + Args: + value (int): the maximum number of consecutive auto replies. + sender (Agent): when the sender is provided, only update the max_consecutive_auto_reply for that sender. + """ + if sender is None: + self._max_consecutive_auto_reply = value + for k in self._max_consecutive_auto_reply_dict: + self._max_consecutive_auto_reply_dict[k] = value + else: + self._max_consecutive_auto_reply_dict[sender] = value + + def max_consecutive_auto_reply(self, sender: Optional[Agent] = None) -> int: + """The maximum number of consecutive auto replies.""" + return self._max_consecutive_auto_reply if sender is None else self._max_consecutive_auto_reply_dict[sender] + + @property + def chat_messages(self) -> Dict[Agent, List[Dict]]: + """A dictionary of conversations from agent to list of messages.""" + return self._oai_messages + + def chat_messages_for_summary(self, agent: Agent) -> List[Dict]: + """A list of messages as a conversation to summarize.""" + return self._oai_messages[agent] + + def last_message(self, agent: Optional[Agent] = None) -> Optional[Dict]: + """The last message exchanged with the agent. + + Args: + agent (Agent): The agent in the conversation. + If None and more than one agent's conversations are found, an error will be raised. + If None and only one conversation is found, the last message of the only conversation will be returned. + + Returns: + The last message exchanged with the agent. + """ + if agent is None: + n_conversations = len(self._oai_messages) + if n_conversations == 0: + return None + if n_conversations == 1: + for conversation in self._oai_messages.values(): + return conversation[-1] + raise ValueError("More than one conversation is found. Please specify the sender to get the last message.") + if agent not in self._oai_messages.keys(): + raise KeyError( + f"The agent '{agent.name}' is not present in any conversation. No history available for this agent." + ) + return self._oai_messages[agent][-1] + + @property + def use_docker(self) -> Union[bool, str, None]: + """Bool value of whether to use docker to execute the code, + or str value of the docker image name to use, or None when code execution is disabled. + """ + return None if self._code_execution_config is False else self._code_execution_config.get("use_docker") + + @staticmethod + def _message_to_dict(message: Union[Dict, str]) -> Dict: + """Convert a message to a dictionary. + + The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary. + """ + if isinstance(message, str): + return {"content": message} + elif isinstance(message, dict): + return message + else: + return dict(message) + + @staticmethod + def _normalize_name(name): + """ + LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_". + + Prefer _assert_valid_name for validating user configuration or input + """ + return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64] + + @staticmethod + def _assert_valid_name(name): + """ + Ensure that configured names are valid, raises ValueError if not. + + For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API. + """ + if not re.match(r"^[a-zA-Z0-9_-]+$", name): + raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.") + if len(name) > 64: + raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.") + return name + + def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent, is_sending: bool) -> bool: + """Append a message to the ChatCompletion conversation. + + If the message received is a string, it will be put in the "content" field of the new dictionary. + If the message received is a dictionary but does not have any of the three fields "content", "function_call", or "tool_calls", + this message is not a valid ChatCompletion message. + If only "function_call" or "tool_calls" is provided, "content" will be set to None if not provided, and the role of the message will be forced "assistant". + + Args: + message (dict or str): message to be appended to the ChatCompletion conversation. + role (str): role of the message, can be "assistant" or "function". + conversation_id (Agent): id of the conversation, should be the recipient or sender. + is_sending (bool): If the agent (aka self) is sending to the conversation_id agent, otherwise receiving. + + Returns: + bool: whether the message is appended to the ChatCompletion conversation. + """ + message = self._message_to_dict(message) + # create oai message to be appended to the oai conversation that can be passed to oai directly. + oai_message = { + k: message[k] + for k in ("content", "function_call", "tool_calls", "tool_responses", "tool_call_id", "name", "context") + if k in message and message[k] is not None + } + if "content" not in oai_message: + if "function_call" in oai_message or "tool_calls" in oai_message: + oai_message["content"] = None # if only function_call is provided, content will be set to None. + else: + return False + + if message.get("role") in ["function", "tool"]: + oai_message["role"] = message.get("role") + elif "override_role" in message: + # If we have a direction to override the role then set the + # role accordingly. Used to customise the role for the + # select speaker prompt. + oai_message["role"] = message.get("override_role") + else: + oai_message["role"] = role + + if oai_message.get("function_call", False) or oai_message.get("tool_calls", False): + oai_message["role"] = "assistant" # only messages with role 'assistant' can have a function call. + elif "name" not in oai_message: + # If we don't have a name field, append it + if is_sending: + oai_message["name"] = self.name + else: + oai_message["name"] = conversation_id.name + + self._oai_messages[conversation_id].append(oai_message) + + return True + + def _process_message_before_send( + self, message: Union[Dict, str], recipient: Agent, silent: bool + ) -> Union[Dict, str]: + """Process the message before sending it to the recipient.""" + hook_list = self.hook_lists["process_message_before_send"] + for hook in hook_list: + if inspect.iscoroutinefunction(hook): + continue + message = hook( + sender=self, message=message, recipient=recipient, silent=ConversableAgent._is_silent(self, silent) + ) + return message + + async def _a_process_message_before_send( + self, message: Union[Dict, str], recipient: Agent, silent: bool + ) -> Union[Dict, str]: + """(async) Process the message before sending it to the recipient.""" + hook_list = self.hook_lists["a_process_message_before_send"] + for hook in hook_list: + if not inspect.iscoroutinefunction(hook): + continue + message = await hook(sender=self, message=message, recipient=recipient, silent=silent) + return message + + def send( + self, + message: Union[Dict, str], + recipient: Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ): + """Send a message to another agent. + + Args: + message (dict or str): message to be sent. + The message could contain the following fields: + - content (str or List): Required, the content of the message. (Can be None) + - function_call (str): the name of the function to be called. + - name (str): the name of the function to be called. + - role (str): the role of the message, any role that is not "function" + will be modified to "assistant". + - context (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](../oai/client#create). + For example, one agent can send a message A as: + ```python + { + "content": lambda context: context["use_tool_msg"], + "context": { + "use_tool_msg": "Use tool X if they are relevant." + } + } + ``` + Next time, one agent can send a message B with a different "use_tool_msg". + Then the content of message A will be refreshed to the new "use_tool_msg". + So effectively, this provides a way for an agent to send a "link" and modify + the content of the "link" later. + recipient (Agent): the recipient of the message. + request_reply (bool or None): whether to request a reply from the recipient. + silent (bool or None): (Experimental) whether to print the message sent. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + message = self._process_message_before_send(message, recipient, ConversableAgent._is_silent(self, silent)) + # When the agent composes and sends the message, the role of the message is "assistant" + # unless it's "function". + valid = self._append_oai_message(message, "assistant", recipient, is_sending=True) + if valid: + recipient.receive(message, self, request_reply, silent) + else: + raise ValueError( + "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." + ) + + async def a_send( + self, + message: Union[Dict, str], + recipient: Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ): + """(async) Send a message to another agent. + + Args: + message (dict or str): message to be sent. + The message could contain the following fields: + - content (str or List): Required, the content of the message. (Can be None) + - function_call (str): the name of the function to be called. + - name (str): the name of the function to be called. + - role (str): the role of the message, any role that is not "function" + will be modified to "assistant". + - context (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](../oai/client#create). + For example, one agent can send a message A as: + ```python + { + "content": lambda context: context["use_tool_msg"], + "context": { + "use_tool_msg": "Use tool X if they are relevant." + } + } + ``` + Next time, one agent can send a message B with a different "use_tool_msg". + Then the content of message A will be refreshed to the new "use_tool_msg". + So effectively, this provides a way for an agent to send a "link" and modify + the content of the "link" later. + recipient (Agent): the recipient of the message. + request_reply (bool or None): whether to request a reply from the recipient. + silent (bool or None): (Experimental) whether to print the message sent. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + message = await self._a_process_message_before_send( + message, recipient, ConversableAgent._is_silent(self, silent) + ) + # When the agent composes and sends the message, the role of the message is "assistant" + # unless it's "function". + valid = self._append_oai_message(message, "assistant", recipient, is_sending=True) + if valid: + await recipient.a_receive(message, self, request_reply, silent) + else: + raise ValueError( + "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." + ) + + def _print_received_message(self, message: Union[Dict, str], sender: Agent): + iostream = IOStream.get_default() + # print the message received + iostream.print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True) + message = self._message_to_dict(message) + + if message.get("tool_responses"): # Handle tool multi-call responses + for tool_response in message["tool_responses"]: + self._print_received_message(tool_response, sender) + if message.get("role") == "tool": + return # If role is tool, then content is just a concatenation of all tool_responses + + if message.get("role") in ["function", "tool"]: + if message["role"] == "function": + id_key = "name" + else: + id_key = "tool_call_id" + id = message.get(id_key, "No id found") + func_print = f"***** Response from calling {message['role']} ({id}) *****" + iostream.print(colored(func_print, "green"), flush=True) + iostream.print(message["content"], flush=True) + iostream.print(colored("*" * len(func_print), "green"), flush=True) + else: + content = message.get("content") + if content is not None: + if "context" in message: + content = OpenAIWrapper.instantiate( + content, + message["context"], + self.llm_config and self.llm_config.get("allow_format_str_template", False), + ) + iostream.print(content_str(content), flush=True) + if "function_call" in message and message["function_call"]: + function_call = dict(message["function_call"]) + func_print = ( + f"***** Suggested function call: {function_call.get('name', '(No function name found)')} *****" + ) + iostream.print(colored(func_print, "green"), flush=True) + iostream.print( + "Arguments: \n", + function_call.get("arguments", "(No arguments found)"), + flush=True, + sep="", + ) + iostream.print(colored("*" * len(func_print), "green"), flush=True) + if "tool_calls" in message and message["tool_calls"]: + for tool_call in message["tool_calls"]: + id = tool_call.get("id", "No tool call id found") + function_call = dict(tool_call.get("function", {})) + func_print = f"***** Suggested tool call ({id}): {function_call.get('name', '(No function name found)')} *****" + iostream.print(colored(func_print, "green"), flush=True) + iostream.print( + "Arguments: \n", + function_call.get("arguments", "(No arguments found)"), + flush=True, + sep="", + ) + iostream.print(colored("*" * len(func_print), "green"), flush=True) + + iostream.print("\n", "-" * 80, flush=True, sep="") + + def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool): + # When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.) + valid = self._append_oai_message(message, "user", sender, is_sending=False) + if logging_enabled(): + log_event(self, "received_message", message=message, sender=sender.name, valid=valid) + + if not valid: + raise ValueError( + "Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." + ) + + if not ConversableAgent._is_silent(sender, silent): + self._print_received_message(message, sender) + + def receive( + self, + message: Union[Dict, str], + sender: Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ): + """Receive a message from another agent. + + Once a message is received, this function sends a reply to the sender or stop. + The reply can be generated automatically or entered manually by a human. + + Args: + message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content or function_call need to be provided). + 1. "content": content of the message, can be None. + 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls") + 3. "tool_calls": a list of dictionaries containing the function name and arguments. + 4. "role": role of the message, can be "assistant", "user", "function", "tool". + This field is only needed to distinguish between "function" or "assistant"/"user". + 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. + 6. "context" (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](../oai/client#create). + sender: sender of an Agent instance. + request_reply (bool or None): whether a reply is requested from the sender. + If None, the value is determined by `self.reply_at_receive[sender]`. + silent (bool or None): (Experimental) whether to print the message received. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + self._process_received_message(message, sender, silent) + if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False: + return + reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender) + if reply is not None: + self.send(reply, sender, silent=silent) + + async def a_receive( + self, + message: Union[Dict, str], + sender: Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ): + """(async) Receive a message from another agent. + + Once a message is received, this function sends a reply to the sender or stop. + The reply can be generated automatically or entered manually by a human. + + Args: + message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content or function_call need to be provided). + 1. "content": content of the message, can be None. + 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls") + 3. "tool_calls": a list of dictionaries containing the function name and arguments. + 4. "role": role of the message, can be "assistant", "user", "function". + This field is only needed to distinguish between "function" or "assistant"/"user". + 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. + 6. "context" (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](../oai/client#create). + sender: sender of an Agent instance. + request_reply (bool or None): whether a reply is requested from the sender. + If None, the value is determined by `self.reply_at_receive[sender]`. + silent (bool or None): (Experimental) whether to print the message received. + + Raises: + ValueError: if the message can't be converted into a valid ChatCompletion message. + """ + self._process_received_message(message, sender, silent) + if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False: + return + reply = await self.a_generate_reply(sender=sender) + if reply is not None: + await self.a_send(reply, sender, silent=silent) + + def _prepare_chat( + self, + recipient: "ConversableAgent", + clear_history: bool, + prepare_recipient: bool = True, + reply_at_receive: bool = True, + ) -> None: + self.reset_consecutive_auto_reply_counter(recipient) + self.reply_at_receive[recipient] = reply_at_receive + if clear_history: + self.clear_history(recipient) + self._human_input = [] + if prepare_recipient: + recipient._prepare_chat(self, clear_history, False, reply_at_receive) + + def _raise_exception_on_async_reply_functions(self) -> None: + """Raise an exception if any async reply functions are registered. + + Raises: + RuntimeError: if any async reply functions are registered. + """ + reply_functions = { + f["reply_func"] for f in self._reply_func_list if not f.get("ignore_async_in_sync_chat", False) + } + + async_reply_functions = [f for f in reply_functions if inspect.iscoroutinefunction(f)] + if async_reply_functions: + msg = ( + "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: " + + ", ".join([f.__name__ for f in async_reply_functions]) + ) + + raise RuntimeError(msg) + + def initiate_chat( + self, + recipient: "ConversableAgent", + clear_history: bool = True, + silent: Optional[bool] = False, + cache: Optional[AbstractCache] = None, + max_turns: Optional[int] = None, + summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD, + summary_args: Optional[dict] = {}, + message: Optional[Union[Dict, str, Callable]] = None, + **kwargs, + ) -> ChatResult: + """Initiate a chat with the recipient agent. + + Reset the consecutive auto reply counter. + If `clear_history` is True, the chat history with the recipient agent will be cleared. + + + Args: + recipient: the recipient agent. + clear_history (bool): whether to clear the chat history with the agent. Default is True. + silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. + cache (AbstractCache or None): the cache client to be used for this conversation. Default is None. + max_turns (int or None): the maximum number of turns for the chat between the two agents. One turn means one conversation round trip. Note that this is different from + [max_consecutive_auto_reply](#max_consecutive_auto_reply) which is the maximum number of consecutive auto replies; and it is also different from [max_rounds in GroupChat](./groupchat#groupchat-objects) which is the maximum number of rounds in a group chat session. + If max_turns is set to None, the chat will continue until a termination condition is met. Default is None. + summary_method (str or callable): a method to get a summary from the chat. Default is DEFAULT_SUMMARY_METHOD, i.e., "last_msg". + + Supported strings are "last_msg" and "reflection_with_llm": + - when set to "last_msg", it returns the last message of the dialog as the summary. + - when set to "reflection_with_llm", it returns a summary extracted using an llm client. + `llm_config` must be set in either the recipient or sender. + + A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g., + + ```python + def my_summary_method( + sender: ConversableAgent, + recipient: ConversableAgent, + summary_args: dict, + ): + return recipient.last_message(sender)["content"] + ``` + summary_args (dict): a dictionary of arguments to be passed to the summary_method. + One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect + on the conversation and extract a summary when summary_method is "reflection_with_llm". + The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out." + Another available key is "summary_role", which is the role of the message sent to the agent in charge of summarizing. Default is "system". + message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message. + - If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context. + If dict, it may contain the following reserved fields (either content or tool_calls need to be provided). + + 1. "content": content of the message, can be None. + 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls") + 3. "tool_calls": a list of dictionaries containing the function name and arguments. + 4. "role": role of the message, can be "assistant", "user", "function". + This field is only needed to distinguish between "function" or "assistant"/"user". + 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. + 6. "context" (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](../oai/client#create). + + - If a callable is provided, it will be called to get the initial message in the form of a string or a dict. + If the returned type is dict, it may contain the reserved fields mentioned above. + + Example of a callable message (returning a string): + + ```python + def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> Union[str, Dict]: + carryover = context.get("carryover", "") + if isinstance(message, list): + carryover = carryover[-1] + final_msg = "Write a blogpost." + "\\nContext: \\n" + carryover + return final_msg + ``` + + Example of a callable message (returning a dict): + + ```python + def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> Union[str, Dict]: + final_msg = {} + carryover = context.get("carryover", "") + if isinstance(message, list): + carryover = carryover[-1] + final_msg["content"] = "Write a blogpost." + "\\nContext: \\n" + carryover + final_msg["context"] = {"prefix": "Today I feel"} + return final_msg + ``` + **kwargs: any additional information. It has the following reserved fields: + - "carryover": a string or a list of string to specify the carryover information to be passed to this chat. + If provided, we will combine this carryover (by attaching a "context: " string and the carryover content after the message content) with the "message" content when generating the initial chat + message in `generate_init_message`. + - "verbose": a boolean to specify whether to print the message and carryover in a chat. Default is False. + + Raises: + RuntimeError: if any async reply functions are registered and not ignored in sync chat. + + Returns: + ChatResult: an ChatResult object. + """ + _chat_info = locals().copy() + _chat_info["sender"] = self + consolidate_chat_info(_chat_info, uniform_sender=self) + for agent in [self, recipient]: + agent._raise_exception_on_async_reply_functions() + agent.previous_cache = agent.client_cache + agent.client_cache = cache + if isinstance(max_turns, int): + self._prepare_chat(recipient, clear_history, reply_at_receive=False) + for _ in range(max_turns): + if _ == 0: + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = self.generate_init_message(message, **kwargs) + else: + msg2send = self.generate_reply(messages=self.chat_messages[recipient], sender=recipient) + if msg2send is None: + break + self.send(msg2send, recipient, request_reply=True, silent=silent) + else: + self._prepare_chat(recipient, clear_history) + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = self.generate_init_message(message, **kwargs) + self.send(msg2send, recipient, silent=silent) + summary = self._summarize_chat( + summary_method, + summary_args, + recipient, + cache=cache, + ) + for agent in [self, recipient]: + agent.client_cache = agent.previous_cache + agent.previous_cache = None + chat_result = ChatResult( + chat_history=self.chat_messages[recipient], + summary=summary, + cost=gather_usage_summary([self, recipient]), + human_input=self._human_input, + ) + return chat_result + + async def a_initiate_chat( + self, + recipient: "ConversableAgent", + clear_history: bool = True, + silent: Optional[bool] = False, + cache: Optional[AbstractCache] = None, + max_turns: Optional[int] = None, + summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD, + summary_args: Optional[dict] = {}, + message: Optional[Union[str, Callable]] = None, + **kwargs, + ) -> ChatResult: + """(async) Initiate a chat with the recipient agent. + + Reset the consecutive auto reply counter. + If `clear_history` is True, the chat history with the recipient agent will be cleared. + `a_generate_init_message` is called to generate the initial message for the agent. + + Args: Please refer to `initiate_chat`. + + Returns: + ChatResult: an ChatResult object. + """ + _chat_info = locals().copy() + _chat_info["sender"] = self + consolidate_chat_info(_chat_info, uniform_sender=self) + for agent in [self, recipient]: + agent.previous_cache = agent.client_cache + agent.client_cache = cache + if isinstance(max_turns, int): + self._prepare_chat(recipient, clear_history, reply_at_receive=False) + for _ in range(max_turns): + if _ == 0: + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = await self.a_generate_init_message(message, **kwargs) + else: + msg2send = await self.a_generate_reply(messages=self.chat_messages[recipient], sender=recipient) + if msg2send is None: + break + await self.a_send(msg2send, recipient, request_reply=True, silent=silent) + else: + self._prepare_chat(recipient, clear_history) + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = await self.a_generate_init_message(message, **kwargs) + await self.a_send(msg2send, recipient, silent=silent) + summary = self._summarize_chat( + summary_method, + summary_args, + recipient, + cache=cache, + ) + for agent in [self, recipient]: + agent.client_cache = agent.previous_cache + agent.previous_cache = None + chat_result = ChatResult( + chat_history=self.chat_messages[recipient], + summary=summary, + cost=gather_usage_summary([self, recipient]), + human_input=self._human_input, + ) + return chat_result + + def _summarize_chat( + self, + summary_method, + summary_args, + recipient: Optional[Agent] = None, + cache: Optional[AbstractCache] = None, + ) -> str: + """Get a chat summary from an agent participating in a chat. + + Args: + summary_method (str or callable): the summary_method to get the summary. + The callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g, + ```python + def my_summary_method( + sender: ConversableAgent, + recipient: ConversableAgent, + summary_args: dict, + ): + return recipient.last_message(sender)["content"] + ``` + summary_args (dict): a dictionary of arguments to be passed to the summary_method. + recipient: the recipient agent in a chat. + prompt (str): the prompt used to get a summary when summary_method is "reflection_with_llm". + + Returns: + str: a chat summary from the agent. + """ + summary = "" + if summary_method is None: + return summary + if "cache" not in summary_args: + summary_args["cache"] = cache + if summary_method == "reflection_with_llm": + summary_method = self._reflection_with_llm_as_summary + elif summary_method == "last_msg": + summary_method = self._last_msg_as_summary + + if isinstance(summary_method, Callable): + summary = summary_method(self, recipient, summary_args) + else: + raise ValueError( + "If not None, the summary_method must be a string from [`reflection_with_llm`, `last_msg`] or a callable." + ) + return summary + + @staticmethod + def _last_msg_as_summary(sender, recipient, summary_args) -> str: + """Get a chat summary from the last message of the recipient.""" + summary = "" + try: + content = recipient.last_message(sender)["content"] + if isinstance(content, str): + summary = content.replace("TERMINATE", "") + elif isinstance(content, list): + # Remove the `TERMINATE` word in the content list. + summary = "\n".join( + x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x + ) + except (IndexError, AttributeError) as e: + warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning) + return summary + + @staticmethod + def _reflection_with_llm_as_summary(sender, recipient, summary_args): + prompt = summary_args.get("summary_prompt") + prompt = ConversableAgent.DEFAULT_SUMMARY_PROMPT if prompt is None else prompt + if not isinstance(prompt, str): + raise ValueError("The summary_prompt must be a string.") + msg_list = recipient.chat_messages_for_summary(sender) + agent = sender if recipient is None else recipient + role = summary_args.get("summary_role", None) + if role and not isinstance(role, str): + raise ValueError("The summary_role in summary_arg must be a string.") + try: + summary = sender._reflection_with_llm( + prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"), role=role + ) + except BadRequestError as e: + warnings.warn( + f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning + ) + summary = "" + return summary + + def _reflection_with_llm( + self, + prompt, + messages, + llm_agent: Optional[Agent] = None, + cache: Optional[AbstractCache] = None, + role: Union[str, None] = None, + ) -> str: + """Get a chat summary using reflection with an llm client based on the conversation history. + + Args: + prompt (str): The prompt (in this method it is used as system prompt) used to get the summary. + messages (list): The messages generated as part of a chat conversation. + llm_agent: the agent with an llm client. + cache (AbstractCache or None): the cache client to be used for this conversation. + role (str): the role of the message, usually "system" or "user". Default is "system". + """ + if not role: + role = "system" + + system_msg = [ + { + "role": role, + "content": prompt, + } + ] + + messages = messages + system_msg + if llm_agent and llm_agent.client is not None: + llm_client = llm_agent.client + elif self.client is not None: + llm_client = self.client + else: + raise ValueError("No OpenAIWrapper client is found.") + response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache) + return response + + def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Check the chat queue and add the "sender" key if it's missing. + + Args: + chat_queue (List[Dict[str, Any]]): A list of dictionaries containing chat information. + + Returns: + List[Dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing. + """ + chat_queue_with_sender = [] + for chat_info in chat_queue: + if chat_info.get("sender") is None: + chat_info["sender"] = self + chat_queue_with_sender.append(chat_info) + return chat_queue_with_sender + + def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: + """(Experimental) Initiate chats with multiple agents. + + Args: + chat_queue (List[Dict]): a list of dictionaries containing the information of the chats. + Each dictionary should contain the input arguments for [`initiate_chat`](conversable_agent#initiate_chat) + + Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue. + """ + _chat_queue = self._check_chat_queue_for_sender(chat_queue) + self._finished_chats = initiate_chats(_chat_queue) + return self._finished_chats + + async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]: + _chat_queue = self._check_chat_queue_for_sender(chat_queue) + self._finished_chats = await a_initiate_chats(_chat_queue) + return self._finished_chats + + def get_chat_results(self, chat_index: Optional[int] = None) -> Union[List[ChatResult], ChatResult]: + """A summary from the finished chats of particular agents.""" + if chat_index is not None: + return self._finished_chats[chat_index] + else: + return self._finished_chats + + def reset(self): + """Reset the agent.""" + self.clear_history() + self.reset_consecutive_auto_reply_counter() + self.stop_reply_at_receive() + if self.client is not None: + self.client.clear_usage_summary() + for reply_func_tuple in self._reply_func_list: + if reply_func_tuple["reset_config"] is not None: + reply_func_tuple["reset_config"](reply_func_tuple["config"]) + else: + reply_func_tuple["config"] = copy.copy(reply_func_tuple["init_config"]) + + def stop_reply_at_receive(self, sender: Optional[Agent] = None): + """Reset the reply_at_receive of the sender.""" + if sender is None: + self.reply_at_receive.clear() + else: + self.reply_at_receive[sender] = False + + def reset_consecutive_auto_reply_counter(self, sender: Optional[Agent] = None): + """Reset the consecutive_auto_reply_counter of the sender.""" + if sender is None: + self._consecutive_auto_reply_counter.clear() + else: + self._consecutive_auto_reply_counter[sender] = 0 + + def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preserve: Optional[int] = None): + """Clear the chat history of the agent. + + Args: + recipient: the agent with whom the chat history to clear. If None, clear the chat history with all agents. + nr_messages_to_preserve: the number of newest messages to preserve in the chat history. + """ + iostream = IOStream.get_default() + if recipient is None: + if nr_messages_to_preserve: + for key in self._oai_messages: + nr_messages_to_preserve_internal = nr_messages_to_preserve + # if breaking history between function call and function response, save function call message + # additionally, otherwise openai will return error + first_msg_to_save = self._oai_messages[key][-nr_messages_to_preserve_internal] + if "tool_responses" in first_msg_to_save: + nr_messages_to_preserve_internal += 1 + iostream.print( + f"Preserving one more message for {self.name} to not divide history between tool call and " + f"tool response." + ) + # Remove messages from history except last `nr_messages_to_preserve` messages. + self._oai_messages[key] = self._oai_messages[key][-nr_messages_to_preserve_internal:] + else: + self._oai_messages.clear() + else: + self._oai_messages[recipient].clear() + if nr_messages_to_preserve: + iostream.print( + colored( + "WARNING: `nr_preserved_messages` is ignored when clearing chat history with a specific agent.", + "yellow", + ), + flush=True, + ) + + def generate_oai_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[OpenAIWrapper] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + """Generate a reply using autogen.oai.""" + client = self.client if config is None else config + if client is None: + return False, None + if messages is None: + messages = self._oai_messages[sender] + extracted_response = self._generate_oai_reply_from_client( + client, self._oai_system_message + messages, self.client_cache + ) + return (False, None) if extracted_response is None else (True, extracted_response) + + def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[str, Dict, None]: + # unroll tool_responses + all_messages = [] + for message in messages: + tool_responses = message.get("tool_responses", []) + if tool_responses: + all_messages += tool_responses + # tool role on the parent message means the content is just concatenation of all of the tool_responses + if message.get("role") != "tool": + all_messages.append({key: message[key] for key in message if key != "tool_responses"}) + else: + all_messages.append(message) + + # TODO: #1143 handle token limit exceeded error + response = llm_client.create( + context=messages[-1].pop("context", None), messages=all_messages, cache=cache, agent=self + ) + extracted_response = llm_client.extract_text_or_completion_object(response)[0] + + if extracted_response is None: + warnings.warn(f"Extracted_response from {response} is None.", UserWarning) + return None + # ensure function and tool calls will be accepted when sent back to the LLM + if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"): + extracted_response = model_dump(extracted_response) + if isinstance(extracted_response, dict): + if extracted_response.get("function_call"): + extracted_response["function_call"]["name"] = self._normalize_name( + extracted_response["function_call"]["name"] + ) + for tool_call in extracted_response.get("tool_calls") or []: + tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"]) + # Remove id and type if they are not present. + # This is to make the tool call object compatible with Mistral API. + if tool_call.get("id") is None: + tool_call.pop("id") + if tool_call.get("type") is None: + tool_call.pop("type") + return extracted_response + + async def a_generate_oai_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + """Generate a reply using autogen.oai asynchronously.""" + iostream = IOStream.get_default() + parent_context = contextvars.copy_context() + + def _generate_oai_reply( + self, iostream: IOStream, *args: Any, **kwargs: Any + ) -> Tuple[bool, Union[str, Dict, None]]: + with IOStream.set_default(iostream): + return self.generate_oai_reply(*args, **kwargs) + + return await asyncio.get_event_loop().run_in_executor( + None, + lambda: parent_context.run( + _generate_oai_reply, self=self, iostream=iostream, messages=messages, sender=sender, config=config + ), + ) + + def _generate_code_execution_reply_using_executor( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Union[Dict, Literal[False]]] = None, + ): + """Generate a reply using code executor.""" + iostream = IOStream.get_default() + + if config is not None: + raise ValueError("config is not supported for _generate_code_execution_reply_using_executor.") + if self._code_execution_config is False: + return False, None + if messages is None: + messages = self._oai_messages[sender] + last_n_messages = self._code_execution_config.get("last_n_messages", "auto") + + if not (isinstance(last_n_messages, (int, float)) and last_n_messages >= 0) and last_n_messages != "auto": + raise ValueError("last_n_messages must be either a non-negative integer, or the string 'auto'.") + + num_messages_to_scan = last_n_messages + if last_n_messages == "auto": + # Find when the agent last spoke + num_messages_to_scan = 0 + for message in reversed(messages): + if "role" not in message: + break + elif message["role"] != "user": + break + else: + num_messages_to_scan += 1 + num_messages_to_scan = min(len(messages), num_messages_to_scan) + messages_to_scan = messages[-num_messages_to_scan:] + + # iterate through the last n messages in reverse + # if code blocks are found, execute the code blocks and return the output + # if no code blocks are found, continue + for message in reversed(messages_to_scan): + if not message["content"]: + continue + code_blocks = self._code_executor.code_extractor.extract_code_blocks(message["content"]) + if len(code_blocks) == 0: + continue + + num_code_blocks = len(code_blocks) + if num_code_blocks == 1: + iostream.print( + colored( + f"\n>>>>>>>> EXECUTING CODE BLOCK (inferred language is {code_blocks[0].language})...", + "red", + ), + flush=True, + ) + else: + iostream.print( + colored( + f"\n>>>>>>>> EXECUTING {num_code_blocks} CODE BLOCKS (inferred languages are [{', '.join([x.language for x in code_blocks])}])...", + "red", + ), + flush=True, + ) + + # found code blocks, execute code. + code_result = self._code_executor.execute_code_blocks(code_blocks) + exitcode2str = "execution succeeded" if code_result.exit_code == 0 else "execution failed" + return True, f"exitcode: {code_result.exit_code} ({exitcode2str})\nCode output: {code_result.output}" + + return False, None + + def generate_code_execution_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Union[Dict, Literal[False]]] = None, + ): + """Generate a reply using code execution.""" + code_execution_config = config if config is not None else self._code_execution_config + if code_execution_config is False: + return False, None + if messages is None: + messages = self._oai_messages[sender] + last_n_messages = code_execution_config.pop("last_n_messages", "auto") + + if not (isinstance(last_n_messages, (int, float)) and last_n_messages >= 0) and last_n_messages != "auto": + raise ValueError("last_n_messages must be either a non-negative integer, or the string 'auto'.") + + messages_to_scan = last_n_messages + if last_n_messages == "auto": + # Find when the agent last spoke + messages_to_scan = 0 + for i in range(len(messages)): + message = messages[-(i + 1)] + if "role" not in message: + break + elif message["role"] != "user": + break + else: + messages_to_scan += 1 + + # iterate through the last n messages in reverse + # if code blocks are found, execute the code blocks and return the output + # if no code blocks are found, continue + for i in range(min(len(messages), messages_to_scan)): + message = messages[-(i + 1)] + if not message["content"]: + continue + code_blocks = extract_code(message["content"]) + if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN: + continue + + # found code blocks, execute code and push "last_n_messages" back + exitcode, logs = self.execute_code_blocks(code_blocks) + code_execution_config["last_n_messages"] = last_n_messages + exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed" + return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}" + + # no code blocks are found, push last_n_messages back and return. + code_execution_config["last_n_messages"] = last_n_messages + + return False, None + + def generate_function_call_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[Dict, None]]: + """ + Generate a reply using function call. + + "function_call" replaced by "tool_calls" as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) + See https://platform.openai.com/docs/api-reference/chat/create#chat-create-functions + """ + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + if "function_call" in message and message["function_call"]: + func_call = message["function_call"] + func = self._function_map.get(func_call.get("name", None), None) + if inspect.iscoroutinefunction(func): + try: + # get the running loop if it was already created + loop = asyncio.get_running_loop() + close_loop = False + except RuntimeError: + # create a loop if there is no running loop + loop = asyncio.new_event_loop() + close_loop = True + + _, func_return = loop.run_until_complete(self.a_execute_function(func_call)) + if close_loop: + loop.close() + else: + _, func_return = self.execute_function(message["function_call"]) + return True, func_return + return False, None + + async def a_generate_function_call_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[Dict, None]]: + """ + Generate a reply using async function call. + + "function_call" replaced by "tool_calls" as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) + See https://platform.openai.com/docs/api-reference/chat/create#chat-create-functions + """ + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + func_call = message.get("function_call") + if func_call: + func_name = func_call.get("name", "") + func = self._function_map.get(func_name, None) + if func and inspect.iscoroutinefunction(func): + _, func_return = await self.a_execute_function(func_call) + else: + _, func_return = self.execute_function(func_call) + return True, func_return + + return False, None + + def _str_for_tool_response(self, tool_response): + return str(tool_response.get("content", "")) + + def generate_tool_calls_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[Dict, None]]: + """Generate a reply using tool call.""" + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + tool_returns = [] + for tool_call in message.get("tool_calls", []): + function_call = tool_call.get("function", {}) + func = self._function_map.get(function_call.get("name", None), None) + if inspect.iscoroutinefunction(func): + try: + # get the running loop if it was already created + loop = asyncio.get_running_loop() + close_loop = False + except RuntimeError: + # create a loop if there is no running loop + loop = asyncio.new_event_loop() + close_loop = True + + _, func_return = loop.run_until_complete(self.a_execute_function(function_call)) + if close_loop: + loop.close() + else: + _, func_return = self.execute_function(function_call) + content = func_return.get("content", "") + if content is None: + content = "" + tool_call_id = tool_call.get("id", None) + if tool_call_id is not None: + tool_call_response = { + "tool_call_id": tool_call_id, + "role": "tool", + "content": content, + } + else: + # Do not include tool_call_id if it is not present. + # This is to make the tool call object compatible with Mistral API. + tool_call_response = { + "role": "tool", + "content": content, + } + tool_returns.append(tool_call_response) + if tool_returns: + return True, { + "role": "tool", + "tool_responses": tool_returns, + "content": "\n\n".join([self._str_for_tool_response(tool_return) for tool_return in tool_returns]), + } + return False, None + + async def _a_execute_tool_call(self, tool_call): + id = tool_call["id"] + function_call = tool_call.get("function", {}) + _, func_return = await self.a_execute_function(function_call) + return { + "tool_call_id": id, + "role": "tool", + "content": func_return.get("content", ""), + } + + async def a_generate_tool_calls_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[Dict, None]]: + """Generate a reply using async function call.""" + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] + message = messages[-1] + async_tool_calls = [] + for tool_call in message.get("tool_calls", []): + async_tool_calls.append(self._a_execute_tool_call(tool_call)) + if async_tool_calls: + tool_returns = await asyncio.gather(*async_tool_calls) + return True, { + "role": "tool", + "tool_responses": tool_returns, + "content": "\n\n".join([self._str_for_tool_response(tool_return) for tool_return in tool_returns]), + } + + return False, None + + def check_termination_and_human_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, None]]: + """Check if the conversation should be terminated, and if human reply is provided. + + This method checks for conditions that require the conversation to be terminated, such as reaching + a maximum number of consecutive auto-replies or encountering a termination message. Additionally, + it prompts for and processes human input based on the configured human input mode, which can be + 'ALWAYS', 'NEVER', or 'TERMINATE'. The method also manages the consecutive auto-reply counter + for the conversation and prints relevant messages based on the human input received. + + Args: + - messages (Optional[List[Dict]]): A list of message dictionaries, representing the conversation history. + - sender (Optional[Agent]): The agent object representing the sender of the message. + - config (Optional[Any]): Configuration object, defaults to the current instance if not provided. + + Returns: + - Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation + should be terminated, and a human reply which can be a string, a dictionary, or None. + """ + iostream = IOStream.get_default() + + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] if sender else [] + message = messages[-1] + reply = "" + no_human_input_msg = "" + sender_name = "the sender" if sender is None else sender.name + if self.human_input_mode == "ALWAYS": + reply = self.get_human_input( + f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + reply = reply if reply or not self._is_termination_msg(message) else "exit" + else: + if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + if self.human_input_mode == "NEVER": + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + terminate = self._is_termination_msg(message) + reply = self.get_human_input( + f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " + if terminate + else f"Please give feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + reply = reply if reply or not terminate else "exit" + elif self._is_termination_msg(message): + if self.human_input_mode == "NEVER": + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + reply = self.get_human_input( + f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + reply = reply or "exit" + + # print the no_human_input_msg + if no_human_input_msg: + iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True) + + # stop the conversation + if reply == "exit": + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + return True, None + + # send the human reply + if reply or self._max_consecutive_auto_reply_dict[sender] == 0: + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + # User provided a custom response, return function and tool failures indicating user interruption + tool_returns = [] + if message.get("function_call", False): + tool_returns.append( + { + "role": "function", + "name": message["function_call"].get("name", ""), + "content": "USER INTERRUPTED", + } + ) + + if message.get("tool_calls", False): + tool_returns.extend( + [ + {"role": "tool", "tool_call_id": tool_call.get("id", ""), "content": "USER INTERRUPTED"} + for tool_call in message["tool_calls"] + ] + ) + + response = {"role": "user", "content": reply} + if tool_returns: + response["tool_responses"] = tool_returns + + return True, response + + # increment the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] += 1 + if self.human_input_mode != "NEVER": + iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True) + + return False, None + + async def a_check_termination_and_human_reply( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, None]]: + """(async) Check if the conversation should be terminated, and if human reply is provided. + + This method checks for conditions that require the conversation to be terminated, such as reaching + a maximum number of consecutive auto-replies or encountering a termination message. Additionally, + it prompts for and processes human input based on the configured human input mode, which can be + 'ALWAYS', 'NEVER', or 'TERMINATE'. The method also manages the consecutive auto-reply counter + for the conversation and prints relevant messages based on the human input received. + + Args: + - messages (Optional[List[Dict]]): A list of message dictionaries, representing the conversation history. + - sender (Optional[Agent]): The agent object representing the sender of the message. + - config (Optional[Any]): Configuration object, defaults to the current instance if not provided. + + Returns: + - Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation + should be terminated, and a human reply which can be a string, a dictionary, or None. + """ + iostream = IOStream.get_default() + + if config is None: + config = self + if messages is None: + messages = self._oai_messages[sender] if sender else [] + message = messages[-1] if messages else {} + reply = "" + no_human_input_msg = "" + sender_name = "the sender" if sender is None else sender.name + if self.human_input_mode == "ALWAYS": + reply = await self.a_get_human_input( + f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + reply = reply if reply or not self._is_termination_msg(message) else "exit" + else: + if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + if self.human_input_mode == "NEVER": + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + terminate = self._is_termination_msg(message) + reply = await self.a_get_human_input( + f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " + if terminate + else f"Please give feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + reply = reply if reply or not terminate else "exit" + elif self._is_termination_msg(message): + if self.human_input_mode == "NEVER": + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + reply = await self.a_get_human_input( + f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + reply = reply or "exit" + + # print the no_human_input_msg + if no_human_input_msg: + iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True) + + # stop the conversation + if reply == "exit": + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + return True, None + + # send the human reply + if reply or self._max_consecutive_auto_reply_dict[sender] == 0: + # User provided a custom response, return function and tool results indicating user interruption + # reset the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] = 0 + tool_returns = [] + if message.get("function_call", False): + tool_returns.append( + { + "role": "function", + "name": message["function_call"].get("name", ""), + "content": "USER INTERRUPTED", + } + ) + + if message.get("tool_calls", False): + tool_returns.extend( + [ + {"role": "tool", "tool_call_id": tool_call.get("id", ""), "content": "USER INTERRUPTED"} + for tool_call in message["tool_calls"] + ] + ) + + response = {"role": "user", "content": reply} + if tool_returns: + response["tool_responses"] = tool_returns + + return True, response + + # increment the consecutive_auto_reply_counter + self._consecutive_auto_reply_counter[sender] += 1 + if self.human_input_mode != "NEVER": + iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True) + + return False, None + + def generate_reply( + self, + messages: Optional[List[Dict[str, Any]]] = None, + sender: Optional["Agent"] = None, + **kwargs: Any, + ) -> Union[str, Dict, None]: + """Reply based on the conversation history and the sender. + + Either messages or sender must be provided. + Register a reply_func with `None` as one trigger for it to be activated when `messages` is non-empty and `sender` is `None`. + Use registered auto reply functions to generate replies. + By default, the following functions are checked in order: + 1. check_termination_and_human_reply + 2. generate_function_call_reply (deprecated in favor of tool_calls) + 3. generate_tool_calls_reply + 4. generate_code_execution_reply + 5. generate_oai_reply + Every function returns a tuple (final, reply). + When a function returns final=False, the next function will be checked. + So by default, termination and human reply will be checked first. + If not terminating and human reply is skipped, execute function or code and return the result. + AI replies are generated only when no code execution is performed. + + Args: + messages: a list of messages in the conversation history. + sender: sender of an Agent instance. + + Additional keyword arguments: + exclude (List[Callable]): a list of reply functions to be excluded. + + Returns: + str or dict or None: reply. None if no reply is generated. + """ + if all((messages is None, sender is None)): + error_msg = f"Either {messages=} or {sender=} must be provided." + logger.error(error_msg) + raise AssertionError(error_msg) + + if messages is None: + messages = self._oai_messages[sender] + + # Call the hookable method that gives registered hooks a chance to process the last message. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = self.process_last_received_message(messages) + + # Call the hookable method that gives registered hooks a chance to process all messages. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = self.process_all_messages_before_reply(messages) + + for reply_func_tuple in self._reply_func_list: + reply_func = reply_func_tuple["reply_func"] + if "exclude" in kwargs and reply_func in kwargs["exclude"]: + continue + if inspect.iscoroutinefunction(reply_func): + continue + if self._match_trigger(reply_func_tuple["trigger"], sender): + final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) + if logging_enabled(): + log_event( + self, + "reply_func_executed", + reply_func_module=reply_func.__module__, + reply_func_name=reply_func.__name__, + final=final, + reply=reply, + ) + if final: + return reply + return self._default_auto_reply + + async def a_generate_reply( + self, + messages: Optional[List[Dict[str, Any]]] = None, + sender: Optional["Agent"] = None, + **kwargs: Any, + ) -> Union[str, Dict[str, Any], None]: + """(async) Reply based on the conversation history and the sender. + + Either messages or sender must be provided. + Register a reply_func with `None` as one trigger for it to be activated when `messages` is non-empty and `sender` is `None`. + Use registered auto reply functions to generate replies. + By default, the following functions are checked in order: + 1. check_termination_and_human_reply + 2. generate_function_call_reply + 3. generate_tool_calls_reply + 4. generate_code_execution_reply + 5. generate_oai_reply + Every function returns a tuple (final, reply). + When a function returns final=False, the next function will be checked. + So by default, termination and human reply will be checked first. + If not terminating and human reply is skipped, execute function or code and return the result. + AI replies are generated only when no code execution is performed. + + Args: + messages: a list of messages in the conversation history. + sender: sender of an Agent instance. + + Additional keyword arguments: + exclude (List[Callable]): a list of reply functions to be excluded. + + Returns: + str or dict or None: reply. None if no reply is generated. + """ + if all((messages is None, sender is None)): + error_msg = f"Either {messages=} or {sender=} must be provided." + logger.error(error_msg) + raise AssertionError(error_msg) + + if messages is None: + messages = self._oai_messages[sender] + + # Call the hookable method that gives registered hooks a chance to process all messages. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = await self.a_process_all_messages_before_reply(messages) + + # Call the hookable method that gives registered hooks a chance to process the last message. + # Message modifications do not affect the incoming messages or self._oai_messages. + messages = await self.a_process_last_received_message(messages) + + for reply_func_tuple in self._reply_func_list: + reply_func = reply_func_tuple["reply_func"] + if "exclude" in kwargs and reply_func in kwargs["exclude"]: + continue + + if self._match_trigger(reply_func_tuple["trigger"], sender): + if inspect.iscoroutinefunction(reply_func): + final, reply = await reply_func( + self, messages=messages, sender=sender, config=reply_func_tuple["config"] + ) + else: + final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) + if final: + return reply + return self._default_auto_reply + + def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Optional[Agent]) -> bool: + """Check if the sender matches the trigger. + + Args: + - trigger (Union[None, str, type, Agent, Callable, List]): The condition to match against the sender. + Can be `None`, string, type, `Agent` instance, callable, or a list of these. + - sender (Agent): The sender object or type to be matched against the trigger. + + Returns: + - bool: Returns `True` if the sender matches the trigger, otherwise `False`. + + Raises: + - ValueError: If the trigger type is unsupported. + """ + if trigger is None: + return sender is None + elif isinstance(trigger, str): + if sender is None: + raise SenderRequired() + return trigger == sender.name + elif isinstance(trigger, type): + return isinstance(sender, trigger) + elif isinstance(trigger, Agent): + # return True if the sender is the same type (class) as the trigger + return trigger == sender + elif isinstance(trigger, Callable): + rst = trigger(sender) + assert isinstance(rst, bool), f"trigger {trigger} must return a boolean value." + return rst + elif isinstance(trigger, list): + return any(self._match_trigger(t, sender) for t in trigger) + else: + raise ValueError(f"Unsupported trigger type: {type(trigger)}") + + def get_human_input(self, prompt: str) -> str: + """Get human input. + + Override this method to customize the way to get human input. + + Args: + prompt (str): prompt for the human input. + + Returns: + str: human input. + """ + iostream = IOStream.get_default() + + reply = iostream.input(prompt) + self._human_input.append(reply) + return reply + + async def a_get_human_input(self, prompt: str) -> str: + """(Async) Get human input. + + Override this method to customize the way to get human input. + + Args: + prompt (str): prompt for the human input. + + Returns: + str: human input. + """ + loop = asyncio.get_running_loop() + reply = await loop.run_in_executor(None, functools.partial(self.get_human_input, prompt)) + return reply + + def run_code(self, code, **kwargs): + """Run the code and return the result. + + Override this function to modify the way to run the code. + Args: + code (str): the code to be executed. + **kwargs: other keyword arguments. + + Returns: + A tuple of (exitcode, logs, image). + exitcode (int): the exit code of the code execution. + logs (str): the logs of the code execution. + image (str or None): the docker image used for the code execution. + """ + return execute_code(code, **kwargs) + + def execute_code_blocks(self, code_blocks): + """Execute the code blocks and return the result.""" + iostream = IOStream.get_default() + + logs_all = "" + for i, code_block in enumerate(code_blocks): + lang, code = code_block + if not lang: + lang = infer_lang(code) + iostream.print( + colored( + f"\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...", + "red", + ), + flush=True, + ) + if lang in ["bash", "shell", "sh"]: + exitcode, logs, image = self.run_code(code, lang=lang, **self._code_execution_config) + elif lang in PYTHON_VARIANTS: + if code.startswith("# filename: "): + filename = code[11 : code.find("\n")].strip() + else: + filename = None + exitcode, logs, image = self.run_code( + code, + lang="python", + filename=filename, + **self._code_execution_config, + ) + else: + # In case the language is not supported, we return an error message. + exitcode, logs, image = ( + 1, + f"unknown language {lang}", + None, + ) + # raise NotImplementedError + if image is not None: + self._code_execution_config["use_docker"] = image + logs_all += "\n" + logs + if exitcode != 0: + return exitcode, logs_all + return exitcode, logs_all + + @staticmethod + def _format_json_str(jstr): + """Remove newlines outside of quotes, and handle JSON escape sequences. + + 1. this function removes the newline in the query outside of quotes otherwise json.loads(s) will fail. + Ex 1: + "{\n"tool": "python",\n"query": "print('hello')\nprint('world')"\n}" -> "{"tool": "python","query": "print('hello')\nprint('world')"}" + Ex 2: + "{\n \"location\": \"Boston, MA\"\n}" -> "{"location": "Boston, MA"}" + + 2. this function also handles JSON escape sequences inside quotes. + Ex 1: + '{"args": "a\na\na\ta"}' -> '{"args": "a\\na\\na\\ta"}' + """ + result = [] + inside_quotes = False + last_char = " " + for char in jstr: + if last_char != "\\" and char == '"': + inside_quotes = not inside_quotes + last_char = char + if not inside_quotes and char == "\n": + continue + if inside_quotes and char == "\n": + char = "\\n" + if inside_quotes and char == "\t": + char = "\\t" + result.append(char) + return "".join(result) + + def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict[str, str]]: + """Execute a function call and return the result. + + Override this function to modify the way to execute function and tool calls. + + Args: + func_call: a dictionary extracted from openai message at "function_call" or "tool_calls" with keys "name" and "arguments". + + Returns: + A tuple of (is_exec_success, result_dict). + is_exec_success (boolean): whether the execution is successful. + result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function". + + "function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) + See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call + """ + iostream = IOStream.get_default() + + func_name = func_call.get("name", "") + func = self._function_map.get(func_name, None) + + is_exec_success = False + if func is not None: + # Extract arguments from a json-like string and put it into a dict. + input_string = self._format_json_str(func_call.get("arguments", "{}")) + try: + arguments = json.loads(input_string) + except json.JSONDecodeError as e: + arguments = None + content = f"Error: {e}\n The argument must be in JSON format." + + # Try to execute the function + if arguments is not None: + iostream.print( + colored(f"\n>>>>>>>> EXECUTING FUNCTION {func_name}...", "magenta"), + flush=True, + ) + try: + content = func(**arguments) + is_exec_success = True + except Exception as e: + content = f"Error: {e}" + else: + content = f"Error: Function {func_name} not found." + + if verbose: + iostream.print( + colored(f"\nInput arguments: {arguments}\nOutput:\n{content}", "magenta"), + flush=True, + ) + + return is_exec_success, { + "name": func_name, + "role": "function", + "content": str(content), + } + + async def a_execute_function(self, func_call): + """Execute an async function call and return the result. + + Override this function to modify the way async functions and tools are executed. + + Args: + func_call: a dictionary extracted from openai message at key "function_call" or "tool_calls" with keys "name" and "arguments". + + Returns: + A tuple of (is_exec_success, result_dict). + is_exec_success (boolean): whether the execution is successful. + result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function". + + "function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) + See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call + """ + iostream = IOStream.get_default() + + func_name = func_call.get("name", "") + func = self._function_map.get(func_name, None) + + is_exec_success = False + if func is not None: + # Extract arguments from a json-like string and put it into a dict. + input_string = self._format_json_str(func_call.get("arguments", "{}")) + try: + arguments = json.loads(input_string) + except json.JSONDecodeError as e: + arguments = None + content = f"Error: {e}\n The argument must be in JSON format." + + # Try to execute the function + if arguments is not None: + iostream.print( + colored(f"\n>>>>>>>> EXECUTING ASYNC FUNCTION {func_name}...", "magenta"), + flush=True, + ) + try: + if inspect.iscoroutinefunction(func): + content = await func(**arguments) + else: + # Fallback to sync function if the function is not async + content = func(**arguments) + is_exec_success = True + except Exception as e: + content = f"Error: {e}" + else: + content = f"Error: Function {func_name} not found." + + return is_exec_success, { + "name": func_name, + "role": "function", + "content": str(content), + } + + def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: + """Generate the initial message for the agent. + If message is None, input() will be called to get the initial message. + + Args: + message (str or None): the message to be processed. + **kwargs: any additional information. It has the following reserved fields: + "carryover": a string or a list of string to specify the carryover information to be passed to this chat. It can be a string or a list of string. + If provided, we will combine this carryover with the "message" content when generating the initial chat + message. + Returns: + str or dict: the processed message. + """ + if message is None: + message = self.get_human_input(">") + + return self._handle_carryover(message, kwargs) + + def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]: + if not kwargs.get("carryover"): + return message + + if isinstance(message, str): + return self._process_carryover(message, kwargs) + + elif isinstance(message, dict): + if isinstance(message.get("content"), str): + # Makes sure the original message is not mutated + message = message.copy() + message["content"] = self._process_carryover(message["content"], kwargs) + elif isinstance(message.get("content"), list): + # Makes sure the original message is not mutated + message = message.copy() + message["content"] = self._process_multimodal_carryover(message["content"], kwargs) + else: + raise InvalidCarryOverType("Carryover should be a string or a list of strings.") + + return message + + def _process_carryover(self, content: str, kwargs: dict) -> str: + # Makes sure there's a carryover + if not kwargs.get("carryover"): + return content + + # if carryover is string + if isinstance(kwargs["carryover"], str): + content += "\nContext: \n" + kwargs["carryover"] + elif isinstance(kwargs["carryover"], list): + content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]]) + else: + raise InvalidCarryOverType( + "Carryover should be a string or a list of strings. Not adding carryover to the message." + ) + return content + + def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]: + """Prepends the context to a multimodal message.""" + # Makes sure there's a carryover + if not kwargs.get("carryover"): + return content + + return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content + + async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: + """Generate the initial message for the agent. + If message is None, input() will be called to get the initial message. + + Args: + Please refer to `generate_init_message` for the description of the arguments. + + Returns: + str or dict: the processed message. + """ + if message is None: + message = await self.a_get_human_input(">") + + return self._handle_carryover(message, kwargs) + + def register_function(self, function_map: Dict[str, Union[Callable, None]]): + """Register functions to the agent. + + Args: + function_map: a dictionary mapping function names to functions. if function_map[name] is None, the function will be removed from the function_map. + """ + for name, func in function_map.items(): + self._assert_valid_name(name) + if func is None and name not in self._function_map.keys(): + warnings.warn(f"The function {name} to remove doesn't exist", name) + if name in self._function_map: + warnings.warn(f"Function '{name}' is being overridden.", UserWarning) + self._function_map.update(function_map) + self._function_map = {k: v for k, v in self._function_map.items() if v is not None} + + def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None): + """update a function_signature in the LLM configuration for function_call. + + Args: + func_sig (str or dict): description/name of the function to update/remove to the model. See: https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions + is_remove: whether removing the function from llm_config with name 'func_sig' + + Deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) + See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call + """ + + if not isinstance(self.llm_config, dict): + error_msg = "To update a function signature, agent must have an llm_config" + logger.error(error_msg) + raise AssertionError(error_msg) + + if is_remove: + if "functions" not in self.llm_config.keys(): + error_msg = "The agent config doesn't have function {name}.".format(name=func_sig) + logger.error(error_msg) + raise AssertionError(error_msg) + else: + self.llm_config["functions"] = [ + func for func in self.llm_config["functions"] if func["name"] != func_sig + ] + else: + if not isinstance(func_sig, dict): + raise ValueError( + f"The function signature must be of the type dict. Received function signature type {type(func_sig)}" + ) + + self._assert_valid_name(func_sig["name"]) + if "functions" in self.llm_config.keys(): + if any(func["name"] == func_sig["name"] for func in self.llm_config["functions"]): + warnings.warn(f"Function '{func_sig['name']}' is being overridden.", UserWarning) + + self.llm_config["functions"] = [ + func for func in self.llm_config["functions"] if func.get("name") != func_sig["name"] + ] + [func_sig] + else: + self.llm_config["functions"] = [func_sig] + + if len(self.llm_config["functions"]) == 0: + del self.llm_config["functions"] + + self.client = OpenAIWrapper(**self.llm_config) + + def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None): + """update a tool_signature in the LLM configuration for tool_call. + + Args: + tool_sig (str or dict): description/name of the tool to update/remove to the model. See: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools + is_remove: whether removing the tool from llm_config with name 'tool_sig' + """ + + if not self.llm_config: + error_msg = "To update a tool signature, agent must have an llm_config" + logger.error(error_msg) + raise AssertionError(error_msg) + + if is_remove: + if "tools" not in self.llm_config.keys(): + error_msg = "The agent config doesn't have tool {name}.".format(name=tool_sig) + logger.error(error_msg) + raise AssertionError(error_msg) + else: + self.llm_config["tools"] = [ + tool for tool in self.llm_config["tools"] if tool["function"]["name"] != tool_sig + ] + else: + if not isinstance(tool_sig, dict): + raise ValueError( + f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}" + ) + self._assert_valid_name(tool_sig["function"]["name"]) + if "tools" in self.llm_config: + if any(tool["function"]["name"] == tool_sig["function"]["name"] for tool in self.llm_config["tools"]): + warnings.warn(f"Function '{tool_sig['function']['name']}' is being overridden.", UserWarning) + self.llm_config["tools"] = [ + tool + for tool in self.llm_config["tools"] + if tool.get("function", {}).get("name") != tool_sig["function"]["name"] + ] + [tool_sig] + else: + self.llm_config["tools"] = [tool_sig] + + if len(self.llm_config["tools"]) == 0: + del self.llm_config["tools"] + + self.client = OpenAIWrapper(**self.llm_config) + + def can_execute_function(self, name: Union[List[str], str]) -> bool: + """Whether the agent can execute the function.""" + names = name if isinstance(name, list) else [name] + return all([n in self._function_map for n in names]) + + @property + def function_map(self) -> Dict[str, Callable]: + """Return the function map.""" + return self._function_map + + def _wrap_function(self, func: F) -> F: + """Wrap the function to dump the return value to json. + + Handles both sync and async functions. + + Args: + func: the function to be wrapped. + + Returns: + The wrapped function. + """ + + @load_basemodels_if_needed + @functools.wraps(func) + def _wrapped_func(*args, **kwargs): + retval = func(*args, **kwargs) + if logging_enabled(): + log_function_use(self, func, kwargs, retval) + return serialize_to_str(retval) + + @load_basemodels_if_needed + @functools.wraps(func) + async def _a_wrapped_func(*args, **kwargs): + retval = await func(*args, **kwargs) + if logging_enabled(): + log_function_use(self, func, kwargs, retval) + return serialize_to_str(retval) + + wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func + + # needed for testing + wrapped_func._origin = func + + return wrapped_func + + def register_for_llm( + self, + *, + name: Optional[str] = None, + description: Optional[str] = None, + api_style: Literal["function", "tool"] = "tool", + ) -> Callable[[F], F]: + """Decorator factory for registering a function to be used by an agent. + + It's return value is used to decorate a function to be registered to the agent. The function uses type hints to + specify the arguments and return type. The function name is used as the default name for the function, + but a custom name can be provided. The function description is used to describe the function in the + agent's configuration. + + Args: + name (optional(str)): name of the function. If None, the function name will be used (default: None). + description (optional(str)): description of the function (default: None). It is mandatory + for the initial decorator, but the following ones can omit it. + api_style: (literal): the API style for function call. + For Azure OpenAI API, use version 2023-12-01-preview or later. + `"function"` style will be deprecated. For earlier version use + `"function"` if `"tool"` doesn't work. + See [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/function-calling?tabs=python) for details. + + Returns: + The decorator for registering a function to be used by an agent. + + Examples: + ``` + @user_proxy.register_for_execution() + @agent2.register_for_llm() + @agent1.register_for_llm(description="This is a very useful function") + def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str: + return a + str(b * c) + ``` + + For Azure OpenAI versions prior to 2023-12-01-preview, set `api_style` + to `"function"` if `"tool"` doesn't work: + ``` + @agent2.register_for_llm(api_style="function") + def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str: + return a + str(b * c) + ``` + + """ + + def _decorator(func: F) -> F: + """Decorator for registering a function to be used by an agent. + + Args: + func: the function to be registered. + + Returns: + The function to be registered, with the _description attribute set to the function description. + + Raises: + ValueError: if the function description is not provided and not propagated by a previous decorator. + RuntimeError: if the LLM config is not set up before registering a function. + + """ + # name can be overwritten by the parameter, by default it is the same as function name + if name: + func._name = name + elif not hasattr(func, "_name"): + func._name = func.__name__ + + # description is propagated from the previous decorator, but it is mandatory for the first one + if description: + func._description = description + else: + if not hasattr(func, "_description"): + raise ValueError("Function description is required, none found.") + + # get JSON schema for the function + f = get_function_schema(func, name=func._name, description=func._description) + + # register the function to the agent if there is LLM config, raise an exception otherwise + if self.llm_config is None: + raise RuntimeError("LLM config must be setup before registering a function for LLM.") + + if api_style == "function": + f = f["function"] + self.update_function_signature(f, is_remove=False) + elif api_style == "tool": + self.update_tool_signature(f, is_remove=False) + else: + raise ValueError(f"Unsupported API style: {api_style}") + + return func + + return _decorator + + def register_for_execution( + self, + name: Optional[str] = None, + ) -> Callable[[F], F]: + """Decorator factory for registering a function to be executed by an agent. + + It's return value is used to decorate a function to be registered to the agent. + + Args: + name (optional(str)): name of the function. If None, the function name will be used (default: None). + + Returns: + The decorator for registering a function to be used by an agent. + + Examples: + ``` + @user_proxy.register_for_execution() + @agent2.register_for_llm() + @agent1.register_for_llm(description="This is a very useful function") + def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14): + return a + str(b * c) + ``` + + """ + + def _decorator(func: F) -> F: + """Decorator for registering a function to be used by an agent. + + Args: + func: the function to be registered. + + Returns: + The function to be registered, with the _description attribute set to the function description. + + Raises: + ValueError: if the function description is not provided and not propagated by a previous decorator. + + """ + # name can be overwritten by the parameter, by default it is the same as function name + if name: + func._name = name + elif not hasattr(func, "_name"): + func._name = func.__name__ + + self.register_function({func._name: self._wrap_function(func)}) + + return func + + return _decorator + + def register_model_client(self, model_client_cls: ModelClient, **kwargs): + """Register a model client. + + Args: + model_client_cls: A custom client class that follows the Client interface + **kwargs: The kwargs for the custom client class to be initialized with + """ + self.client.register_model_client(model_client_cls, **kwargs) + + def register_hook(self, hookable_method: str, hook: Callable): + """ + Registers a hook to be called by a hookable method, in order to add a capability to the agent. + Registered hooks are kept in lists (one per hookable method), and are called in their order of registration. + + Args: + hookable_method: A hookable method name implemented by ConversableAgent. + hook: A method implemented by a subclass of AgentCapability. + """ + assert hookable_method in self.hook_lists, f"{hookable_method} is not a hookable method." + hook_list = self.hook_lists[hookable_method] + assert hook not in hook_list, f"{hook} is already registered as a hook." + + # async hookable checks + expected_async = hookable_method.startswith("a_") + hook_is_async = inspect.iscoroutinefunction(hook) + if expected_async != hook_is_async: + context_type = "asynchronous" if expected_async else "synchronous" + warnings.warn( + f"Hook '{hook.__name__}' is {'asynchronous' if hook_is_async else 'synchronous'}, " + f"but it's being registered in a {context_type} context ('{hookable_method}'). " + "Ensure the hook matches the expected execution context.", + UserWarning, + ) + + hook_list.append(hook) + + def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: + """ + Calls any registered capability hooks to process all messages, potentially modifying the messages. + """ + hook_list = self.hook_lists["process_all_messages_before_reply"] + # If no hooks are registered, or if there are no messages to process, return the original message list. + if len(hook_list) == 0 or messages is None: + return messages + + # Call each hook (in order of registration) to process the messages. + processed_messages = messages + for hook in hook_list: + if inspect.iscoroutinefunction(hook): + continue + processed_messages = hook(processed_messages) + return processed_messages + + async def a_process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: + """ + Calls any registered capability hooks to process all messages, potentially modifying the messages. + """ + hook_list = self.hook_lists["a_process_all_messages_before_reply"] + # If no hooks are registered, or if there are no messages to process, return the original message list. + if len(hook_list) == 0 or messages is None: + return messages + + # Call each hook (in order of registration) to process the messages. + processed_messages = messages + for hook in hook_list: + if not inspect.iscoroutinefunction(hook): + continue + processed_messages = await hook(processed_messages) + return processed_messages + + def process_last_received_message(self, messages: List[Dict]) -> List[Dict]: + """ + Calls any registered capability hooks to use and potentially modify the text of the last message, + as long as the last message is not a function call or exit command. + """ + + # If any required condition is not met, return the original message list. + hook_list = self.hook_lists["process_last_received_message"] + if len(hook_list) == 0: + return messages # No hooks registered. + if messages is None: + return None # No message to process. + if len(messages) == 0: + return messages # No message to process. + last_message = messages[-1] + if "function_call" in last_message: + return messages # Last message is a function call. + if "context" in last_message: + return messages # Last message contains a context key. + if "content" not in last_message: + return messages # Last message has no content. + + user_content = last_message["content"] + if not isinstance(user_content, str) and not isinstance(user_content, list): + # if the user_content is a string, it is for regular LLM + # if the user_content is a list, it should follow the multimodal LMM format. + return messages + if user_content == "exit": + return messages # Last message is an exit command. + + # Call each hook (in order of registration) to process the user's message. + processed_user_content = user_content + for hook in hook_list: + if inspect.iscoroutinefunction(hook): + continue + processed_user_content = hook(processed_user_content) + + if processed_user_content == user_content: + return messages # No hooks actually modified the user's message. + + # Replace the last user message with the expanded one. + messages = messages.copy() + messages[-1]["content"] = processed_user_content + return messages + + async def a_process_last_received_message(self, messages: List[Dict]) -> List[Dict]: + """ + Calls any registered capability hooks to use and potentially modify the text of the last message, + as long as the last message is not a function call or exit command. + """ + + # If any required condition is not met, return the original message list. + hook_list = self.hook_lists["a_process_last_received_message"] + if len(hook_list) == 0: + return messages # No hooks registered. + if messages is None: + return None # No message to process. + if len(messages) == 0: + return messages # No message to process. + last_message = messages[-1] + if "function_call" in last_message: + return messages # Last message is a function call. + if "context" in last_message: + return messages # Last message contains a context key. + if "content" not in last_message: + return messages # Last message has no content. + + user_content = last_message["content"] + if not isinstance(user_content, str) and not isinstance(user_content, list): + # if the user_content is a string, it is for regular LLM + # if the user_content is a list, it should follow the multimodal LMM format. + return messages + if user_content == "exit": + return messages # Last message is an exit command. + + # Call each hook (in order of registration) to process the user's message. + processed_user_content = user_content + for hook in hook_list: + if not inspect.iscoroutinefunction(hook): + continue + processed_user_content = await hook(processed_user_content) + + if processed_user_content == user_content: + return messages # No hooks actually modified the user's message. + + # Replace the last user message with the expanded one. + messages = messages.copy() + messages[-1]["content"] = processed_user_content + return messages + + def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None: + """Print the usage summary.""" + iostream = IOStream.get_default() + + if self.client is None: + iostream.print(f"No cost incurred from agent '{self.name}'.") + else: + iostream.print(f"Agent '{self.name}':") + self.client.print_usage_summary(mode) + + def get_actual_usage(self) -> Union[None, Dict[str, int]]: + """Get the actual usage summary.""" + if self.client is None: + return None + else: + return self.client.actual_usage_summary + + def get_total_usage(self) -> Union[None, Dict[str, int]]: + """Get the total usage summary.""" + if self.client is None: + return None + else: + return self.client.total_usage_summary + + +def register_function( + f: Callable[..., Any], + *, + caller: ConversableAgent, + executor: ConversableAgent, + name: Optional[str] = None, + description: str, +) -> None: + """Register a function to be proposed by an agent and executed for an executor. + + This function can be used instead of function decorators `@ConversationAgent.register_for_llm` and + `@ConversationAgent.register_for_execution`. + + Args: + f: the function to be registered. + caller: the agent calling the function, typically an instance of ConversableAgent. + executor: the agent executing the function, typically an instance of UserProxy. + name: name of the function. If None, the function name will be used (default: None). + description: description of the function. The description is used by LLM to decode whether the function + is called. Make sure the description is properly describing what the function does or it might not be + called by LLM when needed. + + """ + f = caller.register_for_llm(name=name, description=description)(f) + executor.register_for_execution(name=name)(f) \ No newline at end of file diff --git a/train_methods/train_ac.py b/train_methods/train_ac.py index 65f28ce..08de344 100644 --- a/train_methods/train_ac.py +++ b/train_methods/train_ac.py @@ -79,7 +79,7 @@ def train(args: Arguments): batch = next(iter(dataloader)) with torch.no_grad(): - latents = vae.encode(batch["pixel_values"].to(device)).latent_dist.sample() + latents: torch.Tensor = vae.encode(batch["pixel_values"].to(device)).latent_dist.sample() text_embedding = text_encoder(batch["input_ids"].to(device))[0] anchor_embedding = text_encoder(batch["input_anchor_ids"].to(device))[0] latents = latents * vae.config.scaling_factor @@ -95,7 +95,7 @@ def train(args: Arguments): with torch.no_grad(): anchor_pred = unet(noisy_latens[:anchor_embedding.size(0)], timesteps[:anchor_embedding.size(0)], anchor_embedding).sample - mask = batch["mask"].to(device) + mask: torch.Tensor = batch["mask"].to(device) loss: torch.Tensor = F.mse_loss(noise_pred, anchor_pred, reduction="none") loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() diff --git a/train_methods/train_cogfd.py b/train_methods/train_cogfd.py index 465aea7..fa225a1 100644 --- a/train_methods/train_cogfd.py +++ b/train_methods/train_cogfd.py @@ -293,19 +293,13 @@ def train( latents: torch.Tensor = vae.encode(batch["pixel_values"].to(vae.device)).latent_dist.sample() latents = latents * vae.config.scaling_factor - # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(args.cogfd_start, args.cogfd_end, (bsz, ), device=latents.device) + timesteps: torch.Tensor = torch.randint(args.cogfd_start, args.cogfd_end, (bsz, ), device=latents.device) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states_source = text_encoder(batch["source_ids"].to(text_encoder.device), attention_mask=batch["source_mask"])[0] + encoder_hidden_states_source: torch.Tensor = text_encoder(batch["source_ids"].to(text_encoder.device), attention_mask=batch["source_mask"])[0] # set concept_positions for this batch attn_controller.set_encoder_attn_mask(batch["source_mask"]) diff --git a/train_methods/utils_cogfd.py b/train_methods/utils_cogfd.py index b197a3d..9da66ae 100644 --- a/train_methods/utils_cogfd.py +++ b/train_methods/utils_cogfd.py @@ -8,16 +8,18 @@ import re import pprint from dataclasses import dataclass +from json import JSONDecodeError from typing import Optional, Any import torch -import autogen -from autogen import ConversableAgent, GroupChat + from torch import nn from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel from transformers.utils import ModelOutput +from train_methods.legacy_autogen import GroupChat +from train_methods.legacy_autogen_conversable_agent import ConversableAgent @dataclass class TransformationModelOutput(ModelOutput): @@ -146,14 +148,14 @@ def generate_and_save_concept_graph( combination_theme_y: str, output_filename: str = "concept_logic_graph.json" ) -> dict | None: - """根据输入的文本概念组合生成概念逻辑图, 保存为JSON并返回解析后的图谱。 + """Generates a conceptual logic graph based on the given text concept combination, saves it as JSON, and returns the parsed graph. Args: - concept_combination_x: 形如 "A child is drinking wine" 的概念组合字符串。 - output_filename: 保存 JSON 图谱的文件名。 + concept_combination_x: A string representing a concept combination, e.g., "A child is drinking wine". + output_filename: The filename to save the JSON graph. Returns: - 解析后的概念逻辑图 (dict),如果失败则返回 None。 + The parsed conceptual logic graph as a dict, or None if the process fails. """ @@ -201,11 +203,11 @@ def generate_and_save_concept_graph( If you receive instructions on how to fix mistakes, follow them and regenerate the corrected JSON response in the same strict format. ''', llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": BASE_URL}]}, - is_termination_msg=lambda msg: "the answer is correct!" in msg.get("content", "").lower(), # Use .get for safety - human_input_mode="NEVER", # 设置为 "NEVER" 以避免提示用户输入 + is_termination_msg=lambda msg: "the answer is correct!" in msg.get("content", "").lower(), + human_input_mode="NEVER", ) - reviewer = autogen.AssistantAgent( + reviewer = AssistantAgent( name="Reviewer", llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": BASE_URL}]}, system_message=""" @@ -215,16 +217,15 @@ def generate_and_save_concept_graph( If there are some mistakes in the generated graph, please point them out and tell the Generator how to fix them. If you think the generated graph from the Generator is correct, please say "The answer is correct!" and close the chat. You must check carefully!!! """, - human_input_mode="NEVER", # 设置为 "NEVER" 以避免提示用户输入 + human_input_mode="NEVER", ) - # --- 群聊和管理器设置 --- group_chat_with_introductions = GroupChat( agents=[Concept_logic_graph_Agent, reviewer], messages=[], max_round=8, send_introductions=True, - speaker_selection_method='round_robin', # 确保轮流发言 + speaker_selection_method='round_robin', ) # --- 启动聊天 --- @@ -242,7 +243,6 @@ def auto_end_chat(): auto_end_chat() - # --- 提取、解析和保存结果 --- final_graph_string = None parsed_graph = None @@ -283,10 +283,10 @@ def auto_end_chat(): with open(output_filename, 'w', encoding='utf-8') as f: json.dump(parsed_graph, f, ensure_ascii=False, indent=4) print(f"\n--- Saved graph to {output_filename} (from direct parse) ---") - except json.JSONDecodeError: + except JSONDecodeError: print("\nCould not parse the final_graph string directly as JSON either.") - except json.JSONDecodeError as e: + except JSONDecodeError as e: print(f"\nError decoding JSON: {e}") print("String content was likely not valid JSON.") except ImportError: @@ -297,38 +297,36 @@ def auto_end_chat(): return parsed_graph -def extract_concept_from_graph(parsed_graph: dict[str, Any]) -> tuple[list[str], list[str]]: - """从解析的图谱中提取概念组合和子概念。 +def extract_concept_from_graph(parsed_graph: dict[str, dict[str, Any]]) -> tuple[list[str], list[str]]: + """extract combination of concepts and child-concept from analyzed image Args: - parsed_graph: 包含一个或多个迭代的图谱字典 + parsed_graph: graph dictionary includes at least one iteration Returns: - tuple[List[str], List[str]]: 包含概念组合列表和子概念列表的元组 + tuple[list[str], list[str]]: tuple of combination of list of concepts and list of sub-concepts """ concept_combination = [] sub_concept = [] - # 检查是否是迭代格式的图谱 if any(key.startswith('iteration_') for key in parsed_graph.keys()): - # 处理迭代格式 + for iteration_graph in parsed_graph.values(): - # 获取当前迭代的主要概念 + iteration_graph: dict[str, dict[str, Any]] + main_concept = list(iteration_graph.keys())[0].replace("_", " ") concept_combination.append(main_concept) - # 处理当前迭代的图谱 current_graph = iteration_graph[main_concept] - # 添加蕴含关系 + # 包含関係の追加 if 'entailment' in current_graph: concept_combination.extend(current_graph['entailment']) - # 添加等价关系 if 'equivalence' in current_graph: concept_combination.extend(current_graph['equivalence']) - # 添加子概念 + # add child-concept for key, value in current_graph.items(): if isinstance(value, dict): sub_concept.append(key) @@ -337,19 +335,16 @@ def extract_concept_from_graph(parsed_graph: dict[str, Any]) -> tuple[list[str], if 'equivalence' in value: sub_concept.extend(value['equivalence']) else: - # 处理单个图谱格式 + main_concept = list(parsed_graph.keys())[0].replace("_", " ") concept_combination.append(main_concept) - # 添加蕴含关系 if 'entailment' in parsed_graph[main_concept]: concept_combination.extend(parsed_graph[main_concept]['entailment']) - # 添加等价关系 if 'equivalence' in parsed_graph[main_concept]: concept_combination.extend(parsed_graph[main_concept]['equivalence']) - # 添加子概念 for key, value in parsed_graph[main_concept].items(): if isinstance(value, dict): sub_concept.append(key) @@ -358,27 +353,17 @@ def extract_concept_from_graph(parsed_graph: dict[str, Any]) -> tuple[list[str], if 'equivalence' in value: sub_concept.extend(value['equivalence']) - # 去重并返回 return list(set(concept_combination)), list(set(sub_concept)) + def generate_and_save_iterative_graphs( concept_combination_x: str, combination_theme_y: str, output_path: str, iterate_n: int = 3 -) -> dict[str, Any]: - """生成并保存迭代的概念图谱。 - - Args: - concept_combination_x: 初始概念组合 - combination_theme_y: 主题 - iterate_n: 迭代次数, 默认为3 - output_dir: 输出目录路径 - - Returns: - dict[str, Any]: 包含所有迭代图谱的字典 - """ - all_graphs = {} # 用于存储所有迭代生成的graph +) -> dict[str, dict]: + + all_graphs = {} current_concept_combination = concept_combination_x for i in range(iterate_n): @@ -391,21 +376,17 @@ def generate_and_save_iterative_graphs( print(f"concept_combination: {concept_combination}") print(f"sub_concept: {sub_concept}") - # 将当前迭代的graph添加到all_graphs中 all_graphs[f"iteration_{i}"] = generated_graph - # 更新下一个迭代的概念 - if i < iterate_n - 1: # 如果不是最后一次迭代 + if i < iterate_n - 1: current_concept_combination = generated_graph[current_concept_combination]['equivalence'][0] else: print("\n--- Function finished. Failed to generate or parse the graph. ---") break - # 保存所有迭代的graph到JSON文件 - print(output_path) os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: - print(output_path+f"/{concept_combination_x}.json") + print(f"{output_path}/{concept_combination_x}.json") json.dump(all_graphs, f, ensure_ascii=False, indent=4) print(f"\nAll iteration graphs saved to: {output_path}") From 3a7b04bdf7bf16235401a5f07cb3adbddc86b600 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 26 Oct 2025 13:56:08 +0900 Subject: [PATCH 03/25] make legacy autogen directory --- requirements.txt | 7 + train_methods/legacy_autogen/cache.py | 632 +++++++++ train_methods/legacy_autogen/chat.py | 318 +++++ train_methods/legacy_autogen/client.py | 1119 ++++++++++++++++ train_methods/legacy_autogen/completion.py | 1151 +++++++++++++++++ .../{ => legacy_autogen}/legacy_autogen.py | 211 +-- .../legacy_autogen_conversable_agent.py | 425 +++--- train_methods/legacy_autogen/stream.py | 91 ++ train_methods/legacy_autogen/utils.py | 726 +++++++++++ train_methods/utils_cogfd.py | 4 +- 10 files changed, 4263 insertions(+), 421 deletions(-) create mode 100644 train_methods/legacy_autogen/cache.py create mode 100644 train_methods/legacy_autogen/chat.py create mode 100644 train_methods/legacy_autogen/client.py create mode 100644 train_methods/legacy_autogen/completion.py rename train_methods/{ => legacy_autogen}/legacy_autogen.py (89%) rename train_methods/{ => legacy_autogen}/legacy_autogen_conversable_agent.py (92%) create mode 100644 train_methods/legacy_autogen/stream.py create mode 100644 train_methods/legacy_autogen/utils.py diff --git a/requirements.txt b/requirements.txt index e606e3c..a598a6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,13 @@ matplotlib==3.9.2 pydantic==2.9.2 scikit-learn==1.5.2 termcolor==3.1.0 +tiktoken==0.12.0 +diskcache==5.6.3 +redis==7.0.0 +azure-cosmos==4.14.0 +azure-identity==1.25.1 +docker==7.1.0 +flaml==2.3.6 open_clip_torch==2.29.0 bitsandbytes==0.44.1 diff --git a/train_methods/legacy_autogen/cache.py b/train_methods/legacy_autogen/cache.py new file mode 100644 index 0000000..d086c20 --- /dev/null +++ b/train_methods/legacy_autogen/cache.py @@ -0,0 +1,632 @@ +import os +import pickle +from types import TracebackType +from typing import Any, Protocol, Self, TypedDict + +import diskcache +import redis +from azure.cosmos import CosmosClient, PartitionKey +from azure.cosmos.exceptions import CosmosResourceNotFoundError + +class AbstractCache(Protocol): + """ + This protocol defines the basic interface for cache operations. + Implementing classes should provide concrete implementations for + these methods to handle caching mechanisms. + """ + + def get(self, key: str, default: Any | None = None) -> Any | None: + """ + Retrieve an item from the cache. + + Args: + key (str): The key identifying the item in the cache. + default (optional): The default value to return if the key is not found. + Defaults to None. + + Returns: + The value associated with the key if found, else the default value. + """ + ... + + def set(self, key: str, value: Any) -> None: + """ + Set an item in the cache. + + Args: + key (str): The key under which the item is to be stored. + value: The value to be stored in the cache. + """ + ... + + def close(self) -> None: + """ + Close the cache. Perform any necessary cleanup, such as closing network connections or + releasing resources. + """ + ... + + def __enter__(self) -> Self: + """ + Enter the runtime context related to this object. + + The with statement will bind this method's return value to the target(s) + specified in the as clause of the statement, if any. + """ + ... + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """ + Exit the runtime context and close the cache. + + Args: + exc_type: The exception type if an exception was raised in the context. + exc_value: The exception value if an exception was raised in the context. + traceback: The traceback if an exception was raised in the context. + """ + ... + +class DiskCache(AbstractCache): + """ + Implementation of AbstractCache using the DiskCache library. + + This class provides a concrete implementation of the AbstractCache + interface using the diskcache library for caching data on disk. + + Attributes: + cache (diskcache.Cache): The DiskCache instance used for caching. + + Methods: + __init__(self, seed): Initializes the DiskCache with the given seed. + get(self, key, default=None): Retrieves an item from the cache. + set(self, key, value): Sets an item in the cache. + close(self): Closes the cache. + __enter__(self): Context management entry. + __exit__(self, exc_type, exc_value, traceback): Context management exit. + """ + + def __init__(self, seed: str | int): + """ + Initialize the DiskCache instance. + + Args: + seed (str | int): A seed or namespace for the cache. This is used to create + a unique storage location for the cache data. + + """ + self.cache = diskcache.Cache(seed) + + def get(self, key: str, default: Any | None = None) -> Any | None: + """ + Retrieve an item from the cache. + + Args: + key (str): The key identifying the item in the cache. + default (optional): The default value to return if the key is not found. + Defaults to None. + + Returns: + The value associated with the key if found, else the default value. + """ + return self.cache.get(key, default) + + def set(self, key: str, value: Any) -> None: + """ + Set an item in the cache. + + Args: + key (str): The key under which the item is to be stored. + value: The value to be stored in the cache. + """ + self.cache.set(key, value) + + def close(self) -> None: + """ + Close the cache. + + Perform any necessary cleanup, such as closing file handles or + releasing resources. + """ + self.cache.close() + + def __enter__(self) -> Self: + """ + Enter the runtime context related to the object. + + Returns: + self: The instance itself. + """ + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """ + Exit the runtime context related to the object. + + Perform cleanup actions such as closing the cache. + + Args: + exc_type: The exception type if an exception was raised in the context. + exc_value: The exception value if an exception was raised in the context. + traceback: The traceback if an exception was raised in the context. + """ + self.close() + +class RedisCache(AbstractCache): + """ + Implementation of AbstractCache using the Redis database. + + This class provides a concrete implementation of the AbstractCache + interface using the Redis database for caching data. + + Attributes: + seed (str | int): A seed or namespace used as a prefix for cache keys. + cache (redis.Redis): The Redis client used for caching. + + Methods: + __init__(self, seed, redis_url): Initializes the RedisCache with the given seed and Redis URL. + _prefixed_key(self, key): Internal method to get a namespaced cache key. + get(self, key, default=None): Retrieves an item from the cache. + set(self, key, value): Sets an item in the cache. + close(self): Closes the Redis client. + __enter__(self): Context management entry. + __exit__(self, exc_type, exc_value, traceback): Context management exit. + """ + + def __init__(self, seed: str | int, redis_url: str): + """ + Initialize the RedisCache instance. + + Args: + seed (str | int): A seed or namespace for the cache. This is used as a prefix for all cache keys. + redis_url (str): The URL for the Redis server. + + """ + self.seed = seed + self.cache = redis.Redis.from_url(redis_url) + + def _prefixed_key(self, key: str) -> str: + """ + Get a namespaced key for the cache. + + Args: + key (str): The original key. + + Returns: + str: The namespaced key. + """ + return f"autogen:{self.seed}:{key}" + + def get(self, key: str, default: Any | None = None) -> Any | None: + """ + Retrieve an item from the Redis cache. + + Args: + key (str): The key identifying the item in the cache. + default (optional): The default value to return if the key is not found. + Defaults to None. + + Returns: + The deserialized value associated with the key if found, else the default value. + """ + result = self.cache.get(self._prefixed_key(key)) + if result is None: + return default + return pickle.loads(result) + + def set(self, key: str, value: Any) -> None: + """ + Set an item in the Redis cache. + + Args: + key (str): The key under which the item is to be stored. + value: The value to be stored in the cache. + + Notes: + The value is serialized using pickle before being stored in Redis. + """ + serialized_value = pickle.dumps(value) + self.cache.set(self._prefixed_key(key), serialized_value) + + def close(self) -> None: + """ + Close the Redis client. + + Perform any necessary cleanup, such as closing network connections. + """ + self.cache.close() + + def __enter__(self) -> Self: + """ + Enter the runtime context related to the object. + + Returns: + self: The instance itself. + """ + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: + """ + Exit the runtime context related to the object. + + Perform cleanup actions such as closing the Redis client. + + Args: + exc_type: The exception type if an exception was raised in the context. + exc_value: The exception value if an exception was raised in the context. + traceback: The traceback if an exception was raised in the context. + """ + self.close() + +class CosmosDBConfig(TypedDict, total=False): + connection_string: str + database_id: str + container_id: str + cache_seed: str | int | None + client: CosmosClient | None + +class CosmosDBCache(AbstractCache): + """ + Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API. + + This class provides a concrete implementation of the AbstractCache + interface using Azure Cosmos DB for caching data, with synchronous operations. + + Attributes: + seed (str | int): A seed or namespace used as a partition key. + client (CosmosClient): The Cosmos DB client used for caching. + container: The container instance used for caching. + """ + + def __init__(self, seed: str | int, cosmosdb_config: CosmosDBConfig): + """ + Initialize the CosmosDBCache instance. + + Args: + seed (str | int): A seed or namespace for the cache, used as a partition key. + connection_string (str): The connection string for the Cosmos DB account. + container_id (str): The container ID to be used for caching. + client (Optional[CosmosClient]): An existing CosmosClient instance to be used for caching. + """ + self.seed = str(seed) + self.client = cosmosdb_config.get("client") or CosmosClient.from_connection_string( + cosmosdb_config["connection_string"] + ) + database_id = cosmosdb_config.get("database_id", "autogen_cache") + self.database = self.client.get_database_client(database_id) + container_id = cosmosdb_config.get("container_id") + self.container = self.database.create_container_if_not_exists( + id=container_id, partition_key=PartitionKey(path="/partitionKey") + ) + + @classmethod + def create_cache(cls, seed: str | int, cosmosdb_config: CosmosDBConfig): + """ + Factory method to create a CosmosDBCache instance based on the provided configuration. + This method decides whether to use an existing CosmosClient or create a new one. + """ + if "client" in cosmosdb_config and isinstance(cosmosdb_config["client"], CosmosClient): + return cls.from_existing_client(seed, **cosmosdb_config) + else: + return cls.from_config(seed, cosmosdb_config) + + @classmethod + def from_config(cls, seed: str | int, cosmosdb_config: CosmosDBConfig): + return cls(str(seed), cosmosdb_config) + + @classmethod + def from_connection_string(cls, seed: str | int, connection_string: str, database_id: str, container_id: str): + config = {"connection_string": connection_string, "database_id": database_id, "container_id": container_id} + return cls(str(seed), config) + + @classmethod + def from_existing_client(cls, seed: str | int, client: CosmosClient, database_id: str, container_id: str): + config = {"client": client, "database_id": database_id, "container_id": container_id} + return cls(str(seed), config) + + def get(self, key: str, default: Any | None = None) -> Any | None: + """ + Retrieve an item from the Cosmos DB cache. + + Args: + key (str): The key identifying the item in the cache. + default (optional): The default value to return if the key is not found. + + Returns: + The deserialized value associated with the key if found, else the default value. + """ + try: + response = self.container.read_item(item=key, partition_key=str(self.seed)) + return pickle.loads(response["data"]) + except CosmosResourceNotFoundError: + return default + except Exception as e: + # Log the exception or rethrow after logging if needed + # Consider logging or handling the error appropriately here + raise e + + def set(self, key: str, value: Any) -> None: + """ + Set an item in the Cosmos DB cache. + + Args: + key (str): The key under which the item is to be stored. + value: The value to be stored in the cache. + + Notes: + The value is serialized using pickle before being stored. + """ + try: + serialized_value = pickle.dumps(value) + item = {"id": key, "partitionKey": str(self.seed), "data": serialized_value} + self.container.upsert_item(item) + except Exception as e: + # Log or handle exception + raise e + + def close(self) -> None: + """ + Close the Cosmos DB client. + + Perform any necessary cleanup, such as closing network connections. + """ + # CosmosClient doesn"t require explicit close in the current SDK + # If you created the client inside this class, you should close it if necessary + pass + + def __enter__(self): + """ + Context management entry. + + Returns: + self: The instance itself. + """ + return self + + def __exit__( + self, + exc_type: type | None, + exc_value: Exception | None, + traceback: Any | None, + ) -> None: + """ + Context management exit. + + Perform cleanup actions such as closing the Cosmos DB client. + """ + self.close() + +class CacheFactory: + @staticmethod + def cache_factory( + seed: str | int, + redis_url: str | None = None, + cache_path_root: str = ".cache", + cosmosdb_config: dict[str, Any] | None = None, + ) -> AbstractCache: + """ + Factory function for creating cache instances. + + This function decides whether to create a RedisCache, DiskCache, or CosmosDBCache instance + based on the provided parameters. If RedisCache is available and a redis_url is provided, + a RedisCache instance is created. If connection_string, database_id, and container_id + are provided, a CosmosDBCache is created. Otherwise, a DiskCache instance is used. + + Args: + seed (str | int): Used as a seed or namespace for the cache. + redis_url (str | None): URL for the Redis server. + cache_path_root (str): Root path for the disk cache. + cosmosdb_config (Optional[Dict[str, str]]): Dictionary containing 'connection_string', 'database_id', and 'container_id' for Cosmos DB cache. + + Returns: + An instance of RedisCache, DiskCache, or CosmosDBCache. + + Examples: + + Creating a Redis cache + + ```python + redis_cache = cache_factory("myseed", "redis://localhost:6379/0") + ``` + Creating a Disk cache + + ```python + disk_cache = cache_factory("myseed", None) + ``` + + Creating a Cosmos DB cache: + ```python + cosmos_cache = cache_factory("myseed", cosmosdb_config={ + "connection_string": "your_connection_string", + "database_id": "your_database_id", + "container_id": "your_container_id"} + ) + ``` + + """ + if redis_url: + return RedisCache(seed, redis_url) + + if cosmosdb_config: + return CosmosDBCache.create_cache(seed, cosmosdb_config) + + # Default to DiskCache if neither Redis nor Cosmos DB configurations are provided + path = os.path.join(cache_path_root, str(seed)) + return DiskCache(os.path.join(".", path)) + + +class Cache(AbstractCache): + """ + A wrapper class for managing cache configuration and instances. + + This class provides a unified interface for creating and interacting with + different types of cache (e.g., Redis, Disk). It abstracts the underlying + cache implementation details, providing methods for cache operations. + + Attributes: + config (dict[str, Any]): A dictionary containing cache configuration. + cache: The cache instance created based on the provided configuration. + """ + + ALLOWED_CONFIG_KEYS = [ + "cache_seed", + "redis_url", + "cache_path_root", + "cosmos_db_config", + ] + + @staticmethod + def redis(cache_seed: str | int = 42, redis_url: str = "redis://localhost:6379/0") -> "Cache": + """ + Create a Redis cache instance. + + Args: + cache_seed (str | int, optional): A seed for the cache. Defaults to 42. + redis_url (str, optional): The URL for the Redis server. Defaults to "redis://localhost:6379/0". + + Returns: + Cache: A Cache instance configured for Redis. + """ + return Cache({"cache_seed": cache_seed, "redis_url": redis_url}) + + @staticmethod + def disk(cache_seed: str | int = 42, cache_path_root: str = ".cache") -> "Cache": + """ + Create a Disk cache instance. + + Args: + cache_seed (str | int, optional): A seed for the cache. Defaults to 42. + cache_path_root (str, optional): The root path for the disk cache. Defaults to ".cache". + + Returns: + Cache: A Cache instance configured for Disk caching. + """ + return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root}) + + @staticmethod + def cosmos_db( + connection_string: str | None = None, + container_id: str | None = None, + cache_seed: str | int = 42, + client: Any | None = None, + ) -> "Cache": + """ + Create a Cosmos DB cache instance with 'autogen_cache' as database ID. + + Args: + connection_string (str, optional): Connection string to the Cosmos DB account. + container_id (str, optional): The container ID for the Cosmos DB account. + cache_seed (str | int, optional): A seed for the cache. + client: Optional[CosmosClient]: Pass an existing Cosmos DB client. + Returns: + Cache: A Cache instance configured for Cosmos DB. + """ + cosmos_db_config = { + "connection_string": connection_string, + "database_id": "autogen_cache", + "container_id": container_id, + "client": client, + } + return Cache({"cache_seed": str(cache_seed), "cosmos_db_config": cosmos_db_config}) + + def __init__(self, config: dict[str, Any]): + """ + Initialize the Cache with the given configuration. + + Validates the configuration keys and creates the cache instance. + + Args: + config (dict[str, Any]): A dictionary containing the cache configuration. + + Raises: + ValueError: If an invalid configuration key is provided. + """ + self.config = config + # Ensure that the seed is always treated as a string before being passed to any cache factory or stored. + self.config["cache_seed"] = str(self.config.get("cache_seed", 42)) + + # validate config + for key in self.config.keys(): + if key not in self.ALLOWED_CONFIG_KEYS: + raise ValueError(f"Invalid config key: {key}") + # create cache instance + self.cache = CacheFactory.cache_factory( + seed=self.config["cache_seed"], + redis_url=self.config.get("redis_url"), + cache_path_root=self.config.get("cache_path_root"), + cosmosdb_config=self.config.get("cosmos_db_config"), + ) + + def __enter__(self) -> "Cache": + """ + Enter the runtime context related to the cache object. + + Returns: + The cache instance for use within a context block. + """ + return self.cache.__enter__() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """ + Exit the runtime context related to the cache object. + + Cleans up the cache instance and handles any exceptions that occurred + within the context. + + Args: + exc_type: The exception type if an exception was raised in the context. + exc_value: The exception value if an exception was raised in the context. + traceback: The traceback if an exception was raised in the context. + """ + return self.cache.__exit__(exc_type, exc_value, traceback) + + def get(self, key: str, default: Any | None = None) -> Any | None: + """ + Retrieve an item from the cache. + + Args: + key (str): The key identifying the item in the cache. + default (optional): The default value to return if the key is not found. + Defaults to None. + + Returns: + The value associated with the key if found, else the default value. + """ + return self.cache.get(key, default) + + def set(self, key: str, value: Any) -> None: + """ + Set an item in the cache. + + Args: + key (str): The key under which the item is to be stored. + value: The value to be stored in the cache. + """ + self.cache.set(key, value) + + def close(self) -> None: + """ + Close the cache. + + Perform any necessary cleanup, such as closing connections or releasing resources. + """ + self.cache.close() diff --git a/train_methods/legacy_autogen/chat.py b/train_methods/legacy_autogen/chat.py new file mode 100644 index 0000000..c90dfd0 --- /dev/null +++ b/train_methods/legacy_autogen/chat.py @@ -0,0 +1,318 @@ +import asyncio +import datetime +import warnings +from collections import defaultdict +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable + +from termcolor import colored + +from train_methods.legacy_autogen.stream import IOStream + +def consolidate_chat_info(chat_info, uniform_sender=None) -> None: + if isinstance(chat_info, dict): + chat_info = [chat_info] + for c in chat_info: + if uniform_sender is None: + assert "sender" in c, "sender must be provided." + sender = c["sender"] + else: + sender = uniform_sender + assert "recipient" in c, "recipient must be provided." + summary_method = c.get("summary_method") + assert ( + summary_method is None + or isinstance(summary_method, Callable) + or summary_method in ("last_msg", "reflection_with_llm") + ), "summary_method must be a string chosen from 'reflection_with_llm' or 'last_msg' or a callable, or None." + if summary_method == "reflection_with_llm": + assert ( + sender.client is not None or c["recipient"].client is not None + ), "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm." + + +Prerequisite = tuple[int, int] + +@dataclass +class ChatResult: + + chat_id: int = None + """chat id""" + chat_history: list[dict[str, Any]] = None + """The chat history.""" + summary: str = None + """A summary obtained from the chat.""" + cost: dict[str, dict] = None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference" + """The cost of the chat. + The value for each usage type is a dictionary containing cost information for that specific type. + - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference. + - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference". + """ + human_input: list[str] = None + """A list of human input solicited during the chat.""" + + +def _validate_recipients(chat_queue: list[dict[str, Any]]) -> None: + """ + Validate recipients exits and warn repetitive recipients. + """ + receipts_set = set() + for chat_info in chat_queue: + assert "recipient" in chat_info, "recipient must be provided." + receipts_set.add(chat_info["recipient"]) + if len(receipts_set) < len(chat_queue): + warnings.warn( + "Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.", + UserWarning, + ) + + +def __create_async_prerequisites(chat_queue: list[dict[str, Any]]) -> list[Prerequisite]: + """ + Create list of Prerequisite (prerequisite_chat_id, chat_id) + """ + prerequisites = [] + for chat_info in chat_queue: + if "chat_id" not in chat_info: + raise ValueError("Each chat must have a unique id for async multi-chat execution.") + chat_id = chat_info["chat_id"] + pre_chats = chat_info.get("prerequisites", []) + for pre_chat_id in pre_chats: + if not isinstance(pre_chat_id, int): + raise ValueError("Prerequisite chat id is not int.") + prerequisites.append((chat_id, pre_chat_id)) + return prerequisites + + +def __find_async_chat_order(chat_ids: set[int], prerequisites: list[Prerequisite]) -> list[int]: + """Find chat order for async execution based on the prerequisite chats + + args: + num_chats: number of chats + prerequisites: list of Prerequisite (prerequisite_chat_id, chat_id) + + returns: + list: a list of chat_id in order. + """ + edges = defaultdict(set) + indegree = defaultdict(int) + for pair in prerequisites: + chat, pre = pair[0], pair[1] + if chat not in edges[pre]: + indegree[chat] += 1 + edges[pre].add(chat) + bfs = [i for i in chat_ids if i not in indegree] + chat_order = [] + steps = len(indegree) + for _ in range(steps + 1): + if not bfs: + break + chat_order.extend(bfs) + nxt = [] + for node in bfs: + if node in edges: + for course in edges[node]: + indegree[course] -= 1 + if indegree[course] == 0: + nxt.append(course) + indegree.pop(course) + edges.pop(node) + bfs = nxt + + if indegree: + return [] + return chat_order + + +def _post_process_carryover_item(carryover_item): + if isinstance(carryover_item, str): + return carryover_item + elif isinstance(carryover_item, dict) and "content" in carryover_item: + return str(carryover_item["content"]) + else: + return str(carryover_item) + + +def __post_carryover_processing(chat_info: dict[str, Any]) -> None: + iostream = IOStream.get_default() + + if "message" not in chat_info: + warnings.warn( + "message is not provided in a chat_queue entry. input() will be called to get the initial message.", + UserWarning, + ) + print_carryover = ( + ("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]]) + if isinstance(chat_info["carryover"], list) + else chat_info["carryover"] + ) + message = chat_info.get("message") + if isinstance(message, str): + print_message = message + elif callable(message): + print_message = "Callable: " + message.__name__ + elif isinstance(message, dict): + print_message = "dict: " + str(message) + elif message is None: + print_message = "None" + iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="") + iostream.print( + colored( + "Starting a new chat....", + "blue", + ), + flush=True, + ) + if chat_info.get("verbose", False): + iostream.print(colored("Message:\n" + print_message, "blue"), flush=True) + iostream.print(colored("Carryover:\n" + print_carryover, "blue"), flush=True) + iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="") + + +def initiate_chats(chat_queue: list[dict[str, Any]]) -> list[ChatResult]: + """Initiate a list of chats. + Args: + chat_queue (list[dict]): A list of dictionaries containing the information about the chats. + + Each dictionary should contain the input arguments for + [`ConversableAgent.initiate_chat`](/docs/reference/agentchat/conversable_agent#initiate_chat). + For example: + - `"sender"` - the sender agent. + - `"recipient"` - the recipient agent. + - `"clear_history"` (bool) - whether to clear the chat history with the agent. + Default is True. + - `"silent"` (bool or None) - (Experimental) whether to print the messages in this + conversation. Default is False. + - `"cache"` (Cache or None) - the cache client to use for this conversation. + Default is None. + - `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat + will continue until a termination condition is met. Default is None. + - `"summary_method"` (str or callable) - a string or callable specifying the method to get + a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg". + - `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method. + Default is {}. + - `"message"` (str, callable or None) - if None, input() will be called to get the + initial message. + - `**context` - additional context information to be passed to the chat. + - `"carryover"` - It can be used to specify the carryover information to be passed + to this chat. If provided, we will combine this carryover with the "message" content when + generating the initial chat message in `generate_init_message`. + - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list, + from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list, + then summary from all the finished chats will be taken. + Returns: + (list): a list of ChatResult objects corresponding to the finished chats in the chat_queue. + """ + + consolidate_chat_info(chat_queue) + _validate_recipients(chat_queue) + current_chat_queue = chat_queue.copy() + finished_chats = [] + while current_chat_queue: + chat_info = current_chat_queue.pop(0) + _chat_carryover = chat_info.get("carryover", []) + finished_chat_indexes_to_exclude_from_carryover = chat_info.get( + "finished_chat_indexes_to_exclude_from_carryover", [] + ) + + if isinstance(_chat_carryover, str): + _chat_carryover = [_chat_carryover] + chat_info["carryover"] = _chat_carryover + [ + r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover + ] + + if not chat_info.get("silent", False): + __post_carryover_processing(chat_info) + + sender = chat_info["sender"] + chat_res = sender.initiate_chat(**chat_info) + finished_chats.append(chat_res) + return finished_chats + + +def __system_now_str(): + ct = datetime.datetime.now() + return f" System time at {ct}. " + + +def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int): + """ + Update ChatResult when async Task for Chat is completed. + """ + print(f"Update chat {chat_id} result on task completion." + __system_now_str()) + chat_result = chat_future.result() + chat_result.chat_id = chat_id + + +async def _dependent_chat_future( + chat_id: int, chat_info: dict[str, Any], prerequisite_chat_futures: dict[int, asyncio.Future] +) -> asyncio.Task: + """ + Create an async Task for each chat. + """ + print(f"Create Task for chat {chat_id}." + __system_now_str()) + _chat_carryover = chat_info.get("carryover", []) + finished_chat_indexes_to_exclude_from_carryover = chat_info.get( + "finished_chat_indexes_to_exclude_from_carryover", [] + ) + finished_chats = dict() + for chat in prerequisite_chat_futures: + chat_future = prerequisite_chat_futures[chat] + if chat_future.cancelled(): + raise RuntimeError(f"Chat {chat} is cancelled.") + + # wait for prerequisite chat results for the new chat carryover + finished_chats[chat] = await chat_future + + if isinstance(_chat_carryover, str): + _chat_carryover = [_chat_carryover] + data = [ + chat_result.summary + for chat_id, chat_result in finished_chats.items() + if chat_id not in finished_chat_indexes_to_exclude_from_carryover + ] + chat_info["carryover"] = _chat_carryover + data + if not chat_info.get("silent", False): + __post_carryover_processing(chat_info) + + sender = chat_info["sender"] + chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info)) + call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id) + chat_res_future.add_done_callback(call_back_with_args) + print(f"Task for chat {chat_id} created." + __system_now_str()) + return chat_res_future + + +async def a_initiate_chats(chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]: + """(async) Initiate a list of chats. + + args: + - Please refer to `initiate_chats`. + + + returns: + - (dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue. + """ + consolidate_chat_info(chat_queue) + _validate_recipients(chat_queue) + chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue} + num_chats = chat_book.keys() + prerequisites = __create_async_prerequisites(chat_queue) + chat_order_by_id = __find_async_chat_order(num_chats, prerequisites) + finished_chat_futures = dict() + for chat_id in chat_order_by_id: + chat_info = chat_book[chat_id] + prerequisite_chat_ids = chat_info.get("prerequisites", []) + pre_chat_futures = dict() + for pre_chat_id in prerequisite_chat_ids: + pre_chat_future = finished_chat_futures[pre_chat_id] + pre_chat_futures[pre_chat_id] = pre_chat_future + current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures) + finished_chat_futures[chat_id] = current_chat_future + await asyncio.gather(*list(finished_chat_futures.values())) + finished_chats = dict() + for chat in finished_chat_futures: + chat_result = finished_chat_futures[chat].result() + finished_chats[chat] = chat_result + return finished_chats diff --git a/train_methods/legacy_autogen/client.py b/train_methods/legacy_autogen/client.py new file mode 100644 index 0000000..bae85a9 --- /dev/null +++ b/train_methods/legacy_autogen/client.py @@ -0,0 +1,1119 @@ +import json +import inspect +import time + +from typing import Protocol, Any, Callable + +import tiktoken +from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI +from openai.resources import Completions +from openai.types.chat import ChatCompletion +from openai.types.chat.chat_completion import Choice, ChatCompletionMessage +from openai.types.completion import Completion +from openai.types.completion_usage import CompletionUsage +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) +from pydantic import BaseModel + +from train_methods.legacy_autogen.cache import Cache +from train_methods.legacy_autogen.stream import IOStream + +NON_CACHE_KEY = [ + "api_key", + "base_url", + "api_type", + "api_version", + "azure_ad_token", + "azure_ad_token_provider", + "credentials", + "tool_config", +] + +OAI_PRICE1K = { + # https://openai.com/api/pricing/ + # gpt-4o + "gpt-4o": (0.005, 0.015), + "gpt-4o-2024-05-13": (0.005, 0.015), + "gpt-4o-2024-08-06": (0.0025, 0.01), + # gpt-4-turbo + "gpt-4-turbo-2024-04-09": (0.01, 0.03), + # gpt-4 + "gpt-4": (0.03, 0.06), + "gpt-4-32k": (0.06, 0.12), + # gpt-4o-mini + "gpt-4o-mini": (0.000150, 0.000600), + "gpt-4o-mini-2024-07-18": (0.000150, 0.000600), + # gpt-3.5 turbo + "gpt-3.5-turbo": (0.0005, 0.0015), # default is 0125 + "gpt-3.5-turbo-0125": (0.0005, 0.0015), # 16k + "gpt-3.5-turbo-instruct": (0.0015, 0.002), + # base model + "davinci-002": 0.002, + "babbage-002": 0.0004, + # old model + "gpt-4-0125-preview": (0.01, 0.03), + "gpt-4-1106-preview": (0.01, 0.03), + "gpt-4-1106-vision-preview": (0.01, 0.03), # TODO: support vision pricing of images + "gpt-3.5-turbo-1106": (0.001, 0.002), + "gpt-3.5-turbo-0613": (0.0015, 0.002), + # "gpt-3.5-turbo-16k": (0.003, 0.004), + "gpt-3.5-turbo-16k-0613": (0.003, 0.004), + "gpt-3.5-turbo-0301": (0.0015, 0.002), + "text-ada-001": 0.0004, + "text-babbage-001": 0.0005, + "text-curie-001": 0.002, + "code-cushman-001": 0.024, + "code-davinci-002": 0.1, + "text-davinci-002": 0.02, + "text-davinci-003": 0.02, + "gpt-4-0314": (0.03, 0.06), # deprecate in Sep + "gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep + "gpt-4-0613": (0.03, 0.06), + "gpt-4-32k-0613": (0.06, 0.12), + "gpt-4-turbo-preview": (0.01, 0.03), + # https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/#pricing + "gpt-35-turbo": (0.0005, 0.0015), # what's the default? using 0125 here. + "gpt-35-turbo-0125": (0.0005, 0.0015), + "gpt-35-turbo-instruct": (0.0015, 0.002), + "gpt-35-turbo-1106": (0.001, 0.002), + "gpt-35-turbo-0613": (0.0015, 0.002), + "gpt-35-turbo-0301": (0.0015, 0.002), + "gpt-35-turbo-16k": (0.003, 0.004), + "gpt-35-turbo-16k-0613": (0.003, 0.004), +} + + +def get_key(config: dict[str, Any]) -> str: + """Get a unique identifier of a configuration. + + Args: + config (dict or list): A configuration. + + Returns: + tuple: A unique identifier which can be used as a key for a dict. + """ + copied = False + for key in NON_CACHE_KEY: + if key in config: + config, copied = config.copy() if not copied else config, True + config.pop(key) + return json.dumps(config, sort_keys=True) + +def _num_token_from_text(text: str, model: str = "gpt-3.5-turbo-0613"): + """Return the number of tokens used by a string.""" + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print(f"Model {model} not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + return len(encoding.encode(text)) + +def _num_token_from_messages(messages: list | dict, model="gpt-3.5-turbo-0613"): + """Return the number of tokens used by a list of messages. + + retrieved from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb/ + """ + if isinstance(messages, dict): + messages = [messages] + + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print(f"Model {model} not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + if model in { + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + }: + tokens_per_message = 3 + tokens_per_name = 1 + elif model == "gpt-3.5-turbo-0301": + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_name = -1 # if there's a name, the role is omitted + elif "gpt-3.5-turbo" in model: + print("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + return _num_token_from_messages(messages, model="gpt-3.5-turbo-0613") + elif "gpt-4" in model: + print("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + return _num_token_from_messages(messages, model="gpt-4-0613") + elif "gemini" in model: + print("Gemini is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.") + return _num_token_from_messages(messages, model="gpt-4-0613") + elif "claude" in model: + print("Claude is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.") + return _num_token_from_messages(messages, model="gpt-4-0613") + elif "mistral-" in model or "mixtral-" in model: + print("Mistral.AI models are not supported in tiktoken. Returning num tokens assuming gpt-4-0613.") + return _num_token_from_messages(messages, model="gpt-4-0613") + else: + raise NotImplementedError( + f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" + ) + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + if value is None: + continue + + # function calls + if not isinstance(value, str): + try: + value = json.dumps(value) + except TypeError: + print( + f"Value {value} is not a string and cannot be converted to json. It is a type: {type(value)} Skipping." + ) + continue + + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + return num_tokens + +def count_token(input: str | list | dict, model: str = "gpt-3.5-turbo-0613") -> int: + """Count number of tokens used by an OpenAI model. + Args: + input: (str, list, dict): Input to the model. + model: (str): Model name. + + Returns: + int: Number of tokens from the input. + """ + if isinstance(input, str): + return _num_token_from_text(input, model=model) + elif isinstance(input, list) or isinstance(input, dict): + return _num_token_from_messages(input, model=model) + else: + raise ValueError(f"input must be str, list or dict, but we got {type(input)}") + + +class PlaceHolderClient: + def __init__(self, config): + self.config = config + +class ModelClient(Protocol): + """ + A client class must implement the following methods: + - create must return a response object that implements the ModelClientResponseProtocol + - cost must return the cost of the response + - get_usage must return a dict with the following keys: + - prompt_tokens + - completion_tokens + - total_tokens + - cost + - model + + This class is used to create a client that can be used by OpenAIWrapper. + The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed. + The message_retrieval method must be implemented to return a list of str or a list of messages from the response. + """ + + RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] + + class ModelClientResponseProtocol(Protocol): + class Choice(Protocol): + class Message(Protocol): + content: str | None + + message: Message + + choices: list[Choice] + model: str + + def create(self, params: dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover + + def message_retrieval( + self, response: ModelClientResponseProtocol + ) -> list[str] | list[ModelClientResponseProtocol.Choice.Message]: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + ... # pragma: no cover + + def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover + + @staticmethod + def get_usage(response: ModelClientResponseProtocol) -> dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + ... # pragma: no cover + +class RateLimiter(Protocol): + def sleep(self, *args, **kwargs): ... + +class TimeRateLimiter: + """A class to implement a time-based rate limiter. + + This rate limiter ensures that a certain operation does not exceed a specified frequency. + It can be used to limit the rate of requests sent to a server or the rate of any repeated action. + """ + + def __init__(self, rate: float): + """ + Args: + rate (int): The frequency of the time-based rate limiter (NOT time interval). + """ + self._time_interval_seconds = 1.0 / rate + self._last_time_called = 0.0 + + def sleep(self, *args, **kwargs): + """Synchronously waits until enough time has passed to allow the next operation. + + If the elapsed time since the last operation is less than the required time interval, + this method will block the execution by sleeping for the remaining time. + """ + if self._elapsed_time() < self._time_interval_seconds: + time.sleep(self._time_interval_seconds - self._elapsed_time()) + + self._last_time_called = time.perf_counter() + + def _elapsed_time(self): + return time.perf_counter() - self._last_time_called + +class OpenAIClient: + """Follows the Client protocol and wraps the OpenAI client.""" + + def __init__(self, client: OpenAI | AzureOpenAI): + self._oai_client = client + + def message_retrieval( + self, response: ChatCompletion | Completion + ) -> list[str] | list[ChatCompletionMessage]: + """Retrieve the messages from the response.""" + choices = response.choices + if isinstance(response, Completion): + return [choice.text for choice in choices] # type: ignore [union-attr] + + return [ # type: ignore [return-value] + ( + choice.message # type: ignore [union-attr] + if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] + else choice.message.content + ) # type: ignore [union-attr] + for choice in choices + ] + + def create(self, params: dict[str, Any]) -> ChatCompletion: + """Create a completion for a given config using openai's client. + + Args: + client: The openai client. + params: The params for the completion. + + Returns: + The completion. + """ + iostream = IOStream.get_default() + + completions: Completions = ( + self._oai_client.chat.completions if "messages" in params else self._oai_client.completions + ) # type: ignore [attr-defined] + # If streaming is enabled and has messages, then iterate over the chunks of the response. + if params.get("stream", False) and "messages" in params: + response_contents = [""] * params.get("n", 1) + finish_reasons = [""] * params.get("n", 1) + completion_tokens = 0 + + # Set the terminal text color to green + iostream.print("\033[32m", end="") + + # Prepare for potential function call + full_function_call: dict[str, Any] | None = None + full_tool_calls: list[dict[str, Any | None]] | None = None + + # Send the chat completion request to OpenAI's API and process the response in chunks + for chunk in completions.create(**params): + if chunk.choices: + for choice in chunk.choices: + content = choice.delta.content + tool_calls_chunks = choice.delta.tool_calls + finish_reasons[choice.index] = choice.finish_reason + + # todo: remove this after function calls are removed from the API + # the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail + # begin block + function_call_chunk = ( + choice.delta.function_call if hasattr(choice.delta, "function_call") else None + ) + # Handle function call + if function_call_chunk: + # Handle function call + if function_call_chunk: + full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk( + function_call_chunk, full_function_call, completion_tokens + ) + if not content: + continue + # end block + + # Handle tool calls + if tool_calls_chunks: + for tool_calls_chunk in tool_calls_chunks: + # the current tool call to be reconstructed + ix = tool_calls_chunk.index + if full_tool_calls is None: + full_tool_calls = [] + if ix >= len(full_tool_calls): + # in case ix is not sequential + full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1) + + full_tool_calls[ix], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk( + tool_calls_chunk, full_tool_calls[ix], completion_tokens + ) + if not content: + continue + + # End handle tool calls + + # If content is present, print it to the terminal and update response variables + if content is not None: + iostream.print(content, end="", flush=True) + response_contents[choice.index] += content + completion_tokens += 1 + else: + # iostream.print() + pass + + # Reset the terminal text color + iostream.print("\033[0m\n") + + # Prepare the final ChatCompletion object based on the accumulated data + model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API + try: + prompt_tokens = count_token(params["messages"], model) + except NotImplementedError as e: + # Catch token calculation error if streaming with customized models. + print(str(e)) + prompt_tokens = 0 + response = ChatCompletion( + id=chunk.id, + model=chunk.model, + created=chunk.created, + object="chat.completion", + choices=[], + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + for i in range(len(response_contents)): + choice = Choice( + index=i, + finish_reason=finish_reasons[i], + message=ChatCompletionMessage( + role="assistant", + content=response_contents[i], + function_call=full_function_call, + tool_calls=full_tool_calls, + ), + logprobs=None, + ) + + response.choices.append(choice) + else: + # If streaming is not enabled, send a regular chat completion request + params = params.copy() + params["stream"] = False + response = completions.create(**params) + + return response + + def cost(self, response: ChatCompletion | Completion) -> float: + """Calculate the cost of the response.""" + model = response.model + if model not in OAI_PRICE1K: + # log warning that the model is not found + print( + f'Model {model} is not found. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.' + ) + return 0 + + n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr] + n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr] + if n_output_tokens is None: + n_output_tokens = 0 + tmp_price1K = OAI_PRICE1K[model] + # First value is input token rate, second value is output token rate + if isinstance(tmp_price1K, tuple): + return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return] + return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator] + + @staticmethod + def get_usage(response: ChatCompletion | Completion) -> dict: + return { + "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0, + "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0, + "total_tokens": response.usage.total_tokens if response.usage is not None else 0, + "cost": response.cost if hasattr(response, "cost") else 0, + "model": response.model, + } + +class OpenAIWrapper: + """A wrapper class for openai client.""" + + extra_kwargs = { + "agent", + "cache", + "cache_seed", + "filter_func", + "allow_format_str_template", + "context", + "api_version", + "api_type", + "tags", + "price", + } + + openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) + aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs) + openai_kwargs = openai_kwargs | aopenai_kwargs + total_usage_summary: dict[str, Any] | None = None + actual_usage_summary: dict[str, Any] | None = None + + def __init__(self, *, config_list: list[dict[str, Any]] | None = None, **base_config: Any): + """ + Args: + config_list: a list of config dicts to override the base_config. + They can contain additional kwargs as allowed in the [create](/docs/reference/oai/client#create) method. E.g., + + ```python + config_list=[ + { + "model": "gpt-4", + "api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "api_type": "azure", + "base_url": os.environ.get("AZURE_OPENAI_API_BASE"), + "api_version": "2024-02-01", + }, + { + "model": "gpt-3.5-turbo", + "api_key": os.environ.get("OPENAI_API_KEY"), + "api_type": "openai", + "base_url": "https://api.openai.com/v1", + }, + { + "model": "llama-7B", + "base_url": "http://127.0.0.1:8080", + } + ] + ``` + + base_config: base config. It can contain both keyword arguments for openai client + and additional kwargs. + When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `base_config` or in each config of `config_list`. + """ + + openai_config, extra_kwargs = self._separate_openai_config(base_config) + # It's OK if "model" is not provided in base_config or config_list + # Because one can provide "model" at `create` time. + + self._clients: list[ModelClient] = [] + self._config_list: list[dict[str, Any]] = [] + self._rate_limiters: list[RateLimiter | None] = [] + + if config_list: + self._initialize_rate_limiters(config_list) + + config_list = [config.copy() for config in config_list] # make a copy before modifying + for config in config_list: + self._register_default_client(config, openai_config) # could modify the config + self._config_list.append( + {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} + ) + else: + self._register_default_client(extra_kwargs, openai_config) + self._config_list = [extra_kwargs] + self.wrapper_id = id(self) + + def _separate_openai_config(self, config: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + """Separate the config into openai_config and extra_kwargs.""" + openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs} + extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs} + return openai_config, extra_kwargs + + def _separate_create_config(self, config: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + """Separate the config into create_config and extra_kwargs.""" + create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs} + extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs} + return create_config, extra_kwargs + + def _configure_azure_openai(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None: + openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model")) + if openai_config["azure_deployment"] is not None: + openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") + openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) + + # Create a default Azure token provider if requested + if openai_config.get("azure_ad_token_provider") == "DEFAULT": + import azure.identity + + azure_ad_token_provider_scope = openai_config.get( + "azure_ad_token_provider_scope", "https://cognitiveservices.azure.com/.default" + ) + openai_config["azure_ad_token_provider"] = azure.identity.get_bearer_token_provider( + azure.identity.DefaultAzureCredential(), azure_ad_token_provider_scope + ) + + def _configure_openai_config_for_bedrock(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None: + """Update openai_config with AWS credentials from config.""" + required_keys = ["aws_access_key", "aws_secret_key", "aws_region"] + optional_keys = ["aws_session_token", "aws_profile_name"] + for key in required_keys: + if key in config: + openai_config[key] = config[key] + for key in optional_keys: + if key in config: + openai_config[key] = config[key] + + def _register_default_client(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None: + """Create a client with the given config to override openai_config, + after removing extra kwargs. + + For Azure models/deployment names there's a convenience modification of model removing dots in + the it's value (Azure deployment names can't have dots). I.e. if you have Azure deployment name + "gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot + from the name and create a client that connects to "gpt-35-turbo" Azure deployment. + """ + openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} + api_type = config.get("api_type") + model_client_cls_name = config.get("model_client_cls") + if model_client_cls_name is not None: + # a config for a custom client is set + # adding placeholder until the register_model_client is called with the appropriate class + self._clients.append(PlaceHolderClient(config)) + print( + f"Detected custom model client in config: {model_client_cls_name}, model client can not be used until register_model_client is called." + ) + else: + if api_type is not None and api_type.startswith("azure"): + self._configure_azure_openai(config, openai_config) + client = AzureOpenAI(**openai_config) + self._clients.append(OpenAIClient(client)) + elif api_type is not None and api_type.startswith("cerebras"): + if cerebras_import_exception: + raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.") + client = CerebrasClient(**openai_config) + self._clients.append(client) + elif api_type is not None and api_type.startswith("google"): + if gemini_import_exception: + raise ImportError("Please install `google-generativeai` to use Google OpenAI API.") + client = GeminiClient(**openai_config) + self._clients.append(client) + elif api_type is not None and api_type.startswith("anthropic"): + if "api_key" not in config: + self._configure_openai_config_for_bedrock(config, openai_config) + if anthropic_import_exception: + raise ImportError("Please install `anthropic` to use Anthropic API.") + client = AnthropicClient(**openai_config) + self._clients.append(client) + elif api_type is not None and api_type.startswith("mistral"): + if mistral_import_exception: + raise ImportError("Please install `mistralai` to use the Mistral.AI API.") + client = MistralAIClient(**openai_config) + self._clients.append(client) + elif api_type is not None and api_type.startswith("together"): + if together_import_exception: + raise ImportError("Please install `together` to use the Together.AI API.") + client = TogetherClient(**openai_config) + self._clients.append(client) + elif api_type is not None and api_type.startswith("groq"): + if groq_import_exception: + raise ImportError("Please install `groq` to use the Groq API.") + client = GroqClient(**openai_config) + self._clients.append(client) + elif api_type is not None and api_type.startswith("cohere"): + if cohere_import_exception: + raise ImportError("Please install `cohere` to use the Cohere API.") + client = CohereClient(**openai_config) + self._clients.append(client) + elif api_type is not None and api_type.startswith("ollama"): + if ollama_import_exception: + raise ImportError("Please install with `[ollama]` option to use the Ollama API.") + client = OllamaClient(**openai_config) + self._clients.append(client) + elif api_type is not None and api_type.startswith("bedrock"): + self._configure_openai_config_for_bedrock(config, openai_config) + if bedrock_import_exception: + raise ImportError("Please install `boto3` to use the Amazon Bedrock API.") + client = BedrockClient(**openai_config) + self._clients.append(client) + else: + client = OpenAI(**openai_config) + self._clients.append(OpenAIClient(client)) + + def register_model_client(self, model_client_cls: ModelClient, **kwargs): + """Register a model client. + + Args: + model_client_cls: A custom client class that follows the ModelClient interface + **kwargs: The kwargs for the custom client class to be initialized with + """ + existing_client_class = False + for i, client in enumerate(self._clients): + if isinstance(client, PlaceHolderClient): + placeholder_config = client.config + + if placeholder_config.get("model_client_cls") == model_client_cls.__name__: + self._clients[i] = model_client_cls(placeholder_config, **kwargs) + return + elif isinstance(client, model_client_cls): + existing_client_class = True + + if existing_client_class: + print( + f"Model client {model_client_cls.__name__} is already registered. Add more entries in the config_list to use multiple model clients." + ) + else: + raise ValueError( + f'Model client "{model_client_cls.__name__}" is being registered but was not found in the config_list. ' + f'Please make sure to include an entry in the config_list with "model_client_cls": "{model_client_cls.__name__}"' + ) + + @classmethod + def instantiate( + cls, + template: str | Callable[[dict[str, Any]], str] | None, + context: dict[str, Any] | None = None, + allow_format_str_template: bool | None = False, + ) -> str | None: + if not context or template is None: + return template # type: ignore [return-value] + if isinstance(template, str): + return template.format(**context) if allow_format_str_template else template + return template(context) + + def _construct_create_params(self, create_config: dict[str, Any], extra_kwargs: dict[str, Any]) -> dict[str, Any]: + """Prime the create_config with additional_kwargs.""" + # Validate the config + prompt: str | None = create_config.get("prompt") + messages: list[dict[str, Any]] | None = create_config.get("messages") + if (prompt is None) == (messages is None): + raise ValueError("Either prompt or messages should be in create config but not both.") + context = extra_kwargs.get("context") + if context is None: + # No need to instantiate if no context is provided. + return create_config + # Instantiate the prompt or messages + allow_format_str_template = extra_kwargs.get("allow_format_str_template", False) + # Make a copy of the config + params = create_config.copy() + if prompt is not None: + # Instantiate the prompt + params["prompt"] = self.instantiate(prompt, context, allow_format_str_template) + elif context: + # Instantiate the messages + params["messages"] = [ + ( + { + **m, + "content": self.instantiate(m["content"], context, allow_format_str_template), + } + if m.get("content") + else m + ) + for m in messages # type: ignore [union-attr] + ] + return params + + def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: + """Make a completion for a given config using available clients. + Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. + The config in each client will be overridden by the config. + + Args: + - context (Dict | None): The context to instantiate the prompt or messages. Default to None. + It needs to contain keys that are used by the prompt template or the filter function. + E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`. + The actual prompt will be: + "Complete the following sentence: Today I feel". + More examples can be found at [templating](/docs/Use-Cases/enhanced_inference#templating). + - cache (AbstractCache | None): A Cache object to use for response cache. Default to None. + Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided, + then the cache_seed argument is ignored. If this argument is not provided or None, + then the cache_seed argument is used. + - agent (AbstractAgent | None): The object responsible for creating a completion if an agent. + - (Legacy) cache_seed (int | None) for using the DiskCache. Default to 41. + An integer cache_seed is useful when implementing "controlled randomness" for the completion. + None for no caching. + Note: this is a legacy argument. It is only used when the cache argument is not provided. + - filter_func (Callable | None): A function that takes in the context and the response + and returns a boolean to indicate whether the response is valid. E.g., + + ```python + def yes_or_no_filter(context, response): + return context.get("yes_or_no_choice", False) is False or any( + text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response) + ) + ``` + + - allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false. + - api_version (str | None): The api version. Default to None. E.g., "2024-02-01". + Raises: + - RuntimeError: If all declared custom model clients are not registered + - APIError: If any model client create call raises an APIError + """ + + last = len(self._clients) - 1 + # Check if all configs in config list are activated + non_activated = [ + client.config["model_client_cls"] for client in self._clients if isinstance(client, PlaceHolderClient) + ] + if non_activated: + raise RuntimeError( + f"Model client(s) {non_activated} are not activated. Please register the custom model clients using `register_model_client` or filter them out form the config list." + ) + for i, client in enumerate(self._clients): + # merge the input config with the i-th config in the config list + full_config = {**config, **self._config_list[i]} + # separate the config into create_config and extra_kwargs + create_config, extra_kwargs = self._separate_create_config(full_config) + api_type = extra_kwargs.get("api_type") + if api_type and api_type.startswith("azure") and "model" in create_config: + create_config["model"] = create_config["model"].replace(".", "") + # construct the create params + params = self._construct_create_params(create_config, extra_kwargs) + # get the cache_seed, filter_func and context + cache_seed = extra_kwargs.get("cache_seed", 41) + cache = extra_kwargs.get("cache") + filter_func = extra_kwargs.get("filter_func") + context = extra_kwargs.get("context") + price = extra_kwargs.get("price", None) + if isinstance(price, list): + price = tuple(price) + elif isinstance(price, float) or isinstance(price, int): + print( + "Input price is a float/int. Using the same price for prompt and completion tokens. Use a list/tuple if prompt and completion token prices are different." + ) + price = (price, price) + + total_usage = None + actual_usage = None + + cache_client = None + if cache is not None: + # Use the cache object if provided. + cache_client = cache + elif cache_seed is not None: + # Legacy cache behavior, if cache_seed is given, use DiskCache. + cache_client = Cache.disk(cache_seed, ".cache") + + if cache_client is not None: + with cache_client as cache: + # Try to get the response from cache + key = get_key(params) + + response: ModelClient.ModelClientResponseProtocol = cache.get(key, None) + + if response is not None: + response.message_retrieval_function = client.message_retrieval + try: + response.cost # type: ignore [attr-defined] + except AttributeError: + # update attribute if cost is not calculated + response.cost = client.cost(response) + cache.set(key, response) + total_usage = client.get_usage(response) + + # check the filter + pass_filter = filter_func is None or filter_func(context=context, response=response) + if pass_filter or i == last: + # Return the response if it passes the filter or it is the last client + response.config_id = i + response.pass_filter = pass_filter + self._update_usage(actual_usage=actual_usage, total_usage=total_usage) + return response + continue # filter is not passed; try the next config + try: + self._throttle_api_calls(i) + response = client.create(params) + except APITimeoutError as err: + if i == last: + raise TimeoutError( + "OpenAI API call timed out. This could be due to congestion or too small a timeout value. The timeout can be specified by setting the 'timeout' value (in seconds) in the llm_config (if you are using agents) or the OpenAIWrapper constructor (if you are using the OpenAIWrapper directly)." + ) from err + except APIError as err: + error_code = getattr(err, "code", None) + if error_code == "content_filter": + # raise the error for content_filter + raise + if i == last: + raise + else: + # add cost calculation before caching no matter filter is passed or not + if price is not None: + response.cost = self._cost_with_customized_price(response, price) + else: + response.cost = client.cost(response) + actual_usage = client.get_usage(response) + total_usage = actual_usage.copy() if actual_usage is not None else total_usage + self._update_usage(actual_usage=actual_usage, total_usage=total_usage) + if cache_client is not None: + # Cache the response + with cache_client as cache: + cache.set(key, response) + + response.message_retrieval_function = client.message_retrieval + # check the filter + pass_filter = filter_func is None or filter_func(context=context, response=response) + if pass_filter or i == last: + # Return the response if it passes the filter or it is the last client + response.config_id = i + response.pass_filter = pass_filter + return response + continue # filter is not passed; try the next config + raise RuntimeError("Should not reach here.") + + @staticmethod + def _cost_with_customized_price( + response: ModelClient.ModelClientResponseProtocol, price_1k: tuple[float, float] + ) -> None: + """If a customized cost is passed, overwrite the cost in the response.""" + n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr] + n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr] + if n_output_tokens is None: + n_output_tokens = 0 + return (n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]) / 1000 + + @staticmethod + def _update_dict_from_chunk(chunk: BaseModel, d: dict[str, Any], field: str) -> int: + """Update the dict from the chunk. + + Reads `chunk.field` and if present updates `d[field]` accordingly. + + Args: + chunk: The chunk. + d: The dict to be updated in place. + field: The field. + + Returns: + The updated dict. + + """ + completion_tokens = 0 + assert isinstance(d, dict), d + if hasattr(chunk, field) and getattr(chunk, field) is not None: + new_value = getattr(chunk, field) + if isinstance(new_value, list) or isinstance(new_value, dict): + raise NotImplementedError( + f"Field {field} is a list or dict, which is currently not supported. " + "Only string and numbers are supported." + ) + if field not in d: + d[field] = "" + if isinstance(new_value, str): + d[field] += getattr(chunk, field) + else: + d[field] = new_value + completion_tokens = 1 + + return completion_tokens + + @staticmethod + def _update_function_call_from_chunk( + function_call_chunk: ChoiceDeltaToolCallFunction | ChoiceDeltaFunctionCall, + full_function_call: dict[str, Any] | None, + completion_tokens: int, + ) -> tuple[dict[str, Any], int]: + """Update the function call from the chunk. + + Args: + function_call_chunk: The function call chunk. + full_function_call: The full function call. + completion_tokens: The number of completion tokens. + + Returns: + The updated full function call and the updated number of completion tokens. + + """ + # Handle function call + if function_call_chunk: + if full_function_call is None: + full_function_call = {} + for field in ["name", "arguments"]: + completion_tokens += OpenAIWrapper._update_dict_from_chunk( + function_call_chunk, full_function_call, field + ) + + if full_function_call: + return full_function_call, completion_tokens + else: + raise RuntimeError("Function call is not found, this should not happen.") + + @staticmethod + def _update_tool_calls_from_chunk( + tool_calls_chunk: ChoiceDeltaToolCall, + full_tool_call: dict[str, Any] | None, + completion_tokens: int, + ) -> tuple[dict[str, Any], int]: + """Update the tool call from the chunk. + + Args: + tool_call_chunk: The tool call chunk. + full_tool_call: The full tool call. + completion_tokens: The number of completion tokens. + + Returns: + The updated full tool call and the updated number of completion tokens. + + """ + # future proofing for when tool calls other than function calls are supported + if tool_calls_chunk.type and tool_calls_chunk.type != "function": + raise NotImplementedError( + f"Tool call type {tool_calls_chunk.type} is currently not supported. " + "Only function calls are supported." + ) + + # Handle tool call + assert full_tool_call is None or isinstance(full_tool_call, dict), full_tool_call + if tool_calls_chunk: + if full_tool_call is None: + full_tool_call = {} + for field in ["index", "id", "type"]: + completion_tokens += OpenAIWrapper._update_dict_from_chunk(tool_calls_chunk, full_tool_call, field) + + if hasattr(tool_calls_chunk, "function") and tool_calls_chunk.function: + if "function" not in full_tool_call: + full_tool_call["function"] = None + + full_tool_call["function"], completion_tokens = OpenAIWrapper._update_function_call_from_chunk( + tool_calls_chunk.function, full_tool_call["function"], completion_tokens + ) + + if full_tool_call: + return full_tool_call, completion_tokens + else: + raise RuntimeError("Tool call is not found, this should not happen.") + + def _update_usage(self, actual_usage, total_usage): + def update_usage(usage_summary, response_usage): + # go through RESPONSE_USAGE_KEYS and check that they are in response_usage and if not just return usage_summary + for key in ModelClient.RESPONSE_USAGE_KEYS: + if key not in response_usage: + return usage_summary + + model = response_usage["model"] + cost = response_usage["cost"] + prompt_tokens = response_usage["prompt_tokens"] + completion_tokens = response_usage["completion_tokens"] + if completion_tokens is None: + completion_tokens = 0 + total_tokens = response_usage["total_tokens"] + + if usage_summary is None: + usage_summary = {"total_cost": cost} + else: + usage_summary["total_cost"] += cost + + usage_summary[model] = { + "cost": usage_summary.get(model, {}).get("cost", 0) + cost, + "prompt_tokens": usage_summary.get(model, {}).get("prompt_tokens", 0) + prompt_tokens, + "completion_tokens": usage_summary.get(model, {}).get("completion_tokens", 0) + completion_tokens, + "total_tokens": usage_summary.get(model, {}).get("total_tokens", 0) + total_tokens, + } + return usage_summary + + if total_usage is not None: + self.total_usage_summary = update_usage(self.total_usage_summary, total_usage) + if actual_usage is not None: + self.actual_usage_summary = update_usage(self.actual_usage_summary, actual_usage) + + def print_usage_summary(self, mode: str | list[str] = ["actual", "total"]) -> None: + """Print the usage summary.""" + iostream = IOStream.get_default() + + def print_usage(usage_summary: dict[str, Any] | None, usage_type: str = "total") -> None: + word_from_type = "including" if usage_type == "total" else "excluding" + if usage_summary is None: + iostream.print("No actual cost incurred (all completions are using cache).", flush=True) + return + + iostream.print(f"Usage summary {word_from_type} cached usage: ", flush=True) + iostream.print(f"Total cost: {round(usage_summary['total_cost'], 5)}", flush=True) + for model, counts in usage_summary.items(): + if model == "total_cost": + continue # + iostream.print( + f"* Model '{model}': cost: {round(counts['cost'], 5)}, prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}", + flush=True, + ) + + if self.total_usage_summary is None: + iostream.print('No usage summary. Please call "create" first.', flush=True) + return + + if isinstance(mode, list): + if len(mode) == 0 or len(mode) > 2: + raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]') + if "actual" in mode and "total" in mode: + mode = "both" + elif "actual" in mode: + mode = "actual" + elif "total" in mode: + mode = "total" + + iostream.print("-" * 100, flush=True) + if mode == "both": + print_usage(self.actual_usage_summary, "actual") + iostream.print() + if self.total_usage_summary != self.actual_usage_summary: + print_usage(self.total_usage_summary, "total") + else: + iostream.print( + "All completions are non-cached: the total cost with cached completions is the same as actual cost.", + flush=True, + ) + elif mode == "total": + print_usage(self.total_usage_summary, "total") + elif mode == "actual": + print_usage(self.actual_usage_summary, "actual") + else: + raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]') + iostream.print("-" * 100, flush=True) + + def clear_usage_summary(self) -> None: + """Clear the usage summary.""" + self.total_usage_summary = None + self.actual_usage_summary = None + + @classmethod + def extract_text_or_completion_object( + cls, response: ModelClient.ModelClientResponseProtocol + ) -> list[str] | list[ModelClient.ModelClientResponseProtocol.Choice.Message]: + """Extract the text or ChatCompletion objects from a completion or chat response. + + Args: + response (ChatCompletion | Completion): The response from openai. + + Returns: + A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present. + """ + return response.message_retrieval_function(response) + + def _throttle_api_calls(self, idx: int) -> None: + """Rate limit api calls.""" + if idx < len(self._rate_limiters) and self._rate_limiters[idx]: + limiter = self._rate_limiters[idx] + + assert limiter is not None + limiter.sleep() + + def _initialize_rate_limiters(self, config_list: list[dict[str, Any]]) -> None: + for config in config_list: + # Instantiate the rate limiter + if "api_rate_limit" in config: + self._rate_limiters.append(TimeRateLimiter(config["api_rate_limit"])) + del config["api_rate_limit"] + else: + self._rate_limiters.append(None) diff --git a/train_methods/legacy_autogen/completion.py b/train_methods/legacy_autogen/completion.py new file mode 100644 index 0000000..f892bc7 --- /dev/null +++ b/train_methods/legacy_autogen/completion.py @@ -0,0 +1,1151 @@ +import logging +import shutil +import time +from collections import defaultdict +from time import sleep +from typing import Callable, Dict, List, Optional, Union + +import diskcache +import openai +import numpy as np +from flaml import BlendSearch, tune +from flaml.tune.space import is_constant + +from openai.types.completion import Completion as openai_Completion +from openai.types.chat import ChatCompletion +from openai import APIError, APIConnectionError, BadRequestError, Timeout, RateLimitError, AuthenticationError + +from train_methods.legacy_autogen.client import get_key + + +class Completion(openai_Completion): + """(openai<1) A class for OpenAI completion API. + + It also supports: ChatCompletion, Azure OpenAI API. + """ + + # set of models that support chat completion + chat_models = { + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", # deprecate in Sep + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-16k-0613", + "gpt-35-turbo", + "gpt-35-turbo-16k", + "gpt-4", + "gpt-4-32k", + "gpt-4-32k-0314", # deprecate in Sep + "gpt-4-0314", # deprecate in Sep + "gpt-4-0613", + "gpt-4-32k-0613", + } + + # price per 1k tokens + price1K = { + "text-ada-001": 0.0004, + "text-babbage-001": 0.0005, + "text-curie-001": 0.002, + "code-cushman-001": 0.024, + "code-davinci-002": 0.1, + "text-davinci-002": 0.02, + "text-davinci-003": 0.02, + "gpt-3.5-turbo": (0.0015, 0.002), + "gpt-3.5-turbo-instruct": (0.0015, 0.002), + "gpt-3.5-turbo-0301": (0.0015, 0.002), # deprecate in Sep + "gpt-3.5-turbo-0613": (0.0015, 0.002), + "gpt-3.5-turbo-16k": (0.003, 0.004), + "gpt-3.5-turbo-16k-0613": (0.003, 0.004), + "gpt-35-turbo": (0.0015, 0.002), + "gpt-35-turbo-16k": (0.003, 0.004), + "gpt-35-turbo-instruct": (0.0015, 0.002), + "gpt-4": (0.03, 0.06), + "gpt-4-32k": (0.06, 0.12), + "gpt-4-0314": (0.03, 0.06), # deprecate in Sep + "gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep + "gpt-4-0613": (0.03, 0.06), + "gpt-4-32k-0613": (0.06, 0.12), + } + + default_search_space = { + "model": tune.choice( + [ + "text-ada-001", + "text-babbage-001", + "text-davinci-003", + "gpt-3.5-turbo", + "gpt-4", + ] + ), + "temperature_or_top_p": tune.choice( + [ + {"temperature": tune.uniform(0, 2)}, + {"top_p": tune.uniform(0, 1)}, + ] + ), + "max_tokens": tune.lograndint(50, 1000), + "n": tune.randint(1, 100), + "prompt": "{prompt}", + } + + cache_seed = 41 + cache_path = f".cache/{cache_seed}" + # retry after this many seconds + retry_wait_time = 10 + # fail a request after hitting RateLimitError for this many seconds + max_retry_period = 120 + # time out for request to openai server + request_timeout = 60 + + openai_completion_class = openai_Completion + _total_cost = 0 + optimization_budget = None + + _history_dict = _count_create = None + + @classmethod + def set_cache(cls, seed: Optional[int] = 41, cache_path_root: Optional[str] = ".cache"): + """Set cache path. + + Args: + seed (int, Optional): The integer identifier for the pseudo seed. + Results corresponding to different seeds will be cached in different places. + cache_path (str, Optional): The root path for the cache. + The complete cache path will be {cache_path_root}/{seed}. + """ + cls.cache_seed = seed + cls.cache_path = f"{cache_path_root}/{seed}" + + @classmethod + def clear_cache(cls, seed: Optional[int] = None, cache_path_root: Optional[str] = ".cache"): + """Clear cache. + + Args: + seed (int, Optional): The integer identifier for the pseudo seed. + If omitted, all caches under cache_path_root will be cleared. + cache_path (str, Optional): The root path for the cache. + The complete cache path will be {cache_path_root}/{seed}. + """ + if seed is None: + shutil.rmtree(cache_path_root, ignore_errors=True) + return + with diskcache.Cache(f"{cache_path_root}/{seed}") as cache: + cache.clear() + + @classmethod + def _book_keeping(cls, config: Dict, response): + """Book keeping for the created completions.""" + if response != -1 and "cost" not in response: + response["cost"] = cls.cost(response) + if cls._history_dict is None: + return + if cls._history_compact: + value = { + "created_at": [], + "cost": [], + "token_count": [], + } + if "messages" in config: + messages = config["messages"] + if len(messages) > 1 and messages[-1]["role"] != "assistant": + existing_key = get_key(messages[:-1]) + value = cls._history_dict.pop(existing_key, value) + key = get_key(messages + [choice["message"] for choice in response["choices"]]) + else: + key = get_key([config["prompt"]] + [choice.get("text") for choice in response["choices"]]) + value["created_at"].append(cls._count_create) + value["cost"].append(response["cost"]) + value["token_count"].append( + { + "model": response["model"], + "prompt_tokens": response["usage"]["prompt_tokens"], + "completion_tokens": response["usage"].get("completion_tokens", 0), + "total_tokens": response["usage"]["total_tokens"], + } + ) + cls._history_dict[key] = value + cls._count_create += 1 + return + cls._history_dict[cls._count_create] = { + "request": config, + "response": response.to_dict_recursive(), + } + cls._count_create += 1 + + @classmethod + def _get_response(cls, config: Dict, raise_on_ratelimit_or_timeout=False, use_cache=True): + """Get the response from the openai api call. + + Try cache first. If not found, call the openai api. If the api call fails, retry after retry_wait_time. + """ + config = config.copy() + key = get_key(config) + if use_cache: + response = cls._cache.get(key, None) + if response is not None and (response != -1 or not raise_on_ratelimit_or_timeout): + # print("using cached response") + cls._book_keeping(config, response) + return response + openai_completion = ( + ChatCompletion + if config["model"].replace("gpt-35-turbo", "gpt-3.5-turbo") in cls.chat_models + or issubclass(cls, ChatCompletion) + else openai_Completion + ) + start_time = time.time() + request_timeout = cls.request_timeout + max_retry_period = config.pop("max_retry_period", cls.max_retry_period) + retry_wait_time = config.pop("retry_wait_time", cls.retry_wait_time) + while True: + try: + if "request_timeout" in config: + response = openai_completion.create(**config) + else: + response = openai_completion.create(request_timeout=request_timeout, **config) + except APIConnectionError: + # transient error + print(f"retrying in {retry_wait_time} seconds...", exc_info=1) + sleep(retry_wait_time) + except APIError as err: + error_code = err and err.json_body and isinstance(err.json_body, dict) and err.json_body.get("error") + if isinstance(error_code, dict): + error_code = error_code.get("code") + if error_code == "content_filter": + raise + # transient error + print(f"retrying in {retry_wait_time} seconds...", exc_info=1) + sleep(retry_wait_time) + except (RateLimitError, Timeout) as err: + time_left = max_retry_period - (time.time() - start_time + retry_wait_time) + if ( + time_left > 0 + and isinstance(err, RateLimitError) + or time_left > request_timeout + and isinstance(err, Timeout) + and "request_timeout" not in config + ): + if isinstance(err, Timeout): + request_timeout <<= 1 + request_timeout = min(request_timeout, time_left) + print(f"retrying in {retry_wait_time} seconds...", exc_info=1) + sleep(retry_wait_time) + elif raise_on_ratelimit_or_timeout: + raise + else: + response = -1 + if use_cache and isinstance(err, Timeout): + cls._cache.set(key, response) + print( + f"Failed to get response from openai api due to getting RateLimitError or Timeout for {max_retry_period} seconds." + ) + return response + except BadRequestError: + if "azure" in config.get("api_type", openai.api_type) and "model" in config: + # azure api uses "engine" instead of "model" + config["engine"] = config.pop("model").replace("gpt-3.5-turbo", "gpt-35-turbo") + else: + raise + else: + if use_cache: + cls._cache.set(key, response) + cls._book_keeping(config, response) + return response + + @classmethod + def _get_max_valid_n(cls, key, max_tokens): + # find the max value in max_valid_n_per_max_tokens + # whose key is equal or larger than max_tokens + return max( + (value for k, value in cls._max_valid_n_per_max_tokens.get(key, {}).items() if k >= max_tokens), + default=1, + ) + + @classmethod + def _get_min_invalid_n(cls, key, max_tokens): + # find the min value in min_invalid_n_per_max_tokens + # whose key is equal or smaller than max_tokens + return min( + (value for k, value in cls._min_invalid_n_per_max_tokens.get(key, {}).items() if k <= max_tokens), + default=None, + ) + + @classmethod + def _get_region_key(cls, config): + # get a key for the valid/invalid region corresponding to the given config + config = cls._pop_subspace(config, always_copy=False) + return ( + config["model"], + config.get("prompt", config.get("messages")), + config.get("stop"), + ) + + @classmethod + def _update_invalid_n(cls, prune, region_key, max_tokens, num_completions): + if prune: + # update invalid n and prune this config + cls._min_invalid_n_per_max_tokens[region_key] = invalid_n = cls._min_invalid_n_per_max_tokens.get( + region_key, {} + ) + invalid_n[max_tokens] = min(num_completions, invalid_n.get(max_tokens, np.inf)) + + @classmethod + def _pop_subspace(cls, config, always_copy=True): + if "subspace" in config: + config = config.copy() + config.update(config.pop("subspace")) + return config.copy() if always_copy else config + + @classmethod + def _get_params_for_create(cls, config: Dict) -> Dict: + """Get the params for the openai api call from a config in the search space.""" + params = cls._pop_subspace(config) + if cls._prompts: + params["prompt"] = cls._prompts[config["prompt"]] + else: + params["messages"] = cls._messages[config["messages"]] + if "stop" in params: + params["stop"] = cls._stops and cls._stops[params["stop"]] + temperature_or_top_p = params.pop("temperature_or_top_p", None) + if temperature_or_top_p: + params.update(temperature_or_top_p) + if cls._config_list and "config_list" not in params: + params["config_list"] = cls._config_list + return params + + @classmethod + def _eval(cls, config: dict, prune=True, eval_only=False): + """Evaluate the given config as the hyperparameter setting for the openai api call. + + Args: + config (dict): Hyperparameter setting for the openai api call. + prune (bool, optional): Whether to enable pruning. Defaults to True. + eval_only (bool, optional): Whether to evaluate only + (ignore the inference budget and do not raise error when a request fails). + Defaults to False. + + Returns: + dict: Evaluation results. + """ + cost = 0 + data = cls.data + params = cls._get_params_for_create(config) + model = params["model"] + data_length = len(data) + price = cls.price1K.get(model) + price_input, price_output = price if isinstance(price, tuple) else (price, price) + inference_budget = getattr(cls, "inference_budget", None) + prune_hp = getattr(cls, "_prune_hp", "n") + metric = cls._metric + config_n = params.get(prune_hp, 1) # default value in OpenAI is 1 + max_tokens = params.get( + "max_tokens", np.inf if model in cls.chat_models or issubclass(cls, ChatCompletion) else 16 + ) + target_output_tokens = None + if not cls.avg_input_tokens: + input_tokens = [None] * data_length + prune = prune and inference_budget and not eval_only + if prune: + region_key = cls._get_region_key(config) + max_valid_n = cls._get_max_valid_n(region_key, max_tokens) + if cls.avg_input_tokens: + target_output_tokens = (inference_budget * 1000 - cls.avg_input_tokens * price_input) / price_output + # max_tokens bounds the maximum tokens + # so using it we can calculate a valid n according to the avg # input tokens + max_valid_n = max( + max_valid_n, + int(target_output_tokens // max_tokens), + ) + if config_n <= max_valid_n: + start_n = config_n + else: + min_invalid_n = cls._get_min_invalid_n(region_key, max_tokens) + if min_invalid_n is not None and config_n >= min_invalid_n: + # prune this config + return { + "inference_cost": np.inf, + metric: np.inf if cls._mode == "min" else -np.inf, + "cost": cost, + } + start_n = max_valid_n + 1 + else: + start_n = config_n + region_key = None + num_completions, previous_num_completions = start_n, 0 + n_tokens_list, result, responses_list = [], {}, [] + while True: # n <= config_n + params[prune_hp] = num_completions - previous_num_completions + data_limit = 1 if prune else data_length + prev_data_limit = 0 + data_early_stop = False # whether data early stop happens for this n + while True: # data_limit <= data_length + # limit the number of data points to avoid rate limit + for i in range(prev_data_limit, data_limit): + data_i = data[i] + response = cls.create(data_i, raise_on_ratelimit_or_timeout=eval_only, **params) + if response == -1: # rate limit/timeout error, treat as invalid + cls._update_invalid_n(prune, region_key, max_tokens, num_completions) + result[metric] = 0 + result["cost"] = cost + return result + # evaluate the quality of the responses + responses = cls.extract_text_or_function_call(response) + usage = response["usage"] + n_input_tokens = usage["prompt_tokens"] + n_output_tokens = usage.get("completion_tokens", 0) + if not cls.avg_input_tokens and not input_tokens[i]: + # store the # input tokens + input_tokens[i] = n_input_tokens + query_cost = response["cost"] + cls._total_cost += query_cost + cost += query_cost + if cls.optimization_budget and cls._total_cost >= cls.optimization_budget and not eval_only: + # limit the total tuning cost + return { + metric: 0, + "total_cost": cls._total_cost, + "cost": cost, + } + if previous_num_completions: + n_tokens_list[i] += n_output_tokens + responses_list[i].extend(responses) + # Assumption 1: assuming requesting n1, n2 responses separately then combining them + # is the same as requesting (n1+n2) responses together + else: + n_tokens_list.append(n_output_tokens) + responses_list.append(responses) + avg_n_tokens = np.mean(n_tokens_list[:data_limit]) + rho = ( + (1 - data_limit / data_length) * (1 + 1 / data_limit) + if data_limit << 1 > data_length + else (1 - (data_limit - 1) / data_length) + ) + # Hoeffding-Serfling bound + ratio = 0.1 * np.sqrt(rho / data_limit) + if target_output_tokens and avg_n_tokens > target_output_tokens * (1 + ratio) and not eval_only: + cls._update_invalid_n(prune, region_key, max_tokens, num_completions) + result[metric] = 0 + result["total_cost"] = cls._total_cost + result["cost"] = cost + return result + if ( + prune + and target_output_tokens + and avg_n_tokens <= target_output_tokens * (1 - ratio) + and (num_completions < config_n or num_completions == config_n and data_limit == data_length) + ): + # update valid n + cls._max_valid_n_per_max_tokens[region_key] = valid_n = cls._max_valid_n_per_max_tokens.get( + region_key, {} + ) + valid_n[max_tokens] = max(num_completions, valid_n.get(max_tokens, 0)) + if num_completions < config_n: + # valid already, skip the rest of the data + data_limit = data_length + data_early_stop = True + break + prev_data_limit = data_limit + if data_limit < data_length: + data_limit = min(data_limit << 1, data_length) + else: + break + # use exponential search to increase n + if num_completions == config_n: + for i in range(data_limit): + data_i = data[i] + responses = responses_list[i] + metrics = cls._eval_func(responses, **data_i) + if result: + for key, value in metrics.items(): + if isinstance(value, (float, int)): + result[key] += value + else: + result = metrics + for key in result.keys(): + if isinstance(result[key], (float, int)): + result[key] /= data_limit + result["total_cost"] = cls._total_cost + result["cost"] = cost + if not cls.avg_input_tokens: + cls.avg_input_tokens = np.mean(input_tokens) + if prune: + target_output_tokens = ( + inference_budget * 1000 - cls.avg_input_tokens * price_input + ) / price_output + result["inference_cost"] = (avg_n_tokens * price_output + cls.avg_input_tokens * price_input) / 1000 + break + else: + if data_early_stop: + previous_num_completions = 0 + n_tokens_list.clear() + responses_list.clear() + else: + previous_num_completions = num_completions + num_completions = min(num_completions << 1, config_n) + return result + + @classmethod + def tune( + cls, + data: List[Dict], + metric: str, + mode: str, + eval_func: Callable, + log_file_name: Optional[str] = None, + inference_budget: Optional[float] = None, + optimization_budget: Optional[float] = None, + num_samples: Optional[int] = 1, + logging_level: Optional[int] = logging.WARNING, + **config, + ): + """Tune the parameters for the OpenAI API call. + + TODO: support parallel tuning with ray or spark. + TODO: support agg_method as in test + + Args: + data (list): The list of data points. + metric (str): The metric to optimize. + mode (str): The optimization mode, "min" or "max. + eval_func (Callable): The evaluation function for responses. + The function should take a list of responses and a data point as input, + and return a dict of metrics. For example, + + ```python + def eval_func(responses, **data): + solution = data["solution"] + success_list = [] + n = len(responses) + for i in range(n): + response = responses[i] + succeed = is_equiv_chain_of_thought(response, solution) + success_list.append(succeed) + return { + "expected_success": 1 - pow(1 - sum(success_list) / n, n), + "success": any(s for s in success_list), + } + ``` + + log_file_name (str, optional): The log file. + inference_budget (float, optional): The inference budget, dollar per instance. + optimization_budget (float, optional): The optimization budget, dollar in total. + num_samples (int, optional): The number of samples to evaluate. + -1 means no hard restriction in the number of trials + and the actual number is decided by optimization_budget. Defaults to 1. + logging_level (optional): logging level. Defaults to logging.WARNING. + **config (dict): The search space to update over the default search. + For prompt, please provide a string/Callable or a list of strings/Callables. + - If prompt is provided for chat models, it will be converted to messages under role "user". + - Do not provide both prompt and messages for chat models, but provide either of them. + - A string template will be used to generate a prompt for each data instance + using `prompt.format(**data)`. + - A callable template will be used to generate a prompt for each data instance + using `prompt(data)`. + For stop, please provide a string, a list of strings, or a list of lists of strings. + For messages (chat models only), please provide a list of messages (for a single chat prefix) + or a list of lists of messages (for multiple choices of chat prefix to choose from). + Each message should be a dict with keys "role" and "content". The value of "content" can be a string/Callable template. + + Returns: + dict: The optimized hyperparameter setting. + tune.ExperimentAnalysis: The tuning results. + """ + print( + "tuning via Completion.tune is deprecated in pyautogen v0.2 and openai>=1. " + "flaml.tune supports tuning more generically." + ) + space = cls.default_search_space.copy() + if config is not None: + space.update(config) + if "messages" in space: + space.pop("prompt", None) + temperature = space.pop("temperature", None) + top_p = space.pop("top_p", None) + if temperature is not None and top_p is None: + space["temperature_or_top_p"] = {"temperature": temperature} + elif temperature is None and top_p is not None: + space["temperature_or_top_p"] = {"top_p": top_p} + elif temperature is not None and top_p is not None: + space.pop("temperature_or_top_p") + space["temperature"] = temperature + space["top_p"] = top_p + print("temperature and top_p are not recommended to vary together.") + cls._max_valid_n_per_max_tokens, cls._min_invalid_n_per_max_tokens = {}, {} + cls.optimization_budget = optimization_budget + cls.inference_budget = inference_budget + cls._prune_hp = "best_of" if space.get("best_of", 1) != 1 else "n" + cls._prompts = space.get("prompt") + if cls._prompts is None: + cls._messages = space.get("messages") + if not all((isinstance(cls._messages, list), isinstance(cls._messages[0], (dict, list)))): + error_msg = "messages must be a list of dicts or a list of lists." + raise AssertionError(error_msg) + if isinstance(cls._messages[0], dict): + cls._messages = [cls._messages] + space["messages"] = tune.choice(list(range(len(cls._messages)))) + else: + if space.get("messages") is not None: + error_msg = "messages and prompt cannot be provided at the same time." + raise AssertionError(error_msg) + if not isinstance(cls._prompts, (str, list)): + error_msg = "prompt must be a string or a list of strings." + raise AssertionError(error_msg) + if isinstance(cls._prompts, str): + cls._prompts = [cls._prompts] + space["prompt"] = tune.choice(list(range(len(cls._prompts)))) + cls._stops = space.get("stop") + if cls._stops: + if not isinstance(cls._stops, (str, list)): + error_msg = "stop must be a string, a list of strings, or a list of lists of strings." + raise AssertionError(error_msg) + if not (isinstance(cls._stops, list) and isinstance(cls._stops[0], list)): + cls._stops = [cls._stops] + space["stop"] = tune.choice(list(range(len(cls._stops)))) + cls._config_list = space.get("config_list") + if cls._config_list is not None: + is_const = is_constant(cls._config_list) + if is_const: + space.pop("config_list") + cls._metric, cls._mode = metric, mode + cls._total_cost = 0 # total optimization cost + cls._eval_func = eval_func + cls.data = data + cls.avg_input_tokens = None + + space_model = space["model"] + if not isinstance(space_model, str) and len(space_model) > 1: + # make a hierarchical search space + subspace = {} + if "max_tokens" in space: + subspace["max_tokens"] = space.pop("max_tokens") + if "temperature_or_top_p" in space: + subspace["temperature_or_top_p"] = space.pop("temperature_or_top_p") + if "best_of" in space: + subspace["best_of"] = space.pop("best_of") + if "n" in space: + subspace["n"] = space.pop("n") + choices = [] + for model in space["model"]: + choices.append({"model": model, **subspace}) + space["subspace"] = tune.choice(choices) + space.pop("model") + # start all the models with the same hp config + search_alg = BlendSearch( + cost_attr="cost", + cost_budget=optimization_budget, + metric=metric, + mode=mode, + space=space, + ) + config0 = search_alg.suggest("t0") + points_to_evaluate = [config0] + for model in space_model: + if model != config0["subspace"]["model"]: + point = config0.copy() + point["subspace"] = point["subspace"].copy() + point["subspace"]["model"] = model + points_to_evaluate.append(point) + search_alg = BlendSearch( + cost_attr="cost", + cost_budget=optimization_budget, + metric=metric, + mode=mode, + space=space, + points_to_evaluate=points_to_evaluate, + ) + else: + search_alg = BlendSearch( + cost_attr="cost", + cost_budget=optimization_budget, + metric=metric, + mode=mode, + space=space, + ) + with diskcache.Cache(cls.cache_path) as cls._cache: + analysis = tune.run( + cls._eval, + search_alg=search_alg, + num_samples=num_samples, + log_file_name=log_file_name, + verbose=3, + ) + config = analysis.best_config + params = cls._get_params_for_create(config) + if cls._config_list is not None and is_const: + params.pop("config_list") + return params, analysis + + @classmethod + def create( + cls, + context: Optional[Dict] = None, + use_cache: Optional[bool] = True, + config_list: Optional[List[Dict]] = None, + filter_func: Optional[Callable[[Dict, Dict], bool]] = None, + raise_on_ratelimit_or_timeout: Optional[bool] = True, + allow_format_str_template: Optional[bool] = False, + **config, + ): + """Make a completion for a given context. + + Args: + context (Dict, Optional): The context to instantiate the prompt. + It needs to contain keys that are used by the prompt template or the filter function. + E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`. + The actual prompt will be: + "Complete the following sentence: Today I feel". + More examples can be found at [templating](https://microsoft.github.io/autogen/docs/Use-Cases/enhanced_inference#templating). + use_cache (bool, Optional): Whether to use cached responses. + config_list (List, Optional): List of configurations for the completion to try. + The first one that does not raise an error will be used. + Only the differences from the default config need to be provided. + E.g., + + ```python + response = oai.Completion.create( + config_list=[ + { + "model": "gpt-4", + "api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "api_type": "azure", + "base_url": os.environ.get("AZURE_OPENAI_API_BASE"), + "api_version": "2024-02-01", + }, + { + "model": "gpt-3.5-turbo", + "api_key": os.environ.get("OPENAI_API_KEY"), + "api_type": "openai", + "base_url": "https://api.openai.com/v1", + }, + { + "model": "llama-7B", + "base_url": "http://127.0.0.1:8080", + "api_type": "openai", + } + ], + prompt="Hi", + ) + ``` + + filter_func (Callable, Optional): A function that takes in the context and the response and returns a boolean to indicate whether the response is valid. E.g., + + ```python + def yes_or_no_filter(context, config, response): + return context.get("yes_or_no_choice", False) is False or any( + text in ["Yes.", "No."] for text in oai.Completion.extract_text(response) + ) + ``` + + raise_on_ratelimit_or_timeout (bool, Optional): Whether to raise RateLimitError or Timeout when all configs fail. + When set to False, -1 will be returned when all configs fail. + allow_format_str_template (bool, Optional): Whether to allow format string template in the config. + **config: Configuration for the openai API call. This is used as parameters for calling openai API. + The "prompt" or "messages" parameter can contain a template (str or Callable) which will be instantiated with the context. + Besides the parameters for the openai API call, it can also contain: + - `max_retry_period` (int): the total time (in seconds) allowed for retrying failed requests. + - `retry_wait_time` (int): the time interval to wait (in seconds) before retrying a failed request. + - `cache_seed` (int) for the cache. This is useful when implementing "controlled randomness" for the completion. + + Returns: + Responses from OpenAI API, with additional fields. + - `cost`: the total cost. + When `config_list` is provided, the response will contain a few more fields: + - `config_id`: the index of the config in the config_list that is used to generate the response. + - `pass_filter`: whether the response passes the filter function. None if no filter is provided. + """ + print( + "Completion.create is deprecated in pyautogen v0.2 and openai>=1. " + "The new openai requires initiating a client for inference. " + "Please refer to https://microsoft.github.io/autogen/docs/Use-Cases/enhanced_inference#api-unification" + ) + + # Warn if a config list was provided but was empty + if isinstance(config_list, list) and len(config_list) == 0: + print( + "Completion was provided with a config_list, but the list was empty. Adopting default OpenAI behavior, which reads from the 'model' parameter instead." + ) + + if config_list: + last = len(config_list) - 1 + cost = 0 + for i, each_config in enumerate(config_list): + base_config = config.copy() + base_config["allow_format_str_template"] = allow_format_str_template + base_config.update(each_config) + if i < last and filter_func is None and "max_retry_period" not in base_config: + # max_retry_period = 0 to avoid retrying when no filter is given + base_config["max_retry_period"] = 0 + try: + response = cls.create( + context, + use_cache, + raise_on_ratelimit_or_timeout=i < last or raise_on_ratelimit_or_timeout, + **base_config, + ) + if response == -1: + return response + pass_filter = filter_func is None or filter_func(context=context, response=response) + if pass_filter or i == last: + response["cost"] = cost + response["cost"] + response["config_id"] = i + response["pass_filter"] = pass_filter + return response + cost += response["cost"] + except (AuthenticationError, RateLimitError, Timeout, BadRequestError): + if i == last: + raise + params = cls._construct_params(context, config, allow_format_str_template=allow_format_str_template) + if not use_cache: + return cls._get_response( + params, raise_on_ratelimit_or_timeout=raise_on_ratelimit_or_timeout, use_cache=False + ) + cache_seed = cls.cache_seed + if "cache_seed" in params: + cls.set_cache(params.pop("cache_seed")) + with diskcache.Cache(cls.cache_path) as cls._cache: + cls.set_cache(cache_seed) + return cls._get_response(params, raise_on_ratelimit_or_timeout=raise_on_ratelimit_or_timeout) + + @classmethod + def instantiate( + cls, + template: Union[str, None], + context: Optional[Dict] = None, + allow_format_str_template: Optional[bool] = False, + ): + if not context or template is None: + return template + if isinstance(template, str): + return template.format(**context) if allow_format_str_template else template + return template(context) + + @classmethod + def _construct_params(cls, context, config, prompt=None, messages=None, allow_format_str_template=False): + params = config.copy() + model = config["model"] + prompt = config.get("prompt") if prompt is None else prompt + messages = config.get("messages") if messages is None else messages + # either "prompt" should be in config (for being compatible with non-chat models) + # or "messages" should be in config (for tuning chat models only) + if prompt is None and (model in cls.chat_models or issubclass(cls, ChatCompletion)): + if messages is None: + raise ValueError("Either prompt or messages should be in config for chat models.") + if prompt is None: + params["messages"] = ( + [ + ( + { + **m, + "content": cls.instantiate(m["content"], context, allow_format_str_template), + } + if m.get("content") + else m + ) + for m in messages + ] + if context + else messages + ) + elif model in cls.chat_models or issubclass(cls, ChatCompletion): + # convert prompt to messages + params["messages"] = [ + { + "role": "user", + "content": cls.instantiate(prompt, context, allow_format_str_template), + }, + ] + params.pop("prompt", None) + else: + params["prompt"] = cls.instantiate(prompt, context, allow_format_str_template) + return params + + @classmethod + def test( + cls, + data, + eval_func=None, + use_cache=True, + agg_method="avg", + return_responses_and_per_instance_result=False, + logging_level=logging.WARNING, + **config, + ): + """Evaluate the responses created with the config for the OpenAI API call. + + Args: + data (list): The list of test data points. + eval_func (Callable): The evaluation function for responses per data instance. + The function should take a list of responses and a data point as input, + and return a dict of metrics. You need to either provide a valid callable + eval_func; or do not provide one (set None) but call the test function after + calling the tune function in which a eval_func is provided. + In the latter case we will use the eval_func provided via tune function. + Defaults to None. + + ```python + def eval_func(responses, **data): + solution = data["solution"] + success_list = [] + n = len(responses) + for i in range(n): + response = responses[i] + succeed = is_equiv_chain_of_thought(response, solution) + success_list.append(succeed) + return { + "expected_success": 1 - pow(1 - sum(success_list) / n, n), + "success": any(s for s in success_list), + } + ``` + use_cache (bool, Optional): Whether to use cached responses. Defaults to True. + agg_method (str, Callable or a dict of Callable): Result aggregation method (across + multiple instances) for each of the metrics. Defaults to 'avg'. + An example agg_method in str: + + ```python + agg_method = 'median' + ``` + An example agg_method in a Callable: + + ```python + agg_method = np.median + ``` + + An example agg_method in a dict of Callable: + + ```python + agg_method={'median_success': np.median, 'avg_success': np.mean} + ``` + + return_responses_and_per_instance_result (bool): Whether to also return responses + and per instance results in addition to the aggregated results. + logging_level (optional): logging level. Defaults to logging.WARNING. + **config (dict): parameters passed to the openai api call `create()`. + + Returns: + None when no valid eval_func is provided in either test or tune; + Otherwise, a dict of aggregated results, responses and per instance results if `return_responses_and_per_instance_result` is True; + Otherwise, a dict of aggregated results (responses and per instance results are not returned). + """ + result_agg, responses_list, result_list = {}, [], [] + metric_keys = None + cost = 0 + for i, data_i in enumerate(data): + print(f"evaluating data instance {i}") + response = cls.create(data_i, use_cache, **config) + cost += response["cost"] + # evaluate the quality of the responses + responses = cls.extract_text_or_function_call(response) + if eval_func is not None: + metrics = eval_func(responses, **data_i) + elif hasattr(cls, "_eval_func"): + metrics = cls._eval_func(responses, **data_i) + else: + print( + "Please either provide a valid eval_func or do the test after the tune function is called." + ) + return + if not metric_keys: + metric_keys = [] + for k in metrics.keys(): + try: + _ = float(metrics[k]) + metric_keys.append(k) + except ValueError: + pass + result_list.append(metrics) + if return_responses_and_per_instance_result: + responses_list.append(responses) + if isinstance(agg_method, str): + if agg_method in ["avg", "average"]: + for key in metric_keys: + result_agg[key] = np.mean([r[key] for r in result_list]) + elif agg_method == "median": + for key in metric_keys: + result_agg[key] = np.median([r[key] for r in result_list]) + else: + print( + f"Aggregation method {agg_method} not supported. Please write your own aggregation method as a callable(s)." + ) + elif callable(agg_method): + for key in metric_keys: + result_agg[key] = agg_method([r[key] for r in result_list]) + elif isinstance(agg_method, dict): + for key in metric_keys: + metric_agg_method = agg_method[key] + if not callable(metric_agg_method): + error_msg = "please provide a callable for each metric" + raise AssertionError(error_msg) + result_agg[key] = metric_agg_method([r[key] for r in result_list]) + else: + raise ValueError( + "agg_method needs to be a string ('avg' or 'median'),\ + or a callable, or a dictionary of callable." + ) + # should we also return the result_list and responses_list or not? + if "cost" not in result_agg: + result_agg["cost"] = cost + if "inference_cost" not in result_agg: + result_agg["inference_cost"] = cost / len(data) + if return_responses_and_per_instance_result: + return result_agg, result_list, responses_list + else: + return result_agg + + @classmethod + def cost(cls, response: dict): + """Compute the cost of an API call. + + Args: + response (dict): The response from OpenAI API. + + Returns: + The cost in USD. 0 if the model is not supported. + """ + model = response.get("model") + if model not in cls.price1K: + return 0 + # raise ValueError(f"Unknown model: {model}") + usage = response["usage"] + n_input_tokens = usage["prompt_tokens"] + n_output_tokens = usage.get("completion_tokens", 0) + price1K = cls.price1K[model] + if isinstance(price1K, tuple): + return (price1K[0] * n_input_tokens + price1K[1] * n_output_tokens) / 1000 + return price1K * (n_input_tokens + n_output_tokens) / 1000 + + @classmethod + def extract_text(cls, response: dict) -> List[str]: + """Extract the text from a completion or chat response. + + Args: + response (dict): The response from OpenAI API. + + Returns: + A list of text in the responses. + """ + choices = response["choices"] + if "text" in choices[0]: + return [choice["text"] for choice in choices] + return [choice["message"].get("content", "") for choice in choices] + + @classmethod + def extract_text_or_function_call(cls, response: dict) -> List[str]: + """Extract the text or function calls from a completion or chat response. + + Args: + response (dict): The response from OpenAI API. + + Returns: + A list of text or function calls in the responses. + """ + choices = response["choices"] + if "text" in choices[0]: + return [choice["text"] for choice in choices] + return [ + choice["message"] if "function_call" in choice["message"] else choice["message"].get("content", "") + for choice in choices + ] + + @classmethod + @property + def logged_history(cls) -> Dict: + """Return the book keeping dictionary.""" + return cls._history_dict + + @classmethod + def print_usage_summary(cls) -> Dict: + """Return the usage summary.""" + if cls._history_dict is None: + print("No usage summary available.", flush=True) + + token_count_summary = defaultdict(lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}) + + if not cls._history_compact: + source = cls._history_dict.values() + total_cost = sum(msg_pair["response"]["cost"] for msg_pair in source) + else: + # source = cls._history_dict["token_count"] + # total_cost = sum(cls._history_dict['cost']) + total_cost = sum(sum(value_list["cost"]) for value_list in cls._history_dict.values()) + source = ( + token_data for value_list in cls._history_dict.values() for token_data in value_list["token_count"] + ) + + for entry in source: + if not cls._history_compact: + model = entry["response"]["model"] + token_data = entry["response"]["usage"] + else: + model = entry["model"] + token_data = entry + + token_count_summary[model]["prompt_tokens"] += token_data["prompt_tokens"] + token_count_summary[model]["completion_tokens"] += token_data["completion_tokens"] + token_count_summary[model]["total_tokens"] += token_data["total_tokens"] + + print(f"Total cost: {total_cost}", flush=True) + for model, counts in token_count_summary.items(): + print( + f"Token count summary for model {model}: prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}", + flush=True, + ) + + @classmethod + def start_logging( + cls, history_dict: Optional[Dict] = None, compact: Optional[bool] = True, reset_counter: Optional[bool] = True + ): + """Start book keeping. + + Args: + history_dict (Dict): A dictionary for book keeping. + If no provided, a new one will be created. + compact (bool): Whether to keep the history dictionary compact. + Compact history contains one key per conversation, and the value is a dictionary + like: + ```python + { + "create_at": [0, 1], + "cost": [0.1, 0.2], + } + ``` + where "created_at" is the index of API calls indicating the order of all the calls, + and "cost" is the cost of each call. This example shows that the conversation is based + on two API calls. The compact format is useful for condensing the history of a conversation. + If compact is False, the history dictionary will contain all the API calls: the key + is the index of the API call, and the value is a dictionary like: + ```python + { + "request": request_dict, + "response": response_dict, + } + ``` + where request_dict is the request sent to OpenAI API, and response_dict is the response. + For a conversation containing two API calls, the non-compact history dictionary will be like: + ```python + { + 0: { + "request": request_dict_0, + "response": response_dict_0, + }, + 1: { + "request": request_dict_1, + "response": response_dict_1, + }, + ``` + The first request's messages plus the response is equal to the second request's messages. + For a conversation with many turns, the non-compact history dictionary has a quadratic size + while the compact history dict has a linear size. + reset_counter (bool): whether to reset the counter of the number of API calls. + """ + print( + "logging via Completion.start_logging is deprecated in pyautogen v0.2. " + "logging via OpenAIWrapper will be added back in a future release." + ) + cls._history_dict = {} if history_dict is None else history_dict + cls._history_compact = compact + cls._count_create = 0 if reset_counter or cls._count_create is None else cls._count_create + + @classmethod + def stop_logging(cls): + """End book keeping.""" + cls._history_dict = cls._count_create = None + diff --git a/train_methods/legacy_autogen.py b/train_methods/legacy_autogen/legacy_autogen.py similarity index 89% rename from train_methods/legacy_autogen.py rename to train_methods/legacy_autogen/legacy_autogen.py index 94e2da1..d9b9a61 100644 --- a/train_methods/legacy_autogen.py +++ b/train_methods/legacy_autogen/legacy_autogen.py @@ -3,155 +3,19 @@ """ import json import sys -import logging import random import re from copy import deepcopy -from contextlib import contextmanager -from contextvars import ContextVar from dataclasses import dataclass, field -from typing import Any, Callable, Literal, Union, Protocol, TypedDict, Iterator +from typing import Any, Callable, Literal, Union from termcolor import colored -from train_methods.legacy_autogen_conversable_agent import ConversableAgent, Agent +from train_methods.legacy_autogen.legacy_autogen_conversable_agent import ConversableAgent, Agent +from train_methods.legacy_autogen.stream import IOStream +from train_methods.legacy_autogen.client import ModelClient +from train_methods.legacy_autogen.utils import content_str -logger = logging.getLogger(__name__) - - -class OutputStream(Protocol): - def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None: - """Print data to the output stream. - - Args: - objects (any): The data to print. - sep (str, optional): The separator between objects. Defaults to " ". - end (str, optional): The end of the output. Defaults to "\n". - flush (bool, optional): Whether to flush the output. Defaults to False. - """ - ... # pragma: no cover - - -class InputStream(Protocol): - def input(self, prompt: str = "", *, password: bool = False) -> str: - """Read a line from the input stream. - - Args: - prompt (str, optional): The prompt to display. Defaults to "". - password (bool, optional): Whether to read a password. Defaults to False. - - Returns: - str: The line read from the input stream. - - """ - ... # pragma: no cover - - -class IOStream(InputStream, OutputStream, Protocol): - """A protocol for input/output streams.""" - - # ContextVar must be used in multithreaded or async environments - _default_io_stream: ContextVar["IOStream" | None] = ContextVar("default_iostream", default=None) - _default_io_stream.set(None) - _global_default: "IOStream" | None = None - - @staticmethod - def set_global_default(stream: "IOStream") -> None: - """Set the default input/output stream. - - Args: - stream (IOStream): The input/output stream to set as the default. - """ - IOStream._global_default = stream - - @staticmethod - def get_global_default() -> "IOStream": - """Get the default input/output stream. - - Returns: - IOStream: The default input/output stream. - """ - if IOStream._global_default is None: - raise RuntimeError("No global default IOStream has been set") - return IOStream._global_default - - @staticmethod - def get_default() -> "IOStream": - """Get the default input/output stream. - - Returns: - IOStream: The default input/output stream. - """ - iostream = IOStream._default_io_stream.get() - if iostream is None: - iostream = IOStream.get_global_default() - # Set the default IOStream of the current context (thread/cooroutine) - IOStream.set_default(iostream) - return iostream - - @staticmethod - @contextmanager - def set_default(stream: "IOStream" | None) -> Iterator[None]: - """Set the default input/output stream. - - Args: - stream (IOStream): The input/output stream to set as the default. - """ - global _default_io_stream - try: - token = IOStream._default_io_stream.set(stream) - yield - finally: - IOStream._default_io_stream.reset(token) - - return - -class UserMessageTextContentPart(TypedDict): - type: Literal["text"] - text: str - -class UserMessageImageContentPart(TypedDict): - type: Literal["image_url"] - image_url: dict[Literal["url"], str] - -def content_str(content: str | list[UserMessageTextContentPart| UserMessageImageContentPart] | None) -> str: - """Converts the `content` field of an OpenAI message into a string format. - - This function processes content that may be a string, a list of mixed text and image URLs, or None, - and converts it into a string. Text is directly appended to the result string, while image URLs are - represented by a placeholder image token. If the content is None, an empty string is returned. - - Args: - - content (Union[str, List, None]): The content to be processed. Can be a string, a list of dictionaries representing text and image URLs, or None. - - Returns: - str: A string representation of the input content. Image URLs are replaced with an image token. - - Note: - - The function expects each dictionary in the list to have a "type" key that is either "text" or "image_url". - For "text" type, the "text" key's value is appended to the result. For "image_url", an image token is appended. - - This function is useful for handling content that may include both text and image references, especially - in contexts where images need to be represented as placeholders. - """ - if content is None: - return "" - if isinstance(content, str): - return content - if not isinstance(content, list): - raise TypeError(f"content must be None, str, or list, but got {type(content)}") - - rst = "" - for item in content: - if not isinstance(item, dict): - raise TypeError("Wrong content format: every element should be dict if the content is a list.") - assert "type" in item, "Wrong content format. Missing 'type' key in content's dict." - if item["type"] == "text": - rst += item["text"] - elif item["type"] == "image_url": - rst += "" - else: - raise ValueError(f"Wrong content format: unknown type {item['type']} within the content") - return rst class AgentNameConflict(Exception): def __init__(self, msg: str = "Found multiple agents with the same name.", *args: Any, **kwargs: Any): @@ -171,57 +35,6 @@ def __init__(self, message: str = "The provided agents list does not overlap wit self.message = message super().__init__(self.message) - -class ModelClient(Protocol): - """ - A client class must implement the following methods: - - create must return a response object that implements the ModelClientResponseProtocol - - cost must return the cost of the response - - get_usage must return a dict with the following keys: - - prompt_tokens - - completion_tokens - - total_tokens - - cost - - model - - This class is used to create a client that can be used by OpenAIWrapper. - The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed. - The message_retrieval method must be implemented to return a list of str or a list of messages from the response. - """ - - RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] - - class ModelClientResponseProtocol(Protocol): - class Choice(Protocol): - class Message(Protocol): - content: str | None - - message: Message - - choices: list[Choice] - model: str - - def create(self, params: dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover - - def message_retrieval( - self, response: ModelClientResponseProtocol - ) -> list[str] | list[ModelClientResponseProtocol.Choice.Message]: - """ - Retrieve and return a list of strings or a list of Choice.Message from the response. - - NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, - since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. - """ - ... # pragma: no cover - - def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover - - @staticmethod - def get_usage(response: ModelClientResponseProtocol) -> dict: - """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" - ... # pragma: no cover - - @dataclass class GroupChat: @@ -481,7 +294,7 @@ def _prepare_and_select_agents( "Please add more agents to the GroupChat or use direct communication instead." ) elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker: - logger.warning( + print( f"GroupChat is underpopulated with {n_agents} agents. " "Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, " "or use direct communication, unless repeated speaker is desired." @@ -597,7 +410,7 @@ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: if len(mentions) == 1: name = next(iter(mentions)) else: - logger.warning( + print( f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}" ) @@ -968,7 +781,7 @@ def _participant_roles(self, agents: list["Agent"] | None = None) -> str: roles = [] for agent in agents: if agent.description.strip() == "": - logger.warning( + print( f"The agent '{agent.name}' has an empty description, and may not work well with GroupChat." ) roles.append(f"{agent.name}: {agent.description}".strip()) @@ -1192,7 +1005,7 @@ def run_chat( raise except NoEligibleSpeaker: # No eligible speaker, terminate the conversation - logger.warning("No eligible speaker found. Terminating the conversation.") + print("No eligible speaker found. Terminating the conversation.") break if reply is None: @@ -1274,7 +1087,7 @@ async def a_run_chat( raise except NoEligibleSpeaker: # No eligible speaker, terminate the conversation - logger.warning("No eligible speaker found. Terminating the conversation.") + print("No eligible speaker found. Terminating the conversation.") break if reply is None: @@ -1553,7 +1366,7 @@ def _remove_termination_string(content: str) -> str: # Check if the last message meets termination (if it has one) if self._is_termination_msg: if self._is_termination_msg(last_message): - logger.warning("WARNING: Last message meets termination criteria and this may terminate the chat.") + print("WARNING: Last message meets termination criteria and this may terminate the chat.") def messages_from_string(self, message_string: str) -> list[dict]: """Reads the saved state of messages in Json format for resume and returns as a messages list @@ -1641,7 +1454,7 @@ def clear_agents_history(self, reply: dict, groupchat: GroupChat) -> str: # preserve last tool call message if clear history called inside of tool response if "tool_responses" in reply and not nr_messages_to_preserve: nr_messages_to_preserve = 1 - logger.warning( + print( "The last tool call message will be saved to prevent errors caused by tool response without tool call." ) # clear history diff --git a/train_methods/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py similarity index 92% rename from train_methods/legacy_autogen_conversable_agent.py rename to train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index 9c47293..173ab26 100644 --- a/train_methods/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -4,22 +4,26 @@ import functools import inspect import json -import logging import re import warnings from collections import defaultdict -from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union, Protocol +from typing import Any, Callable, Coroutine, Literal, Type, Protocol, TypeVar from openai import BadRequestError +from pydantic import BaseModel +from termcolor import colored + +from ..coding.base import CodeExecutor +from ..coding.factory import CodeExecutorFactory +from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str +from .utils import consolidate_chat_info, gather_usage_summary -from autogen.agentchat.chat import _post_process_carryover_item -from autogen.exception_utils import InvalidCarryOverType, SenderRequired -from .._pydantic import model_dump -from ..cache.cache import AbstractCache -from ..code_utils import ( - PYTHON_VARIANTS, - UNKNOWN, +from train_methods.legacy_autogen.cache import AbstractCache +from train_methods.legacy_autogen.chat import ChatResult, a_initiate_chats, initiate_chats, _post_process_carryover_item +from train_methods.legacy_autogen.client import ModelClient, OpenAIWrapper +from train_methods.legacy_autogen.stream import IOStream +from train_methods.legacy_autogen.utils import ( check_can_use_docker_or_throw, content_str, decide_use_docker, @@ -27,20 +31,31 @@ extract_code, infer_lang, ) -from ..coding.base import CodeExecutor -from ..coding.factory import CodeExecutorFactory -from ..formatting_utils import colored -from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str -from ..io.base import IOStream -from ..oai.client import ModelClient, OpenAIWrapper -from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled -from .agent import Agent, LLMAgent -from .chat import ChatResult, a_initiate_chats, initiate_chats -from .utils import consolidate_chat_info, gather_usage_summary __all__ = ("ConversableAgent",) -logger = logging.getLogger(__name__) +F = TypeVar("F", bound=Callable[..., Any]) +PYTHON_VARIANTS = ["python", "Python", "py"] +UNKNOWN = "unknown" + +def model_dump(model: BaseModel) -> dict[str, Any]: + return model.model_dump() + +class SenderRequired(Exception): + """Exception raised when the sender is required but not provided.""" + + def __init__(self, message: str = "Sender is required but not provided."): + self.message = message + super().__init__(self.message) + +class InvalidCarryOverType(Exception): + """Exception raised when the carryover type is invalid.""" + + def __init__( + self, message: str = "Carryover should be a string or a list of strings. Not adding carryover to the message." + ): + self.message = message + super().__init__(self.message) class Agent(Protocol): """(In preview) A protocol for Agent. @@ -158,7 +173,6 @@ async def a_generate_reply( str or dict or None: the generated reply. If None, no reply is generated. """ - class LLMAgent(Agent, Protocol): """(In preview) A protocol for an LLM agent.""" @@ -173,11 +187,6 @@ def update_system_message(self, system_message: str) -> None: system_message (str): system message for inference. """ - - -F = TypeVar("F", bound=Callable[..., Any]) - - class ConversableAgent(LLMAgent): """(In preview) A class for generic conversable agents which can be configured as assistant or user proxy. @@ -197,22 +206,22 @@ class ConversableAgent(LLMAgent): DEFAULT_SUMMARY_PROMPT = "Summarize the takeaway from the conversation. Do not add any introductory phrases." DEFAULT_SUMMARY_METHOD = "last_msg" - llm_config: Union[Dict, Literal[False]] + llm_config: dict | Literal[False] def __init__( self, name: str, - system_message: Optional[Union[str, List]] = "You are a helpful AI Assistant.", - is_termination_msg: Optional[Callable[[Dict], bool]] = None, - max_consecutive_auto_reply: Optional[int] = None, + system_message: str | list | None = "You are a helpful AI Assistant.", + is_termination_msg: Callable[[dict], bool] | None = None, + max_consecutive_auto_reply: int | None = None, human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE", - function_map: Optional[Dict[str, Callable]] = None, - code_execution_config: Union[Dict, Literal[False]] = False, - llm_config: Optional[Union[Dict, Literal[False]]] = None, - default_auto_reply: Union[str, Dict] = "", - description: Optional[str] = None, - chat_messages: Optional[Dict[Agent, List[Dict]]] = None, - silent: Optional[bool] = None, + function_map: dict[str, Callable] | None = None, + code_execution_config: dict | Literal[False] = False, + llm_config: dict | Literal[False] | None = None, + default_auto_reply: str | dict = "", + description: str | None = None, + chat_messages: dict[Agent, list[dict]] | None = None, + silent: bool | None = None, ): """ Args: @@ -297,10 +306,6 @@ def __init__( self._validate_llm_config(llm_config) - if logging_enabled(): - log_new_agent(self, locals()) - - # Initialize standalone client cache object. self.client_cache = None self.human_input_mode = human_input_mode @@ -381,7 +386,7 @@ def __init__( # Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration. # New hookable methods should be added to this list as required to support new agent capabilities. - self.hook_lists: Dict[str, List[Union[Callable, Callable[..., Coroutine]]]] = { + self.hook_lists: dict[str, list[Callable | Callable[..., Coroutine]]] = { "process_last_received_message": [], "a_process_last_received_message": [], "process_all_messages_before_reply": [], @@ -405,7 +410,7 @@ def _validate_llm_config(self, llm_config): self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config) @staticmethod - def _is_silent(agent: Agent, silent: Optional[bool] = False) -> bool: + def _is_silent(agent: Agent, silent: bool | None = False) -> bool: return agent.silent if agent.silent is not None else silent @property @@ -424,7 +429,7 @@ def description(self, description: str): self._description = description @property - def code_executor(self) -> Optional[CodeExecutor]: + def code_executor(self) -> CodeExecutor | None: """The code executor used by this agent. Returns None if code execution is disabled.""" if not hasattr(self, "_code_executor"): return None @@ -432,11 +437,11 @@ def code_executor(self) -> Optional[CodeExecutor]: def register_reply( self, - trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List], + trigger: Type[Agent] | str | Agent | Callable[[Agent], bool] | list, reply_func: Callable, position: int = 0, - config: Optional[Any] = None, - reset_config: Optional[Callable] = None, + config: Any | None = None, + reset_config: Callable | None = None, *, ignore_async_in_sync_chat: bool = False, remove_other_reply_funcs: bool = False, @@ -469,10 +474,10 @@ def register_reply( ```python def reply_func( recipient: ConversableAgent, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> Tuple[bool, str | dict | None]: ``` position (int): the position of the reply function in the reply function list. The function registered later will be checked earlier by default. @@ -515,8 +520,8 @@ def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable) @staticmethod def _get_chats_to_run( - chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any - ) -> List[Dict[str, Any]]: + chat_queue: list[dict[str, Any]], recipient: Agent, messages: str | Callable, sender: Agent, config: Any + ) -> list[dict[str, Any]]: """A simple chat reply function. This function initiate one or a sequence of chats between the "recipient" and the agents in the chat_queue. @@ -547,8 +552,8 @@ def _get_chats_to_run( @staticmethod def _summary_from_nested_chats( - chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any - ) -> Tuple[bool, Union[str, None]]: + chat_queue: list[dict[str, Any]], recipient: Agent, messages: str | Callable, sender: Agent, config: Any + ) -> tuple[bool, str | None]: """A simple chat reply function. This function initiate one or a sequence of chats between the "recipient" and the agents in the chat_queue. @@ -556,7 +561,7 @@ def _summary_from_nested_chats( It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. Returns: - Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. """ chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) if not chat_to_run: @@ -566,8 +571,8 @@ def _summary_from_nested_chats( @staticmethod async def _a_summary_from_nested_chats( - chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any - ) -> Tuple[bool, Union[str, None]]: + chat_queue: list[dict[str, Any]], recipient: Agent, messages: str | Callable, sender: Agent, config: Any + ) -> tuple[bool, str | None]: """A simple chat reply function. This function initiate one or a sequence of chats between the "recipient" and the agents in the chat_queue. @@ -575,7 +580,7 @@ async def _a_summary_from_nested_chats( It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. Returns: - Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. + tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. """ chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) if not chat_to_run: @@ -586,11 +591,11 @@ async def _a_summary_from_nested_chats( def register_nested_chats( self, - chat_queue: List[Dict[str, Any]], - trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List], - reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats", + chat_queue: list[dict[str, Any]], + trigger: Type[Agent] | str | Agent | Callable[[Agent], bool] | list, + reply_func_from_nested_chats: str | Callable = "summary_from_nested_chats", position: int = 2, - use_async: Union[bool, None] = None, + use_async: bool | None = None, **kwargs, ) -> None: """Register a nested chat reply function. @@ -602,12 +607,12 @@ def register_nested_chats( Default to "summary_from_nested_chats", which corresponds to a built-in reply function that get summary from the nested chat_queue. ```python def reply_func_from_nested_chats( - chat_queue: List[Dict], + chat_queue: list[dict], recipient: ConversableAgent, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, str | dict | None]: ``` position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply. use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync. @@ -664,7 +669,7 @@ def update_system_message(self, system_message: str) -> None: """ self._oai_system_message[0]["content"] = system_message - def update_max_consecutive_auto_reply(self, value: int, sender: Optional[Agent] = None): + def update_max_consecutive_auto_reply(self, value: int, sender: Agent | None = None): """Update the maximum number of consecutive auto replies. Args: @@ -678,20 +683,20 @@ def update_max_consecutive_auto_reply(self, value: int, sender: Optional[Agent] else: self._max_consecutive_auto_reply_dict[sender] = value - def max_consecutive_auto_reply(self, sender: Optional[Agent] = None) -> int: + def max_consecutive_auto_reply(self, sender: Agent | None = None) -> int: """The maximum number of consecutive auto replies.""" return self._max_consecutive_auto_reply if sender is None else self._max_consecutive_auto_reply_dict[sender] @property - def chat_messages(self) -> Dict[Agent, List[Dict]]: + def chat_messages(self) -> dict[Agent, list[dict]]: """A dictionary of conversations from agent to list of messages.""" return self._oai_messages - def chat_messages_for_summary(self, agent: Agent) -> List[Dict]: + def chat_messages_for_summary(self, agent: Agent) -> list[dict]: """A list of messages as a conversation to summarize.""" return self._oai_messages[agent] - def last_message(self, agent: Optional[Agent] = None) -> Optional[Dict]: + def last_message(self, agent: Agent | None = None) -> dict | None: """The last message exchanged with the agent. Args: @@ -717,14 +722,14 @@ def last_message(self, agent: Optional[Agent] = None) -> Optional[Dict]: return self._oai_messages[agent][-1] @property - def use_docker(self) -> Union[bool, str, None]: + def use_docker(self) -> bool | str | None: """Bool value of whether to use docker to execute the code, or str value of the docker image name to use, or None when code execution is disabled. """ return None if self._code_execution_config is False else self._code_execution_config.get("use_docker") @staticmethod - def _message_to_dict(message: Union[Dict, str]) -> Dict: + def _message_to_dict(message: dict | str) -> dict: """Convert a message to a dictionary. The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary. @@ -758,7 +763,7 @@ def _assert_valid_name(name): raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.") return name - def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent, is_sending: bool) -> bool: + def _append_oai_message(self, message: dict | str, role, conversation_id: Agent, is_sending: bool) -> bool: """Append a message to the ChatCompletion conversation. If the message received is a string, it will be put in the "content" field of the new dictionary. @@ -812,8 +817,8 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: return True def _process_message_before_send( - self, message: Union[Dict, str], recipient: Agent, silent: bool - ) -> Union[Dict, str]: + self, message: dict | str, recipient: Agent, silent: bool + ) -> dict | str: """Process the message before sending it to the recipient.""" hook_list = self.hook_lists["process_message_before_send"] for hook in hook_list: @@ -825,8 +830,8 @@ def _process_message_before_send( return message async def _a_process_message_before_send( - self, message: Union[Dict, str], recipient: Agent, silent: bool - ) -> Union[Dict, str]: + self, message: dict | str, recipient: Agent, silent: bool + ) -> dict | str: """(async) Process the message before sending it to the recipient.""" hook_list = self.hook_lists["a_process_message_before_send"] for hook in hook_list: @@ -837,10 +842,10 @@ async def _a_process_message_before_send( def send( self, - message: Union[Dict, str], + message: dict | str, recipient: Agent, - request_reply: Optional[bool] = None, - silent: Optional[bool] = False, + request_reply: bool | None = None, + silent: bool | None = False, ): """Send a message to another agent. @@ -887,10 +892,10 @@ def send( async def a_send( self, - message: Union[Dict, str], + message: dict | str, recipient: Agent, - request_reply: Optional[bool] = None, - silent: Optional[bool] = False, + request_reply: bool | None = None, + silent: bool | None = False, ): """(async) Send a message to another agent. @@ -937,7 +942,7 @@ async def a_send( "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." ) - def _print_received_message(self, message: Union[Dict, str], sender: Agent): + def _print_received_message(self, message: dict | str, sender: Agent): iostream = IOStream.get_default() # print the message received iostream.print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True) @@ -998,11 +1003,9 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent): iostream.print("\n", "-" * 80, flush=True, sep="") - def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool): + def _process_received_message(self, message: dict | str, sender: Agent, silent: bool): # When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.) valid = self._append_oai_message(message, "user", sender, is_sending=False) - if logging_enabled(): - log_event(self, "received_message", message=message, sender=sender.name, valid=valid) if not valid: raise ValueError( @@ -1014,10 +1017,10 @@ def _process_received_message(self, message: Union[Dict, str], sender: Agent, si def receive( self, - message: Union[Dict, str], + message: dict | str, sender: Agent, - request_reply: Optional[bool] = None, - silent: Optional[bool] = False, + request_reply: bool | None = None, + silent: bool | None = False, ): """Receive a message from another agent. @@ -1051,10 +1054,10 @@ def receive( async def a_receive( self, - message: Union[Dict, str], + message: dict | str, sender: Agent, - request_reply: Optional[bool] = None, - silent: Optional[bool] = False, + request_reply: bool | None = None, + silent: bool | None = False, ): """(async) Receive a message from another agent. @@ -1124,12 +1127,12 @@ def initiate_chat( self, recipient: "ConversableAgent", clear_history: bool = True, - silent: Optional[bool] = False, - cache: Optional[AbstractCache] = None, - max_turns: Optional[int] = None, - summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD, - summary_args: Optional[dict] = {}, - message: Optional[Union[Dict, str, Callable]] = None, + silent: bool | None = False, + cache: AbstractCache | None = None, + max_turns: int | None = None, + summary_method: str | Callable | None = DEFAULT_SUMMARY_METHOD, + summary_args: dict | None = {}, + message: dict | str | Callable | None = None, **kwargs, ) -> ChatResult: """Initiate a chat with the recipient agent. @@ -1187,7 +1190,7 @@ def my_summary_method( Example of a callable message (returning a string): ```python - def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> Union[str, Dict]: + def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> str | dict: carryover = context.get("carryover", "") if isinstance(message, list): carryover = carryover[-1] @@ -1198,7 +1201,7 @@ def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: d Example of a callable message (returning a dict): ```python - def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> Union[str, Dict]: + def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> str | dict: final_msg = {} carryover = context.get("carryover", "") if isinstance(message, list): @@ -1267,12 +1270,12 @@ async def a_initiate_chat( self, recipient: "ConversableAgent", clear_history: bool = True, - silent: Optional[bool] = False, - cache: Optional[AbstractCache] = None, - max_turns: Optional[int] = None, - summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD, - summary_args: Optional[dict] = {}, - message: Optional[Union[str, Callable]] = None, + silent: bool | None = False, + cache: AbstractCache | None = None, + max_turns: int | None = None, + summary_method: str | Callable | None = DEFAULT_SUMMARY_METHOD, + summary_args: dict | None = {}, + message: str | Callable | None = None, **kwargs, ) -> ChatResult: """(async) Initiate a chat with the recipient agent. @@ -1333,8 +1336,8 @@ def _summarize_chat( self, summary_method, summary_args, - recipient: Optional[Agent] = None, - cache: Optional[AbstractCache] = None, + recipient: Agent | None = None, + cache: AbstractCache | None = None, ) -> str: """Get a chat summary from an agent participating in a chat. @@ -1417,9 +1420,9 @@ def _reflection_with_llm( self, prompt, messages, - llm_agent: Optional[Agent] = None, - cache: Optional[AbstractCache] = None, - role: Union[str, None] = None, + llm_agent: Agent | None = None, + cache: AbstractCache | None = None, + role: str | None = None, ) -> str: """Get a chat summary using reflection with an llm client based on the conversation history. @@ -1450,15 +1453,15 @@ def _reflection_with_llm( response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache) return response - def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _check_chat_queue_for_sender(self, chat_queue: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Check the chat queue and add the "sender" key if it's missing. Args: - chat_queue (List[Dict[str, Any]]): A list of dictionaries containing chat information. + chat_queue (list[dict[str, Any]]): A list of dictionaries containing chat information. Returns: - List[Dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing. + list[dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing. """ chat_queue_with_sender = [] for chat_info in chat_queue: @@ -1467,11 +1470,11 @@ def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List chat_queue_with_sender.append(chat_info) return chat_queue_with_sender - def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: + def initiate_chats(self, chat_queue: list[dict[str, Any]]) -> list[ChatResult]: """(Experimental) Initiate chats with multiple agents. Args: - chat_queue (List[Dict]): a list of dictionaries containing the information of the chats. + chat_queue (list[dict]): a list of dictionaries containing the information of the chats. Each dictionary should contain the input arguments for [`initiate_chat`](conversable_agent#initiate_chat) Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue. @@ -1480,12 +1483,12 @@ def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: self._finished_chats = initiate_chats(_chat_queue) return self._finished_chats - async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]: + async def a_initiate_chats(self, chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]: _chat_queue = self._check_chat_queue_for_sender(chat_queue) self._finished_chats = await a_initiate_chats(_chat_queue) return self._finished_chats - def get_chat_results(self, chat_index: Optional[int] = None) -> Union[List[ChatResult], ChatResult]: + def get_chat_results(self, chat_index: int | None = None) -> list[ChatResult] | ChatResult: """A summary from the finished chats of particular agents.""" if chat_index is not None: return self._finished_chats[chat_index] @@ -1505,21 +1508,21 @@ def reset(self): else: reply_func_tuple["config"] = copy.copy(reply_func_tuple["init_config"]) - def stop_reply_at_receive(self, sender: Optional[Agent] = None): + def stop_reply_at_receive(self, sender: Agent | None = None): """Reset the reply_at_receive of the sender.""" if sender is None: self.reply_at_receive.clear() else: self.reply_at_receive[sender] = False - def reset_consecutive_auto_reply_counter(self, sender: Optional[Agent] = None): + def reset_consecutive_auto_reply_counter(self, sender: Agent | None = None): """Reset the consecutive_auto_reply_counter of the sender.""" if sender is None: self._consecutive_auto_reply_counter.clear() else: self._consecutive_auto_reply_counter[sender] = 0 - def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preserve: Optional[int] = None): + def clear_history(self, recipient: Agent | None = None, nr_messages_to_preserve: int | None = None): """Clear the chat history of the agent. Args: @@ -1557,10 +1560,10 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser def generate_oai_reply( self, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[OpenAIWrapper] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + messages: list[dict] | None = None, + sender: Agent | None = None, + config: OpenAIWrapper | None = None, + ) -> tuple[bool, str | dict | None]: """Generate a reply using autogen.oai.""" client = self.client if config is None else config if client is None: @@ -1572,7 +1575,7 @@ def generate_oai_reply( ) return (False, None) if extracted_response is None else (True, extracted_response) - def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[str, Dict, None]: + def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> str | dict | None: # unroll tool_responses all_messages = [] for message in messages: @@ -1585,7 +1588,6 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[ else: all_messages.append(message) - # TODO: #1143 handle token limit exceeded error response = llm_client.create( context=messages[-1].pop("context", None), messages=all_messages, cache=cache, agent=self ) @@ -1614,17 +1616,17 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[ async def a_generate_oai_reply( self, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, str | dict | None]: """Generate a reply using autogen.oai asynchronously.""" iostream = IOStream.get_default() parent_context = contextvars.copy_context() def _generate_oai_reply( self, iostream: IOStream, *args: Any, **kwargs: Any - ) -> Tuple[bool, Union[str, Dict, None]]: + ) -> tuple[bool, str | dict | None]: with IOStream.set_default(iostream): return self.generate_oai_reply(*args, **kwargs) @@ -1637,9 +1639,9 @@ def _generate_oai_reply( def _generate_code_execution_reply_using_executor( self, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Union[Dict, Literal[False]]] = None, + messages: list[dict] | None = None, + sender: Agent | None = None, + config: dict | Literal[False] | None = None, ): """Generate a reply using code executor.""" iostream = IOStream.get_default() @@ -1706,9 +1708,9 @@ def _generate_code_execution_reply_using_executor( def generate_code_execution_reply( self, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Union[Dict, Literal[False]]] = None, + messages: list[dict] | None = None, + sender: Agent | None = None, + config: dict | Literal[False] | None = None, ): """Generate a reply using code execution.""" code_execution_config = config if config is not None else self._code_execution_config @@ -1758,10 +1760,10 @@ def generate_code_execution_reply( def generate_function_call_reply( self, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[Dict, None]]: + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, dict | None]: """ Generate a reply using function call. @@ -1796,10 +1798,10 @@ def generate_function_call_reply( async def a_generate_function_call_reply( self, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[Dict, None]]: + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, dict | None]: """ Generate a reply using async function call. @@ -1828,10 +1830,10 @@ def _str_for_tool_response(self, tool_response): def generate_tool_calls_reply( self, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[Dict, None]]: + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, dict | None]: """Generate a reply using tool call.""" if config is None: config = self @@ -1895,10 +1897,10 @@ async def _a_execute_tool_call(self, tool_call): async def a_generate_tool_calls_reply( self, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[Dict, None]]: + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, dict | None]: """Generate a reply using async function call.""" if config is None: config = self @@ -1920,10 +1922,10 @@ async def a_generate_tool_calls_reply( def check_termination_and_human_reply( self, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, None]]: + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, str | None]: """Check if the conversation should be terminated, and if human reply is provided. This method checks for conditions that require the conversation to be terminated, such as reaching @@ -1933,12 +1935,12 @@ def check_termination_and_human_reply( for the conversation and prints relevant messages based on the human input received. Args: - - messages (Optional[List[Dict]]): A list of message dictionaries, representing the conversation history. - - sender (Optional[Agent]): The agent object representing the sender of the message. - - config (Optional[Any]): Configuration object, defaults to the current instance if not provided. + - messages (list[dict] | None): A list of message dictionaries, representing the conversation history. + - sender (Agent | None): The agent object representing the sender of the message. + - config (Any | None): Configuration object, defaults to the current instance if not provided. Returns: - - Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation + - tuple[bool, str | dict | None]: A tuple containing a boolean indicating if the conversation should be terminated, and a human reply which can be a string, a dictionary, or None. """ iostream = IOStream.get_default() @@ -2033,10 +2035,10 @@ def check_termination_and_human_reply( async def a_check_termination_and_human_reply( self, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, None]]: + messages: list[dict] | None = None, + sender: Agent | None = None, + config: Any | None = None, + ) -> tuple[bool, str | None]: """(async) Check if the conversation should be terminated, and if human reply is provided. This method checks for conditions that require the conversation to be terminated, such as reaching @@ -2046,12 +2048,12 @@ async def a_check_termination_and_human_reply( for the conversation and prints relevant messages based on the human input received. Args: - - messages (Optional[List[Dict]]): A list of message dictionaries, representing the conversation history. - - sender (Optional[Agent]): The agent object representing the sender of the message. - - config (Optional[Any]): Configuration object, defaults to the current instance if not provided. + - messages (list[dict] | None): A list of message dictionaries, representing the conversation history. + - sender (Agent | None): The agent object representing the sender of the message. + - config (Any | None): Configuration object, defaults to the current instance if not provided. Returns: - - Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation + - tuple[bool, str | dict | None]: A tuple containing a boolean indicating if the conversation should be terminated, and a human reply which can be a string, a dictionary, or None. """ iostream = IOStream.get_default() @@ -2146,10 +2148,10 @@ async def a_check_termination_and_human_reply( def generate_reply( self, - messages: Optional[List[Dict[str, Any]]] = None, - sender: Optional["Agent"] = None, + messages: list[dict[str, Any]] | None = None, + sender: "Agent" | None = None, **kwargs: Any, - ) -> Union[str, Dict, None]: + ) -> str | dict | None: """Reply based on the conversation history and the sender. Either messages or sender must be provided. @@ -2179,7 +2181,6 @@ def generate_reply( """ if all((messages is None, sender is None)): error_msg = f"Either {messages=} or {sender=} must be provided." - logger.error(error_msg) raise AssertionError(error_msg) if messages is None: @@ -2201,25 +2202,16 @@ def generate_reply( continue if self._match_trigger(reply_func_tuple["trigger"], sender): final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) - if logging_enabled(): - log_event( - self, - "reply_func_executed", - reply_func_module=reply_func.__module__, - reply_func_name=reply_func.__name__, - final=final, - reply=reply, - ) if final: return reply return self._default_auto_reply async def a_generate_reply( self, - messages: Optional[List[Dict[str, Any]]] = None, - sender: Optional["Agent"] = None, + messages: list[dict[str, Any]] | None = None, + sender: "Agent" | None = None, **kwargs: Any, - ) -> Union[str, Dict[str, Any], None]: + ) -> str | dict[str, Any] | None: """(async) Reply based on the conversation history and the sender. Either messages or sender must be provided. @@ -2249,7 +2241,6 @@ async def a_generate_reply( """ if all((messages is None, sender is None)): error_msg = f"Either {messages=} or {sender=} must be provided." - logger.error(error_msg) raise AssertionError(error_msg) if messages is None: @@ -2279,7 +2270,7 @@ async def a_generate_reply( return reply return self._default_auto_reply - def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Optional[Agent]) -> bool: + def _match_trigger(self, trigger: None | str | type | Agent | Callable | list, sender: Agent | None) -> bool: """Check if the sender matches the trigger. Args: @@ -2435,7 +2426,7 @@ def _format_json_str(jstr): result.append(char) return "".join(result) - def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict[str, str]]: + def execute_function(self, func_call, verbose: bool = False) -> tuple[bool, dict[str, str]]: """Execute a function call and return the result. Override this function to modify the way to execute function and tool calls. @@ -2547,7 +2538,7 @@ async def a_execute_function(self, func_call): "content": str(content), } - def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: + def generate_init_message(self, message: dict | str | None, **kwargs) -> str | dict: """Generate the initial message for the agent. If message is None, input() will be called to get the initial message. @@ -2565,7 +2556,7 @@ def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Un return self._handle_carryover(message, kwargs) - def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]: + def _handle_carryover(self, message: str | dict, kwargs: dict) -> str | dict: if not kwargs.get("carryover"): return message @@ -2602,7 +2593,7 @@ def _process_carryover(self, content: str, kwargs: dict) -> str: ) return content - def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]: + def _process_multimodal_carryover(self, content: list[dict], kwargs: dict) -> list[dict]: """Prepends the context to a multimodal message.""" # Makes sure there's a carryover if not kwargs.get("carryover"): @@ -2610,7 +2601,7 @@ def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> Li return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content - async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: + async def a_generate_init_message(self, message: dict | str | None, **kwargs) -> str | dict: """Generate the initial message for the agent. If message is None, input() will be called to get the initial message. @@ -2625,7 +2616,7 @@ async def a_generate_init_message(self, message: Union[Dict, str, None], **kwarg return self._handle_carryover(message, kwargs) - def register_function(self, function_map: Dict[str, Union[Callable, None]]): + def register_function(self, function_map: dict[str, Callable | None]): """Register functions to the agent. Args: @@ -2640,7 +2631,7 @@ def register_function(self, function_map: Dict[str, Union[Callable, None]]): self._function_map.update(function_map) self._function_map = {k: v for k, v in self._function_map.items() if v is not None} - def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None): + def update_function_signature(self, func_sig: str | dict, is_remove: None): """update a function_signature in the LLM configuration for function_call. Args: @@ -2653,7 +2644,6 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None) if not isinstance(self.llm_config, dict): error_msg = "To update a function signature, agent must have an llm_config" - logger.error(error_msg) raise AssertionError(error_msg) if is_remove: @@ -2687,7 +2677,7 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None) self.client = OpenAIWrapper(**self.llm_config) - def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None): + def update_tool_signature(self, tool_sig: str | dict, is_remove: None): """update a tool_signature in the LLM configuration for tool_call. Args: @@ -2697,7 +2687,6 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None): if not self.llm_config: error_msg = "To update a tool signature, agent must have an llm_config" - logger.error(error_msg) raise AssertionError(error_msg) if is_remove: @@ -2731,13 +2720,13 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None): self.client = OpenAIWrapper(**self.llm_config) - def can_execute_function(self, name: Union[List[str], str]) -> bool: + def can_execute_function(self, name: list[str] | str) -> bool: """Whether the agent can execute the function.""" names = name if isinstance(name, list) else [name] return all([n in self._function_map for n in names]) @property - def function_map(self) -> Dict[str, Callable]: + def function_map(self) -> dict[str, Callable]: """Return the function map.""" return self._function_map @@ -2757,16 +2746,12 @@ def _wrap_function(self, func: F) -> F: @functools.wraps(func) def _wrapped_func(*args, **kwargs): retval = func(*args, **kwargs) - if logging_enabled(): - log_function_use(self, func, kwargs, retval) return serialize_to_str(retval) @load_basemodels_if_needed @functools.wraps(func) async def _a_wrapped_func(*args, **kwargs): retval = await func(*args, **kwargs) - if logging_enabled(): - log_function_use(self, func, kwargs, retval) return serialize_to_str(retval) wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func @@ -2779,8 +2764,8 @@ async def _a_wrapped_func(*args, **kwargs): def register_for_llm( self, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, api_style: Literal["function", "tool"] = "tool", ) -> Callable[[F], F]: """Decorator factory for registering a function to be used by an agent. @@ -2870,7 +2855,7 @@ def _decorator(func: F) -> F: def register_for_execution( self, - name: Optional[str] = None, + name: str | None = None, ) -> Callable[[F], F]: """Decorator factory for registering a function to be executed by an agent. @@ -2954,7 +2939,7 @@ def register_hook(self, hookable_method: str, hook: Callable): hook_list.append(hook) - def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: + def process_all_messages_before_reply(self, messages: list[dict]) -> list[dict]: """ Calls any registered capability hooks to process all messages, potentially modifying the messages. """ @@ -2971,7 +2956,7 @@ def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: processed_messages = hook(processed_messages) return processed_messages - async def a_process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: + async def a_process_all_messages_before_reply(self, messages: list[dict]) -> list[dict]: """ Calls any registered capability hooks to process all messages, potentially modifying the messages. """ @@ -2988,7 +2973,7 @@ async def a_process_all_messages_before_reply(self, messages: List[Dict]) -> Lis processed_messages = await hook(processed_messages) return processed_messages - def process_last_received_message(self, messages: List[Dict]) -> List[Dict]: + def process_last_received_message(self, messages: list[dict]) -> list[dict]: """ Calls any registered capability hooks to use and potentially modify the text of the last message, as long as the last message is not a function call or exit command. @@ -3033,7 +3018,7 @@ def process_last_received_message(self, messages: List[Dict]) -> List[Dict]: messages[-1]["content"] = processed_user_content return messages - async def a_process_last_received_message(self, messages: List[Dict]) -> List[Dict]: + async def a_process_last_received_message(self, messages: list[dict]) -> list[dict]: """ Calls any registered capability hooks to use and potentially modify the text of the last message, as long as the last message is not a function call or exit command. @@ -3078,7 +3063,7 @@ async def a_process_last_received_message(self, messages: List[Dict]) -> List[Di messages[-1]["content"] = processed_user_content return messages - def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None: + def print_usage_summary(self, mode: str | list[str] = ["actual", "total"]) -> None: """Print the usage summary.""" iostream = IOStream.get_default() @@ -3088,14 +3073,14 @@ def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) iostream.print(f"Agent '{self.name}':") self.client.print_usage_summary(mode) - def get_actual_usage(self) -> Union[None, Dict[str, int]]: + def get_actual_usage(self) -> dict[str, int] | None: """Get the actual usage summary.""" if self.client is None: return None else: return self.client.actual_usage_summary - def get_total_usage(self) -> Union[None, Dict[str, int]]: + def get_total_usage(self) -> dict[str, int] | None: """Get the total usage summary.""" if self.client is None: return None @@ -3108,7 +3093,7 @@ def register_function( *, caller: ConversableAgent, executor: ConversableAgent, - name: Optional[str] = None, + name: str | None = None, description: str, ) -> None: """Register a function to be proposed by an agent and executed for an executor. @@ -3127,4 +3112,4 @@ def register_function( """ f = caller.register_for_llm(name=name, description=description)(f) - executor.register_for_execution(name=name)(f) \ No newline at end of file + executor.register_for_execution(name=name)(f) diff --git a/train_methods/legacy_autogen/stream.py b/train_methods/legacy_autogen/stream.py new file mode 100644 index 0000000..7ae567b --- /dev/null +++ b/train_methods/legacy_autogen/stream.py @@ -0,0 +1,91 @@ +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Protocol, Any, Iterator + + +class OutputStream(Protocol): + def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None: + """Print data to the output stream. + + Args: + objects (any): The data to print. + sep (str, optional): The separator between objects. Defaults to " ". + end (str, optional): The end of the output. Defaults to "\n". + flush (bool, optional): Whether to flush the output. Defaults to False. + """ + ... # pragma: no cover + + +class InputStream(Protocol): + def input(self, prompt: str = "", *, password: bool = False) -> str: + """Read a line from the input stream. + + Args: + prompt (str, optional): The prompt to display. Defaults to "". + password (bool, optional): Whether to read a password. Defaults to False. + + Returns: + str: The line read from the input stream. + + """ + ... # pragma: no cover + + +class IOStream(InputStream, OutputStream, Protocol): + """A protocol for input/output streams.""" + + # ContextVar must be used in multithreaded or async environments + _default_io_stream: ContextVar["IOStream" | None] = ContextVar("default_iostream", default=None) + _default_io_stream.set(None) + _global_default: "IOStream" | None = None + + @staticmethod + def set_global_default(stream: "IOStream") -> None: + """Set the default input/output stream. + + Args: + stream (IOStream): The input/output stream to set as the default. + """ + IOStream._global_default = stream + + @staticmethod + def get_global_default() -> "IOStream": + """Get the default input/output stream. + + Returns: + IOStream: The default input/output stream. + """ + if IOStream._global_default is None: + raise RuntimeError("No global default IOStream has been set") + return IOStream._global_default + + @staticmethod + def get_default() -> "IOStream": + """Get the default input/output stream. + + Returns: + IOStream: The default input/output stream. + """ + iostream = IOStream._default_io_stream.get() + if iostream is None: + iostream = IOStream.get_global_default() + # Set the default IOStream of the current context (thread/cooroutine) + IOStream.set_default(iostream) + return iostream + + @staticmethod + @contextmanager + def set_default(stream: "IOStream" | None) -> Iterator[None]: + """Set the default input/output stream. + + Args: + stream (IOStream): The input/output stream to set as the default. + """ + global _default_io_stream + try: + token = IOStream._default_io_stream.set(stream) + yield + finally: + IOStream._default_io_stream.reset(token) + + return diff --git a/train_methods/legacy_autogen/utils.py b/train_methods/legacy_autogen/utils.py new file mode 100644 index 0000000..f3d66a4 --- /dev/null +++ b/train_methods/legacy_autogen/utils.py @@ -0,0 +1,726 @@ +import logging +import os +import pathlib +import re +import string +import subprocess +import sys +import time +import venv +from concurrent.futures import ThreadPoolExecutor, TimeoutError +from hashlib import md5 +from types import SimpleNamespace +from typing import Callable, Literal, TypedDict + +import docker + +from train_methods.legacy_autogen.completion import Completion + +SENTINEL = object() +DEFAULT_MODEL = "gpt-4" +FAST_MODEL = "gpt-3.5-turbo" +CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```" +WORKING_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extensions") +UNKNOWN = "unknown" +TIMEOUT_MSG = "Timeout" +DEFAULT_TIMEOUT = 600 +WIN32 = sys.platform == "win32" +PATH_SEPARATOR = WIN32 and "\\" or "/" +PYTHON_VARIANTS = ["python", "Python", "py"] + +logger = logging.getLogger(__name__) + +class UserMessageTextContentPart(TypedDict): + type: Literal["text"] + text: str + +class UserMessageImageContentPart(TypedDict): + type: Literal["image_url"] + image_url: dict[Literal["url"], str] + + +def content_str(content: str | list[UserMessageTextContentPart | UserMessageImageContentPart] | None) -> str: + """Converts the `content` field of an OpenAI message into a string format. + + This function processes content that may be a string, a list of mixed text and image URLs, or None, + and converts it into a string. Text is directly appended to the result string, while image URLs are + represented by a placeholder image token. If the content is None, an empty string is returned. + + Args: + - content (Union[str, List, None]): The content to be processed. Can be a string, a list of dictionaries + representing text and image URLs, or None. + + Returns: + str: A string representation of the input content. Image URLs are replaced with an image token. + + Note: + - The function expects each dictionary in the list to have a "type" key that is either "text" or "image_url". + For "text" type, the "text" key's value is appended to the result. For "image_url", an image token is appended. + - This function is useful for handling content that may include both text and image references, especially + in contexts where images need to be represented as placeholders. + """ + if content is None: + return "" + if isinstance(content, str): + return content + if not isinstance(content, list): + raise TypeError(f"content must be None, str, or list, but got {type(content)}") + + rst = "" + for item in content: + if not isinstance(item, dict): + raise TypeError("Wrong content format: every element should be dict if the content is a list.") + assert "type" in item, "Wrong content format. Missing 'type' key in content's dict." + if item["type"] == "text": + rst += item["text"] + elif item["type"] == "image_url": + rst += "" + else: + raise ValueError(f"Wrong content format: unknown type {item['type']} within the content") + return rst + + +def infer_lang(code: str) -> str: + """infer the language for the code. + TODO: make it robust. + """ + if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "): + return "sh" + + # check if code is a valid python code + try: + compile(code, "test", "exec") + return "python" + except SyntaxError: + # not a valid python code + return UNKNOWN + + +# TODO: In the future move, to better support https://spec.commonmark.org/0.30/#fenced-code-blocks +# perhaps by using a full Markdown parser. +def extract_code( + text: str | list, pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False +) -> list[tuple[str, str]]: + """Extract code from a text. + + Args: + text (str or List): The content to extract code from. The content can be + a string or a list, as returned by standard GPT or multimodal GPT. + pattern (str, optional): The regular expression pattern for finding the + code block. Defaults to CODE_BLOCK_PATTERN. + detect_single_line_code (bool, optional): Enable the new feature for + extracting single line code. Defaults to False. + + Returns: + list: A list of tuples, each containing the language and the code. + If there is no code block in the input text, the language would be "unknown". + If there is code block but the language is not specified, the language would be "". + """ + text = content_str(text) + if not detect_single_line_code: + match = re.findall(pattern, text, flags=re.DOTALL) + return match if match else [(UNKNOWN, text)] + + # Extract both multi-line and single-line code block, separated by the | operator + # `([^`]+)`: Matches inline code. + code_pattern = re.compile(CODE_BLOCK_PATTERN + r"|`([^`]+)`") + code_blocks = code_pattern.findall(text) + + # Extract the individual code blocks and languages from the matched groups + extracted = [] + for lang, group1, group2 in code_blocks: + if group1: + extracted.append((lang.strip(), group1.strip())) + elif group2: + extracted.append(("", group2.strip())) + + return extracted + + +def generate_code(pattern: str = CODE_BLOCK_PATTERN, **config) -> tuple[str, float]: + """(openai<1) Generate code. + + Args: + pattern (Optional, str): The regular expression pattern for finding the code block. + The default pattern is for finding a code block in a markdown file. + config (Optional, dict): The configuration for the API call. + + Returns: + str: The generated code. + float: The cost of the generation. + """ + response = Completion.create(**config) + return extract_code(Completion.extract_text(response)[0], pattern), response["cost"] + + +_IMPROVE_FUNCTION_CONFIG = { + "prompt": """Improve the function '{func_name}' to achieve the objective '{objective}'. +The current implementation of the function is as follows: +{file_string}""", + "model": DEFAULT_MODEL, + "request_timeout": 600, +} + + +def improve_function(file_name, func_name, objective, **config): + """(openai<1) Improve the function to achieve the objective.""" + params = {**_IMPROVE_FUNCTION_CONFIG, **config} + # read the entire file into a str + with open(file_name, "r") as f: + file_string = f.read() + response = Completion.create( + {"func_name": func_name, "objective": objective, "file_string": file_string}, **params + ) + return Completion.extract_text(response)[0], response["cost"] + + +_IMPROVE_CODE_CONFIG = { + "prompt": """Analyze the code in the following files and return a list of suggestions for improvement{followup}, to achieve the objective of '{objective}'. +{code} +""", + "model": DEFAULT_MODEL, + "request_timeout": 900, +} + + +def improve_code(files, objective, suggest_only=True, **config): + """(openai<1) Improve the code to achieve a given objective. + + Args: + files (list): A list of file names containing the source code. + objective (str): The objective to achieve. + suggest_only (bool): Whether to return only the suggestions or the improved code. + config (Optional, dict): The configuration for the API call. + + Returns: + str: The improved code if suggest_only=False; a list of suggestions if suggest_only=True (default). + float: The cost of the generation. + """ + code = "" + for file_name in files: + # read the entire file into a string + with open(file_name, "r") as f: + file_string = f.read() + code += f"""{file_name}: +{file_string} + +""" + params = {**_IMPROVE_CODE_CONFIG, **config} + followup = "" if suggest_only else " followed by the improved code" + response = Completion.create({"objective": objective, "code": code, "followup": followup}, **params) + return Completion.extract_text(response)[0], response["cost"] + + +def timeout_handler(signum, frame): + raise TimeoutError("Timed out!") + + +def get_powershell_command(): + try: + result = subprocess.run(["powershell", "$PSVersionTable.PSVersion.Major"], capture_output=True, text=True) + if result.returncode == 0: + return "powershell" + except (FileNotFoundError, NotADirectoryError): + # This means that 'powershell' command is not found so now we try looking for 'pwsh' + try: + result = subprocess.run( + ["pwsh", "-Command", "$PSVersionTable.PSVersion.Major"], capture_output=True, text=True + ) + if result.returncode == 0: + return "pwsh" + except FileExistsError as e: + raise FileNotFoundError( + "Neither powershell.exe nor pwsh.exe is present in the system. " + "Please install PowerShell and try again. " + ) from e + except NotADirectoryError as e: + raise NotADirectoryError( + "PowerShell is either not installed or its path is not given " + "properly in the environment variable PATH. Please check the " + "path and try again. " + ) from e + except PermissionError as e: + raise PermissionError("No permission to run powershell.") from e + + +def _cmd(lang: str) -> str: + if lang in PYTHON_VARIANTS: + return "python" + if lang.startswith("python") or lang in ["bash", "sh"]: + return lang + if lang in ["shell"]: + return "sh" + if lang == "javascript": + return "node" + if lang in ["ps1", "pwsh", "powershell"]: + powershell_command = get_powershell_command() + return powershell_command + + raise NotImplementedError(f"{lang} not recognized in code execution") + + +def is_docker_running() -> bool: + """Check if docker is running. + + Returns: + bool: True if docker is running; False otherwise. + """ + try: + client = docker.from_env() + client.ping() + return True + except docker.errors.DockerException: + return False + + +def in_docker_container() -> bool: + """Check if the code is running in a docker container. + + Returns: + bool: True if the code is running in a docker container; False otherwise. + """ + return os.path.exists("/.dockerenv") + + +def decide_use_docker(use_docker: bool | None) -> bool | None: + if use_docker is None: + env_var_use_docker = os.environ.get("AUTOGEN_USE_DOCKER", "True") + + truthy_values = {"1", "true", "yes", "t"} + falsy_values = {"0", "false", "no", "f"} + + # Convert the value to lowercase for case-insensitive comparison + env_var_use_docker_lower = env_var_use_docker.lower() + + # Determine the boolean value based on the environment variable + if env_var_use_docker_lower in truthy_values: + use_docker = True + elif env_var_use_docker_lower in falsy_values: + use_docker = False + elif env_var_use_docker_lower == "none": # Special case for 'None' as a string + use_docker = None + else: + # Raise an error for any unrecognized value + raise ValueError( + f'Invalid value for AUTOGEN_USE_DOCKER: {env_var_use_docker}. Please set AUTOGEN_USE_DOCKER to "1/True/yes", "0/False/no", or "None".' + ) + return use_docker + + +def check_can_use_docker_or_throw(use_docker) -> None: + if use_docker is not None: + inside_docker = in_docker_container() + docker_installed_and_running = is_docker_running() + if use_docker and not inside_docker and not docker_installed_and_running: + raise RuntimeError( + "Code execution is set to be run in docker (default behaviour) but docker is not running.\n" + "The options available are:\n" + "- Make sure docker is running (advised approach for code execution)\n" + '- Set "use_docker": False in code_execution_config\n' + '- Set AUTOGEN_USE_DOCKER to "0/False/no" in your environment variables' + ) + + +def _sanitize_filename_for_docker_tag(filename: str) -> str: + """Convert a filename to a valid docker tag. + See https://docs.docker.com/engine/reference/commandline/tag/ for valid tag + format. + + Args: + filename (str): The filename to be converted. + + Returns: + str: The sanitized Docker tag. + """ + # Replace any character not allowed with an underscore + allowed_chars = set(string.ascii_letters + string.digits + "_.-") + sanitized = "".join(char if char in allowed_chars else "_" for char in filename) + + # Ensure it does not start with a period or a dash + if sanitized.startswith(".") or sanitized.startswith("-"): + sanitized = "_" + sanitized[1:] + + # Truncate if longer than 128 characters + return sanitized[:128] + + +def execute_code( + code: str | None = None, + timeout: int | None = None, + filename: str | None = None, + work_dir: str | None = None, + use_docker: list[str] | str | bool = SENTINEL, + lang: str | None = "python", +) -> tuple[int, str, str | None]: + """Execute code in a docker container. + This function is not tested on MacOS. + + Args: + code (Optional, str): The code to execute. + If None, the code from the file specified by filename will be executed. + Either code or filename must be provided. + timeout (Optional, int): The maximum execution time in seconds. + If None, a default timeout will be used. The default timeout is 600 seconds. On Windows, the timeout is not enforced when use_docker=False. + filename (Optional, str): The file name to save the code or where the code is stored when `code` is None. + If None, a file with a randomly generated name will be created. + The randomly generated file will be deleted after execution. + The file name must be a relative path. Relative paths are relative to the working directory. + work_dir (Optional, str): The working directory for the code execution. + If None, a default working directory will be used. + The default working directory is the "extensions" directory under + "path_to_autogen". + use_docker (list, str or bool): The docker image to use for code execution. + Default is True, which means the code will be executed in a docker container. A default list of images will be used. + If a list or a str of image name(s) is provided, the code will be executed in a docker container + with the first image successfully pulled. + If False, the code will be executed in the current environment. + Expected behaviour: + - If `use_docker` is not set (i.e. left default to True) or is explicitly set to True and the docker package is available, the code will run in a Docker container. + - If `use_docker` is not set (i.e. left default to True) or is explicitly set to True but the Docker package is missing or docker isn't running, an error will be raised. + - If `use_docker` is explicitly set to False, the code will run natively. + If the code is executed in the current environment, + the code must be trusted. + lang (Optional, str): The language of the code. Default is "python". + + Returns: + int: 0 if the code executes successfully. + str: The error message if the code fails to execute; the stdout otherwise. + image: The docker image name after container run when docker is used. + """ + if all((code is None, filename is None)): + error_msg = f"Either {code=} or {filename=} must be provided." + logger.error(error_msg) + raise AssertionError(error_msg) + + running_inside_docker = in_docker_container() + docker_running = is_docker_running() + + # SENTINEL is used to indicate that the user did not explicitly set the argument + if use_docker is SENTINEL: + use_docker = decide_use_docker(use_docker=None) + check_can_use_docker_or_throw(use_docker) + + timeout = timeout or DEFAULT_TIMEOUT + original_filename = filename + if WIN32 and lang in ["sh", "shell"] and (not use_docker): + lang = "ps1" + if filename is None: + code_hash = md5(code.encode()).hexdigest() + # create a file with a automatically generated name + filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}" + if work_dir is None: + work_dir = WORKING_DIR + + filepath = os.path.join(work_dir, filename) + file_dir = os.path.dirname(filepath) + os.makedirs(file_dir, exist_ok=True) + + if code is not None: + with open(filepath, "w", encoding="utf-8") as fout: + fout.write(code) + + if not use_docker or running_inside_docker: + # already running in a docker container + cmd = [ + sys.executable if lang.startswith("python") else _cmd(lang), + f".\\{filename}" if WIN32 else filename, + ] + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit( + subprocess.run, + cmd, + cwd=work_dir, + capture_output=True, + text=True, + ) + try: + result = future.result(timeout=timeout) + except TimeoutError: + if original_filename is None: + os.remove(filepath) + return 1, TIMEOUT_MSG, None + if original_filename is None: + os.remove(filepath) + if result.returncode: + logs = result.stderr + if original_filename is None: + abs_path = str(pathlib.Path(filepath).absolute()) + logs = logs.replace(str(abs_path), "").replace(filename, "") + else: + abs_path = str(pathlib.Path(work_dir).absolute()) + PATH_SEPARATOR + logs = logs.replace(str(abs_path), "") + else: + logs = result.stdout + return result.returncode, logs, None + + # create a docker client + if use_docker and not docker_running: + raise RuntimeError( + "Docker package is missing or docker is not running. Please make sure docker is running or set use_docker=False." + ) + + client = docker.from_env() + + image_list = ( + ["python:3-slim", "python:3", "python:3-windowsservercore"] + if use_docker is True + else [use_docker] if isinstance(use_docker, str) else use_docker + ) + for image in image_list: + # check if the image exists + try: + client.images.get(image) + break + except docker.errors.ImageNotFound: + # pull the image + print("Pulling image", image) + try: + client.images.pull(image) + break + except docker.errors.DockerException: + print("Failed to pull image", image) + # get a randomized str based on current time to wrap the exit code + exit_code_str = f"exitcode{time.time()}" + abs_path = pathlib.Path(work_dir).absolute() + cmd = [ + "sh", + "-c", + f'{_cmd(lang)} "{filename}"; exit_code=$?; echo -n {exit_code_str}; echo -n $exit_code; echo {exit_code_str}', + ] + # create a docker container + container = client.containers.run( + image, + command=cmd, + working_dir="/workspace", + detach=True, + # get absolute path to the working directory + volumes={abs_path: {"bind": "/workspace", "mode": "rw"}}, + ) + start_time = time.time() + while container.status != "exited" and time.time() - start_time < timeout: + # Reload the container object + container.reload() + if container.status != "exited": + container.stop() + container.remove() + if original_filename is None: + os.remove(filepath) + return 1, TIMEOUT_MSG, image + # get the container logs + logs = container.logs().decode("utf-8").rstrip() + # commit the image + tag = _sanitize_filename_for_docker_tag(filename) + container.commit(repository="python", tag=tag) + # remove the container + container.remove() + # check if the code executed successfully + exit_code = container.attrs["State"]["ExitCode"] + if exit_code == 0: + # extract the exit code from the logs + pattern = re.compile(f"{exit_code_str}(\\d+){exit_code_str}") + match = pattern.search(logs) + exit_code = 1 if match is None else int(match.group(1)) + # remove the exit code from the logs + logs = logs if match is None else pattern.sub("", logs) + + if original_filename is None: + os.remove(filepath) + if exit_code: + logs = logs.replace(f"/workspace/{filename if original_filename is None else ''}", "") + # return the exit code, logs and image + return exit_code, logs, f"python:{tag}" + + +_GENERATE_ASSERTIONS_CONFIG = { + "prompt": """Given the signature and docstring, write the exactly same number of assertion(s) for the provided example(s) in the docstring, without assertion messages. + +func signature: +{definition} +assertions:""", + "model": FAST_MODEL, + "max_tokens": 256, + "stop": "\n\n", +} + + +def generate_assertions(definition: str, **config) -> tuple[str, float]: + """(openai<1) Generate assertions for a function. + + Args: + definition (str): The function definition, including the signature and docstr. + config (Optional, dict): The configuration for the API call. + + Returns: + str: The generated assertions. + float: The cost of the generation. + """ + params = {**_GENERATE_ASSERTIONS_CONFIG, **config} + response = Completion.create( + {"definition": definition}, + **params, + ) + assertions = Completion.extract_text(response)[0] + return assertions, response["cost"] + + +def _remove_check(response): + """Remove the check function from the response.""" + # find the position of the check function + pos = response.find("def check(") + if pos == -1: + return response + return response[:pos] + + +def eval_function_completions( + responses: list[str], + definition: str, + test: str | None = None, + entry_point: str | None = None, + assertions: str | Callable[[str], tuple[str, float]] | None = None, + timeout: float | None = 3, + use_docker: bool | None = True, +) -> dict: + """(openai<1) Select a response from a list of responses for the function completion task (using generated assertions), and/or evaluate if the task is successful using a gold test. + + Args: + responses (list): The list of responses. + definition (str): The input definition. + test (Optional, str): The test code. + entry_point (Optional, str): The name of the function. + assertions (Optional, str or Callable): The assertion code which serves as a filter of the responses, or an assertion generator. + When provided, only the responses that pass the assertions will be considered for the actual test (if provided). + timeout (Optional, float): The timeout for executing the code. + + Returns: + dict: The success metrics. + """ + n = len(responses) + if assertions is None: + # no assertion filter + success_list = [] + for i in range(n): + response = _remove_check(responses[i]) + code = ( + f"{response}\n{test}\ncheck({entry_point})" + if response.startswith("def") + else f"{definition}{response}\n{test}\ncheck({entry_point})" + ) + success = execute_code(code, timeout=timeout, use_docker=use_docker)[0] == 0 + success_list.append(success) + return { + "expected_success": 1 - pow(1 - sum(success_list) / n, n), + "success": any(s for s in success_list), + } + if callable(assertions) and n > 1: + # assertion generator + assertions, gen_cost = assertions(definition) + else: + assertions, gen_cost = None, 0 + if n > 1 or test is None: + for i in range(n): + response = responses[i] = _remove_check(responses[i]) + code = ( + f"{response}\n{assertions}" if response.startswith("def") else f"{definition}{response}\n{assertions}" + ) + succeed_assertions = execute_code(code, timeout=timeout, use_docker=use_docker)[0] == 0 + if succeed_assertions: + break + else: + # just test, no need to check assertions + succeed_assertions = False + i, response = 0, responses[0] + if test is None: + # no test code + return { + "index_selected": i, + "succeed_assertions": succeed_assertions, + "gen_cost": gen_cost, + "assertions": assertions, + } + code_test = ( + f"{response}\n{test}\ncheck({entry_point})" + if response.startswith("def") + else f"{definition}{response}\n{test}\ncheck({entry_point})" + ) + success = execute_code(code_test, timeout=timeout, use_docker=use_docker)[0] == 0 + return { + "index_selected": i, + "succeed_assertions": succeed_assertions, + "success": success, + "gen_cost": gen_cost, + "assertions": assertions, + } + + +_FUNC_COMPLETION_PROMPT = "# Python 3{definition}" +_FUNC_COMPLETION_STOP = ["\nclass", "\ndef", "\nif", "\nprint"] +_IMPLEMENT_CONFIGS = [ + {"model": FAST_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "temperature": 0, "cache_seed": 0}, + {"model": FAST_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 7, "cache_seed": 0}, + {"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "temperature": 0, "cache_seed": 1}, + {"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 2, "cache_seed": 2}, + {"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 1, "cache_seed": 2}, +] + + +class PassAssertionFilter: + def __init__(self, assertions): + self._assertions = assertions + self.cost = 0 + self.metrics = self.responses = None + + def pass_assertions(self, context, response, **_): + """(openai<1) Check if the response passes the assertions.""" + responses = Completion.extract_text(response) + metrics = eval_function_completions(responses, context["definition"], assertions=self._assertions) + self._assertions = metrics["assertions"] + self.cost += metrics["gen_cost"] + self.metrics = metrics + self.responses = responses + return metrics["succeed_assertions"] + + +def implement( + definition: str, + configs: list[dict] | None = None, + assertions: str | Callable[[str], tuple[str, float]] | None = generate_assertions, +) -> tuple[str, float]: + """(openai<1) Implement a function from a definition. + + Args: + definition (str): The function definition, including the signature and docstr. + configs (list): The list of configurations for completion. + assertions (Optional, str or Callable): The assertion code which serves as a filter of the responses, or an assertion generator. + + Returns: + str: The implementation. + float: The cost of the implementation. + int: The index of the configuration which generates the implementation. + """ + cost = 0 + configs = configs or _IMPLEMENT_CONFIGS + if len(configs) > 1 and callable(assertions): + assertions, cost = assertions(definition) + assertion_filter = PassAssertionFilter(assertions) + response = Completion.create( + {"definition": definition}, config_list=configs, filter_func=assertion_filter.pass_assertions + ) + cost += assertion_filter.cost + response["cost"] + return assertion_filter.responses[assertion_filter.metrics["index_selected"]], cost, response["config_id"] + + +def create_virtual_env(dir_path: str, **env_args) -> SimpleNamespace: + """Creates a python virtual environment and returns the context. + + Args: + dir_path (str): Directory path where the env will be created. + **env_args: Any extra args to pass to the `EnvBuilder` + + Returns: + SimpleNamespace: the virtual env context object.""" + if not env_args: + env_args = {"with_pip": True} + env_builder = venv.EnvBuilder(**env_args) + env_builder.create(dir_path) + return env_builder.ensure_directories(dir_path) diff --git a/train_methods/utils_cogfd.py b/train_methods/utils_cogfd.py index 9da66ae..a9647ba 100644 --- a/train_methods/utils_cogfd.py +++ b/train_methods/utils_cogfd.py @@ -18,8 +18,8 @@ from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel from transformers.utils import ModelOutput -from train_methods.legacy_autogen import GroupChat -from train_methods.legacy_autogen_conversable_agent import ConversableAgent +from train_methods.legacy_autogen.legacy_autogen import GroupChat +from train_methods.legacy_autogen.legacy_autogen_conversable_agent import ConversableAgent @dataclass class TransformationModelOutput(ModelOutput): From 18bd110f4378584406263492a6e6b21ab941b2dd Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Mon, 27 Oct 2025 01:33:34 +0900 Subject: [PATCH 04/25] maybe completed autogen --- requirements.txt | 1 + train_methods/legacy_autogen/coding.py | 848 ++++++++++++++++++ .../legacy_autogen/legacy_autogen.py | 2 +- .../legacy_autogen_conversable_agent.py | 154 +++- train_methods/legacy_autogen/utils.py | 344 ++++++- train_methods/utils_cogfd.py | 12 +- 6 files changed, 1342 insertions(+), 19 deletions(-) create mode 100644 train_methods/legacy_autogen/coding.py diff --git a/requirements.txt b/requirements.txt index a598a6e..1ce58e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,7 @@ azure-cosmos==4.14.0 azure-identity==1.25.1 docker==7.1.0 flaml==2.3.6 +jupyter-client==8.6.3 open_clip_torch==2.29.0 bitsandbytes==0.44.1 diff --git a/train_methods/legacy_autogen/coding.py b/train_methods/legacy_autogen/coding.py new file mode 100644 index 0000000..fb82d42 --- /dev/null +++ b/train_methods/legacy_autogen/coding.py @@ -0,0 +1,848 @@ +import base64 +import json +import os +import re +import inspect +import importlib +import subprocess +import sys +import uuid +import warnings +from dataclasses import dataclass, field +from pathlib import Path +from hashlib import md5 +from string import Template +from importlib.abc import SourceLoader +from queue import Empty +from textwrap import indent, dedent +from types import SimpleNamespace +from typing import Protocol, Literal, TypedDict, Mapping, Any, ClassVar, Callable, Generic, TypeVar +from typing_extensions import ParamSpec + +from jupyter_client import KernelManager +from jupyter_client.kernelspec import KernelSpecManager +from pydantic import BaseModel, Field, field_validator + + +from pydantic import BaseModel, Field + +from train_methods.legacy_autogen.utils import UserMessageImageContentPart, UserMessageTextContentPart, content_str, infer_lang, UNKNOWN, CODE_BLOCK_PATTERN, PYTHON_VARIANTS, WIN32, TIMEOUT_MSG, _cmd + +A = ParamSpec("A") +T = TypeVar("T") +P = ParamSpec("P") + +class CodeBlock(BaseModel): + """(Experimental) A class that represents a code block.""" + + code: str = Field(description="The code to execute.") + + language: str = Field(description="The language of the code.") + +class CodeResult(BaseModel): + """(Experimental) A class that represents the result of a code execution.""" + + exit_code: int = Field(description="The exit code of the code execution.") + + output: str = Field(description="The output of the code execution.") + + +class CodeExtractor(Protocol): + """(Experimental) A code extractor class that extracts code blocks from a message.""" + + def extract_code_blocks( + self, message: str | list[UserMessageTextContentPart | UserMessageImageContentPart] | None + ) -> list[CodeBlock]: + """(Experimental) Extract code blocks from a message. + + Args: + message (str): The message to extract code blocks from. + + Returns: + List[CodeBlock]: The extracted code blocks. + """ + ... # pragma: no cover + +class CodeExecutor(Protocol): + """(Experimental) A code executor class that executes code blocks and returns the result.""" + + @property + def code_extractor(self) -> CodeExtractor: + """(Experimental) The code extractor used by this code executor.""" + ... # pragma: no cover + + def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CodeResult: + """(Experimental) Execute code blocks and return the result. + + This method should be implemented by the code executor. + + Args: + code_blocks (List[CodeBlock]): The code blocks to execute. + + Returns: + CodeResult: The result of the code execution. + """ + ... # pragma: no cover + + def restart(self) -> None: + """(Experimental) Restart the code executor. + + This method should be implemented by the code executor. + + This method is called when the agent is reset. + """ + ... # pragma: no cover + + +class IPythonCodeResult(CodeResult): + """(Experimental) A code result class for IPython code executor.""" + + output_files: list[str] = Field( + default_factory=list, + description="The list of files that the executed code blocks generated.", + ) + +class MarkdownCodeExtractor(CodeExtractor): + """(Experimental) A class that extracts code blocks from a message using Markdown syntax.""" + + def extract_code_blocks( + self, message: str | list[UserMessageTextContentPart | UserMessageImageContentPart] | None + ) -> list[CodeBlock]: + """(Experimental) Extract code blocks from a message. If no code blocks are found, + return an empty list. + + Args: + message (str): The message to extract code blocks from. + + Returns: + List[CodeBlock]: The extracted code blocks or an empty list. + """ + + text = content_str(message) + match = re.findall(CODE_BLOCK_PATTERN, text, flags=re.DOTALL) + if not match: + return [] + code_blocks = [] + for lang, code in match: + if lang == "": + lang = infer_lang(code) + if lang == UNKNOWN: + lang = "" + code_blocks.append(CodeBlock(code=code, language=lang)) + return code_blocks + + +CodeExecutionConfig = TypedDict( + "CodeExecutionConfig", + { + "executor": Literal["ipython-embedded", "commandline-local"] | CodeExecutor, + "last_n_messages": int | Literal["auto"], + "timeout": int, + "use_docker": bool | str | list[str], + "work_dir": str, + "ipython-embedded": Mapping[str, Any], + "commandline-local": Mapping[str, Any], + }, + total=False, +) + + +class EmbeddedIPythonCodeExecutor(BaseModel): + """(Experimental) A code executor class that executes code statefully using an embedded + IPython kernel managed by this class. + + **This will execute LLM generated code on the local machine.** + + Each execution is stateful and can access variables created from previous + executions in the same session. The kernel must be installed before using + this class. The kernel can be installed using the following command: + `python -m ipykernel install --user --name {kernel_name}` + where `kernel_name` is the name of the kernel to install. + + Args: + timeout (int): The timeout for code execution, by default 60. + kernel_name (str): The kernel name to use. Make sure it is installed. + By default, it is "python3". + output_dir (str): The directory to save output files, by default ".". + """ + + timeout: int = Field(default=60, ge=1, description="The timeout for code execution.") + kernel_name: str = Field(default="python3", description="The kernel name to use. Make sure it is installed.") + output_dir: str = Field(default=".", description="The directory to save output files.") + + @field_validator("output_dir") + @classmethod + def _output_dir_must_exist(cls, value: str) -> str: + if not os.path.exists(value): + raise ValueError(f"Output directory {value} does not exist.") + return value + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + # Check if the kernel is installed. + if self.kernel_name not in KernelSpecManager().find_kernel_specs(): + raise ValueError( + f"Kernel {self.kernel_name} is not installed. " + "Please first install it with " + f"`python -m ipykernel install --user --name {self.kernel_name}`." + ) + self._kernel_manager = KernelManager(kernel_name=self.kernel_name) + self._kernel_manager.start_kernel() + self._kernel_client = self._kernel_manager.client() + self._kernel_client.start_channels() + self._timeout = self.timeout + self._kernel_name = self.kernel_name + self._output_dir = Path(self.output_dir) + + @property + def code_extractor(self) -> CodeExtractor: + """(Experimental) Export a code extractor that can be used by an agent.""" + return MarkdownCodeExtractor() + + def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> IPythonCodeResult: + """(Experimental) Execute a list of code blocks and return the result. + + This method executes a list of code blocks as cells in an IPython kernel + managed by this class. + See: https://jupyter-client.readthedocs.io/en/stable/messaging.html + for the message protocol. + + Args: + code_blocks (List[CodeBlock]): A list of code blocks to execute. + + Returns: + IPythonCodeResult: The result of the code execution. + """ + self._kernel_client.wait_for_ready() + outputs = [] + output_files = [] + for code_block in code_blocks: + code = self._process_code(code_block.code) + self._kernel_client.execute(code, store_history=True) + while True: + try: + msg = self._kernel_client.get_iopub_msg(timeout=self._timeout) + msg_type = msg["msg_type"] + content = msg["content"] + if msg_type in ["execute_result", "display_data"]: + for data_type, data in content["data"].items(): + if data_type == "text/plain": + # Output is a text. + outputs.append(data) + elif data_type.startswith("image/"): + # Output is an image. + path = self._save_image(data) + outputs.append(f"Image data saved to {path}") + output_files.append(path) + elif data_type == "text/html": + # Output is an html. + path = self._save_html(data) + outputs.append(f"HTML data saved to {path}") + output_files.append(path) + else: + # Output raw data. + outputs.append(json.dumps(data)) + elif msg_type == "stream": + # Output is a text. + outputs.append(content["text"]) + elif msg_type == "error": + # Output is an error. + return IPythonCodeResult( + exit_code=1, + output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}", + ) + if msg_type == "status" and content["execution_state"] == "idle": + break + # handle time outs. + except Empty: + return IPythonCodeResult( + exit_code=1, + output=f"ERROR: Timeout waiting for output from code block: {code_block.code}", + ) + # We return the full output. + return IPythonCodeResult( + exit_code=0, output="\n".join([str(output) for output in outputs]), output_files=output_files + ) + + def restart(self) -> None: + """(Experimental) Restart a new session.""" + self._kernel_client.stop_channels() + self._kernel_manager.shutdown_kernel() + self._kernel_manager = KernelManager(kernel_name=self.kernel_name) + self._kernel_manager.start_kernel() + self._kernel_client = self._kernel_manager.client() + self._kernel_client.start_channels() + + def _save_image(self, image_data_base64: str) -> str: + """Save image data to a file.""" + image_data = base64.b64decode(image_data_base64) + # Randomly generate a filename. + filename = f"{uuid.uuid4().hex}.png" + path = os.path.join(self.output_dir, filename) + with open(path, "wb") as f: + f.write(image_data) + return os.path.abspath(path) + + def _save_html(self, html_data: str) -> str: + """Save html data to a file.""" + # Randomly generate a filename. + filename = f"{uuid.uuid4().hex}.html" + path = os.path.join(self.output_dir, filename) + with open(path, "w") as f: + f.write(html_data) + return os.path.abspath(path) + + def _process_code(self, code: str) -> str: + """Process code before execution.""" + # Find lines that start with `! pip install` and make sure "-qqq" flag is added. + lines = code.split("\n") + for i, line in enumerate(lines): + # use regex to find lines that start with `! pip install` or `!pip install`. + match = re.search(r"^! ?pip install", line) + if match is not None: + if "-qqq" not in line: + lines[i] = line.replace(match.group(0), match.group(0) + " -qqq") + return "\n".join(lines) + +class CommandLineCodeResult(CodeResult): + """(Experimental) A code result class for command line code executor.""" + + code_file: str | None = Field( + default=None, + description="The file that the executed code block was saved to.", + ) + + +@dataclass +class Alias: + name: str + alias: str + + +@dataclass +class ImportFromModule: + module: str + imports: list[str | Alias] + +Import = str | ImportFromModule | Alias + + +class _StringLoader(SourceLoader): + def __init__(self, data: str): + self.data = data + + def get_source(self, fullname: str) -> str: + return self.data + + def get_data(self, path: str) -> bytes: + return self.data.encode("utf-8") + + def get_filename(self, fullname: str) -> str: + return "/" + fullname + ".py" + +@dataclass +class FunctionWithRequirementsStr: + func: str + _compiled_func: Callable[..., Any] + _func_name: str + python_packages: list[str] = field(default_factory=list) + global_imports: list[Import] = field(default_factory=list) + + def __init__(self, func: str, python_packages: list[str] = [], global_imports: list[Import] = []): + self.func = func + self.python_packages = python_packages + self.global_imports = global_imports + + module_name = "func_module" + loader = _StringLoader(func) + spec = importlib.util.spec_from_loader(module_name, loader) + if spec is None: + raise ValueError("Could not create spec") + module = importlib.util.module_from_spec(spec) + if spec.loader is None: + raise ValueError("Could not create loader") + + try: + spec.loader.exec_module(module) + except Exception as e: + raise ValueError(f"Could not compile function: {e}") from e + + functions = inspect.getmembers(module, inspect.isfunction) + if len(functions) != 1: + raise ValueError("The string must contain exactly one function") + + self._func_name, self._compiled_func = functions[0] + + def __call__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("String based function with requirement objects are not directly callable") + + +@dataclass +class FunctionWithRequirements(Generic[T, P]): + func: Callable[P, T] + python_packages: list[str] = field(default_factory=list) + global_imports: list[Import] = field(default_factory=list) + + @classmethod + def from_callable( + cls, func: Callable[P, T], python_packages: list[str] = [], global_imports: list[Import] = [] + ) -> "FunctionWithRequirements"[T, P]: + return cls(python_packages=python_packages, global_imports=global_imports, func=func) + + @staticmethod + def from_str( + func: str, python_packages: list[str] = [], global_imports: list[Import] = [] + ) -> FunctionWithRequirementsStr: + return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports) + + # Type this based on F + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + return self.func(*args, **kwargs) + +def to_stub(func: Callable[..., Any] | FunctionWithRequirementsStr) -> str: + """Generate a stub for a function as a string + + Args: + func (Callable[..., Any]): The function to generate a stub for + + Returns: + str: The stub for the function + """ + if isinstance(func, FunctionWithRequirementsStr): + return to_stub(func._compiled_func) + + content = f"def {func.__name__}{inspect.signature(func)}:\n" + docstring = func.__doc__ + + if docstring: + docstring = dedent(docstring) + docstring = '"""' + docstring + '"""' + docstring = indent(docstring, " ") + content += docstring + "\n" + + content += " ..." + return content + +def _to_code(func: FunctionWithRequirements[T, P] | Callable[P, T] | FunctionWithRequirementsStr) -> str: + if isinstance(func, FunctionWithRequirementsStr): + return func.func + + code = inspect.getsource(func) + # Strip the decorator + if code.startswith("@"): + code = code[code.index("\n") + 1 :] + return code + +def _import_to_str(im: Import) -> str: + if isinstance(im, str): + return f"import {im}" + elif isinstance(im, Alias): + return f"import {im.name} as {im.alias}" + else: + + def to_str(i: str | Alias) -> str: + if isinstance(i, str): + return i + else: + return f"{i.name} as {i.alias}" + + imports = ", ".join(map(to_str, im.imports)) + return f"from {im.module} import {imports}" + +def _build_python_functions_file( + funcs: list[FunctionWithRequirements[Any, P] | Callable[..., Any] | FunctionWithRequirementsStr] +) -> str: + # First collect all global imports + global_imports: set[str] = set() + for func in funcs: + if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)): + global_imports.update(map(_import_to_str, func.global_imports)) + + content = "\n".join(global_imports) + "\n\n" + + for func in funcs: + content += _to_code(func) + "\n\n" + + return content + +filename_patterns = [ + re.compile(r"^", re.DOTALL), + re.compile(r"^/\* (filename:)?(.+?) \*/", re.DOTALL), + re.compile(r"^// (filename:)?(.+?)$", re.DOTALL), + re.compile(r"^# (filename:)?(.+?)$", re.DOTALL), +] + +def _get_file_name_from_content(code: str, workspace_path: Path) -> str | None: + first_line = code.split("\n")[0].strip() + # TODO - support other languages + for pattern in filename_patterns: + matches = pattern.match(first_line) + if matches is not None: + filename = matches.group(2).strip() + + # Handle relative paths in the filename + path = Path(filename) + if not path.is_absolute(): + path = workspace_path / path + path = path.resolve() + # Throws an error if the file is not in the workspace + relative = path.relative_to(workspace_path.resolve()) + return str(relative) + return None + +def silence_pip(code: str, lang: str) -> str: + """Apply -qqq flag to pip install commands.""" + if lang == "python": + regex = r"^! ?pip install" + elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]: + regex = r"^pip install" + else: + return code + + # Find lines that start with pip install and make sure "-qqq" flag is added. + lines = code.split("\n") + for i, line in enumerate(lines): + # use regex to find lines that start with pip install. + match = re.search(regex, line) + if match is not None: + if "-qqq" not in line: + lines[i] = line.replace(match.group(0), match.group(0) + " -qqq") + return "\n".join(lines) + +class LocalCommandLineCodeExecutor(CodeExecutor): + SUPPORTED_LANGUAGES: ClassVar[list[str]] = [ + "bash", + "shell", + "sh", + "pwsh", + "powershell", + "ps1", + "python", + "javascript", + "html", + "css", + ] + DEFAULT_EXECUTION_POLICY: ClassVar[dict[str, bool]] = { + "bash": True, + "shell": True, + "sh": True, + "pwsh": True, + "powershell": True, + "ps1": True, + "python": True, + "javascript": False, + "html": False, + "css": False, + } + + FUNCTION_PROMPT_TEMPLATE: ClassVar[ + str + ] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names. + +For example, if there was a function called `foo` you could import it by writing `from $module_name import foo` + +$functions""" + + def __init__( + self, + timeout: int = 60, + virtual_env_context: SimpleNamespace | None = None, + work_dir: Path | str = Path("."), + functions: list[FunctionWithRequirements[Any, A] | Callable[..., Any] | FunctionWithRequirementsStr] = [], + functions_module: str = "functions", + execution_policies: dict[str, bool] | None = None, + ): + """(Experimental) A code executor class that executes or saves LLM generated code a local command line + environment. + + **This will execute or save LLM generated code on the local machine.** + + Each code block is saved as a file in the working directory. Depending on the execution policy, + the code may be executed in a separate process. + The code blocks are executed or save in the order they are received. + Command line code is sanitized against a list of dangerous commands to prevent self-destructive commands from being executed, + which could potentially affect the user's environment. Supported languages include Python, shell scripts (bash, shell, sh), + PowerShell (pwsh, powershell, ps1), HTML, CSS, and JavaScript. + Execution policies determine whether each language's code blocks are executed or saved only. + + ## Execution with a Python virtual environment + A python virtual env can be used to execute code and install dependencies. This has the added benefit of not polluting the + base environment with unwanted modules. + ```python + from autogen.code_utils import create_virtual_env + from autogen.coding import LocalCommandLineCodeExecutor + + venv_dir = ".venv" + venv_context = create_virtual_env(venv_dir) + + executor = LocalCommandLineCodeExecutor(virtual_env_context=venv_context) + ``` + + Args: + timeout (int): The timeout for code execution, default is 60 seconds. + virtual_env_context (Optional[SimpleNamespace]): The virtual environment context to use. + work_dir (Union[Path, str]): The working directory for code execution, defaults to the current directory. + functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]): A list of callable functions available to the executor. + functions_module (str): The module name under which functions are accessible. + execution_policies (Optional[Dict[str, bool]]): A dictionary mapping languages to execution policies (True for execution, False for saving only). Defaults to class-wide DEFAULT_EXECUTION_POLICY. + """ + + if timeout < 1: + raise ValueError("Timeout must be greater than or equal to 1.") + + if isinstance(work_dir, str): + work_dir = Path(work_dir) + + if not functions_module.isidentifier(): + raise ValueError("Module name must be a valid Python identifier") + + self._functions_module = functions_module + + work_dir.mkdir(exist_ok=True) + + self._timeout = timeout + self._work_dir: Path = work_dir + self._virtual_env_context: SimpleNamespace | None = virtual_env_context + + self._functions = functions + # Setup could take some time so we intentionally wait for the first code block to do it. + if len(functions) > 0: + self._setup_functions_complete = False + else: + self._setup_functions_complete = True + + self.execution_policies = self.DEFAULT_EXECUTION_POLICY.copy() + if execution_policies is not None: + self.execution_policies.update(execution_policies) + + def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str: + """(Experimental) Format the functions for a prompt. + + The template includes two variables: + - `$module_name`: The module name. + - `$functions`: The functions formatted as stubs with two newlines between each function. + + Args: + prompt_template (str): The prompt template. Default is the class default. + + Returns: + str: The formatted prompt. + """ + template = Template(prompt_template) + return template.substitute( + module_name=self._functions_module, + functions="\n\n".join([to_stub(func) for func in self._functions]), + ) + + @property + def functions_module(self) -> str: + """(Experimental) The module name for the functions.""" + return self._functions_module + + @property + def functions( + self, + ) -> list[FunctionWithRequirements[Any, A] | Callable[..., Any] | FunctionWithRequirementsStr]: + """(Experimental) The functions that are available to the code executor.""" + return self._functions + + @property + def timeout(self) -> int: + """(Experimental) The timeout for code execution.""" + return self._timeout + + @property + def work_dir(self) -> Path: + """(Experimental) The working directory for the code execution.""" + return self._work_dir + + @property + def code_extractor(self) -> CodeExtractor: + """(Experimental) Export a code extractor that can be used by an agent.""" + return MarkdownCodeExtractor() + + @staticmethod + def sanitize_command(lang: str, code: str) -> None: + """ + Sanitize the code block to prevent dangerous commands. + This approach acknowledges that while Docker or similar + containerization/sandboxing technologies provide a robust layer of security, + not all users may have Docker installed or may choose not to use it. + Therefore, having a baseline level of protection helps mitigate risks for users who, + either out of choice or necessity, run code outside of a sandboxed environment. + """ + dangerous_patterns = [ + (r"\brm\s+-rf\b", "Use of 'rm -rf' command is not allowed."), + (r"\bmv\b.*?\s+/dev/null", "Moving files to /dev/null is not allowed."), + (r"\bdd\b", "Use of 'dd' command is not allowed."), + (r">\s*/dev/sd[a-z][1-9]?", "Overwriting disk blocks directly is not allowed."), + (r":\(\)\{\s*:\|\:&\s*\};:", "Fork bombs are not allowed."), + ] + if lang in ["bash", "shell", "sh"]: + for pattern, message in dangerous_patterns: + if re.search(pattern, code): + raise ValueError(f"Potentially dangerous command detected: {message}") + + def _setup_functions(self) -> None: + func_file_content = _build_python_functions_file(self._functions) + func_file = self._work_dir / f"{self._functions_module}.py" + func_file.write_text(func_file_content) + + # Collect requirements + lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)] + flattened_packages = [item for sublist in lists_of_packages for item in sublist] + required_packages = list(set(flattened_packages)) + if len(required_packages) > 0: + print("Ensuring packages are installed in executor.") + if self._virtual_env_context: + py_executable = self._virtual_env_context.env_exe + else: + py_executable = sys.executable + cmd = [py_executable, "-m", "pip", "install"] + required_packages + try: + result = subprocess.run( + cmd, + cwd=self._work_dir, + capture_output=True, + text=True, + timeout=float(self._timeout), + encoding="utf-8", + ) + except subprocess.TimeoutExpired as e: + raise ValueError("Pip install timed out") from e + if result.returncode != 0: + raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}") + # Attempt to load the function file to check for syntax errors, imports etc. + exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")]) + if exec_result.exit_code != 0: + raise ValueError(f"Functions failed to load: {exec_result.output}") + self._setup_functions_complete = True + + def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult: + """(Experimental) Execute the code blocks and return the result. + + Args: + code_blocks (List[CodeBlock]): The code blocks to execute. + + Returns: + CommandLineCodeResult: The result of the code execution.""" + if not self._setup_functions_complete: + self._setup_functions() + return self._execute_code_dont_check_setup(code_blocks) + + def _execute_code_dont_check_setup(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult: + logs_all = "" + file_names = [] + for code_block in code_blocks: + lang, code = code_block.language, code_block.code + lang = lang.lower() + + LocalCommandLineCodeExecutor.sanitize_command(lang, code) + code = silence_pip(code, lang) + + if lang in PYTHON_VARIANTS: + lang = "python" + + if WIN32 and lang in ["sh", "shell"]: + lang = "ps1" + + if lang not in self.SUPPORTED_LANGUAGES: + # In case the language is not supported, we return an error message. + exitcode = 1 + logs_all += "\n" + f"unknown language {lang}" + break + + execute_code = self.execution_policies.get(lang, False) + try: + # Check if there is a filename comment + filename = _get_file_name_from_content(code, self._work_dir) + except ValueError: + return CommandLineCodeResult(exit_code=1, output="Filename is not in the workspace") + + if filename is None: + # create a file with an automatically generated name + code_hash = md5(code.encode()).hexdigest() + filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}" + written_file = (self._work_dir / filename).resolve() + with written_file.open("w", encoding="utf-8") as f: + f.write(code) + file_names.append(written_file) + + if not execute_code: + # Just return a message that the file is saved. + logs_all += f"Code saved to {str(written_file)}\n" + exitcode = 0 + continue + + program = _cmd(lang) + cmd = [program, str(written_file.absolute())] + env = os.environ.copy() + + if self._virtual_env_context: + virtual_env_abs_path = os.path.abspath(self._virtual_env_context.bin_path) + path_with_virtualenv = rf"{virtual_env_abs_path}{os.pathsep}{env['PATH']}" + env["PATH"] = path_with_virtualenv + if WIN32: + activation_script = os.path.join(virtual_env_abs_path, "activate.bat") + cmd = [activation_script, "&&", *cmd] + + try: + result = subprocess.run( + cmd, + cwd=self._work_dir, + capture_output=True, + text=True, + timeout=float(self._timeout), + env=env, + encoding="utf-8", + ) + except subprocess.TimeoutExpired: + logs_all += "\n" + TIMEOUT_MSG + # Same exit code as the timeout command on linux. + exitcode = 124 + break + + logs_all += result.stderr + logs_all += result.stdout + exitcode = result.returncode + + if exitcode != 0: + break + + code_file = str(file_names[0]) if len(file_names) > 0 else None + return CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file) + + def restart(self) -> None: + """(Experimental) Restart the code executor.""" + warnings.warn("Restarting local command line code executor is not supported. No action is taken.") + +class CodeExecutorFactory: + """(Experimental) A factory class for creating code executors.""" + + @staticmethod + def create(code_execution_config: CodeExecutionConfig) -> CodeExecutor: + """(Experimental) Get a code executor based on the code execution config. + + Args: + code_execution_config (Dict): The code execution config, + which is a dictionary that must contain the key "executor". + The value of the key "executor" can be either a string + or an instance of CodeExecutor, in which case the code + executor is returned directly. + + Returns: + CodeExecutor: The code executor. + + Raises: + ValueError: If the code executor is unknown or not specified. + """ + executor = code_execution_config.get("executor") + if isinstance(executor, CodeExecutor): + # If the executor is already an instance of CodeExecutor, return it. + return executor + if executor == "ipython-embedded": + return EmbeddedIPythonCodeExecutor(**code_execution_config.get("ipython-embedded", {})) + elif executor == "commandline-local": + return LocalCommandLineCodeExecutor(**code_execution_config.get("commandline-local", {})) + else: + raise ValueError(f"Unknown code executor {executor}") diff --git a/train_methods/legacy_autogen/legacy_autogen.py b/train_methods/legacy_autogen/legacy_autogen.py index d9b9a61..0f27dd3 100644 --- a/train_methods/legacy_autogen/legacy_autogen.py +++ b/train_methods/legacy_autogen/legacy_autogen.py @@ -1485,4 +1485,4 @@ def clear_agents_history(self, reply: dict, groupchat: GroupChat) -> str: skip_words_number = 2 + int(bool(agent_to_memory_clear)) + int(nr_messages_to_preserve_provided) reply_content = " ".join(words[:clear_word_index] + words[clear_word_index + skip_words_number :]) - return reply_content \ No newline at end of file + return reply_content diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index 173ab26..ec164d5 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -13,15 +13,10 @@ from pydantic import BaseModel from termcolor import colored -from ..coding.base import CodeExecutor -from ..coding.factory import CodeExecutorFactory -from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str -from .utils import consolidate_chat_info, gather_usage_summary - - from train_methods.legacy_autogen.cache import AbstractCache -from train_methods.legacy_autogen.chat import ChatResult, a_initiate_chats, initiate_chats, _post_process_carryover_item +from train_methods.legacy_autogen.chat import ChatResult, a_initiate_chats, initiate_chats, _post_process_carryover_item, consolidate_chat_info from train_methods.legacy_autogen.client import ModelClient, OpenAIWrapper +from train_methods.legacy_autogen.coding import CodeExecutor, CodeExecutorFactory from train_methods.legacy_autogen.stream import IOStream from train_methods.legacy_autogen.utils import ( check_can_use_docker_or_throw, @@ -30,6 +25,9 @@ execute_code, extract_code, infer_lang, + load_basemodels_if_needed, + serialize_to_str, + get_function_schema ) __all__ = ("ConversableAgent",) @@ -187,6 +185,76 @@ def update_system_message(self, system_message: str) -> None: system_message (str): system message for inference. """ + +def gather_usage_summary(agents: list[Agent]) -> dict[dict[str, dict], dict[str, dict]]: + r"""Gather usage summary from all agents. + + Args: + agents: (list): List of agents. + + Returns: + dictionary: A dictionary containing two keys: + - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference. + - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference". + + Example: + + ```python + { + "usage_including_cached_inference" : { + "total_cost": 0.0006090000000000001, + "gpt-35-turbo": { + "cost": 0.0006090000000000001, + "prompt_tokens": 242, + "completion_tokens": 123, + "total_tokens": 365 + }, + }, + + "usage_excluding_cached_inference" : { + "total_cost": 0.0006090000000000001, + "gpt-35-turbo": { + "cost": 0.0006090000000000001, + "prompt_tokens": 242, + "completion_tokens": 123, + "total_tokens": 365 + }, + } + } + ``` + + Note: + + If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`. + """ + + def aggregate_summary(usage_summary: dict[str, Any], agent_summary: dict[str, Any]) -> None: + if agent_summary is None: + return + usage_summary["total_cost"] += agent_summary.get("total_cost", 0) + for model, data in agent_summary.items(): + if model != "total_cost": + if model not in usage_summary: + usage_summary[model] = data.copy() + else: + usage_summary[model]["cost"] += data.get("cost", 0) + usage_summary[model]["prompt_tokens"] += data.get("prompt_tokens", 0) + usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0) + usage_summary[model]["total_tokens"] += data.get("total_tokens", 0) + + usage_including_cached_inference = {"total_cost": 0} + usage_excluding_cached_inference = {"total_cost": 0} + + for agent in agents: + if getattr(agent, "client", None): + aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary) + aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary) + + return { + "usage_including_cached_inference": usage_including_cached_inference, + "usage_excluding_cached_inference": usage_excluding_cached_inference, + } + class ConversableAgent(LLMAgent): """(In preview) A class for generic conversable agents which can be configured as assistant or user proxy. @@ -3113,3 +3181,75 @@ def register_function( """ f = caller.register_for_llm(name=name, description=description)(f) executor.register_for_execution(name=name)(f) + + +class AssistantAgent(ConversableAgent): + """(In preview) Assistant agent, designed to solve a task with LLM. + + AssistantAgent is a subclass of ConversableAgent configured with a default system message. + The default system message is designed to solve a task with LLM, + including suggesting python code blocks and debugging. + `human_input_mode` is default to "NEVER" + and `code_execution_config` is default to False. + This agent doesn't execute code by default, and expects the user to execute the code. + """ + + DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant. +Solve tasks using your coding and language skills. +In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute. + 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself. + 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly. +Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill. +When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user. +If you want the user to save the code in a file before executing it, put # filename: inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user. +If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try. +When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible. +Reply "TERMINATE" in the end when everything is done. + """ + + DEFAULT_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills." + + def __init__( + self, + name: str, + system_message: str | None = DEFAULT_SYSTEM_MESSAGE, + llm_config: dict | Literal[False] | None = None, + is_termination_msg: Callable[[dict], bool] | None = None, + max_consecutive_auto_reply: int | None = None, + human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", + description: str | None = None, + **kwargs, + ): + """ + Args: + name (str): agent name. + system_message (str): system message for the ChatCompletion inference. + Please override this attribute if you want to reprogram the agent. + llm_config (dict or False or None): llm inference configuration. + Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) + for available options. + is_termination_msg (function): a function that takes a message in the form of a dictionary + and returns a boolean value indicating if this received message is a termination message. + The dict can contain the following keys: "content", "role", "name", "function_call". + max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. + default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). + The limit only plays a role when human_input_mode is not "ALWAYS". + **kwargs (dict): Please refer to other kwargs in + [ConversableAgent](conversable_agent#__init__). + """ + super().__init__( + name, + system_message, + is_termination_msg, + max_consecutive_auto_reply, + human_input_mode, + llm_config=llm_config, + description=description, + **kwargs, + ) + + # Update the provided description if None, and we are using the default system_message, + # then use the default description. + if description is None: + if system_message == self.DEFAULT_SYSTEM_MESSAGE: + self.description = self.DEFAULT_DESCRIPTION diff --git a/train_methods/legacy_autogen/utils.py b/train_methods/legacy_autogen/utils.py index f3d66a4..6bc0bd2 100644 --- a/train_methods/legacy_autogen/utils.py +++ b/train_methods/legacy_autogen/utils.py @@ -1,4 +1,6 @@ -import logging +import functools +import inspect +import json import os import pathlib import re @@ -10,9 +12,13 @@ from concurrent.futures import ThreadPoolExecutor, TimeoutError from hashlib import md5 from types import SimpleNamespace -from typing import Callable, Literal, TypedDict +from typing import Callable, Literal, TypedDict, Any, Annotated, ForwardRef +from typing_extensions import get_args, get_origin import docker +from pydantic import BaseModel, Field, TypeAdapter +from pydantic._internal._typing_extra import try_eval_type +from pydantic.json_schema import JsonSchemaValue from train_methods.legacy_autogen.completion import Completion @@ -28,8 +34,6 @@ PATH_SEPARATOR = WIN32 and "\\" or "/" PYTHON_VARIANTS = ["python", "Python", "py"] -logger = logging.getLogger(__name__) - class UserMessageTextContentPart(TypedDict): type: Literal["text"] text: str @@ -389,7 +393,6 @@ def execute_code( """ if all((code is None, filename is None)): error_msg = f"Either {code=} or {filename=} must be provided." - logger.error(error_msg) raise AssertionError(error_msg) running_inside_docker = in_docker_container() @@ -724,3 +727,334 @@ def create_virtual_env(dir_path: str, **env_args) -> SimpleNamespace: env_builder = venv.EnvBuilder(**env_args) env_builder.create(dir_path) return env_builder.ensure_directories(dir_path) + +def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: + """Get the type annotation of a parameter. + + Args: + annotation: The annotation of the parameter + globalns: The global namespace of the function + + Returns: + The type annotation of the parameter + """ + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = try_eval_type(annotation, globalns, globalns) + return annotation + +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + """Get the signature of a function with type annotations. + + Args: + call: The function to get the signature for + + Returns: + The signature of the function with type annotations + """ + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param.annotation, globalns), + ) + for param in signature.parameters.values() + ] + typed_signature = inspect.Signature(typed_params) + return typed_signature + +def get_param_annotations(typed_signature: inspect.Signature) -> dict[str, Annotated[type[Any], str] | type[Any]]: + """Get the type annotations of the parameters of a function + + Args: + typed_signature: The signature of the function with type annotations + + Returns: + A dictionary of the type annotations of the parameters of the function + """ + return { + k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty + } + +def get_load_param_if_needed_function(t: Any) -> Callable[[dict[str, Any], type[BaseModel]], BaseModel] | None: + """Get a function to load a parameter if it is a Pydantic model + + Args: + t: The type annotation of the parameter + + Returns: + A function to load the parameter if it is a Pydantic model, otherwise None + + """ + if get_origin(t) is Annotated: + return get_load_param_if_needed_function(get_args(t)[0]) + + def load_base_model(v: dict[str, Any], t: type[BaseModel]) -> BaseModel: + return t(**v) + + return load_base_model if isinstance(t, type) and issubclass(t, BaseModel) else None + + +def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]: + """A decorator to load the parameters of a function if they are Pydantic models + + Args: + func: The function with annotated parameters + + Returns: + A function that loads the parameters before calling the original function + + """ + # get the type annotations of the parameters + typed_signature = get_typed_signature(func) + param_annotations = get_param_annotations(typed_signature) + + # get functions for loading BaseModels when needed based on the type annotations + kwargs_mapping_with_nones = {k: get_load_param_if_needed_function(t) for k, t in param_annotations.items()} + + # remove the None values + kwargs_mapping = {k: f for k, f in kwargs_mapping_with_nones.items() if f is not None} + + # a function that loads the parameters before calling the original function + @functools.wraps(func) + def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any: + # load the BaseModels if needed + for k, f in kwargs_mapping.items(): + kwargs[k] = f(kwargs[k], param_annotations[k]) + + # call the original function + return func(*args, **kwargs) + + @functools.wraps(func) + async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any: + # load the BaseModels if needed + for k, f in kwargs_mapping.items(): + kwargs[k] = f(kwargs[k], param_annotations[k]) + + # call the original function + return await func(*args, **kwargs) + + if inspect.iscoroutinefunction(func): + return _a_load_parameters_if_needed + else: + return _load_parameters_if_needed + + +def serialize_to_str(x: Any) -> str: + if isinstance(x, str): + return x + elif isinstance(x, BaseModel): + return x.model_dump_json() + else: + return json.dumps(x, ensure_ascii=False) + +def get_required_params(typed_signature: inspect.Signature) -> list[str]: + """Get the required parameters of a function + + Args: + signature: The signature of the function as returned by inspect.signature + + Returns: + A list of the required parameters of the function + """ + return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty] + +def get_default_values(typed_signature: inspect.Signature) -> dict[str, Any]: + """Get default values of parameters of a function + + Args: + signature: The signature of the function as returned by inspect.signature + + Returns: + A dictionary of the default values of the parameters of the function + """ + return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty} + +def get_typed_return_annotation(call: Callable[..., Any]) -> Any: + """Get the return annotation of a function. + + Args: + call: The function to get the return annotation for + + Returns: + The return annotation of the function + """ + signature = inspect.signature(call) + annotation = signature.return_annotation + + if annotation is inspect.Signature.empty: + return None + + globalns = getattr(call, "__globals__", {}) + return get_typed_annotation(annotation, globalns) + +def get_missing_annotations(typed_signature: inspect.Signature, required: list[str]) -> tuple[set[str], set[str]]: + """Get the missing annotations of a function + + Ignores the parameters with default values as they are not required to be annotated, but logs a warning. + Args: + typed_signature: The signature of the function with type annotations + required: The required parameters of the function + + Returns: + A set of the missing annotations of the function + """ + all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty} + missing = all_missing.intersection(set(required)) + unannotated_with_default = all_missing.difference(missing) + return missing, unannotated_with_default + +class Parameters(BaseModel): + """Parameters of a function as defined by the OpenAI API""" + + type: Literal["object"] = "object" + properties: dict[str, JsonSchemaValue] + required: list[str] + + +class Function(BaseModel): + """A function as defined by the OpenAI API""" + + description: Annotated[str, Field(description="Description of the function")] + name: Annotated[str, Field(description="Name of the function")] + parameters: Annotated[Parameters, Field(description="Parameters of the function")] + + +class ToolFunction(BaseModel): + """A function under tool as defined by the OpenAI API.""" + + type: Literal["function"] = "function" + function: Annotated[Function, Field(description="Function under tool")] + + +def get_parameter_json_schema(k: str, v: Any, default_values: dict[str, Any]) -> JsonSchemaValue: + """Get a JSON schema for a parameter as defined by the OpenAI API + + Args: + k: The name of the parameter + v: The type of the parameter + default_values: The default values of the parameters of the function + + Returns: + A Pydanitc model for the parameter + """ + + def type2description(k: str, v: Annotated[type[Any], str] | type[Any]) -> str: + # handles Annotated + if hasattr(v, "__metadata__"): + retval = v.__metadata__[0] + if isinstance(retval, str): + return retval + else: + raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.") + else: + return k + + schema = TypeAdapter(v).json_schema() + if k in default_values: + dv = default_values[k] + schema["default"] = dv + + schema["description"] = type2description(k, v) + + return schema + +def get_parameters( + required: list[str], + param_annotations: dict[str, Annotated[type[Any], str] | type[Any]], + default_values: dict[str, Any], +) -> Parameters: + """Get the parameters of a function as defined by the OpenAI API + + Args: + required: The required parameters of the function + hints: The type hints of the function as returned by typing.get_type_hints + + Returns: + A Pydantic model for the parameters of the function + """ + return Parameters( + properties={ + k: get_parameter_json_schema(k, v, default_values) + for k, v in param_annotations.items() + if v is not inspect.Signature.empty + }, + required=required, + ) + +def get_function_schema(f: Callable[..., Any], *, name: str | None = None, description: str) -> dict[str, Any]: + """Get a JSON schema for a function as defined by the OpenAI API + + Args: + f: The function to get the JSON schema for + name: The name of the function + description: The description of the function + + Returns: + A JSON schema for the function + + Raises: + TypeError: If the function is not annotated + + Examples: + + ```python + def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None: + pass + + get_function_schema(f, description="function f") + + # {'type': 'function', + # 'function': {'description': 'function f', + # 'name': 'f', + # 'parameters': {'type': 'object', + # 'properties': {'a': {'type': 'str', 'description': 'Parameter a'}, + # 'b': {'type': 'int', 'description': 'b'}, + # 'c': {'type': 'float', 'description': 'Parameter c'}}, + # 'required': ['a']}}} + ``` + + """ + typed_signature = get_typed_signature(f) + required = get_required_params(typed_signature) + default_values = get_default_values(typed_signature) + param_annotations = get_param_annotations(typed_signature) + return_annotation = get_typed_return_annotation(f) + missing, unannotated_with_default = get_missing_annotations(typed_signature, required) + + if return_annotation is None: + print( + f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is " + + "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'." + ) + + if unannotated_with_default != set(): + unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)] + print( + f"The following parameters of the function '{f.__name__}' with default values are not annotated: " + + f"{', '.join(unannotated_with_default_s)}." + ) + + if missing != set(): + missing_s = [f"'{k}'" for k in sorted(missing)] + raise TypeError( + f"All parameters of the function '{f.__name__}' without default values must be annotated. " + + f"The annotations are missing for the following parameters: {', '.join(missing_s)}" + ) + + fname = name if name else f.__name__ + + parameters = get_parameters(required, param_annotations, default_values=default_values) + + function = ToolFunction( + function=Function( + description=description, + name=fname, + parameters=parameters, + ) + ) + + return function.model_dump() diff --git a/train_methods/utils_cogfd.py b/train_methods/utils_cogfd.py index a9647ba..fa0e39a 100644 --- a/train_methods/utils_cogfd.py +++ b/train_methods/utils_cogfd.py @@ -19,7 +19,7 @@ from transformers.utils import ModelOutput from train_methods.legacy_autogen.legacy_autogen import GroupChat -from train_methods.legacy_autogen.legacy_autogen_conversable_agent import ConversableAgent +from train_methods.legacy_autogen.legacy_autogen_conversable_agent import ConversableAgent, AssistantAgent @dataclass class TransformationModelOutput(ModelOutput): @@ -142,10 +142,10 @@ def forward( generating concept logic graph """ - def generate_and_save_concept_graph( concept_combination_x: str, combination_theme_y: str, + base_url: str, output_filename: str = "concept_logic_graph.json" ) -> dict | None: """Generates a conceptual logic graph based on the given text concept combination, saves it as JSON, and returns the parsed graph. @@ -157,7 +157,7 @@ def generate_and_save_concept_graph( Returns: The parsed conceptual logic graph as a dict, or None if the process fails. """ - + OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] Concept_logic_graph_Agent = ConversableAgent( name="Concept_logic_graph_Agent", @@ -202,14 +202,14 @@ def generate_and_save_concept_graph( Follow the JSON structure precisely as shown in the example. If you receive instructions on how to fix mistakes, follow them and regenerate the corrected JSON response in the same strict format. ''', - llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": BASE_URL}]}, + llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": base_url}]}, is_termination_msg=lambda msg: "the answer is correct!" in msg.get("content", "").lower(), human_input_mode="NEVER", ) reviewer = AssistantAgent( name="Reviewer", - llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": BASE_URL}]}, + llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": base_url}]}, system_message=""" You are a well-known expert in the description logic field and a compliance reviewer, known for your thoroughness and commitment to standards. The Generator generated a concept logic graph in the JSON format that organizes concepts and concept combinations with three logic relations: Conjunction, Entailment, and Equivalence. Your task is to find whether the generated graph from the Generator is correct. Here are two aspects of the answer which you need to check carefully: 1. Whether the answer is correct and helpful. @@ -401,4 +401,4 @@ def generate_and_save_iterative_graphs( all_graphs = generate_and_save_iterative_graphs(concept_combination_x, combination_theme_y) combine_list, concept_list = extract_concept_from_graph(all_graphs) print(f"combine_list: {combine_list}") - print(f"concept_list: {concept_list}") \ No newline at end of file + print(f"concept_list: {concept_list}") From 46e9d3385b646a52131300950868b92ed9072400 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Tue, 28 Oct 2025 16:17:33 +0900 Subject: [PATCH 05/25] completed (mabye need to be refactored) --- train_methods/train_cogfd.py | 56 ++++++++++-------------------------- train_methods/utils_cogfd.py | 12 +------- utils.py | 5 ++++ 3 files changed, 21 insertions(+), 52 deletions(-) diff --git a/train_methods/train_cogfd.py b/train_methods/train_cogfd.py index fa225a1..f39170e 100644 --- a/train_methods/train_cogfd.py +++ b/train_methods/train_cogfd.py @@ -24,6 +24,7 @@ import math import json import os +from pathlib import Path import torch import torch.nn as nn @@ -37,7 +38,6 @@ from diffusers.models.attention_processor import Attention from transformers import AutoTokenizer, PretrainedConfig from transformers import CLIPTextModel -import argparse from train_methods.data import COGFDDataset from train_methods.utils_cogfd import RobertaSeriesModelWithTransformation, generate_and_save_iterative_graphs, extract_concept_from_graph @@ -207,6 +207,7 @@ def train( ) train_dataset = COGFDDataset( + data_dir=args.data_dir, tokenizer=tokenizer, size=args.image_size, center_crop=args.cogfd_center_crop, @@ -359,49 +360,22 @@ def train( def main(args: Arguments): # first, generate concept logic graph - # second, erasing - pass - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--theme', type=str) - parser.add_argument('--combine_concept_x', type=str, default="A child is drinking wine") - parser.add_argument('--combine_theme_y', type=str, default="underage drinking") - parser.add_argument('--iterate_n', type=int, default=1) - - args = parser.parse_args() - combine_concept = args.combine_concept_x - OUTPUT_DIR = "" - LOGICGRAPH_DIR = OUTPUT_DIR + "/concept_logic_graph" - PREPARED_DATA_DIR = OUTPUT_DIR + "/data" - args.prepared_data_dir = PREPARED_DATA_DIR.format(concept_combination=combine_concept) - args.graph_output_dir = LOGICGRAPH_DIR.formt(concept_combination=combine_concept) - - - combine_theme = args.combine_theme_y - task_info = [combine_concept, combine_theme] - - graph_path = os.path.join(args.graph_output_dir, f"{combine_concept}.json") - # generate concept logic graph - try: + graph_path = args.cogfd_graph_path + if Path(graph_path).exists(): with open(graph_path, 'r') as f: parsed_graph = json.load(f) - except FileNotFoundError: - print(f"File {graph_path} not found. Generating concept logic graph...") - combine_concept_x = args.combine_concept_x.replace("_", " ") - combine_theme_y = args.combine_theme_y.replace("_", " ") - parsed_graph = generate_and_save_iterative_graphs(combine_concept_x, combine_theme_y, graph_path, iterate_n=args.iterate_n) - - + else: + combine_concept_x = args.cogfd_combine_concept_x.replace("_", " ") + combine_theme_y = args.cogfd_combine_theme_y.replace("_", " ") + parsed_graph = generate_and_save_iterative_graphs(combine_concept_x, combine_theme_y, graph_path, iterate_n=args.cogfd_iterate_n) + + # second, erasing # extract concepts from graph concept_combination, sub_concept = extract_concept_from_graph(parsed_graph) - concepts = concept_combination + sub_concept - labels = [args.p1 for i in concept_combination] + [args.p2 for i in sub_concept] - print(concepts) - print(labels) - - train(task_info=task_info, - concept_combination=concepts, - labels=labels, + task_info = [args.cogfd_combine_concept_x, args.cogfd_combine_theme_y] + train( + task_info=task_info, + concept_combination=concept_combination, + labels=[args.cogfd_p1 for _ in concept_combination] + [args.cogfd_p2 for _ in sub_concept] ) diff --git a/train_methods/utils_cogfd.py b/train_methods/utils_cogfd.py index fa0e39a..1663117 100644 --- a/train_methods/utils_cogfd.py +++ b/train_methods/utils_cogfd.py @@ -228,8 +228,6 @@ def generate_and_save_concept_graph( speaker_selection_method='round_robin', ) - # --- 启动聊天 --- - # 构建传递给 agent 的消息 initial_message = f"X = {concept_combination_x}, Y = {combination_theme_y}" print(f"\n--- Starting chat for: '{initial_message}' ---") @@ -242,11 +240,9 @@ def auto_end_chat(): # Call the function after some condition or time has passed auto_end_chat() - final_graph_string = None parsed_graph = None - # 检查聊天是否有历史记录 if group_chat_with_introductions.messages: all_messages = group_chat_with_introductions.messages for msg in reversed(all_messages): @@ -258,7 +254,6 @@ def auto_end_chat(): print("\nNo messages found in group chat history.") if final_graph_string: - # 尝试从 final_graph_string 中提取 JSON 部分 try: match = re.search(r"```json\n(.*?)\n```", final_graph_string, re.DOTALL) if match: @@ -268,18 +263,15 @@ def auto_end_chat(): print("\n--- Parsed Concept Logic Graph --- (from ```json block)") pprint.pprint(parsed_graph) - # 保存到 JSON 文件 with open(output_filename, 'w', encoding='utf-8') as f: json.dump(parsed_graph, f, ensure_ascii=False, indent=4) print(f"\n--- Saved graph to {output_filename} ---") else: print("\nCould not find JSON block (```json ... ```) within the final graph string.") - # 尝试直接解析整个字符串作为备选 try: parsed_graph = json.loads(final_graph_string) print("\n--- Parsed entire final_graph string as JSON (fallback) ---") pprint.pprint(parsed_graph) - # 也可以在这里保存 with open(output_filename, 'w', encoding='utf-8') as f: json.dump(parsed_graph, f, ensure_ascii=False, indent=4) print(f"\n--- Saved graph to {output_filename} (from direct parse) ---") @@ -318,7 +310,7 @@ def extract_concept_from_graph(parsed_graph: dict[str, dict[str, Any]]) -> tuple concept_combination.append(main_concept) current_graph = iteration_graph[main_concept] - + # 包含関係の追加 if 'entailment' in current_graph: concept_combination.extend(current_graph['entailment']) @@ -392,12 +384,10 @@ def generate_and_save_iterative_graphs( return all_graphs -# --- 主执行块 --- # if __name__ == "__main__": concept_combination_x = "A child is drinking wine" combination_theme_y = "underage drinking" - # 使用新函数生成迭代图谱 all_graphs = generate_and_save_iterative_graphs(concept_combination_x, combination_theme_y) combine_list, concept_list = extract_concept_from_graph(all_graphs) print(f"combine_list: {combine_list}") diff --git a/utils.py b/utils.py index c203826..ec2caff 100644 --- a/utils.py +++ b/utils.py @@ -381,6 +381,11 @@ class Arguments(BaseModel): cogfd_lr_num_cycles: Optional[int] = Field(1) cogfd_lr_power: Optional[float] = Field(1.0) cogfd_dataloader_num_workers: Optional[int] = Field(9) + cogfd_graph_path: Optional[str] = Field("cpgfd-graph/graph.json") + cogfd_iterate_n: Optional[int] = Field(2) + cogfd_combine_concept_x: Optional[str] = Field("A child is drinking wine") + cogfd_combine_theme_y: Optional[str] = Field("underage drinking") + # inference part prompt: Optional[str] = Field("a photo of the English springer", description="prompt in inference phase") From e115da22278fe24eebb9d98127e40d25bca6f143 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Tue, 28 Oct 2025 16:18:54 +0900 Subject: [PATCH 06/25] fic --- train_methods/utils_cogfd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_methods/utils_cogfd.py b/train_methods/utils_cogfd.py index 1663117..5de1115 100644 --- a/train_methods/utils_cogfd.py +++ b/train_methods/utils_cogfd.py @@ -145,7 +145,6 @@ def forward( def generate_and_save_concept_graph( concept_combination_x: str, combination_theme_y: str, - base_url: str, output_filename: str = "concept_logic_graph.json" ) -> dict | None: """Generates a conceptual logic graph based on the given text concept combination, saves it as JSON, and returns the parsed graph. @@ -158,6 +157,7 @@ def generate_and_save_concept_graph( The parsed conceptual logic graph as a dict, or None if the process fails. """ OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] + base_url = os.environ["BASE_URL"] Concept_logic_graph_Agent = ConversableAgent( name="Concept_logic_graph_Agent", From b4a592b7d48ef4322e73f092ec04103d460afad2 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:20:21 +0900 Subject: [PATCH 07/25] minor fix --- train_methods/train_cogfd.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/train_methods/train_cogfd.py b/train_methods/train_cogfd.py index f39170e..a8a0c77 100644 --- a/train_methods/train_cogfd.py +++ b/train_methods/train_cogfd.py @@ -23,7 +23,6 @@ import itertools import math import json -import os from pathlib import Path import torch @@ -98,7 +97,13 @@ def __init__(self, hiddenstates_controller: "HiddenStatesController", module_nam self.hiddenstates_controller = hiddenstates_controller self.module_name = module_name - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None + ): encoder_attention_mask = self.hiddenstates_controller.encoder_attn_mask batch_size, sequence_length, _ = hidden_states.shape @@ -138,7 +143,7 @@ def train( if args.seed is not None: set_seed(args.seed) - os.makedirs(args.save_dir, exist_ok=True) + Path(args.save_dir).mkdir(exist_ok=True) tokenizer = AutoTokenizer.from_pretrained( args.sd_version, @@ -155,8 +160,7 @@ def train( unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(args.sd_version, subfolder="unet") unet_1: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(args.sd_version, subfolder="unet") - # unet_1 on device 1 - devices = get_devices(args)[0] + devices = get_devices(args) attn_controller = HiddenStatesController() module_count = 0 @@ -169,7 +173,7 @@ def train( attn_controller_1 = HiddenStatesController() module_count = 0 for name, module in unet_1.named_modules(): - if name.endswith('attn2'): + if name.endswith('attn2') and isinstance(module, Attention): module.set_processor(MyCrossAttnProcessor(attn_controller_1, name)) module_count += 1 print(f"cross attention module count: {module_count}") From f6c9f710c022744c34937fcdffbca1ce0acbb31a Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:23:36 +0900 Subject: [PATCH 08/25] remove Optional --- utils.py | 56 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/utils.py b/utils.py index a17f06a..02f4649 100644 --- a/utils.py +++ b/utils.py @@ -420,34 +420,34 @@ class Arguments(BaseModel): mce_reg_beta: int = Field(1, description="no need to use beta for now for testing") # config for CoGFD - cogfd_p1: Optional[float] = Field(-1.0) - cogfd_p2: Optional[float] = Field(1.0) - cogfd_start: Optional[int] = Field(990) - cogfd_end: Optional[int] = Field(1000) - cogfd_lr: Optional[float] = Field(5e-5) - cogfd_num_train_epochs: Optional[int] = Field(1) - cogfd_train_batch_size: Optional[int] = Field(20) - cogfd_adam_beta_1: Optional[float] = Field(0.9) - cogfd_adam_beta_2: Optional[float] = Field(0.999) - cogfd_adam_weight_decay: Optional[float] = Field(0.01) - cogfd_adam_epsilon: Optional[float] = Field(1.0e-08) - cogfd_gradient_accumulation_steps: Optional[int] = Field(1) - cogfd_scale_lr: Optional[bool] = Field(False) - cogfd_use_8bit_adam: Optional[bool] = Field(False) - cogfd_train_text_encoder: Optional[bool] = Field(False) - cogfd_center_crop: Optional[bool] = Field(False) - cogfd_only_optimize_ca: Optional[bool] = Field(False) - cogfd_set_grads_to_none: Optional[bool] = Field(False) - cogfd_use_pooler: Optional[bool] = Field(True) - cogfd_max_train_steps: Optional[int] = Field(100) - cogfd_lr_warmup_steps: Optional[int] = Field(0) - cogfd_lr_num_cycles: Optional[int] = Field(1) - cogfd_lr_power: Optional[float] = Field(1.0) - cogfd_dataloader_num_workers: Optional[int] = Field(9) - cogfd_graph_path: Optional[str] = Field("cpgfd-graph/graph.json") - cogfd_iterate_n: Optional[int] = Field(2) - cogfd_combine_concept_x: Optional[str] = Field("A child is drinking wine") - cogfd_combine_theme_y: Optional[str] = Field("underage drinking") + cogfd_p1: float = Field(-1.0) + cogfd_p2: float = Field(1.0) + cogfd_start: int = Field(990) + cogfd_end: int = Field(1000) + cogfd_lr: float = Field(5e-5) + cogfd_num_train_epochs: int = Field(1) + cogfd_train_batch_size: int = Field(20) + cogfd_adam_beta_1: float = Field(0.9) + cogfd_adam_beta_2: float = Field(0.999) + cogfd_adam_weight_decay: float = Field(0.01) + cogfd_adam_epsilon: float = Field(1.0e-08) + cogfd_gradient_accumulation_steps: int = Field(1) + cogfd_scale_lr: bool = Field(False) + cogfd_use_8bit_adam: bool = Field(False) + cogfd_train_text_encoder: bool = Field(False) + cogfd_center_crop: bool = Field(False) + cogfd_only_optimize_ca: bool = Field(False) + cogfd_set_grads_to_none: bool = Field(False) + cogfd_use_pooler: bool = Field(True) + cogfd_max_train_steps: int = Field(100) + cogfd_lr_warmup_steps: int = Field(0) + cogfd_lr_num_cycles: int = Field(1) + cogfd_lr_power: float = Field(1.0) + cogfd_dataloader_num_workers: int = Field(9) + cogfd_graph_path: str = Field("cpgfd-graph/graph.json") + cogfd_iterate_n: int = Field(2) + cogfd_combine_concept_x: str = Field("A child is drinking wine") + cogfd_combine_theme_y: str = Field("underage drinking") # inference part From 5409af788f137c932f24fde118018447eaab6add Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:29:17 +0900 Subject: [PATCH 09/25] remove Optional --- train_methods/utils_cogfd.py | 66 ++++++++++++------------------------ 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/train_methods/utils_cogfd.py b/train_methods/utils_cogfd.py index 5de1115..7d1411b 100644 --- a/train_methods/utils_cogfd.py +++ b/train_methods/utils_cogfd.py @@ -9,45 +9,24 @@ import pprint from dataclasses import dataclass from json import JSONDecodeError -from typing import Optional, Any +from typing import Any import torch - -from torch import nn +import torch.nn as nn from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel from transformers.utils import ModelOutput +from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions from train_methods.legacy_autogen.legacy_autogen import GroupChat from train_methods.legacy_autogen.legacy_autogen_conversable_agent import ConversableAgent, AssistantAgent @dataclass class TransformationModelOutput(ModelOutput): - """ - Base class for text model's outputs that also contains a pooling of the last hidden states. - - Args: - text_embeds (`torch.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The text embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one - for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - projection_state: Optional[torch.Tensor] = None - last_hidden_state: torch.Tensor = None - hidden_states: Optional[tuple[torch.Tensor]] = None - attentions: Optional[tuple[torch.Tensor]] = None + projection_state: torch.Tensor | None = None + last_hidden_state: torch.Tensor | None = None + hidden_states: tuple[torch.Tensor] | None = None + attentions: tuple[torch.Tensor] | None = None class RobertaSeriesConfig(XLMRobertaConfig): @@ -75,7 +54,7 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): base_model_prefix = "roberta" config_class = RobertaSeriesConfig - def __init__(self, config): + def __init__(self, config: RobertaSeriesConfig): super().__init__(config) self.roberta = XLMRobertaModel(config) self.transformation = nn.Linear(config.hidden_size, config.project_dim) @@ -87,23 +66,22 @@ def __init__(self, config): def forward( self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + head_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + return_dict: bool | None = None, + output_hidden_states: bool | None = None, ): - r""" """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.base_model( + outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.base_model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -138,9 +116,7 @@ def forward( ) -""" -generating concept logic graph -""" +# generating concept logic graph def generate_and_save_concept_graph( concept_combination_x: str, From bb186f029591aaf0202056ace203406e90209260 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:34:21 +0900 Subject: [PATCH 10/25] fix --- .../legacy_autogen/legacy_autogen_conversable_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index ec164d5..92d48af 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -55,6 +55,7 @@ def __init__( self.message = message super().__init__(self.message) + class Agent(Protocol): """(In preview) A protocol for Agent. @@ -256,7 +257,7 @@ def aggregate_summary(usage_summary: dict[str, Any], agent_summary: dict[str, An } class ConversableAgent(LLMAgent): - """(In preview) A class for generic conversable agents which can be configured as assistant or user proxy. + """A class for generic conversable agents which can be configured as assistant or user proxy. After receiving each message, the agent will send a reply to the sender unless the msg is a termination msg. For example, AssistantAgent and UserProxyAgent are subclasses of this class, @@ -3184,8 +3185,7 @@ def register_function( class AssistantAgent(ConversableAgent): - """(In preview) Assistant agent, designed to solve a task with LLM. - + """ AssistantAgent is a subclass of ConversableAgent configured with a default system message. The default system message is designed to solve a task with LLM, including suggesting python code blocks and debugging. From 0356789154233037ba5dcec9d78ee7cad8fc815f Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:40:35 +0900 Subject: [PATCH 11/25] remove human_input_mode --- .../legacy_autogen_conversable_agent.py | 97 ++----------------- 1 file changed, 9 insertions(+), 88 deletions(-) diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index 92d48af..7e59338 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -264,7 +264,6 @@ class ConversableAgent(LLMAgent): configured with different default settings. To modify auto reply, override `generate_reply` method. - To disable/enable human response in every turn, set `human_input_mode` to "NEVER" or "ALWAYS". To modify the way to get human input, override `get_human_input` method. To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`, `run_code`, and `execute_function` methods respectively. @@ -283,7 +282,6 @@ def __init__( system_message: str | list | None = "You are a helpful AI Assistant.", is_termination_msg: Callable[[dict], bool] | None = None, max_consecutive_auto_reply: int | None = None, - human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE", function_map: dict[str, Callable] | None = None, code_execution_config: dict | Literal[False] = False, llm_config: dict | Literal[False] | None = None, @@ -302,15 +300,6 @@ def __init__( max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). When set to 0, no auto reply will be generated. - human_input_mode (str): whether to ask for human inputs every time a message is received. - Possible values are "ALWAYS", "TERMINATE", "NEVER". - (1) When "ALWAYS", the agent prompts for human input every time a message is received. - Under this mode, the conversation stops when the human input is "exit", - or when is_termination_msg is True and there is no human input. - (2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or - the number of auto reply reaches the max_consecutive_auto_reply. - (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops - when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions, also used for tool calls. code_execution_config (dict or False): config for the code execution. To disable code execution, set to False. Otherwise, set to a dictionary with the following keys: @@ -377,7 +366,6 @@ def __init__( self.client_cache = None - self.human_input_mode = human_input_mode self._max_consecutive_auto_reply = ( max_consecutive_auto_reply if max_consecutive_auto_reply is not None else self.MAX_CONSECUTIVE_AUTO_REPLY ) @@ -2021,40 +2009,10 @@ def check_termination_and_human_reply( message = messages[-1] reply = "" no_human_input_msg = "" - sender_name = "the sender" if sender is None else sender.name - if self.human_input_mode == "ALWAYS": - reply = self.get_human_input( - f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: " - ) - no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" - # if the human input is empty, and the message is a termination message, then we will terminate the conversation - reply = reply if reply or not self._is_termination_msg(message) else "exit" - else: - if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: - if self.human_input_mode == "NEVER": - reply = "exit" - else: - # self.human_input_mode == "TERMINATE": - terminate = self._is_termination_msg(message) - reply = self.get_human_input( - f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " - if terminate - else f"Please give feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: " - ) - no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" - # if the human input is empty, and the message is a termination message, then we will terminate the conversation - reply = reply if reply or not terminate else "exit" - elif self._is_termination_msg(message): - if self.human_input_mode == "NEVER": - reply = "exit" - else: - # self.human_input_mode == "TERMINATE": - reply = self.get_human_input( - f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " - ) - no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" - # if the human input is empty, and the message is a termination message, then we will terminate the conversation - reply = reply or "exit" + if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + reply = "exit" + elif self._is_termination_msg(message): + reply = "exit" # print the no_human_input_msg if no_human_input_msg: @@ -2097,8 +2055,6 @@ def check_termination_and_human_reply( # increment the consecutive_auto_reply_counter self._consecutive_auto_reply_counter[sender] += 1 - if self.human_input_mode != "NEVER": - iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True) return False, None @@ -2134,40 +2090,10 @@ async def a_check_termination_and_human_reply( message = messages[-1] if messages else {} reply = "" no_human_input_msg = "" - sender_name = "the sender" if sender is None else sender.name - if self.human_input_mode == "ALWAYS": - reply = await self.a_get_human_input( - f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: " - ) - no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" - # if the human input is empty, and the message is a termination message, then we will terminate the conversation - reply = reply if reply or not self._is_termination_msg(message) else "exit" - else: - if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: - if self.human_input_mode == "NEVER": - reply = "exit" - else: - # self.human_input_mode == "TERMINATE": - terminate = self._is_termination_msg(message) - reply = await self.a_get_human_input( - f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " - if terminate - else f"Please give feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: " - ) - no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" - # if the human input is empty, and the message is a termination message, then we will terminate the conversation - reply = reply if reply or not terminate else "exit" - elif self._is_termination_msg(message): - if self.human_input_mode == "NEVER": - reply = "exit" - else: - # self.human_input_mode == "TERMINATE": - reply = await self.a_get_human_input( - f"Please give feedback to {sender_name}. Press enter or type 'exit' to stop the conversation: " - ) - no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" - # if the human input is empty, and the message is a termination message, then we will terminate the conversation - reply = reply or "exit" + if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + reply = "exit" + elif self._is_termination_msg(message): + reply = "exit" # print the no_human_input_msg if no_human_input_msg: @@ -2210,8 +2136,6 @@ async def a_check_termination_and_human_reply( # increment the consecutive_auto_reply_counter self._consecutive_auto_reply_counter[sender] += 1 - if self.human_input_mode != "NEVER": - iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True) return False, None @@ -3189,7 +3113,6 @@ class AssistantAgent(ConversableAgent): AssistantAgent is a subclass of ConversableAgent configured with a default system message. The default system message is designed to solve a task with LLM, including suggesting python code blocks and debugging. - `human_input_mode` is default to "NEVER" and `code_execution_config` is default to False. This agent doesn't execute code by default, and expects the user to execute the code. """ @@ -3216,7 +3139,6 @@ def __init__( llm_config: dict | Literal[False] | None = None, is_termination_msg: Callable[[dict], bool] | None = None, max_consecutive_auto_reply: int | None = None, - human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", description: str | None = None, **kwargs, ): @@ -3233,7 +3155,7 @@ def __init__( The dict can contain the following keys: "content", "role", "name", "function_call". max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). - The limit only plays a role when human_input_mode is not "ALWAYS". + The limit only plays a role. **kwargs (dict): Please refer to other kwargs in [ConversableAgent](conversable_agent#__init__). """ @@ -3242,7 +3164,6 @@ def __init__( system_message, is_termination_msg, max_consecutive_auto_reply, - human_input_mode, llm_config=llm_config, description=description, **kwargs, From 5732ad83d580ef8db81d5d7cc5de69947a125fb3 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:40:48 +0900 Subject: [PATCH 12/25] remove human_input_mode --- train_methods/utils_cogfd.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/train_methods/utils_cogfd.py b/train_methods/utils_cogfd.py index 7d1411b..04031fd 100644 --- a/train_methods/utils_cogfd.py +++ b/train_methods/utils_cogfd.py @@ -180,7 +180,6 @@ def generate_and_save_concept_graph( ''', llm_config={"config_list": [{"model": "gpt-4o", "api_key": OPENAI_API_KEY, "base_url": base_url}]}, is_termination_msg=lambda msg: "the answer is correct!" in msg.get("content", "").lower(), - human_input_mode="NEVER", ) reviewer = AssistantAgent( @@ -193,7 +192,6 @@ def generate_and_save_concept_graph( If there are some mistakes in the generated graph, please point them out and tell the Generator how to fix them. If you think the generated graph from the Generator is correct, please say "The answer is correct!" and close the chat. You must check carefully!!! """, - human_input_mode="NEVER", ) group_chat_with_introductions = GroupChat( From b4182d499dabbfac999a2c60da9bb5fa9226d40e Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:44:29 +0900 Subject: [PATCH 13/25] remove arguments from AssistantAgent --- .../legacy_autogen_conversable_agent.py | 59 ++----------------- 1 file changed, 5 insertions(+), 54 deletions(-) diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index 7e59338..55b8131 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -286,7 +286,6 @@ def __init__( code_execution_config: dict | Literal[False] = False, llm_config: dict | Literal[False] | None = None, default_auto_reply: str | dict = "", - description: str | None = None, chat_messages: dict[Agent, list[dict]] | None = None, silent: bool | None = None, ): @@ -323,8 +322,6 @@ def __init__( To disable llm-based auto reply, set to False. When set to None, will use self.DEFAULT_CONFIG, which defaults to False. default_auto_reply (str or dict): default auto reply when no code execution or llm-based reply is generated. - description (str): a short description of the agent. This description is used by other agents - (e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message) chat_messages (dict or None): the previous chat messages that this agent had in the past with other agents. Can be used to give the agent a memory by providing the chat history. This will allow the agent to resume previous had conversations. Defaults to an empty chat history. @@ -345,7 +342,7 @@ def __init__( self._oai_messages = chat_messages self._oai_system_message = [{"content": system_message, "role": "system"}] - self._description = description if description is not None else system_message + self._description = system_message self._is_termination_msg = ( is_termination_msg if is_termination_msg is not None @@ -2780,24 +2777,6 @@ def register_for_llm( Returns: The decorator for registering a function to be used by an agent. - - Examples: - ``` - @user_proxy.register_for_execution() - @agent2.register_for_llm() - @agent1.register_for_llm(description="This is a very useful function") - def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str: - return a + str(b * c) - ``` - - For Azure OpenAI versions prior to 2023-12-01-preview, set `api_style` - to `"function"` if `"tool"` doesn't work: - ``` - @agent2.register_for_llm(api_style="function") - def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str: - return a + str(b * c) - ``` - """ def _decorator(func: F) -> F: @@ -2860,15 +2839,6 @@ def register_for_execution( Returns: The decorator for registering a function to be used by an agent. - Examples: - ``` - @user_proxy.register_for_execution() - @agent2.register_for_llm() - @agent1.register_for_llm(description="This is a very useful function") - def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14): - return a + str(b * c) - ``` - """ def _decorator(func: F) -> F: @@ -3137,10 +3107,6 @@ def __init__( name: str, system_message: str | None = DEFAULT_SYSTEM_MESSAGE, llm_config: dict | Literal[False] | None = None, - is_termination_msg: Callable[[dict], bool] | None = None, - max_consecutive_auto_reply: int | None = None, - description: str | None = None, - **kwargs, ): """ Args: @@ -3150,27 +3116,12 @@ def __init__( llm_config (dict or False or None): llm inference configuration. Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) for available options. - is_termination_msg (function): a function that takes a message in the form of a dictionary - and returns a boolean value indicating if this received message is a termination message. - The dict can contain the following keys: "content", "role", "name", "function_call". - max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. - default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). - The limit only plays a role. - **kwargs (dict): Please refer to other kwargs in - [ConversableAgent](conversable_agent#__init__). """ super().__init__( - name, - system_message, - is_termination_msg, - max_consecutive_auto_reply, + name=name, + system_message=system_message, llm_config=llm_config, - description=description, - **kwargs, ) - # Update the provided description if None, and we are using the default system_message, - # then use the default description. - if description is None: - if system_message == self.DEFAULT_SYSTEM_MESSAGE: - self.description = self.DEFAULT_DESCRIPTION + if system_message == self.DEFAULT_SYSTEM_MESSAGE: + self.description = self.DEFAULT_DESCRIPTION From 6872c8519dfdcba8f2370555fad3d72e62603247 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:50:03 +0900 Subject: [PATCH 14/25] remove code_exec --- .../legacy_autogen_conversable_agent.py | 204 +----------------- 1 file changed, 3 insertions(+), 201 deletions(-) diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index 55b8131..16e4852 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -283,11 +283,8 @@ def __init__( is_termination_msg: Callable[[dict], bool] | None = None, max_consecutive_auto_reply: int | None = None, function_map: dict[str, Callable] | None = None, - code_execution_config: dict | Literal[False] = False, llm_config: dict | Literal[False] | None = None, - default_auto_reply: str | dict = "", chat_messages: dict[Agent, list[dict]] | None = None, - silent: bool | None = None, ): """ Args: @@ -300,39 +297,16 @@ def __init__( default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). When set to 0, no auto reply will be generated. function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions, also used for tool calls. - code_execution_config (dict or False): config for the code execution. - To disable code execution, set to False. Otherwise, set to a dictionary with the following keys: - - work_dir (Optional, str): The working directory for the code execution. - If None, a default working directory will be used. - The default working directory is the "extensions" directory under - "path_to_autogen". - - use_docker (Optional, list, str or bool): The docker image to use for code execution. - Default is True, which means the code will be executed in a docker container. A default list of images will be used. - If a list or a str of image name(s) is provided, the code will be executed in a docker container - with the first image successfully pulled. - If False, the code will be executed in the current environment. - We strongly recommend using docker for code execution. - - timeout (Optional, int): The maximum execution time in seconds. - - last_n_messages (Experimental, int or str): The number of messages to look back for code execution. - If set to 'auto', it will scan backwards through all messages arriving since the agent last spoke, which is typically the last time execution was attempted. (Default: auto) llm_config (dict or False or None): llm inference configuration. Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) for available options. When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`. To disable llm-based auto reply, set to False. When set to None, will use self.DEFAULT_CONFIG, which defaults to False. - default_auto_reply (str or dict): default auto reply when no code execution or llm-based reply is generated. chat_messages (dict or None): the previous chat messages that this agent had in the past with other agents. Can be used to give the agent a memory by providing the chat history. This will allow the agent to resume previous had conversations. Defaults to an empty chat history. - silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of - silent in each function. """ - # we change code_execution_config below and we have to make sure we don't change the input - # in case of UserProxyAgent, without this we could even change the default value {} - code_execution_config = ( - code_execution_config.copy() if hasattr(code_execution_config, "copy") else code_execution_config - ) self._name = name # a dictionary of conversations, default value is list @@ -348,7 +322,6 @@ def __init__( if is_termination_msg is not None else (lambda x: content_str(x.get("content")) == "TERMINATE") ) - self.silent = silent # Take a copy to avoid modifying the given dict if isinstance(llm_config, dict): try: @@ -373,60 +346,12 @@ def __init__( if function_map is None else {name: callable for name, callable in function_map.items() if self._assert_valid_name(name)} ) - self._default_auto_reply = default_auto_reply self._reply_func_list = [] self._human_input = [] self.reply_at_receive = defaultdict(bool) self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True) - # Setting up code execution. - # Do not register code execution reply if code execution is disabled. - if code_execution_config is not False: - # If code_execution_config is None, set it to an empty dict. - if code_execution_config is None: - warnings.warn( - "Using None to signal a default code_execution_config is deprecated. " - "Use {} to use default or False to disable code execution.", - stacklevel=2, - ) - code_execution_config = {} - if not isinstance(code_execution_config, dict): - raise ValueError("code_execution_config must be a dict or False.") - - # We have got a valid code_execution_config. - self._code_execution_config = code_execution_config - - if self._code_execution_config.get("executor") is not None: - if "use_docker" in self._code_execution_config: - raise ValueError( - "'use_docker' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." - ) - - if "work_dir" in self._code_execution_config: - raise ValueError( - "'work_dir' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." - ) - - if "timeout" in self._code_execution_config: - raise ValueError( - "'timeout' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." - ) - - # Use the new code executor. - self._code_executor = CodeExecutorFactory.create(self._code_execution_config) - self.register_reply([Agent, None], ConversableAgent._generate_code_execution_reply_using_executor) - else: - # Legacy code execution using code_utils. - use_docker = self._code_execution_config.get("use_docker", None) - use_docker = decide_use_docker(use_docker) - check_can_use_docker_or_throw(use_docker) - self._code_execution_config["use_docker"] = use_docker - self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply) - else: - # Code execution is disabled. - self._code_execution_config = False - self.register_reply([Agent, None], ConversableAgent.generate_tool_calls_reply) self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply, ignore_async_in_sync_chat=True) self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply) @@ -780,7 +705,7 @@ def use_docker(self) -> bool | str | None: """Bool value of whether to use docker to execute the code, or str value of the docker image name to use, or None when code execution is disabled. """ - return None if self._code_execution_config is False else self._code_execution_config.get("use_docker") + return None @staticmethod def _message_to_dict(message: dict | str) -> dict: @@ -1691,127 +1616,6 @@ def _generate_oai_reply( ), ) - def _generate_code_execution_reply_using_executor( - self, - messages: list[dict] | None = None, - sender: Agent | None = None, - config: dict | Literal[False] | None = None, - ): - """Generate a reply using code executor.""" - iostream = IOStream.get_default() - - if config is not None: - raise ValueError("config is not supported for _generate_code_execution_reply_using_executor.") - if self._code_execution_config is False: - return False, None - if messages is None: - messages = self._oai_messages[sender] - last_n_messages = self._code_execution_config.get("last_n_messages", "auto") - - if not (isinstance(last_n_messages, (int, float)) and last_n_messages >= 0) and last_n_messages != "auto": - raise ValueError("last_n_messages must be either a non-negative integer, or the string 'auto'.") - - num_messages_to_scan = last_n_messages - if last_n_messages == "auto": - # Find when the agent last spoke - num_messages_to_scan = 0 - for message in reversed(messages): - if "role" not in message: - break - elif message["role"] != "user": - break - else: - num_messages_to_scan += 1 - num_messages_to_scan = min(len(messages), num_messages_to_scan) - messages_to_scan = messages[-num_messages_to_scan:] - - # iterate through the last n messages in reverse - # if code blocks are found, execute the code blocks and return the output - # if no code blocks are found, continue - for message in reversed(messages_to_scan): - if not message["content"]: - continue - code_blocks = self._code_executor.code_extractor.extract_code_blocks(message["content"]) - if len(code_blocks) == 0: - continue - - num_code_blocks = len(code_blocks) - if num_code_blocks == 1: - iostream.print( - colored( - f"\n>>>>>>>> EXECUTING CODE BLOCK (inferred language is {code_blocks[0].language})...", - "red", - ), - flush=True, - ) - else: - iostream.print( - colored( - f"\n>>>>>>>> EXECUTING {num_code_blocks} CODE BLOCKS (inferred languages are [{', '.join([x.language for x in code_blocks])}])...", - "red", - ), - flush=True, - ) - - # found code blocks, execute code. - code_result = self._code_executor.execute_code_blocks(code_blocks) - exitcode2str = "execution succeeded" if code_result.exit_code == 0 else "execution failed" - return True, f"exitcode: {code_result.exit_code} ({exitcode2str})\nCode output: {code_result.output}" - - return False, None - - def generate_code_execution_reply( - self, - messages: list[dict] | None = None, - sender: Agent | None = None, - config: dict | Literal[False] | None = None, - ): - """Generate a reply using code execution.""" - code_execution_config = config if config is not None else self._code_execution_config - if code_execution_config is False: - return False, None - if messages is None: - messages = self._oai_messages[sender] - last_n_messages = code_execution_config.pop("last_n_messages", "auto") - - if not (isinstance(last_n_messages, (int, float)) and last_n_messages >= 0) and last_n_messages != "auto": - raise ValueError("last_n_messages must be either a non-negative integer, or the string 'auto'.") - - messages_to_scan = last_n_messages - if last_n_messages == "auto": - # Find when the agent last spoke - messages_to_scan = 0 - for i in range(len(messages)): - message = messages[-(i + 1)] - if "role" not in message: - break - elif message["role"] != "user": - break - else: - messages_to_scan += 1 - - # iterate through the last n messages in reverse - # if code blocks are found, execute the code blocks and return the output - # if no code blocks are found, continue - for i in range(min(len(messages), messages_to_scan)): - message = messages[-(i + 1)] - if not message["content"]: - continue - code_blocks = extract_code(message["content"]) - if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN: - continue - - # found code blocks, execute code and push "last_n_messages" back - exitcode, logs = self.execute_code_blocks(code_blocks) - code_execution_config["last_n_messages"] = last_n_messages - exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed" - return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}" - - # no code blocks are found, push last_n_messages back and return. - code_execution_config["last_n_messages"] = last_n_messages - - return False, None - def generate_function_call_reply( self, messages: list[dict] | None = None, @@ -2151,7 +1955,6 @@ def generate_reply( 1. check_termination_and_human_reply 2. generate_function_call_reply (deprecated in favor of tool_calls) 3. generate_tool_calls_reply - 4. generate_code_execution_reply 5. generate_oai_reply Every function returns a tuple (final, reply). When a function returns final=False, the next function will be checked. @@ -2194,7 +1997,7 @@ def generate_reply( final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) if final: return reply - return self._default_auto_reply + return "" async def a_generate_reply( self, @@ -2211,7 +2014,6 @@ async def a_generate_reply( 1. check_termination_and_human_reply 2. generate_function_call_reply 3. generate_tool_calls_reply - 4. generate_code_execution_reply 5. generate_oai_reply Every function returns a tuple (final, reply). When a function returns final=False, the next function will be checked. @@ -2258,7 +2060,7 @@ async def a_generate_reply( final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) if final: return reply - return self._default_auto_reply + return "" def _match_trigger(self, trigger: None | str | type | Agent | Callable | list, sender: Agent | None) -> bool: """Check if the sender matches the trigger. From f6cd4ff596671f8bb21716311f17866f99cc7eae Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 18:00:16 +0900 Subject: [PATCH 15/25] remove unused args --- train_methods/legacy_autogen/coding.py | 31 -- .../legacy_autogen_conversable_agent.py | 277 +----------------- train_methods/legacy_autogen/utils.py | 115 -------- 3 files changed, 10 insertions(+), 413 deletions(-) diff --git a/train_methods/legacy_autogen/coding.py b/train_methods/legacy_autogen/coding.py index fb82d42..a6c9cf7 100644 --- a/train_methods/legacy_autogen/coding.py +++ b/train_methods/legacy_autogen/coding.py @@ -815,34 +815,3 @@ def _execute_code_dont_check_setup(self, code_blocks: list[CodeBlock]) -> Comman def restart(self) -> None: """(Experimental) Restart the code executor.""" warnings.warn("Restarting local command line code executor is not supported. No action is taken.") - -class CodeExecutorFactory: - """(Experimental) A factory class for creating code executors.""" - - @staticmethod - def create(code_execution_config: CodeExecutionConfig) -> CodeExecutor: - """(Experimental) Get a code executor based on the code execution config. - - Args: - code_execution_config (Dict): The code execution config, - which is a dictionary that must contain the key "executor". - The value of the key "executor" can be either a string - or an instance of CodeExecutor, in which case the code - executor is returned directly. - - Returns: - CodeExecutor: The code executor. - - Raises: - ValueError: If the code executor is unknown or not specified. - """ - executor = code_execution_config.get("executor") - if isinstance(executor, CodeExecutor): - # If the executor is already an instance of CodeExecutor, return it. - return executor - if executor == "ipython-embedded": - return EmbeddedIPythonCodeExecutor(**code_execution_config.get("ipython-embedded", {})) - elif executor == "commandline-local": - return LocalCommandLineCodeExecutor(**code_execution_config.get("commandline-local", {})) - else: - raise ValueError(f"Unknown code executor {executor}") diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index 16e4852..980900c 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -3,7 +3,6 @@ import copy import functools import inspect -import json import re import warnings from collections import defaultdict @@ -16,15 +15,10 @@ from train_methods.legacy_autogen.cache import AbstractCache from train_methods.legacy_autogen.chat import ChatResult, a_initiate_chats, initiate_chats, _post_process_carryover_item, consolidate_chat_info from train_methods.legacy_autogen.client import ModelClient, OpenAIWrapper -from train_methods.legacy_autogen.coding import CodeExecutor, CodeExecutorFactory +from train_methods.legacy_autogen.coding import CodeExecutor from train_methods.legacy_autogen.stream import IOStream from train_methods.legacy_autogen.utils import ( - check_can_use_docker_or_throw, content_str, - decide_use_docker, - execute_code, - extract_code, - infer_lang, load_basemodels_if_needed, serialize_to_str, get_function_schema @@ -262,11 +256,6 @@ class ConversableAgent(LLMAgent): After receiving each message, the agent will send a reply to the sender unless the msg is a termination msg. For example, AssistantAgent and UserProxyAgent are subclasses of this class, configured with different default settings. - - To modify auto reply, override `generate_reply` method. - To modify the way to get human input, override `get_human_input` method. - To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`, - `run_code`, and `execute_function` methods respectively. """ DEFAULT_CONFIG = False # False or dict, the default config for llm inference @@ -281,10 +270,7 @@ def __init__( name: str, system_message: str | list | None = "You are a helpful AI Assistant.", is_termination_msg: Callable[[dict], bool] | None = None, - max_consecutive_auto_reply: int | None = None, - function_map: dict[str, Callable] | None = None, llm_config: dict | Literal[False] | None = None, - chat_messages: dict[Agent, list[dict]] | None = None, ): """ Args: @@ -296,24 +282,16 @@ def __init__( max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). When set to 0, no auto reply will be generated. - function_map (dict[str, callable]): Mapping function names (passed to openai) to callable functions, also used for tool calls. llm_config (dict or False or None): llm inference configuration. Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) for available options. When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`. To disable llm-based auto reply, set to False. When set to None, will use self.DEFAULT_CONFIG, which defaults to False. - chat_messages (dict or None): the previous chat messages that this agent had in the past with other agents. - Can be used to give the agent a memory by providing the chat history. This will allow the agent to - resume previous had conversations. Defaults to an empty chat history. """ self._name = name - # a dictionary of conversations, default value is list - if chat_messages is None: - self._oai_messages = defaultdict(list) - else: - self._oai_messages = chat_messages + self._oai_messages = defaultdict(list) self._oai_system_message = [{"content": system_message, "role": "system"}] self._description = system_message @@ -324,28 +302,14 @@ def __init__( ) # Take a copy to avoid modifying the given dict if isinstance(llm_config, dict): - try: - llm_config = copy.deepcopy(llm_config) - except TypeError as e: - raise TypeError( - "Please implement __deepcopy__ method for each value class in llm_config to support deepcopy." - " Refer to the docs for more details: https://microsoft.github.io/autogen/docs/topics/llm_configuration#adding-http-client-in-llm_config-for-proxy" - ) from e + llm_config = copy.deepcopy(llm_config) self._validate_llm_config(llm_config) self.client_cache = None - - self._max_consecutive_auto_reply = ( - max_consecutive_auto_reply if max_consecutive_auto_reply is not None else self.MAX_CONSECUTIVE_AUTO_REPLY - ) + self._max_consecutive_auto_reply = self.MAX_CONSECUTIVE_AUTO_REPLY self._consecutive_auto_reply_counter = defaultdict(int) self._max_consecutive_auto_reply_dict = defaultdict(self.max_consecutive_auto_reply) - self._function_map = ( - {} - if function_map is None - else {name: callable for name, callable in function_map.items() if self._assert_valid_name(name)} - ) self._reply_func_list = [] self._human_input = [] self.reply_at_receive = defaultdict(bool) @@ -390,7 +354,7 @@ def _validate_llm_config(self, llm_config): @staticmethod def _is_silent(agent: Agent, silent: bool | None = False) -> bool: - return agent.silent if agent.silent is not None else silent + return silent @property def name(self) -> str: @@ -700,13 +664,6 @@ def last_message(self, agent: Agent | None = None) -> dict | None: ) return self._oai_messages[agent][-1] - @property - def use_docker(self) -> bool | str | None: - """Bool value of whether to use docker to execute the code, - or str value of the docker image name to use, or None when code execution is disabled. - """ - return None - @staticmethod def _message_to_dict(message: dict | str) -> dict: """Convert a message to a dictionary. @@ -1635,7 +1592,7 @@ def generate_function_call_reply( message = messages[-1] if "function_call" in message and message["function_call"]: func_call = message["function_call"] - func = self._function_map.get(func_call.get("name", None), None) + func = None if inspect.iscoroutinefunction(func): try: # get the running loop if it was already created @@ -1673,12 +1630,7 @@ async def a_generate_function_call_reply( message = messages[-1] func_call = message.get("function_call") if func_call: - func_name = func_call.get("name", "") - func = self._function_map.get(func_name, None) - if func and inspect.iscoroutinefunction(func): - _, func_return = await self.a_execute_function(func_call) - else: - _, func_return = self.execute_function(func_call) + _, func_return = self.execute_function(func_call) return True, func_return return False, None @@ -1701,7 +1653,7 @@ def generate_tool_calls_reply( tool_returns = [] for tool_call in message.get("tool_calls", []): function_call = tool_call.get("function", {}) - func = self._function_map.get(function_call.get("name", None), None) + func = None if inspect.iscoroutinefunction(func): try: # get the running loop if it was already created @@ -2128,66 +2080,6 @@ async def a_get_human_input(self, prompt: str) -> str: reply = await loop.run_in_executor(None, functools.partial(self.get_human_input, prompt)) return reply - def run_code(self, code, **kwargs): - """Run the code and return the result. - - Override this function to modify the way to run the code. - Args: - code (str): the code to be executed. - **kwargs: other keyword arguments. - - Returns: - A tuple of (exitcode, logs, image). - exitcode (int): the exit code of the code execution. - logs (str): the logs of the code execution. - image (str or None): the docker image used for the code execution. - """ - return execute_code(code, **kwargs) - - def execute_code_blocks(self, code_blocks): - """Execute the code blocks and return the result.""" - iostream = IOStream.get_default() - - logs_all = "" - for i, code_block in enumerate(code_blocks): - lang, code = code_block - if not lang: - lang = infer_lang(code) - iostream.print( - colored( - f"\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...", - "red", - ), - flush=True, - ) - if lang in ["bash", "shell", "sh"]: - exitcode, logs, image = self.run_code(code, lang=lang, **self._code_execution_config) - elif lang in PYTHON_VARIANTS: - if code.startswith("# filename: "): - filename = code[11 : code.find("\n")].strip() - else: - filename = None - exitcode, logs, image = self.run_code( - code, - lang="python", - filename=filename, - **self._code_execution_config, - ) - else: - # In case the language is not supported, we return an error message. - exitcode, logs, image = ( - 1, - f"unknown language {lang}", - None, - ) - # raise NotImplementedError - if image is not None: - self._code_execution_config["use_docker"] = image - logs_all += "\n" + logs - if exitcode != 0: - return exitcode, logs_all - return exitcode, logs_all - @staticmethod def _format_json_str(jstr): """Remove newlines outside of quotes, and handle JSON escape sequences. @@ -2234,40 +2126,11 @@ def execute_function(self, func_call, verbose: bool = False) -> tuple[bool, dict "function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call """ - iostream = IOStream.get_default() func_name = func_call.get("name", "") - func = self._function_map.get(func_name, None) is_exec_success = False - if func is not None: - # Extract arguments from a json-like string and put it into a dict. - input_string = self._format_json_str(func_call.get("arguments", "{}")) - try: - arguments = json.loads(input_string) - except json.JSONDecodeError as e: - arguments = None - content = f"Error: {e}\n The argument must be in JSON format." - - # Try to execute the function - if arguments is not None: - iostream.print( - colored(f"\n>>>>>>>> EXECUTING FUNCTION {func_name}...", "magenta"), - flush=True, - ) - try: - content = func(**arguments) - is_exec_success = True - except Exception as e: - content = f"Error: {e}" - else: - content = f"Error: Function {func_name} not found." - - if verbose: - iostream.print( - colored(f"\nInput arguments: {arguments}\nOutput:\n{content}", "magenta"), - flush=True, - ) + content = f"Error: Function {func_name} not found." return is_exec_success, { "name": func_name, @@ -2291,38 +2154,11 @@ async def a_execute_function(self, func_call): "function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call """ - iostream = IOStream.get_default() func_name = func_call.get("name", "") - func = self._function_map.get(func_name, None) is_exec_success = False - if func is not None: - # Extract arguments from a json-like string and put it into a dict. - input_string = self._format_json_str(func_call.get("arguments", "{}")) - try: - arguments = json.loads(input_string) - except json.JSONDecodeError as e: - arguments = None - content = f"Error: {e}\n The argument must be in JSON format." - - # Try to execute the function - if arguments is not None: - iostream.print( - colored(f"\n>>>>>>>> EXECUTING ASYNC FUNCTION {func_name}...", "magenta"), - flush=True, - ) - try: - if inspect.iscoroutinefunction(func): - content = await func(**arguments) - else: - # Fallback to sync function if the function is not async - content = func(**arguments) - is_exec_success = True - except Exception as e: - content = f"Error: {e}" - else: - content = f"Error: Function {func_name} not found." + content = f"Error: Function {func_name} not found." return is_exec_success, { "name": func_name, @@ -2408,21 +2244,6 @@ async def a_generate_init_message(self, message: dict | str | None, **kwargs) -> return self._handle_carryover(message, kwargs) - def register_function(self, function_map: dict[str, Callable | None]): - """Register functions to the agent. - - Args: - function_map: a dictionary mapping function names to functions. if function_map[name] is None, the function will be removed from the function_map. - """ - for name, func in function_map.items(): - self._assert_valid_name(name) - if func is None and name not in self._function_map.keys(): - warnings.warn(f"The function {name} to remove doesn't exist", name) - if name in self._function_map: - warnings.warn(f"Function '{name}' is being overridden.", UserWarning) - self._function_map.update(function_map) - self._function_map = {k: v for k, v in self._function_map.items() if v is not None} - def update_function_signature(self, func_sig: str | dict, is_remove: None): """update a function_signature in the LLM configuration for function_call. @@ -2512,16 +2333,6 @@ def update_tool_signature(self, tool_sig: str | dict, is_remove: None): self.client = OpenAIWrapper(**self.llm_config) - def can_execute_function(self, name: list[str] | str) -> bool: - """Whether the agent can execute the function.""" - names = name if isinstance(name, list) else [name] - return all([n in self._function_map for n in names]) - - @property - def function_map(self) -> dict[str, Callable]: - """Return the function map.""" - return self._function_map - def _wrap_function(self, func: F) -> F: """Wrap the function to dump the return value to json. @@ -2627,47 +2438,6 @@ def _decorator(func: F) -> F: return _decorator - def register_for_execution( - self, - name: str | None = None, - ) -> Callable[[F], F]: - """Decorator factory for registering a function to be executed by an agent. - - It's return value is used to decorate a function to be registered to the agent. - - Args: - name (optional(str)): name of the function. If None, the function name will be used (default: None). - - Returns: - The decorator for registering a function to be used by an agent. - - """ - - def _decorator(func: F) -> F: - """Decorator for registering a function to be used by an agent. - - Args: - func: the function to be registered. - - Returns: - The function to be registered, with the _description attribute set to the function description. - - Raises: - ValueError: if the function description is not provided and not propagated by a previous decorator. - - """ - # name can be overwritten by the parameter, by default it is the same as function name - if name: - func._name = name - elif not hasattr(func, "_name"): - func._name = func.__name__ - - self.register_function({func._name: self._wrap_function(func)}) - - return func - - return _decorator - def register_model_client(self, model_client_cls: ModelClient, **kwargs): """Register a model client. @@ -2853,33 +2623,6 @@ def get_total_usage(self) -> dict[str, int] | None: return self.client.total_usage_summary -def register_function( - f: Callable[..., Any], - *, - caller: ConversableAgent, - executor: ConversableAgent, - name: str | None = None, - description: str, -) -> None: - """Register a function to be proposed by an agent and executed for an executor. - - This function can be used instead of function decorators `@ConversationAgent.register_for_llm` and - `@ConversationAgent.register_for_execution`. - - Args: - f: the function to be registered. - caller: the agent calling the function, typically an instance of ConversableAgent. - executor: the agent executing the function, typically an instance of UserProxy. - name: name of the function. If None, the function name will be used (default: None). - description: description of the function. The description is used by LLM to decode whether the function - is called. Make sure the description is properly describing what the function does or it might not be - called by LLM when needed. - - """ - f = caller.register_for_llm(name=name, description=description)(f) - executor.register_for_execution(name=name)(f) - - class AssistantAgent(ConversableAgent): """ AssistantAgent is a subclass of ConversableAgent configured with a default system message. diff --git a/train_methods/legacy_autogen/utils.py b/train_methods/legacy_autogen/utils.py index 6bc0bd2..dbfd697 100644 --- a/train_methods/legacy_autogen/utils.py +++ b/train_methods/legacy_autogen/utils.py @@ -84,78 +84,6 @@ def content_str(content: str | list[UserMessageTextContentPart | UserMessageImag return rst -def infer_lang(code: str) -> str: - """infer the language for the code. - TODO: make it robust. - """ - if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "): - return "sh" - - # check if code is a valid python code - try: - compile(code, "test", "exec") - return "python" - except SyntaxError: - # not a valid python code - return UNKNOWN - - -# TODO: In the future move, to better support https://spec.commonmark.org/0.30/#fenced-code-blocks -# perhaps by using a full Markdown parser. -def extract_code( - text: str | list, pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False -) -> list[tuple[str, str]]: - """Extract code from a text. - - Args: - text (str or List): The content to extract code from. The content can be - a string or a list, as returned by standard GPT or multimodal GPT. - pattern (str, optional): The regular expression pattern for finding the - code block. Defaults to CODE_BLOCK_PATTERN. - detect_single_line_code (bool, optional): Enable the new feature for - extracting single line code. Defaults to False. - - Returns: - list: A list of tuples, each containing the language and the code. - If there is no code block in the input text, the language would be "unknown". - If there is code block but the language is not specified, the language would be "". - """ - text = content_str(text) - if not detect_single_line_code: - match = re.findall(pattern, text, flags=re.DOTALL) - return match if match else [(UNKNOWN, text)] - - # Extract both multi-line and single-line code block, separated by the | operator - # `([^`]+)`: Matches inline code. - code_pattern = re.compile(CODE_BLOCK_PATTERN + r"|`([^`]+)`") - code_blocks = code_pattern.findall(text) - - # Extract the individual code blocks and languages from the matched groups - extracted = [] - for lang, group1, group2 in code_blocks: - if group1: - extracted.append((lang.strip(), group1.strip())) - elif group2: - extracted.append(("", group2.strip())) - - return extracted - - -def generate_code(pattern: str = CODE_BLOCK_PATTERN, **config) -> tuple[str, float]: - """(openai<1) Generate code. - - Args: - pattern (Optional, str): The regular expression pattern for finding the code block. - The default pattern is for finding a code block in a markdown file. - config (Optional, dict): The configuration for the API call. - - Returns: - str: The generated code. - float: The cost of the generation. - """ - response = Completion.create(**config) - return extract_code(Completion.extract_text(response)[0], pattern), response["cost"] - _IMPROVE_FUNCTION_CONFIG = { "prompt": """Improve the function '{func_name}' to achieve the objective '{objective}'. @@ -286,44 +214,6 @@ def in_docker_container() -> bool: return os.path.exists("/.dockerenv") -def decide_use_docker(use_docker: bool | None) -> bool | None: - if use_docker is None: - env_var_use_docker = os.environ.get("AUTOGEN_USE_DOCKER", "True") - - truthy_values = {"1", "true", "yes", "t"} - falsy_values = {"0", "false", "no", "f"} - - # Convert the value to lowercase for case-insensitive comparison - env_var_use_docker_lower = env_var_use_docker.lower() - - # Determine the boolean value based on the environment variable - if env_var_use_docker_lower in truthy_values: - use_docker = True - elif env_var_use_docker_lower in falsy_values: - use_docker = False - elif env_var_use_docker_lower == "none": # Special case for 'None' as a string - use_docker = None - else: - # Raise an error for any unrecognized value - raise ValueError( - f'Invalid value for AUTOGEN_USE_DOCKER: {env_var_use_docker}. Please set AUTOGEN_USE_DOCKER to "1/True/yes", "0/False/no", or "None".' - ) - return use_docker - - -def check_can_use_docker_or_throw(use_docker) -> None: - if use_docker is not None: - inside_docker = in_docker_container() - docker_installed_and_running = is_docker_running() - if use_docker and not inside_docker and not docker_installed_and_running: - raise RuntimeError( - "Code execution is set to be run in docker (default behaviour) but docker is not running.\n" - "The options available are:\n" - "- Make sure docker is running (advised approach for code execution)\n" - '- Set "use_docker": False in code_execution_config\n' - '- Set AUTOGEN_USE_DOCKER to "0/False/no" in your environment variables' - ) - def _sanitize_filename_for_docker_tag(filename: str) -> str: """Convert a filename to a valid docker tag. @@ -398,11 +288,6 @@ def execute_code( running_inside_docker = in_docker_container() docker_running = is_docker_running() - # SENTINEL is used to indicate that the user did not explicitly set the argument - if use_docker is SENTINEL: - use_docker = decide_use_docker(use_docker=None) - check_can_use_docker_or_throw(use_docker) - timeout = timeout or DEFAULT_TIMEOUT original_filename = filename if WIN32 and lang in ["sh", "shell"] and (not use_docker): From a3f63ce7a485788299bb82a6b1ae2c9ef1e0179f Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 18:07:42 +0900 Subject: [PATCH 16/25] remove unused code --- .../legacy_autogen_conversable_agent.py | 94 ++----------------- 1 file changed, 7 insertions(+), 87 deletions(-) diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index 980900c..872086b 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -27,8 +27,6 @@ __all__ = ("ConversableAgent",) F = TypeVar("F", bound=Callable[..., Any]) -PYTHON_VARIANTS = ["python", "Python", "py"] -UNKNOWN = "unknown" def model_dump(model: BaseModel) -> dict[str, Any]: return model.model_dump() @@ -1591,23 +1589,7 @@ def generate_function_call_reply( messages = self._oai_messages[sender] message = messages[-1] if "function_call" in message and message["function_call"]: - func_call = message["function_call"] - func = None - if inspect.iscoroutinefunction(func): - try: - # get the running loop if it was already created - loop = asyncio.get_running_loop() - close_loop = False - except RuntimeError: - # create a loop if there is no running loop - loop = asyncio.new_event_loop() - close_loop = True - - _, func_return = loop.run_until_complete(self.a_execute_function(func_call)) - if close_loop: - loop.close() - else: - _, func_return = self.execute_function(message["function_call"]) + func_return = self.execute_function(message["function_call"]) return True, func_return return False, None @@ -1630,7 +1612,7 @@ async def a_generate_function_call_reply( message = messages[-1] func_call = message.get("function_call") if func_call: - _, func_return = self.execute_function(func_call) + func_return = self.execute_function(func_call) return True, func_return return False, None @@ -1653,22 +1635,7 @@ def generate_tool_calls_reply( tool_returns = [] for tool_call in message.get("tool_calls", []): function_call = tool_call.get("function", {}) - func = None - if inspect.iscoroutinefunction(func): - try: - # get the running loop if it was already created - loop = asyncio.get_running_loop() - close_loop = False - except RuntimeError: - # create a loop if there is no running loop - loop = asyncio.new_event_loop() - close_loop = True - - _, func_return = loop.run_until_complete(self.a_execute_function(function_call)) - if close_loop: - loop.close() - else: - _, func_return = self.execute_function(function_call) + func_return = self.execute_function(function_call) content = func_return.get("content", "") if content is None: content = "" @@ -1698,7 +1665,7 @@ def generate_tool_calls_reply( async def _a_execute_tool_call(self, tool_call): id = tool_call["id"] function_call = tool_call.get("function", {}) - _, func_return = await self.a_execute_function(function_call) + func_return = await self.a_execute_function(function_call) return { "tool_call_id": id, "role": "tool", @@ -2080,37 +2047,7 @@ async def a_get_human_input(self, prompt: str) -> str: reply = await loop.run_in_executor(None, functools.partial(self.get_human_input, prompt)) return reply - @staticmethod - def _format_json_str(jstr): - """Remove newlines outside of quotes, and handle JSON escape sequences. - - 1. this function removes the newline in the query outside of quotes otherwise json.loads(s) will fail. - Ex 1: - "{\n"tool": "python",\n"query": "print('hello')\nprint('world')"\n}" -> "{"tool": "python","query": "print('hello')\nprint('world')"}" - Ex 2: - "{\n \"location\": \"Boston, MA\"\n}" -> "{"location": "Boston, MA"}" - - 2. this function also handles JSON escape sequences inside quotes. - Ex 1: - '{"args": "a\na\na\ta"}' -> '{"args": "a\\na\\na\\ta"}' - """ - result = [] - inside_quotes = False - last_char = " " - for char in jstr: - if last_char != "\\" and char == '"': - inside_quotes = not inside_quotes - last_char = char - if not inside_quotes and char == "\n": - continue - if inside_quotes and char == "\n": - char = "\\n" - if inside_quotes and char == "\t": - char = "\\t" - result.append(char) - return "".join(result) - - def execute_function(self, func_call, verbose: bool = False) -> tuple[bool, dict[str, str]]: + def execute_function(self, func_call: dict) -> dict[str, str]: """Execute a function call and return the result. Override this function to modify the way to execute function and tool calls. @@ -2118,21 +2055,12 @@ def execute_function(self, func_call, verbose: bool = False) -> tuple[bool, dict Args: func_call: a dictionary extracted from openai message at "function_call" or "tool_calls" with keys "name" and "arguments". - Returns: - A tuple of (is_exec_success, result_dict). - is_exec_success (boolean): whether the execution is successful. - result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function". - - "function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) - See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call """ func_name = func_call.get("name", "") - - is_exec_success = False content = f"Error: Function {func_name} not found." - return is_exec_success, { + return { "name": func_name, "role": "function", "content": str(content), @@ -2146,21 +2074,13 @@ async def a_execute_function(self, func_call): Args: func_call: a dictionary extracted from openai message at key "function_call" or "tool_calls" with keys "name" and "arguments". - Returns: - A tuple of (is_exec_success, result_dict). - is_exec_success (boolean): whether the execution is successful. - result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function". - - "function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) - See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call """ func_name = func_call.get("name", "") - is_exec_success = False content = f"Error: Function {func_name} not found." - return is_exec_success, { + return { "name": func_name, "role": "function", "content": str(content), From f0856c4cf4ca1b4563cd7554a7d8662d9e7e2aa1 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 18:25:56 +0900 Subject: [PATCH 17/25] remove code exec func --- requirements.txt | 2 - train_methods/legacy_autogen/coding.py | 786 ----------- train_methods/legacy_autogen/completion.py | 1151 ----------------- .../legacy_autogen_conversable_agent.py | 86 +- train_methods/legacy_autogen/utils.py | 767 +---------- 5 files changed, 6 insertions(+), 2786 deletions(-) delete mode 100644 train_methods/legacy_autogen/completion.py diff --git a/requirements.txt b/requirements.txt index a4087eb..7e5a398 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,9 +14,7 @@ diskcache==5.6.3 redis==7.0.0 azure-cosmos==4.14.0 azure-identity==1.25.1 -docker==7.1.0 flaml==2.3.6 -jupyter-client==8.6.3 gdown==5.2.0 open_clip_torch==2.29.0 diff --git a/train_methods/legacy_autogen/coding.py b/train_methods/legacy_autogen/coding.py index a6c9cf7..c1f4bca 100644 --- a/train_methods/legacy_autogen/coding.py +++ b/train_methods/legacy_autogen/coding.py @@ -24,794 +24,8 @@ from pydantic import BaseModel, Field, field_validator -from pydantic import BaseModel, Field - -from train_methods.legacy_autogen.utils import UserMessageImageContentPart, UserMessageTextContentPart, content_str, infer_lang, UNKNOWN, CODE_BLOCK_PATTERN, PYTHON_VARIANTS, WIN32, TIMEOUT_MSG, _cmd A = ParamSpec("A") T = TypeVar("T") P = ParamSpec("P") -class CodeBlock(BaseModel): - """(Experimental) A class that represents a code block.""" - - code: str = Field(description="The code to execute.") - - language: str = Field(description="The language of the code.") - -class CodeResult(BaseModel): - """(Experimental) A class that represents the result of a code execution.""" - - exit_code: int = Field(description="The exit code of the code execution.") - - output: str = Field(description="The output of the code execution.") - - -class CodeExtractor(Protocol): - """(Experimental) A code extractor class that extracts code blocks from a message.""" - - def extract_code_blocks( - self, message: str | list[UserMessageTextContentPart | UserMessageImageContentPart] | None - ) -> list[CodeBlock]: - """(Experimental) Extract code blocks from a message. - - Args: - message (str): The message to extract code blocks from. - - Returns: - List[CodeBlock]: The extracted code blocks. - """ - ... # pragma: no cover - -class CodeExecutor(Protocol): - """(Experimental) A code executor class that executes code blocks and returns the result.""" - - @property - def code_extractor(self) -> CodeExtractor: - """(Experimental) The code extractor used by this code executor.""" - ... # pragma: no cover - - def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CodeResult: - """(Experimental) Execute code blocks and return the result. - - This method should be implemented by the code executor. - - Args: - code_blocks (List[CodeBlock]): The code blocks to execute. - - Returns: - CodeResult: The result of the code execution. - """ - ... # pragma: no cover - - def restart(self) -> None: - """(Experimental) Restart the code executor. - - This method should be implemented by the code executor. - - This method is called when the agent is reset. - """ - ... # pragma: no cover - - -class IPythonCodeResult(CodeResult): - """(Experimental) A code result class for IPython code executor.""" - - output_files: list[str] = Field( - default_factory=list, - description="The list of files that the executed code blocks generated.", - ) - -class MarkdownCodeExtractor(CodeExtractor): - """(Experimental) A class that extracts code blocks from a message using Markdown syntax.""" - - def extract_code_blocks( - self, message: str | list[UserMessageTextContentPart | UserMessageImageContentPart] | None - ) -> list[CodeBlock]: - """(Experimental) Extract code blocks from a message. If no code blocks are found, - return an empty list. - - Args: - message (str): The message to extract code blocks from. - - Returns: - List[CodeBlock]: The extracted code blocks or an empty list. - """ - - text = content_str(message) - match = re.findall(CODE_BLOCK_PATTERN, text, flags=re.DOTALL) - if not match: - return [] - code_blocks = [] - for lang, code in match: - if lang == "": - lang = infer_lang(code) - if lang == UNKNOWN: - lang = "" - code_blocks.append(CodeBlock(code=code, language=lang)) - return code_blocks - - -CodeExecutionConfig = TypedDict( - "CodeExecutionConfig", - { - "executor": Literal["ipython-embedded", "commandline-local"] | CodeExecutor, - "last_n_messages": int | Literal["auto"], - "timeout": int, - "use_docker": bool | str | list[str], - "work_dir": str, - "ipython-embedded": Mapping[str, Any], - "commandline-local": Mapping[str, Any], - }, - total=False, -) - - -class EmbeddedIPythonCodeExecutor(BaseModel): - """(Experimental) A code executor class that executes code statefully using an embedded - IPython kernel managed by this class. - - **This will execute LLM generated code on the local machine.** - - Each execution is stateful and can access variables created from previous - executions in the same session. The kernel must be installed before using - this class. The kernel can be installed using the following command: - `python -m ipykernel install --user --name {kernel_name}` - where `kernel_name` is the name of the kernel to install. - - Args: - timeout (int): The timeout for code execution, by default 60. - kernel_name (str): The kernel name to use. Make sure it is installed. - By default, it is "python3". - output_dir (str): The directory to save output files, by default ".". - """ - - timeout: int = Field(default=60, ge=1, description="The timeout for code execution.") - kernel_name: str = Field(default="python3", description="The kernel name to use. Make sure it is installed.") - output_dir: str = Field(default=".", description="The directory to save output files.") - - @field_validator("output_dir") - @classmethod - def _output_dir_must_exist(cls, value: str) -> str: - if not os.path.exists(value): - raise ValueError(f"Output directory {value} does not exist.") - return value - - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - # Check if the kernel is installed. - if self.kernel_name not in KernelSpecManager().find_kernel_specs(): - raise ValueError( - f"Kernel {self.kernel_name} is not installed. " - "Please first install it with " - f"`python -m ipykernel install --user --name {self.kernel_name}`." - ) - self._kernel_manager = KernelManager(kernel_name=self.kernel_name) - self._kernel_manager.start_kernel() - self._kernel_client = self._kernel_manager.client() - self._kernel_client.start_channels() - self._timeout = self.timeout - self._kernel_name = self.kernel_name - self._output_dir = Path(self.output_dir) - - @property - def code_extractor(self) -> CodeExtractor: - """(Experimental) Export a code extractor that can be used by an agent.""" - return MarkdownCodeExtractor() - - def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> IPythonCodeResult: - """(Experimental) Execute a list of code blocks and return the result. - - This method executes a list of code blocks as cells in an IPython kernel - managed by this class. - See: https://jupyter-client.readthedocs.io/en/stable/messaging.html - for the message protocol. - - Args: - code_blocks (List[CodeBlock]): A list of code blocks to execute. - - Returns: - IPythonCodeResult: The result of the code execution. - """ - self._kernel_client.wait_for_ready() - outputs = [] - output_files = [] - for code_block in code_blocks: - code = self._process_code(code_block.code) - self._kernel_client.execute(code, store_history=True) - while True: - try: - msg = self._kernel_client.get_iopub_msg(timeout=self._timeout) - msg_type = msg["msg_type"] - content = msg["content"] - if msg_type in ["execute_result", "display_data"]: - for data_type, data in content["data"].items(): - if data_type == "text/plain": - # Output is a text. - outputs.append(data) - elif data_type.startswith("image/"): - # Output is an image. - path = self._save_image(data) - outputs.append(f"Image data saved to {path}") - output_files.append(path) - elif data_type == "text/html": - # Output is an html. - path = self._save_html(data) - outputs.append(f"HTML data saved to {path}") - output_files.append(path) - else: - # Output raw data. - outputs.append(json.dumps(data)) - elif msg_type == "stream": - # Output is a text. - outputs.append(content["text"]) - elif msg_type == "error": - # Output is an error. - return IPythonCodeResult( - exit_code=1, - output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}", - ) - if msg_type == "status" and content["execution_state"] == "idle": - break - # handle time outs. - except Empty: - return IPythonCodeResult( - exit_code=1, - output=f"ERROR: Timeout waiting for output from code block: {code_block.code}", - ) - # We return the full output. - return IPythonCodeResult( - exit_code=0, output="\n".join([str(output) for output in outputs]), output_files=output_files - ) - - def restart(self) -> None: - """(Experimental) Restart a new session.""" - self._kernel_client.stop_channels() - self._kernel_manager.shutdown_kernel() - self._kernel_manager = KernelManager(kernel_name=self.kernel_name) - self._kernel_manager.start_kernel() - self._kernel_client = self._kernel_manager.client() - self._kernel_client.start_channels() - - def _save_image(self, image_data_base64: str) -> str: - """Save image data to a file.""" - image_data = base64.b64decode(image_data_base64) - # Randomly generate a filename. - filename = f"{uuid.uuid4().hex}.png" - path = os.path.join(self.output_dir, filename) - with open(path, "wb") as f: - f.write(image_data) - return os.path.abspath(path) - - def _save_html(self, html_data: str) -> str: - """Save html data to a file.""" - # Randomly generate a filename. - filename = f"{uuid.uuid4().hex}.html" - path = os.path.join(self.output_dir, filename) - with open(path, "w") as f: - f.write(html_data) - return os.path.abspath(path) - - def _process_code(self, code: str) -> str: - """Process code before execution.""" - # Find lines that start with `! pip install` and make sure "-qqq" flag is added. - lines = code.split("\n") - for i, line in enumerate(lines): - # use regex to find lines that start with `! pip install` or `!pip install`. - match = re.search(r"^! ?pip install", line) - if match is not None: - if "-qqq" not in line: - lines[i] = line.replace(match.group(0), match.group(0) + " -qqq") - return "\n".join(lines) - -class CommandLineCodeResult(CodeResult): - """(Experimental) A code result class for command line code executor.""" - - code_file: str | None = Field( - default=None, - description="The file that the executed code block was saved to.", - ) - - -@dataclass -class Alias: - name: str - alias: str - - -@dataclass -class ImportFromModule: - module: str - imports: list[str | Alias] - -Import = str | ImportFromModule | Alias - - -class _StringLoader(SourceLoader): - def __init__(self, data: str): - self.data = data - - def get_source(self, fullname: str) -> str: - return self.data - - def get_data(self, path: str) -> bytes: - return self.data.encode("utf-8") - - def get_filename(self, fullname: str) -> str: - return "/" + fullname + ".py" - -@dataclass -class FunctionWithRequirementsStr: - func: str - _compiled_func: Callable[..., Any] - _func_name: str - python_packages: list[str] = field(default_factory=list) - global_imports: list[Import] = field(default_factory=list) - - def __init__(self, func: str, python_packages: list[str] = [], global_imports: list[Import] = []): - self.func = func - self.python_packages = python_packages - self.global_imports = global_imports - - module_name = "func_module" - loader = _StringLoader(func) - spec = importlib.util.spec_from_loader(module_name, loader) - if spec is None: - raise ValueError("Could not create spec") - module = importlib.util.module_from_spec(spec) - if spec.loader is None: - raise ValueError("Could not create loader") - - try: - spec.loader.exec_module(module) - except Exception as e: - raise ValueError(f"Could not compile function: {e}") from e - - functions = inspect.getmembers(module, inspect.isfunction) - if len(functions) != 1: - raise ValueError("The string must contain exactly one function") - - self._func_name, self._compiled_func = functions[0] - - def __call__(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("String based function with requirement objects are not directly callable") - - -@dataclass -class FunctionWithRequirements(Generic[T, P]): - func: Callable[P, T] - python_packages: list[str] = field(default_factory=list) - global_imports: list[Import] = field(default_factory=list) - - @classmethod - def from_callable( - cls, func: Callable[P, T], python_packages: list[str] = [], global_imports: list[Import] = [] - ) -> "FunctionWithRequirements"[T, P]: - return cls(python_packages=python_packages, global_imports=global_imports, func=func) - - @staticmethod - def from_str( - func: str, python_packages: list[str] = [], global_imports: list[Import] = [] - ) -> FunctionWithRequirementsStr: - return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports) - - # Type this based on F - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - return self.func(*args, **kwargs) - -def to_stub(func: Callable[..., Any] | FunctionWithRequirementsStr) -> str: - """Generate a stub for a function as a string - - Args: - func (Callable[..., Any]): The function to generate a stub for - - Returns: - str: The stub for the function - """ - if isinstance(func, FunctionWithRequirementsStr): - return to_stub(func._compiled_func) - - content = f"def {func.__name__}{inspect.signature(func)}:\n" - docstring = func.__doc__ - - if docstring: - docstring = dedent(docstring) - docstring = '"""' + docstring + '"""' - docstring = indent(docstring, " ") - content += docstring + "\n" - - content += " ..." - return content - -def _to_code(func: FunctionWithRequirements[T, P] | Callable[P, T] | FunctionWithRequirementsStr) -> str: - if isinstance(func, FunctionWithRequirementsStr): - return func.func - - code = inspect.getsource(func) - # Strip the decorator - if code.startswith("@"): - code = code[code.index("\n") + 1 :] - return code - -def _import_to_str(im: Import) -> str: - if isinstance(im, str): - return f"import {im}" - elif isinstance(im, Alias): - return f"import {im.name} as {im.alias}" - else: - - def to_str(i: str | Alias) -> str: - if isinstance(i, str): - return i - else: - return f"{i.name} as {i.alias}" - - imports = ", ".join(map(to_str, im.imports)) - return f"from {im.module} import {imports}" - -def _build_python_functions_file( - funcs: list[FunctionWithRequirements[Any, P] | Callable[..., Any] | FunctionWithRequirementsStr] -) -> str: - # First collect all global imports - global_imports: set[str] = set() - for func in funcs: - if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)): - global_imports.update(map(_import_to_str, func.global_imports)) - - content = "\n".join(global_imports) + "\n\n" - - for func in funcs: - content += _to_code(func) + "\n\n" - - return content - -filename_patterns = [ - re.compile(r"^", re.DOTALL), - re.compile(r"^/\* (filename:)?(.+?) \*/", re.DOTALL), - re.compile(r"^// (filename:)?(.+?)$", re.DOTALL), - re.compile(r"^# (filename:)?(.+?)$", re.DOTALL), -] - -def _get_file_name_from_content(code: str, workspace_path: Path) -> str | None: - first_line = code.split("\n")[0].strip() - # TODO - support other languages - for pattern in filename_patterns: - matches = pattern.match(first_line) - if matches is not None: - filename = matches.group(2).strip() - - # Handle relative paths in the filename - path = Path(filename) - if not path.is_absolute(): - path = workspace_path / path - path = path.resolve() - # Throws an error if the file is not in the workspace - relative = path.relative_to(workspace_path.resolve()) - return str(relative) - return None - -def silence_pip(code: str, lang: str) -> str: - """Apply -qqq flag to pip install commands.""" - if lang == "python": - regex = r"^! ?pip install" - elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]: - regex = r"^pip install" - else: - return code - - # Find lines that start with pip install and make sure "-qqq" flag is added. - lines = code.split("\n") - for i, line in enumerate(lines): - # use regex to find lines that start with pip install. - match = re.search(regex, line) - if match is not None: - if "-qqq" not in line: - lines[i] = line.replace(match.group(0), match.group(0) + " -qqq") - return "\n".join(lines) - -class LocalCommandLineCodeExecutor(CodeExecutor): - SUPPORTED_LANGUAGES: ClassVar[list[str]] = [ - "bash", - "shell", - "sh", - "pwsh", - "powershell", - "ps1", - "python", - "javascript", - "html", - "css", - ] - DEFAULT_EXECUTION_POLICY: ClassVar[dict[str, bool]] = { - "bash": True, - "shell": True, - "sh": True, - "pwsh": True, - "powershell": True, - "ps1": True, - "python": True, - "javascript": False, - "html": False, - "css": False, - } - - FUNCTION_PROMPT_TEMPLATE: ClassVar[ - str - ] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names. - -For example, if there was a function called `foo` you could import it by writing `from $module_name import foo` - -$functions""" - - def __init__( - self, - timeout: int = 60, - virtual_env_context: SimpleNamespace | None = None, - work_dir: Path | str = Path("."), - functions: list[FunctionWithRequirements[Any, A] | Callable[..., Any] | FunctionWithRequirementsStr] = [], - functions_module: str = "functions", - execution_policies: dict[str, bool] | None = None, - ): - """(Experimental) A code executor class that executes or saves LLM generated code a local command line - environment. - - **This will execute or save LLM generated code on the local machine.** - - Each code block is saved as a file in the working directory. Depending on the execution policy, - the code may be executed in a separate process. - The code blocks are executed or save in the order they are received. - Command line code is sanitized against a list of dangerous commands to prevent self-destructive commands from being executed, - which could potentially affect the user's environment. Supported languages include Python, shell scripts (bash, shell, sh), - PowerShell (pwsh, powershell, ps1), HTML, CSS, and JavaScript. - Execution policies determine whether each language's code blocks are executed or saved only. - - ## Execution with a Python virtual environment - A python virtual env can be used to execute code and install dependencies. This has the added benefit of not polluting the - base environment with unwanted modules. - ```python - from autogen.code_utils import create_virtual_env - from autogen.coding import LocalCommandLineCodeExecutor - - venv_dir = ".venv" - venv_context = create_virtual_env(venv_dir) - - executor = LocalCommandLineCodeExecutor(virtual_env_context=venv_context) - ``` - - Args: - timeout (int): The timeout for code execution, default is 60 seconds. - virtual_env_context (Optional[SimpleNamespace]): The virtual environment context to use. - work_dir (Union[Path, str]): The working directory for code execution, defaults to the current directory. - functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]): A list of callable functions available to the executor. - functions_module (str): The module name under which functions are accessible. - execution_policies (Optional[Dict[str, bool]]): A dictionary mapping languages to execution policies (True for execution, False for saving only). Defaults to class-wide DEFAULT_EXECUTION_POLICY. - """ - - if timeout < 1: - raise ValueError("Timeout must be greater than or equal to 1.") - - if isinstance(work_dir, str): - work_dir = Path(work_dir) - - if not functions_module.isidentifier(): - raise ValueError("Module name must be a valid Python identifier") - - self._functions_module = functions_module - - work_dir.mkdir(exist_ok=True) - - self._timeout = timeout - self._work_dir: Path = work_dir - self._virtual_env_context: SimpleNamespace | None = virtual_env_context - - self._functions = functions - # Setup could take some time so we intentionally wait for the first code block to do it. - if len(functions) > 0: - self._setup_functions_complete = False - else: - self._setup_functions_complete = True - - self.execution_policies = self.DEFAULT_EXECUTION_POLICY.copy() - if execution_policies is not None: - self.execution_policies.update(execution_policies) - - def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str: - """(Experimental) Format the functions for a prompt. - - The template includes two variables: - - `$module_name`: The module name. - - `$functions`: The functions formatted as stubs with two newlines between each function. - - Args: - prompt_template (str): The prompt template. Default is the class default. - - Returns: - str: The formatted prompt. - """ - template = Template(prompt_template) - return template.substitute( - module_name=self._functions_module, - functions="\n\n".join([to_stub(func) for func in self._functions]), - ) - - @property - def functions_module(self) -> str: - """(Experimental) The module name for the functions.""" - return self._functions_module - - @property - def functions( - self, - ) -> list[FunctionWithRequirements[Any, A] | Callable[..., Any] | FunctionWithRequirementsStr]: - """(Experimental) The functions that are available to the code executor.""" - return self._functions - - @property - def timeout(self) -> int: - """(Experimental) The timeout for code execution.""" - return self._timeout - - @property - def work_dir(self) -> Path: - """(Experimental) The working directory for the code execution.""" - return self._work_dir - - @property - def code_extractor(self) -> CodeExtractor: - """(Experimental) Export a code extractor that can be used by an agent.""" - return MarkdownCodeExtractor() - - @staticmethod - def sanitize_command(lang: str, code: str) -> None: - """ - Sanitize the code block to prevent dangerous commands. - This approach acknowledges that while Docker or similar - containerization/sandboxing technologies provide a robust layer of security, - not all users may have Docker installed or may choose not to use it. - Therefore, having a baseline level of protection helps mitigate risks for users who, - either out of choice or necessity, run code outside of a sandboxed environment. - """ - dangerous_patterns = [ - (r"\brm\s+-rf\b", "Use of 'rm -rf' command is not allowed."), - (r"\bmv\b.*?\s+/dev/null", "Moving files to /dev/null is not allowed."), - (r"\bdd\b", "Use of 'dd' command is not allowed."), - (r">\s*/dev/sd[a-z][1-9]?", "Overwriting disk blocks directly is not allowed."), - (r":\(\)\{\s*:\|\:&\s*\};:", "Fork bombs are not allowed."), - ] - if lang in ["bash", "shell", "sh"]: - for pattern, message in dangerous_patterns: - if re.search(pattern, code): - raise ValueError(f"Potentially dangerous command detected: {message}") - - def _setup_functions(self) -> None: - func_file_content = _build_python_functions_file(self._functions) - func_file = self._work_dir / f"{self._functions_module}.py" - func_file.write_text(func_file_content) - - # Collect requirements - lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)] - flattened_packages = [item for sublist in lists_of_packages for item in sublist] - required_packages = list(set(flattened_packages)) - if len(required_packages) > 0: - print("Ensuring packages are installed in executor.") - if self._virtual_env_context: - py_executable = self._virtual_env_context.env_exe - else: - py_executable = sys.executable - cmd = [py_executable, "-m", "pip", "install"] + required_packages - try: - result = subprocess.run( - cmd, - cwd=self._work_dir, - capture_output=True, - text=True, - timeout=float(self._timeout), - encoding="utf-8", - ) - except subprocess.TimeoutExpired as e: - raise ValueError("Pip install timed out") from e - if result.returncode != 0: - raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}") - # Attempt to load the function file to check for syntax errors, imports etc. - exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")]) - if exec_result.exit_code != 0: - raise ValueError(f"Functions failed to load: {exec_result.output}") - self._setup_functions_complete = True - - def execute_code_blocks(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult: - """(Experimental) Execute the code blocks and return the result. - - Args: - code_blocks (List[CodeBlock]): The code blocks to execute. - - Returns: - CommandLineCodeResult: The result of the code execution.""" - if not self._setup_functions_complete: - self._setup_functions() - return self._execute_code_dont_check_setup(code_blocks) - - def _execute_code_dont_check_setup(self, code_blocks: list[CodeBlock]) -> CommandLineCodeResult: - logs_all = "" - file_names = [] - for code_block in code_blocks: - lang, code = code_block.language, code_block.code - lang = lang.lower() - - LocalCommandLineCodeExecutor.sanitize_command(lang, code) - code = silence_pip(code, lang) - - if lang in PYTHON_VARIANTS: - lang = "python" - - if WIN32 and lang in ["sh", "shell"]: - lang = "ps1" - - if lang not in self.SUPPORTED_LANGUAGES: - # In case the language is not supported, we return an error message. - exitcode = 1 - logs_all += "\n" + f"unknown language {lang}" - break - - execute_code = self.execution_policies.get(lang, False) - try: - # Check if there is a filename comment - filename = _get_file_name_from_content(code, self._work_dir) - except ValueError: - return CommandLineCodeResult(exit_code=1, output="Filename is not in the workspace") - - if filename is None: - # create a file with an automatically generated name - code_hash = md5(code.encode()).hexdigest() - filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}" - written_file = (self._work_dir / filename).resolve() - with written_file.open("w", encoding="utf-8") as f: - f.write(code) - file_names.append(written_file) - - if not execute_code: - # Just return a message that the file is saved. - logs_all += f"Code saved to {str(written_file)}\n" - exitcode = 0 - continue - - program = _cmd(lang) - cmd = [program, str(written_file.absolute())] - env = os.environ.copy() - - if self._virtual_env_context: - virtual_env_abs_path = os.path.abspath(self._virtual_env_context.bin_path) - path_with_virtualenv = rf"{virtual_env_abs_path}{os.pathsep}{env['PATH']}" - env["PATH"] = path_with_virtualenv - if WIN32: - activation_script = os.path.join(virtual_env_abs_path, "activate.bat") - cmd = [activation_script, "&&", *cmd] - - try: - result = subprocess.run( - cmd, - cwd=self._work_dir, - capture_output=True, - text=True, - timeout=float(self._timeout), - env=env, - encoding="utf-8", - ) - except subprocess.TimeoutExpired: - logs_all += "\n" + TIMEOUT_MSG - # Same exit code as the timeout command on linux. - exitcode = 124 - break - - logs_all += result.stderr - logs_all += result.stdout - exitcode = result.returncode - - if exitcode != 0: - break - - code_file = str(file_names[0]) if len(file_names) > 0 else None - return CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file) - - def restart(self) -> None: - """(Experimental) Restart the code executor.""" - warnings.warn("Restarting local command line code executor is not supported. No action is taken.") diff --git a/train_methods/legacy_autogen/completion.py b/train_methods/legacy_autogen/completion.py deleted file mode 100644 index f892bc7..0000000 --- a/train_methods/legacy_autogen/completion.py +++ /dev/null @@ -1,1151 +0,0 @@ -import logging -import shutil -import time -from collections import defaultdict -from time import sleep -from typing import Callable, Dict, List, Optional, Union - -import diskcache -import openai -import numpy as np -from flaml import BlendSearch, tune -from flaml.tune.space import is_constant - -from openai.types.completion import Completion as openai_Completion -from openai.types.chat import ChatCompletion -from openai import APIError, APIConnectionError, BadRequestError, Timeout, RateLimitError, AuthenticationError - -from train_methods.legacy_autogen.client import get_key - - -class Completion(openai_Completion): - """(openai<1) A class for OpenAI completion API. - - It also supports: ChatCompletion, Azure OpenAI API. - """ - - # set of models that support chat completion - chat_models = { - "gpt-3.5-turbo", - "gpt-3.5-turbo-0301", # deprecate in Sep - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-16k-0613", - "gpt-35-turbo", - "gpt-35-turbo-16k", - "gpt-4", - "gpt-4-32k", - "gpt-4-32k-0314", # deprecate in Sep - "gpt-4-0314", # deprecate in Sep - "gpt-4-0613", - "gpt-4-32k-0613", - } - - # price per 1k tokens - price1K = { - "text-ada-001": 0.0004, - "text-babbage-001": 0.0005, - "text-curie-001": 0.002, - "code-cushman-001": 0.024, - "code-davinci-002": 0.1, - "text-davinci-002": 0.02, - "text-davinci-003": 0.02, - "gpt-3.5-turbo": (0.0015, 0.002), - "gpt-3.5-turbo-instruct": (0.0015, 0.002), - "gpt-3.5-turbo-0301": (0.0015, 0.002), # deprecate in Sep - "gpt-3.5-turbo-0613": (0.0015, 0.002), - "gpt-3.5-turbo-16k": (0.003, 0.004), - "gpt-3.5-turbo-16k-0613": (0.003, 0.004), - "gpt-35-turbo": (0.0015, 0.002), - "gpt-35-turbo-16k": (0.003, 0.004), - "gpt-35-turbo-instruct": (0.0015, 0.002), - "gpt-4": (0.03, 0.06), - "gpt-4-32k": (0.06, 0.12), - "gpt-4-0314": (0.03, 0.06), # deprecate in Sep - "gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep - "gpt-4-0613": (0.03, 0.06), - "gpt-4-32k-0613": (0.06, 0.12), - } - - default_search_space = { - "model": tune.choice( - [ - "text-ada-001", - "text-babbage-001", - "text-davinci-003", - "gpt-3.5-turbo", - "gpt-4", - ] - ), - "temperature_or_top_p": tune.choice( - [ - {"temperature": tune.uniform(0, 2)}, - {"top_p": tune.uniform(0, 1)}, - ] - ), - "max_tokens": tune.lograndint(50, 1000), - "n": tune.randint(1, 100), - "prompt": "{prompt}", - } - - cache_seed = 41 - cache_path = f".cache/{cache_seed}" - # retry after this many seconds - retry_wait_time = 10 - # fail a request after hitting RateLimitError for this many seconds - max_retry_period = 120 - # time out for request to openai server - request_timeout = 60 - - openai_completion_class = openai_Completion - _total_cost = 0 - optimization_budget = None - - _history_dict = _count_create = None - - @classmethod - def set_cache(cls, seed: Optional[int] = 41, cache_path_root: Optional[str] = ".cache"): - """Set cache path. - - Args: - seed (int, Optional): The integer identifier for the pseudo seed. - Results corresponding to different seeds will be cached in different places. - cache_path (str, Optional): The root path for the cache. - The complete cache path will be {cache_path_root}/{seed}. - """ - cls.cache_seed = seed - cls.cache_path = f"{cache_path_root}/{seed}" - - @classmethod - def clear_cache(cls, seed: Optional[int] = None, cache_path_root: Optional[str] = ".cache"): - """Clear cache. - - Args: - seed (int, Optional): The integer identifier for the pseudo seed. - If omitted, all caches under cache_path_root will be cleared. - cache_path (str, Optional): The root path for the cache. - The complete cache path will be {cache_path_root}/{seed}. - """ - if seed is None: - shutil.rmtree(cache_path_root, ignore_errors=True) - return - with diskcache.Cache(f"{cache_path_root}/{seed}") as cache: - cache.clear() - - @classmethod - def _book_keeping(cls, config: Dict, response): - """Book keeping for the created completions.""" - if response != -1 and "cost" not in response: - response["cost"] = cls.cost(response) - if cls._history_dict is None: - return - if cls._history_compact: - value = { - "created_at": [], - "cost": [], - "token_count": [], - } - if "messages" in config: - messages = config["messages"] - if len(messages) > 1 and messages[-1]["role"] != "assistant": - existing_key = get_key(messages[:-1]) - value = cls._history_dict.pop(existing_key, value) - key = get_key(messages + [choice["message"] for choice in response["choices"]]) - else: - key = get_key([config["prompt"]] + [choice.get("text") for choice in response["choices"]]) - value["created_at"].append(cls._count_create) - value["cost"].append(response["cost"]) - value["token_count"].append( - { - "model": response["model"], - "prompt_tokens": response["usage"]["prompt_tokens"], - "completion_tokens": response["usage"].get("completion_tokens", 0), - "total_tokens": response["usage"]["total_tokens"], - } - ) - cls._history_dict[key] = value - cls._count_create += 1 - return - cls._history_dict[cls._count_create] = { - "request": config, - "response": response.to_dict_recursive(), - } - cls._count_create += 1 - - @classmethod - def _get_response(cls, config: Dict, raise_on_ratelimit_or_timeout=False, use_cache=True): - """Get the response from the openai api call. - - Try cache first. If not found, call the openai api. If the api call fails, retry after retry_wait_time. - """ - config = config.copy() - key = get_key(config) - if use_cache: - response = cls._cache.get(key, None) - if response is not None and (response != -1 or not raise_on_ratelimit_or_timeout): - # print("using cached response") - cls._book_keeping(config, response) - return response - openai_completion = ( - ChatCompletion - if config["model"].replace("gpt-35-turbo", "gpt-3.5-turbo") in cls.chat_models - or issubclass(cls, ChatCompletion) - else openai_Completion - ) - start_time = time.time() - request_timeout = cls.request_timeout - max_retry_period = config.pop("max_retry_period", cls.max_retry_period) - retry_wait_time = config.pop("retry_wait_time", cls.retry_wait_time) - while True: - try: - if "request_timeout" in config: - response = openai_completion.create(**config) - else: - response = openai_completion.create(request_timeout=request_timeout, **config) - except APIConnectionError: - # transient error - print(f"retrying in {retry_wait_time} seconds...", exc_info=1) - sleep(retry_wait_time) - except APIError as err: - error_code = err and err.json_body and isinstance(err.json_body, dict) and err.json_body.get("error") - if isinstance(error_code, dict): - error_code = error_code.get("code") - if error_code == "content_filter": - raise - # transient error - print(f"retrying in {retry_wait_time} seconds...", exc_info=1) - sleep(retry_wait_time) - except (RateLimitError, Timeout) as err: - time_left = max_retry_period - (time.time() - start_time + retry_wait_time) - if ( - time_left > 0 - and isinstance(err, RateLimitError) - or time_left > request_timeout - and isinstance(err, Timeout) - and "request_timeout" not in config - ): - if isinstance(err, Timeout): - request_timeout <<= 1 - request_timeout = min(request_timeout, time_left) - print(f"retrying in {retry_wait_time} seconds...", exc_info=1) - sleep(retry_wait_time) - elif raise_on_ratelimit_or_timeout: - raise - else: - response = -1 - if use_cache and isinstance(err, Timeout): - cls._cache.set(key, response) - print( - f"Failed to get response from openai api due to getting RateLimitError or Timeout for {max_retry_period} seconds." - ) - return response - except BadRequestError: - if "azure" in config.get("api_type", openai.api_type) and "model" in config: - # azure api uses "engine" instead of "model" - config["engine"] = config.pop("model").replace("gpt-3.5-turbo", "gpt-35-turbo") - else: - raise - else: - if use_cache: - cls._cache.set(key, response) - cls._book_keeping(config, response) - return response - - @classmethod - def _get_max_valid_n(cls, key, max_tokens): - # find the max value in max_valid_n_per_max_tokens - # whose key is equal or larger than max_tokens - return max( - (value for k, value in cls._max_valid_n_per_max_tokens.get(key, {}).items() if k >= max_tokens), - default=1, - ) - - @classmethod - def _get_min_invalid_n(cls, key, max_tokens): - # find the min value in min_invalid_n_per_max_tokens - # whose key is equal or smaller than max_tokens - return min( - (value for k, value in cls._min_invalid_n_per_max_tokens.get(key, {}).items() if k <= max_tokens), - default=None, - ) - - @classmethod - def _get_region_key(cls, config): - # get a key for the valid/invalid region corresponding to the given config - config = cls._pop_subspace(config, always_copy=False) - return ( - config["model"], - config.get("prompt", config.get("messages")), - config.get("stop"), - ) - - @classmethod - def _update_invalid_n(cls, prune, region_key, max_tokens, num_completions): - if prune: - # update invalid n and prune this config - cls._min_invalid_n_per_max_tokens[region_key] = invalid_n = cls._min_invalid_n_per_max_tokens.get( - region_key, {} - ) - invalid_n[max_tokens] = min(num_completions, invalid_n.get(max_tokens, np.inf)) - - @classmethod - def _pop_subspace(cls, config, always_copy=True): - if "subspace" in config: - config = config.copy() - config.update(config.pop("subspace")) - return config.copy() if always_copy else config - - @classmethod - def _get_params_for_create(cls, config: Dict) -> Dict: - """Get the params for the openai api call from a config in the search space.""" - params = cls._pop_subspace(config) - if cls._prompts: - params["prompt"] = cls._prompts[config["prompt"]] - else: - params["messages"] = cls._messages[config["messages"]] - if "stop" in params: - params["stop"] = cls._stops and cls._stops[params["stop"]] - temperature_or_top_p = params.pop("temperature_or_top_p", None) - if temperature_or_top_p: - params.update(temperature_or_top_p) - if cls._config_list and "config_list" not in params: - params["config_list"] = cls._config_list - return params - - @classmethod - def _eval(cls, config: dict, prune=True, eval_only=False): - """Evaluate the given config as the hyperparameter setting for the openai api call. - - Args: - config (dict): Hyperparameter setting for the openai api call. - prune (bool, optional): Whether to enable pruning. Defaults to True. - eval_only (bool, optional): Whether to evaluate only - (ignore the inference budget and do not raise error when a request fails). - Defaults to False. - - Returns: - dict: Evaluation results. - """ - cost = 0 - data = cls.data - params = cls._get_params_for_create(config) - model = params["model"] - data_length = len(data) - price = cls.price1K.get(model) - price_input, price_output = price if isinstance(price, tuple) else (price, price) - inference_budget = getattr(cls, "inference_budget", None) - prune_hp = getattr(cls, "_prune_hp", "n") - metric = cls._metric - config_n = params.get(prune_hp, 1) # default value in OpenAI is 1 - max_tokens = params.get( - "max_tokens", np.inf if model in cls.chat_models or issubclass(cls, ChatCompletion) else 16 - ) - target_output_tokens = None - if not cls.avg_input_tokens: - input_tokens = [None] * data_length - prune = prune and inference_budget and not eval_only - if prune: - region_key = cls._get_region_key(config) - max_valid_n = cls._get_max_valid_n(region_key, max_tokens) - if cls.avg_input_tokens: - target_output_tokens = (inference_budget * 1000 - cls.avg_input_tokens * price_input) / price_output - # max_tokens bounds the maximum tokens - # so using it we can calculate a valid n according to the avg # input tokens - max_valid_n = max( - max_valid_n, - int(target_output_tokens // max_tokens), - ) - if config_n <= max_valid_n: - start_n = config_n - else: - min_invalid_n = cls._get_min_invalid_n(region_key, max_tokens) - if min_invalid_n is not None and config_n >= min_invalid_n: - # prune this config - return { - "inference_cost": np.inf, - metric: np.inf if cls._mode == "min" else -np.inf, - "cost": cost, - } - start_n = max_valid_n + 1 - else: - start_n = config_n - region_key = None - num_completions, previous_num_completions = start_n, 0 - n_tokens_list, result, responses_list = [], {}, [] - while True: # n <= config_n - params[prune_hp] = num_completions - previous_num_completions - data_limit = 1 if prune else data_length - prev_data_limit = 0 - data_early_stop = False # whether data early stop happens for this n - while True: # data_limit <= data_length - # limit the number of data points to avoid rate limit - for i in range(prev_data_limit, data_limit): - data_i = data[i] - response = cls.create(data_i, raise_on_ratelimit_or_timeout=eval_only, **params) - if response == -1: # rate limit/timeout error, treat as invalid - cls._update_invalid_n(prune, region_key, max_tokens, num_completions) - result[metric] = 0 - result["cost"] = cost - return result - # evaluate the quality of the responses - responses = cls.extract_text_or_function_call(response) - usage = response["usage"] - n_input_tokens = usage["prompt_tokens"] - n_output_tokens = usage.get("completion_tokens", 0) - if not cls.avg_input_tokens and not input_tokens[i]: - # store the # input tokens - input_tokens[i] = n_input_tokens - query_cost = response["cost"] - cls._total_cost += query_cost - cost += query_cost - if cls.optimization_budget and cls._total_cost >= cls.optimization_budget and not eval_only: - # limit the total tuning cost - return { - metric: 0, - "total_cost": cls._total_cost, - "cost": cost, - } - if previous_num_completions: - n_tokens_list[i] += n_output_tokens - responses_list[i].extend(responses) - # Assumption 1: assuming requesting n1, n2 responses separately then combining them - # is the same as requesting (n1+n2) responses together - else: - n_tokens_list.append(n_output_tokens) - responses_list.append(responses) - avg_n_tokens = np.mean(n_tokens_list[:data_limit]) - rho = ( - (1 - data_limit / data_length) * (1 + 1 / data_limit) - if data_limit << 1 > data_length - else (1 - (data_limit - 1) / data_length) - ) - # Hoeffding-Serfling bound - ratio = 0.1 * np.sqrt(rho / data_limit) - if target_output_tokens and avg_n_tokens > target_output_tokens * (1 + ratio) and not eval_only: - cls._update_invalid_n(prune, region_key, max_tokens, num_completions) - result[metric] = 0 - result["total_cost"] = cls._total_cost - result["cost"] = cost - return result - if ( - prune - and target_output_tokens - and avg_n_tokens <= target_output_tokens * (1 - ratio) - and (num_completions < config_n or num_completions == config_n and data_limit == data_length) - ): - # update valid n - cls._max_valid_n_per_max_tokens[region_key] = valid_n = cls._max_valid_n_per_max_tokens.get( - region_key, {} - ) - valid_n[max_tokens] = max(num_completions, valid_n.get(max_tokens, 0)) - if num_completions < config_n: - # valid already, skip the rest of the data - data_limit = data_length - data_early_stop = True - break - prev_data_limit = data_limit - if data_limit < data_length: - data_limit = min(data_limit << 1, data_length) - else: - break - # use exponential search to increase n - if num_completions == config_n: - for i in range(data_limit): - data_i = data[i] - responses = responses_list[i] - metrics = cls._eval_func(responses, **data_i) - if result: - for key, value in metrics.items(): - if isinstance(value, (float, int)): - result[key] += value - else: - result = metrics - for key in result.keys(): - if isinstance(result[key], (float, int)): - result[key] /= data_limit - result["total_cost"] = cls._total_cost - result["cost"] = cost - if not cls.avg_input_tokens: - cls.avg_input_tokens = np.mean(input_tokens) - if prune: - target_output_tokens = ( - inference_budget * 1000 - cls.avg_input_tokens * price_input - ) / price_output - result["inference_cost"] = (avg_n_tokens * price_output + cls.avg_input_tokens * price_input) / 1000 - break - else: - if data_early_stop: - previous_num_completions = 0 - n_tokens_list.clear() - responses_list.clear() - else: - previous_num_completions = num_completions - num_completions = min(num_completions << 1, config_n) - return result - - @classmethod - def tune( - cls, - data: List[Dict], - metric: str, - mode: str, - eval_func: Callable, - log_file_name: Optional[str] = None, - inference_budget: Optional[float] = None, - optimization_budget: Optional[float] = None, - num_samples: Optional[int] = 1, - logging_level: Optional[int] = logging.WARNING, - **config, - ): - """Tune the parameters for the OpenAI API call. - - TODO: support parallel tuning with ray or spark. - TODO: support agg_method as in test - - Args: - data (list): The list of data points. - metric (str): The metric to optimize. - mode (str): The optimization mode, "min" or "max. - eval_func (Callable): The evaluation function for responses. - The function should take a list of responses and a data point as input, - and return a dict of metrics. For example, - - ```python - def eval_func(responses, **data): - solution = data["solution"] - success_list = [] - n = len(responses) - for i in range(n): - response = responses[i] - succeed = is_equiv_chain_of_thought(response, solution) - success_list.append(succeed) - return { - "expected_success": 1 - pow(1 - sum(success_list) / n, n), - "success": any(s for s in success_list), - } - ``` - - log_file_name (str, optional): The log file. - inference_budget (float, optional): The inference budget, dollar per instance. - optimization_budget (float, optional): The optimization budget, dollar in total. - num_samples (int, optional): The number of samples to evaluate. - -1 means no hard restriction in the number of trials - and the actual number is decided by optimization_budget. Defaults to 1. - logging_level (optional): logging level. Defaults to logging.WARNING. - **config (dict): The search space to update over the default search. - For prompt, please provide a string/Callable or a list of strings/Callables. - - If prompt is provided for chat models, it will be converted to messages under role "user". - - Do not provide both prompt and messages for chat models, but provide either of them. - - A string template will be used to generate a prompt for each data instance - using `prompt.format(**data)`. - - A callable template will be used to generate a prompt for each data instance - using `prompt(data)`. - For stop, please provide a string, a list of strings, or a list of lists of strings. - For messages (chat models only), please provide a list of messages (for a single chat prefix) - or a list of lists of messages (for multiple choices of chat prefix to choose from). - Each message should be a dict with keys "role" and "content". The value of "content" can be a string/Callable template. - - Returns: - dict: The optimized hyperparameter setting. - tune.ExperimentAnalysis: The tuning results. - """ - print( - "tuning via Completion.tune is deprecated in pyautogen v0.2 and openai>=1. " - "flaml.tune supports tuning more generically." - ) - space = cls.default_search_space.copy() - if config is not None: - space.update(config) - if "messages" in space: - space.pop("prompt", None) - temperature = space.pop("temperature", None) - top_p = space.pop("top_p", None) - if temperature is not None and top_p is None: - space["temperature_or_top_p"] = {"temperature": temperature} - elif temperature is None and top_p is not None: - space["temperature_or_top_p"] = {"top_p": top_p} - elif temperature is not None and top_p is not None: - space.pop("temperature_or_top_p") - space["temperature"] = temperature - space["top_p"] = top_p - print("temperature and top_p are not recommended to vary together.") - cls._max_valid_n_per_max_tokens, cls._min_invalid_n_per_max_tokens = {}, {} - cls.optimization_budget = optimization_budget - cls.inference_budget = inference_budget - cls._prune_hp = "best_of" if space.get("best_of", 1) != 1 else "n" - cls._prompts = space.get("prompt") - if cls._prompts is None: - cls._messages = space.get("messages") - if not all((isinstance(cls._messages, list), isinstance(cls._messages[0], (dict, list)))): - error_msg = "messages must be a list of dicts or a list of lists." - raise AssertionError(error_msg) - if isinstance(cls._messages[0], dict): - cls._messages = [cls._messages] - space["messages"] = tune.choice(list(range(len(cls._messages)))) - else: - if space.get("messages") is not None: - error_msg = "messages and prompt cannot be provided at the same time." - raise AssertionError(error_msg) - if not isinstance(cls._prompts, (str, list)): - error_msg = "prompt must be a string or a list of strings." - raise AssertionError(error_msg) - if isinstance(cls._prompts, str): - cls._prompts = [cls._prompts] - space["prompt"] = tune.choice(list(range(len(cls._prompts)))) - cls._stops = space.get("stop") - if cls._stops: - if not isinstance(cls._stops, (str, list)): - error_msg = "stop must be a string, a list of strings, or a list of lists of strings." - raise AssertionError(error_msg) - if not (isinstance(cls._stops, list) and isinstance(cls._stops[0], list)): - cls._stops = [cls._stops] - space["stop"] = tune.choice(list(range(len(cls._stops)))) - cls._config_list = space.get("config_list") - if cls._config_list is not None: - is_const = is_constant(cls._config_list) - if is_const: - space.pop("config_list") - cls._metric, cls._mode = metric, mode - cls._total_cost = 0 # total optimization cost - cls._eval_func = eval_func - cls.data = data - cls.avg_input_tokens = None - - space_model = space["model"] - if not isinstance(space_model, str) and len(space_model) > 1: - # make a hierarchical search space - subspace = {} - if "max_tokens" in space: - subspace["max_tokens"] = space.pop("max_tokens") - if "temperature_or_top_p" in space: - subspace["temperature_or_top_p"] = space.pop("temperature_or_top_p") - if "best_of" in space: - subspace["best_of"] = space.pop("best_of") - if "n" in space: - subspace["n"] = space.pop("n") - choices = [] - for model in space["model"]: - choices.append({"model": model, **subspace}) - space["subspace"] = tune.choice(choices) - space.pop("model") - # start all the models with the same hp config - search_alg = BlendSearch( - cost_attr="cost", - cost_budget=optimization_budget, - metric=metric, - mode=mode, - space=space, - ) - config0 = search_alg.suggest("t0") - points_to_evaluate = [config0] - for model in space_model: - if model != config0["subspace"]["model"]: - point = config0.copy() - point["subspace"] = point["subspace"].copy() - point["subspace"]["model"] = model - points_to_evaluate.append(point) - search_alg = BlendSearch( - cost_attr="cost", - cost_budget=optimization_budget, - metric=metric, - mode=mode, - space=space, - points_to_evaluate=points_to_evaluate, - ) - else: - search_alg = BlendSearch( - cost_attr="cost", - cost_budget=optimization_budget, - metric=metric, - mode=mode, - space=space, - ) - with diskcache.Cache(cls.cache_path) as cls._cache: - analysis = tune.run( - cls._eval, - search_alg=search_alg, - num_samples=num_samples, - log_file_name=log_file_name, - verbose=3, - ) - config = analysis.best_config - params = cls._get_params_for_create(config) - if cls._config_list is not None and is_const: - params.pop("config_list") - return params, analysis - - @classmethod - def create( - cls, - context: Optional[Dict] = None, - use_cache: Optional[bool] = True, - config_list: Optional[List[Dict]] = None, - filter_func: Optional[Callable[[Dict, Dict], bool]] = None, - raise_on_ratelimit_or_timeout: Optional[bool] = True, - allow_format_str_template: Optional[bool] = False, - **config, - ): - """Make a completion for a given context. - - Args: - context (Dict, Optional): The context to instantiate the prompt. - It needs to contain keys that are used by the prompt template or the filter function. - E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`. - The actual prompt will be: - "Complete the following sentence: Today I feel". - More examples can be found at [templating](https://microsoft.github.io/autogen/docs/Use-Cases/enhanced_inference#templating). - use_cache (bool, Optional): Whether to use cached responses. - config_list (List, Optional): List of configurations for the completion to try. - The first one that does not raise an error will be used. - Only the differences from the default config need to be provided. - E.g., - - ```python - response = oai.Completion.create( - config_list=[ - { - "model": "gpt-4", - "api_key": os.environ.get("AZURE_OPENAI_API_KEY"), - "api_type": "azure", - "base_url": os.environ.get("AZURE_OPENAI_API_BASE"), - "api_version": "2024-02-01", - }, - { - "model": "gpt-3.5-turbo", - "api_key": os.environ.get("OPENAI_API_KEY"), - "api_type": "openai", - "base_url": "https://api.openai.com/v1", - }, - { - "model": "llama-7B", - "base_url": "http://127.0.0.1:8080", - "api_type": "openai", - } - ], - prompt="Hi", - ) - ``` - - filter_func (Callable, Optional): A function that takes in the context and the response and returns a boolean to indicate whether the response is valid. E.g., - - ```python - def yes_or_no_filter(context, config, response): - return context.get("yes_or_no_choice", False) is False or any( - text in ["Yes.", "No."] for text in oai.Completion.extract_text(response) - ) - ``` - - raise_on_ratelimit_or_timeout (bool, Optional): Whether to raise RateLimitError or Timeout when all configs fail. - When set to False, -1 will be returned when all configs fail. - allow_format_str_template (bool, Optional): Whether to allow format string template in the config. - **config: Configuration for the openai API call. This is used as parameters for calling openai API. - The "prompt" or "messages" parameter can contain a template (str or Callable) which will be instantiated with the context. - Besides the parameters for the openai API call, it can also contain: - - `max_retry_period` (int): the total time (in seconds) allowed for retrying failed requests. - - `retry_wait_time` (int): the time interval to wait (in seconds) before retrying a failed request. - - `cache_seed` (int) for the cache. This is useful when implementing "controlled randomness" for the completion. - - Returns: - Responses from OpenAI API, with additional fields. - - `cost`: the total cost. - When `config_list` is provided, the response will contain a few more fields: - - `config_id`: the index of the config in the config_list that is used to generate the response. - - `pass_filter`: whether the response passes the filter function. None if no filter is provided. - """ - print( - "Completion.create is deprecated in pyautogen v0.2 and openai>=1. " - "The new openai requires initiating a client for inference. " - "Please refer to https://microsoft.github.io/autogen/docs/Use-Cases/enhanced_inference#api-unification" - ) - - # Warn if a config list was provided but was empty - if isinstance(config_list, list) and len(config_list) == 0: - print( - "Completion was provided with a config_list, but the list was empty. Adopting default OpenAI behavior, which reads from the 'model' parameter instead." - ) - - if config_list: - last = len(config_list) - 1 - cost = 0 - for i, each_config in enumerate(config_list): - base_config = config.copy() - base_config["allow_format_str_template"] = allow_format_str_template - base_config.update(each_config) - if i < last and filter_func is None and "max_retry_period" not in base_config: - # max_retry_period = 0 to avoid retrying when no filter is given - base_config["max_retry_period"] = 0 - try: - response = cls.create( - context, - use_cache, - raise_on_ratelimit_or_timeout=i < last or raise_on_ratelimit_or_timeout, - **base_config, - ) - if response == -1: - return response - pass_filter = filter_func is None or filter_func(context=context, response=response) - if pass_filter or i == last: - response["cost"] = cost + response["cost"] - response["config_id"] = i - response["pass_filter"] = pass_filter - return response - cost += response["cost"] - except (AuthenticationError, RateLimitError, Timeout, BadRequestError): - if i == last: - raise - params = cls._construct_params(context, config, allow_format_str_template=allow_format_str_template) - if not use_cache: - return cls._get_response( - params, raise_on_ratelimit_or_timeout=raise_on_ratelimit_or_timeout, use_cache=False - ) - cache_seed = cls.cache_seed - if "cache_seed" in params: - cls.set_cache(params.pop("cache_seed")) - with diskcache.Cache(cls.cache_path) as cls._cache: - cls.set_cache(cache_seed) - return cls._get_response(params, raise_on_ratelimit_or_timeout=raise_on_ratelimit_or_timeout) - - @classmethod - def instantiate( - cls, - template: Union[str, None], - context: Optional[Dict] = None, - allow_format_str_template: Optional[bool] = False, - ): - if not context or template is None: - return template - if isinstance(template, str): - return template.format(**context) if allow_format_str_template else template - return template(context) - - @classmethod - def _construct_params(cls, context, config, prompt=None, messages=None, allow_format_str_template=False): - params = config.copy() - model = config["model"] - prompt = config.get("prompt") if prompt is None else prompt - messages = config.get("messages") if messages is None else messages - # either "prompt" should be in config (for being compatible with non-chat models) - # or "messages" should be in config (for tuning chat models only) - if prompt is None and (model in cls.chat_models or issubclass(cls, ChatCompletion)): - if messages is None: - raise ValueError("Either prompt or messages should be in config for chat models.") - if prompt is None: - params["messages"] = ( - [ - ( - { - **m, - "content": cls.instantiate(m["content"], context, allow_format_str_template), - } - if m.get("content") - else m - ) - for m in messages - ] - if context - else messages - ) - elif model in cls.chat_models or issubclass(cls, ChatCompletion): - # convert prompt to messages - params["messages"] = [ - { - "role": "user", - "content": cls.instantiate(prompt, context, allow_format_str_template), - }, - ] - params.pop("prompt", None) - else: - params["prompt"] = cls.instantiate(prompt, context, allow_format_str_template) - return params - - @classmethod - def test( - cls, - data, - eval_func=None, - use_cache=True, - agg_method="avg", - return_responses_and_per_instance_result=False, - logging_level=logging.WARNING, - **config, - ): - """Evaluate the responses created with the config for the OpenAI API call. - - Args: - data (list): The list of test data points. - eval_func (Callable): The evaluation function for responses per data instance. - The function should take a list of responses and a data point as input, - and return a dict of metrics. You need to either provide a valid callable - eval_func; or do not provide one (set None) but call the test function after - calling the tune function in which a eval_func is provided. - In the latter case we will use the eval_func provided via tune function. - Defaults to None. - - ```python - def eval_func(responses, **data): - solution = data["solution"] - success_list = [] - n = len(responses) - for i in range(n): - response = responses[i] - succeed = is_equiv_chain_of_thought(response, solution) - success_list.append(succeed) - return { - "expected_success": 1 - pow(1 - sum(success_list) / n, n), - "success": any(s for s in success_list), - } - ``` - use_cache (bool, Optional): Whether to use cached responses. Defaults to True. - agg_method (str, Callable or a dict of Callable): Result aggregation method (across - multiple instances) for each of the metrics. Defaults to 'avg'. - An example agg_method in str: - - ```python - agg_method = 'median' - ``` - An example agg_method in a Callable: - - ```python - agg_method = np.median - ``` - - An example agg_method in a dict of Callable: - - ```python - agg_method={'median_success': np.median, 'avg_success': np.mean} - ``` - - return_responses_and_per_instance_result (bool): Whether to also return responses - and per instance results in addition to the aggregated results. - logging_level (optional): logging level. Defaults to logging.WARNING. - **config (dict): parameters passed to the openai api call `create()`. - - Returns: - None when no valid eval_func is provided in either test or tune; - Otherwise, a dict of aggregated results, responses and per instance results if `return_responses_and_per_instance_result` is True; - Otherwise, a dict of aggregated results (responses and per instance results are not returned). - """ - result_agg, responses_list, result_list = {}, [], [] - metric_keys = None - cost = 0 - for i, data_i in enumerate(data): - print(f"evaluating data instance {i}") - response = cls.create(data_i, use_cache, **config) - cost += response["cost"] - # evaluate the quality of the responses - responses = cls.extract_text_or_function_call(response) - if eval_func is not None: - metrics = eval_func(responses, **data_i) - elif hasattr(cls, "_eval_func"): - metrics = cls._eval_func(responses, **data_i) - else: - print( - "Please either provide a valid eval_func or do the test after the tune function is called." - ) - return - if not metric_keys: - metric_keys = [] - for k in metrics.keys(): - try: - _ = float(metrics[k]) - metric_keys.append(k) - except ValueError: - pass - result_list.append(metrics) - if return_responses_and_per_instance_result: - responses_list.append(responses) - if isinstance(agg_method, str): - if agg_method in ["avg", "average"]: - for key in metric_keys: - result_agg[key] = np.mean([r[key] for r in result_list]) - elif agg_method == "median": - for key in metric_keys: - result_agg[key] = np.median([r[key] for r in result_list]) - else: - print( - f"Aggregation method {agg_method} not supported. Please write your own aggregation method as a callable(s)." - ) - elif callable(agg_method): - for key in metric_keys: - result_agg[key] = agg_method([r[key] for r in result_list]) - elif isinstance(agg_method, dict): - for key in metric_keys: - metric_agg_method = agg_method[key] - if not callable(metric_agg_method): - error_msg = "please provide a callable for each metric" - raise AssertionError(error_msg) - result_agg[key] = metric_agg_method([r[key] for r in result_list]) - else: - raise ValueError( - "agg_method needs to be a string ('avg' or 'median'),\ - or a callable, or a dictionary of callable." - ) - # should we also return the result_list and responses_list or not? - if "cost" not in result_agg: - result_agg["cost"] = cost - if "inference_cost" not in result_agg: - result_agg["inference_cost"] = cost / len(data) - if return_responses_and_per_instance_result: - return result_agg, result_list, responses_list - else: - return result_agg - - @classmethod - def cost(cls, response: dict): - """Compute the cost of an API call. - - Args: - response (dict): The response from OpenAI API. - - Returns: - The cost in USD. 0 if the model is not supported. - """ - model = response.get("model") - if model not in cls.price1K: - return 0 - # raise ValueError(f"Unknown model: {model}") - usage = response["usage"] - n_input_tokens = usage["prompt_tokens"] - n_output_tokens = usage.get("completion_tokens", 0) - price1K = cls.price1K[model] - if isinstance(price1K, tuple): - return (price1K[0] * n_input_tokens + price1K[1] * n_output_tokens) / 1000 - return price1K * (n_input_tokens + n_output_tokens) / 1000 - - @classmethod - def extract_text(cls, response: dict) -> List[str]: - """Extract the text from a completion or chat response. - - Args: - response (dict): The response from OpenAI API. - - Returns: - A list of text in the responses. - """ - choices = response["choices"] - if "text" in choices[0]: - return [choice["text"] for choice in choices] - return [choice["message"].get("content", "") for choice in choices] - - @classmethod - def extract_text_or_function_call(cls, response: dict) -> List[str]: - """Extract the text or function calls from a completion or chat response. - - Args: - response (dict): The response from OpenAI API. - - Returns: - A list of text or function calls in the responses. - """ - choices = response["choices"] - if "text" in choices[0]: - return [choice["text"] for choice in choices] - return [ - choice["message"] if "function_call" in choice["message"] else choice["message"].get("content", "") - for choice in choices - ] - - @classmethod - @property - def logged_history(cls) -> Dict: - """Return the book keeping dictionary.""" - return cls._history_dict - - @classmethod - def print_usage_summary(cls) -> Dict: - """Return the usage summary.""" - if cls._history_dict is None: - print("No usage summary available.", flush=True) - - token_count_summary = defaultdict(lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}) - - if not cls._history_compact: - source = cls._history_dict.values() - total_cost = sum(msg_pair["response"]["cost"] for msg_pair in source) - else: - # source = cls._history_dict["token_count"] - # total_cost = sum(cls._history_dict['cost']) - total_cost = sum(sum(value_list["cost"]) for value_list in cls._history_dict.values()) - source = ( - token_data for value_list in cls._history_dict.values() for token_data in value_list["token_count"] - ) - - for entry in source: - if not cls._history_compact: - model = entry["response"]["model"] - token_data = entry["response"]["usage"] - else: - model = entry["model"] - token_data = entry - - token_count_summary[model]["prompt_tokens"] += token_data["prompt_tokens"] - token_count_summary[model]["completion_tokens"] += token_data["completion_tokens"] - token_count_summary[model]["total_tokens"] += token_data["total_tokens"] - - print(f"Total cost: {total_cost}", flush=True) - for model, counts in token_count_summary.items(): - print( - f"Token count summary for model {model}: prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}", - flush=True, - ) - - @classmethod - def start_logging( - cls, history_dict: Optional[Dict] = None, compact: Optional[bool] = True, reset_counter: Optional[bool] = True - ): - """Start book keeping. - - Args: - history_dict (Dict): A dictionary for book keeping. - If no provided, a new one will be created. - compact (bool): Whether to keep the history dictionary compact. - Compact history contains one key per conversation, and the value is a dictionary - like: - ```python - { - "create_at": [0, 1], - "cost": [0.1, 0.2], - } - ``` - where "created_at" is the index of API calls indicating the order of all the calls, - and "cost" is the cost of each call. This example shows that the conversation is based - on two API calls. The compact format is useful for condensing the history of a conversation. - If compact is False, the history dictionary will contain all the API calls: the key - is the index of the API call, and the value is a dictionary like: - ```python - { - "request": request_dict, - "response": response_dict, - } - ``` - where request_dict is the request sent to OpenAI API, and response_dict is the response. - For a conversation containing two API calls, the non-compact history dictionary will be like: - ```python - { - 0: { - "request": request_dict_0, - "response": response_dict_0, - }, - 1: { - "request": request_dict_1, - "response": response_dict_1, - }, - ``` - The first request's messages plus the response is equal to the second request's messages. - For a conversation with many turns, the non-compact history dictionary has a quadratic size - while the compact history dict has a linear size. - reset_counter (bool): whether to reset the counter of the number of API calls. - """ - print( - "logging via Completion.start_logging is deprecated in pyautogen v0.2. " - "logging via OpenAIWrapper will be added back in a future release." - ) - cls._history_dict = {} if history_dict is None else history_dict - cls._history_compact = compact - cls._count_create = 0 if reset_counter or cls._count_create is None else cls._count_create - - @classmethod - def stop_logging(cls): - """End book keeping.""" - cls._history_dict = cls._count_create = None - diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index 872086b..b130248 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -15,13 +15,11 @@ from train_methods.legacy_autogen.cache import AbstractCache from train_methods.legacy_autogen.chat import ChatResult, a_initiate_chats, initiate_chats, _post_process_carryover_item, consolidate_chat_info from train_methods.legacy_autogen.client import ModelClient, OpenAIWrapper -from train_methods.legacy_autogen.coding import CodeExecutor from train_methods.legacy_autogen.stream import IOStream from train_methods.legacy_autogen.utils import ( content_str, load_basemodels_if_needed, - serialize_to_str, - get_function_schema + serialize_to_str ) __all__ = ("ConversableAgent",) @@ -370,11 +368,8 @@ def description(self, description: str): self._description = description @property - def code_executor(self) -> CodeExecutor | None: - """The code executor used by this agent. Returns None if code execution is disabled.""" - if not hasattr(self, "_code_executor"): - return None - return self._code_executor + def code_executor(self) -> None: + return None def register_reply( self, @@ -2284,80 +2279,7 @@ async def _a_wrapped_func(*args, **kwargs): return wrapped_func - def register_for_llm( - self, - *, - name: str | None = None, - description: str | None = None, - api_style: Literal["function", "tool"] = "tool", - ) -> Callable[[F], F]: - """Decorator factory for registering a function to be used by an agent. - - It's return value is used to decorate a function to be registered to the agent. The function uses type hints to - specify the arguments and return type. The function name is used as the default name for the function, - but a custom name can be provided. The function description is used to describe the function in the - agent's configuration. - - Args: - name (optional(str)): name of the function. If None, the function name will be used (default: None). - description (optional(str)): description of the function (default: None). It is mandatory - for the initial decorator, but the following ones can omit it. - api_style: (literal): the API style for function call. - For Azure OpenAI API, use version 2023-12-01-preview or later. - `"function"` style will be deprecated. For earlier version use - `"function"` if `"tool"` doesn't work. - See [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/function-calling?tabs=python) for details. - - Returns: - The decorator for registering a function to be used by an agent. - """ - - def _decorator(func: F) -> F: - """Decorator for registering a function to be used by an agent. - - Args: - func: the function to be registered. - - Returns: - The function to be registered, with the _description attribute set to the function description. - - Raises: - ValueError: if the function description is not provided and not propagated by a previous decorator. - RuntimeError: if the LLM config is not set up before registering a function. - - """ - # name can be overwritten by the parameter, by default it is the same as function name - if name: - func._name = name - elif not hasattr(func, "_name"): - func._name = func.__name__ - - # description is propagated from the previous decorator, but it is mandatory for the first one - if description: - func._description = description - else: - if not hasattr(func, "_description"): - raise ValueError("Function description is required, none found.") - - # get JSON schema for the function - f = get_function_schema(func, name=func._name, description=func._description) - - # register the function to the agent if there is LLM config, raise an exception otherwise - if self.llm_config is None: - raise RuntimeError("LLM config must be setup before registering a function for LLM.") - - if api_style == "function": - f = f["function"] - self.update_function_signature(f, is_remove=False) - elif api_style == "tool": - self.update_tool_signature(f, is_remove=False) - else: - raise ValueError(f"Unsupported API style: {api_style}") - - return func - - return _decorator - + def register_model_client(self, model_client_cls: ModelClient, **kwargs): """Register a model client. diff --git a/train_methods/legacy_autogen/utils.py b/train_methods/legacy_autogen/utils.py index dbfd697..edb385e 100644 --- a/train_methods/legacy_autogen/utils.py +++ b/train_methods/legacy_autogen/utils.py @@ -1,38 +1,12 @@ import functools import inspect import json -import os -import pathlib -import re -import string -import subprocess -import sys -import time -import venv -from concurrent.futures import ThreadPoolExecutor, TimeoutError -from hashlib import md5 -from types import SimpleNamespace from typing import Callable, Literal, TypedDict, Any, Annotated, ForwardRef from typing_extensions import get_args, get_origin -import docker -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel from pydantic._internal._typing_extra import try_eval_type -from pydantic.json_schema import JsonSchemaValue - -from train_methods.legacy_autogen.completion import Completion - -SENTINEL = object() -DEFAULT_MODEL = "gpt-4" -FAST_MODEL = "gpt-3.5-turbo" -CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```" -WORKING_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extensions") -UNKNOWN = "unknown" -TIMEOUT_MSG = "Timeout" -DEFAULT_TIMEOUT = 600 -WIN32 = sys.platform == "win32" -PATH_SEPARATOR = WIN32 and "\\" or "/" -PYTHON_VARIANTS = ["python", "Python", "py"] + class UserMessageTextContentPart(TypedDict): type: Literal["text"] @@ -83,536 +57,6 @@ def content_str(content: str | list[UserMessageTextContentPart | UserMessageImag raise ValueError(f"Wrong content format: unknown type {item['type']} within the content") return rst - - -_IMPROVE_FUNCTION_CONFIG = { - "prompt": """Improve the function '{func_name}' to achieve the objective '{objective}'. -The current implementation of the function is as follows: -{file_string}""", - "model": DEFAULT_MODEL, - "request_timeout": 600, -} - - -def improve_function(file_name, func_name, objective, **config): - """(openai<1) Improve the function to achieve the objective.""" - params = {**_IMPROVE_FUNCTION_CONFIG, **config} - # read the entire file into a str - with open(file_name, "r") as f: - file_string = f.read() - response = Completion.create( - {"func_name": func_name, "objective": objective, "file_string": file_string}, **params - ) - return Completion.extract_text(response)[0], response["cost"] - - -_IMPROVE_CODE_CONFIG = { - "prompt": """Analyze the code in the following files and return a list of suggestions for improvement{followup}, to achieve the objective of '{objective}'. -{code} -""", - "model": DEFAULT_MODEL, - "request_timeout": 900, -} - - -def improve_code(files, objective, suggest_only=True, **config): - """(openai<1) Improve the code to achieve a given objective. - - Args: - files (list): A list of file names containing the source code. - objective (str): The objective to achieve. - suggest_only (bool): Whether to return only the suggestions or the improved code. - config (Optional, dict): The configuration for the API call. - - Returns: - str: The improved code if suggest_only=False; a list of suggestions if suggest_only=True (default). - float: The cost of the generation. - """ - code = "" - for file_name in files: - # read the entire file into a string - with open(file_name, "r") as f: - file_string = f.read() - code += f"""{file_name}: -{file_string} - -""" - params = {**_IMPROVE_CODE_CONFIG, **config} - followup = "" if suggest_only else " followed by the improved code" - response = Completion.create({"objective": objective, "code": code, "followup": followup}, **params) - return Completion.extract_text(response)[0], response["cost"] - - -def timeout_handler(signum, frame): - raise TimeoutError("Timed out!") - - -def get_powershell_command(): - try: - result = subprocess.run(["powershell", "$PSVersionTable.PSVersion.Major"], capture_output=True, text=True) - if result.returncode == 0: - return "powershell" - except (FileNotFoundError, NotADirectoryError): - # This means that 'powershell' command is not found so now we try looking for 'pwsh' - try: - result = subprocess.run( - ["pwsh", "-Command", "$PSVersionTable.PSVersion.Major"], capture_output=True, text=True - ) - if result.returncode == 0: - return "pwsh" - except FileExistsError as e: - raise FileNotFoundError( - "Neither powershell.exe nor pwsh.exe is present in the system. " - "Please install PowerShell and try again. " - ) from e - except NotADirectoryError as e: - raise NotADirectoryError( - "PowerShell is either not installed or its path is not given " - "properly in the environment variable PATH. Please check the " - "path and try again. " - ) from e - except PermissionError as e: - raise PermissionError("No permission to run powershell.") from e - - -def _cmd(lang: str) -> str: - if lang in PYTHON_VARIANTS: - return "python" - if lang.startswith("python") or lang in ["bash", "sh"]: - return lang - if lang in ["shell"]: - return "sh" - if lang == "javascript": - return "node" - if lang in ["ps1", "pwsh", "powershell"]: - powershell_command = get_powershell_command() - return powershell_command - - raise NotImplementedError(f"{lang} not recognized in code execution") - - -def is_docker_running() -> bool: - """Check if docker is running. - - Returns: - bool: True if docker is running; False otherwise. - """ - try: - client = docker.from_env() - client.ping() - return True - except docker.errors.DockerException: - return False - - -def in_docker_container() -> bool: - """Check if the code is running in a docker container. - - Returns: - bool: True if the code is running in a docker container; False otherwise. - """ - return os.path.exists("/.dockerenv") - - - -def _sanitize_filename_for_docker_tag(filename: str) -> str: - """Convert a filename to a valid docker tag. - See https://docs.docker.com/engine/reference/commandline/tag/ for valid tag - format. - - Args: - filename (str): The filename to be converted. - - Returns: - str: The sanitized Docker tag. - """ - # Replace any character not allowed with an underscore - allowed_chars = set(string.ascii_letters + string.digits + "_.-") - sanitized = "".join(char if char in allowed_chars else "_" for char in filename) - - # Ensure it does not start with a period or a dash - if sanitized.startswith(".") or sanitized.startswith("-"): - sanitized = "_" + sanitized[1:] - - # Truncate if longer than 128 characters - return sanitized[:128] - - -def execute_code( - code: str | None = None, - timeout: int | None = None, - filename: str | None = None, - work_dir: str | None = None, - use_docker: list[str] | str | bool = SENTINEL, - lang: str | None = "python", -) -> tuple[int, str, str | None]: - """Execute code in a docker container. - This function is not tested on MacOS. - - Args: - code (Optional, str): The code to execute. - If None, the code from the file specified by filename will be executed. - Either code or filename must be provided. - timeout (Optional, int): The maximum execution time in seconds. - If None, a default timeout will be used. The default timeout is 600 seconds. On Windows, the timeout is not enforced when use_docker=False. - filename (Optional, str): The file name to save the code or where the code is stored when `code` is None. - If None, a file with a randomly generated name will be created. - The randomly generated file will be deleted after execution. - The file name must be a relative path. Relative paths are relative to the working directory. - work_dir (Optional, str): The working directory for the code execution. - If None, a default working directory will be used. - The default working directory is the "extensions" directory under - "path_to_autogen". - use_docker (list, str or bool): The docker image to use for code execution. - Default is True, which means the code will be executed in a docker container. A default list of images will be used. - If a list or a str of image name(s) is provided, the code will be executed in a docker container - with the first image successfully pulled. - If False, the code will be executed in the current environment. - Expected behaviour: - - If `use_docker` is not set (i.e. left default to True) or is explicitly set to True and the docker package is available, the code will run in a Docker container. - - If `use_docker` is not set (i.e. left default to True) or is explicitly set to True but the Docker package is missing or docker isn't running, an error will be raised. - - If `use_docker` is explicitly set to False, the code will run natively. - If the code is executed in the current environment, - the code must be trusted. - lang (Optional, str): The language of the code. Default is "python". - - Returns: - int: 0 if the code executes successfully. - str: The error message if the code fails to execute; the stdout otherwise. - image: The docker image name after container run when docker is used. - """ - if all((code is None, filename is None)): - error_msg = f"Either {code=} or {filename=} must be provided." - raise AssertionError(error_msg) - - running_inside_docker = in_docker_container() - docker_running = is_docker_running() - - timeout = timeout or DEFAULT_TIMEOUT - original_filename = filename - if WIN32 and lang in ["sh", "shell"] and (not use_docker): - lang = "ps1" - if filename is None: - code_hash = md5(code.encode()).hexdigest() - # create a file with a automatically generated name - filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}" - if work_dir is None: - work_dir = WORKING_DIR - - filepath = os.path.join(work_dir, filename) - file_dir = os.path.dirname(filepath) - os.makedirs(file_dir, exist_ok=True) - - if code is not None: - with open(filepath, "w", encoding="utf-8") as fout: - fout.write(code) - - if not use_docker or running_inside_docker: - # already running in a docker container - cmd = [ - sys.executable if lang.startswith("python") else _cmd(lang), - f".\\{filename}" if WIN32 else filename, - ] - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit( - subprocess.run, - cmd, - cwd=work_dir, - capture_output=True, - text=True, - ) - try: - result = future.result(timeout=timeout) - except TimeoutError: - if original_filename is None: - os.remove(filepath) - return 1, TIMEOUT_MSG, None - if original_filename is None: - os.remove(filepath) - if result.returncode: - logs = result.stderr - if original_filename is None: - abs_path = str(pathlib.Path(filepath).absolute()) - logs = logs.replace(str(abs_path), "").replace(filename, "") - else: - abs_path = str(pathlib.Path(work_dir).absolute()) + PATH_SEPARATOR - logs = logs.replace(str(abs_path), "") - else: - logs = result.stdout - return result.returncode, logs, None - - # create a docker client - if use_docker and not docker_running: - raise RuntimeError( - "Docker package is missing or docker is not running. Please make sure docker is running or set use_docker=False." - ) - - client = docker.from_env() - - image_list = ( - ["python:3-slim", "python:3", "python:3-windowsservercore"] - if use_docker is True - else [use_docker] if isinstance(use_docker, str) else use_docker - ) - for image in image_list: - # check if the image exists - try: - client.images.get(image) - break - except docker.errors.ImageNotFound: - # pull the image - print("Pulling image", image) - try: - client.images.pull(image) - break - except docker.errors.DockerException: - print("Failed to pull image", image) - # get a randomized str based on current time to wrap the exit code - exit_code_str = f"exitcode{time.time()}" - abs_path = pathlib.Path(work_dir).absolute() - cmd = [ - "sh", - "-c", - f'{_cmd(lang)} "{filename}"; exit_code=$?; echo -n {exit_code_str}; echo -n $exit_code; echo {exit_code_str}', - ] - # create a docker container - container = client.containers.run( - image, - command=cmd, - working_dir="/workspace", - detach=True, - # get absolute path to the working directory - volumes={abs_path: {"bind": "/workspace", "mode": "rw"}}, - ) - start_time = time.time() - while container.status != "exited" and time.time() - start_time < timeout: - # Reload the container object - container.reload() - if container.status != "exited": - container.stop() - container.remove() - if original_filename is None: - os.remove(filepath) - return 1, TIMEOUT_MSG, image - # get the container logs - logs = container.logs().decode("utf-8").rstrip() - # commit the image - tag = _sanitize_filename_for_docker_tag(filename) - container.commit(repository="python", tag=tag) - # remove the container - container.remove() - # check if the code executed successfully - exit_code = container.attrs["State"]["ExitCode"] - if exit_code == 0: - # extract the exit code from the logs - pattern = re.compile(f"{exit_code_str}(\\d+){exit_code_str}") - match = pattern.search(logs) - exit_code = 1 if match is None else int(match.group(1)) - # remove the exit code from the logs - logs = logs if match is None else pattern.sub("", logs) - - if original_filename is None: - os.remove(filepath) - if exit_code: - logs = logs.replace(f"/workspace/{filename if original_filename is None else ''}", "") - # return the exit code, logs and image - return exit_code, logs, f"python:{tag}" - - -_GENERATE_ASSERTIONS_CONFIG = { - "prompt": """Given the signature and docstring, write the exactly same number of assertion(s) for the provided example(s) in the docstring, without assertion messages. - -func signature: -{definition} -assertions:""", - "model": FAST_MODEL, - "max_tokens": 256, - "stop": "\n\n", -} - - -def generate_assertions(definition: str, **config) -> tuple[str, float]: - """(openai<1) Generate assertions for a function. - - Args: - definition (str): The function definition, including the signature and docstr. - config (Optional, dict): The configuration for the API call. - - Returns: - str: The generated assertions. - float: The cost of the generation. - """ - params = {**_GENERATE_ASSERTIONS_CONFIG, **config} - response = Completion.create( - {"definition": definition}, - **params, - ) - assertions = Completion.extract_text(response)[0] - return assertions, response["cost"] - - -def _remove_check(response): - """Remove the check function from the response.""" - # find the position of the check function - pos = response.find("def check(") - if pos == -1: - return response - return response[:pos] - - -def eval_function_completions( - responses: list[str], - definition: str, - test: str | None = None, - entry_point: str | None = None, - assertions: str | Callable[[str], tuple[str, float]] | None = None, - timeout: float | None = 3, - use_docker: bool | None = True, -) -> dict: - """(openai<1) Select a response from a list of responses for the function completion task (using generated assertions), and/or evaluate if the task is successful using a gold test. - - Args: - responses (list): The list of responses. - definition (str): The input definition. - test (Optional, str): The test code. - entry_point (Optional, str): The name of the function. - assertions (Optional, str or Callable): The assertion code which serves as a filter of the responses, or an assertion generator. - When provided, only the responses that pass the assertions will be considered for the actual test (if provided). - timeout (Optional, float): The timeout for executing the code. - - Returns: - dict: The success metrics. - """ - n = len(responses) - if assertions is None: - # no assertion filter - success_list = [] - for i in range(n): - response = _remove_check(responses[i]) - code = ( - f"{response}\n{test}\ncheck({entry_point})" - if response.startswith("def") - else f"{definition}{response}\n{test}\ncheck({entry_point})" - ) - success = execute_code(code, timeout=timeout, use_docker=use_docker)[0] == 0 - success_list.append(success) - return { - "expected_success": 1 - pow(1 - sum(success_list) / n, n), - "success": any(s for s in success_list), - } - if callable(assertions) and n > 1: - # assertion generator - assertions, gen_cost = assertions(definition) - else: - assertions, gen_cost = None, 0 - if n > 1 or test is None: - for i in range(n): - response = responses[i] = _remove_check(responses[i]) - code = ( - f"{response}\n{assertions}" if response.startswith("def") else f"{definition}{response}\n{assertions}" - ) - succeed_assertions = execute_code(code, timeout=timeout, use_docker=use_docker)[0] == 0 - if succeed_assertions: - break - else: - # just test, no need to check assertions - succeed_assertions = False - i, response = 0, responses[0] - if test is None: - # no test code - return { - "index_selected": i, - "succeed_assertions": succeed_assertions, - "gen_cost": gen_cost, - "assertions": assertions, - } - code_test = ( - f"{response}\n{test}\ncheck({entry_point})" - if response.startswith("def") - else f"{definition}{response}\n{test}\ncheck({entry_point})" - ) - success = execute_code(code_test, timeout=timeout, use_docker=use_docker)[0] == 0 - return { - "index_selected": i, - "succeed_assertions": succeed_assertions, - "success": success, - "gen_cost": gen_cost, - "assertions": assertions, - } - - -_FUNC_COMPLETION_PROMPT = "# Python 3{definition}" -_FUNC_COMPLETION_STOP = ["\nclass", "\ndef", "\nif", "\nprint"] -_IMPLEMENT_CONFIGS = [ - {"model": FAST_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "temperature": 0, "cache_seed": 0}, - {"model": FAST_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 7, "cache_seed": 0}, - {"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "temperature": 0, "cache_seed": 1}, - {"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 2, "cache_seed": 2}, - {"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 1, "cache_seed": 2}, -] - - -class PassAssertionFilter: - def __init__(self, assertions): - self._assertions = assertions - self.cost = 0 - self.metrics = self.responses = None - - def pass_assertions(self, context, response, **_): - """(openai<1) Check if the response passes the assertions.""" - responses = Completion.extract_text(response) - metrics = eval_function_completions(responses, context["definition"], assertions=self._assertions) - self._assertions = metrics["assertions"] - self.cost += metrics["gen_cost"] - self.metrics = metrics - self.responses = responses - return metrics["succeed_assertions"] - - -def implement( - definition: str, - configs: list[dict] | None = None, - assertions: str | Callable[[str], tuple[str, float]] | None = generate_assertions, -) -> tuple[str, float]: - """(openai<1) Implement a function from a definition. - - Args: - definition (str): The function definition, including the signature and docstr. - configs (list): The list of configurations for completion. - assertions (Optional, str or Callable): The assertion code which serves as a filter of the responses, or an assertion generator. - - Returns: - str: The implementation. - float: The cost of the implementation. - int: The index of the configuration which generates the implementation. - """ - cost = 0 - configs = configs or _IMPLEMENT_CONFIGS - if len(configs) > 1 and callable(assertions): - assertions, cost = assertions(definition) - assertion_filter = PassAssertionFilter(assertions) - response = Completion.create( - {"definition": definition}, config_list=configs, filter_func=assertion_filter.pass_assertions - ) - cost += assertion_filter.cost + response["cost"] - return assertion_filter.responses[assertion_filter.metrics["index_selected"]], cost, response["config_id"] - - -def create_virtual_env(dir_path: str, **env_args) -> SimpleNamespace: - """Creates a python virtual environment and returns the context. - - Args: - dir_path (str): Directory path where the env will be created. - **env_args: Any extra args to pass to the `EnvBuilder` - - Returns: - SimpleNamespace: the virtual env context object.""" - if not env_args: - env_args = {"with_pip": True} - env_builder = venv.EnvBuilder(**env_args) - env_builder.create(dir_path) - return env_builder.ensure_directories(dir_path) - def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: """Get the type annotation of a parameter. @@ -736,210 +180,3 @@ def serialize_to_str(x: Any) -> str: else: return json.dumps(x, ensure_ascii=False) -def get_required_params(typed_signature: inspect.Signature) -> list[str]: - """Get the required parameters of a function - - Args: - signature: The signature of the function as returned by inspect.signature - - Returns: - A list of the required parameters of the function - """ - return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty] - -def get_default_values(typed_signature: inspect.Signature) -> dict[str, Any]: - """Get default values of parameters of a function - - Args: - signature: The signature of the function as returned by inspect.signature - - Returns: - A dictionary of the default values of the parameters of the function - """ - return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty} - -def get_typed_return_annotation(call: Callable[..., Any]) -> Any: - """Get the return annotation of a function. - - Args: - call: The function to get the return annotation for - - Returns: - The return annotation of the function - """ - signature = inspect.signature(call) - annotation = signature.return_annotation - - if annotation is inspect.Signature.empty: - return None - - globalns = getattr(call, "__globals__", {}) - return get_typed_annotation(annotation, globalns) - -def get_missing_annotations(typed_signature: inspect.Signature, required: list[str]) -> tuple[set[str], set[str]]: - """Get the missing annotations of a function - - Ignores the parameters with default values as they are not required to be annotated, but logs a warning. - Args: - typed_signature: The signature of the function with type annotations - required: The required parameters of the function - - Returns: - A set of the missing annotations of the function - """ - all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty} - missing = all_missing.intersection(set(required)) - unannotated_with_default = all_missing.difference(missing) - return missing, unannotated_with_default - -class Parameters(BaseModel): - """Parameters of a function as defined by the OpenAI API""" - - type: Literal["object"] = "object" - properties: dict[str, JsonSchemaValue] - required: list[str] - - -class Function(BaseModel): - """A function as defined by the OpenAI API""" - - description: Annotated[str, Field(description="Description of the function")] - name: Annotated[str, Field(description="Name of the function")] - parameters: Annotated[Parameters, Field(description="Parameters of the function")] - - -class ToolFunction(BaseModel): - """A function under tool as defined by the OpenAI API.""" - - type: Literal["function"] = "function" - function: Annotated[Function, Field(description="Function under tool")] - - -def get_parameter_json_schema(k: str, v: Any, default_values: dict[str, Any]) -> JsonSchemaValue: - """Get a JSON schema for a parameter as defined by the OpenAI API - - Args: - k: The name of the parameter - v: The type of the parameter - default_values: The default values of the parameters of the function - - Returns: - A Pydanitc model for the parameter - """ - - def type2description(k: str, v: Annotated[type[Any], str] | type[Any]) -> str: - # handles Annotated - if hasattr(v, "__metadata__"): - retval = v.__metadata__[0] - if isinstance(retval, str): - return retval - else: - raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.") - else: - return k - - schema = TypeAdapter(v).json_schema() - if k in default_values: - dv = default_values[k] - schema["default"] = dv - - schema["description"] = type2description(k, v) - - return schema - -def get_parameters( - required: list[str], - param_annotations: dict[str, Annotated[type[Any], str] | type[Any]], - default_values: dict[str, Any], -) -> Parameters: - """Get the parameters of a function as defined by the OpenAI API - - Args: - required: The required parameters of the function - hints: The type hints of the function as returned by typing.get_type_hints - - Returns: - A Pydantic model for the parameters of the function - """ - return Parameters( - properties={ - k: get_parameter_json_schema(k, v, default_values) - for k, v in param_annotations.items() - if v is not inspect.Signature.empty - }, - required=required, - ) - -def get_function_schema(f: Callable[..., Any], *, name: str | None = None, description: str) -> dict[str, Any]: - """Get a JSON schema for a function as defined by the OpenAI API - - Args: - f: The function to get the JSON schema for - name: The name of the function - description: The description of the function - - Returns: - A JSON schema for the function - - Raises: - TypeError: If the function is not annotated - - Examples: - - ```python - def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None: - pass - - get_function_schema(f, description="function f") - - # {'type': 'function', - # 'function': {'description': 'function f', - # 'name': 'f', - # 'parameters': {'type': 'object', - # 'properties': {'a': {'type': 'str', 'description': 'Parameter a'}, - # 'b': {'type': 'int', 'description': 'b'}, - # 'c': {'type': 'float', 'description': 'Parameter c'}}, - # 'required': ['a']}}} - ``` - - """ - typed_signature = get_typed_signature(f) - required = get_required_params(typed_signature) - default_values = get_default_values(typed_signature) - param_annotations = get_param_annotations(typed_signature) - return_annotation = get_typed_return_annotation(f) - missing, unannotated_with_default = get_missing_annotations(typed_signature, required) - - if return_annotation is None: - print( - f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is " - + "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'." - ) - - if unannotated_with_default != set(): - unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)] - print( - f"The following parameters of the function '{f.__name__}' with default values are not annotated: " - + f"{', '.join(unannotated_with_default_s)}." - ) - - if missing != set(): - missing_s = [f"'{k}'" for k in sorted(missing)] - raise TypeError( - f"All parameters of the function '{f.__name__}' without default values must be annotated. " - + f"The annotations are missing for the following parameters: {', '.join(missing_s)}" - ) - - fname = name if name else f.__name__ - - parameters = get_parameters(required, param_annotations, default_values=default_values) - - function = ToolFunction( - function=Function( - description=description, - name=fname, - parameters=parameters, - ) - ) - - return function.model_dump() From 24c0f4152c6980b5e9a8cb5d0fbe6ed772add53f Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 18:27:05 +0900 Subject: [PATCH 18/25] remove unused codes --- train_methods/legacy_autogen/client.py | 48 -------------------------- train_methods/legacy_autogen/coding.py | 31 ----------------- 2 files changed, 79 deletions(-) delete mode 100644 train_methods/legacy_autogen/coding.py diff --git a/train_methods/legacy_autogen/client.py b/train_methods/legacy_autogen/client.py index bae85a9..ef9f445 100644 --- a/train_methods/legacy_autogen/client.py +++ b/train_methods/legacy_autogen/client.py @@ -601,54 +601,6 @@ def _register_default_client(self, config: dict[str, Any], openai_config: dict[s self._configure_azure_openai(config, openai_config) client = AzureOpenAI(**openai_config) self._clients.append(OpenAIClient(client)) - elif api_type is not None and api_type.startswith("cerebras"): - if cerebras_import_exception: - raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.") - client = CerebrasClient(**openai_config) - self._clients.append(client) - elif api_type is not None and api_type.startswith("google"): - if gemini_import_exception: - raise ImportError("Please install `google-generativeai` to use Google OpenAI API.") - client = GeminiClient(**openai_config) - self._clients.append(client) - elif api_type is not None and api_type.startswith("anthropic"): - if "api_key" not in config: - self._configure_openai_config_for_bedrock(config, openai_config) - if anthropic_import_exception: - raise ImportError("Please install `anthropic` to use Anthropic API.") - client = AnthropicClient(**openai_config) - self._clients.append(client) - elif api_type is not None and api_type.startswith("mistral"): - if mistral_import_exception: - raise ImportError("Please install `mistralai` to use the Mistral.AI API.") - client = MistralAIClient(**openai_config) - self._clients.append(client) - elif api_type is not None and api_type.startswith("together"): - if together_import_exception: - raise ImportError("Please install `together` to use the Together.AI API.") - client = TogetherClient(**openai_config) - self._clients.append(client) - elif api_type is not None and api_type.startswith("groq"): - if groq_import_exception: - raise ImportError("Please install `groq` to use the Groq API.") - client = GroqClient(**openai_config) - self._clients.append(client) - elif api_type is not None and api_type.startswith("cohere"): - if cohere_import_exception: - raise ImportError("Please install `cohere` to use the Cohere API.") - client = CohereClient(**openai_config) - self._clients.append(client) - elif api_type is not None and api_type.startswith("ollama"): - if ollama_import_exception: - raise ImportError("Please install with `[ollama]` option to use the Ollama API.") - client = OllamaClient(**openai_config) - self._clients.append(client) - elif api_type is not None and api_type.startswith("bedrock"): - self._configure_openai_config_for_bedrock(config, openai_config) - if bedrock_import_exception: - raise ImportError("Please install `boto3` to use the Amazon Bedrock API.") - client = BedrockClient(**openai_config) - self._clients.append(client) else: client = OpenAI(**openai_config) self._clients.append(OpenAIClient(client)) diff --git a/train_methods/legacy_autogen/coding.py b/train_methods/legacy_autogen/coding.py deleted file mode 100644 index c1f4bca..0000000 --- a/train_methods/legacy_autogen/coding.py +++ /dev/null @@ -1,31 +0,0 @@ -import base64 -import json -import os -import re -import inspect -import importlib -import subprocess -import sys -import uuid -import warnings -from dataclasses import dataclass, field -from pathlib import Path -from hashlib import md5 -from string import Template -from importlib.abc import SourceLoader -from queue import Empty -from textwrap import indent, dedent -from types import SimpleNamespace -from typing import Protocol, Literal, TypedDict, Mapping, Any, ClassVar, Callable, Generic, TypeVar -from typing_extensions import ParamSpec - -from jupyter_client import KernelManager -from jupyter_client.kernelspec import KernelSpecManager -from pydantic import BaseModel, Field, field_validator - - - -A = ParamSpec("A") -T = TypeVar("T") -P = ParamSpec("P") - From 509bf0d407a23571dd297181a45cd7f5fcf7c210 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Sun, 1 Feb 2026 18:34:07 +0900 Subject: [PATCH 19/25] remove unused codes --- train_methods/legacy_autogen/client.py | 102 +----- .../legacy_autogen/legacy_autogen.py | 333 +----------------- .../legacy_autogen_conversable_agent.py | 77 +--- 3 files changed, 4 insertions(+), 508 deletions(-) diff --git a/train_methods/legacy_autogen/client.py b/train_methods/legacy_autogen/client.py index ef9f445..0280dca 100644 --- a/train_methods/legacy_autogen/client.py +++ b/train_methods/legacy_autogen/client.py @@ -4,7 +4,6 @@ from typing import Protocol, Any, Callable -import tiktoken from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI from openai.resources import Completions from openai.types.chat import ChatCompletion @@ -102,99 +101,6 @@ def get_key(config: dict[str, Any]) -> str: config.pop(key) return json.dumps(config, sort_keys=True) -def _num_token_from_text(text: str, model: str = "gpt-3.5-turbo-0613"): - """Return the number of tokens used by a string.""" - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - print(f"Model {model} not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - return len(encoding.encode(text)) - -def _num_token_from_messages(messages: list | dict, model="gpt-3.5-turbo-0613"): - """Return the number of tokens used by a list of messages. - - retrieved from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb/ - """ - if isinstance(messages, dict): - messages = [messages] - - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - print(f"Model {model} not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - if model in { - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", - "gpt-4-0613", - "gpt-4-32k-0613", - }: - tokens_per_message = 3 - tokens_per_name = 1 - elif model == "gpt-3.5-turbo-0301": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - tokens_per_name = -1 # if there's a name, the role is omitted - elif "gpt-3.5-turbo" in model: - print("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") - return _num_token_from_messages(messages, model="gpt-3.5-turbo-0613") - elif "gpt-4" in model: - print("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") - return _num_token_from_messages(messages, model="gpt-4-0613") - elif "gemini" in model: - print("Gemini is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.") - return _num_token_from_messages(messages, model="gpt-4-0613") - elif "claude" in model: - print("Claude is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.") - return _num_token_from_messages(messages, model="gpt-4-0613") - elif "mistral-" in model or "mixtral-" in model: - print("Mistral.AI models are not supported in tiktoken. Returning num tokens assuming gpt-4-0613.") - return _num_token_from_messages(messages, model="gpt-4-0613") - else: - raise NotImplementedError( - f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" - ) - num_tokens = 0 - for message in messages: - num_tokens += tokens_per_message - for key, value in message.items(): - if value is None: - continue - - # function calls - if not isinstance(value, str): - try: - value = json.dumps(value) - except TypeError: - print( - f"Value {value} is not a string and cannot be converted to json. It is a type: {type(value)} Skipping." - ) - continue - - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name - num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> - return num_tokens - -def count_token(input: str | list | dict, model: str = "gpt-3.5-turbo-0613") -> int: - """Count number of tokens used by an OpenAI model. - Args: - input: (str, list, dict): Input to the model. - model: (str): Model name. - - Returns: - int: Number of tokens from the input. - """ - if isinstance(input, str): - return _num_token_from_text(input, model=model) - elif isinstance(input, list) or isinstance(input, dict): - return _num_token_from_messages(input, model=model) - else: - raise ValueError(f"input must be str, list or dict, but we got {type(input)}") - class PlaceHolderClient: def __init__(self, config): @@ -389,13 +295,7 @@ def create(self, params: dict[str, Any]) -> ChatCompletion: iostream.print("\033[0m\n") # Prepare the final ChatCompletion object based on the accumulated data - model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API - try: - prompt_tokens = count_token(params["messages"], model) - except NotImplementedError as e: - # Catch token calculation error if streaming with customized models. - print(str(e)) - prompt_tokens = 0 + prompt_tokens = 0 response = ChatCompletion( id=chunk.id, model=chunk.model, diff --git a/train_methods/legacy_autogen/legacy_autogen.py b/train_methods/legacy_autogen/legacy_autogen.py index 0f27dd3..bee239d 100644 --- a/train_methods/legacy_autogen/legacy_autogen.py +++ b/train_methods/legacy_autogen/legacy_autogen.py @@ -910,42 +910,6 @@ def last_speaker(self) -> Agent: send the message to all other agents in the group chat. So, when an agent receives a message, it will always be from the group chat manager. With this property, the agent receiving the message can know who actually sent the message. - - Example: - ```python - from autogen import ConversableAgent - from autogen import GroupChat, GroupChatManager - - - def print_messages(recipient, messages, sender, config): - # Print the message immediately - print( - f"Sender: {sender.name} | Recipient: {recipient.name} | Message: {messages[-1].get('content')}" - ) - print(f"Real Sender: {sender.last_speaker.name}") - assert sender.last_speaker.name in messages[-1].get("content") - return False, None # Required to ensure the agent communication flow continues - - - agent_a = ConversableAgent("agent A", default_auto_reply="I'm agent A.") - agent_b = ConversableAgent("agent B", default_auto_reply="I'm agent B.") - agent_c = ConversableAgent("agent C", default_auto_reply="I'm agent C.") - for agent in [agent_a, agent_b, agent_c]: - agent.register_reply( - [ConversableAgent, None], reply_func=print_messages, config=None - ) - group_chat = GroupChat( - [agent_a, agent_b, agent_c], - messages=[], - max_round=6, - speaker_selection_method="random", - allow_repeat_speaker=True, - ) - chat_manager = GroupChatManager(group_chat) - groupchat_result = agent_a.initiate_chat( - chat_manager, message="Hi, there, I'm agent A." - ) - ``` """ return self._last_speaker @@ -1101,302 +1065,7 @@ async def a_run_chat( a.previous_cache = None return True, None - def resume( - self, - messages: Union[list[dict], str], - remove_termination_string: Union[str, Callable[[str], str]] = None, - silent: bool | None = False, - ) -> tuple[ConversableAgent, dict]: - """Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established - as per the original group chat. - - Args: - - messages Union[list[dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries. - - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination - If a string is provided, this string will be removed from last message. - If a function is provided, the last message will be passed to this function. - - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. - - Returns: - - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message - """ - - # Convert messages from string to messages list, if needed - if isinstance(messages, str): - messages = self.messages_from_string(messages) - elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages): - messages = deepcopy(messages) - else: - raise Exception("Messages is not of type str or list[dict]") - - # Clean up the objects, ensuring there are no messages in the agents and group chat - - # Clear agent message history - for agent in self._groupchat.agents: - if isinstance(agent, ConversableAgent): - agent.clear_history() - - # Clear Manager message history - self.clear_history() - - # Clear GroupChat messages - self._groupchat.reset() - - # Validation of message and agents - - try: - self._valid_resume_messages(messages) - except: - raise - - # Load the messages into the group chat - for i, message in enumerate(messages): - if "name" in message: - message_speaker_agent = self._groupchat.agent_by_name(message["name"]) - else: - # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state) - message_speaker_agent = self - message["name"] = self.name - - # If it wasn't an agent speaking, it may be the manager - if not message_speaker_agent and message["name"] == self.name: - message_speaker_agent = self - - # Add previous messages to each agent (except the last message, as we'll kick off the conversation with it) - if i != len(messages) - 1: - for agent in self._groupchat.agents: - self.send(message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True) - - # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly - if message_speaker_agent: - self._groupchat.append(message, message_speaker_agent) - else: - self._groupchat.messages.append(message) - - # Last speaker agent - last_speaker_name = message["name"] - - # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future) - last_message = message - - # Get last speaker as an agent - previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name) - - # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so - if not previous_last_agent and ( - last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name - ): - previous_last_agent = self - - # Termination removal and check - self._process_resume_termination(remove_termination_string, messages) - - if not silent: - iostream = IOStream.get_default() - iostream.print( - f"Prepared group chat with {len(messages)} messages, the last speaker is", - colored(last_speaker_name, "yellow"), - flush=True, - ) - - # Update group chat settings for resuming - self._groupchat.send_introductions = False - - return previous_last_agent, last_message - - async def a_resume( - self, - messages: Union[list[dict], str], - remove_termination_string: Union[str, Callable[[str], str]] = None, - silent: bool | None = False, - ) -> tuple[ConversableAgent, dict]: - """Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established - as per the original group chat. - - Args: - - messages Union[list[dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries. - - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination - If a string is provided, this string will be removed from last message. - If a function is provided, the last message will be passed to this function, and the function returns the string after processing. - - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. - - Returns: - - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message - """ - - # Convert messages from string to messages list, if needed - if isinstance(messages, str): - messages = self.messages_from_string(messages) - elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages): - messages = deepcopy(messages) - else: - raise Exception("Messages is not of type str or list[dict]") - - # Clean up the objects, ensuring there are no messages in the agents and group chat - - # Clear agent message history - for agent in self._groupchat.agents: - if isinstance(agent, ConversableAgent): - agent.clear_history() - - # Clear Manager message history - self.clear_history() - - # Clear GroupChat messages - self._groupchat.reset() - - # Validation of message and agents - - try: - self._valid_resume_messages(messages) - except: - raise - - # Load the messages into the group chat - for i, message in enumerate(messages): - if "name" in message: - message_speaker_agent = self._groupchat.agent_by_name(message["name"]) - else: - # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state) - message_speaker_agent = self - message["name"] = self.name - - # If it wasn't an agent speaking, it may be the manager - if not message_speaker_agent and message["name"] == self.name: - message_speaker_agent = self - - # Add previous messages to each agent (except the last message, as we'll kick off the conversation with it) - if i != len(messages) - 1: - for agent in self._groupchat.agents: - await self.a_send( - message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True - ) - - # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly - if message_speaker_agent: - self._groupchat.append(message, message_speaker_agent) - else: - self._groupchat.messages.append(message) - - # Last speaker agent - last_speaker_name = message["name"] - - # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future) - last_message = message - - # Get last speaker as an agent - previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name) - - # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so - if not previous_last_agent and ( - last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name - ): - previous_last_agent = self - - # Termination removal and check - self._process_resume_termination(remove_termination_string, messages) - - if not silent: - iostream = IOStream.get_default() - iostream.print( - f"Prepared group chat with {len(messages)} messages, the last speaker is", - colored(last_speaker_name, "yellow"), - flush=True, - ) - - # Update group chat settings for resuming - self._groupchat.send_introductions = False - - return previous_last_agent, last_message - - def _valid_resume_messages(self, messages: list[dict]): - """Validates the messages used for resuming - - args: - messages (list[dict]): list of messages to resume with - - returns: - - bool: Whether they are valid for resuming - """ - # Must have messages to start with, otherwise they should run run_chat - if not messages: - raise Exception( - "Cannot resume group chat as no messages were provided. Use GroupChatManager.run_chat or ConversableAgent.initiate_chat to start a new chat." - ) - - # Check that all agents in the chat messages exist in the group chat - for message in messages: - if message.get("name"): - if ( - not self._groupchat.agent_by_name(message["name"]) - and not message["name"] == self._groupchat.admin_name # ignore group chat's name - and not message["name"] == self.name # ignore group chat manager's name - ): - raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}") - - def _process_resume_termination( - self, remove_termination_string: str | Callable[[str], str], messages: list[dict] - ): - """Removes termination string, if required, and checks if termination may occur. - - args: - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination - If a string is provided, this string will be removed from last message. - If a function is provided, the last message will be passed to this function, and the function returns the string after processing. - - returns: - None - """ - - last_message = messages[-1] - - # Replace any given termination string in the last message - if isinstance(remove_termination_string, str): - - def _remove_termination_string(content: str) -> str: - return content.replace(remove_termination_string, "") - - else: - _remove_termination_string = remove_termination_string - - if _remove_termination_string: - if messages[-1].get("content"): - messages[-1]["content"] = _remove_termination_string(messages[-1]["content"]) - - # Check if the last message meets termination (if it has one) - if self._is_termination_msg: - if self._is_termination_msg(last_message): - print("WARNING: Last message meets termination criteria and this may terminate the chat.") - - def messages_from_string(self, message_string: str) -> list[dict]: - """Reads the saved state of messages in Json format for resume and returns as a messages list - - args: - - message_string: Json string, the saved state - - returns: - - list[dict]: List of messages - """ - try: - state = json.loads(message_string) - except json.JSONDecodeError: - raise Exception("Messages string is not a valid JSON string") - - return state - - def messages_to_string(self, messages: list[dict]) -> str: - """Converts the provided messages into a Json string that can be used for resuming the chat. - The state is made up of a list of messages - - args: - - messages (list[dict]): set of messages to convert to a string - - returns: - - str: Json representation of the messages which can be persisted for resuming later - """ - - return json.dumps(messages) - + def _raise_exception_on_async_reply_functions(self) -> None: """Raise an exception if any async reply functions are registered. diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index b130248..8a43d36 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -177,75 +177,6 @@ def update_system_message(self, system_message: str) -> None: """ -def gather_usage_summary(agents: list[Agent]) -> dict[dict[str, dict], dict[str, dict]]: - r"""Gather usage summary from all agents. - - Args: - agents: (list): List of agents. - - Returns: - dictionary: A dictionary containing two keys: - - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference. - - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference". - - Example: - - ```python - { - "usage_including_cached_inference" : { - "total_cost": 0.0006090000000000001, - "gpt-35-turbo": { - "cost": 0.0006090000000000001, - "prompt_tokens": 242, - "completion_tokens": 123, - "total_tokens": 365 - }, - }, - - "usage_excluding_cached_inference" : { - "total_cost": 0.0006090000000000001, - "gpt-35-turbo": { - "cost": 0.0006090000000000001, - "prompt_tokens": 242, - "completion_tokens": 123, - "total_tokens": 365 - }, - } - } - ``` - - Note: - - If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`. - """ - - def aggregate_summary(usage_summary: dict[str, Any], agent_summary: dict[str, Any]) -> None: - if agent_summary is None: - return - usage_summary["total_cost"] += agent_summary.get("total_cost", 0) - for model, data in agent_summary.items(): - if model != "total_cost": - if model not in usage_summary: - usage_summary[model] = data.copy() - else: - usage_summary[model]["cost"] += data.get("cost", 0) - usage_summary[model]["prompt_tokens"] += data.get("prompt_tokens", 0) - usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0) - usage_summary[model]["total_tokens"] += data.get("total_tokens", 0) - - usage_including_cached_inference = {"total_cost": 0} - usage_excluding_cached_inference = {"total_cost": 0} - - for agent in agents: - if getattr(agent, "client", None): - aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary) - aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary) - - return { - "usage_including_cached_inference": usage_including_cached_inference, - "usage_excluding_cached_inference": usage_excluding_cached_inference, - } - class ConversableAgent(LLMAgent): """A class for generic conversable agents which can be configured as assistant or user proxy. @@ -367,10 +298,6 @@ def description(self, description: str): """Set the description of the agent.""" self._description = description - @property - def code_executor(self) -> None: - return None - def register_reply( self, trigger: Type[Agent] | str | Agent | Callable[[Agent], bool] | list, @@ -1190,7 +1117,7 @@ def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: d chat_result = ChatResult( chat_history=self.chat_messages[recipient], summary=summary, - cost=gather_usage_summary([self, recipient]), + cost=None, human_input=self._human_input, ) return chat_result @@ -1256,7 +1183,7 @@ async def a_initiate_chat( chat_result = ChatResult( chat_history=self.chat_messages[recipient], summary=summary, - cost=gather_usage_summary([self, recipient]), + cost=None, human_input=self._human_input, ) return chat_result From 37f716c37903222814d7de10b2a45bd2d4cd16f4 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Mon, 2 Feb 2026 17:00:19 +0900 Subject: [PATCH 20/25] minor fix --- train_methods/legacy_autogen/cache.py | 27 ++--------------- train_methods/legacy_autogen/chat.py | 29 +++---------------- .../legacy_autogen/legacy_autogen.py | 21 ++++---------- train_methods/utils_cogfd.py | 9 ------ 4 files changed, 12 insertions(+), 74 deletions(-) diff --git a/train_methods/legacy_autogen/cache.py b/train_methods/legacy_autogen/cache.py index d086c20..bdb79c0 100644 --- a/train_methods/legacy_autogen/cache.py +++ b/train_methods/legacy_autogen/cache.py @@ -435,26 +435,6 @@ def cache_factory( Returns: An instance of RedisCache, DiskCache, or CosmosDBCache. - Examples: - - Creating a Redis cache - - ```python - redis_cache = cache_factory("myseed", "redis://localhost:6379/0") - ``` - Creating a Disk cache - - ```python - disk_cache = cache_factory("myseed", None) - ``` - - Creating a Cosmos DB cache: - ```python - cosmos_cache = cache_factory("myseed", cosmosdb_config={ - "connection_string": "your_connection_string", - "database_id": "your_database_id", - "container_id": "your_container_id"} - ) ``` """ @@ -468,7 +448,6 @@ def cache_factory( path = os.path.join(cache_path_root, str(seed)) return DiskCache(os.path.join(".", path)) - class Cache(AbstractCache): """ A wrapper class for managing cache configuration and instances. @@ -566,9 +545,9 @@ def __init__(self, config: dict[str, Any]): # create cache instance self.cache = CacheFactory.cache_factory( seed=self.config["cache_seed"], - redis_url=self.config.get("redis_url"), - cache_path_root=self.config.get("cache_path_root"), - cosmosdb_config=self.config.get("cosmos_db_config"), + redis_url=self.config.get("redis_url", ""), + cache_path_root=self.config.get("cache_path_root", ""), + cosmosdb_config=self.config.get("cosmos_db_config", ""), ) def __enter__(self) -> "Cache": diff --git a/train_methods/legacy_autogen/chat.py b/train_methods/legacy_autogen/chat.py index c90dfd0..5879aa5 100644 --- a/train_methods/legacy_autogen/chat.py +++ b/train_methods/legacy_autogen/chat.py @@ -137,11 +137,6 @@ def _post_process_carryover_item(carryover_item): def __post_carryover_processing(chat_info: dict[str, Any]) -> None: iostream = IOStream.get_default() - if "message" not in chat_info: - warnings.warn( - "message is not provided in a chat_queue entry. input() will be called to get the initial message.", - UserWarning, - ) print_carryover = ( ("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]]) if isinstance(chat_info["carryover"], list) @@ -221,9 +216,7 @@ def initiate_chats(chat_queue: list[dict[str, Any]]) -> list[ChatResult]: chat_info["carryover"] = _chat_carryover + [ r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover ] - - if not chat_info.get("silent", False): - __post_carryover_processing(chat_info) + __post_carryover_processing(chat_info) sender = chat_info["sender"] chat_res = sender.initiate_chat(**chat_info) @@ -231,16 +224,11 @@ def initiate_chats(chat_queue: list[dict[str, Any]]) -> list[ChatResult]: return finished_chats -def __system_now_str(): - ct = datetime.datetime.now() - return f" System time at {ct}. " - - def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int): """ Update ChatResult when async Task for Chat is completed. """ - print(f"Update chat {chat_id} result on task completion." + __system_now_str()) + print(f"Update chat {chat_id} result on task completion. System time at {datetime.datetime.now()}.") chat_result = chat_future.result() chat_result.chat_id = chat_id @@ -251,7 +239,7 @@ async def _dependent_chat_future( """ Create an async Task for each chat. """ - print(f"Create Task for chat {chat_id}." + __system_now_str()) + print(f"Create Task for chat {chat_id}. System time at {datetime.datetime.now()}.") _chat_carryover = chat_info.get("carryover", []) finished_chat_indexes_to_exclude_from_carryover = chat_info.get( "finished_chat_indexes_to_exclude_from_carryover", [] @@ -280,20 +268,11 @@ async def _dependent_chat_future( chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info)) call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id) chat_res_future.add_done_callback(call_back_with_args) - print(f"Task for chat {chat_id} created." + __system_now_str()) + print(f"Task for chat {chat_id} created. System time at {datetime.datetime.now()}.") return chat_res_future async def a_initiate_chats(chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]: - """(async) Initiate a list of chats. - - args: - - Please refer to `initiate_chats`. - - - returns: - - (dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue. - """ consolidate_chat_info(chat_queue) _validate_recipients(chat_queue) chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue} diff --git a/train_methods/legacy_autogen/legacy_autogen.py b/train_methods/legacy_autogen/legacy_autogen.py index bee239d..91d4ed7 100644 --- a/train_methods/legacy_autogen/legacy_autogen.py +++ b/train_methods/legacy_autogen/legacy_autogen.py @@ -1,13 +1,12 @@ """Legacy autogen (ver 2.0) for cogfd """ -import json + import sys import random import re -from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Callable, Literal, Union +from typing import Any, Callable, Literal from termcolor import colored @@ -17,10 +16,6 @@ from train_methods.legacy_autogen.utils import content_str -class AgentNameConflict(Exception): - def __init__(self, msg: str = "Found multiple agents with the same name.", *args: Any, **kwargs: Any): - super().__init__(msg, *args, **kwargs) - class NoEligibleSpeaker(Exception): """Exception raised for early termination of a GroupChat.""" @@ -28,12 +23,6 @@ def __init__(self, message: str = "No eligible speakers."): self.message = message super().__init__(self.message) -class UndefinedNextAgent(Exception): - """Exception raised when the provided next agents list does not overlap with agents in the group.""" - - def __init__(self, message: str = "The provided agents list does not overlap with agents in the group."): - self.message = message - super().__init__(self.message) @dataclass class GroupChat: @@ -132,7 +121,7 @@ def agent_by_name( filtered_agents = [agent for agent in agents if agent.name == name] if raise_on_name_conflict and len(filtered_agents) > 1: - raise AgentNameConflict() + raise ValueError("Found multiple agents with the same name.") return filtered_agents[0] if filtered_agents else None @@ -152,7 +141,7 @@ def next_agent(self, agent: Agent, agents: list[Agent] | None = None) -> Agent: # Ensure the provided list of agents is a subset of self.agents if not set(agents).issubset(set(self.agents)): - raise UndefinedNextAgent() + raise ValueError("The provided agents list does not overlap with agents in the group.") # What index is the agent? (-1 if not present) idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1 @@ -167,7 +156,7 @@ def next_agent(self, agent: Agent, agents: list[Agent] | None = None) -> Agent: return self.agents[(offset + i) % len(self.agents)] # Explicitly handle cases where no valid next agent exists in the provided subset. - raise UndefinedNextAgent() + raise ValueError("The provided agents list does not overlap with agents in the group.") def select_speaker_msg(self, agents: list[Agent] | None = None) -> str: """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" diff --git a/train_methods/utils_cogfd.py b/train_methods/utils_cogfd.py index 04031fd..43d14f8 100644 --- a/train_methods/utils_cogfd.py +++ b/train_methods/utils_cogfd.py @@ -204,15 +204,6 @@ def generate_and_save_concept_graph( initial_message = f"X = {concept_combination_x}, Y = {combination_theme_y}" print(f"\n--- Starting chat for: '{initial_message}' ---") - - # Automatically trigger the chat to end after the initial response or based on specific conditions - def auto_end_chat(): - # Trigger to end the conversation after the response is received - print("Automatically ending the conversation.") - return "exit" # or any other appropriate method to end the conversation - - # Call the function after some condition or time has passed - auto_end_chat() final_graph_string = None parsed_graph = None From 1feb9943c482a5921f6bed361f25b7925a98d7e8 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Mon, 2 Feb 2026 17:14:38 +0900 Subject: [PATCH 21/25] remove redis and azure-cosmos cache --- requirements.txt | 2 - train_methods/legacy_autogen/cache.py | 321 +----------------- train_methods/legacy_autogen/chat.py | 21 +- train_methods/legacy_autogen/client.py | 74 ++-- .../legacy_autogen_conversable_agent.py | 26 +- 5 files changed, 44 insertions(+), 400 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7e5a398..91cf1fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,8 +11,6 @@ scikit-learn==1.5.2 termcolor==3.1.0 tiktoken==0.12.0 diskcache==5.6.3 -redis==7.0.0 -azure-cosmos==4.14.0 azure-identity==1.25.1 flaml==2.3.6 gdown==5.2.0 diff --git a/train_methods/legacy_autogen/cache.py b/train_methods/legacy_autogen/cache.py index bdb79c0..decfe5b 100644 --- a/train_methods/legacy_autogen/cache.py +++ b/train_methods/legacy_autogen/cache.py @@ -1,12 +1,9 @@ import os -import pickle +from pathlib import Path from types import TracebackType -from typing import Any, Protocol, Self, TypedDict +from typing import Any, Protocol, Self import diskcache -import redis -from azure.cosmos import CosmosClient, PartitionKey -from azure.cosmos.exceptions import CosmosResourceNotFoundError class AbstractCache(Protocol): """ @@ -161,291 +158,25 @@ def __exit__( """ self.close() -class RedisCache(AbstractCache): - """ - Implementation of AbstractCache using the Redis database. - - This class provides a concrete implementation of the AbstractCache - interface using the Redis database for caching data. - - Attributes: - seed (str | int): A seed or namespace used as a prefix for cache keys. - cache (redis.Redis): The Redis client used for caching. - - Methods: - __init__(self, seed, redis_url): Initializes the RedisCache with the given seed and Redis URL. - _prefixed_key(self, key): Internal method to get a namespaced cache key. - get(self, key, default=None): Retrieves an item from the cache. - set(self, key, value): Sets an item in the cache. - close(self): Closes the Redis client. - __enter__(self): Context management entry. - __exit__(self, exc_type, exc_value, traceback): Context management exit. - """ - - def __init__(self, seed: str | int, redis_url: str): - """ - Initialize the RedisCache instance. - - Args: - seed (str | int): A seed or namespace for the cache. This is used as a prefix for all cache keys. - redis_url (str): The URL for the Redis server. - - """ - self.seed = seed - self.cache = redis.Redis.from_url(redis_url) - - def _prefixed_key(self, key: str) -> str: - """ - Get a namespaced key for the cache. - - Args: - key (str): The original key. - - Returns: - str: The namespaced key. - """ - return f"autogen:{self.seed}:{key}" - - def get(self, key: str, default: Any | None = None) -> Any | None: - """ - Retrieve an item from the Redis cache. - - Args: - key (str): The key identifying the item in the cache. - default (optional): The default value to return if the key is not found. - Defaults to None. - - Returns: - The deserialized value associated with the key if found, else the default value. - """ - result = self.cache.get(self._prefixed_key(key)) - if result is None: - return default - return pickle.loads(result) - - def set(self, key: str, value: Any) -> None: - """ - Set an item in the Redis cache. - - Args: - key (str): The key under which the item is to be stored. - value: The value to be stored in the cache. - - Notes: - The value is serialized using pickle before being stored in Redis. - """ - serialized_value = pickle.dumps(value) - self.cache.set(self._prefixed_key(key), serialized_value) - - def close(self) -> None: - """ - Close the Redis client. - - Perform any necessary cleanup, such as closing network connections. - """ - self.cache.close() - - def __enter__(self) -> Self: - """ - Enter the runtime context related to the object. - - Returns: - self: The instance itself. - """ - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None - ) -> None: - """ - Exit the runtime context related to the object. - - Perform cleanup actions such as closing the Redis client. - - Args: - exc_type: The exception type if an exception was raised in the context. - exc_value: The exception value if an exception was raised in the context. - traceback: The traceback if an exception was raised in the context. - """ - self.close() - -class CosmosDBConfig(TypedDict, total=False): - connection_string: str - database_id: str - container_id: str - cache_seed: str | int | None - client: CosmosClient | None - -class CosmosDBCache(AbstractCache): - """ - Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API. - - This class provides a concrete implementation of the AbstractCache - interface using Azure Cosmos DB for caching data, with synchronous operations. - - Attributes: - seed (str | int): A seed or namespace used as a partition key. - client (CosmosClient): The Cosmos DB client used for caching. - container: The container instance used for caching. - """ - - def __init__(self, seed: str | int, cosmosdb_config: CosmosDBConfig): - """ - Initialize the CosmosDBCache instance. - - Args: - seed (str | int): A seed or namespace for the cache, used as a partition key. - connection_string (str): The connection string for the Cosmos DB account. - container_id (str): The container ID to be used for caching. - client (Optional[CosmosClient]): An existing CosmosClient instance to be used for caching. - """ - self.seed = str(seed) - self.client = cosmosdb_config.get("client") or CosmosClient.from_connection_string( - cosmosdb_config["connection_string"] - ) - database_id = cosmosdb_config.get("database_id", "autogen_cache") - self.database = self.client.get_database_client(database_id) - container_id = cosmosdb_config.get("container_id") - self.container = self.database.create_container_if_not_exists( - id=container_id, partition_key=PartitionKey(path="/partitionKey") - ) - - @classmethod - def create_cache(cls, seed: str | int, cosmosdb_config: CosmosDBConfig): - """ - Factory method to create a CosmosDBCache instance based on the provided configuration. - This method decides whether to use an existing CosmosClient or create a new one. - """ - if "client" in cosmosdb_config and isinstance(cosmosdb_config["client"], CosmosClient): - return cls.from_existing_client(seed, **cosmosdb_config) - else: - return cls.from_config(seed, cosmosdb_config) - - @classmethod - def from_config(cls, seed: str | int, cosmosdb_config: CosmosDBConfig): - return cls(str(seed), cosmosdb_config) - - @classmethod - def from_connection_string(cls, seed: str | int, connection_string: str, database_id: str, container_id: str): - config = {"connection_string": connection_string, "database_id": database_id, "container_id": container_id} - return cls(str(seed), config) - - @classmethod - def from_existing_client(cls, seed: str | int, client: CosmosClient, database_id: str, container_id: str): - config = {"client": client, "database_id": database_id, "container_id": container_id} - return cls(str(seed), config) - - def get(self, key: str, default: Any | None = None) -> Any | None: - """ - Retrieve an item from the Cosmos DB cache. - - Args: - key (str): The key identifying the item in the cache. - default (optional): The default value to return if the key is not found. - - Returns: - The deserialized value associated with the key if found, else the default value. - """ - try: - response = self.container.read_item(item=key, partition_key=str(self.seed)) - return pickle.loads(response["data"]) - except CosmosResourceNotFoundError: - return default - except Exception as e: - # Log the exception or rethrow after logging if needed - # Consider logging or handling the error appropriately here - raise e - - def set(self, key: str, value: Any) -> None: - """ - Set an item in the Cosmos DB cache. - - Args: - key (str): The key under which the item is to be stored. - value: The value to be stored in the cache. - - Notes: - The value is serialized using pickle before being stored. - """ - try: - serialized_value = pickle.dumps(value) - item = {"id": key, "partitionKey": str(self.seed), "data": serialized_value} - self.container.upsert_item(item) - except Exception as e: - # Log or handle exception - raise e - - def close(self) -> None: - """ - Close the Cosmos DB client. - - Perform any necessary cleanup, such as closing network connections. - """ - # CosmosClient doesn"t require explicit close in the current SDK - # If you created the client inside this class, you should close it if necessary - pass - - def __enter__(self): - """ - Context management entry. - - Returns: - self: The instance itself. - """ - return self - - def __exit__( - self, - exc_type: type | None, - exc_value: Exception | None, - traceback: Any | None, - ) -> None: - """ - Context management exit. - - Perform cleanup actions such as closing the Cosmos DB client. - """ - self.close() - class CacheFactory: @staticmethod def cache_factory( seed: str | int, - redis_url: str | None = None, cache_path_root: str = ".cache", - cosmosdb_config: dict[str, Any] | None = None, ) -> AbstractCache: """ Factory function for creating cache instances. - This function decides whether to create a RedisCache, DiskCache, or CosmosDBCache instance - based on the provided parameters. If RedisCache is available and a redis_url is provided, - a RedisCache instance is created. If connection_string, database_id, and container_id - are provided, a CosmosDBCache is created. Otherwise, a DiskCache instance is used. - Args: seed (str | int): Used as a seed or namespace for the cache. - redis_url (str | None): URL for the Redis server. cache_path_root (str): Root path for the disk cache. - cosmosdb_config (Optional[Dict[str, str]]): Dictionary containing 'connection_string', 'database_id', and 'container_id' for Cosmos DB cache. Returns: - An instance of RedisCache, DiskCache, or CosmosDBCache. - - ``` + An instance of DiskCache """ - if redis_url: - return RedisCache(seed, redis_url) - - if cosmosdb_config: - return CosmosDBCache.create_cache(seed, cosmosdb_config) - # Default to DiskCache if neither Redis nor Cosmos DB configurations are provided - path = os.path.join(cache_path_root, str(seed)) + path = Path(cache_path_root, str(seed)) return DiskCache(os.path.join(".", path)) class Cache(AbstractCache): @@ -463,25 +194,9 @@ class Cache(AbstractCache): ALLOWED_CONFIG_KEYS = [ "cache_seed", - "redis_url", "cache_path_root", - "cosmos_db_config", ] - @staticmethod - def redis(cache_seed: str | int = 42, redis_url: str = "redis://localhost:6379/0") -> "Cache": - """ - Create a Redis cache instance. - - Args: - cache_seed (str | int, optional): A seed for the cache. Defaults to 42. - redis_url (str, optional): The URL for the Redis server. Defaults to "redis://localhost:6379/0". - - Returns: - Cache: A Cache instance configured for Redis. - """ - return Cache({"cache_seed": cache_seed, "redis_url": redis_url}) - @staticmethod def disk(cache_seed: str | int = 42, cache_path_root: str = ".cache") -> "Cache": """ @@ -496,32 +211,6 @@ def disk(cache_seed: str | int = 42, cache_path_root: str = ".cache") -> "Cache" """ return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root}) - @staticmethod - def cosmos_db( - connection_string: str | None = None, - container_id: str | None = None, - cache_seed: str | int = 42, - client: Any | None = None, - ) -> "Cache": - """ - Create a Cosmos DB cache instance with 'autogen_cache' as database ID. - - Args: - connection_string (str, optional): Connection string to the Cosmos DB account. - container_id (str, optional): The container ID for the Cosmos DB account. - cache_seed (str | int, optional): A seed for the cache. - client: Optional[CosmosClient]: Pass an existing Cosmos DB client. - Returns: - Cache: A Cache instance configured for Cosmos DB. - """ - cosmos_db_config = { - "connection_string": connection_string, - "database_id": "autogen_cache", - "container_id": container_id, - "client": client, - } - return Cache({"cache_seed": str(cache_seed), "cosmos_db_config": cosmos_db_config}) - def __init__(self, config: dict[str, Any]): """ Initialize the Cache with the given configuration. @@ -545,9 +234,7 @@ def __init__(self, config: dict[str, Any]): # create cache instance self.cache = CacheFactory.cache_factory( seed=self.config["cache_seed"], - redis_url=self.config.get("redis_url", ""), cache_path_root=self.config.get("cache_path_root", ""), - cosmosdb_config=self.config.get("cosmos_db_config", ""), ) def __enter__(self) -> "Cache": diff --git a/train_methods/legacy_autogen/chat.py b/train_methods/legacy_autogen/chat.py index 5879aa5..323c637 100644 --- a/train_methods/legacy_autogen/chat.py +++ b/train_methods/legacy_autogen/chat.py @@ -1,6 +1,5 @@ import asyncio import datetime -import warnings from collections import defaultdict from dataclasses import dataclass from functools import partial @@ -32,17 +31,12 @@ def consolidate_chat_info(chat_info, uniform_sender=None) -> None: ), "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm." -Prerequisite = tuple[int, int] - @dataclass class ChatResult: chat_id: int = None - """chat id""" chat_history: list[dict[str, Any]] = None - """The chat history.""" summary: str = None - """A summary obtained from the chat.""" cost: dict[str, dict] = None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference" """The cost of the chat. The value for each usage type is a dictionary containing cost information for that specific type. @@ -61,16 +55,11 @@ def _validate_recipients(chat_queue: list[dict[str, Any]]) -> None: for chat_info in chat_queue: assert "recipient" in chat_info, "recipient must be provided." receipts_set.add(chat_info["recipient"]) - if len(receipts_set) < len(chat_queue): - warnings.warn( - "Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.", - UserWarning, - ) -def __create_async_prerequisites(chat_queue: list[dict[str, Any]]) -> list[Prerequisite]: +def __create_async_prerequisites(chat_queue: list[dict[str, Any]]) -> list[tuple[int, int]]: """ - Create list of Prerequisite (prerequisite_chat_id, chat_id) + Create list of tuple[int, int] (prerequisite_chat_id, chat_id) """ prerequisites = [] for chat_info in chat_queue: @@ -80,17 +69,17 @@ def __create_async_prerequisites(chat_queue: list[dict[str, Any]]) -> list[Prere pre_chats = chat_info.get("prerequisites", []) for pre_chat_id in pre_chats: if not isinstance(pre_chat_id, int): - raise ValueError("Prerequisite chat id is not int.") + raise ValueError("tuple[int, int] chat id is not int.") prerequisites.append((chat_id, pre_chat_id)) return prerequisites -def __find_async_chat_order(chat_ids: set[int], prerequisites: list[Prerequisite]) -> list[int]: +def __find_async_chat_order(chat_ids: set[int], prerequisites: list[tuple[int, int]]) -> list[int]: """Find chat order for async execution based on the prerequisite chats args: num_chats: number of chats - prerequisites: list of Prerequisite (prerequisite_chat_id, chat_id) + prerequisites: list of tuple[int, int] (prerequisite_chat_id, chat_id) returns: list: a list of chat_id in order. diff --git a/train_methods/legacy_autogen/client.py b/train_methods/legacy_autogen/client.py index 0280dca..7bcfe16 100644 --- a/train_methods/legacy_autogen/client.py +++ b/train_methods/legacy_autogen/client.py @@ -591,15 +591,7 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: The actual prompt will be: "Complete the following sentence: Today I feel". More examples can be found at [templating](/docs/Use-Cases/enhanced_inference#templating). - - cache (AbstractCache | None): A Cache object to use for response cache. Default to None. - Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided, - then the cache_seed argument is ignored. If this argument is not provided or None, - then the cache_seed argument is used. - agent (AbstractAgent | None): The object responsible for creating a completion if an agent. - - (Legacy) cache_seed (int | None) for using the DiskCache. Default to 41. - An integer cache_seed is useful when implementing "controlled randomness" for the completion. - None for no caching. - Note: this is a legacy argument. It is only used when the cache argument is not provided. - filter_func (Callable | None): A function that takes in the context and the response and returns a boolean to indicate whether the response is valid. E.g., @@ -637,8 +629,7 @@ def yes_or_no_filter(context, response): # construct the create params params = self._construct_create_params(create_config, extra_kwargs) # get the cache_seed, filter_func and context - cache_seed = extra_kwargs.get("cache_seed", 41) - cache = extra_kwargs.get("cache") + cache = None filter_func = extra_kwargs.get("filter_func") context = extra_kwargs.get("context") price = extra_kwargs.get("price", None) @@ -652,41 +643,32 @@ def yes_or_no_filter(context, response): total_usage = None actual_usage = None + cache_client = Cache.disk(41, ".cache") - cache_client = None - if cache is not None: - # Use the cache object if provided. - cache_client = cache - elif cache_seed is not None: - # Legacy cache behavior, if cache_seed is given, use DiskCache. - cache_client = Cache.disk(cache_seed, ".cache") + with cache_client as cache: + key = get_key(params) - if cache_client is not None: - with cache_client as cache: - # Try to get the response from cache - key = get_key(params) - - response: ModelClient.ModelClientResponseProtocol = cache.get(key, None) - - if response is not None: - response.message_retrieval_function = client.message_retrieval - try: - response.cost # type: ignore [attr-defined] - except AttributeError: - # update attribute if cost is not calculated - response.cost = client.cost(response) - cache.set(key, response) - total_usage = client.get_usage(response) - - # check the filter - pass_filter = filter_func is None or filter_func(context=context, response=response) - if pass_filter or i == last: - # Return the response if it passes the filter or it is the last client - response.config_id = i - response.pass_filter = pass_filter - self._update_usage(actual_usage=actual_usage, total_usage=total_usage) - return response - continue # filter is not passed; try the next config + response: ModelClient.ModelClientResponseProtocol = cache.get(key, None) + + if response is not None: + response.message_retrieval_function = client.message_retrieval + try: + response.cost # type: ignore [attr-defined] + except AttributeError: + # update attribute if cost is not calculated + response.cost = client.cost(response) + cache.set(key, response) + total_usage = client.get_usage(response) + + # check the filter + pass_filter = filter_func is None or filter_func(context=context, response=response) + if pass_filter or i == last: + # Return the response if it passes the filter or it is the last client + response.config_id = i + response.pass_filter = pass_filter + self._update_usage(actual_usage=actual_usage, total_usage=total_usage) + return response + continue # filter is not passed; try the next config try: self._throttle_api_calls(i) response = client.create(params) @@ -711,10 +693,8 @@ def yes_or_no_filter(context, response): actual_usage = client.get_usage(response) total_usage = actual_usage.copy() if actual_usage is not None else total_usage self._update_usage(actual_usage=actual_usage, total_usage=total_usage) - if cache_client is not None: - # Cache the response - with cache_client as cache: - cache.set(key, response) + with cache_client as cache: + cache.set(key, response) response.message_retrieval_function = client.message_retrieval # check the filter diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index 8a43d36..baebfe3 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -12,7 +12,6 @@ from pydantic import BaseModel from termcolor import colored -from train_methods.legacy_autogen.cache import AbstractCache from train_methods.legacy_autogen.chat import ChatResult, a_initiate_chats, initiate_chats, _post_process_carryover_item, consolidate_chat_info from train_methods.legacy_autogen.client import ModelClient, OpenAIWrapper from train_methods.legacy_autogen.stream import IOStream @@ -984,7 +983,6 @@ def initiate_chat( recipient: "ConversableAgent", clear_history: bool = True, silent: bool | None = False, - cache: AbstractCache | None = None, max_turns: int | None = None, summary_method: str | Callable | None = DEFAULT_SUMMARY_METHOD, summary_args: dict | None = {}, @@ -1001,7 +999,6 @@ def initiate_chat( recipient: the recipient agent. clear_history (bool): whether to clear the chat history with the agent. Default is True. silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. - cache (AbstractCache or None): the cache client to be used for this conversation. Default is None. max_turns (int or None): the maximum number of turns for the chat between the two agents. One turn means one conversation round trip. Note that this is different from [max_consecutive_auto_reply](#max_consecutive_auto_reply) which is the maximum number of consecutive auto replies; and it is also different from [max_rounds in GroupChat](./groupchat#groupchat-objects) which is the maximum number of rounds in a group chat session. If max_turns is set to None, the chat will continue until a termination condition is met. Default is None. @@ -1084,7 +1081,7 @@ def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: d for agent in [self, recipient]: agent._raise_exception_on_async_reply_functions() agent.previous_cache = agent.client_cache - agent.client_cache = cache + agent.client_cache = None if isinstance(max_turns, int): self._prepare_chat(recipient, clear_history, reply_at_receive=False) for _ in range(max_turns): @@ -1109,7 +1106,6 @@ def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: d summary_method, summary_args, recipient, - cache=cache, ) for agent in [self, recipient]: agent.client_cache = agent.previous_cache @@ -1127,7 +1123,6 @@ async def a_initiate_chat( recipient: "ConversableAgent", clear_history: bool = True, silent: bool | None = False, - cache: AbstractCache | None = None, max_turns: int | None = None, summary_method: str | Callable | None = DEFAULT_SUMMARY_METHOD, summary_args: dict | None = {}, @@ -1150,7 +1145,7 @@ async def a_initiate_chat( consolidate_chat_info(_chat_info, uniform_sender=self) for agent in [self, recipient]: agent.previous_cache = agent.client_cache - agent.client_cache = cache + agent.client_cache = None if isinstance(max_turns, int): self._prepare_chat(recipient, clear_history, reply_at_receive=False) for _ in range(max_turns): @@ -1175,7 +1170,6 @@ async def a_initiate_chat( summary_method, summary_args, recipient, - cache=cache, ) for agent in [self, recipient]: agent.client_cache = agent.previous_cache @@ -1193,7 +1187,6 @@ def _summarize_chat( summary_method, summary_args, recipient: Agent | None = None, - cache: AbstractCache | None = None, ) -> str: """Get a chat summary from an agent participating in a chat. @@ -1219,7 +1212,7 @@ def my_summary_method( if summary_method is None: return summary if "cache" not in summary_args: - summary_args["cache"] = cache + summary_args["cache"] = None if summary_method == "reflection_with_llm": summary_method = self._reflection_with_llm_as_summary elif summary_method == "last_msg": @@ -1263,7 +1256,7 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args): raise ValueError("The summary_role in summary_arg must be a string.") try: summary = sender._reflection_with_llm( - prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"), role=role + prompt, msg_list, llm_agent=agent, role=role ) except BadRequestError as e: warnings.warn( @@ -1277,7 +1270,6 @@ def _reflection_with_llm( prompt, messages, llm_agent: Agent | None = None, - cache: AbstractCache | None = None, role: str | None = None, ) -> str: """Get a chat summary using reflection with an llm client based on the conversation history. @@ -1286,7 +1278,6 @@ def _reflection_with_llm( prompt (str): The prompt (in this method it is used as system prompt) used to get the summary. messages (list): The messages generated as part of a chat conversation. llm_agent: the agent with an llm client. - cache (AbstractCache or None): the cache client to be used for this conversation. role (str): the role of the message, usually "system" or "user". Default is "system". """ if not role: @@ -1306,7 +1297,7 @@ def _reflection_with_llm( llm_client = self.client else: raise ValueError("No OpenAIWrapper client is found.") - response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache) + response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages) return response def _check_chat_queue_for_sender(self, chat_queue: list[dict[str, Any]]) -> list[dict[str, Any]]: @@ -1427,12 +1418,11 @@ def generate_oai_reply( if messages is None: messages = self._oai_messages[sender] extracted_response = self._generate_oai_reply_from_client( - client, self._oai_system_message + messages, self.client_cache + client, self._oai_system_message + messages ) return (False, None) if extracted_response is None else (True, extracted_response) - def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> str | dict | None: - # unroll tool_responses + def _generate_oai_reply_from_client(self, llm_client, messages) -> str | dict | None: all_messages = [] for message in messages: tool_responses = message.get("tool_responses", []) @@ -1445,7 +1435,7 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> str | all_messages.append(message) response = llm_client.create( - context=messages[-1].pop("context", None), messages=all_messages, cache=cache, agent=self + context=messages[-1].pop("context", None), messages=all_messages, agent=self ) extracted_response = llm_client.extract_text_or_completion_object(response)[0] From 6c43420be74ec7d5c77e1a895156de8d32fd0436 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Mon, 2 Feb 2026 17:24:23 +0900 Subject: [PATCH 22/25] remove reflection summarize --- train_methods/legacy_autogen/client.py | 34 +--- .../legacy_autogen_conversable_agent.py | 155 +----------------- 2 files changed, 7 insertions(+), 182 deletions(-) diff --git a/train_methods/legacy_autogen/client.py b/train_methods/legacy_autogen/client.py index 7bcfe16..1ff9592 100644 --- a/train_methods/legacy_autogen/client.py +++ b/train_methods/legacy_autogen/client.py @@ -386,29 +386,7 @@ def __init__(self, *, config_list: list[dict[str, Any]] | None = None, **base_co """ Args: config_list: a list of config dicts to override the base_config. - They can contain additional kwargs as allowed in the [create](/docs/reference/oai/client#create) method. E.g., - - ```python - config_list=[ - { - "model": "gpt-4", - "api_key": os.environ.get("AZURE_OPENAI_API_KEY"), - "api_type": "azure", - "base_url": os.environ.get("AZURE_OPENAI_API_BASE"), - "api_version": "2024-02-01", - }, - { - "model": "gpt-3.5-turbo", - "api_key": os.environ.get("OPENAI_API_KEY"), - "api_type": "openai", - "base_url": "https://api.openai.com/v1", - }, - { - "model": "llama-7B", - "base_url": "http://127.0.0.1:8080", - } - ] - ``` + They can contain additional kwargs as allowed in the [create](/docs/reference/oai/client#create) method. base_config: base config. It can contain both keyword arguments for openai client and additional kwargs. @@ -585,7 +563,7 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: The config in each client will be overridden by the config. Args: - - context (Dict | None): The context to instantiate the prompt or messages. Default to None. + - context (dict | None): The context to instantiate the prompt or messages. Default to None. It needs to contain keys that are used by the prompt template or the filter function. E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`. The actual prompt will be: @@ -594,14 +572,6 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: - agent (AbstractAgent | None): The object responsible for creating a completion if an agent. - filter_func (Callable | None): A function that takes in the context and the response and returns a boolean to indicate whether the response is valid. E.g., - - ```python - def yes_or_no_filter(context, response): - return context.get("yes_or_no_choice", False) is False or any( - text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response) - ) - ``` - - allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false. - api_version (str | None): The api version. Default to None. E.g., "2024-02-01". Raises: diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index baebfe3..ac0d564 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -8,7 +8,6 @@ from collections import defaultdict from typing import Any, Callable, Coroutine, Literal, Type, Protocol, TypeVar -from openai import BadRequestError from pydantic import BaseModel from termcolor import colored @@ -332,15 +331,6 @@ def register_reply( Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`. reply_func (Callable): the reply function. The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. - - ```python - def reply_func( - recipient: ConversableAgent, - messages: list[dict] | None = None, - sender: Agent | None = None, - config: Any | None = None, - ) -> Tuple[bool, str | dict | None]: - ``` position (int): the position of the reply function in the reply function list. The function registered later will be checked earlier by default. To change the order, set the position to a positive integer. @@ -465,17 +455,7 @@ def register_nested_chats( chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them. trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details. reply_func_from_nested_chats (Callable, str): the reply function for the nested chat. - The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. - Default to "summary_from_nested_chats", which corresponds to a built-in reply function that get summary from the nested chat_queue. - ```python - def reply_func_from_nested_chats( - chat_queue: list[dict], - recipient: ConversableAgent, - messages: list[dict] | None = None, - sender: Agent | None = None, - config: Any | None = None, - ) -> tuple[bool, str | dict | None]: - ``` + The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. Default to "summary_from_nested_chats", which corresponds to a built-in reply function that get summary from the nested chat_queue. position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply. use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync. kwargs: Ref to `register_reply` for details. @@ -714,15 +694,6 @@ def send( will be modified to "assistant". - context (dict): the context of the message, which will be passed to [OpenAIWrapper.create](../oai/client#create). - For example, one agent can send a message A as: - ```python - { - "content": lambda context: context["use_tool_msg"], - "context": { - "use_tool_msg": "Use tool X if they are relevant." - } - } - ``` Next time, one agent can send a message B with a different "use_tool_msg". Then the content of message A will be refreshed to the new "use_tool_msg". So effectively, this provides a way for an agent to send a "link" and modify @@ -752,38 +723,7 @@ async def a_send( request_reply: bool | None = None, silent: bool | None = False, ): - """(async) Send a message to another agent. - - Args: - message (dict or str): message to be sent. - The message could contain the following fields: - - content (str or List): Required, the content of the message. (Can be None) - - function_call (str): the name of the function to be called. - - name (str): the name of the function to be called. - - role (str): the role of the message, any role that is not "function" - will be modified to "assistant". - - context (dict): the context of the message, which will be passed to - [OpenAIWrapper.create](../oai/client#create). - For example, one agent can send a message A as: - ```python - { - "content": lambda context: context["use_tool_msg"], - "context": { - "use_tool_msg": "Use tool X if they are relevant." - } - } - ``` - Next time, one agent can send a message B with a different "use_tool_msg". - Then the content of message A will be refreshed to the new "use_tool_msg". - So effectively, this provides a way for an agent to send a "link" and modify - the content of the "link" later. - recipient (Agent): the recipient of the message. - request_reply (bool or None): whether to request a reply from the recipient. - silent (bool or None): (Experimental) whether to print the message sent. - - Raises: - ValueError: if the message can't be converted into a valid ChatCompletion message. - """ + """(async) Send a message to another agent.""" message = await self._a_process_message_before_send( message, recipient, ConversableAgent._is_silent(self, silent) ) @@ -1009,16 +949,7 @@ def initiate_chat( - when set to "reflection_with_llm", it returns a summary extracted using an llm client. `llm_config` must be set in either the recipient or sender. - A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g., - - ```python - def my_summary_method( - sender: ConversableAgent, - recipient: ConversableAgent, - summary_args: dict, - ): - return recipient.last_message(sender)["content"] - ``` + A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. summary_args (dict): a dictionary of arguments to be passed to the summary_method. One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect on the conversation and extract a summary when summary_method is "reflection_with_llm". @@ -1040,29 +971,6 @@ def my_summary_method( - If a callable is provided, it will be called to get the initial message in the form of a string or a dict. If the returned type is dict, it may contain the reserved fields mentioned above. - Example of a callable message (returning a string): - - ```python - def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> str | dict: - carryover = context.get("carryover", "") - if isinstance(message, list): - carryover = carryover[-1] - final_msg = "Write a blogpost." + "\\nContext: \\n" + carryover - return final_msg - ``` - - Example of a callable message (returning a dict): - - ```python - def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> str | dict: - final_msg = {} - carryover = context.get("carryover", "") - if isinstance(message, list): - carryover = carryover[-1] - final_msg["content"] = "Write a blogpost." + "\\nContext: \\n" + carryover - final_msg["context"] = {"prefix": "Today I feel"} - return final_msg - ``` **kwargs: any additional information. It has the following reserved fields: - "carryover": a string or a list of string to specify the carryover information to be passed to this chat. If provided, we will combine this carryover (by attaching a "context: " string and the carryover content after the message content) with the "message" content when generating the initial chat @@ -1103,8 +1011,6 @@ def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: d msg2send = self.generate_init_message(message, **kwargs) self.send(msg2send, recipient, silent=silent) summary = self._summarize_chat( - summary_method, - summary_args, recipient, ) for agent in [self, recipient]: @@ -1167,8 +1073,6 @@ async def a_initiate_chat( msg2send = await self.a_generate_init_message(message, **kwargs) await self.a_send(msg2send, recipient, silent=silent) summary = self._summarize_chat( - summary_method, - summary_args, recipient, ) for agent in [self, recipient]: @@ -1184,23 +1088,12 @@ async def a_initiate_chat( def _summarize_chat( self, - summary_method, - summary_args, recipient: Agent | None = None, ) -> str: """Get a chat summary from an agent participating in a chat. Args: summary_method (str or callable): the summary_method to get the summary. - The callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g, - ```python - def my_summary_method( - sender: ConversableAgent, - recipient: ConversableAgent, - summary_args: dict, - ): - return recipient.last_message(sender)["content"] - ``` summary_args (dict): a dictionary of arguments to be passed to the summary_method. recipient: the recipient agent in a chat. prompt (str): the prompt used to get a summary when summary_method is "reflection_with_llm". @@ -1208,26 +1101,10 @@ def my_summary_method( Returns: str: a chat summary from the agent. """ - summary = "" - if summary_method is None: - return summary - if "cache" not in summary_args: - summary_args["cache"] = None - if summary_method == "reflection_with_llm": - summary_method = self._reflection_with_llm_as_summary - elif summary_method == "last_msg": - summary_method = self._last_msg_as_summary - - if isinstance(summary_method, Callable): - summary = summary_method(self, recipient, summary_args) - else: - raise ValueError( - "If not None, the summary_method must be a string from [`reflection_with_llm`, `last_msg`] or a callable." - ) - return summary + return self._last_msg_as_summary(self, recipient) @staticmethod - def _last_msg_as_summary(sender, recipient, summary_args) -> str: + def _last_msg_as_summary(sender, recipient) -> str: """Get a chat summary from the last message of the recipient.""" summary = "" try: @@ -1243,28 +1120,6 @@ def _last_msg_as_summary(sender, recipient, summary_args) -> str: warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning) return summary - @staticmethod - def _reflection_with_llm_as_summary(sender, recipient, summary_args): - prompt = summary_args.get("summary_prompt") - prompt = ConversableAgent.DEFAULT_SUMMARY_PROMPT if prompt is None else prompt - if not isinstance(prompt, str): - raise ValueError("The summary_prompt must be a string.") - msg_list = recipient.chat_messages_for_summary(sender) - agent = sender if recipient is None else recipient - role = summary_args.get("summary_role", None) - if role and not isinstance(role, str): - raise ValueError("The summary_role in summary_arg must be a string.") - try: - summary = sender._reflection_with_llm( - prompt, msg_list, llm_agent=agent, role=role - ) - except BadRequestError as e: - warnings.warn( - f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning - ) - summary = "" - return summary - def _reflection_with_llm( self, prompt, From 324ec7d0af451c8c5dc7a5813a08b2da149d5979 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Mon, 2 Feb 2026 17:29:22 +0900 Subject: [PATCH 23/25] remove silent --- .../legacy_autogen/legacy_autogen.py | 6 - .../legacy_autogen_conversable_agent.py | 196 +----------------- 2 files changed, 9 insertions(+), 193 deletions(-) diff --git a/train_methods/legacy_autogen/legacy_autogen.py b/train_methods/legacy_autogen/legacy_autogen.py index 91d4ed7..8c71b0a 100644 --- a/train_methods/legacy_autogen/legacy_autogen.py +++ b/train_methods/legacy_autogen/legacy_autogen.py @@ -869,12 +869,6 @@ def groupchat(self) -> GroupChat: """Returns the group chat managed by the group chat manager.""" return self._groupchat - def chat_messages_for_summary(self, agent: Agent) -> list[dict]: - """The list of messages in the group chat as a conversation to summarize. - The agent is ignored. - """ - return self._groupchat.messages - def _prepare_chat( self, recipient: ConversableAgent, diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index ac0d564..892d5c1 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -167,13 +167,6 @@ class LLMAgent(Agent, Protocol): def system_message(self) -> str: """The system message of this agent.""" - def update_system_message(self, system_message: str) -> None: - """Update this agent's system message. - - Args: - system_message (str): system message for inference. - """ - class ConversableAgent(LLMAgent): """A class for generic conversable agents which can be configured as assistant or user proxy. @@ -277,10 +270,6 @@ def _validate_llm_config(self, llm_config): ) self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config) - @staticmethod - def _is_silent(agent: Agent, silent: bool | None = False) -> bool: - return silent - @property def name(self) -> str: """Get the name of the agent.""" @@ -359,172 +348,12 @@ def register_reply( }, ) - def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable): - """Replace a registered reply function with a new one. - - Args: - old_reply_func (Callable): the old reply function to be replaced. - new_reply_func (Callable): the new reply function to replace the old one. - """ - for f in self._reply_func_list: - if f["reply_func"] == old_reply_func: - f["reply_func"] = new_reply_func - - @staticmethod - def _get_chats_to_run( - chat_queue: list[dict[str, Any]], recipient: Agent, messages: str | Callable, sender: Agent, config: Any - ) -> list[dict[str, Any]]: - """A simple chat reply function. - This function initiate one or a sequence of chats between the "recipient" and the agents in the - chat_queue. - - It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. - - Returns: - Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. - """ - last_msg = messages[-1].get("content") - chat_to_run = [] - for i, c in enumerate(chat_queue): - current_c = c.copy() - if current_c.get("sender") is None: - current_c["sender"] = recipient - message = current_c.get("message") - # If message is not provided in chat_queue, we by default use the last message from the original chat history as the first message in this nested chat (for the first chat in the chat queue). - # NOTE: This setting is prone to change. - if message is None and i == 0: - message = last_msg - if callable(message): - message = message(recipient, messages, sender, config) - # We only run chat that has a valid message. NOTE: This is prone to change dependin on applications. - if message: - current_c["message"] = message - chat_to_run.append(current_c) - return chat_to_run - - @staticmethod - def _summary_from_nested_chats( - chat_queue: list[dict[str, Any]], recipient: Agent, messages: str | Callable, sender: Agent, config: Any - ) -> tuple[bool, str | None]: - """A simple chat reply function. - This function initiate one or a sequence of chats between the "recipient" and the agents in the - chat_queue. - - It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. - - Returns: - tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. - """ - chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) - if not chat_to_run: - return True, None - res = initiate_chats(chat_to_run) - return True, res[-1].summary - - @staticmethod - async def _a_summary_from_nested_chats( - chat_queue: list[dict[str, Any]], recipient: Agent, messages: str | Callable, sender: Agent, config: Any - ) -> tuple[bool, str | None]: - """A simple chat reply function. - This function initiate one or a sequence of chats between the "recipient" and the agents in the - chat_queue. - - It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. - - Returns: - tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. - """ - chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) - if not chat_to_run: - return True, None - res = await a_initiate_chats(chat_to_run) - index_of_last_chat = chat_to_run[-1]["chat_id"] - return True, res[index_of_last_chat].summary - - def register_nested_chats( - self, - chat_queue: list[dict[str, Any]], - trigger: Type[Agent] | str | Agent | Callable[[Agent], bool] | list, - reply_func_from_nested_chats: str | Callable = "summary_from_nested_chats", - position: int = 2, - use_async: bool | None = None, - **kwargs, - ) -> None: - """Register a nested chat reply function. - Args: - chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them. - trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details. - reply_func_from_nested_chats (Callable, str): the reply function for the nested chat. - The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. Default to "summary_from_nested_chats", which corresponds to a built-in reply function that get summary from the nested chat_queue. - position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply. - use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync. - kwargs: Ref to `register_reply` for details. - """ - if use_async: - for chat in chat_queue: - if chat.get("chat_id") is None: - raise ValueError("chat_id is required for async nested chats") - - if use_async: - if reply_func_from_nested_chats == "summary_from_nested_chats": - reply_func_from_nested_chats = self._a_summary_from_nested_chats - if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction( - reply_func_from_nested_chats - ): - raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine") - - async def wrapped_reply_func(recipient, messages=None, sender=None, config=None): - return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) - - else: - if reply_func_from_nested_chats == "summary_from_nested_chats": - reply_func_from_nested_chats = self._summary_from_nested_chats - if not callable(reply_func_from_nested_chats): - raise ValueError("reply_func_from_nested_chats must be a callable") - - def wrapped_reply_func(recipient, messages=None, sender=None, config=None): - return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config) - - functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats) - - self.register_reply( - trigger, - wrapped_reply_func, - position, - kwargs.get("config"), - kwargs.get("reset_config"), - ignore_async_in_sync_chat=( - not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat") - ), - ) @property def system_message(self) -> str: """Return the system message.""" return self._oai_system_message[0]["content"] - def update_system_message(self, system_message: str) -> None: - """Update the system message. - - Args: - system_message (str): system message for the ChatCompletion inference. - """ - self._oai_system_message[0]["content"] = system_message - - def update_max_consecutive_auto_reply(self, value: int, sender: Agent | None = None): - """Update the maximum number of consecutive auto replies. - - Args: - value (int): the maximum number of consecutive auto replies. - sender (Agent): when the sender is provided, only update the max_consecutive_auto_reply for that sender. - """ - if sender is None: - self._max_consecutive_auto_reply = value - for k in self._max_consecutive_auto_reply_dict: - self._max_consecutive_auto_reply_dict[k] = value - else: - self._max_consecutive_auto_reply_dict[sender] = value - def max_consecutive_auto_reply(self, sender: Agent | None = None) -> int: """The maximum number of consecutive auto replies.""" return self._max_consecutive_auto_reply if sender is None else self._max_consecutive_auto_reply_dict[sender] @@ -534,10 +363,6 @@ def chat_messages(self) -> dict[Agent, list[dict]]: """A dictionary of conversations from agent to list of messages.""" return self._oai_messages - def chat_messages_for_summary(self, agent: Agent) -> list[dict]: - """A list of messages as a conversation to summarize.""" - return self._oai_messages[agent] - def last_message(self, agent: Agent | None = None) -> dict | None: """The last message exchanged with the agent. @@ -652,7 +477,7 @@ def _append_oai_message(self, message: dict | str, role, conversation_id: Agent, return True def _process_message_before_send( - self, message: dict | str, recipient: Agent, silent: bool + self, message: dict | str, recipient: Agent ) -> dict | str: """Process the message before sending it to the recipient.""" hook_list = self.hook_lists["process_message_before_send"] @@ -660,19 +485,19 @@ def _process_message_before_send( if inspect.iscoroutinefunction(hook): continue message = hook( - sender=self, message=message, recipient=recipient, silent=ConversableAgent._is_silent(self, silent) + sender=self, message=message, recipient=recipient, silent=False ) return message async def _a_process_message_before_send( - self, message: dict | str, recipient: Agent, silent: bool + self, message: dict | str, recipient: Agent ) -> dict | str: """(async) Process the message before sending it to the recipient.""" hook_list = self.hook_lists["a_process_message_before_send"] for hook in hook_list: if not inspect.iscoroutinefunction(hook): continue - message = await hook(sender=self, message=message, recipient=recipient, silent=silent) + message = await hook(sender=self, message=message, recipient=recipient, silent=False) return message def send( @@ -705,7 +530,7 @@ def send( Raises: ValueError: if the message can't be converted into a valid ChatCompletion message. """ - message = self._process_message_before_send(message, recipient, ConversableAgent._is_silent(self, silent)) + message = self._process_message_before_send(message, recipient) # When the agent composes and sends the message, the role of the message is "assistant" # unless it's "function". valid = self._append_oai_message(message, "assistant", recipient, is_sending=True) @@ -724,9 +549,7 @@ async def a_send( silent: bool | None = False, ): """(async) Send a message to another agent.""" - message = await self._a_process_message_before_send( - message, recipient, ConversableAgent._is_silent(self, silent) - ) + message = await self._a_process_message_before_send(message, recipient) # When the agent composes and sends the message, the role of the message is "assistant" # unless it's "function". valid = self._append_oai_message(message, "assistant", recipient, is_sending=True) @@ -798,7 +621,7 @@ def _print_received_message(self, message: dict | str, sender: Agent): iostream.print("\n", "-" * 80, flush=True, sep="") - def _process_received_message(self, message: dict | str, sender: Agent, silent: bool): + def _process_received_message(self, message: dict | str, sender: Agent): # When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.) valid = self._append_oai_message(message, "user", sender, is_sending=False) @@ -807,8 +630,7 @@ def _process_received_message(self, message: dict | str, sender: Agent, silent: "Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." ) - if not ConversableAgent._is_silent(sender, silent): - self._print_received_message(message, sender) + self._print_received_message(message, sender) def receive( self, @@ -840,7 +662,7 @@ def receive( Raises: ValueError: if the message can't be converted into a valid ChatCompletion message. """ - self._process_received_message(message, sender, silent) + self._process_received_message(message, sender) if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False: return reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender) From b423f0b0a59541e7ca63cc706516444866467014 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Mon, 2 Feb 2026 17:41:00 +0900 Subject: [PATCH 24/25] remove function / tool call --- .../legacy_autogen_conversable_agent.py | 489 +----------------- 1 file changed, 24 insertions(+), 465 deletions(-) diff --git a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py index 892d5c1..26f39de 100644 --- a/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py +++ b/train_methods/legacy_autogen/legacy_autogen_conversable_agent.py @@ -196,7 +196,7 @@ def __init__( system_message (str or list): system message for the ChatCompletion inference. is_termination_msg (function): a function that takes a message in the form of a dictionary and returns a boolean value indicating if this received message is a termination message. - The dict can contain the following keys: "content", "role", "name", "function_call". + The dict can contain the following keys: "content", "role" and "name". max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). When set to 0, no auto reply will be generated. @@ -234,12 +234,6 @@ def __init__( self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True) - self.register_reply([Agent, None], ConversableAgent.generate_tool_calls_reply) - self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply, ignore_async_in_sync_chat=True) - self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply) - self.register_reply( - [Agent, None], ConversableAgent.a_generate_function_call_reply, ignore_async_in_sync_chat=True - ) self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply) self.register_reply( [Agent, None], ConversableAgent.a_check_termination_and_human_reply, ignore_async_in_sync_chat=True @@ -427,9 +421,7 @@ def _append_oai_message(self, message: dict | str, role, conversation_id: Agent, """Append a message to the ChatCompletion conversation. If the message received is a string, it will be put in the "content" field of the new dictionary. - If the message received is a dictionary but does not have any of the three fields "content", "function_call", or "tool_calls", - this message is not a valid ChatCompletion message. - If only "function_call" or "tool_calls" is provided, "content" will be set to None if not provided, and the role of the message will be forced "assistant". + If the message received is a dictionary but does not have any of the three fields "content", this message is not a valid ChatCompletion message. Args: message (dict or str): message to be appended to the ChatCompletion conversation. @@ -444,18 +436,13 @@ def _append_oai_message(self, message: dict | str, role, conversation_id: Agent, # create oai message to be appended to the oai conversation that can be passed to oai directly. oai_message = { k: message[k] - for k in ("content", "function_call", "tool_calls", "tool_responses", "tool_call_id", "name", "context") + for k in ("content", "name", "context") if k in message and message[k] is not None } if "content" not in oai_message: - if "function_call" in oai_message or "tool_calls" in oai_message: - oai_message["content"] = None # if only function_call is provided, content will be set to None. - else: - return False + return False - if message.get("role") in ["function", "tool"]: - oai_message["role"] = message.get("role") - elif "override_role" in message: + if "override_role" in message: # If we have a direction to override the role then set the # role accordingly. Used to customise the role for the # select speaker prompt. @@ -463,9 +450,7 @@ def _append_oai_message(self, message: dict | str, role, conversation_id: Agent, else: oai_message["role"] = role - if oai_message.get("function_call", False) or oai_message.get("tool_calls", False): - oai_message["role"] = "assistant" # only messages with role 'assistant' can have a function call. - elif "name" not in oai_message: + if "name" not in oai_message: # If we don't have a name field, append it if is_sending: oai_message["name"] = self.name @@ -513,7 +498,6 @@ def send( message (dict or str): message to be sent. The message could contain the following fields: - content (str or List): Required, the content of the message. (Can be None) - - function_call (str): the name of the function to be called. - name (str): the name of the function to be called. - role (str): the role of the message, any role that is not "function" will be modified to "assistant". @@ -565,60 +549,15 @@ def _print_received_message(self, message: dict | str, sender: Agent): # print the message received iostream.print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True) message = self._message_to_dict(message) - - if message.get("tool_responses"): # Handle tool multi-call responses - for tool_response in message["tool_responses"]: - self._print_received_message(tool_response, sender) - if message.get("role") == "tool": - return # If role is tool, then content is just a concatenation of all tool_responses - - if message.get("role") in ["function", "tool"]: - if message["role"] == "function": - id_key = "name" - else: - id_key = "tool_call_id" - id = message.get(id_key, "No id found") - func_print = f"***** Response from calling {message['role']} ({id}) *****" - iostream.print(colored(func_print, "green"), flush=True) - iostream.print(message["content"], flush=True) - iostream.print(colored("*" * len(func_print), "green"), flush=True) - else: - content = message.get("content") - if content is not None: - if "context" in message: - content = OpenAIWrapper.instantiate( - content, - message["context"], - self.llm_config and self.llm_config.get("allow_format_str_template", False), - ) - iostream.print(content_str(content), flush=True) - if "function_call" in message and message["function_call"]: - function_call = dict(message["function_call"]) - func_print = ( - f"***** Suggested function call: {function_call.get('name', '(No function name found)')} *****" - ) - iostream.print(colored(func_print, "green"), flush=True) - iostream.print( - "Arguments: \n", - function_call.get("arguments", "(No arguments found)"), - flush=True, - sep="", + content = message.get("content") + if content is not None: + if "context" in message: + content = OpenAIWrapper.instantiate( + content, + message["context"], + self.llm_config and self.llm_config.get("allow_format_str_template", False), ) - iostream.print(colored("*" * len(func_print), "green"), flush=True) - if "tool_calls" in message and message["tool_calls"]: - for tool_call in message["tool_calls"]: - id = tool_call.get("id", "No tool call id found") - function_call = dict(tool_call.get("function", {})) - func_print = f"***** Suggested tool call ({id}): {function_call.get('name', '(No function name found)')} *****" - iostream.print(colored(func_print, "green"), flush=True) - iostream.print( - "Arguments: \n", - function_call.get("arguments", "(No arguments found)"), - flush=True, - sep="", - ) - iostream.print(colored("*" * len(func_print), "green"), flush=True) - + iostream.print(content_str(content), flush=True) iostream.print("\n", "-" * 80, flush=True, sep="") def _process_received_message(self, message: dict | str, sender: Agent): @@ -645,10 +584,8 @@ def receive( The reply can be generated automatically or entered manually by a human. Args: - message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content or function_call need to be provided). + message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content need to be provided). 1. "content": content of the message, can be None. - 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls") - 3. "tool_calls": a list of dictionaries containing the function name and arguments. 4. "role": role of the message, can be "assistant", "user", "function", "tool". This field is only needed to distinguish between "function" or "assistant"/"user". 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. @@ -682,10 +619,8 @@ async def a_receive( The reply can be generated automatically or entered manually by a human. Args: - message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content or function_call need to be provided). + message (dict or str): message from the sender. If the type is dict, it may contain the following reserved fields (either content need to be provided). 1. "content": content of the message, can be None. - 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls") - 3. "tool_calls": a list of dictionaries containing the function name and arguments. 4. "role": role of the message, can be "assistant", "user", "function". This field is only needed to distinguish between "function" or "assistant"/"user". 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. @@ -746,8 +681,8 @@ def initiate_chat( clear_history: bool = True, silent: bool | None = False, max_turns: int | None = None, - summary_method: str | Callable | None = DEFAULT_SUMMARY_METHOD, - summary_args: dict | None = {}, + summary_method: str = DEFAULT_SUMMARY_METHOD, + summary_args: dict = {}, message: dict | str | Callable | None = None, **kwargs, ) -> ChatResult: @@ -782,8 +717,6 @@ def initiate_chat( If dict, it may contain the following reserved fields (either content or tool_calls need to be provided). 1. "content": content of the message, can be None. - 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls") - 3. "tool_calls": a list of dictionaries containing the function name and arguments. 4. "role": role of the message, can be "assistant", "user", "function". This field is only needed to distinguish between "function" or "assistant"/"user". 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. @@ -852,8 +785,8 @@ async def a_initiate_chat( clear_history: bool = True, silent: bool | None = False, max_turns: int | None = None, - summary_method: str | Callable | None = DEFAULT_SUMMARY_METHOD, - summary_args: dict | None = {}, + summary_method: str = DEFAULT_SUMMARY_METHOD, + summary_args: dict = {}, message: str | Callable | None = None, **kwargs, ) -> ChatResult: @@ -942,41 +875,6 @@ def _last_msg_as_summary(sender, recipient) -> str: warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning) return summary - def _reflection_with_llm( - self, - prompt, - messages, - llm_agent: Agent | None = None, - role: str | None = None, - ) -> str: - """Get a chat summary using reflection with an llm client based on the conversation history. - - Args: - prompt (str): The prompt (in this method it is used as system prompt) used to get the summary. - messages (list): The messages generated as part of a chat conversation. - llm_agent: the agent with an llm client. - role (str): the role of the message, usually "system" or "user". Default is "system". - """ - if not role: - role = "system" - - system_msg = [ - { - "role": role, - "content": prompt, - } - ] - - messages = messages + system_msg - if llm_agent and llm_agent.client is not None: - llm_client = llm_agent.client - elif self.client is not None: - llm_client = self.client - else: - raise ValueError("No OpenAIWrapper client is found.") - response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages) - return response - def _check_chat_queue_for_sender(self, chat_queue: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Check the chat queue and add the "sender" key if it's missing. @@ -1012,13 +910,6 @@ async def a_initiate_chats(self, chat_queue: list[dict[str, Any]]) -> dict[int, self._finished_chats = await a_initiate_chats(_chat_queue) return self._finished_chats - def get_chat_results(self, chat_index: int | None = None) -> list[ChatResult] | ChatResult: - """A summary from the finished chats of particular agents.""" - if chat_index is not None: - return self._finished_chats[chat_index] - else: - return self._finished_chats - def reset(self): """Reset the agent.""" self.clear_history() @@ -1058,15 +949,6 @@ def clear_history(self, recipient: Agent | None = None, nr_messages_to_preserve: if nr_messages_to_preserve: for key in self._oai_messages: nr_messages_to_preserve_internal = nr_messages_to_preserve - # if breaking history between function call and function response, save function call message - # additionally, otherwise openai will return error - first_msg_to_save = self._oai_messages[key][-nr_messages_to_preserve_internal] - if "tool_responses" in first_msg_to_save: - nr_messages_to_preserve_internal += 1 - iostream.print( - f"Preserving one more message for {self.name} to not divide history between tool call and " - f"tool response." - ) # Remove messages from history except last `nr_messages_to_preserve` messages. self._oai_messages[key] = self._oai_messages[key][-nr_messages_to_preserve_internal:] else: @@ -1100,16 +982,7 @@ def generate_oai_reply( return (False, None) if extracted_response is None else (True, extracted_response) def _generate_oai_reply_from_client(self, llm_client, messages) -> str | dict | None: - all_messages = [] - for message in messages: - tool_responses = message.get("tool_responses", []) - if tool_responses: - all_messages += tool_responses - # tool role on the parent message means the content is just concatenation of all of the tool_responses - if message.get("role") != "tool": - all_messages.append({key: message[key] for key in message if key != "tool_responses"}) - else: - all_messages.append(message) + all_messages = messages response = llm_client.create( context=messages[-1].pop("context", None), messages=all_messages, agent=self @@ -1122,19 +995,6 @@ def _generate_oai_reply_from_client(self, llm_client, messages) -> str | dict | # ensure function and tool calls will be accepted when sent back to the LLM if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"): extracted_response = model_dump(extracted_response) - if isinstance(extracted_response, dict): - if extracted_response.get("function_call"): - extracted_response["function_call"]["name"] = self._normalize_name( - extracted_response["function_call"]["name"] - ) - for tool_call in extracted_response.get("tool_calls") or []: - tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"]) - # Remove id and type if they are not present. - # This is to make the tool call object compatible with Mistral API. - if tool_call.get("id") is None: - tool_call.pop("id") - if tool_call.get("type") is None: - tool_call.pop("type") return extracted_response async def a_generate_oai_reply( @@ -1160,132 +1020,6 @@ def _generate_oai_reply( ), ) - def generate_function_call_reply( - self, - messages: list[dict] | None = None, - sender: Agent | None = None, - config: Any | None = None, - ) -> tuple[bool, dict | None]: - """ - Generate a reply using function call. - - "function_call" replaced by "tool_calls" as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) - See https://platform.openai.com/docs/api-reference/chat/create#chat-create-functions - """ - if config is None: - config = self - if messages is None: - messages = self._oai_messages[sender] - message = messages[-1] - if "function_call" in message and message["function_call"]: - func_return = self.execute_function(message["function_call"]) - return True, func_return - return False, None - - async def a_generate_function_call_reply( - self, - messages: list[dict] | None = None, - sender: Agent | None = None, - config: Any | None = None, - ) -> tuple[bool, dict | None]: - """ - Generate a reply using async function call. - - "function_call" replaced by "tool_calls" as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) - See https://platform.openai.com/docs/api-reference/chat/create#chat-create-functions - """ - if config is None: - config = self - if messages is None: - messages = self._oai_messages[sender] - message = messages[-1] - func_call = message.get("function_call") - if func_call: - func_return = self.execute_function(func_call) - return True, func_return - - return False, None - - def _str_for_tool_response(self, tool_response): - return str(tool_response.get("content", "")) - - def generate_tool_calls_reply( - self, - messages: list[dict] | None = None, - sender: Agent | None = None, - config: Any | None = None, - ) -> tuple[bool, dict | None]: - """Generate a reply using tool call.""" - if config is None: - config = self - if messages is None: - messages = self._oai_messages[sender] - message = messages[-1] - tool_returns = [] - for tool_call in message.get("tool_calls", []): - function_call = tool_call.get("function", {}) - func_return = self.execute_function(function_call) - content = func_return.get("content", "") - if content is None: - content = "" - tool_call_id = tool_call.get("id", None) - if tool_call_id is not None: - tool_call_response = { - "tool_call_id": tool_call_id, - "role": "tool", - "content": content, - } - else: - # Do not include tool_call_id if it is not present. - # This is to make the tool call object compatible with Mistral API. - tool_call_response = { - "role": "tool", - "content": content, - } - tool_returns.append(tool_call_response) - if tool_returns: - return True, { - "role": "tool", - "tool_responses": tool_returns, - "content": "\n\n".join([self._str_for_tool_response(tool_return) for tool_return in tool_returns]), - } - return False, None - - async def _a_execute_tool_call(self, tool_call): - id = tool_call["id"] - function_call = tool_call.get("function", {}) - func_return = await self.a_execute_function(function_call) - return { - "tool_call_id": id, - "role": "tool", - "content": func_return.get("content", ""), - } - - async def a_generate_tool_calls_reply( - self, - messages: list[dict] | None = None, - sender: Agent | None = None, - config: Any | None = None, - ) -> tuple[bool, dict | None]: - """Generate a reply using async function call.""" - if config is None: - config = self - if messages is None: - messages = self._oai_messages[sender] - message = messages[-1] - async_tool_calls = [] - for tool_call in message.get("tool_calls", []): - async_tool_calls.append(self._a_execute_tool_call(tool_call)) - if async_tool_calls: - tool_returns = await asyncio.gather(*async_tool_calls) - return True, { - "role": "tool", - "tool_responses": tool_returns, - "content": "\n\n".join([self._str_for_tool_response(tool_return) for tool_return in tool_returns]), - } - - return False, None - def check_termination_and_human_reply( self, messages: list[dict] | None = None, @@ -1337,29 +1071,9 @@ def check_termination_and_human_reply( if reply or self._max_consecutive_auto_reply_dict[sender] == 0: # reset the consecutive_auto_reply_counter self._consecutive_auto_reply_counter[sender] = 0 - # User provided a custom response, return function and tool failures indicating user interruption - tool_returns = [] - if message.get("function_call", False): - tool_returns.append( - { - "role": "function", - "name": message["function_call"].get("name", ""), - "content": "USER INTERRUPTED", - } - ) - - if message.get("tool_calls", False): - tool_returns.extend( - [ - {"role": "tool", "tool_call_id": tool_call.get("id", ""), "content": "USER INTERRUPTED"} - for tool_call in message["tool_calls"] - ] - ) - + response = {"role": "user", "content": reply} - if tool_returns: - response["tool_responses"] = tool_returns - + return True, response # increment the consecutive_auto_reply_counter @@ -1418,28 +1132,8 @@ async def a_check_termination_and_human_reply( if reply or self._max_consecutive_auto_reply_dict[sender] == 0: # User provided a custom response, return function and tool results indicating user interruption # reset the consecutive_auto_reply_counter - self._consecutive_auto_reply_counter[sender] = 0 - tool_returns = [] - if message.get("function_call", False): - tool_returns.append( - { - "role": "function", - "name": message["function_call"].get("name", ""), - "content": "USER INTERRUPTED", - } - ) - - if message.get("tool_calls", False): - tool_returns.extend( - [ - {"role": "tool", "tool_call_id": tool_call.get("id", ""), "content": "USER INTERRUPTED"} - for tool_call in message["tool_calls"] - ] - ) - + self._consecutive_auto_reply_counter[sender] = 0 response = {"role": "user", "content": reply} - if tool_returns: - response["tool_responses"] = tool_returns return True, response @@ -1461,8 +1155,6 @@ def generate_reply( Use registered auto reply functions to generate replies. By default, the following functions are checked in order: 1. check_termination_and_human_reply - 2. generate_function_call_reply (deprecated in favor of tool_calls) - 3. generate_tool_calls_reply 5. generate_oai_reply Every function returns a tuple (final, reply). When a function returns final=False, the next function will be checked. @@ -1520,8 +1212,6 @@ async def a_generate_reply( Use registered auto reply functions to generate replies. By default, the following functions are checked in order: 1. check_termination_and_human_reply - 2. generate_function_call_reply - 3. generate_tool_calls_reply 5. generate_oai_reply Every function returns a tuple (final, reply). When a function returns final=False, the next function will be checked. @@ -1636,44 +1326,6 @@ async def a_get_human_input(self, prompt: str) -> str: reply = await loop.run_in_executor(None, functools.partial(self.get_human_input, prompt)) return reply - def execute_function(self, func_call: dict) -> dict[str, str]: - """Execute a function call and return the result. - - Override this function to modify the way to execute function and tool calls. - - Args: - func_call: a dictionary extracted from openai message at "function_call" or "tool_calls" with keys "name" and "arguments". - - """ - - func_name = func_call.get("name", "") - content = f"Error: Function {func_name} not found." - - return { - "name": func_name, - "role": "function", - "content": str(content), - } - - async def a_execute_function(self, func_call): - """Execute an async function call and return the result. - - Override this function to modify the way async functions and tools are executed. - - Args: - func_call: a dictionary extracted from openai message at key "function_call" or "tool_calls" with keys "name" and "arguments". - - """ - - func_name = func_call.get("name", "") - - content = f"Error: Function {func_name} not found." - - return { - "name": func_name, - "role": "function", - "content": str(content), - } def generate_init_message(self, message: dict | str | None, **kwargs) -> str | dict: """Generate the initial message for the agent. @@ -1753,95 +1405,6 @@ async def a_generate_init_message(self, message: dict | str | None, **kwargs) -> return self._handle_carryover(message, kwargs) - def update_function_signature(self, func_sig: str | dict, is_remove: None): - """update a function_signature in the LLM configuration for function_call. - - Args: - func_sig (str or dict): description/name of the function to update/remove to the model. See: https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions - is_remove: whether removing the function from llm_config with name 'func_sig' - - Deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) - See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call - """ - - if not isinstance(self.llm_config, dict): - error_msg = "To update a function signature, agent must have an llm_config" - raise AssertionError(error_msg) - - if is_remove: - if "functions" not in self.llm_config.keys(): - error_msg = "The agent config doesn't have function {name}.".format(name=func_sig) - logger.error(error_msg) - raise AssertionError(error_msg) - else: - self.llm_config["functions"] = [ - func for func in self.llm_config["functions"] if func["name"] != func_sig - ] - else: - if not isinstance(func_sig, dict): - raise ValueError( - f"The function signature must be of the type dict. Received function signature type {type(func_sig)}" - ) - - self._assert_valid_name(func_sig["name"]) - if "functions" in self.llm_config.keys(): - if any(func["name"] == func_sig["name"] for func in self.llm_config["functions"]): - warnings.warn(f"Function '{func_sig['name']}' is being overridden.", UserWarning) - - self.llm_config["functions"] = [ - func for func in self.llm_config["functions"] if func.get("name") != func_sig["name"] - ] + [func_sig] - else: - self.llm_config["functions"] = [func_sig] - - if len(self.llm_config["functions"]) == 0: - del self.llm_config["functions"] - - self.client = OpenAIWrapper(**self.llm_config) - - def update_tool_signature(self, tool_sig: str | dict, is_remove: None): - """update a tool_signature in the LLM configuration for tool_call. - - Args: - tool_sig (str or dict): description/name of the tool to update/remove to the model. See: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools - is_remove: whether removing the tool from llm_config with name 'tool_sig' - """ - - if not self.llm_config: - error_msg = "To update a tool signature, agent must have an llm_config" - raise AssertionError(error_msg) - - if is_remove: - if "tools" not in self.llm_config.keys(): - error_msg = "The agent config doesn't have tool {name}.".format(name=tool_sig) - logger.error(error_msg) - raise AssertionError(error_msg) - else: - self.llm_config["tools"] = [ - tool for tool in self.llm_config["tools"] if tool["function"]["name"] != tool_sig - ] - else: - if not isinstance(tool_sig, dict): - raise ValueError( - f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}" - ) - self._assert_valid_name(tool_sig["function"]["name"]) - if "tools" in self.llm_config: - if any(tool["function"]["name"] == tool_sig["function"]["name"] for tool in self.llm_config["tools"]): - warnings.warn(f"Function '{tool_sig['function']['name']}' is being overridden.", UserWarning) - self.llm_config["tools"] = [ - tool - for tool in self.llm_config["tools"] - if tool.get("function", {}).get("name") != tool_sig["function"]["name"] - ] + [tool_sig] - else: - self.llm_config["tools"] = [tool_sig] - - if len(self.llm_config["tools"]) == 0: - del self.llm_config["tools"] - - self.client = OpenAIWrapper(**self.llm_config) - def _wrap_function(self, func: F) -> F: """Wrap the function to dump the return value to json. @@ -1959,8 +1522,6 @@ def process_last_received_message(self, messages: list[dict]) -> list[dict]: if len(messages) == 0: return messages # No message to process. last_message = messages[-1] - if "function_call" in last_message: - return messages # Last message is a function call. if "context" in last_message: return messages # Last message contains a context key. if "content" not in last_message: @@ -2004,8 +1565,6 @@ async def a_process_last_received_message(self, messages: list[dict]) -> list[di if len(messages) == 0: return messages # No message to process. last_message = messages[-1] - if "function_call" in last_message: - return messages # Last message is a function call. if "context" in last_message: return messages # Last message contains a context key. if "content" not in last_message: From bc184eb21b43fff2a1b81f40641847992d2fcb59 Mon Sep 17 00:00:00 2001 From: fmuuly <64724985+fmp453@users.noreply.github.com> Date: Mon, 2 Feb 2026 17:48:14 +0900 Subject: [PATCH 25/25] remove docstring --- train_methods/legacy_autogen/cache.py | 195 +------------------------ train_methods/legacy_autogen/stream.py | 39 ----- 2 files changed, 1 insertion(+), 233 deletions(-) diff --git a/train_methods/legacy_autogen/cache.py b/train_methods/legacy_autogen/cache.py index decfe5b..1821219 100644 --- a/train_methods/legacy_autogen/cache.py +++ b/train_methods/legacy_autogen/cache.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from types import TracebackType from typing import Any, Protocol, Self @@ -6,50 +5,17 @@ import diskcache class AbstractCache(Protocol): - """ - This protocol defines the basic interface for cache operations. - Implementing classes should provide concrete implementations for - these methods to handle caching mechanisms. - """ def get(self, key: str, default: Any | None = None) -> Any | None: - """ - Retrieve an item from the cache. - - Args: - key (str): The key identifying the item in the cache. - default (optional): The default value to return if the key is not found. - Defaults to None. - - Returns: - The value associated with the key if found, else the default value. - """ ... def set(self, key: str, value: Any) -> None: - """ - Set an item in the cache. - - Args: - key (str): The key under which the item is to be stored. - value: The value to be stored in the cache. - """ ... def close(self) -> None: - """ - Close the cache. Perform any necessary cleanup, such as closing network connections or - releasing resources. - """ ... def __enter__(self) -> Self: - """ - Enter the runtime context related to this object. - - The with statement will bind this method's return value to the target(s) - specified in the as clause of the statement, if any. - """ ... def __exit__( @@ -58,86 +24,22 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - """ - Exit the runtime context and close the cache. - - Args: - exc_type: The exception type if an exception was raised in the context. - exc_value: The exception value if an exception was raised in the context. - traceback: The traceback if an exception was raised in the context. - """ ... class DiskCache(AbstractCache): - """ - Implementation of AbstractCache using the DiskCache library. - - This class provides a concrete implementation of the AbstractCache - interface using the diskcache library for caching data on disk. - - Attributes: - cache (diskcache.Cache): The DiskCache instance used for caching. - - Methods: - __init__(self, seed): Initializes the DiskCache with the given seed. - get(self, key, default=None): Retrieves an item from the cache. - set(self, key, value): Sets an item in the cache. - close(self): Closes the cache. - __enter__(self): Context management entry. - __exit__(self, exc_type, exc_value, traceback): Context management exit. - """ - def __init__(self, seed: str | int): - """ - Initialize the DiskCache instance. - - Args: - seed (str | int): A seed or namespace for the cache. This is used to create - a unique storage location for the cache data. - - """ self.cache = diskcache.Cache(seed) def get(self, key: str, default: Any | None = None) -> Any | None: - """ - Retrieve an item from the cache. - - Args: - key (str): The key identifying the item in the cache. - default (optional): The default value to return if the key is not found. - Defaults to None. - - Returns: - The value associated with the key if found, else the default value. - """ return self.cache.get(key, default) def set(self, key: str, value: Any) -> None: - """ - Set an item in the cache. - - Args: - key (str): The key under which the item is to be stored. - value: The value to be stored in the cache. - """ self.cache.set(key, value) def close(self) -> None: - """ - Close the cache. - - Perform any necessary cleanup, such as closing file handles or - releasing resources. - """ self.cache.close() def __enter__(self) -> Self: - """ - Enter the runtime context related to the object. - - Returns: - self: The instance itself. - """ return self def __exit__( @@ -146,16 +48,6 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - """ - Exit the runtime context related to the object. - - Perform cleanup actions such as closing the cache. - - Args: - exc_type: The exception type if an exception was raised in the context. - exc_value: The exception value if an exception was raised in the context. - traceback: The traceback if an exception was raised in the context. - """ self.close() class CacheFactory: @@ -164,34 +56,10 @@ def cache_factory( seed: str | int, cache_path_root: str = ".cache", ) -> AbstractCache: - """ - Factory function for creating cache instances. - - Args: - seed (str | int): Used as a seed or namespace for the cache. - cache_path_root (str): Root path for the disk cache. - - Returns: - An instance of DiskCache - - """ - # Default to DiskCache if neither Redis nor Cosmos DB configurations are provided path = Path(cache_path_root, str(seed)) - return DiskCache(os.path.join(".", path)) + return DiskCache(Path(".", path)) class Cache(AbstractCache): - """ - A wrapper class for managing cache configuration and instances. - - This class provides a unified interface for creating and interacting with - different types of cache (e.g., Redis, Disk). It abstracts the underlying - cache implementation details, providing methods for cache operations. - - Attributes: - config (dict[str, Any]): A dictionary containing cache configuration. - cache: The cache instance created based on the provided configuration. - """ - ALLOWED_CONFIG_KEYS = [ "cache_seed", "cache_path_root", @@ -199,30 +67,9 @@ class Cache(AbstractCache): @staticmethod def disk(cache_seed: str | int = 42, cache_path_root: str = ".cache") -> "Cache": - """ - Create a Disk cache instance. - - Args: - cache_seed (str | int, optional): A seed for the cache. Defaults to 42. - cache_path_root (str, optional): The root path for the disk cache. Defaults to ".cache". - - Returns: - Cache: A Cache instance configured for Disk caching. - """ return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root}) def __init__(self, config: dict[str, Any]): - """ - Initialize the Cache with the given configuration. - - Validates the configuration keys and creates the cache instance. - - Args: - config (dict[str, Any]): A dictionary containing the cache configuration. - - Raises: - ValueError: If an invalid configuration key is provided. - """ self.config = config # Ensure that the seed is always treated as a string before being passed to any cache factory or stored. self.config["cache_seed"] = str(self.config.get("cache_seed", 42)) @@ -238,12 +85,6 @@ def __init__(self, config: dict[str, Any]): ) def __enter__(self) -> "Cache": - """ - Enter the runtime context related to the cache object. - - Returns: - The cache instance for use within a context block. - """ return self.cache.__enter__() def __exit__( @@ -252,47 +93,13 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - """ - Exit the runtime context related to the cache object. - - Cleans up the cache instance and handles any exceptions that occurred - within the context. - - Args: - exc_type: The exception type if an exception was raised in the context. - exc_value: The exception value if an exception was raised in the context. - traceback: The traceback if an exception was raised in the context. - """ return self.cache.__exit__(exc_type, exc_value, traceback) def get(self, key: str, default: Any | None = None) -> Any | None: - """ - Retrieve an item from the cache. - - Args: - key (str): The key identifying the item in the cache. - default (optional): The default value to return if the key is not found. - Defaults to None. - - Returns: - The value associated with the key if found, else the default value. - """ return self.cache.get(key, default) def set(self, key: str, value: Any) -> None: - """ - Set an item in the cache. - - Args: - key (str): The key under which the item is to be stored. - value: The value to be stored in the cache. - """ self.cache.set(key, value) def close(self) -> None: - """ - Close the cache. - - Perform any necessary cleanup, such as closing connections or releasing resources. - """ self.cache.close() diff --git a/train_methods/legacy_autogen/stream.py b/train_methods/legacy_autogen/stream.py index 7ae567b..a022145 100644 --- a/train_methods/legacy_autogen/stream.py +++ b/train_methods/legacy_autogen/stream.py @@ -5,34 +5,15 @@ class OutputStream(Protocol): def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None: - """Print data to the output stream. - - Args: - objects (any): The data to print. - sep (str, optional): The separator between objects. Defaults to " ". - end (str, optional): The end of the output. Defaults to "\n". - flush (bool, optional): Whether to flush the output. Defaults to False. - """ ... # pragma: no cover class InputStream(Protocol): def input(self, prompt: str = "", *, password: bool = False) -> str: - """Read a line from the input stream. - - Args: - prompt (str, optional): The prompt to display. Defaults to "". - password (bool, optional): Whether to read a password. Defaults to False. - - Returns: - str: The line read from the input stream. - - """ ... # pragma: no cover class IOStream(InputStream, OutputStream, Protocol): - """A protocol for input/output streams.""" # ContextVar must be used in multithreaded or async environments _default_io_stream: ContextVar["IOStream" | None] = ContextVar("default_iostream", default=None) @@ -41,31 +22,16 @@ class IOStream(InputStream, OutputStream, Protocol): @staticmethod def set_global_default(stream: "IOStream") -> None: - """Set the default input/output stream. - - Args: - stream (IOStream): The input/output stream to set as the default. - """ IOStream._global_default = stream @staticmethod def get_global_default() -> "IOStream": - """Get the default input/output stream. - - Returns: - IOStream: The default input/output stream. - """ if IOStream._global_default is None: raise RuntimeError("No global default IOStream has been set") return IOStream._global_default @staticmethod def get_default() -> "IOStream": - """Get the default input/output stream. - - Returns: - IOStream: The default input/output stream. - """ iostream = IOStream._default_io_stream.get() if iostream is None: iostream = IOStream.get_global_default() @@ -76,11 +42,6 @@ def get_default() -> "IOStream": @staticmethod @contextmanager def set_default(stream: "IOStream" | None) -> Iterator[None]: - """Set the default input/output stream. - - Args: - stream (IOStream): The input/output stream to set as the default. - """ global _default_io_stream try: token = IOStream._default_io_stream.set(stream)