Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions aiohttp/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,10 @@ def _passthrough_loop_context(
else:
# this shadows loop_context's standard behavior
loop = setup_test_loop()
yield loop
teardown_test_loop(loop, fast=fast)
try:
yield loop
finally:
teardown_test_loop(loop, fast=fast)


def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def]
Expand Down
9 changes: 3 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from aiohttp.compression_utils import ZLibBackend, ZLibBackendProtocol, set_zlib_backend
from aiohttp.helpers import TimerNoop
from aiohttp.http import WS_KEY, HttpVersion11
from aiohttp.test_utils import get_unused_port_socket, loop_context
from aiohttp.test_utils import REUSE_ADDRESS, loop_context


def pytest_configure(config: pytest.Config) -> None:
Expand Down Expand Up @@ -387,11 +387,8 @@ def unused_port_socket() -> Iterator[socket.socket]:
race condition between checking if the port is in use and
binding to it later in the test.
"""
s = get_unused_port_socket("127.0.0.1")
try:
with socket.create_server(("127.0.0.1", 0), reuse_port=REUSE_ADDRESS) as s:
yield s
finally:
s.close()


@pytest.fixture(params=["zlib", "zlib_ng.zlib_ng", "isal.isal_zlib"])
Expand Down Expand Up @@ -433,7 +430,7 @@ def maker(
session = ClientSession()
sessions.append(session)
default_args: ClientRequestArgs = {
"loop": loop,
"loop": asyncio.get_running_loop(),
"params": {},
"headers": CIMultiDict[str](),
"skip_auto_headers": None,
Expand Down
89 changes: 40 additions & 49 deletions tests/test_client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@
from aiohttp.http_parser import HttpParser, RawResponseMessage


async def test_force_close(loop: asyncio.AbstractEventLoop) -> None:
async def test_force_close() -> None:
"""Ensure that the force_close method sets the should_close attribute to True.

This is used externally in aiodocker
https://github.com/aio-libs/aiodocker/issues/920
"""
proto = ResponseHandler(loop=loop)
proto = ResponseHandler(loop=asyncio.get_running_loop())
proto.force_close()
assert proto.should_close


async def test_oserror(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
async def test_oserror() -> None:
proto = ResponseHandler(loop=asyncio.get_running_loop())
transport = mock.Mock()
proto.connection_made(transport)
proto.connection_lost(OSError())
Expand All @@ -34,9 +34,9 @@ async def test_oserror(loop: asyncio.AbstractEventLoop) -> None:
assert isinstance(proto.exception(), ClientOSError)


async def test_pause_resume_on_error(loop: asyncio.AbstractEventLoop) -> None:
async def test_pause_resume_on_error() -> None:
parser = mock.create_autospec(HttpParser, spec_set=True, instance=True)
proto = ResponseHandler(loop=loop)
proto = ResponseHandler(loop=asyncio.get_running_loop())
proto._parser = parser
transport = mock.Mock()
proto.connection_made(transport)
Expand All @@ -48,8 +48,8 @@ async def test_pause_resume_on_error(loop: asyncio.AbstractEventLoop) -> None:
assert not proto._reading_paused


async def test_client_proto_bad_message(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
async def test_client_proto_bad_message() -> None:
proto = ResponseHandler(loop=asyncio.get_running_loop())
transport = mock.Mock()
proto.connection_made(transport)
proto.set_response_params()
Expand All @@ -60,8 +60,8 @@ async def test_client_proto_bad_message(loop: asyncio.AbstractEventLoop) -> None
assert isinstance(proto.exception(), http.HttpProcessingError)


async def test_uncompleted_message(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
async def test_uncompleted_message() -> None:
proto = ResponseHandler(loop=asyncio.get_running_loop())
transport = mock.Mock()
proto.connection_made(transport)
proto.set_response_params(read_until_eof=True)
Expand All @@ -78,8 +78,8 @@ async def test_uncompleted_message(loop: asyncio.AbstractEventLoop) -> None:
assert dict(exc.message.headers) == {"Location": "http://python.org/"}


async def test_data_received_after_close(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
async def test_data_received_after_close() -> None:
proto = ResponseHandler(loop=asyncio.get_running_loop())
transport = mock.Mock()
proto.connection_made(transport)
proto.set_response_params(read_until_eof=True)
Expand All @@ -92,9 +92,8 @@ async def test_data_received_after_close(loop: asyncio.AbstractEventLoop) -> Non
assert isinstance(proto.exception(), http.HttpProcessingError)


async def test_multiple_responses_one_byte_at_a_time(
loop: asyncio.AbstractEventLoop,
) -> None:
async def test_multiple_responses_one_byte_at_a_time() -> None:
loop = asyncio.get_running_loop()
proto = ResponseHandler(loop=loop)
proto.connection_made(mock.Mock())
conn = mock.Mock(protocol=proto)
Expand Down Expand Up @@ -128,9 +127,8 @@ async def test_multiple_responses_one_byte_at_a_time(
await response.read() == payload


async def test_unexpected_exception_during_data_received(
loop: asyncio.AbstractEventLoop,
) -> None:
async def test_unexpected_exception_during_data_received() -> None:
loop = asyncio.get_running_loop()
proto = ResponseHandler(loop=loop)

class PatchableHttpResponseParser(http.HttpResponseParser):
Expand Down Expand Up @@ -164,7 +162,8 @@ class PatchableHttpResponseParser(http.HttpResponseParser):
assert isinstance(proto.exception(), http.HttpProcessingError)


async def test_client_protocol_readuntil_eof(loop: asyncio.AbstractEventLoop) -> None:
async def test_client_protocol_readuntil_eof() -> None:
loop = asyncio.get_running_loop()
proto = ResponseHandler(loop=loop)
transport = mock.Mock()
proto.connection_made(transport)
Expand Down Expand Up @@ -203,32 +202,32 @@ async def test_client_protocol_readuntil_eof(loop: asyncio.AbstractEventLoop) ->
assert response.content.is_eof()


async def test_empty_data(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
async def test_empty_data() -> None:
proto = ResponseHandler(loop=asyncio.get_running_loop())
proto.data_received(b"")

# do nothing


async def test_schedule_timeout(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
async def test_schedule_timeout() -> None:
proto = ResponseHandler(loop=asyncio.get_running_loop())
proto.set_response_params(read_timeout=1)
assert proto._read_timeout_handle is None
proto.start_timeout()
assert proto._read_timeout_handle is not None


async def test_drop_timeout(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
async def test_drop_timeout() -> None:
proto = ResponseHandler(loop=asyncio.get_running_loop())
proto.set_response_params(read_timeout=1)
proto.start_timeout()
assert proto._read_timeout_handle is not None
proto._drop_timeout()
assert proto._read_timeout_handle is None


async def test_reschedule_timeout(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
async def test_reschedule_timeout() -> None:
proto = ResponseHandler(loop=asyncio.get_running_loop())
proto.set_response_params(read_timeout=1)
proto.start_timeout()
assert proto._read_timeout_handle is not None
Expand All @@ -238,23 +237,21 @@ async def test_reschedule_timeout(loop: asyncio.AbstractEventLoop) -> None:
assert proto._read_timeout_handle is not h


async def test_eof_received(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
async def test_eof_received() -> None:
proto = ResponseHandler(loop=asyncio.get_running_loop())
proto.set_response_params(read_timeout=1)
proto.start_timeout()
assert proto._read_timeout_handle is not None
proto.eof_received()
assert proto._read_timeout_handle is None


async def test_connection_lost_sets_transport_to_none(
loop: asyncio.AbstractEventLoop, mocker: MockerFixture
) -> None:
async def test_connection_lost_sets_transport_to_none(mocker: MockerFixture) -> None:
"""Ensure that the transport is set to None when the connection is lost.

