Source code for spykes.ml.strf

import numpy as np
import matplotlib.pyplot as plt


[docs]class STRF(object): '''Allows the estimation of spatiotemporal receptive fields Args: patch_size (int): Dimension of the square patch spanned by the spatial basis. sigma (float): Standard deviation of the Gaussian distribution. n_spatial_basis (int): Number of spatial basis functions for the Gaussian basis (has to be a perfect square). n_temporal_basis (int): Number of temporal basis functions. ''' def __init__(self, patch_size=100, sigma=0.5, n_spatial_basis=25, n_temporal_basis=3): self.patch_size = patch_size self.sigma = sigma self.n_spatial_basis = n_spatial_basis self.n_temporal_basis = n_temporal_basis
[docs] def make_2d_gaussian(self, center=(0, 0)): '''Makes a 2D Gaussian filter with arbitary mean and variance. Args: center (tuple): The coordinates of the center of the Gaussian, specified as :data:`(row, col)`. The center of the image is :data:`(0, 0)`. Returns: numpy array: The Gaussian mask. ''' sigma = self.sigma n_rows = (self.patch_size - 1.) / 2. n_cols = (self.patch_size - 1.) / 2. y, x = np.ogrid[-n_rows: n_rows + 1, -n_cols: n_cols + 1] y0, x0 = center[1], center[0] gaussian_mask = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2. * sigma ** 2)) gaussian_mask[gaussian_mask < np.finfo(gaussian_mask.dtype).eps * gaussian_mask.max()] = 0 gaussian_mask = 1. / gaussian_mask.max() * gaussian_mask return gaussian_mask
[docs] def make_gaussian_basis(self): '''Makes a list of Gaussian filters. Returns: list: A list where each entry is a 2D array of size :data:`(patch_size, patch_size)` specifing the spatial basis. ''' spatial_basis = list() n_tiles = np.sqrt(self.n_spatial_basis) n_pixels = self.patch_size centers = np.linspace(start=(-n_pixels / 2. + n_pixels / (n_tiles + 1.)), stop=(n_pixels / 2. - n_pixels / (n_tiles + 1.)), num=n_tiles) for y in range(n_tiles.astype(int)): for x in range(n_tiles.astype(int)): gaussian_mask = self.make_2d_gaussian(center=(centers[x], centers[y])) spatial_basis.append(gaussian_mask) return spatial_basis
[docs] def make_cosine_basis(self): '''Makes a spatial cosine and sine basis. Returns: list: A list where each entry is a 2D array of size :data:`(patch_size, patch_size)` specifing the spatial basis. ''' patch_size = self.patch_size cosine_mask = np.zeros((patch_size, patch_size)) sine_mask = np.zeros((patch_size, patch_size)) for row in np.arange(patch_size): for col in np.arange(patch_size): theta = np.arctan2(patch_size / 2 - row, col - patch_size / 2) cosine_mask[row, col] = np.cos(theta) sine_mask[row, col] = np.sin(theta) spatial_basis = list() spatial_basis.append(cosine_mask) spatial_basis.append(sine_mask) return spatial_basis
[docs] def visualize_gaussian_basis(self, spatial_basis, color='Greys', show=True): '''Plots spatial basis functions in a tile of images. Args: spatial_basis (list): A list where each entry is a 2D array of size :data:`(patch_size, patch_size)` specifing the spatial basis. color (str): The color for the figure. show (bool): Whether or not to show the image when it is plotted. ''' n_spatial_basis = len(spatial_basis) n_tiles = np.sqrt(n_spatial_basis) plt.figure(figsize=(7, 7)) for i in range(n_spatial_basis): plt.subplot(np.int(n_tiles), np.int(n_tiles), i + 1) plt.imshow(spatial_basis[i], cmap=color) plt.axis('off') if show: plt.show()
[docs] def project_to_spatial_basis(self, image, spatial_basis): '''Projects a given image into a spatial basis. Args: image (numpy array): Image that must be projected into the spatial basis 2D array of size :data:`(patch_size, patch_size)`. spatial_basis (list): A list where each entry is a 2D array of size :data:`(patch_size, patch_size)` specifing the spatial basis. Returns: numpy array: A 1D array of coefficients for the projection. ''' n_spatial_basis = len(spatial_basis) weights = np.zeros(n_spatial_basis) for b in range(n_spatial_basis): weights[b] = np.sum(spatial_basis[b] * image) return weights
[docs] def make_image_from_spatial_basis(self, basis, weights): '''Recovers an image from a basis given a set of weights. Args: spatial_basis (list): A list where each entry is a 2D array of size :data:`(patch_size, patch_size)` specifing the spatial basis. weights (numpy array): A 1D array of coefficients. Returns: numpy array: 2D array of size :data:`(patch_size, patch_size)`, the resulting image. ''' image = np.zeros(basis[0].shape) n_basis = len(basis) for b in range(n_basis): image += weights[b] * basis[b] return image
[docs] def make_raised_cosine_temporal_basis(self, time_points, centers, widths): '''Makes a series of raised cosine temporal basis. Args: time_points (numpy array): List of time points at which the basis function is computed. centers (numpy array or list): List of coordinates at which each basis function is centered 1D array of size :data:`(n_temporal_basis)`. widths (numpy array or list): List of widths, one per basis function; each is a 1D array of size :data:`(n_temporal_basis)`. Returns: numpy array: 2D array of size :data:`(n_basis, n_timepoints)`. ''' temporal_basis = list() for idx, center in enumerate(centers): this_basis = np.zeros(len(time_points)) arg_cos = (time_points - center) * np.pi / widths[idx] / 2. arg_cos[arg_cos > np.pi] = np.pi arg_cos[arg_cos < -np.pi] = -np.pi this_basis = 0.5 * (np.cos(arg_cos) + 1.) temporal_basis.append(this_basis) temporal_basis = np.transpose(np.array(temporal_basis)) return temporal_basis
[docs] def convolve_with_temporal_basis(self, design_matrix, temporal_basis): '''Convolves a design matrix with a temporal basis function. Convolves each column of the design matrix with a series of temporal basis functions. Args: design_matrix (numpy array): 2D array of size :data:`(n_samples, n_features)`. temporal_basis (numpy array): 2D array of size :data:`(n_basis, n_timepoints)`. Returns: numpy array: 2D array of size :data:`(n_samples, n_features * n_basis)`. ''' n_temporal_basis = temporal_basis.shape[1] n_features = design_matrix.shape[1] convolved_design_matrix = list() for feat in range(n_features): for b in range(n_temporal_basis): convolved_design_matrix.append( np.convolve(design_matrix[:, feat], temporal_basis[:, b], mode='same')) convolved_design_matrix = \ np.transpose(np.array(convolved_design_matrix)) return convolved_design_matrix
[docs] def design_prior_covariance(self, sigma_temporal=2., sigma_spatial=5.): '''Design a prior covariance matrix for STRF estimation. Args: sigma_temporal (float): Standard deviation of temporal prior covariance. sigma_spatial (float): Standard deviation of spatial prior covariance. Returns: numpy array: 2-d array of size :data:`(n_spatial_basis * n_temporal_basis, n_spatial_basis * n_temporal_basis)`, the ordering of rows and columns is so that all temporal basis are consecutive for each spatial basis. ''' n_spatial_basis = self.n_spatial_basis n_temporal_basis = self.n_temporal_basis n_features = n_temporal_basis * n_spatial_basis sp_covariance = np.zeros([n_features, n_features]) te_covariance = np.zeros([n_features, n_features]) prior_covariance = np.zeros([n_features, n_features]) for i in np.arange(0, n_features): # Get spatiotemporal indices s_i = np.floor(np.float(i) % (n_temporal_basis * n_spatial_basis) / n_temporal_basis) t_i = i % n_temporal_basis # Convert spatial indices to (x,y) coordinates x_i = s_i % np.sqrt(n_spatial_basis) y_i = np.floor(np.float(s_i) / np.sqrt(n_spatial_basis)) for j in np.arange(i, n_features): # Get spatiotemporal indices s_j = np.floor(np.float(j) % (n_temporal_basis * n_spatial_basis) / n_temporal_basis) t_j = j % n_temporal_basis # Convert spatial indices to (x,y) coordinates x_j = s_j % np.sqrt(n_spatial_basis) y_j = np.floor(np.float(s_j) / np.sqrt(n_spatial_basis)) sp_covariance[i, j] = np.exp(-1. / (sigma_spatial ** 2) * ((x_i - x_j) ** 2 + (y_i - y_j) ** 2)) sp_covariance[j, i] = sp_covariance[i, j] te_covariance[i, j] = np.exp(-1. / (sigma_temporal ** 2) * (t_i - t_j) ** 2) te_covariance[j, i] = te_covariance[i, j] prior_covariance = sp_covariance * te_covariance prior_covariance = 1. / np.max(prior_covariance) * prior_covariance return prior_covariance