Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions sql_redis/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ComputedField,
Condition,
DateFunctionSpec,
HybridSearchSpec,
ParsedQuery,
)

Expand All @@ -22,6 +23,19 @@ class VectorSearchAnalysis:
alias: str


@dataclass
class HybridSearchAnalysis:
"""Analyzed FT.HYBRID fusion search details.

Wraps the parsed spec and adds the resolved KNN ``k`` (from LIMIT). Field
types are validated during analysis (text leg must be TEXT, vector leg must
be VECTOR).
"""

spec: HybridSearchSpec
k: int


@dataclass
class AnalyzedQuery:
"""Result of analyzing a parsed SQL query with schema context."""
Expand All @@ -34,6 +48,7 @@ class AnalyzedQuery:
groupby_fields: list[str] = field(default_factory=list)
is_global_aggregation: bool = False
vector_search: VectorSearchAnalysis | None = None
hybrid_search: HybridSearchAnalysis | None = None
has_prefilter: bool = False

def get_field_type(self, field_name: str) -> str | None:
Expand Down Expand Up @@ -121,6 +136,11 @@ def analyze(self, parsed: ParsedQuery) -> AnalyzedQuery:
if parsed.vector_search:
referenced_fields.add(parsed.vector_search.field)

# Fields from hybrid fusion search (both legs)
if parsed.hybrid_search:
referenced_fields.add(parsed.hybrid_search.vector_field)
referenced_fields.add(parsed.hybrid_search.text_field)

# Fields from date functions (YEAR, MONTH, etc.)
for date_func in parsed.date_functions:
referenced_fields.add(date_func.field)
Expand All @@ -141,6 +161,10 @@ def analyze(self, parsed: ParsedQuery) -> AnalyzedQuery:
# KNN similarity; the alias is a computed column, not an indexed
# field, so it must not be looked up in the schema.
alias_names.add(parsed.vector_search.alias)
if parsed.hybrid_search is not None and parsed.hybrid_search.alias:
# ORDER BY <combined-score-alias> sorts by the fused score; like the
# vector alias, it is a computed column, not an indexed field.
alias_names.add(parsed.hybrid_search.alias)

# Fields from GROUP BY (exclude aliases since they're computed)
for field_name in parsed.groupby_fields:
Expand Down Expand Up @@ -180,4 +204,27 @@ def analyze(self, parsed: ParsedQuery) -> AnalyzedQuery:
# Has prefilter if there are conditions
result.has_prefilter = len(parsed.conditions) > 0

# Analyze hybrid fusion search
if parsed.hybrid_search:
spec = parsed.hybrid_search
vector_type = schema.get(spec.vector_field)
if vector_type != "VECTOR":
raise ValueError(
f"hybrid_vector_search() vector leg field "
f"'{spec.vector_field}' must be a VECTOR field, "
f"got {vector_type}."
)
text_type = schema.get(spec.text_field)
if text_type != "TEXT":
raise ValueError(
f"hybrid_vector_search() text leg field "
f"'{spec.text_field}' must be a TEXT field, got {text_type}."
)
result.hybrid_search = HybridSearchAnalysis(
spec=spec,
k=parsed.limit or spec.k or 10,
)
# Conditions become per-leg filters.
result.has_prefilter = len(parsed.conditions) > 0

return result
102 changes: 86 additions & 16 deletions sql_redis/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,54 @@ class QueryResult:
class _ScoreParseMixin:
"""Shared helpers for score-related response parsing."""

@staticmethod
def _is_unknown_command(error_msg: str) -> bool:
"""Return True when a ResponseError means the command is unsupported."""
lowered = error_msg.lower()
return "unknown command" in lowered or "unknown subcommand" in lowered

@staticmethod
def _parse_hybrid_reply(raw_result) -> tuple[Any, list[dict]]:
"""Parse an FT.HYBRID reply into (count, rows).

FT.HYBRID does not use the FT.AGGREGATE array shape. The reply is a map
``{total_results: N, results: [...], warnings: [...], ...}`` that arrives
either as a dict (redis-py 8.x / RESP3) or as a flat list
(``[total_results, N, results, [...], ...]``) on RESP2. Each result row
is likewise a dict or a flat ``[field, val, ...]`` list. Keys/values may
be bytes or str depending on the client's decode_responses setting.
"""
if isinstance(raw_result, dict):
reply = raw_result
else:
reply = dict(zip(raw_result[::2], raw_result[1::2]))

