Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions kafka/net/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def __init__(self, net, sock, host=None):
self._closed = False
self._write_buffer = deque()
self._writing = False
self._read_task = None
self._write_task = None
self._protocol = None
self._read = False
self._write = True
Expand Down Expand Up @@ -80,7 +82,7 @@ def resume_reading(self):
data_received() method.
"""
if not self._read:
self._net.call_soon(self._read_from_sock)
self._read_task = self._net.call_soon(self._read_from_sock)
self._read = True
log.debug('%s: Resumed reading', self)

Expand Down Expand Up @@ -173,7 +175,7 @@ def write(self, data):
self._write_buffer.append(data)
if not self._writing:
self._writing = True
self._net.call_soon(self._write_to_sock)
self._write_task = self._net.call_soon(self._write_to_sock)

def writelines(self, list_of_data):
"""Write a list (or any iterable) of data bytes to the transport."""
Expand All @@ -182,7 +184,7 @@ def writelines(self, list_of_data):
self._write_buffer.extend(list_of_data)
if not self._writing:
self._writing = True
self._net.call_soon(self._write_to_sock)
self._write_task = self._net.call_soon(self._write_to_sock)

async def _write_to_sock(self):
try:
Expand Down Expand Up @@ -272,6 +274,10 @@ def _close(self, error=None):
except OSError:
pass
sock.close()
for task in (self._read_task, self._write_task):
if task is not None:
self._net.cancel(task)
self._read_task = self._write_task = None
proto = self._protocol
self._protocol = None
if proto is not None:
Expand Down
57 changes: 56 additions & 1 deletion test/net/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import kafka.errors as Errors
from kafka.future import Future
from kafka.net.selector import NetworkSelector
from kafka.net.selector import NetworkSelector, TaskState
from kafka.net.transport import KafkaTCPTransport


Expand Down Expand Up @@ -365,3 +365,58 @@ def test_str_closed(self, net):
t._closed = True
s = str(t)
assert 'closed' in s


class TestTransportWaiterCleanup:
"""Regression: a locally-initiated close()/abort() must reclaim the socket
read/write coroutine tasks parked in the event loop.

These tests fail until ``KafkaTCPTransport._close`` cancels its read/write
waiter tasks (``net.cancel(task)``); the selector's existing WAIT_IO branch
in ``cancel()`` then drives the io_guard finalizer and discards the task.
"""

def test_local_close_reclaims_parked_reader(self, net, socketpair):
rsock, wsock = socketpair
t = KafkaTCPTransport(net, wsock)
t.set_protocol(MagicMock())
baseline = len(net._pending_tasks)

t.resume_reading()
net.drain() # reader runs and parks in WAIT_IO on wait_read(wsock)
parked = [task for task in net._pending_tasks if task.state is TaskState.WAIT_IO]
assert len(parked) == 1, 'reader did not park as expected'
assert len(net._pending_tasks) == baseline + 1

# Empty write buffer -> close() tears the socket down synchronously. The
# peer (rsock) never sent anything, so there is no I/O event to wake the
# parked reader; close() itself must reclaim it.
t.close()
net.drain() # running the loop must not be needed -- and must not help either

assert t._sock is None
assert len(net._pending_tasks) == baseline, (
'parked reader leaked into _pending_tasks after local close')
assert not any(task.state is TaskState.WAIT_IO for task in net._pending_tasks)

def test_local_close_reclaims_parked_writer(self, net, socketpair):
rsock, wsock = socketpair
t = KafkaTCPTransport(net, wsock)
t.set_protocol(MagicMock())
baseline = len(net._pending_tasks)

# write() schedules _write_to_sock, whose loop parks on the first
# wait_write before the bytes leave the buffer. drain() steps it once
# to the park and returns (no _ready left), leaving it suspended.
t.write(b'data')
net.drain()
parked = [task for task in net._pending_tasks if task.state is TaskState.WAIT_IO]
assert len(parked) == 1, 'writer did not park as expected'

t.abort(error=Errors.KafkaConnectionError('boom'))
net.drain()

assert t._sock is None
assert len(net._pending_tasks) == baseline, (
'parked writer leaked into _pending_tasks after abort')
assert not any(task.state is TaskState.WAIT_IO for task in net._pending_tasks)