Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::planner::sql_evaluator::sql_nodes::SqlNodesFactory;
use crate::planner::sql_evaluator::{CubeNameSymbol, CubeTableSymbol};
use crate::planner::sql_templates::PlanSqlTemplates;
use crate::planner::VisitorContext;
use crate::utils::sql_expression_scanner::analyze_template_arg_contexts;
use cubenativeutils::CubeError;
use itertools::Itertools;
use std::collections::HashMap;
Expand Down Expand Up @@ -117,6 +118,10 @@ pub struct SqlCall {
filter_params: Vec<SqlCallFilterParamsItem>,
filter_groups: Vec<SqlCallFilterGroupItem>,
security_context: SecutityContextProps,
/// Per `{arg:N}` index: whether the surrounding context in the template
/// would make a compound substitution unsafe (requiring parentheses).
/// Computed once at construction from the template.
arg_paren_contexts: HashMap<usize, bool>,
}

impl SqlCall {
Expand All @@ -127,12 +132,26 @@ impl SqlCall {
filter_groups: Vec<SqlCallFilterGroupItem>,
security_context: SecutityContextProps,
) -> Self {
let arg_paren_contexts = match &template {
SqlTemplate::String(s) => analyze_template_arg_contexts(s),
SqlTemplate::StringVec(strings) => {
let mut merged: HashMap<usize, bool> = HashMap::new();
for s in strings {
for (idx, needs_safe) in analyze_template_arg_contexts(s) {
let entry = merged.entry(idx).or_insert(false);
*entry = *entry || needs_safe;
}
}
merged
}
};
Self {
template,
deps,
filter_params,
filter_groups,
security_context,
arg_paren_contexts,
}
}

Expand Down Expand Up @@ -254,10 +273,22 @@ impl SqlCall {
let deps = self
.deps
.iter()
.map(|dep| match dep {
SqlDependency::Symbol(m) => visitor.apply(m, node_processor.clone(), templates),
SqlDependency::CubeRef(cr) => {
visitor.evaluate_cube_ref(cr, node_processor.clone(), templates)
.enumerate()
.map(|(i, dep)| {
// Each arg's `arg_needs_paren_safe` flag is set by this call's
// template context, overriding whatever the caller's visitor
// carried. The caller's flag only governs wrapping of this
// whole SqlCall's output, handled by an enclosing Parenthesize
// node up the processor chain.
let needs_safe = *self.arg_paren_contexts.get(&i).unwrap_or(&false);
let arg_visitor = visitor.with_arg_needs_paren_safe(needs_safe);
match dep {
SqlDependency::Symbol(m) => {
arg_visitor.apply(m, node_processor.clone(), templates)
}
SqlDependency::CubeRef(cr) => {
arg_visitor.evaluate_cube_ref(cr, node_processor.clone(), templates)
}
}
})
.collect::<Result<Vec<_>, _>>()?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,6 @@ impl SqlNode for CalendarTimeShiftSqlNode {
node_processor: Rc<dyn SqlNode>,
templates: &PlanSqlTemplates,
) -> Result<String, CubeError> {
let input = self.input.to_sql(
visitor,
node,
query_tools.clone(),
node_processor.clone(),
templates,
)?;
let res = match node.as_ref() {
MemberSymbol::Dimension(ev) => {
if !ev.is_reference() {
Expand All @@ -55,20 +48,52 @@ impl SqlNode for CalendarTimeShiftSqlNode {
templates,
)?
} else if let Some(interval) = &shift.interval {
let inner_visitor = visitor.with_arg_needs_paren_safe(false);
let input = self.input.to_sql(
&inner_visitor,
node,
query_tools.clone(),
node_processor.clone(),
templates,
)?;
let res = templates
.add_timestamp_interval(input, interval.inverse().to_sql())?;
format!("({})", res)
} else {
input
self.input.to_sql(
visitor,
node,
query_tools.clone(),
node_processor.clone(),
templates,
)?
}
} else {
input
self.input.to_sql(
visitor,
node,
query_tools.clone(),
node_processor.clone(),
templates,
)?
}
} else {
input
self.input.to_sql(
visitor,
node,
query_tools.clone(),
node_processor.clone(),
templates,
)?
}
}
_ => input,
_ => self.input.to_sql(
visitor,
node,
query_tools.clone(),
node_processor.clone(),
templates,
)?,
};
Ok(res)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,20 @@ impl CaseSqlNode {
node_processor: Rc<dyn SqlNode>,
templates: &PlanSqlTemplates,
) -> Result<String, CubeError> {
// All sub-SQLs end up inside `CASE … END` — a safe wrap.
let inner_visitor = visitor.with_arg_needs_paren_safe(false);
let mut when_then = Vec::new();
for itm in case.items.iter() {
let when = itm.sql.eval(
visitor,
&inner_visitor,
node_processor.clone(),
query_tools.clone(),
templates,
)?;
let then = match &itm.label {
CaseLabel::String(s) => templates.quote_string(&s)?,
CaseLabel::Sql(sql) => sql.eval(
visitor,
&inner_visitor,
node_processor.clone(),
query_tools.clone(),
templates,
Expand All @@ -53,7 +55,7 @@ impl CaseSqlNode {
let else_label = match &case.else_label {
CaseLabel::String(s) => templates.quote_string(&s)?,
CaseLabel::Sql(sql) => sql.eval(
visitor,
&inner_visitor,
node_processor.clone(),
query_tools.clone(),
templates,
Expand All @@ -69,6 +71,9 @@ impl CaseSqlNode {
node_processor: Rc<dyn SqlNode>,
templates: &PlanSqlTemplates,
) -> Result<String, CubeError> {
// Degenerate shortcuts return the inner SQL as-is — propagate the outer
// visitor so an enclosing ParenthesizeSqlNode still sees the compound
// flag.
if case.items.len() == 1 && case.else_sql.is_none() {
return case.items[0].sql.eval(
visitor,
Expand All @@ -85,22 +90,23 @@ impl CaseSqlNode {
templates,
);
}
let inner_visitor = visitor.with_arg_needs_paren_safe(false);
let expr = match &case.switch {
CaseSwitchItem::Sql(sql_call) => sql_call.eval(
visitor,
&inner_visitor,
node_processor.clone(),
query_tools.clone(),
templates,
)?,
CaseSwitchItem::Member(member_symbol) => {
visitor.apply(&member_symbol, node_processor.clone(), templates)?
inner_visitor.apply(&member_symbol, node_processor.clone(), templates)?
}
};
let mut when_then = Vec::new();
for itm in case.items.iter() {
let when = templates.quote_string(&itm.value)?;
let then = itm.sql.eval(
visitor,
&inner_visitor,
node_processor.clone(),
query_tools.clone(),
templates,
Expand All @@ -109,7 +115,7 @@ impl CaseSqlNode {
}
let else_label = if let Some(else_sql) = &case.else_sql {
Some(else_sql.eval(
visitor,
&inner_visitor,
node_processor.clone(),
query_tools.clone(),
templates,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::{
AutoPrefixSqlNode, CaseSqlNode, EvaluateSqlNode, FinalMeasureSqlNode,
FinalPreAggregationMeasureSqlNode, GeoDimensionSqlNode, MaskedSqlNode, MeasureFilterSqlNode,
MultiStageRankNode, MultiStageWindowNode, RenderReferencesSqlNode, RenderReferencesType,
RollingWindowNode, RootSqlNode, SqlNode, TimeDimensionNode, TimeShiftSqlNode,
UngroupedMeasureSqlNode, UngroupedQueryFinalMeasureSqlNode,
MultiStageRankNode, MultiStageWindowNode, ParenthesizeSqlNode, RenderReferencesSqlNode,
RenderReferencesType, RollingWindowNode, RootSqlNode, SqlNode, TimeDimensionNode,
TimeShiftSqlNode, UngroupedMeasureSqlNode, UngroupedQueryFinalMeasureSqlNode,
};
use crate::planner::planners::multi_stage::TimeShiftState;
use crate::planner::sql_evaluator::cube_ref_evaluator::CubeRefEvaluator;
Expand Down Expand Up @@ -156,8 +156,10 @@ impl SqlNodesFactory {
evaluate_sql_processor.clone(),
self.cube_name_references.clone(),
);
let parenthesize_processor: Rc<dyn SqlNode> =
ParenthesizeSqlNode::new(auto_prefix_processor.clone());

let measure_filter_processor = MeasureFilterSqlNode::new(auto_prefix_processor.clone());
let measure_filter_processor = MeasureFilterSqlNode::new(parenthesize_processor.clone());
let measure_processor = CaseSqlNode::new(measure_filter_processor.clone());

let measure_processor = self.add_ungrouped_measure_reference_if_needed(measure_processor);
Expand All @@ -182,10 +184,11 @@ impl SqlNodesFactory {
} else {
evaluate_sql_processor.clone()
};
let default_processor: Rc<dyn SqlNode> = ParenthesizeSqlNode::new(default_processor);

let root_node = RootSqlNode::new(
self.dimension_processor(evaluate_sql_processor.clone()),
self.time_dimension_processor(evaluate_sql_processor.clone()),
self.time_dimension_processor(ParenthesizeSqlNode::new(evaluate_sql_processor.clone())),
measure_processor.clone(),
default_processor,
);
Expand Down Expand Up @@ -261,6 +264,8 @@ impl SqlNodesFactory {
let input: Rc<dyn SqlNode> =
AutoPrefixSqlNode::new(input, self.cube_name_references.clone());

let input: Rc<dyn SqlNode> = ParenthesizeSqlNode::new(input);

let input: Rc<dyn SqlNode> =
TimeDimensionNode::new(self.dimensions_with_ignored_timezone.clone(), input);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::SqlNode;
use crate::planner::query_tools::QueryTools;
use crate::planner::sql_evaluator::symbols::{AggregateWrap, MeasureSymbol};
use crate::planner::sql_evaluator::symbols::AggregateWrap;
use crate::planner::sql_evaluator::MemberSymbol;
use crate::planner::sql_evaluator::SqlEvaluatorVisitor;
use crate::planner::sql_templates::PlanSqlTemplates;
Expand Down Expand Up @@ -32,16 +32,13 @@ impl FinalMeasureSqlNode {
&self.input
}

fn wrap_aggregate(
fn apply_wrap(
&self,
ev: &MeasureSymbol,
wrap: AggregateWrap,
input: String,
templates: &PlanSqlTemplates,
) -> Result<String, CubeError> {
let is_multiplied = self
.rendered_as_multiplied_measures
.contains(&ev.full_name());
match ev.kind().aggregate_wrap(is_multiplied) {
match wrap {
AggregateWrap::PassThrough => Ok(input),
AggregateWrap::Function(name) => Ok(format!("{}({})", name, input)),
AggregateWrap::CountDistinct => templates.count_distinct(&input),
Expand All @@ -67,14 +64,22 @@ impl SqlNode for FinalMeasureSqlNode {
) -> Result<String, CubeError> {
let res = match node.as_ref() {
MemberSymbol::Measure(ev) => {
let is_multiplied = self
.rendered_as_multiplied_measures
.contains(&ev.full_name());
let wrap = ev.kind().aggregate_wrap(is_multiplied);
let child_visitor = match wrap {
AggregateWrap::PassThrough => visitor.clone(),
_ => visitor.with_arg_needs_paren_safe(false),
};
let input = self.input.to_sql(
visitor,
&child_visitor,
node,
query_tools.clone(),
node_processor.clone(),
templates,
)?;
self.wrap_aggregate(ev, input, templates)?
self.apply_wrap(wrap, input, templates)?
}
_ => {
return Err(CubeError::internal(format!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ impl SqlNode for GeoDimensionSqlNode {
let res = match node.as_ref() {
MemberSymbol::Dimension(ev) => {
if let DimensionKind::Geo(geo) = ev.kind() {
let inner_visitor = visitor.with_arg_needs_paren_safe(false);
let latitude_str = geo.latitude().eval(
visitor,
&inner_visitor,
node_processor.clone(),
query_tools.clone(),
templates,
)?;
let longitude_str = geo.longitude().eval(
visitor,
&inner_visitor,
node_processor.clone(),
query_tools.clone(),
templates,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,25 @@ impl SqlNode for MeasureFilterSqlNode {
node_processor: Rc<dyn SqlNode>,
templates: &PlanSqlTemplates,
) -> Result<String, CubeError> {
let input = self.input.to_sql(
visitor,
node,
query_tools.clone(),
node_processor.clone(),
templates,
)?;
let res = match node.as_ref() {
MemberSymbol::Measure(ev) => {
let measure_filters = ev.measure_filters();
if !measure_filters.is_empty() {
let inner_visitor = visitor.with_arg_needs_paren_safe(false);
let input = self.input.to_sql(
&inner_visitor,
node,
query_tools.clone(),
node_processor.clone(),
templates,
)?;
let filters = measure_filters
.iter()
.map(|filter| -> Result<String, CubeError> {
Ok(format!(
"({})",
filter.eval(
&visitor,
&inner_visitor,
node_processor.clone(),
query_tools.clone(),
templates
Expand All @@ -63,7 +64,14 @@ impl SqlNode for MeasureFilterSqlNode {
};
format!("CASE WHEN {} THEN {} END", filters, result)
} else {
input
// Passthrough — propagate visitor unchanged.
self.input.to_sql(
visitor,
node,
query_tools.clone(),
node_processor.clone(),
templates,
)?
}
}
_ => {
Expand Down
Loading
Loading