Source code for tensortrade.base.context

import threading
import json
import yaml

from typing import Union, List
from collections import UserDict

from .registry import registered_names, get_major_component_names
from tensortrade.instruments import Instrument, USD


[docs]class TradingContext(UserDict): """A class for objects that put themselves in a `Context` using the `with` statement. The implementation for this class is heavily borrowed from the pymc3 library and adapted with the design goals of TensorTrade in mind. Arguments: shared: A context that is shared between all components that are made under the overarching `TradingContext`. exchanges: A context that is specific to components with a registered name of `exchanges`. actions: A context that is specific to components with a registered name of `actions`. rewards: A context that is specific to components with a registered name of `rewards`. features: A context that is specific to components with a registered name of `features`. Warnings: If there is a conflict in the contexts of different components because they were initialized under different contexts, can have undesirable effects. Therefore, a warning should be made to the user indicating that using components together that have conflicting contexts can lead to unwanted behavior. Reference: - https://github.com/pymc-devs/pymc3/blob/master/pymc3/model.py """ contexts = threading.local() def __init__(self, base_instrument: Instrument = USD, **config): super().__init__(base_instrument=base_instrument, **config) for name in registered_names(): if name not in get_major_component_names(): setattr(self, name, config.get(name, {})) config_items = {k: config[k] for k in config.keys() if k not in registered_names()} self._shared = config.get('shared', {}) self._exchanges = config.get('exchanges', {}) self._actions = config.get('actions', {}) self._rewards = config.get('rewards', {}) self._features = config.get('features', {}) self._slippage = config.get('slippage', {}) self._shared = { 'base_instrument': base_instrument, **self._shared, **config_items } @property def shared(self) -> dict: return self._shared @property def exchanges(self) -> dict: return self._exchanges @property def actions(self) -> dict: return self._actions @property def rewards(self) -> dict: return self._rewards @property def features(self) -> dict: return self._features @property def slippage(self) -> dict: return self._slippage
[docs] def __enter__(self): """Adds a new context to the context stack. This method is used for a `with` statement and adds a `TradingContext` to the context stack. The new context on the stack is then used by every class that subclasses `Component` the initialization of its instances. """ type(self).get_contexts().append(self) return self
def __exit__(self, typ, value, traceback): type(self).get_contexts().pop()
[docs] @classmethod def get_contexts(cls): if not hasattr(cls.contexts, 'stack'): cls.contexts.stack = [TradingContext()] return cls.contexts.stack
[docs] @classmethod def get_context(cls): """Gets the deepest context on the stack.""" return cls.get_contexts()[-1]
[docs] @classmethod def from_json(cls, path: str): with open(path, "rb") as fp: config = json.load(fp) return TradingContext(**config)
[docs] @classmethod def from_yaml(cls, path: str): with open(path, "rb") as fp: config = yaml.load(fp, Loader=yaml.FullLoader) return TradingContext(**config)
[docs]class Context(UserDict): """A context that is injected into every instance of a class that is a subclass of component. Arguments: base_instrument: The exchange symbol of the instrument to store/measure value in. """ def __init__(self, base_instrument: Instrument = USD, **kwargs): super(Context, self).__init__(base_instrument=base_instrument, **kwargs) self._base_instrument = base_instrument self.__dict__ = {**self.__dict__, **self.data} @property def base_instrument(self) -> Instrument: return self._base_instrument def __str__(self): data = ['{}={}'.format(k, getattr(self, k)) for k in self.__slots__] return '<{}: {}>'.format(self.__class__.__name__, ', '.join(data))