Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions kafka/net/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ def exception(self):
class NetworkSelector:
DEFAULT_CONFIG = {
'selector': selectors.DefaultSelector,
# Warn (or, in debug mode, raise) when a single ready-task step takes
# longer than this many seconds. A coroutine that hits this threshold
# is blocking the event loop -- common cause is a tight sync loop
# over a synchronously-raising await (see cluster._refresh_loop hang
# where RuntimeError from a closed manager was caught and retried).
# Mirrors asyncio's loop.slow_callback_duration. Set to 0 to disable.
'slow_task_threshold_secs': 0.1,
# When True, raise RuntimeError on slow tasks instead of just warning.
# Useful in tests so livelocks fail loudly.
'raise_on_slow_task': False,
}

def __init__(self, **configs):
Expand All @@ -126,7 +136,14 @@ def __init__(self, **configs):
if key in configs:
self.config[key] = configs[key]

self._lock = threading.Lock()
# Used by poll() as both a mutex (cross-thread concurrent-entry guard)
# and the in-loop flag. acquire(blocking=False) doubles as the
# "is anyone in poll() right now?" check. Held only across poll()'s
# body; never held by anything else.
# _poll_owner tracks which thread holds the lock so we can produce
# an accurate diagnostic (recursive vs concurrent) on contention.
self._poll_lock = threading.Lock()
self._poll_owner = None
self._closed = False
self._stop = False
self._selector = self.config['selector']()
Expand Down Expand Up @@ -273,10 +290,15 @@ def remove_writer(self, fileobj):
self.unregister_event(fileobj, selectors.EVENT_WRITE)

def poll(self, timeout_ms=None, future=None):
if self._current:
raise RuntimeError('Recursive access to net.poll!')
elif not self._lock.acquire(blocking=False):
if not self._poll_lock.acquire(blocking=False):
# Lock contended. Distinguish recursive (this thread is already
# in poll, e.g. via a task callback) from concurrent (a different
# thread is in poll). Same-thread reentry of a non-RLock fails
# the same way as cross-thread contention.
if self._poll_owner is threading.current_thread():
raise RuntimeError('Recursive access to net.poll!')
raise RuntimeError('Concurrent access to net.poll!')
self._poll_owner = threading.current_thread()
try:
log_trace('poll: enter')
start_at = time.monotonic()
Expand All @@ -292,7 +314,8 @@ def poll(self, timeout_ms=None, future=None):
if inner_timeout <= 0:
break
finally:
self._lock.release()
self._poll_owner = None
self._poll_lock.release()
log_trace('poll: exit')

def _poll_once(self, timeout=None):
Expand All @@ -316,9 +339,11 @@ def _poll_once(self, timeout=None):
self._process_events(ready_events)
self._schedule_tasks()

threshold = self.config['slow_task_threshold_secs']
n = len(self._ready)
for i in range(n):
self._current = self._ready.popleft()
step_start = time.monotonic() if threshold else None
try:
log_trace('Calling task %s', self._current)
event = self._current()
Expand All @@ -338,6 +363,19 @@ def _poll_once(self, timeout=None):
else:
raise RuntimeError('Unhandled event type: %s' % event)

if threshold:
elapsed = time.monotonic() - step_start
if elapsed > threshold:
msg = (
'Task %r ran for %.3fs (>%.3fs threshold). It is '
'blocking the event loop -- likely a tight sync loop '
'inside a coroutine. Other pollers will time out.'
% (self._current, elapsed, threshold))
if self.config['raise_on_slow_task']:
self._current = None
raise RuntimeError(msg)
log.warning(msg)

self._current = None
log_trace('_poll_once: exit')

Expand Down
138 changes: 138 additions & 0 deletions test/net/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,141 @@ async def resolver():
net.call_soon(resolver)
net.poll(timeout_ms=1000, future=done)
assert results == [('a', 'b')]


class TestSlowTaskMonitor:
"""Detection for tasks that hog the event loop (livelock guard).

See task #44: a coroutine in a tight sync loop never yields back to the
selector. From the outside this looks like a hang; with monitoring it
becomes a clean warning (or, in raise-mode, a RuntimeError).
"""

