import sys, logging, os, importlib, gzip, json
from multiprocessing import Lock
from numpy import *
sys.path.append('.')
try:
    from .evaluator import Evaluator
except:
    from pyneuronautofit import Evaluator


from project import fitmeneuron, param_nslh, simulator
try:
	from project import downsampler
except:
	downsampler = None

if   simulator == "neuron":
    from neuron import h
elif simulator == "brian2":
    import brian2 as br
    raise NotImplemented("Brian's runner isn't implemented yet")
else:
    raise RuntimeError(f"Unknown simulator {simulator}")

class RunAndTest():
    def __init__(self,evaluator:Evaluator,
        params:(dict,None)   = None,
        celsius:(float,None) = None,
        init:(dict,None)     = None,
        dt:(float,None)      = None,
        lock:(Lock,None)=None,logname:(str,None)=None,loglevel:str="INFO"):
        self.eval      = evaluator
        self.params    = params        
        self.celsius   = celsius
        self.init      = init
        self.N         = len(self.eval.TestCurr)
        self.dt        = dt
        self.lock      = lock
        self.logger    = "RanAndTest.log" if logname is None else (logname+f"-RunAndTest-{os.getpid():09d}.log")
        self.logger    = "threadlog/"+self.logger
        self.loglevel  = loglevel
        if simulator == "neuron":
            self.__sim_run__ = self.__nrn_run__
        
    @property
    def __run__(self): return self.__sim_run__
    
    def log_error(self,message):
        if not self.loglevel in "ERROR WARNING INFO DEBUG".split(): return
        if self.lock is None: self.logger.error(message)
        else:
            with self.lock:
                #self.logger.error(message)
                with open(self.logger,"a") as fd:
                    fd.write(message+"\n")
    def log_info(self,message):
        if not self.loglevel in "INFO DEBUG".split(): return
        if self.lock is None: self.logger.info(message)
        else:
            with self.lock:
                #self.logger.info(message)
                with open(self.logger,"a") as fd:
                    fd.write(message+"\n")
    def log_debug(self,message):
        if not self.loglevel == "DEBUG": return
        if self.lock is None: self.logger.debug(message)
        else:
            with self.lock:
                #self.logger.debug(message)
                with open(self.logger,"a") as fd:
                    fd.write(message+"\n")

        
    def __nrn_SetPrm__(self, n:fitmeneuron, pname:(str,tuple), val:float)->None:
        if type(pname) is tuple:
            for pn in pname:
                exec(f"n.{pn} = {val}")
        else:
            exec(f"n.{pname} = {val}")
    
    def __nrn_run__(self, params:(dict,None)=None, init:(dict,None)=None, view:(bool,int)=False)->list:
        self.log_debug("Running evaluation")
        if params is None and self.params is None:
            self.log_error(f"Parameters are not given")
            raise RuntimeError(f"Parameters are not given")
        elif params is None: params = self.params
        if init is None : init = self.init
        
        h.celsius = 36. if self.celsius is None else self.celsius
        
        pop   = [ fitmeneuron(nid=x+1) for x in range(self.N) ]
        stims = [
            [
                h.IClamp(0.5,sec=n.soma),
                h.Vector(self.eval.TestCurr[nid]*1e-3), # rec in pA and nrn in nA
                h.Vector(arange(self.eval.TestCurr[nid].shape[0])*self.eval.expdt )
            ] for nid,n in enumerate(pop)
        ]
            
        recs  = [ h.Vector() for n in pop ]
        
        for nid,(n,(ic,ival,itime),rc) in enumerate( zip(pop,stims,recs) ):
            for pname in params:
                self.__nrn_SetPrm__(n,pname,params[pname])
            if init is not None:
                for var in init:
                    self.__nrn_SetPrm__(n,var,init[var ])
            
            pop[nid].setcable()
            ic.amp = self.eval.TestCurr[nid][0]*1e-3
            ic.delay = -1000.
            ic.dur   = 1e9
            itime.x[0] = -1000.
            ival.play(stims[nid][0]._ref_amp,stims[nid][2],1)
            rc.record(n.soma(0.5)._ref_v)
            
        h.celsius = 36. if self.celsius is None else self.celsius
        
        if self.dt is not None:
            if self.dt > 0.:
                h.dt = self.dt
            else:
                h.dt = self.eval.expdt*abs(self.dt)

