diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_call.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_call.rs index 08dc7235cb380..ad2aca6b1e2d8 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_call.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_call.rs @@ -6,6 +6,7 @@ use crate::planner::sql_evaluator::sql_nodes::SqlNodesFactory; use crate::planner::sql_evaluator::{CubeNameSymbol, CubeTableSymbol}; use crate::planner::sql_templates::PlanSqlTemplates; use crate::planner::VisitorContext; +use crate::utils::sql_expression_scanner::analyze_template_arg_contexts; use cubenativeutils::CubeError; use itertools::Itertools; use std::collections::HashMap; @@ -117,6 +118,10 @@ pub struct SqlCall { filter_params: Vec, filter_groups: Vec, security_context: SecutityContextProps, + /// Per `{arg:N}` index: whether the surrounding context in the template + /// would make a compound substitution unsafe (requiring parentheses). + /// Computed once at construction from the template. + arg_paren_contexts: HashMap, } impl SqlCall { @@ -127,12 +132,26 @@ impl SqlCall { filter_groups: Vec, security_context: SecutityContextProps, ) -> Self { + let arg_paren_contexts = match &template { + SqlTemplate::String(s) => analyze_template_arg_contexts(s), + SqlTemplate::StringVec(strings) => { + let mut merged: HashMap = HashMap::new(); + for s in strings { + for (idx, needs_safe) in analyze_template_arg_contexts(s) { + let entry = merged.entry(idx).or_insert(false); + *entry = *entry || needs_safe; + } + } + merged + } + }; Self { template, deps, filter_params, filter_groups, security_context, + arg_paren_contexts, } } @@ -254,10 +273,22 @@ impl SqlCall { let deps = self .deps .iter() - .map(|dep| match dep { - SqlDependency::Symbol(m) => visitor.apply(m, node_processor.clone(), templates), - SqlDependency::CubeRef(cr) => { - visitor.evaluate_cube_ref(cr, node_processor.clone(), templates) + .enumerate() + .map(|(i, dep)| { + // Each arg's `arg_needs_paren_safe` flag is set by this call's + // template context, overriding whatever the caller's visitor + // carried. The caller's flag only governs wrapping of this + // whole SqlCall's output, handled by an enclosing Parenthesize + // node up the processor chain. + let needs_safe = *self.arg_paren_contexts.get(&i).unwrap_or(&false); + let arg_visitor = visitor.with_arg_needs_paren_safe(needs_safe); + match dep { + SqlDependency::Symbol(m) => { + arg_visitor.apply(m, node_processor.clone(), templates) + } + SqlDependency::CubeRef(cr) => { + arg_visitor.evaluate_cube_ref(cr, node_processor.clone(), templates) + } } }) .collect::, _>>()?; diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/calendar_time_shift.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/calendar_time_shift.rs index 7bfc8d9e40be4..e77234cc3315e 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/calendar_time_shift.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/calendar_time_shift.rs @@ -36,13 +36,6 @@ impl SqlNode for CalendarTimeShiftSqlNode { node_processor: Rc, templates: &PlanSqlTemplates, ) -> Result { - let input = self.input.to_sql( - visitor, - node, - query_tools.clone(), - node_processor.clone(), - templates, - )?; let res = match node.as_ref() { MemberSymbol::Dimension(ev) => { if !ev.is_reference() { @@ -55,20 +48,52 @@ impl SqlNode for CalendarTimeShiftSqlNode { templates, )? } else if let Some(interval) = &shift.interval { + let inner_visitor = visitor.with_arg_needs_paren_safe(false); + let input = self.input.to_sql( + &inner_visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )?; let res = templates .add_timestamp_interval(input, interval.inverse().to_sql())?; format!("({})", res) } else { - input + self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )? } } else { - input + self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )? } } else { - input + self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )? } } - _ => input, + _ => self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )?, }; Ok(res) } diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/case.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/case.rs index d1fe0e1e7f23b..d1bc4452e665c 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/case.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/case.rs @@ -31,10 +31,12 @@ impl CaseSqlNode { node_processor: Rc, templates: &PlanSqlTemplates, ) -> Result { + // All sub-SQLs end up inside `CASE … END` — a safe wrap. + let inner_visitor = visitor.with_arg_needs_paren_safe(false); let mut when_then = Vec::new(); for itm in case.items.iter() { let when = itm.sql.eval( - visitor, + &inner_visitor, node_processor.clone(), query_tools.clone(), templates, @@ -42,7 +44,7 @@ impl CaseSqlNode { let then = match &itm.label { CaseLabel::String(s) => templates.quote_string(&s)?, CaseLabel::Sql(sql) => sql.eval( - visitor, + &inner_visitor, node_processor.clone(), query_tools.clone(), templates, @@ -53,7 +55,7 @@ impl CaseSqlNode { let else_label = match &case.else_label { CaseLabel::String(s) => templates.quote_string(&s)?, CaseLabel::Sql(sql) => sql.eval( - visitor, + &inner_visitor, node_processor.clone(), query_tools.clone(), templates, @@ -69,6 +71,9 @@ impl CaseSqlNode { node_processor: Rc, templates: &PlanSqlTemplates, ) -> Result { + // Degenerate shortcuts return the inner SQL as-is — propagate the outer + // visitor so an enclosing ParenthesizeSqlNode still sees the compound + // flag. if case.items.len() == 1 && case.else_sql.is_none() { return case.items[0].sql.eval( visitor, @@ -85,22 +90,23 @@ impl CaseSqlNode { templates, ); } + let inner_visitor = visitor.with_arg_needs_paren_safe(false); let expr = match &case.switch { CaseSwitchItem::Sql(sql_call) => sql_call.eval( - visitor, + &inner_visitor, node_processor.clone(), query_tools.clone(), templates, )?, CaseSwitchItem::Member(member_symbol) => { - visitor.apply(&member_symbol, node_processor.clone(), templates)? + inner_visitor.apply(&member_symbol, node_processor.clone(), templates)? } }; let mut when_then = Vec::new(); for itm in case.items.iter() { let when = templates.quote_string(&itm.value)?; let then = itm.sql.eval( - visitor, + &inner_visitor, node_processor.clone(), query_tools.clone(), templates, @@ -109,7 +115,7 @@ impl CaseSqlNode { } let else_label = if let Some(else_sql) = &case.else_sql { Some(else_sql.eval( - visitor, + &inner_visitor, node_processor.clone(), query_tools.clone(), templates, diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/factory.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/factory.rs index 5d15962841284..675006d9726aa 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/factory.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/factory.rs @@ -1,9 +1,9 @@ use super::{ AutoPrefixSqlNode, CaseSqlNode, EvaluateSqlNode, FinalMeasureSqlNode, FinalPreAggregationMeasureSqlNode, GeoDimensionSqlNode, MaskedSqlNode, MeasureFilterSqlNode, - MultiStageRankNode, MultiStageWindowNode, RenderReferencesSqlNode, RenderReferencesType, - RollingWindowNode, RootSqlNode, SqlNode, TimeDimensionNode, TimeShiftSqlNode, - UngroupedMeasureSqlNode, UngroupedQueryFinalMeasureSqlNode, + MultiStageRankNode, MultiStageWindowNode, ParenthesizeSqlNode, RenderReferencesSqlNode, + RenderReferencesType, RollingWindowNode, RootSqlNode, SqlNode, TimeDimensionNode, + TimeShiftSqlNode, UngroupedMeasureSqlNode, UngroupedQueryFinalMeasureSqlNode, }; use crate::planner::planners::multi_stage::TimeShiftState; use crate::planner::sql_evaluator::cube_ref_evaluator::CubeRefEvaluator; @@ -156,8 +156,10 @@ impl SqlNodesFactory { evaluate_sql_processor.clone(), self.cube_name_references.clone(), ); + let parenthesize_processor: Rc = + ParenthesizeSqlNode::new(auto_prefix_processor.clone()); - let measure_filter_processor = MeasureFilterSqlNode::new(auto_prefix_processor.clone()); + let measure_filter_processor = MeasureFilterSqlNode::new(parenthesize_processor.clone()); let measure_processor = CaseSqlNode::new(measure_filter_processor.clone()); let measure_processor = self.add_ungrouped_measure_reference_if_needed(measure_processor); @@ -182,10 +184,11 @@ impl SqlNodesFactory { } else { evaluate_sql_processor.clone() }; + let default_processor: Rc = ParenthesizeSqlNode::new(default_processor); let root_node = RootSqlNode::new( self.dimension_processor(evaluate_sql_processor.clone()), - self.time_dimension_processor(evaluate_sql_processor.clone()), + self.time_dimension_processor(ParenthesizeSqlNode::new(evaluate_sql_processor.clone())), measure_processor.clone(), default_processor, ); @@ -261,6 +264,8 @@ impl SqlNodesFactory { let input: Rc = AutoPrefixSqlNode::new(input, self.cube_name_references.clone()); + let input: Rc = ParenthesizeSqlNode::new(input); + let input: Rc = TimeDimensionNode::new(self.dimensions_with_ignored_timezone.clone(), input); diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_measure.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_measure.rs index 7c06e438ea43b..4135608d145f3 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_measure.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_measure.rs @@ -1,6 +1,6 @@ use super::SqlNode; use crate::planner::query_tools::QueryTools; -use crate::planner::sql_evaluator::symbols::{AggregateWrap, MeasureSymbol}; +use crate::planner::sql_evaluator::symbols::AggregateWrap; use crate::planner::sql_evaluator::MemberSymbol; use crate::planner::sql_evaluator::SqlEvaluatorVisitor; use crate::planner::sql_templates::PlanSqlTemplates; @@ -32,16 +32,13 @@ impl FinalMeasureSqlNode { &self.input } - fn wrap_aggregate( + fn apply_wrap( &self, - ev: &MeasureSymbol, + wrap: AggregateWrap, input: String, templates: &PlanSqlTemplates, ) -> Result { - let is_multiplied = self - .rendered_as_multiplied_measures - .contains(&ev.full_name()); - match ev.kind().aggregate_wrap(is_multiplied) { + match wrap { AggregateWrap::PassThrough => Ok(input), AggregateWrap::Function(name) => Ok(format!("{}({})", name, input)), AggregateWrap::CountDistinct => templates.count_distinct(&input), @@ -67,14 +64,22 @@ impl SqlNode for FinalMeasureSqlNode { ) -> Result { let res = match node.as_ref() { MemberSymbol::Measure(ev) => { + let is_multiplied = self + .rendered_as_multiplied_measures + .contains(&ev.full_name()); + let wrap = ev.kind().aggregate_wrap(is_multiplied); + let child_visitor = match wrap { + AggregateWrap::PassThrough => visitor.clone(), + _ => visitor.with_arg_needs_paren_safe(false), + }; let input = self.input.to_sql( - visitor, + &child_visitor, node, query_tools.clone(), node_processor.clone(), templates, )?; - self.wrap_aggregate(ev, input, templates)? + self.apply_wrap(wrap, input, templates)? } _ => { return Err(CubeError::internal(format!( diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/geo_dimension.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/geo_dimension.rs index c74ca07a708b0..32f0863a8bdb0 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/geo_dimension.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/geo_dimension.rs @@ -34,14 +34,15 @@ impl SqlNode for GeoDimensionSqlNode { let res = match node.as_ref() { MemberSymbol::Dimension(ev) => { if let DimensionKind::Geo(geo) = ev.kind() { + let inner_visitor = visitor.with_arg_needs_paren_safe(false); let latitude_str = geo.latitude().eval( - visitor, + &inner_visitor, node_processor.clone(), query_tools.clone(), templates, )?; let longitude_str = geo.longitude().eval( - visitor, + &inner_visitor, node_processor.clone(), query_tools.clone(), templates, diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/measure_filter.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/measure_filter.rs index e180959925ec0..d5eb99c920b6f 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/measure_filter.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/measure_filter.rs @@ -30,24 +30,25 @@ impl SqlNode for MeasureFilterSqlNode { node_processor: Rc, templates: &PlanSqlTemplates, ) -> Result { - let input = self.input.to_sql( - visitor, - node, - query_tools.clone(), - node_processor.clone(), - templates, - )?; let res = match node.as_ref() { MemberSymbol::Measure(ev) => { let measure_filters = ev.measure_filters(); if !measure_filters.is_empty() { + let inner_visitor = visitor.with_arg_needs_paren_safe(false); + let input = self.input.to_sql( + &inner_visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )?; let filters = measure_filters .iter() .map(|filter| -> Result { Ok(format!( "({})", filter.eval( - &visitor, + &inner_visitor, node_processor.clone(), query_tools.clone(), templates @@ -63,7 +64,14 @@ impl SqlNode for MeasureFilterSqlNode { }; format!("CASE WHEN {} THEN {} END", filters, result) } else { - input + // Passthrough — propagate visitor unchanged. + self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )? } } _ => { diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/mod.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/mod.rs index 1e515a39688e3..35f81ee5d4708 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/mod.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/mod.rs @@ -11,6 +11,7 @@ pub mod masked; pub mod measure_filter; pub mod multi_stage_rank; pub mod multi_stage_window; +pub mod parenthesize; pub mod render_references; pub mod rolling_window; pub mod root_processor; @@ -32,6 +33,7 @@ pub use masked::MaskedSqlNode; pub use measure_filter::MeasureFilterSqlNode; pub use multi_stage_rank::MultiStageRankNode; pub use multi_stage_window::MultiStageWindowNode; +pub use parenthesize::ParenthesizeSqlNode; pub use render_references::*; pub use rolling_window::RollingWindowNode; pub use root_processor::RootSqlNode; diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_rank.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_rank.rs index c2b9f52e30671..392b1f82ea275 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_rank.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_rank.rs @@ -42,13 +42,14 @@ impl SqlNode for MultiStageRankNode { let res = match node.as_ref() { MemberSymbol::Measure(m) => { if m.is_multi_stage() && matches!(m.kind(), MeasureKind::Rank) { + let inner_visitor = visitor.with_arg_needs_paren_safe(false); let order_by = if !m.measure_order_by().is_empty() { let sql = m .measure_order_by() .iter() .map(|item| -> Result { let sql = item.sql_call().eval( - visitor, + &inner_visitor, node_processor.clone(), query_tools.clone(), templates, diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_window.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_window.rs index fe41626d666c8..6258677ddfe97 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_window.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_window.rs @@ -50,8 +50,9 @@ impl SqlNode for MultiStageWindowNode { let res = match node.as_ref() { MemberSymbol::Measure(m) => { if m.is_multi_stage() && !m.is_calculated() { + let inner_visitor = visitor.with_arg_needs_paren_safe(false); let input_sql = self.input.to_sql( - visitor, + &inner_visitor, node, query_tools.clone(), node_processor.clone(), diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/parenthesize.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/parenthesize.rs new file mode 100644 index 0000000000000..a7a238becc116 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/parenthesize.rs @@ -0,0 +1,62 @@ +use super::SqlNode; +use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::MemberSymbol; +use crate::planner::sql_evaluator::SqlEvaluatorVisitor; +use crate::planner::sql_templates::PlanSqlTemplates; +use crate::utils::sql_expression_scanner::is_top_level_compound; +use cubenativeutils::CubeError; +use std::any::Any; +use std::rc::Rc; + +/// Wraps the child's rendered SQL in parentheses when the visitor signals that +/// the surrounding context expects a parentheses-safe argument (for example, a +/// `SqlCall` substitution into an arithmetic or logical position) and the +/// rendered expression is compound at the top level. +/// +/// Sits immediately above [`AutoPrefixSqlNode`] in the processor chain — the +/// lowest point where renaming is complete. Higher-layer nodes that wrap the +/// child's output in a syntactically safe construct (aggregate, window +/// function, CASE/DATE_TRUNC/CONVERT_TZ, etc.) should reset +/// `arg_needs_paren_safe` on the visitor before recursing, so this node avoids +/// scanning output that will be discarded. +pub struct ParenthesizeSqlNode { + input: Rc, +} + +impl ParenthesizeSqlNode { + pub fn new(input: Rc) -> Rc { + Rc::new(Self { input }) + } + + pub fn input(&self) -> &Rc { + &self.input + } +} + +impl SqlNode for ParenthesizeSqlNode { + fn to_sql( + &self, + visitor: &SqlEvaluatorVisitor, + node: &Rc, + query_tools: Rc, + node_processor: Rc, + templates: &PlanSqlTemplates, + ) -> Result { + let input_sql = self + .input + .to_sql(visitor, node, query_tools, node_processor, templates)?; + if visitor.arg_needs_paren_safe() && is_top_level_compound(&input_sql) { + Ok(format!("({})", input_sql)) + } else { + Ok(input_sql) + } + } + + fn as_any(self: Rc) -> Rc { + self.clone() + } + + fn childs(&self) -> Vec> { + vec![self.input.clone()] + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/rolling_window.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/rolling_window.rs index 133bfde729376..e2a3440d6017e 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/rolling_window.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/rolling_window.rs @@ -35,55 +35,52 @@ impl SqlNode for RollingWindowNode { templates: &PlanSqlTemplates, ) -> Result { let res = match node.as_ref() { - MemberSymbol::Measure(m) => { - if m.is_cumulative() { - let input = self.input.to_sql( + MemberSymbol::Measure(m) if m.is_cumulative() => { + let delegate = || { + self.default_processor.to_sql( visitor, node, query_tools.clone(), node_processor.clone(), templates, - )?; - match m.kind() { - MeasureKind::Aggregated(a) - if a.agg_type() == AggregationType::CountDistinctApprox => - { - templates.hll_cardinality_merge(input)? - } - MeasureKind::Count(_) => format!("sum({})", input), - MeasureKind::Aggregated(a) => match a.agg_type() { - AggregationType::Sum | AggregationType::RunningTotal => { - format!("sum({})", input) - } - AggregationType::Min | AggregationType::Max => { - format!("{}({})", a.agg_type().as_str(), input) - } - _ => self.default_processor.to_sql( - visitor, - node, - query_tools.clone(), - node_processor, - templates, - )?, - }, - _ => self.default_processor.to_sql( - visitor, - node, - query_tools.clone(), - node_processor, - templates, - )?, - } - } else { - self.default_processor.to_sql( - visitor, + ) + }; + let render_input = || -> Result { + let inner_visitor = visitor.with_arg_needs_paren_safe(false); + self.input.to_sql( + &inner_visitor, node, query_tools.clone(), - node_processor, + node_processor.clone(), templates, - )? + ) + }; + match m.kind() { + MeasureKind::Count(_) => format!("sum({})", render_input()?), + MeasureKind::Aggregated(a) => match a.agg_type() { + AggregationType::CountDistinctApprox => { + templates.hll_cardinality_merge(render_input()?)? + } + AggregationType::Sum | AggregationType::RunningTotal => { + format!("sum({})", render_input()?) + } + AggregationType::Min | AggregationType::Max => { + format!("{}({})", a.agg_type().as_str(), render_input()?) + } + AggregationType::Avg + | AggregationType::CountDistinct + | AggregationType::NumberAgg => delegate()?, + }, + _ => delegate()?, } } + MemberSymbol::Measure(_) => self.default_processor.to_sql( + visitor, + node, + query_tools.clone(), + node_processor, + templates, + )?, _ => { return Err(CubeError::internal(format!( "Unexpected evaluation node type for RollingWindowNode" diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_dimension.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_dimension.rs index 861fd3ba9ee1c..3d9a15747adc0 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_dimension.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_dimension.rs @@ -34,16 +34,12 @@ impl SqlNode for TimeDimensionNode { node_processor: Rc, templates: &PlanSqlTemplates, ) -> Result { - let input_sql = self.input.to_sql( - visitor, - node, - query_tools.clone(), - node_processor.clone(), - templates, - )?; match node.as_ref() { MemberSymbol::TimeDimension(ev) => { - let res = if let Some(granularity_obj) = ev.granularity_obj() { + if let Some(granularity_obj) = ev.granularity_obj() { + // Short-circuits to calendar SQL — `self.input` is not used. + // Propagate the outer visitor: the calendar SQL is the + // expression itself, not wrapped further here. if let Some(calendar_sql) = granularity_obj.calendar_sql() { return calendar_sql.eval( visitor, @@ -52,7 +48,16 @@ impl SqlNode for TimeDimensionNode { templates, ); } - + // Wraps in `convert_tz(…)` and a granularity function — + // safe, reset for child render. + let inner_visitor = visitor.with_arg_needs_paren_safe(false); + let input_sql = self.input.to_sql( + &inner_visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )?; let skip_convert_tz = self .dimensions_with_ignored_timezone .contains(&ev.full_name()); @@ -63,23 +68,48 @@ impl SqlNode for TimeDimensionNode { templates.convert_tz(input_sql)? }; - granularity_obj.apply_to_input_sql(templates, converted_tz)? + Ok(granularity_obj.apply_to_input_sql(templates, converted_tz)?) } else { - input_sql - }; - Ok(res) + self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + ) + } } MemberSymbol::Dimension(ev) => { - if !visitor.ignore_tz_convert() + let wraps_convert_tz = !visitor.ignore_tz_convert() && query_tools.convert_tz_for_raw_time_dimension() - && ev.dimension_type() == "time" - { + && ev.dimension_type() == "time"; + if wraps_convert_tz { + let inner_visitor = visitor.with_arg_needs_paren_safe(false); + let input_sql = self.input.to_sql( + &inner_visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )?; Ok(templates.convert_tz(input_sql)?) } else { - Ok(input_sql) + self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + ) } } - _ => Ok(input_sql), + _ => self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + ), } } diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs index 1c0a37f9aa8da..39f0434830063 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs @@ -32,28 +32,47 @@ impl SqlNode for TimeShiftSqlNode { node_processor: Rc, templates: &PlanSqlTemplates, ) -> Result { - let input = self.input.to_sql( - visitor, - node, - query_tools.clone(), - node_processor.clone(), - templates, - )?; let res = match node.as_ref() { MemberSymbol::Dimension(ev) => { if !ev.is_reference() && ev.is_time() { if let Some(shift) = self.shifts.dimensions_shifts.get(&ev.full_name()) { - let shift = shift.interval.clone().unwrap().to_sql(); // Common time shifts should always have an interval + let shift = shift.interval.clone().unwrap().to_sql(); + let inner_visitor = visitor.with_arg_needs_paren_safe(false); + let input = self.input.to_sql( + &inner_visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )?; let res = templates.add_timestamp_interval(input, shift)?; format!("({})", res) } else { - input + self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )? } } else { - input + self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )? } } - _ => input, + _ => self.input.to_sql( + visitor, + node, + query_tools.clone(), + node_processor.clone(), + templates, + )?, }; Ok(res) } diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/ungroupped_query_final_measure.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/ungroupped_query_final_measure.rs index 0dbbd94539de4..fe9992b58ea30 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/ungroupped_query_final_measure.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/ungroupped_query_final_measure.rs @@ -33,8 +33,23 @@ impl SqlNode for UngroupedQueryFinalMeasureSqlNode { ) -> Result { let res = match node.as_ref() { MemberSymbol::Measure(ev) => { + let is_count_like = match ev.kind() { + MeasureKind::Count(_) => true, + MeasureKind::Aggregated(a) => matches!( + a.agg_type(), + AggregationType::CountDistinct | AggregationType::CountDistinctApprox + ), + _ => false, + }; + // Count-likes wrap the child in `CASE WHEN … IS NOT NULL THEN 1 END` + // (safe), other kinds pass through and must propagate the flag. + let child_visitor = if is_count_like { + visitor.with_arg_needs_paren_safe(false) + } else { + visitor.clone() + }; let input = self.input.to_sql( - visitor, + &child_visitor, node, query_tools.clone(), node_processor.clone(), @@ -43,20 +58,10 @@ impl SqlNode for UngroupedQueryFinalMeasureSqlNode { if input == "*" { "1".to_string() + } else if is_count_like { + format!("CASE WHEN ({}) IS NOT NULL THEN 1 END", input) //TODO templates!! } else { - let is_count_like = match ev.kind() { - MeasureKind::Count(_) => true, - MeasureKind::Aggregated(a) => matches!( - a.agg_type(), - AggregationType::CountDistinct | AggregationType::CountDistinctApprox - ), - _ => false, - }; - if is_count_like { - format!("CASE WHEN ({}) IS NOT NULL THEN 1 END", input) //TODO templates!! - } else { - input - } + input } } _ => { diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_visitor.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_visitor.rs index 25e14ab148337..7182b0a10963d 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_visitor.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_visitor.rs @@ -14,6 +14,10 @@ pub struct SqlEvaluatorVisitor { cube_ref_evaluator: Rc, all_filters: Option, //To pass to FILTER_PARAMS and FILTER_GROUP ignore_tz_convert: bool, + /// When `true`, the caller (typically a `SqlCall` substitution site) expects + /// the rendered expression to be safe for embedding next to operators — + /// i.e. a compound top-level result should be wrapped in parentheses. + arg_needs_paren_safe: bool, } impl SqlEvaluatorVisitor { @@ -27,6 +31,7 @@ impl SqlEvaluatorVisitor { cube_ref_evaluator, all_filters, ignore_tz_convert: false, + arg_needs_paren_safe: false, } } @@ -36,6 +41,16 @@ impl SqlEvaluatorVisitor { self_copy } + pub fn with_arg_needs_paren_safe(&self, value: bool) -> Self { + let mut self_copy = self.clone(); + self_copy.arg_needs_paren_safe = value; + self_copy + } + + pub fn arg_needs_paren_safe(&self) -> bool { + self.arg_needs_paren_safe + } + pub fn all_filters(&self) -> Option { self.all_filters.clone() } diff --git a/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/auto_parentheses_tests.yaml b/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/auto_parentheses_tests.yaml new file mode 100644 index 0000000000000..c7d99fdc45ede --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/auto_parentheses_tests.yaml @@ -0,0 +1,131 @@ +cubes: + - name: expr_cube + sql: "SELECT * FROM expr_cube_table" + dimensions: + - name: id + type: number + sql: id + primary_key: true + + # Atomic single-column dimensions. + - name: atomic + type: number + sql: "{CUBE}.col" + + - name: a + type: number + sql: "{CUBE}.a" + + - name: b + type: number + sql: "{CUBE}.b" + + # Compound: top-level arithmetic. + - name: compound_arith + type: number + sql: "{CUBE}.a + {CUBE}.b" + + # Compound: top-level logical / comparison. + - name: compound_or + type: boolean + sql: "{CUBE}.a > 0 OR {CUBE}.b > 0" + + # Consumer 1: compound in arithmetic context → should WRAP. + - name: arith_over_compound + type: number + sql: "{compound_arith} * 2" + + # Consumer 2: compound in logical context → should WRAP. + - name: and_over_compound + type: boolean + sql: "{compound_or} AND {CUBE}.a > 10" + + # Consumer 3: compound as function argument → NO wrap. + - name: abs_of_compound + type: number + sql: "ABS({compound_arith})" + + # Consumer 4: compound inside CAST → NO wrap. + - name: cast_of_compound + type: number + sql: "CAST({compound_arith} AS INT)" + + # Consumer 5: compound in CASE THEN branch — NO wrap. + - name: case_with_compound + type: number + sql: "CASE WHEN {CUBE}.a > 0 THEN {compound_arith} ELSE 0 END" + + # Consumer 6: direct reference, single placeholder template — NO wrap. + - name: direct_compound + type: number + sql: "{compound_arith}" + + # Consumer 7: atomic dep in arithmetic — NO wrap (dep output is atomic). + - name: arith_over_atomic + type: number + sql: "{atomic} * 2" + + measures: + - name: cnt + type: count + + # Atomic sum of a simple column. + - name: sum_a + type: sum + sql: "{CUBE}.a" + + - name: sum_b + type: sum + sql: "{CUBE}.b" + + # SUM(compound) — aggregate wrap is safe, so the inner chain should + # render without adding parens around the compound expression. + - name: sum_compound + type: sum + sql: "{compound_arith}" + + # SUM with a compound template itself. + - name: sum_of_sum_template + type: sum + sql: "{CUBE}.a + {CUBE}.b" + + # `type: number` — calculated measure, PassThrough aggregate wrap. + # Compound template at top level: no outer context, renders as-is. + - name: calc_number_compound + type: number + sql: "{sum_a} + {sum_b}" + + # Calculated `number` with arithmetic over a compound calc measure — + # the dep's rendered SQL is `sum(a) + sum(b)` which is compound, so + # the outer `* 100` context must wrap it. + - name: calc_number_over_compound + type: number + sql: "{calc_number_compound} * 100" + + # Calculated `number` with a logical expression at top level. + - name: calc_boolean_compound + type: boolean + sql: "{sum_a} > 0 OR {sum_b} > 0" + + # Calculated `number` in AND context — compound dep must wrap. + - name: calc_boolean_combined + type: boolean + sql: "{calc_boolean_compound} AND {sum_a} > 100" + + # Measure with a `case:` definition over compound branches. + # CASE wrap is safe, so the compound THEN branch does not need parens. + - name: measure_case + type: number + case: + when: + - sql: "{CUBE}.a > 0" + label: positive + else: + label: other + + # SUM over an arithmetic template that references a compound number + # measure — the outer SUM is a safe wrap, so no parens are needed + # around the compound child despite the enclosing `+ 1`. + - name: sum_over_calc_number_plus_one + type: sum + sql: "{calc_number_compound} + 1" diff --git a/rust/cubesqlplanner/cubesqlplanner/src/tests/auto_parentheses.rs b/rust/cubesqlplanner/cubesqlplanner/src/tests/auto_parentheses.rs new file mode 100644 index 0000000000000..703869adb4899 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/tests/auto_parentheses.rs @@ -0,0 +1,178 @@ +//! Integration tests for automatic parenthesization of SqlCall arg substitutions. +//! +//! Each test renders a dimension or measure via the full processor chain and +//! inspects the produced SQL for the presence or absence of parentheses +//! around the substituted expression. + +use crate::test_fixtures::cube_bridge::MockSchema; +use crate::test_fixtures::test_utils::TestContext; + +fn ctx() -> TestContext { + let schema = MockSchema::from_yaml_file("common/auto_parentheses_tests.yaml"); + TestContext::new(schema).unwrap() +} + +fn dimension_sql(ctx: &TestContext, path: &str) -> String { + let sym = ctx.create_dimension(path).unwrap(); + ctx.evaluate_symbol(&sym).unwrap() +} + +fn measure_sql(ctx: &TestContext, path: &str) -> String { + let sym = ctx.create_measure(path).unwrap(); + ctx.evaluate_symbol(&sym).unwrap() +} + +#[test] +fn atomic_dimension_is_never_wrapped() { + let ctx = ctx(); + let sql = dimension_sql(&ctx, "expr_cube.atomic"); + assert_eq!(sql, "\"expr_cube\".col"); +} + +#[test] +fn compound_dimension_rendered_standalone_is_not_wrapped() { + let ctx = ctx(); + // Top-level `evaluate_symbol` call sets no paren-safe expectation, so the + // outermost compound expression is rendered as-is. + let sql = dimension_sql(&ctx, "expr_cube.compound_arith"); + assert_eq!(sql, "\"expr_cube\".a + \"expr_cube\".b"); +} + +#[test] +fn compound_in_arithmetic_context_is_wrapped() { + let ctx = ctx(); + let sql = dimension_sql(&ctx, "expr_cube.arith_over_compound"); + assert_eq!(sql, "(\"expr_cube\".a + \"expr_cube\".b) * 2"); +} + +#[test] +fn compound_in_logical_context_is_wrapped() { + let ctx = ctx(); + let sql = dimension_sql(&ctx, "expr_cube.and_over_compound"); + assert_eq!( + sql, + "(\"expr_cube\".a > 0 OR \"expr_cube\".b > 0) AND \"expr_cube\".a > 10" + ); +} + +#[test] +fn compound_in_function_arg_is_not_wrapped() { + let ctx = ctx(); + let sql = dimension_sql(&ctx, "expr_cube.abs_of_compound"); + assert_eq!(sql, "ABS(\"expr_cube\".a + \"expr_cube\".b)"); +} + +#[test] +fn compound_in_cast_is_not_wrapped() { + let ctx = ctx(); + let sql = dimension_sql(&ctx, "expr_cube.cast_of_compound"); + assert_eq!(sql, "CAST(\"expr_cube\".a + \"expr_cube\".b AS INT)"); +} + +#[test] +fn compound_in_case_branch_is_not_wrapped() { + let ctx = ctx(); + let sql = dimension_sql(&ctx, "expr_cube.case_with_compound"); + assert_eq!( + sql, + "CASE WHEN \"expr_cube\".a > 0 THEN \"expr_cube\".a + \"expr_cube\".b ELSE 0 END" + ); +} + +#[test] +fn direct_reference_does_not_add_parens() { + let ctx = ctx(); + // Template is a single `{arg:0}` placeholder — no surrounding operators. + let sql = dimension_sql(&ctx, "expr_cube.direct_compound"); + assert_eq!(sql, "\"expr_cube\".a + \"expr_cube\".b"); +} + +#[test] +fn atomic_dep_in_arithmetic_stays_unwrapped() { + let ctx = ctx(); + let sql = dimension_sql(&ctx, "expr_cube.arith_over_atomic"); + assert_eq!(sql, "\"expr_cube\".col * 2"); +} + +#[test] +fn aggregate_wrap_resets_flag_for_inner_compound() { + let ctx = ctx(); + // `SUM(…)` already provides a safe context, so the compound expression + // inside is not additionally parenthesized. + let sql = measure_sql(&ctx, "expr_cube.sum_compound"); + assert_eq!(sql, "sum(\"expr_cube\".a + \"expr_cube\".b)"); +} + +#[test] +fn sum_with_compound_template_no_inner_wrap() { + let ctx = ctx(); + // The template itself is compound; the SUM wrap resets the child flag + // so the arithmetic inside stays flat. + let sql = measure_sql(&ctx, "expr_cube.sum_of_sum_template"); + assert_eq!(sql, "sum(\"expr_cube\".a + \"expr_cube\".b)"); +} + +#[test] +fn calculated_number_compound_top_level_is_not_wrapped() { + let ctx = ctx(); + // `type: number` is a passthrough aggregate wrap. Rendered at top level + // (no caller context), the compound expression comes out as-is. + let sql = measure_sql(&ctx, "expr_cube.calc_number_compound"); + assert_eq!(sql, "sum(\"expr_cube\".a) + sum(\"expr_cube\".b)"); +} + +#[test] +fn calculated_number_over_compound_calc_wraps_dep() { + let ctx = ctx(); + // `{calc_number_compound} * 100` — the dep renders to a compound + // expression and must be wrapped before being embedded next to `*`. + let sql = measure_sql(&ctx, "expr_cube.calc_number_over_compound"); + assert_eq!(sql, "(sum(\"expr_cube\".a) + sum(\"expr_cube\".b)) * 100"); +} + +#[test] +fn calculated_boolean_compound_top_level_is_not_wrapped() { + let ctx = ctx(); + let sql = measure_sql(&ctx, "expr_cube.calc_boolean_compound"); + assert_eq!(sql, "sum(\"expr_cube\".a) > 0 OR sum(\"expr_cube\".b) > 0"); +} + +#[test] +fn calculated_boolean_combined_wraps_compound_dep() { + let ctx = ctx(); + let sql = measure_sql(&ctx, "expr_cube.calc_boolean_combined"); + assert_eq!( + sql, + "(sum(\"expr_cube\".a) > 0 OR sum(\"expr_cube\".b) > 0) AND sum(\"expr_cube\".a) > 100" + ); +} + +#[test] +fn measure_with_case_definition_renders_safely() { + let ctx = ctx(); + // `CaseSqlNode` wraps the whole result in `CASE … END`. No substituted + // deps here, but we verify the node's reset path doesn't break anything. + let sql = measure_sql(&ctx, "expr_cube.measure_case"); + assert_eq!( + sql, + "CASE WHEN \"expr_cube\".a > 0 THEN 'positive' ELSE 'other' END" + ); +} + +#[test] +fn sum_over_compound_plus_one_inner_stays_wrapped_outer_not() { + let ctx = ctx(); + // Two effects stack: + // - The inner SqlCall template `{calc_number_compound} + 1` marks the + // placeholder as unsafe (adjacent to `+`), so the compound dep is + // wrapped: `(sum(a) + sum(b)) + 1`. + // - The outer `SUM(…)` is then applied by FinalMeasure and provides its + // own safe wrap, so no further parens are added around the `+ 1`. + // (The inner `(…)` is over-parenthesized for `+` due to conservative + // handling — acceptable trade-off, see scanner design notes.) + let sql = measure_sql(&ctx, "expr_cube.sum_over_calc_number_plus_one"); + assert_eq!( + sql, + "sum((sum(\"expr_cube\".a) + sum(\"expr_cube\".b)) + 1)" + ); +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/tests/mod.rs b/rust/cubesqlplanner/cubesqlplanner/src/tests/mod.rs index 588c2a52808a6..984209ba6191e 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/tests/mod.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/tests/mod.rs @@ -1,3 +1,4 @@ +mod auto_parentheses; mod common_sql_generation; mod compiled_member_path; mod cube_evaluator; diff --git a/rust/cubesqlplanner/cubesqlplanner/src/utils/mod.rs b/rust/cubesqlplanner/cubesqlplanner/src/utils/mod.rs index 647ec00d15068..6e3604fbcf382 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/utils/mod.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/utils/mod.rs @@ -1,4 +1,5 @@ pub mod debug; +pub mod sql_expression_scanner; mod unique_vector; pub use unique_vector::UniqueVector; diff --git a/rust/cubesqlplanner/cubesqlplanner/src/utils/sql_expression_scanner.rs b/rust/cubesqlplanner/cubesqlplanner/src/utils/sql_expression_scanner.rs new file mode 100644 index 0000000000000..e66ce6e319d8c --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/utils/sql_expression_scanner.rs @@ -0,0 +1,946 @@ +//! Minimal dialect-agnostic SQL scanner used to decide whether a SQL expression +//! needs to be wrapped in parentheses when substituted into another expression. +//! +//! Two public entry points: +//! - [`is_top_level_compound`] classifies a rendered SQL string as either atomic +//! (safe to inline as-is) or compound (has a top-level operator and needs +//! parentheses in arithmetic/logical contexts). +//! - [`analyze_template_arg_contexts`] analyses a `SqlCall` template and, per +//! `{arg:N}` placeholder, reports whether the surrounding context would make +//! a compound substitution unsafe. +//! +//! The scanner is intentionally not a full SQL parser. It tokenizes enough to +//! respect strings, comments, brackets and Cube placeholders, then decides +//! atomicity by a positive-list rule: an expression is atomic iff its top-level +//! token stream contains no operator, no operator-keyword and no top-level +//! comma outside of `CASE ... END`. + +use std::collections::HashMap; + +// ---------- Tokenizer ---------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum PlaceholderKind { + Arg, + FilterParam, + FilterGroup, + SecurityValue, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TokenKind { + Word, + QuotedIdent, + Number, + StringLit, + Open(char), + Close(char), + Comma, + Dot, + Semicolon, + CastOp, + Operator, + Placeholder { kind: PlaceholderKind, index: usize }, + OpaqueBraces, + Unknown, +} + +#[derive(Debug, Clone)] +struct Token<'a> { + kind: TokenKind, + text: &'a str, + depth: usize, +} + +struct Tokenizer<'a> { + src: &'a str, + bytes: &'a [u8], + pos: usize, + depth: usize, +} + +impl<'a> Tokenizer<'a> { + fn new(src: &'a str) -> Self { + Self { + src, + bytes: src.as_bytes(), + pos: 0, + depth: 0, + } + } + + fn peek(&self, offset: usize) -> Option { + self.bytes.get(self.pos + offset).copied() + } + + fn at_eof(&self) -> bool { + self.pos >= self.bytes.len() + } + + fn skip_trivia(&mut self) { + loop { + if self.at_eof() { + return; + } + match self.peek(0).unwrap() { + b' ' | b'\t' | b'\n' | b'\r' => { + self.pos += 1; + } + b'-' if self.peek(1) == Some(b'-') => { + self.pos += 2; + while !self.at_eof() && self.peek(0) != Some(b'\n') { + self.pos += 1; + } + } + b'/' if self.peek(1) == Some(b'/') => { + // Line comment variant in BigQuery and Snowflake. + self.pos += 2; + while !self.at_eof() && self.peek(0) != Some(b'\n') { + self.pos += 1; + } + } + b'/' if self.peek(1) == Some(b'*') => { + self.pos += 2; + let mut nested = 1usize; + while !self.at_eof() { + if self.peek(0) == Some(b'/') && self.peek(1) == Some(b'*') { + self.pos += 2; + nested += 1; + } else if self.peek(0) == Some(b'*') && self.peek(1) == Some(b'/') { + self.pos += 2; + nested -= 1; + if nested == 0 { + break; + } + } else { + self.pos += 1; + } + } + } + _ => return, + } + } + } + + fn next_token(&mut self) -> Option> { + self.skip_trivia(); + if self.at_eof() { + return None; + } + let offset = self.pos; + let b = self.peek(0).unwrap(); + + // String-like prefixes: N'...', E'...', B'...', R'...', X'..., also lowercased. + if matches!( + b, + b'N' | b'n' | b'E' | b'e' | b'B' | b'b' | b'R' | b'r' | b'X' | b'x' + ) && self.peek(1) == Some(b'\'') + { + self.pos += 1; + return Some(self.read_quoted(offset, b'\'', true, TokenKind::StringLit)); + } + + match b { + b'\'' => return Some(self.read_quoted(offset, b'\'', true, TokenKind::StringLit)), + b'"' => return Some(self.read_quoted(offset, b'"', true, TokenKind::QuotedIdent)), + b'`' => return Some(self.read_quoted(offset, b'`', false, TokenKind::QuotedIdent)), + b'$' => { + if let Some(tok) = self.try_read_dollar_quoted(offset) { + return Some(tok); + } + // fall through to operator handling + } + _ => {} + } + + // Brackets + if b == b'(' { + self.pos += 1; + self.depth += 1; + return Some(Token { + kind: TokenKind::Open('('), + text: &self.src[offset..self.pos], + depth: self.depth - 1, + }); + } + if b == b')' { + self.pos += 1; + if self.depth > 0 { + self.depth -= 1; + } + return Some(Token { + kind: TokenKind::Close(')'), + text: &self.src[offset..self.pos], + depth: self.depth, + }); + } + if b == b'[' { + self.pos += 1; + self.depth += 1; + return Some(Token { + kind: TokenKind::Open('['), + text: &self.src[offset..self.pos], + depth: self.depth - 1, + }); + } + if b == b']' { + self.pos += 1; + if self.depth > 0 { + self.depth -= 1; + } + return Some(Token { + kind: TokenKind::Close(']'), + text: &self.src[offset..self.pos], + depth: self.depth, + }); + } + if b == b'{' { + if let Some(tok) = self.try_read_placeholder(offset) { + return Some(tok); + } + return Some(self.read_opaque_braces(offset)); + } + if b == b'}' { + self.pos += 1; + return Some(Token { + kind: TokenKind::Unknown, + text: &self.src[offset..self.pos], + depth: self.depth, + }); + } + + if b == b',' { + self.pos += 1; + return Some(Token { + kind: TokenKind::Comma, + text: &self.src[offset..self.pos], + depth: self.depth, + }); + } + if b == b';' { + self.pos += 1; + return Some(Token { + kind: TokenKind::Semicolon, + text: &self.src[offset..self.pos], + depth: self.depth, + }); + } + if b == b'.' { + self.pos += 1; + return Some(Token { + kind: TokenKind::Dot, + text: &self.src[offset..self.pos], + depth: self.depth, + }); + } + + if b == b':' && self.peek(1) == Some(b':') { + self.pos += 2; + return Some(Token { + kind: TokenKind::CastOp, + text: &self.src[offset..self.pos], + depth: self.depth, + }); + } + + if b.is_ascii_digit() { + return Some(self.read_number(offset)); + } + + if is_ident_start(b) { + return Some(self.read_word(offset)); + } + + if is_operator_byte(b) { + while !self.at_eof() && is_operator_byte(self.peek(0).unwrap()) { + self.pos += 1; + } + return Some(Token { + kind: TokenKind::Operator, + text: &self.src[offset..self.pos], + depth: self.depth, + }); + } + + // Anything else — consume one byte (or UTF-8 char) as Unknown so we keep progressing. + let char_len = utf8_char_len(self.bytes, self.pos); + self.pos += char_len; + Some(Token { + kind: TokenKind::Unknown, + text: &self.src[offset..self.pos], + depth: self.depth, + }) + } + + /// Reads a quoted run delimited by `quote`. Supports the doubled-quote escape + /// (`''`, `""`, `` `` ``) and backslash escapes. When `allow_triple` is set + /// and the opener is tripled (`'''`/`"""`), reads until the matching triple. + fn read_quoted( + &mut self, + offset: usize, + quote: u8, + allow_triple: bool, + kind: TokenKind, + ) -> Token<'a> { + if allow_triple + && self.peek(0) == Some(quote) + && self.peek(1) == Some(quote) + && self.peek(2) == Some(quote) + { + self.pos += 3; + while !self.at_eof() { + if self.peek(0) == Some(quote) + && self.peek(1) == Some(quote) + && self.peek(2) == Some(quote) + { + self.pos += 3; + break; + } + self.pos += 1; + } + return Token { + kind, + text: &self.src[offset..self.pos], + depth: self.depth, + }; + } + self.pos += 1; + while !self.at_eof() { + let c = self.peek(0).unwrap(); + if c == b'\\' { + self.pos = (self.pos + 2).min(self.bytes.len()); + continue; + } + if c == quote { + if self.peek(1) == Some(quote) { + self.pos += 2; + continue; + } + self.pos += 1; + break; + } + self.pos += 1; + } + Token { + kind, + text: &self.src[offset..self.pos], + depth: self.depth, + } + } + + fn try_read_dollar_quoted(&mut self, offset: usize) -> Option> { + // $[tag]$...$[tag]$ where tag is optional alnum/_ run. + let mut tag_end = self.pos + 1; + while tag_end < self.bytes.len() + && (self.bytes[tag_end] == b'_' || self.bytes[tag_end].is_ascii_alphanumeric()) + { + tag_end += 1; + } + if tag_end >= self.bytes.len() || self.bytes[tag_end] != b'$' { + return None; + } + let delim_len = tag_end - self.pos + 1; + let delim = self.src[self.pos..self.pos + delim_len].to_string(); + self.pos += delim_len; + while self.pos + delim_len <= self.bytes.len() { + if self.src[self.pos..self.pos + delim_len] == delim { + self.pos += delim_len; + return Some(Token { + kind: TokenKind::StringLit, + text: &self.src[offset..self.pos], + depth: self.depth, + }); + } + self.pos += 1; + } + self.pos = self.bytes.len(); + Some(Token { + kind: TokenKind::StringLit, + text: &self.src[offset..self.pos], + depth: self.depth, + }) + } + + fn try_read_placeholder(&mut self, offset: usize) -> Option> { + let mut end = self.pos + 1; + while end < self.bytes.len() && self.bytes[end] != b'}' { + end += 1; + } + if end >= self.bytes.len() { + return None; + } + let inner = &self.src[self.pos + 1..end]; + let (prefix, idx_str) = inner.split_once(':')?; + let idx: usize = idx_str.parse().ok()?; + let kind = match prefix { + "arg" => PlaceholderKind::Arg, + "fp" => PlaceholderKind::FilterParam, + "fg" => PlaceholderKind::FilterGroup, + "sv" => PlaceholderKind::SecurityValue, + _ => return None, + }; + self.pos = end + 1; + Some(Token { + kind: TokenKind::Placeholder { kind, index: idx }, + text: &self.src[offset..self.pos], + depth: self.depth, + }) + } + + fn read_opaque_braces(&mut self, offset: usize) -> Token<'a> { + self.pos += 1; + let mut nested = 1usize; + while self.pos < self.bytes.len() && nested > 0 { + match self.bytes[self.pos] { + b'{' => nested += 1, + b'}' => nested -= 1, + _ => {} + } + self.pos += 1; + } + Token { + kind: TokenKind::OpaqueBraces, + text: &self.src[offset..self.pos], + depth: self.depth, + } + } + + fn read_number(&mut self, offset: usize) -> Token<'a> { + while !self.at_eof() { + let c = self.peek(0).unwrap(); + if c.is_ascii_digit() || c == b'.' || c == b'_' { + self.pos += 1; + continue; + } + if (c == b'e' || c == b'E') + && matches!(self.peek(1), Some(b'+') | Some(b'-') | Some(b'0'..=b'9')) + { + self.pos += 1; + if matches!(self.peek(0), Some(b'+') | Some(b'-')) { + self.pos += 1; + } + while !self.at_eof() && self.peek(0).unwrap().is_ascii_digit() { + self.pos += 1; + } + break; + } + break; + } + Token { + kind: TokenKind::Number, + text: &self.src[offset..self.pos], + depth: self.depth, + } + } + + fn read_word(&mut self, offset: usize) -> Token<'a> { + while !self.at_eof() && is_ident_cont(self.peek(0).unwrap()) { + self.pos += 1; + } + Token { + kind: TokenKind::Word, + text: &self.src[offset..self.pos], + depth: self.depth, + } + } +} + +fn is_ident_start(b: u8) -> bool { + b == b'_' || b.is_ascii_alphabetic() || b >= 0x80 +} + +fn is_ident_cont(b: u8) -> bool { + b == b'_' || b.is_ascii_alphanumeric() || b >= 0x80 +} + +fn is_operator_byte(b: u8) -> bool { + matches!( + b, + b'+' | b'-' + | b'*' + | b'/' + | b'%' + | b'=' + | b'<' + | b'>' + | b'!' + | b'|' + | b'&' + | b'^' + | b'~' + | b'?' + | b'@' + | b'#' + | b':' + ) +} + +fn utf8_char_len(bytes: &[u8], pos: usize) -> usize { + let b = bytes[pos]; + if b < 0x80 { + 1 + } else if b < 0xC0 { + 1 + } else if b < 0xE0 { + 2 + } else if b < 0xF0 { + 3 + } else { + 4 + } + .min(bytes.len() - pos) +} + +fn matches_any_keyword(word: &str, keywords: &[&str]) -> bool { + keywords.iter().any(|kw| word.eq_ignore_ascii_case(kw)) +} + +fn is_operator_keyword(word: &str) -> bool { + const KEYWORDS: &[&str] = &[ + "AND", "OR", "NOT", "IS", "LIKE", "ILIKE", "RLIKE", "BETWEEN", "IN", "SIMILAR", "OVERLAPS", + "ESCAPE", "ANY", "ALL", "SOME", "COLLATE", + ]; + matches_any_keyword(word, KEYWORDS) +} + +fn is_case_start(word: &str) -> bool { + word.eq_ignore_ascii_case("CASE") +} + +fn is_case_end(word: &str) -> bool { + word.eq_ignore_ascii_case("END") +} + +fn is_case_keyword(word: &str) -> bool { + const KEYWORDS: &[&str] = &["WHEN", "THEN", "ELSE", "CASE", "END"]; + matches_any_keyword(word, KEYWORDS) +} + +fn tokenize_all(src: &str) -> Vec> { + let mut tokenizer = Tokenizer::new(src); + let mut out = Vec::new(); + while let Some(t) = tokenizer.next_token() { + out.push(t); + } + out +} + +// ---------- Classifier: render-time atomicity ---------- + +/// Returns `true` if `sql` has a top-level operator (or operator-keyword) and +/// therefore needs parentheses when embedded in an operator context. +/// Atomic forms — identifier, literal, function call (optionally with +/// `OVER/FILTER/WITHIN GROUP/IGNORE NULLS/RESPECT NULLS` suffixes), +/// `CAST/EXTRACT/CASE` constructs, or an already-parenthesized expression — +/// return `false`. +pub fn is_top_level_compound(sql: &str) -> bool { + let mut case_depth: usize = 0; + let mut prev_significant: Option = None; + for tok in tokenize_all(sql) { + if tok.depth != 0 { + continue; + } + if let TokenKind::Word = tok.kind { + if is_case_start(tok.text) { + case_depth += 1; + prev_significant = Some(tok.kind); + continue; + } + if is_case_end(tok.text) && case_depth > 0 { + case_depth -= 1; + prev_significant = Some(tok.kind); + continue; + } + } + if case_depth > 0 { + prev_significant = Some(tok.kind); + continue; + } + match &tok.kind { + TokenKind::Operator => return true, + TokenKind::Comma | TokenKind::Semicolon => return true, + TokenKind::Word => { + if is_operator_keyword(tok.text) { + // Avoid treating "xxx.in" style column refs as operator. + if !matches!(prev_significant, Some(TokenKind::Dot)) { + return true; + } + } + } + _ => {} + } + prev_significant = Some(tok.kind); + } + false +} + +// ---------- Template analyzer: compile-time placeholder contexts ---------- + +/// Analyses an `SqlCall` template and returns, for each `{arg:N}` index present, +/// whether the surrounding context would require a compound substitution to be +/// wrapped in parentheses (`true`) or allow raw inlining (`false`). +/// +/// Indices absent from the returned map were not referenced by any placeholder +/// in the template; the caller should treat them as safe by default. +pub fn analyze_template_arg_contexts(template: &str) -> HashMap { + let tokens = tokenize_all(template); + let mut result: HashMap = HashMap::new(); + + for (i, tok) in tokens.iter().enumerate() { + let idx = match &tok.kind { + TokenKind::Placeholder { + kind: PlaceholderKind::Arg, + index, + } => *index, + _ => continue, + }; + let unsafe_here = is_placeholder_context_unsafe(&tokens, i); + let entry = result.entry(idx).or_insert(false); + *entry = *entry || unsafe_here; + } + result +} + +fn is_placeholder_context_unsafe(tokens: &[Token<'_>], idx: usize) -> bool { + let placeholder_depth = tokens[idx].depth; + + let scan_start = match find_left_boundary(tokens, idx, placeholder_depth) { + Some(lb) => lb + 1, + None => 0, + }; + let scan_end = find_right_boundary(tokens, idx, placeholder_depth).unwrap_or(tokens.len()); + + let mut case_depth: usize = 0; + for (i, tok) in tokens + .iter() + .enumerate() + .skip(scan_start) + .take(scan_end - scan_start) + { + if i == idx { + continue; + } + if tok.depth != placeholder_depth { + continue; + } + if let TokenKind::Word = tok.kind { + if is_case_start(tok.text) { + case_depth += 1; + continue; + } + if is_case_end(tok.text) && case_depth > 0 { + case_depth -= 1; + continue; + } + } + if case_depth > 0 { + // Treat CASE keywords at this depth as boundaries; nothing inside a + // sibling CASE branch can affect this placeholder. + continue; + } + match &tok.kind { + TokenKind::Operator => return true, + TokenKind::Word => { + if is_operator_keyword(tok.text) { + return true; + } + } + _ => {} + } + } + false +} + +/// Returns the largest index `j < idx` that acts as a left boundary for the +/// placeholder's scope: an `Open` at `depth - 1`, or a `Comma/Semicolon/CASE` +/// keyword at `depth`. `None` means "scan from start of input". +fn find_left_boundary(tokens: &[Token<'_>], idx: usize, depth: usize) -> Option { + let mut i = idx; + while i > 0 { + i -= 1; + let t = &tokens[i]; + if t.depth < depth { + return Some(i); + } + if t.depth == depth { + match &t.kind { + TokenKind::Comma | TokenKind::Semicolon => return Some(i), + TokenKind::Word if is_case_keyword(t.text) => return Some(i), + _ => {} + } + } + if let TokenKind::Open(_) = t.kind { + if depth > 0 && t.depth == depth - 1 { + return Some(i); + } + } + } + None +} + +/// Mirror of [`find_left_boundary`]. `None` means "scan to end of input". +fn find_right_boundary(tokens: &[Token<'_>], idx: usize, depth: usize) -> Option { + let mut i = idx + 1; + while i < tokens.len() { + let t = &tokens[i]; + if t.depth < depth { + return Some(i); + } + if t.depth == depth { + match &t.kind { + TokenKind::Comma | TokenKind::Semicolon => return Some(i), + TokenKind::Word if is_case_keyword(t.text) => return Some(i), + _ => {} + } + } + if let TokenKind::Close(_) = t.kind { + if depth > 0 && t.depth == depth - 1 { + return Some(i); + } + } + i += 1; + } + None +} + +// ---------- Tests ---------- + +#[cfg(test)] +mod tests { + use super::*; + + // ----- is_top_level_compound ----- + + #[test] + fn atomic_simple_identifier() { + assert!(!is_top_level_compound("a")); + assert!(!is_top_level_compound("users.id")); + assert!(!is_top_level_compound("schema.table.col")); + } + + #[test] + fn atomic_literals() { + assert!(!is_top_level_compound("1")); + assert!(!is_top_level_compound("1.5")); + assert!(!is_top_level_compound("'hello'")); + assert!(!is_top_level_compound("NULL")); + assert!(!is_top_level_compound("TRUE")); + assert!(!is_top_level_compound("DATE '2020-01-01'")); + } + + #[test] + fn atomic_function_call() { + assert!(!is_top_level_compound("COUNT(*)")); + assert!(!is_top_level_compound("COALESCE(a, b)")); + assert!(!is_top_level_compound("MAX(a + b)")); + assert!(!is_top_level_compound("FN(a OR b, c AND d)")); + assert!(!is_top_level_compound("schema.fn(a)")); + } + + #[test] + fn atomic_window_function() { + assert!(!is_top_level_compound( + "ROW_NUMBER() OVER (PARTITION BY x ORDER BY y)" + )); + assert!(!is_top_level_compound("COUNT(*) FILTER (WHERE x > 0)")); + assert!(!is_top_level_compound( + "PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x)" + )); + assert!(!is_top_level_compound( + "LAST_VALUE(x IGNORE NULLS) OVER (ORDER BY y)" + )); + } + + #[test] + fn atomic_cast() { + assert!(!is_top_level_compound("CAST(x AS INT)")); + assert!(!is_top_level_compound("EXTRACT(YEAR FROM ts)")); + assert!(!is_top_level_compound("x::int")); + assert!(!is_top_level_compound("x::int::text")); + } + + #[test] + fn atomic_case() { + // Searched form. + assert!(!is_top_level_compound( + "CASE WHEN x = 1 THEN 'a' ELSE 'b' END" + )); + assert!(!is_top_level_compound( + "CASE WHEN x IS NULL THEN 0 ELSE x + 1 END" + )); + // Simple form (expression after CASE). + assert!(!is_top_level_compound( + "CASE status WHEN 'active' THEN 1 WHEN 'inactive' THEN 0 ELSE -1 END" + )); + assert!(!is_top_level_compound("CASE x + 1 WHEN 2 THEN 'a' END")); + } + + #[test] + fn atomic_parenthesized() { + assert!(!is_top_level_compound("(a + b)")); + assert!(!is_top_level_compound("(a OR b)")); + } + + #[test] + fn atomic_array_and_tuple_literal() { + assert!(!is_top_level_compound("[1, 2, 3]")); + assert!(!is_top_level_compound("ARRAY[1, 2, 3]")); + } + + #[test] + fn compound_arithmetic() { + assert!(is_top_level_compound("a + b")); + assert!(is_top_level_compound("a - b")); + assert!(is_top_level_compound("a * b")); + assert!(is_top_level_compound("a / b + c")); + } + + #[test] + fn compound_logical() { + assert!(is_top_level_compound("a AND b")); + assert!(is_top_level_compound("a OR b")); + assert!(is_top_level_compound("NOT x")); + assert!(is_top_level_compound("x IS NULL")); + assert!(is_top_level_compound("x BETWEEN 1 AND 10")); + assert!(is_top_level_compound("x LIKE '%foo%'")); + assert!(is_top_level_compound("x IN (1, 2, 3)")); + } + + #[test] + fn compound_comparison() { + assert!(is_top_level_compound("a = b")); + assert!(is_top_level_compound("a < b")); + assert!(is_top_level_compound("a >= b")); + assert!(is_top_level_compound("a <> b")); + } + + #[test] + fn compound_string_concat() { + assert!(is_top_level_compound("'a' || 'b'")); + } + + #[test] + fn strings_protect_contents() { + assert!(!is_top_level_compound("'a OR b'")); + assert!(!is_top_level_compound("'a + b'")); + } + + #[test] + fn comments_protect_contents() { + assert!(!is_top_level_compound("users.id -- a + b")); + assert!(!is_top_level_compound("users.id // a + b")); + assert!(!is_top_level_compound("users.id /* a + b */")); + assert!(!is_top_level_compound("users./* comment */id")); + } + + #[test] + fn dollar_quoted_strings() { + assert!(!is_top_level_compound("$$a + b$$")); + assert!(!is_top_level_compound("$tag$a + b$tag$")); + } + + #[test] + fn triple_quoted_strings() { + assert!(!is_top_level_compound("'''a + b'''")); + assert!(!is_top_level_compound("\"\"\"a + b\"\"\"")); + } + + #[test] + fn mssql_bracket_identifier() { + assert!(!is_top_level_compound("[my col]")); + assert!(!is_top_level_compound("[a+b]")); + } + + #[test] + fn clickhouse_opaque_braces() { + // {name:Type} is a CH parameter, should be treated as opaque atom. + assert!(!is_top_level_compound("{user_id:Int64}")); + } + + #[test] + fn nested_case_is_atomic() { + assert!(!is_top_level_compound( + "CASE WHEN CASE WHEN x = 1 THEN y = 2 ELSE y = 3 END THEN 'a' ELSE 'b' END" + )); + } + + // ----- analyze_template_arg_contexts ----- + + fn is_unsafe(template: &str, arg: usize) -> bool { + let m = analyze_template_arg_contexts(template); + *m.get(&arg).unwrap_or(&false) + } + + #[test] + fn direct_reference_is_safe() { + assert!(!is_unsafe("{arg:0}", 0)); + } + + #[test] + fn top_level_arithmetic_is_unsafe() { + assert!(is_unsafe("{arg:0} + 1", 0)); + assert!(is_unsafe("1 + {arg:0}", 0)); + assert!(is_unsafe("{arg:0} * {arg:1}", 0)); + assert!(is_unsafe("{arg:0} * {arg:1}", 1)); + } + + #[test] + fn top_level_logical_is_unsafe() { + assert!(is_unsafe("{arg:0} AND x", 0)); + assert!(is_unsafe("{arg:0} OR {arg:1}", 0)); + assert!(is_unsafe("NOT {arg:0}", 0)); + assert!(is_unsafe("{arg:0} IS NULL", 0)); + assert!(is_unsafe("{arg:0} BETWEEN 1 AND 10", 0)); + } + + #[test] + fn function_arg_is_safe() { + assert!(!is_unsafe("FN({arg:0})", 0)); + assert!(!is_unsafe("FN({arg:0}, x)", 0)); + assert!(!is_unsafe("FN(x, {arg:0})", 0)); + assert!(!is_unsafe("COALESCE({arg:0}, {arg:1}, 0)", 0)); + assert!(!is_unsafe("COALESCE({arg:0}, {arg:1}, 0)", 1)); + } + + #[test] + fn cast_arg_is_safe() { + assert!(!is_unsafe("CAST({arg:0} AS INT)", 0)); + assert!(!is_unsafe("EXTRACT(YEAR FROM {arg:0})", 0)); + } + + #[test] + fn function_with_inner_operator_is_unsafe() { + assert!(is_unsafe("FN({arg:0} + 1)", 0)); + assert!(is_unsafe("FN(x, {arg:0} OR y)", 0)); + } + + #[test] + fn join_equality_template() { + // {arg:0} = {arg:1} — classic join condition. + assert!(is_unsafe("{arg:0} = {arg:1}", 0)); + assert!(is_unsafe("{arg:0} = {arg:1}", 1)); + } + + #[test] + fn case_branch_scoping() { + // Inside a CASE branch, sibling branches should not affect scoping. + // `{arg:0}` sits in the THEN branch; the `=` is in a sibling WHEN branch. + assert!(!is_unsafe("CASE WHEN y = 1 THEN {arg:0} ELSE 0 END", 0)); + // But if the placeholder is inside a branch with a top-level operator, + // that branch still produces compound context. + assert!(is_unsafe("CASE WHEN y = 1 THEN {arg:0} + 1 ELSE 0 END", 0)); + } + + #[test] + fn string_literals_hide_placeholders_logic() { + // Placeholder here is inside a string literal — tokenizer swallows it. + // The non-string template still reports correctly. + assert!(is_unsafe("'{arg:0}' + {arg:0}", 0)); + } + + #[test] + fn multiple_occurrences_merge_with_or() { + // One occurrence safe, another unsafe — overall must be unsafe. + assert!(is_unsafe("FN({arg:0}) + {arg:0}", 0)); + } +}