diff --git a/kafka/net/compat.py b/kafka/net/compat.py index cc959d3fe..fe71cd511 100644 --- a/kafka/net/compat.py +++ b/kafka/net/compat.py @@ -136,7 +136,6 @@ def poll(self, timeout_ms=None, future=None): def close(self, node_id=None): self._manager.close(node_id=node_id) if node_id is None: - self._manager.stop() self._net.close() def least_loaded_node(self, bootstrap_fallback=False): diff --git a/kafka/net/manager.py b/kafka/net/manager.py index 138744d38..13033d836 100644 --- a/kafka/net/manager.py +++ b/kafka/net/manager.py @@ -337,7 +337,7 @@ def connection_delay(self, node_id): return 0 return max(0, self._backoff[node_id][1] - time.monotonic()) - def close(self, node_id=None): + def close(self, node_id=None, timeout_ms=None): if node_id is not None: conn = self._conns.get(node_id) if conn is not None: @@ -348,6 +348,7 @@ def close(self, node_id=None): for conn in list(self._conns.values()): conn.close() self.cluster.close() + self.stop(timeout_ms) def start(self): """Spawn a daemon IO thread that owns the event loop. Idempotent.""" @@ -359,15 +360,16 @@ def start(self): self._io_thread = t t.start() - def stop(self, timeout=None): + def stop(self, timeout_ms=None): """Signal the IO thread to exit and join it. Fails any pending run() waiters with KafkaConnectionError. Idempotent.""" t = self._io_thread if t is None: + self._net.poll(drain=True) return self._io_thread = None self._net.stop() - t.join(timeout) + t.join(timeout_ms / 1000 if timeout_ms is not None else None) with self._pending_waiters_lock: waiters = list(self._pending_waiters.items()) self._pending_waiters.clear() diff --git a/kafka/net/selector.py b/kafka/net/selector.py index 0ddc68ca6..870388c2d 100644 --- a/kafka/net/selector.py +++ b/kafka/net/selector.py @@ -289,7 +289,9 @@ def add_writer(self, fileobj, task): def remove_writer(self, fileobj): self.unregister_event(fileobj, selectors.EVENT_WRITE) - def poll(self, timeout_ms=None, future=None): + def poll(self, timeout_ms=None, future=None, drain=False): + if drain and future: + raise ValueError('Cannot set both drain and future') if not self._poll_lock.acquire(blocking=False): # Lock contended. Distinguish recursive (this thread is already # in poll, e.g. via a task callback) from concurrent (a different @@ -305,7 +307,7 @@ def poll(self, timeout_ms=None, future=None): inner_timeout = timeout_ms / 1000 if timeout_ms is not None else None if future is not None and future.is_done: inner_timeout = 0 - while True: + while (not drain) or (drain and self._ready): self._poll_once(inner_timeout) if future is None or future.is_done: break