diff --git a/src/openfermion/circuits/vpe_circuits_test.py b/src/openfermion/circuits/vpe_circuits_test.py index 95b04a747..781af140e 100644 --- a/src/openfermion/circuits/vpe_circuits_test.py +++ b/src/openfermion/circuits/vpe_circuits_test.py @@ -15,6 +15,7 @@ import cirq from openfermion.measurements import get_phase_function +from openfermion.testing import retry_once_with_later_random_values from .vpe_circuits import vpe_single_circuit, vpe_circuits_single_timestep @@ -35,6 +36,7 @@ def test_single_circuit(): assert data_counts[1] == 100 +@retry_once_with_later_random_values def test_single_timestep(): q0 = cirq.GridQubit(0, 0) q1 = cirq.GridQubit(0, 1) diff --git a/src/openfermion/testing/__init__.py b/src/openfermion/testing/__init__.py index 344e8add9..92b655a01 100644 --- a/src/openfermion/testing/__init__.py +++ b/src/openfermion/testing/__init__.py @@ -26,4 +26,8 @@ module_importable, ) -from .wrapped import assert_equivalent_repr, assert_implements_consistent_protocols +from .wrapped import ( + assert_equivalent_repr, + assert_implements_consistent_protocols, + retry_once_with_later_random_values, +) diff --git a/src/openfermion/testing/wrapped.py b/src/openfermion/testing/wrapped.py index 7fd8f2a19..6c5784b7e 100644 --- a/src/openfermion/testing/wrapped.py +++ b/src/openfermion/testing/wrapped.py @@ -65,3 +65,28 @@ def assert_implements_consistent_protocols( global_vals=global_vals, # coverage: ignore local_vals=local_vals, # coverage: ignore ) # coverage: ignore + + +def retry_once_with_later_random_values(testfunc: Any) -> Any: + """Marks a test function for one retry with later random values. + + This decorator is intended for test functions which occasionally fail + for specific random seeds from pytest-randomly. + """ + try: + return cirq.testing.retry_once_with_later_random_values(testfunc) + except AttributeError: + # decorator not available in cirq < 1.5.0 + import functools + import warnings + + @functools.wraps(testfunc) + def wrapped_func(*args, **kwargs) -> Any: + try: + return testfunc(*args, **kwargs) + except AssertionError: + pass + warnings.warn("Retrying in case we got a failing seed from pytest-randomly.") + return testfunc(*args, **kwargs) + + return wrapped_func diff --git a/src/openfermion/testing/wrapped_test.py b/src/openfermion/testing/wrapped_test.py new file mode 100644 index 000000000..fb340af23 --- /dev/null +++ b/src/openfermion/testing/wrapped_test.py @@ -0,0 +1,83 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest import mock +import openfermion.testing.wrapped as wrapped + + +class MockCirqTesting: + @property + def retry_once_with_later_random_values(self): + raise AttributeError("No such attribute") + + +class MockCirq: + testing = MockCirqTesting() + + +def test_retry_once_fallback_success(): + with mock.patch('openfermion.testing.wrapped.cirq', MockCirq()): + + call_count = 0 + + @wrapped.retry_once_with_later_random_values + def successful_test(): + nonlocal call_count + call_count += 1 + return "Success" + + assert successful_test() == "Success" + assert call_count == 1 + + +def test_retry_once_fallback_flaky(): + with mock.patch('openfermion.testing.wrapped.cirq', MockCirq()): + + call_count = 0 + + @wrapped.retry_once_with_later_random_values + def flaky_test(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise AssertionError("Failed first time") + return "Success" + + with pytest.warns( + UserWarning, match="Retrying in case we got a failing seed from pytest-randomly." + ): + assert flaky_test() == "Success" + + assert call_count == 2 + + +def test_retry_once_fallback_failure(): + with mock.patch('openfermion.testing.wrapped.cirq', MockCirq()): + + call_count = 0 + + @wrapped.retry_once_with_later_random_values + def failing_test(): + nonlocal call_count + call_count += 1 + raise AssertionError("Failed both times") + + with pytest.warns( + UserWarning, match="Retrying in case we got a failing seed from pytest-randomly." + ): + with pytest.raises(AssertionError, match="Failed both times"): + failing_test() + + assert call_count == 2