Source code for psychrnn.backend.curriculum

from __future__ import division
from __future__ import print_function

from os import makedirs, path
import numpy as np

import warnings

[docs]def default_metric(curriculum_params, input_data, correct_output, output_mask, output, epoch, losses, verbosity): """ Default metric to use to evaluate performance when using Curriculum learning. Advance is true if accuracy >= threshold, False otherwise. Arguments: curriculum_params (dict): Dictionary of the :class:`Curriculum` object parameters, containing the following keys: :Dictionary Keys: * **stop_training** (*bool*) -- True if the network has finished training and completed all stages. * **stage** (*int*) -- Current training stage (initial stage is 0). * **metric_values** (*list of [float, int]*) -- List of metric values and the stage at which each metric value was computed. * **tasks** (*list of :class:`psychrnn.tasks.task.Task` objects*) -- List of tasks in the curriculum. * **metric** (*function*) -- What metric function to use. :func:`default_metric` is an example of one in terms of inputs and outputs taken. * **accuracies** (*list of functions with the signature of* :func:`psychrnn.tasks.task.Task.accuracy_function`) -- Accuracy function to use at each stage. * **thresholds** (*list of float*) -- Thresholds for each stage that accuracy must reach to move to the next stage. * **metric_epoch** (*int*) -- Calculate the metric / test if advance to the next stage every metric_epoch training epochs. * **output_file** (*str*) -- Optional path for where to save out metric value and stage. input_data (ndarray(dtype=float, shape =(:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` ))): Task inputs. correct_output (ndarray(dtype=float, shape = (:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out`))): Correct (target) task output given input_data. output_mask (ndarray(dtype=float, shape = (:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out`))): Output mask for the task. True when the network should aim to match the target output, False when the target output can be ignored. output (ndarray(dtype=float, shape = (:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out`))): The network's output given input_data. epoch (*int*): The epoch number in training. losses (*list of float*): List of losses, computed during training. verbosity (*bool*): Whether to print information as training progresses. If True, prints accuracy every time it is computed. Returns: tuple: * **advance** (*bool*) -- True if the accuracy is >= the threshold for the current stage. False otherwise. * **metric_value** (*float*) -- Value of the computed accuracy. """ accuracy = curriculum_params['accuracies'][curriculum_params['stage']](correct_output,output, output_mask) threshold = curriculum_params['thresholds'][curriculum_params['stage']] if verbosity: print("Accuracy: " + str(accuracy)) return accuracy>=threshold, accuracy
[docs]class Curriculum(object): """ Curriculum object. Allows training on a sequence of tasks when Curriculum is passed into :func:`~psychrnn.backend.rnn.RNN.train`. Arguments: tasks (list of :class:`~psychrnn.tasks.task.Task` objects): List of tasks to use in the curriculum. metric (function, optional): Function for calculating whether the stage advances and what the metric value is at each metric_epoch. Default: :func:`default_metric`. :Arguments: * **curriculum_params** (*dict*) -- Dictionary of the :class:`Curriculum` object parameters, containing the following keys: :Dictionary Keys: * **stop_training** (*bool*) -- True if the network has finished training and completed all stages. * **stage** (*int*) -- Current training stage (initial stage is 0). * **metric_values** (*list of [float, int]*) -- List of metric values and the stage at which each metric value was computed. * **tasks** (*list of :class:`psychrnn.tasks.task.Task` objects*) -- List of tasks in the curriculum. * **metric** (*function*) -- What metric function to use. :func:`default_metric` is an example of one in terms of inputs and outputs taken. * **accuracies** (*list of functions*) -- Accuracy function to use at each stage. * **thresholds** (*list of float*) -- Thresholds for each stage that accuracy must reach to move to the next stage. * **metric_epoch** (*int*) -- Calculate the metric and test if the model should advance to the next stage every :data:`metric_epoch` training epochs. * **output_file** (*str*) -- Optional path for saving out themetric value and stage. If the .npz filename extension is not included, it will be appended. * **input_data** (*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Task inputs. * **correct_output** (*ndarray(dtype=float, shape = (*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Correct (target) task output given input_data. * **output_mask** (*ndarray(dtype=float, shape = (*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Output mask for the task. True when the network should aim to match the target output, False when the target output can be ignored. * **output** (*ndarray(dtype=float, shape = (*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- The network's output given input_data. * **epoch** (*int*) -- The epoch number in training. * **losses** (*list of float*) -- List of losses, computed during training. * **verbosity** (*bool*) -- Whether to print information as training progresses. :Returns: *tuple* * **advance** (*bool*) -- True if the the stage should be advanced. False otherwise. * **metric_value** (*float*) -- Value of the computed metric. accuracies (list of functions, optional): Optional list of functions to use to calculate network performance for the purposes of advancing tasks. Used by :func:`default_metric` to compute accuracy. Default: ``[tasks[i].accuracy_function for i in range(len(tasks))]``. thresholds (list of float, optional): Optional list of thresholds. If metric = default_metric, accuracies must reach the threshold for a given stage in order to advance to the next stage. Default: ``[.9 for i in range(len(tasks))]`` metric_epoch (int): Calculate the metric and test if the model should advance to the next stage every :data:`metric_epoch` training epochs. Default: 10 output_file (str): Optional path for saving out the metric value and stage. If the .npz filename extension is not included, it will be appended. Default: None. """ def __init__(self, tasks, **kwargs): self.stop_training = False self.stage = 0 self.metric_values = [] # List of tasks that make up the curriculum self.tasks = tasks #Optional function with parameters as in default_metric that returns whether to advance stage, and the accuracy / metric value self.metric = kwargs.get('metric', default_metric) #Optional list of accuracy functions to use for each task self.accuracies = kwargs.get('accuracies', [tasks[i].accuracy_function for i in range(len(tasks))]) assert len(self.accuracies)==len(self.tasks) # Optional list of accuracy cuttoff values to use with each tasks self.thresholds = kwargs.get('thresholds', [.9 for i in range(len(tasks))]) assert len(self.thresholds)==len(self.tasks) # How often to check metric? self.metric_epoch = kwargs.get('metric_epoch', 10) # Optional path to save out metric value and stage to self.output_file = kwargs.get('output_file', None) if self.output_file is not None: if path.dirname(self.output_file) != "" and not path.exists(path.dirname(self.output_file)): makedirs(path.dirname(self.output_file))
[docs] def metric_test(self, input_data, correct_output, output_mask, test_output, epoch, losses, verbosity = False): """Evaluates whether to advance the stage to the next task or not. Arguments: input_data (ndarray(dtype=float, shape =(:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` ))): Task inputs. correct_output (ndarray(dtype=float, shape = (:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out`))): Correct (target) task output given input_data. output_mask (ndarray(dtype=float, shape = (:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out`))): Output mask for the task. True when the network should aim to match the target output, False when the target output can be ignored. test_output (ndarray(dtype=float, shape = (:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out`))): The network's output given input_data. epoch (*int*): The epoch number in training. losses (*list of float*): List of losses, computed during training. verbosity (*bool, optional*): Whether to print information as metric is computed and stages advanced. Default: False Returns: True if stage advances, False otherwise. """ advance, metric_value = self.metric(self.__dict__, input_data, correct_output, output_mask, test_output, epoch, losses, verbosity) self.metric_values.append([metric_value, self.stage]) if advance: self.stage+=1 if self.stage == len(self.tasks): self.stop_training = True if self.output_file is not None:, self.metric_values) if verbosity: print("Metric values saved in file: %s" % self.output_file) if verbosity: print("Stage " + str(self.stage)) return True return False
[docs] def get_generator_function(self): """ Depcreated method to return a generator for the current task. Use :func:`batch_generator` instead. """ warnings.warn("Curriculum.get_generator_function is deprecated. Use Curriculum.batch_generator instead.", DeprecationWarning) return self.batch_generator()
[docs] def batch_generator(self): """ Returns a generator for the current task. Returns: Generator[tuple, None, None]: generator iterator for the current task Yields: tuple: * **stimulus** (*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*): Task stimuli for :attr:`N_batch` trials. * **target_output** (*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*): Target output for the network on :attr:`N_batch` trials given the :data:`stimulus`. * **output_mask** (*ndarray(dtype=bool, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*): Output mask for :attr:`N_batch` trials. True when the network should aim to match the target output, False when the target output can be ignored. * **trial_params** (*ndarray(dtype=dict, shape =(*:attr:`N_batch` *,))*): Array of dictionaries containing the trial parameters produced by :func:`generate_trial_params` for each trial in :attr:`N_batch`. """ return self.tasks[self.stage].batch_generator()