From a12e198926191e3e1df6c9a21b33ca2545696309 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 27 Apr 2026 11:58:27 +0200 Subject: [PATCH 01/33] first pass --- src/execute/layer.rs | 13 + src/execute/mod.rs | 22 +- src/plot/layer/geom/area.rs | 20 +- src/plot/layer/geom/arrow.rs | 4 + src/plot/layer/geom/bar.rs | 16 +- src/plot/layer/geom/errorbar.rs | 4 + src/plot/layer/geom/line.rs | 22 +- src/plot/layer/geom/mod.rs | 47 +- src/plot/layer/geom/path.rs | 4 + src/plot/layer/geom/point.rs | 4 + src/plot/layer/geom/polygon.rs | 4 + src/plot/layer/geom/ribbon.rs | 20 +- src/plot/layer/geom/rule.rs | 4 + src/plot/layer/geom/segment.rs | 4 + src/plot/layer/geom/stat_aggregate.rs | 982 ++++++++++++++++++++++++++ src/plot/layer/geom/text.rs | 4 + src/plot/layer/mod.rs | 5 + src/reader/duckdb.rs | 8 + src/reader/mod.rs | 11 + 19 files changed, 1164 insertions(+), 34 deletions(-) create mode 100644 src/plot/layer/geom/stat_aggregate.rs diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 6af5c641..1c9b1b0d 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -584,6 +584,19 @@ where layer.mappings.aesthetics.remove(aes); } + // Auto-remap stat columns whose names are position aesthetics that were + // consumed by the stat (e.g. Aggregate's `pos1`/`pos2` outputs). The geom + // can't list these in `default_remappings` because the set of position + // aesthetics in play is dynamic per layer. + for stat in &stat_columns { + if final_remappings.contains_key(stat) { + continue; + } + if aesthetic::is_position_aesthetic(stat) && consumed_aesthetics.contains(stat) { + final_remappings.insert(stat.clone(), stat.clone()); + } + } + // Apply stat_columns to layer aesthetics using the remappings for stat in &stat_columns { if let Some(aesthetic) = final_remappings.get(stat) { diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 82c9e1c0..063ca1d6 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -116,24 +116,38 @@ fn validate( } } - // Validate remapping source columns are valid stat columns for this geom + // Validate remapping source columns are valid stat columns for this geom. + // Geoms that opt into the Aggregate stat (`supports_aggregate`) also accept + // `aggregate`, `count`, and any position aesthetic name as a stat source. let valid_stat_columns = layer.geom.valid_stat_columns(); + let supports_aggregate = layer.geom.supports_aggregate(); for stat_value in layer.remappings.aesthetics.values() { if let Some(stat_col) = stat_value.column_name() { - if !valid_stat_columns.contains(&stat_col) { - if valid_stat_columns.is_empty() { + let is_aggregate_stat_col = supports_aggregate + && (stat_col == "aggregate" + || stat_col == "count" + || crate::plot::aesthetic::is_position_aesthetic(stat_col)); + if !valid_stat_columns.contains(&stat_col) && !is_aggregate_stat_col { + if valid_stat_columns.is_empty() && !supports_aggregate { return Err(GgsqlError::ValidationError(format!( "Layer {}: REMAPPING not supported for geom '{}' (no stat transform)", idx + 1, layer.geom ))); } else { + let mut valid: Vec = + valid_stat_columns.iter().map(|s| s.to_string()).collect(); + if supports_aggregate { + valid.push("aggregate".to_string()); + valid.push("count".to_string()); + } + let valid_refs: Vec<&str> = valid.iter().map(|s| s.as_str()).collect(); return Err(GgsqlError::ValidationError(format!( "Layer {}: REMAPPING references unknown stat column '{}'. Valid stat columns for geom '{}' are: {}", idx + 1, stat_col, layer.geom, - crate::and_list(valid_stat_columns) + crate::and_list(&valid_refs) ))); } } diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index a9df6bff..101806d0 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -5,8 +5,9 @@ use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamDefinition}; use crate::{naming, Mappings}; +use super::stat_aggregate; use super::types::{ParamConstraint, POSITION_VALUES}; -use super::{DefaultAesthetics, GeomTrait, GeomType, StatResult}; +use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; /// Area geom - filled area charts #[derive(Debug, Clone, Copy)] @@ -54,6 +55,10 @@ impl GeomTrait for Area { PARAMS } + fn supports_aggregate(&self) -> bool { + true + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -61,13 +66,16 @@ impl GeomTrait for Area { fn apply_stat_transform( &self, query: &str, - _schema: &crate::plot::Schema, - _aesthetics: &Mappings, - _group_by: &[String], - _parameters: &std::collections::HashMap, + schema: &crate::plot::Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &std::collections::HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, - _dialect: &dyn crate::reader::SqlDialect, + dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { + if has_aggregate_param(parameters) { + return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + } // Area geom needs ordering by pos1 (domain axis) for proper rendering let order_col = naming::aesthetic_column("pos1"); Ok(StatResult::Transformed { diff --git a/src/plot/layer/geom/arrow.rs b/src/plot/layer/geom/arrow.rs index 375d9754..2e3369d2 100644 --- a/src/plot/layer/geom/arrow.rs +++ b/src/plot/layer/geom/arrow.rs @@ -39,6 +39,10 @@ impl GeomTrait for Arrow { }]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for Arrow { diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index d64bce9f..7824f74f 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -3,10 +3,11 @@ use std::collections::HashMap; use std::collections::HashSet; +use super::stat_aggregate; use super::types::{get_column_name, POSITION_VALUES}; use super::{ - DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, - StatResult, + has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, + ParamDefinition, StatResult, }; use crate::naming; use crate::plot::types::{DefaultAestheticValue, ParameterValue}; @@ -79,6 +80,10 @@ impl GeomTrait for Bar { &["pos1", "pos2", "weight"] } + fn supports_aggregate(&self) -> bool { + true + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true // Bar stat decides COUNT vs identity based on y mapping } @@ -89,10 +94,13 @@ impl GeomTrait for Bar { schema: &Schema, aesthetics: &Mappings, group_by: &[String], - _parameters: &HashMap, + parameters: &HashMap, _execute_query: &dyn Fn(&str) -> Result, - _dialect: &dyn SqlDialect, + dialect: &dyn SqlDialect, ) -> Result { + if has_aggregate_param(parameters) { + return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + } stat_bar_count(query, schema, aesthetics, group_by) } } diff --git a/src/plot/layer/geom/errorbar.rs b/src/plot/layer/geom/errorbar.rs index 394c81e9..2821d141 100644 --- a/src/plot/layer/geom/errorbar.rs +++ b/src/plot/layer/geom/errorbar.rs @@ -44,6 +44,10 @@ impl GeomTrait for ErrorBar { ]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for ErrorBar { diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index a8ded3b1..20b87228 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -1,8 +1,9 @@ //! Line geom implementation +use super::stat_aggregate; use super::{ - DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, - StatResult, + has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, + ParamDefinition, StatResult, }; use crate::plot::layer::orientation::{ALIGNED, ORIENTATION_VALUES}; use crate::plot::types::DefaultAestheticValue; @@ -39,6 +40,10 @@ impl GeomTrait for Line { PARAMS } + fn supports_aggregate(&self) -> bool { + true + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -46,13 +51,16 @@ impl GeomTrait for Line { fn apply_stat_transform( &self, query: &str, - _schema: &crate::plot::Schema, - _aesthetics: &Mappings, - _group_by: &[String], - _parameters: &std::collections::HashMap, + schema: &crate::plot::Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &std::collections::HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, - _dialect: &dyn crate::reader::SqlDialect, + dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { + if has_aggregate_param(parameters) { + return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + } // Line geom needs ordering by pos1 (domain axis) for proper rendering let order_col = naming::aesthetic_column("pos1"); Ok(StatResult::Transformed { diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index 145f8089..a10e6165 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -43,6 +43,7 @@ mod ribbon; mod rule; mod segment; mod smooth; +pub(crate) mod stat_aggregate; mod text; mod tile; mod violin; @@ -192,20 +193,35 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { false } + /// Whether this geom accepts the `aggregate` SETTING parameter. + /// + /// Geoms that opt in (the Identity-stat geoms) gain a generic Aggregate stat + /// that groups by discrete mappings + PARTITION BY and emits one row per + /// (group × aggregation function). Statistical geoms (histogram, density, + /// smooth, boxplot, violin) leave this `false` to keep their bespoke stats. + fn supports_aggregate(&self) -> bool { + false + } + /// Apply statistical transformation to the layer query. /// - /// The default implementation returns identity (no transformation). + /// The default implementation dispatches to the Aggregate stat when + /// `supports_aggregate()` is true and the `aggregate` parameter is set; + /// otherwise returns identity (no transformation). #[allow(clippy::too_many_arguments)] fn apply_stat_transform( &self, - _query: &str, - _schema: &Schema, - _aesthetics: &Mappings, - _group_by: &[String], - _parameters: &HashMap, + query: &str, + schema: &Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &HashMap, _execute_query: &dyn Fn(&str) -> Result, - _dialect: &dyn SqlDialect, + dialect: &dyn SqlDialect, ) -> Result { + if self.supports_aggregate() && has_aggregate_param(parameters) { + return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + } Ok(StatResult::Identity) } @@ -248,10 +264,22 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { for param in self.default_params() { valid.push(param.name); } + if self.supports_aggregate() { + valid.push("aggregate"); + } valid } } +/// True when `parameters["aggregate"]` is set to a non-null string or array. +pub(crate) fn has_aggregate_param(parameters: &HashMap) -> bool { + match parameters.get("aggregate") { + None | Some(ParameterValue::Null) => false, + Some(ParameterValue::String(_)) | Some(ParameterValue::Array(_)) => true, + _ => false, + } +} + /// Wrapper struct for geom trait objects /// /// This provides a convenient interface for working with geoms while hiding @@ -455,6 +483,11 @@ impl Geom { self.0.valid_settings() } + /// Whether this geom accepts the `aggregate` SETTING parameter. + pub fn supports_aggregate(&self) -> bool { + self.0.supports_aggregate() + } + /// Validate aesthetic mappings pub fn validate_aesthetics(&self, mappings: &Mappings) -> std::result::Result<(), String> { self.0.validate_aesthetics(mappings) diff --git a/src/plot/layer/geom/path.rs b/src/plot/layer/geom/path.rs index 5e32a3be..062e5f73 100644 --- a/src/plot/layer/geom/path.rs +++ b/src/plot/layer/geom/path.rs @@ -36,6 +36,10 @@ impl GeomTrait for Path { }]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for Path { diff --git a/src/plot/layer/geom/point.rs b/src/plot/layer/geom/point.rs index 3dafde2a..5101f2f0 100644 --- a/src/plot/layer/geom/point.rs +++ b/src/plot/layer/geom/point.rs @@ -38,6 +38,10 @@ impl GeomTrait for Point { }]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for Point { diff --git a/src/plot/layer/geom/polygon.rs b/src/plot/layer/geom/polygon.rs index d1ed6841..dee34338 100644 --- a/src/plot/layer/geom/polygon.rs +++ b/src/plot/layer/geom/polygon.rs @@ -37,6 +37,10 @@ impl GeomTrait for Polygon { }]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for Polygon { diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index 87d4636c..98b60951 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -1,7 +1,8 @@ //! Ribbon geom implementation +use super::stat_aggregate; use super::types::POSITION_VALUES; -use super::{DefaultAesthetics, GeomTrait, GeomType, StatResult}; +use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamConstraint, ParamDefinition}; use crate::{naming, Mappings}; @@ -39,6 +40,10 @@ impl GeomTrait for Ribbon { PARAMS } + fn supports_aggregate(&self) -> bool { + true + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -46,13 +51,16 @@ impl GeomTrait for Ribbon { fn apply_stat_transform( &self, query: &str, - _schema: &crate::plot::Schema, - _aesthetics: &Mappings, - _group_by: &[String], - _parameters: &std::collections::HashMap, + schema: &crate::plot::Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &std::collections::HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, - _dialect: &dyn crate::reader::SqlDialect, + dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { + if has_aggregate_param(parameters) { + return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + } // Ribbon geom needs ordering by pos1 (domain axis) for proper rendering let order_col = naming::aesthetic_column("pos1"); Ok(StatResult::Transformed { diff --git a/src/plot/layer/geom/rule.rs b/src/plot/layer/geom/rule.rs index be434f7a..a495cb48 100644 --- a/src/plot/layer/geom/rule.rs +++ b/src/plot/layer/geom/rule.rs @@ -25,6 +25,10 @@ impl GeomTrait for Rule { } } + fn supports_aggregate(&self) -> bool { + true + } + fn validate_aesthetics(&self, mappings: &crate::Mappings) -> std::result::Result<(), String> { // Rule requires exactly one of pos1 or pos2 (XOR logic) let has_pos1 = mappings.contains_key("pos1"); diff --git a/src/plot/layer/geom/segment.rs b/src/plot/layer/geom/segment.rs index 2ebfe920..d3fac22d 100644 --- a/src/plot/layer/geom/segment.rs +++ b/src/plot/layer/geom/segment.rs @@ -39,6 +39,10 @@ impl GeomTrait for Segment { }]; PARAMS } + + fn supports_aggregate(&self) -> bool { + true + } } impl std::fmt::Display for Segment { diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs new file mode 100644 index 00000000..54692d88 --- /dev/null +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -0,0 +1,982 @@ +//! Aggregate stat - groups data and applies one or more aggregation functions per group. +//! +//! When a layer's `aggregate` SETTING is set to a function name (or array of names), +//! this stat groups by discrete mappings + PARTITION BY columns and produces one row +//! per (group × function), aggregating numeric position aesthetics. +//! +//! Output columns: +//! - One column per numeric position aesthetic (named `pos1`, `pos2`, etc.) holding the +//! aggregated value. NULL for `count` rows. +//! - `aggregate` - the function name for the row. +//! - `count` (only when `count` is requested) - the row tally for that group. + +use std::collections::HashMap; + +use super::types::StatResult; +use crate::naming; +use crate::plot::aesthetic::is_position_aesthetic; +use crate::plot::types::{ + DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue, Schema, +}; +use crate::reader::SqlDialect; +use crate::{GgsqlError, Mappings, Result}; + +/// All aggregation function names accepted by the `aggregate` SETTING. +pub const AGG_NAMES: &[&str] = &[ + // Tallies & sums + "count", + "sum", + "prod", + // Extremes + "min", + "max", + "range", + // Central tendency + "mean", + "geomean", + "harmean", + "rms", + "median", + // Spread (standalone) + "sdev", + "var", + "iqr", + // Quantiles + "q05", + "q10", + "q25", + "q50", + "q75", + "q90", + "q95", + // Bands (mean ± spread) + "mean-sdev", + "mean+sdev", + "mean-2sdev", + "mean+2sdev", + "mean-se", + "mean+se", +]; + +/// Returns the `ParamDefinition` for the `aggregate` SETTING parameter. +/// +/// Used by `Layer::validate_settings` to check the value against `AGG_NAMES`, +/// and by geoms that support aggregation. +pub fn aggregate_param_definition() -> ParamDefinition { + ParamDefinition { + name: "aggregate", + default: DefaultParamValue::Null, + constraint: ParamConstraint::string_or_string_array(AGG_NAMES), + } +} + +/// Apply the Aggregate stat to a layer query. +/// +/// Returns `StatResult::Identity` when the `aggregate` parameter is unset or null. +/// Otherwise, builds a grouped-aggregation query and returns `StatResult::Transformed`. +/// +/// Strategy: +/// - **Single-pass** (preferred): one `GROUP BY` produces a wide row per group, then +/// `CROSS JOIN VALUES(...)` of function names explodes to one row per (group × function). +/// Used when all requested functions are inline-able. +/// - **UNION ALL fallback**: when a quantile is requested but the dialect doesn't +/// provide `sql_quantile_inline`, fall back to per-function subqueries using +/// `dialect.sql_percentile`. +pub fn apply( + query: &str, + schema: &Schema, + aesthetics: &Mappings, + group_by: &[String], + parameters: &HashMap, + dialect: &dyn SqlDialect, +) -> Result { + let funcs = match extract_aggregate_param(parameters) { + None => return Ok(StatResult::Identity), + Some(funcs) => funcs, + }; + + // Discover position aesthetics on the layer, splitting into numeric (to be + // aggregated) and discrete (to be carried through as group columns). + let mut numeric_pos: Vec<(String, String)> = Vec::new(); // (aesthetic, prefixed col) + let mut discrete_pos_cols: Vec = Vec::new(); + for (aesthetic, value) in &aesthetics.aesthetics { + if !is_position_aesthetic(aesthetic) { + continue; + } + let col = match value.column_name() { + Some(c) => c.to_string(), + None => continue, + }; + let info = schema.iter().find(|c| c.name == col); + let is_discrete = info.map(|c| c.is_discrete).unwrap_or(false); + if is_discrete { + discrete_pos_cols.push(col); + } else { + numeric_pos.push((aesthetic.clone(), col)); + } + } + numeric_pos.sort_by(|a, b| a.0.cmp(&b.0)); + discrete_pos_cols.sort(); + + if numeric_pos.is_empty() && !funcs.iter().any(|f| f == "count") { + return Err(GgsqlError::ValidationError( + "aggregate requires at least one numeric position aesthetic, or the 'count' function" + .to_string(), + )); + } + + // Group columns: PARTITION BY + discrete mappings (already in group_by) + discrete + // position aesthetic columns. Deduplicated, preserving order. + let mut group_cols: Vec = Vec::new(); + for g in group_by { + if !group_cols.contains(g) { + group_cols.push(g.clone()); + } + } + for c in &discrete_pos_cols { + if !group_cols.contains(c) { + group_cols.push(c.clone()); + } + } + + let needs_count_col = funcs.iter().any(|f| f == "count"); + + // Decide strategy: single-pass when every quantile can be inlined. + let needs_fallback = funcs.iter().any(|f| { + if let Some(frac) = quantile_fraction(f) { + // Use the first numeric column (any will do) for the probe, since we + // only care whether the dialect produces Some or None. + let probe = numeric_pos + .first() + .map(|(_, c)| c.as_str()) + .unwrap_or("__ggsql_probe__"); + dialect.sql_quantile_inline(probe, frac).is_none() + } else { + false + } + }); + + let transformed_query = if needs_fallback { + build_union_all_query(query, &funcs, &numeric_pos, &group_cols, dialect) + } else { + build_single_pass_query(query, &funcs, &numeric_pos, &group_cols, dialect) + }; + + let mut stat_columns: Vec = numeric_pos.iter().map(|(a, _)| a.clone()).collect(); + stat_columns.push("aggregate".to_string()); + if needs_count_col { + stat_columns.push("count".to_string()); + } + + let consumed_aesthetics: Vec = numeric_pos.into_iter().map(|(a, _)| a).collect(); + + Ok(StatResult::Transformed { + query: transformed_query, + stat_columns, + dummy_columns: vec![], + consumed_aesthetics, + }) +} + +/// Extract the `aggregate` parameter as a list of function names, or `None` when +/// the parameter is unset/null. +fn extract_aggregate_param(parameters: &HashMap) -> Option> { + use crate::plot::types::ArrayElement; + match parameters.get("aggregate") { + None | Some(ParameterValue::Null) => None, + Some(ParameterValue::String(s)) => Some(vec![s.clone()]), + Some(ParameterValue::Array(arr)) => { + let names: Vec = arr + .iter() + .filter_map(|el| match el { + ArrayElement::String(s) => Some(s.clone()), + _ => None, + }) + .collect(); + if names.is_empty() { + None + } else { + Some(names) + } + } + _ => None, + } +} + +/// Map a quantile function name (`q05`..`q95`, `median`) to its fraction. +fn quantile_fraction(func: &str) -> Option { + match func { + "median" | "q50" => Some(0.50), + "q05" => Some(0.05), + "q10" => Some(0.10), + "q25" => Some(0.25), + "q75" => Some(0.75), + "q90" => Some(0.90), + "q95" => Some(0.95), + _ => None, + } +} + +/// Build the inline SQL fragment for a function applied to a quoted column. +/// +/// Returns None for `count` (which doesn't take a column) and for quantiles when +/// the dialect lacks an inline form (caller should switch to UNION ALL strategy). +fn function_inline_sql(func: &str, qcol: &str, dialect: &dyn SqlDialect) -> Option { + if func == "count" { + return None; + } + if let Some(frac) = quantile_fraction(func) { + // Strip the quotes added by `naming::quote_ident` so we can re-quote inside + // `sql_quantile_inline` via the same helper. The dialect impl quotes itself. + let unquoted = unquote(qcol); + return dialect.sql_quantile_inline(&unquoted, frac); + } + Some(match func { + "sum" => format!("SUM({})", qcol), + "prod" => format!("EXP(SUM(LN({})))", qcol), + "min" => format!("MIN({})", qcol), + "max" => format!("MAX({})", qcol), + "range" => format!("(MAX({c}) - MIN({c}))", c = qcol), + "mean" => format!("AVG({})", qcol), + "geomean" => format!("EXP(AVG(LN({})))", qcol), + "harmean" => format!("(COUNT({c}) * 1.0 / SUM(1.0 / {c}))", c = qcol), + "rms" => format!("SQRT(AVG({c} * {c}))", c = qcol), + "sdev" => format!("STDDEV_POP({})", qcol), + "var" => format!("VAR_POP({})", qcol), + "mean-sdev" => format!("(AVG({c}) - STDDEV_POP({c}))", c = qcol), + "mean+sdev" => format!("(AVG({c}) + STDDEV_POP({c}))", c = qcol), + "mean-2sdev" => format!("(AVG({c}) - 2.0 * STDDEV_POP({c}))", c = qcol), + "mean+2sdev" => format!("(AVG({c}) + 2.0 * STDDEV_POP({c}))", c = qcol), + "mean-se" => format!( + "(AVG({c}) - STDDEV_POP({c}) / SQRT(COUNT({c})))", + c = qcol + ), + "mean+se" => format!( + "(AVG({c}) + STDDEV_POP({c}) / SQRT(COUNT({c})))", + c = qcol + ), + // `iqr` is computed from quantiles - handled separately. + _ => return None, + }) +} + +/// Strip surrounding double quotes from an identifier, undoing `naming::quote_ident`. +fn unquote(qcol: &str) -> String { + let trimmed = qcol.trim_start_matches('"').trim_end_matches('"'); + trimmed.replace("\"\"", "\"") +} + +/// SQL for a function name literal, properly escaped. +fn func_literal(func: &str) -> String { + format!("'{}'", func.replace('\'', "''")) +} + +// ============================================================================= +// Single-pass strategy: GROUP BY produces a wide CTE, then CROSS JOIN explodes +// rows per requested function. +// ============================================================================= + +fn build_single_pass_query( + query: &str, + funcs: &[String], + numeric_pos: &[(String, String)], + group_cols: &[String], + dialect: &dyn SqlDialect, +) -> String { + let src_alias = "\"__ggsql_stat_src__\""; + let agg_alias = "\"__ggsql_stat_agg__\""; + let funcs_alias = "\"__ggsql_stat_funcs__\""; + + let group_by_clause = if group_cols.is_empty() { + String::new() + } else { + let qcols: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); + format!(" GROUP BY {}", qcols.join(", ")) + }; + + // Build the wide aggregation SELECT: one column per (function × position). + let mut wide_select_exprs: Vec = group_cols + .iter() + .map(|c| naming::quote_ident(c)) + .collect(); + + // Track the synthetic column names for each (aesthetic, function) pair. + let mut wide_col_for: HashMap<(String, String), String> = HashMap::new(); + + for (aes, col) in numeric_pos { + let qcol = naming::quote_ident(col); + for func in funcs { + if func == "count" { + continue; + } + let key = (aes.clone(), func.clone()); + if wide_col_for.contains_key(&key) { + continue; + } + let wide_name = synthetic_col_name(aes, func); + let expr = match func.as_str() { + "iqr" => { + // q75 - q25 inline if dialect supports it + let q75 = dialect + .sql_quantile_inline(col, 0.75) + .expect("sql_quantile_inline must be Some when single-pass is selected"); + let q25 = dialect + .sql_quantile_inline(col, 0.25) + .expect("sql_quantile_inline must be Some when single-pass is selected"); + format!("({} - {})", q75, q25) + } + _ => function_inline_sql(func, &qcol, dialect) + .expect("function_inline_sql must be Some when single-pass is selected"), + }; + wide_select_exprs.push(format!("{} AS {}", expr, naming::quote_ident(&wide_name))); + wide_col_for.insert(key, wide_name); + } + } + + let needs_count_col = funcs.iter().any(|f| f == "count"); + let count_wide = if needs_count_col { + let c = "__ggsql_stat_cnt__"; + wide_select_exprs.push(format!("COUNT(*) AS {}", naming::quote_ident(c))); + Some(c.to_string()) + } else { + None + }; + + let wide_select = wide_select_exprs.join(", "); + + // Build the CROSS JOIN VALUES table of function names. + let funcs_values: Vec = funcs.iter().map(|f| format!("({})", func_literal(f))).collect(); + let funcs_cte = format!( + "{}(name) AS (VALUES {})", + funcs_alias, + funcs_values.join(", ") + ); + + // Build the outer SELECT: group cols + per-aesthetic CASE + count CASE + name AS aggregate. + let mut outer_exprs: Vec = group_cols + .iter() + .map(|c| format!("{}.{}", agg_alias, naming::quote_ident(c))) + .collect(); + + for (aes, _) in numeric_pos { + let stat_col = naming::stat_column(aes); + let mut whens: Vec = Vec::new(); + for func in funcs { + if let Some(wide_name) = wide_col_for.get(&(aes.clone(), func.clone())) { + whens.push(format!( + "WHEN {} THEN {}.{}", + func_literal(func), + agg_alias, + naming::quote_ident(wide_name) + )); + } + } + let case_expr = if whens.is_empty() { + "NULL".to_string() + } else { + format!( + "CASE {}.name {} ELSE NULL END", + funcs_alias, + whens.join(" ") + ) + }; + outer_exprs.push(format!("{} AS {}", case_expr, naming::quote_ident(&stat_col))); + } + + if let Some(count_wide) = count_wide { + let stat_col = naming::stat_column("count"); + let case_expr = format!( + "CASE {f}.name WHEN {lit} THEN {a}.{c} ELSE NULL END", + f = funcs_alias, + a = agg_alias, + lit = func_literal("count"), + c = naming::quote_ident(&count_wide) + ); + outer_exprs.push(format!("{} AS {}", case_expr, naming::quote_ident(&stat_col))); + } + + let stat_aggregate_col = naming::stat_column("aggregate"); + outer_exprs.push(format!( + "{}.name AS {}", + funcs_alias, + naming::quote_ident(&stat_aggregate_col) + )); + + format!( + "WITH {src} AS ({query}), \ + {agg_alias_def} AS (SELECT {wide_select} FROM {src}{group_by}), \ + {funcs_cte} \ + SELECT {outer} FROM {agg} CROSS JOIN {funcs}", + src = src_alias, + query = query, + agg_alias_def = agg_alias, + wide_select = wide_select, + group_by = group_by_clause, + funcs_cte = funcs_cte, + outer = outer_exprs.join(", "), + agg = agg_alias, + funcs = funcs_alias, + ) +} + +/// Synthetic name for a (aesthetic, function) intermediate column in the wide CTE. +/// Includes a sanitized form of the function name to avoid collisions on `+`/`-`. +fn synthetic_col_name(aes: &str, func: &str) -> String { + let safe: String = func + .chars() + .map(|c| match c { + '+' => 'p', + '-' => 'm', + _ if c.is_ascii_alphanumeric() => c, + _ => '_', + }) + .collect(); + format!("__ggsql_stat_{}_{}", aes, safe) +} + +// ============================================================================= +// UNION ALL fallback strategy: one SELECT per requested function. +// ============================================================================= + +fn build_union_all_query( + query: &str, + funcs: &[String], + numeric_pos: &[(String, String)], + group_cols: &[String], + dialect: &dyn SqlDialect, +) -> String { + let src_alias = "\"__ggsql_stat_src__\""; + + let group_by_clause = if group_cols.is_empty() { + String::new() + } else { + let qcols: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); + format!(" GROUP BY {}", qcols.join(", ")) + }; + + let group_select: Vec = group_cols + .iter() + .map(|c| naming::quote_ident(c)) + .collect(); + + let needs_count_col = funcs.iter().any(|f| f == "count"); + let stat_aggregate_col = naming::stat_column("aggregate"); + let stat_count_col = naming::stat_column("count"); + + let branches: Vec = funcs + .iter() + .map(|func| { + let mut select_parts: Vec = group_select.clone(); + + for (aes, col) in numeric_pos { + let stat_col = naming::stat_column(aes); + let value_expr = if func == "count" { + "NULL".to_string() + } else if func == "iqr" { + let q75 = dialect.sql_percentile(col, 0.75, src_alias, group_cols); + let q25 = dialect.sql_percentile(col, 0.25, src_alias, group_cols); + format!("({} - {})", q75, q25) + } else if let Some(frac) = quantile_fraction(func) { + dialect.sql_percentile(col, frac, src_alias, group_cols) + } else { + let qcol = naming::quote_ident(col); + function_inline_sql(func, &qcol, dialect).unwrap_or_else(|| "NULL".to_string()) + }; + select_parts.push(format!("{} AS {}", value_expr, naming::quote_ident(&stat_col))); + } + + if needs_count_col { + let value_expr = if func == "count" { + "COUNT(*)".to_string() + } else { + "NULL".to_string() + }; + select_parts.push(format!( + "{} AS {}", + value_expr, + naming::quote_ident(&stat_count_col) + )); + } + + select_parts.push(format!( + "{} AS {}", + func_literal(func), + naming::quote_ident(&stat_aggregate_col) + )); + + // Quantile fallbacks (sql_percentile) need the outer alias `__ggsql_qt__` + // so their correlated WHERE clause can find group columns. + format!( + "SELECT {} FROM {} AS \"__ggsql_qt__\"{}", + select_parts.join(", "), + src_alias, + group_by_clause + ) + }) + .collect(); + + format!( + "WITH {src} AS ({query}) {branches}", + src = src_alias, + query = query, + branches = branches.join(" UNION ALL ") + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::plot::types::{AestheticValue, ColumnInfo}; + use arrow::datatypes::DataType; + + /// A test dialect that mimics DuckDB's native QUANTILE_CONT support. + struct InlineQuantileDialect; + impl SqlDialect for InlineQuantileDialect { + fn sql_quantile_inline(&self, column: &str, fraction: f64) -> Option { + Some(format!( + "QUANTILE_CONT({}, {})", + naming::quote_ident(column), + fraction + )) + } + } + + /// A test dialect with no inline quantile support, exercising the UNION ALL fallback. + struct NoInlineQuantileDialect; + impl SqlDialect for NoInlineQuantileDialect {} + + fn col(name: &str) -> AestheticValue { + AestheticValue::Column { + name: name.to_string(), + original_name: None, + is_dummy: false, + } + } + + fn numeric_schema(cols: &[&str]) -> Schema { + cols.iter() + .map(|c| ColumnInfo { + name: c.to_string(), + dtype: DataType::Float64, + is_discrete: false, + min: None, + max: None, + }) + .collect() + } + + #[test] + fn returns_identity_when_param_unset() { + let aes = Mappings::new(); + let schema: Schema = vec![]; + let params = HashMap::new(); + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + assert_eq!(result, StatResult::Identity); + } + + #[test] + fn returns_identity_when_param_null() { + let aes = Mappings::new(); + let schema: Schema = vec![]; + let mut params = HashMap::new(); + params.insert("aggregate".to_string(), ParameterValue::Null); + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + assert_eq!(result, StatResult::Identity); + } + + #[test] + fn single_pass_for_mean_emits_avg() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + + match result { + StatResult::Transformed { + query, + stat_columns, + consumed_aesthetics, + .. + } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!(query.contains("CROSS JOIN")); + assert!(stat_columns.contains(&"pos2".to_string())); + assert!(stat_columns.contains(&"aggregate".to_string())); + assert!(!stat_columns.contains(&"count".to_string())); + assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn count_emits_count_star_and_keeps_count_column() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("count".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + + match result { + StatResult::Transformed { + query, + stat_columns, + .. + } => { + assert!(query.contains("COUNT(*)")); + assert!(stat_columns.contains(&"count".to_string())); + assert!(stat_columns.contains(&"aggregate".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn mixed_count_and_mean_produces_two_rows_per_group() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("count".to_string()), + ArrayElement::String("mean".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); + assert!(query.contains("COUNT(*)")); + assert!(query.contains("'count'")); + assert!(query.contains("'mean'")); + // The count CASE must reference the agg CTE for the value column, + // not the funcs CTE (regression: previously emitted funcs.cnt which + // doesn't exist). + assert!( + query.contains("\"__ggsql_stat_agg__\".\"__ggsql_stat_cnt__\""), + "count CASE should reference the agg CTE, query was: {}", + query + ); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn quantile_uses_dialect_inline_when_available() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("q25".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("QUANTILE_CONT")); + assert!(query.contains("0.25")); + assert!(!query.contains("UNION ALL")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn quantile_falls_back_to_union_all_without_dialect_support() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("q25".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &NoInlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + // Fallback dialect uses NTILE-based correlated subquery via UNION ALL. + assert!(query.contains("NTILE(4)")); + assert!(query.contains("UNION ALL") || !query.contains("CROSS JOIN")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn mean_sdev_emits_avg_and_stddev() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("mean-sdev".to_string()), + ArrayElement::String("mean+sdev".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("STDDEV_POP")); + assert!(query.contains("AVG")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn mean_se_includes_sqrt_count() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean+se".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("SQRT(COUNT")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn prod_emits_exp_sum_ln() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("prod".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("EXP(SUM(LN")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn iqr_emits_q75_minus_q25() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("iqr".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("0.75")); + assert!(query.contains("0.25")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn discrete_position_aesthetic_becomes_group_column() { + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = vec![ + ColumnInfo { + name: "__ggsql_aes_pos1__".to_string(), + dtype: DataType::Utf8, + is_discrete: true, + min: None, + max: None, + }, + ColumnInfo { + name: "__ggsql_aes_pos2__".to_string(), + dtype: DataType::Float64, + is_discrete: false, + min: None, + max: None, + }, + ]; + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + stat_columns, + consumed_aesthetics, + .. + } => { + // pos1 (discrete) is in GROUP BY, not aggregated. + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + // pos2 is aggregated. + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); + // Only pos2 is consumed. + assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); + // Only pos2 (numeric) appears in stat_columns; pos1 stays as-is. + assert!(stat_columns.contains(&"pos2".to_string())); + assert!(!stat_columns.contains(&"pos1".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn explicit_group_by_columns_appear_in_query() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &["region".to_string()], + ¶ms, + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("GROUP BY \"region\"")); + } + _ => panic!("expected Transformed"), + } + } +} diff --git a/src/plot/layer/geom/text.rs b/src/plot/layer/geom/text.rs index 6ceb45f9..5909c34d 100644 --- a/src/plot/layer/geom/text.rs +++ b/src/plot/layer/geom/text.rs @@ -63,6 +63,10 @@ impl GeomTrait for Text { PARAMS } + fn supports_aggregate(&self) -> bool { + true + } + fn post_process( &self, df: DataFrame, diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 33321a48..e6656590 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -419,6 +419,11 @@ impl Layer { { validate_parameter(param_name, value, ¶m.constraint)?; } + // Or the shared `aggregate` param for Identity-stat geoms + else if param_name == "aggregate" && self.geom.supports_aggregate() { + let definition = crate::plot::layer::geom::stat_aggregate::aggregate_param_definition(); + validate_parameter(param_name, value, &definition.constraint)?; + } // Otherwise it's a valid aesthetic setting (no constraint validation needed) } diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index aae89f20..6e7ab0cb 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -41,6 +41,14 @@ impl super::SqlDialect for DuckDbDialect { ) } + fn sql_quantile_inline(&self, column: &str, fraction: f64) -> Option { + Some(format!( + "QUANTILE_CONT({}, {})", + naming::quote_ident(column), + fraction + )) + } + fn sql_percentile(&self, column: &str, fraction: f64, from: &str, groups: &[String]) -> String { let group_filter = groups .iter() diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 16a96b66..a02bda3c 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -215,6 +215,17 @@ pub trait SqlDialect { ) } + /// Inline-form quantile aggregate, usable directly in a `SELECT` list. + /// + /// Returns `Some(sql_fragment)` when the dialect supports a native quantile + /// aggregate that can be combined with other aggregates in the same `GROUP BY` + /// query (e.g. DuckDB's `QUANTILE_CONT`). Returns `None` when no native + /// inline form exists; callers should then fall back to [`sql_percentile`], + /// which produces a correlated scalar subquery. + fn sql_quantile_inline(&self, _column: &str, _fraction: f64) -> Option { + None + } + /// SQL literal for a date value (days since Unix epoch). fn sql_date_literal(&self, days_since_epoch: i32) -> String { format!( From 778b6acbfe7305403d99543e922cdd266b4376bb Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 27 Apr 2026 13:58:46 +0200 Subject: [PATCH 02/33] support numeric axis geoms --- src/plot/layer/geom/area.rs | 31 +-- src/plot/layer/geom/arrow.rs | 4 + src/plot/layer/geom/bar.rs | 10 +- src/plot/layer/geom/line.rs | 30 ++- src/plot/layer/geom/mod.rs | 28 ++- src/plot/layer/geom/path.rs | 4 + src/plot/layer/geom/point.rs | 4 + src/plot/layer/geom/polygon.rs | 4 + src/plot/layer/geom/ribbon.rs | 31 +-- src/plot/layer/geom/rule.rs | 6 + src/plot/layer/geom/segment.rs | 4 + src/plot/layer/geom/stat_aggregate.rs | 303 +++++++++++++++++++++++++- src/plot/layer/geom/text.rs | 4 + src/plot/layer/geom/types.rs | 88 ++++++++ 14 files changed, 501 insertions(+), 50 deletions(-) diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index 101806d0..f6388032 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -3,10 +3,10 @@ use crate::plot::layer::orientation::{ALIGNED, ORIENTATION_VALUES}; use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamDefinition}; -use crate::{naming, Mappings}; +use crate::Mappings; use super::stat_aggregate; -use super::types::{ParamConstraint, POSITION_VALUES}; +use super::types::{wrap_with_order_by, ParamConstraint, POSITION_VALUES}; use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; /// Area geom - filled area charts @@ -73,17 +73,22 @@ impl GeomTrait for Area { _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { - if has_aggregate_param(parameters) { - return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); - } - // Area geom needs ordering by pos1 (domain axis) for proper rendering - let order_col = naming::aesthetic_column("pos1"); - Ok(StatResult::Transformed { - query: format!("{} ORDER BY {}", query, naming::quote_ident(&order_col)), - stat_columns: vec![], - dummy_columns: vec![], - consumed_aesthetics: vec![], - }) + let result = if has_aggregate_param(parameters) { + stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + self.aggregate_slots(), + )? + } else { + StatResult::Identity + }; + // Area needs ordering by pos1 (domain axis) for proper rendering, in both + // the Identity and Aggregate paths. + Ok(wrap_with_order_by(query, result, "pos1")) } } diff --git a/src/plot/layer/geom/arrow.rs b/src/plot/layer/geom/arrow.rs index 2e3369d2..5737bb95 100644 --- a/src/plot/layer/geom/arrow.rs +++ b/src/plot/layer/geom/arrow.rs @@ -43,6 +43,10 @@ impl GeomTrait for Arrow { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } } impl std::fmt::Display for Arrow { diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index 7824f74f..b3b82d72 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -99,7 +99,15 @@ impl GeomTrait for Bar { dialect: &dyn SqlDialect, ) -> Result { if has_aggregate_param(parameters) { - return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + return stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + self.aggregate_slots(), + ); } stat_bar_count(query, schema, aesthetics, group_by) } diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index 20b87228..92f7927c 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -1,13 +1,14 @@ //! Line geom implementation use super::stat_aggregate; +use super::types::wrap_with_order_by; use super::{ has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, StatResult, }; use crate::plot::layer::orientation::{ALIGNED, ORIENTATION_VALUES}; use crate::plot::types::DefaultAestheticValue; -use crate::{naming, Mappings}; +use crate::Mappings; /// Line geom - line charts with connected points #[derive(Debug, Clone, Copy)] @@ -58,17 +59,22 @@ impl GeomTrait for Line { _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { - if has_aggregate_param(parameters) { - return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); - } - // Line geom needs ordering by pos1 (domain axis) for proper rendering - let order_col = naming::aesthetic_column("pos1"); - Ok(StatResult::Transformed { - query: format!("{} ORDER BY {}", query, naming::quote_ident(&order_col)), - stat_columns: vec![], - dummy_columns: vec![], - consumed_aesthetics: vec![], - }) + let result = if has_aggregate_param(parameters) { + stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + self.aggregate_slots(), + )? + } else { + StatResult::Identity + }; + // Line needs ordering by pos1 (domain axis) for proper rendering, in both + // the Identity and Aggregate paths. + Ok(wrap_with_order_by(query, result, "pos1")) } } diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index a10e6165..1fd22dbd 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -203,6 +203,19 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { false } + /// Which numeric position-aesthetic slots the Aggregate stat should reduce. + /// + /// Slot 1 is `pos1`/`pos1min`/`pos1max`/`pos1end` (the independent / domain axis). + /// Slot 2 is `pos2`/`pos2min`/`pos2max`/`pos2end` (the dependent / range axis). + /// + /// Default: `&[2]` — only the dependent axis is reduced; pos1-family stays as a + /// grouping column, so e.g. line geoms produce a summary trace along x. Geoms + /// whose natural Aggregate is centroid-like (point, polygon, segment, arrow, + /// text, path, tile, rule) override to `&[1, 2]`. + fn aggregate_slots(&self) -> &'static [u8] { + &[2] + } + /// Apply statistical transformation to the layer query. /// /// The default implementation dispatches to the Aggregate stat when @@ -220,7 +233,15 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { dialect: &dyn SqlDialect, ) -> Result { if self.supports_aggregate() && has_aggregate_param(parameters) { - return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); + return stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + self.aggregate_slots(), + ); } Ok(StatResult::Identity) } @@ -488,6 +509,11 @@ impl Geom { self.0.supports_aggregate() } + /// Which position-aesthetic slots the Aggregate stat should reduce. + pub fn aggregate_slots(&self) -> &'static [u8] { + self.0.aggregate_slots() + } + /// Validate aesthetic mappings pub fn validate_aesthetics(&self, mappings: &Mappings) -> std::result::Result<(), String> { self.0.validate_aesthetics(mappings) diff --git a/src/plot/layer/geom/path.rs b/src/plot/layer/geom/path.rs index 062e5f73..c2c8af9f 100644 --- a/src/plot/layer/geom/path.rs +++ b/src/plot/layer/geom/path.rs @@ -40,6 +40,10 @@ impl GeomTrait for Path { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } } impl std::fmt::Display for Path { diff --git a/src/plot/layer/geom/point.rs b/src/plot/layer/geom/point.rs index 5101f2f0..1f60a5f6 100644 --- a/src/plot/layer/geom/point.rs +++ b/src/plot/layer/geom/point.rs @@ -42,6 +42,10 @@ impl GeomTrait for Point { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } } impl std::fmt::Display for Point { diff --git a/src/plot/layer/geom/polygon.rs b/src/plot/layer/geom/polygon.rs index dee34338..efda483e 100644 --- a/src/plot/layer/geom/polygon.rs +++ b/src/plot/layer/geom/polygon.rs @@ -41,6 +41,10 @@ impl GeomTrait for Polygon { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } } impl std::fmt::Display for Polygon { diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index 98b60951..bf1898b9 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -1,11 +1,11 @@ //! Ribbon geom implementation use super::stat_aggregate; -use super::types::POSITION_VALUES; +use super::types::{wrap_with_order_by, POSITION_VALUES}; use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamConstraint, ParamDefinition}; -use crate::{naming, Mappings}; +use crate::Mappings; /// Ribbon geom - confidence bands and ranges #[derive(Debug, Clone, Copy)] @@ -58,17 +58,22 @@ impl GeomTrait for Ribbon { _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { - if has_aggregate_param(parameters) { - return stat_aggregate::apply(query, schema, aesthetics, group_by, parameters, dialect); - } - // Ribbon geom needs ordering by pos1 (domain axis) for proper rendering - let order_col = naming::aesthetic_column("pos1"); - Ok(StatResult::Transformed { - query: format!("{} ORDER BY {}", query, naming::quote_ident(&order_col)), - stat_columns: vec![], - dummy_columns: vec![], - consumed_aesthetics: vec![], - }) + let result = if has_aggregate_param(parameters) { + stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + self.aggregate_slots(), + )? + } else { + StatResult::Identity + }; + // Ribbon needs ordering by pos1 (domain axis) for proper rendering, in both + // the Identity and Aggregate paths. + Ok(wrap_with_order_by(query, result, "pos1")) } } diff --git a/src/plot/layer/geom/rule.rs b/src/plot/layer/geom/rule.rs index a495cb48..21d7adbe 100644 --- a/src/plot/layer/geom/rule.rs +++ b/src/plot/layer/geom/rule.rs @@ -29,6 +29,12 @@ impl GeomTrait for Rule { true } + fn aggregate_slots(&self) -> &'static [u8] { + // Rule maps exactly one of pos1/pos2 (XOR). Allow either to be the reduced + // axis — whichever is mapped wins, and the other slot has nothing to filter. + &[1, 2] + } + fn validate_aesthetics(&self, mappings: &crate::Mappings) -> std::result::Result<(), String> { // Rule requires exactly one of pos1 or pos2 (XOR logic) let has_pos1 = mappings.contains_key("pos1"); diff --git a/src/plot/layer/geom/segment.rs b/src/plot/layer/geom/segment.rs index d3fac22d..b229d054 100644 --- a/src/plot/layer/geom/segment.rs +++ b/src/plot/layer/geom/segment.rs @@ -43,6 +43,10 @@ impl GeomTrait for Segment { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } } impl std::fmt::Display for Segment { diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 54692d88..b446cc0a 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -14,7 +14,7 @@ use std::collections::HashMap; use super::types::StatResult; use crate::naming; -use crate::plot::aesthetic::is_position_aesthetic; +use crate::plot::aesthetic::{is_position_aesthetic, parse_position}; use crate::plot::types::{ DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue, Schema, }; @@ -89,16 +89,19 @@ pub fn apply( group_by: &[String], parameters: &HashMap, dialect: &dyn SqlDialect, + agg_slots: &[u8], ) -> Result { let funcs = match extract_aggregate_param(parameters) { None => return Ok(StatResult::Identity), Some(funcs) => funcs, }; - // Discover position aesthetics on the layer, splitting into numeric (to be - // aggregated) and discrete (to be carried through as group columns). + // Walk the layer's position aesthetics and route each by (slot, type): + // in-axis slot && numeric → aggregated (numeric_pos) + // in-axis slot && discrete → kept as group column (kept_pos_cols) + // out-of-axis (any type) → kept as group column (kept_pos_cols) let mut numeric_pos: Vec<(String, String)> = Vec::new(); // (aesthetic, prefixed col) - let mut discrete_pos_cols: Vec = Vec::new(); + let mut kept_pos_cols: Vec = Vec::new(); for (aesthetic, value) in &aesthetics.aesthetics { if !is_position_aesthetic(aesthetic) { continue; @@ -107,16 +110,19 @@ pub fn apply( Some(c) => c.to_string(), None => continue, }; + let slot = parse_position(aesthetic).map(|(s, _)| s).unwrap_or(0); + let in_axis = agg_slots.contains(&slot); let info = schema.iter().find(|c| c.name == col); let is_discrete = info.map(|c| c.is_discrete).unwrap_or(false); - if is_discrete { - discrete_pos_cols.push(col); + + if !in_axis || is_discrete { + kept_pos_cols.push(col); } else { numeric_pos.push((aesthetic.clone(), col)); } } numeric_pos.sort_by(|a, b| a.0.cmp(&b.0)); - discrete_pos_cols.sort(); + kept_pos_cols.sort(); if numeric_pos.is_empty() && !funcs.iter().any(|f| f == "count") { return Err(GgsqlError::ValidationError( @@ -125,15 +131,16 @@ pub fn apply( )); } - // Group columns: PARTITION BY + discrete mappings (already in group_by) + discrete - // position aesthetic columns. Deduplicated, preserving order. + // Group columns: PARTITION BY + discrete mappings (already in group_by) + any + // position-aesthetic columns we kept (out-of-axis or in-axis-but-discrete). + // Deduplicated, preserving order. let mut group_cols: Vec = Vec::new(); for g in group_by { if !group_cols.contains(g) { group_cols.push(g.clone()); } } - for c in &discrete_pos_cols { + for c in &kept_pos_cols { if !group_cols.contains(c) { group_cols.push(c.clone()); } @@ -577,6 +584,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); assert_eq!(result, StatResult::Identity); @@ -595,6 +603,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); assert_eq!(result, StatResult::Identity); @@ -618,6 +627,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); @@ -657,6 +667,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); @@ -696,6 +707,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -735,6 +747,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -765,6 +778,7 @@ mod tests { &[], ¶ms, &NoInlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -799,6 +813,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -828,6 +843,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -856,6 +872,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -884,6 +901,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -929,6 +947,7 @@ mod tests { &[], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -970,6 +989,7 @@ mod tests { &["region".to_string()], ¶ms, &InlineQuantileDialect, + &[2], ) .unwrap(); match result { @@ -979,4 +999,267 @@ mod tests { _ => panic!("expected Transformed"), } } + + #[test] + fn line_style_groups_by_pos1_and_aggregates_pos2() { + // slots=[2]: pos1 stays as group (even though numeric), pos2 gets aggregated. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos1__", "__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("max".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + consumed_aesthetics, + stat_columns, + .. + } => { + assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!(!query.contains("MAX(\"__ggsql_aes_pos1__\")")); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); + assert!(stat_columns.contains(&"pos2".to_string())); + assert!(!stat_columns.contains(&"pos1".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn point_style_aggregates_both_slots() { + // slots=[1,2]: both pos1 and pos2 (numeric) get aggregated → centroid. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos1__", "__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[1, 2], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + consumed_aesthetics, + stat_columns, + .. + } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"), "query: {}", query); + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!(!query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + let mut consumed = consumed_aesthetics.clone(); + consumed.sort(); + assert_eq!(consumed, vec!["pos1".to_string(), "pos2".to_string()]); + assert!(stat_columns.contains(&"pos1".to_string())); + assert!(stat_columns.contains(&"pos2".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn errorbar_aggregates_pos2_minmax() { + // slots=[2]: pos1 fixed (group), pos2min and pos2max both aggregated. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2min", col("__ggsql_aes_pos2min__")); + aes.insert("pos2max", col("__ggsql_aes_pos2max__")); + let schema = numeric_schema(&[ + "__ggsql_aes_pos1__", + "__ggsql_aes_pos2min__", + "__ggsql_aes_pos2max__", + ]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + consumed_aesthetics, + .. + } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos2min__\")"), "query: {}", query); + assert!(query.contains("AVG(\"__ggsql_aes_pos2max__\")"), "query: {}", query); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + let mut consumed = consumed_aesthetics.clone(); + consumed.sort(); + assert_eq!(consumed, vec!["pos2max".to_string(), "pos2min".to_string()]); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn out_of_axis_numeric_pos_stays_as_group() { + // slots=[2], numeric pos1 → still goes to GROUP BY (not aggregated). + // Same expectation as line_style_groups_by_pos1_and_aggregates_pos2 but + // explicit about the "numeric out-of-axis" path. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos1__", "__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn discrete_in_axis_pos_stays_as_group_on_centroid_geom() { + // slots=[1,2], pos1 discrete + pos2 numeric → only pos2 aggregated, + // pos1 stays as GROUP BY. Confirms numeric check is preserved on + // slot=[1,2] geoms (e.g. point with category AS x, value AS y). + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = vec![ + ColumnInfo { + name: "__ggsql_aes_pos1__".to_string(), + dtype: DataType::Utf8, + is_discrete: true, + min: None, + max: None, + }, + ColumnInfo { + name: "__ggsql_aes_pos2__".to_string(), + dtype: DataType::Float64, + is_discrete: false, + min: None, + max: None, + }, + ]; + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[1, 2], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + consumed_aesthetics, + stat_columns, + .. + } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); + assert!(!query.contains("AVG(\"__ggsql_aes_pos1__\")")); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); + assert!(stat_columns.contains(&"pos2".to_string())); + assert!(!stat_columns.contains(&"pos1".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn count_works_with_no_numeric_pos() { + // slots=[2], only discrete pos1 mapped, aggregate=count → no + // "needs numeric" error; query has COUNT(*) and groups by pos1. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + let schema = vec![ColumnInfo { + name: "__ggsql_aes_pos1__".to_string(), + dtype: DataType::Utf8, + is_discrete: true, + min: None, + max: None, + }]; + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("count".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + stat_columns, + .. + } => { + assert!(query.contains("COUNT(*)")); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + assert!(stat_columns.contains(&"count".to_string())); + } + _ => panic!("expected Transformed"), + } + } } diff --git a/src/plot/layer/geom/text.rs b/src/plot/layer/geom/text.rs index 5909c34d..d9af79ac 100644 --- a/src/plot/layer/geom/text.rs +++ b/src/plot/layer/geom/text.rs @@ -67,6 +67,10 @@ impl GeomTrait for Text { true } + fn aggregate_slots(&self) -> &'static [u8] { + &[1, 2] + } + fn post_process( &self, df: DataFrame, diff --git a/src/plot/layer/geom/types.rs b/src/plot/layer/geom/types.rs index fb0ab5b8..4bdcf897 100644 --- a/src/plot/layer/geom/types.rs +++ b/src/plot/layer/geom/types.rs @@ -175,6 +175,44 @@ pub use crate::plot::types::ColumnInfo; /// Schema of a data source - list of columns with type info pub use crate::plot::types::Schema; +/// Wrap a stat result with `ORDER BY `. +/// +/// Used by line/area/ribbon to ensure the rendered output is sorted along the +/// domain axis whether or not the layer also goes through the Aggregate stat. +/// +/// - `Identity` → becomes `Transformed` with ` ORDER BY `, +/// empty `stat_columns`/`dummy_columns`/`consumed_aesthetics`. Same shape as +/// the previous inline `ORDER BY` path produced. +/// - `Transformed` → wraps the existing query in +/// `SELECT * FROM () AS "__ggsql_ord__" ORDER BY ` and preserves +/// the stat metadata. +pub fn wrap_with_order_by(input_query: &str, result: StatResult, aesthetic: &str) -> StatResult { + let order_col = naming::aesthetic_column(aesthetic); + let order_quoted = naming::quote_ident(&order_col); + match result { + StatResult::Identity => StatResult::Transformed { + query: format!("{} ORDER BY {}", input_query, order_quoted), + stat_columns: vec![], + dummy_columns: vec![], + consumed_aesthetics: vec![], + }, + StatResult::Transformed { + query, + stat_columns, + dummy_columns, + consumed_aesthetics, + } => StatResult::Transformed { + query: format!( + "SELECT * FROM ({}) AS \"__ggsql_ord__\" ORDER BY {}", + query, order_quoted + ), + stat_columns, + dummy_columns, + consumed_aesthetics, + }, + } +} + /// Helper to extract column name from aesthetic value pub fn get_column_name(aesthetics: &Mappings, aesthetic: &str) -> Option { use crate::AestheticValue; @@ -260,6 +298,56 @@ mod tests { assert!(!aes.is_required("yend")); } + #[test] + fn wrap_with_order_by_identity_appends_order() { + let result = wrap_with_order_by("SELECT * FROM t", StatResult::Identity, "pos1"); + match result { + StatResult::Transformed { + query, + stat_columns, + dummy_columns, + consumed_aesthetics, + } => { + assert_eq!( + query, + "SELECT * FROM t ORDER BY \"__ggsql_aes_pos1__\"" + ); + assert!(stat_columns.is_empty()); + assert!(dummy_columns.is_empty()); + assert!(consumed_aesthetics.is_empty()); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn wrap_with_order_by_transformed_wraps_query_and_preserves_metadata() { + let inner = StatResult::Transformed { + query: "SELECT * FROM grouped".to_string(), + stat_columns: vec!["pos2".to_string(), "aggregate".to_string()], + dummy_columns: vec!["pos1".to_string()], + consumed_aesthetics: vec!["pos2".to_string()], + }; + let result = wrap_with_order_by("SELECT * FROM raw", inner, "pos1"); + match result { + StatResult::Transformed { + query, + stat_columns, + dummy_columns, + consumed_aesthetics, + } => { + assert_eq!( + query, + "SELECT * FROM (SELECT * FROM grouped) AS \"__ggsql_ord__\" ORDER BY \"__ggsql_aes_pos1__\"" + ); + assert_eq!(stat_columns, vec!["pos2".to_string(), "aggregate".to_string()]); + assert_eq!(dummy_columns, vec!["pos1".to_string()]); + assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); + } + _ => panic!("expected Transformed"), + } + } + #[test] fn test_color_alias_requires_stroke_or_fill() { // Geom with neither stroke nor fill: color alias should NOT be supported From 0a1b214f12a0a389028319ee441311c053b17ee0 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 27 Apr 2026 14:27:16 +0200 Subject: [PATCH 03/33] support range geoms --- src/plot/layer/geom/area.rs | 1 + src/plot/layer/geom/bar.rs | 1 + src/plot/layer/geom/errorbar.rs | 8 + src/plot/layer/geom/line.rs | 1 + src/plot/layer/geom/mod.rs | 19 ++ src/plot/layer/geom/ribbon.rs | 9 + src/plot/layer/geom/stat_aggregate.rs | 414 ++++++++++++++++++++++++++ src/plot/layer/mod.rs | 15 + 8 files changed, 468 insertions(+) diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index f6388032..31617ea6 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -82,6 +82,7 @@ impl GeomTrait for Area { parameters, dialect, self.aggregate_slots(), + self.aggregate_range_pair(), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index b3b82d72..aebf207e 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -107,6 +107,7 @@ impl GeomTrait for Bar { parameters, dialect, self.aggregate_slots(), + self.aggregate_range_pair(), ); } stat_bar_count(query, schema, aesthetics, group_by) diff --git a/src/plot/layer/geom/errorbar.rs b/src/plot/layer/geom/errorbar.rs index 2821d141..5d088224 100644 --- a/src/plot/layer/geom/errorbar.rs +++ b/src/plot/layer/geom/errorbar.rs @@ -21,6 +21,10 @@ impl GeomTrait for ErrorBar { ("pos1", DefaultAestheticValue::Required), ("pos2min", DefaultAestheticValue::Required), ("pos2max", DefaultAestheticValue::Required), + // pos2 is the input column for the Aggregate stat in range mode + // (`SETTING aggregate => (lower_func, upper_func)` consumes pos2 + // and produces pos2min/pos2max). Optional otherwise. + ("pos2", DefaultAestheticValue::Null), ("stroke", DefaultAestheticValue::String("black")), ("opacity", DefaultAestheticValue::Number(1.0)), ("linewidth", DefaultAestheticValue::Number(1.0)), @@ -48,6 +52,10 @@ impl GeomTrait for ErrorBar { fn supports_aggregate(&self) -> bool { true } + + fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { + Some(("pos2min", "pos2max")) + } } impl std::fmt::Display for ErrorBar { diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index 92f7927c..6493fd83 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -68,6 +68,7 @@ impl GeomTrait for Line { parameters, dialect, self.aggregate_slots(), + self.aggregate_range_pair(), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index 1fd22dbd..806a1405 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -216,6 +216,19 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { &[2] } + /// Range pair for range-style Aggregate output. + /// + /// When `Some((lower, upper))`, this geom is a "range geom" that takes exactly + /// two `aggregate` functions and assigns them to the two named aesthetics + /// (e.g. `("pos2min", "pos2max")` for ribbon/errorbar). The user maps `pos2` + /// as the input column; the stat consumes pos2 and produces the range pair. + /// One row per group; no `aggregate` tag column. + /// + /// `None` (default) means standard per-function-rows aggregation. + fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { + None + } + /// Apply statistical transformation to the layer query. /// /// The default implementation dispatches to the Aggregate stat when @@ -241,6 +254,7 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { parameters, dialect, self.aggregate_slots(), + self.aggregate_range_pair(), ); } Ok(StatResult::Identity) @@ -514,6 +528,11 @@ impl Geom { self.0.aggregate_slots() } + /// Range pair for range-style Aggregate output, if any. + pub fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { + self.0.aggregate_range_pair() + } + /// Validate aesthetic mappings pub fn validate_aesthetics(&self, mappings: &Mappings) -> std::result::Result<(), String> { self.0.validate_aesthetics(mappings) diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index bf1898b9..07f005d7 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -22,6 +22,10 @@ impl GeomTrait for Ribbon { ("pos1", DefaultAestheticValue::Required), ("pos2min", DefaultAestheticValue::Required), ("pos2max", DefaultAestheticValue::Required), + // pos2 is the input column for the Aggregate stat in range mode + // (`SETTING aggregate => (lower_func, upper_func)` consumes pos2 + // and produces pos2min/pos2max). Optional otherwise. + ("pos2", DefaultAestheticValue::Null), ("fill", DefaultAestheticValue::String("black")), ("stroke", DefaultAestheticValue::String("black")), ("opacity", DefaultAestheticValue::Number(0.8)), @@ -44,6 +48,10 @@ impl GeomTrait for Ribbon { true } + fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { + Some(("pos2min", "pos2max")) + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -67,6 +75,7 @@ impl GeomTrait for Ribbon { parameters, dialect, self.aggregate_slots(), + self.aggregate_range_pair(), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index b446cc0a..2d3ce1a0 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -90,12 +90,17 @@ pub fn apply( parameters: &HashMap, dialect: &dyn SqlDialect, agg_slots: &[u8], + range_pair: Option<(&'static str, &'static str)>, ) -> Result { let funcs = match extract_aggregate_param(parameters) { None => return Ok(StatResult::Identity), Some(funcs) => funcs, }; + if let Some((lo, hi)) = range_pair { + return apply_range_mode(query, schema, aesthetics, group_by, &funcs, dialect, lo, hi); + } + // Walk the layer's position aesthetics and route each by (slot, type): // in-axis slot && numeric → aggregated (numeric_pos) // in-axis slot && discrete → kept as group column (kept_pos_cols) @@ -278,6 +283,143 @@ fn func_literal(func: &str) -> String { format!("'{}'", func.replace('\'', "''")) } +// ============================================================================= +// Range-mode strategy: exactly two functions filling a (lower, upper) aesthetic +// pair on the same row. Used by ribbon/errorbar. +// ============================================================================= + +fn apply_range_mode( + query: &str, + schema: &Schema, + aesthetics: &Mappings, + group_by: &[String], + funcs: &[String], + dialect: &dyn SqlDialect, + lo: &'static str, + hi: &'static str, +) -> Result { + if funcs.len() != 2 { + return Err(GgsqlError::ValidationError(format!( + "aggregate on a range geom must be an array of exactly two functions (lower, upper), got {}", + funcs.len() + ))); + } + + // Range mode requires `pos2` mapped to a numeric input column. The user + // writes `MAPPING value AS y` and the stat consumes it to produce both + // bounds. + let input_col = match aesthetics.get("pos2").and_then(|v| v.column_name()) { + Some(c) => c.to_string(), + None => { + return Err(GgsqlError::ValidationError( + "aggregate on a range geom requires a `y` (pos2) mapping as the input column" + .to_string(), + )); + } + }; + let info = schema.iter().find(|c| c.name == input_col); + if info.map(|c| c.is_discrete).unwrap_or(false) { + return Err(GgsqlError::ValidationError( + "aggregate on a range geom requires a numeric `y` (pos2) input, not a discrete column" + .to_string(), + )); + } + let qcol = naming::quote_ident(&input_col); + + // Group columns: PARTITION BY + discrete mappings (already in group_by) + + // any discrete position aesthetics on the layer (e.g. pos1 if it's a string). + let mut group_cols: Vec = Vec::new(); + for g in group_by { + if !group_cols.contains(g) { + group_cols.push(g.clone()); + } + } + for (aesthetic, value) in &aesthetics.aesthetics { + if !is_position_aesthetic(aesthetic) || aesthetic == "pos2" { + continue; + } + let col = match value.column_name() { + Some(c) => c.to_string(), + None => continue, + }; + if !group_cols.contains(&col) { + group_cols.push(col); + } + } + + let src_alias = "\"__ggsql_stat_src__\""; + let group_by_clause = if group_cols.is_empty() { + String::new() + } else { + let qcols: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); + format!(" GROUP BY {}", qcols.join(", ")) + }; + + // Build the two function expressions. Quantiles use the inline form when + // available; otherwise fall back to `sql_percentile` correlated to the + // outer alias used in the FROM (`__ggsql_qt__`, matching boxplot/etc.). + let lo_expr = build_range_function_sql(&funcs[0], &qcol, &input_col, dialect, &group_cols)?; + let hi_expr = build_range_function_sql(&funcs[1], &qcol, &input_col, dialect, &group_cols)?; + + let stat_lo = naming::stat_column(lo); + let stat_hi = naming::stat_column(hi); + + let group_select: Vec = group_cols + .iter() + .map(|c| naming::quote_ident(c)) + .collect(); + let mut select_parts = group_select.clone(); + select_parts.push(format!("{} AS {}", lo_expr, naming::quote_ident(&stat_lo))); + select_parts.push(format!("{} AS {}", hi_expr, naming::quote_ident(&stat_hi))); + + let transformed_query = format!( + "WITH {src} AS ({query}) SELECT {sel} FROM {src} AS \"__ggsql_qt__\"{gb}", + src = src_alias, + query = query, + sel = select_parts.join(", "), + gb = group_by_clause, + ); + + // consumed_aesthetics: pos2 carries the original-name capture for axis + // labels; lo/hi flag the auto-rename in execute/layer.rs (their stat-column + // names match the position aesthetics they fill). + Ok(StatResult::Transformed { + query: transformed_query, + stat_columns: vec![lo.to_string(), hi.to_string()], + dummy_columns: vec![], + consumed_aesthetics: vec!["pos2".to_string(), lo.to_string(), hi.to_string()], + }) +} + +/// Build the SQL fragment for one function in range mode. Quantiles get the +/// inline form when the dialect supports it; otherwise the fallback subquery. +fn build_range_function_sql( + func: &str, + qcol: &str, + raw_col: &str, + dialect: &dyn SqlDialect, + group_cols: &[String], +) -> Result { + if func == "count" { + return Err(GgsqlError::ValidationError( + "aggregate on a range geom does not support 'count' (it has no range semantics)" + .to_string(), + )); + } + if let Some(frac) = quantile_fraction(func) { + if let Some(inline) = dialect.sql_quantile_inline(raw_col, frac) { + return Ok(inline); + } + return Ok(dialect.sql_percentile(raw_col, frac, "\"__ggsql_stat_src__\"", group_cols)); + } + function_inline_sql(func, qcol, dialect).ok_or_else(|| { + GgsqlError::ValidationError(format!( + "aggregate on a range geom does not support function '{}' on this dialect", + func + )) + }) +} + // ============================================================================= // Single-pass strategy: GROUP BY produces a wide CTE, then CROSS JOIN explodes // rows per requested function. @@ -585,6 +727,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); assert_eq!(result, StatResult::Identity); @@ -604,6 +747,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); assert_eq!(result, StatResult::Identity); @@ -628,6 +772,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); @@ -668,6 +813,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); @@ -708,6 +854,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -748,6 +895,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -779,6 +927,7 @@ mod tests { ¶ms, &NoInlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -814,6 +963,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -844,6 +994,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -873,6 +1024,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -902,6 +1054,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -948,6 +1101,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -990,6 +1144,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -1021,6 +1176,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -1062,6 +1218,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[1, 2], + None, ) .unwrap(); match result { @@ -1110,6 +1267,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -1152,6 +1310,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -1200,6 +1359,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[1, 2], + None, ) .unwrap(); match result { @@ -1247,6 +1407,7 @@ mod tests { ¶ms, &InlineQuantileDialect, &[2], + None, ) .unwrap(); match result { @@ -1262,4 +1423,257 @@ mod tests { _ => panic!("expected Transformed"), } } + + // ======================================================================== + // Range-mode tests (ribbon / errorbar) + // ======================================================================== + + fn range_pair() -> Option<(&'static str, &'static str)> { + Some(("pos2min", "pos2max")) + } + + fn range_input_aes_with_group() -> (Mappings, Schema) { + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = vec![ + ColumnInfo { + name: "__ggsql_aes_pos1__".to_string(), + dtype: DataType::Utf8, + is_discrete: true, + min: None, + max: None, + }, + ColumnInfo { + name: "__ggsql_aes_pos2__".to_string(), + dtype: DataType::Float64, + is_discrete: false, + min: None, + max: None, + }, + ]; + (aes, schema) + } + + #[test] + fn range_mode_two_functions_emits_one_row_per_group() { + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("mean-sdev".to_string()), + ArrayElement::String("mean+sdev".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + stat_columns, + consumed_aesthetics, + .. + } => { + assert!( + query.contains("AVG(\"__ggsql_aes_pos2__\") - STDDEV_POP(\"__ggsql_aes_pos2__\")"), + "lower bound expr missing: {}", + query + ); + assert!( + query.contains("AVG(\"__ggsql_aes_pos2__\") + STDDEV_POP(\"__ggsql_aes_pos2__\")"), + "upper bound expr missing: {}", + query + ); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + assert!(!query.contains("UNION ALL")); + assert!(!query.contains("CROSS JOIN")); + // No `aggregate` tag column in range mode. + assert!(!query.contains("__ggsql_stat_aggregate__")); + assert_eq!( + stat_columns, + vec!["pos2min".to_string(), "pos2max".to_string()] + ); + assert!(consumed_aesthetics.contains(&"pos2".to_string())); + assert!(consumed_aesthetics.contains(&"pos2min".to_string())); + assert!(consumed_aesthetics.contains(&"pos2max".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn range_mode_rejects_single_function() { + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("exactly two"), + "expected 'exactly two' in error, got: {}", + err + ); + } + + #[test] + fn range_mode_rejects_three_functions() { + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("min".to_string()), + ArrayElement::String("mean".to_string()), + ArrayElement::String("max".to_string()), + ]), + ); + + let err = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap_err() + .to_string(); + assert!(err.contains("exactly two")); + } + + #[test] + fn range_mode_quantile_uses_inline_when_available() { + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("q25".to_string()), + ArrayElement::String("q75".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("QUANTILE_CONT")); + assert!(query.contains("0.25")); + assert!(query.contains("0.75")); + assert!(!query.contains("NTILE(4)")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn range_mode_quantile_falls_back_without_dialect_support() { + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("q25".to_string()), + ArrayElement::String("q75".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &NoInlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("NTILE(4)")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn range_mode_requires_pos2_input() { + // Range geom but pos2 not mapped → error. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + let schema = vec![ColumnInfo { + name: "__ggsql_aes_pos1__".to_string(), + dtype: DataType::Utf8, + is_discrete: true, + min: None, + max: None, + }]; + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("mean-sdev".to_string()), + ArrayElement::String("mean+sdev".to_string()), + ]), + ); + + let err = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap_err() + .to_string(); + assert!( + err.contains("pos2") || err.contains("`y`"), + "expected pos2/y mention in error, got: {}", + err + ); + } } diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index e6656590..8416c411 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -200,10 +200,25 @@ impl Layer { }; // Check if all required aesthetics exist. + // When `aggregate` is set on a range geom, the (lower, upper) range pair + // is filled by the stat (e.g. pos2min/pos2max for ribbon) and shouldn't + // be required from the user. + let range_pair_skip: Option<(&'static str, &'static str)> = + if crate::plot::layer::geom::has_aggregate_param(&self.parameters) { + self.geom.aggregate_range_pair() + } else { + None + }; + let mut missing = Vec::new(); let mut position_reqs: Vec<(&str, u8, &str)> = Vec::new(); for aesthetic in self.geom.aesthetics().required() { + if let Some((lo, hi)) = range_pair_skip { + if aesthetic == lo || aesthetic == hi { + continue; + } + } if let Some((slot, suffix)) = parse_position(aesthetic) { position_reqs.push((aesthetic, slot, suffix)) } else if !self.mappings.contains_key(aesthetic) { From 218f302aea3b1691f2b3b17503ac3052fc89df19 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 27 Apr 2026 14:32:55 +0200 Subject: [PATCH 04/33] reformat --- src/plot/layer/geom/bar.rs | 4 +- src/plot/layer/geom/line.rs | 4 +- src/plot/layer/geom/stat_aggregate.rs | 93 ++++++++++++++++++--------- src/plot/layer/geom/types.rs | 10 +-- src/plot/layer/mod.rs | 3 +- 5 files changed, 72 insertions(+), 42 deletions(-) diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index aebf207e..e65a0256 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -6,8 +6,8 @@ use std::collections::HashSet; use super::stat_aggregate; use super::types::{get_column_name, POSITION_VALUES}; use super::{ - has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, - ParamDefinition, StatResult, + has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, + ParamConstraint, ParamDefinition, StatResult, }; use crate::naming; use crate::plot::types::{DefaultAestheticValue, ParameterValue}; diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index 6493fd83..e0600af0 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -3,8 +3,8 @@ use super::stat_aggregate; use super::types::wrap_with_order_by; use super::{ - has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, - ParamDefinition, StatResult, + has_aggregate_param, DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, + ParamConstraint, ParamDefinition, StatResult, }; use crate::plot::layer::orientation::{ALIGNED, ORIENTATION_VALUES}; use crate::plot::types::DefaultAestheticValue; diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 2d3ce1a0..bd8cb0dc 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -259,14 +259,8 @@ fn function_inline_sql(func: &str, qcol: &str, dialect: &dyn SqlDialect) -> Opti "mean+sdev" => format!("(AVG({c}) + STDDEV_POP({c}))", c = qcol), "mean-2sdev" => format!("(AVG({c}) - 2.0 * STDDEV_POP({c}))", c = qcol), "mean+2sdev" => format!("(AVG({c}) + 2.0 * STDDEV_POP({c}))", c = qcol), - "mean-se" => format!( - "(AVG({c}) - STDDEV_POP({c}) / SQRT(COUNT({c})))", - c = qcol - ), - "mean+se" => format!( - "(AVG({c}) + STDDEV_POP({c}) / SQRT(COUNT({c})))", - c = qcol - ), + "mean-se" => format!("(AVG({c}) - STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), + "mean+se" => format!("(AVG({c}) + STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), // `iqr` is computed from quantiles - handled separately. _ => return None, }) @@ -364,10 +358,7 @@ fn apply_range_mode( let stat_lo = naming::stat_column(lo); let stat_hi = naming::stat_column(hi); - let group_select: Vec = group_cols - .iter() - .map(|c| naming::quote_ident(c)) - .collect(); + let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); let mut select_parts = group_select.clone(); select_parts.push(format!("{} AS {}", lo_expr, naming::quote_ident(&stat_lo))); select_parts.push(format!("{} AS {}", hi_expr, naming::quote_ident(&stat_hi))); @@ -444,10 +435,8 @@ fn build_single_pass_query( }; // Build the wide aggregation SELECT: one column per (function × position). - let mut wide_select_exprs: Vec = group_cols - .iter() - .map(|c| naming::quote_ident(c)) - .collect(); + let mut wide_select_exprs: Vec = + group_cols.iter().map(|c| naming::quote_ident(c)).collect(); // Track the synthetic column names for each (aesthetic, function) pair. let mut wide_col_for: HashMap<(String, String), String> = HashMap::new(); @@ -494,7 +483,10 @@ fn build_single_pass_query( let wide_select = wide_select_exprs.join(", "); // Build the CROSS JOIN VALUES table of function names. - let funcs_values: Vec = funcs.iter().map(|f| format!("({})", func_literal(f))).collect(); + let funcs_values: Vec = funcs + .iter() + .map(|f| format!("({})", func_literal(f))) + .collect(); let funcs_cte = format!( "{}(name) AS (VALUES {})", funcs_alias, @@ -529,7 +521,11 @@ fn build_single_pass_query( whens.join(" ") ) }; - outer_exprs.push(format!("{} AS {}", case_expr, naming::quote_ident(&stat_col))); + outer_exprs.push(format!( + "{} AS {}", + case_expr, + naming::quote_ident(&stat_col) + )); } if let Some(count_wide) = count_wide { @@ -541,7 +537,11 @@ fn build_single_pass_query( lit = func_literal("count"), c = naming::quote_ident(&count_wide) ); - outer_exprs.push(format!("{} AS {}", case_expr, naming::quote_ident(&stat_col))); + outer_exprs.push(format!( + "{} AS {}", + case_expr, + naming::quote_ident(&stat_col) + )); } let stat_aggregate_col = naming::stat_column("aggregate"); @@ -603,10 +603,7 @@ fn build_union_all_query( format!(" GROUP BY {}", qcols.join(", ")) }; - let group_select: Vec = group_cols - .iter() - .map(|c| naming::quote_ident(c)) - .collect(); + let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); let needs_count_col = funcs.iter().any(|f| f == "count"); let stat_aggregate_col = naming::stat_column("aggregate"); @@ -631,7 +628,11 @@ fn build_union_all_query( let qcol = naming::quote_ident(col); function_inline_sql(func, &qcol, dialect).unwrap_or_else(|| "NULL".to_string()) }; - select_parts.push(format!("{} AS {}", value_expr, naming::quote_ident(&stat_col))); + select_parts.push(format!( + "{} AS {}", + value_expr, + naming::quote_ident(&stat_col) + )); } if needs_count_col { @@ -783,7 +784,11 @@ mod tests { consumed_aesthetics, .. } => { - assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!( + query.contains("AVG(\"__ggsql_aes_pos2__\")"), + "query: {}", + query + ); assert!(query.contains("CROSS JOIN")); assert!(stat_columns.contains(&"pos2".to_string())); assert!(stat_columns.contains(&"aggregate".to_string())); @@ -1186,7 +1191,11 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!( + query.contains("MAX(\"__ggsql_aes_pos2__\")"), + "query: {}", + query + ); assert!(!query.contains("MAX(\"__ggsql_aes_pos1__\")")); assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); @@ -1228,8 +1237,16 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"), "query: {}", query); - assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"), "query: {}", query); + assert!( + query.contains("AVG(\"__ggsql_aes_pos1__\")"), + "query: {}", + query + ); + assert!( + query.contains("AVG(\"__ggsql_aes_pos2__\")"), + "query: {}", + query + ); assert!(!query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); let mut consumed = consumed_aesthetics.clone(); consumed.sort(); @@ -1276,8 +1293,16 @@ mod tests { consumed_aesthetics, .. } => { - assert!(query.contains("AVG(\"__ggsql_aes_pos2min__\")"), "query: {}", query); - assert!(query.contains("AVG(\"__ggsql_aes_pos2max__\")"), "query: {}", query); + assert!( + query.contains("AVG(\"__ggsql_aes_pos2min__\")"), + "query: {}", + query + ); + assert!( + query.contains("AVG(\"__ggsql_aes_pos2max__\")"), + "query: {}", + query + ); assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); let mut consumed = consumed_aesthetics.clone(); consumed.sort(); @@ -1487,12 +1512,16 @@ mod tests { .. } => { assert!( - query.contains("AVG(\"__ggsql_aes_pos2__\") - STDDEV_POP(\"__ggsql_aes_pos2__\")"), + query.contains( + "AVG(\"__ggsql_aes_pos2__\") - STDDEV_POP(\"__ggsql_aes_pos2__\")" + ), "lower bound expr missing: {}", query ); assert!( - query.contains("AVG(\"__ggsql_aes_pos2__\") + STDDEV_POP(\"__ggsql_aes_pos2__\")"), + query.contains( + "AVG(\"__ggsql_aes_pos2__\") + STDDEV_POP(\"__ggsql_aes_pos2__\")" + ), "upper bound expr missing: {}", query ); diff --git a/src/plot/layer/geom/types.rs b/src/plot/layer/geom/types.rs index 4bdcf897..e9f86bcf 100644 --- a/src/plot/layer/geom/types.rs +++ b/src/plot/layer/geom/types.rs @@ -308,10 +308,7 @@ mod tests { dummy_columns, consumed_aesthetics, } => { - assert_eq!( - query, - "SELECT * FROM t ORDER BY \"__ggsql_aes_pos1__\"" - ); + assert_eq!(query, "SELECT * FROM t ORDER BY \"__ggsql_aes_pos1__\""); assert!(stat_columns.is_empty()); assert!(dummy_columns.is_empty()); assert!(consumed_aesthetics.is_empty()); @@ -340,7 +337,10 @@ mod tests { query, "SELECT * FROM (SELECT * FROM grouped) AS \"__ggsql_ord__\" ORDER BY \"__ggsql_aes_pos1__\"" ); - assert_eq!(stat_columns, vec!["pos2".to_string(), "aggregate".to_string()]); + assert_eq!( + stat_columns, + vec!["pos2".to_string(), "aggregate".to_string()] + ); assert_eq!(dummy_columns, vec!["pos1".to_string()]); assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); } diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 8416c411..562a55fc 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -436,7 +436,8 @@ impl Layer { } // Or the shared `aggregate` param for Identity-stat geoms else if param_name == "aggregate" && self.geom.supports_aggregate() { - let definition = crate::plot::layer::geom::stat_aggregate::aggregate_param_definition(); + let definition = + crate::plot::layer::geom::stat_aggregate::aggregate_param_definition(); validate_parameter(param_name, value, &definition.constraint)?; } // Otherwise it's a valid aesthetic setting (no constraint validation needed) From 8c5845fa9a618c3e89a672ee85dfd3a2eb9978eb Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 11:41:04 +0200 Subject: [PATCH 05/33] support aggregation in segment --- src/plot/layer/geom/segment.rs | 5 +++++ src/plot/layer/mod.rs | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/plot/layer/geom/segment.rs b/src/plot/layer/geom/segment.rs index e0815be7..499ae173 100644 --- a/src/plot/layer/geom/segment.rs +++ b/src/plot/layer/geom/segment.rs @@ -44,6 +44,11 @@ impl GeomTrait for Segment { } fn aggregate_slots(&self) -> &'static [u8] { + // Segment is two endpoints connected by a line. Aggregate runs + // independently on each of the four position aesthetics: pos1 and + // pos1end (slot 1), pos2 and pos2end (slot 2). With `aggregate => 'mean'`, + // the segment goes from `(mean(pos1), mean(pos2))` to + // `(mean(pos1end), mean(pos2end))`. &[1, 2] } } diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 562a55fc..14572919 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -201,8 +201,8 @@ impl Layer { // Check if all required aesthetics exist. // When `aggregate` is set on a range geom, the (lower, upper) range pair - // is filled by the stat (e.g. pos2min/pos2max for ribbon) and shouldn't - // be required from the user. + // is filled by the stat (e.g. pos2min/pos2max for ribbon, pos2/pos2end + // for segment) and shouldn't be required from the user. let range_pair_skip: Option<(&'static str, &'static str)> = if crate::plot::layer::geom::has_aggregate_param(&self.parameters) { self.geom.aggregate_range_pair() From 2cb021646acba3741620640d9b9776590b71b96a Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 12:21:11 +0200 Subject: [PATCH 06/33] allow orientation in range and ribbon for aggregation case --- src/plot/layer/geom/range.rs | 8 ++++++++ src/plot/layer/geom/ribbon.rs | 20 +++++++++++++++----- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/plot/layer/geom/range.rs b/src/plot/layer/geom/range.rs index 12789dd4..d368c4e7 100644 --- a/src/plot/layer/geom/range.rs +++ b/src/plot/layer/geom/range.rs @@ -4,6 +4,7 @@ use super::types::POSITION_VALUES; use super::{ DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, }; +use crate::plot::layer::orientation::ORIENTATION_VALUES; use crate::plot::types::DefaultAestheticValue; /// Range geom - intervals along the secondary axis @@ -45,6 +46,13 @@ impl GeomTrait for Range { default: DefaultParamValue::Number(10.0), constraint: ParamConstraint::number_min(0.0), }, + // Default Null → resolve_orientation auto-detects from mappings/scales. + // User can override with `SETTING orientation => 'transposed'`. + ParamDefinition { + name: "orientation", + default: DefaultParamValue::Null, + constraint: ParamConstraint::string_option(ORIENTATION_VALUES), + }, ]; PARAMS } diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index 07f005d7..47b58a97 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -3,6 +3,7 @@ use super::stat_aggregate; use super::types::{wrap_with_order_by, POSITION_VALUES}; use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; +use crate::plot::layer::orientation::ORIENTATION_VALUES; use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamConstraint, ParamDefinition}; use crate::Mappings; @@ -36,11 +37,20 @@ impl GeomTrait for Ribbon { } fn default_params(&self) -> &'static [ParamDefinition] { - const PARAMS: &[ParamDefinition] = &[ParamDefinition { - name: "position", - default: DefaultParamValue::String("identity"), - constraint: ParamConstraint::string_option(POSITION_VALUES), - }]; + const PARAMS: &[ParamDefinition] = &[ + ParamDefinition { + name: "position", + default: DefaultParamValue::String("identity"), + constraint: ParamConstraint::string_option(POSITION_VALUES), + }, + // Default Null → resolve_orientation auto-detects from mappings/scales. + // User can override with `SETTING orientation => 'transposed'`. + ParamDefinition { + name: "orientation", + default: DefaultParamValue::Null, + constraint: ParamConstraint::string_option(ORIENTATION_VALUES), + }, + ]; PARAMS } From cc390bd6e31230799c1bfb7d5672b04e9bc736a9 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 12:25:57 +0200 Subject: [PATCH 07/33] rename to percentile --- src/plot/layer/geom/stat_aggregate.rs | 70 +++++++++++++-------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 67dc9d38..875f6e29 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -41,14 +41,14 @@ pub const AGG_NAMES: &[&str] = &[ "sdev", "var", "iqr", - // Quantiles - "q05", - "q10", - "q25", - "q50", - "q75", - "q90", - "q95", + // Percentiles + "p05", + "p10", + "p25", + "p50", + "p75", + "p90", + "p95", // Bands (mean ± spread) "mean-sdev", "mean+sdev", @@ -155,7 +155,7 @@ pub fn apply( // Decide strategy: single-pass when every quantile can be inlined. let needs_fallback = funcs.iter().any(|f| { - if let Some(frac) = quantile_fraction(f) { + if let Some(frac) = percentile_fraction(f) { // Use the first numeric column (any will do) for the probe, since we // only care whether the dialect produces Some or None. let probe = numeric_pos @@ -215,16 +215,16 @@ fn extract_aggregate_param(parameters: &HashMap) -> Opti } } -/// Map a quantile function name (`q05`..`q95`, `median`) to its fraction. -fn quantile_fraction(func: &str) -> Option { +/// Map a percentile function name (`p05`..`p95`, `median`) to its fraction. +fn percentile_fraction(func: &str) -> Option { match func { - "median" | "q50" => Some(0.50), - "q05" => Some(0.05), - "q10" => Some(0.10), - "q25" => Some(0.25), - "q75" => Some(0.75), - "q90" => Some(0.90), - "q95" => Some(0.95), + "median" | "p50" => Some(0.50), + "p05" => Some(0.05), + "p10" => Some(0.10), + "p25" => Some(0.25), + "p75" => Some(0.75), + "p90" => Some(0.90), + "p95" => Some(0.95), _ => None, } } @@ -237,7 +237,7 @@ fn function_inline_sql(func: &str, qcol: &str, dialect: &dyn SqlDialect) -> Opti if func == "count" { return None; } - if let Some(frac) = quantile_fraction(func) { + if let Some(frac) = percentile_fraction(func) { // Strip the quotes added by `naming::quote_ident` so we can re-quote inside // `sql_quantile_inline` via the same helper. The dialect impl quotes itself. let unquoted = unquote(qcol); @@ -397,7 +397,7 @@ fn build_range_function_sql( .to_string(), )); } - if let Some(frac) = quantile_fraction(func) { + if let Some(frac) = percentile_fraction(func) { if let Some(inline) = dialect.sql_quantile_inline(raw_col, frac) { return Ok(inline); } @@ -454,14 +454,14 @@ fn build_single_pass_query( let wide_name = synthetic_col_name(aes, func); let expr = match func.as_str() { "iqr" => { - // q75 - q25 inline if dialect supports it - let q75 = dialect + // p75 - p25 inline if dialect supports it + let p75 = dialect .sql_quantile_inline(col, 0.75) .expect("sql_quantile_inline must be Some when single-pass is selected"); - let q25 = dialect + let p25 = dialect .sql_quantile_inline(col, 0.25) .expect("sql_quantile_inline must be Some when single-pass is selected"); - format!("({} - {})", q75, q25) + format!("({} - {})", p75, p25) } _ => function_inline_sql(func, &qcol, dialect) .expect("function_inline_sql must be Some when single-pass is selected"), @@ -619,10 +619,10 @@ fn build_union_all_query( let value_expr = if func == "count" { "NULL".to_string() } else if func == "iqr" { - let q75 = dialect.sql_percentile(col, 0.75, src_alias, group_cols); - let q25 = dialect.sql_percentile(col, 0.25, src_alias, group_cols); - format!("({} - {})", q75, q25) - } else if let Some(frac) = quantile_fraction(func) { + let p75 = dialect.sql_percentile(col, 0.75, src_alias, group_cols); + let p25 = dialect.sql_percentile(col, 0.25, src_alias, group_cols); + format!("({} - {})", p75, p25) + } else if let Some(frac) = percentile_fraction(func) { dialect.sql_percentile(col, frac, src_alias, group_cols) } else { let qcol = naming::quote_ident(col); @@ -889,7 +889,7 @@ mod tests { let mut params = HashMap::new(); params.insert( "aggregate".to_string(), - ParameterValue::String("q25".to_string()), + ParameterValue::String("p25".to_string()), ); let result = apply( @@ -921,7 +921,7 @@ mod tests { let mut params = HashMap::new(); params.insert( "aggregate".to_string(), - ParameterValue::String("q25".to_string()), + ParameterValue::String("p25".to_string()), ); let result = apply( @@ -1041,7 +1041,7 @@ mod tests { } #[test] - fn iqr_emits_q75_minus_q25() { + fn iqr_emits_p75_minus_p25() { let mut aes = Mappings::new(); aes.insert("pos2", col("__ggsql_aes_pos2__")); let schema = numeric_schema(&["__ggsql_aes_pos2__"]); @@ -1606,8 +1606,8 @@ mod tests { params.insert( "aggregate".to_string(), ParameterValue::Array(vec![ - ArrayElement::String("q25".to_string()), - ArrayElement::String("q75".to_string()), + ArrayElement::String("p25".to_string()), + ArrayElement::String("p75".to_string()), ]), ); @@ -1641,8 +1641,8 @@ mod tests { params.insert( "aggregate".to_string(), ParameterValue::Array(vec![ - ArrayElement::String("q25".to_string()), - ArrayElement::String("q75".to_string()), + ArrayElement::String("p25".to_string()), + ArrayElement::String("p75".to_string()), ]), ); From 447600565ce9c6dfffeb5cdda7221bbea3b9ea4e Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 13:06:50 +0200 Subject: [PATCH 08/33] make aggregates parametric --- src/plot/layer/geom/stat_aggregate.rs | 749 +++++++++++++++++++++++--- src/plot/layer/mod.rs | 4 +- 2 files changed, 666 insertions(+), 87 deletions(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 875f6e29..1400bc33 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -15,13 +15,15 @@ use std::collections::HashMap; use super::types::StatResult; use crate::naming; use crate::plot::aesthetic::{is_position_aesthetic, parse_position}; -use crate::plot::types::{ - DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue, Schema, -}; +use crate::plot::types::{ParameterValue, Schema}; use crate::reader::SqlDialect; use crate::{GgsqlError, Mappings, Result}; -/// All aggregation function names accepted by the `aggregate` SETTING. +/// All simple-aggregation function names accepted by the `aggregate` SETTING. +/// +/// Band names (e.g. `mean+sdev`, `median-0.5iqr`) are validated separately by +/// `parse_agg_name`, which checks the offset against `OFFSET_STATS` and the +/// expansion against `EXPANSION_STATS`. pub const AGG_NAMES: &[&str] = &[ // Tallies & sums "count", @@ -49,25 +51,212 @@ pub const AGG_NAMES: &[&str] = &[ "p75", "p90", "p95", - // Bands (mean ± spread) - "mean-sdev", - "mean+sdev", - "mean-2sdev", - "mean+2sdev", - "mean-se", - "mean+se", ]; -/// Returns the `ParamDefinition` for the `aggregate` SETTING parameter. +/// Stats that can appear as the *offset* (left of `±`) in a band name like +/// `mean+sdev`. Single-value central or representative quantities only — +/// counts/spreads are excluded. +pub const OFFSET_STATS: &[&str] = &[ + "mean", + "median", + "geomean", + "harmean", + "rms", + "sum", + "prod", + "min", + "max", + "p05", + "p10", + "p25", + "p50", + "p75", + "p90", + "p95", +]; + +/// Stats that can appear as the *expansion* (right of `±[mod]`) in a band name. +/// Spread / dispersion measures only. +pub const EXPANSION_STATS: &[&str] = &["sdev", "se", "var", "iqr", "range"]; + +/// Parsed representation of any aggregate-function name. /// -/// Used by `Layer::validate_settings` to check the value against `AGG_NAMES`, -/// and by geoms that support aggregation. -pub fn aggregate_param_definition() -> ParamDefinition { - ParamDefinition { - name: "aggregate", - default: DefaultParamValue::Null, - constraint: ParamConstraint::string_or_string_array(AGG_NAMES), +/// Simple aggregates (`mean`, `count`, `p25`) have `band == None`. Band names +/// (`mean+sdev`, `median-0.5iqr`) have `band == Some(...)` with the offset +/// stored in `offset` and the spread/multiplier in `band`. +#[derive(Debug, Clone, PartialEq)] +pub struct AggSpec { + pub offset: &'static str, + pub band: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Band { + pub sign: char, + pub mod_value: f64, + pub expansion: &'static str, +} + +/// Resolve a name to its canonical `&'static str` from the given vocabulary, +/// or `None` if the input doesn't match any entry. +fn resolve_static(name: &str, vocab: &'static [&'static str]) -> Option<&'static str> { + vocab.iter().copied().find(|v| *v == name) +} + +/// Parse an aggregate-function name into an `AggSpec`. Returns `None` on +/// invalid input (unknown stat, malformed band, or band with vocabulary +/// violation). +pub fn parse_agg_name(name: &str) -> Option { + if let Some(spec) = parse_band(name) { + return Some(spec); + } + resolve_static(name, AGG_NAMES).map(|offset| AggSpec { offset, band: None }) +} + +/// Try to parse `name` as a band: `?`. Returns +/// `None` if it doesn't match the band shape OR if either half is outside its +/// allowed vocabulary. +fn parse_band(name: &str) -> Option { + // Walk offsets longest-first so `median` matches before `mean`. + let mut offsets: Vec<&'static str> = OFFSET_STATS.to_vec(); + offsets.sort_by_key(|s| std::cmp::Reverse(s.len())); + + for offset in offsets { + let rest = match name.strip_prefix(offset) { + Some(r) => r, + None => continue, // doesn't start with this offset + }; + let (sign, after_sign) = match rest.chars().next() { + Some('+') => ('+', &rest[1..]), + Some('-') => ('-', &rest[1..]), + _ => continue, // wrong sign char — try next offset + }; + + let (mod_value, expansion_str) = parse_mod_and_remainder(after_sign); + let expansion = match resolve_static(expansion_str, EXPANSION_STATS) { + Some(e) => e, + None => continue, // expansion doesn't match — try next offset + }; + + return Some(AggSpec { + offset, + band: Some(Band { + sign, + mod_value, + expansion, + }), + }); } + None +} + +/// Parse a leading `(.)?` modifier from `s`. Returns +/// `(parsed_value, rest_of_string)`. If no leading digits, returns +/// `(1.0, s)` — modifier defaults to 1. +fn parse_mod_and_remainder(s: &str) -> (f64, &str) { + let mut idx = 0; + let bytes = s.as_bytes(); + while idx < bytes.len() && bytes[idx].is_ascii_digit() { + idx += 1; + } + if idx < bytes.len() && bytes[idx] == b'.' { + let mut after_dot = idx + 1; + while after_dot < bytes.len() && bytes[after_dot].is_ascii_digit() { + after_dot += 1; + } + if after_dot > idx + 1 { + // need at least one digit after '.' + idx = after_dot; + } + } + if idx == 0 { + return (1.0, s); + } + let num_str = &s[..idx]; + let value: f64 = num_str.parse().unwrap_or(1.0); + (value, &s[idx..]) +} + +/// Validate the `aggregate` SETTING value: null, a single function name, or +/// an array of function names. Each name must be parseable by `parse_agg_name`. +pub fn validate_aggregate_param(value: &ParameterValue) -> std::result::Result<(), String> { + use crate::plot::types::ArrayElement; + match value { + ParameterValue::Null => Ok(()), + ParameterValue::String(s) => validate_function_name(s), + ParameterValue::Array(arr) => { + for el in arr { + match el { + ArrayElement::String(s) => validate_function_name(s)?, + ArrayElement::Null => continue, + _ => { + return Err( + "'aggregate' array entries must be strings or null".to_string() + ); + } + } + } + Ok(()) + } + _ => Err("'aggregate' must be a string, array of strings, or null".to_string()), + } +} + +fn validate_function_name(name: &str) -> std::result::Result<(), String> { + match parse_agg_name(name) { + Some(_) => Ok(()), + None => Err(diagnose_invalid_function_name(name)), + } +} + +/// Build a per-role error message for a name that didn't parse. Re-walks the +/// input with looser rules to identify which side (offset / expansion) failed. +fn diagnose_invalid_function_name(name: &str) -> String { + // Look for a sign character. If there is one, examine the offset and + // expansion halves separately. + if let Some(sign_idx) = name.find(|c| c == '+' || c == '-') { + let offset_str = &name[..sign_idx]; + let after_sign = &name[sign_idx + 1..]; + let (_mod_value, expansion_str) = parse_mod_and_remainder(after_sign); + + let offset_known_simple = AGG_NAMES.contains(&offset_str); + let offset_known_band = OFFSET_STATS.contains(&offset_str); + let expansion_known_band = EXPANSION_STATS.contains(&expansion_str); + + if !offset_known_band { + // The offset half is the problem. + if offset_known_simple { + return format!( + "'{}': '{}' is not a valid offset stat. Allowed offsets: {}", + name, + offset_str, + crate::or_list_quoted(OFFSET_STATS, '\''), + ); + } + return format!( + "'{}': '{}' is not a known stat. Allowed offsets: {}", + name, + offset_str, + crate::or_list_quoted(OFFSET_STATS, '\''), + ); + } + if !expansion_known_band { + return format!( + "'{}': '{}' is not a valid expansion stat. Allowed expansions: {}", + name, + expansion_str, + crate::or_list_quoted(EXPANSION_STATS, '\''), + ); + } + // Both halves are individually valid but band parsing failed for some + // other reason (e.g. malformed modifier). + return format!("'{}' is not a valid aggregate function name", name); + } + format!( + "unknown aggregate function '{}'. Allowed: {} (or use a band like `mean+sdev`)", + name, + crate::or_list_quoted(AGG_NAMES, '\''), + ) } /// Apply the Aggregate stat to a layer query. @@ -153,19 +342,15 @@ pub fn apply( let needs_count_col = funcs.iter().any(|f| f == "count"); - // Decide strategy: single-pass when every quantile can be inlined. + // Decide strategy: single-pass when every percentile component can be inlined. + let probe = numeric_pos + .first() + .map(|(_, c)| c.as_str()) + .unwrap_or("__ggsql_probe__"); let needs_fallback = funcs.iter().any(|f| { - if let Some(frac) = percentile_fraction(f) { - // Use the first numeric column (any will do) for the probe, since we - // only care whether the dialect produces Some or None. - let probe = numeric_pos - .first() - .map(|(_, c)| c.as_str()) - .unwrap_or("__ggsql_probe__"); - dialect.sql_quantile_inline(probe, frac).is_none() - } else { - false - } + parse_agg_name(f) + .map(|spec| needs_quantile_fallback(&spec, probe, dialect)) + .unwrap_or(false) }); let transformed_query = if needs_fallback { @@ -229,21 +414,27 @@ fn percentile_fraction(func: &str) -> Option { } } -/// Build the inline SQL fragment for a function applied to a quoted column. +/// Build the inline SQL fragment for a *simple* stat (no band) applied to a +/// quoted column. /// -/// Returns None for `count` (which doesn't take a column) and for quantiles when -/// the dialect lacks an inline form (caller should switch to UNION ALL strategy). -fn function_inline_sql(func: &str, qcol: &str, dialect: &dyn SqlDialect) -> Option { - if func == "count" { +/// Returns `None` for `count` (which doesn't take a column) and for percentile- +/// based stats (`p05..p95`, `median`, `iqr`) when the dialect lacks an inline +/// quantile aggregate (caller should switch to UNION ALL strategy). +fn simple_stat_sql_inline(name: &str, qcol: &str, dialect: &dyn SqlDialect) -> Option { + if name == "count" { return None; } - if let Some(frac) = percentile_fraction(func) { - // Strip the quotes added by `naming::quote_ident` so we can re-quote inside - // `sql_quantile_inline` via the same helper. The dialect impl quotes itself. + if let Some(frac) = percentile_fraction(name) { let unquoted = unquote(qcol); return dialect.sql_quantile_inline(&unquoted, frac); } - Some(match func { + if name == "iqr" { + let unquoted = unquote(qcol); + let p75 = dialect.sql_quantile_inline(&unquoted, 0.75)?; + let p25 = dialect.sql_quantile_inline(&unquoted, 0.25)?; + return Some(format!("({} - {})", p75, p25)); + } + Some(match name { "sum" => format!("SUM({})", qcol), "prod" => format!("EXP(SUM(LN({})))", qcol), "min" => format!("MIN({})", qcol), @@ -254,18 +445,103 @@ fn function_inline_sql(func: &str, qcol: &str, dialect: &dyn SqlDialect) -> Opti "harmean" => format!("(COUNT({c}) * 1.0 / SUM(1.0 / {c}))", c = qcol), "rms" => format!("SQRT(AVG({c} * {c}))", c = qcol), "sdev" => format!("STDDEV_POP({})", qcol), + "se" => format!("(STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), "var" => format!("VAR_POP({})", qcol), - "mean-sdev" => format!("(AVG({c}) - STDDEV_POP({c}))", c = qcol), - "mean+sdev" => format!("(AVG({c}) + STDDEV_POP({c}))", c = qcol), - "mean-2sdev" => format!("(AVG({c}) - 2.0 * STDDEV_POP({c}))", c = qcol), - "mean+2sdev" => format!("(AVG({c}) + 2.0 * STDDEV_POP({c}))", c = qcol), - "mean-se" => format!("(AVG({c}) - STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), - "mean+se" => format!("(AVG({c}) + STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), - // `iqr` is computed from quantiles - handled separately. _ => return None, }) } +/// Inline SQL for a parsed `AggSpec`. Combines the offset and (optional) +/// expansion halves with the appropriate sign and modifier. +fn agg_sql_inline(spec: &AggSpec, qcol: &str, dialect: &dyn SqlDialect) -> Option { + let offset_sql = simple_stat_sql_inline(spec.offset, qcol, dialect)?; + match &spec.band { + None => Some(offset_sql), + Some(band) => { + let exp_sql = simple_stat_sql_inline(band.expansion, qcol, dialect)?; + Some(format_band(&offset_sql, band.sign, band.mod_value, &exp_sql)) + } + } +} + +/// Build the SQL fragment `(offset ± mod * exp)`, omitting the `mod *` prefix +/// when `mod_value == 1.0`. +fn format_band(offset: &str, sign: char, mod_value: f64, exp: &str) -> String { + if mod_value == 1.0 { + format!("({} {} {})", offset, sign, exp) + } else { + format!("({} {} {} * {})", offset, sign, mod_value, exp) + } +} + +/// Fallback SQL for a simple stat. Used by the UNION-ALL path for percentile +/// components (which need correlated `sql_percentile`) and falls through to +/// the inline form for everything else. +fn simple_stat_sql_fallback( + name: &str, + raw_col: &str, + dialect: &dyn SqlDialect, + src_alias: &str, + group_cols: &[String], +) -> String { + if name == "count" { + return "NULL".to_string(); + } + if let Some(frac) = percentile_fraction(name) { + return dialect.sql_percentile(raw_col, frac, src_alias, group_cols); + } + if name == "iqr" { + let p75 = dialect.sql_percentile(raw_col, 0.75, src_alias, group_cols); + let p25 = dialect.sql_percentile(raw_col, 0.25, src_alias, group_cols); + return format!("({} - {})", p75, p25); + } + let qcol = naming::quote_ident(raw_col); + simple_stat_sql_inline(name, &qcol, dialect).unwrap_or_else(|| "NULL".to_string()) +} + +/// Fallback SQL for a parsed `AggSpec` (UNION-ALL path). +fn agg_sql_fallback( + spec: &AggSpec, + raw_col: &str, + dialect: &dyn SqlDialect, + src_alias: &str, + group_cols: &[String], +) -> String { + let offset_sql = simple_stat_sql_fallback(spec.offset, raw_col, dialect, src_alias, group_cols); + match &spec.band { + None => offset_sql, + Some(band) => { + let exp_sql = + simple_stat_sql_fallback(band.expansion, raw_col, dialect, src_alias, group_cols); + format_band(&offset_sql, band.sign, band.mod_value, &exp_sql) + } + } +} + +/// Whether this spec has any percentile component that the dialect can't +/// inline (in which case the caller must use the UNION-ALL fallback). +fn needs_quantile_fallback(spec: &AggSpec, probe_col: &str, dialect: &dyn SqlDialect) -> bool { + if simple_needs_fallback(spec.offset, probe_col, dialect) { + return true; + } + if let Some(band) = &spec.band { + if simple_needs_fallback(band.expansion, probe_col, dialect) { + return true; + } + } + false +} + +fn simple_needs_fallback(name: &str, probe_col: &str, dialect: &dyn SqlDialect) -> bool { + if let Some(frac) = percentile_fraction(name) { + return dialect.sql_quantile_inline(probe_col, frac).is_none(); + } + if name == "iqr" { + return dialect.sql_quantile_inline(probe_col, 0.5).is_none(); + } + false +} + /// Strip surrounding double quotes from an identifier, undoing `naming::quote_ident`. fn unquote(qcol: &str) -> String { let trimmed = qcol.trim_start_matches('"').trim_end_matches('"'); @@ -349,9 +625,9 @@ fn apply_range_mode( format!(" GROUP BY {}", qcols.join(", ")) }; - // Build the two function expressions. Quantiles use the inline form when - // available; otherwise fall back to `sql_percentile` correlated to the - // outer alias used in the FROM (`__ggsql_qt__`, matching boxplot/etc.). + // Parse and emit each bound. Use the inline form when the dialect supports + // every percentile component; otherwise fall back to `sql_percentile` + // correlated to the outer alias used in the FROM (`__ggsql_qt__`). let lo_expr = build_range_function_sql(&funcs[0], &qcol, &input_col, dialect, &group_cols)?; let hi_expr = build_range_function_sql(&funcs[1], &qcol, &input_col, dialect, &group_cols)?; @@ -382,8 +658,10 @@ fn apply_range_mode( }) } -/// Build the SQL fragment for one function in range mode. Quantiles get the -/// inline form when the dialect supports it; otherwise the fallback subquery. +/// Build the SQL fragment for one function in range mode. Parses the function +/// name into an `AggSpec` (which validates the offset/expansion vocabulary) +/// and emits inline SQL when the dialect supports every percentile component, +/// otherwise the correlated fallback. fn build_range_function_sql( func: &str, qcol: &str, @@ -397,18 +675,28 @@ fn build_range_function_sql( .to_string(), )); } - if let Some(frac) = percentile_fraction(func) { - if let Some(inline) = dialect.sql_quantile_inline(raw_col, frac) { - return Ok(inline); - } - return Ok(dialect.sql_percentile(raw_col, frac, "\"__ggsql_stat_src__\"", group_cols)); - } - function_inline_sql(func, qcol, dialect).ok_or_else(|| { + let spec = parse_agg_name(func).ok_or_else(|| { GgsqlError::ValidationError(format!( - "aggregate on a range geom does not support function '{}' on this dialect", - func + "aggregate on a range geom: {}", + diagnose_invalid_function_name(func) )) - }) + })?; + if needs_quantile_fallback(&spec, raw_col, dialect) { + Ok(agg_sql_fallback( + &spec, + raw_col, + dialect, + "\"__ggsql_stat_src__\"", + group_cols, + )) + } else { + agg_sql_inline(&spec, qcol, dialect).ok_or_else(|| { + GgsqlError::ValidationError(format!( + "aggregate on a range geom does not support function '{}' on this dialect", + func + )) + }) + } } // ============================================================================= @@ -452,20 +740,10 @@ fn build_single_pass_query( continue; } let wide_name = synthetic_col_name(aes, func); - let expr = match func.as_str() { - "iqr" => { - // p75 - p25 inline if dialect supports it - let p75 = dialect - .sql_quantile_inline(col, 0.75) - .expect("sql_quantile_inline must be Some when single-pass is selected"); - let p25 = dialect - .sql_quantile_inline(col, 0.25) - .expect("sql_quantile_inline must be Some when single-pass is selected"); - format!("({} - {})", p75, p25) - } - _ => function_inline_sql(func, &qcol, dialect) - .expect("function_inline_sql must be Some when single-pass is selected"), - }; + let spec = parse_agg_name(func) + .expect("aggregate function names are validated upstream of single-pass"); + let expr = agg_sql_inline(&spec, &qcol, dialect) + .expect("agg_sql_inline must be Some when single-pass is selected"); wide_select_exprs.push(format!("{} AS {}", expr, naming::quote_ident(&wide_name))); wide_col_for.insert(key, wide_name); } @@ -614,19 +892,18 @@ fn build_union_all_query( .map(|func| { let mut select_parts: Vec = group_select.clone(); + // Parse the function name once per branch. Falls through to a + // string-NULL value column if parsing fails (shouldn't happen + // because validation runs upstream, but stay defensive). + let parsed_spec = parse_agg_name(func); for (aes, col) in numeric_pos { let stat_col = naming::stat_column(aes); let value_expr = if func == "count" { "NULL".to_string() - } else if func == "iqr" { - let p75 = dialect.sql_percentile(col, 0.75, src_alias, group_cols); - let p25 = dialect.sql_percentile(col, 0.25, src_alias, group_cols); - format!("({} - {})", p75, p25) - } else if let Some(frac) = percentile_fraction(func) { - dialect.sql_percentile(col, frac, src_alias, group_cols) + } else if let Some(spec) = &parsed_spec { + agg_sql_fallback(spec, col, dialect, src_alias, group_cols) } else { - let qcol = naming::quote_ident(col); - function_inline_sql(func, &qcol, dialect).unwrap_or_else(|| "NULL".to_string()) + "NULL".to_string() }; select_parts.push(format!( "{} AS {}", @@ -1705,4 +1982,308 @@ mod tests { err ); } + + // ======================================================================== + // Parser tests (parse_agg_name) + // ======================================================================== + + #[test] + fn parse_simple_names() { + assert_eq!( + parse_agg_name("mean"), + Some(AggSpec { offset: "mean", band: None }) + ); + assert_eq!( + parse_agg_name("count"), + Some(AggSpec { offset: "count", band: None }) + ); + assert_eq!( + parse_agg_name("p25"), + Some(AggSpec { offset: "p25", band: None }) + ); + } + + #[test] + fn parse_band_default_modifier() { + let spec = parse_agg_name("mean+sdev").unwrap(); + assert_eq!(spec.offset, "mean"); + let band = spec.band.unwrap(); + assert_eq!(band.sign, '+'); + assert_eq!(band.mod_value, 1.0); + assert_eq!(band.expansion, "sdev"); + } + + #[test] + fn parse_band_integer_modifier() { + let spec = parse_agg_name("mean-2sdev").unwrap(); + let band = spec.band.unwrap(); + assert_eq!(band.sign, '-'); + assert_eq!(band.mod_value, 2.0); + assert_eq!(band.expansion, "sdev"); + } + + #[test] + fn parse_band_decimal_modifier() { + let spec = parse_agg_name("mean+1.96sdev").unwrap(); + let band = spec.band.unwrap(); + assert_eq!(band.mod_value, 1.96); + } + + #[test] + fn parse_band_longest_offset_wins() { + // 'median+sdev' must match offset 'median', not 'me' (which isn't an + // offset anyway, but more pertinently the parser must not stop at a + // shorter prefix). + let spec = parse_agg_name("median+sdev").unwrap(); + assert_eq!(spec.offset, "median"); + } + + #[test] + fn parse_band_percentile_offset() { + let spec = parse_agg_name("p25+0.5range").unwrap(); + assert_eq!(spec.offset, "p25"); + let band = spec.band.unwrap(); + assert_eq!(band.mod_value, 0.5); + assert_eq!(band.expansion, "range"); + } + + #[test] + fn parse_band_rejects_invalid_offset() { + assert!(parse_agg_name("count+sdev").is_none()); + assert!(parse_agg_name("iqr+sdev").is_none()); + } + + #[test] + fn parse_band_rejects_invalid_expansion() { + assert!(parse_agg_name("mean+count").is_none()); + assert!(parse_agg_name("mean+median").is_none()); + } + + #[test] + fn parse_rejects_unknown() { + assert!(parse_agg_name("foo").is_none()); + assert!(parse_agg_name("").is_none()); + } + + // ======================================================================== + // Validation tests (validate_aggregate_param) + // ======================================================================== + + #[test] + fn validate_accepts_simple_names_and_bands() { + use crate::plot::types::ArrayElement; + validate_aggregate_param(&ParameterValue::String("mean".to_string())).unwrap(); + validate_aggregate_param(&ParameterValue::String("mean+sdev".to_string())).unwrap(); + validate_aggregate_param(&ParameterValue::String("median-0.5iqr".to_string())).unwrap(); + validate_aggregate_param(&ParameterValue::Array(vec![ + ArrayElement::String("mean".to_string()), + ArrayElement::String("mean+1.96sdev".to_string()), + ])) + .unwrap(); + } + + #[test] + fn validate_diagnostic_for_invalid_offset() { + let err = validate_aggregate_param(&ParameterValue::String("count+sdev".to_string())) + .unwrap_err(); + assert!(err.contains("count"), "err: {}", err); + assert!(err.contains("offset"), "err: {}", err); + } + + #[test] + fn validate_diagnostic_for_invalid_expansion() { + let err = validate_aggregate_param(&ParameterValue::String("mean+count".to_string())) + .unwrap_err(); + assert!(err.contains("count"), "err: {}", err); + assert!(err.contains("expansion"), "err: {}", err); + } + + #[test] + fn validate_diagnostic_for_unknown() { + let err = + validate_aggregate_param(&ParameterValue::String("foo".to_string())).unwrap_err(); + assert!(err.contains("unknown"), "err: {}", err); + assert!(err.contains("foo"), "err: {}", err); + } + + // ======================================================================== + // SQL emission for parametric bands + // ======================================================================== + + #[test] + fn band_decimal_modifier_emits_in_sql() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean+1.96sdev".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + None, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!( + query.contains("AVG(\"__ggsql_aes_pos2__\") + 1.96 * STDDEV_POP(\"__ggsql_aes_pos2__\")"), + "query: {}", + query + ); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn band_with_percentile_offset_inline() { + // median-0.5iqr on a dialect with inline quantile support. + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("median-0.5iqr".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + None, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + // median uses QUANTILE_CONT(col, 0.5); iqr uses QUANTILE_CONT(.., 0.75) and 0.25. + assert!( + query.contains("QUANTILE_CONT") && query.contains("0.5"), + "query: {}", + query + ); + assert!(query.contains("0.75") && query.contains("0.25")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn band_with_percentile_offset_falls_back() { + // median+2sdev on a dialect WITHOUT inline quantile support → UNION-ALL + // path with sql_percentile for median, inline STDDEV_POP for sdev. + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("median+2sdev".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &NoInlineQuantileDialect, + &[2], + None, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("NTILE(4)")); + assert!(query.contains("STDDEV_POP")); + assert!(query.contains("2 * ")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn band_with_default_modifier_omits_one_prefix() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = numeric_schema(&["__ggsql_aes_pos2__"]); + let mut params = HashMap::new(); + params.insert( + "aggregate".to_string(), + ParameterValue::String("mean+sdev".to_string()), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + None, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + // mod=1 case: (offset + exp), no `1 *` prefix. + assert!( + query.contains( + "AVG(\"__ggsql_aes_pos2__\") + STDDEV_POP(\"__ggsql_aes_pos2__\")" + ), + "expected `(AVG + STDDEV_POP)` form, got: {}", + query + ); + assert!(!query.contains("1 * STDDEV_POP")); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn range_mode_supports_decimal_band() { + // Ribbon range mode + 95% CI band. + let (aes, schema) = range_input_aes_with_group(); + let mut params = HashMap::new(); + use crate::plot::types::ArrayElement; + params.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + ArrayElement::String("mean-1.96sdev".to_string()), + ArrayElement::String("mean+1.96sdev".to_string()), + ]), + ); + + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + ¶ms, + &InlineQuantileDialect, + &[2], + range_pair(), + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("- 1.96 * STDDEV_POP")); + assert!(query.contains("+ 1.96 * STDDEV_POP")); + } + _ => panic!("expected Transformed"), + } + } } diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 14572919..91961156 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -436,9 +436,7 @@ impl Layer { } // Or the shared `aggregate` param for Identity-stat geoms else if param_name == "aggregate" && self.geom.supports_aggregate() { - let definition = - crate::plot::layer::geom::stat_aggregate::aggregate_param_definition(); - validate_parameter(param_name, value, &definition.constraint)?; + crate::plot::layer::geom::stat_aggregate::validate_aggregate_param(value)?; } // Otherwise it's a valid aesthetic setting (no constraint validation needed) } From 3f1a4335059e9d6f70ee3d945d3ef6aa90d66022 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 13:07:03 +0200 Subject: [PATCH 09/33] reformat --- src/plot/layer/geom/stat_aggregate.rs | 81 ++++++++++----------------- 1 file changed, 30 insertions(+), 51 deletions(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 1400bc33..98246b8f 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -26,53 +26,19 @@ use crate::{GgsqlError, Mappings, Result}; /// expansion against `EXPANSION_STATS`. pub const AGG_NAMES: &[&str] = &[ // Tallies & sums - "count", - "sum", - "prod", - // Extremes - "min", - "max", - "range", - // Central tendency - "mean", - "geomean", - "harmean", - "rms", - "median", - // Spread (standalone) - "sdev", - "var", - "iqr", - // Percentiles - "p05", - "p10", - "p25", - "p50", - "p75", - "p90", - "p95", + "count", "sum", "prod", // Extremes + "min", "max", "range", // Central tendency + "mean", "geomean", "harmean", "rms", "median", // Spread (standalone) + "sdev", "var", "iqr", // Percentiles + "p05", "p10", "p25", "p50", "p75", "p90", "p95", ]; /// Stats that can appear as the *offset* (left of `±`) in a band name like /// `mean+sdev`. Single-value central or representative quantities only — /// counts/spreads are excluded. pub const OFFSET_STATS: &[&str] = &[ - "mean", - "median", - "geomean", - "harmean", - "rms", - "sum", - "prod", - "min", - "max", - "p05", - "p10", - "p25", - "p50", - "p75", - "p90", - "p95", + "mean", "median", "geomean", "harmean", "rms", "sum", "prod", "min", "max", "p05", "p10", + "p25", "p50", "p75", "p90", "p95", ]; /// Stats that can appear as the *expansion* (right of `±[mod]`) in a band name. @@ -190,9 +156,7 @@ pub fn validate_aggregate_param(value: &ParameterValue) -> std::result::Result<( ArrayElement::String(s) => validate_function_name(s)?, ArrayElement::Null => continue, _ => { - return Err( - "'aggregate' array entries must be strings or null".to_string() - ); + return Err("'aggregate' array entries must be strings or null".to_string()); } } } @@ -459,7 +423,12 @@ fn agg_sql_inline(spec: &AggSpec, qcol: &str, dialect: &dyn SqlDialect) -> Optio None => Some(offset_sql), Some(band) => { let exp_sql = simple_stat_sql_inline(band.expansion, qcol, dialect)?; - Some(format_band(&offset_sql, band.sign, band.mod_value, &exp_sql)) + Some(format_band( + &offset_sql, + band.sign, + band.mod_value, + &exp_sql, + )) } } } @@ -1991,15 +1960,24 @@ mod tests { fn parse_simple_names() { assert_eq!( parse_agg_name("mean"), - Some(AggSpec { offset: "mean", band: None }) + Some(AggSpec { + offset: "mean", + band: None + }) ); assert_eq!( parse_agg_name("count"), - Some(AggSpec { offset: "count", band: None }) + Some(AggSpec { + offset: "count", + band: None + }) ); assert_eq!( parse_agg_name("p25"), - Some(AggSpec { offset: "p25", band: None }) + Some(AggSpec { + offset: "p25", + band: None + }) ); } @@ -2100,8 +2078,7 @@ mod tests { #[test] fn validate_diagnostic_for_unknown() { - let err = - validate_aggregate_param(&ParameterValue::String("foo".to_string())).unwrap_err(); + let err = validate_aggregate_param(&ParameterValue::String("foo".to_string())).unwrap_err(); assert!(err.contains("unknown"), "err: {}", err); assert!(err.contains("foo"), "err: {}", err); } @@ -2135,7 +2112,9 @@ mod tests { match result { StatResult::Transformed { query, .. } => { assert!( - query.contains("AVG(\"__ggsql_aes_pos2__\") + 1.96 * STDDEV_POP(\"__ggsql_aes_pos2__\")"), + query.contains( + "AVG(\"__ggsql_aes_pos2__\") + 1.96 * STDDEV_POP(\"__ggsql_aes_pos2__\")" + ), "query: {}", query ); From 6147ccc42cf4556bbd47a2b2a60bd2b3153b80e5 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 13:16:25 +0200 Subject: [PATCH 10/33] clippy be happy --- src/plot/layer/geom/stat_aggregate.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 98246b8f..61ee15cf 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -178,7 +178,7 @@ fn validate_function_name(name: &str) -> std::result::Result<(), String> { fn diagnose_invalid_function_name(name: &str) -> String { // Look for a sign character. If there is one, examine the offset and // expansion halves separately. - if let Some(sign_idx) = name.find(|c| c == '+' || c == '-') { + if let Some(sign_idx) = name.find(['+', '-']) { let offset_str = &name[..sign_idx]; let after_sign = &name[sign_idx + 1..]; let (_mod_value, expansion_str) = parse_mod_and_remainder(after_sign); @@ -235,6 +235,7 @@ fn diagnose_invalid_function_name(name: &str) -> String { /// - **UNION ALL fallback**: when a quantile is requested but the dialect doesn't /// provide `sql_quantile_inline`, fall back to per-function subqueries using /// `dialect.sql_percentile`. +#[allow(clippy::too_many_arguments)] pub fn apply( query: &str, schema: &Schema, @@ -527,6 +528,7 @@ fn func_literal(func: &str) -> String { // pair on the same row. Used by ribbon/range. // ============================================================================= +#[allow(clippy::too_many_arguments)] fn apply_range_mode( query: &str, schema: &Schema, From 1c613e48228cd429e260bf1902da0108b1ff331e Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 13:55:08 +0200 Subject: [PATCH 11/33] ensure multiple aggregates give rise to multiple groups --- src/execute/layer.rs | 45 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 1c9b1b0d..1f0c10d0 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -187,11 +187,16 @@ pub fn apply_remappings_post_query(df: DataFrame, layer: &Layer) -> Result = df .get_column_names() .into_iter() - .filter(|name| naming::is_stat_column(name)) + .filter(|name| { + naming::is_stat_column(name) && !layer.partition_by.contains(&name.to_string()) + }) .collect(); if !stat_cols.is_empty() { df = df.drop_many(&stat_cols)?; @@ -200,6 +205,18 @@ pub fn apply_remappings_post_query(df: DataFrame, layer: &Layer) -> Result) -> usize { + match parameters.get("aggregate") { + Some(ParameterValue::String(_)) => 1, + Some(ParameterValue::Array(arr)) => arr.len(), + _ => 0, + } +} + /// Convert a literal value to an Arrow ArrayRef with constant values. /// /// For string literals, attempts to parse as temporal types (date/datetime/time) @@ -634,6 +651,30 @@ where } } + // The `aggregate` stat column (produced by stat_aggregate when the + // user requests multiple functions) tags each row with its function + // name. For mark types that connect rows within a group (line, area, + // path, polygon), we need to add this column to `layer.partition_by` + // so that e.g. `aggregate => ('min', 'max')` renders as two separate + // lines rather than one zigzag through both. Resolves to the + // post-rename data-column name: if the user remapped `aggregate AS + // `, the prefixed aesthetic column; otherwise the stat column. + // + // Only fires when more than one function is requested — a single + // function produces a constant aggregate column, partitioning by + // which would just add a no-op detail channel. + if stat_columns.iter().any(|s| s == "aggregate") + && aggregate_param_function_count(&layer.parameters) > 1 + { + let partition_col = match final_remappings.get("aggregate") { + Some(aes) => naming::aesthetic_column(aes), + None => naming::stat_column("aggregate"), + }; + if !layer.partition_by.contains(&partition_col) { + layer.partition_by.push(partition_col); + } + } + // Wrap transformed query to rename stat columns to prefixed aesthetic names let stat_rename_exprs: Vec = stat_columns .iter() From f3081a3e4dbdc6ecaceeea700c0053b869218d45 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 15:37:42 +0200 Subject: [PATCH 12/33] begin to document --- doc/syntax/clause/draw.qmd | 19 +++++++++++++++++++ doc/syntax/layer/type/area.qmd | 5 ++++- doc/syntax/layer/type/bar.qmd | 13 +++++++++++++ doc/syntax/layer/type/line.qmd | 18 ++++++++++++++++-- 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/doc/syntax/clause/draw.qmd b/doc/syntax/clause/draw.qmd index 31bf8f3a..f8b8b30e 100644 --- a/doc/syntax/clause/draw.qmd +++ b/doc/syntax/clause/draw.qmd @@ -76,6 +76,25 @@ The `SETTING` clause can be used for two different things: #### Position A special setting is `position` which controls how overlapping objects are repositioned to avoid overlapping etc. Position adjustments have special mapping requirements so all position adjustments will not be relevant for all layer types. Different layers have different defaults as detailed in their documentation. You can read about each different position adjustment at [their own documentation sites](../index.qmd#position-adjustments). +#### Aggregate +Some layers support aggregation of its data through the `aggregate` setting. These layers will state this. `aggregate` allows a single string or an array of strings that specify the aggregation to calculate. The aggregates can be either a simple function or a parameterized band function. + +The simple functions can be one of: + +* `'count'`: Row count +* `'sum'` and `'prod'`: The sum or product +* `'min'`, `'max'`, and `'range'`: Extremes and max - min +* `'mean'`, and `'median'`: Central tendency +* `'geomean'`, `'harmean'`, and `'rms'`: Geometric, harmonic, and root-mean-square +* `'sdev'`, `'var'`, `'iqr'`, and `'se'`: Standard deviation, variance, interquartile range, and standard error +* `'p05'`, `'p10'`, `'p25'`, `'p50'`, `'p75'`, `'p90'`, and `'p95'`: Percentiles + +For band functions you combine an offset with an expansion, potentially multiplied. An example could be `'mean-1.96sdev'` which does exactly what you'd expect it to be. The general form is `±` with `` being optional (defaults to `1`). + +Allowed offsets are: `'mean'`, `'median'`, `'geomean'`, `'harmean'`, `'rms'`, `'sum'`, `'prod'`, `'min'`, `'max'`, and `'p05'`–`'p95'` + +Allowed expansions are: `'sdev'`, `'se'`, `'var'`, `'iqr'`, and `'range'` + ### `FILTER` ```ggsql FILTER diff --git a/doc/syntax/layer/type/area.qmd b/doc/syntax/layer/type/area.qmd index a72b059f..1fa0cdc3 100644 --- a/doc/syntax/layer/type/area.qmd +++ b/doc/syntax/layer/type/area.qmd @@ -25,9 +25,12 @@ The following aesthetics are recognised by the area layer. * `orientation`: The orientation of the layer, see the [Orientation section](#orientation). One of the following: * `'aligned'` to align the layer's primary axis with the coordinate system's first axis. * `'transposed'` to align the layer's primary axis with the coordinate system's second axis. +* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -The area layer sorts the data along its primary axis +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` + +Further, the area layer sorts the data along its primary axis before returning it. ## Orientation Area plots are sorted and connected along their primary axis. Since the primary axis cannot be deduced from the mapping it must be specified using the `orientation` setting. E.g. if you wish to create a vertical area plot you need to set `orientation => 'transposed'` to indicate that the primary layer axis follows the second axis of the coordinate system. diff --git a/doc/syntax/layer/type/bar.qmd b/doc/syntax/layer/type/bar.qmd index d34a4953..f8efc63b 100644 --- a/doc/syntax/layer/type/bar.qmd +++ b/doc/syntax/layer/type/bar.qmd @@ -25,10 +25,13 @@ The bar layer has no required aesthetics ## Settings * `position`: Position adjustment. One of `'identity'`, `'stack'` (default), `'dodge'`, or `'jitter'` * `width`: The width of the bars as a proportion of the available width (0 to 1) +* `aggregate`: Aggregation functions to apply per group if the secondary position has been mapped. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation If the secondary axis has not been mapped the layer will calculate counts for you and display these as the secondary axis. +If the secondary axis has been mapped you can apply aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` + ### Properties * `weight`: If mapped, the sum of the weights within each group is calculated instead of the count in each group @@ -116,3 +119,13 @@ DRAW bar MAPPING species AS fill PROJECT TO polar ``` + +Use a different type of aggregation for the bars through the `aggregate` setting: + +```{ggsql} +VISUALISE species AS y, body_mass AS y FROM ggsql:penguins +DRAW bar + SETTING aggregate => 'mean', fill => 'steelblue' +DRAW range + setting aggregate => ('mean-1.96sdev', 'mean+1.96sdev') +``` diff --git a/doc/syntax/layer/type/line.qmd b/doc/syntax/layer/type/line.qmd index 3ec9ec21..88bc9034 100644 --- a/doc/syntax/layer/type/line.qmd +++ b/doc/syntax/layer/type/line.qmd @@ -24,11 +24,16 @@ The following aesthetics are recognised by the line layer. * `orientation`: The orientation of the layer, see the [Orientation section](#orientation). One of the following: * `'aligned'` to align the layer's primary axis with the coordinate system's first axis. * `'transposed'` to align the layer's primary axis with the coordinate system's second axis. +* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -The line layer sorts the data along its primary axis. +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` + +Further, the line layer sorts the data along its primary axis before returning it. + If the line has a variable `stroke` or `opacity` aesthetic within groups, the line is broken into segments. Each segment gets the property of the preceding datapoint, so the last datapoint in a group does not transfer these properties. +This behavior is not compatible with aggregation. ## Orientation Line plots are sorted and connected along their primary axis. Since the primary axis cannot be deduced from the mapping it must be specified using the `orientation` setting. If you wish to create a vertical line plot, you need to set `orientation => 'transposed'` to indicate that the primary layer axis follows the second axis of the coordinate system. @@ -89,4 +94,13 @@ VISUALISE x, y FROM data DRAW line MAPPING z AS linewidth SCALE linewidth TO (0, 30) -``` \ No newline at end of file +``` + +Use aggregation to draw min and max lines from a set of observations + +```{ggsql} +VISUALISE Day AS x, Temp AS y FROM ggsql:airquality +DRAW line + SETTING aggregate => ('min', 'max') +DRAW point +``` From 56780b0df397ebf2a6ff00ca8674d2c32ed3fcda Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 16:57:49 +0200 Subject: [PATCH 13/33] polygon and path doesn't allow aggregation --- src/plot/layer/geom/path.rs | 8 -------- src/plot/layer/geom/polygon.rs | 8 -------- 2 files changed, 16 deletions(-) diff --git a/src/plot/layer/geom/path.rs b/src/plot/layer/geom/path.rs index c2c8af9f..5e32a3be 100644 --- a/src/plot/layer/geom/path.rs +++ b/src/plot/layer/geom/path.rs @@ -36,14 +36,6 @@ impl GeomTrait for Path { }]; PARAMS } - - fn supports_aggregate(&self) -> bool { - true - } - - fn aggregate_slots(&self) -> &'static [u8] { - &[1, 2] - } } impl std::fmt::Display for Path { diff --git a/src/plot/layer/geom/polygon.rs b/src/plot/layer/geom/polygon.rs index efda483e..d1ed6841 100644 --- a/src/plot/layer/geom/polygon.rs +++ b/src/plot/layer/geom/polygon.rs @@ -37,14 +37,6 @@ impl GeomTrait for Polygon { }]; PARAMS } - - fn supports_aggregate(&self) -> bool { - true - } - - fn aggregate_slots(&self) -> &'static [u8] { - &[1, 2] - } } impl std::fmt::Display for Polygon { From 802f1f1d34bb1e1d7967d36bc8e7705dfcee8001 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Tue, 28 Apr 2026 20:04:36 +0200 Subject: [PATCH 14/33] Add documentation for non-range layers --- doc/syntax/layer/type/bar.qmd | 2 +- doc/syntax/layer/type/point.qmd | 3 ++- doc/syntax/layer/type/rule.qmd | 14 +++++++++++++- doc/syntax/layer/type/segment.qmd | 3 ++- doc/syntax/layer/type/text.qmd | 14 +++++++++++++- src/execute/mod.rs | 7 ++++++- 6 files changed, 37 insertions(+), 6 deletions(-) diff --git a/doc/syntax/layer/type/bar.qmd b/doc/syntax/layer/type/bar.qmd index f8efc63b..c6916866 100644 --- a/doc/syntax/layer/type/bar.qmd +++ b/doc/syntax/layer/type/bar.qmd @@ -123,7 +123,7 @@ PROJECT TO polar Use a different type of aggregation for the bars through the `aggregate` setting: ```{ggsql} -VISUALISE species AS y, body_mass AS y FROM ggsql:penguins +VISUALISE species AS x, body_mass AS y FROM ggsql:penguins DRAW bar SETTING aggregate => 'mean', fill => 'steelblue' DRAW range diff --git a/doc/syntax/layer/type/point.qmd b/doc/syntax/layer/type/point.qmd index a64ca258..5c465488 100644 --- a/doc/syntax/layer/type/point.qmd +++ b/doc/syntax/layer/type/point.qmd @@ -23,9 +23,10 @@ The following aesthetics are recognised by the point layer. ## Settings * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` +* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -The point layer does not transform its data but passes it through unchanged +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` ## Orientation The point layer has no orientation. The axes are treated symmetrically. diff --git a/doc/syntax/layer/type/rule.qmd b/doc/syntax/layer/type/rule.qmd index 71a2ceb4..ba9edcf1 100644 --- a/doc/syntax/layer/type/rule.qmd +++ b/doc/syntax/layer/type/rule.qmd @@ -25,8 +25,10 @@ The following aesthetics are recognised by the rule layer. ## Settings * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` +* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` For diagonal lines, the position aesthetic determines the intercept: @@ -110,4 +112,14 @@ VISUALISE FROM ggsql:penguins intercept AS y, label AS colour FROM lines -``` \ No newline at end of file +``` + +Show a max rule for a timeseries + +```{ggsql} +VISUALISE Temp AS y FROM ggsql:airquality +DRAW line + MAPPING Date AS x +DRAW rule + SETTING aggregate => 'max' +``` diff --git a/doc/syntax/layer/type/segment.qmd b/doc/syntax/layer/type/segment.qmd index 7553aef9..4ee8e367 100644 --- a/doc/syntax/layer/type/segment.qmd +++ b/doc/syntax/layer/type/segment.qmd @@ -25,9 +25,10 @@ For axis-aligned intervals where one coordinate is shared between the start and ## Settings * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` +* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -The segment layer does not transform its data but passes it through unchanged. +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` ## Orientation The segment layer has no orientations. The axes are treated symmetrically. diff --git a/doc/syntax/layer/type/text.qmd b/doc/syntax/layer/type/text.qmd index fa9d010a..002033f8 100644 --- a/doc/syntax/layer/type/text.qmd +++ b/doc/syntax/layer/type/text.qmd @@ -35,6 +35,7 @@ The following aesthetics are recognised by the text layer. * a 2-element numeric array `[h, v]` where the first number is the horizontal offset and the second number is the vertical offset. * `format` Formatting specifier, see explanation below. * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` +* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ### Format The `format` setting can take a string that will be used in formatting the `label` aesthetic. @@ -66,7 +67,7 @@ Known formatters are: * `x`/`X`: Unsigned hexadecimal ## Data transformation -The text layer does not transform its data but passed it through unchanged. +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` ## Orientation The text layer has no orientation. The axes are treated symmetrically. @@ -146,3 +147,14 @@ PLACE text x => (40, 50, 50), y => (19, 19, 15) ``` + +Use aggregation to place labels at their centroid + +```{ggsql} +VISUALISE bill_len AS x, bill_dep AS y FROM ggsql:penguins +DRAW point + MAPPING species AS fill +DRAW text + MAPPING species AS label + SETTING aggregate => 'mean', stroke => 'white', fontweight => 'bold', fontsize => 20 +``` diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 869cae37..9faf9ee8 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -714,9 +714,14 @@ fn add_discrete_columns_to_partition_by( // Build set of excluded aesthetics that should not trigger auto-grouping: // - Stat-consumed aesthetics (transformed, not grouped) // - 'label' aesthetic (text content to display, not grouping categories) + // — except when `aggregate` is set on the layer, in which case label + // becomes a legitimate grouping key (e.g. "mean per species, place + // species name at the centroid"). let consumed_aesthetics = layer.geom.stat_consumed_aesthetics(); let mut excluded_aesthetics: HashSet<&str> = consumed_aesthetics.iter().copied().collect(); - excluded_aesthetics.insert("label"); + if !crate::plot::layer::geom::has_aggregate_param(&layer.parameters) { + excluded_aesthetics.insert("label"); + } for (aesthetic, value) in &layer.mappings.aesthetics { // Skip position aesthetics - these should not trigger auto-grouping. From c6dd4a90c3e6f63fd72381988cebdaffd8c6a8e0 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 4 May 2026 13:34:26 +0200 Subject: [PATCH 15/33] rethink aggregation --- CHANGELOG.md | 9 + doc/syntax/clause/draw.qmd | 15 +- doc/syntax/layer/type/area.qmd | 2 +- doc/syntax/layer/type/bar.qmd | 10 +- doc/syntax/layer/type/line.qmd | 8 +- doc/syntax/layer/type/point.qmd | 2 +- doc/syntax/layer/type/range.qmd | 3 +- doc/syntax/layer/type/ribbon.qmd | 3 +- doc/syntax/layer/type/rule.qmd | 2 +- doc/syntax/layer/type/segment.qmd | 2 +- doc/syntax/layer/type/text.qmd | 2 +- src/execute/layer.rs | 48 +- src/execute/mod.rs | 1 + src/plot/layer/geom/area.rs | 4 +- src/plot/layer/geom/arrow.rs | 4 - src/plot/layer/geom/bar.rs | 4 +- src/plot/layer/geom/boxplot.rs | 1 + src/plot/layer/geom/density.rs | 1 + src/plot/layer/geom/histogram.rs | 1 + src/plot/layer/geom/line.rs | 4 +- src/plot/layer/geom/mod.rs | 52 +- src/plot/layer/geom/point.rs | 4 - src/plot/layer/geom/range.rs | 16 - src/plot/layer/geom/ribbon.rs | 32 +- src/plot/layer/geom/rule.rs | 6 - src/plot/layer/geom/segment.rs | 9 - src/plot/layer/geom/smooth.rs | 1 + src/plot/layer/geom/stat_aggregate.rs | 2214 +++++++------------------ src/plot/layer/geom/text.rs | 4 - src/plot/layer/geom/tile.rs | 1 + src/plot/layer/geom/violin.rs | 1 + src/plot/layer/mod.rs | 20 +- 32 files changed, 658 insertions(+), 1828 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f76a91d..1e176731 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ ### Added +- New `aggregate` SETTING on Identity-stat layers (point, line, area, bar, ribbon, +range, segment, arrow, rule, text). Collapses each group to a single row by +replacing every numeric mapping in place with its aggregated value. Accepts a +single string or array of strings; entries are either unprefixed defaults +(`'mean'`) or per-aesthetic targets (`'y:max'`, `'color:median'`). Up to two +defaults may be supplied — the first applies to lower-half aesthetics plus all +non-range layers, the second to upper-half (`max`/`end` suffix). Numeric +mappings without a target or applicable default are dropped with a warning. - Add cell delimiters and code lens actions to the Positron extension (#366) - ODBC is now turned on for the CLI as well (#344) - `FROM` can now come before `VISUALIZE`, mirroring the DuckDB style. This means @@ -37,6 +45,7 @@ portion (#364). - Removed polars from dependency list along with all its transient dependencies. Rewrote DataFrame struct on top of arrow (#350) - Moved ggsql-python to its own repo (posit-dev/ggsql-python) and cleaned up any additional references to it - Moved ggsql-r to its own repo (posit-dev/ggsql-r) +- The `orientation` setting on `ribbon` and `range` layers. With explicit `xmin`/`xmax` or `ymin`/`ymax` mappings, orientation is unambiguous and is auto-detected from the mappings; the override is no longer needed. ## [2.7.0] - 2026-04-20 diff --git a/doc/syntax/clause/draw.qmd b/doc/syntax/clause/draw.qmd index f8b8b30e..ba84fc0e 100644 --- a/doc/syntax/clause/draw.qmd +++ b/doc/syntax/clause/draw.qmd @@ -77,11 +77,18 @@ The `SETTING` clause can be used for two different things: A special setting is `position` which controls how overlapping objects are repositioned to avoid overlapping etc. Position adjustments have special mapping requirements so all position adjustments will not be relevant for all layer types. Different layers have different defaults as detailed in their documentation. You can read about each different position adjustment at [their own documentation sites](../index.qmd#position-adjustments). #### Aggregate -Some layers support aggregation of its data through the `aggregate` setting. These layers will state this. `aggregate` allows a single string or an array of strings that specify the aggregation to calculate. The aggregates can be either a simple function or a parameterized band function. +Some layers support aggregation of their data through the `aggregate` setting. These layers will state this. `aggregate` collapses each group to a single row, replacing every numeric mapping in place with its aggregated value. Groups are defined by `PARTITION BY` together with all discrete mappings. -The simple functions can be one of: +The setting takes a single string or an array of strings. Each string is one of: -* `'count'`: Row count +* **Default** — `''` (no prefix). With one default the function applies to every untargeted numeric mapping. With two defaults the first is used for the lower side of range layers (e.g. `x`/`xmin`) plus all non-range layers, and the second is used for the upper side of range layers (e.g. `xend`/`xmax`). More than two defaults is an error. +* **Target** — `':'`. Applies `func` to the named aesthetic only (`` is a user-facing name like `x`, `y`, `xmin`, `xmax`, `xend`, `yend`, `color`, `size`, …). A target overrides any default for that aesthetic. + +A numeric mapping that has neither a target nor an applicable default is dropped from the layer with a warning. + +The simple functions are: + +* `'count'`: Non-null tally of the bound column. * `'sum'` and `'prod'`: The sum or product * `'min'`, `'max'`, and `'range'`: Extremes and max - min * `'mean'`, and `'median'`: Central tendency @@ -95,6 +102,8 @@ Allowed offsets are: `'mean'`, `'median'`, `'geomean'`, `'harmean'`, `'rms'`, `' Allowed expansions are: `'sdev'`, `'se'`, `'var'`, `'iqr'`, and `'range'` +Aggregation applies in place: there is no extra `aggregate` column to remap, and you do not need a `REMAPPING` clause to consume aggregate output. The aggregated value replaces the bound column for the same aesthetic. + ### `FILTER` ```ggsql FILTER diff --git a/doc/syntax/layer/type/area.qmd b/doc/syntax/layer/type/area.qmd index 1fa0cdc3..5213175e 100644 --- a/doc/syntax/layer/type/area.qmd +++ b/doc/syntax/layer/type/area.qmd @@ -28,7 +28,7 @@ The following aesthetics are recognised by the area layer. * `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. Further, the area layer sorts the data along its primary axis before returning it. diff --git a/doc/syntax/layer/type/bar.qmd b/doc/syntax/layer/type/bar.qmd index c6916866..e71ba3f2 100644 --- a/doc/syntax/layer/type/bar.qmd +++ b/doc/syntax/layer/type/bar.qmd @@ -30,7 +30,7 @@ The bar layer has no required aesthetics ## Data transformation If the secondary axis has not been mapped the layer will calculate counts for you and display these as the secondary axis. -If the secondary axis has been mapped you can apply aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` +If the secondary axis has been mapped you can apply aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ### Properties @@ -120,12 +120,14 @@ DRAW bar PROJECT TO polar ``` -Use a different type of aggregation for the bars through the `aggregate` setting: +Use a different type of aggregation for the bars through the `aggregate` setting. The `range` layer needs both `ymin` and `ymax` mapped; with two defaults, the first is applied to the lower bound and the second to the upper bound. ```{ggsql} -VISUALISE species AS x, body_mass AS y FROM ggsql:penguins +VISUALISE species AS x FROM ggsql:penguins DRAW bar + MAPPING body_mass AS y SETTING aggregate => 'mean', fill => 'steelblue' DRAW range - setting aggregate => ('mean-1.96sdev', 'mean+1.96sdev') + MAPPING body_mass AS ymin, body_mass AS ymax + SETTING aggregate => ('mean-1.96sdev', 'mean+1.96sdev') ``` diff --git a/doc/syntax/layer/type/line.qmd b/doc/syntax/layer/type/line.qmd index 88bc9034..a40fd486 100644 --- a/doc/syntax/layer/type/line.qmd +++ b/doc/syntax/layer/type/line.qmd @@ -27,7 +27,7 @@ The following aesthetics are recognised by the line layer. * `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, every numeric mapping is replaced in place by its aggregated value to produce a summary trace. Use a default like `'mean'` to summarise the secondary axis, or target other aesthetics with `':'` (e.g. `'color:median'`). To draw min/max envelope lines, use a separate `DRAW line` layer per function, or use a [`range` layer](range.qmd) for a single range mark. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. Further, the line layer sorts the data along its primary axis before returning it. @@ -96,11 +96,13 @@ DRAW line SCALE linewidth TO (0, 30) ``` -Use aggregation to draw min and max lines from a set of observations +Use aggregation to draw min and max lines from a set of observations. Each layer produces one summary trace; stack two layers for both bounds. ```{ggsql} VISUALISE Day AS x, Temp AS y FROM ggsql:airquality DRAW line - SETTING aggregate => ('min', 'max') + SETTING aggregate => 'min' +DRAW line + SETTING aggregate => 'max' DRAW point ``` diff --git a/doc/syntax/layer/type/point.qmd b/doc/syntax/layer/type/point.qmd index 5c465488..16687e0f 100644 --- a/doc/syntax/layer/type/point.qmd +++ b/doc/syntax/layer/type/point.qmd @@ -26,7 +26,7 @@ The following aesthetics are recognised by the point layer. * `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ## Orientation The point layer has no orientation. The axes are treated symmetrically. diff --git a/doc/syntax/layer/type/range.qmd b/doc/syntax/layer/type/range.qmd index d3982bd6..d8c6d672 100644 --- a/doc/syntax/layer/type/range.qmd +++ b/doc/syntax/layer/type/range.qmd @@ -22,9 +22,10 @@ The following aesthetics are recognised by the range layer. ## Settings * `width`: The width of the hinges in points (must be >= 0). Defaults to 10. Can be set to `null` to not display hinges. +* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and the *Data transformation* section below. ## Data transformation -The range layer does not transform its data but passes it through unchanged. +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value, producing one range per group. Range is a range layer: with two defaults the first applies to the start point (`xmin`/`ymin`) and the second applies to the end point (`xmax`/`ymax`). Use a single default like `'mean'` to apply the same function to all values, or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ## Orientation The orientation of range layers is deduced directly from the mapping, because the interval is mapped to the secondary axis. To create a horizontal range layer, you map the independent variable to `y` instead of `x` and the interval to `xmin` and `xmax` (assuming a default Cartesian coordinate system). diff --git a/doc/syntax/layer/type/ribbon.qmd b/doc/syntax/layer/type/ribbon.qmd index 50a38d25..cbc7379e 100644 --- a/doc/syntax/layer/type/ribbon.qmd +++ b/doc/syntax/layer/type/ribbon.qmd @@ -23,9 +23,10 @@ The following aesthetics are recognised by the ribbon layer. ## Settings * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` +* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and the *Data transformation* section below. ## Data transformation -The ribbon layer sorts the data along its primary axis +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value, producing one ribbon per group. Ribon is a range layer: with two defaults the first applies to the start point (`xmin`/`ymin`) and the second applies to the end point (`xmax`/`ymax`). Use a single default like `'mean'` to apply the same function to all values, or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ## Orientation Ribbon layers are sorted and connected along their primary axis. The orientation is deduced directly from the mapping, because the interval is mapped to the secondary axis. To create a vertical ribbon layer you map the independent variable to `y` instead of `x` and the interval to `xmin` and `xmax` (assuming a default Cartesian coordinate system). diff --git a/doc/syntax/layer/type/rule.qmd b/doc/syntax/layer/type/rule.qmd index ba9edcf1..470ea39e 100644 --- a/doc/syntax/layer/type/rule.qmd +++ b/doc/syntax/layer/type/rule.qmd @@ -28,7 +28,7 @@ The following aesthetics are recognised by the rule layer. * `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. For diagonal lines, the position aesthetic determines the intercept: diff --git a/doc/syntax/layer/type/segment.qmd b/doc/syntax/layer/type/segment.qmd index 4ee8e367..ac759829 100644 --- a/doc/syntax/layer/type/segment.qmd +++ b/doc/syntax/layer/type/segment.qmd @@ -28,7 +28,7 @@ For axis-aligned intervals where one coordinate is shared between the start and * `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value, producing one segment per group. Segment is a range layer: with two defaults the first applies to the start point (`x`/`y`) and the second applies to the end point (`xend`/`yend`). Use a single default like `'mean'` to apply the same function to all four endpoints, or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ## Orientation The segment layer has no orientations. The axes are treated symmetrically. diff --git a/doc/syntax/layer/type/text.qmd b/doc/syntax/layer/type/text.qmd index 002033f8..58983055 100644 --- a/doc/syntax/layer/type/text.qmd +++ b/doc/syntax/layer/type/text.qmd @@ -67,7 +67,7 @@ Known formatters are: * `x`/`X`: Unsigned hexadecimal ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, aggregates will be calculated and used as the values to plot. Multiple aggregates will give rise to multiple separate groups in the end. These can be distinguished through the added `aggregate` column you can remap to, e.g. `REMAPPING aggregate AS color` +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ## Orientation The text layer has no orientation. The axes are treated symmetrically. diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 1f0c10d0..51457a93 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -205,18 +205,6 @@ pub fn apply_remappings_post_query(df: DataFrame, layer: &Layer) -> Result) -> usize { - match parameters.get("aggregate") { - Some(ParameterValue::String(_)) => 1, - Some(ParameterValue::Array(arr)) => arr.len(), - _ => 0, - } -} - /// Convert a literal value to an Arrow ArrayRef with constant values. /// /// For string literals, attempts to parse as temporal types (date/datetime/time) @@ -453,6 +441,7 @@ pub fn apply_layer_transforms( scales: &[Scale], dialect: &dyn SqlDialect, execute_query: &F, + aesthetic_ctx: &AestheticContext, ) -> Result where F: Fn(&str) -> Result, @@ -528,6 +517,7 @@ where &layer.parameters, execute_query, dialect, + aesthetic_ctx, )?; // Flip user remappings BEFORE merging defaults for Transposed orientation. @@ -601,15 +591,15 @@ where layer.mappings.aesthetics.remove(aes); } - // Auto-remap stat columns whose names are position aesthetics that were - // consumed by the stat (e.g. Aggregate's `pos1`/`pos2` outputs). The geom - // can't list these in `default_remappings` because the set of position - // aesthetics in play is dynamic per layer. + // Auto-remap stat columns whose names match aesthetics that were + // consumed by the stat (e.g. Aggregate's per-aesthetic outputs). The + // geom can't list these in `default_remappings` because the set of + // mapped aesthetics is dynamic per layer. for stat in &stat_columns { if final_remappings.contains_key(stat) { continue; } - if aesthetic::is_position_aesthetic(stat) && consumed_aesthetics.contains(stat) { + if consumed_aesthetics.contains(stat) { final_remappings.insert(stat.clone(), stat.clone()); } } @@ -651,30 +641,6 @@ where } } - // The `aggregate` stat column (produced by stat_aggregate when the - // user requests multiple functions) tags each row with its function - // name. For mark types that connect rows within a group (line, area, - // path, polygon), we need to add this column to `layer.partition_by` - // so that e.g. `aggregate => ('min', 'max')` renders as two separate - // lines rather than one zigzag through both. Resolves to the - // post-rename data-column name: if the user remapped `aggregate AS - // `, the prefixed aesthetic column; otherwise the stat column. - // - // Only fires when more than one function is requested — a single - // function produces a constant aggregate column, partitioning by - // which would just add a no-op detail channel. - if stat_columns.iter().any(|s| s == "aggregate") - && aggregate_param_function_count(&layer.parameters) > 1 - { - let partition_col = match final_remappings.get("aggregate") { - Some(aes) => naming::aesthetic_column(aes), - None => naming::stat_column("aggregate"), - }; - if !layer.partition_by.contains(&partition_col) { - layer.partition_by.push(partition_col); - } - } - // Wrap transformed query to rename stat columns to prefixed aesthetic names let stat_rename_exprs: Vec = stat_columns .iter() diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 9faf9ee8..40efdfc8 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -1194,6 +1194,7 @@ pub fn prepare_data_with_reader(query: &str, reader: &dyn Reader) -> Result, _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, + aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> crate::Result { let result = if has_aggregate_param(parameters) { stat_aggregate::apply( @@ -81,8 +82,7 @@ impl GeomTrait for Area { group_by, parameters, dialect, - self.aggregate_slots(), - self.aggregate_range_pair(), + aesthetic_ctx, )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/arrow.rs b/src/plot/layer/geom/arrow.rs index 5737bb95..2e3369d2 100644 --- a/src/plot/layer/geom/arrow.rs +++ b/src/plot/layer/geom/arrow.rs @@ -43,10 +43,6 @@ impl GeomTrait for Arrow { fn supports_aggregate(&self) -> bool { true } - - fn aggregate_slots(&self) -> &'static [u8] { - &[1, 2] - } } impl std::fmt::Display for Arrow { diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index e65a0256..65f24c80 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -97,6 +97,7 @@ impl GeomTrait for Bar { parameters: &HashMap, _execute_query: &dyn Fn(&str) -> Result, dialect: &dyn SqlDialect, + aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> Result { if has_aggregate_param(parameters) { return stat_aggregate::apply( @@ -106,8 +107,7 @@ impl GeomTrait for Bar { group_by, parameters, dialect, - self.aggregate_slots(), - self.aggregate_range_pair(), + aesthetic_ctx, ); } stat_bar_count(query, schema, aesthetics, group_by) diff --git a/src/plot/layer/geom/boxplot.rs b/src/plot/layer/geom/boxplot.rs index fdc7bae6..5d99b358 100644 --- a/src/plot/layer/geom/boxplot.rs +++ b/src/plot/layer/geom/boxplot.rs @@ -95,6 +95,7 @@ impl GeomTrait for Boxplot { parameters: &HashMap, _execute_query: &dyn Fn(&str) -> Result, dialect: &dyn SqlDialect, + _aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> Result { stat_boxplot(query, aesthetics, group_by, parameters, dialect) } diff --git a/src/plot/layer/geom/density.rs b/src/plot/layer/geom/density.rs index 89910be9..7198e491 100644 --- a/src/plot/layer/geom/density.rs +++ b/src/plot/layer/geom/density.rs @@ -111,6 +111,7 @@ impl GeomTrait for Density { parameters: &std::collections::HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn SqlDialect, + _aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> crate::Result { // Density geom: no tails limit (don't set tails parameter, defaults to None) stat_density( diff --git a/src/plot/layer/geom/histogram.rs b/src/plot/layer/geom/histogram.rs index 66400e56..bfb80050 100644 --- a/src/plot/layer/geom/histogram.rs +++ b/src/plot/layer/geom/histogram.rs @@ -97,6 +97,7 @@ impl GeomTrait for Histogram { parameters: &HashMap, execute_query: &dyn Fn(&str) -> Result, dialect: &dyn SqlDialect, + _aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> Result { stat_histogram( query, diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index e0600af0..63980856 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -58,6 +58,7 @@ impl GeomTrait for Line { parameters: &std::collections::HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, + aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> crate::Result { let result = if has_aggregate_param(parameters) { stat_aggregate::apply( @@ -67,8 +68,7 @@ impl GeomTrait for Line { group_by, parameters, dialect, - self.aggregate_slots(), - self.aggregate_range_pair(), + aesthetic_ctx, )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index 9296d7e3..38c8a034 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -73,6 +73,7 @@ pub use text::Text; pub use tile::Tile; pub use violin::Violin; +use crate::plot::aesthetic::AestheticContext; use crate::plot::types::{ParameterValue, Schema}; use crate::reader::SqlDialect; @@ -195,40 +196,15 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { /// Whether this geom accepts the `aggregate` SETTING parameter. /// - /// Geoms that opt in (the Identity-stat geoms) gain a generic Aggregate stat - /// that groups by discrete mappings + PARTITION BY and emits one row per - /// (group × aggregation function). Statistical geoms (histogram, density, - /// smooth, boxplot, violin) leave this `false` to keep their bespoke stats. + /// Geoms that opt in gain a generic Aggregate stat that groups by discrete + /// mappings + PARTITION BY and emits one row per group, replacing every + /// numeric mapping (positional and material) with its aggregated value. + /// Statistical geoms (histogram, density, smooth, boxplot, violin) leave + /// this `false` to keep their bespoke stats. fn supports_aggregate(&self) -> bool { false } - /// Which numeric position-aesthetic slots the Aggregate stat should reduce. - /// - /// Slot 1 is `pos1`/`pos1min`/`pos1max`/`pos1end` (the independent / domain axis). - /// Slot 2 is `pos2`/`pos2min`/`pos2max`/`pos2end` (the dependent / range axis). - /// - /// Default: `&[2]` — only the dependent axis is reduced; pos1-family stays as a - /// grouping column, so e.g. line geoms produce a summary trace along x. Geoms - /// whose natural Aggregate is centroid-like (point, polygon, segment, arrow, - /// text, path, tile, rule) override to `&[1, 2]`. - fn aggregate_slots(&self) -> &'static [u8] { - &[2] - } - - /// Range pair for range-style Aggregate output. - /// - /// When `Some((lower, upper))`, this geom is a "range geom" that takes exactly - /// two `aggregate` functions and assigns them to the two named aesthetics - /// (e.g. `("pos2min", "pos2max")` for ribbon/range). The user maps `pos2` - /// as the input column; the stat consumes pos2 and produces the range pair. - /// One row per group; no `aggregate` tag column. - /// - /// `None` (default) means standard per-function-rows aggregation. - fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { - None - } - /// Apply statistical transformation to the layer query. /// /// The default implementation dispatches to the Aggregate stat when @@ -244,6 +220,7 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { parameters: &HashMap, _execute_query: &dyn Fn(&str) -> Result, dialect: &dyn SqlDialect, + aesthetic_ctx: &AestheticContext, ) -> Result { if self.supports_aggregate() && has_aggregate_param(parameters) { return stat_aggregate::apply( @@ -253,8 +230,7 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { group_by, parameters, dialect, - self.aggregate_slots(), - self.aggregate_range_pair(), + aesthetic_ctx, ); } Ok(StatResult::Identity) @@ -483,6 +459,7 @@ impl Geom { parameters: &HashMap, execute_query: &dyn Fn(&str) -> Result, dialect: &dyn SqlDialect, + aesthetic_ctx: &AestheticContext, ) -> Result { self.0.apply_stat_transform( query, @@ -492,6 +469,7 @@ impl Geom { parameters, execute_query, dialect, + aesthetic_ctx, ) } @@ -523,16 +501,6 @@ impl Geom { self.0.supports_aggregate() } - /// Which position-aesthetic slots the Aggregate stat should reduce. - pub fn aggregate_slots(&self) -> &'static [u8] { - self.0.aggregate_slots() - } - - /// Range pair for range-style Aggregate output, if any. - pub fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { - self.0.aggregate_range_pair() - } - /// Validate aesthetic mappings pub fn validate_aesthetics(&self, mappings: &Mappings) -> std::result::Result<(), String> { self.0.validate_aesthetics(mappings) diff --git a/src/plot/layer/geom/point.rs b/src/plot/layer/geom/point.rs index 1f60a5f6..5101f2f0 100644 --- a/src/plot/layer/geom/point.rs +++ b/src/plot/layer/geom/point.rs @@ -42,10 +42,6 @@ impl GeomTrait for Point { fn supports_aggregate(&self) -> bool { true } - - fn aggregate_slots(&self) -> &'static [u8] { - &[1, 2] - } } impl std::fmt::Display for Point { diff --git a/src/plot/layer/geom/range.rs b/src/plot/layer/geom/range.rs index d368c4e7..5cacd874 100644 --- a/src/plot/layer/geom/range.rs +++ b/src/plot/layer/geom/range.rs @@ -4,7 +4,6 @@ use super::types::POSITION_VALUES; use super::{ DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, }; -use crate::plot::layer::orientation::ORIENTATION_VALUES; use crate::plot::types::DefaultAestheticValue; /// Range geom - intervals along the secondary axis @@ -22,10 +21,6 @@ impl GeomTrait for Range { ("pos1", DefaultAestheticValue::Required), ("pos2min", DefaultAestheticValue::Required), ("pos2max", DefaultAestheticValue::Required), - // pos2 is the input column for the Aggregate stat in range mode - // (`SETTING aggregate => (lower_func, upper_func)` consumes pos2 - // and produces pos2min/pos2max). Optional otherwise. - ("pos2", DefaultAestheticValue::Null), ("stroke", DefaultAestheticValue::String("black")), ("opacity", DefaultAestheticValue::Number(1.0)), ("linewidth", DefaultAestheticValue::Number(1.0)), @@ -46,13 +41,6 @@ impl GeomTrait for Range { default: DefaultParamValue::Number(10.0), constraint: ParamConstraint::number_min(0.0), }, - // Default Null → resolve_orientation auto-detects from mappings/scales. - // User can override with `SETTING orientation => 'transposed'`. - ParamDefinition { - name: "orientation", - default: DefaultParamValue::Null, - constraint: ParamConstraint::string_option(ORIENTATION_VALUES), - }, ]; PARAMS } @@ -60,10 +48,6 @@ impl GeomTrait for Range { fn supports_aggregate(&self) -> bool { true } - - fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { - Some(("pos2min", "pos2max")) - } } impl std::fmt::Display for Range { diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index 47b58a97..c13f56f5 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -3,7 +3,6 @@ use super::stat_aggregate; use super::types::{wrap_with_order_by, POSITION_VALUES}; use super::{has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, StatResult}; -use crate::plot::layer::orientation::ORIENTATION_VALUES; use crate::plot::types::DefaultAestheticValue; use crate::plot::{DefaultParamValue, ParamConstraint, ParamDefinition}; use crate::Mappings; @@ -23,10 +22,6 @@ impl GeomTrait for Ribbon { ("pos1", DefaultAestheticValue::Required), ("pos2min", DefaultAestheticValue::Required), ("pos2max", DefaultAestheticValue::Required), - // pos2 is the input column for the Aggregate stat in range mode - // (`SETTING aggregate => (lower_func, upper_func)` consumes pos2 - // and produces pos2min/pos2max). Optional otherwise. - ("pos2", DefaultAestheticValue::Null), ("fill", DefaultAestheticValue::String("black")), ("stroke", DefaultAestheticValue::String("black")), ("opacity", DefaultAestheticValue::Number(0.8)), @@ -37,20 +32,11 @@ impl GeomTrait for Ribbon { } fn default_params(&self) -> &'static [ParamDefinition] { - const PARAMS: &[ParamDefinition] = &[ - ParamDefinition { - name: "position", - default: DefaultParamValue::String("identity"), - constraint: ParamConstraint::string_option(POSITION_VALUES), - }, - // Default Null → resolve_orientation auto-detects from mappings/scales. - // User can override with `SETTING orientation => 'transposed'`. - ParamDefinition { - name: "orientation", - default: DefaultParamValue::Null, - constraint: ParamConstraint::string_option(ORIENTATION_VALUES), - }, - ]; + const PARAMS: &[ParamDefinition] = &[ParamDefinition { + name: "position", + default: DefaultParamValue::String("identity"), + constraint: ParamConstraint::string_option(POSITION_VALUES), + }]; PARAMS } @@ -58,10 +44,6 @@ impl GeomTrait for Ribbon { true } - fn aggregate_range_pair(&self) -> Option<(&'static str, &'static str)> { - Some(("pos2min", "pos2max")) - } - fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -75,6 +57,7 @@ impl GeomTrait for Ribbon { parameters: &std::collections::HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, + aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> crate::Result { let result = if has_aggregate_param(parameters) { stat_aggregate::apply( @@ -84,8 +67,7 @@ impl GeomTrait for Ribbon { group_by, parameters, dialect, - self.aggregate_slots(), - self.aggregate_range_pair(), + aesthetic_ctx, )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/rule.rs b/src/plot/layer/geom/rule.rs index 21d7adbe..a495cb48 100644 --- a/src/plot/layer/geom/rule.rs +++ b/src/plot/layer/geom/rule.rs @@ -29,12 +29,6 @@ impl GeomTrait for Rule { true } - fn aggregate_slots(&self) -> &'static [u8] { - // Rule maps exactly one of pos1/pos2 (XOR). Allow either to be the reduced - // axis — whichever is mapped wins, and the other slot has nothing to filter. - &[1, 2] - } - fn validate_aesthetics(&self, mappings: &crate::Mappings) -> std::result::Result<(), String> { // Rule requires exactly one of pos1 or pos2 (XOR logic) let has_pos1 = mappings.contains_key("pos1"); diff --git a/src/plot/layer/geom/segment.rs b/src/plot/layer/geom/segment.rs index 499ae173..4dd7e65f 100644 --- a/src/plot/layer/geom/segment.rs +++ b/src/plot/layer/geom/segment.rs @@ -42,15 +42,6 @@ impl GeomTrait for Segment { fn supports_aggregate(&self) -> bool { true } - - fn aggregate_slots(&self) -> &'static [u8] { - // Segment is two endpoints connected by a line. Aggregate runs - // independently on each of the four position aesthetics: pos1 and - // pos1end (slot 1), pos2 and pos2end (slot 2). With `aggregate => 'mean'`, - // the segment goes from `(mean(pos1), mean(pos2))` to - // `(mean(pos1end), mean(pos2end))`. - &[1, 2] - } } impl std::fmt::Display for Segment { diff --git a/src/plot/layer/geom/smooth.rs b/src/plot/layer/geom/smooth.rs index 2f053235..f65111ee 100644 --- a/src/plot/layer/geom/smooth.rs +++ b/src/plot/layer/geom/smooth.rs @@ -100,6 +100,7 @@ impl GeomTrait for Smooth { parameters: &std::collections::HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn SqlDialect, + _aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> crate::Result { // Get method from parameters (validated by ParamConstraint::string_option) let ParameterValue::String(method) = parameters.get("method").unwrap() else { diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 61ee15cf..6078d8f5 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -1,21 +1,36 @@ -//! Aggregate stat - groups data and applies one or more aggregation functions per group. +//! Aggregate stat — collapse each group to a single row by applying an +//! aggregate function per numeric mapping. //! -//! When a layer's `aggregate` SETTING is set to a function name (or array of names), -//! this stat groups by discrete mappings + PARTITION BY columns and produces one row -//! per (group × function), aggregating numeric position aesthetics. +//! When a layer's `aggregate` SETTING is set, this stat groups by discrete +//! mappings + PARTITION BY columns and emits one row per group. Each numeric +//! column-mapping (positional *and* material) is replaced in place by the +//! aggregated value of its bound column. Discrete mappings stay as group keys; +//! literal mappings pass through unchanged. //! -//! Output columns: -//! - One column per numeric position aesthetic (named `pos1`, `pos2`, etc.) holding the -//! aggregated value. NULL for `count` rows. -//! - `aggregate` - the function name for the row. -//! - `count` (only when `count` is requested) - the row tally for that group. +//! # Setting shape +//! +//! `aggregate` accepts a single string or array of strings. Each string is +//! either: +//! +//! - **default** — `''` (no prefix). Up to two defaults may be supplied. +//! With one default it applies to every untargeted numeric mapping. With two +//! defaults the first applies to *lower-half* aesthetics (no suffix or `min` +//! suffix) plus all non-range geoms, and the second applies to *upper-half* +//! aesthetics (`max` or `end` suffix). More than two defaults is an error. +//! - **target** — `':'`. Applies `func` to the named aesthetic only. +//! `` is a user-facing name (`x`, `y`, `xmin`, `xmax`, `xend`, `yend`, +//! `color`, `size`, …); the stat resolves it to the internal name through +//! `AestheticContext`. +//! +//! Numeric mappings without a target *or* applicable default are dropped with +//! a warning to stderr. use std::collections::HashMap; use super::types::StatResult; use crate::naming; -use crate::plot::aesthetic::{is_position_aesthetic, parse_position}; -use crate::plot::types::{ParameterValue, Schema}; +use crate::plot::aesthetic::AestheticContext; +use crate::plot::types::{ArrayElement, ParameterValue, Schema}; use crate::reader::SqlDialect; use crate::{GgsqlError, Mappings, Result}; @@ -63,8 +78,6 @@ pub struct Band { pub expansion: &'static str, } -/// Resolve a name to its canonical `&'static str` from the given vocabulary, -/// or `None` if the input doesn't match any entry. fn resolve_static(name: &str, vocab: &'static [&'static str]) -> Option<&'static str> { vocab.iter().copied().find(|v| *v == name) } @@ -79,9 +92,6 @@ pub fn parse_agg_name(name: &str) -> Option { resolve_static(name, AGG_NAMES).map(|offset| AggSpec { offset, band: None }) } -/// Try to parse `name` as a band: `?`. Returns -/// `None` if it doesn't match the band shape OR if either half is outside its -/// allowed vocabulary. fn parse_band(name: &str) -> Option { // Walk offsets longest-first so `median` matches before `mean`. let mut offsets: Vec<&'static str> = OFFSET_STATS.to_vec(); @@ -90,18 +100,18 @@ fn parse_band(name: &str) -> Option { for offset in offsets { let rest = match name.strip_prefix(offset) { Some(r) => r, - None => continue, // doesn't start with this offset + None => continue, }; let (sign, after_sign) = match rest.chars().next() { Some('+') => ('+', &rest[1..]), Some('-') => ('-', &rest[1..]), - _ => continue, // wrong sign char — try next offset + _ => continue, }; let (mod_value, expansion_str) = parse_mod_and_remainder(after_sign); let expansion = match resolve_static(expansion_str, EXPANSION_STATS) { Some(e) => e, - None => continue, // expansion doesn't match — try next offset + None => continue, }; return Some(AggSpec { @@ -116,9 +126,6 @@ fn parse_band(name: &str) -> Option { None } -/// Parse a leading `(.)?` modifier from `s`. Returns -/// `(parsed_value, rest_of_string)`. If no leading digits, returns -/// `(1.0, s)` — modifier defaults to 1. fn parse_mod_and_remainder(s: &str) -> (f64, &str) { let mut idx = 0; let bytes = s.as_bytes(); @@ -131,7 +138,6 @@ fn parse_mod_and_remainder(s: &str) -> (f64, &str) { after_dot += 1; } if after_dot > idx + 1 { - // need at least one digit after '.' idx = after_dot; } } @@ -143,41 +149,119 @@ fn parse_mod_and_remainder(s: &str) -> (f64, &str) { (value, &s[idx..]) } -/// Validate the `aggregate` SETTING value: null, a single function name, or -/// an array of function names. Each name must be parseable by `parse_agg_name`. -pub fn validate_aggregate_param(value: &ParameterValue) -> std::result::Result<(), String> { - use crate::plot::types::ArrayElement; - match value { - ParameterValue::Null => Ok(()), - ParameterValue::String(s) => validate_function_name(s), +// ============================================================================= +// AggregateSpec — parsed representation of the `aggregate` SETTING. +// ============================================================================= + +/// Parsed `aggregate` SETTING: zero, one, or two unprefixed defaults plus an +/// optional set of per-aesthetic targets keyed by user-facing aesthetic name. +#[derive(Debug, Clone, PartialEq)] +pub struct AggregateSpec { + pub default_lower: Option, + pub default_upper: Option, + /// Targets keyed by user-facing aesthetic name (e.g. `"y"`, `"xmax"`, + /// `"color"`). Resolved to internal names at apply-time. + pub targets: HashMap, +} + +impl AggregateSpec { + fn new() -> Self { + Self { + default_lower: None, + default_upper: None, + targets: HashMap::new(), + } + } +} + +/// Parse the `aggregate` SETTING value into an `AggregateSpec`. Returns `Ok(None)` +/// when the parameter is unset or null. Returns `Err(...)` for malformed input. +pub fn parse_aggregate_param( + value: &ParameterValue, +) -> std::result::Result, String> { + let entries: Vec<&str> = match value { + ParameterValue::Null => return Ok(None), + ParameterValue::String(s) => vec![s.as_str()], ParameterValue::Array(arr) => { + let mut out = Vec::with_capacity(arr.len()); for el in arr { match el { - ArrayElement::String(s) => validate_function_name(s)?, + ArrayElement::String(s) => out.push(s.as_str()), ArrayElement::Null => continue, _ => { return Err("'aggregate' array entries must be strings or null".to_string()); } } } - Ok(()) + if out.is_empty() { + return Ok(None); + } + out + } + _ => return Err("'aggregate' must be a string, array of strings, or null".to_string()), + }; + + let mut spec = AggregateSpec::new(); + for entry in entries { + if let Some((aes, func)) = split_target(entry) { + if aes.is_empty() { + return Err(format!("'{}': aesthetic prefix is empty", entry)); + } + if func.is_empty() { + return Err(format!("'{}': aggregate function is empty", entry)); + } + let agg = parse_agg_name(func).ok_or_else(|| { + format!( + "'{}': {}", + entry, + diagnose_invalid_function_name(func) + ) + })?; + if spec.targets.contains_key(aes) { + return Err(format!( + "aesthetic '{}' is targeted by more than one aggregate", + aes + )); + } + spec.targets.insert(aes.to_string(), agg); + } else { + let agg = parse_agg_name(entry) + .ok_or_else(|| diagnose_invalid_function_name(entry))?; + if spec.default_lower.is_none() { + spec.default_lower = Some(agg); + } else if spec.default_upper.is_none() { + spec.default_upper = Some(agg); + } else { + return Err(format!( + "'aggregate' accepts at most two unprefixed defaults; got a third: '{}'", + entry + )); + } } - _ => Err("'aggregate' must be a string, array of strings, or null".to_string()), } -} -fn validate_function_name(name: &str) -> std::result::Result<(), String> { - match parse_agg_name(name) { - Some(_) => Ok(()), - None => Err(diagnose_invalid_function_name(name)), + if spec.default_lower.is_none() && spec.default_upper.is_none() && spec.targets.is_empty() { + return Ok(None); } + Ok(Some(spec)) +} + +/// Split an entry into `(aesthetic, function)` if it contains a `:`. Returns +/// `None` for an unprefixed entry like `'mean'`. +fn split_target(entry: &str) -> Option<(&str, &str)> { + entry.split_once(':') +} + +/// Validate the `aggregate` SETTING value at parse-time. Used by +/// `Layer::validate_settings`. Aesthetic-name resolution is deferred to +/// `apply()` because `AestheticContext` isn't available here. +pub fn validate_aggregate_param(value: &ParameterValue) -> std::result::Result<(), String> { + parse_aggregate_param(value).map(|_| ()) } /// Build a per-role error message for a name that didn't parse. Re-walks the /// input with looser rules to identify which side (offset / expansion) failed. fn diagnose_invalid_function_name(name: &str) -> String { - // Look for a sign character. If there is one, examine the offset and - // expansion halves separately. if let Some(sign_idx) = name.find(['+', '-']) { let offset_str = &name[..sign_idx]; let after_sign = &name[sign_idx + 1..]; @@ -188,7 +272,6 @@ fn diagnose_invalid_function_name(name: &str) -> String { let expansion_known_band = EXPANSION_STATS.contains(&expansion_str); if !offset_known_band { - // The offset half is the problem. if offset_known_simple { return format!( "'{}': '{}' is not a valid offset stat. Allowed offsets: {}", @@ -212,8 +295,6 @@ fn diagnose_invalid_function_name(name: &str) -> String { crate::or_list_quoted(EXPANSION_STATS, '\''), ); } - // Both halves are individually valid but band parsing failed for some - // other reason (e.g. malformed modifier). return format!("'{}' is not a valid aggregate function name", name); } format!( @@ -223,147 +304,9 @@ fn diagnose_invalid_function_name(name: &str) -> String { ) } -/// Apply the Aggregate stat to a layer query. -/// -/// Returns `StatResult::Identity` when the `aggregate` parameter is unset or null. -/// Otherwise, builds a grouped-aggregation query and returns `StatResult::Transformed`. -/// -/// Strategy: -/// - **Single-pass** (preferred): one `GROUP BY` produces a wide row per group, then -/// `CROSS JOIN VALUES(...)` of function names explodes to one row per (group × function). -/// Used when all requested functions are inline-able. -/// - **UNION ALL fallback**: when a quantile is requested but the dialect doesn't -/// provide `sql_quantile_inline`, fall back to per-function subqueries using -/// `dialect.sql_percentile`. -#[allow(clippy::too_many_arguments)] -pub fn apply( - query: &str, - schema: &Schema, - aesthetics: &Mappings, - group_by: &[String], - parameters: &HashMap, - dialect: &dyn SqlDialect, - agg_slots: &[u8], - range_pair: Option<(&'static str, &'static str)>, -) -> Result { - let funcs = match extract_aggregate_param(parameters) { - None => return Ok(StatResult::Identity), - Some(funcs) => funcs, - }; - - if let Some((lo, hi)) = range_pair { - return apply_range_mode(query, schema, aesthetics, group_by, &funcs, dialect, lo, hi); - } - - // Walk the layer's position aesthetics and route each by (slot, type): - // in-axis slot && numeric → aggregated (numeric_pos) - // in-axis slot && discrete → kept as group column (kept_pos_cols) - // out-of-axis (any type) → kept as group column (kept_pos_cols) - let mut numeric_pos: Vec<(String, String)> = Vec::new(); // (aesthetic, prefixed col) - let mut kept_pos_cols: Vec = Vec::new(); - for (aesthetic, value) in &aesthetics.aesthetics { - if !is_position_aesthetic(aesthetic) { - continue; - } - let col = match value.column_name() { - Some(c) => c.to_string(), - None => continue, - }; - let slot = parse_position(aesthetic).map(|(s, _)| s).unwrap_or(0); - let in_axis = agg_slots.contains(&slot); - let info = schema.iter().find(|c| c.name == col); - let is_discrete = info.map(|c| c.is_discrete).unwrap_or(false); - - if !in_axis || is_discrete { - kept_pos_cols.push(col); - } else { - numeric_pos.push((aesthetic.clone(), col)); - } - } - numeric_pos.sort_by(|a, b| a.0.cmp(&b.0)); - kept_pos_cols.sort(); - - if numeric_pos.is_empty() && !funcs.iter().any(|f| f == "count") { - return Err(GgsqlError::ValidationError( - "aggregate requires at least one numeric position aesthetic, or the 'count' function" - .to_string(), - )); - } - - // Group columns: PARTITION BY + discrete mappings (already in group_by) + any - // position-aesthetic columns we kept (out-of-axis or in-axis-but-discrete). - // Deduplicated, preserving order. - let mut group_cols: Vec = Vec::new(); - for g in group_by { - if !group_cols.contains(g) { - group_cols.push(g.clone()); - } - } - for c in &kept_pos_cols { - if !group_cols.contains(c) { - group_cols.push(c.clone()); - } - } - - let needs_count_col = funcs.iter().any(|f| f == "count"); - - // Decide strategy: single-pass when every percentile component can be inlined. - let probe = numeric_pos - .first() - .map(|(_, c)| c.as_str()) - .unwrap_or("__ggsql_probe__"); - let needs_fallback = funcs.iter().any(|f| { - parse_agg_name(f) - .map(|spec| needs_quantile_fallback(&spec, probe, dialect)) - .unwrap_or(false) - }); - - let transformed_query = if needs_fallback { - build_union_all_query(query, &funcs, &numeric_pos, &group_cols, dialect) - } else { - build_single_pass_query(query, &funcs, &numeric_pos, &group_cols, dialect) - }; - - let mut stat_columns: Vec = numeric_pos.iter().map(|(a, _)| a.clone()).collect(); - stat_columns.push("aggregate".to_string()); - if needs_count_col { - stat_columns.push("count".to_string()); - } - - let consumed_aesthetics: Vec = numeric_pos.into_iter().map(|(a, _)| a).collect(); - - Ok(StatResult::Transformed { - query: transformed_query, - stat_columns, - dummy_columns: vec![], - consumed_aesthetics, - }) -} - -/// Extract the `aggregate` parameter as a list of function names, or `None` when -/// the parameter is unset/null. -fn extract_aggregate_param(parameters: &HashMap) -> Option> { - use crate::plot::types::ArrayElement; - match parameters.get("aggregate") { - None | Some(ParameterValue::Null) => None, - Some(ParameterValue::String(s)) => Some(vec![s.clone()]), - Some(ParameterValue::Array(arr)) => { - let names: Vec = arr - .iter() - .filter_map(|el| match el { - ArrayElement::String(s) => Some(s.clone()), - _ => None, - }) - .collect(); - if names.is_empty() { - None - } else { - Some(names) - } - } - _ => None, - } -} +// ============================================================================= +// SQL fragment helpers (per-column aggregate expressions). +// ============================================================================= /// Map a percentile function name (`p05`..`p95`, `median`) to its fraction. fn percentile_fraction(func: &str) -> Option { @@ -380,14 +323,13 @@ fn percentile_fraction(func: &str) -> Option { } /// Build the inline SQL fragment for a *simple* stat (no band) applied to a -/// quoted column. -/// -/// Returns `None` for `count` (which doesn't take a column) and for percentile- -/// based stats (`p05..p95`, `median`, `iqr`) when the dialect lacks an inline -/// quantile aggregate (caller should switch to UNION ALL strategy). +/// quoted column. Returns `None` for percentile-based stats when the dialect +/// lacks an inline quantile aggregate (caller switches to the correlated +/// `sql_percentile` fallback). fn simple_stat_sql_inline(name: &str, qcol: &str, dialect: &dyn SqlDialect) -> Option { if name == "count" { - return None; + // `count` in this position is COUNT(col): non-null tally for that column. + return Some(format!("COUNT({})", qcol)); } if let Some(frac) = percentile_fraction(name) { let unquoted = unquote(qcol); @@ -416,8 +358,6 @@ fn simple_stat_sql_inline(name: &str, qcol: &str, dialect: &dyn SqlDialect) -> O }) } -/// Inline SQL for a parsed `AggSpec`. Combines the offset and (optional) -/// expansion halves with the appropriate sign and modifier. fn agg_sql_inline(spec: &AggSpec, qcol: &str, dialect: &dyn SqlDialect) -> Option { let offset_sql = simple_stat_sql_inline(spec.offset, qcol, dialect)?; match &spec.band { @@ -434,8 +374,6 @@ fn agg_sql_inline(spec: &AggSpec, qcol: &str, dialect: &dyn SqlDialect) -> Optio } } -/// Build the SQL fragment `(offset ± mod * exp)`, omitting the `mod *` prefix -/// when `mod_value == 1.0`. fn format_band(offset: &str, sign: char, mod_value: f64, exp: &str) -> String { if mod_value == 1.0 { format!("({} {} {})", offset, sign, exp) @@ -444,9 +382,9 @@ fn format_band(offset: &str, sign: char, mod_value: f64, exp: &str) -> String { } } -/// Fallback SQL for a simple stat. Used by the UNION-ALL path for percentile -/// components (which need correlated `sql_percentile`) and falls through to -/// the inline form for everything else. +/// Fallback SQL for a simple stat — used when a percentile component lacks +/// inline support. Emits a correlated `sql_percentile` subquery; falls +/// through to the inline form for everything else. fn simple_stat_sql_fallback( name: &str, raw_col: &str, @@ -454,9 +392,6 @@ fn simple_stat_sql_fallback( src_alias: &str, group_cols: &[String], ) -> String { - if name == "count" { - return "NULL".to_string(); - } if let Some(frac) = percentile_fraction(name) { return dialect.sql_percentile(raw_col, frac, src_alias, group_cols); } @@ -469,7 +404,6 @@ fn simple_stat_sql_fallback( simple_stat_sql_inline(name, &qcol, dialect).unwrap_or_else(|| "NULL".to_string()) } -/// Fallback SQL for a parsed `AggSpec` (UNION-ALL path). fn agg_sql_fallback( spec: &AggSpec, raw_col: &str, @@ -488,8 +422,6 @@ fn agg_sql_fallback( } } -/// Whether this spec has any percentile component that the dialect can't -/// inline (in which case the caller must use the UNION-ALL fallback). fn needs_quantile_fallback(spec: &AggSpec, probe_col: &str, dialect: &dyn SqlDialect) -> bool { if simple_needs_fallback(spec.offset, probe_col, dialect) { return true; @@ -512,418 +444,254 @@ fn simple_needs_fallback(name: &str, probe_col: &str, dialect: &dyn SqlDialect) false } -/// Strip surrounding double quotes from an identifier, undoing `naming::quote_ident`. fn unquote(qcol: &str) -> String { let trimmed = qcol.trim_start_matches('"').trim_end_matches('"'); trimmed.replace("\"\"", "\"") } -/// SQL for a function name literal, properly escaped. -fn func_literal(func: &str) -> String { - format!("'{}'", func.replace('\'', "''")) -} - // ============================================================================= -// Range-mode strategy: exactly two functions filling a (lower, upper) aesthetic -// pair on the same row. Used by ribbon/range. +// apply — entry point. // ============================================================================= +/// Resolve a user-facing target aesthetic name to one or more internal names +/// that are actually mapped on the layer. Handles three cases: +/// 1. The name maps directly through `AestheticContext` (e.g. `y` → `pos2`). +/// 2. The name is an alias from `AESTHETIC_ALIASES` (e.g. `color` → `stroke`, +/// `fill`); each target whose internal counterpart is mapped is included. +/// 3. The name is a material aesthetic with the same internal name (e.g. `size`). +/// +/// Returns the empty vector if no resolution finds a mapped aesthetic. +fn resolve_target_aesthetic( + user_aes: &str, + aesthetics: &Mappings, + aesthetic_ctx: &AestheticContext, +) -> Vec { + use crate::plot::layer::geom::types::AESTHETIC_ALIASES; + let mut out = Vec::new(); + if let Some(internal) = aesthetic_ctx.map_user_to_internal(user_aes) { + if aesthetics.aesthetics.contains_key(internal) { + out.push(internal.to_string()); + return out; + } + } + for (alias, targets) in AESTHETIC_ALIASES { + if *alias == user_aes { + for t in *targets { + let internal = aesthetic_ctx + .map_user_to_internal(t) + .map(|s| s.to_string()) + .unwrap_or_else(|| (*t).to_string()); + if aesthetics.aesthetics.contains_key(&internal) && !out.contains(&internal) { + out.push(internal); + } + } + return out; + } + } + if aesthetics.aesthetics.contains_key(user_aes) { + out.push(user_aes.to_string()); + } + out +} + +/// Classify an internal aesthetic name as upper-half or lower-half for the +/// purpose of default-aggregate routing. +/// +/// `min` suffix → lower; `max`/`end` → upper; no suffix → lower. Material +/// aesthetics (no position prefix) are always lower. +fn is_upper_half(internal_aes: &str) -> bool { + internal_aes.ends_with("max") || internal_aes.ends_with("end") +} + +/// Apply the Aggregate stat to a layer query. +/// +/// Returns `StatResult::Identity` when the `aggregate` parameter is unset, null, +/// or empty. Otherwise, builds a single-pass `GROUP BY` query producing one row +/// per group with one aggregated column per kept numeric mapping. #[allow(clippy::too_many_arguments)] -fn apply_range_mode( +pub fn apply( query: &str, schema: &Schema, aesthetics: &Mappings, group_by: &[String], - funcs: &[String], + parameters: &HashMap, dialect: &dyn SqlDialect, - lo: &'static str, - hi: &'static str, + aesthetic_ctx: &AestheticContext, ) -> Result { - if funcs.len() != 2 { - return Err(GgsqlError::ValidationError(format!( - "aggregate on a range geom must be an array of exactly two functions (lower, upper), got {}", - funcs.len() - ))); + let raw = match parameters.get("aggregate") { + None | Some(ParameterValue::Null) => return Ok(StatResult::Identity), + Some(v) => v, + }; + let spec = parse_aggregate_param(raw) + .map_err(GgsqlError::ValidationError)?; + let spec = match spec { + Some(s) => s, + None => return Ok(StatResult::Identity), + }; + + // Resolve target keys (user-facing) → internal aesthetic names. An alias + // like `color` expands to whichever of its targets (stroke/fill) is mapped + // on the layer; the function applies to all of them. + let mut targets_internal: HashMap = HashMap::new(); + for (user_aes, agg) in &spec.targets { + let resolved = resolve_target_aesthetic(user_aes, aesthetics, aesthetic_ctx); + if resolved.is_empty() { + return Err(GgsqlError::ValidationError(format!( + "aggregate target '{}' is not mapped on this layer", + user_aes + ))); + } + for internal in resolved { + if targets_internal.contains_key(&internal) { + return Err(GgsqlError::ValidationError(format!( + "aggregate target '{}' resolves to aesthetic '{}' which is already targeted", + user_aes, internal + ))); + } + targets_internal.insert(internal, agg.clone()); + } } - // Range mode requires `pos2` mapped to a numeric input column. The user - // writes `MAPPING value AS y` and the stat consumes it to produce both - // bounds. - let input_col = match aesthetics.get("pos2").and_then(|v| v.column_name()) { - Some(c) => c.to_string(), - None => { - return Err(GgsqlError::ValidationError( - "aggregate on a range geom requires a `y` (pos2) mapping as the input column" - .to_string(), - )); + // Walk mappings. Three buckets: + // - aggregated: (internal_aes, raw_col, AggSpec) — each emits one column + // - kept_cols: discrete column-mappings — keep as group key + // - dropped: numeric mapping with no applicable function (warn & skip) + let mut aggregated: Vec<(String, String, AggSpec)> = Vec::new(); + let mut kept_cols: Vec = Vec::new(); + let mut dropped: Vec = Vec::new(); + + let mut entries: Vec<(&String, &crate::AestheticValue)> = aesthetics.aesthetics.iter().collect(); + entries.sort_by(|a, b| a.0.cmp(b.0)); + + for (aes, value) in entries { + let col = match value.column_name() { + Some(c) => c.to_string(), + None => continue, // literals & annotation columns pass through + }; + let info = schema.iter().find(|c| c.name == col); + let is_discrete = info.map(|c| c.is_discrete).unwrap_or(false); + if is_discrete { + if !kept_cols.contains(&col) { + kept_cols.push(col); + } + continue; } - }; - let info = schema.iter().find(|c| c.name == input_col); - if info.map(|c| c.is_discrete).unwrap_or(false) { - return Err(GgsqlError::ValidationError( - "aggregate on a range geom requires a numeric `y` (pos2) input, not a discrete column" - .to_string(), - )); + + // Numeric mapping. Look up the aggregation function. + let agg = if let Some(targeted) = targets_internal.get(aes) { + Some(targeted.clone()) + } else if is_upper_half(aes) { + spec.default_upper + .clone() + .or_else(|| spec.default_lower.clone()) + .filter(|_| spec.default_upper.is_some() || spec.default_lower.is_some()) + } else { + spec.default_lower.clone() + }; + + match agg { + Some(a) => aggregated.push((aes.clone(), col, a)), + None => dropped.push(aes.clone()), + } + } + + // The *only* time we have nothing to aggregate but should still transform + // is when defaults exist but every numeric mapping was dropped — we still + // emit a GROUP BY to honour the grouping. If there are no aggregations and + // no kept columns and no group_by, return Identity. + if aggregated.is_empty() && kept_cols.is_empty() && group_by.is_empty() { + for d in &dropped { + eprintln!( + "Warning: aggregate dropped numeric mapping for aesthetic '{}' (no applicable default and no targeted function)", + aesthetic_ctx.map_internal_to_user(d) + ); + } + return Ok(StatResult::Identity); + } + + for d in &dropped { + eprintln!( + "Warning: aggregate dropped numeric mapping for aesthetic '{}' (no applicable default and no targeted function)", + aesthetic_ctx.map_internal_to_user(d) + ); } - let qcol = naming::quote_ident(&input_col); - // Group columns: PARTITION BY + discrete mappings (already in group_by) + - // any discrete position aesthetics on the layer (e.g. pos1 if it's a string). + // Group columns: PARTITION BY + discrete column-mappings, deduped. let mut group_cols: Vec = Vec::new(); for g in group_by { if !group_cols.contains(g) { group_cols.push(g.clone()); } } - for (aesthetic, value) in &aesthetics.aesthetics { - if !is_position_aesthetic(aesthetic) || aesthetic == "pos2" { - continue; - } - let col = match value.column_name() { - Some(c) => c.to_string(), - None => continue, - }; - if !group_cols.contains(&col) { - group_cols.push(col); + for c in &kept_cols { + if !group_cols.contains(c) { + group_cols.push(c.clone()); } } - let src_alias = "\"__ggsql_stat_src__\""; - let group_by_clause = if group_cols.is_empty() { - String::new() - } else { - let qcols: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); - format!(" GROUP BY {}", qcols.join(", ")) - }; - - // Parse and emit each bound. Use the inline form when the dialect supports - // every percentile component; otherwise fall back to `sql_percentile` - // correlated to the outer alias used in the FROM (`__ggsql_qt__`). - let lo_expr = build_range_function_sql(&funcs[0], &qcol, &input_col, dialect, &group_cols)?; - let hi_expr = build_range_function_sql(&funcs[1], &qcol, &input_col, dialect, &group_cols)?; + let transformed_query = + build_group_by_query(query, &aggregated, &group_cols, dialect); - let stat_lo = naming::stat_column(lo); - let stat_hi = naming::stat_column(hi); - - let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); - let mut select_parts = group_select.clone(); - select_parts.push(format!("{} AS {}", lo_expr, naming::quote_ident(&stat_lo))); - select_parts.push(format!("{} AS {}", hi_expr, naming::quote_ident(&stat_hi))); - - let transformed_query = format!( - "WITH {src} AS ({query}) SELECT {sel} FROM {src} AS \"__ggsql_qt__\"{gb}", - src = src_alias, - query = query, - sel = select_parts.join(", "), - gb = group_by_clause, - ); + let stat_columns: Vec = aggregated.iter().map(|(a, _, _)| a.clone()).collect(); + let consumed_aesthetics: Vec = stat_columns.clone(); - // consumed_aesthetics: pos2 carries the original-name capture for axis - // labels; lo/hi flag the auto-rename in execute/layer.rs (their stat-column - // names match the position aesthetics they fill). Ok(StatResult::Transformed { query: transformed_query, - stat_columns: vec![lo.to_string(), hi.to_string()], + stat_columns, dummy_columns: vec![], - consumed_aesthetics: vec!["pos2".to_string(), lo.to_string(), hi.to_string()], + consumed_aesthetics, }) } -/// Build the SQL fragment for one function in range mode. Parses the function -/// name into an `AggSpec` (which validates the offset/expansion vocabulary) -/// and emits inline SQL when the dialect supports every percentile component, -/// otherwise the correlated fallback. -fn build_range_function_sql( - func: &str, - qcol: &str, - raw_col: &str, - dialect: &dyn SqlDialect, - group_cols: &[String], -) -> Result { - if func == "count" { - return Err(GgsqlError::ValidationError( - "aggregate on a range geom does not support 'count' (it has no range semantics)" - .to_string(), - )); - } - let spec = parse_agg_name(func).ok_or_else(|| { - GgsqlError::ValidationError(format!( - "aggregate on a range geom: {}", - diagnose_invalid_function_name(func) - )) - })?; - if needs_quantile_fallback(&spec, raw_col, dialect) { - Ok(agg_sql_fallback( - &spec, - raw_col, - dialect, - "\"__ggsql_stat_src__\"", - group_cols, - )) - } else { - agg_sql_inline(&spec, qcol, dialect).ok_or_else(|| { - GgsqlError::ValidationError(format!( - "aggregate on a range geom does not support function '{}' on this dialect", - func - )) - }) - } -} - -// ============================================================================= -// Single-pass strategy: GROUP BY produces a wide CTE, then CROSS JOIN explodes -// rows per requested function. -// ============================================================================= - -fn build_single_pass_query( +/// Build the `WITH src AS () SELECT , FROM src +/// AS "__ggsql_qt__" GROUP BY ` query. +/// +/// Falls back to `dialect.sql_percentile()` per-column when an aggregate's +/// percentile component lacks inline support. +fn build_group_by_query( query: &str, - funcs: &[String], - numeric_pos: &[(String, String)], + aggregated: &[(String, String, AggSpec)], group_cols: &[String], dialect: &dyn SqlDialect, ) -> String { let src_alias = "\"__ggsql_stat_src__\""; - let agg_alias = "\"__ggsql_stat_agg__\""; - let funcs_alias = "\"__ggsql_stat_funcs__\""; + let outer_alias = "\"__ggsql_qt__\""; + let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); let group_by_clause = if group_cols.is_empty() { String::new() } else { - let qcols: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); - format!(" GROUP BY {}", qcols.join(", ")) + format!(" GROUP BY {}", group_select.join(", ")) }; - // Build the wide aggregation SELECT: one column per (function × position). - let mut wide_select_exprs: Vec = - group_cols.iter().map(|c| naming::quote_ident(c)).collect(); - - // Track the synthetic column names for each (aesthetic, function) pair. - let mut wide_col_for: HashMap<(String, String), String> = HashMap::new(); - - for (aes, col) in numeric_pos { - let qcol = naming::quote_ident(col); - for func in funcs { - if func == "count" { - continue; - } - let key = (aes.clone(), func.clone()); - if wide_col_for.contains_key(&key) { - continue; - } - let wide_name = synthetic_col_name(aes, func); - let spec = parse_agg_name(func) - .expect("aggregate function names are validated upstream of single-pass"); - let expr = agg_sql_inline(&spec, &qcol, dialect) - .expect("agg_sql_inline must be Some when single-pass is selected"); - wide_select_exprs.push(format!("{} AS {}", expr, naming::quote_ident(&wide_name))); - wide_col_for.insert(key, wide_name); - } - } - - let needs_count_col = funcs.iter().any(|f| f == "count"); - let count_wide = if needs_count_col { - let c = "__ggsql_stat_cnt__"; - wide_select_exprs.push(format!("COUNT(*) AS {}", naming::quote_ident(c))); - Some(c.to_string()) - } else { - None - }; + let mut select_parts: Vec = group_select.clone(); - let wide_select = wide_select_exprs.join(", "); - - // Build the CROSS JOIN VALUES table of function names. - let funcs_values: Vec = funcs - .iter() - .map(|f| format!("({})", func_literal(f))) - .collect(); - let funcs_cte = format!( - "{}(name) AS (VALUES {})", - funcs_alias, - funcs_values.join(", ") - ); - - // Build the outer SELECT: group cols + per-aesthetic CASE + count CASE + name AS aggregate. - let mut outer_exprs: Vec = group_cols - .iter() - .map(|c| format!("{}.{}", agg_alias, naming::quote_ident(c))) - .collect(); - - for (aes, _) in numeric_pos { + for (aes, raw_col, agg) in aggregated { let stat_col = naming::stat_column(aes); - let mut whens: Vec = Vec::new(); - for func in funcs { - if let Some(wide_name) = wide_col_for.get(&(aes.clone(), func.clone())) { - whens.push(format!( - "WHEN {} THEN {}.{}", - func_literal(func), - agg_alias, - naming::quote_ident(wide_name) - )); - } - } - let case_expr = if whens.is_empty() { - "NULL".to_string() + let qcol = naming::quote_ident(raw_col); + let expr = if needs_quantile_fallback(agg, raw_col, dialect) { + agg_sql_fallback(agg, raw_col, dialect, src_alias, group_cols) } else { - format!( - "CASE {}.name {} ELSE NULL END", - funcs_alias, - whens.join(" ") - ) + agg_sql_inline(agg, &qcol, dialect) + .expect("agg_sql_inline must succeed when needs_quantile_fallback is false") }; - outer_exprs.push(format!( - "{} AS {}", - case_expr, - naming::quote_ident(&stat_col) - )); - } - - if let Some(count_wide) = count_wide { - let stat_col = naming::stat_column("count"); - let case_expr = format!( - "CASE {f}.name WHEN {lit} THEN {a}.{c} ELSE NULL END", - f = funcs_alias, - a = agg_alias, - lit = func_literal("count"), - c = naming::quote_ident(&count_wide) - ); - outer_exprs.push(format!( - "{} AS {}", - case_expr, - naming::quote_ident(&stat_col) - )); + select_parts.push(format!("{} AS {}", expr, naming::quote_ident(&stat_col))); } - let stat_aggregate_col = naming::stat_column("aggregate"); - outer_exprs.push(format!( - "{}.name AS {}", - funcs_alias, - naming::quote_ident(&stat_aggregate_col) - )); - - format!( - "WITH {src} AS ({query}), \ - {agg_alias_def} AS (SELECT {wide_select} FROM {src}{group_by}), \ - {funcs_cte} \ - SELECT {outer} FROM {agg} CROSS JOIN {funcs}", - src = src_alias, - query = query, - agg_alias_def = agg_alias, - wide_select = wide_select, - group_by = group_by_clause, - funcs_cte = funcs_cte, - outer = outer_exprs.join(", "), - agg = agg_alias, - funcs = funcs_alias, - ) -} - -/// Synthetic name for a (aesthetic, function) intermediate column in the wide CTE. -/// Includes a sanitized form of the function name to avoid collisions on `+`/`-`. -fn synthetic_col_name(aes: &str, func: &str) -> String { - let safe: String = func - .chars() - .map(|c| match c { - '+' => 'p', - '-' => 'm', - _ if c.is_ascii_alphanumeric() => c, - _ => '_', - }) - .collect(); - format!("__ggsql_stat_{}_{}", aes, safe) -} - -// ============================================================================= -// UNION ALL fallback strategy: one SELECT per requested function. -// ============================================================================= - -fn build_union_all_query( - query: &str, - funcs: &[String], - numeric_pos: &[(String, String)], - group_cols: &[String], - dialect: &dyn SqlDialect, -) -> String { - let src_alias = "\"__ggsql_stat_src__\""; - - let group_by_clause = if group_cols.is_empty() { - String::new() - } else { - let qcols: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); - format!(" GROUP BY {}", qcols.join(", ")) - }; - - let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); - - let needs_count_col = funcs.iter().any(|f| f == "count"); - let stat_aggregate_col = naming::stat_column("aggregate"); - let stat_count_col = naming::stat_column("count"); - - let branches: Vec = funcs - .iter() - .map(|func| { - let mut select_parts: Vec = group_select.clone(); - - // Parse the function name once per branch. Falls through to a - // string-NULL value column if parsing fails (shouldn't happen - // because validation runs upstream, but stay defensive). - let parsed_spec = parse_agg_name(func); - for (aes, col) in numeric_pos { - let stat_col = naming::stat_column(aes); - let value_expr = if func == "count" { - "NULL".to_string() - } else if let Some(spec) = &parsed_spec { - agg_sql_fallback(spec, col, dialect, src_alias, group_cols) - } else { - "NULL".to_string() - }; - select_parts.push(format!( - "{} AS {}", - value_expr, - naming::quote_ident(&stat_col) - )); - } - - if needs_count_col { - let value_expr = if func == "count" { - "COUNT(*)".to_string() - } else { - "NULL".to_string() - }; - select_parts.push(format!( - "{} AS {}", - value_expr, - naming::quote_ident(&stat_count_col) - )); - } - - select_parts.push(format!( - "{} AS {}", - func_literal(func), - naming::quote_ident(&stat_aggregate_col) - )); - - // Quantile fallbacks (sql_percentile) need the outer alias `__ggsql_qt__` - // so their correlated WHERE clause can find group columns. - format!( - "SELECT {} FROM {} AS \"__ggsql_qt__\"{}", - select_parts.join(", "), - src_alias, - group_by_clause - ) - }) - .collect(); - format!( - "WITH {src} AS ({query}) {branches}", + "WITH {src} AS ({query}) SELECT {sel} FROM {src} AS {outer}{gb}", src = src_alias, query = query, - branches = branches.join(" UNION ALL ") + sel = select_parts.join(", "), + outer = outer_alias, + gb = group_by_clause, ) } #[cfg(test)] mod tests { use super::*; + use crate::plot::aesthetic::AestheticContext; use crate::plot::types::{AestheticValue, ColumnInfo}; use arrow::datatypes::DataType; @@ -939,7 +707,8 @@ mod tests { } } - /// A test dialect with no inline quantile support, exercising the UNION ALL fallback. + /// A test dialect with no inline quantile support, exercising the + /// per-column `sql_percentile` fallback. struct NoInlineQuantileDialect; impl SqlDialect for NoInlineQuantileDialect {} @@ -951,410 +720,155 @@ mod tests { } } - fn numeric_schema(cols: &[&str]) -> Schema { + fn schema_for(cols: &[(&str, bool)]) -> Schema { cols.iter() - .map(|c| ColumnInfo { - name: c.to_string(), - dtype: DataType::Float64, - is_discrete: false, + .map(|(name, is_discrete)| ColumnInfo { + name: name.to_string(), + dtype: if *is_discrete { + DataType::Utf8 + } else { + DataType::Float64 + }, + is_discrete: *is_discrete, min: None, max: None, }) .collect() } + fn cartesian_ctx() -> AestheticContext { + AestheticContext::from_static(&["x", "y"], &[]) + } + + fn run( + params: ParameterValue, + aes: &Mappings, + schema: &Schema, + group_by: &[String], + dialect: &dyn SqlDialect, + ) -> Result { + let mut p = HashMap::new(); + p.insert("aggregate".to_string(), params); + let ctx = cartesian_ctx(); + apply("SELECT * FROM t", schema, aes, group_by, &p, dialect, &ctx) + } + + fn arr(items: &[&str]) -> ParameterValue { + ParameterValue::Array(items.iter().map(|s| ArrayElement::String(s.to_string())).collect()) + } + + // ---------- parser tests ---------- + #[test] - fn returns_identity_when_param_unset() { - let aes = Mappings::new(); - let schema: Schema = vec![]; - let params = HashMap::new(); - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - assert_eq!(result, StatResult::Identity); + fn parses_unset_and_null() { + assert_eq!(parse_aggregate_param(&ParameterValue::Null).unwrap(), None); + assert_eq!(parse_aggregate_param(&arr(&[])).unwrap(), None); } #[test] - fn returns_identity_when_param_null() { - let aes = Mappings::new(); - let schema: Schema = vec![]; - let mut params = HashMap::new(); - params.insert("aggregate".to_string(), ParameterValue::Null); - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - assert_eq!(result, StatResult::Identity); + fn parses_single_default() { + let s = parse_aggregate_param(&ParameterValue::String("mean".to_string())) + .unwrap() + .unwrap(); + assert_eq!(s.default_lower.as_ref().map(|a| a.offset), Some("mean")); + assert!(s.default_upper.is_none()); + assert!(s.targets.is_empty()); } #[test] - fn single_pass_for_mean_emits_avg() { - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("mean".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - - match result { - StatResult::Transformed { - query, - stat_columns, - consumed_aesthetics, - .. - } => { - assert!( - query.contains("AVG(\"__ggsql_aes_pos2__\")"), - "query: {}", - query - ); - assert!(query.contains("CROSS JOIN")); - assert!(stat_columns.contains(&"pos2".to_string())); - assert!(stat_columns.contains(&"aggregate".to_string())); - assert!(!stat_columns.contains(&"count".to_string())); - assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); - } - _ => panic!("expected Transformed"), - } + fn parses_two_defaults_in_order() { + let s = parse_aggregate_param(&arr(&["min", "max"])).unwrap().unwrap(); + assert_eq!(s.default_lower.as_ref().map(|a| a.offset), Some("min")); + assert_eq!(s.default_upper.as_ref().map(|a| a.offset), Some("max")); } #[test] - fn count_emits_count_star_and_keeps_count_column() { - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("count".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - - match result { - StatResult::Transformed { - query, - stat_columns, - .. - } => { - assert!(query.contains("COUNT(*)")); - assert!(stat_columns.contains(&"count".to_string())); - assert!(stat_columns.contains(&"aggregate".to_string())); - } - _ => panic!("expected Transformed"), - } + fn three_unprefixed_defaults_is_error() { + let err = parse_aggregate_param(&arr(&["mean", "min", "max"])).unwrap_err(); + assert!(err.contains("at most two"), "got: {}", err); } #[test] - fn mixed_count_and_mean_produces_two_rows_per_group() { - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - use crate::plot::types::ArrayElement; - params.insert( - "aggregate".to_string(), - ParameterValue::Array(vec![ - ArrayElement::String("count".to_string()), - ArrayElement::String("mean".to_string()), - ]), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); - assert!(query.contains("COUNT(*)")); - assert!(query.contains("'count'")); - assert!(query.contains("'mean'")); - // The count CASE must reference the agg CTE for the value column, - // not the funcs CTE (regression: previously emitted funcs.cnt which - // doesn't exist). - assert!( - query.contains("\"__ggsql_stat_agg__\".\"__ggsql_stat_cnt__\""), - "count CASE should reference the agg CTE, query was: {}", - query - ); - } - _ => panic!("expected Transformed"), - } + fn parses_targeted_entries() { + let s = parse_aggregate_param(&arr(&["mean", "y:max", "color:median"])) + .unwrap() + .unwrap(); + assert_eq!(s.default_lower.as_ref().map(|a| a.offset), Some("mean")); + assert_eq!(s.targets.get("y").map(|a| a.offset), Some("max")); + assert_eq!(s.targets.get("color").map(|a| a.offset), Some("median")); } #[test] - fn quantile_uses_dialect_inline_when_available() { - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("p25".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - assert!(query.contains("QUANTILE_CONT")); - assert!(query.contains("0.25")); - assert!(!query.contains("UNION ALL")); - } - _ => panic!("expected Transformed"), - } + fn duplicate_target_is_error() { + let err = parse_aggregate_param(&arr(&["y:mean", "y:median"])).unwrap_err(); + assert!(err.contains("more than one aggregate"), "got: {}", err); } #[test] - fn quantile_falls_back_to_union_all_without_dialect_support() { - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("p25".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &NoInlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - // Fallback dialect uses NTILE-based correlated subquery via UNION ALL. - assert!(query.contains("NTILE(4)")); - assert!(query.contains("UNION ALL") || !query.contains("CROSS JOIN")); - } - _ => panic!("expected Transformed"), - } + fn empty_prefix_is_error() { + let err = parse_aggregate_param(&ParameterValue::String(":mean".to_string())).unwrap_err(); + assert!(err.contains("aesthetic prefix"), "got: {}", err); } #[test] - fn mean_sdev_emits_avg_and_stddev() { - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - use crate::plot::types::ArrayElement; - params.insert( - "aggregate".to_string(), - ParameterValue::Array(vec![ - ArrayElement::String("mean-sdev".to_string()), - ArrayElement::String("mean+sdev".to_string()), - ]), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - assert!(query.contains("STDDEV_POP")); - assert!(query.contains("AVG")); - } - _ => panic!("expected Transformed"), - } + fn unknown_function_is_error() { + let err = parse_aggregate_param(&ParameterValue::String("nope".to_string())).unwrap_err(); + assert!(err.contains("unknown aggregate"), "got: {}", err); } #[test] - fn mean_se_includes_sqrt_count() { - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("mean+se".to_string()), + fn band_functions_parse() { + let s = parse_aggregate_param(&arr(&["mean-sdev", "mean+sdev"])) + .unwrap() + .unwrap(); + assert_eq!(s.default_lower.as_ref().unwrap().offset, "mean"); + assert_eq!( + s.default_lower.as_ref().unwrap().band.as_ref().unwrap().expansion, + "sdev" ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - assert!(query.contains("SQRT(COUNT")); - } - _ => panic!("expected Transformed"), - } + assert_eq!( + s.default_lower.as_ref().unwrap().band.as_ref().unwrap().sign, + '-' + ); + assert_eq!(s.default_upper.as_ref().unwrap().offset, "mean"); } - #[test] - fn prod_emits_exp_sum_ln() { - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("prod".to_string()), - ); + // ---------- apply tests ---------- - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - assert!(query.contains("EXP(SUM(LN")); - } - _ => panic!("expected Transformed"), - } + #[test] + fn returns_identity_when_param_unset() { + let aes = Mappings::new(); + let schema: Schema = vec![]; + let p: HashMap = HashMap::new(); + let ctx = cartesian_ctx(); + let result = apply("SELECT * FROM t", &schema, &aes, &[], &p, &InlineQuantileDialect, &ctx) + .unwrap(); + assert_eq!(result, StatResult::Identity); } #[test] - fn iqr_emits_p75_minus_p25() { - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("iqr".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - assert!(query.contains("0.75")); - assert!(query.contains("0.25")); - } - _ => panic!("expected Transformed"), - } + fn returns_identity_when_param_null() { + let aes = Mappings::new(); + let schema: Schema = vec![]; + let result = run(ParameterValue::Null, &aes, &schema, &[], &InlineQuantileDialect).unwrap(); + assert_eq!(result, StatResult::Identity); } #[test] - fn discrete_position_aesthetic_becomes_group_column() { + fn single_default_applies_to_every_numeric_mapping() { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = vec![ - ColumnInfo { - name: "__ggsql_aes_pos1__".to_string(), - dtype: DataType::Utf8, - is_discrete: true, - min: None, - max: None, - }, - ColumnInfo { - name: "__ggsql_aes_pos2__".to_string(), - dtype: DataType::Float64, - is_discrete: false, - min: None, - max: None, - }, - ]; - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ]); + let result = run( ParameterValue::String("mean".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, &aes, + &schema, &[], - ¶ms, &InlineQuantileDialect, - &[2], - None, ) .unwrap(); match result { @@ -1364,907 +878,331 @@ mod tests { consumed_aesthetics, .. } => { - // pos1 (discrete) is in GROUP BY, not aggregated. - assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); - // pos2 is aggregated. - assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); - // Only pos2 is consumed. - assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); - // Only pos2 (numeric) appears in stat_columns; pos1 stays as-is. + assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"), "{}", query); + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")"), "{}", query); + // No GROUP BY when no discrete mappings or PARTITION BY — SQL + // collapses to a single row per query, which is correct. + assert!(!query.contains("CROSS JOIN")); + assert!(!query.contains("UNION ALL")); + assert_eq!(stat_columns.len(), 2); + assert!(stat_columns.contains(&"pos1".to_string())); assert!(stat_columns.contains(&"pos2".to_string())); - assert!(!stat_columns.contains(&"pos1".to_string())); + assert_eq!(consumed_aesthetics.len(), 2); } _ => panic!("expected Transformed"), } } #[test] - fn explicit_group_by_columns_appear_in_query() { + fn two_defaults_split_lower_and_upper_for_segment() { let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("mean".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &["region".to_string()], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); + aes.insert("pos1end", col("__ggsql_aes_pos1end__")); + aes.insert("pos2end", col("__ggsql_aes_pos2end__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ("__ggsql_aes_pos1end__", false), + ("__ggsql_aes_pos2end__", false), + ]); + let result = run(arr(&["min", "max"]), &aes, &schema, &[], &InlineQuantileDialect) + .unwrap(); match result { StatResult::Transformed { query, .. } => { - assert!(query.contains("GROUP BY \"region\"")); + // pos1, pos2 use MIN; pos1end, pos2end use MAX. + assert!(query.contains("MIN(\"__ggsql_aes_pos1__\")"), "{}", query); + assert!(query.contains("MIN(\"__ggsql_aes_pos2__\")"), "{}", query); + assert!(query.contains("MAX(\"__ggsql_aes_pos1end__\")"), "{}", query); + assert!(query.contains("MAX(\"__ggsql_aes_pos2end__\")"), "{}", query); + assert!(!query.contains("MIN(\"__ggsql_aes_pos1end__\")")); + assert!(!query.contains("MAX(\"__ggsql_aes_pos1__\")")); } _ => panic!("expected Transformed"), } } #[test] - fn line_style_groups_by_pos1_and_aggregates_pos2() { - // slots=[2]: pos1 stays as group (even though numeric), pos2 gets aggregated. + fn two_defaults_split_for_ribbon() { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos1__", "__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("max".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, + aes.insert("pos2min", col("__ggsql_aes_pos2min__")); + aes.insert("pos2max", col("__ggsql_aes_pos2max__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2min__", false), + ("__ggsql_aes_pos2max__", false), + ]); + let result = run( + arr(&["mean-sdev", "mean+sdev"]), &aes, + &schema, &[], - ¶ms, &InlineQuantileDialect, - &[2], - None, ) .unwrap(); match result { - StatResult::Transformed { - query, - consumed_aesthetics, - stat_columns, - .. - } => { - assert!( - query.contains("MAX(\"__ggsql_aes_pos2__\")"), - "query: {}", - query - ); - assert!(!query.contains("MAX(\"__ggsql_aes_pos1__\")")); - assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); - assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); - assert!(stat_columns.contains(&"pos2".to_string())); - assert!(!stat_columns.contains(&"pos1".to_string())); + StatResult::Transformed { query, .. } => { + assert!(query.contains("STDDEV_POP(\"__ggsql_aes_pos2max__\")")); + assert!(query.contains("AVG(\"__ggsql_aes_pos2min__\")")); + // upper default (mean+sdev) goes to pos2max → '+' between AVG and STDDEV + let pos2max_section = query + .split("__ggsql_aes_pos2max__\")") + .next() + .unwrap_or(""); + assert!(pos2max_section.contains('+') || query.contains("+ STDDEV_POP")); } _ => panic!("expected Transformed"), } } #[test] - fn point_style_aggregates_both_slots() { - // slots=[1,2]: both pos1 and pos2 (numeric) get aggregated → centroid. + fn targeted_prefix_overrides_default() { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos1__", "__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("mean".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ]); + let result = run( + arr(&["mean", "y:max"]), &aes, + &schema, &[], - ¶ms, &InlineQuantileDialect, - &[1, 2], - None, ) .unwrap(); match result { - StatResult::Transformed { - query, - consumed_aesthetics, - stat_columns, - .. - } => { - assert!( - query.contains("AVG(\"__ggsql_aes_pos1__\")"), - "query: {}", - query - ); - assert!( - query.contains("AVG(\"__ggsql_aes_pos2__\")"), - "query: {}", - query - ); - assert!(!query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); - let mut consumed = consumed_aesthetics.clone(); - consumed.sort(); - assert_eq!(consumed, vec!["pos1".to_string(), "pos2".to_string()]); - assert!(stat_columns.contains(&"pos1".to_string())); - assert!(stat_columns.contains(&"pos2".to_string())); + StatResult::Transformed { query, .. } => { + assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")"), "{}", query); + assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")"), "{}", query); + assert!(!query.contains("AVG(\"__ggsql_aes_pos2__\")")); } _ => panic!("expected Transformed"), } } #[test] - fn range_geom_aggregates_pos2_minmax() { - // slots=[2]: pos1 fixed (group), pos2min and pos2max both aggregated. + fn material_aesthetic_targeted_by_user_facing_name() { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); - aes.insert("pos2min", col("__ggsql_aes_pos2min__")); - aes.insert("pos2max", col("__ggsql_aes_pos2max__")); - let schema = numeric_schema(&[ - "__ggsql_aes_pos1__", - "__ggsql_aes_pos2min__", - "__ggsql_aes_pos2max__", + aes.insert("pos2", col("__ggsql_aes_pos2__")); + aes.insert("size", col("__ggsql_aes_size__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ("__ggsql_aes_size__", false), ]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("mean".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, + let result = run( + arr(&["mean", "size:median"]), &aes, + &schema, &[], - ¶ms, &InlineQuantileDialect, - &[2], - None, ) .unwrap(); match result { - StatResult::Transformed { - query, - consumed_aesthetics, - .. - } => { - assert!( - query.contains("AVG(\"__ggsql_aes_pos2min__\")"), - "query: {}", - query - ); - assert!( - query.contains("AVG(\"__ggsql_aes_pos2max__\")"), - "query: {}", - query - ); - assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); - let mut consumed = consumed_aesthetics.clone(); - consumed.sort(); - assert_eq!(consumed, vec!["pos2max".to_string(), "pos2min".to_string()]); + StatResult::Transformed { query, stat_columns, .. } => { + assert!(query.contains("QUANTILE_CONT(\"__ggsql_aes_size__\", 0.5)")); + assert!(stat_columns.contains(&"size".to_string())); } _ => panic!("expected Transformed"), } } #[test] - fn out_of_axis_numeric_pos_stays_as_group() { - // slots=[2], numeric pos1 → still goes to GROUP BY (not aggregated). - // Same expectation as line_style_groups_by_pos1_and_aggregates_pos2 but - // explicit about the "numeric out-of-axis" path. + fn color_alias_targets_stroke_and_fill() { + // `color` is an alias that resolves to whichever of `stroke`/`fill` + // is actually mapped on the layer. let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos1__", "__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("mean".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, + aes.insert("fill", col("__ggsql_aes_fill__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ("__ggsql_aes_fill__", false), + ]); + let result = run( + arr(&["mean", "color:max"]), &aes, + &schema, &[], - ¶ms, &InlineQuantileDialect, - &[2], - None, ) .unwrap(); match result { - StatResult::Transformed { query, .. } => { - assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); + StatResult::Transformed { query, stat_columns, .. } => { + assert!(query.contains("MAX(\"__ggsql_aes_fill__\")"), "{}", query); + assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")")); + assert!(stat_columns.contains(&"fill".to_string())); } _ => panic!("expected Transformed"), } } #[test] - fn discrete_in_axis_pos_stays_as_group_on_centroid_geom() { - // slots=[1,2], pos1 discrete + pos2 numeric → only pos2 aggregated, - // pos1 stays as GROUP BY. Confirms numeric check is preserved on - // slot=[1,2] geoms (e.g. point with category AS x, value AS y). + fn discrete_mapping_becomes_group_key() { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = vec![ - ColumnInfo { - name: "__ggsql_aes_pos1__".to_string(), - dtype: DataType::Utf8, - is_discrete: true, - min: None, - max: None, - }, - ColumnInfo { - name: "__ggsql_aes_pos2__".to_string(), - dtype: DataType::Float64, - is_discrete: false, - min: None, - max: None, - }, - ]; - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), + aes.insert("color", col("__ggsql_aes_color__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ("__ggsql_aes_color__", true), // discrete! + ]); + let result = run( ParameterValue::String("mean".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, &aes, + &schema, &[], - ¶ms, &InlineQuantileDialect, - &[1, 2], - None, ) .unwrap(); match result { StatResult::Transformed { query, - consumed_aesthetics, stat_columns, .. } => { + assert!(query.contains("GROUP BY \"__ggsql_aes_color__\""), "{}", query); + assert!(!stat_columns.contains(&"color".to_string())); + assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")")); assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); - assert!(!query.contains("AVG(\"__ggsql_aes_pos1__\")")); - assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); - assert_eq!(consumed_aesthetics, vec!["pos2".to_string()]); - assert!(stat_columns.contains(&"pos2".to_string())); - assert!(!stat_columns.contains(&"pos1".to_string())); } _ => panic!("expected Transformed"), } } #[test] - fn count_works_with_no_numeric_pos() { - // slots=[2], only discrete pos1 mapped, aggregate=count → no - // "needs numeric" error; query has COUNT(*) and groups by pos1. + fn literal_mapping_passes_through() { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); - let schema = vec![ColumnInfo { - name: "__ggsql_aes_pos1__".to_string(), - dtype: DataType::Utf8, - is_discrete: true, - min: None, - max: None, - }]; - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("count".to_string()), + aes.insert("pos2", col("__ggsql_aes_pos2__")); + aes.insert( + "fill", + AestheticValue::Literal(ParameterValue::String("steelblue".to_string())), ); - - let result = apply( - "SELECT * FROM t", - &schema, + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ]); + let result = run( + ParameterValue::String("mean".to_string()), &aes, + &schema, &[], - ¶ms, &InlineQuantileDialect, - &[2], - None, ) .unwrap(); match result { - StatResult::Transformed { - query, - stat_columns, - .. - } => { - assert!(query.contains("COUNT(*)")); - assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); - assert!(stat_columns.contains(&"count".to_string())); + StatResult::Transformed { query, .. } => { + assert!(!query.contains("AVG(\"__ggsql_aes_fill__\")")); + assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")")); + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); } _ => panic!("expected Transformed"), } } - // ======================================================================== - // Range-mode tests (ribbon / range) - // ======================================================================== - - fn range_pair() -> Option<(&'static str, &'static str)> { - Some(("pos2min", "pos2max")) - } - - fn range_input_aes_with_group() -> (Mappings, Schema) { + #[test] + fn untargeted_numeric_mapping_dropped_when_no_default() { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = vec![ - ColumnInfo { - name: "__ggsql_aes_pos1__".to_string(), - dtype: DataType::Utf8, - is_discrete: true, - min: None, - max: None, - }, - ColumnInfo { - name: "__ggsql_aes_pos2__".to_string(), - dtype: DataType::Float64, - is_discrete: false, - min: None, - max: None, - }, - ]; - (aes, schema) - } - - #[test] - fn range_mode_two_functions_emits_one_row_per_group() { - let (aes, schema) = range_input_aes_with_group(); - let mut params = HashMap::new(); - use crate::plot::types::ArrayElement; - params.insert( - "aggregate".to_string(), - ParameterValue::Array(vec![ - ArrayElement::String("mean-sdev".to_string()), - ArrayElement::String("mean+sdev".to_string()), - ]), - ); - - let result = apply( - "SELECT * FROM t", - &schema, + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ]); + // Only `y` targeted, no default → x is dropped. + let result = run( + ParameterValue::String("y:mean".to_string()), &aes, + &schema, &[], - ¶ms, &InlineQuantileDialect, - &[2], - range_pair(), ) .unwrap(); match result { StatResult::Transformed { query, stat_columns, - consumed_aesthetics, .. } => { - assert!( - query.contains( - "AVG(\"__ggsql_aes_pos2__\") - STDDEV_POP(\"__ggsql_aes_pos2__\")" - ), - "lower bound expr missing: {}", - query - ); - assert!( - query.contains( - "AVG(\"__ggsql_aes_pos2__\") + STDDEV_POP(\"__ggsql_aes_pos2__\")" - ), - "upper bound expr missing: {}", - query - ); - assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\"")); - assert!(!query.contains("UNION ALL")); - assert!(!query.contains("CROSS JOIN")); - // No `aggregate` tag column in range mode. - assert!(!query.contains("__ggsql_stat_aggregate__")); - assert_eq!( - stat_columns, - vec!["pos2min".to_string(), "pos2max".to_string()] - ); - assert!(consumed_aesthetics.contains(&"pos2".to_string())); - assert!(consumed_aesthetics.contains(&"pos2min".to_string())); - assert!(consumed_aesthetics.contains(&"pos2max".to_string())); + assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); + assert!(!query.contains("\"__ggsql_aes_pos1__\"")); + assert_eq!(stat_columns, vec!["pos2".to_string()]); } _ => panic!("expected Transformed"), } } #[test] - fn range_mode_rejects_single_function() { - let (aes, schema) = range_input_aes_with_group(); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("mean".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - range_pair(), - ); - let err = result.unwrap_err().to_string(); - assert!( - err.contains("exactly two"), - "expected 'exactly two' in error, got: {}", - err - ); - } - - #[test] - fn range_mode_rejects_three_functions() { - let (aes, schema) = range_input_aes_with_group(); - let mut params = HashMap::new(); - use crate::plot::types::ArrayElement; - params.insert( - "aggregate".to_string(), - ParameterValue::Array(vec![ - ArrayElement::String("min".to_string()), - ArrayElement::String("mean".to_string()), - ArrayElement::String("max".to_string()), - ]), - ); - - let err = apply( - "SELECT * FROM t", - &schema, + fn quantile_uses_dialect_inline_when_available() { + let mut aes = Mappings::new(); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = schema_for(&[("__ggsql_aes_pos2__", false)]); + let result = run( + ParameterValue::String("p25".to_string()), &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - range_pair(), - ) - .unwrap_err() - .to_string(); - assert!(err.contains("exactly two")); - } - - #[test] - fn range_mode_quantile_uses_inline_when_available() { - let (aes, schema) = range_input_aes_with_group(); - let mut params = HashMap::new(); - use crate::plot::types::ArrayElement; - params.insert( - "aggregate".to_string(), - ParameterValue::Array(vec![ - ArrayElement::String("p25".to_string()), - ArrayElement::String("p75".to_string()), - ]), - ); - - let result = apply( - "SELECT * FROM t", &schema, - &aes, &[], - ¶ms, &InlineQuantileDialect, - &[2], - range_pair(), ) .unwrap(); match result { StatResult::Transformed { query, .. } => { assert!(query.contains("QUANTILE_CONT")); assert!(query.contains("0.25")); - assert!(query.contains("0.75")); - assert!(!query.contains("NTILE(4)")); } _ => panic!("expected Transformed"), } } #[test] - fn range_mode_quantile_falls_back_without_dialect_support() { - let (aes, schema) = range_input_aes_with_group(); - let mut params = HashMap::new(); - use crate::plot::types::ArrayElement; - params.insert( - "aggregate".to_string(), - ParameterValue::Array(vec![ - ArrayElement::String("p25".to_string()), - ArrayElement::String("p75".to_string()), - ]), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &NoInlineQuantileDialect, - &[2], - range_pair(), - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - assert!(query.contains("NTILE(4)")); - } - _ => panic!("expected Transformed"), - } - } - - #[test] - fn range_mode_requires_pos2_input() { - // Range geom but pos2 not mapped → error. - let mut aes = Mappings::new(); - aes.insert("pos1", col("__ggsql_aes_pos1__")); - let schema = vec![ColumnInfo { - name: "__ggsql_aes_pos1__".to_string(), - dtype: DataType::Utf8, - is_discrete: true, - min: None, - max: None, - }]; - let mut params = HashMap::new(); - use crate::plot::types::ArrayElement; - params.insert( - "aggregate".to_string(), - ParameterValue::Array(vec![ - ArrayElement::String("mean-sdev".to_string()), - ArrayElement::String("mean+sdev".to_string()), - ]), - ); - - let err = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - range_pair(), - ) - .unwrap_err() - .to_string(); - assert!( - err.contains("pos2") || err.contains("`y`"), - "expected pos2/y mention in error, got: {}", - err - ); - } - - // ======================================================================== - // Parser tests (parse_agg_name) - // ======================================================================== - - #[test] - fn parse_simple_names() { - assert_eq!( - parse_agg_name("mean"), - Some(AggSpec { - offset: "mean", - band: None - }) - ); - assert_eq!( - parse_agg_name("count"), - Some(AggSpec { - offset: "count", - band: None - }) - ); - assert_eq!( - parse_agg_name("p25"), - Some(AggSpec { - offset: "p25", - band: None - }) - ); - } - - #[test] - fn parse_band_default_modifier() { - let spec = parse_agg_name("mean+sdev").unwrap(); - assert_eq!(spec.offset, "mean"); - let band = spec.band.unwrap(); - assert_eq!(band.sign, '+'); - assert_eq!(band.mod_value, 1.0); - assert_eq!(band.expansion, "sdev"); - } - - #[test] - fn parse_band_integer_modifier() { - let spec = parse_agg_name("mean-2sdev").unwrap(); - let band = spec.band.unwrap(); - assert_eq!(band.sign, '-'); - assert_eq!(band.mod_value, 2.0); - assert_eq!(band.expansion, "sdev"); - } - - #[test] - fn parse_band_decimal_modifier() { - let spec = parse_agg_name("mean+1.96sdev").unwrap(); - let band = spec.band.unwrap(); - assert_eq!(band.mod_value, 1.96); - } - - #[test] - fn parse_band_longest_offset_wins() { - // 'median+sdev' must match offset 'median', not 'me' (which isn't an - // offset anyway, but more pertinently the parser must not stop at a - // shorter prefix). - let spec = parse_agg_name("median+sdev").unwrap(); - assert_eq!(spec.offset, "median"); - } - - #[test] - fn parse_band_percentile_offset() { - let spec = parse_agg_name("p25+0.5range").unwrap(); - assert_eq!(spec.offset, "p25"); - let band = spec.band.unwrap(); - assert_eq!(band.mod_value, 0.5); - assert_eq!(band.expansion, "range"); - } - - #[test] - fn parse_band_rejects_invalid_offset() { - assert!(parse_agg_name("count+sdev").is_none()); - assert!(parse_agg_name("iqr+sdev").is_none()); - } - - #[test] - fn parse_band_rejects_invalid_expansion() { - assert!(parse_agg_name("mean+count").is_none()); - assert!(parse_agg_name("mean+median").is_none()); - } - - #[test] - fn parse_rejects_unknown() { - assert!(parse_agg_name("foo").is_none()); - assert!(parse_agg_name("").is_none()); - } - - // ======================================================================== - // Validation tests (validate_aggregate_param) - // ======================================================================== - - #[test] - fn validate_accepts_simple_names_and_bands() { - use crate::plot::types::ArrayElement; - validate_aggregate_param(&ParameterValue::String("mean".to_string())).unwrap(); - validate_aggregate_param(&ParameterValue::String("mean+sdev".to_string())).unwrap(); - validate_aggregate_param(&ParameterValue::String("median-0.5iqr".to_string())).unwrap(); - validate_aggregate_param(&ParameterValue::Array(vec![ - ArrayElement::String("mean".to_string()), - ArrayElement::String("mean+1.96sdev".to_string()), - ])) - .unwrap(); - } - - #[test] - fn validate_diagnostic_for_invalid_offset() { - let err = validate_aggregate_param(&ParameterValue::String("count+sdev".to_string())) - .unwrap_err(); - assert!(err.contains("count"), "err: {}", err); - assert!(err.contains("offset"), "err: {}", err); - } - - #[test] - fn validate_diagnostic_for_invalid_expansion() { - let err = validate_aggregate_param(&ParameterValue::String("mean+count".to_string())) - .unwrap_err(); - assert!(err.contains("count"), "err: {}", err); - assert!(err.contains("expansion"), "err: {}", err); - } - - #[test] - fn validate_diagnostic_for_unknown() { - let err = validate_aggregate_param(&ParameterValue::String("foo".to_string())).unwrap_err(); - assert!(err.contains("unknown"), "err: {}", err); - assert!(err.contains("foo"), "err: {}", err); - } - - // ======================================================================== - // SQL emission for parametric bands - // ======================================================================== - - #[test] - fn band_decimal_modifier_emits_in_sql() { + fn quantile_falls_back_to_correlated_subquery_without_inline() { let mut aes = Mappings::new(); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("mean+1.96sdev".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, - &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - assert!( - query.contains( - "AVG(\"__ggsql_aes_pos2__\") + 1.96 * STDDEV_POP(\"__ggsql_aes_pos2__\")" - ), - "query: {}", - query - ); - } - _ => panic!("expected Transformed"), - } - } - - #[test] - fn band_with_percentile_offset_inline() { - // median-0.5iqr on a dialect with inline quantile support. - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("median-0.5iqr".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, + let schema = schema_for(&[("__ggsql_aes_pos2__", false)]); + let result = run( + ParameterValue::String("p25".to_string()), &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - // median uses QUANTILE_CONT(col, 0.5); iqr uses QUANTILE_CONT(.., 0.75) and 0.25. - assert!( - query.contains("QUANTILE_CONT") && query.contains("0.5"), - "query: {}", - query - ); - assert!(query.contains("0.75") && query.contains("0.25")); - } - _ => panic!("expected Transformed"), - } - } - - #[test] - fn band_with_percentile_offset_falls_back() { - // median+2sdev on a dialect WITHOUT inline quantile support → UNION-ALL - // path with sql_percentile for median, inline STDDEV_POP for sdev. - let mut aes = Mappings::new(); - aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("median+2sdev".to_string()), - ); - - let result = apply( - "SELECT * FROM t", &schema, - &aes, &[], - ¶ms, &NoInlineQuantileDialect, - &[2], - None, ) .unwrap(); match result { StatResult::Transformed { query, .. } => { + // The fallback dialect's sql_percentile uses NTILE. assert!(query.contains("NTILE(4)")); - assert!(query.contains("STDDEV_POP")); - assert!(query.contains("2 * ")); + // No explosion any more — single SELECT, no UNION ALL. + assert!(!query.contains("UNION ALL")); } _ => panic!("expected Transformed"), } } #[test] - fn band_with_default_modifier_omits_one_prefix() { + fn unknown_targeted_aesthetic_is_error() { let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = numeric_schema(&["__ggsql_aes_pos2__"]); - let mut params = HashMap::new(); - params.insert( - "aggregate".to_string(), - ParameterValue::String("mean+sdev".to_string()), - ); - - let result = apply( - "SELECT * FROM t", - &schema, + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ]); + let err = run( + ParameterValue::String("size:mean".to_string()), &aes, - &[], - ¶ms, - &InlineQuantileDialect, - &[2], - None, - ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - // mod=1 case: (offset + exp), no `1 *` prefix. - assert!( - query.contains( - "AVG(\"__ggsql_aes_pos2__\") + STDDEV_POP(\"__ggsql_aes_pos2__\")" - ), - "expected `(AVG + STDDEV_POP)` form, got: {}", - query - ); - assert!(!query.contains("1 * STDDEV_POP")); - } - _ => panic!("expected Transformed"), - } - } - - #[test] - fn range_mode_supports_decimal_band() { - // Ribbon range mode + 95% CI band. - let (aes, schema) = range_input_aes_with_group(); - let mut params = HashMap::new(); - use crate::plot::types::ArrayElement; - params.insert( - "aggregate".to_string(), - ParameterValue::Array(vec![ - ArrayElement::String("mean-1.96sdev".to_string()), - ArrayElement::String("mean+1.96sdev".to_string()), - ]), - ); - - let result = apply( - "SELECT * FROM t", &schema, - &aes, &[], - ¶ms, &InlineQuantileDialect, - &[2], - range_pair(), ) - .unwrap(); - match result { - StatResult::Transformed { query, .. } => { - assert!(query.contains("- 1.96 * STDDEV_POP")); - assert!(query.contains("+ 1.96 * STDDEV_POP")); - } - _ => panic!("expected Transformed"), - } + .unwrap_err(); + let msg = format!("{}", err); + assert!(msg.contains("not mapped"), "got: {}", msg); } } diff --git a/src/plot/layer/geom/text.rs b/src/plot/layer/geom/text.rs index d9af79ac..5909c34d 100644 --- a/src/plot/layer/geom/text.rs +++ b/src/plot/layer/geom/text.rs @@ -67,10 +67,6 @@ impl GeomTrait for Text { true } - fn aggregate_slots(&self) -> &'static [u8] { - &[1, 2] - } - fn post_process( &self, df: DataFrame, diff --git a/src/plot/layer/geom/tile.rs b/src/plot/layer/geom/tile.rs index 3633f944..fea51d38 100644 --- a/src/plot/layer/geom/tile.rs +++ b/src/plot/layer/geom/tile.rs @@ -104,6 +104,7 @@ impl GeomTrait for Tile { parameters: &HashMap, _execute_query: &dyn Fn(&str) -> Result, _dialect: &dyn crate::reader::SqlDialect, + _aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> Result { stat_tile(query, schema, aesthetics, group_by, parameters) } diff --git a/src/plot/layer/geom/violin.rs b/src/plot/layer/geom/violin.rs index 7fef55eb..6ee8d95b 100644 --- a/src/plot/layer/geom/violin.rs +++ b/src/plot/layer/geom/violin.rs @@ -123,6 +123,7 @@ impl GeomTrait for Violin { parameters: &HashMap, _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, + _aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> Result { stat_violin(query, aesthetics, group_by, parameters, dialect) } diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 91961156..1cea17ac 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -199,26 +199,14 @@ impl Layer { format!("`{}`", name) }; - // Check if all required aesthetics exist. - // When `aggregate` is set on a range geom, the (lower, upper) range pair - // is filled by the stat (e.g. pos2min/pos2max for ribbon, pos2/pos2end - // for segment) and shouldn't be required from the user. - let range_pair_skip: Option<(&'static str, &'static str)> = - if crate::plot::layer::geom::has_aggregate_param(&self.parameters) { - self.geom.aggregate_range_pair() - } else { - None - }; - + // Check if all required aesthetics exist. The Aggregate stat replaces + // mapped values in place — it never synthesises new aesthetics — so + // every required aesthetic must be mapped by the user regardless of + // the `aggregate` setting. let mut missing = Vec::new(); let mut position_reqs: Vec<(&str, u8, &str)> = Vec::new(); for aesthetic in self.geom.aesthetics().required() { - if let Some((lo, hi)) = range_pair_skip { - if aesthetic == lo || aesthetic == hi { - continue; - } - } if let Some((slot, suffix)) = parse_position(aesthetic) { position_reqs.push((aesthetic, slot, suffix)) } else if !self.mappings.contains_key(aesthetic) { From caf0a8e88e3a7d933d170eaf6347df193ca3e800 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 4 May 2026 14:57:10 +0200 Subject: [PATCH 16/33] add back long-form aggregation --- CHANGELOG.md | 22 +- doc/syntax/clause/draw.qmd | 13 +- doc/syntax/layer/type/line.qmd | 7 +- src/execute/layer.rs | 18 + src/plot/layer/geom/stat_aggregate.rs | 485 +++++++++++++++++++++++--- 5 files changed, 485 insertions(+), 60 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e176731..71974b7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,13 +3,21 @@ ### Added - New `aggregate` SETTING on Identity-stat layers (point, line, area, bar, ribbon, -range, segment, arrow, rule, text). Collapses each group to a single row by -replacing every numeric mapping in place with its aggregated value. Accepts a -single string or array of strings; entries are either unprefixed defaults -(`'mean'`) or per-aesthetic targets (`'y:max'`, `'color:median'`). Up to two -defaults may be supplied — the first applies to lower-half aesthetics plus all -non-range layers, the second to upper-half (`max`/`end` suffix). Numeric -mappings without a target or applicable default are dropped with a warning. +range, segment, arrow, rule, text). By default it collapses each group to a +single row by replacing every numeric mapping in place with its aggregated +value. Accepts a single string or array of strings; entries are either +unprefixed defaults (`'mean'`) or per-aesthetic targets (`'y:max'`, +`'color:median'`). Up to two defaults may be supplied — the first applies to +lower-half aesthetics plus all non-range layers, the second to upper-half +(`max`/`end` suffix). Numeric mappings without a target or applicable default +are dropped with a warning. Targeting the same aesthetic more than once +(e.g. `aggregate => ('y:min', 'y:max')`) produces one row per function with +a synthetic `aggregate` column tagging each row, available for `REMAPPING` to +another aesthetic; targets with a single function and the unprefixed defaults +are reused unchanged across the exploded rows. The `aggregate` column's value +is built from the dedup-and-joined function names of all exploded targets at +each row, separated by `/` (so `('y:min', 'y:max', 'color:sum', 'color:prod')` +yields `'min/sum'` and `'max/prod'`). Mixed lengths above 1 are an error. - Add cell delimiters and code lens actions to the Positron extension (#366) - ODBC is now turned on for the CLI as well (#344) - `FROM` can now come before `VISUALIZE`, mirroring the DuckDB style. This means diff --git a/doc/syntax/clause/draw.qmd b/doc/syntax/clause/draw.qmd index ba84fc0e..08cd2b2c 100644 --- a/doc/syntax/clause/draw.qmd +++ b/doc/syntax/clause/draw.qmd @@ -86,6 +86,17 @@ The setting takes a single string or an array of strings. Each string is one of: A numeric mapping that has neither a target nor an applicable default is dropped from the layer with a warning. +You can also target the same aesthetic more than once to produce **multiple rows per group** — one for each function. For example `aggregate => ('y:min', 'y:max')` emits a min row and a max row per group, so a single `DRAW line` produces two summary lines that connect within each group rather than across them. + +The stat exposes a synthetic `aggregate` column tagging each row, which you can pick up with a `REMAPPING` to drive another aesthetic — e.g. `REMAPPING aggregate AS stroke` to colour the two lines differently. The column's value is built from the per-row function names of the *exploded* targets, deduplicated, and joined with `/`: + +* `aggregate => ('y:min', 'y:max')` → rows tagged `'min'`, `'max'`. +* `aggregate => ('y:min', 'y:max', 'color:sum', 'color:prod')` → rows tagged `'min/sum'`, `'max/prod'`. +* `aggregate => ('y:mean', 'y:max', 'color:mean', 'color:prod')` → rows tagged `'mean'`, `'max/prod'` (the duplicate `'mean'` collapses). +* `aggregate => ('y:min', 'y:max', 'color:median')` → rows tagged `'min'`, `'max'` (the single-function `color` target is recycled across rows and is not part of the label). + +When several aesthetics are targeted with the same number of functions, they explode in lockstep (row 1 uses each aesthetic's first function, row 2 the second, and so on); aesthetics with a single function — and the unprefixed defaults — are reused unchanged across every row. Mixing different lengths above 1 is an error. + The simple functions are: * `'count'`: Non-null tally of the bound column. @@ -102,7 +113,7 @@ Allowed offsets are: `'mean'`, `'median'`, `'geomean'`, `'harmean'`, `'rms'`, `' Allowed expansions are: `'sdev'`, `'se'`, `'var'`, `'iqr'`, and `'range'` -Aggregation applies in place: there is no extra `aggregate` column to remap, and you do not need a `REMAPPING` clause to consume aggregate output. The aggregated value replaces the bound column for the same aesthetic. +In the single-row (reduction) case aggregation applies in place — no `REMAPPING` is needed and no synthetic column is added. Only the multi-row (explosion) case described above introduces the synthetic `aggregate` column. ### `FILTER` ```ggsql diff --git a/doc/syntax/layer/type/line.qmd b/doc/syntax/layer/type/line.qmd index a40fd486..ca266a6a 100644 --- a/doc/syntax/layer/type/line.qmd +++ b/doc/syntax/layer/type/line.qmd @@ -96,13 +96,12 @@ DRAW line SCALE linewidth TO (0, 30) ``` -Use aggregation to draw min and max lines from a set of observations. Each layer produces one summary trace; stack two layers for both bounds. +Use aggregation to draw min and max lines from a set of observations on a single layer. Targeting `y` twice produces one summary line per function within the same layer, with a synthetic `aggregate` column tagging each row that you can remap to colour the lines distinctly: ```{ggsql} VISUALISE Day AS x, Temp AS y FROM ggsql:airquality DRAW line - SETTING aggregate => 'min' -DRAW line - SETTING aggregate => 'max' + REMAPPING aggregate AS stroke + SETTING aggregate => ('y:min', 'y:max') DRAW point ``` diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 51457a93..8ada0d8c 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -604,6 +604,24 @@ where } } + // The synthetic `aggregate` stat column produced by an exploded + // Aggregate stat tags each row with its function name. For mark + // types that connect rows within a group (line, area, path, + // polygon) we add this column to `layer.partition_by` so e.g. + // `aggregate => ('y:min', 'y:max')` renders as two separate lines + // rather than one zigzag through both. Resolves to the post-rename + // data-column name: if the user remapped `aggregate AS `, the + // prefixed aesthetic column; otherwise the stat column. + if stat_columns.iter().any(|s| s == "aggregate") { + let partition_col = match final_remappings.get("aggregate") { + Some(aes) => naming::aesthetic_column(aes), + None => naming::stat_column("aggregate"), + }; + if !layer.partition_by.contains(&partition_col) { + layer.partition_by.push(partition_col); + } + } + // Apply stat_columns to layer aesthetics using the remappings for stat in &stat_columns { if let Some(aesthetic) = final_remappings.get(stat) { diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 6078d8f5..3dc102e4 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -153,15 +153,20 @@ fn parse_mod_and_remainder(s: &str) -> (f64, &str) { // AggregateSpec — parsed representation of the `aggregate` SETTING. // ============================================================================= -/// Parsed `aggregate` SETTING: zero, one, or two unprefixed defaults plus an -/// optional set of per-aesthetic targets keyed by user-facing aesthetic name. +/// Parsed `aggregate` SETTING. +/// +/// Up to two unprefixed defaults plus per-aesthetic targets. A target may be +/// named more than once; the multiple functions cause that aesthetic to +/// *explode* into multiple rows per group #[derive(Debug, Clone, PartialEq)] pub struct AggregateSpec { pub default_lower: Option, pub default_upper: Option, - /// Targets keyed by user-facing aesthetic name (e.g. `"y"`, `"xmax"`, - /// `"color"`). Resolved to internal names at apply-time. - pub targets: HashMap, + /// Targets in declaration order. Each entry is `(user-facing aesthetic, + /// non-empty list of functions)`. Multiple SETTING entries with the same + /// aesthetic are merged into one list during parsing — the cumulative + /// length determines that aesthetic's explosion factor. + pub targets: Vec<(String, Vec)>, } impl AggregateSpec { @@ -169,13 +174,81 @@ impl AggregateSpec { Self { default_lower: None, default_upper: None, - targets: HashMap::new(), + targets: Vec::new(), + } + } + + /// Maximum target list length, or `1` if every target has a single function. + /// This is the number of exploded rows the stat will emit per group. + pub fn explosion_factor(&self) -> usize { + self.targets + .iter() + .map(|(_, fns)| fns.len()) + .max() + .unwrap_or(1) + .max(1) + } + + /// Per-row labels for the synthetic `aggregate` column. `None` for the + /// single-row case (no explosion), since the column only makes sense as a + /// row-differentiator and there's nothing to differentiate. + /// + /// For each row in `0..explosion_factor`, walks every *exploded* target + /// (length == n; length-1 recycled targets are skipped because they take + /// the same value on every row), collects each target's function name at + /// that row, deduplicates them while preserving declaration order, and + /// joins with `/`. + /// + /// Examples (with `n = 2`): + /// - `('y:min', 'y:max')` → `['min', 'max']` + /// - `('y:min', 'y:max', 'color:sum', 'color:prod')` → `['min/sum', 'max/prod']` + /// - `('y:mean', 'y:max', 'color:mean', 'color:prod')` → `['mean', 'max/prod']` + /// - `('y:min', 'y:max', 'color:median')` → `['min', 'max']` (color is recycled) + pub fn explosion_labels(&self) -> Option> { + let n = self.explosion_factor(); + if n <= 1 { + return None; } + let exploded: Vec<&Vec> = self + .targets + .iter() + .filter(|(_, fns)| fns.len() == n) + .map(|(_, fns)| fns) + .collect(); + let labels = (0..n) + .map(|row| { + let mut parts: Vec = Vec::new(); + for fns in &exploded { + let label = agg_label(&fns[row]); + if !parts.contains(&label) { + parts.push(label); + } + } + parts.join("/") + }) + .collect(); + Some(labels) } } -/// Parse the `aggregate` SETTING value into an `AggregateSpec`. Returns `Ok(None)` -/// when the parameter is unset or null. Returns `Err(...)` for malformed input. +/// Human-readable label for an `AggSpec`. Re-emits simple names verbatim and +/// reconstructs band names like `mean+sdev`. +fn agg_label(spec: &AggSpec) -> String { + match &spec.band { + None => spec.offset.to_string(), + Some(b) => { + if b.mod_value == 1.0 { + format!("{}{}{}", spec.offset, b.sign, b.expansion) + } else { + format!("{}{}{}{}", spec.offset, b.sign, b.mod_value, b.expansion) + } + } + } +} + +/// Parse the `aggregate` SETTING value into an `AggregateSpec`. Returns +/// `Ok(None)` when the parameter is unset, null, or empty. Returns `Err(...)` +/// for malformed input. pub fn parse_aggregate_param( value: &ParameterValue, ) -> std::result::Result, String> { @@ -217,13 +290,12 @@ pub fn parse_aggregate_param( diagnose_invalid_function_name(func) ) })?; - if spec.targets.contains_key(aes) { - return Err(format!( - "aesthetic '{}' is targeted by more than one aggregate", - aes - )); + // Append to existing list for this aesthetic, or create one. + if let Some((_, fns)) = spec.targets.iter_mut().find(|(a, _)| a == aes) { + fns.push(agg); + } else { + spec.targets.push((aes.to_string(), vec![agg])); } - spec.targets.insert(aes.to_string(), agg); } else { let agg = parse_agg_name(entry) .ok_or_else(|| diagnose_invalid_function_name(entry))?; @@ -243,6 +315,23 @@ pub fn parse_aggregate_param( if spec.default_lower.is_none() && spec.default_upper.is_none() && spec.targets.is_empty() { return Ok(None); } + + // Validate recycling: every target list must be length 1 or N (the max). + let n = spec.explosion_factor(); + if n > 1 { + for (aes, fns) in &spec.targets { + if fns.len() != 1 && fns.len() != n { + return Err(format!( + "aggregate target '{}' has {} functions; targets in an exploded layer must \ + have either 1 or {} functions (the longest target's count)", + aes, + fns.len(), + n + )); + } + } + } + Ok(Some(spec)) } @@ -506,8 +595,10 @@ fn is_upper_half(internal_aes: &str) -> bool { /// Apply the Aggregate stat to a layer query. /// /// Returns `StatResult::Identity` when the `aggregate` parameter is unset, null, -/// or empty. Otherwise, builds a single-pass `GROUP BY` query producing one row -/// per group with one aggregated column per kept numeric mapping. +/// or empty. Otherwise, builds a `GROUP BY` query producing one row per group +/// (the *reduce* path) — or, when at least one target lists multiple functions, +/// `N` rows per group with a synthetic `aggregate` column tagging each row +/// (the *explode* path). #[allow(clippy::too_many_arguments)] pub fn apply( query: &str, @@ -528,12 +619,15 @@ pub fn apply( Some(s) => s, None => return Ok(StatResult::Identity), }; - - // Resolve target keys (user-facing) → internal aesthetic names. An alias - // like `color` expands to whichever of its targets (stroke/fill) is mapped - // on the layer; the function applies to all of them. - let mut targets_internal: HashMap = HashMap::new(); - for (user_aes, agg) in &spec.targets { + let n = spec.explosion_factor(); + let labels = spec.explosion_labels(); + + // Resolve target keys (user-facing) → internal aesthetic names, keeping + // each target's function list. An alias like `color` expands to whichever + // of its targets (stroke/fill) is mapped on the layer; the same list + // applies to all of them. + let mut targets_internal: HashMap> = HashMap::new(); + for (user_aes, fns) in &spec.targets { let resolved = resolve_target_aesthetic(user_aes, aesthetics, aesthetic_ctx); if resolved.is_empty() { return Err(GgsqlError::ValidationError(format!( @@ -548,15 +642,15 @@ pub fn apply( user_aes, internal ))); } - targets_internal.insert(internal, agg.clone()); + targets_internal.insert(internal, fns.clone()); } } // Walk mappings. Three buckets: - // - aggregated: (internal_aes, raw_col, AggSpec) — each emits one column + // - aggregated: (internal_aes, raw_col, fns of length n) — each emits one column per row // - kept_cols: discrete column-mappings — keep as group key // - dropped: numeric mapping with no applicable function (warn & skip) - let mut aggregated: Vec<(String, String, AggSpec)> = Vec::new(); + let mut aggregated: Vec<(String, String, Vec)> = Vec::new(); let mut kept_cols: Vec = Vec::new(); let mut dropped: Vec = Vec::new(); @@ -577,20 +671,26 @@ pub fn apply( continue; } - // Numeric mapping. Look up the aggregation function. - let agg = if let Some(targeted) = targets_internal.get(aes) { - Some(targeted.clone()) - } else if is_upper_half(aes) { - spec.default_upper - .clone() - .or_else(|| spec.default_lower.clone()) - .filter(|_| spec.default_upper.is_some() || spec.default_lower.is_some()) + // Numeric mapping. Look up the function list (recycling to length n). + let fns: Option> = if let Some(list) = targets_internal.get(aes) { + if list.len() == n { + Some(list.clone()) + } else { + // Validated to be 1 or n during parsing; guard with a sanity check. + debug_assert_eq!(list.len(), 1); + Some(vec![list[0].clone(); n]) + } } else { - spec.default_lower.clone() + let default = if is_upper_half(aes) { + spec.default_upper.clone().or_else(|| spec.default_lower.clone()) + } else { + spec.default_lower.clone() + }; + default.map(|d| vec![d; n]) }; - match agg { - Some(a) => aggregated.push((aes.clone(), col, a)), + match fns { + Some(list) => aggregated.push((aes.clone(), col, list)), None => dropped.push(aes.clone()), } } @@ -629,11 +729,19 @@ pub fn apply( } } - let transformed_query = - build_group_by_query(query, &aggregated, &group_cols, dialect); + let transformed_query = match &labels { + Some(ls) => build_aggregate_query(query, &aggregated, &group_cols, ls, dialect), + None => build_group_by_query(query, &aggregated, &group_cols, dialect), + }; - let stat_columns: Vec = aggregated.iter().map(|(a, _, _)| a.clone()).collect(); + let mut stat_columns: Vec = aggregated.iter().map(|(a, _, _)| a.clone()).collect(); let consumed_aesthetics: Vec = stat_columns.clone(); + // The synthetic `aggregate` column is only emitted for the multi-row + // (explosion) case, where it differentiates rows that share the same + // group key. + if labels.is_some() { + stat_columns.push("aggregate".to_string()); + } Ok(StatResult::Transformed { query: transformed_query, @@ -643,14 +751,15 @@ pub fn apply( }) } -/// Build the `WITH src AS () SELECT , FROM src -/// AS "__ggsql_qt__" GROUP BY ` query. +/// Build the single-row `WITH src AS () SELECT , +/// FROM src AS "__ggsql_qt__" GROUP BY ` query. Each aggregated +/// aesthetic's function list is length 1 here. /// /// Falls back to `dialect.sql_percentile()` per-column when an aggregate's /// percentile component lacks inline support. fn build_group_by_query( query: &str, - aggregated: &[(String, String, AggSpec)], + aggregated: &[(String, String, Vec)], group_cols: &[String], dialect: &dyn SqlDialect, ) -> String { @@ -666,7 +775,8 @@ fn build_group_by_query( let mut select_parts: Vec = group_select.clone(); - for (aes, raw_col, agg) in aggregated { + for (aes, raw_col, fns) in aggregated { + let agg = &fns[0]; let stat_col = naming::stat_column(aes); let qcol = naming::quote_ident(raw_col); let expr = if needs_quantile_fallback(agg, raw_col, dialect) { @@ -688,6 +798,76 @@ fn build_group_by_query( ) } +/// Build the exploded `WITH src AS () UNION ALL +/// ...` query. One branch per row in `0..labels.len()`, each branch its own +/// `GROUP BY` with the row's aggregation functions and a literal label tagged +/// to `__ggsql_stat_aggregate__`. +fn build_aggregate_query( + query: &str, + aggregated: &[(String, String, Vec)], + group_cols: &[String], + labels: &[String], + dialect: &dyn SqlDialect, +) -> String { + let src_alias = "\"__ggsql_stat_src__\""; + let outer_alias = "\"__ggsql_qt__\""; + + let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); + let group_by_clause = if group_cols.is_empty() { + String::new() + } else { + format!(" GROUP BY {}", group_select.join(", ")) + }; + + let stat_aggregate_col = naming::stat_column("aggregate"); + + let branches: Vec = labels + .iter() + .enumerate() + .map(|(row_idx, label)| { + let mut select_parts: Vec = group_select.clone(); + + for (aes, raw_col, fns) in aggregated { + let agg = &fns[row_idx]; + let stat_col = naming::stat_column(aes); + let qcol = naming::quote_ident(raw_col); + let expr = if needs_quantile_fallback(agg, raw_col, dialect) { + agg_sql_fallback(agg, raw_col, dialect, src_alias, group_cols) + } else { + agg_sql_inline(agg, &qcol, dialect) + .expect("agg_sql_inline must succeed when needs_quantile_fallback is false") + }; + select_parts.push(format!("{} AS {}", expr, naming::quote_ident(&stat_col))); + } + + select_parts.push(format!( + "{} AS {}", + func_literal(label), + naming::quote_ident(&stat_aggregate_col) + )); + + format!( + "SELECT {} FROM {} AS {}{}", + select_parts.join(", "), + src_alias, + outer_alias, + group_by_clause, + ) + }) + .collect(); + + format!( + "WITH {src} AS ({query}) {body}", + src = src_alias, + query = query, + body = branches.join(" UNION ALL "), + ) +} + +fn func_literal(s: &str) -> String { + format!("'{}'", s.replace('\'', "''")) +} + #[cfg(test)] mod tests { use super::*; @@ -788,20 +968,107 @@ mod tests { assert!(err.contains("at most two"), "got: {}", err); } + fn target_funcs<'a>(spec: &'a AggregateSpec, aes: &str) -> Option<&'a [AggSpec]> { + spec.targets + .iter() + .find(|(a, _)| a == aes) + .map(|(_, fns)| fns.as_slice()) + } + #[test] fn parses_targeted_entries() { let s = parse_aggregate_param(&arr(&["mean", "y:max", "color:median"])) .unwrap() .unwrap(); assert_eq!(s.default_lower.as_ref().map(|a| a.offset), Some("mean")); - assert_eq!(s.targets.get("y").map(|a| a.offset), Some("max")); - assert_eq!(s.targets.get("color").map(|a| a.offset), Some("median")); + assert_eq!(target_funcs(&s, "y").map(|fs| fs[0].offset), Some("max")); + assert_eq!(target_funcs(&s, "color").map(|fs| fs[0].offset), Some("median")); + } + + #[test] + fn duplicate_target_explodes_into_a_list() { + let s = parse_aggregate_param(&arr(&["y:min", "y:max"])).unwrap().unwrap(); + let fns = target_funcs(&s, "y").unwrap(); + assert_eq!(fns.len(), 2); + assert_eq!(fns[0].offset, "min"); + assert_eq!(fns[1].offset, "max"); + assert_eq!(s.explosion_factor(), 2); + assert_eq!( + s.explosion_labels(), + Some(vec!["min".to_string(), "max".to_string()]) + ); } #[test] - fn duplicate_target_is_error() { - let err = parse_aggregate_param(&arr(&["y:mean", "y:median"])).unwrap_err(); - assert!(err.contains("more than one aggregate"), "got: {}", err); + fn multi_aesthetic_explosion_joins_unique_function_names() { + // Two exploded targets contribute distinct function names per row → 'min/sum', 'max/prod'. + let s = parse_aggregate_param(&arr(&["y:min", "y:max", "color:sum", "color:prod"])) + .unwrap() + .unwrap(); + assert_eq!( + s.explosion_labels(), + Some(vec!["min/sum".to_string(), "max/prod".to_string()]) + ); + } + + #[test] + fn multi_aesthetic_explosion_dedups_repeats() { + // y and color both use 'mean' at row 0 → label is just 'mean' (deduped). + let s = parse_aggregate_param(&arr(&["y:mean", "y:max", "color:mean", "color:prod"])) + .unwrap() + .unwrap(); + assert_eq!( + s.explosion_labels(), + Some(vec!["mean".to_string(), "max/prod".to_string()]) + ); + } + + #[test] + fn recycled_target_excluded_from_label() { + // color has length 1 → recycled, not exploded; label only reflects y's functions. + let s = parse_aggregate_param(&arr(&["y:min", "y:max", "color:median"])) + .unwrap() + .unwrap(); + assert_eq!( + s.explosion_labels(), + Some(vec!["min".to_string(), "max".to_string()]) + ); + } + + #[test] + fn single_row_returns_no_labels() { + // The aggregate column only makes sense as a row-differentiator, and a + // single-row aggregation has nothing to differentiate, so no labels. + let s = parse_aggregate_param(&ParameterValue::String("mean".to_string())) + .unwrap() + .unwrap(); + assert_eq!(s.explosion_labels(), None); + + let s = parse_aggregate_param(&arr(&["mean", "color:median"])).unwrap().unwrap(); + assert_eq!(s.explosion_labels(), None); + } + + #[test] + fn recycling_violation_is_error() { + // y has 2, color has 3 → mismatched, neither is 1 nor matches the longest. + let err = parse_aggregate_param(&arr(&[ + "y:min", + "y:max", + "color:p10", + "color:p50", + "color:p90", + ])) + .unwrap_err(); + assert!(err.contains("longest target"), "got: {}", err); + } + + #[test] + fn length_one_target_recycles_in_explosion() { + let s = parse_aggregate_param(&arr(&["y:min", "y:max", "color:median"])) + .unwrap() + .unwrap(); + assert_eq!(s.explosion_factor(), 2); + assert_eq!(target_funcs(&s, "color").map(|f| f.len()), Some(1)); } #[test] @@ -1042,6 +1309,128 @@ mod tests { } } + #[test] + fn explosion_emits_union_all_with_aggregate_label_column() { + // ('y:min', 'y:max') on a line-style layer → 2 rows per group, each + // tagged with the function name in __ggsql_stat_aggregate__. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ]); + let result = run(arr(&["y:min", "y:max"]), &aes, &schema, &[], &InlineQuantileDialect) + .unwrap(); + match result { + StatResult::Transformed { + query, + stat_columns, + consumed_aesthetics, + .. + } => { + assert!(query.contains("UNION ALL"), "{}", query); + assert!(query.contains("MIN(\"__ggsql_aes_pos2__\")"), "{}", query); + assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")"), "{}", query); + assert!(query.contains("'min' AS \"__ggsql_stat_aggregate\"")); + assert!(query.contains("'max' AS \"__ggsql_stat_aggregate\"")); + // Aesthetics consumed: pos2. The synthetic `aggregate` is in + // stat_columns but NOT consumed (it's a new column). + assert!(consumed_aesthetics.contains(&"pos2".to_string())); + assert!(!consumed_aesthetics.contains(&"aggregate".to_string())); + assert!(stat_columns.contains(&"aggregate".to_string())); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn explosion_recycles_length_one_targets_and_defaults() { + // ('mean', 'y:min', 'y:max', 'color:median'): + // - default 'mean' applies to non-targeted aesthetics, recycled + // - y is exploded into [min, max] → N=2 + // - color is targeted with one function → recycled to [median, median] + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + aes.insert("fill", col("__ggsql_aes_fill__")); + aes.insert("size", col("__ggsql_aes_size__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2__", false), + ("__ggsql_aes_fill__", false), + ("__ggsql_aes_size__", false), + ]); + let result = run( + arr(&["mean", "y:min", "y:max", "color:median"]), + &aes, + &schema, + &[], + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + // y is exploded → MIN and MAX appear in different branches + assert!(query.contains("MIN(\"__ggsql_aes_pos2__\")"), "{}", query); + assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")")); + // color (alias → fill) is recycled → QUANTILE_CONT(.5) appears in BOTH branches + let median_count = query.matches("QUANTILE_CONT(\"__ggsql_aes_fill__\", 0.5)").count(); + assert_eq!(median_count, 2, "color median should appear once per branch: {}", query); + // size has no target → uses default 'mean' → AVG appears in both branches + let avg_size = query.matches("AVG(\"__ggsql_aes_size__\")").count(); + assert_eq!(avg_size, 2, "size mean should appear once per branch: {}", query); + // pos1 (no target) → mean → AVG appears in both branches + let avg_pos1 = query.matches("AVG(\"__ggsql_aes_pos1__\")").count(); + assert_eq!(avg_pos1, 2); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn explosion_with_range_geom_two_defaults() { + // For ribbon: pos1 + pos2min (lower) + pos2max (upper). + // ('mean-sdev', 'mean+sdev', 'color:p25', 'color:p75'): + // - two defaults split lower/upper for range aesthetics + // - color is exploded → N=2 + // Result: two rows, with color taking p25 in row 0 and p75 in row 1; + // pos1/pos2min always use mean-sdev, pos2max always uses mean+sdev. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2min", col("__ggsql_aes_pos2min__")); + aes.insert("pos2max", col("__ggsql_aes_pos2max__")); + aes.insert("fill", col("__ggsql_aes_fill__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2min__", false), + ("__ggsql_aes_pos2max__", false), + ("__ggsql_aes_fill__", false), + ]); + let result = run( + arr(&["mean-sdev", "mean+sdev", "color:p25", "color:p75"]), + &aes, + &schema, + &[], + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, stat_columns, .. } => { + assert!(query.contains("UNION ALL")); + // pos2max always uses mean+sdev (upper default) — a `+` between AVG and STDDEV + let upper_branch_marker = "AVG(\"__ggsql_aes_pos2max__\") + STDDEV_POP"; + assert!(query.contains(upper_branch_marker), "{}", query); + // color uses p25 in one branch, p75 in another + assert!(query.contains("QUANTILE_CONT(\"__ggsql_aes_fill__\", 0.25)")); + assert!(query.contains("QUANTILE_CONT(\"__ggsql_aes_fill__\", 0.75)")); + // Synthetic aggregate column is present + assert!(stat_columns.contains(&"aggregate".to_string())); + } + _ => panic!("expected Transformed"), + } + } + #[test] fn discrete_mapping_becomes_group_key() { let mut aes = Mappings::new(); From 564673cdb09fa4b311e89b336c71b80a6dfe65cf Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 4 May 2026 14:57:27 +0200 Subject: [PATCH 17/33] reformat --- src/plot/layer/geom/stat_aggregate.rs | 190 +++++++++++++++++--------- 1 file changed, 126 insertions(+), 64 deletions(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 3dc102e4..a5e572cb 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -283,13 +283,8 @@ pub fn parse_aggregate_param( if func.is_empty() { return Err(format!("'{}': aggregate function is empty", entry)); } - let agg = parse_agg_name(func).ok_or_else(|| { - format!( - "'{}': {}", - entry, - diagnose_invalid_function_name(func) - ) - })?; + let agg = parse_agg_name(func) + .ok_or_else(|| format!("'{}': {}", entry, diagnose_invalid_function_name(func)))?; // Append to existing list for this aesthetic, or create one. if let Some((_, fns)) = spec.targets.iter_mut().find(|(a, _)| a == aes) { fns.push(agg); @@ -297,8 +292,7 @@ pub fn parse_aggregate_param( spec.targets.push((aes.to_string(), vec![agg])); } } else { - let agg = parse_agg_name(entry) - .ok_or_else(|| diagnose_invalid_function_name(entry))?; + let agg = parse_agg_name(entry).ok_or_else(|| diagnose_invalid_function_name(entry))?; if spec.default_lower.is_none() { spec.default_lower = Some(agg); } else if spec.default_upper.is_none() { @@ -613,8 +607,7 @@ pub fn apply( None | Some(ParameterValue::Null) => return Ok(StatResult::Identity), Some(v) => v, }; - let spec = parse_aggregate_param(raw) - .map_err(GgsqlError::ValidationError)?; + let spec = parse_aggregate_param(raw).map_err(GgsqlError::ValidationError)?; let spec = match spec { Some(s) => s, None => return Ok(StatResult::Identity), @@ -654,7 +647,8 @@ pub fn apply( let mut kept_cols: Vec = Vec::new(); let mut dropped: Vec = Vec::new(); - let mut entries: Vec<(&String, &crate::AestheticValue)> = aesthetics.aesthetics.iter().collect(); + let mut entries: Vec<(&String, &crate::AestheticValue)> = + aesthetics.aesthetics.iter().collect(); entries.sort_by(|a, b| a.0.cmp(b.0)); for (aes, value) in entries { @@ -682,7 +676,9 @@ pub fn apply( } } else { let default = if is_upper_half(aes) { - spec.default_upper.clone().or_else(|| spec.default_lower.clone()) + spec.default_upper + .clone() + .or_else(|| spec.default_lower.clone()) } else { spec.default_lower.clone() }; @@ -934,7 +930,12 @@ mod tests { } fn arr(items: &[&str]) -> ParameterValue { - ParameterValue::Array(items.iter().map(|s| ArrayElement::String(s.to_string())).collect()) + ParameterValue::Array( + items + .iter() + .map(|s| ArrayElement::String(s.to_string())) + .collect(), + ) } // ---------- parser tests ---------- @@ -957,7 +958,9 @@ mod tests { #[test] fn parses_two_defaults_in_order() { - let s = parse_aggregate_param(&arr(&["min", "max"])).unwrap().unwrap(); + let s = parse_aggregate_param(&arr(&["min", "max"])) + .unwrap() + .unwrap(); assert_eq!(s.default_lower.as_ref().map(|a| a.offset), Some("min")); assert_eq!(s.default_upper.as_ref().map(|a| a.offset), Some("max")); } @@ -982,12 +985,17 @@ mod tests { .unwrap(); assert_eq!(s.default_lower.as_ref().map(|a| a.offset), Some("mean")); assert_eq!(target_funcs(&s, "y").map(|fs| fs[0].offset), Some("max")); - assert_eq!(target_funcs(&s, "color").map(|fs| fs[0].offset), Some("median")); + assert_eq!( + target_funcs(&s, "color").map(|fs| fs[0].offset), + Some("median") + ); } #[test] fn duplicate_target_explodes_into_a_list() { - let s = parse_aggregate_param(&arr(&["y:min", "y:max"])).unwrap().unwrap(); + let s = parse_aggregate_param(&arr(&["y:min", "y:max"])) + .unwrap() + .unwrap(); let fns = target_funcs(&s, "y").unwrap(); assert_eq!(fns.len(), 2); assert_eq!(fns[0].offset, "min"); @@ -1044,7 +1052,9 @@ mod tests { .unwrap(); assert_eq!(s.explosion_labels(), None); - let s = parse_aggregate_param(&arr(&["mean", "color:median"])).unwrap().unwrap(); + let s = parse_aggregate_param(&arr(&["mean", "color:median"])) + .unwrap() + .unwrap(); assert_eq!(s.explosion_labels(), None); } @@ -1090,11 +1100,23 @@ mod tests { .unwrap(); assert_eq!(s.default_lower.as_ref().unwrap().offset, "mean"); assert_eq!( - s.default_lower.as_ref().unwrap().band.as_ref().unwrap().expansion, + s.default_lower + .as_ref() + .unwrap() + .band + .as_ref() + .unwrap() + .expansion, "sdev" ); assert_eq!( - s.default_lower.as_ref().unwrap().band.as_ref().unwrap().sign, + s.default_lower + .as_ref() + .unwrap() + .band + .as_ref() + .unwrap() + .sign, '-' ); assert_eq!(s.default_upper.as_ref().unwrap().offset, "mean"); @@ -1108,8 +1130,16 @@ mod tests { let schema: Schema = vec![]; let p: HashMap = HashMap::new(); let ctx = cartesian_ctx(); - let result = apply("SELECT * FROM t", &schema, &aes, &[], &p, &InlineQuantileDialect, &ctx) - .unwrap(); + let result = apply( + "SELECT * FROM t", + &schema, + &aes, + &[], + &p, + &InlineQuantileDialect, + &ctx, + ) + .unwrap(); assert_eq!(result, StatResult::Identity); } @@ -1117,7 +1147,14 @@ mod tests { fn returns_identity_when_param_null() { let aes = Mappings::new(); let schema: Schema = vec![]; - let result = run(ParameterValue::Null, &aes, &schema, &[], &InlineQuantileDialect).unwrap(); + let result = run( + ParameterValue::Null, + &aes, + &schema, + &[], + &InlineQuantileDialect, + ) + .unwrap(); assert_eq!(result, StatResult::Identity); } @@ -1126,10 +1163,7 @@ mod tests { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = schema_for(&[ - ("__ggsql_aes_pos1__", false), - ("__ggsql_aes_pos2__", false), - ]); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); let result = run( ParameterValue::String("mean".to_string()), &aes, @@ -1173,15 +1207,29 @@ mod tests { ("__ggsql_aes_pos1end__", false), ("__ggsql_aes_pos2end__", false), ]); - let result = run(arr(&["min", "max"]), &aes, &schema, &[], &InlineQuantileDialect) - .unwrap(); + let result = run( + arr(&["min", "max"]), + &aes, + &schema, + &[], + &InlineQuantileDialect, + ) + .unwrap(); match result { StatResult::Transformed { query, .. } => { // pos1, pos2 use MIN; pos1end, pos2end use MAX. assert!(query.contains("MIN(\"__ggsql_aes_pos1__\")"), "{}", query); assert!(query.contains("MIN(\"__ggsql_aes_pos2__\")"), "{}", query); - assert!(query.contains("MAX(\"__ggsql_aes_pos1end__\")"), "{}", query); - assert!(query.contains("MAX(\"__ggsql_aes_pos2end__\")"), "{}", query); + assert!( + query.contains("MAX(\"__ggsql_aes_pos1end__\")"), + "{}", + query + ); + assert!( + query.contains("MAX(\"__ggsql_aes_pos2end__\")"), + "{}", + query + ); assert!(!query.contains("MIN(\"__ggsql_aes_pos1end__\")")); assert!(!query.contains("MAX(\"__ggsql_aes_pos1__\")")); } @@ -1213,10 +1261,7 @@ mod tests { assert!(query.contains("STDDEV_POP(\"__ggsql_aes_pos2max__\")")); assert!(query.contains("AVG(\"__ggsql_aes_pos2min__\")")); // upper default (mean+sdev) goes to pos2max → '+' between AVG and STDDEV - let pos2max_section = query - .split("__ggsql_aes_pos2max__\")") - .next() - .unwrap_or(""); + let pos2max_section = query.split("__ggsql_aes_pos2max__\")").next().unwrap_or(""); assert!(pos2max_section.contains('+') || query.contains("+ STDDEV_POP")); } _ => panic!("expected Transformed"), @@ -1228,10 +1273,7 @@ mod tests { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = schema_for(&[ - ("__ggsql_aes_pos1__", false), - ("__ggsql_aes_pos2__", false), - ]); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); let result = run( arr(&["mean", "y:max"]), &aes, @@ -1270,7 +1312,11 @@ mod tests { ) .unwrap(); match result { - StatResult::Transformed { query, stat_columns, .. } => { + StatResult::Transformed { + query, + stat_columns, + .. + } => { assert!(query.contains("QUANTILE_CONT(\"__ggsql_aes_size__\", 0.5)")); assert!(stat_columns.contains(&"size".to_string())); } @@ -1300,7 +1346,11 @@ mod tests { ) .unwrap(); match result { - StatResult::Transformed { query, stat_columns, .. } => { + StatResult::Transformed { + query, + stat_columns, + .. + } => { assert!(query.contains("MAX(\"__ggsql_aes_fill__\")"), "{}", query); assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")")); assert!(stat_columns.contains(&"fill".to_string())); @@ -1316,12 +1366,15 @@ mod tests { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = schema_for(&[ - ("__ggsql_aes_pos1__", false), - ("__ggsql_aes_pos2__", false), - ]); - let result = run(arr(&["y:min", "y:max"]), &aes, &schema, &[], &InlineQuantileDialect) - .unwrap(); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); + let result = run( + arr(&["y:min", "y:max"]), + &aes, + &schema, + &[], + &InlineQuantileDialect, + ) + .unwrap(); match result { StatResult::Transformed { query, @@ -1375,11 +1428,21 @@ mod tests { assert!(query.contains("MIN(\"__ggsql_aes_pos2__\")"), "{}", query); assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")")); // color (alias → fill) is recycled → QUANTILE_CONT(.5) appears in BOTH branches - let median_count = query.matches("QUANTILE_CONT(\"__ggsql_aes_fill__\", 0.5)").count(); - assert_eq!(median_count, 2, "color median should appear once per branch: {}", query); + let median_count = query + .matches("QUANTILE_CONT(\"__ggsql_aes_fill__\", 0.5)") + .count(); + assert_eq!( + median_count, 2, + "color median should appear once per branch: {}", + query + ); // size has no target → uses default 'mean' → AVG appears in both branches let avg_size = query.matches("AVG(\"__ggsql_aes_size__\")").count(); - assert_eq!(avg_size, 2, "size mean should appear once per branch: {}", query); + assert_eq!( + avg_size, 2, + "size mean should appear once per branch: {}", + query + ); // pos1 (no target) → mean → AVG appears in both branches let avg_pos1 = query.matches("AVG(\"__ggsql_aes_pos1__\")").count(); assert_eq!(avg_pos1, 2); @@ -1416,7 +1479,11 @@ mod tests { ) .unwrap(); match result { - StatResult::Transformed { query, stat_columns, .. } => { + StatResult::Transformed { + query, + stat_columns, + .. + } => { assert!(query.contains("UNION ALL")); // pos2max always uses mean+sdev (upper default) — a `+` between AVG and STDDEV let upper_branch_marker = "AVG(\"__ggsql_aes_pos2max__\") + STDDEV_POP"; @@ -1456,7 +1523,11 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("GROUP BY \"__ggsql_aes_color__\""), "{}", query); + assert!( + query.contains("GROUP BY \"__ggsql_aes_color__\""), + "{}", + query + ); assert!(!stat_columns.contains(&"color".to_string())); assert!(query.contains("AVG(\"__ggsql_aes_pos1__\")")); assert!(query.contains("AVG(\"__ggsql_aes_pos2__\")")); @@ -1474,10 +1545,7 @@ mod tests { "fill", AestheticValue::Literal(ParameterValue::String("steelblue".to_string())), ); - let schema = schema_for(&[ - ("__ggsql_aes_pos1__", false), - ("__ggsql_aes_pos2__", false), - ]); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); let result = run( ParameterValue::String("mean".to_string()), &aes, @@ -1501,10 +1569,7 @@ mod tests { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = schema_for(&[ - ("__ggsql_aes_pos1__", false), - ("__ggsql_aes_pos2__", false), - ]); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); // Only `y` targeted, no default → x is dropped. let result = run( ParameterValue::String("y:mean".to_string()), @@ -1579,10 +1644,7 @@ mod tests { let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); aes.insert("pos2", col("__ggsql_aes_pos2__")); - let schema = schema_for(&[ - ("__ggsql_aes_pos1__", false), - ("__ggsql_aes_pos2__", false), - ]); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); let err = run( ParameterValue::String("size:mean".to_string()), &aes, From 88a707bdb92fac558756e8283dfc338870cd2955 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Mon, 4 May 2026 18:53:26 +0200 Subject: [PATCH 18/33] fix aggregation of time-dependent layers --- src/plot/layer/geom/area.rs | 5 ++ src/plot/layer/geom/bar.rs | 1 + src/plot/layer/geom/line.rs | 5 ++ src/plot/layer/geom/mod.rs | 12 ++++ src/plot/layer/geom/ribbon.rs | 5 ++ src/plot/layer/geom/stat_aggregate.rs | 84 ++++++++++++++++++++++++++- 6 files changed, 111 insertions(+), 1 deletion(-) diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index 59c4cb04..e4d1230a 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -59,6 +59,10 @@ impl GeomTrait for Area { true } + fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { + &["pos1"] + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -83,6 +87,7 @@ impl GeomTrait for Area { parameters, dialect, aesthetic_ctx, + self.aggregate_domain_aesthetics(), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index 65f24c80..f2990467 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -108,6 +108,7 @@ impl GeomTrait for Bar { parameters, dialect, aesthetic_ctx, + self.aggregate_domain_aesthetics(), ); } stat_bar_count(query, schema, aesthetics, group_by) diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index 63980856..a6fd8edd 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -45,6 +45,10 @@ impl GeomTrait for Line { true } + fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { + &["pos1"] + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -69,6 +73,7 @@ impl GeomTrait for Line { parameters, dialect, aesthetic_ctx, + self.aggregate_domain_aesthetics(), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index 38c8a034..1a6b7082 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -205,6 +205,17 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { false } + /// Aesthetics that the Aggregate stat must keep as group keys rather than + /// aggregating, even if their bound column is continuous. This is for + /// geoms like line/area/ribbon where one axis is the *domain* — the + /// natural group identity of each row — and the user expects "summarise + /// the other axis per domain value" without writing an explicit target. + /// + /// Default empty; line/area/ribbon override to `&["pos1"]`. + fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { + &[] + } + /// Apply statistical transformation to the layer query. /// /// The default implementation dispatches to the Aggregate stat when @@ -231,6 +242,7 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { parameters, dialect, aesthetic_ctx, + self.aggregate_domain_aesthetics(), ); } Ok(StatResult::Identity) diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index c13f56f5..5b2a390e 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -44,6 +44,10 @@ impl GeomTrait for Ribbon { true } + fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { + &["pos1"] + } + fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { true } @@ -68,6 +72,7 @@ impl GeomTrait for Ribbon { parameters, dialect, aesthetic_ctx, + self.aggregate_domain_aesthetics(), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index a5e572cb..4cb187b4 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -602,6 +602,7 @@ pub fn apply( parameters: &HashMap, dialect: &dyn SqlDialect, aesthetic_ctx: &AestheticContext, + domain_aesthetics: &[&'static str], ) -> Result { let raw = match parameters.get("aggregate") { None | Some(ParameterValue::Null) => return Ok(StatResult::Identity), @@ -656,6 +657,15 @@ pub fn apply( Some(c) => c.to_string(), None => continue, // literals & annotation columns pass through }; + // Geom-declared domain aesthetics (e.g. `pos1` for line/area/ribbon) + // always become group keys — they identify each row, never get + // aggregated, never get dropped. + if domain_aesthetics.contains(&aes.as_str()) { + if !kept_cols.contains(&col) { + kept_cols.push(col); + } + continue; + } let info = schema.iter().find(|c| c.name == col); let is_discrete = info.map(|c| c.is_discrete).unwrap_or(false); if is_discrete { @@ -922,11 +932,31 @@ mod tests { schema: &Schema, group_by: &[String], dialect: &dyn SqlDialect, + ) -> Result { + run_with_domain(params, aes, schema, group_by, dialect, &[]) + } + + fn run_with_domain( + params: ParameterValue, + aes: &Mappings, + schema: &Schema, + group_by: &[String], + dialect: &dyn SqlDialect, + domain: &[&'static str], ) -> Result { let mut p = HashMap::new(); p.insert("aggregate".to_string(), params); let ctx = cartesian_ctx(); - apply("SELECT * FROM t", schema, aes, group_by, &p, dialect, &ctx) + apply( + "SELECT * FROM t", + schema, + aes, + group_by, + &p, + dialect, + &ctx, + domain, + ) } fn arr(items: &[&str]) -> ParameterValue { @@ -1138,6 +1168,7 @@ mod tests { &p, &InlineQuantileDialect, &ctx, + &[], ) .unwrap(); assert_eq!(result, StatResult::Identity); @@ -1451,6 +1482,57 @@ mod tests { } } + #[test] + fn domain_aesthetic_kept_as_group_key_even_when_continuous() { + // Regression test for the line/area/ribbon case: the user writes + // DRAW line ... SETTING aggregate => ('y:min', 'y:max') + // and expects pos1 (the continuous time-axis column) to be a group + // key, not a dropped numeric mapping. The geom declares pos1 as a + // domain aesthetic; the stat keeps it as a group column. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), // continuous, would be dropped without the domain hint + ("__ggsql_aes_pos2__", false), + ]); + let result = run_with_domain( + arr(&["y:min", "y:max"]), + &aes, + &schema, + &[], + &InlineQuantileDialect, + &["pos1"], + ) + .unwrap(); + match result { + StatResult::Transformed { + query, + stat_columns, + consumed_aesthetics, + .. + } => { + // pos1 is in the GROUP BY, not aggregated. + assert!( + query.contains("GROUP BY \"__ggsql_aes_pos1__\""), + "{}", + query + ); + assert!(!query.contains("MIN(\"__ggsql_aes_pos1__\")")); + assert!(!query.contains("MAX(\"__ggsql_aes_pos1__\")")); + // pos2 is exploded into MIN and MAX branches. + assert!(query.contains("MIN(\"__ggsql_aes_pos2__\")")); + assert!(query.contains("MAX(\"__ggsql_aes_pos2__\")")); + // pos1 is NOT consumed (kept), pos2 IS consumed. + assert!(!consumed_aesthetics.contains(&"pos1".to_string())); + assert!(consumed_aesthetics.contains(&"pos2".to_string())); + // synthetic aggregate column emitted in the explosion case. + assert!(stat_columns.contains(&"aggregate".to_string())); + } + _ => panic!("expected Transformed"), + } + } + #[test] fn explosion_with_range_geom_two_defaults() { // For ribbon: pos1 + pos2min (lower) + pos2max (upper). From b1938d82a9209ad76dd9645504564d215e0f3af7 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Wed, 6 May 2026 13:36:30 +0200 Subject: [PATCH 19/33] add additional aggregations + examples --- doc/syntax/clause/draw.qmd | 5 +- doc/syntax/layer/type/point.qmd | 10 ++ doc/syntax/layer/type/range.qmd | 12 ++ doc/syntax/layer/type/ribbon.qmd | 8 + src/plot/layer/geom/stat_aggregate.rs | 215 ++++++++++++++++++++++---- src/reader/duckdb.rs | 8 + src/reader/mod.rs | 41 +++++ src/reader/sqlite.rs | 17 ++ 8 files changed, 287 insertions(+), 29 deletions(-) diff --git a/doc/syntax/clause/draw.qmd b/doc/syntax/clause/draw.qmd index 08cd2b2c..f507fc2e 100644 --- a/doc/syntax/clause/draw.qmd +++ b/doc/syntax/clause/draw.qmd @@ -101,15 +101,16 @@ The simple functions are: * `'count'`: Non-null tally of the bound column. * `'sum'` and `'prod'`: The sum or product -* `'min'`, `'max'`, and `'range'`: Extremes and max - min +* `'min'`, `'max'`, `'range'`, and `'mid'`: Extremes, max - min, and (min + max) / 2 * `'mean'`, and `'median'`: Central tendency * `'geomean'`, `'harmean'`, and `'rms'`: Geometric, harmonic, and root-mean-square * `'sdev'`, `'var'`, `'iqr'`, and `'se'`: Standard deviation, variance, interquartile range, and standard error * `'p05'`, `'p10'`, `'p25'`, `'p50'`, `'p75'`, `'p90'`, and `'p95'`: Percentiles +* `'first'` and `'last'`: The first or last value in the group, in row order For band functions you combine an offset with an expansion, potentially multiplied. An example could be `'mean-1.96sdev'` which does exactly what you'd expect it to be. The general form is `±` with `` being optional (defaults to `1`). -Allowed offsets are: `'mean'`, `'median'`, `'geomean'`, `'harmean'`, `'rms'`, `'sum'`, `'prod'`, `'min'`, `'max'`, and `'p05'`–`'p95'` +Allowed offsets are: `'mean'`, `'median'`, `'geomean'`, `'harmean'`, `'rms'`, `'sum'`, `'prod'`, `'min'`, `'max'`, `'mid'`, and `'p05'`–`'p95'` Allowed expansions are: `'sdev'`, `'se'`, `'var'`, `'iqr'`, and `'range'` diff --git a/doc/syntax/layer/type/point.qmd b/doc/syntax/layer/type/point.qmd index 16687e0f..46eb13ed 100644 --- a/doc/syntax/layer/type/point.qmd +++ b/doc/syntax/layer/type/point.qmd @@ -73,3 +73,13 @@ VISUALISE species AS x, bill_dep AS y FROM ggsql:penguins DRAW point SETTING position => 'jitter', distribution => 'density' ``` + +Use aggregation to show a single point per group + +```{ggsql} +VISUALISE species AS x, island AS y, body_mass AS fill, body_mass AS size + FROM ggsql:penguins +DRAW point + SETTING aggregate => ('fill:mean', 'size:count') +SCALE size TO (5, 20) +``` diff --git a/doc/syntax/layer/type/range.qmd b/doc/syntax/layer/type/range.qmd index d8c6d672..bd8d501b 100644 --- a/doc/syntax/layer/type/range.qmd +++ b/doc/syntax/layer/type/range.qmd @@ -109,3 +109,15 @@ DRAW range MAPPING low AS ymin, high AS ymax SETTING width => null ``` + +Use aggregation to calculate bounds dynamically + +```{ggsql} +VISUALISE body_mass AS x, species AS y FROM ggsql:penguins +DRAW range + MAPPING body_mass AS xmin, body_mass AS xmax + SETTING aggregate => ('min', 'max'), width => null +DRAW point + REMAPPING aggregate AS fill + SETTING aggregate => ('x:min', 'x:max'), size => 20, opacity => 1 +``` diff --git a/doc/syntax/layer/type/ribbon.qmd b/doc/syntax/layer/type/ribbon.qmd index cbc7379e..742328f4 100644 --- a/doc/syntax/layer/type/ribbon.qmd +++ b/doc/syntax/layer/type/ribbon.qmd @@ -60,3 +60,11 @@ DRAW ribbon DRAW line MAPPING MeanTemp AS y ``` + +Use aggregation to calculate bounds on the fly + +```{ggsql} +VISUALISE Day AS x, Temp AS ymin, Temp AS ymax FROM ggsql:airquality +DRAW ribbon + SETTING aggregate => ('min', 'max') +``` diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 4cb187b4..8015fd0e 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -42,18 +42,19 @@ use crate::{GgsqlError, Mappings, Result}; pub const AGG_NAMES: &[&str] = &[ // Tallies & sums "count", "sum", "prod", // Extremes - "min", "max", "range", // Central tendency + "min", "max", "range", "mid", // Central tendency "mean", "geomean", "harmean", "rms", "median", // Spread (standalone) "sdev", "var", "iqr", // Percentiles - "p05", "p10", "p25", "p50", "p75", "p90", "p95", + "p05", "p10", "p25", "p50", "p75", "p90", "p95", // Positional (row order in the group) + "first", "last", ]; /// Stats that can appear as the *offset* (left of `±`) in a band name like /// `mean+sdev`. Single-value central or representative quantities only — /// counts/spreads are excluded. pub const OFFSET_STATS: &[&str] = &[ - "mean", "median", "geomean", "harmean", "rms", "sum", "prod", "min", "max", "p05", "p10", - "p25", "p50", "p75", "p90", "p95", + "mean", "median", "geomean", "harmean", "rms", "sum", "prod", "min", "max", "mid", "p05", + "p10", "p25", "p50", "p75", "p90", "p95", ]; /// Stats that can appear as the *expansion* (right of `±[mod]`) in a band name. @@ -406,14 +407,12 @@ fn percentile_fraction(func: &str) -> Option { } /// Build the inline SQL fragment for a *simple* stat (no band) applied to a -/// quoted column. Returns `None` for percentile-based stats when the dialect -/// lacks an inline quantile aggregate (caller switches to the correlated -/// `sql_percentile` fallback). +/// quoted column. Returns `None` when the dialect cannot express this +/// aggregate inline — for the percentile/iqr family that means the caller +/// switches to the correlated `sql_percentile` fallback; for other names it +/// means the dialect doesn't support that function and the stat layer raises +/// a clear error before SQL is built (see `validate_supported`). fn simple_stat_sql_inline(name: &str, qcol: &str, dialect: &dyn SqlDialect) -> Option { - if name == "count" { - // `count` in this position is COUNT(col): non-null tally for that column. - return Some(format!("COUNT({})", qcol)); - } if let Some(frac) = percentile_fraction(name) { let unquoted = unquote(qcol); return dialect.sql_quantile_inline(&unquoted, frac); @@ -424,21 +423,40 @@ fn simple_stat_sql_inline(name: &str, qcol: &str, dialect: &dyn SqlDialect) -> O let p25 = dialect.sql_quantile_inline(&unquoted, 0.25)?; return Some(format!("({} - {})", p75, p25)); } - Some(match name { - "sum" => format!("SUM({})", qcol), - "prod" => format!("EXP(SUM(LN({})))", qcol), - "min" => format!("MIN({})", qcol), - "max" => format!("MAX({})", qcol), - "range" => format!("(MAX({c}) - MIN({c}))", c = qcol), - "mean" => format!("AVG({})", qcol), - "geomean" => format!("EXP(AVG(LN({})))", qcol), - "harmean" => format!("(COUNT({c}) * 1.0 / SUM(1.0 / {c}))", c = qcol), - "rms" => format!("SQRT(AVG({c} * {c}))", c = qcol), - "sdev" => format!("STDDEV_POP({})", qcol), - "se" => format!("(STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), - "var" => format!("VAR_POP({})", qcol), - _ => return None, - }) + dialect.sql_aggregate(name, qcol) +} + +/// Whether the dialect can produce SQL for this aggregate (inline or via the +/// percentile fallback). Used to surface a clear error before SQL is built. +fn dialect_supports(name: &str, dialect: &dyn SqlDialect) -> bool { + if percentile_fraction(name).is_some() || name == "iqr" { + // Always supported: percentile path falls back to a correlated subquery + // built from `sql_percentile`, which has a portable default. + return true; + } + dialect.sql_aggregate(name, "x").is_some() +} + +/// Walk every aggregate that will be emitted and confirm the dialect supports +/// it. Returns the list of unsupported function names, deduplicated. +fn unsupported_functions( + aggregated: &[(String, String, Vec)], + dialect: &dyn SqlDialect, +) -> Vec { + let mut missing: Vec = Vec::new(); + for (_, _, specs) in aggregated { + for spec in specs { + for name in [Some(spec.offset), spec.band.as_ref().map(|b| b.expansion)] + .into_iter() + .flatten() + { + if !dialect_supports(name, dialect) && !missing.iter().any(|m| m == name) { + missing.push(name.to_string()); + } + } + } + } + missing } fn agg_sql_inline(spec: &AggSpec, qcol: &str, dialect: &dyn SqlDialect) -> Option { @@ -735,6 +753,14 @@ pub fn apply( } } + let missing = unsupported_functions(&aggregated, dialect); + if !missing.is_empty() { + return Err(GgsqlError::ValidationError(format!( + "aggregate function(s) {} are not supported by this database backend", + crate::or_list_quoted(&missing, '\''), + ))); + } + let transformed_query = match &labels { Some(ls) => build_aggregate_query(query, &aggregated, &group_cols, ls, dialect), None => build_group_by_query(query, &aggregated, &group_cols, dialect), @@ -881,7 +907,8 @@ mod tests { use crate::plot::types::{AestheticValue, ColumnInfo}; use arrow::datatypes::DataType; - /// A test dialect that mimics DuckDB's native QUANTILE_CONT support. + /// A test dialect that mimics DuckDB: native QUANTILE_CONT plus the + /// row-positional FIRST / LAST aggregates. struct InlineQuantileDialect; impl SqlDialect for InlineQuantileDialect { fn sql_quantile_inline(&self, column: &str, fraction: f64) -> Option { @@ -891,6 +918,14 @@ mod tests { fraction )) } + + fn sql_aggregate(&self, name: &str, qcol: &str) -> Option { + match name { + "first" => Some(format!("FIRST({})", qcol)), + "last" => Some(format!("LAST({})", qcol)), + _ => crate::reader::default_sql_aggregate(name, qcol), + } + } } /// A test dialect with no inline quantile support, exercising the @@ -1225,6 +1260,132 @@ mod tests { } } + #[cfg(feature = "sqlite")] + #[test] + fn sqlite_dialect_emits_portable_stddev_and_rejects_first() { + use crate::reader::sqlite::SqliteDialect; + + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); + + // sdev must not emit STDDEV_POP (SQLite has no such function). + let result = run( + ParameterValue::String("sdev".to_string()), + &aes, + &schema, + &[], + &SqliteDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!( + !query.contains("STDDEV_POP"), + "SQLite dialect must not emit STDDEV_POP, got: {query}" + ); + assert!(query.contains("SQRT") && query.contains("AVG"), "{query}"); + } + _ => panic!("expected Transformed"), + } + + // first / last route through the validation error rather than emitting + // SQL that SQLite cannot run. + let err = run( + ParameterValue::String("first".to_string()), + &aes, + &schema, + &[], + &SqliteDialect, + ) + .unwrap_err(); + assert!( + format!("{err}").contains("not supported"), + "expected unsupported-function error, got: {err}" + ); + } + + #[test] + fn unsupported_aggregate_errors_with_dialect_that_lacks_function() { + // AnsiDialect's default doesn't implement FIRST/LAST. + struct AnsiTestDialect; + impl SqlDialect for AnsiTestDialect {} + + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); + let err = run( + ParameterValue::String("first".to_string()), + &aes, + &schema, + &[], + &AnsiTestDialect, + ) + .unwrap_err(); + let msg = format!("{}", err); + assert!( + msg.contains("first") && msg.contains("not supported"), + "expected unsupported-function error mentioning 'first', got: {msg}" + ); + } + + #[test] + fn mid_emits_min_max_midpoint() { + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); + let result = run( + ParameterValue::String("mid".to_string()), + &aes, + &schema, + &[], + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!( + query + .contains("(MIN(\"__ggsql_aes_pos1__\") + MAX(\"__ggsql_aes_pos1__\")) / 2.0"), + "{}", + query + ); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn first_and_last_emit_positional_aggregates() { + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2min", col("__ggsql_aes_pos2min__")); + aes.insert("pos2max", col("__ggsql_aes_pos2max__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", false), + ("__ggsql_aes_pos2min__", false), + ("__ggsql_aes_pos2max__", false), + ]); + let result = run( + arr(&["first", "last"]), + &aes, + &schema, + &[], + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("FIRST(\"__ggsql_aes_pos2min__\")"), "{}", query); + assert!(query.contains("LAST(\"__ggsql_aes_pos2max__\")"), "{}", query); + } + _ => panic!("expected Transformed"), + } + } + #[test] fn two_defaults_split_lower_and_upper_for_segment() { let mut aes = Mappings::new(); diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index fab422c4..8cac61bd 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -49,6 +49,14 @@ impl super::SqlDialect for DuckDbDialect { )) } + fn sql_aggregate(&self, name: &str, qcol: &str) -> Option { + match name { + "first" => Some(format!("FIRST({})", qcol)), + "last" => Some(format!("LAST({})", qcol)), + _ => super::default_sql_aggregate(name, qcol), + } + } + fn sql_percentile(&self, column: &str, fraction: f64, from: &str, groups: &[String]) -> String { let group_filter = groups .iter() diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 6c469dc1..7e02448f 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -226,6 +226,20 @@ pub trait SqlDialect { None } + /// SQL fragment for a simple aggregate function applied to an + /// already-quoted column expression. + /// + /// Returns `Some(expr)` when the dialect can express this aggregate inline + /// in a `GROUP BY` query. Returns `None` when the aggregate is not + /// supported by this backend; the stat layer surfaces a clear error. + /// + /// Names handled here are the entries of `stat_aggregate::AGG_NAMES` other + /// than the percentile/iqr family, which goes through [`sql_quantile_inline`] + /// / [`sql_percentile`] instead. + fn sql_aggregate(&self, name: &str, qcol: &str) -> Option { + default_sql_aggregate(name, qcol) + } + /// SQL literal for a date value (days since Unix epoch). fn sql_date_literal(&self, days_since_epoch: i32) -> String { format!( @@ -302,6 +316,33 @@ pub(crate) fn wrap_with_column_aliases(body_sql: &str, column_aliases: &[String] ) } +/// Default aggregate SQL emission, shared so dialects can opt into the standard +/// portable forms while overriding selected functions. +/// +/// Returns `None` for names this default can't express portably (today: the +/// `first` / `last` row-positional aggregates — backends that have a native +/// equivalent override [`SqlDialect::sql_aggregate`]). +pub fn default_sql_aggregate(name: &str, qcol: &str) -> Option { + let s = match name { + "count" => format!("COUNT({})", qcol), + "sum" => format!("SUM({})", qcol), + "prod" => format!("EXP(SUM(LN({})))", qcol), + "min" => format!("MIN({})", qcol), + "max" => format!("MAX({})", qcol), + "range" => format!("(MAX({c}) - MIN({c}))", c = qcol), + "mid" => format!("((MIN({c}) + MAX({c})) / 2.0)", c = qcol), + "mean" => format!("AVG({})", qcol), + "geomean" => format!("EXP(AVG(LN({})))", qcol), + "harmean" => format!("(COUNT({c}) * 1.0 / SUM(1.0 / {c}))", c = qcol), + "rms" => format!("SQRT(AVG({c} * {c}))", c = qcol), + "sdev" => format!("STDDEV_POP({})", qcol), + "se" => format!("(STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), + "var" => format!("VAR_POP({})", qcol), + _ => return None, + }; + Some(s) +} + pub struct AnsiDialect; impl SqlDialect for AnsiDialect {} diff --git a/src/reader/sqlite.rs b/src/reader/sqlite.rs index 67f1033b..15645f70 100644 --- a/src/reader/sqlite.rs +++ b/src/reader/sqlite.rs @@ -93,6 +93,23 @@ impl super::SqlDialect for SqliteDialect { ) } + /// Stock SQLite has no `STDDEV_POP` / `VAR_POP`, so express variance, + /// standard deviation, and standard error in portable arithmetic. Every + /// other aggregate falls through to the shared default. + fn sql_aggregate(&self, name: &str, qcol: &str) -> Option { + // Population variance with a `MAX(0, …)` floor against tiny negative + // floats from catastrophic cancellation. Both `MAX(a, b)` and `SQRT` + // are scalar functions in modern bundled SQLite (math-functions build). + let var_pop = || format!("MAX(0.0, AVG({c} * {c}) - AVG({c}) * AVG({c}))", c = qcol); + let s = match name { + "var" => var_pop(), + "sdev" => format!("SQRT({})", var_pop()), + "se" => format!("(SQRT({}) / SQRT(COUNT({c})))", var_pop(), c = qcol), + _ => return super::default_sql_aggregate(name, qcol), + }; + Some(s) + } + /// SQLite does not support `CREATE OR REPLACE`, so emit a drop-then-create /// pair. Column aliases are preserved portably via the default CTE wrapper. fn create_or_replace_temp_table_sql( From c40ea31ab88163120bac1a14533fa1f2f38f8bbf Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Wed, 6 May 2026 14:52:55 +0200 Subject: [PATCH 20/33] Apply suggestions from code review Co-authored-by: Thomas Lin Pedersen Co-authored-by: Teun van den Brand <49372158+teunbrand@users.noreply.github.com> --- doc/syntax/clause/draw.qmd | 9 +++++---- doc/syntax/layer/type/area.qmd | 6 ++++-- doc/syntax/layer/type/line.qmd | 3 +-- doc/syntax/layer/type/range.qmd | 2 +- doc/syntax/layer/type/text.qmd | 2 +- src/plot/layer/geom/mod.rs | 1 - src/plot/layer/geom/types.rs | 3 +-- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/doc/syntax/clause/draw.qmd b/doc/syntax/clause/draw.qmd index f507fc2e..070e5740 100644 --- a/doc/syntax/clause/draw.qmd +++ b/doc/syntax/clause/draw.qmd @@ -81,19 +81,20 @@ Some layers support aggregation of their data through the `aggregate` setting. T The setting takes a single string or an array of strings. Each string is one of: -* **Default** — `''` (no prefix). With one default the function applies to every untargeted numeric mapping. With two defaults the first is used for the lower side of range layers (e.g. `x`/`xmin`) plus all non-range layers, and the second is used for the upper side of range layers (e.g. `xend`/`xmax`). More than two defaults is an error. -* **Target** — `':'`. Applies `func` to the named aesthetic only (`` is a user-facing name like `x`, `y`, `xmin`, `xmax`, `xend`, `yend`, `color`, `size`, …). A target overrides any default for that aesthetic. +* **Untargeted** — `''` (no prefix). With one untargeted aggregation the function applies to every numeric mapping that doesn't have a targeted aggregation. With two untargeted aggregations the first is used for the lower side of range layers (e.g. `x`/`xmin`) plus all non-range layers, and the second is used for the upper side of range layers (e.g. `xend`/`xmax`). More than two untargeted aggregations is an error. +* **Targeted** — `':'`. Applies `func` to the named aesthetic only (`` is a user-facing name like `x`, `y`, `xmin`, `xmax`, `xend`, `yend`, `color`, `size`, …). A target overrides any untargeted aggregation for that aesthetic. -A numeric mapping that has neither a target nor an applicable default is dropped from the layer with a warning. +A numeric mapping is dropped from the layer with a warning, when it has neither a target nor an applicable default . You can also target the same aesthetic more than once to produce **multiple rows per group** — one for each function. For example `aggregate => ('y:min', 'y:max')` emits a min row and a max row per group, so a single `DRAW line` produces two summary lines that connect within each group rather than across them. The stat exposes a synthetic `aggregate` column tagging each row, which you can pick up with a `REMAPPING` to drive another aesthetic — e.g. `REMAPPING aggregate AS stroke` to colour the two lines differently. The column's value is built from the per-row function names of the *exploded* targets, deduplicated, and joined with `/`: * `aggregate => ('y:min', 'y:max')` → rows tagged `'min'`, `'max'`. +* `aggregate => ('y:min', 'y:max', 'color:median')` → rows tagged `'min'`, `'max'` (the single-function `color` target is recycled across rows and is not part of the label). * `aggregate => ('y:min', 'y:max', 'color:sum', 'color:prod')` → rows tagged `'min/sum'`, `'max/prod'`. * `aggregate => ('y:mean', 'y:max', 'color:mean', 'color:prod')` → rows tagged `'mean'`, `'max/prod'` (the duplicate `'mean'` collapses). -* `aggregate => ('y:min', 'y:max', 'color:median')` → rows tagged `'min'`, `'max'` (the single-function `color` target is recycled across rows and is not part of the label). + When several aesthetics are targeted with the same number of functions, they explode in lockstep (row 1 uses each aesthetic's first function, row 2 the second, and so on); aesthetics with a single function — and the unprefixed defaults — are reused unchanged across every row. Mixing different lengths above 1 is an error. diff --git a/doc/syntax/layer/type/area.qmd b/doc/syntax/layer/type/area.qmd index 5213175e..623b00a2 100644 --- a/doc/syntax/layer/type/area.qmd +++ b/doc/syntax/layer/type/area.qmd @@ -25,10 +25,12 @@ The following aesthetics are recognised by the area layer. * `orientation`: The orientation of the layer, see the [Orientation section](#orientation). One of the following: * `'aligned'` to align the layer's primary axis with the coordinate system's first axis. * `'transposed'` to align the layer's primary axis with the coordinate system's second axis. -* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. +* `aggregate` Aggregation functions to apply per group: + * `null` apply no group aggregation (default). + * A single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. +This layer supports aggregation through the `aggregate` setting. Aggregation groups are defined by `PARTITION BY`, all discrete mappings, but also the primary axis. Within each group, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. Further, the area layer sorts the data along its primary axis before returning it. diff --git a/doc/syntax/layer/type/line.qmd b/doc/syntax/layer/type/line.qmd index ca266a6a..acbe32f3 100644 --- a/doc/syntax/layer/type/line.qmd +++ b/doc/syntax/layer/type/line.qmd @@ -33,7 +33,6 @@ Further, the line layer sorts the data along its primary axis before returning i If the line has a variable `stroke` or `opacity` aesthetic within groups, the line is broken into segments. Each segment gets the property of the preceding datapoint, so the last datapoint in a group does not transfer these properties. -This behavior is not compatible with aggregation. ## Orientation Line plots are sorted and connected along their primary axis. Since the primary axis cannot be deduced from the mapping it must be specified using the `orientation` setting. If you wish to create a vertical line plot, you need to set `orientation => 'transposed'` to indicate that the primary layer axis follows the second axis of the coordinate system. @@ -96,7 +95,7 @@ DRAW line SCALE linewidth TO (0, 30) ``` -Use aggregation to draw min and max lines from a set of observations on a single layer. Targeting `y` twice produces one summary line per function within the same layer, with a synthetic `aggregate` column tagging each row that you can remap to colour the lines distinctly: +Use aggregation to draw min and max lines from a set of observations on a single layer. Targeting `y` twice produces one summary row per function within the same group. A synthetic `aggregate` column tags each row with the different function names, that you can remap to colour the lines distinctly: ```{ggsql} VISUALISE Day AS x, Temp AS y FROM ggsql:airquality diff --git a/doc/syntax/layer/type/range.qmd b/doc/syntax/layer/type/range.qmd index bd8d501b..2f3f116a 100644 --- a/doc/syntax/layer/type/range.qmd +++ b/doc/syntax/layer/type/range.qmd @@ -25,7 +25,7 @@ The following aesthetics are recognised by the range layer. * `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value, producing one range per group. Range is a range layer: with two defaults the first applies to the start point (`xmin`/`ymin`) and the second applies to the end point (`xmax`/`ymax`). Use a single default like `'mean'` to apply the same function to all values, or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value, producing one range per group. Range is a range layer with two defaults: the first applies to the start point (`xmin`/`ymin`) and the second applies to the end point (`xmax`/`ymax`). Use a single default like `'mean'` to apply the same function to all values, or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ## Orientation The orientation of range layers is deduced directly from the mapping, because the interval is mapped to the secondary axis. To create a horizontal range layer, you map the independent variable to `y` instead of `x` and the interval to `xmin` and `xmax` (assuming a default Cartesian coordinate system). diff --git a/doc/syntax/layer/type/text.qmd b/doc/syntax/layer/type/text.qmd index 58983055..1cca11d9 100644 --- a/doc/syntax/layer/type/text.qmd +++ b/doc/syntax/layer/type/text.qmd @@ -148,7 +148,7 @@ PLACE text y => (19, 19, 15) ``` -Use aggregation to place labels at their centroid +Use aggregation to place labels at their centroid. ```{ggsql} VISUALISE bill_len AS x, bill_dep AS y FROM ggsql:penguins diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index 1a6b7082..b9493668 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -297,7 +297,6 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { /// True when `parameters["aggregate"]` is set to a non-null string or array. pub(crate) fn has_aggregate_param(parameters: &HashMap) -> bool { match parameters.get("aggregate") { - None | Some(ParameterValue::Null) => false, Some(ParameterValue::String(_)) | Some(ParameterValue::Array(_)) => true, _ => false, } diff --git a/src/plot/layer/geom/types.rs b/src/plot/layer/geom/types.rs index e9f86bcf..5f57cc0c 100644 --- a/src/plot/layer/geom/types.rs +++ b/src/plot/layer/geom/types.rs @@ -181,8 +181,7 @@ pub use crate::plot::types::Schema; /// domain axis whether or not the layer also goes through the Aggregate stat. /// /// - `Identity` → becomes `Transformed` with ` ORDER BY `, -/// empty `stat_columns`/`dummy_columns`/`consumed_aesthetics`. Same shape as -/// the previous inline `ORDER BY` path produced. +/// empty `stat_columns`/`dummy_columns`/`consumed_aesthetics`. /// - `Transformed` → wraps the existing query in /// `SELECT * FROM () AS "__ggsql_ord__" ORDER BY ` and preserves /// the stat metadata. From d76825d23e37914226d0bded4c462f6dfe16b812 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Wed, 6 May 2026 15:14:26 +0200 Subject: [PATCH 21/33] apply doc changes to all layers --- doc/syntax/clause/draw.qmd | 23 ++++++++++------------- doc/syntax/layer/type/bar.qmd | 6 ++++-- doc/syntax/layer/type/line.qmd | 6 ++++-- doc/syntax/layer/type/point.qmd | 6 ++++-- doc/syntax/layer/type/range.qmd | 4 +++- doc/syntax/layer/type/ribbon.qmd | 6 ++++-- doc/syntax/layer/type/rule.qmd | 6 ++++-- doc/syntax/layer/type/segment.qmd | 6 ++++-- doc/syntax/layer/type/text.qmd | 6 ++++-- 9 files changed, 41 insertions(+), 28 deletions(-) diff --git a/doc/syntax/clause/draw.qmd b/doc/syntax/clause/draw.qmd index 070e5740..d7aba95d 100644 --- a/doc/syntax/clause/draw.qmd +++ b/doc/syntax/clause/draw.qmd @@ -84,19 +84,7 @@ The setting takes a single string or an array of strings. Each string is one of: * **Untargeted** — `''` (no prefix). With one untargeted aggregation the function applies to every numeric mapping that doesn't have a targeted aggregation. With two untargeted aggregations the first is used for the lower side of range layers (e.g. `x`/`xmin`) plus all non-range layers, and the second is used for the upper side of range layers (e.g. `xend`/`xmax`). More than two untargeted aggregations is an error. * **Targeted** — `':'`. Applies `func` to the named aesthetic only (`` is a user-facing name like `x`, `y`, `xmin`, `xmax`, `xend`, `yend`, `color`, `size`, …). A target overrides any untargeted aggregation for that aesthetic. -A numeric mapping is dropped from the layer with a warning, when it has neither a target nor an applicable default . - -You can also target the same aesthetic more than once to produce **multiple rows per group** — one for each function. For example `aggregate => ('y:min', 'y:max')` emits a min row and a max row per group, so a single `DRAW line` produces two summary lines that connect within each group rather than across them. - -The stat exposes a synthetic `aggregate` column tagging each row, which you can pick up with a `REMAPPING` to drive another aesthetic — e.g. `REMAPPING aggregate AS stroke` to colour the two lines differently. The column's value is built from the per-row function names of the *exploded* targets, deduplicated, and joined with `/`: - -* `aggregate => ('y:min', 'y:max')` → rows tagged `'min'`, `'max'`. -* `aggregate => ('y:min', 'y:max', 'color:median')` → rows tagged `'min'`, `'max'` (the single-function `color` target is recycled across rows and is not part of the label). -* `aggregate => ('y:min', 'y:max', 'color:sum', 'color:prod')` → rows tagged `'min/sum'`, `'max/prod'`. -* `aggregate => ('y:mean', 'y:max', 'color:mean', 'color:prod')` → rows tagged `'mean'`, `'max/prod'` (the duplicate `'mean'` collapses). - - -When several aesthetics are targeted with the same number of functions, they explode in lockstep (row 1 uses each aesthetic's first function, row 2 the second, and so on); aesthetics with a single function — and the unprefixed defaults — are reused unchanged across every row. Mixing different lengths above 1 is an error. +A numeric mapping is dropped from the layer with a warning, when it has neither a target nor an applicable default. The simple functions are: @@ -115,6 +103,15 @@ Allowed offsets are: `'mean'`, `'median'`, `'geomean'`, `'harmean'`, `'rms'`, `' Allowed expansions are: `'sdev'`, `'se'`, `'var'`, `'iqr'`, and `'range'` +You can also target the same aesthetic more than once to produce **multiple rows per group** — one for each function. For example `aggregate => ('y:min', 'y:max')` emits a min row and a max row per group, so a single `DRAW line` produces two summary lines that connect within each group rather than across them. When multiple rows are created a synthetic `aggregate` column is made that tags each row with the aggregation function. You can use this with a `REMAPPING` to drive another aesthetic — e.g. `REMAPPING aggregate AS stroke` to colour the two lines differently. The column's value is built from the per-row function names of the *exploded* targets, deduplicated, and joined with `/`: + +* `aggregate => ('y:min', 'y:max')` → rows tagged `'min'`, `'max'`. +* `aggregate => ('y:min', 'y:max', 'color:median')` → rows tagged `'min'`, `'max'` (the single-function `color` target is recycled across rows and is not part of the label). +* `aggregate => ('y:min', 'y:max', 'color:sum', 'color:prod')` → rows tagged `'min/sum'`, `'max/prod'`. +* `aggregate => ('y:mean', 'y:max', 'color:mean', 'color:prod')` → rows tagged `'mean'`, `'max/prod'` (the duplicate `'mean'` collapses). + +When several aesthetics are targeted with the same number of functions, they explode in lockstep (row 1 uses each aesthetic's first function, row 2 the second, and so on); aesthetics with a single function — and the unprefixed defaults — are reused unchanged across every row. Mixing different lengths above 1 is an error. + In the single-row (reduction) case aggregation applies in place — no `REMAPPING` is needed and no synthetic column is added. Only the multi-row (explosion) case described above introduces the synthetic `aggregate` column. ### `FILTER` diff --git a/doc/syntax/layer/type/bar.qmd b/doc/syntax/layer/type/bar.qmd index e71ba3f2..d32b1f88 100644 --- a/doc/syntax/layer/type/bar.qmd +++ b/doc/syntax/layer/type/bar.qmd @@ -25,12 +25,14 @@ The bar layer has no required aesthetics ## Settings * `position`: Position adjustment. One of `'identity'`, `'stack'` (default), `'dodge'`, or `'jitter'` * `width`: The width of the bars as a proportion of the available width (0 to 1) -* `aggregate`: Aggregation functions to apply per group if the secondary position has been mapped. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. +* `aggregate` Aggregation functions to apply per group: + * `null` apply no group aggregation (default). + * A single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation If the secondary axis has not been mapped the layer will calculate counts for you and display these as the secondary axis. -If the secondary axis has been mapped you can apply aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. +This layer supports aggregation through the `aggregate` setting. Aggregation groups are defined by `PARTITION BY` and all discrete mappings. Within each group, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ### Properties diff --git a/doc/syntax/layer/type/line.qmd b/doc/syntax/layer/type/line.qmd index acbe32f3..89864715 100644 --- a/doc/syntax/layer/type/line.qmd +++ b/doc/syntax/layer/type/line.qmd @@ -24,10 +24,12 @@ The following aesthetics are recognised by the line layer. * `orientation`: The orientation of the layer, see the [Orientation section](#orientation). One of the following: * `'aligned'` to align the layer's primary axis with the coordinate system's first axis. * `'transposed'` to align the layer's primary axis with the coordinate system's second axis. -* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. +* `aggregate` Aggregation functions to apply per group: + * `null` apply no group aggregation (default). + * A single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY`, all discrete mappings, and the primary axis, every numeric mapping is replaced in place by its aggregated value to produce a summary trace. Use a default like `'mean'` to summarise the secondary axis, or target other aesthetics with `':'` (e.g. `'color:median'`). To draw min/max envelope lines, use a separate `DRAW line` layer per function, or use a [`range` layer](range.qmd) for a single range mark. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. +This layer supports aggregation through the `aggregate` setting. Aggregation groups are defined by `PARTITION BY`, all discrete mappings, but also the primary axis. Within each group, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. Further, the line layer sorts the data along its primary axis before returning it. diff --git a/doc/syntax/layer/type/point.qmd b/doc/syntax/layer/type/point.qmd index 46eb13ed..b9fd5016 100644 --- a/doc/syntax/layer/type/point.qmd +++ b/doc/syntax/layer/type/point.qmd @@ -23,10 +23,12 @@ The following aesthetics are recognised by the point layer. ## Settings * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` -* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. +* `aggregate` Aggregation functions to apply per group: + * `null` apply no group aggregation (default). + * A single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. +This layer supports aggregation through the `aggregate` setting. Aggregation groups are defined by `PARTITION BY` and all discrete mappings. Within each group, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ## Orientation The point layer has no orientation. The axes are treated symmetrically. diff --git a/doc/syntax/layer/type/range.qmd b/doc/syntax/layer/type/range.qmd index 2f3f116a..43d30693 100644 --- a/doc/syntax/layer/type/range.qmd +++ b/doc/syntax/layer/type/range.qmd @@ -22,7 +22,9 @@ The following aesthetics are recognised by the range layer. ## Settings * `width`: The width of the hinges in points (must be >= 0). Defaults to 10. Can be set to `null` to not display hinges. -* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and the *Data transformation* section below. +* `aggregate` Aggregation functions to apply per group: + * `null` apply no group aggregation (default). + * A single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value, producing one range per group. Range is a range layer with two defaults: the first applies to the start point (`xmin`/`ymin`) and the second applies to the end point (`xmax`/`ymax`). Use a single default like `'mean'` to apply the same function to all values, or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. diff --git a/doc/syntax/layer/type/ribbon.qmd b/doc/syntax/layer/type/ribbon.qmd index 742328f4..b3e2b375 100644 --- a/doc/syntax/layer/type/ribbon.qmd +++ b/doc/syntax/layer/type/ribbon.qmd @@ -23,10 +23,12 @@ The following aesthetics are recognised by the ribbon layer. ## Settings * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` -* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and the *Data transformation* section below. +* `aggregate` Aggregation functions to apply per group: + * `null` apply no group aggregation (default). + * A single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value, producing one ribbon per group. Ribon is a range layer: with two defaults the first applies to the start point (`xmin`/`ymin`) and the second applies to the end point (`xmax`/`ymax`). Use a single default like `'mean'` to apply the same function to all values, or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value, producing one ribbon per group. Ribon is a range layer with two defaults: the first applies to the start point (`xmin`/`ymin`) and the second applies to the end point (`xmax`/`ymax`). Use a single default like `'mean'` to apply the same function to all values, or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ## Orientation Ribbon layers are sorted and connected along their primary axis. The orientation is deduced directly from the mapping, because the interval is mapped to the secondary axis. To create a vertical ribbon layer you map the independent variable to `y` instead of `x` and the interval to `xmin` and `xmax` (assuming a default Cartesian coordinate system). diff --git a/doc/syntax/layer/type/rule.qmd b/doc/syntax/layer/type/rule.qmd index 470ea39e..032e5ea1 100644 --- a/doc/syntax/layer/type/rule.qmd +++ b/doc/syntax/layer/type/rule.qmd @@ -25,10 +25,12 @@ The following aesthetics are recognised by the rule layer. ## Settings * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` -* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. +* `aggregate` Aggregation functions to apply per group: + * `null` apply no group aggregation (default). + * A single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. +This layer supports aggregation through the `aggregate` setting. Aggregation groups are defined by `PARTITION BY` and all discrete mappings. Within each group, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. For diagonal lines, the position aesthetic determines the intercept: diff --git a/doc/syntax/layer/type/segment.qmd b/doc/syntax/layer/type/segment.qmd index ac759829..f2aab57b 100644 --- a/doc/syntax/layer/type/segment.qmd +++ b/doc/syntax/layer/type/segment.qmd @@ -25,10 +25,12 @@ For axis-aligned intervals where one coordinate is shared between the start and ## Settings * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` -* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. +* `aggregate` Aggregation functions to apply per group: + * `null` apply no group aggregation (default). + * A single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value, producing one segment per group. Segment is a range layer: with two defaults the first applies to the start point (`x`/`y`) and the second applies to the end point (`xend`/`yend`). Use a single default like `'mean'` to apply the same function to all four endpoints, or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. +This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value, producing one segment per group. Segment is a range layer with two defaults: the first applies to the start point (`x`/`y`) and the second applies to the end point (`xend`/`yend`). Use a single default like `'mean'` to apply the same function to all four endpoints, or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ## Orientation The segment layer has no orientations. The axes are treated symmetrically. diff --git a/doc/syntax/layer/type/text.qmd b/doc/syntax/layer/type/text.qmd index 1cca11d9..6a431033 100644 --- a/doc/syntax/layer/type/text.qmd +++ b/doc/syntax/layer/type/text.qmd @@ -35,7 +35,9 @@ The following aesthetics are recognised by the text layer. * a 2-element numeric array `[h, v]` where the first number is the horizontal offset and the second number is the vertical offset. * `format` Formatting specifier, see explanation below. * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` -* `aggregate`: Aggregation functions to apply per group. Either a single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. +* `aggregate` Aggregation functions to apply per group: + * `null` apply no group aggregation (default). + * A single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ### Format The `format` setting can take a string that will be used in formatting the `label` aesthetic. @@ -67,7 +69,7 @@ Known formatters are: * `x`/`X`: Unsigned hexadecimal ## Data transformation -This layer supports aggregation through the `aggregate` setting. Within each group, defined by `PARTITION BY` and all discrete mappings, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. +This layer supports aggregation through the `aggregate` setting. Aggregation groups are defined by `PARTITION BY` and all discrete mappings. Within each group, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. ## Orientation The text layer has no orientation. The axes are treated symmetrically. From bcbedba4841218b4333b86c13de46b5535af39d7 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 10:45:21 +0200 Subject: [PATCH 22/33] support first and last in ANSI, add diff --- .gitignore | 1 + doc/syntax/clause/draw.qmd | 3 +- doc/syntax/layer/type/range.qmd | 12 ++ src/plot/layer/geom/stat_aggregate.rs | 215 +++++++++++++++++++++++--- src/reader/duckdb.rs | 1 + src/reader/mod.rs | 17 +- 6 files changed, 222 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index be1bbfdd..1432765b 100644 --- a/.gitignore +++ b/.gitignore @@ -96,6 +96,7 @@ criterion/ # Claude Code specific .claude/ +memory # R specific *.Rproj.user diff --git a/doc/syntax/clause/draw.qmd b/doc/syntax/clause/draw.qmd index d7aba95d..010857ff 100644 --- a/doc/syntax/clause/draw.qmd +++ b/doc/syntax/clause/draw.qmd @@ -95,7 +95,8 @@ The simple functions are: * `'geomean'`, `'harmean'`, and `'rms'`: Geometric, harmonic, and root-mean-square * `'sdev'`, `'var'`, `'iqr'`, and `'se'`: Standard deviation, variance, interquartile range, and standard error * `'p05'`, `'p10'`, `'p25'`, `'p50'`, `'p75'`, `'p90'`, and `'p95'`: Percentiles -* `'first'` and `'last'`: The first or last value in the group, in row order +* `'first'` and `'last'`: The first or last value in the group, in row order. Note that the row order within a group is engine-defined unless the source query has an `ORDER BY` — these are most useful when the upstream SQL provides an explicit ordering. +* `'diff'`: `last - first`. The change between the first and last value in row order — same ordering caveat applies. For band functions you combine an offset with an expansion, potentially multiplied. An example could be `'mean-1.96sdev'` which does exactly what you'd expect it to be. The general form is `±` with `` being optional (defaults to `1`). diff --git a/doc/syntax/layer/type/range.qmd b/doc/syntax/layer/type/range.qmd index 43d30693..9464c1c6 100644 --- a/doc/syntax/layer/type/range.qmd +++ b/doc/syntax/layer/type/range.qmd @@ -112,6 +112,18 @@ DRAW range SETTING width => null ``` +```{ggsql} +VISUALISE Date AS x, Temp AS ymin, Temp AS ymax, Temp AS color + FROM ggsql:airquality +DRAW range + REMAPPING aggregate AS linewidth + SETTING aggregate => ('x:first', 'ymin:first', 'ymin:min', 'ymax:last', 'ymax:max', 'color:diff'), width => null + PARTITION BY Week +SCALE linewidth TO (5, 1) +SCALE color FROM (-20, 20) TO managua + SETTING reverse => true, oob => 'keep' +``` + Use aggregation to calculate bounds dynamically ```{ggsql} diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 8015fd0e..049b1509 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -45,8 +45,9 @@ pub const AGG_NAMES: &[&str] = &[ "min", "max", "range", "mid", // Central tendency "mean", "geomean", "harmean", "rms", "median", // Spread (standalone) "sdev", "var", "iqr", // Percentiles - "p05", "p10", "p25", "p50", "p75", "p90", "p95", // Positional (row order in the group) - "first", "last", + "p05", "p10", "p25", "p50", "p75", "p90", + "p95", // Positional (row order in the group) + "first", "last", "diff", ]; /// Stats that can appear as the *offset* (left of `±`) in a band name like @@ -783,6 +784,68 @@ pub fn apply( }) } +/// CTE preamble plus the alias the caller should `FROM`. When any emitted +/// aggregate references the `__ggsql_rn__` / `__ggsql_max_rn__` columns +/// (the dialect-portable form of `first` / `last`), wrap the source CTE in a +/// row-numbered layer. +fn source_cte_chain( + query: &str, + aggregated: &[(String, String, Vec)], + group_cols: &[String], + dialect: &dyn SqlDialect, +) -> (String, &'static str) { + let raw_src = "\"__ggsql_stat_src__\""; + if !needs_row_position(aggregated, dialect) { + return (format!("WITH {raw_src} AS ({query})"), raw_src); + } + let rn_src = "\"__ggsql_stat_src_rn__\""; + let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); + // ORDER BY (SELECT 1) is the canonical "no real ordering" stand-in: it + // satisfies the standard's required ORDER BY for window functions while + // letting the engine pick the row order — same indeterminacy as DuckDB's + // native FIRST() without a user ORDER BY. + let partition = if group_select.is_empty() { + String::new() + } else { + format!("PARTITION BY {} ", group_select.join(", ")) + }; + let cte = format!( + "WITH {raw_src} AS ({query}), {rn_src} AS (\ + SELECT *, \ + ROW_NUMBER() OVER ({partition}ORDER BY (SELECT 1)) AS \"__ggsql_rn__\", \ + COUNT(*) OVER ({partition_no_order}) AS \"__ggsql_max_rn__\" \ + FROM {raw_src}\ + )", + partition_no_order = partition.trim_end(), + ); + (cte, rn_src) +} + +/// True iff at least one aggregate spec, after the dialect emits its SQL, +/// references the row-position columns. Backends with native `FIRST`/`LAST` +/// (DuckDB) emit a string that doesn't mention `__ggsql_rn__`, and so don't +/// pay for the extra window functions. +fn needs_row_position( + aggregated: &[(String, String, Vec)], + dialect: &dyn SqlDialect, +) -> bool { + for (_, _, specs) in aggregated { + for spec in specs { + for name in [Some(spec.offset), spec.band.as_ref().map(|b| b.expansion)] + .into_iter() + .flatten() + { + if let Some(sql) = dialect.sql_aggregate(name, "x") { + if sql.contains("__ggsql_rn__") { + return true; + } + } + } + } + } + false +} + /// Build the single-row `WITH src AS () SELECT , /// FROM src AS "__ggsql_qt__" GROUP BY ` query. Each aggregated /// aesthetic's function list is length 1 here. @@ -795,8 +858,8 @@ fn build_group_by_query( group_cols: &[String], dialect: &dyn SqlDialect, ) -> String { - let src_alias = "\"__ggsql_stat_src__\""; let outer_alias = "\"__ggsql_qt__\""; + let (with_clause, src_alias) = source_cte_chain(query, aggregated, group_cols, dialect); let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); let group_by_clause = if group_cols.is_empty() { @@ -821,10 +884,9 @@ fn build_group_by_query( } format!( - "WITH {src} AS ({query}) SELECT {sel} FROM {src} AS {outer}{gb}", - src = src_alias, - query = query, + "{with_clause} SELECT {sel} FROM {src} AS {outer}{gb}", sel = select_parts.join(", "), + src = src_alias, outer = outer_alias, gb = group_by_clause, ) @@ -841,8 +903,8 @@ fn build_aggregate_query( labels: &[String], dialect: &dyn SqlDialect, ) -> String { - let src_alias = "\"__ggsql_stat_src__\""; let outer_alias = "\"__ggsql_qt__\""; + let (with_clause, src_alias) = source_cte_chain(query, aggregated, group_cols, dialect); let group_select: Vec = group_cols.iter().map(|c| naming::quote_ident(c)).collect(); let group_by_clause = if group_cols.is_empty() { @@ -889,9 +951,7 @@ fn build_aggregate_query( .collect(); format!( - "WITH {src} AS ({query}) {body}", - src = src_alias, - query = query, + "{with_clause} {body}", body = branches.join(" UNION ALL "), ) } @@ -923,6 +983,7 @@ mod tests { match name { "first" => Some(format!("FIRST({})", qcol)), "last" => Some(format!("LAST({})", qcol)), + "diff" => Some(format!("(LAST({c}) - FIRST({c}))", c = qcol)), _ => crate::reader::default_sql_aggregate(name, qcol), } } @@ -1262,7 +1323,7 @@ mod tests { #[cfg(feature = "sqlite")] #[test] - fn sqlite_dialect_emits_portable_stddev_and_rejects_first() { + fn sqlite_dialect_emits_portable_stddev_and_first() { use crate::reader::sqlite::SqliteDialect; let mut aes = Mappings::new(); @@ -1290,27 +1351,48 @@ mod tests { _ => panic!("expected Transformed"), } - // first / last route through the validation error rather than emitting - // SQL that SQLite cannot run. - let err = run( + // first now uses the portable ROW_NUMBER + MAX(CASE) form. It must run + // on SQLite without `FIRST` ever appearing as an aggregate call. + let result = run( ParameterValue::String("first".to_string()), &aes, &schema, &[], &SqliteDialect, ) - .unwrap_err(); - assert!( - format!("{err}").contains("not supported"), - "expected unsupported-function error, got: {err}" - ); + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!( + query.contains("ROW_NUMBER()"), + "expected ROW_NUMBER prep, got: {query}" + ); + assert!( + query.contains("\"__ggsql_rn__\" = 1"), + "expected first via rn=1, got: {query}" + ); + assert!( + !query.contains("FIRST(\""), + "must not call FIRST as an aggregate, got: {query}" + ); + } + _ => panic!("expected Transformed"), + } } #[test] fn unsupported_aggregate_errors_with_dialect_that_lacks_function() { - // AnsiDialect's default doesn't implement FIRST/LAST. - struct AnsiTestDialect; - impl SqlDialect for AnsiTestDialect {} + // A dialect that explicitly opts out of `first` (returns None) must + // produce the validation error rather than emitting broken SQL. + struct OptOutDialect; + impl SqlDialect for OptOutDialect { + fn sql_aggregate(&self, name: &str, qcol: &str) -> Option { + if name == "first" { + return None; + } + crate::reader::default_sql_aggregate(name, qcol) + } + } let mut aes = Mappings::new(); aes.insert("pos1", col("__ggsql_aes_pos1__")); @@ -1321,7 +1403,7 @@ mod tests { &aes, &schema, &[], - &AnsiTestDialect, + &OptOutDialect, ) .unwrap_err(); let msg = format!("{}", err); @@ -1358,6 +1440,93 @@ mod tests { } } + #[test] + fn diff_uses_row_position_and_subtracts_first_from_last() { + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); + + // AnsiDialect path: portable rn-based form for last - first. + struct AnsiTestDialect; + impl SqlDialect for AnsiTestDialect {} + let result = run( + ParameterValue::String("diff".to_string()), + &aes, + &schema, + &[], + &AnsiTestDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!(query.contains("ROW_NUMBER()"), "{query}"); + assert!( + query.contains("\"__ggsql_rn__\" = \"__ggsql_max_rn__\""), + "{query}" + ); + assert!(query.contains("\"__ggsql_rn__\" = 1"), "{query}"); + assert!(query.contains(" - "), "expected subtraction, got: {query}"); + } + _ => panic!("expected Transformed"), + } + + // Native-FIRST/LAST path: no rn CTE. + let result = run( + ParameterValue::String("diff".to_string()), + &aes, + &schema, + &[], + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!( + query.contains("LAST(") && query.contains("FIRST("), + "expected native LAST/FIRST: {query}" + ); + assert!( + !query.contains("__ggsql_rn__"), + "native dialect must not add ROW_NUMBER prep: {query}" + ); + } + _ => panic!("expected Transformed"), + } + } + + #[cfg(feature = "duckdb")] + #[test] + fn duckdb_first_skips_row_number_cte() { + use crate::reader::duckdb::DuckDbDialect; + + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = schema_for(&[("__ggsql_aes_pos1__", false), ("__ggsql_aes_pos2__", false)]); + let result = run( + ParameterValue::String("first".to_string()), + &aes, + &schema, + &[], + &DuckDbDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!( + query.contains("FIRST(\""), + "expected native FIRST aggregate, got: {query}" + ); + assert!( + !query.contains("__ggsql_rn__"), + "DuckDB has native FIRST, must not add ROW_NUMBER prep: {query}" + ); + } + _ => panic!("expected Transformed"), + } + } + #[test] fn first_and_last_emit_positional_aggregates() { let mut aes = Mappings::new(); diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index 8cac61bd..42bb9676 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -53,6 +53,7 @@ impl super::SqlDialect for DuckDbDialect { match name { "first" => Some(format!("FIRST({})", qcol)), "last" => Some(format!("LAST({})", qcol)), + "diff" => Some(format!("(LAST({c}) - FIRST({c}))", c = qcol)), _ => super::default_sql_aggregate(name, qcol), } } diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 7e02448f..054f60d0 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -319,9 +319,10 @@ pub(crate) fn wrap_with_column_aliases(body_sql: &str, column_aliases: &[String] /// Default aggregate SQL emission, shared so dialects can opt into the standard /// portable forms while overriding selected functions. /// -/// Returns `None` for names this default can't express portably (today: the -/// `first` / `last` row-positional aggregates — backends that have a native -/// equivalent override [`SqlDialect::sql_aggregate`]). +/// `first` / `last` are expressed as `MAX(CASE WHEN __ggsql_rn__ = … THEN col END)`, +/// which depends on the row-number columns the stat layer injects when any +/// aggregate references them. Backends with a cheaper native equivalent +/// (e.g. DuckDB's `FIRST`/`LAST`) override [`SqlDialect::sql_aggregate`]. pub fn default_sql_aggregate(name: &str, qcol: &str) -> Option { let s = match name { "count" => format!("COUNT({})", qcol), @@ -338,6 +339,16 @@ pub fn default_sql_aggregate(name: &str, qcol: &str) -> Option { "sdev" => format!("STDDEV_POP({})", qcol), "se" => format!("(STDDEV_POP({c}) / SQRT(COUNT({c})))", c = qcol), "var" => format!("VAR_POP({})", qcol), + "first" => format!("MAX(CASE WHEN \"__ggsql_rn__\" = 1 THEN {} END)", qcol), + "last" => format!( + "MAX(CASE WHEN \"__ggsql_rn__\" = \"__ggsql_max_rn__\" THEN {} END)", + qcol + ), + "diff" => format!( + "(MAX(CASE WHEN \"__ggsql_rn__\" = \"__ggsql_max_rn__\" THEN {c} END) \ + - MAX(CASE WHEN \"__ggsql_rn__\" = 1 THEN {c} END))", + c = qcol + ), _ => return None, }; Some(s) From 905063d219b77d186901efbf8a4eb001b89ba512 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 10:46:09 +0200 Subject: [PATCH 23/33] support tile --- doc/syntax/layer/type/tile.qmd | 14 ++ src/plot/layer/geom/tile.rs | 267 ++++++++++++++++++++++++++++++++- 2 files changed, 276 insertions(+), 5 deletions(-) diff --git a/doc/syntax/layer/type/tile.qmd b/doc/syntax/layer/type/tile.qmd index e700092c..99256e43 100644 --- a/doc/syntax/layer/type/tile.qmd +++ b/doc/syntax/layer/type/tile.qmd @@ -37,11 +37,16 @@ Alternatively, use only the center, which will set `height` to 1 by default. ## Settings * `position`: Position adjustment. One of `'identity'` (default), `'stack'`, `'dodge'`, or `'jitter'` +* `aggregate` Aggregation functions to apply per group: + * `null` apply no group aggregation (default). + * A single string or an array of strings. See an overview of aggregation function in [the `DRAW` documentation](../../clause/draw.qmd#aggregate) and more information in the *Data transformation* section below. ## Data transformation. When the primary aesthetics are continuous, primary data is reparameterised to {start, end}, e.g. `xmin` and `xmax`. When the secondary aesthetics are continuous, secondary data is reparameterised to {start, end}, e.g. `ymin` and `ymax`. +This layer also supports aggregation through the `aggregate` setting. Aggregation groups are defined by `PARTITION BY` and all discrete mappings. Within each group, every numeric mapping is replaced in place by its aggregated value. Use a default like `'mean'` or target individual aesthetics with `':'`. See [the `DRAW` documentation](../../clause/draw.qmd#aggregate) for the full setting shape. The position parameterisation runs after aggregation, so a heatmap from raw rows is just one `aggregate => ''` setting away. + ## Orientation The tile layer has no orientation. The axes are treated symmetrically. @@ -91,6 +96,15 @@ VISUALISE start AS xmin, end AS xmax, min AS ymin, max AS ymax DRAW tile ``` +Building a heatmap from raw rows by aggregating per cell. + +```{ggsql} +VISUALISE FROM ggsql:airquality +DRAW tile + MAPPING Month AS x, Day AS y, Temp AS fill + SETTING aggregate => 'mean' +``` + Using a tile as an annotation. Note we're using the `PLACE` clause here instead of `DRAW` because we're not mapping from data. ```{ggsql} diff --git a/src/plot/layer/geom/tile.rs b/src/plot/layer/geom/tile.rs index fea51d38..63bcd68d 100644 --- a/src/plot/layer/geom/tile.rs +++ b/src/plot/layer/geom/tile.rs @@ -2,12 +2,16 @@ use std::collections::HashMap; +use super::stat_aggregate; use super::types::POSITION_VALUES; use super::types::{get_column_name, get_quoted_column_name}; -use super::{DefaultAesthetics, GeomTrait, GeomType, ParamConstraint, StatResult}; +use super::{ + has_aggregate_param, DefaultAesthetics, GeomTrait, GeomType, ParamConstraint, StatResult, +}; use crate::naming; -use crate::plot::types::{DefaultAestheticValue, ParameterValue}; +use crate::plot::types::{ColumnInfo, DefaultAestheticValue, ParameterValue}; use crate::plot::{DefaultParamValue, ParamDefinition}; +use crate::reader::SqlDialect; use crate::{DataFrame, GgsqlError, Mappings, Result}; use super::types::Schema; @@ -95,6 +99,20 @@ impl GeomTrait for Tile { true } + fn supports_aggregate(&self) -> bool { + true + } + + /// Every spatial slot is pinned as a group key — the rectangle's position + /// and size *define* the group, they are never the thing being summarised. + /// Material aesthetics (fill, stroke, opacity, …) pass through to the + /// aggregate as normal. + fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { + &[ + "pos1", "pos1min", "pos1max", "width", "pos2", "pos2min", "pos2max", "height", + ] + } + fn apply_stat_transform( &self, query: &str, @@ -103,11 +121,117 @@ impl GeomTrait for Tile { group_by: &[String], parameters: &HashMap, _execute_query: &dyn Fn(&str) -> Result, - _dialect: &dyn crate::reader::SqlDialect, - _aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, + dialect: &dyn SqlDialect, + aesthetic_ctx: &crate::plot::aesthetic::AestheticContext, ) -> Result { - stat_tile(query, schema, aesthetics, group_by, parameters) + // When `aggregate` is set, collapse rows first, then run the standard + // tile parameter consolidation over the aggregated result. The wrapper + // re-aliases stat-prefixed columns back to `__ggsql_aes_*` so stat_tile + // sees the same column shape as it does in the unaggregated path. When + // aggregate explodes (multi-function), stat_tile is given an extended + // schema so it passes the synthetic `__ggsql_stat_aggregate__` tag + // through to layer.rs (which uses it to drive `partition_by`). + let (working_query, exploded) = if has_aggregate_param(parameters) { + let agg = stat_aggregate::apply( + query, + schema, + aesthetics, + group_by, + parameters, + dialect, + aesthetic_ctx, + self.aggregate_domain_aesthetics(), + )?; + match agg { + StatResult::Transformed { + query: agg_query, + stat_columns: agg_stats, + consumed_aesthetics, + .. + } => { + let exploded = agg_stats.iter().any(|s| s == "aggregate"); + ( + rename_agg_stats_to_aes(agg_query, &consumed_aesthetics), + exploded, + ) + } + StatResult::Identity => (query.to_string(), false), + } + } else { + (query.to_string(), false) + }; + + // For exploded aggregate, splice the synthetic stat column into the + // schema so stat_tile's pass-through projection emits it. Avoids + // dropping the per-row function tag that `partition_by` needs. + let extended_schema: Schema; + let schema_for_tile = if exploded { + extended_schema = schema + .iter() + .cloned() + .chain(std::iter::once(ColumnInfo { + name: naming::stat_column("aggregate"), + dtype: arrow::datatypes::DataType::Utf8, + is_discrete: true, + min: None, + max: None, + })) + .collect(); + &extended_schema + } else { + schema + }; + + let tile_result = + stat_tile(&working_query, schema_for_tile, aesthetics, group_by, parameters)?; + + if exploded { + if let StatResult::Transformed { + query, + mut stat_columns, + dummy_columns, + consumed_aesthetics, + } = tile_result + { + if !stat_columns.iter().any(|s| s == "aggregate") { + stat_columns.push("aggregate".to_string()); + } + return Ok(StatResult::Transformed { + query, + stat_columns, + dummy_columns, + consumed_aesthetics, + }); + } + } + Ok(tile_result) + } +} + +/// Wrap an aggregated query so each `__ggsql_stat___` column is also +/// exposed as `__ggsql_aes___`. Lets downstream stages treat the +/// aggregated values as if they were original aesthetic columns, which is +/// exactly the substitution the tile layer wants when only material +/// aesthetics get aggregated. +fn rename_agg_stats_to_aes(agg_query: String, consumed: &[String]) -> String { + if consumed.is_empty() { + return agg_query; } + let aliases: Vec = consumed + .iter() + .map(|aes| { + format!( + "{} AS {}", + naming::quote_ident(&naming::stat_column(aes)), + naming::quote_ident(&naming::aesthetic_column(aes)), + ) + }) + .collect(); + format!( + "SELECT *, {} FROM ({}) AS \"__ggsql_post_agg__\"", + aliases.join(", "), + agg_query + ) } impl std::fmt::Display for Tile { @@ -899,6 +1023,139 @@ mod tests { } } + #[test] + fn test_aggregate_dispatches_to_aggregate_then_tile() { + use crate::plot::aesthetic::AestheticContext; + use crate::reader::AnsiDialect; + + let mut aesthetics = Mappings::new(); + for aes in ["pos1", "pos2", "fill"] { + aesthetics.insert( + aes.to_string(), + AestheticValue::standard_column(naming::aesthetic_column(aes)), + ); + } + // Heatmap shape: discrete x and y, continuous fill. + let mut schema = create_schema(&["pos1", "pos2"]); + schema.push(ColumnInfo { + name: "__ggsql_aes_fill__".to_string(), + dtype: DataType::Float64, + is_discrete: false, + min: None, + max: None, + }); + let ctx = AestheticContext::from_static(&["x", "y"], &[]); + let mut parameters = HashMap::new(); + parameters.insert( + "aggregate".to_string(), + ParameterValue::String("mean".to_string()), + ); + + let result = Tile + .apply_stat_transform( + "SELECT * FROM data", + &schema, + &aesthetics, + &[], + ¶meters, + &|_| panic!("execute_query should not run during stat building"), + &AnsiDialect, + &ctx, + ) + .unwrap(); + + match result { + StatResult::Transformed { query, .. } => { + // Aggregate stage: GROUP BY pos1/pos2, AVG of fill into a stat column. + assert!(query.contains("GROUP BY"), "expected GROUP BY, got: {query}"); + assert!( + query.contains("AVG(\"__ggsql_aes_fill__\")"), + "expected AVG over fill, got: {query}" + ); + // Re-alias stage: stat fill column re-exposed as the aesthetic name. + let expected_alias = format!( + "{} AS {}", + naming::quote_ident(&naming::stat_column("fill")), + naming::quote_ident(&naming::aesthetic_column("fill")), + ); + assert!( + query.contains(&expected_alias), + "expected re-alias '{expected_alias}', got: {query}" + ); + // Tile stage: discrete-x position computation runs on top. + assert!( + query.contains("\"__ggsql_aes_pos1__\" AS \"__ggsql_stat_pos1"), + "expected tile pos1 stat, got: {query}" + ); + } + _ => panic!("expected Transformed"), + } + } + + #[test] + fn test_aggregate_explosion_propagates_synthetic_column() { + use crate::plot::aesthetic::AestheticContext; + use crate::reader::AnsiDialect; + + let mut aesthetics = Mappings::new(); + for aes in ["pos1", "pos2", "fill"] { + aesthetics.insert( + aes.to_string(), + AestheticValue::standard_column(naming::aesthetic_column(aes)), + ); + } + let mut schema = create_schema(&["pos1", "pos2"]); + schema.push(ColumnInfo { + name: "__ggsql_aes_fill__".to_string(), + dtype: DataType::Float64, + is_discrete: false, + min: None, + max: None, + }); + let ctx = AestheticContext::from_static(&["x", "y"], &[]); + let mut parameters = HashMap::new(); + parameters.insert( + "aggregate".to_string(), + ParameterValue::Array(vec![ + crate::plot::types::ArrayElement::String("fill:min".to_string()), + crate::plot::types::ArrayElement::String("fill:max".to_string()), + ]), + ); + + let result = Tile + .apply_stat_transform( + "SELECT * FROM data", + &schema, + &aesthetics, + &[], + ¶meters, + &|_| panic!("execute_query should not run during stat building"), + &AnsiDialect, + &ctx, + ) + .unwrap(); + + match result { + StatResult::Transformed { + query, + stat_columns, + .. + } => { + assert!(query.contains("UNION ALL"), "expected UNION ALL, got: {query}"); + let synth = naming::stat_column("aggregate"); + assert!( + query.contains(&naming::quote_ident(&synth)), + "synthetic aggregate column dropped from query: {query}" + ); + assert!( + stat_columns.iter().any(|s| s == "aggregate"), + "stat_columns missing 'aggregate' tag: {stat_columns:?}" + ); + } + _ => panic!("expected Transformed"), + } + } + #[test] fn test_setting_width_as_fallback() { // Test that SETTING width/height are used when no MAPPING is provided From 840fc6e57ac826dfdca718f071e7645a9e154f53 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 11:30:51 +0200 Subject: [PATCH 24/33] defer scaling of aggregated columns --- doc/syntax/layer/type/range.qmd | 26 ++++--- src/execute/layer.rs | 36 ++++++++++ src/execute/mod.rs | 21 +++++- src/execute/scale.rs | 90 ++++++++++++++++++++----- src/plot/layer/geom/mod.rs | 6 ++ src/plot/layer/geom/stat_aggregate.rs | 97 ++++++++++++++++++++++++++- 6 files changed, 242 insertions(+), 34 deletions(-) diff --git a/doc/syntax/layer/type/range.qmd b/doc/syntax/layer/type/range.qmd index 9464c1c6..35771ef9 100644 --- a/doc/syntax/layer/type/range.qmd +++ b/doc/syntax/layer/type/range.qmd @@ -112,26 +112,24 @@ DRAW range SETTING width => null ``` +Rather than precomputing the values and plotting them, you can use the aggregate functionality to calculate the relevant statistics dynamically: + ```{ggsql} VISUALISE Date AS x, Temp AS ymin, Temp AS ymax, Temp AS color FROM ggsql:airquality DRAW range REMAPPING aggregate AS linewidth - SETTING aggregate => ('x:first', 'ymin:first', 'ymin:min', 'ymax:last', 'ymax:max', 'color:diff'), width => null + SETTING + aggregate => ( + 'x:first', + 'ymin:first', 'ymin:min', + 'ymax:last', 'ymax:max', + 'color:diff' + ), + width => null PARTITION BY Week SCALE linewidth TO (5, 1) -SCALE color FROM (-20, 20) TO managua - SETTING reverse => true, oob => 'keep' +SCALE BINNED color TO ('steelblue', 'firebrick') + SETTING breaks => (-20, 0, 20) ``` -Use aggregation to calculate bounds dynamically - -```{ggsql} -VISUALISE body_mass AS x, species AS y FROM ggsql:penguins -DRAW range - MAPPING body_mass AS xmin, body_mass AS xmax - SETTING aggregate => ('min', 'max'), width => null -DRAW point - REMAPPING aggregate AS fill - SETTING aggregate => ('x:min', 'x:max'), size => 20, opacity => 1 -``` diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 8ada0d8c..fa95923d 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -276,10 +276,24 @@ pub fn apply_pre_stat_transform( aesthetic_schema: &Schema, scales: &[Scale], dialect: &dyn SqlDialect, + aesthetic_ctx: &AestheticContext, ) -> String { let mut transform_exprs: Vec<(String, String)> = vec![]; let mut transformed_columns: HashSet = HashSet::new(); + // When a layer has `aggregate => …`, scale-driven rewrites are deferred to + // after the stat for aesthetics where running them up-front would defeat + // the aggregate. The post-stat machinery (`apply_post_stat_binning`, + // `apply_scale_oob`) picks the deferred ones up against the aggregated + // values in the materialised DataFrame. + let agg_buckets = crate::plot::layer::geom::stat_aggregate::aggregated_aesthetics( + &layer.parameters, + &layer.mappings, + aesthetic_schema, + aesthetic_ctx, + layer.geom.aggregate_domain_aesthetics(), + ); + // Check layer mappings for aesthetics with scales that need pre-stat transformation // Handles both column mappings and literal mappings (which are injected as synthetic columns) for (aesthetic, value) in &layer.mappings.aesthetics { @@ -308,6 +322,27 @@ pub fn apply_pre_stat_transform( // Find scale for this aesthetic if let Some(scale) = scales.iter().find(|s| s.aesthetic == *aesthetic) { if let Some(ref scale_type) = scale.scale_type { + // Defer this rewrite when the layer aggregates and the scale + // semantics call for it (see post-stat machinery for how the + // deferred rewrite actually runs). `Binned` only defers when + // the aesthetic is *explicitly* targeted (untargeted Binned + // still drives meaningful pre-stat grouping); OOB-flavoured + // rewrites defer whenever the aesthetic is being aggregated. + if let Some((ref targeted, ref aggregated)) = agg_buckets { + use crate::plot::scale::ScaleTypeKind; + let kind = scale_type.scale_type_kind(); + let defer = match kind { + ScaleTypeKind::Binned => targeted.contains(aesthetic), + ScaleTypeKind::Continuous + | ScaleTypeKind::Discrete + | ScaleTypeKind::Ordinal => aggregated.contains(aesthetic), + ScaleTypeKind::Identity => false, + }; + if defer { + continue; + } + } + // Get pre-stat SQL transformation from scale type (if applicable) // Each scale type's pre_stat_transform_sql() returns None if not applicable if let Some(sql) = @@ -488,6 +523,7 @@ where &aesthetic_schema, scales, dialect, + aesthetic_ctx, ); // Build group_by columns from partition_by diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 40efdfc8..7cbbce83 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -723,6 +723,22 @@ fn add_discrete_columns_to_partition_by( excluded_aesthetics.insert("label"); } + // When aggregate is active, an explicitly-targeted Binned aesthetic + // shouldn't auto-promote to a group key — the user is summarising the + // raw values and the binning runs post-stat against the aggregate + // output. Untargeted Binned still groups, so binning can drive + // meaningful aggregation buckets in the common case. + let agg_targeted: HashSet = + crate::plot::layer::geom::stat_aggregate::aggregated_aesthetics( + &layer.parameters, + &layer.mappings, + schema, + aesthetic_ctx, + layer.geom.aggregate_domain_aesthetics(), + ) + .map(|(t, _)| t) + .unwrap_or_default(); + for (aesthetic, value) in &layer.mappings.aesthetics { // Skip position aesthetics - these should not trigger auto-grouping. // Stats that need to group by position aesthetics (like bar/histogram) @@ -754,9 +770,8 @@ fn add_discrete_columns_to_partition_by( let is_discrete = if let Some(scale) = scale_map.get(primary_aes) { if let Some(ref scale_type) = scale.scale_type { match scale_type.scale_type_kind() { - ScaleTypeKind::Discrete - | ScaleTypeKind::Binned - | ScaleTypeKind::Ordinal => true, + ScaleTypeKind::Discrete | ScaleTypeKind::Ordinal => true, + ScaleTypeKind::Binned => !agg_targeted.contains(aesthetic), ScaleTypeKind::Continuous => false, ScaleTypeKind::Identity => discrete_columns.contains(col), } diff --git a/src/execute/scale.rs b/src/execute/scale.rs index 9397ed47..afbcfa6d 100644 --- a/src/execute/scale.rs +++ b/src/execute/scale.rs @@ -151,6 +151,25 @@ pub fn apply_post_stat_binning( ) -> Result<()> { let aesthetic_ctx = spec.get_aesthetic_context(); + // Per-layer set of aesthetics that the aggregate stat *explicitly targets*. + // Targeted aesthetics had their pre-stat binning deferred (see + // `apply_pre_stat_transform`), so the materialised DataFrame still holds + // raw aggregate output for them — we need to bin those columns here. + // Untargeted aesthetics were binned pre-stat and the SQL `CASE WHEN` is + // already baked into the column, so the existing `__ggsql_aes_*` skip + // still applies for those. + let targeted_per_layer: Vec> = spec + .layers + .iter() + .map(|layer| { + crate::plot::layer::geom::stat_aggregate::targeted_aesthetics( + &layer.parameters, + &layer.mappings, + &aesthetic_ctx, + ) + }) + .collect(); + for scale in &spec.scales { // Only process Binned scales match &scale.scale_type { @@ -177,32 +196,49 @@ pub fn apply_post_stat_binning( _ => true, }; - // Find columns for this aesthetic across layers - let column_sources = find_columns_for_aesthetic_with_sources( - &spec.layers, - &scale.aesthetic, - data_map, - &aesthetic_ctx, - ); + // Walk layers directly so we can decide per-layer whether an + // aesthetic-named column was deferred (needs binning here) or + // already binned upstream by the pre-stat SQL. + let aesthetics_to_check = aesthetic_ctx + .internal_position_family(&scale.aesthetic) + .map(|f| f.to_vec()) + .unwrap_or_else(|| vec![scale.aesthetic.clone()]); + + for (idx, layer) in spec.layers.iter().enumerate() { + let data_key = naming::layer_key(idx); + if !data_map.contains_key(&data_key) { + continue; + } - // Apply binning to each column - for (data_key, col_name) in column_sources { - if let Some(df) = data_map.get(&data_key) { - // Skip if column doesn't exist in this data source + for aes_name in &aesthetics_to_check { + let col_name = match layer.mappings.get(aes_name) { + Some(crate::AestheticValue::Column { name, .. }) => name.clone(), + _ => continue, + }; + + let df = match data_map.get(&data_key) { + Some(d) => d, + None => continue, + }; if df.column(&col_name).is_err() { continue; } - // Skip post-stat binning for aesthetic columns (like __ggsql_aes_x__) - // because pre_stat_transform already binned them via SQL. - // Post-stat binning only applies to stat columns or remapped aesthetics. - if naming::is_aesthetic_column(&col_name) { + // Skip post-stat binning for aesthetic columns that were + // already binned via pre_stat_transform's CASE WHEN. The + // exception is when the layer's aggregate explicitly targets + // this aesthetic — in that case binning was deferred and the + // column holds the raw aggregate output that needs binning + // now. + if naming::is_aesthetic_column(&col_name) + && !targeted_per_layer[idx].contains(aes_name) + { continue; } let binned_df = apply_binning_to_dataframe(df, &col_name, &break_values, closed_left)?; - data_map.insert(data_key, binned_df); + data_map.insert(data_key.clone(), binned_df); } } } @@ -490,6 +526,22 @@ pub fn apply_pre_stat_resolve(spec: &mut Plot, layer_schemas: &[Schema]) -> Resu let aesthetic_ctx = spec.get_aesthetic_context(); + // Aesthetics that any layer's `aggregate` setting explicitly targets. Their + // BINNED scales must be resolved post-stat — the relevant column range is + // the aggregated output, not the raw input. Leaving them un-resolved here + // means `resolved == false` and `resolve_scales` will pick them up after + // the data is materialised. + let mut targeted_in_any_layer: HashSet = HashSet::new(); + for layer in &spec.layers { + for aes in crate::plot::layer::geom::stat_aggregate::targeted_aesthetics( + &layer.parameters, + &layer.mappings, + &aesthetic_ctx, + ) { + targeted_in_any_layer.insert(aes); + } + } + for scale in &mut spec.scales { // Only pre-resolve Binned scales let scale_type = match &scale.scale_type { @@ -497,6 +549,12 @@ pub fn apply_pre_stat_resolve(spec: &mut Plot, layer_schemas: &[Schema]) -> Resu _ => continue, }; + // Defer resolution for aesthetics targeted by aggregate so breaks + // come from the post-stat range. + if targeted_in_any_layer.contains(&scale.aesthetic) { + continue; + } + // Find all ColumnInfos for this aesthetic from schemas let column_infos = find_schema_columns_for_aesthetic( &spec.layers, diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index b9493668..fff38ae1 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -512,6 +512,12 @@ impl Geom { self.0.supports_aggregate() } + /// Aesthetics the Aggregate stat must keep as group keys rather than + /// aggregating, even if their bound column is continuous. + pub fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { + self.0.aggregate_domain_aesthetics() + } + /// Validate aesthetic mappings pub fn validate_aesthetics(&self, mappings: &Mappings) -> std::result::Result<(), String> { self.0.validate_aesthetics(mappings) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 049b1509..1414f6b6 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -25,7 +25,7 @@ //! Numeric mappings without a target *or* applicable default are dropped with //! a warning to stderr. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use super::types::StatResult; use crate::naming; @@ -605,6 +605,101 @@ fn is_upper_half(internal_aes: &str) -> bool { internal_aes.ends_with("max") || internal_aes.ends_with("end") } +/// Compute the set of internal aesthetic names that the layer's `aggregate` +/// setting *explicitly targets*. Lighter than [`aggregated_aesthetics`] — +/// doesn't need a schema — so post-stat callers can use it without rebuilding +/// type information from a materialised DataFrame. +pub fn targeted_aesthetics( + parameters: &HashMap, + aesthetics: &Mappings, + aesthetic_ctx: &AestheticContext, +) -> HashSet { + let raw = match parameters.get("aggregate") { + Some(v) if !matches!(v, ParameterValue::Null) => v, + _ => return HashSet::new(), + }; + let spec = match parse_aggregate_param(raw).ok().flatten() { + Some(s) => s, + None => return HashSet::new(), + }; + let mut targeted: HashSet = HashSet::new(); + for (user_aes, _fns) in &spec.targets { + for internal in resolve_target_aesthetic(user_aes, aesthetics, aesthetic_ctx) { + targeted.insert(internal); + } + } + targeted +} + +/// Compute, for a layer's `aggregate` setting, which internal aesthetic names +/// will be (a) *explicitly targeted* by `aggregate => ':'` and +/// (b) *aggregated* by the stat (either targeted OR a numeric mapping that an +/// untargeted default applies to). +/// +/// The execute pipeline uses this to decide whether to defer scale-driven +/// pre-stat rewrites (`SCALE BINNED `, `SCALE FROM […]`, …) until +/// after the stat. The bucketing here mirrors the per-mapping branching in +/// [`apply`]; both must stay in sync. +/// +/// Returns `None` when `aggregate` is unset, null, or fails to parse — i.e. +/// when the stat will return `Identity` and no aesthetic is touched. Parse +/// errors are swallowed; the stat itself surfaces a clean diagnostic. +pub fn aggregated_aesthetics( + parameters: &HashMap, + aesthetics: &Mappings, + schema: &Schema, + aesthetic_ctx: &AestheticContext, + domain_aesthetics: &[&'static str], +) -> Option<(HashSet, HashSet)> { + let raw = parameters.get("aggregate")?; + if matches!(raw, ParameterValue::Null) { + return None; + } + let spec = parse_aggregate_param(raw).ok()??; + + let mut targeted: HashSet = HashSet::new(); + for (user_aes, _fns) in &spec.targets { + for internal in resolve_target_aesthetic(user_aes, aesthetics, aesthetic_ctx) { + targeted.insert(internal); + } + } + + let mut aggregated: HashSet = targeted.clone(); + let mut entries: Vec<(&String, &crate::AestheticValue)> = + aesthetics.aesthetics.iter().collect(); + entries.sort_by(|a, b| a.0.cmp(b.0)); + for (aes, value) in entries { + let col = match value.column_name() { + Some(c) => c, + None => continue, + }; + if domain_aesthetics.contains(&aes.as_str()) { + continue; + } + let is_discrete = schema + .iter() + .find(|c| c.name == col) + .map(|c| c.is_discrete) + .unwrap_or(false); + if is_discrete { + continue; + } + if targeted.contains(aes) { + continue; + } + let default_applies = if is_upper_half(aes) { + spec.default_upper.is_some() || spec.default_lower.is_some() + } else { + spec.default_lower.is_some() + }; + if default_applies { + aggregated.insert(aes.clone()); + } + } + + Some((targeted, aggregated)) +} + /// Apply the Aggregate stat to a layer query. /// /// Returns `StatResult::Identity` when the `aggregate` parameter is unset, null, From bdcd700e661af0422784e9c067f8f74e1e2b7255 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 11:34:27 +0200 Subject: [PATCH 25/33] update SKILL --- doc/vendor/SKILL.md | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/doc/vendor/SKILL.md b/doc/vendor/SKILL.md index 90c41dcf..6039f00c 100644 --- a/doc/vendor/SKILL.md +++ b/doc/vendor/SKILL.md @@ -129,6 +129,35 @@ SETTING position => 'dodge' -- side by side (default for boxplot, violin) SETTING position => 'jitter' -- random offset ``` +**Aggregate** collapses each group to a single row, replacing every numeric mapping in place with its aggregated value. Groups = `PARTITION BY` columns + all discrete mappings. Supported by `point`, `line`, `path`, `bar`, `area`, `ribbon`, `range`, `segment`, `rule`, `text`, `tile`. Not supported by `histogram`, `density`, `smooth`, `boxplot`, `violin` (they have their own stats). + +```ggsql +SETTING aggregate => '' -- single +SETTING aggregate => ('', '', …) -- list +``` + +Each `` is either: +- **Untargeted** — `''`. Applies to every numeric mapping without an explicit target. With two untargeted defaults, the first applies to lower-side aesthetics (`x`/`xmin`/etc.) plus all non-range layers, the second to upper-side (`xend`/`xmax`). More than two untargeted defaults is an error. +- **Targeted** — `':'`. Applies `func` to the named aesthetic only. Overrides any untargeted default for that aesthetic. + +Functions: +- Standard reductions: `count`, `sum`, `prod`, `min`, `max`, `range` (max−min), `mid` ((min+max)/2), `mean`, `median`, `geomean`, `harmean`, `rms`, `sdev`, `var`, `iqr`, `se`, `p05`–`p95`. +- Positional (rely on upstream `ORDER BY` for deterministic order): `first`, `last`, `diff` (last − first). +- Band: `±[]`, e.g. `'mean+1.96sdev'`, `'median-iqr'`. Offsets: `mean`, `median`, `geomean`, `harmean`, `rms`, `sum`, `prod`, `min`, `max`, `mid`, `p05`–`p95`. Expansions: `sdev`, `se`, `var`, `iqr`, `range`. + +**Explosion** — targeting the same aesthetic with multiple functions emits one row per function per group. A synthetic `aggregate` column tags each row with the function name. Use `REMAPPING aggregate AS ` to drive another aesthetic from it. When several aesthetics are exploded with the same length, they explode in lockstep (row 1 = each target's first function, row 2 = second, …); single-function targets are reused on every row. Mixing target lengths > 1 is an error. + +```ggsql +-- min/max envelope as two lines per group, coloured by function +DRAW line + MAPPING Date AS x, Temp AS y + REMAPPING aggregate AS color + SETTING aggregate => ('y:min', 'y:max') + PARTITION BY Year +``` + +**Scale interaction** — for an aesthetic that is *targeted* by aggregate, `SCALE BINNED ` runs **after** aggregation (otherwise the diff/mean/etc. would cancel within a bin). Untargeted `SCALE BINNED` still bins pre-aggregate so the bins can drive grouping. Continuous censoring (`SCALE FROM (lo, hi)`) and discrete OOB filtering defer to post-aggregate whenever the aesthetic is being aggregated (targeted or untargeted default). + ### FILTER SQL WHERE condition applied to layer data. Content is passed to the database: @@ -461,6 +490,21 @@ VISUALISE DRAW line MAPPING Date AS x, value AS y, 'Temperature' AS color FROM temps DRAW point MAPPING Date AS x, value AS y, 'Ozone' AS color FROM ozone SCALE x VIA date + +-- Per-week summary: open/close range, weekly temperature change (binned post-aggregate) +VISUALISE Date AS x, Temp AS ymin, Temp AS ymax, Temp AS color + FROM ggsql:airquality +DRAW range + SETTING aggregate => ('x:first', 'ymin:first', 'ymax:last', 'color:diff'), + width => null + PARTITION BY Week +SCALE BINNED color + +-- Mean ± 1.96·sdev band per group, drawn as a ribbon +VISUALISE Day AS x, Temp AS ymin, Temp AS ymax FROM ggsql:airquality +DRAW ribbon + SETTING aggregate => ('mean-1.96sdev', 'mean+1.96sdev') + PARTITION BY Month ``` --- From 65c504cddc961b287523483f3e26c27cbf9dcdd7 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 11:34:47 +0200 Subject: [PATCH 26/33] reformat --- src/plot/layer/geom/stat_aggregate.rs | 25 +++++++++++++++---------- src/plot/layer/geom/tile.rs | 19 +++++++++++++++---- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 1414f6b6..0a5bb56c 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -45,8 +45,7 @@ pub const AGG_NAMES: &[&str] = &[ "min", "max", "range", "mid", // Central tendency "mean", "geomean", "harmean", "rms", "median", // Spread (standalone) "sdev", "var", "iqr", // Percentiles - "p05", "p10", "p25", "p50", "p75", "p90", - "p95", // Positional (row order in the group) + "p05", "p10", "p25", "p50", "p75", "p90", "p95", // Positional (row order in the group) "first", "last", "diff", ]; @@ -1045,10 +1044,7 @@ fn build_aggregate_query( }) .collect(); - format!( - "{with_clause} {body}", - body = branches.join(" UNION ALL "), - ) + format!("{with_clause} {body}", body = branches.join(" UNION ALL "),) } fn func_literal(s: &str) -> String { @@ -1525,8 +1521,9 @@ mod tests { match result { StatResult::Transformed { query, .. } => { assert!( - query - .contains("(MIN(\"__ggsql_aes_pos1__\") + MAX(\"__ggsql_aes_pos1__\")) / 2.0"), + query.contains( + "(MIN(\"__ggsql_aes_pos1__\") + MAX(\"__ggsql_aes_pos1__\")) / 2.0" + ), "{}", query ); @@ -1643,8 +1640,16 @@ mod tests { .unwrap(); match result { StatResult::Transformed { query, .. } => { - assert!(query.contains("FIRST(\"__ggsql_aes_pos2min__\")"), "{}", query); - assert!(query.contains("LAST(\"__ggsql_aes_pos2max__\")"), "{}", query); + assert!( + query.contains("FIRST(\"__ggsql_aes_pos2min__\")"), + "{}", + query + ); + assert!( + query.contains("LAST(\"__ggsql_aes_pos2max__\")"), + "{}", + query + ); } _ => panic!("expected Transformed"), } diff --git a/src/plot/layer/geom/tile.rs b/src/plot/layer/geom/tile.rs index 63bcd68d..8c133f6f 100644 --- a/src/plot/layer/geom/tile.rs +++ b/src/plot/layer/geom/tile.rs @@ -182,8 +182,13 @@ impl GeomTrait for Tile { schema }; - let tile_result = - stat_tile(&working_query, schema_for_tile, aesthetics, group_by, parameters)?; + let tile_result = stat_tile( + &working_query, + schema_for_tile, + aesthetics, + group_by, + parameters, + )?; if exploded { if let StatResult::Transformed { @@ -1067,7 +1072,10 @@ mod tests { match result { StatResult::Transformed { query, .. } => { // Aggregate stage: GROUP BY pos1/pos2, AVG of fill into a stat column. - assert!(query.contains("GROUP BY"), "expected GROUP BY, got: {query}"); + assert!( + query.contains("GROUP BY"), + "expected GROUP BY, got: {query}" + ); assert!( query.contains("AVG(\"__ggsql_aes_fill__\")"), "expected AVG over fill, got: {query}" @@ -1141,7 +1149,10 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("UNION ALL"), "expected UNION ALL, got: {query}"); + assert!( + query.contains("UNION ALL"), + "expected UNION ALL, got: {query}" + ); let synth = naming::stat_column("aggregate"); assert!( query.contains(&naming::quote_ident(&synth)), From a2b24b9ed7bd0dc49d14f8c58357d31b4487413a Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 13:05:39 +0200 Subject: [PATCH 27/33] Merge aggregate_domain_aesthetics and supports_aggregate into one --- src/execute/layer.rs | 2 +- src/execute/mod.rs | 2 +- src/plot/layer/geom/area.rs | 10 +++---- src/plot/layer/geom/arrow.rs | 4 +-- src/plot/layer/geom/bar.rs | 6 ++--- src/plot/layer/geom/line.rs | 10 +++---- src/plot/layer/geom/mod.rs | 48 +++++++++++++++++++--------------- src/plot/layer/geom/point.rs | 4 +-- src/plot/layer/geom/range.rs | 4 +-- src/plot/layer/geom/ribbon.rs | 10 +++---- src/plot/layer/geom/rule.rs | 4 +-- src/plot/layer/geom/segment.rs | 4 +-- src/plot/layer/geom/text.rs | 4 +-- src/plot/layer/geom/tile.rs | 12 +++------ 14 files changed, 57 insertions(+), 67 deletions(-) diff --git a/src/execute/layer.rs b/src/execute/layer.rs index fa95923d..7a75a1d4 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -291,7 +291,7 @@ pub fn apply_pre_stat_transform( &layer.mappings, aesthetic_schema, aesthetic_ctx, - layer.geom.aggregate_domain_aesthetics(), + layer.geom.aggregate_domain_aesthetics().unwrap_or(&[]), ); // Check layer mappings for aesthetics with scales that need pre-stat transformation diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 7cbbce83..45620e8d 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -734,7 +734,7 @@ fn add_discrete_columns_to_partition_by( &layer.mappings, schema, aesthetic_ctx, - layer.geom.aggregate_domain_aesthetics(), + layer.geom.aggregate_domain_aesthetics().unwrap_or(&[]), ) .map(|(t, _)| t) .unwrap_or_default(); diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index e4d1230a..66a5bc58 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -55,12 +55,8 @@ impl GeomTrait for Area { PARAMS } - fn supports_aggregate(&self) -> bool { - true - } - - fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { - &["pos1"] + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&["pos1"]) } fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { @@ -87,7 +83,7 @@ impl GeomTrait for Area { parameters, dialect, aesthetic_ctx, - self.aggregate_domain_aesthetics(), + self.aggregate_domain_aesthetics().unwrap_or(&[]), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/arrow.rs b/src/plot/layer/geom/arrow.rs index 2e3369d2..9f6af1ee 100644 --- a/src/plot/layer/geom/arrow.rs +++ b/src/plot/layer/geom/arrow.rs @@ -40,8 +40,8 @@ impl GeomTrait for Arrow { PARAMS } - fn supports_aggregate(&self) -> bool { - true + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&[]) } } diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index f2990467..96b9d429 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -80,8 +80,8 @@ impl GeomTrait for Bar { &["pos1", "pos2", "weight"] } - fn supports_aggregate(&self) -> bool { - true + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&[]) } fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { @@ -108,7 +108,7 @@ impl GeomTrait for Bar { parameters, dialect, aesthetic_ctx, - self.aggregate_domain_aesthetics(), + self.aggregate_domain_aesthetics().unwrap_or(&[]), ); } stat_bar_count(query, schema, aesthetics, group_by) diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index a6fd8edd..5acff920 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -41,12 +41,8 @@ impl GeomTrait for Line { PARAMS } - fn supports_aggregate(&self) -> bool { - true - } - - fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { - &["pos1"] + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&["pos1"]) } fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { @@ -73,7 +69,7 @@ impl GeomTrait for Line { parameters, dialect, aesthetic_ctx, - self.aggregate_domain_aesthetics(), + self.aggregate_domain_aesthetics().unwrap_or(&[]), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index fff38ae1..26f8eade 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -194,26 +194,28 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { false } - /// Whether this geom accepts the `aggregate` SETTING parameter. + /// Whether the Aggregate stat applies to this geom, and which aesthetics + /// stay as group keys when it does. /// - /// Geoms that opt in gain a generic Aggregate stat that groups by discrete - /// mappings + PARTITION BY and emits one row per group, replacing every - /// numeric mapping (positional and material) with its aggregated value. - /// Statistical geoms (histogram, density, smooth, boxplot, violin) leave - /// this `false` to keep their bespoke stats. - fn supports_aggregate(&self) -> bool { - false + /// - `None` — geom doesn't accept the `aggregate` SETTING. Used by the + /// statistical geoms (`histogram`, `density`, `smooth`, `boxplot`, + /// `violin`) that have their own bespoke stats. + /// - `Some(&[])` — geom opts in; the stat groups by discrete mappings + + /// `PARTITION BY` only. Most non-statistical geoms. + /// - `Some(&[, …])` — geom opts in *and* pins the listed aesthetics + /// as group keys regardless of their column's continuity. Used by + /// `line`/`area`/`ribbon` (domain axis) and `tile` (every spatial slot). + /// + /// `supports_aggregate()` is derived from this; geoms only override one + /// method to opt in. + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + None } - /// Aesthetics that the Aggregate stat must keep as group keys rather than - /// aggregating, even if their bound column is continuous. This is for - /// geoms like line/area/ribbon where one axis is the *domain* — the - /// natural group identity of each row — and the user expects "summarise - /// the other axis per domain value" without writing an explicit target. - /// - /// Default empty; line/area/ribbon override to `&["pos1"]`. - fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { - &[] + /// Whether this geom accepts the `aggregate` SETTING parameter. + /// Derived from `aggregate_domain_aesthetics`; do not override. + fn supports_aggregate(&self) -> bool { + self.aggregate_domain_aesthetics().is_some() } /// Apply statistical transformation to the layer query. @@ -233,7 +235,10 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { dialect: &dyn SqlDialect, aesthetic_ctx: &AestheticContext, ) -> Result { - if self.supports_aggregate() && has_aggregate_param(parameters) { + if let (Some(domain), true) = ( + self.aggregate_domain_aesthetics(), + has_aggregate_param(parameters), + ) { return stat_aggregate::apply( query, schema, @@ -242,7 +247,7 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { parameters, dialect, aesthetic_ctx, - self.aggregate_domain_aesthetics(), + domain, ); } Ok(StatResult::Identity) @@ -513,8 +518,9 @@ impl Geom { } /// Aesthetics the Aggregate stat must keep as group keys rather than - /// aggregating, even if their bound column is continuous. - pub fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { + /// aggregating, even if their bound column is continuous. `None` when + /// the geom doesn't accept the `aggregate` setting. + pub fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { self.0.aggregate_domain_aesthetics() } diff --git a/src/plot/layer/geom/point.rs b/src/plot/layer/geom/point.rs index 5101f2f0..9202195e 100644 --- a/src/plot/layer/geom/point.rs +++ b/src/plot/layer/geom/point.rs @@ -39,8 +39,8 @@ impl GeomTrait for Point { PARAMS } - fn supports_aggregate(&self) -> bool { - true + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&[]) } } diff --git a/src/plot/layer/geom/range.rs b/src/plot/layer/geom/range.rs index 5cacd874..a6e0490a 100644 --- a/src/plot/layer/geom/range.rs +++ b/src/plot/layer/geom/range.rs @@ -45,8 +45,8 @@ impl GeomTrait for Range { PARAMS } - fn supports_aggregate(&self) -> bool { - true + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&[]) } } diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index 5b2a390e..d2c686c3 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -40,12 +40,8 @@ impl GeomTrait for Ribbon { PARAMS } - fn supports_aggregate(&self) -> bool { - true - } - - fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { - &["pos1"] + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&["pos1"]) } fn needs_stat_transform(&self, _aesthetics: &Mappings) -> bool { @@ -72,7 +68,7 @@ impl GeomTrait for Ribbon { parameters, dialect, aesthetic_ctx, - self.aggregate_domain_aesthetics(), + self.aggregate_domain_aesthetics().unwrap_or(&[]), )? } else { StatResult::Identity diff --git a/src/plot/layer/geom/rule.rs b/src/plot/layer/geom/rule.rs index a495cb48..502d3724 100644 --- a/src/plot/layer/geom/rule.rs +++ b/src/plot/layer/geom/rule.rs @@ -25,8 +25,8 @@ impl GeomTrait for Rule { } } - fn supports_aggregate(&self) -> bool { - true + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&[]) } fn validate_aesthetics(&self, mappings: &crate::Mappings) -> std::result::Result<(), String> { diff --git a/src/plot/layer/geom/segment.rs b/src/plot/layer/geom/segment.rs index 4dd7e65f..58bacb07 100644 --- a/src/plot/layer/geom/segment.rs +++ b/src/plot/layer/geom/segment.rs @@ -39,8 +39,8 @@ impl GeomTrait for Segment { PARAMS } - fn supports_aggregate(&self) -> bool { - true + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&[]) } } diff --git a/src/plot/layer/geom/text.rs b/src/plot/layer/geom/text.rs index 5909c34d..d203e023 100644 --- a/src/plot/layer/geom/text.rs +++ b/src/plot/layer/geom/text.rs @@ -63,8 +63,8 @@ impl GeomTrait for Text { PARAMS } - fn supports_aggregate(&self) -> bool { - true + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&[]) } fn post_process( diff --git a/src/plot/layer/geom/tile.rs b/src/plot/layer/geom/tile.rs index 8c133f6f..99ec1c3d 100644 --- a/src/plot/layer/geom/tile.rs +++ b/src/plot/layer/geom/tile.rs @@ -99,18 +99,14 @@ impl GeomTrait for Tile { true } - fn supports_aggregate(&self) -> bool { - true - } - /// Every spatial slot is pinned as a group key — the rectangle's position /// and size *define* the group, they are never the thing being summarised. /// Material aesthetics (fill, stroke, opacity, …) pass through to the /// aggregate as normal. - fn aggregate_domain_aesthetics(&self) -> &'static [&'static str] { - &[ + fn aggregate_domain_aesthetics(&self) -> Option<&'static [&'static str]> { + Some(&[ "pos1", "pos1min", "pos1max", "width", "pos2", "pos2min", "pos2max", "height", - ] + ]) } fn apply_stat_transform( @@ -140,7 +136,7 @@ impl GeomTrait for Tile { parameters, dialect, aesthetic_ctx, - self.aggregate_domain_aesthetics(), + self.aggregate_domain_aesthetics().unwrap_or(&[]), )?; match agg { StatResult::Transformed { From 7014e93cbc5ec0d1df31f9c48fc1b6fb755ed3cc Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 13:50:54 +0200 Subject: [PATCH 28/33] avoid twice parsing --- src/plot/layer/mod.rs | 29 ++++++++++++++++++++++++----- src/validate.rs | 11 +++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/src/plot/layer/mod.rs b/src/plot/layer/mod.rs index 1cea17ac..d6948860 100644 --- a/src/plot/layer/mod.rs +++ b/src/plot/layer/mod.rs @@ -422,16 +422,35 @@ impl Layer { { validate_parameter(param_name, value, ¶m.constraint)?; } - // Or the shared `aggregate` param for Identity-stat geoms - else if param_name == "aggregate" && self.geom.supports_aggregate() { - crate::plot::layer::geom::stat_aggregate::validate_aggregate_param(value)?; - } - // Otherwise it's a valid aesthetic setting (no constraint validation needed) + // Otherwise it's a valid aesthetic setting (no constraint validation needed). + // + // The shared `aggregate` parameter is intentionally not parsed here. + // The execute pipeline parses it once in `stat_aggregate::apply` + // (where the result is actually used), so doing a parse-then-discard + // here would be redundant. Standalone validation paths + // (`validate.rs::validate`, used by `ggsql validate`) call + // [`validate_aggregate_setting`] explicitly to surface malformed + // aggregate settings without going through execute. } Ok(()) } + /// Validate the `aggregate` SETTING in isolation. Used by the standalone + /// validation path (`ggsql validate`) where the error wouldn't otherwise + /// surface — execute paths catch the same error inside + /// `stat_aggregate::apply` once the value is actually used. + pub fn validate_aggregate_setting(&self) -> std::result::Result<(), String> { + if !self.geom.supports_aggregate() { + return Ok(()); + } + let value = match self.parameters.get("aggregate") { + Some(v) => v, + None => return Ok(()), + }; + crate::plot::layer::geom::stat_aggregate::validate_aggregate_param(value) + } + /// Update layer mappings to use prefixed aesthetic column names. /// /// After building a layer query that creates aesthetic columns with prefixed names, diff --git a/src/validate.rs b/src/validate.rs index 1a134ef1..52a275b1 100644 --- a/src/validate.rs +++ b/src/validate.rs @@ -212,6 +212,17 @@ pub fn validate(query: &str) -> Result { location: None, }); } + + // The aggregate setting is validated in isolation here so the + // standalone validate path (which doesn't run the stat) still + // catches malformed `aggregate` values. The execute path skips + // this; `stat_aggregate::apply` parses + reports there. + if let Err(e) = layer.validate_aggregate_setting() { + errors.push(ValidationError { + message: format!("{}: {}", context, e), + location: None, + }); + } } } From 5ce2ddcf3748ccc2921d5ad4a5eadc987cee6c1b Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 13:53:32 +0200 Subject: [PATCH 29/33] refactor aggregate parsing --- src/plot/layer/geom/stat_aggregate.rs | 297 +++++++++++++------------- 1 file changed, 143 insertions(+), 154 deletions(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 0a5bb56c..bf99e009 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -26,6 +26,9 @@ //! a warning to stderr. use std::collections::{HashMap, HashSet}; +use std::sync::OnceLock; + +use regex::Regex; use super::types::StatResult; use crate::naming; @@ -74,7 +77,9 @@ pub struct AggSpec { #[derive(Debug, Clone, PartialEq)] pub struct Band { - pub sign: char, + /// Signed multiplier on the expansion. `+1.0` corresponds to `+`; + /// `-1.96` corresponds to `-1.96`. The sign and magnitude are + /// folded together so there's a single source of truth. pub mod_value: f64, pub expansion: &'static str, } @@ -83,71 +88,99 @@ fn resolve_static(name: &str, vocab: &'static [&'static str]) -> Option<&'static vocab.iter().copied().find(|v| *v == name) } -/// Parse an aggregate-function name into an `AggSpec`. Returns `None` on -/// invalid input (unknown stat, malformed band, or band with vocabulary -/// violation). -pub fn parse_agg_name(name: &str) -> Option { - if let Some(spec) = parse_band(name) { - return Some(spec); - } - resolve_static(name, AGG_NAMES).map(|offset| AggSpec { offset, band: None }) +/// Single regex covering one `aggregate` entry: optional `:` prefix, +/// required offset name, optional `±[]` band suffix. +/// +/// Capture groups: +/// 1. aesthetic prefix (anything up to the first `:`; structural-only — full +/// aesthetic resolution happens in `apply()`) +/// 2. offset name +/// 3. sign — present iff the entry has a band +/// 4. magnitude — optional, defaults to `1.0` +/// 5. expansion name +fn entry_re() -> &'static Regex { + static RE: OnceLock = OnceLock::new(); + RE.get_or_init(|| { + Regex::new(r"^(?:([^:]+):)?([a-z]+\d*)(?:([+-])(\d+(?:\.\d+)?)?([a-z]+))?$").unwrap() + }) +} + +/// Parsed shape of a single `aggregate` array entry. +struct ParsedEntry { + /// `Some(name)` when the entry has an `:` prefix; `None` for an + /// unprefixed default. Resolution to internal aesthetic names happens in + /// `apply()` via `resolve_target_aesthetic`. + aesthetic: Option, + spec: AggSpec, } -fn parse_band(name: &str) -> Option { - // Walk offsets longest-first so `median` matches before `mean`. - let mut offsets: Vec<&'static str> = OFFSET_STATS.to_vec(); - offsets.sort_by_key(|s| std::cmp::Reverse(s.len())); +fn parse_entry(entry: &str) -> std::result::Result { + let caps = entry_re() + .captures(entry) + .ok_or_else(|| format!("could not parse aggregate entry '{}'", entry))?; - for offset in offsets { - let rest = match name.strip_prefix(offset) { - Some(r) => r, - None => continue, - }; - let (sign, after_sign) = match rest.chars().next() { - Some('+') => ('+', &rest[1..]), - Some('-') => ('-', &rest[1..]), - _ => continue, - }; + let aesthetic = caps.get(1).map(|m| m.as_str().to_string()); + let offset_str = caps.get(2).unwrap().as_str(); + let band_present = caps.get(3).is_some(); - let (mod_value, expansion_str) = parse_mod_and_remainder(after_sign); - let expansion = match resolve_static(expansion_str, EXPANSION_STATS) { - Some(e) => e, - None => continue, + let band = if band_present { + let expansion_str = caps.get(5).unwrap().as_str(); + let expansion = resolve_static(expansion_str, EXPANSION_STATS).ok_or_else(|| { + format!( + "'{}': '{}' is not a valid expansion stat. Allowed expansions: {}", + entry, + expansion_str, + crate::or_list_quoted(EXPANSION_STATS, '\''), + ) + })?; + let magnitude: f64 = caps + .get(4) + .map_or(1.0, |m| m.as_str().parse().unwrap_or(1.0)); + let mod_value = if caps.get(3).unwrap().as_str() == "-" { + -magnitude + } else { + magnitude }; + Some(Band { + mod_value, + expansion, + }) + } else { + None + }; - return Some(AggSpec { - offset, - band: Some(Band { - sign, - mod_value, - expansion, - }), - }); - } - None -} + let offset = if band.is_some() { + resolve_static(offset_str, OFFSET_STATS).ok_or_else(|| { + if AGG_NAMES.contains(&offset_str) { + format!( + "'{}': '{}' is not a valid offset stat. Allowed offsets: {}", + entry, + offset_str, + crate::or_list_quoted(OFFSET_STATS, '\''), + ) + } else { + format!( + "'{}': '{}' is not a known stat. Allowed offsets: {}", + entry, + offset_str, + crate::or_list_quoted(OFFSET_STATS, '\''), + ) + } + })? + } else { + resolve_static(offset_str, AGG_NAMES).ok_or_else(|| { + format!( + "unknown aggregate function '{}'. Allowed: {} (or use a band like `mean+sdev`)", + offset_str, + crate::or_list_quoted(AGG_NAMES, '\''), + ) + })? + }; -fn parse_mod_and_remainder(s: &str) -> (f64, &str) { - let mut idx = 0; - let bytes = s.as_bytes(); - while idx < bytes.len() && bytes[idx].is_ascii_digit() { - idx += 1; - } - if idx < bytes.len() && bytes[idx] == b'.' { - let mut after_dot = idx + 1; - while after_dot < bytes.len() && bytes[after_dot].is_ascii_digit() { - after_dot += 1; - } - if after_dot > idx + 1 { - idx = after_dot; - } - } - if idx == 0 { - return (1.0, s); - } - let num_str = &s[..idx]; - let value: f64 = num_str.parse().unwrap_or(1.0); - (value, &s[idx..]) + Ok(ParsedEntry { + aesthetic, + spec: AggSpec { offset, band }, + }) } // ============================================================================= @@ -233,15 +266,17 @@ impl AggregateSpec { } /// Human-readable label for an `AggSpec`. Re-emits simple names verbatim and -/// reconstructs band names like `mean+sdev`. +/// reconstructs band names like `mean+sdev` / `mean-1.96sdev`. fn agg_label(spec: &AggSpec) -> String { match &spec.band { None => spec.offset.to_string(), Some(b) => { - if b.mod_value == 1.0 { - format!("{}{}{}", spec.offset, b.sign, b.expansion) + let sign = if b.mod_value < 0.0 { '-' } else { '+' }; + let magnitude = b.mod_value.abs(); + if magnitude == 1.0 { + format!("{}{}{}", spec.offset, sign, b.expansion) } else { - format!("{}{}{}{}", spec.offset, b.sign, b.mod_value, b.expansion) + format!("{}{}{}{}", spec.offset, sign, magnitude, b.expansion) } } } @@ -277,32 +312,26 @@ pub fn parse_aggregate_param( let mut spec = AggregateSpec::new(); for entry in entries { - if let Some((aes, func)) = split_target(entry) { - if aes.is_empty() { - return Err(format!("'{}': aesthetic prefix is empty", entry)); - } - if func.is_empty() { - return Err(format!("'{}': aggregate function is empty", entry)); - } - let agg = parse_agg_name(func) - .ok_or_else(|| format!("'{}': {}", entry, diagnose_invalid_function_name(func)))?; - // Append to existing list for this aesthetic, or create one. - if let Some((_, fns)) = spec.targets.iter_mut().find(|(a, _)| a == aes) { - fns.push(agg); - } else { - spec.targets.push((aes.to_string(), vec![agg])); + let parsed = parse_entry(entry)?; + match parsed.aesthetic { + Some(aes) => { + if let Some((_, fns)) = spec.targets.iter_mut().find(|(a, _)| *a == aes) { + fns.push(parsed.spec); + } else { + spec.targets.push((aes, vec![parsed.spec])); + } } - } else { - let agg = parse_agg_name(entry).ok_or_else(|| diagnose_invalid_function_name(entry))?; - if spec.default_lower.is_none() { - spec.default_lower = Some(agg); - } else if spec.default_upper.is_none() { - spec.default_upper = Some(agg); - } else { - return Err(format!( - "'aggregate' accepts at most two unprefixed defaults; got a third: '{}'", - entry - )); + None => { + if spec.default_lower.is_none() { + spec.default_lower = Some(parsed.spec); + } else if spec.default_upper.is_none() { + spec.default_upper = Some(parsed.spec); + } else { + return Err(format!( + "'aggregate' accepts at most two unprefixed defaults; got a third: '{}'", + entry + )); + } } } } @@ -330,64 +359,13 @@ pub fn parse_aggregate_param( Ok(Some(spec)) } -/// Split an entry into `(aesthetic, function)` if it contains a `:`. Returns -/// `None` for an unprefixed entry like `'mean'`. -fn split_target(entry: &str) -> Option<(&str, &str)> { - entry.split_once(':') -} - /// Validate the `aggregate` SETTING value at parse-time. Used by -/// `Layer::validate_settings`. Aesthetic-name resolution is deferred to -/// `apply()` because `AestheticContext` isn't available here. +/// `Layer::validate_aggregate_setting`. Aesthetic-name resolution is deferred +/// to `apply()` because `AestheticContext` isn't available here. pub fn validate_aggregate_param(value: &ParameterValue) -> std::result::Result<(), String> { parse_aggregate_param(value).map(|_| ()) } -/// Build a per-role error message for a name that didn't parse. Re-walks the -/// input with looser rules to identify which side (offset / expansion) failed. -fn diagnose_invalid_function_name(name: &str) -> String { - if let Some(sign_idx) = name.find(['+', '-']) { - let offset_str = &name[..sign_idx]; - let after_sign = &name[sign_idx + 1..]; - let (_mod_value, expansion_str) = parse_mod_and_remainder(after_sign); - - let offset_known_simple = AGG_NAMES.contains(&offset_str); - let offset_known_band = OFFSET_STATS.contains(&offset_str); - let expansion_known_band = EXPANSION_STATS.contains(&expansion_str); - - if !offset_known_band { - if offset_known_simple { - return format!( - "'{}': '{}' is not a valid offset stat. Allowed offsets: {}", - name, - offset_str, - crate::or_list_quoted(OFFSET_STATS, '\''), - ); - } - return format!( - "'{}': '{}' is not a known stat. Allowed offsets: {}", - name, - offset_str, - crate::or_list_quoted(OFFSET_STATS, '\''), - ); - } - if !expansion_known_band { - return format!( - "'{}': '{}' is not a valid expansion stat. Allowed expansions: {}", - name, - expansion_str, - crate::or_list_quoted(EXPANSION_STATS, '\''), - ); - } - return format!("'{}' is not a valid aggregate function name", name); - } - format!( - "unknown aggregate function '{}'. Allowed: {} (or use a band like `mean+sdev`)", - name, - crate::or_list_quoted(AGG_NAMES, '\''), - ) -} - // ============================================================================= // SQL fragment helpers (per-column aggregate expressions). // ============================================================================= @@ -465,21 +443,22 @@ fn agg_sql_inline(spec: &AggSpec, qcol: &str, dialect: &dyn SqlDialect) -> Optio None => Some(offset_sql), Some(band) => { let exp_sql = simple_stat_sql_inline(band.expansion, qcol, dialect)?; - Some(format_band( - &offset_sql, - band.sign, - band.mod_value, - &exp_sql, - )) + Some(format_band(&offset_sql, band.mod_value, &exp_sql)) } } } -fn format_band(offset: &str, sign: char, mod_value: f64, exp: &str) -> String { - if mod_value == 1.0 { +/// Format a band expression `(offset ± [magnitude *] expansion)`. The sign and +/// magnitude come folded together in `mod_value`; this splits them back out +/// only when emitting SQL so the output is readable (e.g. `(mean - 1.96 * sdev)` +/// rather than `(mean + -1.96 * sdev)`). +fn format_band(offset: &str, mod_value: f64, exp: &str) -> String { + let sign = if mod_value < 0.0 { '-' } else { '+' }; + let magnitude = mod_value.abs(); + if magnitude == 1.0 { format!("({} {} {})", offset, sign, exp) } else { - format!("({} {} {} * {})", offset, sign, mod_value, exp) + format!("({} {} {} * {})", offset, sign, magnitude, exp) } } @@ -518,7 +497,7 @@ fn agg_sql_fallback( Some(band) => { let exp_sql = simple_stat_sql_fallback(band.expansion, raw_col, dialect, src_alias, group_cols); - format_band(&offset_sql, band.sign, band.mod_value, &exp_sql) + format_band(&offset_sql, band.mod_value, &exp_sql) } } } @@ -1301,7 +1280,7 @@ mod tests { #[test] fn empty_prefix_is_error() { let err = parse_aggregate_param(&ParameterValue::String(":mean".to_string())).unwrap_err(); - assert!(err.contains("aesthetic prefix"), "got: {}", err); + assert!(err.contains("could not parse"), "got: {}", err); } #[test] @@ -1333,10 +1312,20 @@ mod tests { .band .as_ref() .unwrap() - .sign, - '-' + .mod_value, + -1.0, ); assert_eq!(s.default_upper.as_ref().unwrap().offset, "mean"); + assert_eq!( + s.default_upper + .as_ref() + .unwrap() + .band + .as_ref() + .unwrap() + .mod_value, + 1.0, + ); } // ---------- apply tests ---------- From 4aa41596a772e75ae62e8073cfec9db2cbdfd574 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 14:20:13 +0200 Subject: [PATCH 30/33] better warning --- src/plot/layer/geom/stat_aggregate.rs | 31 +++++++++++++-------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index bf99e009..6b1fe1ae 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -793,27 +793,26 @@ pub fn apply( } } - // The *only* time we have nothing to aggregate but should still transform - // is when defaults exist but every numeric mapping was dropped — we still - // emit a GROUP BY to honour the grouping. If there are no aggregations and - // no kept columns and no group_by, return Identity. - if aggregated.is_empty() && kept_cols.is_empty() && group_by.is_empty() { - for d in &dropped { - eprintln!( - "Warning: aggregate dropped numeric mapping for aesthetic '{}' (no applicable default and no targeted function)", - aesthetic_ctx.map_internal_to_user(d) - ); - } - return Ok(StatResult::Identity); - } - for d in &dropped { + let user_aes = aesthetic_ctx.map_internal_to_user(d); eprintln!( - "Warning: aggregate dropped numeric mapping for aesthetic '{}' (no applicable default and no targeted function)", - aesthetic_ctx.map_internal_to_user(d) + "Warning: aggregate dropped numeric mapping for aesthetic '{}' \ + (no applicable default and no targeted function). \ + Suggestion: add an unprefixed default like `aggregate => 'mean'` \ + to apply one function to every numeric mapping, or target this \ + aesthetic with `'{0}:'`.", + user_aes, ); } + // No aggregate functions to apply → the stat has nothing to do. Whether + // the layer has group keys or not is irrelevant: emitting a `SELECT keys + // FROM src GROUP BY keys` query would be a distinct-rows transform the + // user didn't ask for. + if aggregated.is_empty() { + return Ok(StatResult::Identity); + } + // Group columns: PARTITION BY + discrete column-mappings, deduped. let mut group_cols: Vec = Vec::new(); for g in group_by { From 95d4df19aa4d80141f4ac9958a53ef7e9442e8e5 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 14:43:41 +0200 Subject: [PATCH 31/33] add finer test --- src/plot/layer/geom/stat_aggregate.rs | 65 +++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/src/plot/layer/geom/stat_aggregate.rs b/src/plot/layer/geom/stat_aggregate.rs index 6b1fe1ae..076bdfb5 100644 --- a/src/plot/layer/geom/stat_aggregate.rs +++ b/src/plot/layer/geom/stat_aggregate.rs @@ -1607,6 +1607,71 @@ mod tests { } } + #[test] + fn last_with_discrete_group_partitions_row_number_over_group() { + // Pins build_group_by_query's behaviour for the rn-CTE + non-empty + // group_cols combo: every other test that exercises the rn CTE has + // empty group_cols (so windows emit `OVER ()`). A bug that + // forgot to thread `PARTITION BY ` through wouldn't + // surface in those tests. + let mut aes = Mappings::new(); + aes.insert("pos1", col("__ggsql_aes_pos1__")); + aes.insert("pos2", col("__ggsql_aes_pos2__")); + let schema = schema_for(&[ + ("__ggsql_aes_pos1__", true), // discrete group key + ("__ggsql_aes_pos2__", false), + ]); + + // Native-FIRST/LAST dialect: no rn CTE, GROUP BY uses the discrete key. + let result = run( + ParameterValue::String("last".to_string()), + &aes, + &schema, + &[], + &InlineQuantileDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!( + !query.contains("__ggsql_rn__"), + "native LAST must not add ROW_NUMBER prep: {query}" + ); + assert!(query.contains("LAST(\"__ggsql_aes_pos2__\")"), "{query}"); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\""), "{query}"); + } + _ => panic!("expected Transformed"), + } + + // Default dialect: rn CTE must partition by the discrete group key. + struct AnsiTestDialect; + impl SqlDialect for AnsiTestDialect {} + let result = run( + ParameterValue::String("last".to_string()), + &aes, + &schema, + &[], + &AnsiTestDialect, + ) + .unwrap(); + match result { + StatResult::Transformed { query, .. } => { + assert!( + query.contains( + "ROW_NUMBER() OVER (PARTITION BY \"__ggsql_aes_pos1__\" ORDER BY (SELECT 1))" + ), + "{query}" + ); + assert!( + query.contains("COUNT(*) OVER (PARTITION BY \"__ggsql_aes_pos1__\")"), + "{query}" + ); + assert!(query.contains("GROUP BY \"__ggsql_aes_pos1__\""), "{query}"); + } + _ => panic!("expected Transformed"), + } + } + #[test] fn first_and_last_emit_positional_aggregates() { let mut aes = Mappings::new(); From ee499981950b320f7fe8bb7428df4f5d68f2e320 Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 16:18:17 +0200 Subject: [PATCH 32/33] appease our dear lord and master clippy --- src/plot/layer/geom/mod.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/plot/layer/geom/mod.rs b/src/plot/layer/geom/mod.rs index 7b2adf56..2ca8265b 100644 --- a/src/plot/layer/geom/mod.rs +++ b/src/plot/layer/geom/mod.rs @@ -305,10 +305,7 @@ pub trait GeomTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { /// True when `parameters["aggregate"]` is set to a non-null string or array. pub(crate) fn has_aggregate_param(parameters: &HashMap) -> bool { - match parameters.get("aggregate") { - Some(ParameterValue::String(_)) | Some(ParameterValue::Array(_)) => true, - _ => false, - } + matches!(parameters.get("aggregate"), Some(ParameterValue::String(_)) | Some(ParameterValue::Array(_))) } /// Wrapper struct for geom trait objects From c574e45facc7f7170b2672b9685c0a257cd171db Mon Sep 17 00:00:00 2001 From: Thomas Lin Pedersen Date: Thu, 7 May 2026 20:42:33 +0200 Subject: [PATCH 33/33] Apply suggestions from code review Co-authored-by: Teun van den Brand <49372158+teunbrand@users.noreply.github.com> --- doc/syntax/clause/draw.qmd | 15 ++++++++------- doc/syntax/layer/type/ribbon.qmd | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/doc/syntax/clause/draw.qmd b/doc/syntax/clause/draw.qmd index 010857ff..173bf19f 100644 --- a/doc/syntax/clause/draw.qmd +++ b/doc/syntax/clause/draw.qmd @@ -77,12 +77,12 @@ The `SETTING` clause can be used for two different things: A special setting is `position` which controls how overlapping objects are repositioned to avoid overlapping etc. Position adjustments have special mapping requirements so all position adjustments will not be relevant for all layer types. Different layers have different defaults as detailed in their documentation. You can read about each different position adjustment at [their own documentation sites](../index.qmd#position-adjustments). #### Aggregate -Some layers support aggregation of their data through the `aggregate` setting. These layers will state this. `aggregate` collapses each group to a single row, replacing every numeric mapping in place with its aggregated value. Groups are defined by `PARTITION BY` together with all discrete mappings. +Some layers support aggregation of their data through the `aggregate` setting. Their documentation will state this. `aggregate` collapses each group to a single row, replacing every numeric mapping in place with its aggregated value. Groups are defined by `PARTITION BY` together with all discrete mappings. -The setting takes a single string or an array of strings. Each string is one of: +The `aggregate` setting takes a single string or an array of strings. Each string is one of: -* **Untargeted** — `''` (no prefix). With one untargeted aggregation the function applies to every numeric mapping that doesn't have a targeted aggregation. With two untargeted aggregations the first is used for the lower side of range layers (e.g. `x`/`xmin`) plus all non-range layers, and the second is used for the upper side of range layers (e.g. `xend`/`xmax`). More than two untargeted aggregations is an error. -* **Targeted** — `':'`. Applies `func` to the named aesthetic only (`` is a user-facing name like `x`, `y`, `xmin`, `xmax`, `xend`, `yend`, `color`, `size`, …). A target overrides any untargeted aggregation for that aesthetic. +* **Untargeted** — `''` (no prefix). With one untargeted aggregation, the function applies to every numeric mapping that doesn't have a targeted aggregation. With two untargeted aggregations, the first is used for the lower side of range layers (e.g. `x`/`xmin`) plus all non-range layers, and the second is used for the upper side of range layers (e.g. `xend`/`xmax`). More than two untargeted aggregations is not allowed. +* **Targeted** — `':'`. Applies `func` to the named aesthetic only (`` is a name like `x`, `y`, `xmin`, `xmax`, `xend`, `yend`, `color`, `size`, …). A target overrides any untargeted aggregation for that aesthetic. A numeric mapping is dropped from the layer with a warning, when it has neither a target nor an applicable default. @@ -90,7 +90,8 @@ The simple functions are: * `'count'`: Non-null tally of the bound column. * `'sum'` and `'prod'`: The sum or product -* `'min'`, `'max'`, `'range'`, and `'mid'`: Extremes, max - min, and (min + max) / 2 +* `'min'`, `'max'`: Extremes +* `'range'` (max - min), `'mid'` (min + max) / 2 * `'mean'`, and `'median'`: Central tendency * `'geomean'`, `'harmean'`, and `'rms'`: Geometric, harmonic, and root-mean-square * `'sdev'`, `'var'`, `'iqr'`, and `'se'`: Standard deviation, variance, interquartile range, and standard error @@ -104,14 +105,14 @@ Allowed offsets are: `'mean'`, `'median'`, `'geomean'`, `'harmean'`, `'rms'`, `' Allowed expansions are: `'sdev'`, `'se'`, `'var'`, `'iqr'`, and `'range'` -You can also target the same aesthetic more than once to produce **multiple rows per group** — one for each function. For example `aggregate => ('y:min', 'y:max')` emits a min row and a max row per group, so a single `DRAW line` produces two summary lines that connect within each group rather than across them. When multiple rows are created a synthetic `aggregate` column is made that tags each row with the aggregation function. You can use this with a `REMAPPING` to drive another aesthetic — e.g. `REMAPPING aggregate AS stroke` to colour the two lines differently. The column's value is built from the per-row function names of the *exploded* targets, deduplicated, and joined with `/`: +You can also target the same aesthetic more than once to produce **multiple rows per group** — one for each function. For example `aggregate => ('y:min', 'y:max')` emits a min row and a max row per group, so a single `DRAW line` produces two summary lines that connect within each group rather than across them. When multiple rows are created, a synthetic `aggregate` column is made that tags each row with the name of the aggregation function. You can use this with a `REMAPPING` to drive another aesthetic — e.g. `REMAPPING aggregate AS stroke` to colour the two lines differently. The column's value is built from the per-row function names of the *exploded* targets, deduplicated, and joined with `/`: * `aggregate => ('y:min', 'y:max')` → rows tagged `'min'`, `'max'`. * `aggregate => ('y:min', 'y:max', 'color:median')` → rows tagged `'min'`, `'max'` (the single-function `color` target is recycled across rows and is not part of the label). * `aggregate => ('y:min', 'y:max', 'color:sum', 'color:prod')` → rows tagged `'min/sum'`, `'max/prod'`. * `aggregate => ('y:mean', 'y:max', 'color:mean', 'color:prod')` → rows tagged `'mean'`, `'max/prod'` (the duplicate `'mean'` collapses). -When several aesthetics are targeted with the same number of functions, they explode in lockstep (row 1 uses each aesthetic's first function, row 2 the second, and so on); aesthetics with a single function — and the unprefixed defaults — are reused unchanged across every row. Mixing different lengths above 1 is an error. +When several aesthetics are targeted with the same number of functions, they explode in lockstep: row 1 uses each aesthetic's first function, row 2 the second, and so on. Aesthetics with a single function — and the unprefixed defaults — are reused unchanged across every row. Mixing different numbers of aggregation metrics above 1 across aesthetics is not allowed. In the single-row (reduction) case aggregation applies in place — no `REMAPPING` is needed and no synthetic column is added. Only the multi-row (explosion) case described above introduces the synthetic `aggregate` column. diff --git a/doc/syntax/layer/type/ribbon.qmd b/doc/syntax/layer/type/ribbon.qmd index b3e2b375..d46aa02a 100644 --- a/doc/syntax/layer/type/ribbon.qmd +++ b/doc/syntax/layer/type/ribbon.qmd @@ -63,7 +63,7 @@ DRAW line MAPPING MeanTemp AS y ``` -Use aggregation to calculate bounds on the fly +Use aggregation to calculate bounds on the fly. The two untargeted aggregation functions target the `ymin` and `ymax` aesthetics automatically. ```{ggsql} VISUALISE Day AS x, Temp AS ymin, Temp AS ymax FROM ggsql:airquality