#import matplotlib.pyplot as plt
#plt.ion()
import numpy as np
import ctadpole
import random
n=1382
m=200

per = 0.5 # desired percentage of total connections
id_chosen = 4 # desired type of cells for which decrease probability of connections from these type of cells

def build_probability_matrix():
	path='/home/andrea/Desktop/computer_right/Evenbody CONNECTOME 22January2013 with SYN_PROBABILITY/connectomes/' # path of the sample connetomes
	all_cellists=[] # vector containig the cellists of the different sampled connectomes
	for connectome_id in range(1,m+1):
		print connectome_id
		# load sample connectome i
		cellist=ctadpole.load_cells(path+'DendriteGrad'+str(connectome_id)+'.txt',path+'inc_connectGrad'+str(connectome_id)+'.txt', properties_file_path=None)
		idx=[x.type_id for x in cellist]
		idx=np.array(idx)
		
		vec_idx=[]
		for i in xrange(7):
			v=np.where(idx==i)
			vec_idx.append(v[0].tolist())
		vec_idx=np.concatenate((vec_idx)).tolist()

		# create new cellist with ordered positions
		new_cellist=[cellist[ind] for ind in vec_idx]
		for i in xrange(n):
			new_cellist[i].id=i
		
		all_cellists.append(new_cellist)

	# ===== BUILD THE PROBABILITY MATRIX BETWEEN EACH CELL TYPE AND MEAN VALUES =====
	P=np.zeros((n,n))
	pos=np.zeros((n,1))
	dendrite_bounds=np.zeros((n,2))
	count=1
	for x in all_cellists:
		print count
		A=np.zeros((n,n))
		for i in xrange(n):
			inc_conn=[y.id for y in x[i].incoming_connections] # calculate unique vector of incoming connections 
			inc_conn=list(set(inc_conn)) # avoid multiple connections from single cell
			A[i,inc_conn]=+1
			P[i,inc_conn]=P[i,inc_conn]+1

			# mean values update
			pos[i]+=x[i].pos
			dendrite_bounds[i,0]+=x[i].dendrite_bounds[0]
			dendrite_bounds[i,1]+=x[i].dendrite_bounds[1]
		np.savetxt("many connectomes axon growth/A_connectome"+str(count)+".txt",A.transpose())
		count=count+1
	P=P/m
	pos=pos/m
	dendrite_bounds=dendrite_bounds/m
	np.savetxt("P_1000.txt",P)
	np.savetxt("pos_1000.txt",pos)
	np.savetxt("dendrite_bounds_1000.txt",dendrite_bounds)

# ==== divide the cellist between left and right sides ====
def sort_cellist(cellist):
	left_cellist=[]
	right_cellist=[]
	for x in cellist:
		if x.body_side==1:
			left_cellist.append(x)
		else:
			right_cellist.append(x)
	newout_cellist=np.concatenate([left_cellist,right_cellist]).tolist()
	return newout_cellist


