from __future__ import absolute_import
import numpy as np
import matplotlib.pyplot as plt
from .. import utils
from ..config import DEFAULT_POPULATION_COLORS
[docs]class NeuroVis(object):
'''This class is used to visualize firing activity of single neurons.
This class implements several conveniences for visualizing firing
activity of single neurons.
Args:
spiketimes (Numpy array): Array of spike times.
name (str): The name of the visualization.
'''
def __init__(self, spiketimes, name='neuron'):
self.name = name
self.spiketimes = np.squeeze(np.sort(spiketimes))
n_seconds = (self.spiketimes[-1] - self.spiketimes[0])
n_spikes = np.size(spiketimes)
self.firingrate = (n_spikes / n_seconds)
[docs] def get_raster(self, event=None, conditions=None, df=None,
window=[-100, 500], binsize=10, plot=True,
sortby=None, sortorder='descend'):
'''Computes the raster and plots it.
Args:
event (str): Column/key name of DataFrame/dictionary "data" which
contains event times in milliseconds (e.g.
stimulus/trial/fixation onset, etc.)
conditions (str): Column/key name of DataFrame/dictionary
:data:`data` which contains the conditions by which the trials
must be grouped.
df (DataFrame or dictionary): The dataframe containing the data,
or a dictionary with the equivalent structure.
window (list of 2 elements): Time interval to consider, in
milliseconds.
binsize (int): Bin size in milliseconds
plot (bool): If True then plot
sortby (str or list): If :data:`rate`, sort by firing rate. If
:data:`latency`, sort by peak latency. If a list, integers to
be used as sorting indices.
sortorder (str): Direction to sort, either :data:`descend` or
:data:`ascend`.
Returns:
dict: :data:`rasters` with keys :data:`event`, :data:`conditions`,
:data:`binsize`, :data:`window`, and :data:`data`.
:data:`rasters['data']` is a dictionary where each value is a
raster for each unique entry of :data:`df['conditions']`.
'''
if not type(df) is dict:
df = df.reset_index()
window = [np.floor(window[0] / binsize) * binsize,
np.ceil(window[1] / binsize) * binsize]
# Get a set of binary indicators for trials of interest
if conditions:
trials = dict()
for cond_id in np.sort(df[conditions].unique()):
trials[cond_id] = \
np.where((df[conditions] == cond_id).apply(
lambda x: (0, 1)[x]).values)[0]
else:
trials = dict()
trials[0] = np.where(np.ones(np.size(df[event])))[0]
# Initialize rasters
rasters = {
'event': event,
'conditions': conditions,
'window': window,
'binsize': binsize,
'data': {},
}
# Loop over each raster
for cond_id in trials:
# Select events relevant to this raster
selected_events = df[event][trials[cond_id]]
raster = []
bin_template = 1e-3 * \
np.arange(window[0], window[1] + binsize, binsize)
for event_time in selected_events:
bins = event_time + bin_template
# consider only spikes within window
searchsorted_idx = np.squeeze(np.searchsorted(self.spiketimes,
[event_time + 1e-3 *
window[0],
event_time + 1e-3 *
window[1]]))
# bin the spikes into time bins
spike_counts = np.histogram(
self.spiketimes[searchsorted_idx[0]:searchsorted_idx[1]],
bins)[0]
raster.append(spike_counts)
rasters['data'][cond_id] = np.array(raster)
# Show the raster
if plot is True:
self.plot_raster(rasters, cond_id=None, sortby=sortby,
sortorder=sortorder)
# Return all the rasters
return rasters
[docs] def plot_raster(self, rasters, cond_id=None, cond_name=None, sortby=None,
sortorder='descend', cmap='Greys', has_title=True):
'''Plot a single raster.
Args:
rasters (dict): Output of get_raster method
cond_id (str): Which raster to plot indicated by the key in
:data:`rasters['data']`. If None then all are plotted.
cond_name (str): Name to appear in the title.
sortby (str or list): If :data:`rate`, sort by firing rate. If
:data:`latency`, sort by peak latency. If a list, integers to
be used as sorting indices.
sortorder (str): Direction to sort in, either :data:`descend` or
:data:`ascend`.
cmap (str): Colormap for raster.
has_title (bool): If True then adds title.
'''
window = rasters['window']
binsize = rasters['binsize']
xtics = [window[0], 0, window[1]]
xtics = [str(i) for i in xtics]
xtics_loc = [-0.5, (-window[0]) / binsize - 0.5,
(window[1] - window[0]) / binsize - 0.5]
if cond_id is None:
for cond in list(rasters['data']):
self.plot_raster(rasters, cond_id=cond, cond_name=cond_name,
sortby=sortby, sortorder=sortorder, cmap=cmap,
has_title=has_title)
plt.show()
else:
raster = rasters['data'][cond_id]
if len(raster) > 0:
sort_idx = utils.get_sort_indices(
data=raster,
by=sortby,
order=sortorder,
)
raster_sorted = raster[sort_idx]
plt.imshow(raster_sorted, aspect='auto',
interpolation='none', cmap=plt.get_cmap(cmap))
plt.axvline(
(-window[0]) / binsize - 0.5, color='r', linestyle='--')
plt.ylabel('trials')
plt.xlabel('time [ms]')
plt.xticks(xtics_loc, xtics)
if has_title:
if cond_id:
if cond_name:
plt.title('neuron %s. %s' %
(self.name, cond_name))
else:
plt.title('neuron %s. %s: %s' %
(self.name, rasters['conditions'],
cond_id))
else:
plt.title('neuron %s' % self.name)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.tick_params(axis='x', which='both', top='off')
plt.tick_params(axis='y', which='both', right='off')
else:
print('No trials for this condition!')
[docs] def get_psth(self, event=None, df=None, conditions=None, cond_id=None,
window=[-100, 500], binsize=10, plot=True, event_name=None,
conditions_names=None, ylim=None,
colors=DEFAULT_POPULATION_COLORS):
'''Compute the PSTH and plot it.
Args:
event (str): Column/key name of DataFrame/dictionary :data:`data`
which contains event times in milliseconds (e.g.
stimulus/trial/fixation onset, etc.)
conditions (str): Column/key name of DataFrame/dictionary
:data:`data` which contains the conditions by which the trials
must be grouped.
cond_id (list): Which psth to plot indicated by the key in
:data:`all_psth['data']``. If None then all are plotted.
df (DataFrame or dictionary): The dataframe containing the data.
window (list of 2 elements): Time interval to consider, in
milliseconds.
binsize (int): Bin size in milliseconds.
plot (bool): If True then plot.
event_name (string): Legend name for event. Default is the actual
event name
conditions_names (TODO): Legend names for conditions. Default are
the unique values in :data:`df['conditions']`.
ylim (list): The lower and upper limits for Y.
colors (list): The colors for the plot.
Returns:
dict: :data:`rasters` with keys :data:`event`, :data:`conditions`,
:data:`binsize`, :data:`window`, and :data:`data`.
:data:`rasters['data']` is a dictionary where each value is a
raster for each unique entry of :data:`df['conditions']`.
'''
window = [np.floor(window[0] / binsize) * binsize,
np.ceil(window[1] / binsize) * binsize]
# Get all the rasters first
rasters = self.get_raster(event=event, df=df,
conditions=conditions,
window=window, binsize=binsize, plot=False)
# Initialize PSTH
psth = dict()
psth['window'] = window
psth['binsize'] = binsize
psth['event'] = event
psth['conditions'] = conditions
psth['data'] = dict()
# Compute the PSTH
for cond_id in np.sort(list(rasters['data'])):
psth['data'][cond_id] = dict()
raster = rasters['data'][cond_id]
mean_psth = np.mean(raster, axis=0) / (1e-3 * binsize)
std_psth = np.sqrt(np.var(raster, axis=0)) / (1e-3 * binsize)
sem_psth = std_psth / np.sqrt(float(np.shape(raster)[0]))
psth['data'][cond_id]['mean'] = mean_psth
psth['data'][cond_id]['sem'] = sem_psth
if plot is True:
if not event_name:
event_name = event
conditions_names = list(psth['data'])
self.plot_psth(psth, ylim=ylim, event_name=event_name,
conditions_names=conditions_names,
colors=colors)
return psth
[docs] def plot_psth(self, psth, event_name='event_onset', conditions_names=None,
cond_id=None, ylim=None, colors=DEFAULT_POPULATION_COLORS):
'''Plots PSTH.
Args:
psth (dict): Output of :meth:`get_psth`.
event_name (string): Legend name for event. Default is the actual
event name.
conditions_names (list of str): Legend names for conditions.
Default are the keys in :data:`psth['data']`.
cond_id (list): Which psth to plot indicated by the key in
:data:`all_psth['data']``. If None then all are plotted.
ylim (list): The lower and upper limits for Y.
colors (list): The colors for the plot.
'''
window = psth['window']
binsize = psth['binsize']
conditions = psth['conditions']
if cond_id is None:
keys = np.sort(list(psth['data'].keys()))
else:
keys = cond_id
if conditions_names is None:
conditions_names = keys
scale = 0.1
y_min = (1.0 - scale) * np.nanmin([np.min(
psth['data'][psth_idx]['mean'])
for psth_idx in psth['data']])
y_max = (1.0 + scale) * np.nanmax([np.max(
psth['data'][psth_idx]['mean'])
for psth_idx in psth['data']])
legend = [event_name]
time_bins = np.arange(window[0], window[1], binsize) + binsize / 2.0
if ylim:
plt.plot([0, 0], ylim, color='k', ls='--')
else:
plt.plot([0, 0], [y_min, y_max], color='k', ls='--')
for i, cond_id in enumerate(keys):
if np.all(np.isnan(psth['data'][cond_id]['mean'])):
plt.plot(0, 0, alpha=1.0, color=colors[i % len(colors)])
else:
plt.plot(time_bins, psth['data'][cond_id]['mean'],
color=colors[i % len(colors)], lw=1.5)
for i, cond_id in enumerate(keys):
if conditions is not None:
if conditions_names is not None:
legend.append('%s' % conditions_names[i])
else:
legend.append('%s' % str(cond_id))
else:
legend.append('all')
if not np.all(np.isnan(psth['data'][cond_id]['mean'])):
plt.fill_between(time_bins, psth['data'][cond_id]['mean'] -
psth['data'][cond_id]['sem'],
psth['data'][cond_id]['mean'] +
psth['data'][cond_id]['sem'],
color=colors[i % len(colors)],
alpha=0.2)
if conditions:
plt.title('neuron %s: %s' % (self.name, conditions))
else:
plt.title('neuron %s' % self.name)
plt.xlabel('time [ms]')
plt.ylabel('spikes per second [spks/s]')
if ylim:
plt.ylim(ylim)
else:
plt.ylim([y_min, y_max])
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tick_params(axis='y', right='off')
plt.tick_params(axis='x', top='off')
plt.legend(legend, frameon=False)
[docs] def get_spikecounts(self, event=None, df=None,
window=np.array([50.0, 100.0])):
'''Counts spikes in the dataframe.
Args:
event (str): Column/key name of DataFrame/dictionary :data:`data`
which contains event times in milliseconds (e.g.
stimulus/trial/fixation onset, etc.)
window (list of 2 elements): Time interval to consider, in
milliseconds.
Return:
array: An :data:`n x 1` array of spike counts.
'''
events = df[event].values
spiketimes = self.spiketimes
spikecounts = np.asarray([
np.sum(np.all((
spiketimes >= e + 1e-3 * window[0],
spiketimes <= e + 1e-3 * window[1],
), axis=0))
for e in events
])
return spikecounts