Source code for cirq.study.flatten_expressions

# Copyright 2019 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Resolves symbolic expressions to unique symbols."""

from typing import overload, Any, Callable, List, Optional, Tuple, Union

import sympy

from cirq import protocols
from cirq.study import resolver, sweeps, sweepable


[docs]def flatten(val: Any) -> Tuple[Any, 'ExpressionMap']: """Creates a copy of `val` with any symbols or expressions replaced with new symbols. `val` can be a `Circuit`, `Gate`, `Operation`, or other type. `flatten` goes through every parameter in `val` and does the following: - If the parameter is a number, don't change it. - If the parameter is a symbol, don't change it. - If the parameter is an expression, replace it with a symbol. The new symbol will be `sympy.Symbol('<x + 1>')` if the expression was `sympy.Symbol('x') + 1`. In the unlikely case that an expression with a different meaning also has the string `'x + 1'`, a number is appended to the name to avoid collision: `sympy.Symbol('<x + 1>_1')`. This function also creates a dictionary mapping from expressions and symbols in `val` to the new symbols in the flattened copy of `val`. E.g `cirq.ExpressionMap({sympy.Symbol('x')+1: sympy.Symbol('<x + 1>')})`. This `ExpressionMap` can be used to transform a sweep over the symbols in `val` to a sweep over the flattened symbols e.g. a sweep over `sympy.Symbol('x')` to a sweep over `sympy.Symbol('<x + 1>')`. Args: val: The value to copy and substitute parameter expressions with flattened symbols. Returns: The tuple (new value, expression map) where new value and expression map are described above. Examples: >>> qubit = cirq.LineQubit(0) >>> a = sympy.Symbol('a') >>> circuit = cirq.Circuit( ... cirq.X(qubit) ** (a/4), ... cirq.Y(qubit) ** (1-a/2), ... ) >>> print(circuit) 0: ───X^(a/4)───Y^(1 - a/2)─── >>> sweep = cirq.Linspace(a, start=0, stop=3, length=4) >>> print(cirq.ListSweep(sweep)) Sweep: {'a': 0.0} {'a': 1.0} {'a': 2.0} {'a': 3.0} >>> c_flat, expr_map = cirq.flatten(circuit) >>> print(c_flat) 0: ───X^(<a/4>)───Y^(<1 - a/2>)─── >>> expr_map cirq.ExpressionMap({a/4: <a/4>, 1 - a/2: <1 - a/2>}) >>> new_sweep = expr_map.transform_sweep(sweep) >>> print(new_sweep) Sweep: {'<a/4>': 0.0, '<1 - a/2>': 1.0} {'<a/4>': 0.25, '<1 - a/2>': 0.5} {'<a/4>': 0.5, '<1 - a/2>': 0.0} {'<a/4>': 0.75, '<1 - a/2>': -0.5} >>> for params in sweep: # Original ... print(circuit, ... '=>', ... cirq.resolve_parameters(circuit, params)) 0: ───X^(a/4)───Y^(1 - a/2)─── => 0: ───X^0───Y─── 0: ───X^(a/4)───Y^(1 - a/2)─── => 0: ───X^0.25───Y^0.5─── 0: ───X^(a/4)───Y^(1 - a/2)─── => 0: ───X^0.5───Y^0─── 0: ───X^(a/4)───Y^(1 - a/2)─── => 0: ───X^0.75───Y^-0.5─── >>> for params in new_sweep: # Flattened ... print(c_flat, '=>', end=' ') ... print(cirq.resolve_parameters(c_flat, params)) 0: ───X^(<a/4>)───Y^(<1 - a/2>)─── => 0: ───X^0───Y─── 0: ───X^(<a/4>)───Y^(<1 - a/2>)─── => 0: ───X^0.25───Y^0.5─── 0: ───X^(<a/4>)───Y^(<1 - a/2>)─── => 0: ───X^0.5───Y^0─── 0: ───X^(<a/4>)───Y^(<1 - a/2>)─── => 0: ───X^0.75───Y^-0.5─── """ flattener = _ParamFlattener() val_flat = flattener.flatten(val) expr_map = ExpressionMap(flattener.param_dict) return val_flat, expr_map
[docs]def flatten_with_sweep(val: Any, sweep: Union[sweeps.Sweep, List[resolver.ParamResolver]] ) -> Tuple[Any, sweeps.Sweep]: """Creates a copy of `val` with any symbols or expressions replaced with new symbols. `val` can be a `Circuit`, `Gate`, `Operation`, or other type. Also transforms a sweep over the symbols in `val` to a sweep over the new symbols. `flatten_with_sweep` goes through every parameter in `val` and does the following: - If the parameter is a number, don't change it. - If the parameter is a symbol, don't change it and use the same symbol with the same values in the new sweep. - If the parameter is an expression, replace it with a symbol and use the new symbol with the evaluated value of the expression in the new sweep. The new symbol will be `sympy.Symbol('<x + 1>')` if the expression was `sympy.Symbol('x') + 1`. In the unlikely case that an expression with a different meaning also has the string `'x + 1'`, a number is appended to the name to avoid collision: `sympy.Symbol('<x + 1>_1')`. Args: val: The value to copy and substitute parameter expressions with flattened symbols. sweep: A sweep over parameters used by `val`. Returns: The tuple (new value, new sweep) where new value is `val` with flattened expressions and new sweep is the equivalent sweep over it. """ val_flat, expr_map = flatten(val) new_sweep = expr_map.transform_sweep(sweep) return val_flat, new_sweep
[docs]def flatten_with_params(val: Any, params: resolver.ParamResolverOrSimilarType ) -> Tuple[Any, resolver.ParamDictType]: """Creates a copy of `val` with any symbols or expressions replaced with new symbols. `val` can be a `Circuit`, `Gate`, `Operation`, or other type. Also transforms a dictionary of symbol values for `val` to an equivalent dictionary mapping the new symbols to their evaluated values. `flatten_with_params` goes through every parameter in `val` and does the following: - If the parameter is a number, don't change it. - If the parameter is a symbol, don't change it and use the same symbol with the same value in the new dictionary of symbol values. - If the parameter is an expression, replace it with a symbol and use the new symbol with the evaluated value of the expression in the new dictionary of symbol values. The new symbol will be `sympy.Symbol('<x + 1>')` if the expression was `sympy.Symbol('x') + 1`. In the unlikely case that an expression with a different meaning also has the string `'x + 1'`, a number is appended to the name to avoid collision: `sympy.Symbol('<x + 1>_1')`. Args: val: The value to copy and substitute parameter expressions with flattened symbols. params: A dictionary or `ParamResolver` where the keys are `sympy.Symbol`s used by `val` and the values are numbers. Returns: The tuple (new value, new params) where new value is `val` with flattened expressions and new params is a dictionary mapping the new symbols like `sympy.Symbol('<x + 1>')` to numbers like `params['x'] + 1`. """ val_flat, expr_map = flatten(val) new_params = expr_map.transform_params(params) return val_flat, new_params
class _ParamFlattener(resolver.ParamResolver): """A `ParamResolver` that resolves sympy expressions to unique symbols. This is a mutable object that stores new expression to symbol mappings when it is used to resolve parameters with `cirq.resolve_parameters` or `_ParamFlattener.flatten_circuit`. It is useful for replacing sympy expressions from circuits with single symbols and transforming parameter sweeps to match. """ def __new__(cls, *args, **kwargs): """Disables the behavior of `ParamResolver.__new__`.""" return super().__new__(cls) def __init__( self, param_dict: Optional[resolver.ParamResolverOrSimilarType] = None, *, # Force keyword args get_param_name: Callable[[ sympy.Basic, ], str] = None): """Initializes a new _ParamFlattener. Args: param_dict: A default initial mapping from some parameter names, symbols, or expressions to other symbols or values. Only sympy expressions and symbols not specified in `param_dict` will be flattened. get_param_name: A callback function that returns a new parameter name for a given sympy expression or symbol. If this function returns the same value for two different expressions, `'_#'` is appended to the name to avoid name collision where `#` is the number of previous collisions. By default, returns the expression string surrounded by angle brackets e.g. `'<x+1>'`. """ if hasattr(self, '_taken_symbols'): # Has already been initialized return if isinstance(param_dict, resolver.ParamResolver): params = param_dict.param_dict else: params = param_dict if param_dict else {} symbol_params = { _ensure_not_str(param): _ensure_not_str(val) for param, val in params.items() } super().__init__(symbol_params) if get_param_name is None: get_param_name = self.default_get_param_name self.get_param_name = get_param_name self._taken_symbols = set(self.param_dict.values()) @staticmethod def default_get_param_name(val: sympy.Basic) -> str: if isinstance(val, sympy.Symbol): return val.name return '<{!s}>'.format(val) def _next_symbol(self, val: sympy.Basic) -> sympy.Symbol: name = self.get_param_name(val) symbol = sympy.Symbol(name) # Ensure the symbol hasn't already been used collision = 0 while symbol in self._taken_symbols: collision += 1 symbol = sympy.Symbol('{}_{}'.format(name, collision)) return symbol def value_of(self, value: Union[sympy.Basic, float, str] ) -> Union[sympy.Basic, float]: """Resolves a symbol or expression to a new symbol unique to that value. - If value is a float, returns it. - If value is a str, treat it as a symbol with that name and continue. - Otherwise return a symbol unique to the given value. Return `param_dict[value]` if it exists or create a new symbol and add it to `param_dict`. Args: value: The sympy.Symbol, sympy expression, name, or float to resolve to a unique symbol or float. Returns: The unique symbol or value of the parameter as resolved by this resolver. """ if isinstance(value, (int, float)): return value if isinstance(value, str): value = sympy.Symbol(value) out = self.param_dict.get(value, None) if out is not None: return out # Create a new symbol symbol = self._next_symbol(value) self.param_dict[value] = symbol self._taken_symbols.add(symbol) return symbol # Default object truth, equality, and hash __eq__ = object.__eq__ __ne__ = object.__ne__ __hash__ = object.__hash__ def __bool__(self) -> bool: return True def __repr__(self) -> str: if self.get_param_name == self.default_get_param_name: return f'_ParamFlattener({self.param_dict!r})' else: return (f'_ParamFlattener({self.param_dict!r}, ' f'get_param_name={self.get_param_name!r})') def flatten(self, val: Any) -> Any: """Returns a copy of `val` with any symbols or expressions replaced with new symbols. `val` can be a `Circuit`, `Gate`, `Operation`, or other type. This method mutates the `_ParamFlattener` by storing any new mappings from expression to symbol that is uses on val. Args: val: The value to copy with substituted parameters. """ return protocols.resolve_parameters(val, self)
[docs]class ExpressionMap(dict): """A dictionary with sympy expressions and symbols for keys and sympy symbols for values. This is returned by `cirq.flatten`. See `ExpressionMap.transform_sweep` and `ExpressionMap.transform_params`. """
[docs] def __init__(self, *args, **kwargs): """Initializes the `ExpressionMap`. Takes the same arguments as the builtin `dict`. Keys must be sympy expressions or symbols (instances of `sympy.Basic`). """ super().__init__(*args, **kwargs)
[docs] def transform_sweep(self, sweep: Union[sweeps.Sweep, List[resolver.ParamResolver]] ) -> sweeps.Sweep: """Returns a sweep to use with a circuit flattened earlier with `cirq.flatten`. If `sweep` sweeps symbol `a` over (1.0, 2.0, 3.0) and this `ExpressionMap` maps `a/2+1` to the symbol `'<a/2 + 1>'` then this method returns a sweep that sweeps symbol `'<a/2 + 1>'` over (1.5, 2, 2.5). See `cirq.flatten` for an example. Args: sweep: The sweep to transform. """ sweep = sweepable.to_sweep(sweep) param_list = [] for r in sweep: param_dict = {} for formula, sym in self.items(): if isinstance(sym, (sympy.Symbol, str)): param_dict[str(sym)] = protocols.resolve_parameters( formula, r) param_list.append(param_dict) return sweeps.ListSweep(param_list)
[docs] def transform_params(self, params: resolver.ParamResolverOrSimilarType ) -> resolver.ParamDictType: """Returns a `ParamResolver` to use with a circuit flattened earlier with `cirq.flatten`. If `params` maps symbol `a` to 3.0 and this `ExpressionMap` maps `a/2+1` to `'<a/2 + 1>'` then this method returns a resolver that maps symbol `'<a/2 + 1>'` to 2.5. See `cirq.flatten` for an example. Args: params: The params to transform. """ param_dict = { sym: protocols.resolve_parameters(formula, params) for formula, sym in self.items() if isinstance(sym, sympy.Basic) } return param_dict
def __repr__(self) -> str: super_repr = super().__repr__() return f'cirq.ExpressionMap({super_repr})'
@overload def _ensure_not_str(param: Union[sympy.Basic, str]) -> sympy.Basic: pass @overload def _ensure_not_str(param: float) -> float: pass def _ensure_not_str(param): if isinstance(param, str): return sympy.Symbol(param) return param