Source code for pennylane.transforms.core.transform_dispatcher
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the transform dispatcher and the transform container.
"""
import functools
import os
import warnings
from collections.abc import Callable, Sequence
from copy import copy
from pennylane import capture, math
from pennylane.capture.autograph import wraps
from pennylane.exceptions import TransformError
from pennylane.measurements import MeasurementProcess
from pennylane.operation import Operator
from pennylane.pytrees import flatten
from pennylane.queuing import AnnotatedQueue, QueuingManager, apply
from pennylane.tape import QuantumScript
from pennylane.typing import ResultBatch
@functools.lru_cache
def _create_transform_primitive():
try:
# pylint: disable=import-outside-toplevel
from pennylane.capture.custom_primitives import QmlPrimitive
except ImportError:
return None
transform_prim = QmlPrimitive("transform")
transform_prim.multiple_results = True
transform_prim.prim_type = "transform"
# pylint: disable=too-many-arguments, disable=unused-argument
@transform_prim.def_impl
def _impl(*all_args, inner_jaxpr, args_slice, consts_slice, targs_slice, tkwargs, transform):
args = all_args[slice(*args_slice)]
consts = all_args[slice(*consts_slice)]
return capture.eval_jaxpr(inner_jaxpr, consts, *args)
@transform_prim.def_abstract_eval
def _abstract_eval(*_, inner_jaxpr, **__):
return [out.aval for out in inner_jaxpr.outvars]
return transform_prim
def _create_plxpr_fallback_transform(tape_transform):
# pylint: disable=import-outside-toplevel
try:
import jax
from pennylane.tape import plxpr_to_tape
except ImportError:
return None
def plxpr_fallback_transform(jaxpr, consts, targs, tkwargs, *args):
# Restore tkwargs from hashable tuple to dict
tkwargs = dict(tkwargs)
def wrapper(*inner_args):
tape = plxpr_to_tape(jaxpr, consts, *inner_args)
with capture.pause():
tapes, _ = tape_transform(tape, *targs, **tkwargs)
if len(tapes) > 1:
raise TransformError(
f"Cannot apply {tape_transform.__name__} transform with program "
"capture enabled. Only transforms that return a single QuantumTape "
"and null processing function are usable with program capture."
)
for op in tapes[0].operations:
data, struct = jax.tree_util.tree_flatten(op)
jax.tree_util.tree_unflatten(struct, data)
out = []
for mp in tapes[0].measurements:
data, struct = jax.tree_util.tree_flatten(mp)
out.append(jax.tree_util.tree_unflatten(struct, data))
return tuple(out)
abstracted_axes, abstract_shapes = capture.determine_abstracted_axes(args)
return jax.make_jaxpr(wrapper, abstracted_axes=abstracted_axes)(*abstract_shapes, *args)
return plxpr_fallback_transform
[docs]
def specific_apply_transform(transform, obj, *targs, **tkwargs):
"""The default behavior for TransformDispatcher._apply_transform. By default, it dispatches to the
generic registration."""
return transform.generic_apply_transform(obj, *targs, **tkwargs)
[docs]
@functools.singledispatch
def generic_apply_transform(obj, transform, *targs, **tkwargs):
"""Apply an generic transform to a specific type of object. A singledispatch function
used by ``TransformDipsatcher.generic_apply_transform``, but with a different order of arguments
to allow is to be used by singledispatch.
When called with an object that is not a valid dispatch target (e.g., not a QNode, tape, etc.),
this returns a BoundTransform with the supplied args and kwargs. This enables patterns like:
decompose(gate_set=gate_set) + merge_rotations(1e-6)
where transforms are called with just configuration parameters and combined into a CompilePipeline
"""
# If the first argument is not a valid dispatch target, return a BoundTransform
# with the first argument and any additional args/kwargs stored as transform parameters.
return BoundTransform(transform, args=(obj, *targs), kwargs=tkwargs)
# pragma: no cover
def _dummy_register(obj): # just used for sphinx
if isinstance(obj, type): # pragma: no cover
return lambda arg: arg # pragma: no cover
return obj # pragma: no cover
[docs]
class TransformDispatcher: # pylint: disable=too-many-instance-attributes
r"""Converts a transform that has the signature ``(tape -> Sequence(tape), fn)`` to a transform dispatcher
that can act on :class:`pennylane.tape.QuantumTape`, quantum function, :class:`pennylane.QNode`,
:class:`pennylane.devices.Device`.
.. warning::
This class is developer-facing and should not be used directly. Instead, use
:func:`qml.transform <pennylane.transform>` if you would like to make a custom
transform.
.. seealso:: :func:`~.pennylane.transform`
"""
def __new__(cls, *args, **__):
if os.environ.get("SPHINX_BUILD") == "1":
# If called during a Sphinx documentation build,
# simply return the original function rather than
# instantiating the object. This allows the signature to
# be correctly displayed in the documentation.
warnings.warn(
"Transforms have been disabled, as a Sphinx "
"build has been detected via SPHINX_BUILD='1'. If this is not the "
"case, please set the environment variable SPHINX_BUILD='0'.",
UserWarning,
)
args[0].custom_qnode_transform = lambda x: x
args[0].register = _dummy_register
return args[0]
return super().__new__(cls)
# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
transform: Callable | None = None,
pass_name: None | str = None,
*,
expand_transform: Callable | None = None,
classical_cotransform: Callable | None = None,
is_informative: bool = False,
final_transform: bool = False,
use_argnum_in_expand: bool = False,
plxpr_transform=None,
):
if transform is None and pass_name is None:
raise ValueError(
"Transforms must currently define either a tape transform or a pass_name"
)
self._transform = transform
self._expand_transform = expand_transform
self._classical_cotransform = classical_cotransform
self._is_informative = is_informative
# is_informative supersedes final_transform
self._final_transform = is_informative or final_transform
self._custom_qnode_transform = None
self._pass_name = pass_name
self._use_argnum_in_expand = use_argnum_in_expand
if transform:
functools.update_wrapper(self, transform)
self._apply_transform = functools.singledispatch(
functools.partial(specific_apply_transform, self)
)
self._plxpr_transform = plxpr_transform or _create_plxpr_fallback_transform(self._transform)
@property
def pass_name(self) -> None | str:
"""The name of the equivalent MLIR pass."""
return self._pass_name
@property
def register(self):
"""Returns a decorator for registering a specific application behavior for a given transform
and a new class.
.. code-block:: python
@qml.transform
def printer(tape):
print("I have a tape: ", tape)
return (tape, ), lambda x: x[0]
@printer.register
def _(obj: qml.operation.Operator, *targs, **tkwargs):
print("I have an operator:", obj)
return obj
>>> printer(qml.X(0))
I have an operator: X(0)
X(0)
"""
return self._apply_transform.register
[docs]
def generic_apply_transform(self, obj, *targs, **tkwargs):
"""generic_apply_transform(obj, *targs, **tkwargs)
Generic application of a transform that forms the default for all transforms.
Args:
obj: The object we want to transform
*targs: The arguments for the transform
**tkwargs: The keyword arguments for the transform.
"""
return generic_apply_transform(obj, self, *targs, **tkwargs)
[docs]
@staticmethod
def generic_register(arg=None):
"""Returns a decorator for registering a default application behavior for a transform for a new class.
Given a special new class, we can register how transforms should apply to them via:
.. code-block:: python
class Subroutine:
def __repr__(self):
return f"<Subroutine: {self.ops}>"
def __init__(self, ops):
self.ops = ops
from pennylane.transforms.core import TransformDispatcher
@TransformDispatcher.generic_register
def apply_to_subroutine(obj: Subroutine, transform, *targs, **tkwargs):
tape = qml.tape.QuantumScript(obj.ops)
batch, _ = transform(tape, *targs, **tkwargs)
return Subroutine(batch[0].operations)
>>> qml.transforms.cancel_inverses(Subroutine([qml.Y(0), qml.X(0), qml.X(0)]))
<Subroutine: [Y(0)]>
The type can also be explicitly provided like:
.. code-block:: python
@TransformDispatcher.generic_register(Subroutine)
def apply_to_subroutine(obj: Subroutine, transform, *targs, **tkwargs):
tape = qml.tape.QuantumScript(obj.ops)
batch, _ = transform(tape, *targs, **tkwargs)
return Subroutine(batch[0].operations)
to more explicitly force registration for a given type.
"""
return generic_apply_transform.register(arg) # pylint: disable=no-member
def __call__(self, obj=None, *targs, **tkwargs): # pylint: disable=keyword-arg-before-vararg
# If called with only keyword arguments (no positional args), return a BoundTransform
# This enables patterns like: decompose(gate_set=gate_set) + merge_rotations(1e-6)
if obj is None:
if tkwargs:
return BoundTransform(self, args=targs, kwargs=tkwargs)
raise TypeError(
f"{self!r} requires at least one argument. "
"Provide a tape, qfunc, QNode, or device to transform, "
"or provide keyword arguments to create a BoundTransform for composition."
)
return self._apply_transform(obj, *targs, **tkwargs)
def __repr__(self):
name = self._transform.__name__ if self._transform else self.pass_name
return f"<transform: {name}>"
def __add__(self, other):
"""Add two dispatchers or a dispatcher and a container to create a CompilePipeline.
When adding dispatchers, they are converted to containers with no args or kwargs.
For dispatcher + program, Python falls back to CompilePipeline.__radd__.
Args:
other: Another TransformDispatcher or BoundTransform to add.
Returns:
CompilePipeline: A new program with this dispatcher followed by the other.
"""
# Convert this dispatcher to a container (no args/kwargs) and delegate
return BoundTransform(self) + other
def __mul__(self, n):
"""Multiply a dispatcher by an integer to create a program with repeated dispatchers.
Args:
n (int): Number of times to repeat this dispatcher.
Returns:
CompilePipeline: A new program with this dispatcher repeated n times.
"""
# Convert to container (no args/kwargs) and delegate
return BoundTransform(self) * n
__rmul__ = __mul__
@property
def transform(self):
"""The quantum transform."""
return self._transform
@property
def expand_transform(self):
"""The expand transform."""
return self._expand_transform
@property
def classical_cotransform(self):
"""The classical co-transform."""
return self._classical_cotransform
@property
def plxpr_transform(self):
"""Function for transforming plxpr."""
return self._plxpr_transform
@property
def is_informative(self):
"""``True`` if the transform is informative."""
return self._is_informative
@property
def final_transform(self):
"""``True`` if the transformed tapes must be executed."""
return self._final_transform
[docs]
def custom_qnode_transform(self, fn):
"""Register a custom QNode execution wrapper function
for the batch transform.
**Example**
.. code-block:: python3
@transform
def my_transform(tape, *targs, **tkwargs):
...
return tapes, processing_fn
@my_transform.custom_qnode_transform
def my_custom_qnode_wrapper(self, qnode, targs, tkwargs):
new_tkwargs = dict(tkwargs)
new_tkwargs['shots'] = 100
return self.generic_apply_transform(qnode, *targs, **new_tkwargs)
The custom QNode execution wrapper must have arguments
``self`` (the batch transform object), ``qnode`` (the input QNode
to transform and execute), ``targs`` and ``tkwargs`` (the transform
arguments and keyword arguments respectively).
It should return a QNode that accepts the *same* arguments as the
input QNode with the transform applied.
The default :meth:`~.generic_apply_transform` method may be called
if only pre- or post-processing dependent on QNode arguments is required.
"""
# unfortunately, we don't have access to qml.QNode here, or in the places where
# transforms are defining custom qnode transforms, so we still need to have this
# "hold onto until later" approach
# potentially can remove this patch by moving source code
self._custom_qnode_transform = fn
[docs]
def default_qnode_transform(self, qnode, targs, tkwargs):
"""
The default method that takes in a QNode and returns another QNode
with the transform applied.
"""
# same comment as custom_qnode_transform :(
qnode = copy(qnode)
if self.expand_transform:
qnode.transform_program.push_back(
BoundTransform(
TransformDispatcher(self._expand_transform),
args=targs,
kwargs=tkwargs,
use_argnum=self._use_argnum_in_expand,
)
)
qnode.transform_program.push_back(
BoundTransform(
self,
args=targs,
kwargs=tkwargs,
)
)
return qnode
[docs]
class BoundTransform: # pylint: disable=too-many-instance-attributes
"""A transform with bound inputs.
Args:
transform: Any transform.
args (Sequence[Any]): The positional arguments to use with the transform.
kwargs (Dict | None): The keyword arguments for use with the transform.
Keyword Args:
use_argnum (bool): An advanced option used in conjunction with calculating
classical cotransforms of jax workflows.
.. seealso:: :func:`~.pennylane.transform`
>>> bound_t = BoundTransform(qml.transforms.merge_rotations, (), {"atol": 1e-4})
>>> bound_t
<merge_rotations((), {'atol': 0.0001})>
The class can also be created by directly calling the transform with its inputs:
>>> qml.transforms.merge_rotations(atol=1e-4)
<merge_rotations((), {'atol': 0.0001})>
These objects can now directly applied to anything individual transforms can apply to:
.. code-block:: python
@bound_t
@qml.qnode(qml.device('null.qubit'))
def c(x):
qml.RX(x, 0)
qml.RX(-x + 1e-6, 0)
qml.RY(x, 1)
qml.RY(-x + 1e-2, 1)
return qml.probs(wires=(0,1))
If we draw this circuit, we can see that the ``merge_rotations`` transforms was applied with a
tolerance of ``1e-4``. The ``RX`` gates sufficiently close to zero disappear, while the ``RY`` gates
that are further from zero remain.
>>> print(qml.draw(c)(1.0))
0: ───────────┤ ╭Probs
1: ──RY(0.01)─┤ ╰Probs
Repeated versions of the bound transform can be created with multiplication:
>>> bound_t * 3
CompilePipeline(merge_rotations, merge_rotations, merge_rotations)
And it can be used in conjunction with both individual transforms, bound transforms, and
compile pipelines.
>>> bound_t + qml.transforms.cancel_inverses
CompilePipeline(merge_rotations, cancel_inverses)
>>> bound_t + qml.transforms.cancel_inverses + bound_t
CompilePipeline(merge_rotations, cancel_inverses, merge_rotations)
"""
def __hash__(self):
hashable_dict = tuple((key, value) for key, value in self.kwargs.items())
return hash((self.transform, self.pass_name, self.args, hashable_dict))
def __init__(
self,
transform: TransformDispatcher,
args: tuple | list = (),
kwargs: None | dict = None,
*,
use_argnum: bool = False,
**transform_config,
):
if not isinstance(transform, TransformDispatcher):
transform = TransformDispatcher(transform, **transform_config)
elif transform_config:
raise ValueError(
f"transform_config kwargs {transform_config} cannot be passed if a TransformDispatcher is provided."
)
self._transform_dispatcher = transform
self._args = tuple(args)
self._kwargs = kwargs or {}
self._use_argnum = use_argnum
def __repr__(self):
name = self.transform.__name__ if self.transform else self.pass_name
return f"<{name}({self._args}, {self._kwargs})>"
def __call__(self, obj):
return self._transform_dispatcher(obj, *self.args, **self.kwargs)
def __iter__(self):
return iter(
(
self._transform_dispatcher.transform,
self._args,
self._kwargs,
self._transform_dispatcher._classical_cotransform,
self._transform_dispatcher._plxpr_transform,
self._transform_dispatcher._is_informative,
self._transform_dispatcher.final_transform,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, BoundTransform):
return False
return (
self.args == other.args
and self.transform == other.transform
and self.pass_name == other.pass_name
and self.kwargs == other.kwargs
and self.classical_cotransform == other.classical_cotransform
and self.is_informative == other.is_informative
and self.final_transform == other.final_transform
)
@property
def transform(self) -> Callable | None:
"""The raw tape transform definition for the transform."""
return self._transform_dispatcher.transform
@property
def pass_name(self) -> None | str:
"""The name of the corresponding Catalyst pass, if it exists."""
return self._transform_dispatcher.pass_name
@property
def args(self) -> tuple:
"""The stored quantum transform's ``args``."""
return self._args
@property
def kwargs(self) -> dict:
"""The stored quantum transform's ``kwargs``."""
return self._kwargs
@property
def classical_cotransform(self) -> None | Callable:
"""The stored quantum transform's classical co-transform."""
return self._transform_dispatcher.classical_cotransform
@property
def plxpr_transform(self) -> None | Callable:
"""The stored quantum transform's PLxPR transform.
**UNMAINTAINED AND EXPERIMENTAL**
"""
return self._transform_dispatcher.plxpr_transform
@property
def is_informative(self) -> bool:
"""Whether or not a transform is informative. If true the transform is queued at the end
of the transform program and the tapes or qnode aren't executed.
This property is rare, but used by such transforms as ``qml.transforms.commutation_dag``.
"""
return self._transform_dispatcher.is_informative
@property
def final_transform(self) -> bool:
"""Whether or not the transform must be the last one to be executed
in a ``CompilePipeline``.
This property is ``True`` for most gradient transforms.
"""
return self._transform_dispatcher.final_transform
def __add__(self, other):
"""Add two containers or a container and a dispatcher to create a CompilePipeline.
For container + program, Python falls back to CompilePipeline.__radd__.
Args:
other: Another BoundTransform or TransformDispatcher to add.
Returns:
CompilePipeline: A new program with this container followed by the other.
"""
# Convert dispatcher to container if needed
if isinstance(other, TransformDispatcher):
other = BoundTransform(other)
if isinstance(other, BoundTransform):
# Import here to avoid circular import\
# pylint: disable=import-outside-toplevel
from .compile_pipeline import CompilePipeline
if self.final_transform and other.final_transform:
raise TransformError(
f"Both {self} and {other} are final transforms and cannot be combined."
)
return CompilePipeline([self, other])
# For CompilePipeline, Python falls back to program.__radd__(container)
return NotImplemented
def __mul__(self, n):
"""Multiply a container by an integer to create a program with repeated containers.
Args:
n (int): Number of times to repeat this container.
Returns:
CompilePipeline: A new program with this container repeated n times.
"""
# Import here to avoid circular import
from .compile_pipeline import CompilePipeline # pylint: disable=import-outside-toplevel
if not isinstance(n, int):
return NotImplemented
if n < 0:
raise ValueError("Cannot multiply transform container by negative integer")
if self.final_transform and n > 1:
raise TransformError(
f"{self} is a final transform and cannot be applied more than once."
)
return CompilePipeline([self] * n)
__rmul__ = __mul__
@TransformDispatcher.generic_register
def _apply_to_tape(obj: QuantumScript, transform, *targs, **tkwargs):
if transform.transform is None:
raise NotImplementedError(f"transform {transform} has no defined tape transform.")
if transform.expand_transform:
expanded_tapes, expand_processing = transform.expand_transform(obj, *targs, **tkwargs)
transformed_tapes = []
processing_and_slices = []
start = 0
for tape in expanded_tapes:
intermediate_tapes, post_processing_fn = transform.transform(tape, *targs, **tkwargs)
transformed_tapes.extend(intermediate_tapes)
end = start + len(intermediate_tapes)
processing_and_slices.append(tuple([post_processing_fn, slice(start, end)]))
start = end
def processing_fn(results):
processed_results = [fn(results[slice]) for fn, slice in processing_and_slices]
return expand_processing(processed_results)
else:
transformed_tapes, processing_fn = transform.transform(obj, *targs, **tkwargs)
if transform.is_informative:
return processing_fn(transformed_tapes)
return transformed_tapes, processing_fn
def _capture_apply(obj, transform, *targs, **tkwargs):
@wraps(obj)
def qfunc_transformed(*args, **kwargs):
import jax # pylint: disable=import-outside-toplevel
flat_qfunc = capture.flatfn.FlatFn(obj)
jaxpr = jax.make_jaxpr(flat_qfunc)(*args, **kwargs)
flat_args = jax.tree_util.tree_leaves(args)
n_args = len(flat_args)
n_consts = len(jaxpr.consts)
args_slice = slice(0, n_args)
consts_slice = slice(n_args, n_args + n_consts)
targs_slice = slice(n_args + n_consts, None)
results = _create_transform_primitive().bind( # pylint: disable=protected-access
*flat_args,
*jaxpr.consts,
*targs,
inner_jaxpr=jaxpr.jaxpr,
args_slice=args_slice,
consts_slice=consts_slice,
targs_slice=targs_slice,
tkwargs=tkwargs,
transform=transform,
)
assert flat_qfunc.out_tree is not None
return jax.tree_util.tree_unflatten(flat_qfunc.out_tree, results)
return qfunc_transformed
[docs]
@TransformDispatcher.generic_register
def apply_to_callable(obj: Callable, transform, *targs, **tkwargs):
"""Apply a transform to a Callable object."""
if obj.__class__.__name__ == "QJIT":
raise TransformError(
"Functions that are wrapped / decorated with qjit cannot subsequently be"
f" transformed with a PennyLane transform (attempted {transform})."
f" For the desired affect, ensure that qjit is applied after {transform}."
)
@functools.wraps(obj)
def qfunc_transformed(*args, **kwargs):
if capture.enabled():
return _capture_apply(obj, transform, *targs, **tkwargs)(*args, **kwargs)
# removes the argument to the qfuncs from the active queuing context.
leaves, _ = flatten((args, kwargs), lambda obj: isinstance(obj, Operator))
for l in leaves:
if isinstance(l, Operator):
QueuingManager.remove(l)
with AnnotatedQueue() as q:
qfunc_output = obj(*args, **kwargs)
tape = QuantumScript.from_queue(q)
with QueuingManager.stop_recording():
if transform.is_informative:
transformed_tapes, processing_fn = transform.transform(tape, *targs, **tkwargs)
else:
transformed_tapes, processing_fn = transform(tape, *targs, **tkwargs)
if len(transformed_tapes) != 1:
raise TransformError(
"Impossible to dispatch your transform on quantum function, because more than "
"one tape is returned"
)
transformed_tape = transformed_tapes[0]
if transform.is_informative:
return processing_fn(transformed_tapes)
for op in transformed_tape.operations:
apply(op)
mps = [apply(mp) for mp in transformed_tape.measurements]
if not mps:
return qfunc_output
if isinstance(qfunc_output, MeasurementProcess):
return tuple(mps) if len(mps) > 1 else mps[0]
if isinstance(qfunc_output, (tuple, list)):
return type(qfunc_output)(mps)
interface = math.get_interface(qfunc_output)
return math.asarray(mps, like=interface)
return qfunc_transformed
@TransformDispatcher.generic_register
def _apply_to_sequence(obj: Sequence, transform, *targs, **tkwargs):
if not all(isinstance(t, QuantumScript) for t in obj):
raise TransformError(
f"Transforms can only apply to sequences of QuantumScript, not {type(obj[0])}"
)
execution_tapes = []
batch_fns = []
tape_counts = []
for t in obj:
# Preprocess the tapes by applying transforms
# to each tape, and storing corresponding tapes
# for execution, processing functions, and list of tape lengths.
new_tapes, fn = transform(t, *targs, **tkwargs)
execution_tapes.extend(new_tapes)
batch_fns.append(fn)
tape_counts.append(len(new_tapes))
def processing_fn(res: ResultBatch) -> ResultBatch:
"""Applies a batch of post-processing functions to results.
Args:
res (ResultBatch): the results of executing a batch of circuits.
Returns:
ResultBatch: results that have undergone classical post processing.
Closure variables:
tape_counts: the number of tapes outputted from each application of the transform.
batch_fns: the post processing functions to apply to each sub-batch.
"""
count = 0
final_results = []
for f, s in zip(batch_fns, tape_counts):
# apply any batch transform post-processing
new_res = f(res[count : count + s])
final_results.append(new_res)
count += s
return tuple(final_results)
return tuple(execution_tapes), processing_fn
TransformContainer = BoundTransform
_modules/pennylane/transforms/core/transform_dispatcher
Download Python script
Download Notebook
View on GitHub