Source code for auto_diff.sess.session

from typing import Union, Mapping, List
from auto_diff.op.operation import Operation

__all__ = ['Session']


[docs]class Session(object): __step = [0]
[docs] def __init__(self): self.prepare()
[docs] def prepare(self): self.__step[0] += 1
[docs] def run(self, fetches: Union[Operation, List[Operation], Mapping[str, Operation]], feed_dict=None): if feed_dict is None: feed_dict = {} feed_dict[Operation.STEP_KEY] = self.__step[0] if isinstance(fetches, Operation): return fetches.forward(feed_dict) if isinstance(fetches, list): return [fetch.forward(feed_dict) for fetch in fetches] if isinstance(fetches, dict): return {key: fetch.forward(feed_dict) for key, fetch in fetches.items()} raise NotImplementedError('Unknown type of fetches: %s' % type(fetches))