Source code for pennylane.templates.subroutines.qram

# Copyright 2018-2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Bucket-Brigade QRAM with explicit bus routing for PennyLane, supporting:
- Bucket-brigade QRAM LSBs (``control_wires``) using 3-qubits-per-node (dir, portL, portR)

Address loading is performed **layer-by-layer** by routing a single top **bus** qubit
down to the active node using CSWAPs controlled by already-written upper routers,
depositing each low-order address bit into the node's direction qubit.

Data phase routes the target qubits down to the selected leaf for each target bit,
performs the leaf write (classical bit flip), then routes back and restores the target.
"""
from collections import defaultdict
from dataclasses import dataclass
from typing import Sequence

from pennylane import math
from pennylane.decomposition import (
    add_decomps,
    controlled_resource_rep,
    register_resources,
    resource_rep,
)
from pennylane.operation import Operation
from pennylane.ops import CSWAP, SWAP, Hadamard, PauliZ, adjoint, ctrl
from pennylane.wires import Wires, WiresLike

# pylint: disable=consider-using-generator


# -----------------------------
# Wires Data Structure
# -----------------------------
@dataclass
class _QRAMWires:

    control_wires: Wires
    target_wires: Wires
    bus_wire: Wires
    dir_wires: Wires
    portL_wires: Wires
    portR_wires: Wires

    # ---------- Tree helpers ----------
    def node_in_wire(self, level: int, prefix: int):
        """The input wire of node (level, prefix): root input is `bus`, else parent's L/R port."""
        if level == 0:
            return self.bus_wire[0]
        parent = _node_index(level - 1, prefix >> 1)
        return self.portL_wires[parent] if (prefix % 2 == 0) else self.portR_wires[parent]

    def router(self, level: int, prefix: int):
        """Helps with fetching the routing qubits of a node."""
        return self.dir_wires[_node_index(level, prefix)]

    def portL(self, level: int, prefix: int):
        """Helps with fetching the left port qubit of a node."""
        return self.portL_wires[_node_index(level, prefix)]

    def portR(self, level: int, prefix: int):
        """Helps with fetching the right port qubit of a node."""
        return self.portR_wires[_node_index(level, prefix)]


# -----------------------------
# Utilities
# -----------------------------
def _level_offset(level: int) -> int:
    """Index offset of the first node at a given level (root=0). Offset = 2^level - 1."""
    return (1 << level) - 1


def _node_index(level: int, prefix_value: int) -> int:
    """Return the flat index (level order) of the internal node at `level` with prefix `prefix_value`."""
    return _level_offset(level) + prefix_value


