diff --git a/backend/prompts/async_knowledge_summary.yaml b/backend/prompts/async_knowledge_summary.yaml new file mode 100644 index 00000000..c94293d3 --- /dev/null +++ b/backend/prompts/async_knowledge_summary.yaml @@ -0,0 +1,115 @@ +# Async Knowledge Summary Prompt Templates (Chinese) + +# Summary Generation Prompt +SUMMARY_GENERATION_PROMPT: |- + ### 你是【知识总结专家】,负责生成简洁准确的知识总结。 + + 请为以下内容生成简洁的知识总结(不超过{{ max_length }}个中文字符): + + 内容: + {{ text }} + + ### 要求: + 1. 提取核心观点和关键信息 + 2. 使用简洁清晰的语言 + 3. 保持客观准确 + 4. 突出重点内容 + 5. 不要使用markdown格式符号(如#、*、-等) + 6. 直接输出总结内容,无需额外说明 + + 知识总结: + +# Keyword Extraction Prompt +KEYWORD_EXTRACTION_PROMPT: |- + ### 你是【关键词提取专家】,负责从文本中提取核心关键词。 + + 请从以下文本中提取{{ max_keywords }}个最重要的关键词: + + {{ text }} + + ### 要求: + 1. 关键词应准确反映文本主题 + 2. 优先提取专有名词和核心概念 + 3. 每个关键词用逗号分隔 + 4. 只输出关键词,不要其他内容 + 5. 使用中文输出 + + 关键词: + +# Knowledge Card Generation Prompt +KNOWLEDGE_CARD_GENERATION_PROMPT: |- + ### 你是【知识卡片生成专家】,负责将文本内容提炼成结构化的知识卡片。 + + 请为以下内容生成一个知识卡片,包含摘要和关键词: + + 内容: + {{ text }} + + ### 要求: + 1. 摘要部分: + - 不超过200个中文字符 + - 提炼核心内容和关键信息 + - 语言简洁清晰,逻辑连贯 + - 不使用markdown格式符号 + + 2. 关键词部分: + - 提取5-10个核心关键词 + - 用逗号分隔 + - 反映内容主题 + + 3. 输出格式: + - 第一行:摘要内容 + - 第二行:关键词(用"关键词:"前缀) + + 请直接输出,无需额外说明。 + +# Cluster Integration Prompt +CLUSTER_INTEGRATION_PROMPT: |- + ### 你是【知识整合专家】,负责将多个知识卡片整合成连贯的集群总结。 + + 请将以下知识卡片整合成一个连贯完整的集群总结: + + {{ summaries_text }} + + ### 要求: + 1. 将所有卡片的核心信息整合成统一主题 + 2. 根据内容重要性和相关性调整权重,重要内容详细描述,次要内容简要提及 + 3. 保持清晰逻辑和完整结构,确保所有信息都得到体现 + 4. 字数控制在200字以内 + 5. 使用简洁清晰的语言 + 6. 不要遗漏任何信息,只调整描述权重 + 7. 不要使用markdown格式符号 + 8. 直接输出纯文本内容 + + 集群整合总结: + +# Global Integration Prompt +GLOBAL_INTEGRATION_PROMPT: |- + ### 你是【知识库总结专家】,负责生成清晰明确的知识库整体总结。 + + 请将以下{{ cluster_count }}个集群总结整合成一个清晰明确的知识库内容总结: + + {{ summaries_text }} + + ### 要求: + + #### 1. 内容整合要求: + - 分析{{ cluster_count }}个集群总结的内容相似性和关联性 + - 将相似或关联的内容合并到同一个要点中 + - 最终要点数量不能超过{{ cluster_count }}个(即≤{{ cluster_count }}个要点) + - 如果内容差异很大,可以保持{{ cluster_count }}个独立要点 + + #### 2. 内容要求: + - 总结要清晰、完整、不遗漏关键信息 + - 每个要点突出核心观点和关键数据 + - 语言简洁明确,便于大模型识别查询意图 + - 保持逻辑连贯性和主题关联性 + + #### 3. 输出要求: + - 使用纯文本格式,不使用Markdown标记 + - 分点使用"一、"、"二、"等序号 + - 每个要点之间用空行分隔 + - 直接输出内容,无需额外说明 + + 知识库内容总结: + diff --git a/backend/prompts/async_knowledge_summary_en.yaml b/backend/prompts/async_knowledge_summary_en.yaml new file mode 100644 index 00000000..9c4a835f --- /dev/null +++ b/backend/prompts/async_knowledge_summary_en.yaml @@ -0,0 +1,114 @@ +# Async Knowledge Summary Prompt Templates (English) + +# Summary Generation Prompt +SUMMARY_GENERATION_PROMPT: |- + ### You are a [Knowledge Summary Expert] responsible for generating concise and accurate knowledge summaries. + + Please generate a concise knowledge summary (no more than {{ max_length }} characters) for the following content: + + Content: + {{ text }} + + ### Requirements: + 1. Extract core viewpoints and key information + 2. Use concise and clear language + 3. Maintain objectivity and accuracy + 4. Highlight important content + 5. Do not use markdown format symbols (such as #, *, -, etc.) + 6. Output the summary directly without additional explanation + + Knowledge Summary: + +# Keyword Extraction Prompt +KEYWORD_EXTRACTION_PROMPT: |- + ### You are a【Keyword Extraction Expert】responsible for extracting core keywords from text. + + Please extract {{ max_keywords }} most important keywords from the following text: + + {{ text }} + + ### Requirements: + 1. Keywords should accurately reflect the text theme + 2. Prioritize proper nouns and core concepts + 3. Separate each keyword with a comma + 4. Output only keywords, no other content + + Keywords: + +# Knowledge Card Generation Prompt +KNOWLEDGE_CARD_GENERATION_PROMPT: |- + ### You are a【Knowledge Card Generation Expert】responsible for refining text content into structured knowledge cards. + + Please generate a knowledge card for the following content, including summary and keywords: + + Content: + {{ text }} + + ### Requirements: + 1. Summary section: + - No more than 200 characters + - Refine core content and key information + - Use concise and clear language with coherent logic + - Do not use markdown format symbols + + 2. Keywords section: + - Extract 5-10 core keywords + - Separate with commas + - Reflect content theme + + 3. Output format: + - First line: Summary content + - Second line: Keywords (with "Keywords:" prefix) + + Please output directly without additional explanation. + +# Cluster Integration Prompt +CLUSTER_INTEGRATION_PROMPT: |- + ### You are a【Knowledge Integration Expert】responsible for integrating multiple knowledge cards into coherent cluster summaries. + + Please integrate the following knowledge cards into a coherent and complete cluster summary: + + {{ summaries_text }} + + ### Requirements: + 1. Integrate core information from all cards into a unified theme + 2. Adjust weight based on content importance and relevance, describe important content in detail, mention secondary content briefly + 3. Maintain clear logic and complete structure, ensure all information is represented + 4. Control word count within 200 words + 5. Use concise and clear language + 6. Do not omit any information, only adjust description weight + 7. Do not use markdown format symbols + 8. Output plain text content directly + + Cluster Integration Summary: + +# Global Integration Prompt +GLOBAL_INTEGRATION_PROMPT: |- + ### You are a【Knowledge Base Summary Expert】responsible for generating clear and explicit overall knowledge base summaries. + + Please integrate the following {{ cluster_count }} cluster summaries into a clear and explicit knowledge base content summary: + + {{ summaries_text }} + + ### Requirements: + + #### 1. Content Integration Requirements: + - Analyze content similarity and relevance of {{ cluster_count }} cluster summaries + - Merge similar or related content into the same point + - Final number of points must not exceed {{ cluster_count }} (i.e., ≤{{ cluster_count }} points) + - If content is very different, keep {{ cluster_count }} independent points + + #### 2. Content Requirements: + - Summary should be clear, complete, without missing key information + - Each point highlights core viewpoints and key data + - Language is concise and clear, easy for large models to identify query intent + - Maintain logical coherence and thematic relevance + + #### 3. Output Requirements: + - Use plain text format, do not use Markdown markup + - Use numbered points like "1.", "2.", etc. + - Separate each point with blank lines + - Output content directly without additional explanation + + Knowledge Base Content Summary: + diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 777ca3cd..8fcca7e9 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -14,16 +14,18 @@ dependencies = [ "pyyaml>=6.0.2", "redis>=5.0.0", "fastmcp==2.12.0", - "langchain>=0.3.26" + "langchain>=0.3.26", + "scikit-learn>=1.3.0" ] [project.optional-dependencies] data-process = [ - "ray[default]>=2.9.3", + "ray[default]>=2.8.0,<2.10.0", "celery>=5.3.6", "flower>=2.0.1", "nest_asyncio>=1.5.6", - "unstructured[csv,docx,pdf,pptx,xlsx,md]" + "unstructured[csv,docx,pdf,pptx,xlsx,md]", + "pydantic>=2.0.0,<3.0.0" ] test = [ "pytest", diff --git a/backend/services/elasticsearch_service.py b/backend/services/elasticsearch_service.py index aa386d02..2f861af9 100644 --- a/backend/services/elasticsearch_service.py +++ b/backend/services/elasticsearch_service.py @@ -33,10 +33,19 @@ get_knowledge_record, update_knowledge_record, get_knowledge_info_by_tenant_id, update_model_name_by_index_name, ) +from database.model_management_db import get_model_by_model_id from services.redis_service import get_redis_service from utils.config_utils import tenant_config_manager, get_model_name_from_config from utils.file_management_utils import get_all_files_status, get_file_size from utils.prompt_template_utils import get_knowledge_summary_prompt_template +from utils.async_knowledge_summary_utils import ( + AsyncLLMClient, + ChunkClusterer, + KnowledgeIntegrator, + async_vectorize_batch +) +import numpy as np +import json # Configure logging logger = logging.getLogger("elasticsearch_service") @@ -871,7 +880,7 @@ async def summary_index_name(self, model_id: Optional[int] = None ): """ - Generate a summary for the specified index based on its content + Generate a summary for the specified index based on its content using async pipeline Args: index_name: Name of the index to summarize @@ -879,44 +888,285 @@ async def summary_index_name(self, es_core: ElasticSearchCore instance tenant_id: ID of the tenant language: Language of the summary (default: 'zh') + model_id: Optional model ID for LLM Returns: StreamingResponse containing the generated summary """ try: - # Get all documents + # Validate tenant ID if not tenant_id: - raise Exception( - "Tenant ID is required for summary generation.") + raise Exception("Tenant ID is required for summary generation.") + + # Get documents from Elasticsearch all_documents = ElasticSearchService.get_random_documents( index_name, batch_size, es_core) - all_chunks = self._clean_chunks_for_summary(all_documents) - keywords_dict = calculate_term_weights(all_chunks) - keywords_for_summary = "" - for _, key in enumerate(keywords_dict): - keywords_for_summary = keywords_for_summary + ", " + key - - async def generate_summary(): - token_join = [] + + total_docs = all_documents.get('total', 0) + documents_list = all_documents.get('documents', []) + + # Check if index has any documents + if total_docs == 0: + error_msg = ( + f"No indexed content found in knowledge base '{index_name}' (total indexed documents: 0). " + f"Possible reasons:\n" + f"1. Documents are still being processed (check file processing status)\n" + f"2. Document processing failed (check logs for errors)\n" + f"3. No documents have been uploaded yet\n" + f"Please ensure documents are fully processed before generating a summary." + ) + logger.error(error_msg) + raise Exception(error_msg) + + if not documents_list: + raise Exception(f"Failed to retrieve documents from index {index_name}") + + logger.info(f"Retrieved {len(documents_list)} documents from index {index_name} (total in index: {total_docs})") + + # Get model configuration + if model_id: try: - for new_token in generate_knowledge_summary_stream(keywords_for_summary, language, tenant_id, model_id): - if new_token == "END": - break - else: - token_join.append(new_token) - yield f"data: {{\"status\": \"success\", \"message\": \"{new_token}\"}}\n\n" - await asyncio.sleep(0.1) + model_config = get_model_by_model_id(model_id, tenant_id) + if not model_config: + logger.warning(f"Model {model_id} not found, using default LLM") + model_config = tenant_config_manager.get_model_config( + key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) except Exception as e: - yield f"data: {{\"status\": \"error\", \"message\": \"{e}\"}}\n\n" - - # Return the flow response + logger.warning(f"Failed to get model {model_id}, using default: {e}") + model_config = tenant_config_manager.get_model_config( + key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) + else: + model_config = tenant_config_manager.get_model_config( + key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) + + # Get embedding model + embedding_model = get_embedding_model(tenant_id) + if not embedding_model: + raise Exception("Failed to get embedding model") + + # Async summary generation stream + async def generate_summary_stream(): + summary_parts = [] + try: + # Note: ES stores chunks, not original documents + # Each item in documents_list is already a processed chunk + chunks = documents_list + + # Progress: Step 1 - Document reconstruction and clustering + progress_data = {"status": "progress", "step": "document_clustering", "message": "Reconstructing documents and performing document-level clustering..."} + yield f"data: {json.dumps(progress_data)}\n\n" + await asyncio.sleep(0.1) + + # Reconstruct documents from chunks + documents = self._reconstruct_documents_from_chunks(chunks) + logger.info(f"Reconstructed {len(documents)} documents from {len(chunks)} chunks") + + # Vectorize documents for document-level clustering + doc_texts = [doc.get('content', '') for doc in documents] + doc_vectors = await async_vectorize_batch(doc_texts, embedding_model, batch_size=20) + + # Document-level clustering + from utils.async_knowledge_summary_utils import DocumentClusterer + doc_clusterer = DocumentClusterer(max_clusters=10) + doc_cluster_result = doc_clusterer.cluster_documents(doc_vectors) + + if doc_cluster_result is None: + # Fallback: treat all documents as one cluster + import numpy as np + doc_cluster_labels = np.zeros(len(documents), dtype=int) + n_doc_clusters = 1 + else: + doc_cluster_labels = doc_cluster_result['cluster_labels'] + n_doc_clusters = doc_cluster_result['n_clusters'] + + logger.info(f"Document clustering completed: {n_doc_clusters} document clusters") + + # Progress: Step 2 - Chunk vectorization and clustering within document clusters + progress_data = {"status": "progress", "step": "chunk_clustering", "message": "Performing chunk-level clustering within document clusters..."} + yield f"data: {json.dumps(progress_data)}\n\n" + await asyncio.sleep(0.1) + + # Reorganize chunks by document clusters + chunks_by_doc_cluster = self._organize_chunks_by_document_clusters(chunks, documents, doc_cluster_labels) + + # Vectorize chunks for chunk-level clustering + chunk_texts = [chunk.get('content', '') for chunk in chunks] + chunk_vectors = await async_vectorize_batch(chunk_texts, embedding_model, batch_size=20) + + # Cluster chunks with document-cluster awareness + chunk_clusterer = ChunkClusterer(similarity_threshold=0.70, min_cluster_size=1) + chunk_cluster_result = chunk_clusterer.cluster_chunks_with_document_clusters(chunk_vectors, chunks, chunks_by_doc_cluster) + + n_clusters = chunk_cluster_result['n_clusters'] + logger.info(f"Chunk clustering completed: {n_clusters} clusters from {len(chunks)} chunks") + + # Fallback strategy: if no clusters formed, treat all chunks as one cluster + if n_clusters == 0 and len(chunks) > 0: + logger.warning("No clusters formed, using fallback: treating all chunks as single cluster") + chunk_cluster_result['chunk_clusters'] = [{ + 'cluster_id': 0, + 'chunks': chunks, + 'size': len(chunks), + 'avg_similarity': 0.5 + }] + n_clusters = 1 + chunk_cluster_result['n_clusters'] = 1 + + progress_data = {"status": "progress", "step": "chunk_clustering", "message": f"Organized into {n_clusters} topic clusters"} + yield f"data: {json.dumps(progress_data)}\n\n" + await asyncio.sleep(0.1) + + # Progress: Step 2 - Knowledge card generation + progress_data = {"status": "progress", "step": "card_generation", "message": "Generating knowledge cards..."} + yield f"data: {json.dumps(progress_data)}\n\n" + await asyncio.sleep(0.1) + + # Initialize LLM client and integrator + llm_client = AsyncLLMClient(model_config, language=language) + knowledge_integrator = KnowledgeIntegrator(llm_client) + + # Generate knowledge cards for each chunk cluster + cards_data = [ + { + 'chunk_cluster': chunk_cluster, + 'parent_cluster_id': idx + } + for idx, chunk_cluster in enumerate(chunk_cluster_result['chunk_clusters']) + ] + + knowledge_cards = await llm_client.batch_generate_cards_async(cards_data) + knowledge_cards = [card for card in knowledge_cards if card is not None] + total_cards = len(knowledge_cards) + + logger.info(f"Generated {total_cards} knowledge cards") + progress_data = {"status": "progress", "step": "card_generation", "message": f"Generated {total_cards} knowledge cards"} + yield f"data: {json.dumps(progress_data)}\n\n" + await asyncio.sleep(0.1) + + # Progress: Step 3 - Knowledge integration + progress_data = {"status": "progress", "step": "integration", "message": "Integrating knowledge cards..."} + yield f"data: {json.dumps(progress_data)}\n\n" + await asyncio.sleep(0.1) + + # Group cards by document cluster for integration (ensure 1:1 mapping) + doc_cluster_knowledge_cards = {} + for card in knowledge_cards: + # Get document cluster ID from chunk cluster + chunk_cluster_id = card.get('parent_cluster', 0) + doc_cluster_id = None + for chunk_cluster in chunk_cluster_result['chunk_clusters']: + if chunk_cluster.get('cluster_id') == chunk_cluster_id: + doc_cluster_id = chunk_cluster.get('document_cluster_id', 0) + break + + if doc_cluster_id not in doc_cluster_knowledge_cards: + doc_cluster_knowledge_cards[doc_cluster_id] = [] + doc_cluster_knowledge_cards[doc_cluster_id].append(card) + + # Integrate within each document cluster (1:1 mapping) + cluster_integrations = [] + for doc_cluster_id, cards in doc_cluster_knowledge_cards.items(): + if cards: + integration = await knowledge_integrator.integrate_cluster_cards(cards, doc_cluster_id) + if integration: + cluster_integrations.append(integration) + + logger.info(f"Integrated {len(cluster_integrations)} document clusters from {len(knowledge_cards)} knowledge cards") + + # Progress: Step 4 - Global integration + progress_data = {"status": "progress", "step": "global_integration", "message": "Generating final knowledge base summary..."} + yield f"data: {json.dumps(progress_data)}\n\n" + await asyncio.sleep(0.1) + + # Global integration + global_integration = await knowledge_integrator.integrate_all_clusters(cluster_integrations) + + if global_integration: + # Stream the final summary + final_summary = global_integration['global_summary'] + + # Stream summary character by character to preserve line breaks + for char in final_summary: + summary_parts.append(char) + # Use json.dumps for proper JSON formatting + message_data = {"status": "success", "message": char} + yield f"data: {json.dumps(message_data)}\n\n" + await asyncio.sleep(0.01) + + # Send completion metadata + metadata = { + "total_clusters": global_integration['cluster_count'], + "total_cards": global_integration['total_cards'], + "confidence": global_integration['avg_confidence'], + "keywords": global_integration['global_keywords'][:10] + } + complete_data = {"status": "complete", "metadata": metadata} + yield f"data: {json.dumps(complete_data)}\n\n" + + logger.info(f"Knowledge summary generated successfully: {len(final_summary)} characters") + else: + raise Exception("Global integration failed") + + except Exception as e: + logger.error(f"Error in async summary generation: {e}", exc_info=True) + error_data = {"status": "error", "message": str(e)} + yield f"data: {json.dumps(error_data)}\n\n" + + # Return streaming response return StreamingResponse( - generate_summary(), + generate_summary_stream(), media_type="text/event-stream" ) - + except Exception as e: - raise Exception(f"{str(e)}") + logger.error(f"Failed to initialize summary generation: {e}", exc_info=True) + raise Exception(f"Failed to generate knowledge summary: {str(e)}") + + def _reconstruct_documents_from_chunks(self, chunks: List[dict]) -> List[dict]: + """Reconstruct documents from chunks""" + documents = {} + + for chunk in chunks: + doc_id = chunk.get('filename', chunk.get('source_doc', 'unknown')) + + if doc_id not in documents: + documents[doc_id] = { + 'doc_id': doc_id, + 'content': '', + 'chunks': [], + 'metadata': chunk.get('metadata', {}) + } + + documents[doc_id]['content'] += chunk.get('content', '') + ' ' + documents[doc_id]['chunks'].append(chunk) + + # Convert to list and clean content + doc_list = [] + for doc_id, doc_info in documents.items(): + doc_info['content'] = doc_info['content'].strip() + doc_list.append(doc_info) + + return doc_list + + def _organize_chunks_by_document_clusters(self, chunks: List[dict], documents: List[dict], doc_cluster_labels: np.ndarray) -> Dict[int, List[dict]]: + """Organize chunks by their document clusters""" + chunks_by_doc_cluster = {} + + # Create mapping from doc_id to cluster_id + doc_to_cluster = {} + for i, doc in enumerate(documents): + doc_to_cluster[doc['doc_id']] = doc_cluster_labels[i] + + # Group chunks by document cluster + for chunk in chunks: + doc_id = chunk.get('filename', chunk.get('source_doc', 'unknown')) + cluster_id = doc_to_cluster.get(doc_id, 0) + + if cluster_id not in chunks_by_doc_cluster: + chunks_by_doc_cluster[cluster_id] = [] + chunks_by_doc_cluster[cluster_id].append(chunk) + + return chunks_by_doc_cluster @staticmethod def _clean_chunks_for_summary(all_documents): diff --git a/backend/utils/async_knowledge_summary_utils.py b/backend/utils/async_knowledge_summary_utils.py new file mode 100644 index 00000000..a58ab826 --- /dev/null +++ b/backend/utils/async_knowledge_summary_utils.py @@ -0,0 +1,1043 @@ +""" +Async Knowledge Summary Utilities +Provides async pipeline for knowledge base summarization with clustering and integration +""" + +import asyncio +import logging +import re +import time +from collections import Counter +from typing import Dict, List, Optional + +import numpy as np +from jinja2 import Template +from openai import AsyncOpenAI +from sklearn.cluster import KMeans +from sklearn.metrics import silhouette_score +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.preprocessing import StandardScaler + +from backend.database.model_management_db import get_model_by_model_id +from utils.config_utils import get_model_name_from_config, tenant_config_manager +from utils.prompt_template_utils import get_async_knowledge_summary_prompt_template +from consts.const import MODEL_CONFIG_MAPPING, LANGUAGE + +logger = logging.getLogger(__name__) + + +class AsyncLLMClient: + """Async LLM client using OpenAI-compatible API""" + + def __init__(self, model_config: Dict, language: str = LANGUAGE["ZH"]): + """ + Initialize async LLM client + + Args: + model_config: Model configuration dict with api_key, base_url, model_name + language: Language code ('zh' or 'en') + """ + self.model_config = model_config + self.client = AsyncOpenAI( + api_key=model_config.get('api_key', ''), + base_url=model_config.get('base_url', '') + ) + self.model_name = get_model_name_from_config(model_config) + self.semaphore = asyncio.Semaphore(3) # Max 3 concurrent LLM calls + self.language = language + + # Load prompt templates + self.prompts = get_async_knowledge_summary_prompt_template(language) + logger.info(f"Loaded async knowledge summary prompts for language: {language}") + + async def chat_async( + self, + messages: List[Dict], + max_tokens: int = 500, + temperature: float = 0.3 + ) -> Optional[str]: + """ + Async LLM chat completion + + Args: + messages: List of message dicts with role and content + max_tokens: Max tokens to generate + temperature: Sampling temperature + + Returns: + Generated text or None on error + """ + async with self.semaphore: + try: + response = await self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=0.8 + ) + + if response and response.choices: + return response.choices[0].message.content.strip() + else: + logger.error("LLM response format error") + return None + + except Exception as e: + logger.error(f"LLM call failed: {e}") + return None + + async def batch_generate_cards_async(self, cards_data: List[Dict]) -> List[Dict]: + """ + Batch async generate knowledge cards + + Args: + cards_data: List of card data dicts + + Returns: + List of generated knowledge cards + """ + logger.info(f"Starting batch async generation of {len(cards_data)} knowledge cards") + + tasks = [self._generate_single_card_async(card_data) for card_data in cards_data] + results = await asyncio.gather(*tasks, return_exceptions=True) + + generated_cards = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"Card {i} generation failed: {result}") + fallback_card = self._create_fallback_card(cards_data[i]) + generated_cards.append(fallback_card) + elif result: + generated_cards.append(result) + + logger.info(f"Batch async generation completed, {len(generated_cards)} cards generated") + return generated_cards + + async def _generate_single_card_async(self, card_data: Dict) -> Optional[Dict]: + """Generate single knowledge card""" + chunk_cluster = card_data['chunk_cluster'] + parent_cluster_id = card_data['parent_cluster_id'] + + # Merge chunk texts + merged_text = self._merge_chunks(chunk_cluster['chunks']) + + if not merged_text or len(merged_text.strip()) < 50: + logger.warning("Merged text too short, skipping card generation") + return None + + # Parallel generation of summary and keywords + summary_task = self._generate_summary_async(merged_text) + keywords_task = self._extract_keywords_async(merged_text) + + summary, keywords = await asyncio.gather(summary_task, keywords_task) + + # Calculate confidence score + confidence_score = self._calculate_confidence(chunk_cluster) + + card = { + 'card_id': f"cluster_{parent_cluster_id}_card_{chunk_cluster.get('cluster_id', 0)}", + 'parent_cluster': parent_cluster_id, + 'summary': summary, + 'keywords': keywords, + 'source_chunks': [ + { + 'chunk_id': chunk.get('chunk_id', f"{chunk.get('_id', 'unknown')}"), + 'source_doc': chunk.get('filename', 'unknown'), + 'length': len(chunk.get('content', '')) + } + for chunk in chunk_cluster['chunks'] + ], + 'chunk_count': chunk_cluster['size'], + 'avg_similarity': chunk_cluster.get('avg_similarity', 0.0), + 'confidence_score': confidence_score + } + + return card + + async def _generate_summary_async(self, text: str, max_length: int = 200) -> str: + """Async generate summary""" + if len(text) > 2000: + text = text[:2000] + "..." + + # Use template from YAML + template = Template(self.prompts['SUMMARY_GENERATION_PROMPT']) + prompt = template.render(text=text, max_length=max_length) + + messages = [{"role": "user", "content": prompt}] + + response = await self.chat_async( + messages=messages, + max_tokens=max_length, # Reduce token count and enforce conciseness + temperature=0.3 + ) + + if response: + # 清理可能的markdown格式符号 + cleaned_response = self._clean_markdown_symbols(response) + # 确保长度控制 + if len(cleaned_response) > max_length: + cleaned_response = cleaned_response[:max_length] + "..." + return cleaned_response + else: + logger.warning("LLM summary generation failed, using fallback strategy") + return text[:max_length] + "..." + + async def _extract_keywords_async(self, text: str, max_keywords: int = 10) -> List[str]: + """Async extract keywords""" + if len(text) > 2000: + text = text[:2000] + "..." + + # Use template from YAML + template = Template(self.prompts['KEYWORD_EXTRACTION_PROMPT']) + prompt = template.render(text=text, max_keywords=max_keywords) + + messages = [{"role": "user", "content": prompt}] + + response = await self.chat_async( + messages=messages, + max_tokens=200, + temperature=0.3 + ) + + if response: + keywords = [kw.strip() for kw in response.replace(',', ',').split(',')] + keywords = [kw for kw in keywords if kw] + return keywords[:max_keywords] + else: + logger.warning("LLM keyword extraction failed, using fallback strategy") + return self._fallback_keyword_extraction(text) + + def _merge_chunks(self, chunks: List[dict]) -> str: + """Merge chunk texts""" + if not chunks: + return "" + + texts = [chunk.get('content', '') for chunk in chunks] + return ' '.join(texts) + + def _calculate_confidence(self, chunk_cluster: Dict) -> float: + """Calculate confidence score based on cluster size and similarity + + The confidence score is a weighted combination of: + - Size score (40%): Normalized cluster size, capped at 1.0 for clusters >= 10 chunks + - Similarity score (60%): Average similarity within the cluster + + Weights are chosen to prioritize semantic similarity over raw size, + as a small but highly similar cluster is more reliable than a large but diverse one. + """ + # Normalize cluster size: divide by 10 and cap at 1.0 + # Clusters with >= 10 chunks get full size score (1.0) + size_score = min(chunk_cluster['size'] / 10.0, 1.0) + + # Get average similarity within cluster (0.0 if not available) + similarity_score = chunk_cluster.get('avg_similarity', 0.0) + + # Weighted combination: 40% size + 60% similarity + # This emphasizes semantic coherence over raw quantity + confidence = 0.4 * size_score + 0.6 * similarity_score + return round(confidence, 3) + + def _create_fallback_card(self, card_data: Dict) -> Dict: + """Create fallback card""" + chunk_cluster = card_data['chunk_cluster'] + parent_cluster_id = card_data['parent_cluster_id'] + + merged_text = self._merge_chunks(chunk_cluster['chunks']) + + return { + 'card_id': f"cluster_{parent_cluster_id}_card_{chunk_cluster.get('cluster_id', 0)}", + 'parent_cluster': parent_cluster_id, + 'summary': merged_text[:200] + "..." if len(merged_text) > 200 else merged_text, + 'keywords': self._fallback_keyword_extraction(merged_text), + 'source_chunks': [ + { + 'chunk_id': chunk.get('chunk_id', f"{chunk.get('_id', 'unknown')}"), + 'source_doc': chunk.get('filename', 'unknown'), + 'length': len(chunk.get('content', '')) + } + for chunk in chunk_cluster['chunks'] + ], + 'chunk_count': chunk_cluster['size'], + 'avg_similarity': chunk_cluster.get('avg_similarity', 0.0), + 'confidence_score': 0.5 + } + + def _fallback_keyword_extraction(self, text: str) -> List[str]: + """Fallback keyword extraction""" + words = re.findall(r'[\u4e00-\u9fa5]+', text) + stop_words = {'的', '了', '和', '是', '在', '有', '个', '等', '与', '及'} + words = [w for w in words if len(w) >= 2 and w not in stop_words] + + word_counts = Counter(words) + top_words = [word for word, count in word_counts.most_common(10)] + + return top_words + + def _clean_markdown_symbols(self, text: str) -> str: + """Clean markdown symbols from text""" + # Remove markdown headers + text = re.sub(r'^#{1,6}\s*', '', text, flags=re.MULTILINE) + # Remove markdown bold/italic + text = re.sub(r'\*{1,2}([^*]+)\*{1,2}', r'\1', text) + text = re.sub(r'_{1,2}([^_]+)_{1,2}', r'\1', text) + # Remove markdown lists + text = re.sub(r'^[\s]*[-*+]\s*', '', text, flags=re.MULTILINE) + text = re.sub(r'^[\s]*\d+\.\s*', '', text, flags=re.MULTILINE) + # Remove horizontal rules + text = re.sub(r'^---+$', '', text, flags=re.MULTILINE) + # Remove extra whitespace + text = re.sub(r'\n\s*\n', '\n', text) + text = text.strip() + + return text + + +class DocumentClusterer: + """Document clusterer using K-means""" + + def __init__(self, max_clusters: int = 10): + """Initialize clusterer""" + self.max_clusters = max_clusters + self.scaler = StandardScaler() + + def cluster_documents(self, vectors: np.ndarray) -> Dict: + """ + Cluster documents + + Args: + vectors: Document vectors + + Returns: + Clustering result dict + """ + # Handle single document case + if len(vectors) == 1: + logger.info("Single document detected: returning single cluster") + return { + 'cluster_labels': np.array([0]), + 'n_clusters': 1, + 'silhouette_score': 1.0 + } + + try: + vectors_scaled = self.scaler.fit_transform(vectors) + optimal_k = self._find_optimal_k(vectors_scaled) + + kmeans = KMeans(n_clusters=optimal_k, random_state=42, n_init=10) + cluster_labels = kmeans.fit_predict(vectors_scaled) + + silhouette_avg = silhouette_score(vectors_scaled, cluster_labels) + + logger.info(f"K-means clustering completed: {optimal_k} clusters, silhouette score: {silhouette_avg:.3f}") + + return { + 'cluster_labels': cluster_labels, + 'n_clusters': optimal_k, + 'silhouette_score': silhouette_avg + } + + except Exception as e: + logger.error(f"Clustering failed: {e}") + return None + + def _find_optimal_k(self, vectors: np.ndarray) -> int: + """Find optimal K value""" + max_k = min(self.max_clusters, len(vectors) - 1) + + if max_k < 2: + return 1 + + # Check document similarity + if len(vectors) > 1: + similarity_matrix = cosine_similarity(vectors) + mask = ~np.eye(similarity_matrix.shape[0], dtype=bool) + avg_similarity = np.mean(similarity_matrix[mask]) + + if avg_similarity > 0.95: + logger.warning(f"Documents too similar (avg similarity: {avg_similarity:.3f}), using single cluster") + return 1 + + best_k = 2 + best_score = -1 + + for k in range(2, max_k + 1): + try: + kmeans = KMeans(n_clusters=k, random_state=42, n_init=10) + cluster_labels = kmeans.fit_predict(vectors) + + if len(set(cluster_labels)) > 1: + score = silhouette_score(vectors, cluster_labels) + if score > best_score: + best_score = score + best_k = k + + except Exception as e: + logger.warning(f"K={k} clustering failed: {e}") + continue + + logger.info(f"Optimal K value: {best_k}") + return best_k + + +class ChunkDivider: + """Chunk divider with sliding window""" + + def __init__(self, window_size: int = 512, overlap_ratio: float = 0.2): + """ + Initialize chunk divider + + Args: + window_size: Window size in characters + overlap_ratio: Overlap ratio (0-1) + """ + self.window_size = window_size + self.overlap_ratio = overlap_ratio + self.min_chunk_length = 50 + + def divide_documents(self, documents: List[dict]) -> List[dict]: + """ + Divide documents into chunks + + Args: + documents: List of document dicts from Elasticsearch + + Returns: + List of chunk dicts + """ + all_chunks = [] + + for doc in documents: + text = doc.get('content', '') + if not text or len(text.strip()) < self.min_chunk_length: + continue + + chunks = self._sliding_window_chunk(text, doc) + all_chunks.extend(chunks) + + logger.info(f"Divided {len(all_chunks)} chunks from {len(documents)} documents") + return all_chunks + + def _sliding_window_chunk(self, text: str, doc: dict) -> List[dict]: + """Sliding window chunking""" + chunks = [] + sentences = self._split_sentences(text) + + current_chunk = [] + current_length = 0 + chunk_start_pos = 0 + + for sent in sentences: + sent_length = len(sent) + + if current_length + sent_length > self.window_size and current_chunk: + chunk_text = ''.join(current_chunk) + if len(chunk_text.strip()) >= self.min_chunk_length: + chunks.append({ + 'content': chunk_text, + 'filename': doc.get('filename', 'unknown'), + '_id': doc.get('_id', 'unknown'), + 'title': doc.get('title', ''), + 'chunk_id': f"{doc.get('_id', 'unknown')}_chunk_{len(chunks)}", + 'start_pos': chunk_start_pos, + 'end_pos': chunk_start_pos + len(chunk_text), + 'length': len(chunk_text) + }) + + overlap_count = int(len(current_chunk) * self.overlap_ratio) + if overlap_count > 0: + overlap_text = ''.join(current_chunk[-overlap_count:]) + chunk_start_pos += len(chunk_text) - len(overlap_text) + current_chunk = current_chunk[-overlap_count:] + current_length = len(overlap_text) + else: + chunk_start_pos += len(chunk_text) + current_chunk = [] + current_length = 0 + + current_chunk.append(sent) + current_length += sent_length + + if current_chunk: + chunk_text = ''.join(current_chunk) + if len(chunk_text.strip()) >= self.min_chunk_length: + chunks.append({ + 'content': chunk_text, + 'filename': doc.get('filename', 'unknown'), + '_id': doc.get('_id', 'unknown'), + 'title': doc.get('title', ''), + 'chunk_id': f"{doc.get('_id', 'unknown')}_chunk_{len(chunks)}", + 'start_pos': chunk_start_pos, + 'end_pos': chunk_start_pos + len(chunk_text), + 'length': len(chunk_text) + }) + + return chunks + + def _split_sentences(self, text: str) -> List[str]: + """Simple sentence splitting""" + if len(text) < self.min_chunk_length: + return [text] if text.strip() else [] + + sentences = re.split(r'([。!?\n]+)', text) + + result = [] + for i in range(0, len(sentences), 2): + if i + 1 < len(sentences): + sent = sentences[i] + sentences[i + 1] + else: + sent = sentences[i] + + if sent.strip(): + result.append(sent) + + if not result: + result = [text] + + return result + + +class ChunkClusterer: + """Chunk clusterer based on similarity threshold""" + + def __init__(self, similarity_threshold: float = 0.70, min_cluster_size: int = 1): + """ + Initialize chunk clusterer + + Args: + similarity_threshold: Similarity threshold (0-1) + min_cluster_size: Minimum cluster size + """ + self.similarity_threshold = similarity_threshold + self.min_cluster_size = min_cluster_size + + def cluster_chunks(self, vectors: np.ndarray, chunks: List[dict]) -> Dict: + """ + Cluster chunks with document-first approach + + Args: + vectors: Chunk vectors + chunks: Chunk dicts + + Returns: + Clustering result + """ + if len(chunks) == 0: + return {'chunk_clusters': [], 'n_clusters': 0} + + # First: Group chunks by document source + document_groups = self._group_chunks_by_document(chunks, vectors) + + # Second: Within each document, perform semantic clustering + all_clusters = [] + cluster_counter = 0 + + for doc_id, (doc_chunks, doc_vectors) in document_groups.items(): + if len(doc_chunks) == 1: + # Single chunk document - create single cluster + cluster = { + 'cluster_id': cluster_counter, + 'chunks': doc_chunks, + 'size': 1, + 'avg_similarity': 1.0, + 'document_id': doc_id + } + all_clusters.append(cluster) + cluster_counter += 1 + else: + # Multiple chunks document - semantic clustering within document + doc_similarity_matrix = cosine_similarity(doc_vectors) + doc_cluster_labels = self._threshold_clustering(doc_similarity_matrix) + + # Organize clusters within this document + doc_clusters = self._organize_clusters(doc_chunks, doc_cluster_labels, doc_similarity_matrix) + + # Assign cluster IDs and add document info + for cluster in doc_clusters['chunk_clusters']: + cluster['cluster_id'] = cluster_counter + cluster['document_id'] = doc_id + all_clusters.append(cluster) + cluster_counter += 1 + + result = { + 'chunk_clusters': all_clusters, + 'noise_chunks': [], + 'n_clusters': len(all_clusters), + 'n_noise': 0 + } + + logger.info(f"Document-first clustering completed: {result['n_clusters']} clusters from {len(document_groups)} documents") + return result + + def _group_chunks_by_document(self, chunks: List[dict], vectors: np.ndarray) -> Dict[str, tuple]: + """Group chunks by their source document""" + document_groups = {} + + for i, chunk in enumerate(chunks): + # Use filename as document identifier + doc_id = chunk.get('filename', chunk.get('source_doc', f'doc_{i}')) + + if doc_id not in document_groups: + document_groups[doc_id] = ([], []) + + document_groups[doc_id][0].append(chunk) # chunks + document_groups[doc_id][1].append(vectors[i]) # vectors + + # Convert lists to numpy arrays + for doc_id in document_groups: + chunks_list, vectors_list = document_groups[doc_id] + vectors_array = np.array(vectors_list) + document_groups[doc_id] = (chunks_list, vectors_array) + + return document_groups + + def cluster_chunks_with_document_clusters(self, vectors: np.ndarray, chunks: List[dict], chunks_by_doc_cluster: Dict[int, List[dict]]) -> Dict: + """ + Cluster chunks with awareness of document clusters + + Args: + vectors: Chunk vectors + chunks: Chunk dicts + chunks_by_doc_cluster: Chunks organized by document clusters + + Returns: + Clustering result with document cluster awareness + """ + if len(chunks) == 0: + return {'chunk_clusters': [], 'n_clusters': 0} + + all_clusters = [] + cluster_counter = 0 + + # Process each document cluster separately + for doc_cluster_id, doc_cluster_chunks in chunks_by_doc_cluster.items(): + if not doc_cluster_chunks: + continue + + # Get vectors for this document cluster + doc_cluster_indices = [] + for chunk in doc_cluster_chunks: + for i, orig_chunk in enumerate(chunks): + if chunk == orig_chunk: + doc_cluster_indices.append(i) + break + + if not doc_cluster_indices: + continue + + doc_cluster_vectors = vectors[doc_cluster_indices] + + # Smart chunk clustering decision based on token limit + estimated_tokens = self._estimate_tokens_for_chunks(doc_cluster_chunks) + max_tokens_per_cluster = 4000 # Conservative token limit for LLM processing + + if estimated_tokens <= max_tokens_per_cluster: + # Few chunks or low token count: merge all chunks into one cluster + cluster = { + 'cluster_id': cluster_counter, + 'chunks': doc_cluster_chunks, + 'size': len(doc_cluster_chunks), + 'avg_similarity': 0.9, # High similarity for same document cluster + 'document_cluster_id': doc_cluster_id + } + all_clusters.append(cluster) + cluster_counter += 1 + logger.info(f"Document cluster {doc_cluster_id}: merged {len(doc_cluster_chunks)} chunks into 1 cluster (tokens: {estimated_tokens})") + else: + # Many chunks or high token count: apply semantic clustering + doc_similarity_matrix = cosine_similarity(doc_cluster_vectors) + doc_cluster_labels = self._threshold_clustering(doc_similarity_matrix) + + # Organize clusters within this document cluster + doc_clusters = self._organize_clusters(doc_cluster_chunks, doc_cluster_labels, doc_similarity_matrix) + + # Assign cluster IDs and add document cluster info + for cluster in doc_clusters['chunk_clusters']: + cluster['cluster_id'] = cluster_counter + cluster['document_cluster_id'] = doc_cluster_id + all_clusters.append(cluster) + cluster_counter += 1 + + logger.info(f"Document cluster {doc_cluster_id}: {len(doc_cluster_chunks)} chunks clustered into {len(doc_clusters['chunk_clusters'])} sub-clusters (tokens: {estimated_tokens})") + + result = { + 'chunk_clusters': all_clusters, + 'noise_chunks': [], + 'n_clusters': len(all_clusters), + 'n_noise': 0 + } + + logger.info(f"Document-cluster-aware chunk clustering completed: {result['n_clusters']} clusters from {len(chunks_by_doc_cluster)} document clusters") + return result + + def _estimate_tokens_for_chunks(self, chunks: List[dict]) -> int: + """Estimate token count for a list of chunks""" + total_text = "" + for chunk in chunks: + content = chunk.get('content', '') + total_text += content + " " + + # Rough estimation: 1 token ≈ 4 characters for Chinese text + estimated_tokens = len(total_text.strip()) // 4 + return estimated_tokens + + def _threshold_clustering(self, similarity_matrix: np.ndarray) -> np.ndarray: + """Threshold-based clustering""" + n_chunks = len(similarity_matrix) + cluster_labels = np.full(n_chunks, -1, dtype=int) + current_cluster = 0 + + for i in range(n_chunks): + if cluster_labels[i] != -1: + continue + + similar_indices = np.where(similarity_matrix[i] >= self.similarity_threshold)[0] + + if len(similar_indices) >= self.min_cluster_size: + for idx in similar_indices: + if cluster_labels[idx] == -1: + cluster_labels[idx] = current_cluster + current_cluster += 1 + else: + cluster_labels[i] = -1 + + return cluster_labels + + def _organize_clusters( + self, + chunks: List[dict], + cluster_labels: np.ndarray, + similarity_matrix: np.ndarray + ) -> Dict: + """Organize clustering results""" + chunk_clusters = {} + noise_chunks = [] + + for i, (chunk, label) in enumerate(zip(chunks, cluster_labels)): + if label == -1: + noise_chunks.append(chunk) + else: + if label not in chunk_clusters: + chunk_clusters[label] = { + 'cluster_id': int(label), + 'chunks': [], + 'size': 0, + 'avg_similarity': 0.0 + } + chunk_clusters[label]['chunks'].append(chunk) + chunk_clusters[label]['size'] += 1 + + for cluster_id, cluster_info in chunk_clusters.items(): + chunk_indices = [i for i, label in enumerate(cluster_labels) if label == cluster_id] + if len(chunk_indices) > 1: + cluster_sim_matrix = similarity_matrix[np.ix_(chunk_indices, chunk_indices)] + mask = ~np.eye(cluster_sim_matrix.shape[0], dtype=bool) + avg_sim = np.mean(cluster_sim_matrix[mask]) + cluster_info['avg_similarity'] = float(avg_sim) + else: + cluster_info['avg_similarity'] = 1.0 + + return { + 'chunk_clusters': list(chunk_clusters.values()), + 'noise_chunks': noise_chunks, + 'n_clusters': len(chunk_clusters), + 'n_noise': len(noise_chunks) + } + + +class KnowledgeIntegrator: + """Knowledge integrator for cluster and global integration""" + + def __init__(self, llm_client: AsyncLLMClient): + """ + Initialize knowledge integrator + + Args: + llm_client: Async LLM client + """ + self.llm_client = llm_client + + async def integrate_cluster_cards(self, cards: List[Dict], cluster_id: int) -> Dict: + """ + Integrate knowledge cards within a cluster + + Args: + cards: Knowledge cards + cluster_id: Cluster ID + + Returns: + Cluster integration result + """ + if not cards: + return None + + if len(cards) == 1: + card = cards[0] + return { + 'cluster_id': cluster_id, + 'integrated_summary': card['summary'], + 'integrated_keywords': card['keywords'], + 'card_count': 1, + 'source_cards': [card['card_id']], + 'confidence_score': card['confidence_score'] + } + + logger.info(f"Integrating {len(cards)} knowledge cards in cluster {cluster_id}...") + + card_summaries = [] + all_keywords = [] + + for i, card in enumerate(cards): + card_summaries.append(f"Card{i+1}: {card['summary']}") + all_keywords.extend(card['keywords']) + + integrated_summary = await self._generate_cluster_summary_async(card_summaries, cluster_id) + integrated_keywords = self._integrate_keywords(all_keywords) + + avg_confidence = sum(card['confidence_score'] for card in cards) / len(cards) + + return { + 'cluster_id': cluster_id, + 'integrated_summary': integrated_summary, + 'integrated_keywords': integrated_keywords, + 'card_count': len(cards), + 'source_cards': [card['card_id'] for card in cards], + 'confidence_score': round(avg_confidence, 3), + 'detailed_cards': cards + } + + async def integrate_all_clusters(self, cluster_integrations: List[Dict]) -> Dict: + """ + Integrate all cluster summaries + + Args: + cluster_integrations: List of cluster integration results + + Returns: + Global integration result + """ + if not cluster_integrations: + return None + + logger.info(f"Integrating {len(cluster_integrations)} cluster summaries...") + + cluster_summaries = [] + all_keywords = [] + + for i, cluster in enumerate(cluster_integrations): + cluster_summaries.append(f"Cluster{i+1}: {cluster['integrated_summary']}") + all_keywords.extend(cluster['integrated_keywords']) + + global_summary = await self._generate_global_summary_async(cluster_summaries) + global_keywords = self._integrate_keywords(all_keywords) + + total_cards = sum(cluster['card_count'] for cluster in cluster_integrations) + avg_confidence = sum(cluster['confidence_score'] for cluster in cluster_integrations) / len(cluster_integrations) + + return { + 'global_summary': global_summary, + 'global_keywords': global_keywords, + 'cluster_count': len(cluster_integrations), + 'total_cards': total_cards, + 'avg_confidence': round(avg_confidence, 3), + 'cluster_details': cluster_integrations + } + + async def _generate_cluster_summary_async(self, card_summaries: List[str], cluster_id: int) -> str: + """Generate cluster-level integrated summary""" + summaries_text = '\n'.join(card_summaries) + + # Use template from YAML + template = Template(self.llm_client.prompts['CLUSTER_INTEGRATION_PROMPT']) + prompt = template.render(summaries_text=summaries_text) + + messages = [{"role": "user", "content": prompt}] + + response = await self.llm_client.chat_async( + messages=messages, + max_tokens=300, # Reduce token count + temperature=0.3 + ) + + if response: + # 清理markdown符号 + cleaned_response = self._clean_markdown_symbols(response) + return cleaned_response + else: + logger.warning("LLM cluster integration failed, using fallback strategy") + return ";".join(card_summaries) + + async def _generate_global_summary_async(self, cluster_summaries: List[str]) -> str: + """Generate global integrated summary with dynamic point structure""" + summaries_text = '\n\n'.join(cluster_summaries) + cluster_count = len(cluster_summaries) + + # Use template from YAML + template = Template(self.llm_client.prompts['GLOBAL_INTEGRATION_PROMPT']) + prompt = template.render(summaries_text=summaries_text, cluster_count=cluster_count) + + messages = [{"role": "user", "content": prompt}] + + response = await self.llm_client.chat_async( + messages=messages, + max_tokens=2000, # 增加token限制以支持结构化长文本 + temperature=0.3 + ) + + if response: + # 直接返回大模型输出,无需额外格式化 + return response + else: + logger.warning("LLM global integration failed, using fallback strategy") + return "\n\n".join(cluster_summaries) + + def _clean_markdown_symbols(self, text: str) -> str: + """Clean markdown symbols from text""" + # Remove markdown headers + text = re.sub(r'^#{1,6}\s*', '', text, flags=re.MULTILINE) + # Remove markdown bold/italic + text = re.sub(r'\*{1,2}([^*]+)\*{1,2}', r'\1', text) + text = re.sub(r'_{1,2}([^_]+)_{1,2}', r'\1', text) + # Remove markdown lists + text = re.sub(r'^[\s]*[-*+]\s*', '', text, flags=re.MULTILINE) + text = re.sub(r'^[\s]*\d+\.\s*', '', text, flags=re.MULTILINE) + # Remove horizontal rules + text = re.sub(r'^---+$', '', text, flags=re.MULTILINE) + # Remove extra whitespace + text = re.sub(r'\n\s*\n', '\n', text) + text = text.strip() + + return text + + def _format_final_summary(self, text: str) -> str: + """Format final summary for better readability with dynamic numbered points""" + # Split into lines and clean + lines = [line.strip() for line in text.split('\n') if line.strip()] + + # Format numbered points dynamically + formatted_lines = [] + point_counter = 0 + + for line in lines: + if line: + # Check if line already has Chinese numbering + if re.match(r'^[一二三四五六七八九十百千万]+、', line): + formatted_lines.append(line) + point_counter += 1 + else: + # Add dynamic numbering + point_counter += 1 + chinese_num = self._convert_to_chinese_number(point_counter) + line = f'{chinese_num}、{line}' + formatted_lines.append(line) + + # Add proper spacing between numbered points + formatted_text = '\n\n'.join(formatted_lines) + + # Ensure each point has proper spacing and clear structure + # Clean up any extra spaces and ensure consistent formatting + formatted_text = re.sub(r'\n\s*\n\s*\n', '\n\n', formatted_text) # Remove excessive line breaks + formatted_text = re.sub(r'([一二三四五六七八九十]+、[^一-十\n]+)(?=\n|$)', r'\1', formatted_text) # Ensure proper ending + + # Ensure each numbered point is followed by proper spacing + lines = formatted_text.split('\n') + formatted_lines = [] + for i, line in enumerate(lines): + formatted_lines.append(line) + # Add spacing after numbered points + if re.match(r'^[一二三四五六七八九十]+、', line.strip()) and i < len(lines) - 1: + if not lines[i + 1].strip(): # If next line is empty, keep it + continue + else: # If next line has content, add spacing + formatted_lines.append('') + + return '\n'.join(formatted_lines) + + def _convert_to_chinese_number(self, num: int) -> str: + """Convert Arabic number to Chinese number""" + if num <= 10: + num_map = {1: '一', 2: '二', 3: '三', 4: '四', 5: '五', 6: '六', 7: '七', 8: '八', 9: '九', 10: '十'} + return num_map[num] + elif num <= 99: + # For numbers 11-99, use combination + if num == 10: + return '十' + elif num < 20: + return f'十{self._convert_to_chinese_number(num % 10)}' + else: + tens = num // 10 + ones = num % 10 + if ones == 0: + return f'{self._convert_to_chinese_number(tens)}十' + else: + return f'{self._convert_to_chinese_number(tens)}十{self._convert_to_chinese_number(ones)}' + else: + # For numbers >= 100, use Arabic number as fallback + return str(num) + + def _integrate_keywords(self, keywords_list: List[str], max_keywords: int = 20) -> List[str]: + """Integrate keyword list""" + keyword_counts = Counter(keywords_list) + top_keywords = [kw for kw, count in keyword_counts.most_common(max_keywords)] + return top_keywords + + +async def async_vectorize_batch( + texts: List[str], + embedding_model, + batch_size: int = 20 +) -> np.ndarray: + """ + Async batch vectorization using embedding model + + Args: + texts: List of texts to vectorize + embedding_model: Embedding model instance + batch_size: Batch size + + Returns: + Vector matrix + """ + logger.info(f"Starting async vectorization of {len(texts)} texts, batch size: {batch_size}") + + batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)] + logger.info(f"Divided into {len(batches)} batches") + + async def vectorize_batch(batch_texts: List[str]) -> np.ndarray: + """Vectorize a single batch""" + try: + # Use embedding model's get_embeddings method in thread pool + loop = asyncio.get_event_loop() + vectors = await loop.run_in_executor( + None, + lambda: embedding_model.get_embeddings(batch_texts) + ) + return np.array(vectors) + except Exception as e: + logger.error(f"Batch vectorization failed: {e}", exc_info=True) + # Return zero vectors as fallback + return np.zeros((len(batch_texts), embedding_model.embedding_dim)) + + # Concurrent processing of all batches + tasks = [vectorize_batch(batch) for batch in batches] + batch_results = await asyncio.gather(*tasks, return_exceptions=True) + + # Merge results + all_vectors = [] + for i, result in enumerate(batch_results): + if isinstance(result, Exception): + logger.error(f"Batch {i} processing failed: {result}") + batch_size_actual = len(batches[i]) + zero_vector = np.zeros((batch_size_actual, embedding_model.embedding_dim)) + all_vectors.append(zero_vector) + else: + all_vectors.append(result) + + if all_vectors: + vectors = np.vstack(all_vectors) + logger.info(f"Async vectorization completed, vector shape: {vectors.shape}") + return vectors + else: + logger.error("All batches failed") + return np.zeros((len(texts), embedding_model.embedding_dim)) + diff --git a/backend/utils/prompt_template_utils.py b/backend/utils/prompt_template_utils.py index 4232ef19..a342956b 100644 --- a/backend/utils/prompt_template_utils.py +++ b/backend/utils/prompt_template_utils.py @@ -18,6 +18,7 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw - 'prompt_generate': Prompt generation template - 'agent': Agent template including manager and managed agents - 'knowledge_summary': Knowledge summary template + - 'async_knowledge_summary': Async knowledge summary template - 'analyze_file': File analysis template - 'generate_title': Title generation template - 'file_processing_messages': File processing messages template @@ -50,6 +51,10 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw LANGUAGE["ZH"]: 'backend/prompts/knowledge_summary_agent.yaml', LANGUAGE["EN"]: 'backend/prompts/knowledge_summary_agent_en.yaml' }, + 'async_knowledge_summary': { + LANGUAGE["ZH"]: 'backend/prompts/async_knowledge_summary.yaml', + LANGUAGE["EN"]: 'backend/prompts/async_knowledge_summary_en.yaml' + }, 'analyze_file': { LANGUAGE["ZH"]: 'backend/prompts/analyze_file.yaml', LANGUAGE["EN"]: 'backend/prompts/analyze_file_en.yaml' @@ -164,3 +169,16 @@ def get_file_processing_messages_template(language: str = 'zh') -> Dict[str, Any dict: Loaded file processing messages configuration """ return get_prompt_template('file_processing_messages', language) + + +def get_async_knowledge_summary_prompt_template(language: str = 'zh') -> Dict[str, Any]: + """ + Get async knowledge summary prompt template + + Args: + language: Language code ('zh' or 'en') + + Returns: + dict: Loaded async knowledge summary prompt template configuration + """ + return get_prompt_template('async_knowledge_summary', language) diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py index a908935c..12bdf270 100644 --- a/sdk/nexent/vector_database/elasticsearch_core.py +++ b/sdk/nexent/vector_database/elasticsearch_core.py @@ -1,873 +1,873 @@ -import time -import logging -import threading -from typing import List, Dict, Any, Optional -from contextlib import contextmanager -from dataclasses import dataclass -from datetime import datetime, timedelta -from ..core.models.embedding_model import BaseEmbedding -from .utils import format_size, format_timestamp, build_weighted_query -from elasticsearch import Elasticsearch, exceptions - -from ..core.nlp.tokenizer import calculate_term_weights - -logger = logging.getLogger("elasticsearch_core") - -@dataclass -class BulkOperation: - """Bulk operation status tracking""" - index_name: str - operation_id: str - start_time: datetime - expected_duration: timedelta - -class ElasticSearchCore: - """ - Core class for Elasticsearch operations including: - - Index management - - Document insertion with embeddings - - Document deletion - - Accurate text search - - Semantic vector search - - Hybrid search - - Index statistics - """ - - def __init__( - self, - host: Optional[str], - api_key: Optional[str], - verify_certs: bool = False, - ssl_show_warn: bool = False, - ): - """ - Initialize ElasticSearchCore with Elasticsearch client and JinaEmbedding model. - - Args: - host: Elasticsearch host URL (defaults to env variable) - api_key: Elasticsearch API key (defaults to env variable) - verify_certs: Whether to verify SSL certificates - ssl_show_warn: Whether to show SSL warnings - """ - # Get credentials from environment if not provided - self.host = host - self.api_key = api_key - - # Initialize Elasticsearch client with HTTPS support - self.client = Elasticsearch( - self.host, - api_key=self.api_key, - verify_certs=verify_certs, - ssl_show_warn=ssl_show_warn, - request_timeout=20, - max_retries=3, # Reduce retries for faster failure detection - retry_on_timeout=True, - retry_on_status=[502, 503, 504], # Retry on these status codes, - ) - - # Initialize embedding model - self._bulk_operations: Dict[str, List[BulkOperation]] = {} - self._settings_lock = threading.Lock() - self._operation_counter = 0 - - # Embedding API limits - self.max_texts_per_batch = 2048 - self.max_tokens_per_text = 8192 - self.max_total_tokens = 100000 - - # ---- INDEX MANAGEMENT ---- - - def create_vector_index(self, index_name: str, embedding_dim: Optional[int] = None) -> bool: - """ - Create a new vector search index with appropriate mappings in a celery-friendly way. - - Args: - index_name: Name of the index to create - embedding_dim: Dimension of the embedding vectors (optional, will use model's dim if not provided) - - Returns: - bool: True if creation was successful - """ - try: - # Use provided embedding_dim or get from model - actual_embedding_dim = embedding_dim or 1024 - - # Use balanced fixed settings to avoid dynamic adjustment - settings = { - "number_of_shards": 1, - "number_of_replicas": 0, - "refresh_interval": "5s", # not too fast, not too slow - "index": { - "max_result_window": 50000, - "translog": { - "durability": "async", - "sync_interval": "5s" - }, - "write": { - "wait_for_active_shards": "1" - }, - # Memory optimization for bulk operations - "merge": { - "policy": { - "max_merge_at_once": 5, - "segments_per_tier": 5 - } - } - } - } - - # Check if index already exists - if self.client.indices.exists(index=index_name): - logger.info(f"Index {index_name} already exists, skipping creation") - self._ensure_index_ready(index_name) - return True - - # Define the mapping with vector field - mappings = { - "properties": { - "id": {"type": "keyword"}, - "title": {"type": "text"}, - "filename": {"type": "keyword"}, - "path_or_url": {"type": "keyword"}, - "language": {"type": "keyword"}, - "author": {"type": "keyword"}, - "date": {"type": "date"}, - "content": {"type": "text"}, - "process_source": {"type": "keyword"}, - "embedding_model_name": {"type": "keyword"}, - "file_size": {"type": "long"}, - "create_time": {"type": "date"}, - "embedding": { - "type": "dense_vector", - "dims": actual_embedding_dim, - "index": "true", - "similarity": "cosine", - }, - } - } - - # Create the index with the defined mappings - self.client.indices.create( - index=index_name, - mappings=mappings, - settings=settings, - wait_for_active_shards="1" - ) - - # Force refresh to ensure visibility - self._force_refresh_with_retry(index_name) - self._ensure_index_ready(index_name) - - logger.info(f"Successfully created index: {index_name}") - return True - - except exceptions.RequestError as e: - # Handle the case where index already exists (error 400) - if "resource_already_exists_exception" in str(e): - logger.info(f"Index {index_name} already exists, skipping creation") - self._ensure_index_ready(index_name) - return True - logger.error(f"Error creating index: {str(e)}") - return False - except Exception as e: - logger.error(f"Error creating index: {str(e)}") - return False - - def _force_refresh_with_retry(self, index_name: str, max_retries: int = 3) -> bool: - """ - Force refresh with retry - synchronous version - """ - for attempt in range(max_retries): - try: - self.client.indices.refresh(index=index_name) - return True - except Exception as e: - if attempt < max_retries - 1: - time.sleep(0.5 * (attempt + 1)) - continue - logger.error(f"Failed to refresh index {index_name}: {e}") - return False - return False - - def _ensure_index_ready(self, index_name: str, timeout: int = 10) -> bool: - """ - Ensure index is ready, avoid 503 error - synchronous version - """ - start_time = time.time() - - while time.time() - start_time < timeout: - try: - # Check cluster health - health = self.client.cluster.health( - index=index_name, - wait_for_status="yellow", - timeout="1s" - ) - - if health["status"] in ["green", "yellow"]: - # Double check: try simple query - self.client.search( - index=index_name, - body={"query": {"match_all": {}}, "size": 0} - ) - return True - - except Exception as e: - time.sleep(0.1) - - logger.warning(f"Index {index_name} may not be fully ready after {timeout}s") - return False - - @contextmanager - def bulk_operation_context(self, index_name: str, estimated_duration: int = 60): - """ - Celery-friendly context manager - using threading.Lock - """ - operation_id = f"bulk_{self._operation_counter}_{threading.current_thread().name}" - self._operation_counter += 1 - - operation = BulkOperation( - index_name=index_name, - operation_id=operation_id, - start_time=datetime.now(), - expected_duration=timedelta(seconds=estimated_duration) - ) - - with self._settings_lock: - # Record current operation - if index_name not in self._bulk_operations: - self._bulk_operations[index_name] = [] - self._bulk_operations[index_name].append(operation) - - # If this is the first bulk operation, adjust settings - if len(self._bulk_operations[index_name]) == 1: - self._apply_bulk_settings(index_name) - - try: - yield operation_id - finally: - with self._settings_lock: - # Remove operation record - self._bulk_operations[index_name] = [ - op for op in self._bulk_operations[index_name] - if op.operation_id != operation_id - ] - - # If there are no other bulk operations, restore settings - if not self._bulk_operations[index_name]: - self._restore_normal_settings(index_name) - del self._bulk_operations[index_name] - - def _apply_bulk_settings(self, index_name: str): - """Apply bulk operation optimization settings""" - try: - self.client.indices.put_settings( - index=index_name, - body={ - "refresh_interval": "30s", - "translog.durability": "async", - "translog.sync_interval": "10s" - } - ) - logger.info(f"Applied bulk settings to {index_name}") - except Exception as e: - logger.warning(f"Failed to apply bulk settings: {e}") - - def _restore_normal_settings(self, index_name: str): - """Restore normal settings""" - try: - self.client.indices.put_settings( - index=index_name, - body={ - "refresh_interval": "5s", - "translog.durability": "request" - } - ) - # Refresh after restoration - self._force_refresh_with_retry(index_name) - logger.info(f"Restored normal settings for {index_name}") - except Exception as e: - logger.warning(f"Failed to restore settings: {e}") - - def delete_index(self, index_name: str) -> bool: - """ - Delete an entire index - - Args: - index_name: Name of the index to delete - - Returns: - bool: True if deletion was successful - """ - try: - self.client.indices.delete(index=index_name) - logger.info(f"Successfully deleted the index: {index_name}") - return True - except exceptions.NotFoundError: - logger.info(f"Index {index_name} not found") - return False - except Exception as e: - logger.error(f"Error deleting index: {str(e)}") - return False - - def get_user_indices(self, index_pattern: str = "*") -> List[str]: - """ - Get list of user created indices (excluding system indices) - - Args: - index_pattern: Pattern to match index names - - Returns: - List of index names - """ - try: - indices = self.client.indices.get_alias(index=index_pattern) - # Filter out system indices (starting with '.') - return [index_name for index_name in indices.keys() if not index_name.startswith('.')] - except Exception as e: - logger.error(f"Error getting user indices: {str(e)}") - return [] - - # ---- DOCUMENT OPERATIONS ---- - - def index_documents( - self, - index_name: str, - embedding_model: BaseEmbedding, - documents: List[Dict[str, Any]], - batch_size: int = 2048, - content_field: str = "content" - ) -> int: - """ - Smart batch insertion - automatically selecting strategy based on data size - - Args: - index_name: Name of the index to add documents to - embedding_model: Model used to generate embeddings for documents - documents: List of document dictionaries - batch_size: Number of documents to process at once - content_field: Field to use for generating embeddings - - Returns: - int: Number of documents successfully indexed - """ - logger.info(f"Indexing {len(documents)} chunks to {index_name}") - - # Handle empty documents list - if not documents: - return 0 - - # Smart strategy selection - total_docs = len(documents) - if total_docs < 100: - # Small data: direct insertion, using wait_for refresh - return self._small_batch_insert(index_name, documents, content_field, embedding_model) - else: - # Large data: using context manager - estimated_duration = max(60, total_docs // 100) - with self.bulk_operation_context(index_name, estimated_duration): - return self._large_batch_insert(index_name, documents, batch_size, content_field, embedding_model) - - def _small_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], content_field: str, embedding_model:BaseEmbedding) -> int: - """Small batch insertion: real-time""" - try: - # Preprocess documents - processed_docs = self._preprocess_documents(documents, content_field) - - # Get embeddings - inputs = [doc[content_field] for doc in processed_docs] - embeddings = embedding_model.get_embeddings(inputs) - - # Prepare bulk operations - operations = [] - for doc, embedding in zip(processed_docs, embeddings): - operations.append({"index": {"_index": index_name}}) - doc["embedding"] = embedding - if "embedding_model_name" not in doc: - doc["embedding_model_name"] = embedding_model.embedding_model_name - operations.append(doc) - - # Execute bulk insertion, wait for refresh to complete - response = self.client.bulk( - index=index_name, - operations=operations, - refresh='wait_for' - ) - - # Handle errors - self._handle_bulk_errors(response) - - logger.info(f"Small batch insert completed: {len(documents)} chunks indexed.") - return len(documents) - - except Exception as e: - logger.error(f"Small batch insert failed: {e}") - return 0 - - def _large_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], batch_size: int, content_field: str, embedding_model: BaseEmbedding) -> int: - """ - Large batch insertion with sub-batching for embedding API. - Splits large document batches into smaller chunks to respect embedding API limits before bulk inserting into Elasticsearch. - """ - try: - processed_docs = self._preprocess_documents(documents, content_field) - total_indexed = 0 - total_docs = len(processed_docs) - es_total_batches = (total_docs + batch_size - 1) // batch_size - - for i in range(0, total_docs, batch_size): - es_batch = processed_docs[i:i + batch_size] - es_batch_num = i // batch_size + 1 - - # Store documents and their embeddings for this Elasticsearch batch - doc_embedding_pairs = [] - - # Sub-batch for embedding API - embedding_batch_size = self.max_texts_per_batch - for j in range(0, len(es_batch), embedding_batch_size): - embedding_sub_batch = es_batch[j:j + embedding_batch_size] - - try: - inputs = [doc[content_field] for doc in embedding_sub_batch] - embeddings = embedding_model.get_embeddings(inputs) - - for doc, embedding in zip(embedding_sub_batch, embeddings): - doc_embedding_pairs.append((doc, embedding)) - - except Exception as e: - logger.error(f"Embedding API error: {e}, ES batch num: {es_batch_num}, sub-batch start: {j}, size: {len(embedding_sub_batch)}") - continue - - # Perform a single bulk insert for the entire Elasticsearch batch - if not doc_embedding_pairs: - logger.warning(f"No documents with embeddings to index for ES batch {es_batch_num}") - continue - - operations = [] - for doc, embedding in doc_embedding_pairs: - operations.append({"index": {"_index": index_name}}) - doc["embedding"] = embedding - if "embedding_model_name" not in doc: - doc["embedding_model_name"] = getattr(embedding_model, 'embedding_model_name', 'unknown') - operations.append(doc) - - try: - response = self.client.bulk( - index=index_name, - operations=operations, - refresh=False - ) - self._handle_bulk_errors(response) - total_indexed += len(doc_embedding_pairs) - logger.info(f"Processed ES batch {es_batch_num}/{es_total_batches}, indexed {len(doc_embedding_pairs)} documents.") - - except Exception as e: - logger.error(f"Bulk insert error: {e}, ES batch num: {es_batch_num}") - continue - - if es_batch_num % 10 == 0: - time.sleep(0.1) - - self._force_refresh_with_retry(index_name) - logger.info(f"Large batch insert completed: {total_indexed} chunks indexed.") - return total_indexed - except Exception as e: - logger.error(f"Large batch insert failed: {e}") - return 0 - - def _preprocess_documents(self, documents: List[Dict[str, Any]], content_field: str) -> List[Dict[str, Any]]: - """Ensure all documents have the required fields and set default values""" - current_time = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()) - current_date = time.strftime('%Y-%m-%d', time.gmtime()) - - processed_docs = [] - for doc in documents: - # Create a copy of the document to avoid modifying the original data - doc_copy = doc.copy() - - # Set create_time if not present - if not doc_copy.get("create_time"): - doc_copy["create_time"] = current_time - - if not doc_copy.get("date"): - doc_copy["date"] = current_date - - # Ensure file_size is present (default to 0 if not provided) - if not doc_copy.get("file_size"): - logger.warning(f"File size not found in {doc_copy}") - doc_copy["file_size"] = 0 - - # Ensure process_source is present - if not doc_copy.get("process_source"): - doc_copy["process_source"] = "Unstructured" - - # Ensure all documents have an ID - if not doc_copy.get("id"): - doc_copy["id"] = f"{int(time.time())}_{hash(doc_copy[content_field])}"[:20] - - processed_docs.append(doc_copy) - - return processed_docs - - def _handle_bulk_errors(self, response: Dict[str, Any]) -> None: - """Handle bulk operation errors""" - if response.get('errors'): - for item in response['items']: - if 'error' in item.get('index', {}): - error_info = item['index']['error'] - error_type = error_info.get('type') - error_reason = error_info.get('reason') - error_cause = error_info.get('caused_by', {}) - - if error_type == 'version_conflict_engine_exception': - # ignore version conflict - continue - else: - logger.error(f"FATAL ERROR {error_type}: {error_reason}") - if error_cause: - logger.error(f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}") - - def delete_documents_by_path_or_url(self, index_name: str, path_or_url: str) -> int: - """ - Delete documents based on their path_or_url field - - Args: - index_name: Name of the index to delete documents from - path_or_url: The URL or path of the documents to delete - - Returns: - int: Number of documents deleted - """ - try: - result = self.client.delete_by_query( - index=index_name, - body={ - "query": { - "term": { - "path_or_url": path_or_url - } - } - } - ) - logger.info(f"Successfully deleted {result['deleted']} documents with path_or_url: {path_or_url} from index: {index_name}") - return result['deleted'] - except Exception as e: - logger.error(f"Error deleting documents: {str(e)}") - return 0 - - # ---- SEARCH OPERATIONS ---- - - def accurate_search(self, index_names: List[str], query_text: str, top_k: int = 5) -> List[Dict[str, Any]]: - """ - Search for documents using fuzzy text matching across multiple indices. - - Args: - index_names: Name of the index to search in - query_text: The text query to search for - top_k: Number of results to return - - Returns: - List of search results with scores and document content - """ - # Join index names for multi-index search - index_pattern = ",".join(index_names) - - weights = calculate_term_weights(query_text) - - # Prepare the search query using match query for fuzzy matching - search_query = build_weighted_query(query_text, weights) | { - "size": top_k, - "_source": { - "excludes": ["embedding"] - } - } - - # Execute the search across multiple indices - return self.exec_query(index_pattern, search_query) - - def exec_query(self, index_pattern, search_query): - response = self.client.search( - index=index_pattern, - body=search_query - ) - # Process and return results - results = [] - for hit in response["hits"]["hits"]: - results.append({ - "score": hit["_score"], - "document": hit["_source"], - "index": hit["_index"] # Include source index in results - }) - return results - - def semantic_search(self, index_names: List[str], query_text: str, embedding_model: BaseEmbedding, top_k: int = 5) -> List[Dict[str, Any]]: - """ - Search for similar documents using vector similarity across multiple indices. - - Args: - index_names: List of index names to search in - query_text: The text query to search for - embedding_model: The embedding model to use - top_k: Number of results to return - - Returns: - List of search results with scores and document content - """ - # Join index names for multi-index search - index_pattern = ",".join(index_names) - - # Get query embedding - query_embedding = embedding_model.get_embeddings(query_text)[0] - - # Prepare the search query - search_query = { - "knn": { - "field": "embedding", - "query_vector": query_embedding, - "k": top_k, - "num_candidates": top_k * 2, - }, - "size": top_k, - "_source": { - "excludes": ["embedding"] - } - } - - # Execute the search across multiple indices - return self.exec_query(index_pattern, search_query) - - def hybrid_search( - self, - index_names: List[str], - query_text: str, - embedding_model: BaseEmbedding, - top_k: int = 5, - weight_accurate: float = 0.3 - ) -> List[Dict[str, Any]]: - """ - Hybrid search method, combining accurate matching and semantic search results across multiple indices. - - Args: - index_names: List of index names to search in - query_text: The text query to search for - embedding_model: The embedding model to use - top_k: Number of results to return - weight_accurate: The weight of the accurate matching score (0-1), the semantic search weight is 1-weight_accurate - - Returns: - List of search results sorted by combined score - """ - # Get results from both searches - accurate_results = self.accurate_search(index_names, query_text, top_k=top_k) - semantic_results = self.semantic_search(index_names, query_text, embedding_model=embedding_model, top_k=top_k) - - # Create a mapping from document ID to results - combined_results = {} - - # Process accurate matching results - for result in accurate_results: - try: - doc_id = result['document']['id'] - combined_results[doc_id] = { - 'document': result['document'], - 'accurate_score': result.get('score', 0), - 'semantic_score': 0, - 'index': result['index'] # Keep track of source index - } - except KeyError as e: - logger.warning(f"Warning: Missing required field in accurate result: {e}") - continue - - # Process semantic search results - for result in semantic_results: - try: - doc_id = result['document']['id'] - if doc_id in combined_results: - combined_results[doc_id]['semantic_score'] = result.get('score', 0) - else: - combined_results[doc_id] = { - 'document': result['document'], - 'accurate_score': 0, - 'semantic_score': result.get('score', 0), - 'index': result['index'] # Keep track of source index - } - except KeyError as e: - logger.warning(f"Warning: Missing required field in semantic result: {e}") - continue - - # Calculate maximum scores - max_accurate = max([r.get('score', 0) for r in accurate_results]) if accurate_results else 1 - max_semantic = max([r.get('score', 0) for r in semantic_results]) if semantic_results else 1 - - # Calculate combined scores and sort - results = [] - for doc_id, result in combined_results.items(): - try: - # Get scores safely - accurate_score = result.get('accurate_score', 0) - semantic_score = result.get('semantic_score', 0) - - # Normalize scores - normalized_accurate = accurate_score / max_accurate if max_accurate > 0 else 0 - normalized_semantic = semantic_score / max_semantic if max_semantic > 0 else 0 - - # Calculate weighted combined score - combined_score = (weight_accurate * normalized_accurate + - (1 - weight_accurate) * normalized_semantic) - - results.append({ - 'score': combined_score, - 'document': result['document'], - 'index': result['index'], # Include source index in results - 'scores': { - 'accurate': normalized_accurate, - 'semantic': normalized_semantic - } - }) - except KeyError as e: - logger.warning(f"Warning: Error processing result for doc_id {doc_id}: {e}") - continue - - # Sort by combined score and return top k results - results.sort(key=lambda x: x['score'], reverse=True) - return results[:top_k] - - # ---- STATISTICS AND MONITORING ---- - def get_file_list_with_details(self, index_name: str) -> List[Dict[str, Any]]: - """ - Get a list of unique path_or_url values with their file_size and create_time - - Args: - index_name: Name of the index to query - - Returns: - List of dictionaries with path_or_url, file_size, and create_time - """ - agg_query = { - "size": 0, - "aggs": { - "unique_sources": { - "terms": { - "field": "path_or_url", - "size": 1000 # Limit to 1000 files for performance - }, - "aggs": { - "file_sample": { - "top_hits": { - "size": 1, - "_source": ["path_or_url", "file_size", "create_time", "filename"] - } - } - } - } - } - } - - try: - result = self.client.search( - index=index_name, - body=agg_query - ) - - file_list = [] - for bucket in result['aggregations']['unique_sources']['buckets']: - source = bucket['file_sample']['hits']['hits'][0]['_source'] - file_info = { - "path_or_url": source["path_or_url"], - "filename": source.get("filename", ""), - "file_size": source.get("file_size", 0), - "create_time": source.get("create_time", None) - } - file_list.append(file_info) - - return file_list - except Exception as e: - logger.error(f"Error getting file list: {str(e)}") - return [] - - def get_index_mapping(self, index_names: List[str]) -> Dict[str, List[str]]: - """Get field mappings for multiple indices""" - mappings = {} - for index_name in index_names: - try: - mapping = self.client.indices.get_mapping(index=index_name) - if mapping[index_name].get('mappings') and mapping[index_name]['mappings'].get('properties'): - mappings[index_name] = list(mapping[index_name]['mappings']['properties'].keys()) - else: - mappings[index_name] = [] - except Exception as e: - logger.error(f"Error getting mapping for index {index_name}: {str(e)}") - mappings[index_name] = [] - return mappings - - def get_index_stats(self, index_names: List[str], embedding_dim: Optional[int] = None) -> Dict[str, Dict[str, Dict[str, Any]]]: - """Get formatted statistics for multiple indices""" - all_stats = {} - for index_name in index_names: - try: - stats = self.client.indices.stats(index=index_name) - settings = self.client.indices.get_settings(index=index_name) - - # Merge query - agg_query = { - "size": 0, - "aggs": { - "unique_path_or_url_count": { - "cardinality": { - "field": "path_or_url" - } - }, - "process_sources": { - "terms": { - "field": "process_source", - "size": 10 - } - }, - "embedding_models": { - "terms": { - "field": "embedding_model_name", - "size": 10 - } - } - } - } - - # Execute query - agg_result = self.client.search( - index=index_name, - body=agg_query - ) - - unique_sources_count = agg_result['aggregations']['unique_path_or_url_count']['value'] - process_source = agg_result['aggregations']['process_sources']['buckets'][0]['key'] if agg_result['aggregations']['process_sources']['buckets'] else "" - embedding_model = agg_result['aggregations']['embedding_models']['buckets'][0]['key'] if agg_result['aggregations']['embedding_models']['buckets'] else "" - - index_stats = stats["indices"][index_name]["primaries"] - - # Get creation and update timestamps from settings - creation_date = int(settings[index_name]['settings']['index']['creation_date']) - # Update time defaults to creation time if not modified - update_time = creation_date - - all_stats[index_name] = { - "base_info": { - "doc_count": unique_sources_count, - "chunk_count": index_stats["docs"]["count"], - "store_size": format_size(index_stats["store"]["size_in_bytes"]), - "process_source": process_source, - "embedding_model": embedding_model, - "embedding_dim": embedding_dim or 1024, - "creation_date": creation_date, - "update_date": update_time - }, - "search_performance": { - "total_search_count": index_stats["search"]["query_total"], - "hit_count": index_stats["request_cache"]["hit_count"], - } - } - except Exception as e: - logger.error(f"Error getting stats for index {index_name}: {str(e)}") - all_stats[index_name] = {"error": str(e)} - - return all_stats +import time +import logging +import threading +from typing import List, Dict, Any, Optional +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime, timedelta +from ..core.models.embedding_model import BaseEmbedding +from .utils import format_size, format_timestamp, build_weighted_query +from elasticsearch import Elasticsearch, exceptions + +from ..core.nlp.tokenizer import calculate_term_weights + +logger = logging.getLogger("elasticsearch_core") + +@dataclass +class BulkOperation: + """Bulk operation status tracking""" + index_name: str + operation_id: str + start_time: datetime + expected_duration: timedelta + +class ElasticSearchCore: + """ + Core class for Elasticsearch operations including: + - Index management + - Document insertion with embeddings + - Document deletion + - Accurate text search + - Semantic vector search + - Hybrid search + - Index statistics + """ + + def __init__( + self, + host: Optional[str], + api_key: Optional[str], + verify_certs: bool = False, + ssl_show_warn: bool = False, + ): + """ + Initialize ElasticSearchCore with Elasticsearch client and JinaEmbedding model. + + Args: + host: Elasticsearch host URL (defaults to env variable) + api_key: Elasticsearch API key (defaults to env variable) + verify_certs: Whether to verify SSL certificates + ssl_show_warn: Whether to show SSL warnings + """ + # Get credentials from environment if not provided + self.host = host + self.api_key = api_key + + # Initialize Elasticsearch client with HTTPS support + self.client = Elasticsearch( + self.host, + api_key=self.api_key, + verify_certs=verify_certs, + ssl_show_warn=ssl_show_warn, + request_timeout=20, + max_retries=3, # Reduce retries for faster failure detection + retry_on_timeout=True, + retry_on_status=[502, 503, 504], # Retry on these status codes, + ) + + # Initialize embedding model + self._bulk_operations: Dict[str, List[BulkOperation]] = {} + self._settings_lock = threading.Lock() + self._operation_counter = 0 + + # Embedding API limits + self.max_texts_per_batch = 2048 + self.max_tokens_per_text = 8192 + self.max_total_tokens = 100000 + + # ---- INDEX MANAGEMENT ---- + + def create_vector_index(self, index_name: str, embedding_dim: Optional[int] = None) -> bool: + """ + Create a new vector search index with appropriate mappings in a celery-friendly way. + + Args: + index_name: Name of the index to create + embedding_dim: Dimension of the embedding vectors (optional, will use model's dim if not provided) + + Returns: + bool: True if creation was successful + """ + try: + # Use provided embedding_dim or get from model + actual_embedding_dim = embedding_dim or 1024 + + # Use balanced fixed settings to avoid dynamic adjustment + settings = { + "number_of_shards": 1, + "number_of_replicas": 0, + "refresh_interval": "5s", # not too fast, not too slow + "index": { + "max_result_window": 50000, + "translog": { + "durability": "async", + "sync_interval": "5s" + }, + "write": { + "wait_for_active_shards": "1" + }, + # Memory optimization for bulk operations + "merge": { + "policy": { + "max_merge_at_once": 5, + "segments_per_tier": 5 + } + } + } + } + + # Check if index already exists + if self.client.indices.exists(index=index_name): + logger.info(f"Index {index_name} already exists, skipping creation") + self._ensure_index_ready(index_name) + return True + + # Define the mapping with vector field + mappings = { + "properties": { + "id": {"type": "keyword"}, + "title": {"type": "text"}, + "filename": {"type": "keyword"}, + "path_or_url": {"type": "keyword"}, + "language": {"type": "keyword"}, + "author": {"type": "keyword"}, + "date": {"type": "date"}, + "content": {"type": "text"}, + "process_source": {"type": "keyword"}, + "embedding_model_name": {"type": "keyword"}, + "file_size": {"type": "long"}, + "create_time": {"type": "date"}, + "embedding": { + "type": "dense_vector", + "dims": actual_embedding_dim, + "index": "true", + "similarity": "cosine", + }, + } + } + + # Create the index with the defined mappings + self.client.indices.create( + index=index_name, + mappings=mappings, + settings=settings, + wait_for_active_shards="1" + ) + + # Force refresh to ensure visibility + self._force_refresh_with_retry(index_name) + self._ensure_index_ready(index_name) + + logger.info(f"Successfully created index: {index_name}") + return True + + except exceptions.RequestError as e: + # Handle the case where index already exists (error 400) + if "resource_already_exists_exception" in str(e): + logger.info(f"Index {index_name} already exists, skipping creation") + self._ensure_index_ready(index_name) + return True + logger.error(f"Error creating index: {str(e)}") + return False + except Exception as e: + logger.error(f"Error creating index: {str(e)}") + return False + + def _force_refresh_with_retry(self, index_name: str, max_retries: int = 3) -> bool: + """ + Force refresh with retry - synchronous version + """ + for attempt in range(max_retries): + try: + self.client.indices.refresh(index=index_name) + return True + except Exception as e: + if attempt < max_retries - 1: + time.sleep(0.5 * (attempt + 1)) + continue + logger.error(f"Failed to refresh index {index_name}: {e}") + return False + return False + + def _ensure_index_ready(self, index_name: str, timeout: int = 10) -> bool: + """ + Ensure index is ready, avoid 503 error - synchronous version + """ + start_time = time.time() + + while time.time() - start_time < timeout: + try: + # Check cluster health + health = self.client.cluster.health( + index=index_name, + wait_for_status="yellow", + timeout="1s" + ) + + if health["status"] in ["green", "yellow"]: + # Double check: try simple query + self.client.search( + index=index_name, + body={"query": {"match_all": {}}, "size": 0} + ) + return True + + except Exception as e: + time.sleep(0.1) + + logger.warning(f"Index {index_name} may not be fully ready after {timeout}s") + return False + + @contextmanager + def bulk_operation_context(self, index_name: str, estimated_duration: int = 60): + """ + Celery-friendly context manager - using threading.Lock + """ + operation_id = f"bulk_{self._operation_counter}_{threading.current_thread().name}" + self._operation_counter += 1 + + operation = BulkOperation( + index_name=index_name, + operation_id=operation_id, + start_time=datetime.now(), + expected_duration=timedelta(seconds=estimated_duration) + ) + + with self._settings_lock: + # Record current operation + if index_name not in self._bulk_operations: + self._bulk_operations[index_name] = [] + self._bulk_operations[index_name].append(operation) + + # If this is the first bulk operation, adjust settings + if len(self._bulk_operations[index_name]) == 1: + self._apply_bulk_settings(index_name) + + try: + yield operation_id + finally: + with self._settings_lock: + # Remove operation record + self._bulk_operations[index_name] = [ + op for op in self._bulk_operations[index_name] + if op.operation_id != operation_id + ] + + # If there are no other bulk operations, restore settings + if not self._bulk_operations[index_name]: + self._restore_normal_settings(index_name) + del self._bulk_operations[index_name] + + def _apply_bulk_settings(self, index_name: str): + """Apply bulk operation optimization settings""" + try: + self.client.indices.put_settings( + index=index_name, + body={ + "refresh_interval": "30s", + "translog.durability": "async", + "translog.sync_interval": "10s" + } + ) + logger.info(f"Applied bulk settings to {index_name}") + except Exception as e: + logger.warning(f"Failed to apply bulk settings: {e}") + + def _restore_normal_settings(self, index_name: str): + """Restore normal settings""" + try: + self.client.indices.put_settings( + index=index_name, + body={ + "refresh_interval": "5s", + "translog.durability": "request" + } + ) + # Refresh after restoration + self._force_refresh_with_retry(index_name) + logger.info(f"Restored normal settings for {index_name}") + except Exception as e: + logger.warning(f"Failed to restore settings: {e}") + + def delete_index(self, index_name: str) -> bool: + """ + Delete an entire index + + Args: + index_name: Name of the index to delete + + Returns: + bool: True if deletion was successful + """ + try: + self.client.indices.delete(index=index_name) + logger.info(f"Successfully deleted the index: {index_name}") + return True + except exceptions.NotFoundError: + logger.info(f"Index {index_name} not found") + return False + except Exception as e: + logger.error(f"Error deleting index: {str(e)}") + return False + + def get_user_indices(self, index_pattern: str = "*") -> List[str]: + """ + Get list of user created indices (excluding system indices) + + Args: + index_pattern: Pattern to match index names + + Returns: + List of index names + """ + try: + indices = self.client.indices.get_alias(index=index_pattern) + # Filter out system indices (starting with '.') + return [index_name for index_name in indices.keys() if not index_name.startswith('.')] + except Exception as e: + logger.error(f"Error getting user indices: {str(e)}") + return [] + + # ---- DOCUMENT OPERATIONS ---- + + def index_documents( + self, + index_name: str, + embedding_model: BaseEmbedding, + documents: List[Dict[str, Any]], + batch_size: int = 2048, + content_field: str = "content" + ) -> int: + """ + Smart batch insertion - automatically selecting strategy based on data size + + Args: + index_name: Name of the index to add documents to + embedding_model: Model used to generate embeddings for documents + documents: List of document dictionaries + batch_size: Number of documents to process at once + content_field: Field to use for generating embeddings + + Returns: + int: Number of documents successfully indexed + """ + logger.info(f"Indexing {len(documents)} chunks to {index_name}") + + # Handle empty documents list + if not documents: + return 0 + + # Smart strategy selection + total_docs = len(documents) + if total_docs < 100: + # Small data: direct insertion, using wait_for refresh + return self._small_batch_insert(index_name, documents, content_field, embedding_model) + else: + # Large data: using context manager + estimated_duration = max(60, total_docs // 100) + with self.bulk_operation_context(index_name, estimated_duration): + return self._large_batch_insert(index_name, documents, batch_size, content_field, embedding_model) + + def _small_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], content_field: str, embedding_model:BaseEmbedding) -> int: + """Small batch insertion: real-time""" + try: + # Preprocess documents + processed_docs = self._preprocess_documents(documents, content_field) + + # Get embeddings + inputs = [doc[content_field] for doc in processed_docs] + embeddings = embedding_model.get_embeddings(inputs) + + # Prepare bulk operations + operations = [] + for doc, embedding in zip(processed_docs, embeddings): + operations.append({"index": {"_index": index_name}}) + doc["embedding"] = embedding + if "embedding_model_name" not in doc: + doc["embedding_model_name"] = embedding_model.embedding_model_name + operations.append(doc) + + # Execute bulk insertion, wait for refresh to complete + response = self.client.bulk( + index=index_name, + operations=operations, + refresh='wait_for' + ) + + # Handle errors + self._handle_bulk_errors(response) + + logger.info(f"Small batch insert completed: {len(documents)} chunks indexed.") + return len(documents) + + except Exception as e: + logger.error(f"Small batch insert failed: {e}") + return 0 + + def _large_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], batch_size: int, content_field: str, embedding_model: BaseEmbedding) -> int: + """ + Large batch insertion with sub-batching for embedding API. + Splits large document batches into smaller chunks to respect embedding API limits before bulk inserting into Elasticsearch. + """ + try: + processed_docs = self._preprocess_documents(documents, content_field) + total_indexed = 0 + total_docs = len(processed_docs) + es_total_batches = (total_docs + batch_size - 1) // batch_size + + for i in range(0, total_docs, batch_size): + es_batch = processed_docs[i:i + batch_size] + es_batch_num = i // batch_size + 1 + + # Store documents and their embeddings for this Elasticsearch batch + doc_embedding_pairs = [] + + # Sub-batch for embedding API + embedding_batch_size = self.max_texts_per_batch + for j in range(0, len(es_batch), embedding_batch_size): + embedding_sub_batch = es_batch[j:j + embedding_batch_size] + + try: + inputs = [doc[content_field] for doc in embedding_sub_batch] + embeddings = embedding_model.get_embeddings(inputs) + + for doc, embedding in zip(embedding_sub_batch, embeddings): + doc_embedding_pairs.append((doc, embedding)) + + except Exception as e: + logger.error(f"Embedding API error: {e}, ES batch num: {es_batch_num}, sub-batch start: {j}, size: {len(embedding_sub_batch)}") + continue + + # Perform a single bulk insert for the entire Elasticsearch batch + if not doc_embedding_pairs: + logger.warning(f"No documents with embeddings to index for ES batch {es_batch_num}") + continue + + operations = [] + for doc, embedding in doc_embedding_pairs: + operations.append({"index": {"_index": index_name}}) + doc["embedding"] = embedding + if "embedding_model_name" not in doc: + doc["embedding_model_name"] = getattr(embedding_model, 'embedding_model_name', 'unknown') + operations.append(doc) + + try: + response = self.client.bulk( + index=index_name, + operations=operations, + refresh=False + ) + self._handle_bulk_errors(response) + total_indexed += len(doc_embedding_pairs) + logger.info(f"Processed ES batch {es_batch_num}/{es_total_batches}, indexed {len(doc_embedding_pairs)} documents.") + + except Exception as e: + logger.error(f"Bulk insert error: {e}, ES batch num: {es_batch_num}") + continue + + if es_batch_num % 10 == 0: + time.sleep(0.1) + + self._force_refresh_with_retry(index_name) + logger.info(f"Large batch insert completed: {total_indexed} chunks indexed.") + return total_indexed + except Exception as e: + logger.error(f"Large batch insert failed: {e}") + return 0 + + def _preprocess_documents(self, documents: List[Dict[str, Any]], content_field: str) -> List[Dict[str, Any]]: + """Ensure all documents have the required fields and set default values""" + current_time = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()) + current_date = time.strftime('%Y-%m-%d', time.gmtime()) + + processed_docs = [] + for doc in documents: + # Create a copy of the document to avoid modifying the original data + doc_copy = doc.copy() + + # Set create_time if not present + if not doc_copy.get("create_time"): + doc_copy["create_time"] = current_time + + if not doc_copy.get("date"): + doc_copy["date"] = current_date + + # Ensure file_size is present (default to 0 if not provided) + if not doc_copy.get("file_size"): + logger.warning(f"File size not found in {doc_copy}") + doc_copy["file_size"] = 0 + + # Ensure process_source is present + if not doc_copy.get("process_source"): + doc_copy["process_source"] = "Unstructured" + + # Ensure all documents have an ID + if not doc_copy.get("id"): + doc_copy["id"] = f"{int(time.time())}_{hash(doc_copy[content_field])}"[:20] + + processed_docs.append(doc_copy) + + return processed_docs + + def _handle_bulk_errors(self, response: Dict[str, Any]) -> None: + """Handle bulk operation errors""" + if response.get('errors'): + for item in response['items']: + if 'error' in item.get('index', {}): + error_info = item['index']['error'] + error_type = error_info.get('type') + error_reason = error_info.get('reason') + error_cause = error_info.get('caused_by', {}) + + if error_type == 'version_conflict_engine_exception': + # ignore version conflict + continue + else: + logger.error(f"FATAL ERROR {error_type}: {error_reason}") + if error_cause: + logger.error(f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}") + + def delete_documents_by_path_or_url(self, index_name: str, path_or_url: str) -> int: + """ + Delete documents based on their path_or_url field + + Args: + index_name: Name of the index to delete documents from + path_or_url: The URL or path of the documents to delete + + Returns: + int: Number of documents deleted + """ + try: + result = self.client.delete_by_query( + index=index_name, + body={ + "query": { + "term": { + "path_or_url": path_or_url + } + } + } + ) + logger.info(f"Successfully deleted {result['deleted']} documents with path_or_url: {path_or_url} from index: {index_name}") + return result['deleted'] + except Exception as e: + logger.error(f"Error deleting documents: {str(e)}") + return 0 + + # ---- SEARCH OPERATIONS ---- + + def accurate_search(self, index_names: List[str], query_text: str, top_k: int = 5) -> List[Dict[str, Any]]: + """ + Search for documents using fuzzy text matching across multiple indices. + + Args: + index_names: Name of the index to search in + query_text: The text query to search for + top_k: Number of results to return + + Returns: + List of search results with scores and document content + """ + # Join index names for multi-index search + index_pattern = ",".join(index_names) + + weights = calculate_term_weights(query_text) + + # Prepare the search query using match query for fuzzy matching + search_query = build_weighted_query(query_text, weights) | { + "size": top_k, + "_source": { + "excludes": ["embedding"] + } + } + + # Execute the search across multiple indices + return self.exec_query(index_pattern, search_query) + + def exec_query(self, index_pattern, search_query): + response = self.client.search( + index=index_pattern, + body=search_query + ) + # Process and return results + results = [] + for hit in response["hits"]["hits"]: + results.append({ + "score": hit["_score"], + "document": hit["_source"], + "index": hit["_index"] # Include source index in results + }) + return results + + def semantic_search(self, index_names: List[str], query_text: str, embedding_model: BaseEmbedding, top_k: int = 5) -> List[Dict[str, Any]]: + """ + Search for similar documents using vector similarity across multiple indices. + + Args: + index_names: List of index names to search in + query_text: The text query to search for + embedding_model: The embedding model to use + top_k: Number of results to return + + Returns: + List of search results with scores and document content + """ + # Join index names for multi-index search + index_pattern = ",".join(index_names) + + # Get query embedding + query_embedding = embedding_model.get_embeddings(query_text)[0] + + # Prepare the search query + search_query = { + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_k, + "num_candidates": top_k * 2, + }, + "size": top_k, + "_source": { + "excludes": ["embedding"] + } + } + + # Execute the search across multiple indices + return self.exec_query(index_pattern, search_query) + + def hybrid_search( + self, + index_names: List[str], + query_text: str, + embedding_model: BaseEmbedding, + top_k: int = 5, + weight_accurate: float = 0.3 + ) -> List[Dict[str, Any]]: + """ + Hybrid search method, combining accurate matching and semantic search results across multiple indices. + + Args: + index_names: List of index names to search in + query_text: The text query to search for + embedding_model: The embedding model to use + top_k: Number of results to return + weight_accurate: The weight of the accurate matching score (0-1), the semantic search weight is 1-weight_accurate + + Returns: + List of search results sorted by combined score + """ + # Get results from both searches + accurate_results = self.accurate_search(index_names, query_text, top_k=top_k) + semantic_results = self.semantic_search(index_names, query_text, embedding_model=embedding_model, top_k=top_k) + + # Create a mapping from document ID to results + combined_results = {} + + # Process accurate matching results + for result in accurate_results: + try: + doc_id = result['document']['id'] + combined_results[doc_id] = { + 'document': result['document'], + 'accurate_score': result.get('score', 0), + 'semantic_score': 0, + 'index': result['index'] # Keep track of source index + } + except KeyError as e: + logger.warning(f"Warning: Missing required field in accurate result: {e}") + continue + + # Process semantic search results + for result in semantic_results: + try: + doc_id = result['document']['id'] + if doc_id in combined_results: + combined_results[doc_id]['semantic_score'] = result.get('score', 0) + else: + combined_results[doc_id] = { + 'document': result['document'], + 'accurate_score': 0, + 'semantic_score': result.get('score', 0), + 'index': result['index'] # Keep track of source index + } + except KeyError as e: + logger.warning(f"Warning: Missing required field in semantic result: {e}") + continue + + # Calculate maximum scores + max_accurate = max([r.get('score', 0) for r in accurate_results]) if accurate_results else 1 + max_semantic = max([r.get('score', 0) for r in semantic_results]) if semantic_results else 1 + + # Calculate combined scores and sort + results = [] + for doc_id, result in combined_results.items(): + try: + # Get scores safely + accurate_score = result.get('accurate_score', 0) + semantic_score = result.get('semantic_score', 0) + + # Normalize scores + normalized_accurate = accurate_score / max_accurate if max_accurate > 0 else 0 + normalized_semantic = semantic_score / max_semantic if max_semantic > 0 else 0 + + # Calculate weighted combined score + combined_score = (weight_accurate * normalized_accurate + + (1 - weight_accurate) * normalized_semantic) + + results.append({ + 'score': combined_score, + 'document': result['document'], + 'index': result['index'], # Include source index in results + 'scores': { + 'accurate': normalized_accurate, + 'semantic': normalized_semantic + } + }) + except KeyError as e: + logger.warning(f"Warning: Error processing result for doc_id {doc_id}: {e}") + continue + + # Sort by combined score and return top k results + results.sort(key=lambda x: x['score'], reverse=True) + return results[:top_k] + + # ---- STATISTICS AND MONITORING ---- + def get_file_list_with_details(self, index_name: str) -> List[Dict[str, Any]]: + """ + Get a list of unique path_or_url values with their file_size and create_time + + Args: + index_name: Name of the index to query + + Returns: + List of dictionaries with path_or_url, file_size, and create_time + """ + agg_query = { + "size": 0, + "aggs": { + "unique_sources": { + "terms": { + "field": "path_or_url", + "size": 1000 # Limit to 1000 files for performance + }, + "aggs": { + "file_sample": { + "top_hits": { + "size": 1, + "_source": ["path_or_url", "file_size", "create_time", "filename"] + } + } + } + } + } + } + + try: + result = self.client.search( + index=index_name, + body=agg_query + ) + + file_list = [] + for bucket in result['aggregations']['unique_sources']['buckets']: + source = bucket['file_sample']['hits']['hits'][0]['_source'] + file_info = { + "path_or_url": source["path_or_url"], + "filename": source.get("filename", ""), + "file_size": source.get("file_size", 0), + "create_time": source.get("create_time", None) + } + file_list.append(file_info) + + return file_list + except Exception as e: + logger.error(f"Error getting file list: {str(e)}") + return [] + + def get_index_mapping(self, index_names: List[str]) -> Dict[str, List[str]]: + """Get field mappings for multiple indices""" + mappings = {} + for index_name in index_names: + try: + mapping = self.client.indices.get_mapping(index=index_name) + if mapping[index_name].get('mappings') and mapping[index_name]['mappings'].get('properties'): + mappings[index_name] = list(mapping[index_name]['mappings']['properties'].keys()) + else: + mappings[index_name] = [] + except Exception as e: + logger.error(f"Error getting mapping for index {index_name}: {str(e)}") + mappings[index_name] = [] + return mappings + + def get_index_stats(self, index_names: List[str], embedding_dim: Optional[int] = None) -> Dict[str, Dict[str, Dict[str, Any]]]: + """Get formatted statistics for multiple indices""" + all_stats = {} + for index_name in index_names: + try: + stats = self.client.indices.stats(index=index_name) + settings = self.client.indices.get_settings(index=index_name) + + # Merge query + agg_query = { + "size": 0, + "aggs": { + "unique_path_or_url_count": { + "cardinality": { + "field": "path_or_url" + } + }, + "process_sources": { + "terms": { + "field": "process_source", + "size": 10 + } + }, + "embedding_models": { + "terms": { + "field": "embedding_model_name", + "size": 10 + } + } + } + } + + # Execute query + agg_result = self.client.search( + index=index_name, + body=agg_query + ) + + unique_sources_count = agg_result['aggregations']['unique_path_or_url_count']['value'] + process_source = agg_result['aggregations']['process_sources']['buckets'][0]['key'] if agg_result['aggregations']['process_sources']['buckets'] else "" + embedding_model = agg_result['aggregations']['embedding_models']['buckets'][0]['key'] if agg_result['aggregations']['embedding_models']['buckets'] else "" + + index_stats = stats["indices"][index_name]["primaries"] + + # Get creation and update timestamps from settings + creation_date = int(settings[index_name]['settings']['index']['creation_date']) + # Update time defaults to creation time if not modified + update_time = creation_date + + all_stats[index_name] = { + "base_info": { + "doc_count": unique_sources_count, + "chunk_count": index_stats["docs"]["count"], + "store_size": format_size(index_stats["store"]["size_in_bytes"]), + "process_source": process_source, + "embedding_model": embedding_model, + "embedding_dim": embedding_dim or 1024, + "creation_date": creation_date, + "update_date": update_time + }, + "search_performance": { + "total_search_count": index_stats["search"]["query_total"], + "hit_count": index_stats["request_cache"]["hit_count"], + } + } + except Exception as e: + logger.error(f"Error getting stats for index {index_name}: {str(e)}") + all_stats[index_name] = {"error": str(e)} + + return all_stats \ No newline at end of file diff --git a/test/backend/services/test_elasticsearch_service.py b/test/backend/services/test_elasticsearch_service.py index 92ac6bbc..6e4b04c5 100644 --- a/test/backend/services/test_elasticsearch_service.py +++ b/test/backend/services/test_elasticsearch_service.py @@ -1171,67 +1171,6 @@ def test_health_check_unhealthy(self): self.assertIn("Health check failed", str(context.exception)) - @patch('backend.services.elasticsearch_service.calculate_term_weights') - @patch('database.model_management_db.get_model_by_model_id') - def test_summary_index_name(self, mock_get_model_by_model_id, mock_calculate_weights): - """ - Test generating a summary for an index. - - This test verifies that: - 1. Random documents are retrieved for summarization - 2. Term weights are calculated to identify important keywords - 3. The summary generation stream is properly initialized - 4. A StreamingResponse object is returned for streaming the summary tokens - """ - # Setup - mock_calculate_weights.return_value = { - "keyword1": 0.8, "keyword2": 0.6} - mock_get_model_by_model_id.return_value = { - 'api_key': 'test_api_key', - 'base_url': 'https://api.test.com', - 'model_name': 'test-model', - 'model_repo': 'test-repo' - } - - # Mock get_random_documents - with patch.object(ElasticSearchService, 'get_random_documents') as mock_get_docs: - mock_get_docs.return_value = { - "documents": [ - {"title": "Doc1", "filename": "file1.txt", "content": "Content1"}, - {"title": "Doc2", "filename": "file2.txt", "content": "Content2"} - ] - } - - # Execute - async def run_test(): - result = await self.es_service.summary_index_name( - index_name="test_index", - batch_size=1000, - es_core=self.mock_es_core, - language='en', - model_id=1, - tenant_id="test_tenant" - ) - - # Consume part of the stream to trigger the generator function - generator = result.body_iterator - # Get at least one item from the generator to trigger execution - try: - async for item in generator: - break # Just get one item to trigger execution - except StopAsyncIteration: - pass - - return result - - result = asyncio.run(run_test()) - - # Assert - self.assertIsInstance(result, StreamingResponse) - mock_get_docs.assert_called_once() - mock_calculate_weights.assert_called_once() - mock_get_model_by_model_id.assert_called_once_with(1, "test_tenant") - def test_get_random_documents(self): """ Test retrieving random documents from an index. diff --git a/test/backend/utils/test_async_knowledge_summary.py b/test/backend/utils/test_async_knowledge_summary.py new file mode 100644 index 00000000..a68e2216 --- /dev/null +++ b/test/backend/utils/test_async_knowledge_summary.py @@ -0,0 +1,395 @@ +""" +Test cases for async knowledge summary functionality +""" + +import sys +import pytest +import asyncio +from unittest.mock import Mock, patch, AsyncMock, MagicMock +import numpy as np + +# Mock external dependencies before importing backend modules +sys.modules['boto3'] = MagicMock() + +# Mock backend modules that have import issues in CI environment +with patch.dict('sys.modules', { + 'backend.database.client': MagicMock(), + 'elasticsearch': MagicMock() +}): + from backend.utils.async_knowledge_summary_utils import ( + AsyncLLMClient, + DocumentClusterer, + ChunkDivider, + ChunkClusterer, + KnowledgeIntegrator, + async_vectorize_batch + ) + + +class TestAsyncLLMClient: + """Test AsyncLLMClient""" + + @pytest.fixture + def model_config(self): + return { + 'api_key': 'test-key', + 'base_url': 'http://test.com', + 'model_name': 'test-model', + 'model_repo': '' + } + + def test_client_initialization(self, model_config): + """Test client initialization with prompt template loading""" + client = AsyncLLMClient(model_config, language='zh') + assert client.model_config == model_config + assert client.model_name == 'test-model' + assert client.language == 'zh' + # Verify prompts are loaded (any prompts, from real YAML) + assert hasattr(client, 'prompts') + assert isinstance(client.prompts, dict) + assert 'SUMMARY_GENERATION_PROMPT' in client.prompts + assert 'KEYWORD_EXTRACTION_PROMPT' in client.prompts + assert 'CLUSTER_INTEGRATION_PROMPT' in client.prompts + assert 'GLOBAL_INTEGRATION_PROMPT' in client.prompts + + def test_chat_async_initialization(self, model_config): + """Test async chat client initialization""" + client = AsyncLLMClient(model_config, language='zh') + + # Test that the client is properly initialized + assert hasattr(client, 'client') + assert hasattr(client, 'semaphore') + assert client.semaphore._value == 3 # Actual semaphore value from implementation + + def test_clean_markdown_symbols(self, model_config): + """Test markdown symbols cleaning""" + client = AsyncLLMClient(model_config, language='zh') + + text_with_markdown = "**Bold** *Italic*" + cleaned = client._clean_markdown_symbols(text_with_markdown) + assert "**" not in cleaned + assert "*" not in cleaned + # The method only cleans certain markdown symbols, not all + + def test_fallback_keyword_extraction(self, model_config): + """Test fallback keyword extraction""" + client = AsyncLLMClient(model_config, language='zh') + + text = "这是一个测试文档关于机器学习和人工智能的内容。" + keywords = client._fallback_keyword_extraction(text) + + assert isinstance(keywords, list) + assert len(keywords) >= 0 + # Should contain some meaningful Chinese words (check if any keywords are extracted) + # The method extracts words >= 2 characters, so we expect some results + + + +class TestDocumentClusterer: + """Test DocumentClusterer""" + + def test_clusterer_initialization(self): + """Test clusterer initialization""" + clusterer = DocumentClusterer(max_clusters=5) + assert clusterer.max_clusters == 5 + + def test_cluster_documents(self): + """Test document clustering""" + clusterer = DocumentClusterer(max_clusters=3) + + # Create test vectors + vectors = np.random.rand(10, 128) + + result = clusterer.cluster_documents(vectors) + + assert result is not None + assert 'cluster_labels' in result + assert 'n_clusters' in result + assert len(result['cluster_labels']) == 10 + assert result['n_clusters'] <= 3 + + def test_single_document_clustering(self): + """Test single document returns single cluster""" + clusterer = DocumentClusterer(max_clusters=5) + + # Single document vector + vectors = np.random.rand(1, 128) + + result = clusterer.cluster_documents(vectors) + + assert result is not None + assert result['n_clusters'] == 1 + assert len(result['cluster_labels']) == 1 + assert result['cluster_labels'][0] == 0 + + def test_find_optimal_k_high_similarity(self): + """Test finding optimal k for highly similar documents""" + clusterer = DocumentClusterer() + + # Create highly similar vectors (similarity > 0.95) + vectors = np.array([ + [1.0, 2.0, 3.0], + [1.1, 2.1, 3.1], + [0.9, 1.9, 2.9] + ]) + + optimal_k = clusterer._find_optimal_k(vectors) + assert optimal_k == 1 # Should return 1 for highly similar documents + + def test_find_optimal_k_multiple_clusters(self): + """Test finding optimal k for distinct document clusters""" + clusterer = DocumentClusterer() + + # Create more distinct clusters with larger separation + vectors = np.array([ + [1.0, 2.0, 3.0], # Cluster 1 + [1.1, 2.1, 3.1], # Cluster 1 + [100.0, 200.0, 300.0], # Cluster 2 - much further away + [100.1, 200.1, 300.1], # Cluster 2 + ]) + + optimal_k = clusterer._find_optimal_k(vectors) + assert optimal_k >= 1 # Should find at least 1 cluster (might find more) + + +class TestChunkDivider: + """Test ChunkDivider""" + + def test_divider_initialization(self): + """Test divider initialization""" + divider = ChunkDivider(window_size=300, overlap_ratio=0.3) + assert divider.window_size == 300 + assert divider.overlap_ratio == 0.3 + + def test_divide_documents(self): + """Test document division into chunks""" + divider = ChunkDivider(window_size=100, overlap_ratio=0.2) + + documents = [ + { + 'content': 'This is a test document. ' * 20, # Long enough to create multiple chunks + 'filename': 'test1.txt', + '_id': '1', + 'title': 'Test Document 1' + } + ] + + chunks = divider.divide_documents(documents) + + assert len(chunks) > 0 + assert all('content' in chunk for chunk in chunks) + assert all('chunk_id' in chunk for chunk in chunks) + + +class TestChunkClusterer: + """Test ChunkClusterer""" + + def test_clusterer_initialization(self): + """Test clusterer initialization""" + clusterer = ChunkClusterer(similarity_threshold=0.7) + assert clusterer.similarity_threshold == 0.7 + + def test_cluster_chunks_with_document_clusters(self): + """Test document-cluster-aware chunk clustering""" + clusterer = ChunkClusterer(similarity_threshold=0.7, min_cluster_size=1) + + # Create test vectors and chunks + vectors = np.random.rand(6, 128) + chunks = [ + { + 'content': f'test content {i}', + 'filename': 'doc1.txt' if i < 3 else 'doc2.txt', + '_id': str(i), + 'chunk_id': f'chunk_{i}' + } + for i in range(6) + ] + + # Simulate document clusters + chunks_by_doc_cluster = { + 0: chunks[0:3], # doc_cluster_0: chunks from doc1 + 1: chunks[3:6] # doc_cluster_1: chunks from doc2 + } + + result = clusterer.cluster_chunks_with_document_clusters(vectors, chunks, chunks_by_doc_cluster) + + assert 'chunk_clusters' in result + assert 'n_clusters' in result + assert result['n_clusters'] >= len(chunks_by_doc_cluster) # At least 1 cluster per doc cluster + + def test_estimate_tokens_for_chunks(self): + """Test token estimation for chunks""" + clusterer = ChunkClusterer() + + chunks = [ + {'content': '这是一个测试文本' * 100}, # ~600 Chinese chars + {'content': '另一个测试' * 50} # ~200 Chinese chars + ] + + tokens = clusterer._estimate_tokens_for_chunks(chunks) + + # Roughly 800 chars / 4 ≈ 200 tokens + assert tokens > 0 + assert tokens < 500 # Should be reasonable + + def test_threshold_clustering(self): + """Test threshold-based clustering""" + clusterer = ChunkClusterer(similarity_threshold=0.7) + + # Create similarity matrix with clear clusters + similarity_matrix = np.array([ + [1.0, 0.9, 0.8, 0.2, 0.1], # Cluster 1 + [0.9, 1.0, 0.85, 0.3, 0.2], # Cluster 1 + [0.8, 0.85, 1.0, 0.25, 0.15], # Cluster 1 + [0.2, 0.3, 0.25, 1.0, 0.9], # Cluster 2 + [0.1, 0.2, 0.15, 0.9, 1.0], # Cluster 2 + ]) + + cluster_labels = clusterer._threshold_clustering(similarity_matrix) + + assert len(cluster_labels) == 5 + # First 3 should be in same cluster, last 2 in another cluster + assert cluster_labels[0] == cluster_labels[1] == cluster_labels[2] + assert cluster_labels[3] == cluster_labels[4] + assert cluster_labels[0] != cluster_labels[3] + + def test_organize_clusters(self): + """Test cluster organization""" + clusterer = ChunkClusterer() + + chunks = [ + {'content': 'chunk1', '_id': '1'}, + {'content': 'chunk2', '_id': '2'}, + {'content': 'chunk3', '_id': '3'}, + ] + + cluster_labels = np.array([0, 0, 1]) + similarity_matrix = np.array([ + [1.0, 0.8, 0.2], + [0.8, 1.0, 0.3], + [0.2, 0.3, 1.0] + ]) + + organized = clusterer._organize_clusters(chunks, cluster_labels, similarity_matrix) + + assert 'chunk_clusters' in organized + assert 'noise_chunks' in organized + assert len(organized['chunk_clusters']) == 2 # Two clusters + assert len(organized['noise_chunks']) == 0 # No noise chunks + + +class TestKnowledgeIntegrator: + """Test KnowledgeIntegrator""" + + @pytest.fixture + def model_config(self): + return { + 'api_key': 'test-key', + 'base_url': 'http://test.com', + 'model_name': 'test-model', + 'model_repo': '' + } + + @pytest.fixture + def llm_client(self, model_config): + return AsyncLLMClient(model_config) + + def test_integrator_initialization(self, llm_client): + """Test integrator initialization""" + integrator = KnowledgeIntegrator(llm_client) + assert integrator.llm_client == llm_client + + def test_integrator_methods_exist(self, llm_client): + """Test that integrator methods exist and are callable""" + integrator = KnowledgeIntegrator(llm_client) + + # Test that all required methods exist + assert hasattr(integrator, 'integrate_cluster_cards') + assert hasattr(integrator, 'integrate_all_clusters') + assert hasattr(integrator, '_generate_cluster_summary_async') + assert hasattr(integrator, '_generate_global_summary_async') + assert hasattr(integrator, '_integrate_keywords') + assert hasattr(integrator, '_clean_markdown_symbols') + + # Test that methods are callable + assert callable(integrator.integrate_cluster_cards) + assert callable(integrator.integrate_all_clusters) + assert callable(integrator._integrate_keywords) + assert callable(integrator._clean_markdown_symbols) + + def test_clean_markdown_symbols(self, llm_client): + """Test markdown symbols cleaning in integrator""" + integrator = KnowledgeIntegrator(llm_client) + + text_with_markdown = "**Bold** *Italic*" + cleaned = integrator._clean_markdown_symbols(text_with_markdown) + assert "**" not in cleaned + assert "*" not in cleaned + # The method only cleans certain markdown symbols, not all + + def test_integrate_keywords(self, llm_client): + """Test keyword integration logic""" + integrator = KnowledgeIntegrator(llm_client) + + # Test with a list of keywords directly + keywords_list = ['keyword1', 'keyword2', 'keyword1', 'keyword2', 'keyword3', 'keyword1'] + + integrated_keywords = integrator._integrate_keywords(keywords_list) + + assert isinstance(integrated_keywords, list) + assert len(integrated_keywords) > 0 + # keyword1 appears 3 times, should be most frequent + assert 'keyword1' in integrated_keywords + + +class TestAsyncVectorizeBatch: + """Test async vectorize batch function""" + + def test_vectorize_batch_function_exists(self): + """Test that async_vectorize_batch function exists and is callable""" + assert callable(async_vectorize_batch) + + def test_vectorize_batch_signature(self): + """Test async_vectorize_batch function signature""" + import inspect + + # Check function signature + sig = inspect.signature(async_vectorize_batch) + params = list(sig.parameters.keys()) + + assert 'texts' in params + assert 'embedding_model' in params + assert 'batch_size' in params + + +class TestPromptTemplateUsage: + """Test prompt template usage in async knowledge summary""" + + def test_summary_uses_template(self): + """Test that summary generation uses YAML templates""" + model_config = { + 'api_key': 'test-key', + 'base_url': 'http://test.com', + 'model_name': 'test-model', + 'model_repo': '' + } + + # Test with Chinese language + client_zh = AsyncLLMClient(model_config, language='zh') + assert 'SUMMARY_GENERATION_PROMPT' in client_zh.prompts + assert 'KEYWORD_EXTRACTION_PROMPT' in client_zh.prompts + assert 'CLUSTER_INTEGRATION_PROMPT' in client_zh.prompts + assert 'GLOBAL_INTEGRATION_PROMPT' in client_zh.prompts + + # Test with English language + client_en = AsyncLLMClient(model_config, language='en') + assert 'SUMMARY_GENERATION_PROMPT' in client_en.prompts + assert 'KEYWORD_EXTRACTION_PROMPT' in client_en.prompts + + # Verify that different languages load different templates + assert client_zh.prompts != client_en.prompts + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) +