Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions js/build.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@ const jsTargets = [
source: "src/viz.ts",
output: "../pkg-py/src/querychat/static/js/viz.js",
},
{
source: "src/viz.ts",
output: "../pkg-r/inst/htmldep/viz.js",
},
];

const cssTargets = [
{
source: "src/viz.css",
output: "../pkg-py/src/querychat/static/css/viz.css",
},
{
source: "src/viz.css",
output: "../pkg-r/inst/htmldep/viz.css",
},
];

const ensureParentDir = async (relativePath) => {
Expand Down Expand Up @@ -81,10 +89,9 @@ const reportMissingSources = async () => {
};

export const stageBuildOutputs = async (stageDir) => {
const cssSourcePath = path.resolve(rootDir, "src/viz.css");
const cssSource = await readFile(cssSourcePath, "utf8");

for (const target of cssTargets) {
const cssSourcePath = path.resolve(rootDir, target.source);
const cssSource = await readFile(cssSourcePath, "utf8");
const outputPath = resolveOutputPath(stageDir, target.output);
await mkdir(path.dirname(outputPath), { recursive: true });
await writeFile(outputPath, `${banner(target.source)}${cssSource}`, "utf8");
Expand Down
9 changes: 7 additions & 2 deletions pkg-r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ BugReports: https://github.com/posit-dev/querychat/issues
Depends:
R (>= 4.1.0)
Imports:
bslib,
bsicons,
bslib (>= 0.10.0),
cli,
DBI,
ellmer (>= 0.3.0),
Expand All @@ -37,20 +38,24 @@ Imports:
utils,
whisker
Suggests:
bsicons,
dbplyr,
dplyr,
DT,
duckdb,
ggsql,
knitr,
palmerpenguins,
rmarkdown,
RSQLite,
rsvg,
shinytest2,
testthat (>= 3.0.0),
V8,
withr
VignetteBuilder:
knitr
Remotes:
posit-dev/ggsql-r
Config/testthat/edition: 3
Config/testthat/parallel: true
Encoding: UTF-8
Expand Down
43 changes: 37 additions & 6 deletions pkg-r/R/QueryChat.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ QueryChat <- R6::R6Class(
create_session_client = function(
client_spec = NULL,
tools = NA,
session = NULL,
update_dashboard = function(query, title) {},
reset_dashboard = function() {}
reset_dashboard = function() {},
visualize = function(data) {}
) {
spec <- client_spec %||% private$.client_spec
chat <- as_querychat_client(spec)
Expand Down Expand Up @@ -169,6 +171,21 @@ QueryChat <- R6::R6Class(
chat$register_tool(tool_query(private$.data_source))
}

if ("visualize" %in% tools) {
rlang::check_installed(
"ggsql",
reason = "for visualization support."
)
chat$register_tool(
tool_visualize_dashboard(
private$.data_source,
session = session,
update_fn = visualize,
has_tool_query = "query" %in% tools
)
)
}

chat
}
),
Expand Down Expand Up @@ -249,7 +266,11 @@ QueryChat <- R6::R6Class(
# Validate arguments
check_string(id, allow_null = TRUE)
check_string(greeting, allow_null = TRUE)
arg_match(tools)
arg_match(
tools,
values = c("update", "query", "visualize"),
multiple = TRUE
)
check_string(data_description, allow_null = TRUE)
check_number_whole(categorical_threshold, min = 1)
check_string(extra_instructions, allow_null = TRUE)
Expand Down Expand Up @@ -318,25 +339,35 @@ QueryChat <- R6::R6Class(
#' `title` generated by the LLM for the `update_dashboard` tool.
#' @param reset_dashboard Optional function to call when the
#' `reset_dashboard` tool is called.
#' @param visualize Optional function to call with a list containing
#' `ggsql`, `title`, and `widget_id` when a visualization succeeds.
#' @param session A Shiny session object. Required when `"visualize"` is
#' in `tools` and you want interactive chart rendering. When `NULL`
#' (the default), visualizations still execute but are not rendered
#' as Shiny outputs.
client = function(
tools = NA,
update_dashboard = function(query, title) {},
reset_dashboard = function() {}
reset_dashboard = function() {},
visualize = function(data) {},
session = NULL
) {
private$require_data_source("$client")

if (!is_na(tools) && !is.null(tools)) {
tools <- arg_match(
tools,
values = c("update", "query"),
values = c("update", "query", "visualize"),
multiple = TRUE
)
}

private$create_session_client(
tools = tools,
session = session,
update_dashboard = update_dashboard,
reset_dashboard = reset_dashboard
reset_dashboard = reset_dashboard,
visualize = visualize
)
},

Expand Down Expand Up @@ -417,7 +448,6 @@ QueryChat <- R6::R6Class(
app_obj = function(..., bookmark_store = "url") {
private$require_data_source("$app_obj")
check_installed("DT")
check_installed("bsicons")
check_dots_empty()

table_name <- private$.data_source$table_name
Expand Down Expand Up @@ -705,6 +735,7 @@ QueryChat <- R6::R6Class(
data_source = private$.data_source,
greeting = self$greeting,
client = create_session_client,
tools = self$tools,
enable_bookmarking = enable_bookmarking
)
},
Expand Down
16 changes: 15 additions & 1 deletion pkg-r/R/QueryChatSystemPrompt.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,24 @@ QueryChatSystemPrompt <- R6::R6Class(
extra_instructions = self$extra_instructions,
has_tool_update = if ("update" %in% tools) "true",
has_tool_query = if ("query" %in% tools) "true",
has_tool_visualize = if ("visualize" %in% tools) "true",
include_query_guidelines = if (length(tools) > 0) "true"
)

whisker::whisker.render(self$template, context)
partials <- list()
syntax_path <- system.file(
"prompts",
"ggsql-syntax.md",
package = "querychat"
)
if (nzchar(syntax_path)) {
partials[["ggsql-syntax"]] <- paste(
readLines(syntax_path),
collapse = "\n"
)
}

whisker::whisker.render(self$template, context, partials = partials)
}
)
)
Expand Down
53 changes: 52 additions & 1 deletion pkg-r/R/querychat_module.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mod_server <- function(
data_source,
greeting,
client,
tools = c("update", "query"),
enable_bookmarking = FALSE
) {
shiny::moduleServer(id, function(input, output, session) {
Expand Down Expand Up @@ -61,11 +62,24 @@ mod_server <- function(
querychat_tool_result(action = "reset")
}

# Non-reactive bookkeeping for bookmark save/restore of viz widgets
viz_widgets <- list()

on_visualize <- function(data) {
viz_widgets[[length(viz_widgets) + 1L]] <<- list(
widget_id = data$widget_id,
ggsql = data$ggsql
)
}

# Set up the chat object for this session
check_function(client)
chat <- client(
update_dashboard = update_dashboard,
reset_dashboard = reset_query
reset_dashboard = reset_query,
visualize = on_visualize,
tools = tools,
session = session
)

# Prepopulate the chat UI with a welcome message that appears to be from the
Expand Down Expand Up @@ -121,6 +135,9 @@ mod_server <- function(
state$values$querychat_sql <- current_query()
state$values$querychat_title <- current_title()
state$values$querychat_has_greeted <- has_greeted()
if (length(viz_widgets) > 0) {
state$values$querychat_viz_widgets <- viz_widgets
}
})

shiny::onRestore(function(state) {
Expand All @@ -133,6 +150,14 @@ mod_server <- function(
if (!is.null(state$values$querychat_has_greeted)) {
has_greeted(state$values$querychat_has_greeted)
}
if (!is.null(state$values$querychat_viz_widgets)) {
restored <- restore_viz_widgets(
data_source,
state$values$querychat_viz_widgets,
session
)
viz_widgets <<- restored
}
})
}

Expand All @@ -147,3 +172,29 @@ mod_server <- function(

# TODO: Make this dependent on enabled tools
GREETING_PROMPT <- "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list."

restore_viz_widgets <- function(data_source, saved_widgets, session) {
rlang::check_installed("ggsql", reason = "for visualization support.")

restored <- list()
for (entry in saved_widgets) {
tryCatch(
{
validated <- ggsql::ggsql_validate(entry$ggsql)
spec <- execute_ggsql(data_source, validated)
session$output[[entry$widget_id]] <- ggsql::renderGgsql(spec)
restored <- c(restored, list(entry))
},
error = function(e) {
warning(
sprintf(
"Failed to restore visualization widget '%s' on bookmark restore.",
entry$widget_id
),
call. = FALSE
)
}
)
}
restored
}
22 changes: 18 additions & 4 deletions pkg-r/R/querychat_tools.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,13 @@ tool_query <- function(data_source) {
db_type <- data_source$get_db_type()

ellmer::tool(
function(query, `_intent` = "") {
querychat_tool_result(data_source, query, action = "query")
function(query, `_intent` = "", collapsed = FALSE) {
querychat_tool_result(
data_source,
query,
action = "query",
collapsed = collapsed
)
},
name = "querychat_query",
description = interpolate_package("tool-query.md", db_type = db_type),
Expand All @@ -106,6 +111,10 @@ tool_query <- function(data_source) {
),
`_intent` = ellmer::type_string(
"A brief, user-friendly description of what this query calculates or retrieves."
),
collapsed = ellmer::type_boolean(
"Optional (default: false). Set to true for exploratory or preparatory queries whose results aren't the primary answer. When true, the result card starts collapsed.",
required = FALSE
)
),
annotations = ellmer::tool_annotations(
Expand Down Expand Up @@ -161,7 +170,8 @@ querychat_tool_result <- function(
data_source,
query,
title = NULL,
action = "update"
action = "update",
collapsed = NULL
) {
action <- arg_match(action, c("update", "query", "reset"))

Expand Down Expand Up @@ -231,7 +241,11 @@ querychat_tool_result <- function(
title = if (action == "update" && !is.null(title)) title,
show_request = is_error,
markdown = display_md,
open = querychat_tool_starts_open(action)
open = if (!is.null(collapsed)) {
!collapsed
} else {
querychat_tool_starts_open(action)
}
)
)
)
Expand Down
Loading
Loading