diff --git a/integration/test_collection_diversity_hybrid.py b/integration/test_collection_diversity_hybrid.py new file mode 100644 index 000000000..ef0ae7b08 --- /dev/null +++ b/integration/test_collection_diversity_hybrid.py @@ -0,0 +1,149 @@ +"""Integration tests for hybrid search + MMR diversity selection. + +``DiversitySelection`` passed inside ``HybridVector.near_vector`` / +``HybridVector.near_text`` is applied by the server as a post-fusion MMR pass +(Weaviate >= 1.39.0). These tests assert that ``balance=0`` (pure diversity) +produces a different ordering than ``balance=1`` (pure relevance), and that +``mmr.limit`` caps the result count. + +The equivalent ``near_vector`` behaviour is covered in +``test_collection_diversity.py``. +""" + +import pytest + +from integration.conftest import CollectionFactory +from weaviate.classes.query import Diversity, HybridVector +from weaviate.collections.classes.config import Configure, DataType, Property +from weaviate.collections.classes.data import DataObject + +MIN_VERSION = (1, 39, 0) + + +def _skip_if_unsupported(collection) -> None: + if collection._connection._weaviate_version.is_lower_than(*MIN_VERSION): + pytest.skip("Hybrid diversity selection requires Weaviate >= 1.39.0") + + +def _create_clustered_collection(collection_factory: CollectionFactory): + """Create a collection with 3 tight clusters (a, b, c) of vectors in 3D.""" + collection = collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.none(), + ) + _skip_if_unsupported(collection) + collection.data.insert_many( + [ + DataObject(properties={"text": "a1"}, vector=[1.0, 0.0, 0.0]), + DataObject(properties={"text": "a2"}, vector=[0.95, 0.05, 0.0]), + DataObject(properties={"text": "a3"}, vector=[0.9, 0.1, 0.0]), + DataObject(properties={"text": "b1"}, vector=[0.0, 1.0, 0.0]), + DataObject(properties={"text": "b2"}, vector=[0.05, 0.95, 0.0]), + DataObject(properties={"text": "c1"}, vector=[0.0, 0.0, 1.0]), + ] + ) + return collection + + +def _create_large_collection(collection_factory: CollectionFactory, n_items: int = 50): + """Create a collection with enough items (>25) that a small mmr.limit is distinguishable from the server's default limit.""" + collection = collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.none(), + ) + _skip_if_unsupported(collection) + collection.data.insert_many( + [ + DataObject(properties={"text": f"t{i}"}, vector=[1.0 - 0.001 * i, 0.0, 0.0]) + for i in range(n_items) + ] + ) + return collection + + +def test_hybrid_near_vector_balance_0_differs_from_balance_1( + collection_factory: CollectionFactory, +) -> None: + """Hybrid near-vector: balance=0 (diversity) must reorder vs balance=1 (relevance).""" + collection = _create_clustered_collection(collection_factory) + balance_0 = collection.query.hybrid( + query=None, + vector=HybridVector.near_vector( + vector=[1.0, 0.0, 0.0], + diversity_selection=Diversity.mmr(limit=3, balance=0.0), + ), + limit=3, + ).objects + balance_1 = collection.query.hybrid( + query=None, + vector=HybridVector.near_vector( + vector=[1.0, 0.0, 0.0], + diversity_selection=Diversity.mmr(limit=3, balance=1.0), + ), + limit=3, + ).objects + assert [o.uuid for o in balance_0] != [o.uuid for o in balance_1] + + +def test_hybrid_near_vector_balance_1_matches_baseline( + collection_factory: CollectionFactory, +) -> None: + """Hybrid near-vector with MMR balance=1 (pure relevance) matches the plain baseline.""" + collection = _create_clustered_collection(collection_factory) + baseline = collection.query.hybrid( + query=None, + vector=HybridVector.near_vector(vector=[1.0, 0.0, 0.0]), + limit=3, + ).objects + mmr_balance_1 = collection.query.hybrid( + query=None, + vector=HybridVector.near_vector( + vector=[1.0, 0.0, 0.0], + diversity_selection=Diversity.mmr(limit=3, balance=1.0), + ), + limit=3, + ).objects + assert [o.uuid for o in baseline] == [o.uuid for o in mmr_balance_1] + + +def test_hybrid_alpha_1_balance_0_differs_from_balance_1( + collection_factory: CollectionFactory, +) -> None: + """Hybrid with explicit alpha=1.0 (pure vector) applies MMR like near_vector.""" + collection = _create_clustered_collection(collection_factory) + balance_0 = collection.query.hybrid( + query="irrelevant", + alpha=1.0, + vector=HybridVector.near_vector( + vector=[1.0, 0.0, 0.0], + diversity_selection=Diversity.mmr(limit=3, balance=0.0), + ), + limit=3, + ).objects + balance_1 = collection.query.hybrid( + query="irrelevant", + alpha=1.0, + vector=HybridVector.near_vector( + vector=[1.0, 0.0, 0.0], + diversity_selection=Diversity.mmr(limit=3, balance=1.0), + ), + limit=3, + ).objects + assert [o.uuid for o in balance_0] != [o.uuid for o in balance_1] + + +def test_hybrid_respects_mmr_limit( + collection_factory: CollectionFactory, +) -> None: + """Hybrid respects mmr.limit as the result-count cap when no outer limit is set.""" + mmr_limit = 5 + collection = _create_large_collection(collection_factory, n_items=50) + + result = collection.query.hybrid( + query=None, + vector=HybridVector.near_vector( + vector=[1.0, 0.0, 0.0], + diversity_selection=Diversity.mmr(limit=mmr_limit, balance=0.5), + ), + ).objects + assert len(result) == mmr_limit diff --git a/test/collection/test_hybrid_diversity.py b/test/collection/test_hybrid_diversity.py new file mode 100644 index 000000000..679324081 --- /dev/null +++ b/test/collection/test_hybrid_diversity.py @@ -0,0 +1,63 @@ +"""Unit tests: hybrid search wires diversity_selection into the gRPC request. + +Hybrid diversity is a post-fusion, hybrid-level operation, so the +``HybridVector.near_vector`` / ``HybridVector.near_text`` ``diversity_selection`` +argument must populate the top-level ``Hybrid.selection.mmr`` in the +SearchRequest proto (not the nested ``near_vector`` / ``near_text`` selection). +""" + +from weaviate.collections.grpc.query import _QueryGRPC +from weaviate.classes.query import Diversity, HybridVector +from weaviate.util import _ServerVersion + + +def _builder() -> _QueryGRPC: + return _QueryGRPC( + weaviate_version=_ServerVersion(1, 39, 0), + name="Dummy", + tenant=None, + consistency_level=None, + validate_arguments=True, + uses_125_api=True, + uses_127_api=True, + ) + + +def test_hybrid_near_vector_sets_top_level_selection() -> None: + req = _builder().hybrid( + query=None, + vector=HybridVector.near_vector( + vector=[1.0, 0.0, 0.0], + diversity_selection=Diversity.mmr(limit=7, balance=0.0), + ), + limit=7, + ) + # Canonical location: top-level Hybrid.selection, not the nested near_vector. + mmr = req.hybrid_search.selection.mmr + assert mmr.limit == 7 + assert mmr.balance == 0.0 + assert not req.hybrid_search.near_vector.HasField("selection") + + +def test_hybrid_near_text_sets_top_level_selection() -> None: + req = _builder().hybrid( + query=None, + vector=HybridVector.near_text( + query="cats", + diversity_selection=Diversity.mmr(limit=3, balance=0.5), + ), + limit=3, + ) + mmr = req.hybrid_search.selection.mmr + assert mmr.limit == 3 + assert mmr.balance == 0.5 + assert not req.hybrid_search.near_text.HasField("selection") + + +def test_hybrid_without_selection_leaves_it_unset() -> None: + req = _builder().hybrid( + query=None, + vector=HybridVector.near_vector(vector=[1.0, 0.0, 0.0]), + limit=5, + ) + assert not req.hybrid_search.HasField("selection") diff --git a/weaviate/collections/classes/grpc.py b/weaviate/collections/classes/grpc.py index aeca327a7..8cebd0960 100644 --- a/weaviate/collections/classes/grpc.py +++ b/weaviate/collections/classes/grpc.py @@ -760,6 +760,7 @@ class _HybridNearBase(_WeaviateInput): distance: Optional[float] = None certainty: Optional[float] = None + diversity_selection: Optional[MMR] = None class _HybridNearText(_HybridNearBase): @@ -772,6 +773,7 @@ class _HybridNearVector: # can't be a Pydantic model because of validation issu vector: NearVectorInputType distance: Optional[float] certainty: Optional[float] + diversity_selection: Optional[MMR] def __init__( self, @@ -779,10 +781,12 @@ def __init__( vector: NearVectorInputType, distance: Optional[float] = None, certainty: Optional[float] = None, + diversity_selection: Optional[MMR] = None, ) -> None: self.vector = vector self.distance = distance self.certainty = certainty + self.diversity_selection = diversity_selection HybridVectorType = Union[NearVectorInputType, _HybridNearText, _HybridNearVector] @@ -897,6 +901,7 @@ def near_text( distance: Optional[float] = None, move_to: Optional[Move] = None, move_away: Optional[Move] = None, + diversity_selection: Optional[MMR] = None, ) -> _HybridNearText: """Define a near text search to be used within a hybrid query. @@ -906,6 +911,7 @@ def near_text( distance: The maximum distance to search. If not specified, the default distance specified by the server is used. move_to: Define the concepts that should be moved towards in the vector space during the search. move_away: Define the concepts that should be moved away from in the vector space during the search. + diversity_selection: Apply diversity selection (e.g. MMR) to the hybrid results. Requires Weaviate >= 1.39.0. Returns: A `_HybridNearText` object to be used in the `vector` parameter of the `query.hybrid` and `generate.hybrid` search methods. @@ -916,6 +922,7 @@ def near_text( certainty=certainty, move_to=move_to, move_away=move_away, + diversity_selection=diversity_selection, ) @staticmethod @@ -924,12 +931,14 @@ def near_vector( *, certainty: Optional[float] = None, distance: Optional[float] = None, + diversity_selection: Optional[MMR] = None, ) -> _HybridNearVector: """Define a near vector search to be used within a hybrid query. Args: certainty: The minimum similarity score to return. If not specified, the default certainty specified by the server is used. distance: The maximum distance to search. If not specified, the default distance specified by the server is used. + diversity_selection: Apply diversity selection (e.g. MMR) to the hybrid results. Requires Weaviate >= 1.39.0. Returns: A `_HybridNearVector` object to be used in the `vector` parameter of the `query.hybrid` and `generate.hybrid` search methods. @@ -938,6 +947,7 @@ def near_vector( vector=vector, distance=distance, certainty=certainty, + diversity_selection=diversity_selection, ) diff --git a/weaviate/collections/grpc/shared.py b/weaviate/collections/grpc/shared.py index f13e5c4b1..a4f3214e9 100644 --- a/weaviate/collections/grpc/shared.py +++ b/weaviate/collections/grpc/shared.py @@ -639,6 +639,15 @@ def _parse_hybrid( near_text, near_vector, vector_bytes, vectors = None, None, None, None + # Hybrid diversity selection is a post-fusion, hybrid-level operation, so + # it is carried on the top-level Hybrid.selection field rather than on the + # near_text / near_vector sub-query. + hybrid_selection = ( + vector.diversity_selection + if isinstance(vector, (_HybridNearText, _HybridNearVector)) + else None + ) + if vector is None: pass elif isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], float): @@ -739,6 +748,7 @@ def _parse_hybrid( vector_bytes=vector_bytes, vector_distance=distance, vectors=vectors, + selection=self._diversity_selection_to_grpc(hybrid_selection), bm25_search_operator=base_search_pb2.SearchOperatorOptions( operator=bm25_operator.operator, minimum_or_tokens_match=bm25_operator.minimum_should_match