Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kafka/net/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def stop(self, timeout_ms=None):
waiters with KafkaConnectionError. Idempotent."""
t = self._io_thread
if t is None:
self._net.poll(drain=True)
self._net.drain()
return
self._io_thread = None
self._net.stop()
Expand Down
156 changes: 78 additions & 78 deletions kafka/net/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ def run_until_done(self, task_or_future):
self._poll_once()
return task_or_future

def drain(self):
while self._ready:
self._poll_once()

def call_at(self, when, task):
if not isinstance(task, Task):
task = Task(task)
Expand Down Expand Up @@ -218,23 +222,20 @@ def sleep(self, delay):

def _sleep(self, delay):
self.call_later(delay, self._current)
self._current = None

def wait_write(self, fileobj):
return KernelEvent('_wait_write', fileobj)

def _wait_write(self, fileobj):
self.register_event(fileobj, selectors.EVENT_WRITE, self._current)
self._current.push_stack(lambda: self.unregister_event(fileobj, selectors.EVENT_WRITE))
self._current = None

def wait_read(self, fileobj):
return KernelEvent('_wait_read', fileobj)

def _wait_read(self, fileobj):
self.register_event(fileobj, selectors.EVENT_READ, self._current)
self._current.push_stack(lambda: self.unregister_event(fileobj, selectors.EVENT_READ))
self._current = None

def _schedule_tasks(self):
while self._scheduled and self._scheduled[0][0] <= time.monotonic():
Expand Down Expand Up @@ -289,9 +290,24 @@ def add_writer(self, fileobj, task):
def remove_writer(self, fileobj):
self.unregister_event(fileobj, selectors.EVENT_WRITE)

def poll(self, timeout_ms=None, future=None, drain=False):
if drain and future:
raise ValueError('Cannot set both drain and future')
def poll(self, timeout_ms=None, future=None):
log_trace('poll: enter')
start_at = time.monotonic()
inner_timeout = timeout_ms / 1000 if timeout_ms is not None else None
if future is not None and future.is_done:
inner_timeout = 0
while True:
self._poll_once(inner_timeout)
if future is None or future.is_done:
break
elif timeout_ms is not None:
inner_timeout = (timeout_ms / 1000) - (time.monotonic() - start_at)
if inner_timeout <= 0:
break
log_trace('poll: exit')

def _poll_once(self, timeout=None):
log_trace('_poll_once: enter')
if not self._poll_lock.acquire(blocking=False):
# Lock contended. Distinguish recursive (this thread is already
# in poll, e.g. via a task callback) from concurrent (a different
Expand All @@ -302,84 +318,68 @@ def poll(self, timeout_ms=None, future=None, drain=False):
raise RuntimeError('Concurrent access to net.poll!')
self._poll_owner = threading.current_thread()
try:
log_trace('poll: enter')
start_at = time.monotonic()
inner_timeout = timeout_ms / 1000 if timeout_ms is not None else None
if future is not None and future.is_done:
inner_timeout = 0
while (not drain) or (drain and self._ready):
self._poll_once(inner_timeout)
if future is None or future.is_done:
break
elif timeout_ms is not None:
inner_timeout = (timeout_ms / 1000) - (time.monotonic() - start_at)
if inner_timeout <= 0:
break
finally:
self._poll_owner = None
self._poll_lock.release()
log_trace('poll: exit')

def _poll_once(self, timeout=None):
log_trace('_poll_once: enter')
if self._ready:
timeout = 0
else:
scheduled_timeout = self._next_scheduled_timeout(time.monotonic())
if scheduled_timeout is not None:
timeout = min(timeout, scheduled_timeout) if timeout is not None else scheduled_timeout
if timeout is not None:
if timeout > MAX_TIMEOUT:
timeout = MAX_TIMEOUT
elif timeout < 0:
if self._ready:
timeout = 0
else:
scheduled_timeout = self._next_scheduled_timeout(time.monotonic())
if scheduled_timeout is not None:
timeout = min(timeout, scheduled_timeout) if timeout is not None else scheduled_timeout
if timeout is not None:
if timeout > MAX_TIMEOUT:
timeout = MAX_TIMEOUT
elif timeout < 0:
timeout = 0
elif not self._selector.get_map():
timeout = 0
elif not self._selector.get_map():
timeout = 0

ready_events = self._selector.select(timeout)
log_trace('_poll_once: %d ready_events', len(ready_events))
self._process_events(ready_events)
self._schedule_tasks()

threshold = self.config['slow_task_threshold_secs']
n = len(self._ready)
for i in range(n):
self._current = self._ready.popleft()
step_start = time.monotonic() if threshold else None
try:
log_trace('Calling task %s', self._current)
event = self._current()

except StopIteration:
pass
ready_events = self._selector.select(timeout)
log_trace('_poll_once: %d ready_events', len(ready_events))
self._process_events(ready_events)
self._schedule_tasks()

except BaseException as e:
log.exception(e)
threshold = self.config['slow_task_threshold_secs']
n = len(self._ready)
for i in range(n):
self._current = self._ready.popleft()
step_start = time.monotonic() if threshold else None
try:
log_trace('Calling task %s', self._current)
event = self._current()

except StopIteration:
pass

except BaseException as e:
log.exception(e)

else:
if isinstance(event, KernelEvent):
log_trace('kernel event %s', event.method)
getattr(self, event.method)(*event.args)
elif isinstance(event, Future):
event.add_both(lambda _, task=self._current: self.call_soon(task))
else:
raise RuntimeError('Unhandled event type: %s' % event)

if threshold:
elapsed = time.monotonic() - step_start
if elapsed > threshold:
msg = (
'Task %r ran for %.3fs (>%.3fs threshold). It is '
'blocking the event loop -- likely a tight sync loop '
'inside a coroutine. Other pollers will time out.'
% (self._current, elapsed, threshold))
if self.config['raise_on_slow_task']:
self._current = None
raise RuntimeError(msg)
log.warning(msg)
if isinstance(event, KernelEvent):
log_trace('kernel event %s', event.method)
getattr(self, event.method)(*event.args)
elif isinstance(event, Future):
event.add_both(lambda _, task=self._current: self.call_soon(task))
else:
raise RuntimeError('Unhandled event type: %s' % event)

if threshold:
elapsed = time.monotonic() - step_start
if elapsed > threshold:
msg = (
'Task %r ran for %.3fs (>%.3fs threshold). It is '
'blocking the event loop -- likely a tight sync loop '
'inside a coroutine. Other pollers will time out.'
% (self._current, elapsed, threshold))
if self.config['raise_on_slow_task']:
self._current = None
raise RuntimeError(msg)
log.warning(msg)

self._current = None

self._current = None
log_trace('_poll_once: exit')
finally:
self._poll_owner = None
self._poll_lock.release()
log_trace('_poll_once: exit')

def wakeup(self):
try:
Expand Down
Loading