def test_slow_task_warns_with_default_threshold(self, caplog):
net = NetworkSelector(slow_task_threshold_secs=0.01)
done = Future()

async def hog():
time.sleep(0.05) # synchronous sleep — does not yield to loop
done.success(True)

net.call_soon(hog)
with caplog.at_level('WARNING', logger='kafka.net.selector'):
net.poll(timeout_ms=1000, future=done)
assert any('blocking the event loop' in rec.message for rec in caplog.records), (
'expected slow-task warning, got: %r'
% [(r.levelname, r.message) for r in caplog.records])
assert done.succeeded()

def test_slow_task_below_threshold_no_warning(self, caplog):
net = NetworkSelector(slow_task_threshold_secs=0.5)
done = Future()

async def quick():
done.success(True)

net.call_soon(quick)
with caplog.at_level('WARNING', logger='kafka.net.selector'):
net.poll(timeout_ms=1000, future=done)
assert not any('blocking the event loop' in rec.message for rec in caplog.records)

def test_slow_task_disabled_when_threshold_zero(self, caplog):
net = NetworkSelector(slow_task_threshold_secs=0)
done = Future()

async def hog():
time.sleep(0.02)
done.success(True)

net.call_soon(hog)
with caplog.at_level('WARNING', logger='kafka.net.selector'):
net.poll(timeout_ms=1000, future=done)
assert not any('blocking the event loop' in rec.message for rec in caplog.records)

def test_slow_task_raise_mode(self):
net = NetworkSelector(slow_task_threshold_secs=0.01,
raise_on_slow_task=True)
done = Future()

async def hog():
time.sleep(0.05)
done.success(True)

net.call_soon(hog)
with pytest.raises(RuntimeError, match='blocking the event loop'):
net.poll(timeout_ms=1000, future=done)

def test_concurrent_poll_raises(self):
"""Two threads calling poll() simultaneously should raise instead of
racing on selector / task state."""
net = NetworkSelector()
gate = threading.Event()
done = Future()
errors = []

async def slow():
gate.set()
time.sleep(0.1)
done.success(True)

def driver_a():
net.call_soon(slow)
net.poll(timeout_ms=1000, future=done)

def driver_b():
gate.wait(timeout=1)
try:
net.poll(timeout_ms=10)
except RuntimeError as exc:
errors.append(str(exc))

ta = threading.Thread(target=driver_a)
tb = threading.Thread(target=driver_b)
ta.start()
tb.start()
ta.join(2)
tb.join(2)
assert errors and 'Concurrent access' in errors[0], (
'expected Concurrent access error, got: %r' % errors)

def test_recursive_poll_raises_recursive_error(self):
"""A task callback calling poll() reentrantly should be diagnosed as
recursive (same-thread), not as concurrent."""
net = NetworkSelector()
errors = []

async def reenter():
try:
net.poll(timeout_ms=10)
except RuntimeError as exc:
errors.append(str(exc))

net.call_soon(reenter)
net.poll(timeout_ms=100)
assert errors and 'Recursive access' in errors[0], (
'expected Recursive access error, got: %r' % errors)

def test_poll_lock_released_on_exception(self):
"""An exception in _poll_once must release the poll lock so the next
caller doesn't see a stale 'Concurrent access' error."""
net = NetworkSelector()

# Inject a coroutine that raises a base-level error to escape the
# per-task BaseException catch (StopIteration / Exception are caught).
# We use a custom signal: monkey-patch _poll_once to raise.
orig = net._poll_once
first_call = [True]

def _poll_once_raising(*args, **kwargs):
if first_call[0]:
first_call[0] = False
raise KeyboardInterrupt('simulated Ctrl-C')
return orig(*args, **kwargs)

net._poll_once = _poll_once_raising
with pytest.raises(KeyboardInterrupt):
net.poll(timeout_ms=10)

# Restore and verify the lock was released so the next poll succeeds.
net._poll_once = orig
net.poll(timeout_ms=10) # would raise 'Concurrent access' if leaked
Loading