Source code for psychrnn.backend.models.lstm

from __future__ import division

from psychrnn.backend.rnn import RNN
import tensorflow as tf

tf.compat.v1.disable_eager_execution()


[docs]class LSTM(RNN): """ LSTM (Long Short Term Memory) recurrent network model LSTM implementation of :class:`psychrnn.backend.rnn.RNN`. Because LSTM is structured differently from the basic RNN, biological constraints such as dale's, autapses, and connectivity are not enabled. Args: params (dict): See :class:`psychrnn.backend.rnn.RNN` for details. """ def __init__(self, params): # ---------------------------------- # Call RNN constructor # ---------------------------------- super(LSTM, self).__init__(params) # ---------------------------------- # Add new variables for gates # TODO better LSTM initialization # ---------------------------------- self.N_concat = self.N_in + self.N_rec self.init_hidden_initializer = tf.compat.v1.random_normal_initializer(mean=0, stddev=0.1) self.init_cell_initializer = tf.compat.v1.random_normal_initializer(mean=0, stddev=0.1) self.W_f_initializer = tf.compat.v1.random_normal_initializer(mean=0, stddev=0.1) self.W_i_initializer = tf.compat.v1.random_normal_initializer(mean=0, stddev=0.1) self.W_c_initializer = tf.compat.v1.random_normal_initializer(mean=0, stddev=0.1) self.W_o_initializer = tf.compat.v1.random_normal_initializer(mean=0, stddev=0.1) self.b_f_initializer = tf.compat.v1.constant_initializer(1.0) self.b_i_initializer = tf.compat.v1.constant_initializer(1.0) self.b_c_initializer = tf.compat.v1.constant_initializer(1.0) self.b_o_initializer = tf.compat.v1.constant_initializer(1.0) # ---------------------------------- # TensorFlow initializations # ---------------------------------- with tf.compat.v1.variable_scope(self.name) as scope: self.init_hidden = tf.compat.v1.get_variable('init_hidden', [self.N_batch, self.N_rec], initializer=self.init_hidden_initializer, trainable=True) self.init_cell = tf.compat.v1.get_variable('init_cell', [self.N_batch, self.N_rec], initializer=self.init_cell_initializer, trainable=True) self.W_f = tf.compat.v1.get_variable('W_f', [self.N_concat, self.N_rec], initializer=self.W_f_initializer, trainable=True) self.W_i = tf.compat.v1.get_variable('W_i', [self.N_concat, self.N_rec], initializer=self.W_i_initializer, trainable=True) self.W_c = tf.compat.v1.get_variable('W_c', [self.N_concat, self.N_rec], initializer=self.W_c_initializer, trainable=True) self.W_o = tf.compat.v1.get_variable('W_o', [self.N_concat, self.N_rec], initializer=self.W_o_initializer, trainable=True) self.b_f = tf.compat.v1.get_variable('b_f', [self.N_rec], initializer=self.b_f_initializer, trainable=True) self.b_i = tf.compat.v1.get_variable('b_i', [self.N_rec], initializer=self.b_i_initializer, trainable=True) self.b_c = tf.compat.v1.get_variable('b_c', [self.N_rec], initializer=self.b_c_initializer, trainable=True) self.b_o = tf.compat.v1.get_variable('b_o', [self.N_rec], initializer=self.b_o_initializer, trainable=True)
[docs] def recurrent_timestep(self, rnn_in, hidden, cell): """ Recurrent time step. Given input and previous state, outputs the next state of the network. Arguments: rnn_in (*tf.Tensor(dtype=float, shape=(?*, :attr:`N_in` *))*): Input to the rnn at a certain time point. hidden (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*): Hidden units state of network at previous time point. cell (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*): Cell state of the network at previous time point. Returns: tuple: * **new_hidden** (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*) -- New hidden unit state of the network. * **new_cell** (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*) -- New cell state of the network. """ f = tf.nn.sigmoid(tf.matmul(tf.concat([hidden, rnn_in], 1), self.W_f) + self.b_f) i = tf.nn.sigmoid(tf.matmul(tf.concat([hidden, rnn_in], 1), self.W_i) + self.b_i) c = tf.nn.tanh(tf.matmul(tf.concat([hidden, rnn_in], 1), self.W_c) + self.b_c) o = tf.nn.sigmoid(tf.matmul(tf.concat([hidden, rnn_in], 1), self.W_o) + self.b_o) new_cell = f * cell + i * c new_hidden = o * tf.nn.sigmoid(new_cell) return new_hidden, new_cell
[docs] def output_timestep(self, hidden): """Returns the output node activity for a given timestep. Arguments: hidden (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*): Hidden units of network at a given timepoint for each trial in the batch. Returns: output (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_out` *))*): Output of the network at a given timepoint for each trial in the batch. """ output = tf.matmul(hidden, self.W_out, transpose_b=True) + self.b_out return output
[docs] def forward_pass(self): """ Run the LSTM on a batch of task inputs. Iterates over timesteps, running the :func:`recurrent_timestep` and :func:`output_timestep` Implements :func:`psychrnn.backend.rnn.RNN.forward_pass`. Returns: tuple: * **predictions** (*tf.Tensor(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Network output on inputs found in self.x within the tf network. * **hidden** (*tf.Tensor(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_rec` *))*) -- Hidden unit values over the course of the trials found in self.x within the tf network. """ rnn_inputs = tf.unstack(self.x, axis=1) hidden = self.init_hidden cell = self.init_cell rnn_outputs = [] rnn_states = [] for rnn_input in rnn_inputs: hidden, cell = self.recurrent_timestep(rnn_input, hidden, cell) output = self.output_timestep(hidden) rnn_outputs.append(output) rnn_states.append(hidden) return tf.transpose(a=rnn_outputs, perm=[1, 0, 2]), tf.transpose(a=rnn_states, perm=[1, 0, 2])