diff --git a/pkg-py/examples/10-viz-app.py b/pkg-py/examples/10-viz-app.py new file mode 100644 index 000000000..54df9a164 --- /dev/null +++ b/pkg-py/examples/10-viz-app.py @@ -0,0 +1,21 @@ +from querychat import QueryChat +from querychat.data import titanic + +from shiny import App, ui + +qc = QueryChat( + titanic(), + "titanic", + tools=("query", "visualize_query"), +) + +app_ui = ui.page_fillable( + qc.ui(), +) + + +def server(input, output, session): + qc.server() + + +app = App(app_ui, server) diff --git a/pkg-py/src/querychat/__init__.py b/pkg-py/src/querychat/__init__.py index 0e3eaa5f5..04ba01723 100644 --- a/pkg-py/src/querychat/__init__.py +++ b/pkg-py/src/querychat/__init__.py @@ -2,9 +2,11 @@ from ._deprecated import mod_server as server from ._deprecated import mod_ui as ui from ._shiny import QueryChat +from .tools import VisualizeQueryData __all__ = ( "QueryChat", + "VisualizeQueryData", # TODO(lifecycle): Remove these deprecated functions when we reach v1.0 "greeting", "init", diff --git a/pkg-py/src/querychat/_datasource.py b/pkg-py/src/querychat/_datasource.py index 5cac5f08c..af7628f5c 100644 --- a/pkg-py/src/querychat/_datasource.py +++ b/pkg-py/src/querychat/_datasource.py @@ -214,6 +214,7 @@ def __init__(self, df: nw.DataFrame, table_name: str): self._df_lib = native_namespace.__name__ self._conn = duckdb.connect(database=":memory:") + # NOTE: if native representation is polars, pyarrow is required for registration self._conn.register(table_name, self._df.to_native()) self._conn.execute(""" -- extensions: lock down supply chain + auto behaviors diff --git a/pkg-py/src/querychat/_ggsql.py b/pkg-py/src/querychat/_ggsql.py new file mode 100644 index 000000000..34dfbb5d4 --- /dev/null +++ b/pkg-py/src/querychat/_ggsql.py @@ -0,0 +1,80 @@ +"""Helpers for ggsql integration.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +import narwhals.stable.v1 as nw + +if TYPE_CHECKING: + import ggsql + import polars as pl + from narwhals.stable.v1.typing import IntoFrame + + from ._datasource import DataSource + + +def to_polars(data: IntoFrame) -> pl.DataFrame: + """Convert any narwhals-compatible frame to a polars DataFrame.""" + nw_df = nw.from_native(data) + if isinstance(nw_df, nw.LazyFrame): + nw_df = nw_df.collect() + return nw_df.to_polars() + + +def execute_ggsql(data_source: DataSource, query: str) -> ggsql.Spec: + """ + Execute a full ggsql query against a DataSource, returning a Spec. + + Uses ggsql.validate() to split SQL from VISUALISE, executes the SQL + through DataSource (preserving database pushdown), then feeds the result + into a ggsql DuckDBReader to produce a Spec. + + Parameters + ---------- + data_source + The querychat DataSource to execute the SQL portion against. + query + A full ggsql query (SQL + VISUALISE). + + Returns + ------- + ggsql.Spec + The writer-independent plot specification. + + """ + import ggsql as _ggsql + + validated = _ggsql.validate(query) + pl_df = to_polars(data_source.execute_query(validated.sql())) + + reader = _ggsql.DuckDBReader("duckdb://memory") + reader.register("_data", pl_df) + return reader.execute(f"SELECT * FROM _data {validated.visual()}") + + +def spec_to_altair(spec: ggsql.Spec) -> ggsql.AltairChart: + """Render a ggsql Spec to an Altair chart via VegaLiteWriter.""" + import ggsql as _ggsql + + writer = _ggsql.VegaLiteWriter() + return writer.render_chart(spec, validate=False) + + +def extract_title(spec: ggsql.Spec) -> str | None: + """ + Extract the title from a ggsql Spec's rendered Vega-Lite JSON. + + TODO: Replace with ``spec.title()`` once ggsql exposes this natively. + """ + import ggsql as _ggsql + + writer = _ggsql.VegaLiteWriter() + vl: dict[str, object] = json.loads(writer.render(spec)) + title = vl.get("title") + if isinstance(title, str): + return title + if isinstance(title, dict): + return title.get("text") + return None diff --git a/pkg-py/src/querychat/_icons.py b/pkg-py/src/querychat/_icons.py index 2b7683da0..61880f830 100644 --- a/pkg-py/src/querychat/_icons.py +++ b/pkg-py/src/querychat/_icons.py @@ -2,7 +2,14 @@ from shiny import ui -ICON_NAMES = Literal["arrow-counterclockwise", "funnel-fill", "terminal-fill", "table"] +ICON_NAMES = Literal[ + "arrow-counterclockwise", + "bar-chart-fill", + "funnel-fill", + "graph-up", + "terminal-fill", + "table", +] def bs_icon(name: ICON_NAMES) -> ui.HTML: @@ -14,7 +21,9 @@ def bs_icon(name: ICON_NAMES) -> ui.HTML: BS_ICONS = { "arrow-counterclockwise": '', + "bar-chart-fill": '', "funnel-fill": '', + "graph-up": '', "terminal-fill": '', "table": '', } diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py index e8a7c7f15..839e2e5e6 100644 --- a/pkg-py/src/querychat/_querychat_base.py +++ b/pkg-py/src/querychat/_querychat_base.py @@ -25,9 +25,11 @@ from ._utils import MISSING, MISSING_TYPE, is_ibis_table from .tools import ( UpdateDashboardData, + VisualizeQueryData, tool_query, tool_reset_dashboard, tool_update_dashboard, + tool_visualize_query, ) if TYPE_CHECKING: @@ -35,7 +37,32 @@ from narwhals.stable.v1.typing import IntoFrame -TOOL_GROUPS = Literal["update", "query"] +TOOL_GROUPS = Literal["update", "query", "visualize_query"] +DEFAULT_TOOLS: tuple[TOOL_GROUPS, ...] = ("update", "query") +ALL_TOOLS: tuple[TOOL_GROUPS, ...] = ( + "update", + "query", + "visualize_query", +) + +VIZ_TOOLS: tuple[TOOL_GROUPS, ...] = ("visualize_query",) + + +def check_viz_dependencies(tools: tuple[TOOL_GROUPS, ...] | None) -> None: + """Raise ImportError early if viz tools are requested but ggsql is not installed.""" + if tools is None: + return + has_viz = any(t in VIZ_TOOLS for t in tools) + if not has_viz: + return + try: + import altair as alt # noqa: F401 + import ggsql # noqa: F401 + except ImportError as e: + raise ImportError( + f"Visualization tools require ggsql and altair: {e}. " + "Install them with: pip install querychat[viz]" + ) from e class QueryChatBase(Generic[IntoFrameT]): @@ -58,7 +85,7 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -72,7 +99,8 @@ def __init__( "Table name must begin with a letter and contain only letters, numbers, and underscores", ) - self.tools = normalize_tools(tools, default=("update", "query")) + self.tools = normalize_tools(tools, default=DEFAULT_TOOLS) + check_viz_dependencies(self.tools) self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting # Store init parameters for deferred system prompt building @@ -132,6 +160,7 @@ def client( tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, reset_dashboard: Callable[[], None] | None = None, + visualize_query: Callable[[VisualizeQueryData], None] | None = None, ) -> chatlas.Chat: """ Create a chat client with registered tools. @@ -139,11 +168,14 @@ def client( Parameters ---------- tools - Which tools to include: `"update"`, `"query"`, or both. + Which tools to include: `"update"`, `"query"`, `"visualize_query"`, + or a combination. update_dashboard Callback when update_dashboard tool succeeds. reset_dashboard Callback when reset_dashboard tool is invoked. + visualize_query + Callback when visualize_query tool succeeds. Returns ------- @@ -172,6 +204,10 @@ def client( if "query" in tools: chat.register_tool(tool_query(data_source)) + if "visualize_query" in tools: + query_viz_fn = visualize_query or (lambda _: None) + chat.register_tool(tool_visualize_query(self._data_source, query_viz_fn)) + return chat def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str: diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py index c1dcc9a19..bbda9fb76 100644 --- a/pkg-py/src/querychat/_shiny.py +++ b/pkg-py/src/querychat/_shiny.py @@ -10,13 +10,14 @@ from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui from ._icons import bs_icon -from ._querychat_base import TOOL_GROUPS, QueryChatBase +from ._querychat_base import DEFAULT_TOOLS, TOOL_GROUPS, QueryChatBase from ._shiny_module import ServerValues, mod_server, mod_ui from ._utils import as_narwhals if TYPE_CHECKING: from pathlib import Path + import altair as alt import chatlas import ibis import narwhals.stable.v1 as nw @@ -97,10 +98,11 @@ class QueryChat(QueryChatBase[IntoFrameT]): tools Which querychat tools to include in the chat client by default. Can be: - A single tool string: `"update"` or `"query"` - - A tuple of tools: `("update", "query")` + - A tuple of tools: `("update", "query", "visualize_query")` - `None` or `()` to disable all tools - Default is `("update", "query")` (both tools enabled). + Default is `("update", "query")`. The visualization tool (`"visualize_query"`) + can be opted into by including it in the tuple. Set to `"update"` to prevent the LLM from accessing data values, only allowing dashboard filtering without answering questions. @@ -156,7 +158,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -172,7 +174,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -188,7 +190,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -204,7 +206,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -219,7 +221,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -245,9 +247,13 @@ def app( """ Quickly chat with a dataset. - Creates a Shiny app with a chat sidebar and data table view -- providing a + Creates a Shiny app with a chat sidebar and tabbed view -- providing a quick-and-easy way to start chatting with your data. + The app includes two tabs: + - **Data**: Shows the filtered data table + - **Query Plot**: Shows the most recent query visualization + Parameters ---------- bookmark_store @@ -266,7 +272,30 @@ def app( enable_bookmarking = bookmark_store != "disable" table_name = data_source.table_name + tools_tuple = ( + (self.tools,) + if isinstance(self.tools, str) + else (self.tools or ()) + ) + has_query_viz = "visualize_query" in tools_tuple + def app_ui(request): + nav_panels = [ + ui.nav_panel( + "Data", + ui.card( + ui.card_header(bs_icon("table"), " Data"), + ui.output_data_frame("dt"), + ), + ), + ] + if has_query_viz: + nav_panels.append( + ui.nav_panel( + "Query Plot", + ui.output_ui("query_plot_container"), + ) + ) return ui.page_sidebar( self.sidebar(), ui.card( @@ -285,10 +314,7 @@ def app_ui(request): fill=False, style="max-height: 33%;", ), - ui.card( - ui.card_header(bs_icon("table"), " Data"), - ui.output_data_frame("dt"), - ), + ui.navset_tab(*nav_panels, id="main_tabs"), title=ui.span("querychat with ", ui.code(table_name)), class_="bslib-page-dashboard", fillable=True, @@ -301,6 +327,7 @@ def app_server(input: Inputs, output: Outputs, session: Session): greeting=self.greeting, client=self._client, enable_bookmarking=enable_bookmarking, + tools=self.tools, ) @render.text @@ -338,6 +365,37 @@ def sql_output(): width="100%", ) + if has_query_viz: + + @render.ui + def query_plot_container(): + from shinywidgets import output_widget, render_altair + + chart = vals.query_viz_chart() + if chart is None: + return ui.card( + ui.card_body( + ui.p( + "No query visualization yet. " + "Use the chat to create one." + ), + class_="text-muted text-center py-5", + ), + ) + + @render_altair + def query_chart(): + return chart + + return ui.card( + ui.card_header( + bs_icon("bar-chart-fill"), + " ", + vals.query_viz_title.get() or "Query Visualization", + ), + output_widget("query_chart"), + ) + return App(app_ui, app_server, bookmark_store=bookmark_store) def sidebar( @@ -493,6 +551,7 @@ def title(): greeting=self.greeting, client=self.client, enable_bookmarking=enable_bookmarking, + tools=self.tools, ) @@ -730,6 +789,7 @@ def __init__( greeting=self.greeting, client=self._client, enable_bookmarking=enable, + tools=self.tools, ) def sidebar( @@ -870,3 +930,39 @@ def title(self, value: Optional[str] = None) -> str | None | bool: return self._vals.title() else: return self._vals.title.set(value) + + def ggvis(self) -> alt.Chart | None: + """ + Get the visualization chart from the most recent visualize_query call. + + Returns + ------- + : + The Altair chart, or None if no visualization exists. + + """ + return self._vals.query_viz_chart() + + def ggsql(self) -> str | None: + """ + Get the full ggsql query from the most recent visualize_query call. + + Returns + ------- + : + The ggsql query string, or None if no visualization exists. + + """ + return self._vals.query_viz_ggsql.get() + + def ggtitle(self) -> str | None: + """ + Get the visualization title from the most recent visualize_query call. + + Returns + ------- + : + The title, or None if no visualization exists. + + """ + return self._vals.query_viz_title.get() diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py index 335f6803a..acc550411 100644 --- a/pkg-py/src/querychat/_shiny_module.py +++ b/pkg-py/src/querychat/_shiny_module.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import logging import warnings from dataclasses import dataclass from pathlib import Path @@ -13,17 +14,26 @@ from shiny import module, reactive, ui from ._querychat_core import GREETING_PROMPT -from .tools import tool_query, tool_reset_dashboard, tool_update_dashboard +from .tools import ( + tool_query, + tool_reset_dashboard, + tool_update_dashboard, + tool_visualize_query, +) if TYPE_CHECKING: from collections.abc import Callable + import altair as alt from shiny.bookmark import BookmarkState, RestoreState from shiny import Inputs, Outputs, Session from ._datasource import DataSource - from .types import UpdateDashboardData + from ._querychat_base import TOOL_GROUPS + from .tools import UpdateDashboardData, VisualizeQueryData + +logger = logging.getLogger(__name__) ReactiveString = reactive.Value[str] """A reactive string value.""" @@ -79,6 +89,16 @@ class ServerValues(Generic[IntoFrameT]): The session-specific chat client instance. This is a deep copy of the base client configured for this specific session, containing the chat history and tool registrations for this session only. + query_viz_ggsql + A reactive Value containing the full ggsql query from visualize_query. + Returns `None` if no visualization has been created. + query_viz_title + A reactive Value containing the title from visualize_query. + Returns `None` if no visualization has been created. + query_viz_chart + A callable returning the rendered Altair chart from visualize_query. + Returns `None` if no visualization has been created. The chart is + re-rendered on each call using `ggsql.render_altair()`. """ @@ -86,6 +106,10 @@ class ServerValues(Generic[IntoFrameT]): sql: ReactiveStringOrNone title: ReactiveStringOrNone client: chatlas.Chat + # Visualization state + query_viz_ggsql: ReactiveStringOrNone + query_viz_title: ReactiveStringOrNone + query_viz_chart: Callable[[], alt.TopLevelMixin | None] @module.server @@ -98,12 +122,17 @@ def mod_server( greeting: str | None, client: chatlas.Chat | Callable, enable_bookmarking: bool, + tools: tuple[TOOL_GROUPS, ...] | None = None, ) -> ServerValues[IntoFrameT]: # Reactive values to store state sql = ReactiveStringOrNone(None) title = ReactiveStringOrNone(None) has_greeted = reactive.value[bool](False) # noqa: FBT003 + # Visualization state - store only specs, render on demand + query_viz_ggsql = ReactiveStringOrNone(None) + query_viz_title = ReactiveStringOrNone(None) + # Short-circuit for stub sessions (e.g. 1st run of an Express app) # data_source may be None during stub session for deferred pattern if session.is_stub_session(): @@ -116,6 +145,9 @@ def _stub_df(): sql=sql, title=title, client=client if isinstance(client, chatlas.Chat) else client(), + query_viz_ggsql=query_viz_ggsql, + query_viz_title=query_viz_title, + query_viz_chart=lambda: None, ) # Real session requires data_source @@ -133,11 +165,17 @@ def reset_dashboard(): sql.set(None) title.set(None) + def update_query_viz(data: VisualizeQueryData): + query_viz_ggsql.set(data["ggsql"]) + query_viz_title.set(data["title"]) + # Set up the chat object for this session # Support both a callable that creates a client and legacy instance pattern if callable(client) and not isinstance(client, chatlas.Chat): chat = client( - update_dashboard=update_dashboard, reset_dashboard=reset_dashboard + update_dashboard=update_dashboard, + reset_dashboard=reset_dashboard, + visualize_query=update_query_viz, ) else: # Legacy pattern: client is Chat instance @@ -147,12 +185,26 @@ def reset_dashboard(): chat.register_tool(tool_query(data_source)) chat.register_tool(tool_reset_dashboard(reset_dashboard)) + if tools and "visualize_query" in tools: + chat.register_tool(tool_visualize_query(data_source, update_query_viz)) + # Execute query when SQL changes @reactive.calc def filtered_df(): query = sql.get() - df = data_source.get_data() if not query else data_source.execute_query(query) - return df + return data_source.get_data() if not query else data_source.execute_query(query) + + # Render query visualization on demand + @reactive.calc + def render_query_viz_chart(): + from ._ggsql import execute_ggsql, spec_to_altair + + ggsql_query = query_viz_ggsql.get() + if ggsql_query is None: + return None + + spec = execute_ggsql(data_source, ggsql_query) + return spec_to_altair(spec) # Chat UI logic chat_ui = shinychat.Chat(CHAT_ID) @@ -209,6 +261,8 @@ def _on_bookmark(x: BookmarkState) -> None: vals["querychat_sql"] = sql.get() vals["querychat_title"] = title.get() vals["querychat_has_greeted"] = has_greeted.get() + vals["querychat_query_viz_ggsql"] = query_viz_ggsql.get() + vals["querychat_query_viz_title"] = query_viz_title.get() @session.bookmark.on_restore def _on_restore(x: RestoreState) -> None: @@ -219,8 +273,20 @@ def _on_restore(x: RestoreState) -> None: title.set(vals["querychat_title"]) if "querychat_has_greeted" in vals: has_greeted.set(vals["querychat_has_greeted"]) - - return ServerValues(df=filtered_df, sql=sql, title=title, client=chat) + if "querychat_query_viz_ggsql" in vals: + query_viz_ggsql.set(vals["querychat_query_viz_ggsql"]) + if "querychat_query_viz_title" in vals: + query_viz_title.set(vals["querychat_query_viz_title"]) + + return ServerValues( + df=filtered_df, + sql=sql, + title=title, + client=chat, + query_viz_ggsql=query_viz_ggsql, + query_viz_title=query_viz_title, + query_viz_chart=render_query_viz_chart, + ) class GreetWarning(Warning): diff --git a/pkg-py/src/querychat/_system_prompt.py b/pkg-py/src/querychat/_system_prompt.py index 5a8445e93..c81ff478b 100644 --- a/pkg-py/src/querychat/_system_prompt.py +++ b/pkg-py/src/querychat/_system_prompt.py @@ -83,6 +83,7 @@ def render(self, tools: tuple[TOOL_GROUPS, ...] | None) -> str: "extra_instructions": self.extra_instructions, "has_tool_update": "update" in tools if tools else False, "has_tool_query": "query" in tools if tools else False, + "has_tool_visualize_query": "visualize_query" in tools if tools else False, "include_query_guidelines": len(tools or ()) > 0, } diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py index 555e8e376..7b620e50a 100644 --- a/pkg-py/src/querychat/_utils.py +++ b/pkg-py/src/querychat/_utils.py @@ -171,14 +171,18 @@ def get_tool_details_setting() -> Optional[Literal["expanded", "collapsed", "def return setting_lower -def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> bool: +def querychat_tool_starts_open( + action: Literal[ + "update", "query", "reset", "visualize_query" + ], +) -> bool: """ Determine whether a tool card should be open based on action and setting. Parameters ---------- action : str - The action type ('update', 'query', or 'reset') + The action type ('update', 'query', 'reset', or 'visualize_query') Returns ------- diff --git a/pkg-py/src/querychat/prompts/prompt.md b/pkg-py/src/querychat/prompts/prompt.md index 8c6ff97bc..0ed81bd97 100644 --- a/pkg-py/src/querychat/prompts/prompt.md +++ b/pkg-py/src/querychat/prompts/prompt.md @@ -180,6 +180,12 @@ You might want to explore the advanced features - Never use generic phrases like "If you'd like to..." or "Would you like to explore..." — instead, provide concrete suggestions - Never refer to suggestions as "prompts" – call them "suggestions" or "ideas" or similar +{{#has_tool_visualize_query}} +## Visualization with ggsql + +You can create visualizations using the `visualize_query` tool, which uses ggsql — a SQL extension for declarative data visualization. The tool description contains the full ggsql syntax reference. Always consult it when constructing visualization queries. +{{/has_tool_visualize_query}} + ## Important Guidelines - **Ask for clarification** if any request is unclear or ambiguous diff --git a/pkg-py/src/querychat/prompts/tool-visualize-query.md b/pkg-py/src/querychat/prompts/tool-visualize-query.md new file mode 100644 index 000000000..4bb33d40d --- /dev/null +++ b/pkg-py/src/querychat/prompts/tool-visualize-query.md @@ -0,0 +1,424 @@ +Run an exploratory visualization query inline in the chat. + +## When to Use + +- The user asks an exploratory question that benefits from visualization +- You want to show a one-off chart without affecting the dashboard filter +- You need to visualize data with specific SQL transformations + +## Behavior + +- Executes the SQL query against the data source +- Renders the visualization inline in the chat +- The chart is also accessible via the Query Plot tab +- Does NOT affect the dashboard filter or filtered data +- Each call replaces the previous query visualization +- The `title` parameter is displayed as the card header above the chart — do NOT also put a title in the ggsql query via `LABEL title => ...` as it will be redundant +- Always provide the `title` parameter with a brief, descriptive title for the visualization + +## ggsql Syntax Reference + +### Quick Reference + +```sql +[WITH cte AS (...), ...] +[SELECT columns FROM table WHERE conditions] +VISUALISE [mappings] [FROM source] +DRAW geom_type + [MAPPING col AS aesthetic, ... FROM source] + [REMAPPING stat AS aesthetic, ...] + [SETTING param => value, ...] + [FILTER sql_condition] + [PARTITION BY col, ...] + [ORDER BY col [ASC|DESC], ...] +[SCALE [TYPE] aesthetic [FROM ...] [TO ...] [VIA ...] [SETTING ...] [RENAMING ...]] +[COORD type SETTING property => value, ...] +[FACET var | row_var BY col_var [SETTING free => 'x'|'y'|['x','y'], ncol => N]] +[LABEL x => '...', y => '...', ...] +[THEME name [SETTING property => value, ...]] +``` + +### VISUALISE Clause + +Entry point for visualization. Marks where SQL ends and visualization begins. + +```sql +-- After SELECT (most common) +SELECT date, revenue, region FROM sales +VISUALISE date AS x, revenue AS y, region AS color +DRAW line + +-- Shorthand with FROM (auto-generates SELECT * FROM) +VISUALISE FROM sales +DRAW bar MAPPING region AS x, total AS y +``` + +### Mapping Styles + +| Style | Syntax | Use When | +|-------|--------|----------| +| Explicit | `date AS x` | Column name differs from aesthetic | +| Implicit | `x` | Column name equals aesthetic name | +| Wildcard | `*` | Map all matching columns automatically | +| Literal | `'string' AS color` | Use a literal value (for legend labels in multi-layer plots) | + +### DRAW Clause (Layers) + +Multiple DRAW clauses create layered visualizations. + +```sql +DRAW geom_type + [MAPPING col AS aesthetic, ... FROM source] + [REMAPPING stat AS aesthetic, ...] + [SETTING param => value, ...] + [FILTER sql_condition] + [PARTITION BY col, ...] + [ORDER BY col [ASC|DESC], ...] +``` + +**Geom types:** + +| Category | Types | +|----------|-------| +| Basic | `point`, `line`, `path`, `bar`, `area`, `tile`, `polygon`, `ribbon` | +| Statistical | `histogram`, `density`, `smooth`, `boxplot`, `violin` | +| Annotation | `text`, `label`, `segment`, `arrow`, `hline`, `vline`, `abline`, `errorbar` | + +**Aesthetics (MAPPING):** + +| Category | Aesthetics | +|----------|------------| +| Position | `x`, `y`, `xmin`, `xmax`, `ymin`, `ymax`, `xend`, `yend` | +| Color | `color`/`colour`, `fill`, `stroke`, `opacity` | +| Size/Shape | `size`, `shape`, `linewidth`, `linetype`, `width`, `height` | +| Text | `label`, `family`, `fontface`, `hjust`, `vjust` | +| Aggregation | `weight` (for histogram/bar/density/violin) | + +**Layer-specific data source:** Each layer can use a different data source: + +```sql +WITH summary AS (SELECT region, SUM(sales) as total FROM sales GROUP BY region) +SELECT * FROM sales +VISUALISE date AS x, amount AS y +DRAW line +DRAW bar MAPPING region AS x, total AS y FROM summary +``` + +**PARTITION BY** groups data without visual encoding (useful for separate lines per group without color): + +```sql +DRAW line PARTITION BY category +``` + +**ORDER BY** controls row ordering within a layer: + +```sql +DRAW line ORDER BY date ASC +``` + +### Statistical Layers and REMAPPING + +Some layers compute statistics. Use REMAPPING to access computed values: + +| Layer | Computed Stats | Default Remapping | +|-------|---------------|-------------------| +| `bar` (y unmapped) | `count`, `proportion` | `count AS y` | +| `histogram` | `count`, `density` | `count AS y` | +| `density` | `density`, `intensity` | `density AS y` | +| `violin` | `density`, `intensity` | `density AS offset` | +| `boxplot` | `value`, `type` | `value AS y` | + +`density` computes a KDE from a continuous `x`. Settings: `bandwidth` (numeric), `adjust` (multiplier, default 1), `kernel` (`'gaussian'` default, `'epanechnikov'`, `'triangular'`, `'rectangular'`, `'biweight'`, `'cosine'`), `stacking` (`'off'` default, `'on'`, `'fill'`). Use `REMAPPING intensity AS y` to show unnormalized density that reflects group size differences. + +`violin` displays mirrored KDE curves for groups. Requires both `x` (categorical) and `y` (continuous). Accepts the same bandwidth/adjust/kernel settings as density. Use `REMAPPING intensity AS offset` to reflect group size differences. + +**Examples:** + +```sql +-- Density histogram (instead of count) +VISUALISE FROM products +DRAW histogram MAPPING price AS x REMAPPING density AS y + +-- Bar showing proportion +VISUALISE FROM sales +DRAW bar MAPPING region AS x REMAPPING proportion AS y + +-- Overlay histogram and density on the same scale +VISUALISE FROM measurements +DRAW histogram MAPPING value AS x SETTING opacity => 0.5 +DRAW density MAPPING value AS x REMAPPING intensity AS y SETTING opacity => 0.5 + +-- Violin plot +SELECT department, salary FROM employees +VISUALISE department AS x, salary AS y +DRAW violin +``` + +### SCALE Clause + +Configures how data maps to visual properties. All sub-clauses are optional; type and transform are auto-detected from data when omitted. + +```sql +SCALE [TYPE] aesthetic [FROM range] [TO output] [VIA transform] [SETTING prop => value, ...] [RENAMING ...] +``` + +**Type identifiers** (optional — auto-detected if omitted): + +| Type | Description | +|------|-------------| +| `CONTINUOUS` | Numeric data on a continuous axis | +| `DISCRETE` | Categorical/nominal data | +| `BINNED` | Pre-bucketed data | +| `ORDINAL` | Ordered categories with interpolated output | +| `IDENTITY` | Data values are already visual values (e.g., literal hex colors) | + +**FROM** — input domain: +```sql +SCALE x FROM [0, 100] -- explicit min and max +SCALE x FROM [0, null] -- explicit min, auto max +SCALE DISCRETE x FROM ['A', 'B', 'C'] -- explicit category order +``` + +**TO** — output range or palette: +```sql +SCALE color TO navia -- named palette (default continuous: navia) +SCALE color TO viridis -- other continuous: viridis, plasma, inferno, magma, cividis, batlow +SCALE color TO vik -- diverging: vik, rdbu, rdylbu, spectral, brbg +SCALE DISCRETE color TO ggsql10 -- discrete (default: ggsql10): tableau10, category10, set1, set2, dark2 +SCALE color TO ['red', 'blue'] -- explicit color array +SCALE size TO [1, 10] -- numeric output range +``` + +**VIA** — transformation: +```sql +SCALE x VIA date -- date axis (auto-detected from Date columns) +SCALE x VIA datetime -- datetime axis +SCALE y VIA log10 -- base-10 logarithm +SCALE y VIA sqrt -- square root +``` + +| Category | Transforms | +|----------|------------| +| Logarithmic | `log10`, `log2`, `log` (natural) | +| Power | `sqrt`, `square` | +| Exponential | `exp`, `exp2`, `exp10` | +| Other | `asinh`, `pseudo_log` | +| Temporal | `date`, `datetime`, `time` | +| Type coercion | `integer`, `string`, `bool` | + +**SETTING** — additional properties: +```sql +SCALE x SETTING breaks => 5 -- number of tick marks +SCALE x SETTING breaks => '2 months' -- interval-based breaks +SCALE x SETTING expand => 0.05 -- expand scale range by 5% +SCALE x SETTING reverse => true -- reverse direction +``` + +**RENAMING** — custom axis/legend labels: +```sql +SCALE DISCRETE x RENAMING 'A' => 'Alpha', 'B' => 'Beta' +SCALE CONTINUOUS x RENAMING * => '{} units' -- template for all labels +SCALE x VIA date RENAMING * => '{:time %b %Y}' -- date label formatting +``` + +### Date/Time Axes + +Temporal transforms are auto-detected from column data types. Use `VIA date` explicitly only when the column isn't typed as Date (e.g., after `DATE_TRUNC` which returns timestamps). + +**Break intervals:** +```sql +SCALE x SETTING breaks => 'month' -- one break per month +SCALE x SETTING breaks => '2 weeks' -- every 2 weeks +SCALE x SETTING breaks => '3 months' -- quarterly +SCALE x SETTING breaks => 'year' -- yearly +``` + +Valid units: `day`, `week`, `month`, `year` (for date); also `hour`, `minute`, `second` (for datetime/time). + +**Date label formatting** (strftime syntax): +```sql +SCALE x VIA date RENAMING * => '{:time %b %Y}' -- "Jan 2024" +SCALE x VIA date RENAMING * => '{:time %B %d, %Y}' -- "January 15, 2024" +SCALE x VIA date RENAMING * => '{:time %b %d}' -- "Jan 15" +``` + +### COORD Clause + +Sets coordinate system. Types: `cartesian` (default), `flip`, `polar`, `fixed`, `trans`, `map`, `quickmap`. + +```sql +COORD cartesian SETTING xlim => [0, 100], ylim => [0, 50] +COORD polar -- Pie/radial charts +COORD polar SETTING theta => y +``` + +**WARNING:** `COORD flip` is currently broken and produces empty charts. Avoid using it. + +### FACET Clause + +Creates small multiples (subplots by category). + +```sql +FACET category -- Single variable, wrapped layout +FACET row_var BY col_var -- Grid layout (rows x columns) +FACET category SETTING free => 'y' -- Independent y-axes +FACET category SETTING free => ['x', 'y'] -- Independent both axes +FACET category SETTING ncol => 4 -- Control number of columns +``` + +Custom strip labels via SCALE: +```sql +FACET region +SCALE panel RENAMING 'N' => 'North', 'S' => 'South' +``` + +### LABEL Clause + +Use LABEL for axis labels only. Do NOT use `title =>` — the tool's `title` parameter handles chart titles. + +```sql +LABEL x => 'X Axis Label', y => 'Y Axis Label' +``` + +### THEME Clause + +Available themes: `minimal`, `classic`, `gray`/`grey`, `bw`, `dark`, `light`, `void` + +```sql +THEME minimal +THEME dark +THEME classic SETTING background => '#f5f5f5' +``` + +## Complete Examples + +**Line chart with multiple series:** +```sql +SELECT date, revenue, region FROM sales WHERE year = 2024 +VISUALISE date AS x, revenue AS y, region AS color +DRAW line +SCALE x VIA date +LABEL x => 'Date', y => 'Revenue ($)' +THEME minimal +``` + +**Bar chart (auto-count):** +```sql +VISUALISE FROM products +DRAW bar MAPPING category AS x +``` + +**Scatter plot with trend line:** +```sql +SELECT mpg, hp, cylinders FROM cars +VISUALISE mpg AS x, hp AS y +DRAW point MAPPING cylinders AS color +DRAW smooth +``` + +**Histogram with density overlay:** +```sql +VISUALISE FROM measurements +DRAW histogram MAPPING value AS x SETTING bins => 20, opacity => 0.5 +DRAW density MAPPING value AS x REMAPPING intensity AS y SETTING opacity => 0.5 +``` + +**Density plot with groups:** +```sql +VISUALISE FROM measurements +DRAW density MAPPING value AS x, category AS color SETTING opacity => 0.7 +``` + +**Faceted chart:** +```sql +SELECT month, sales, region FROM data +VISUALISE month AS x, sales AS y +DRAW line +DRAW point +FACET region +SCALE x VIA date +``` + +**CTE with aggregation and date formatting:** +```sql +WITH monthly AS ( + SELECT DATE_TRUNC('month', order_date) as month, SUM(amount) as total + FROM orders GROUP BY 1 +) +VISUALISE month AS x, total AS y FROM monthly +DRAW line +DRAW point +SCALE x VIA date SETTING breaks => 'month' RENAMING * => '{:time %b %Y}' +LABEL y => 'Revenue ($)' +``` + +**Ribbon / confidence band:** +```sql +WITH daily AS ( + SELECT DATE_TRUNC('day', timestamp) as day, + AVG(temperature) as avg_temp, + MIN(temperature) as min_temp, + MAX(temperature) as max_temp + FROM sensor_data + GROUP BY DATE_TRUNC('day', timestamp) +) +VISUALISE day AS x FROM daily +DRAW ribbon MAPPING min_temp AS ymin, max_temp AS ymax SETTING opacity => 0.3 +DRAW line MAPPING avg_temp AS y +SCALE x VIA date +LABEL y => 'Temperature' +``` + +## Important Notes + +1. **Date columns**: Use `SCALE x VIA date` for date/time columns. It's auto-detected from `DATE` columns but needed after `DATE_TRUNC` (which returns `TIMESTAMP`). Customize labels with `RENAMING * => '{:time ...}'` for readable axes. +2. **Multiple layers**: Use multiple DRAW clauses for overlaid visualizations. +3. **Charts vs Tables**: For visualizations use VISUALISE with DRAW. For tabular data use plain SQL without VISUALISE. +4. **CTEs work**: Use `WITH ... SELECT ... VISUALISE` or shorthand `WITH ... VISUALISE FROM cte_name`. +5. **Statistical layers**: When using `histogram`, `bar` (without y), `density`, `violin`, or `boxplot`, the layer computes statistics. Use REMAPPING to access `density`, `intensity`, `proportion`, etc. +6. **Stacked bars via fill**: Map a categorical column to `fill` — there is no `position => 'stack'` setting: + ```sql + DRAW bar MAPPING category AS x, subcategory AS fill + ``` +7. **String values use single quotes**: In SETTING, LABEL, and RENAMING clauses, always use single quotes for string values. Double quotes cause parse errors. +8. **Column casing**: DuckDB lowercases unquoted column names. VISUALISE validates column references case-sensitively. Always alias to lowercase in SELECT: + ```sql + -- WRONG: uppercase column name + SELECT ROOM_TYPE, COUNT(*) AS listings FROM airbnb + VISUALISE ROOM_TYPE AS x, listings AS y + DRAW bar + + -- CORRECT: alias to lowercase + SELECT ROOM_TYPE AS room_type, COUNT(*) AS listings FROM airbnb + VISUALISE room_type AS x, listings AS y + DRAW bar + ``` +9. **COORD flip is broken**: It currently produces empty charts. Avoid using it. +10. **Do not mix `VISUALISE FROM` with a preceding `SELECT`**: `VISUALISE FROM table` is shorthand that auto-generates `SELECT * FROM table`. If you already have a `SELECT`, use `SELECT ... VISUALISE` instead: + ```sql + -- WRONG: VISUALISE FROM after SELECT + SELECT * FROM titanic + VISUALISE FROM titanic + DRAW bar MAPPING class AS x + + -- CORRECT: use VISUALISE (without FROM) after SELECT + SELECT * FROM titanic + VISUALISE class AS x + DRAW bar + + -- ALSO CORRECT: use VISUALISE FROM without any SELECT + VISUALISE FROM titanic + DRAW bar MAPPING class AS x + ``` + +Parameters +---------- +ggsql : + A full ggsql query with SELECT and VISUALISE clauses. The SELECT portion follows standard {{db_type}} SQL syntax. The VISUALISE portion specifies the chart configuration. Do NOT include `LABEL title => ...` in the query — use the `title` parameter instead. +title : + Always provide this. A brief, user-friendly title for this visualization. This is displayed as the card header above the chart. + +Returns +------- +: + The visualization rendered inline in the chat, or the error that occurred. The chart will also be accessible in the Query Plot tab. Does not affect the dashboard filter state. diff --git a/pkg-py/src/querychat/static/css/styles.css b/pkg-py/src/querychat/static/css/styles.css index bd227030a..1a8ba2f17 100644 --- a/pkg-py/src/querychat/static/css/styles.css +++ b/pkg-py/src/querychat/static/css/styles.css @@ -13,3 +13,132 @@ right: 4px; top: 4px; } + +/* Hide Vega's built-in action dropdown (we have our own save button) */ +.querychat-viz-container details:has(> .vega-actions) { + display: none !important; +} + +/* ---- Visualization container ---- */ + +.querychat-viz-container { + aspect-ratio: 4 / 2; + width: 100%; +} + +/* In full-screen mode, let the chart fill the available space */ +.bslib-full-screen-container .querychat-viz-container { + aspect-ratio: unset; +} + +/* ---- Visualization footer ---- */ + +.querychat-footer-buttons { + display: flex; + justify-content: space-between; + align-items: center; +} + +.querychat-footer-left, +.querychat-footer-right { + display: flex; + align-items: center; + gap: 4px; +} + +.querychat-show-query-btn, +.querychat-save-btn { + display: inline-flex; + align-items: center; + gap: 4px; + padding: 2px 8px; + height: 28px; + border: none; + border-radius: var(--bs-border-radius, 4px); + background: transparent; + color: var(--bs-secondary-color, #6c757d); + font-size: 0.75rem; + cursor: pointer; + white-space: nowrap; +} + +.querychat-show-query-btn:hover, +.querychat-save-btn:hover { + color: var(--bs-body-color, #212529); + background-color: rgba(var(--bs-emphasis-color-rgb, 0, 0, 0), 0.05); +} + +.querychat-query-chevron { + font-size: 0.625rem; + transition: transform 150ms; + display: inline-block; +} + +.querychat-query-chevron--expanded { + transform: rotate(90deg); +} + +.querychat-icon { + width: 14px; + height: 14px; +} + +.querychat-dropdown-chevron { + width: 12px; + height: 12px; + margin-left: 2px; +} + +.querychat-save-dropdown { + position: relative; +} + +.querychat-save-menu { + display: none; + position: absolute; + right: 0; + bottom: 100%; + margin-bottom: 4px; + z-index: 20; + background: var(--bs-body-bg, #fff); + border: 1px solid var(--bs-border-color, #dee2e6); + border-radius: var(--bs-border-radius, 4px); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15); + padding: 4px 0; + min-width: 120px; +} + +.querychat-save-menu--visible { + display: block; +} + +.querychat-save-menu button { + display: block; + width: 100%; + padding: 6px 12px; + border: none; + background: transparent; + color: var(--bs-body-color, #212529); + font-size: 0.75rem; + text-align: left; + cursor: pointer; +} + +.querychat-save-menu button:hover { + background-color: rgba(var(--bs-emphasis-color-rgb, 0, 0, 0), 0.05); +} + +.querychat-query-section { + display: none; + position: relative; + border-top: 1px solid var(--bs-border-color, #dee2e6); +} + +.querychat-query-section--visible { + display: block; +} + +.querychat-query-section bslib-code-editor { + border-radius: 0; +} + diff --git a/pkg-py/src/querychat/static/js/ggsql-grammar.js b/pkg-py/src/querychat/static/js/ggsql-grammar.js new file mode 100644 index 000000000..fbb36d1e3 --- /dev/null +++ b/pkg-py/src/querychat/static/js/ggsql-grammar.js @@ -0,0 +1,146 @@ +// ggsql-grammar.js - Extends SQL grammar for Prism Code Editor with ggsql tokens. +// +// Grammar derived from the TextMate grammar at posit-dev/ggsql: +// https://github.com/posit-dev/ggsql/blob/main/ggsql-vscode/syntaxes/ggsql.tmLanguage.json +// +// IMPORTANT: This dynamically imports from a content-hashed internal module of +// prism-code-editor. The filename (index-C1_GGQ8y.js) may change when bslib +// updates prism-code-editor. If syntax highlighting stops working after a bslib +// upgrade, check: +// {shiny_package}/www/shared/prism-code-editor/prism/languages/sql.js +// and copy the hashed filename from the import on the first line. + +// Extend a SQL grammar object in-place with ggsql tokens, then reorder keys +// so ggsql tokens are checked before generic SQL tokens. +function extendWithGgsql(sqlGrammar) { + // ggsql clause keywords — alias "keyword" so the theme styles them + sqlGrammar["ggsql-keyword"] = { + pattern: + /\b(?:VISUALISE|VISUALIZE|DRAW|MAPPING|REMAPPING|SETTING|FILTER|PARTITION|SCALE|FACET|PROJECT|LABEL|THEME|RENAMING|VIA|TO)\b/i, + alias: "keyword", + }; + + // Geom types (after DRAW) + sqlGrammar["ggsql-geom"] = { + pattern: + /\b(?:point|line|path|bar|col|area|tile|polygon|ribbon|histogram|density|smooth|boxplot|violin|text|label|segment|arrow|hline|vline|abline|errorbar)\b/, + alias: "builtin", + }; + + // Scale type modifiers + sqlGrammar["ggsql-scale-type"] = { + pattern: /\b(?:CONTINUOUS|DISCRETE|BINNED|ORDINAL|IDENTITY)\b/i, + alias: "builtin", + }; + + // Aesthetic names + sqlGrammar["ggsql-aesthetic"] = { + pattern: + /\b(?:x|y|xmin|xmax|ymin|ymax|xend|yend|weight|color|colour|fill|stroke|opacity|size|shape|linetype|linewidth|width|height|family|fontface|hjust|vjust|panel|row|column|theta|radius|thetamin|thetamax|radiusmin|radiusmax|thetaend|radiusend|offset)\b/, + alias: "attr-name", + }; + + // Theme names + sqlGrammar["ggsql-theme"] = { + pattern: /\b(?:minimal|classic|gray|grey|bw|dark|light|void)\b/, + alias: "class-name", + }; + + // Project types + sqlGrammar["ggsql-project"] = { + pattern: /\b(?:cartesian|polar|flip|fixed|trans|map|quickmap)\b/, + alias: "class-name", + }; + + // Fat arrow operator (SETTING/LABEL/RENAMING clauses) + sqlGrammar["ggsql-arrow"] = { + pattern: /=>/, + alias: "operator", + }; + + // Broader SQL function coverage: aggregate, window, datetime, string, math, + // conversion, conditional, JSON, list (from TextMate grammar sql-functions) + sqlGrammar["function"] = + /\b(?:count|sum|avg|min|max|stddev|variance|array_agg|string_agg|group_concat|row_number|rank|dense_rank|ntile|lag|lead|first_value|last_value|nth_value|cume_dist|percent_rank|date_trunc|date_part|datepart|datename|dateadd|datediff|extract|now|current_date|current_time|current_timestamp|getdate|getutcdate|strftime|strptime|make_date|make_time|make_timestamp|concat|substring|substr|left|right|length|len|char_length|lower|upper|trim|ltrim|rtrim|replace|reverse|repeat|lpad|rpad|split_part|string_split|format|printf|regexp_replace|regexp_extract|regexp_matches|abs|ceil|ceiling|floor|round|trunc|truncate|mod|power|sqrt|exp|ln|log|log10|log2|sign|sin|cos|tan|asin|acos|atan|atan2|pi|degrees|radians|random|rand|cast|convert|coalesce|nullif|ifnull|isnull|nvl|try_cast|typeof|if|iff|iif|greatest|least|decode|json|json_extract|json_extract_path|json_extract_string|json_value|json_query|json_object|json_array|json_array_length|to_json|from_json|list|list_value|list_aggregate|array_length|unnest|generate_series|range|first|last)(?=\s*\()/i; + + // Reorder: Prism checks tokens in object key order. + // ggsql tokens must come before generic SQL keyword/boolean to win. + var ggsqlKeys = [ + "ggsql-keyword", + "ggsql-geom", + "ggsql-scale-type", + "ggsql-aesthetic", + "ggsql-theme", + "ggsql-project", + "ggsql-arrow", + ]; + var ordered = {}; + + // 1. Greedy/high-priority tokens first + ["comment", "string", "identifier", "variable"].forEach(function (key) { + if (key in sqlGrammar) ordered[key] = sqlGrammar[key]; + }); + + // 2. ggsql-specific tokens + ggsqlKeys.forEach(function (key) { + if (key in sqlGrammar) ordered[key] = sqlGrammar[key]; + }); + + // 3. Remaining SQL tokens + Object.keys(sqlGrammar).forEach(function (key) { + if (!(key in ordered)) ordered[key] = sqlGrammar[key]; + }); + + // Update in-place to preserve the object identity Prism holds. + Object.keys(sqlGrammar).forEach(function (key) { + delete sqlGrammar[key]; + }); + Object.assign(sqlGrammar, ordered); +} + +(async () => { + // Locate the prism-code-editor base URL from its script tag. + var scriptEl = document.querySelector( + 'script[src*="prism-code-editor"][src$="index.js"]' + ); + if (!scriptEl) return; + + var baseUrl = scriptEl.src.replace(/\/index\.js$/, ""); + + // Import the GRAMMAR registry (not the languageMap from index.js). + // The grammar registry is only exported from the internal hashed module. + var grammarModule = await import(baseUrl + "/index-C1_GGQ8y.js"); + var languages = grammarModule.l; + if (!languages) return; + + // If the SQL grammar is already loaded, extend it directly. + if (languages.sql) { + extendWithGgsql(languages.sql); + return; + } + + // Otherwise, intercept the assignment from sql.js so we extend the grammar + // BEFORE the code editor reads it for tokenization. Using a property setter + // ensures zero timing gap — the grammar is extended the instant sql.js + // assigns it. + var _sql; + Object.defineProperty(languages, "sql", { + set: function (grammar) { + // Store the grammar, then extend it with ggsql tokens. + _sql = grammar; + extendWithGgsql(_sql); + // Replace the setter with a plain property now that we've intercepted. + Object.defineProperty(languages, "sql", { + value: _sql, + writable: true, + configurable: true, + enumerable: true, + }); + }, + get: function () { + return _sql; + }, + configurable: true, + enumerable: true, + }); +})(); diff --git a/pkg-py/src/querychat/static/js/querychat.js b/pkg-py/src/querychat/static/js/querychat.js index 18d6b4f45..75fcfe447 100644 --- a/pkg-py/src/querychat/static/js/querychat.js +++ b/pkg-py/src/querychat/static/js/querychat.js @@ -1,3 +1,62 @@ +// Helper: get the real click target, even inside Shadow DOM. +// event.target is retargeted to the shadow host when the click originates +// inside a shadow tree, so .closest() fails. composedPath() gives the +// full path including shadow-internal elements. +function deepTarget(event) { + return event.composedPath()[0] || event.target; +} + +// Helper: find a widget container by its base ID. +// Shiny module namespacing may prefix the ID (e.g. "mod-querychat_viz_abc"), +// so we match elements whose ID ends with the base widget ID. +function findWidgetContainer(widgetId) { + return document.getElementById(widgetId) + || document.querySelector('[id$="' + CSS.escape(widgetId) + '"]'); +} + +// Helper: get the SVG element from a widget container. +// Works with both vega-embed (via __view__) and shinywidgets (direct SVG). +function getChartSvg(container) { + var vegaEmbed = container.querySelector(".vega-embed"); + if (!vegaEmbed) return null; + return vegaEmbed.querySelector("svg"); +} + +// Helper: serialize an SVG element to a standalone SVG string. +function serializeSvg(svgEl) { + var clone = svgEl.cloneNode(true); + if (!clone.getAttribute("xmlns")) { + clone.setAttribute("xmlns", "http://www.w3.org/2000/svg"); + } + return new XMLSerializer().serializeToString(clone); +} + +// Helper: trigger a file download from a Blob. +function downloadBlob(blob, filename) { + var url = URL.createObjectURL(blob); + var link = document.createElement("a"); + link.download = filename; + link.href = url; + link.click(); + URL.revokeObjectURL(url); +} + +// Helper: close all visible save menus, including those inside Shadow DOM. +function closeAllSaveMenus() { + // Light DOM + document.querySelectorAll(".querychat-save-menu--visible").forEach(function (menu) { + menu.classList.remove("querychat-save-menu--visible"); + }); + // Shadow DOM (shinychat tool result cards) + document.querySelectorAll("shiny-tool-result").forEach(function (el) { + var root = el.shadowRoot; + if (!root) return; + root.querySelectorAll(".querychat-save-menu--visible").forEach(function (menu) { + menu.classList.remove("querychat-save-menu--visible"); + }); + }); +} + (function () { if (!window.Shiny) return; @@ -17,4 +76,97 @@ { priority: "event" } ); }); -})(); \ No newline at end of file +})(); + +// Show/Hide Query toggle +window.addEventListener("click", function (event) { + var btn = deepTarget(event).closest(".querychat-show-query-btn"); + if (!btn) return; + event.stopPropagation(); + var targetId = btn.dataset.target; + // Section may be inside the same shadow root as the button + var root = btn.getRootNode(); + var section = root.getElementById + ? root.getElementById(targetId) + : document.getElementById(targetId); + if (!section) return; + var isVisible = section.classList.toggle("querychat-query-section--visible"); + var label = btn.querySelector(".querychat-query-label"); + var chevron = btn.querySelector(".querychat-query-chevron"); + if (label) label.textContent = isVisible ? "Hide Query" : "Show Query"; + if (chevron) chevron.classList.toggle("querychat-query-chevron--expanded", isVisible); +}); + +// Save dropdown toggle + close on outside click +window.addEventListener("click", function (event) { + var btn = deepTarget(event).closest(".querychat-save-btn"); + if (btn) { + event.stopPropagation(); + var menu = btn.parentElement.querySelector(".querychat-save-menu"); + if (menu) menu.classList.toggle("querychat-save-menu--visible"); + return; + } + closeAllSaveMenus(); +}); + +// Save as PNG: render the chart SVG onto a canvas and export +window.addEventListener("click", function (event) { + var btn = deepTarget(event).closest(".querychat-save-png-btn"); + if (!btn) return; + event.stopPropagation(); + var widgetId = btn.dataset.widgetId; + var title = btn.dataset.title || "chart"; + var menu = btn.closest(".querychat-save-menu"); + if (menu) menu.classList.remove("querychat-save-menu--visible"); + + var container = findWidgetContainer(widgetId); + if (!container) return; + var svgEl = getChartSvg(container); + if (!svgEl) return; + + var svgStr = serializeSvg(svgEl); + var svgBlob = new Blob([svgStr], { type: "image/svg+xml;charset=utf-8" }); + var url = URL.createObjectURL(svgBlob); + var img = new Image(); + var scale = 2; + + img.onload = function () { + var canvas = document.createElement("canvas"); + canvas.width = img.width * scale; + canvas.height = img.height * scale; + var ctx = canvas.getContext("2d"); + ctx.scale(scale, scale); + ctx.drawImage(img, 0, 0); + URL.revokeObjectURL(url); + + canvas.toBlob(function (blob) { + if (blob) downloadBlob(blob, title + ".png"); + }, "image/png"); + }; + img.onerror = function () { + console.error("Failed to save chart as PNG: SVG image load failed"); + URL.revokeObjectURL(url); + }; + img.src = url; +}); + +// Save as SVG: extract the SVG directly from the DOM +window.addEventListener("click", function (event) { + var btn = deepTarget(event).closest(".querychat-save-svg-btn"); + if (!btn) return; + event.stopPropagation(); + var widgetId = btn.dataset.widgetId; + var title = btn.dataset.title || "chart"; + var menu = btn.closest(".querychat-save-menu"); + if (menu) menu.classList.remove("querychat-save-menu--visible"); + + var container = findWidgetContainer(widgetId); + if (!container) return; + var svgEl = getChartSvg(container); + if (!svgEl) return; + + var svgStr = serializeSvg(svgEl); + var blob = new Blob([svgStr], { type: "image/svg+xml" }); + downloadBlob(blob, title + ".svg"); +}); + diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py index 67ea453f5..800d173c8 100644 --- a/pkg-py/src/querychat/tools.py +++ b/pkg-py/src/querychat/tools.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Protocol, TypedDict, runtime_checkable +from uuid import uuid4 import chevron from chatlas import ContentToolResult, Tool @@ -13,6 +14,8 @@ if TYPE_CHECKING: from collections.abc import Callable + from htmltools import TagList + from ._datasource import DataSource @@ -69,6 +72,27 @@ def log_update(data: UpdateDashboardData): title: str +class VisualizeQueryData(TypedDict): + """ + Data passed to visualize_query callback. + + This TypedDict defines the structure of data passed to the + `tool_visualize_query` callback function when the LLM creates an + exploratory visualization from a ggsql query. + + Attributes + ---------- + ggsql + The full ggsql query string (SQL + VISUALISE). + title + A descriptive title for the visualization, or None if not provided. + + """ + + ggsql: str + title: str | None + + def _read_prompt_template(filename: str, **kwargs) -> str: """Read and interpolate a prompt template file.""" template_path = Path(__file__).parent / "prompts" / filename @@ -286,3 +310,260 @@ def tool_query(data_source: DataSource) -> Tool: name="querychat_query", annotations={"title": "Query Data"}, ) + + +def _build_viz_footer( + ggsql_str: str, + title: str | None, + widget_id: str, +) -> TagList: + """Build footer HTML for visualization tool results.""" + from htmltools import HTMLDependency, Tag, TagList, tags + + from shiny import ui + + footer_id = f"querychat_footer_{uuid4().hex[:8]}" + query_section_id = f"{footer_id}_query" + code_editor_id = f"{footer_id}_code" + + # ggsql grammar dependency (extends SQL grammar at runtime) + ggsql_grammar_dep = HTMLDependency( + name="querychat-ggsql-grammar", + version="0.1.0", + source={"package": "querychat", "subdir": "static/js"}, + script={"src": "ggsql-grammar.js", "type": "module"}, + ) + + # Read-only code editor for query display + code_editor = ui.input_code_editor( + id=code_editor_id, + value=ggsql_str, + language="sql", + read_only=True, + line_numbers=False, + height="auto", + theme_dark="github-dark", + ) + + # Query section (hidden by default) + query_section = tags.div( + {"class": "querychat-query-section", "id": query_section_id}, + code_editor, + ) + + # Footer buttons row + buttons_row = tags.div( + {"class": "querychat-footer-buttons"}, + # Left: Show Query toggle + tags.div( + {"class": "querychat-footer-left"}, + tags.button( + { + "class": "querychat-show-query-btn", + "data-target": query_section_id, + }, + tags.span({"class": "querychat-query-chevron"}, "\u25b6"), + tags.span({"class": "querychat-query-label"}, "Show Query"), + ), + ), + # Right: Save dropdown + tags.div( + {"class": "querychat-footer-right"}, + tags.div( + {"class": "querychat-save-dropdown"}, + tags.button( + { + "class": "querychat-save-btn", + "data-widget-id": widget_id, + }, + tags.svg( + { + "class": "querychat-icon", + "viewBox": "0 0 20 20", + "fill": "currentColor", + "xmlns": "http://www.w3.org/2000/svg", + }, + Tag( + "path", + d="M10.75 2.75a.75.75 0 00-1.5 0v8.614L6.295 8.235a.75.75 0 10-1.09 1.03l4.25 4.5a.75.75 0 001.09 0l4.25-4.5a.75.75 0 00-1.09-1.03l-2.955 3.129V2.75z", + ), + Tag( + "path", + d="M3.5 12.75a.75.75 0 00-1.5 0v2.5A2.75 2.75 0 004.75 18h10.5A2.75 2.75 0 0018 15.25v-2.5a.75.75 0 00-1.5 0v2.5c0 .69-.56 1.25-1.25 1.25H4.75c-.69 0-1.25-.56-1.25-1.25v-2.5z", + ), + ), + "Save", + tags.svg( + { + "class": "querychat-dropdown-chevron", + "viewBox": "0 0 20 20", + "fill": "currentColor", + "xmlns": "http://www.w3.org/2000/svg", + }, + Tag( + "path", + clip_rule="evenodd", + fill_rule="evenodd", + d="M5.22 8.22a.75.75 0 0 1 1.06 0L10 11.94l3.72-3.72a.75.75 0 1 1 1.06 1.06l-4.25 4.25a.75.75 0 0 1-1.06 0L5.22 9.28a.75.75 0 0 1 0-1.06Z", + ), + ), + ), + tags.div( + {"class": "querychat-save-menu"}, + tags.button( + { + "class": "querychat-save-png-btn", + "data-widget-id": widget_id, + "data-title": title or "chart", + }, + "Save as PNG", + ), + tags.button( + { + "class": "querychat-save-svg-btn", + "data-widget-id": widget_id, + "data-title": title or "chart", + }, + "Save as SVG", + ), + ), + ), + ), + ) + + return TagList(ggsql_grammar_dep, buttons_row, query_section) + + +class VisualizeQueryResult(ContentToolResult): + """Tool result that embeds an Altair chart inline via shinywidgets.""" + + def __init__( + self, + chart: Any, + ggsql_str: str, + title: str | None, + row_count: int, + col_count: int, + **kwargs: Any, + ): + from shinywidgets import output_widget, register_widget + + widget_id = f"querychat_viz_{uuid4().hex[:8]}" + register_widget(widget_id, chart) + + title_display = f" - {title}" if title else "" + markdown = f"```sql\n{ggsql_str}\n```" + markdown += f"\n\nVisualization created{title_display}." + markdown += f"\n\nData: {row_count} rows, {col_count} columns." + + footer = _build_viz_footer(ggsql_str, title, widget_id) + + widget_html = output_widget(widget_id, fill=True, fillable=True) + widget_html.add_class("querychat-viz-container") + + extra = { + "display": ToolResultDisplay( + html=widget_html, + title=title or "Query Visualization", + show_request=False, + open=True, + full_screen=True, + icon=bs_icon("graph-up"), + footer=footer, + ), + } + + super().__init__(value=markdown, extra=extra, **kwargs) + + +def _visualize_query_impl( + data_source: DataSource, + update_fn: Callable[[VisualizeQueryData], None], +) -> Callable[[str, str | None], ContentToolResult]: + """Create the visualize_query implementation function.""" + import ggsql as ggsql_pkg + + from ._ggsql import execute_ggsql, extract_title, spec_to_altair + + def visualize_query( + ggsql: str, + title: str | None = None, + ) -> ContentToolResult: + """Execute a ggsql query and render the visualization.""" + markdown = f"```sql\n{ggsql}\n```" + + try: + # Validate and split the query + validated = ggsql_pkg.validate(ggsql) + if not validated.has_visual(): + raise ValueError( + "Query must include a VISUALISE clause. " + "Use querychat_query for queries without visualization." + ) + + # Execute the SQL and render the visualization + spec = execute_ggsql(data_source, ggsql) + chart = spec_to_altair(spec) + + if title is None: + title = extract_title(spec) + metadata = spec.metadata() + row_count = metadata["rows"] + col_count = len(metadata["columns"]) + + update_fn( + { + "ggsql": ggsql, + "title": title, + } + ) + + chart = chart.properties(width="container", height="container") + + return VisualizeQueryResult( + chart=chart, + ggsql_str=ggsql, + title=title, + row_count=row_count, + col_count=col_count, + ) + + except Exception as e: + error_msg = str(e) + markdown += f"\n\n> Error: {error_msg}" + return ContentToolResult(value=markdown, error=e) + + return visualize_query + + +def tool_visualize_query( + data_source: DataSource, + update_fn: Callable[[VisualizeQueryData], None], +) -> Tool: + """ + Create a tool that executes a ggsql query and renders the visualization. + + Parameters + ---------- + data_source + The data source to query against + update_fn + Callback function to call with VisualizeQueryData when visualization succeeds + + Returns + ------- + Tool + A tool that can be registered with chatlas + + """ + impl = _visualize_query_impl(data_source, update_fn) + impl.__doc__ = _read_prompt_template( + "tool-visualize-query.md", + db_type=data_source.get_db_type(), + ) + + return Tool.from_func( + impl, + name="querychat_visualize_query", + annotations={"title": "Query Visualization"}, + ) diff --git a/pkg-py/src/querychat/types/__init__.py b/pkg-py/src/querychat/types/__init__.py index f9a8163df..002d9ba85 100644 --- a/pkg-py/src/querychat/types/__init__.py +++ b/pkg-py/src/querychat/types/__init__.py @@ -9,7 +9,7 @@ from .._querychat_core import AppStateDict from .._shiny_module import ServerValues from .._utils import UnsafeQueryError -from ..tools import UpdateDashboardData +from ..tools import UpdateDashboardData, VisualizeQueryData __all__ = ( "AppStateDict", @@ -22,4 +22,5 @@ "ServerValues", "UnsafeQueryError", "UpdateDashboardData", + "VisualizeQueryData", ) diff --git a/pkg-py/tests/conftest.py b/pkg-py/tests/conftest.py new file mode 100644 index 000000000..005195e77 --- /dev/null +++ b/pkg-py/tests/conftest.py @@ -0,0 +1,23 @@ +"""Shared pytest fixtures for querychat unit tests.""" + +import polars as pl +import pytest + + +def _ggsql_render_works() -> bool: + """Check if ggsql.render_altair() is functional (build can be broken in some envs).""" + try: + import ggsql + + df = pl.DataFrame({"x": [1, 2], "y": [3, 4]}) + result = ggsql.render_altair(df, "VISUALISE x, y DRAW point") + spec = result.to_dict() + return "$schema" in spec + except (ValueError, ImportError): + return False + + +ggsql_render_works = pytest.mark.skipif( + not _ggsql_render_works(), + reason="ggsql.render_altair() not functional (build environment issue)", +) diff --git a/pkg-py/tests/playwright/conftest.py b/pkg-py/tests/playwright/conftest.py index 6febfd4e8..961af01f3 100644 --- a/pkg-py/tests/playwright/conftest.py +++ b/pkg-py/tests/playwright/conftest.py @@ -592,3 +592,31 @@ def dash_cleanup(_thread, server): yield url finally: _stop_dash_server(server) + + +@pytest.fixture(scope="module") +def app_10_viz() -> Generator[str, None, None]: + """Start the 10-viz-app.py Shiny server for testing.""" + app_path = str(EXAMPLES_DIR / "10-viz-app.py") + + def start_factory(): + port = _find_free_port() + url = f"http://localhost:{port}" + return url, lambda: _start_shiny_app_threaded(app_path, port) + + def shiny_cleanup(_thread, server): + _stop_shiny_server(server) + + url, _thread, server = _start_server_with_retry( + start_factory, shiny_cleanup, timeout=30.0 + ) + try: + yield url + finally: + _stop_shiny_server(server) + + +@pytest.fixture +def chat_10_viz(page: Page) -> ChatControllerType: + """Create a ChatController for the 10-viz-app chat component.""" + return _create_chat_controller(page, "titanic") diff --git a/pkg-py/tests/playwright/test_10_viz_inline.py b/pkg-py/tests/playwright/test_10_viz_inline.py new file mode 100644 index 000000000..818a5d963 --- /dev/null +++ b/pkg-py/tests/playwright/test_10_viz_inline.py @@ -0,0 +1,123 @@ +""" +Playwright tests for inline visualization and fullscreen behavior. + +These tests verify that: +1. The visualize_query tool renders Altair charts inline in tool result cards +2. The fullscreen toggle button appears on visualization tool results +3. Fullscreen mode works (expand and collapse via button and Escape key) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from playwright.sync_api import expect + +if TYPE_CHECKING: + from playwright.sync_api import Page + from shinychat.playwright import ChatController + + +class TestInlineVisualization: + """Tests for inline chart rendering in tool result cards.""" + + @pytest.fixture(autouse=True) + def setup( + self, page: Page, app_10_viz: str, chat_10_viz: ChatController + ) -> None: + """Navigate to the viz app before each test.""" + page.goto(app_10_viz) + page.wait_for_selector("table", timeout=30000) + self.page = page + self.chat = chat_10_viz + + def test_app_loads_with_query_plot_tab(self) -> None: + """VIZ-INIT: App with visualize_query has a Query Plot tab.""" + expect(self.page.get_by_role("tab", name="Query Plot")).to_be_visible() + + def test_viz_tool_renders_inline_chart(self) -> None: + """VIZ-INLINE: Visualization tool result contains an inline chart widget.""" + self.chat.set_user_input( + "Create a scatter plot of age vs fare for the titanic passengers" + ) + self.chat.send_user_input(method="click") + + # Wait for a tool result card with full-screen attribute (viz results have it) + tool_card = self.page.locator("shiny-tool-result[full-screen]") + expect(tool_card).to_be_visible(timeout=90000) + + # The card should contain a widget output (Altair chart) + widget_output = tool_card.locator(".jupyter-widgets") + expect(widget_output).to_be_visible(timeout=10000) + + def test_fullscreen_button_visible_on_viz_card(self) -> None: + """VIZ-FS-BTN: Fullscreen toggle button appears on visualization cards.""" + self.chat.set_user_input( + "Make a bar chart showing count of passengers by class" + ) + self.chat.send_user_input(method="click") + + # Wait for viz tool result + tool_card = self.page.locator("shiny-tool-result[full-screen]") + expect(tool_card).to_be_visible(timeout=90000) + + # Fullscreen toggle should be visible + fs_button = tool_card.locator(".tool-fullscreen-toggle") + expect(fs_button).to_be_visible() + + def test_fullscreen_toggle_expands_card(self) -> None: + """VIZ-FS-EXPAND: Clicking fullscreen button expands the card.""" + self.chat.set_user_input( + "Plot a histogram of passenger ages from the titanic data" + ) + self.chat.send_user_input(method="click") + + # Wait for viz tool result + tool_result = self.page.locator("shiny-tool-result[full-screen]") + expect(tool_result).to_be_visible(timeout=90000) + + # Click fullscreen toggle + fs_button = tool_result.locator(".tool-fullscreen-toggle") + fs_button.click() + + # The .shiny-tool-card inside should now have fullscreen attribute + card = tool_result.locator(".shiny-tool-card[fullscreen]") + expect(card).to_be_visible() + + def test_escape_closes_fullscreen(self) -> None: + """VIZ-FS-ESC: Pressing Escape closes fullscreen mode.""" + self.chat.set_user_input( + "Create a visualization of survival rate by passenger class" + ) + self.chat.send_user_input(method="click") + + # Wait for viz tool result + tool_result = self.page.locator("shiny-tool-result[full-screen]") + expect(tool_result).to_be_visible(timeout=90000) + + # Enter fullscreen + fs_button = tool_result.locator(".tool-fullscreen-toggle") + fs_button.click() + + card = tool_result.locator(".shiny-tool-card[fullscreen]") + expect(card).to_be_visible() + + # Press Escape + self.page.keyboard.press("Escape") + + # Fullscreen should be removed + expect(card).not_to_be_visible() + + def test_non_viz_tool_results_have_no_fullscreen(self) -> None: + """VIZ-NO-FS: Non-visualization tool results don't have fullscreen.""" + self.chat.set_user_input("Show me passengers who survived") + self.chat.send_user_input(method="click") + + # Wait for a tool result (any) + tool_result = self.page.locator("shiny-tool-result").first + expect(tool_result).to_be_visible(timeout=90000) + + # Non-viz tool results should NOT have full-screen attribute + fs_results = self.page.locator("shiny-tool-result[full-screen]") + expect(fs_results).to_have_count(0) diff --git a/pkg-py/tests/playwright/test_visualization_tabs.py b/pkg-py/tests/playwright/test_visualization_tabs.py new file mode 100644 index 000000000..48c8ba7b6 --- /dev/null +++ b/pkg-py/tests/playwright/test_visualization_tabs.py @@ -0,0 +1,38 @@ +""" +Playwright tests for visualization tab behavior based on tools config. + +These tests verify that the Query Plot tab is only present when the +visualize_query tool is enabled. With default tools ("update", "query"), +only the Data tab should appear. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from playwright.sync_api import expect + +if TYPE_CHECKING: + from playwright.sync_api import Page + + +# Shiny Tests +class TestShinyVisualizationTabs: + """Tests for tab behavior in Shiny app with default tools (no viz).""" + + @pytest.fixture(autouse=True) + def setup(self, page: Page, app_01_hello: str) -> None: + page.goto(app_01_hello) + page.wait_for_selector("table", timeout=30000) + self.page = page + + def test_only_data_tab_present_without_viz_tools(self) -> None: + """With default tools, only the Data tab should be visible.""" + tabs = self.page.locator('[role="tab"]') + expect(tabs).to_have_count(1) + expect(self.page.get_by_role("tab", name="Data")).to_be_visible() + + def test_no_query_plot_tab(self) -> None: + """Query Plot tab should not exist without visualize_query tool.""" + expect(self.page.get_by_role("tab", name="Query Plot")).to_have_count(0) diff --git a/pkg-py/tests/test_ggsql.py b/pkg-py/tests/test_ggsql.py new file mode 100644 index 000000000..2733ed450 --- /dev/null +++ b/pkg-py/tests/test_ggsql.py @@ -0,0 +1,116 @@ +"""Tests for ggsql integration helpers.""" + +import ggsql +import narwhals.stable.v1 as nw +import polars as pl +from conftest import ggsql_render_works +from querychat._datasource import DataFrameSource +from querychat._ggsql import execute_ggsql, extract_title, spec_to_altair + + +class TestGgsqlValidate: + """Tests for ggsql.validate() usage (split SQL and VISUALISE).""" + + def test_splits_query_with_visualise(self): + query = "SELECT x, y FROM data VISUALISE x, y DRAW point" + validated = ggsql.validate(query) + assert validated.sql() == "SELECT x, y FROM data" + assert validated.visual() == "VISUALISE x, y DRAW point" + assert validated.has_visual() + + def test_returns_empty_viz_without_visualise(self): + query = "SELECT x, y FROM data" + validated = ggsql.validate(query) + assert validated.sql() == "SELECT x, y FROM data" + assert validated.visual() == "" + assert not validated.has_visual() + + def test_handles_complex_query(self): + query = """ + SELECT date, SUM(revenue) as total + FROM sales + GROUP BY date + VISUALISE date AS x, total AS y + DRAW line + LABEL title => 'Revenue Over Time' + """ + validated = ggsql.validate(query) + assert "SELECT date, SUM(revenue)" in validated.sql() + assert "GROUP BY date" in validated.sql() + assert "VISUALISE date AS x" in validated.visual() + assert "LABEL title" in validated.visual() + + +class TestExtractTitle: + @ggsql_render_works + def test_extracts_title_from_spec(self): + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2], "y": [3, 4]})) + ds = DataFrameSource(nw_df, "data") + spec = execute_ggsql( + ds, "SELECT * FROM data VISUALISE x, y DRAW point LABEL title => 'My Chart'" + ) + assert extract_title(spec) == "My Chart" + + @ggsql_render_works + def test_returns_none_without_title(self): + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2], "y": [3, 4]})) + ds = DataFrameSource(nw_df, "data") + spec = execute_ggsql(ds, "SELECT * FROM data VISUALISE x, y DRAW point") + assert extract_title(spec) is None + + +class TestSpecToAltair: + @ggsql_render_works + def test_produces_altair_chart(self): + import altair as alt + import ggsql + + reader = ggsql.DuckDBReader("duckdb://memory") + df = pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) + reader.register("data", df) + spec = reader.execute("SELECT * FROM data VISUALISE x, y DRAW point") + chart = spec_to_altair(spec) + assert isinstance(chart, (alt.Chart, alt.LayerChart, alt.FacetChart)) + result = chart.to_dict() + assert "$schema" in result + assert "vega-lite" in result["$schema"] + + +class TestExecuteGgsql: + @ggsql_render_works + def test_full_pipeline(self): + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + ds = DataFrameSource(nw_df, "test_data") + spec = execute_ggsql(ds, "SELECT * FROM test_data VISUALISE x, y DRAW point") + chart = spec_to_altair(spec) + result = chart.to_dict() + assert "$schema" in result + + @ggsql_render_works + def test_with_filtered_query(self): + nw_df = nw.from_native( + pl.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]}) + ) + ds = DataFrameSource(nw_df, "test_data") + spec = execute_ggsql( + ds, "SELECT * FROM test_data WHERE x > 2 VISUALISE x, y DRAW point" + ) + assert spec.metadata()["rows"] == 3 + + @ggsql_render_works + def test_spec_has_visual(self): + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2], "y": [3, 4]})) + ds = DataFrameSource(nw_df, "test_data") + spec = execute_ggsql(ds, "SELECT * FROM test_data VISUALISE x, y DRAW point") + assert "VISUALISE" in spec.visual() + + @ggsql_render_works + def test_with_pandas_dataframe(self): + import pandas as pd + + nw_df = nw.from_native(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + ds = DataFrameSource(nw_df, "test_data") + spec = execute_ggsql(ds, "SELECT * FROM test_data VISUALISE x, y DRAW point") + chart = spec_to_altair(spec) + result = chart.to_dict() + assert "$schema" in result diff --git a/pkg-py/tests/test_tools.py b/pkg-py/tests/test_tools.py index 682f259cf..94d8e3c64 100644 --- a/pkg-py/tests/test_tools.py +++ b/pkg-py/tests/test_tools.py @@ -12,6 +12,7 @@ def test_querychat_tool_starts_open_default_behavior(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize_query") is True def test_querychat_tool_starts_open_expanded(monkeypatch): @@ -21,6 +22,7 @@ def test_querychat_tool_starts_open_expanded(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is True + assert querychat_tool_starts_open("visualize_query") is True def test_querychat_tool_starts_open_collapsed(monkeypatch): @@ -30,6 +32,7 @@ def test_querychat_tool_starts_open_collapsed(monkeypatch): assert querychat_tool_starts_open("query") is False assert querychat_tool_starts_open("update") is False assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize_query") is False def test_querychat_tool_starts_open_default_setting(monkeypatch): @@ -39,6 +42,7 @@ def test_querychat_tool_starts_open_default_setting(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize_query") is True def test_querychat_tool_starts_open_case_insensitive(monkeypatch): diff --git a/pkg-py/tests/test_viz_tools.py b/pkg-py/tests/test_viz_tools.py new file mode 100644 index 000000000..27478b19f --- /dev/null +++ b/pkg-py/tests/test_viz_tools.py @@ -0,0 +1,110 @@ +"""Tests for visualization tool functions.""" + +import builtins + +import narwhals.stable.v1 as nw +import polars as pl +import pytest +from conftest import ggsql_render_works +from querychat._datasource import DataFrameSource +from querychat.tools import ( + VisualizeQueryData, + VisualizeQueryResult, + tool_visualize_query, +) + + +class TestVizDependencyCheck: + def test_missing_ggsql_raises_helpful_error(self, monkeypatch): + """Requesting viz tools without ggsql installed should fail early.""" + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "ggsql": + raise ImportError("No module named 'ggsql'") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mock_import) + + from querychat._querychat_base import check_viz_dependencies + + with pytest.raises(ImportError, match="pip install querychat\\[viz\\]"): + check_viz_dependencies(("visualize_query",)) + + def test_no_error_without_viz_tools(self): + """Non-viz tool configs should not check for ggsql.""" + from querychat._querychat_base import check_viz_dependencies + + # Should not raise + check_viz_dependencies(("update", "query")) + check_viz_dependencies(None) + + +@pytest.fixture +def sample_df(): + return pl.DataFrame( + { + "x": [1, 2, 3, 4, 5], + "y": [10, 20, 15, 25, 30], + "category": ["A", "B", "A", "B", "A"], + } + ) + + +@pytest.fixture +def data_source(sample_df): + nw_df = nw.from_native(sample_df) + return DataFrameSource(nw_df, "test_data") + + +class TestToolVisualizeQuery: + def test_creates_tool(self, data_source): + callback_data = {} + + def update_fn(data: VisualizeQueryData): + callback_data.update(data) + + tool = tool_visualize_query(data_source, update_fn) + assert tool.name == "querychat_visualize_query" + + @ggsql_render_works + def test_tool_executes_sql_and_renders(self, data_source, monkeypatch): + callback_data = {} + + def update_fn(data: VisualizeQueryData): + callback_data.update(data) + + monkeypatch.setattr("shinywidgets.register_widget", lambda _widget_id, _chart: None) + monkeypatch.setattr("shinywidgets.output_widget", lambda widget_id: widget_id) + + tool = tool_visualize_query(data_source, update_fn) + impl = tool.func + + result = impl( + ggsql="SELECT x, y FROM test_data WHERE x > 2 VISUALISE x, y DRAW point", + title="Filtered Scatter", + ) + + assert "ggsql" in callback_data + assert "title" in callback_data + assert callback_data["title"] == "Filtered Scatter" + + assert isinstance(result, VisualizeQueryResult) + display = result.extra["display"] + assert display.full_screen is True + assert display.open is True + + @ggsql_render_works + def test_tool_handles_query_without_visualise(self, data_source): + callback_data = {} + + def update_fn(data: VisualizeQueryData): + callback_data.update(data) + + tool = tool_visualize_query(data_source, update_fn) + impl = tool.func + + result = impl(ggsql="SELECT x, y FROM test_data", title="No Viz") + + assert result.error is not None + assert "VISUALISE" in str(result.error) diff --git a/pyproject.toml b/pyproject.toml index 02ed2eed7..d901d19da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,12 +42,14 @@ classifiers = [ [project.optional-dependencies] # For SQLAlchemySource and sample data, one of polars or pandas is required pandas = ["pandas"] -polars = ["polars"] +polars = ["polars", "pyarrow"] # duckdb requires pyarrow for polars DataFrame registration ibis = ["ibis-framework>=9.0.0", "pandas"] # pandas required for ibis .execute() to return DataFrames # Web framework extras streamlit = ["streamlit>=1.30"] gradio = ["gradio>=6.0"] dash = ["dash-ag-grid>=31.0", "dash[async]>=3.1", "dash-bootstrap-components>=2.0", "pandas"] +# Visualization with ggsql +viz = ["ggsql>=0.1.5", "altair>=5.0", "shinywidgets>=0.3.0"] [project.urls] Homepage = "https://github.com/posit-dev/querychat" # TODO update when we have docs @@ -65,7 +67,7 @@ packages = ["pkg-py/src/querychat"] include = ["pkg-py/src/querychat", "pkg-py/LICENSE", "pkg-py/README.md"] [dependency-groups] -dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0", "polars>=1.0.0", "pyarrow>=14.0.0", "ibis-framework[duckdb]>=9.0.0"] +dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0", "polars>=1.0.0", "pyarrow>=14.0.0", "ibis-framework[duckdb]>=9.0.0", "ggsql>=0.1.5", "altair>=5.0", "shinywidgets>=0.3.0"] docs = ["quartodoc>=0.11.1", "griffe<2", "nbformat", "nbclient", "ipykernel"] examples = [ "openai", @@ -218,6 +220,9 @@ line-ending = "auto" docstring-code-format = true docstring-code-line-length = "dynamic" +[tool.pytest.ini_options] +pythonpath = ["pkg-py/tests"] + [tool.pyright] include = ["pkg-py/src/querychat"]