Source code for psychrnn.backend.simulation

from __future__ import division

import numpy as np

from abc import ABCMeta, abstractmethod

# abstract class python 2 & 3 compatible
ABC = ABCMeta('ABC', (object,), {})

[docs]def relu(x): """NumPy implementation of `tf.nn.relu <https://www.tensorflow.org/api_docs/python/tf/nn/relu>`_ Arguments: x (ndarray): array for which relu is computed. Returns: ndarray: np.maximum(x,0) """ return np.maximum(x, 0)
[docs]def sigmoid(x): """NumPy implementation of `tf.nn.sigmoid <https://www.tensorflow.org/api_docs/python/tf/math/sigmoid>`_ Arguments: x (ndarray): array for which sigmoid is computed. Returns: ndarray: 1/(1 + np.exp(-x)) """ return 1/(1 + np.exp(-x))
[docs]class Simulator(ABC): """The simulator class. Note: The base Simulator class is not itself a functioning Simulator. run_trials and rnn_step must be implemented to define a functioning Simulator Args: rnn_model (:class:`psychrnn.backend.rnn.RNN` object, optional): Uses the :class:`psychrnn.backend.rnn.RNN` object to set :attr:`alpha` and :attr:`rec_noise`. Also used to initialize weights if :data:`weights` and :data:`weights_path` are not passed in. Default: None. weights_path (str, optional): Where to load weights from. Take precedence over rnn_model weights. Default: :data:`rnn_model.get_weights() <rnn_model>`. np.load(:data:`weights_path`) should return something of the form :data:`weights`. transfer_function (function, optonal): Function that takes an ndarray as input and outputs an ndarray of the same shape with the transfer / activation function applied. NumPy implementation of a TensorFlow transfer function. Default: :func:`relu`. weights (dict, optional): Takes precedence over both weights_path and rnn_model. Default: np.load(:data:`weights_path`). Dictionary containing the following keys: :Dictionary Keys: * **init_state** (*ndarray(dtype=float, shape=(1,* :attr:`N_rec` *))*) -- Initial state of the network's recurrent units. * **W_in** (*ndarray(dtype=float, shape=(*:attr:`N_rec`. :attr:`N_in` *))*) -- Input weights. * **W_rec** (*ndarray(dtype=float, shape=(*:attr:`N_rec`, :attr:`N_rec` *))*) -- Recurrent weights. * **W_out** (*ndarray(dtype=float, shape=(*:attr:`N_out`, :attr:`N_rec` *))*) -- Output weights. * **b_rec** (*ndarray(dtype=float, shape=(*:attr:`N_rec`, *))*) -- Recurrent bias. * **b_out** (*ndarray(dtype=float, shape=(*:attr:`N_out`, *))*) -- Output bias. params (dict, optional): :Dictionary Keys: * **rec_noise** (*float, optional*) -- Amount of recurrent noise to add to the network. Default: 0 * **alpha** (*float, optional*) -- The number of unit time constants per simulation timestep. Defaut: (1.0* dt) / tau * **dt** (*float, optional*) -- The simulation timestep. Used to calculate alpha if alpha is not passed in. Required if alpha is not in params and rnn_model is None. * **tau** (*float*) -- The intrinsic time constant of neural state decay. Used to calculate alpha if alpha is not passed in. Required if alpha is not in params and rnn_model is None. """ def __init__(self, rnn_model = None, params = None, weights_path=None, weights=None, transfer_function = relu): # ---------------------------------- # Extract params # ---------------------------------- self.transfer_function = transfer_function if rnn_model is not None: self.alpha = rnn_model.alpha self.rec_noise = rnn_model.rec_noise if rnn_model.transfer_function.__name__ != self.transfer_function.__name__: raise UserWarning("The rnn_model transfer function is " + str(rnn_model.transfer_function) + " and the current transfer function is " + str(self.transfer_function) + ". You should make sure these functions do the same thing -- their names do not match.") if params is not None: raise UserWarning("params was passed in but will not be used. rnn_model takes precedence.") else: self.rec_noise = params.get('rec_noise', 0) if params.get('alpha') is not None: self.alpha = params['alpha'] else: dt = params['dt'] tau = params['tau'] self.alpha = params.get('alpha',(1.0* dt) / tau) # ---------------------------------- # Initialize weights # ---------------------------------- self.weights = weights if weights is not None: if weights_path is not None or rnn_model is not None: raise UserWarning("Weights and either rnn_model or weights_path were passed in. Weights from rnn_model and weights_path will be ignored.") elif weights_path is not None: if rnn_model is not None: raise UserWarning("rnn_model and weights_path were both passed in. Weights from rnn_model will be ignored.") self.weights = np.load(weights_path) elif rnn_model is not None: self.weights = rnn_model.get_weights() else: raise UserWarning("Either weights, rnn_model, or weights_path must be passed in.") self.W_in = self.weights['W_in'] self.W_rec = self.weights['W_rec'] self.W_out = self.weights['W_out'] self.b_rec = self.weights['b_rec'] self.b_out = self.weights['b_out'] self.init_state = self.weights['init_state'] # ---------------------------------------------- # t_connectivity allows for ablation experiments # ----------------------------------------------
[docs] @abstractmethod def rnn_step(self, state, rnn_in, t_connectivity): """Given input and previous state, outputs the next state and output of the network. Note: This is an abstract function that must be defined in a child class. Arguments: state (ndarray(dtype=float, shape=(:attr:`N_batch` , :attr:`N_rec`))): State of network at previous time point. rnn_in (ndarray(dtype=float, shape=(:attr:`N_batch` , :attr:`N_in`))): State of network at previous time point. t_connectivity (ndarray(dtype=float, shape=(:attr:`N_rec` , :attr:`N_rec`))): Matrix for ablating / perturbing W_rec. Returns: tuple: * **new_output** (*ndarray(dtype=float, shape=(*:attr:`N_batch`, :attr:`N_out` *))*) -- Output of the network at a given timepoint for each trial in the batch. * **new_state** (*ndarray(dtype=float, shape=(*:attr:`N_batch`, :attr:`N_rec` *))*) -- New state of the network for each trial in the batch. """ pass
[docs] @abstractmethod def run_trials(self, trial_input, t_connectivity=None): """Test the network on a certain task input, optionally including ablation terms. A NumPy implementation of :func:`~psychrnn.backend.rnn.RNN.test` with additional options for ablation. N_batch here is flexible and will be inferred from trial_input. Arguments: trial_batch ((*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*): Task stimulus to run the network on. Stimulus from :func:`psychrnn.tasks.task.Task.get_trial_batch`, or from next(:func:`psychrnn.tasks.task.Task.batch_generator` ). If you want the network to run autonomously, without input, set input to an array of zeroes, N_steps will still indicate how long to run the network. t_connectivity ((*ndarray(dtype=float, shape =(*:attr:`N_steps`, :attr:`N_rec`, :attr:`N_rec` *))*): Matrix for ablating / perturbing W_rec. Passed step by step to rnn_step. Returns: tuple: * **outputs** (*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Output time series of the network for each trial in the batch. * **states** (*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_rec` *))*) -- Activity of recurrent units during each trial. """ pass
[docs]class BasicSimulator(Simulator): """:class:`Simulator` implementation for :class:`psychrnn.backend.models.basic.Basic` and for :class:`psychrnn.backend.models.basic.BasicScan`. See :class:`Simulator` for arguments. """
[docs] def rnn_step(self, state, rnn_in, t_connectivity): """Given input and previous state, outputs the next state and output of the network as a NumPy implementation of :class:`psychrnn.backend.models.basic.Basic.recurrent_timestep` and of :class:`psychrnn.backend.models.basic.Basic.output_timestep`. Additionally takes in :data:`t_connectivity`. If :data:`t_connectivity` is all ones, :func:`rnn_step`'s output will match that of :class:`psychrnn.backend.models.basic.Basic.recurrent_timestep` and :class:`psychrnn.backend.models.basic.Basic.output_timestep`. Otherwise :data:`W_rec` is multiplied by :data:`t_connectivity` elementwise, ablating / perturbing the recurrent connectivity. Arguments: state (ndarray(dtype=float, shape=(:attr:`N_batch` , :attr:`N_rec`))): State of network at previous time point. rnn_in (ndarray(dtype=float, shape=(:attr:`N_batch` , :attr:`N_in`))): State of network at previous time point. t_connectivity (ndarray(dtype=float, shape=(:attr:`N_rec` , :attr:`N_rec`))): Matrix for ablating / perturbing W_rec. Returns: tuple: * **new_output** (*ndarray(dtype=float, shape=(*:attr:`N_batch`, :attr:`N_out` *))*) -- Output of the network at a given timepoint for each trial in the batch. * **new_state** (*ndarray(dtype=float, shape=(*:attr:`N_batch`, :attr:`N_rec` *))*) -- New state of the network for each trial in the batch. """ new_state = ((1-self.alpha) * state) \ + self.alpha * ( np.matmul( self.transfer_function(state), np.transpose(self.W_rec * t_connectivity)) + np.matmul(rnn_in, np.transpose(self.W_in)) + self.b_rec)\ + np.sqrt(2.0 * self.alpha * self.rec_noise * self.rec_noise) * \ np.random.normal(loc=0.0, scale=1.0, size=state.shape) new_output = np.matmul( self.transfer_function(new_state), np.transpose(self.W_out)) + self.b_out return new_output, new_state
[docs] def run_trials(self, trial_input, t_connectivity=None): """Test the network on a certain task input, optionally including ablation terms. A NumPy implementation of :func:`~psychrnn.backend.rnn.RNN.test` with additional options for ablation. N_batch here is flexible and will be inferred from trial_input. Repeatedly calls :func:`rnn_step` to build output and states over the entire timecourse of the :data:`trial_batch` Arguments: trial_batch ((*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*): Task stimulus to run the network on. Stimulus from :func:`psychrnn.tasks.task.Task.get_trial_batch`, or from next(:func:`psychrnn.tasks.task.Task.batch_generator` ). To run the network autonomously without input, set input to an array of zeroes. N_steps will still indicate for how many steps to run the network. t_connectivity ((*ndarray(dtype=float, shape =(*:attr:`N_steps`, :attr:`N_rec`, :attr:`N_rec` *))*): Matrix for ablating / perturbing W_rec. Passed step by step to :func:`rnn_step`. Returns: tuple: * **outputs** (*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Output time series of the network for each trial in the batch. * **states** (*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_rec` *))*) -- Activity of recurrent units during each trial. """ batch_size = trial_input.shape[0] rnn_inputs = np.squeeze(np.split(trial_input, trial_input.shape[1], axis=1)) state = np.expand_dims(self.init_state[0, :], 0) state = np.repeat(state, batch_size, 0) rnn_outputs = [] rnn_states = [] for i, rnn_input in enumerate(rnn_inputs): if t_connectivity is not None: output, state = self.rnn_step(state, rnn_input, t_connectivity[i]) else: output, state = self.rnn_step(state, rnn_input, np.ones_like(self.W_rec)) rnn_outputs.append(output) rnn_states.append(state) return np.swapaxes(np.array(rnn_outputs), 0, 1), np.swapaxes(np.array(rnn_states), 0, 1)
[docs]class LSTMSimulator(Simulator): """:class:`Simulator` implementation for :class:`psychrnn.backend.models.lstm.LSTM` and for :class:`psychrnn.backend.models.lstm.LSTM`. See :class:`Simulator` for arguments. The contents of weights / np.load(weights_path) must now include the following additional keys: :Dictionary Keys: * **init_hidden** (*ndarray(dtype=float, shape=(*:attr:`N_batch` , :attr:`N_rec` *))*) -- Initial state of the cell state. * **init_hidden** (*ndarray(dtype=float, shape=(*:attr:`N_batch` , :attr:`N_rec` *))*) -- Initial state of the hidden state. * **W_f** (*ndarray(dtype=float, shape=(*:attr:`N_rec` + :attr:`N_in`, :attr:`N_rec` *))*) -- f term weights * **W_i** (*ndarray(dtype=float, shape=(*:attr:`N_rec` + :attr:`N_in`, :attr:`N_rec` *))*) -- i term weights * **W_c** (*ndarray(dtype=float, shape=(*:attr:`N_rec` + :attr:`N_in`, :attr:`N_rec` *))*) -- c term weights * **W_o** (*ndarray(dtype=float, shape=(*:attr:`N_rec` + :attr:`N_in`, :attr:`N_rec` *))*) -- o term weights * **b_f** (*ndarray(dtype=float, shape=(*:attr:`N_rec`, *))*) -- f term bias. * **b_i** (*ndarray(dtype=float, shape=(*:attr:`N_rec`, *))*) -- i term bias. * **b_c** (*ndarray(dtype=float, shape=(*:attr:`N_rec`, *))*) -- c term bias. * **b_o** (*ndarray(dtype=float, shape=(*:attr:`N_rec`, *))*) -- o term bias. """ def __init__(self, rnn_model = None, params = None, weights_path=None, weights=None): super(LSTMSimulator, self).__init__(rnn_model=rnn_model, params=params, weights_path=weights_path, weights=weights) self.init_hidden = self.weights['init_hidden'] self.init_cell = self.weights['init_cell'] self.W_f = self.weights['W_f'] self.W_i = self.weights['W_i'] self.W_c = self.weights['W_c'] self.W_o = self.weights['W_o'] self.b_f = self.weights['b_f'] self.b_i = self.weights['b_i'] self.b_c = self.weights['b_c'] self.b_o = self.weights['b_o']
[docs] def rnn_step(self, hidden, cell, rnn_in): """Given input and previous state, outputs the next state and output of the network as a NumPy implementation of :class:`psychrnn.backend.models.lstm.LSTM.recurrent_timestep` and of :class:`psychrnn.backend.models.lstm.LSTM.output_timestep`. Arguments: hidden (ndarray(dtype=float, shape=(:attr:`N_batch` , :attr:`N_rec` ))): Hidden units state of network at previous time point. cell (ndarray(dtype=float, shape=(:attr:`N_batch` , :attr:`N_rec` ))): Cell state of the network at previous time point. rnn_in (ndarray(dtype=float, shape=(:attr:`N_batch` , :attr:`N_in`))): State of network at previous time point. Returns: tuple: * **new_output** (*ndarray(dtype=float, shape=(*:attr:`N_batch`, :attr:`N_out` *))*) -- Output of the network at a given timepoint for each trial in the batch. * **new_hidden** (*ndarray(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*) -- New hidden unit state of the network. * **new_cell** (*ndarray(dtype=float, shape=(* :attr:`N_batch` , :attr:`N_rec` *))*) -- New cell state of the network. """ f = sigmoid(np.matmul(np.concatenate([hidden, rnn_in], 1), self.W_f) + self.b_f) i = sigmoid(np.matmul(np.concatenate([hidden, rnn_in], 1), self.W_i) + self.b_i) c = np.tanh(np.matmul(np.concatenate([hidden, rnn_in], 1), self.W_c) + self.b_c) o = sigmoid(np.matmul(np.concatenate([hidden, rnn_in], 1), self.W_o) + self.b_o) new_cell = f * cell + i * c new_hidden = o * sigmoid(new_cell) new_output = np.matmul(new_hidden, np.transpose(self.W_out)) + self.b_out return new_output, new_hidden, new_cell
[docs] def run_trials(self, trial_input): """Test the network on a certain task input, optionally including ablation terms. A NumPy implementation of :func:`~psychrnn.backend.rnn.RNN.test` with additional options for ablation. N_batch here is flexible and will be inferred from trial_input. Repeatedly calls :func:`rnn_step` to build output and states over the entire timecourse of the :data:`trial_batch` Arguments: trial_batch ((*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*): Task stimulus to run the network on. Stimulus from :func:`psychrnn.tasks.task.Task.get_trial_batch`, or from next(:func:`psychrnn.tasks.task.Task.batch_generator` ). To run the network autonomously without input, set input to an array of zeroes. N_steps will still indicate for how many steps to run the network. Returns: tuple: * **outputs** (*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Output time series of the network for each trial in the batch. * **states** (*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_rec` *))*) -- Activity of recurrent units during each trial. """ batch_size = trial_input.shape[0] rnn_inputs = np.squeeze(np.split(trial_input, trial_input.shape[1], axis=1)) cell = np.expand_dims(self.init_cell[0, :], 0) cell = np.repeat(cell, batch_size, 0) hidden = np.expand_dims(self.init_hidden[0, :], 0) hidden = np.repeat(hidden, batch_size, 0) rnn_outputs = [] rnn_states = [] for i, rnn_input in enumerate(rnn_inputs): output, hidden, cell = self.rnn_step(hidden, cell, rnn_input) rnn_outputs.append(output) rnn_states.append(hidden) return np.swapaxes(np.array(rnn_outputs), 0, 1), np.swapaxes(np.array(rnn_states), 0, 1)