diff --git a/tensorflow_quantum/python/util.py b/tensorflow_quantum/python/util.py index 9972f26b5..f945df916 100644 --- a/tensorflow_quantum/python/util.py +++ b/tensorflow_quantum/python/util.py @@ -133,6 +133,10 @@ def random_symbol_circuit(qubits, Symbols are randomly included in the gates of the first `n_moments` moments of the resulting circuit. Then, parameterized H gates are added as subsequent moments for any remaining unused symbols. + + Args: + include_channels: If True, supported noise channels may be inserted into + the circuit. Channels are never parameterized or controlled. """ supported_ops = get_supported_gates() if include_channels: @@ -178,7 +182,12 @@ def random_circuit_resolver_batch(qubits, n_moments=15, p=0.9, include_channels=False): - """Generate a batch of random circuits and symbolless resolvers.""" + """Generate a batch of random circuits and symbolless resolvers. + + Args: + include_channels: If True, supported noise channels may be inserted into + each circuit. Channels are never parameterized or controlled. + """ supported_ops = get_supported_gates() if include_channels: for chan, n in get_supported_channels().items(): @@ -222,7 +231,12 @@ def random_symbol_circuit_resolver_batch(qubits, p=0.9, include_scalars=True, include_channels=False): - """Generate a batch of random circuits and resolvers.""" + """Generate a batch of random circuits and resolvers. + + Args: + include_channels: If True, supported noise channels may be inserted into + each circuit. Channels are never parameterized or controlled. + """ return_circuits = [] return_resolvers = [] for _ in range(batch_size): diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index d22c118a8..1d9dd0e2d 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -21,6 +21,8 @@ sys.path = NEW_PATH # pylint: enable=wrong-import-position +import random + import numpy as np import tensorflow as tf from absl.testing import parameterized @@ -47,6 +49,50 @@ def _exponential(theta, op): return np.eye(op_mat.shape[0]) * np.cos(theta) - 1j * op_mat * np.sin(theta) +def _supported_channel_types(): + return tuple(type(channel) for channel in util.get_supported_channels()) + + +def _op_base_gate_and_controlled(op): + """Return the underlying gate and whether the op is controlled.""" + is_controlled = False + sub_op = op + if isinstance(op, cirq.ControlledOperation): + sub_op = op.sub_operation + is_controlled = True + + gate = sub_op.gate + if isinstance(gate, cirq.ControlledGate): + gate = gate.sub_gate + is_controlled = True + + return gate, is_controlled + + +def _op_has_supported_channel(op): + gate, _ = _op_base_gate_and_controlled(op) + return isinstance(gate, _supported_channel_types()) + + +def _circuit_has_supported_channel(circuit): + return any( + _op_has_supported_channel(op) for moment in circuit for op in moment) + + +def _assert_channel_ops_valid(testcase, circuit): + """Assert channels are neither parameterized nor controlled.""" + for moment in circuit: + for op in moment: + gate, is_controlled = _op_base_gate_and_controlled(op) + if isinstance(gate, _supported_channel_types()): + testcase.assertFalse( + is_controlled, + msg=f'Found controlled channel operation: {op}') + testcase.assertFalse( + cirq.is_parameterized(op), + msg=f'Found parameterized channel operation: {op}') + + BITS = list(cirq.GridQubit.rect(1, 10) + cirq.LineQubit.range(2)) @@ -139,6 +185,111 @@ def test_random_symbol_circuit_resolver_batch_shapes_and_types( isinstance(value, float) for value in resolver.param_dict.values())) + @parameterized.named_parameters( + ('without_channels', False), + ('with_channels', True), + ) + def test_random_circuit_resolver_batch_channel_content( + self, include_channels): + """Confirm channel ops are generated and handled as expected.""" + random.seed(0) + np.random.seed(0) + qubits = cirq.GridQubit.rect(1, 3) + batch_size = 8 + + circuits, _ = util.random_circuit_resolver_batch( + qubits, batch_size, n_moments=30, include_channels=include_channels) + + for circuit in circuits: + if include_channels: + _assert_channel_ops_valid(self, circuit) + else: + self.assertFalse(_circuit_has_supported_channel(circuit)) + + if include_channels: + self.assertTrue( + any(_circuit_has_supported_channel(c) for c in circuits)) + + serialized = util.convert_to_tensor(circuits, + deterministic_proto_serialize=True) + round_tripped = util.from_tensor(serialized) + self.assertAllEqual( + serialized, + util.convert_to_tensor(round_tripped, + deterministic_proto_serialize=True)) + + @parameterized.named_parameters( + ('without_channels', False), + ('with_channels', True), + ) + def test_random_symbol_circuit_resolver_batch_channel_content( + self, include_channels): + """Confirm symbol circuits with channels keep symbols on gates only.""" + random.seed(0) + np.random.seed(0) + qubits = cirq.GridQubit.rect(1, 3) + symbols = ['alpha', 'beta', 'gamma'] + batch_size = 8 + + circuits, _ = util.random_symbol_circuit_resolver_batch( + qubits, + symbols, + batch_size, + n_moments=30, + include_channels=include_channels) + + for circuit in circuits: + self.assertSetEqual(set(util.get_circuit_symbols(circuit)), + set(symbols)) + if include_channels: + _assert_channel_ops_valid(self, circuit) + else: + self.assertFalse(_circuit_has_supported_channel(circuit)) + + if include_channels: + self.assertTrue( + any(_circuit_has_supported_channel(c) for c in circuits)) + + serialized = util.convert_to_tensor(circuits, + deterministic_proto_serialize=True) + round_tripped = util.from_tensor(serialized) + self.assertAllEqual( + serialized, + util.convert_to_tensor(round_tripped, + deterministic_proto_serialize=True)) + + @parameterized.named_parameters( + ('without_channels', False), + ('with_channels', True), + ) + def test_random_symbol_circuit_channel_content(self, include_channels): + """Confirm random_symbol_circuit handles channels correctly.""" + random.seed(0) + np.random.seed(0) + qubits = cirq.GridQubit.rect(1, 3) + symbols = ['alpha', 'beta', 'gamma'] + + circuit = util.random_symbol_circuit(qubits, + symbols, + n_moments=30, + include_channels=include_channels) + + self.assertSetEqual(set(util.get_circuit_symbols(circuit)), + set(symbols)) + if include_channels: + _assert_channel_ops_valid(self, circuit) + self.assertTrue(_circuit_has_supported_channel(circuit)) + else: + self.assertFalse(_circuit_has_supported_channel(circuit)) + + serialized = util.convert_to_tensor([circuit], + deterministic_proto_serialize=True) + round_tripped = util.from_tensor(serialized) + self.assertAllEqual( + serialized, + util.convert_to_tensor(round_tripped, + deterministic_proto_serialize=True)) + @parameterized.parameters(_items_to_tensorize()) def test_convert_to_tensor(self, item): """Test that the convert_to_tensor function works correctly by manually