diff --git a/docs/source/release-history/v11.0.0.md b/docs/source/release-history/v11.0.0.md index d0ccce7f..9f264933 100644 --- a/docs/source/release-history/v11.0.0.md +++ b/docs/source/release-history/v11.0.0.md @@ -13,3 +13,11 @@ Release date: 2026-xx-x Python 3.9 reached [end-of-life](https://devguide.python.org/developer-workflow/development-cycle/index.html#end-of-life-branches) on [October 31, 2025](https://devguide.python.org/versions/). It is no longer receiving any updates, even security updates. The MSS project has chosen to end support for Python 3.9, in order to focus our resources on current versions of Python. + +### Windows Improvements + +Improved error handling when interacting with Win32 API, which will improve diagnostics of issues. + +### General Improvements + +The MSS context object will now always surface inner exceptions, even if `__exit__` may also generate an exception during tear-down. diff --git a/src/mss/base.py b/src/mss/base.py index 85cee1c8..f5f373bf 100644 --- a/src/mss/base.py +++ b/src/mss/base.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator + from types import TracebackType from mss.models import Monitor, Monitors, Size @@ -260,9 +261,21 @@ def __enter__(self) -> Self: """For the cool call `with MSS() as mss:`.""" return self - def __exit__(self, *_: object) -> None: + def __exit__( + self, + _exc_type: type[BaseException] | None, + exc_value: BaseException | None, + _traceback: TracebackType | None, + ) -> None: """For the cool call `with MSS() as mss:`.""" - self.close() + try: + self.close() + except Exception: + # This extra work is needed so that exceptions generated during __exit__ + # will not swallow exceptions that caused the __exit__ to be called + if exc_value is not None: + return + raise def close(self) -> None: """Clean up. diff --git a/src/mss/windows/gdi.py b/src/mss/windows/gdi.py index 18d41c88..02f5ba62 100644 --- a/src/mss/windows/gdi.py +++ b/src/mss/windows/gdi.py @@ -109,23 +109,39 @@ class DISPLAY_DEVICEW(Structure): # noqa: N801 MONITORNUMPROC = WINFUNCTYPE(BOOL, HMONITOR, HDC, POINTER(RECT), LPARAM) -def _errcheck(result: int | _Pointer, func: Callable, arguments: tuple) -> tuple: - """If the result is zero, raise an exception.""" +def _check_result(result: int | _Pointer, func: Callable, arguments: tuple) -> tuple: + """Raise if ``result`` is 0/NULL for functions that do not document GetLastError.""" if not result: - # Notably, the errno that is in winerror may not be relevant. Use the winerror and strerror attributes - # instead. - winerror = WinError() details = { "func": func.__name__, "args": arguments, - "error_code": winerror.winerror, - "error_msg": winerror.strerror, + "error_msg": "The function returned a failure value.", } - if winerror.winerror == 0: - # Some functions return NULL/0 on failure without setting last error. (Example: CreateDIBSection - # with an invalid HDC.) - msg = f"Windows graphics function failed (no error provided): {func.__name__}" + msg = f"Windows graphics function returned failure: {func.__name__}" + raise ScreenShotError(msg, details=details) + return arguments + + +def _check_result_with_last_error(result: int | _Pointer, func: Callable, arguments: tuple) -> tuple: + """Raise if ``result`` is 0/NULL for functions that document GetLastError.""" + if not result: + error_code = ctypes.get_last_error() + details = { + "func": func.__name__, + "args": arguments, + "error_code": error_code, + } + if error_code == 0: + # Do not use WinError(0) here: its message is "The operation completed successfully", which makes the + # failure look like a success. + details["error_msg"] = ( + "The function returned a failure value, but no Windows last-error code was available." + ) + msg = f"Windows graphics function returned failure but no last-error code was available: {func.__name__}" raise ScreenShotError(msg, details=details) + # Notably, the errno that is in winerror may not be relevant. Use the winerror and strerror attributes instead. + winerror = WinError(error_code) + details["error_msg"] = winerror.strerror msg = f"Windows graphics function failed: {func.__name__}: {winerror.strerror}" raise ScreenShotError(msg, details=details) from winerror return arguments @@ -138,15 +154,20 @@ def _errcheck(result: int | _Pointer, func: Callable, arguments: tuple) -> tuple # Note: keep it sorted by cfunction. CFUNCTIONS: CFunctionsErrChecked = { # Syntax: cfunction: (attr, argtypes, restype, errcheck) - "BitBlt": ("gdi32", [HDC, INT, INT, INT, INT, HDC, INT, INT, DWORD], BOOL, _errcheck), - "CreateCompatibleDC": ("gdi32", [HDC], HDC, _errcheck), + "BitBlt": ("gdi32", [HDC, INT, INT, INT, INT, HDC, INT, INT, DWORD], BOOL, _check_result_with_last_error), + "CreateCompatibleDC": ("gdi32", [HDC], HDC, _check_result), # CreateDIBSection: ppvBits (4th param) receives a pointer to the DIB pixel data. # hSection is NULL and offset is 0 to have the system allocate the memory. - "CreateDIBSection": ("gdi32", [HDC, POINTER(BITMAPINFO), UINT, POINTER(LPVOID), HANDLE, DWORD], HBITMAP, _errcheck), - "DeleteDC": ("gdi32", [HDC], HDC, _errcheck), - "DeleteObject": ("gdi32", [HGDIOBJ], BOOL, _errcheck), + "CreateDIBSection": ( + "gdi32", + [HDC, POINTER(BITMAPINFO), UINT, POINTER(LPVOID), HANDLE, DWORD], + HBITMAP, + _check_result_with_last_error, + ), + "DeleteDC": ("gdi32", [HDC], BOOL, _check_result), + "DeleteObject": ("gdi32", [HGDIOBJ], BOOL, _check_result), "EnumDisplayDevicesW": ("user32", [POINTER(WORD), DWORD, POINTER(DISPLAY_DEVICEW), DWORD], BOOL, None), - "EnumDisplayMonitors": ("user32", [HDC, LPCRECT, MONITORNUMPROC, LPARAM], BOOL, _errcheck), + "EnumDisplayMonitors": ("user32", [HDC, LPCRECT, MONITORNUMPROC, LPARAM], BOOL, _check_result), # GdiFlush flushes the calling thread's current batch of GDI operations. # This ensures DIB memory is fully updated before reading. "GdiFlush": ("gdi32", [], BOOL, None), @@ -154,12 +175,12 @@ def _errcheck(result: int | _Pointer, func: Callable, arguments: tuple) -> tuple # parameter is valid but the value is actually 0 (e.g., SM_CLEANBOOT on a normal boot). Thus, we do not attach an # errcheck function here. "GetSystemMetrics": ("user32", [INT], INT, None), - "GetMonitorInfoW": ("user32", [HMONITOR, POINTER(MONITORINFOEXW)], BOOL, _errcheck), - "GetWindowDC": ("user32", [HWND], HDC, _errcheck), - "ReleaseDC": ("user32", [HWND, HDC], INT, _errcheck), + "GetMonitorInfoW": ("user32", [HMONITOR, POINTER(MONITORINFOEXW)], BOOL, _check_result), + "GetWindowDC": ("user32", [HWND], HDC, _check_result), + "ReleaseDC": ("user32", [HWND, HDC], INT, _check_result), # SelectObject returns NULL on error the way we call it. If it's called to select a region, it returns HGDI_ERROR # on error. - "SelectObject": ("gdi32", [HDC, HGDIOBJ], HGDIOBJ, _errcheck), + "SelectObject": ("gdi32", [HDC, HGDIOBJ], HGDIOBJ, _check_result), } diff --git a/src/tests/test_implementation.py b/src/tests/test_implementation.py index 335e26f6..278fe3ca 100644 --- a/src/tests/test_implementation.py +++ b/src/tests/test_implementation.py @@ -20,12 +20,13 @@ from mss.base import MSS, MSSImplementation from mss.exception import ScreenShotError from mss.screenshot import ScreenShot +from tests.thread_helpers import run_threads if TYPE_CHECKING: from collections.abc import Callable from typing import Any - from mss.models import Monitor, Monitors + from mss.models import Monitor, Monitors, Size try: from datetime import UTC @@ -54,12 +55,55 @@ def monitors(self) -> Monitors: return [] +class MSSCloseRaises(MSSImplementation): + """Implementation whose cleanup fails.""" + + def __init__(self, close_error: Exception) -> None: + super().__init__() + self.close_error = close_error + + def cursor(self) -> None: + pass + + def grab(self, _: Monitor) -> bytearray | tuple[bytearray, Size]: + return bytearray() + + def monitors(self) -> Monitors: + return [] + + def close(self) -> None: + raise self.close_error + + @pytest.mark.parametrize("cls", [MSS0, MSS1, MSS2]) def test_incomplete_class(cls: type[MSSImplementation]) -> None: with pytest.raises(TypeError): cls() +def test_context_manager_keeps_body_exception_when_close_fails(monkeypatch: pytest.MonkeyPatch) -> None: + body_error = RuntimeError("body failed") + close_error = RuntimeError("close failed") + impl = MSSCloseRaises(close_error) + monkeypatch.setattr("mss.base._choose_impl", lambda **_kwargs: impl) + + with pytest.raises(RuntimeError, match="body failed") as exc, MSS(): + raise body_error + + assert exc.value is body_error + + +def test_context_manager_reports_close_failure_after_clean_exit(monkeypatch: pytest.MonkeyPatch) -> None: + close_error = RuntimeError("close failed") + impl = MSSCloseRaises(close_error) + monkeypatch.setattr("mss.base._choose_impl", lambda **_kwargs: impl) + + with pytest.raises(RuntimeError, match="close failed") as exc, MSS(): + pass + + assert exc.value is close_error + + def test_bad_monitor(mss_impl: Callable[..., MSS]) -> None: with mss_impl() as sct, pytest.raises(ScreenShotError): sct.shot(mon=222) @@ -303,17 +347,8 @@ def record() -> None: checkpoint[threading.current_thread()] = True - checkpoint: dict = {} - t1 = threading.Thread(target=record) - t2 = threading.Thread(target=record) - - t1.start() - time.sleep(0.5) - t2.start() - - t1.join() - t2.join() - + checkpoint: dict[threading.Thread, bool] = {} + run_threads(record, record, start_delay=0.5) assert len(checkpoint) == 2 def test_issue_169(self, backend: str) -> None: diff --git a/src/tests/test_setup.py b/src/tests/test_setup.py index cb629d16..fde4dcea 100644 --- a/src/tests/test_setup.py +++ b/src/tests/test_setup.py @@ -114,6 +114,7 @@ def test_sdist() -> None: f"mss-{__version__}/src/tests/third_party/__init__.py", f"mss-{__version__}/src/tests/third_party/test_numpy.py", f"mss-{__version__}/src/tests/third_party/test_pil.py", + f"mss-{__version__}/src/tests/thread_helpers.py", f"mss-{__version__}/src/xcbproto/README.md", f"mss-{__version__}/src/xcbproto/gen_xcb_to_py.py", f"mss-{__version__}/src/xcbproto/randr.xml", diff --git a/src/tests/test_windows.py b/src/tests/test_windows.py index 41a43646..612a98f9 100644 --- a/src/tests/test_windows.py +++ b/src/tests/test_windows.py @@ -4,16 +4,17 @@ from __future__ import annotations -import threading +import ctypes import pytest import mss from mss.exception import ScreenShotError +from tests.thread_helpers import run_threads try: import mss.windows - from mss.windows.gdi import MSSImplGdi + from mss.windows.gdi import MSSImplGdi, _check_result, _check_result_with_last_error except ImportError: pytestmark = pytest.mark.skip @@ -33,6 +34,70 @@ def test_factory_gdi_backend() -> None: assert type(gdi_sct._impl) is MSSImplGdi +def test_check_result_with_last_error_zero_is_not_reported_as_success() -> None: + """A failed Windows API call may leave ``GetLastError()`` set to 0.""" + + def fake_func() -> int: + return 0 + + previous_last_error = ctypes.get_last_error() + try: + ctypes.set_last_error(0) + with pytest.raises(ScreenShotError, match="returned failure but no last-error code was available") as exc: + _check_result_with_last_error(0, fake_func, ()) + finally: + ctypes.set_last_error(previous_last_error) + + assert exc.value.details["error_code"] == 0 + assert exc.value.details["error_msg"] == ( + "The function returned a failure value, but no Windows last-error code was available." + ) + + +def test_check_result_with_last_error_reports_error_code() -> None: + """A failed Windows API call should report a non-zero ``GetLastError()`` value.""" + + def fake_func() -> int: + return 0 + + error_code = 8 + previous_last_error = ctypes.get_last_error() + try: + ctypes.set_last_error(error_code) + with pytest.raises(ScreenShotError, match="Windows graphics function failed: fake_func:") as exc: + _check_result_with_last_error(0, fake_func, ()) + finally: + ctypes.set_last_error(previous_last_error) + + assert exc.value.details["func"] == "fake_func" + assert exc.value.details["args"] == () + assert exc.value.details["error_code"] == error_code + assert exc.value.details["error_msg"] + assert isinstance(exc.value.__cause__, OSError) + assert exc.value.__cause__.winerror == error_code + + +def test_check_result_ignores_stale_last_error() -> None: + """Some Windows APIs do not document ``GetLastError()`` diagnostics.""" + + def fake_func() -> None: + pass + + previous_last_error = ctypes.get_last_error() + try: + ctypes.set_last_error(8) + with pytest.raises(ScreenShotError, match="Windows graphics function returned failure: fake_func") as exc: + _check_result(0, fake_func, ()) + finally: + ctypes.set_last_error(previous_last_error) + + assert exc.value.details == { + "func": "fake_func", + "args": (), + "error_msg": "The function returned a failure value.", + } + + def test_region_caching() -> None: """The region to grab is cached, ensure this is well-done.""" with mss.MSS() as sct: @@ -92,12 +157,7 @@ def test_thread_safety() -> None: The following code will throw a ScreenShotError exception if thread-safety is not guaranteed. """ # Let thread 1 finished ahead of thread 2 - thread1 = threading.Thread(target=run_child_thread, args=(30,)) - thread2 = threading.Thread(target=run_child_thread, args=(50,)) - thread1.start() - thread2.start() - thread1.join() - thread2.join() + run_threads(lambda: run_child_thread(30), lambda: run_child_thread(50)) def run_child_thread_bbox(loops: int, bbox: tuple[int, int, int, int]) -> None: @@ -111,9 +171,7 @@ def test_thread_safety_regions() -> None: The following code will throw a ScreenShotError exception if thread-safety is not guaranteed. """ - thread1 = threading.Thread(target=run_child_thread_bbox, args=(100, (0, 0, 100, 100))) - thread2 = threading.Thread(target=run_child_thread_bbox, args=(100, (0, 0, 50, 1))) - thread1.start() - thread2.start() - thread1.join() - thread2.join() + run_threads( + lambda: run_child_thread_bbox(100, (0, 0, 100, 100)), + lambda: run_child_thread_bbox(100, (0, 0, 50, 1)), + ) diff --git a/src/tests/thread_helpers.py b/src/tests/thread_helpers.py new file mode 100644 index 00000000..9cb6b9b1 --- /dev/null +++ b/src/tests/thread_helpers.py @@ -0,0 +1,31 @@ +"""Helpers for tests that need to run work on background threads.""" + +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable + + +def run_threads(*targets: Callable[[], None], start_delay: float = 0.0) -> None: + errors: list[Exception] = [] + + def record(target: Callable[[], None]) -> None: + try: + target() + except Exception as exc: # noqa: BLE001 - transport worker failures to the main test thread. + errors.append(exc) + + threads = [threading.Thread(target=record, args=(target,)) for target in targets] + for index, thread in enumerate(threads): + thread.start() + if start_delay and index < len(threads) - 1: + time.sleep(start_delay) + for thread in threads: + thread.join() + + if errors: + raise errors[0]