#!/usr/bin/env python
# -*- coding: utf-8 -*-
## x-y-t receptive field maps
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.ndimage.filters import gaussian_filter
# data path
data_path = "/home/pablo/Desktop/Biophysical_thalamocortical_system/thalamocortical/results/"
##################
### Parameters ###
##################
# Number of neurons (all layers except INs)
N = 10.0
# Stimulus
# Spatial impulse response
Npoints = 40.0 # Number of flashing spots per row
raw_stimulus = np.arange(0.0,Npoints*Npoints,1.0)
stimulus = []
# To speed computations, select a center square
row_top_limit = 3.0 # (between 0 and N)
row_bottom_limit = 7.0 # (between 0 and N)
col_left_limit = 3.0 # (between 0 and N)
col_right_limit = 7.0 # (between 0 and N)
for pos in raw_stimulus:
row = int(pos/Npoints)*(N/Npoints)
col = np.remainder(pos,Npoints)*(N/Npoints)
if(row >= row_top_limit and row <= row_bottom_limit and col >=\
col_left_limit and col <= col_right_limit):
stimulus.append(pos)
#ID = "retinaON"
#ID = "RC-ON"
ID = "PY_v-ON"
# Simulation parameters
tsim = 300.0
binsize = 5.0
numbertrials =100.0
# Combination
cc = "comb0"
selected_cell = 55
##################
### Plots ########
##################
# Load PST
def loadPST(stim,N,tsim,binsize,neuron,add_path):
PST_avg = np.zeros((int(N*N),int(tsim/binsize)))
lines = [line.rstrip('\n') for line in open(data_path+add_path+"/stim"+str(stim)+"/PST"+neuron, "r")]
for n in np.arange(len(lines)):
h = lines[int(n)].split(',')
for pos in np.arange(0,tsim/binsize):
PST_avg[int(n),int(pos)] = float(h[int(pos)])
return PST_avg
# Create PSTs
def createPST(cellID,stimulus,N,tsim,binsize,comb,type):
PST = []
if cellID=="retinaON":
for s in stimulus:
PST.append(loadPST(s,N,tsim,binsize,"","retina/"+type+"/"+"ON"))
elif cellID=="retinaOFF":
for s in stimulus:
PST.append(loadPST(s,N,tsim,binsize,"","retina/"+type+"/"+"OFF"))
else:
for s in stimulus:
PST.append(loadPST(s,N,tsim,binsize,cellID,"xt_plots/"+type+"/"+comb))
return PST[0]
# xt map
def xt_map(intervals,type):
data = np.zeros((np.sqrt(len(stimulus)),np.sqrt(len(stimulus))))
n = 0
for x in np.arange(0,np.sqrt(len(stimulus))):
for y in np.arange(0,np.sqrt(len(stimulus))):
# PST
if (ID == "retinaON" or ID=="retinaOFF"):
PST = createPST(ID,[stimulus[n]],N,tsim,binsize,"",type)
else:
PST = createPST(ID,[stimulus[n]],N,tsim,binsize,cc,type)
data[int(x),int(y)] = np.sum(PST[int(selected_cell),intervals[0]:intervals[1]])/\
(numbertrials*(intervals[1]-intervals[0]))
n+=1
return data
# Spatiotemporal RF
def spatiotemporalRF(intervals,vertical):
data = np.zeros((len(intervals)-1,np.sqrt(len(stimulus))))
for n in np.arange(0,len(intervals)-1):
print("time = %s" % intervals[n])
data_1 = xt_map([int(intervals[n]/binsize),int(intervals[n+1]/binsize)],"RF_1_matched_old")
data_2 = xt_map([int(intervals[n]/binsize),int(intervals[n+1]/binsize)],"RF_2_matched_old")
aux_data = data_1 - data_2
for y in np.arange(0,np.sqrt(len(stimulus))):
if(vertical):
int_data = np.sum(aux_data[y,:])/np.sqrt(len(stimulus))
else:
int_data = np.sum(aux_data[:,y])/np.sqrt(len(stimulus))
data[n,y] = int_data
return data
### Plotting ###
## Topographical responses
start = 200.0
stop = 220.0
intervals = [int(start/binsize),int(stop/binsize)]
print("computing ON response")
data_1 = xt_map(intervals,"RF_1_matched_old")
print("computing OFF response")
data_2 = xt_map(intervals,"RF_2_matched_old")
data = data_1 - data_2
x = np.arange(-2.0, 2.25, 0.25)
y = np.arange(-2.0, 2.25, 0.25)
#x = np.arange(Npoints)
#y = np.arange(Npoints)
X, Y = np.meshgrid(x, y)
### Spatiotemporal response
#start = 100.0
#stop = 300.0
#intervals = np.arange(start,stop,10.0)
#data = spatiotemporalRF(intervals,False)
#x = np.arange(-2.0, 2.25, 0.25)
#y = intervals[0:len(intervals)-1]
#X, Y = np.meshgrid(x, y)
#### Contour plot
fig = plt.figure()
Gax = plt.subplot2grid((1,1), (0,0))
# interpolation
sigma = 1.5
data2 = gaussian_filter(data, sigma)
#CS = plt.contourf(X, Y, data2, 11, cmap=plt.cm.coolwarm)
CS = plt.contourf(X, Y, data2, [-13.5,-13.,-11.3,-7.5,-6.,-4.5,-3.,-1.5,0.,1.5,3.,4.5], cmap=plt.cm.coolwarm)
plt.setp(Gax, yticks=[])
plt.setp(Gax, xticks=[])
cbar = plt.colorbar(CS,ticks=[])
#cbar = plt.colorbar(CS)
#plt.rcParams.update({'font.size': 22})
#### x/y profiles
#fig = plt.figure()
#plt.plot(np.arange(-2.0, 2.25, 0.25),data[8,:])
#plt.plot(np.arange(-2.0, 2.25, 0.25),data[:,8])
plt.show()