This page was generated from docs/notebooks/BiologicalConstraints.ipynb. Interactive online version: Open In Colab.

Biological Constraints

The default RNN network has all to all connectivity, and allows units to have both excitatory and inhibitory connections. However, this does not reflect the biology we know. PsychRNN includes a framework for easily specifying biological constraints on the model.

This example will introduce the different options for biological constraints included in PsychRNN: - Dale Ratio - Autapses - Connectivity - Fixed Weights

[2]:
import numpy as np

from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
%matplotlib inline

# ---------------------- Import the package ---------------------------
from psychrnn.tasks.perceptual_discrimination import PerceptualDiscrimination
from psychrnn.backend.models.basic import Basic

# ---------------------- Set up a basic model ---------------------------
pd = PerceptualDiscrimination(dt = 10, tau = 100, T = 2000, N_batch = 128)
network_params = pd.get_task_params() # get the params passed in and defined in pd
network_params['name'] = 'model' # name the model uniquely if running mult models in unison
network_params['N_rec'] = 50 # set the number of recurrent units in the model

# -------------------- Set up variables that will be useful later -------
N_in = network_params['N_in']
N_rec = network_params['N_rec']
N_out = network_params['N_out']

This function will plot the colormap of the weights

[3]:
def plot_weights(weights, title=""):
    cmap = plt.set_cmap('RdBu_r')
    img = plt.matshow(weights, norm=Normalize(vmin=-.5, vmax=.5))
    plt.title(title)

    plt.colorbar()

Biologically Unconstrained

[4]:
basicModel = Basic(network_params) # instantiate a basic vanilla RNN we will compare to later on
[5]:
weights = basicModel.get_weights()
plot_weights(weights['W_rec'])
<Figure size 432x288 with 0 Axes>
../_images/notebooks_BiologicalConstraints_7_1.png
[6]:
basicModel.destruct()

Dale Ratio

Dale’s Principle states that a neuron releases the same set of neurotransmitters at each of its synapses (Eccles et al., 1954). Since neurotransmitters tend to be either excitatory or inhibitory, theorists have taken this to mean that each neuron has exclusively either excitatory or inhibitory synapses (Song et al., 2016; Rajan and Abbott, 2006).

To set the dale ratio, simply set network_params['dale_ratio'] equal to the proportion of total recurrent neurons that should be excitatory. The remainder will be inhibitory.

The dale ratio can be combined with any other parameter settings except for network_params['initializer'], in which case the dale ratio needs to be passed directly into the initializer being used. Dale ratio is not enforced if LSTM is used as the RNN imlementation.

Once the model is instantiated it can be trained and tested as demonstrated in Simple Example

[7]:
dale_network_params = network_params.copy()
dale_network_params['name'] = 'dales_model'
dale_network_params['dale_ratio'] = .8
daleModel = Basic(dale_network_params)
[8]:
weights = daleModel.get_weights()
plot_weights(weights['W_rec'])
<Figure size 432x288 with 0 Axes>
../_images/notebooks_BiologicalConstraints_11_1.png
[9]:
daleModel.destruct()

Autapses

To disallow autapses (self connections) or not, simply set network_params['autapses'] = False.

The autapses parameter can be combined with any other parameter settings except for network_params['initializer'], in which case the boolean for autapses needs to be passed directly into the initializer being used. Autapses are not enforced if LSTM is used as the RNN imlementation.

Once the model is instantiated it can be trained and tested as demonstrated in Simple Example

[10]:
autapses_network_params = network_params.copy()
autapses_network_params['name'] = 'autapses_model'
autapses_network_params['autapses'] = False
autapsesModel = Basic(autapses_network_params)
[11]:
weights = autapsesModel.get_weights()
plot_weights(weights['W_rec'])
<Figure size 432x288 with 0 Axes>
../_images/notebooks_BiologicalConstraints_15_1.png

Notice the white line on the diagonal (self-connections) above, where the weights are 0.

[12]:
autapsesModel.destruct()

Connectivity

The brain is not all-to-all connected, so it can be useful to restrict and structure the connectivity of our RNNs.

The input_connectivity, recurrent_connectivity, and output_connectivity parameters allow us to do just that. Any subset of them can be combined with any other parameter settings except for network_params['initializer'], in which case the connectivity matrices need to be passed directly into the initializer being used. Connectivity is not enforced if LSTM is used as the RNN imlementation.

Once the model is instantiated it can be trained and tested as demonstrated in Simple Example

