# Source code for cirq.sim.wave_function_simulator

# Copyright 2019 The Cirq Developers
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

"""Abstract classes for simulations which keep track of wave functions."""

import abc

from typing import Any, cast, Dict, Iterator, Sequence, TYPE_CHECKING, Tuple

import numpy as np

from cirq import circuits, ops, study, value
from cirq.sim import simulator, wave_function

if TYPE_CHECKING:
import cirq

[docs]class SimulatesIntermediateWaveFunction(simulator.SimulatesAmplitudes,
simulator.SimulatesIntermediateState,
metaclass=abc.ABCMeta):
"""A simulator that accesses its wave function as it does its simulation.

Implementors of this interface should implement the _simulator_iterator
method."""

@abc.abstractmethod
def _simulator_iterator(
self,
circuit: circuits.Circuit,
param_resolver: study.ParamResolver,
qubit_order: ops.QubitOrderOrList,
initial_state: np.ndarray,
) -> Iterator:
"""Iterator over WaveFunctionStepResult from Moments of a Circuit.

Args:
circuit: The circuit to simulate.
param_resolver: A ParamResolver for determining values of
Symbols.
qubit_order: Determines the canonical ordering of the qubits. This
is often used in specifying the initial state, i.e. the
ordering of the computational basis states.
initial_state: The initial state for the simulation. The form of
this state depends on the simulation implementation. See
documentation of the implementing class for details.

Yields:
WaveFunctionStepResult from simulating a Moment of the Circuit.
"""
raise NotImplementedError()

def _create_simulator_trial_result(self,
params: study.ParamResolver,
measurements: Dict[str, np.ndarray],
final_simulator_state: 'WaveFunctionSimulatorState') \
-> 'WaveFunctionTrialResult':
return WaveFunctionTrialResult(
params=params,
measurements=measurements,
final_simulator_state=final_simulator_state)

[docs]    def compute_amplitudes_sweep(
self,
program: 'cirq.Circuit',
bitstrings: Sequence[int],
params: study.Sweepable,
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
) -> Sequence[Sequence[complex]]:
if isinstance(bitstrings, np.ndarray) and len(bitstrings.shape) > 1:
raise ValueError('The list of bitstrings must be input as a '
'1-dimensional array of ints. Got an array with '
f'shape {bitstrings.shape}.')

trial_results = self.simulate_sweep(program, params, qubit_order)

# 1-dimensional tuples don't trigger advanced Numpy array indexing
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
if isinstance(bitstrings, tuple):
bitstrings = list(bitstrings)

all_amplitudes = []
for trial_result in trial_results:
trial_result = cast(WaveFunctionTrialResult, trial_result)
amplitudes = trial_result.final_state[bitstrings]
all_amplitudes.append(amplitudes)

return all_amplitudes

[docs]class WaveFunctionStepResult(simulator.StepResult, metaclass=abc.ABCMeta):

@abc.abstractmethod
def _simulator_state(self) -> 'WaveFunctionSimulatorState':
"""Returns the simulator_state of the simulator after this step.

The form of the simulator_state depends on the implementation of the
simulation,see documentation for the implementing class for the form of
details.
"""
raise NotImplementedError()

[docs]@value.value_equality(unhashable=True)
class WaveFunctionSimulatorState:

[docs]    def __init__(self, state_vector: np.ndarray,
qubit_map: Dict[ops.Qid, int]) -> None:
self.state_vector = state_vector
self.qubit_map = qubit_map
self._qid_shape = simulator._qubit_map_to_shape(qubit_map)

def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def __repr__(self) -> str:
return ('cirq.WaveFunctionSimulatorState('
f'state_vector=np.{self.state_vector!r}, '
f'qubit_map={self.qubit_map!r})')

def _value_equality_values_(self) -> Any:
return (self.state_vector.tolist(), self.qubit_map)

[docs]@value.value_equality(unhashable=True)
class WaveFunctionTrialResult(wave_function.StateVectorMixin,
simulator.SimulationTrialResult):
"""A SimulationTrialResult that includes the StateVectorMixin methods.

Attributes:
final_state: The final wave function of the system.
"""

[docs]    def __init__(self,
params: study.ParamResolver,
measurements: Dict[str, np.ndarray],
final_simulator_state: WaveFunctionSimulatorState) -> None:
super().__init__(params=params,
measurements=measurements,
final_simulator_state=final_simulator_state,
qubit_map=final_simulator_state.qubit_map)
self.final_state = final_simulator_state.state_vector

[docs]    def state_vector(self):
"""Return the wave function at the end of the computation.

The state is returned in the computational basis with these basis
states defined by the qubit_map. In particular the value in the
qubit_map is the index of the qubit, and these are translated into
binary vectors where the last qubit is the 1s bit of the index, the
second-to-last is the 2s bit of the index, and so forth (i.e. big
endian ordering).

Example:
qubit_map: {QubitA: 0, QubitB: 1, QubitC: 2}
Then the returned vector will have indices mapped to qubit basis
states like the following table

|     | QubitA | QubitB | QubitC |
| :-: | :----: | :----: | :----: |
|  0  |   0    |   0    |   0    |
|  1  |   0    |   0    |   1    |
|  2  |   0    |   1    |   0    |
|  3  |   0    |   1    |   1    |
|  4  |   1    |   0    |   0    |
|  5  |   1    |   0    |   1    |
|  6  |   1    |   1    |   0    |
|  7  |   1    |   1    |   1    |
"""
return self._final_simulator_state.state_vector

def _value_equality_values_(self):
measurements = {k: v.tolist() for k, v in
sorted(self.measurements.items())}
return (self.params, measurements, self._final_simulator_state)

def __str__(self) -> str:
samples = super().__str__()
final = self.state_vector()
if len([1 for e in final if abs(e) > 0.001]) < 16:
wave = self.dirac_notation(3)
else:
wave = str(final)
return f'measurements: {samples}\noutput vector: {wave}'

def _repr_pretty_(self, p: Any, cycle: bool) -> None:
"""Text output in Jupyter."""
if cycle:
# There should never be a cycle.  This is just in case.
p.text('WaveFunctionTrialResult(...)')
else:
p.text(str(self))

def __repr__(self) -> str:
return (f'cirq.WaveFunctionTrialResult(params={self.params!r}, '
f'measurements={self.measurements!r}, '
f'final_simulator_state={self._final_simulator_state!r})')