# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import

# This is moose.server.
# It accepts simulation request on a specified TCP port (default 31417).
# It simulates the given file (usually a archive file e.g., tar.bz2) and sends
# back artefacts generated by simulation (mostly images); and streams data from
# moose.Tables back to client.

__author__           = "Dilawar Singh"
__copyright__        = "Copyright 2019, Dilawar Singh"
__version__          = "1.0.0"
__maintainer__       = "Dilawar Singh"
__email__            = "dilawars@ncbs.res.in"
__status__           = "Development"

import sys
import re
import os 
import time
import math
import shutil
import socket 
import signal
import tarfile 
import tempfile 
import threading 
import logging
import subprocess

# create a logger for this server.
logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
        datefmt='%m-%d %H:%M',
        filename='moose_server.log',
        filemode='a'
        )
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
console.setFormatter(formatter)
_logger = logging.getLogger('')
_logger.addHandler(console)

__all__ = [ 'serve' ]

# Global variable to stop all running threads.
stop_all_ = False
sock_     = None
stop_streamer_ = {}

# Use prefixL_ bytes to encode the size of stream. One can probably use just one
# byte to do. Lets go with the inefficient one for now.
prefixL_  = 9

# Matplotlib text for running simulation. It make sures at each figure is saved
# to individual png files.
matplotlibText = """
print( '>>>> saving all figues')
import matplotlib.pyplot as plt
def multipage(filename, figs=None, dpi=200):
    pp = PdfPages(filename)
    if figs is None:
        figs = [plt.figure(n) for n in plt.get_fignums()]
    for fig in figs:
        fig.savefig(pp, format='pdf')
    pp.close()

def saveall(prefix='results', figs=None):
    if figs is None:
        figs = [plt.figure(n) for n in plt.get_fignums()]
    for i, fig in enumerate(figs):
        outfile = '%s.%d.png' % (prefix, i)
        fig.savefig(outfile)
        print( '>>>> %s saved.' % outfile )
    plt.close()

try:
    saveall()
except Exception as e:
    print( '>>>> Error in saving: %s' % e )
    quit(0)
"""


def execute(cmd):
    """execute: Execute a given command.

    :param cmd: string, given command.

    Return:
    ------
        Return a iterator over output.
    """
    popen = subprocess.Popen(cmd, stdout=subprocess.PIPE, universal_newlines=True)
    for stdout_line in iter(popen.stdout.readline, ""):
        yield stdout_line 
    popen.stdout.close()
    return_code = popen.wait()
    if return_code:
        raise subprocess.CalledProcessError(return_code, cmd)


def find_files( dirname, ext=None, name_contains=None, text_regex_search=None):
    files = []
    for d, sd, fs in os.walk(dirname):
        for f in fs:
            fpath = os.path.join(d,f)
            include = True
            if ext is not None:
                if f.split('.')[-1] != ext:
                    include = False
            if name_contains:
                if name_contains not in os.path.basename(f):
                    include = False
            if text_regex_search:
                with open(fpath, 'r' ) as f:
                    txt = f.read()
                    if re.search(text_regex_search, txt) is None:
                        include = False
            if include:
                files.append(fpath)
    return files

def prefix_data_with_size(data):
    global prefixL_
    prefix = b'0'*(prefixL_-int(math.log10(len(data)))-1) + b'%d' % len(data)
    assert len(prefix) == prefixL_
    return b'%s%s' % (prefix, data)

# Signal handler.
def signal_handler(signum, frame):
    global stop_all_
    global sock_
    _logger.info( "User terminated all processes." )
    stop_all_ = True
    #  sock_.shutdown( socket.SHUT_RDWR )
    sock_.close()
    time.sleep(1)
    quit(1)


def split_data( data ):
    global prefixL_
    return data[:prefixL_].strip(), data[prefixL_:]

def send_msg(msg, conn, prefix='LOG'):
    if not msg.strip():
        return False
    if prefix != 'TAB':
        _logger.debug(msg)
    else:
        _logger.debug( 'Sending msg with size %d' % len(msg))
    msg = '<%s>%s' % (prefix, msg)
    conn.sendall(prefix_data_with_size(msg))

def run(cmd, conn, cwd=None):
    _logger.info( "Executing %s" % cmd )
    oldCWD = os.getcwd()
    if cwd is not None:
        os.chdir(cwd)
    try:
        for line in execute(cmd.split()):
            if line:
                send_msg(line, conn)
    except Exception as e:
        send_msg("Simulation failed: %s" % e, conn)
    os.chdir(oldCWD)

def recv_input(conn, size=1024):
    # first 10 bytes always tell how much to read next. Make sure the submit job
    # script has it
    d = conn.recv(prefixL_, socket.MSG_WAITALL)
    while len(d) < prefixL_:
        try:
            d = conn.recv(prefixL_, socket.MSG_WAITALL)
        except Exception:
            _logger.error("MSG FORMAT: %d bytes are size of msg."%prefixL_)
            continue
    d, data = int(d), b''
    while len(data) < d:
        data += conn.recv(d-len(data), socket.MSG_WAITALL)
    return data

def writeTarfile( data ):
    tfile = os.path.join(tempfile.mkdtemp(), 'data.tar.bz2')
    with open(tfile, 'wb' ) as f:
        _logger.info( "Writing %d bytes to %s" % (len(data), tfile))
        f.write(data)
    # Sleep for some time so that file can be written to disk.
    time.sleep(0.1)
    if not tarfile.is_tarfile(tfile):
        _logger.warn( 'Not a valid tar file: %s' % tfile)
        return None
    return tfile

def suffixMatplotlibStmt( filename ):
    outfile = '%s.1.py' % filename
    with open(filename, 'r') as f:
        txt = f.read()

    with open(outfile, 'w' ) as f:
        f.write( txt )
        f.write( '\n' )
        f.write( matplotlibText )
    return outfile

def streamer_client(socketPath, conn):
    # Connect to running socket server.
    global stop_streamer_
    stop = False
    _logger.debug( "Trying to connect to server at : %s" % socketPath )
    while not os.path.exists( socketPath ):
        #print( 'socket %s is not available yet.' % socketPath )
        time.sleep(0.1)
        stop = stop_streamer_[threading.currentThread().name]
        if stop:
            return

    stClient = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    try:
        stClient.connect(socketPath)
    except socket.error as e:
        _logger.warning('Could not connect: %s' % e)
        return

    # send streaming data back to client. The streamer send fixed size messages
    # of 1024/2048 bytes each (see the c++ implmenetation).
    _logger.info( "Socket Streamer is connected with server." )
    stClient.settimeout(0.05)
    send_msg( b'Now streaming table data.', conn, 'TAB')
    while not stop:
        stop = stop_streamer_[threading.currentThread().name]
        data = b''
        try:
            data = stClient.recv(1024)
            if len(data.strip()) > 0:
                send_msg(data, conn, 'TAB')
        except socket.timeout:
            continue
    stClient.close()
    if os.path.isfile(socketPath):
        os.unlink(socketPath)

def run_file(filename, conn, cwd=None):
    # set environment variable so that socket streamer can start.
    global stop_streamer_
    socketPath = os.path.join(tempfile.mkdtemp(), 'SOCK_TABLE_STREAMER')
    os.environ['MOOSE_STREAMER_ADDRESS'] = socketPath
    streamerThread = threading.Thread(target=streamer_client
            , args=(socketPath, conn,))
    stop_streamer_[streamerThread.name] = False
    streamerThread.daemon = True
    streamerThread.start()
    filename = suffixMatplotlibStmt(filename)
    run( "%s %s" % (sys.executable, filename), conn, cwd)
    stop_streamer_[streamerThread.name] = True
    streamerThread.join( timeout = 1)
    if streamerThread.is_alive():
        _logger.error( "The socket streamer client is still running...")

def extract_files(tfile, to):
    userFiles = []
    with tarfile.open(tfile, 'r' ) as f:
        userFiles = f.getnames( )
        try:
            f.extractall( to )
        except Exception as e:
            _logger.warn( e)
    # now check if all files have been extracted properly
    for f in userFiles:
        if not os.path.exists(f):
            _logger.error( "File %s could not be extracted." % f )
    return userFiles

def prepareMatplotlib( cwd ):
    with open(os.path.join(cwd, 'matplotlibrc'), 'w') as f:
        f.write( 'interactive : True' )

def send_bz2(conn, data):
    global prefixL_
    send_msg(data, conn, 'TAR')

