from neuron import h
import numpy as np
import pylab as p
import time
import sys
import os
import h5py as h5
import itertools as it
from DS2M0Purk import DS2M0Purk
def save_dict(fid, group, data):
for key,value in data.iteritems():
if isinstance(value, dict):
new_group = fid.create_group(group.name + '/' + key)
save_dict(fid, new_group, value)
elif type(value) in (int,float,tuple,str):
group.attrs.create(key,value)
else:
group.create_dataset(key, data=np.array(value), compression='gzip', compression_opts=9)
def save_h5_file(filename, **kwargs):
with h5.File(filename, 'w') as fid:
save_dict(fid, fid, kwargs)
def make_output_filename(prefix='', extension='.out'):
filename = prefix
if prefix != '' and prefix[-1] != '_':
filename = filename + '_'
now = time.localtime(time.time())
filename = filename + '%d%02d%02d-%02d%02d%02d' % \
(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec)
if extension[0] != '.':
extension = '.' + extension
suffix = ''
k = 0
while os.path.exists(filename + suffix + extension):
k = k+1
suffix = '_%d' % k
return filename + suffix + extension
def somatic_current_injection(cell, amplitude=3., tbefore=100., tstim=500., tafter=100., celsius=37.):
Vrest = -68.
print('Inserting stimulus...')
stim = h.IClamp(cell.soma(0.5))
stim.delay = tbefore
stim.dur = tstim
stim.amp = float(amplitude)
print('Setting up recorders...')
rec = {}
for lbl in 't','vsoma','vdend','spikes':
rec[lbl] = h.Vector()
rec['t'].record(h._ref_t)
rec['vsoma'].record(cell.soma(0.5)._ref_v)
rec['vdend'].record(cell.dendrites[1511](0.5)._ref_v)
apc = h.APCount(cell.soma(0.5))
apc.record(rec['spikes'])
print('Setting up the simulation...')
h.load_file('stdrun.hoc')
h.t = 0.
h.dt = 0.02
h.v_init = Vrest
h.celsius = celsius
h.tstop = tbefore + tstim + tafter
print('Running model...')
h.finitialize(h.v_init)
h.run()
return np.array(rec['t'])*1e-3,np.array(rec['vsoma']),np.array(rec['vdend']),np.array(rec['spikes'])*1e-3
def synaptic_activation(cell, exc_rate, inh_rate, tstop=5000., seed=None, celsius=37.):
if not cell.with_synapses:
raise Exception('No synapses present')
cell.add_offset_current()
if seed is None:
np.random.seed(int(time.time()))
else:
np.random.seed(seed)
print('Computing the activation times of the synapses...')
if exc_rate > 0:
nspikes = (tstop/1e3)*exc_rate
for syn in cell.synapses['granule_cells']:
isi = -np.log(np.random.uniform(size=nspikes)) / exc_rate * 1e3
cell.set_presynaptic_spike_times(syn, np.cumsum(isi))
if inh_rate > 0:
nspikes = (tstop/1e3)*inh_rate
for syn in cell.synapses['stellate_cells']:
isi = -np.log(np.random.uniform(size=nspikes)) / inh_rate * 1e3
cell.set_presynaptic_spike_times(syn, np.cumsum(isi))
print('Setting up recorders...')
rec = {}
for lbl in 't','vsoma','vdend','spikes':
rec[lbl] = h.Vector()
rec['t'].record(h._ref_t)
rec['vsoma'].record(cell.soma(0.5)._ref_v)
rec['vdend'].record(cell.dendrites[1511](0.5)._ref_v)
apc = h.APCount(cell.soma(0.5))
apc.record(rec['spikes'])
print('Setting up the simulation...')
h.load_file('stdrun.hoc')
h.t = 0.
h.dt = 0.02
h.v_init = -69.
h.celsius = celsius
h.tstop = tstop
print('Running model...')
h.finitialize(h.v_init)
h.run()
return np.array(rec['t'])*1e-3,np.array(rec['vsoma']),np.array(rec['vdend']),np.array(rec['spikes'])*1e-3
def compute_PRC(cell, firing_rate, exc_rate=None, inh_rate=None, seed=None, which_isi=0, ttran=3000., tstop=5000., ntrials=100, pulse_amp=0.5, pulse_dur=0.5, celsius=37., jittered_isi=False):
if cell.with_synapses: # PRC with synaptic activation
cell.add_offset_current()
np.random.seed(seed)
print('Computing the activation times of the synapses...')
presyn_spike_times = {}
if exc_rate > 0:
nspikes = (tstop/1e3)*exc_rate
presyn_spike_times['exc'] = np.zeros((len(cell.synapses['granule_cells']),nspikes))
for i,syn in enumerate(cell.synapses['granule_cells']):
isi = -np.log(np.random.uniform(size=nspikes)) / exc_rate * 1e3
cell.set_presynaptic_spike_times(syn, np.cumsum(isi))
presyn_spike_times['exc'][i,:] = np.cumsum(isi)
if inh_rate > 0:
nspikes = (tstop/1e3)*inh_rate
presyn_spike_times['inh'] = np.zeros((len(cell.synapses['stellate_cells']),nspikes))
for i,syn in enumerate(cell.synapses['stellate_cells']):
isi = -np.log(np.random.uniform(size=nspikes)) / inh_rate * 1e3
cell.set_presynaptic_spike_times(syn, np.cumsum(isi))
presyn_spike_times['inh'][i,:] = np.cumsum(isi)
else:
print('Not adding any stimulation to the cell.')
pulse = h.IClamp(cell.soma(0.5))
pulse.delay = 1e8
pulse.dur = pulse_dur
pulse.amp = pulse_amp
print('Setting up recorders...')
rec = {}
for lbl in 't','vsoma','spikes':
rec[lbl] = h.Vector()
rec['t'].record(h._ref_t)
rec['vsoma'].record(cell.soma(0.5)._ref_v)
apc = h.APCount(cell.soma(0.5))
apc.record(rec['spikes'])
print('Setting up the simulation...')
h.load_file('stdrun.hoc')
h.t = 0.
h.dt = 0.02
h.v_init = -69.
h.celsius = celsius
h.tstop = ttran + 100
h.finitialize(h.v_init)
print('Running the model...')
sys.stdout.flush()
h.run()
found = False
while h.tstop <= tstop:
t = np.array(rec['t'])
v = np.array(rec['vsoma'])
spikes = np.array(rec['spikes'])
spikes = spikes[spikes>ttran]
isi = np.diff(spikes)
if firing_rate <= 0:
firing_rate = 1000./isi[0]
try:
idx = np.where((1000./isi >= firing_rate) & (1000./isi <= firing_rate+1.))[0][which_isi]
found = True
break
except:
print('No spikes @ %g Hz in the first %g ms. Continuing...' % (firing_rate,h.tstop))
h.tstop += 100.
h.continuerun(h.tstop)
if not found:
print('Unable to find the right ISI in the first %g ms.' % h.tstop)
print np.sort(1000./isi)
#ax = p.subplot(211)
#p.plot(t,v,'k')
#p.subplot(212, sharex=ax)
#p.plot(spikes[1:],1000./isi,'ko')
#p.show()
sys.exit(0)
tbefore = 2.
tafter = 5.
t0 = spikes[idx]
t1 = spikes[idx+1]
idx = int(t0/h.dt) + np.argmax(v[int(t0/h.dt):int((t0+2)/h.dt)])
t0 = t[idx]
idx = int(t1/h.dt) + np.argmax(v[int(t1/h.dt):int((t1+2)/h.dt)])
t1 = t[idx]
print('Good ISI found between %g and %g ms.' % (t0,t1))
if cell.with_synapses and jittered_isi:
print('Adding all other objects...')
rs = np.random.RandomState(int(time.time()))
for j in range(1,ntrials):
np.random.seed(seed)
for k,syn in enumerate(it.chain(cell.synapses['granule_cells'],cell.synapses['stellate_cells'])):
s,st,nc = cell.insert_synapse(syn['sec'], syn['tau'], syn['E'], 0.)
if syn in cell.synapses['granule_cells']:
isi = -np.log(np.random.uniform(size=len(syn['spike_times'][0]))) / exc_rate * 1e3
else:
isi = -np.log(np.random.uniform(size=len(syn['spike_times'][0]))) / inh_rate * 1e3
spks = np.cumsum(isi)
idx, = np.where((spks>t0) & (spks<t1))
if len(idx) > 0:
isi[idx] += rs.uniform(size=len(idx))
spks = np.cumsum(isi)
vec = h.Vector(spks)
st.play(vec)
syn['syn'].append(s)
syn['stim'].append(st)
syn['conn'].append(nc)
syn['spike_times'].append(vec)
print('Running the model again...')
sys.stdout.flush()
rec['t'].resize(0)
rec['vsoma'].resize(0)
rec['spikes'].resize(0)
apc.n = 0
h.t = 0
h.tstop = t0 - tbefore
h.finitialize(h.v_init)
h.run()
print('Saving the state...')
ss = h.SaveState()
ss.save()
h.dt = 0.005
nsamples = np.round((t1-t0+tbefore+tafter) / h.dt)
V = np.zeros((ntrials,nsamples))
spike_times = np.nan + np.zeros((ntrials,2))
perturbation_times = np.zeros(ntrials)
for i in range(ntrials):
sys.stdout.write('\rTrial [%02d/%02d] ' % (i+1,ntrials))
sys.stdout.flush()
ss.restore()
rec['t'].resize(0)
rec['vsoma'].resize(0)
rec['spikes'].resize(0)
apc.n = 0
if cell.with_synapses and jittered_isi:
for syn in it.chain(cell.synapses['granule_cells'],cell.synapses['stellate_cells']):
for j in range(ntrials):
syn['conn'][j].weight[0] = 0
syn['conn'][i].weight[0] = syn['w']
if i > 0:
pulse.delay = t0 + i*(t1+5-t0)/ntrials
perturbation_times[i] = pulse.delay
h.continuerun(t1+tafter)
idx, = np.where(np.array(rec['spikes']) > t0-tbefore)
try:
spike_times[i,:] = np.array(rec['spikes'])[idx[0]:idx[0]+2]
except:
pass
V[i,:] = np.array(rec['vsoma'])[:nsamples]
t = np.array(rec['t'])[idx[0]:idx[0]+nsamples]
sys.stdout.write('\n')
#p.figure()
#for i in range(ntrials):
# p.plot(t,V[i,:])
#p.show()
if cell.with_synapses:
return t,V,perturbation_times,spike_times,presyn_spike_times,1000./(t1-t0)
return t,V,perturbation_times,spike_times,1000./(t1-t0)
def usage():
print('')
print('This script can be used to compute the PRC of the DeSchutter-Bower Purkinje cell model.')
print('')
print('Usage:')
print('')
print(' %s iclamp [amplitude] Simulate the injection of a constant current step lasting 10 seconds.' % os.path.basename(sys.argv[0]))
print(' The default amplitude is 0.1 nA.')
print(' %s synapses Simulate the activation of synapses across the dendritic tree, lasting 10 seconds.' % os.path.basename(sys.argv[0]))
print(' %s PRC_syn Compute the PRC during synaptic activation.' % os.path.basename(sys.argv[0]))
print(' %s PRC_iclamp Compute the PRC during current clamp stimulation.' % os.path.basename(sys.argv[0]))
print(' %s [-h|--help] Print this help message and exit.' % os.path.basename(sys.argv[0]))
print('')
print('Author: Daniele Linaro - danielelinaro@gmail.com')
print('')
def main():
if sys.argv[1] in ('-h','--help'):
usage()
sys.exit(0)
elif sys.argv[1] == 'iclamp':
cell = DS2M0Purk(g_granule=None,g_stellate=None) # no synapses
# Somatic current injection
try:
amplitude = float(sys.argv[2])
except:
amplitude = 0.1
tbefore = 0
tstim = 10000.
tafter = 0.
celsius = 28.
t,vsoma,vdend,spikes = somatic_current_injection(cell, amplitude, tbefore, tstim, tafter, celsius)
tbefore *= 1e-3
tstim *= 1e-3
print('Firing rate: {0} Hz.'.format(len(np.intersect1d(np.where(spikes>=tbefore+1)[0],\
np.where(spikes<=tbefore+tstim)[0]))/(tstim-1)))
elif sys.argv[1] == 'synapses':
# Dendritic synapse activation
cell = DS2M0Purk(g_granule=0.7e-3, g_stellate=[7000.,1400.])
t,vsoma,vdend,spikes = synaptic_activation(cell,exc_rate=35.,inh_rate=2, tstop=10000., seed=5061983)
save_h5_file(make_output_filename('synaptic_activation_','.h5'), dt=t[1]-t[0], V=vsoma)
elif sys.argv[1] == 'PRC_syn':
# Computation of the PRC
g_granule = 0.7e-3
g_stellate = [7000.,1400.]
cell = DS2M0Purk('PM9',g_granule, g_stellate)
exc_rate = 35.
inh_rate = 2.
try:
firing_rate = float(sys.argv[2])
except:
firing_rate = 85.
print('Target firing rate: %.1f Hz.' % firing_rate)
try:
pulse_amp = float(sys.argv[3])
except:
pulse_amp = 0.5
print('Pulse amplitude: %.2f nA.' % pulse_amp)
which_isi = 0
ttran = 5000.
tstop = 10000.
ntrials = 50
seed = 1416394853 # int(time.time())
pulse_dur = 0.5
temperature = 37.
t,V,perturbation_times,spike_times,presyn_spike_times,actual_firing_rate = \
compute_PRC(cell=cell, firing_rate=firing_rate, exc_rate=exc_rate, inh_rate=inh_rate, seed=seed, which_isi=which_isi, ttran=ttran, \
tstop=tstop, ntrials=ntrials, pulse_amp=pulse_amp, pulse_dur=pulse_dur, celsius=temperature, jittered_isi=False)
save_h5_file(make_output_filename('prc_','.h5'),dt=t[1]-t[0],V=V,perturbation_times=perturbation_times,
spike_times=spike_times,exc_rate=exc_rate,inh_rate=inh_rate,firing_rate=actual_firing_rate,
which_isi=which_isi,ttran=ttran,tstop=tstop,ntrials=ntrials,g_granule=g_granule,g_stellate=g_stellate,
presyn_spike_times=presyn_spike_times,seed=seed,pulse_amp=pulse_amp,pulse_dur=pulse_dur, temperature=temperature)
elif sys.argv[1] == 'PRC_iclamp':
# Computation of the PRC
cell = DS2M0Purk('PM10',None,None)
try:
firing_rate = float(sys.argv[2])
except:
firing_rate = 85.
print('Target firing rate: %.1f Hz.' % firing_rate)
try:
offset = float(sys.argv[3])
except:
offset = 0.2
print('DC offset: %.2f nA.' % offset)
try:
pulse_amp = float(sys.argv[4])
except:
pulse_amp = 0.5
print('Pulse amplitude: %.2f nA.' % pulse_amp)
cell.add_offset_current(offset)
which_isi = 0
ttran = 3000.
tstop = 3100.
ntrials = 50
pulse_dur = 0.5
t,V,perturbation_times,spike_times,actual_firing_rate = \
compute_PRC(cell, firing_rate=firing_rate, exc_rate=None, inh_rate=None, seed=None, which_isi=which_isi, \
ttran=ttran, tstop=tstop, ntrials=ntrials, pulse_amp=pulse_amp, pulse_dur=pulse_dur, celsius=28)
save_h5_file(make_output_filename('prc_','.h5'),dt=t[1]-t[0],V=V,perturbation_times=perturbation_times,
spike_times=spike_times,firing_rate=actual_firing_rate,
which_isi=which_isi,ttran=ttran,tstop=tstop,ntrials=ntrials,
pulse_amp=pulse_amp,pulse_dur=pulse_dur)
try:
print np.sort(1./np.diff(spikes))
p.figure()
p.plot(t,vsoma,'k',label='Soma')
p.plot(t,vdend,'r',label='Dendrite')
p.xlabel('Time (s)')
p.ylabel('Membrane voltage (mV)')
p.legend(loc='best')
p.figure()
p.plot(spikes[1:],1./np.diff(spikes),'ko')
p.xlabel('Time (s)')
p.ylabel('1/ISI (s^-1)')
p.show()
except:
pass
if __name__ == '__main__':
main()