Source code for tensortrade.agents.agent

import numpy as np

from abc import ABCMeta, abstractmethod

from tensortrade.base import Identifiable


[docs]class Agent(Identifiable, metaclass=ABCMeta):
[docs] @abstractmethod def restore(self, path: str, **kwargs): raise NotImplementedError()
[docs] @abstractmethod def save(self, path: str, **kwargs): raise NotImplementedError()
[docs] @abstractmethod def get_action(self, state: np.ndarray, **kwargs) -> int: raise NotImplementedError()
[docs] @abstractmethod def train(self, n_steps: int = None, n_episodes: int = 10000, save_every: int = None, save_path: str = None, callback: callable = None, **kwargs): raise NotImplementedError()