from mrth import * 
from optparse import OptionParser
oprs = OptionParser("USAGE: python3 %prog [options] stkdata stkdata .....")
oprs.add_option("-s", "--save-json" , dest="sjson" , default=False,   help="save into json"                       , type  = "str")
oprs.add_option("-l", "--load-json" , dest="ljson" , default=False,   help="load from json"                       , type  = "str")

oprs.add_option( "--no-show"        , dest="shown"  , default=True ,  help="Do not show graphs on screen"         , action="store_false")
oprs.add_option( "--png"            , dest="savepng", default=False,  help="Save png file"                        , action="store_true" )
oprs.add_option( "--svg"            , dest="savesvg", default=False,  help="Save svg file"                        , action="store_true" )
oprs.add_option( "--skip-first"     , dest="skip"   , default=-1   ,  help="skip first # hash"                    , type  ="int"        )
oprs.add_option( "--show-lines"     , dest="showlns", default=False,  help="Show individual lines along with mean", action="store_true")
oprs.add_option( '-m',"--cor-min"   , dest="cormin" , default=-0.2 ,  help="minimal correlation to show"          , type  ="float")
oprs.add_option( '-M',"--cor-max"   , dest="cormax" , default= 1.0 ,  help="maximal correlation to show"          , type  ="float")
oprs.add_option( '-Y',"--pro-max"   , dest="promax" , default= 0.25,  help="maximal of proportion"                , type  ="float")
oprs.add_option( '-Z',"--m-cor-max" , dest="meancm" , default= 0.4 ,  help="maximal of mean corelation"           , type  ="float")
oprs.add_option( '-R',"--recompute" , dest="recom"  , default=False,  help="Recompute correlation"                , action="store_true")
oprs.add_option( '-K',"--kernel"    , dest="kernel" , default=20   ,  help="Positive kernel size"                 , type="int")
oprs.add_option( '-L',"--time-lag"  , dest="tlag"   , default=False,  help="Time lag"                             , type="int")
oprs.add_option( "--use-ganglion"   , dest="rgc"    , default=False,  help="Compute metric for rGC"               , action="store_true")
oprs.add_option('--num-of-cores'    , dest="ncores" , default=os.cpu_count(), \
                                                                      help="Use number of cores"                   , type  = "int")
oprs.add_option( '-S',"--swarm"     , dest="swarm"  ,  default=False,  help="show swarm plot instead of stderr"    , action= "store_true")
oprs.add_option( '-V',"--symbol"    , dest="symbol" ,  default=None ,  help="Show a symbol and error bar"          , type  = "str")
oprs.add_option( '-v',"--sym-size"  , dest="synsize",  default=9    ,  help="symbol size"                          , type  = "float")
oprs.add_option( "--save-and-exit"  , dest="sae"    ,  default=False,  help="Save JSON file and exit"              , action= "store_true")
oprs.add_option( "-T","--title"     , dest="title"  ,  default=None,   help="Title"                                , type  = "str")
#oprs.add_option( '-0',"--hide-rGC"  , dest="hidrgc" , default=True ,  help="Hide rGC distribution"                , action="store_false")

opts, args = oprs.parse_args() 

import lzma as xz    
from simtoolkit import tree
from simtoolkit import data as stkdata

if opts.swarm:
    import pandas as pd
    from seaborn import swarmplot

def getCorDist(sp,ncells):
    rkm  = arange(-float(opts.kernel*20),float(opts.kernel*20+1),1)
    rkp  = exp(-rkm**2/float(opts.kernel  )**2)
    rkp /= sum(rkp)
    rkm  = exp(-rkm**2/float(opts.kernel*4)**2)
    rkm /= sum(rkm)
    rka = rkp-rkm
    FRdur = int( ceil(amax(sp[:,0])+10.) )
    FR = zeros( (ncells+1,FRdur) )
    for t,n in sp:
        FR[int(round(n)),int(round(t))] += 1
    FR[:,0] = 0.
    for n in range(ncells+1):
        if opts.tlag and opts.tlag != 0:
            shift = opts.tlag*n
            FR[n,  :   ] = convolve(FR[n,:],rka,mode='same')
            # print(shift, FR[n,shift:].shape, FR[n,:-shift].shape)
            if n != 0:
                FR[n,shift:] = FR[n,:-shift]
                FR[n,:shift] = zeros(shift)
        else:
            FR[n,:] = convolve(FR[n,:],rka,mode='same')
    xcor = corrcoef(FR)
    xcor = array([ xcor[n1,n2] for n1 in range(ncells) for n2 in range(n1+1,ncells) ])
    #xcor[~isfinite(xcor)] = 0
    xcor = xcor[isfinite(xcor)]

    LGNh,LGNb = histogram(xcor, bins=201, range=(-1.,1.) )
    LGNh,LGNb = LGNh/sum(LGNh), (LGNb[:-1]+LGNb[1:])/2.
    return json.dumps([column_stack((LGNb,LGNh)).tolist(),mean(xcor)])


