From 11afb649c4086ee8e2827884c008d6f9c04c59cc Mon Sep 17 00:00:00 2001 From: mhucka Date: Fri, 5 Jun 2026 01:43:13 +0000 Subject: [PATCH] Add test to cover check_commutability This adds a unit test for `util.check_commutability()`, a previously-uncovered function. --- tensorflow_quantum/python/util_test.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index d22c118a8..57b66b1de 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -536,6 +536,31 @@ def test_get_circuit_symbols_error(self): 'cirq.Circuit'): util.get_circuit_symbols(param) + def test_check_commutability(self): + """Confirm that check_commutability works correctly.""" + q0, q1 = cirq.GridQubit.rect(1, 2) + + # Commutable terms + pauli_sum = cirq.X(q0) + cirq.X(q1) + util.check_commutability(pauli_sum) + + pauli_sum = cirq.X(q0) * cirq.Z(q1) + cirq.Z(q0) * cirq.X(q1) + util.check_commutability(pauli_sum) + + # Non-commutable terms + pauli_sum = cirq.X(q0) + cirq.Z(q0) + with self.assertRaisesRegex(ValueError, + expected_regex='non-commutable'): + util.check_commutability(pauli_sum) + + # Single term + pauli_sum = cirq.PauliSum.from_pauli_strings([cirq.X(q0)]) + util.check_commutability(pauli_sum) + + # Empty PauliSum + pauli_sum = cirq.PauliSum() + util.check_commutability(pauli_sum) + class ExponentialUtilFunctionsTest(tf.test.TestCase): """Test that Exponential utility functions work."""