Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg-py/src/querychat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from ._deprecated import mod_server as server
from ._deprecated import mod_ui as ui
from ._shiny import QueryChat
from .tools import VisualizeDashboardData, VisualizeQueryData

__all__ = (
"QueryChat",
"VisualizeDashboardData",
"VisualizeQueryData",
# TODO(lifecycle): Remove these deprecated functions when we reach v1.0
"greeting",
"init",
Expand Down
228 changes: 220 additions & 8 deletions pkg-py/src/querychat/_dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from collections.abc import Callable
from pathlib import Path as PathType

import altair as alt
import chatlas
import ibis
import narwhals.stable.v1 as nw
Expand Down Expand Up @@ -113,7 +114,12 @@ def __init__(
*,
greeting: Optional[str | PathType] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = (
"update",
"query",
"visualize_dashboard",
"visualize_query",
),
data_description: Optional[str | PathType] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | PathType] = None,
Expand All @@ -129,7 +135,12 @@ def __init__(
*,
greeting: Optional[str | PathType] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = (
"update",
"query",
"visualize_dashboard",
"visualize_query",
),
data_description: Optional[str | PathType] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | PathType] = None,
Expand All @@ -145,7 +156,12 @@ def __init__(
*,
greeting: Optional[str | PathType] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = (
"update",
"query",
"visualize_dashboard",
"visualize_query",
),
data_description: Optional[str | PathType] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | PathType] = None,
Expand All @@ -161,7 +177,12 @@ def __init__(
*,
greeting: Optional[str | PathType] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = (
"update",
"query",
"visualize_dashboard",
"visualize_query",
),
data_description: Optional[str | PathType] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | PathType] = None,
Expand All @@ -176,7 +197,12 @@ def __init__(
*,
greeting: Optional[str | PathType] = None,
client: Optional[str | chatlas.Chat] = None,
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = (
"update",
"query",
"visualize_dashboard",
"visualize_query",
),
data_description: Optional[str | PathType] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | PathType] = None,
Expand Down Expand Up @@ -207,6 +233,96 @@ def store_id(self) -> str:
"""
return self._ids.store

def ggvis(
self, state: AppStateDict, source: Literal["filter", "query"] = "filter"
) -> alt.Chart | None:
"""
Get the current Altair visualization chart.

Parameters
----------
state
The state dict from the store component.
source
Which visualization to return. "filter" returns the dashboard
visualization (from visualize_dashboard tool), "query" returns
the query visualization (from visualize_query tool).

Returns
-------
:
An Altair Chart object, or None if no visualization exists.

"""
import ggsql

if source == "filter":
spec = state.get("filter_viz_spec")
if spec is None:
return None
# Render against current filtered data
df = as_narwhals(self.df(state))
return ggsql.render_altair(df, spec)
else:
ggsql_query = state.get("query_viz_ggsql")
if ggsql_query is None:
return None
# Re-execute SQL and render
sql, viz_spec = ggsql.split_query(ggsql_query)
df = as_narwhals(self._data_source.execute_query(sql))
return ggsql.render_altair(df, viz_spec)

def ggsql(
self, state: AppStateDict, source: Literal["filter", "query"] = "filter"
) -> str | None:
"""
Get the current ggsql specification.

Parameters
----------
state
The state dict from the store component.
source
Which specification to return. "filter" returns the VISUALISE spec
from visualize_dashboard, "query" returns the full ggsql query
from visualize_query.

Returns
-------
:
The ggsql specification string, or None if no visualization exists.

"""
if source == "filter":
return state.get("filter_viz_spec")
else:
return state.get("query_viz_ggsql")

def ggtitle(
self, state: AppStateDict, source: Literal["filter", "query"] = "filter"
) -> str | None:
"""
Get the current visualization title.

Parameters
----------
state
The state dict from the store component.
source
Which title to return. "filter" returns the title from
visualize_dashboard, "query" returns the title from visualize_query.

Returns
-------
:
The visualization title, or None if no visualization exists.

"""
if source == "filter":
return state.get("filter_viz_title")
else:
return state.get("query_viz_title")

def app(self) -> dash.Dash:
"""
Create a complete Dash app.
Expand Down Expand Up @@ -393,6 +509,34 @@ def app_layout(ids: IDs, table_name: str, chat_ui):
body_class_name="d-flex flex-column",
)

# Filter plot card
filter_plot_card = card_ui(
body=[
html.Iframe(
id=ids.filter_plot,
srcDoc="<p>No filter visualization yet. Ask the assistant to create one.</p>",
style={"width": "100%", "height": "400px", "border": "none"},
),
dcc.Markdown(id=ids.filter_ggsql, className="querychat-ggsql-display mt-2"),
],
title_id=ids.filter_plot_title,
class_name="h-100",
)

# Query plot card
query_plot_card = card_ui(
body=[
html.Iframe(
id=ids.query_plot,
srcDoc="<p>No query visualization yet. Ask the assistant to create one.</p>",
style={"width": "100%", "height": "400px", "border": "none"},
),
dcc.Markdown(id=ids.query_ggsql, className="querychat-ggsql-display mt-2"),
],
title_id=ids.query_plot_title,
class_name="h-100",
)

