# -*- coding: utf-8 -*-
"""
kinetic_model_package

runs kinetic model

input:
listS,conteggioPYR,dictProbab,initDF,W_INT_INT,W_INT_PYR,W_PYR_INT,W_PYR_PYR,maxTime    
    
"""
import time
import pandas as pd
import numpy as np
    
    
def kinetic_model(listS,conteggioPYR,dictProbab,initDF,W_INT_INT,W_INT_PYR,W_PYR_INT,W_PYR_PYR,maxTime):    

    S = len(listS) #slice count

    cap=[]
    for i in range(S):
        cap.append(listS[i]-conteggioPYR[i]) #numero di interneuroni totali (h^) tutti spenti
    
    R = sum(listS)  # number of neurons
                    
    # condizioni iniziali da dataframe
    N0 = np.zeros((2, S, 2))  # Initialize N0 with shape (2, S, 2)
    for h in range(1, S+1):  # Loop over slices
        # Interneuron
        N0[0, h-1, 0] = initDF[h-1,0]   # Inactive state
        N0[0, h-1, 1] = initDF[h-1,1]   # Active state
        # Pyramidal
        N0[1, h-1, 0] = initDF[h-1,2]  # Inactive state
        N0[1, h-1, 1] = initDF[h-1,3]  # Active state
    #print(N0)
    
    print(time.localtime())
    
    def C_prob(s,t,vl,vn,vm,h,k): #transition probabilities
        """
        Calculates the interaction between all pairs of neurons in the network.
    
        Parameters:
        neurons (list): A list of Neuron objects.
    
        Returns:
        None
        """

        p1 = dictProbab['p1'][h][k]
        p2 = dictProbab['p2'][h][k]
        q1 = dictProbab['q1'][h][k]
        q2 = dictProbab['q2'][h][k]          

        
        
        if vn==0 and vm==0 and vl==0:#s,t pyr or interneurons; C000
            return 1 #18
        if vn==0 and vm==0 and vl==1:#s,t pyr or interneurons; to change status if m inactive C100
            return 0 #18
        if vn==1 and vm==0 and vl==0:#s,t pyr or interneurons; to change status if m inactive
            return 0 #19
        if vn==1 and vm==0 and vl==1:#s,t pyr or interneurons; to rest in its status if m inactive
            return 1 #19
        if s==0 and t==0 and vn==0 and vm==1 and vl==0: #s,t both interneurons: int remains off if m active
            return 1 #21
        if s==0 and t==0 and vn==0 and vm==1 and vl==1: #s,t both interneurons: int turns on if m active
            return 0 #21
        if s==0 and t==1 and vn==0 and vm==1 and vl==0: #s interneuron t pyramidal, int remains off if pyr active C001
            return 1-p1 #22
        if s==0 and t==1 and vn==0 and vm==1 and vl==1:#s interneuron t pyramidal, int turns on if pyr active C101
            return p1 #22
        if s==1 and t==0 and vn==0 and vm==1 and vl==0: #s pyramidal t interneuron, pyr remains off if int active C010
            return 1 #23
        if s==1 and t==0 and vn==0 and vm==1 and vl==1: #s pyramidal t interneuron, pyr turns on if int active C101
            return 0 #23
        if s==1 and t==1 and vn==0 and vm==1 and vl==0: #s,t both pyramidal, pyr remains off if mPyr active
            return 1-p2 #24
        if s==1 and t==1 and vn==0 and vm==1 and vl==1: #s,t both pyramidal, pyr turns on if mPyr active
            return p2 #24
        if s==0 and t==0 and vn==1 and vm==1 and vl==0: #s,t both interneurons, int turns off if mInt active
            return q1 #25
        if s==0 and t==0 and vn==1 and vm==1 and vl==1:#s,t both interneurons, int remains on if mInt active
            return 1-q1 #25
        if s==0 and t==1 and vn==1 and vm==1 and vl==0: #s interneuron t pyramidal; int turns off if mPyr active
            return 0 #26
        if s==0 and t==1 and vn==1 and vm==1 and vl==1: #s interneuron t pyramidal; int remains on if mPyr active
            return 1 #26
        if s==1 and t==0 and vn==1 and vm==1 and vl==0: #s pyramidal t interneuron, pyr turns off if mInt active
            return q2 #27
        if s==1 and t==0 and vn==1 and vm==1 and vl==1:#s pyramidal t interneuron, pyr remains on if mInt active
            return 1-q2 #27
        if s==1 and t==1 and vn==1 and vm==1 and vl==0: #s,t both pyramidal, pry turns off if mPyr active
            return 0 #28
        if s==1 and t==1 and vn==1 and vm==1 and vl==1:#s,t both pyramidal, pry remains on if mPyr active
            return 1 #28
           
    def D_function_int_int(h,k,vl,vn,vm): 
        if np.round(C_prob(0,0,0,vn,vm,h,k)+ C_prob(0,0,1,vn,vm,h,k))!=1:
            print('no ok1 C int int')
        Dii = W_INT_INT.loc[h].iat[k]*C_prob(0,0,vl,vn,vm,h,k)
        return Dii
    
    
    def D_function_int_pyr(h,k,vl,vn,vm):
        if np.round(C_prob(0,1,0,vn,vm,h,k)+ C_prob(0,1,1,vn,vm,h,k))!=1:
            print('no ok1 C int pyr')
        Dii = W_INT_PYR.loc[h].iat[k]*C_prob(0,1,vl,vn,vm,h,k)
        return Dii
    
    def D_function_pyr_int(h,k,vl,vn,vm): 
        if np.round(C_prob(1,0,0,vn,vm,h,k)+ C_prob(1,0,1,vn,vm,h,k))!=1:
            print('no ok1 C pyr int')
        Dii = W_PYR_INT.loc[h].iat[k]*C_prob(1,0,vl,vn,vm,h,k)
        return Dii
    
    def D_function_pyr_pyr(h,k,vl,vn,vm):
        if np.round(C_prob(1,1,0,vn,vm,h,k)+ C_prob(1,1,1,vn,vm,h,k))!=1:
            print('no ok1 C pyr pyr')
        Dii = W_PYR_PYR.loc[h].iat[k]*C_prob(1,1,vl,vn,vm,h,k)
        return Dii
    
    D_int_int=np.zeros((S, 2)) 
    for h in range(0, S):
    	for l in range(0,2): 
    		for k in range(0, S):
    			for vn in range(0, 2):
    				for vm in range(0, 2):
    					#print('h', h,'k',k,'l',l,'vn',vn,'vm',vm)
    					D_int_int[h,l]=D_function_int_int(h,k,l,vn,vm)
    					#print(h,l,D_function_int_int(h,k,l,vn,vm))
    					D0 = D_function_int_int(h,k,0,vn,vm)
    					D1 = D_function_int_int(h,k,1,vn,vm)
    					#print('D 0 is ', D0)
    					#print('D 1 is ', D1)
    					if np.round(D0+D1)!=W_INT_INT.loc[h].iat[k]:
    						print('no ok int int')
    print('verificato int int')
    #print(D_int_int)
    
    D_int_pyr=np.zeros((S, 2))
    for h in range(0, S):
    	for l in range(0,2): 
    		for k in range(0, S):
    			for vn in range(0, 2):
    				for vm in range(0, 2):
    					D_int_pyr[h,l]=D_function_int_pyr(h,k,l,vn,vm)
    					D0 = D_function_int_pyr(h,k,0,vn,vm)
    					D1 = D_function_int_pyr(h,k,1,vn,vm)
    					#print('D 0 is ', D0)
    					#print('D 1 is ', D1)
    					if np.round(D0+D1)!=W_INT_PYR.loc[h].iat[k]:
    						print('no ok int pyr')
    print('verificato int pyr')
    # print(D_int_pyr)
    
    D_pyr_int=np.zeros((S, 2))
    for h in range(0, S):
    	for l in range(0,2): 
    		for k in range(0, S):
    			for vn in range(0, 2):
    				for vm in range(0, 2):
    					D_pyr_int[h,l]=D_function_pyr_int(h,k,l,vn,vm)
    					D0 = D_function_pyr_int(h,k,0,vn,vm)
    					D1 = D_function_pyr_int(h,k,1,vn,vm)
    					#print('D 0 is ', D0)
    					#print('D 1 is ', D1)
    					if np.round(D0+D1)!=W_PYR_INT.loc[h].iat[k]:
    						print('no ok pyr int')
    print('verificato pyr int')
    # print(D_pyr_int)
    
    D_pyr_pyr=np.zeros((S, 2))
    for h in range(0, S):
    	for l in range(0,2): 
    		for k in range(0, S):
    			for vn in range(0, 2):
    				for vm in range(0, 2):
    					D_pyr_pyr[h,l]=D_function_pyr_pyr(h,k,l,vn,vm)
    					D0 = D_function_pyr_pyr(h,k,0,vn,vm)
    					D1 = D_function_pyr_pyr(h,k,1,vn,vm)
    					#print('D 0 is ', D0)
    					#print('D 1 is ', D1)
    					#print(whkst)
    					if np.round(D0+D1)!=W_PYR_PYR.loc[h].iat[k]:
    						print('no ok pyr pyr')
    print('verificato pyr pyr')
    # print(D_pyr_pyr)
    
    D_int_int={}
    D_int_pyr={}
    D_pyr_pyr={}
    D_pyr_int={}
    for h in range(0, S):
    	for l in range(0,2):
    		for k in range(0, S):
    			for vn in range(0, 2):
    				for vm in range(0, 2):
    					D_int_int[(h,k,l,vn,vm)]=D_function_int_int(h,k,l,vn,vm)
    					D_int_pyr[(h,k,l,vn,vm)]=D_function_int_pyr(h,k,l,vn,vm)
    					D_pyr_pyr[(h,k,l,vn,vm)]=D_function_pyr_pyr(h,k,l,vn,vm)
    					D_pyr_int[(h,k,l,vn,vm)]=D_function_pyr_int(h,k,l,vn,vm)
    # print(D_int_int)
    def M_rhs(tMRHS,NL):
    	NL_int = []
    	a=0
    	for h in range(0, S):
    		for l in range(0,2):
    			NL_int.append(NL[a])
    			a=a+1
    	NL_pyr = []
    	for h in range(0, S):
    		for l in range(0,2):
    			NL_pyr.append(NL[a])
    			a=a+1
    	NMRHSL1=NL_int
    	NMRHSL2=NL_pyr
    	NMRHS1 = np.zeros((S, 2))
    	a=0 
    	for h in range(0, S):
    		for l in range(0,2):
    			NMRHS1[h,l]=NMRHSL1[a]
    			a=a+1
    	#print(NMRHS1)
    	NMRHS2 = np.zeros((S, 2))
    	a=0 
    	for h in range(0, S):
    		for l in range(0,2):
    			NMRHS2[h,l]=NMRHSL2[a]
    			a=a+1
    	#print(NMRHS2)
    	dNdt18 = np.zeros_like(NMRHS1)		
    	for h in range(0, S):
    		#print('slice',h)
    		sumN1=0	
    		sumN2=0
    		sumN1 += NMRHS1[h,0]+NMRHS1[h,1]
    		sumN2 += NMRHS2[h,0]+NMRHS2[h,1]
    		for l in range(0,2):
    			N_l = NMRHS1[h,l]
    			dNdtS1 = 0	
    			dNdtS2 = 0  
    			dNdtS3 = 0 
    			for k in range(0, S):
    				for vn in range(0, 2):
    					for vm in range(0, 2):
    						N_n = NMRHS1[h,vn]
    						N_m1 = NMRHS1[k,vm]
    						N_m2 = NMRHS2[k,vm]
    						dNdtS1 = dNdtS1+ D_int_int[(h,k,l,vn,vm)]* N_n * N_m1
    						dNdtS2 = dNdtS2+ D_int_pyr[(h,k,l,vn,vm)]* N_n * N_m2
    				for vm in range(0,2):
    					N_m1 = NMRHS1[k,vm]
    					N_m2 = NMRHS2[k,vm]
    					dNdtS3 =dNdtS3+ W_INT_INT.loc[h].iat[k]*N_m1+W_INT_PYR.loc[h].iat[k]*N_m2 
    			dNdt18[h,l] = dNdtS1+dNdtS2- N_l*dNdtS3

    	dNdt18L = [] 
    	for h in range(0, S):
    		for l in range(0,2):
    			dNdt18L.append(dNdt18[h,l])
    	dNdt18 = np.zeros_like(NMRHS1)
    	sumN1=0	
    	sumN2=0		
    	for h in range(0, S):
    		sumN1 += NMRHS1[h,0]+NMRHS1[h,1]
    		sumN2 += NMRHS2[h,0]+NMRHS2[h,1]
    		for l in range(0,2):
    			N_l = NMRHS2[h,l]
    			dNdtS1 = 0	
    			dNdtS2 = 0  
    			dNdtS3 = 0 
    			for k in range(0, S):
    				for vn in range(0, 2):
    					for vm in range(0, 2):
    						N_n = NMRHS2[h,vn]
    						N_m1 = NMRHS1[k,vm]
    						N_m2 = NMRHS2[k,vm]
    						dNdtS1 = dNdtS1+ D_pyr_pyr[(h,k,l,vn,vm)] * N_n * N_m2
    						dNdtS2 = dNdtS2+ D_pyr_int[(h,k,l,vn,vm)] * N_n * N_m1
    				for vm in range(0,2):
    					N_m1 = NMRHS1[k,vm]
    					N_m2 = NMRHS2[k,vm]
    					dNdtS3 =dNdtS3+ W_PYR_INT.loc[h].iat[k]*N_m1+W_PYR_PYR.loc[h].iat[k]*N_m2
    			dNdt18[h,l] = dNdtS1+dNdtS2- N_l*dNdtS3

    	for h in range(0, S):
    		for l in range(0,2):
    			dNdt18L.append(dNdt18[h,l])
    	return dNdt18L
    
    print(time.localtime())
    
    from scipy.integrate import solve_ivp
    N0L = []
    for i in range(0, 2): 
    	for h in range(0, S):
    		for l in range(0,2):
    			N0L.append(N0[i,h,l])
    
    sol = solve_ivp(M_rhs,[0,maxTime], N0L, method = 'BDF',rtol=1e-5,atol=1e-8)
    
    print('sol')
    print(time.localtime())

    return(sol)