from __future__ import division
import numpy as np
import tensorflow as tf
from warnings import warn
tf.compat.v1.disable_eager_execution()
[docs]class WeightInitializer(object):
""" Base Weight Initialization class.
Initializes biological constraints and network weights, optionally loading weights from a file or from passed in arrays.
Keyword Arguments:
N_in (int): The number of network inputs.
N_rec (int): The number of recurrent units in the network.
N_out (int): The number of network outputs.
load_weights_path (str, optional): Path to load weights from using np.load. Weights saved at that path should be in the form saved out by :func:`psychrnn.backend.rnn.RNN.save` Default: None.
input_connectivity (ndarray(dtype=float, shape=(:attr:`N_rec`, :attr:`N_in`)), optional): Connectivity mask for the input layer. 1 where connected, 0 where unconnected. Default: np.ones((:attr:`N_rec`, :attr:`N_in`)).
rec_connectivity (ndarray(dtype=float, shape=(:attr:`N_rec`, :attr:`N_rec`)), optional): Connectivity mask for the recurrent layer. 1 where connected, 0 where unconnected. Default: np.ones((:attr:`N_rec`, :attr:`N_rec`)).
output_connectivity (ndarray(dtype=float, shape=(:attr:`N_out`, :attr:`N_rec`)), optional): Connectivity mask for the output layer. 1 where connected, 0 where unconnected. Default: np.ones((:attr:`N_out`, :attr:`N_rec`)).
autapses (bool, optional): If False, self connections are not allowed in N_rec, and diagonal of :data:`rec_connectivity` will be set to 0. Default: True.
dale_ratio (float, optional): Dale's ratio, used to construct Dale_rec and Dale_out. 0 <= dale_ratio <=1 if dale_ratio should be used. ``dale_ratio * N_rec`` recurrent units will be excitatory, the rest will be inhibitory. Default: None
which_rand_init (str, optional): Which random initialization to use for W_in and W_out. Will also be used for W_rec if :data:`which_rand_W_rec_init` is not passed in. Options: :func:`'const_unif' <const_unif_init>`, :func:`'const_gauss' <const_gauss_init>`, :func:`'glorot_unif' <glorot_unif_init>`, :func:`'glorot_gauss' <glorot_gauss_init>`. Default: :func:`'glorot_gauss' <glorot_gauss_init>`.
which_rand_W_rec_init (str, optional): Which random initialization to use for W_rec. Options: :func:`'const_unif' <const_unif_init>`, :func:`'const_gauss' <const_gauss_init>`, :func:`'glorot_unif' <glorot_unif_init>`, :func:`'glorot_gauss' <glorot_gauss_init>`. Default: :data:`which_rand_init`.
init_minval (float, optional): Used by :func:`const_unif_init` as :attr:`minval` if ``'const_unif'`` is passed in for :data:`which_rand_init` or :data:`which_rand_W_rec_init`. Default: -.1.
init_maxval (float, optional): Used by :func:`const_unif_init` as :attr:`maxval` if ``'const_unif'`` is passed in for :data:`which_rand_init` or :data:`which_rand_W_rec_init`. Default: .1.
W_in (ndarray(dtype=float, shape=(:attr:`N_rec`, :attr:`N_in` )), optional): Input weights. Default: Initialized using the function indicated by :data:`which_rand_init`
W_rec (ndarray(dtype=float, shape=(:attr:`N_rec`, :attr:`N_rec` )), optional): Recurrent weights. Default: Initialized using the function indicated by :data:`which_rand_W_rec_init`
W_out (ndarray(dtype=float, shape=(:attr:`N_out`, :attr:`N_rec` )), optional): Output weights. Defualt: Initialized using the function indicated by :data:`which_rand_init`
b_rec (ndarray(dtype=float, shape=(:attr:`N_rec`, )), optional): Recurrent bias. Default: np.zeros(:attr:`N_rec`)
b_out (ndarray(dtype=float, shape=(:attr:`N_out`, )), optional): Output bias. Default: np.zeros(:attr:`N_out`)
Dale_rec (ndarray(dtype=float, shape=(:attr:`N_rec`, :attr:`N_rec`)), optional): Diagonal matrix with ones and negative ones on the diagonal. If :data:`dale_ratio` is not ``None``, indicates whether a recurrent unit is excitatory(1) or inhibitory(-1). Default: constructed based on :data:`dale_ratio`
Dale_out (ndarray(dtype=float, shape=(:attr:`N_rec`, :attr:`N_rec`)), optional): Diagonal matrix with ones and zeroes on the diagonal. If :data:`dale_ratio` is not ``None``, indicates whether a recurrent unit is excitatory(1) or inhibitory(0). Inhibitory neurons do not contribute to the output. Default: constructed based on :data:`dale_ratio`
init_state (ndarray(dtype=float, shape=(1, :attr:`N_rec` )), optional): Initial state of the network's recurrent units. Default: .1 + .01 * np.random.randn(:data:`N_rec` ).
Attributes:
initializations (dict): Dictionary containing entries for :data:`input_connectivity`, :data:`rec_connectivity`, :data:`output_connectivity`,:data:`transfer_function`, :data:`dale_ratio`, :data:`Dale_rec`, :data:`Dale_out`, :data:`W_in`, :data:`W_rec`, :data:`W_out`, :data:`b_rec`, :data:`b_out`, and :data:`init_state`.
"""
def __init__(self, **kwargs):
# ----------------------------------
# Required parameters
# ----------------------------------
self.load_weights_path = kwargs.get('load_weights_path', None)
N_in = self.N_in = kwargs.get('N_in')
N_rec = self.N_rec = kwargs.get('N_rec')
N_out = self.N_out = kwargs.get('N_out')
self.autapses = kwargs.get('autapses', True)
self.initializations = dict()
if self.load_weights_path is not None:
# ----------------------------------
# Load saved weights
# ----------------------------------
self.initializations = dict(np.load(self.load_weights_path, allow_pickle = True))
if 'dale_ratio' in self.initializations.keys():
if type(self.initializations['dale_ratio']) == np.ndarray:
self.initializations['dale_ratio'] = self.initializations['dale_ratio'].item()
else:
warn("You are loading weights from a model trained with an old version (<1.0). Dale's formatting has changed. Dale's rule will not be applied even if the model was previously trained using Dale's. To change this behavior, add the correct dale ratio to the 'dale_ratio' field to the file that weights are being loaded from, " + self.load_weights_path + ".")
self.initializations['dale_ratio'] = None;
if 'transfer_function' in self.initializations.keys():
self.initializations['transfer_function'] = self.initializations['transfer_function'].item()
else:
warn("You are loading weights from a model trained with an old version (<=1.0). Transfer function formatting has changed--if your model was previously trained with a custom transfer function, with your old format you need to pass in the `transfer_function` paramter to ensure the same model behavior as when trained. To automatically use the correct transfer function, add the correct transfer function under the key 'transfer_function' to the file that weights are being loaded from, " + self.load_weights_path + ".")
self.initializations['transfer_function'] = kwargs.get('transfer_function');
else:
if kwargs.get('W_rec', None) is None and type(self).__name__=='WeightInitializer':
warn("This network may not train since the eigenvalues of W_rec are not regulated in any way.")
# ----------------------------------
# Optional Parameters
# ----------------------------------
self.rand_init = kwargs.get('which_rand_init', 'glorot_gauss')
self.rand_W_rec_init = self.get_rand_init_func(kwargs.get('which_rand_W_rec_init', self.rand_init))
self.rand_init = self.get_rand_init_func(self.rand_init)
if self.rand_init == self.const_unif_init or self.rand_W_rec_init == self.const_unif_init:
self.init_minval = kwargs.get('init_minval', -.1)
self.init_maxval = kwargs.get('init_maxval', .1)
# ----------------------------------
# Biological Constraints
# ----------------------------------
# Connectivity constraints
self.initializations['input_connectivity'] = kwargs.get('input_connectivity',np.ones([N_rec, N_in]))
assert(self.initializations['input_connectivity'].shape == (N_rec, N_in))
self.initializations['rec_connectivity'] = kwargs.get('rec_connectivity',np.ones([N_rec, N_rec]))
assert(self.initializations['rec_connectivity'].shape == (N_rec, N_rec))
self.initializations['output_connectivity'] = kwargs.get('output_connectivity', np.ones([N_out, N_rec]))
assert(self.initializations['output_connectivity'].shape == (N_out, N_rec))
# Autapses constraint
if not self.autapses:
self.initializations['rec_connectivity'][np.eye(N_rec) == 1] = 0
# Dale's constraint
self.initializations['dale_ratio'] = dale_ratio = kwargs.get('dale_ratio', None)
if type(self.initializations['dale_ratio']) == np.ndarray:
self.initializations['dale_ratio'] = dale_ratio = self.initializations['dale_ratio'].item()
if dale_ratio is not None and (dale_ratio <0 or dale_ratio > 1):
print("Need 0 <= dale_ratio <= 1. dale_ratio was: " + str(dale_ratio))
raise
dale_vec = np.ones(N_rec)
if dale_ratio is not None:
dale_vec[int(dale_ratio * N_rec):] = -1
dale_rec = np.diag(dale_vec)
dale_vec[int(dale_ratio * N_rec):] = 0
dale_out = np.diag(dale_vec)
else:
dale_rec = np.diag(dale_vec)
dale_out = np.diag(dale_vec)
self.initializations['Dale_rec'] = kwargs.get('Dale_rec', dale_rec)
assert(self.initializations['Dale_rec'].shape == (N_rec, N_rec))
self.initializations['Dale_out'] = kwargs.get('Dale_out', dale_out)
assert(self.initializations['Dale_out'].shape == (N_rec, N_rec))
# ----------------------------------
# Default initializations / optional loading from params
# ----------------------------------
self.initializations['transfer_function'] = kwargs.get('transfer_function', tf.nn.relu)
self.initializations['W_in'] = kwargs.get('W_in', self.rand_init(self.initializations['input_connectivity']))
assert(self.initializations['W_in'].shape == (N_rec, N_in))
self.initializations['W_out'] = kwargs.get('W_out', self.rand_init(self.initializations['output_connectivity']))
assert(self.initializations['W_out'].shape == (N_out, N_rec))
self.initializations['W_rec'] = kwargs.get('W_rec', self.rand_W_rec_init(self.initializations['rec_connectivity']))
assert(self.initializations['W_rec'].shape == (N_rec, N_rec))
self.initializations['b_rec'] = kwargs.get('b_rec',np.zeros(N_rec))
assert(self.initializations['b_rec'].shape == (N_rec,))
self.initializations['b_out'] = kwargs.get('b_out',np.zeros(N_out))
assert(self.initializations['b_out'].shape == (N_out,))
self.initializations['init_state'] = kwargs.get('init_state', .1 + .01 * np.random.randn(N_rec))
assert(self.initializations['init_state'].size == N_rec)
return
[docs] def get_rand_init_func(self, which_rand_init):
"""Maps initialization function names (strings) to generating functions.
Arguments:
which_rand_init (str): Maps to ``[which_rand_init]_init``. Options are :func:`'const_unif' <const_unif_init>`, :func:`'const_gauss' <const_gauss_init>`, :func:`'glorot_unif' <glorot_unif_init>`, :func:`'glorot_gauss' <glorot_gauss_init>`.
Returns:
function: ``self.[which_rand_init]_init``
"""
mapping = {
'const_unif': self.const_unif_init,
'const_gauss': self.const_gauss_init,
'glorot_unif': self.glorot_unif_init,
'glorot_gauss': self.glorot_gauss_init}
return mapping[which_rand_init]
[docs] def const_gauss_init(self, connectivity):
""" Initialize ndarray of shape :data:`connectivity` with values from a normal distribution.
Arguments:
connectivity (ndarray): 1 where connected, 0 where unconnected.
Returns:
ndarray(dtype=float, shape=connectivity.shape)
"""
return np.random.randn(connectivity.shape[0], connectivity.shape[1])
[docs] def const_unif_init(self, connectivity):
""" Initialize ndarray of shape :data:`connectivity` with values uniform distribution with minimum :data:`init_minval` and maximum :data:`init_maxval` as set in :class:`WeightInitializer`.
Arguments:
connectivity (ndarray): 1 where connected, 0 where unconnected.
Returns:
ndarray(dtype=float, shape=connectivity.shape)
"""
minval = self.init_minval
maxval = self.init_maxval
return (maxval-minval) * np.random.rand(connectivity.shape[0], connectivity.shape[1]) + minval
[docs] def glorot_unif_init(self, connectivity):
""" Initialize ndarray of shape :data:`connectivity` with values from a glorot uniform distribution.
Draws samples from a uniform distribution within [-limit, limit] where `limit`
is `sqrt(6 / (fan_in + fan_out))` where `fan_in` is the number of input units and `fan_out` is the number of output units. Respects the :data:`connectivity` matrix.
Arguments:
connectivity (ndarray): 1 where connected, 0 where unconnected.
Returns:
ndarray(dtype=float, shape=connectivity.shape)
"""
init = np.zeros(connectivity.shape)
fan_in = np.sum(connectivity, axis = 1)
init += np.tile(fan_in, (connectivity.shape[1],1)).T
fan_out = np.sum(connectivity, axis = 0)
init += np.tile(fan_out, (connectivity.shape[0],1))
return np.random.uniform(-np.sqrt(6/init), np.sqrt(6/init))
[docs] def glorot_gauss_init(self, connectivity):
""" Initialize ndarray of shape :data:`connectivity` with values from a glorot normal distribution.
Draws samples from a normal distribution centered on 0 with `stddev
= sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number of input units and `fan_out` is the number of output units. Respects the :data:`connectivity` matrix.
Arguments:
connectivity (ndarray): 1 where connected, 0 where unconnected.
Returns:
ndarray(dtype=float, shape=connectivity.shape)
"""
init = np.zeros(connectivity.shape)
fan_in = np.sum(connectivity, axis = 1)
init += np.tile(fan_in, (connectivity.shape[1],1)).T
fan_out = np.sum(connectivity, axis = 0)
init += np.tile(fan_out, (connectivity.shape[0],1))
return np.random.normal(0, np.sqrt(2/init))
[docs] def get_dale_ratio(self):
""" Returns the dale_ratio.
:math:`0 \\leq dale\\_ratio \\leq 1` if dale_ratio should be used, dale_ratio = None otherwise. ``dale_ratio * N_rec`` recurrent units will be excitatory, the rest will be inhibitory.
Returns:
float: Dale ratio, None if no dale ratio is set.
"""
return self.initializations['dale_ratio']
[docs] def get_transfer_function(self):
""" Returns the transfer function.
Returns:
function: transfer function, Default: `tf.nn.relu <https://www.tensorflow.org/api_docs/python/tf/nn/relu>`_.
"""
return self.initializations['transfer_function']
[docs] def get(self, tensor_name):
""" Get :data:`tensor_name` from :attr:`initializations` as a Tensor.
Arguments:
tensor_name (str): The name of the tensor to get. See :attr:`initializations` for options.
Returns:
Tensor object
"""
return tf.compat.v1.constant_initializer(self.initializations[tensor_name])
[docs] def save(self, save_path):
""" Save :attr:`initializations` to :data:`save_path`.
Arguments:
save_path (str): File path for saving the initializations. The .npz extension will be appended if not already provided.
"""
np.savez(save_path, **self.initializations)
return
[docs] def balance_dale_ratio(self):
""" If dale_ratio is not None, balances :attr:`initializations['W_rec'] <initializations>` 's excitatory and inhibitory weights so the network will train.
"""
dale_ratio = self.get_dale_ratio()
if dale_ratio is not None:
dale_vec = np.ones(self.N_rec)
dale_vec[int(dale_ratio * self.N_rec):] = dale_ratio/(1-dale_ratio)
dale_rec = np.diag(dale_vec) / np.linalg.norm(np.matmul(self.initializations['rec_connectivity'], np.diag(dale_vec)), axis=1)[:,np.newaxis]
self.initializations['W_rec'] = np.matmul(self.initializations['W_rec'], dale_rec)
return
[docs]class GaussianSpectralRadius(WeightInitializer):
"""Generate random gaussian weights with specified spectral radius.
If Dale is set, balances the random gaussian weights between excitatory and inhibitory using :func:`balance_dale_ratio` before applying the specified spectral radius.
Keyword Args:
spec_rad (float, optional): The spectral radius to initialize W_rec with. Default: 1.1.
Other Keyword Args:
See :class:`~psychrnn.backend.initializations.WeightInitializer` for details.
"""
def __init__(self, **kwargs):
super(GaussianSpectralRadius, self).__init__(**kwargs)
self.spec_rad = kwargs.get('spec_rad', 1.1)
self.initializations['W_rec'] = np.random.randn(self.N_rec, self.N_rec)
# balance weights for dale ratio training to proceed normally
self.balance_dale_ratio()
self.initializations['W_rec'] = self.spec_rad * self.initializations['W_rec'] / np.max(np.abs(np.linalg.eig(self.initializations['W_rec'])[0]))
return
[docs]class AlphaIdentity(WeightInitializer):
'''Generate recurrent weights :math:`w(i,i) = alpha`, :math:`w(i,j) = 0` where :math:`i \\neq j`.
If Dale is set, balances the alpha excitatory and inhibitory weights using :func:`~psychrnn.backend.initializations.WeightInitializer.balance_dale_ratio`, so w(i,i) will not be exactly equal to alpha.
Keyword Args:
alpha (float): The value of alpha to set w(i,i) to in W_rec.
Other Keyword Args:
See :class:`WeightInitializer` for details.
'''
def __init__(self, **kwargs):
super(AlphaIdentity, self).__init__(**kwargs)
self.alpha = kwargs.get('alpha')
self.initializations['W_rec'] = np.eye(self.N_rec) * self.alpha
# balance weights for dale ratio training to proceed normally
self.balance_dale_ratio()
return