diff --git a/pkg-py/src/querychat/__init__.py b/pkg-py/src/querychat/__init__.py index 0e3eaa5f..00527afb 100644 --- a/pkg-py/src/querychat/__init__.py +++ b/pkg-py/src/querychat/__init__.py @@ -2,9 +2,12 @@ from ._deprecated import mod_server as server from ._deprecated import mod_ui as ui from ._shiny import QueryChat +from .tools import VisualizeDashboardData, VisualizeQueryData __all__ = ( "QueryChat", + "VisualizeDashboardData", + "VisualizeQueryData", # TODO(lifecycle): Remove these deprecated functions when we reach v1.0 "greeting", "init", diff --git a/pkg-py/src/querychat/_dash.py b/pkg-py/src/querychat/_dash.py index 16d612dc..97d53091 100644 --- a/pkg-py/src/querychat/_dash.py +++ b/pkg-py/src/querychat/_dash.py @@ -24,6 +24,7 @@ from collections.abc import Callable from pathlib import Path as PathType + import altair as alt import chatlas import ibis import narwhals.stable.v1 as nw @@ -113,7 +114,12 @@ def __init__( *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | PathType] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | PathType] = None, @@ -129,7 +135,12 @@ def __init__( *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | PathType] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | PathType] = None, @@ -145,7 +156,12 @@ def __init__( *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | PathType] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | PathType] = None, @@ -161,7 +177,12 @@ def __init__( *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | PathType] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | PathType] = None, @@ -176,7 +197,12 @@ def __init__( *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | PathType] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | PathType] = None, @@ -207,6 +233,96 @@ def store_id(self) -> str: """ return self._ids.store + def ggvis( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> alt.Chart | None: + """ + Get the current Altair visualization chart. + + Parameters + ---------- + state + The state dict from the store component. + source + Which visualization to return. "filter" returns the dashboard + visualization (from visualize_dashboard tool), "query" returns + the query visualization (from visualize_query tool). + + Returns + ------- + : + An Altair Chart object, or None if no visualization exists. + + """ + import ggsql + + if source == "filter": + spec = state.get("filter_viz_spec") + if spec is None: + return None + # Render against current filtered data + df = as_narwhals(self.df(state)) + return ggsql.render_altair(df, spec) + else: + ggsql_query = state.get("query_viz_ggsql") + if ggsql_query is None: + return None + # Re-execute SQL and render + sql, viz_spec = ggsql.split_query(ggsql_query) + df = as_narwhals(self._data_source.execute_query(sql)) + return ggsql.render_altair(df, viz_spec) + + def ggsql( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> str | None: + """ + Get the current ggsql specification. + + Parameters + ---------- + state + The state dict from the store component. + source + Which specification to return. "filter" returns the VISUALISE spec + from visualize_dashboard, "query" returns the full ggsql query + from visualize_query. + + Returns + ------- + : + The ggsql specification string, or None if no visualization exists. + + """ + if source == "filter": + return state.get("filter_viz_spec") + else: + return state.get("query_viz_ggsql") + + def ggtitle( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> str | None: + """ + Get the current visualization title. + + Parameters + ---------- + state + The state dict from the store component. + source + Which title to return. "filter" returns the title from + visualize_dashboard, "query" returns the title from visualize_query. + + Returns + ------- + : + The visualization title, or None if no visualization exists. + + """ + if source == "filter": + return state.get("filter_viz_title") + else: + return state.get("query_viz_title") + def app(self) -> dash.Dash: """ Create a complete Dash app. @@ -393,6 +509,34 @@ def app_layout(ids: IDs, table_name: str, chat_ui): body_class_name="d-flex flex-column", ) + # Filter plot card + filter_plot_card = card_ui( + body=[ + html.Iframe( + id=ids.filter_plot, + srcDoc="
No filter visualization yet. Ask the assistant to create one.
", + style={"width": "100%", "height": "400px", "border": "none"}, + ), + dcc.Markdown(id=ids.filter_ggsql, className="querychat-ggsql-display mt-2"), + ], + title_id=ids.filter_plot_title, + class_name="h-100", + ) + + # Query plot card + query_plot_card = card_ui( + body=[ + html.Iframe( + id=ids.query_plot, + srcDoc="No query visualization yet. Ask the assistant to create one.
", + style={"width": "100%", "height": "400px", "border": "none"}, + ), + dcc.Markdown(id=ids.query_ggsql, className="querychat-ggsql-display mt-2"), + ], + title_id=ids.query_plot_title, + class_name="h-100", + ) + chat_card = card_ui( body=chat_ui, title="Chat", @@ -407,7 +551,27 @@ def app_layout(ids: IDs, table_name: str, chat_ui): [ dbc.Col(chat_card, width=4), dbc.Col( - [sql_card, data_card], + dbc.Tabs( + [ + dbc.Tab( + [sql_card, data_card], + label="Data", + tab_id="data-tab", + ), + dbc.Tab( + filter_plot_card, + label="Filter Plot", + tab_id="filter-plot-tab", + ), + dbc.Tab( + query_plot_card, + label="Query Plot", + tab_id="query-plot-tab", + ), + ], + id=ids.tabs, + active_tab="data-tab", + ), width=8, className="d-flex flex-column", ), @@ -426,12 +590,15 @@ def register_app_callbacks( table_name: str, deserialize_state: Callable[[AppStateDict], AppState], ) -> None: - """Register callbacks for SQL display, data table, and export.""" + """Register callbacks for SQL display, data table, visualizations, and export.""" + import ggsql from dash.dcc.express import send_data_frame import dash from dash import Input, Output, State + from ._ggsql import vegalite_to_html + @app.callback( [ Output(ids.sql_title, "children"), @@ -440,6 +607,12 @@ def register_app_callbacks( Output(ids.data_table, "columnDefs"), Output(ids.data_info, "children"), Output(ids.store, "data", allow_duplicate=True), + Output(ids.filter_plot_title, "children"), + Output(ids.filter_plot, "srcDoc"), + Output(ids.filter_ggsql, "children"), + Output(ids.query_plot_title, "children"), + Output(ids.query_plot, "srcDoc"), + Output(ids.query_ggsql, "children"), ], [ Input(ids.store, "data"), @@ -459,7 +632,8 @@ def update_display(state_data: AppStateDict, reset_clicks): sql_title = state.title or "SQL Query" sql_code = f"```sql\n{state.get_display_sql()}\n```" - nw_df = as_narwhals(state.get_current_data()) + current_data = state.get_current_data() + nw_df = as_narwhals(current_data) nrow, ncol = nw_df.shape display_df = nw_df.to_pandas() @@ -472,6 +646,38 @@ def update_display(state_data: AppStateDict, reset_clicks): data_info_parts.append(f"Data has {nrow} rows and {ncol} columns.") data_info = " ".join(data_info_parts) + # Filter visualization - render on demand + filter_title = state.filter_viz_title or "Filter Plot" + filter_spec = state.filter_viz_spec + + if filter_spec: + # Render against current filtered data + chart = ggsql.render_altair(nw_df, filter_spec) + filter_html = vegalite_to_html(chart.to_dict()) + else: + filter_html = ( + "No filter visualization yet. Ask the assistant to create one.
" + ) + + filter_ggsql_md = f"```sql\n{filter_spec}\n```" if filter_spec else "" + + # Query visualization - render on demand + query_title = state.query_viz_title or "Query Plot" + query_ggsql_str = state.query_viz_ggsql + + if query_ggsql_str: + # Re-execute SQL and render + sql_part, viz_spec = ggsql.split_query(query_ggsql_str) + query_df = as_narwhals(state.data_source.execute_query(sql_part)) + chart = ggsql.render_altair(query_df, viz_spec) + query_html = vegalite_to_html(chart.to_dict()) + else: + query_html = ( + "No query visualization yet. Ask the assistant to create one.
" + ) + + query_ggsql_md = f"```sql\n{query_ggsql_str}\n```" if query_ggsql_str else "" + return ( sql_title, sql_code, @@ -479,6 +685,12 @@ def update_display(state_data: AppStateDict, reset_clicks): table_columns, data_info, state.to_dict(), + filter_title, + filter_html, + filter_ggsql_md, + query_title, + query_html, + query_ggsql_md, ) @app.callback( diff --git a/pkg-py/src/querychat/_dash_ui.py b/pkg-py/src/querychat/_dash_ui.py index f6ba64ca..f838aaeb 100644 --- a/pkg-py/src/querychat/_dash_ui.py +++ b/pkg-py/src/querychat/_dash_ui.py @@ -28,6 +28,13 @@ class IDs: data_info: str download_csv: str export_button: str + tabs: str + filter_plot: str + filter_plot_title: str + filter_ggsql: str + query_plot: str + query_plot_title: str + query_ggsql: str @classmethod def from_table_name(cls, table_name: str) -> IDs: @@ -46,6 +53,13 @@ def from_table_name(cls, table_name: str) -> IDs: data_info=f"{prefix}-data-info", download_csv=f"{prefix}-download-csv", export_button=f"{prefix}-export-button", + tabs=f"{prefix}-tabs", + filter_plot=f"{prefix}-filter-plot", + filter_plot_title=f"{prefix}-filter-plot-title", + filter_ggsql=f"{prefix}-filter-ggsql", + query_plot=f"{prefix}-query-plot", + query_plot_title=f"{prefix}-query-plot-title", + query_ggsql=f"{prefix}-query-ggsql", ) diff --git a/pkg-py/src/querychat/_datasource.py b/pkg-py/src/querychat/_datasource.py index 5cac5f08..af7628f5 100644 --- a/pkg-py/src/querychat/_datasource.py +++ b/pkg-py/src/querychat/_datasource.py @@ -214,6 +214,7 @@ def __init__(self, df: nw.DataFrame, table_name: str): self._df_lib = native_namespace.__name__ self._conn = duckdb.connect(database=":memory:") + # NOTE: if native representation is polars, pyarrow is required for registration self._conn.register(table_name, self._df.to_native()) self._conn.execute(""" -- extensions: lock down supply chain + auto behaviors diff --git a/pkg-py/src/querychat/_ggsql.py b/pkg-py/src/querychat/_ggsql.py new file mode 100644 index 00000000..9c7bb2cc --- /dev/null +++ b/pkg-py/src/querychat/_ggsql.py @@ -0,0 +1,74 @@ +"""Helpers for ggsql integration.""" + +from __future__ import annotations + +import re + + +def extract_title(viz_spec: str) -> str | None: + """ + Extract the title from a VISUALISE spec's LABEL clause. + + Parameters + ---------- + viz_spec + The VISUALISE portion of a ggsql query. + + Returns + ------- + str | None + The title if found, otherwise None. + + """ + # Match LABEL title => 'value' or LABEL title => "value" + pattern = r"LABEL\s+title\s*=>\s*['\"]([^'\"]+)['\"]" + match = re.search(pattern, viz_spec, re.IGNORECASE) + if match: + return match.group(1) + return None + + +def vegalite_to_html(vegalite_spec: dict) -> str: + """ + Convert a Vega-Lite specification to standalone HTML. + + This renders the spec directly using vega-embed. + + Parameters + ---------- + vegalite_spec + A Vega-Lite specification as a dictionary. + + Returns + ------- + str + A complete HTML document that renders the chart. + + """ + import json + + spec_json = json.dumps(vegalite_spec) + + # ggsql produces v6 specs + vl_version = "6" + + return f""" + + + + + + + + + + + + +""" diff --git a/pkg-py/src/querychat/_gradio.py b/pkg-py/src/querychat/_gradio.py index c0f3518d..374c375d 100644 --- a/pkg-py/src/querychat/_gradio.py +++ b/pkg-py/src/querychat/_gradio.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, overload +from typing import TYPE_CHECKING, Any, Literal, Optional, overload from gradio.context import Context from narwhals.stable.v1.typing import IntoDataFrameT, IntoFrameT, IntoLazyFrameT @@ -24,6 +24,7 @@ if TYPE_CHECKING: from pathlib import Path + import altair as alt import chatlas import ibis import sqlalchemy @@ -106,7 +107,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -121,7 +127,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -136,7 +147,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -151,7 +167,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -165,7 +186,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -203,6 +229,96 @@ def head(self) -> str: """ return f"" + def ggvis( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> alt.Chart | None: + """ + Get the current Altair visualization chart. + + Parameters + ---------- + state + The state dict from the state component. + source + Which visualization to return. "filter" returns the dashboard + visualization (from visualize_dashboard tool), "query" returns + the query visualization (from visualize_query tool). + + Returns + ------- + : + An Altair Chart object, or None if no visualization exists. + + """ + import ggsql + + if source == "filter": + spec = state.get("filter_viz_spec") + if spec is None: + return None + # Render against current filtered data + df = as_narwhals(self.df(state)) + return ggsql.render_altair(df, spec) + else: + ggsql_query = state.get("query_viz_ggsql") + if ggsql_query is None: + return None + # Re-execute SQL and render + sql, viz_spec = ggsql.split_query(ggsql_query) + df = as_narwhals(self._data_source.execute_query(sql)) + return ggsql.render_altair(df, viz_spec) + + def ggsql( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> str | None: + """ + Get the current ggsql specification. + + Parameters + ---------- + state + The state dict from the state component. + source + Which specification to return. "filter" returns the VISUALISE spec + from visualize_dashboard, "query" returns the full ggsql query + from visualize_query. + + Returns + ------- + : + The ggsql specification string, or None if no visualization exists. + + """ + if source == "filter": + return state.get("filter_viz_spec") + else: + return state.get("query_viz_ggsql") + + def ggtitle( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> str | None: + """ + Get the current visualization title. + + Parameters + ---------- + state + The state dict from the state component. + source + Which title to return. "filter" returns the title from + visualize_dashboard, "query" returns the title from visualize_query. + + Returns + ------- + : + The visualization title, or None if no visualization exists. + + """ + if source == "filter": + return state.get("filter_viz_title") + else: + return state.get("query_viz_title") + def ui(self) -> gr.State: """ Create chat UI components for custom layouts. @@ -327,6 +443,11 @@ def app(self) -> GradioBlocksWrapper: A wrapped Gradio Blocks app ready to launch. The wrapper injects querychat CSS/JS at launch time for Gradio 6.0+ compatibility. + The app includes three tabs: + - **Data**: Shows the filtered data table with the current SQL query + - **Filter Plot**: Shows the persistent dashboard visualization + - **Query Plot**: Shows the most recent query visualization + """ data_source = self._require_data_source("app") from gradio.themes import Soft @@ -343,31 +464,53 @@ def app(self) -> GradioBlocksWrapper: gr.Markdown(f"## `{table_name}`") - with gr.Group(): - with gr.Row(): - sql_title = gr.Markdown("**Current Query**") - reset_btn = gr.Button( - "Reset", size="sm", variant="secondary", scale=0 + with gr.Tabs(): + with gr.Tab("Data"): + with gr.Group(): + with gr.Row(): + sql_title = gr.Markdown("**Current Query**") + reset_btn = gr.Button( + "Reset", size="sm", variant="secondary", scale=0 + ) + sql_display = gr.Code( + label="", + language="sql", + value=f"SELECT * FROM {table_name}", + interactive=False, + lines=2, + ) + + with gr.Group(): + gr.Markdown("**Data Preview**") + data_display = gr.Dataframe( + label="", + buttons=["fullscreen", "copy"], + show_search="filter", + ) + data_info = gr.Markdown("") + + with gr.Tab("Filter Plot"): + filter_plot_title = gr.Markdown("") + filter_plot_display = gr.Plot(label="") + filter_ggsql_display = gr.Code( + label="ggsql spec", language="sql", lines=2 + ) + filter_plot_info = gr.Markdown( + "*No filter visualization yet. Ask the assistant to create one.*" ) - sql_display = gr.Code( - label="", - language="sql", - value=f"SELECT * FROM {table_name}", - interactive=False, - lines=2, - ) - with gr.Group(): - gr.Markdown("**Data Preview**") - data_display = gr.Dataframe( - label="", - buttons=["fullscreen", "copy"], - show_search="filter", - ) - data_info = gr.Markdown("") + with gr.Tab("Query Plot"): + query_plot_title = gr.Markdown("") + query_plot_display = gr.Plot(label="") + query_ggsql_display = gr.Code( + label="ggsql query", language="sql", lines=2 + ) + query_plot_info = gr.Markdown( + "*No query visualization yet. Ask the assistant to create one.*" + ) def update_displays(state_dict: AppStateDict): - """Update SQL and data displays based on state.""" + """Update SQL, data, and visualization displays based on state.""" title = state_dict.get("title") if state_dict else None error = state_dict.get("error") if state_dict else None @@ -385,11 +528,44 @@ def update_displays(state_dict: AppStateDict): data_info_parts = [] if error: - data_info_parts.append(f"⚠️ {error}") + data_info_parts.append(f"Warning: {error}") data_info_parts.append(f"*Data has {nrow} rows and {ncol} columns.*") data_info_text = " ".join(data_info_parts) - return sql_title_text, sql_code, native_df, data_info_text + # Filter visualization + filter_chart = self.ggvis(state_dict, "filter") + filter_title_text = self.ggtitle(state_dict, "filter") or "" + filter_spec = self.ggsql(state_dict, "filter") or "" + filter_info = ( + "" + if filter_chart + else "*No filter visualization yet. Ask the assistant to create one.*" + ) + + # Query visualization + query_chart = self.ggvis(state_dict, "query") + query_title_text = self.ggtitle(state_dict, "query") or "" + query_spec = self.ggsql(state_dict, "query") or "" + query_info = ( + "" + if query_chart + else "*No query visualization yet. Ask the assistant to create one.*" + ) + + return ( + sql_title_text, + sql_code, + native_df, + data_info_text, + f"### {filter_title_text}" if filter_title_text else "", + filter_chart, + filter_spec, + filter_info, + f"### {query_title_text}" if query_title_text else "", + query_chart, + query_spec, + query_info, + ) def reset_query(state_dict: AppStateDict): """Reset state to show full dataset.""" @@ -401,7 +577,20 @@ def reset_query(state_dict: AppStateDict): state_holder.change( fn=update_displays, inputs=[state_holder], - outputs=[sql_title, sql_display, data_display, data_info], + outputs=[ + sql_title, + sql_display, + data_display, + data_info, + filter_plot_title, + filter_plot_display, + filter_ggsql_display, + filter_plot_info, + query_plot_title, + query_plot_display, + query_ggsql_display, + query_plot_info, + ], ) reset_btn.click( diff --git a/pkg-py/src/querychat/_icons.py b/pkg-py/src/querychat/_icons.py index 2b7683da..61880f83 100644 --- a/pkg-py/src/querychat/_icons.py +++ b/pkg-py/src/querychat/_icons.py @@ -2,7 +2,14 @@ from shiny import ui -ICON_NAMES = Literal["arrow-counterclockwise", "funnel-fill", "terminal-fill", "table"] +ICON_NAMES = Literal[ + "arrow-counterclockwise", + "bar-chart-fill", + "funnel-fill", + "graph-up", + "terminal-fill", + "table", +] def bs_icon(name: ICON_NAMES) -> ui.HTML: @@ -14,7 +21,9 @@ def bs_icon(name: ICON_NAMES) -> ui.HTML: BS_ICONS = { "arrow-counterclockwise": '', + "bar-chart-fill": '', "funnel-fill": '', + "graph-up": '', "terminal-fill": '', "table": '', } diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py index e8a7c7f1..83856989 100644 --- a/pkg-py/src/querychat/_querychat_base.py +++ b/pkg-py/src/querychat/_querychat_base.py @@ -25,17 +25,22 @@ from ._utils import MISSING, MISSING_TYPE, is_ibis_table from .tools import ( UpdateDashboardData, + VisualizeDashboardData, + VisualizeQueryData, tool_query, tool_reset_dashboard, tool_update_dashboard, + tool_visualize_dashboard, + tool_visualize_query, ) if TYPE_CHECKING: from collections.abc import Callable + import altair as alt from narwhals.stable.v1.typing import IntoFrame -TOOL_GROUPS = Literal["update", "query"] +TOOL_GROUPS = Literal["update", "query", "visualize_dashboard", "visualize_query"] class QueryChatBase(Generic[IntoFrameT]): @@ -58,7 +63,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -72,7 +82,9 @@ def __init__( "Table name must begin with a letter and contain only letters, numbers, and underscores", ) - self.tools = normalize_tools(tools, default=("update", "query")) + self.tools = normalize_tools( + tools, default=("update", "query", "visualize_dashboard", "visualize_query") + ) self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting # Store init parameters for deferred system prompt building @@ -132,6 +144,8 @@ def client( tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, reset_dashboard: Callable[[], None] | None = None, + visualize_dashboard: Callable[[VisualizeDashboardData], None] | None = None, + visualize_query: Callable[[VisualizeQueryData], None] | None = None, ) -> chatlas.Chat: """ Create a chat client with registered tools. @@ -139,11 +153,16 @@ def client( Parameters ---------- tools - Which tools to include: `"update"`, `"query"`, or both. + Which tools to include: `"update"`, `"query"`, `"visualize_dashboard"`, + `"visualize_query"`, or a combination. update_dashboard Callback when update_dashboard tool succeeds. reset_dashboard Callback when reset_dashboard tool is invoked. + visualize_dashboard + Callback when visualize_dashboard tool succeeds. + visualize_query + Callback when visualize_query tool succeeds. Returns ------- @@ -172,6 +191,14 @@ def client( if "query" in tools: chat.register_tool(tool_query(data_source)) + if "visualize_dashboard" in tools: + viz_fn = visualize_dashboard or (lambda _: None) + chat.register_tool(tool_visualize_dashboard(self._data_source, viz_fn)) + + if "visualize_query" in tools: + query_viz_fn = visualize_query or (lambda _: None) + chat.register_tool(tool_visualize_query(self._data_source, query_viz_fn)) + return chat def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str: @@ -221,6 +248,63 @@ def cleanup(self) -> None: if self._data_source is not None: self._data_source.cleanup() + def ggvis(self, source: Literal["filter", "query"] = "filter") -> alt.Chart | None: + """ + Get the visualization chart. + + Parameters + ---------- + source + Which visualization to return: + - "filter": Chart from visualize_dashboard (updates with filter changes) + - "query": Chart from visualize_query (most recent inline visualization) + + Returns + ------- + alt.Chart | None + The Altair chart, or None if no visualization exists. + + """ + raise NotImplementedError("Subclasses must implement ggvis()") + + def ggsql(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the ggsql specification. + + Parameters + ---------- + source + Which specification to return: + - "filter": VISUALISE spec only (from visualize_dashboard) + - "query": Full ggsql query (from visualize_query) + + Returns + ------- + str | None + The ggsql specification, or None if no visualization exists. + + """ + raise NotImplementedError("Subclasses must implement ggsql()") + + def ggtitle(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the visualization title. + + Parameters + ---------- + source + Which title to return: + - "filter": Title from visualize_dashboard + - "query": Title from visualize_query + + Returns + ------- + str | None + The title, or None if no visualization exists. + + """ + raise NotImplementedError("Subclasses must implement ggtitle()") + def normalize_data_source( data_source: IntoFrame | sqlalchemy.Engine | DataSource, diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py index af0685e0..a7553a3c 100644 --- a/pkg-py/src/querychat/_querychat_core.py +++ b/pkg-py/src/querychat/_querychat_core.py @@ -21,7 +21,7 @@ from chatlas.types import Content from narwhals.stable.v1.typing import IntoFrameT -from .tools import UpdateDashboardData +from .tools import UpdateDashboardData, VisualizeDashboardData, VisualizeQueryData GREETING_PROMPT: str = ( "Please give me a friendly greeting. " @@ -38,10 +38,15 @@ ClientFactory = Callable[ - [Callable[[UpdateDashboardData], None], Callable[[], None]], + [ + Callable[[UpdateDashboardData], None], + Callable[[], None], + Callable[[VisualizeDashboardData], None], + Callable[[VisualizeQueryData], None], + ], Chat, ] -"""Factory that creates a Chat client with update_dashboard and reset_dashboard callbacks.""" +"""Factory that creates a Chat client with dashboard and visualization callbacks.""" class AppStateDict(TypedDict): @@ -51,6 +56,11 @@ class AppStateDict(TypedDict): title: str | None error: str | None turns: list[dict] # Serialized chatlas Turns via model_dump() + # Visualization state - only specs stored, charts rendered on demand + filter_viz_spec: str | None + filter_viz_title: str | None + query_viz_ggsql: str | None + query_viz_title: str | None class DisplayMessage(TypedDict): @@ -69,9 +79,16 @@ def _client_factory( self, update_cb: Callable[[UpdateDashboardData], None], reset_cb: Callable[[], None], + filter_viz_cb: Callable[[VisualizeDashboardData], None], + query_viz_cb: Callable[[VisualizeQueryData], None], ) -> Chat: - """Create a chat client with dashboard callbacks.""" - return self.client(update_dashboard=update_cb, reset_dashboard=reset_cb) # type: ignore[attr-defined] + """Create a chat client with dashboard and visualization callbacks.""" + return self.client( # type: ignore[attr-defined] + update_dashboard=update_cb, + reset_dashboard=reset_cb, + visualize_dashboard=filter_viz_cb, + visualize_query=query_viz_cb, + ) def df(self, state: AppStateDict | None) -> IntoFrameT: """ @@ -204,6 +221,15 @@ class AppState: title: Optional[str] = None error: Optional[str] = None + # Filter visualization state (from visualize_dashboard tool) + # Only specs stored, charts rendered on demand via ggsql.render_altair() + filter_viz_spec: Optional[str] = None + filter_viz_title: Optional[str] = None + + # Query visualization state (from visualize_query tool) + query_viz_ggsql: Optional[str] = None + query_viz_title: Optional[str] = None + def update_dashboard(self, data: UpdateDashboardData) -> None: self.sql = data["query"] self.title = data["title"] @@ -213,6 +239,27 @@ def reset_dashboard(self) -> None: self.sql = None self.title = None self.error = None + # Also clear filter visualization + self.filter_viz_spec = None + self.filter_viz_title = None + + def update_filter_viz( + self, + spec: str, + title: Optional[str], + ) -> None: + """Update filter visualization state.""" + self.filter_viz_spec = spec + self.filter_viz_title = title + + def update_query_viz( + self, + ggsql: str, + title: Optional[str], + ) -> None: + """Update query visualization state.""" + self.query_viz_ggsql = ggsql + self.query_viz_title = title def get_current_data(self) -> IntoFrame: """Get current data, falling back to default if query fails.""" @@ -281,6 +328,10 @@ def to_dict(self) -> AppStateDict: "title": self.title, "error": self.error, "turns": [turn.model_dump() for turn in self.client.get_turns()], + "filter_viz_spec": self.filter_viz_spec, + "filter_viz_title": self.filter_viz_title, + "query_viz_ggsql": self.query_viz_ggsql, + "query_viz_title": self.query_viz_title, } def update_from_dict(self, data: AppStateDict) -> None: @@ -295,6 +346,12 @@ def update_from_dict(self, data: AppStateDict) -> None: turns = [Turn.model_validate(t) for t in turns_data] self.client.set_turns(turns) + # Restore visualization state + self.filter_viz_spec = data.get("filter_viz_spec") + self.filter_viz_title = data.get("filter_viz_title") + self.query_viz_ggsql = data.get("query_viz_ggsql") + self.query_viz_title = data.get("query_viz_title") + def create_app_state( data_source: DataSource, @@ -316,7 +373,21 @@ def reset_callback() -> None: raise RuntimeError("Callback invoked before state initialization") state.reset_dashboard() - client = client_factory(update_callback, reset_callback) + def filter_viz_callback(data: VisualizeDashboardData) -> None: + state = state_holder["state"] + if state is None: + raise RuntimeError("Callback invoked before state initialization") + state.update_filter_viz(data["spec"], data["title"]) + + def query_viz_callback(data: VisualizeQueryData) -> None: + state = state_holder["state"] + if state is None: + raise RuntimeError("Callback invoked before state initialization") + state.update_query_viz(data["ggsql"], data["title"]) + + client = client_factory( + update_callback, reset_callback, filter_viz_callback, query_viz_callback + ) state = AppState( data_source=data_source, client=client, diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py index c1dcc9a1..016ad59d 100644 --- a/pkg-py/src/querychat/_shiny.py +++ b/pkg-py/src/querychat/_shiny.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from pathlib import Path + import altair as alt import chatlas import ibis import narwhals.stable.v1 as nw @@ -97,10 +98,10 @@ class QueryChat(QueryChatBase[IntoFrameT]): tools Which querychat tools to include in the chat client by default. Can be: - A single tool string: `"update"` or `"query"` - - A tuple of tools: `("update", "query")` + - A tuple of tools: `("update", "query", "visualize_dashboard", "visualize_query")` - `None` or `()` to disable all tools - Default is `("update", "query")` (both tools enabled). + Default is `("update", "query", "visualize_dashboard", "visualize_query")` (all tools enabled). Set to `"update"` to prevent the LLM from accessing data values, only allowing dashboard filtering without answering questions. @@ -156,7 +157,12 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -172,7 +178,12 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -188,7 +199,12 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -204,7 +220,12 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -219,7 +240,12 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -245,9 +271,14 @@ def app( """ Quickly chat with a dataset. - Creates a Shiny app with a chat sidebar and data table view -- providing a + Creates a Shiny app with a chat sidebar and tabbed view -- providing a quick-and-easy way to start chatting with your data. + The app includes three tabs: + - **Data**: Shows the filtered data table + - **Filter Plot**: Shows the persistent dashboard visualization + - **Query Plot**: Shows the most recent query visualization + Parameters ---------- bookmark_store @@ -285,9 +316,23 @@ def app_ui(request): fill=False, style="max-height: 33%;", ), - ui.card( - ui.card_header(bs_icon("table"), " Data"), - ui.output_data_frame("dt"), + ui.navset_tab( + ui.nav_panel( + "Data", + ui.card( + ui.card_header(bs_icon("table"), " Data"), + ui.output_data_frame("dt"), + ), + ), + ui.nav_panel( + "Filter Plot", + ui.output_ui("filter_plot_container"), + ), + ui.nav_panel( + "Query Plot", + ui.output_ui("query_plot_container"), + ), + id="main_tabs", ), title=ui.span("querychat with ", ui.code(table_name)), class_="bslib-page-dashboard", @@ -301,6 +346,7 @@ def app_server(input: Inputs, output: Outputs, session: Session): greeting=self.greeting, client=self._client, enable_bookmarking=enable_bookmarking, + tools=self.tools, ) @render.text @@ -338,6 +384,64 @@ def sql_output(): width="100%", ) + @render.ui + def filter_plot_container(): + from shinywidgets import output_widget, render_altair + + chart = vals.filter_viz_chart() + if chart is None: + return ui.card( + ui.card_body( + ui.p( + "No filter visualization yet. " + "Use the chat to create one." + ), + class_="text-muted text-center py-5", + ), + ) + + @render_altair + def filter_chart(): + return chart + + return ui.card( + ui.card_header( + bs_icon("bar-chart-fill"), + " ", + vals.filter_viz_title.get() or "Filter Visualization", + ), + output_widget("filter_chart"), + ) + + @render.ui + def query_plot_container(): + from shinywidgets import output_widget, render_altair + + chart = vals.query_viz_chart() + if chart is None: + return ui.card( + ui.card_body( + ui.p( + "No query visualization yet. " + "Use the chat to create one." + ), + class_="text-muted text-center py-5", + ), + ) + + @render_altair + def query_chart(): + return chart + + return ui.card( + ui.card_header( + bs_icon("bar-chart-fill"), + " ", + vals.query_viz_title.get() or "Query Visualization", + ), + output_widget("query_chart"), + ) + return App(app_ui, app_server, bookmark_store=bookmark_store) def sidebar( @@ -493,6 +597,7 @@ def title(): greeting=self.greeting, client=self.client, enable_bookmarking=enable_bookmarking, + tools=self.tools, ) @@ -730,6 +835,7 @@ def __init__( greeting=self.greeting, client=self._client, enable_bookmarking=enable, + tools=self.tools, ) def sidebar( @@ -870,3 +976,69 @@ def title(self, value: Optional[str] = None) -> str | None | bool: return self._vals.title() else: return self._vals.title.set(value) + + def ggvis(self, source: Literal["filter", "query"] = "filter") -> alt.Chart | None: + """ + Get the visualization chart. + + Parameters + ---------- + source + Which visualization to return: + - "filter": Chart from visualize_dashboard (updates with filter changes) + - "query": Chart from visualize_query (most recent inline visualization) + + Returns + ------- + : + The Altair chart, or None if no visualization exists. + + """ + if source == "filter": + return self._vals.filter_viz_chart() + else: + return self._vals.query_viz_chart() + + def ggsql(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the ggsql specification. + + Parameters + ---------- + source + Which specification to return: + - "filter": VISUALISE spec only (from visualize_dashboard) + - "query": Full ggsql query (from visualize_query) + + Returns + ------- + : + The ggsql specification, or None if no visualization exists. + + """ + if source == "filter": + return self._vals.filter_viz_spec.get() + else: + return self._vals.query_viz_ggsql.get() + + def ggtitle(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the visualization title. + + Parameters + ---------- + source + Which title to return: + - "filter": Title from visualize_dashboard + - "query": Title from visualize_query + + Returns + ------- + : + The title, or None if no visualization exists. + + """ + if source == "filter": + return self._vals.filter_viz_title.get() + else: + return self._vals.query_viz_title.get() diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py index 335f6803..584f7aff 100644 --- a/pkg-py/src/querychat/_shiny_module.py +++ b/pkg-py/src/querychat/_shiny_module.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import logging import warnings from dataclasses import dataclass from pathlib import Path @@ -13,17 +14,26 @@ from shiny import module, reactive, ui from ._querychat_core import GREETING_PROMPT -from .tools import tool_query, tool_reset_dashboard, tool_update_dashboard +from .tools import ( + tool_query, + tool_reset_dashboard, + tool_update_dashboard, + tool_visualize_dashboard, + tool_visualize_query, +) if TYPE_CHECKING: from collections.abc import Callable + import altair as alt from shiny.bookmark import BookmarkState, RestoreState from shiny import Inputs, Outputs, Session from ._datasource import DataSource - from .types import UpdateDashboardData + from .tools import UpdateDashboardData, VisualizeDashboardData, VisualizeQueryData + +logger = logging.getLogger(__name__) ReactiveString = reactive.Value[str] """A reactive string value.""" @@ -79,6 +89,26 @@ class ServerValues(Generic[IntoFrameT]): The session-specific chat client instance. This is a deep copy of the base client configured for this specific session, containing the chat history and tool registrations for this session only. + filter_viz_spec + A reactive Value containing the VISUALISE spec from visualize_dashboard. + Returns `None` if no visualization has been created. + filter_viz_title + A reactive Value containing the title from visualize_dashboard. + Returns `None` if no visualization has been created. + filter_viz_chart + A callable returning the rendered Altair chart from visualize_dashboard. + Returns `None` if no visualization has been created. The chart is + re-rendered on each call using `ggsql.render_altair()`. + query_viz_ggsql + A reactive Value containing the full ggsql query from visualize_query. + Returns `None` if no visualization has been created. + query_viz_title + A reactive Value containing the title from visualize_query. + Returns `None` if no visualization has been created. + query_viz_chart + A callable returning the rendered Altair chart from visualize_query. + Returns `None` if no visualization has been created. The chart is + re-rendered on each call using `ggsql.render_altair()`. """ @@ -86,6 +116,13 @@ class ServerValues(Generic[IntoFrameT]): sql: ReactiveStringOrNone title: ReactiveStringOrNone client: chatlas.Chat + # Visualization state + filter_viz_spec: ReactiveStringOrNone + filter_viz_title: ReactiveStringOrNone + filter_viz_chart: Callable[[], alt.TopLevelMixin | None] + query_viz_ggsql: ReactiveStringOrNone + query_viz_title: ReactiveStringOrNone + query_viz_chart: Callable[[], alt.TopLevelMixin | None] @module.server @@ -98,12 +135,19 @@ def mod_server( greeting: str | None, client: chatlas.Chat | Callable, enable_bookmarking: bool, + tools: tuple[str, ...] | None = None, ) -> ServerValues[IntoFrameT]: # Reactive values to store state sql = ReactiveStringOrNone(None) title = ReactiveStringOrNone(None) has_greeted = reactive.value[bool](False) # noqa: FBT003 + # Visualization state - store only specs, render on demand + filter_viz_spec: reactive.Value[str | None] = reactive.value(None) + filter_viz_title: reactive.Value[str | None] = reactive.value(None) + query_viz_ggsql: reactive.Value[str | None] = reactive.value(None) + query_viz_title: reactive.Value[str | None] = reactive.value(None) + # Short-circuit for stub sessions (e.g. 1st run of an Express app) # data_source may be None during stub session for deferred pattern if session.is_stub_session(): @@ -116,6 +160,12 @@ def _stub_df(): sql=sql, title=title, client=client if isinstance(client, chatlas.Chat) else client(), + filter_viz_spec=filter_viz_spec, + filter_viz_title=filter_viz_title, + filter_viz_chart=lambda: filter_viz_chart.get(), + query_viz_ggsql=query_viz_ggsql, + query_viz_title=query_viz_title, + query_viz_chart=lambda: query_viz_chart.get(), ) # Real session requires data_source @@ -133,6 +183,14 @@ def reset_dashboard(): sql.set(None) title.set(None) + def update_filter_viz(data: VisualizeDashboardData): + filter_viz_spec.set(data["spec"]) + filter_viz_title.set(data["title"]) + + def update_query_viz(data: VisualizeQueryData): + query_viz_ggsql.set(data["ggsql"]) + query_viz_title.set(data["title"]) + # Set up the chat object for this session # Support both a callable that creates a client and legacy instance pattern if callable(client) and not isinstance(client, chatlas.Chat): @@ -147,6 +205,12 @@ def reset_dashboard(): chat.register_tool(tool_query(data_source)) chat.register_tool(tool_reset_dashboard(reset_dashboard)) + # Register visualization tools if enabled + if tools and "visualize_dashboard" in tools: + chat.register_tool(tool_visualize_dashboard(data_source, update_filter_viz)) + if tools and "visualize_query" in tools: + chat.register_tool(tool_visualize_query(data_source, update_query_viz)) + # Execute query when SQL changes @reactive.calc def filtered_df(): @@ -154,6 +218,33 @@ def filtered_df(): df = data_source.get_data() if not query else data_source.execute_query(query) return df + # Render filter visualization on demand + @reactive.calc + def render_filter_viz_chart(): + """Render filter visualization using current filtered data.""" + import ggsql + + spec = filter_viz_spec.get() + if spec is None: + return None + + current_df = filtered_df() + return ggsql.render_altair(current_df, spec) + + # Render query visualization on demand + @reactive.calc + def render_query_viz_chart(): + """Render query visualization by re-executing the ggsql query.""" + import ggsql + + ggsql_query = query_viz_ggsql.get() + if ggsql_query is None: + return None + + sql_part, viz_spec = ggsql.split_query(ggsql_query) + df = data_source.execute_query(sql_part) + return ggsql.render_altair(df, viz_spec) + # Chat UI logic chat_ui = shinychat.Chat(CHAT_ID) @@ -220,7 +311,18 @@ def _on_restore(x: RestoreState) -> None: if "querychat_has_greeted" in vals: has_greeted.set(vals["querychat_has_greeted"]) - return ServerValues(df=filtered_df, sql=sql, title=title, client=chat) + return ServerValues( + df=filtered_df, + sql=sql, + title=title, + client=chat, + filter_viz_spec=filter_viz_spec, + filter_viz_title=filter_viz_title, + filter_viz_chart=render_filter_viz_chart, + query_viz_ggsql=query_viz_ggsql, + query_viz_title=query_viz_title, + query_viz_chart=render_query_viz_chart, + ) class GreetWarning(Warning): diff --git a/pkg-py/src/querychat/_streamlit.py b/pkg-py/src/querychat/_streamlit.py index 484b8f4a..a5f30614 100644 --- a/pkg-py/src/querychat/_streamlit.py +++ b/pkg-py/src/querychat/_streamlit.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, cast, overload +from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload from narwhals.stable.v1.typing import IntoDataFrameT, IntoFrameT, IntoLazyFrameT @@ -19,6 +19,7 @@ if TYPE_CHECKING: from pathlib import Path + import altair as alt import chatlas import ibis import narwhals.stable.v1 as nw @@ -81,7 +82,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -96,7 +102,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -111,7 +122,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -126,7 +142,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -140,7 +161,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -167,9 +193,11 @@ def _get_state(self) -> AppState: if self._state_key not in st.session_state: st.session_state[self._state_key] = create_app_state( data_source, - lambda update_cb, reset_cb: self.client( + lambda update_cb, reset_cb, filter_viz_cb, query_viz_cb: self.client( update_dashboard=update_cb, reset_dashboard=reset_cb, + visualize_dashboard=filter_viz_cb, + visualize_query=query_viz_cb, ), self.greeting, ) @@ -180,7 +208,12 @@ def app(self) -> None: Render a complete Streamlit app. Configures the page, renders chat in sidebar, and displays - SQL query and data table in the main area. + SQL query, data table, and visualizations in a tabbed interface. + + The app includes three tabs: + - **Data**: Shows the filtered data table with the current SQL query + - **Filter Plot**: Shows the persistent dashboard visualization + - **Query Plot**: Shows the most recent query visualization """ data_source = self._require_data_source("app") import streamlit as st @@ -192,7 +225,23 @@ def app(self) -> None: ) self.sidebar() - self._render_main_content() + + state = self._get_state() + + st.title(f"querychat with `{self._data_source.table_name}`") + + data_tab, filter_plot_tab, query_plot_tab = st.tabs( + ["Data", "Filter Plot", "Query Plot"] + ) + + with data_tab: + self._render_data_tab(state) + + with filter_plot_tab: + self._render_filter_plot_tab() + + with query_plot_tab: + self._render_query_plot_tab() def sidebar(self) -> None: """Render the chat interface in the Streamlit sidebar.""" @@ -303,14 +352,92 @@ def reset(self) -> None: state.reset_dashboard() st.rerun() - def _render_main_content(self) -> None: - """Render the main content area (SQL + data table).""" - data_source = self._require_data_source("_render_main_content") - import streamlit as st + def ggvis(self, source: Literal["filter", "query"] = "filter") -> alt.Chart | None: + """ + Get the current Altair visualization chart. + + Parameters + ---------- + source + Which visualization to return. "filter" returns the dashboard + visualization (from visualize_dashboard tool), "query" returns + the query visualization (from visualize_query tool). + + Returns + ------- + : + An Altair Chart object, or None if no visualization exists. + + """ + import ggsql + + from ._utils import as_narwhals state = self._get_state() + if source == "filter": + spec = state.filter_viz_spec + if spec is None: + return None + # Render against current filtered data + df = as_narwhals(self.df()) + return ggsql.render_altair(df, spec) + else: + ggsql_query = state.query_viz_ggsql + if ggsql_query is None: + return None + # Re-execute SQL and render + sql, viz_spec = ggsql.split_query(ggsql_query) + df = as_narwhals(self._data_source.execute_query(sql)) + return ggsql.render_altair(df, viz_spec) + + def ggsql(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the current ggsql specification. - st.title(f"querychat with `{data_source.table_name}`") + Parameters + ---------- + source + Which specification to return. "filter" returns the VISUALISE spec + from visualize_dashboard, "query" returns the full ggsql query + from visualize_query. + + Returns + ------- + : + The ggsql specification string, or None if no visualization exists. + + """ + state = self._get_state() + if source == "filter": + return state.filter_viz_spec + else: + return state.query_viz_ggsql + + def ggtitle(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the current visualization title. + + Parameters + ---------- + source + Which title to return. "filter" returns the title from + visualize_dashboard, "query" returns the title from visualize_query. + + Returns + ------- + : + The visualization title, or None if no visualization exists. + + """ + state = self._get_state() + if source == "filter": + return state.filter_viz_title + else: + return state.query_viz_title + + def _render_data_tab(self, state: AppState) -> None: + """Render the Data tab content.""" + import streamlit as st st.subheader(state.title or "SQL Query") @@ -331,3 +458,45 @@ def _render_main_content(self) -> None: df.to_native(), use_container_width=True, height=400, hide_index=True ) st.caption(f"Data has {df.shape[0]} rows and {df.shape[1]} columns.") + + def _render_filter_plot_tab(self) -> None: + """Render the Filter Plot tab content.""" + import streamlit as st + + chart = self.ggvis("filter") + if chart is not None: + title = self.ggtitle("filter") + if title: + st.subheader(title) + st.altair_chart(chart, use_container_width=True) + + spec = self.ggsql("filter") + if spec: + with st.expander("ggsql spec"): + st.code(spec, language="sql") + else: + st.info( + "No filter visualization. Ask the assistant to create one " + "using the visualize_dashboard tool." + ) + + def _render_query_plot_tab(self) -> None: + """Render the Query Plot tab content.""" + import streamlit as st + + chart = self.ggvis("query") + if chart is not None: + title = self.ggtitle("query") + if title: + st.subheader(title) + st.altair_chart(chart, use_container_width=True) + + spec = self.ggsql("query") + if spec: + with st.expander("ggsql query"): + st.code(spec, language="sql") + else: + st.info( + "No query visualization. Ask the assistant to create one " + "using the visualize_query tool." + ) diff --git a/pkg-py/src/querychat/_system_prompt.py b/pkg-py/src/querychat/_system_prompt.py index 7b4f737a..bf9774f9 100644 --- a/pkg-py/src/querychat/_system_prompt.py +++ b/pkg-py/src/querychat/_system_prompt.py @@ -77,6 +77,10 @@ def render(self, tools: tuple[TOOL_GROUPS, ...] | None) -> str: "extra_instructions": self.extra_instructions, "has_tool_update": "update" in tools if tools else False, "has_tool_query": "query" in tools if tools else False, + "has_tool_visualize_dashboard": "visualize_dashboard" in tools + if tools + else False, + "has_tool_visualize_query": "visualize_query" in tools if tools else False, "include_query_guidelines": len(tools or ()) > 0, } diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py index 555e8e37..ce2c0197 100644 --- a/pkg-py/src/querychat/_utils.py +++ b/pkg-py/src/querychat/_utils.py @@ -171,14 +171,18 @@ def get_tool_details_setting() -> Optional[Literal["expanded", "collapsed", "def return setting_lower -def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> bool: +def querychat_tool_starts_open( + action: Literal[ + "update", "query", "reset", "visualize_dashboard", "visualize_query" + ], +) -> bool: """ Determine whether a tool card should be open based on action and setting. Parameters ---------- action : str - The action type ('update', 'query', or 'reset') + The action type ('update', 'query', 'reset', 'visualize_dashboard', or 'visualize_query') Returns ------- @@ -189,14 +193,14 @@ def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> b setting = get_tool_details_setting() if setting is None: - return action != "reset" + return action in ("query", "update", "visualize_dashboard", "visualize_query") if setting == "expanded": return True elif setting == "collapsed": return False else: # setting == "default" - return action != "reset" + return action in ("query", "update", "visualize_dashboard", "visualize_query") def is_ibis_table(obj: Any) -> TypeGuard[ibis.Table]: diff --git a/pkg-py/src/querychat/prompts/prompt.md b/pkg-py/src/querychat/prompts/prompt.md index 8c6ff97b..c5315102 100644 --- a/pkg-py/src/querychat/prompts/prompt.md +++ b/pkg-py/src/querychat/prompts/prompt.md @@ -180,6 +180,48 @@ You might want to explore the advanced features - Never use generic phrases like "If you'd like to..." or "Would you like to explore..." — instead, provide concrete suggestions - Never refer to suggestions as "prompts" – call them "suggestions" or "ideas" or similar +{{#has_tool_visualize_dashboard}} +## Visualization with ggsql + +You can create visualizations using the `visualize_dashboard` and `visualize_query` tools. These use ggsql, a SQL extension for declarative data visualization. + +### Basic Syntax + +``` +SELECT