# -----------------------------
# Select-prefix × Bucket-Brigade with explicit bus routing
# -----------------------------
[docs] class BBQRAM(Operation): # pylint: disable=too-many-instance-attributes r"""Bucket-brigade QRAM with **explicit bus routing** using 3 qubits per node. Bucket-brigade QRAM achieves an :math:`O(\log N)` complexity instead of the typical :math:`N`, where :math:`N` is the number of memory cells addressed. It does this by reducing the number of nodes that need to be visited in a tree which converts our binary address into a unary address at the leaves. In the end, the target wires' state corresponds to the data at the desired address. For more theoretical details on how this algorithm works, please consult `arXiv:0708.1879 <https://arxiv.org/pdf/0708.1879>`__. Args: bitstrings (Sequence[str]): The classical data as a sequence of bitstrings. The size of the classical data must be :math:`2^{\texttt{len(control_wires)}}`. control_wires (WiresLike): The register that stores the index for the entry of the classical data we want to access. target_wires (WiresLike): The register in which the classical data gets loaded. The size of this register must equal each bitstring length in ``bitstrings``. work_wires (WiresLike): The additional wires required to funnel the desired entry of ``bitstrings`` into the target register. The size of the ``work_wires`` register must be :math:`1 + 3 ((2^\texttt{len(control_wires)}) - 1)`. More specifically, the ``work_wires`` register includes the bus, direction, left port and right port wires in that order. Each node in the tree contains one address (direction), one left port and one right port wire. The single bus wire is used for address loading and data routing. Raises: ValueError: if the ``bitstrings`` are not provided, the ``bitstrings`` are of the wrong length, the ``target_wires`` are of the wrong size, or the ``work_wires`` register size is not exactly equal to :math:`1 + 3 ((2^\texttt{len(control_wires)}) - 1)`. .. seealso:: :class:`~.QROM`, :class:`~.QROMStatePreparation` .. note:: QRAM and QROM, though similar, have different applications and purposes. QRAM is intended for read-and-write capabilities, where the stored data can be loaded and changed. QROM is designed to only load stored data into a quantum register. **Example:** Consider the following example, where the classical data is a list of four bitstrings (each of length 3): .. code-block:: python bitstrings = ["010", "111", "110", "000"] bitstring_size = 3 The number of wires needed to store a length-4 array is 2, which means that the ``control_wires`` register must contain 2 wires. Additionally, this lets us specify the number of work wires needed. .. code-block:: python num_control_wires = 2 # len(bistrings) = 4 = 2**2 num_work_wires = 1 + 3 * ((1 << num_control_wires) - 1) # 10 Now, we can define all three registers concretely and demonstrate ``BBQRAM`` in practice. In the following circuit, we prepare the state :math:`\vert 2 \rangle = \vert 10 \rangle` on the ``control_wires``, which indicates that we would like to access the second (zero-indexed) entry of ``bitstrings`` (which is ``"110"``). The ``target_wires`` register should therefore store this state after ``BBQRAM`` is applied. .. code-block:: python import pennylane as qml reg = qml.registers( { "control": num_control_wires, "target": bitstring_size, "work_wires": num_work_wires } ) dev = qml.device("default.qubit") @qml.qnode(dev) def bb_quantum(): # prepare an address, e.g., |10> (index 2) qml.BasisEmbedding(2, wires=reg["control"]) qml.BBQRAM( bitstrings, control_wires=reg["control"], target_wires=reg["target"], work_wires=reg["work_wires"], ) return qml.probs(wires=reg["target"]) >>> import numpy as np >>> print(np.round(bb_quantum())) # doctest: +SKIP [0. 0. 0. 0. 0. 0. 1. 0.] Note that ``"110"`` in binary is equal to 6 in decimal, which is the position of the only non-zero entry in the ``target_wires`` register. """ grad_method = None resource_keys = {"bitstrings"} @property def resource_params(self) -> dict: return { "bitstrings": self.hyperparameters["bitstrings"], } def __init__( self, bitstrings: Sequence[str], control_wires: WiresLike, target_wires: WiresLike, work_wires: WiresLike, id: str | None = None, ): # pylint: disable=too-many-arguments if not bitstrings: raise ValueError("'bitstrings' cannot be empty.") m_set = {len(s) for s in bitstrings} if len(m_set) != 1: raise ValueError("All bitstrings must have equal length.") m = next(iter(m_set)) bitstrings = list(bitstrings) control_wires = Wires(control_wires) n_k = len(control_wires) if (1 << n_k) != len(bitstrings): raise ValueError("len(bitstrings) must be 2^(len(control_wires)).") target_wires = Wires(target_wires) if m != len(target_wires): raise ValueError("len(target_wires) must equal bitstring length.") expected_nodes = (1 << n_k) - 1 if n_k > 0 else 0 if len(work_wires) != 1 + 3 * expected_nodes: raise ValueError(f"work_wires must have length {1 + 3 * expected_nodes}.") bus_wire = Wires(work_wires[0]) divider = len(work_wires[1:]) // 3 dir_wires = Wires(work_wires[1 : 1 + divider]) portL_wires = Wires(work_wires[1 + divider : 1 + divider * 2]) portR_wires = Wires(work_wires[1 + divider * 2 : 1 + divider * 3]) all_wires = ( list(control_wires) + list(target_wires) + list(bus_wire) + list(dir_wires) + list(portL_wires) + list(portR_wires) ) wire_manager = _QRAMWires( control_wires, target_wires, bus_wire, dir_wires, portL_wires, portR_wires ) self._hyperparameters = { "wire_manager": wire_manager, "bitstrings": bitstrings, } super().__init__(wires=all_wires, id=id) @classmethod def _primitive_bind_call(cls, *args, **kwargs): return cls._primitive.bind(*args, **kwargs)
def _bucket_brigade_qram_resources(bitstrings): num_target_wires = len(bitstrings[0]) n_k = int(math.log2(len(bitstrings))) resources = defaultdict(int) resources[resource_rep(SWAP)] = ((1 << n_k) - 1 + n_k) * 2 + num_target_wires * 2 resources[resource_rep(CSWAP)] = ((1 << n_k) - 1) * num_target_wires * 2 + ( ((1 << n_k) - 1 - n_k) * 2 ) resources[ controlled_resource_rep( base_class=SWAP, base_params={}, num_control_wires=1, num_zero_control_values=1 ) ] = ((1 << n_k) - 1) * num_target_wires * 2 + (((1 << n_k) - 1 - n_k) * 2) resources[resource_rep(Hadamard)] += num_target_wires * 2 for j in range(num_target_wires): for p in range(1 << n_k): resources[resource_rep(PauliZ)] += 1 if int(bitstrings[p][j]) else 0 return resources def _mark_routers_via_bus(wire_manager, n_k): """Write low-order address bits into router directions **layer-by-layer** via the bus. For each low bit a_k (k = 0..n_k-1): 1) SWAP(control_wires[k], bus) 2) Route bus down k levels (CSWAPs controlled by routers at levels < k) 3) At node (k, path-prefix), SWAP(bus, dir[k, path-prefix]) """ SWAP([wire_manager.control_wires[0], wire_manager.bus_wire[0]]) SWAP([wire_manager.bus_wire[0], wire_manager.router(0, 0)]) for k in range(1, n_k): # 1) load a_k into the bus origin = wire_manager.control_wires[k] target = wire_manager.bus_wire[0] SWAP(wires=[origin, target]) # 2) route down k levels _route_bus_down_first_k_levels(wire_manager, k) # 3) deposit at level-k node on the active path for p in range(1 << k): # change to in_wire later parent = _node_index(k - 1, p >> 1) origin = ( wire_manager.portL_wires[parent] if p % 2 == 0 else wire_manager.portR_wires[parent] ) target = wire_manager.router(k, p) SWAP(wires=[origin, target]) def _route_bus_down_first_k_levels(wire_manager, k_levels): """Route the bus down the first `k_levels` of the tree using dir-controlled CSWAPs.""" for ell in range(k_levels): for p in range(1 << ell): in_w = wire_manager.node_in_wire(ell, p) L = wire_manager.portL(ell, p) R = wire_manager.portR(ell, p) d = wire_manager.router(ell, p) # dir==1 ⇒ SWAP(in, R) CSWAP(wires=[d, in_w, R]) # dir==0 ⇒ SWAP(in, L) ctrl(SWAP(wires=[in_w, L]), control=[d], control_values=[0]) def _leaf_ops_for_bit(wire_manager, bitstrings, n_k, j): """Apply the leaf write for target bit index j.""" ops = [] for p in range(1 << n_k): if p % 2 == 0: target = wire_manager.portL(n_k - 1, p >> 1) else: target = wire_manager.portR(n_k - 1, p >> 1) bit = bitstrings[p][j] if bit == "1": PauliZ(wires=target) elif bit == "0": pass return ops @register_resources(_bucket_brigade_qram_resources) def _bucket_brigade_qram_decomposition( wires, wire_manager, bitstrings ): # pylint: disable=unused-argument bus_wire = wire_manager.bus_wire control_wires = wire_manager.control_wires n_k = len(control_wires) # 1) address loading _mark_routers_via_bus(wire_manager, n_k) # 2) For each target bit: load→route down→leaf op→route up→restore (reuse the route bus function) for j, tw in enumerate(wire_manager.target_wires): Hadamard(wires=[tw]) SWAP(wires=[tw, bus_wire[0]]) _route_bus_down_first_k_levels(wire_manager, len(control_wires)) _leaf_ops_for_bit(wire_manager, bitstrings, n_k, j) adjoint(_route_bus_down_first_k_levels, lazy=False)(wire_manager, len(control_wires)) SWAP(wires=[tw, bus_wire[0]]) Hadamard(wires=[tw]) # 3) address unloading adjoint(_mark_routers_via_bus, lazy=False)(wire_manager, n_k) add_decomps(BBQRAM, _bucket_brigade_qram_decomposition)