This ensures the writer knows that the connection is closed.
"""
proto = ResponseHandler(loop=loop)
proto = ResponseHandler(loop=asyncio.get_running_loop())
proto.connection_made(mocker.Mock())
assert proto.transport is not None

Expand All @@ -263,11 +260,9 @@ async def test_connection_lost_sets_transport_to_none(
assert proto.transport is None


async def test_connection_lost_exception_is_marked_retrieved(
loop: asyncio.AbstractEventLoop,
) -> None:
async def test_connection_lost_exception_is_marked_retrieved() -> None:
"""Test that connection_lost properly handles exceptions without warnings."""
proto = ResponseHandler(loop=loop)
proto = ResponseHandler(loop=asyncio.get_running_loop())
proto.connection_made(mock.Mock())

# Access closed property before connection_lost to ensure future is created
Expand All @@ -286,11 +281,9 @@ async def test_connection_lost_exception_is_marked_retrieved(
assert exc.__cause__ is ssl_error


async def test_closed_property_lazy_creation(
loop: asyncio.AbstractEventLoop,
) -> None:
async def test_closed_property_lazy_creation() -> None:
"""Test that closed future is created lazily."""
proto = ResponseHandler(loop=loop)
proto = ResponseHandler(loop=asyncio.get_running_loop())

# Initially, the closed future should not be created
assert proto._closed is None
Expand All @@ -305,11 +298,9 @@ async def test_closed_property_lazy_creation(
assert proto.closed is closed_future


async def test_closed_property_after_connection_lost(
loop: asyncio.AbstractEventLoop,
) -> None:
async def test_closed_property_after_connection_lost() -> None:
"""Test that closed property returns None after connection_lost if never accessed."""
proto = ResponseHandler(loop=loop)
proto = ResponseHandler(loop=asyncio.get_running_loop())
proto.connection_made(mock.Mock())

# Don't access proto.closed before connection_lost
Expand All @@ -319,9 +310,9 @@ async def test_closed_property_after_connection_lost(
assert proto.closed is None


async def test_abort(loop: asyncio.AbstractEventLoop) -> None:
async def test_abort() -> None:
"""Test the abort() method."""
proto = ResponseHandler(loop=loop)
proto = ResponseHandler(loop=asyncio.get_running_loop())

# Create a mock transport
transport = mock.Mock()
Expand All @@ -345,9 +336,9 @@ async def test_abort(loop: asyncio.AbstractEventLoop) -> None:
mock_drop_timeout.assert_called_once()


async def test_abort_without_transport(loop: asyncio.AbstractEventLoop) -> None:
async def test_abort_without_transport() -> None:
"""Test abort() when transport is None."""
proto = ResponseHandler(loop=loop)
proto = ResponseHandler(loop=asyncio.get_running_loop())

# Mock _drop_timeout method using patch.object
with mock.patch.object(proto, "_drop_timeout") as mock_drop_timeout:
Expand Down
Loading
Loading