diff --git a/sql_redis/parser.py b/sql_redis/parser.py index dc12690..3aff647 100644 --- a/sql_redis/parser.py +++ b/sql_redis/parser.py @@ -210,6 +210,30 @@ class ScoringSpec: scorer: str = "BM25" # Scorer algorithm (BM25, TFIDF, DISMAX, etc.) +@dataclass +class BoolLeaf: + """A leaf in the WHERE-clause boolean tree wrapping a single Condition.""" + + condition: Condition + + +@dataclass +class BoolGroup: + """An internal node in the WHERE-clause boolean tree. + + Preserves the SQL operator precedence and parenthesization so that + mixed expressions like ``A AND (B OR C)`` are not flattened into a + single boolean operator. + """ + + operator: str # "AND" or "OR" + children: list = dataclasses.field(default_factory=list) + + +# Type alias for a boolean tree node — Union[BoolLeaf, BoolGroup]. +BoolNode = "BoolLeaf | BoolGroup" + + @dataclass class ParsedQuery: """Result of parsing a SQL query.""" @@ -222,6 +246,12 @@ class ParsedQuery: default_factory=list ) boolean_operator: str = "AND" + condition_tree: object | None = None # BoolLeaf | BoolGroup | None + # True iff the WHERE clause contains an OR anywhere in the original SQL. + # Set during parsing — independent of the boolean tree, which may collapse + # an OR group when one side is a side-channel predicate (geo_distance) + # that produces no tree leaf. + has_or_in_where: bool = False aggregations: list[AggregationSpec] = dataclasses.field(default_factory=list) computed_fields: list[ComputedField] = dataclasses.field(default_factory=list) date_functions: list[DateFunctionSpec] = dataclasses.field(default_factory=list) @@ -267,7 +297,12 @@ def parse(self, sql: str) -> ParsedQuery: # Extract WHERE clause conditions where = ast.find(exp.Where) if where: - self._process_where_clause(where.this, result) + tree = self._process_where_clause(where.this, result) + result.condition_tree = tree + # Set legacy boolean_operator from the tree root for backward + # compatibility with callers that still consult this field. + if isinstance(tree, BoolGroup): + result.boolean_operator = tree.operator # Extract GROUP BY clause group = ast.find(exp.Group) @@ -352,11 +387,35 @@ def _process_select_expression_inner( func_name = redis_func_map.get(func_name, func_name) field_name = None # Get the field being aggregated (if any) - if expression.this: - if isinstance(expression.this, exp.Column): - field_name = expression.this.name - elif isinstance(expression.this, exp.Star): - field_name = None # COUNT(*) + inner = expression.this + if isinstance(inner, exp.Distinct): + # AGG(DISTINCT col) — only COUNT has a native RediSearch + # equivalent (COUNT_DISTINCT). Other aggregates can't be + # silently translated to a non-distinct form, so raise. + distinct_cols = inner.expressions or ( + [inner.this] if inner.this is not None else [] + ) + if len(distinct_cols) != 1 or not isinstance( + distinct_cols[0], exp.Column + ): + raise ValueError( + f"{func_name}(DISTINCT ...) expects a single column " + "reference; multi-column or expression DISTINCT is " + "not supported by RediSearch." + ) + if func_name != "COUNT": + raise ValueError( + f"{func_name}(DISTINCT ...) is not supported by " + "RediSearch. Only COUNT(DISTINCT x) maps to a native " + "reducer (COUNT_DISTINCT); pre-deduplicate the data " + "or use COUNT_DISTINCT for cardinality." + ) + func_name = "COUNT_DISTINCT" + field_name = distinct_cols[0].name + elif isinstance(inner, exp.Column): + field_name = inner.name + elif isinstance(inner, exp.Star): + field_name = None # COUNT(*) result.aggregations.append( AggregationSpec(function=func_name, field=field_name, alias=alias) ) @@ -688,39 +747,48 @@ def _process_date_expression( def _process_where_clause( self, expression, result: ParsedQuery, negated: bool = False - ) -> None: - """Process WHERE clause expression recursively.""" + ): + """Process WHERE clause expression recursively. + + Returns a boolean tree (BoolLeaf or BoolGroup) preserving the original + SQL operator precedence and grouping, or None when the expression + contributes no boolean clause to the RediSearch query string (e.g., + geo_distance comparisons stored separately on result.geo_conditions). + """ if isinstance(expression, exp.EQ): - self._add_condition(expression, "=", result, negated) + return self._leaf(self._add_condition(expression, "=", result, negated)) elif isinstance(expression, exp.GT): - self._add_condition(expression, ">", result, negated) + return self._leaf(self._add_condition(expression, ">", result, negated)) elif isinstance(expression, exp.GTE): - self._add_condition(expression, ">=", result, negated) + return self._leaf(self._add_condition(expression, ">=", result, negated)) elif isinstance(expression, exp.LT): - self._add_condition(expression, "<", result, negated) + return self._leaf(self._add_condition(expression, "<", result, negated)) elif isinstance(expression, exp.LTE): - self._add_condition(expression, "<=", result, negated) + return self._leaf(self._add_condition(expression, "<=", result, negated)) elif isinstance(expression, exp.NEQ): - self._add_condition(expression, "!=", result, negated) + return self._leaf(self._add_condition(expression, "!=", result, negated)) elif isinstance(expression, exp.Between): - self._add_between_condition(expression, result, negated) + return self._leaf(self._add_between_condition(expression, result, negated)) elif isinstance(expression, exp.In): - self._add_in_condition(expression, result, negated) + return self._leaf(self._add_in_condition(expression, result, negated)) elif isinstance(expression, exp.Like): # LIKE 'pattern%' / '%pattern' / '%pattern%' - self._add_condition(expression, "LIKE", result, negated) + return self._leaf(self._add_condition(expression, "LIKE", result, negated)) elif isinstance(expression, exp.And): - result.boolean_operator = "AND" - self._process_where_clause(expression.this, result, negated) - self._process_where_clause(expression.expression, result, negated) + left = self._process_where_clause(expression.this, result, negated) + right = self._process_where_clause(expression.expression, result, negated) + return self._combine("AND", left, right) elif isinstance(expression, exp.Or): - result.boolean_operator = "OR" - self._process_where_clause(expression.this, result, negated) - self._process_where_clause(expression.expression, result, negated) + result.has_or_in_where = True + left = self._process_where_clause(expression.this, result, negated) + right = self._process_where_clause(expression.expression, result, negated) + return self._combine("OR", left, right) elif isinstance(expression, exp.Not): - self._process_where_clause(expression.this, result, negated=not negated) + return self._process_where_clause( + expression.this, result, negated=not negated + ) elif isinstance(expression, exp.Paren): - self._process_where_clause(expression.this, result, negated=negated) + return self._process_where_clause(expression.this, result, negated=negated) elif isinstance(expression, exp.Is): # IS NULL: exp.Is(this=Column, expression=Null()) # IS NOT NULL arrives here with negated=True via the exp.Not handler above @@ -728,14 +796,14 @@ def _process_where_clause( expression.expression, exp.Null ): operator = "IS_NOT_NULL" if negated else "IS_NULL" - result.conditions.append( - Condition( - field=expression.this.name, - operator=operator, - value=None, - negated=False, - ) + cond = Condition( + field=expression.this.name, + operator=operator, + value=None, + negated=False, ) + result.conditions.append(cond) + return BoolLeaf(cond) else: raise ValueError( "Unsupported IS expression in WHERE clause; only " @@ -752,9 +820,39 @@ def _process_where_clause( "for post-aggregate filtering." ) # EXISTS (SELECT ...) — SQL subquery, silently ignored (not supported) + return None elif isinstance(expression, exp.Anonymous): # Custom function like MATCH(field, value) - self._add_function_condition(expression, result, negated) + return self._leaf(self._add_function_condition(expression, result, negated)) + return None + + @staticmethod + def _leaf(condition: Condition | None): + """Wrap a Condition in a BoolLeaf, or return None for non-leaf adds.""" + if condition is None: + return None + return BoolLeaf(condition) + + @staticmethod + def _combine(operator: str, left, right): + """Combine two child nodes under a boolean operator. + + Drops None children, flattens same-operator subtrees so that + ``A AND B AND C`` produces a single AND group with three children. + """ + children: list = [] + for child in (left, right): + if child is None: + continue + if isinstance(child, BoolGroup) and child.operator == operator: + children.extend(child.children) + else: + children.append(child) + if not children: + return None + if len(children) == 1: + return children[0] + return BoolGroup(operator=operator, children=children) def _process_having_clause(self, expression, result: ParsedQuery) -> None: """Process HAVING clause — routes exists() to filters.""" @@ -780,8 +878,12 @@ def _process_having_clause(self, expression, result: ParsedQuery) -> None: def _add_condition( self, expression, operator: str, result: ParsedQuery, negated: bool - ) -> None: - """Add a condition from a comparison expression.""" + ) -> Condition | None: + """Add a condition from a comparison expression. + + Returns the appended Condition for inclusion in the boolean tree, or + None when the expression was routed to result.geo_conditions instead. + """ field_name = None value = None is_geo_distance = False @@ -875,20 +977,26 @@ def _add_condition( unit=geo_unit, ) ) + return None else: - result.conditions.append( - Condition( - field=field_name, - operator=operator, - value=value, - negated=negated, - ) + cond = Condition( + field=field_name, + operator=operator, + value=value, + negated=negated, ) + result.conditions.append(cond) + return cond + return None def _add_between_condition( self, expression, result: ParsedQuery, negated: bool - ) -> None: - """Add a BETWEEN condition.""" + ) -> Condition | None: + """Add a BETWEEN condition. + + Returns the appended Condition for inclusion in the boolean tree, or + None when the expression was routed to result.geo_conditions instead. + """ field_name = None is_geo_distance = False geo_lon = None @@ -963,18 +1071,22 @@ def _add_between_condition( unit=geo_unit, ) ) + return None else: - result.conditions.append( - Condition( - field=field_name, - operator="BETWEEN", - value=(low_val, high_val), - negated=negated, - ) + cond = Condition( + field=field_name, + operator="BETWEEN", + value=(low_val, high_val), + negated=negated, ) + result.conditions.append(cond) + return cond + return None - def _add_in_condition(self, expression, result: ParsedQuery, negated: bool) -> None: - """Add an IN condition.""" + def _add_in_condition( + self, expression, result: ParsedQuery, negated: bool + ) -> Condition | None: + """Add an IN condition. Returns the appended Condition or None.""" field_name = None if isinstance(expression.this, exp.Column): field_name = expression.this.name @@ -982,16 +1094,21 @@ def _add_in_condition(self, expression, result: ParsedQuery, negated: bool) -> N values = [self._extract_literal_value(e) for e in expression.expressions] if field_name is not None: - result.conditions.append( - Condition( - field=field_name, operator="IN", value=values, negated=negated - ) + cond = Condition( + field=field_name, operator="IN", value=values, negated=negated ) + result.conditions.append(cond) + return cond + return None def _add_function_condition( self, expression, result: ParsedQuery, negated: bool - ) -> None: - """Add a condition from a function call like fulltext(field, value) or fuzzy(field, value, level).""" + ) -> Condition | None: + """Add a condition from a function call like fulltext(field, value) or fuzzy(field, value, level). + + Returns the appended Condition for inclusion in the boolean tree, or + None when no condition was added. + """ func_name = expression.name.upper() args = expression.expressions @@ -1071,16 +1188,16 @@ def _add_function_condition( "fulltext() first argument must be a column name, " f"got {args[0]}. Usage: fulltext(field, 'search terms')" ) - result.conditions.append( - Condition( - field=field_name, - operator="FULLTEXT", - value=value, - negated=negated, - slop=slop, - inorder=inorder, - ) + cond = Condition( + field=field_name, + operator="FULLTEXT", + value=value, + negated=negated, + slop=slop, + inorder=inorder, ) + result.conditions.append(cond) + return cond elif func_name == "FUZZY" and len(args) >= 2: field_name = args[0].name if isinstance(args[0], exp.Column) else None @@ -1121,15 +1238,17 @@ def _add_function_condition( "fuzzy() first argument must be a column name, " f"got {args[0]}. Usage: fuzzy(field, 'search term')" ) - result.conditions.append( - Condition( - field=field_name, - operator="FUZZY", - value=value, - negated=negated, - fuzzy_level=fuzzy_level, - ) + cond = Condition( + field=field_name, + operator="FUZZY", + value=value, + negated=negated, + fuzzy_level=fuzzy_level, ) + result.conditions.append(cond) + return cond + + return None def _extract_literal_value(self, expression, convert_dates: bool = False): """Extract a Python value from a sqlglot Literal or Neg expression. diff --git a/sql_redis/query_builder.py b/sql_redis/query_builder.py index 8a8acf0..94c2361 100644 --- a/sql_redis/query_builder.py +++ b/sql_redis/query_builder.py @@ -85,6 +85,25 @@ def _escape_text_value(value: str) -> str: # then escape double quotes. return value.replace("\\", "\\\\").replace('"', '\\"') + @classmethod + def _escape_text_equality_term(cls, term: str) -> str: + """Escape single-term equality while preserving legacy wildcard semantics. + + For backward compatibility, TEXT equality on a single token continues to + behave like a RediSearch token query instead of an exact quoted phrase. + This preserves wildcard markers like `*` and fuzzy markers like `%term%`, + while still escaping other operator characters. + """ + result = [] + for index, char in enumerate(term): + if char == "*" or (char == "~" and index == 0): + result.append(char) + elif char in cls.TEXT_QUERY_SPECIAL_CHARS: + result.append(f"\\{char}") + else: + result.append(char) + return "".join(result) + def build_text_condition( self, field: str | list[str], @@ -101,7 +120,7 @@ def build_text_condition( Args: field: Field name or list of field names for multi-field search. operator: One of =, !=, FULLTEXT, LIKE, FUZZY. - - = / !=: exact phrase match, value wrapped in double quotes. + - = / !=: single-term token match, or multi-word exact phrase. - FULLTEXT: tokenized keyword search with stopword filtering. - LIKE: prefix/suffix/infix pattern (SQL % → RediSearch *). - FUZZY: Levenshtein fuzzy match. @@ -112,7 +131,7 @@ def build_text_condition( inorder: If True with slop, require terms in order. Returns: - RediSearch query syntax like @field:"exact phrase" or @field:(term1 term2). + RediSearch query syntax like @field:term or @field:"exact phrase". """ # Derive negation from both the flag and the operator itself, # consistent with how build_tag_condition handles != via operator. @@ -143,38 +162,43 @@ def build_text_condition( pct = "%" * level search_value = f"{pct}{escaped}{pct}" elif operator in ("=", "!="): - # Exact phrase match — wrap in double quotes. - # Strip default stopwords because RediSearch does not index them; - # keeping them in the quoted phrase causes a query-time error - # (e.g. "diagnosing and treating" fails on "and"). - # Since the indexer assigns consecutive positions after dropping - # stopwords, the stripped phrase matches correctly. words = value.split() - removed = [w for w in words if w.lower() in REDIS_DEFAULT_STOPWORDS] - filtered = [w for w in words if w.lower() not in REDIS_DEFAULT_STOPWORDS] + if len(words) == 1: + search_value = self._escape_text_equality_term(words[0]) + else: + # Multi-word equality remains an exact phrase match. + # Strip default stopwords because RediSearch does not index them; + # keeping them in the quoted phrase causes a query-time error + # (e.g. "diagnosing and treating" fails on "and"). + # Since the indexer assigns consecutive positions after dropping + # stopwords, the stripped phrase matches correctly. + removed = [w for w in words if w.lower() in REDIS_DEFAULT_STOPWORDS] + filtered = [ + w for w in words if w.lower() not in REDIS_DEFAULT_STOPWORDS + ] - if removed: - phrase_words = filtered if filtered else words - if filtered: - sw_msg = f"Stopwords {removed} were removed from" - else: - sw_msg = ( - f"All tokens in '{value}' are stopwords and may not " - "be indexed in" + if removed: + phrase_words = filtered if filtered else words + if filtered: + sw_msg = f"Stopwords {removed} were removed from" + else: + sw_msg = ( + f"All tokens in '{value}' are stopwords and may not " + "be indexed in" + ) + warnings.warn( + f"{sw_msg} exact phrase '{value}'. " + "By default, Redis does not index stopwords. " + "To include stopwords in your index, create it " + "with STOPWORDS 0.", + UserWarning, + stacklevel=2, ) - warnings.warn( - f"{sw_msg} exact phrase '{value}'. " - "By default, Redis does not index stopwords. " - "To include stopwords in your index, create it " - "with STOPWORDS 0.", - UserWarning, - stacklevel=2, - ) - else: - phrase_words = words + else: + phrase_words = words - escaped = self._escape_text_value(" ".join(phrase_words)) - search_value = f'"{escaped}"' + escaped = self._escape_text_value(" ".join(phrase_words)) + search_value = f'"{escaped}"' elif re.search(r"(?:^|\s+)OR(?:\s+|$)", value): # OR union within text field: split on uppercase-only OR with # flexible whitespace, escape each term, join with |. @@ -270,7 +294,7 @@ def build_text_condition( ) escaped_words = [] - for w in (filtered_words if filtered_words else words): + for w in filtered_words if filtered_words else words: if w.startswith("~"): # Preserve ~ optional-term prefix, escape the rest escaped_words.append("~" + self._escape_fulltext_term(w[1:])) diff --git a/sql_redis/translator.py b/sql_redis/translator.py index 8f3f199..0751a37 100644 --- a/sql_redis/translator.py +++ b/sql_redis/translator.py @@ -9,6 +9,8 @@ from sql_redis.analyzer import AnalyzedQuery, Analyzer from sql_redis.parser import ( SQL_TO_REDIS_DATE_FUNCTIONS, + BoolGroup, + BoolLeaf, Condition, GeoDistanceCondition, ParsedQuery, @@ -119,7 +121,10 @@ def _build_command(self, analyzed: AnalyzedQuery) -> TranslatedQuery: # Geo filters are applied as top-level command args (GEOFILTER/FILTER) and # are not part of the boolean expression. Combining with OR would change # semantics (e.g., `A OR geo_distance(...)` would become `(A) AND geo_filter`). - if parsed.geo_conditions and parsed.boolean_operator == "OR": + # ``has_or_in_where`` is set by the parser whenever an OR appears in + # WHERE, even when the boolean tree collapses (e.g., the OR's other + # branch was a geo_distance predicate that produced no tree leaf). + if parsed.geo_conditions and parsed.has_or_in_where: raise ValueError( "Geo distance predicates cannot be combined with OR; " "they are applied as top-level filters and would change query " @@ -138,8 +143,12 @@ def _build_command(self, analyzed: AnalyzedQuery) -> TranslatedQuery: # Validate: date function predicates cannot be combined with OR # Date filters are applied via FILTER clauses (ANDed with query). - # Combining with OR would change semantics. - if has_date_func_conditions and parsed.boolean_operator == "OR": + # Combining with OR would change semantics. Walk the tree to reject + # mixing at any depth (e.g., `A AND (YEAR(x) = 2024 OR B)`), not just + # when OR is the root operator. + if has_date_func_conditions and self._tree_has_date_in_or( + parsed.condition_tree + ): raise ValueError( "Date function predicates cannot be combined with OR; " "they are applied as top-level filters and would change query " @@ -174,36 +183,30 @@ def _build_command(self, analyzed: AnalyzedQuery) -> TranslatedQuery: return self._build_search(analyzed, query_string) def _build_query_string(self, analyzed: AnalyzedQuery) -> str: - """Build the RediSearch query string from conditions.""" + """Build the RediSearch query string from conditions. + + Walks the boolean tree built by the parser so that mixed AND/OR + expressions like ``A AND (B OR C)`` keep their original grouping + instead of collapsing onto a single boolean operator. + """ parsed = analyzed.parsed - conditions = parsed.conditions - # Filter out date function conditions (they need FILTER in AGGREGATE) - regular_conditions = [ - c for c in conditions if not self._is_date_function_condition(c) - ] + # Render the boolean tree (with proper RediSearch parenthesization) + # if one was built by the parser. Date-function leaves are skipped — + # they are emitted as FILTER args by the FT.AGGREGATE path. + if parsed.condition_tree is not None: + combined = self._render_bool_tree(parsed.condition_tree, analyzed) or "" + else: + combined = "" - if not regular_conditions and not analyzed.vector_search: + if not combined and not analyzed.vector_search: return "*" - # Build condition strings by type - condition_strings: list[str] = [] - - for condition in regular_conditions: - field_type = analyzed.get_field_type(condition.field) - condition_str = self._build_condition(condition, field_type) - condition_strings.append(condition_str) - - # Combine with boolean operator - combined = self._query_builder.combine_conditions( - condition_strings, parsed.boolean_operator - ) - # Handle vector search with prefilter if analyzed.vector_search: vs = analyzed.vector_search # Vector search uses KNN syntax - if analyzed.has_prefilter: + if analyzed.has_prefilter and combined: # Prefilter: (filter)=>[KNN k @field $vec] return f"({combined})=>[KNN {vs.k} @{vs.field} $vector AS {vs.alias}]" else: @@ -212,6 +215,53 @@ def _build_query_string(self, analyzed: AnalyzedQuery) -> str: return combined + def _render_bool_tree(self, node, analyzed: AnalyzedQuery) -> str | None: + """Recursively render a BoolLeaf/BoolGroup tree to a query string. + + Date-function leaves are dropped (handled via FILTER in FT.AGGREGATE). + OR groups are wrapped in parentheses so that, when nested inside an + AND group, RediSearch's higher AND precedence does not silently + re-associate the expression. Returns None for an empty tree (e.g., + a group that contained only date-function leaves). + """ + if isinstance(node, BoolLeaf): + condition = node.condition + if self._is_date_function_condition(condition): + return None + field_type = analyzed.get_field_type(condition.field) + return self._build_condition(condition, field_type) + if isinstance(node, BoolGroup): + rendered = [ + r + for r in (self._render_bool_tree(c, analyzed) for c in node.children) + if r + ] + if not rendered: + return None + if len(rendered) == 1: + return rendered[0] + if node.operator == "OR": + return "(" + "|".join(rendered) + ")" + # AND: space-joined; RediSearch gives AND higher precedence than OR + # so child OR groups (already wrapped in parens above) keep grouping. + return " ".join(rendered) + return None + + def _tree_has_date_in_or(self, node, in_or: bool = False) -> bool: + """Return True if any date-function leaf is reachable through an OR. + + Walks the boolean tree and returns True as soon as a date-function + condition is found beneath an OR ancestor — used to reject + ``A OR YEAR(x) = 2024`` and similar mixes that the FT.AGGREGATE + FILTER path cannot represent. + """ + if isinstance(node, BoolLeaf): + return in_or and self._is_date_function_condition(node.condition) + if isinstance(node, BoolGroup): + now_in_or = in_or or node.operator == "OR" + return any(self._tree_has_date_in_or(c, now_in_or) for c in node.children) + return False + def _build_condition(self, condition: Condition, field_type: str | None) -> str: """Build a single condition string based on field type.""" # Short-circuit for IS NULL / IS NOT NULL → ismissing() diff --git a/tests/test_query_builder.py b/tests/test_query_builder.py index d01c1be..a98d40e 100644 --- a/tests/test_query_builder.py +++ b/tests/test_query_builder.py @@ -8,12 +8,33 @@ class TestQueryBuilderTextFields: """Tests for building TEXT field query syntax.""" - def test_text_single_term_exact(self): - """TEXT field with = wraps in quotes for exact phrase: @field:"term".""" + def test_text_single_term_equality_uses_token_match(self): + """TEXT field with = on one token uses token-match semantics.""" builder = QueryBuilder() result = builder.build_text_condition("title", "=", "laptop") - assert result == '@title:"laptop"' + assert result == "@title:laptop" + + def test_text_single_term_equality_preserves_prefix_wildcard(self): + """Single-term equality preserves RediSearch wildcard syntax.""" + builder = QueryBuilder() + result = builder.build_text_condition("title", "=", "lap*") + + assert result == "@title:lap*" + + def test_text_single_term_equality_preserves_suffix_wildcard(self): + """Single-term equality preserves suffix wildcard syntax.""" + builder = QueryBuilder() + result = builder.build_text_condition("title", "=", "*book") + + assert result == "@title:*book" + + def test_text_single_term_equality_preserves_fuzzy_markers(self): + """Single-term equality preserves fuzzy syntax markers.""" + builder = QueryBuilder() + result = builder.build_text_condition("title", "=", "%laptap%") + + assert result == "@title:%laptap%" def test_text_exact_phrase(self): """TEXT field with = preserves multi-word phrase: @field:"exact phrase".""" @@ -55,12 +76,12 @@ def test_text_exact_phrase_escapes_quotes(self): assert result == r'@title:"say \"hello\""' - def test_text_exact_phrase_escapes_backslashes(self): - """TEXT field with = escapes backslashes inside the value.""" + def test_text_single_term_equality_escapes_backslashes(self): + """Single-term equality still escapes backslashes and other operators.""" builder = QueryBuilder() result = builder.build_text_condition("path", "=", r"c:\users\docs") - assert result == r'@path:"c:\\users\\docs"' + assert result == r"@path:c\:\\users\\docs" def test_text_fulltext_term(self): """TEXT field with FULLTEXT (tokenized search): @field:term.""" diff --git a/tests/test_sql_parser.py b/tests/test_sql_parser.py index 72a4f56..7c6af51 100644 --- a/tests/test_sql_parser.py +++ b/tests/test_sql_parser.py @@ -675,16 +675,43 @@ def test_parse_insert_statement(self): assert result.index == "" def test_parse_count_distinct(self): - """Parse COUNT(DISTINCT field) - this isn't Column or Star.""" + """COUNT(DISTINCT field) routes to the COUNT_DISTINCT reducer.""" parser = SQLParser() result = parser.parse( "SELECT COUNT(DISTINCT category) AS unique_cats FROM products" ) - # DISTINCT wraps the column, so field stays None assert len(result.aggregations) == 1 - assert result.aggregations[0].function == "COUNT" - assert result.aggregations[0].field is None + assert result.aggregations[0].function == "COUNT_DISTINCT" + assert result.aggregations[0].field == "category" + assert result.aggregations[0].alias == "unique_cats" + + def test_parse_count_distinct_without_alias(self): + """COUNT(DISTINCT field) preserves the field even without an alias.""" + parser = SQLParser() + result = parser.parse("SELECT COUNT(DISTINCT title) FROM products") + + assert len(result.aggregations) == 1 + assert result.aggregations[0].function == "COUNT_DISTINCT" + assert result.aggregations[0].field == "title" + + def test_parse_sum_distinct_raises(self): + """SUM(DISTINCT field) is rejected — RediSearch has no equivalent.""" + parser = SQLParser() + with pytest.raises(ValueError, match="DISTINCT"): + parser.parse("SELECT SUM(DISTINCT price) FROM products") + + def test_parse_avg_distinct_raises(self): + """AVG(DISTINCT field) is rejected — RediSearch has no equivalent.""" + parser = SQLParser() + with pytest.raises(ValueError, match="DISTINCT"): + parser.parse("SELECT AVG(DISTINCT price) FROM products") + + def test_parse_count_distinct_multi_column_raises(self): + """COUNT(DISTINCT a, b) is rejected — RediSearch has no equivalent.""" + parser = SQLParser() + with pytest.raises(ValueError, match="single column"): + parser.parse("SELECT COUNT(DISTINCT a, b) FROM products") def test_parse_sum_expression(self): """Parse SUM of expression - not a simple Column.""" @@ -759,6 +786,123 @@ def test_not_parenthesized_condition(self): assert result.conditions[0].field == "status" +class TestSQLParserMixedBooleanLogic: + """Tests that the WHERE-clause boolean tree preserves AND/OR grouping. + + Regression coverage for bugs where ``A AND (B OR C)``, + ``A OR (B AND C)`` and similar mixed expressions collapsed onto a single + ``boolean_operator`` and a flat ``conditions`` list, losing the grouping + expressed by the SQL parentheses. + """ + + def _fields(self, node): + """Recursively extract leaf field names from a BoolNode tree.""" + from sql_redis.parser import BoolGroup, BoolLeaf + + if isinstance(node, BoolLeaf): + return [node.condition.field] + if isinstance(node, BoolGroup): + return [f for c in node.children for f in self._fields(c)] + return [] + + def test_and_with_nested_or_keeps_group(self): + """A AND (B OR C): root is AND, with an inner OR child.""" + from sql_redis.parser import BoolGroup, BoolLeaf + + parser = SQLParser() + result = parser.parse( + "SELECT * FROM idx WHERE a = '1' AND (b = '2' OR c = '3')" + ) + + tree = result.condition_tree + assert isinstance(tree, BoolGroup) + assert tree.operator == "AND" + assert len(tree.children) == 2 + # First child is the leaf `a = '1'` + assert isinstance(tree.children[0], BoolLeaf) + assert tree.children[0].condition.field == "a" + # Second child is the OR group with b and c + inner = tree.children[1] + assert isinstance(inner, BoolGroup) + assert inner.operator == "OR" + assert self._fields(inner) == ["b", "c"] + + def test_or_with_nested_and_keeps_group(self): + """A OR (B AND C): root is OR, with an inner AND child.""" + from sql_redis.parser import BoolGroup, BoolLeaf + + parser = SQLParser() + result = parser.parse( + "SELECT * FROM idx WHERE a = '1' OR (b = '2' AND c = '3')" + ) + + tree = result.condition_tree + assert isinstance(tree, BoolGroup) + assert tree.operator == "OR" + assert len(tree.children) == 2 + assert isinstance(tree.children[0], BoolLeaf) + assert tree.children[0].condition.field == "a" + inner = tree.children[1] + assert isinstance(inner, BoolGroup) + assert inner.operator == "AND" + assert self._fields(inner) == ["b", "c"] + + def test_or_group_first_then_and(self): + """(B OR C) AND A keeps the OR group as the first child.""" + from sql_redis.parser import BoolGroup, BoolLeaf + + parser = SQLParser() + result = parser.parse( + "SELECT * FROM idx WHERE (b = '2' OR c = '3') AND a = '1'" + ) + + tree = result.condition_tree + assert isinstance(tree, BoolGroup) + assert tree.operator == "AND" + assert isinstance(tree.children[0], BoolGroup) + assert tree.children[0].operator == "OR" + assert self._fields(tree.children[0]) == ["b", "c"] + assert isinstance(tree.children[1], BoolLeaf) + assert tree.children[1].condition.field == "a" + + def test_chained_ands_with_trailing_or_group(self): + """A AND B AND C AND (D OR E) flattens AND children and keeps OR.""" + from sql_redis.parser import BoolGroup + + parser = SQLParser() + result = parser.parse( + "SELECT * FROM idx " + "WHERE a = '1' AND b = '2' AND c = '3' AND (d = '4' OR e = '5')" + ) + + tree = result.condition_tree + assert isinstance(tree, BoolGroup) + assert tree.operator == "AND" + # Three AND leaves followed by an OR group — same-operator subtrees + # are flattened so the AND group has 4 children, not nested. + assert len(tree.children) == 4 + assert self._fields(tree.children[0]) == ["a"] + assert self._fields(tree.children[1]) == ["b"] + assert self._fields(tree.children[2]) == ["c"] + last = tree.children[3] + assert isinstance(last, BoolGroup) + assert last.operator == "OR" + assert self._fields(last) == ["d", "e"] + + def test_flat_and_chain_is_single_group(self): + """A AND B AND C produces one AND group with three children.""" + from sql_redis.parser import BoolGroup + + parser = SQLParser() + result = parser.parse("SELECT * FROM idx WHERE a = '1' AND b = '2' AND c = '3'") + + tree = result.condition_tree + assert isinstance(tree, BoolGroup) + assert tree.operator == "AND" + assert len(tree.children) == 3 + assert self._fields(tree) == ["a", "b", "c"] + + class TestSQLParserIsNull: """Tests for IS NULL / IS NOT NULL parsing.""" diff --git a/tests/test_translator.py b/tests/test_translator.py index 23fda9a..da85595 100644 --- a/tests/test_translator.py +++ b/tests/test_translator.py @@ -136,7 +136,7 @@ def test_select_with_text_filter(self, translator: Translator, basic_index: str) ) assert result.command == "FT.SEARCH" - assert result.query_string == '@title:"hello"' + assert result.query_string == "@title:hello" def test_select_with_numeric_filter(self, translator: Translator, basic_index: str): """SELECT with NUMERIC field condition.""" @@ -202,7 +202,7 @@ def test_and_conditions(self, translator: Translator, basic_index: str): f"SELECT * FROM {basic_index} WHERE title = 'hello' AND price > 50" ) - assert '@title:"hello"' in result.query_string + assert "@title:hello" in result.query_string assert "@price:[(50 +inf]" in result.query_string def test_or_conditions(self, translator: Translator, basic_index: str): @@ -221,6 +221,88 @@ def test_boolean_in_numeric_context_raises( translator.translate(f"SELECT * FROM {basic_index} WHERE price = true") +class TestTranslatorMixedBooleanLogic: + """Tests that mixed AND/OR WHERE clauses keep their SQL grouping. + + Regression coverage for the bug where ``A AND (B OR C)`` and similar + expressions were flattened to a single boolean operator (e.g. + ``@a|@b|@c``), losing the user's intended precedence. + """ + + def test_and_with_nested_or(self, translator: Translator, basic_index: str): + """A AND (B OR C) -> ``@a (B|C)`` and the OR group is parenthesized.""" + result = translator.translate( + f"SELECT * FROM {basic_index} " + "WHERE category = 'a' AND (status = 'b' OR status = 'c')" + ) + + assert result.query_string == "@category:{a} (@status:{b}|@status:{c})" + + def test_or_with_nested_and(self, translator: Translator, basic_index: str): + """A OR (B AND C) -> ``(@a|@b @c)`` with the whole tree wrapped.""" + result = translator.translate( + f"SELECT * FROM {basic_index} " + "WHERE category = 'a' OR (status = 'b' AND price > 50)" + ) + + assert result.query_string == "(@category:{a}|@status:{b} @price:[(50 +inf])" + + def test_or_group_first_then_and(self, translator: Translator, basic_index: str): + """(B OR C) AND A -> ``(@b|@c) @a`` keeps the leading OR group.""" + result = translator.translate( + f"SELECT * FROM {basic_index} " + "WHERE (status = 'b' OR status = 'c') AND category = 'a'" + ) + + assert result.query_string == "(@status:{b}|@status:{c}) @category:{a}" + + def test_chained_ands_with_trailing_or_group( + self, translator: Translator, basic_index: str + ): + """A AND B AND C AND (D OR E) keeps the OR group only around D|E.""" + result = translator.translate( + f"SELECT * FROM {basic_index} " + "WHERE category = 'a' AND status = 'b' AND price > 10 " + "AND (title = 'd' OR title = 'e')" + ) + + assert ( + result.query_string + == "@category:{a} @status:{b} @price:[(10 +inf] (@title:d|@title:e)" + ) + + def test_two_or_groups_anded(self, translator: Translator, basic_index: str): + """(A OR B) AND (C OR D) keeps both OR groups parenthesized.""" + result = translator.translate( + f"SELECT * FROM {basic_index} " + "WHERE (category = 'a' OR category = 'b') " + "AND (status = 'c' OR status = 'd')" + ) + + assert ( + result.query_string + == "(@category:{a}|@category:{b}) (@status:{c}|@status:{d})" + ) + + def test_pure_and_chain_unchanged(self, translator: Translator, basic_index: str): + """A AND B AND C still renders as space-joined without parens.""" + result = translator.translate( + f"SELECT * FROM {basic_index} WHERE category = 'a' " + "AND status = 'b' AND price > 10" + ) + + assert result.query_string == "@category:{a} @status:{b} @price:[(10 +inf]" + + def test_pure_or_chain_unchanged(self, translator: Translator, basic_index: str): + """A OR B OR C still renders as a single pipe-joined OR group.""" + result = translator.translate( + f"SELECT * FROM {basic_index} " + "WHERE category = 'a' OR category = 'b' OR category = 'c'" + ) + + assert result.query_string == "(@category:{a}|@category:{b}|@category:{c})" + + class TestTranslatorAggregate: """Tests for FT.AGGREGATE translation.""" @@ -344,6 +426,71 @@ def test_count_distinct_reducer(self, translator: Translator, basic_index: str): assert "AS" in args assert "unique_titles" in args + def test_sql_count_distinct_routes_to_count_distinct( + self, translator: Translator, basic_index: str + ): + """SQL COUNT(DISTINCT x) emits REDUCE COUNT_DISTINCT 1 @x, not COUNT 0.""" + result = translator.translate( + f"SELECT category, COUNT(DISTINCT title) AS unique_titles " + f"FROM {basic_index} GROUP BY category" + ) + + assert result.command == "FT.AGGREGATE" + args = result.args + reduce_idx = args.index("REDUCE") + assert args[reduce_idx + 1] == "COUNT_DISTINCT" + assert args[reduce_idx + 2] == "1" + assert args[reduce_idx + 3] == "@title" + assert args[reduce_idx + 4] == "AS" + assert args[reduce_idx + 5] == "unique_titles" + + def test_sql_count_distinct_global_aggregation( + self, translator: Translator, basic_index: str + ): + """Global COUNT(DISTINCT x) (no GROUP BY) still emits COUNT_DISTINCT.""" + result = translator.translate( + f"SELECT COUNT(DISTINCT title) AS n FROM {basic_index}" + ) + + assert result.command == "FT.AGGREGATE" + args = result.args + # GROUPBY 0 for global aggregation + groupby_idx = args.index("GROUPBY") + assert args[groupby_idx + 1] == "0" + reduce_idx = args.index("REDUCE") + assert args[reduce_idx + 1] == "COUNT_DISTINCT" + assert args[reduce_idx + 2] == "1" + assert args[reduce_idx + 3] == "@title" + + def test_sql_count_distinct_matches_count_distinct_function( + self, translator: Translator, basic_index: str + ): + """COUNT(DISTINCT x) and COUNT_DISTINCT(x) emit equivalent reducers.""" + sql_distinct = translator.translate( + f"SELECT category, COUNT(DISTINCT title) AS n " + f"FROM {basic_index} GROUP BY category" + ) + redis_distinct = translator.translate( + f"SELECT category, COUNT_DISTINCT(title) AS n " + f"FROM {basic_index} GROUP BY category" + ) + + assert sql_distinct.args == redis_distinct.args + + def test_sql_sum_distinct_raises(self, translator: Translator, basic_index: str): + """SUM(DISTINCT x) is rejected — no native RediSearch equivalent.""" + with pytest.raises(ValueError, match="DISTINCT"): + translator.translate(f"SELECT SUM(DISTINCT price) FROM {basic_index}") + + def test_sql_count_distinct_multi_column_raises( + self, translator: Translator, basic_index: str + ): + """COUNT(DISTINCT a, b) is rejected — multi-column DISTINCT unsupported.""" + with pytest.raises(ValueError, match="single column"): + translator.translate( + f"SELECT COUNT(DISTINCT title, category) FROM {basic_index}" + ) + def test_quantile_reducer(self, translator: Translator, basic_index: str): """QUANTILE(field, value) should generate REDUCE QUANTILE 2 @field value.""" result = translator.translate( @@ -429,7 +576,7 @@ def test_double_negation_cancels(self, translator: Translator, basic_index: str) f"SELECT * FROM {basic_index} WHERE NOT title != 'good'" ) - assert result.query_string == '@title:"good"' + assert result.query_string == "@title:good" class TestTranslatorOutput: