Source code for spykes.ml.tensorflow.poisson_models

from math import pi as PI

import tensorflow as tf
from tensorflow import keras as ks


[docs]class PoissonLayer(ks.layers.Layer): '''Defines a TensorFlow implementation of the NeuroPop layers. Two types of models are available. `The Generalized von Mises model by Amirikan & Georgopulos (2000) <http://brain.umn.edu/pdfs/BA118.pdf>`_ is defined by .. math:: f(x) = b + g * exp(k * cos(x - mu)) f(x) = b + g * exp(k1 * cos(x) + k2 * sin(x)) The Poisson generalized linear model is defined by .. math:: f(x) = exp(k0 + k * cos(x - mu)) f(x) = exp(k0 + k1 * cos(x) + k2 * sin(x)) Args: model_type (str): Can be either :data:`gvm`, the Generalized von Mises model, or :data:`glm`, the Poisson generalized linear model. num_neurons (int): Number of neurons in the population (being inferred from the input features). num_features (int): Number of input features. Convenience parameter for for setting the input shape. mu_initializer (Keras initializer): The initializer for the :data:`mu`. k_initializer (Keras initializer): The initializer for the :data:`k`. g_initializer (Keras initializer): The initializer for the :data:`g`. If :data:`model_type` is :data:`glm`, this is ignored. b_initializer (Keras initializer): The initializer for the :data:`b`. If :data:`model_type` is :data:`glm`, this is ignored. k0_initializer (Keras initializer): The initializer for the :data:`k0`. If :data:`model_type` is :data:`gvm`, this is ignored. ''' def __init__(self, model_type, num_neurons, num_features=None, mu_initializer=ks.initializers.RandomUniform(-PI, PI), k_initializer=ks.initializers.RandomNormal(stddev=.2), g_initializer=ks.initializers.RandomNormal(stddev=.05), b_initializer=ks.initializers.RandomNormal(stddev=.1), k0_initializer=ks.initializers.RandomNormal(stddev=.01), **kwargs): if num_features is not None: kwargs['input_shape'] = (num_features,) super(PoissonLayer, self).__init__(**kwargs) self.model_type = model_type.lower() if self.model_type not in ('gvm', 'glm'): raise ValueError('Invalid model type: "{}" Must be either "gvm" ' '(generalised Von Mises model) or "glm" ' '(generalized linear model)'.format(model_type)) self.num_neurons = num_neurons self.mu_initializer = ks.initializers.get(mu_initializer) self.g_initializer = ks.initializers.get(g_initializer) self.b_initializer = ks.initializers.get(b_initializer) self.k_initializer = ks.initializers.get(k_initializer) self.k0_initializer = ks.initializers.get(k0_initializer) def build(self, input_shape): assert len(input_shape) == 2 input_dim = input_shape[-1] self.mu = self.add_weight( shape=(input_dim, self.num_neurons), initializer=self.mu_initializer, name='mu', ) self.k1 = self.add_weight( shape=(input_dim, self.num_neurons), initializer=self.k_initializer, name='k1', ) self.k2 = self.add_weight( shape=(input_dim, self.num_neurons), initializer=self.k_initializer, name='k2', ) # Adds generalized Von Mises parameters. if self.model_type == 'gvm': self.g = self.add_weight( shape=(1, input_dim), initializer=self.g_initializer, name='g', ) self.b = self.add_weight( shape=(1, input_dim), initializer=self.b_initializer, name='b', ) # Adds generalized linear model parameters. if self.model_type == 'glm': self.k0 = self.add_weight( shape=(1, input_dim), initializer=self.k_initializer, name='k0', ) def call(self, inputs): k1 = tf.matmul(tf.cos(inputs), self.k1 * tf.cos(self.mu)) k2 = tf.matmul(tf.sin(inputs), self.k2 * tf.sin(self.mu)) # Defines the two model formulations: "glm" vs "gvm". if self.model_type == 'glm': return tf.exp(k1 + k2 + self.k0) else: return tf.nn.softplus(self.b) + self.g * tf.exp(k1 + k2) def get_config(self): config = { 'model_type': self.model_type, 'mu_initializer': ks.initializers.serialize(self.mu_initializer), 'g_initializer': ks.initializers.serialize(self.g_initializer), 'b_initializer': ks.initializers.serialize(self.b_initializer), 'k_initializer': ks.initializers.serialize(self.k_initializer), 'k0_initializer': ks.initializers.serialize(self.k0_initializer), } base_config = super(PoissonLayer, self).get_config() return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): output_shape = list(input_shape) output_shape[-1] = self.num_neurons return tuple(output_shape)