Source code for cirq.protocols.mul

# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from cirq.protocols.resolve_parameters import is_parameterized

# This is a special indicator value used to determine whether or not the caller
# provided a 'default' argument.
RaiseTypeErrorIfNotProvided = ([],)  # type: Any


[docs]def mul(lhs: Any, rhs: Any, default: Any = RaiseTypeErrorIfNotProvided) -> Any: """Returns lhs * rhs, or else a default if the operator is not implemented. This method is mostly used by __pow__ methods trying to return NotImplemented instead of causing a TypeError. Args: lhs: Left hand side of the multiplication. rhs: Right hand side of the multiplication. default: Default value to return if the multiplication is not defined. If not default is specified, a type error is raised when the multiplication fails. Returns: The product of the two inputs, or else the default value if the product is not defined, or else raises a TypeError if no default is defined. Raises: TypeError: lhs doesn't have __mul__ or it returned NotImplemented AND lhs doesn't have __rmul__ or it returned NotImplemented AND a default value isn't specified. """ # Use left-hand-side's __mul__. left_mul = getattr(lhs, '__mul__', None) result = NotImplemented if left_mul is None else left_mul(rhs) # Fallback to right-hand-side's __rmul__. if result is NotImplemented: right_mul = getattr(rhs, '__rmul__', None) result = NotImplemented if right_mul is None else right_mul(lhs) # Don't build up factors of 1.0 vs sympy Symbols. if lhs == 1 and is_parameterized(rhs): result = rhs if rhs == 1 and is_parameterized(lhs): result = lhs if lhs == -1 and is_parameterized(rhs): result = -rhs if rhs == -1 and is_parameterized(lhs): result = -lhs # Output. if result is not NotImplemented: return result if default is not RaiseTypeErrorIfNotProvided: return default raise TypeError("unsupported operand type(s) for *: '{}' and '{}'".format( type(lhs), type(rhs)))
# pylint: enable=function-redefined, redefined-builtin