Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/openfermion/circuits/vpe_circuits_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import cirq

from openfermion.measurements import get_phase_function
from openfermion.testing import retry_once_with_later_random_values

from .vpe_circuits import vpe_single_circuit, vpe_circuits_single_timestep

Expand All @@ -35,6 +36,7 @@ def test_single_circuit():
assert data_counts[1] == 100


@retry_once_with_later_random_values
def test_single_timestep():
q0 = cirq.GridQubit(0, 0)
q1 = cirq.GridQubit(0, 1)
Expand Down
6 changes: 5 additions & 1 deletion src/openfermion/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,8 @@
module_importable,
)

from .wrapped import assert_equivalent_repr, assert_implements_consistent_protocols
from .wrapped import (
assert_equivalent_repr,
assert_implements_consistent_protocols,
retry_once_with_later_random_values,
)
25 changes: 25 additions & 0 deletions src/openfermion/testing/wrapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,28 @@ def assert_implements_consistent_protocols(
global_vals=global_vals, # coverage: ignore
local_vals=local_vals, # coverage: ignore
) # coverage: ignore


def retry_once_with_later_random_values(testfunc: Any) -> Any:
"""Marks a test function for one retry with later random values.

This decorator is intended for test functions which occasionally fail
for specific random seeds from pytest-randomly.
"""
try:
return cirq.testing.retry_once_with_later_random_values(testfunc)
except AttributeError:
# decorator not available in cirq < 1.5.0
import functools
import warnings

@functools.wraps(testfunc)
def wrapped_func(*args, **kwargs) -> Any:
try:
return testfunc(*args, **kwargs)
except AssertionError:
pass
warnings.warn("Retrying in case we got a failing seed from pytest-randomly.")
return testfunc(*args, **kwargs)

return wrapped_func
83 changes: 83 additions & 0 deletions src/openfermion/testing/wrapped_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from unittest import mock
import openfermion.testing.wrapped as wrapped


class MockCirqTesting:
@property
def retry_once_with_later_random_values(self):
raise AttributeError("No such attribute")


class MockCirq:
testing = MockCirqTesting()


def test_retry_once_fallback_success():
with mock.patch('openfermion.testing.wrapped.cirq', MockCirq()):

call_count = 0

@wrapped.retry_once_with_later_random_values
def successful_test():
nonlocal call_count
call_count += 1
return "Success"

assert successful_test() == "Success"
assert call_count == 1


def test_retry_once_fallback_flaky():
with mock.patch('openfermion.testing.wrapped.cirq', MockCirq()):

call_count = 0

@wrapped.retry_once_with_later_random_values
def flaky_test():
nonlocal call_count
call_count += 1
if call_count == 1:
raise AssertionError("Failed first time")
return "Success"

with pytest.warns(
UserWarning, match="Retrying in case we got a failing seed from pytest-randomly."
):
assert flaky_test() == "Success"

assert call_count == 2


def test_retry_once_fallback_failure():
with mock.patch('openfermion.testing.wrapped.cirq', MockCirq()):

call_count = 0

@wrapped.retry_once_with_later_random_values
def failing_test():
nonlocal call_count
call_count += 1
raise AssertionError("Failed both times")

with pytest.warns(
UserWarning, match="Retrying in case we got a failing seed from pytest-randomly."
):
with pytest.raises(AssertionError, match="Failed both times"):
failing_test()

assert call_count == 2
Loading