def readafile(f):
    #DB>>
    # print(f,type(f))
    #<<DB
    if type(f) is str:
        _, fext = os.path.splitext(f)
        if   fext == '.stkdata':
            return readafile((f,f))
        elif fext == '.xz':
            with xz.open(f,'r') as fd:
                return readafile((f,fd))
    elif type(f) is tuple:
        f,fd = f
        with stkdata(fd,'ro') as sd:
            try:
                hashline = sd["/hash", -1]
                model    = tree().imp(sd["/"+hashline+"/model",-1])
                if model.check("/network/syn/sig2gsyn"):
                    s2s  = model["/network/syn/sig2gsyn"]
                elif model.check("/network/syn/geom/o2o"):
                    s2s  = 0
                else:
                    print(f"File {f}: there is no sig2syn and o2o isn't true (-.-)")
                    return json.dumps(None)
                ncells = sd['/'+hashline+'/n-neurons',-1]
                NMADp  = model['/network/syn/NMDA/p']
                AMPAp  = model['/network/syn/AMPA/p']
            except BaseException as e:
                print(f"Cannot read {f} file: {e}")
                return json.dumps(None)
            
            #for hashid,hashline in enumerate(sd["/hash"]):
                #if hashid < opts.skip : continue            
            if opts.rgc :
                hashline = sd["/hash",0]
                if not "/"+hashline+"/CorrDist/rGC" in sd or opts.recom:
                    xspikes = sd["/"+hashline+"/rGC/spikes",-1]
                    ncells  = int(amax(xspikes[:,1])+1)
                    #DB>>
                    print(f"Using {ncells} rGC from {f}")
                    #<<DB
                    return json.dumps(json.loads(getCorDist(xspikes,ncells )) + [s2s,NMADp,AMPAp])
                else:
                    corr = sd[f"/{hashline}/CorrDist/rGC", -1]
                    return json.dumps([corr.tolist(),average(corr[:,0], weights=corr[:,1]),s2s,NMADp,AMPAp] )
            else:
                for hashid,hashline in enumerate(sd["/hash",None][opts.skip:]):
                    if not "/"+hashline+"/CorrDist/LGN" in sd or opts.recom:
                        xspikes = None
                        for spk in sd["/"+hashline+"/spikes"]:
                            if xspikes is None: xspikes = spk
                            else:
                                xspikes = append(xspikes,spk,axis=0)
                        if xspikes is None: return json.dumps(None)
                        return json.dumps(json.loads(getCorDist(xspikes,ncells )) + [s2s,NMADp,AMPAp])
                    else:
                        corr = sd[f"/{hashline}/CorrDist/LGN", -1]
                        return json.dumps([corr.tolist(),average(corr[:,0], weights=corr[:,1]),s2s,NMADp,AMPAp] )
                # labels.append([s2s,len(corrdist)-1])
            # if "/"+hashline+"/CorrDist/rGC" in sd and rgdist is None:
                # rgdist = sd["/"+hashline+"/CorrDist/rGC",-1]
    else:
        print(f"Unknow type {f}: {type(f)}")
        return json.dumps(None)
            
        

if opts.ljson:
    with open(opts.ljson) as fd:
        j = json.load(fd)
    args        = j['files']
    corrdist    = [ array(json.loads(xcorr)) for xcorr in j['cordist'] ]
    meancordist = array(j['meancordist'])
    labels      = array([ json.loads(l) for l in j['labels'] ])
else:
    if len(args) == 0:
        print(f"Need at least one data file")
        print(f"python3 {sys.argv[0]} -h for more information")
        exit(1)

    if opts.ncores < 2:
        result      = [ [array(corr), xmean, s2s, NMDAp, AMPAp] for corr, xmean, s2s, NMDAp, AMPAp in [ json.loads(readafile(f)) for f in args ] ]
        meancordist = array([   xmean             for _, xmean, s2s, NMDAp, AMPAp in result ])
        labels      = array([ [s2s, NMDAp, AMPAp] for _,     _, s2s, NMDAp, AMPAp in result ])
        corrdist    = [ array(corr) for corr, _, _, _, _ in result ]
    else:
        import multiprocessing as mp
        print(f"TOTAL TASKS : {len(args)} / KERNEL : {opts.kernel} / TIMELAG : {opts.tlag}")
        pool = mp.Pool(processes=opts.ncores)
        result = [ pool.apply_async(readafile,[tsk]) for tsk in args ]
        pool.close()
        pool.join()
        result      = [json.loads(r.get()) for r in result]
        result    = [ r for r in result if not r is None ]
        meancordist = array([   xmean             for _, xmean, s2s, NMDAp, AMPAp in result ])
        labels      = array([ [s2s, NMDAp, AMPAp] for _,     _, s2s, NMDAp, AMPAp in result ])
        corrdist    = [ array(corr) for corr, _, _, _, _ in result ]
    if opts.sjson:
        with open(opts.sjson, 'w') as fd:
            json.dump({
                'rGC'         : opts.rgc,
                'timelag'     : opts.tlag,
                'recompute'   : opts.recom,
                'kernel'      : opts.kernel,
                'files'       : args,
                'cordist'     : [ json.dumps(xcorr.tolist())  for xcorr in corrdist],
                'meancordist' : meancordist.tolist(),
                'labels'      : [ json.dumps(l.tolist()) for l in labels ]
            },fd, sort_keys=True, indent=4)
    if opts.sae:
        exit(0)

