Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions tensorflow_quantum/python/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def random_symbol_circuit(qubits,
Symbols are randomly included in the gates of the first `n_moments` moments
of the resulting circuit. Then, parameterized H gates are added as
subsequent moments for any remaining unused symbols.

Args:
include_channels: If True, supported noise channels may be inserted into
the circuit. Channels are never parameterized or controlled.
"""
supported_ops = get_supported_gates()
if include_channels:
Expand Down Expand Up @@ -178,7 +182,12 @@ def random_circuit_resolver_batch(qubits,
n_moments=15,
p=0.9,
include_channels=False):
"""Generate a batch of random circuits and symbolless resolvers."""
"""Generate a batch of random circuits and symbolless resolvers.

Args:
include_channels: If True, supported noise channels may be inserted into
each circuit. Channels are never parameterized or controlled.
"""
supported_ops = get_supported_gates()
if include_channels:
for chan, n in get_supported_channels().items():
Expand Down Expand Up @@ -222,7 +231,12 @@ def random_symbol_circuit_resolver_batch(qubits,
p=0.9,
include_scalars=True,
include_channels=False):
"""Generate a batch of random circuits and resolvers."""
"""Generate a batch of random circuits and resolvers.

Args:
include_channels: If True, supported noise channels may be inserted into
each circuit. Channels are never parameterized or controlled.
"""
return_circuits = []
return_resolvers = []
for _ in range(batch_size):
Expand Down
151 changes: 151 additions & 0 deletions tensorflow_quantum/python/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
sys.path = NEW_PATH
# pylint: enable=wrong-import-position

import random

import numpy as np
import tensorflow as tf
from absl.testing import parameterized
Expand All @@ -47,6 +49,50 @@ def _exponential(theta, op):
return np.eye(op_mat.shape[0]) * np.cos(theta) - 1j * op_mat * np.sin(theta)


def _supported_channel_types():
return tuple(type(channel) for channel in util.get_supported_channels())


def _op_base_gate_and_controlled(op):
"""Return the underlying gate and whether the op is controlled."""
is_controlled = False
sub_op = op
if isinstance(op, cirq.ControlledOperation):
sub_op = op.sub_operation
is_controlled = True

gate = sub_op.gate
if isinstance(gate, cirq.ControlledGate):
gate = gate.sub_gate
is_controlled = True

return gate, is_controlled


def _op_has_supported_channel(op):
gate, _ = _op_base_gate_and_controlled(op)
return isinstance(gate, _supported_channel_types())


def _circuit_has_supported_channel(circuit):
return any(
_op_has_supported_channel(op) for moment in circuit for op in moment)


def _assert_channel_ops_valid(testcase, circuit):
"""Assert channels are neither parameterized nor controlled."""
for moment in circuit:
for op in moment:
gate, is_controlled = _op_base_gate_and_controlled(op)
if isinstance(gate, _supported_channel_types()):
testcase.assertFalse(
is_controlled,
msg=f'Found controlled channel operation: {op}')
testcase.assertFalse(
cirq.is_parameterized(op),
msg=f'Found parameterized channel operation: {op}')


BITS = list(cirq.GridQubit.rect(1, 10) + cirq.LineQubit.range(2))


Expand Down Expand Up @@ -139,6 +185,111 @@ def test_random_symbol_circuit_resolver_batch_shapes_and_types(
isinstance(value, float)
for value in resolver.param_dict.values()))

@parameterized.named_parameters(
('without_channels', False),
('with_channels', True),
)
def test_random_circuit_resolver_batch_channel_content(
self, include_channels):
"""Confirm channel ops are generated and handled as expected."""
random.seed(0)
Comment thread
rosspeili marked this conversation as resolved.
np.random.seed(0)
qubits = cirq.GridQubit.rect(1, 3)
batch_size = 8

circuits, _ = util.random_circuit_resolver_batch(
qubits, batch_size, n_moments=30, include_channels=include_channels)

for circuit in circuits:
if include_channels:
_assert_channel_ops_valid(self, circuit)
else:
self.assertFalse(_circuit_has_supported_channel(circuit))

if include_channels:
self.assertTrue(
any(_circuit_has_supported_channel(c) for c in circuits))

serialized = util.convert_to_tensor(circuits,
deterministic_proto_serialize=True)
round_tripped = util.from_tensor(serialized)
self.assertAllEqual(
serialized,
util.convert_to_tensor(round_tripped,
deterministic_proto_serialize=True))

@parameterized.named_parameters(
('without_channels', False),
('with_channels', True),
)
def test_random_symbol_circuit_resolver_batch_channel_content(
self, include_channels):
"""Confirm symbol circuits with channels keep symbols on gates only."""
random.seed(0)
Comment thread
rosspeili marked this conversation as resolved.
np.random.seed(0)
qubits = cirq.GridQubit.rect(1, 3)
symbols = ['alpha', 'beta', 'gamma']
batch_size = 8

circuits, _ = util.random_symbol_circuit_resolver_batch(
qubits,
symbols,
batch_size,
n_moments=30,
include_channels=include_channels)

for circuit in circuits:
self.assertSetEqual(set(util.get_circuit_symbols(circuit)),
set(symbols))
if include_channels:
_assert_channel_ops_valid(self, circuit)
else:
self.assertFalse(_circuit_has_supported_channel(circuit))

if include_channels:
self.assertTrue(
any(_circuit_has_supported_channel(c) for c in circuits))

serialized = util.convert_to_tensor(circuits,
deterministic_proto_serialize=True)
round_tripped = util.from_tensor(serialized)
self.assertAllEqual(
serialized,
util.convert_to_tensor(round_tripped,
deterministic_proto_serialize=True))

@parameterized.named_parameters(
('without_channels', False),
('with_channels', True),
)
def test_random_symbol_circuit_channel_content(self, include_channels):
"""Confirm random_symbol_circuit handles channels correctly."""
random.seed(0)
Comment thread
rosspeili marked this conversation as resolved.
np.random.seed(0)
qubits = cirq.GridQubit.rect(1, 3)
symbols = ['alpha', 'beta', 'gamma']

circuit = util.random_symbol_circuit(qubits,
symbols,
n_moments=30,
include_channels=include_channels)

self.assertSetEqual(set(util.get_circuit_symbols(circuit)),
set(symbols))
if include_channels:
_assert_channel_ops_valid(self, circuit)
self.assertTrue(_circuit_has_supported_channel(circuit))
else:
self.assertFalse(_circuit_has_supported_channel(circuit))

serialized = util.convert_to_tensor([circuit],
deterministic_proto_serialize=True)
round_tripped = util.from_tensor(serialized)
self.assertAllEqual(
serialized,
util.convert_to_tensor(round_tripped,
deterministic_proto_serialize=True))

@parameterized.parameters(_items_to_tensorize())
def test_convert_to_tensor(self, item):
"""Test that the convert_to_tensor function works correctly by manually
Expand Down
Loading