Source code for auto_diff.op.op_constant

from typing import Mapping, Union
import numpy as np
from .operation import Operation
from .op_placeholder import OpPlaceholder


[docs]class OpConstant(Operation): """Contains a constant."""
[docs] def __init__(self, x: Union[int, float, list, np.ndarray], **kwargs): """ :param x: The constant value. :param kwargs: """ if isinstance(x, int): x = float(x) if not np.isscalar(x) and not isinstance(x, np.ndarray): x = np.array(x, dtype=np.float64) self.x = x if np.isscalar(x): self.shape = (1,) else: self.shape = x.shape super(OpConstant, self).__init__(**kwargs)
[docs] def _get_name(self) -> str: if np.isscalar(self.x): return str(self.x) return 'C%s' % str(self.x.shape)
[docs] def _get_op_name(self) -> str: return 'c_%d' % self._op_index
[docs] def _forward(self, feed_dict: Mapping[Union[str, OpPlaceholder], np.ndarray]) -> np.ndarray: """Returns the constant.""" return self.x
[docs] def _backward(self, gradient: Operation) -> None: """No backward operation needed.""" pass