chat_card = card_ui(
body=chat_ui,
title="Chat",
Expand All @@ -407,7 +551,27 @@ def app_layout(ids: IDs, table_name: str, chat_ui):
[
dbc.Col(chat_card, width=4),
dbc.Col(
[sql_card, data_card],
dbc.Tabs(
[
dbc.Tab(
[sql_card, data_card],
label="Data",
tab_id="data-tab",
),
dbc.Tab(
filter_plot_card,
label="Filter Plot",
tab_id="filter-plot-tab",
),
dbc.Tab(
query_plot_card,
label="Query Plot",
tab_id="query-plot-tab",
),
],
id=ids.tabs,
active_tab="data-tab",
),
width=8,
className="d-flex flex-column",
),
Expand All @@ -426,12 +590,15 @@ def register_app_callbacks(
table_name: str,
deserialize_state: Callable[[AppStateDict], AppState],
) -> None:
"""Register callbacks for SQL display, data table, and export."""
"""Register callbacks for SQL display, data table, visualizations, and export."""
import ggsql
from dash.dcc.express import send_data_frame

import dash
from dash import Input, Output, State

from ._ggsql import vegalite_to_html

@app.callback(
[
Output(ids.sql_title, "children"),
Expand All @@ -440,6 +607,12 @@ def register_app_callbacks(
Output(ids.data_table, "columnDefs"),
Output(ids.data_info, "children"),
Output(ids.store, "data", allow_duplicate=True),
Output(ids.filter_plot_title, "children"),
Output(ids.filter_plot, "srcDoc"),
Output(ids.filter_ggsql, "children"),
Output(ids.query_plot_title, "children"),
Output(ids.query_plot, "srcDoc"),
Output(ids.query_ggsql, "children"),
],
[
Input(ids.store, "data"),
Expand All @@ -459,7 +632,8 @@ def update_display(state_data: AppStateDict, reset_clicks):
sql_title = state.title or "SQL Query"
sql_code = f"```sql\n{state.get_display_sql()}\n```"

nw_df = as_narwhals(state.get_current_data())
current_data = state.get_current_data()
nw_df = as_narwhals(current_data)
nrow, ncol = nw_df.shape

display_df = nw_df.to_pandas()
Expand All @@ -472,13 +646,51 @@ def update_display(state_data: AppStateDict, reset_clicks):
data_info_parts.append(f"Data has {nrow} rows and {ncol} columns.")
data_info = " ".join(data_info_parts)

# Filter visualization - render on demand
filter_title = state.filter_viz_title or "Filter Plot"
filter_spec = state.filter_viz_spec

if filter_spec:
# Render against current filtered data
chart = ggsql.render_altair(nw_df, filter_spec)
filter_html = vegalite_to_html(chart.to_dict())
else:
filter_html = (
"<p>No filter visualization yet. Ask the assistant to create one.</p>"
)

filter_ggsql_md = f"```sql\n{filter_spec}\n```" if filter_spec else ""

# Query visualization - render on demand
query_title = state.query_viz_title or "Query Plot"
query_ggsql_str = state.query_viz_ggsql

if query_ggsql_str:
# Re-execute SQL and render
sql_part, viz_spec = ggsql.split_query(query_ggsql_str)
query_df = as_narwhals(state.data_source.execute_query(sql_part))
chart = ggsql.render_altair(query_df, viz_spec)
query_html = vegalite_to_html(chart.to_dict())
else:
query_html = (
"<p>No query visualization yet. Ask the assistant to create one.</p>"
)

query_ggsql_md = f"```sql\n{query_ggsql_str}\n```" if query_ggsql_str else ""

return (
sql_title,
sql_code,
table_data,
table_columns,
data_info,
state.to_dict(),
filter_title,
filter_html,
filter_ggsql_md,
query_title,
query_html,
query_ggsql_md,
)

@app.callback(
Expand Down
14 changes: 14 additions & 0 deletions pkg-py/src/querychat/_dash_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ class IDs:
data_info: str
download_csv: str
export_button: str
tabs: str
filter_plot: str
filter_plot_title: str
filter_ggsql: str
query_plot: str
query_plot_title: str
query_ggsql: str

@classmethod
def from_table_name(cls, table_name: str) -> IDs:
Expand All @@ -46,6 +53,13 @@ def from_table_name(cls, table_name: str) -> IDs:
data_info=f"{prefix}-data-info",
download_csv=f"{prefix}-download-csv",
export_button=f"{prefix}-export-button",
tabs=f"{prefix}-tabs",
filter_plot=f"{prefix}-filter-plot",
filter_plot_title=f"{prefix}-filter-plot-title",
filter_ggsql=f"{prefix}-filter-ggsql",
query_plot=f"{prefix}-query-plot",
query_plot_title=f"{prefix}-query-plot-title",
query_ggsql=f"{prefix}-query-ggsql",
)


Expand Down
1 change: 1 addition & 0 deletions pkg-py/src/querychat/_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(self, df: nw.DataFrame, table_name: str):
self._df_lib = native_namespace.__name__

self._conn = duckdb.connect(database=":memory:")
# NOTE: if native representation is polars, pyarrow is required for registration
self._conn.register(table_name, self._df.to_native())
self._conn.execute("""
-- extensions: lock down supply chain + auto behaviors
Expand Down
Loading
Loading