diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index d22c118a8..57b66b1de 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -536,6 +536,31 @@ def test_get_circuit_symbols_error(self): 'cirq.Circuit'): util.get_circuit_symbols(param) + def test_check_commutability(self): + """Confirm that check_commutability works correctly.""" + q0, q1 = cirq.GridQubit.rect(1, 2) + + # Commutable terms + pauli_sum = cirq.X(q0) + cirq.X(q1) + util.check_commutability(pauli_sum) + + pauli_sum = cirq.X(q0) * cirq.Z(q1) + cirq.Z(q0) * cirq.X(q1) + util.check_commutability(pauli_sum) + + # Non-commutable terms + pauli_sum = cirq.X(q0) + cirq.Z(q0) + with self.assertRaisesRegex(ValueError, + expected_regex='non-commutable'): + util.check_commutability(pauli_sum) + + # Single term + pauli_sum = cirq.PauliSum.from_pauli_strings([cirq.X(q0)]) + util.check_commutability(pauli_sum) + + # Empty PauliSum + pauli_sum = cirq.PauliSum() + util.check_commutability(pauli_sum) + class ExponentialUtilFunctionsTest(tf.test.TestCase): """Test that Exponential utility functions work."""