from typing import Union, Mapping, Optional, Sequence
import numpy as np
[docs]class Operation(object):
"""Abstract operation for building computing graph."""
#: The counter for giving each operation a unique index.
__op_counter = [0]
#: Collection of existing operations.
__op_collection = {}
#: The key for extracting step information from session.
STEP_KEY = '__step__'
[docs] def __init__(self, **kwargs):
if not hasattr(self, 'name'):
if 'name' in kwargs:
self.name: str = kwargs['name']
else:
self.name: str = self._get_name()
if not hasattr(self, 'shape'):
self.shape: Sequence = None
raise NotImplementedError('Shape not defined')
if not hasattr(self, 'inputs'):
self.inputs: Sequence['Operation'] = []
self.gradient: Optional['Operation'] = None
self._op_index = self.__op_counter[0]
self.__op_counter[0] += 1
self._op_name = self._get_op_name()
self.__op_collection[self] = self
self._last_step = -1
self._last_forward = None
[docs] def _get_name(self) -> str:
"""Get the name for display."""
raise NotImplementedError('Get name not implemented')
[docs] def _get_op_name(self) -> str:
"""Get the name for indexing."""
raise NotImplementedError('Get operation name not implemented')
@property
def dim(self) -> int:
return len(self.shape)
[docs] def isscalar(self) -> bool:
return self.shape in [(), (1,)]
[docs] def forward(self, feed_dict: Mapping[Union[str, 'Operation'], np.ndarray] = None) -> np.ndarray:
"""Do the calculations to get the output of the operations.
:param feed_dict: Contains the real values of placeholders, see :class:`OpPlaceholder`.
:return: A numpy array.
"""
if feed_dict is None:
feed_dict = {}
if self.STEP_KEY in feed_dict and feed_dict[self.STEP_KEY] == self._last_step:
return self._last_forward
output = self._forward(feed_dict)
if self.STEP_KEY in feed_dict:
self._last_step = feed_dict[self.STEP_KEY]
self._last_forward = output
return output
[docs] def _forward(self, feed_dict: Mapping[Union[str, 'Operation'], np.ndarray]) -> np.ndarray:
"""Forward operation to be implemented."""
raise NotImplementedError('Forward operation not implemented')
[docs] def backward(self, gradient: 'Operation' = None) -> None:
"""Update gradients recursively.
:param gradient: Current gradient.
"""
if gradient is None:
from .op_constant import OpConstant
gradient = OpConstant(np.ones(self.shape), name='ones%s' % str(self.shape))
self.gradient = gradient
self._backward(gradient)
[docs] def _backward(self, gradient: 'Operation') -> None:
"""Backward operation to be implemented."""
raise NotImplementedError('Backward operation not implemented')
[docs] def _broadcast_shape(self, x: 'Operation', y: 'Operation'):
min_dim = min(len(x.shape), len(y.shape))
shape = []
for i in range(1, min_dim + 1):
if x.shape[-i] != 1 and y.shape[-i] != 1 and x.shape[-i] != y.shape[-i]:
raise ValueError('Cannot broadcast with shape %s and %s' % (str(x.shape), str(y.shape)))
shape.append(max(x.shape[-i], y.shape[-i]))
self.shape = tuple(list(x.shape[:-min_dim]) + list(y.shape[:-min_dim]) + list(reversed(shape)))
[docs] def _broadcast_backward(self, gradient: 'Operation'):
if self.shape == gradient.shape:
return gradient
expand_dim = len(gradient.shape) - len(self.shape)
axis = list(range(expand_dim))
for i, dim in enumerate(self.shape):
if self.shape[i] == 1 and gradient.shape[i + expand_dim] > 1:
axis.append(expand_dim + i)
if len(axis) == 1:
axis = axis[0]
else:
axis = tuple(axis)
gradient = gradient.sum(axis=axis, keepdims=True)
if expand_dim:
gradient = gradient.squeeze(axis=tuple(list(range(expand_dim))))
return gradient
[docs] def transpose(self, axes: Optional[Sequence[int]] = None, **kwargs) -> 'Operation':
"""See :class:`OpTranspose`."""
from .op_transpose import OpTranspose
return OpTranspose(self, axes, **kwargs)
[docs] def reshape(self, shape: Sequence[int], **kwargs) -> 'Operation':
"""See :class:`OpReshape`."""
from .op_reshape import OpReshape
return OpReshape(self, shape, **kwargs)
[docs] def flatten(self, **kwargs) -> 'Operation':
"""See :class:`OpFlatten`."""
from .op_flatten import OpFlatten
return OpFlatten(self, **kwargs)
[docs] def expand_dims(self, axis: Optional[int] = None, **kwargs) -> 'Operation':
"""See :class:`OpExpandDims`."""
from .op_expand_dims import OpExpandDims
return OpExpandDims(self, axis, **kwargs)
[docs] def squeeze(self, axis=None, **kwargs) -> 'Operation':
"""See :class:`OpSqueeze`."""
from .op_squeeze import OpSqueeze
return OpSqueeze(self, axis, **kwargs)
[docs] def sum(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, **kwargs) -> 'Operation':
"""See :class:`OpSum`."""
from .op_sum import OpSum
return OpSum(self, axis, keepdims, **kwargs)
[docs] def dot(self, x: 'Operation', **kwargs) -> 'Operation':
"""See :class:`OpDot`."""
from .op_dot import OpDot
return OpDot(self, x, **kwargs)
def __add__(self, other) -> 'Operation':
"""See :class:`OpAdd`."""
from .op_add import OpAdd
from .op_constant import OpConstant
if not isinstance(other, Operation):
other = OpConstant(other)
return OpAdd(self, other)
def __radd__(self, other) -> 'Operation':
"""See :class:`OpAdd`."""
from .op_add import OpAdd
from .op_constant import OpConstant
if not isinstance(other, Operation):
other = OpConstant(other)
return OpAdd(other, self)
def __sub__(self, other) -> 'Operation':
"""See :class:`OpSubtract`."""
from .op_subtract import OpSubtract
from .op_constant import OpConstant
if not isinstance(other, Operation):
other = OpConstant(other)
return OpSubtract(self, other)
def __rsub__(self, other) -> 'Operation':
"""See :class:`OpSubtract`."""
from .op_subtract import OpSubtract
from .op_constant import OpConstant
if not isinstance(other, Operation):
other = OpConstant(other)
return OpSubtract(other, self)
def __mul__(self, other) -> 'Operation':
"""See :class:`OpMultiply`."""
from .op_multiply import OpMultiply
from .op_constant import OpConstant
if not isinstance(other, Operation):
other = OpConstant(other)
return OpMultiply(self, other)
def __rmul__(self, other):
"""See :class:`OpMultiply`."""
from .op_multiply import OpMultiply
from .op_constant import OpConstant
if not isinstance(other, Operation):
other = OpConstant(other)
return OpMultiply(other, self)
def __neg__(self) -> 'Operation':
"""See :class:`OpNegative`."""
from .op_negative import OpNegative
return OpNegative(self)
[docs] def simplify(self) -> 'Operation':
from ..simple import simplify
return simplify(self)
def __hash__(self):
return hash(self._op_index)
def __eq__(self, other: 'Operation'):
return self._op_index == other._op_index
def __str__(self):
return self.name
def __unicode__(self):
return self.__str__()