#DB>>
        # dbn=pop[0]
        # names = [ n for n in params]+[ n for n in "soma.diam soma.nseg axon.nseg soma.eca".split() ]
        # for n in names:
            # if type(n) is tuple or type(n) is list:
                # for xn in n:
                    # v = eval("dbn."+xn)
                    # print(f"{xn:<45s} = {v}")
            # else:
                # v = eval("dbn."+n)
                # print(f"{n:<45s} = {v}")
        # print(f"celcius = {h.celsius}")
        # print(f"dt = {h.dt}")
#<<DB

        trec   = h.Vector()
        trec.record(h._ref_t)
        
    
        try:
            h.finitialize()
            h.fcurrent()
            h.frecord_init()
            h.t = -1000.   # 1000 ms for transient
        except BaseException as e :
            self.log_error(f"STUCK IN PRERUN WITH PARAMETERS:{params}")
            self.log_error(f"                      EXCEPTION:{e}")
            if view:raise
            else   :return self.eval.clone(None)
        
        try:    
            while h.t < self.eval.tmax :h.fadvance()
        except BaseException as e :
            self.log_error(f"STUCK IN RUN WITH PARAMETERS:{params}")
            self.log_error(f"                   EXCEPTION:{e}")
            if view:raise
            else   :return self.eval.clone(None)
    
    
    
        atrec    = copy(array(trec))
        zerot    = where(atrec > 0.)[0][0]-1
        atrec    = atrec[zerot:]
        arecs    = [ copy(array(v)[zerot:]) for v in recs ]

        tscale   = int(round(self.eval.expdt/h.dt) )
        scldrecs = [ mean( v[:v.shape[0]-v.shape[0]%tscale].reshape((v.shape[0]//tscale,tscale)),axis=1) for v in arecs ]

        if   int(view) == 2:
            return self.eval.diff(self.eval.clone(scldrecs),marks=True,nummark=True),atrec,arecs
        elif int(view) == 3:
            e = self.eval.clone(scldrecs)
            return self.eval.diff(e,marks=True),atrec,arecs,e,scldrecs,tscale
        elif view: 
            return atrec,arecs
        else:
            del pop,stims,recs,trec#,h,dLGN
            return self.eval.clone(scldrecs)

    def __call__(self, params=None)->list:
        return self.__run__(params=params,view=False) - self.eval

def ReadArXive(fname,selection):
    def read_arxive(fd):
        arx = json.load(fd)
        if type(arx) is dict:
            markers    = arx['markers']    if 'markers'    in arx else None
            bvalues    = arx['bvalues']    if 'bvalues'    in arx else None
            parameters = arx['parameters'] if 'parameters' in arx else None
            target     = arx['target']     if 'target'     in arx else None
            evaluation = arx['evaluation'] if 'evaluation' in arx else None
            version    = arx['version']    if 'version'    in arx else None
            cmd        = arx['cmd']        if 'cmd'        in arx else None

            arXive  = [
                    [ p['fitness'],p['parameters'] ]
                    for r in 'final records unique model models'.split() if r in arx
                    for p in arx[r] if not p is None
            ]
        else:
            sys.stderr.write(f"Wrong format of arXive {type(arx)}\n")
            exit(1)
        return arXive,markers,bvalues,parameters,version,cmd,target,evaluation
    
    _, fext = os.path.splitext(fname)
    logging.info( "=========================")
    if fext == ".py":
        try:
            mod = importlib.import_module(fname)
            collections = mod.selected
        except BaseException as e :
            print(f"Cannot import selected from {args[0]}")
            raise
    elif fext == ".gz":
        logging.info(f"Reading GZIP {fname}")
        with gzip.open(fname,'r') as fd:
            collections,markers,bvalues,parameters,\
            version,cmd,target,evaluation = read_arxive(fd)
    elif fext == ".json":
        logging.info(f"Reading JSON {fname}")
        with open(fname,'r') as fd:
            collections,markers,bvalues,parameters,\
            version,cmd,target,evaluation = read_arxive(fd)
    else:
        sys.stderr.write(f"Unknown input file extension {fext}")
        exit(1)
    if selection is not None:
        logging.info( "=========================")
        logging.info(f"Filtering collections by selection {selection}")
        if type(selection) is str:
            try:
                selection = eval(selection)
            except BaseException as e :
                logging.error(f" Cannot convert selection {selection} into a python object:{e}")
                logging.error( " ====== !!! FULL STOP !!! ======")
                exit(1)
        if   type(selection) is int: selection = [selection]
        elif type(selection) is tuple and len(selection) == 2:
            if selection[0] >= len(collections):
                logging.error(f" Left boundary of selection{selection[0]} is bigger than arXive size {len(collection)}")
                logging.error( " ====== !!! FULL STOP !!! ======")
                exit(1)
            selection = [ i for i in range(selection[0],len(collections)) if i <= selection[1]]
        elif type(selection) is tuple and len(selection) == 3:
            if selection[0] >= len(collections):
                logging.error(f" Left boundary of selection{selection[0]} is bigger than arXive size {len(collection)}")
                logging.error( " ====== !!! FULL STOP !!! ======")
                exit(1)
            selection = [ i for i in range(selection[0],len(collections),selection[1]) if i <= selection[2]]
        elif type(selection) is list:
            proselection = []
            for x in selection:
                if   type(x) is int  :proselection.append(x)
                elif type(x) is list :
                    if len(x) == 2  or len(x) == 3:
                        for y in range(*x):
                            proselection.append(y)
                    else:
                        sys.stderr.write(f"Range selection should have 2 or 3 numbers {len(x)} is given\n")
                        exit(1)
                else:
                    sys.stderr.write(f"Unknown input type of selection {x}\n")
                    exit(1)
            selection = proselection
        else:
            logging.error(f" incorrect selector or selector size {selection}")
            logging.error( " ====== !!! FULL STOP !!! ======")
            exit(1)
        logging.info(f"Selection = {selection}")
        collections = [ collections[i] for i in selection]
    logging.info( "==================== DONE")
    
    return collections,markers,bvalues,parameters,\
            version,cmd,target,evaluation,selection


if __name__ == '__main__':
    import sys,os,importlib,gzip,json
    
    from optparse import OptionParser
    oprs = OptionParser("USAGE: %prog [flags] file_with_parameters")    
    oprs.add_option("-i", "--input",       dest="input",   default=None,   help="input file should be abf,json,or npz. -v needs only abf",type="str")
    oprs.add_option("-M", "--Mode",        dest="mode",    default=None,   help="Mode for model evaluation (don't use)",                  type="str")
    oprs.add_option("-K", "--masK",        dest="mask",    default=None,   help="mask for evaluation",                                    type="str")
    oprs.add_option("-T", "--Threshold",   dest="thrsh",   default=None,   help="spike threshold",                                        type="float")
    oprs.add_option("-L", "--Left",        dest="left",    default=None,   help="left sample for spike shape",                            type="int")
    oprs.add_option("-R", "--Right",       dest="rght",    default=None,   help="right sample for spike shape",                           type="int")
    oprs.add_option("-C", "--Count",       dest="count",   default=None,   help="number of spikes to analyze in spike shape and width",   type="int")
    oprs.add_option("-Z", "--spike-Zoom",  dest="spwtgh",  default=None,                                                                  type="float",\
        help="if positive absolute weight of voltage diff during spike; if negative relataed scaler")
    oprs.add_option("-Q","--v-dvdt-size",  dest="vpvsize", default=None,                                                                  type='int',\
        help="v dv/dt histogram size")
    oprs.add_option("-t", "--temperature", dest="celsius", default=35.,    help="temperature in celsius",                                 type="float")

    oprs.add_option(      "--dt",          dest="simdt",   default=None,                                                                  type="float",\
        help="if positive absolute simulation dt; if negative scaler for recorded dt")

    oprs.add_option("-s", "--sort",        dest="sort",    default=False,  help="sort parameters fist (it changes neurons IDs!)",         action="store_true")
    oprs.add_option("-n", "--neuron-ids",  dest="nrnid",   default=None,                                                                  type="str",\
        help="Neuron selector: if int-select neuron in the file; if tuple(int,int) - from,to selection; if tuple(int,int,int) - from,step,to selection; if list [int,....] -list of selected neurons. All ID from the top of the list.")
    oprs.add_option("-N", "--Num-threads", dest="nthrs",   default=os.cpu_count(),                                                        type="int",\
        help="Number of thread avalible for process. If not set all will be used. Set to 0 to stop multithreading")
    
    oprs.add_option("-d", "--diff",        dest="diff",  default=False,  help="Print out differences",                                  action="store_true")
    oprs.add_option("-c", "--collapsed-diff",dest="cdiff",default=False, help="Print out collapsed differences",                        action="store_true")
    oprs.add_option("-v", "--view",        dest="view",  default=False,  help="Show graphs with voltages",                              action="store_true")
    oprs.add_option("-V", "--save-view",   dest="Gsave", default=None,   help="Save graphs in to a file instead of showing them",       type="str")
    oprs.add_option("-W", "--fig-size",    dest="Fsize", default=None,   help="The size of the figure in WxH format",                   type="str")
    oprs.add_option("-O", "--view-current",dest="showa", default=False,  help="Show figure with currents",                              action="store_true")
    oprs.add_option("-B", "--numBer-records",dest="nrec", default=False,  help="Show number of records in archive and exit",            action="store_true")
    oprs.add_option(      "--save-json",   dest="sjson", default=None,   help="save abf in json format",                                type="str")
    oprs.add_option(      "--save-numpy",  dest="snpz",  default=None,   help="save abf in numpy compressed format",                    type="str")
    oprs.add_option(      "--show-vector", dest="svec",  default=False,  help="print out vector and exit",                              action="store_true")

    oprs.add_option("-l", "--log-level"   ,        dest="ll",     default="INFO",          type='str',
        help="Level of logging.[CRITICAL, ERROR, WARNING, INFO, or DEBUG] (default INFO)") 
    oprs.add_option(     "--log-file"     ,       dest="lf",     default=None,           type='str',
        help="save log to file")

    opt, args = oprs.parse_args()
    

    if opt.lf is None:
        logging.basicConfig(format='%(asctime)s:%(name)-10s:%(lineno)-4d:%(levelname)-8s:%(message)s', level=eval("logging."+opt.ll) )
    else:
        logging.basicConfig(filename=opt.lf, format='%(asctime)s:%(name)-10s:%(lineno)-4d:%(levelname)-8s:%(message)s', level=eval("logging."+opt.ll) )

       
    if len(args) != 1 and opt.sjson is None and opt.snpz is None and not opt.svec:
        logging.error(f" Need a json[.gz] file with arXive")
        logging.error( " ====== !!! FULL STOP !!! ======")
        exit(1)
    
    logging.info( "=========================")
    logging.info(f" Number of active threads : {opt.nthrs}")
    
    collections,markers,bvalues,parameters,version,\
        cmd,target,evaluation,selection = ReadArXive(args[0],opt.nrnid)

    ### Altering Parameters and Evaluation if it's needed >>>
    if parameters is not None: prm_nslh = parameters
    if opt.input is not None: target = opt.input
    if evaluation is None: evaluation = {}
    if opt.mode is not None:
        evaluation['mode'] = opt.mode
    if opt.mask is not None:
        try:
            evaluation['mask'] = eval(opt.mask)
        except BaseException as e :
            logging.info(f"Cannot evaluate condition mask {opt.mask}:{e}")
            logging.info(f"Treating {opt.mask} as filename")
            try:
                with open(opt.mask) as fd:
                    evaluation['mask'] = eval(fd.read())
            except BaseException as e :
                logging.error(f" Cannot read condition mask from the {opt.mask}:{e}")
                logging.error( " ====== !!! FULL STOP !!! ======")
                exit(1)
    if opt.thrsh is not None:
        evaluation['spikethreshold'] = opt.thrsh
    if opt.left is not None:
        evaluation['prespike'] = opt.left
    if opt.rght is not None:
        evaluation['postspike'] = opt.rght
    if opt.count is not None:
        evaluation['spikecount'] = opt.count
    if opt.cdiff: evaluation['collapse_tests'] = opt.cdiff    
    if opt.spwtgh is not None:
        evaluation['spikeweight'] = opt.spwtgh
    if opt.vpvsize is not None:
        evaluation['vpvsize'] = opt.vpvsize
    ### <<<--------------------------------------------------

    logging.info(f"Target file           :{target}")
    logging.info(f"Evaluation parameters :{evaluation}")

    if target is None:
        logging.error(f" Target file is not present in arXive and was not set by -i option")
        logging.error( " ====== !!! FULL STOP !!! ======")
        exit(1)

    evaluation['downsampler']=downsampler
    evaluator = Evaluator(target, savetruedata=opt.view, **evaluation) 
    if opt.sjson is not None or opt.snpz is not None:
        logging.infor( " ====== ================= ======")
        if opt.sjson is not None:
            logging.info(f" Exporting JSON into {opt.sjson}")
            evaluator.exportJSON(opt.sjson)
        if opt.snpz is not None:
            logging.info(f" Exporting NumPy archive into {opt.snpz}")
            evaluator.exportNPZ(opt.snpz)
        logging.infor( " ====== ===== DONE ====== ======")
        exit(0)    
    if opt.svec:
        print(evaluator.vector())
        exit(0)
    if prm_nslh is None:
        logging.error(f" Parameter ranges are not given in project.py or arXive")
        logging.error( " ====== !!! FULL STOP !!! ======")
        exit(1)


    ### Run Evaluation in parallel >>>
    import multiprocessing as mp
    def worker(p):
        if   type(p) is list:
            prm = {}
            for (n,s,l,h),v in zip(prm_nslh,p):
                if type(n) is list or type(n) is tuple:
                    for r in n:
                        prm[r] = v
                else:
                    prm[n]=v
        elif type(p) is dict:
            prm = p
        else:
            raise RuntimeError(f"Unsupported parameters type {type(p)}. It should be list or dictionary")
            exit(1)
        if opt.view:
            if opt.diff or opt.cdiff:
                fitness = RunAndTest(evaluator,celsius=opt.celsius,dt=opt.simdt, params=prm).__run__(view=2)
            else:
                fitness = RunAndTest(evaluator,celsius=opt.celsius,dt=opt.simdt, params=prm).__run__(view=opt.view)
        else:
            fitness = RunAndTest(evaluator,celsius=opt.celsius,dt=opt.simdt)(params=prm)
        return fitness
    if opt.nthrs > 0:
        pool = mp.Pool(processes=opt.nthrs)
        result = [pool.apply_async(worker,[p]) for _,p in collections]
        pool.close()
        pool.join()
        result = [r.get() for r in result]
    else:
        result = [ worker(p) for p in collections]
    ### <<<---------------------------


    if opt.view:
        from matplotlib.pyplot import *
        def keypass(event):
                if   event.key == "down"      : keypass.recid -= 1
                elif event.key == "up"        : keypass.recid += 1
                elif event.key == "home"      : keypass.recid  = 0
                elif event.key == "end"       : keypass.recid  = keypass.recid = len(result) - 1
                # elif event.key == "enter"      : 
                    # cand = candidates[keypass.recid] if modmode else af_ec2mod(candidates[keypass.recid],prm_ranges)
                    # print("Model # ",keypass.recid)
                    # print("Fitness ",fitness[keypass.recid])
                    # for p,(pname,pscale,lo,hi) in zip(cand,param_ranges):
                        # print(" > {:<33s}:{:g}".format("/".join(pname) if type(pname) is tuple else pname,p))
                    # print()
                    # return

                if keypass.recid < 0            : keypass.recid = 0
                if keypass.recid >= len(result) : keypass.recid = len(result) - 1
                
                if opt.diff or opt.cdiff:
                    vdiff,trec,recs = result[keypass.recid]
                    print(keypass.recid,vdiff)
                    t=arange(len(vdiff))
                    vdiff = array([v for m,v in vdiff ])
                    yline.set_xdata(t[where(vdiff > 0)])
                    yline.set_ydata(vdiff[where(vdiff > 0.)])

                else:
                    trec,recs = result[keypass.recid]
                    print(keypass.recid)
                suptit.set_text(f"N #{selection[keypass.recid]}")
                for xline,md in zip(xlines,recs):
                    xline.set_ydata(md)
                f1.canvas.draw()
                

        
        if opt.Fsize is not None:
            try:
                opt.Fsize = eval(opt.Fsize)
            except BaseException as e :
                logging.error(f" Cannot convert figure size {opt.Fsize} into python object:{e}")
                exit(1)
        
        nh =  len(evaluator.TestCurr)   //4 + (1 if len(evaluator.TestCurr)%4 else 0)
        if opt.showa:
            if opt.Fsize is None:
                f2=figure(2,figsize=(16,9)) if opt.Gsave is None else figure(2,figsize=(64,36))
            else:
                f2=figure(2,figsize=opt.Fsize)
            for cid,c in enumerate(evaluator.TestCurr):
                subplot(nh,4,cid+1)
                plot(arange(c.shape[0])*evaluator.expdt,c)
        if opt.Fsize is None:
            f1=figure(1,figsize=(16,9)) if opt.Gsave is None else figure(1,figsize=(64,36))
        else:
             f1=figure(1,figsize=opt.Fsize)
        
        suptit = suptitle(f"N #{selection[0]}",fontsize=18)

        if opt.diff or opt.cdiff:
            subplots = [ subplot2grid( (nh,7),(cid//4,cid%4) ) for cid,c in enumerate(evaluator.TestCurr) ]
        else:
            subplots = [      subplot(    nh,4,      cid+1     ) for cid,c in enumerate(evaluator.TestCurr) ]
        
        
        for sp,rec in zip(subplots,evaluator.TrueData):
            sp.plot(arange(rec.shape[0])*evaluator.expdt,rec)

        xlines = []
        if opt.diff or opt.cdiff:
            vdiff,trec,recs = result[0]
        else:
            trec,recs = result[0]

        for sp,rec in zip(subplots,recs):
            l, = sp.plot(trec,array(rec))
            xlines.append(l)
        keypass.recid = 0
        if opt.diff or opt.cdiff:
            adiff = array([ [ v for m,v in d ] for d,_,_ in result ])
            difmin = amin(adiff[where(adiff > 0)])
            difmax = amax(adiff[where(adiff > 0)])
            saxis = subplot2grid( (nh,7),(0,4), rowspan=nh, colspan=3 )
            t=arange(len(vdiff))
            adiff = array([v for m,v in vdiff ])
            yline, = saxis.semilogy(t[where(adiff > 0)],adiff[where(adiff > 0.)],"ko")
            saxis.set_ylim(difmin/2,difmax*2)
            saxis.set_xticks(t)
            saxis.set_xticklabels( [m for m,v in vdiff] )
            setp(saxis.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor")
        
        
                
        if opt.Gsave is None:
            f1.canvas.mpl_connect('key_press_event', keypass)
            show()
        else:
            class Ek():
                def __init__(self, key):
                    self.key = key
            keypass(Ek("home"))
            for i in range(len(result)):
                f1.savefig(opt.Gsave.format(selection[i]))
                keypass(Ek("up"))

    elif opt.diff or opt.cdiff:
        res = RunAndTest(evaluator,celsius=opt.celsius,dt=opt.simdt)(params=prms)
        for i,r in enumerate(res):
            print(f"{i},{r.tolist()}".replace("[","").replace("]","").replace(" ",""))
            #print("  :", evaluator.spikezoomers[i])
        #print("len=",len(res))
        # from matplotlib.pyplot import *
        # ax1 = subplot(121)
        # ax2 = subplot(122,sharex=ax1)
        # for i,r in enumerate(res):
            # ax1.plot(r)
            # ax2.plot(evaluator.spikezoomers[i])
        # show()
        
    else:
        ev  = RunAndTest(evaluator,celsius=opt.celsius,dt=opt.simdt).__run__(params=prms)
        print(ev.vector())