def sendResults(tdir, conn, notTheseFiles):
    # Only send new files.
    resdir = tempfile.mkdtemp()
    resfile = os.path.join(resdir, 'results.tar.bz2')
    with tarfile.open( resfile, 'w|bz2') as tf:
        for f in find_files(tdir, ext='png'):
            _logger.info( "Adding file %s" % f )
            tf.add(f, os.path.basename(f))

    time.sleep(0.01)
    # now send the tar file back to client
    with open(resfile, 'rb' ) as f:
        data = f.read()
        _logger.info( 'Total bytes to send to client: %d' % len(data))
        send_bz2(conn, data)
    shutil.rmtree(resdir)

def find_files_to_run( files ):
    """Any file name starting with __main is to be run.
    Many such files can be recieved by client.
    """
    toRun = []
    for f in files:
        if '__main' in os.path.basename(f):
            toRun.append(f)
    if toRun:
        return toRun
    # Else guess.
    if len(files) == 1:
        return files

    for f in files:
        with open(f, 'r' ) as fh:
            txt = fh.read()
            if re.search(r'def\s+main\(', txt):
                if re.search('^\s+main\(\S+?\)', txt):
                    toRun.append(f)
    return toRun

def simulate( tfile, conn ):
    """Simulate a given tar file.
    """
    tdir = os.path.dirname( tfile )
    os.chdir( tdir )
    userFiles = extract_files(tfile, tdir)
    # Now simulate.
    toRun = find_files_to_run(userFiles)
    if len(toRun) < 1:
        return 1
    prepareMatplotlib(tdir)
    status, msg = 0, ''
    for _file in toRun:
        try:
            run_file(_file, conn, tdir) 
        except Exception as e:
            msg += str(e)
            status = 1
    return status, msg

def savePayload( conn ):
    data = recv_input(conn)
    tarfileName = writeTarfile(data)
    return tarfileName, len(data)

def handle_client(conn, ip, port):
    isActive = True
    _logger.info( "Serving request from %s:%s" % (ip, port) )
    while isActive:
        tarfileName, nBytes = savePayload(conn)
        if tarfileName is None:
            _logger.warn( "Could not recieve data." )
            isActive = False
        if not os.path.isfile(tarfileName):
            send_msg("[ERROR] %s is not a valid tarfile. Retry"%tarfileName, conn)
            break

        # list of files before the simulation.
        notthesefiles = find_files(os.path.dirname(tarfileName))
        res, msg = simulate( tarfileName, conn )
        if 0 != res:
            send_msg( "Failed to run simulation: %s" % msg, conn)
            isActive = False
            time.sleep(0.1)

        # Send results after DONE is sent.
        send_msg('All done', conn, 'EOS')
        sendResults(os.path.dirname(tarfileName), conn, notthesefiles)
        break


def start_server( host, port, max_requests = 10 ):
    global stop_all_
    global sock_
    sock_ = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    try:
        sock_.bind( (host, port))
        _logger.info( "Server created %s:%s" %(host,port) )
    except Exception as e:
        _logger.error( "Failed to bind: %s" % e)
        quit(1)

    # listen upto 10 of requests
    sock_.listen(max_requests)
    while True:
        if stop_all_:
            break
        sock_.settimeout(10)
        try:
            conn, (ip, port) = sock_.accept()
        except socket.timeout as e:
            continue
        sock_.settimeout(0.0)
        t = threading.Thread(target=handle_client, args=(conn, ip, port)) 
        t.start()
    sock_.close()

def serve(host, port):
    start_server(host, port)

def main( args ):
    global stop_all_
    host, port = args.host, args.port
    # Install a signal handler.
    signal.signal( signal.SIGINT, signal_handler)
    serve(host, port)

if __name__ == '__main__':
    import argparse
    # Argument parser.
    description = '''Run MOOSE server.'''
    parser = argparse.ArgumentParser(description=description, add_help=False)
    parser.add_argument( '--help', action='help', help='Show this msg and exit')
    parser.add_argument('--host', '-h'
        , required = False, default = socket.gethostbyname(socket.gethostname())
        , help = 'Server Name'
        )
    parser.add_argument('--port', '-p'
        , required = False, default = 31417, type=int
        , help = 'Port number'
        )
    class Args: pass 
    args = Args()
    parser.parse_args(namespace=args)
    try:
        main(args)
    except KeyboardInterrupt as e:
        stop_all_ = True
        quit(1)