diff --git a/bindings/python/example/example.py b/bindings/python/example/example.py index 52cefe1e..7049ea26 100644 --- a/bindings/python/example/example.py +++ b/bindings/python/example/example.py @@ -933,8 +933,15 @@ async def main(): print(f"Error with partitioned KV table: {e}") traceback.print_exc() + + print("\n--- New: async context manager demo ---") + async with await fluss.FlussConnection.create(config) as demo_conn: + demo_table = await demo_conn.get_table(table_path) + async with demo_table.new_append().create_writer() as writer: + writer.append({"id": 1, "name": "demo", "score": 1.0}) + # auto-flushes on exit # Close connection - conn.close() + await conn.close() print("\nConnection closed") diff --git a/bindings/python/fluss/__init__.pyi b/bindings/python/fluss/__init__.pyi index 02edcdb3..2dce7bf9 100644 --- a/bindings/python/fluss/__init__.pyi +++ b/bindings/python/fluss/__init__.pyi @@ -245,7 +245,7 @@ class FlussConnection: async def create(config: Config) -> FlussConnection: ... def get_admin(self) -> FlussAdmin: ... async def get_table(self, table_path: TablePath) -> FlussTable: ... - def close(self) -> None: ... + async def close(self) -> None: ... def __enter__(self) -> FlussConnection: ... def __exit__( self, @@ -253,6 +253,13 @@ class FlussConnection: exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> bool: ... + async def __aenter__(self) -> FlussConnection: ... + async def __aexit__( + self, + exc_type: Optional[type], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> bool: ... def __repr__(self) -> str: ... class ServerNode: @@ -611,6 +618,27 @@ class AppendWriter: def write_arrow_batch(self, batch: pa.RecordBatch) -> WriteResultHandle: ... def write_pandas(self, df: pd.DataFrame) -> None: ... async def flush(self) -> None: ... + async def __aenter__(self) -> AppendWriter: + """ + Enter the async context manager. + + Returns: + The AppendWriter instance. + """ + ... + async def __aexit__( + self, + exc_type: Optional[type], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> bool: + """ + Exit the async context manager. + + On exit, the writer is automatically flushed to ensure + all pending records are sent and acknowledged. + """ + ... def __repr__(self) -> str: ... class UpsertWriter: @@ -644,6 +672,27 @@ class UpsertWriter: async def flush(self) -> None: """Flush all pending upsert/delete operations to the server.""" ... + async def __aenter__(self) -> UpsertWriter: + """ + Enter the async context manager. + + Returns: + The UpsertWriter instance. + """ + ... + async def __aexit__( + self, + exc_type: Optional[type], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> bool: + """ + Exit the async context manager. + + On exit, the writer is automatically flushed to ensure + all pending records are sent and acknowledged. + """ + ... def __repr__(self) -> str: ... @@ -807,6 +856,15 @@ class LogScanner: You must call subscribe(), subscribe_buckets(), or subscribe_partition() first. """ + ... + + async def __aenter__(self) -> LogScanner: ... + async def __aexit__( + self, + exc_type: Optional[type], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> bool: ... def __repr__(self) -> str: ... def __aiter__(self) -> AsyncIterator[Union[ScanRecord, RecordBatch]]: ... diff --git a/bindings/python/src/connection.rs b/bindings/python/src/connection.rs index a8d2d9e3..ccf2d5c5 100644 --- a/bindings/python/src/connection.rs +++ b/bindings/python/src/connection.rs @@ -18,6 +18,7 @@ use crate::*; use pyo3_async_runtimes::tokio::future_into_py; use std::sync::Arc; +use std::time::Duration; /// Connection to a Fluss cluster #[pyclass] @@ -82,9 +83,19 @@ impl FlussConnection { }) } - // Close the connection - fn close(&mut self) -> PyResult<()> { - Ok(()) + /// Close the connection (async). + /// + /// Gracefully shuts down the connection by draining any pending write batches. + /// This method is awaitable. + fn close<'py>(&self, py: Python<'py>) -> PyResult> { + let inner = self.inner.clone(); + + future_into_py(py, async move { + inner + .close(Duration::MAX) + .await + .map_err(|e| FlussError::from_core_error(&e)) + }) } // Enter the runtime context (for 'with' statement) @@ -100,10 +111,36 @@ impl FlussConnection { _exc_value: Option>, _traceback: Option>, ) -> PyResult { - self.close()?; + // Sync exit cannot await the graceful drain, so it's a no-op here. + // Users should use 'async with' for graceful shutdown. Ok(false) } + // Enter the async runtime context (for 'async with' statement) + fn __aenter__<'py>(slf: PyRef<'py, Self>, py: Python<'py>) -> PyResult> { + let py_slf = slf.into_pyobject(py)?.unbind(); + future_into_py(py, async move { Ok(py_slf) }) + } + + // Exit the async runtime context (for 'async with' statement) + #[pyo3(signature = (_exc_type=None, _exc_value=None, _traceback=None))] + fn __aexit__<'py>( + &self, + py: Python<'py>, + _exc_type: Option>, + _exc_value: Option>, + _traceback: Option>, + ) -> PyResult> { + let inner = self.inner.clone(); + future_into_py(py, async move { + inner + .close(Duration::MAX) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + Ok(false) + }) + } + fn __repr__(&self) -> String { "FlussConnection()".to_string() } diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index c1b46734..9bf9c6ff 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -989,6 +989,32 @@ impl AppendWriter { }) } + // Enter the async runtime context (for 'async with' statement) + fn __aenter__<'py>(slf: PyRef<'py, Self>, py: Python<'py>) -> PyResult> { + let py_slf = slf.into_pyobject(py)?.unbind(); + future_into_py(py, async move { Ok(py_slf) }) + } + + // Exit the async runtime context (for 'async with' statement) + /// On exit, the writer is automatically flushed. + #[pyo3(signature = (_exc_type=None, _exc_value=None, _traceback=None))] + fn __aexit__<'py>( + &self, + py: Python<'py>, + _exc_type: Option>, + _exc_value: Option>, + _traceback: Option>, + ) -> PyResult> { + let inner = self.inner.clone(); + future_into_py(py, async move { + inner + .flush() + .await + .map_err(|e| FlussError::from_core_error(&e))?; + Ok(false) + }) + } + fn __repr__(&self) -> String { "AppendWriter()".to_string() } diff --git a/bindings/python/src/upsert.rs b/bindings/python/src/upsert.rs index 02ad7fa4..cba28683 100644 --- a/bindings/python/src/upsert.rs +++ b/bindings/python/src/upsert.rs @@ -108,6 +108,32 @@ impl UpsertWriter { }) } + // Enter the async runtime context (for 'async with' statement) + fn __aenter__<'py>(slf: PyRef<'py, Self>, py: Python<'py>) -> PyResult> { + let py_slf = slf.into_pyobject(py)?.unbind(); + future_into_py(py, async move { Ok(py_slf) }) + } + + // Exit the async runtime context (for 'async with' statement) + /// On exit, the writer is automatically flushed. + #[pyo3(signature = (_exc_type=None, _exc_value=None, _traceback=None))] + fn __aexit__<'py>( + &self, + py: Python<'py>, + _exc_type: Option>, + _exc_value: Option>, + _traceback: Option>, + ) -> PyResult> { + let writer = self.writer.clone(); + future_into_py(py, async move { + writer + .flush() + .await + .map_err(|e| FlussError::from_core_error(&e))?; + Ok(false) + }) + } + fn __repr__(&self) -> String { "UpsertWriter()".to_string() } diff --git a/bindings/python/test/conftest.py b/bindings/python/test/conftest.py index 47c92807..c1622f29 100644 --- a/bindings/python/test/conftest.py +++ b/bindings/python/test/conftest.py @@ -96,7 +96,7 @@ async def _connect(bootstrap_servers): nodes = await admin.get_server_nodes() if any(n.server_type == "TabletServer" for n in nodes): return conn - conn.close() + await conn.close() last_err = RuntimeError("No TabletServer available yet") except Exception as e: last_err = e diff --git a/bindings/python/test/test_context_manager.py b/bindings/python/test/test_context_manager.py new file mode 100644 index 00000000..3a1c3aef --- /dev/null +++ b/bindings/python/test/test_context_manager.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import pytest +import pyarrow as pa +import time +import fluss + +def _poll_records(scanner, expected_count, timeout_s=10): + """Poll a record-based scanner until expected_count records are collected.""" + collected = [] + deadline = time.monotonic() + timeout_s + while len(collected) < expected_count and time.monotonic() < deadline: + records = scanner.poll(5000) + collected.extend(records) + return collected + +@pytest.mark.asyncio +async def test_connection_context_manager(plaintext_bootstrap_servers): + config = fluss.Config({"bootstrap.servers": plaintext_bootstrap_servers}) + async with await fluss.FlussConnection.create(config) as conn: + admin = conn.get_admin() + nodes = await admin.get_server_nodes() + assert len(nodes) > 0 + + +@pytest.mark.asyncio +async def test_append_writer_success_flush(connection, admin): + table_path = fluss.TablePath("fluss", "test_append_ctx_success") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema(pa.schema([pa.field("a", pa.int32())])) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + + async with table.new_append().create_writer() as writer: + writer.append({"a": 1}) + writer.append({"a": 2}) + # No explicit flush here + + # After context exit, data should be flushed + scanner = await table.new_scan().create_log_scanner() + scanner.subscribe(0, fluss.EARLIEST_OFFSET) + records = _poll_records(scanner, expected_count=2) + assert len(records) == 2 + assert sorted([r.row["a"] for r in records]) == [1, 2] + +@pytest.mark.asyncio +async def test_connection_drain_on_close(plaintext_bootstrap_servers, admin): + table_path = fluss.TablePath("fluss", "test_conn_drain") + await admin.drop_table(table_path, ignore_if_not_exists=True) + schema = fluss.Schema(pa.schema([pa.field("a", pa.int32())])) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + config = fluss.Config({"bootstrap.servers": plaintext_bootstrap_servers}) + async with await fluss.FlussConnection.create(config) as conn: + table = await conn.get_table(table_path) + writer = table.new_append().create_writer() + writer.append({"a": 123}) + # No explicit flush, no writer context exit. + # Rely on connection.__aexit__ -> close() to drain. + + # Re-connect with a new connection to verify data arrived + async with await fluss.FlussConnection.create(config) as conn2: + table2 = await conn2.get_table(table_path) + scanner = await table2.new_scan().create_log_scanner() + scanner.subscribe(0, fluss.EARLIEST_OFFSET) + records = _poll_records(scanner, expected_count=1) + assert len(records) == 1 + assert records[0].row["a"] == 123 + +@pytest.mark.asyncio +async def test_upsert_writer_context_manager(connection, admin): + table_path = fluss.TablePath("fluss", "test_upsert_ctx") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema(pa.schema([pa.field("id", pa.int32()), pa.field("v", pa.string())]), primary_keys=["id"]) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + + # Success path: verify it flushes + async with table.new_upsert().create_writer() as writer: + writer.upsert({"id": 1, "v": "a"}) + + lookuper = table.new_lookup().create_lookuper() + res = await lookuper.lookup({"id": 1}) + assert res is not None + assert res["v"] == "a" + +@pytest.mark.asyncio +async def test_connection_context_manager_exception(plaintext_bootstrap_servers): + config = fluss.Config({"bootstrap.servers": plaintext_bootstrap_servers}) + class TestException(Exception): pass + + try: + async with await fluss.FlussConnection.create(config) as conn: + raise TestException("connection error") + except TestException: + pass + # If we reach here without hanging, the connection __aexit__ gracefully handled the error \ No newline at end of file diff --git a/bindings/python/test/test_sasl_auth.py b/bindings/python/test/test_sasl_auth.py index 9dd2ddda..6889f1ab 100644 --- a/bindings/python/test/test_sasl_auth.py +++ b/bindings/python/test/test_sasl_auth.py @@ -45,7 +45,7 @@ async def test_sasl_connect_with_valid_credentials(sasl_bootstrap_servers): # Cleanup await admin.drop_database(db_name, ignore_if_not_exists=True, cascade=True) - conn.close() + await conn.close() async def test_sasl_connect_with_second_user(sasl_bootstrap_servers): @@ -62,7 +62,7 @@ async def test_sasl_connect_with_second_user(sasl_bootstrap_servers): # Basic operation to confirm functional connection assert not await admin.database_exists("some_nonexistent_db_alice") - conn.close() + await conn.close() async def test_sasl_connect_with_wrong_password(sasl_bootstrap_servers): diff --git a/crates/fluss/src/client/connection.rs b/crates/fluss/src/client/connection.rs index 62d440be..a3ffd755 100644 --- a/crates/fluss/src/client/connection.rs +++ b/crates/fluss/src/client/connection.rs @@ -28,10 +28,6 @@ use parking_lot::RwLock; use std::sync::Arc; use std::time::Duration; -// TODO: implement `close(&self, timeout: Duration)` to gracefully shut down the -// writer client (drain pending batches, then force-close on timeout). -// Java's FlussConnection.close() calls writerClient.close(Long.MAX_VALUE). -// WriterClient::close() already exists but is never called from the public API. pub struct FlussConnection { metadata: Arc, network_connects: Arc, @@ -73,6 +69,19 @@ impl FlussConnection { }) } + /// Gracefully shut down the connection, draining any pending write batches. + /// + /// If a writer client has been created, this method will signal it to drain + /// its buffers and wait for the background sender task to complete, bounded + /// by the provided timeout. + pub async fn close(&self, timeout: Duration) -> Result<()> { + let writer_client = self.writer_client.write().take(); + if let Some(client) = writer_client { + client.close(timeout).await?; + } + Ok(()) + } + pub fn get_metadata(&self) -> Arc { self.metadata.clone() }