#This code is adapted from venny4py (https://github.com/timyerg/venny4py), licensed under the MIT License.
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Ellipse
import os

def draw_venn_from_counts(region_counts, set_labels, out='./', ext='png', dpi=300,
                          size=3.5, colors="bgrc", line_width=None, font_size=None,
                          legend_cols=2, column_spacing=4, edge_color='black'):
    """
Example:

region_counts = {
    "1000": 10,
    "0100": 12,
    "0010": 14,
    "0001": 9,
    "1100": 5,
    "1010": 3,
    "1001": 2,
    "0110": 4,
    "0101": 3,
    "0011": 1,
    "1110": 2,
    "1101": 1,
    "1011": 1,
    "0111": 1,
    "1111": 1
}

labels = ["Set A", "Set B", "Set C", "Set D"]

draw_venn_from_counts(region_counts, labels)
    """

    if len(set_labels) != 4:
        raise ValueError("Only 4-set Venn diagrams are supported in this function.")
    
    os.makedirs(out, exist_ok=True)
    
    ce = colors
    lw = size * .12 if line_width is None else line_width
    fs = size * 2 if font_size is None else font_size
    nc = legend_cols
    cs = column_spacing
    ec = edge_color

    # Coordinates for ellipses
    ew, eh = 45, 75
    xe = [35, 48, 52, 65]
    ye = [35, 45, 45, 35]
    ae = [225, 225, 315, 315]

    # Coordinates for labels (total 16)
    xt = [12, 32, 68, 88, 14, 34, 66, 86, 26, 28, 50, 50, 72, 74, 37, 60]
    yt = [67, 79, 79, 67, 41, 70, 70, 41, 59, 26, 11, 60, 26, 59, 51, 17]
    
    #Tuomo: use these ones instead
    xt = [92, 14, 34, 66, 86, 26, 28, 50, 50, 72, 74, 37, 60, 40, 63, 50] #x                                                                                                                                                
    yt = [70, 41, 70, 70, 41, 59, 26, 11, 60, 26, 59, 51, 17, 17, 51, 35] #y                                                                                                           

    # Create plot
    plt.rcParams['figure.dpi'] = 200
    plt.rcParams['savefig.dpi'] = dpi
    fig, ax = plt.subplots(1, 1, figsize=(size, size))
    ax.set_xlim(0, 100)
    ax.set_ylim(0, 100)
    ax.axis('off')

    # Draw ellipses
    for i in range(4):
        ax.add_artist(Ellipse(xy=(xe[i], ye[i]), width=ew, height=eh, fc=ce[i],
                              angle=ae[i], alpha=.3))
        ax.add_artist(Ellipse(xy=(xe[i], ye[i]), width=ew, height=eh, fc='None',
                              angle=ae[i], ec=ec, lw=lw))

    # Draw text for regions (in binary order: 1000, 0100, 0010, ..., 1111)
    binary_keys = [
        "0000", "1000", "0100", "0010", "0001",
        "1100", "1010", "1001", "0110",
        "0101", "0011", "1110", "1101",
        "1011", "0111", "1111"
    ]

    # Add counts to regions
    for i, key in enumerate(binary_keys):
        val = region_counts.get(key, 0)
        ax.text(xt[i], yt[i], str(val), ha='center', va='center', fontsize=fs)
        print("x = "+str(xt[i])+", y = "+str(yt[i]))

    # Add set labels
    handles = [mpatches.Patch(color=ce[i], label=lbl, alpha=.3) for i, lbl in enumerate(set_labels)]
    ax.legend(handles=handles, fontsize=fs * 1.1, frameon=False,
              bbox_to_anchor=(.5, .99), bbox_transform=ax.transAxes, loc=9,
              handlelength=1.5, ncol=nc, columnspacing=cs, handletextpad=.5)

    fig.savefig(f'{out}/Venn_4sets.{ext}', bbox_inches='tight', facecolor='w')

    return fig,ax