Source code for psychrnn.backend.loss_functions

from __future__ import division

import tensorflow as tf
tf.compat.v1.disable_eager_execution()


[docs]class LossFunction(object): """ Set the loss function for the :class:`~psychrnn.backend.rnn.RNN` model. Arguments: params(dict): Dictionary of parameters including the following keys: :Dictionary Keys: * **loss_function** (*str*) -- String indicating what loss function to use. If :data:`params["loss_function"]` is not `mean_squared_error` or `binary_cross_entropy`, :data:`params[params["loss_function"]]` defines the custom loss function. Default: "mean_squared_error". * **params["loss_function"]** (*function, optional*) -- Defines the custom loss function. Must have the same signature as :func:`mean_squared_error` and :func:`binary_cross_entropy`. """ def __init__(self, params): self.type = params.get("loss_function", "mean_squared_error") if self.type != "mean_squared_error" and self.type != "binary_cross_entropy": self.custom_loss_function = params.get(self.type, None) if self.custom_loss_function is None: raise UserWarning("Loss type is '" + self.type + "' but '" + self.type + "' is not an entry in params. Did you mean 'mean_square_error' or 'binary_cross_entropy'? If not, you must pass a function in to params as '" + self.type + "'." )
[docs] def set_model_loss(self, model): """ Returns the model loss, calculated as indicated by :attr:`type` (inferred from :data:`params["loss_function"]`. ``'mean_squared_error'`` indicates :func:`mean_squared_error`, ``'binary_cross_entropy'`` indicates :func:`binary_cross_entropy`. If :attr:`type` is not one of the above options, :attr:`custom_loss_function` is used. The custom loss function would have been passed in to :data:`params` as :data:`params[type]`. Args: model (:class:`~psychrnn.backend.rnn.RNN` object): Model for which to calculate the regularization. Returns: tf.Tensor(dtype=float): Model loss. """ loss = 0 if self.type == "mean_squared_error": loss = self.mean_squared_error(model.predictions, model.y, model.output_mask) elif self.type == "binary_cross_entropy": loss = self.binary_cross_entropy(model.predictions, model.y, model.output_mask) else: loss = self.custom_loss_function(model.predictions, model.y, model.output_mask) return loss
[docs] def mean_squared_error(self, predictions, y, output_mask): """ Mean squared error. ``loss = mean(square(output_mask * (predictions - y)))`` Args: predictions (*tf.Tensor(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*): Network output. y (*tf.Tensor(dtype=float, shape =(*?, :attr:`N_steps`, :attr:`N_out` *))*): Target output. output_mask (*tf.Tensor(dtype=float, shape =(*?, :attr:`N_steps`, :attr:`N_out` *))*): Output mask for :attr:`N_batch` trials. True when the network should aim to match the target output, False when the target output can be ignored. Returns: tf.Tensor(dtype=float): Mean squared error. """ return tf.reduce_mean(input_tensor=tf.square(output_mask * (predictions - y)))
[docs] def binary_cross_entropy(self, predictions, y, output_mask): """ Binary cross-entropy. Binary label values are assumed to be 0 and 1. ``loss = mean(output_mask * -(y * log(predictions) + (1-y)* log(1-predictions)))`` Args: predictions (*tf.Tensor(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*): Network output. y (*tf.Tensor(dtype=float, shape =(*?, :attr:`N_steps`, :attr:`N_out` *))*): Target output. output_mask (*tf.Tensor(dtype=float, shape =(*?, :attr:`N_steps`, :attr:`N_out` *))*): Output mask for :attr:`N_batch` trials. True when the network should aim to match the target output, False when the target output can be ignored. Returns: tf.Tensor(dtype=float): Binary cross-entropy. """ epsilon = 1e-07 # default epsilon used in TensorFlow predictions = tf.clip_by_value(predictions, epsilon, 1. - epsilon) return tf.reduce_mean( input_tensor=output_mask * -(y * tf.math.log(predictions + epsilon) + (1 - y) * tf.math.log(1 - predictions + epsilon)))