Search

Search Algorithms Written in Qwerty #

Grover’s Algorithm #

Algorithm (grover.py):
import math
from qwerty import *

def grover(oracle, n_iter, n_shots):
  @qpu[N](oracle)
  def grover_iter(oracle: cfunc[N,1],
                  q: qubit[N]) \
                 -> qubit[N]:
    return \
      q | oracle.phase \
        | -('+'[N] >> -'+'[N])

  @qpu[N,I](grover_iter)
  def kernel(grover_iter: qfunc[N]) \
            -> bit[N]:
    return \
      '+'[N] | (grover_iter
                for _ in range(I)) \
             | std[N].measure

  kern_inst = kernel[[n_iter]]
  results = kern_inst(shots=n_shots)
  return {r for r in set(results)
            if oracle(r)}

def get_n_iter(n_qubits, n_answers):
  n = 2**n_qubits
  m = n_answers
  theta = 2*math.acos(
            math.sqrt((n-m)/n))
  rnd = lambda x: math.ceil(x-0.5)
  return rnd(math.acos(
             math.sqrt(m/n))/theta)
Driver (grover_driver.py):
from qwerty import *
from grover import grover, get_n_iter
import sys

@classical[N]
def all_ones(x: bit[N]) -> bit:
  return x.and_reduce()
n_ans = 1

n_qubits = int(sys.argv[1])
oracle = all_ones[[n_qubits]]
n_iter = get_n_iter(n_qubits, n_ans)
answers = grover(oracle, n_iter,
                 n_shots=32)
for answer in answers:
  print(answer)

Fixed-Point Amplitude Amplification #

Algorithm (fix_pt_amp.py):
from qwerty import *
from fix_pt_phases import get_phases

def fix_pt_amp(a, oracle, orig_prob,
               new_prob=0.98,
               n_shots=2048,
               histogram=False):
  phis = get_phases(orig_prob,
                    new_prob)

  @qpu[N,K,D](phis, a, oracle)
  def amp_iter(phis: angle[2*D],
               a: rev_qfunc[N],
               oracle: cfunc[N,1],
               q: qubit[N+1]) \
              -> qubit[N+1]:
    return \
      q | oracle.xor_embed \
        | id[N] + \
          std.rotate(phis[[2*K]]) \
        | oracle.xor_embed \
        | ~a + id \
        | '0'[N] & std.flip \
        | id[N] + \
          std.rotate(phis[[2*K+1]]) \
        | '0'[N] & std.flip \
        | a + id

  @qpu[N,D](phis, a, amp_iter)
  def kernel(
      phis: angle[2*D], a: qfunc[N],
      amp_iter: qfunc[N+1][[...]]) \
            -> bit[N]:
    return '0'[N+1] \
           | a + id \
           | (amp_iter[[k]]
              for k in range(D)) \
           | std[N].measure + discard

  return kernel(shots=n_shots,
                histogram=histogram)
Driver (fix_pt_amp_driver.py):
from qwerty import *
from fix_pt_amp import fix_pt_amp
import sys

@qpu[N]
@reversible
def a(q: qubit[N]) -> qubit[N]:
    return q | '+'[N].prep

@classical[N]
def oracle_(x: bit[N]) -> bit:
    return ~x[:N-1].and_reduce()

n_qubits = int(sys.argv[1])
oracle = oracle_[[n_qubits]]
orig_prob = 1/2**n_qubits
res = fix_pt_amp(a, oracle,
                 orig_prob, 0.98,
                 histogram=True)
print_histogram(res)
# Prints:
# 00 -> 48.93%
# 01 -> 49.46%
# 10 -> 0.44%
# 11 -> 1.17%

Niroula–Nam String Matching #

Algorithm (match.py):
import math
from qwerty import *
from fix_pt_amp import fix_pt_amp

def match(string, pat):
  n, m = len(string), len(pat)
  k = math.ceil(math.log2(n))

  @classical[K(k),N(n),M(m)]
  def shift_and_cmp(off: bit[K],
                    string: bit[N],
                    pat: bit[M]) \
                   -> bit[K+N+M]:
    return off, \
           string, \
           string.rotl(off)[:M] ^ pat

  @qpu[K(k),N,M](string, pat,
                 shift_and_cmp)
  @reversible
  def a(string: bit[N], pat: bit[M],
        shift_and_cmp: cfunc[K+N+M],
        q: qubit[K+N+M]) \
       -> qubit[K+N+M]:
    return \
      q | '+'[K].prep + string.prep \
                      + pat.prep \
        | shift_and_cmp \
          .inplace(shift_and_cmp)

  @classical[K(k),N(n),M(m)]
  def oracle(off: bit[K],
             string: bit[N],
             pat: bit[M]) -> bit:
    return (~pat).and_reduce()

  ret = fix_pt_amp(a, oracle, 1/n)
  return {int(result[:k])
          for result in set(ret)
          if oracle(result)}
Driver (match_driver.py):
from qwerty import *
from match import match
import sys

string = bit.from_str(sys.argv[1])
pat = bit.from_str(sys.argv[2])
print('Matching indices:')
for index in match(string, pat):
  print(index)
# Output for inputs 1010 and 10: 0, 2