diff --git a/evaluation/scripts/hotpot/hotpot_eval.py b/evaluation/scripts/hotpot/hotpot_eval.py new file mode 100644 index 00000000..036f7a69 --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_eval.py @@ -0,0 +1,246 @@ +import importlib.util +import json +import os + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from datasets import load_dataset +from dotenv import load_dotenv +from tqdm import tqdm + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS + + +load_dotenv() + +db_name = "stx-hotpot-001" + +openapi_config = { + "model_name_or_path": "gpt-4o", + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") +data = load_dataset("hotpotqa/hotpot_qa", "distractor") +base_config = { + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": {"backend": "openai", "config": openapi_config}, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, +} + + +def init_mos_and_cube(user_name: str) -> MOS: + cfg = dict(base_config) + cfg["user_id"] = user_name + mos_config = MOSConfig(**cfg) + mos = MOS(mos_config) + cube_conf = GeneralMemCubeConfig.model_validate( + { + "user_id": user_name, + "cube_id": f"{user_name}", + "text_mem": { + "backend": "tree_text", + "config": { + "extractor_llm": {"backend": "openai", "config": openapi_config}, + "dispatcher_llm": {"backend": "openai", "config": openapi_config}, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": neo4j_uri, + "user": "neo4j", + "password": "iaarlichunyu", + "db_name": db_name, + "user_name": user_name, + "use_multi_db": False, + "auto_create": True, + }, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "reorganize": False, + }, + }, + "act_mem": {}, + "para_mem": {}, + } + ) + mem_cube = GeneralMemCube(cube_conf) + temp_dir = "tmp/" + user_name + if not os.path.exists(temp_dir) or not os.listdir(temp_dir): + mem_cube.dump(temp_dir) + mos.register_mem_cube(temp_dir, mem_cube_id=user_name) + return mos + + +def build_context_text(context_list): + parts = [] + for title, sentences in context_list: + text = " ".join(s.strip() for s in sentences if s.strip()) + parts.append(f"{title}: {text}") + return "\n".join(parts) + + +def build_and_ask(item): + qid = item.get("_id") or item.get("id") + question = item["question"] + mos = init_mos_and_cube(qid) + ctx = item.get("context") + if isinstance(ctx, dict): + titles = ctx.get("title") or [] + sentences_list = ctx.get("sentences") or [] + for title, sentences in zip(titles, sentences_list, strict=False): + text = " ".join(s.strip() for s in sentences if isinstance(s, str) and s.strip()) + if title or text: + mos.add(memory_content=f"{title}: {text}") + else: + for entry in ctx or []: + if isinstance(entry, list) and len(entry) >= 2: + title, sentences = entry[0], entry[1] + elif isinstance(entry, dict): + title = entry.get("title", "") + sentences = entry.get("sentences", []) + else: + continue + text = " ".join(s.strip() for s in sentences if isinstance(s, str) and s.strip()) + if title or text: + mos.add(memory_content=f"{title}: {text}") + answer = mos.chat(question, qid).strip() + print("question:", question) + print("answer:", answer) + return qid, answer + + +pred_answers = {} +output_dir = "evaluation/data/hotpot/output" +os.makedirs(output_dir, exist_ok=True) +pred_path = os.path.join(output_dir, "dev_distractor_pred.json") +gold_path = os.path.join(output_dir, "dev_distractor_gold.json") + + +def write_gold(data): + split = data.get("validation") + items_list = [split[i] for i in range(10)] + out = [] + for it in items_list: + qid = it.get("_id") or it.get("id") + sp = it.get("supporting_facts") + if isinstance(sp, dict): + titles = sp.get("title") or [] + sent_ids = sp.get("sent_id") or [] + sp_list = [[t, s] for t, s in zip(titles, sent_ids, strict=False)] + else: + sp_list = sp or [] + ctx = it.get("context") + if isinstance(ctx, dict): + titles = ctx.get("title") or [] + sentences = ctx.get("sentences") or [] + ctx_list = [[t, s] for t, s in zip(titles, sentences, strict=False)] + else: + ctx_list = ctx or [] + out.append( + { + "_id": qid, + "question": it.get("question"), + "answer": it.get("answer"), + "supporting_facts": sp_list, + "context": ctx_list, + } + ) + with open(gold_path, "w", encoding="utf-8") as f: + json.dump(out, f, ensure_ascii=False) + + +def run_eval(): + spec = importlib.util.spec_from_file_location( + "hotpot_eval_v1", "evaluation/scripts/hotpot/hotpot_evaluate_v1.py" + ) + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + print("评估分数:") + m.eval(pred_path, gold_path) + + +def main(): + interval = 50 + split = data.get("validation") + items_list = [split[i] for i in range(10)] + if os.path.exists(pred_path): + try: + with open(pred_path, encoding="utf-8") as f: + prev = json.load(f) + if isinstance(prev, dict) and isinstance(prev.get("answer"), dict): + pred_answers.update(prev["answer"]) + except Exception: + pass + processed = len(pred_answers) + print("开始评估,总样本:", len(items_list)) + print("已存在预测:", processed) + pending_items = [] + for it in items_list: + qid = it.get("_id") or it.get("id") + if qid not in pred_answers: + pending_items.append(it) + with ThreadPoolExecutor(max_workers=4) as executor: + futures = { + executor.submit(build_and_ask, item): idx for idx, item in enumerate(pending_items) + } + for future in tqdm(as_completed(futures), total=len(futures)): + qid, answer = future.result() + pred_answers[qid] = answer + processed += 1 + if processed % 10 == 0: + print("已完成:", processed, "剩余:", len(items_list) - processed) + with open(pred_path, "w", encoding="utf-8") as f: + json.dump({"answer": pred_answers, "sp": {}}, f, ensure_ascii=False, indent=2) + if processed % interval == 0: + print("阶段评估,当前进度:", processed) + run_eval() + + print("最终评估:") + run_eval() + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/hotpot/hotpot_evaluate_v1.py b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py new file mode 100644 index 00000000..2002d8dc --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py @@ -0,0 +1,149 @@ +import re +import string +import sys + +from collections import Counter + +import ujson as json + + +def normalize_answer(s): + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + zero_metric = (0, 0, 0) + + if ( + normalized_prediction in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): + return zero_metric + if ( + normalized_ground_truth in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): + return zero_metric + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return zero_metric + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, precision, recall + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def update_answer(metrics, prediction, gold): + em = exact_match_score(prediction, gold) + f1, prec, recall = f1_score(prediction, gold) + metrics["em"] += float(em) + metrics["f1"] += f1 + metrics["prec"] += prec + metrics["recall"] += recall + return em, prec, recall + + +def update_sp(metrics, prediction, gold): + cur_sp_pred = set(map(tuple, prediction)) + gold_sp_pred = set(map(tuple, gold)) + tp, fp, fn = 0, 0, 0 + for e in cur_sp_pred: + if e in gold_sp_pred: + tp += 1 + else: + fp += 1 + for e in gold_sp_pred: + if e not in cur_sp_pred: + fn += 1 + prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 + recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 + f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 + em = 1.0 if fp + fn == 0 else 0.0 + metrics["sp_em"] += em + metrics["sp_f1"] += f1 + metrics["sp_prec"] += prec + metrics["sp_recall"] += recall + return em, prec, recall + + +def eval(prediction_file, gold_file): + with open(prediction_file) as f: + prediction = json.load(f) + with open(gold_file) as f: + gold = json.load(f) + + metrics = { + "em": 0, + "f1": 0, + "prec": 0, + "recall": 0, + "sp_em": 0, + "sp_f1": 0, + "sp_prec": 0, + "sp_recall": 0, + "joint_em": 0, + "joint_f1": 0, + "joint_prec": 0, + "joint_recall": 0, + } + for dp in gold: + cur_id = dp["_id"] + can_eval_joint = True + if cur_id not in prediction["answer"]: + can_eval_joint = False + else: + em, prec, recall = update_answer(metrics, prediction["answer"][cur_id], dp["answer"]) + if cur_id not in prediction["sp"]: + can_eval_joint = False + else: + sp_em, sp_prec, sp_recall = update_sp( + metrics, prediction["sp"][cur_id], dp["supporting_facts"] + ) + + if can_eval_joint: + joint_prec = prec * sp_prec + joint_recall = recall * sp_recall + if joint_prec + joint_recall > 0: + joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) + else: + joint_f1 = 0.0 + joint_em = em * sp_em + + metrics["joint_em"] += joint_em + metrics["joint_f1"] += joint_f1 + metrics["joint_prec"] += joint_prec + metrics["joint_recall"] += joint_recall + + n = len(gold) + for k in metrics: + metrics[k] /= n + + print(metrics) + + +if __name__ == "__main__": + eval(sys.argv[1], sys.argv[2]) diff --git a/evaluation/scripts/longbenchV2/import_data.py b/evaluation/scripts/longbenchV2/import_data.py new file mode 100644 index 00000000..ff17cd27 --- /dev/null +++ b/evaluation/scripts/longbenchV2/import_data.py @@ -0,0 +1,20 @@ +from datasets import load_dataset + + +dataset = load_dataset("zai-org/LongBench-v2", split="train") +print(dataset) + + +def truncate(value, max_len=200): + if isinstance(value, str) and len(value) > max_len: + return value[:max_len] + "... [TRUNCATED]" + return value + + +for i in range(10): + sample = dataset[i] + print(f"========== Sample {i} ==========") + for key, value in sample.items(): + print(f"{key}: {truncate(value)}") + + print("\n") diff --git a/evaluation/scripts/mmlongbench/eval/__init__.py b/evaluation/scripts/mmlongbench/eval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/evaluation/scripts/mmlongbench/eval/eval_score.py b/evaluation/scripts/mmlongbench/eval/eval_score.py new file mode 100644 index 00000000..02ef6eb5 --- /dev/null +++ b/evaluation/scripts/mmlongbench/eval/eval_score.py @@ -0,0 +1,246 @@ +import re + +from collections import defaultdict +from math import isclose + + +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] + + +def anls_compute(groundtruth, prediction, threshold=0.5): + dist = levenshtein_distance(groundtruth, prediction) + length = max(len(groundtruth.upper()), len(prediction.upper())) + value = 0.0 if length == 0 else float(dist) / float(length) + anls = 1.0 - value + if anls <= threshold: + anls = 0.0 + return anls + + +def is_float_equal( + reference, prediction, include_percentage: bool = False, is_close: float = False +) -> bool: + def get_precision(gt_ans: float) -> int: + precision = 3 + if "." in str(gt_ans): + precision = len(str(gt_ans).split(".")[-1]) + return precision + + reference = float(str(reference).strip().rstrip("%").strip()) + try: + prediction = float(str(prediction).strip().rstrip("%").strip()) + except Exception: + return False + + gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + for item in gt_result: + try: + if is_close and isclose(item, prediction, rel_tol=0.01): + return True + precision = max(min(get_precision(prediction), get_precision(item)), 2) + if round(prediction, precision) == round(item, precision): + return True + except Exception: + continue + return False + + +def get_clean_string(s): + s = str(s).lower().strip() + + for suffix in ["mile", "miles", "million"]: + if s.endswith(suffix): + s = s[: -len(suffix)].strip() + + s = re.sub(r"\s*\([^)]*\)", "", s).strip() + s = re.sub(r"^['\"]|['\"]$", "", s).strip() + s = s.lstrip("$").rstrip("%").strip() + + return s + + +def is_exact_match(s): + flag = False + # Website + if "https://" in s: + flag = True + # code file + if s.endswith((".py", ".ipynb")) or s.startswith("page"): + flag = True + # telephone number + if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s): + flag = True + # time + if "a.m." in s or "p.m." in s: + flag = True + # YYYY-MM-DD + if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s): + flag = True + # YYYY-MM + if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s): + flag = True + # Email address + if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s): + flag = True + return flag + + +def isfloat(num): + try: + float(num) + return True + except ValueError: + return False + + +def eval_score(gt, pred, answer_type): + if answer_type == "Int": + try: + gt, pred = int(gt), int(float(pred)) + except Exception: + pred = "" + score = gt == pred + elif answer_type == "Float": + try: + gt = float(get_clean_string(str(gt))) + pred = float(get_clean_string(str(pred))) + except Exception: + pred = "" + score = is_float_equal(gt, pred, include_percentage=True, is_close=True) + elif answer_type in ["Str", "None"]: + gt = get_clean_string(gt) + pred = get_clean_string(pred) + score = gt == pred if is_exact_match(gt) else anls_compute(gt, pred) + else: + if isinstance(gt, str) and gt.startswith("["): + gt = eval(gt) + if not isinstance(gt, list): + gt = [gt] + if isinstance(pred, str) and pred.startswith("["): + pred = eval(pred) + if not isinstance(pred, list): + pred = [pred] + print(len(gt), len(pred)) + if len(gt) != len(pred): + score = 0.0 + else: + gt = sorted([get_clean_string(a) for a in gt]) + pred = sorted([get_clean_string(a) for a in pred]) + print(gt, pred) + if isfloat(gt[0]) or is_exact_match(gt[0]): + score = "-".join(gt) == "-".join(pred) + else: + score = min( + [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)] + ) + + return float(score) + + +def eval_acc_and_f1(samples): + evaluated_samples = [sample for sample in samples if "score" in sample] + if not evaluated_samples: + return 0.0, 0.0 + + acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples) + try: + recall = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"]) + precision = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"]) + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0 + except Exception: + f1 = 0.0 + + return acc, f1 + + +def show_results(samples, show_path=None): + for sample in samples: + sample["evidence_pages"] = eval(sample["evidence_pages"]) + sample["evidence_sources"] = eval(sample["evidence_sources"]) + + with open(show_path, "w") as f: + acc, f1 = eval_acc_and_f1(samples) + f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n") + f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n") + f.write("-----------------------\n") + + acc_single_page, _ = eval_acc_and_f1( + [sample for sample in samples if len(sample["evidence_pages"]) == 1] + ) + acc_multi_page, _ = eval_acc_and_f1( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable" + ] + ) + acc_neg, _ = eval_acc_and_f1( + [sample for sample in samples if sample["answer"] == "Not answerable"] + ) + + f.write( + "Single-page | Accuracy: {} | Question Number: {}\n".format( + acc_single_page, + len([sample for sample in samples if len(sample["evidence_pages"]) == 1]), + ) + ) + f.write( + "Cross-page | Accuracy: {} | Question Number: {}\n".format( + acc_multi_page, + len( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 + and sample["answer"] != "Not answerable" + ] + ), + ) + ) + f.write( + "Unanswerable | Accuracy: {} | Question Number: {}\n".format( + acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"]) + ) + ) + f.write("-----------------------\n") + + source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list) + for sample in samples: + for answer_source in sample["evidence_sources"]: + source_sample_dict[answer_source].append(sample) + document_type_dict[sample["doc_type"]].append(sample) + for type, sub_samples in source_sample_dict.items(): + f.write( + f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) + + f.write("-----------------------\n") + for type, sub_samples in document_type_dict.items(): + f.write( + f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) diff --git a/evaluation/scripts/mmlongbench/eval/extract_answer.py b/evaluation/scripts/mmlongbench/eval/extract_answer.py new file mode 100644 index 00000000..b7f7e686 --- /dev/null +++ b/evaluation/scripts/mmlongbench/eval/extract_answer.py @@ -0,0 +1,33 @@ +import os + +import openai + +from dotenv import load_dotenv + + +load_dotenv() +client = openai.Client( + api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +) + + +def extract_answer(question, output, prompt, model_name="gpt-4o"): + response = client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": prompt, + }, + {"role": "assistant", "content": f"\n\nQuestion:{question}\nAnalysis:{output}\n"}, + ], + temperature=0.0, + max_tokens=256, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + ) + response = response.choices[0].message.content + + return response diff --git a/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md b/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md new file mode 100644 index 00000000..a309c093 --- /dev/null +++ b/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md @@ -0,0 +1,35 @@ +Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis. +- Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List. If you find the analysis the question can not be answered from the given documents, type "Not answerable". Exception: If the analysis only tells you that it can not read/understand the images or documents, type "Fail to answer". +- Please make your response as concise as possible. Also note that your response should be formatted as below: +``` +Extracted answer: [answer] +Answer format: [answer format] +``` + +Please read the following example, then extract the answer from the model response and type it at the end of the prompt. + +--- +Question: List the primary questions asked about the services in this report. +Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n1. Is the service safe?\n2. Is the service effective?\n3. Is the service caring?\n4. Is the service responsive?\n5. Is the service well-led? +Extracted answer: ['Is the servife safe?', 'Is the service effective', 'Is the serve caring?', 'Is the service responsive?', 'Is the service well-led?'] +Answer format: List + +--- +Question: How many regulations of the HSCA 2008 are breached in all according to this report? +Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities) Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11: Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17: Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9. Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing, safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to notify the CQC of incidents. +Extracted answer: 10 +Answer format: Integer + +--- +Question: According to the survey that is the percentage of Chinese who are paying more or about the same attention to politics after Trump's election? +Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying more or about the same attention to politics after Trump's election. The report focuses primarily on American demographics and does not include specific details about the Chinese population in relation to this question. If you need information about a different demographic or a summary of the findings from the American demographic, I can certainly help with that! +Extracted answer: Not answerable +Answer format: String + +--- +Question: How many quotations from male respondent over 50 years old are included in this report? +Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be able to help you with your question. +Extracted answer: Fail to answer +Answer format: String + +--- diff --git a/evaluation/scripts/mmlongbench/eval_docs.py b/evaluation/scripts/mmlongbench/eval_docs.py new file mode 100644 index 00000000..510a0b1e --- /dev/null +++ b/evaluation/scripts/mmlongbench/eval_docs.py @@ -0,0 +1,265 @@ +import csv +import json +import os +import re +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from dotenv import load_dotenv +from eval.eval_score import eval_acc_and_f1, eval_score, show_results +from eval.extract_answer import extract_answer +from tqdm import tqdm + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS + + +load_dotenv() +openapi_config = { + "model_name_or_path": "gpt-4o", + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +neo4j_uri = os.getenv("NEO4J_URI", "bolt://47.117.41.207:7687") +db_name = "stx-mmlongbench-004" + +doc_paths = [ + f + for f in os.listdir("evaluation/data/mmlongbench/documents") + if os.path.isfile(os.path.join("evaluation/data/mmlongbench/documents", f)) +] + +with open("evaluation/data/mmlongbench/samples.json") as f: + samples = json.load(f) + +RESULTS_PATH = "evaluation/data/mmlongbench/test_results.json" +completed_pairs: set[tuple[str, str]] = set() + + +def _load_existing_results(): + global completed_pairs + if os.path.exists(RESULTS_PATH): + try: + with open(RESULTS_PATH, encoding="utf-8") as f: + existing = json.load(f) + for r in existing: + did = r.get("doc_id") + q = r.get("question") + if did and q: + completed_pairs.add((did, q)) + return existing + except Exception: + return [] + return [] + + +def _doc_has_pending(doc_file: str) -> bool: + for s in samples: + if s.get("doc_id") == doc_file and (doc_file, s.get("question")) not in completed_pairs: + return True + return False + + +def get_user_name(doc_file): + csv_path = "evaluation/data/mmlongbench/user_doc_map.csv" + if os.path.exists(csv_path): + with open(csv_path, newline="", encoding="utf-8") as f: + reader = csv.reader(f) + for row in reader: + uid, path = row[0], row[1] + base = os.path.basename(path) + if base == doc_file or os.path.splitext(base)[0] == os.path.splitext(doc_file)[0]: + return uid + return "" + + +def process_doc(doc_file): + user_name = get_user_name(doc_file) + print(user_name, doc_file) + config = { + "user_id": user_name, + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": {"backend": "openai", "config": openapi_config}, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, + } + mos_config = MOSConfig(**config) + mos = MOS(mos_config) + + mem_cube_config = GeneralMemCubeConfig.model_validate( + { + "user_id": user_name, + "cube_id": user_name, + "text_mem": { + "backend": "tree_text", + "config": { + "extractor_llm": {"backend": "openai", "config": openapi_config}, + "dispatcher_llm": {"backend": "openai", "config": openapi_config}, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": neo4j_uri, + "user": "neo4j", + "password": "iaarlichunyu", + "db_name": db_name, + "user_name": user_name, + "use_multi_db": False, + "auto_create": True, + "embedding_dimension": 3072, + }, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "reorganize": False, + }, + }, + "act_mem": {}, + "para_mem": {}, + } + ) + mem_cube = GeneralMemCube(mem_cube_config) + + temp_dir = "tmp/" + doc_file + if not os.path.exists(temp_dir) or not os.listdir(temp_dir): + mem_cube.dump(temp_dir) + + mos.register_mem_cube(temp_dir, mem_cube_id=user_name) + + with open("evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md") as f: + prompt = f.read() + + samples_res = [] + doc_samples = [s for s in samples if s.get("doc_id") == doc_file] + if len(doc_samples) == 0: + return [] + + for sample in tqdm(doc_samples, desc=f"Processing {doc_file}"): + if (doc_file, sample.get("question")) in completed_pairs: + continue + messages = sample["question"] + try_cnt, is_success = 0, False + + while True: + try: + mos.clear_messages() + response = mos.chat(messages, user_name) + is_success = True + except Exception as e: + print(f"[{doc_file}] Error:", e) + traceback.print_exc() + try_cnt += 1 + response = "Failed" + if is_success or try_cnt > 5: + break + + sample["response"] = response + extracted_res = extract_answer(sample["question"], response, prompt) + sample["extracted_res"] = extracted_res + + pred_ans = extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip() + score = eval_score(sample["answer"], pred_ans, sample["answer_format"]) + + sample["pred"] = pred_ans + sample["score"] = score + samples_res.append(sample) + + print("--------------------------------------") + print(f"Question: {sample['question']}") + print(f"Response: {sample['response']}") + print(f"Ground true: {sample['answer']}\tPred: {sample['pred']}\tScore: {sample['score']}") + + return samples_res + + +if __name__ == "__main__": + results = _load_existing_results() + total_samples = len(samples) + processed_samples = len(completed_pairs) + pending_samples = total_samples - processed_samples + sample_doc_ids = [s.get("doc_id") for s in samples if s.get("doc_id")] + all_docs_in_samples = set(sample_doc_ids) + processed_docs = {d for d, _ in completed_pairs} + with ThreadPoolExecutor(max_workers=4) as executor: + pending_docs = [d for d in doc_paths if _doc_has_pending(d)] + print("\n" + "=" * 80) + print("📊 评测进度统计") + print("=" * 80) + print(f"✅ 已加载历史结果: {len(results)} 条") + print(f"📂 数据集总样本: {total_samples}") + print(f"🧪 已完成样本: {processed_samples}") + print(f"⏳ 待处理样本: {pending_samples}") + print(f"📄 数据集中总文档: {len(all_docs_in_samples)}") + print(f"✔️ 已完成文档: {len(processed_docs)}") + print(f"➡️ 待处理文档(本次将运行): {len(pending_docs)}") + print("=" * 80 + "\n") + future_to_doc = { + executor.submit(process_doc, doc_file): doc_file for doc_file in pending_docs + } + + for future in as_completed(future_to_doc): + doc_file = future_to_doc[future] + try: + res = future.result() + results.extend(res) + + if len(res) > 0: + acc, f1 = eval_acc_and_f1(results) + print() + print(f"Avg acc: {acc}") + print(f"Avg f1: {f1}") + + with open(RESULTS_PATH, "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"[{doc_file}] failed with {e}") + + acc, f1 = eval_acc_and_f1(results) + print("--------------------------------------") + print(f"Final avg acc: {acc}") + print(f"Final avg f1: {f1}") + + show_results( + results, + show_path=re.sub(r"\.json$", ".txt", "evaluation/data/mmlongbench/test_results_report.txt"), + ) diff --git a/evaluation/scripts/mmlongbench/import_docs.py b/evaluation/scripts/mmlongbench/import_docs.py new file mode 100644 index 00000000..540c8f96 --- /dev/null +++ b/evaluation/scripts/mmlongbench/import_docs.py @@ -0,0 +1,88 @@ +import asyncio +import os +import traceback +import uuid + +from memos import log +from memos.configs.mem_reader import SimpleStructMemReaderConfig +from memos.configs.memory import TreeTextMemoryConfig +from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.tree import TreeTextMemory + + +logger = log.get_logger(__name__) +db_name = "stx-mmlongbench-004" +# Create a memory reader instance +reader_config = SimpleStructMemReaderConfig.from_json_file( + "examples/data/config/simple_struct_reader_config.json" +) +reader = SimpleStructMemReader(reader_config) + +tree_config = TreeTextMemoryConfig.from_json_file( + "examples/data/config/tree_config_shared_database.json" +) +tree_config.graph_db.config.db_name = db_name +# Processing Documents +existing_names = { + d for d in os.listdir("ppt_test_result") if os.path.isdir(os.path.join("ppt_test_result", d)) +} +doc_paths = [] +for f in os.listdir("evaluation/data/mmlongbench/documents"): + fp = os.path.join("evaluation/data/mmlongbench/documents", f) + if os.path.isfile(fp): + name = os.path.splitext(f)[0] + if name in existing_names: + continue + doc_paths.append(fp) + +print("existing_names length:", len(existing_names)) +print("doc_paths length:", len(doc_paths)) + + +async def process_doc(doc_path): + print(f"🔄 Processing document: {doc_path}") + # Generate random user id: 'user_' + random short hex + user_id = "user_" + uuid.uuid4().hex[:8] + # Persist mapping between user_id and doc_path + try: + os.makedirs("evaluation/data/mmlongbench", exist_ok=True) + with open("evaluation/data/mmlongbench/user_doc_map.csv", "a", encoding="utf-8") as f: + f.write(f"{user_id},{doc_path}\n") + except Exception as e: + logger.error(f"Failed to write user-doc mapping: {e}") + + tree_config.graph_db.config.user_name = user_id + my_tree_textual_memory = TreeTextMemory(tree_config) + doc_memory = await reader.get_memory( + [doc_path], "doc", info={"user_id": user_id, "session_id": "session_" + str(uuid.uuid4())} + ) + + count = 0 + for m_list in doc_memory: + count += len(m_list) + my_tree_textual_memory.add(m_list) + print("total memories: ", count) + + return doc_path + + +async def main(): + batch_size = 4 + for i in range(0, len(doc_paths), batch_size): + batch = doc_paths[i : i + batch_size] + print(f"🚀 Starting batch {i // batch_size + 1} with {len(batch)} docs") + + tasks = [process_doc(p) for p in batch] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for p, result in zip(batch, results, strict=False): + if isinstance(result, Exception): + print(f"❌ Error processing {p}: {result}") + tb_text = "".join(traceback.TracebackException.from_exception(result).format()) + print(tb_text) + else: + print(f"✅ Finished {result}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/evaluation/scripts/mmlongbench/models/__init__.py b/evaluation/scripts/mmlongbench/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/evaluation/scripts/mmlongbench/models/internlm_xc2_4khd.py b/evaluation/scripts/mmlongbench/models/internlm_xc2_4khd.py new file mode 100644 index 00000000..ae62eec9 --- /dev/null +++ b/evaluation/scripts/mmlongbench/models/internlm_xc2_4khd.py @@ -0,0 +1,128 @@ +import torch +import torch.nn.functional as func + +from transformers import AutoModel, AutoTokenizer + + +torch.set_grad_enabled(False) + + +try: + from transformers.generation.streamers import BaseStreamer +except Exception: + BaseStreamer = None + + +def chat( + model, + tokenizer, + query: str, + image: None, + hd_num: int = 25, + history: list[tuple[str, str]] | None = None, + streamer: BaseStreamer | None = None, + max_new_tokens: int = 1024, + temperature: float = 1.0, + top_p: float = 0.8, + repetition_penalty: float = 1.005, + meta_instruction: str = "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n" + "- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by the user such as English and 中文.\n" + "- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively based on the provided image.", + **kwargs, +): + if history is None: + history = [] + if image is None: + inputs = model.build_inputs(tokenizer, query, history, meta_instruction) + im_mask = torch.zeros(inputs["input_ids"].shape[:2]).cuda().bool() + else: + if isinstance(image, str): + with torch.cuda.amp.autocast(): + image = model.encode_img(image, hd_num=hd_num) + inputs, im_mask = model.interleav_wrap_chat( + tokenizer, query, image, history, meta_instruction + ) + if isinstance(image, str): + image_list = [] + with torch.cuda.amp.autocast(): + for image_path in image: + tmp = model.encode_img(image_path, hd_num=hd_num) + image_list.append(tmp) + if len(image_list) > 1 and image_list[-1].shape[1] != image_list[-2].shape[1]: + image_list[-1] = func.interpolate( + image_list[-1].unsqueeze(1), size=image_list[-2].shape[1:], mode="bilinear" + ).squeeze(1) + image = torch.cat(image_list, dim=0) + with torch.cuda.amp.autocast(): + inputs, im_mask = model.interleav_wrap_chat( + tokenizer, query, image, history, meta_instruction + ) + else: + raise NotImplementedError + inputs = {k: v.to(model.device) for k, v in inputs.items() if torch.is_tensor(v)} + # also add end-of-assistant token in eos token id to avoid unnecessary generation + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids(["[UNUSED_TOKEN_145]"])[0], + ] + with torch.cuda.amp.autocast(): + outputs = model.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=temperature != 0.0, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + im_mask=im_mask, + **kwargs, + ) + if image is None: + outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :] + else: + outputs = outputs[0].cpu().tolist() + + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.partition("[UNUSED_TOKEN_145]")[0] + + history = [*history, (query, response)] + return response, history + + +def init_model(cache_path): + model_path = ( + cache_path + if (cache_path is not None and cache_path != "None") + else "internlm/internlm-xcomposer2-4khd-7b" + ) + model = AutoModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + device_map="auto", + ).eval() + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.tokenizer = tokenizer + return model + + +def get_response_concat(model, question, image_path_list, max_new_tokens=1024, temperature=1.0): + query = " " * len(image_path_list) + question + try: + response, _ = chat( + model, + model.tokenizer, + query=query, + image=image_path_list, + max_new_tokens=max_new_tokens, + hd_num=16, + temperature=temperature, + ) + except Exception as e: + print(e) + response = "Failed" + return response diff --git a/evaluation/scripts/mmlongbench/models/internvl_chat.py b/evaluation/scripts/mmlongbench/models/internvl_chat.py new file mode 100644 index 00000000..793f1d35 --- /dev/null +++ b/evaluation/scripts/mmlongbench/models/internvl_chat.py @@ -0,0 +1,137 @@ +import torch +import torchvision.transforms as tf + +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoTokenizer + + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def build_transform(input_size, mean=IMAGENET_MEAN, std=IMAGENET_STD): + return tf.Compose( + [ + tf.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + tf.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + tf.ToTensor(), + tf.Normalize(mean=mean, std=std), + ] + ) + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def load_image(image_file, input_size=448, max_num=6): + image = Image.open(image_file).convert("RGB") + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + +def init_model(cache_path): + import os + + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + + model_path = ( + cache_path + if (cache_path is not None and cache_path != "None") + else "OpenGVLab/InternVL-Chat-V1-5" + ) + model = AutoModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + device_map="auto", + ).eval() + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.tokenizer = tokenizer + return model + + +def get_response_concat( + model, question, image_path_list, max_new_tokens=1024, temperature=1.0, max_num=6 +): + generation_config = { + "num_beams": 1, + "max_new_tokens": max_new_tokens, + "do_sample": temperature != 0.0, + "temperature": temperature, + } + pixel_values_list = [ + load_image(image_path, max_num=max_num).to(torch.bfloat16).cuda() + for image_path in image_path_list + ] + pixel_values = torch.cat(pixel_values_list, dim=0) + response, _ = model.chat( + model.tokenizer, + pixel_values, + question, + generation_config, + history=None, + return_history=True, + ) + return response diff --git a/evaluation/scripts/mmlongbench/models/minicpm_llama3.py b/evaluation/scripts/mmlongbench/models/minicpm_llama3.py new file mode 100644 index 00000000..7f6d4b74 --- /dev/null +++ b/evaluation/scripts/mmlongbench/models/minicpm_llama3.py @@ -0,0 +1,56 @@ +import torch + +from PIL import Image +from transformers import AutoModel, AutoTokenizer + + +def init_model(cache_path): + model_path = ( + cache_path + if (cache_path is not None and cache_path != "None") + else "openbmb/MiniCPM-Llama3-V-2_5" + ) + model = AutoModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + device_map="auto", + ).eval() + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.tokenizer = tokenizer + return model + + +def get_response_concat(model, question, image_path_list, max_new_tokens=1024, temperature=1.0): + msgs = [] + system_prompt = "Answer in detail." + if system_prompt: + msgs.append({"type": "text", "value": system_prompt}) + if isinstance(image_path_list, list): + msgs.extend([{"type": "image", "value": p} for p in image_path_list]) + else: + msgs = [{"type": "image", "value": image_path_list}] + msgs.append({"type": "text", "value": question}) + + content = [] + for x in msgs: + if x["type"] == "text": + content.append(x["value"]) + elif x["type"] == "image": + image = Image.open(x["value"]).convert("RGB") + content.append(image) + msgs = [{"role": "user", "content": content}] + + with torch.cuda.amp.autocast(): + res = model.chat( + msgs=msgs, + context=None, + image=None, + max_new_tokens=max_new_tokens, + temperature=temperature, + do_sample=temperature != 0.0, + tokenizer=model.tokenizer, + ) + return res diff --git a/evaluation/scripts/mmlongbench/multimodal_test.py b/evaluation/scripts/mmlongbench/multimodal_test.py new file mode 100644 index 00000000..92921522 --- /dev/null +++ b/evaluation/scripts/mmlongbench/multimodal_test.py @@ -0,0 +1,185 @@ +import os +import shutil + +from dotenv import load_dotenv + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS + + +load_dotenv() + +db_name = "stx-mmlongbench-002" +user_id = "user_dc812220" + +# 1.1 Set openai config +openapi_config = { + "model_name_or_path": "gpt-4o", + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +# 1.2 Set neo4j config +neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") + +# 1.3 Create MOS Config +config = { + "user_id": user_id, + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": { + "backend": "openai", + "config": openapi_config, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, +} + +mos_config = MOSConfig(**config) +mos = MOS(mos_config) + +config = GeneralMemCubeConfig.model_validate( + { + "user_id": user_id, + "cube_id": f"{user_id}", + "text_mem": { + "backend": "tree_text", + "config": { + "extractor_llm": { + "backend": "openai", + "config": openapi_config, + }, + "dispatcher_llm": { + "backend": "openai", + "config": openapi_config, + }, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": neo4j_uri, + "user": "neo4j", + "password": "iaarlichunyu", + "db_name": db_name, + "user_name": user_id, + "use_multi_db": False, + "auto_create": True, + "embedding_dimension": 3072, + }, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "reorganize": False, + }, + }, + "act_mem": {}, + "para_mem": {}, + }, +) + + +# Filter out embedding fields, keeping only necessary fields +def filter_memory_data(memories_data): + filtered_data = {} + for key, value in memories_data.items(): + if key == "text_mem": + filtered_data[key] = [] + for mem_group in value: + # Check if it's the new data structure (list of TextualMemoryItem objects) + if "memories" in mem_group and isinstance(mem_group["memories"], list): + # New data structure: directly a list of TextualMemoryItem objects + filtered_memories = [] + for memory_item in mem_group["memories"]: + # Create filtered dictionary + filtered_item = { + "id": memory_item.id, + "memory": memory_item.memory, + "metadata": {}, + } + # Filter metadata, excluding embedding + if hasattr(memory_item, "metadata") and memory_item.metadata: + for attr_name in dir(memory_item.metadata): + if not attr_name.startswith("_") and attr_name != "embedding": + attr_value = getattr(memory_item.metadata, attr_name) + if not callable(attr_value): + filtered_item["metadata"][attr_name] = attr_value + filtered_memories.append(filtered_item) + + filtered_group = { + "cube_id": mem_group.get("cube_id", ""), + "memories": filtered_memories, + } + filtered_data[key].append(filtered_group) + else: + # Old data structure: dictionary with nodes and edges + filtered_group = { + "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])} + } + for node in mem_group["memories"].get("nodes", []): + filtered_node = { + "id": node.get("id"), + "memory": node.get("memory"), + "metadata": { + k: v + for k, v in node.get("metadata", {}).items() + if k != "embedding" + }, + } + filtered_group["memories"]["nodes"].append(filtered_node) + filtered_data[key].append(filtered_group) + else: + filtered_data[key] = value + return filtered_data + + +mem_cube = GeneralMemCube(config) + +temp_dir = f"/tmp/{user_id}" +if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) +mem_cube.dump(temp_dir) +mos.register_mem_cube(temp_dir, mem_cube_id=user_id) + + +print("start answering...") +user_query = "图8美股变化的影响是什么" +print(f"👤 User query: {user_query}") +response = mos.chat(user_query) +print(f"🤖 Response: {response}") diff --git a/evaluation/scripts/utils/mem0_local.py b/evaluation/scripts/utils/mem0_local.py new file mode 100644 index 00000000..62b9d905 --- /dev/null +++ b/evaluation/scripts/utils/mem0_local.py @@ -0,0 +1,191 @@ +from typing import Any + +import requests + + +class Mem0Client: + def __init__(self, base_url: str = "http://localhost:8000"): + self.base_url = base_url + + def add( + self, + messages: list[dict], + timestamp: str | None = None, + user_id: str | None = None, + agent_id: str | None = None, + run_id: str | None = None, + metadata: dict[str, Any] | None = None, + ): + """Create memories.""" + url = f"{self.base_url}/memories" + + if metadata is None: + metadata = {} + + if user_id is None and agent_id is None and run_id is None: + raise ValueError("At least one of user_id, agent_id, or run_id must be provided.") + + if user_id: + metadata["user_id"] = user_id + if agent_id: + metadata["agent_id"] = agent_id + if run_id: + metadata["run_id"] = run_id + + metadata["timestamp"] = timestamp + + data = { + "messages": messages, + "user_id": user_id, + "agent_id": agent_id, + "run_id": run_id, + "metadata": metadata, + } + + response = requests.post(url, json=data) + response.raise_for_status() + return response.json() + + def search( + self, + query: str, + user_id: str | None = None, + agent_id: str | None = None, + run_id: str | None = None, + filters: dict[str, Any] | None = None, + top_k: int = 10, + ): + """Search memories.""" + url = f"{self.base_url}/search" + + if filters is None: + filters = {} + + data = { + "query": query, + "user_id": user_id, + "agent_id": agent_id, + "run_id": run_id, + "filters": filters, + } + + response = requests.post(url, json=data) + response.raise_for_status() + + results = response.json().get("results", []) + top_k_results = results[:top_k] if len(results) > top_k else results + + relations = response.json().get("relations", []) + top_k_relations = relations[:top_k] if len(relations) > top_k else relations + + return {"results": top_k_results, "relations": top_k_relations} + + def get_all( + self, user_id: str | None = None, agent_id: str | None = None, run_id: str | None = None + ): + """Retrieve all memories.""" + url = f"{self.base_url}/memories" + + params = {} + if user_id: + params["user_id"] = user_id + if agent_id: + params["agent_id"] = agent_id + if run_id: + params["run_id"] = run_id + + response = requests.get(url, params=params) + response.raise_for_status() + return response.json() + + def get(self, memory_id: str): + """Retrieve a specific memory by ID.""" + url = f"{self.base_url}/memories/{memory_id}" + + response = requests.get(url) + response.raise_for_status() + return response.json() + + def delete(self, memory_id: str): + """Delete a specific memory by ID.""" + url = f"{self.base_url}/memories/{memory_id}" + + response = requests.delete(url) + response.raise_for_status() + return response.json() + + def delete_all( + self, user_id: str | None = None, agent_id: str | None = None, run_id: str | None = None + ): + """Delete all memories for a user, agent, or run.""" + url = f"{self.base_url}/memories" + + params = {} + if user_id: + params["user_id"] = user_id + if agent_id: + params["agent_id"] = agent_id + if run_id: + params["run_id"] = run_id + + response = requests.delete(url, params=params) + response.raise_for_status() + return response.json() + + def reset(self): + """Reset the memory store.""" + url = f"{self.base_url}/reset" + + response = requests.post(url) + response.raise_for_status() + return response.json() + + +if __name__ == "__main__": + client = Mem0Client(base_url="http://localhost:9999") + + # Example usage + print("Adding memories...") + add_result_a = client.add( + messages=[{"role": "user", "content": "I like drinking coffee in the morning"}], + user_id="alice", + ) + print(add_result_a) + + add_result_b = client.add( + messages=[{"role": "user", "content": "I enjoy reading books in the evening"}], + user_id="alice", + ) + print(add_result_b) + + print("\nSearching memories...") + search_result = client.search( + query="When did Melanie paint a sunrise?", user_id="alice", top_k=10 + ) + print(search_result) + print(len(search_result.get("results", []))) + + print("\nRetrieving all memories...") + all_memories = client.get_all(user_id="alice") + print(all_memories) + print(len(all_memories.get("results", []))) + + print("\nRetrieving a specific memory...") + if all_memories and "results" in all_memories and len(all_memories["results"]) > 0: + memory_id = all_memories["results"][0]["id"] + specific_memory = client.get(memory_id) + print(specific_memory) + + print("\nDeleting a specific memory...") + if all_memories and "results" in all_memories and len(all_memories["results"]) > 0: + memory_id = all_memories["results"][0]["id"] + delete_result = client.delete(memory_id) + print(delete_result) + + print("\nDeleting all memories for user 'alice'...") + delete_all_result = client.delete_all(user_id="alice") + print(delete_all_result) + + print("\nResetting the memory store...") + reset_result = client.reset() + print(reset_result) diff --git a/evaluation/scripts/utils/memobase_utils.py b/evaluation/scripts/utils/memobase_utils.py new file mode 100644 index 00000000..dcf06ea3 --- /dev/null +++ b/evaluation/scripts/utils/memobase_utils.py @@ -0,0 +1,46 @@ +import time +import uuid + +from memobase import ChatBlob + + +def string_to_uuid(s: str, salt="memobase_client") -> str: + return str(uuid.uuid5(uuid.NAMESPACE_DNS, s + salt)) + + +def memobase_add_memory(user, message, retries=3): + for attempt in range(retries): + try: + _ = user.insert(ChatBlob(messages=message), sync=True) + return + except Exception as e: + if attempt < retries - 1: + time.sleep(1) + continue + else: + raise e + + +def memobase_search_memory( + client, user_id, query, max_memory_context_size, max_retries=3, retry_delay=1 +): + retries = 0 + real_uid = string_to_uuid(user_id) + u = client.get_user(real_uid, no_get=True) + + while retries < max_retries: + try: + memories = u.context( + max_token_size=max_memory_context_size, + chats=[{"role": "user", "content": query}], + event_similarity_threshold=0.2, + fill_window_with_events=True, + ) + return memories + except Exception as e: + print(f"Error during memory search: {e}") + print("Retrying...") + retries += 1 + if retries >= max_retries: + raise e + time.sleep(retry_delay) diff --git a/evaluation/scripts/utils/memos_api.py b/evaluation/scripts/utils/memos_api.py new file mode 100644 index 00000000..7b7f2a06 --- /dev/null +++ b/evaluation/scripts/utils/memos_api.py @@ -0,0 +1,63 @@ +import json + +import requests + + +class MemOSAPI: + def __init__(self, base_url: str = "http://localhost:8000"): + self.base_url = base_url + self.headers = {"Content-Type": "application/json"} + + def user_register(self, user_id: str): + """Register a user.""" + url = f"{self.base_url}/users/register" + payload = json.dumps({"user_id": user_id}) + response = requests.request("POST", url, data=payload, headers=self.headers) + return response.text + + def add(self, messages: list[dict], user_id: str | None = None): + """Create memories.""" + register_res = json.loads(self.user_register(user_id)) + cube_id = register_res["data"]["mem_cube_id"] + url = f"{self.base_url}/add" + payload = json.dumps({"messages": messages, "user_id": user_id, "mem_cube_id": cube_id}) + + response = requests.request("POST", url, data=payload, headers=self.headers) + return response.text + + def search(self, query: str, user_id: str | None = None, top_k: int = 10): + """Search memories.""" + url = f"{self.base_url}/search" + payload = json.dumps( + { + "query": query, + "user_id": user_id, + } + ) + + response = requests.request("POST", url, data=payload, headers=self.headers) + if response.status_code != 200: + response.raise_for_status() + else: + result = json.loads(response.text)["data"]["text_mem"][0]["memories"] + text_memories = [item["memory"] for item in result][:top_k] + return text_memories + + +if __name__ == "__main__": + client = MemOSAPI(base_url="http://localhost:8000") + # Example usage + try: + messages = [ + { + "role": "user", + "content": "I went to the store and bought a red apple.", + "chat_time": "2023-10-01T12:00:00Z", + } + ] + add_response = client.add(messages, user_id="user789") + print("Add memory response:", add_response) + search_response = client.search("red apple", user_id="user789", top_k=1) + print("Search memory response:", search_response) + except requests.RequestException as e: + print("An error occurred:", e) diff --git a/evaluation/scripts/utils/memos_filters.py b/evaluation/scripts/utils/memos_filters.py new file mode 100644 index 00000000..815f3143 --- /dev/null +++ b/evaluation/scripts/utils/memos_filters.py @@ -0,0 +1,51 @@ +def filter_memory_data(memories_data): + filtered_data = {} + for key, value in memories_data.items(): + if key == "text_mem": + filtered_data[key] = [] + for mem_group in value: + # Check if it's the new data structure (list of TextualMemoryItem objects) + if "memories" in mem_group and isinstance(mem_group["memories"], list): + # New data structure: directly a list of TextualMemoryItem objects + filtered_memories = [] + for memory_item in mem_group["memories"]: + # Create filtered dictionary + filtered_item = { + "id": memory_item.id, + "memory": memory_item.memory, + "metadata": {}, + } + # Filter metadata, excluding embedding + if hasattr(memory_item, "metadata") and memory_item.metadata: + for attr_name in dir(memory_item.metadata): + if not attr_name.startswith("_") and attr_name != "embedding": + attr_value = getattr(memory_item.metadata, attr_name) + if not callable(attr_value): + filtered_item["metadata"][attr_name] = attr_value + filtered_memories.append(filtered_item) + + filtered_group = { + "cube_id": mem_group.get("cube_id", ""), + "memories": filtered_memories, + } + filtered_data[key].append(filtered_group) + else: + # Old data structure: dictionary with nodes and edges + filtered_group = { + "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])} + } + for node in mem_group["memories"].get("nodes", []): + filtered_node = { + "id": node.get("id"), + "memory": node.get("memory"), + "metadata": { + k: v + for k, v in node.get("metadata", {}).items() + if k != "embedding" + }, + } + filtered_group["memories"]["nodes"].append(filtered_node) + filtered_data[key].append(filtered_group) + else: + filtered_data[key] = value + return filtered_data diff --git a/evaluation/scripts/xinyu/eval/__init__.py b/evaluation/scripts/xinyu/eval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/evaluation/scripts/xinyu/eval/eval_score_llm.py b/evaluation/scripts/xinyu/eval/eval_score_llm.py new file mode 100644 index 00000000..f5764ce3 --- /dev/null +++ b/evaluation/scripts/xinyu/eval/eval_score_llm.py @@ -0,0 +1,279 @@ +import os +import re +import traceback + +from collections import defaultdict +from math import isclose + +from memos.configs.mem_os import MOSConfig +from memos.llms.factory import LLMFactory + + +openapi_config = { + "model_name_or_path": "gpt-5-nano", + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +config = { + "user_id": "user_name", + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": {"backend": "openai", "config": openapi_config}, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, +} +mos_config = MOSConfig(**config) +chat_llm = LLMFactory.from_config(mos_config.chat_model) + + +def is_float_equal( + reference, prediction, include_percentage: bool = False, is_close: float = False +) -> bool: + def get_precision(gt_ans: float) -> int: + precision = 3 + if "." in str(gt_ans): + precision = len(str(gt_ans).split(".")[-1]) + return precision + + reference = float(str(reference).strip().rstrip("%").strip()) + try: + prediction = float(str(prediction).strip().rstrip("%").strip()) + except Exception: + return False + + gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + for item in gt_result: + try: + if is_close and isclose(item, prediction, rel_tol=0.01): + return True + precision = max(min(get_precision(prediction), get_precision(item)), 2) + if round(prediction, precision) == round(item, precision): + return True + except Exception: + continue + return False + + +def get_clean_string(s): + s = str(s).lower().strip() + + for suffix in ["mile", "miles", "million"]: + if s.endswith(suffix): + s = s[: -len(suffix)].strip() + + s = re.sub(r"\s*\([^)]*\)", "", s).strip() + s = re.sub(r"^['\"]|['\"]$", "", s).strip() + s = s.lstrip("$").rstrip("%").strip() + + return s + + +def is_exact_match(s): + flag = False + # Website + if "https://" in s: + flag = True + # code file + if s.endswith((".py", ".ipynb")) or s.startswith("page"): + flag = True + # telephone number + if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s): + flag = True + # time + if "a.m." in s or "p.m." in s: + flag = True + # YYYY-MM-DD + if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s): + flag = True + # YYYY-MM + if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s): + flag = True + # Email address + if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s): + flag = True + return flag + + +def isfloat(num): + try: + float(num) + return True + except ValueError: + return False + + +def eval_score(question, gt, pred): + prompt = """ + 你是一个评委,根据问题和标准答案对学生的答案进行打分。打分规则如下: + + 完全不对(0分): + 学生答案与问题无关,未展示出任何相关概念或知识。 + 对了一部分(0.5分): + 学生答案提供了一些相关信息,但未能直接回答问题。 + 答案中包含部分正确内容,但缺乏关键信息,导致整体理解不清。 + 基本正确(0.7分): + 学生答案提供了大部分关键信息,不过依然距离标准答案有一定缺失。 + 答案中包含部分关键内容,但缺乏部分信息,导致不够完整。 + 完全正确(1分): + 学生答案准确地回答了问题,涵盖所有关键信息。 + 表达清晰,逻辑合理,直接且有效地回应了问题。 + + 问题:{} + + 标准答案:{} + + 学生答案:{} + """ + + max_try = 20 + try_i = 0 + while try_i < max_try: + try: + llm_input_prompt_score = ( + prompt.format(question, gt, pred) + + """请返回给我一个json: + { + "分数": 1, + "理由": "xxxx" + }""" + ) + score = chat_llm.generate( + [ + {"role": "user", "content": llm_input_prompt_score}, + ] + ) + + print(f"score: {score}") + score_real = eval(score.replace("json", "").replace("\n", "").replace("```", "")) + return float(score_real["分数"]) + except Exception: + traceback.print_exc() + print(f"trying num {try_i}") + try_i += 1 + return -1 + + +def eval_acc_and_f1(samples): + evaluated_samples = [sample for sample in samples if "score" in sample] + if not evaluated_samples: + return 0.0, 0.0 + + acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples) + try: + recall = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"]) + precision = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"]) + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0 + except Exception: + f1 = 0.0 + + return acc, f1 + + +def show_results(samples, show_path=None): + for sample in samples: + sample["evidence_pages"] = eval(sample["evidence_pages"]) + sample["evidence_sources"] = eval(sample["evidence_sources"]) + + with open(show_path, "w") as f: + acc, f1 = eval_acc_and_f1(samples) + f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n") + f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n") + f.write("-----------------------\n") + + acc_single_page, _ = eval_acc_and_f1( + [sample for sample in samples if len(sample["evidence_pages"]) == 1] + ) + acc_multi_page, _ = eval_acc_and_f1( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable" + ] + ) + acc_neg, _ = eval_acc_and_f1( + [sample for sample in samples if sample["answer"] == "Not answerable"] + ) + + f.write( + "Single-page | Accuracy: {} | Question Number: {}\n".format( + acc_single_page, + len([sample for sample in samples if len(sample["evidence_pages"]) == 1]), + ) + ) + f.write( + "Cross-page | Accuracy: {} | Question Number: {}\n".format( + acc_multi_page, + len( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 + and sample["answer"] != "Not answerable" + ] + ), + ) + ) + f.write( + "Unanswerable | Accuracy: {} | Question Number: {}\n".format( + acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"]) + ) + ) + f.write("-----------------------\n") + + source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list) + for sample in samples: + for answer_source in sample["evidence_sources"]: + source_sample_dict[answer_source].append(sample) + document_type_dict[sample["doc_type"]].append(sample) + for type, sub_samples in source_sample_dict.items(): + f.write( + f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) + + f.write("-----------------------\n") + for type, sub_samples in document_type_dict.items(): + f.write( + f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) diff --git a/evaluation/scripts/xinyu/eval/extract_answer.py b/evaluation/scripts/xinyu/eval/extract_answer.py new file mode 100644 index 00000000..b7f7e686 --- /dev/null +++ b/evaluation/scripts/xinyu/eval/extract_answer.py @@ -0,0 +1,33 @@ +import os + +import openai + +from dotenv import load_dotenv + + +load_dotenv() +client = openai.Client( + api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +) + + +def extract_answer(question, output, prompt, model_name="gpt-4o"): + response = client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": prompt, + }, + {"role": "assistant", "content": f"\n\nQuestion:{question}\nAnalysis:{output}\n"}, + ], + temperature=0.0, + max_tokens=256, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + ) + response = response.choices[0].message.content + + return response diff --git a/evaluation/scripts/xinyu/eval_docs.py b/evaluation/scripts/xinyu/eval_docs.py new file mode 100644 index 00000000..807e0906 --- /dev/null +++ b/evaluation/scripts/xinyu/eval_docs.py @@ -0,0 +1,228 @@ +import csv +import json +import os +import re +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from dotenv import load_dotenv +from eval.eval_score_llm import eval_acc_and_f1, eval_score, show_results +from eval.extract_answer import extract_answer + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS + + +load_dotenv() +openapi_config = { + "model_name_or_path": "gpt-4o", + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +neo4j_uri = os.getenv("NEO4J_URI", "bolt://47.117.41.207:7687") +db_name = "stx-mmlongbench-003" +doc_paths = [ + f + for f in os.listdir("evaluation/data/xinyu/documents") + if os.path.isfile(os.path.join("evaluation/data/xinyu/documents", f)) +] + +with open("evaluation/data/xinyu/all_samples_with_gt.json") as f: + samples = json.load(f) + + +def get_user_name(doc_file): + csv_path = "evaluation/data/xinyu/user_doc_map.csv" + if os.path.exists(csv_path): + with open(csv_path, newline="", encoding="utf-8") as f: + reader = csv.reader(f) + for row in reader: + uid, path = row[0], row[1] + base = os.path.basename(path) + if base == doc_file or os.path.splitext(base)[0] == os.path.splitext(doc_file)[0]: + return uid + return "" + + +def process_doc(doc_file): + user_name = get_user_name(doc_file) + print(user_name, doc_file) + config = { + "user_id": user_name, + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": {"backend": "openai", "config": openapi_config}, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, + } + mos_config = MOSConfig(**config) + mos = MOS(mos_config) + + mem_cube_config = GeneralMemCubeConfig.model_validate( + { + "user_id": user_name, + "cube_id": user_name, + "text_mem": { + "backend": "tree_text", + "config": { + "extractor_llm": {"backend": "openai", "config": openapi_config}, + "dispatcher_llm": {"backend": "openai", "config": openapi_config}, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": neo4j_uri, + "user": "neo4j", + "password": "iaarlichunyu", + "db_name": db_name, + "user_name": user_name, + "use_multi_db": False, + "auto_create": True, + "embedding_dimension": 3072, + }, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "reorganize": False, + }, + }, + "act_mem": {}, + "para_mem": {}, + } + ) + mem_cube = GeneralMemCube(mem_cube_config) + + temp_dir = os.path.join("tmp", doc_file) + + if (not os.path.exists(temp_dir)) or (not os.listdir(temp_dir)): + mem_cube.dump(temp_dir) + + mos.register_mem_cube(temp_dir, mem_cube_id=user_name) + + with open("evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md") as f: + prompt = f.read() + + samples_res = [] + doc_samples = [s for s in samples if s.get("doc_id") == doc_file] + + if len(doc_samples) == 0: + return [] + + sample = doc_samples[0] + question_list = sample["question"] + answer_list = sample["answer"] + + for idx, question in enumerate(question_list): + gt = answer_list.get(str(idx)) + + try_cnt, is_success = 0, False + while True: + try: + mos.clear_messages() + response = mos.chat(question, user_name) + is_success = True + except Exception as e: + print(f"[{doc_file}] Error:", e) + traceback.print_exc() + try_cnt += 1 + response = "Failed" + if is_success or try_cnt > 5: + break + + sample_item = dict(sample) + sample_item["question"] = question + sample_item["answer"] = gt + sample_item["response"] = response + + extracted_res = extract_answer(sample_item["question"], response, prompt) + sample_item["extracted_res"] = extracted_res + + print("--------------------------------------") + pred_ans = extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip() + score = eval_score(question, gt, response) + + sample_item["pred"] = pred_ans + sample_item["score"] = score + samples_res.append(sample_item) + + print(f"Question: {question}") + print(f"Response: {sample_item['response']}") + print(f"Ground true: {gt}\tPred: {sample_item['pred']}\tScore: {sample_item['score']}") + + print("samples_res length: ", len(samples_res)) + return samples_res + + +if __name__ == "__main__": + results = [] + + with ThreadPoolExecutor(max_workers=4) as executor: + future_to_doc = {executor.submit(process_doc, doc_file): doc_file for doc_file in doc_paths} + + for future in as_completed(future_to_doc): + doc_file = future_to_doc[future] + try: + res = future.result() + results.extend(res) + + if len(res) > 0: + acc, f1 = eval_acc_and_f1(results) + print() + print(f"Avg acc: {acc}") + print(f"Avg f1: {f1}") + + with open("evaluation/data/xinyu/test_results.json", "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + + except Exception as e: + print(f"[{doc_file}] failed with {e}") + traceback.print_exc() + + acc, f1 = eval_acc_and_f1(results) + print("--------------------------------------") + print(f"Final avg acc: {acc}") + print(f"Final avg f1: {f1}") + + show_results( + results, + show_path=re.sub(r"\.json$", ".txt", "evaluation/data/xinyu/test_results_report.txt"), + ) diff --git a/evaluation/scripts/xinyu/import_docs.py b/evaluation/scripts/xinyu/import_docs.py new file mode 100644 index 00000000..6fe2f4e3 --- /dev/null +++ b/evaluation/scripts/xinyu/import_docs.py @@ -0,0 +1,87 @@ +import asyncio +import os +import traceback +import uuid + +from memos import log +from memos.configs.mem_reader import SimpleStructMemReaderConfig +from memos.configs.memory import TreeTextMemoryConfig +from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.tree import TreeTextMemory + + +logger = log.get_logger(__name__) +db_name = "stx-mmlongbench-003" +# Create a memory reader instance +reader_config = SimpleStructMemReaderConfig.from_json_file( + "examples/data/config/simple_struct_reader_config.json" +) +reader = SimpleStructMemReader(reader_config) + +tree_config = TreeTextMemoryConfig.from_json_file( + "examples/data/config/tree_config_shared_database.json" +) +tree_config.graph_db.config.db_name = db_name +# Processing Documents +existing_names = { + d for d in os.listdir("ppt_test_result") if os.path.isdir(os.path.join("ppt_test_result", d)) +} +doc_paths = [] +for f in os.listdir("evaluation/data/xinyu/documents"): + fp = os.path.join("evaluation/data/xinyu/documents", f) + if os.path.isfile(fp): + name = os.path.splitext(f)[0] + if name in existing_names: + continue + doc_paths.append(fp) +print(f"existing_names length: {len(existing_names)}") +print(f"doc_paths length: {len(doc_paths)}") + + +async def process_doc(doc_path): + print(f"🔄 Processing document: {doc_path}") + doc_file = doc_path.split("/")[-1].rsplit(".", 1)[0] + + # Generate random user id: 'user_' + random short hex + user_id = "user_" + uuid.uuid4().hex[:8] + # Persist mapping between user_id and doc_path + with open("evaluation/data/xinyu/user_doc_map.csv", "a", encoding="utf-8") as f: + f.write(f"{user_id},{doc_path}\n") + + tree_config.graph_db.config.user_name = user_id + temp_dir = "tmp/" + doc_file + my_tree_textual_memory = TreeTextMemory(tree_config) + doc_memory = await reader.get_memory( + [doc_path], "doc", info={"user_id": user_id, "session_id": "session_" + str(uuid.uuid4())} + ) + + count = 0 + for m_list in doc_memory: + count += len(m_list) + my_tree_textual_memory.add(m_list) + print("total memories: ", count) + + my_tree_textual_memory.dump(temp_dir) + return doc_path + + +async def main(): + batch_size = 2 + for i in range(0, len(doc_paths), batch_size): + batch = doc_paths[i : i + batch_size] + print(f"🚀 Starting batch {i // batch_size + 1} with {len(batch)} docs") + + tasks = [process_doc(p) for p in batch] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for p, result in zip(batch, results, strict=False): + if isinstance(result, Exception): + print(f"❌ Error processing {p}: {result}") + tb_text = "".join(traceback.TracebackException.from_exception(result).format()) + print(tb_text) + else: + print(f"✅ Finished {result}") + + +if __name__ == "__main__": + asyncio.run(main())