diff --git a/.gitignore b/.gitignore index 1819ead..9c65ee4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ data_climate_foresight.tar data/ *.ipynb __pycache__ -climsight.log +climsight.logwhe climsight_evaluation.log cache/ evaluation/evaluation_report.txt @@ -12,4 +12,5 @@ rag_articles/ .* *.log venv311 -venv \ No newline at end of file +venv +tmp/ \ No newline at end of file diff --git a/README.md b/README.md index 734aaf4..4e58e7e 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,7 @@ ClimSight is an advanced tool that integrates Large Language Models (LLMs) with climate data to provide localized climate insights for decision-making. ClimSight transforms complex climate data into actionable insights for agriculture, urban planning, disaster management, and policy development. -The target audience includes researchers, providers of climate services, policymakers, agricultural planners, urban developers, and other stakeholders who require detailed climate information to support decision-making. ClimSight is designed to democratize access to climate -data, empowering users with insights relevant to their specific contexts. +The target audience includes researchers, providers of climate services, policymakers, agricultural planners, urban developers, and other stakeholders who require detailed climate information to support decision-making. ClimSight is designed to democratize access to climate data, empowering users with insights relevant to their specific contexts. ![Image](https://github.com/user-attachments/assets/f9f89735-ef08-4c91-bc03-112c8e4c0896) @@ -15,61 +14,11 @@ ClimSight distinguishes itself through several key advancements: - **Real-World Applications**: ClimSight is validated through practical examples, such as assessing climate risks for specific agricultural activities and urban planning scenarios. -## Installation Options +## Installation -You can use ClimSight in three ways: -1. Run a pre-built Docker container (simplest approach) -2. Build and run a Docker container from source -3. Install the Python package (via pip or conda/mamba) +### Recommended: Building from source with conda/mamba -Using ClimSight requires an OpenAI API key unless using the `skipLLMCall` mode for testing. The API key is only needed when running the application, not during installation. - -## Batch Processing - -For batch processing of climate questions, the `sequential` directory contains specialized tools for generating, validating, and processing questions in bulk. These tools are particularly useful for research and analysis requiring multiple climate queries. See the [sequential/README.md](sequential/README.md) for detailed usage instructions. - -## 1. Running with Docker (Pre-built Container) - -The simplest way to get started is with our pre-built Docker container: - -```bash -# Make sure your OpenAI API key is set as an environment variable -export OPENAI_API_KEY="your-api-key-here" - -# Pull and run the container -docker pull koldunovn/climsight:stable -docker run -p 8501:8501 -e OPENAI_API_KEY=$OPENAI_API_KEY koldunovn/climsight:stable -``` - -Then open `http://localhost:8501/` in your browser. - -## 2. Building and Running from Source with Docker - -If you prefer to build from the latest source: - -```bash -# Clone the repository -git clone https://github.com/CliDyn/climsight.git -cd climsight - -# Download required data -python download_data.py - -# Build and run the container -docker build -t climsight . -docker run -p 8501:8501 -e OPENAI_API_KEY=$OPENAI_API_KEY climsight -``` - -Visit `http://localhost:8501/` in your browser once the container is running. - -For testing without OpenAI API calls: -```bash -docker run -p 8501:8501 -e STREAMLIT_ARGS="skipLLMCall" climsight -``` - -## 3. Python Package Installation - -### Option A: Building from source with conda/mamba +This is the recommended installation method to get the latest features and updates. ```bash # Clone the repository @@ -82,13 +31,15 @@ conda activate climsight # Download required data python download_data.py + +# Optional: download DestinE data (large ~12 GB, not downloaded by default) +python download_data.py DestinE ``` -### Option B: Using pip +### Alternative: Using pip from source -It's recommended to create a virtual environment to avoid dependency conflicts: ```bash -# Option 1: Install from source +# Clone the repository git clone https://github.com/CliDyn/climsight.git cd climsight @@ -96,33 +47,37 @@ cd climsight python -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate -# Install ClimSight -pip install -e . +# Install dependencies +pip install -r requirements.txt + +# Download required data python download_data.py + +# Optional: download DestinE data (large ~12 GB, not downloaded by default) +python download_data.py DestinE ``` -Or if you prefer to set up without cloning the repository: +### Running with Docker (Stable Release v1.0.0) + +The Docker container provides a stable release (v1.0.0) of ClimSight. For the latest features, please install from source as described above. ```bash -# Option 2: Install from PyPI -# Create and activate a virtual environment -python -m venv climsight_env -source climsight_env/bin/activate # On Windows: climsight_env\Scripts\activate +# Make sure your OpenAI API key is set as an environment variable +export OPENAI_API_KEY="your-api-key-here" -# Install the package -pip install climsight +# Pull and run the container +docker pull koldunovn/climsight:stable +docker run -p 8501:8501 -e OPENAI_API_KEY=$OPENAI_API_KEY koldunovn/climsight:stable +``` -# Create a directory for data -mkdir -p climsight -cd climsight +Then open `http://localhost:8501/` in your browser. -# Download necessary configuration files -wget https://raw.githubusercontent.com/CliDyn/climsight/main/data_sources.yml -wget https://raw.githubusercontent.com/CliDyn/climsight/main/download_data.py -wget https://raw.githubusercontent.com/CliDyn/climsight/main/config.yml +### Using pip from PyPI (Stable Release v1.0.0) -# Download the required data (about 8 GB) -python download_data.py +The PyPI package provides a stable release (v1.0.0) of ClimSight. For the latest features, please install from source as described above. + +```bash +pip install climsight ``` ## Configuration @@ -131,50 +86,54 @@ ClimSight will automatically use a `config.yml` file from the current directory. ```yaml # Key settings you can modify in config.yml: -# - LLM model (gpt-4, ...) +# - LLM model (gpt-4, gpt-5, ...) # - Climate data sources # - RAG database configuration # - Agent parameters +# - ERA5 data retrieval settings ``` -## Running ClimSight -### If installed with conda/mamba from source: +## API Keys -```bash -# Run from the repository root -streamlit run src/climsight/climsight.py -``` +### OpenAI API Key -### If installed with pip: +ClimSight requires an OpenAI API key for LLM functionality. You can set it as an environment variable: ```bash -# Make sure you're in the directory with your data and config -climsight +export OPENAI_API_KEY="your-api-key-here" ``` -You can optionally set your OpenAI API key as an environment variable: +Alternatively, you can enter your API key directly in the browser interface when prompted. + +### Arraylake API Key (Optional - for ERA5 Data) + +If you want to use ERA5 time series data retrieval (enabled via the "Enable ERA5 data" toggle in the UI), you need an Arraylake API key from [Earthmover](https://earthmover.io/). This allows downloading ERA5 reanalysis data for detailed historical climate analysis. + ```bash -export OPENAI_API_KEY="your-api-key-here" +export ARRAYLAKE_API_KEY="your-arraylake-api-key-here" ``` -Otherwise, you can enter your API key directly in the browser interface when prompted. +You can also enter the Arraylake API key in the browser interface when the ERA5 data option is enabled. -### Testing without an OpenAI API key: +## Running ClimSight ```bash -# From source: -streamlit run src/climsight/climsight.py skipLLMCall - -# Or if installed with pip: -climsight skipLLMCall +# Run from the repository root +streamlit run src/climsight/climsight.py ``` The application will open in your browser automatically. Just type your climate-related questions and press "Generate" to get insights. ClimSight Interface +## Batch Processing + +For batch processing of climate questions, the `sequential` directory contains specialized tools for generating, validating, and processing questions in bulk. These tools are particularly useful for research and analysis requiring multiple climate queries. See the [sequential/README.md](sequential/README.md) for detailed usage instructions. + ## Citation If you use or refer to ClimSight in your work, please cite: +Kuznetsov, I., Jost, A.A., Pantiukhin, D. et al. Transforming climate services with LLMs and multi-source data integration. _npj Clim. Action_ **4**, 97 (2025). https://doi.org/10.1038/s44168-025-00300-y + Koldunov, N., Jung, T. Local climate services for all, courtesy of large language models. _Commun Earth Environ_ **5**, 13 (2024). https://doi.org/10.1038/s43247-023-01199-1 diff --git a/config.yml b/config.yml index 4e0c2e6..1f067df 100644 --- a/config.yml +++ b/config.yml @@ -2,16 +2,27 @@ #model_type: "openai" #"openai / local / aitta llm_rag: model_type: "openai" - model_name: "gpt-4.1-nano" # used only for RAGs + model_name: "gpt-5-mini" # used only for RAGs llm_smart: #used only in smart_agent model_type: "openai" - model_name: "gpt-4.1-nano" # used only for smart agent + model_name: "gpt-5.2" # used only for smart agent llm_combine: #used only in combine_agent and intro model_type: "openai" - model_name: "gpt-4.1-nano" # used only for combine agent ("mkchaou/climsight-calm_ft_Q3_13k") + model_name: "gpt-5.2" # used only for combine agent ("mkchaou/climsight-calm_ft_Q3_13k") +llm_dataanalysis: #used only in data_analysis_agent + model_type: "openai" + model_name: "gpt-5.2" + use_filter_step: true # Set to false to skip context filtering LLM call climatemodel_name: "AWI_CM" llmModeKey: "agent_llm" #"agent_llm" #"direct_llm" use_smart_agent: false +use_era5_data: false # Download ERA5 time series from CDS API (requires credentials) +use_powerful_data_analysis: false + +# ERA5 Climatology Configuration (pre-computed observational baseline) +era5_climatology: + enabled: true # Always use ERA5 climatology as ground truth baseline + path: "data/era5/era5_climatology_2015_2025.zarr" # Path to pre-computed climatology # Climate Data Source Configuration # Options: "nextGEMS", "ICCP", "AWI_CM" @@ -126,6 +137,50 @@ climate_data_sources: longitude: "lon" time: "month" + DestinE: + enabled: true + coordinate_system: "unstructured" + description: "DestinE IFS-FESOM high-resolution climate simulations (SSP3-7.0)" + data_path: "./data/DestinE/" + # Time periods configuration + time_periods: + historical: + pattern: "ifs-fesom_baseline_hist_sfc_high_monthly_1990_2014_mean" + years_of_averaging: "1990-2014" + description: "DestinE IFS-FESOM historical baseline simulation" + is_main: true + source: "Destination Earth Climate DT, IFS-FESOM coupled model" + 2015_2019: + pattern: "ifs-fesom_projections_ssp3-7.0_sfc_high_monthly_2015_2019_mean" + years_of_averaging: "2015-2019" + description: "DestinE IFS-FESOM SSP3-7.0 near-term projection" + is_main: false + source: "Destination Earth Climate DT, IFS-FESOM coupled model, SSP3-7.0" + 2020_2029: + pattern: "ifs-fesom_projections_ssp3-7.0_sfc_high_monthly_2020_2029_mean" + years_of_averaging: "2020-2029" + description: "DestinE IFS-FESOM SSP3-7.0 mid-term projection" + is_main: false + source: "Destination Earth Climate DT, IFS-FESOM coupled model, SSP3-7.0" + 2040_2049: + pattern: "ifs-fesom_projections_ssp3-7.0_sfc_high_monthly_2040_2049_mean" + years_of_averaging: "2040-2049" + description: "DestinE IFS-FESOM SSP3-7.0 far-term projection" + is_main: false + source: "Destination Earth Climate DT, IFS-FESOM coupled model, SSP3-7.0" + # Variable mapping: display_name -> netcdf_variable + variable_mapping: + Temperature: avg_2t + Total Precipitation: avg_tprate + Wind U: avg_10u + Wind V: avg_10v + # Variable file suffixes (to construct full filenames) + variable_suffixes: + avg_2t: "_avg_2t.nc" + avg_tprate: "_avg_tprate.nc" + avg_10u: "_avg_10u.nc" + avg_10v: "_avg_10v.nc" + # Legacy settings (kept for backwards compatibility, will be migrated automatically) data_settings: data_path: "./data/" diff --git a/data_sources.yml b/data_sources.yml index 88b7f7b..45f1e08 100644 --- a/data_sources.yml +++ b/data_sources.yml @@ -120,4 +120,10 @@ sources: url: 'https://swift.dkrz.de/v1/dkrz_035d8f6ff058403bb42f8302e6badfbc/climsight/awi_cm.zip?temp_url_sig=f40cc2f349b24482a6f7247d173ca194fad28950&temp_url_expires=2299-10-02T09:52:13Z' archive_type: 'zip' subdir: './' + citation: + + - filename: 'DestinE.zip' + url: 'https://swift.dkrz.de/v1/dkrz_035d8f6ff058403bb42f8302e6badfbc/climsight/DestinE.zip?temp_url_sig=f60ad2be0bf65479f489611255c066148dc4741c&temp_url_expires=2053-06-19T11:20:40Z' + archive_type: 'zip' + subdir: './' citation: \ No newline at end of file diff --git a/download_data.py b/download_data.py index 47c890f..9f6a81a 100644 --- a/download_data.py +++ b/download_data.py @@ -97,6 +97,11 @@ def main(): # Parse command-line argument (--source_files) parser = argparse.ArgumentParser(description="Download and extract the raw source files of the RAG.") parser.add_argument('--source_files', type=bool, default=False, help='Whether to download and extract source files (IPCC text reports).') + parser.add_argument( + 'datasets', + nargs='*', + help="Optional extra datasets to include (e.g. DestinE).", + ) #parser.add_argument('--CMIP_OIFS', type=bool, default=False, help='Whether to download CMIP6 low resolution AWI model data and ECE4/OIFS data.') args = parser.parse_args() @@ -112,6 +117,11 @@ def main(): sources = [d for d in sources if d['filename'] != 'ipcc_text_reports.zip'] #if not args.CMIP_OIFS: # sources = [d for d in sources if d['filename'] != 'data_climate_foresight.zip'] + + # Skip DestinE unless explicitly requested (large dataset). + requested = {name.strip().lower() for name in args.datasets} + if 'destine' not in requested: + sources = [d for d in sources if d['filename'] != 'DestinE.zip'] #make subdirs list and clean it subdirs = [] @@ -136,6 +146,12 @@ def main(): url = entry['url'] subdir = os.path.join(base_path, entry['subdir']) + if not url: + files_skiped.append(file) + urls_skiped.append(url) + subdirs_skiped.append(subdir) + continue + if download_file(url, file): extract_arch(file, subdir) files_downloaded.append(file) diff --git a/environment.yml b/environment.yml index b8d272c..63df334 100644 --- a/environment.yml +++ b/environment.yml @@ -4,36 +4,72 @@ channels: - defaults dependencies: - python=3.11 + + # Web Framework - streamlit + - streamlit-folium + - folium + + # Climate & Geospatial Data Processing - xarray - - geopy + - netcdf4 - geopandas + - shapely - pyproj - - requests - - requests-mock + - geopy + - osmnx + + # Data Science & Numerical Computing - pandas - - folium - - openai - - langchain - - streamlit-folium - - netcdf4 + - numpy + - scipy - dask - - pip - - langchain-community + - matplotlib + + # LangChain & LLM Framework (conda-forge packages) + - langchain + - langchain-community - langchain-openai - langchain-chroma - langchain-core - - osmnx - - matplotlib - - pydantic + - langchain-text-splitters + - langchain-experimental - langgraph - - chardet + - openai + - pydantic + + # Web Scraping & Data Retrieval - bs4 - wikipedia - - scipy - - pyproj + - requests + - requests-mock + + # ERA5 & Zarr tooling + - zarr + - gcsfs + - numcodecs + - blosc + + # Jupyter-backed Python REPL + - jupyter_client + - ipykernel + + # PDF Generation - reportlab - # pip-only packages must go under this section: + + # Configuration & Utilities + - pyyaml + - python-dotenv + + # Testing + - pytest + + # Package Management + - pip + + # pip-only packages (not available on conda-forge) - pip: - - aitta-client - - langchain_classic + - langchain-classic + - langchain-anthropic + - arraylake + - aitta-client diff --git a/pyproject.toml b/pyproject.toml index d17c21d..16f232c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,39 +11,72 @@ authors = [ {name = "boryasbora"} ] license = {text = "BSD-3-Clause"} -keywords = ["climate", "llm", "climate-assessment", "rag", "decision-support"] +keywords = ["climate", "llm", "climate-assessment", "rag", "decision-support"] +requires-python = ">=3.11" dependencies = [ + # Web Framework "streamlit", + "streamlit-folium", + "folium", + + # Climate & Geospatial Data Processing "xarray", - "geopy", + "netcdf4", "geopandas", + "shapely", "pyproj", - "requests", - "requests-mock", + "geopy", + "osmnx", + + # Data Science & Numerical Computing "pandas", - "folium", - "langchain", - "streamlit-folium", - "netcdf4", + "numpy", + "scipy", "dask", - "pip", - "osmnx", "matplotlib", - "openai", + + # LangChain & LLM Framework + "langchain", + "langchain-classic", "langchain-community", "langchain-openai", "langchain-chroma", "langchain-core", - "pydantic", + "langchain-text-splitters", + "langchain-experimental", + "langchain-anthropic", "langgraph", + "openai", + "pydantic", + + # Web Scraping & Data Retrieval "bs4", "wikipedia", - "scipy", - "pyproj", + "requests", + "requests-mock", + + # PDF Generation "reportlab", - "aitta-client" + + # ERA5 & Zarr tooling + "zarr", + "arraylake", + "gcsfs", + "numcodecs", + "blosc", + + # Jupyter-backed Python REPL + "jupyter_client", + "ipykernel", + + # Configuration & Utilities + "pyyaml", + "python-dotenv", ] +[project.optional-dependencies] +aitta = ["aitta-client"] +dev = ["pytest", "flake8"] [build-system] requires = ["setuptools"] @@ -54,8 +87,6 @@ package-dir = {"" = "src"} [tool.setuptools.packages.find] where = ["src"] -#find = {} # Scan the project directory with the default parameters - [project.scripts] climsight = "climsight.launch:launch_streamlit" diff --git a/requirements.txt b/requirements.txt index 2bd51be..b874d85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,19 @@ # ClimSight Dependencies -# Generated from pyproject.toml # Install with: pip install -r requirements.txt # Web Framework streamlit streamlit-folium +folium # Climate & Geospatial Data Processing xarray netcdf4 geopandas +shapely pyproj geopy osmnx -folium # Data Science & Numerical Computing pandas @@ -32,6 +32,8 @@ langchain-openai langchain-chroma langchain-core langchain-text-splitters +langchain-experimental +langchain-anthropic langgraph openai pydantic @@ -45,13 +47,24 @@ requests-mock # PDF Generation reportlab -# Optional: CSC's AITTA Platform (local/custom AI models) -# Uncomment if you need CSC AITTA integration -# aitta-client +# ERA5 + Zarr tooling +zarr +arraylake +gcsfs +numcodecs +blosc + +# Jupyter-backed Python REPL +jupyter_client +ipykernel + +# Configuration & Utilities +pyyaml +python-dotenv # Development & Testing pytest -pyyaml -# Package Management -pip +# Optional: CSC's AITTA Platform (local/custom AI models) +# Uncomment if you need CSC AITTA integration +# aitta-client diff --git a/src/climsight/agent_helpers.py b/src/climsight/agent_helpers.py new file mode 100644 index 0000000..4499aeb --- /dev/null +++ b/src/climsight/agent_helpers.py @@ -0,0 +1,121 @@ +"""Helper utilities for tool-based agents (PangaeaGPT parity).""" + +import logging +import os +from typing import Any, Dict, List, Tuple + +try: + import streamlit as st +except ImportError: + st = None + +try: + from langchain.agents import AgentExecutor, create_openai_tools_agent +except ImportError: + from langchain_classic.agents import AgentExecutor, create_openai_tools_agent + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + +logger = logging.getLogger(__name__) + + +def prepare_visualization_environment(datasets_info: List[Dict[str, Any]]) -> Tuple[Dict[str, Any], str, List[str]]: + """Prepare dataset variables and prompt text for tool agents.""" + datasets_text = "" + dataset_variables: List[str] = [] + datasets: Dict[str, Any] = {} + + uuid_main_dir = None + for info in datasets_info: + sandbox_path = info.get("sandbox_path") + if sandbox_path and isinstance(sandbox_path, str) and os.path.isdir(sandbox_path): + uuid_main_dir = os.path.dirname(os.path.abspath(sandbox_path)) + logger.info("Found main UUID directory from sandbox_path: %s", uuid_main_dir) + break + + datasets["uuid_main_dir"] = uuid_main_dir + + results_dir = None + uuid_dir_files: List[str] = [] + if uuid_main_dir and os.path.exists(uuid_main_dir): + results_dir = os.path.join(uuid_main_dir, "results") + os.makedirs(results_dir, exist_ok=True) + datasets["results_dir"] = results_dir + try: + uuid_dir_files = os.listdir(uuid_main_dir) + except Exception as exc: + logger.error("Error listing UUID directory files: %s", exc) + + # Path instructions for the prompt + uuid_paths = "WARNING: EXACT DATASET PATHS - USE THESE EXACTLY AS SHOWN\n" + uuid_paths += "The following paths contain unique IDs that MUST be used with os.path.join().\n\n" + + if uuid_main_dir: + uuid_paths += "# MAIN OUTPUT DIRECTORY\n" + uuid_paths += f"uuid_main_dir = r'{uuid_main_dir}'\n" + uuid_paths += f"results_dir = r'{results_dir}' # Save all plots here\n\n" + uuid_paths += f"# Files in main directory: {', '.join(uuid_dir_files) if uuid_dir_files else 'None'}\n\n" + + for i, info in enumerate(datasets_info): + var_name = f"dataset_{i + 1}" + datasets[var_name] = info.get("dataset") + dataset_variables.append(var_name) + + sandbox_path = info.get("sandbox_path") + if sandbox_path and isinstance(sandbox_path, str) and os.path.isdir(sandbox_path): + full_uuid_path = os.path.abspath(sandbox_path).replace("\\", "/") + uuid_paths += f"# Dataset {i + 1}: {info.get('name', 'unknown')}\n" + uuid_paths += f"{var_name}_path = r'{full_uuid_path}'\n\n" + if os.path.exists(full_uuid_path): + try: + files = os.listdir(full_uuid_path) + uuid_paths += f"# Files available in {var_name}_path: {', '.join(files)}\n\n" + except Exception as exc: + uuid_paths += f"# Error listing files: {exc}\n\n" + + uuid_paths += "# WARNINGS\n" + uuid_paths += "# 1. Never use placeholder paths.\n" + uuid_paths += "# 2. Always use the dataset_X_path variables shown above.\n" + uuid_paths += "# 3. Check which files exist before reading.\n\n" + + datasets_summary = "" + for i, info in enumerate(datasets_info): + datasets_summary += ( + f"Dataset {i + 1}:\n" + f"Name: {info.get('name', 'Unknown')}\n" + f"Description: {info.get('description', 'No description available')}\n" + f"Type: {info.get('data_type', 'Unknown type')}\n" + f"Sample Data: {info.get('df_head', 'No sample available')}\n\n" + ) + + datasets_text = uuid_paths + datasets_summary + + if st is not None and hasattr(st, "session_state"): + st.session_state["viz_datasets_text"] = datasets_text + + return datasets, datasets_text, dataset_variables + + +def create_standard_agent_executor(llm, tools, prompt_template, max_iterations: int = 25) -> AgentExecutor: + """Create an OpenAI tools agent executor with standard wiring.""" + agent = create_openai_tools_agent( + llm, + tools=tools, + prompt=ChatPromptTemplate.from_messages( + [ + ("system", prompt_template), + ("user", "{input}"), + MessagesPlaceholder(variable_name="messages"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ), + ) + + return AgentExecutor( + agent=agent, + tools=tools, + verbose=True, + handle_parsing_errors=True, + max_iterations=max_iterations, + return_intermediate_steps=True, + ) diff --git a/src/climsight/climate_data_providers.py b/src/climsight/climate_data_providers.py index 07275e4..a2c15a0 100644 --- a/src/climsight/climate_data_providers.py +++ b/src/climsight/climate_data_providers.py @@ -296,8 +296,13 @@ def _extract_data_healpix( """ dataset = xr.open_dataset(nc_file) + # Normalize longitude to 0-360 range (HEALPix data uses 0-360) + query_lon = desired_lon + if desired_lon < 0: + query_lon = desired_lon + 360 + # Query 4 nearest neighbors - distances, indices = tree.query([desired_lon, desired_lat], k=4) + distances, indices = tree.query([query_lon, desired_lat], k=4) neighbors_lons = lons[indices] neighbors_lats = lats[indices] @@ -525,8 +530,13 @@ def _extract_data_regular_grid( # Open NetCDF file ds = xr.open_dataset(file_path) + # Normalize longitude to 0-360 range if dataset uses that convention + query_lon = lon + if lon < 0 and ds.lon.min() >= 0: + query_lon = lon + 360 + # Use bilinear interpolation to get data at the exact point - ds_interpolated = ds.interp(lat=lat, lon=lon, method='linear') + ds_interpolated = ds.interp(lat=lat, lon=query_lon, method='linear') # Extract variables for all 12 months month_names = [calendar.month_name[i] for i in range(1, 13)] @@ -820,6 +830,285 @@ def extract_data( ) +class DestinEProvider(ClimateDataProvider): + """Provider for DestinE IFS-FESOM high-resolution climate data. + + This provider handles unstructured grid data (similar to HEALPix) using + cKDTree for efficient spatial lookups. Unlike NextGEMS, DestinE stores + each variable in separate files that must be combined. + + Data characteristics: + - Coordinate system: Unstructured grid (12.5M points) + - Variables: avg_2t (temp), avg_tprate (precip), avg_10u/avg_10v (wind) + - Time periods: 1990-2014 (hist), 2015-2019, 2020-2029, 2040-2049 (SSP3-7.0) + - Units: Temperature in K, precipitation in kg m⁻² s⁻¹, wind in m/s + """ + + # Days in each month for precipitation conversion (using 28.25 for Feb to account for leap years) + DAYS_IN_MONTH = [31, 28.25, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + SECONDS_PER_DAY = 86400 + + def __init__(self, source_config: dict, global_config: dict = None): + super().__init__(source_config, global_config) + self._spatial_indices = {} # Cache for cKDTree per file + self._data_path = source_config.get('data_path', './data/DestinE/') + + @property + def name(self) -> str: + return "DestinE" + + @property + def coordinate_system(self) -> str: + return "unstructured" + + def is_available(self) -> bool: + """Check if at least one complete time period exists.""" + time_periods = self.source_config.get('time_periods', {}) + var_suffixes = self.source_config.get('variable_suffixes', {}) + + if not time_periods or not var_suffixes: + return False + + for period_key, period_meta in time_periods.items(): + pattern = period_meta.get('pattern', '') + if not pattern: + continue + + # Check if all variable files exist for this period + all_exist = True + for var_name, suffix in var_suffixes.items(): + file_path = os.path.join(self._data_path, f"{pattern}{suffix}") + if not os.path.exists(file_path): + all_exist = False + break + if all_exist: + return True + return False + + def _build_spatial_index(self, nc_file: str) -> Tuple[cKDTree, np.ndarray, np.ndarray]: + """Build cKDTree spatial index from a NetCDF file.""" + if nc_file in self._spatial_indices: + return self._spatial_indices[nc_file] + + ds = xr.open_dataset(nc_file) + lons = ds['longitude'].values + lats = ds['latitude'].values + ds.close() + + points = np.column_stack((lons, lats)) + tree = cKDTree(points) + + self._spatial_indices[nc_file] = (tree, lons, lats) + return tree, lons, lats + + def _extract_point_data( + self, + ds: xr.Dataset, + var_name: str, + indices: np.ndarray, + weights: np.ndarray, + use_exact: bool, + exact_idx: int = 0 + ) -> np.ndarray: + """Extract interpolated values for all months at a point.""" + interpolated = [] + for month_idx in range(12): + data_values = ds[var_name][month_idx, indices].values + if use_exact: + value = data_values[exact_idx] + else: + value = np.dot(weights, data_values) + interpolated.append(value) + return np.array(interpolated) + + def _convert_precip_rate(self, values: np.ndarray) -> np.ndarray: + """Convert precipitation from kg m⁻² s⁻¹ to mm/month. + + 1 kg/m² water = 1 mm depth (by definition of water density) + So: kg m⁻² s⁻¹ = mm/s + mm/month = mm/s × days_in_month × 86400 s/day + """ + converted = np.array([ + values[i] * self.DAYS_IN_MONTH[i] * self.SECONDS_PER_DAY + for i in range(len(values)) + ]) + return converted + + def _post_process_data( + self, + df: pd.DataFrame, + df_vars: Dict + ) -> Tuple[pd.DataFrame, Dict]: + """Apply unit conversions and calculate wind speed/direction.""" + df_processed = df.copy() + df_vars_processed = df_vars.copy() + + # Temperature: K to °C + if 'avg_2t' in df_processed.columns: + df_processed['avg_2t'] = df_processed['avg_2t'] - 273.15 + df_vars_processed['avg_2t']['units'] = '°C' + + # Precipitation: kg m⁻² s⁻¹ to mm/month + if 'avg_tprate' in df_processed.columns: + df_processed['avg_tprate'] = self._convert_precip_rate( + df_processed['avg_tprate'].values + ) + df_vars_processed['avg_tprate']['units'] = 'mm/month' + + # Calculate wind speed and direction + if 'avg_10u' in df_processed.columns and 'avg_10v' in df_processed.columns: + wind_speed = np.sqrt( + df_processed['avg_10u']**2 + df_processed['avg_10v']**2 + ) + wind_direction = (180.0 + np.degrees( + np.arctan2(df_processed['avg_10u'], df_processed['avg_10v']) + )) % 360 + + df_processed['wind_speed'] = wind_speed.round(2) + df_processed['wind_direction'] = wind_direction.round(2) + + df_vars_processed['wind_speed'] = { + 'name': 'wind_speed', 'units': 'm/s', + 'full_name': 'Wind Speed', 'long_name': 'Wind Speed' + } + df_vars_processed['wind_direction'] = { + 'name': 'wind_direction', 'units': '°', + 'full_name': 'Wind Direction', 'long_name': 'Wind Direction' + } + + # Round numeric columns + for var in df_processed.columns: + if var != 'Month' and pd.api.types.is_numeric_dtype(df_processed[var]): + df_processed[var] = df_processed[var].round(2) + + return df_processed, df_vars_processed + + def extract_data( + self, + lon: float, + lat: float, + months: Optional[List[int]] = None + ) -> ClimateDataResult: + """Extract climate data for a location from DestinE data.""" + time_periods = self.source_config.get('time_periods', {}) + var_mapping = self.source_config.get('variable_mapping', {}) + var_suffixes = self.source_config.get('variable_suffixes', {}) + + if months is None: + months = list(range(1, 13)) + + df_list = [] + tree = None + lons = None + lats = None + indices = None + weights = None + use_exact = False + exact_idx = 0 + + # Process each time period + for period_key, period_meta in time_periods.items(): + pattern = period_meta.get('pattern', '') + if not pattern: + continue + + # Build paths for all variable files + var_files = {} + all_exist = True + for nc_var, suffix in var_suffixes.items(): + file_path = os.path.join(self._data_path, f"{pattern}{suffix}") + if os.path.exists(file_path): + var_files[nc_var] = file_path + else: + all_exist = False + logger.debug(f"Missing DestinE file: {file_path}") + break + + if not all_exist: + continue + + # Build spatial index from first variable file (coordinates are same in all files) + first_file = list(var_files.values())[0] + if tree is None: + tree, lons, lats = self._build_spatial_index(first_file) + + # Normalize longitude to 0-360 if data uses that convention + query_lon = lon + if lon < 0 and lons.min() >= 0: + query_lon = lon + 360 + + # Query 4 nearest neighbors + distances, indices = tree.query([query_lon, lat], k=4) + + # Project and compute inverse distance weights + pste = pyproj.Proj( + proj="stere", errcheck=True, ellps='WGS84', + lat_0=lat, lon_0=lon + ) + neighbors_lons = lons[indices] + neighbors_lats = lats[indices] + neighbors_x, neighbors_y = pste(neighbors_lons, neighbors_lats) + desired_x, desired_y = pste(lon, lat) + + dx = neighbors_x - desired_x + dy = neighbors_y - desired_y + distances_proj = np.hypot(dx, dy) + + if np.any(distances_proj == 0): + exact_idx = np.where(distances_proj == 0)[0][0] + use_exact = True + else: + inv_distances = 1.0 / distances_proj + weights = inv_distances / inv_distances.sum() + + # Extract data from all variable files + df_data = {'Month': [calendar.month_name[m] for m in months]} + df_vars = {} + + for display_name, nc_var in var_mapping.items(): + if nc_var not in var_files: + continue + + ds = xr.open_dataset(var_files[nc_var]) + + values = self._extract_point_data( + ds, nc_var, indices, weights, use_exact, exact_idx + ) + + df_data[nc_var] = values + df_vars[nc_var] = { + 'name': nc_var, + 'units': ds[nc_var].attrs.get('units', ''), + 'full_name': display_name, + 'long_name': ds[nc_var].attrs.get('long_name', '') + } + + ds.close() + + df = pd.DataFrame(df_data) + df_processed, df_vars_processed = self._post_process_data(df, df_vars) + + df_list.append({ + 'filename': pattern, + 'years_of_averaging': period_meta.get('years_of_averaging', ''), + 'description': period_meta.get('description', ''), + 'dataframe': df_processed, + 'extracted_vars': df_vars_processed, + 'main': period_meta.get('is_main', False), + 'source': period_meta.get('source', 'DestinE IFS-FESOM') + }) + + # Prepare data_agent_response + data_agent_response = self._prepare_data_agent_response(df_list) + + return ClimateDataResult( + df_list=df_list, + data_agent_response=data_agent_response, + source_name=self.name, + source_description=self.description + ) + + # Factory functions def get_climate_data_provider( @@ -847,6 +1136,8 @@ def get_climate_data_provider( return ICCPProvider(sources_config.get('ICCP', {}), config) elif source == 'AWI_CM': return AWICMProvider(sources_config.get('AWI_CM', {}), config) + elif source == 'DestinE': + return DestinEProvider(sources_config.get('DestinE', {}), config) else: raise ValueError(f"Unknown climate data source: {source}") @@ -864,7 +1155,7 @@ def get_available_providers(config: dict) -> List[str]: List of provider names that are available """ available = [] - for source in ['nextGEMS', 'ICCP', 'AWI_CM']: + for source in ['nextGEMS', 'ICCP', 'AWI_CM', 'DestinE']: try: provider = get_climate_data_provider(config, source) if provider.is_available(): diff --git a/src/climsight/climate_functions.py b/src/climsight/climate_functions.py index bf8734d..fecedaa 100644 --- a/src/climsight/climate_functions.py +++ b/src/climsight/climate_functions.py @@ -63,10 +63,16 @@ def load_data(config): def select_data(dataset, variable, dimensions, lat, lon): """ Selects data for a given variable at specified latitude and longitude. + Handles longitude normalization for datasets using 0-360 range. """ + # Normalize longitude if dataset uses 0-360 range and input is negative + lon_dim = dimensions['longitude'] + if lon < 0 and dataset[lon_dim].min() >= 0: + lon = lon + 360 + return dataset[variable].sel(**{ dimensions['latitude']: lat, - dimensions['longitude']: lon}, + lon_dim: lon}, method="nearest") def verify_shape(hist_units, future_units, variable): diff --git a/src/climsight/climsight_classes.py b/src/climsight/climsight_classes.py index 0f3c07b..c5cb4d3 100644 --- a/src/climsight/climsight_classes.py +++ b/src/climsight/climsight_classes.py @@ -17,12 +17,22 @@ class AgentState(BaseModel): content_message: str = "" input_params: dict = {} smart_agent_response: dict = {} - wikipedia_tool_response: str = "" + wikipedia_tool_response: list = [] ecocrop_search_response: str = "" - rag_search_response: str = "" + rag_search_response: list = [] ipcc_rag_agent_response: str = "" - general_rag_agent_response: str = "" + general_rag_agent_response: str = "" + data_analysis_response: str = "" # Response from data analysis agent + data_analysis_prompt_text: str = "" # Filtered analysis brief for tools + data_analysis_images: list = [] # Paths to generated analysis images df_list: list = [] # List of dataframes with climate data references: list = [] # List of references combine_agent_prompt_text: str = "" + thread_id: str = "" # Session ID for sandbox storage + uuid_main_dir: str = "" # Root sandbox path + results_dir: str = "" # Plot output directory + climate_data_dir: str = "" # Saved climatology directory + era5_data_dir: str = "" # ERA5 output directory + era5_climatology_response: dict = {} # ERA5 observed climatology (ground truth) + era5_tool_response: str = "" # stream_handler: StreamHandler # Uncomment if needed diff --git a/src/climsight/climsight_engine.py b/src/climsight/climsight_engine.py index 8fa71fd..006c0b4 100644 --- a/src/climsight/climsight_engine.py +++ b/src/climsight/climsight_engine.py @@ -61,8 +61,17 @@ # import climsight classes from climsight_classes import AgentState +# sandbox helpers +from sandbox_utils import ( + ensure_thread_id, + ensure_sandbox_dirs, + get_sandbox_paths, + write_climate_data_manifest, +) # import smart_agent from smart_agent import get_aitta_chat_model, smart_agent +# import data_analysis_agent +from data_analysis_agent import data_analysis_agent # import climsight functions from geo_functions import ( @@ -631,6 +640,13 @@ def agent_llm_request(content_message, input_params, config, api_key, api_key_lo if not isinstance(stream_handler, StreamHandler): logging.error(f"stream_handler must be an instance of StreamHandler") raise TypeError("stream_handler must be an instance of StreamHandler") + + # Ensure sandbox paths for this session (Streamlit or CLI). + thread_id = ensure_thread_id(existing_thread_id=input_params.get("thread_id", "")) + sandbox_paths = get_sandbox_paths(thread_id) + ensure_sandbox_dirs(sandbox_paths) + input_params.update(sandbox_paths) + input_params["thread_id"] = thread_id lat = float(input_params['lat']) # should be already present in input_params lon = float(input_params['lon']) # should be already present in input_params @@ -669,6 +685,31 @@ def agent_llm_request(content_message, input_params, config, api_key, api_key_lo max_completion_tokens=4096 ) llm_intro = llm_combine_agent + + # Data analysis LLM (separate from combine step). + llm_dataanalysis_cfg = config.get("llm_dataanalysis") + if not llm_dataanalysis_cfg: + raise RuntimeError("Missing llm_dataanalysis configuration.") + if llm_dataanalysis_cfg.get("model_type") == "local": + llm_dataanalysis_agent = ChatOpenAI( + openai_api_base="http://localhost:8000/v1", + model_name=llm_dataanalysis_cfg.get("model_name"), + openai_api_key=api_key_local, + max_tokens=16000, + ) + elif llm_dataanalysis_cfg.get("model_type") == "openai": + llm_dataanalysis_agent = ChatOpenAI( + openai_api_key=api_key, + model_name=llm_dataanalysis_cfg.get("model_name"), + max_tokens=16000, + ) + elif llm_dataanalysis_cfg.get("model_type") == "aitta": + llm_dataanalysis_agent = get_aitta_chat_model( + llm_dataanalysis_cfg.get("model_name"), + max_completion_tokens=4096 + ) + else: + llm_dataanalysis_agent = llm_combine_agent def zero_rag_agent(state: AgentState, figs = {}): logger.debug(f"get_elevation_from_api from: {lat}, {lon}") @@ -841,6 +882,20 @@ def data_agent(state: AgentState, data={}, df={}): data['high_res_climate'] = data['climate_data'] state.df_list = df_list + # Persist climatology into the sandbox for cross-agent access. + if state.thread_id: + sandbox_paths = get_sandbox_paths(state.thread_id) + ensure_sandbox_dirs(sandbox_paths) + manifest_path, _ = write_climate_data_manifest( + df_list, sandbox_paths["climate_data_dir"], climate_source + ) + state.uuid_main_dir = sandbox_paths["uuid_main_dir"] + state.results_dir = sandbox_paths["results_dir"] + state.climate_data_dir = sandbox_paths["climate_data_dir"] + state.era5_data_dir = sandbox_paths["era5_data_dir"] + state.input_params.update(sandbox_paths) + state.input_params["climate_data_manifest"] = manifest_path + # Add appropriate references based on data source ref_key_map = { 'nextGEMS': 'high_resolution_climate_model', @@ -930,6 +985,7 @@ def intro_agent(state: AgentState): - **No Keywords Required:** Do not look for specific words like "climate" or "weather". - **Accept Fragments:** "Bridge", "Data Center", "Tomatoes", "Here", "My car" are all **VALID**. - **Accept Statements:** "I am worried about the heat", "Building a shed" are **VALID**. + - **Accept Technical Constraints:** Requests specifying years, models, or datasets (e.g., "Use 1980-2000 baseline") are **VALID**. Based on the conversation, decide on one of the following responses: - "next": either "FINISH" or "CONTINUE" @@ -1020,6 +1076,34 @@ def combine_agent(state: AgentState): state.content_message += state.data_agent_response['content_message'] state.input_params.update(state.data_agent_response['input_params']) + #add data_analysis_agent response to content_message and input_params + if state.data_analysis_response: + state.input_params['data_analysis_response'] = state.data_analysis_response + state.content_message += "\n Data analysis agent response: {data_analysis_response} " + + # Add ERA5 climatology (observational ground truth) to prompt + if state.era5_climatology_response and isinstance(state.era5_climatology_response, dict): + era5_data = state.era5_climatology_response + # Format ERA5 data as structured text for the LLM + era5_text = "ERA5 OBSERVATIONAL CLIMATOLOGY (2015-2025 average - GROUND TRUTH):\n" + era5_text += f"Source: {era5_data.get('source', 'ERA5 Reanalysis')}\n" + era5_text += f"Location: {era5_data.get('extracted_location', {})}\n" + if 'variables' in era5_data: + for var_name, var_info in era5_data['variables'].items(): + era5_text += f"\n{var_info.get('full_name', var_name)} ({var_info.get('units', '')}):\n" + monthly = var_info.get('monthly_values', {}) + for month, value in monthly.items(): + era5_text += f" {month}: {value}\n" + state.input_params['era5_climatology'] = era5_text + state.content_message += "\n ERA5 Observations (ground truth baseline): {era5_climatology} " + + # Add generated plot images to prompt so LLM knows about them + if state.data_analysis_images: + state.input_params['data_analysis_images'] = state.data_analysis_images + images_list = ", ".join(state.data_analysis_images) + state.input_params['data_analysis_images_text'] = images_list + state.content_message += "\n Generated visualizations (plot files): {data_analysis_images_text} " + if state.smart_agent_response != {}: smart_analysis = state.smart_agent_response.get('output', '') state.input_params['smart_agent_analysis'] = smart_analysis @@ -1027,12 +1111,7 @@ def combine_agent(state: AgentState): logger.info(f"smart_agent_response: {state.smart_agent_response}") # Add Wikipedia tool response - if state.wikipedia_tool_response != {}: - wiki_response = state.wikipedia_tool_response - state.input_params['wikipedia_tool_response'] = wiki_response - state.content_message += "\n Wikipedia Search Response: {wikipedia_tool_response} " - logger.info(f"Wikipedia_tool_reponse: {state.wikipedia_tool_response}") - if state.ecocrop_search_response != {}: + if state.ecocrop_search_response: ecocrop_response = state.ecocrop_search_response state.input_params['ecocrop_search_response'] = ecocrop_response state.content_message += "\n ECOCROP Search Response: {ecocrop_search_response} " @@ -1079,24 +1158,22 @@ def combine_agent(state: AgentState): } def route_fromintro(state: AgentState) -> Sequence[str]: + """Route from intro agent to parallel information gathering agents.""" output = [] if "FINISH" in state.next: return "FINISH" else: + # Always run these agents in parallel output.append("ipcc_rag_agent") output.append("general_rag_agent") output.append("data_agent") output.append("zero_rag_agent") - #output.append("smart_agent") - return output - def route_fromdata(state: AgentState) -> Sequence[str]: - output = [] - if config['use_smart_agent']: - output.append("smart_agent") - else: - output.append("combine_agent") + + # Conditionally add smart_agent based on config + if config.get('use_smart_agent', False): + output.append("smart_agent") return output - + workflow = StateGraph(AgentState) figs = data_pocket.figs @@ -1108,28 +1185,38 @@ def route_fromdata(state: AgentState) -> Sequence[str]: workflow.add_node("ipcc_rag_agent", ipcc_rag_agent) workflow.add_node("general_rag_agent", general_rag_agent) workflow.add_node("data_agent", lambda s: data_agent(s, data, df)) # Pass `data` as argument - workflow.add_node("zero_rag_agent", lambda s: zero_rag_agent(s, figs)) # Pass `figs` as argument + workflow.add_node("zero_rag_agent", lambda s: zero_rag_agent(s, figs)) # Pass `figs` as argument workflow.add_node("smart_agent", lambda s: smart_agent(s, config, api_key, api_key_local, stream_handler)) + workflow.add_node("data_analysis_agent", lambda s: data_analysis_agent( + s, config, api_key, api_key_local, stream_handler, llm_dataanalysis_agent + )) workflow.add_node("combine_agent", combine_agent) - path_map = {'ipcc_rag_agent':'ipcc_rag_agent', 'general_rag_agent':'general_rag_agent', 'data_agent':'data_agent','zero_rag_agent':'zero_rag_agent','FINISH':END} - path_map_data = {'combine_agent':'combine_agent', 'smart_agent':'smart_agent'} + path_map = { + 'ipcc_rag_agent': 'ipcc_rag_agent', + 'general_rag_agent': 'general_rag_agent', + 'data_agent': 'data_agent', + 'zero_rag_agent': 'zero_rag_agent', + 'smart_agent': 'smart_agent', + 'FINISH': END + } - workflow.set_entry_point("intro_agent") # Set the entry point of the graph - + workflow.set_entry_point("intro_agent") # Set the entry point of the graph + + # Route from intro_agent to parallel agents (conditionally includes smart_agent) workflow.add_conditional_edges("intro_agent", route_fromintro, path_map=path_map) - workflow.add_conditional_edges("data_agent", route_fromdata, path_map=path_map_data) - #if config['use_smart_agent']: - # workflow.add_edge(["ipcc_rag_agent","general_rag_agent","data_agent","zero_rag_agent"], "combine_agent") - #else: - workflow.add_edge(["ipcc_rag_agent","general_rag_agent","smart_agent","zero_rag_agent"], "combine_agent") - - #workflow.add_edge("ipcc_rag_agent", "combine_agent") - #workflow.add_edge("general_rag_agent", "combine_agent") - #workflow.add_edge("data_agent", "combine_agent") - #workflow.add_edge("zero_rag_agent", "combine_agent") - #workflow.add_edge("smart_agent", "combine_agent") + # All parallel agents (ipcc_rag, general_rag, data, zero_rag) go to data_analysis_agent + workflow.add_edge(["ipcc_rag_agent", "general_rag_agent", "data_agent", "zero_rag_agent"], "data_analysis_agent") + + # If smart_agent is enabled, it also goes to data_analysis_agent + if config.get('use_smart_agent', False): + workflow.add_edge("smart_agent", "data_analysis_agent") + + # Data analysis agent goes to combine agent + workflow.add_edge("data_analysis_agent", "combine_agent") + + # Combine agent goes to END workflow.add_edge("combine_agent", END) # Compile the graph app = workflow.compile() @@ -1140,7 +1227,18 @@ def route_fromdata(state: AgentState) -> Sequence[str]: # with open(graph_image_path, 'wb') as f: # f.write(graph_img) # Write the image bytes to the file - state = AgentState(messages=[], input_params=input_params, user=input_params['user_message'], content_message=content_message, references=[]) + state = AgentState( + messages=[], + input_params=input_params, + user=input_params['user_message'], + content_message=content_message, + references=[], + thread_id=input_params.get("thread_id", ""), + uuid_main_dir=input_params.get("uuid_main_dir", ""), + results_dir=input_params.get("results_dir", ""), + climate_data_dir=input_params.get("climate_data_dir", ""), + era5_data_dir=input_params.get("era5_data_dir", ""), + ) stream_handler.update_progress("Starting workflow...") output = app.invoke(state) @@ -1160,4 +1258,4 @@ def route_fromdata(state: AgentState) -> Sequence[str]: stream_handler.send_reference_text('- '+ref+' \n') - return output['final_answer'], input_params, content_message, combine_agent_prompt_text \ No newline at end of file + return output['final_answer'], input_params, content_message, combine_agent_prompt_text diff --git a/src/climsight/config.py b/src/climsight/config.py new file mode 100644 index 0000000..2757216 --- /dev/null +++ b/src/climsight/config.py @@ -0,0 +1,20 @@ +"""Minimal config helpers for tool modules.""" + +import os + +try: + import streamlit as st +except ImportError: + st = None + + +def _get_openai_api_key() -> str: + if st is not None and hasattr(st, "secrets"): + try: + return st.secrets["general"]["openai_api_key"] + except Exception: + pass + return os.environ.get("OPENAI_API_KEY", "") + + +API_KEY = _get_openai_api_key() diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py new file mode 100644 index 0000000..6a12597 --- /dev/null +++ b/src/climsight/data_analysis_agent.py @@ -0,0 +1,662 @@ +""" +Data analysis agent for Climsight. + +This agent mirrors PangaeaGPT's oceanographer/visualization style while operating +on local climatology. It filters context, then uses tools to extract or analyze +climate data, saving outputs into the sandbox. +""" + +import json +import logging +import os +from typing import Any, Dict, List + +from climsight_classes import AgentState + +try: + from utils import make_json_serializable +except ImportError: + from .utils import make_json_serializable +from sandbox_utils import ensure_thread_id, ensure_sandbox_dirs, get_sandbox_paths +from agent_helpers import create_standard_agent_executor +from tools.get_data_components import create_get_data_components_tool +from tools.era5_climatology_tool import create_era5_climatology_tool +from tools.era5_retrieval_tool import create_era5_retrieval_tool +from tools.python_repl import CustomPythonREPLTool +from tools.image_viewer import create_image_viewer_tool +from tools.reflection_tools import reflect_tool +from tools.visualization_tools import ( + list_plotting_data_files_tool, + wise_agent_tool, +) + +from langchain_core.prompts import ChatPromptTemplate + +logger = logging.getLogger(__name__) + + +def _build_climate_data_summary(df_list: List[Dict[str, Any]]) -> str: + """Summarize available climatology without exposing raw values.""" + if not df_list: + return "No climatology data available." + + lines = [] + for entry in df_list: + vars_summary = [] + for var_name, var_info in entry.get("extracted_vars", {}).items(): + full_name = var_info.get("full_name", var_name) + units = var_info.get("units", "") + vars_summary.append(f"{full_name} ({units})") + vars_text = ", ".join(vars_summary) if vars_summary else "Unknown variables" + lines.append( + f"- {entry.get('years_of_averaging', '')}: {entry.get('description', '')} | {vars_text}" + ) + + return "\n".join(lines) + + +def _build_datasets_text(state) -> str: + """Build simple dataset paths text for prompt injection. + + IMPORTANT: The Python REPL kernel CWD is already set to the sandbox root, + so we tell the agent to use RELATIVE paths (not full tmp/sandbox/... paths). + """ + lines = [ + "## Sandbox Paths (Python REPL is ALREADY inside the sandbox directory)", + "**CRITICAL: Use RELATIVE paths in your Python code, NOT full paths starting with 'tmp/sandbox/...'**", + f"- Current Working Directory: '.' (which is {state.uuid_main_dir})", + f"- Results directory: 'results' (save all plots here)", + f"- Climate data: 'climate_data'", + ] + + if state.era5_data_dir: + lines.append(f"- ERA5 data: 'era5_data'") + + # List available climate data files + if state.climate_data_dir and os.path.exists(state.climate_data_dir): + try: + files = os.listdir(state.climate_data_dir) + if files: + lines.append(f"\n## Climate Data Files Available (in 'climate_data/' folder)") + lines.append(f"Files: {', '.join(files)}") + # Highlight the main data.csv file + if "data.csv" in files: + lines.append("Note: Load with `pd.read_csv('climate_data/data.csv')`") + except Exception as e: + logger.warning(f"Could not list climate data files: {e}") + + return "\n".join(lines) + + +def _build_filter_prompt() -> str: + """Prompt for the analysis brief filter LLM.""" + return ( + "You are a data analysis filter for a climate assistant.\n" + "Extract only actionable analysis requirements.\n\n" + "Output format (bullets only):\n" + "- Target variables (with units if specified)\n" + "- Thresholds or criteria\n" + "- Time ranges or scenarios\n" + "- Spatial specifics (location, buffers)\n" + "- Analysis tasks (comparisons, trends, plots)\n\n" + "Rules:\n" + "- Do NOT include raw climate data values.\n" + "- Do NOT include long RAG or Wikipedia text.\n" + "- Omit vague statements that are not actionable.\n" + ) + + +def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon: float = None) -> str: + """System prompt for tool-driven analysis - dynamically built based on config.""" + has_era5_climatology = config.get("era5_climatology", {}).get("enabled", True) + has_era5_download = config.get("use_era5_data", False) + has_repl = config.get("use_powerful_data_analysis", False) + + prompt = """You are the data analysis agent for ClimSight. +Your job is to provide quantitative climate analysis with visualizations. + +## AVAILABLE TOOLS +""" + + tool_num = 1 + + # TOOL #1: ERA5 Climatology - ALWAYS FIRST (observational ground truth) + if has_era5_climatology: + coord_example = "" + if lat is not None and lon is not None: + coord_example = f" - For this query: get_era5_climatology(latitude={lat}, longitude={lon}, variables=[\"t2m\", \"tp\", \"u10\", \"v10\"])\n" + + prompt += f""" +{tool_num}. **get_era5_climatology** - Extract OBSERVED climate data (CALL THIS FIRST!) + - Source: ERA5 reanalysis 2015-2025 monthly climatology + - This is GROUND TRUTH - actual observations, not model output + - Variables: temperature (t2m), precipitation (tp), wind_u (u10), wind_v (v10), dewpoint (d2m), pressure (msl) + - Returns monthly averages for the nearest grid point (~28km resolution) +{coord_example} + **CRITICAL**: This tool provides what the climate ACTUALLY IS at this location. + Use this as the BASELINE to compare against climate model projections. +""" + tool_num += 1 + + # TOOL #2: get_data_components - only if Python_REPL is NOT available + if not has_repl: + prompt += f""" +{tool_num}. **get_data_components** - Extract climate variables from climate MODEL projections + - Variables: Temperature, Precipitation, u_wind, v_wind + - Returns monthly values for historical AND future climate projections + - Example: get_data_components(environmental_data="Temperature", months=["Jan", "Feb", "Mar"]) + - These are MODEL outputs - compare with ERA5 observations to assess model quality +""" + tool_num += 1 + + # TOOL #3: ERA5 time series download (optional, for detailed analysis) + if has_era5_download: + prompt += f""" +{tool_num}. **retrieve_era5_data** - Retrieve ERA5 Surface climate data from Earthmover (Arraylake) + + Use this tool to retrieve **historical weather/climate context (Time Series)**. + + **DATA SOURCE:** Earthmover (Arraylake), hardcoded to "temporal" mode. + **VARIABLE CODES:** Use short codes: 't2' (Temp), 'u10'/'v10' (Wind), 'mslp' (Pressure), 'tp' (Precip). + **WORK_DIR:** Pass work_dir='.' to save in current sandbox. + + **OUTPUT & LOADING:** + The tool returns an absolute path to a Zarr store (saved in 'era5_data/' folder). + **CRITICAL:** In Python_REPL, use RELATIVE paths to load the data since CWD is already the sandbox: + + ```python + import xarray as xr + import glob + + # List available ERA5 Zarr files (RELATIVE path) + era5_files = glob.glob('era5_data/*.zarr') + print(era5_files) + + # Load using RELATIVE path (NOT the absolute path from tool response) + ds = xr.open_dataset('era5_data/era5_t2_temporal_....zarr', engine='zarr', chunks={{{{}}}}) + data = ds['t2'].to_series() + ``` + +""" + tool_num += 1 + + # TOOL #4: Python REPL + if has_repl: + prompt += f""" +{tool_num}. **Python_REPL** - Execute Python code for data analysis and visualizations + - Pre-loaded: pandas (pd), numpy (np), matplotlib.pyplot (plt), xarray (xr) + - Working directory is ALREADY the sandbox root + - ALWAYS save plots to results/ directory + **CRITICAL PATH RULE:** + ❌ WRONG: `base='tmp/sandbox/uuid...'` then `f'{{{{base}}}}/era5_data/...'` + ✓ CORRECT: Use relative paths directly: `'era5_data/...'`, `'climate_data/...'`, `'results/...'` + The kernel CWD is already inside the sandbox, so DO NOT prepend 'tmp/sandbox/...'! + + **Climate Model Data** (climate_data/ directory): + - `climate_data_manifest.json` - READ THIS FIRST to see all available simulations + - `simulation_1.csv`, `simulation_2.csv`, ... - Data for different time periods + - `simulation_N_meta.json` - Metadata (years, description) for each simulation + - `data.csv` - Main/baseline simulation only (for quick access) + - Columns: Month, mean2t (temperature °C), tp (precipitation mm/month), wind_u, wind_v, wind_speed, wind_direction + - Use list_plotting_data_files tool to discover all available files + + **ERA5 Climatology** (era5_climatology.json): + - After calling get_era5_climatology, results are saved here + - Load with: `import json; era5 = json.load(open('era5_climatology.json'))` + - Use ERA5 as GROUND TRUTH baseline for comparisons + + Example workflow: + ```python + import pandas as pd + import json + import matplotlib.pyplot as plt + + # 1. Load manifest to see all available simulations + manifest = json.load(open('climate_data/climate_data_manifest.json')) + print(f"Data source: {{{{manifest['source']}}}}") + for entry in manifest['entries']: + print(f" {{{{entry['csv']}}}} : {{{{entry['years_of_averaging']}}}}") + + # 2. Load ERA5 observations (ground truth) + era5 = json.load(open('era5_climatology.json')) + era5_temp = era5['variables']['t2m']['monthly_values'] + + # 3. Load multiple climate model simulations + simulations = [] + for entry in manifest['entries']: + df = pd.read_csv(f"climate_data/{{{{entry['csv']}}}}") + simulations.append({{{{'df': df, 'years': entry['years_of_averaging'], 'main': entry['main']}}}})) + + # 4. Plot comparison + months = list(era5_temp.keys()) + plt.figure(figsize=(12, 6)) + plt.plot(months, list(era5_temp.values()), 'k-o', linewidth=2, label='ERA5 Observations') + for sim in simulations: + style = '-' if sim['main'] else '--' + plt.plot(months, sim['df']['mean2t'].tolist(), style, label=f"Model {{{{sim['years']}}}}") + plt.xlabel('Month') + plt.ylabel('Temperature (°C)') + plt.legend() + plt.tight_layout() + plt.savefig('results/temperature_comparison.png', dpi=150) + plt.close() + ``` +""" + tool_num += 1 + + prompt += f""" +{tool_num}. **list_plotting_data_files** - List files in sandbox directories +""" + tool_num += 1 + + # image_viewer only available with Python REPL + if has_repl: + prompt += f""" +{tool_num}. **image_viewer** - View generated plots to verify quality + + Pass the EXACT path printed by Python_REPL after saving a plot. + Use this to verify your visualizations before finalizing. + +""" + tool_num += 1 + + # reflect_on_image and wise_agent only available with Python REPL + if has_repl: + prompt += f"""{tool_num}. **reflect_on_image** - Analyze a generated plot and get feedback for improvements +{tool_num + 1}. **wise_agent** - Get guidance on complex visualization decisions +""" + tool_num += 2 + + prompt += """ +## REQUIRED WORKFLOW +""" + + if has_era5_climatology: + prompt += """ +**STEP 1 - GET OBSERVATIONS (MANDATORY):** +Call get_era5_climatology FIRST to get the observed climate baseline. +- Extract at minimum: temperature (t2m), precipitation (tp) +- This is what the climate ACTUALLY IS at this location (2015-2025 average) +""" + # Add ERA5 time series download and analysis steps if available + if has_era5_download and has_repl: + prompt += """ +**STEP 2 - DOWNLOAD ERA5 TIME SERIES (RECOMMENDED):** +Call retrieve_era5_data to get detailed historical time series for deeper analysis. +- Download temperature (t2) and precipitation (tp) for 2015-2024 +- This provides year-by-year data, not just climatological averages +- Data saved to era5_data/ as .zarr files for Python analysis +- Enables trend analysis, extreme event detection, interannual variability + +**STEP 3 - ANALYZE ERA5 TIME SERIES:** +Use Python_REPL to analyze the downloaded ERA5 time series: +```python +import xarray as xr +ds = xr.open_zarr('era5_data/era5_t2_temporal_YYYYMMDD_YYYYMMDD.zarr') +# Compute trends, identify extreme years, plot interannual variability +``` +""" + + if has_repl: + # step_offset accounts for ERA5 download (step 2) and analysis (step 3) when enabled + step_offset = 2 if has_era5_download else 0 + prompt += f""" +**STEP {2 + step_offset} - LOAD MODEL DATA:** +Use Python_REPL to read climate_data/data.csv (model projections) + +**STEP {3 + step_offset} - COMPARE OBSERVATIONS vs MODEL:** +- ERA5 = ground truth (what we observe NOW) +- Model historical = what the model simulates for recent past +- Difference = MODEL BIAS (critical for interpreting future projections) + +**STEP {4 + step_offset} - ANALYZE FUTURE WITH CONTEXT:** +- Show future projections from climate model +- Explain how model bias affects confidence +- Future change = (Model future) - (ERA5 baseline) + +**STEP {5 + step_offset} - CREATE VISUALIZATIONS (MANDATORY):** +- Plot 1: ERA5 observations vs model (shows model bias) +- Plot 2: Future projections with ERA5 baseline +- Save ALL plots to results/ directory +""" + else: + prompt += """ +**STEP 2 - GET MODEL DATA:** +Call get_data_components for Temperature and Precipitation + +**STEP 3 - COMPARE:** +- ERA5 climatology = observations (ground truth) +- Model data = projections +- Note any differences between observed and modeled values +""" + elif has_repl: + prompt += """ +1. **READ THE DATA** - Use Python_REPL to inspect climate_data/data.csv +2. **ANALYZE** - Compare historical vs future projections +3. **CREATE VISUALIZATIONS** - Save plots to results/ directory +""" + else: + prompt += """ +1. **ALWAYS START** by calling get_data_components for Temperature (all months) +2. **THEN** call get_data_components for Precipitation (all months) +3. **ANALYZE** the data - compare historical vs future projections +""" + + prompt += f""" +## SANDBOX PATHS AND DATA + +{datasets_text} + +## PROACTIVE ANALYSIS + +Even if the user doesn't explicitly ask for plots, you SHOULD: +- Create temperature trend visualizations +- Show precipitation comparisons +- Highlight months with largest projected changes +- Identify potential climate risks (heat stress, drought, flooding) +""" + + if has_era5_climatology: + prompt += """ +**With ERA5 observations, ALWAYS include:** +- Current observed climate (from ERA5 - this is REALITY) +- Model performance assessment (how well does the model match observations?) +- Future projections interpreted in context of model quality +""" + + prompt += """ +## OUTPUT FORMAT + +Your final response should include: +""" + + if has_era5_climatology: + prompt += """1. **Current Climate (ERA5 Observations)**: What the climate ACTUALLY IS (2015-2025 average) +2. **Model Assessment**: How well climate models match ERA5 observations +3. **Future Projections**: Model predictions with confidence based on model-observation agreement +4. **Climate Change Signal**: Projected changes from current observed baseline +5. **Critical Months**: Which months show largest changes +6. **Visualizations**: List of plot files created +7. **Implications**: Interpretation relevant to the user's query +""" + else: + prompt += """1. **Key Climate Values**: Extracted temperature, precipitation data +2. **Climate Change Signal**: Differences between historical and future projections +3. **Critical Months**: Which months show largest changes +4. **Visualizations**: List of plot files created (if Python_REPL available) +5. **Implications**: Brief interpretation relevant to the user's query +""" + + prompt += """ +Limit total tool calls to 50. +""" + return prompt + + +def _normalize_tool_observation(observation: Any) -> Any: + """Normalize tool output into a plain Python object.""" + try: + from langchain_core.messages import AIMessage + except Exception: + AIMessage = None + + if AIMessage is not None and isinstance(observation, AIMessage): + return observation.content + return observation + + +def data_analysis_agent( + state: AgentState, + config: dict, + api_key: str, + api_key_local: str, + stream_handler, + llm_dataanalysis_agent=None, +): + """Run filtered analysis + tool-based climatology extraction.""" + stream_handler.update_progress("Data analysis: preparing sandbox...") + + # Ensure sandbox paths are available. + thread_id = ensure_thread_id(existing_thread_id=state.thread_id) + sandbox_paths = get_sandbox_paths(thread_id) + ensure_sandbox_dirs(sandbox_paths) + + state.thread_id = thread_id + state.uuid_main_dir = sandbox_paths["uuid_main_dir"] + state.results_dir = sandbox_paths["results_dir"] + state.climate_data_dir = sandbox_paths["climate_data_dir"] + state.era5_data_dir = sandbox_paths["era5_data_dir"] + + # Build analysis context for filtering. + climate_summary = _build_climate_data_summary(state.df_list) + context_sections = [ + f"User query: {state.user}", + f"Location: {state.input_params.get('location_str', '')}", + f"Coordinates: {state.input_params.get('lat', '')}, {state.input_params.get('lon', '')}", + f"Climatology summary:\n{climate_summary}", + ] + + if state.ipcc_rag_agent_response: + context_sections.append(f"IPCC RAG: {state.ipcc_rag_agent_response}") + if state.general_rag_agent_response: + context_sections.append(f"General RAG: {state.general_rag_agent_response}") + if state.smart_agent_response: + context_sections.append(f"Smart agent: {state.smart_agent_response.get('output', '')}") + if state.ecocrop_search_response: + context_sections.append(f"ECOCROP: {state.ecocrop_search_response}") + if state.zero_agent_response: + safe_zero_context = make_json_serializable(state.zero_agent_response) + context_sections.append(f"Local context: {json.dumps(safe_zero_context, indent=2)}") + + analysis_context = "\n\n".join(context_sections) + + # Check if filter step is enabled (configurable) + use_filter_step = config.get("llm_dataanalysis", {}).get("use_filter_step", True) + + if use_filter_step and llm_dataanalysis_agent is not None: + stream_handler.update_progress("Data analysis: filtering context...") + filter_prompt = ChatPromptTemplate.from_messages( + [ + ("system", _build_filter_prompt()), + ("user", "{context}"), + ] + ) + result = llm_dataanalysis_agent.invoke(filter_prompt.format_messages(context=analysis_context)) + filtered_context = result.content if hasattr(result, "content") else str(result) + + # CRITICAL: Always preserve the user's original question + location_str = state.input_params.get('location_str', 'Unknown location') + analysis_brief = f"""USER QUESTION: {state.user} + +Location: {location_str} +Coordinates: {state.input_params.get('lat', '')}, {state.input_params.get('lon', '')} + +ANALYSIS REQUIREMENTS: +{filtered_context} +""" + else: + # Skip filter step - pass essential context directly + stream_handler.update_progress("Data analysis: preparing context (no filter)...") + location_str = state.input_params.get('location_str', 'Unknown location') + analysis_brief = f"""USER QUESTION: {state.user} + +Location: {location_str} +Coordinates: {state.input_params.get('lat', '')}, {state.input_params.get('lon', '')} + +Available climatology: +{climate_summary} + +Required analysis: +- Extract Temperature and Precipitation data +- Compare historical vs future projections +- Create visualizations if Python_REPL is available +""" + + state.data_analysis_prompt_text = analysis_brief + + brief_path = os.path.join(state.uuid_main_dir, "analysis_brief.txt") + with open(brief_path, "w", encoding="utf-8") as f: + f.write(analysis_brief) + + # Build simplified datasets_text for prompt + datasets_text = _build_datasets_text(state) + + # Build datasets dict for Python REPL + datasets = { + "uuid_main_dir": state.uuid_main_dir, + "results_dir": state.results_dir, + } + if state.climate_data_dir: + datasets["climate_data_dir"] = state.climate_data_dir + if state.era5_data_dir: + datasets["era5_data_dir"] = state.era5_data_dir + + # Get coordinates for prompt + lat = state.input_params.get('lat') + lon = state.input_params.get('lon') + try: + lat = float(lat) if lat is not None else None + lon = float(lon) if lon is not None else None + except (ValueError, TypeError): + lat, lon = None, None + + # Tool setup - ORDER MATTERS (matches prompt workflow) + tools = [] + + has_python_repl = config.get("use_powerful_data_analysis", False) + has_era5_climatology = config.get("era5_climatology", {}).get("enabled", True) + + # 1. ERA5 Climatology - ALWAYS FIRST (observational ground truth) + if has_era5_climatology: + tools.append(create_era5_climatology_tool(state, config, stream_handler)) + + # 2. get_data_components - ONLY if Python_REPL is NOT available + # (When Python_REPL is enabled, it's redundant - agent reads CSV directly) + if not has_python_repl: + tools.append(create_get_data_components_tool(state, config, stream_handler)) + + # 3. ERA5 time series retrieval (if enabled - for detailed year-by-year analysis) + if config.get("use_era5_data", False): + arraylake_api_key = config.get("arraylake_api_key", "") + if arraylake_api_key: + tools.append(create_era5_retrieval_tool(arraylake_api_key)) + else: + logger.warning("ERA5 data enabled but no arraylake_api_key in config. ERA5 retrieval tool not added.") + + # 4. Python REPL for analysis/visualization (if enabled) + if has_python_repl: + repl_tool = CustomPythonREPLTool( + datasets=datasets, + results_dir=state.results_dir, + session_key=thread_id, + ) + tools.append(repl_tool) + + # 5. Helper tools + tools.append(list_plotting_data_files_tool) + + # 6. Image viewer - ONLY when Python REPL is enabled + if has_python_repl: + # Extract model name from config or default to gpt-4o + vision_model = config.get("llm_combine", {}).get("model_name", "gpt-4o") + + # Pass sandbox_path so relative paths from the agent can be resolved + image_viewer_tool = create_image_viewer_tool( + openai_api_key=api_key, + model_name=vision_model, + sandbox_path=state.uuid_main_dir + ) + tools.append(image_viewer_tool) + + # 7. Image reflection and wise_agent - ONLY when Python REPL is enabled + # (these tools are for evaluating/creating visualizations) + if has_python_repl: + tools.append(reflect_tool) + tools.append(wise_agent_tool) + + stream_handler.update_progress("Data analysis: running tools...") + tool_prompt = _create_tool_prompt(datasets_text, config, lat=lat, lon=lon) + + if llm_dataanalysis_agent is None: + from langchain_openai import ChatOpenAI + + llm_dataanalysis_agent = ChatOpenAI( + openai_api_key=api_key, + model_name=config.get("llm_combine", {}).get("model_name", "gpt-4.1-nano"), + ) + + agent_executor = create_standard_agent_executor( + llm_dataanalysis_agent, + tools, + tool_prompt, + max_iterations=20, + ) + + agent_input = { + "input": analysis_brief or state.user, + "messages": state.messages, + } + + result = agent_executor(agent_input) + + data_components_outputs = [] + plot_images: List[str] = [] + era5_climatology_output = None + + for action, observation in result.get("intermediate_steps", []): + if action.tool == "get_era5_climatology": + obs = _normalize_tool_observation(observation) + if isinstance(obs, dict) and "error" not in obs: + era5_climatology_output = obs + state.era5_climatology_response = obs + if action.tool == "get_data_components": + data_components_outputs.append(_normalize_tool_observation(observation)) + if action.tool in ("Python_REPL", "python_repl"): + obs = _normalize_tool_observation(observation) + if isinstance(obs, dict): + plot_images.extend(obs.get("plot_images", [])) + if action.tool == "retrieve_era5_data": + # Handle ERA5 retrieval tool output + obs = _normalize_tool_observation(observation) + if isinstance(obs, dict): + era5_output = str(obs) + elif hasattr(obs, 'content'): + era5_output = obs.content + else: + era5_output = str(obs) + # Store in state + state.era5_tool_response = era5_output + state.input_params.setdefault("era5_results", []).append(obs) + + analysis_text = result.get("output", "") + + # Append ERA5 climatology summary if available + if era5_climatology_output: + analysis_text += "\n\n### ERA5 Observational Baseline (2015-2025)\n" + analysis_text += f"Location: {era5_climatology_output.get('extracted_location', {})}\n" + if "variables" in era5_climatology_output: + for var_name, var_data in era5_climatology_output["variables"].items(): + analysis_text += f"\n**{var_data.get('full_name', var_name)}** ({var_data.get('units', '')}):\n" + monthly = var_data.get("monthly_values", {}) + # Show a few key months + for month in ["January", "April", "July", "October"]: + if month in monthly: + analysis_text += f" {month}: {monthly[month]}\n" + + if data_components_outputs: + analysis_text += "\n\n### Climate Model Extracts:\n" + for item in data_components_outputs: + analysis_text += json.dumps(item, indent=2) + "\n" + + state.data_analysis_response = analysis_text + state.data_analysis_images = plot_images + + stream_handler.update_progress("Data analysis complete.") + + return { + "data_analysis_response": analysis_text, + "data_analysis_images": plot_images, + "data_analysis_prompt_text": analysis_brief, + "era5_climatology_response": state.era5_climatology_response, + "era5_tool_response": getattr(state, 'era5_tool_response', None), + } diff --git a/src/climsight/geo_functions.py b/src/climsight/geo_functions.py index 5c05089..bc81280 100644 --- a/src/climsight/geo_functions.py +++ b/src/climsight/geo_functions.py @@ -46,7 +46,7 @@ def get_location(lat, lon): "User-Agent": "climsight", "accept-language": "en" } - response = requests.get(url, params=params, headers=headers, timeout=5) + response = requests.get(url, params=params, headers=headers, timeout=10) location = response.json() # Wait before making the next request (according to terms of use) @@ -347,7 +347,7 @@ def get_elevation_from_api(lat, lon): float: The elevation of the location in meters. """ url = f"https://api.opentopodata.org/v1/etopo1?locations={lat},{lon}" - response = requests.get(url, timeout=3) + response = requests.get(url, timeout=10) data = response.json() return data["results"][0]["elevation"] @@ -370,7 +370,7 @@ def fetch_land_use(lon, lat): area.a["landuse"]; out tags; """ - response = requests.get(overpass_url, params={"data": overpass_query}, timeout=3) + response = requests.get(overpass_url, params={"data": overpass_query}, timeout=10) data = response.json() return data @@ -388,7 +388,7 @@ def get_soil_from_api(lat, lon): """ try: url = f"https://rest.isric.org/soilgrids/v2.0/classification/query?lon={lon}&lat={lat}&number_classes=5" - response = requests.get(url, timeout=3) # Set timeout to 2 seconds + response = requests.get(url, timeout=10) # Set timeout to 2 seconds data = response.json() return data["wrb_class_name"] except Timeout: diff --git a/src/climsight/sandbox_utils.py b/src/climsight/sandbox_utils.py new file mode 100644 index 0000000..5b52953 --- /dev/null +++ b/src/climsight/sandbox_utils.py @@ -0,0 +1,131 @@ +""" +Sandbox utilities for per-session data storage. + +This mirrors the PangaeaGPT layout while keeping the API minimal for Climsight. +""" + +import json +import logging +import os +import uuid +from pathlib import Path +from typing import Dict, List, Tuple + +logger = logging.getLogger(__name__) + + +def ensure_thread_id(session_state=None, existing_thread_id: str = "") -> str: + """Ensure a stable session thread_id across Streamlit/CLI runs.""" + thread_id = existing_thread_id or "" + + if not thread_id and session_state is not None: + thread_id = session_state.get("thread_id", "") + + if not thread_id: + thread_id = uuid.uuid4().hex + + if session_state is not None: + session_state["thread_id"] = thread_id + + # Expose to non-Streamlit tools (CLI, background workers). + # CRITICAL: Always update (not setdefault) to ensure all tools use the same thread_id + os.environ["CLIMSIGHT_THREAD_ID"] = thread_id + + return thread_id + + +def get_sandbox_paths(thread_id: str) -> Dict[str, str]: + """Return sandbox paths for a given session.""" + base_dir = Path("tmp") / "sandbox" / thread_id + return { + "uuid_main_dir": str(base_dir), + "results_dir": str(base_dir / "results"), + "climate_data_dir": str(base_dir / "climate_data"), + "era5_data_dir": str(base_dir / "era5_data"), + } + + +def ensure_sandbox_dirs(paths: Dict[str, str]) -> None: + """Create sandbox directories if they do not exist.""" + for key, path in paths.items(): + if not path: + continue + os.makedirs(path, exist_ok=True) + logger.debug("Ensured sandbox dir %s: %s", key, path) + + +def write_climate_data_manifest( + df_list: List[Dict], + climate_data_dir: str, + source: str, +) -> Tuple[str, List[Dict]]: + """Persist climate dataframes and metadata into the sandbox. + + Returns: + (manifest_path, entries) + """ + os.makedirs(climate_data_dir, exist_ok=True) + + entries: List[Dict] = [] + main_index = 0 + + for i, entry in enumerate(df_list): + if entry.get("main"): + main_index = i + break + + for i, entry in enumerate(df_list): + df = entry.get("dataframe") + if df is None: + continue + + csv_name = f"simulation_{i + 1}.csv" + meta_name = f"simulation_{i + 1}_meta.json" + csv_path = os.path.join(climate_data_dir, csv_name) + meta_path = os.path.join(climate_data_dir, meta_name) + + df.to_csv(csv_path, index=False) + + meta = { + "years_of_averaging": entry.get("years_of_averaging", ""), + "description": entry.get("description", ""), + "extracted_vars": entry.get("extracted_vars", {}), + "main": bool(entry.get("main", False)), + "source": entry.get("source", source), + "filename": entry.get("filename", ""), + } + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(meta, f, indent=2) + + entries.append({ + "csv": csv_name, + "meta": meta_name, + "years_of_averaging": meta["years_of_averaging"], + "description": meta["description"], + "main": meta["main"], + }) + + # Provide a simple auto-load CSV for Python_REPL parity. + if df_list: + main_df = df_list[main_index].get("dataframe") + if main_df is not None: + main_df.to_csv(os.path.join(climate_data_dir, "data.csv"), index=False) + + manifest = { + "source": source, + "entries": entries, + } + manifest_path = os.path.join(climate_data_dir, "climate_data_manifest.json") + with open(manifest_path, "w", encoding="utf-8") as f: + json.dump(manifest, f, indent=2) + + return manifest_path, entries + + +def load_climate_data_manifest(manifest_path: str) -> Dict: + """Load a climate data manifest from disk.""" + if not manifest_path or not os.path.exists(manifest_path): + return {} + + with open(manifest_path, "r", encoding="utf-8") as f: + return json.load(f) diff --git a/src/climsight/smart_agent.py b/src/climsight/smart_agent.py index e27eea9..b0794a5 100644 --- a/src/climsight/smart_agent.py +++ b/src/climsight/smart_agent.py @@ -31,6 +31,7 @@ #Import tools from tools.python_repl import create_python_repl_tool from tools.image_viewer import create_image_viewer_tool +# era5_retrieval_tool is used only in data_analysis_agent #import requests #from bs4 import BeautifulSoup @@ -82,364 +83,93 @@ def smart_agent(state: AgentState, config, api_key, api_key_local, stream_handle # Create working directory work_dir = Path("tmp/sandbox") / st.session_state.session_uuid work_dir.mkdir(parents=True, exist_ok=True) + work_dir_str = str(work_dir.resolve()) # System prompt prompt = f""" - You are the smart agent of ClimSight. Your task is to retrieve necessary components of the climatic datasets based on the user's request. - You have access to tools called "get_data_components", "wikipedia_search", "RAG_search" and "ECOCROP_search" and "python_repl" which you can use to retrieve the necessary environmental data components. - - "get_data_components" will retrieve the necessary data from the climatic datasets at the location of interest (latitude: {lat}, longitude: {lon}). It accepts an 'environmental_data' parameter to specify the type of data, and a 'months' parameter to specify which months to retrieve data for. The 'months' parameter is a list of month names (e.g., ['Jan', 'Feb', 'Mar']). If 'months' is not specified, data for all months will be retrieved. - Call "get_data_components" tool multiple times if necessary, but only within one iteration, [chat_completion -> n * "get_data_components" -> chat_completion] after you recieve the necessary data from wikipedia_search and RAG_search. - - "wikipedia_search" will help you determine the necessary data to retrieve with the get_data_components tool. - - "RAG_search" can provide detailed information about environmental conditions for growing corn from your internal knowledge base. - - "ECOCROP_search" will help you determine the specific environmental requirements for the crop of interest from ecocrop database. - call "ECOCROP_search" ONLY and ONLY if you sure that the user question is related to the crop of interest. - - "python_repl" allows you to execute Python code for data analysis, visualization, and calculations. - Use this tool when: - - Creating visualizations (plots, charts, graphs) of climate data - - Performing statistical analysis (means, trends, correlations, standard deviations) - - Comparing data between different time periods (e.g., historical vs future projections) - - Calculating climate indicators or derived metrics - - Any analysis that requires more than simple data retrieval - - The tool has access to pandas, numpy, matplotlib, and xarray. - - **IMPORTANT: Climate data is pre-loaded in the Python environment. To see what's available, run:** - ```python - print(DATA_CATALOG) # Shows all available climate datasets and their descriptions - print(list(locals().keys())) # Shows all available variables - ``` - The climate data includes historical reference periods and future projections with monthly temperature, - precipitation, and wind data for your specific location. - - **CRITICAL: Your working directory is available at `work_dir` = '{str(work_dir)}'** - **When saving plots, ALWAYS store the full path in a variable for later use!** - - **CORRECT way to save and reference images:** - ```python - # Save the plot and store the full path - plot_path = f'{{{{work_dir}}}}/my_plot.png' # Or use Path(work_dir) / 'my_plot.png' - plt.savefig(plot_path) - print(f"Plot saved to: {{{{plot_path}}}}") # Verify the path - # Now plot_path contains the full path for image_viewer - ``` - - **WRONG way (DO NOT do this):** - ```python - plt.savefig('work_dir/my_plot.png') # This is just a string, not the actual path! - ``` - - - - "image_viewer" allows you to analyze saved climate visualizations to extract scientific insights. - Use this tool when: - - You have created and saved a visualization using python_repl - - You need to describe the patterns shown in a climate plot - - You want to extract specific values or trends from a generated figure - - The tool will provide scientific analysis of the image including: - - Data description and quantitative observations - - Temporal patterns and climate insights - - Key findings and implications - - **CRITICAL: How to use image_viewer correctly:** - 1. First save your plot in python_repl and store the path: - ```python - plot_path = f'{{{{work_dir}}}}/temperature_plot.png' - plt.savefig(plot_path) - ``` - 2. Then use image_viewer with the VARIABLE containing the path: - ``` - image_viewer(plot_path) # Use the variable, not a string! - ``` - - **NEVER do this:** - ``` - image_viewer('work_dir/plot.png') # WRONG - this is just a string! - ``` - - **Your actual work_dir path is: {str(work_dir)}** - **Always use the full path stored in a variable when calling image_viewer!** - + You are the information gathering agent of ClimSight. Your task is to collect relevant background + information about the user's query using external knowledge sources. + + You have access to three information retrieval tools: + - "wikipedia_search": Search Wikipedia for general information about the topic of interest. + Use this to understand the broader context of the user's question. + + - "RAG_search": Query ClimSight's internal climate knowledge base for scientific details. + Use this to find detailed scientific information from climate literature and research. + + - "ECOCROP_search": Get crop-specific environmental requirements from the ECOCROP database. + ONLY use this tool if the user's question is clearly about agriculture or crops. + Do NOT use it for general climate queries. + + **Your goal**: Gather comprehensive background information and compile it into a well-structured + summary that will be used by subsequent agents for data analysis. + + **Important guidelines**: + - When citing information from Wikipedia or RAG sources, include source references in your summary + - Present ECOCROP database results exactly as they are provided (no citations needed for database facts) + - Focus on gathering contextual information, NOT on extracting or analyzing specific climate data + - Do NOT attempt to retrieve climate data values - that will be handled by other agents + - Your output should be a comprehensive text summary with relevant background context + - If you call a tool multiple times, incorporate all results in your summary + - Do NOT call wikipedia_search more than 10 times total """ if config['llm_smart']['model_type'] in ("local", "aitta"): prompt += f""" - - Tool use order. Always call the wikipedia_search, RAG_search, and ECOCROP_search tools as needed, but only one at a time per turn.; it will help you determine the necessary data to retrieve with the get_data_components tool. At second step, call the get_data_components tool with the necessary data. + - Tool use order. Call the wikipedia_search, RAG_search, and ECOCROP_search tools as needed, + but only one at a time per turn. Gather all relevant background information. """ else: prompt += f""" - - Tool use order. ALWAYS call FIRST SIMULTANEOUSLY the wikipedia_search, RAG_search and "ECOCROP_search"; it will help you determine the necessary data to retrieve with the get_data_components tool. At second step, call the get_data_components tool with the necessary data. - - """ + - Tool use order. Call wikipedia_search, RAG_search, and ECOCROP_search + (if applicable) SIMULTANEOUSLY to gather background information efficiently. + + """ prompt += f""" - Use these tools to get the data you need to answer the user's question. - After retrieving the data, provide a concise summary of the parameters you retrieved, explaining briefly why they are important. Keep your response short and to the point. - Do not include any additional explanations or reasoning beyond the concise summary. - Do not include any chain-of-thought reasoning or action steps in your final answer. - Do not ask the user for any additional information, but you can include into the final answer what kind of information user should provide in the future. - - Additional workflow guidance: - - If the user asks for analysis, trends, comparisons, or visualizations, use python_repl after retrieving data to: - * Create plots and charts - * Calculate statistics (averages, changes, trends) - * Compare different time periods - * Save any outputs to work_dir - - - For the final response try to follow the following format: - 'The [retrieved values of the parameter] for the [object of interest] at [location] is [value for current and future are ...], [according to the Wikipedia article] the required [parameter] for [object of interest] is [value]. [Two sentence of clarification, with criitcal montly-based assessment of the potential changes]' - 'Repeat for each parameter.' + **Output format**: + After gathering information from the available tools, provide a well-structured summary organized as follows: + + 1. **General Background** (from Wikipedia if available): + - Key facts and context about the topic + - Relevant qualitative and quantitative information + - Include specific values with units where mentioned + + 2. **Scientific Context** (from RAG if available): + - Detailed scientific information from climate literature + - Relevant research findings and climate insights + - Technical details that provide context + + 3. **Specific Requirements** (from ECOCROP if applicable): + - Database results presented exactly as provided + - Environmental thresholds and requirements + - Optimal and tolerable ranges + + 4. **Key Takeaways**: + - Summarize the most important points for understanding the user's query + - Highlight critical factors or thresholds mentioned + + + - Keep your summary concise but comprehensive + - Do NOT include chain-of-thought reasoning or tool invocation details + - Do NOT mention what you're going to do or explain your process + - Focus on presenting the gathered information in a clear, organized format + - The summary should help subsequent agents understand the context for data analysis + - If you have already called wikipedia_search 10 times, proceed without further Wikipedia calls """ - - - #[1] Tool description for netCDF extraction - class get_data_components_args(BaseModel): - environmental_data: Optional[Union[str, Literal["Temperature", "Precipitation", "u_wind", "v_wind"]]] = Field( - default=None, - description="The type of environmental data to retrieve. Choose from Temperature, Precipitation, u_wind, or v_wind.", - enum_description={ - "Temperature": "The mean monthly temperature data.", - "Precipitation": "The mean monthly precipitation data.", - "u_wind": "The mean monthly u wind component data.", - "v_wind": "The mean monthly v wind component data." - } - ) - months: Optional[Union[str, List[Literal["Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]]]] = Field( - default=None, - description="List of months or a stringified list of month names to retrieve data for. Each month should be one of 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'. If not specified, data for all months will be retrieved." - ) - def get_data_components(**kwargs): - stream_handler.update_progress("Retrieving data for advanced analysis with a smart agent...") - - if isinstance(kwargs.get("months"), str): - try: - kwargs["months"] = ast.literal_eval(kwargs["months"]) - except (ValueError, SyntaxError): - # Optional: handle invalid input - kwargs["months"] = None # or raise an exception - args = get_data_components_args(**kwargs) - # Parse the arguments using the args_schema - environmental_data = args.environmental_data - months = args.months # List of month names - if environmental_data is None: - return {"error": "No environmental data type specified."} - if environmental_data not in ["Temperature", "Precipitation", "u_wind", "v_wind"]: - return {"error": f"Invalid environmental data type: {environmental_data}"} - - # Get climate data source from config - climate_source = config.get('climate_data_source', 'nextGEMS') - - # Check if we have df_list from the provider (unified approach) - df_list = getattr(state, 'df_list', None) - - if df_list: - # Use unified df_list from any provider (nextGEMS, ICCP, AWI_CM) - response = {} - - # Variable mapping depends on data source - if climate_source == 'nextGEMS': - environmental_mapping = { - "Temperature": "mean2t", - "Precipitation": "tp", - "u_wind": "wind_u", - "v_wind": "wind_v" - } - elif climate_source == 'ICCP': - environmental_mapping = { - "Temperature": "mean2t", - "Precipitation": "tp", - "u_wind": "wind_u", - "v_wind": "wind_v" - } - elif climate_source == 'AWI_CM': - environmental_mapping = { - "Temperature": "Present Day Temperature", - "Precipitation": "Present Day Precipitation", - "u_wind": "u_wind", - "v_wind": "v_wind" - } - else: - environmental_mapping = { - "Temperature": "mean2t", - "Precipitation": "tp", - "u_wind": "wind_u", - "v_wind": "wind_v" - } - - if environmental_data not in environmental_mapping: - return {"error": f"Invalid environmental data type: {environmental_data}"} - - # Filter the DataFrame for the selected months and extract the values - var_name = environmental_mapping[environmental_data] - - if not months: - months = [calendar.month_abbr[m] for m in range(1, 13)] - - # Create a mapping from abbreviated to full month names - month_mapping = {calendar.month_abbr[m]: calendar.month_name[m] for m in range(1, 13)} - selected_months = [month_mapping[abbr] for abbr in months] + #[1] get_data_components tool moved to data_analysis_agent.py - for entry in df_list: - df = entry.get('dataframe') - extracted_vars = entry.get('extracted_vars', {}) + #[2] Wikipedia processing tool + wikipedia_call_state = {"count": 0} - if df is None: - raise ValueError(f"Entry does not contain a 'dataframe' key.") + def process_wikipedia_article(query: str) -> str: + if wikipedia_call_state["count"] >= 10: + return "" + wikipedia_call_state["count"] += 1 - # Try to find the variable in the dataframe - if var_name in df.columns: - var_meta = extracted_vars.get(var_name, {'units': ''}) - data_values = df[df['Month'].isin(selected_months)][var_name].tolist() - ext_data = {month: np.round(value, 2) for month, value in zip(selected_months, data_values)} - ext_exp = f"Monthly mean values of {environmental_data}, {var_meta.get('units', '')} for years: " + entry['years_of_averaging'] - response.update({ext_exp: ext_data}) - else: - # Try to find by full name in columns - matching_cols = [col for col in df.columns if environmental_data.lower() in col.lower()] - if matching_cols: - col_name = matching_cols[0] - data_values = df[df['Month'].isin(selected_months)][col_name].tolist() - ext_data = {month: np.round(value, 2) for month, value in zip(selected_months, data_values)} - ext_exp = f"Monthly mean values of {environmental_data} for years: " + entry['years_of_averaging'] - response.update({ext_exp: ext_data}) - - return response - else: - # Legacy fallback for AWI_CM without df_list - lat = float(state.input_params['lat']) - lon = float(state.input_params['lon']) - data_path = config['data_settings']['data_path'] - - # Dictionaries for historical and SSP585 data files - data_files_historical = { - "Temperature": ("AWI_CM_mm_historical.nc", "tas"), - "Precipitation": ("AWI_CM_mm_historical_pr.nc", "pr"), - "u_wind": ("AWI_CM_mm_historical_uas.nc", "uas"), - "v_wind": ("AWI_CM_mm_historical_vas.nc", "vas") - } - - data_files_ssp585 = { - "Temperature": ("AWI_CM_mm_ssp585.nc", "tas"), - "Precipitation": ("AWI_CM_mm_ssp585_pr.nc", "pr"), - "u_wind": ("AWI_CM_mm_ssp585_uas.nc", "uas"), - "v_wind": ("AWI_CM_mm_ssp585_vas.nc", "vas") - } - - if environmental_data not in data_files_historical: - return {"error": f"Invalid environmental data type: {environmental_data}"} - - # Get file names and variable names for both datasets - file_name_hist, var_name_hist = data_files_historical[environmental_data] - file_name_ssp585, var_name_ssp585 = data_files_ssp585[environmental_data] - - # Build file paths - file_path_hist = os.path.join(data_path, file_name_hist) - file_path_ssp585 = os.path.join(data_path, file_name_ssp585) - - # Check if files exist - if not os.path.exists(file_path_hist): - return {"error": f"Data file {file_name_hist} not found in {data_path}"} - if not os.path.exists(file_path_ssp585): - return {"error": f"Data file {file_name_ssp585} not found in {data_path}"} - - # Open datasets - dataset_hist = nc.Dataset(file_path_hist) - dataset_ssp585 = nc.Dataset(file_path_ssp585) - - # Get latitude and longitude arrays - lats_hist = dataset_hist.variables['lat'][:] - lons_hist = dataset_hist.variables['lon'][:] - lats_ssp585 = dataset_ssp585.variables['lat'][:] - lons_ssp585 = dataset_ssp585.variables['lon'][:] - - # Find the nearest indices for historical data - lat_idx_hist = (np.abs(lats_hist - lat)).argmin() - lon_idx_hist = (np.abs(lons_hist - lon)).argmin() - - # Find the nearest indices for SSP585 data - lat_idx_ssp585 = (np.abs(lats_ssp585 - lat)).argmin() - lon_idx_ssp585 = (np.abs(lons_ssp585 - lon)).argmin() - - # Extract data at the specified location - data_hist = dataset_hist.variables[var_name_hist][:, :, :, lat_idx_hist, lon_idx_hist] - data_ssp585 = dataset_ssp585.variables[var_name_ssp585][:, :, :, lat_idx_ssp585, lon_idx_ssp585] - - # Squeeze data to remove singleton dimensions (shape becomes (12,)) - data_hist = np.squeeze(data_hist) - data_ssp585 = np.squeeze(data_ssp585) - - # Process data according to the variable - if environmental_data == "Temperature": - # Convert from Kelvin to Celsius - data_hist = data_hist - 273.15 - data_ssp585 = data_ssp585 - 273.15 - units = "°C" - elif environmental_data == "Precipitation": - # Convert from kg m-2 s-1 to mm/month - days_in_month = np.array([31, 28, 31, 30, 31, 30, - 31, 31, 30, 31, 30, 31]) - seconds_in_month = days_in_month * 24 * 3600 # seconds in each month - data_hist = data_hist * seconds_in_month - data_ssp585 = data_ssp585 * seconds_in_month - units = "mm/month" - elif environmental_data in ["u_wind", "v_wind"]: - # Units are already in m/s - units = "m/s" - else: - units = "unknown" - - # Close datasets - dataset_hist.close() - dataset_ssp585.close() - - # List of all month names - all_months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] - - # Map month names to indices - month_indices = {month: idx for idx, month in enumerate(all_months)} - - # If months are specified, select data for those months - if months: - # Validate months - valid_months = [month for month in months if month in month_indices] - if not valid_months: - return {"error": "Invalid months provided."} - selected_indices = [month_indices[month] for month in valid_months] - selected_months = valid_months - else: - # Use all months if none are specified - selected_indices = list(range(12)) - selected_months = all_months - - # Subset data for selected months - data_hist = data_hist[selected_indices] - data_ssp585 = data_ssp585[selected_indices] - - # Create dictionaries mapping months to values with units - hist_data_dict = {month: f"{value:.2f} {units}" for month, value in zip(selected_months, data_hist)} - ssp585_data_dict = {month: f"{value:.2f} {units}" for month, value in zip(selected_months, data_ssp585)} - - # Return both historical and SSP585 data - return { - f"{environmental_data}_historical": hist_data_dict, - f"{environmental_data}_ssp585": ssp585_data_dict - } - - # Define the data_extraction_tool - data_extraction_tool = StructuredTool.from_function( - func=get_data_components, - name="get_data_components", - description="Retrieve the necessary environmental data component.", - args_schema=get_data_components_args - ) - - #[2] Wikipedia processing tool - def process_wikipedia_article(query: str) -> str: stream_handler.update_progress("Searching Wikipedia for related information with a smart agent...") # Initialize the LLM @@ -782,53 +512,9 @@ def process_ecocrop_search(query: str) -> str: args_schema=EcoCropSearchArgs ) - python_repl_tool = create_python_repl_tool() - - def inject_climate_context(): - context = { - 'lat': lat, - 'lon': lon, - 'location_str': state.input_params.get('location_str', ''), - 'work_dir': str(work_dir), - } - - # Add climate data if available - if state.df_list: - # Add all dataframes from df_list - for i, entry in enumerate(state.df_list): - df = entry.get('dataframe') - if df is not None: - context[f'climate_df_{i}'] = df - context[f'climate_info_{i}'] = { - 'years': entry.get('years_of_averaging', ''), - 'description': entry.get('description', ''), - 'variables': entry.get('extracted_vars', {}) - } - - # ADD THIS - Build DATA_CATALOG dynamically - catalog = "Available climate datasets:\n" - for i, entry in enumerate(state.df_list): - years = entry.get('years_of_averaging', '') - desc = entry.get('description', '') - is_main = " (historical reference)" if entry.get('main', False) else "" - catalog += f"- climate_df_{i}: {years}{is_main} - {desc}\n" - - catalog += "\nEach dataset contains monthly values for:\n" - if state.df_list and state.df_list[0].get('extracted_vars'): - for var_name, var_info in state.df_list[0]['extracted_vars'].items(): - catalog += f"- {var_info['full_name']} ({var_info['units']})\n" - - context['DATA_CATALOG'] = catalog - - return context - - # Inject climate context into Python REPL BEFORE agent runs - if hasattr(python_repl_tool.func, '__self__'): - repl_instance = python_repl_tool.func.__self__ - context = inject_climate_context() - repl_instance.locals.update(context) - + # Create python_repl tool + #python_repl_tool = create_python_repl_tool() # Initialize the LLM if config['llm_smart']['model_type'] == "local": @@ -849,21 +535,8 @@ def inject_climate_context(): llm = get_aitta_chat_model(config['llm_smart']['model_name'], temperature = 0) # List of tools - #tools = [data_extraction_tool, rag_tool, wikipedia_tool, ecocrop_tool, python_repl_tool] - tools = [data_extraction_tool, rag_tool, ecocrop_tool, wikipedia_tool]#python_repl_tool] - - ##Append image viewer for openai models - #if config['model_type'] == "openai": - # try: - # image_viewer_tool = create_image_viewer_tool( - # api_key, - # config['model_name_agents'] # Use model from config - # ) - # tools.append(image_viewer_tool) - # except Exception as e: - # pass - - # Create the agent with the tools and prompt + tools = [rag_tool, ecocrop_tool, wikipedia_tool] + prompt += """\nadditional information:\n question is related to this location: {location_str} \n """ @@ -902,58 +575,61 @@ def inject_climate_context(): result = agent_executor(agent_input) # Extract the tool outputs - tool_outputs = {} + wikipedia_results = [] + rag_results = [] + + def _extract_tool_query(action): + tool_input = getattr(action, "tool_input", None) + if tool_input is None: + tool_input = getattr(action, "input", None) + if isinstance(tool_input, dict): + return tool_input.get("query") or tool_input.get("input") or tool_input + return tool_input + + def _normalize_tool_output(observation): + if isinstance(observation, AIMessage): + return observation.content, None + if isinstance(observation, dict): + result = observation.get("result") + if isinstance(result, AIMessage): + result = result.content + return result, observation.get("references") + return observation, None + for action, observation in result['intermediate_steps']: - if action.tool == 'wikipedia_search': - if isinstance(observation, AIMessage): - tool_outputs['wikipedia_search'] = observation.content - else: - output = observation - # Assuming output_data is a dict now - if isinstance(output, dict): - tool_outputs['wikipedia_search'] = output.get('result').content - state.references.append(output.get('references', [])) - else: - tool_outputs['wikipedia_search'] = output - elif action.tool == 'get_data_components': - if isinstance(observation, AIMessage): - tool_outputs['get_data_components'] = observation.content - else: - tool_outputs['get_data_components'] = observation - elif action.tool == 'ECOCROP_search': - if isinstance(observation, AIMessage): - tool_outputs['ECOCROP_search'] = observation.content - else: - tool_outputs['ECOCROP_search'] = observation - if any("FAO, IIASA" not in element for element in state.references): - state.references.append("FAO, IIASA: Global Agro-Ecological Zones (GAEZ V4) - Data Portal User's Guide, 1st edn. FAO and IIASA, Rome, Italy (2021). https://doi.org/10.4060/cb5167en") - elif action.tool == 'python_repl': - if isinstance(observation, AIMessage): - tool_outputs['python_repl'] = observation.content - else: - tool_outputs['python_repl'] = observation - if action.tool == 'RAG_search': - if isinstance(observation, AIMessage): - tool_outputs['RAG_search'] = observation.content - else: - output = observation - # Assuming output_data is a dict now - if isinstance(output, dict): - tool_outputs['RAG_search'] = output.get('result').content - # If refs is a list, extend; if it's a string, append. - refs = output.get('references', []) - if isinstance(refs, list): - state.references.extend(refs) - elif isinstance(refs, str): - state.references.append(refs) - - # Store the response from the wikipedia_search tool into state - if 'wikipedia_search' in tool_outputs: - state.wikipedia_tool_response = tool_outputs['wikipedia_search'] - if 'ECOCROP_search' in tool_outputs: - state.ecocrop_search_response = tool_outputs['ECOCROP_search'] - if 'RAG_search' in tool_outputs: - state.rag_search_response = tool_outputs['RAG_search'] + tool_name = action.tool + tool_query = _extract_tool_query(action) + + if tool_name == 'wikipedia_search': + answer_text, references = _normalize_tool_output(observation) + if answer_text: + wikipedia_results.append({"query": tool_query, "answer": answer_text}) + if references: + if isinstance(references, list): + state.references.extend(references) + else: + state.references.append(references) + elif tool_name == 'RAG_search': + answer_text, references = _normalize_tool_output(observation) + if answer_text: + rag_results.append({"query": tool_query, "answer": answer_text}) + if references: + if isinstance(references, list): + state.references.extend(references) + else: + state.references.append(references) + elif tool_name == 'ECOCROP_search': + answer_text, _ = _normalize_tool_output(observation) + if answer_text: + if state.ecocrop_search_response: + state.ecocrop_search_response += "\n" + answer_text + else: + state.ecocrop_search_response = answer_text + if any("FAO, IIASA" not in element for element in state.references): + state.references.append("FAO, IIASA: Global Agro-Ecological Zones (GAEZ V4) - Data Portal User's Guide, 1st edn. FAO and IIASA, Rome, Italy (2021). https://doi.org/10.4060/cb5167en") + + state.wikipedia_tool_response = wikipedia_results + state.rag_search_response = rag_results # Also store the agent's final answer smart_agent_response = result['output'] diff --git a/src/climsight/stream_handler.py b/src/climsight/stream_handler.py index 0720661..b3699c3 100644 --- a/src/climsight/stream_handler.py +++ b/src/climsight/stream_handler.py @@ -1,4 +1,7 @@ from langchain_core.callbacks.base import BaseCallbackHandler +import logging + +logger = logging.getLogger(__name__) class StreamHandler(BaseCallbackHandler): """ @@ -34,7 +37,15 @@ def _display_text(self): if self.container: display_function = getattr(self.container, self.display_method, None) if display_function is not None: - display_function(self.text) + try: + display_function(self.text) + except Exception as e: + # Streamlit context not available (e.g., running in worker thread) + # Log the message instead + if "NoSessionContext" in str(type(e).__name__): + logger.debug(f"Streamlit context not available, skipping display: {self.text[:100]}") + else: + logger.error(f"Error displaying text: {e}") else: raise ValueError(f"Invalid display_method: {self.display_method}") @@ -42,7 +53,14 @@ def _display_reference_text(self): if self.container2: display_function = getattr(self.container2, self.display_method, None) if display_function is not None: - display_function(self.reference_text) + try: + display_function(self.reference_text) + except Exception as e: + # Streamlit context not available (e.g., running in worker thread) + if "NoSessionContext" in str(type(e).__name__): + logger.debug(f"Streamlit context not available, skipping reference display") + else: + logger.error(f"Error displaying reference text: {e}") else: raise ValueError(f"Invalid display_method: {self.display_method}") @@ -50,7 +68,15 @@ def _display_progress(self): if self.container: display_function = getattr(self.container, "info", None) if display_function is not None: - display_function(self.progress_text) + try: + display_function(self.progress_text) + except Exception as e: + # Streamlit context not available (e.g., running in worker thread) + # This is expected when agents run in parallel + if "NoSessionContext" in str(type(e).__name__): + logger.debug(f"Progress update (no UI context): {self.progress_text}") + else: + logger.error(f"Error displaying progress: {e}") def get_text(self): return self.text diff --git a/src/climsight/streamlit_interface.py b/src/climsight/streamlit_interface.py index 5746421..b229418 100644 --- a/src/climsight/streamlit_interface.py +++ b/src/climsight/streamlit_interface.py @@ -22,6 +22,7 @@ from extract_climatedata_functions import plot_climate_data from embedding_utils import create_embeddings from climate_data_providers import get_available_providers +from sandbox_utils import ensure_thread_id, ensure_sandbox_dirs, get_sandbox_paths #ui for saving docs from datetime import datetime @@ -42,7 +43,11 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r - references (dict): References for the data used in the analysis. Returns: None - """ + """ + # Ensure sandbox exists for the current session. + thread_id = ensure_thread_id(session_state=st.session_state) + sandbox_paths = get_sandbox_paths(thread_id) + ensure_sandbox_dirs(sandbox_paths) # Config try: @@ -66,7 +71,10 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r api_key_local = os.environ.get("OPENAI_API_KEY_LOCAL") if not api_key_local: api_key_local = "" - + + # Check for Arraylake API key (for ERA5 data retrieval) + arraylake_api_key = os.environ.get("ARRAYLAKE_API_KEY", "") + #read data while loading here ##### like hist, future = load_data(config) @@ -130,9 +138,19 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r else: config['llm_combine']['model_type'] = "local" with col1: - show_add_info = st.toggle("Provide additional information", value=False, help="""If this is activated you will see all the variables - that were taken into account for the analysis as well as some plots.""") + # Always show additional information (removed toggle per user request) + show_add_info = True smart_agent = st.toggle("Use extra search", value=False, help="""If this is activated, ClimSight will make additional requests to Wikipedia and RAG, which can significantly increase response time.""") + use_era5_data = st.toggle( + "Enable ERA5 data", + value=config.get("use_era5_data", False), + help="Allow the data analysis agent to retrieve ERA5 data into the sandbox.", + ) + use_powerful_data_analysis = st.toggle( + "Enable Python analysis", + value=config.get("use_powerful_data_analysis", False), + help="Allow the data analysis agent to use the Python REPL and generate plots.", + ) # remove the llmModeKey_box from the form, as we tend to run the agent mode, direct mode is for development only #llmModeKey_box = st.radio("Select LLM mode 👉", key="visibility", options=["Direct", "Agent (experimental)"]) @@ -149,7 +167,8 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r source_descriptions = { 'nextGEMS': 'nextGEMS (High resolution)', 'ICCP': 'ICCP (AWI-CM3, medium resolution)', - 'AWI_CM': 'AWI-CM (CMIP6, low resolution)' + 'AWI_CM': 'AWI-CM (CMIP6, low resolution)', + 'DestinE': 'DestinE IFS-FESOM (High resolution, SSP3-7.0)' } col1_src, col2_src = st.columns([1, 1]) @@ -179,7 +198,16 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r type="password", ) - + # Include Arraylake API key input if ERA5 data is enabled and key not in environment + arraylake_api_key_input = "" + if use_era5_data and not arraylake_api_key: + arraylake_api_key_input = st.text_input( + "Arraylake API key (for ERA5 data)", + placeholder="Enter your Arraylake API key here", + type="password", + help="Required for downloading ERA5 time series data from Earthmover/Arraylake.", + ) + # Replace the st.button with st.form_submit_button submit_button = st.form_submit_button(label='Generate') @@ -190,22 +218,48 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r if (not api_key) and (not skip_llm_call) and (config['llm_combine']['model_type'] == "openai"): st.error("Please provide an OpenAI API key.") st.stop() + + # Handle Arraylake API key for ERA5 data + if use_era5_data: + if not arraylake_api_key: + arraylake_api_key = arraylake_api_key_input + if not arraylake_api_key: + st.error("Please provide an Arraylake API key to use ERA5 data retrieval.") + st.stop() + # Store in config so data_analysis_agent can pass it to the tool + config["arraylake_api_key"] = arraylake_api_key + # Update config with the selected LLM mode - #config['llmModeKey'] = "direct_llm" if llmModeKey_box == "Direct" else "agent_llm" + #config['llmModeKey'] = "direct_llm" if llmModeKey_box == "Direct" else "agent_llm" config['show_add_info'] = show_add_info config['use_smart_agent'] = smart_agent - - # RUN submit button + config['use_era5_data'] = use_era5_data + config['use_powerful_data_analysis'] = use_powerful_data_analysis + + # RUN submit button if submit_button and user_message: if not api_key: api_key = api_key_input if (not api_key) and (not skip_llm_call) and (config['llm_combine']['model_type'] == "openai"): st.error("Please provide an OpenAI API key.") st.stop() + + # Handle Arraylake API key for ERA5 data (in nested block too) + if use_era5_data: + if not arraylake_api_key: + arraylake_api_key = arraylake_api_key_input + if not arraylake_api_key: + st.error("Please provide an Arraylake API key to use ERA5 data retrieval.") + st.stop() + # Store in config so data_analysis_agent can pass it to the tool + config["arraylake_api_key"] = arraylake_api_key + # Update config with the selected LLM mode - #config['llmModeKey'] = "direct_llm" if llmModeKey_box == "Direct" else "agent_llm" + #config['llmModeKey'] = "direct_llm" if llmModeKey_box == "Direct" else "agent_llm" config['show_add_info'] = show_add_info config['use_smart_agent'] = smart_agent + config['use_era5_data'] = use_era5_data + config['use_powerful_data_analysis'] = use_powerful_data_analysis # Creating a potential bottle neck here with loading the db inside the streamlit form, but it works fine # for the moment. Just making a note here for any potential problems that might arise later one. @@ -266,6 +320,9 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r is_on_land = False st.markdown(f"The selected point is in the ocean.\n Please choose a location on land.") else: + # Pass sandbox paths into the agent state. + input_params['thread_id'] = thread_id + input_params.update(sandbox_paths) # extend input_params with user_message input_params['user_message'] = user_message content_message = "Human request: {user_message} \n " + content_message @@ -288,7 +345,17 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r # Add a method to update the progress area def update_progress_ui(message): - progress_area.info(message) + try: + progress_area.info(message) + except Exception as e: + # Streamlit context not available (e.g., running in worker thread) + # This is expected when agents run in parallel + import logging + logger = logging.getLogger(__name__) + if "NoSessionContext" in str(type(e).__name__): + logger.debug(f"Progress update (no UI context): {message}") + else: + logger.error(f"Error displaying progress: {e}") # Attach this method to your StreamHandler stream_handler.update_progress = update_progress_ui @@ -421,6 +488,16 @@ def update_progress_ui(message): with st.expander("Source"): st.markdown(model_info) + # Data analysis images (from python_repl) + analysis_images = stored_input_params.get('data_analysis_images', []) + if analysis_images: + st.markdown("**Data analysis visuals:**") + for image_path in analysis_images: + if os.path.exists(image_path): + st.image(image_path) + else: + st.caption(f"Missing image: {image_path}") + # Natural Hazards if 'haz_fig' in stored_figs: st.markdown("**Natural hazards:**") @@ -480,4 +557,4 @@ def update_progress_ui(message): key=f"download_button_txt_{timestamp}" ) - return \ No newline at end of file + return diff --git a/src/climsight/terminal_interface.py b/src/climsight/terminal_interface.py index 6909357..ea2cc6f 100644 --- a/src/climsight/terminal_interface.py +++ b/src/climsight/terminal_interface.py @@ -17,6 +17,7 @@ from data_container import DataContainer from extract_climatedata_functions import plot_climate_data +from sandbox_utils import ensure_thread_id, ensure_sandbox_dirs, get_sandbox_paths logger = logging.getLogger(__name__) @@ -57,6 +58,11 @@ def run_terminal(config, api_key='', skip_llm_call=False, lon=None, lat=None, us logging.error(f"Missing configuration key: {e}") raise RuntimeError(f"Missing configuration key: {e}") + # Ensure sandbox exists for this CLI session. + thread_id = ensure_thread_id() + sandbox_paths = get_sandbox_paths(thread_id) + ensure_sandbox_dirs(sandbox_paths) + if not isinstance(skip_llm_call, bool): logging.error(f"skip_llm_call must be bool") raise TypeError("skip_llm_call must be bool") @@ -192,6 +198,9 @@ def run_terminal(config, api_key='', skip_llm_call=False, lon=None, lat=None, us is_on_land = False print_verbose(verbose, f"The selected point is in the ocean. Please choose a location on land.") else: + # Pass sandbox paths into the agent state. + input_params['thread_id'] = thread_id + input_params.update(sandbox_paths) # extend input_params with user_message input_params['user_message'] = user_message content_message = "Human request: {user_message} \n " + content_message @@ -331,4 +340,4 @@ def print_progress(message): #print(f"Time for forming request: {forming_request_time}") #print(f"Time for LLM request: {llm_request_time}") - return output, input_params, content_message, combine_agent_prompt_text \ No newline at end of file + return output, input_params, content_message, combine_agent_prompt_text diff --git a/src/climsight/tools/__init__.py b/src/climsight/tools/__init__.py index b46b142..6e60df4 100644 --- a/src/climsight/tools/__init__.py +++ b/src/climsight/tools/__init__.py @@ -3,7 +3,15 @@ Tools for Climsight smart agents. """ -from .python_repl import create_python_repl_tool from .image_viewer import create_image_viewer_tool +from .python_repl import CustomPythonREPLTool, create_python_repl_tool +from .era5_climatology_tool import create_era5_climatology_tool +from .era5_retrieval_tool import era5_retrieval_tool -__all__ = ['create_python_repl_tool', 'create_image_viewer_tool'] \ No newline at end of file +__all__ = [ + 'CustomPythonREPLTool', + 'create_python_repl_tool', + 'create_image_viewer_tool', + 'create_era5_climatology_tool', + 'era5_retrieval_tool', +] \ No newline at end of file diff --git a/src/climsight/tools/era5_climatology_tool.py b/src/climsight/tools/era5_climatology_tool.py new file mode 100644 index 0000000..bfe6308 --- /dev/null +++ b/src/climsight/tools/era5_climatology_tool.py @@ -0,0 +1,325 @@ +"""Tool for extracting ERA5 climatology data from pre-computed Zarr file. + +This tool provides OBSERVED climate data (ground truth) from ERA5 reanalysis. +It should be called FIRST before analyzing climate model projections. +""" + +import json +import logging +import os +from typing import List, Optional, Union + +import numpy as np +import xarray as xr +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +# Default path to ERA5 climatology file +DEFAULT_ERA5_CLIMATOLOGY_PATH = "data/era5/era5_climatology_2015_2025.zarr" + +# Variable name mapping: agent-friendly names -> ERA5 variable names +VARIABLE_ALIASES = { + # Temperature + "temperature": "t2m", + "temp": "t2m", + "2m_temperature": "t2m", + "t2m": "t2m", + # Dewpoint + "dewpoint": "d2m", + "dewpoint_temperature": "d2m", + "2m_dewpoint": "d2m", + "d2m": "d2m", + # Precipitation + "precipitation": "tp", + "precip": "tp", + "total_precipitation": "tp", + "tp": "tp", + # Wind U component + "wind_u": "u10", + "u_wind": "u10", + "10m_u_wind": "u10", + "u10": "u10", + # Wind V component + "wind_v": "v10", + "v_wind": "v10", + "10m_v_wind": "v10", + "v10": "v10", + # Pressure + "pressure": "msl", + "mean_sea_level_pressure": "msl", + "msl": "msl", + # Surface pressure + "surface_pressure": "sp", + "sp": "sp", + # Sea surface temperature + "sea_surface_temperature": "sst", + "sst": "sst", +} + +# Variable metadata +# Note: tp in ERA5 Zarr is stored as m/day (daily average rate), converted to mm/month +VARIABLE_INFO = { + "t2m": {"full_name": "2 metre temperature", "units": "K", "convert_to_celsius": True}, + "d2m": {"full_name": "2 metre dewpoint temperature", "units": "K", "convert_to_celsius": True}, + "tp": {"full_name": "Total precipitation", "units": "m/day", "convert_to_mm": True}, + "u10": {"full_name": "10 metre U wind component", "units": "m/s", "convert_to_celsius": False}, + "v10": {"full_name": "10 metre V wind component", "units": "m/s", "convert_to_celsius": False}, + "msl": {"full_name": "Mean sea level pressure", "units": "Pa", "convert_to_celsius": False}, + "sp": {"full_name": "Surface pressure", "units": "Pa", "convert_to_celsius": False}, + "sst": {"full_name": "Sea surface temperature", "units": "K", "convert_to_celsius": True}, +} + +MONTH_NAMES = [ + "January", "February", "March", "April", "May", "June", + "July", "August", "September", "October", "November", "December" +] + +# Days per month (climatological average, Feb=28.25 for leap year average) +DAYS_IN_MONTH = [31, 28.25, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + + +def _haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float: + """Calculate distance in km between two points using Haversine formula.""" + R = 6371 # Earth's radius in km + + lat1_rad = np.radians(lat1) + lat2_rad = np.radians(lat2) + dlat = np.radians(lat2 - lat1) + dlon = np.radians(lon2 - lon1) + + a = np.sin(dlat/2)**2 + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon/2)**2 + c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a)) + + return R * c + + +def _normalize_longitude(lon: float) -> float: + """Normalize longitude to 0-360 range (ERA5 uses 0-360).""" + if lon < 0: + return lon + 360 + return lon + + +class GetERA5ClimatologyArgs(BaseModel): + latitude: float = Field( + description="Latitude of the location to extract data for (decimal degrees, -90 to 90)" + ) + longitude: float = Field( + description="Longitude of the location to extract data for (decimal degrees, -180 to 180)" + ) + variables: Optional[Union[str, List[str]]] = Field( + default=None, + description=( + "List of variables to extract. Options: temperature (t2m), precipitation (tp), " + "wind_u (u10), wind_v (v10), dewpoint (d2m), pressure (msl), surface_pressure (sp), sst. " + "If not specified, extracts temperature and precipitation by default." + ) + ) + + +def create_era5_climatology_tool(state, config, stream_handler=None): + """Create a StructuredTool for extracting ERA5 climatology data. + + Args: + state: AgentState with sandbox paths + config: Configuration dict + stream_handler: Optional progress handler + + Returns: + StructuredTool instance + """ + try: + from langchain.tools import StructuredTool + except ImportError: + from langchain_core.tools import StructuredTool + + # Get ERA5 climatology path from config or use default + era5_config = config.get("era5_climatology", {}) + era5_path = era5_config.get("path", DEFAULT_ERA5_CLIMATOLOGY_PATH) + + def get_era5_climatology( + latitude: float, + longitude: float, + variables: Optional[Union[str, List[str]]] = None + ) -> dict: + """Extract ERA5 climatology for a specific location. + + This provides OBSERVED climate data (2015-2025 average) as ground truth. + Use this data as the baseline to compare against climate model projections. + """ + if stream_handler is not None: + stream_handler.update_progress("Extracting ERA5 observational climatology...") + + # Parse variables input + if variables is None: + variables = ["t2m", "tp"] # Default: temperature and precipitation + elif isinstance(variables, str): + # Handle string input (might be JSON list or comma-separated) + try: + import ast + variables = ast.literal_eval(variables) + except (ValueError, SyntaxError): + variables = [v.strip() for v in variables.split(",")] + + # Normalize variable names + normalized_vars = [] + for var in variables: + var_lower = var.lower().strip() + if var_lower in VARIABLE_ALIASES: + normalized_vars.append(VARIABLE_ALIASES[var_lower]) + else: + logger.warning(f"Unknown variable: {var}. Skipping.") + + if not normalized_vars: + return { + "error": "No valid variables specified.", + "available_variables": list(VARIABLE_INFO.keys()) + } + + # Remove duplicates while preserving order + normalized_vars = list(dict.fromkeys(normalized_vars)) + + # Check if ERA5 file exists + if not os.path.exists(era5_path): + return { + "error": f"ERA5 climatology file not found at: {era5_path}", + "suggestion": "Check config.yml era5_climatology.path setting" + } + + try: + # Open the Zarr dataset + ds = xr.open_zarr(era5_path) + + # Normalize longitude to 0-360 range (ERA5 convention) + lon_normalized = _normalize_longitude(longitude) + + # Find nearest point + nearest = ds.sel( + latitude=latitude, + longitude=lon_normalized, + method="nearest" + ) + + actual_lat = float(nearest.latitude.values) + actual_lon = float(nearest.longitude.values) + + # Convert actual longitude back to -180 to 180 if needed + actual_lon_display = actual_lon if actual_lon <= 180 else actual_lon - 360 + + # Calculate distance to nearest point + distance_km = _haversine_distance(latitude, longitude, actual_lat, actual_lon_display) + + # Extract data for each variable + variables_data = {} + for var_name in normalized_vars: + if var_name not in ds.data_vars: + logger.warning(f"Variable {var_name} not in dataset. Available: {list(ds.data_vars)}") + continue + + var_info = VARIABLE_INFO.get(var_name, {}) + # Force compute if dask array and convert to numpy + raw_data = nearest[var_name] + if hasattr(raw_data, 'compute'): + raw_data = raw_data.compute() + values = np.array(raw_data.values, dtype=np.float64) + + # Convert units if needed + if var_info.get("convert_to_celsius", False): + values = values - 273.15 # K to °C + units = "°C" + elif var_info.get("convert_to_mm", False): + # ERA5 tp is in m/day (daily average rate), convert to mm/month + # Multiply by 1000 (m->mm) and by days in each month + values_monthly = np.array([ + values[i] * 1000.0 * DAYS_IN_MONTH[i] + for i in range(len(values)) + ]) + values = values_monthly + units = "mm/month" + else: + units = var_info.get("units", "") + + logger.debug(f"Variable {var_name}: raw range [{np.min(values):.4f}, {np.max(values):.4f}] {units}") + + # Build monthly values dict + monthly_values = {} + for i, month_name in enumerate(MONTH_NAMES): + monthly_values[month_name] = round(float(values[i]), 2) + + variables_data[var_name] = { + "full_name": var_info.get("full_name", var_name), + "units": units, + "monthly_values": monthly_values + } + + ds.close() + + # Build result + result = { + "source": "ERA5 Reanalysis (ECMWF)", + "data_type": "OBSERVATIONS (ground truth)", + "period": "2015-2025 monthly climatology (10-year average)", + "resolution": "0.25° grid (~28 km)", + "description": ( + "This is OBSERVED climate data from ERA5 reanalysis. " + "Use these values as the GROUND TRUTH baseline. " + "Climate model projections should be compared against this observational data." + ), + "requested_location": { + "latitude": latitude, + "longitude": longitude + }, + "extracted_location": { + "latitude": actual_lat, + "longitude": actual_lon_display + }, + "distance_from_requested_km": round(distance_km, 1), + "note": ( + f"ERA5 grid resolution is 0.25° (~28km). " + f"Data extracted from nearest grid point, {round(distance_km, 1)} km from requested location." + ), + "variables": variables_data, + "usage_guidance": { + "comparison": "Compare ERA5 values with climate model historical period to assess model bias", + "baseline": "Use ERA5 as the 'current climate' baseline (what we observe NOW)", + "interpretation": "ERA5 represents actual observed conditions, climate models are projections" + } + } + + # Save to sandbox if available + if hasattr(state, 'uuid_main_dir') and state.uuid_main_dir: + output_path = os.path.join(state.uuid_main_dir, "era5_climatology.json") + try: + with open(output_path, "w", encoding="utf-8") as f: + json.dump(result, f, indent=2) + logger.info(f"ERA5 climatology saved to: {output_path}") + result["saved_to"] = output_path + except Exception as e: + logger.warning(f"Could not save ERA5 climatology: {e}") + + # Also store in state + if hasattr(state, 'era5_climatology_response'): + state.era5_climatology_response = result + + return result + + except Exception as e: + logger.error(f"Error extracting ERA5 climatology: {e}") + return { + "error": f"Failed to extract ERA5 climatology: {str(e)}", + "latitude": latitude, + "longitude": longitude, + "variables_requested": normalized_vars + } + + return StructuredTool.from_function( + func=get_era5_climatology, + name="get_era5_climatology", + description=( + "Extract OBSERVED climate data from ERA5 reanalysis (2015-2025 climatology). " + "This provides GROUND TRUTH observations - call this FIRST before analyzing climate model data. " + "Returns monthly averages for temperature, precipitation, wind, etc. at the specified location." + ), + args_schema=GetERA5ClimatologyArgs, + ) diff --git a/src/climsight/tools/era5_retrieval_tool.py b/src/climsight/tools/era5_retrieval_tool.py new file mode 100644 index 0000000..fbcdfe5 --- /dev/null +++ b/src/climsight/tools/era5_retrieval_tool.py @@ -0,0 +1,361 @@ +# src/climsight/tools/era5_retrieval_tool.py +""" +ERA5 data retrieval tool for use in the visualization agent. +Retrieves ERA5 Surface climate data from Earthmover (Arraylake). +Saves the retrieved data locally in Zarr format. +Hardcoded to 'temporal' query mode for efficient time-series retrieval. +Uses 'nearest' neighbor selection for point coordinates to prevent empty spatial slices. +""" + +import os +import sys +import logging +import shutil +import xarray as xr +import pandas as pd +from pydantic import BaseModel, Field +from typing import Optional, Literal +from langchain_core.tools import StructuredTool + +# --- IMPORTS & CONFIGURATION --- +try: + import zarr + import arraylake + from arraylake import Client + # Optional: Check for Streamlit to support session state if available + try: + import streamlit as st + except ImportError: + st = None +except ImportError as e: + install_command = "pip install --upgrade xarray zarr arraylake pandas numpy pydantic langchain-core" + raise ImportError( + f"Required libraries missing. Please ensure arraylake is installed.\n" + f"Try running: {install_command}" + ) from e + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# ============================================================================= +# EARTHMOVER (ARRAYLAKE) IMPLEMENTATION +# ============================================================================= + +# Variable Mapping: Maps friendly names to Earthmover short codes +VARIABLE_MAPPING = { + # Temperature + "sea_surface_temperature": "sst", + "2m_temperature": "t2", + "temperature": "t2", + "skin_temperature": "skt", + "dewpoint_temperature": "d2", + + # Wind + "10m_u_component_of_wind": "u10", + "10m_v_component_of_wind": "v10", + "u_component_of_wind": "u10", + "v_component_of_wind": "v10", + + # Pressure + "surface_pressure": "sp", + "mean_sea_level_pressure": "mslp", + + # Clouds/Precip + "total_cloud_cover": "tcc", + "convective_precipitation": "cp", + "large_scale_precipitation": "lsp", + "total_precipitation": "tp", + + # Identity mappings (so short codes work) + "t2": "t2", "sst": "sst", "mslp": "mslp", "u10": "u10", "v10": "v10", + "sp": "sp", "tcc": "tcc", "cp": "cp", "lsp": "lsp", "sd": "sd", "tp": "tp" +} + +class ERA5RetrievalArgs(BaseModel): + # query_type is removed from the agent's view (hardcoded internally) + variable_id: Literal[ + "t2", "sst", "mslp", "u10", "v10", "sp", "tcc", "cp", "lsp", "sd", "skt", "d2", "tp", + "sea_surface_temperature", "surface_pressure", "total_cloud_cover", "total_precipitation", + "10m_u_component_of_wind", "10m_v_component_of_wind", "2m_temperature", "2m_dewpoint_temperature", + "temperature", "u_component_of_wind", "v_component_of_wind" + ] = Field(description="ERA5 variable to retrieve. Preferred short codes: 't2' (Air Temp), 'sst' (Sea Surface Temp), 'u10'/'v10' (Wind), 'mslp' (Pressure), 'tp' (Total Precip).") + + start_date: str = Field(description="Start date (YYYY-MM-DD). Data available 1979-2024.") + end_date: str = Field(description="End date (YYYY-MM-DD).") + + # Coordinates + min_latitude: float = Field(-90.0, description="Minimum latitude. For a specific point, use the same value as max_latitude.") + max_latitude: float = Field(90.0, description="Maximum latitude. For a specific point, use the same value as min_latitude.") + min_longitude: float = Field(0.0, description="Minimum longitude. For a specific point, use the same value as max_longitude.") + max_longitude: float = Field(359.75, description="Maximum longitude. For a specific point, use the same value as min_longitude.") + + work_dir: Optional[str] = Field(None, description="The absolute path to the working directory where data should be saved.") + +def _generate_descriptive_filename(variable_id: str, query_type: str, start_date: str, end_date: str) -> str: + """Generate a descriptive directory name for the Zarr store.""" + clean_var = variable_id.replace('_', '') + clean_start = start_date.split()[0].replace('-', '') + clean_end = end_date.split()[0].replace('-', '') + # We use .zarr extension, but it is a directory + return f"era5_{clean_var}_{query_type}_{clean_start}_{clean_end}.zarr" + +def retrieve_era5_data( + variable_id: str, + start_date: str, + end_date: str, + min_latitude: float = -90.0, + max_latitude: float = 90.0, + min_longitude: float = 0.0, + max_longitude: float = 359.75, + work_dir: Optional[str] = None, + arraylake_api_key: Optional[str] = None, + **kwargs # Catch-all for unused args +) -> dict: + """ + Retrieves ERA5 Surface data from Earthmover (Arraylake). + Hardcoded to use 'temporal' (Time-series) queries. + Uses Nearest Neighbor selection for points to avoid empty slices. + + Args: + arraylake_api_key: Arraylake API key. If not provided, falls back to + ARRAYLAKE_API_KEY environment variable. + """ + ds = None + local_zarr_path = None + query_type = "temporal" # Hardcoded for Climsight point-data focus + + # Get API Key - prefer passed parameter, fall back to environment + ARRAYLAKE_API_KEY = arraylake_api_key or os.environ.get("ARRAYLAKE_API_KEY") + + if not ARRAYLAKE_API_KEY: + return {"success": False, "error": "Missing Arraylake API Key", "message": "Please provide arraylake_api_key parameter or set ARRAYLAKE_API_KEY environment variable."} + + try: + # Map Variable Name + short_var = VARIABLE_MAPPING.get(variable_id.lower(), variable_id) + logging.info(f"🌍 Earthmover ERA5 Retrieval ({query_type}): {short_var} | {start_date} to {end_date}") + + # --- 1. Sandbox / Path Logic --- + # Priority: 1) CLIMSIGHT_THREAD_ID env var, 2) Streamlit session, 3) work_dir, 4) default + main_dir = None + thread_id = os.environ.get("CLIMSIGHT_THREAD_ID") + + if thread_id: + # Use sandbox path based on thread_id (set by sandbox_utils) + main_dir = os.path.join("tmp", "sandbox", thread_id) + elif "streamlit" in sys.modules and st is not None and hasattr(st, 'session_state'): + try: + session_uuid = getattr(st.session_state, "session_uuid", None) + if session_uuid: + main_dir = os.path.join("tmp", "sandbox", session_uuid) + else: + st_thread_id = st.session_state.get("thread_id") + if st_thread_id: + main_dir = os.path.join("tmp", "sandbox", st_thread_id) + except: + pass + + # Fallback to work_dir if no sandbox available + if not main_dir and work_dir: + main_dir = work_dir + + if not main_dir: + main_dir = os.path.join("tmp", "sandbox", "era5_default") + + os.makedirs(main_dir, exist_ok=True) + + # FIX: Prevent double nesting if work_dir already ends with 'era5_data' + if os.path.basename(main_dir.rstrip(os.sep)) == "era5_data": + era5_dir = main_dir + else: + era5_dir = os.path.join(main_dir, "era5_data") + os.makedirs(era5_dir, exist_ok=True) + + # Check Cache + zarr_dirname = _generate_descriptive_filename(short_var, query_type, start_date, end_date) + local_zarr_path = os.path.join(era5_dir, zarr_dirname) + absolute_zarr_path = os.path.abspath(local_zarr_path) + + if os.path.exists(local_zarr_path): + logging.info(f"⚡ Cache hit: {local_zarr_path}") + return { + "success": True, + "output_path_zarr": absolute_zarr_path, + "full_path": absolute_zarr_path, + "variable": short_var, + "query_type": query_type, + "message": f"Cached ERA5 data found at {absolute_zarr_path}" + } + + # --- 2. Connect to Earthmover --- + logging.info("Connecting to Arraylake...") + client = Client(token=ARRAYLAKE_API_KEY) + repo_name = "earthmover-public/era5-surface-aws" + repo = client.get_repo(repo_name) + session = repo.readonly_session("main") + + # Open Dataset + ds = xr.open_dataset( + session.store, + engine="zarr", + consolidated=False, + zarr_format=3, + chunks=None, + group=query_type + ) + + if short_var not in ds: + return {"success": False, "error": f"Variable '{short_var}' not found. Available: {list(ds.data_vars)}"} + + # --- 3. Slicing & Selection --- + start_datetime_obj = pd.to_datetime(start_date) + end_datetime_obj = pd.to_datetime(end_date) + time_slice = slice(start_datetime_obj, end_datetime_obj) + + # Check if it's a point query (min approx equal to max) + is_point_query = (abs(max_latitude - min_latitude) < 0.01) and (abs(max_longitude - min_longitude) < 0.01) + + if is_point_query: + # CRITICAL FIX: Cannot use method="nearest" with slice objects! + # Must do spatial selection FIRST (with method="nearest"), THEN time slicing. + center_lat = (min_latitude + max_latitude) / 2.0 + center_lon = (min_longitude + max_longitude) / 2.0 + + logging.info(f"Point query detected: selecting nearest neighbor to {center_lat}, {center_lon}") + + # Step 1: Select nearest spatial point FIRST (no slices!) + subset = ds[short_var].sel( + latitude=center_lat, + longitude=center_lon, + method="nearest" + ) + # Step 2: THEN slice by time + subset = subset.sel(time=time_slice) + else: + # Box query - no method="nearest" needed, standard slicing + # Earthmover lat is typically sorted Max -> Min + req_min_lon = min_longitude % 360 + req_max_lon = max_longitude % 360 + + if req_min_lon > req_max_lon: + # Simple clamp for now + lon_slice = slice(req_min_lon, 359.75) + else: + lon_slice = slice(req_min_lon, req_max_lon) + + subset = ds[short_var].sel( + time=time_slice, + latitude=slice(max_latitude, min_latitude), + longitude=lon_slice + ) + + # Validate Data Existence + if subset.sizes.get('time', 0) == 0: + return {"success": False, "error": "Empty time slice.", "message": "No data found for the requested dates."} + + # --- 4. Save to Zarr (Atomic Write) --- + logging.info(f"Downloading data to {local_zarr_path}...") + + ds_out = subset.to_dataset(name=short_var) + + # Clear encoding + for var in ds_out.variables: + ds_out[var].encoding = {} + + # Atomic write: write to temp dir first, then rename on success + temp_zarr_path = local_zarr_path + ".tmp" + + # Clean up any previous failed temp directory + if os.path.exists(temp_zarr_path): + shutil.rmtree(temp_zarr_path) + + try: + ds_out.to_zarr(temp_zarr_path, mode="w", consolidated=True, compute=True) + + # Atomic replace: remove old cache (if exists) and rename temp to final + if os.path.exists(local_zarr_path): + shutil.rmtree(local_zarr_path) + os.rename(temp_zarr_path, local_zarr_path) + logging.info("✅ Download complete.") + except Exception as write_error: + # Clean up failed temp directory + if os.path.exists(temp_zarr_path): + shutil.rmtree(temp_zarr_path, ignore_errors=True) + raise write_error + + return { + "success": True, + "output_path_zarr": absolute_zarr_path, + "full_path": absolute_zarr_path, + "variable": short_var, + "query_type": query_type, + "message": f"ERA5 data ({query_type} optimized) retrieved and saved to {absolute_zarr_path}" + } + + except Exception as e: + logging.error(f"Error in ERA5 Earthmover retrieval: {e}", exc_info=True) + if local_zarr_path and os.path.exists(local_zarr_path): + shutil.rmtree(local_zarr_path, ignore_errors=True) + return {"success": False, "error": str(e), "message": f"Failed: {str(e)}"} + finally: + if ds is not None: + ds.close() + +def create_era5_retrieval_tool(arraylake_api_key: str): + """ + Create the ERA5 retrieval tool with the API key bound. + + Args: + arraylake_api_key: Arraylake API key for accessing Earthmover data + + Returns: + StructuredTool configured for ERA5 data retrieval + """ + def retrieve_era5_wrapper( + variable_id: str, + start_date: str, + end_date: str, + min_latitude: float = -90.0, + max_latitude: float = 90.0, + min_longitude: float = 0.0, + max_longitude: float = 359.75, + work_dir: Optional[str] = None, + ) -> dict: + return retrieve_era5_data( + variable_id=variable_id, + start_date=start_date, + end_date=end_date, + min_latitude=min_latitude, + max_latitude=max_latitude, + min_longitude=min_longitude, + max_longitude=max_longitude, + work_dir=work_dir, + arraylake_api_key=arraylake_api_key, + ) + + return StructuredTool.from_function( + func=retrieve_era5_wrapper, + name="retrieve_era5_data", + description=( + "Retrieves ERA5 Surface climate data from Earthmover (Arraylake). " + "Optimized for TEMPORAL time-series extraction at specific locations. " + "Automatically snaps to nearest grid point to ensure data is returned. " + "Returns a Zarr directory path. " + "Available vars: t2 (temp), sst, u10/v10 (wind), mslp (pressure), tp (precip)." + ), + args_schema=ERA5RetrievalArgs + ) + + +# Keep backward-compatible module-level tool that uses environment variable +era5_retrieval_tool = StructuredTool.from_function( + func=retrieve_era5_data, + name="retrieve_era5_data", + description=( + "Retrieves ERA5 Surface climate data from Earthmover (Arraylake). " + "Optimized for TEMPORAL time-series extraction at specific locations. " + "Automatically snaps to nearest grid point to ensure data is returned. " + "Returns a Zarr directory path. " + "Available vars: t2 (temp), sst, u10/v10 (wind), mslp (pressure), tp (precip)." + ), + args_schema=ERA5RetrievalArgs +) \ No newline at end of file diff --git a/src/climsight/tools/get_data_components.py b/src/climsight/tools/get_data_components.py new file mode 100644 index 0000000..7eb0993 --- /dev/null +++ b/src/climsight/tools/get_data_components.py @@ -0,0 +1,206 @@ +"""Tool for extracting specific climate variables from stored climatology.""" + +import ast +import calendar +import json +import logging +import os +from typing import List, Optional, Union + +import numpy as np +import pandas as pd +from pydantic import BaseModel, Field + +try: + from ..sandbox_utils import load_climate_data_manifest +except ImportError: + from sandbox_utils import load_climate_data_manifest + +logger = logging.getLogger(__name__) + + +def _load_df_list_from_manifest(manifest_path: str, climate_data_dir: str) -> List[dict]: + """Rehydrate df_list from a sandbox manifest when state.df_list is absent.""" + manifest = load_climate_data_manifest(manifest_path) + if not manifest: + return [] + + entries = [] + for entry in manifest.get("entries", []): + csv_path = os.path.join(climate_data_dir, entry.get("csv", "")) + meta_path = os.path.join(climate_data_dir, entry.get("meta", "")) + if not os.path.exists(csv_path): + continue + + df = pd.read_csv(csv_path) + meta = {} + if os.path.exists(meta_path): + try: + with open(meta_path, "r", encoding="utf-8") as f: + meta = json.load(f) + except Exception: + meta = {} + + entries.append({ + "dataframe": df, + "extracted_vars": meta.get("extracted_vars", {}), + "years_of_averaging": meta.get("years_of_averaging", ""), + "description": meta.get("description", ""), + "main": meta.get("main", False), + "source": meta.get("source", ""), + }) + + return entries + + +class GetDataComponentsArgs(BaseModel): + environmental_data: Optional[str] = Field( + default=None, + description=( + "The type of environmental data to retrieve. " + "Choose from Temperature, Precipitation, u_wind, or v_wind." + ), + ) + months: Optional[Union[str, List[str]]] = Field( + default=None, + description=( + "List of months or a stringified list of month names to retrieve data for. " + "Each month should be one of 'Jan', 'Feb', ..., 'Dec'. " + "If not specified, data for all months will be retrieved." + ), + ) + + +def create_get_data_components_tool(state, config, stream_handler=None): + """Create a StructuredTool bound to the current agent state.""" + try: + from langchain.tools import StructuredTool + except ImportError: + from langchain_core.tools import StructuredTool + + def get_data_components(**kwargs): + if stream_handler is not None: + stream_handler.update_progress("Retrieving specific climatology values...") + + if isinstance(kwargs.get("months"), str): + try: + kwargs["months"] = ast.literal_eval(kwargs["months"]) + except (ValueError, SyntaxError): + kwargs["months"] = None + + args = GetDataComponentsArgs(**kwargs) + environmental_data = args.environmental_data + months = args.months + + if not environmental_data: + return {"error": "No environmental data type specified."} + + climate_source = config.get("climate_data_source", "nextGEMS") + + # Normalize common shorthand to canonical names. + env_normalized = environmental_data.strip() + env_aliases = { + "tp": "Precipitation", + "precipitation": "Precipitation", + "mean2t": "Temperature", + "tas": "Temperature", + "temp": "Temperature", + "t2m": "Temperature", + "wind_u": "u_wind", + "wind_v": "v_wind", + "uas": "u_wind", + "vas": "v_wind", + } + env_normalized = env_aliases.get(env_normalized, env_normalized) + + df_list = getattr(state, "df_list", None) + if not df_list: + manifest_path = state.input_params.get("climate_data_manifest", "") + climate_data_dir = state.input_params.get("climate_data_dir", "") or getattr( + state, "climate_data_dir", "" + ) + df_list = _load_df_list_from_manifest(manifest_path, climate_data_dir) + + if not df_list: + return {"error": "No climatology data available."} + + if climate_source in ("nextGEMS", "ICCP"): + environmental_mapping = { + "Temperature": "mean2t", + "Precipitation": "tp", + "u_wind": "wind_u", + "v_wind": "wind_v", + } + elif climate_source == "AWI_CM": + environmental_mapping = { + "Temperature": "Present Day Temperature", + "Precipitation": "Present Day Precipitation", + "u_wind": "u_wind", + "v_wind": "v_wind", + } + else: + environmental_mapping = { + "Temperature": "mean2t", + "Precipitation": "tp", + "u_wind": "wind_u", + "v_wind": "wind_v", + } + + if env_normalized not in environmental_mapping: + return {"error": f"Invalid environmental data type: {environmental_data}"} + + var_name = environmental_mapping[env_normalized] + + if not months: + months = [calendar.month_abbr[m] for m in range(1, 13)] + + month_mapping = {calendar.month_abbr[m]: calendar.month_name[m] for m in range(1, 13)} + selected_months = [] + for month in months: + if month in month_mapping: + selected_months.append(month_mapping[month]) + elif month in calendar.month_name: + selected_months.append(month) + + response = {} + for entry in df_list: + df = entry.get("dataframe") + extracted_vars = entry.get("extracted_vars", {}) + + if df is None: + continue + if "Month" not in df.columns: + continue + + if var_name in df.columns: + var_meta = extracted_vars.get(var_name, {"units": ""}) + data_values = df[df["Month"].isin(selected_months)][var_name].tolist() + ext_data = {month: float(np.round(value, 2)) for month, value in zip(selected_months, data_values)} + ext_exp = ( + f"Monthly mean values of {env_normalized}, {var_meta.get('units', '')} " + f"for years: {entry.get('years_of_averaging', '')}" + ) + response.update({ext_exp: ext_data}) + else: + matching_cols = [col for col in df.columns if environmental_data.lower() in col.lower()] + if matching_cols: + col_name = matching_cols[0] + data_values = df[df["Month"].isin(selected_months)][col_name].tolist() + ext_data = {month: float(np.round(value, 2)) for month, value in zip(selected_months, data_values)} + ext_exp = ( + f"Monthly mean values of {env_normalized} " + f"for years: {entry.get('years_of_averaging', '')}" + ) + response.update({ext_exp: ext_data}) + + if not response: + return {"error": f"Variable '{environmental_data}' not found in climatology."} + + return response + + return StructuredTool.from_function( + func=get_data_components, + name="get_data_components", + description="Retrieve specific climate variables from the saved climatology.", + args_schema=GetDataComponentsArgs, + ) diff --git a/src/climsight/tools/image_viewer.py b/src/climsight/tools/image_viewer.py index df489ee..35589cb 100644 --- a/src/climsight/tools/image_viewer.py +++ b/src/climsight/tools/image_viewer.py @@ -82,7 +82,7 @@ def view_and_analyze_image(image_path: str, openai_api_key: str, model_name: str ] } ], - max_tokens=5000 + max_completion_tokens=5000 ) return response.choices[0].message.content @@ -101,19 +101,29 @@ class ImageViewerArgs(BaseModel): ) -def create_image_viewer_tool(openai_api_key: str, model_name: str): +def create_image_viewer_tool(openai_api_key: str, model_name: str, sandbox_path: Optional[str] = None): """ Create the image viewer tool with the API key and model bound. Args: openai_api_key: OpenAI API key model_name: Model name to use + sandbox_path: Optional sandbox directory path for resolving relative paths Returns: StructuredTool configured for image analysis """ def view_image_wrapper(image_path: str) -> str: - return view_and_analyze_image(image_path, openai_api_key, model_name) + resolved_path = image_path + # If path is relative and sandbox_path is provided, resolve it + if sandbox_path and not os.path.isabs(image_path): + resolved_path = os.path.join(sandbox_path, image_path) + # Fallback: if still not found and sandbox_path exists, try it anyway + if not os.path.exists(resolved_path) and sandbox_path: + alt_path = os.path.join(sandbox_path, image_path) + if os.path.exists(alt_path): + resolved_path = alt_path + return view_and_analyze_image(resolved_path, openai_api_key, model_name) return StructuredTool.from_function( func=view_image_wrapper, @@ -122,7 +132,7 @@ def view_image_wrapper(image_path: str) -> str: "Analyze climate-related visualizations to extract scientific insights. " "Use this tool after generating plots with python_repl to understand " "the patterns and trends shown in the visualization. " - "Provide the full path to the saved image file." + "Provide the path to the saved image file (can be relative to the sandbox)." ), args_schema=ImageViewerArgs ) \ No newline at end of file diff --git a/src/climsight/tools/package_tools.py b/src/climsight/tools/package_tools.py new file mode 100644 index 0000000..4b4ecf1 --- /dev/null +++ b/src/climsight/tools/package_tools.py @@ -0,0 +1,37 @@ +# src/tools/package_tools.py +import sys +import subprocess +import logging +from pydantic import BaseModel, Field +from langchain_core.tools import StructuredTool + +def install_package(package_name: str, pip_options: str = ""): + """ + Installs a Python package using pip. + + Args: + package_name: The name of the package to install + pip_options: Additional pip options (e.g., '--force-reinstall') + + Returns: + str: Success or error message + """ + try: + command = [sys.executable, '-m', 'pip', 'install'] + pip_options.split() + [package_name] + subprocess.check_call(command) + return f"Package '{package_name}' installed successfully." + except Exception as e: + return f"Failed to install package '{package_name}': {e}" + +# Define the args schema for install_package +class InstallPackageArgs(BaseModel): + package_name: str = Field(description="The name of the package to install.") + pip_options: str = Field(default="", description="Additional pip options (e.g., '--force-reinstall').") + +# Create the install_package_tool +install_package_tool = StructuredTool.from_function( + func=install_package, + name="install_package", + description="Installs a Python package using pip. Use this tool if you encounter a ModuleNotFoundError or need a package that's not installed.", + args_schema=InstallPackageArgs +) \ No newline at end of file diff --git a/src/climsight/tools/python_repl.py b/src/climsight/tools/python_repl.py index 580e17a..8892666 100644 --- a/src/climsight/tools/python_repl.py +++ b/src/climsight/tools/python_repl.py @@ -1,120 +1,373 @@ # src/climsight/tools/python_repl.py """ -Simple Python REPL Tool for Climsight with persistent state. +Python REPL Tool for Climsight using Jupyter Kernel. +Executes code in a persistent, isolated Jupyter Kernel process. """ +import os import sys import logging -from io import StringIO -from typing import Dict, Any +import re +import threading +import queue +import time +import atexit +import textwrap +from typing import Any, Dict, Optional, List, Set +from pydantic import BaseModel, Field, PrivateAttr + +# Import LangChain components +try: + from langchain.tools import StructuredTool +except ImportError: + from langchain_core.tools import StructuredTool +from langchain_experimental.tools import PythonREPLTool + +# Import Jupyter Client +try: + from jupyter_client import KernelManager + from jupyter_client.client import KernelClient +except ImportError: + raise ImportError("Missing dependencies for persistent REPL. Install jupyter_client and ipykernel.") + +try: + import streamlit as st +except ImportError: + st = None logger = logging.getLogger(__name__) +# --- Helper: Unified Session ID --- -class PersistentPythonREPL: +def get_global_session_id() -> str: """ - A persistent Python REPL that maintains state between executions. - Variables created in one execution persist to the next, like a Jupyter notebook. + Returns a unified session ID prioritizing: + 1. Streamlit thread_id (most specific) + 2. Streamlit session_uuid (fallback) + 3. Environment variable (CLI/Docker) + 4. Default fallback """ + if st is not None and hasattr(st, "session_state"): + if getattr(st.session_state, "thread_id", None): + return st.session_state.thread_id + if getattr(st.session_state, "session_uuid", None): + return st.session_state.session_uuid - def __init__(self): - """Initialize with an empty locals dictionary and pre-imported modules.""" - # Persistent locals dictionary - this is THE KEY FEATURE - self.locals = {} - - # Pre-import common modules into the persistent namespace - self.locals.update({ - 'pd': __import__('pandas'), - 'np': __import__('numpy'), - 'plt': __import__('matplotlib.pyplot', fromlist=['pyplot']), - 'xr': __import__('xarray'), - 'os': __import__('os'), - 'Path': __import__('pathlib').Path, - }) + return os.environ.get("CLIMSIGHT_THREAD_ID", "default_cli_session") + +# --- Jupyter Kernel Executor (ISOLATED PROCESS) --- + +class JupyterKernelExecutor: + """ + Manages a persistent Jupyter kernel for code execution. + Running in a separate process ensures isolation. + """ + def __init__(self, working_dir=None): + self._working_dir = working_dir + # Use default python3 kernel + self.km = KernelManager(kernel_name="python3") + self.kc: Optional[KernelClient] = None + self.is_initialized = False + self._start_kernel() + + def _start_kernel(self): + cwd = self._working_dir if (self._working_dir and os.path.exists(self._working_dir)) else os.getcwd() + logging.info(f"Starting Jupyter kernel in {cwd}...") - def execute(self, code: str) -> str: + try: + # Robust check for kernel spec + if not self.km.kernel_spec: + raise RuntimeError("No kernel spec found") + except Exception as e: + logging.warning(f"Kernel spec warning: {e}. Forcing sys.executable for ipykernel.") + # Fallback: explicitly call the current python executable to avoid path issues + self.km.kernel_cmd = [sys.executable, "-m", "ipykernel_launcher", "-f", "{connection_file}"] + + try: + self.km.start_kernel(cwd=cwd) + self.kc = self.km.client() + self.kc.start_channels() + self.kc.wait_for_ready(timeout=60) + logging.info(f"Jupyter kernel started successfully.") + except Exception as e: + logging.error(f"Kernel failed to start: {e}") + self.close() + raise + + def restart_kernel(self): + """Hard restart in case of stuck process or timeout.""" + logging.warning("Restarting Jupyter Kernel...") + self.close() + self._start_kernel() + self.is_initialized = False + + def _drain_channels(self, timeout=1.0): + """Drain messages from channels to prevent cross-contamination after interrupt/timeout.""" + if not self.kc: return + start = time.time() + while time.time() - start < timeout: + try: + self.kc.get_iopub_msg(timeout=0.1) + except queue.Empty: + break + + def _execute_code(self, code: str, timeout: float = 300.0) -> Dict[str, Any]: + if not self.kc: + return {"status": "error", "error": "Kernel client not available."} + + # Flush previous messages + self._drain_channels(timeout=0.1) + + msg_id = self.kc.execute(code) + result = {"status": "success", "stdout": "", "stderr": "", "display_data": []} + start_time = time.time() + + while True: + # 1. Check Timeout + if time.time() - start_time > timeout: + logging.warning(f"Code execution timed out ({timeout}s). Interrupting kernel.") + self.km.interrupt_kernel() + # Drain leftover messages from the interrupted execution + self._drain_channels(timeout=2.0) + result["status"] = "error" + result["error"] = f"Timeout after {timeout}s. Execution interrupted." + break + + # 2. Check Kernel Vitality + if not self.km.is_alive(): + result["status"] = "error" + result["error"] = "Kernel died unexpectedly. Session was restarted - please retry your command." + self.restart_kernel() + self.is_initialized = False # Force re-initialization on next run + break + + # 3. Get Message + try: + msg = self.kc.get_iopub_msg(timeout=0.1) + except queue.Empty: + continue + + # 4. Filter by Parent Message ID + if msg['parent_header'].get('msg_id') != msg_id: + continue + + msg_type = msg['msg_type'] + content = msg['content'] + + if msg_type == 'status' and content['execution_state'] == 'idle': + break + elif msg_type == 'stream': + if content['name'] == 'stdout': result["stdout"] += content['text'] + elif content['name'] == 'stderr': result["stderr"] += content['text'] + elif msg_type in ('display_data', 'execute_result'): + result["display_data"].append(content['data']) + if 'text/plain' in content['data']: + result["stdout"] += content['data']['text/plain'] + "\n" + elif msg_type == 'error': + result["status"] = "error" + # Remove ANSI colors for readable error logs + error_trace = "\n".join(content['traceback']) + clean_error = re.sub(r'\x1b\[[0-9;]*m', '', error_trace) + result["error"] = f"{content['ename']}: {content['evalue']}\n{clean_error}" + break + + return result + + def run(self, code: str) -> Dict[str, Any]: """ - Execute Python code in the persistent environment. + Executes code and returns full result dict. + Sanitizes markdown fences but PRESERVES indentation. + """ + # 1. Dedent to fix common indentation errors (e.g. code inside list items) + code = textwrap.dedent(code) + + # 2. Remove markdown fences + code = code.strip() + if code.startswith("```"): + # Remove leading ```python or ``` (case insensitive) + code = re.sub(r"^\s*```(?:python)?\s*\n", "", code, flags=re.IGNORECASE) + # Remove trailing ``` + code = re.sub(r"\n\s*```\s*$", "", code) - Args: - code: Python code to execute + return self._execute_code(code) + + def close(self): + if self.kc: + try: self.kc.stop_channels() + except: pass + if self.km and self.km.is_alive(): + try: self.km.shutdown_kernel(now=True) + except: pass + +# --- Session Manager --- + +class REPLManager: + """Singleton to manage Kernels per session (thread_id).""" + _instances: dict[str, JupyterKernelExecutor] = {} + _lock = threading.Lock() + + @classmethod + def get_repl(cls, session_id: str, sandbox_path: str = None) -> JupyterKernelExecutor: + with cls._lock: + if session_id not in cls._instances: + logging.info(f"Creating new Kernel for session: {session_id}") + try: + cls._instances[session_id] = JupyterKernelExecutor(working_dir=sandbox_path) + except Exception as e: + logging.error(f"Failed to create kernel: {e}") + raise - Returns: - String containing output or error message - """ - # Capture stdout - old_stdout = sys.stdout - stdout_capture = StringIO() - sys.stdout = stdout_capture + # Restart dead kernels if necessary (Zombie check) + if not cls._instances[session_id].km.is_alive(): + logging.warning(f"Kernel for {session_id} died. Restarting.") + cls._instances[session_id].close() + cls._instances[session_id] = JupyterKernelExecutor(working_dir=sandbox_path) + + return cls._instances[session_id] + + @classmethod + def cleanup_all(cls): + for sid in list(cls._instances.keys()): + try: + cls._instances[sid].close() + except Exception: + pass + del cls._instances[sid] + +atexit.register(REPLManager.cleanup_all) + +# --- LangChain Tool Wrappers --- + +class CustomPythonREPLTool(PythonREPLTool): + """Used by data_analysis_agent. Imports dataset paths automatically.""" + _datasets: dict = PrivateAttr() + _results_dir: Optional[str] = PrivateAttr() + _session_key: Optional[str] = PrivateAttr() + + def __init__(self, datasets, results_dir=None, session_key=None, **kwargs): + super().__init__(**kwargs) + self._datasets = datasets + self._results_dir = results_dir + self._session_key = session_key + + def _run(self, query: str, **kwargs) -> Any: + # 1. Get Unified Session ID + sid = self._session_key or get_global_session_id() + # 2. Get/Create Kernel + sandbox_path = self._datasets.get("uuid_main_dir") try: - # Try to execute as an expression first (to show return values) - try: - result = eval(code, self.locals, self.locals) - if result is not None: - print(repr(result)) - except SyntaxError: - # If it's not an expression, execute as statement(s) - exec(code, self.locals, self.locals) + repl = REPLManager.get_repl(sid, sandbox_path=sandbox_path) + except Exception as e: + return f"System Error: Could not start Python environment. {str(e)}" + + # 3. Auto-init if fresh + if not repl.is_initialized: + # Initialize matplotlib backend inside the kernel! + init_code = [ + "import pandas as pd", + "import numpy as np", + "import matplotlib", + "matplotlib.use('Agg')", # CRITICAL: Non-interactive backend + "import matplotlib.pyplot as plt", + "import xarray as xr", + "import os", + "import json" + ] - # Get the output - output = stdout_capture.getvalue() + # Safe path injection using repr() + if self._results_dir: + # FIX: Since the kernel CWD is already the sandbox root, + init_code.append(f"results_dir = 'results'") + init_code.append("os.makedirs(results_dir, exist_ok=True)") - if output: - return f"Output:\n{output}" + if "climate_data_dir" in self._datasets: + # FIX: Set simple relative path because Kernel CWD is already the sandbox root + init_code.append(f"climate_data_dir = 'climate_data'") + # Auto-load main data if exists + init_code.append(f"try:\n if os.path.exists(f'{{climate_data_dir}}/data.csv'):\n df = pd.read_csv(f'{{climate_data_dir}}/data.csv')\n print('Loaded df from data.csv')\nexcept Exception as e: print(f'Auto-load failed: {{e}}')") + + if "era5_data_dir" in self._datasets: + # FIX: Set simple relative path because Kernel CWD is already the sandbox root + init_code.append(f"era5_data_dir = 'era5_data'") + + # Execute and check status + init_res = repl.run("\n".join(init_code)) + if init_res["status"] == "success": + repl.is_initialized = True else: - return "Code executed successfully (no output)." - - except Exception as e: - return f"Error: {type(e).__name__}: {str(e)}" - - finally: - # Restore stdout - sys.stdout = old_stdout + return f"Kernel Initialization Error: {init_res.get('error')}" + # 4. Detect Plots (Snapshot Method - Robust) + existing_files: Set[str] = set() + if self._results_dir and os.path.exists(self._results_dir): + try: + existing_files = set(os.listdir(self._results_dir)) + except Exception: + pass -# Create a global instance for persistence across tool calls -_repl_instance = PersistentPythonREPL() + # 5. Execute User Code + res = repl.run(query) + + output_str = res["stdout"] + if res["stderr"]: + output_str += f"\n[STDERR]\n{res['stderr']}" + + if res["status"] == "error": + return f"Error:\n{res.get('error')}" + # 6. Identify New Plots (Diff) + plots = [] + if self._results_dir and os.path.exists(self._results_dir): + try: + current_files = set(os.listdir(self._results_dir)) + new_files = current_files - existing_files + + for f in new_files: + if f.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf', '.svg')): + full_path = os.path.join(self._results_dir, f) + plots.append(full_path) + except Exception: + pass + + return { + "result": output_str.strip() or "Code executed successfully (no output).", + "output": output_str.strip(), + "plot_images": plots + } -def create_python_repl_tool(): +def create_python_repl_tool() -> StructuredTool: + """Factory for simple use cases (Smart Agent).""" + class Input(BaseModel): + code: str = Field(description="Python code to execute.") - """ - Create a Python REPL tool for LangChain agents with persistent state. - - Returns: - StructuredTool configured for Python code execution - """ - try: - from langchain.tools import StructuredTool - except ImportError: - from langchain_core.tools import StructuredTool - from pydantic import BaseModel, Field - - class PythonREPLInput(BaseModel): - code: str = Field( - description="Python code to execute. Has access to pandas (pd), numpy (np), matplotlib.pyplot (plt), and xarray (xr). Variables persist between executions." - ) - # Get available variables info - available_vars = [] - if hasattr(_repl_instance, 'locals'): - for key, value in _repl_instance.locals.items(): - if not key.startswith('__'): - available_vars.append(f"{key} ({type(value).__name__})") - - vars_description = "" - if available_vars: - vars_description = f"\nPre-loaded variables: {', '.join(available_vars[:10])}..." - - tool = StructuredTool.from_function( - func=_repl_instance.execute, + def run_repl(code: str): + sid = get_global_session_id() + + # Determine sandbox path + sandbox = None + if sid != "default_cli_session": + sandbox = os.path.join("tmp", "sandbox", sid) + if not os.path.exists(sandbox): + os.makedirs(sandbox, exist_ok=True) + + try: + repl = REPLManager.get_repl(sid, sandbox_path=sandbox) + # Basic init + if not repl.is_initialized: + repl.run("import matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot as plt; import pandas as pd; import numpy as np; import os") + repl.is_initialized = True + + res = repl.run(code) + if res["status"] == "success": + return res["stdout"] or "Executed." + return f"Error: {res.get('error')}" + except Exception as e: + return f"Kernel Error: {e}" + + return StructuredTool.from_function( + func=run_repl, name="python_repl", - description=( - "Execute Python code for data analysis and calculations. " - "Available: pandas as pd, numpy as np, matplotlib.pyplot as plt, xarray as xr. " - "Variables and imports persist between executions like a Jupyter notebook." - + vars_description - ), - args_schema=PythonREPLInput - ) - return tool \ No newline at end of file + description="Execute Python code in a persistent Jupyter Kernel. State is preserved. Use this for calculations and plotting.", + args_schema=Input + ) \ No newline at end of file diff --git a/src/climsight/tools/reflection_tools.py b/src/climsight/tools/reflection_tools.py new file mode 100644 index 0000000..728c498 --- /dev/null +++ b/src/climsight/tools/reflection_tools.py @@ -0,0 +1,135 @@ +# src/tools/reflection_tools.py +import base64 +import os +import logging +try: + import streamlit as st # Import streamlit to access session state +except ImportError: + st = None +from pydantic import BaseModel, Field +from langchain_core.tools import StructuredTool +from openai import OpenAI +try: + from ..config import API_KEY as _API_KEY +except ImportError: + from config import API_KEY as _API_KEY + +# Define the function to encode the image +def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + +def reflect_on_image(image_path: str) -> str: + """ + Analyzes an image and provides feedback. Automatically resolves sandbox paths. + """ + # --- NEW SANDBOX PATH RESOLUTION LOGIC --- + final_image_path = image_path + + # If the path is relative, resolve it against the current session's sandbox + if not os.path.isabs(image_path): + thread_id = None + if st is not None and hasattr(st, "session_state"): + thread_id = st.session_state.get("thread_id") + if not thread_id: + thread_id = os.environ.get("CLIMSIGHT_THREAD_ID") + + if thread_id: + sandbox_dir = os.path.join("tmp", "sandbox", thread_id) + potential_path = os.path.join(sandbox_dir, image_path) + if os.path.exists(potential_path): + final_image_path = potential_path + logging.info(f"Resolved relative path '{image_path}' to sandbox path '{final_image_path}'") + else: + logging.warning(f"Could not resolve relative path '{image_path}' in sandbox '{sandbox_dir}'") + else: + logging.warning("Could not resolve relative path: No thread_id available.") + # --- END OF NEW LOGIC --- + + if not os.path.exists(final_image_path): + # Provide a more informative error message + return (f"Error: The file '{final_image_path}' (resolved from '{image_path}') does not exist. " + f"This often happens if the file was not saved correctly or if there is a path mismatch between the agent's environment and the tool's environment.") + + base64_image = encode_image(final_image_path) + + prompt = """You are a STRICT professional reviewer of scientific images. Your task is to provide critical feedback to the visual creator agent so they can improve their visualization. Be particularly harsh on basic readability issues. Evaluate the provided image using the following criteria: + +**CRITICAL FAILURES (Each results in automatic score reduction of at least 5 points):** +- Any overlapping text or labels +- Illegible or cut-off axis labels +- Missing axis titles or units +- Text that is too small to read clearly +- Labels that obscure data points + +1. **Axis and Font Quality** (CRITICAL): Evaluate the visibility of axes and appropriateness of font size and style. ANY of the following issues should result in a score of 3/10 or lower: + - Axis labels that are cut off, truncated, or partially visible + - Font size that is too small to read comfortably + - Missing axis titles or units + - Poorly formatted tick labels (overlapping, rotated at bad angles, etc.) + +2. **Label Clarity** (CRITICAL): This is ABSOLUTELY ESSENTIAL. If ANY text overlaps with other text, data points, or visual elements, the maximum possible score is 2/10. Check for: + - Text overlapping with other text + - Labels overlapping with data points or lines + - Legend text that overlaps or is cut off + - Annotations that clash with other elements + +3. Color Scheme: Analyze the color choices. Is the color scheme appropriate for the data presented? Are the colors distinguishable and not causing visual confusion? + +4. Data Representation: Evaluate how well the data is represented. Are data points clearly visible? Is the chosen chart or graph type appropriate for the data? + +5. **Legend and Scale** (Important): Check the presence and clarity of legends and scales. If the legend overlaps with the plot area or has overlapping text, reduce score by at least 4 points. + +6. Overall Layout: Assess the overall layout and use of space. Poor spacing that causes any text overlap should be heavily penalized. + +7. Technical Issues: Identify any technical problems such as pixelation, blurriness, or artifacts that might affect the image quality. + +8. Scientific Accuracy: To the best of your ability, comment on whether the image appears scientifically accurate and free from obvious errors or misrepresentations. + +9. **Convention Adherence**: Verify that the figure follows scientific conventions. For example, when depicting variables like 'Depth of water' or other vertical dimensions, these should appear on the Y-axis with minimum values at the top and maximum depth at the bottom. This is a critically important scientific convention - if depth/vertical dimensions are incorrectly presented on the horizontal X-axis, assign a score of 1/10. + +**SCORING GUIDELINES:** +- 7-10: Professional quality with no text/label issues +- 5-6: Minor issues but all text is readable +- 3-4: Significant problems including some text overlap or readability issues +- 1-2: Major failures with overlapping text, illegible labels, or missing critical elements + +BE CRITICAL of any text overlap or readability issues. A visualization with overlapping text is fundamentally flawed and should receive a very low score regardless of other qualities. + +Please provide a structured review addressing each of these points. Conclude with an overall assessment of the image quality, highlighting any significant issues or exemplary aspects. Finally, give the image a score out of 10.""" + if not _API_KEY: + return "Error: OPENAI_API_KEY not configured for reflect_on_image." + + openai_client = OpenAI(api_key=_API_KEY) + response = openai_client.chat.completions.create( + model="gpt-4o", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{base64_image}" + } + } + ] + } + ], + max_completion_tokens=1000 + ) + + return response.choices[0].message.content + +# Define the args schema for reflect_on_image +class ReflectOnImageArgs(BaseModel): + image_path: str = Field(description="The path to the image to reflect on.") + +# Define the reflect_on_image tool +reflect_tool = StructuredTool.from_function( + func=reflect_on_image, + name="reflect_on_image", + description="A tool to reflect on an image and provide feedback for improvements.", + args_schema=ReflectOnImageArgs +) diff --git a/src/climsight/tools/visualization_tools.py b/src/climsight/tools/visualization_tools.py new file mode 100644 index 0000000..e4bbc02 --- /dev/null +++ b/src/climsight/tools/visualization_tools.py @@ -0,0 +1,261 @@ +# src/tools/visualization_tools.py +import os +import logging +from pydantic import BaseModel, Field +from langchain_core.tools import StructuredTool +from langchain_openai import OpenAIEmbeddings +from langchain_community.vectorstores import Chroma +import streamlit as st + +try: + from ..config import API_KEY as _API_KEY +except ImportError: + from config import API_KEY as _API_KEY + +class ExampleVisualizationArgs(BaseModel): + query: str = Field(description="The user's query about plotting.") + +def get_example_of_visualizations(query: str) -> str: + """ + Retrieves example visualizations related to the query. + + Parameters: + - query (str): The user's query about plotting. + + Returns: + - str: The content of the most relevant example file. + """ + # Initialize embeddings from session state or config + embeddings = OpenAIEmbeddings(api_key=_API_KEY) + + # Load the existing vector store + vector_store = Chroma( + collection_name="example_collection", + embedding_function=embeddings, + persist_directory=os.path.join('data', 'examples_database', 'chroma_langchain_notebooks') + ) + + # Perform a similarity search + results = vector_store.similarity_search_with_score(query, k=1) + + # Extract the most relevant document + doc, score = results[0] + + # Construct the full path to the txt file + file_name = doc.metadata['source'].lstrip('./') + full_path = os.path.join('data', 'examples_database', file_name) + + # Read and return the content of the txt file + try: + with open(full_path, 'r', encoding='utf-8') as file: + content = file.read() + return content + except Exception as e: + logging.error(f"An error occurred while reading the file: {str(e)}") + return "" # Return empty string if error occurs + +# Create the example visualization tool +example_visualization_tool = StructuredTool.from_function( + func=get_example_of_visualizations, + name="get_example_of_visualizations", + description="Retrieves example visualization code related to the user's query.", + args_schema=ExampleVisualizationArgs +) + +# File listing tool definition +class ListPlottingDataFilesArgs(BaseModel): + dummy_arg: str = Field(default="", description="(No arguments needed)") + +def list_plotting_data_files(dummy_arg: str = "") -> str: + """ + Lists ALL files recursively from two sources: + 1. The data/plotting_data directory (static resources) + 2. All files in the current UUID sandbox directories (active datasets) + + Returns a flat list of all available file paths using relative paths. + """ + import os + import streamlit as st + + all_files = [] + cwd = os.getcwd() + + # Part 1: List files from data/plotting_data + plotting_data_dir = os.path.join("data", "plotting_data") + if os.path.exists(plotting_data_dir): + for root, dirs, files in os.walk(plotting_data_dir): + for filename in files: + full_path = os.path.join(root, filename) + # Keep this as a relative path + all_files.append(f"STATIC: {full_path}") + + # Part 2: List all files from the current sandbox directory + thread_id = st.session_state.get("thread_id") if hasattr(st, "session_state") else None + if not thread_id: + thread_id = os.environ.get("CLIMSIGHT_THREAD_ID") + + if thread_id: + sandbox_dir = os.path.join("tmp", "sandbox", thread_id) + if os.path.exists(sandbox_dir): + for root, dirs, files in os.walk(sandbox_dir): + for filename in files: + full_path = os.path.join(root, filename) + if full_path.startswith(cwd): + rel_path = full_path[len(cwd) + 1:] + else: + rel_path = full_path + + rel_path = rel_path.replace('\\', '/') + + if "era5_data" in rel_path: + all_files.append(f"ERA5: {rel_path}") + else: + all_files.append(f"DATA: {rel_path}") + + # Return a simple list of all available files + if all_files: + return "Available files:\n" + "\n".join(all_files) + else: + return "No files found in plotting_data or active datasets." + +# Create the list plotting data files tool +list_plotting_data_files_tool = StructuredTool.from_function( + func=list_plotting_data_files, + name="list_plotting_data_files", + description="Lists ALL available files recursively, including plotting resources, dataset files, and ERA5 data. Use this to see exactly what files you can work with.", + args_schema=ListPlottingDataFilesArgs +) + +class WiseAgentToolArgs(BaseModel): + query: str = Field(description="The query about visualization to send to Claude for advice. Include details about your dataset structure, variables, and visualization goals.") + +def wise_agent(query: str) -> str: + """ + A tool that provides visualization advice using either OpenAI or Anthropic models. + + Args: + query: The query about visualization to send to the AI model + + Returns: + str: AI's advice on visualization + """ + import streamlit as st + import logging + import yaml + import os + + # Load configuration (Climsight uses config.yml by default) + config_path = os.path.join(os.getcwd(), "config.yml") + if os.path.exists(config_path): + with open(config_path, "r") as f: + app_config = yaml.safe_load(f) + else: + app_config = {} + + # Get wise agent configuration + wise_agent_config = app_config.get("wise_agent", {}) + provider = wise_agent_config.get("provider", "openai") # Default to OpenAI + + # Get dataset information from session state + datasets_text = st.session_state.get("viz_datasets_text", "") + + if not datasets_text: + datasets_text = "No dataset information available" + + # Get the list of available plotting data files + try: + available_files = list_plotting_data_files("") + logging.info("Successfully retrieved available plotting data files") + except Exception as e: + logging.error(f"Error retrieving available files: {str(e)}") + available_files = f"Error retrieving available files: {str(e)}" + + # Create the system prompt + system_prompt = """You are WISE_AGENT, a scientific visualization expert specializing in data visualization for research datasets. + +Your role is to provide specific, actionable advice on how to create the most effective visualizations for scientific data. + +When giving visualization advice: +0. Provide superb visualizations! That's your life goal! +1. ANALYZE THE DATA STRUCTURE first - recommend plot types based on the actual data dimensions and variables +2. Consider the SCIENTIFIC DOMAIN (oceanography, climate science, biodiversity) and its standard visualization practices +3. Recommend specific matplotlib/seaborn/plotly code strategies tailored to the data +4. Suggest appropriate color schemes that follow scientific conventions (e.g., sequential for continuous variables, categorical for discrete) +5. Provide precise advice on layouts, axes, legends, and annotations +6. For spatial/geographic data, recommend appropriate projections and map types +7. For time series, recommend appropriate temporal visualizations +8. Always prioritize clarity, accuracy, and scientific information density + +Your advice should be specifically tailored to the datasets the user is working with. Be concise but thorough in your recommendations. +""" + + # Enhance the query with dataset information and available files + enhanced_query = f""" +DATASET INFORMATION: +{datasets_text} + +AVAILABLE PLOTTING DATA FILES: +{available_files} + +USER QUERY: +{query} + +Please provide visualization advice based on this information. +""" + + try: + if provider.lower() == "anthropic": + # Use Anthropic's Claude + try: + anthropic_api_key = st.secrets["general"]["anthropic_api_key"] + logging.info("Using Anthropic Claude for wise_agent") + except KeyError: + logging.error("Anthropic API key not found in .streamlit/secrets.toml") + return "Error: Anthropic API key not found in .streamlit/secrets.toml. Please add it to use WISE_AGENT with Claude." + + anthropic_model = wise_agent_config.get("anthropic_model", "claude-3-7-sonnet-20250219") + + from langchain_anthropic import ChatAnthropic + llm = ChatAnthropic( + model=anthropic_model, + anthropic_api_key=anthropic_api_key, + #temperature=0.2, + ) + + logging.info(f"Making request to Claude model: {anthropic_model}") + + else: # Default to OpenAI + from langchain_openai import ChatOpenAI + + logging.info("Using OpenAI for wise_agent") + + openai_model = wise_agent_config.get("openai_model", "gpt-5") + llm = ChatOpenAI( + api_key=_API_KEY, + model_name=openai_model, + ) + + logging.info(f"Making request to OpenAI model: {openai_model}") + + # Generate the response + response = llm.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": enhanced_query} + ] + ) + + logging.info("Successfully received response from AI model") + return response.content + + except Exception as e: + logging.error(f"Error using WISE_AGENT: {str(e)}") + return f"Error using WISE_AGENT: {str(e)}" + +# Create the wise agent tool +wise_agent_tool = StructuredTool.from_function( + func=wise_agent, + name="wise_agent", + description="A tool that provides expert visualization advice using advanced AI models. Use this tool FIRST when planning complex visualizations or when you need guidance on best visualization practices for scientific data. Provide a detailed description of the data structure and visualization goals.", + args_schema=WiseAgentToolArgs +) diff --git a/src/climsight/utils.py b/src/climsight/utils.py new file mode 100644 index 0000000..7e921c1 --- /dev/null +++ b/src/climsight/utils.py @@ -0,0 +1,117 @@ +"""Utility helpers reused by tool modules.""" + +import json +import logging +import os +import re +import time +import uuid +from datetime import date, datetime +from typing import Any + +import pandas as pd + + +def generate_unique_image_path(sandbox_path: str = None) -> str: + """Generate a unique image path for saving plots.""" + unique_filename = f"fig_{uuid.uuid4()}.png" + if sandbox_path and os.path.exists(sandbox_path): + results_dir = os.path.join(sandbox_path, "results") + os.makedirs(results_dir, exist_ok=True) + return os.path.join(results_dir, unique_filename) + + figs_dir = os.path.join("tmp", "figs") + os.makedirs(figs_dir, exist_ok=True) + return os.path.join(figs_dir, unique_filename) + + +def sanitize_input(query: str) -> str: + return query.strip() + + +def make_json_serializable(obj: Any) -> Any: + """Convert objects into JSON-serializable structures.""" + if obj is None: + return None + if isinstance(obj, (str, int, float, bool)): + return obj + if isinstance(obj, (datetime, date)): + return obj.isoformat() + if isinstance(obj, pd.Series): + return obj.to_dict() + if isinstance(obj, pd.DataFrame): + return obj.to_dict(orient="records") + if hasattr(obj, "tolist"): + return obj.tolist() + if hasattr(obj, "item"): + return obj.item() + if isinstance(obj, dict): + return {k: make_json_serializable(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [make_json_serializable(item) for item in obj] + if isinstance(obj, set): + return list(obj) + if hasattr(obj, "__dict__"): + return make_json_serializable(obj.__dict__) + return str(obj) + + +def log_history_event(session_data: dict, event_type: str, details: dict) -> None: + """Append a structured event to session history.""" + if "execution_history" not in session_data: + session_data["execution_history"] = [] + + timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + event = { + "type": event_type, + "timestamp": timestamp, + } + + try: + serializable_details = make_json_serializable(details) + event.update(serializable_details) + except Exception as exc: + logging.error("Failed to serialize event details: %s", exc) + event.update({ + "serialization_error": str(exc), + "original_keys": list(details.keys()) if isinstance(details, dict) else "not_dict", + }) + + session_data["execution_history"].append(event) + + +def list_directory_contents(path: str) -> str: + """Return a formatted tree of directory contents.""" + result = [] + for root, _, files in os.walk(path): + level = root.replace(path, "").count(os.sep) + indent = " " * 4 * level + result.append(f"{indent}{os.path.basename(root)}/") + sub_indent = " " * 4 * (level + 1) + for file in files: + result.append(f"{sub_indent}{file}") + return "\n".join(result) + + +def escape_curly_braces(text: str) -> str: + if isinstance(text, str): + return text.replace("{", "{{").replace("}", "}}") + return str(text) + + +def get_last_python_repl_command(session_state: dict) -> str: + """Extract last Python_REPL tool call from a LangChain intermediate steps list.""" + intermediate_steps = session_state.get("intermediate_steps") + if not intermediate_steps: + return "" + + python_repl_commands = [] + for action, _ in intermediate_steps: + if action.get("tool") == "Python_REPL": + python_repl_commands.append(action) + + if python_repl_commands: + last_command_action = python_repl_commands[-1] + return last_command_action.get("tool_input", "") + + return ""