fname,_ = os.path.splitext(args[0])
comname = list(fname)
for f in  args[1:]:
    fname,_ = os.path.splitext(f)
    comname = [ x if x == y else 'X' for x,y in zip(comname,list(fname)) ]

comname = "".join(comname)


s2sunic = unique(labels[:,0])


def catrange(x,y):
    y  = y[where(logical_and(x>=opts.cormin,x<=opts.cormax))]
    x  = x[where(logical_and(x>=opts.cormin,x<=opts.cormax))]
    return x,y

f1 = figure(1, figsize=(16,8))
if opts.title is not None:
    suptitle(opts.title,fontsize=23)
cmap = get_cmap("rainbow")
#cmap = get_cmap("tab10")
subplot(121)
# if rgdist is not None and opts.hidrgc:
    # x,y = rgdist[:,0],rgdist[:,1]
    # x,y = catrange(x,y)
    # plot(x,y,"k-",label="rGC",lw=5 if opts.showlns else 3)


for s in s2sunic:
    si   = int(round(s-1))
    idx, =  where(labels[:,0] == s)
    c = cmap(1-si/9. )
    meancor = None
    for ki,k in  enumerate(idx):
        cd = corrdist[k]
        meancor = cd if meancor is None else column_stack((meancor,cd[:,1]))
        if opts.showlns:                    
            x,y = cd[:,0], cd[:,1]
            x,y = catrange(x,y)
            plot(x,y,"-",c=c,lw=0.75)
                
    if meancor is not None:
        x, y = meancor[:,0],mean(meancor[:,1:],axis=1)
        x,y = catrange(x,y)
        plot(x,y,"-",c=c,lw=3,label=r"$\sigma^2$="+f"{s}")
        #DB>>
        # print(si,s,idx)
        # print(meancor)
        # print(x)
        # print(y)
        #<<DB
            
# if igsyn==0 and idly == 0:
legend(loc='best',fontsize=10)
if type(opts.promax) is float:
    ylim(bottom=0,top=opts.promax)

meanmap = {}
ax2= subplot(122)

if opts.swarm:
    df = pd.DataFrame()
    c  ={}
    for s in s2sunic:
        si   = int(round(s-1))
        idx, =  where(labels[:,0] == s)
        df[s] = [meancordist[k] for k in idx]
        c[s]  = cmap(1-si/9. )
    bar(arange(s2sunic.shape[0]),df.mean(),0.25,fc='None',ec='k')
    _, caplines, _ = errorbar(arange(s2sunic.shape[0]),df.mean(),yerr=df.std(),color='k',linestyle='None',lw=1,lolims=True)
    caplines[0].set_marker('_')
    swarmplot(data = df,palette=c,size=10)
    # ylim(bottom=0)
    # #plot(ALLM[:,0],ALLM[:,i+1],'ko')
    # ylim(bottom=0)
elif opts.symbol is not None:
    print(opts.symbol)
    for s in s2sunic:
        si   = int(round(s-1))
        idx, =  where(labels[:,0] == s)
        c = cmap(1-si/9. )
        meancor = [meancordist[k] for k in idx]
        if len(meancor) > 2:
            y = mean(meancor)
            z = std(meancor)
            plot([si],[y],opts.symbol,c=c,ms=opts.synsize)
            errorbar([si],[y],yerr=[z],color='k')
            
else:
    for s in s2sunic:
        si   = int(round(s-1))
        idx, =  where(labels[:,0] == s)
        c = cmap(1-si/9. )
        meancor = [meancordist[k] for k in idx]
        if len(meancor) > 2:
            y = mean(meancor)
            z = std(meancor)
            bar([si],[y],color=c)
            errorbar([si],[y],yerr=[z],color='k')
if type(opts.meancm) is float:
    ylim(bottom=0,top=opts.meancm)
        
        



if opts.savepng:
    f1.savefig(comname+"-distribution-and-meancorrelation.png")
if opts.savesvg:
    f1.savefig(comname+"-distribution-and-meancorrelation.svg")
if opts.shown  : show()