diff --git a/fastapi_app/main.py b/fastapi_app/main.py index 90bc146..3bbbb45 100644 --- a/fastapi_app/main.py +++ b/fastapi_app/main.py @@ -35,11 +35,14 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse -from fastapi_app.routers import kb, kb_embedding, files, paper2drawio, paper2ppt, auth +from fastapi_app.routers import kb, kb_embedding, files, paper2drawio, paper2ppt, auth, data_insight from fastapi_app.middleware.api_key import APIKeyMiddleware from fastapi_app.middleware.logging import LoggingMiddleware from workflow_engine.utils import get_project_root +# 导入workflow模块以触发所有workflow注册 +from workflow_engine import workflow + # 本地 Embedding 服务端口(Octen-Embedding-0.6B) LOCAL_EMBEDDING_PORT = 26210 LOCAL_EMBEDDING_URL = f"http://127.0.0.1:{LOCAL_EMBEDDING_PORT}/v1/embeddings" @@ -152,6 +155,7 @@ def create_app() -> FastAPI: app.include_router(paper2drawio.router, prefix="/api/v1", tags=["Paper2Drawio"]) app.include_router(paper2ppt.router, prefix="/api/v1", tags=["Paper2PPT"]) app.include_router(auth.router, prefix="/api/v1", tags=["Auth"]) + app.include_router(data_insight.router, prefix="/api/v1", tags=["Data Insight"]) # 静态文件:/outputs 下的文件(兼容 URL 中 %40 与 磁盘 @ 两种路径) project_root = get_project_root() diff --git a/fastapi_app/routers/data_insight.py b/fastapi_app/routers/data_insight.py new file mode 100644 index 0000000..b7ac8e9 --- /dev/null +++ b/fastapi_app/routers/data_insight.py @@ -0,0 +1,247 @@ +""" +Data Insight Discovery API +Multi-dataset insight analysis using DM framework. +""" +import json +import tempfile +from pathlib import Path +from typing import List, Optional, Dict, Any +from fastapi import APIRouter, Form, HTTPException, UploadFile, File +from fastapi.responses import FileResponse +from pydantic import BaseModel + +import pandas as pd + +from workflow_engine.logger import get_logger +from fastapi_app.services.data_insight_service import DataInsightService + +log = get_logger(__name__) +router = APIRouter(prefix="/data_insight", tags=["data_insight"]) + + +# ==================== Pydantic Models ==================== +class DataInsightResponse(BaseModel): + """Response model for data insight analysis""" + status: str + synthesized_insights: List[str] + raw_insights: List[str] + summary: str + detailed_appendix: Dict[str, Any] = {} + result_path: str = "" + error: Optional[str] = None + + +class ErrorResponse(BaseModel): + """Standard error response""" + error: str + code: str = "INTERNAL_ERROR" + details: Optional[Dict] = None + + +# ==================== API Endpoints ==================== +@router.post( + "/analyze", + response_model=DataInsightResponse, + responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}}, +) +async def analyze_datasets( + chat_api_url: str = Form(..., description="LLM API URL"), + api_key: str = Form(..., description="LLM API key"), + model: str = Form("deepseek-v3.2", description="Model name"), + output_mode: str = Form("concise", description="Output mode: concise or detailed"), + language: str = Form("en", description="Language preference"), + files: List[UploadFile] = File(..., description="Data files (CSV, Excel)"), + analysis_goal: Optional[str] = Form(None, description="Custom analysis goal"), + email: Optional[str] = Form(None, description="User email"), +): + """ + Analyze multiple datasets and discover insights. + + Accepts CSV, Excel files. + Returns synthesized insights and summary. + """ + try: + # Validate inputs + if not files: + raise HTTPException(status_code=400, detail="No files provided") + + if not api_key or not chat_api_url: + raise HTTPException(status_code=400, detail="API key and URL required") + + # Call service + service = DataInsightService() + result = await service.analyze_datasets( + chat_api_url=chat_api_url, + api_key=api_key, + model=model, + output_mode=output_mode, + analysis_goal=analysis_goal, + language=language, + email=email, + files=files, + ) + + # Check for errors + if result.get("status") == "error": + raise HTTPException( + status_code=500, + detail=result.get("error", "Analysis failed") + ) + + # Convert raw_insights from dict to string if needed + raw_insights = result.get("raw_insights", []) + if raw_insights and isinstance(raw_insights[0], dict): + # Convert dict format to string representation + result["raw_insights"] = [ + f"[{item.get('source', 'unknown')}] {item.get('insight', str(item))}" + for item in raw_insights + ] + + return DataInsightResponse(**result) + + except HTTPException: + raise + except Exception as e: + log.error(f"Unexpected error in analyze_datasets: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +# ==================== Helper Functions ==================== +def generate_markdown_report( + synthesized_insights: List[str], + raw_insights: List[str], + summary: str, + detailed_appendix: Dict[str, Any], + language: str = "en" +) -> str: + """ + Generate a markdown report from analysis results. + + Args: + synthesized_insights: List of synthesized insights + raw_insights: List of raw insights from individual agents + summary: Overall summary + detailed_appendix: Detailed appendix data + language: Language preference (en/zh) + + Returns: + Markdown formatted report content + """ + lang = language.lower() + is_zh = lang == "zh" + + # Headers + title = "📊 Data Insight Report" if is_zh else "📊 Data Insight Report" + summary_header = "📝 Summary" if is_zh else "📝 Summary" + insights_header = "💡 Key Insights" if is_zh else "💡 Key Insights" + raw_header = "📋 Raw Analysis" if is_zh else "📋 Raw Analysis" + appendix_header = "📎 Detailed Appendix" if is_zh else "📎 Detailed Appendix" + footer = f"*Generated on {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}*" if is_zh else f"*Generated on {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}*" + + # Build report + report_lines = [ + f"# {title}", + "", + f"## {summary_header}", + "", + summary, + "", + f"## {insights_header}", + "" + ] + + # Add synthesized insights + for i, insight in enumerate(synthesized_insights, 1): + if is_zh: + report_lines.append(f"### Insight {i}") + else: + report_lines.append(f"### Insight {i}") + report_lines.append("") + report_lines.append(insight) + report_lines.append("") + + # Add raw insights if available + if raw_insights: + report_lines.append(f"## {raw_header}") + report_lines.append("") + for i, insight in enumerate(raw_insights, 1): + report_lines.append(f"**{i}.** {insight}") + report_lines.append("") + + # Add detailed appendix if available + if detailed_appendix: + report_lines.append(f"## {appendix_header}") + report_lines.append("") + for key, value in detailed_appendix.items(): + report_lines.append(f"### {key}") + report_lines.append("") + if isinstance(value, dict): + for k, v in value.items(): + report_lines.append(f"- **{k}:** {v}") + elif isinstance(value, list): + for item in value: + report_lines.append(f"- {item}") + else: + report_lines.append(str(value)) + report_lines.append("") + + # Add footer + report_lines.append("---") + report_lines.append("") + report_lines.append(footer) + + return "\n".join(report_lines) + + +# ==================== New API Endpoints ==================== +@router.post( + "/generate_report", + responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}}, +) +async def generate_report( + synthesized_insights: str = Form(..., description="JSON string of synthesized insights"), + raw_insights: str = Form(..., description="JSON string of raw insights"), + summary: str = Form(..., description="Analysis summary"), + detailed_appendix: str = Form("{}", description="JSON string of detailed appendix"), + language: str = Form("en", description="Language preference"), +): + """ + Generate a markdown report from analysis results. + """ + try: + # Parse JSON strings + synthesized = json.loads(synthesized_insights) if synthesized_insights else [] + raw = json.loads(raw_insights) if raw_insights else [] + appendix = json.loads(detailed_appendix) if detailed_appendix else {} + + # Generate markdown report + report_content = generate_markdown_report( + synthesized_insights=synthesized, + raw_insights=raw, + summary=summary, + detailed_appendix=appendix, + language=language + ) + + # Save to temporary file + temp_dir = Path(tempfile.gettempdir()) / "data_insight_reports" + temp_dir.mkdir(parents=True, exist_ok=True) + + report_filename = f"insight_report_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.md" + report_path = temp_dir / report_filename + report_path.write_text(report_content, encoding='utf-8') + + log.info(f"Generated markdown report: {report_path}") + + return FileResponse( + path=str(report_path), + filename=report_filename, + media_type='text/markdown' + ) + + except json.JSONDecodeError as e: + log.error(f"JSON decode error: {e}") + raise HTTPException(status_code=400, detail="Invalid JSON format in request") + except Exception as e: + log.error(f"Error generating report: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/fastapi_app/routers/kb.py b/fastapi_app/routers/kb.py index 9d97fea..7be07eb 100644 --- a/fastapi_app/routers/kb.py +++ b/fastapi_app/routers/kb.py @@ -118,7 +118,7 @@ def _text_to_pdf(text: str, output_path: str) -> None: doc.close() -ALLOWED_EXTENSIONS = {".pdf", ".docx", ".pptx", ".png", ".jpg", ".jpeg", ".mp4", ".md"} +ALLOWED_EXTENSIONS = {".pdf", ".docx", ".pptx", ".png", ".jpg", ".jpeg", ".mp4", ".md", ".csv", ".txt", ".db", ".json", ".jsonl", ".xlsx", ".xls", ".parquet", ".ndjson"} IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg"} DOC_EXTENSIONS = {".pdf", ".docx", ".doc", ".pptx", ".ppt", ".md", ".markdown"} @@ -329,14 +329,14 @@ def _append_images_to_pptx(pptx_path: Path, image_paths: List[Path]) -> None: @router.post("/upload") async def upload_kb_file( - file: UploadFile = File(...), + files: List[UploadFile] = File(...), email: str = Form(...), user_id: str = Form(...), notebook_id: Optional[str] = Form(None), notebook_title: Optional[str] = Form(None), ): """ - Upload a file to the notebook's knowledge base directory. + Upload multiple files to the notebook's knowledge base directory. New layout: outputs/{title}_{id}/sources/{stem}/original/ Fallback: also writes to legacy kb_data path for backward compat. """ @@ -345,77 +345,89 @@ async def upload_kb_file( if not notebook_id: raise HTTPException(status_code=400, detail="notebook_id is required for per-notebook storage") - file_ext = Path(file.filename).suffix.lower() - if file_ext not in ALLOWED_EXTENSIONS: - raise HTTPException( - status_code=400, - detail=f"Unsupported file type: {file_ext}. Allowed: {', '.join(ALLOWED_EXTENSIONS)}" - ) + uploaded_files = [] + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) + mgr = SourceManager(paths) - try: - filename = file.filename or f"unnamed_{user_id}" - filename = os.path.basename(filename) + for file in files: + file_ext = Path(file.filename).suffix.lower() + if file_ext not in ALLOWED_EXTENSIONS: + log.warning(f"Skipping unsupported file: {file.filename}") + continue - # --- New notebook-centric layout --- - paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) - mgr = SourceManager(paths) + try: + filename = file.filename or f"unnamed_{user_id}" + filename = os.path.basename(filename) - # Save uploaded bytes to a temp location first, then import - tmp_dir = paths.root / "_tmp" - tmp_dir.mkdir(parents=True, exist_ok=True) - tmp_path = tmp_dir / filename - with open(tmp_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) + # Save uploaded bytes to a temp location first, then import + tmp_dir = paths.root / "_tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + tmp_path = tmp_dir / filename + with open(tmp_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) - source_info = await mgr.import_file(tmp_path, filename) + source_info = await mgr.import_file(tmp_path, filename) - # Clean up temp - try: - tmp_path.unlink(missing_ok=True) - except Exception: - pass + # Clean up temp + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass - # Build static URL from the original path in new layout - project_root = get_project_root() - rel = source_info.original_path.relative_to(project_root) - static_path = "/" + rel.as_posix() - - # --- Also write to legacy path for backward compat --- - legacy_dir = _notebook_dir(email, notebook_id) - legacy_dir.mkdir(parents=True, exist_ok=True) - legacy_path = legacy_dir / filename - if not legacy_path.exists(): - shutil.copy2(str(source_info.original_path), str(legacy_path)) + # Build static URL from the original path in new layout + project_root = get_project_root() + rel = source_info.original_path.relative_to(project_root) + static_path = "/" + rel.as_posix() - # Auto-embed using new vector_store path - embedded = False - try: - vector_base = str(paths.vector_store_dir) - mineru_base = str(paths.source_mineru_dir(filename)) - file_list = [{"path": str(source_info.original_path)}] - await process_knowledge_base_files( - file_list=file_list, - base_dir=vector_base, - mineru_output_base=mineru_base, - ) - embedded = True - log.info("[upload] auto-embedding done: %s", filename) - except Exception as emb_err: - log.warning("[upload] auto-embedding failed for %s: %s", filename, emb_err) + # --- Also write to legacy path for backward compat --- + legacy_dir = _notebook_dir(email, notebook_id) + legacy_dir.mkdir(parents=True, exist_ok=True) + legacy_path = legacy_dir / filename + if not legacy_path.exists(): + shutil.copy2(str(source_info.original_path), str(legacy_path)) - return { - "success": True, - "filename": filename, - "file_size": os.path.getsize(source_info.original_path), - "storage_path": str(source_info.original_path), - "static_url": static_path, - "file_type": file.content_type, - "embedded": embedded, - } + # Auto-embed using new vector_store path + embedded = False + try: + vector_base = str(paths.vector_store_dir) + mineru_base = str(paths.source_mineru_dir(filename)) + file_list = [{"path": str(source_info.original_path)}] + # 使用本地embedding服务 + local_embedding_url = os.getenv("EMBEDDING_API_URL", "http://127.0.0.1:26210/v1/embeddings") + await process_knowledge_base_files( + file_list=file_list, + base_dir=vector_base, + mineru_output_base=mineru_base, + api_url=local_embedding_url, + ) + embedded = True + log.info("[upload] auto-embedding done: %s", filename) + except Exception as emb_err: + log.warning("[upload] auto-embedding failed for %s: %s", filename, emb_err) + + uploaded_files.append({ + "filename": filename, + "file_size": os.path.getsize(source_info.original_path), + "storage_path": str(source_info.original_path), + "static_url": static_path, + "file_type": file.content_type, + "embedded": embedded, + }) - except Exception as e: - print(f"Error uploading file: {e}") - raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + log.error(f"Error uploading file {file.filename}: {e}") + uploaded_files.append({ + "filename": file.filename, + "error": str(e), + "success": False + }) + + return { + "success": True, + "files": uploaded_files, + "total_uploaded": len([f for f in uploaded_files if "error" not in f]), + "total_failed": len([f for f in uploaded_files if "error" in f]), + } def _sanitize_md_filename(title: str, prefix: str = "doc") -> str: diff --git a/fastapi_app/services/data_insight_service.py b/fastapi_app/services/data_insight_service.py new file mode 100644 index 0000000..930f45b --- /dev/null +++ b/fastapi_app/services/data_insight_service.py @@ -0,0 +1,82 @@ +""" +Data Insight Service +Handles file upload and calls adapter. +""" +import time +from pathlib import Path +from typing import Any, Dict, List, Optional +from fastapi import UploadFile + +from workflow_engine.logger import get_logger +from workflow_engine.utils import get_project_root +from fastapi_app.workflow_adapters.wa_data_insight import DataInsightAdapter + +log = get_logger(__name__) + + +class DataInsightService: + """Data insight analysis service""" + + def _create_upload_dir(self, email: Optional[str]) -> Path: + """Create directory for uploaded files.""" + ts = int(time.time()) + root = get_project_root() + upload_dir = root / "outputs" / "data_insights" / (email or "default") / f"{ts}_upload" + upload_dir.mkdir(parents=True, exist_ok=True) + return upload_dir + + async def analyze_datasets( + self, + chat_api_url: str, + api_key: str, + model: str, + output_mode: str, + analysis_goal: Optional[str], + language: str, + email: Optional[str], + files: List[UploadFile], + ) -> Dict[str, Any]: + """ + Execute insight analysis workflow. + + Args: + chat_api_url: LLM API URL + api_key: LLM API key + model: Model name + output_mode: "concise" or "detailed" + analysis_goal: Optional custom goal + language: Language preference + email: User email + files: Uploaded data files + + Returns: + Analysis results dict + """ + # Save uploaded files + upload_dir = self._create_upload_dir(email) + file_paths = [] + + for file in files: + file_path = upload_dir / (file.filename or f"file_{len(file_paths)}.csv") + content = await file.read() + file_path.write_bytes(content) + file_paths.append(str(file_path)) + log.info(f"Uploaded: {file.filename}") + + # Build request dict + request_data = { + "file_ids": file_paths, + "model": model, + "api_key": api_key, + "chat_api_url": chat_api_url, + "output_mode": output_mode, + "analysis_goal": analysis_goal, + "language": language, + "email": email + } + + # Call adapter (NOT workflow directly) + adapter = DataInsightAdapter() + result = await adapter.execute(request_data) + + return result diff --git a/fastapi_app/workflow_adapters/wa_data_insight.py b/fastapi_app/workflow_adapters/wa_data_insight.py new file mode 100644 index 0000000..74131d7 --- /dev/null +++ b/fastapi_app/workflow_adapters/wa_data_insight.py @@ -0,0 +1,108 @@ +""" +Data Insight Workflow Adapter +Mandatory isolation layer between Service and Workflow. +""" +from __future__ import annotations +from typing import Dict, Any +from workflow_engine.state import DataInsightState, DataInsightRequest +from workflow_engine.workflow.registry import RuntimeRegistry +from workflow_engine.logger import get_logger + +log = get_logger(__name__) + + +class DataInsightAdapter: + """ + Adapter for data insight workflow. + Converts API request dict to workflow state and executes workflow. + """ + + async def execute(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute data insight workflow. + + Args: + request_data: Dict with keys: + - file_ids: List[str] + - model: str + - api_key: str + - chat_api_url: str + - output_mode: str + - analysis_goal: Optional[str] + - language: str + - email: Optional[str] + + Returns: + Dict with keys: + - status: "success" | "error" + - synthesized_insights: List[str] + - raw_insights: List[str] + - summary: str + - detailed_appendix: Dict (if detailed mode) + - result_path: str + - error: str (if error) + """ + try: + # Build workflow request + wf_request = DataInsightRequest( + file_ids=request_data.get("file_ids", []), + output_mode=request_data.get("output_mode", "concise"), + analysis_goal=request_data.get("analysis_goal"), + model=request_data.get("model", "deepseek-v3.2"), + api_key=request_data.get("api_key", ""), + chat_api_url=request_data.get("chat_api_url", ""), + language=request_data.get("language", "en") + ) + + # Add email if provided + if request_data.get("email"): + wf_request.email = request_data["email"] + + # Build workflow state + state = DataInsightState(request=wf_request) + + # Execute workflow + log.info("Executing data_insight workflow") + factory = RuntimeRegistry.get("data_insight") + builder = factory() + graph = builder.build() + + result_state = await graph.ainvoke(state) + + # Handle both dict and dataclass returns + if isinstance(result_state, dict): + # Result is a dict + synthesized_insights = result_state.get("synthesized_insights", []) + raw_insights = result_state.get("raw_insights", []) + summary = result_state.get("summary", "") + detailed_appendix = result_state.get("detailed_appendix", {}) + result_path = result_state.get("result_path", "") + else: + # Result is a DataInsightState object + synthesized_insights = result_state.synthesized_insights + raw_insights = result_state.raw_insights + summary = result_state.summary + detailed_appendix = result_state.detailed_appendix + result_path = result_state.result_path + + # Format response + return { + "status": "success", + "synthesized_insights": synthesized_insights, + "raw_insights": raw_insights, + "summary": summary, + "detailed_appendix": detailed_appendix, + "result_path": result_path + } + + except Exception as e: + log.error(f"Adapter execution failed: {e}", exc_info=True) + return { + "status": "error", + "error": str(e), + "synthesized_insights": [], + "raw_insights": [], + "summary": f"Analysis failed: {str(e)}", + "detailed_appendix": {}, + "result_path": "" + } diff --git a/frontend_en/Dockerfile b/frontend_en/Dockerfile index a1bee6e..022ef93 100644 --- a/frontend_en/Dockerfile +++ b/frontend_en/Dockerfile @@ -9,7 +9,7 @@ COPY frontend-v2/ ./ ARG VITE_API_KEY=df-internal-2024-workflow-key ARG VITE_DEFAULT_LLM_API_URL=https://api.apiyi.com/v1 -ARG VITE_LLM_API_URLS=https://api.apiyi.com/v1,http://b.apiyi.com:16888/v1,http://123.129.219.111:3000/v1 +ARG VITE_LLM_API_URLS=https://api.apiyi.com/v1,http://b.apiyi.com:16888/v1,http://172.96.160.199:3000/v1 ARG VITE_API_BASE_URL= ENV VITE_API_KEY=$VITE_API_KEY \ diff --git a/workflow_engine/state.py b/workflow_engine/state.py index 4ec4e64..fa98b55 100644 --- a/workflow_engine/state.py +++ b/workflow_engine/state.py @@ -544,3 +544,27 @@ class Paper2DrawioState(MainState): output_xml_path: str = "" # XML 文件路径 output_png_path: str = "" # PNG 导出路径 output_svg_path: str = "" # SVG 导出路径 + + +# ==================== Data Insight Request ==================== +@dataclass +class DataInsightRequest(MainRequest): + """Data insight discovery request""" + file_ids: List[str] = field(default_factory=list) # Uploaded data file paths + output_mode: str = "concise" # "concise" | "detailed" + analysis_goal: Optional[str] = None # Optional custom analysis goal + + +# ==================== Data Insight State ==================== +@dataclass +class DataInsightState(MainState): + """Data insight discovery state""" + request: DataInsightRequest = field(default_factory=DataInsightRequest) + result_path: str = "" + + # Results from DM insight analysis + synthesized_insights: List[str] = field(default_factory=list) + raw_insights: List[str] = field(default_factory=list) + summary: str = "" + detailed_appendix: Dict[str, Any] = field(default_factory=dict) + diff --git a/workflow_engine/toolkits/insight_tool/__init__.py b/workflow_engine/toolkits/insight_tool/__init__.py new file mode 100644 index 0000000..c58b653 --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/__init__.py @@ -0,0 +1,6 @@ +""" +Data Insight Toolkit - Wrapper for DM insight framework +""" +from .insight_wrapper import InsightToolkit + +__all__ = ["InsightToolkit"] diff --git a/workflow_engine/toolkits/insight_tool/dm_components/__init__.py b/workflow_engine/toolkits/insight_tool/dm_components/__init__.py new file mode 100644 index 0000000..7ef6a7b --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/__init__.py @@ -0,0 +1,13 @@ +# # Import submodules +# from . import agents +# from . import datasets +# from . import metrics +# from . import utils + +# # You can also define any package-level variables or functions here +# __version__ = "0.1.0" + +# # If you want to make certain classes or functions directly accessible +# # when someone imports insight, you can add them here +# from .agents import BaseAgent, HumanAgent, LLMAgent +# from .metrics import calculate_metrics diff --git a/workflow_engine/toolkits/insight_tool/dm_components/agents/__init__.py b/workflow_engine/toolkits/insight_tool/dm_components/agents/__init__.py new file mode 100644 index 0000000..75dec4f --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/agents/__init__.py @@ -0,0 +1 @@ +# DM Components - Agents diff --git a/workflow_engine/toolkits/insight_tool/dm_components/agents/base_agent.py b/workflow_engine/toolkits/insight_tool/dm_components/agents/base_agent.py new file mode 100644 index 0000000..9b79248 --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/agents/base_agent.py @@ -0,0 +1,371 @@ +import os +import json +import copy +import tempfile +from PIL import Image + + +from dm_components import prompts +from dm_components.config import logger +from dm_components.utils import agent_utils as au +from dm_components.utils.dataloader_utils import DataSourceReader + +from langchain.schema import HumanMessage, SystemMessage + + + +class AgentBase: + def __init__( + self, + savedir=None, + context="This is a dataset that could potentially consist of interesting insights", + model_name="gpt-3.5-turbo-0613", + goal="I want to find interesting trends in this dataset", + verbose=False, + temperature=0, + n_retries=2, + dataset_path=None, + api_key=None, + base_url=None + ): + self.goal = goal + if savedir is None: + savedir = tempfile.mkdtemp() + self.savedir = savedir + self.context = context + + self.model_name = model_name + self.temperature = temperature + self.api_key = api_key + self.base_url = base_url + + self.insights_history = [] + self.verbose = verbose + self.n_retries = n_retries + self.schema = None + self.dataset_path = dataset_path + self.multi_schema = None + self.multi_dataset_path = None + self.multi_profile = None # NEW: Support for profile information in multi-dataset scenarios + + def set_table( + self, + table=None, + multi_table=None, + dataset_path=None, # 从 dataset_csv_path 重命名 + multi_dataset_path=None, # 从 multi_dataset_csv_path 重命名 + dataset_read_kwargs=None, + multi_dataset_read_kwargs=None, + ): + + if dataset_read_kwargs is None: + dataset_read_kwargs = {} + if multi_dataset_read_kwargs is None: + multi_dataset_read_kwargs = {} + + # 1. 始终存储路径。它们对执行上下文是必要的。 + # 保持变量名称(self.dataset_path)不变 + # 因为 answer_question 函数依赖于它们。 + self.dataset_path = dataset_path + self.multi_dataset_path = multi_dataset_path + + # --- 主表逻辑 --- + if table is not None: + self.table = table + elif dataset_path is not None: + # 优先级 2: 如果没有传递 DataFrame,则从路径加载 + logger.info(f"未提供 DataFrame,正在从路径加载: {dataset_path}") + try: + self.table = DataSourceReader.read_data(dataset_path, **dataset_read_kwargs) # + except Exception as e: + logger.error(f"从 {dataset_path} 读取数据失败: {e}") + raise + else: + self.table = None # 未提供数据 + + if self.table is None: + raise ValueError("AgentBase.set_table: no 'table' provided.") + + self.schema = au.get_schema(self.table) + + + + def summarize(self, pred_insights, method="list", prompt_summarize_method="basic"): + if method == "list": + chat = au.get_chat_model(self.model_name, self.temperature, self.api_key, self.base_url) + + # Function to format the data + def format_data(data): + result = "" + for i, item in enumerate(data): + question_tag = f"{item['question']}\n" + answer_tag = f"{item['answer']}\n\n" + result += f"{question_tag} {answer_tag}\n" + return result + + # Format the data and print + formatted_history = format_data(pred_insights) + + # summary = agent.summarize_insights(method="list") + content_prompt, system_prompt = prompts.get_summarize_prompt( + method=prompt_summarize_method + ) + messages = [ + SystemMessage(content=system_prompt), + HumanMessage( + content=content_prompt.format( + context=self.context, + goal=self.goal, + history=formatted_history, + ) + ), + ] + + def _validate_tasks(out): + isights = au.extract_html_tags(out, ["insight"]) + + # Check that there are insights generated + if "insight" not in isights: + return ( + out, + False, + f"Error: you did not generate insights within the tags.", + ) + isights = isights["insight"] + return (isights, out), True, "" + + insight_list, message = au.chat_and_retry( + chat, messages, n_retry=3, parser=_validate_tasks + ) + + insights = "\n".join(insight_list) + + return insights + + def select_a_question(self, questions): + """ + Select a question from the list of questions + """ + return au.select_a_question( + questions, + self.context, + self.goal, + [o["question"] for o in self.insights_history], + self.model_name, + prompts.SELECT_A_QUESTION_TEMPLATE, + prompts.SELECT_A_QUESTION_SYSTEM_MESSAGE, + ) + + def generate_notebook(): + pass + + def generate_report(): + pass + + def recommend_questions( + self, + n_questions=3, + insights_history=None, + prompt_method=None, + question_type=None, + ): + """ + Suggest Next Best Questions + """ + if self.verbose: + print(f"Generating {n_questions} Questions using {self.model_name}...") + + if insights_history is None: + + # Generate Root Questions + questions = au.get_questions( + prompt_method=prompt_method, + context=self.context, + goal=self.goal, + messages=[], + schema=self.schema, + max_questions=n_questions, + model_name=self.model_name, + temperature=self.temperature, + ) + else: + # Generate Follow Up Questions + last_insight = insights_history[-1] + questions = au.get_follow_up_questions( + context=self.context, + goal=self.goal, + question=last_insight["question"], + answer=last_insight["answer"], + schema=self.schema, + max_questions=n_questions, + model_name=self.model_name, + prompt_method=prompt_method, + question_type=question_type, + temperature=self.temperature, + ) + if self.verbose: + print( + "\nFollowing up on the last insight:\n---------------------------------" + ) + print(f"Question: {last_insight['question']}\n") + print(f"Answer: {last_insight['answer']}\n") + + if self.verbose: + print("\nNext Best Questions:\n-------------------") + for idx, question in enumerate(questions): + print(f"{idx+1}. {question}") + print() + + return questions + + def answer_question( + self, + question, + n_retries=2, + return_insight_dict=True, + prompt_code_method="single", + prompt_interpret_method="interpret", + ): + n_retries = self.n_retries + if self.verbose: + print(f"Generating Code...") + + code_output_folder = os.path.join( + self.savedir, f"question_{str(len(self.insights_history))}" + ) + + if self.verbose: + print(f"Interpreting Solution...") + print(f"Results saved at: {self.savedir}") + + multi_path_processed = None + if self.multi_dataset_path is not None: + if isinstance(self.multi_dataset_path, list): + # 如果是路径列表,对每个路径应用 abspath + multi_path_processed = [os.path.abspath(p) for p in self.multi_dataset_path] + elif isinstance(self.multi_dataset_path, str): + # 如果是单个路径字符串,直接应用 abspath (保持向后兼容) + multi_path_processed = os.path.abspath(self.multi_dataset_path) + + solution = au.generate_code( + schema=self.schema, + multi_schema=self.multi_schema, + goal=self.goal, + question=question, + database_path=os.path.abspath(self.dataset_path) if self.dataset_path else None, + # 使用上面处理好的变量 + multi_database_path=multi_path_processed, + output_folder=code_output_folder, + model_name=self.model_name, + n_retries=n_retries, + prompt_method=prompt_code_method, + temperature=self.temperature, + multi_profile=self.multi_profile, # NEW: Pass profile information + ) + + # Prompt 4: Interpret Solution + interpretation_dict = au.interpret_solution( + solution=solution, + model_name=self.model_name, + schema=self.schema, + n_retries=n_retries, + prompt_method=prompt_interpret_method, + temperature=self.temperature, + ) + answer = interpretation_dict["interpretation"]["answer"] + + if self.verbose: + print("\nSolution\n---------") + print(f"Question: {question}\n") + print(f"Answer: {answer}\n") + print( + f"Justification: {interpretation_dict['interpretation']['justification']}\n" + ) + + insight_dict = { + "question": question, + "answer": answer, + "insight": interpretation_dict["interpretation"]["insight"], + "justification": interpretation_dict["interpretation"]["justification"], + "output_folder": code_output_folder, + } + + # Save into the savedir + os.makedirs(code_output_folder, exist_ok=True) # 确保目录存在 + with open(os.path.join(code_output_folder, "insight.json"), "w", encoding='utf-8') as json_file: + json.dump(insight_dict, json_file, indent=4, sort_keys=True, ensure_ascii=False) + + # add to insights + self.insights_history += [insight_dict] + + insight_dict = copy.deepcopy(insight_dict) + insight_dict.update(self.get_insight_objects(insight_dict)) + + if return_insight_dict: + return answer, insight_dict + + return answer["answer"] + + + def get_insight_objects(self, insight_dict): + """ + Get Insight Objects + """ + if os.path.exists(os.path.join(insight_dict["output_folder"], "plot.jpg")): + # get plot.jpg + plot = Image.open(os.path.join(insight_dict["output_folder"], "plot.jpg")) + else: + plot = None + + if os.path.exists(os.path.join(insight_dict["output_folder"], "x_axis.jpg")): + # get x_axis.json + x_axis = json.load( + open(os.path.join(insight_dict["output_folder"], "x_axis.json"), "r") + ) + else: + x_axis = None + + if os.path.exists(os.path.join(insight_dict["output_folder"], "y_axis.json")): + # get y_axis.json + y_axis = json.load( + open(os.path.join(insight_dict["output_folder"], "y_axis.json"), "r") + ) + else: + y_axis = None + + if os.path.exists(os.path.join(insight_dict["output_folder"], "stat.json")): + try: + # get stat.json + stat = json.load( + open(os.path.join(insight_dict["output_folder"], "stat.json"), "r") + ) + except: + stat = None + else: + stat = None + + # get code.py + if os.path.exists(os.path.join(insight_dict["output_folder"], "code.py")): + code = open( + os.path.join(insight_dict["output_folder"], "code.py"), "r" + ).read() + else: + code = None + + insight_object = { + "plot": plot, + "x_axis": x_axis, + "y_axis": y_axis, + "stat": stat, + "code": code, + } + return insight_object + + def save_state_dict(self, fname): + with open(fname, "w", encoding='utf-8') as f: + json.dump(self.insights_history, f, indent=4, ensure_ascii=False) + + def load_state_dict(self, fname): + with open(fname, "r") as f: + self.insights_history = json.load(f) + diff --git a/workflow_engine/toolkits/insight_tool/dm_components/agents/datasource_agent.py b/workflow_engine/toolkits/insight_tool/dm_components/agents/datasource_agent.py new file mode 100644 index 0000000..f384558 --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/agents/datasource_agent.py @@ -0,0 +1,267 @@ +import os +import json +import copy +import tempfile +import pandas as pd +from PIL import Image + + +from dm_components import prompts +from dm_components.config import logger +from dm_components.workflows.insight_workflow import InsightWorkflow +from dm_components.agents.base_agent import AgentBase +from dm_components.utils import agent_utils as au +from dm_components.utils.dataloader_utils import DataSourceReader + + +from typing import TypedDict, List, Dict, Optional, Any +from langchain.schema import HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI +from langgraph.graph import StateGraph, END +from langgraph.checkpoint.sqlite import SqliteSaver + + + +class DataSourceAgent: + """ + Agent representing a single data source. + Wraps an AgentBase instance internally for analysis capabilities. + """ + + def __init__(self, name: str, data: pd.DataFrame, original_file_path: str, + external_knowledge: str, agent_config: Dict[str, Any], + global_goal: str = ""): + """ + Initialize a DataSourceAgent. + + Args: + name: Agent identifier + data: DataFrame containing the dataset + original_file_path: Path to the original data file + external_knowledge: Domain knowledge description for this agent + agent_config: Configuration dictionary for agent behavior + global_goal: Overall analysis objective + """ + self.name = name + self.data = data + self.external_knowledge = external_knowledge + self.agent_config = agent_config + self.original_file_path = original_file_path + self.global_goal = global_goal + + # Initialize labels and metadata + self.profile = au.get_enhanced_data_profile(self.data) + self.importance_label = "Secondary" # Default value + self.preliminary_priority = "Medium" + self.final_priority = "Medium" + self.summary = "" + self.insights = [] + + # Create agent-specific directory for outputs + agent_save_path = os.path.join(agent_config['base_savedir'], self.name.replace(' ', '_')) + os.makedirs(agent_save_path, exist_ok=True) + + # Initialize the underlying AgentBase instance + self.agent_base = AgentBase( + model_name=agent_config['model_name'], + savedir=agent_save_path, + goal=f"Finding trends related to '{global_goal}' in {self.name} dataset", + verbose=True, + temperature=agent_config['temperature'], + n_retries=agent_config['n_retries'], + api_key=agent_config.get('api_key'), + base_url=agent_config.get('base_url') + ) + self.agent_base.set_table(table=self.data, dataset_path=self.original_file_path) + + # Initialize utilities + self.schema_str = au.schema_to_str(self.agent_base.schema) + self.chat_model = au.get_chat_model( + agent_config['model_name'], + agent_config['temperature'], + api_key=agent_config.get('api_key'), + base_url=agent_config.get('base_url') + ) + self.summary = "" + self.insights = [] + + def analyze_self(self) -> Dict[str, Any]: + """ + Phase 1: Independent analysis. Delegates to InsightWorkflow. + + Returns: + Dictionary containing analysis report + """ + logger.info(f"[{self.name} Agent]: Starting Phase 1: Independent Analysis...") + + try: + workflow = InsightWorkflow( + agent_base=self.agent_base, + branch_depth=self.agent_config.get('branch_depth', 2) + ) + + final_state = workflow.run( + initial_goal=self.agent_base.goal, + max_questions=self.agent_config.get('max_questions', 2) + ) + + self.insights = final_state.get('insights_history', []) + self.summary = final_state.get('final_summary', 'Analysis completed but no summary generated.') + + except Exception as e: + logger.error(f"[{self.name} Agent]: Independent analysis failed: {e}", exc_info=True) + self.summary = f"Analysis failed: {e}" + self.insights = [] + + report = { + "agent_name": self.name, + "summary": self.summary, + "key_metrics": self.insights, + "annotations": [], + } + + logger.info(f"[{self.name} Agent]: Analysis completed. Summary: {self.summary[:100]}...") + return report + + def annotate_other_agent_summary( + self, + report_to_annotate: Dict[str, Any], + max_summary_length: int = 1000, + max_insights_count: int = 3, + n_retries: int = 2 + ) -> Dict[str, str]: + """ + Core of background crossover: Generate annotations on another agent's report. + + Improved stability features: + - Input length limiting to prevent token overflow + - Retry logic for failed attempts + - Better response parsing with fallback + + Args: + report_to_annotate: Report from another agent to annotate + max_summary_length: Maximum characters for summary input + max_insights_count: Maximum number of insights to include + n_retries: Number of retry attempts on failure + + Returns: + Dictionary containing annotation information + """ + target_name = report_to_annotate['agent_name'] + + # Limit input lengths to prevent token overflow + target_summary = report_to_annotate.get('summary', '') + if len(target_summary) > max_summary_length: + target_summary = target_summary[:max_summary_length] + "... [truncated]" + + # Limit insights count and convert to string representation + target_insights = report_to_annotate.get('key_metrics', []) + if isinstance(target_insights, list) and len(target_insights) > max_insights_count: + target_insights = target_insights[:max_insights_count] + + # Convert insights to safe string representation + try: + if isinstance(target_insights, list): + insights_str = "\n".join([ + f"- {insight.get('question', 'N/A')}: {insight.get('answer', 'N/A')[:200]}" + if isinstance(insight, dict) else str(insight)[:300] + for insight in target_insights[:max_insights_count] + ]) + else: + insights_str = str(target_insights)[:1000] + except Exception: + insights_str = "[Insights unavailable]" + + # Also limit schema string + schema_str_limited = self.schema_str[:1500] if len(self.schema_str) > 1500 else self.schema_str + + def _attempt_annotation() -> str: + """Single annotation attempt.""" + prompt = prompts.ANNOTATION_PROMPT_TEMPLATE.format( + annotator_name=self.name, + annotator_knowledge=self.external_knowledge[:500], # Limit knowledge too + annotator_schema=schema_str_limited, + target_agent_name=target_name, + target_insight=insights_str, + target_summary=target_summary + ) + + response = self.chat_model(prompt) + comment = response.content if hasattr(response, 'content') else str(response) + return comment.strip() + + def _parse_comment(raw_response: str) -> str: + """Parse comment from response with fallback.""" + # Try to extract from tags + tags = au.extract_html_tags(raw_response, ["comment"]) + if tags and "comment" in tags and tags["comment"]: + return tags["comment"][0].strip() + + # Fallback: use raw response if it looks like valid content + cleaned = raw_response.strip() + + # Filter out meta-responses + skip_phrases = [ + "no comment", "nothing to add", "no additional", + "i don't have", "cannot provide", "unable to" + ] + if any(phrase in cleaned.lower() for phrase in skip_phrases): + return "" + + # If response is too short, it's likely not useful + if len(cleaned) < 10: + return "" + + # Limit output length + if len(cleaned) > 500: + cleaned = cleaned[:500] + "..." + + return cleaned + + # Attempt annotation with retries + comment = "" + last_error = None + + for attempt in range(n_retries + 1): + try: + raw_response = _attempt_annotation() + comment = _parse_comment(raw_response) + + if comment: # Success + logger.debug(f"[{self.name} Agent]: Annotation successful on attempt {attempt + 1}") + break + + except Exception as e: + last_error = e + logger.warning(f"[{self.name} Agent]: Annotation attempt {attempt + 1} failed: {e}") + + if attempt < n_retries: + # Reduce input size for retry + max_summary_length = max_summary_length // 2 + max_insights_count = max(1, max_insights_count - 1) + continue + + if not comment and last_error: + logger.warning(f"[{self.name} Agent]: All annotation attempts failed. Last error: {last_error}") + + return { + "author_agent": self.name, + "comment": comment + } + + def get_agent_info(self) -> Dict[str, Any]: + """ + Get comprehensive information about this agent. + + Returns: + Dictionary with agent metadata + """ + return { + "name": self.name, + "data_shape": self.data.shape, + "data_columns": list(self.data.columns), + "preliminary_priority": self.preliminary_priority, + "final_priority": self.final_priority, + "importance_label": self.importance_label, + "file_path": self.original_file_path + } diff --git a/workflow_engine/toolkits/insight_tool/dm_components/config.py b/workflow_engine/toolkits/insight_tool/dm_components/config.py new file mode 100644 index 0000000..3b28420 --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/config.py @@ -0,0 +1,120 @@ +# config.py +""" +Configuration file for the Insight Multi-Agent Framework. + +This file contains all configurable parameters for: +- LLM settings +- Data processing +- Output control +- Logging +""" +import os + +# ============================================================================= +# LLM 基础配置 +# ============================================================================= +MODEL_NAME = "gpt-4o" +TEMPERATURE = 0.0 +N_RETRIES = 4 + +# ============================================================================= +# 分析流程配置 +# ============================================================================= +BRANCH_DEPTH = 2 # 单源分析的深度探索层数 +MAX_QUESTIONS = 2 # 每次迭代的最大问题数 + +# ============================================================================= +# 输出目录配置 +# ============================================================================= +BASE_SAVEDIR = "/mnt/DataFlow/qry/DataCrossBench-Exp-Results/DataCross" + +# ============================================================================= +# 背景信息处理配置 (NEW) +# ============================================================================= +# 文本文件超过此字符数阈值时,将自动进行摘要提取 +TEXT_SUMMARY_THRESHOLD = 2000 + +# 图片分类使用的模型 (需要支持视觉能力) +IMAGE_CLASSIFICATION_MODEL = "gpt-4o" + +# ============================================================================= +# 输出控制配置 (NEW) +# ============================================================================= +# 默认输出模式: "concise" (简洁模式) 或 "detailed" (详细模式) +# - concise: 精简输出,适合快速查看结果 +# - detailed: 完整输出,包含detailed_appendix,适合benchmark对比 +DEFAULT_OUTPUT_MODE = "concise" + +# ============================================================================= +# 评分机制配置 (NEW) +# ============================================================================= +# 混合评分权重 +SCORING_WEIGHTS = { + "objective": 0.4, # 客观指标 (数据质量、丰富度、时间维度) + "semantic": 0.3, # 语义相关性 (关键词匹配) + "llm": 0.3 # LLM主观评分 +} + +# 优先级阈值 +PRIORITY_THRESHOLDS = { + "high": 7.0, # 分数 >= 7 为 High 优先级 + "medium": 4.0 # 分数 >= 4 为 Medium 优先级, 否则为 Low +} + +# ============================================================================= +# 批注流程配置 (NEW) +# ============================================================================= +# 批注输入限制 (防止token溢出) +ANNOTATION_MAX_SUMMARY_LENGTH = 1000 +ANNOTATION_MAX_INSIGHTS_COUNT = 3 +ANNOTATION_N_RETRIES = 2 + +# ============================================================================= +# LangGraph 相关 +# ============================================================================= +from langgraph.graph import StateGraph, END + +# ============================================================================= +# 类型提示 +# ============================================================================= +from typing import List, Dict, Any, TypedDict, Optional + +# ============================================================================= +# 日志配置 - 使用 Open-NotebookLM 的 logger +# ============================================================================= +import sys +import os +# 添加 workflow_engine 到路径以便导入 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))) +from workflow_engine.logger import get_logger +logger = get_logger(__name__) + + +# ============================================================================= +# 辅助函数 +# ============================================================================= +def get_config_dict() -> Dict[str, Any]: + """ + 获取所有配置项的字典形式,便于传递和记录。 + + Returns: + 包含所有配置项的字典 + """ + return { + "model_name": MODEL_NAME, + "temperature": TEMPERATURE, + "n_retries": N_RETRIES, + "branch_depth": BRANCH_DEPTH, + "max_questions": MAX_QUESTIONS, + "base_savedir": BASE_SAVEDIR, + "text_summary_threshold": TEXT_SUMMARY_THRESHOLD, + "image_classification_model": IMAGE_CLASSIFICATION_MODEL, + "default_output_mode": DEFAULT_OUTPUT_MODE, + "scoring_weights": SCORING_WEIGHTS, + "priority_thresholds": PRIORITY_THRESHOLDS, + "annotation_config": { + "max_summary_length": ANNOTATION_MAX_SUMMARY_LENGTH, + "max_insights_count": ANNOTATION_MAX_INSIGHTS_COUNT, + "n_retries": ANNOTATION_N_RETRIES + } + } \ No newline at end of file diff --git a/workflow_engine/toolkits/insight_tool/dm_components/insight_entry.py b/workflow_engine/toolkits/insight_tool/dm_components/insight_entry.py new file mode 100644 index 0000000..4bd8ada --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/insight_entry.py @@ -0,0 +1,457 @@ +# insight_discovery.py +""" +Main API for multi-dataset insight discovery system. +Provides high-level interfaces for both folder-based and single-dataset analysis. + +Now supports: +- Non-tabular data handling (txt, images) +- Background knowledge collection and injection +- Concise/Detailed output modes +""" + +import os +import json +import pandas as pd +from typing import List, Tuple, Dict, Any, Optional, Union + +from dm_components.config import logger +from dm_components.agents.datasource_agent import DataSourceAgent +from dm_components.utils.dataloader_utils import DataSourceReader +from dm_components.workflows.orches_workflow import OrchestratorWorkflow + + +# Type alias for background text data +BackgroundTextData = Dict[str, Any] + + +class InsightEntry: + """ + Main interface for automated insight discovery across datasets. + + This class provides two primary modes of operation: + 1. analyze_folder(): Analyze all datasets in a folder (with meta-info.json support) + 2. analyze_single_dataset(): Analyze one or two specific datasets + + New features: + - Handles non-tabular data (txt, images) as background knowledge + - Supports output_mode: "concise" or "detailed" + - Collects and injects background information into final synthesis + + Both methods return a tuple of (insights_list, summary_string) or + (insights_list, summary_string, detailed_appendix) in detailed mode. + """ + + def __init__(self, + model_name: str = "gpt-4.1-nano", + base_savedir: str = "./outputs", + temperature: float = 0.1, + n_retries: int = 1, + branch_depth: int = 1, + max_questions: int = 1, + text_summary_threshold: int = 2000, + default_output_mode: str = "concise", + api_key: str = "", + base_url: str = ""): + """ + Initialize the insight discovery system. + + Args: + model_name: LLM model to use for analysis + base_savedir: Base directory for saving outputs + temperature: LLM temperature (0.0-1.0) + n_retries: Number of retries for failed LLM calls + branch_depth: Exploration depth for single-dataset analysis + max_questions: Max questions per iteration + text_summary_threshold: Character threshold for text summarization (NEW) + default_output_mode: Default output mode - "concise" or "detailed" (NEW) + api_key: API key for LLM (NEW) + base_url: Base URL for LLM API (NEW) + """ + self.model_name = model_name + self.base_savedir = base_savedir + self.temperature = temperature + self.n_retries = n_retries + self.branch_depth = branch_depth + self.max_questions = max_questions + self.text_summary_threshold = text_summary_threshold + self.default_output_mode = default_output_mode + self.api_key = api_key + self.base_url = base_url + + # Set global API config for all downstream functions + from dm_components.utils import agent_utils as au + au.set_global_api_config(api_key=api_key, base_url=base_url) + + # Background knowledge pool for non-tabular data + self.background_knowledge_pool: List[BackgroundTextData] = [] + + # Ensure output directory exists + os.makedirs(self.base_savedir, exist_ok=True) + + logger.info(f"InsightEntry initialized with model={model_name}, " + f"output_dir={base_savedir}, output_mode={default_output_mode}") + + def _get_agent_config(self) -> Dict[str, Any]: + """Get unified agent configuration.""" + return { + "model_name": self.model_name, + "base_savedir": self.base_savedir, + "temperature": self.temperature, + "n_retries": self.n_retries, + "branch_depth": self.branch_depth, + "max_questions": self.max_questions, + "api_key": self.api_key, + "base_url": self.base_url + } + + def _create_agent_from_dataframe(self, + name: str, + dataframe: pd.DataFrame, + file_path: str, + description: str) -> Optional[DataSourceAgent]: + """ + Create a DataSourceAgent from a DataFrame. + + Args: + name: Agent name + dataframe: Data to analyze + file_path: Original file path + description: Data source description + + Returns: + DataSourceAgent instance or None if creation fails + """ + try: + external_knowledge = ( + f"An expert analyst in {name} domain. " + f"Data source: {description}" + ) + + agent = DataSourceAgent( + name=name, + data=dataframe, + original_file_path=file_path, + external_knowledge=external_knowledge, + agent_config=self._get_agent_config(), + global_goal="" # Will be set by workflow + ) + + logger.info(f"Created agent [{name}] with {len(dataframe)} rows") + return agent + + except Exception as e: + logger.error(f"Failed to create agent {name}: {e}") + return None + + def _read_meta_config(self, data_folder: str) -> Dict[str, Any]: + """ + Read meta.json configuration from parent directory. + + Args: + data_folder: Data folder path + + Returns: + Meta configuration dictionary + """ + parent_dir = os.path.dirname(data_folder) + meta_path = os.path.join(parent_dir, "meta-info.json") + + if not os.path.exists(meta_path): + logger.info(f"No meta-info.json found at {meta_path}") + return {} + + try: + with open(meta_path, 'r', encoding='utf-8') as f: + meta_data = json.load(f) + + logger.info(f"Loaded meta.json: goal='{meta_data.get('goal', 'unspecified')}'") + return meta_data + + except Exception as e: + logger.error(f"Failed to read meta-info.json: {e}") + return {} + + def _process_single_file( + self, + file_path: str, + processed_dir: str, + include_background: bool = True + ) -> Tuple[List[DataSourceAgent], List[BackgroundTextData]]: + """ + Process a single data file and create appropriate agents. + + Now handles non-tabular data (txt, images) as background knowledge. + + Args: + file_path: Path to data file + processed_dir: Directory for processed files + include_background: Whether to collect non-tabular data as background + + Returns: + Tuple of (list of created agents, list of background info dicts) + """ + agents = [] + background_info = [] + filename = os.path.basename(file_path) + + try: + # Read data using enhanced DataSourceReader + loaded_data = DataSourceReader.read_data( + file_path, + as_background=False, # Try structured first + max_chars_for_direct_use=self.text_summary_threshold + ) + + # Check if it's background text data (non-tabular) + if isinstance(loaded_data, dict) and loaded_data.get('type') == 'background_text': + # This is non-tabular data (txt or chart image) + if include_background: + logger.info(f"Collected background info from: {filename}") + background_info.append(loaded_data) + else: + logger.info(f"Skipping background data (include_background=False): {filename}") + return agents, background_info + + if isinstance(loaded_data, dict) and 'type' not in loaded_data: + # Multi-table file (e.g., SQLite) - dict of DataFrames + logger.info(f"Processing multi-table file: {filename}") + file_basename = os.path.splitext(filename)[0] + + for table_name, table_df in loaded_data.items(): + if not isinstance(table_df, pd.DataFrame): + continue + + csv_name = f"{file_basename}_{table_name}.csv" + processed_file_path = os.path.join(processed_dir, csv_name) + table_df.to_csv(processed_file_path, index=False) + + agent_name = f"{file_basename}-{table_name}" + description = f"Table {table_name} in {filename}" + agent = self._create_agent_from_dataframe( + agent_name, table_df, processed_file_path, description + ) + if agent: + agents.append(agent) + + elif isinstance(loaded_data, pd.DataFrame): + # Single-table file + logger.info(f"Processing single-table file: {filename}") + file_basename = os.path.splitext(filename)[0] + csv_name = f"{file_basename}.csv" + processed_file_path = os.path.join(processed_dir, csv_name) + loaded_data.to_csv(processed_file_path, index=False) + + description = f"File {filename}" + agent = self._create_agent_from_dataframe( + file_basename, loaded_data, processed_file_path, description + ) + if agent: + agents.append(agent) + + else: + logger.warning(f"Unknown data type from {filename}: {type(loaded_data)}") + + except Exception as e: + logger.error(f"Failed to process file {filename}: {e}") + + return agents, background_info + + def analyze_folder( + self, + data_folder: str, + use_meta_goal: bool = True, + output_mode: Optional[str] = None, + include_background: bool = True + ) -> Union[Tuple[List[str], str], Tuple[List[str], str, Dict[str, Any]]]: + """ + Analyze all datasets in a folder. + + Now supports: + - Non-tabular data (txt, images) as background knowledge + - Concise/Detailed output modes + + Args: + data_folder: Path to folder containing data files + use_meta_goal: Whether to use goal from meta-info.json + output_mode: "concise" or "detailed" (defaults to self.default_output_mode) + include_background: Whether to process non-tabular data as background + + Returns: + Tuple of (insights, summary) in concise mode, or + Tuple of (insights, summary, detailed_appendix) in detailed mode + """ + logger.info(f"Analyzing data folder: {data_folder}") + + # Use default output mode if not specified + output_mode = output_mode or self.default_output_mode + + # Validate folder + if not os.path.exists(data_folder): + logger.error(f"Data folder not found: {data_folder}") + return { + "synthesized_insights": [], + "raw_insights": [], + "summary": f"Data folder not found: {data_folder}", + "detailed_appendix": {} + } + + # Read meta configuration + meta_data = self._read_meta_config(data_folder) if use_meta_goal else {} + global_goal = meta_data.get('goal', 'Discover insights from multiple datasets') + + processed_dir = os.path.join(data_folder, "processed") + os.makedirs(processed_dir, exist_ok=True) + + # Process all files - collect agents and background info + all_agents = [] + all_background_info = [] + + for filename in os.listdir(data_folder): + file_path = os.path.join(data_folder, filename) + + if os.path.isdir(file_path): # 跳过目录 + continue + + agents, background_items = self._process_single_file( + file_path, + processed_dir, + include_background=include_background + ) + all_agents.extend(agents) + all_background_info.extend(background_items) + + # Store background info for reference + self.background_knowledge_pool = all_background_info + + # Check if any agents were created + if not all_agents: + # If we have background info but no agents, provide a warning + if all_background_info: + logger.warning("No tabular data found, but collected background info. " + "Analysis requires at least one tabular data source.") + logger.error("No data agents were successfully created") + return { + "synthesized_insights": [], + "raw_insights": [], + "summary": "No analyzable data found", + "detailed_appendix": {} + } + + logger.info(f"Created {len(all_agents)} agents from {data_folder}") + if all_background_info: + logger.info(f"Collected {len(all_background_info)} background knowledge items") + + # Run multi-agent analysis + try: + workflow = OrchestratorWorkflow( + data_agents=all_agents, + global_goal=global_goal + ) + + # Run with background info and output mode + result = workflow.run( + output_mode=output_mode, + background_knowledge_pool=all_background_info + ) + + # result is now a dictionary with keys: synthesized_insights, raw_insights, summary, detailed_appendix + insights = result.get("synthesized_insights", []) + raw_insights = result.get("raw_insights", []) + summary = result.get("summary", "") + detailed_appendix = result.get("detailed_appendix", {}) + + logger.info(f"Analysis completed: {len(insights)} synthesized insights, {len(raw_insights)} raw insights (mode: {output_mode})") + + # Return the full result dictionary + return result + + except Exception as e: + logger.error(f"Multi-agent analysis failed: {e}", exc_info=True) + # Return dictionary format for consistency + return { + "synthesized_insights": [], + "raw_insights": [], + "summary": f"Analysis failed: {str(e)}", + "detailed_appendix": {} + } + + + + def analyze_insight_bench(self, dataset_csv_path: str, user_dataset_csv_path: Optional[str] = None, **kwargs) -> Tuple[List[str], str]: + """ + 分析单个数据集 (可能包含用户数据集) 并返回洞察。 + 这个方法模仿 analyze_folder 的内部逻辑,但只处理传入的 CSV 文件路径。 + + Args: + dataset_csv_path: 核心数据集的 CSV 路径 + user_dataset_csv_path: 用户提供的数据集的 CSV 路径 (可选) + **kwargs: 允许传入 max_questions, branch_depth 等用于覆盖 Agent 配置的参数 + + Returns: + Tuple[pred_insights, pred_summary]: 洞察列表和总结 + """ + all_agents: List[DataSourceAgent] = [] + + # 临时更新 agent_config 以接受 exp_dict 中的参数覆盖 + current_config = self.agent_config.copy() + current_config.update(kwargs) + + + def _create_single_agent_from_path(file_path: str, agent_suffix: str, source_description: str): + """内部辅助函数:从路径读取数据并创建 Agent""" + try: + # 1. 读取数据 + loaded_data = DataSourceReader.read_data(file_path) + if not isinstance(loaded_data, pd.DataFrame): + logger.warning(f"文件 {os.path.basename(file_path)} 返回了非 DataFrame 数据,跳过。") + return + + # 2. 创建 Agent + agent_name = agent_suffix.replace('_', ' ').strip().title() + # 外部知识可以定义 Agent 的角色和数据来源 + external_knowledge = f"一名 {agent_name} 领域的专家分析师。数据来源: {source_description}" + + agent = DataSourceAgent( + name=agent_name, + data=loaded_data, + original_file_path=file_path, + external_knowledge=external_knowledge, + agent_config=current_config # 使用更新后的配置 + ) + all_agents.append(agent) + logger.info(f"成功创建 Agent: [{agent_name}] (文件: {os.path.basename(file_path)})") + except Exception as e: + logger.error(f"处理文件 {os.path.basename(file_path)} 时发生意外错误: {e}") + + + # 1. 处理核心数据集 + _create_single_agent_from_path( + file_path=dataset_csv_path, + agent_suffix="Core Dataset", + source_description=f"核心数据集文件: {os.path.basename(dataset_csv_path)}" + ) + + # 2. 处理用户数据集 (如果存在) + if user_dataset_csv_path and os.path.exists(user_dataset_csv_path): + _create_single_agent_from_path( + file_path=user_dataset_csv_path, + agent_suffix="User Dataset", + source_description=f"用户提供的数据集文件: {os.path.basename(user_dataset_csv_path)}" + ) + + if not all_agents: + logger.error(f"未能成功加载数据集 {os.path.basename(dataset_csv_path)} 的任何 Agent。") + return { + "synthesized_insights": [], + "raw_insights": [], + "summary": "没有可分析的数据", + "detailed_appendix": {} + } + + # 3. 运行工作流 (与 analyze_folder 相同) + logger.info("===== 运行单个数据集分析工作流 =====") + workflow = OrchestratorWorkflow(data_agents=all_agents) + result = workflow.run() + + # Return the full result dictionary + return result diff --git a/workflow_engine/toolkits/insight_tool/dm_components/main.py b/workflow_engine/toolkits/insight_tool/dm_components/main.py new file mode 100644 index 0000000..c20103b --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/main.py @@ -0,0 +1,33 @@ +# main.py +import os +from insight_entry import InsightEntry + +def main(): + + if "QDF_API_KEY" not in os.environ or "QDF_API_URL" not in os.environ: + print("Error: Please set QDF_API_KEY and QDF_API_URL environments.") + return + + # 创建分析器实例 + analyzer = InsightEntry( + model_name="gpt-4.1-nano", + base_savedir="./outputs", + temperature=0.1, + n_retries=1, + branch_depth=1, + max_questions=1 + ) + + sample_data_dir = "./insight/sample_data/flag-99/output" + insights, summary = analyzer.analyze_folder(sample_data_dir) + + print("\n=== Analysis Results ===") + print(f"Summary: {summary}") + print("\nInsights:") + for i, insight in enumerate(insights, 1): + print(f"{i}. {insight}") + + return insights, summary + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/workflow_engine/toolkits/insight_tool/dm_components/prompts/__init__.py b/workflow_engine/toolkits/insight_tool/dm_components/prompts/__init__.py new file mode 100644 index 0000000..fb996ce --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/prompts/__init__.py @@ -0,0 +1,1137 @@ +# prompts.py +""" +All prompt templates for the multi-agent analysis system. +All prompts are in English to ensure consistency with LLM models. +""" + +# ============================================================================== +# Core Prompt Definitions for Data Annotation and Cross-Analysis +# ============================================================================== + +PRELIMINARY_EVAL_PROMPT = """ +Role: Data Strategy Consultant +Global Goal: {global_goal} +Current Dataset Metadata (Schema & Stats): +{data_profile} + +Task: +Based ONLY on the schema and statistics provided, evaluate the potential relevance of this dataset to the Global Goal. +Identify if this data contains core KPIs, key dimensions, or is likely just background noise. + +Output Format: +Please wrap your evaluation in the following tags: +1-10 +A brief explanation +High/Medium/Low +""" + +FORMAL_ANNOTATION_PROMPT = """ +Role: Chief Data Scientist +Global Goal: {global_goal} +Given the following schema: +{schema} +Exploration Summary from the Agent's Deep-Dive: +{exploration_summary} + +Task: +Perform a final assessment of this data's importance to the global objective. +Metrics: +- Information Richness (1-10): How deep and high-quality are the insights found? +- Theme Alignment (1-10): How directly does this support the Global Goal? + +Decision Criteria: +- "Primary": Contains core metrics; can drive the main analysis. +- "Secondary": Provides context, auxiliary dimensions, or validation. + +Output Format: +1-10 +1-10 + +Detailed reason +""" + +CROSS_QUESTION_PROMPT = """ +Role: Cross-domain Analyst +Dataset A (Your Data) Summary: {my_summary} +Dataset B (Target Data) Summary: {other_summary} +Your Label: {my_label} + +Task: +Generate analytical questions that require JOINING or COMPARING both datasets to find hidden patterns. +Constraint: +- If your label is "Primary", generate 3 deep questions. +- If your label is "Secondary", generate 1 focused question. + +Output Format: +Generate your questions, each enclosed in tags. +Example: Your question text (Rationale: ...) +""" + +# Background annotation prompt +ANNOTATION_PROMPT_TEMPLATE = """ +Role: Domain Expert & Critical Reviewer + +You are: {annotator_name} +Your Domain Knowledge: {annotator_knowledge} +Your Data Schema: {annotator_schema} + +You are reviewing analysis results from another agent. +Target Agent: {target_agent_name} +Target Agent's Analysis Insights: {target_insight} +Target Agent's Analysis Summary: {target_summary} + +Task: +Provide critical comments or cross-domain insights based on your expertise. +Focus on: +1. Missing perspectives that your data might provide +2. Potential data quality issues +3. Alternative interpretations +4. Connections to broader business context + +If you have no meaningful comments, respond with "no comment". +Otherwise, provide concise but insightful feedback. + +Output Format: +Your critical feedback here +""" + +# # Numerical crossover ideation prompt +# NUMERICAL_CROSSOVER_IDEATION_PROMPT = """ +# Role: Senior Data Analyst +# Global Goal: {global_goal} + +# Context from all datasets: +# {context} + +# Task: +# Based on the analysis summaries and cross-agent annotations above, generate specific, actionable analytical questions that require combining data from multiple datasets. + +# Focus on questions that: +# 1. Reveal relationships between different datasets +# 2. Combine metrics from primary and secondary datasets +# 3. Uncover hidden patterns through data joins +# 4. Address the Global Goal through multi-dataset analysis + +# Generate 3-5 high-quality analytical questions. Each question should specify which datasets need to be combined. + +# Output Format: +# Wrap each question in tags. +# Question 1: Clear description of what to analyze and why (Specify datasets: DatasetA + DatasetB) +# Question 2: ... +# """ + +# Final synthesis prompt +FINAL_PROMPT_TEMPLATE = """ +Role: Senior Business Intelligence Analyst + +Context from Complete Multi-Agent Analysis: +{full_context} + +Task: +Synthesize all analyses, cross-dataset findings, and agent annotations into a comprehensive final report. + +Your output should include: +1. Executive Summary (2-3 paragraphs) +2. Key Insights (bullet points, prioritized by importance) +3. Cross-dataset Discoveries +4. Limitations and Data Quality Notes +5. Recommended Next Steps + +Format your response as a JSON object with the following structure: +{{ + "summary": "executive summary text here", + "insights": ["insight 1", "insight 2", ...], + "cross_dataset_discoveries": ["discovery 1", ...], + "limitations": ["limitation 1", ...], + "next_steps": ["recommendation 1", ...] +}} +""" + +# Report generation prompt +REPORT_GENERATION_PROMPT = """ +Role: Technical Report Writer + +Analysis Workflow Final State: +{state_str} + +Task: +Convert this technical workflow state into a professional Markdown report suitable for business stakeholders. +The report should be clear, concise, and focus on actionable insights rather than technical details. + +Include: +1. Title and Executive Summary +2. Analysis Methodology Overview +3. Key Findings by Dataset +4. Cross-Dataset Insights +5. Limitations and Assumptions +6. Recommendations +7. Appendices (technical details if necessary) + +Format the entire report in Markdown with appropriate headings and structure. +""" + +# __all__ = [ +# 'PRELIMINARY_EVAL_PROMPT', +# 'FORMAL_ANNOTATION_PROMPT', +# 'CROSS_QUESTION_PROMPT', +# 'ANNOTATION_PROMPT_TEMPLATE', +# 'NUMERICAL_CROSSOVER_IDEATION_PROMPT', +# 'FINAL_PROMPT_TEMPLATE', +# 'REPORT_GENERATION_PROMPT' +# ] + + +FINAL_PROMPT_TEMPLATE = """ +你是一名首席战略官。你已经收到了来自多个部门的报告,包括他们之间的同行评审(批注),以及一份跨部门的数值交叉分析报告。 +你的任务是综合所有这些信息,提炼出 2-5 个高层次的、跨职能的洞察。如果给你的信息中包含了前置内容的一些报错情况,请忽略,不要反映在最终报告中。 +请专注于识别单一部门会错过的因果联系、利弊权衡和战略机会。 + +**完整上下文:** +{full_context} + +**指示:** +请生成一个包含两个部分的综合报告: +1. **pred_insights**: 2-5个高层次的洞察,每个洞察应该是一个完整的句子,并包含洞察类型标签(如Trend:, Comparison:, Extreme:, Attribution:等) +2. **pred_summary**: 一个简短的总结段落,概括报告的核心内容 + +**重要:所有内容(insights 的值、summary)必须用中文撰写。只有 JSON 的 key(如 "insights"、"summary")和前缀(如 "Trend:"、"Comparison:")保持英文。** + +请使用以下JSON格式输出你的回答: +{{ + "insights": [ + "Trend: 洞察内容1(用中文撰写)", + "Comparison: 洞察内容2(用中文撰写)", + "Extreme: 洞察内容3(用中文撰写)", + "Attribution: 洞察内容4(用中文撰写)" + ], + "summary": "总结内容(用中文撰写)" +}} +""" + +# 支持简洁模式和详细模式的最终合成Prompt +FINAL_PROMPT_TEMPLATE_WITH_MODES = """ +你是一位**经验丰富的、专注于生成可操作业务报告的资深数据分析师**。 + +你的任务是:基于提供的所有分析报告、交叉批注、计算结果和背景信息(即 Context),生成一份**最终综合分析报告**。 + +### 分析内容 +{full_context} + +### 背景知识(仅供参考,核心结论需基于数据分析) +{background_info} + +### 输出模式: {output_mode} + +### 输出要求 + +**如果输出模式是 "concise" (简洁模式):** +1. 每个数据源只保留最重要的 1-2 条洞察 +2. 重点突出跨源发现和联合分析结论 +3. 总结控制在 500 字以内 + +**如果输出模式是 "detailed" (详细模式):** +1. 保留所有数据源的完整洞察 +2. 包含详细的跨源分析和背景信息关联 +3. 在 detailed_appendix 中保存完整的原始报告供参考 + +### JSON 输出格式 +{{ + "insights": [ + "Trend: 洞察1 (关键趋势发现,如时间序列变化、增长/下降趋势)", + "Comparison: 洞察2 (跨源对比分析,如不同数据源之间的差异、关联)", + "Extreme: 洞察3 (异常值或极端情况,如最大值、最小值、异常波动)", + "Attribution: 洞察4 (归因分析,如因果关系、影响因素分析)" + ], + "summary": "综合性摘要,概括核心发现和业务含义", + "detailed_appendix": {{ + "full_reports": ["仅详细模式填充"], + "crossover_results": ["仅详细模式填充"], + "background_info": ["仅详细模式填充"] + }} +}} + +**重要要求:** +- **语言要求**:所有内容(insights 的值、summary)必须用**中文**撰写。只有 JSON 的 key(如 "insights"、"summary")和前缀(如 "Trend:"、"Comparison:")保持英文。 +- insights 列表中的每条洞察必须以 "Trend:"、"Comparison:"、"Extreme:" 或 "Attribution:" 开头,但冒号后的内容必须用中文撰写 +- 尽量覆盖所有四种类型,如果某种类型没有相关发现,可以省略 +- 每条洞察应该是完整的句子,清晰表达发现 +- summary 必须用中文撰写,全面概括核心发现和业务含义 + +**注意:** +- 如果是简洁模式,detailed_appendix 应为空对象 {{}} +- 如果是详细模式,detailed_appendix 应包含完整信息 +- 忽略任何错误信息,只关注有效的分析结果 +""" + +# 用于背景信息注入的简化模板 +BACKGROUND_INFO_SECTION_TEMPLATE = """ +--- Background Knowledge (Reference Only) --- +The following background information is provided for context. + +{background_content} + +--- End of Background Knowledge --- +""" + +FINAL_PROMPT_TEMPLATE_IB = """ +你是一位**经验丰富的、专注于生成可操作业务报告的资深数据分析师**。 + +你的任务是:基于提供的所有分析报告、交叉批注和计算结果(即 Context),生成一份**最终综合分析报告**。 + +### 输出要求 + +1. **内容来源**:严格且仅基于提供的 **Context** 信息进行总结和洞察提取。 +2. **格式要求**:**必须**以 **JSON** 格式输出,并且只包含 **'insights'** 和 **'summary'** 两个顶级键。 +3. **洞察 (insights)**: + * 必须是**列表 (List)** 形式。 + * 每条洞察应是一个独立的、简洁的**陈述句**。 + * 侧重于**关键发现、异常值或重要趋势** +4. **总结 (summary)**: + * 必须是**字符串 (String)** 形式。 + * 内容应是结构化的、详细的**叙述性段落**,全面概述核心发现及其业务含义。 + +### Context (分析内容和中间结果) +{full_context} + +### 最终输出格式 (必须是 JSON,用英文撰写) +{{ + "insights": [ + "第一个关键发现...", + "第二个关键发现..." + ], + "summary": "(详细且结构化的叙述性总结)" +}} +""" + + +SYNTHESIS_PROMPT_TEMPLATE = """ +你是一名首席战略官。你已经收到了来自多个部门的报告,包括他们之间的同行评审(批注),以及一份跨部门的数值交叉分析报告。 +你的任务是综合所有这些信息,提炼出 2-5 个高层次的、跨职能的洞察。 +请专注于识别单一部门会错过的因果联系、利弊权衡和战略机会。 + +**完整上下文:** +{full_context} + +**指示:** +生成一个最终的洞察列表。每个洞察都应该是一个完整的句子。 +请使用 Markdown 的无序列表(以 - 开始)格式化你的回答。 +""" + +# ### [NEW] ### +# 为数值交叉步骤生成问题的Prompt +NUMERICAL_CROSSOVER_IDEATION_PROMPT = """ +你是一位经验丰富的数据分析主管。你已经收到了各个部门的初步分析报告和他们之间的交叉评论(批注)。 +你的任务是基于这些信息,提出 1-3 个需要进行**跨数据集数值计算**的具体问题。 + +这些问题应该是: +1. **具体的**:可以被转化为代码执行。例如,“比较营销部门的广告支出和销售部门的销售额随时间变化的趋势” 而不是 “看看营销和销售有没有关系”。 +2. **跨领域的**:需要联合至少两个数据源才能回答。 +3. **有价值的**:能够揭示单一部门无法发现的深层联系。 + +**背景信息:** +{context} + +请严格按照以下格式,只输出需要计算的问题,每个问题占一行,不要有其他多余的文字: +问题1 +问题2 +... +""" + +INTERPRET_SOLUTION = """ +### Instruction: +You are trying to answer a question based on information provided by a data scientist. + +Given the context: + + You need to answer a question based on information provided by a data scientist. + + +Given the following dataset schema: +{schema} + +Given the goal: +{goal} + +Given the question: +{question} + +Given the analysis: + + + {message} + + {insights} + + +Instructions: +* Based on the analysis and other information provided above, write an answer to the question enclosed with tags. +* **重要:所有内容(answer、insight、justification)必须用中文撰写。** +* The answer should be a single sentence, but it should not be too high level and should include the key details from justification. +* Write your answer in HTML-like tags, enclosing the answer between tags, followed by a justification between tags, followed by an insight between tags. +* Refer to the following example response for the format of the answer and justification. +* The insight should be something interesting and grounded based on the question, goal, and the dataset schema, something that would be interesting. +* The insight should be as quantiative as possible and informative and non-trivial and concise. +* The insight should be a meaningful conclusion that can be acquired from the analysis in laymans terms + +Example response: +This is a sample answer +This is a sample insight +This is a sample justification + +### Response: +""" + + +# =========================== +# (1) Recommend Questions Prompts +# =========================== +def get_question_prompt(method="basic"): + if method == "basic": + prompt_template = GET_QUESTIONS_TEMPLATE + system_template = GET_QUESTIONS_SYSTEM_MESSAGE + if method == "follow_up": + prompt_template = FOLLOW_UP_TEMPLATE + system_template = FOLLOW_UP_SYSTEM_MESSAGE + if method == "follow_up_with_type": + prompt_template = FOLLOW_UP_TYPE_TEMPLATE + system_template = FOLLOW_UP_SYSTEM_MESSAGE + + return prompt_template, system_template + + +# =========================== +# (2) CODE Prompts +# =========================== + +def get_code_prompt(method=None): + """ + Returns the appropriate prompt template for code generation based on the method. + """ + code_template = None # Initialize + + if method == "single" or method == "basic": + # 【修改】 指向我们强化的 SINGLE 模板 + code_template = GENERATE_CODE_SINGLE_TEMPLATE + + elif method == "multi": + # 【修改】 指向我们强化的 REINFORCED_MULTI 模板 + code_template = REINFORCED_MULTI_CODE_PROMPT + + elif method == "multi_with_paths": + # 【修改】 指向我们强化的 REINFORCED_MULTI 模板 + code_template = REINFORCED_MULTI_CODE_PROMPT + + else: + # 添加一个后备/默认选项或抛出错误,避免UnboundLocalError + print(f"Warning: Code prompt method '{method}' not recognized. Falling back to 'single'.") + # 【修改】 默认也使用强化的 SINGLE 模板 + code_template = GENERATE_CODE_SINGLE_TEMPLATE + + return code_template + + +# =========================== +# (3) Interpret Prompt +# =========================== + +# 在文件顶部,和 INTERPRET_SOLUTION 放在一起 +# 我们为多文件场景创建一个新的(或者复用现有的)解释模板 +# 这里我们复用 INTERPRET_SOLUTION,因为它足够通用。 +# 如果需要更复杂的,可以专门为多文件场景写一个。 +INTERPRET_SOLUTION_MULTI = INTERPRET_SOLUTION + +def get_interpret_prompt(method): + prompt_template = None # 先初始化为 None + + if method == "basic": + prompt_template = INTERPRET_SOLUTION + + # 增加对 "interpret" 方法的处理,因为你的 agents.py 可能也用到了这个默认值 + elif method == "interpret": + prompt_template = INTERPRET_SOLUTION + + # 增加对多文件场景的处理,这直接解决了你的 UnboundLocalError + elif method == "multi_with_paths": + prompt_template = INTERPRET_SOLUTION_MULTI # 使用多文件解释模板 + + # 提供一个健壮的后备选项 + else: + print(f"Warning: Interpret prompt method '{method}' not recognized. Falling back to 'basic'.") + prompt_template = INTERPRET_SOLUTION + + return prompt_template + + +# =========================== +# (4) Summarize Insights Prompt +# =========================== +def get_summarize_prompt(method="basic"): + if method == "basic": + prompt_template = SUMMARIZE_TEMPLATE + system_template = SUMMARIZE_SYSTEM_MESSAGE + + return prompt_template, system_template + + +GET_QUESTIONS_TEMPLATE = """ +### Instruction: + +Given the following context: +{context} + +Given the following goal: +{goal} + +Given the following schema: +{schema} + +Instructions: +* Write a list of questions to be solved by the data scientists in your team to explore my data and reach my goal. +* Explore diverse aspects of the data, and ask questions that are relevant to my goal. +* You must ask the right questions to surface anything interesting (trends, anomalies, etc.) +* Make sure these can realistically be answered based on the data schema. +* The insights that your team will extract will be used to generate a report. +* Each question should only have one part, that is a single '?' at the end which only require a single answer. +* Do not number the questions. +* You can produce at most {max_questions} questions. Stop generation after that. +* Most importantly, each question must be enclosed within tags. Refer to the example response below: + +Example response: +What is the average age of the customers? +What is the distribution of the customers based on their age? + +### Response: +""" + +GET_QUESTIONS_SYSTEM_MESSAGE = """ +You the manager of a data science team whose goal is to help stakeholders within your company extract actionable insights from their data. +You have access to a team of highly skilled data scientists that can answer complex questions about the data. +You call the shots and they do the work. +Your ultimate deliverable is a report that summarizes the findings and makes hypothesis for any trend or anomaly that was found. +""" + + + +RETRY_TEMPLATE = """You failed. + +Instructions: +------------- +{initial_prompt} +------------- + +Completion: +------------- +{prev_output} +------------- + +Above, the Completion did not satisfy the constraints given in the Instructions. +Error: +------------- +{error} +------------- + +Please try again. Do not apologize. Please only respond with an answer that satisfies the constraints laid out in the Instructions: + +""" + + +GET_INSIGHTS_TEMPLATE = """ +Hi, I require the services of your team to help me reach my goal. + +{context} + +{goal} + +{schema} + +Instructions: +* Produce a list of possible insights that we should look into to explore my data and reach my goal. +* Explore diverse aspects of the data, and present possible interesting insights (with explanation) that are relevant to my goal. +* Make sure these can realistically be based on the data schema. +* The insights that your team will extract will be used to insight a report. +* Each question that you produce must be enclosed in tags. +* Do not number the questions. +* You can produce at most {max_questions} insight. + +""" + +GET_INSIGHTS_SYSTEM_MESSAGE = """ +You the manager of a data science team whose goal is to help stakeholders within your company extract actionable insights from their data. +You have access to a team of highly skilled data scientists that can answer complex questions about the data. +You call the shots and they do the work. +Your ultimate deliverable is a report that summarizes the findings and makes hypothesis for any trend or anomaly that was found. +""" + + +GET_DATASET_DESCRIPTION_TEMPLATE = """ +Hi, I require the services of your team to help me reach my goal. + +{context} + +{goal} + +{schema} + +Instructions: +* Generate a description of the dataset provided in the schema. +* The description should include the number of rows, columns, and a brief summary of the data. +* The description should be enclosed inside content tags. + +""" + +GET_DATASET_DESCRIPTION_SYSTEM_MESSAGE = """ +You the manager of a data science team whose goal is to help stakeholders within your company extract actionable insights from their data. +You have access to a team of highly skilled data scientists that can answer complex questions about the data. +You call the shots and they do the work. +Your ultimate deliverable is a report that summarizes the findings and makes hypothesis for any trend or anomaly that was found. +""" + +FOLLOW_UP_TEMPLATE = """ +Hi, I require the services of your team to help me reach my goal. + +{context} + +{goal} + +{schema} + +{question} + +{answer} + +Instructions: +* Produce a list of follow up questions explore my data and reach my goal. +* Note that we have already answered and have the answer at , do not include a question similar to the one above. +* Explore diverse aspects of the data, and ask questions that are relevant to my goal. +* You must ask the right questions to surface anything interesting (trends, anomalies, etc.) +* Make sure these can realistically be answered based on the data schema. +* The insights that your team will extract will be used to generate a report. +* Each question that you produce must be enclosed in content tags. +* Each question should only have one part, that is a single '?' at the end which only require a single answer. +* Do not number the questions. +* You can produce at most {max_questions} questions. + +""" + +FOLLOW_UP_TYPE_TEMPLATE = """ +Hi, I require the services of your team to help me reach my goal. + +{context} + +{goal} + +{schema} + +{question_type} + +{question} + +{answer} + +Instructions: +* Produce a list of follow up questions explore my data and reach my goal. +* Note that we have already answered and have the answer at , do not include a question similar to the one above. +* Explore diverse aspects of the data, and ask questions that are relevant to my goal. +* You must ask the right questions to surface anything interesting (trends, anomalies, etc.) +* Make sure these can realistically be answered based on the data schema. +* The insights that your team will extract will be used to generate a report. +* The question has to adhere to the type of question that is provided in the tag +* The type of question is either descriptive, diagnostic, prescriptive, or predictive. +* Each question that you produce must be enclosed in content tags. +* Each question should only have one part, that is a single '?' at the end which only require a single answer. +* Do not number the questions. +* You can produce at most {max_questions} questions. + +""" + + +FOLLOW_UP_SYSTEM_MESSAGE = """ +You the manager of a data science team whose goal is to help stakeholders within your company extract actionable insights from their data. +You have access to a team of highly skilled data scientists that can answer complex questions about the data. +You call the shots and they do the work. +Your ultimate deliverable is a report that summarizes the findings and makes hypothesis for any trend or anomaly that was found. +""" + +SELECT_A_QUESTION_TEMPLATE = """ +Hi, I require the services of your team to help me reach my goal. + +{context} + +{goal} + +{prev_questions_formatted} + +{followup_questions_formatted} + +Instructions: +* Given a context and a goal, select one follow up question from the above list to explore after prev_question that will help me reach my goal. +* Do not select a question similar to the prev_questions above. +* Output only the index of the question in your response inside tag. +* The output questions id must be 0-indexed. +""" + +SELECT_A_QUESTION_SYSTEM_MESSAGE = """ +You the manager of a data science team whose goal is to help stakeholders within your company extract actionable insights from their data. +You have access to a team of highly skilled data scientists that can answer complex questions about the data. +You call the shots and they do the work. +Your ultimate deliverable is a report that summarizes the findings and makes hypothesis for any trend or anomaly that was found. +""" + + +# 【修改】 这是旧的、较弱的模板。我们同样强化它。 +GENERATE_CODE_TEMPLATE = """ + +Given the goal:\n +{goal} + +Given the schema:\n +{schema} + +Given the data path:\n +{database_path} + +Given the list of predefined functions in insight.tools module and their example usage:\n\n +{function_docs} + +Give me the python code required to answer this question "{question}" and put a comment on top of each variable.\n\n + +--- +**CRITICAL INSTRUCTIONS FOR WRITING PYTHON CODE:** +--- +1. **File Reading**: + - You MUST load the file at `{database_path}` using the appropriate pandas function **based on its file extension**. + - For example: use `pd.read_csv()` for `.csv` files, `pd.read_json()` for `.json` files. + - **If reading a CSV file**: Handle potential `UnicodeDecodeError`. First, try `encoding='utf-8'`. If it fails, try `encoding='gbk'` or `encoding='latin1'`. + +2. **CRITICAL: Date/Time Columns**: + - After loading the data, inspect the schema. If you see any columns that represent dates or times (e.g., 'date', 'timestamp'), you **MUST** convert them to datetime objects using `pd.to_datetime(df['column_name'], errors='coerce')`. + - **DO NOT** use string methods like `.strftime()` before conversion. All date operations **MUST** use the `.dt` accessor *after* conversion. + +3. **Code Quality & Data Types**: + - When creating a `pd.DataFrame` from a dictionary, ensure all arrays/lists have the same length to avoid `ValueError`. + - Be mindful of data types. Do not assign string values to numeric columns or vice-versa, to avoid `FutureWarning`. + +4. **Output Generation**: + - **Chinese Font Setup**: Configure matplotlib for Chinese fonts using: + ```python + import matplotlib + matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS'] + matplotlib.rcParams['axes.unicode_minus'] = False + ``` + - Make simple plots and save them as `plot.jpg` file. + - Use standard Python `json` module to save JSON outputs: + ```python + with open('filename.json', 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + ``` + - For every plot, save a stats json file (`stat.json`), and x/y axis json files (`x_axis.json`, `y_axis.json`). + - Each json file must have a "name", "description", and "value" field. +--- + +Make a single code block for starting with ```python +Import json, pandas as pd, and numpy as np at the beginning. +End your code with ```. + +Output code:\n +""" + +# 【修改】 这是旧的、较弱的多文件模板。我们同样强化它。 +GENERATE_CODE_TEMPLATE_MULTI = """ + +Given the goal:\n +{goal} + +Given the schema of the first dataset:\n +{schema} + +Given the data path of the first dataset:\n +{database_path} + +Given the schema of the second dataset:\n +{user_schema} + +Given the data path of the second dataset:\n +{user_database_path} + +Given the list of predefined functions in insight.tools module and their example usage:\n\n +{function_docs} + +Give me the python code required to answer this question "{question}" and put a comment on top of each variable.\n\n + +--- +**CRITICAL INSTRUCTIONS FOR WRITING PYTHON CODE:** +--- +1. **File Reading**: + - You MUST load the files (e.g., `{database_path}`, `{user_database_path}`) using the appropriate pandas function **based on each file's extension**. + - For example: use `pd.read_csv()` for `.csv` files, `pd.read_json()` for `.json` files. + - **If reading a CSV file**: Handle potential `UnicodeDecodeError`. First, try `encoding='utf-8'`. If it fails, try `encoding='gbk'` or `encoding='latin1'`. + +2. **CRITICAL: Date/Time Columns**: + - After loading **EACH** dataframe, inspect its schema. If you see any columns that represent dates or times, you **MUST** convert them to datetime objects using `pd.to_datetime(df['column_name'], errors='coerce')`. + - **DO NOT** use string methods like `.strftime()` before conversion. All date operations **MUST** use the `.dt` accessor *after* conversion. + +3. **Code Quality & Data Types**: + - When creating a `pd.DataFrame` from a dictionary, ensure all arrays/lists have the same length to avoid `ValueError`. + - Be mindful of data types. Do not assign string values to numeric columns or vice-versa, to avoid `FutureWarning`. + +4. **Output Generation**: + - **Chinese Font Setup**: Configure matplotlib for Chinese fonts using: + ```python + import matplotlib + matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS'] + matplotlib.rcParams['axes.unicode_minus'] = False + ``` + - You must generate one single simple plot and save it as a `plot.jpg` file. + - Use standard Python `json` module to save JSON outputs: + ```python + with open('filename.json', 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + ``` + - For the plot, save a stats json file (`stat.json`), and x/y axis json files (`x_axis.json`, `y_axis.json`). + - Each json file must have a "name", "description", and "value" field. +--- + +Make a single code block for starting with ```python +Import json, pandas as pd, and numpy as np at the beginning. +Do not produce code blocks for languages other than Python. +End your code with ```. + +Output code:\n +""" + +# 【修改】 这是你定义的单文件强化模板。我们应用所有修复。 +GENERATE_CODE_SINGLE_TEMPLATE = """ +**Goal:** {goal} +**Question:** "{question}" +**Dataset Schema:** +{schema} + +**File Path:** +The dataset is located at `{database_path}`. + +--- +**CRITICAL INSTRUCTIONS FOR WRITING PYTHON CODE:** + +1. **File Reading**: + - You MUST load the file at `{database_path}` using the appropriate pandas function **based on its file extension** (e.g., `pd.read_csv()`, `pd.read_json()`). + - **If reading a CSV file**: You MUST handle encoding errors. Use a `try-except` block. First, try `encoding='utf-8'`. If it fails, try `encoding='gbk'` or `encoding='latin1'`. + - **Example for robust CSV reading**: + ```python + import pandas as pd + file_path = '{database_path}' # This is the path + try: + df = pd.read_csv(file_path, encoding='utf-8') + except UnicodeDecodeError: + df = pd.read_csv(file_path, encoding='gbk') + ``` + +2. **CRITICAL: Date/Time Columns**: + - After loading the data, inspect the schema. If you see any columns that represent dates or times (e.g., 'date', 'timestamp'), you **MUST** convert them to datetime objects using `pd.to_datetime(df['column_name'], errors='coerce')`. + - **DO NOT** attempt to use string methods like `.strftime()` on a column before converting it to datetime. All date operations **MUST** use the `.dt` accessor *after* this conversion. + - Use `insight.tools.safe_datetime_parse()` for robust date parsing if standard methods fail. + +3. **Code Quality & Data Types**: + - When creating a `pd.DataFrame` from a dictionary, ensure all arrays/lists have the same length to avoid `ValueError`. + - Be mindful of data types. Do not assign string values to numeric columns or vice-versa, to avoid `FutureWarning`. + - Use `insight.tools.safe_numeric_convert()` for converting mixed-type columns to numeric. + +4. **CRITICAL: Empty DataFrame Checks**: + - After loading data, ALWAYS check if the DataFrame is empty: `if df.empty: print("Warning: Empty DataFrame")` + - After filtering operations, check if the result is empty before proceeding. + - Before aggregations (mean, sum, etc.), verify there is data to aggregate. + - **Example pattern**: + ```python + filtered_df = df[df['column'] > threshold] + if filtered_df.empty: + print("No data matches the filter criteria") + # Provide sensible defaults or skip the operation + else: + result = filtered_df['value'].mean() + ``` + +5. **Error Handling**: + - Wrap critical operations in try-except blocks. + - For column access, verify the column exists first: `if 'column_name' in df.columns:` + - Handle KeyError, ValueError, and TypeError gracefully. + +6. **Output Generation**: + - **Chinese Font Setup**: Configure matplotlib for Chinese fonts using: + ```python + import matplotlib + matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS'] + matplotlib.rcParams['axes.unicode_minus'] = False + ``` + - Use standard Python `json` module to save JSON outputs: + ```python + with open('filename.json', 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + ``` + - Generate one simple plot and save it as a `.jpg` file. + - For the plot, save a statistics summary to `stat.json`. + - Save the X and Y axis data (max 50 points) to `x_axis.json` and `y_axis.json` respectively. + - Each JSON file must have "name", "description", and "value" fields. Ensure content is less than 4500 characters. + +7. **Code Structure**: + - Start your code block with ```python and end it with ```. + - Do not produce any text outside of this single Python code block. + +**Available Tools:** +{function_docs} + +--- +Now, write the Python code to answer the question. + +```python +""" + +# 【修改】 这是你定义的多文件强化模板。我们应用所有修复。 +REINFORCED_MULTI_CODE_PROMPT = """ +**Goal:** {goal} +**Question:** "{question}" + +**Available Datasets:** +--- +**Dataset(s):** +- **Schema:** +{multi_schema} +- **File Path(s):** `{multi_database_path}` +- **Data Profiles (Statistical Summary):** +{multi_profile} +--- + +**CRITICAL INSTRUCTIONS FOR WRITING PYTHON CODE:** + +1. **File Reading**: + - You MUST load all files (from `{multi_database_path}`) using the appropriate pandas function **based on their file extension** (e.g., `pd.read_csv()`, `pd.read_json()`). + - **If reading a CSV file**: You MUST handle encoding errors. Use a `try-except` block for EACH CSV file. First, try `encoding='utf-8'`. If it fails, try `encoding='gbk'` or `encoding='latin1'`. + - **Example for robust CSV reading** (apply this logic to all CSV files you read): + ```python + import pandas as pd + # For main file (if it's a CSV) + try: + df1 = pd.read_csv('/your/csv/path.csv', encoding='utf-8') + except UnicodeDecodeError: + df1 = pd.read_csv('/your/csv/path.csv', encoding='gbk') + ``` + - **For JSON files**: Simply use `df = pd.read_json(file_path)` + +2. **CRITICAL: Date/Time Columns**: + - After loading **EACH** dataframe, inspect its schema. If a column represents dates/times, you **MUST** convert it to a datetime object using `pd.to_datetime(df['column_name'], errors='coerce')`. + - **DO NOT** use string methods like `.strftime()` before conversion. All date operations **MUST** use the `.dt` accessor *after* conversion. + - Use `insight.tools.safe_datetime_parse()` for robust date parsing if standard methods fail. + +3. **Data Merging**: + - You will likely need to merge or join the dataframes to answer the question. Use `pd.merge()` or `pd.concat()` on a common column (e.g., 'user_id', 'date'). + - **IMPORTANT**: Before merging, verify that the join columns exist in both DataFrames and have compatible types. + - After merging, check if the result is empty: `if merged_df.empty: print("Warning: No matching records found")` + +4. **CRITICAL: Empty DataFrame Checks**: + - After loading each file, check if empty: `if df.empty: print(f"Warning: {{filepath}} is empty")` + - After filtering or merging, always check for empty results. + - Before aggregations, verify there is data to aggregate. + +5. **Code Quality & Data Types**: + - When creating a `pd.DataFrame` from a dictionary, ensure all arrays/lists have the same length to avoid `ValueError`. + - Be mindful of data types. Do not assign string values to numeric columns or vice-versa, to avoid `FutureWarning`. + - Use `insight.tools.safe_numeric_convert()` for converting mixed-type columns to numeric. + +6. **Error Handling**: + - Wrap critical operations in try-except blocks. + - Verify columns exist before accessing them. + - Handle KeyError, ValueError, and TypeError gracefully. + +7. **Output Generation**: + - **CRITICAL: Chinese Font Setup**: Before creating any plot, you MUST call `setup()` from `insight.tools` to ensure Chinese characters display correctly in plots. Add this line before any plotting code: `setup()`. + - Use functions from `insight.tools` to save all outputs. + - Generate one plot and save it as a `.jpg`. + - Save statistics to `stat.json`, and axis data (max 100 points) to `x_axis.json` and `y_axis.json`. + - Call `insight.tools.fix_fnames()` at the very end. + +8. **Code Structure**: + - Enclose your entire script in a single ```python code block. No other text. + +**Available Tools:** +{function_docs} + +--- +Now, write the Python code to answer the question by analyzing and combining the provided datasets. + +```python +""" + + +def get_g_eval_prompt(method="basic"): + if method == "basic": + geval_template, system_template = ( + G_EVAL_BASIC_TEMPLATE, + G_EVAL_BASIC_SYSTEM_MESSAGE, + ) + if method == "binary": + geval_template, system_template = ( + G_EVAL_BINARY_TEMPLATE, + G_EVAL_BINARY_SYSTEM_MESSAGE, + ) + + return geval_template, system_template + + +G_EVAL_BASIC_TEMPLATE = """ +Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: +Provided Answer: +{answer} + +Ground Truth Answer: +{gt_answer} + +Follow these instructions when writing your response: +* On a scale of 1-10, provide a numerical rating for how close the provided answer is to the ground truth answer, with 10 denoting that the provided answer is the same as ground truth answer. +* Your response should contain only the numerical rating. DONOT include anything else like the provided answer, the ground truth answer, or an explanation of your rating scale in your response. +* Wrap your numerical rating inside tags. +* Check very carefully before answering. +* Follow the output format as shown in the example below: +Example response: +7 + +### Response: + +""" + +G_EVAL_BINARY_SYSTEM_MESSAGE = """You are a high school teacher evaluating student responses to a question. You are tasked with grading the response based on how well it answers the question. You are to provide a numerical rating for how well the provided response matches the ground truth answer.""" + +G_EVAL_BASIC_SYSTEM_MESSAGE = """You are a high school teacher evaluating student responses to a question. You are tasked with grading the response based on how well it answers the question. You are to provide a numerical rating for how well the response answers the question based on the ground truth answer.""" + + +G_EVAL_BINARY_TEMPLATE = """ +Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: +Provided answer: +{answer} + +GT Answer: +{gt_answer} + +On a scale of 1-10, provide a numerical rating for how close the provided answer is to the ground truth answer, with 10 denoting that the provided answer is the the same as ground truth answer. The response should contain only the numerical rating.\ + +Check very carefully before answering. + +### Response: +""" + +G_EVAL_SYSTEM_MESSAGE = """You are a a high school teacher evaluating student responses to a question. You are tasked with grading the response based on how well it answers the question. You are to provide a numerical rating for how well the response answers the question based on the ground truth answer.""" + + +G_EVAL_M2M_TEMPLATE = """ +Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: +Predicted Answers: +{pred_list} + +Grouth Truth Answers: +{gt_list} + +For each ground truth answer above, provide the index of the most appropriate predicted answer (1-indexed). +Each line must contain a single integer value denoting the id of the matched prediction. +If there is no appropriate prediction for a ground truth answer, write -1. +Check very carefully before answering. + +### Response: +""" + +G_EWAL_M2M_SYSTEM_MESSAGE = "You are a high school teacher evaluating student responses to some questions. Before scoring their answers, you need to first match each ground truth answer with the most appropriate answer provided by the student." + +SUMMARIZE_TEMPLATE = """ +Hi, I require the services of your team to help me reach my goal. + +{context} + +{goal} + +{history} + +Instructions: +* Given a context and a goal, and all the history of pairs from the above list generate the 3 top actionable insights. +* Make sure they don't offer actions and the summary should be more about highlights of the findings +* Output each insight within this tag . +* Each insight should be a meaningful conclusion that can be acquired from the analysis in laymans terms and should be as quantiative as possible and should aggregate the findings. +""" + +SUMMARIZE_SYSTEM_MESSAGE = """ +You the manager of a data science team whose goal is to help stakeholders within your company extract actionable insights from their data. +You have access to a team of highly skilled data scientists that can answer complex questions about the data. +You call the shots and they do the work. +Your ultimate deliverable is a report that summarizes the findings and makes hypothesis for any trend or anomaly that was found. +""" + +# --- 在 prompts/__init__.py 的顶部或合适位置,添加这个新的模板 --- + +# 新增的Prompt模板,用于指导LLM处理多个文件路径 +MULTI_WITH_PATHS_CODE_PROMPT = """ +Your goal is to write a Python script that addresses the following question. + +**Overall Goal:** +{goal} + +**Current Question:** +{question} + +You have access to multiple datasets. Here are their schemas and file paths: + + +{schema} + + +**Important Instructions:** +1. You **MUST** write a Python script. +2. Load the necessary data from the provided CSV file paths. You might need to load and merge data from multiple files. +3. The main dataset is located at `{database_path}`. Other datasets are listed in the schema section. +4. **CRITICAL: Chinese Font Setup**: Before creating any plot, you MUST call `setup()` from `insight.tools` to ensure Chinese characters display correctly in plots. Add this line before any plotting code: `setup()`. +5. Use the `tools` module for plotting and saving results. All outputs **MUST** be saved to files using the provided functions. +6. **Do not** use `plt.show()` or `print()` for final outputs. Save plots as `plot.jpg` and statistical results as JSON files (`stat.json`, `x_axis.json`, `y_axis.json`). +7. The final script should be enclosed in a single ```python code block. + +Available tools from the `tools` module: +{function_docs} + +Begin writing the Python script now. +```python +""" + +REPORT_GENERATION_PROMPT = """ +You are a professional data scientist and report writer. Your task is to generate a comprehensive academic data analysis report in **Markdown** format based on the conversation history provided. + +### Requirements: +1. **Language**: The entire report must be written in **English**. +2. **Content Filtering**: + - Focus ONLY on the successful analysis steps, logic, and results. + - **DO NOT** include any code errors, debugging processes, or failed attempts that appeared in the chat history. +3. **Image Insertion**: + - You must include the visualizations generated during the analysis. + - Insert them using Markdown syntax: `![Figure Description](path/to/figure.png)`. + - Use the figure paths provided in the context. +4. **Formatting**: Use clear headers (H1, H2, H3), bullet points, and tables to organize the content. + +### Report Structure: +1. **Title**: A concise title for the analysis. +2. **Abstract**: (approx. 200 words) Background, dataset summary, methods, and key conclusions. +3. **Introduction**: Background of the task and dataset description. +4. **Methodology**: + - **Dataset**: Statistical description, feature analysis, missing values, etc. + - **Data Processing**: Steps taken to clean and process the data (show processed data examples if available). + - **Modeling/Analysis**: Algorithms or analytical methods used. +5. **Results**: + - Present key findings. + - **Crucial**: Insert the generated figures here to support your analysis. + - Use tables to summarize model metrics or key data statistics. +6. **Conclusion**: (approx. 200 words) Summary of the entire report. + +### Context: +The chat history involves a user interacting with a code interpreter to analyze data. Your job is to synthesize this interaction into a formal report. + +Here's the state data: {state_str} + +Please generate a detailed report based on data and requirements above. +""" \ No newline at end of file diff --git a/workflow_engine/toolkits/insight_tool/dm_components/requirements.txt b/workflow_engine/toolkits/insight_tool/dm_components/requirements.txt new file mode 100644 index 0000000..096ee03 --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/requirements.txt @@ -0,0 +1,101 @@ +absl-py==2.3.1 +aiohappyeyeballs==2.6.1 +aiohttp==3.13.2 +aiosignal==1.4.0 +aiosqlite==0.21.0 +annotated-types==0.7.0 +anyio==4.11.0 +async-timeout==4.0.3 +attrs==25.4.0 +certifi==2025.10.5 +charset-normalizer==3.4.4 +click==8.3.1 +contourpy==1.3.2 +cycler==0.12.1 +datasets==4.4.1 +dill==0.4.0 +distro==1.9.0 +evaluate==0.4.6 +exceptiongroup==1.3.0 +fastjsonschema==2.21.2 +filelock==3.20.0 +fonttools==4.60.1 +frozenlist==1.8.0 +fsspec==2025.10.0 +greenlet==3.2.4 +h11==0.16.0 +hf-xet==1.2.0 +httpcore==1.0.9 +httpx==0.28.1 +huggingface_hub==1.1.7 +idna==3.11 +jiter==0.11.0 +joblib==1.5.2 +jsonpatch==1.33 +jsonpointer==3.0.0 +jsonschema==4.25.1 +jsonschema-specifications==2025.9.1 +jupyter_core==5.9.1 +kiwisolver==1.4.9 +langchain==0.3.27 +langchain-core==0.3.79 +langchain-openai==0.3.35 +langchain-text-splitters==0.3.11 +langgraph==0.6.10 +langgraph-checkpoint==2.1.2 +langgraph-checkpoint-sqlite==2.0.11 +langgraph-prebuilt==0.6.4 +langgraph-sdk==0.2.9 +langsmith==0.4.37 +loguru==0.7.3 +matplotlib==3.10.7 +multidict==6.7.0 +multiprocess==0.70.18 +nbformat==5.10.4 +nltk==3.9.2 +numpy==2.2.6 +openai==2.4.0 +orjson==3.11.3 +ormsgpack==1.11.0 +packaging==25.0 +pandas==2.3.3 +pillow==12.0.0 +platformdirs==4.5.0 +propcache==0.4.1 +pyarrow==22.0.0 +pydantic==2.12.2 +pydantic_core==2.41.4 +pyparsing==3.2.5 +python-dateutil==2.9.0.post0 +pytz==2025.2 +PyYAML==6.0.3 +referencing==0.37.0 +regex==2025.9.18 +reportlab==4.4.4 +requests==2.32.5 +requests-toolbelt==1.0.0 +rouge-score==0.1.2 +rpds-py==0.27.1 +scikit-learn==1.7.2 +scipy==1.15.3 +seaborn==0.13.2 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +socksio==1.0.0 +SQLAlchemy==2.0.44 +sqlite-vec==0.1.6 +tenacity==9.1.2 +threadpoolctl==3.6.0 +tiktoken==0.12.0 +tqdm==4.67.1 +traitlets==5.14.3 +typer-slim==0.20.0 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +tzdata==2025.2 +urllib3==2.5.0 +wordcloud==1.9.4 +xxhash==3.6.0 +yarl==1.22.0 +zstandard==0.25.0 diff --git a/workflow_engine/toolkits/insight_tool/dm_components/tools.py b/workflow_engine/toolkits/insight_tool/dm_components/tools.py new file mode 100644 index 0000000..1c1dbfc --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/tools.py @@ -0,0 +1,565 @@ +import matplotlib +import matplotlib.pyplot as plt +import json, pandas as pd, os +import numpy as np +from typing import Dict, List, Optional, Union, Callable +from copy import deepcopy +from wordcloud import WordCloud +import seaborn as sns +from functools import wraps +import warnings + +def setup(): + """ + Set up Chinese font for matplotlib to ensure proper display of Chinese characters. + Tries multiple font options in order of preference, prioritizing the provided font file. + """ + # Priority 1: Use the provided font file if it exists (highest priority) + primary_font_path = "/mnt/DataFlow/qry/DM/DataManus/src/insight/utils/simhei.ttf" + + if os.path.exists(primary_font_path): + try: + matplotlib.font_manager.fontManager.addfont(primary_font_path) + font_prop = matplotlib.font_manager.FontProperties(fname=primary_font_path) + font_name = font_prop.get_name() + plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif'] + plt.rcParams['axes.unicode_minus'] = False + # Test if the font works by creating a test figure + test_fig = plt.figure(figsize=(1, 1)) + plt.close(test_fig) + return True + except Exception as e: + print(f"Warning: Failed to load primary font from {primary_font_path}: {e}") + + # Priority 2: Try other custom font paths + custom_font_paths = [ + "/home/ubuntu/qiruyi/DM/DataManus/src/insight/utils/simhei.ttf", + "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", + ] + + for font_path in custom_font_paths: + if os.path.exists(font_path): + try: + matplotlib.font_manager.fontManager.addfont(font_path) + font_prop = matplotlib.font_manager.FontProperties(fname=font_path) + font_name = font_prop.get_name() + plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif'] + plt.rcParams['axes.unicode_minus'] = False + test_fig = plt.figure(figsize=(1, 1)) + plt.close(test_fig) + return True + except Exception as e: + continue + + # Priority 3: Try system fonts + font_options = [ + 'Noto Sans CJK SC', + 'Noto Sans CJK TC', + 'SimHei', + 'Microsoft YaHei', + 'Droid Sans Fallback', + 'WenQuanYi Micro Hei', + 'WenQuanYi Zen Hei', + ] + + for font_name in font_options: + try: + plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif'] + plt.rcParams['axes.unicode_minus'] = False + # Test if the font works + test_fig = plt.figure(figsize=(1, 1)) + plt.close(test_fig) + return True + except: + continue + + # If all else fails, use DejaVu Sans (won't show Chinese but won't crash) + print("Warning: No Chinese font found. Chinese characters may not display correctly.") + plt.rcParams['axes.unicode_minus'] = False + return False + + +# 【新增】一个自定义的 JSON 编码器,用于处理 +class CustomJSONEncoder(json.JSONEncoder): + """ + 自定义 JSON 编码器,用于处理标准库无法序列化的类型 + (例如 numpy.int64, pandas.Timestamp) + """ + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, pd.Timestamp): + return obj.isoformat() # 将 Timestamp 转换为 ISO 格式的字符串 + if isinstance(obj, pd.Series): + return obj.tolist() + # 处理其他无法序列化的类型(例如,如果LLM返回了集合) + if isinstance(obj, set): + return list(obj) + + # 让基类来处理它不知道的类型 + return super(CustomJSONEncoder, self).default(obj) + + +def plot_countplot(df: pd.DataFrame, plot_column: str, plot_title: str) -> None: + """ + Takes a DataFrame as input, performs a group by on plot_column and saves a count plot. + The plot is then saved into plot.jpg + + Parameters: + df: DataFrame containing the data. + plot_column: Column name to plot. + plot_title: Title of the plot. + + Example usage: + >>> data = pd.DataFrame({ + ... 'category': ['A', 'B', 'A', 'B', 'A'], + ... }) + >>> plot_column = 'category' + >>> plot_title = 'Category count plot' + >>> plot_countplot(data, plot_column) + """ + # make countplot with plot title using seaborn + sns.countplot(data=df, x=plot_column, hue=plot_column).set_title(plot_title) + plt.savefig("plot.jpg") + plt.close() + + +def plot_lines( + df: pd.DataFrame, x_column: str, plot_columns: List[str], plot_title: str +) -> None: + """ + Takes a DataFrame as input, and makes a line plot of the data in plot_columns using seaborn. + The plot is then saved into plot.jpg + + Parameters: + df: DataFrame containing the data. + x_column: Column name with the x-axis data. + plot_columns: Columns with y-axis data to plot. + plot_title: Title of the plot. + + Example usage: + >>> data = pd.DataFrame({ + ... 'time': [10, 20, 30, 40, 50], + ... 'A': [1, 2, 3, 4, 5], + ... 'B': [5, 4, 3, 2, 1], + ... }) + >>> x_column = 'time' + >>> plot_columns = ['A', 'B'] + >>> plot_title = 'Line plot of A and B' + >>> plot_lines(data, x_column, plot_columns) + """ + # make lineplot with plot title using seaborn + for plot_column in plot_columns: + df[x_column] = df[x_column].astype(str) + sns.lineplot(data=df, x=x_column, y=plot_column, label=plot_column) + # set plot title + plt.title(plot_title) + plt.savefig("plot.jpg") + plt.close() + + +def save_json(data_dict: Dict, ftype: str) -> None: + """ + Saves data_dict to a json file. + + Parameters: + data_dict: Dictionary containing data to be saved. + ftype: One of "stat", "x_axis", or "y_axis". + + Example usage: + >>> ftype = "x_axis" + >>> data_dict = { + ... 'name': "X-axis", + ... 'description': "Different x-axis values for the plot.", + ... 'value': ["apple", "orange", "banana", "grapes"], + ... } + >>> save_json(data_dict, ftype) + """ + + def validate_dict(parent): + """ + Goes through all the keys in the dictionary and converts the keys are strings. + If the values are dictionaries, it recursively fixes them as well. + """ + duplicate = deepcopy(parent) + for k, v in duplicate.items(): + if isinstance(v, dict): + parent[k] = validate_dict(v) + if not isinstance(k, str): + parent[str(k)] = parent.pop(k) + return parent + + ftype = ftype.lower() + if "stat" in ftype: + ftype = "stat" + elif "x_axis" in ftype: + ftype = "x_axis" + elif "y_axis" in ftype: + ftype = "y_axis" + + assert all(isinstance(k, str) for k in data_dict.keys()) + # perform a sanity check that all the keys are strings + validate_dict(data_dict) + # recursively check if all the keys are strings + + # filename depends on the number of plots already in the folder + ftype_count = len([f for f in os.listdir() if f.startswith(f"{ftype}_")]) + fname = f"{ftype}.json" + with open(fname, "w", encoding='utf-8') as f: + # 【修改】 使用我们自定义的编码器,并确保中文正常显示 + json.dump(data_dict, f, indent=4, cls=CustomJSONEncoder, ensure_ascii=False) + + +def generate_wordcloud( + df: pd.DataFrame, group_by_column: str, plot_column: str +) -> None: + """ + Generates a wordcloud by performing a groupby on df and using the plot_column. + The plot is then saved into plot.jpg + + Parameters: + df: DataFrame containing the data. + group_by_column: Column name to group by. + plot_column: Column name to plot. + + Example usage: + >>> data = pd.DataFrame({ + ... 'category': ['A', 'B', 'A', 'B', 'A'], + ... 'description': ['apple', 'orange', 'banana', 'grapes', 'kiwi'], + ... }) + >>> group_by_column = 'category' + >>> plot_column = 'description' + >>> generate_wordcloud(data, group_by_column, plot_column) + """ + # check if data in plot_column is a string + assert isinstance(df[plot_column].iloc[0], str) + + # group by the column and aggregate the data + grouped_data = df.groupby(group_by_column)[plot_column].apply(list).reset_index() + # generate a wordcloud for each group + plt.figure(figsize=(20, 10)) + for i, row in grouped_data.iterrows(): + wc = WordCloud(width=800, height=400).generate(" ".join(row[plot_column])) + plt.subplot(1, len(grouped_data), i + 1) + plt.imshow(wc, interpolation="bilinear") + plt.title(row[group_by_column]) + plt.axis("off") + plt.savefig("plot.jpg") + plt.close() + + +def linear_regression(X, y): + """ + Fits a linear regression model on the data and returns the model. + + Parameters: + X: Features to fit the model on. + y: Target variable to predict. + + Example usage: + >>> X = np.array([1, 2, 3, 4, 5]).reshape(-1, 1) + >>> y = np.array([2, 4, 6, 8, 10]) + >>> model = linear_regression(X, y) + """ + from sklearn.linear_model import LinearRegression + + model = LinearRegression() + model.fit(X, y) + return model + + +def fix_fnames(): + """ + Renames all the plot and stat files in the current directory to plot_.jpg. + """ + for i, f in enumerate([f for f in os.listdir() if f.startswith("plot")]): + if f.startswith("plot"): + os.rename(f, f"plot.jpg") + + for i, f in enumerate([f for f in os.listdir() if f.startswith("stat")]): + if f.startswith("stat"): + os.rename(f, f"stat.json") + + for i, f in enumerate([f for f in os.listdir() if f.startswith("x_axis")]): + if f.startswith("x_axis"): + os.rename(f, f"x_axis.json") + + for i, f in enumerate([f for f in os.listdir() if f.startswith("y_axis")]): + if f.startswith("y_axis"): + os.rename(f, f"y_axis.json") + + +# ============================================================================= +# Robust Utility Functions - Added for improved code execution reliability +# ============================================================================= + +def safe_datetime_parse( + series: pd.Series, + formats: Optional[List[str]] = None, + errors: str = 'coerce' +) -> pd.Series: + """ + Safely parse a pandas Series to datetime, trying multiple formats automatically. + + This function attempts to convert a Series to datetime using various common formats. + It handles mixed formats and returns NaT for unparseable values when errors='coerce'. + + Parameters: + ----------- + series : pd.Series + The Series containing date/time values to parse. + formats : List[str], optional + List of datetime format strings to try. If None, uses common formats. + errors : str, default 'coerce' + How to handle parsing errors: 'coerce' (return NaT), 'raise', or 'ignore'. + + Returns: + -------- + pd.Series + A Series with datetime64[ns] dtype. + + Example usage: + >>> dates = pd.Series(['2023-01-15', '15/01/2023', '01-15-2023', 'invalid']) + >>> parsed = safe_datetime_parse(dates) + >>> print(parsed.dtype) # datetime64[ns] + """ + if series.empty: + return pd.Series(dtype='datetime64[ns]') + + # Default formats to try + default_formats = [ + '%Y-%m-%d', # 2023-01-15 + '%Y/%m/%d', # 2023/01/15 + '%d-%m-%Y', # 15-01-2023 + '%d/%m/%Y', # 15/01/2023 + '%m-%d-%Y', # 01-15-2023 + '%m/%d/%Y', # 01/15/2023 + '%Y-%m-%d %H:%M:%S', # 2023-01-15 14:30:00 + '%Y/%m/%d %H:%M:%S', # 2023/01/15 14:30:00 + '%d-%m-%Y %H:%M:%S', # 15-01-2023 14:30:00 + '%Y%m%d', # 20230115 + '%Y-%m-%dT%H:%M:%S', # ISO format + '%Y-%m-%dT%H:%M:%SZ', # ISO format with Z + ] + + formats_to_try = formats if formats else default_formats + + # First, try pandas' intelligent parser + try: + result = pd.to_datetime(series, errors=errors, infer_datetime_format=True) + if result.notna().any(): + return result + except Exception: + pass + + # Try each format explicitly + for fmt in formats_to_try: + try: + result = pd.to_datetime(series, format=fmt, errors='coerce') + # If more than 50% parsed successfully, use this format + if result.notna().sum() > len(series) * 0.5: + return result + except Exception: + continue + + # Fallback: try generic parsing with coerce + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return pd.to_datetime(series, errors=errors) + + +def safe_numeric_convert( + series: pd.Series, + downcast: Optional[str] = None, + fill_value: Optional[Union[int, float]] = None +) -> pd.Series: + """ + Safely convert a pandas Series to numeric type, handling mixed types gracefully. + + This function handles common issues like: + - Strings mixed with numbers (e.g., '1,234' or '50%') + - Currency symbols (e.g., '$100', '€50') + - Whitespace and special characters + - Empty strings and None values + + Parameters: + ----------- + series : pd.Series + The Series to convert to numeric. + downcast : str, optional + Downcast to 'integer', 'signed', 'unsigned', or 'float' for memory efficiency. + fill_value : int or float, optional + Value to use for non-convertible entries. If None, uses NaN. + + Returns: + -------- + pd.Series + A Series with numeric dtype. + + Example usage: + >>> mixed = pd.Series(['100', '1,234', '$50.5', '75%', 'N/A', None]) + >>> numeric = safe_numeric_convert(mixed) + >>> print(numeric.tolist()) # [100.0, 1234.0, 50.5, 75.0, nan, nan] + """ + if series.empty: + return pd.Series(dtype='float64') + + # Work on a copy + result = series.copy() + + # Convert to string for consistent processing + result = result.astype(str) + + # Remove common non-numeric characters + # Currency symbols + result = result.str.replace(r'[$€£¥₹]', '', regex=True) + # Thousands separators (comma) + result = result.str.replace(',', '', regex=False) + # Percentage signs (but keep the number) + result = result.str.replace('%', '', regex=False) + # Whitespace + result = result.str.strip() + # Common null representations + result = result.replace(['', 'nan', 'NaN', 'null', 'NULL', 'None', 'N/A', 'n/a', '-', '--'], np.nan) + + # Convert to numeric + result = pd.to_numeric(result, errors='coerce', downcast=downcast) + + # Fill NaN values if specified + if fill_value is not None: + result = result.fillna(fill_value) + + return result + + +def handle_empty_dataframe(operation_name: str = "operation"): + """ + Decorator to handle empty DataFrame scenarios gracefully. + + This decorator wraps functions that operate on DataFrames and provides + informative error messages when the DataFrame is empty or has no valid data. + + Parameters: + ----------- + operation_name : str + Name of the operation for error messages. + + Returns: + -------- + Callable + Decorated function with empty DataFrame handling. + + Example usage: + >>> @handle_empty_dataframe("aggregation") + ... def compute_stats(df: pd.DataFrame, column: str) -> Dict: + ... return { + ... 'mean': df[column].mean(), + ... 'sum': df[column].sum() + ... } + >>> + >>> result = compute_stats(pd.DataFrame(), 'value') # Returns empty dict with warning + """ + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + # Find DataFrame in arguments + df = None + for arg in args: + if isinstance(arg, pd.DataFrame): + df = arg + break + if df is None: + for key, value in kwargs.items(): + if isinstance(value, pd.DataFrame): + df = value + break + + # Check if DataFrame is empty + if df is not None and df.empty: + warnings.warn( + f"Empty DataFrame provided to {operation_name}. " + f"Returning default empty result.", + UserWarning + ) + # Try to return a sensible default based on return type hints + return_type = func.__annotations__.get('return', None) + if return_type == dict or return_type == Dict: + return {} + elif return_type == list or return_type == List: + return [] + elif return_type == pd.DataFrame: + return pd.DataFrame() + elif return_type == pd.Series: + return pd.Series(dtype='float64') + else: + return None + + # Check if DataFrame has all NaN values in relevant columns + if df is not None and len(df) > 0: + # Check if all values are NaN + if df.isna().all().all(): + warnings.warn( + f"DataFrame contains only NaN values for {operation_name}. " + f"Results may be unreliable.", + UserWarning + ) + + return func(*args, **kwargs) + return wrapper + return decorator + + +def validate_columns(df: pd.DataFrame, required_columns: List[str], operation_name: str = "operation") -> bool: + """ + Validate that required columns exist in a DataFrame. + + Parameters: + ----------- + df : pd.DataFrame + The DataFrame to validate. + required_columns : List[str] + List of column names that must be present. + operation_name : str + Name of the operation for error messages. + + Returns: + -------- + bool + True if all columns exist. + + Raises: + ------- + KeyError + If any required column is missing, with helpful suggestions. + + Example usage: + >>> df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) + >>> validate_columns(df, ['A', 'C'], 'analysis') # Raises KeyError with suggestions + """ + missing = [col for col in required_columns if col not in df.columns] + + if missing: + available = df.columns.tolist() + # Find similar column names for suggestions + suggestions = {} + for m in missing: + similar = [c for c in available if m.lower() in c.lower() or c.lower() in m.lower()] + if similar: + suggestions[m] = similar + + error_msg = f"Missing columns for {operation_name}: {missing}\n" + error_msg += f"Available columns: {available}\n" + if suggestions: + error_msg += "Possible matches:\n" + for m, s in suggestions.items(): + error_msg += f" '{m}' -> {s}\n" + + raise KeyError(error_msg) + + return True + \ No newline at end of file diff --git a/workflow_engine/toolkits/insight_tool/dm_components/utils/__init__.py b/workflow_engine/toolkits/insight_tool/dm_components/utils/__init__.py new file mode 100644 index 0000000..d3f5a12 --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/workflow_engine/toolkits/insight_tool/dm_components/utils/agent_utils.py b/workflow_engine/toolkits/insight_tool/dm_components/utils/agent_utils.py new file mode 100644 index 0000000..cb1c4e7 --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/utils/agent_utils.py @@ -0,0 +1,1629 @@ +import os +import numpy as np, pandas as pd, re, json, os, shutil, inspect +import nbformat +import contextlib +import io +import os +import re +import subprocess +import traceback +import sys + +from io import StringIO +from pathlib import Path +from dm_components import prompts +from copy import deepcopy +from typing import Dict, Any, List +from dm_components import tools +from dateutil.parser import parse +from langchain.schema import HumanMessage, SystemMessage +from warnings import warn +from functools import partial +from langchain.prompts import PromptTemplate +from openai import OpenAI +from pathlib import Path +import httpx + +# from langchain_community.chat_models import ChatOpenAI + + +OPENAI_API_KEY = os.getenv("QDF_API_KEY") +OPENAI_API_URL = os.getenv("QDF_API_URL") + +# Global config for API credentials (set by InsightEntry) +_global_api_key = None +_global_base_url = None + +def set_global_api_config(api_key: str = None, base_url: str = None): + """Set global API configuration for all get_chat_model calls.""" + global _global_api_key, _global_base_url + if api_key: + _global_api_key = api_key + if base_url: + _global_base_url = base_url +JSON_MAX_TOKENS = 40000 +JSON_MAX_CHARS = JSON_MAX_TOKENS * 4 # 4 chars per token (roughly) + + +def interpret_solution( + solution: dict, + n_retries: int, + model_name: str, + prompt_method, + schema, + temperature: 0, +) -> str: + """ + Produce insights for a task based on a solution output by a model + + Parameters: + ----------- + solution: dict + The output of the code generation function + answer_template: dict + A template for the answer that the human should provide. This template should contain a "results" tag + that contains a list of expected results in the form of dictionaries. Each dictionary should contain + the following keys: "name", "description", and "value". The model will be asked to fill in the values. + model: str + The name of the model to use (default: gpt-4) + n_retries: int + The number of times to retry the interpretation if it fails + + Returns: + -------- + solution_path: str + The path to the input solution file, which has been updated with the interpretation + + """ + prompt_template = prompts.get_interpret_prompt(method=prompt_method) + # create prompt + prompt = PromptTemplate.from_template(prompt_template) + + insight_prompt = _build_insight_prompt(solution) + + # instantiate llm model + llm = get_chat_model(model_name, temperature) + + # Get human readable answer + out, _ = retry_on_parsing_error( + llm, + prompt.format( + goal=solution["goal"], + question=solution["question"], + message=solution["code_output"], + insights=insight_prompt, + schema=schema, + ), + parser=_parse_human_readable_insight, + n_retries=n_retries, + ) + solution["interpretation"] = out + return solution + + +def extract_python_code_blocks(text): + """ + Extract and merge Python code blocks from a given text string. + + The function identifies code blocks that start with ``` or ```python and end with ```. + After extracting the code blocks, it removes the start and end delimiters (```, ```python), + and merges the code blocks together into a single string. + + Parameters + ---------- + text : str + The input string from which Python code blocks need to be extracted. + + Returns + ------- + str + A string containing the merged Python code blocks stripped of leading and trailing whitespaces. + Code blocks are separated by a newline character. + + """ + + code_blocks = re.findall(r"```(?:python)?(.*?)```", text, re.DOTALL) + return "\n".join(block.strip() for block in code_blocks) + + +class PythonREPL: + """ + Simulates a standalone Python REPL. + + TODO add a way to pass a random seed to the REPL + """ + + def __init__(self): + self.history = [] + + def run(self, command: str, workdir: str = None) -> str: + """Run command with own globals/locals and returns anything printed.""" + + if workdir is not None: + old_cwd = Path.cwd() + os.chdir(workdir) + + # 1. 定义一个所有代码执行前都需要的 "Header" + # 这可以防止 LLM 忘记导入常用库 + REPL_HEADER = """ +import pandas as pd +import numpy as np +import os +import sys +import json +import warnings +warnings.filterwarnings('ignore') + +# Try to setup Chinese font for matplotlib (optional) +try: + import matplotlib + matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif'] + matplotlib.rcParams['axes.unicode_minus'] = False +except Exception: + pass +""" + + # 2. Dynamically add project root to sys.path + # Path calculation: agent_utils.py is at insight_tool/dm_components/utils/agent_utils.py + # Need to go up 5 levels to reach Open-NotebookLM root + src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', '..')) + + # 3. 组合 Header、路径修复 和 LLM 生成的代码 + full_command = f""" +{REPL_HEADER} +if '{src_dir}' not in sys.path: + sys.path.insert(0, '{src_dir}') + +# --- LLM Generated Code Below --- +{command} +""" + + buffer = io.StringIO() + with contextlib.redirect_stdout(buffer): + try: + # Dynamically add project root to sys.path + # This allows the REPL environment to find 'insight' package + src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', '..')) + exec(f"import sys\nif '{src_dir}' not in sys.path:\n sys.path.insert(0, '{src_dir}')", locals()) + + exec(full_command, locals()) + valid = True + retry_message = "" + self.history.append((command, workdir)) + except Exception as e: + valid = False + retry_message = traceback.format_exc() + "\n" + str(e) + finally: + if workdir is not None: + os.chdir(old_cwd) + output = buffer.getvalue() + + return output, valid, retry_message + + def clone(self): + """Clone the REPL from history. + + it is not possible to clone the REPL from the globals/locals because they + may contain references to objects that cannot be pickled e.g. python modules. + Instead, we clone the REPL by replaying the history. + + Python REPL类中的clone()方法是为了复制当前REPL的状态。 + 由于REPL类内部维护了一个历史记录(history),该历史记录包含了之前执行过的所有命令以及执行时的工作目录。 + clone()方法通过创建一个新的REPL实例,并深拷贝历史记录,然后重新执行历史记录中的所有命令来复制当前REPL的状态。 + """ + new_repl = PythonREPL() + # deepcopy of history + new_repl.history = deepcopy(self.history) + + for command, workdir in self.history: + new_repl.run(command, workdir=workdir) + + return new_repl + + +def _execute_command(args): + """Execute a command and return the stdout, stderr and return code + + Parameters + ---------- + args : list of str or str, directly passed to subprocess.Popen + + Returns + ------- + stdout : str + stdout of the command + stderr : str + stderr of the command + returncode : int + return code of the command + """ + try: + process = subprocess.Popen( + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False + ) + except FileNotFoundError as e: + return "", str(e), 1 + + stdout, stderr = process.communicate() + + # decode bytes object to string + stdout = stdout.decode("utf-8") + stderr = stderr.decode("utf-8") + + return stdout, stderr, process.returncode + + +def pip_install(requirements): + """Install a list of requirements using pip + + Parameters + ---------- + requirements : list of str + List of requirements to install, should be in the format of pip install + + Returns + ------- + stdout : str + stdout of the command + valid : bool + True if the installation was successful, False otherwise + retry_message : str + Error message if the installation was not successful + """ + + if isinstance(requirements, str): + requirements = [requirements] + retry_messages = [] + stdouts = [] + for req in requirements: + stdout, stderr, code = _execute_command(["pip", "install", req]) + stdouts.append(stdout) + if stdout.strip().startswith("Usage:"): + retry_messages.append( + f"Seems like there is an error on the pip commandline, it just prints usage. stderr:\n{stderr}. stdout:\n{stdout[-1000:]}" + ) + if code != 0: + retry_messages.append( + f"Error code {code} when installing {req}. stderr:\n{stderr}. stdout:\n{stdout[-1000:]}" + ) + + valid = len(retry_messages) == 0 + retry_message = "\n".join(retry_messages) + return "\n".join(stdouts), valid, retry_message + + +# ============================================================================= +# Code generation +# ============================================================================= +def _code_parser(code, output_folder): + """ + A parser that is used to parse the code generated by the LLM + and determine whether it is acceptable or not + + """ + # Clean output folder + output_folder = Path(output_folder) + if output_folder.exists(): + shutil.rmtree(output_folder) + output_folder.mkdir(parents=True) + + # Extract code blocks from the input code (might contain other text) + code_block = extract_python_code_blocks(code) + if len(code_block) == 0: + # No code blocks detected so input is likely already raw code + code_block = code + + # Run code and report any errors + output, valid, retry_message = PythonREPL().run(code_block, workdir=output_folder) + if not valid: + return "", valid, retry_message + + # Validate output files + json_files = [f.name for f in output_folder.glob("*.json")] + plot_files = [f.name for f in output_folder.glob("*.jpg")] + + try: + # assert that there is x_axis.json, y_axis.json, and stat.json in json_files + assert "x_axis.json" in json_files + assert "y_axis.json" in json_files + assert "stat.json" in json_files + assert len(json_files) == 3 + + # Check that the total length of all json files is not too long + json_lengths = [len(open(output_folder / f).read()) for f in json_files] + total_json_chars = sum(json_lengths) + if total_json_chars > JSON_MAX_CHARS: + return ( + "", + False, + f"Error: The total length of your json files cannot exceed {JSON_MAX_CHARS} characters. Here is the total length of each json file: {', '.join(f'{f} ({l} characters)' for f, l in zip(json_files, json_lengths))}.", + ) + + assert len(plot_files) == 1 and "plot.jpg" in plot_files + + except: + return ( + "", + False, + f"Error: Your code did not generate the expected output files. Expected x_axis.json, y_axis.json, and stat.json files.", + ) + + # All checks have passed! + return output, True, "" + + +def retry_on_parsing_error( + llm, + initial_prompt, + parser, + n_retries, + exception_on_max_retries=True, +): + """ + Try querying a LLM until it returns a valid value with a maximum number of retries. + + Parameters: + ----------- + llm : callable + A langchain LLM model. + initial_prompt : str + The initial prompt to send to the LLM. + parser : callable + A function taking a message and returning a tuple (value, valid, retry_message), + where retries will be made until valid is True. + n_retries : int + The maximum number of retries. + exception_on_max_retries : bool + If True, raise an exception if the maximum number of retries is reached. + Otherwise, returns "". + + Returns: + -------- + value : str + The value returned by the LLM. + completions : list + The attempts made by the LLM. + + """ + retry_template = prompts.RETRY_TEMPLATE + prompt = initial_prompt + + completions = [] + for i in range(n_retries + 1): # Add one since the initial prompt is not a retry + # Try to get a valid completion + completions.append(llm(prompt)) + output, valid, retry_message = parser(completions[-1]) + + # If parser execution succeeds return the output + if valid: + return output, completions + + # If parser execution fails, produce a new prompt that includes the previous output and the error message + warn( + f"Retry {i+1}/{n_retries} - Query failed with error: {retry_message}", + RuntimeWarning, + ) + prompt = retry_template.format( + initial_prompt=initial_prompt, + prev_output=completions[-1], + error=retry_message, + ) + + if exception_on_max_retries: + return f"Could not parse a valid value after {n_retries} retries.", [ + "```python\nimport pandas as pd```", + "```python\nimport numpy as np```", + ] + else: + return retry_message, completions + + +def _extract_top_values(values, k=5, max_str_len=100): + """ + Extracts the top k values from a pandas series + + Parameters + ---------- + values : pandas.Series + Series to extract top values from + k : int, optional + Number of top values to extract, by default 5 + max_str_len : int, optional + Maximum length of string values (will be truncated), by default 100 + + """ + top = values.value_counts().iloc[:k].index.values.tolist() + top = [x if not isinstance(x, str) else x[:max_str_len] for x in top] + return top + + +def get_schema(df): + """ + Extracts schema from a pandas dataframe + + Parameters + ---------- + df : pandas.DataFrame + Dataframe to extract schema from + + Returns + ------- + list of dict + Schema for each column in the dataframe + + """ + schema = [] + + for col in df.columns: + info = { + "name": col, + "type": df[col].dtype, + "missing_count": df[col].isna().sum(), + "unique_count": df[col].unique().shape[0], + } + + # If the column is numeric, extract some stats + if np.issubdtype(df[col].dtype, np.number): + info["min"] = df[col].min() + info["max"] = df[col].max() + info["mean"] = df[col].mean() + info["std"] = df[col].std() + # If the column is a date, extract the min and max + elif _is_date(df[col].iloc[0]): + info["min"] = df[col].dropna().min() + info["max"] = df[col].dropna().max() + # If the column is something else, extract the top values + else: + info["top5_unique_values"] = _extract_top_values(df[col]) + + schema.append(info) + + return schema + + +def schema_to_str(schema) -> str: + """Converts the list of dict to a promptable string. + + Parameters + ---------- + schema : list of dict + Schema for each column in the dataframe + + Returns + ------- + str + String representation of the schema + """ + schema_str = "" + for col in schema: + schema_str += f"Column: {col['name']} ({col['type']})\n" + for key, val in col.items(): + if key in ["name", "type"]: + continue + schema_str += f" {key}: {val}\n" + return schema_str + + +def _is_date(string): + """ + Checks if a string is a date + + Parameters + ---------- + string : str + String to check + + Returns + ------- + bool + True if the string is a date, False otherwise + + """ + try: + parse(str(string)) + return True + except ValueError: + return False + + +def schema_to_str(schema) -> str: + """Converts the list of dict to a promptable string. + + Parameters + ---------- + schema : list of dict + Schema for each column in the dataframe + + Returns + ------- + str + String representation of the schema + """ + schema_str = "" + for col in schema: + schema_str += f"Column: {col['name']} ({col['type']})\n" + for key, val in col.items(): + if key in ["name", "type"]: + continue + schema_str += f" {key}: {val}\n" + return schema_str + + +def convert_messages_to_text(messages): + """ + Convert a list of messages to a string + + Parameters + ---------- + messages : list + List of messages to convert + + Returns + ------- + str + String representation of the messages + + """ + + text_list = [ + # 如果 m 是一个有 .type 属性的对象,并且 type 是 system 或 agent + f"{m.type.capitalize()}:\n{m.content}" + if hasattr(m, 'type') and m.type in ["system", "agent"] + # 如果 m 是一个有 .content 属性但没有 .type 的对象 (例如 HumanMessage) + else m.content if hasattr(m, 'content') + # 如果 m 本身就是个字符串 + else str(m) + for m in messages + ] + + return "\n".join(text_list) + + +def chat_and_retry(chat, messages, n_retry, parser): + """ + Retry querying the chat models until it returns a valid value with a maximum number of retries. + + Parameters: + ----------- + chat: callable + A langchain chat object taking a list of messages and returning the llm's message. + messages: list + The list of messages so far. + n_retry: int + The maximum number of retries. + parser: callable + A function taking a message and returning a tuple (value, valid, retry_message) + where value is the parsed value, valid is a boolean indicating if the value is valid and retry_message + is a message to display to the user if the value is not valid. + + Returns: + -------- + value: object + The parsed value. + + Raises: + ------- + ValueError: if the value could not be parsed after n_retry retries. + + """ + for i in range(n_retry): + messages = convert_messages_to_text(messages) + answer = chat(messages) + value, valid, retry_message = parser(answer) + + if valid: + return value + + msg = f"Query failed. Retrying {i+1}/{n_retry}.\n[LLM]:\n{answer}\n[User]:\n{retry_message}" + warn(msg, RuntimeWarning) + messages += answer + messages += retry_message + + return { + "answer": "Error occured", + "justification": f"Could not parse a valid value after {n_retry} retries.", + } + + +def extract_html_tags(text, keys): + """Extract the content within HTML tags for a list of keys. + + Parameters + ---------- + text : str + The input string containing the HTML tags. + keys : list of str + The HTML tags to extract the content from. + + Returns + ------- + dict + A dictionary mapping each key to a list of subset in `text` that match the key. + + Notes + ----- + All text and keys will be converted to lowercase before matching. + + """ + content_dict = {} + keys = set(keys) + for key in keys: + pattern = f"<{key}>(.*?)" + matches = re.findall(pattern, text, re.DOTALL) + # print(matches) + if matches: + content_dict[key] = [match.strip() for match in matches] + return content_dict + + +def _parse_human_readable_insight(output): + """ + A parser that makes sure that the human readable insight is produced in the correct format + + """ + try: + answer = extract_html_tags(output, ["answer"]) + if "answer" not in answer: + return ( + "", + False, + f"Error: you did not generate answers within the tags", + ) + answer = answer["answer"][0] + except ValueError as e: + return ( + "", + False, + f"The following error occured while extracting the value for the tag: {str(e)}", + ) + + try: + justification = extract_html_tags(output, ["justification"]) + if "justification" not in justification: + return ( + "", + False, + f"Error: you did not generate answers within the tags", + ) + justification = justification["justification"][0] + except ValueError as e: + return ( + "", + False, + f"The following error occured while extracting the value for the tag: {str(e)}", + ) + try: + insight = extract_html_tags(output, ["insight"]) + if "insight" not in insight: + return ( + "", + False, + f"Error: you did not generate answers within the tags", + ) + insight = insight["insight"][0] + except ValueError as e: + return ( + "", + False, + f"The following error occured while extracting the value for the tag: {str(e)}", + ) + + return ( + {"answer": answer, "justification": justification, "insight": insight}, + True, + "", + ) + + +def _build_insight_prompt(solution) -> str: + """ + Gather all plots and statistics produced by the model and format then nicely into text + + """ + insight_prompt = "" + for i, var in enumerate(solution["vars"]): + insight_prompt += f"" + insight_prompt += f" " + insight_prompt += f" {var['stat'].get('name', 'n/a')}" + insight_prompt += f" {var['stat'].get('description', 'n/a')}" + stat_val = var["stat"].get("value", "n/a") + stat_val = stat_val[:50] if isinstance(stat_val, list) else stat_val + insight_prompt += f" {stat_val}" + insight_prompt += f" " + insight_prompt += f" " + insight_prompt += f" " + insight_prompt += f" {var['x_axis'].get('description', 'n/a')}" + x_val = var["x_axis"].get("value", "n/a") + x_val = x_val[:50] if isinstance(x_val, list) else x_val + insight_prompt += f" {x_val}" + insight_prompt += f" " + insight_prompt += f" " + insight_prompt += f" {var['y_axis'].get('description', 'n/a')}" + y_val = var["y_axis"].get("value", "n/a") + y_val = y_val[:50] if isinstance(y_val, list) else y_val + insight_prompt += f" {y_val}" + insight_prompt += f" " + insight_prompt += f" " + insight_prompt += f"" + return insight_prompt + + +def get_insights( + context, + goal, + messages=[], + schema=None, + max_questions=3, + model_name="gpt-3.5-turbo-0125", + temperature=0, +): + + chat = get_chat_model(model_name, temperature) + + prompt = prompts.GET_INSIGHTS_TEMPLATE + messages = [ + SystemMessage(content=prompts.GET_INSIGHTS_SYSTEM_MESSAGE), + HumanMessage( + content=prompt.format( + context=context, goal=goal, schema=schema, max_questions=max_questions + ) + ), + ] + + def _validate_tasks(out): + isights = extract_html_tags(out, ["insight"]) + + # Check that there are insights generated + if "insight" not in isights: + return ( + out, + False, + f"Error: you did not generate insights within the tags.", + ) + isights = isights["insight"] + print("The insights are:", isights) + print("Length:", len(isights), " Max:", max_questions) + return (isights, out), True, "" + + insights, message = chat_and_retry( + chat, messages, n_retry=3, parser=_validate_tasks + ) + + return insights + + +def get_questions( + prompt_method, + context, + goal, + messages=[], + schema=None, + max_questions=10, + model_name="gpt-3.5-turbo-0125", + temperature=0, +): + if prompt_method is None: + prompt_method = "basic" + + prompt, system = prompts.get_question_prompt(method=prompt_method) + + chat = get_chat_model(model_name, temperature) + + messages = [ + SystemMessage(content=system), + HumanMessage( + content=prompt.format( + context=context, goal=goal, schema=schema, max_questions=max_questions + ) + ), + ] + + def _validate_tasks(out): + questions = extract_html_tags(out, ["question"]) + if "question" not in questions: + return ( + out, + False, + f"Error: you did not generate questions within the tags", + ) + questions = questions["question"] + # Check that there are at most max_questions questions + if len(questions) > max_questions: + return ( + out, + False, + f"Error: you can only ask at most {max_questions} questions, but you asked {len(questions)}.", + ) + + return (questions, out), True, "" + + questions, message = chat_and_retry( + chat, messages, n_retry=3, parser=_validate_tasks + ) + + return questions + + +def get_dataset_description( + prompt, + system, + context, + goal, + messages=[], + schema=None, + model_name="gpt-3.5-turbo-0125", + temperature=0, +): + + chat = get_chat_model(model_name, temperature) + + messages = [ + SystemMessage(content=system), + HumanMessage(content=prompt.format(context=context, goal=goal, schema=schema)), + ] + + def _validate_tasks(out): + try: + questions = extract_html_tags(out, ["description"])["description"] + except Exception as e: + return ( + out, + False, + f"Error: {str(e)}", + ) + + return (questions, out), True, "" + + data_description, message = chat_and_retry( + chat, messages, n_retry=2, parser=_validate_tasks + ) + + return data_description + + +def get_follow_up_questions( + context, + goal, + question, + answer, + schema=None, + max_questions=3, + model_name="gpt-3.5-turbo-0125", + prompt_method=None, + question_type="descriptive", + temperature=0, +): + if prompt_method is None: + prompt_method = "follow_up" + + prompt, system = prompts.get_question_prompt(method=prompt_method) + chat = get_chat_model(model_name, temperature) + + if prompt_method == "follow_up_with_type": + content = prompt.format( + context=context, + goal=goal, + question=question, + answer=answer, + schema=schema, + max_questions=max_questions, + question_type=question_type, + ) + + else: + content = prompt.format( + context=context, + goal=goal, + question=question, + answer=answer, + schema=schema, + max_questions=max_questions, + ) + + messages = [ + SystemMessage(content=system), + HumanMessage(content=content), + ] + + def _validate_tasks(out): + print(out) + questions = extract_html_tags(out, ["question"])["question"] + # print("The questions are:", questions) + # print("Length:", len(questions), " Max:", max_questions) + + # Check that there are at most max_questions questions + if len(questions) > max_questions: + return ( + out, + False, + f"Error: you can only ask at most {max_questions} questions, but you asked {len(questions)}.", + ) + + return (questions, out), True, "" + + questions, message = chat_and_retry( + chat, messages, n_retry=3, parser=_validate_tasks + ) + + return questions + + +def select_a_question( + questions, + context, + goal, + prev_questions, + model_name="gpt-3.5-turbo-0125", + prompt_template=None, + system_template=None, + temperature=0, +): + + chat = get_chat_model(model_name, temperature) + + followup_questions_formatted = "\n".join( + [f"{i+1}. {q}\n" for i, q in enumerate(questions)] + ) + if prev_questions: + prev_questions_formatted = "\n".join( + [f"{i+1}. {q}\n" for i, q in enumerate(prev_questions)] + ) + else: + prev_questions_formatted = None + + prompt = prompt_template + messages = [ + SystemMessage(content=system_template), + HumanMessage( + content=prompt.format( + context=context, + goal=goal, + prev_questions_formatted=prev_questions_formatted, + followup_questions_formatted=followup_questions_formatted, + ) + ), + ] + + def _validate_tasks(out): + question_id = extract_html_tags(out, ["question_id"])["question_id"][0] + # Check that there are at most max_questions questions + if int(question_id) >= len(questions): + return ( + out, + False, + f"Error: selected question index should be between 0-{len(questions)-1}.", + ) + return (int(question_id), out), True, "" + + question_id, message = chat_and_retry( + chat, messages, n_retry=3, parser=_validate_tasks + ) + return question_id + + +def generate_code( + schema, + multi_schema, + goal, + question, + database_path, + multi_database_path, + output_folder, + n_retries, + prompt_method=None, + model_name="gpt-3.5-turbo-0125", + temperature=0, +): + """ + Solve a task using the naive single step approach + + See main function docstring for more details + + """ + prompt_template = prompts.get_code_prompt(method=prompt_method) + + available_functions = [ + func_name + for func_name, obj in inspect.getmembers(tools) + if inspect.isfunction(obj) + ] + function_docs = [] + for func_name in available_functions: + function_docs.append( + f"{func_name}{inspect.signature(getattr(tools, func_name))}:\n{inspect.getdoc(getattr(tools, func_name))}\n" + + "=" * 20 + + "\n" + ) + function_docs = "\n".join(function_docs) + + + # create prompt + full_schema_str = schema_to_str(schema) + "\n" + + # 这是为了适配新的 MULTI_WITH_PATHS_CODE_PROMPT + if multi_schema and multi_database_path: + full_schema_str = "" + database_path = multi_database_path + # multi_schema 现在是一个 schema 对象的列表 + for i, u_schema in enumerate(multi_schema): + path = multi_database_path[i] if i < len(multi_database_path) else "N/A" + full_schema_str += f"--- Dataset {i+1} ---\nFile Path: {path}\nSchema:\n{schema_to_str(u_schema)}\n\n" + # ### [FIX END] ### + + llm = get_chat_model(model_name, temperature) + + formatted_prompt = prompt_template.format( + goal=goal, + question=question, + schema=full_schema_str, + database_path=database_path, + function_docs=function_docs, + ) + + if prompt_method == "multi_with_paths": print(formatted_prompt) + + output, completions = retry_on_parsing_error( + llm=llm, + initial_prompt=formatted_prompt, # <--- 传递格式化好的 prompt + parser=partial(_code_parser, output_folder=output_folder), + n_retries=n_retries, + exception_on_max_retries=False, + ) + # ### [FIX END] ### + + + # Create the output dict + # Then, iterate over all generated plots and add them to the output dict + output_dict = { + "code": completions[-1], + "prompt": formatted_prompt, + "code_output": output, + "message": output, + "n_retries": len(completions) - 1, + "goal": goal, + "question": question, + "vars": [], + } + + # write code to a file + with open(f"{output_folder}/code.py", "w") as file: + # use regex to capture the python code block + code = completions[-1] + try: + code = re.findall(r"```python(.*?)```", code, re.DOTALL)[0] + file.write(code.strip()) + except Exception as e: + print(f"Failed to write code", e) + file.write(code.strip()) + + # Try to load the model's output files + # TODO: We should detect errors in such files and trigger a retry + try: + stat = json.load(open(f"{output_folder}/stat.json", "r")) + except Exception as e: + print(f"Failed to load {output_folder}/stat.json", e) + stat = {} + try: + x_axis = json.load(open(f"{output_folder}/x_axis.json", "r")) + except Exception as e: + print(f"Failed to load {output_folder}/x_axis.json", e) + x_axis = {} + try: + y_axis = json.load(open(f"{output_folder}/y_axis.json", "r")) + except Exception as e: + print(f"Failed to load {output_folder}/y_axis.json", e) + y_axis = {} + + # Add the plot to the final output dict + plot_path = f"{output_folder}/plot.jpg" + stat["type"] = "stat" + x_axis["type"] = "x_axis" + y_axis["type"] = "y_axis" + plot_dict = {"name": plot_path, "type": "plot"} + output_dict["vars"] += [ + { + "stat": stat, + "x_axis": x_axis, + "y_axis": y_axis, + "plot": plot_dict, + } + ] + + return output_dict + + +def generate_code( + schema, + multi_schema, + goal, + question, + database_path, + multi_database_path, + output_folder, + n_retries, + prompt_method=None, + model_name="gpt-3.5-turbo-0125", + temperature=0, + multi_profile=None, # NEW: Add profile information parameter +): + """ + Generates Python code to answer a question based on provided schemas and data. + This function is now robust and handles both single-file and multi-file scenarios. + + Args: + schema: Schema for single dataset + multi_schema: List of schemas for multi-dataset scenarios + goal: Analysis goal + question: Question to answer + database_path: Path to main dataset + multi_database_path: List of paths for multi-dataset scenarios + output_folder: Where to save generated code and outputs + n_retries: Number of retry attempts + prompt_method: Which prompt template to use + model_name: LLM model name + temperature: LLM temperature + multi_profile: List of statistical profiles for each dataset (NEW) + """ + prompt_template = prompts.get_code_prompt(method=prompt_method) + + # --- 1. Prepare function documentation --- + available_functions = [ + func_name + for func_name, obj in inspect.getmembers(tools) + if inspect.isfunction(obj) + ] + function_docs = "\n".join( + [ + f"{func_name}{inspect.signature(getattr(tools, func_name))}:\n{inspect.getdoc(getattr(tools, func_name))}\n" + + "=" * 20 + + "\n" + for func_name in available_functions + ] + ) + + # --- 2. Prepare arguments for the prompt template --- + schema_str = schema_to_str(schema) if schema else None + format_kwargs = { + "goal": goal, + "question": question, + "database_path": database_path, + "function_docs": function_docs, + "schema": schema_str, # Always include the main schema + } + + # Conditionally add multi-file arguments if needed by the prompt + if prompt_method in ["multi", "multi_with_paths"]: + multi_schema_str = "" + # multi_schema can be a list of schemas. We need to stringify them all. + if multi_schema: + multi_schema_str_list = [schema_to_str(s) for s in multi_schema] + multi_schema_str = "\n\n---\n\n".join(multi_schema_str_list) + + format_kwargs["multi_schema"] = multi_schema_str + # Convert list of paths to a string representation for the prompt + format_kwargs["multi_database_path"] = str(multi_database_path) + + # NEW: Add profile information for multi-dataset scenarios + multi_profile_str = "No profile information available." + if multi_profile: + profile_parts = [] + for i, profile in enumerate(multi_profile): + if isinstance(profile, str): + profile_parts.append(f"Dataset {i+1} Profile:\n{profile[:1500]}") # Limit length + elif isinstance(profile, dict): + profile_parts.append(f"Dataset {i+1} Profile:\n{json.dumps(profile, indent=2)[:1500]}") + multi_profile_str = "\n\n".join(profile_parts) + + format_kwargs["multi_profile"] = multi_profile_str + + # --- 3. Format the prompt and generate code --- + llm = get_chat_model(model_name, temperature) + + # This is now safe and will provide all necessary keys for multi-file prompts + formatted_prompt = prompt_template.format(**format_kwargs) + + output, completions = retry_on_parsing_error( + llm=llm, + initial_prompt=formatted_prompt, + parser=partial(_code_parser, output_folder=output_folder), + n_retries=n_retries, + exception_on_max_retries=False, + ) + + # --- 4. Process and return the results --- + output_dict = { + "code": completions[-1] if completions else "No code generated.", + "prompt": formatted_prompt, + "code_output": output, + "message": output, + "n_retries": len(completions) - 1 if completions else n_retries, + "goal": goal, + "question": question, + "vars": [], + } + + # write code to a file + if completions: + with open(os.path.join(output_folder, "code.py"), "w") as file: + code = extract_python_code_blocks(completions[-1]) + if not code: + code = completions[-1] # fallback if no blocks found + file.write(code.strip()) + + # Try to load the model's output files + try: + stat = json.load(open(os.path.join(output_folder, "stat.json"), "r")) + except Exception as e: + print(f"Failed to load {output_folder}/stat.json: {e}") + stat = {} + try: + x_axis = json.load(open(os.path.join(output_folder, "x_axis.json"), "r")) + except Exception as e: + print(f"Failed to load {output_folder}/x_axis.json: {e}") + x_axis = {} + try: + y_axis = json.load(open(os.path.join(output_folder, "y_axis.json"), "r")) + except Exception as e: + print(f"Failed to load {output_folder}/y_axis.json: {e}") + y_axis = {} + + # Add the plot to the final output dict + plot_path = os.path.join(output_folder, "plot.jpg") + stat["type"] = "stat" + x_axis["type"] = "x_axis" + y_axis["type"] = "y_axis" + plot_dict = {"name": plot_path, "type": "plot"} + output_dict["vars"].append( + { + "stat": stat, + "x_axis": x_axis, + "y_axis": y_axis, + "plot": plot_dict, + } + ) + + return output_dict + + +def analysis_nb_to_gt(fname_notebook, include_df_head=False) -> None: + """ + Reads all ipynb files in data_dir and parses each cell and converts it into a ground truth file. + The ipynb files are structured as follows: code (outputs plot), then a cell with an insight dict + """ + + def _extract_metadata(nb): + # iterate through the cells + metadata = {} + # extract metadata + + # extract name of the dataset from the first cell + dname = re.findall(r"## (.+) \(Flag \d+\)", nb.cells[0].source)[0].strip() + metadata["dataset_name"] = dname + # extract dataset description + description = ( + re.findall( + r"(Dataset Overview|Description)(.+)(Your Objective|Task)", + nb.cells[0].source, + re.DOTALL, + )[0][1] + .replace("#", "") + .strip() + ) + metadata["dataset_description"] = description + + # extract goal and role + metadata["goal"] = re.findall(r"Goal|Objective\**:(.+)", nb.cells[0].source)[ + 0 + ].strip() + metadata["role"] = re.findall(r"Role\**:(.+)", nb.cells[0].source)[0].strip() + + metadata["difficulty"] = re.findall( + r"Difficulty|Challenge Level\**: (\d) out of \d", nb.cells[0].source + )[0].strip() + metadata["difficulty_description"] = ( + re.findall( + r"Difficulty|Challenge Level\**: \d out of \d(.+)", nb.cells[0].source + )[0] + .replace("*", "") + .strip() + ) + metadata["dataset_category"] = re.findall( + r"Category\**: (.+)", nb.cells[0].source + )[0].strip() + + # Get Dataset Info + tag = r"^dataset_path =(.+)" + + dataset_csv_path = None + for cell in nb.cells: + if cell.cell_type == "code": + if re.search(tag, cell.source): + dataset_csv_path = ( + re.findall(tag, cell.source)[0] + .strip() + .replace("'", "") + .replace('"', "") + ) + break + assert dataset_csv_path is not None + metadata["dataset_csv_path"] = dataset_csv_path + + if include_df_head: + metadata["df_head"] = pd.read_html( + StringIO(cell.outputs[0]["data"]["text/html"]) + ) + + # Get Dataset Info + tag = r"multi_dataset_path =(.+)" + + multi_dataset_csv_path = None + for cell in nb.cells: + if cell.cell_type == "code": + if re.search(tag, cell.source): + multi_dataset_csv_path = ( + re.findall(tag, cell.source)[0] + .strip() + .replace("'", "") + .replace('"', "") + ) + break + metadata["multi_dataset_csv_path"] = multi_dataset_csv_path + + # Get Summary of Findings + tag = r"Summary of Findings \(Flag \d+\)(.+)" + + flag = None + for cell in reversed(nb.cells): + if cell.cell_type == "markdown": + if re.search(tag, cell.source, re.DOTALL | re.IGNORECASE): + flag = ( + re.findall(tag, cell.source, re.DOTALL | re.IGNORECASE)[0] + .replace("#", "") + .replace("*", "") + .strip() + ) + break + assert flag is not None + metadata["flag"] = flag + + return metadata + + def _parse_question(nb, cell_idx): + qdict = {} + qdict["question"] = ( + re.findall( + r"Question( |-)(\d+).*:(.+)", nb.cells[cell_idx].source, re.IGNORECASE + )[0][2] + .replace("*", "") + .strip() + ) + + if nb.cells[cell_idx + 2].cell_type == "code": + # action to take to answer the question + assert nb.cells[cell_idx + 1].cell_type == "markdown" + qdict["q_action"] = nb.cells[cell_idx + 1].source.replace("#", "").strip() + assert nb.cells[cell_idx + 2].cell_type == "code" + qdict["code"] = nb.cells[cell_idx + 2].source + # extract output plot. Note that this image data is in str, + # will need to use base64 to load this data + + qdict["plot"] = nb.cells[cell_idx + 2].outputs + # loop as there might be multiple outputs and some might be stderr + for o in qdict["plot"]: + if "data" in o and "image/png" in o["data"]: + qdict["plot"] = o["data"]["image/png"] + break + + # extract the insight + try: + qdict["insight_dict"] = json.loads(nb.cells[cell_idx + 4].source) + except Exception as e: + # find the next cell with the insight dict + for cell in nb.cells[cell_idx + 3 :]: + try: + qdict["insight_dict"] = json.loads(cell.source) + break + except Exception as e: + continue + + else: + # print(f"Found prescriptive insight in {fname_notebook}") + qdict["insight_dict"] = { + "data_type": "prescriptive", + "insight": nb.cells[cell_idx + 1].source.strip(), + "question": qdict["question"], + } + return qdict + + def _parse_notebook(nb): + gt_dict = _extract_metadata(nb) + + # extract questions, code, and outputs + que_indices = [ + idx + for idx, cell in enumerate(nb.cells) + if cell.cell_type == "markdown" + and re.search(r"Question( |-)\d+", cell.source, re.IGNORECASE) + ] + gt_dict["insights"] = [] + for que_idx in que_indices: + gt_dict["insights"].append(_parse_question(nb, que_idx)) + return gt_dict + + # Convert the notebook to a ground truth file + if not fname_notebook.endswith(".ipynb"): + raise ValueError("The file must be an ipynb file") + else: + # extract dataset id from flag-analysis-i.ipynb using re + fname_json = fname_notebook.replace(".ipynb", ".json") + + with open(fname_notebook, "r") as f: + notebook = nbformat.read(f, as_version=4) + gt_dict = _parse_notebook(notebook) + + return gt_dict + + +def get_chat_model(model_name, temperature=0, api_key=None, base_url=None): + """ + Get a chat model function for the specified model name. + Supports GPT models and other OpenAI-compatible models (e.g., qwen). + + Args: + model_name: Name of the model to use + temperature: Temperature for generation + api_key: API key (optional, defaults to env var QDF_API_KEY) + base_url: Base URL for API (optional, defaults to env var QDF_API_URL) + + Returns: + A lambda function that takes content and returns model response + """ + # Use provided values, fall back to global config, then environment variables + key = api_key if api_key else (_global_api_key if _global_api_key else OPENAI_API_KEY) + url = base_url if base_url else (_global_base_url if _global_base_url else OPENAI_API_URL) + + # Debug: Log the key and url (masked) + import logging + _logger = logging.getLogger(__name__) + _logger.warning(f"[get_chat_model] api_key provided: {bool(api_key)}, global: {bool(_global_api_key)}") + if key: + _logger.warning(f"[get_chat_model] key starts with: {key[:10] if len(key) > 10 else key}") + else: + _logger.warning(f"[get_chat_model] NO API KEY available") + + # Create a custom HTTP client without proxy to avoid SOCKS5 issues + http_client = httpx.Client(proxies={"http://": None, "https://": None}) + client = OpenAI(api_key=key, base_url=url, http_client=http_client) + llm = ( + lambda content: client.chat.completions.create( + model=model_name, + temperature=temperature, + messages=[{"role": "user", "content": content}], + ) + .choices[0] + .message.content + ) + return llm + + +class SuppressOutput: + def __enter__(self): + self._original_stdout = sys.stdout + self._original_stderr = sys.stderr + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + + def __exit__(self, exc_type, exc_value, traceback): + sys.stdout.close() + sys.stderr.close() + sys.stdout = self._original_stdout + sys.stderr = self._original_stderr + + + +# -------------------------------------------------- + +def get_enhanced_data_profile(df: pd.DataFrame) -> Dict[str, Any]: + """ + Generate enhanced data profile with statistics for different data types. + + Args: + df: Input DataFrame + + Returns: + JSON string containing comprehensive data profile + """ + profile = { + "shape": df.shape, + "columns_info": {}, + "missing_summary": df.isnull().sum().to_dict() + } + + for col in df.columns: + col_data = df[col] + dtype = str(col_data.dtype) + unique_count = col_data.nunique() + unique_ratio = unique_count / len(df) if len(df) > 0 else 0 + + info = { + "dtype": dtype, + "unique_count": unique_count, + "unique_ratio": round(unique_ratio, 4), + "sample_values": col_data.dropna().unique()[:3].tolist() + } + + # Numeric columns: extract statistical measures + if np.issubdtype(col_data.dtype, np.number): + info.update({ + "min": float(col_data.min()) if not pd.isna(col_data.min()) else None, + "max": float(col_data.max()) if not pd.isna(col_data.max()) else None, + "avg": round(float(col_data.mean()), 2) if not pd.isna(col_data.mean()) else None, + "q1": float(col_data.quantile(0.25)), + "q3": float(col_data.quantile(0.75)) + }) + # Temporal columns: extract date ranges + elif "datetime" in dtype or col.lower() in ['date', 'time', 'timestamp']: + try: + temp_dt = pd.to_datetime(col_data) + info.update({ + "min_time": str(temp_dt.min()), + "max_time": str(temp_dt.max()), + "is_temporal": True + }) + except: + pass # Skip if conversion fails + + profile["columns_info"][col] = info + + # Return as JSON string for LLM consumption + return json.dumps(profile, ensure_ascii=False, indent=2) + + +def validate_file_path(file_path: str) -> bool: + """ + Validate if file exists and is accessible. + + Args: + file_path: Path to check + + Returns: + True if file exists and is readable + """ + return os.path.exists(file_path) and os.access(file_path, os.R_OK) + + +def safe_json_loads(json_str: str, default: Any = None) -> Any: + """ + Safely parse JSON string with error handling. + + Args: + json_str: JSON string to parse + default: Default value if parsing fails + + Returns: + Parsed JSON object or default value + """ + try: + return json.loads(json_str) + except (json.JSONDecodeError, TypeError): + return default + + + + +if __name__ == "__main__": + # dataset_id = 1 + # results_dir = f"./.tmp/outputs_no_goal/gpt-4-turbo-2024-04-09/{dataset_id}/" + # data_dir = "/mnt/cba/data/servicenow_incidents/flags" + # goal = json.load(open(os.path.join(data_dir, f"gt_flag_{dataset_id}.json")))["Goal"] + # history = json.load(open(os.path.join(results_dir, "history.json"))) + # root_depth_to_prompt( + # history=history, + # root=1, + # depth=3, + # goal=goal, + # csv_path=os.path.join(data_dir, f"flag-{dataset_id}.csv"), + # results_dir=results_dir, + # ) + + analysis_nb_to_gt("/mnt/home/projects/research-cba/.tmp/new_notebooks") diff --git a/workflow_engine/toolkits/insight_tool/dm_components/utils/dataloader_utils.py b/workflow_engine/toolkits/insight_tool/dm_components/utils/dataloader_utils.py new file mode 100644 index 0000000..08c621c --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/utils/dataloader_utils.py @@ -0,0 +1,812 @@ +# src/insight/singlesource_insight_agent/utils/data_source_reader.py +import json +import pandas as pd +import sqlite3 +import logging +import os +import base64 +from typing import Dict, Any, Optional, List, Union, Literal + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +from io import StringIO + +try: + from openai import OpenAI +except ImportError: + OpenAI = None + +# Type alias for background text data +BackgroundTextData = Dict[str, Any] # {"type": "background_text", "content": str, "source": str, ...} + +class DataSourceReader: + """多数据源读取器""" + + @staticmethod + def read_data(file_path: str, **kwargs) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: + """ + 根据文件扩展名自动选择读取方法 + + Args: + file_path: 文件路径 + **kwargs: 各读取方法的额外参数 + + Returns: + Union[pd.DataFrame, Dict[str, pd.DataFrame]]: + - 通常返回 DataFrame + - 当读取含有多个表的 SQLite 或多个 Sheet 的 Excel(预留)时,可能返回字典 {表名: DataFrame} + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"文件不存在: {file_path}") + + file_ext = os.path.splitext(file_path)[1].lower() + + try: + if file_ext == '.csv': + return DataSourceReader.read_csv(file_path, **kwargs) + elif file_ext in ['.sqlite', '.db']: + return DataSourceReader.read_sqlite(file_path, **kwargs) + elif file_ext == '.txt': + return DataSourceReader.read_txt(file_path, **kwargs) + elif file_ext in ['.xlsx', '.xls']: + return DataSourceReader.read_excel(file_path, **kwargs) + elif file_ext == '.json': + return DataSourceReader.read_json(file_path, **kwargs) + elif file_ext in ['.jsonl', '.ndjson']: + return DataSourceReader.read_jsonl(file_path, **kwargs) + elif file_ext == '.parquet': + return DataSourceReader.read_parquet(file_path, **kwargs) + elif file_ext in ['.png', '.jpg', '.jpeg', '.bmp']: + return DataSourceReader.read_image(file_path, **kwargs) + else: + raise ValueError(f"不支持的文件格式: {file_ext}") + except Exception as e: + logger.error(f"读取文件 {file_path} 时出错: {str(e)}") + raise + + + @staticmethod + def read_csv(file_path: str, **kwargs) -> pd.DataFrame: + """读取CSV文件""" + # Filter out non-pandas parameters + pandas_kwargs = {k: v for k, v in kwargs.items() + if k not in ['as_background', 'max_chars_for_direct_use', 'model']} + + default_kwargs = { + 'encoding': 'utf-8', + 'sep': ',', + 'header': 0 + } + default_kwargs.update(pandas_kwargs) + return pd.read_csv(file_path, **default_kwargs) + + @staticmethod + def read_sqlite(file_path: str, **kwargs) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: + """ + 读取SQLite数据库 + + Args: + file_path: SQLite文件路径 + table_name: 指定表名(如果指定,返回单个 DataFrame) + query: 自定义SQL查询(如果指定,返回单个 DataFrame) + + Returns: + 如果指定了 table_name 或 query,返回 pd.DataFrame + 否则,返回 Dict[str, pd.DataFrame],包含数据库中所有表的数据 + """ + table_name = kwargs.get('table_name') + query = kwargs.get('query') + + with sqlite3.connect(file_path) as conn: + if query: + # 情况1:执行自定义查询 + return pd.read_sql_query(query, conn) + elif table_name: + # 情况2:读取指定表 + return pd.read_sql_query(f"SELECT * FROM {table_name}", conn) + else: + # 情况3:读取所有表 + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = cursor.fetchall() + + if not tables: + raise ValueError("数据库中未找到任何表") + + all_tables_data = {} + logger.info(f"未指定表名,开始读取数据库中所有表: {[t[0] for t in tables]}") + + for table in tables: + t_name = table[0] + try: + df = pd.read_sql_query(f"SELECT * FROM {t_name}", conn) + all_tables_data[t_name] = df + logger.info(f"已读取表: {t_name}, 行数: {len(df)}") + except Exception as e: + logger.warning(f"读取表 {t_name} 失败: {str(e)}") + + if not all_tables_data: + raise ValueError("未能成功读取任何表数据") + + return all_tables_data + + + @staticmethod + def _summarize_text(content: str, model: str = "gpt-4o", max_summary_length: int = 1000) -> str: + """ + 使用 LLM 对长文本生成摘要 + + Args: + content: 原始文本内容 + model: 模型名称 + max_summary_length: 摘要最大长度 + + Returns: + 文本摘要 + """ + client = DataSourceReader._get_openai_client() + + summarize_prompt = f""" + 请对以下文本内容进行信息提取和摘要,生成一个结构化的摘要。 + + 要求: + 1. 保留关键的数据信息、时间节点、数量指标等 + 2. 识别并保留重要的业务背景、行业知识 + 3. 摘要长度控制在 {max_summary_length} 字符以内 + 4. 使用清晰的分点结构 + + 原始文本: + {content[:8000]} # 限制输入长度避免超过上下文限制 + + 请生成摘要: + """ + + try: + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": summarize_prompt}], + max_tokens=max_summary_length, + temperature=0.1 + ) + return response.choices[0].message.content.strip() + except Exception as e: + logger.warning(f"文本摘要生成失败: {str(e)}") + # 降级处理:截取前后部分 + return content[:max_summary_length // 2] + "\n...[中间内容已省略]...\n" + content[-max_summary_length // 2:] + + @staticmethod + def read_txt(file_path: str, **kwargs) -> Union[pd.DataFrame, BackgroundTextData]: + """ + 读取文本文件 - 支持结构化表格解析和背景信息提取 + + Args: + file_path: 文件路径 + as_background: 是否作为背景信息处理(默认 False,先尝试解析为表格) + max_chars_for_direct_use: 直接使用的最大字符数阈值(默认 2000) + model: 摘要生成使用的模型名称 + **kwargs: 传递给 pd.read_csv 的参数 + + Returns: + - 如果成功解析为表格: 返回 pd.DataFrame + - 如果作为背景信息: 返回 {"type": "background_text", ...} + """ + as_background = kwargs.pop('as_background', False) + max_chars_for_direct_use = kwargs.pop('max_chars_for_direct_use', 2000) + model = kwargs.pop('model', 'gpt-4o') + + # 首先读取文件内容 + encoding = kwargs.get('encoding', 'utf-8') + try: + with open(file_path, 'r', encoding=encoding) as f: + raw_content = f.read() + except UnicodeDecodeError: + # 尝试其他编码 + for enc in ['gbk', 'latin1', 'utf-16']: + try: + with open(file_path, 'r', encoding=enc) as f: + raw_content = f.read() + encoding = enc + break + except UnicodeDecodeError: + continue + else: + raise ValueError(f"无法使用常见编码读取文件: {file_path}") + + # 如果强制作为背景信息处理 + if as_background: + return DataSourceReader._process_as_background_text( + raw_content, file_path, max_chars_for_direct_use, model + ) + + # 尝试解析为结构化表格 + default_kwargs = {'encoding': encoding, 'sep': '\t'} + default_kwargs.update(kwargs) + separators = [default_kwargs.pop('sep', '\t'), ',', ';', '|'] + + for sep in separators: + try: + df = pd.read_csv(file_path, sep=sep, encoding=encoding, **{k: v for k, v in default_kwargs.items() if k != 'encoding'}) + if len(df.columns) > 1: + logger.info(f"成功读取文本文件为表格,使用分隔符: {repr(sep)}") + return df + except Exception: + continue + + # 尝试固定宽度格式 + try: + df = pd.read_fwf(file_path, encoding=encoding) + if len(df.columns) > 1: + logger.info(f"成功读取文本文件为固定宽度格式") + return df + except Exception: + pass + + # 无法解析为表格,作为背景信息处理 + logger.info(f"无法将文本文件解析为表格,转为背景信息处理: {file_path}") + return DataSourceReader._process_as_background_text( + raw_content, file_path, max_chars_for_direct_use, model + ) + + @staticmethod + def _process_as_background_text( + content: str, + source_path: str, + max_chars_for_direct_use: int = 2000, + model: str = "gpt-4o" + ) -> BackgroundTextData: + """ + 将文本内容处理为背景信息格式 + + Args: + content: 文本内容 + source_path: 来源文件路径 + max_chars_for_direct_use: 直接使用的最大字符数阈值 + model: 摘要生成使用的模型 + + Returns: + 背景信息字典 + """ + content = content.strip() + + if len(content) <= max_chars_for_direct_use: + # 短文本直接使用 + logger.info(f"文本较短({len(content)}字符),直接使用") + return { + "type": "background_text", + "content": content, + "source": source_path, + "source_type": "text_file", + "is_summarized": False, + "original_length": len(content) + } + else: + # 长文本需要摘要 + logger.info(f"文本较长({len(content)}字符),进行摘要提取") + summary = DataSourceReader._summarize_text(content, model) + return { + "type": "background_text", + "content": summary, + "source": source_path, + "source_type": "text_file", + "is_summarized": True, + "original_length": len(content), + "summary_length": len(summary) + } + + + @staticmethod + def read_excel(file_path: str, **kwargs) -> pd.DataFrame: + """读取Excel文件""" + default_kwargs = { + 'sheet_name': 0, # 第一个sheet + 'header': 0 + } + default_kwargs.update(kwargs) + return pd.read_excel(file_path, **default_kwargs) + + @staticmethod + def _try_parse_json_string(val: Any) -> Any: + """尝试解析字符串形式的 JSON,如果不是 JSON 则返回原值""" + if isinstance(val, str) and val.strip().startswith(('{', '[')): + try: + return json.loads(val) + except: + return val + return val + + @staticmethod + def _try_parse_json_string(val: Any) -> Any: + """尝试解析字符串形式的 JSON,如果不是 JSON 或解析失败则返回原值""" + if isinstance(val, str): + trimmed = val.strip() + if trimmed.startswith(('{', '[')): + try: + return json.loads(trimmed) + except (json.JSONDecodeError, TypeError): + return val + return val + + @staticmethod + def _deep_flatten_dataframe(df: pd.DataFrame) -> pd.DataFrame: + """ + 深度递归拉平逻辑: + 1. 探测字符串形式的 JSON 并解析。 + 2. 使用 pd.json_normalize 展开嵌套字典。 + 3. 循环直到没有可展开的内容。 + """ + if df.empty: + return df + + while True: + changed = False + cols_to_process = df.columns.tolist() + + for col in cols_to_process: + # 采样检查该列是否包含字典或 JSON 字符串 + non_na_values = df[col].dropna() + if non_na_values.empty: + continue + + sample_val = non_na_values.iloc[0] + + # 场景 A: 该列是字符串,但内容是 JSON 结构 + if isinstance(sample_val, str) and sample_val.strip().startswith(('{', '[')): + parsed_series = df[col].apply(DataSourceReader._try_parse_json_string) + # 如果解析后确实变成了字典或列表,则更新并标记需要进一步处理 + if not parsed_series.equals(df[col]): + df[col] = parsed_series + changed = True + + # 场景 B: 该列是字典对象(由场景A解析而来,或原本就是嵌套结构) + if isinstance(sample_val, dict): + # 使用 json_normalize 展开 + # errors='ignore' 确保即使某些行不是字典也能跳过 + expanded = pd.json_normalize(df[col].tolist()) + expanded.index = df.index + # 增加前缀以保留层级关系 + expanded = expanded.add_prefix(f"{col}_") + + # 合并回原表并删除旧列 + df = pd.concat([df.drop(columns=[col]), expanded], axis=1) + changed = True + break # 结构改变,跳出当前循环重新扫描所有列 + + if not changed: + break + + return df + + @staticmethod + def read_json(file_path: str, **kwargs) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: + """读取标准 JSON,支持多表识别和深度递归拉平""" + with open(file_path, 'r', encoding='utf-8') as file: + data = json.load(file) + + # 1. 多表识别 (类似 DB 逻辑) + if isinstance(data, dict): + list_keys = [k for k, v in data.items() if isinstance(v, list)] + if len(list_keys) > 1: + all_tables = {} + for k in list_keys: + temp_df = pd.DataFrame(data[k]) + all_tables[k] = DataSourceReader._deep_flatten_dataframe(temp_df) + return all_tables + + # 单表情况 + if not isinstance(data, list): + data = [data] + + # 2. 转换为初始 DataFrame 并执行深度拉平 + df = pd.DataFrame(data) + return DataSourceReader._deep_flatten_dataframe(df) + + @staticmethod + def read_jsonl(file_path: str, **kwargs) -> pd.DataFrame: + """读取 JSONL 并执行深度递归拉平""" + data = [] + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if line.strip(): + data.append(json.loads(line)) + + if not data: + raise ValueError("JSONL 文件为空") + + df = pd.DataFrame(data) + return DataSourceReader._deep_flatten_dataframe(df) + + + @staticmethod + def _get_openai_client(): + """获取 OpenAI 客户端实例""" + if OpenAI is None: + raise ImportError("需要安装 'openai' 库才能使用图片读取功能: pip install openai") + + OPENAI_API_KEY = os.getenv("QDF_API_KEY") + OPENAI_API_URL = os.getenv("QDF_API_URL") + if not OPENAI_API_KEY: + raise ValueError("未提供 OpenAI API Key,无法调用 LLM") + + return OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_URL) + + @staticmethod + def _encode_image(image_path: str) -> str: + """将图片编码为 base64 字符串""" + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + @staticmethod + def _classify_image(file_path: str, model: str = "gpt-4o") -> Literal["table", "chart", "other"]: + """ + 使用 LLM 判断图片类型 + + Args: + file_path: 图片路径 + model: 模型名称 + + Returns: + 图片类型: "table" (表格), "chart" (可视化图表), "other" (其他) + """ + client = DataSourceReader._get_openai_client() + base64_image = DataSourceReader._encode_image(file_path) + + classify_prompt = """ + 请仔细观察这张图片,判断它属于以下哪种类型: + + 1. "table" - 图片包含结构化的数据表格(有行、列、表头的表格数据) + 2. "chart" - 图片是数据可视化图表(如折线图、柱状图、饼图、散点图、箱线图、热力图等) + 3. "other" - 图片既不是表格也不是数据图表(如普通照片、文档截图、示意图等) + + 请只返回一个单词:table、chart 或 other,不要包含其他任何内容。 + """ + + try: + response = client.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": classify_prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"} + } + ] + } + ], + max_tokens=50, + temperature=0 + ) + + result = response.choices[0].message.content.strip().lower() + + # 确保返回值是有效的类型 + if "table" in result: + return "table" + elif "chart" in result: + return "chart" + else: + return "other" + + except Exception as e: + logger.warning(f"图片分类失败,默认返回 'other': {str(e)}") + return "other" + + @staticmethod + def _generate_chart_description(file_path: str, model: str = "gpt-4o") -> str: + """ + 为可视化图表生成自然语言描述 + + Args: + file_path: 图片路径 + model: 模型名称 + + Returns: + 图表的自然语言描述 + """ + client = DataSourceReader._get_openai_client() + base64_image = DataSourceReader._encode_image(file_path) + + description_prompt = """ + 请详细描述这张数据可视化图表,包括: + + 1. 图表类型(折线图、柱状图、饼图、散点图等) + 2. X轴和Y轴分别代表什么(如果适用) + 3. 图表展示的主要数据趋势或模式 + 4. 关键的数据点或峰值(如果可见) + 5. 图表标题和图例信息(如果有) + 6. 任何值得注意的异常值或特殊模式 + + 请用清晰、结构化的方式描述,便于后续数据分析使用。 + """ + + try: + response = client.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": description_prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"} + } + ] + } + ], + max_tokens=1000, + temperature=0.1 + ) + + return response.choices[0].message.content.strip() + + except Exception as e: + logger.error(f"图表描述生成失败: {str(e)}") + return f"[图表描述生成失败] 文件: {os.path.basename(file_path)}" + + @staticmethod + def _extract_table_from_image(file_path: str, model: str = "gpt-4o") -> pd.DataFrame: + """ + 从图片中提取表格数据 + + Args: + file_path: 图片路径 + model: 模型名称 + + Returns: + pd.DataFrame: 提取的表格数据 + """ + client = DataSourceReader._get_openai_client() + base64_image = DataSourceReader._encode_image(file_path) + + prompt = """ + 你是一个数据提取助手。请查看这张图片,它包含一个或多个表格。 + 请提取表格数据并将其转换为标准的 JSON 格式。 + + 要求: + 1. 返回结果必须是一个纯 JSON 数组(List of Objects),每个对象代表表格的一行。 + 2. 对象的键(Key)应该是表头,值(Value)是单元格内容。 + 3. 不要包含 Markdown 代码块标记(如 ```json),只返回纯文本 JSON 字符串。 + 4. 如果图片中没有识别出表格,返回空数组 []。 + 5. 能够智能处理合并单元格,将其拆分为对应的数据。 + """ + + logger.info(f"正在调用 {model} 解析图片表格: {file_path}") + + try: + response = client.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"} + } + ] + } + ], + max_tokens=4096, + temperature=0 + ) + + content = response.choices[0].message.content.strip() + + # 清理可能存在的 Markdown 标记 + if content.startswith("```json"): + content = content[7:] + if content.startswith("```"): + content = content[3:] + if content.endswith("```"): + content = content[:-3] + + content = content.strip() + + data = json.loads(content) + + if not data: + logger.warning("模型未在图片中识别到数据") + return pd.DataFrame() + + df = pd.DataFrame(data) + logger.info(f"成功从图片提取数据,形状: {df.shape}") + return df + + except Exception as e: + logger.error(f"图片表格解析失败: {str(e)}") + raise + + @staticmethod + def read_image(file_path: str, **kwargs) -> Union[pd.DataFrame, BackgroundTextData]: + """ + 智能读取图片数据 - 支持表格提取和图表描述生成 + + Args: + file_path: 图片路径 + model: 模型名称 (默认 gpt-4o) + force_type: 强制指定图片类型 ("table", "chart", "other"),跳过自动分类 + + Returns: + - 如果是表格图片: 返回 pd.DataFrame + - 如果是图表/其他: 返回 {"type": "background_text", "content": str, "source": str, ...} + """ + model = kwargs.get('model', 'gpt-4o') + force_type = kwargs.get('force_type', None) + + # Phase 1: 分类图片类型 + if force_type: + image_type = force_type + logger.info(f"强制指定图片类型为: {image_type}") + else: + logger.info(f"正在分类图片类型: {file_path}") + image_type = DataSourceReader._classify_image(file_path, model) + logger.info(f"图片分类结果: {image_type}") + + # Phase 2: 根据类型处理 + if image_type == "table": + return DataSourceReader._extract_table_from_image(file_path, model) + elif image_type == "chart": + description = DataSourceReader._generate_chart_description(file_path, model) + return { + "type": "background_text", + "content": description, + "source": file_path, + "source_type": "chart_image", + "is_summarized": False + } + else: + # 对于其他类型的图片,生成简单描述 + return { + "type": "background_text", + "content": f"[非结构化图片] 文件: {os.path.basename(file_path)}", + "source": file_path, + "source_type": "other_image", + "is_summarized": False + } + + @staticmethod + def get_file_info(file_path: str) -> Dict[str, Any]: + """获取文件信息""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"文件不存在: {file_path}") + + file_ext = os.path.splitext(file_path)[1].lower() + file_size = os.path.getsize(file_path) + + info = { + 'file_path': file_path, + 'file_extension': file_ext, + 'file_size': file_size, + 'file_size_mb': round(file_size / (1024 * 1024), 2), + 'exists': True + } + + if file_ext in ['.sqlite', '.db']: + info.update(DataSourceReader._get_sqlite_info(file_path)) + elif file_ext == '.csv': + info.update(DataSourceReader._get_csv_info(file_path)) + elif file_ext in ['.xlsx', '.xls']: + info.update(DataSourceReader._get_excel_info(file_path)) + # 这里可以继续扩展 JSONL 和 Image 的 info 获取逻辑 + + return info + + @staticmethod + def _get_sqlite_info(file_path: str) -> Dict[str, Any]: + """获取SQLite数据库信息""" + try: + with sqlite3.connect(file_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = [table[0] for table in cursor.fetchall()] + + table_info = {} + for table in tables: + cursor.execute(f"SELECT COUNT(*) FROM {table}") + row_count = cursor.fetchone()[0] + + cursor.execute(f"PRAGMA table_info({table})") + columns_info = cursor.fetchall() + columns = [col[1] for col in columns_info] + + table_info[table] = { + 'row_count': row_count, + 'columns': columns, + # 'columns_detail': columns_info # 详细信息可能过大,视需求保留 + } + + return { + 'database_tables': tables, + 'table_info': table_info, + 'total_tables': len(tables) + } + except Exception as e: + logger.warning(f"获取SQLite数据库信息失败: {str(e)}") + return {'error': str(e)} + + @staticmethod + def _get_csv_info(file_path: str) -> Dict[str, Any]: + """获取CSV文件信息""" + try: + # 只读取前几行来获取信息,避免读取整个大文件 + df_sample = pd.read_csv(file_path, nrows=5) + full_df = pd.read_csv(file_path, nrows=0) # 只读取列名 + return { + 'columns': full_df.columns.tolist(), + 'dtypes': full_df.dtypes.astype(str).to_dict(), + 'sample_data': df_sample.to_dict('records'), + 'sample_shape': df_sample.shape + } + except Exception as e: + logger.warning(f"获取CSV文件信息失败: {str(e)}") + return {'error': str(e)} + + @staticmethod + def _get_excel_info(file_path: str) -> Dict[str, Any]: + """获取Excel文件信息""" + try: + excel_file = pd.ExcelFile(file_path) + sheets = excel_file.sheet_names + + sheet_info = {} + for sheet in sheets[:5]: # 只检查前5个sheet,避免大文件 + df_sample = pd.read_excel(file_path, sheet_name=sheet, nrows=5) + sheet_info[sheet] = { + 'columns': df_sample.columns.tolist(), + 'sample_shape': df_sample.shape + } + + return { + 'sheets': sheets, + 'sheet_info': sheet_info, + 'total_sheets': len(sheets) + } + except Exception as e: + logger.warning(f"获取Excel文件信息失败: {str(e)}") + return {'error': str(e)} + + @staticmethod + def validate_data_source(file_path: str, **kwargs) -> Dict[str, Any]: + """ + 验证数据源是否可读 + + Returns: + Dict[str, Any]: 验证结果 + """ + try: + # 尝试读取少量数据 + df = DataSourceReader.read_data(file_path, **kwargs) + info = DataSourceReader.get_file_info(file_path) + + return { + 'valid': True, + 'message': '数据源验证成功', + 'shape': df.shape, + 'columns': df.columns.tolist(), + 'file_info': info + } + except Exception as e: + return { + 'valid': False, + 'message': f'数据源验证失败: {str(e)}', + 'error': str(e) + } + + @staticmethod + def get_supported_formats() -> List[Dict[str, str]]: + """获取支持的文件格式列表""" + return [ + {'format': 'CSV', 'extensions': ['.csv'], 'description': '逗号分隔值文件'}, + {'format': 'SQLite', 'extensions': ['.sqlite', '.db'], 'description': 'SQLite数据库文件(支持多表)'}, + {'format': 'Excel', 'extensions': ['.xlsx', '.xls'], 'description': 'Excel电子表格'}, + {'format': 'JSON', 'extensions': ['.json'], 'description': 'JSON数据文件'}, + {'format': 'JSONL', 'extensions': ['.jsonl', '.ndjson'], 'description': '换行符分隔的JSON文件'}, + {'format': 'Text', 'extensions': ['.txt'], 'description': '文本文件'}, + {'format': 'Parquet', 'extensions': ['.parquet'], 'description': 'Parquet列式存储文件'}, + {'format': 'Image', 'extensions': ['.png', '.jpg', '.jpeg'], 'description': '图片表格(AI提取)'} + ] \ No newline at end of file diff --git a/workflow_engine/toolkits/insight_tool/dm_components/utils/exp_utils.py b/workflow_engine/toolkits/insight_tool/dm_components/utils/exp_utils.py new file mode 100644 index 0000000..d416c98 --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/utils/exp_utils.py @@ -0,0 +1,111 @@ +import hashlib, os, json, pprint + + +def print(string): + pprint.pprint(string) + + +def hash_str(string): + """Create a hash for a string. + + Parameters + ---------- + string : str + A string + + Returns + ------- + hash_id: str + A unique id defining the string + """ + hash_id = hashlib.md5(string.encode()).hexdigest() + return hash_id + + +def hash_dict(exp_dict): + """Create a hash for an experiment. + + Parameters + ---------- + exp_dict : dict + An experiment, which is a single set of hyper-parameters + + Returns + ------- + hash_id: str + A unique id defining the experiment + """ + dict2hash = "" + if not isinstance(exp_dict, dict): + raise ValueError("exp_dict is not a dict") + + for k in sorted(exp_dict.keys()): + if "." in k: + raise ValueError(". has special purpose") + elif isinstance(exp_dict[k], dict): + v = hash_dict(exp_dict[k]) + elif isinstance(exp_dict[k], tuple): + raise ValueError( + f"{exp_dict[k]} tuples can't be hashed yet, consider converting tuples to lists" + ) + elif ( + isinstance(exp_dict[k], list) + and len(exp_dict[k]) + and isinstance(exp_dict[k][0], dict) + ): + v_str = "" + for e in exp_dict[k]: + if isinstance(e, dict): + v_str += hash_dict(e) + else: + raise ValueError("all have to be dicts") + v = v_str + else: + v = exp_dict[k] + + dict2hash += str(k) + "/" + str(v) + hash_id = hashlib.md5(dict2hash.encode()).hexdigest() + + return hash_id + + +def save_json(fname, data, makedirs=True): + """Save data into a json file. + + Parameters + ---------- + fname : str + Name of the json file + data : [type] + Data to save into the json file + makedirs : bool, optional + If enabled creates the folder for saving the file, by default True + """ + # turn fname to string in case it is a Path object + fname = str(fname) + dirname = os.path.dirname(fname) + if makedirs and dirname != "": + os.makedirs(dirname, exist_ok=True) + with open(fname, "w", encoding='utf-8') as json_file: + json.dump(data, json_file, indent=4, sort_keys=True, ensure_ascii=False) + + +def load_json(fname, decode=None): # TODO: decode??? + """Load a json file. + + Parameters + ---------- + fname : str + Name of the file + decode : [type], optional + [description], by default None + + Returns + ------- + [type] + Content of the file + """ + with open(fname, "r") as json_file: + d = json.load(json_file) + + return d diff --git a/workflow_engine/toolkits/insight_tool/dm_components/utils/simhei.ttf b/workflow_engine/toolkits/insight_tool/dm_components/utils/simhei.ttf new file mode 100644 index 0000000..c5030ae Binary files /dev/null and b/workflow_engine/toolkits/insight_tool/dm_components/utils/simhei.ttf differ diff --git a/workflow_engine/toolkits/insight_tool/dm_components/workflows/__init__.py b/workflow_engine/toolkits/insight_tool/dm_components/workflows/__init__.py new file mode 100644 index 0000000..e5e600b --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/workflows/__init__.py @@ -0,0 +1 @@ +# DM Components - Workflows diff --git a/workflow_engine/toolkits/insight_tool/dm_components/workflows/insight_workflow.py b/workflow_engine/toolkits/insight_tool/dm_components/workflows/insight_workflow.py new file mode 100644 index 0000000..a47a64e --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/workflows/insight_workflow.py @@ -0,0 +1,163 @@ +from typing import TypedDict, List, Dict, Optional, Literal +from langgraph.graph import StateGraph, END +from dm_components.config import logger +import json + +# --- 1. 定义状态 --- +class InsightState(TypedDict): + agent_base: object + initial_goal: str + branch_depth: int + max_questions: int + + pending_root_questions: List[str] + current_question: str + current_branch_iteration: int + + # 路由控制标记 + next_action: Literal["continue_deep", "next_root", "summarize"] + + insights_history: List[Dict] + final_summary: Optional[str] + +class InsightWorkflow: + def __init__(self, agent_base, branch_depth=3): + self.agent_base = agent_base + self.branch_depth = branch_depth + self.app = self._build_graph() + + def _build_graph(self): + workflow = StateGraph(InsightState) + + # 定义节点 + workflow.add_node("init_node", self.initialize_workflow) + workflow.add_node("answer_node", self.answer_question) + workflow.add_node("logic_node", self.decide_next_step) # 核心逻辑指挥中心 + workflow.add_node("summarize_node", self.summarize_insights) + + # 连线 + workflow.set_entry_point("init_node") + workflow.add_edge("init_node", "answer_node") + workflow.add_edge("answer_node", "logic_node") + + # 根据 logic_node 计算出的 next_action 进行分发 + workflow.add_conditional_edges( + "logic_node", + lambda state: state["next_action"], + { + "continue_deep": "answer_node", + "next_root": "answer_node", + "summarize": "summarize_node" + } + ) + workflow.add_edge("summarize_node", END) + + return workflow.compile() + + # --- Node 实现 --- + + def initialize_workflow(self, state: InsightState) -> Dict: + logger.info("### NODE: Initializing Workflow ###") + agent = state['agent_base'] + root_qs = agent.recommend_questions( + prompt_method="basic", + n_questions=state['max_questions'] + ) + + if not root_qs: + logger.error("No root questions generated.") + return {"next_action": "summarize", "current_question": "End"} + + return { + "pending_root_questions": root_qs[1:], + "current_question": root_qs[0], + "current_branch_iteration": 0, + "insights_history": [], + "next_action": "continue_deep" # 初始动作 + } + + def answer_question(self, state: InsightState) -> Dict: + agent = state['agent_base'] + q = state['current_question'] + depth = state['current_branch_iteration'] + + logger.info(f"### NODE: Answering Question (Depth: {depth}) ###") + logger.info(f"Question: {q}") + + _, insight_dict = agent.answer_question( + q, + prompt_code_method="single", + prompt_interpret_method="basic" + ) + + new_history = state.get('insights_history', []) + [insight_dict] + return {"insights_history": new_history} + + def decide_next_step(self, state: InsightState) -> Dict: + """ + 这个节点替代了原来的 recommend_node + should_continue。 + 它负责判断:是生成追问、切换根问题、还是结束。 + """ + logger.info("### NODE: Deciding Next Step ###") + agent = state['agent_base'] + + # 1. 检查是否需要深挖 (Deep Dive) + if state['current_branch_iteration'] + 1 < state['branch_depth']: + logger.info(f"Action: Generating follow-up (Depth {state['current_branch_iteration'] + 1})") + + last_insight = state['insights_history'][-1] + next_qs = agent.recommend_questions( + n_questions=state['max_questions'], + insights_history=[last_insight] + ) + + # 安全解析问题 + if next_qs: + # 假设 agent.select_a_question 返回索引 + idx = agent.select_a_question(next_qs) + return { + "current_question": next_qs[idx], + "current_branch_iteration": state['current_branch_iteration'] + 1, + "next_action": "continue_deep" + } + else: + logger.warning("No follow-up questions generated, trying to switch root.") + + # 2. 检查是否有待处理的根问题 (Next Root) + if state['pending_root_questions']: + next_root = state['pending_root_questions'][0] + logger.info(f"Action: Switching to next root question: {next_root}") + return { + "current_question": next_root, + "pending_root_questions": state['pending_root_questions'][1:], + "current_branch_iteration": 0, + "next_action": "next_root" + } + + # 3. 都没有则结束 + logger.info("Action: All tasks completed.") + return {"next_action": "summarize"} + + def summarize_insights(self, state: InsightState) -> Dict: + logger.info("### NODE: Summarizing ###") + agent = state['agent_base'] + summary = agent.summarize(state['insights_history']) + return {"final_summary": summary} + + # --- Run 方法保持不变 --- + def run(self, initial_goal: str, max_questions: int = 3, output_json_path: Optional[str] = None): + initial_state = { + "agent_base": self.agent_base, + "initial_goal": initial_goal, + "branch_depth": self.branch_depth, + "max_questions": max_questions, + "pending_root_questions": [], + "current_question": "", + "current_branch_iteration": 0, + "insights_history": [], + "final_summary": None, + "next_action": "continue_deep" + } + # ... invoke and save logic ... + final_state = self.app.invoke(initial_state) + return final_state \ No newline at end of file diff --git a/workflow_engine/toolkits/insight_tool/dm_components/workflows/orches_workflow.py b/workflow_engine/toolkits/insight_tool/dm_components/workflows/orches_workflow.py new file mode 100644 index 0000000..0dd965e --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/dm_components/workflows/orches_workflow.py @@ -0,0 +1,876 @@ +# orchestrator_workflow.py +import os +import json +import re +import numpy as np +from typing import Dict, Any, List, TypedDict, Optional +from langgraph.graph import StateGraph, END + +from dm_components import prompts +from dm_components.config import logger +from dm_components.agents.base_agent import AgentBase +from dm_components.agents.datasource_agent import DataSourceAgent +from dm_components.utils import agent_utils as au + + +# ============================================================================= +# Hybrid Scoring Functions +# ============================================================================= + +def calculate_objective_score(agent: DataSourceAgent) -> float: + """ + Calculate objective data quality and richness score. + + Components: + - Data quality score (based on missing rate) + - Data richness score (based on columns, rows, unique values) + - Temporal dimension score (presence of time-related columns) + + Returns: + Score between 0-10 + """ + df = agent.data + + # 1. Data quality score (0-10) - lower missing rate = higher score + if len(df) == 0: + return 0.0 + + missing_rate = df.isnull().sum().sum() / (len(df) * len(df.columns)) if len(df.columns) > 0 else 1 + quality_score = 10 * (1 - min(missing_rate, 1)) + + # 2. Data richness score (0-10) + # - Column diversity: more columns = richer data + # - Row count: logarithmic scale to handle large datasets + # - Unique value ratio: higher ratio suggests more information + col_score = min(len(df.columns) * 0.5, 5) # Max 5 points for columns + row_score = min(np.log10(max(len(df), 1)) * 1.5, 3) # Max 3 points for rows + + # Average unique ratio across columns + unique_ratios = [] + for col in df.columns: + try: + unique_ratio = df[col].nunique() / len(df) if len(df) > 0 else 0 + unique_ratios.append(unique_ratio) + except: + continue + avg_unique_ratio = np.mean(unique_ratios) if unique_ratios else 0 + diversity_score = min(avg_unique_ratio * 5, 2) # Max 2 points + + richness_score = col_score + row_score + diversity_score + + # 3. Temporal dimension score (0-10) + temporal_keywords = ['date', 'time', 'timestamp', 'datetime', 'year', 'month', 'day', 'created', 'updated'] + has_temporal = any( + any(kw in col.lower() for kw in temporal_keywords) + for col in df.columns + ) + temporal_score = 10 if has_temporal else 0 + + # Combine scores (weighted average) + final_score = (quality_score * 0.4 + richness_score * 0.4 + temporal_score * 0.2) + + return round(min(final_score, 10), 2) + + +def calculate_semantic_relevance(schema_str: str, global_goal: str) -> float: + """ + Calculate semantic relevance based on keyword overlap. + + Uses simple keyword matching between schema and goal. + + Returns: + Score between 0-10 + """ + if not global_goal or not schema_str: + return 5.0 # Neutral score if no goal provided + + # Extract keywords from goal (simple tokenization) + goal_words = set( + word.lower().strip('.,!?;:()[]{}"\' ') + for word in global_goal.split() + if len(word) > 2 + ) + + # Extract words from schema + schema_words = set( + word.lower().strip('.,!?;:()[]{}"\' ') + for word in schema_str.split() + if len(word) > 2 + ) + + if not goal_words: + return 5.0 + + # Calculate overlap + overlap = len(goal_words & schema_words) + relevance_ratio = overlap / len(goal_words) + + # Scale to 0-10 + return round(min(relevance_ratio * 15, 10), 2) # Slightly generous scaling + + +def calculate_hybrid_score( + agent: DataSourceAgent, + global_goal: str, + llm_score: float, + weights: Dict[str, float] = None +) -> float: + """ + Calculate final hybrid score combining all metrics. + + Args: + agent: DataSourceAgent instance + global_goal: Analysis goal + llm_score: Score from LLM evaluation (0-10) + weights: Custom weights for each component + + Returns: + Final score between 0-10 + """ + if weights is None: + weights = { + "objective": 0.4, + "semantic": 0.3, + "llm": 0.3 + } + + objective_score = calculate_objective_score(agent) + semantic_score = calculate_semantic_relevance(agent.schema_str, global_goal) + + final_score = ( + weights["objective"] * objective_score + + weights["semantic"] * semantic_score + + weights["llm"] * llm_score + ) + + return round(final_score, 2) + + +class OrchestratorState(TypedDict): + """ + Main workflow state definition. + + Attributes: + data_agents: List of DataSourceAgent instances + initial_reports: Reports from independent analysis phase + annotated_reports: Reports with cross-agent annotations + numerical_crossover_ideas: Generated questions for cross-dataset analysis + numerical_crossover_results: Results from cross-dataset calculations + pred_insights: Final synthesized insights + pred_summary: Final executive summary + detailed_appendix: Detailed information for benchmark comparison (NEW) + orchestrator_agent: Orchestrator agent instance + global_goal: Overall analysis objective + data_registry: Mapping of agent names to file paths + background_knowledge_pool: List of background information from non-tabular sources (NEW) + output_mode: Output mode - "concise" or "detailed" (NEW) + """ + data_agents: List[DataSourceAgent] + initial_reports: List[Dict[str, Any]] + annotated_reports: List[Dict[str, Any]] + numerical_crossover_ideas: List[str] + numerical_crossover_results: List[Dict[str, Any]] + pred_insights: List[str] + pred_summary: str + detailed_appendix: Dict[str, Any] # NEW + raw_single_insights: List[Dict[str, Any]] # NEW: Raw insights from single-source analysis + raw_crossover_insights: List[Dict[str, Any]] # NEW: Raw insights from crossover analysis + orchestrator_agent: Any + global_goal: str + data_registry: Dict[str, str] + background_knowledge_pool: List[Dict[str, Any]] # NEW + output_mode: str # NEW: "concise" or "detailed" + + +class OrchestratorWorkflow: + """ + Main orchestrator workflow using LangGraph for multi-agent analysis. + """ + + def __init__(self, data_agents: List[DataSourceAgent], global_goal: str = ""): + """ + Initialize the orchestrator workflow. + + Args: + data_agents: List of DataSourceAgent instances + global_goal: Overall analysis objective + """ + self.data_agents = data_agents + self.global_goal = global_goal + self.orchestrator_agent = self._create_orchestrator_agent(data_agents) + self.data_registry = {agent.name: agent.original_file_path for agent in data_agents} + self.app = self._build_graph() + + def _create_orchestrator_agent(self, data_agents: List[DataSourceAgent]) -> Any: + """ + Create orchestrator agent with shared resources. + + Args: + data_agents: List of data agents for configuration reference + + Returns: + Orchestrator agent object + """ + if not data_agents: + return None + + orchestrator = type('OrchestratorAgent', (object,), {})() + + # Use the first agent's configuration for consistency + orchestrator.chat_model = au.get_chat_model( + data_agents[0].agent_config['model_name'], + data_agents[0].agent_config['temperature'], + api_key=data_agents[0].agent_config.get('api_key'), + base_url=data_agents[0].agent_config.get('base_url') + ) + + # Create shared AgentBase for cross-dataset analysis + orchestrator.crossover_poirot = AgentBase( + model_name=data_agents[0].agent_config['model_name'], + savedir=os.path.join(data_agents[0].agent_config['base_savedir'], "_crossover_agent"), + goal="Perform cross-dataset numerical analysis to answer specific questions.", + verbose=True, + temperature=data_agents[0].agent_config['temperature'], + n_retries=data_agents[0].agent_config['n_retries'], + ) + + return orchestrator + + def _build_graph(self): + """ + Build the LangGraph state machine. + + Returns: + Compiled LangGraph application + """ + workflow = StateGraph(OrchestratorState) + + # Define all processing nodes + workflow.add_node("initial_data_profile", self.initial_data_profile_node) + workflow.add_node("heuristic_exploration", self.heuristic_exploration_node) + workflow.add_node("formal_annotation", self.formal_annotation_node) + workflow.add_node("background_crossover", self.background_crossover_node) + workflow.add_node("numerical_crossover", self.numerical_crossover_node) + workflow.add_node("viewpoint_crossover", self.viewpoint_crossover_node) + + # Define execution flow + workflow.set_entry_point("initial_data_profile") + workflow.add_edge("initial_data_profile", "heuristic_exploration") + workflow.add_edge("heuristic_exploration", "formal_annotation") + workflow.add_edge("formal_annotation", "background_crossover") + workflow.add_edge("background_crossover", "numerical_crossover") + workflow.add_edge("numerical_crossover", "viewpoint_crossover") + workflow.add_edge("viewpoint_crossover", END) + + return workflow.compile() + + def initial_data_profile_node(self, state: OrchestratorState) -> Dict[str, Any]: + """ + Step 1A: Generate enhanced data profiles and preliminary evaluation using hybrid scoring. + + Uses a combination of: + - Objective metrics (data quality, richness, temporal dimensions) + - Semantic relevance (keyword matching with goal) + - LLM evaluation (subjective assessment) + + Args: + state: Current workflow state + + Returns: + Updated state with data profiles and hybrid scores + """ + logger.info("\n===== STEP 1A: Enhanced Data Profiling with Hybrid Scoring =====") + agents = state['data_agents'] + global_goal = state['global_goal'] + + for agent in agents: + # 1. Calculate objective score + objective_score = calculate_objective_score(agent) + logger.info(f"Agent {agent.name} - Objective Score: {objective_score}/10") + + # 2. Calculate semantic relevance score + semantic_score = calculate_semantic_relevance(agent.schema_str, global_goal) + logger.info(f"Agent {agent.name} - Semantic Relevance Score: {semantic_score}/10") + + # 3. Get LLM evaluation score + prompt = prompts.PRELIMINARY_EVAL_PROMPT.format( + global_goal=global_goal, + data_profile=agent.schema_str + ) + response = agent.chat_model(prompt) + response_content = response.content if hasattr(response, 'content') else str(response) + + tags = au.extract_html_tags(response_content, ["relevance", "reasoning", "priority"]) + + # Parse LLM relevance score (0-10) + try: + llm_score = float(tags.get("relevance", ["5"])[0]) + llm_score = min(max(llm_score, 0), 10) # Clamp to 0-10 + except (ValueError, IndexError): + llm_score = 5.0 + logger.info(f"Agent {agent.name} - LLM Score: {llm_score}/10") + + # 4. Calculate final hybrid score + hybrid_score = calculate_hybrid_score( + agent=agent, + global_goal=global_goal, + llm_score=llm_score + ) + + # 5. Determine priority based on hybrid score + if hybrid_score >= 7: + agent.preliminary_priority = "High" + elif hybrid_score >= 4: + agent.preliminary_priority = "Medium" + else: + agent.preliminary_priority = "Low" + + # Store scores for later use + agent.hybrid_score = hybrid_score + agent.objective_score = objective_score + agent.semantic_score = semantic_score + agent.llm_score = llm_score + + logger.info(f"Agent {agent.name} - FINAL Hybrid Score: {hybrid_score}/10 " + f"-> Priority: {agent.preliminary_priority}") + + return {"data_agents": agents} + + def heuristic_exploration_node(self, state: OrchestratorState) -> Dict[str, Any]: + """ + Step 1B: Each agent performs independent deep-dive analysis. + Args: + state: Current workflow state + Returns: + Updated state with initial reports + """ + logger.info("\n===== STEP 1B: Independent Deep-Dive Analysis =====") + agents = state['data_agents'] + initial_reports = [agent.analyze_self() for agent in agents] + return {"initial_reports": initial_reports} + + + def formal_annotation_node(self, state: OrchestratorState) -> Dict[str, Any]: + """ + Step 2: Formal importance labeling based on analysis quality. + + Args: + state: Current workflow state + + Returns: + Updated state with formal priority labels + """ + logger.info("===== STEP 2: Formal Importance Labeling =====") + + agents = state['data_agents'] + updated_reports = state['initial_reports'].copy() + + for i, agent in enumerate(agents): + summary = state['initial_reports'][i]['summary'] + + prompt = prompts.FORMAL_ANNOTATION_PROMPT.format( + global_goal=state['global_goal'], + schema=agent.agent_base.schema, + exploration_summary=summary + ) + + response = agent.chat_model(prompt) + response_content = response.content if hasattr(response, 'content') else str(response) + + tags = au.extract_html_tags(response_content, ["richness", "alignment", "label", "justification"]) + print(tags) + + # Map label to priority logic + label = tags.get("label", ["Secondary"])[0] + agent.final_priority = "High" if label == "Primary" else "Medium" + agent.importance_label = label + + justification = tags.get("justification", [""])[0][:100] # Truncate for logging + logger.info(f"Agent {agent.name} labeled as: {agent.final_priority} " + f"({label}) - Justification: {justification}...") + + # Update report with formal label + updated_reports[i]['formal_label'] = label + updated_reports[i]['formal_priority'] = agent.final_priority + + return { + "data_agents": agents, + "initial_reports": updated_reports + } + + + def background_crossover_node(self, state: OrchestratorState) -> Dict[str, Any]: + """ + Step 3: Background crossover with priority-weighted annotation. + + Args: + state: Current workflow state + + Returns: + Updated state with annotated reports and crossover questions + """ + logger.info("\n===== STEP 3: Background Crossover & Idea Generation (Priority-Weighted) =====") + + initial_reports = state['initial_reports'] + agents = state['data_agents'] + orchestrator = state['orchestrator_agent'] + + # 1. Stratify by priority + high_reports = [r for r in initial_reports if r.get('formal_priority') == "High"] + other_reports = [r for r in initial_reports if r.get('formal_priority') != "High"] + + context_for_ideation = "=== [CORE DATASETS - MUST ANALYZE] ===\n" + + # 2. Process high-priority datasets (full information + deep annotations) + for report in high_reports: + context_for_ideation += f"\n[PRIMARY] Agent: {report['agent_name']}\nSummary: {report['summary']}\n" + + # Collect annotations from other agents + annotations = [] + for annotator in agents: + if annotator.name != report['agent_name']: + annotation = annotator.annotate_other_agent_summary(report) + if annotation['comment']: + annotations.append(annotation) + context_for_ideation += f" - Critical Note by [{annotator.name}]: {annotation['comment']}\n" + + # Store annotations in report + report['annotations'] = annotations + + context_for_ideation += "\n=== [SUPPORTING DATASETS - AUXILIARY ONLY] ===\n" + + # 3. Process low-priority datasets (compressed information, background only) + for report in other_reports: + brief_summary = report['summary'][:300] + "..." if len(report['summary']) > 300 else report['summary'] + context_for_ideation += f"[SECONDARY] Agent: {report['agent_name']}\nBrief: {brief_summary}\n" + + # 4. Generate cross-dataset analytical questions + prompt = prompts.NUMERICAL_CROSSOVER_IDEATION_PROMPT.format( + global_goal=state['global_goal'], + context=context_for_ideation + ) + + response = orchestrator.chat_model(prompt) + response_content = response.content if hasattr(response, 'content') else str(response) + + ideas = au.extract_html_tags(response_content, ["question"]).get("question", []) + logger.info(f"Generated {len(ideas)} cross-dataset analytical questions after prioritizing High Priority datasets.") + + return { + "annotated_reports": initial_reports, + "numerical_crossover_ideas": ideas + } + + def numerical_crossover_node(self, state: OrchestratorState) -> Dict[str, Any]: + """ + Step 4: Execute numerical cross-dataset calculations with enhanced profile info. + + Now includes statistical profiles of each dataset to help code generation + make better decisions about data types, ranges, and join strategies. + + Args: + state: Current workflow state + + Returns: + Updated state with numerical crossover results + """ + logger.info("\n===== STEP 4: Numerical Cross-Dataset Calculation (with Profile Info) =====") + + ideas = state['numerical_crossover_ideas'] + if not ideas: + logger.info("No numerical crossover questions generated, skipping this step.") + return {"numerical_crossover_results": []} + + agents = state['data_agents'] + crossover_agent = state['orchestrator_agent'].crossover_poirot + results = [] + schemas = [] + paths = [] + profiles = [] # NEW: Collect profile information + + for agent in agents: + schemas.append(agent.agent_base.schema) + paths.append(agent.original_file_path) + # Collect profile information for each dataset + profiles.append(agent.profile if hasattr(agent, 'profile') else "{}") + + for question in ideas: + logger.info(f"\nProcessing numerical crossover question: {question}") + + crossover_agent.multi_schema = schemas + crossover_agent.multi_dataset_path = paths + crossover_agent.multi_profile = profiles # NEW: Pass profile info + + # Execute cross-dataset analysis + try: + _, insight_dict = crossover_agent.answer_question( + question, + prompt_code_method="multi_with_paths" + ) + results.append(insight_dict) + except Exception as e: + logger.error(f"Failed to process crossover question: {e}") + results.append({ + "question": question, + "answer": f"Analysis failed: {str(e)}", + "error": True + }) + + logger.info(f"Completed {len(results)} numerical crossover calculations.") + return {"numerical_crossover_results": results} + + def viewpoint_crossover_node(self, state: OrchestratorState) -> Dict[str, Any]: + """ + Step 5: Viewpoint crossover - synthesize all information into final report. + + Supports two output modes: + - concise: Brief insights with truncated details + - detailed: Full information with detailed_appendix for benchmark comparison + + Also supports background knowledge injection from non-tabular sources. + + Args: + state: Current workflow state + + Returns: + Updated state with final insights, summary, and optional detailed_appendix + """ + logger.info("\n===== STEP 5: Viewpoint Crossover & Final Synthesis =====") + + orchestrator = state['orchestrator_agent'] + annotated_reports = state['annotated_reports'] + numerical_results = state['numerical_crossover_results'] + output_mode = state.get('output_mode', 'concise') + background_pool = state.get('background_knowledge_pool', []) + + logger.info(f"Output mode: {output_mode}") + + # Build comprehensive context + full_context = "" + + for report in annotated_reports: + full_context += f"--- Analysis Report Source: {report['agent_name']} ---\n" + + # In concise mode, limit summary length + summary = report['summary'] + if output_mode == 'concise' and len(summary) > 500: + summary = summary[:500] + "... [truncated]" + full_context += f"Preliminary Summary: {summary}\n" + + if report.get('annotations'): + full_context += "Cross-Agent Annotations:\n" + for ann in report['annotations']: + comment = ann['comment'] + if output_mode == 'concise' and len(comment) > 200: + comment = comment[:200] + "..." + full_context += f" - [{ann['author_agent']}]: {comment}\n" + + if report.get('formal_label'): + full_context += f"Formal Classification: {report['formal_label']} ({report.get('formal_priority', 'Medium')})\n" + + full_context += "\n" + + if numerical_results: + full_context += "--- Cross-Domain Numerical Analysis ---\n" + for res in numerical_results: + full_context += f"Question: {res.get('question', 'N/A')}\n" + answer = res.get('answer', 'N/A') + if output_mode == 'concise' and isinstance(answer, str) and len(answer) > 300: + answer = answer[:300] + "..." + full_context += f"Findings: {answer}\n" + + if res.get('error'): + full_context += f"Status: ERROR\n" + + full_context += "\n" + + # Build background information string + background_info = "No additional background information available." + if background_pool: + bg_parts = [] + for bg in background_pool: + if isinstance(bg, dict): + source = bg.get('source', 'Unknown') + content = bg.get('content', '') + source_type = bg.get('source_type', 'unknown') + # Limit each background item + if output_mode == 'concise' and len(content) > 500: + content = content[:500] + "... [truncated]" + bg_parts.append(f"[{source_type}] {os.path.basename(source)}:\n{content}") + background_info = "\n\n".join(bg_parts) if bg_parts else background_info + + # Generate final synthesis using the appropriate prompt + if hasattr(prompts, 'FINAL_PROMPT_TEMPLATE_WITH_MODES'): + prompt = prompts.FINAL_PROMPT_TEMPLATE_WITH_MODES.format( + full_context=full_context, + background_info=background_info, + output_mode=output_mode + ) + else: + # Fallback to original prompt + prompt = prompts.FINAL_PROMPT_TEMPLATE.format(full_context=full_context) + + response = orchestrator.chat_model(prompt) + response_content = response.content if hasattr(response, 'content') else str(response) + + # Parse response + pred_insights = [] + pred_summary = "" + detailed_appendix = {} + + # Collect all raw insights from single-source analysis and crossover analysis + all_single_insights = [] + all_crossover_insights = [] + + # Extract single-source insights + for report in annotated_reports: + agent_name = report.get('agent_name', 'Unknown') + # Get key_metrics (insights) from each agent's report + agent_insights = report.get('key_metrics', []) + logger.debug(f"Extracting insights from {agent_name}: found {len(agent_insights)} items") + + for insight in agent_insights: + if isinstance(insight, str): + all_single_insights.append({ + "source": agent_name, + "insight": insight + }) + elif isinstance(insight, dict): + # insights_history format: {question, answer, insight, justification, ...} + # Prefer 'insight' field, fallback to 'answer', then to string representation + insight_text = insight.get('insight') or insight.get('answer') or str(insight) + question = insight.get('question', '') + + all_single_insights.append({ + "source": agent_name, + "insight": insight_text, + "question": question if question else None + }) + else: + # Fallback for any other type + all_single_insights.append({ + "source": agent_name, + "insight": str(insight) + }) + + logger.debug(f"Extracted {len([i for i in all_single_insights if i['source'] == agent_name])} insights from {agent_name}") + + # Extract crossover insights + for res in numerical_results: + question = res.get('question', 'N/A') + answer = res.get('answer', 'N/A') + if not res.get('error') and answer != 'N/A': + all_crossover_insights.append({ + "question": question, + "finding": answer + }) + + try: + # Extract JSON from response + json_match = re.search(r'\{.*\}', response_content, re.DOTALL) + if json_match: + json_str = json_match.group() + result = json.loads(json_str) + pred_insights = result.get("insights", []) + pred_summary = result.get("summary", "") + + # Extract detailed_appendix if in detailed mode + if output_mode == 'detailed': + detailed_appendix = result.get("detailed_appendix", {}) + # If LLM didn't provide, build it ourselves + if not detailed_appendix: + detailed_appendix = { + "full_reports": [ + { + "agent_name": r['agent_name'], + "summary": r['summary'], + "insights": r.get('key_metrics', []), + "annotations": r.get('annotations', []) + } + for r in annotated_reports + ], + "crossover_results": numerical_results, + "background_info": background_pool + } + else: + # Fallback to simple parsing + logger.warning("Could not parse JSON response, using fallback logic") + pred_insights = [ + line.strip().lstrip('- ') + for line in response_content.split('\n') + if line.strip() and line.strip().startswith('-') + ] + pred_summary = "Auto-generated summary from analysis." + + except json.JSONDecodeError as e: + logger.error(f"JSON parsing error: {e}") + pred_insights = [ + line.strip().lstrip('- ') + for line in response_content.split('\n') + if line.strip() and line.strip().startswith('-') + ] + pred_summary = "Auto-generated summary from analysis." + + # Log final results + logger.info("\n===== Final Synthesis Report =====") + logger.info(f"Output Mode: {output_mode}") + logger.info(f"Executive Summary: {pred_summary[:200]}...") + logger.info(f"Number of Synthesized Insights (categorized): {len(pred_insights)}") + logger.info(f"Number of Raw Single-Source Insights: {len(all_single_insights)}") + logger.info(f"Number of Raw Crossover Insights: {len(all_crossover_insights)}") + + for i, insight in enumerate(pred_insights[:5], 1): + logger.info(f"{i}. {insight[:100]}...") + + if len(pred_insights) > 5: + logger.info(f"... and {len(pred_insights) - 5} more insights") + + if detailed_appendix: + logger.info(f"Detailed appendix included with {len(detailed_appendix)} sections") + + # Log what we're returning + logger.info(f"\nReturning from viewpoint_crossover_node:") + logger.info(f" - raw_single_insights: {len(all_single_insights)} items") + logger.info(f" - raw_crossover_insights: {len(all_crossover_insights)} items") + if all_single_insights: + logger.info(f" - Sample single insight: {all_single_insights[0]}") + if all_crossover_insights: + logger.info(f" - Sample crossover insight: {all_crossover_insights[0]}") + + return { + "pred_insights": pred_insights, # Synthesized insights (categorized by Trend/Comparison/Extreme/Attribution) + "pred_summary": pred_summary, + "detailed_appendix": detailed_appendix, + "raw_single_insights": all_single_insights, # All single-source insights + "raw_crossover_insights": all_crossover_insights # All crossover insights + } + + def run( + self, + output_mode: str = "concise", + background_knowledge_pool: Optional[List[Dict[str, Any]]] = None + ): + """ + Execute the complete workflow. + + Args: + output_mode: "concise" for brief output, "detailed" for full output with appendix + background_knowledge_pool: List of background info dicts from non-tabular sources + + Returns: + Tuple of (insights list, summary string) or + Tuple of (insights list, summary string, detailed_appendix) if detailed mode + """ + # Initialize workflow state + initial_state = { + "data_agents": self.data_agents, + "data_registry": self.data_registry, + "orchestrator_agent": self.orchestrator_agent, + "global_goal": self.global_goal, + "initial_reports": [], + "annotated_reports": [], + "numerical_crossover_ideas": [], + "numerical_crossover_results": [], + "pred_insights": [], + "pred_summary": "", + "detailed_appendix": {}, + "raw_single_insights": [], # Initialize raw insights + "raw_crossover_insights": [], # Initialize raw crossover insights + "output_mode": output_mode, + "background_knowledge_pool": background_knowledge_pool or [] + } + + logger.info(f"Starting workflow execution (output_mode={output_mode})...") + + # Execute the LangGraph state machine + final_state = self.app.invoke(initial_state) + + logger.info("Workflow execution completed successfully.") + + # Debug: log all keys in final_state + logger.info(f"Final state keys: {list(final_state.keys())}") + + # Extract final results + pred_insights = final_state.get("pred_insights", []) + pred_summary = final_state.get("pred_summary", "") + detailed_appendix = final_state.get("detailed_appendix", {}) + raw_single_insights = final_state.get("raw_single_insights", []) + raw_crossover_insights = final_state.get("raw_crossover_insights", []) + + # Log extracted raw insights + logger.info(f"Extracted from final_state:") + logger.info(f" - raw_single_insights: {len(raw_single_insights)} items") + logger.info(f" - raw_crossover_insights: {len(raw_crossover_insights)} items") + + # Debug: check if keys exist but are empty + if "raw_single_insights" in final_state: + logger.info(f" - raw_single_insights key exists, value type: {type(final_state['raw_single_insights'])}, length: {len(final_state.get('raw_single_insights', []))}") + else: + logger.warning(" - raw_single_insights key NOT FOUND in final_state!") + + if "raw_crossover_insights" in final_state: + logger.info(f" - raw_crossover_insights key exists, value type: {type(final_state['raw_crossover_insights'])}, length: {len(final_state.get('raw_crossover_insights', []))}") + else: + logger.warning(" - raw_crossover_insights key NOT FOUND in final_state!") + + # Combine crossover insights into the same format as single insights + combined_crossover = [] + for item in raw_crossover_insights: + if isinstance(item, dict): + finding = item.get('finding', item.get('answer', 'N/A')) + question = item.get('question', 'N/A') + combined_crossover.append({ + "source": "crossover", + "insight": finding, + "question": question + }) + + # Combine all raw insights + all_raw_insights = raw_single_insights + combined_crossover + + logger.info(f"Combined raw insights: {len(all_raw_insights)} total (single: {len(raw_single_insights)}, crossover: {len(combined_crossover)})") + + # Return all insights types + return { + "synthesized_insights": pred_insights, # Categorized insights (Trend/Comparison/Extreme/Attribution) + "raw_insights": all_raw_insights, # Combined all raw insights + "summary": pred_summary, + "detailed_appendix": detailed_appendix if output_mode == "detailed" else {} + } + + def generate_markdown_report(self, final_state: Dict[str, Any], filename: str) -> str: + """ + Generate a Markdown report from the final workflow state. + + Args: + final_state: Complete workflow state + filename: Output file path + + Returns: + Path to generated report file + """ + logger.info(f"Generating Markdown report: {filename}") + + # Create simplified state for report generation + simplified_state = {} + + for key, value in final_state.items(): + if isinstance(value, (str, int, float, bool, type(None))): + simplified_state[key] = value + elif isinstance(value, (list, dict)): + try: + json.dumps(value) + simplified_state[key] = value + except (TypeError, ValueError): + simplified_state[key] = f"[Non-serializable {type(value).__name__} object]" + else: + simplified_state[key] = f"[{type(value).__name__} object]" + + state_str = json.dumps(simplified_state, indent=2, ensure_ascii=False) + + # Generate report using LLM + chat_model = au.get_chat_model("gpt-4o", 0) + prompt = prompts.REPORT_GENERATION_PROMPT.format(state_str=state_str) + report_content = chat_model(prompt).content + + # Save report \ No newline at end of file diff --git a/workflow_engine/toolkits/insight_tool/insight_wrapper.py b/workflow_engine/toolkits/insight_tool/insight_wrapper.py new file mode 100644 index 0000000..d0c387d --- /dev/null +++ b/workflow_engine/toolkits/insight_tool/insight_wrapper.py @@ -0,0 +1,97 @@ +""" +Thin wrapper around DM InsightEntry for Open-NotebookLM integration. +Handles API key passing and result formatting. +""" +from pathlib import Path +from typing import Dict, Any, Optional +import sys +import os + +# Add insight_tool directory to path so dm_components can be imported as a package +_insight_tool_path = os.path.dirname(os.path.abspath(__file__)) +if _insight_tool_path not in sys.path: + sys.path.insert(0, _insight_tool_path) + +from workflow_engine.logger import get_logger + +log = get_logger(__name__) + + +class InsightToolkit: + """Wrapper for DM insight discovery functionality.""" + + def __init__(self, + model_name: str = "deepseek-v3.2", + api_key: str = "", + base_url: str = "", + base_savedir: str = "./outputs/insights", + temperature: float = 0.1): + """ + Initialize with explicit API credentials. + + Args: + model_name: LLM model name + api_key: API key for LLM + base_url: Base URL for LLM API + base_savedir: Output directory + temperature: LLM temperature + """ + from dm_components.insight_entry import InsightEntry + + self.api_key = api_key + self.base_url = base_url + + # Initialize DM InsightEntry + self.insight_entry = InsightEntry( + model_name=model_name, + base_savedir=base_savedir, + temperature=temperature, + n_retries=1, + branch_depth=1, + max_questions=1, + default_output_mode="concise", + api_key=api_key, + base_url=base_url + ) + + log.info(f"InsightToolkit initialized: model={model_name}") + + def analyze_folder(self, + data_folder: str, + output_mode: str = "concise") -> Dict[str, Any]: + """ + Analyze all datasets in a folder. + + Args: + data_folder: Path to folder with data files + output_mode: "concise" or "detailed" + + Returns: + { + "synthesized_insights": List[str], + "raw_insights": List[str], + "summary": str, + "detailed_appendix": Dict + } + """ + log.info(f"Analyzing folder: {data_folder}") + + try: + result = self.insight_entry.analyze_folder( + data_folder=data_folder, + use_meta_goal=True, + output_mode=output_mode, + include_background=True + ) + + log.info(f"Analysis complete: {len(result.get('synthesized_insights', []))} insights") + return result + + except Exception as e: + log.error(f"Analysis failed: {e}", exc_info=True) + return { + "synthesized_insights": [], + "raw_insights": [], + "summary": f"Analysis failed: {str(e)}", + "detailed_appendix": {} + } diff --git a/workflow_engine/toolkits/ragtool/vector_store_tool.py b/workflow_engine/toolkits/ragtool/vector_store_tool.py index c1c84fc..0dc86c8 100644 --- a/workflow_engine/toolkits/ragtool/vector_store_tool.py +++ b/workflow_engine/toolkits/ragtool/vector_store_tool.py @@ -274,9 +274,10 @@ def _call_embedding_api(self, texts: List[str]) -> np.ndarray: vecs = [] # Batch processing to avoid payload limits - batch_size = 10 - - with httpx.Client(timeout=60.0) as client: + batch_size = 10 + + # Disable proxy to avoid issues with unavailable proxy servers + with httpx.Client(timeout=60.0, proxies=False) as client: for i in range(0, len(texts), batch_size): batch = texts[i:i+batch_size] # Replace newlines which can negatively affect performance diff --git a/workflow_engine/toolkits/research_tools.py b/workflow_engine/toolkits/research_tools.py index 39354b0..1468ee3 100644 --- a/workflow_engine/toolkits/research_tools.py +++ b/workflow_engine/toolkits/research_tools.py @@ -52,7 +52,7 @@ def fetch_page_text(url: str, max_chars: int = 50000) -> str: "User-Agent": "Mozilla/5.0 (compatible; OpenNotebook/1.0; +https://opennotebook.ai)" } try: - with httpx.Client(timeout=20, headers=headers, follow_redirects=True) as client: + with httpx.Client(timeout=20, headers=headers, follow_redirects=True, proxies=False) as client: resp = client.get(url.strip()) resp.raise_for_status() content_type = (resp.headers.get("content-type") or "").lower() @@ -88,7 +88,7 @@ def serpapi_search(query: str, api_key: str, engine: str = "google", num: int = "api_key": api_key, "num": num, } - with httpx.Client(timeout=20) as client: + with httpx.Client(timeout=20, proxies=False) as client: resp = client.get("https://serpapi.com/search.json", params=params) resp.raise_for_status() data = resp.json() @@ -119,7 +119,7 @@ def google_cse_search( "num": max(1, min(10, num)), "start": max(1, start), } - with httpx.Client(timeout=20) as client: + with httpx.Client(timeout=20, proxies=False) as client: resp = client.get("https://www.googleapis.com/customsearch/v1", params=params) resp.raise_for_status() data = resp.json() @@ -141,7 +141,7 @@ def brave_search(query: str, api_key: str, count: int = 10) -> List[Dict[str, An """Brave Search API.""" headers = {"X-Subscription-Token": api_key} params = {"q": query, "count": max(1, min(20, count))} - with httpx.Client(timeout=20, headers=headers) as client: + with httpx.Client(timeout=20, headers=headers, proxies=False) as client: resp = client.get("https://api.search.brave.com/res/v1/web/search", params=params) resp.raise_for_status() data = resp.json() @@ -181,7 +181,7 @@ def bocha_web_search( "Authorization": f"Bearer {api_key.strip()}", "Content-Type": "application/json", } - with httpx.Client(timeout=25) as client: + with httpx.Client(timeout=25, proxies=False) as client: resp = client.post(BOCHA_WEB_SEARCH_URL, json=payload, headers=headers) resp.raise_for_status() body = resp.json() diff --git a/workflow_engine/workflow/wf_data_insight.py b/workflow_engine/workflow/wf_data_insight.py new file mode 100644 index 0000000..9d1c18d --- /dev/null +++ b/workflow_engine/workflow/wf_data_insight.py @@ -0,0 +1,129 @@ +""" +Data Insight Discovery Workflow +Integrates DM insight framework for multi-dataset analysis. +""" +from __future__ import annotations +import os +import shutil +import time +from pathlib import Path +from typing import Dict, Any + +from workflow_engine.workflow.registry import register +from workflow_engine.graphbuilder.graph_builder import GenericGraphBuilder +from workflow_engine.logger import get_logger +from workflow_engine.state import DataInsightState +from workflow_engine.utils import get_project_root + +log = get_logger(__name__) + + +@register("data_insight") +def create_data_insight_graph() -> GenericGraphBuilder: + """ + Workflow for multi-dataset insight discovery using DM framework. + + Steps: + 1. Initialize paths and validate inputs + 2. Prepare data folder (copy uploaded files) + 3. Run DM insight analysis + """ + builder = GenericGraphBuilder(state_model=DataInsightState, entry_point="_start_") + + def _start_(state: DataInsightState) -> DataInsightState: + """Initialize paths and validate inputs.""" + if not state.request.file_ids: + state.request.file_ids = [] + + # Create output directory + if not state.result_path: + project_root = get_project_root() + ts = int(time.time()) + email = getattr(state.request, 'email', None) or 'default' + output_dir = project_root / "outputs" / "data_insights" / email / f"{ts}_insight" + output_dir.mkdir(parents=True, exist_ok=True) + state.result_path = str(output_dir) + log.info(f"Output directory: {state.result_path}") + + return state + + async def prepare_data_node(state: DataInsightState) -> DataInsightState: + """Copy uploaded files to analysis folder.""" + if not state.request.file_ids: + log.warning("No files provided for analysis") + return state + + # Create data folder + data_folder = Path(state.result_path) / "data" + data_folder.mkdir(exist_ok=True) + + # Copy files + for file_path in state.request.file_ids: + src = Path(file_path) + if src.exists(): + dst = data_folder / src.name + shutil.copy2(src, dst) + log.info(f"Copied {src.name} to analysis folder") + + # Create meta-info.json if custom goal provided + if state.request.analysis_goal: + import json + meta_path = Path(state.result_path) / "meta-info.json" + meta_path.write_text(json.dumps({ + "goal": state.request.analysis_goal + }, ensure_ascii=False, indent=2)) + + return state + + async def analyze_node(state: DataInsightState) -> DataInsightState: + """Run DM insight analysis.""" + from workflow_engine.toolkits.insight_tool.insight_wrapper import InsightToolkit + + data_folder = Path(state.result_path) / "data" + + try: + # Initialize toolkit with API credentials + toolkit = InsightToolkit( + model_name=state.request.model, + api_key=state.request.api_key, + base_url=state.request.chat_api_url, + base_savedir=state.result_path, + temperature=0.1 + ) + + # Run analysis + result = toolkit.analyze_folder( + data_folder=str(data_folder), + output_mode=state.request.output_mode + ) + + # Extract results + state.synthesized_insights = result.get("synthesized_insights", []) + state.raw_insights = result.get("raw_insights", []) + state.summary = result.get("summary", "") + state.detailed_appendix = result.get("detailed_appendix", {}) + + log.info(f"Analysis complete: {len(state.synthesized_insights)} insights") + + except Exception as e: + log.error(f"Insight analysis failed: {e}", exc_info=True) + state.summary = f"Analysis failed: {str(e)}" + + return state + + # Build graph + nodes = { + "_start_": _start_, + "prepare_data": prepare_data_node, + "analyze": analyze_node, + "_end_": lambda s: s + } + + edges = [ + ("_start_", "prepare_data"), + ("prepare_data", "analyze"), + ("analyze", "_end_") + ] + + builder.add_nodes(nodes).add_edges(edges) + return builder