From e564ead8de51ef7f9e6c86909d8a0818bab94268 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 27 Jan 2026 15:49:48 -0600 Subject: [PATCH 1/3] feat(pkg-py): add ggsql visualization infrastructure Add core support for LLM-generated visualizations using ggsql syntax: - Add ggsql and altair dependencies to pyproject.toml - Create _ggsql.py with helpers for parsing and rendering visualizations - Extend AppState with visualization state fields (filter_viz_*, query_viz_*) - Implement visualize_dashboard and visualize_query tools in tools.py - Add prompt templates for visualization tools with ggsql syntax reference - Update system prompt with ggsql grammar documentation - Add visualization accessor methods to QueryChatBase - Export visualization data types (VisualizeDashboardData, VisualizeQueryData) The ggsql DSL allows the LLM to generate chart specifications that are rendered to Altair/Vega-Lite charts, supporting bar, line, point, area, and boxplot marks with various encodings. Co-Authored-By: Claude Opus 4.5 --- pkg-py/src/querychat/__init__.py | 3 + pkg-py/src/querychat/_datasource.py | 1 + pkg-py/src/querychat/_ggsql.py | 74 ++++++ pkg-py/src/querychat/_icons.py | 11 +- pkg-py/src/querychat/_querychat_base.py | 92 ++++++- pkg-py/src/querychat/_querychat_core.py | 83 ++++++- pkg-py/src/querychat/_system_prompt.py | 4 + pkg-py/src/querychat/_utils.py | 12 +- pkg-py/src/querychat/prompts/prompt.md | 42 ++++ .../prompts/tool-visualize-dashboard.md | 102 ++++++++ .../querychat/prompts/tool-visualize-query.md | 103 ++++++++ pkg-py/src/querychat/tools.py | 235 ++++++++++++++++++ pkg-py/src/querychat/types/__init__.py | 4 +- pyproject.toml | 5 +- 14 files changed, 754 insertions(+), 17 deletions(-) create mode 100644 pkg-py/src/querychat/_ggsql.py create mode 100644 pkg-py/src/querychat/prompts/tool-visualize-dashboard.md create mode 100644 pkg-py/src/querychat/prompts/tool-visualize-query.md diff --git a/pkg-py/src/querychat/__init__.py b/pkg-py/src/querychat/__init__.py index 0e3eaa5f..00527afb 100644 --- a/pkg-py/src/querychat/__init__.py +++ b/pkg-py/src/querychat/__init__.py @@ -2,9 +2,12 @@ from ._deprecated import mod_server as server from ._deprecated import mod_ui as ui from ._shiny import QueryChat +from .tools import VisualizeDashboardData, VisualizeQueryData __all__ = ( "QueryChat", + "VisualizeDashboardData", + "VisualizeQueryData", # TODO(lifecycle): Remove these deprecated functions when we reach v1.0 "greeting", "init", diff --git a/pkg-py/src/querychat/_datasource.py b/pkg-py/src/querychat/_datasource.py index 5cac5f08..af7628f5 100644 --- a/pkg-py/src/querychat/_datasource.py +++ b/pkg-py/src/querychat/_datasource.py @@ -214,6 +214,7 @@ def __init__(self, df: nw.DataFrame, table_name: str): self._df_lib = native_namespace.__name__ self._conn = duckdb.connect(database=":memory:") + # NOTE: if native representation is polars, pyarrow is required for registration self._conn.register(table_name, self._df.to_native()) self._conn.execute(""" -- extensions: lock down supply chain + auto behaviors diff --git a/pkg-py/src/querychat/_ggsql.py b/pkg-py/src/querychat/_ggsql.py new file mode 100644 index 00000000..9c7bb2cc --- /dev/null +++ b/pkg-py/src/querychat/_ggsql.py @@ -0,0 +1,74 @@ +"""Helpers for ggsql integration.""" + +from __future__ import annotations + +import re + + +def extract_title(viz_spec: str) -> str | None: + """ + Extract the title from a VISUALISE spec's LABEL clause. + + Parameters + ---------- + viz_spec + The VISUALISE portion of a ggsql query. + + Returns + ------- + str | None + The title if found, otherwise None. + + """ + # Match LABEL title => 'value' or LABEL title => "value" + pattern = r"LABEL\s+title\s*=>\s*['\"]([^'\"]+)['\"]" + match = re.search(pattern, viz_spec, re.IGNORECASE) + if match: + return match.group(1) + return None + + +def vegalite_to_html(vegalite_spec: dict) -> str: + """ + Convert a Vega-Lite specification to standalone HTML. + + This renders the spec directly using vega-embed. + + Parameters + ---------- + vegalite_spec + A Vega-Lite specification as a dictionary. + + Returns + ------- + str + A complete HTML document that renders the chart. + + """ + import json + + spec_json = json.dumps(vegalite_spec) + + # ggsql produces v6 specs + vl_version = "6" + + return f""" + + + + + + + + + +
+ + +""" diff --git a/pkg-py/src/querychat/_icons.py b/pkg-py/src/querychat/_icons.py index 2b7683da..61880f83 100644 --- a/pkg-py/src/querychat/_icons.py +++ b/pkg-py/src/querychat/_icons.py @@ -2,7 +2,14 @@ from shiny import ui -ICON_NAMES = Literal["arrow-counterclockwise", "funnel-fill", "terminal-fill", "table"] +ICON_NAMES = Literal[ + "arrow-counterclockwise", + "bar-chart-fill", + "funnel-fill", + "graph-up", + "terminal-fill", + "table", +] def bs_icon(name: ICON_NAMES) -> ui.HTML: @@ -14,7 +21,9 @@ def bs_icon(name: ICON_NAMES) -> ui.HTML: BS_ICONS = { "arrow-counterclockwise": '', + "bar-chart-fill": '', "funnel-fill": '', + "graph-up": '', "terminal-fill": '', "table": '', } diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py index e8a7c7f1..83856989 100644 --- a/pkg-py/src/querychat/_querychat_base.py +++ b/pkg-py/src/querychat/_querychat_base.py @@ -25,17 +25,22 @@ from ._utils import MISSING, MISSING_TYPE, is_ibis_table from .tools import ( UpdateDashboardData, + VisualizeDashboardData, + VisualizeQueryData, tool_query, tool_reset_dashboard, tool_update_dashboard, + tool_visualize_dashboard, + tool_visualize_query, ) if TYPE_CHECKING: from collections.abc import Callable + import altair as alt from narwhals.stable.v1.typing import IntoFrame -TOOL_GROUPS = Literal["update", "query"] +TOOL_GROUPS = Literal["update", "query", "visualize_dashboard", "visualize_query"] class QueryChatBase(Generic[IntoFrameT]): @@ -58,7 +63,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -72,7 +82,9 @@ def __init__( "Table name must begin with a letter and contain only letters, numbers, and underscores", ) - self.tools = normalize_tools(tools, default=("update", "query")) + self.tools = normalize_tools( + tools, default=("update", "query", "visualize_dashboard", "visualize_query") + ) self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting # Store init parameters for deferred system prompt building @@ -132,6 +144,8 @@ def client( tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, reset_dashboard: Callable[[], None] | None = None, + visualize_dashboard: Callable[[VisualizeDashboardData], None] | None = None, + visualize_query: Callable[[VisualizeQueryData], None] | None = None, ) -> chatlas.Chat: """ Create a chat client with registered tools. @@ -139,11 +153,16 @@ def client( Parameters ---------- tools - Which tools to include: `"update"`, `"query"`, or both. + Which tools to include: `"update"`, `"query"`, `"visualize_dashboard"`, + `"visualize_query"`, or a combination. update_dashboard Callback when update_dashboard tool succeeds. reset_dashboard Callback when reset_dashboard tool is invoked. + visualize_dashboard + Callback when visualize_dashboard tool succeeds. + visualize_query + Callback when visualize_query tool succeeds. Returns ------- @@ -172,6 +191,14 @@ def client( if "query" in tools: chat.register_tool(tool_query(data_source)) + if "visualize_dashboard" in tools: + viz_fn = visualize_dashboard or (lambda _: None) + chat.register_tool(tool_visualize_dashboard(self._data_source, viz_fn)) + + if "visualize_query" in tools: + query_viz_fn = visualize_query or (lambda _: None) + chat.register_tool(tool_visualize_query(self._data_source, query_viz_fn)) + return chat def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str: @@ -221,6 +248,63 @@ def cleanup(self) -> None: if self._data_source is not None: self._data_source.cleanup() + def ggvis(self, source: Literal["filter", "query"] = "filter") -> alt.Chart | None: + """ + Get the visualization chart. + + Parameters + ---------- + source + Which visualization to return: + - "filter": Chart from visualize_dashboard (updates with filter changes) + - "query": Chart from visualize_query (most recent inline visualization) + + Returns + ------- + alt.Chart | None + The Altair chart, or None if no visualization exists. + + """ + raise NotImplementedError("Subclasses must implement ggvis()") + + def ggsql(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the ggsql specification. + + Parameters + ---------- + source + Which specification to return: + - "filter": VISUALISE spec only (from visualize_dashboard) + - "query": Full ggsql query (from visualize_query) + + Returns + ------- + str | None + The ggsql specification, or None if no visualization exists. + + """ + raise NotImplementedError("Subclasses must implement ggsql()") + + def ggtitle(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the visualization title. + + Parameters + ---------- + source + Which title to return: + - "filter": Title from visualize_dashboard + - "query": Title from visualize_query + + Returns + ------- + str | None + The title, or None if no visualization exists. + + """ + raise NotImplementedError("Subclasses must implement ggtitle()") + def normalize_data_source( data_source: IntoFrame | sqlalchemy.Engine | DataSource, diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py index af0685e0..a7553a3c 100644 --- a/pkg-py/src/querychat/_querychat_core.py +++ b/pkg-py/src/querychat/_querychat_core.py @@ -21,7 +21,7 @@ from chatlas.types import Content from narwhals.stable.v1.typing import IntoFrameT -from .tools import UpdateDashboardData +from .tools import UpdateDashboardData, VisualizeDashboardData, VisualizeQueryData GREETING_PROMPT: str = ( "Please give me a friendly greeting. " @@ -38,10 +38,15 @@ ClientFactory = Callable[ - [Callable[[UpdateDashboardData], None], Callable[[], None]], + [ + Callable[[UpdateDashboardData], None], + Callable[[], None], + Callable[[VisualizeDashboardData], None], + Callable[[VisualizeQueryData], None], + ], Chat, ] -"""Factory that creates a Chat client with update_dashboard and reset_dashboard callbacks.""" +"""Factory that creates a Chat client with dashboard and visualization callbacks.""" class AppStateDict(TypedDict): @@ -51,6 +56,11 @@ class AppStateDict(TypedDict): title: str | None error: str | None turns: list[dict] # Serialized chatlas Turns via model_dump() + # Visualization state - only specs stored, charts rendered on demand + filter_viz_spec: str | None + filter_viz_title: str | None + query_viz_ggsql: str | None + query_viz_title: str | None class DisplayMessage(TypedDict): @@ -69,9 +79,16 @@ def _client_factory( self, update_cb: Callable[[UpdateDashboardData], None], reset_cb: Callable[[], None], + filter_viz_cb: Callable[[VisualizeDashboardData], None], + query_viz_cb: Callable[[VisualizeQueryData], None], ) -> Chat: - """Create a chat client with dashboard callbacks.""" - return self.client(update_dashboard=update_cb, reset_dashboard=reset_cb) # type: ignore[attr-defined] + """Create a chat client with dashboard and visualization callbacks.""" + return self.client( # type: ignore[attr-defined] + update_dashboard=update_cb, + reset_dashboard=reset_cb, + visualize_dashboard=filter_viz_cb, + visualize_query=query_viz_cb, + ) def df(self, state: AppStateDict | None) -> IntoFrameT: """ @@ -204,6 +221,15 @@ class AppState: title: Optional[str] = None error: Optional[str] = None + # Filter visualization state (from visualize_dashboard tool) + # Only specs stored, charts rendered on demand via ggsql.render_altair() + filter_viz_spec: Optional[str] = None + filter_viz_title: Optional[str] = None + + # Query visualization state (from visualize_query tool) + query_viz_ggsql: Optional[str] = None + query_viz_title: Optional[str] = None + def update_dashboard(self, data: UpdateDashboardData) -> None: self.sql = data["query"] self.title = data["title"] @@ -213,6 +239,27 @@ def reset_dashboard(self) -> None: self.sql = None self.title = None self.error = None + # Also clear filter visualization + self.filter_viz_spec = None + self.filter_viz_title = None + + def update_filter_viz( + self, + spec: str, + title: Optional[str], + ) -> None: + """Update filter visualization state.""" + self.filter_viz_spec = spec + self.filter_viz_title = title + + def update_query_viz( + self, + ggsql: str, + title: Optional[str], + ) -> None: + """Update query visualization state.""" + self.query_viz_ggsql = ggsql + self.query_viz_title = title def get_current_data(self) -> IntoFrame: """Get current data, falling back to default if query fails.""" @@ -281,6 +328,10 @@ def to_dict(self) -> AppStateDict: "title": self.title, "error": self.error, "turns": [turn.model_dump() for turn in self.client.get_turns()], + "filter_viz_spec": self.filter_viz_spec, + "filter_viz_title": self.filter_viz_title, + "query_viz_ggsql": self.query_viz_ggsql, + "query_viz_title": self.query_viz_title, } def update_from_dict(self, data: AppStateDict) -> None: @@ -295,6 +346,12 @@ def update_from_dict(self, data: AppStateDict) -> None: turns = [Turn.model_validate(t) for t in turns_data] self.client.set_turns(turns) + # Restore visualization state + self.filter_viz_spec = data.get("filter_viz_spec") + self.filter_viz_title = data.get("filter_viz_title") + self.query_viz_ggsql = data.get("query_viz_ggsql") + self.query_viz_title = data.get("query_viz_title") + def create_app_state( data_source: DataSource, @@ -316,7 +373,21 @@ def reset_callback() -> None: raise RuntimeError("Callback invoked before state initialization") state.reset_dashboard() - client = client_factory(update_callback, reset_callback) + def filter_viz_callback(data: VisualizeDashboardData) -> None: + state = state_holder["state"] + if state is None: + raise RuntimeError("Callback invoked before state initialization") + state.update_filter_viz(data["spec"], data["title"]) + + def query_viz_callback(data: VisualizeQueryData) -> None: + state = state_holder["state"] + if state is None: + raise RuntimeError("Callback invoked before state initialization") + state.update_query_viz(data["ggsql"], data["title"]) + + client = client_factory( + update_callback, reset_callback, filter_viz_callback, query_viz_callback + ) state = AppState( data_source=data_source, client=client, diff --git a/pkg-py/src/querychat/_system_prompt.py b/pkg-py/src/querychat/_system_prompt.py index 7b4f737a..bf9774f9 100644 --- a/pkg-py/src/querychat/_system_prompt.py +++ b/pkg-py/src/querychat/_system_prompt.py @@ -77,6 +77,10 @@ def render(self, tools: tuple[TOOL_GROUPS, ...] | None) -> str: "extra_instructions": self.extra_instructions, "has_tool_update": "update" in tools if tools else False, "has_tool_query": "query" in tools if tools else False, + "has_tool_visualize_dashboard": "visualize_dashboard" in tools + if tools + else False, + "has_tool_visualize_query": "visualize_query" in tools if tools else False, "include_query_guidelines": len(tools or ()) > 0, } diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py index 555e8e37..ce2c0197 100644 --- a/pkg-py/src/querychat/_utils.py +++ b/pkg-py/src/querychat/_utils.py @@ -171,14 +171,18 @@ def get_tool_details_setting() -> Optional[Literal["expanded", "collapsed", "def return setting_lower -def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> bool: +def querychat_tool_starts_open( + action: Literal[ + "update", "query", "reset", "visualize_dashboard", "visualize_query" + ], +) -> bool: """ Determine whether a tool card should be open based on action and setting. Parameters ---------- action : str - The action type ('update', 'query', or 'reset') + The action type ('update', 'query', 'reset', 'visualize_dashboard', or 'visualize_query') Returns ------- @@ -189,14 +193,14 @@ def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> b setting = get_tool_details_setting() if setting is None: - return action != "reset" + return action in ("query", "update", "visualize_dashboard", "visualize_query") if setting == "expanded": return True elif setting == "collapsed": return False else: # setting == "default" - return action != "reset" + return action in ("query", "update", "visualize_dashboard", "visualize_query") def is_ibis_table(obj: Any) -> TypeGuard[ibis.Table]: diff --git a/pkg-py/src/querychat/prompts/prompt.md b/pkg-py/src/querychat/prompts/prompt.md index 8c6ff97b..c5315102 100644 --- a/pkg-py/src/querychat/prompts/prompt.md +++ b/pkg-py/src/querychat/prompts/prompt.md @@ -180,6 +180,48 @@ You might want to explore the advanced features - Never use generic phrases like "If you'd like to..." or "Would you like to explore..." — instead, provide concrete suggestions - Never refer to suggestions as "prompts" – call them "suggestions" or "ideas" or similar +{{#has_tool_visualize_dashboard}} +## Visualization with ggsql + +You can create visualizations using the `visualize_dashboard` and `visualize_query` tools. These use ggsql, a SQL extension for declarative data visualization. + +### Basic Syntax + +``` +SELECT FROM +VISUALISE AS x, AS y [, AS color] +DRAW +[LABEL title => 'Title'] +``` + +### Available Geoms +- `point` - scatter plot +- `line` - line chart +- `bar` - bar chart +- `area` - area chart +- `histogram` - histogram (single variable) +- `boxplot` - box plot + +### When to Use Each Tool + +- **visualize_dashboard**: Creates a persistent chart that updates when filters change. Use for dashboards. +- **visualize_query**: Creates a one-off chart from a specific SQL query. Use for exploratory analysis. + +### Examples + +Scatter plot: +``` +VISUALISE mpg AS x, hp AS y DRAW point +``` + +Time series: +```sql +SELECT date, revenue FROM sales +VISUALISE date AS x, revenue AS y DRAW line +LABEL title => 'Revenue Trend' +``` +{{/has_tool_visualize_dashboard}} + ## Important Guidelines - **Ask for clarification** if any request is unclear or ambiguous diff --git a/pkg-py/src/querychat/prompts/tool-visualize-dashboard.md b/pkg-py/src/querychat/prompts/tool-visualize-dashboard.md new file mode 100644 index 00000000..ecfa9bc5 --- /dev/null +++ b/pkg-py/src/querychat/prompts/tool-visualize-dashboard.md @@ -0,0 +1,102 @@ +Create or update a persistent visualization for the dashboard's Filter Plot tab. + +## Input Format + +Provide a VISUALISE-only specification (no SELECT statement). The visualization will be applied to the current filtered/sorted data. + +## ggsql VISUALISE Syntax + +``` +VISUALISE +[DRAW ] +[SCALE ] +[FACET ] +[COORD ] +[LABEL ] +[THEME ] +``` + +### Mappings +- Basic: `VISUALISE x, y` (column names map to x/y aesthetics) +- Named: `VISUALISE date AS x, revenue AS y, region AS color` +- With aggregation: `VISUALISE category AS x, SUM(amount) AS y` + +### DRAW Geoms +- `point` - scatter plot +- `line` - line chart +- `bar` - bar chart +- `area` - area chart +- `histogram` - histogram (use `DRAW histogram` with single variable) +- `boxplot` - box plot +- `text` - text labels + +### SCALE Configuration +- Type: `SCALE x TYPE log`, `SCALE y TYPE sqrt` +- Palette: `SCALE color PALETTE 'viridis'`, `SCALE color PALETTE 'category10'` +- Domain: `SCALE x DOMAIN [0, 100]` + +### FACET Configuration +- Wrap: `FACET WRAP region` +- Grid: `FACET GRID row => year, col => quarter` + +### COORD Configuration +- Flip: `COORD flip` +- Polar: `COORD polar` + +### LABEL Configuration +- Title: `LABEL title => 'Sales by Region'` +- Axes: `LABEL x => 'Date', y => 'Revenue ($)'` +- Caption: `LABEL caption => 'Data source: internal sales'` + +### THEME Configuration +- Built-in: `THEME dark`, `THEME minimal` + +## Examples + +**Simple scatter plot:** +``` +VISUALISE mpg AS x, hp AS y +DRAW point +LABEL title => 'MPG vs Horsepower' +``` + +**Bar chart with color:** +``` +VISUALISE category AS x, COUNT(*) AS y, category AS color +DRAW bar +LABEL title => 'Count by Category' +``` + +**Time series:** +``` +VISUALISE date AS x, revenue AS y +DRAW line +LABEL title => 'Revenue Over Time', x => 'Date', y => 'Revenue' +``` + +**Faceted histogram:** +``` +VISUALISE age +DRAW histogram +FACET WRAP gender +LABEL title => 'Age Distribution by Gender' +``` + +## Behavior + +- The visualization is applied to the **current filtered data** (after any `update_dashboard` filters) +- When filters change, the visualization automatically re-renders with the new data +- The chart appears in the Filter Plot tab +- Calling this tool again replaces the previous filter visualization + +Parameters +---------- +viz_spec : + A ggsql VISUALISE specification (without SELECT). Must include at least a VISUALISE clause with column mappings. Optional clauses: DRAW, SCALE, FACET, COORD, LABEL, THEME. +title : + A brief, user-friendly title for this visualization. + +Returns +------- +: + A confirmation that the dashboard visualization was updated successfully, or the error that occurred. The visualization will appear in the Filter Plot tab and will automatically update when filters change. diff --git a/pkg-py/src/querychat/prompts/tool-visualize-query.md b/pkg-py/src/querychat/prompts/tool-visualize-query.md new file mode 100644 index 00000000..95c429fb --- /dev/null +++ b/pkg-py/src/querychat/prompts/tool-visualize-query.md @@ -0,0 +1,103 @@ +Run an exploratory visualization query inline in the chat. + +## Input Format + +Provide a full ggsql query with both SELECT and VISUALISE clauses: + +```sql +SELECT +FROM
+[WHERE ] +[GROUP BY ] +VISUALISE +[DRAW ] +[SCALE ] +[FACET ] +[LABEL ] +``` + +## When to Use + +Use this tool when: +- The user asks an exploratory question that benefits from visualization +- You want to show a one-off chart without affecting the dashboard filter +- You need to visualize data with specific SQL transformations + +Use `visualize_dashboard` instead when: +- The user wants a persistent chart that updates with filter changes +- The visualization should reflect the current filtered data + +## ggsql VISUALISE Syntax + +See `visualize_dashboard` tool for full syntax reference. + +### Quick Reference + +**Mappings:** `VISUALISE col1 AS x, col2 AS y, col3 AS color` + +**Geoms:** `point`, `line`, `bar`, `area`, `histogram`, `boxplot`, `text` + +**Labels:** `LABEL title => 'Title', x => 'X Label', y => 'Y Label'` + +## Examples + +**Aggregated bar chart:** +```sql +SELECT region, SUM(sales) as total_sales +FROM orders +GROUP BY region +VISUALISE region AS x, total_sales AS y +DRAW bar +LABEL title => 'Total Sales by Region' +``` + +**Filtered time series:** +```sql +SELECT date, revenue +FROM sales +WHERE year = 2024 +VISUALISE date AS x, revenue AS y +DRAW line +LABEL title => '2024 Revenue Trend' +``` + +**Correlation scatter with subset:** +```sql +SELECT mpg, horsepower, cylinders +FROM cars +WHERE cylinders IN (4, 6, 8) +VISUALISE mpg AS x, horsepower AS y, cylinders AS color +DRAW point +LABEL title => 'MPG vs HP by Cylinder Count' +``` + +**Distribution comparison:** +```sql +SELECT age, gender +FROM users +WHERE age BETWEEN 18 AND 65 +VISUALISE age +DRAW histogram +FACET WRAP gender +LABEL title => 'Age Distribution by Gender' +``` + +## Behavior + +- Executes the SQL query against the data source +- Renders the visualization inline in the chat +- The chart is also accessible via the Query Plot tab +- Does NOT affect the dashboard filter or filtered data +- Each call replaces the previous query visualization + +Parameters +---------- +ggsql : + A full ggsql query with SELECT and VISUALISE clauses. The SELECT portion follows standard {{db_type}} SQL syntax. The VISUALISE portion specifies the chart configuration. +title : + A brief, user-friendly title for this visualization. + +Returns +------- +: + The visualization rendered inline in the chat, or the error that occurred. The chart will also be accessible in the Query Plot tab. Does not affect the dashboard filter state. diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py index 67ea453f..9fc1cc96 100644 --- a/pkg-py/src/querychat/tools.py +++ b/pkg-py/src/querychat/tools.py @@ -69,6 +69,48 @@ def log_update(data: UpdateDashboardData): title: str +class VisualizeDashboardData(TypedDict): + """ + Data passed to visualize_dashboard callback. + + This TypedDict defines the structure of data passed to the + `tool_visualize_dashboard` callback function when the LLM creates a + persistent visualization for the dashboard. + + Attributes + ---------- + spec + The ggsql VISUALISE specification string. + title + A descriptive title for the visualization, or None if not provided. + + """ + + spec: str + title: str | None + + +class VisualizeQueryData(TypedDict): + """ + Data passed to visualize_query callback. + + This TypedDict defines the structure of data passed to the + `tool_visualize_query` callback function when the LLM creates an + exploratory visualization from a ggsql query. + + Attributes + ---------- + ggsql + The full ggsql query string (SQL + VISUALISE). + title + A descriptive title for the visualization, or None if not provided. + + """ + + ggsql: str + title: str | None + + def _read_prompt_template(filename: str, **kwargs) -> str: """Read and interpolate a prompt template file.""" template_path = Path(__file__).parent / "prompts" / filename @@ -286,3 +328,196 @@ def tool_query(data_source: DataSource) -> Tool: name="querychat_query", annotations={"title": "Query Data"}, ) + + +def _visualize_dashboard_impl( + data_source: DataSource, + update_fn: Callable[[VisualizeDashboardData], None], +) -> Callable[[str, str | None], ContentToolResult]: + """Create the visualize_dashboard implementation function.""" + import ggsql + + from ._ggsql import extract_title + + def visualize_dashboard( + viz_spec: str, + title: str | None = None, + ) -> ContentToolResult: + """Create a dashboard visualization from a VISUALISE spec.""" + markdown = f"```ggsql\n{viz_spec}\n```" + + try: + # Validate the spec by rendering it (will raise on error) + df = data_source.get_data() + ggsql.render_altair(df, viz_spec) + + # Extract title from spec if not provided + if title is None: + title = extract_title(viz_spec) + + # Store just the spec - rendering happens on display + update_fn( + { + "spec": viz_spec, + "title": title, + } + ) + + # Format success message + title_display = f" - {title}" if title else "" + markdown += f"\n\nVisualization created{title_display}." + + return ContentToolResult( + value=markdown, + extra={ + "display": ToolResultDisplay( + markdown=markdown, + title=title or "Filter Visualization", + show_request=False, + open=querychat_tool_starts_open("visualize_dashboard"), + icon=bs_icon("bar-chart-fill"), + ), + }, + ) + + except Exception as e: + error_msg = str(e) + markdown += f"\n\n> Error: {error_msg}" + return ContentToolResult(value=markdown, error=e) + + return visualize_dashboard + + +def tool_visualize_dashboard( + data_source: DataSource, + update_fn: Callable[[VisualizeDashboardData], None], +) -> Tool: + """ + Create a tool that creates a persistent visualization for the dashboard. + + Parameters + ---------- + data_source + The data source to visualize + update_fn + Callback function to call with VisualizeDashboardData when visualization succeeds + + Returns + ------- + Tool + A tool that can be registered with chatlas + + """ + impl = _visualize_dashboard_impl(data_source, update_fn) + impl.__doc__ = _read_prompt_template("tool-visualize-dashboard.md") + + return Tool.from_func( + impl, + name="querychat_visualize_dashboard", + annotations={"title": "Create Filter Visualization"}, + ) + + +def _visualize_query_impl( + data_source: DataSource, + update_fn: Callable[[VisualizeQueryData], None], +) -> Callable[[str, str | None], ContentToolResult]: + """Create the visualize_query implementation function.""" + import ggsql as ggsql_pkg + + from ._ggsql import extract_title + + def visualize_query( + ggsql: str, + title: str | None = None, + ) -> ContentToolResult: + """Execute a ggsql query and render the visualization.""" + markdown = f"```sql\n{ggsql}\n```" + + try: + # Split the query + sql, viz_spec = ggsql_pkg.split_query(ggsql) + + if not viz_spec: + raise ValueError( + "Query must include a VISUALISE clause. " + "Use querychat_query for queries without visualization." + ) + + # Execute the SQL and validate by rendering + df = data_source.execute_query(sql) + ggsql_pkg.render_altair(df, viz_spec) + + # Extract title from spec if not provided + if title is None: + title = extract_title(viz_spec) + + # Store just the ggsql - rendering happens on display + update_fn( + { + "ggsql": ggsql, + "title": title, + } + ) + + # Format success message with data summary + nw_df = as_narwhals(df) + row_count = len(nw_df) + col_count = len(nw_df.columns) + + title_display = f" - {title}" if title else "" + markdown += f"\n\nVisualization created{title_display}." + markdown += f"\n\nData: {row_count} rows, {col_count} columns." + + return ContentToolResult( + value=markdown, + extra={ + "display": ToolResultDisplay( + markdown=markdown, + title=title or "Query Visualization", + show_request=False, + open=querychat_tool_starts_open("visualize_query"), + icon=bs_icon("graph-up"), + ), + }, + ) + + except Exception as e: + error_msg = str(e) + markdown += f"\n\n> Error: {error_msg}" + return ContentToolResult(value=markdown, error=e) + + return visualize_query + + +def tool_visualize_query( + data_source: DataSource, + update_fn: Callable[[VisualizeQueryData], None], +) -> Tool: + """ + Create a tool that executes a ggsql query and renders the visualization. + + Parameters + ---------- + data_source + The data source to query against + update_fn + Callback function to call with VisualizeQueryData when visualization succeeds + + Returns + ------- + Tool + A tool that can be registered with chatlas + + """ + impl = _visualize_query_impl(data_source, update_fn) + impl.__doc__ = _read_prompt_template( + "tool-visualize-query.md", + db_type=data_source.get_db_type(), + ) + + return Tool.from_func( + impl, + name="querychat_visualize_query", + annotations={"title": "Query Visualization"}, + ) diff --git a/pkg-py/src/querychat/types/__init__.py b/pkg-py/src/querychat/types/__init__.py index f9a8163d..7498b8e4 100644 --- a/pkg-py/src/querychat/types/__init__.py +++ b/pkg-py/src/querychat/types/__init__.py @@ -9,7 +9,7 @@ from .._querychat_core import AppStateDict from .._shiny_module import ServerValues from .._utils import UnsafeQueryError -from ..tools import UpdateDashboardData +from ..tools import UpdateDashboardData, VisualizeDashboardData, VisualizeQueryData __all__ = ( "AppStateDict", @@ -22,4 +22,6 @@ "ServerValues", "UnsafeQueryError", "UpdateDashboardData", + "VisualizeDashboardData", + "VisualizeQueryData", ) diff --git a/pyproject.toml b/pyproject.toml index dcfa6491..57a9b755 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,12 +23,15 @@ dependencies = [ "duckdb", "shiny>=1.5.1", "shinychat>=0.2.8", + "shinywidgets>=0.3.0", "htmltools", "chatlas>=0.13.2", "narwhals", "chevron", "sqlalchemy>=2.0.0", # Using 2.0+ for improved type hints and API "great-tables>=0.16.0", + "ggsql @ {root:uri}/../ggsql/ggsql-python", + "altair>=5.0", ] classifiers = [ "Programming Language :: Python", @@ -42,7 +45,7 @@ classifiers = [ [project.optional-dependencies] # For SQLAlchemySource and sample data, one of polars or pandas is required pandas = ["pandas"] -polars = ["polars"] +polars = ["polars", "pyarrow"] # duckdb requires pyarrow for polars DataFrame registration ibis = ["ibis-framework>=9.0.0", "pandas"] # pandas required for ibis .execute() to return DataFrames # Web framework extras streamlit = ["streamlit>=1.30"] From 4865e706fdfe72c99d6c68d62b7bd4f824ce4e9f Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 27 Jan 2026 15:50:02 -0600 Subject: [PATCH 2/3] feat(pkg-py): add visualization support to web frameworks Integrate ggsql visualization capabilities into all supported frameworks: Shiny: - Add visualization state to ServerValues dataclass - Implement ggvis(), ggsql(), ggtitle() reactive accessors - Add filter visualization re-rendering on data change - Create tabbed UI with Data/Filter Plot/Query Plot tabs Streamlit: - Add ggvis(), ggsql(), ggtitle() methods reading from session state - Create tabbed app layout with visualization tabs - Render Altair charts with expandable ggsql specs Gradio: - Add ggvis(), ggsql(), ggtitle() methods taking state dict - Create tabbed Blocks layout with visualization displays - Wire state changes to update all visualization outputs Dash: - Add visualization callbacks and state management - Create tabbed layout with dcc.Graph for Altair charts - Add ggsql spec display in collapsible sections All frameworks enable visualization tools by default and support both filter (dashboard) and query visualizations. Co-Authored-By: Claude Opus 4.5 --- pkg-py/src/querychat/_dash.py | 228 ++++++++++++++++++++++- pkg-py/src/querychat/_dash_ui.py | 14 ++ pkg-py/src/querychat/_gradio.py | 249 ++++++++++++++++++++++---- pkg-py/src/querychat/_shiny.py | 194 ++++++++++++++++++-- pkg-py/src/querychat/_shiny_module.py | 108 ++++++++++- pkg-py/src/querychat/_streamlit.py | 197 ++++++++++++++++++-- 6 files changed, 924 insertions(+), 66 deletions(-) diff --git a/pkg-py/src/querychat/_dash.py b/pkg-py/src/querychat/_dash.py index 16d612dc..97d53091 100644 --- a/pkg-py/src/querychat/_dash.py +++ b/pkg-py/src/querychat/_dash.py @@ -24,6 +24,7 @@ from collections.abc import Callable from pathlib import Path as PathType + import altair as alt import chatlas import ibis import narwhals.stable.v1 as nw @@ -113,7 +114,12 @@ def __init__( *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | PathType] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | PathType] = None, @@ -129,7 +135,12 @@ def __init__( *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | PathType] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | PathType] = None, @@ -145,7 +156,12 @@ def __init__( *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | PathType] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | PathType] = None, @@ -161,7 +177,12 @@ def __init__( *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | PathType] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | PathType] = None, @@ -176,7 +197,12 @@ def __init__( *, greeting: Optional[str | PathType] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | PathType] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | PathType] = None, @@ -207,6 +233,96 @@ def store_id(self) -> str: """ return self._ids.store + def ggvis( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> alt.Chart | None: + """ + Get the current Altair visualization chart. + + Parameters + ---------- + state + The state dict from the store component. + source + Which visualization to return. "filter" returns the dashboard + visualization (from visualize_dashboard tool), "query" returns + the query visualization (from visualize_query tool). + + Returns + ------- + : + An Altair Chart object, or None if no visualization exists. + + """ + import ggsql + + if source == "filter": + spec = state.get("filter_viz_spec") + if spec is None: + return None + # Render against current filtered data + df = as_narwhals(self.df(state)) + return ggsql.render_altair(df, spec) + else: + ggsql_query = state.get("query_viz_ggsql") + if ggsql_query is None: + return None + # Re-execute SQL and render + sql, viz_spec = ggsql.split_query(ggsql_query) + df = as_narwhals(self._data_source.execute_query(sql)) + return ggsql.render_altair(df, viz_spec) + + def ggsql( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> str | None: + """ + Get the current ggsql specification. + + Parameters + ---------- + state + The state dict from the store component. + source + Which specification to return. "filter" returns the VISUALISE spec + from visualize_dashboard, "query" returns the full ggsql query + from visualize_query. + + Returns + ------- + : + The ggsql specification string, or None if no visualization exists. + + """ + if source == "filter": + return state.get("filter_viz_spec") + else: + return state.get("query_viz_ggsql") + + def ggtitle( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> str | None: + """ + Get the current visualization title. + + Parameters + ---------- + state + The state dict from the store component. + source + Which title to return. "filter" returns the title from + visualize_dashboard, "query" returns the title from visualize_query. + + Returns + ------- + : + The visualization title, or None if no visualization exists. + + """ + if source == "filter": + return state.get("filter_viz_title") + else: + return state.get("query_viz_title") + def app(self) -> dash.Dash: """ Create a complete Dash app. @@ -393,6 +509,34 @@ def app_layout(ids: IDs, table_name: str, chat_ui): body_class_name="d-flex flex-column", ) + # Filter plot card + filter_plot_card = card_ui( + body=[ + html.Iframe( + id=ids.filter_plot, + srcDoc="

No filter visualization yet. Ask the assistant to create one.

", + style={"width": "100%", "height": "400px", "border": "none"}, + ), + dcc.Markdown(id=ids.filter_ggsql, className="querychat-ggsql-display mt-2"), + ], + title_id=ids.filter_plot_title, + class_name="h-100", + ) + + # Query plot card + query_plot_card = card_ui( + body=[ + html.Iframe( + id=ids.query_plot, + srcDoc="

No query visualization yet. Ask the assistant to create one.

", + style={"width": "100%", "height": "400px", "border": "none"}, + ), + dcc.Markdown(id=ids.query_ggsql, className="querychat-ggsql-display mt-2"), + ], + title_id=ids.query_plot_title, + class_name="h-100", + ) + chat_card = card_ui( body=chat_ui, title="Chat", @@ -407,7 +551,27 @@ def app_layout(ids: IDs, table_name: str, chat_ui): [ dbc.Col(chat_card, width=4), dbc.Col( - [sql_card, data_card], + dbc.Tabs( + [ + dbc.Tab( + [sql_card, data_card], + label="Data", + tab_id="data-tab", + ), + dbc.Tab( + filter_plot_card, + label="Filter Plot", + tab_id="filter-plot-tab", + ), + dbc.Tab( + query_plot_card, + label="Query Plot", + tab_id="query-plot-tab", + ), + ], + id=ids.tabs, + active_tab="data-tab", + ), width=8, className="d-flex flex-column", ), @@ -426,12 +590,15 @@ def register_app_callbacks( table_name: str, deserialize_state: Callable[[AppStateDict], AppState], ) -> None: - """Register callbacks for SQL display, data table, and export.""" + """Register callbacks for SQL display, data table, visualizations, and export.""" + import ggsql from dash.dcc.express import send_data_frame import dash from dash import Input, Output, State + from ._ggsql import vegalite_to_html + @app.callback( [ Output(ids.sql_title, "children"), @@ -440,6 +607,12 @@ def register_app_callbacks( Output(ids.data_table, "columnDefs"), Output(ids.data_info, "children"), Output(ids.store, "data", allow_duplicate=True), + Output(ids.filter_plot_title, "children"), + Output(ids.filter_plot, "srcDoc"), + Output(ids.filter_ggsql, "children"), + Output(ids.query_plot_title, "children"), + Output(ids.query_plot, "srcDoc"), + Output(ids.query_ggsql, "children"), ], [ Input(ids.store, "data"), @@ -459,7 +632,8 @@ def update_display(state_data: AppStateDict, reset_clicks): sql_title = state.title or "SQL Query" sql_code = f"```sql\n{state.get_display_sql()}\n```" - nw_df = as_narwhals(state.get_current_data()) + current_data = state.get_current_data() + nw_df = as_narwhals(current_data) nrow, ncol = nw_df.shape display_df = nw_df.to_pandas() @@ -472,6 +646,38 @@ def update_display(state_data: AppStateDict, reset_clicks): data_info_parts.append(f"Data has {nrow} rows and {ncol} columns.") data_info = " ".join(data_info_parts) + # Filter visualization - render on demand + filter_title = state.filter_viz_title or "Filter Plot" + filter_spec = state.filter_viz_spec + + if filter_spec: + # Render against current filtered data + chart = ggsql.render_altair(nw_df, filter_spec) + filter_html = vegalite_to_html(chart.to_dict()) + else: + filter_html = ( + "

No filter visualization yet. Ask the assistant to create one.

" + ) + + filter_ggsql_md = f"```sql\n{filter_spec}\n```" if filter_spec else "" + + # Query visualization - render on demand + query_title = state.query_viz_title or "Query Plot" + query_ggsql_str = state.query_viz_ggsql + + if query_ggsql_str: + # Re-execute SQL and render + sql_part, viz_spec = ggsql.split_query(query_ggsql_str) + query_df = as_narwhals(state.data_source.execute_query(sql_part)) + chart = ggsql.render_altair(query_df, viz_spec) + query_html = vegalite_to_html(chart.to_dict()) + else: + query_html = ( + "

No query visualization yet. Ask the assistant to create one.

" + ) + + query_ggsql_md = f"```sql\n{query_ggsql_str}\n```" if query_ggsql_str else "" + return ( sql_title, sql_code, @@ -479,6 +685,12 @@ def update_display(state_data: AppStateDict, reset_clicks): table_columns, data_info, state.to_dict(), + filter_title, + filter_html, + filter_ggsql_md, + query_title, + query_html, + query_ggsql_md, ) @app.callback( diff --git a/pkg-py/src/querychat/_dash_ui.py b/pkg-py/src/querychat/_dash_ui.py index f6ba64ca..f838aaeb 100644 --- a/pkg-py/src/querychat/_dash_ui.py +++ b/pkg-py/src/querychat/_dash_ui.py @@ -28,6 +28,13 @@ class IDs: data_info: str download_csv: str export_button: str + tabs: str + filter_plot: str + filter_plot_title: str + filter_ggsql: str + query_plot: str + query_plot_title: str + query_ggsql: str @classmethod def from_table_name(cls, table_name: str) -> IDs: @@ -46,6 +53,13 @@ def from_table_name(cls, table_name: str) -> IDs: data_info=f"{prefix}-data-info", download_csv=f"{prefix}-download-csv", export_button=f"{prefix}-export-button", + tabs=f"{prefix}-tabs", + filter_plot=f"{prefix}-filter-plot", + filter_plot_title=f"{prefix}-filter-plot-title", + filter_ggsql=f"{prefix}-filter-ggsql", + query_plot=f"{prefix}-query-plot", + query_plot_title=f"{prefix}-query-plot-title", + query_ggsql=f"{prefix}-query-ggsql", ) diff --git a/pkg-py/src/querychat/_gradio.py b/pkg-py/src/querychat/_gradio.py index c0f3518d..374c375d 100644 --- a/pkg-py/src/querychat/_gradio.py +++ b/pkg-py/src/querychat/_gradio.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, overload +from typing import TYPE_CHECKING, Any, Literal, Optional, overload from gradio.context import Context from narwhals.stable.v1.typing import IntoDataFrameT, IntoFrameT, IntoLazyFrameT @@ -24,6 +24,7 @@ if TYPE_CHECKING: from pathlib import Path + import altair as alt import chatlas import ibis import sqlalchemy @@ -106,7 +107,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -121,7 +127,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -136,7 +147,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -151,7 +167,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -165,7 +186,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -203,6 +229,96 @@ def head(self) -> str: """ return f"" + def ggvis( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> alt.Chart | None: + """ + Get the current Altair visualization chart. + + Parameters + ---------- + state + The state dict from the state component. + source + Which visualization to return. "filter" returns the dashboard + visualization (from visualize_dashboard tool), "query" returns + the query visualization (from visualize_query tool). + + Returns + ------- + : + An Altair Chart object, or None if no visualization exists. + + """ + import ggsql + + if source == "filter": + spec = state.get("filter_viz_spec") + if spec is None: + return None + # Render against current filtered data + df = as_narwhals(self.df(state)) + return ggsql.render_altair(df, spec) + else: + ggsql_query = state.get("query_viz_ggsql") + if ggsql_query is None: + return None + # Re-execute SQL and render + sql, viz_spec = ggsql.split_query(ggsql_query) + df = as_narwhals(self._data_source.execute_query(sql)) + return ggsql.render_altair(df, viz_spec) + + def ggsql( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> str | None: + """ + Get the current ggsql specification. + + Parameters + ---------- + state + The state dict from the state component. + source + Which specification to return. "filter" returns the VISUALISE spec + from visualize_dashboard, "query" returns the full ggsql query + from visualize_query. + + Returns + ------- + : + The ggsql specification string, or None if no visualization exists. + + """ + if source == "filter": + return state.get("filter_viz_spec") + else: + return state.get("query_viz_ggsql") + + def ggtitle( + self, state: AppStateDict, source: Literal["filter", "query"] = "filter" + ) -> str | None: + """ + Get the current visualization title. + + Parameters + ---------- + state + The state dict from the state component. + source + Which title to return. "filter" returns the title from + visualize_dashboard, "query" returns the title from visualize_query. + + Returns + ------- + : + The visualization title, or None if no visualization exists. + + """ + if source == "filter": + return state.get("filter_viz_title") + else: + return state.get("query_viz_title") + def ui(self) -> gr.State: """ Create chat UI components for custom layouts. @@ -327,6 +443,11 @@ def app(self) -> GradioBlocksWrapper: A wrapped Gradio Blocks app ready to launch. The wrapper injects querychat CSS/JS at launch time for Gradio 6.0+ compatibility. + The app includes three tabs: + - **Data**: Shows the filtered data table with the current SQL query + - **Filter Plot**: Shows the persistent dashboard visualization + - **Query Plot**: Shows the most recent query visualization + """ data_source = self._require_data_source("app") from gradio.themes import Soft @@ -343,31 +464,53 @@ def app(self) -> GradioBlocksWrapper: gr.Markdown(f"## `{table_name}`") - with gr.Group(): - with gr.Row(): - sql_title = gr.Markdown("**Current Query**") - reset_btn = gr.Button( - "Reset", size="sm", variant="secondary", scale=0 + with gr.Tabs(): + with gr.Tab("Data"): + with gr.Group(): + with gr.Row(): + sql_title = gr.Markdown("**Current Query**") + reset_btn = gr.Button( + "Reset", size="sm", variant="secondary", scale=0 + ) + sql_display = gr.Code( + label="", + language="sql", + value=f"SELECT * FROM {table_name}", + interactive=False, + lines=2, + ) + + with gr.Group(): + gr.Markdown("**Data Preview**") + data_display = gr.Dataframe( + label="", + buttons=["fullscreen", "copy"], + show_search="filter", + ) + data_info = gr.Markdown("") + + with gr.Tab("Filter Plot"): + filter_plot_title = gr.Markdown("") + filter_plot_display = gr.Plot(label="") + filter_ggsql_display = gr.Code( + label="ggsql spec", language="sql", lines=2 + ) + filter_plot_info = gr.Markdown( + "*No filter visualization yet. Ask the assistant to create one.*" ) - sql_display = gr.Code( - label="", - language="sql", - value=f"SELECT * FROM {table_name}", - interactive=False, - lines=2, - ) - with gr.Group(): - gr.Markdown("**Data Preview**") - data_display = gr.Dataframe( - label="", - buttons=["fullscreen", "copy"], - show_search="filter", - ) - data_info = gr.Markdown("") + with gr.Tab("Query Plot"): + query_plot_title = gr.Markdown("") + query_plot_display = gr.Plot(label="") + query_ggsql_display = gr.Code( + label="ggsql query", language="sql", lines=2 + ) + query_plot_info = gr.Markdown( + "*No query visualization yet. Ask the assistant to create one.*" + ) def update_displays(state_dict: AppStateDict): - """Update SQL and data displays based on state.""" + """Update SQL, data, and visualization displays based on state.""" title = state_dict.get("title") if state_dict else None error = state_dict.get("error") if state_dict else None @@ -385,11 +528,44 @@ def update_displays(state_dict: AppStateDict): data_info_parts = [] if error: - data_info_parts.append(f"⚠️ {error}") + data_info_parts.append(f"Warning: {error}") data_info_parts.append(f"*Data has {nrow} rows and {ncol} columns.*") data_info_text = " ".join(data_info_parts) - return sql_title_text, sql_code, native_df, data_info_text + # Filter visualization + filter_chart = self.ggvis(state_dict, "filter") + filter_title_text = self.ggtitle(state_dict, "filter") or "" + filter_spec = self.ggsql(state_dict, "filter") or "" + filter_info = ( + "" + if filter_chart + else "*No filter visualization yet. Ask the assistant to create one.*" + ) + + # Query visualization + query_chart = self.ggvis(state_dict, "query") + query_title_text = self.ggtitle(state_dict, "query") or "" + query_spec = self.ggsql(state_dict, "query") or "" + query_info = ( + "" + if query_chart + else "*No query visualization yet. Ask the assistant to create one.*" + ) + + return ( + sql_title_text, + sql_code, + native_df, + data_info_text, + f"### {filter_title_text}" if filter_title_text else "", + filter_chart, + filter_spec, + filter_info, + f"### {query_title_text}" if query_title_text else "", + query_chart, + query_spec, + query_info, + ) def reset_query(state_dict: AppStateDict): """Reset state to show full dataset.""" @@ -401,7 +577,20 @@ def reset_query(state_dict: AppStateDict): state_holder.change( fn=update_displays, inputs=[state_holder], - outputs=[sql_title, sql_display, data_display, data_info], + outputs=[ + sql_title, + sql_display, + data_display, + data_info, + filter_plot_title, + filter_plot_display, + filter_ggsql_display, + filter_plot_info, + query_plot_title, + query_plot_display, + query_ggsql_display, + query_plot_info, + ], ) reset_btn.click( diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py index c1dcc9a1..016ad59d 100644 --- a/pkg-py/src/querychat/_shiny.py +++ b/pkg-py/src/querychat/_shiny.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from pathlib import Path + import altair as alt import chatlas import ibis import narwhals.stable.v1 as nw @@ -97,10 +98,10 @@ class QueryChat(QueryChatBase[IntoFrameT]): tools Which querychat tools to include in the chat client by default. Can be: - A single tool string: `"update"` or `"query"` - - A tuple of tools: `("update", "query")` + - A tuple of tools: `("update", "query", "visualize_dashboard", "visualize_query")` - `None` or `()` to disable all tools - Default is `("update", "query")` (both tools enabled). + Default is `("update", "query", "visualize_dashboard", "visualize_query")` (all tools enabled). Set to `"update"` to prevent the LLM from accessing data values, only allowing dashboard filtering without answering questions. @@ -156,7 +157,12 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -172,7 +178,12 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -188,7 +199,12 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -204,7 +220,12 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -219,7 +240,12 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -245,9 +271,14 @@ def app( """ Quickly chat with a dataset. - Creates a Shiny app with a chat sidebar and data table view -- providing a + Creates a Shiny app with a chat sidebar and tabbed view -- providing a quick-and-easy way to start chatting with your data. + The app includes three tabs: + - **Data**: Shows the filtered data table + - **Filter Plot**: Shows the persistent dashboard visualization + - **Query Plot**: Shows the most recent query visualization + Parameters ---------- bookmark_store @@ -285,9 +316,23 @@ def app_ui(request): fill=False, style="max-height: 33%;", ), - ui.card( - ui.card_header(bs_icon("table"), " Data"), - ui.output_data_frame("dt"), + ui.navset_tab( + ui.nav_panel( + "Data", + ui.card( + ui.card_header(bs_icon("table"), " Data"), + ui.output_data_frame("dt"), + ), + ), + ui.nav_panel( + "Filter Plot", + ui.output_ui("filter_plot_container"), + ), + ui.nav_panel( + "Query Plot", + ui.output_ui("query_plot_container"), + ), + id="main_tabs", ), title=ui.span("querychat with ", ui.code(table_name)), class_="bslib-page-dashboard", @@ -301,6 +346,7 @@ def app_server(input: Inputs, output: Outputs, session: Session): greeting=self.greeting, client=self._client, enable_bookmarking=enable_bookmarking, + tools=self.tools, ) @render.text @@ -338,6 +384,64 @@ def sql_output(): width="100%", ) + @render.ui + def filter_plot_container(): + from shinywidgets import output_widget, render_altair + + chart = vals.filter_viz_chart() + if chart is None: + return ui.card( + ui.card_body( + ui.p( + "No filter visualization yet. " + "Use the chat to create one." + ), + class_="text-muted text-center py-5", + ), + ) + + @render_altair + def filter_chart(): + return chart + + return ui.card( + ui.card_header( + bs_icon("bar-chart-fill"), + " ", + vals.filter_viz_title.get() or "Filter Visualization", + ), + output_widget("filter_chart"), + ) + + @render.ui + def query_plot_container(): + from shinywidgets import output_widget, render_altair + + chart = vals.query_viz_chart() + if chart is None: + return ui.card( + ui.card_body( + ui.p( + "No query visualization yet. " + "Use the chat to create one." + ), + class_="text-muted text-center py-5", + ), + ) + + @render_altair + def query_chart(): + return chart + + return ui.card( + ui.card_header( + bs_icon("bar-chart-fill"), + " ", + vals.query_viz_title.get() or "Query Visualization", + ), + output_widget("query_chart"), + ) + return App(app_ui, app_server, bookmark_store=bookmark_store) def sidebar( @@ -493,6 +597,7 @@ def title(): greeting=self.greeting, client=self.client, enable_bookmarking=enable_bookmarking, + tools=self.tools, ) @@ -730,6 +835,7 @@ def __init__( greeting=self.greeting, client=self._client, enable_bookmarking=enable, + tools=self.tools, ) def sidebar( @@ -870,3 +976,69 @@ def title(self, value: Optional[str] = None) -> str | None | bool: return self._vals.title() else: return self._vals.title.set(value) + + def ggvis(self, source: Literal["filter", "query"] = "filter") -> alt.Chart | None: + """ + Get the visualization chart. + + Parameters + ---------- + source + Which visualization to return: + - "filter": Chart from visualize_dashboard (updates with filter changes) + - "query": Chart from visualize_query (most recent inline visualization) + + Returns + ------- + : + The Altair chart, or None if no visualization exists. + + """ + if source == "filter": + return self._vals.filter_viz_chart() + else: + return self._vals.query_viz_chart() + + def ggsql(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the ggsql specification. + + Parameters + ---------- + source + Which specification to return: + - "filter": VISUALISE spec only (from visualize_dashboard) + - "query": Full ggsql query (from visualize_query) + + Returns + ------- + : + The ggsql specification, or None if no visualization exists. + + """ + if source == "filter": + return self._vals.filter_viz_spec.get() + else: + return self._vals.query_viz_ggsql.get() + + def ggtitle(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the visualization title. + + Parameters + ---------- + source + Which title to return: + - "filter": Title from visualize_dashboard + - "query": Title from visualize_query + + Returns + ------- + : + The title, or None if no visualization exists. + + """ + if source == "filter": + return self._vals.filter_viz_title.get() + else: + return self._vals.query_viz_title.get() diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py index 335f6803..584f7aff 100644 --- a/pkg-py/src/querychat/_shiny_module.py +++ b/pkg-py/src/querychat/_shiny_module.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import logging import warnings from dataclasses import dataclass from pathlib import Path @@ -13,17 +14,26 @@ from shiny import module, reactive, ui from ._querychat_core import GREETING_PROMPT -from .tools import tool_query, tool_reset_dashboard, tool_update_dashboard +from .tools import ( + tool_query, + tool_reset_dashboard, + tool_update_dashboard, + tool_visualize_dashboard, + tool_visualize_query, +) if TYPE_CHECKING: from collections.abc import Callable + import altair as alt from shiny.bookmark import BookmarkState, RestoreState from shiny import Inputs, Outputs, Session from ._datasource import DataSource - from .types import UpdateDashboardData + from .tools import UpdateDashboardData, VisualizeDashboardData, VisualizeQueryData + +logger = logging.getLogger(__name__) ReactiveString = reactive.Value[str] """A reactive string value.""" @@ -79,6 +89,26 @@ class ServerValues(Generic[IntoFrameT]): The session-specific chat client instance. This is a deep copy of the base client configured for this specific session, containing the chat history and tool registrations for this session only. + filter_viz_spec + A reactive Value containing the VISUALISE spec from visualize_dashboard. + Returns `None` if no visualization has been created. + filter_viz_title + A reactive Value containing the title from visualize_dashboard. + Returns `None` if no visualization has been created. + filter_viz_chart + A callable returning the rendered Altair chart from visualize_dashboard. + Returns `None` if no visualization has been created. The chart is + re-rendered on each call using `ggsql.render_altair()`. + query_viz_ggsql + A reactive Value containing the full ggsql query from visualize_query. + Returns `None` if no visualization has been created. + query_viz_title + A reactive Value containing the title from visualize_query. + Returns `None` if no visualization has been created. + query_viz_chart + A callable returning the rendered Altair chart from visualize_query. + Returns `None` if no visualization has been created. The chart is + re-rendered on each call using `ggsql.render_altair()`. """ @@ -86,6 +116,13 @@ class ServerValues(Generic[IntoFrameT]): sql: ReactiveStringOrNone title: ReactiveStringOrNone client: chatlas.Chat + # Visualization state + filter_viz_spec: ReactiveStringOrNone + filter_viz_title: ReactiveStringOrNone + filter_viz_chart: Callable[[], alt.TopLevelMixin | None] + query_viz_ggsql: ReactiveStringOrNone + query_viz_title: ReactiveStringOrNone + query_viz_chart: Callable[[], alt.TopLevelMixin | None] @module.server @@ -98,12 +135,19 @@ def mod_server( greeting: str | None, client: chatlas.Chat | Callable, enable_bookmarking: bool, + tools: tuple[str, ...] | None = None, ) -> ServerValues[IntoFrameT]: # Reactive values to store state sql = ReactiveStringOrNone(None) title = ReactiveStringOrNone(None) has_greeted = reactive.value[bool](False) # noqa: FBT003 + # Visualization state - store only specs, render on demand + filter_viz_spec: reactive.Value[str | None] = reactive.value(None) + filter_viz_title: reactive.Value[str | None] = reactive.value(None) + query_viz_ggsql: reactive.Value[str | None] = reactive.value(None) + query_viz_title: reactive.Value[str | None] = reactive.value(None) + # Short-circuit for stub sessions (e.g. 1st run of an Express app) # data_source may be None during stub session for deferred pattern if session.is_stub_session(): @@ -116,6 +160,12 @@ def _stub_df(): sql=sql, title=title, client=client if isinstance(client, chatlas.Chat) else client(), + filter_viz_spec=filter_viz_spec, + filter_viz_title=filter_viz_title, + filter_viz_chart=lambda: filter_viz_chart.get(), + query_viz_ggsql=query_viz_ggsql, + query_viz_title=query_viz_title, + query_viz_chart=lambda: query_viz_chart.get(), ) # Real session requires data_source @@ -133,6 +183,14 @@ def reset_dashboard(): sql.set(None) title.set(None) + def update_filter_viz(data: VisualizeDashboardData): + filter_viz_spec.set(data["spec"]) + filter_viz_title.set(data["title"]) + + def update_query_viz(data: VisualizeQueryData): + query_viz_ggsql.set(data["ggsql"]) + query_viz_title.set(data["title"]) + # Set up the chat object for this session # Support both a callable that creates a client and legacy instance pattern if callable(client) and not isinstance(client, chatlas.Chat): @@ -147,6 +205,12 @@ def reset_dashboard(): chat.register_tool(tool_query(data_source)) chat.register_tool(tool_reset_dashboard(reset_dashboard)) + # Register visualization tools if enabled + if tools and "visualize_dashboard" in tools: + chat.register_tool(tool_visualize_dashboard(data_source, update_filter_viz)) + if tools and "visualize_query" in tools: + chat.register_tool(tool_visualize_query(data_source, update_query_viz)) + # Execute query when SQL changes @reactive.calc def filtered_df(): @@ -154,6 +218,33 @@ def filtered_df(): df = data_source.get_data() if not query else data_source.execute_query(query) return df + # Render filter visualization on demand + @reactive.calc + def render_filter_viz_chart(): + """Render filter visualization using current filtered data.""" + import ggsql + + spec = filter_viz_spec.get() + if spec is None: + return None + + current_df = filtered_df() + return ggsql.render_altair(current_df, spec) + + # Render query visualization on demand + @reactive.calc + def render_query_viz_chart(): + """Render query visualization by re-executing the ggsql query.""" + import ggsql + + ggsql_query = query_viz_ggsql.get() + if ggsql_query is None: + return None + + sql_part, viz_spec = ggsql.split_query(ggsql_query) + df = data_source.execute_query(sql_part) + return ggsql.render_altair(df, viz_spec) + # Chat UI logic chat_ui = shinychat.Chat(CHAT_ID) @@ -220,7 +311,18 @@ def _on_restore(x: RestoreState) -> None: if "querychat_has_greeted" in vals: has_greeted.set(vals["querychat_has_greeted"]) - return ServerValues(df=filtered_df, sql=sql, title=title, client=chat) + return ServerValues( + df=filtered_df, + sql=sql, + title=title, + client=chat, + filter_viz_spec=filter_viz_spec, + filter_viz_title=filter_viz_title, + filter_viz_chart=render_filter_viz_chart, + query_viz_ggsql=query_viz_ggsql, + query_viz_title=query_viz_title, + query_viz_chart=render_query_viz_chart, + ) class GreetWarning(Warning): diff --git a/pkg-py/src/querychat/_streamlit.py b/pkg-py/src/querychat/_streamlit.py index 484b8f4a..a5f30614 100644 --- a/pkg-py/src/querychat/_streamlit.py +++ b/pkg-py/src/querychat/_streamlit.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, cast, overload +from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload from narwhals.stable.v1.typing import IntoDataFrameT, IntoFrameT, IntoLazyFrameT @@ -19,6 +19,7 @@ if TYPE_CHECKING: from pathlib import Path + import altair as alt import chatlas import ibis import narwhals.stable.v1 as nw @@ -81,7 +82,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -96,7 +102,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -111,7 +122,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -126,7 +142,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -140,7 +161,12 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ( + "update", + "query", + "visualize_dashboard", + "visualize_query", + ), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -167,9 +193,11 @@ def _get_state(self) -> AppState: if self._state_key not in st.session_state: st.session_state[self._state_key] = create_app_state( data_source, - lambda update_cb, reset_cb: self.client( + lambda update_cb, reset_cb, filter_viz_cb, query_viz_cb: self.client( update_dashboard=update_cb, reset_dashboard=reset_cb, + visualize_dashboard=filter_viz_cb, + visualize_query=query_viz_cb, ), self.greeting, ) @@ -180,7 +208,12 @@ def app(self) -> None: Render a complete Streamlit app. Configures the page, renders chat in sidebar, and displays - SQL query and data table in the main area. + SQL query, data table, and visualizations in a tabbed interface. + + The app includes three tabs: + - **Data**: Shows the filtered data table with the current SQL query + - **Filter Plot**: Shows the persistent dashboard visualization + - **Query Plot**: Shows the most recent query visualization """ data_source = self._require_data_source("app") import streamlit as st @@ -192,7 +225,23 @@ def app(self) -> None: ) self.sidebar() - self._render_main_content() + + state = self._get_state() + + st.title(f"querychat with `{self._data_source.table_name}`") + + data_tab, filter_plot_tab, query_plot_tab = st.tabs( + ["Data", "Filter Plot", "Query Plot"] + ) + + with data_tab: + self._render_data_tab(state) + + with filter_plot_tab: + self._render_filter_plot_tab() + + with query_plot_tab: + self._render_query_plot_tab() def sidebar(self) -> None: """Render the chat interface in the Streamlit sidebar.""" @@ -303,14 +352,92 @@ def reset(self) -> None: state.reset_dashboard() st.rerun() - def _render_main_content(self) -> None: - """Render the main content area (SQL + data table).""" - data_source = self._require_data_source("_render_main_content") - import streamlit as st + def ggvis(self, source: Literal["filter", "query"] = "filter") -> alt.Chart | None: + """ + Get the current Altair visualization chart. + + Parameters + ---------- + source + Which visualization to return. "filter" returns the dashboard + visualization (from visualize_dashboard tool), "query" returns + the query visualization (from visualize_query tool). + + Returns + ------- + : + An Altair Chart object, or None if no visualization exists. + + """ + import ggsql + + from ._utils import as_narwhals state = self._get_state() + if source == "filter": + spec = state.filter_viz_spec + if spec is None: + return None + # Render against current filtered data + df = as_narwhals(self.df()) + return ggsql.render_altair(df, spec) + else: + ggsql_query = state.query_viz_ggsql + if ggsql_query is None: + return None + # Re-execute SQL and render + sql, viz_spec = ggsql.split_query(ggsql_query) + df = as_narwhals(self._data_source.execute_query(sql)) + return ggsql.render_altair(df, viz_spec) + + def ggsql(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the current ggsql specification. - st.title(f"querychat with `{data_source.table_name}`") + Parameters + ---------- + source + Which specification to return. "filter" returns the VISUALISE spec + from visualize_dashboard, "query" returns the full ggsql query + from visualize_query. + + Returns + ------- + : + The ggsql specification string, or None if no visualization exists. + + """ + state = self._get_state() + if source == "filter": + return state.filter_viz_spec + else: + return state.query_viz_ggsql + + def ggtitle(self, source: Literal["filter", "query"] = "filter") -> str | None: + """ + Get the current visualization title. + + Parameters + ---------- + source + Which title to return. "filter" returns the title from + visualize_dashboard, "query" returns the title from visualize_query. + + Returns + ------- + : + The visualization title, or None if no visualization exists. + + """ + state = self._get_state() + if source == "filter": + return state.filter_viz_title + else: + return state.query_viz_title + + def _render_data_tab(self, state: AppState) -> None: + """Render the Data tab content.""" + import streamlit as st st.subheader(state.title or "SQL Query") @@ -331,3 +458,45 @@ def _render_main_content(self) -> None: df.to_native(), use_container_width=True, height=400, hide_index=True ) st.caption(f"Data has {df.shape[0]} rows and {df.shape[1]} columns.") + + def _render_filter_plot_tab(self) -> None: + """Render the Filter Plot tab content.""" + import streamlit as st + + chart = self.ggvis("filter") + if chart is not None: + title = self.ggtitle("filter") + if title: + st.subheader(title) + st.altair_chart(chart, use_container_width=True) + + spec = self.ggsql("filter") + if spec: + with st.expander("ggsql spec"): + st.code(spec, language="sql") + else: + st.info( + "No filter visualization. Ask the assistant to create one " + "using the visualize_dashboard tool." + ) + + def _render_query_plot_tab(self) -> None: + """Render the Query Plot tab content.""" + import streamlit as st + + chart = self.ggvis("query") + if chart is not None: + title = self.ggtitle("query") + if title: + st.subheader(title) + st.altair_chart(chart, use_container_width=True) + + spec = self.ggsql("query") + if spec: + with st.expander("ggsql query"): + st.code(spec, language="sql") + else: + st.info( + "No query visualization. Ask the assistant to create one " + "using the visualize_query tool." + ) From a21712ca3d2a8910bdd2c9a422293910f8132737 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 27 Jan 2026 15:50:13 -0600 Subject: [PATCH 3/3] test(pkg-py): add visualization tests Add comprehensive test coverage for ggsql visualization features: - test_ggsql.py: Unit tests for ggsql parsing and rendering - test_ggsql_integration.py: Integration tests for end-to-end visualization - test_viz_tools.py: Tests for visualize_dashboard and visualize_query tools - test_visualization_tabs.py: Playwright tests for UI tab interactions - Update test_state.py with visualization state field tests - Update test_tools.py and test_base.py for new tool configurations Co-Authored-By: Claude Opus 4.5 --- .../playwright/test_visualization_tabs.py | 286 ++++++++++++++++++ pkg-py/tests/test_base.py | 2 +- pkg-py/tests/test_client_console.py | 4 +- pkg-py/tests/test_ggsql.py | 113 +++++++ pkg-py/tests/test_ggsql_integration.py | 225 ++++++++++++++ pkg-py/tests/test_state.py | 76 ++++- pkg-py/tests/test_tools.py | 8 + pkg-py/tests/test_viz_tools.py | 144 +++++++++ 8 files changed, 853 insertions(+), 5 deletions(-) create mode 100644 pkg-py/tests/playwright/test_visualization_tabs.py create mode 100644 pkg-py/tests/test_ggsql.py create mode 100644 pkg-py/tests/test_ggsql_integration.py create mode 100644 pkg-py/tests/test_viz_tools.py diff --git a/pkg-py/tests/playwright/test_visualization_tabs.py b/pkg-py/tests/playwright/test_visualization_tabs.py new file mode 100644 index 00000000..c7f68c9e --- /dev/null +++ b/pkg-py/tests/playwright/test_visualization_tabs.py @@ -0,0 +1,286 @@ +""" +Playwright tests for visualization tabs (Filter Plot, Query Plot). + +These tests verify that the visualization tabs are present and show +appropriate placeholder messages when no visualization has been created. + +Since the visualization tools require real LLM interaction to create charts, +these tests focus on: +1. Tab presence and accessibility +2. Placeholder messages +3. Tab switching functionality +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from playwright.sync_api import expect + +if TYPE_CHECKING: + from playwright.sync_api import Page + + +# Shiny Tests +class TestShinyVisualizationTabs: + """Tests for visualization tabs in Shiny app (01-hello-app.py).""" + + @pytest.fixture(autouse=True) + def setup(self, page: Page, app_01_hello: str) -> None: + """Navigate to the app before each test.""" + page.goto(app_01_hello) + page.wait_for_selector("table", timeout=30000) + self.page = page + + def test_three_tabs_present(self) -> None: + """VIZ-SHINY-01: Three tabs are visible (Data, Filter Plot, Query Plot).""" + tabs = self.page.locator('[role="tab"]') + expect(tabs).to_have_count(3) + + expect(self.page.get_by_role("tab", name="Data")).to_be_visible() + expect(self.page.get_by_role("tab", name="Filter Plot")).to_be_visible() + expect(self.page.get_by_role("tab", name="Query Plot")).to_be_visible() + + def test_filter_plot_tab_clickable(self) -> None: + """VIZ-SHINY-02: Filter Plot tab can be clicked.""" + filter_tab = self.page.locator('text="Filter Plot"') + filter_tab.click() + + # Should show placeholder message + expect(self.page.locator("text=No filter visualization")).to_be_visible( + timeout=5000 + ) + + def test_query_plot_tab_clickable(self) -> None: + """VIZ-SHINY-03: Query Plot tab can be clicked.""" + query_tab = self.page.locator('text="Query Plot"') + query_tab.click() + + # Should show placeholder message + expect(self.page.locator("text=No query visualization")).to_be_visible( + timeout=5000 + ) + + def test_filter_plot_shows_placeholder(self) -> None: + """VIZ-SHINY-04: Filter Plot tab shows placeholder when empty.""" + filter_tab = self.page.locator('text="Filter Plot"') + filter_tab.click() + + placeholder = self.page.locator("text=Use the chat to create one") + expect(placeholder).to_be_visible(timeout=5000) + + def test_query_plot_shows_placeholder(self) -> None: + """VIZ-SHINY-05: Query Plot tab shows placeholder when empty.""" + query_tab = self.page.locator('text="Query Plot"') + query_tab.click() + + placeholder = self.page.locator("text=Use the chat to create one") + expect(placeholder).to_be_visible(timeout=5000) + + def test_can_switch_between_tabs(self) -> None: + """VIZ-SHINY-06: Can switch between all three tabs.""" + # Start on Data tab (default) + expect(self.page.locator("table")).to_be_visible() + + # Switch to Filter Plot + self.page.locator('text="Filter Plot"').click() + expect(self.page.locator("text=No filter visualization")).to_be_visible( + timeout=5000 + ) + + # Switch to Query Plot + self.page.locator('text="Query Plot"').click() + expect(self.page.locator("text=No query visualization")).to_be_visible( + timeout=5000 + ) + + # Switch back to Data + self.page.locator('[role="tab"]:has-text("Data")').click() + expect(self.page.locator("table")).to_be_visible() + + +# Streamlit Tests +class TestStreamlitVisualizationTabs: + """Tests for visualization tabs in Streamlit app (04-streamlit-app.py).""" + + @pytest.fixture(autouse=True) + def setup(self, page: Page, app_04_streamlit: str) -> None: + """Navigate to the app before each test.""" + page.goto(app_04_streamlit) + page.wait_for_selector('[data-testid="stApp"]', timeout=30000) + page.wait_for_selector('[data-testid="stChatMessage"]', timeout=30000) + self.page = page + + def test_three_tabs_present(self) -> None: + """VIZ-STREAMLIT-01: Three tabs are visible.""" + tabs = self.page.locator('[data-baseweb="tab"]') + expect(tabs).to_have_count(3) + + def test_filter_plot_tab_clickable(self) -> None: + """VIZ-STREAMLIT-02: Filter Plot tab can be clicked.""" + tabs = self.page.locator('[data-baseweb="tab"]') + tabs.nth(1).click() # Filter Plot is second tab + + # Should show info message about no visualization + expect(self.page.locator("text=No filter visualization")).to_be_visible( + timeout=5000 + ) + + def test_query_plot_tab_clickable(self) -> None: + """VIZ-STREAMLIT-03: Query Plot tab can be clicked.""" + tabs = self.page.locator('[data-baseweb="tab"]') + tabs.nth(2).click() # Query Plot is third tab + + # Should show info message about no visualization + expect(self.page.locator("text=No query visualization")).to_be_visible( + timeout=5000 + ) + + def test_filter_plot_mentions_tool(self) -> None: + """VIZ-STREAMLIT-04: Filter Plot placeholder mentions the tool.""" + tabs = self.page.locator('[data-baseweb="tab"]') + tabs.nth(1).click() + + expect(self.page.locator("text=visualize_dashboard")).to_be_visible( + timeout=5000 + ) + + def test_query_plot_mentions_tool(self) -> None: + """VIZ-STREAMLIT-05: Query Plot placeholder mentions the tool.""" + tabs = self.page.locator('[data-baseweb="tab"]') + tabs.nth(2).click() + + expect(self.page.locator("text=visualize_query")).to_be_visible(timeout=5000) + + +# Gradio Tests +class TestGradioVisualizationTabs: + """Tests for visualization tabs in Gradio app (05-gradio-app.py).""" + + @pytest.fixture(autouse=True) + def setup(self, page: Page, app_05_gradio: str) -> None: + """Navigate to the app before each test.""" + page.goto(app_05_gradio) + page.wait_for_selector("gradio-app", timeout=30000) + page.wait_for_selector('[data-testid="bot"]', timeout=30000) + self.page = page + + def test_three_tabs_present(self) -> None: + """VIZ-GRADIO-01: Three tabs are visible.""" + tabs = self.page.locator('[role="tab"]') + expect(tabs).to_have_count(3) + + def test_filter_plot_tab_clickable(self) -> None: + """VIZ-GRADIO-02: Filter Plot tab can be clicked.""" + filter_tab = self.page.locator('button[role="tab"]:has-text("Filter Plot")') + filter_tab.click() + + # Should show placeholder message + expect(self.page.locator("text=No filter visualization")).to_be_visible( + timeout=5000 + ) + + def test_query_plot_tab_clickable(self) -> None: + """VIZ-GRADIO-03: Query Plot tab can be clicked.""" + query_tab = self.page.locator('button[role="tab"]:has-text("Query Plot")') + query_tab.click() + + # Should show placeholder message + expect(self.page.locator("text=No query visualization")).to_be_visible( + timeout=5000 + ) + + def test_filter_plot_has_plot_area(self) -> None: + """VIZ-GRADIO-04: Filter Plot tab has plot and ggsql spec areas.""" + filter_tab = self.page.locator('button[role="tab"]:has-text("Filter Plot")') + filter_tab.click() + + # Should have Plot label + expect(self.page.locator('text="Plot"')).to_be_visible(timeout=5000) + # Should have ggsql spec label + expect(self.page.locator('text="ggsql spec"')).to_be_visible(timeout=5000) + + def test_query_plot_has_plot_area(self) -> None: + """VIZ-GRADIO-05: Query Plot tab has plot and ggsql query areas.""" + query_tab = self.page.locator('button[role="tab"]:has-text("Query Plot")') + query_tab.click() + + # Should have Plot label + expect(self.page.locator('text="Plot"')).to_be_visible(timeout=5000) + # Should have ggsql query label + expect(self.page.locator('text="ggsql query"')).to_be_visible(timeout=5000) + + +# Dash Tests +class TestDashVisualizationTabs: + """Tests for visualization tabs in Dash app (06-dash-app.py).""" + + @pytest.fixture(autouse=True) + def setup(self, page: Page, app_06_dash: str) -> None: + """Navigate to the app before each test.""" + page.goto(app_06_dash) + page.wait_for_selector("#querychat-titanic-chat-history", timeout=30000) + # Wait for greeting + expect(page.locator("#querychat-titanic-chat-history")).to_contain_text( + "Hello", timeout=30000 + ) + self.page = page + + def test_three_tabs_present(self) -> None: + """VIZ-DASH-01: Three tabs are visible.""" + tabs = self.page.locator('[role="tab"]') + expect(tabs).to_have_count(3) + + expect(self.page.get_by_role("tab", name="Data")).to_be_visible() + expect(self.page.get_by_role("tab", name="Filter Plot")).to_be_visible() + expect(self.page.get_by_role("tab", name="Query Plot")).to_be_visible() + + def test_filter_plot_tab_clickable(self) -> None: + """VIZ-DASH-02: Filter Plot tab can be clicked.""" + filter_tab = self.page.get_by_role("tab", name="Filter Plot") + filter_tab.click() + + # Should show placeholder in iframe + filter_plot = self.page.locator("#querychat-titanic-filter-plot") + expect(filter_plot).to_be_visible(timeout=5000) + + def test_query_plot_tab_clickable(self) -> None: + """VIZ-DASH-03: Query Plot tab can be clicked.""" + query_tab = self.page.get_by_role("tab", name="Query Plot") + query_tab.click() + + # Should show placeholder in iframe + query_plot = self.page.locator("#querychat-titanic-query-plot") + expect(query_plot).to_be_visible(timeout=5000) + + def test_data_tab_shows_table(self) -> None: + """VIZ-DASH-04: Data tab shows the data table.""" + # Data tab is default, should show AG Grid. The table wrapper is present + # but the grid may have height: 0 until data loads. + # Check that rows are rendered instead. + data_rows = self.page.locator(".ag-row") + expect(data_rows.first).to_be_visible(timeout=15000) + + def test_can_switch_between_tabs(self) -> None: + """VIZ-DASH-05: Can switch between all three tabs.""" + # Start on Data tab (default) + expect(self.page.locator("#querychat-titanic-sql-display")).to_be_visible() + + # Switch to Filter Plot + self.page.get_by_role("tab", name="Filter Plot").click() + expect(self.page.locator("#querychat-titanic-filter-plot")).to_be_visible( + timeout=5000 + ) + + # Switch to Query Plot + self.page.get_by_role("tab", name="Query Plot").click() + expect(self.page.locator("#querychat-titanic-query-plot")).to_be_visible( + timeout=5000 + ) + + # Switch back to Data + self.page.get_by_role("tab", name="Data").click() + expect(self.page.locator("#querychat-titanic-sql-display")).to_be_visible( + timeout=5000 + ) diff --git a/pkg-py/tests/test_base.py b/pkg-py/tests/test_base.py index ead9001e..d95481e1 100644 --- a/pkg-py/tests/test_base.py +++ b/pkg-py/tests/test_base.py @@ -161,7 +161,7 @@ class TestQueryChatBase: def test_init_with_dataframe(self, sample_df): qc = QueryChatBase(sample_df, "test_table") assert isinstance(qc.data_source, DataFrameSource) - assert qc.tools == ("update", "query") + assert qc.tools == ("update", "query", "visualize_dashboard", "visualize_query") def test_init_with_custom_greeting(self, sample_df): qc = QueryChatBase(sample_df, "test_table", greeting="Hello!") diff --git a/pkg-py/tests/test_client_console.py b/pkg-py/tests/test_client_console.py index 5db3e6e6..d1b61e47 100644 --- a/pkg-py/tests/test_client_console.py +++ b/pkg-py/tests/test_client_console.py @@ -273,7 +273,7 @@ def test_default_tools_maintain_current_behavior(self, sample_df): # Without tools parameter, should include both tools (like before) qc = QueryChat(sample_df, "test_table", greeting="Hello!") - assert qc.tools == ("update", "query") + assert qc.tools == ("update", "query", "visualize_dashboard", "visualize_query") prompt = qc.system_prompt assert "Filtering and Sorting Data" in prompt @@ -292,4 +292,4 @@ def test_existing_initialization_still_works(self, sample_df): assert qc is not None assert qc.id == "querychat_test_table" - assert qc.tools == ("update", "query") + assert qc.tools == ("update", "query", "visualize_dashboard", "visualize_query") diff --git a/pkg-py/tests/test_ggsql.py b/pkg-py/tests/test_ggsql.py new file mode 100644 index 00000000..9a278f0b --- /dev/null +++ b/pkg-py/tests/test_ggsql.py @@ -0,0 +1,113 @@ +"""Tests for ggsql integration helpers.""" + +import ggsql +import polars as pl +import pytest +from querychat._ggsql import extract_title + + +def _ggsql_render_works() -> bool: + """Check if ggsql.render_altair() is functional (build can be broken in some envs).""" + try: + import ggsql + + df = pl.DataFrame({"x": [1, 2], "y": [3, 4]}) + result = ggsql.render_altair(df, "VISUALISE x, y DRAW point") + spec = result.to_dict() + return "$schema" in spec + except (ValueError, ImportError): + return False + + +ggsql_render_works = pytest.mark.skipif( + not _ggsql_render_works(), + reason="ggsql.render_altair() not functional (build environment issue)", +) + + +class TestGgsqlSplitQuery: + """Tests for ggsql.split_query() usage.""" + + def test_splits_query_with_visualise(self): + query = "SELECT x, y FROM data VISUALISE x, y DRAW point" + sql, viz = ggsql.split_query(query) + assert sql == "SELECT x, y FROM data" + assert viz == "VISUALISE x, y DRAW point" + + def test_returns_empty_viz_without_visualise(self): + query = "SELECT x, y FROM data" + sql, viz = ggsql.split_query(query) + assert sql == "SELECT x, y FROM data" + assert viz == "" + + def test_handles_complex_query(self): + query = """ + SELECT date, SUM(revenue) as total + FROM sales + GROUP BY date + VISUALISE date AS x, total AS y + DRAW line + LABEL title => 'Revenue Over Time' + """ + sql, viz = ggsql.split_query(query) + assert "SELECT date, SUM(revenue)" in sql + assert "GROUP BY date" in sql + assert "VISUALISE date AS x" in viz + assert "LABEL title" in viz + + +class TestGgsqlRenderAltair: + @ggsql_render_works + def test_renders_simple_scatter(self): + import ggsql + + df = pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) + viz_spec = "VISUALISE x, y DRAW point" + chart = ggsql.render_altair(df, viz_spec) + result = chart.to_dict() + assert "$schema" in result + assert "vega-lite" in result["$schema"] + assert "layer" in result + + @ggsql_render_works + def test_returns_altair_chart(self): + import altair as alt + import ggsql + + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + viz_spec = "VISUALISE a AS x, b AS y DRAW line" + chart = ggsql.render_altair(df, viz_spec) + # ggsql returns LayerChart or other chart types + assert isinstance(chart, (alt.Chart, alt.LayerChart, alt.FacetChart)) + result = chart.to_dict() + assert result["$schema"] == "https://vega.github.io/schema/vega-lite/v6.json" + + @ggsql_render_works + def test_renders_pandas_dataframe(self): + import ggsql + import pandas as pd + + df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) + viz_spec = "VISUALISE x, y DRAW point" + chart = ggsql.render_altair(df, viz_spec) + result = chart.to_dict() + assert "$schema" in result + assert "vega-lite" in result["$schema"] + assert "layer" in result + + +class TestExtractTitle: + def test_extracts_title_from_label(self): + viz_spec = "VISUALISE x, y DRAW point LABEL title => 'My Chart'" + title = extract_title(viz_spec) + assert title == "My Chart" + + def test_returns_none_without_title(self): + viz_spec = "VISUALISE x, y DRAW point" + title = extract_title(viz_spec) + assert title is None + + def test_extracts_title_with_double_quotes(self): + viz_spec = 'VISUALISE x, y DRAW point LABEL title => "Double Quoted"' + title = extract_title(viz_spec) + assert title == "Double Quoted" diff --git a/pkg-py/tests/test_ggsql_integration.py b/pkg-py/tests/test_ggsql_integration.py new file mode 100644 index 00000000..05277f6c --- /dev/null +++ b/pkg-py/tests/test_ggsql_integration.py @@ -0,0 +1,225 @@ +"""Integration tests for ggsql visualization.""" + +import narwhals.stable.v1 as nw +import polars as pl +import pytest +from querychat._datasource import DataFrameSource +from querychat._querychat_core import AppState +from querychat.tools import ( + VisualizeDashboardData, + VisualizeQueryData, + tool_visualize_dashboard, + tool_visualize_query, +) + + +def _ggsql_render_works() -> bool: + """Check if ggsql.render_altair() is functional.""" + try: + import ggsql + + df = pl.DataFrame({"x": [1, 2], "y": [3, 4]}) + result = ggsql.render_altair(df, "VISUALISE x, y DRAW point") + spec = result.to_dict() + return "$schema" in spec + except (ValueError, ImportError): + return False + + +ggsql_render_works = pytest.mark.skipif( + not _ggsql_render_works(), + reason="ggsql.render_altair() not functional (build environment issue)", +) + + +@pytest.fixture +def sample_df(): + return pl.DataFrame( + { + "date": ["2024-01-01", "2024-01-02", "2024-01-03"], + "revenue": [100, 150, 120], + "category": ["A", "B", "A"], + } + ) + + +@pytest.fixture +def data_source(sample_df): + nw_df = nw.from_native(sample_df) + return DataFrameSource(nw_df, "test_data") + + +class TestVisualizeDashboardIntegration: + """Integration tests for visualize_dashboard tool.""" + + @ggsql_render_works + def test_creates_vegalite_chart(self, data_source): + """Test that visualize_dashboard stores spec (chart rendered on demand).""" + captured = {} + + def update_callback(data: VisualizeDashboardData): + captured.update(data) + + tool = tool_visualize_dashboard(data_source, update_callback) + impl = tool.func + + result = impl( + viz_spec="VISUALISE category AS x, revenue AS y DRAW bar", + title="Revenue by Category", + ) + + assert result.error is None + assert "spec" in captured + assert captured["title"] == "Revenue by Category" + # Chart is now rendered on demand, not stored + assert "chart" not in captured + + @ggsql_render_works + def test_extracts_title_from_spec(self, data_source): + """Test that title is extracted from LABEL clause when not provided.""" + captured = {} + + def update_callback(data: VisualizeDashboardData): + captured.update(data) + + tool = tool_visualize_dashboard(data_source, update_callback) + impl = tool.func + + impl( + viz_spec="VISUALISE category AS x, revenue AS y DRAW bar LABEL title => 'From Spec Title'", + title=None, + ) + + assert captured["title"] == "From Spec Title" + + +class TestVisualizeQueryIntegration: + """Integration tests for visualize_query tool.""" + + @ggsql_render_works + def test_executes_sql_and_creates_chart(self, data_source): + """Test that visualize_query stores ggsql (chart rendered on demand).""" + captured = {} + + def update_callback(data: VisualizeQueryData): + captured.update(data) + + tool = tool_visualize_query(data_source, update_callback) + impl = tool.func + + result = impl( + ggsql="SELECT category, SUM(revenue) as total FROM test_data GROUP BY category VISUALISE category AS x, total AS y DRAW bar", + title="Total Revenue by Category", + ) + + assert result.error is None + assert "ggsql" in captured + assert captured["title"] == "Total Revenue by Category" + # Chart is now rendered on demand, not stored + assert "chart" not in captured + + @ggsql_render_works + def test_handles_filter_in_query(self, data_source): + """Test that WHERE clause filters data correctly.""" + captured = {} + + def update_callback(data: VisualizeQueryData): + captured.update(data) + + tool = tool_visualize_query(data_source, update_callback) + impl = tool.func + + result = impl( + ggsql="SELECT date, revenue FROM test_data WHERE category = 'A' VISUALISE date AS x, revenue AS y DRAW line", + title="Category A Revenue", + ) + + assert result.error is None + # Chart is now rendered on demand, just verify the ggsql was captured + assert "ggsql" in captured + + def test_returns_error_without_visualise(self, data_source): + """Test that query without VISUALISE returns error.""" + captured = {} + + def update_callback(data: VisualizeQueryData): + captured.update(data) + + tool = tool_visualize_query(data_source, update_callback) + impl = tool.func + + result = impl(ggsql="SELECT * FROM test_data", title="No Viz") + + assert result.error is not None + assert "VISUALISE" in str(result.error) + assert "ggsql" not in captured + + +class TestAppStateVisualizationIntegration: + """Integration tests for AppState visualization handling.""" + + @ggsql_render_works + def test_state_serialization_includes_viz(self, data_source): + """Test that to_dict() includes visualization state (specs only).""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_client.get_turns.return_value = [] + + state = AppState(data_source=data_source, client=mock_client) + state.update_filter_viz( + spec="VISUALISE x, y DRAW point", + title="Test Chart", + ) + + state_dict = state.to_dict() + + assert state_dict["filter_viz_spec"] == "VISUALISE x, y DRAW point" + assert state_dict["filter_viz_title"] == "Test Chart" + # Chart is no longer stored, only spec + assert "filter_viz_chart" not in state_dict + + @ggsql_render_works + def test_state_deserialization_restores_viz(self, data_source): + """Test that update_from_dict() restores visualization state (specs only).""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_client.get_turns.return_value = [] + + state = AppState(data_source=data_source, client=mock_client) + + state_dict = { + "sql": None, + "title": None, + "error": None, + "turns": [], + "filter_viz_spec": "VISUALISE a, b DRAW line", + "filter_viz_title": "Restored Chart", + "query_viz_ggsql": None, + "query_viz_title": None, + } + + state.update_from_dict(state_dict) + + assert state.filter_viz_spec == "VISUALISE a, b DRAW line" + assert state.filter_viz_title == "Restored Chart" + + +class TestGgsqlRenderAltair: + """Tests for ggsql.render_altair() which renders charts on demand.""" + + @ggsql_render_works + def test_render_altair_returns_correct_type(self): + """Test that ggsql.render_altair returns Altair charts directly.""" + import altair as alt + import ggsql + + df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]}) + chart = ggsql.render_altair(df, "VISUALISE x, y DRAW point") + + # ggsql.render_altair now returns the correct Altair type + assert isinstance(chart, alt.TopLevelMixin) + spec = chart.to_dict() + assert "$schema" in spec + assert "vega-lite" in spec["$schema"] diff --git a/pkg-py/tests/test_state.py b/pkg-py/tests/test_state.py index 44bca227..58deb6a6 100644 --- a/pkg-py/tests/test_state.py +++ b/pkg-py/tests/test_state.py @@ -141,10 +141,14 @@ class TestCreateAppState: def test_creates_state_with_callbacks(self, data_source): callback_data: dict[str, Any] = {} - def client_factory(update_callback, reset_callback): + def client_factory( + update_callback, reset_callback, filter_viz_callback, query_viz_callback + ): # Store the callbacks for testing callback_data["update_callback"] = update_callback callback_data["reset_callback"] = reset_callback + callback_data["filter_viz_callback"] = filter_viz_callback + callback_data["query_viz_callback"] = query_viz_callback return MagicMock() state = create_app_state(data_source, client_factory, greeting="Welcome!") @@ -161,6 +165,20 @@ def client_factory(update_callback, reset_callback): assert state.sql is None assert state.title is None + # Test that the filter visualization callback works + callback_data["filter_viz_callback"]( + {"spec": "VISUALISE ...", "title": "Filter Chart"} + ) + assert state.filter_viz_spec == "VISUALISE ..." + assert state.filter_viz_title == "Filter Chart" + + # Test that the query visualization callback works + callback_data["query_viz_callback"]( + {"ggsql": "SELECT * VISUALISE ...", "title": "Query Chart"} + ) + assert state.query_viz_ggsql == "SELECT * VISUALISE ..." + assert state.query_viz_title == "Query Chart" + class TestStreamResponse: def test_stream_response_yields_strings(self): @@ -266,6 +284,10 @@ def test_app_state_dict_structure(self): "turns": [ {"role": "user", "contents": [{"content_type": "text", "text": "hi"}]} ], + "filter_viz_spec": None, + "filter_viz_title": None, + "query_viz_ggsql": None, + "query_viz_title": None, } assert state["sql"] == "SELECT * FROM test" assert len(state["turns"]) == 1 @@ -321,6 +343,10 @@ def test_update_from_dict_restores_turns(self, data_source, mock_client): ], }, ], + "filter_viz_spec": None, + "filter_viz_title": None, + "query_viz_ggsql": None, + "query_viz_title": None, } ) @@ -334,5 +360,51 @@ def test_update_from_dict_restores_turns(self, data_source, mock_client): def test_update_from_dict_empty_turns(self, data_source, mock_client): state = AppState(data_source=data_source, client=mock_client) - state.update_from_dict({"sql": None, "title": None, "error": None, "turns": []}) + state.update_from_dict( + { + "sql": None, + "title": None, + "error": None, + "turns": [], + "filter_viz_spec": None, + "filter_viz_title": None, + "query_viz_ggsql": None, + "query_viz_title": None, + } + ) mock_client.set_turns.assert_called_with([]) + + +class TestVisualizationState: + def test_initial_viz_state(self, data_source, mock_client): + state = AppState(data_source=data_source, client=mock_client) + assert state.filter_viz_spec is None + assert state.filter_viz_title is None + assert state.query_viz_ggsql is None + assert state.query_viz_title is None + + def test_update_filter_viz(self, data_source, mock_client): + state = AppState(data_source=data_source, client=mock_client) + state.update_filter_viz( + spec="VISUALISE x, y DRAW point", + title="Scatter Plot", + ) + assert state.filter_viz_spec == "VISUALISE x, y DRAW point" + assert state.filter_viz_title == "Scatter Plot" + + def test_update_query_viz(self, data_source, mock_client): + state = AppState(data_source=data_source, client=mock_client) + state.update_query_viz( + ggsql="SELECT x, y FROM t VISUALISE x, y DRAW point", + title="Query Plot", + ) + assert state.query_viz_ggsql == "SELECT x, y FROM t VISUALISE x, y DRAW point" + assert state.query_viz_title == "Query Plot" + + def test_reset_dashboard_clears_filter_viz(self, data_source, mock_client): + state = AppState(data_source=data_source, client=mock_client) + state.filter_viz_spec = "VISUALISE x, y DRAW point" + state.filter_viz_title = "Plot" + state.reset_dashboard() + assert state.filter_viz_spec is None + assert state.filter_viz_title is None diff --git a/pkg-py/tests/test_tools.py b/pkg-py/tests/test_tools.py index 682f259c..bd7e478e 100644 --- a/pkg-py/tests/test_tools.py +++ b/pkg-py/tests/test_tools.py @@ -12,6 +12,8 @@ def test_querychat_tool_starts_open_default_behavior(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize_dashboard") is True + assert querychat_tool_starts_open("visualize_query") is True def test_querychat_tool_starts_open_expanded(monkeypatch): @@ -21,6 +23,8 @@ def test_querychat_tool_starts_open_expanded(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is True + assert querychat_tool_starts_open("visualize_dashboard") is True + assert querychat_tool_starts_open("visualize_query") is True def test_querychat_tool_starts_open_collapsed(monkeypatch): @@ -30,6 +34,8 @@ def test_querychat_tool_starts_open_collapsed(monkeypatch): assert querychat_tool_starts_open("query") is False assert querychat_tool_starts_open("update") is False assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize_dashboard") is False + assert querychat_tool_starts_open("visualize_query") is False def test_querychat_tool_starts_open_default_setting(monkeypatch): @@ -39,6 +45,8 @@ def test_querychat_tool_starts_open_default_setting(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize_dashboard") is True + assert querychat_tool_starts_open("visualize_query") is True def test_querychat_tool_starts_open_case_insensitive(monkeypatch): diff --git a/pkg-py/tests/test_viz_tools.py b/pkg-py/tests/test_viz_tools.py new file mode 100644 index 00000000..49af7087 --- /dev/null +++ b/pkg-py/tests/test_viz_tools.py @@ -0,0 +1,144 @@ +"""Tests for visualization tool functions.""" + +import narwhals.stable.v1 as nw +import polars as pl +import pytest +from querychat._datasource import DataFrameSource +from querychat.tools import ( + VisualizeDashboardData, + VisualizeQueryData, + tool_visualize_dashboard, + tool_visualize_query, +) + + +@pytest.fixture +def sample_df(): + return pl.DataFrame( + { + "x": [1, 2, 3, 4, 5], + "y": [10, 20, 15, 25, 30], + "category": ["A", "B", "A", "B", "A"], + } + ) + + +@pytest.fixture +def data_source(sample_df): + nw_df = nw.from_native(sample_df) + return DataFrameSource(nw_df, "test_data") + + +def _ggsql_render_works() -> bool: + """Check if ggsql.render_altair() is functional (build can be broken in some envs).""" + try: + import ggsql + + df = pl.DataFrame({"x": [1, 2], "y": [3, 4]}) + result = ggsql.render_altair(df, "VISUALISE x, y DRAW point") + spec = result.to_dict() + return "$schema" in spec + except (ValueError, ImportError): + return False + + +ggsql_render_works = pytest.mark.skipif( + not _ggsql_render_works(), + reason="ggsql.render_altair() not functional (build environment issue)", +) + + +class TestToolVisualizeDashboard: + def test_creates_tool(self, data_source): + callback_data = {} + + def update_fn(data: VisualizeDashboardData): + callback_data.update(data) + + tool = tool_visualize_dashboard(data_source, update_fn) + assert tool.name == "querychat_visualize_dashboard" + + @ggsql_render_works + def test_tool_renders_visualization(self, data_source): + callback_data = {} + + def update_fn(data: VisualizeDashboardData): + callback_data.update(data) + + tool = tool_visualize_dashboard(data_source, update_fn) + impl = tool.func + + impl(viz_spec="VISUALISE x, y DRAW point", title="Test Scatter") + + assert "spec" in callback_data + assert "title" in callback_data + assert callback_data["title"] == "Test Scatter" + # Chart is now rendered on demand, not stored in callback data + assert "chart" not in callback_data + + @ggsql_render_works + def test_tool_extracts_title_from_spec(self, data_source): + callback_data = {} + + def update_fn(data: VisualizeDashboardData): + callback_data.update(data) + + tool = tool_visualize_dashboard(data_source, update_fn) + impl = tool.func + + impl( + viz_spec="VISUALISE x, y DRAW point LABEL title => 'From Spec'", title=None + ) + + # Title from spec should be used when title param is None + assert callback_data["title"] == "From Spec" + + +class TestToolVisualizeQuery: + def test_creates_tool(self, data_source): + callback_data = {} + + def update_fn(data: VisualizeQueryData): + callback_data.update(data) + + tool = tool_visualize_query(data_source, update_fn) + assert tool.name == "querychat_visualize_query" + + @ggsql_render_works + def test_tool_executes_sql_and_renders(self, data_source): + callback_data = {} + + def update_fn(data: VisualizeQueryData): + callback_data.update(data) + + tool = tool_visualize_query(data_source, update_fn) + impl = tool.func + + impl( + ggsql="SELECT x, y FROM test_data WHERE x > 2 VISUALISE x, y DRAW point", + title="Filtered Scatter", + ) + + assert "ggsql" in callback_data + assert "title" in callback_data + assert callback_data["title"] == "Filtered Scatter" + # Chart is now rendered on demand, not stored in callback data + assert "chart" not in callback_data + + @ggsql_render_works + def test_tool_handles_query_without_visualise(self, data_source): + callback_data = {} + + def update_fn(data: VisualizeQueryData): + callback_data.update(data) + + tool = tool_visualize_query(data_source, update_fn) + impl = tool.func + + # Query without VISUALISE should return error result + result = impl(ggsql="SELECT x, y FROM test_data", title="No Viz") + + # Check that error is returned and callback was not called + assert result.error is not None + assert "VISUALISE" in str(result.error) + assert "chart" not in callback_data