Source code for tensortrade.agents.parallel.parallel_dqn_model

import uuid
import random
import numpy as np
import tensorflow as tf

from typing import Callable


[docs]class ParallelDQNModel: def __init__(self, create_env: Callable[[], 'TradingEnvironment'], policy_network: tf.keras.Model = None): temp_env = create_env() self.n_actions = temp_env.action_space.n self.observation_shape = temp_env.observation_space.shape self.policy_network = policy_network or self._build_policy_network() self.target_network = tf.keras.models.clone_model(self.policy_network) self.target_network.trainable = False self.id = str(uuid.uuid4()) self.episode_id = None def _build_policy_network(self): network = tf.keras.Sequential([ tf.keras.layers.InputLayer(input_shape=self.observation_shape), tf.keras.layers.Conv1D(filters=64, kernel_size=6, padding="same", activation="tanh"), tf.keras.layers.MaxPooling1D(pool_size=2), tf.keras.layers.Conv1D(filters=32, kernel_size=3, padding="same", activation="tanh"), tf.keras.layers.MaxPooling1D(pool_size=2), tf.keras.layers.Flatten(), tf.keras.layers.Dense(self.n_actions, activation="sigmoid"), tf.keras.layers.Dense(self.n_actions, activation="softmax") ]) return network
[docs] def restore(self, path: str, **kwargs): self.policy_network = tf.keras.models.load_model(path) self.target_network = tf.keras.models.clone_model(self.policy_network) self.target_network.trainable = False
[docs] def save(self, path: str, **kwargs): episode: int = kwargs.get('episode', None) if episode: filename = "policy_network__" + self.id + "__" + str(episode).zfill(3) + ".hdf5" else: filename = "policy_network__" + self.id + ".hdf5" self.policy_network.save(path + filename)
[docs] def get_action(self, state: np.ndarray, **kwargs) -> int: threshold: float = kwargs.get('threshold', 0) rand = random.random() if rand < threshold: return np.random.choice(self.n_actions) else: return np.argmax(self.policy_network(np.expand_dims(state, 0)))
[docs] def update_networks(self, model: 'ParallelDQNModel'): self.policy_network.set_weights(model.policy_network.get_weights()) self.target_network.set_weights(model.target_network.get_weights())
[docs] def update_target_network(self): self.target_network.set_weights(self.policy_network.get_weights())