def build_new_connectome(num):
	pos=np.loadtxt("pos.txt")
	dendrite_bounds=np.loadtxt("dendrite_bounds.txt")
	P=np.loadtxt("P_1000.txt")
	A=np.zeros(P.shape)
	path='../PyNN_Tadpole_Simulation (copy)/txt files/';
	sample_cellist=ctadpole.load_cells(path+'DendriteGrad1.txt',path+'inc_connectGrad1.txt', properties_file_path=None)

	idx=np.array([x.type_id for x in sample_cellist])
	new_cellist=[]
	vec_idx=[np.where(idx==i)[0].tolist() for i in xrange(7)]
	vec_idx=np.concatenate((vec_idx)).tolist()
	new_cellist=[sample_cellist[ind] for ind in vec_idx]
	for i in xrange(n):
		new_cellist[i].id=i

	out_cellist=[]
	for i in xrange(n):
		out_cellist.append(ctadpole.Cell(new_cellist[i].id, new_cellist[i].type_id, new_cellist[i].relative_id, pos[i], dendrite_bounds=(dendrite_bounds[i,0],dendrite_bounds[i,1]), body_side=new_cellist[i].body_side))
		out_cellist[-1].properties_dict = new_cellist[i].properties_dict
		out_cellist[-1].primary_axon_points = new_cellist[i].primary_axon_points
		out_cellist[-1].secondary_axon_points = new_cellist[i].secondary_axon_points
		#out_cellist[-1].incoming_connections = new_cellist[i].incoming_connections

	np.random.seed(int(random.getrandbits(32)))
	count=0

	'''	
	A=np.loadtxt("A_connectome.txt")
	for i in xrange(n):
		for j in xrange(n):
			if A[i,j]==1:
				out_cellist[i].incoming_connections.append(out_cellist[j])
	'''

	for i in xrange(n):
		for j in xrange(n):
			if out_cellist[j].type_id==4 and pos[i]<pos[j]:
				P[i,j]=0
	
	for i in xrange(n):
		for j in xrange(n):
			if np.random.binomial(1,P[i,j]):
				A[i,j]=1
				out_cellist[i].incoming_connections.append(out_cellist[j])
	

	number_types=[]
	ids=[cell.type_id for cell in out_cellist]
	for i in xrange(7):
		number_types.append(len([x for x in ids if x==i]))
	print number_types

	conn={i: len([j for j in out_cellist[i].incoming_connections]) for i in xrange(n)}
	totnum_conn=np.sum(conn[x] for x in xrange(n))

	sample_conn={i: len([j for j in new_cellist[i].incoming_connections]) for i in xrange(n)}
	sample_totnum_conn=np.sum(sample_conn[x] for x in xrange(n))


	num_incconn_type=[]
	id_start=0; id_stop=number_types[0]
	for k in xrange(7):
		num_incconn_type.append(np.sum([len([x for x in out_cellist[i].incoming_connections]) for i in range(id_start,id_stop)]))
		if k!=6:
			id_start=id_stop+1; id_stop=id_stop+number_types[k+1]

	sample_num_incconn_type=[]
	id_start=0; id_stop=number_types[0]
	for k in xrange(7):
		sample_num_incconn_type.append(np.sum([len([x for x in new_cellist[i].incoming_connections]) for i in range(id_start,id_stop)]))
		if k!=6:
			id_start=id_stop+1; id_stop=id_stop+number_types[k+1]


	# ==== PROBABILITY OF CONNECTIONS BETWEEN DIFFERENT TYPES OF CELLS IN Q_conn_tot ====
	P = P.transpose() # get the probability of i connects to j - i is now presinaptic, j postsynaptic
	id_start=0; id_stop=number_types[0]
	Q_conn_tot = np.zeros((7,7))
	for k in xrange(7):
		id2_start=0; id2_stop=number_types[0]
		for l in xrange(7):
			Q_conn_tot[k,l] = np.sum(P[id_start:id_stop,id2_start:id2_stop])
			if l!=6:
				id2_start=id2_stop+1; id2_stop=id2_stop+number_types[l+1]
		if k!=6:
			id_start=id_stop+1; id_stop=id_stop+number_types[k+1]
	np.savetxt("Q_con",Q_conn_tot)

	
	# ==== SORT CELLISTS WITH LEFT-RIGHT POSITIONS ====
	out_cellist=sort_cellist(out_cellist)
	for i in xrange(n):
		out_cellist[i].id=i

	#new_cellist=sort_cellist(new_cellist)
	#for i in xrange(n):
	#	new_cellist[i].id=i

	path='../PyNN_Tadpole_Simulation (copy)/connectome files/'
	ctadpole.save_cells(out_cellist, path+'cell_file'+str(num)+'.txt', path+'connectome'+str(num)+'.txt')
	np.save(path+"figures/A"+str(num)+".npy",A)

	import main
	pop=main.main_full(str(num))
	return pop



def load_cells(cell_file_path, connectome_file_path, properties_file_path=None):
    cell_file = open(cell_file_path, "rt")
    
    type_counts = {}

    cell_list = []
    for l in cell_file.readlines():
        toks = l.split()
        id = int(float(toks[0])) - 1
        type_id = int(float(toks[1])) - 1
        if type_id not in type_counts:
            type_counts[type_id] = 0
        else:
            type_counts[type_id] += 1

        relative_id = type_counts[type_id]
        pos = float(toks[2])
        dendrite_bounds = (float(toks[4]), float(toks[3]))
        # Don't set body_side yet until we know how many cells there are
        cell_list.append(Cell(id, type_id, relative_id, pos, dendrite_bounds))
    cell_file.close()

    for c in cell_list:
        c.body_side = 1 if c.id < len(cell_list)/2 else 2

    connectome_file = open(connectome_file_path)
    connectome_file.readline() # Skip first line
    line = 1
    for l in connectome_file.readlines():
        cell = get_cell(cell_list, line-1)
        if cell is None:
            raise Exception("%s:%d: no cell with ID %d." %
                            (connectome_file_path, line, line))

        cols = [int(s) for s in l.split()]
        for i in range((cols[0] - 1) / 2):
            inc_id = cols[1 + i*2] - 1
            inc_type = cols[2 + i*2] - 1
            inc_cell = get_cell(cell_list, inc_id)
            if inc_cell is None:
                raise Exception("%s:%d: can't find cell ID %d." % (connectome_file_path, line, inc_id+1))
            if inc_cell.type_id != inc_type:
                raise Exception("%s:%d: connectome and cell list specify different cell types (%d <> %d)." %
                                (connectome_file_path, line, inc_cell.type_id, inc_type))

            cell_list[line-1].incoming_connections.append(inc_cell)

        line += 1
    connectome_file.close()

    if properties_file_path is not None:
        js_file = open(properties_file_path)
        js_root = json.load(js_file)
        for js in js_root["properties"]:
            cell = cell_list[js["id"]-1]
            cell.properties_dict = {k: js[k] for k in js.keys() if k not in ["id", "type"]}

        js_file.close()

    return cell_list