Source code for psychrnn.backend.models.basic

from __future__ import division

from psychrnn.backend.rnn import RNN
import tensorflow as tf
tf.compat.v1.disable_eager_execution()


[docs]class Basic(RNN): """ The basic continuous time recurrent neural network model. Basic implementation of :class:`psychrnn.backend.rnn.RNN` with a simple RNN, enabling biological constraints. Args: params (dict): See :class:`psychrnn.backend.rnn.RNN` for details. """
[docs] def recurrent_timestep(self, rnn_in, state): """ Recurrent time step. Given input and previous state, outputs the next state of the network. Arguments: rnn_in (*tf.Tensor(dtype=float, shape=(?*, :attr:`N_in` *))*): Input to the rnn at a certain time point. state (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*): State of network at previous time point. Returns: new_state (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*): New state of the network. """ new_state = ((1-self.alpha) * state) \ + self.alpha * ( tf.matmul( self.transfer_function(state), self.get_effective_W_rec(), transpose_b=True, name="1") + tf.matmul( rnn_in, self.get_effective_W_in(), transpose_b=True, name="2") + self.b_rec)\ + tf.sqrt(2.0 * self.alpha * self.rec_noise * self.rec_noise)\ * tf.random.normal(tf.shape(input=state), mean=0.0, stddev=1.0) return new_state
[docs] def output_timestep(self, state): """Returns the output node activity for a given timestep. Arguments: state (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*): State of network at a given timepoint for each trial in the batch. Returns: output (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_out` *))*): Output of the network at a given timepoint for each trial in the batch. """ output = tf.matmul(self.transfer_function(state), self.get_effective_W_out(), transpose_b=True, name="3") \ + self.b_out return output
[docs] def forward_pass(self): """ Run the RNN on a batch of task inputs. Iterates over timesteps, running the :func:`recurrent_timestep` and :func:`output_timestep` Implements :func:`psychrnn.backend.rnn.RNN.forward_pass`. Returns: tuple: * **predictions** (*tf.Tensor(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Network output on inputs found in self.x within the tf network. * **states** (*tf.Tensor(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_rec` *))*) -- State variable values over the course of the trials found in self.x within the tf network. """ rnn_inputs = tf.unstack(self.x, axis=1) state = self.init_state rnn_outputs = [] rnn_states = [] for rnn_input in rnn_inputs: state = self.recurrent_timestep(rnn_input, state) output = self.output_timestep(state) rnn_outputs.append(output) rnn_states.append(state) return tf.transpose(a=rnn_outputs, perm=[1, 0, 2]), tf.transpose(a=rnn_states, perm=[1, 0, 2])
[docs]class BasicScan(Basic): """ The basic continuous time recurrent neural network model implemented with `tf.scan <https://www.tensorflow.org/api_docs/python/tf/scan>`_ . Produces the same results as :class:`Basic`, with possible differences in execution time. Args: params (dict): See :class:`psychrnn.backend.rnn.RNN` for details. """
[docs] def recurrent_timestep(self, state, rnn_in): """ Wrapper function for :func:`psychrnn.backend.models.basic.Basic.recurrent_timestep`. Arguments: state (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*): State of network at previous time point. rnn_in (*tf.Tensor(dtype=float, shape=(?*, :attr:`N_in` *))*): Input to the rnn at a certain time point. Returns: new_state (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*): New state of the network. """ return super(BasicScan, self).recurrent_timestep(rnn_in, state)
[docs] def output_timestep(self, dummy, state): """ Wrapper function for :func:`psychrnn.backend.models.basic.Basic.output_timestep`. Includes additional dummy argument to facilitate `tf.scan <https://www.tensorflow.org/api_docs/python/tf/scan>`_. Arguments: dummy: Dummy variable provided by tf.scan. Not actually used by the function. state (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*): State of network at a given timepoint for each trial in the batch. Returns: output (*tf.Tensor(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_out` *))*): Output of the network at a given timepoint for each trial in the batch. """ return super(BasicScan, self).output_timestep(state)
[docs] def forward_pass(self): """ Run the RNN on a batch of task inputs. Iterates over timesteps, running the :func:`recurrent_timestep` and :func:`output_timestep` Implements :func:`psychrnn.backend.rnn.RNN.forward_pass`. Returns: tuple: * **predictions** (*tf.Tensor(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Network output on inputs found in self.x within the tf network. * **states** (*tf.Tensor(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_rec` *))*) -- State variable values over the course of the trials found in self.x within the tf network. """ state = self.init_state rnn_states = \ tf.scan( self.recurrent_timestep, tf.transpose(a=self.x, perm=[1, 0, 2]), initializer=state, parallel_iterations=1) rnn_outputs = \ tf.scan( self.output_timestep, rnn_states, initializer=tf.zeros([self.N_batch, self.N_out]), parallel_iterations=1) return tf.transpose(a=rnn_outputs, perm=[1, 0, 2]), tf.transpose(a=rnn_states, perm=[1, 0, 2])