"""
(C) Asaph Zylbertal 01.03.2015, HUJI, Jerusalem, Israel

Basic model functions: stimulation and recording
If you use this model in your research please cite:
****************

"""

import numpy as np
import neuron

class mitral_neuron(object):


    def __del__(self):
        self.soma=None
        self.basl=None
        self.apic1=None
        self.tuft1=None
        self.apic2=None
        self.tuft2=None
        self.hlck=None
        self.iseg=None
        self.axon=None
        self.root=None
            
        neuron.h("forall delete_section()")
        

    # Application of a square pulse stimulation

    def init_square_stim(self,amp):

        stim=neuron.h.IClamp(self.root(0.5))

        stim.delay=100
        stim.amp=amp
        stim.dur=400
        self.stim=stim
    
    # Application of pulse train stimulation
    
    def init_train_stim(self, delay, duration, freq, pulse_duration, amp, dc, limit_dc=False, noise_stdv=0):
        stim=[]
        pulse_num=int(duration*freq)
    
        for i in range(pulse_num):        
                        
            stim.append(neuron.h.IClamp(self.root(0.5)))
            stim[i].delay=delay+i/freq
            stim[i].amp=amp
            stim[i].dur=pulse_duration
        stim.append(neuron.h.IClamp(self.root(0.5)))
        stim[i+1].amp=dc
        if limit_dc:
            stim[i+1].delay=delay
            stim[i+1].dur=duration
        else:
            stim[i+1].delay=0
            stim[i+1].dur=self.sim_time
        
        
        if noise_stdv>0:

            stim.append(neuron.h.IClamp(self.root(0.5)))

            stim[i+2].delay=0
            stim[i+2].dur=self.sim_time            

            noise_t=np.linspace(0,self.sim_time, self.sim_time)
            t_vec=neuron.h.Vector(noise_t)
            noise_vec=np.random.normal(0,noise_stdv,self.sim_time)
            self.nstim_vec=neuron.h.Vector(noise_vec)        
            self.nstim_vec.play(stim[i+2]._ref_amp, t_vec)            
            
        self.stim=stim
        
    # Record voltage from a specific segment
        
    def init_recording(self, seg):
        

        self.rec_v=self.init_vec_recording(seg._ref_v)
        self.rec_t=self.init_vec_recording(neuron.h._ref_t)
        

    # Record arbitrary time series
    
    def init_vec_recording(self,ref):
        vec=neuron.h.Vector()
        vec.record(ref)
        return vec
        
    def stop_recording(self):
        
        
        if hasattr(self, 'rec_v'):        
            del self.rec_v
        if hasattr(self, 'rec_t'):
            del self.rec_t
        if hasattr(self, 'rec_f'):
            del self.rec_f
        

    # Run the model untill a steady state is reached

    def init_steady_state(self, test_seg, init_run_chunk=500., min_slope=0.001, max_run=2000000.):
           
             
        v=neuron.h.Vector()
        v.record(test_seg._ref_v)  
        t=neuron.h.Vector()
        t.record(neuron.h._ref_t)        
            
        self.steady=neuron.h.SaveState()    # define state object
        
        if self.cv.active()==1:
            self.cv.re_init()
        
        neuron.h.finitialize(self.E)
        neuron.h.fcurrent() 
        
        good_chunk=False
        failed=False
        
        chunks_so_far=0
        chunk_start=0
        while ((not good_chunk) and (not failed)):
        
            run_point=init_run_chunk*(chunks_so_far+1)
            chunks_so_far+=1
            
            neuron.run(run_point)
        
        
            ta=np.array(t)[chunk_start:]
            va=np.array(v)[chunk_start:]
            try:            

                has_spikes=np.max(va)>30

            except:
                has_spikes=False
            if len(va)>1:
                slope=abs(va[0]-va[-1:])/init_run_chunk
                chunk_start=chunk_start+len(ta)+1
                
                if (slope<min_slope) and (not has_spikes) and (len(ta)>1):
                    good_chunk=True
                if (run_point>max_run):
                    failed=True
            
            
        del v
        del t
        del ta
                
        self.steady.save()
        return (run_point)            
        
    def run_model(self):

        
        neuron.h.t=0.
        if self.cv.active()==1:
            self.cv.re_init()        
        
        neuron.run(self.sim_time)