diff --git a/sql_redis/analyzer.py b/sql_redis/analyzer.py index a6eeb3a..c02f70d 100644 --- a/sql_redis/analyzer.py +++ b/sql_redis/analyzer.py @@ -9,6 +9,7 @@ ComputedField, Condition, DateFunctionSpec, + HybridSearchSpec, ParsedQuery, ) @@ -22,6 +23,19 @@ class VectorSearchAnalysis: alias: str +@dataclass +class HybridSearchAnalysis: + """Analyzed FT.HYBRID fusion search details. + + Wraps the parsed spec and adds the resolved KNN ``k`` (from LIMIT). Field + types are validated during analysis (text leg must be TEXT, vector leg must + be VECTOR). + """ + + spec: HybridSearchSpec + k: int + + @dataclass class AnalyzedQuery: """Result of analyzing a parsed SQL query with schema context.""" @@ -34,6 +48,7 @@ class AnalyzedQuery: groupby_fields: list[str] = field(default_factory=list) is_global_aggregation: bool = False vector_search: VectorSearchAnalysis | None = None + hybrid_search: HybridSearchAnalysis | None = None has_prefilter: bool = False def get_field_type(self, field_name: str) -> str | None: @@ -121,6 +136,11 @@ def analyze(self, parsed: ParsedQuery) -> AnalyzedQuery: if parsed.vector_search: referenced_fields.add(parsed.vector_search.field) + # Fields from hybrid fusion search (both legs) + if parsed.hybrid_search: + referenced_fields.add(parsed.hybrid_search.vector_field) + referenced_fields.add(parsed.hybrid_search.text_field) + # Fields from date functions (YEAR, MONTH, etc.) for date_func in parsed.date_functions: referenced_fields.add(date_func.field) @@ -141,6 +161,10 @@ def analyze(self, parsed: ParsedQuery) -> AnalyzedQuery: # KNN similarity; the alias is a computed column, not an indexed # field, so it must not be looked up in the schema. alias_names.add(parsed.vector_search.alias) + if parsed.hybrid_search is not None and parsed.hybrid_search.alias: + # ORDER BY sorts by the fused score; like the + # vector alias, it is a computed column, not an indexed field. + alias_names.add(parsed.hybrid_search.alias) # Fields from GROUP BY (exclude aliases since they're computed) for field_name in parsed.groupby_fields: @@ -180,4 +204,27 @@ def analyze(self, parsed: ParsedQuery) -> AnalyzedQuery: # Has prefilter if there are conditions result.has_prefilter = len(parsed.conditions) > 0 + # Analyze hybrid fusion search + if parsed.hybrid_search: + spec = parsed.hybrid_search + vector_type = schema.get(spec.vector_field) + if vector_type != "VECTOR": + raise ValueError( + f"hybrid_vector_search() vector leg field " + f"'{spec.vector_field}' must be a VECTOR field, " + f"got {vector_type}." + ) + text_type = schema.get(spec.text_field) + if text_type != "TEXT": + raise ValueError( + f"hybrid_vector_search() text leg field " + f"'{spec.text_field}' must be a TEXT field, got {text_type}." + ) + result.hybrid_search = HybridSearchAnalysis( + spec=spec, + k=parsed.limit or spec.k or 10, + ) + # Conditions become per-leg filters. + result.has_prefilter = len(parsed.conditions) > 0 + return result diff --git a/sql_redis/executor.py b/sql_redis/executor.py index 6672d6c..1790ff7 100644 --- a/sql_redis/executor.py +++ b/sql_redis/executor.py @@ -118,6 +118,54 @@ class QueryResult: class _ScoreParseMixin: """Shared helpers for score-related response parsing.""" + @staticmethod + def _is_unknown_command(error_msg: str) -> bool: + """Return True when a ResponseError means the command is unsupported.""" + lowered = error_msg.lower() + return "unknown command" in lowered or "unknown subcommand" in lowered + + @staticmethod + def _parse_hybrid_reply(raw_result) -> tuple[Any, list[dict]]: + """Parse an FT.HYBRID reply into (count, rows). + + FT.HYBRID does not use the FT.AGGREGATE array shape. The reply is a map + ``{total_results: N, results: [...], warnings: [...], ...}`` that arrives + either as a dict (redis-py 8.x / RESP3) or as a flat list + (``[total_results, N, results, [...], ...]``) on RESP2. Each result row + is likewise a dict or a flat ``[field, val, ...]`` list. Keys/values may + be bytes or str depending on the client's decode_responses setting. + """ + if isinstance(raw_result, dict): + reply = raw_result + else: + reply = dict(zip(raw_result[::2], raw_result[1::2])) + + def _field(name: str): + if name in reply: + return reply[name] + return reply.get(name.encode()) + + count = _field("total_results") or 0 + results = _field("results") or [] + rows = [ + dict(row) if isinstance(row, dict) else dict(zip(row[::2], row[1::2])) + for row in results + ] + return count, rows + + @staticmethod + def _inject_vector_param(cmd: list[str | bytes], vector_param: bytes) -> None: + """Replace the vector PARAMS value with the actual bytes, in place. + + Only the ``$vector`` token in the PARAMS value position (the one + preceded by the param name ``vector``) is replaced. Query-side + references to ``$vector`` (FT.SEARCH KNN expressions, FT.HYBRID VSIM) + must stay as parameter references so Redis resolves them from PARAMS. + """ + for i, arg in enumerate(cmd): + if arg == "$vector" and i > 0 and cmd[i - 1] == "vector": + cmd[i] = vector_param + @staticmethod def _has_return_0(args: list[str]) -> bool: """Return True when the args contain 'RETURN 0' (no document fields).""" @@ -188,17 +236,24 @@ def execute(self, sql: str, *, params: dict | None = None) -> QueryResult: vector_param = value break - # Replace $vector placeholder with actual bytes + # Replace the $vector PARAMS value with actual bytes (query/VSIM + # references to $vector stay as parameter references). if vector_param: - for i, arg in enumerate(cmd): - if arg == "$vector": - cmd[i] = vector_param + self._inject_vector_param(cmd, vector_param) # Execute command try: raw_result = self._client.execute_command(*cmd) except redis.ResponseError as e: error_msg = str(e) + if translated.command == "FT.HYBRID" and self._is_unknown_command( + error_msg + ): + raise redis.ResponseError( + f"{error_msg}. hybrid_vector_search() translates to FT.HYBRID, " + "which requires Redis 8.4+ (RediSearch with hybrid search) " + "and redis-py >= 7.1.0." + ) from e _ismissing_signatures = ( "Unknown function", "No such function", @@ -217,10 +272,14 @@ def execute(self, sql: str, *, params: dict | None = None) -> QueryResult: raise # Parse result based on command type - count = raw_result[0] if raw_result else 0 - rows = [] - - if translated.command == "FT.SEARCH": + # FT.SEARCH/FT.AGGREGATE replies are arrays with the count first; + # FT.HYBRID replies are maps (dict) and set count during parsing below. + count = raw_result[0] if isinstance(raw_result, list) and raw_result else 0 + rows: list[dict] = [] + + if translated.command == "FT.HYBRID": + count, rows = self._parse_hybrid_reply(raw_result) + elif translated.command == "FT.SEARCH": # Use the explicit score_alias signal rather than scanning args # for the literal token "WITHSCORES", which could false-positive # if a returned field happened to be named "WITHSCORES". @@ -323,17 +382,24 @@ async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult: vector_param = value break - # Replace $vector placeholder with actual bytes + # Replace the $vector PARAMS value with actual bytes (query/VSIM + # references to $vector stay as parameter references). if vector_param: - for i, arg in enumerate(cmd): - if arg == "$vector": - cmd[i] = vector_param + self._inject_vector_param(cmd, vector_param) # Execute command asynchronously try: raw_result = await self._client.execute_command(*cmd) except redis.ResponseError as e: error_msg = str(e) + if translated.command == "FT.HYBRID" and self._is_unknown_command( + error_msg + ): + raise redis.ResponseError( + f"{error_msg}. hybrid_vector_search() translates to FT.HYBRID, " + "which requires Redis 8.4+ (RediSearch with hybrid search) " + "and redis-py >= 7.1.0." + ) from e _ismissing_signatures = ( "Unknown function", "No such function", @@ -352,10 +418,14 @@ async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult: raise # Parse result based on command type - count = raw_result[0] if raw_result else 0 - rows = [] - - if translated.command == "FT.SEARCH": + # FT.SEARCH/FT.AGGREGATE replies are arrays with the count first; + # FT.HYBRID replies are maps (dict) and set count during parsing below. + count = raw_result[0] if isinstance(raw_result, list) and raw_result else 0 + rows: list[dict] = [] + + if translated.command == "FT.HYBRID": + count, rows = self._parse_hybrid_reply(raw_result) + elif translated.command == "FT.SEARCH": with_scores = translated.score_alias is not None no_content = self._has_return_0(translated.args) diff --git a/sql_redis/parser.py b/sql_redis/parser.py index 61ba228..5c73564 100644 --- a/sql_redis/parser.py +++ b/sql_redis/parser.py @@ -156,6 +156,34 @@ class VectorSearchSpec: k: int | None = None +@dataclass +class HybridSearchSpec: + """Specification for FT.HYBRID fusion search. + + Populated when a SELECT projection contains + ``hybrid_vector_search(, , )``. The vector + leg reuses ``cosine_distance``/``vector_distance``/``vector_range`` and the + text leg reuses ``fulltext``; the third argument is ``rrf(...)`` or + ``linear(...)``. Distinct from ``VectorSearchSpec`` (pre-filter hybrid + search), which fuses nothing. + """ + + vector_field: str + text_field: str + text_query: str + alias: str = "hybrid_score" # combined-score column (the SELECT alias) + text_scorer: str = "BM25STD" # SEARCH scorer + vector_method: str = "KNN" # "KNN" | "RANGE" + ef_runtime: int | None = None # KNN tuning knob + radius: float | None = None # RANGE knob + epsilon: float | None = None # RANGE knob + combine_method: str = "RRF" # "RRF" | "LINEAR" + rrf_constant: int | None = None # RRF knob (server default 60) + rrf_window: int | None = None # fusion window knob (server default 20) + linear_alpha: float | None = None # LINEAR knob; beta derived as (1 - alpha) + k: int | None = None # KNN K, derived from LIMIT by the analyzer + + @dataclass class Condition: """A WHERE condition.""" @@ -256,6 +284,7 @@ class ParsedQuery: computed_fields: list[ComputedField] = dataclasses.field(default_factory=list) date_functions: list[DateFunctionSpec] = dataclasses.field(default_factory=list) vector_search: VectorSearchSpec | None = None + hybrid_search: HybridSearchSpec | None = None groupby_fields: list[str] = dataclasses.field(default_factory=list) orderby_fields: list[tuple[str, str]] = dataclasses.field( default_factory=list @@ -531,7 +560,10 @@ def _process_select_expression_inner( "quantile", "random_sample", } - if func_name_lower == "vector_distance": + if func_name_lower == "hybrid_vector_search": + # FT.HYBRID fusion: hybrid_vector_search(vector_leg, text_leg, combine) + self._process_hybrid_vector_search(expression, result, alias) + elif func_name_lower == "vector_distance": # Extract the vector field name from first argument if expression.expressions: first_arg = expression.expressions[0] @@ -634,6 +666,173 @@ def _process_vector_distance( alias=alias or "vector_distance", ) + def _process_hybrid_vector_search( + self, expression, result: ParsedQuery, alias: str | None + ) -> None: + """Process a hybrid_vector_search() call into a HybridSearchSpec. + + Shape: hybrid_vector_search(, , ) where + the vector leg is cosine_distance/vector_distance/vector_range, the text + leg is fulltext(field, 'query'), and the optional combine is rrf()/linear(). + """ + args = expression.expressions + if len(args) < 2: + raise ValueError( + "hybrid_vector_search() requires a vector leg and a text leg, " + "e.g. hybrid_vector_search(cosine_distance(field, :vec), " + "fulltext(field, 'query'), rrf())." + ) + + vector_field, vector_method, ef_runtime, radius, epsilon = ( + self._parse_hybrid_vector_leg(args[0]) + ) + text_field, text_query, text_scorer = self._parse_hybrid_text_leg(args[1]) + combine_method, rrf_constant, rrf_window, linear_alpha = ( + self._parse_hybrid_combine(args[2] if len(args) >= 3 else None) + ) + + result.hybrid_search = HybridSearchSpec( + vector_field=vector_field, + text_field=text_field, + text_query=text_query, + alias=alias or "hybrid_score", + text_scorer=text_scorer, + vector_method=vector_method, + ef_runtime=ef_runtime, + radius=radius, + epsilon=epsilon, + combine_method=combine_method, + rrf_constant=rrf_constant, + rrf_window=rrf_window, + linear_alpha=linear_alpha, + ) + + def _extract_function_kwargs(self, func_expr) -> dict[str, object]: + """Return a {name: value} mapping for ``name => value`` kwargs in a call.""" + kwargs: dict[str, object] = {} + for child in func_expr.expressions: + if isinstance(child, exp.Kwarg): + key = getattr(child.this, "name", None) or str(child.this) + kwargs[key.lower()] = self._extract_literal_value(child.expression) + return kwargs + + def _parse_hybrid_vector_leg(self, leg): + """Parse the vector leg of hybrid_vector_search(). + + Returns (field, method, ef_runtime, radius, epsilon). + """ + # cosine_distance()/L2 distance parse as builtins (this == field column). + if isinstance(leg, (exp.CosineDistance, exp.Distance)): + if not isinstance(leg.this, exp.Column): + raise ValueError( + "hybrid_vector_search() vector leg field must be a column name." + ) + return leg.this.name, "KNN", None, None, None + + # vector_distance()/vector_range() parse as anonymous functions, which + # (unlike the cosine_distance builtin) accept extra tuning kwargs. + if isinstance(leg, exp.Anonymous): + name = leg.name.lower() + if name not in ("vector_distance", "cosine_distance", "vector_range"): + raise ValueError( + "hybrid_vector_search() first argument must be a vector leg " + "(cosine_distance, vector_distance, or vector_range), " + f"got {leg.name}()." + ) + field_node = leg.expressions[0] if leg.expressions else None + if not isinstance(field_node, exp.Column): + raise ValueError( + "hybrid_vector_search() vector leg field must be a column name." + ) + kwargs = self._extract_function_kwargs(leg) + if name == "vector_range": + radius = kwargs.get("radius") + if radius is None: + raise ValueError( + "vector_range() requires a radius, e.g. " + "vector_range(field, :vec, radius => 0.2)." + ) + epsilon = kwargs.get("epsilon") + return ( + field_node.name, + "RANGE", + None, + float(radius), + float(epsilon) if epsilon is not None else None, + ) + ef = kwargs.get("ef_runtime") + return ( + field_node.name, + "KNN", + int(ef) if ef is not None else None, + None, + None, + ) + + raise ValueError( + "hybrid_vector_search() first argument must be a vector leg, " + "e.g. cosine_distance(field, :vec)." + ) + + def _parse_hybrid_text_leg(self, leg): + """Parse the fulltext() text leg. Returns (field, query, scorer).""" + if not isinstance(leg, exp.Anonymous) or leg.name.lower() != "fulltext": + raise ValueError( + "hybrid_vector_search() second argument must be " + "fulltext(field, 'query')." + ) + if len(leg.expressions) < 2: + raise ValueError( + "fulltext() in hybrid_vector_search() requires a field and a " + "query string." + ) + field_node = leg.expressions[0] + if not isinstance(field_node, exp.Column): + raise ValueError("fulltext() field must be a column name.") + query_val = self._extract_literal_value(leg.expressions[1]) + if not isinstance(query_val, str): + raise ValueError("fulltext() query must be a string literal.") + kwargs = self._extract_function_kwargs(leg) + scorer = kwargs.get("scorer", "BM25STD") + return field_node.name, query_val, str(scorer) + + def _parse_hybrid_combine(self, combine): + """Parse the optional rrf()/linear() combine. + + Returns (method, rrf_constant, rrf_window, linear_alpha). When ``combine`` + is None the server default (RRF) applies. + """ + if combine is None: + return "RRF", None, None, None + if not isinstance(combine, exp.Anonymous): + raise ValueError( + "hybrid_vector_search() fusion argument must be rrf(...) or " + "linear(...)." + ) + name = combine.name.lower() + kwargs = self._extract_function_kwargs(combine) + window = kwargs.get("window") + rrf_window = int(window) if window is not None else None + if name == "rrf": + constant = kwargs.get("constant") + return ( + "RRF", + int(constant) if constant is not None else None, + rrf_window, + None, + ) + if name == "linear": + alpha = kwargs.get("alpha") + return ( + "LINEAR", + None, + rrf_window, + float(alpha) if alpha is not None else None, + ) + raise ValueError( + f"Unknown fusion method {combine.name}(). Use rrf(...) or linear(...)." + ) + def _process_geo_distance_select( self, expression, result: ParsedQuery, alias: str | None ) -> None: diff --git a/sql_redis/schema.py b/sql_redis/schema.py index 936954a..1d89f31 100644 --- a/sql_redis/schema.py +++ b/sql_redis/schema.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable import redis @@ -11,41 +11,74 @@ import redis.asyncio as async_redis -def _parse_schema_from_info(info: list) -> dict[str, str]: - """Parse field types from FT.INFO response. +def _decode(value: Any) -> Any: + """Decode a bytes value to str; pass through everything else.""" + return value.decode("utf-8") if isinstance(value, bytes) else value - This is a pure function with no I/O operations, shared by both - sync and async schema registries. + +def _extract_attributes(info: Any) -> list: + """Pull the ``attributes`` section out of an FT.INFO reply. + + Handles both reply shapes: the RESP2 flat list + (``[..., 'attributes', [...], ...]``) and the redis-py 8.x / RESP3 map + (``{b'attributes': [...], ...}``), whose keys may be bytes or str. + """ + if isinstance(info, dict): + for key, val in info.items(): + if _decode(key) == "attributes": + return val or [] + return [] + for i, item in enumerate(info): + if _decode(item) == "attributes": + return info[i + 1] + return [] + + +def _attribute_name_and_type(attr: Any) -> tuple[str | None, str | None]: + """Extract (field_name, field_type) from a single FT.INFO attribute. + + An attribute is either a dict (redis-py 8.x), e.g. + ``{b'attribute': b'title', b'type': b'TEXT', ...}``, or a flat list, e.g. + ``[b'identifier', b'title', b'attribute', b'title', b'type', b'TEXT', ...]``. + """ + if isinstance(attr, dict): + name = next( + (_decode(v) for k, v in attr.items() if _decode(k) == "attribute"), None + ) + ftype = next( + (_decode(v) for k, v in attr.items() if _decode(k) == "type"), None + ) + return name, ftype + + name = None + ftype = None + for j, val in enumerate(attr): + val_str = _decode(val) + if val_str == "attribute" and j + 1 < len(attr): + name = _decode(attr[j + 1]) + if val_str == "type" and j + 1 < len(attr): + ftype = _decode(attr[j + 1]) + return name, ftype + + +def _parse_schema_from_info(info: Any) -> dict[str, str]: + """Parse field types from an FT.INFO response. + + This is a pure function with no I/O operations, shared by both the sync + and async schema registries. It accepts both the RESP2 list reply and the + redis-py 8.x / RESP3 map reply (see ``_extract_attributes``). Args: - info: The raw response from FT.INFO command. + info: The raw response from the FT.INFO command (list or dict). Returns: Dictionary mapping field names to their types (e.g., {"title": "TEXT"}). """ - schema = {} - # Find the 'attributes' section in the info response - for i, item in enumerate(info): - # Handle bytes or string comparison - item_str = item.decode("utf-8") if isinstance(item, bytes) else item - if item_str == "attributes": - attributes = info[i + 1] - for attr in attributes: - field_name = None - field_type = None - # Each attribute is a list like: - # [b'identifier', b'title', b'attribute', b'title', b'type', b'TEXT', ...] - for j, val in enumerate(attr): - val_str = val.decode("utf-8") if isinstance(val, bytes) else val - if val_str == "attribute" and j + 1 < len(attr): - fn = attr[j + 1] - field_name = fn.decode("utf-8") if isinstance(fn, bytes) else fn - if val_str == "type" and j + 1 < len(attr): - ft = attr[j + 1] - field_type = ft.decode("utf-8") if isinstance(ft, bytes) else ft - if field_name and field_type: - schema[field_name] = field_type - break + schema: dict[str, str] = {} + for attr in _extract_attributes(info): + field_name, field_type = _attribute_name_and_type(attr) + if field_name and field_type: + schema[field_name] = field_type return schema diff --git a/sql_redis/translator.py b/sql_redis/translator.py index ac7673f..372ce02 100644 --- a/sql_redis/translator.py +++ b/sql_redis/translator.py @@ -21,23 +21,39 @@ from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry +def _fmt_num(value: float) -> str: + """Format a numeric FT.HYBRID parameter without trailing float noise.""" + return f"{value:g}" + + @dataclass class TranslatedQuery: """Result of translating SQL to Redis.""" - command: str # FT.SEARCH or FT.AGGREGATE + command: str # FT.SEARCH, FT.AGGREGATE, or FT.HYBRID index: str query_string: str args: list[str] = field(default_factory=list) params: dict[str, object] = field(default_factory=dict) # Named parameters score_alias: str | None = None # Alias for score column when WITHSCORES is used + # FT.HYBRID has no single top-level query string (its SEARCH/VSIM legs live + # in args), so it is rendered without the quoted query_string slot. + is_hybrid: bool = False def to_command_list(self) -> list[str]: """Return as a list suitable for redis.execute_command().""" + if self.is_hybrid: + return [self.command, self.index, *self.args] return [self.command, self.index, self.query_string, *self.args] def to_command_string(self) -> str: """Return as a human-readable command string.""" + if self.is_hybrid: + # Quote multi-word tokens (the SEARCH query and filter expressions) + # for readability; execution uses to_command_list (raw tokens). + rendered = [self.command, self.index] + rendered.extend(f'"{tok}"' if " " in tok else tok for tok in self.args) + return " ".join(rendered) parts = [self.command, self.index, f'"{self.query_string}"'] parts.extend(self.args) return " ".join(parts) @@ -117,6 +133,11 @@ def _build_command(self, analyzed: AnalyzedQuery) -> TranslatedQuery: """Build the Redis command from analyzed query.""" parsed = analyzed.parsed + # FT.HYBRID fusion is a dedicated command path with its own + # SEARCH/VSIM/COMBINE layout, distinct from FT.SEARCH/FT.AGGREGATE. + if analyzed.hybrid_search is not None: + return self._build_hybrid(analyzed) + # Validate: geo_distance cannot be combined with OR # Geo filters are applied as top-level command args (GEOFILTER/FILTER) and # are not part of the boolean expression. Combining with OR would change @@ -510,6 +531,106 @@ def _build_search( score_alias=(parsed.scoring.alias if parsed.scoring is not None else None), ) + def _build_hybrid(self, analyzed: AnalyzedQuery) -> TranslatedQuery: + """Build an FT.HYBRID command from a hybrid_vector_search() query. + + Layout: ``FT.HYBRID index SEARCH "" [SCORER s] VSIM @field $vector + [FILTER n ] (KNN|RANGE ...) COMBINE (RRF|LINEAR ...) [LOAD ...] + [LIMIT ...] PARAMS 2 vector $vector DIALECT 2``. The WHERE clause is + applied to both legs: folded into the SEARCH query and emitted as the + VSIM FILTER so the candidate sets agree. + """ + parsed = analyzed.parsed + hybrid = analyzed.hybrid_search + assert hybrid is not None # guaranteed by the caller's dispatch check + spec = hybrid.spec + args: list[str] = [] + + # WHERE clause becomes the shared per-leg filter (reuses the standard + # query-string builder; vector_search is unset on the hybrid path). + filter_expr = self._build_query_string(analyzed) + has_filter = bool(filter_expr) and filter_expr != "*" + + # SEARCH leg: tokenized text query, with the filter folded in. + text_query = self._query_builder.build_text_condition( + spec.text_field, "FULLTEXT", spec.text_query + ) + search_query = f"({text_query}) ({filter_expr})" if has_filter else text_query + args.extend(["SEARCH", search_query]) + if spec.text_scorer: + args.extend(["SCORER", spec.text_scorer]) + + # VSIM leg: vector field + param placeholder, then the KNN/RANGE method + # clause, then the optional per-leg filter (the method must precede + # FILTER in the VSIM grammar). + args.extend(["VSIM", f"@{spec.vector_field}", "$vector"]) + if spec.vector_method == "RANGE": + assert spec.radius is not None # parser requires radius for RANGE + method_tokens = ["RADIUS", _fmt_num(spec.radius)] + if spec.epsilon is not None: + method_tokens.extend(["EPSILON", _fmt_num(spec.epsilon)]) + args.extend(["RANGE", str(len(method_tokens)), *method_tokens]) + else: + method_tokens = ["K", str(hybrid.k)] + if spec.ef_runtime is not None: + method_tokens.extend(["EF_RUNTIME", str(spec.ef_runtime)]) + args.extend(["KNN", str(len(method_tokens)), *method_tokens]) + if has_filter: + args.extend(["FILTER", "1", filter_expr]) + + # COMBINE: RRF (default) or LINEAR. The combined score is yielded under + # the SELECT alias so it comes back as a column. + combine_tokens: list[str] = [] + if spec.combine_method == "LINEAR": + if spec.linear_alpha is not None: + combine_tokens.extend( + [ + "ALPHA", + _fmt_num(spec.linear_alpha), + "BETA", + _fmt_num(1 - spec.linear_alpha), + ] + ) + if spec.rrf_window is not None: + combine_tokens.extend(["WINDOW", str(spec.rrf_window)]) + combine_tokens.extend(["YIELD_SCORE_AS", spec.alias]) + args.extend( + ["COMBINE", "LINEAR", str(len(combine_tokens)), *combine_tokens] + ) + else: + if spec.rrf_constant is not None: + combine_tokens.extend(["CONSTANT", str(spec.rrf_constant)]) + if spec.rrf_window is not None: + combine_tokens.extend(["WINDOW", str(spec.rrf_window)]) + combine_tokens.extend(["YIELD_SCORE_AS", spec.alias]) + args.extend(["COMBINE", "RRF", str(len(combine_tokens)), *combine_tokens]) + + # LOAD: the SELECT columns (field names require an @ prefix). + load_fields = [f for f in parsed.fields if f != "*"] + if load_fields: + args.extend( + ["LOAD", str(len(load_fields)), *(f"@{f}" for f in load_fields)] + ) + + # LIMIT (final row cut; KNN K is set separately above). + if parsed.limit is not None: + args.extend(["LIMIT", str(parsed.offset or 0), str(parsed.limit)]) + + # PARAMS placeholder — the executor injects the vector bytes for "vector". + # Note: FT.HYBRID rejects an explicit DIALECT argument (the server uses + # its configured search-default-dialect), so none is appended here. + args.extend(["PARAMS", "2", "vector", "$vector"]) + + return TranslatedQuery( + command="FT.HYBRID", + index=parsed.index, + query_string="", + args=args, + params={"vector": None}, + score_alias=spec.alias, + is_hybrid=True, + ) + def _build_geo_filter_args(self, geo_cond: GeoDistanceCondition) -> list[str]: """Build GEOFILTER args from a GeoDistanceCondition.""" return [ diff --git a/tests/conftest.py b/tests/conftest.py index 7a11da2..c0e7524 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,8 +9,12 @@ @pytest.fixture(scope="module") def redis_container(): - """Create a Redis 8 container for testing.""" - with RedisContainer(image="redis:8.0.2") as container: + """Create a Redis 8 container for testing. + + Uses 8.4+ so FT.HYBRID (hybrid_vector_search) integration tests can run; + older versions cause those tests to skip via a server-capability check. + """ + with RedisContainer(image="redis:8.4") as container: yield container diff --git a/tests/test_ft_hybrid.py b/tests/test_ft_hybrid.py new file mode 100644 index 0000000..46b3a1d --- /dev/null +++ b/tests/test_ft_hybrid.py @@ -0,0 +1,821 @@ +"""TDD tests for FT.HYBRID support via the hybrid_vector_search() function. + +These tests are written ahead of the implementation (RAAE-1322). They define the +contract for translating + + hybrid_vector_search(, , ) + +into a native ``FT.HYBRID`` command (Redis 8.4+), which fuses an independently +ranked text search and vector search server-side (RRF or LINEAR). This is distinct +from the existing pre-filter hybrid search, where text is only a hard prefilter and +the ranking comes from the vector leg alone. + +Expected (not yet implemented) API contract +-------------------------------------------- +``ParsedQuery.hybrid_search``: ``HybridSearchSpec | None``, populated when the SELECT +projection contains ``hybrid_vector_search(...)``. The function composes the existing +``cosine_distance(field, :vec)`` (vector leg) and ``fulltext(field, 'query')`` (text +leg) functions, plus a third ``rrf(...)`` / ``linear(...)`` argument for fusion. + +``HybridSearchSpec`` fields: + vector_field: str # VSIM field + text_field: str # SEARCH field + text_query: str # SEARCH query string + text_scorer: str = "BM25STD" # SEARCH scorer + vector_method: str = "KNN" # "KNN" | "RANGE" + ef_runtime: int | None # KNN tuning knob + radius / epsilon: float|None # RANGE knobs + combine_method: str = "RRF" # "RRF" | "LINEAR" + rrf_constant: int | None # RRF knob (default 60) + rrf_window: int | None # RRF/LINEAR window knob (default 20) + linear_alpha: float | None # LINEAR knob; beta derived as (1 - alpha) + alias: str # combined-score column (the SELECT alias) + k: int | None # KNN K, derived from LIMIT + +``AnalyzedQuery.hybrid_search``: ``HybridSearchAnalysis | None`` with field types +resolved. ``Translator.translate(...)`` returns a ``TranslatedQuery`` whose +``command == "FT.HYBRID"``. +""" + +import struct + +import pytest +import redis + +from sql_redis.analyzer import Analyzer +from sql_redis.executor import Executor +from sql_redis.parser import SQLParser +from sql_redis.schema import SchemaRegistry +from sql_redis.translator import Translator + + +def float_vector_to_bytes(vector: list[float]) -> bytes: + """Convert a list of floats to binary format for Redis vector storage.""" + return struct.pack(f"{len(vector)}f", *vector) + + +def _hybrid_supported(client: redis.Redis) -> bool: + """Return True if the connected server understands FT.HYBRID (Redis 8.4+).""" + try: + client.execute_command("FT.HYBRID") + except redis.ResponseError as exc: + message = str(exc).lower() + # No args -> arity/syntax error means the command exists; only an + # "unknown command" reply means the server is too old. + return "unknown command" not in message and "unknown subcommand" not in message + return True + + +@pytest.fixture +def sample_schema() -> dict[str, dict[str, str]]: + """Schema with text, tag, vector, and numeric fields for hybrid tests.""" + return { + "items": { + "name": "TEXT", + "category": "TAG", + "description": "TEXT", + "price": "NUMERIC", + "embedding": "VECTOR", + } + } + + +@pytest.fixture(scope="module") +def hybrid_translator(redis_client: redis.Redis, items_index: str) -> Translator: + """Translator with the items index (text + tag + vector) loaded.""" + registry = SchemaRegistry(redis_client) + registry.load_all() + return Translator(registry) + + +@pytest.fixture(scope="module") +def hybrid_executor(redis_client: redis.Redis, items_data: str) -> Executor: + """Executor against the items index; skips when FT.HYBRID is unavailable.""" + if not _hybrid_supported(redis_client): + pytest.skip("FT.HYBRID requires Redis 8.4+ (the test container is 8.0.2)") + registry = SchemaRegistry(redis_client) + registry.load_all() + return Executor(redis_client, registry) + + +# A canonical hybrid query reused across layers. +HYBRID_SQL = ( + "SELECT name, description, " + "hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'smartphone features'), " + "rrf()" + ") AS hybrid_score " + "FROM items " + "WHERE category = 'electronics' " + "ORDER BY hybrid_score DESC " + "LIMIT 5" +) + + +class TestHybridParserSelect: + """Parsing hybrid_vector_search() out of the SELECT clause.""" + + def test_detects_hybrid_search(self): + """A hybrid_vector_search() projection populates ParsedQuery.hybrid_search.""" + parser = SQLParser() + result = parser.parse(HYBRID_SQL) + + assert result.hybrid_search is not None + assert result.index == "items" + + def test_extracts_vector_and_text_legs(self): + """The vector and text legs are extracted from the nested functions.""" + parser = SQLParser() + result = parser.parse(HYBRID_SQL) + + spec = result.hybrid_search + assert spec.vector_field == "embedding" + assert spec.text_field == "description" + assert spec.text_query == "smartphone features" + + def test_combined_score_alias(self): + """The SELECT alias becomes the combined-score column name.""" + parser = SQLParser() + result = parser.parse(HYBRID_SQL) + + assert result.hybrid_search.alias == "hybrid_score" + + def test_defaults_rrf_and_bm25std(self): + """rrf() with no scorer override yields RRF + the default BM25STD scorer.""" + parser = SQLParser() + result = parser.parse(HYBRID_SQL) + + spec = result.hybrid_search + assert spec.combine_method == "RRF" + assert spec.text_scorer == "BM25STD" + + def test_defaults_to_knn_vector_method(self): + """cosine_distance() selects the KNN vector method by default.""" + parser = SQLParser() + result = parser.parse(HYBRID_SQL) + + assert result.hybrid_search.vector_method == "KNN" + + def test_where_condition_is_preserved(self): + """The WHERE clause is retained as the per-leg filter.""" + parser = SQLParser() + result = parser.parse(HYBRID_SQL) + + assert len(result.conditions) == 1 + assert result.conditions[0].field == "category" + + +class TestHybridParserKnobs: + """Parsing the full set of fusion / leg knobs.""" + + def test_linear_alpha(self): + """linear(alpha => 0.3) selects LINEAR fusion with the given alpha.""" + parser = SQLParser() + result = parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'phone'), " + "linear(alpha => 0.3)" + ") AS score FROM items LIMIT 5" + ) + + spec = result.hybrid_search + assert spec.combine_method == "LINEAR" + assert spec.linear_alpha == 0.3 + + def test_rrf_constant_and_window(self): + """rrf(constant => 60, window => 20) captures both RRF knobs.""" + parser = SQLParser() + result = parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'phone'), " + "rrf(constant => 60, window => 20)" + ") AS score FROM items LIMIT 5" + ) + + spec = result.hybrid_search + assert spec.rrf_constant == 60 + assert spec.rrf_window == 20 + + def test_custom_text_scorer(self): + """fulltext(..., scorer => 'TFIDF') overrides the default scorer.""" + parser = SQLParser() + result = parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'phone', scorer => 'TFIDF'), " + "rrf()" + ") AS score FROM items LIMIT 5" + ) + + assert result.hybrid_search.text_scorer == "TFIDF" + + def test_knn_ef_runtime(self): + """vector_distance(..., ef_runtime => 20) captures the KNN tuning knob. + + The tuning knob rides on vector_distance() rather than cosine_distance(): + sqlglot models cosine_distance as a built-in capped at 2 args, while + vector_distance() parses as an anonymous function and accepts the extra arg. + """ + parser = SQLParser() + result = parser.parse( + "SELECT name, hybrid_vector_search(" + "vector_distance(embedding, :vec, ef_runtime => 20), " + "fulltext(description, 'phone'), " + "rrf()" + ") AS score FROM items LIMIT 5" + ) + + assert result.hybrid_search.ef_runtime == 20 + + +class TestHybridParserValidation: + """Error handling for malformed hybrid_vector_search() calls.""" + + def test_missing_text_leg_raises(self): + """hybrid_vector_search() without a text leg is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError): + parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec)" + ") AS score FROM items LIMIT 5" + ) + + def test_combine_omitted_defaults_to_rrf(self): + """A two-argument call (no combine) defaults to RRF fusion.""" + parser = SQLParser() + result = parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'phone')" + ") AS score FROM items LIMIT 5" + ) + + assert result.hybrid_search.combine_method == "RRF" + assert result.hybrid_search.rrf_constant is None + + def test_non_distance_vector_leg_raises(self): + """A bare column as the vector leg is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="vector leg"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "embedding, fulltext(description, 'phone'), rrf()" + ") AS score FROM items LIMIT 5" + ) + + def test_unknown_vector_function_raises(self): + """An unrecognized vector-leg function is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="vector leg"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "made_up(embedding, :vec), fulltext(description, 'phone'), rrf()" + ") AS score FROM items LIMIT 5" + ) + + def test_text_leg_not_fulltext_raises(self): + """A non-fulltext() text leg is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="fulltext"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), made_up(description, 'p'), rrf()" + ") AS score FROM items LIMIT 5" + ) + + def test_text_leg_non_string_query_raises(self): + """A non-string fulltext() query is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="string literal"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), fulltext(description, 123), rrf()" + ") AS score FROM items LIMIT 5" + ) + + def test_unknown_combine_method_raises(self): + """An unrecognized fusion function is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="rrf|linear"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), fulltext(description, 'p'), foo()" + ") AS score FROM items LIMIT 5" + ) + + def test_cosine_distance_non_column_field_raises(self): + """A literal (not a column) as the cosine_distance field is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="column name"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance('lit', :vec), fulltext(description, 'p'), rrf()" + ") AS score FROM items LIMIT 5" + ) + + def test_vector_distance_non_column_field_raises(self): + """A literal (not a column) as the vector_distance field is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="column name"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "vector_distance(123, :vec), fulltext(description, 'p'), rrf()" + ") AS score FROM items LIMIT 5" + ) + + def test_fulltext_insufficient_args_raises(self): + """fulltext() with only a field (no query) is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="field and a"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), fulltext(description), rrf()" + ") AS score FROM items LIMIT 5" + ) + + def test_fulltext_non_column_field_raises(self): + """A literal (not a column) as the fulltext field is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="column name"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), fulltext(123, 'p'), rrf()" + ") AS score FROM items LIMIT 5" + ) + + def test_vector_range_without_radius_raises(self): + """vector_range() without a radius is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="radius"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "vector_range(embedding, :vec), fulltext(description, 'p'), rrf()" + ") AS score FROM items LIMIT 5" + ) + + def test_non_function_combine_raises(self): + """A literal (not rrf/linear) as the fusion argument is rejected.""" + parser = SQLParser() + with pytest.raises(ValueError, match="rrf|linear"): + parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), fulltext(description, 'p'), 99" + ") AS score FROM items LIMIT 5" + ) + + +class TestHybridAnalyzer: + """Analyzing hybrid_vector_search() against a schema.""" + + def test_detects_hybrid_search(self, sample_schema): + """Analyzer surfaces the hybrid search on the analyzed query.""" + parser = SQLParser() + parsed = parser.parse(HYBRID_SQL) + result = Analyzer(sample_schema).analyze(parsed) + + assert result.hybrid_search is not None + + def test_resolves_leg_field_types(self, sample_schema): + """Both leg fields resolve to their schema types.""" + parser = SQLParser() + parsed = parser.parse(HYBRID_SQL) + result = Analyzer(sample_schema).analyze(parsed) + + assert result.get_field_type("embedding") == "VECTOR" + assert result.get_field_type("description") == "TEXT" + + def test_knn_k_derived_from_limit(self, sample_schema): + """LIMIT becomes the KNN K for the vector leg.""" + parser = SQLParser() + parsed = parser.parse(HYBRID_SQL) + result = Analyzer(sample_schema).analyze(parsed) + + assert result.hybrid_search.k == 5 + + def test_vector_leg_on_non_vector_field_raises(self, sample_schema): + """Using a non-VECTOR field as the vector leg is an error.""" + parser = SQLParser() + parsed = parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(price, :vec), " + "fulltext(description, 'phone'), " + "rrf()" + ") AS score FROM items LIMIT 5" + ) + with pytest.raises(ValueError): + Analyzer(sample_schema).analyze(parsed) + + def test_text_leg_on_non_text_field_raises(self, sample_schema): + """Using a non-TEXT field as the text leg is an error.""" + parser = SQLParser() + parsed = parser.parse( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(price, 'phone'), " + "rrf()" + ") AS score FROM items LIMIT 5" + ) + with pytest.raises(ValueError): + Analyzer(sample_schema).analyze(parsed) + + +class TestHybridTranslator: + """Translating hybrid_vector_search() to an FT.HYBRID command.""" + + def test_emits_ft_hybrid_command( + self, hybrid_translator: Translator, items_index: str + ): + """The translated command targets FT.HYBRID, not FT.SEARCH/FT.AGGREGATE.""" + result = hybrid_translator.translate(HYBRID_SQL) + + assert result.command == "FT.HYBRID" + + def test_command_has_search_and_vsim_legs( + self, hybrid_translator: Translator, items_index: str + ): + """Both the SEARCH and VSIM legs appear in the rendered command.""" + cmd = hybrid_translator.translate(HYBRID_SQL).to_command_string() + + assert "SEARCH" in cmd + assert "VSIM" in cmd + assert "@embedding" in cmd + + def test_command_has_rrf_combine( + self, hybrid_translator: Translator, items_index: str + ): + """Default fusion renders a COMBINE RRF clause.""" + cmd = hybrid_translator.translate(HYBRID_SQL).to_command_string() + + assert "COMBINE" in cmd + assert "RRF" in cmd + + def test_command_omits_dialect( + self, hybrid_translator: Translator, items_index: str + ): + """FT.HYBRID rejects an explicit DIALECT argument, so none is emitted.""" + result = hybrid_translator.translate(HYBRID_SQL) + + assert "DIALECT" not in result.to_command_string() + + def test_where_becomes_filter( + self, hybrid_translator: Translator, items_index: str + ): + """The WHERE clause is rendered as a category filter on the legs.""" + cmd = hybrid_translator.translate(HYBRID_SQL).to_command_string() + + assert "electronics" in cmd + + def test_linear_combine_renders_alpha( + self, hybrid_translator: Translator, items_index: str + ): + """A linear() fusion renders COMBINE LINEAR with ALPHA.""" + cmd = hybrid_translator.translate( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'smartphone'), " + "linear(alpha => 0.3)" + f") AS score FROM {items_index} LIMIT 5" + ).to_command_string() + + assert "LINEAR" in cmd + assert "ALPHA" in cmd + + def test_linear_derives_beta_from_alpha( + self, hybrid_translator: Translator, items_index: str + ): + """LINEAR exposes alpha only; beta is derived as (1 - alpha).""" + cmd = hybrid_translator.translate( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'smartphone'), " + "linear(alpha => 0.3)" + f") AS score FROM {items_index} LIMIT 5" + ).to_command_string() + + assert "ALPHA 0.3" in cmd + assert "BETA 0.7" in cmd + + def test_rrf_constant_and_window_in_command( + self, hybrid_translator: Translator, items_index: str + ): + """RRF knobs render as CONSTANT and WINDOW.""" + cmd = hybrid_translator.translate( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'smartphone'), " + "rrf(constant => 60, window => 20)" + f") AS score FROM {items_index} LIMIT 5" + ).to_command_string() + + assert "CONSTANT 60" in cmd + assert "WINDOW 20" in cmd + + def test_knn_ef_runtime_in_command( + self, hybrid_translator: Translator, items_index: str + ): + """A KNN ef_runtime knob renders EF_RUNTIME in the VSIM leg.""" + cmd = hybrid_translator.translate( + "SELECT name, hybrid_vector_search(" + "vector_distance(embedding, :vec, ef_runtime => 20), " + "fulltext(description, 'smartphone'), " + "rrf()" + f") AS score FROM {items_index} LIMIT 5" + ).to_command_string() + + assert "EF_RUNTIME 20" in cmd + + def test_range_method_renders_radius( + self, hybrid_translator: Translator, items_index: str + ): + """A vector_range() leg renders a RANGE method with RADIUS/EPSILON.""" + cmd = hybrid_translator.translate( + "SELECT name, hybrid_vector_search(" + "vector_range(embedding, :vec, radius => 0.2, epsilon => 0.01), " + "fulltext(description, 'smartphone'), " + "rrf()" + f") AS score FROM {items_index} LIMIT 5" + ).to_command_string() + + assert "RANGE" in cmd + assert "RADIUS 0.2" in cmd + assert "EPSILON 0.01" in cmd + + def test_linear_window_in_command( + self, hybrid_translator: Translator, items_index: str + ): + """A linear() window knob renders WINDOW in the COMBINE clause.""" + cmd = hybrid_translator.translate( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'smartphone'), " + "linear(alpha => 0.3, window => 30)" + f") AS score FROM {items_index} LIMIT 5" + ).to_command_string() + + assert "WINDOW 30" in cmd + + def test_range_without_epsilon( + self, hybrid_translator: Translator, items_index: str + ): + """vector_range() without epsilon renders RADIUS and no EPSILON.""" + cmd = hybrid_translator.translate( + "SELECT name, hybrid_vector_search(" + "vector_range(embedding, :vec, radius => 0.2), " + "fulltext(description, 'smartphone'), " + "rrf()" + f") AS score FROM {items_index} LIMIT 5" + ).to_command_string() + + assert "RADIUS 0.2" in cmd + assert "EPSILON" not in cmd + + def test_score_only_select_omits_load( + self, hybrid_translator: Translator, items_index: str + ): + """With only the fused score projected, no LOAD clause is emitted.""" + cmd = hybrid_translator.translate( + "SELECT hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'smartphone'), " + "rrf()" + f") AS score FROM {items_index} LIMIT 5" + ).to_command_string() + + assert "LOAD" not in cmd + + def test_no_where_omits_vsim_filter( + self, hybrid_translator: Translator, items_index: str + ): + """With no WHERE clause, the VSIM leg carries no FILTER.""" + cmd = hybrid_translator.translate( + "SELECT name, hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'smartphone'), " + "rrf()" + f") AS score FROM {items_index} LIMIT 5" + ).to_command_string() + + assert "FILTER" not in cmd + + +class _FakeRegistry: + """Minimal schema registry returning the items schema for translation.""" + + def get_schema(self, index: str) -> dict[str, str]: + return { + "name": "TEXT", + "category": "TAG", + "description": "TEXT", + "embedding": "VECTOR", + } + + +class _FakeClient: + """Sync client stub that returns a canned reply for execute_command.""" + + def __init__(self, reply): + self.reply = reply + self.last_command: tuple | None = None + + def execute_command(self, *args): + self.last_command = args + if isinstance(self.reply, Exception): + raise self.reply + return self.reply + + +_HYBRID_SQL_NO_WHERE = ( + "SELECT name, description, " + "hybrid_vector_search(" + "cosine_distance(embedding, :vec), " + "fulltext(description, 'smartphone'), " + "rrf()" + ") AS hybrid_score " + "FROM items LIMIT 5" +) + + +class TestHybridExecutorVersionGuard: + """The executor raises a clear error when FT.HYBRID is unsupported.""" + + def test_unknown_command_raises_version_hint(self): + """An 'unknown command' reply is rewrapped with the 8.4 requirement.""" + client = _FakeClient(redis.ResponseError("ERR unknown command 'FT.HYBRID'")) + executor = Executor(client, _FakeRegistry()) + + with pytest.raises(redis.ResponseError, match="8.4"): + executor.execute(_HYBRID_SQL_NO_WHERE, params={"vec": b"\x00" * 16}) + + +class TestHybridExecutorParsing: + """The executor parses FT.HYBRID replies into rows with the fused score.""" + + def test_parses_rows_with_combined_score(self): + """A hybrid reply maps field/value pairs (incl. the score) into rows.""" + reply = [ + "total_results", + 1, + "results", + [["name", "iPhone 15", "description", "smartphone", "hybrid_score", "0.5"]], + "warnings", + [], + "execution_time", + "0.1", + ] + client = _FakeClient(reply) + executor = Executor(client, _FakeRegistry()) + + result = executor.execute(_HYBRID_SQL_NO_WHERE, params={"vec": b"\x00" * 16}) + + assert result.count == 1 + assert result.rows[0]["name"] == "iPhone 15" + assert result.rows[0]["hybrid_score"] == "0.5" + + def test_parses_rows_from_resp3_dict_reply(self): + """A redis-py 8.x / RESP3 map reply (dict of dict rows) parses to rows.""" + reply = { + b"total_results": 1, + b"results": [{b"name": b"iPhone 15", b"hybrid_score": b"0.5"}], + b"warnings": [], + b"execution_time": 0.1, + } + client = _FakeClient(reply) + executor = Executor(client, _FakeRegistry()) + + result = executor.execute(_HYBRID_SQL_NO_WHERE, params={"vec": b"\x00" * 16}) + + assert result.count == 1 + assert result.rows[0][b"name"] == b"iPhone 15" + assert result.rows[0][b"hybrid_score"] == b"0.5" + + def test_vector_bytes_injected_into_command(self): + """The vector param bytes replace the $vector placeholder in the command.""" + client = _FakeClient([0]) + executor = Executor(client, _FakeRegistry()) + blob = b"\x01" * 16 + + executor.execute(_HYBRID_SQL_NO_WHERE, params={"vec": blob}) + + assert client.last_command[0] == "FT.HYBRID" + # The bytes are injected as the PARAMS value... + assert blob in client.last_command + # ...while the VSIM leg keeps the $vector parameter reference. + assert "$vector" in client.last_command + + +class _FakeAsyncRegistry: + """Minimal async schema registry for the async executor unit tests.""" + + async def ensure_schema(self, index: str) -> None: + return None + + def get_schema(self, index: str) -> dict[str, str]: + return { + "name": "TEXT", + "category": "TAG", + "description": "TEXT", + "embedding": "VECTOR", + } + + +class _FakeAsyncClient: + """Async client stub that returns a canned reply for execute_command.""" + + def __init__(self, reply): + self.reply = reply + self.last_command: tuple | None = None + + async def execute_command(self, *args): + self.last_command = args + if isinstance(self.reply, Exception): + raise self.reply + return self.reply + + +class TestHybridAsyncExecutor: + """Async executor mirrors the sync FT.HYBRID guard and parsing.""" + + async def test_async_version_guard(self): + """An 'unknown command' reply is rewrapped with the 8.4 requirement.""" + from sql_redis.executor import AsyncExecutor + + client = _FakeAsyncClient( + redis.ResponseError("ERR unknown command 'FT.HYBRID'") + ) + executor = AsyncExecutor(client, _FakeAsyncRegistry()) + + with pytest.raises(redis.ResponseError, match="8.4"): + await executor.execute(_HYBRID_SQL_NO_WHERE, params={"vec": b"\x00" * 16}) + + async def test_async_parses_rows(self): + """An async hybrid reply parses into rows with the fused score.""" + from sql_redis.executor import AsyncExecutor + + reply = [ + "total_results", + 1, + "results", + [["name", "iPhone 15", "hybrid_score", "0.5"]], + "warnings", + [], + ] + client = _FakeAsyncClient(reply) + executor = AsyncExecutor(client, _FakeAsyncRegistry()) + + result = await executor.execute( + _HYBRID_SQL_NO_WHERE, params={"vec": b"\x00" * 16} + ) + + assert result.rows[0]["name"] == "iPhone 15" + assert result.rows[0]["hybrid_score"] == "0.5" + + +class TestHybridFusionIntegration: + """End-to-end FT.HYBRID execution (requires Redis 8.4+).""" + + def test_returns_fused_rows(self, hybrid_executor: Executor, items_data: str): + """A hybrid fusion query returns rows with the combined-score column.""" + query_vector = float_vector_to_bytes([0.1, 0.2, 0.3, 0.4]) + + result = hybrid_executor.execute( + f""" + SELECT name, description, + hybrid_vector_search( + cosine_distance(embedding, :vec), + fulltext(description, 'smartphone features'), + rrf() + ) AS hybrid_score + FROM {items_data} + WHERE category = 'electronics' + ORDER BY hybrid_score DESC + LIMIT 5 + """, + params={"vec": query_vector}, + ) + + assert len(result.rows) >= 1 + assert "hybrid_score" in result.rows[0] + + def test_linear_fusion_executes(self, hybrid_executor: Executor, items_data: str): + """LINEAR fusion with an alpha weight executes end-to-end.""" + query_vector = float_vector_to_bytes([0.1, 0.2, 0.3, 0.4]) + + result = hybrid_executor.execute( + f""" + SELECT name, + hybrid_vector_search( + cosine_distance(embedding, :vec), + fulltext(description, 'smartphone'), + linear(alpha => 0.3) + ) AS hybrid_score + FROM {items_data} + LIMIT 5 + """, + params={"vec": query_vector}, + ) + + assert len(result.rows) >= 1 diff --git a/tests/test_schema_registry.py b/tests/test_schema_registry.py index 99026cd..91683cc 100644 --- a/tests/test_schema_registry.py +++ b/tests/test_schema_registry.py @@ -247,6 +247,42 @@ def test_parse_schema_incomplete_attribute(self): # Only field2 should be captured (field1 has no type) assert schema == {"field2": "TEXT"} + def test_parse_schema_dict_reply_with_bytes_keys(self): + """_parse_schema_from_info handles the redis-py 8.x / RESP3 map reply. + + redis-py 8.x applies a response callback to FT.INFO, returning a dict + with bytes keys whose ``attributes`` value is a list of dicts. + """ + fake_info = { + b"index_name": b"items", + b"attributes": [ + {b"identifier": b"title", b"attribute": b"title", b"type": b"TEXT"}, + {b"identifier": b"genre", b"attribute": b"genre", b"type": b"TAG"}, + { + b"identifier": b"embedding", + b"attribute": b"embedding", + b"type": b"VECTOR", + }, + ], + } + schema = _parse_schema_from_info(fake_info) + + assert schema == {"title": "TEXT", "genre": "TAG", "embedding": "VECTOR"} + + def test_parse_schema_dict_reply_without_attributes(self): + """A dict reply with no attributes section yields an empty schema.""" + assert _parse_schema_from_info({b"index_name": b"items"}) == {} + + def test_parse_schema_dict_attribute_missing_type(self): + """A dict attribute without a type is skipped.""" + fake_info = { + b"attributes": [ + {b"attribute": b"field1"}, # no type + {b"attribute": b"field2", b"type": b"TEXT"}, + ], + } + assert _parse_schema_from_info(fake_info) == {"field2": "TEXT"} + class TestSchemaRegistryRefresh: """Tests for schema refresh functionality."""