Neuropixels Example

Use spykes to analyze data from UCL’s Neuropixels

# Authors: Mayank Agrawal <>
# License: MIT
import numpy as np
import pandas as pd
from spykes.plot.neurovis import NeuroVis
from spykes.plot.popvis import PopVis
import matplotlib.pyplot as plt
from import load_neuropixels_data'seaborn-ticks')


Neuropixels is a new recording technique by UCL’s Cortex Lab that is able to measure data from hundreds of neurons. Below we show how this data can be worked with in Spykes

0 Download Data

Download all data here.

1 Read In Data

folder_names = ['posterior', 'frontal']
Fs = 30000.0

striatum = list()
motor_ctx = list()
thalamus = list()
hippocampus = list()
visual_ctx = list()

# a lot of this code is adapted from Cortex Lab's MATLAB script
# see here:

data_dict = load_neuropixels_data()

for name in folder_names:

    clusters = np.squeeze(data_dict[name + '/spike_clusters.npy'])
    spike_times = (np.squeeze(data_dict[(name + '/spike_times.npy')])) / Fs
    spike_templates = (np.squeeze(data_dict[(name + '/spike_templates.npy')]))
    temps = (np.squeeze(data_dict[(name + '/templates.npy')]))
    winv = (np.squeeze(data_dict[(name + '/whitening_mat_inv.npy')]))
    y_coords = (np.squeeze(data_dict[(name + '/channel_positions.npy')]))[:, 1]

    # frontal times need to align with posterior
    if (name == 'frontal'):
        time_correction = data_dict[('timeCorrection.npy')]
        spike_times *= time_correction[0]
        spike_times += time_correction[1]

    data = data_dict[(name + '/cluster_groups.csv')]
    cids = np.array([x[0] for x in data])
    cfg = np.array([x[1] for x in data])

    # find good clusters and only use those spikes
    good_clusters = cids[cfg == 'good']
    good_indices = (np.in1d(clusters, good_clusters))

    real_spikes = spike_times[good_indices]
    real_clusters = clusters[good_indices]
    real_spike_templates = spike_templates[good_indices]

    # find how many spikes per cluster and then order spikes by which cluster
    # they are in

    counts_per_cluster = np.bincount(real_clusters)

    sort_idx = np.argsort(real_clusters)
    sorted_clusters = real_clusters[sort_idx]
    sorted_spikes = real_spikes[sort_idx]
    sorted_spike_templates = real_spike_templates[sort_idx]

    # find depth for each spike
    # this is translated from Cortex Lab's MATLAB code
    # for more details, check out the original code here:

    temps_unw = np.zeros(temps.shape)
    for t in range(temps.shape[0]):
        temps_unw[t, :, :] =[t, :, :], winv)

    temp_chan_amps = np.ptp(temps_unw, axis=1)
    temps_amps = np.max(temp_chan_amps, axis=1)
    thresh_vals = temps_amps * 0.3

    thresh_vals = [thresh_vals for i in range(temp_chan_amps.shape[1])]
    thresh_vals = np.stack(thresh_vals, axis=1)

    temp_chan_amps[temp_chan_amps < thresh_vals] = 0

    y_coords = np.reshape(y_coords, (y_coords.shape[0], 1))
    temp_depths = np.sum(, y_coords), axis=1) / (np.sum(temp_chan_amps,

    sorted_spike_depths = temp_depths[sorted_spike_templates]

    # create neurons and find region

    accumulator = 0

    for idx, count in enumerate(counts_per_cluster):

        if count > 0:

            spike_times = sorted_spikes[accumulator:accumulator + count]
            neuron = NeuroVis(spiketimes=spike_times, name='%d' % (idx))
            cluster_depth = np.mean(
                sorted_spike_depths[accumulator:accumulator + count])

            if name == 'frontal':

                if (cluster_depth > 0 and cluster_depth < 1550):
                elif (cluster_depth > 1550 and cluster_depth < 3840):

            elif name == 'posterior':

                if (cluster_depth > 0 and cluster_depth < 1634):
                elif (cluster_depth > 1634 and cluster_depth < 2797):
                elif (cluster_depth > 2797 and cluster_depth < 3840):

            accumulator += count

print("Striatum (n = %d)" % len(striatum))
print("Motor Cortex (n = %d)" % len(motor_ctx))
print("Thalamus (n = %d)" % len(thalamus))
print("Hippocampus (n = %d)" % len(hippocampus))
print("Visual Cortex (n = %d)" % len(visual_ctx))


Striatum (n = 200)
Motor Cortex (n = 243)
Thalamus (n = 244)
Hippocampus (n = 68)
Visual Cortex (n = 76)

2 Create Data Frame

df = pd.DataFrame()

raw_data = data_dict['experiment1stimInfo.mat']

df['start'] = np.squeeze(raw_data['stimStarts'])
df['stop'] = np.squeeze(raw_data['stimStops'])
df['stimulus'] = np.squeeze(raw_data['stimIDs'])



start        stop  stimulus
0   99.108333  101.108333         6
1  101.908333  103.908333         3
2  104.924667  106.924667        11
3  107.574667  109.574667         7
4  110.308000  112.308000         8

3 Start Plotting

3.1 Striatum

pop = PopVis(striatum, name='Striatum')

fig = plt.figure(figsize=(30, 20))

all_psth = pop.get_all_psth(
    event='start', df=df, conditions='stimulus', plot=False, binsize=100,
    window=[-500, 2000])

pop.plot_heat_map(all_psth, cond_id=[
                  2, 7, 13], sortorder='descend', neuron_names=False)
pop.plot_population_psth(all_psth=all_psth, cond_id=[1, 7, 12])

3.2 Frontal

pop = PopVis(striatum + motor_ctx, name='Frontal')

fig = plt.figure(figsize=(30, 20))

all_psth = pop.get_all_psth(
    event='start', df=df, conditions='stimulus', plot=False, binsize=100,
    window=[-500, 2000])

    all_psth, cond_id=[2, 7, 13], sortorder='descend', neuron_names=False)
pop.plot_population_psth(all_psth=all_psth, cond_id=[1, 7, 12])

3.3 All Neurons

pop = PopVis(striatum + motor_ctx + thalamus + hippocampus + visual_ctx)

fig = plt.figure(figsize=(30, 20))

all_psth = pop.get_all_psth(
    event='start', df=df, conditions='stimulus', plot=False, binsize=100,
    window=[-500, 2000])

    all_psth, cond_id=[2, 7, 13], sortorder='descend', neuron_names=False)
pop.plot_population_psth(all_psth=all_psth, cond_id=[1, 7, 12])

3.4 Striatum vs. Motor Cortex

striatum_pop = PopVis(striatum, name='Striatum')
motor_ctx_pop = PopVis(motor_ctx, name='Motor Cortex')

striatum_psth = striatum_pop.get_all_psth(
    event='start', df=df, conditions='stimulus', plot=False, binsize=100,
    window=[-500, 2000])
motor_ctx_psth = motor_ctx_pop.get_all_psth(
    event='start', df=df, conditions='stimulus', plot=False, binsize=100,
    window=[-500, 2000])
striatum_pop.plot_population_psth(all_psth=striatum_psth, cond_id=[1, 7, 12])
motor_ctx_pop.plot_population_psth(all_psth=motor_ctx_psth, cond_id=[1, 7, 12])

Total running time of the script: ( 1 minutes 4.994 seconds)

Gallery generated by Sphinx-Gallery