This page was generated from docs/notebooks/SavingLoadingWeights.ipynb. Interactive online version: Open In Colab.

Accessing and Modifying Weights

In Simple Example, we saved weights to ./weights/saved_weights. Here we will load those weights, and modify them by silencing a few recurrent units.

[2]:
import numpy as np
weights = dict(np.load('./weights/saved_weights.npz', allow_pickle = True))
weights['W_rec'][:10, :10] = 0

Here are all the different weights you have access to for modifying. The ones that don’t end in Adam or Adam_1 will be read in when loading a model from weights.

[3]:
print(weights.keys())
dict_keys(['init_state', 'W_in', 'W_rec', 'W_out', 'b_rec', 'b_out', 'Dale_rec', 'Dale_out', 'input_connectivity', 'rec_connectivity', 'output_connectivity', 'init_state/Adam', 'init_state/Adam_1', 'W_in/Adam', 'W_in/Adam_1', 'W_rec/Adam', 'W_rec/Adam_1', 'W_out/Adam', 'W_out/Adam_1', 'b_rec/Adam', 'b_rec/Adam_1', 'b_out/Adam', 'b_out/Adam_1', 'dale_ratio'])

Save the modified weights at './weights/modified_saved_weights.npz'.

[4]:
np.savez('./weights/modified_saved_weights.npz', **weights)

Loading Model with Weights

[5]:
from psychrnn.backend.models.basic import Basic
[6]:
network_params = {'N_batch': 50,
                  'N_in': 2,
                  'N_out': 2,
                  'dt': 10,
                  'tau': 100,
                  'T': 2000,
                  'N_steps': 200,
                  'N_rec': 50
                 }

Load from File

Set network parameters.

[7]:
file_network_params = network_params.copy()
file_network_params['name'] = 'file'
file_network_params['load_weights_path'] = './weights/modified_saved_weights.npz'

Instantiate model.

[8]:
fileModel = Basic(file_network_params)

Verify that the W_rec weights are modified as expected.

[9]:
print(fileModel.get_weights()['W_rec'][:10,:10])
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[10]:
fileModel.destruct()

Load from Weights Dictionary

Set network parameters.

[11]:
dict_network_params = network_params.copy()
dict_network_params['name'] = 'dict'
dict_network_params.update(weights)
type(dict_network_params['dale_ratio']) == np.ndarray and dict_network_params['dale_ratio'].item() is None
[11]:
True

Instantiate model.

[12]:
dictModel = Basic(dict_network_params)

Verify that the W_rec weights are modified as expected.

[13]:
print(dictModel.get_weights()['W_rec'][:10,:10])
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[14]:
dictModel.destruct()