# Copyright 2019 The TensorTrade Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from tensortrade.data.stream import Node
[docs]class DataFeed(Node):
def __init__(self, nodes: List[Node] = None):
super().__init__("")
self.process = None
self.compiled = False
if nodes:
self.__call__(*nodes)
@staticmethod
def _gather(node, vertices, edges):
if node not in vertices:
vertices += [node]
for input_node in node.inputs:
edges += [(input_node, node)]
for input_node in node.inputs:
DataFeed._gather(input_node, vertices, edges)
return edges
[docs] def gather(self):
return self._gather(self, [], [])
[docs] @staticmethod
def toposort(edges):
S = set([s for s, t in edges])
T = set([t for s, t in edges])
starting = list(S.difference(T))
process = starting.copy()
while len(starting) > 0:
start = starting.pop()
edges = list(filter(lambda e: e[0] != start, edges))
S = set([s for s, t in edges])
T = set([t for s, t in edges])
starting += [v for v in S.difference(T) if v not in starting]
if start not in process:
process += [start]
return process
[docs] def compile(self):
edges = self.gather()
self.process = self.toposort(edges)
self.compiled = True
self.reset()
[docs] def run(self):
if not self.compiled:
self.compile()
for node in self.process:
node.run()
super().run()
[docs] def forward(self):
return {node.name: node.value for node in self.inputs}
[docs] def next(self):
self.run()
for listener in self.listeners:
listener.on_next(self.value)
return self.value
[docs] def has_next(self) -> bool:
return all(node.has_next() for node in self.process)
def __add__(self, other):
if isinstance(other, DataFeed):
nodes = list(set(self.inputs + other.inputs))
feed = DataFeed(nodes)
for listener in self.listeners + other.listeners:
feed.attach(listener)
return feed
[docs] def reset(self):
for node in self.process:
node.reset()