def _field(name: str):
if name in reply:
return reply[name]
return reply.get(name.encode())

count = _field("total_results") or 0
results = _field("results") or []
rows = [
dict(row) if isinstance(row, dict) else dict(zip(row[::2], row[1::2]))
for row in results
]
return count, rows

@staticmethod
def _inject_vector_param(cmd: list[str | bytes], vector_param: bytes) -> None:
"""Replace the vector PARAMS value with the actual bytes, in place.

Only the ``$vector`` token in the PARAMS value position (the one
preceded by the param name ``vector``) is replaced. Query-side
references to ``$vector`` (FT.SEARCH KNN expressions, FT.HYBRID VSIM)
must stay as parameter references so Redis resolves them from PARAMS.
"""
for i, arg in enumerate(cmd):
if arg == "$vector" and i > 0 and cmd[i - 1] == "vector":
cmd[i] = vector_param

@staticmethod
def _has_return_0(args: list[str]) -> bool:
"""Return True when the args contain 'RETURN 0' (no document fields)."""
Expand Down Expand Up @@ -188,17 +236,24 @@ def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
vector_param = value
break

# Replace $vector placeholder with actual bytes
# Replace the $vector PARAMS value with actual bytes (query/VSIM
# references to $vector stay as parameter references).
if vector_param:
for i, arg in enumerate(cmd):
if arg == "$vector":
cmd[i] = vector_param
self._inject_vector_param(cmd, vector_param)

# Execute command
try:
raw_result = self._client.execute_command(*cmd)
except redis.ResponseError as e:
error_msg = str(e)
if translated.command == "FT.HYBRID" and self._is_unknown_command(
error_msg
):
raise redis.ResponseError(
f"{error_msg}. hybrid_vector_search() translates to FT.HYBRID, "
"which requires Redis 8.4+ (RediSearch with hybrid search) "
"and redis-py >= 7.1.0."
) from e
_ismissing_signatures = (
"Unknown function",
"No such function",
Expand All @@ -217,10 +272,14 @@ def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
raise

# Parse result based on command type
count = raw_result[0] if raw_result else 0
rows = []

if translated.command == "FT.SEARCH":
# FT.SEARCH/FT.AGGREGATE replies are arrays with the count first;
# FT.HYBRID replies are maps (dict) and set count during parsing below.
count = raw_result[0] if isinstance(raw_result, list) and raw_result else 0
rows: list[dict] = []

if translated.command == "FT.HYBRID":
count, rows = self._parse_hybrid_reply(raw_result)
elif translated.command == "FT.SEARCH":
# Use the explicit score_alias signal rather than scanning args
# for the literal token "WITHSCORES", which could false-positive
# if a returned field happened to be named "WITHSCORES".
Expand Down Expand Up @@ -323,17 +382,24 @@ async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
vector_param = value
break

# Replace $vector placeholder with actual bytes
# Replace the $vector PARAMS value with actual bytes (query/VSIM
# references to $vector stay as parameter references).
if vector_param:
for i, arg in enumerate(cmd):
if arg == "$vector":
cmd[i] = vector_param
self._inject_vector_param(cmd, vector_param)

# Execute command asynchronously
try:
raw_result = await self._client.execute_command(*cmd)
except redis.ResponseError as e:
error_msg = str(e)
if translated.command == "FT.HYBRID" and self._is_unknown_command(
error_msg
):
raise redis.ResponseError(
f"{error_msg}. hybrid_vector_search() translates to FT.HYBRID, "
"which requires Redis 8.4+ (RediSearch with hybrid search) "
"and redis-py >= 7.1.0."
) from e
_ismissing_signatures = (
"Unknown function",
"No such function",
Expand All @@ -352,10 +418,14 @@ async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
raise

# Parse result based on command type
count = raw_result[0] if raw_result else 0
rows = []

if translated.command == "FT.SEARCH":
# FT.SEARCH/FT.AGGREGATE replies are arrays with the count first;
# FT.HYBRID replies are maps (dict) and set count during parsing below.
count = raw_result[0] if isinstance(raw_result, list) and raw_result else 0
rows: list[dict] = []

if translated.command == "FT.HYBRID":
count, rows = self._parse_hybrid_reply(raw_result)
elif translated.command == "FT.SEARCH":
with_scores = translated.score_alias is not None
no_content = self._has_return_0(translated.args)

Expand Down
Loading
Loading