diff --git a/docker/.env.example b/.env.example similarity index 100% rename from docker/.env.example rename to .env.example diff --git a/evaluation/scripts/hotpot/hotpot_eval.py b/evaluation/scripts/hotpot/hotpot_eval.py new file mode 100644 index 000000000..05ff52349 --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_eval.py @@ -0,0 +1,224 @@ +import json +import os +import uuid + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from dotenv import load_dotenv +from tqdm import tqdm + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS + + +load_dotenv() + +db_name = "stx-hotpot-001" + + +user_name = str(uuid.uuid4()) + +# 1.1 Set openai config +openapi_config = { + "model_name_or_path": "gpt-4o-mini", + "temperature": 0.8, + "max_tokens": 1024, + "top_p": 0.9, + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +# 1.2 Set neo4j config +neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") + +# 1.3 Create MOS Config +config = { + "user_id": user_name, + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": { + "backend": "openai", + "config": openapi_config, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, +} + +mos_config = MOSConfig(**config) +# you can set PRO_MODE to True to enable CoT enhancement mos_config.PRO_MODE = True +mos = MOS(mos_config) + + +# Filter out embedding fields, keeping only necessary fields +def filter_memory_data(memories_data): + filtered_data = {} + for key, value in memories_data.items(): + if key == "text_mem": + filtered_data[key] = [] + for mem_group in value: + # Check if it's the new data structure (list of TextualMemoryItem objects) + if "memories" in mem_group and isinstance(mem_group["memories"], list): + # New data structure: directly a list of TextualMemoryItem objects + filtered_memories = [] + for memory_item in mem_group["memories"]: + # Create filtered dictionary + filtered_item = { + "id": memory_item.id, + "memory": memory_item.memory, + "metadata": {}, + } + # Filter metadata, excluding embedding + if hasattr(memory_item, "metadata") and memory_item.metadata: + for attr_name in dir(memory_item.metadata): + if not attr_name.startswith("_") and attr_name != "embedding": + attr_value = getattr(memory_item.metadata, attr_name) + if not callable(attr_value): + filtered_item["metadata"][attr_name] = attr_value + filtered_memories.append(filtered_item) + + filtered_group = { + "cube_id": mem_group.get("cube_id", ""), + "memories": filtered_memories, + } + filtered_data[key].append(filtered_group) + else: + # Old data structure: dictionary with nodes and edges + filtered_group = { + "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])} + } + for node in mem_group["memories"].get("nodes", []): + filtered_node = { + "id": node.get("id"), + "memory": node.get("memory"), + "metadata": { + k: v + for k, v in node.get("metadata", {}).items() + if k != "embedding" + }, + } + filtered_group["memories"]["nodes"].append(filtered_node) + filtered_data[key].append(filtered_group) + else: + filtered_data[key] = value + return filtered_data + + +config = GeneralMemCubeConfig.model_validate( + { + "user_id": user_name, + "cube_id": f"{user_name}", + "text_mem": { + "backend": "tree_text", + "config": { + "extractor_llm": { + "backend": "openai", + "config": openapi_config, + }, + "dispatcher_llm": { + "backend": "openai", + "config": openapi_config, + }, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": neo4j_uri, + "user": "neo4j", + "password": "iaarlichunyu", + "db_name": db_name, + "auto_create": True, + }, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "reorganize": True, + }, + }, + "act_mem": {}, + "para_mem": {}, + }, +) + +mem_cube = GeneralMemCube(config) + + +mos.register_mem_cube(f"/tmp/{user_name}", mem_cube_id=user_name) + + +with open("evaluation/data/hotpot/hotpot_dev_distractor_v1.json") as f: + data = json.load(f) + + +def build_context_text(context_list): + parts = [] + for title, sentences in context_list: + text = " ".join(s.strip() for s in sentences if s.strip()) + parts.append(f"{title}: {text}") + return "\n".join(parts) + + +def build_and_ask(item): + qid = item["_id"] + question = item["question"] + + for title, sentences in item["context"]: + text = " ".join(s.strip() for s in sentences if s.strip()) + memory_content = f"{title}: {text}" + mos.add(memory_content=memory_content) + + answer = mos.chat(question).strip() + return qid, answer + + +pred_answers = {} + +with ThreadPoolExecutor(max_workers=5) as executor: + futures = {executor.submit(build_and_ask, item): item for item in data} + for future in tqdm(as_completed(futures), total=len(futures)): + try: + qid, answer = future.result() + pred_answers[qid] = answer + except Exception as e: + print(f"Error: {e}") + +predictions = {"answer": pred_answers, "sp": []} + +with open("evaluation/data/hotpot/output/dev_distractor_pred.json", "w") as f: + json.dump(predictions, f, ensure_ascii=False, indent=2) diff --git a/evaluation/scripts/hotpot/hotpot_evaluate_v1.py b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py new file mode 100644 index 000000000..d4d6e71e1 --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py @@ -0,0 +1,151 @@ +import re +import string +import sys + +from collections import Counter + +import ujson as json + + +def normalize_answer(s): + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + zero_metric = (0, 0, 0) + + if ( + normalized_prediction in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): + return zero_metric + if ( + normalized_ground_truth in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): + return zero_metric + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return zero_metric + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, precision, recall + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def update_answer(metrics, prediction, gold): + em = exact_match_score(prediction, gold) + f1, prec, recall = f1_score(prediction, gold) + metrics["em"] += float(em) + metrics["f1"] += f1 + metrics["prec"] += prec + metrics["recall"] += recall + return em, prec, recall + + +def update_sp(metrics, prediction, gold): + cur_sp_pred = set(map(tuple, prediction)) + gold_sp_pred = set(map(tuple, gold)) + tp, fp, fn = 0, 0, 0 + for e in cur_sp_pred: + if e in gold_sp_pred: + tp += 1 + else: + fp += 1 + for e in gold_sp_pred: + if e not in cur_sp_pred: + fn += 1 + prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 + recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 + f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 + em = 1.0 if fp + fn == 0 else 0.0 + metrics["sp_em"] += em + metrics["sp_f1"] += f1 + metrics["sp_prec"] += prec + metrics["sp_recall"] += recall + return em, prec, recall + + +def eval(prediction_file, gold_file): + with open(prediction_file) as f: + prediction = json.load(f) + with open(gold_file) as f: + gold = json.load(f) + + metrics = { + "em": 0, + "f1": 0, + "prec": 0, + "recall": 0, + "sp_em": 0, + "sp_f1": 0, + "sp_prec": 0, + "sp_recall": 0, + "joint_em": 0, + "joint_f1": 0, + "joint_prec": 0, + "joint_recall": 0, + } + for dp in gold: + cur_id = dp["_id"] + can_eval_joint = True + if cur_id not in prediction["answer"]: + print(f"missing answer {cur_id}") + can_eval_joint = False + else: + em, prec, recall = update_answer(metrics, prediction["answer"][cur_id], dp["answer"]) + if cur_id not in prediction["sp"]: + print(f"missing sp fact {cur_id}") + can_eval_joint = False + else: + sp_em, sp_prec, sp_recall = update_sp( + metrics, prediction["sp"][cur_id], dp["supporting_facts"] + ) + + if can_eval_joint: + joint_prec = prec * sp_prec + joint_recall = recall * sp_recall + if joint_prec + joint_recall > 0: + joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) + else: + joint_f1 = 0.0 + joint_em = em * sp_em + + metrics["joint_em"] += joint_em + metrics["joint_f1"] += joint_f1 + metrics["joint_prec"] += joint_prec + metrics["joint_recall"] += joint_recall + + n = len(gold) + for k in metrics: + metrics[k] /= n + + print(metrics) + + +if __name__ == "__main__": + eval(sys.argv[1], sys.argv[2]) diff --git a/evaluation/scripts/longbenchV2/import_data.py b/evaluation/scripts/longbenchV2/import_data.py new file mode 100644 index 000000000..ff17cd275 --- /dev/null +++ b/evaluation/scripts/longbenchV2/import_data.py @@ -0,0 +1,20 @@ +from datasets import load_dataset + + +dataset = load_dataset("zai-org/LongBench-v2", split="train") +print(dataset) + + +def truncate(value, max_len=200): + if isinstance(value, str) and len(value) > max_len: + return value[:max_len] + "... [TRUNCATED]" + return value + + +for i in range(10): + sample = dataset[i] + print(f"========== Sample {i} ==========") + for key, value in sample.items(): + print(f"{key}: {truncate(value)}") + + print("\n") diff --git a/evaluation/scripts/mmlongbench/eval/__init__.py b/evaluation/scripts/mmlongbench/eval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/mmlongbench/eval/eval_score.py b/evaluation/scripts/mmlongbench/eval/eval_score.py new file mode 100644 index 000000000..02ef6eb53 --- /dev/null +++ b/evaluation/scripts/mmlongbench/eval/eval_score.py @@ -0,0 +1,246 @@ +import re + +from collections import defaultdict +from math import isclose + + +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] + + +def anls_compute(groundtruth, prediction, threshold=0.5): + dist = levenshtein_distance(groundtruth, prediction) + length = max(len(groundtruth.upper()), len(prediction.upper())) + value = 0.0 if length == 0 else float(dist) / float(length) + anls = 1.0 - value + if anls <= threshold: + anls = 0.0 + return anls + + +def is_float_equal( + reference, prediction, include_percentage: bool = False, is_close: float = False +) -> bool: + def get_precision(gt_ans: float) -> int: + precision = 3 + if "." in str(gt_ans): + precision = len(str(gt_ans).split(".")[-1]) + return precision + + reference = float(str(reference).strip().rstrip("%").strip()) + try: + prediction = float(str(prediction).strip().rstrip("%").strip()) + except Exception: + return False + + gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + for item in gt_result: + try: + if is_close and isclose(item, prediction, rel_tol=0.01): + return True + precision = max(min(get_precision(prediction), get_precision(item)), 2) + if round(prediction, precision) == round(item, precision): + return True + except Exception: + continue + return False + + +def get_clean_string(s): + s = str(s).lower().strip() + + for suffix in ["mile", "miles", "million"]: + if s.endswith(suffix): + s = s[: -len(suffix)].strip() + + s = re.sub(r"\s*\([^)]*\)", "", s).strip() + s = re.sub(r"^['\"]|['\"]$", "", s).strip() + s = s.lstrip("$").rstrip("%").strip() + + return s + + +def is_exact_match(s): + flag = False + # Website + if "https://" in s: + flag = True + # code file + if s.endswith((".py", ".ipynb")) or s.startswith("page"): + flag = True + # telephone number + if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s): + flag = True + # time + if "a.m." in s or "p.m." in s: + flag = True + # YYYY-MM-DD + if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s): + flag = True + # YYYY-MM + if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s): + flag = True + # Email address + if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s): + flag = True + return flag + + +def isfloat(num): + try: + float(num) + return True + except ValueError: + return False + + +def eval_score(gt, pred, answer_type): + if answer_type == "Int": + try: + gt, pred = int(gt), int(float(pred)) + except Exception: + pred = "" + score = gt == pred + elif answer_type == "Float": + try: + gt = float(get_clean_string(str(gt))) + pred = float(get_clean_string(str(pred))) + except Exception: + pred = "" + score = is_float_equal(gt, pred, include_percentage=True, is_close=True) + elif answer_type in ["Str", "None"]: + gt = get_clean_string(gt) + pred = get_clean_string(pred) + score = gt == pred if is_exact_match(gt) else anls_compute(gt, pred) + else: + if isinstance(gt, str) and gt.startswith("["): + gt = eval(gt) + if not isinstance(gt, list): + gt = [gt] + if isinstance(pred, str) and pred.startswith("["): + pred = eval(pred) + if not isinstance(pred, list): + pred = [pred] + print(len(gt), len(pred)) + if len(gt) != len(pred): + score = 0.0 + else: + gt = sorted([get_clean_string(a) for a in gt]) + pred = sorted([get_clean_string(a) for a in pred]) + print(gt, pred) + if isfloat(gt[0]) or is_exact_match(gt[0]): + score = "-".join(gt) == "-".join(pred) + else: + score = min( + [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)] + ) + + return float(score) + + +def eval_acc_and_f1(samples): + evaluated_samples = [sample for sample in samples if "score" in sample] + if not evaluated_samples: + return 0.0, 0.0 + + acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples) + try: + recall = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"]) + precision = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"]) + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0 + except Exception: + f1 = 0.0 + + return acc, f1 + + +def show_results(samples, show_path=None): + for sample in samples: + sample["evidence_pages"] = eval(sample["evidence_pages"]) + sample["evidence_sources"] = eval(sample["evidence_sources"]) + + with open(show_path, "w") as f: + acc, f1 = eval_acc_and_f1(samples) + f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n") + f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n") + f.write("-----------------------\n") + + acc_single_page, _ = eval_acc_and_f1( + [sample for sample in samples if len(sample["evidence_pages"]) == 1] + ) + acc_multi_page, _ = eval_acc_and_f1( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable" + ] + ) + acc_neg, _ = eval_acc_and_f1( + [sample for sample in samples if sample["answer"] == "Not answerable"] + ) + + f.write( + "Single-page | Accuracy: {} | Question Number: {}\n".format( + acc_single_page, + len([sample for sample in samples if len(sample["evidence_pages"]) == 1]), + ) + ) + f.write( + "Cross-page | Accuracy: {} | Question Number: {}\n".format( + acc_multi_page, + len( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 + and sample["answer"] != "Not answerable" + ] + ), + ) + ) + f.write( + "Unanswerable | Accuracy: {} | Question Number: {}\n".format( + acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"]) + ) + ) + f.write("-----------------------\n") + + source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list) + for sample in samples: + for answer_source in sample["evidence_sources"]: + source_sample_dict[answer_source].append(sample) + document_type_dict[sample["doc_type"]].append(sample) + for type, sub_samples in source_sample_dict.items(): + f.write( + f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) + + f.write("-----------------------\n") + for type, sub_samples in document_type_dict.items(): + f.write( + f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) diff --git a/evaluation/scripts/mmlongbench/eval/extract_answer.py b/evaluation/scripts/mmlongbench/eval/extract_answer.py new file mode 100644 index 000000000..b7f7e6863 --- /dev/null +++ b/evaluation/scripts/mmlongbench/eval/extract_answer.py @@ -0,0 +1,33 @@ +import os + +import openai + +from dotenv import load_dotenv + + +load_dotenv() +client = openai.Client( + api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +) + + +def extract_answer(question, output, prompt, model_name="gpt-4o"): + response = client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": prompt, + }, + {"role": "assistant", "content": f"\n\nQuestion:{question}\nAnalysis:{output}\n"}, + ], + temperature=0.0, + max_tokens=256, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + ) + response = response.choices[0].message.content + + return response diff --git a/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md b/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md new file mode 100644 index 000000000..a309c0935 --- /dev/null +++ b/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md @@ -0,0 +1,35 @@ +Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis. +- Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List. If you find the analysis the question can not be answered from the given documents, type "Not answerable". Exception: If the analysis only tells you that it can not read/understand the images or documents, type "Fail to answer". +- Please make your response as concise as possible. Also note that your response should be formatted as below: +``` +Extracted answer: [answer] +Answer format: [answer format] +``` + +Please read the following example, then extract the answer from the model response and type it at the end of the prompt. + +--- +Question: List the primary questions asked about the services in this report. +Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n1. Is the service safe?\n2. Is the service effective?\n3. Is the service caring?\n4. Is the service responsive?\n5. Is the service well-led? +Extracted answer: ['Is the servife safe?', 'Is the service effective', 'Is the serve caring?', 'Is the service responsive?', 'Is the service well-led?'] +Answer format: List + +--- +Question: How many regulations of the HSCA 2008 are breached in all according to this report? +Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities) Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11: Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17: Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9. Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing, safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to notify the CQC of incidents. +Extracted answer: 10 +Answer format: Integer + +--- +Question: According to the survey that is the percentage of Chinese who are paying more or about the same attention to politics after Trump's election? +Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying more or about the same attention to politics after Trump's election. The report focuses primarily on American demographics and does not include specific details about the Chinese population in relation to this question. If you need information about a different demographic or a summary of the findings from the American demographic, I can certainly help with that! +Extracted answer: Not answerable +Answer format: String + +--- +Question: How many quotations from male respondent over 50 years old are included in this report? +Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be able to help you with your question. +Extracted answer: Fail to answer +Answer format: String + +--- diff --git a/evaluation/scripts/mmlongbench/eval_docs.py b/evaluation/scripts/mmlongbench/eval_docs.py new file mode 100644 index 000000000..510a0b1ed --- /dev/null +++ b/evaluation/scripts/mmlongbench/eval_docs.py @@ -0,0 +1,265 @@ +import csv +import json +import os +import re +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from dotenv import load_dotenv +from eval.eval_score import eval_acc_and_f1, eval_score, show_results +from eval.extract_answer import extract_answer +from tqdm import tqdm + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS + + +load_dotenv() +openapi_config = { + "model_name_or_path": "gpt-4o", + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +neo4j_uri = os.getenv("NEO4J_URI", "bolt://47.117.41.207:7687") +db_name = "stx-mmlongbench-004" + +doc_paths = [ + f + for f in os.listdir("evaluation/data/mmlongbench/documents") + if os.path.isfile(os.path.join("evaluation/data/mmlongbench/documents", f)) +] + +with open("evaluation/data/mmlongbench/samples.json") as f: + samples = json.load(f) + +RESULTS_PATH = "evaluation/data/mmlongbench/test_results.json" +completed_pairs: set[tuple[str, str]] = set() + + +def _load_existing_results(): + global completed_pairs + if os.path.exists(RESULTS_PATH): + try: + with open(RESULTS_PATH, encoding="utf-8") as f: + existing = json.load(f) + for r in existing: + did = r.get("doc_id") + q = r.get("question") + if did and q: + completed_pairs.add((did, q)) + return existing + except Exception: + return [] + return [] + + +def _doc_has_pending(doc_file: str) -> bool: + for s in samples: + if s.get("doc_id") == doc_file and (doc_file, s.get("question")) not in completed_pairs: + return True + return False + + +def get_user_name(doc_file): + csv_path = "evaluation/data/mmlongbench/user_doc_map.csv" + if os.path.exists(csv_path): + with open(csv_path, newline="", encoding="utf-8") as f: + reader = csv.reader(f) + for row in reader: + uid, path = row[0], row[1] + base = os.path.basename(path) + if base == doc_file or os.path.splitext(base)[0] == os.path.splitext(doc_file)[0]: + return uid + return "" + + +def process_doc(doc_file): + user_name = get_user_name(doc_file) + print(user_name, doc_file) + config = { + "user_id": user_name, + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": {"backend": "openai", "config": openapi_config}, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, + } + mos_config = MOSConfig(**config) + mos = MOS(mos_config) + + mem_cube_config = GeneralMemCubeConfig.model_validate( + { + "user_id": user_name, + "cube_id": user_name, + "text_mem": { + "backend": "tree_text", + "config": { + "extractor_llm": {"backend": "openai", "config": openapi_config}, + "dispatcher_llm": {"backend": "openai", "config": openapi_config}, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": neo4j_uri, + "user": "neo4j", + "password": "iaarlichunyu", + "db_name": db_name, + "user_name": user_name, + "use_multi_db": False, + "auto_create": True, + "embedding_dimension": 3072, + }, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "reorganize": False, + }, + }, + "act_mem": {}, + "para_mem": {}, + } + ) + mem_cube = GeneralMemCube(mem_cube_config) + + temp_dir = "tmp/" + doc_file + if not os.path.exists(temp_dir) or not os.listdir(temp_dir): + mem_cube.dump(temp_dir) + + mos.register_mem_cube(temp_dir, mem_cube_id=user_name) + + with open("evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md") as f: + prompt = f.read() + + samples_res = [] + doc_samples = [s for s in samples if s.get("doc_id") == doc_file] + if len(doc_samples) == 0: + return [] + + for sample in tqdm(doc_samples, desc=f"Processing {doc_file}"): + if (doc_file, sample.get("question")) in completed_pairs: + continue + messages = sample["question"] + try_cnt, is_success = 0, False + + while True: + try: + mos.clear_messages() + response = mos.chat(messages, user_name) + is_success = True + except Exception as e: + print(f"[{doc_file}] Error:", e) + traceback.print_exc() + try_cnt += 1 + response = "Failed" + if is_success or try_cnt > 5: + break + + sample["response"] = response + extracted_res = extract_answer(sample["question"], response, prompt) + sample["extracted_res"] = extracted_res + + pred_ans = extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip() + score = eval_score(sample["answer"], pred_ans, sample["answer_format"]) + + sample["pred"] = pred_ans + sample["score"] = score + samples_res.append(sample) + + print("--------------------------------------") + print(f"Question: {sample['question']}") + print(f"Response: {sample['response']}") + print(f"Ground true: {sample['answer']}\tPred: {sample['pred']}\tScore: {sample['score']}") + + return samples_res + + +if __name__ == "__main__": + results = _load_existing_results() + total_samples = len(samples) + processed_samples = len(completed_pairs) + pending_samples = total_samples - processed_samples + sample_doc_ids = [s.get("doc_id") for s in samples if s.get("doc_id")] + all_docs_in_samples = set(sample_doc_ids) + processed_docs = {d for d, _ in completed_pairs} + with ThreadPoolExecutor(max_workers=4) as executor: + pending_docs = [d for d in doc_paths if _doc_has_pending(d)] + print("\n" + "=" * 80) + print("📊 评测进度统计") + print("=" * 80) + print(f"✅ 已加载历史结果: {len(results)} 条") + print(f"📂 数据集总样本: {total_samples}") + print(f"🧪 已完成样本: {processed_samples}") + print(f"⏳ 待处理样本: {pending_samples}") + print(f"📄 数据集中总文档: {len(all_docs_in_samples)}") + print(f"✔️ 已完成文档: {len(processed_docs)}") + print(f"➡️ 待处理文档(本次将运行): {len(pending_docs)}") + print("=" * 80 + "\n") + future_to_doc = { + executor.submit(process_doc, doc_file): doc_file for doc_file in pending_docs + } + + for future in as_completed(future_to_doc): + doc_file = future_to_doc[future] + try: + res = future.result() + results.extend(res) + + if len(res) > 0: + acc, f1 = eval_acc_and_f1(results) + print() + print(f"Avg acc: {acc}") + print(f"Avg f1: {f1}") + + with open(RESULTS_PATH, "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"[{doc_file}] failed with {e}") + + acc, f1 = eval_acc_and_f1(results) + print("--------------------------------------") + print(f"Final avg acc: {acc}") + print(f"Final avg f1: {f1}") + + show_results( + results, + show_path=re.sub(r"\.json$", ".txt", "evaluation/data/mmlongbench/test_results_report.txt"), + ) diff --git a/evaluation/scripts/mmlongbench/import_docs.py b/evaluation/scripts/mmlongbench/import_docs.py new file mode 100644 index 000000000..540c8f960 --- /dev/null +++ b/evaluation/scripts/mmlongbench/import_docs.py @@ -0,0 +1,88 @@ +import asyncio +import os +import traceback +import uuid + +from memos import log +from memos.configs.mem_reader import SimpleStructMemReaderConfig +from memos.configs.memory import TreeTextMemoryConfig +from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.tree import TreeTextMemory + + +logger = log.get_logger(__name__) +db_name = "stx-mmlongbench-004" +# Create a memory reader instance +reader_config = SimpleStructMemReaderConfig.from_json_file( + "examples/data/config/simple_struct_reader_config.json" +) +reader = SimpleStructMemReader(reader_config) + +tree_config = TreeTextMemoryConfig.from_json_file( + "examples/data/config/tree_config_shared_database.json" +) +tree_config.graph_db.config.db_name = db_name +# Processing Documents +existing_names = { + d for d in os.listdir("ppt_test_result") if os.path.isdir(os.path.join("ppt_test_result", d)) +} +doc_paths = [] +for f in os.listdir("evaluation/data/mmlongbench/documents"): + fp = os.path.join("evaluation/data/mmlongbench/documents", f) + if os.path.isfile(fp): + name = os.path.splitext(f)[0] + if name in existing_names: + continue + doc_paths.append(fp) + +print("existing_names length:", len(existing_names)) +print("doc_paths length:", len(doc_paths)) + + +async def process_doc(doc_path): + print(f"🔄 Processing document: {doc_path}") + # Generate random user id: 'user_' + random short hex + user_id = "user_" + uuid.uuid4().hex[:8] + # Persist mapping between user_id and doc_path + try: + os.makedirs("evaluation/data/mmlongbench", exist_ok=True) + with open("evaluation/data/mmlongbench/user_doc_map.csv", "a", encoding="utf-8") as f: + f.write(f"{user_id},{doc_path}\n") + except Exception as e: + logger.error(f"Failed to write user-doc mapping: {e}") + + tree_config.graph_db.config.user_name = user_id + my_tree_textual_memory = TreeTextMemory(tree_config) + doc_memory = await reader.get_memory( + [doc_path], "doc", info={"user_id": user_id, "session_id": "session_" + str(uuid.uuid4())} + ) + + count = 0 + for m_list in doc_memory: + count += len(m_list) + my_tree_textual_memory.add(m_list) + print("total memories: ", count) + + return doc_path + + +async def main(): + batch_size = 4 + for i in range(0, len(doc_paths), batch_size): + batch = doc_paths[i : i + batch_size] + print(f"🚀 Starting batch {i // batch_size + 1} with {len(batch)} docs") + + tasks = [process_doc(p) for p in batch] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for p, result in zip(batch, results, strict=False): + if isinstance(result, Exception): + print(f"❌ Error processing {p}: {result}") + tb_text = "".join(traceback.TracebackException.from_exception(result).format()) + print(tb_text) + else: + print(f"✅ Finished {result}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/evaluation/scripts/mmlongbench/models/__init__.py b/evaluation/scripts/mmlongbench/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/mmlongbench/models/internlm_xc2_4khd.py b/evaluation/scripts/mmlongbench/models/internlm_xc2_4khd.py new file mode 100644 index 000000000..ae62eec9a --- /dev/null +++ b/evaluation/scripts/mmlongbench/models/internlm_xc2_4khd.py @@ -0,0 +1,128 @@ +import torch +import torch.nn.functional as func + +from transformers import AutoModel, AutoTokenizer + + +torch.set_grad_enabled(False) + + +try: + from transformers.generation.streamers import BaseStreamer +except Exception: + BaseStreamer = None + + +def chat( + model, + tokenizer, + query: str, + image: None, + hd_num: int = 25, + history: list[tuple[str, str]] | None = None, + streamer: BaseStreamer | None = None, + max_new_tokens: int = 1024, + temperature: float = 1.0, + top_p: float = 0.8, + repetition_penalty: float = 1.005, + meta_instruction: str = "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n" + "- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by the user such as English and 中文.\n" + "- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively based on the provided image.", + **kwargs, +): + if history is None: + history = [] + if image is None: + inputs = model.build_inputs(tokenizer, query, history, meta_instruction) + im_mask = torch.zeros(inputs["input_ids"].shape[:2]).cuda().bool() + else: + if isinstance(image, str): + with torch.cuda.amp.autocast(): + image = model.encode_img(image, hd_num=hd_num) + inputs, im_mask = model.interleav_wrap_chat( + tokenizer, query, image, history, meta_instruction + ) + if isinstance(image, str): + image_list = [] + with torch.cuda.amp.autocast(): + for image_path in image: + tmp = model.encode_img(image_path, hd_num=hd_num) + image_list.append(tmp) + if len(image_list) > 1 and image_list[-1].shape[1] != image_list[-2].shape[1]: + image_list[-1] = func.interpolate( + image_list[-1].unsqueeze(1), size=image_list[-2].shape[1:], mode="bilinear" + ).squeeze(1) + image = torch.cat(image_list, dim=0) + with torch.cuda.amp.autocast(): + inputs, im_mask = model.interleav_wrap_chat( + tokenizer, query, image, history, meta_instruction + ) + else: + raise NotImplementedError + inputs = {k: v.to(model.device) for k, v in inputs.items() if torch.is_tensor(v)} + # also add end-of-assistant token in eos token id to avoid unnecessary generation + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids(["[UNUSED_TOKEN_145]"])[0], + ] + with torch.cuda.amp.autocast(): + outputs = model.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=temperature != 0.0, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + im_mask=im_mask, + **kwargs, + ) + if image is None: + outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :] + else: + outputs = outputs[0].cpu().tolist() + + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.partition("[UNUSED_TOKEN_145]")[0] + + history = [*history, (query, response)] + return response, history + + +def init_model(cache_path): + model_path = ( + cache_path + if (cache_path is not None and cache_path != "None") + else "internlm/internlm-xcomposer2-4khd-7b" + ) + model = AutoModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + device_map="auto", + ).eval() + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.tokenizer = tokenizer + return model + + +def get_response_concat(model, question, image_path_list, max_new_tokens=1024, temperature=1.0): + query = " " * len(image_path_list) + question + try: + response, _ = chat( + model, + model.tokenizer, + query=query, + image=image_path_list, + max_new_tokens=max_new_tokens, + hd_num=16, + temperature=temperature, + ) + except Exception as e: + print(e) + response = "Failed" + return response diff --git a/evaluation/scripts/mmlongbench/models/internvl_chat.py b/evaluation/scripts/mmlongbench/models/internvl_chat.py new file mode 100644 index 000000000..793f1d35a --- /dev/null +++ b/evaluation/scripts/mmlongbench/models/internvl_chat.py @@ -0,0 +1,137 @@ +import torch +import torchvision.transforms as tf + +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoTokenizer + + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def build_transform(input_size, mean=IMAGENET_MEAN, std=IMAGENET_STD): + return tf.Compose( + [ + tf.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + tf.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + tf.ToTensor(), + tf.Normalize(mean=mean, std=std), + ] + ) + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def load_image(image_file, input_size=448, max_num=6): + image = Image.open(image_file).convert("RGB") + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + +def init_model(cache_path): + import os + + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + + model_path = ( + cache_path + if (cache_path is not None and cache_path != "None") + else "OpenGVLab/InternVL-Chat-V1-5" + ) + model = AutoModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + device_map="auto", + ).eval() + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.tokenizer = tokenizer + return model + + +def get_response_concat( + model, question, image_path_list, max_new_tokens=1024, temperature=1.0, max_num=6 +): + generation_config = { + "num_beams": 1, + "max_new_tokens": max_new_tokens, + "do_sample": temperature != 0.0, + "temperature": temperature, + } + pixel_values_list = [ + load_image(image_path, max_num=max_num).to(torch.bfloat16).cuda() + for image_path in image_path_list + ] + pixel_values = torch.cat(pixel_values_list, dim=0) + response, _ = model.chat( + model.tokenizer, + pixel_values, + question, + generation_config, + history=None, + return_history=True, + ) + return response diff --git a/evaluation/scripts/mmlongbench/models/minicpm_llama3.py b/evaluation/scripts/mmlongbench/models/minicpm_llama3.py new file mode 100644 index 000000000..7f6d4b743 --- /dev/null +++ b/evaluation/scripts/mmlongbench/models/minicpm_llama3.py @@ -0,0 +1,56 @@ +import torch + +from PIL import Image +from transformers import AutoModel, AutoTokenizer + + +def init_model(cache_path): + model_path = ( + cache_path + if (cache_path is not None and cache_path != "None") + else "openbmb/MiniCPM-Llama3-V-2_5" + ) + model = AutoModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + device_map="auto", + ).eval() + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.tokenizer = tokenizer + return model + + +def get_response_concat(model, question, image_path_list, max_new_tokens=1024, temperature=1.0): + msgs = [] + system_prompt = "Answer in detail." + if system_prompt: + msgs.append({"type": "text", "value": system_prompt}) + if isinstance(image_path_list, list): + msgs.extend([{"type": "image", "value": p} for p in image_path_list]) + else: + msgs = [{"type": "image", "value": image_path_list}] + msgs.append({"type": "text", "value": question}) + + content = [] + for x in msgs: + if x["type"] == "text": + content.append(x["value"]) + elif x["type"] == "image": + image = Image.open(x["value"]).convert("RGB") + content.append(image) + msgs = [{"role": "user", "content": content}] + + with torch.cuda.amp.autocast(): + res = model.chat( + msgs=msgs, + context=None, + image=None, + max_new_tokens=max_new_tokens, + temperature=temperature, + do_sample=temperature != 0.0, + tokenizer=model.tokenizer, + ) + return res diff --git a/evaluation/scripts/mmlongbench/multimodal_test.py b/evaluation/scripts/mmlongbench/multimodal_test.py new file mode 100644 index 000000000..929215229 --- /dev/null +++ b/evaluation/scripts/mmlongbench/multimodal_test.py @@ -0,0 +1,185 @@ +import os +import shutil + +from dotenv import load_dotenv + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS + + +load_dotenv() + +db_name = "stx-mmlongbench-002" +user_id = "user_dc812220" + +# 1.1 Set openai config +openapi_config = { + "model_name_or_path": "gpt-4o", + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +# 1.2 Set neo4j config +neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") + +# 1.3 Create MOS Config +config = { + "user_id": user_id, + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": { + "backend": "openai", + "config": openapi_config, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, +} + +mos_config = MOSConfig(**config) +mos = MOS(mos_config) + +config = GeneralMemCubeConfig.model_validate( + { + "user_id": user_id, + "cube_id": f"{user_id}", + "text_mem": { + "backend": "tree_text", + "config": { + "extractor_llm": { + "backend": "openai", + "config": openapi_config, + }, + "dispatcher_llm": { + "backend": "openai", + "config": openapi_config, + }, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": neo4j_uri, + "user": "neo4j", + "password": "iaarlichunyu", + "db_name": db_name, + "user_name": user_id, + "use_multi_db": False, + "auto_create": True, + "embedding_dimension": 3072, + }, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "reorganize": False, + }, + }, + "act_mem": {}, + "para_mem": {}, + }, +) + + +# Filter out embedding fields, keeping only necessary fields +def filter_memory_data(memories_data): + filtered_data = {} + for key, value in memories_data.items(): + if key == "text_mem": + filtered_data[key] = [] + for mem_group in value: + # Check if it's the new data structure (list of TextualMemoryItem objects) + if "memories" in mem_group and isinstance(mem_group["memories"], list): + # New data structure: directly a list of TextualMemoryItem objects + filtered_memories = [] + for memory_item in mem_group["memories"]: + # Create filtered dictionary + filtered_item = { + "id": memory_item.id, + "memory": memory_item.memory, + "metadata": {}, + } + # Filter metadata, excluding embedding + if hasattr(memory_item, "metadata") and memory_item.metadata: + for attr_name in dir(memory_item.metadata): + if not attr_name.startswith("_") and attr_name != "embedding": + attr_value = getattr(memory_item.metadata, attr_name) + if not callable(attr_value): + filtered_item["metadata"][attr_name] = attr_value + filtered_memories.append(filtered_item) + + filtered_group = { + "cube_id": mem_group.get("cube_id", ""), + "memories": filtered_memories, + } + filtered_data[key].append(filtered_group) + else: + # Old data structure: dictionary with nodes and edges + filtered_group = { + "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])} + } + for node in mem_group["memories"].get("nodes", []): + filtered_node = { + "id": node.get("id"), + "memory": node.get("memory"), + "metadata": { + k: v + for k, v in node.get("metadata", {}).items() + if k != "embedding" + }, + } + filtered_group["memories"]["nodes"].append(filtered_node) + filtered_data[key].append(filtered_group) + else: + filtered_data[key] = value + return filtered_data + + +mem_cube = GeneralMemCube(config) + +temp_dir = f"/tmp/{user_id}" +if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) +mem_cube.dump(temp_dir) +mos.register_mem_cube(temp_dir, mem_cube_id=user_id) + + +print("start answering...") +user_query = "图8美股变化的影响是什么" +print(f"👤 User query: {user_query}") +response = mos.chat(user_query) +print(f"🤖 Response: {response}") diff --git a/evaluation/scripts/xinyu/eval/__init__.py b/evaluation/scripts/xinyu/eval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/xinyu/eval/eval_score_llm.py b/evaluation/scripts/xinyu/eval/eval_score_llm.py new file mode 100644 index 000000000..f5764ce39 --- /dev/null +++ b/evaluation/scripts/xinyu/eval/eval_score_llm.py @@ -0,0 +1,279 @@ +import os +import re +import traceback + +from collections import defaultdict +from math import isclose + +from memos.configs.mem_os import MOSConfig +from memos.llms.factory import LLMFactory + + +openapi_config = { + "model_name_or_path": "gpt-5-nano", + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +config = { + "user_id": "user_name", + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": {"backend": "openai", "config": openapi_config}, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, +} +mos_config = MOSConfig(**config) +chat_llm = LLMFactory.from_config(mos_config.chat_model) + + +def is_float_equal( + reference, prediction, include_percentage: bool = False, is_close: float = False +) -> bool: + def get_precision(gt_ans: float) -> int: + precision = 3 + if "." in str(gt_ans): + precision = len(str(gt_ans).split(".")[-1]) + return precision + + reference = float(str(reference).strip().rstrip("%").strip()) + try: + prediction = float(str(prediction).strip().rstrip("%").strip()) + except Exception: + return False + + gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + for item in gt_result: + try: + if is_close and isclose(item, prediction, rel_tol=0.01): + return True + precision = max(min(get_precision(prediction), get_precision(item)), 2) + if round(prediction, precision) == round(item, precision): + return True + except Exception: + continue + return False + + +def get_clean_string(s): + s = str(s).lower().strip() + + for suffix in ["mile", "miles", "million"]: + if s.endswith(suffix): + s = s[: -len(suffix)].strip() + + s = re.sub(r"\s*\([^)]*\)", "", s).strip() + s = re.sub(r"^['\"]|['\"]$", "", s).strip() + s = s.lstrip("$").rstrip("%").strip() + + return s + + +def is_exact_match(s): + flag = False + # Website + if "https://" in s: + flag = True + # code file + if s.endswith((".py", ".ipynb")) or s.startswith("page"): + flag = True + # telephone number + if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s): + flag = True + # time + if "a.m." in s or "p.m." in s: + flag = True + # YYYY-MM-DD + if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s): + flag = True + # YYYY-MM + if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s): + flag = True + # Email address + if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s): + flag = True + return flag + + +def isfloat(num): + try: + float(num) + return True + except ValueError: + return False + + +def eval_score(question, gt, pred): + prompt = """ + 你是一个评委,根据问题和标准答案对学生的答案进行打分。打分规则如下: + + 完全不对(0分): + 学生答案与问题无关,未展示出任何相关概念或知识。 + 对了一部分(0.5分): + 学生答案提供了一些相关信息,但未能直接回答问题。 + 答案中包含部分正确内容,但缺乏关键信息,导致整体理解不清。 + 基本正确(0.7分): + 学生答案提供了大部分关键信息,不过依然距离标准答案有一定缺失。 + 答案中包含部分关键内容,但缺乏部分信息,导致不够完整。 + 完全正确(1分): + 学生答案准确地回答了问题,涵盖所有关键信息。 + 表达清晰,逻辑合理,直接且有效地回应了问题。 + + 问题:{} + + 标准答案:{} + + 学生答案:{} + """ + + max_try = 20 + try_i = 0 + while try_i < max_try: + try: + llm_input_prompt_score = ( + prompt.format(question, gt, pred) + + """请返回给我一个json: + { + "分数": 1, + "理由": "xxxx" + }""" + ) + score = chat_llm.generate( + [ + {"role": "user", "content": llm_input_prompt_score}, + ] + ) + + print(f"score: {score}") + score_real = eval(score.replace("json", "").replace("\n", "").replace("```", "")) + return float(score_real["分数"]) + except Exception: + traceback.print_exc() + print(f"trying num {try_i}") + try_i += 1 + return -1 + + +def eval_acc_and_f1(samples): + evaluated_samples = [sample for sample in samples if "score" in sample] + if not evaluated_samples: + return 0.0, 0.0 + + acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples) + try: + recall = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"]) + precision = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"]) + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0 + except Exception: + f1 = 0.0 + + return acc, f1 + + +def show_results(samples, show_path=None): + for sample in samples: + sample["evidence_pages"] = eval(sample["evidence_pages"]) + sample["evidence_sources"] = eval(sample["evidence_sources"]) + + with open(show_path, "w") as f: + acc, f1 = eval_acc_and_f1(samples) + f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n") + f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n") + f.write("-----------------------\n") + + acc_single_page, _ = eval_acc_and_f1( + [sample for sample in samples if len(sample["evidence_pages"]) == 1] + ) + acc_multi_page, _ = eval_acc_and_f1( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable" + ] + ) + acc_neg, _ = eval_acc_and_f1( + [sample for sample in samples if sample["answer"] == "Not answerable"] + ) + + f.write( + "Single-page | Accuracy: {} | Question Number: {}\n".format( + acc_single_page, + len([sample for sample in samples if len(sample["evidence_pages"]) == 1]), + ) + ) + f.write( + "Cross-page | Accuracy: {} | Question Number: {}\n".format( + acc_multi_page, + len( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 + and sample["answer"] != "Not answerable" + ] + ), + ) + ) + f.write( + "Unanswerable | Accuracy: {} | Question Number: {}\n".format( + acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"]) + ) + ) + f.write("-----------------------\n") + + source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list) + for sample in samples: + for answer_source in sample["evidence_sources"]: + source_sample_dict[answer_source].append(sample) + document_type_dict[sample["doc_type"]].append(sample) + for type, sub_samples in source_sample_dict.items(): + f.write( + f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) + + f.write("-----------------------\n") + for type, sub_samples in document_type_dict.items(): + f.write( + f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) diff --git a/evaluation/scripts/xinyu/eval/extract_answer.py b/evaluation/scripts/xinyu/eval/extract_answer.py new file mode 100644 index 000000000..b7f7e6863 --- /dev/null +++ b/evaluation/scripts/xinyu/eval/extract_answer.py @@ -0,0 +1,33 @@ +import os + +import openai + +from dotenv import load_dotenv + + +load_dotenv() +client = openai.Client( + api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +) + + +def extract_answer(question, output, prompt, model_name="gpt-4o"): + response = client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": prompt, + }, + {"role": "assistant", "content": f"\n\nQuestion:{question}\nAnalysis:{output}\n"}, + ], + temperature=0.0, + max_tokens=256, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + ) + response = response.choices[0].message.content + + return response diff --git a/evaluation/scripts/xinyu/eval_docs.py b/evaluation/scripts/xinyu/eval_docs.py new file mode 100644 index 000000000..03a333201 --- /dev/null +++ b/evaluation/scripts/xinyu/eval_docs.py @@ -0,0 +1,228 @@ +import csv +import json +import os +import re +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from dotenv import load_dotenv + +from evaluation.scripts.mmlongbench.eval.extract_answer import extract_answer +from evaluation.scripts.xinyu.eval.eval_score_llm import eval_acc_and_f1, eval_score, show_results +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS + + +load_dotenv() +openapi_config = { + "model_name_or_path": "gpt-4o", + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +neo4j_uri = os.getenv("NEO4J_URI", "bolt://47.117.41.207:7687") +db_name = "stx-mmlongbench-003" +doc_paths = [ + f + for f in os.listdir("evaluation/data/xinyu/documents") + if os.path.isfile(os.path.join("evaluation/data/xinyu/documents", f)) +] + +with open("evaluation/data/xinyu/all_samples_with_gt.json") as f: + samples = json.load(f) + + +def get_user_name(doc_file): + csv_path = "evaluation/data/xinyu/user_doc_map.csv" + if os.path.exists(csv_path): + with open(csv_path, newline="", encoding="utf-8") as f: + reader = csv.reader(f) + for row in reader: + uid, path = row[0], row[1] + base = os.path.basename(path) + if base == doc_file or os.path.splitext(base)[0] == os.path.splitext(doc_file)[0]: + return uid + return "" + + +def process_doc(doc_file): + user_name = get_user_name(doc_file) + print(user_name, doc_file) + config = { + "user_id": user_name, + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": {"backend": "openai", "config": openapi_config}, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, + } + mos_config = MOSConfig(**config) + mos = MOS(mos_config) + + mem_cube_config = GeneralMemCubeConfig.model_validate( + { + "user_id": user_name, + "cube_id": user_name, + "text_mem": { + "backend": "tree_text", + "config": { + "extractor_llm": {"backend": "openai", "config": openapi_config}, + "dispatcher_llm": {"backend": "openai", "config": openapi_config}, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": neo4j_uri, + "user": "neo4j", + "password": "iaarlichunyu", + "db_name": db_name, + "user_name": user_name, + "use_multi_db": False, + "auto_create": True, + "embedding_dimension": 3072, + }, + }, + "embedder": { + "backend": "universal_api", + "config": { + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + }, + }, + "reorganize": False, + }, + }, + "act_mem": {}, + "para_mem": {}, + } + ) + mem_cube = GeneralMemCube(mem_cube_config) + + temp_dir = os.path.join("tmp", doc_file) + + if (not os.path.exists(temp_dir)) or (not os.listdir(temp_dir)): + mem_cube.dump(temp_dir) + + mos.register_mem_cube(temp_dir, mem_cube_id=user_name) + + with open("evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md") as f: + prompt = f.read() + + samples_res = [] + doc_samples = [s for s in samples if s.get("doc_id") == doc_file] + + if len(doc_samples) == 0: + return [] + + sample = doc_samples[0] + question_list = sample["question"] + answer_list = sample["answer"] + + for idx, question in enumerate(question_list): + gt = answer_list.get(str(idx)) + + try_cnt, is_success = 0, False + while True: + try: + mos.clear_messages() + response = mos.chat(question, user_name) + is_success = True + except Exception as e: + print(f"[{doc_file}] Error:", e) + traceback.print_exc() + try_cnt += 1 + response = "Failed" + if is_success or try_cnt > 5: + break + + sample_item = dict(sample) + sample_item["question"] = question + sample_item["answer"] = gt + sample_item["response"] = response + + extracted_res = extract_answer(sample_item["question"], response, prompt) + sample_item["extracted_res"] = extracted_res + + print("--------------------------------------") + pred_ans = extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip() + score = eval_score(question, gt, response) + + sample_item["pred"] = pred_ans + sample_item["score"] = score + samples_res.append(sample_item) + + print(f"Question: {question}") + print(f"Response: {sample_item['response']}") + print(f"Ground true: {gt}\tPred: {sample_item['pred']}\tScore: {sample_item['score']}") + + print("samples_res length: ", len(samples_res)) + return samples_res + + +if __name__ == "__main__": + results = [] + + with ThreadPoolExecutor(max_workers=4) as executor: + future_to_doc = {executor.submit(process_doc, doc_file): doc_file for doc_file in doc_paths} + + for future in as_completed(future_to_doc): + doc_file = future_to_doc[future] + try: + res = future.result() + results.extend(res) + + if len(res) > 0: + acc, f1 = eval_acc_and_f1(results) + print() + print(f"Avg acc: {acc}") + print(f"Avg f1: {f1}") + + with open("evaluation/data/xinyu/test_results.json", "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + + except Exception as e: + print(f"[{doc_file}] failed with {e}") + traceback.print_exc() + + acc, f1 = eval_acc_and_f1(results) + print("--------------------------------------") + print(f"Final avg acc: {acc}") + print(f"Final avg f1: {f1}") + + show_results( + results, + show_path=re.sub(r"\.json$", ".txt", "evaluation/data/xinyu/test_results_report.txt"), + ) diff --git a/evaluation/scripts/xinyu/import_docs.py b/evaluation/scripts/xinyu/import_docs.py new file mode 100644 index 000000000..6fe2f4e34 --- /dev/null +++ b/evaluation/scripts/xinyu/import_docs.py @@ -0,0 +1,87 @@ +import asyncio +import os +import traceback +import uuid + +from memos import log +from memos.configs.mem_reader import SimpleStructMemReaderConfig +from memos.configs.memory import TreeTextMemoryConfig +from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.tree import TreeTextMemory + + +logger = log.get_logger(__name__) +db_name = "stx-mmlongbench-003" +# Create a memory reader instance +reader_config = SimpleStructMemReaderConfig.from_json_file( + "examples/data/config/simple_struct_reader_config.json" +) +reader = SimpleStructMemReader(reader_config) + +tree_config = TreeTextMemoryConfig.from_json_file( + "examples/data/config/tree_config_shared_database.json" +) +tree_config.graph_db.config.db_name = db_name +# Processing Documents +existing_names = { + d for d in os.listdir("ppt_test_result") if os.path.isdir(os.path.join("ppt_test_result", d)) +} +doc_paths = [] +for f in os.listdir("evaluation/data/xinyu/documents"): + fp = os.path.join("evaluation/data/xinyu/documents", f) + if os.path.isfile(fp): + name = os.path.splitext(f)[0] + if name in existing_names: + continue + doc_paths.append(fp) +print(f"existing_names length: {len(existing_names)}") +print(f"doc_paths length: {len(doc_paths)}") + + +async def process_doc(doc_path): + print(f"🔄 Processing document: {doc_path}") + doc_file = doc_path.split("/")[-1].rsplit(".", 1)[0] + + # Generate random user id: 'user_' + random short hex + user_id = "user_" + uuid.uuid4().hex[:8] + # Persist mapping between user_id and doc_path + with open("evaluation/data/xinyu/user_doc_map.csv", "a", encoding="utf-8") as f: + f.write(f"{user_id},{doc_path}\n") + + tree_config.graph_db.config.user_name = user_id + temp_dir = "tmp/" + doc_file + my_tree_textual_memory = TreeTextMemory(tree_config) + doc_memory = await reader.get_memory( + [doc_path], "doc", info={"user_id": user_id, "session_id": "session_" + str(uuid.uuid4())} + ) + + count = 0 + for m_list in doc_memory: + count += len(m_list) + my_tree_textual_memory.add(m_list) + print("total memories: ", count) + + my_tree_textual_memory.dump(temp_dir) + return doc_path + + +async def main(): + batch_size = 2 + for i in range(0, len(doc_paths), batch_size): + batch = doc_paths[i : i + batch_size] + print(f"🚀 Starting batch {i // batch_size + 1} with {len(batch)} docs") + + tasks = [process_doc(p) for p in batch] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for p, result in zip(batch, results, strict=False): + if isinstance(result, Exception): + print(f"❌ Error processing {p}: {result}") + tb_text = "".join(traceback.TracebackException.from_exception(result).format()) + print(tb_text) + else: + print(f"✅ Finished {result}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/core_memories/tree_textual_memory.py b/examples/core_memories/tree_textual_memory.py index d2e197e5b..f27c69b6f 100644 --- a/examples/core_memories/tree_textual_memory.py +++ b/examples/core_memories/tree_textual_memory.py @@ -1,5 +1,3 @@ -import time - from memos import log from memos.configs.mem_reader import SimpleStructMemReaderConfig from memos.configs.memory import TreeTextMemoryConfig @@ -8,8 +6,6 @@ logger = log.get_logger(__name__) - - tree_config = TreeTextMemoryConfig.from_json_file( "examples/data/config/tree_config_shared_database.json" ) @@ -22,232 +18,16 @@ ) reader = SimpleStructMemReader(reader_config) -scene_data = [ - [ - {"role": "user", "chat_time": "3 May 2025", "content": "I’m feeling a bit down today."}, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "I’m sorry to hear that. Do you want to talk about what’s been going on?", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "It’s just been a tough couple of days, you know? Everything feels a bit overwhelming, and I just can’t seem to shake it off.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "It sounds like you're going through a lot right now. Sometimes it helps to talk things out. Is there something specific that's been weighing on you, or is it more of a general feeling?", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "It’s a mix, I guess. Work’s been really stressful, and on top of that, I’ve been feeling kinda disconnected from the people around me.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "That can be really draining, especially when you’re feeling isolated on top of the stress. Do you think there’s something from your past that’s contributing to how you’re feeling now? Sometimes our emotions are tied to older experiences.", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "Yeah, now that you mention it… I’ve been thinking a lot about my childhood lately. I didn’t have the easiest upbringing. My parents were always busy, and I often felt like I had to figure things out on my own. I remember one winter, I got really sick with a high fever. My parents were both working, and I didn’t want to bother them. I just lay in bed, crying by myself, not knowing what to do.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "That must have been very lonely. Growing up without that support can really make you feel like you have to handle everything on your own. How do you think that experience has affected you now?", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "Yeah, it was really lonely. I remember spending a lot of time doing things on my own, like drawing, because it was my only way to escape. I used to draw animals, and I created my own little world on paper. It was peaceful in a way, just me and my imagination, with no outside distractions.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "Drawing sounds like a beautiful way to express yourself and find some peace in that solitude. It must have been a very quiet and private world for you. Do you think that sense of loneliness from back then has shaped how you connect with others today?", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "I think it has. I’ve always found it hard to let others in. I feel like I have to handle everything on my own. If I rely on someone else, I feel vulnerable, and I start to feel that loneliness again.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "It sounds like this was a protective mechanism that helped you cope with your feelings of loneliness back then. But as you’ve grown, it seems like this pattern might not be serving you as well anymore. Have you thought about trying to break this cycle, or finding new ways to connect with others?", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "I’ve thought about it, though it’s hard. But I also know I can’t keep going like this forever. Maybe it’s time to seek some help.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "It’s a big step, but a positive one. Exploring your emotions and past experiences could really help you understand how they’ve shaped your present. Therapy could be a great place to start, as it would help you connect with both yourself and others. Would you be open to trying it?", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "I’m open to trying. Thank you for encouraging me. It feels like it might be a good starting point for me.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "I’m really glad to hear that. Taking that first step is always difficult, but you’ve made an important decision. If you need help finding the right therapist, or just want to talk, I’m here for you.", - }, - ], - [ - { - "role": "user", - "chat_time": "3 May 2025", - "content": "This week I’ve been feeling a bit off, especially when playing tennis. My body just doesn’t feel right.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "It sounds like you've been having some physical discomfort lately. Where specifically are you feeling unwell? Is there something particular about tennis that's been difficult?", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "I just feel a bit fatigued, and every time I play tennis, my movements feel off. I used to play pretty well, but lately, I’ve been feeling weak and my movements aren’t as coordinated as they used to be.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "Physical discomfort can definitely affect performance, especially in a sport like tennis that requires quick reactions and coordination. Have you noticed anything specific that might be causing these changes? Could it be overtraining, or is there another physical issue making you feel off?", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "I think it might be due to stress and lack of sleep recently. I just feel drained all the time. Plus, I’ve been frustrated with my tennis performance, which makes me feel even worse, like I’m stuck in a vicious cycle.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "Stress and poor sleep quality can definitely affect your physical performance, especially in sports. When you’re also feeling down about your performance, it’s easy to fall into a negative cycle. Have you tried anything to relieve some of the stress, or to improve your sleep?", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "I’ve tried some relaxation techniques, like deep breathing and simple meditation, but it doesn’t seem to help much. I still can’t focus during tennis, and I feel like my mind is distracted by other things.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "Relaxation techniques like deep breathing and meditation can help reduce stress, but sometimes we need a more structured approach to improve both our physical and mental state. For example, you might want to focus on specific aspects of your tennis training, or adjust your rest and recovery time. Have you thought about setting smaller goals for yourself? Like focusing on a specific tennis move each day, rather than expecting perfection right away? That might help you gradually regain confidence.", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "That’s a good idea. I think I’ve been expecting too much from myself, which just makes me more anxious. Maybe I can start with small practice sessions and gradually find my rhythm again.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "Exactly! Setting small and clear goals will help reduce pressure and give you a sense of accomplishment. And don’t forget that rest is just as important—taking time to recover will help you feel better physically and mentally. Both physical and mental recovery are key to improving your performance and health.", - }, - { - "role": "user", - "chat_time": "3 May 2025", - "content": "Thanks, I feel like I need to rethink a few things. Maybe I really need to be kinder to myself and take things step by step.", - }, - { - "role": "assistant", - "chat_time": "3 May 2025", - "content": "You’re on the right track. Rethinking things and giving yourself space and patience is so important. You’ve already taken the first step, now just keep moving forward, one step at a time. If you need anything, I’m always here to help.", - }, - ], -] - -# Acquiring memories -memory = reader.get_memory(scene_data, type="chat", info={"user_id": "1234", "session_id": "2222"}) - -for m_list in memory: - added_ids = my_tree_textual_memory.add(m_list) - for i, id in enumerate(added_ids): - print(f"{i}'th added result is:" + my_tree_textual_memory.get(id).memory) - my_tree_textual_memory.memory_manager.wait_reorganizer() - -time.sleep(60) - -init_time = time.time() -results = my_tree_textual_memory.search( - "Talk about the user's childhood story?", - top_k=10, - info={ - "query": "Talk about the user's childhood story?", - "user_id": "111", - "session_id": "2234", - "chat_history": [{"role": "user", "content": "xxxxx"}], - }, -) -for i, r in enumerate(results): - r = r.to_dict() - print(f"{i}'th similar result is: " + str(r["memory"])) -print(f"Successfully search {len(results)} memories in {round(time.time() - init_time)}s") - -# try this when use 'fine' mode (Note that you should pass the internet Config, refer to examples/core_memories/textual_internet_memoy.py) -init_time = time.time() -results_fine_search = my_tree_textual_memory.search( - "Recent news in the first city you've mentioned.", - top_k=10, - mode="fine", - info={ - "query": "Recent news in NewYork", - "user_id": "111", - "session_id": "2234", - "chat_history": [ - {"role": "user", "content": "I want to know three beautiful cities"}, - {"role": "assistant", "content": "New York, London, and Shanghai"}, - ], - }, -) - -for i, r in enumerate(results_fine_search): - r = r.to_dict() - print(f"{i}'th similar result is: " + str(r["memory"])) -print( - f"Successfully search {len(results_fine_search)} memories in {round(time.time() - init_time)}s" -) - -# find related nodes -related_nodes = my_tree_textual_memory.get_relevant_subgraph("Painting") - -# get current memory_size -print(f"Current Memory Size is {my_tree_textual_memory.get_current_memory_size()}") - -logger.info("Start doc search example...") # Processing Documents doc_paths = [ - "./text1.txt", - "./text2.txt", + "evaluation/data/mmlongbench/documents/0b85477387a9d0cc33fca0f4becaa0e5.pdf", + "evaluation/data/mmlongbench/documents/0e94b4197b10096b1f4c699701570fbf.pdf", ] # Acquiring memories from documents doc_memory = reader.get_memory(doc_paths, "doc", info={"user_id": "1111", "session_id": "2222"}) for m_list in doc_memory: added_ids = my_tree_textual_memory.add(m_list) - my_tree_textual_memory.memory_manager.wait_reorganizer() - -results = my_tree_textual_memory.search( - "Tell me about what memos consist of?", - top_k=30, - info={"query": "Tell me about what memos consist of?", "user_id": "111", "session": "2234"}, -) -for i, r in enumerate(results): - r = r.to_dict() - print(f"{i}'th similar result is: " + str(r["memory"])) -print(f"Successfully search {len(results)} memories") - -# close the synchronous thread in memory manager -my_tree_textual_memory.memory_manager.close() # my_tree_textual_memory.dump my_tree_textual_memory.dump("tmp/my_tree_textual_memory") -my_tree_textual_memory.drop() diff --git a/examples/data/config/simple_struct_reader_config.json b/examples/data/config/simple_struct_reader_config.json index 8a27de457..b0cc139c2 100644 --- a/examples/data/config/simple_struct_reader_config.json +++ b/examples/data/config/simple_struct_reader_config.json @@ -1,17 +1,24 @@ { "llm": { - "backend": "ollama", + "backend": "openai", "config": { - "model_name_or_path": "qwen3:0.6b", - "temperature": 0.0, + "model_name_or_path": "gpt-4o-mini", + "temperature": 0.8, + "max_tokens": 1024, + "top_p": 0.9, + "top_k": 50, "remove_think_prefix": true, - "max_tokens": 8192 + "api_key": "sk-5hTdYsmX7ErNsLQKJ8jUTahHzyDjQ3H5mpgAeTfUfeRyIUYV", + "api_base": "http://123.129.219.111:3000/v1" } }, "embedder": { - "backend": "ollama", + "backend": "universal_api", "config": { - "model_name_or_path": "nomic-embed-text:latest" + "provider": "openai", + "api_key": "sk-5hTdYsmX7ErNsLQKJ8jUTahHzyDjQ3H5mpgAeTfUfeRyIUYV", + "model_name_or_path": "text-embedding-3-large", + "base_url": "http://123.129.219.111:3000/v1" } }, "chunker": { diff --git a/examples/data/config/tree_config.json b/examples/data/config/tree_config.json index bcb07b3a4..2fbcce164 100644 --- a/examples/data/config/tree_config.json +++ b/examples/data/config/tree_config.json @@ -18,9 +18,12 @@ } }, "embedder": { - "backend": "ollama", + "backend": "universal_api", "config": { - "model_name_or_path": "nomic-embed-text:latest" + "provider": "openai", + "api_key": "sk-xxxxxxxxx", + "model_name_or_path": "text-embedding-3-large", + "base_url": "http://xxx.xxx.xxx.xxx:xxxx/v1" } }, "graph_db": { @@ -31,8 +34,8 @@ "password": "12345678", "db_name": "user08alice", "auto_create": true, - "embedding_dimension": 768 + "embedding_dimension": 3072 } }, - "reorganize": false + "reorganize": true } diff --git a/examples/data/config/tree_config_shared_database.json b/examples/data/config/tree_config_shared_database.json index 914f3c725..cb4dffb28 100644 --- a/examples/data/config/tree_config_shared_database.json +++ b/examples/data/config/tree_config_shared_database.json @@ -1,40 +1,48 @@ { "extractor_llm": { - "backend": "ollama", + "backend": "openai", "config": { - "model_name_or_path": "qwen3:0.6b", - "temperature": 0.0, + "model_name_or_path": "gpt-4o-mini", + "temperature": 0.8, + "max_tokens": 1024, + "top_p": 0.9, + "top_k": 50, "remove_think_prefix": true, - "max_tokens": 8192 + "api_key": "sk-5hTdYsmX7ErNsLQKJ8jUTahHzyDjQ3H5mpgAeTfUfeRyIUYV", + "api_base": "http://123.129.219.111:3000/v1" } }, "dispatcher_llm": { - "backend": "ollama", + "backend": "openai", "config": { - "model_name_or_path": "qwen3:0.6b", - "temperature": 0.0, + "model_name_or_path": "gpt-4o-mini", + "temperature": 0.8, + "max_tokens": 1024, + "top_p": 0.9, + "top_k": 50, "remove_think_prefix": true, - "max_tokens": 8192 + "api_key": "sk-5hTdYsmX7ErNsLQKJ8jUTahHzyDjQ3H5mpgAeTfUfeRyIUYV", + "api_base": "http://123.129.219.111:3000/v1" } }, "embedder": { - "backend": "ollama", + "backend": "universal_api", "config": { - "model_name_or_path": "nomic-embed-text:latest" + "provider": "openai", + "api_key": "sk-5hTdYsmX7ErNsLQKJ8jUTahHzyDjQ3H5mpgAeTfUfeRyIUYV", + "model_name_or_path": "text-embedding-3-large", + "base_url": "http://123.129.219.111:3000/v1" } }, "graph_db": { "backend": "neo4j", "config": { - "uri": "bolt://localhost:7687", + "uri": "bolt://47.117.45.189:7687", "user": "neo4j", - "password": "12345678", - "db_name": "shared-tree-textual-memory", - "user_name": "alice", - "auto_create": true, - "use_multi_db": false, - "embedding_dimension": 768 + "password": "iaarlichunyu", + "db_name": "", + "auto_create": true } }, - "reorganize": true + "reorganize": false } diff --git a/examples/mem_os/simple_openapi_memos.py b/examples/mem_os/simple_openapi_memos.py index a9faa3dc8..cf26c17f9 100644 --- a/examples/mem_os/simple_openapi_memos.py +++ b/examples/mem_os/simple_openapi_memos.py @@ -14,6 +14,7 @@ load_dotenv() +db_name = "stx-pdf-doc-test" # 1. Create MOS Config and set openai config print(f"🚀 [{datetime.now().strftime('%H:%M:%S')}] Starting to create MOS configuration...") start_time = time.time() @@ -50,9 +51,12 @@ "config": openapi_config, }, "embedder": { - "backend": "ollama", + "backend": "universal_api", "config": { - "model_name_or_path": "nomic-embed-text:latest", + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), }, }, "chunker": { @@ -105,17 +109,21 @@ "config": { "uri": neo4j_uri, "user": "neo4j", - "password": "12345678", - "db_name": f"db{user_name.replace('-', '')}", + "password": "iaarlichunyu", + "db_name": db_name, "auto_create": True, }, }, "embedder": { - "backend": "ollama", + "backend": "universal_api", "config": { - "model_name_or_path": "nomic-embed-text:latest", + "provider": "openai", + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "model_name_or_path": "text-embedding-3-large", + "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), }, }, + "reorganize": True, }, }, "act_mem": {}, @@ -259,11 +267,12 @@ def filter_memory_data(memories_data): # 7. Add document print(f"🚀 [{datetime.now().strftime('%H:%M:%S')}] Starting to add document...") start_time = time.time() -## 7.1 add pdf for ./tmp/data if use doc mem mos.add(doc_path="./tmp/data/") -start_time = time.time() +# 7.1 add pdf for ./tmp/data if use doc mem mos.add(doc_path="./tmp/data/") +mos.add(doc_path=["examples/data/docs/test/0b85477387a9d0cc33fca0f4becaa0e5.pdf"]) get_all_results = mos.get_all() filtered_results = filter_memory_data(get_all_results) + print(f"Get all results after add doc: {filtered_results}") print( diff --git a/src/memos/api/config.py b/src/memos/api/config.py index c62cd3b08..19222f680 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -507,6 +507,7 @@ def get_neo4j_community_config(user_id: str | None = None) -> dict[str, Any]: @staticmethod def get_neo4j_config(user_id: str | None = None) -> dict[str, Any]: """Get Neo4j configuration.""" + print("MOS_NEO4J_SHARED_DB: ", os.getenv("MOS_NEO4J_SHARED_DB", "false")) if os.getenv("MOS_NEO4J_SHARED_DB", "false").lower() == "true": return APIConfig.get_neo4j_shared_config(user_id) else: diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index 080962482..57f157967 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -1,3 +1,5 @@ +from langchain_text_splitters import Language, RecursiveCharacterTextSplitter + from memos.configs.chunker import SentenceChunkerConfig from memos.dependency import require_python_package from memos.log import get_logger @@ -17,24 +19,21 @@ class SentenceChunker(BaseChunker): install_link="https://docs.chonkie.ai/python-sdk/getting-started/installation", ) def __init__(self, config: SentenceChunkerConfig): - from chonkie import SentenceChunker as ChonkieSentenceChunker - self.config = config - self.chunker = ChonkieSentenceChunker( - tokenizer_or_token_counter=config.tokenizer_or_token_counter, + self.chunker = RecursiveCharacterTextSplitter.from_language( + language=Language.PYTHON, chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap, - min_sentences_per_chunk=config.min_sentences_per_chunk, ) logger.info(f"Initialized SentenceChunker with config: {config}") def chunk(self, text: str) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - chonkie_chunks = self.chunker.chunk(text) + chonkie_chunks = self.chunker.split_text(text) chunks = [] for c in chonkie_chunks: - chunk = Chunk(text=c.text, token_count=c.token_count, sentences=c.sentences) + chunk = Chunk(text=c, token_count=-1, sentences=[]) chunks.append(chunk) logger.debug(f"Generated {len(chunks)} chunks from input text") diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index d88b6005e..7578e3469 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -49,10 +49,11 @@ class UniversalAPIEmbedderConfig(BaseEmbedderConfig): OpenAI, etc. """ - provider: str = Field(..., description="Provider name, e.g., 'openai'") - api_key: str = Field(..., description="API key for the embedding provider") + provider: str = Field(default="openai", description="Provider name, e.g., 'openai'") + api_key: str = Field(default="EMPTY", description="API key for the embedding provider") base_url: str | None = Field( - default=None, description="Optional base URL for custom or proxied endpoint" + default="http://111.31.225.48:8078/v1", + description="Optional base URL for custom or proxied endpoint", ) diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index ce180606b..99f5b34cc 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -68,7 +68,7 @@ class Neo4jGraphDBConfig(BaseGraphDBConfig): ), ) - embedding_dimension: int = Field(default=768, description="Dimension of vector embedding") + embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding") @model_validator(mode="after") def validate_config(self): diff --git a/src/memos/configs/mem_os.py b/src/memos/configs/mem_os.py index 549e55792..756066a09 100644 --- a/src/memos/configs/mem_os.py +++ b/src/memos/configs/mem_os.py @@ -43,7 +43,7 @@ class MOSConfig(BaseConfig): description="Maximum number of turns to keep in the conversation history", ) top_k: int = Field( - default=5, + default=50, description="Maximum number of memories to retrieve for each query", ) enable_textual_memory: bool = Field( diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index f11b3a44c..2c5132f1e 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -1,8 +1,12 @@ +import base64 import json +import mimetypes import os +import re import time from datetime import datetime +from functools import lru_cache from pathlib import Path from threading import Lock from typing import Any, Literal @@ -34,6 +38,169 @@ logger = get_logger(__name__) +# ----------------------- +# Helper functions for image handling in chat +# ----------------------- + + +def _encode_image_to_data_url(image_path: str) -> str | None: + """Encode local image file to base64 data URL for OpenAI-compatible image messages. + + Returns a data URL like: data:image/jpeg;base64,<...> + """ + try: + mime, _ = mimetypes.guess_type(image_path) + if not mime: + # default to jpeg + mime = "image/jpeg" + with open(image_path, "rb") as f: + b64 = base64.b64encode(f.read()).decode("ascii") + return f"data:{mime};base64,{b64}" + except Exception as e: + logger.warning(f"Failed to encode image '{image_path}' to data URL: {e}") + return None + + +@lru_cache(maxsize=1) +def _build_images_index() -> dict[str, str]: + """Scan `./ppt_test_result` recursively and index images by filename. + + New structure example: + ./ppt_test_result//extracted/file_*//auto/images/*.{png,jpg,jpeg,webp,gif} + + Also compatible with previous layouts. Returns mapping: + basename (e.g. img_123.jpg) -> absolute path + """ + base_dir = Path("./ppt_test_result") + index: dict[str, str] = {} + if not base_dir.exists(): + return index + + # Recursively find any `auto/images` directories under ppt_test_result + for images_dir in base_dir.rglob("auto/images"): + if images_dir.is_dir(): + for img_file in images_dir.iterdir(): + if img_file.is_file(): + index[img_file.name] = str(img_file.resolve()) + logger.info(f"Image index built with {len(index)} entries") + return index + + +def get_images(sources: list[str]) -> list[str]: + """Extract image absolute paths from metadata sources. + + Supports patterns like: ![](images/.jpg) or any 'images/...jpg' substring. + Falls back to scanning the ppt_test_result index to resolve basenames. + """ + if not sources: + return [] + + # Ensure index exists + index = _build_images_index() + found: list[str] = [] + + md_img_pattern = re.compile(r"!\[[^\]]*\]\(([^\)]+)\)") + images_substr_pattern = re.compile(r"images/[^\s)]+\.(?:png|jpg|jpeg|webp)", re.IGNORECASE) + + for src in sources: + if not src: + continue + # 1) markdown image syntax + for m in md_img_pattern.findall(src): + candidate = m.strip() + # if it's a relative like 'images/xxx.jpg', resolve via index + basename = os.path.basename(candidate) + if basename in index: + found.append(index[basename]) + else: + # try direct path (absolute or relative) + p = Path(candidate) + if not p.is_absolute(): + p = Path.cwd() / p + if p.exists(): + found.append(str(p.resolve())) + + # 2) any 'images/xxx.jpg' substring + for m in images_substr_pattern.findall(src): + candidate = m.strip() + basename = os.path.basename(candidate) + if basename in index: + found.append(index[basename]) + else: + p = Path(candidate) + if not p.is_absolute(): + p = Path.cwd() / p + if p.exists(): + found.append(str(p.resolve())) + + # Deduplicate preserving order + dedup: list[str] = [] + seen = set() + for path in found: + if path not in seen: + dedup.append(path) + seen.add(path) + return dedup + + +def add_images_context( + current_messages: list[dict[str, Any]], images: list[str] +) -> list[dict[str, Any]]: + """Append images in OpenAI-compatible multi-part format and ensure message structure. + + - Deduplicates image paths. + - Ensures a system message exists with a concise CN vision instruction. + - Ensures the last user message has multi-part content: [text, image_url...]. + - Uses base64 data URLs. Limits to 6 images. + - In-place modification of `current_messages`. + """ + if not images: + return current_messages + + # Deduplicate images while preserving order + unique_images: list[str] = [] + seen_paths: set[str] = set() + for p in images: + if p not in seen_paths: + unique_images.append(p) + seen_paths.add(p) + + # Locate or create the last user message + user_idx = None + for i in range(len(current_messages) - 1, -1, -1): + if current_messages[i].get("role") == "user": + user_idx = i + break + + user_msg = current_messages[user_idx] + orig_content = user_msg.get("content", "") + + # Normalize user content to multi-part format using original query as text (no fallback) + content_parts: list[dict[str, Any]] + if isinstance(orig_content, str): + content_parts = [{"type": "text", "text": orig_content}] + elif isinstance(orig_content, list): + content_parts = orig_content + else: + content_parts = [{"type": "text", "text": str(orig_content)}] + + # 5) Append up to 3 images as data URLs + limit = 6 + count = 0 + + for img_path in unique_images: + if count >= limit: + break + data_url = _encode_image_to_data_url(img_path) + if data_url: + content_parts.append({"type": "image_url", "image_url": {"url": data_url}}) + count += 1 + + user_msg["content"] = content_parts + current_messages[user_idx] = user_msg + logger.info(f"Attached {count} images to user message (deduplicated from {len(images)})") + return current_messages + class MOSCore: """ @@ -262,6 +429,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = Returns: str: The response from the MOS. """ + global images_all target_user_id = user_id if user_id is not None else self.user_id accessible_cubes = self.user_manager.get_user_cubes(target_user_id) user_cube_ids = [cube.cube_id for cube in accessible_cubes] @@ -272,12 +440,12 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = if self.config.enable_textual_memory and self.mem_cubes: memories_all = [] + images_all = [] for mem_cube_id, mem_cube in self.mem_cubes.items(): if mem_cube_id not in user_cube_ids: continue if not mem_cube.text_mem: continue - # submit message to scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: message_item = ScheduleMessageItem( @@ -291,7 +459,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = memories = mem_cube.text_mem.search( query, - top_k=self.config.top_k, + top_k=100, info={ "user_id": target_user_id, "session_id": self.session_id, @@ -299,15 +467,22 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = }, ) memories_all.extend(memories) + for memory in memories: + images_list = get_images(memory.metadata.sources) + if len(images_list): + images_all.extend(images_list) + logger.info(f"🧠 [Memory] Searched memories:\n{self._str_memories(memories_all)}\n") system_prompt = self._build_system_prompt(memories_all, base_prompt=base_prompt) else: system_prompt = self._build_system_prompt(base_prompt=base_prompt) + current_messages = [ {"role": "system", "content": system_prompt}, *chat_history.chat_history, {"role": "user", "content": query}, ] + current_messages = add_images_context(current_messages, images_all) past_key_values = None if self.config.enable_activation_memory: @@ -338,6 +513,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = # submit message to scheduler for accessible_mem_cube in accessible_cubes: mem_cube_id = accessible_mem_cube.cube_id + mem_cube_id = mem_cube_id.replace("cube_", "") mem_cube = self.mem_cubes[mem_cube_id] if self.enable_mem_scheduler and self.mem_scheduler is not None: message_item = ScheduleMessageItem( @@ -683,7 +859,7 @@ def add( self, messages: MessageList | None = None, memory_content: str | None = None, - doc_path: str | None = None, + doc_path: list[str] | None = None, mem_cube_id: str | None = None, user_id: str | None = None, session_id: str | None = None, @@ -857,7 +1033,7 @@ def process_preference_memory(): info={"user_id": target_user_id, "session_id": target_session_id}, mode="fast" if sync_mode == "async" else "fine", ) - + print("memories lenght:", len(memories)) mem_ids = [] for mem in memories: mem_id_list: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(mem) @@ -897,7 +1073,7 @@ def process_preference_memory(): and self.config.enable_textual_memory and self.mem_cubes[mem_cube_id].text_mem ): - documents = self._get_all_documents(doc_path) + documents = doc_path doc_memories = self.mem_reader.get_memory( documents, type="doc", diff --git a/src/memos/parsers/markitdown.py b/src/memos/parsers/markitdown.py index 02e75355f..822020bd4 100644 --- a/src/memos/parsers/markitdown.py +++ b/src/memos/parsers/markitdown.py @@ -1,11 +1,191 @@ +import asyncio +import shlex +import time +import traceback + +from pathlib import Path + +import httpx + from memos.configs.parser import MarkItDownParserConfig -from memos.dependency import require_python_package from memos.log import get_logger from memos.parsers.base import BaseParser logger = get_logger(__name__) +BASE_URL = "http://106.75.235.231:8001" +LOCAL_TOKEN = "local_only_a8f3d2c1b5e7f9a6c4d8e2b1f7c3a9e5d2f8b4c6a1e9d7f3c5b8a2e4f6d9c1a3" +HEADERS = {"Authorization": f"Bearer {LOCAL_TOKEN}"} + + +async def check_service_health(client: httpx.AsyncClient): + try: + response = await client.get(f"{BASE_URL}/health", timeout=5) + response.raise_for_status() + print("✅ API服务运行正常") + return True + except (httpx.ConnectError, httpx.HTTPStatusError) as e: + print(f"❌ API服务异常: {e}") + return False + + +async def upload_ppt_file(client: httpx.AsyncClient, file_path: Path) -> str: + print(f"\n📤 上传PPT文件: {file_path}") + print(f" 文件大小: {file_path.stat().st_size / 1024 / 1024:.1f} MB") + + external_file_id = f"ppt-test-{int(time.time())}" + + print("🔄 使用直接文件上传接口...") + + with open(file_path, "rb") as f: + files = {"file": (file_path.name, f, "application/vnd.ms-powerpoint")} + data = { + "fileId": external_file_id, + "force_ocr": "false", + "ocr_all_images": "true", + } + + print("📡 发送上传请求到 /api/file/upload...") + response = await client.post( + f"{BASE_URL}/api/file/upload", files=files, data=data, headers=HEADERS + ) + + response.raise_for_status() + + result = response.json() + print(f"📄 上传响应: {result}") + + if result["code"] != 200: + raise Exception(f"API返回错误: {result.get('message', '未知错误')}") + + generated_ids = result["data"]["generated_ids"] + if len(generated_ids) == 0: + raise Exception(f"服务器未生成文件ID,响应: {result}") + + file_id = generated_ids[0] + print(f"🆔 获得内部file_id: {file_id}") + return file_id + + +async def poll_status(client: httpx.AsyncClient, file_id: str): + print(f"\n⏳ 监控处理状态: {file_id}") + print(" 注意: PPT文件处理需要更长时间(LibreOffice转换 + AI处理)") + + max_retries = 720 + poll_interval = 5 + + for i in range(max_retries): + response = await client.get(f"{BASE_URL}/api/v6/status/{file_id}", headers=HEADERS) + response.raise_for_status() + + status_data = response.json() + status = status_data.get("status") + print(f" - 尝试 {i + 1}/{max_retries}: 状态 '{status}'") + + if status == "completed": + print("✅ 处理完成!") + return + elif status == "failed": + error_msg = status_data.get("error_message", "未知错误") + raise Exception(f"处理失败: {error_msg}") + + await asyncio.sleep(poll_interval) + + raise Exception("状态轮询超时") + + +async def download_and_verify_ppt(client: httpx.AsyncClient, file_id: str, temp_dir: Path): + """下载并验证PPT处理结果""" + print(f"\n📥 下载处理结果: {file_id}") + + temp_dir.mkdir(exist_ok=True, parents=True) + + archive_path = temp_dir / f"{file_id}.tar.zstd" + curl_cmd = [ + "curl", + "-o", + str(archive_path), + "-H", + f"Authorization: Bearer {LOCAL_TOKEN}", + f"{BASE_URL}/api/v6/download/{file_id}", + ] + + proc = await asyncio.create_subprocess_exec(*curl_cmd) + await proc.wait() + + if proc.returncode != 0: + raise Exception(f"curl下载失败,返回码: {proc.returncode}") + + file_size = archive_path.stat().st_size + print(f"📦 压缩包已保存: {archive_path} ({file_size} 字节)") + + extract_dir = temp_dir / "extracted" + extract_dir.mkdir() + + print("🔄 解压缩文件 (macOS兼容方式)...") + decompress_cmd = ( + f"zstd -d < {shlex.quote(str(archive_path))} | tar -xf - -C {shlex.quote(str(extract_dir))}" + ) + + proc = await asyncio.create_subprocess_shell( + decompress_cmd, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await proc.communicate() + + if proc.returncode != 0: + stderr_text = stderr.decode() if stderr else "未知错误" + print(f"⚠️ 管道方式失败,尝试分步解压: {stderr_text}") + + temp_tar = temp_dir / f"{file_id}.tar" + + zstd_cmd = ["zstd", "-d", str(archive_path), "-o", str(temp_tar)] + proc = await asyncio.create_subprocess_exec(*zstd_cmd) + await proc.wait() + + if proc.returncode != 0: + raise Exception(f"zstd解压失败,返回码: {proc.returncode}") + + tar_cmd = ["tar", "-xf", str(temp_tar), "-C", str(extract_dir)] + proc = await asyncio.create_subprocess_exec(*tar_cmd) + await proc.wait() + + if proc.returncode != 0: + raise Exception(f"tar解压失败,返回码: {proc.returncode}") + + temp_tar.unlink(missing_ok=True) + + print(f"📂 文件已解压到: {extract_dir}") + + result_content_dir = extract_dir / file_id + + md_files = list(result_content_dir.glob("**/*.md")) + if not md_files: + raise Exception("未找到markdown结果文件") + + result_file = md_files[0] + print(f"✔️ 找到结果文件: {result_file}") + + content = result_file.read_text(encoding="utf-8") + print(f"📄 提取的内容长度: {len(content)} 字符") + + lines = content.split("\n") + + text_content = "" + for line in lines: + if line.strip(): + text_content += line + + print("👍 PPT文件处理验证完成!") + return text_content + + +async def cleanup_server_file(client: httpx.AsyncClient, file_id: str): + print(f"\n🧹 清理服务器文件: {file_id}") + response = await client.delete(f"{BASE_URL}/api/v6/delete/{file_id}", headers=HEADERS) + response.raise_for_status() + print("✅ 服务器清理完成") + class MarkItDownParser(BaseParser): """MarkItDown Parser class.""" @@ -13,16 +193,51 @@ class MarkItDownParser(BaseParser): def __init__(self, config: MarkItDownParserConfig): self.config = config - @require_python_package( - import_name="markitdown", - install_command="pip install markitdown[all]", - install_link="https://github.com/microsoft/markitdown", - ) - def parse(self, file_path: str) -> str: - from markitdown import MarkItDown + async def parse(self, file_path: str) -> str: + ppt_file = Path(file_path) + + if not ppt_file.exists(): + print(f"❌ PPT文件不存在: {ppt_file}") + return "" + file_stem = ppt_file.stem + temp_dir = Path("./ppt_test_result/" + file_stem) + temp_dir.mkdir(exist_ok=True) + file_id: str | None = None + text_content: str = "" + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + # 0. 健康检查 + if not await check_service_health(client): + print("中止测试") + return "" + + file_id = await upload_ppt_file(client, ppt_file) + + await poll_status(client, file_id) + + text_content = await download_and_verify_ppt(client, file_id, temp_dir) + + except Exception as e: + print(f"\n❌ 测试失败: {e}") + tb_text = "".join(traceback.TracebackException.from_exception(e).format()) + print(tb_text) + print(f"🔍 文件保留在代码目录用于检查: ./{temp_dir.name}/") + return "" + + try: + if file_id: + async with httpx.AsyncClient() as client: + await cleanup_server_file(client, file_id) + + print(f"📂 下载的文件保留在代码目录: ./{temp_dir.name}/") + print(f" - 压缩包: ./{temp_dir.name}/file_*.tar.zstd") + print(f" - 解压内容: ./{temp_dir.name}/extracted/") + print("💡 您可以在代码目录下直接检查处理结果和文件质量") - """Parse the file at the given path and return its content as a MarkDown string.""" - md = MarkItDown(enable_plugins=False) - result = md.convert(file_path) + except Exception as e: + print(f"⚠️ 服务器清理失败: {e}") + tb_text = "".join(traceback.TracebackException.from_exception(e).format()) + print(tb_text) - return result.text_content + return text_content diff --git a/src/memos/templates/mos_prompts.py b/src/memos/templates/mos_prompts.py index 357a9f1bd..ba68878c7 100644 --- a/src/memos/templates/mos_prompts.py +++ b/src/memos/templates/mos_prompts.py @@ -30,22 +30,6 @@ ✅ Chain of Thought (CoT) enhancement is now enabled by default ✅ Complex queries will be automatically decomposed and enhanced -🌐 To enable Internet search capabilities: - 1. Go to your cube's textual memory configuration - 2. Set the backend to 'google' in the internet_retriever section - 3. Configure the following parameters: - - api_key: Your Google Search API key - - cse_id: Your Custom Search Engine ID - - num_results: Number of search results (default: 5) - -📝 Example configuration at cube config for tree_text_memory : - internet_retriever: - backend: 'google' - config: - api_key: 'your_google_api_key_here' - cse_id: 'your_custom_search_engine_id' - num_results: 5 -details: https://github.com/memos-ai/memos/blob/main/examples/core_memories/tree_textual_w_internet_memoy.py ============================================================ """