[13]:
modular_network_params = network_params.copy()
modular_network_params['name'] = 'modular_model'

# Set connectivity matrices to the default -- fully connected
input_connectivity = np.ones((N_rec, N_in))
rec_connectivity = np.ones((N_rec, N_rec))
output_connectivity = np.ones((N_out, N_rec))

# Specify certain connections to disallow. This can be done with input and output connectivity matrices as well
rec_connectivity[2*(N_rec//5):4*(N_rec//5),:2*(N_rec//5)] = 0
rec_connectivity[:2*(N_rec//5),2*(N_rec//5):4*(N_rec//5)] = 0

Plot the recurrent connectivity matrix

[14]:
plot_weights(rec_connectivity, "recurrent connectivity")
<Figure size 432x288 with 0 Axes>
../_images/notebooks_BiologicalConstraints_21_1.png

Specify the connectivity matrices in network_params.

[15]:
modular_network_params['input_connectivity'] = input_connectivity
modular_network_params['rec_connectivity'] = rec_connectivity
modular_network_params['output_connectivity'] = output_connectivity
modularModel = Basic(modular_network_params)
[16]:
weights = modularModel.get_weights()

plot_weights(weights['W_rec'])
<Figure size 432x288 with 0 Axes>
../_images/notebooks_BiologicalConstraints_24_1.png
[17]:
modularModel.destruct()

Fixed Weights

Some parts of the brain we may assume to be less plastic than others. Alternatively, we may want to specify particular weights within the model and train the rest of them around those.

The fixed_weights parameter for the train() fucntion allows us to do this.

Instantiate the model

[18]:
fixed_network_params = network_params.copy()
fixed_network_params['name'] = 'fixed_model'
fixedModel = Basic(fixed_network_params) # instantiate a basic vanilla RNN we will compare to later on

Plot the model weights before training

[19]:
weights = fixedModel.get_weights()

plot_weights(weights['W_rec'])
<Figure size 432x288 with 0 Axes>
../_images/notebooks_BiologicalConstraints_31_1.png
[20]:
# Set fixed weight matrices to the default -- fully trainable
W_in_fixed = np.zeros((N_rec,N_in))
W_rec_fixed = np.zeros((N_rec,N_rec))
W_out_fixed = np.zeros((N_out, N_rec))

# Specify certain weights to fix.
W_rec_fixed[N_rec//5*4:, :4*N_rec//5] = 1
W_rec_fixed[:4*N_rec//5, N_rec//5*4:] = 1


# Specify the fixed weights parameters in train_params
train_params = {}
train_params['fixed_weights'] = {
    'W_in': W_in_fixed,
    'W_rec': W_rec_fixed,
    'W_out': W_out_fixed
}
[21]:
losses, initialTime, trainTime = fixedModel.train(pd, train_params)
Iter 1280, Minibatch Loss= 0.177185
Iter 2560, Minibatch Loss= 0.107636
Iter 3840, Minibatch Loss= 0.099301
Iter 5120, Minibatch Loss= 0.085224
Iter 6400, Minibatch Loss= 0.082593
Iter 7680, Minibatch Loss= 0.079836
Iter 8960, Minibatch Loss= 0.080765
Iter 10240, Minibatch Loss= 0.079680
Iter 11520, Minibatch Loss= 0.072564
Iter 12800, Minibatch Loss= 0.067365
Iter 14080, Minibatch Loss= 0.040751
Iter 15360, Minibatch Loss= 0.052333
Iter 16640, Minibatch Loss= 0.046463
Iter 17920, Minibatch Loss= 0.031513
Iter 19200, Minibatch Loss= 0.033700
Iter 20480, Minibatch Loss= 0.033375
Iter 21760, Minibatch Loss= 0.035751
Iter 23040, Minibatch Loss= 0.041844
Iter 24320, Minibatch Loss= 0.038133
Iter 25600, Minibatch Loss= 0.023348
Iter 26880, Minibatch Loss= 0.027589
Iter 28160, Minibatch Loss= 0.019354
Iter 29440, Minibatch Loss= 0.022398
Iter 30720, Minibatch Loss= 0.020543
Iter 32000, Minibatch Loss= 0.013847
Iter 33280, Minibatch Loss= 0.017195
Iter 34560, Minibatch Loss= 0.019519
Iter 35840, Minibatch Loss= 0.020920
Iter 37120, Minibatch Loss= 0.016392
Iter 38400, Minibatch Loss= 0.019325
Iter 39680, Minibatch Loss= 0.015266
Iter 40960, Minibatch Loss= 0.031248
Iter 42240, Minibatch Loss= 0.023118
Iter 43520, Minibatch Loss= 0.015399
Iter 44800, Minibatch Loss= 0.018544
Iter 46080, Minibatch Loss= 0.021445
Iter 47360, Minibatch Loss= 0.012260
Iter 48640, Minibatch Loss= 0.017937
Iter 49920, Minibatch Loss= 0.020652
Optimization finished!

Plot the weights after training:

[22]:
weights = fixedModel.get_weights()

plot_weights(weights['W_rec'])
<Figure size 432x288 with 0 Axes>
../_images/notebooks_BiologicalConstraints_35_1.png
[23]:
fixedModel.destruct()

Unfortunately, it’s hard to see visually whether the weights actually stayed fixed or not. To make it more apparent, we will set all of the fixed weights to the same value, the average of their previous value.

[24]:
weights['W_rec'][N_rec//5*4:, :4*N_rec//5] = np.mean(weights['W_rec'][N_rec//5*4:, :4*N_rec//5])
weights['W_rec'][:4*N_rec//5, N_rec//5*4:] = np.mean(weights['W_rec'][:4*N_rec//5, N_rec//5*4:])

Now we make a new model loading the weights weights

[25]:
fixed_network_params = network_params.copy()
fixed_network_params['name'] = 'fixed_model_clearer'
for key, value in weights.items():
    fixed_network_params[key] = value
fixedModelClearer = Basic(fixed_network_params) # instantiate an RNN loading the revised weights from the previous model

Plot the model weights before training

[26]:
weights = fixedModelClearer.get_weights()

plot_weights(weights['W_rec'])
<Figure size 432x288 with 0 Axes>
../_images/notebooks_BiologicalConstraints_42_1.png
[27]:
losses, initialTime, trainTime = fixedModelClearer.train(pd, train_params)
Iter 1280, Minibatch Loss= 0.050554
Iter 2560, Minibatch Loss= 0.024552
Iter 3840, Minibatch Loss= 0.021128
Iter 5120, Minibatch Loss= 0.028251
Iter 6400, Minibatch Loss= 0.019927
Iter 7680, Minibatch Loss= 0.016723
Iter 8960, Minibatch Loss= 0.013385
Iter 10240, Minibatch Loss= 0.016600
Iter 11520, Minibatch Loss= 0.020957
Iter 12800, Minibatch Loss= 0.012375
Iter 14080, Minibatch Loss= 0.019829
Iter 15360, Minibatch Loss= 0.020301
Iter 16640, Minibatch Loss= 0.019600
Iter 17920, Minibatch Loss= 0.017423
Iter 19200, Minibatch Loss= 0.010484
Iter 20480, Minibatch Loss= 0.014385
Iter 21760, Minibatch Loss= 0.017793
Iter 23040, Minibatch Loss= 0.009582
Iter 24320, Minibatch Loss= 0.014552
Iter 25600, Minibatch Loss= 0.010809
Iter 26880, Minibatch Loss= 0.012337
Iter 28160, Minibatch Loss= 0.017401
Iter 29440, Minibatch Loss= 0.012895
Iter 30720, Minibatch Loss= 0.016758
Iter 32000, Minibatch Loss= 0.011036
Iter 33280, Minibatch Loss= 0.007268
Iter 34560, Minibatch Loss= 0.008717
Iter 35840, Minibatch Loss= 0.014370
Iter 37120, Minibatch Loss= 0.012818
Iter 38400, Minibatch Loss= 0.021543
Iter 39680, Minibatch Loss= 0.011174
Iter 40960, Minibatch Loss= 0.010043
Iter 42240, Minibatch Loss= 0.015098
Iter 43520, Minibatch Loss= 0.012391
Iter 44800, Minibatch Loss= 0.011706
Iter 46080, Minibatch Loss= 0.015107
Iter 47360, Minibatch Loss= 0.012814
Iter 48640, Minibatch Loss= 0.009676
Iter 49920, Minibatch Loss= 0.009720
Optimization finished!

Plot the model weights after training. Now it is clear that the weights haven’t changed.

[28]:
weights = fixedModelClearer.get_weights()

plot_weights(weights['W_rec'])
<Figure size 432x288 with 0 Axes>
../_images/notebooks_BiologicalConstraints_45_1.png
[29]:
fixedModelClearer.destruct()