From 9e3bcc9a3e3e0222f0eaaa9f6e30be22231e5890 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 29 Jun 2026 12:16:25 -0700 Subject: [PATCH 1/2] Coalesce multiple wakeups for optimization --- kafka/net/wakeup_notifier.py | 30 ++++++--- test/net/test_wakeup_notifier.py | 105 +++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 7 deletions(-) diff --git a/kafka/net/wakeup_notifier.py b/kafka/net/wakeup_notifier.py index e2120dd7e..67041deb5 100644 --- a/kafka/net/wakeup_notifier.py +++ b/kafka/net/wakeup_notifier.py @@ -28,8 +28,19 @@ def __init__(self, net): # next ``__call__``. All accesses run on the IO thread (notify # routes through call_soon_threadsafe), so no lock is needed. self._pending = False + # Coalescing guard: True once a ``_wakeup`` has been scheduled via + # ``notify()`` but has not yet run on the IO thread. Lets ``notify()`` + # skip the redundant ``call_soon_threadsafe`` (Task alloc + socketpair + # write + selector wakeup) when a wake is already in flight. Set on + # user threads, cleared by ``_wakeup`` on the IO thread; cross-thread + # access is GIL-atomic and the check-then-set in ``notify()`` can at + # worst schedule one redundant wake, never drop one (see ``notify``). + self._scheduled = False def _wakeup(self): + # Clear the coalescing guard first so a notify() racing with this + # callback schedules a fresh wake rather than being dropped. + self._scheduled = False if self._fut is not None and not self._fut.is_done: self._fut.success(None) else: @@ -57,13 +68,18 @@ async def __call__(self, timeout_secs=None): self._net.cancel(timer) def notify(self): - # Always queue _wakeup on the IO thread. Skipping the queue when - # ``self._fut is None`` would re-introduce the lost-wakeup race: - # the check could pass before another thread enters ``__call__`` - # and creates the future. Routing through the IO thread is one - # call_soon_threadsafe (~microseconds) and lets ``_wakeup`` decide - # under single-threaded semantics whether to signal or latch. + # Coalesce: if a _wakeup is already scheduled and not yet consumed, + # skip. The state this notify() announces was mutated before this + # call, and the already-pending _wakeup runs (and drains) later, so it + # will observe that state -- no lost wakeup. We deliberately do NOT + # skip based on ``self._fut is None``. ``_scheduled`` is cleared by + # ``_wakeup`` the instant it runs, so a True value guarantees a wake + # is still pending. The check-then-set is GIL-atomic per access; a + # lost race between two threads only schedules one redundant wake. + if self._scheduled: + return + self._scheduled = True try: self._net.call_soon_threadsafe(self._wakeup) except ReferenceError: - pass + self._scheduled = False diff --git a/test/net/test_wakeup_notifier.py b/test/net/test_wakeup_notifier.py index c687ef9b1..7554e5fd6 100644 --- a/test/net/test_wakeup_notifier.py +++ b/test/net/test_wakeup_notifier.py @@ -123,6 +123,111 @@ async def task(): 'a single latch should be consumed by one __call__; took %.3fs' % elapsed) + def test_coalesced_notify_still_latches(self, net, notifier): + """Coalescing guard: a burst of notify() before any awaiter schedules + exactly one _wakeup (the rest are skipped because _scheduled is True), + and that single _wakeup still latches _pending -- so the skips do NOT + lose the wakeup. The next __call__ returns immediately.""" + async def task(): + notifier.notify() # schedules _wakeup; _scheduled=True + assert notifier._scheduled is True + notifier.notify() # coalesced: _scheduled already True + notifier.notify() # coalesced + await net.sleep(0) # let the one queued _wakeup run + assert notifier._scheduled is False, '_wakeup must clear the guard' + assert notifier._pending is True, 'coalesced burst must still latch' + start = time.monotonic() + await notifier(timeout_secs=5.0) + assert notifier._pending is False, 'latch consumed' + return time.monotonic() - start + + elapsed = net.run(task) + assert elapsed < 0.5, ( + 'latched wakeup should fire __call__ immediately; took %.3fs' % elapsed) + + def test_coalescing_reschedules_after_wakeup_runs(self, net, notifier): + """Once a _wakeup has run (clearing _scheduled), a subsequent notify() + must schedule a fresh wake rather than be silently skipped. Verifies + _scheduled is not a one-way latch that would swallow later notifies.""" + async def task(): + notifier.notify() + await net.sleep(0) # _wakeup runs, _scheduled=False + await notifier(timeout_secs=5.0) # consume first latched wake + assert notifier._scheduled is False + + # A brand new notify() after the first cycle must wake again. + notifier.notify() + assert notifier._scheduled is True + await net.sleep(0) + assert notifier._pending is True + start = time.monotonic() + await notifier(timeout_secs=5.0) + return time.monotonic() - start + + elapsed = net.run(task) + assert elapsed < 0.5, ( + 'second cycle should wake immediately; took %.3fs' % elapsed) + + def test_no_lost_wakeup_under_concurrent_notify_stress(self): + """Probabilistic regression guard for the coalescing path: a consumer + coroutine awaits the notifier every iteration (max race exposure) and + drains a shared queue; many cross-thread producers append work and + notify(). Each notify must either wake the current await, latch for the + next one, or be coalesced into an already-pending wake -- never lost. + + A long (5s) awaiter timeout means a genuinely lost-and-uncovered wakeup + would stall the consumer for ~5s; the test asserts the tail (time from + 'all producers done' to 'consumer drained everything') stays well under + that. With the latch + coalescing correct, it finishes in milliseconds. + """ + import collections + net = NetworkSelector() + net.start() + try: + notifier = WakeupNotifier(net) + work = collections.deque() + N = 4000 + n_threads = 4 + per = N // n_threads + consumed = [] + finished = threading.Event() + + async def consumer(): + while len(consumed) < N: + await notifier(timeout_secs=5.0) # always await -> every + # notify is a race chance + while work: # drain all available + try: + consumed.append(work.popleft()) + except IndexError: + break + finished.set() + + def producer(base): + for i in range(per): + work.append((base, i)) # deque append is atomic + notifier.notify() + + net.call_soon_with_future(consumer) # thread-safe schedule + threads = [threading.Thread(target=producer, args=(b,), daemon=True) + for b in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + t_joined = time.monotonic() + ok = finished.wait(timeout=15) + tail = time.monotonic() - t_joined + + assert ok, 'consumer never finished -- a wakeup was lost' + assert len(consumed) == N, ( + 'consumed %d of %d -- lost work' % (len(consumed), N)) + assert tail < 2.0, ( + 'consumer stalled %.2fs after producers finished -- a tail ' + 'wakeup was lost and only the 5s timeout recovered it' % tail) + finally: + net.stop() + def test_notify_from_other_thread(self, net, notifier): """notify() is safe to call from another thread; the wakeup routes through call_soon_threadsafe to the IO thread.""" From 8912d0a16b586d3218bdc2c4150b6157288c48e6 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 29 Jun 2026 12:46:31 -0700 Subject: [PATCH 2/2] producer: run sender on net IO thread via coroutine --- kafka/producer/kafka.py | 39 +++++--- kafka/producer/sender.py | 176 ++++++++++++++++++++------------- test/producer/test_producer.py | 17 +++- test/producer/test_sender.py | 78 ++++++++++----- 4 files changed, 205 insertions(+), 105 deletions(-) diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index 6d1dd079e..20df36f00 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -525,7 +525,10 @@ def __init__(self, **configs): metrics=self._metrics, metric_group_prefix='producer', wakeup_timeout_ms=self.config['max_block_ms'], **self.config) - manager = client._manager + self._client = client + self._manager = manager = client._manager + self._net = client._net + self._net.start() # We currently depend on eager-resolution of api_version. # If it wasn't provided as a config option, we need to bootstrap @@ -629,7 +632,6 @@ def __init__(self, **configs): transaction_manager=self._transaction_manager, guarantee_message_order=guarantee_message_order, **self.config) - self._sender.daemon = True self._sender.start() self._closed = False @@ -693,25 +695,40 @@ def __getattr__(self, name): log.info("%s: Closing the Kafka producer with %s secs timeout.", str(self), timeout) self.flush(timeout) - invoked_from_callback = bool(threading.current_thread() is self._sender) + on_io_thread = bool(self._net._io_thread is not None + and threading.current_thread() is self._net._io_thread) if timeout > 0: - if invoked_from_callback: + if on_io_thread: log.warning("%s: Overriding close timeout %s secs to 0 in order to" " prevent useless blocking due to self-join. This" " means you have incorrectly invoked close with a" " non-zero timeout from the producer call-back.", str(self), timeout) - else: - # Try to close gracefully. - if self._sender is not None: - self._sender.initiate_close() - self._sender.join(timeout) - - if self._sender is not None and self._sender.is_alive(): + elif self._sender is not None: + self._sender.initiate_close() + try: + self._manager.run(self._manager.wait_for, + self._sender._loop_future, timeout * 1000) + except Errors.KafkaTimeoutError: + pass + + if self._sender is not None and self._sender.is_running(): log.info("%s: Proceeding to force close the producer since pending" " requests could not be completed within timeout %s.", str(self), timeout) self._sender.force_close() + if not on_io_thread: + try: + self._manager.run(self._manager.wait_for, + self._sender._loop_future, self.config['retry_backoff_ms']) + except Errors.KafkaTimeoutError: + pass + + if not on_io_thread: + try: + self._client.close() + except Exception: + log.exception("%s: Failed to close network client", str(self)) if self._metrics: self._metrics.close() diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index e362a26d7..02f197cc1 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -2,12 +2,12 @@ import copy import heapq import logging -import threading import time from kafka import errors as Errors from kafka.metrics.measurable import AnonMeasurable from kafka.metrics.stats import Avg, Max, Rate +from kafka.net.wakeup_notifier import WakeupNotifier from kafka.producer.transaction_manager import TransactionManager from kafka.protocol.producer import ProduceRequest, ProduceResponse from kafka.structs import TopicPartition @@ -24,11 +24,15 @@ _PartitionProduceResponse = ProduceResponse.TopicProduceResponse.PartitionProduceResponse -class Sender(threading.Thread): +class Sender: """ - The background thread that handles the sending of produce requests to the - Kafka cluster. This thread makes metadata requests to renew its view of the - cluster and then sends produce requests to the appropriate nodes. + Drives the sending of produce requests to the Kafka cluster. + + Runs as an ``async def _sender_loop`` coroutine on the shared ``kafka/net/`` + IO thread (scheduled via ``manager.call_soon`` in ``start``), alongside the + cluster metadata-refresh and coordinator-heartbeat loops. It drains the + accumulator, dispatches produce requests, and sleeps until the next deadline + on a thread-safe ``WakeupNotifier``. """ DEFAULT_CONFIG = { 'max_request_size': 1048576, @@ -45,19 +49,23 @@ class Sender(threading.Thread): } def __init__(self, client, metadata, accumulator, **configs): - super().__init__() self.config = copy.copy(self.DEFAULT_CONFIG) for key in self.config: if key in configs: self.config[key] = configs.pop(key) - self.name = self.config['client_id'] + '-network-thread' self._client = client + self._manager = client._manager + self._net = client._net self._accumulator = accumulator self._metadata = client.cluster self._running = True self._force_close = False self._topics_to_add = set() + self._wakeup = WakeupNotifier(self._net) + # Future returned by manager.call_soon(self._sender_loop); resolves when + # the loop coroutine returns. close() blocks on it. + self._loop_future = None if self.config['metrics']: self._sensors = SenderMetrics(self.config['metrics'], self._client, self._metadata) else: @@ -103,18 +111,32 @@ def _get_expired_inflight_batches(self, now=None): del self._in_flight_batches[tp] return expired_batches - def run(self): - """The main run loop for the sender thread.""" - log.debug("%s: Starting Kafka producer I/O thread.", str(self)) + def start(self): + """Schedule the sender coroutine on the IO thread. Idempotent. + + Returns the Future that resolves when the loop coroutine completes + (after the graceful drain on close). + """ + if self._loop_future is None or self._loop_future.is_done: + self._loop_future = self._manager.call_soon(self._sender_loop) + return self._loop_future + + def is_running(self): + return self._loop_future is not None and not self._loop_future.is_done + + async def _sender_loop(self): + """The main loop for the sender, run as a coroutine on the IO thread.""" + log.debug("%s: Starting Kafka producer I/O loop.", str(self)) # main loop, runs until close is called while self._running: try: - self.run_once() + await self._run_once() except Exception: - log.exception("%s: Uncaught error in kafka producer I/O thread", str(self)) + log.exception("%s: Uncaught error in kafka producer I/O loop", str(self)) + await self._net.sleep(self.config['retry_backoff_ms'] / 1000) - log.debug("%s: Beginning shutdown of Kafka producer I/O thread, sending" + log.debug("%s: Beginning shutdown of Kafka producer I/O loop, sending" " remaining records.", str(self)) # okay we stopped accepting requests but there may still be @@ -124,23 +146,19 @@ def run(self): and (self._accumulator.has_undrained() or self._client.in_flight_request_count() > 0)): try: - self.run_once() + await self._run_once() except Exception: - log.exception("%s: Uncaught error in kafka producer I/O thread", str(self)) + log.exception("%s: Uncaught error in kafka producer I/O loop", str(self)) + await self._net.sleep(self.config['retry_backoff_ms'] / 1000) if self._force_close: # We need to fail all the incomplete batches and wake up the # threads waiting on the futures. self._accumulator.abort_incomplete_batches() - try: - self._client.close() - except Exception: - log.exception("%s: Failed to close network client", str(self)) - - log.debug("%s: Shutdown of Kafka producer I/O thread has completed.", str(self)) + log.debug("%s: Shutdown of Kafka producer I/O loop has completed.", str(self)) - def run_once(self): + async def _run_once(self): """Run a single iteration of sending.""" while self._topics_to_add: self._metadata.add_topic(self._topics_to_add.pop()) @@ -155,9 +173,15 @@ def run_once(self): # below blocks new sends until the response arrives. self._transaction_manager.init_producer_id() - if self._transaction_manager.has_in_flight_transactional_request() or self._maybe_send_pending_request(): + if self._transaction_manager.has_in_flight_transactional_request(): # as long as there are outstanding transactional requests, we simply wait for them to return - self._client.poll(timeout_ms=self.config['retry_backoff_ms']) + await self._wakeup(self.config['retry_backoff_ms'] / 1000) + return + + if await self._maybe_send_pending_request(): + # A transactional request was dispatched or is backing off + # (the latter already awaited its delay); gate produce until + # the next iteration. return # do not continue sending if the transaction manager is in a failed state, if there @@ -170,7 +194,7 @@ def run_once(self): last_error = self._transaction_manager.last_error if last_error is not None: self._maybe_abort_batches(last_error) - self._client.poll(timeout_ms=self.config['retry_backoff_ms']) + await self._wakeup(self.config['retry_backoff_ms'] / 1000) return elif self._transaction_manager.has_abortable_error(): # Attempt to get the last error that caused this abort. @@ -185,7 +209,7 @@ def run_once(self): self._transaction_manager.authentication_failed(e) poll_timeout_ms = self._send_producer_data() - self._client.poll(timeout_ms=poll_timeout_ms) + await self._wakeup(poll_timeout_ms / 1000) def _send_producer_data(self, now=None): now = time.monotonic() if now is None else now @@ -203,7 +227,7 @@ def _send_producer_data(self, now=None): not_ready_timeout_ms = float('inf') for node in list(ready_nodes): if not self._client.is_ready(node): - node_delay_ms = self._client.connection_delay(node) + node_delay_ms = self._manager.connection_delay(node) log.debug('%s: Node %s not ready; delaying produce of accumulated batch (%f ms)', str(self), node, node_delay_ms) self._client.maybe_connect(node, wakeup=False) ready_nodes.remove(node) @@ -282,14 +306,20 @@ def _send_producer_data(self, now=None): for node_id, request in requests.items(): batches = batches_by_node[node_id] log.debug('%s: Sending Produce Request: %r', str(self), request) - (self._client.send(node_id, request, wakeup=False) + (self._manager.send(request, node_id=node_id) .add_callback( self._handle_produce_response, node_id, time.monotonic(), batches) .add_errback( self._failed_produce, batches, node_id)) return poll_timeout_ms - def _maybe_send_pending_request(self): + async def _maybe_send_pending_request(self): + """Dispatch the next pending transactional/idempotent coordinator request. + + Returns True if a request was dispatched or is backing off (in which case + the produce path is gated until the next loop iteration), False if there + is no pending request to send. + """ if self._transaction_manager.is_completing() and self._accumulator.has_incomplete: if self._transaction_manager.is_aborting(): # KIP-654: prefer the last error that triggered the abort; @@ -310,44 +340,48 @@ def _maybe_send_pending_request(self): return False log.debug("%s: Sending transactional request %s", str(self), next_request_handler.request) - while self._running and not self._force_close: - target_node = None - try: - if next_request_handler.needs_coordinator(): - target_node = self._transaction_manager.coordinator(next_request_handler.coordinator_type) - if target_node is None: - self._transaction_manager.lookup_coordinator_for_request(next_request_handler) - break - elif not self._client.await_ready(target_node, timeout_ms=self.config['request_timeout_ms']): - self._transaction_manager.lookup_coordinator_for_request(next_request_handler) - target_node = None - break - else: - target_node = self._client.least_loaded_node() - if target_node is None: - self._client.poll(timeout_ms=self.config['retry_backoff_ms'], - future=self._metadata.request_update()) - elif not self._client.await_ready(target_node, timeout_ms=self.config['request_timeout_ms']): - continue - - if target_node is not None: - if next_request_handler.is_retry: - time.sleep(self.config['retry_backoff_ms'] / 1000) - txn_correlation_id = self._transaction_manager.next_in_flight_request_correlation_id() - future = self._client.send(target_node, next_request_handler.request) - future.add_both(next_request_handler.on_complete, txn_correlation_id) - return True - - except Exception as e: - log.warning("%s: Got an exception when trying to find a node to send a transactional request to. Going to back off and retry: %s", str(self), e) - if next_request_handler.needs_coordinator(): + backoff = self.config['retry_backoff_ms'] / 1000 + try: + if next_request_handler.needs_coordinator(): + target_node = self._transaction_manager.coordinator(next_request_handler.coordinator_type) + if target_node is None or not self._client.is_ready(target_node): + # Coordinator unknown or its connection isn't ready: re-look + # up the coordinator (it may have moved) and back off. + if target_node is not None: + self._client.maybe_connect(target_node, wakeup=False) + backoff = max(backoff, self._manager.connection_delay(target_node)) self._transaction_manager.lookup_coordinator_for_request(next_request_handler) - break + self._transaction_manager.retry(next_request_handler) + await self._net.sleep(backoff) + return True + else: + target_node = self._manager.least_loaded_node() + if target_node is None: + # No known broker -- force a metadata refresh and back off. + self._metadata.request_update() + self._transaction_manager.retry(next_request_handler) + await self._net.sleep(backoff) + return True + if not self._client.is_ready(target_node): + self._client.maybe_connect(target_node, wakeup=False) + self._transaction_manager.retry(next_request_handler) + await self._net.sleep(max(backoff, self._manager.connection_delay(target_node))) + return True - if target_node is None: + if next_request_handler.is_retry: + await self._net.sleep(backoff) + txn_correlation_id = self._transaction_manager.next_in_flight_request_correlation_id() + future = self._manager.send(next_request_handler.request, node_id=target_node) + future.add_both(next_request_handler.on_complete, txn_correlation_id) + return True + + except Exception as e: + log.warning("%s: Got an exception when trying to find a node to send a transactional request to. Going to back off and retry: %s", str(self), e) + if next_request_handler.needs_coordinator(): + self._transaction_manager.lookup_coordinator_for_request(next_request_handler) self._transaction_manager.retry(next_request_handler) - - return True + await self._net.sleep(backoff) + return True def _maybe_abort_batches(self, exc): if self._accumulator.has_incomplete: @@ -383,6 +417,9 @@ def _failed_produce(self, batches, node_id, error): log.error("%s: Error sending produce request to node %d: %s", str(self), node_id, error) # trace for batch in batches: self._complete_batch_with_exception(batch, error) + # Completing batches frees in-flight capacity, unmutes partitions, and + # may re-enqueue retries; wake the loop to re-drain promptly. + self.wakeup() def _handle_produce_response(self, node_id, send_time, batches, response): """Handle a produce response.""" @@ -401,6 +438,9 @@ def _handle_produce_response(self, node_id, send_time, batches, response): synthetic = _PartitionProduceResponse(error_code=0) for batch in batches: self._complete_batch(batch, synthetic) + # Completing batches frees in-flight capacity, unmutes partitions, and + # may re-enqueue retries; wake the loop to re-drain promptly. + self.wakeup() def _record_exceptions_fn(self, top_level_exception, record_errors, error_message): """Returns a fn mapping batch_index to exception""" @@ -705,8 +745,12 @@ def _produce_request(self, node_id, acks, timeout, batches): ) def wakeup(self): - """Wake up the selector associated with this send thread.""" - self._client.wakeup() + """Wake the sender loop early (e.g. when a sendable batch is appended). + + Thread-safe: ``WakeupNotifier.notify`` routes through + ``call_soon_threadsafe``, so user threads may call this directly. + """ + self._wakeup.notify() def bootstrap_connected(self): return self._client.bootstrap_connected() diff --git a/test/producer/test_producer.py b/test/producer/test_producer.py index d8d92b996..25d44c68d 100644 --- a/test/producer/test_producer.py +++ b/test/producer/test_producer.py @@ -94,15 +94,22 @@ def _producer_for_send_test(partitioner): """Build a real KafkaProducer but replace the accumulator + sender with mocks so ``send()`` doesn't try to actually push data. - __init__ already starts a real Sender thread; we stop and join it before - swapping in the mock so it isn't orphaned (close() would otherwise act on - the mock and leak the real daemon thread). MockBroker keeps it off the - real network.""" + __init__ already starts the real Sender coroutine on the IO thread; we stop + it and wait on its loop Future before swapping in the mock so it isn't + orphaned (close() would otherwise act on the mock). MockBroker keeps it off + the real network.""" producer = _mock_producer(partitioner=partitioner) producer._sender.initiate_close() - producer._sender.join(2) + producer._manager.run(producer._manager.wait_for, producer._sender._loop_future, 2000) producer._accumulator = MagicMock() producer._sender = MagicMock() + # close() now blocks on the sender's loop Future; give the mock an + # already-resolved one (and is_running()==False) so teardown doesn't hang. + from kafka.future import Future + _done = Future() + _done.success(None) + producer._sender._loop_future = _done + producer._sender.is_running.return_value = False producer._metadata = MagicMock() producer._metadata.topics.return_value = {'t'} producer._metadata.partitions_for_topic.return_value = set(range(20)) diff --git a/test/producer/test_sender.py b/test/producer/test_sender.py index e127251f8..124959a26 100644 --- a/test/producer/test_sender.py +++ b/test/producer/test_sender.py @@ -30,6 +30,30 @@ _PartitionProduceResponse = ProduceResponse.TopicProduceResponse.PartitionProduceResponse +class _RecordingWakeup: + """Stand-in for ``Sender._wakeup`` (a WakeupNotifier) in unit tests. + + The sender loop sleeps via ``await self._wakeup(timeout_secs)``; this + records the timeouts so tests can assert on them without driving a real + timer, and provides a no-op ``notify`` for ``Sender.wakeup``. + """ + def __init__(self): + self.calls = [] + + async def __call__(self, timeout_secs=None): + self.calls.append(timeout_secs) + + def notify(self): + pass + + +def _drive(sender, coro_method): + """Run one of the sender's coroutine methods to completion on the test + selector (no IO thread is started in unit tests, so manager.run drives + the loop on the calling thread).""" + return sender._manager.run(coro_method) + + def _partition_response(error_cls=None, **kwargs): """Test helper that constructs a PartitionProduceResponse. @@ -397,7 +421,7 @@ def test_transaction_aborted_error_on_user_abort_with_undrained_batches(client, # Short-circuit the EndTxnHandler dispatch so we don't need a live # coordinator -- the abort happens before next_request_handler is consulted. mocker.patch.object(tm, 'next_request_handler', return_value=None) - sender._maybe_send_pending_request() + _drive(sender, sender._maybe_send_pending_request) assert future.failed() assert isinstance(future.exception, Errors.TransactionAbortedError) @@ -461,38 +485,41 @@ def test_failed_produce(sender, mocker): def test_run_once(sender, mocker): """The plain (non-transactional) iteration: drain pending topic adds into - cluster metadata, send producer data, and poll the client with the - timeout _send_producer_data returned.""" - mocker.patch.object(sender, '_client') + cluster metadata, send producer data, and sleep for the timeout + _send_producer_data returned.""" mocker.patch.object(sender, '_send_producer_data', return_value=42) + wakeup = _RecordingWakeup() + sender._wakeup = wakeup spy_add_topic = mocker.spy(sender._metadata, 'add_topic') sender.add_topic('foo-topic') - sender.run_once() + _drive(sender, sender._run_once) spy_add_topic.assert_called_once_with('foo-topic') assert not sender._topics_to_add sender._send_producer_data.assert_called_once() - sender._client.poll.assert_called_once_with(timeout_ms=42) + assert wakeup.calls == [42 / 1000] def test_run_once_gates_on_transactional_request(sender, transaction_manager, mocker): """An idempotent producer without a producer_id enqueues InitProducerId and waits on the transactional request instead of sending produce data.""" sender._transaction_manager = transaction_manager - mocker.patch.object(sender, '_client') mocker.patch.object(sender, '_send_producer_data') - mocker.patch.object(sender, '_maybe_send_pending_request', return_value=True) + mocker.patch.object(sender, '_maybe_send_pending_request', + new_callable=mocker.AsyncMock, return_value=True) + wakeup = _RecordingWakeup() + sender._wakeup = wakeup spy_init = mocker.spy(transaction_manager, 'init_producer_id') assert not transaction_manager.has_producer_id() - sender.run_once() + _drive(sender, sender._run_once) spy_init.assert_called_once() - sender._maybe_send_pending_request.assert_called_once() + sender._maybe_send_pending_request.assert_awaited_once() sender._send_producer_data.assert_not_called() - sender._client.poll.assert_called_once_with( - timeout_ms=sender.config['retry_backoff_ms']) + # _maybe_send_pending_request handled its own backoff; the loop just gates. + assert wakeup.calls == [] def test_run_once_fatal_error_aborts_batches(sender, transaction_manager, mocker): @@ -501,16 +528,16 @@ def test_run_once_fatal_error_aborts_batches(sender, transaction_manager, mocker transaction_manager.set_producer_id_and_epoch(ProducerIdAndEpoch(1000, 0)) error = Errors.ProducerFencedError() transaction_manager.transition_to_fatal_error(error) - mocker.patch.object(sender, '_client') mocker.patch.object(sender, '_send_producer_data') mocker.patch.object(sender, '_maybe_abort_batches') + wakeup = _RecordingWakeup() + sender._wakeup = wakeup - sender.run_once() + _drive(sender, sender._run_once) sender._maybe_abort_batches.assert_called_once_with(error) sender._send_producer_data.assert_not_called() - sender._client.poll.assert_called_once_with( - timeout_ms=sender.config['retry_backoff_ms']) + assert wakeup.calls == [sender.config['retry_backoff_ms'] / 1000] def test_run_once_abortable_error_aborts_undrained_batches(client, mocker): @@ -533,15 +560,16 @@ def test_run_once_abortable_error_aborts_undrained_batches(client, mocker): accumulator = RecordAccumulator(transaction_manager=tm) sender = Sender(client, cluster, accumulator, transaction_manager=tm) - mocker.patch.object(sender, '_client') mocker.patch.object(sender, '_send_producer_data', return_value=0) + wakeup = _RecordingWakeup() + sender._wakeup = wakeup spy_abort = mocker.spy(accumulator, 'abort_undrained_batches') - sender.run_once() + _drive(sender, sender._run_once) spy_abort.assert_called_once_with(error) sender._send_producer_data.assert_called_once() - sender._client.poll.assert_called_once_with(timeout_ms=0) + assert wakeup.calls == [0 / 1000] def test__send_producer_data_expiry_time_reset(sender, accumulator, mocker): @@ -1425,19 +1453,23 @@ def test_second_in_flight_error_does_not_cascade_bumps(self, sender, accumulator assert len(second_init_handlers) == 1 # still just one def test_sender_loop_gates_on_bumping_state(self, sender, accumulator, mocker): - """When in BUMPING_PRODUCER_EPOCH, run_once short-circuits before + """When in BUMPING_PRODUCER_EPOCH, _run_once short-circuits before sending produce data.""" from kafka.producer.transaction_manager import TransactionState as _TS tm = self._make_txn_manager() sender._transaction_manager = tm tm._current_state = _TS.BUMPING_PRODUCER_EPOCH mocker.patch.object(sender, '_send_producer_data') - mocker.patch.object(sender._client, 'poll') + # No pending coordinator request; fall through to the bumping gate. + mocker.patch.object(sender, '_maybe_send_pending_request', + new_callable=mocker.AsyncMock, return_value=False) + wakeup = _RecordingWakeup() + sender._wakeup = wakeup - sender.run_once() + _drive(sender, sender._run_once) sender._send_producer_data.assert_not_called() - sender._client.poll.assert_called_once() + assert wakeup.calls == [sender.config['retry_backoff_ms'] / 1000] class TestProducerAcks: