diff --git a/CHANGELOG.md b/CHANGELOG.md index a592dec..58e6094 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [2.4.0] - 2026-04-21 + +### Added +- `unique_text` parameter in `create_table` to enforce a uniqueness constraint on the text column +- `on_conflict` parameter in the `add` method to control duplicate-entry handling with options: `"error"`, `"ignore"`, and `"replace"` +- Validation for `on_conflict` values to ensure only accepted options are used +- Tests covering unique text constraints and all conflict resolution strategies + ## [2.3.0] - 2025-02-15 ### Added diff --git a/pyproject.toml b/pyproject.toml index d5f4852..df113e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sqlite-vec-client" -version = "2.3.0" +version = "2.4.0" description = "A lightweight Python client around sqlite-vec for CRUD and similarity search." readme = "README.md" requires-python = ">=3.9" diff --git a/sqlite_vec_client/base.py b/sqlite_vec_client/base.py index 2bc5f51..4029daa 100644 --- a/sqlite_vec_client/base.py +++ b/sqlite_vec_client/base.py @@ -30,6 +30,7 @@ validate_limit, validate_metadata_filters, validate_offset, + validate_on_conflict, validate_table_name, validate_top_k, ) @@ -144,12 +145,15 @@ def create_table( self, dim: int, distance: Literal["L1", "L2", "cosine"] = "cosine", + unique_text: bool = False, ) -> None: """Create base table, vector table, and triggers to keep them in sync. Args: dim: Embedding dimension (must be positive) distance: Distance metric for similarity search + unique_text: If True, enforce uniqueness on the text column. + This enables ``on_conflict`` options in :meth:`add`. Raises: TableNameError: If table name is invalid @@ -172,6 +176,14 @@ def create_table( ; """ ) + if unique_text: + self.connection.execute( + f""" + CREATE UNIQUE INDEX IF NOT EXISTS {self.table}_text_unique + ON {self.table}(text) + ; + """ + ) self.connection.execute( f""" CREATE VIRTUAL TABLE IF NOT EXISTS {self.table}_vec USING vec0( @@ -325,6 +337,7 @@ def add( texts: list[Text], embeddings: list[Embeddings], metadata: list[Metadata] | None = None, + on_conflict: Literal["error", "ignore", "replace"] = "error", ) -> Rowids: """Insert texts with embeddings (and optional metadata) and return rowids. @@ -332,14 +345,22 @@ def add( texts: List of text strings embeddings: List of embedding vectors metadata: Optional list of metadata dicts + on_conflict: How to handle duplicate texts when a UNIQUE index on + ``text`` exists (see ``create_table(unique_text=True)``). + + - ``"error"`` (default): raise on conflict. + - ``"ignore"``: silently skip duplicate texts. + - ``"replace"``: update metadata and embedding of + existing records that share the same text. Returns: - List of rowids for inserted records + List of rowids for inserted (or upserted) records Raises: - ValidationError: If list lengths don't match + ValidationError: If list lengths don't match or on_conflict is invalid TableNotFoundError: If table doesn't exist """ + validate_on_conflict(on_conflict) validate_embeddings_match(texts, embeddings, metadata) expected_dim = self._ensure_dimension() for embedding in embeddings: @@ -356,19 +377,49 @@ def add( cur = self.connection.cursor() - # Get max rowid before insert - max_before = cur.execute( - f"SELECT COALESCE(MAX(rowid), 0) FROM {self.table}" - ).fetchone()[0] + if on_conflict == "ignore": + sql = ( + f"INSERT OR IGNORE INTO {self.table}" + f"(text, metadata, text_embedding) VALUES (?,?,?)" + ) + elif on_conflict == "replace": + sql = ( + f"INSERT INTO {self.table}(text, metadata, text_embedding) " + f"VALUES (?,?,?) " + f"ON CONFLICT(text) DO UPDATE SET " + f"metadata=excluded.metadata, " + f"text_embedding=excluded.text_embedding" + ) + else: + sql = ( + f"INSERT INTO {self.table}" + f"(text, metadata, text_embedding) VALUES (?,?,?)" + ) - cur.executemany( - f"""INSERT INTO {self.table}(text, metadata, text_embedding) - VALUES (?,?,?)""", - data_input, - ) + if on_conflict in ("error", "ignore"): + max_before = cur.execute( + f"SELECT COALESCE(MAX(rowid), 0) FROM {self.table}" + ).fetchone()[0] + + cur.executemany(sql, data_input) - # Calculate rowids from max_before - rowids = list(range(max_before + 1, max_before + len(texts) + 1)) + if on_conflict == "error": + rowids = list(range(max_before + 1, max_before + len(texts) + 1)) + elif on_conflict == "ignore": + cur.execute( + f"SELECT rowid FROM {self.table} " + f"WHERE rowid > ? ORDER BY rowid", + [max_before], + ) + rowids = [row[0] for row in cur.fetchall()] + else: + placeholders = ",".join(["?"] * len(texts)) + cur.execute( + f"SELECT rowid FROM {self.table} " + f"WHERE text IN ({placeholders}) ORDER BY rowid", + texts, + ) + rowids = [row[0] for row in cur.fetchall()] if not self._in_transaction: self.connection.commit() diff --git a/sqlite_vec_client/validation.py b/sqlite_vec_client/validation.py index 49762a5..fe7f358 100644 --- a/sqlite_vec_client/validation.py +++ b/sqlite_vec_client/validation.py @@ -123,6 +123,25 @@ def validate_embedding_dimension(embedding: list[float], expected_dim: int) -> N ) +_VALID_ON_CONFLICT = frozenset({"error", "ignore", "replace"}) + + +def validate_on_conflict(on_conflict: str) -> None: + """Validate the on_conflict parameter for add(). + + Args: + on_conflict: Conflict resolution strategy + + Raises: + ValidationError: If the value is not one of 'error', 'ignore', 'replace' + """ + if on_conflict not in _VALID_ON_CONFLICT: + raise ValidationError( + f"on_conflict must be one of {sorted(_VALID_ON_CONFLICT)}, " + f"got '{on_conflict}'" + ) + + def validate_metadata_filters(filters: dict[str, Any]) -> None: """Validate metadata filters dictionary. diff --git a/tests/conftest.py b/tests/conftest.py index 2a400e2..54ea907 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,13 @@ def client_with_table(client: SQLiteVecClient) -> SQLiteVecClient: return client +@pytest.fixture +def client_with_unique_table(client: SQLiteVecClient) -> SQLiteVecClient: + """Provide a client with table created with unique_text=True.""" + client.create_table(dim=3, distance="cosine", unique_text=True) + return client + + @pytest.fixture def sample_embeddings() -> list[list[float]]: """Provide sample 3D embeddings for testing.""" diff --git a/tests/test_client.py b/tests/test_client.py index 4ebd8bc..c25d679 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,7 @@ """Integration tests for SQLiteVecClient.""" +import sqlite3 + import pytest from sqlite_vec_client import ( @@ -94,6 +96,155 @@ def test_add_invalid_embedding_dimension( client_with_table.add(texts=sample_texts, embeddings=invalid_embeddings) +@pytest.mark.integration +class TestUniqueText: + """Tests for unique_text constraint and on_conflict parameter.""" + + def test_unique_text_rejects_duplicates_by_default( + self, client_with_unique_table, sample_embeddings + ): + """Duplicate text raises IntegrityError when on_conflict='error'.""" + client_with_unique_table.add( + texts=["hello"], embeddings=[sample_embeddings[0]] + ) + with pytest.raises(sqlite3.IntegrityError): + client_with_unique_table.add( + texts=["hello"], embeddings=[sample_embeddings[1]] + ) + + def test_on_conflict_ignore_skips_duplicates( + self, client_with_unique_table, sample_embeddings + ): + """Duplicate texts are silently skipped with on_conflict='ignore'.""" + client_with_unique_table.add( + texts=["hello"], embeddings=[sample_embeddings[0]] + ) + rowids = client_with_unique_table.add( + texts=["hello", "world"], + embeddings=[sample_embeddings[1], sample_embeddings[2]], + on_conflict="ignore", + ) + assert len(rowids) == 1 + assert client_with_unique_table.count() == 2 + + def test_on_conflict_ignore_all_duplicates( + self, client_with_unique_table, sample_embeddings + ): + """All duplicates skipped returns empty rowids.""" + client_with_unique_table.add( + texts=["hello"], embeddings=[sample_embeddings[0]] + ) + rowids = client_with_unique_table.add( + texts=["hello"], + embeddings=[sample_embeddings[1]], + on_conflict="ignore", + ) + assert rowids == [] + assert client_with_unique_table.count() == 1 + + def test_on_conflict_ignore_preserves_original( + self, client_with_unique_table, sample_embeddings + ): + """Ignored duplicates do not overwrite the original record.""" + client_with_unique_table.add( + texts=["hello"], + embeddings=[sample_embeddings[0]], + metadata=[{"version": 1}], + ) + client_with_unique_table.add( + texts=["hello"], + embeddings=[sample_embeddings[1]], + metadata=[{"version": 2}], + on_conflict="ignore", + ) + record = client_with_unique_table.get(1) + assert record[2] == {"version": 1} + + def test_on_conflict_replace_updates_existing( + self, client_with_unique_table, sample_embeddings + ): + """Duplicate texts are updated with on_conflict='replace'.""" + client_with_unique_table.add( + texts=["hello"], + embeddings=[sample_embeddings[0]], + metadata=[{"version": 1}], + ) + rowids = client_with_unique_table.add( + texts=["hello"], + embeddings=[sample_embeddings[1]], + metadata=[{"version": 2}], + on_conflict="replace", + ) + assert len(rowids) == 1 + assert client_with_unique_table.count() == 1 + record = client_with_unique_table.get(rowids[0]) + assert record[2] == {"version": 2} + assert record[3] == pytest.approx(sample_embeddings[1], abs=1e-6) + + def test_on_conflict_replace_mixed_insert_and_update( + self, client_with_unique_table, sample_embeddings + ): + """Replace mode handles a mix of new and existing texts.""" + client_with_unique_table.add( + texts=["hello"], embeddings=[sample_embeddings[0]] + ) + rowids = client_with_unique_table.add( + texts=["hello", "world"], + embeddings=[sample_embeddings[1], sample_embeddings[2]], + on_conflict="replace", + ) + assert len(rowids) == 2 + assert client_with_unique_table.count() == 2 + + def test_on_conflict_replace_keeps_rowid( + self, client_with_unique_table, sample_embeddings + ): + """Replace mode preserves the original rowid.""" + original_rowids = client_with_unique_table.add( + texts=["hello"], embeddings=[sample_embeddings[0]] + ) + new_rowids = client_with_unique_table.add( + texts=["hello"], + embeddings=[sample_embeddings[1]], + on_conflict="replace", + ) + assert new_rowids == original_rowids + + def test_on_conflict_replace_vec_table_synced( + self, client_with_unique_table, sample_embeddings + ): + """Replace mode keeps the vector table in sync for similarity search.""" + client_with_unique_table.add( + texts=["hello"], embeddings=[sample_embeddings[0]] + ) + client_with_unique_table.add( + texts=["hello"], + embeddings=[sample_embeddings[1]], + on_conflict="replace", + ) + results = client_with_unique_table.similarity_search( + embedding=sample_embeddings[1], top_k=1 + ) + assert results[0][1] == "hello" + + def test_on_conflict_invalid_value(self, client_with_unique_table): + """Invalid on_conflict value raises ValidationError.""" + with pytest.raises(ValidationError): + client_with_unique_table.add( + texts=["hello"], + embeddings=[[0.1, 0.2, 0.3]], + on_conflict="bad", + ) + + def test_without_unique_text_allows_duplicates( + self, client_with_table, sample_embeddings + ): + """Without unique_text, duplicate texts are allowed.""" + client_with_table.add(texts=["hello"], embeddings=[sample_embeddings[0]]) + client_with_table.add(texts=["hello"], embeddings=[sample_embeddings[1]]) + assert client_with_table.count() == 2 + + @pytest.mark.integration class TestSimilaritySearch: """Tests for similarity_search method.""" diff --git a/tests/test_validation.py b/tests/test_validation.py index 5f6c15c..683f81f 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -14,6 +14,7 @@ validate_limit, validate_metadata_filters, validate_offset, + validate_on_conflict, validate_table_name, validate_top_k, ) @@ -239,3 +240,23 @@ def test_non_string_keys(self): """Test that non-string keys raise error.""" with pytest.raises(ValidationError, match="must be string"): validate_metadata_filters({123: "value"}) + + +@pytest.mark.unit +class TestValidateOnConflict: + """Tests for validate_on_conflict function.""" + + def test_valid_values(self): + """Test that valid on_conflict values pass validation.""" + for value in ["error", "ignore", "replace"]: + validate_on_conflict(value) + + def test_invalid_value(self): + """Test that invalid on_conflict value raises error.""" + with pytest.raises(ValidationError, match="on_conflict must be one of"): + validate_on_conflict("bad") + + def test_empty_value(self): + """Test that empty string raises error.""" + with pytest.raises(ValidationError, match="on_conflict must be one of"): + validate_on_conflict("") diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..98f7add --- /dev/null +++ b/uv.lock @@ -0,0 +1,26 @@ +version = 1 +revision = 3 +requires-python = ">=3.9" + +[[package]] +name = "sqlite-vec" +version = "0.1.9" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/85/9fad0045d8e7c8df3e0fa5a56c630e8e15ad6e5ca2e6106fceb666aa6638/sqlite_vec-0.1.9-py3-none-macosx_10_6_x86_64.whl", hash = "sha256:1b62a7f0a060d9475575d4e599bbf94a13d85af896bc1ce86ee80d1b5b48e5fb", size = 131171, upload-time = "2026-03-31T08:02:31.717Z" }, + { url = "https://files.pythonhosted.org/packages/a4/3d/3677e0cd2f92e5ebc43cd29fbf565b75582bff1ccfa0b8327c7508e1084f/sqlite_vec-0.1.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1d52e30513bae4cc9778ddbf6145610434081be4c3afe57cd877893bad9f6b6c", size = 165434, upload-time = "2026-03-31T08:02:32.712Z" }, + { url = "https://files.pythonhosted.org/packages/00/d4/f2b936d3bdc38eadcbd2a87875815db36430fab0363182ba5d12cd8e0b51/sqlite_vec-0.1.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e921e592f24a5f9a18f590b6ddd530eb637e2d474e3b1972f9bbeb773aa3cb9", size = 160076, upload-time = "2026-03-31T08:02:33.796Z" }, + { url = "https://files.pythonhosted.org/packages/6f/ad/6afd073b0f817b3e03f9e37ad626ae341805891f23c74b5292818f49ac63/sqlite_vec-0.1.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux1_x86_64.whl", hash = "sha256:1515727990b49e79bcaf75fdee2ffc7d461f8b66905013231251f1c8938e7786", size = 163388, upload-time = "2026-03-31T08:02:34.888Z" }, + { url = "https://files.pythonhosted.org/packages/42/89/81b2907cda14e566b9bf215e2ad82fc9b349edf07d2010756ffdb902f328/sqlite_vec-0.1.9-py3-none-win_amd64.whl", hash = "sha256:4a28dc12fa4b53d7b1dced22da2488fade444e96b5d16fd2d698cd670675cf32", size = 292804, upload-time = "2026-03-31T08:02:36.035Z" }, +] + +[[package]] +name = "sqlite-vec-client" +version = "2.4.0" +source = { editable = "." } +dependencies = [ + { name = "sqlite-vec" }, +] + +[package.metadata] +requires-dist = [{ name = "sqlite-vec", specifier = ">=0.1.6" }]