Source code for spykes.plot.popvis

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import matplotlib.pyplot as plt
import copy
from collections import defaultdict

from fractions import gcd

from .neurovis import NeuroVis
from .. import utils
from ..config import DEFAULT_POPULATION_COLORS

# Defines the default colors for a PSTH plot.
DEFAULT_PSTH_COLORS = ['Blues', 'Reds', 'Greens']


[docs]class PopVis(object): '''Facilitates visualization of neuron population firing activity. Args: neuron_list (list of NeuroVis objects): The list of neurons to visualize (see the NeuroVis class in :class:`spykes.plot.neurovis`). Attributes: n_neurons (int): The number of neurons in the visualization. name (str): The name of this visualization. ''' def __init__(self, neuron_list, name='PopVis'): self.neuron_list = neuron_list self._name = name @property def name(self): return self._name @property def n_neurons(self): return len(self.neuron_list)
[docs] def get_all_psth(self, event=None, df=None, conditions=None, window=[-100, 500], binsize=10, conditions_names=None, plot=True, colors=DEFAULT_PSTH_COLORS): '''Iterates through all neurons and computes their PSTH's. Args: event (str): Column/key name of DataFrame/dictionary "data" which contains event times in milliseconds (e.g. stimulus/trial/ fixation onset, etc.). df (DataFrame or dictionary): The data to use. conditions (str): Column/key name of DataFrame/dictionary :data:`df` which contains the conditions by which the trials must be grouped window (list of 2 elements): Time interval to consider, in milliseconds. binsize (int): Bin size, in milliseconds. conditions_names (list of str): Legend names for conditions. Default are the unique values in :data:`df['conditions']`. plot (bool): If set, automatically plot; otherwise, don't. colors (list): List of colors for heatmap (only if plot is True). Returns: dict: With keys :data:`event`, :data:`conditions`, :data:`binsize`, :data:`window`, and :data:`data`. Each entry in :data:`psth['data']` is itself a dictionary with keys of each :data:`cond_id` that correspond to the means for that condition. ''' all_psth = { 'window': window, 'binsize': binsize, 'event': event, 'conditions': conditions, 'data': defaultdict(list), } for i, neuron in enumerate(self.neuron_list): psth = neuron.get_psth( event=event, df=df, conditions=conditions, window=window, binsize=binsize, plot=False, ) for cond_id in np.sort(list(psth['data'].keys())): all_psth['data'][cond_id].append(psth['data'][cond_id]['mean']) for cond_id in np.sort(list(all_psth['data'].keys())): all_psth['data'][cond_id] = np.stack(all_psth['data'][cond_id]) if plot is True: self.plot_heat_map( all_psth, conditions_names=conditions_names, colors=colors, ) return all_psth
[docs] def plot_heat_map(self, psth_dict, cond_id=None, conditions_names=None, sortby=None, sortorder='descend', normalize=None, neuron_names=True, colors=None, show=False): '''Plots heat map for neuron population Args: psth_dict (dict): With keys :data:`event`, :data:`conditions`, :data:`binsize`, :data:`window`, and :data:`data`. Each entry in :data:`psth['data']` is itself a dictionary with keys of each :data:`cond_id` that correspond to the means for that condition. cond_id (str): Which psth to plot indicated by the key in :data:`all_psth['data']`. If None then all are plotted. conditions_names (str or list of str): Name(s) to appear in the title. sortby (str or list): If :data:`rate`, sort by firing rate. If :data:`latency`, sort by peak latency. If a list or array is provided, it must correspond to integer indices to be used as sorting indices. If no sort order is provided, the data is returned as-is. sortorder (str): The direction to sort in, either :data:`descend` or :data:`ascend`. normalize (str): If :data:`all`, divide all PSTHs by highest peak firing rate in all neurons. If :data:`each`, divide each PSTH by its own peak firing rate. If None, do not normalize. neuron_names (bool): Whether or not to list the names of neurons on the side. colors (list of str): List of colors for the heatmap (as strings). show (bool): If set, show the plot once finished. ''' if colors is None: colors = ['Blues', 'Reds', 'Greens'] window = psth_dict['window'] binsize = psth_dict['binsize'] conditions = psth_dict['conditions'] if cond_id is None: keys = np.sort(list(psth_dict['data'].keys())) else: keys = cond_id if conditions_names is None: conditions_names = keys for i, cond_id in enumerate(keys): # Sorts and norms the data. orig_data = psth_dict['data'][cond_id] normed_data = self._get_normed_data(orig_data, normalize=normalize) sort_idx = utils.get_sort_indices( normed_data, by=sortby, order=sortorder, ) data = normed_data[sort_idx, :] plt.subplot(len(keys), 1, i + 1) plt.pcolormesh(data, cmap=colors[i % len(colors)]) # Makes it visually appealing. xtic_len = gcd(abs(window[0]), window[1]) xtic_labels = range(window[0], window[1] + xtic_len, xtic_len) xtic_locs = [(j - window[0]) / binsize for j in xtic_labels] if 0 not in xtic_labels: xtic_labels.append(0) xtic_locs.append(-window[0] / binsize) plt.xticks(xtic_locs, xtic_labels) plt.axvline((-window[0]) / binsize, color='r', linestyle='--') if neuron_names: unsorted_ylabels = [neuron.name for neuron in self.neuron_list] ylabels = [unsorted_ylabels[j] for j in sort_idx] else: ylabels = ["" for neuron in self.neuron_list] plt.yticks(np.arange(data.shape[0]) + 0.5, ylabels) ax = plt.gca() ax.invert_yaxis() ax.set_frame_on(False) plt.tick_params(axis='x', which='both', top='off') plt.tick_params(axis='y', which='both', left='off', right='off') plt.xlabel('time [ms]') plt.ylabel('Neuron') plt.title("%s: %s" % (conditions, conditions_names[i])) plt.colorbar() if show: plt.show()
[docs] def plot_population_psth(self, all_psth=None, event=None, df=None, conditions=None, cond_id=None, window=[-100, 500], binsize=10, conditions_names=None, event_name='event_onset', ylim=None, colors=DEFAULT_POPULATION_COLORS, show=False): '''Plots population PSTH's. This involves two steps. First, it normalizes each neuron's PSTH across the conditions. Second, it averages out and plots population PSTH. Args: all_psth (dict): With keys :data:`event`, :data:`conditions`, :data:`binsize`, :data:`window`, and :data:`data`. Each entry in :data:`psth['data']` is itself a dictionary with keys of each :data:`cond_id` that correspond to the means for that condition. event (str): Column/key name of the :data:`df` DataFrame/dictionary which contains event times in milliseconds (stimulus/trial/fixation onset, etc.). df (DataFrame or dictionary): DataFrame containing the data to plot, or a dictionary corresponding to such a DataFrame. conditions (str): Column/key name of DataFrame/dictionary :data:`df` which contains the conditions by which the trials must be grouped. cond_id (list): Which psth to plot indicated by the key in :data:`all_psth['data']``. If :data:`None` then all are plotted. window (list of 2 elements): Time interval to consider, in milliseconds. binsize (int): Bin size, in milliseconds. conditions_names (list): Legend names for conditions. Default are the unique values in :data:`df['conditions']`. event_name (string): Legend name for event. Default is the actual event name ylim (list): The minimum and maximum values for Y. colors (list): The colors to use for the plot (as strings). show (bool): If set, show the plot once it has been created. ''' # placeholder in order to use NeuroVis functionality base_neuron = NeuroVis(spiketimes=range(10), name=self.name) if all_psth is None: psth = self.get_all_psth( event=event, df=df, conditions=conditions, window=window, binsize=binsize, plot=False, ) else: psth = copy.deepcopy(all_psth) keys = np.sort(list(psth['data'].keys())) # Normalizes each neuron across all conditions. for i in range(self.n_neurons): max_rates = list() for key in keys: max_rates.append(np.amax(psth['data'][key][i, :])) norm_factor = max(max_rates) for key in keys: psth['data'][key][i, :] /= norm_factor # Averages out all the neurons and plots. for i, key in enumerate(keys): normed_data = psth['data'][key] psth['data'][key] = dict() psth['data'][key]['mean'] = np.nanmean(normed_data, axis=0) psth['data'][key]['sem'] = \ np.nanvar(normed_data, axis=0) / (len(self.neuron_list)**.5) # Plots the PSTH. base_neuron.plot_psth( psth=psth, event_name=event_name, cond_id=cond_id, conditions_names=conditions_names, ylim=ylim, colors=colors, ) # Adds the appropriate title. plt.title("%s Population PSTH: %s" % (self.name, psth['conditions'])) if show: plt.show()
def _get_normed_data(self, data, normalize): '''Normalizes all PSTH data Args: data (2-D numpy array): Array with shape :data:`(n_neurons, n_bins)` normalize (str): If :data:`all`, divide all PSTHs by highest peak firing rate in all neurons. If :data:`each`, divide each PSTH by its own peak firing rate. If None, do not normalize. Returns: array: The original data array, divided such that all values fall between 0 and 1. ''' max_rates = np.amax(data, axis=1) # Computes the normalization factors. if normalize == 'all': norm_factors = np.ones([data.shape[0], 1]) * np.amax(max_rates) elif normalize == 'each': norm_factors = ( np.reshape(max_rates, (max_rates.shape[0], 1)) * np.ones((1, data.shape[1])) ) elif normalize is None: norm_factors = np.ones([data.shape[0], 1]) else: raise ValueError('Invalid norm factors: {}'.format(norm_factors)) return data / norm_factors