This page was generated from docs/notebooks/Minimal_Example.ipynb. Interactive online version: Open In Colab.

Hello World!

A popular 2-alternative forced choice perceptual discrimination task is the random dot motion (RDM) task. In RDM, the subject observes dots moving in different directions. The RDM task is a forced choice task – although dots can move in any direction, their are two directions in which the movement of the coherent dots could be. The subject must make a choice towards one of the two directions at the end of the stimulus period (Britten et al., 1992).

To make it possible for an RNN to complete this task, we model this task as two simultaneous noisy inputs into each of two input channels, representing the two directions. The network must determine which channel has the higher mean input and respond by driving the corresponding output unit to 1, and the other output unit to .2. We’ve included this example task in PsychRNN as PerceptualDiscrimination.

To get started, let’s train a basic model in PsychRNN on this 2-alternative forced choice perceptual discrimination task and test how it does on task input. For simplicity, we will use the model defaults.

[2]:
from matplotlib import pyplot as plt
%matplotlib inline

# ---------------------- Import the package ---------------------------
from psychrnn.tasks.perceptual_discrimination import PerceptualDiscrimination
from psychrnn.backend.models.basic import Basic

# ---------------------- Set up a basic model ---------------------------
pd = PerceptualDiscrimination(dt = 10, tau = 100, T = 2000, N_batch = 128)
network_params = pd.get_task_params() # get the params passed in and defined in pd
network_params['name'] = 'model' # name the model uniquely if running mult models in unison
network_params['N_rec'] = 50 # set the number of recurrent units in the model
model = Basic(network_params) # instantiate a basic vanilla RNN

# ---------------------- Train a basic model ---------------------------
model.train(pd) # train model to perform pd task

# ---------------------- Test the trained model ---------------------------
x,target_output,mask, trial_params = pd.get_trial_batch() # get pd task inputs and outputs
model_output, model_state = model.test(x) # run the model on input x

# ---------------------- Plot the results ---------------------------
plt.plot(model_output[0])

# ---------------------- Teardown the model -------------------------
model.destruct()
Iter 1280, Minibatch Loss= 0.173182
Iter 2560, Minibatch Loss= 0.110828
Iter 3840, Minibatch Loss= 0.089823
Iter 5120, Minibatch Loss= 0.090613
Iter 6400, Minibatch Loss= 0.082815
Iter 7680, Minibatch Loss= 0.084474
Iter 8960, Minibatch Loss= 0.084676
Iter 10240, Minibatch Loss= 0.082200
Iter 11520, Minibatch Loss= 0.076985
Iter 12800, Minibatch Loss= 0.080215
Iter 14080, Minibatch Loss= 0.078905
Iter 15360, Minibatch Loss= 0.074752
Iter 16640, Minibatch Loss= 0.071335
Iter 17920, Minibatch Loss= 0.062578
Iter 19200, Minibatch Loss= 0.040781
Iter 20480, Minibatch Loss= 0.033980
Iter 21760, Minibatch Loss= 0.043546
Iter 23040, Minibatch Loss= 0.025923
Iter 24320, Minibatch Loss= 0.022121
Iter 25600, Minibatch Loss= 0.018697
Iter 26880, Minibatch Loss= 0.018280
Iter 28160, Minibatch Loss= 0.021877
Iter 29440, Minibatch Loss= 0.016974
Iter 30720, Minibatch Loss= 0.020424
Iter 32000, Minibatch Loss= 0.022463
Iter 33280, Minibatch Loss= 0.018284
Iter 34560, Minibatch Loss= 0.029344
Iter 35840, Minibatch Loss= 0.014679
Iter 37120, Minibatch Loss= 0.024408
Iter 38400, Minibatch Loss= 0.016357
Iter 39680, Minibatch Loss= 0.023122
Iter 40960, Minibatch Loss= 0.019468
Iter 42240, Minibatch Loss= 0.018826
Iter 43520, Minibatch Loss= 0.013565
Iter 44800, Minibatch Loss= 0.015086
Iter 46080, Minibatch Loss= 0.018692
Iter 47360, Minibatch Loss= 0.012970
Iter 48640, Minibatch Loss= 0.018514
Iter 49920, Minibatch Loss= 0.016651
Optimization finished!
../_images/notebooks_Minimal_Example_3_1.png

Congratulations! You’ve successfully trained and tested your first model! Continue to Simple Example to learn how to define more useful models.