From edb945d1fafcbb87a4097bd9d9a6d20669a60c75 Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Fri, 16 Jan 2026 18:10:03 +0100 Subject: [PATCH 01/16] Refactor agent architecture: Split smart_agent into info gathering and data analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit restructures the multi-agent workflow to improve modularity and enable true parallel execution of information gathering agents. Key Changes: - Split smart_agent responsibilities: * smart_agent: Info gathering only (Wikipedia, RAG, ECOCROP) * data_analysis_agent: Data extraction & visualization (stub for now) - Updated routing logic in climsight_engine.py: * smart_agent now runs in true parallel with other agents (not triggered from data_agent) * Removed route_fromdata() function * All parallel agents → data_analysis_agent → combine_agent - Fixed NoSessionContext errors: * Added exception handling in stream_handler.py * Added exception handling in streamlit_interface.py * Parallel agents can now safely call update_progress() from worker threads - Added data_analysis_response field to AgentState Files modified: 5 Files added: 1 --- src/climsight/climsight_classes.py | 3 +- src/climsight/climsight_engine.py | 58 ++++---- src/climsight/data_analysis_agent.py | 72 +++++++++ src/climsight/smart_agent.py | 209 +++++++++------------------ src/climsight/stream_handler.py | 32 +++- src/climsight/streamlit_interface.py | 12 +- 6 files changed, 219 insertions(+), 167 deletions(-) create mode 100644 src/climsight/data_analysis_agent.py diff --git a/src/climsight/climsight_classes.py b/src/climsight/climsight_classes.py index 0f3c07b..082ff0a 100644 --- a/src/climsight/climsight_classes.py +++ b/src/climsight/climsight_classes.py @@ -21,7 +21,8 @@ class AgentState(BaseModel): ecocrop_search_response: str = "" rag_search_response: str = "" ipcc_rag_agent_response: str = "" - general_rag_agent_response: str = "" + general_rag_agent_response: str = "" + data_analysis_response: str = "" # Response from data analysis agent df_list: list = [] # List of dataframes with climate data references: list = [] # List of references combine_agent_prompt_text: str = "" diff --git a/src/climsight/climsight_engine.py b/src/climsight/climsight_engine.py index 8fa71fd..6d2fe23 100644 --- a/src/climsight/climsight_engine.py +++ b/src/climsight/climsight_engine.py @@ -63,6 +63,8 @@ # import smart_agent from smart_agent import get_aitta_chat_model, smart_agent +# import data_analysis_agent +from data_analysis_agent import data_analysis_agent # import climsight functions from geo_functions import ( @@ -1079,24 +1081,22 @@ def combine_agent(state: AgentState): } def route_fromintro(state: AgentState) -> Sequence[str]: + """Route from intro agent to parallel information gathering agents.""" output = [] if "FINISH" in state.next: return "FINISH" else: + # Always run these agents in parallel output.append("ipcc_rag_agent") output.append("general_rag_agent") output.append("data_agent") output.append("zero_rag_agent") - #output.append("smart_agent") - return output - def route_fromdata(state: AgentState) -> Sequence[str]: - output = [] - if config['use_smart_agent']: - output.append("smart_agent") - else: - output.append("combine_agent") + + # Conditionally add smart_agent based on config + if config.get('use_smart_agent', False): + output.append("smart_agent") return output - + workflow = StateGraph(AgentState) figs = data_pocket.figs @@ -1108,28 +1108,36 @@ def route_fromdata(state: AgentState) -> Sequence[str]: workflow.add_node("ipcc_rag_agent", ipcc_rag_agent) workflow.add_node("general_rag_agent", general_rag_agent) workflow.add_node("data_agent", lambda s: data_agent(s, data, df)) # Pass `data` as argument - workflow.add_node("zero_rag_agent", lambda s: zero_rag_agent(s, figs)) # Pass `figs` as argument + workflow.add_node("zero_rag_agent", lambda s: zero_rag_agent(s, figs)) # Pass `figs` as argument workflow.add_node("smart_agent", lambda s: smart_agent(s, config, api_key, api_key_local, stream_handler)) + workflow.add_node("data_analysis_agent", lambda s: data_analysis_agent(s, config, api_key, api_key_local, stream_handler)) workflow.add_node("combine_agent", combine_agent) - path_map = {'ipcc_rag_agent':'ipcc_rag_agent', 'general_rag_agent':'general_rag_agent', 'data_agent':'data_agent','zero_rag_agent':'zero_rag_agent','FINISH':END} - path_map_data = {'combine_agent':'combine_agent', 'smart_agent':'smart_agent'} + path_map = { + 'ipcc_rag_agent': 'ipcc_rag_agent', + 'general_rag_agent': 'general_rag_agent', + 'data_agent': 'data_agent', + 'zero_rag_agent': 'zero_rag_agent', + 'smart_agent': 'smart_agent', + 'FINISH': END + } - workflow.set_entry_point("intro_agent") # Set the entry point of the graph - + workflow.set_entry_point("intro_agent") # Set the entry point of the graph + + # Route from intro_agent to parallel agents (conditionally includes smart_agent) workflow.add_conditional_edges("intro_agent", route_fromintro, path_map=path_map) - workflow.add_conditional_edges("data_agent", route_fromdata, path_map=path_map_data) - #if config['use_smart_agent']: - # workflow.add_edge(["ipcc_rag_agent","general_rag_agent","data_agent","zero_rag_agent"], "combine_agent") - #else: - workflow.add_edge(["ipcc_rag_agent","general_rag_agent","smart_agent","zero_rag_agent"], "combine_agent") - - #workflow.add_edge("ipcc_rag_agent", "combine_agent") - #workflow.add_edge("general_rag_agent", "combine_agent") - #workflow.add_edge("data_agent", "combine_agent") - #workflow.add_edge("zero_rag_agent", "combine_agent") - #workflow.add_edge("smart_agent", "combine_agent") + # All parallel agents (ipcc_rag, general_rag, data, zero_rag) go to data_analysis_agent + workflow.add_edge(["ipcc_rag_agent", "general_rag_agent", "data_agent", "zero_rag_agent"], "data_analysis_agent") + + # If smart_agent is enabled, it also goes to data_analysis_agent + if config.get('use_smart_agent', False): + workflow.add_edge("smart_agent", "data_analysis_agent") + + # Data analysis agent goes to combine agent + workflow.add_edge("data_analysis_agent", "combine_agent") + + # Combine agent goes to END workflow.add_edge("combine_agent", END) # Compile the graph app = workflow.compile() diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py new file mode 100644 index 0000000..c3258a0 --- /dev/null +++ b/src/climsight/data_analysis_agent.py @@ -0,0 +1,72 @@ +""" +Data Analysis Agent + +This agent receives all state from previous agents and performs: +1. Climate data extraction from datasets +2. Data post-processing and calculations +3. Visualization and analysis using Python REPL +4. Scientific interpretation of results + +Currently a stub implementation - will be fully implemented in next phase. + +This agent will eventually include: +- get_data_components tool (from smart_agent_backup.py) +- python_repl tool for analysis and visualization +- image_viewer tool for scientific interpretation of plots +- Statistical calculations and trend analysis +""" + +from climsight_classes import AgentState +import logging + +logger = logging.getLogger(__name__) + + +def data_analysis_agent(state: AgentState, config, api_key, api_key_local, stream_handler): + """ + Data analysis and visualization agent (stub implementation). + + This agent will be responsible for: + - Extracting specific climate data components based on user needs + - Performing statistical analysis and calculations + - Creating visualizations and plots + - Interpreting results using image_viewer + - Providing data-driven insights + + Args: + state: AgentState containing all previous agent outputs including: + - user: Original user query + - input_params: Location and other parameters + - df_list: Climate data from data_agent + - smart_agent_response: Background information (if smart_agent was enabled) + - ipcc_rag_agent_response: IPCC report insights + - general_rag_agent_response: General climate literature insights + - zero_rag_agent_response: Geographic/environmental context + config: Configuration dictionary + api_key: OpenAI API key + api_key_local: Local model API key + stream_handler: Stream handler for progress updates + + Returns: + dict: Updated state with analysis results + """ + stream_handler.update_progress("Data analysis agent (stub - passing through)...") + + logger.info("Data analysis agent stub called") + logger.info(f"User query: {state.user}") + logger.info(f"Location: {state.input_params.get('location_str', 'unknown')}") + + # TODO: Implement full data analysis functionality + # Phase 1 implementation will include: + # 1. Move get_data_components tool from smart_agent_backup.py + # 2. Move python_repl tool initialization with climate data context + # 3. Move image_viewer tool for plot interpretation + # 4. Create agent executor with data analysis prompts + # 5. Integrate with combine_agent for final synthesis + + # For now, just pass through with a stub message + stub_response = "Data analysis agent not yet implemented. Analysis and visualization features will be added in the next implementation phase." + + return { + 'data_analysis_response': stub_response + } diff --git a/src/climsight/smart_agent.py b/src/climsight/smart_agent.py index e27eea9..4a98760 100644 --- a/src/climsight/smart_agent.py +++ b/src/climsight/smart_agent.py @@ -29,8 +29,7 @@ from langchain_core.prompts import ChatPromptTemplate #Import tools -from tools.python_repl import create_python_repl_tool -from tools.image_viewer import create_image_viewer_tool +#Import tools (python_repl and image_viewer moved to data_analysis_agent) #import requests #from bs4 import BeautifulSoup @@ -85,80 +84,74 @@ def smart_agent(state: AgentState, config, api_key, api_key_local, stream_handle # System prompt prompt = f""" - You are the smart agent of ClimSight. Your task is to retrieve necessary components of the climatic datasets based on the user's request. - You have access to tools called "get_data_components", "wikipedia_search", "RAG_search" and "ECOCROP_search" and "python_repl" which you can use to retrieve the necessary environmental data components. - - "get_data_components" will retrieve the necessary data from the climatic datasets at the location of interest (latitude: {lat}, longitude: {lon}). It accepts an 'environmental_data' parameter to specify the type of data, and a 'months' parameter to specify which months to retrieve data for. The 'months' parameter is a list of month names (e.g., ['Jan', 'Feb', 'Mar']). If 'months' is not specified, data for all months will be retrieved. - Call "get_data_components" tool multiple times if necessary, but only within one iteration, [chat_completion -> n * "get_data_components" -> chat_completion] after you recieve the necessary data from wikipedia_search and RAG_search. - - "wikipedia_search" will help you determine the necessary data to retrieve with the get_data_components tool. - - "RAG_search" can provide detailed information about environmental conditions for growing corn from your internal knowledge base. - - "ECOCROP_search" will help you determine the specific environmental requirements for the crop of interest from ecocrop database. - call "ECOCROP_search" ONLY and ONLY if you sure that the user question is related to the crop of interest. - - "python_repl" allows you to execute Python code for data analysis, visualization, and calculations. - Use this tool when: - - Creating visualizations (plots, charts, graphs) of climate data - - Performing statistical analysis (means, trends, correlations, standard deviations) - - Comparing data between different time periods (e.g., historical vs future projections) - - Calculating climate indicators or derived metrics - - Any analysis that requires more than simple data retrieval - - The tool has access to pandas, numpy, matplotlib, and xarray. - - **IMPORTANT: Climate data is pre-loaded in the Python environment. To see what's available, run:** - ```python - print(DATA_CATALOG) # Shows all available climate datasets and their descriptions - print(list(locals().keys())) # Shows all available variables - ``` - The climate data includes historical reference periods and future projections with monthly temperature, - precipitation, and wind data for your specific location. - - **CRITICAL: Your working directory is available at `work_dir` = '{str(work_dir)}'** - **When saving plots, ALWAYS store the full path in a variable for later use!** - - **CORRECT way to save and reference images:** - ```python - # Save the plot and store the full path - plot_path = f'{{{{work_dir}}}}/my_plot.png' # Or use Path(work_dir) / 'my_plot.png' - plt.savefig(plot_path) - print(f"Plot saved to: {{{{plot_path}}}}") # Verify the path - # Now plot_path contains the full path for image_viewer - ``` - - **WRONG way (DO NOT do this):** - ```python - plt.savefig('work_dir/my_plot.png') # This is just a string, not the actual path! - ``` - - - - "image_viewer" allows you to analyze saved climate visualizations to extract scientific insights. - Use this tool when: - - You have created and saved a visualization using python_repl - - You need to describe the patterns shown in a climate plot - - You want to extract specific values or trends from a generated figure - - The tool will provide scientific analysis of the image including: - - Data description and quantitative observations - - Temporal patterns and climate insights - - Key findings and implications - - **CRITICAL: How to use image_viewer correctly:** - 1. First save your plot in python_repl and store the path: - ```python - plot_path = f'{{{{work_dir}}}}/temperature_plot.png' - plt.savefig(plot_path) - ``` - 2. Then use image_viewer with the VARIABLE containing the path: - ``` - image_viewer(plot_path) # Use the variable, not a string! - ``` - - **NEVER do this:** - ``` - image_viewer('work_dir/plot.png') # WRONG - this is just a string! - ``` - - **Your actual work_dir path is: {str(work_dir)}** - **Always use the full path stored in a variable when calling image_viewer!** - + You are the information gathering agent of ClimSight. Your task is to collect relevant background + information about the user's query using external knowledge sources. + + You have access to three information retrieval tools: + - "wikipedia_search": Search Wikipedia for general information about the topic of interest. + Use this to understand the broader context of the user's question. + + - "RAG_search": Query ClimSight's internal climate knowledge base for scientific details. + Use this to find detailed scientific information from climate literature and research. + + - "ECOCROP_search": Get crop-specific environmental requirements from the ECOCROP database. + ONLY use this tool if the user's question is clearly about agriculture or crops. + Do NOT use it for general climate queries. + + **Your goal**: Gather comprehensive background information and compile it into a well-structured + summary that will be used by subsequent agents for data analysis. + + **Important guidelines**: + - When citing information from Wikipedia or RAG sources, include source references in your summary + - Present ECOCROP database results exactly as they are provided (no citations needed for database facts) + - Focus on gathering contextual information, NOT on extracting or analyzing specific climate data + - Do NOT attempt to retrieve climate data values - that will be handled by other agents + - Your output should be a comprehensive text summary with relevant background context + """ + if config['llm_smart']['model_type'] in ("local", "aitta"): + prompt += f""" + + - Tool use order. Call the wikipedia_search, RAG_search, and ECOCROP_search tools as needed, + but only one at a time per turn. Gather all relevant background information. + """ + else: + prompt += f""" + + - Tool use order. Call wikipedia_search, RAG_search, and ECOCROP_search + (if applicable) SIMULTANEOUSLY to gather background information efficiently. + + """ + prompt += f""" + + **Output format**: + After gathering information from the available tools, provide a well-structured summary organized as follows: + + 1. **General Background** (from Wikipedia if available): + - Key facts and context about the topic + - Relevant qualitative and quantitative information + - Include specific values with units where mentioned + + 2. **Scientific Context** (from RAG if available): + - Detailed scientific information from climate literature + - Relevant research findings and climate insights + - Technical details that provide context + + 3. **Specific Requirements** (from ECOCROP if applicable): + - Database results presented exactly as provided + - Environmental thresholds and requirements + - Optimal and tolerable ranges + + 4. **Key Takeaways**: + - Summarize the most important points for understanding the user's query + - Highlight critical factors or thresholds mentioned + + + - Keep your summary concise but comprehensive + - Do NOT include chain-of-thought reasoning or tool invocation details + - Do NOT mention what you're going to do or explain your process + - Focus on presenting the gathered information in a clear, organized format + - The summary should help subsequent agents understand the context for data analysis + """ if config['llm_smart']['model_type'] in ("local", "aitta"): prompt += f""" @@ -782,53 +775,8 @@ def process_ecocrop_search(query: str) -> str: args_schema=EcoCropSearchArgs ) - python_repl_tool = create_python_repl_tool() - - def inject_climate_context(): - context = { - 'lat': lat, - 'lon': lon, - 'location_str': state.input_params.get('location_str', ''), - 'work_dir': str(work_dir), - } - - # Add climate data if available - if state.df_list: - # Add all dataframes from df_list - for i, entry in enumerate(state.df_list): - df = entry.get('dataframe') - if df is not None: - context[f'climate_df_{i}'] = df - context[f'climate_info_{i}'] = { - 'years': entry.get('years_of_averaging', ''), - 'description': entry.get('description', ''), - 'variables': entry.get('extracted_vars', {}) - } - - # ADD THIS - Build DATA_CATALOG dynamically - catalog = "Available climate datasets:\n" - for i, entry in enumerate(state.df_list): - years = entry.get('years_of_averaging', '') - desc = entry.get('description', '') - is_main = " (historical reference)" if entry.get('main', False) else "" - catalog += f"- climate_df_{i}: {years}{is_main} - {desc}\n" - - catalog += "\nEach dataset contains monthly values for:\n" - if state.df_list and state.df_list[0].get('extracted_vars'): - for var_name, var_info in state.df_list[0]['extracted_vars'].items(): - catalog += f"- {var_info['full_name']} ({var_info['units']})\n" - - context['DATA_CATALOG'] = catalog - - return context - - # Inject climate context into Python REPL BEFORE agent runs - if hasattr(python_repl_tool.func, '__self__'): - repl_instance = python_repl_tool.func.__self__ - context = inject_climate_context() - repl_instance.locals.update(context) - + # python_repl_tool moved to data_analysis_agent # Initialize the LLM if config['llm_smart']['model_type'] == "local": @@ -849,21 +797,8 @@ def inject_climate_context(): llm = get_aitta_chat_model(config['llm_smart']['model_name'], temperature = 0) # List of tools - #tools = [data_extraction_tool, rag_tool, wikipedia_tool, ecocrop_tool, python_repl_tool] - tools = [data_extraction_tool, rag_tool, ecocrop_tool, wikipedia_tool]#python_repl_tool] - - ##Append image viewer for openai models - #if config['model_type'] == "openai": - # try: - # image_viewer_tool = create_image_viewer_tool( - # api_key, - # config['model_name_agents'] # Use model from config - # ) - # tools.append(image_viewer_tool) - # except Exception as e: - # pass - - # Create the agent with the tools and prompt + tools = [rag_tool, ecocrop_tool, wikipedia_tool] + prompt += """\nadditional information:\n question is related to this location: {location_str} \n """ diff --git a/src/climsight/stream_handler.py b/src/climsight/stream_handler.py index 0720661..b3699c3 100644 --- a/src/climsight/stream_handler.py +++ b/src/climsight/stream_handler.py @@ -1,4 +1,7 @@ from langchain_core.callbacks.base import BaseCallbackHandler +import logging + +logger = logging.getLogger(__name__) class StreamHandler(BaseCallbackHandler): """ @@ -34,7 +37,15 @@ def _display_text(self): if self.container: display_function = getattr(self.container, self.display_method, None) if display_function is not None: - display_function(self.text) + try: + display_function(self.text) + except Exception as e: + # Streamlit context not available (e.g., running in worker thread) + # Log the message instead + if "NoSessionContext" in str(type(e).__name__): + logger.debug(f"Streamlit context not available, skipping display: {self.text[:100]}") + else: + logger.error(f"Error displaying text: {e}") else: raise ValueError(f"Invalid display_method: {self.display_method}") @@ -42,7 +53,14 @@ def _display_reference_text(self): if self.container2: display_function = getattr(self.container2, self.display_method, None) if display_function is not None: - display_function(self.reference_text) + try: + display_function(self.reference_text) + except Exception as e: + # Streamlit context not available (e.g., running in worker thread) + if "NoSessionContext" in str(type(e).__name__): + logger.debug(f"Streamlit context not available, skipping reference display") + else: + logger.error(f"Error displaying reference text: {e}") else: raise ValueError(f"Invalid display_method: {self.display_method}") @@ -50,7 +68,15 @@ def _display_progress(self): if self.container: display_function = getattr(self.container, "info", None) if display_function is not None: - display_function(self.progress_text) + try: + display_function(self.progress_text) + except Exception as e: + # Streamlit context not available (e.g., running in worker thread) + # This is expected when agents run in parallel + if "NoSessionContext" in str(type(e).__name__): + logger.debug(f"Progress update (no UI context): {self.progress_text}") + else: + logger.error(f"Error displaying progress: {e}") def get_text(self): return self.text diff --git a/src/climsight/streamlit_interface.py b/src/climsight/streamlit_interface.py index 5746421..394b804 100644 --- a/src/climsight/streamlit_interface.py +++ b/src/climsight/streamlit_interface.py @@ -288,7 +288,17 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r # Add a method to update the progress area def update_progress_ui(message): - progress_area.info(message) + try: + progress_area.info(message) + except Exception as e: + # Streamlit context not available (e.g., running in worker thread) + # This is expected when agents run in parallel + import logging + logger = logging.getLogger(__name__) + if "NoSessionContext" in str(type(e).__name__): + logger.debug(f"Progress update (no UI context): {message}") + else: + logger.error(f"Error displaying progress: {e}") # Attach this method to your StreamHandler stream_handler.update_progress = update_progress_ui From 8154af24ab48e9f97606182f5a4c7cf1768796d4 Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Fri, 16 Jan 2026 18:33:24 +0100 Subject: [PATCH 02/16] Fix smart_agent.py: Remove stale prompt sections and tool definitions - Removed prompt sections referencing get_data_components and python_repl - Removed get_data_components tool definition (244 lines) - Removed tool output processing for get_data_components and python_repl - Prompt now correctly matches available tools (wikipedia_search, RAG_search, ECOCROP_search only) This fixes the mismatch where the prompt instructed the model to use tools that were no longer exposed in the agent's tool list. --- src/climsight/smart_agent.py | 291 +---------------------------------- 1 file changed, 3 insertions(+), 288 deletions(-) diff --git a/src/climsight/smart_agent.py b/src/climsight/smart_agent.py index 4a98760..30255da 100644 --- a/src/climsight/smart_agent.py +++ b/src/climsight/smart_agent.py @@ -153,283 +153,8 @@ def smart_agent(state: AgentState, config, api_key, api_key_local, stream_handle - The summary should help subsequent agents understand the context for data analysis """ - if config['llm_smart']['model_type'] in ("local", "aitta"): - prompt += f""" - - - Tool use order. Always call the wikipedia_search, RAG_search, and ECOCROP_search tools as needed, but only one at a time per turn.; it will help you determine the necessary data to retrieve with the get_data_components tool. At second step, call the get_data_components tool with the necessary data. - """ - else: - prompt += f""" - - - Tool use order. ALWAYS call FIRST SIMULTANEOUSLY the wikipedia_search, RAG_search and "ECOCROP_search"; it will help you determine the necessary data to retrieve with the get_data_components tool. At second step, call the get_data_components tool with the necessary data. - - """ - prompt += f""" - - Use these tools to get the data you need to answer the user's question. - After retrieving the data, provide a concise summary of the parameters you retrieved, explaining briefly why they are important. Keep your response short and to the point. - Do not include any additional explanations or reasoning beyond the concise summary. - Do not include any chain-of-thought reasoning or action steps in your final answer. - Do not ask the user for any additional information, but you can include into the final answer what kind of information user should provide in the future. - - Additional workflow guidance: - - If the user asks for analysis, trends, comparisons, or visualizations, use python_repl after retrieving data to: - * Create plots and charts - * Calculate statistics (averages, changes, trends) - * Compare different time periods - * Save any outputs to work_dir - - - For the final response try to follow the following format: - 'The [retrieved values of the parameter] for the [object of interest] at [location] is [value for current and future are ...], [according to the Wikipedia article] the required [parameter] for [object of interest] is [value]. [Two sentence of clarification, with criitcal montly-based assessment of the potential changes]' - 'Repeat for each parameter.' - - """ - - - #[1] Tool description for netCDF extraction - class get_data_components_args(BaseModel): - environmental_data: Optional[Union[str, Literal["Temperature", "Precipitation", "u_wind", "v_wind"]]] = Field( - default=None, - description="The type of environmental data to retrieve. Choose from Temperature, Precipitation, u_wind, or v_wind.", - enum_description={ - "Temperature": "The mean monthly temperature data.", - "Precipitation": "The mean monthly precipitation data.", - "u_wind": "The mean monthly u wind component data.", - "v_wind": "The mean monthly v wind component data." - } - ) - months: Optional[Union[str, List[Literal["Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]]]] = Field( - default=None, - description="List of months or a stringified list of month names to retrieve data for. Each month should be one of 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'. If not specified, data for all months will be retrieved." - ) - def get_data_components(**kwargs): - stream_handler.update_progress("Retrieving data for advanced analysis with a smart agent...") - - if isinstance(kwargs.get("months"), str): - try: - kwargs["months"] = ast.literal_eval(kwargs["months"]) - except (ValueError, SyntaxError): - # Optional: handle invalid input - kwargs["months"] = None # or raise an exception - args = get_data_components_args(**kwargs) - # Parse the arguments using the args_schema - environmental_data = args.environmental_data - months = args.months # List of month names - if environmental_data is None: - return {"error": "No environmental data type specified."} - if environmental_data not in ["Temperature", "Precipitation", "u_wind", "v_wind"]: - return {"error": f"Invalid environmental data type: {environmental_data}"} - - # Get climate data source from config - climate_source = config.get('climate_data_source', 'nextGEMS') - - # Check if we have df_list from the provider (unified approach) - df_list = getattr(state, 'df_list', None) - - if df_list: - # Use unified df_list from any provider (nextGEMS, ICCP, AWI_CM) - response = {} - - # Variable mapping depends on data source - if climate_source == 'nextGEMS': - environmental_mapping = { - "Temperature": "mean2t", - "Precipitation": "tp", - "u_wind": "wind_u", - "v_wind": "wind_v" - } - elif climate_source == 'ICCP': - environmental_mapping = { - "Temperature": "mean2t", - "Precipitation": "tp", - "u_wind": "wind_u", - "v_wind": "wind_v" - } - elif climate_source == 'AWI_CM': - environmental_mapping = { - "Temperature": "Present Day Temperature", - "Precipitation": "Present Day Precipitation", - "u_wind": "u_wind", - "v_wind": "v_wind" - } - else: - environmental_mapping = { - "Temperature": "mean2t", - "Precipitation": "tp", - "u_wind": "wind_u", - "v_wind": "wind_v" - } - - if environmental_data not in environmental_mapping: - return {"error": f"Invalid environmental data type: {environmental_data}"} - - # Filter the DataFrame for the selected months and extract the values - var_name = environmental_mapping[environmental_data] - - if not months: - months = [calendar.month_abbr[m] for m in range(1, 13)] - - # Create a mapping from abbreviated to full month names - month_mapping = {calendar.month_abbr[m]: calendar.month_name[m] for m in range(1, 13)} - selected_months = [month_mapping[abbr] for abbr in months] - - for entry in df_list: - df = entry.get('dataframe') - extracted_vars = entry.get('extracted_vars', {}) - - if df is None: - raise ValueError(f"Entry does not contain a 'dataframe' key.") - - # Try to find the variable in the dataframe - if var_name in df.columns: - var_meta = extracted_vars.get(var_name, {'units': ''}) - data_values = df[df['Month'].isin(selected_months)][var_name].tolist() - ext_data = {month: np.round(value, 2) for month, value in zip(selected_months, data_values)} - ext_exp = f"Monthly mean values of {environmental_data}, {var_meta.get('units', '')} for years: " + entry['years_of_averaging'] - response.update({ext_exp: ext_data}) - else: - # Try to find by full name in columns - matching_cols = [col for col in df.columns if environmental_data.lower() in col.lower()] - if matching_cols: - col_name = matching_cols[0] - data_values = df[df['Month'].isin(selected_months)][col_name].tolist() - ext_data = {month: np.round(value, 2) for month, value in zip(selected_months, data_values)} - ext_exp = f"Monthly mean values of {environmental_data} for years: " + entry['years_of_averaging'] - response.update({ext_exp: ext_data}) - - return response - else: - # Legacy fallback for AWI_CM without df_list - lat = float(state.input_params['lat']) - lon = float(state.input_params['lon']) - data_path = config['data_settings']['data_path'] - - # Dictionaries for historical and SSP585 data files - data_files_historical = { - "Temperature": ("AWI_CM_mm_historical.nc", "tas"), - "Precipitation": ("AWI_CM_mm_historical_pr.nc", "pr"), - "u_wind": ("AWI_CM_mm_historical_uas.nc", "uas"), - "v_wind": ("AWI_CM_mm_historical_vas.nc", "vas") - } - - data_files_ssp585 = { - "Temperature": ("AWI_CM_mm_ssp585.nc", "tas"), - "Precipitation": ("AWI_CM_mm_ssp585_pr.nc", "pr"), - "u_wind": ("AWI_CM_mm_ssp585_uas.nc", "uas"), - "v_wind": ("AWI_CM_mm_ssp585_vas.nc", "vas") - } - - if environmental_data not in data_files_historical: - return {"error": f"Invalid environmental data type: {environmental_data}"} - - # Get file names and variable names for both datasets - file_name_hist, var_name_hist = data_files_historical[environmental_data] - file_name_ssp585, var_name_ssp585 = data_files_ssp585[environmental_data] - - # Build file paths - file_path_hist = os.path.join(data_path, file_name_hist) - file_path_ssp585 = os.path.join(data_path, file_name_ssp585) - - # Check if files exist - if not os.path.exists(file_path_hist): - return {"error": f"Data file {file_name_hist} not found in {data_path}"} - if not os.path.exists(file_path_ssp585): - return {"error": f"Data file {file_name_ssp585} not found in {data_path}"} - - # Open datasets - dataset_hist = nc.Dataset(file_path_hist) - dataset_ssp585 = nc.Dataset(file_path_ssp585) - - # Get latitude and longitude arrays - lats_hist = dataset_hist.variables['lat'][:] - lons_hist = dataset_hist.variables['lon'][:] - lats_ssp585 = dataset_ssp585.variables['lat'][:] - lons_ssp585 = dataset_ssp585.variables['lon'][:] - - # Find the nearest indices for historical data - lat_idx_hist = (np.abs(lats_hist - lat)).argmin() - lon_idx_hist = (np.abs(lons_hist - lon)).argmin() - - # Find the nearest indices for SSP585 data - lat_idx_ssp585 = (np.abs(lats_ssp585 - lat)).argmin() - lon_idx_ssp585 = (np.abs(lons_ssp585 - lon)).argmin() - - # Extract data at the specified location - data_hist = dataset_hist.variables[var_name_hist][:, :, :, lat_idx_hist, lon_idx_hist] - data_ssp585 = dataset_ssp585.variables[var_name_ssp585][:, :, :, lat_idx_ssp585, lon_idx_ssp585] - - # Squeeze data to remove singleton dimensions (shape becomes (12,)) - data_hist = np.squeeze(data_hist) - data_ssp585 = np.squeeze(data_ssp585) - - # Process data according to the variable - if environmental_data == "Temperature": - # Convert from Kelvin to Celsius - data_hist = data_hist - 273.15 - data_ssp585 = data_ssp585 - 273.15 - units = "°C" - elif environmental_data == "Precipitation": - # Convert from kg m-2 s-1 to mm/month - days_in_month = np.array([31, 28, 31, 30, 31, 30, - 31, 31, 30, 31, 30, 31]) - seconds_in_month = days_in_month * 24 * 3600 # seconds in each month - data_hist = data_hist * seconds_in_month - data_ssp585 = data_ssp585 * seconds_in_month - units = "mm/month" - elif environmental_data in ["u_wind", "v_wind"]: - # Units are already in m/s - units = "m/s" - else: - units = "unknown" - - # Close datasets - dataset_hist.close() - dataset_ssp585.close() - - # List of all month names - all_months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] - - # Map month names to indices - month_indices = {month: idx for idx, month in enumerate(all_months)} - - # If months are specified, select data for those months - if months: - # Validate months - valid_months = [month for month in months if month in month_indices] - if not valid_months: - return {"error": "Invalid months provided."} - selected_indices = [month_indices[month] for month in valid_months] - selected_months = valid_months - else: - # Use all months if none are specified - selected_indices = list(range(12)) - selected_months = all_months - - # Subset data for selected months - data_hist = data_hist[selected_indices] - data_ssp585 = data_ssp585[selected_indices] - - # Create dictionaries mapping months to values with units - hist_data_dict = {month: f"{value:.2f} {units}" for month, value in zip(selected_months, data_hist)} - ssp585_data_dict = {month: f"{value:.2f} {units}" for month, value in zip(selected_months, data_ssp585)} - - # Return both historical and SSP585 data - return { - f"{environmental_data}_historical": hist_data_dict, - f"{environmental_data}_ssp585": ssp585_data_dict - } - - # Define the data_extraction_tool - data_extraction_tool = StructuredTool.from_function( - func=get_data_components, - name="get_data_components", - description="Retrieve the necessary environmental data component.", - args_schema=get_data_components_args - ) + #[1] get_data_components tool moved to data_analysis_agent.py #[2] Wikipedia processing tool def process_wikipedia_article(query: str) -> str: @@ -850,23 +575,13 @@ def process_ecocrop_search(query: str) -> str: state.references.append(output.get('references', [])) else: tool_outputs['wikipedia_search'] = output - elif action.tool == 'get_data_components': - if isinstance(observation, AIMessage): - tool_outputs['get_data_components'] = observation.content - else: - tool_outputs['get_data_components'] = observation elif action.tool == 'ECOCROP_search': if isinstance(observation, AIMessage): tool_outputs['ECOCROP_search'] = observation.content else: tool_outputs['ECOCROP_search'] = observation - if any("FAO, IIASA" not in element for element in state.references): - state.references.append("FAO, IIASA: Global Agro-Ecological Zones (GAEZ V4) - Data Portal User's Guide, 1st edn. FAO and IIASA, Rome, Italy (2021). https://doi.org/10.4060/cb5167en") - elif action.tool == 'python_repl': - if isinstance(observation, AIMessage): - tool_outputs['python_repl'] = observation.content - else: - tool_outputs['python_repl'] = observation + if any("FAO, IIASA" not in element for element in state.references): + state.references.append("FAO, IIASA: Global Agro-Ecological Zones (GAEZ V4) - Data Portal User's Guide, 1st edn. FAO and IIASA, Rome, Italy (2021). https://doi.org/10.4060/cb5167en") if action.tool == 'RAG_search': if isinstance(observation, AIMessage): tool_outputs['RAG_search'] = observation.content From 1a691add2a587ca387c5a67a62023732078d1f71 Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Tue, 20 Jan 2026 17:03:31 +0100 Subject: [PATCH 03/16] Implement data_analysis_agent with ERA5 ground truth workflow and tool improvements Major changes: - Implement full data_analysis_agent with dynamic tool prompt based on config - Make ERA5 mandatory when enabled - use as ground truth for climate model validation - Remove redundant get_data_components when Python_REPL is enabled - Preserve user query in analysis_brief with "USER QUESTION:" header - Add configurable filter step via use_filter_step in config - Accumulate multiple Wikipedia/RAG results in smart_agent (was overwriting) - Add Wikipedia call limit (10 max) to prevent excessive API calls New files: - agent_helpers.py: Helper utilities for tool-based agents - sandbox_utils.py: Sandbox directory management - config.py: Configuration utilities - utils.py: Logging and history utilities - tools/era5_retrieval_tool.py: ERA5 data download tool - tools/get_data_components.py: Climate data extraction tool - tools/visualization_tools.py: File listing and helper tools - tools/reflection_tools.py: Agent reflection tool - tools/package_tools.py: Package installation tool Type changes in AgentState: - wikipedia_tool_response: str -> list (accumulate multiple results) - rag_search_response: str -> list (accumulate multiple results) - Added: data_analysis_prompt_text, data_analysis_images, thread_id, uuid_main_dir, results_dir, climate_data_dir, era5_data_dir Python REPL improvements: - Rewrite to use JupyterKernelExecutor (PangaeaGPT pattern) - Auto-load datasets from sandbox paths - Better plot detection and results directory handling --- .gitignore | 3 +- config.yml | 12 +- pyproject.toml | 13 +- requirements.txt | 13 + src/climsight/agent_helpers.py | 121 +++++ src/climsight/climsight_classes.py | 11 +- src/climsight/climsight_engine.py | 86 +++- src/climsight/config.py | 20 + src/climsight/data_analysis_agent.py | 560 ++++++++++++++++++--- src/climsight/sandbox_utils.py | 130 +++++ src/climsight/smart_agent.py | 102 ++-- src/climsight/streamlit_interface.py | 36 +- src/climsight/terminal_interface.py | 11 +- src/climsight/tools/__init__.py | 11 +- src/climsight/tools/era5_retrieval_tool.py | 275 ++++++++++ src/climsight/tools/get_data_components.py | 206 ++++++++ src/climsight/tools/package_tools.py | 37 ++ src/climsight/tools/python_repl.py | 511 +++++++++++++++---- src/climsight/tools/reflection_tools.py | 135 +++++ src/climsight/tools/visualization_tools.py | 261 ++++++++++ src/climsight/utils.py | 117 +++++ 21 files changed, 2451 insertions(+), 220 deletions(-) create mode 100644 src/climsight/agent_helpers.py create mode 100644 src/climsight/config.py create mode 100644 src/climsight/sandbox_utils.py create mode 100644 src/climsight/tools/era5_retrieval_tool.py create mode 100644 src/climsight/tools/get_data_components.py create mode 100644 src/climsight/tools/package_tools.py create mode 100644 src/climsight/tools/reflection_tools.py create mode 100644 src/climsight/tools/visualization_tools.py create mode 100644 src/climsight/utils.py diff --git a/.gitignore b/.gitignore index 1819ead..29f8f22 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,5 @@ rag_articles/ .* *.log venv311 -venv \ No newline at end of file +venv +tmp/ \ No newline at end of file diff --git a/config.yml b/config.yml index 4e0c2e6..92afe69 100644 --- a/config.yml +++ b/config.yml @@ -2,16 +2,22 @@ #model_type: "openai" #"openai / local / aitta llm_rag: model_type: "openai" - model_name: "gpt-4.1-nano" # used only for RAGs + model_name: "gpt-5-mini" # used only for RAGs llm_smart: #used only in smart_agent model_type: "openai" - model_name: "gpt-4.1-nano" # used only for smart agent + model_name: "gpt-5-mini" # used only for smart agent llm_combine: #used only in combine_agent and intro model_type: "openai" - model_name: "gpt-4.1-nano" # used only for combine agent ("mkchaou/climsight-calm_ft_Q3_13k") + model_name: "gpt-5.2" # used only for combine agent ("mkchaou/climsight-calm_ft_Q3_13k") +llm_dataanalysis: #used only in data_analysis_agent + model_type: "openai" + model_name: "gpt-5.2" + use_filter_step: true # Set to false to skip context filtering LLM call climatemodel_name: "AWI_CM" llmModeKey: "agent_llm" #"agent_llm" #"direct_llm" use_smart_agent: false +use_era5_data: false +use_powerful_data_analysis: false # Climate Data Source Configuration # Options: "nextGEMS", "ICCP", "AWI_CM" diff --git a/pyproject.toml b/pyproject.toml index d17c21d..7268c77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,12 +17,14 @@ dependencies = [ "xarray", "geopy", "geopandas", + "shapely", "pyproj", "requests", "requests-mock", "pandas", "folium", "langchain", + "langchain-classic", "streamlit-folium", "netcdf4", "dask", @@ -34,13 +36,22 @@ dependencies = [ "langchain-openai", "langchain-chroma", "langchain-core", + "langchain-text-splitters", + "langchain-experimental", + "langchain-anthropic", "pydantic", "langgraph", "bs4", "wikipedia", "scipy", - "pyproj", "reportlab", + "pyyaml", + "jupyter_client", + "ipykernel", + "zarr", + "gcsfs", + "numcodecs", + "blosc", "aitta-client" ] diff --git a/requirements.txt b/requirements.txt index 2bd51be..d8602d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ streamlit-folium xarray netcdf4 geopandas +shapely pyproj geopy osmnx @@ -32,6 +33,8 @@ langchain-openai langchain-chroma langchain-core langchain-text-splitters +langchain-experimental +langchain-anthropic langgraph openai pydantic @@ -45,6 +48,16 @@ requests-mock # PDF Generation reportlab +# ERA5 + Zarr tooling +zarr +gcsfs +numcodecs +blosc + +# Jupyter-backed Python REPL +jupyter_client +ipykernel + # Optional: CSC's AITTA Platform (local/custom AI models) # Uncomment if you need CSC AITTA integration # aitta-client diff --git a/src/climsight/agent_helpers.py b/src/climsight/agent_helpers.py new file mode 100644 index 0000000..4499aeb --- /dev/null +++ b/src/climsight/agent_helpers.py @@ -0,0 +1,121 @@ +"""Helper utilities for tool-based agents (PangaeaGPT parity).""" + +import logging +import os +from typing import Any, Dict, List, Tuple + +try: + import streamlit as st +except ImportError: + st = None + +try: + from langchain.agents import AgentExecutor, create_openai_tools_agent +except ImportError: + from langchain_classic.agents import AgentExecutor, create_openai_tools_agent + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + +logger = logging.getLogger(__name__) + + +def prepare_visualization_environment(datasets_info: List[Dict[str, Any]]) -> Tuple[Dict[str, Any], str, List[str]]: + """Prepare dataset variables and prompt text for tool agents.""" + datasets_text = "" + dataset_variables: List[str] = [] + datasets: Dict[str, Any] = {} + + uuid_main_dir = None + for info in datasets_info: + sandbox_path = info.get("sandbox_path") + if sandbox_path and isinstance(sandbox_path, str) and os.path.isdir(sandbox_path): + uuid_main_dir = os.path.dirname(os.path.abspath(sandbox_path)) + logger.info("Found main UUID directory from sandbox_path: %s", uuid_main_dir) + break + + datasets["uuid_main_dir"] = uuid_main_dir + + results_dir = None + uuid_dir_files: List[str] = [] + if uuid_main_dir and os.path.exists(uuid_main_dir): + results_dir = os.path.join(uuid_main_dir, "results") + os.makedirs(results_dir, exist_ok=True) + datasets["results_dir"] = results_dir + try: + uuid_dir_files = os.listdir(uuid_main_dir) + except Exception as exc: + logger.error("Error listing UUID directory files: %s", exc) + + # Path instructions for the prompt + uuid_paths = "WARNING: EXACT DATASET PATHS - USE THESE EXACTLY AS SHOWN\n" + uuid_paths += "The following paths contain unique IDs that MUST be used with os.path.join().\n\n" + + if uuid_main_dir: + uuid_paths += "# MAIN OUTPUT DIRECTORY\n" + uuid_paths += f"uuid_main_dir = r'{uuid_main_dir}'\n" + uuid_paths += f"results_dir = r'{results_dir}' # Save all plots here\n\n" + uuid_paths += f"# Files in main directory: {', '.join(uuid_dir_files) if uuid_dir_files else 'None'}\n\n" + + for i, info in enumerate(datasets_info): + var_name = f"dataset_{i + 1}" + datasets[var_name] = info.get("dataset") + dataset_variables.append(var_name) + + sandbox_path = info.get("sandbox_path") + if sandbox_path and isinstance(sandbox_path, str) and os.path.isdir(sandbox_path): + full_uuid_path = os.path.abspath(sandbox_path).replace("\\", "/") + uuid_paths += f"# Dataset {i + 1}: {info.get('name', 'unknown')}\n" + uuid_paths += f"{var_name}_path = r'{full_uuid_path}'\n\n" + if os.path.exists(full_uuid_path): + try: + files = os.listdir(full_uuid_path) + uuid_paths += f"# Files available in {var_name}_path: {', '.join(files)}\n\n" + except Exception as exc: + uuid_paths += f"# Error listing files: {exc}\n\n" + + uuid_paths += "# WARNINGS\n" + uuid_paths += "# 1. Never use placeholder paths.\n" + uuid_paths += "# 2. Always use the dataset_X_path variables shown above.\n" + uuid_paths += "# 3. Check which files exist before reading.\n\n" + + datasets_summary = "" + for i, info in enumerate(datasets_info): + datasets_summary += ( + f"Dataset {i + 1}:\n" + f"Name: {info.get('name', 'Unknown')}\n" + f"Description: {info.get('description', 'No description available')}\n" + f"Type: {info.get('data_type', 'Unknown type')}\n" + f"Sample Data: {info.get('df_head', 'No sample available')}\n\n" + ) + + datasets_text = uuid_paths + datasets_summary + + if st is not None and hasattr(st, "session_state"): + st.session_state["viz_datasets_text"] = datasets_text + + return datasets, datasets_text, dataset_variables + + +def create_standard_agent_executor(llm, tools, prompt_template, max_iterations: int = 25) -> AgentExecutor: + """Create an OpenAI tools agent executor with standard wiring.""" + agent = create_openai_tools_agent( + llm, + tools=tools, + prompt=ChatPromptTemplate.from_messages( + [ + ("system", prompt_template), + ("user", "{input}"), + MessagesPlaceholder(variable_name="messages"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ), + ) + + return AgentExecutor( + agent=agent, + tools=tools, + verbose=True, + handle_parsing_errors=True, + max_iterations=max_iterations, + return_intermediate_steps=True, + ) diff --git a/src/climsight/climsight_classes.py b/src/climsight/climsight_classes.py index 082ff0a..618c072 100644 --- a/src/climsight/climsight_classes.py +++ b/src/climsight/climsight_classes.py @@ -17,13 +17,20 @@ class AgentState(BaseModel): content_message: str = "" input_params: dict = {} smart_agent_response: dict = {} - wikipedia_tool_response: str = "" + wikipedia_tool_response: list = [] ecocrop_search_response: str = "" - rag_search_response: str = "" + rag_search_response: list = [] ipcc_rag_agent_response: str = "" general_rag_agent_response: str = "" data_analysis_response: str = "" # Response from data analysis agent + data_analysis_prompt_text: str = "" # Filtered analysis brief for tools + data_analysis_images: list = [] # Paths to generated analysis images df_list: list = [] # List of dataframes with climate data references: list = [] # List of references combine_agent_prompt_text: str = "" + thread_id: str = "" # Session ID for sandbox storage + uuid_main_dir: str = "" # Root sandbox path + results_dir: str = "" # Plot output directory + climate_data_dir: str = "" # Saved climatology directory + era5_data_dir: str = "" # ERA5 output directory # stream_handler: StreamHandler # Uncomment if needed diff --git a/src/climsight/climsight_engine.py b/src/climsight/climsight_engine.py index 6d2fe23..53ff43a 100644 --- a/src/climsight/climsight_engine.py +++ b/src/climsight/climsight_engine.py @@ -61,6 +61,13 @@ # import climsight classes from climsight_classes import AgentState +# sandbox helpers +from sandbox_utils import ( + ensure_thread_id, + ensure_sandbox_dirs, + get_sandbox_paths, + write_climate_data_manifest, +) # import smart_agent from smart_agent import get_aitta_chat_model, smart_agent # import data_analysis_agent @@ -633,6 +640,13 @@ def agent_llm_request(content_message, input_params, config, api_key, api_key_lo if not isinstance(stream_handler, StreamHandler): logging.error(f"stream_handler must be an instance of StreamHandler") raise TypeError("stream_handler must be an instance of StreamHandler") + + # Ensure sandbox paths for this session (Streamlit or CLI). + thread_id = ensure_thread_id(existing_thread_id=input_params.get("thread_id", "")) + sandbox_paths = get_sandbox_paths(thread_id) + ensure_sandbox_dirs(sandbox_paths) + input_params.update(sandbox_paths) + input_params["thread_id"] = thread_id lat = float(input_params['lat']) # should be already present in input_params lon = float(input_params['lon']) # should be already present in input_params @@ -671,6 +685,31 @@ def agent_llm_request(content_message, input_params, config, api_key, api_key_lo max_completion_tokens=4096 ) llm_intro = llm_combine_agent + + # Data analysis LLM (separate from combine step). + llm_dataanalysis_cfg = config.get("llm_dataanalysis") + if not llm_dataanalysis_cfg: + raise RuntimeError("Missing llm_dataanalysis configuration.") + if llm_dataanalysis_cfg.get("model_type") == "local": + llm_dataanalysis_agent = ChatOpenAI( + openai_api_base="http://localhost:8000/v1", + model_name=llm_dataanalysis_cfg.get("model_name"), + openai_api_key=api_key_local, + max_tokens=16000, + ) + elif llm_dataanalysis_cfg.get("model_type") == "openai": + llm_dataanalysis_agent = ChatOpenAI( + openai_api_key=api_key, + model_name=llm_dataanalysis_cfg.get("model_name"), + max_tokens=16000, + ) + elif llm_dataanalysis_cfg.get("model_type") == "aitta": + llm_dataanalysis_agent = get_aitta_chat_model( + llm_dataanalysis_cfg.get("model_name"), + max_completion_tokens=4096 + ) + else: + llm_dataanalysis_agent = llm_combine_agent def zero_rag_agent(state: AgentState, figs = {}): logger.debug(f"get_elevation_from_api from: {lat}, {lon}") @@ -843,6 +882,20 @@ def data_agent(state: AgentState, data={}, df={}): data['high_res_climate'] = data['climate_data'] state.df_list = df_list + # Persist climatology into the sandbox for cross-agent access. + if state.thread_id: + sandbox_paths = get_sandbox_paths(state.thread_id) + ensure_sandbox_dirs(sandbox_paths) + manifest_path, _ = write_climate_data_manifest( + df_list, sandbox_paths["climate_data_dir"], climate_source + ) + state.uuid_main_dir = sandbox_paths["uuid_main_dir"] + state.results_dir = sandbox_paths["results_dir"] + state.climate_data_dir = sandbox_paths["climate_data_dir"] + state.era5_data_dir = sandbox_paths["era5_data_dir"] + state.input_params.update(sandbox_paths) + state.input_params["climate_data_manifest"] = manifest_path + # Add appropriate references based on data source ref_key_map = { 'nextGEMS': 'high_resolution_climate_model', @@ -1022,6 +1075,13 @@ def combine_agent(state: AgentState): state.content_message += state.data_agent_response['content_message'] state.input_params.update(state.data_agent_response['input_params']) + #add data_analysis_agent response to content_message and input_params + if state.data_analysis_response: + state.input_params['data_analysis_response'] = state.data_analysis_response + state.content_message += "\n Data analysis agent response: {data_analysis_response} " + if state.data_analysis_images: + state.input_params['data_analysis_images'] = state.data_analysis_images + if state.smart_agent_response != {}: smart_analysis = state.smart_agent_response.get('output', '') state.input_params['smart_agent_analysis'] = smart_analysis @@ -1029,12 +1089,7 @@ def combine_agent(state: AgentState): logger.info(f"smart_agent_response: {state.smart_agent_response}") # Add Wikipedia tool response - if state.wikipedia_tool_response != {}: - wiki_response = state.wikipedia_tool_response - state.input_params['wikipedia_tool_response'] = wiki_response - state.content_message += "\n Wikipedia Search Response: {wikipedia_tool_response} " - logger.info(f"Wikipedia_tool_reponse: {state.wikipedia_tool_response}") - if state.ecocrop_search_response != {}: + if state.ecocrop_search_response: ecocrop_response = state.ecocrop_search_response state.input_params['ecocrop_search_response'] = ecocrop_response state.content_message += "\n ECOCROP Search Response: {ecocrop_search_response} " @@ -1110,7 +1165,9 @@ def route_fromintro(state: AgentState) -> Sequence[str]: workflow.add_node("data_agent", lambda s: data_agent(s, data, df)) # Pass `data` as argument workflow.add_node("zero_rag_agent", lambda s: zero_rag_agent(s, figs)) # Pass `figs` as argument workflow.add_node("smart_agent", lambda s: smart_agent(s, config, api_key, api_key_local, stream_handler)) - workflow.add_node("data_analysis_agent", lambda s: data_analysis_agent(s, config, api_key, api_key_local, stream_handler)) + workflow.add_node("data_analysis_agent", lambda s: data_analysis_agent( + s, config, api_key, api_key_local, stream_handler, llm_dataanalysis_agent + )) workflow.add_node("combine_agent", combine_agent) path_map = { @@ -1148,7 +1205,18 @@ def route_fromintro(state: AgentState) -> Sequence[str]: # with open(graph_image_path, 'wb') as f: # f.write(graph_img) # Write the image bytes to the file - state = AgentState(messages=[], input_params=input_params, user=input_params['user_message'], content_message=content_message, references=[]) + state = AgentState( + messages=[], + input_params=input_params, + user=input_params['user_message'], + content_message=content_message, + references=[], + thread_id=input_params.get("thread_id", ""), + uuid_main_dir=input_params.get("uuid_main_dir", ""), + results_dir=input_params.get("results_dir", ""), + climate_data_dir=input_params.get("climate_data_dir", ""), + era5_data_dir=input_params.get("era5_data_dir", ""), + ) stream_handler.update_progress("Starting workflow...") output = app.invoke(state) @@ -1168,4 +1236,4 @@ def route_fromintro(state: AgentState) -> Sequence[str]: stream_handler.send_reference_text('- '+ref+' \n') - return output['final_answer'], input_params, content_message, combine_agent_prompt_text \ No newline at end of file + return output['final_answer'], input_params, content_message, combine_agent_prompt_text diff --git a/src/climsight/config.py b/src/climsight/config.py new file mode 100644 index 0000000..2757216 --- /dev/null +++ b/src/climsight/config.py @@ -0,0 +1,20 @@ +"""Minimal config helpers for tool modules.""" + +import os + +try: + import streamlit as st +except ImportError: + st = None + + +def _get_openai_api_key() -> str: + if st is not None and hasattr(st, "secrets"): + try: + return st.secrets["general"]["openai_api_key"] + except Exception: + pass + return os.environ.get("OPENAI_API_KEY", "") + + +API_KEY = _get_openai_api_key() diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py index c3258a0..7731551 100644 --- a/src/climsight/data_analysis_agent.py +++ b/src/climsight/data_analysis_agent.py @@ -1,72 +1,512 @@ """ -Data Analysis Agent +Data analysis agent for Climsight. -This agent receives all state from previous agents and performs: -1. Climate data extraction from datasets -2. Data post-processing and calculations -3. Visualization and analysis using Python REPL -4. Scientific interpretation of results - -Currently a stub implementation - will be fully implemented in next phase. - -This agent will eventually include: -- get_data_components tool (from smart_agent_backup.py) -- python_repl tool for analysis and visualization -- image_viewer tool for scientific interpretation of plots -- Statistical calculations and trend analysis +This agent mirrors PangaeaGPT's oceanographer/visualization style while operating +on local climatology. It filters context, then uses tools to extract or analyze +climate data, saving outputs into the sandbox. """ -from climsight_classes import AgentState +import json import logging +import os +from typing import Any, Dict, List + +from climsight_classes import AgentState + +try: + from utils import make_json_serializable +except ImportError: + from .utils import make_json_serializable +from sandbox_utils import ensure_thread_id, ensure_sandbox_dirs, get_sandbox_paths +from agent_helpers import create_standard_agent_executor +from tools.get_data_components import create_get_data_components_tool +from tools.era5_retrieval_tool import era5_retrieval_tool +from tools.python_repl import CustomPythonREPLTool +from tools.reflection_tools import reflect_tool +from tools.visualization_tools import ( + list_plotting_data_files_tool, + wise_agent_tool, +) + +from langchain_core.prompts import ChatPromptTemplate logger = logging.getLogger(__name__) -def data_analysis_agent(state: AgentState, config, api_key, api_key_local, stream_handler): - """ - Data analysis and visualization agent (stub implementation). - - This agent will be responsible for: - - Extracting specific climate data components based on user needs - - Performing statistical analysis and calculations - - Creating visualizations and plots - - Interpreting results using image_viewer - - Providing data-driven insights - - Args: - state: AgentState containing all previous agent outputs including: - - user: Original user query - - input_params: Location and other parameters - - df_list: Climate data from data_agent - - smart_agent_response: Background information (if smart_agent was enabled) - - ipcc_rag_agent_response: IPCC report insights - - general_rag_agent_response: General climate literature insights - - zero_rag_agent_response: Geographic/environmental context - config: Configuration dictionary - api_key: OpenAI API key - api_key_local: Local model API key - stream_handler: Stream handler for progress updates - - Returns: - dict: Updated state with analysis results - """ - stream_handler.update_progress("Data analysis agent (stub - passing through)...") - - logger.info("Data analysis agent stub called") - logger.info(f"User query: {state.user}") - logger.info(f"Location: {state.input_params.get('location_str', 'unknown')}") - - # TODO: Implement full data analysis functionality - # Phase 1 implementation will include: - # 1. Move get_data_components tool from smart_agent_backup.py - # 2. Move python_repl tool initialization with climate data context - # 3. Move image_viewer tool for plot interpretation - # 4. Create agent executor with data analysis prompts - # 5. Integrate with combine_agent for final synthesis - - # For now, just pass through with a stub message - stub_response = "Data analysis agent not yet implemented. Analysis and visualization features will be added in the next implementation phase." +def _build_climate_data_summary(df_list: List[Dict[str, Any]]) -> str: + """Summarize available climatology without exposing raw values.""" + if not df_list: + return "No climatology data available." + + lines = [] + for entry in df_list: + vars_summary = [] + for var_name, var_info in entry.get("extracted_vars", {}).items(): + full_name = var_info.get("full_name", var_name) + units = var_info.get("units", "") + vars_summary.append(f"{full_name} ({units})") + vars_text = ", ".join(vars_summary) if vars_summary else "Unknown variables" + lines.append( + f"- {entry.get('years_of_averaging', '')}: {entry.get('description', '')} | {vars_text}" + ) + + return "\n".join(lines) + + +def _build_datasets_text(state) -> str: + """Build simple dataset paths text for prompt injection.""" + lines = [ + "## Sandbox Paths (use these exact paths in your code)", + f"- Main directory: {state.uuid_main_dir}", + f"- Results directory (save plots here): {state.results_dir}", + f"- Climate data: {state.climate_data_dir}", + ] + + if state.era5_data_dir: + lines.append(f"- ERA5 data: {state.era5_data_dir}") + + # List available climate data files + if state.climate_data_dir and os.path.exists(state.climate_data_dir): + try: + files = os.listdir(state.climate_data_dir) + if files: + lines.append(f"\n## Climate Data Files Available") + lines.append(f"Files in climate_data/: {', '.join(files)}") + # Highlight the main data.csv file + if "data.csv" in files: + lines.append("Note: 'data.csv' contains the main climatology dataset") + except Exception as e: + logger.warning(f"Could not list climate data files: {e}") + + return "\n".join(lines) + + +def _build_filter_prompt() -> str: + """Prompt for the analysis brief filter LLM.""" + return ( + "You are a data analysis filter for a climate assistant.\n" + "Extract only actionable analysis requirements.\n\n" + "Output format (bullets only):\n" + "- Target variables (with units if specified)\n" + "- Thresholds or criteria\n" + "- Time ranges or scenarios\n" + "- Spatial specifics (location, buffers)\n" + "- Analysis tasks (comparisons, trends, plots)\n\n" + "Rules:\n" + "- Do NOT include raw climate data values.\n" + "- Do NOT include long RAG or Wikipedia text.\n" + "- Omit vague statements that are not actionable.\n" + ) + + +def _create_tool_prompt(datasets_text: str, config: dict) -> str: + """System prompt for tool-driven analysis - dynamically built based on config.""" + has_era5 = config.get("use_era5_data", False) + has_repl = config.get("use_powerful_data_analysis", False) + + prompt = """You are the data analysis agent for ClimSight. +Your job is to provide quantitative climate analysis with visualizations. + +## AVAILABLE TOOLS +""" + + tool_num = 1 + + # Only include get_data_components if Python_REPL is NOT available + if not has_repl: + prompt += f""" +{tool_num}. **get_data_components** - Extract climate variables from stored climatology + - Variables: Temperature, Precipitation, u_wind, v_wind + - Returns monthly values for historical AND future climate projections + - Example: get_data_components(environmental_data="Temperature", months=["Jan", "Feb", "Mar"]) + - ALWAYS use this first to get the baseline climatology data +""" + tool_num += 1 + + if has_era5: + prompt += f""" +{tool_num}. **retrieve_era5_data** - Download ERA5 reanalysis data (OBSERVED REALITY) + - Variables: 2m_temperature, total_precipitation, 10m_u_component_of_wind, 10m_v_component_of_wind + - Specify start_date, end_date (YYYY-MM-DD), and lat/lon bounds + - Data saved to sandbox as Zarr files for Python analysis + + **CRITICAL**: ERA5 is OBSERVED climate data (ground truth), NOT model projections! + - Use ERA5 to validate climate models by comparing with model historical period + - ERA5 shows the ACTUAL climate at this location (recent 10-20 years) + - Climate models may have bias - ERA5 reveals this bias + + **Standard workflow**: Download ~10 years of recent ERA5 data (e.g., 2014-2024) + - Example: retrieve_era5_data(variable_id="2m_temperature", start_date="2014-01-01", end_date="2023-12-31", min_latitude=48.0, max_latitude=49.0, min_longitude=11.0, max_longitude=12.0) + - Then use Python_REPL to compare ERA5 vs climate model historical period +""" + tool_num += 1 + + if has_repl: + era5_note = "" + if has_era5: + era5_note = """ + **ERA5 DATA PROCESSING** (when available): + - ERA5 data is saved as Zarr format in era5_data/ directory + - Use xarray to load: `ds = xr.open_zarr('era5_data/.zarr')` + - ERA5 = OBSERVED reality, climate_data/data.csv = MODEL projections + - ALWAYS compare ERA5 with model historical period to show model bias +""" + + prompt += f""" +{tool_num}. **Python_REPL** - Execute Python code for data analysis and visualizations + - Pre-loaded libraries: pandas (pd), numpy (np), matplotlib.pyplot (plt), xarray (xr) + - Climate data is AUTO-LOADED in the namespace (see data files below) + - Working directory is the sandbox root + - ALWAYS save plots to results/ directory + + **Climate Model Data** (climate_data/data.csv): + - Monthly climatology from climate models + - Columns: Month, mean2t (temperature), tp (precipitation), wind_u, wind_v + - Contains BOTH historical period (1995-2014) AND future projections (2020-2049) + - These are MODEL outputs, not observations! +{era5_note} + Example workflow (climate models only): + ```python + import pandas as pd + import matplotlib.pyplot as plt + + # Read climate model data + df = pd.read_csv('climate_data/data.csv') + print(df.head()) + + # Create temperature plot + plt.figure(figsize=(10, 6)) + plt.plot(df['Month'], df['mean2t'], marker='o', label='Temperature') + plt.xlabel('Month') + plt.ylabel('Temperature (°C)') + plt.title('Monthly Temperature') + plt.legend() + plt.savefig('results/temperature_plot.png', dpi=150, bbox_inches='tight') + plt.close() + print('Plot saved to results/temperature_plot.png') + ``` +""" + tool_num += 1 + + prompt += f""" +{tool_num}. **list_plotting_data_files** - List files in sandbox directories +{tool_num + 1}. **reflect** - Think through analysis approach before executing +{tool_num + 2}. **wise_agent** - Get guidance on complex analysis decisions + +## REQUIRED WORKFLOW +""" + + if has_era5 and has_repl: + # ERA5 + Python_REPL: Most comprehensive workflow + prompt += """ +**CRITICAL**: ERA5 is enabled - you MUST download and use it as ground truth! + +1. **DOWNLOAD ERA5 OBSERVATIONS** - This is MANDATORY, not optional: + - Call retrieve_era5_data for the last 10 years (e.g., 2014-01-01 to 2023-12-31) + - Download at least temperature and precipitation + - Use exact coordinates from the user query + - ERA5 = observed reality at this location + +2. **LOAD CLIMATE MODEL DATA** - Use Python_REPL: + - Read climate_data/data.csv (contains model historical + future projections) + - These are MODEL outputs, not observations + +3. **COMPARE ERA5 vs MODEL** - This reveals model bias: + - Load ERA5 data with xarray: `ds = xr.open_zarr('era5_data/.zarr')` + - Calculate ERA5 monthly climatology (mean by month across years) + - Compare with model historical period (1995-2014) + - Show the difference (bias) between observations and model + +4. **ANALYZE FUTURE PROJECTIONS** - With ERA5 context: + - Show model future projections + - Explain how model bias affects confidence in projections + - Use ERA5 as the baseline ("current climate") + +5. **CREATE VISUALIZATIONS** - MANDATORY: + - Plot 1: ERA5 observations vs model historical vs model future (3-way comparison) + - Plot 2: Model bias (ERA5 minus model historical) + - Plot 3: Future change (model future minus ERA5 baseline) + - Save ALL plots to results/ directory +""" + elif has_repl and not has_era5: + # Python_REPL without ERA5: Standard workflow + prompt += """ +1. **READ THE DATA** - Use Python_REPL to inspect climate_data/data.csv +2. **ANALYZE** - Compare historical vs future projections in the data +3. **CREATE VISUALIZATIONS** - This is MANDATORY: + - Plot monthly temperature comparison (historical vs future) + - Plot precipitation patterns if relevant + - Plot wind data if relevant to the query + - Save ALL plots to results/ directory +""" + elif not has_repl: + # Without Python_REPL, use get_data_components + prompt += """ +1. **ALWAYS START** by calling get_data_components for Temperature (all months) +2. **THEN** call get_data_components for Precipitation (all months) +3. **ANALYZE** the data - compare historical vs future projections +""" + + if has_era5: + prompt += """ + +**NOTE**: ERA5 is enabled but Python_REPL is not. You can download ERA5 data, but +you won't be able to process it effectively. Consider using get_data_components only. +""" + + prompt += f""" +## SANDBOX PATHS AND DATA + +{datasets_text} + +## PROACTIVE ANALYSIS + +Even if the user doesn't explicitly ask for plots, you SHOULD: +- Create temperature trend visualizations +- Show precipitation comparisons between historical and future +- Highlight months with the largest projected changes +- Identify potential climate risks (e.g., heat stress, drought, flooding) +""" + + if has_era5: + prompt += """ +**With ERA5 data, ALWAYS include:** +- Model bias analysis (ERA5 vs model historical) +- Current climate baseline (from ERA5 observations) +- Confidence assessment (how well do models match observations?) +""" + + prompt += """ +## OUTPUT FORMAT + +Your final response should include: +""" + + if has_era5 and has_repl: + prompt += """1. **Current Climate (ERA5 Observations)**: Recent observed temperature, precipitation +2. **Model Performance**: How well climate models match ERA5 observations (bias analysis) +3. **Future Projections**: Model predictions with context of model bias +4. **Climate Change Signal**: Differences between ERA5 baseline and future projections +5. **Critical Months**: Which months show largest changes +6. **Visualizations**: List of plot files created (MANDATORY: 3-way comparison plots) +7. **Implications**: Interpretation with confidence levels based on model-observation agreement +""" + else: + prompt += """1. **Key Climate Values**: Extracted temperature, precipitation data +2. **Climate Change Signal**: Differences between historical and future projections +3. **Critical Months**: Which months show largest changes +4. **Visualizations**: List of plot files created (if Python_REPL available) +5. **Implications**: Brief interpretation relevant to the user's query +""" + + prompt += """ +Limit total tool calls to 20. +""" + return prompt + + +def _normalize_tool_observation(observation: Any) -> Any: + """Normalize tool output into a plain Python object.""" + try: + from langchain_core.messages import AIMessage + except Exception: + AIMessage = None + + if AIMessage is not None and isinstance(observation, AIMessage): + return observation.content + return observation + + +def data_analysis_agent( + state: AgentState, + config: dict, + api_key: str, + api_key_local: str, + stream_handler, + llm_dataanalysis_agent=None, +): + """Run filtered analysis + tool-based climatology extraction.""" + stream_handler.update_progress("Data analysis: preparing sandbox...") + + # Ensure sandbox paths are available. + thread_id = ensure_thread_id(existing_thread_id=state.thread_id) + sandbox_paths = get_sandbox_paths(thread_id) + ensure_sandbox_dirs(sandbox_paths) + + state.thread_id = thread_id + state.uuid_main_dir = sandbox_paths["uuid_main_dir"] + state.results_dir = sandbox_paths["results_dir"] + state.climate_data_dir = sandbox_paths["climate_data_dir"] + state.era5_data_dir = sandbox_paths["era5_data_dir"] + + # Build analysis context for filtering. + climate_summary = _build_climate_data_summary(state.df_list) + context_sections = [ + f"User query: {state.user}", + f"Location: {state.input_params.get('location_str', '')}", + f"Coordinates: {state.input_params.get('lat', '')}, {state.input_params.get('lon', '')}", + f"Climatology summary:\n{climate_summary}", + ] + + if state.ipcc_rag_agent_response: + context_sections.append(f"IPCC RAG: {state.ipcc_rag_agent_response}") + if state.general_rag_agent_response: + context_sections.append(f"General RAG: {state.general_rag_agent_response}") + if state.smart_agent_response: + context_sections.append(f"Smart agent: {state.smart_agent_response.get('output', '')}") + if state.ecocrop_search_response: + context_sections.append(f"ECOCROP: {state.ecocrop_search_response}") + if state.zero_agent_response: + safe_zero_context = make_json_serializable(state.zero_agent_response) + context_sections.append(f"Local context: {json.dumps(safe_zero_context, indent=2)}") + + analysis_context = "\n\n".join(context_sections) + + # Check if filter step is enabled (configurable) + use_filter_step = config.get("llm_dataanalysis", {}).get("use_filter_step", True) + + if use_filter_step and llm_dataanalysis_agent is not None: + stream_handler.update_progress("Data analysis: filtering context...") + filter_prompt = ChatPromptTemplate.from_messages( + [ + ("system", _build_filter_prompt()), + ("user", "{context}"), + ] + ) + result = llm_dataanalysis_agent.invoke(filter_prompt.format_messages(context=analysis_context)) + filtered_context = result.content if hasattr(result, "content") else str(result) + + # CRITICAL: Always preserve the user's original question + location_str = state.input_params.get('location_str', 'Unknown location') + analysis_brief = f"""USER QUESTION: {state.user} + +Location: {location_str} +Coordinates: {state.input_params.get('lat', '')}, {state.input_params.get('lon', '')} + +ANALYSIS REQUIREMENTS: +{filtered_context} +""" + else: + # Skip filter step - pass essential context directly + stream_handler.update_progress("Data analysis: preparing context (no filter)...") + location_str = state.input_params.get('location_str', 'Unknown location') + analysis_brief = f"""USER QUESTION: {state.user} + +Location: {location_str} +Coordinates: {state.input_params.get('lat', '')}, {state.input_params.get('lon', '')} + +Available climatology: +{climate_summary} + +Required analysis: +- Extract Temperature and Precipitation data +- Compare historical vs future projections +- Create visualizations if Python_REPL is available +""" + + state.data_analysis_prompt_text = analysis_brief + + brief_path = os.path.join(state.uuid_main_dir, "analysis_brief.txt") + with open(brief_path, "w", encoding="utf-8") as f: + f.write(analysis_brief) + + # Build simplified datasets_text for prompt + datasets_text = _build_datasets_text(state) + + # Build datasets dict for Python REPL + datasets = { + "uuid_main_dir": state.uuid_main_dir, + "results_dir": state.results_dir, + } + if state.climate_data_dir: + datasets["climate_data_dir"] = state.climate_data_dir + if state.era5_data_dir: + datasets["era5_data_dir"] = state.era5_data_dir + + # Tool setup - ORDER MATTERS (matches prompt workflow) + tools = [] + + has_python_repl = config.get("use_powerful_data_analysis", False) + + # 1. get_data_components - ONLY if Python_REPL is NOT available + # (When Python_REPL is enabled, it's redundant - agent reads CSV directly) + if not has_python_repl: + tools.append(create_get_data_components_tool(state, config, stream_handler)) + + # 2. ERA5 retrieval (if enabled) + if config.get("use_era5_data", False): + tools.append(era5_retrieval_tool) + + # 3. Python REPL for analysis/visualization (if enabled) + if has_python_repl: + repl_tool = CustomPythonREPLTool( + datasets=datasets, + results_dir=state.results_dir, + session_key=thread_id, + ) + tools.append(repl_tool) + + # 4. Helper tools + tools.extend([ + list_plotting_data_files_tool, + reflect_tool, + wise_agent_tool, + ]) + + stream_handler.update_progress("Data analysis: running tools...") + tool_prompt = _create_tool_prompt(datasets_text, config) + + if llm_dataanalysis_agent is None: + from langchain_openai import ChatOpenAI + + llm_dataanalysis_agent = ChatOpenAI( + openai_api_key=api_key, + model_name=config.get("llm_combine", {}).get("model_name", "gpt-4.1-nano"), + ) + + agent_executor = create_standard_agent_executor( + llm_dataanalysis_agent, + tools, + tool_prompt, + max_iterations=20, + ) + + agent_input = { + "input": analysis_brief or state.user, + "messages": state.messages, + } + + result = agent_executor(agent_input) + + data_components_outputs = [] + plot_images: List[str] = [] + + for action, observation in result.get("intermediate_steps", []): + if action.tool == "get_data_components": + data_components_outputs.append(_normalize_tool_observation(observation)) + if action.tool in ("Python_REPL", "python_repl"): + obs = _normalize_tool_observation(observation) + if isinstance(obs, dict): + plot_images.extend(obs.get("plot_images", [])) + if action.tool == "retrieve_era5_data": + # Keep ERA5 outputs in state only (no raw text in summary). + state.input_params.setdefault("era5_results", []).append( + _normalize_tool_observation(observation) + ) + + analysis_text = result.get("output", "") + if data_components_outputs: + analysis_text += "\n\nClimatology extracts:\n" + for item in data_components_outputs: + analysis_text += json.dumps(item, indent=2) + "\n" + + state.data_analysis_response = analysis_text + state.data_analysis_images = plot_images + + stream_handler.update_progress("Data analysis complete.") return { - 'data_analysis_response': stub_response + "data_analysis_response": analysis_text, + "data_analysis_images": plot_images, + "data_analysis_prompt_text": analysis_brief, } diff --git a/src/climsight/sandbox_utils.py b/src/climsight/sandbox_utils.py new file mode 100644 index 0000000..963731c --- /dev/null +++ b/src/climsight/sandbox_utils.py @@ -0,0 +1,130 @@ +""" +Sandbox utilities for per-session data storage. + +This mirrors the PangaeaGPT layout while keeping the API minimal for Climsight. +""" + +import json +import logging +import os +import uuid +from pathlib import Path +from typing import Dict, List, Tuple + +logger = logging.getLogger(__name__) + + +def ensure_thread_id(session_state=None, existing_thread_id: str = "") -> str: + """Ensure a stable session thread_id across Streamlit/CLI runs.""" + thread_id = existing_thread_id or "" + + if not thread_id and session_state is not None: + thread_id = session_state.get("thread_id", "") + + if not thread_id: + thread_id = uuid.uuid4().hex + + if session_state is not None: + session_state["thread_id"] = thread_id + + # Expose to non-Streamlit tools (CLI, background workers). + os.environ.setdefault("CLIMSIGHT_THREAD_ID", thread_id) + + return thread_id + + +def get_sandbox_paths(thread_id: str) -> Dict[str, str]: + """Return sandbox paths for a given session.""" + base_dir = Path("tmp") / "sandbox" / thread_id + return { + "uuid_main_dir": str(base_dir), + "results_dir": str(base_dir / "results"), + "climate_data_dir": str(base_dir / "climate_data"), + "era5_data_dir": str(base_dir / "era5_data"), + } + + +def ensure_sandbox_dirs(paths: Dict[str, str]) -> None: + """Create sandbox directories if they do not exist.""" + for key, path in paths.items(): + if not path: + continue + os.makedirs(path, exist_ok=True) + logger.debug("Ensured sandbox dir %s: %s", key, path) + + +def write_climate_data_manifest( + df_list: List[Dict], + climate_data_dir: str, + source: str, +) -> Tuple[str, List[Dict]]: + """Persist climate dataframes and metadata into the sandbox. + + Returns: + (manifest_path, entries) + """ + os.makedirs(climate_data_dir, exist_ok=True) + + entries: List[Dict] = [] + main_index = 0 + + for i, entry in enumerate(df_list): + if entry.get("main"): + main_index = i + break + + for i, entry in enumerate(df_list): + df = entry.get("dataframe") + if df is None: + continue + + csv_name = f"simulation_{i + 1}.csv" + meta_name = f"simulation_{i + 1}_meta.json" + csv_path = os.path.join(climate_data_dir, csv_name) + meta_path = os.path.join(climate_data_dir, meta_name) + + df.to_csv(csv_path, index=False) + + meta = { + "years_of_averaging": entry.get("years_of_averaging", ""), + "description": entry.get("description", ""), + "extracted_vars": entry.get("extracted_vars", {}), + "main": bool(entry.get("main", False)), + "source": entry.get("source", source), + "filename": entry.get("filename", ""), + } + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(meta, f, indent=2) + + entries.append({ + "csv": csv_name, + "meta": meta_name, + "years_of_averaging": meta["years_of_averaging"], + "description": meta["description"], + "main": meta["main"], + }) + + # Provide a simple auto-load CSV for Python_REPL parity. + if df_list: + main_df = df_list[main_index].get("dataframe") + if main_df is not None: + main_df.to_csv(os.path.join(climate_data_dir, "data.csv"), index=False) + + manifest = { + "source": source, + "entries": entries, + } + manifest_path = os.path.join(climate_data_dir, "climate_data_manifest.json") + with open(manifest_path, "w", encoding="utf-8") as f: + json.dump(manifest, f, indent=2) + + return manifest_path, entries + + +def load_climate_data_manifest(manifest_path: str) -> Dict: + """Load a climate data manifest from disk.""" + if not manifest_path or not os.path.exists(manifest_path): + return {} + + with open(manifest_path, "r", encoding="utf-8") as f: + return json.load(f) diff --git a/src/climsight/smart_agent.py b/src/climsight/smart_agent.py index 30255da..a5b21b1 100644 --- a/src/climsight/smart_agent.py +++ b/src/climsight/smart_agent.py @@ -107,6 +107,8 @@ def smart_agent(state: AgentState, config, api_key, api_key_local, stream_handle - Focus on gathering contextual information, NOT on extracting or analyzing specific climate data - Do NOT attempt to retrieve climate data values - that will be handled by other agents - Your output should be a comprehensive text summary with relevant background context + - If you call a tool multiple times, incorporate all results in your summary + - Do NOT call wikipedia_search more than 10 times total """ if config['llm_smart']['model_type'] in ("local", "aitta"): prompt += f""" @@ -151,13 +153,20 @@ def smart_agent(state: AgentState, config, api_key, api_key_local, stream_handle - Do NOT mention what you're going to do or explain your process - Focus on presenting the gathered information in a clear, organized format - The summary should help subsequent agents understand the context for data analysis + - If you have already called wikipedia_search 10 times, proceed without further Wikipedia calls """ #[1] get_data_components tool moved to data_analysis_agent.py #[2] Wikipedia processing tool - def process_wikipedia_article(query: str) -> str: + wikipedia_call_state = {"count": 0} + + def process_wikipedia_article(query: str) -> str: + if wikipedia_call_state["count"] >= 10: + return "" + wikipedia_call_state["count"] += 1 + stream_handler.update_progress("Searching Wikipedia for related information with a smart agent...") # Initialize the LLM @@ -562,48 +571,61 @@ def process_ecocrop_search(query: str) -> str: result = agent_executor(agent_input) # Extract the tool outputs - tool_outputs = {} + wikipedia_results = [] + rag_results = [] + + def _extract_tool_query(action): + tool_input = getattr(action, "tool_input", None) + if tool_input is None: + tool_input = getattr(action, "input", None) + if isinstance(tool_input, dict): + return tool_input.get("query") or tool_input.get("input") or tool_input + return tool_input + + def _normalize_tool_output(observation): + if isinstance(observation, AIMessage): + return observation.content, None + if isinstance(observation, dict): + result = observation.get("result") + if isinstance(result, AIMessage): + result = result.content + return result, observation.get("references") + return observation, None + for action, observation in result['intermediate_steps']: - if action.tool == 'wikipedia_search': - if isinstance(observation, AIMessage): - tool_outputs['wikipedia_search'] = observation.content - else: - output = observation - # Assuming output_data is a dict now - if isinstance(output, dict): - tool_outputs['wikipedia_search'] = output.get('result').content - state.references.append(output.get('references', [])) - else: - tool_outputs['wikipedia_search'] = output - elif action.tool == 'ECOCROP_search': - if isinstance(observation, AIMessage): - tool_outputs['ECOCROP_search'] = observation.content - else: - tool_outputs['ECOCROP_search'] = observation + tool_name = action.tool + tool_query = _extract_tool_query(action) + + if tool_name == 'wikipedia_search': + answer_text, references = _normalize_tool_output(observation) + if answer_text: + wikipedia_results.append({"query": tool_query, "answer": answer_text}) + if references: + if isinstance(references, list): + state.references.extend(references) + else: + state.references.append(references) + elif tool_name == 'RAG_search': + answer_text, references = _normalize_tool_output(observation) + if answer_text: + rag_results.append({"query": tool_query, "answer": answer_text}) + if references: + if isinstance(references, list): + state.references.extend(references) + else: + state.references.append(references) + elif tool_name == 'ECOCROP_search': + answer_text, _ = _normalize_tool_output(observation) + if answer_text: + if state.ecocrop_search_response: + state.ecocrop_search_response += "\n" + answer_text + else: + state.ecocrop_search_response = answer_text if any("FAO, IIASA" not in element for element in state.references): state.references.append("FAO, IIASA: Global Agro-Ecological Zones (GAEZ V4) - Data Portal User's Guide, 1st edn. FAO and IIASA, Rome, Italy (2021). https://doi.org/10.4060/cb5167en") - if action.tool == 'RAG_search': - if isinstance(observation, AIMessage): - tool_outputs['RAG_search'] = observation.content - else: - output = observation - # Assuming output_data is a dict now - if isinstance(output, dict): - tool_outputs['RAG_search'] = output.get('result').content - # If refs is a list, extend; if it's a string, append. - refs = output.get('references', []) - if isinstance(refs, list): - state.references.extend(refs) - elif isinstance(refs, str): - state.references.append(refs) - - # Store the response from the wikipedia_search tool into state - if 'wikipedia_search' in tool_outputs: - state.wikipedia_tool_response = tool_outputs['wikipedia_search'] - if 'ECOCROP_search' in tool_outputs: - state.ecocrop_search_response = tool_outputs['ECOCROP_search'] - if 'RAG_search' in tool_outputs: - state.rag_search_response = tool_outputs['RAG_search'] + + state.wikipedia_tool_response = wikipedia_results + state.rag_search_response = rag_results # Also store the agent's final answer smart_agent_response = result['output'] diff --git a/src/climsight/streamlit_interface.py b/src/climsight/streamlit_interface.py index 394b804..3e16eb3 100644 --- a/src/climsight/streamlit_interface.py +++ b/src/climsight/streamlit_interface.py @@ -22,6 +22,7 @@ from extract_climatedata_functions import plot_climate_data from embedding_utils import create_embeddings from climate_data_providers import get_available_providers +from sandbox_utils import ensure_thread_id, ensure_sandbox_dirs, get_sandbox_paths #ui for saving docs from datetime import datetime @@ -42,7 +43,11 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r - references (dict): References for the data used in the analysis. Returns: None - """ + """ + # Ensure sandbox exists for the current session. + thread_id = ensure_thread_id(session_state=st.session_state) + sandbox_paths = get_sandbox_paths(thread_id) + ensure_sandbox_dirs(sandbox_paths) # Config try: @@ -133,6 +138,16 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r show_add_info = st.toggle("Provide additional information", value=False, help="""If this is activated you will see all the variables that were taken into account for the analysis as well as some plots.""") smart_agent = st.toggle("Use extra search", value=False, help="""If this is activated, ClimSight will make additional requests to Wikipedia and RAG, which can significantly increase response time.""") + use_era5_data = st.toggle( + "Enable ERA5 data", + value=config.get("use_era5_data", False), + help="Allow the data analysis agent to retrieve ERA5 data into the sandbox.", + ) + use_powerful_data_analysis = st.toggle( + "Enable Python analysis", + value=config.get("use_powerful_data_analysis", False), + help="Allow the data analysis agent to use the Python REPL and generate plots.", + ) # remove the llmModeKey_box from the form, as we tend to run the agent mode, direct mode is for development only #llmModeKey_box = st.radio("Select LLM mode 👉", key="visibility", options=["Direct", "Agent (experimental)"]) @@ -194,6 +209,8 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r #config['llmModeKey'] = "direct_llm" if llmModeKey_box == "Direct" else "agent_llm" config['show_add_info'] = show_add_info config['use_smart_agent'] = smart_agent + config['use_era5_data'] = use_era5_data + config['use_powerful_data_analysis'] = use_powerful_data_analysis # RUN submit button if submit_button and user_message: @@ -206,6 +223,8 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r #config['llmModeKey'] = "direct_llm" if llmModeKey_box == "Direct" else "agent_llm" config['show_add_info'] = show_add_info config['use_smart_agent'] = smart_agent + config['use_era5_data'] = use_era5_data + config['use_powerful_data_analysis'] = use_powerful_data_analysis # Creating a potential bottle neck here with loading the db inside the streamlit form, but it works fine # for the moment. Just making a note here for any potential problems that might arise later one. @@ -266,6 +285,9 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r is_on_land = False st.markdown(f"The selected point is in the ocean.\n Please choose a location on land.") else: + # Pass sandbox paths into the agent state. + input_params['thread_id'] = thread_id + input_params.update(sandbox_paths) # extend input_params with user_message input_params['user_message'] = user_message content_message = "Human request: {user_message} \n " + content_message @@ -431,6 +453,16 @@ def update_progress_ui(message): with st.expander("Source"): st.markdown(model_info) + # Data analysis images (from python_repl) + analysis_images = stored_input_params.get('data_analysis_images', []) + if analysis_images: + st.markdown("**Data analysis visuals:**") + for image_path in analysis_images: + if os.path.exists(image_path): + st.image(image_path) + else: + st.caption(f"Missing image: {image_path}") + # Natural Hazards if 'haz_fig' in stored_figs: st.markdown("**Natural hazards:**") @@ -490,4 +522,4 @@ def update_progress_ui(message): key=f"download_button_txt_{timestamp}" ) - return \ No newline at end of file + return diff --git a/src/climsight/terminal_interface.py b/src/climsight/terminal_interface.py index 6909357..ea2cc6f 100644 --- a/src/climsight/terminal_interface.py +++ b/src/climsight/terminal_interface.py @@ -17,6 +17,7 @@ from data_container import DataContainer from extract_climatedata_functions import plot_climate_data +from sandbox_utils import ensure_thread_id, ensure_sandbox_dirs, get_sandbox_paths logger = logging.getLogger(__name__) @@ -57,6 +58,11 @@ def run_terminal(config, api_key='', skip_llm_call=False, lon=None, lat=None, us logging.error(f"Missing configuration key: {e}") raise RuntimeError(f"Missing configuration key: {e}") + # Ensure sandbox exists for this CLI session. + thread_id = ensure_thread_id() + sandbox_paths = get_sandbox_paths(thread_id) + ensure_sandbox_dirs(sandbox_paths) + if not isinstance(skip_llm_call, bool): logging.error(f"skip_llm_call must be bool") raise TypeError("skip_llm_call must be bool") @@ -192,6 +198,9 @@ def run_terminal(config, api_key='', skip_llm_call=False, lon=None, lat=None, us is_on_land = False print_verbose(verbose, f"The selected point is in the ocean. Please choose a location on land.") else: + # Pass sandbox paths into the agent state. + input_params['thread_id'] = thread_id + input_params.update(sandbox_paths) # extend input_params with user_message input_params['user_message'] = user_message content_message = "Human request: {user_message} \n " + content_message @@ -331,4 +340,4 @@ def print_progress(message): #print(f"Time for forming request: {forming_request_time}") #print(f"Time for LLM request: {llm_request_time}") - return output, input_params, content_message, combine_agent_prompt_text \ No newline at end of file + return output, input_params, content_message, combine_agent_prompt_text diff --git a/src/climsight/tools/__init__.py b/src/climsight/tools/__init__.py index b46b142..bae2b6c 100644 --- a/src/climsight/tools/__init__.py +++ b/src/climsight/tools/__init__.py @@ -3,7 +3,14 @@ Tools for Climsight smart agents. """ -from .python_repl import create_python_repl_tool from .image_viewer import create_image_viewer_tool +from .python_repl import CustomPythonREPLTool -__all__ = ['create_python_repl_tool', 'create_image_viewer_tool'] \ No newline at end of file +__all__ = ['CustomPythonREPLTool', 'create_image_viewer_tool'] + +# Backwards-compatible import if older code expects create_python_repl_tool. +try: + from .python_repl import create_python_repl_tool # type: ignore + __all__.append('create_python_repl_tool') +except Exception: + pass diff --git a/src/climsight/tools/era5_retrieval_tool.py b/src/climsight/tools/era5_retrieval_tool.py new file mode 100644 index 0000000..c0ecfef --- /dev/null +++ b/src/climsight/tools/era5_retrieval_tool.py @@ -0,0 +1,275 @@ +# src/tools/era5_retrieval_tool.py +""" +src/tools/era5_retrieval_tool.py + +ERA5 data retrieval tool for use in the visualization agent. +Retrieves climate data from the Google Cloud ARCO-ERA5 dataset. +Saves the retrieved data ONLY in Zarr format. +The Zarr filename is based on a hash of the request parameters. +Optimized for memory efficiency using lazy loading (Dask) and streaming. +""" + +import os +import sys +import logging +import uuid +import hashlib +import shutil +import xarray as xr +import pandas as pd +import numpy as np +from pydantic import BaseModel, Field +from typing import Optional, Literal +from langchain_core.tools import StructuredTool + +# --- IMPORTS & CONFIGURATION --- +try: + import zarr + import gcsfs + # Optional: Check for Streamlit to support session state if available + try: + import streamlit as st + except ImportError: + st = None +except ImportError as e: + install_command = "pip install --upgrade xarray zarr numcodecs gcsfs blosc pandas numpy pydantic langchain-core" + raise ImportError( + f"Required libraries missing. Please ensure zarr, gcsfs, and others are installed.\n" + f"Try running: {install_command}" + ) from e + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# Target the Analysis-Ready (AR) Zarr store (Chunk-1 optimized) +ARCO_ERA5_MAIN_ZARR_STORE = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' + +class ERA5RetrievalArgs(BaseModel): + variable_id: Literal[ + "sea_surface_temperature", "surface_pressure", "total_cloud_cover", "total_precipitation", + "10m_u_component_of_wind", "10m_v_component_of_wind", "2m_temperature", "2m_dewpoint_temperature", + "geopotential", "specific_humidity", "temperature", "u_component_of_wind", + "v_component_of_wind", "vertical_velocity" + ] = Field(description="ERA5 variable to retrieve (must match Zarr store names).") + start_date: str = Field(description="Start date (YYYY-MM-DD or YYYY-MM-DD HH:MM:SS).") + end_date: str = Field(description="End date (YYYY-MM-DD or YYYY-MM-DD HH:MM:SS).") + min_latitude: float = Field(-90.0, description="Minimum latitude (-90 to 90).") + max_latitude: float = Field(90.0, description="Maximum latitude (-90 to 90).") + min_longitude: float = Field(0.0, description="Minimum longitude (0-360 or -180 to 360).") + max_longitude: float = Field(359.75, description="Maximum longitude (0-360 or -180 to 360).") + pressure_level: Optional[int] = Field(None, description="Pressure level in hPa for 3D variables.") + +def _generate_descriptive_filename( + variable_id: str, start_date: str, end_date: str, + pressure_level: Optional[int] = None +) -> str: + """ + Generate a descriptive filename based on request parameters. + """ + clean_variable = variable_id.replace('-', '_').replace('/', '_').replace(' ', '_') + clean_start = start_date.split()[0].replace('-', '_') + clean_end = end_date.split()[0].replace('-', '_') + + filename_parts = [clean_variable, clean_start, "to", clean_end] + if pressure_level is not None: + filename_parts.extend(["level", str(pressure_level)]) + + filename = "_".join(filename_parts) + ".zarr" + return filename + +def _generate_request_hash( + variable_id: str, start_date: str, end_date: str, + min_latitude: float, max_latitude: float, + min_longitude: float, max_longitude: float, + pressure_level: Optional[int] +) -> str: + params_string = ( + f"{variable_id}-{start_date}-{end_date}-" + f"{min_latitude:.2f}-{max_latitude:.2f}-" + f"{min_longitude:.2f}-{max_longitude:.2f}-" + f"{pressure_level if pressure_level is not None else 'None'}" + ) + return hashlib.md5(params_string.encode('utf-8')).hexdigest() + +def retrieve_era5_data( + variable_id: str, start_date: str, end_date: str, + min_latitude: float = -90.0, max_latitude: float = 90.0, + min_longitude: float = 0.0, max_longitude: float = 359.75, + pressure_level: Optional[int] = None +) -> dict: + ds_arco = None + zarr_path = None + + try: + logging.info(f"ERA5 retrieval: var={variable_id}, time={start_date} to {end_date}, " + f"lat=[{min_latitude},{max_latitude}], lon=[{min_longitude},{max_longitude}], " + f"level={pressure_level}") + + # --- 1. Sandbox / Path Logic --- + main_dir = None + if "streamlit" in sys.modules and st is not None and hasattr(st, 'session_state'): + try: + thread_id = st.session_state.get("thread_id") + if thread_id: + main_dir = os.path.join("tmp", "sandbox", thread_id) + logging.info(f"Found session thread_id. Using persistent sandbox: {main_dir}") + except Exception: + pass + + # CLI fallback: use environment-provided session id when available. + if not main_dir: + env_thread_id = os.environ.get("CLIMSIGHT_THREAD_ID") + if env_thread_id: + main_dir = os.path.join("tmp", "sandbox", env_thread_id) + logging.info(f"Using CLI sandbox from CLIMSIGHT_THREAD_ID: {main_dir}") + + if not main_dir: + main_dir = os.path.join("tmp", "sandbox", "era5_data") + logging.info(f"Using general sandbox: {main_dir}") + + os.makedirs(main_dir, exist_ok=True) + era5_specific_dir = os.path.join(main_dir, "era5_data") + os.makedirs(era5_specific_dir, exist_ok=True) + + # Check cache + zarr_filename = _generate_descriptive_filename(variable_id, start_date, end_date, pressure_level) + zarr_path = os.path.join(era5_specific_dir, zarr_filename) + relative_zarr_path = os.path.join(os.path.basename(era5_specific_dir), zarr_filename) + + if os.path.exists(zarr_path): + logging.info(f"Data already cached at {zarr_path}") + return { + "success": True, + "output_path_zarr": relative_zarr_path, + "full_path": zarr_path, + "variable": variable_id, + "message": f"Cached ERA5 data found at {relative_zarr_path}" + } + + # --- 2. Open Dataset (Lazy / Dask) --- + logging.info(f"Opening ERA5 dataset: {ARCO_ERA5_MAIN_ZARR_STORE}") + + # CRITICAL OPTIMIZATION: chunks={'time': 24} + # This prevents loading the whole dataset into RAM and creates an efficient download graph. + ds_arco = xr.open_zarr( + ARCO_ERA5_MAIN_ZARR_STORE, + chunks={'time': 24}, + consolidated=True, + storage_options={'token': 'anon'} + ) + + if variable_id not in ds_arco: + raise ValueError(f"Variable '{variable_id}' not found. Available: {list(ds_arco.data_vars)}") + + var_data_from_arco = ds_arco[variable_id] + + # --- 3. Lazy Subsetting --- + start_datetime_obj = pd.to_datetime(start_date) + end_datetime_obj = pd.to_datetime(end_date) + time_filtered_data = var_data_from_arco.sel(time=slice(start_datetime_obj, end_datetime_obj)) + + lat_slice_coords = slice(max_latitude, min_latitude) + lon_min_request_360 = min_longitude % 360 + lon_max_request_360 = max_longitude % 360 + + # Anti-Meridian / Date Line Logic + if lon_min_request_360 > lon_max_request_360: + part1 = time_filtered_data.sel(latitude=lat_slice_coords, longitude=slice(lon_min_request_360, 359.75)) + part2 = time_filtered_data.sel(latitude=lat_slice_coords, longitude=slice(0, lon_max_request_360)) + + if part1.sizes.get('longitude', 0) > 0 and part2.sizes.get('longitude', 0) > 0: + space_filtered_data = xr.concat([part1, part2], dim='longitude') + elif part1.sizes.get('longitude', 0) > 0: + space_filtered_data = part1 + elif part2.sizes.get('longitude', 0) > 0: + space_filtered_data = part2 + else: + space_filtered_data = time_filtered_data.sel(latitude=lat_slice_coords, longitude=slice(0,0)) + else: + space_filtered_data = time_filtered_data.sel( + latitude=lat_slice_coords, + longitude=slice(lon_min_request_360, lon_max_request_360) + ) + + if pressure_level is not None and "level" in space_filtered_data.coords: + space_filtered_data = space_filtered_data.sel(level=pressure_level) + + if not all(dim_size > 0 for dim_size in space_filtered_data.sizes.values()): + msg = "Selected region/time period has zero size." + logging.warning(msg) + return {"success": False, "error": msg, "message": f"Failed: {msg}"} + + # --- 4. Prepare Output Dataset --- + final_subset_ds = space_filtered_data.to_dataset(name=variable_id) + + # Clear encoding (fixes Zarr v3 codec conflicts) + for var in final_subset_ds.variables: + final_subset_ds[var].encoding = {} + + # RE-CHUNK for safe streaming (24h blocks are manageable for Mac RAM) + final_subset_ds = final_subset_ds.chunk({'time': 24, 'latitude': -1, 'longitude': -1}) + + # Metadata + request_hash_str = _generate_request_hash( + variable_id, start_date, end_date, min_latitude, max_latitude, + min_longitude, max_longitude, pressure_level + ) + final_subset_ds.attrs.update({ + 'title': f"ERA5 {variable_id} data subset", + 'source_arco_era5': ARCO_ERA5_MAIN_ZARR_STORE, + 'retrieval_parameters_hash': request_hash_str + }) + + if os.path.exists(zarr_path): + logging.warning(f"Zarr path {zarr_path} exists. Removing for freshness.") + shutil.rmtree(zarr_path) + + # --- 5. Streaming Write (NO .load()) --- + logging.info(f"Streaming data to Zarr store: {zarr_path}") + + # compute=True triggers the download/write graph chunk-by-chunk + final_subset_ds.to_zarr( + store=zarr_path, + mode='w', + consolidated=True, + compute=True + ) + + logging.info(f"Successfully saved: {zarr_path}") + + return { + "success": True, + "output_path_zarr": relative_zarr_path, + "full_path": zarr_path, + "variable": variable_id, + "message": f"ERA5 data saved to {relative_zarr_path}" + } + + except Exception as e: + logging.error(f"Error in ERA5 retrieval: {e}", exc_info=True) + error_msg = str(e) + + if "Resource Exhausted" in error_msg or "Too many open files" in error_msg: + error_msg += " (GCS access issue or system limits. Try again.)" + elif "Unsupported type for store_like" in error_msg: + error_msg += " (Issue with Zarr store path. Ensure gs:// URI is used.)" + + # Cleanup partial data on failure + if zarr_path and os.path.exists(zarr_path): + shutil.rmtree(zarr_path, ignore_errors=True) + + return {"success": False, "error": error_msg, "message": f"Failed: {error_msg}"} + + finally: + if ds_arco is not None: + ds_arco.close() + +era5_retrieval_tool = StructuredTool.from_function( + func=retrieve_era5_data, + name="retrieve_era5_data", + description=( + "Retrieves a subset of the ARCO-ERA5 Zarr climate reanalysis dataset. " + "Saves the data locally as a Zarr store. " + "Uses efficient streaming to prevent memory overflow on large files." + ), + args_schema=ERA5RetrievalArgs +) diff --git a/src/climsight/tools/get_data_components.py b/src/climsight/tools/get_data_components.py new file mode 100644 index 0000000..7eb0993 --- /dev/null +++ b/src/climsight/tools/get_data_components.py @@ -0,0 +1,206 @@ +"""Tool for extracting specific climate variables from stored climatology.""" + +import ast +import calendar +import json +import logging +import os +from typing import List, Optional, Union + +import numpy as np +import pandas as pd +from pydantic import BaseModel, Field + +try: + from ..sandbox_utils import load_climate_data_manifest +except ImportError: + from sandbox_utils import load_climate_data_manifest + +logger = logging.getLogger(__name__) + + +def _load_df_list_from_manifest(manifest_path: str, climate_data_dir: str) -> List[dict]: + """Rehydrate df_list from a sandbox manifest when state.df_list is absent.""" + manifest = load_climate_data_manifest(manifest_path) + if not manifest: + return [] + + entries = [] + for entry in manifest.get("entries", []): + csv_path = os.path.join(climate_data_dir, entry.get("csv", "")) + meta_path = os.path.join(climate_data_dir, entry.get("meta", "")) + if not os.path.exists(csv_path): + continue + + df = pd.read_csv(csv_path) + meta = {} + if os.path.exists(meta_path): + try: + with open(meta_path, "r", encoding="utf-8") as f: + meta = json.load(f) + except Exception: + meta = {} + + entries.append({ + "dataframe": df, + "extracted_vars": meta.get("extracted_vars", {}), + "years_of_averaging": meta.get("years_of_averaging", ""), + "description": meta.get("description", ""), + "main": meta.get("main", False), + "source": meta.get("source", ""), + }) + + return entries + + +class GetDataComponentsArgs(BaseModel): + environmental_data: Optional[str] = Field( + default=None, + description=( + "The type of environmental data to retrieve. " + "Choose from Temperature, Precipitation, u_wind, or v_wind." + ), + ) + months: Optional[Union[str, List[str]]] = Field( + default=None, + description=( + "List of months or a stringified list of month names to retrieve data for. " + "Each month should be one of 'Jan', 'Feb', ..., 'Dec'. " + "If not specified, data for all months will be retrieved." + ), + ) + + +def create_get_data_components_tool(state, config, stream_handler=None): + """Create a StructuredTool bound to the current agent state.""" + try: + from langchain.tools import StructuredTool + except ImportError: + from langchain_core.tools import StructuredTool + + def get_data_components(**kwargs): + if stream_handler is not None: + stream_handler.update_progress("Retrieving specific climatology values...") + + if isinstance(kwargs.get("months"), str): + try: + kwargs["months"] = ast.literal_eval(kwargs["months"]) + except (ValueError, SyntaxError): + kwargs["months"] = None + + args = GetDataComponentsArgs(**kwargs) + environmental_data = args.environmental_data + months = args.months + + if not environmental_data: + return {"error": "No environmental data type specified."} + + climate_source = config.get("climate_data_source", "nextGEMS") + + # Normalize common shorthand to canonical names. + env_normalized = environmental_data.strip() + env_aliases = { + "tp": "Precipitation", + "precipitation": "Precipitation", + "mean2t": "Temperature", + "tas": "Temperature", + "temp": "Temperature", + "t2m": "Temperature", + "wind_u": "u_wind", + "wind_v": "v_wind", + "uas": "u_wind", + "vas": "v_wind", + } + env_normalized = env_aliases.get(env_normalized, env_normalized) + + df_list = getattr(state, "df_list", None) + if not df_list: + manifest_path = state.input_params.get("climate_data_manifest", "") + climate_data_dir = state.input_params.get("climate_data_dir", "") or getattr( + state, "climate_data_dir", "" + ) + df_list = _load_df_list_from_manifest(manifest_path, climate_data_dir) + + if not df_list: + return {"error": "No climatology data available."} + + if climate_source in ("nextGEMS", "ICCP"): + environmental_mapping = { + "Temperature": "mean2t", + "Precipitation": "tp", + "u_wind": "wind_u", + "v_wind": "wind_v", + } + elif climate_source == "AWI_CM": + environmental_mapping = { + "Temperature": "Present Day Temperature", + "Precipitation": "Present Day Precipitation", + "u_wind": "u_wind", + "v_wind": "v_wind", + } + else: + environmental_mapping = { + "Temperature": "mean2t", + "Precipitation": "tp", + "u_wind": "wind_u", + "v_wind": "wind_v", + } + + if env_normalized not in environmental_mapping: + return {"error": f"Invalid environmental data type: {environmental_data}"} + + var_name = environmental_mapping[env_normalized] + + if not months: + months = [calendar.month_abbr[m] for m in range(1, 13)] + + month_mapping = {calendar.month_abbr[m]: calendar.month_name[m] for m in range(1, 13)} + selected_months = [] + for month in months: + if month in month_mapping: + selected_months.append(month_mapping[month]) + elif month in calendar.month_name: + selected_months.append(month) + + response = {} + for entry in df_list: + df = entry.get("dataframe") + extracted_vars = entry.get("extracted_vars", {}) + + if df is None: + continue + if "Month" not in df.columns: + continue + + if var_name in df.columns: + var_meta = extracted_vars.get(var_name, {"units": ""}) + data_values = df[df["Month"].isin(selected_months)][var_name].tolist() + ext_data = {month: float(np.round(value, 2)) for month, value in zip(selected_months, data_values)} + ext_exp = ( + f"Monthly mean values of {env_normalized}, {var_meta.get('units', '')} " + f"for years: {entry.get('years_of_averaging', '')}" + ) + response.update({ext_exp: ext_data}) + else: + matching_cols = [col for col in df.columns if environmental_data.lower() in col.lower()] + if matching_cols: + col_name = matching_cols[0] + data_values = df[df["Month"].isin(selected_months)][col_name].tolist() + ext_data = {month: float(np.round(value, 2)) for month, value in zip(selected_months, data_values)} + ext_exp = ( + f"Monthly mean values of {env_normalized} " + f"for years: {entry.get('years_of_averaging', '')}" + ) + response.update({ext_exp: ext_data}) + + if not response: + return {"error": f"Variable '{environmental_data}' not found in climatology."} + + return response + + return StructuredTool.from_function( + func=get_data_components, + name="get_data_components", + description="Retrieve specific climate variables from the saved climatology.", + args_schema=GetDataComponentsArgs, + ) diff --git a/src/climsight/tools/package_tools.py b/src/climsight/tools/package_tools.py new file mode 100644 index 0000000..4b4ecf1 --- /dev/null +++ b/src/climsight/tools/package_tools.py @@ -0,0 +1,37 @@ +# src/tools/package_tools.py +import sys +import subprocess +import logging +from pydantic import BaseModel, Field +from langchain_core.tools import StructuredTool + +def install_package(package_name: str, pip_options: str = ""): + """ + Installs a Python package using pip. + + Args: + package_name: The name of the package to install + pip_options: Additional pip options (e.g., '--force-reinstall') + + Returns: + str: Success or error message + """ + try: + command = [sys.executable, '-m', 'pip', 'install'] + pip_options.split() + [package_name] + subprocess.check_call(command) + return f"Package '{package_name}' installed successfully." + except Exception as e: + return f"Failed to install package '{package_name}': {e}" + +# Define the args schema for install_package +class InstallPackageArgs(BaseModel): + package_name: str = Field(description="The name of the package to install.") + pip_options: str = Field(default="", description="Additional pip options (e.g., '--force-reinstall').") + +# Create the install_package_tool +install_package_tool = StructuredTool.from_function( + func=install_package, + name="install_package", + description="Installs a Python package using pip. Use this tool if you encounter a ModuleNotFoundError or need a package that's not installed.", + args_schema=InstallPackageArgs +) \ No newline at end of file diff --git a/src/climsight/tools/python_repl.py b/src/climsight/tools/python_repl.py index 580e17a..f78fbf7 100644 --- a/src/climsight/tools/python_repl.py +++ b/src/climsight/tools/python_repl.py @@ -1,120 +1,433 @@ -# src/climsight/tools/python_repl.py -""" -Simple Python REPL Tool for Climsight with persistent state. -""" +# src/tools/python_repl.py +import os import sys import logging +import pandas as pd +import matplotlib.pyplot as plt +import xarray as xr +import streamlit as st from io import StringIO -from typing import Dict, Any +from pydantic import BaseModel, Field, PrivateAttr +from langchain_experimental.tools import PythonREPLTool +from typing import Any, Dict, Optional, List +import re +import threading +import json +import queue +import time +import atexit -logger = logging.getLogger(__name__) +# Import Jupyter Client +try: + # We use KernelManager for lifecycle management and KernelClient for communication + from jupyter_client import KernelManager + from jupyter_client.client import KernelClient +except ImportError: + logging.error("Jupyter client not installed. Please run 'pip install jupyter_client ipykernel'") + # Raise error to ensure environment is correctly set up + raise ImportError("Missing dependencies for persistent REPL. Install jupyter_client and ipykernel.") +try: + from ..utils import log_history_event +except ImportError: + from utils import log_history_event -class PersistentPythonREPL: +# --- New Jupyter Kernel Executor --- + +class JupyterKernelExecutor: """ - A persistent Python REPL that maintains state between executions. - Variables created in one execution persist to the next, like a Jupyter notebook. + Manages a persistent Jupyter kernel for code execution. + Replaces the previous subprocess-based PersistentREPL. """ - - def __init__(self): - """Initialize with an empty locals dictionary and pre-imported modules.""" - # Persistent locals dictionary - this is THE KEY FEATURE - self.locals = {} + def __init__(self, working_dir=None): + self._working_dir = working_dir + # Initialize the KernelManager to use the Climsight kernel by default. + kernel_name = os.environ.get("CLIMSIGHT_KERNEL_NAME", "climsight") + self.km = KernelManager(kernel_name=kernel_name) + self.kc: Optional[KernelClient] = None + self.is_initialized = False + self._start_kernel() + + def _start_kernel(self): + """Starts the Jupyter kernel.""" + logging.info("Starting Jupyter kernel...") - # Pre-import common modules into the persistent namespace - self.locals.update({ - 'pd': __import__('pandas'), - 'np': __import__('numpy'), - 'plt': __import__('matplotlib.pyplot', fromlist=['pyplot']), - 'xr': __import__('xarray'), - 'os': __import__('os'), - 'Path': __import__('pathlib').Path, - }) + # Ensure the ipykernel is available. + try: + # Accessing kernel_spec validates that the kernel is installed and found + self.km.kernel_spec + except Exception as e: + logging.warning( + "Could not find python3 kernel spec. Falling back to current interpreter. Error: %s", + e, + ) + self.km.kernel_cmd = [ + sys.executable, + "-m", + "ipykernel_launcher", + "-f", + "{connection_file}", + ] + + # If kernel spec points to a missing interpreter, fall back to current Python. + try: + kernel_cmd = list(self.km.kernel_cmd or []) + if kernel_cmd: + exec_path = kernel_cmd[0] + if os.path.isabs(exec_path) and not os.path.exists(exec_path): + logging.warning( + "Kernel spec uses missing interpreter: %s. Using current Python.", + exec_path, + ) + self.km.kernel_cmd = [ + sys.executable, + "-m", + "ipykernel_launcher", + "-f", + "{connection_file}", + ] + except Exception as e: + logging.warning("Failed to validate kernel_cmd: %s", e) + + # Start the kernel in the specified working directory (CWD) + if self._working_dir and os.path.exists(self._working_dir): + self.km.start_kernel(cwd=self._working_dir) + else: + self.km.start_kernel() + + # Create the synchronous client + self.kc = self.km.client() + self.kc.start_channels() - def execute(self, code: str) -> str: + # Wait for the kernel to be ready + try: + self.kc.wait_for_ready(timeout=60) + logging.info("Jupyter kernel started and ready.") + except RuntimeError as e: + logging.error(f"Kernel failed to start: {e}") + self.km.shutdown_kernel() + raise + + def _execute_code(self, code: str, timeout: float = 300.0) -> Dict[str, Any]: + """Executes code synchronously in the kernel and captures outputs.""" + if not self.kc: + return {"status": "error", "error": "Kernel client not available."} + + # Send the execution request + msg_id = self.kc.execute(code) + result = { + "status": "success", + "stdout": "", + "stderr": "", + "display_data": [] + } + start_time = time.time() + + # Process messages until the kernel is idle or timeout occurs + while True: + if time.time() - start_time > timeout: + logging.warning("Kernel execution timed out. Interrupting kernel.") + self.km.interrupt_kernel() + result["status"] = "error" + result["error"] = f"Execution timed out after {timeout} seconds." + break + + try: + # Get messages from the IOPub channel (where outputs are published) + # Use a small timeout to remain responsive + msg = self.kc.get_iopub_msg(timeout=0.1) + except queue.Empty: + # Check if the kernel is still alive if the queue is empty + if not self.km.is_alive(): + result["status"] = "error" + result["error"] = "Kernel died unexpectedly." + break + continue + except Exception as e: + logging.error(f"Error getting IOPub message: {e}") + continue + + # Process the message if it relates to our execution request + if msg['parent_header'].get('msg_id') == msg_id: + msg_type = msg['msg_type'] + + if msg_type == 'status': + if msg['content']['execution_state'] == 'idle': + break # Execution finished + elif msg_type == 'stream': + if msg['content']['name'] == 'stdout': + result["stdout"] += msg['content']['text'] + elif msg['content']['name'] == 'stderr': + result["stderr"] += msg['content']['text'] + elif msg_type == 'display_data' or msg_type == 'execute_result': + # Capture rich display data (e.g., for interactive plots in the future) + result["display_data"].append(msg['content']['data']) + # Also capture text representation for the current agent output + if 'text/plain' in msg['content']['data']: + result["stdout"] += msg['content']['data']['text/plain'] + "\n" + elif msg_type == 'error': + result["status"] = "error" + ename = msg['content']['ename'] + evalue = msg['content']['evalue'] + # Clean ANSI escape codes from traceback + traceback = "\n".join(msg['content']['traceback']) + traceback = re.sub(r'\x1b\[[0-9;]*m', '', traceback) + result["error"] = f"{ename}: {evalue}\n{traceback}" + break + + return result + + @staticmethod + def sanitize_input(query: str) -> str: + """Sanitize input (from LangChain).""" + # Remove surrounding backticks, whitespace, and the 'python' keyword + query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query) + query = re.sub(r"(\s|`)*$", "", query) + return query + + def _sanitize_code(self, code: str) -> str: + """Clean specific syntax.""" + # Jupyter kernels handle magic commands natively, so we just need LangChain sanitization + return self.sanitize_input(code) + + def run(self, code: str, timeout: int = 300) -> str: + """Execute code in the Jupyter kernel and return formatted output.""" + cleaned_code = self._sanitize_code(code) + + result = self._execute_code(cleaned_code, timeout) + + if result["status"] == "success": + output = result["stdout"] + if result.get("stderr"): + # Include stderr output (often contains warnings) + if output: + output += "\n[STDERR]\n" + result["stderr"] + else: + output = "[STDERR]\n" + result["stderr"] + return output.strip() + elif result["status"] == "error": + return result["error"] + else: + return f"Fatal error: {result.get('error', 'Unknown error')}" + + def close(self): + """Cleanup and terminate the kernel.""" + if self.kc: + self.kc.stop_channels() + if hasattr(self, 'km') and self.km.is_alive(): + logging.info("Shutting down Jupyter kernel...") + self.km.shutdown_kernel(now=True) + logging.info("Jupyter kernel shut down.") + +# REPLManager updated to manage JupyterKernelExecutor instances +class REPLManager: + """ + Singleton for managing JupyterKernelExecutor instances for each session (thread_id). + """ + _instances: dict[str, JupyterKernelExecutor] = {} + _lock = threading.Lock() + + @classmethod + def get_repl(cls, session_id: str, sandbox_path: str = None) -> JupyterKernelExecutor: + with cls._lock: + if session_id not in cls._instances: + logging.info(f"Creating new Jupyter Kernel for session: {session_id}") + if sandbox_path: + logging.info(f"Starting Kernel in sandbox directory: {sandbox_path}") + try: + cls._instances[session_id] = JupyterKernelExecutor(working_dir=sandbox_path) + except Exception as e: + logging.error(f"Failed to create Jupyter Kernel for session {session_id}: {e}") + # Raise the exception so the calling agent can handle the failure + raise RuntimeError(f"Failed to initialize execution environment: {e}") + return cls._instances[session_id] + + @classmethod + def cleanup_repl(cls, session_id: str): + with cls._lock: + if session_id in cls._instances: + logging.info(f"Closing and removing Kernel for session: {session_id}") + cls._instances[session_id].close() + del cls._instances[session_id] + + @classmethod + def cleanup_all(cls): + """Cleanup all running kernels.""" + logging.info("Cleaning up all kernels on exit...") + sessions = list(cls._instances.keys()) + for session_id in sessions: + cls.cleanup_repl(session_id) + +# Ensure kernels are cleaned up when the application exits +atexit.register(REPLManager.cleanup_all) + +# --- End of new code --- + + +class CustomPythonREPLTool(PythonREPLTool): + _datasets: dict = PrivateAttr() + _results_dir: Optional[str] = PrivateAttr() + _session_key: Optional[str] = PrivateAttr() + + def __init__(self, datasets, results_dir=None, session_key=None, **kwargs): + super().__init__(**kwargs) + self._datasets = datasets + self._results_dir = results_dir + self._session_key = session_key + + def _initialize_session(self, repl: JupyterKernelExecutor) -> None: + """Prepare the Kernel session by importing libraries and auto-loading data.""" + logging.info("Initializing persistent Kernel session with auto-loading...") + + initialization_code = [ + "import os", + "import sys", + "import pandas as pd", + "import numpy as np", + "import matplotlib.pyplot as plt", + "import xarray as xr", + "import logging", + "from io import StringIO", + # Configure matplotlib for non-interactive backend (Crucial for kernels) + "import matplotlib", + "matplotlib.use('Agg')" + ] + + # Add results_dir if available + if self._results_dir: + # Ensure paths are correctly formatted for the kernel environment (use forward slashes) + results_dir_path = os.path.abspath(self._results_dir).replace('\\', '/') + initialization_code.append(f"results_dir = r'{results_dir_path}'") + # Kernel CWD should handle the existence, but we ensure the 'results' subdirectory exists + initialization_code.append("os.makedirs(results_dir, exist_ok=True)") + + # Inject dataset variables by auto-loading from files + for key, value in self._datasets.items(): + if key.startswith('dataset_') and isinstance(value, str) and os.path.isdir(value): + # 1. Create the path variable (e.g., dataset_1_path) + path_var_name = f"{key}_path" + abs_path = os.path.abspath(value).replace('\\', '/') + initialization_code.append(f"{path_var_name} = r'{abs_path}'") + + # 2. Attempt to find and auto-load data.csv + csv_path = os.path.join(value, 'data.csv') + if os.path.exists(csv_path): + initialization_code.append(f"# Auto-loading {key} from data.csv") + initialization_code.append(f"try:") + initialization_code.append(f" {key} = pd.read_csv(os.path.join({path_var_name}, 'data.csv'))") + initialization_code.append(f" print(f'Successfully auto-loaded data.csv into `{key}` variable.')") + initialization_code.append(f"except Exception as e:") + initialization_code.append(f" print(f'Could not auto-load {key}: {{e}}')") + + elif key == "uuid_main_dir" and isinstance(value, str): + # Add the main UUID directory + abs_path = os.path.abspath(value).replace('\\', '/') + initialization_code.append(f"uuid_main_dir = r'{abs_path}'") + + # Execute all initialization code in one go + full_init_code = "\n".join(initialization_code) + result = repl.run(full_init_code) + logging.info(f"Kernel session initialization result: {result.strip()}") + repl.is_initialized = True + + def _run(self, query: str, **kwargs) -> Any: """ - Execute Python code in the persistent environment. - - Args: - code: Python code to execute - - Returns: - String containing output or error message + Execute code using the persistent Jupyter Kernel tied to the session. """ - # Capture stdout - old_stdout = sys.stdout - stdout_capture = StringIO() - sys.stdout = stdout_capture + # Resolve session_id for Streamlit or CLI execution. + session_id = "" + if hasattr(self, "_session_key") and self._session_key: + session_id = self._session_key + else: + try: + session_id = st.session_state.thread_id + except Exception: + session_id = os.environ.get("CLIMSIGHT_THREAD_ID", "") + + if not session_id: + logging.error("CRITICAL ERROR: Session ID (thread_id) is missing.") + return { + "result": "ERROR: Session ID is missing. Cannot continue execution.", + "output": "ERROR: Session ID is missing. Cannot continue execution.", + "plot_images": [] + } + + # Get sandbox path from datasets (this is the intended working directory) + sandbox_path = self._datasets.get("uuid_main_dir") try: - # Try to execute as an expression first (to show return values) - try: - result = eval(code, self.locals, self.locals) - if result is not None: - print(repr(result)) - except SyntaxError: - # If it's not an expression, execute as statement(s) - exec(code, self.locals, self.locals) - - # Get the output - output = stdout_capture.getvalue() + # Retrieve or create the kernel executor + repl = REPLManager.get_repl(session_id, sandbox_path=sandbox_path) + except RuntimeError as e: + # Handle kernel initialization failure + return { + "result": f"ERROR: Failed to initialize the execution environment (Jupyter Kernel). Details: {e}", + "output": f"ERROR: Failed to initialize the execution environment (Jupyter Kernel). Details: {e}", + "plot_images": [] + } + + # "Warm up" the REPL on first call in the session + if not repl.is_initialized: + self._initialize_session(repl) + + logging.info(f"Executing code in persistent Kernel for session {session_id}:\n{query}") + + # Execute user code + output = repl.run(query) + + # Logic for detecting generated plots + generated_plots = [] + image_extensions = ('.png', '.jpg', '.jpeg', '.svg', '.pdf') + + # --- Plot Detection Logic --- + # The kernel executes with the sandbox_path as its CWD. + # We rely on the agent saving plots to relative paths (e.g., 'results/plot.png') + + # 1. Check for plot files mentioned in the output (stdout/stderr) + # Use regex to find potential paths (e.g., "Plot saved to: results/my_plot.png") + matches = re.findall(r'([a-zA-Z0-9_\-/\\.]+\.(png|jpg|jpeg|svg|pdf))', output, re.IGNORECASE) + for match in matches: + potential_path = match[0].strip() - if output: - return f"Output:\n{output}" + # Resolve the path relative to the sandbox directory (where the kernel is running) + if sandbox_path: + # Normalize slashes + potential_path = potential_path.replace('\\', '/') + full_potential_path = os.path.join(sandbox_path, potential_path) else: - return "Code executed successfully (no output)." - - except Exception as e: - return f"Error: {type(e).__name__}: {str(e)}" - - finally: - # Restore stdout - sys.stdout = old_stdout - + # If no sandbox path (e.g. CLI fallback), check relative to current working dir + full_potential_path = os.path.abspath(potential_path) -# Create a global instance for persistence across tool calls -_repl_instance = PersistentPythonREPL() + if os.path.exists(full_potential_path): + if full_potential_path not in generated_plots: + generated_plots.append(full_potential_path) + logging.info(f"Found plot path in output: {full_potential_path}") + # 2. Check results directory for recently created files (fallback mechanism) + if self._results_dir and os.path.exists(self._results_dir): + for file in os.listdir(self._results_dir): + if file.lower().endswith(image_extensions): + full_path = os.path.join(self._results_dir, file) + if full_path not in generated_plots and os.path.exists(full_path): + # Check if file was created recently (within the last 10 seconds) + # This prevents picking up old plots from the same session + if time.time() - os.path.getmtime(full_path) < 10: + generated_plots.append(full_path) + logging.info(f"Found recent plot in results directory: {full_path}") -def create_python_repl_tool(): + if generated_plots: + # Update UI flag if using Streamlit (safe check) + if 'streamlit' in sys.modules and hasattr(st, 'session_state'): + st.session_state.new_plot_generated = True + + log_history_event( + st.session_state, "plot_generated", + {"plot_paths": generated_plots, "agent": "PythonREPL", "content": query} + ) - """ - Create a Python REPL tool for LangChain agents with persistent state. - - Returns: - StructuredTool configured for Python code execution - """ - try: - from langchain.tools import StructuredTool - except ImportError: - from langchain_core.tools import StructuredTool - from pydantic import BaseModel, Field - - class PythonREPLInput(BaseModel): - code: str = Field( - description="Python code to execute. Has access to pandas (pd), numpy (np), matplotlib.pyplot (plt), and xarray (xr). Variables persist between executions." - ) - # Get available variables info - available_vars = [] - if hasattr(_repl_instance, 'locals'): - for key, value in _repl_instance.locals.items(): - if not key.startswith('__'): - available_vars.append(f"{key} ({type(value).__name__})") - - vars_description = "" - if available_vars: - vars_description = f"\nPre-loaded variables: {', '.join(available_vars[:10])}..." - - tool = StructuredTool.from_function( - func=_repl_instance.execute, - name="python_repl", - description=( - "Execute Python code for data analysis and calculations. " - "Available: pandas as pd, numpy as np, matplotlib.pyplot as plt, xarray as xr. " - "Variables and imports persist between executions like a Jupyter notebook." - + vars_description - ), - args_schema=PythonREPLInput - ) - return tool \ No newline at end of file + return { + "result": output, + "output": output, # For backward compatibility + "plot_images": generated_plots + } diff --git a/src/climsight/tools/reflection_tools.py b/src/climsight/tools/reflection_tools.py new file mode 100644 index 0000000..e6925de --- /dev/null +++ b/src/climsight/tools/reflection_tools.py @@ -0,0 +1,135 @@ +# src/tools/reflection_tools.py +import base64 +import os +import logging +try: + import streamlit as st # Import streamlit to access session state +except ImportError: + st = None +from pydantic import BaseModel, Field +from langchain_core.tools import StructuredTool +from openai import OpenAI +try: + from ..config import API_KEY as _API_KEY +except ImportError: + from config import API_KEY as _API_KEY + +# Define the function to encode the image +def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + +def reflect_on_image(image_path: str) -> str: + """ + Analyzes an image and provides feedback. Automatically resolves sandbox paths. + """ + # --- NEW SANDBOX PATH RESOLUTION LOGIC --- + final_image_path = image_path + + # If the path is relative, resolve it against the current session's sandbox + if not os.path.isabs(image_path): + thread_id = None + if st is not None and hasattr(st, "session_state"): + thread_id = st.session_state.get("thread_id") + if not thread_id: + thread_id = os.environ.get("CLIMSIGHT_THREAD_ID") + + if thread_id: + sandbox_dir = os.path.join("tmp", "sandbox", thread_id) + potential_path = os.path.join(sandbox_dir, image_path) + if os.path.exists(potential_path): + final_image_path = potential_path + logging.info(f"Resolved relative path '{image_path}' to sandbox path '{final_image_path}'") + else: + logging.warning(f"Could not resolve relative path '{image_path}' in sandbox '{sandbox_dir}'") + else: + logging.warning("Could not resolve relative path: No thread_id available.") + # --- END OF NEW LOGIC --- + + if not os.path.exists(final_image_path): + # Provide a more informative error message + return (f"Error: The file '{final_image_path}' (resolved from '{image_path}') does not exist. " + f"This often happens if the file was not saved correctly or if there is a path mismatch between the agent's environment and the tool's environment.") + + base64_image = encode_image(final_image_path) + + prompt = """You are a STRICT professional reviewer of scientific images. Your task is to provide critical feedback to the visual creator agent so they can improve their visualization. Be particularly harsh on basic readability issues. Evaluate the provided image using the following criteria: + +**CRITICAL FAILURES (Each results in automatic score reduction of at least 5 points):** +- Any overlapping text or labels +- Illegible or cut-off axis labels +- Missing axis titles or units +- Text that is too small to read clearly +- Labels that obscure data points + +1. **Axis and Font Quality** (CRITICAL): Evaluate the visibility of axes and appropriateness of font size and style. ANY of the following issues should result in a score of 3/10 or lower: + - Axis labels that are cut off, truncated, or partially visible + - Font size that is too small to read comfortably + - Missing axis titles or units + - Poorly formatted tick labels (overlapping, rotated at bad angles, etc.) + +2. **Label Clarity** (CRITICAL): This is ABSOLUTELY ESSENTIAL. If ANY text overlaps with other text, data points, or visual elements, the maximum possible score is 2/10. Check for: + - Text overlapping with other text + - Labels overlapping with data points or lines + - Legend text that overlaps or is cut off + - Annotations that clash with other elements + +3. Color Scheme: Analyze the color choices. Is the color scheme appropriate for the data presented? Are the colors distinguishable and not causing visual confusion? + +4. Data Representation: Evaluate how well the data is represented. Are data points clearly visible? Is the chosen chart or graph type appropriate for the data? + +5. **Legend and Scale** (Important): Check the presence and clarity of legends and scales. If the legend overlaps with the plot area or has overlapping text, reduce score by at least 4 points. + +6. Overall Layout: Assess the overall layout and use of space. Poor spacing that causes any text overlap should be heavily penalized. + +7. Technical Issues: Identify any technical problems such as pixelation, blurriness, or artifacts that might affect the image quality. + +8. Scientific Accuracy: To the best of your ability, comment on whether the image appears scientifically accurate and free from obvious errors or misrepresentations. + +9. **Convention Adherence**: Verify that the figure follows scientific conventions. For example, when depicting variables like 'Depth of water' or other vertical dimensions, these should appear on the Y-axis with minimum values at the top and maximum depth at the bottom. This is a critically important scientific convention - if depth/vertical dimensions are incorrectly presented on the horizontal X-axis, assign a score of 1/10. + +**SCORING GUIDELINES:** +- 7-10: Professional quality with no text/label issues +- 5-6: Minor issues but all text is readable +- 3-4: Significant problems including some text overlap or readability issues +- 1-2: Major failures with overlapping text, illegible labels, or missing critical elements + +BE CRITICAL of any text overlap or readability issues. A visualization with overlapping text is fundamentally flawed and should receive a very low score regardless of other qualities. + +Please provide a structured review addressing each of these points. Conclude with an overall assessment of the image quality, highlighting any significant issues or exemplary aspects. Finally, give the image a score out of 10.""" + if not _API_KEY: + return "Error: OPENAI_API_KEY not configured for reflect_on_image." + + openai_client = OpenAI(api_key=_API_KEY) + response = openai_client.chat.completions.create( + model="gpt-4o", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{base64_image}" + } + } + ] + } + ], + max_tokens=1000 + ) + + return response.choices[0].message.content + +# Define the args schema for reflect_on_image +class ReflectOnImageArgs(BaseModel): + image_path: str = Field(description="The path to the image to reflect on.") + +# Define the reflect_on_image tool +reflect_tool = StructuredTool.from_function( + func=reflect_on_image, + name="reflect_on_image", + description="A tool to reflect on an image and provide feedback for improvements.", + args_schema=ReflectOnImageArgs +) diff --git a/src/climsight/tools/visualization_tools.py b/src/climsight/tools/visualization_tools.py new file mode 100644 index 0000000..e4bbc02 --- /dev/null +++ b/src/climsight/tools/visualization_tools.py @@ -0,0 +1,261 @@ +# src/tools/visualization_tools.py +import os +import logging +from pydantic import BaseModel, Field +from langchain_core.tools import StructuredTool +from langchain_openai import OpenAIEmbeddings +from langchain_community.vectorstores import Chroma +import streamlit as st + +try: + from ..config import API_KEY as _API_KEY +except ImportError: + from config import API_KEY as _API_KEY + +class ExampleVisualizationArgs(BaseModel): + query: str = Field(description="The user's query about plotting.") + +def get_example_of_visualizations(query: str) -> str: + """ + Retrieves example visualizations related to the query. + + Parameters: + - query (str): The user's query about plotting. + + Returns: + - str: The content of the most relevant example file. + """ + # Initialize embeddings from session state or config + embeddings = OpenAIEmbeddings(api_key=_API_KEY) + + # Load the existing vector store + vector_store = Chroma( + collection_name="example_collection", + embedding_function=embeddings, + persist_directory=os.path.join('data', 'examples_database', 'chroma_langchain_notebooks') + ) + + # Perform a similarity search + results = vector_store.similarity_search_with_score(query, k=1) + + # Extract the most relevant document + doc, score = results[0] + + # Construct the full path to the txt file + file_name = doc.metadata['source'].lstrip('./') + full_path = os.path.join('data', 'examples_database', file_name) + + # Read and return the content of the txt file + try: + with open(full_path, 'r', encoding='utf-8') as file: + content = file.read() + return content + except Exception as e: + logging.error(f"An error occurred while reading the file: {str(e)}") + return "" # Return empty string if error occurs + +# Create the example visualization tool +example_visualization_tool = StructuredTool.from_function( + func=get_example_of_visualizations, + name="get_example_of_visualizations", + description="Retrieves example visualization code related to the user's query.", + args_schema=ExampleVisualizationArgs +) + +# File listing tool definition +class ListPlottingDataFilesArgs(BaseModel): + dummy_arg: str = Field(default="", description="(No arguments needed)") + +def list_plotting_data_files(dummy_arg: str = "") -> str: + """ + Lists ALL files recursively from two sources: + 1. The data/plotting_data directory (static resources) + 2. All files in the current UUID sandbox directories (active datasets) + + Returns a flat list of all available file paths using relative paths. + """ + import os + import streamlit as st + + all_files = [] + cwd = os.getcwd() + + # Part 1: List files from data/plotting_data + plotting_data_dir = os.path.join("data", "plotting_data") + if os.path.exists(plotting_data_dir): + for root, dirs, files in os.walk(plotting_data_dir): + for filename in files: + full_path = os.path.join(root, filename) + # Keep this as a relative path + all_files.append(f"STATIC: {full_path}") + + # Part 2: List all files from the current sandbox directory + thread_id = st.session_state.get("thread_id") if hasattr(st, "session_state") else None + if not thread_id: + thread_id = os.environ.get("CLIMSIGHT_THREAD_ID") + + if thread_id: + sandbox_dir = os.path.join("tmp", "sandbox", thread_id) + if os.path.exists(sandbox_dir): + for root, dirs, files in os.walk(sandbox_dir): + for filename in files: + full_path = os.path.join(root, filename) + if full_path.startswith(cwd): + rel_path = full_path[len(cwd) + 1:] + else: + rel_path = full_path + + rel_path = rel_path.replace('\\', '/') + + if "era5_data" in rel_path: + all_files.append(f"ERA5: {rel_path}") + else: + all_files.append(f"DATA: {rel_path}") + + # Return a simple list of all available files + if all_files: + return "Available files:\n" + "\n".join(all_files) + else: + return "No files found in plotting_data or active datasets." + +# Create the list plotting data files tool +list_plotting_data_files_tool = StructuredTool.from_function( + func=list_plotting_data_files, + name="list_plotting_data_files", + description="Lists ALL available files recursively, including plotting resources, dataset files, and ERA5 data. Use this to see exactly what files you can work with.", + args_schema=ListPlottingDataFilesArgs +) + +class WiseAgentToolArgs(BaseModel): + query: str = Field(description="The query about visualization to send to Claude for advice. Include details about your dataset structure, variables, and visualization goals.") + +def wise_agent(query: str) -> str: + """ + A tool that provides visualization advice using either OpenAI or Anthropic models. + + Args: + query: The query about visualization to send to the AI model + + Returns: + str: AI's advice on visualization + """ + import streamlit as st + import logging + import yaml + import os + + # Load configuration (Climsight uses config.yml by default) + config_path = os.path.join(os.getcwd(), "config.yml") + if os.path.exists(config_path): + with open(config_path, "r") as f: + app_config = yaml.safe_load(f) + else: + app_config = {} + + # Get wise agent configuration + wise_agent_config = app_config.get("wise_agent", {}) + provider = wise_agent_config.get("provider", "openai") # Default to OpenAI + + # Get dataset information from session state + datasets_text = st.session_state.get("viz_datasets_text", "") + + if not datasets_text: + datasets_text = "No dataset information available" + + # Get the list of available plotting data files + try: + available_files = list_plotting_data_files("") + logging.info("Successfully retrieved available plotting data files") + except Exception as e: + logging.error(f"Error retrieving available files: {str(e)}") + available_files = f"Error retrieving available files: {str(e)}" + + # Create the system prompt + system_prompt = """You are WISE_AGENT, a scientific visualization expert specializing in data visualization for research datasets. + +Your role is to provide specific, actionable advice on how to create the most effective visualizations for scientific data. + +When giving visualization advice: +0. Provide superb visualizations! That's your life goal! +1. ANALYZE THE DATA STRUCTURE first - recommend plot types based on the actual data dimensions and variables +2. Consider the SCIENTIFIC DOMAIN (oceanography, climate science, biodiversity) and its standard visualization practices +3. Recommend specific matplotlib/seaborn/plotly code strategies tailored to the data +4. Suggest appropriate color schemes that follow scientific conventions (e.g., sequential for continuous variables, categorical for discrete) +5. Provide precise advice on layouts, axes, legends, and annotations +6. For spatial/geographic data, recommend appropriate projections and map types +7. For time series, recommend appropriate temporal visualizations +8. Always prioritize clarity, accuracy, and scientific information density + +Your advice should be specifically tailored to the datasets the user is working with. Be concise but thorough in your recommendations. +""" + + # Enhance the query with dataset information and available files + enhanced_query = f""" +DATASET INFORMATION: +{datasets_text} + +AVAILABLE PLOTTING DATA FILES: +{available_files} + +USER QUERY: +{query} + +Please provide visualization advice based on this information. +""" + + try: + if provider.lower() == "anthropic": + # Use Anthropic's Claude + try: + anthropic_api_key = st.secrets["general"]["anthropic_api_key"] + logging.info("Using Anthropic Claude for wise_agent") + except KeyError: + logging.error("Anthropic API key not found in .streamlit/secrets.toml") + return "Error: Anthropic API key not found in .streamlit/secrets.toml. Please add it to use WISE_AGENT with Claude." + + anthropic_model = wise_agent_config.get("anthropic_model", "claude-3-7-sonnet-20250219") + + from langchain_anthropic import ChatAnthropic + llm = ChatAnthropic( + model=anthropic_model, + anthropic_api_key=anthropic_api_key, + #temperature=0.2, + ) + + logging.info(f"Making request to Claude model: {anthropic_model}") + + else: # Default to OpenAI + from langchain_openai import ChatOpenAI + + logging.info("Using OpenAI for wise_agent") + + openai_model = wise_agent_config.get("openai_model", "gpt-5") + llm = ChatOpenAI( + api_key=_API_KEY, + model_name=openai_model, + ) + + logging.info(f"Making request to OpenAI model: {openai_model}") + + # Generate the response + response = llm.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": enhanced_query} + ] + ) + + logging.info("Successfully received response from AI model") + return response.content + + except Exception as e: + logging.error(f"Error using WISE_AGENT: {str(e)}") + return f"Error using WISE_AGENT: {str(e)}" + +# Create the wise agent tool +wise_agent_tool = StructuredTool.from_function( + func=wise_agent, + name="wise_agent", + description="A tool that provides expert visualization advice using advanced AI models. Use this tool FIRST when planning complex visualizations or when you need guidance on best visualization practices for scientific data. Provide a detailed description of the data structure and visualization goals.", + args_schema=WiseAgentToolArgs +) diff --git a/src/climsight/utils.py b/src/climsight/utils.py new file mode 100644 index 0000000..7e921c1 --- /dev/null +++ b/src/climsight/utils.py @@ -0,0 +1,117 @@ +"""Utility helpers reused by tool modules.""" + +import json +import logging +import os +import re +import time +import uuid +from datetime import date, datetime +from typing import Any + +import pandas as pd + + +def generate_unique_image_path(sandbox_path: str = None) -> str: + """Generate a unique image path for saving plots.""" + unique_filename = f"fig_{uuid.uuid4()}.png" + if sandbox_path and os.path.exists(sandbox_path): + results_dir = os.path.join(sandbox_path, "results") + os.makedirs(results_dir, exist_ok=True) + return os.path.join(results_dir, unique_filename) + + figs_dir = os.path.join("tmp", "figs") + os.makedirs(figs_dir, exist_ok=True) + return os.path.join(figs_dir, unique_filename) + + +def sanitize_input(query: str) -> str: + return query.strip() + + +def make_json_serializable(obj: Any) -> Any: + """Convert objects into JSON-serializable structures.""" + if obj is None: + return None + if isinstance(obj, (str, int, float, bool)): + return obj + if isinstance(obj, (datetime, date)): + return obj.isoformat() + if isinstance(obj, pd.Series): + return obj.to_dict() + if isinstance(obj, pd.DataFrame): + return obj.to_dict(orient="records") + if hasattr(obj, "tolist"): + return obj.tolist() + if hasattr(obj, "item"): + return obj.item() + if isinstance(obj, dict): + return {k: make_json_serializable(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [make_json_serializable(item) for item in obj] + if isinstance(obj, set): + return list(obj) + if hasattr(obj, "__dict__"): + return make_json_serializable(obj.__dict__) + return str(obj) + + +def log_history_event(session_data: dict, event_type: str, details: dict) -> None: + """Append a structured event to session history.""" + if "execution_history" not in session_data: + session_data["execution_history"] = [] + + timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + event = { + "type": event_type, + "timestamp": timestamp, + } + + try: + serializable_details = make_json_serializable(details) + event.update(serializable_details) + except Exception as exc: + logging.error("Failed to serialize event details: %s", exc) + event.update({ + "serialization_error": str(exc), + "original_keys": list(details.keys()) if isinstance(details, dict) else "not_dict", + }) + + session_data["execution_history"].append(event) + + +def list_directory_contents(path: str) -> str: + """Return a formatted tree of directory contents.""" + result = [] + for root, _, files in os.walk(path): + level = root.replace(path, "").count(os.sep) + indent = " " * 4 * level + result.append(f"{indent}{os.path.basename(root)}/") + sub_indent = " " * 4 * (level + 1) + for file in files: + result.append(f"{sub_indent}{file}") + return "\n".join(result) + + +def escape_curly_braces(text: str) -> str: + if isinstance(text, str): + return text.replace("{", "{{").replace("}", "}}") + return str(text) + + +def get_last_python_repl_command(session_state: dict) -> str: + """Extract last Python_REPL tool call from a LangChain intermediate steps list.""" + intermediate_steps = session_state.get("intermediate_steps") + if not intermediate_steps: + return "" + + python_repl_commands = [] + for action, _ in intermediate_steps: + if action.get("tool") == "Python_REPL": + python_repl_commands.append(action) + + if python_repl_commands: + last_command_action = python_repl_commands[-1] + return last_command_action.get("tool_input", "") + + return "" From a9ef91a994067aacd948e4e8b3645957e1fef5c5 Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Wed, 21 Jan 2026 00:14:01 +0100 Subject: [PATCH 04/16] Add ERA5 climatology baseline tool and wire into analysis flow --- config.yml | 7 +- src/climsight/climsight_classes.py | 1 + src/climsight/climsight_engine.py | 21 ++ src/climsight/data_analysis_agent.py | 293 ++++++++++------- src/climsight/tools/__init__.py | 7 +- src/climsight/tools/era5_climatology_tool.py | 325 +++++++++++++++++++ 6 files changed, 533 insertions(+), 121 deletions(-) create mode 100644 src/climsight/tools/era5_climatology_tool.py diff --git a/config.yml b/config.yml index 92afe69..276c2dd 100644 --- a/config.yml +++ b/config.yml @@ -16,9 +16,14 @@ llm_dataanalysis: #used only in data_analysis_agent climatemodel_name: "AWI_CM" llmModeKey: "agent_llm" #"agent_llm" #"direct_llm" use_smart_agent: false -use_era5_data: false +use_era5_data: false # Download ERA5 time series from CDS API (requires credentials) use_powerful_data_analysis: false +# ERA5 Climatology Configuration (pre-computed observational baseline) +era5_climatology: + enabled: true # Always use ERA5 climatology as ground truth baseline + path: "data/era5/era5_climatology_2015_2025.zarr" # Path to pre-computed climatology + # Climate Data Source Configuration # Options: "nextGEMS", "ICCP", "AWI_CM" climate_data_source: "nextGEMS" diff --git a/src/climsight/climsight_classes.py b/src/climsight/climsight_classes.py index 618c072..d43c025 100644 --- a/src/climsight/climsight_classes.py +++ b/src/climsight/climsight_classes.py @@ -33,4 +33,5 @@ class AgentState(BaseModel): results_dir: str = "" # Plot output directory climate_data_dir: str = "" # Saved climatology directory era5_data_dir: str = "" # ERA5 output directory + era5_climatology_response: dict = {} # ERA5 observed climatology (ground truth) # stream_handler: StreamHandler # Uncomment if needed diff --git a/src/climsight/climsight_engine.py b/src/climsight/climsight_engine.py index 53ff43a..acd6dee 100644 --- a/src/climsight/climsight_engine.py +++ b/src/climsight/climsight_engine.py @@ -1079,8 +1079,29 @@ def combine_agent(state: AgentState): if state.data_analysis_response: state.input_params['data_analysis_response'] = state.data_analysis_response state.content_message += "\n Data analysis agent response: {data_analysis_response} " + + # Add ERA5 climatology (observational ground truth) to prompt + if state.era5_climatology_response and isinstance(state.era5_climatology_response, dict): + era5_data = state.era5_climatology_response + # Format ERA5 data as structured text for the LLM + era5_text = "ERA5 OBSERVATIONAL CLIMATOLOGY (2015-2025 average - GROUND TRUTH):\n" + era5_text += f"Source: {era5_data.get('source', 'ERA5 Reanalysis')}\n" + era5_text += f"Location: {era5_data.get('extracted_location', {})}\n" + if 'variables' in era5_data: + for var_name, var_info in era5_data['variables'].items(): + era5_text += f"\n{var_info.get('full_name', var_name)} ({var_info.get('units', '')}):\n" + monthly = var_info.get('monthly_values', {}) + for month, value in monthly.items(): + era5_text += f" {month}: {value}\n" + state.input_params['era5_climatology'] = era5_text + state.content_message += "\n ERA5 Observations (ground truth baseline): {era5_climatology} " + + # Add generated plot images to prompt so LLM knows about them if state.data_analysis_images: state.input_params['data_analysis_images'] = state.data_analysis_images + images_list = ", ".join(state.data_analysis_images) + state.input_params['data_analysis_images_text'] = images_list + state.content_message += "\n Generated visualizations (plot files): {data_analysis_images_text} " if state.smart_agent_response != {}: smart_analysis = state.smart_agent_response.get('output', '') diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py index 7731551..5d2796d 100644 --- a/src/climsight/data_analysis_agent.py +++ b/src/climsight/data_analysis_agent.py @@ -20,6 +20,7 @@ from sandbox_utils import ensure_thread_id, ensure_sandbox_dirs, get_sandbox_paths from agent_helpers import create_standard_agent_executor from tools.get_data_components import create_get_data_components_tool +from tools.era5_climatology_tool import create_era5_climatology_tool from tools.era5_retrieval_tool import era5_retrieval_tool from tools.python_repl import CustomPythonREPLTool from tools.reflection_tools import reflect_tool @@ -99,9 +100,10 @@ def _build_filter_prompt() -> str: ) -def _create_tool_prompt(datasets_text: str, config: dict) -> str: +def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon: float = None) -> str: """System prompt for tool-driven analysis - dynamically built based on config.""" - has_era5 = config.get("use_era5_data", False) + has_era5_climatology = config.get("era5_climatology", {}).get("enabled", True) + has_era5_download = config.get("use_era5_data", False) has_repl = config.get("use_powerful_data_analysis", False) prompt = """You are the data analysis agent for ClimSight. @@ -112,146 +114,162 @@ def _create_tool_prompt(datasets_text: str, config: dict) -> str: tool_num = 1 - # Only include get_data_components if Python_REPL is NOT available + # TOOL #1: ERA5 Climatology - ALWAYS FIRST (observational ground truth) + if has_era5_climatology: + coord_example = "" + if lat is not None and lon is not None: + coord_example = f" - For this query: get_era5_climatology(latitude={lat}, longitude={lon}, variables=[\"t2m\", \"tp\", \"u10\", \"v10\"])\n" + + prompt += f""" +{tool_num}. **get_era5_climatology** - Extract OBSERVED climate data (CALL THIS FIRST!) + - Source: ERA5 reanalysis 2015-2025 monthly climatology + - This is GROUND TRUTH - actual observations, not model output + - Variables: temperature (t2m), precipitation (tp), wind_u (u10), wind_v (v10), dewpoint (d2m), pressure (msl) + - Returns monthly averages for the nearest grid point (~28km resolution) +{coord_example} + **CRITICAL**: This tool provides what the climate ACTUALLY IS at this location. + Use this as the BASELINE to compare against climate model projections. +""" + tool_num += 1 + + # TOOL #2: get_data_components - only if Python_REPL is NOT available if not has_repl: prompt += f""" -{tool_num}. **get_data_components** - Extract climate variables from stored climatology +{tool_num}. **get_data_components** - Extract climate variables from climate MODEL projections - Variables: Temperature, Precipitation, u_wind, v_wind - Returns monthly values for historical AND future climate projections - Example: get_data_components(environmental_data="Temperature", months=["Jan", "Feb", "Mar"]) - - ALWAYS use this first to get the baseline climatology data + - These are MODEL outputs - compare with ERA5 observations to assess model quality """ tool_num += 1 - if has_era5: + # TOOL #3: ERA5 time series download (optional, for detailed analysis) + if has_era5_download: prompt += f""" -{tool_num}. **retrieve_era5_data** - Download ERA5 reanalysis data (OBSERVED REALITY) - - Variables: 2m_temperature, total_precipitation, 10m_u_component_of_wind, 10m_v_component_of_wind +{tool_num}. **retrieve_era5_data** - Download detailed ERA5 time series (OPTIONAL) + - Use ONLY if you need year-by-year data beyond the climatology + - For most analyses, get_era5_climatology is sufficient - Specify start_date, end_date (YYYY-MM-DD), and lat/lon bounds - - Data saved to sandbox as Zarr files for Python analysis - - **CRITICAL**: ERA5 is OBSERVED climate data (ground truth), NOT model projections! - - Use ERA5 to validate climate models by comparing with model historical period - - ERA5 shows the ACTUAL climate at this location (recent 10-20 years) - - Climate models may have bias - ERA5 reveals this bias - - **Standard workflow**: Download ~10 years of recent ERA5 data (e.g., 2014-2024) - - Example: retrieve_era5_data(variable_id="2m_temperature", start_date="2014-01-01", end_date="2023-12-31", min_latitude=48.0, max_latitude=49.0, min_longitude=11.0, max_longitude=12.0) - - Then use Python_REPL to compare ERA5 vs climate model historical period """ tool_num += 1 + # TOOL #4: Python REPL if has_repl: - era5_note = "" - if has_era5: - era5_note = """ - **ERA5 DATA PROCESSING** (when available): - - ERA5 data is saved as Zarr format in era5_data/ directory - - Use xarray to load: `ds = xr.open_zarr('era5_data/.zarr')` - - ERA5 = OBSERVED reality, climate_data/data.csv = MODEL projections - - ALWAYS compare ERA5 with model historical period to show model bias -""" - prompt += f""" {tool_num}. **Python_REPL** - Execute Python code for data analysis and visualizations - - Pre-loaded libraries: pandas (pd), numpy (np), matplotlib.pyplot (plt), xarray (xr) - - Climate data is AUTO-LOADED in the namespace (see data files below) + - Pre-loaded: pandas (pd), numpy (np), matplotlib.pyplot (plt), xarray (xr) - Working directory is the sandbox root - ALWAYS save plots to results/ directory **Climate Model Data** (climate_data/data.csv): - - Monthly climatology from climate models - - Columns: Month, mean2t (temperature), tp (precipitation), wind_u, wind_v - - Contains BOTH historical period (1995-2014) AND future projections (2020-2049) - - These are MODEL outputs, not observations! -{era5_note} - Example workflow (climate models only): + - Monthly climatology from climate models (MODEL projections, not observations) + - Columns: Month, mean2t (temperature °C), tp (precipitation), wind_u, wind_v + - Contains historical period AND future projections + + **ERA5 Climatology** (era5_climatology.json): + - After calling get_era5_climatology, results are saved here + - Load with: `import json; era5 = json.load(open('era5_climatology.json'))` + - Use ERA5 as GROUND TRUTH baseline for comparisons + + Example workflow: ```python import pandas as pd + import json import matplotlib.pyplot as plt - # Read climate model data + # Load ERA5 observations (ground truth) + era5 = json.load(open('era5_climatology.json')) + era5_temp = era5['variables']['t2m']['monthly_values'] + + # Load climate model projections df = pd.read_csv('climate_data/data.csv') - print(df.head()) - # Create temperature plot + # Compare: ERA5 observations vs climate model + months = list(era5_temp.keys()) + era5_values = list(era5_temp.values()) + model_values = df['mean2t'].tolist()[:12] # historical period + plt.figure(figsize=(10, 6)) - plt.plot(df['Month'], df['mean2t'], marker='o', label='Temperature') + plt.plot(months, era5_values, 'b-o', label='ERA5 Observations (2015-2025)') + plt.plot(months, model_values, 'r--s', label='Climate Model') plt.xlabel('Month') plt.ylabel('Temperature (°C)') - plt.title('Monthly Temperature') + plt.title('Observed vs Modeled Temperature') plt.legend() - plt.savefig('results/temperature_plot.png', dpi=150, bbox_inches='tight') + plt.xticks(rotation=45) + plt.tight_layout() + plt.savefig('results/era5_vs_model.png', dpi=150) plt.close() - print('Plot saved to results/temperature_plot.png') + print('Plot saved to results/era5_vs_model.png') ``` """ tool_num += 1 prompt += f""" {tool_num}. **list_plotting_data_files** - List files in sandbox directories -{tool_num + 1}. **reflect** - Think through analysis approach before executing -{tool_num + 2}. **wise_agent** - Get guidance on complex analysis decisions +""" + tool_num += 1 + # reflect_on_image and wise_agent only available with Python REPL + if has_repl: + prompt += f"""{tool_num}. **reflect_on_image** - Analyze a generated plot and get feedback for improvements +{tool_num + 1}. **wise_agent** - Get guidance on complex visualization decisions +""" + tool_num += 2 + + prompt += """ ## REQUIRED WORKFLOW """ - if has_era5 and has_repl: - # ERA5 + Python_REPL: Most comprehensive workflow + if has_era5_climatology: prompt += """ -**CRITICAL**: ERA5 is enabled - you MUST download and use it as ground truth! - -1. **DOWNLOAD ERA5 OBSERVATIONS** - This is MANDATORY, not optional: - - Call retrieve_era5_data for the last 10 years (e.g., 2014-01-01 to 2023-12-31) - - Download at least temperature and precipitation - - Use exact coordinates from the user query - - ERA5 = observed reality at this location - -2. **LOAD CLIMATE MODEL DATA** - Use Python_REPL: - - Read climate_data/data.csv (contains model historical + future projections) - - These are MODEL outputs, not observations - -3. **COMPARE ERA5 vs MODEL** - This reveals model bias: - - Load ERA5 data with xarray: `ds = xr.open_zarr('era5_data/.zarr')` - - Calculate ERA5 monthly climatology (mean by month across years) - - Compare with model historical period (1995-2014) - - Show the difference (bias) between observations and model - -4. **ANALYZE FUTURE PROJECTIONS** - With ERA5 context: - - Show model future projections - - Explain how model bias affects confidence in projections - - Use ERA5 as the baseline ("current climate") - -5. **CREATE VISUALIZATIONS** - MANDATORY: - - Plot 1: ERA5 observations vs model historical vs model future (3-way comparison) - - Plot 2: Model bias (ERA5 minus model historical) - - Plot 3: Future change (model future minus ERA5 baseline) - - Save ALL plots to results/ directory +**STEP 1 - GET OBSERVATIONS (MANDATORY):** +Call get_era5_climatology FIRST to get the observed climate baseline. +- Extract at minimum: temperature (t2m), precipitation (tp) +- This is what the climate ACTUALLY IS at this location (2015-2025 average) """ - elif has_repl and not has_era5: - # Python_REPL without ERA5: Standard workflow + + if has_repl: + prompt += """ +**STEP 2 - LOAD MODEL DATA:** +Use Python_REPL to read climate_data/data.csv (model projections) + +**STEP 3 - COMPARE OBSERVATIONS vs MODEL:** +- ERA5 = ground truth (what we observe NOW) +- Model historical = what the model simulates for recent past +- Difference = MODEL BIAS (critical for interpreting future projections) + +**STEP 4 - ANALYZE FUTURE WITH CONTEXT:** +- Show future projections from climate model +- Explain how model bias affects confidence +- Future change = (Model future) - (ERA5 baseline) + +**STEP 5 - CREATE VISUALIZATIONS (MANDATORY):** +- Plot 1: ERA5 observations vs model (shows model bias) +- Plot 2: Future projections with ERA5 baseline +- Save ALL plots to results/ directory +""" + else: + prompt += """ +**STEP 2 - GET MODEL DATA:** +Call get_data_components for Temperature and Precipitation + +**STEP 3 - COMPARE:** +- ERA5 climatology = observations (ground truth) +- Model data = projections +- Note any differences between observed and modeled values +""" + elif has_repl: prompt += """ 1. **READ THE DATA** - Use Python_REPL to inspect climate_data/data.csv -2. **ANALYZE** - Compare historical vs future projections in the data -3. **CREATE VISUALIZATIONS** - This is MANDATORY: - - Plot monthly temperature comparison (historical vs future) - - Plot precipitation patterns if relevant - - Plot wind data if relevant to the query - - Save ALL plots to results/ directory +2. **ANALYZE** - Compare historical vs future projections +3. **CREATE VISUALIZATIONS** - Save plots to results/ directory """ - elif not has_repl: - # Without Python_REPL, use get_data_components + else: prompt += """ 1. **ALWAYS START** by calling get_data_components for Temperature (all months) 2. **THEN** call get_data_components for Precipitation (all months) 3. **ANALYZE** the data - compare historical vs future projections -""" - - if has_era5: - prompt += """ - -**NOTE**: ERA5 is enabled but Python_REPL is not. You can download ERA5 data, but -you won't be able to process it effectively. Consider using get_data_components only. """ prompt += f""" @@ -263,17 +281,17 @@ def _create_tool_prompt(datasets_text: str, config: dict) -> str: Even if the user doesn't explicitly ask for plots, you SHOULD: - Create temperature trend visualizations -- Show precipitation comparisons between historical and future -- Highlight months with the largest projected changes -- Identify potential climate risks (e.g., heat stress, drought, flooding) +- Show precipitation comparisons +- Highlight months with largest projected changes +- Identify potential climate risks (heat stress, drought, flooding) """ - if has_era5: + if has_era5_climatology: prompt += """ -**With ERA5 data, ALWAYS include:** -- Model bias analysis (ERA5 vs model historical) -- Current climate baseline (from ERA5 observations) -- Confidence assessment (how well do models match observations?) +**With ERA5 observations, ALWAYS include:** +- Current observed climate (from ERA5 - this is REALITY) +- Model performance assessment (how well does the model match observations?) +- Future projections interpreted in context of model quality """ prompt += """ @@ -282,14 +300,14 @@ def _create_tool_prompt(datasets_text: str, config: dict) -> str: Your final response should include: """ - if has_era5 and has_repl: - prompt += """1. **Current Climate (ERA5 Observations)**: Recent observed temperature, precipitation -2. **Model Performance**: How well climate models match ERA5 observations (bias analysis) -3. **Future Projections**: Model predictions with context of model bias -4. **Climate Change Signal**: Differences between ERA5 baseline and future projections + if has_era5_climatology: + prompt += """1. **Current Climate (ERA5 Observations)**: What the climate ACTUALLY IS (2015-2025 average) +2. **Model Assessment**: How well climate models match ERA5 observations +3. **Future Projections**: Model predictions with confidence based on model-observation agreement +4. **Climate Change Signal**: Projected changes from current observed baseline 5. **Critical Months**: Which months show largest changes -6. **Visualizations**: List of plot files created (MANDATORY: 3-way comparison plots) -7. **Implications**: Interpretation with confidence levels based on model-observation agreement +6. **Visualizations**: List of plot files created +7. **Implications**: Interpretation relevant to the user's query """ else: prompt += """1. **Key Climate Values**: Extracted temperature, precipitation data @@ -423,21 +441,35 @@ def data_analysis_agent( if state.era5_data_dir: datasets["era5_data_dir"] = state.era5_data_dir + # Get coordinates for prompt + lat = state.input_params.get('lat') + lon = state.input_params.get('lon') + try: + lat = float(lat) if lat is not None else None + lon = float(lon) if lon is not None else None + except (ValueError, TypeError): + lat, lon = None, None + # Tool setup - ORDER MATTERS (matches prompt workflow) tools = [] has_python_repl = config.get("use_powerful_data_analysis", False) + has_era5_climatology = config.get("era5_climatology", {}).get("enabled", True) - # 1. get_data_components - ONLY if Python_REPL is NOT available + # 1. ERA5 Climatology - ALWAYS FIRST (observational ground truth) + if has_era5_climatology: + tools.append(create_era5_climatology_tool(state, config, stream_handler)) + + # 2. get_data_components - ONLY if Python_REPL is NOT available # (When Python_REPL is enabled, it's redundant - agent reads CSV directly) if not has_python_repl: tools.append(create_get_data_components_tool(state, config, stream_handler)) - # 2. ERA5 retrieval (if enabled) + # 3. ERA5 time series retrieval (if enabled - for detailed year-by-year analysis) if config.get("use_era5_data", False): tools.append(era5_retrieval_tool) - # 3. Python REPL for analysis/visualization (if enabled) + # 4. Python REPL for analysis/visualization (if enabled) if has_python_repl: repl_tool = CustomPythonREPLTool( datasets=datasets, @@ -446,15 +478,17 @@ def data_analysis_agent( ) tools.append(repl_tool) - # 4. Helper tools - tools.extend([ - list_plotting_data_files_tool, - reflect_tool, - wise_agent_tool, - ]) + # 5. Helper tools + tools.append(list_plotting_data_files_tool) + + # 6. Image reflection and wise_agent - ONLY when Python REPL is enabled + # (these tools are for evaluating/creating visualizations) + if has_python_repl: + tools.append(reflect_tool) + tools.append(wise_agent_tool) stream_handler.update_progress("Data analysis: running tools...") - tool_prompt = _create_tool_prompt(datasets_text, config) + tool_prompt = _create_tool_prompt(datasets_text, config, lat=lat, lon=lon) if llm_dataanalysis_agent is None: from langchain_openai import ChatOpenAI @@ -480,8 +514,14 @@ def data_analysis_agent( data_components_outputs = [] plot_images: List[str] = [] + era5_climatology_output = None for action, observation in result.get("intermediate_steps", []): + if action.tool == "get_era5_climatology": + obs = _normalize_tool_observation(observation) + if isinstance(obs, dict) and "error" not in obs: + era5_climatology_output = obs + state.era5_climatology_response = obs if action.tool == "get_data_components": data_components_outputs.append(_normalize_tool_observation(observation)) if action.tool in ("Python_REPL", "python_repl"): @@ -489,14 +529,28 @@ def data_analysis_agent( if isinstance(obs, dict): plot_images.extend(obs.get("plot_images", [])) if action.tool == "retrieve_era5_data": - # Keep ERA5 outputs in state only (no raw text in summary). + # Keep ERA5 time series outputs in state only (no raw text in summary). state.input_params.setdefault("era5_results", []).append( _normalize_tool_observation(observation) ) analysis_text = result.get("output", "") + + # Append ERA5 climatology summary if available + if era5_climatology_output: + analysis_text += "\n\n### ERA5 Observational Baseline (2015-2025)\n" + analysis_text += f"Location: {era5_climatology_output.get('extracted_location', {})}\n" + if "variables" in era5_climatology_output: + for var_name, var_data in era5_climatology_output["variables"].items(): + analysis_text += f"\n**{var_data.get('full_name', var_name)}** ({var_data.get('units', '')}):\n" + monthly = var_data.get("monthly_values", {}) + # Show a few key months + for month in ["January", "April", "July", "October"]: + if month in monthly: + analysis_text += f" {month}: {monthly[month]}\n" + if data_components_outputs: - analysis_text += "\n\nClimatology extracts:\n" + analysis_text += "\n\n### Climate Model Extracts:\n" for item in data_components_outputs: analysis_text += json.dumps(item, indent=2) + "\n" @@ -509,4 +563,5 @@ def data_analysis_agent( "data_analysis_response": analysis_text, "data_analysis_images": plot_images, "data_analysis_prompt_text": analysis_brief, + "era5_climatology_response": state.era5_climatology_response, } diff --git a/src/climsight/tools/__init__.py b/src/climsight/tools/__init__.py index bae2b6c..9182d27 100644 --- a/src/climsight/tools/__init__.py +++ b/src/climsight/tools/__init__.py @@ -5,8 +5,13 @@ from .image_viewer import create_image_viewer_tool from .python_repl import CustomPythonREPLTool +from .era5_climatology_tool import create_era5_climatology_tool -__all__ = ['CustomPythonREPLTool', 'create_image_viewer_tool'] +__all__ = [ + 'CustomPythonREPLTool', + 'create_image_viewer_tool', + 'create_era5_climatology_tool', +] # Backwards-compatible import if older code expects create_python_repl_tool. try: diff --git a/src/climsight/tools/era5_climatology_tool.py b/src/climsight/tools/era5_climatology_tool.py new file mode 100644 index 0000000..bfe6308 --- /dev/null +++ b/src/climsight/tools/era5_climatology_tool.py @@ -0,0 +1,325 @@ +"""Tool for extracting ERA5 climatology data from pre-computed Zarr file. + +This tool provides OBSERVED climate data (ground truth) from ERA5 reanalysis. +It should be called FIRST before analyzing climate model projections. +""" + +import json +import logging +import os +from typing import List, Optional, Union + +import numpy as np +import xarray as xr +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +# Default path to ERA5 climatology file +DEFAULT_ERA5_CLIMATOLOGY_PATH = "data/era5/era5_climatology_2015_2025.zarr" + +# Variable name mapping: agent-friendly names -> ERA5 variable names +VARIABLE_ALIASES = { + # Temperature + "temperature": "t2m", + "temp": "t2m", + "2m_temperature": "t2m", + "t2m": "t2m", + # Dewpoint + "dewpoint": "d2m", + "dewpoint_temperature": "d2m", + "2m_dewpoint": "d2m", + "d2m": "d2m", + # Precipitation + "precipitation": "tp", + "precip": "tp", + "total_precipitation": "tp", + "tp": "tp", + # Wind U component + "wind_u": "u10", + "u_wind": "u10", + "10m_u_wind": "u10", + "u10": "u10", + # Wind V component + "wind_v": "v10", + "v_wind": "v10", + "10m_v_wind": "v10", + "v10": "v10", + # Pressure + "pressure": "msl", + "mean_sea_level_pressure": "msl", + "msl": "msl", + # Surface pressure + "surface_pressure": "sp", + "sp": "sp", + # Sea surface temperature + "sea_surface_temperature": "sst", + "sst": "sst", +} + +# Variable metadata +# Note: tp in ERA5 Zarr is stored as m/day (daily average rate), converted to mm/month +VARIABLE_INFO = { + "t2m": {"full_name": "2 metre temperature", "units": "K", "convert_to_celsius": True}, + "d2m": {"full_name": "2 metre dewpoint temperature", "units": "K", "convert_to_celsius": True}, + "tp": {"full_name": "Total precipitation", "units": "m/day", "convert_to_mm": True}, + "u10": {"full_name": "10 metre U wind component", "units": "m/s", "convert_to_celsius": False}, + "v10": {"full_name": "10 metre V wind component", "units": "m/s", "convert_to_celsius": False}, + "msl": {"full_name": "Mean sea level pressure", "units": "Pa", "convert_to_celsius": False}, + "sp": {"full_name": "Surface pressure", "units": "Pa", "convert_to_celsius": False}, + "sst": {"full_name": "Sea surface temperature", "units": "K", "convert_to_celsius": True}, +} + +MONTH_NAMES = [ + "January", "February", "March", "April", "May", "June", + "July", "August", "September", "October", "November", "December" +] + +# Days per month (climatological average, Feb=28.25 for leap year average) +DAYS_IN_MONTH = [31, 28.25, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + + +def _haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float: + """Calculate distance in km between two points using Haversine formula.""" + R = 6371 # Earth's radius in km + + lat1_rad = np.radians(lat1) + lat2_rad = np.radians(lat2) + dlat = np.radians(lat2 - lat1) + dlon = np.radians(lon2 - lon1) + + a = np.sin(dlat/2)**2 + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon/2)**2 + c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a)) + + return R * c + + +def _normalize_longitude(lon: float) -> float: + """Normalize longitude to 0-360 range (ERA5 uses 0-360).""" + if lon < 0: + return lon + 360 + return lon + + +class GetERA5ClimatologyArgs(BaseModel): + latitude: float = Field( + description="Latitude of the location to extract data for (decimal degrees, -90 to 90)" + ) + longitude: float = Field( + description="Longitude of the location to extract data for (decimal degrees, -180 to 180)" + ) + variables: Optional[Union[str, List[str]]] = Field( + default=None, + description=( + "List of variables to extract. Options: temperature (t2m), precipitation (tp), " + "wind_u (u10), wind_v (v10), dewpoint (d2m), pressure (msl), surface_pressure (sp), sst. " + "If not specified, extracts temperature and precipitation by default." + ) + ) + + +def create_era5_climatology_tool(state, config, stream_handler=None): + """Create a StructuredTool for extracting ERA5 climatology data. + + Args: + state: AgentState with sandbox paths + config: Configuration dict + stream_handler: Optional progress handler + + Returns: + StructuredTool instance + """ + try: + from langchain.tools import StructuredTool + except ImportError: + from langchain_core.tools import StructuredTool + + # Get ERA5 climatology path from config or use default + era5_config = config.get("era5_climatology", {}) + era5_path = era5_config.get("path", DEFAULT_ERA5_CLIMATOLOGY_PATH) + + def get_era5_climatology( + latitude: float, + longitude: float, + variables: Optional[Union[str, List[str]]] = None + ) -> dict: + """Extract ERA5 climatology for a specific location. + + This provides OBSERVED climate data (2015-2025 average) as ground truth. + Use this data as the baseline to compare against climate model projections. + """ + if stream_handler is not None: + stream_handler.update_progress("Extracting ERA5 observational climatology...") + + # Parse variables input + if variables is None: + variables = ["t2m", "tp"] # Default: temperature and precipitation + elif isinstance(variables, str): + # Handle string input (might be JSON list or comma-separated) + try: + import ast + variables = ast.literal_eval(variables) + except (ValueError, SyntaxError): + variables = [v.strip() for v in variables.split(",")] + + # Normalize variable names + normalized_vars = [] + for var in variables: + var_lower = var.lower().strip() + if var_lower in VARIABLE_ALIASES: + normalized_vars.append(VARIABLE_ALIASES[var_lower]) + else: + logger.warning(f"Unknown variable: {var}. Skipping.") + + if not normalized_vars: + return { + "error": "No valid variables specified.", + "available_variables": list(VARIABLE_INFO.keys()) + } + + # Remove duplicates while preserving order + normalized_vars = list(dict.fromkeys(normalized_vars)) + + # Check if ERA5 file exists + if not os.path.exists(era5_path): + return { + "error": f"ERA5 climatology file not found at: {era5_path}", + "suggestion": "Check config.yml era5_climatology.path setting" + } + + try: + # Open the Zarr dataset + ds = xr.open_zarr(era5_path) + + # Normalize longitude to 0-360 range (ERA5 convention) + lon_normalized = _normalize_longitude(longitude) + + # Find nearest point + nearest = ds.sel( + latitude=latitude, + longitude=lon_normalized, + method="nearest" + ) + + actual_lat = float(nearest.latitude.values) + actual_lon = float(nearest.longitude.values) + + # Convert actual longitude back to -180 to 180 if needed + actual_lon_display = actual_lon if actual_lon <= 180 else actual_lon - 360 + + # Calculate distance to nearest point + distance_km = _haversine_distance(latitude, longitude, actual_lat, actual_lon_display) + + # Extract data for each variable + variables_data = {} + for var_name in normalized_vars: + if var_name not in ds.data_vars: + logger.warning(f"Variable {var_name} not in dataset. Available: {list(ds.data_vars)}") + continue + + var_info = VARIABLE_INFO.get(var_name, {}) + # Force compute if dask array and convert to numpy + raw_data = nearest[var_name] + if hasattr(raw_data, 'compute'): + raw_data = raw_data.compute() + values = np.array(raw_data.values, dtype=np.float64) + + # Convert units if needed + if var_info.get("convert_to_celsius", False): + values = values - 273.15 # K to °C + units = "°C" + elif var_info.get("convert_to_mm", False): + # ERA5 tp is in m/day (daily average rate), convert to mm/month + # Multiply by 1000 (m->mm) and by days in each month + values_monthly = np.array([ + values[i] * 1000.0 * DAYS_IN_MONTH[i] + for i in range(len(values)) + ]) + values = values_monthly + units = "mm/month" + else: + units = var_info.get("units", "") + + logger.debug(f"Variable {var_name}: raw range [{np.min(values):.4f}, {np.max(values):.4f}] {units}") + + # Build monthly values dict + monthly_values = {} + for i, month_name in enumerate(MONTH_NAMES): + monthly_values[month_name] = round(float(values[i]), 2) + + variables_data[var_name] = { + "full_name": var_info.get("full_name", var_name), + "units": units, + "monthly_values": monthly_values + } + + ds.close() + + # Build result + result = { + "source": "ERA5 Reanalysis (ECMWF)", + "data_type": "OBSERVATIONS (ground truth)", + "period": "2015-2025 monthly climatology (10-year average)", + "resolution": "0.25° grid (~28 km)", + "description": ( + "This is OBSERVED climate data from ERA5 reanalysis. " + "Use these values as the GROUND TRUTH baseline. " + "Climate model projections should be compared against this observational data." + ), + "requested_location": { + "latitude": latitude, + "longitude": longitude + }, + "extracted_location": { + "latitude": actual_lat, + "longitude": actual_lon_display + }, + "distance_from_requested_km": round(distance_km, 1), + "note": ( + f"ERA5 grid resolution is 0.25° (~28km). " + f"Data extracted from nearest grid point, {round(distance_km, 1)} km from requested location." + ), + "variables": variables_data, + "usage_guidance": { + "comparison": "Compare ERA5 values with climate model historical period to assess model bias", + "baseline": "Use ERA5 as the 'current climate' baseline (what we observe NOW)", + "interpretation": "ERA5 represents actual observed conditions, climate models are projections" + } + } + + # Save to sandbox if available + if hasattr(state, 'uuid_main_dir') and state.uuid_main_dir: + output_path = os.path.join(state.uuid_main_dir, "era5_climatology.json") + try: + with open(output_path, "w", encoding="utf-8") as f: + json.dump(result, f, indent=2) + logger.info(f"ERA5 climatology saved to: {output_path}") + result["saved_to"] = output_path + except Exception as e: + logger.warning(f"Could not save ERA5 climatology: {e}") + + # Also store in state + if hasattr(state, 'era5_climatology_response'): + state.era5_climatology_response = result + + return result + + except Exception as e: + logger.error(f"Error extracting ERA5 climatology: {e}") + return { + "error": f"Failed to extract ERA5 climatology: {str(e)}", + "latitude": latitude, + "longitude": longitude, + "variables_requested": normalized_vars + } + + return StructuredTool.from_function( + func=get_era5_climatology, + name="get_era5_climatology", + description=( + "Extract OBSERVED climate data from ERA5 reanalysis (2015-2025 climatology). " + "This provides GROUND TRUTH observations - call this FIRST before analyzing climate model data. " + "Returns monthly averages for temperature, precipitation, wind, etc. at the specified location." + ), + args_schema=GetERA5ClimatologyArgs, + ) From 76df28a071a5b8469e0218108fc2a5601a783474 Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Fri, 23 Jan 2026 14:05:46 +0100 Subject: [PATCH 05/16] Fix longitude normalization for NextGEMS HEALPix data and UI cleanup - Fix NextGEMSProvider: normalize negative longitudes to 0-360 range before KDTree query (fixes wrong temperatures for western hemisphere) - Remove "Provide additional information" toggle, always show extra info - Minor config and gitignore updates --- .gitignore | 2 +- config.yml | 2 +- src/climsight/climate_data_providers.py | 7 ++++++- src/climsight/streamlit_interface.py | 4 ++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 29f8f22..9c65ee4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ data_climate_foresight.tar data/ *.ipynb __pycache__ -climsight.log +climsight.logwhe climsight_evaluation.log cache/ evaluation/evaluation_report.txt diff --git a/config.yml b/config.yml index 276c2dd..66a8490 100644 --- a/config.yml +++ b/config.yml @@ -5,7 +5,7 @@ llm_rag: model_name: "gpt-5-mini" # used only for RAGs llm_smart: #used only in smart_agent model_type: "openai" - model_name: "gpt-5-mini" # used only for smart agent + model_name: "gpt-5.2" # used only for smart agent llm_combine: #used only in combine_agent and intro model_type: "openai" model_name: "gpt-5.2" # used only for combine agent ("mkchaou/climsight-calm_ft_Q3_13k") diff --git a/src/climsight/climate_data_providers.py b/src/climsight/climate_data_providers.py index 07275e4..4fc9c9f 100644 --- a/src/climsight/climate_data_providers.py +++ b/src/climsight/climate_data_providers.py @@ -296,8 +296,13 @@ def _extract_data_healpix( """ dataset = xr.open_dataset(nc_file) + # Normalize longitude to 0-360 range (HEALPix data uses 0-360) + query_lon = desired_lon + if desired_lon < 0: + query_lon = desired_lon + 360 + # Query 4 nearest neighbors - distances, indices = tree.query([desired_lon, desired_lat], k=4) + distances, indices = tree.query([query_lon, desired_lat], k=4) neighbors_lons = lons[indices] neighbors_lats = lats[indices] diff --git a/src/climsight/streamlit_interface.py b/src/climsight/streamlit_interface.py index 3e16eb3..c440d2e 100644 --- a/src/climsight/streamlit_interface.py +++ b/src/climsight/streamlit_interface.py @@ -135,8 +135,8 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r else: config['llm_combine']['model_type'] = "local" with col1: - show_add_info = st.toggle("Provide additional information", value=False, help="""If this is activated you will see all the variables - that were taken into account for the analysis as well as some plots.""") + # Always show additional information (removed toggle per user request) + show_add_info = True smart_agent = st.toggle("Use extra search", value=False, help="""If this is activated, ClimSight will make additional requests to Wikipedia and RAG, which can significantly increase response time.""") use_era5_data = st.toggle( "Enable ERA5 data", From 2a2a8c284fc471d8a1f05c1a2daebd16e1d343dd Mon Sep 17 00:00:00 2001 From: dmpantiu Date: Mon, 26 Jan 2026 18:09:38 +0100 Subject: [PATCH 06/16] Refactor ERA5 retrieval (Arraylake) & Python REPL; fix agent prompts --- src/climsight/climsight_classes.py | 1 + src/climsight/climsight_engine.py | 1 + src/climsight/data_analysis_agent.py | 60 ++- src/climsight/smart_agent.py | 25 +- src/climsight/tools/__init__.py | 14 +- src/climsight/tools/era5_retrieval_tool.py | 407 +++++++-------- src/climsight/tools/python_repl.py | 559 +++++++-------------- 7 files changed, 461 insertions(+), 606 deletions(-) diff --git a/src/climsight/climsight_classes.py b/src/climsight/climsight_classes.py index d43c025..c5cb4d3 100644 --- a/src/climsight/climsight_classes.py +++ b/src/climsight/climsight_classes.py @@ -34,4 +34,5 @@ class AgentState(BaseModel): climate_data_dir: str = "" # Saved climatology directory era5_data_dir: str = "" # ERA5 output directory era5_climatology_response: dict = {} # ERA5 observed climatology (ground truth) + era5_tool_response: str = "" # stream_handler: StreamHandler # Uncomment if needed diff --git a/src/climsight/climsight_engine.py b/src/climsight/climsight_engine.py index acd6dee..006c0b4 100644 --- a/src/climsight/climsight_engine.py +++ b/src/climsight/climsight_engine.py @@ -985,6 +985,7 @@ def intro_agent(state: AgentState): - **No Keywords Required:** Do not look for specific words like "climate" or "weather". - **Accept Fragments:** "Bridge", "Data Center", "Tomatoes", "Here", "My car" are all **VALID**. - **Accept Statements:** "I am worried about the heat", "Building a shed" are **VALID**. + - **Accept Technical Constraints:** Requests specifying years, models, or datasets (e.g., "Use 1980-2000 baseline") are **VALID**. Based on the conversation, decide on one of the following responses: - "next": either "FINISH" or "CONTINUE" diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py index 5d2796d..290dd2c 100644 --- a/src/climsight/data_analysis_agent.py +++ b/src/climsight/data_analysis_agent.py @@ -23,6 +23,7 @@ from tools.era5_climatology_tool import create_era5_climatology_tool from tools.era5_retrieval_tool import era5_retrieval_tool from tools.python_repl import CustomPythonREPLTool +from tools.image_viewer import create_image_viewer_tool from tools.reflection_tools import reflect_tool from tools.visualization_tools import ( list_plotting_data_files_tool, @@ -145,11 +146,28 @@ def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon # TOOL #3: ERA5 time series download (optional, for detailed analysis) if has_era5_download: + work_dir_str = sandbox_paths.get("uuid_main_dir", "") if "sandbox_paths" in dir() else "" prompt += f""" -{tool_num}. **retrieve_era5_data** - Download detailed ERA5 time series (OPTIONAL) - - Use ONLY if you need year-by-year data beyond the climatology - - For most analyses, get_era5_climatology is sufficient - - Specify start_date, end_date (YYYY-MM-DD), and lat/lon bounds +{tool_num}. **retrieve_era5_data** - Retrieve ERA5 Surface climate data from Earthmover (Arraylake) + + Use this tool to retrieve **historical weather/climate context (Time Series)**. + + **DATA SOURCE:** Earthmover (Arraylake), hardcoded to "temporal" mode. + **VARIABLE CODES:** Use short codes: 't2' (Temp), 'u10'/'v10' (Wind), 'mslp' (Pressure), 'tp' (Precip). + **PATH RULE:** You MUST always pass `work_dir` to this tool. + + **OUTPUT USAGE:** + The tool returns a path to a Zarr store. + **CRITICAL:** If you requested a specific point, the data is likely already reduced to that point (nearest neighbor). + + How to load result in Python_REPL: + ```python + import xarray as xr + ds = xr.open_dataset(path_from_tool_response, engine='zarr', chunks={{}}) + # If lat/lon are scalar, access directly: + data = ds['t2'].to_series() + ``` + """ tool_num += 1 @@ -210,6 +228,17 @@ def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon """ tool_num += 1 + # image_viewer only available with Python REPL + if has_repl: + prompt += f""" +{tool_num}. **image_viewer** - View generated plots to verify quality + + Pass the EXACT path printed by Python_REPL after saving a plot. + Use this to verify your visualizations before finalizing. + +""" + tool_num += 1 + # reflect_on_image and wise_agent only available with Python REPL if has_repl: prompt += f"""{tool_num}. **reflect_on_image** - Analyze a generated plot and get feedback for improvements @@ -481,7 +510,12 @@ def data_analysis_agent( # 5. Helper tools tools.append(list_plotting_data_files_tool) - # 6. Image reflection and wise_agent - ONLY when Python REPL is enabled + # 6. Image viewer - ONLY when Python REPL is enabled + if has_python_repl: + image_viewer_tool = create_image_viewer_tool() + tools.append(image_viewer_tool) + + # 7. Image reflection and wise_agent - ONLY when Python REPL is enabled # (these tools are for evaluating/creating visualizations) if has_python_repl: tools.append(reflect_tool) @@ -529,10 +563,17 @@ def data_analysis_agent( if isinstance(obs, dict): plot_images.extend(obs.get("plot_images", [])) if action.tool == "retrieve_era5_data": - # Keep ERA5 time series outputs in state only (no raw text in summary). - state.input_params.setdefault("era5_results", []).append( - _normalize_tool_observation(observation) - ) + # Handle ERA5 retrieval tool output + obs = _normalize_tool_observation(observation) + if isinstance(obs, dict): + era5_output = str(obs) + elif hasattr(obs, 'content'): + era5_output = obs.content + else: + era5_output = str(obs) + # Store in state + state.era5_tool_response = era5_output + state.input_params.setdefault("era5_results", []).append(obs) analysis_text = result.get("output", "") @@ -564,4 +605,5 @@ def data_analysis_agent( "data_analysis_images": plot_images, "data_analysis_prompt_text": analysis_brief, "era5_climatology_response": state.era5_climatology_response, + "era5_tool_response": getattr(state, 'era5_tool_response', None), } diff --git a/src/climsight/smart_agent.py b/src/climsight/smart_agent.py index a5b21b1..3203785 100644 --- a/src/climsight/smart_agent.py +++ b/src/climsight/smart_agent.py @@ -29,7 +29,9 @@ from langchain_core.prompts import ChatPromptTemplate #Import tools -#Import tools (python_repl and image_viewer moved to data_analysis_agent) +from tools.python_repl import create_python_repl_tool +from tools.image_viewer import create_image_viewer_tool +# era5_retrieval_tool is used only in data_analysis_agent #import requests #from bs4 import BeautifulSoup @@ -81,6 +83,7 @@ def smart_agent(state: AgentState, config, api_key, api_key_local, stream_handle # Create working directory work_dir = Path("tmp/sandbox") / st.session_state.session_uuid work_dir.mkdir(parents=True, exist_ok=True) + work_dir_str = str(work_dir.resolve()) # System prompt prompt = f""" @@ -98,6 +101,21 @@ def smart_agent(state: AgentState, config, api_key, api_key_local, stream_handle ONLY use this tool if the user's question is clearly about agriculture or crops. Do NOT use it for general climate queries. + - "python_repl" allows you to execute Python code for data analysis, visualization, and calculations. + Use this tool when: + - Creating visualizations (plots, charts, graphs) of climate data + + **CRITICAL: Your working directory is available at `work_dir` = '{work_dir_str}'** + **When saving plots, ALWAYS store the full path in a variable for later use!** + + **CORRECT way to save and reference images:** + ```python + plot_path = f'{{{{work_dir}}}}/temperature_plot.png' + plt.savefig(plot_path) + print(plot_path) # PRINT IT TO CONFIRM + ``` + + **Your goal**: Gather comprehensive background information and compile it into a well-structured summary that will be used by subsequent agents for data analysis. @@ -510,7 +528,8 @@ def process_ecocrop_search(query: str) -> str: ) - # python_repl_tool moved to data_analysis_agent + # Create python_repl tool + python_repl_tool = create_python_repl_tool() # Initialize the LLM if config['llm_smart']['model_type'] == "local": @@ -531,7 +550,7 @@ def process_ecocrop_search(query: str) -> str: llm = get_aitta_chat_model(config['llm_smart']['model_name'], temperature = 0) # List of tools - tools = [rag_tool, ecocrop_tool, wikipedia_tool] + tools = [rag_tool, ecocrop_tool, wikipedia_tool, python_repl_tool] prompt += """\nadditional information:\n question is related to this location: {location_str} \n diff --git a/src/climsight/tools/__init__.py b/src/climsight/tools/__init__.py index 9182d27..6e60df4 100644 --- a/src/climsight/tools/__init__.py +++ b/src/climsight/tools/__init__.py @@ -4,18 +4,14 @@ """ from .image_viewer import create_image_viewer_tool -from .python_repl import CustomPythonREPLTool +from .python_repl import CustomPythonREPLTool, create_python_repl_tool from .era5_climatology_tool import create_era5_climatology_tool +from .era5_retrieval_tool import era5_retrieval_tool __all__ = [ 'CustomPythonREPLTool', + 'create_python_repl_tool', 'create_image_viewer_tool', 'create_era5_climatology_tool', -] - -# Backwards-compatible import if older code expects create_python_repl_tool. -try: - from .python_repl import create_python_repl_tool # type: ignore - __all__.append('create_python_repl_tool') -except Exception: - pass + 'era5_retrieval_tool', +] \ No newline at end of file diff --git a/src/climsight/tools/era5_retrieval_tool.py b/src/climsight/tools/era5_retrieval_tool.py index c0ecfef..e395ba1 100644 --- a/src/climsight/tools/era5_retrieval_tool.py +++ b/src/climsight/tools/era5_retrieval_tool.py @@ -1,23 +1,18 @@ -# src/tools/era5_retrieval_tool.py +# src/climsight/tools/era5_retrieval_tool.py """ -src/tools/era5_retrieval_tool.py - ERA5 data retrieval tool for use in the visualization agent. -Retrieves climate data from the Google Cloud ARCO-ERA5 dataset. -Saves the retrieved data ONLY in Zarr format. -The Zarr filename is based on a hash of the request parameters. -Optimized for memory efficiency using lazy loading (Dask) and streaming. +Retrieves ERA5 Surface climate data from Earthmover (Arraylake). +Saves the retrieved data locally in Zarr format. +Hardcoded to 'temporal' query mode for efficient time-series retrieval. +Uses 'nearest' neighbor selection for point coordinates to prevent empty spatial slices. """ import os import sys import logging -import uuid -import hashlib import shutil import xarray as xr import pandas as pd -import numpy as np from pydantic import BaseModel, Field from typing import Optional, Literal from langchain_core.tools import StructuredTool @@ -25,251 +20,263 @@ # --- IMPORTS & CONFIGURATION --- try: import zarr - import gcsfs + import arraylake + from arraylake import Client # Optional: Check for Streamlit to support session state if available try: import streamlit as st except ImportError: st = None except ImportError as e: - install_command = "pip install --upgrade xarray zarr numcodecs gcsfs blosc pandas numpy pydantic langchain-core" + install_command = "pip install --upgrade xarray zarr arraylake pandas numpy pydantic langchain-core" raise ImportError( - f"Required libraries missing. Please ensure zarr, gcsfs, and others are installed.\n" + f"Required libraries missing. Please ensure arraylake is installed.\n" f"Try running: {install_command}" ) from e logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -# Target the Analysis-Ready (AR) Zarr store (Chunk-1 optimized) -ARCO_ERA5_MAIN_ZARR_STORE = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' +# ============================================================================= +# EARTHMOVER (ARRAYLAKE) IMPLEMENTATION +# ============================================================================= + +# Variable Mapping: Maps friendly names to Earthmover short codes +VARIABLE_MAPPING = { + # Temperature + "sea_surface_temperature": "sst", + "2m_temperature": "t2", + "temperature": "t2", + "skin_temperature": "skt", + "dewpoint_temperature": "d2", + + # Wind + "10m_u_component_of_wind": "u10", + "10m_v_component_of_wind": "v10", + "u_component_of_wind": "u10", + "v_component_of_wind": "v10", + + # Pressure + "surface_pressure": "sp", + "mean_sea_level_pressure": "mslp", + + # Clouds/Precip + "total_cloud_cover": "tcc", + "convective_precipitation": "cp", + "large_scale_precipitation": "lsp", + "total_precipitation": "tp", + + # Identity mappings (so short codes work) + "t2": "t2", "sst": "sst", "mslp": "mslp", "u10": "u10", "v10": "v10", + "sp": "sp", "tcc": "tcc", "cp": "cp", "lsp": "lsp", "sd": "sd", "tp": "tp" +} class ERA5RetrievalArgs(BaseModel): + # query_type is removed from the agent's view (hardcoded internally) variable_id: Literal[ + "t2", "sst", "mslp", "u10", "v10", "sp", "tcc", "cp", "lsp", "sd", "skt", "d2", "tp", "sea_surface_temperature", "surface_pressure", "total_cloud_cover", "total_precipitation", "10m_u_component_of_wind", "10m_v_component_of_wind", "2m_temperature", "2m_dewpoint_temperature", - "geopotential", "specific_humidity", "temperature", "u_component_of_wind", - "v_component_of_wind", "vertical_velocity" - ] = Field(description="ERA5 variable to retrieve (must match Zarr store names).") - start_date: str = Field(description="Start date (YYYY-MM-DD or YYYY-MM-DD HH:MM:SS).") - end_date: str = Field(description="End date (YYYY-MM-DD or YYYY-MM-DD HH:MM:SS).") - min_latitude: float = Field(-90.0, description="Minimum latitude (-90 to 90).") - max_latitude: float = Field(90.0, description="Maximum latitude (-90 to 90).") - min_longitude: float = Field(0.0, description="Minimum longitude (0-360 or -180 to 360).") - max_longitude: float = Field(359.75, description="Maximum longitude (0-360 or -180 to 360).") - pressure_level: Optional[int] = Field(None, description="Pressure level in hPa for 3D variables.") - -def _generate_descriptive_filename( - variable_id: str, start_date: str, end_date: str, - pressure_level: Optional[int] = None -) -> str: - """ - Generate a descriptive filename based on request parameters. - """ - clean_variable = variable_id.replace('-', '_').replace('/', '_').replace(' ', '_') - clean_start = start_date.split()[0].replace('-', '_') - clean_end = end_date.split()[0].replace('-', '_') - - filename_parts = [clean_variable, clean_start, "to", clean_end] - if pressure_level is not None: - filename_parts.extend(["level", str(pressure_level)]) - - filename = "_".join(filename_parts) + ".zarr" - return filename - -def _generate_request_hash( - variable_id: str, start_date: str, end_date: str, - min_latitude: float, max_latitude: float, - min_longitude: float, max_longitude: float, - pressure_level: Optional[int] -) -> str: - params_string = ( - f"{variable_id}-{start_date}-{end_date}-" - f"{min_latitude:.2f}-{max_latitude:.2f}-" - f"{min_longitude:.2f}-{max_longitude:.2f}-" - f"{pressure_level if pressure_level is not None else 'None'}" - ) - return hashlib.md5(params_string.encode('utf-8')).hexdigest() + "temperature", "u_component_of_wind", "v_component_of_wind" + ] = Field(description="ERA5 variable to retrieve. Preferred short codes: 't2' (Air Temp), 'sst' (Sea Surface Temp), 'u10'/'v10' (Wind), 'mslp' (Pressure), 'tp' (Total Precip).") + + start_date: str = Field(description="Start date (YYYY-MM-DD). Data available 1979-2024.") + end_date: str = Field(description="End date (YYYY-MM-DD).") + + # Coordinates + min_latitude: float = Field(-90.0, description="Minimum latitude. For a specific point, use the same value as max_latitude.") + max_latitude: float = Field(90.0, description="Maximum latitude. For a specific point, use the same value as min_latitude.") + min_longitude: float = Field(0.0, description="Minimum longitude. For a specific point, use the same value as max_longitude.") + max_longitude: float = Field(359.75, description="Maximum longitude. For a specific point, use the same value as min_longitude.") + + work_dir: Optional[str] = Field(None, description="The absolute path to the working directory where data should be saved.") + +def _generate_descriptive_filename(variable_id: str, query_type: str, start_date: str, end_date: str) -> str: + """Generate a descriptive directory name for the Zarr store.""" + clean_var = variable_id.replace('_', '') + clean_start = start_date.split()[0].replace('-', '') + clean_end = end_date.split()[0].replace('-', '') + # We use .zarr extension, but it is a directory + return f"era5_{clean_var}_{query_type}_{clean_start}_{clean_end}.zarr" def retrieve_era5_data( - variable_id: str, start_date: str, end_date: str, - min_latitude: float = -90.0, max_latitude: float = 90.0, - min_longitude: float = 0.0, max_longitude: float = 359.75, - pressure_level: Optional[int] = None + variable_id: str, + start_date: str, + end_date: str, + min_latitude: float = -90.0, + max_latitude: float = 90.0, + min_longitude: float = 0.0, + max_longitude: float = 359.75, + work_dir: Optional[str] = None, + **kwargs # Catch-all for unused args ) -> dict: - ds_arco = None - zarr_path = None + """ + Retrieves ERA5 Surface data from Earthmover (Arraylake). + Hardcoded to use 'temporal' (Time-series) queries. + Uses Nearest Neighbor selection for points to avoid empty slices. + """ + ds = None + local_zarr_path = None + query_type = "temporal" # Hardcoded for Climsight point-data focus + + # Get API Key from environment + ARRAYLAKE_API_KEY = os.environ.get("ARRAYLAKE_API_KEY") + + if not ARRAYLAKE_API_KEY: + return {"success": False, "error": "Missing Arraylake API Key", "message": "Please set ARRAYLAKE_API_KEY environment variable."} try: - logging.info(f"ERA5 retrieval: var={variable_id}, time={start_date} to {end_date}, " - f"lat=[{min_latitude},{max_latitude}], lon=[{min_longitude},{max_longitude}], " - f"level={pressure_level}") + # Map Variable Name + short_var = VARIABLE_MAPPING.get(variable_id.lower(), variable_id) + logging.info(f"🌍 Earthmover ERA5 Retrieval ({query_type}): {short_var} | {start_date} to {end_date}") # --- 1. Sandbox / Path Logic --- main_dir = None - if "streamlit" in sys.modules and st is not None and hasattr(st, 'session_state'): + if work_dir and os.path.exists(work_dir): + main_dir = work_dir + # Fallback logic for streamlit session if work_dir not provided + elif "streamlit" in sys.modules and st is not None and hasattr(st, 'session_state'): try: - thread_id = st.session_state.get("thread_id") - if thread_id: - main_dir = os.path.join("tmp", "sandbox", thread_id) - logging.info(f"Found session thread_id. Using persistent sandbox: {main_dir}") - except Exception: + session_uuid = getattr(st.session_state, "session_uuid", None) + if session_uuid: + main_dir = os.path.join("tmp", "sandbox", session_uuid) + else: + thread_id = st.session_state.get("thread_id") + if thread_id: + main_dir = os.path.join("tmp", "sandbox", thread_id) + except: pass - # CLI fallback: use environment-provided session id when available. if not main_dir: - env_thread_id = os.environ.get("CLIMSIGHT_THREAD_ID") - if env_thread_id: - main_dir = os.path.join("tmp", "sandbox", env_thread_id) - logging.info(f"Using CLI sandbox from CLIMSIGHT_THREAD_ID: {main_dir}") - - if not main_dir: - main_dir = os.path.join("tmp", "sandbox", "era5_data") - logging.info(f"Using general sandbox: {main_dir}") + main_dir = os.path.join("tmp", "sandbox", "era5_earthmover") os.makedirs(main_dir, exist_ok=True) - era5_specific_dir = os.path.join(main_dir, "era5_data") - os.makedirs(era5_specific_dir, exist_ok=True) - - # Check cache - zarr_filename = _generate_descriptive_filename(variable_id, start_date, end_date, pressure_level) - zarr_path = os.path.join(era5_specific_dir, zarr_filename) - relative_zarr_path = os.path.join(os.path.basename(era5_specific_dir), zarr_filename) - - if os.path.exists(zarr_path): - logging.info(f"Data already cached at {zarr_path}") - return { - "success": True, - "output_path_zarr": relative_zarr_path, - "full_path": zarr_path, - "variable": variable_id, - "message": f"Cached ERA5 data found at {relative_zarr_path}" + era5_dir = os.path.join(main_dir, "era5_data") + os.makedirs(era5_dir, exist_ok=True) + + # Check Cache + zarr_dirname = _generate_descriptive_filename(short_var, query_type, start_date, end_date) + local_zarr_path = os.path.join(era5_dir, zarr_dirname) + absolute_zarr_path = os.path.abspath(local_zarr_path) + + if os.path.exists(local_zarr_path): + logging.info(f"⚡ Cache hit: {local_zarr_path}") + return { + "success": True, + "output_path_zarr": absolute_zarr_path, + "full_path": absolute_zarr_path, + "variable": short_var, + "query_type": query_type, + "message": f"Cached ERA5 data found at {absolute_zarr_path}" } - # --- 2. Open Dataset (Lazy / Dask) --- - logging.info(f"Opening ERA5 dataset: {ARCO_ERA5_MAIN_ZARR_STORE}") - - # CRITICAL OPTIMIZATION: chunks={'time': 24} - # This prevents loading the whole dataset into RAM and creates an efficient download graph. - ds_arco = xr.open_zarr( - ARCO_ERA5_MAIN_ZARR_STORE, - chunks={'time': 24}, - consolidated=True, - storage_options={'token': 'anon'} + # --- 2. Connect to Earthmover --- + logging.info("Connecting to Arraylake...") + client = Client(token=ARRAYLAKE_API_KEY) + repo_name = "earthmover-public/era5-surface-aws" + repo = client.get_repo(repo_name) + session = repo.readonly_session("main") + + # Open Dataset + ds = xr.open_dataset( + session.store, + engine="zarr", + consolidated=False, + zarr_format=3, + chunks=None, + group=query_type ) - if variable_id not in ds_arco: - raise ValueError(f"Variable '{variable_id}' not found. Available: {list(ds_arco.data_vars)}") + if short_var not in ds: + return {"success": False, "error": f"Variable '{short_var}' not found. Available: {list(ds.data_vars)}"} - var_data_from_arco = ds_arco[variable_id] - - # --- 3. Lazy Subsetting --- + # --- 3. Slicing & Selection --- start_datetime_obj = pd.to_datetime(start_date) end_datetime_obj = pd.to_datetime(end_date) - time_filtered_data = var_data_from_arco.sel(time=slice(start_datetime_obj, end_datetime_obj)) - - lat_slice_coords = slice(max_latitude, min_latitude) - lon_min_request_360 = min_longitude % 360 - lon_max_request_360 = max_longitude % 360 - - # Anti-Meridian / Date Line Logic - if lon_min_request_360 > lon_max_request_360: - part1 = time_filtered_data.sel(latitude=lat_slice_coords, longitude=slice(lon_min_request_360, 359.75)) - part2 = time_filtered_data.sel(latitude=lat_slice_coords, longitude=slice(0, lon_max_request_360)) - - if part1.sizes.get('longitude', 0) > 0 and part2.sizes.get('longitude', 0) > 0: - space_filtered_data = xr.concat([part1, part2], dim='longitude') - elif part1.sizes.get('longitude', 0) > 0: - space_filtered_data = part1 - elif part2.sizes.get('longitude', 0) > 0: - space_filtered_data = part2 - else: - space_filtered_data = time_filtered_data.sel(latitude=lat_slice_coords, longitude=slice(0,0)) - else: - space_filtered_data = time_filtered_data.sel( - latitude=lat_slice_coords, - longitude=slice(lon_min_request_360, lon_max_request_360) - ) + time_slice = slice(start_datetime_obj, end_datetime_obj) - if pressure_level is not None and "level" in space_filtered_data.coords: - space_filtered_data = space_filtered_data.sel(level=pressure_level) + # Check if it's a point query (min approx equal to max) + is_point_query = (abs(max_latitude - min_latitude) < 0.01) and (abs(max_longitude - min_longitude) < 0.01) - if not all(dim_size > 0 for dim_size in space_filtered_data.sizes.values()): - msg = "Selected region/time period has zero size." - logging.warning(msg) - return {"success": False, "error": msg, "message": f"Failed: {msg}"} + if is_point_query: + # CRITICAL FIX: Cannot use method="nearest" with slice objects! + # Must do spatial selection FIRST (with method="nearest"), THEN time slicing. + center_lat = (min_latitude + max_latitude) / 2.0 + center_lon = (min_longitude + max_longitude) / 2.0 - # --- 4. Prepare Output Dataset --- - final_subset_ds = space_filtered_data.to_dataset(name=variable_id) + logging.info(f"Point query detected: selecting nearest neighbor to {center_lat}, {center_lon}") - # Clear encoding (fixes Zarr v3 codec conflicts) - for var in final_subset_ds.variables: - final_subset_ds[var].encoding = {} + # Step 1: Select nearest spatial point FIRST (no slices!) + subset = ds[short_var].sel( + latitude=center_lat, + longitude=center_lon, + method="nearest" + ) + # Step 2: THEN slice by time + subset = subset.sel(time=time_slice) + else: + # Box query - no method="nearest" needed, standard slicing + # Earthmover lat is typically sorted Max -> Min + req_min_lon = min_longitude % 360 + req_max_lon = max_longitude % 360 + + if req_min_lon > req_max_lon: + # Simple clamp for now + lon_slice = slice(req_min_lon, 359.75) + else: + lon_slice = slice(req_min_lon, req_max_lon) + + subset = ds[short_var].sel( + time=time_slice, + latitude=slice(max_latitude, min_latitude), + longitude=lon_slice + ) - # RE-CHUNK for safe streaming (24h blocks are manageable for Mac RAM) - final_subset_ds = final_subset_ds.chunk({'time': 24, 'latitude': -1, 'longitude': -1}) + # Validate Data Existence + if subset.sizes.get('time', 0) == 0: + return {"success": False, "error": "Empty time slice.", "message": "No data found for the requested dates."} - # Metadata - request_hash_str = _generate_request_hash( - variable_id, start_date, end_date, min_latitude, max_latitude, - min_longitude, max_longitude, pressure_level - ) - final_subset_ds.attrs.update({ - 'title': f"ERA5 {variable_id} data subset", - 'source_arco_era5': ARCO_ERA5_MAIN_ZARR_STORE, - 'retrieval_parameters_hash': request_hash_str - }) - - if os.path.exists(zarr_path): - logging.warning(f"Zarr path {zarr_path} exists. Removing for freshness.") - shutil.rmtree(zarr_path) - - # --- 5. Streaming Write (NO .load()) --- - logging.info(f"Streaming data to Zarr store: {zarr_path}") + # --- 4. Save to Zarr --- + logging.info(f"Downloading data to {local_zarr_path}...") + + ds_out = subset.to_dataset(name=short_var) - # compute=True triggers the download/write graph chunk-by-chunk - final_subset_ds.to_zarr( - store=zarr_path, - mode='w', - consolidated=True, - compute=True - ) + # Clear encoding + for var in ds_out.variables: + ds_out[var].encoding = {} - logging.info(f"Successfully saved: {zarr_path}") + if os.path.exists(local_zarr_path): + shutil.rmtree(local_zarr_path) + + ds_out.to_zarr(local_zarr_path, mode="w", consolidated=True, compute=True) + logging.info("✅ Download complete.") return { - "success": True, - "output_path_zarr": relative_zarr_path, - "full_path": zarr_path, - "variable": variable_id, - "message": f"ERA5 data saved to {relative_zarr_path}" + "success": True, + "output_path_zarr": absolute_zarr_path, + "full_path": absolute_zarr_path, + "variable": short_var, + "query_type": query_type, + "message": f"ERA5 data ({query_type} optimized) retrieved and saved to {absolute_zarr_path}" } - - except Exception as e: - logging.error(f"Error in ERA5 retrieval: {e}", exc_info=True) - error_msg = str(e) - - if "Resource Exhausted" in error_msg or "Too many open files" in error_msg: - error_msg += " (GCS access issue or system limits. Try again.)" - elif "Unsupported type for store_like" in error_msg: - error_msg += " (Issue with Zarr store path. Ensure gs:// URI is used.)" - - # Cleanup partial data on failure - if zarr_path and os.path.exists(zarr_path): - shutil.rmtree(zarr_path, ignore_errors=True) - return {"success": False, "error": error_msg, "message": f"Failed: {error_msg}"} - + except Exception as e: + logging.error(f"Error in ERA5 Earthmover retrieval: {e}", exc_info=True) + if local_zarr_path and os.path.exists(local_zarr_path): + shutil.rmtree(local_zarr_path, ignore_errors=True) + return {"success": False, "error": str(e), "message": f"Failed: {str(e)}"} finally: - if ds_arco is not None: - ds_arco.close() + if ds is not None: + ds.close() era5_retrieval_tool = StructuredTool.from_function( func=retrieve_era5_data, name="retrieve_era5_data", description=( - "Retrieves a subset of the ARCO-ERA5 Zarr climate reanalysis dataset. " - "Saves the data locally as a Zarr store. " - "Uses efficient streaming to prevent memory overflow on large files." + "Retrieves ERA5 Surface climate data from Earthmover (Arraylake). " + "Optimized for TEMPORAL time-series extraction at specific locations. " + "Automatically snaps to nearest grid point to ensure data is returned. " + "Returns a Zarr directory path. " + "Available vars: t2 (temp), sst, u10/v10 (wind), mslp (pressure), tp (precip)." ), args_schema=ERA5RetrievalArgs -) +) \ No newline at end of file diff --git a/src/climsight/tools/python_repl.py b/src/climsight/tools/python_repl.py index f78fbf7..c41b4e7 100644 --- a/src/climsight/tools/python_repl.py +++ b/src/climsight/tools/python_repl.py @@ -1,433 +1,222 @@ -# src/tools/python_repl.py +# src/climsight/tools/python_repl.py +""" +Simple Python REPL Tool for Climsight. +Thread-safe implementation that maintains state within a specific agent execution context. +Includes Matplotlib backend fixes and Markdown sanitization. +""" -import os import sys import logging -import pandas as pd +import ast +import re +import traceback +import os +import matplotlib +# Force non-interactive backend before importing pyplot +matplotlib.use('Agg') import matplotlib.pyplot as plt -import xarray as xr -import streamlit as st from io import StringIO +from typing import Dict, Any, Optional from pydantic import BaseModel, Field, PrivateAttr -from langchain_experimental.tools import PythonREPLTool -from typing import Any, Dict, Optional, List -import re -import threading -import json -import queue -import time -import atexit - -# Import Jupyter Client -try: - # We use KernelManager for lifecycle management and KernelClient for communication - from jupyter_client import KernelManager - from jupyter_client.client import KernelClient -except ImportError: - logging.error("Jupyter client not installed. Please run 'pip install jupyter_client ipykernel'") - # Raise error to ensure environment is correctly set up - raise ImportError("Missing dependencies for persistent REPL. Install jupyter_client and ipykernel.") +# Try importing StructuredTool from different locations based on version try: - from ..utils import log_history_event + from langchain.tools import StructuredTool except ImportError: - from utils import log_history_event + from langchain_core.tools import StructuredTool +from langchain_experimental.tools import PythonREPLTool -# --- New Jupyter Kernel Executor --- +logger = logging.getLogger(__name__) -class JupyterKernelExecutor: +class PersistentPythonREPL: """ - Manages a persistent Jupyter kernel for code execution. - Replaces the previous subprocess-based PersistentREPL. + A persistent Python REPL that maintains state between executions within a single session. + Replaces the Jupyter Kernel approach with a lighter in-process exec approach. """ - def __init__(self, working_dir=None): - self._working_dir = working_dir - # Initialize the KernelManager to use the Climsight kernel by default. - kernel_name = os.environ.get("CLIMSIGHT_KERNEL_NAME", "climsight") - self.km = KernelManager(kernel_name=kernel_name) - self.kc: Optional[KernelClient] = None - self.is_initialized = False - self._start_kernel() - - def _start_kernel(self): - """Starts the Jupyter kernel.""" - logging.info("Starting Jupyter kernel...") - - # Ensure the ipykernel is available. - try: - # Accessing kernel_spec validates that the kernel is installed and found - self.km.kernel_spec - except Exception as e: - logging.warning( - "Could not find python3 kernel spec. Falling back to current interpreter. Error: %s", - e, - ) - self.km.kernel_cmd = [ - sys.executable, - "-m", - "ipykernel_launcher", - "-f", - "{connection_file}", - ] - # If kernel spec points to a missing interpreter, fall back to current Python. - try: - kernel_cmd = list(self.km.kernel_cmd or []) - if kernel_cmd: - exec_path = kernel_cmd[0] - if os.path.isabs(exec_path) and not os.path.exists(exec_path): - logging.warning( - "Kernel spec uses missing interpreter: %s. Using current Python.", - exec_path, - ) - self.km.kernel_cmd = [ - sys.executable, - "-m", - "ipykernel_launcher", - "-f", - "{connection_file}", - ] - except Exception as e: - logging.warning("Failed to validate kernel_cmd: %s", e) - - # Start the kernel in the specified working directory (CWD) - if self._working_dir and os.path.exists(self._working_dir): - self.km.start_kernel(cwd=self._working_dir) - else: - self.km.start_kernel() + def __init__(self): + """Initialize with an empty locals dictionary and pre-imported modules.""" + # Persistent locals dictionary - this is THE KEY FEATURE + self.locals = {} + + # Pre-import common modules into the persistent namespace + self.locals.update({ + 'pd': __import__('pandas'), + 'np': __import__('numpy'), + 'plt': __import__('matplotlib.pyplot', fromlist=['pyplot']), + 'xr': __import__('xarray'), + 'os': __import__('os'), + 'Path': __import__('pathlib').Path, + 'json': __import__('json'), + }) + # Add matplotlib.pyplot explicitly as plt (redundant but safe) + self.locals['plt'] = plt + + def execute(self, code: str) -> str: + """ + Execute Python code and capture stdout/stderr. + Automatically prints the result of the last expression. + """ + # 1. Sanitize input (remove Markdown code blocks) + code = re.sub(r"^(\s|`)*(?i:python)?\s*", "", code) + code = re.sub(r"(\s|`)*$", "", code) + + # 2. Capture stdout/stderr + old_stdout = sys.stdout + old_stderr = sys.stderr + stdout_capture = StringIO() + stderr_capture = StringIO() + sys.stdout = stdout_capture + sys.stderr = stderr_capture - # Create the synchronous client - self.kc = self.km.client() - self.kc.start_channels() - - # Wait for the kernel to be ready try: - self.kc.wait_for_ready(timeout=60) - logging.info("Jupyter kernel started and ready.") - except RuntimeError as e: - logging.error(f"Kernel failed to start: {e}") - self.km.shutdown_kernel() - raise - - def _execute_code(self, code: str, timeout: float = 300.0) -> Dict[str, Any]: - """Executes code synchronously in the kernel and captures outputs.""" - if not self.kc: - return {"status": "error", "error": "Kernel client not available."} - - # Send the execution request - msg_id = self.kc.execute(code) - result = { - "status": "success", - "stdout": "", - "stderr": "", - "display_data": [] - } - start_time = time.time() - - # Process messages until the kernel is idle or timeout occurs - while True: - if time.time() - start_time > timeout: - logging.warning("Kernel execution timed out. Interrupting kernel.") - self.km.interrupt_kernel() - result["status"] = "error" - result["error"] = f"Execution timed out after {timeout} seconds." - break - - try: - # Get messages from the IOPub channel (where outputs are published) - # Use a small timeout to remain responsive - msg = self.kc.get_iopub_msg(timeout=0.1) - except queue.Empty: - # Check if the kernel is still alive if the queue is empty - if not self.km.is_alive(): - result["status"] = "error" - result["error"] = "Kernel died unexpectedly." - break - continue - except Exception as e: - logging.error(f"Error getting IOPub message: {e}") - continue - - # Process the message if it relates to our execution request - if msg['parent_header'].get('msg_id') == msg_id: - msg_type = msg['msg_type'] - - if msg_type == 'status': - if msg['content']['execution_state'] == 'idle': - break # Execution finished - elif msg_type == 'stream': - if msg['content']['name'] == 'stdout': - result["stdout"] += msg['content']['text'] - elif msg['content']['name'] == 'stderr': - result["stderr"] += msg['content']['text'] - elif msg_type == 'display_data' or msg_type == 'execute_result': - # Capture rich display data (e.g., for interactive plots in the future) - result["display_data"].append(msg['content']['data']) - # Also capture text representation for the current agent output - if 'text/plain' in msg['content']['data']: - result["stdout"] += msg['content']['data']['text/plain'] + "\n" - elif msg_type == 'error': - result["status"] = "error" - ename = msg['content']['ename'] - evalue = msg['content']['evalue'] - # Clean ANSI escape codes from traceback - traceback = "\n".join(msg['content']['traceback']) - traceback = re.sub(r'\x1b\[[0-9;]*m', '', traceback) - result["error"] = f"{ename}: {evalue}\n{traceback}" - break - - return result + # 3. Parse the code + tree = ast.parse(code) + + if not tree.body: + return "No code to execute." + + # 4. Execution Logic: Handle last expression printing + last_node = tree.body[-1] + if isinstance(last_node, ast.Expr): + # Execute everything before the last expression + if len(tree.body) > 1: + module = ast.Module(body=tree.body[:-1], type_ignores=[]) + exec(compile(module, "", "exec"), self.locals, self.locals) + + # Evaluate the last expression and print result + expr = ast.Expression(body=last_node.value) + result = eval(compile(expr, "", "eval"), self.locals, self.locals) + if result is not None: + print(repr(result)) + else: + # Execute the whole block if it doesn't end in an expression + exec(compile(tree, "", "exec"), self.locals, self.locals) - @staticmethod - def sanitize_input(query: str) -> str: - """Sanitize input (from LangChain).""" - # Remove surrounding backticks, whitespace, and the 'python' keyword - query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query) - query = re.sub(r"(\s|`)*$", "", query) - return query + output = stdout_capture.getvalue() + errors = stderr_capture.getvalue() - def _sanitize_code(self, code: str) -> str: - """Clean specific syntax.""" - # Jupyter kernels handle magic commands natively, so we just need LangChain sanitization - return self.sanitize_input(code) + # Combine output + final_output = output + if errors: + final_output += f"\n[Stderr]: {errors}" - def run(self, code: str, timeout: int = 300) -> str: - """Execute code in the Jupyter kernel and return formatted output.""" - cleaned_code = self._sanitize_code(code) + if not final_output.strip(): + return "Code executed successfully (no output)." - result = self._execute_code(cleaned_code, timeout) + return final_output.strip() - if result["status"] == "success": - output = result["stdout"] - if result.get("stderr"): - # Include stderr output (often contains warnings) - if output: - output += "\n[STDERR]\n" + result["stderr"] - else: - output = "[STDERR]\n" + result["stderr"] - return output.strip() - elif result["status"] == "error": - return result["error"] - else: - return f"Fatal error: {result.get('error', 'Unknown error')}" + except Exception: + # Capture the full traceback for debugging + return traceback.format_exc() - def close(self): - """Cleanup and terminate the kernel.""" - if self.kc: - self.kc.stop_channels() - if hasattr(self, 'km') and self.km.is_alive(): - logging.info("Shutting down Jupyter kernel...") - self.km.shutdown_kernel(now=True) - logging.info("Jupyter kernel shut down.") + finally: + # Restore stdout/stderr + sys.stdout = old_stdout + sys.stderr = old_stderr -# REPLManager updated to manage JupyterKernelExecutor instances -class REPLManager: +class CustomPythonREPLTool(PythonREPLTool): """ - Singleton for managing JupyterKernelExecutor instances for each session (thread_id). + A wrapper class to maintain compatibility with Agent logic that expects + CustomPythonREPLTool. It uses PersistentPythonREPL internally. """ - _instances: dict[str, JupyterKernelExecutor] = {} - _lock = threading.Lock() - - @classmethod - def get_repl(cls, session_id: str, sandbox_path: str = None) -> JupyterKernelExecutor: - with cls._lock: - if session_id not in cls._instances: - logging.info(f"Creating new Jupyter Kernel for session: {session_id}") - if sandbox_path: - logging.info(f"Starting Kernel in sandbox directory: {sandbox_path}") - try: - cls._instances[session_id] = JupyterKernelExecutor(working_dir=sandbox_path) - except Exception as e: - logging.error(f"Failed to create Jupyter Kernel for session {session_id}: {e}") - # Raise the exception so the calling agent can handle the failure - raise RuntimeError(f"Failed to initialize execution environment: {e}") - return cls._instances[session_id] - - @classmethod - def cleanup_repl(cls, session_id: str): - with cls._lock: - if session_id in cls._instances: - logging.info(f"Closing and removing Kernel for session: {session_id}") - cls._instances[session_id].close() - del cls._instances[session_id] - - @classmethod - def cleanup_all(cls): - """Cleanup all running kernels.""" - logging.info("Cleaning up all kernels on exit...") - sessions = list(cls._instances.keys()) - for session_id in sessions: - cls.cleanup_repl(session_id) - -# Ensure kernels are cleaned up when the application exits -atexit.register(REPLManager.cleanup_all) - -# --- End of new code --- - - -class CustomPythonREPLTool(PythonREPLTool): _datasets: dict = PrivateAttr() _results_dir: Optional[str] = PrivateAttr() _session_key: Optional[str] = PrivateAttr() + _repl_instance: PersistentPythonREPL = PrivateAttr() def __init__(self, datasets, results_dir=None, session_key=None, **kwargs): super().__init__(**kwargs) self._datasets = datasets self._results_dir = results_dir self._session_key = session_key + self._repl_instance = PersistentPythonREPL() + self._initialize_session() - def _initialize_session(self, repl: JupyterKernelExecutor) -> None: - """Prepare the Kernel session by importing libraries and auto-loading data.""" - logging.info("Initializing persistent Kernel session with auto-loading...") - - initialization_code = [ - "import os", - "import sys", - "import pandas as pd", - "import numpy as np", - "import matplotlib.pyplot as plt", - "import xarray as xr", - "import logging", - "from io import StringIO", - # Configure matplotlib for non-interactive backend (Crucial for kernels) - "import matplotlib", - "matplotlib.use('Agg')" - ] - - # Add results_dir if available + def _initialize_session(self) -> None: + """Pre-load variables into the REPL locals.""" + logging.info("Initializing PersistentPythonREPL session...") + + # Inject dataset variables if self._results_dir: - # Ensure paths are correctly formatted for the kernel environment (use forward slashes) - results_dir_path = os.path.abspath(self._results_dir).replace('\\', '/') - initialization_code.append(f"results_dir = r'{results_dir_path}'") - # Kernel CWD should handle the existence, but we ensure the 'results' subdirectory exists - initialization_code.append("os.makedirs(results_dir, exist_ok=True)") + self._repl_instance.locals['results_dir'] = self._results_dir + # Ensure it exists + os.makedirs(self._results_dir, exist_ok=True) - # Inject dataset variables by auto-loading from files - for key, value in self._datasets.items(): - if key.startswith('dataset_') and isinstance(value, str) and os.path.isdir(value): - # 1. Create the path variable (e.g., dataset_1_path) - path_var_name = f"{key}_path" - abs_path = os.path.abspath(value).replace('\\', '/') - initialization_code.append(f"{path_var_name} = r'{abs_path}'") - - # 2. Attempt to find and auto-load data.csv - csv_path = os.path.join(value, 'data.csv') - if os.path.exists(csv_path): - initialization_code.append(f"# Auto-loading {key} from data.csv") - initialization_code.append(f"try:") - initialization_code.append(f" {key} = pd.read_csv(os.path.join({path_var_name}, 'data.csv'))") - initialization_code.append(f" print(f'Successfully auto-loaded data.csv into `{key}` variable.')") - initialization_code.append(f"except Exception as e:") - initialization_code.append(f" print(f'Could not auto-load {key}: {{e}}')") - - elif key == "uuid_main_dir" and isinstance(value, str): - # Add the main UUID directory - abs_path = os.path.abspath(value).replace('\\', '/') - initialization_code.append(f"uuid_main_dir = r'{abs_path}'") - - # Execute all initialization code in one go - full_init_code = "\n".join(initialization_code) - result = repl.run(full_init_code) - logging.info(f"Kernel session initialization result: {result.strip()}") - repl.is_initialized = True + if self._datasets.get("uuid_main_dir"): + self._repl_instance.locals['uuid_main_dir'] = self._datasets["uuid_main_dir"] + + # Auto-load data.csv if available + climate_data_dir = self._datasets.get("climate_data_dir") + if climate_data_dir: + self._repl_instance.locals['climate_data_dir'] = climate_data_dir + data_csv = os.path.join(climate_data_dir, 'data.csv') + if os.path.exists(data_csv): + try: + pd = self._repl_instance.locals['pd'] + self._repl_instance.locals['df'] = pd.read_csv(data_csv) + logging.info(f"Auto-loaded dataframe 'df' from {data_csv}") + except Exception as e: + logging.error(f"Failed to auto-load data.csv: {e}") def _run(self, query: str, **kwargs) -> Any: """ - Execute code using the persistent Jupyter Kernel tied to the session. + Execute code using the persistent REPL. """ - # Resolve session_id for Streamlit or CLI execution. - session_id = "" - if hasattr(self, "_session_key") and self._session_key: - session_id = self._session_key - else: - try: - session_id = st.session_state.thread_id - except Exception: - session_id = os.environ.get("CLIMSIGHT_THREAD_ID", "") - - if not session_id: - logging.error("CRITICAL ERROR: Session ID (thread_id) is missing.") - return { - "result": "ERROR: Session ID is missing. Cannot continue execution.", - "output": "ERROR: Session ID is missing. Cannot continue execution.", - "plot_images": [] - } - - # Get sandbox path from datasets (this is the intended working directory) - sandbox_path = self._datasets.get("uuid_main_dir") + logging.info(f"Executing code via PersistentPythonREPL:\n{query}") - try: - # Retrieve or create the kernel executor - repl = REPLManager.get_repl(session_id, sandbox_path=sandbox_path) - except RuntimeError as e: - # Handle kernel initialization failure - return { - "result": f"ERROR: Failed to initialize the execution environment (Jupyter Kernel). Details: {e}", - "output": f"ERROR: Failed to initialize the execution environment (Jupyter Kernel). Details: {e}", - "plot_images": [] - } - - # "Warm up" the REPL on first call in the session - if not repl.is_initialized: - self._initialize_session(repl) - - logging.info(f"Executing code in persistent Kernel for session {session_id}:\n{query}") - - # Execute user code - output = repl.run(query) - - # Logic for detecting generated plots - generated_plots = [] - image_extensions = ('.png', '.jpg', '.jpeg', '.svg', '.pdf') - - # --- Plot Detection Logic --- - # The kernel executes with the sandbox_path as its CWD. - # We rely on the agent saving plots to relative paths (e.g., 'results/plot.png') + # Execute + output = self._repl_instance.execute(query) - # 1. Check for plot files mentioned in the output (stdout/stderr) - # Use regex to find potential paths (e.g., "Plot saved to: results/my_plot.png") - matches = re.findall(r'([a-zA-Z0-9_\-/\\.]+\.(png|jpg|jpeg|svg|pdf))', output, re.IGNORECASE) - for match in matches: - potential_path = match[0].strip() - - # Resolve the path relative to the sandbox directory (where the kernel is running) - if sandbox_path: - # Normalize slashes - potential_path = potential_path.replace('\\', '/') - full_potential_path = os.path.join(sandbox_path, potential_path) - else: - # If no sandbox path (e.g. CLI fallback), check relative to current working dir - full_potential_path = os.path.abspath(potential_path) - - if os.path.exists(full_potential_path): - if full_potential_path not in generated_plots: - generated_plots.append(full_potential_path) - logging.info(f"Found plot path in output: {full_potential_path}") - - # 2. Check results directory for recently created files (fallback mechanism) + # Detect plots in results_dir + generated_plots = [] if self._results_dir and os.path.exists(self._results_dir): - for file in os.listdir(self._results_dir): - if file.lower().endswith(image_extensions): - full_path = os.path.join(self._results_dir, file) - if full_path not in generated_plots and os.path.exists(full_path): - # Check if file was created recently (within the last 10 seconds) - # This prevents picking up old plots from the same session - if time.time() - os.path.getmtime(full_path) < 10: - generated_plots.append(full_path) - logging.info(f"Found recent plot in results directory: {full_path}") - - if generated_plots: - # Update UI flag if using Streamlit (safe check) - if 'streamlit' in sys.modules and hasattr(st, 'session_state'): - st.session_state.new_plot_generated = True - - log_history_event( - st.session_state, "plot_generated", - {"plot_paths": generated_plots, "agent": "PythonREPL", "content": query} - ) + # Logic to find recently created images + # This simplistic check just looks for files mentioned in the output or assumes recent creation + matches = re.findall(r'([a-zA-Z0-9_\-/\\.]+\.(png|jpg|jpeg|svg|pdf))', output, re.IGNORECASE) + for match in matches: + path = match[0] + # Handle relative paths if they start with results_dir name + if os.path.basename(self._results_dir) in path: + full_path = os.path.join(os.path.dirname(self._results_dir), path) + if os.path.exists(full_path): + generated_plots.append(full_path) + # Check direct existence + elif os.path.exists(path): + generated_plots.append(path) return { "result": output, - "output": output, # For backward compatibility - "plot_images": generated_plots + "output": output, + "plot_images": list(set(generated_plots)) # Deduplicate } + +# Factory function for direct tool creation (used by Smart Agent) +def create_python_repl_tool() -> StructuredTool: + """ + Create a NEW Python REPL tool instance. + CRITICAL CHANGE: This creates a NEW PersistentPythonREPL() every time it is called. + """ + from pydantic import BaseModel, Field + + # Instantiate a fresh REPL for this specific tool instance + repl_instance = PersistentPythonREPL() + + class PythonREPLInput(BaseModel): + code: str = Field( + description="Python code to execute. Pre-loaded: pandas(pd), numpy(np), matplotlib.pyplot(plt), xarray(xr)." + ) + + # Bind the execute method of THIS specific instance + tool = StructuredTool.from_function( + func=repl_instance.execute, + name="python_repl", + description=( + "Execute Python code for data analysis and visualization. " + "State persists between calls within this session. " + "The last line expression is automatically printed. " + "Plots must be saved to 'results_dir' using plt.savefig()." + ), + args_schema=PythonREPLInput + ) + return tool \ No newline at end of file From 1799ea3683b05ddd36dd6466fd3eb0e95a8ee709 Mon Sep 17 00:00:00 2001 From: dmpantiu Date: Tue, 27 Jan 2026 14:07:08 +0100 Subject: [PATCH 07/16] feat: migrate Python REPL to persistent Jupyter kernels --- src/climsight/data_analysis_agent.py | 6 +- src/climsight/tools/python_repl.py | 486 +++++++++++++++++---------- 2 files changed, 322 insertions(+), 170 deletions(-) diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py index 290dd2c..35ea262 100644 --- a/src/climsight/data_analysis_agent.py +++ b/src/climsight/data_analysis_agent.py @@ -512,7 +512,11 @@ def data_analysis_agent( # 6. Image viewer - ONLY when Python REPL is enabled if has_python_repl: - image_viewer_tool = create_image_viewer_tool() + # Extract model name from config or default to gpt-4o + vision_model = config.get("llm_combine", {}).get("model_name", "gpt-4o") + + # FIX: Pass the required api_key and model_name + image_viewer_tool = create_image_viewer_tool(api_key, vision_model) tools.append(image_viewer_tool) # 7. Image reflection and wise_agent - ONLY when Python REPL is enabled diff --git a/src/climsight/tools/python_repl.py b/src/climsight/tools/python_repl.py index c41b4e7..cef8cea 100644 --- a/src/climsight/tools/python_repl.py +++ b/src/climsight/tools/python_repl.py @@ -1,222 +1,370 @@ # src/climsight/tools/python_repl.py """ -Simple Python REPL Tool for Climsight. -Thread-safe implementation that maintains state within a specific agent execution context. -Includes Matplotlib backend fixes and Markdown sanitization. +Python REPL Tool for Climsight using Jupyter Kernel. +Executes code in a persistent, isolated Jupyter Kernel process. """ +import os import sys import logging -import ast import re -import traceback -import os -import matplotlib -# Force non-interactive backend before importing pyplot -matplotlib.use('Agg') -import matplotlib.pyplot as plt -from io import StringIO -from typing import Dict, Any, Optional +import threading +import queue +import time +import atexit +import textwrap +from typing import Any, Dict, Optional, List, Set from pydantic import BaseModel, Field, PrivateAttr -# Try importing StructuredTool from different locations based on version +# Import LangChain components try: from langchain.tools import StructuredTool except ImportError: from langchain_core.tools import StructuredTool from langchain_experimental.tools import PythonREPLTool +# Import Jupyter Client +try: + from jupyter_client import KernelManager + from jupyter_client.client import KernelClient +except ImportError: + raise ImportError("Missing dependencies for persistent REPL. Install jupyter_client and ipykernel.") + +try: + import streamlit as st +except ImportError: + st = None + logger = logging.getLogger(__name__) -class PersistentPythonREPL: +# --- Helper: Unified Session ID --- + +def get_global_session_id() -> str: + """ + Returns a unified session ID prioritizing: + 1. Streamlit thread_id (most specific) + 2. Streamlit session_uuid (fallback) + 3. Environment variable (CLI/Docker) + 4. Default fallback + """ + if st is not None and hasattr(st, "session_state"): + if getattr(st.session_state, "thread_id", None): + return st.session_state.thread_id + if getattr(st.session_state, "session_uuid", None): + return st.session_state.session_uuid + + return os.environ.get("CLIMSIGHT_THREAD_ID", "default_cli_session") + +# --- Jupyter Kernel Executor (ISOLATED PROCESS) --- + +class JupyterKernelExecutor: """ - A persistent Python REPL that maintains state between executions within a single session. - Replaces the Jupyter Kernel approach with a lighter in-process exec approach. + Manages a persistent Jupyter kernel for code execution. + Running in a separate process ensures isolation. """ + def __init__(self, working_dir=None): + self._working_dir = working_dir + # Use default python3 kernel + self.km = KernelManager(kernel_name="python3") + self.kc: Optional[KernelClient] = None + self.is_initialized = False + self._start_kernel() + + def _start_kernel(self): + cwd = self._working_dir if (self._working_dir and os.path.exists(self._working_dir)) else os.getcwd() + logging.info(f"Starting Jupyter kernel in {cwd}...") + + try: + # Robust check for kernel spec + if not self.km.kernel_spec: + raise RuntimeError("No kernel spec found") + except Exception as e: + logging.warning(f"Kernel spec warning: {e}. Forcing sys.executable for ipykernel.") + # Fallback: explicitly call the current python executable to avoid path issues + self.km.kernel_cmd = [sys.executable, "-m", "ipykernel_launcher", "-f", "{connection_file}"] + + try: + self.km.start_kernel(cwd=cwd) + self.kc = self.km.client() + self.kc.start_channels() + self.kc.wait_for_ready(timeout=60) + logging.info(f"Jupyter kernel started successfully.") + except Exception as e: + logging.error(f"Kernel failed to start: {e}") + self.close() + raise + + def restart_kernel(self): + """Hard restart in case of stuck process or timeout.""" + logging.warning("Restarting Jupyter Kernel...") + self.close() + self._start_kernel() + self.is_initialized = False + + def _drain_channels(self, timeout=1.0): + """Drain messages from channels to prevent cross-contamination after interrupt/timeout.""" + if not self.kc: return + start = time.time() + while time.time() - start < timeout: + try: + self.kc.get_iopub_msg(timeout=0.1) + except queue.Empty: + break + + def _execute_code(self, code: str, timeout: float = 300.0) -> Dict[str, Any]: + if not self.kc: + return {"status": "error", "error": "Kernel client not available."} + + # Flush previous messages + self._drain_channels(timeout=0.1) + + msg_id = self.kc.execute(code) + result = {"status": "success", "stdout": "", "stderr": "", "display_data": []} + start_time = time.time() - def __init__(self): - """Initialize with an empty locals dictionary and pre-imported modules.""" - # Persistent locals dictionary - this is THE KEY FEATURE - self.locals = {} - - # Pre-import common modules into the persistent namespace - self.locals.update({ - 'pd': __import__('pandas'), - 'np': __import__('numpy'), - 'plt': __import__('matplotlib.pyplot', fromlist=['pyplot']), - 'xr': __import__('xarray'), - 'os': __import__('os'), - 'Path': __import__('pathlib').Path, - 'json': __import__('json'), - }) - # Add matplotlib.pyplot explicitly as plt (redundant but safe) - self.locals['plt'] = plt - - def execute(self, code: str) -> str: + while True: + # 1. Check Timeout + if time.time() - start_time > timeout: + logging.warning(f"Code execution timed out ({timeout}s). Interrupting kernel.") + self.km.interrupt_kernel() + # Drain leftover messages from the interrupted execution + self._drain_channels(timeout=2.0) + result["status"] = "error" + result["error"] = f"Timeout after {timeout}s. Execution interrupted." + break + + # 2. Check Kernel Vitality + if not self.km.is_alive(): + result["status"] = "error" + result["error"] = "Kernel died unexpectedly." + self.restart_kernel() + break + + # 3. Get Message + try: + msg = self.kc.get_iopub_msg(timeout=0.1) + except queue.Empty: + continue + + # 4. Filter by Parent Message ID + if msg['parent_header'].get('msg_id') != msg_id: + continue + + msg_type = msg['msg_type'] + content = msg['content'] + + if msg_type == 'status' and content['execution_state'] == 'idle': + break + elif msg_type == 'stream': + if content['name'] == 'stdout': result["stdout"] += content['text'] + elif content['name'] == 'stderr': result["stderr"] += content['text'] + elif msg_type in ('display_data', 'execute_result'): + result["display_data"].append(content['data']) + if 'text/plain' in content['data']: + result["stdout"] += content['data']['text/plain'] + "\n" + elif msg_type == 'error': + result["status"] = "error" + # Remove ANSI colors for readable error logs + error_trace = "\n".join(content['traceback']) + clean_error = re.sub(r'\x1b\[[0-9;]*m', '', error_trace) + result["error"] = f"{content['ename']}: {content['evalue']}\n{clean_error}" + break + + return result + + def run(self, code: str) -> Dict[str, Any]: """ - Execute Python code and capture stdout/stderr. - Automatically prints the result of the last expression. + Executes code and returns full result dict. + Sanitizes markdown fences but PRESERVES indentation. """ - # 1. Sanitize input (remove Markdown code blocks) - code = re.sub(r"^(\s|`)*(?i:python)?\s*", "", code) - code = re.sub(r"(\s|`)*$", "", code) - - # 2. Capture stdout/stderr - old_stdout = sys.stdout - old_stderr = sys.stderr - stdout_capture = StringIO() - stderr_capture = StringIO() - sys.stdout = stdout_capture - sys.stderr = stderr_capture + # 1. Dedent to fix common indentation errors (e.g. code inside list items) + code = textwrap.dedent(code) + + # 2. Remove markdown fences + code = code.strip() + if code.startswith("```"): + # Remove leading ```python or ``` (case insensitive) + code = re.sub(r"^```(?:python)?\s*\n", "", code, flags=re.IGNORECASE) + # Remove trailing ``` + code = re.sub(r"\n\s*```\s*$", "", code) + + return self._execute_code(code) - try: - # 3. Parse the code - tree = ast.parse(code) - - if not tree.body: - return "No code to execute." - - # 4. Execution Logic: Handle last expression printing - last_node = tree.body[-1] - if isinstance(last_node, ast.Expr): - # Execute everything before the last expression - if len(tree.body) > 1: - module = ast.Module(body=tree.body[:-1], type_ignores=[]) - exec(compile(module, "", "exec"), self.locals, self.locals) - - # Evaluate the last expression and print result - expr = ast.Expression(body=last_node.value) - result = eval(compile(expr, "", "eval"), self.locals, self.locals) - if result is not None: - print(repr(result)) - else: - # Execute the whole block if it doesn't end in an expression - exec(compile(tree, "", "exec"), self.locals, self.locals) + def close(self): + if self.kc: + try: self.kc.stop_channels() + except: pass + if self.km and self.km.is_alive(): + try: self.km.shutdown_kernel(now=True) + except: pass + +# --- Session Manager --- - output = stdout_capture.getvalue() - errors = stderr_capture.getvalue() +class REPLManager: + """Singleton to manage Kernels per session (thread_id).""" + _instances: dict[str, JupyterKernelExecutor] = {} + _lock = threading.Lock() - # Combine output - final_output = output - if errors: - final_output += f"\n[Stderr]: {errors}" + @classmethod + def get_repl(cls, session_id: str, sandbox_path: str = None) -> JupyterKernelExecutor: + with cls._lock: + if session_id not in cls._instances: + logging.info(f"Creating new Kernel for session: {session_id}") + try: + cls._instances[session_id] = JupyterKernelExecutor(working_dir=sandbox_path) + except Exception as e: + logging.error(f"Failed to create kernel: {e}") + raise + + # Restart dead kernels if necessary (Zombie check) + if not cls._instances[session_id].km.is_alive(): + logging.warning(f"Kernel for {session_id} died. Restarting.") + cls._instances[session_id].close() + cls._instances[session_id] = JupyterKernelExecutor(working_dir=sandbox_path) - if not final_output.strip(): - return "Code executed successfully (no output)." + return cls._instances[session_id] - return final_output.strip() + @classmethod + def cleanup_all(cls): + for sid in list(cls._instances.keys()): + try: + cls._instances[sid].close() + except Exception: + pass + del cls._instances[sid] - except Exception: - # Capture the full traceback for debugging - return traceback.format_exc() +atexit.register(REPLManager.cleanup_all) - finally: - # Restore stdout/stderr - sys.stdout = old_stdout - sys.stderr = old_stderr +# --- LangChain Tool Wrappers --- class CustomPythonREPLTool(PythonREPLTool): - """ - A wrapper class to maintain compatibility with Agent logic that expects - CustomPythonREPLTool. It uses PersistentPythonREPL internally. - """ + """Used by data_analysis_agent. Imports dataset paths automatically.""" _datasets: dict = PrivateAttr() _results_dir: Optional[str] = PrivateAttr() _session_key: Optional[str] = PrivateAttr() - _repl_instance: PersistentPythonREPL = PrivateAttr() def __init__(self, datasets, results_dir=None, session_key=None, **kwargs): super().__init__(**kwargs) self._datasets = datasets self._results_dir = results_dir self._session_key = session_key - self._repl_instance = PersistentPythonREPL() - self._initialize_session() - - def _initialize_session(self) -> None: - """Pre-load variables into the REPL locals.""" - logging.info("Initializing PersistentPythonREPL session...") + + def _run(self, query: str, **kwargs) -> Any: + # 1. Get Unified Session ID + sid = self._session_key or get_global_session_id() - # Inject dataset variables - if self._results_dir: - self._repl_instance.locals['results_dir'] = self._results_dir - # Ensure it exists - os.makedirs(self._results_dir, exist_ok=True) - - if self._datasets.get("uuid_main_dir"): - self._repl_instance.locals['uuid_main_dir'] = self._datasets["uuid_main_dir"] + # 2. Get/Create Kernel + sandbox_path = self._datasets.get("uuid_main_dir") + try: + repl = REPLManager.get_repl(sid, sandbox_path=sandbox_path) + except Exception as e: + return f"System Error: Could not start Python environment. {str(e)}" - # Auto-load data.csv if available - climate_data_dir = self._datasets.get("climate_data_dir") - if climate_data_dir: - self._repl_instance.locals['climate_data_dir'] = climate_data_dir - data_csv = os.path.join(climate_data_dir, 'data.csv') - if os.path.exists(data_csv): - try: - pd = self._repl_instance.locals['pd'] - self._repl_instance.locals['df'] = pd.read_csv(data_csv) - logging.info(f"Auto-loaded dataframe 'df' from {data_csv}") - except Exception as e: - logging.error(f"Failed to auto-load data.csv: {e}") + # 3. Auto-init if fresh + if not repl.is_initialized: + # Initialize matplotlib backend inside the kernel! + init_code = [ + "import pandas as pd", + "import numpy as np", + "import matplotlib", + "matplotlib.use('Agg')", # CRITICAL: Non-interactive backend + "import matplotlib.pyplot as plt", + "import xarray as xr", + "import os", + "import json" + ] + + # Safe path injection using repr() (handles Windows backslashes/quotes) + if self._results_dir: + init_code.append(f"results_dir = {repr(self._results_dir)}") + init_code.append("os.makedirs(results_dir, exist_ok=True)") + + if "climate_data_dir" in self._datasets: + cd_dir = self._datasets["climate_data_dir"] + init_code.append(f"climate_data_dir = {repr(cd_dir)}") + # Auto-load main data if exists + init_code.append(f"try:\n if os.path.exists(f'{{climate_data_dir}}/data.csv'):\n df = pd.read_csv(f'{{climate_data_dir}}/data.csv')\n print('Loaded df from data.csv')\nexcept Exception as e: print(f'Auto-load failed: {{e}}')") - def _run(self, query: str, **kwargs) -> Any: - """ - Execute code using the persistent REPL. - """ - logging.info(f"Executing code via PersistentPythonREPL:\n{query}") + if "era5_data_dir" in self._datasets: + init_code.append(f"era5_data_dir = {repr(self._datasets['era5_data_dir'])}") + + # Execute and check status + init_res = repl.run("\n".join(init_code)) + if init_res["status"] == "success": + repl.is_initialized = True + else: + return f"Kernel Initialization Error: {init_res.get('error')}" + + # 4. Detect Plots (Snapshot Method - Robust) + existing_files: Set[str] = set() + if self._results_dir and os.path.exists(self._results_dir): + try: + existing_files = set(os.listdir(self._results_dir)) + except Exception: + pass + + # 5. Execute User Code + res = repl.run(query) - # Execute - output = self._repl_instance.execute(query) + output_str = res["stdout"] + if res["stderr"]: + output_str += f"\n[STDERR]\n{res['stderr']}" - # Detect plots in results_dir - generated_plots = [] - if self._results_dir and os.path.exists(self._results_dir): - # Logic to find recently created images - # This simplistic check just looks for files mentioned in the output or assumes recent creation - matches = re.findall(r'([a-zA-Z0-9_\-/\\.]+\.(png|jpg|jpeg|svg|pdf))', output, re.IGNORECASE) - for match in matches: - path = match[0] - # Handle relative paths if they start with results_dir name - if os.path.basename(self._results_dir) in path: - full_path = os.path.join(os.path.dirname(self._results_dir), path) - if os.path.exists(full_path): - generated_plots.append(full_path) - # Check direct existence - elif os.path.exists(path): - generated_plots.append(path) + if res["status"] == "error": + return f"Error:\n{res.get('error')}" + # 6. Identify New Plots (Diff) + plots = [] + if self._results_dir and os.path.exists(self._results_dir): + try: + current_files = set(os.listdir(self._results_dir)) + new_files = current_files - existing_files + + for f in new_files: + if f.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf', '.svg')): + full_path = os.path.join(self._results_dir, f) + plots.append(full_path) + except Exception: + pass + return { - "result": output, - "output": output, - "plot_images": list(set(generated_plots)) # Deduplicate + "result": output_str.strip() or "Code executed successfully (no output).", + "output": output_str.strip(), + "plot_images": plots } -# Factory function for direct tool creation (used by Smart Agent) def create_python_repl_tool() -> StructuredTool: - """ - Create a NEW Python REPL tool instance. - CRITICAL CHANGE: This creates a NEW PersistentPythonREPL() every time it is called. - """ - from pydantic import BaseModel, Field - - # Instantiate a fresh REPL for this specific tool instance - repl_instance = PersistentPythonREPL() + """Factory for simple use cases (Smart Agent).""" + class Input(BaseModel): + code: str = Field(description="Python code to execute.") + + def run_repl(code: str): + sid = get_global_session_id() + + # Determine sandbox path + sandbox = None + if sid != "default_cli_session": + sandbox = os.path.join("tmp", "sandbox", sid) + if not os.path.exists(sandbox): + os.makedirs(sandbox, exist_ok=True) + + try: + repl = REPLManager.get_repl(sid, sandbox_path=sandbox) + # Basic init + if not repl.is_initialized: + repl.run("import matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot as plt; import pandas as pd; import numpy as np; import os") + repl.is_initialized = True - class PythonREPLInput(BaseModel): - code: str = Field( - description="Python code to execute. Pre-loaded: pandas(pd), numpy(np), matplotlib.pyplot(plt), xarray(xr)." - ) + res = repl.run(code) + if res["status"] == "success": + return res["stdout"] or "Executed." + return f"Error: {res.get('error')}" + except Exception as e: + return f"Kernel Error: {e}" - # Bind the execute method of THIS specific instance - tool = StructuredTool.from_function( - func=repl_instance.execute, + return StructuredTool.from_function( + func=run_repl, name="python_repl", - description=( - "Execute Python code for data analysis and visualization. " - "State persists between calls within this session. " - "The last line expression is automatically printed. " - "Plots must be saved to 'results_dir' using plt.savefig()." - ), - args_schema=PythonREPLInput - ) - return tool \ No newline at end of file + description="Execute Python code in a persistent Jupyter Kernel. State is preserved. Use this for calculations and plotting.", + args_schema=Input + ) \ No newline at end of file From cd6cca7bad8aa7263d04285798e52b703ef47919 Mon Sep 17 00:00:00 2001 From: dmpantiu Date: Tue, 27 Jan 2026 14:26:37 +0100 Subject: [PATCH 08/16] bug fix --- src/climsight/data_analysis_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py index 35ea262..7ba8309 100644 --- a/src/climsight/data_analysis_agent.py +++ b/src/climsight/data_analysis_agent.py @@ -163,7 +163,7 @@ def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon How to load result in Python_REPL: ```python import xarray as xr - ds = xr.open_dataset(path_from_tool_response, engine='zarr', chunks={{}}) + ds = xr.open_dataset(path_from_tool_response, engine='zarr', chunks={{{{}}}}) # If lat/lon are scalar, access directly: data = ds['t2'].to_series() ``` From beb2b9b055dab0b01fe322cb5732c43d46491da5 Mon Sep 17 00:00:00 2001 From: dmpantiu Date: Wed, 28 Jan 2026 15:44:54 +0100 Subject: [PATCH 09/16] fix: simplify sandbox paths to relative and update LLM token parameters --- src/climsight/data_analysis_agent.py | 50 ++++++++++++++-------- src/climsight/smart_agent.py | 19 +------- src/climsight/tools/era5_retrieval_tool.py | 7 ++- src/climsight/tools/image_viewer.py | 2 +- src/climsight/tools/python_repl.py | 12 +++--- src/climsight/tools/reflection_tools.py | 2 +- 6 files changed, 49 insertions(+), 43 deletions(-) diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py index 7ba8309..741d226 100644 --- a/src/climsight/data_analysis_agent.py +++ b/src/climsight/data_analysis_agent.py @@ -56,27 +56,32 @@ def _build_climate_data_summary(df_list: List[Dict[str, Any]]) -> str: def _build_datasets_text(state) -> str: - """Build simple dataset paths text for prompt injection.""" + """Build simple dataset paths text for prompt injection. + + IMPORTANT: The Python REPL kernel CWD is already set to the sandbox root, + so we tell the agent to use RELATIVE paths (not full tmp/sandbox/... paths). + """ lines = [ - "## Sandbox Paths (use these exact paths in your code)", - f"- Main directory: {state.uuid_main_dir}", - f"- Results directory (save plots here): {state.results_dir}", - f"- Climate data: {state.climate_data_dir}", + "## Sandbox Paths (Python REPL is ALREADY inside the sandbox directory)", + "**CRITICAL: Use RELATIVE paths in your Python code, NOT full paths starting with 'tmp/sandbox/...'**", + f"- Current Working Directory: '.' (which is {state.uuid_main_dir})", + f"- Results directory: 'results' (save all plots here)", + f"- Climate data: 'climate_data'", ] if state.era5_data_dir: - lines.append(f"- ERA5 data: {state.era5_data_dir}") + lines.append(f"- ERA5 data: 'era5_data'") # List available climate data files if state.climate_data_dir and os.path.exists(state.climate_data_dir): try: files = os.listdir(state.climate_data_dir) if files: - lines.append(f"\n## Climate Data Files Available") - lines.append(f"Files in climate_data/: {', '.join(files)}") + lines.append(f"\n## Climate Data Files Available (in 'climate_data/' folder)") + lines.append(f"Files: {', '.join(files)}") # Highlight the main data.csv file if "data.csv" in files: - lines.append("Note: 'data.csv' contains the main climatology dataset") + lines.append("Note: Load with `pd.read_csv('climate_data/data.csv')`") except Exception as e: logger.warning(f"Could not list climate data files: {e}") @@ -146,7 +151,6 @@ def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon # TOOL #3: ERA5 time series download (optional, for detailed analysis) if has_era5_download: - work_dir_str = sandbox_paths.get("uuid_main_dir", "") if "sandbox_paths" in dir() else "" prompt += f""" {tool_num}. **retrieve_era5_data** - Retrieve ERA5 Surface climate data from Earthmover (Arraylake) @@ -154,17 +158,22 @@ def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon **DATA SOURCE:** Earthmover (Arraylake), hardcoded to "temporal" mode. **VARIABLE CODES:** Use short codes: 't2' (Temp), 'u10'/'v10' (Wind), 'mslp' (Pressure), 'tp' (Precip). - **PATH RULE:** You MUST always pass `work_dir` to this tool. + **WORK_DIR:** Pass work_dir='.' to save in current sandbox. - **OUTPUT USAGE:** - The tool returns a path to a Zarr store. - **CRITICAL:** If you requested a specific point, the data is likely already reduced to that point (nearest neighbor). + **OUTPUT & LOADING:** + The tool returns an absolute path to a Zarr store (saved in 'era5_data/' folder). + **CRITICAL:** In Python_REPL, use RELATIVE paths to load the data since CWD is already the sandbox: - How to load result in Python_REPL: ```python import xarray as xr - ds = xr.open_dataset(path_from_tool_response, engine='zarr', chunks={{{{}}}}) - # If lat/lon are scalar, access directly: + import glob + + # List available ERA5 Zarr files (RELATIVE path) + era5_files = glob.glob('era5_data/*.zarr') + print(era5_files) + + # Load using RELATIVE path (NOT the absolute path from tool response) + ds = xr.open_dataset('era5_data/era5_t2_temporal_....zarr', engine='zarr', chunks={{{{}}}}) data = ds['t2'].to_series() ``` @@ -176,9 +185,14 @@ def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon prompt += f""" {tool_num}. **Python_REPL** - Execute Python code for data analysis and visualizations - Pre-loaded: pandas (pd), numpy (np), matplotlib.pyplot (plt), xarray (xr) - - Working directory is the sandbox root + - Working directory is ALREADY the sandbox root - ALWAYS save plots to results/ directory + **CRITICAL PATH RULE:** + ❌ WRONG: `base='tmp/sandbox/uuid...'` then `f'{{{{base}}}}/era5_data/...'` + ✓ CORRECT: Use relative paths directly: `'era5_data/...'`, `'climate_data/...'`, `'results/...'` + The kernel CWD is already inside the sandbox, so DO NOT prepend 'tmp/sandbox/...'! + **Climate Model Data** (climate_data/data.csv): - Monthly climatology from climate models (MODEL projections, not observations) - Columns: Month, mean2t (temperature °C), tp (precipitation), wind_u, wind_v diff --git a/src/climsight/smart_agent.py b/src/climsight/smart_agent.py index 3203785..b0794a5 100644 --- a/src/climsight/smart_agent.py +++ b/src/climsight/smart_agent.py @@ -101,21 +101,6 @@ def smart_agent(state: AgentState, config, api_key, api_key_local, stream_handle ONLY use this tool if the user's question is clearly about agriculture or crops. Do NOT use it for general climate queries. - - "python_repl" allows you to execute Python code for data analysis, visualization, and calculations. - Use this tool when: - - Creating visualizations (plots, charts, graphs) of climate data - - **CRITICAL: Your working directory is available at `work_dir` = '{work_dir_str}'** - **When saving plots, ALWAYS store the full path in a variable for later use!** - - **CORRECT way to save and reference images:** - ```python - plot_path = f'{{{{work_dir}}}}/temperature_plot.png' - plt.savefig(plot_path) - print(plot_path) # PRINT IT TO CONFIRM - ``` - - **Your goal**: Gather comprehensive background information and compile it into a well-structured summary that will be used by subsequent agents for data analysis. @@ -529,7 +514,7 @@ def process_ecocrop_search(query: str) -> str: # Create python_repl tool - python_repl_tool = create_python_repl_tool() + #python_repl_tool = create_python_repl_tool() # Initialize the LLM if config['llm_smart']['model_type'] == "local": @@ -550,7 +535,7 @@ def process_ecocrop_search(query: str) -> str: llm = get_aitta_chat_model(config['llm_smart']['model_name'], temperature = 0) # List of tools - tools = [rag_tool, ecocrop_tool, wikipedia_tool, python_repl_tool] + tools = [rag_tool, ecocrop_tool, wikipedia_tool] prompt += """\nadditional information:\n question is related to this location: {location_str} \n diff --git a/src/climsight/tools/era5_retrieval_tool.py b/src/climsight/tools/era5_retrieval_tool.py index e395ba1..d934441 100644 --- a/src/climsight/tools/era5_retrieval_tool.py +++ b/src/climsight/tools/era5_retrieval_tool.py @@ -150,7 +150,12 @@ def retrieve_era5_data( main_dir = os.path.join("tmp", "sandbox", "era5_earthmover") os.makedirs(main_dir, exist_ok=True) - era5_dir = os.path.join(main_dir, "era5_data") + + # FIX: Prevent double nesting if work_dir already ends with 'era5_data' + if os.path.basename(main_dir.rstrip(os.sep)) == "era5_data": + era5_dir = main_dir + else: + era5_dir = os.path.join(main_dir, "era5_data") os.makedirs(era5_dir, exist_ok=True) # Check Cache diff --git a/src/climsight/tools/image_viewer.py b/src/climsight/tools/image_viewer.py index df489ee..b501c89 100644 --- a/src/climsight/tools/image_viewer.py +++ b/src/climsight/tools/image_viewer.py @@ -82,7 +82,7 @@ def view_and_analyze_image(image_path: str, openai_api_key: str, model_name: str ] } ], - max_tokens=5000 + max_completion_tokens=5000 ) return response.choices[0].message.content diff --git a/src/climsight/tools/python_repl.py b/src/climsight/tools/python_repl.py index cef8cea..a5f9ba0 100644 --- a/src/climsight/tools/python_repl.py +++ b/src/climsight/tools/python_repl.py @@ -274,19 +274,21 @@ def _run(self, query: str, **kwargs) -> Any: "import json" ] - # Safe path injection using repr() (handles Windows backslashes/quotes) + # Safe path injection using repr() if self._results_dir: - init_code.append(f"results_dir = {repr(self._results_dir)}") + # FIX: Since the kernel CWD is already the sandbox root, + init_code.append(f"results_dir = 'results'") init_code.append("os.makedirs(results_dir, exist_ok=True)") if "climate_data_dir" in self._datasets: - cd_dir = self._datasets["climate_data_dir"] - init_code.append(f"climate_data_dir = {repr(cd_dir)}") + # FIX: Set simple relative path because Kernel CWD is already the sandbox root + init_code.append(f"climate_data_dir = 'climate_data'") # Auto-load main data if exists init_code.append(f"try:\n if os.path.exists(f'{{climate_data_dir}}/data.csv'):\n df = pd.read_csv(f'{{climate_data_dir}}/data.csv')\n print('Loaded df from data.csv')\nexcept Exception as e: print(f'Auto-load failed: {{e}}')") if "era5_data_dir" in self._datasets: - init_code.append(f"era5_data_dir = {repr(self._datasets['era5_data_dir'])}") + # FIX: Set simple relative path because Kernel CWD is already the sandbox root + init_code.append(f"era5_data_dir = 'era5_data'") # Execute and check status init_res = repl.run("\n".join(init_code)) diff --git a/src/climsight/tools/reflection_tools.py b/src/climsight/tools/reflection_tools.py index e6925de..728c498 100644 --- a/src/climsight/tools/reflection_tools.py +++ b/src/climsight/tools/reflection_tools.py @@ -117,7 +117,7 @@ def reflect_on_image(image_path: str) -> str: ] } ], - max_tokens=1000 + max_completion_tokens=1000 ) return response.choices[0].message.content From c19b4c14f38842e02a32b6de71d673887f36a345 Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Wed, 28 Jan 2026 17:12:40 +0100 Subject: [PATCH 10/16] update 260 grad --- src/climsight/climate_data_providers.py | 7 ++++++- src/climsight/climate_functions.py | 8 +++++++- src/climsight/geo_functions.py | 8 ++++---- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/climsight/climate_data_providers.py b/src/climsight/climate_data_providers.py index 4fc9c9f..bbaab8e 100644 --- a/src/climsight/climate_data_providers.py +++ b/src/climsight/climate_data_providers.py @@ -530,8 +530,13 @@ def _extract_data_regular_grid( # Open NetCDF file ds = xr.open_dataset(file_path) + # Normalize longitude to 0-360 range if dataset uses that convention + query_lon = lon + if lon < 0 and ds.lon.min() >= 0: + query_lon = lon + 360 + # Use bilinear interpolation to get data at the exact point - ds_interpolated = ds.interp(lat=lat, lon=lon, method='linear') + ds_interpolated = ds.interp(lat=lat, lon=query_lon, method='linear') # Extract variables for all 12 months month_names = [calendar.month_name[i] for i in range(1, 13)] diff --git a/src/climsight/climate_functions.py b/src/climsight/climate_functions.py index bf8734d..fecedaa 100644 --- a/src/climsight/climate_functions.py +++ b/src/climsight/climate_functions.py @@ -63,10 +63,16 @@ def load_data(config): def select_data(dataset, variable, dimensions, lat, lon): """ Selects data for a given variable at specified latitude and longitude. + Handles longitude normalization for datasets using 0-360 range. """ + # Normalize longitude if dataset uses 0-360 range and input is negative + lon_dim = dimensions['longitude'] + if lon < 0 and dataset[lon_dim].min() >= 0: + lon = lon + 360 + return dataset[variable].sel(**{ dimensions['latitude']: lat, - dimensions['longitude']: lon}, + lon_dim: lon}, method="nearest") def verify_shape(hist_units, future_units, variable): diff --git a/src/climsight/geo_functions.py b/src/climsight/geo_functions.py index 5c05089..bc81280 100644 --- a/src/climsight/geo_functions.py +++ b/src/climsight/geo_functions.py @@ -46,7 +46,7 @@ def get_location(lat, lon): "User-Agent": "climsight", "accept-language": "en" } - response = requests.get(url, params=params, headers=headers, timeout=5) + response = requests.get(url, params=params, headers=headers, timeout=10) location = response.json() # Wait before making the next request (according to terms of use) @@ -347,7 +347,7 @@ def get_elevation_from_api(lat, lon): float: The elevation of the location in meters. """ url = f"https://api.opentopodata.org/v1/etopo1?locations={lat},{lon}" - response = requests.get(url, timeout=3) + response = requests.get(url, timeout=10) data = response.json() return data["results"][0]["elevation"] @@ -370,7 +370,7 @@ def fetch_land_use(lon, lat): area.a["landuse"]; out tags; """ - response = requests.get(overpass_url, params={"data": overpass_query}, timeout=3) + response = requests.get(overpass_url, params={"data": overpass_query}, timeout=10) data = response.json() return data @@ -388,7 +388,7 @@ def get_soil_from_api(lat, lon): """ try: url = f"https://rest.isric.org/soilgrids/v2.0/classification/query?lon={lon}&lat={lat}&number_classes=5" - response = requests.get(url, timeout=3) # Set timeout to 2 seconds + response = requests.get(url, timeout=10) # Set timeout to 2 seconds data = response.json() return data["wrb_class_name"] except Timeout: From f8e448b45fb70b625ce1172085376c5b5a1ab1c6 Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Sun, 1 Feb 2026 14:54:21 +0100 Subject: [PATCH 11/16] update prompt --- src/climsight/data_analysis_agent.py | 77 +++++++++++++++------- src/climsight/sandbox_utils.py | 3 +- src/climsight/tools/era5_retrieval_tool.py | 21 ++++-- 3 files changed, 68 insertions(+), 33 deletions(-) diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py index 741d226..998186f 100644 --- a/src/climsight/data_analysis_agent.py +++ b/src/climsight/data_analysis_agent.py @@ -187,16 +187,18 @@ def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon - Pre-loaded: pandas (pd), numpy (np), matplotlib.pyplot (plt), xarray (xr) - Working directory is ALREADY the sandbox root - ALWAYS save plots to results/ directory - **CRITICAL PATH RULE:** ❌ WRONG: `base='tmp/sandbox/uuid...'` then `f'{{{{base}}}}/era5_data/...'` ✓ CORRECT: Use relative paths directly: `'era5_data/...'`, `'climate_data/...'`, `'results/...'` The kernel CWD is already inside the sandbox, so DO NOT prepend 'tmp/sandbox/...'! - **Climate Model Data** (climate_data/data.csv): - - Monthly climatology from climate models (MODEL projections, not observations) - - Columns: Month, mean2t (temperature °C), tp (precipitation), wind_u, wind_v - - Contains historical period AND future projections + **Climate Model Data** (climate_data/ directory): + - `climate_data_manifest.json` - READ THIS FIRST to see all available simulations + - `simulation_1.csv`, `simulation_2.csv`, ... - Data for different time periods + - `simulation_N_meta.json` - Metadata (years, description) for each simulation + - `data.csv` - Main/baseline simulation only (for quick access) + - Columns: Month, mean2t (temperature °C), tp (precipitation mm/month), wind_u, wind_v, wind_speed, wind_direction + - Use list_plotting_data_files tool to discover all available files **ERA5 Climatology** (era5_climatology.json): - After calling get_era5_climatology, results are saved here @@ -209,30 +211,35 @@ def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon import json import matplotlib.pyplot as plt - # Load ERA5 observations (ground truth) + # 1. Load manifest to see all available simulations + manifest = json.load(open('climate_data/climate_data_manifest.json')) + print(f"Data source: {{{{manifest['source']}}}}") + for entry in manifest['entries']: + print(f" {{{{entry['csv']}}}} : {{{{entry['years_of_averaging']}}}}") + + # 2. Load ERA5 observations (ground truth) era5 = json.load(open('era5_climatology.json')) era5_temp = era5['variables']['t2m']['monthly_values'] - # Load climate model projections - df = pd.read_csv('climate_data/data.csv') + # 3. Load multiple climate model simulations + simulations = [] + for entry in manifest['entries']: + df = pd.read_csv(f"climate_data/{{{{entry['csv']}}}}") + simulations.append({{{{'df': df, 'years': entry['years_of_averaging'], 'main': entry['main']}}}})) - # Compare: ERA5 observations vs climate model + # 4. Plot comparison months = list(era5_temp.keys()) - era5_values = list(era5_temp.values()) - model_values = df['mean2t'].tolist()[:12] # historical period - - plt.figure(figsize=(10, 6)) - plt.plot(months, era5_values, 'b-o', label='ERA5 Observations (2015-2025)') - plt.plot(months, model_values, 'r--s', label='Climate Model') + plt.figure(figsize=(12, 6)) + plt.plot(months, list(era5_temp.values()), 'k-o', linewidth=2, label='ERA5 Observations') + for sim in simulations: + style = '-' if sim['main'] else '--' + plt.plot(months, sim['df']['mean2t'].tolist(), style, label=f"Model {{{{sim['years']}}}}") plt.xlabel('Month') plt.ylabel('Temperature (°C)') - plt.title('Observed vs Modeled Temperature') plt.legend() - plt.xticks(rotation=45) plt.tight_layout() - plt.savefig('results/era5_vs_model.png', dpi=150) + plt.savefig('results/temperature_comparison.png', dpi=150) plt.close() - print('Plot saved to results/era5_vs_model.png') ``` """ tool_num += 1 @@ -270,24 +277,44 @@ def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon Call get_era5_climatology FIRST to get the observed climate baseline. - Extract at minimum: temperature (t2m), precipitation (tp) - This is what the climate ACTUALLY IS at this location (2015-2025 average) +""" + # Add ERA5 time series download and analysis steps if available + if has_era5_download and has_repl: + prompt += """ +**STEP 2 - DOWNLOAD ERA5 TIME SERIES (RECOMMENDED):** +Call retrieve_era5_data to get detailed historical time series for deeper analysis. +- Download temperature (t2) and precipitation (tp) for 2015-2024 +- This provides year-by-year data, not just climatological averages +- Data saved to era5_data/ as .zarr files for Python analysis +- Enables trend analysis, extreme event detection, interannual variability + +**STEP 3 - ANALYZE ERA5 TIME SERIES:** +Use Python_REPL to analyze the downloaded ERA5 time series: +```python +import xarray as xr +ds = xr.open_zarr('era5_data/era5_t2_temporal_YYYYMMDD_YYYYMMDD.zarr') +# Compute trends, identify extreme years, plot interannual variability +``` """ if has_repl: - prompt += """ -**STEP 2 - LOAD MODEL DATA:** + # step_offset accounts for ERA5 download (step 2) and analysis (step 3) when enabled + step_offset = 2 if has_era5_download else 0 + prompt += f""" +**STEP {2 + step_offset} - LOAD MODEL DATA:** Use Python_REPL to read climate_data/data.csv (model projections) -**STEP 3 - COMPARE OBSERVATIONS vs MODEL:** +**STEP {3 + step_offset} - COMPARE OBSERVATIONS vs MODEL:** - ERA5 = ground truth (what we observe NOW) - Model historical = what the model simulates for recent past - Difference = MODEL BIAS (critical for interpreting future projections) -**STEP 4 - ANALYZE FUTURE WITH CONTEXT:** +**STEP {4 + step_offset} - ANALYZE FUTURE WITH CONTEXT:** - Show future projections from climate model - Explain how model bias affects confidence - Future change = (Model future) - (ERA5 baseline) -**STEP 5 - CREATE VISUALIZATIONS (MANDATORY):** +**STEP {5 + step_offset} - CREATE VISUALIZATIONS (MANDATORY):** - Plot 1: ERA5 observations vs model (shows model bias) - Plot 2: Future projections with ERA5 baseline - Save ALL plots to results/ directory @@ -361,7 +388,7 @@ def _create_tool_prompt(datasets_text: str, config: dict, lat: float = None, lon """ prompt += """ -Limit total tool calls to 20. +Limit total tool calls to 50. """ return prompt diff --git a/src/climsight/sandbox_utils.py b/src/climsight/sandbox_utils.py index 963731c..5b52953 100644 --- a/src/climsight/sandbox_utils.py +++ b/src/climsight/sandbox_utils.py @@ -28,7 +28,8 @@ def ensure_thread_id(session_state=None, existing_thread_id: str = "") -> str: session_state["thread_id"] = thread_id # Expose to non-Streamlit tools (CLI, background workers). - os.environ.setdefault("CLIMSIGHT_THREAD_ID", thread_id) + # CRITICAL: Always update (not setdefault) to ensure all tools use the same thread_id + os.environ["CLIMSIGHT_THREAD_ID"] = thread_id return thread_id diff --git a/src/climsight/tools/era5_retrieval_tool.py b/src/climsight/tools/era5_retrieval_tool.py index d934441..435928d 100644 --- a/src/climsight/tools/era5_retrieval_tool.py +++ b/src/climsight/tools/era5_retrieval_tool.py @@ -130,24 +130,31 @@ def retrieve_era5_data( logging.info(f"🌍 Earthmover ERA5 Retrieval ({query_type}): {short_var} | {start_date} to {end_date}") # --- 1. Sandbox / Path Logic --- + # Priority: 1) CLIMSIGHT_THREAD_ID env var, 2) Streamlit session, 3) work_dir, 4) default main_dir = None - if work_dir and os.path.exists(work_dir): - main_dir = work_dir - # Fallback logic for streamlit session if work_dir not provided + thread_id = os.environ.get("CLIMSIGHT_THREAD_ID") + + if thread_id: + # Use sandbox path based on thread_id (set by sandbox_utils) + main_dir = os.path.join("tmp", "sandbox", thread_id) elif "streamlit" in sys.modules and st is not None and hasattr(st, 'session_state'): try: session_uuid = getattr(st.session_state, "session_uuid", None) if session_uuid: main_dir = os.path.join("tmp", "sandbox", session_uuid) else: - thread_id = st.session_state.get("thread_id") - if thread_id: - main_dir = os.path.join("tmp", "sandbox", thread_id) + st_thread_id = st.session_state.get("thread_id") + if st_thread_id: + main_dir = os.path.join("tmp", "sandbox", st_thread_id) except: pass + # Fallback to work_dir if no sandbox available + if not main_dir and work_dir: + main_dir = work_dir + if not main_dir: - main_dir = os.path.join("tmp", "sandbox", "era5_earthmover") + main_dir = os.path.join("tmp", "sandbox", "era5_default") os.makedirs(main_dir, exist_ok=True) From 73069aca2417829e38d35afd392b8c0a16e7be7d Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Mon, 2 Feb 2026 10:51:15 +0100 Subject: [PATCH 12/16] Add DestinE provider and pass Arraylake API key for ERA5 retrieval add DestinE config (time periods, variable mapping/suffixes) and provider implementation with unstructured grid interpolation and unit conversions wire DestinE into provider factory and availability list gate ERA5 retrieval tool on Arraylake API key, pass via config from Streamlit UI allow era5 retrieval tool creation with bound API key while keeping env-based fallback --- config.yml | 44 ++++ src/climsight/climate_data_providers.py | 283 ++++++++++++++++++++- src/climsight/data_analysis_agent.py | 8 +- src/climsight/streamlit_interface.py | 49 +++- src/climsight/tools/era5_retrieval_tool.py | 68 ++++- 5 files changed, 434 insertions(+), 18 deletions(-) diff --git a/config.yml b/config.yml index 66a8490..1f067df 100644 --- a/config.yml +++ b/config.yml @@ -137,6 +137,50 @@ climate_data_sources: longitude: "lon" time: "month" + DestinE: + enabled: true + coordinate_system: "unstructured" + description: "DestinE IFS-FESOM high-resolution climate simulations (SSP3-7.0)" + data_path: "./data/DestinE/" + # Time periods configuration + time_periods: + historical: + pattern: "ifs-fesom_baseline_hist_sfc_high_monthly_1990_2014_mean" + years_of_averaging: "1990-2014" + description: "DestinE IFS-FESOM historical baseline simulation" + is_main: true + source: "Destination Earth Climate DT, IFS-FESOM coupled model" + 2015_2019: + pattern: "ifs-fesom_projections_ssp3-7.0_sfc_high_monthly_2015_2019_mean" + years_of_averaging: "2015-2019" + description: "DestinE IFS-FESOM SSP3-7.0 near-term projection" + is_main: false + source: "Destination Earth Climate DT, IFS-FESOM coupled model, SSP3-7.0" + 2020_2029: + pattern: "ifs-fesom_projections_ssp3-7.0_sfc_high_monthly_2020_2029_mean" + years_of_averaging: "2020-2029" + description: "DestinE IFS-FESOM SSP3-7.0 mid-term projection" + is_main: false + source: "Destination Earth Climate DT, IFS-FESOM coupled model, SSP3-7.0" + 2040_2049: + pattern: "ifs-fesom_projections_ssp3-7.0_sfc_high_monthly_2040_2049_mean" + years_of_averaging: "2040-2049" + description: "DestinE IFS-FESOM SSP3-7.0 far-term projection" + is_main: false + source: "Destination Earth Climate DT, IFS-FESOM coupled model, SSP3-7.0" + # Variable mapping: display_name -> netcdf_variable + variable_mapping: + Temperature: avg_2t + Total Precipitation: avg_tprate + Wind U: avg_10u + Wind V: avg_10v + # Variable file suffixes (to construct full filenames) + variable_suffixes: + avg_2t: "_avg_2t.nc" + avg_tprate: "_avg_tprate.nc" + avg_10u: "_avg_10u.nc" + avg_10v: "_avg_10v.nc" + # Legacy settings (kept for backwards compatibility, will be migrated automatically) data_settings: data_path: "./data/" diff --git a/src/climsight/climate_data_providers.py b/src/climsight/climate_data_providers.py index bbaab8e..a2c15a0 100644 --- a/src/climsight/climate_data_providers.py +++ b/src/climsight/climate_data_providers.py @@ -830,6 +830,285 @@ def extract_data( ) +class DestinEProvider(ClimateDataProvider): + """Provider for DestinE IFS-FESOM high-resolution climate data. + + This provider handles unstructured grid data (similar to HEALPix) using + cKDTree for efficient spatial lookups. Unlike NextGEMS, DestinE stores + each variable in separate files that must be combined. + + Data characteristics: + - Coordinate system: Unstructured grid (12.5M points) + - Variables: avg_2t (temp), avg_tprate (precip), avg_10u/avg_10v (wind) + - Time periods: 1990-2014 (hist), 2015-2019, 2020-2029, 2040-2049 (SSP3-7.0) + - Units: Temperature in K, precipitation in kg m⁻² s⁻¹, wind in m/s + """ + + # Days in each month for precipitation conversion (using 28.25 for Feb to account for leap years) + DAYS_IN_MONTH = [31, 28.25, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + SECONDS_PER_DAY = 86400 + + def __init__(self, source_config: dict, global_config: dict = None): + super().__init__(source_config, global_config) + self._spatial_indices = {} # Cache for cKDTree per file + self._data_path = source_config.get('data_path', './data/DestinE/') + + @property + def name(self) -> str: + return "DestinE" + + @property + def coordinate_system(self) -> str: + return "unstructured" + + def is_available(self) -> bool: + """Check if at least one complete time period exists.""" + time_periods = self.source_config.get('time_periods', {}) + var_suffixes = self.source_config.get('variable_suffixes', {}) + + if not time_periods or not var_suffixes: + return False + + for period_key, period_meta in time_periods.items(): + pattern = period_meta.get('pattern', '') + if not pattern: + continue + + # Check if all variable files exist for this period + all_exist = True + for var_name, suffix in var_suffixes.items(): + file_path = os.path.join(self._data_path, f"{pattern}{suffix}") + if not os.path.exists(file_path): + all_exist = False + break + if all_exist: + return True + return False + + def _build_spatial_index(self, nc_file: str) -> Tuple[cKDTree, np.ndarray, np.ndarray]: + """Build cKDTree spatial index from a NetCDF file.""" + if nc_file in self._spatial_indices: + return self._spatial_indices[nc_file] + + ds = xr.open_dataset(nc_file) + lons = ds['longitude'].values + lats = ds['latitude'].values + ds.close() + + points = np.column_stack((lons, lats)) + tree = cKDTree(points) + + self._spatial_indices[nc_file] = (tree, lons, lats) + return tree, lons, lats + + def _extract_point_data( + self, + ds: xr.Dataset, + var_name: str, + indices: np.ndarray, + weights: np.ndarray, + use_exact: bool, + exact_idx: int = 0 + ) -> np.ndarray: + """Extract interpolated values for all months at a point.""" + interpolated = [] + for month_idx in range(12): + data_values = ds[var_name][month_idx, indices].values + if use_exact: + value = data_values[exact_idx] + else: + value = np.dot(weights, data_values) + interpolated.append(value) + return np.array(interpolated) + + def _convert_precip_rate(self, values: np.ndarray) -> np.ndarray: + """Convert precipitation from kg m⁻² s⁻¹ to mm/month. + + 1 kg/m² water = 1 mm depth (by definition of water density) + So: kg m⁻² s⁻¹ = mm/s + mm/month = mm/s × days_in_month × 86400 s/day + """ + converted = np.array([ + values[i] * self.DAYS_IN_MONTH[i] * self.SECONDS_PER_DAY + for i in range(len(values)) + ]) + return converted + + def _post_process_data( + self, + df: pd.DataFrame, + df_vars: Dict + ) -> Tuple[pd.DataFrame, Dict]: + """Apply unit conversions and calculate wind speed/direction.""" + df_processed = df.copy() + df_vars_processed = df_vars.copy() + + # Temperature: K to °C + if 'avg_2t' in df_processed.columns: + df_processed['avg_2t'] = df_processed['avg_2t'] - 273.15 + df_vars_processed['avg_2t']['units'] = '°C' + + # Precipitation: kg m⁻² s⁻¹ to mm/month + if 'avg_tprate' in df_processed.columns: + df_processed['avg_tprate'] = self._convert_precip_rate( + df_processed['avg_tprate'].values + ) + df_vars_processed['avg_tprate']['units'] = 'mm/month' + + # Calculate wind speed and direction + if 'avg_10u' in df_processed.columns and 'avg_10v' in df_processed.columns: + wind_speed = np.sqrt( + df_processed['avg_10u']**2 + df_processed['avg_10v']**2 + ) + wind_direction = (180.0 + np.degrees( + np.arctan2(df_processed['avg_10u'], df_processed['avg_10v']) + )) % 360 + + df_processed['wind_speed'] = wind_speed.round(2) + df_processed['wind_direction'] = wind_direction.round(2) + + df_vars_processed['wind_speed'] = { + 'name': 'wind_speed', 'units': 'm/s', + 'full_name': 'Wind Speed', 'long_name': 'Wind Speed' + } + df_vars_processed['wind_direction'] = { + 'name': 'wind_direction', 'units': '°', + 'full_name': 'Wind Direction', 'long_name': 'Wind Direction' + } + + # Round numeric columns + for var in df_processed.columns: + if var != 'Month' and pd.api.types.is_numeric_dtype(df_processed[var]): + df_processed[var] = df_processed[var].round(2) + + return df_processed, df_vars_processed + + def extract_data( + self, + lon: float, + lat: float, + months: Optional[List[int]] = None + ) -> ClimateDataResult: + """Extract climate data for a location from DestinE data.""" + time_periods = self.source_config.get('time_periods', {}) + var_mapping = self.source_config.get('variable_mapping', {}) + var_suffixes = self.source_config.get('variable_suffixes', {}) + + if months is None: + months = list(range(1, 13)) + + df_list = [] + tree = None + lons = None + lats = None + indices = None + weights = None + use_exact = False + exact_idx = 0 + + # Process each time period + for period_key, period_meta in time_periods.items(): + pattern = period_meta.get('pattern', '') + if not pattern: + continue + + # Build paths for all variable files + var_files = {} + all_exist = True + for nc_var, suffix in var_suffixes.items(): + file_path = os.path.join(self._data_path, f"{pattern}{suffix}") + if os.path.exists(file_path): + var_files[nc_var] = file_path + else: + all_exist = False + logger.debug(f"Missing DestinE file: {file_path}") + break + + if not all_exist: + continue + + # Build spatial index from first variable file (coordinates are same in all files) + first_file = list(var_files.values())[0] + if tree is None: + tree, lons, lats = self._build_spatial_index(first_file) + + # Normalize longitude to 0-360 if data uses that convention + query_lon = lon + if lon < 0 and lons.min() >= 0: + query_lon = lon + 360 + + # Query 4 nearest neighbors + distances, indices = tree.query([query_lon, lat], k=4) + + # Project and compute inverse distance weights + pste = pyproj.Proj( + proj="stere", errcheck=True, ellps='WGS84', + lat_0=lat, lon_0=lon + ) + neighbors_lons = lons[indices] + neighbors_lats = lats[indices] + neighbors_x, neighbors_y = pste(neighbors_lons, neighbors_lats) + desired_x, desired_y = pste(lon, lat) + + dx = neighbors_x - desired_x + dy = neighbors_y - desired_y + distances_proj = np.hypot(dx, dy) + + if np.any(distances_proj == 0): + exact_idx = np.where(distances_proj == 0)[0][0] + use_exact = True + else: + inv_distances = 1.0 / distances_proj + weights = inv_distances / inv_distances.sum() + + # Extract data from all variable files + df_data = {'Month': [calendar.month_name[m] for m in months]} + df_vars = {} + + for display_name, nc_var in var_mapping.items(): + if nc_var not in var_files: + continue + + ds = xr.open_dataset(var_files[nc_var]) + + values = self._extract_point_data( + ds, nc_var, indices, weights, use_exact, exact_idx + ) + + df_data[nc_var] = values + df_vars[nc_var] = { + 'name': nc_var, + 'units': ds[nc_var].attrs.get('units', ''), + 'full_name': display_name, + 'long_name': ds[nc_var].attrs.get('long_name', '') + } + + ds.close() + + df = pd.DataFrame(df_data) + df_processed, df_vars_processed = self._post_process_data(df, df_vars) + + df_list.append({ + 'filename': pattern, + 'years_of_averaging': period_meta.get('years_of_averaging', ''), + 'description': period_meta.get('description', ''), + 'dataframe': df_processed, + 'extracted_vars': df_vars_processed, + 'main': period_meta.get('is_main', False), + 'source': period_meta.get('source', 'DestinE IFS-FESOM') + }) + + # Prepare data_agent_response + data_agent_response = self._prepare_data_agent_response(df_list) + + return ClimateDataResult( + df_list=df_list, + data_agent_response=data_agent_response, + source_name=self.name, + source_description=self.description + ) + + # Factory functions def get_climate_data_provider( @@ -857,6 +1136,8 @@ def get_climate_data_provider( return ICCPProvider(sources_config.get('ICCP', {}), config) elif source == 'AWI_CM': return AWICMProvider(sources_config.get('AWI_CM', {}), config) + elif source == 'DestinE': + return DestinEProvider(sources_config.get('DestinE', {}), config) else: raise ValueError(f"Unknown climate data source: {source}") @@ -874,7 +1155,7 @@ def get_available_providers(config: dict) -> List[str]: List of provider names that are available """ available = [] - for source in ['nextGEMS', 'ICCP', 'AWI_CM']: + for source in ['nextGEMS', 'ICCP', 'AWI_CM', 'DestinE']: try: provider = get_climate_data_provider(config, source) if provider.is_available(): diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py index 998186f..a820399 100644 --- a/src/climsight/data_analysis_agent.py +++ b/src/climsight/data_analysis_agent.py @@ -21,7 +21,7 @@ from agent_helpers import create_standard_agent_executor from tools.get_data_components import create_get_data_components_tool from tools.era5_climatology_tool import create_era5_climatology_tool -from tools.era5_retrieval_tool import era5_retrieval_tool +from tools.era5_retrieval_tool import create_era5_retrieval_tool from tools.python_repl import CustomPythonREPLTool from tools.image_viewer import create_image_viewer_tool from tools.reflection_tools import reflect_tool @@ -537,7 +537,11 @@ def data_analysis_agent( # 3. ERA5 time series retrieval (if enabled - for detailed year-by-year analysis) if config.get("use_era5_data", False): - tools.append(era5_retrieval_tool) + arraylake_api_key = config.get("arraylake_api_key", "") + if arraylake_api_key: + tools.append(create_era5_retrieval_tool(arraylake_api_key)) + else: + logger.warning("ERA5 data enabled but no arraylake_api_key in config. ERA5 retrieval tool not added.") # 4. Python REPL for analysis/visualization (if enabled) if has_python_repl: diff --git a/src/climsight/streamlit_interface.py b/src/climsight/streamlit_interface.py index c440d2e..b229418 100644 --- a/src/climsight/streamlit_interface.py +++ b/src/climsight/streamlit_interface.py @@ -71,7 +71,10 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r api_key_local = os.environ.get("OPENAI_API_KEY_LOCAL") if not api_key_local: api_key_local = "" - + + # Check for Arraylake API key (for ERA5 data retrieval) + arraylake_api_key = os.environ.get("ARRAYLAKE_API_KEY", "") + #read data while loading here ##### like hist, future = load_data(config) @@ -164,7 +167,8 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r source_descriptions = { 'nextGEMS': 'nextGEMS (High resolution)', 'ICCP': 'ICCP (AWI-CM3, medium resolution)', - 'AWI_CM': 'AWI-CM (CMIP6, low resolution)' + 'AWI_CM': 'AWI-CM (CMIP6, low resolution)', + 'DestinE': 'DestinE IFS-FESOM (High resolution, SSP3-7.0)' } col1_src, col2_src = st.columns([1, 1]) @@ -194,7 +198,16 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r type="password", ) - + # Include Arraylake API key input if ERA5 data is enabled and key not in environment + arraylake_api_key_input = "" + if use_era5_data and not arraylake_api_key: + arraylake_api_key_input = st.text_input( + "Arraylake API key (for ERA5 data)", + placeholder="Enter your Arraylake API key here", + type="password", + help="Required for downloading ERA5 time series data from Earthmover/Arraylake.", + ) + # Replace the st.button with st.form_submit_button submit_button = st.form_submit_button(label='Generate') @@ -205,22 +218,44 @@ def run_streamlit(config, api_key='', skip_llm_call=False, rag_activated=True, r if (not api_key) and (not skip_llm_call) and (config['llm_combine']['model_type'] == "openai"): st.error("Please provide an OpenAI API key.") st.stop() + + # Handle Arraylake API key for ERA5 data + if use_era5_data: + if not arraylake_api_key: + arraylake_api_key = arraylake_api_key_input + if not arraylake_api_key: + st.error("Please provide an Arraylake API key to use ERA5 data retrieval.") + st.stop() + # Store in config so data_analysis_agent can pass it to the tool + config["arraylake_api_key"] = arraylake_api_key + # Update config with the selected LLM mode - #config['llmModeKey'] = "direct_llm" if llmModeKey_box == "Direct" else "agent_llm" + #config['llmModeKey'] = "direct_llm" if llmModeKey_box == "Direct" else "agent_llm" config['show_add_info'] = show_add_info config['use_smart_agent'] = smart_agent config['use_era5_data'] = use_era5_data config['use_powerful_data_analysis'] = use_powerful_data_analysis - - # RUN submit button + + # RUN submit button if submit_button and user_message: if not api_key: api_key = api_key_input if (not api_key) and (not skip_llm_call) and (config['llm_combine']['model_type'] == "openai"): st.error("Please provide an OpenAI API key.") st.stop() + + # Handle Arraylake API key for ERA5 data (in nested block too) + if use_era5_data: + if not arraylake_api_key: + arraylake_api_key = arraylake_api_key_input + if not arraylake_api_key: + st.error("Please provide an Arraylake API key to use ERA5 data retrieval.") + st.stop() + # Store in config so data_analysis_agent can pass it to the tool + config["arraylake_api_key"] = arraylake_api_key + # Update config with the selected LLM mode - #config['llmModeKey'] = "direct_llm" if llmModeKey_box == "Direct" else "agent_llm" + #config['llmModeKey'] = "direct_llm" if llmModeKey_box == "Direct" else "agent_llm" config['show_add_info'] = show_add_info config['use_smart_agent'] = smart_agent config['use_era5_data'] = use_era5_data diff --git a/src/climsight/tools/era5_retrieval_tool.py b/src/climsight/tools/era5_retrieval_tool.py index 435928d..344461b 100644 --- a/src/climsight/tools/era5_retrieval_tool.py +++ b/src/climsight/tools/era5_retrieval_tool.py @@ -99,30 +99,35 @@ def _generate_descriptive_filename(variable_id: str, query_type: str, start_date return f"era5_{clean_var}_{query_type}_{clean_start}_{clean_end}.zarr" def retrieve_era5_data( - variable_id: str, - start_date: str, + variable_id: str, + start_date: str, end_date: str, - min_latitude: float = -90.0, + min_latitude: float = -90.0, max_latitude: float = 90.0, - min_longitude: float = 0.0, + min_longitude: float = 0.0, max_longitude: float = 359.75, work_dir: Optional[str] = None, + arraylake_api_key: Optional[str] = None, **kwargs # Catch-all for unused args ) -> dict: """ Retrieves ERA5 Surface data from Earthmover (Arraylake). Hardcoded to use 'temporal' (Time-series) queries. Uses Nearest Neighbor selection for points to avoid empty slices. + + Args: + arraylake_api_key: Arraylake API key. If not provided, falls back to + ARRAYLAKE_API_KEY environment variable. """ ds = None local_zarr_path = None query_type = "temporal" # Hardcoded for Climsight point-data focus - # Get API Key from environment - ARRAYLAKE_API_KEY = os.environ.get("ARRAYLAKE_API_KEY") + # Get API Key - prefer passed parameter, fall back to environment + ARRAYLAKE_API_KEY = arraylake_api_key or os.environ.get("ARRAYLAKE_API_KEY") - if not ARRAYLAKE_API_KEY: - return {"success": False, "error": "Missing Arraylake API Key", "message": "Please set ARRAYLAKE_API_KEY environment variable."} + if not ARRAYLAKE_API_KEY: + return {"success": False, "error": "Missing Arraylake API Key", "message": "Please provide arraylake_api_key parameter or set ARRAYLAKE_API_KEY environment variable."} try: # Map Variable Name @@ -280,6 +285,53 @@ def retrieve_era5_data( if ds is not None: ds.close() +def create_era5_retrieval_tool(arraylake_api_key: str): + """ + Create the ERA5 retrieval tool with the API key bound. + + Args: + arraylake_api_key: Arraylake API key for accessing Earthmover data + + Returns: + StructuredTool configured for ERA5 data retrieval + """ + def retrieve_era5_wrapper( + variable_id: str, + start_date: str, + end_date: str, + min_latitude: float = -90.0, + max_latitude: float = 90.0, + min_longitude: float = 0.0, + max_longitude: float = 359.75, + work_dir: Optional[str] = None, + ) -> dict: + return retrieve_era5_data( + variable_id=variable_id, + start_date=start_date, + end_date=end_date, + min_latitude=min_latitude, + max_latitude=max_latitude, + min_longitude=min_longitude, + max_longitude=max_longitude, + work_dir=work_dir, + arraylake_api_key=arraylake_api_key, + ) + + return StructuredTool.from_function( + func=retrieve_era5_wrapper, + name="retrieve_era5_data", + description=( + "Retrieves ERA5 Surface climate data from Earthmover (Arraylake). " + "Optimized for TEMPORAL time-series extraction at specific locations. " + "Automatically snaps to nearest grid point to ensure data is returned. " + "Returns a Zarr directory path. " + "Available vars: t2 (temp), sst, u10/v10 (wind), mslp (pressure), tp (precip)." + ), + args_schema=ERA5RetrievalArgs + ) + + +# Keep backward-compatible module-level tool that uses environment variable era5_retrieval_tool = StructuredTool.from_function( func=retrieve_era5_data, name="retrieve_era5_data", From 9c2cea6bad0469639d6f203152c1741b0ce32223 Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Mon, 2 Feb 2026 11:05:54 +0100 Subject: [PATCH 13/16] update requiremenets --- environment.yml | 74 +++++++++++++++++++++++++++++++++++------------- pyproject.toml | 60 ++++++++++++++++++++++++++------------- requirements.txt | 16 +++++------ 3 files changed, 103 insertions(+), 47 deletions(-) diff --git a/environment.yml b/environment.yml index b8d272c..63df334 100644 --- a/environment.yml +++ b/environment.yml @@ -4,36 +4,72 @@ channels: - defaults dependencies: - python=3.11 + + # Web Framework - streamlit + - streamlit-folium + - folium + + # Climate & Geospatial Data Processing - xarray - - geopy + - netcdf4 - geopandas + - shapely - pyproj - - requests - - requests-mock + - geopy + - osmnx + + # Data Science & Numerical Computing - pandas - - folium - - openai - - langchain - - streamlit-folium - - netcdf4 + - numpy + - scipy - dask - - pip - - langchain-community + - matplotlib + + # LangChain & LLM Framework (conda-forge packages) + - langchain + - langchain-community - langchain-openai - langchain-chroma - langchain-core - - osmnx - - matplotlib - - pydantic + - langchain-text-splitters + - langchain-experimental - langgraph - - chardet + - openai + - pydantic + + # Web Scraping & Data Retrieval - bs4 - wikipedia - - scipy - - pyproj + - requests + - requests-mock + + # ERA5 & Zarr tooling + - zarr + - gcsfs + - numcodecs + - blosc + + # Jupyter-backed Python REPL + - jupyter_client + - ipykernel + + # PDF Generation - reportlab - # pip-only packages must go under this section: + + # Configuration & Utilities + - pyyaml + - python-dotenv + + # Testing + - pytest + + # Package Management + - pip + + # pip-only packages (not available on conda-forge) - pip: - - aitta-client - - langchain_classic + - langchain-classic + - langchain-anthropic + - arraylake + - aitta-client diff --git a/pyproject.toml b/pyproject.toml index 7268c77..16f232c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,27 +11,33 @@ authors = [ {name = "boryasbora"} ] license = {text = "BSD-3-Clause"} -keywords = ["climate", "llm", "climate-assessment", "rag", "decision-support"] +keywords = ["climate", "llm", "climate-assessment", "rag", "decision-support"] +requires-python = ">=3.11" dependencies = [ + # Web Framework "streamlit", + "streamlit-folium", + "folium", + + # Climate & Geospatial Data Processing "xarray", - "geopy", + "netcdf4", "geopandas", "shapely", "pyproj", - "requests", - "requests-mock", + "geopy", + "osmnx", + + # Data Science & Numerical Computing "pandas", - "folium", - "langchain", - "langchain-classic", - "streamlit-folium", - "netcdf4", + "numpy", + "scipy", "dask", - "pip", - "osmnx", "matplotlib", - "openai", + + # LangChain & LLM Framework + "langchain", + "langchain-classic", "langchain-community", "langchain-openai", "langchain-chroma", @@ -39,22 +45,38 @@ dependencies = [ "langchain-text-splitters", "langchain-experimental", "langchain-anthropic", - "pydantic", "langgraph", + "openai", + "pydantic", + + # Web Scraping & Data Retrieval "bs4", "wikipedia", - "scipy", + "requests", + "requests-mock", + + # PDF Generation "reportlab", - "pyyaml", - "jupyter_client", - "ipykernel", + + # ERA5 & Zarr tooling "zarr", + "arraylake", "gcsfs", "numcodecs", "blosc", - "aitta-client" + + # Jupyter-backed Python REPL + "jupyter_client", + "ipykernel", + + # Configuration & Utilities + "pyyaml", + "python-dotenv", ] +[project.optional-dependencies] +aitta = ["aitta-client"] +dev = ["pytest", "flake8"] [build-system] requires = ["setuptools"] @@ -65,8 +87,6 @@ package-dir = {"" = "src"} [tool.setuptools.packages.find] where = ["src"] -#find = {} # Scan the project directory with the default parameters - [project.scripts] climsight = "climsight.launch:launch_streamlit" diff --git a/requirements.txt b/requirements.txt index d8602d0..b874d85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ # ClimSight Dependencies -# Generated from pyproject.toml # Install with: pip install -r requirements.txt # Web Framework streamlit streamlit-folium +folium # Climate & Geospatial Data Processing xarray @@ -14,7 +14,6 @@ shapely pyproj geopy osmnx -folium # Data Science & Numerical Computing pandas @@ -50,6 +49,7 @@ reportlab # ERA5 + Zarr tooling zarr +arraylake gcsfs numcodecs blosc @@ -58,13 +58,13 @@ blosc jupyter_client ipykernel -# Optional: CSC's AITTA Platform (local/custom AI models) -# Uncomment if you need CSC AITTA integration -# aitta-client +# Configuration & Utilities +pyyaml +python-dotenv # Development & Testing pytest -pyyaml -# Package Management -pip +# Optional: CSC's AITTA Platform (local/custom AI models) +# Uncomment if you need CSC AITTA integration +# aitta-client From 797507b570868f0c234415410aa0db390d0564e5 Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Mon, 2 Feb 2026 11:21:46 +0100 Subject: [PATCH 14/16] update README --- README.md | 143 ++++++++++++++++++------------------------------------ 1 file changed, 48 insertions(+), 95 deletions(-) diff --git a/README.md b/README.md index 734aaf4..2e2985c 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,7 @@ ClimSight is an advanced tool that integrates Large Language Models (LLMs) with climate data to provide localized climate insights for decision-making. ClimSight transforms complex climate data into actionable insights for agriculture, urban planning, disaster management, and policy development. -The target audience includes researchers, providers of climate services, policymakers, agricultural planners, urban developers, and other stakeholders who require detailed climate information to support decision-making. ClimSight is designed to democratize access to climate -data, empowering users with insights relevant to their specific contexts. +The target audience includes researchers, providers of climate services, policymakers, agricultural planners, urban developers, and other stakeholders who require detailed climate information to support decision-making. ClimSight is designed to democratize access to climate data, empowering users with insights relevant to their specific contexts. ![Image](https://github.com/user-attachments/assets/f9f89735-ef08-4c91-bc03-112c8e4c0896) @@ -15,61 +14,11 @@ ClimSight distinguishes itself through several key advancements: - **Real-World Applications**: ClimSight is validated through practical examples, such as assessing climate risks for specific agricultural activities and urban planning scenarios. -## Installation Options +## Installation -You can use ClimSight in three ways: -1. Run a pre-built Docker container (simplest approach) -2. Build and run a Docker container from source -3. Install the Python package (via pip or conda/mamba) +### Recommended: Building from source with conda/mamba -Using ClimSight requires an OpenAI API key unless using the `skipLLMCall` mode for testing. The API key is only needed when running the application, not during installation. - -## Batch Processing - -For batch processing of climate questions, the `sequential` directory contains specialized tools for generating, validating, and processing questions in bulk. These tools are particularly useful for research and analysis requiring multiple climate queries. See the [sequential/README.md](sequential/README.md) for detailed usage instructions. - -## 1. Running with Docker (Pre-built Container) - -The simplest way to get started is with our pre-built Docker container: - -```bash -# Make sure your OpenAI API key is set as an environment variable -export OPENAI_API_KEY="your-api-key-here" - -# Pull and run the container -docker pull koldunovn/climsight:stable -docker run -p 8501:8501 -e OPENAI_API_KEY=$OPENAI_API_KEY koldunovn/climsight:stable -``` - -Then open `http://localhost:8501/` in your browser. - -## 2. Building and Running from Source with Docker - -If you prefer to build from the latest source: - -```bash -# Clone the repository -git clone https://github.com/CliDyn/climsight.git -cd climsight - -# Download required data -python download_data.py - -# Build and run the container -docker build -t climsight . -docker run -p 8501:8501 -e OPENAI_API_KEY=$OPENAI_API_KEY climsight -``` - -Visit `http://localhost:8501/` in your browser once the container is running. - -For testing without OpenAI API calls: -```bash -docker run -p 8501:8501 -e STREAMLIT_ARGS="skipLLMCall" climsight -``` - -## 3. Python Package Installation - -### Option A: Building from source with conda/mamba +This is the recommended installation method to get the latest features and updates. ```bash # Clone the repository @@ -84,11 +33,10 @@ conda activate climsight python download_data.py ``` -### Option B: Using pip +### Alternative: Using pip from source -It's recommended to create a virtual environment to avoid dependency conflicts: ```bash -# Option 1: Install from source +# Clone the repository git clone https://github.com/CliDyn/climsight.git cd climsight @@ -96,33 +44,34 @@ cd climsight python -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate -# Install ClimSight -pip install -e . +# Install dependencies +pip install -r requirements.txt + +# Download required data python download_data.py ``` -Or if you prefer to set up without cloning the repository: +### Running with Docker (Stable Release v1.0.0) + +The Docker container provides a stable release (v1.0.0) of ClimSight. For the latest features, please install from source as described above. ```bash -# Option 2: Install from PyPI -# Create and activate a virtual environment -python -m venv climsight_env -source climsight_env/bin/activate # On Windows: climsight_env\Scripts\activate +# Make sure your OpenAI API key is set as an environment variable +export OPENAI_API_KEY="your-api-key-here" -# Install the package -pip install climsight +# Pull and run the container +docker pull koldunovn/climsight:stable +docker run -p 8501:8501 -e OPENAI_API_KEY=$OPENAI_API_KEY koldunovn/climsight:stable +``` -# Create a directory for data -mkdir -p climsight -cd climsight +Then open `http://localhost:8501/` in your browser. -# Download necessary configuration files -wget https://raw.githubusercontent.com/CliDyn/climsight/main/data_sources.yml -wget https://raw.githubusercontent.com/CliDyn/climsight/main/download_data.py -wget https://raw.githubusercontent.com/CliDyn/climsight/main/config.yml +### Using pip from PyPI (Stable Release v1.0.0) -# Download the required data (about 8 GB) -python download_data.py +The PyPI package provides a stable release (v1.0.0) of ClimSight. For the latest features, please install from source as described above. + +```bash +pip install climsight ``` ## Configuration @@ -131,50 +80,54 @@ ClimSight will automatically use a `config.yml` file from the current directory. ```yaml # Key settings you can modify in config.yml: -# - LLM model (gpt-4, ...) +# - LLM model (gpt-4, gpt-5, ...) # - Climate data sources # - RAG database configuration # - Agent parameters +# - ERA5 data retrieval settings ``` -## Running ClimSight -### If installed with conda/mamba from source: +## API Keys -```bash -# Run from the repository root -streamlit run src/climsight/climsight.py -``` +### OpenAI API Key -### If installed with pip: +ClimSight requires an OpenAI API key for LLM functionality. You can set it as an environment variable: ```bash -# Make sure you're in the directory with your data and config -climsight +export OPENAI_API_KEY="your-api-key-here" ``` -You can optionally set your OpenAI API key as an environment variable: +Alternatively, you can enter your API key directly in the browser interface when prompted. + +### Arraylake API Key (Optional - for ERA5 Data) + +If you want to use ERA5 time series data retrieval (enabled via the "Enable ERA5 data" toggle in the UI), you need an Arraylake API key from [Earthmover](https://earthmover.io/). This allows downloading ERA5 reanalysis data for detailed historical climate analysis. + ```bash -export OPENAI_API_KEY="your-api-key-here" +export ARRAYLAKE_API_KEY="your-arraylake-api-key-here" ``` -Otherwise, you can enter your API key directly in the browser interface when prompted. +You can also enter the Arraylake API key in the browser interface when the ERA5 data option is enabled. -### Testing without an OpenAI API key: +## Running ClimSight ```bash -# From source: -streamlit run src/climsight/climsight.py skipLLMCall - -# Or if installed with pip: -climsight skipLLMCall +# Run from the repository root +streamlit run src/climsight/climsight.py ``` The application will open in your browser automatically. Just type your climate-related questions and press "Generate" to get insights. ClimSight Interface +## Batch Processing + +For batch processing of climate questions, the `sequential` directory contains specialized tools for generating, validating, and processing questions in bulk. These tools are particularly useful for research and analysis requiring multiple climate queries. See the [sequential/README.md](sequential/README.md) for detailed usage instructions. + ## Citation If you use or refer to ClimSight in your work, please cite: +Kuznetsov, I., Jost, A.A., Pantiukhin, D. et al. Transforming climate services with LLMs and multi-source data integration. _npj Clim. Action_ **4**, 97 (2025). https://doi.org/10.1038/s44168-025-00300-y + Koldunov, N., Jung, T. Local climate services for all, courtesy of large language models. _Commun Earth Environ_ **5**, 13 (2024). https://doi.org/10.1038/s43247-023-01199-1 From ec5233fe596fd94f8a40cd6300e9868305e031ad Mon Sep 17 00:00:00 2001 From: Ivan Kuznetsov Date: Mon, 2 Feb 2026 12:21:53 +0100 Subject: [PATCH 15/16] update download sources --- README.md | 6 ++++++ data_sources.yml | 6 ++++++ download_data.py | 16 ++++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/README.md b/README.md index 2e2985c..4e58e7e 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,9 @@ conda activate climsight # Download required data python download_data.py + +# Optional: download DestinE data (large ~12 GB, not downloaded by default) +python download_data.py DestinE ``` ### Alternative: Using pip from source @@ -49,6 +52,9 @@ pip install -r requirements.txt # Download required data python download_data.py + +# Optional: download DestinE data (large ~12 GB, not downloaded by default) +python download_data.py DestinE ``` ### Running with Docker (Stable Release v1.0.0) diff --git a/data_sources.yml b/data_sources.yml index 88b7f7b..45f1e08 100644 --- a/data_sources.yml +++ b/data_sources.yml @@ -120,4 +120,10 @@ sources: url: 'https://swift.dkrz.de/v1/dkrz_035d8f6ff058403bb42f8302e6badfbc/climsight/awi_cm.zip?temp_url_sig=f40cc2f349b24482a6f7247d173ca194fad28950&temp_url_expires=2299-10-02T09:52:13Z' archive_type: 'zip' subdir: './' + citation: + + - filename: 'DestinE.zip' + url: 'https://swift.dkrz.de/v1/dkrz_035d8f6ff058403bb42f8302e6badfbc/climsight/DestinE.zip?temp_url_sig=f60ad2be0bf65479f489611255c066148dc4741c&temp_url_expires=2053-06-19T11:20:40Z' + archive_type: 'zip' + subdir: './' citation: \ No newline at end of file diff --git a/download_data.py b/download_data.py index 47c890f..9f6a81a 100644 --- a/download_data.py +++ b/download_data.py @@ -97,6 +97,11 @@ def main(): # Parse command-line argument (--source_files) parser = argparse.ArgumentParser(description="Download and extract the raw source files of the RAG.") parser.add_argument('--source_files', type=bool, default=False, help='Whether to download and extract source files (IPCC text reports).') + parser.add_argument( + 'datasets', + nargs='*', + help="Optional extra datasets to include (e.g. DestinE).", + ) #parser.add_argument('--CMIP_OIFS', type=bool, default=False, help='Whether to download CMIP6 low resolution AWI model data and ECE4/OIFS data.') args = parser.parse_args() @@ -112,6 +117,11 @@ def main(): sources = [d for d in sources if d['filename'] != 'ipcc_text_reports.zip'] #if not args.CMIP_OIFS: # sources = [d for d in sources if d['filename'] != 'data_climate_foresight.zip'] + + # Skip DestinE unless explicitly requested (large dataset). + requested = {name.strip().lower() for name in args.datasets} + if 'destine' not in requested: + sources = [d for d in sources if d['filename'] != 'DestinE.zip'] #make subdirs list and clean it subdirs = [] @@ -136,6 +146,12 @@ def main(): url = entry['url'] subdir = os.path.join(base_path, entry['subdir']) + if not url: + files_skiped.append(file) + urls_skiped.append(url) + subdirs_skiped.append(subdir) + continue + if download_file(url, file): extract_arch(file, subdir) files_downloaded.append(file) From e217f4982bb2c103f8288e9ed650ef6e07d732ff Mon Sep 17 00:00:00 2001 From: dmpantiu Date: Mon, 2 Feb 2026 17:07:33 +0100 Subject: [PATCH 16/16] fix: improve REPL robustness and path resolution - Fix regex to handle leading whitespace before code blocks (python_repl.py) - Reset is_initialized flag on kernel restart to re-run initialization (python_repl.py) - Add sandbox_path parameter for relative path resolution (image_viewer.py) - Pass sandbox_path in data_analysis_agent.py call - Implement atomic writes for ERA5 cache to prevent corruption (era5_retrieval_tool.py) --- src/climsight/data_analysis_agent.py | 8 +++++-- src/climsight/tools/era5_retrieval_tool.py | 27 +++++++++++++++++----- src/climsight/tools/image_viewer.py | 16 ++++++++++--- src/climsight/tools/python_repl.py | 5 ++-- 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/climsight/data_analysis_agent.py b/src/climsight/data_analysis_agent.py index a820399..6a12597 100644 --- a/src/climsight/data_analysis_agent.py +++ b/src/climsight/data_analysis_agent.py @@ -560,8 +560,12 @@ def data_analysis_agent( # Extract model name from config or default to gpt-4o vision_model = config.get("llm_combine", {}).get("model_name", "gpt-4o") - # FIX: Pass the required api_key and model_name - image_viewer_tool = create_image_viewer_tool(api_key, vision_model) + # Pass sandbox_path so relative paths from the agent can be resolved + image_viewer_tool = create_image_viewer_tool( + openai_api_key=api_key, + model_name=vision_model, + sandbox_path=state.uuid_main_dir + ) tools.append(image_viewer_tool) # 7. Image reflection and wise_agent - ONLY when Python REPL is enabled diff --git a/src/climsight/tools/era5_retrieval_tool.py b/src/climsight/tools/era5_retrieval_tool.py index 344461b..fbcdfe5 100644 --- a/src/climsight/tools/era5_retrieval_tool.py +++ b/src/climsight/tools/era5_retrieval_tool.py @@ -252,7 +252,7 @@ def retrieve_era5_data( if subset.sizes.get('time', 0) == 0: return {"success": False, "error": "Empty time slice.", "message": "No data found for the requested dates."} - # --- 4. Save to Zarr --- + # --- 4. Save to Zarr (Atomic Write) --- logging.info(f"Downloading data to {local_zarr_path}...") ds_out = subset.to_dataset(name=short_var) @@ -261,11 +261,26 @@ def retrieve_era5_data( for var in ds_out.variables: ds_out[var].encoding = {} - if os.path.exists(local_zarr_path): - shutil.rmtree(local_zarr_path) - - ds_out.to_zarr(local_zarr_path, mode="w", consolidated=True, compute=True) - logging.info("✅ Download complete.") + # Atomic write: write to temp dir first, then rename on success + temp_zarr_path = local_zarr_path + ".tmp" + + # Clean up any previous failed temp directory + if os.path.exists(temp_zarr_path): + shutil.rmtree(temp_zarr_path) + + try: + ds_out.to_zarr(temp_zarr_path, mode="w", consolidated=True, compute=True) + + # Atomic replace: remove old cache (if exists) and rename temp to final + if os.path.exists(local_zarr_path): + shutil.rmtree(local_zarr_path) + os.rename(temp_zarr_path, local_zarr_path) + logging.info("✅ Download complete.") + except Exception as write_error: + # Clean up failed temp directory + if os.path.exists(temp_zarr_path): + shutil.rmtree(temp_zarr_path, ignore_errors=True) + raise write_error return { "success": True, diff --git a/src/climsight/tools/image_viewer.py b/src/climsight/tools/image_viewer.py index b501c89..35589cb 100644 --- a/src/climsight/tools/image_viewer.py +++ b/src/climsight/tools/image_viewer.py @@ -101,19 +101,29 @@ class ImageViewerArgs(BaseModel): ) -def create_image_viewer_tool(openai_api_key: str, model_name: str): +def create_image_viewer_tool(openai_api_key: str, model_name: str, sandbox_path: Optional[str] = None): """ Create the image viewer tool with the API key and model bound. Args: openai_api_key: OpenAI API key model_name: Model name to use + sandbox_path: Optional sandbox directory path for resolving relative paths Returns: StructuredTool configured for image analysis """ def view_image_wrapper(image_path: str) -> str: - return view_and_analyze_image(image_path, openai_api_key, model_name) + resolved_path = image_path + # If path is relative and sandbox_path is provided, resolve it + if sandbox_path and not os.path.isabs(image_path): + resolved_path = os.path.join(sandbox_path, image_path) + # Fallback: if still not found and sandbox_path exists, try it anyway + if not os.path.exists(resolved_path) and sandbox_path: + alt_path = os.path.join(sandbox_path, image_path) + if os.path.exists(alt_path): + resolved_path = alt_path + return view_and_analyze_image(resolved_path, openai_api_key, model_name) return StructuredTool.from_function( func=view_image_wrapper, @@ -122,7 +132,7 @@ def view_image_wrapper(image_path: str) -> str: "Analyze climate-related visualizations to extract scientific insights. " "Use this tool after generating plots with python_repl to understand " "the patterns and trends shown in the visualization. " - "Provide the full path to the saved image file." + "Provide the path to the saved image file (can be relative to the sandbox)." ), args_schema=ImageViewerArgs ) \ No newline at end of file diff --git a/src/climsight/tools/python_repl.py b/src/climsight/tools/python_repl.py index a5f9ba0..8892666 100644 --- a/src/climsight/tools/python_repl.py +++ b/src/climsight/tools/python_repl.py @@ -136,8 +136,9 @@ def _execute_code(self, code: str, timeout: float = 300.0) -> Dict[str, Any]: # 2. Check Kernel Vitality if not self.km.is_alive(): result["status"] = "error" - result["error"] = "Kernel died unexpectedly." + result["error"] = "Kernel died unexpectedly. Session was restarted - please retry your command." self.restart_kernel() + self.is_initialized = False # Force re-initialization on next run break # 3. Get Message @@ -184,7 +185,7 @@ def run(self, code: str) -> Dict[str, Any]: code = code.strip() if code.startswith("```"): # Remove leading ```python or ``` (case insensitive) - code = re.sub(r"^```(?:python)?\s*\n", "", code, flags=re.IGNORECASE) + code = re.sub(r"^\s*```(?:python)?\s*\n", "", code, flags=re.IGNORECASE) # Remove trailing ``` code = re.sub(r"\n\s*```\s*$", "", code)