# ------------------------------------------------
#                    Imports
# ------------------------------------------------

from matplotlib import pyplot as plt
from datetime import datetime
import os
import json
import re
import sys, getopt


# ------------------------------------------------
#                   Globals
# ------------------------------------------------

DATA_PATH = './data/'

MONTH_MODE = False
MONTH = None

OTHER_LABEL = 'Les Autres'

# JSON tags
PARTICIPANTS = 'participants'
MESSAGES = 'messages'
NAME = 'name'
CONTENT = 'content'
TIMESTAMP = 'timestamp_ms'
SENDER = 'sender_name'

HELP = """General options :
    -h, --help          Consulter l'aide
    --path=<path>       Redéfinir le chemin d'accès aux données (par défaut ./data)
    --month <mm/yyyy>
    """

# ------------------------------------------------
#                   Functions
# ------------------------------------------------

def printHelp():
    print('Usage:\n '+os.path.basename(__file__)+' <command> [option]\n')
    print(HELP)
    sys.exit(2)

def handleArguments(argv):
    global MONTH_MODE
    global MONTH
    
    try:
        opts, args = getopt.getopt(argv, 'h',['help','path=', 'month='])
    except getopt.GetoptError:
        printHelp()

    for opt, arg in opts:
        if opt in ('-h', '--help'):
            print(HELP)
            sys.exit()
        elif opt in ('--path'):
            DATA_PATH = arg
        elif opt in ('--month'):
            MONTH_MODE = True
            try:
                t = arg.split("/")
                int(t[0])
                int(t[1])
                MONTH = t 
            except:
                printHelp()

def readBrokenFbJson(datafile_path):
    # ntm facebook
    # https://stackoverflow.com/questions/50008296/facebook-json-badly-encoded
    with open(datafile_path, 'rb') as data_file:
        binary_data = data_file.read()
        replace_func = lambda m: bytes.fromhex(m.group(1).decode())
        pattern = rb'\\u00([\da-f]{2})'
    
        repaired = re.sub(pattern, replace_func, binary_data)
        return json.loads(repaired.decode('utf8'))

def computeData():
    # Tous les fichiers du dossier sont traités sans distinction
    datafiles_path = [DATA_PATH + filename for filename in os.listdir(DATA_PATH)]
    messages, participants = [], []

    
    for datafile_path in datafiles_path:
        datacontent = readBrokenFbJson(datafile_path)
        if datacontent is None : continue

        participants += datacontent[PARTICIPANTS]
        messages += datacontent[MESSAGES]

    participants = cleanParticipants(participants)
    messages = cleanMessages(messages)
    return participants, messages

def cleanParticipants(rawParticipants):
    return set([participant[NAME] for participant in rawParticipants])

def cleanMessages(rawMessages):
    if MONTH_MODE:
        cleanMessages = [
            message for message in rawMessages if \
            CONTENT in message and \
            datetime.fromtimestamp(message[TIMESTAMP]/1000).month == int(MONTH[0]) and \
            datetime.fromtimestamp(message[TIMESTAMP]/1000).year == int(MONTH[1])
        ]
    else:
        cleanMessages = [message for message in rawMessages if CONTENT in message]
    return sorted(cleanMessages, key = lambda x: x[TIMESTAMP])

# TODO tester l'approche en recherche incrémentale
# Jeu de données du 14/10/2021. 33679 messages conservés
# pour un compte final de 34120. Soit une perte estimée à 1.3%
def filterMessages(messages):
    return [msg for msg in messages if re.search('(\d{2,}|^\d$)', msg[CONTENT])]

def computeParticipation(messages):
    result = {}

    for message in messages:
        sender = message[SENDER]
        countParticipation(result, sender, message)

    return sorted(result.items(), key = lambda x: x[1])

def countParticipation(participations, sender, message):
    participations[sender] = participations[sender] + 1 if sender in participations else 1

def mergeSmallParticipation(rawParticipation, threshold = 1):
    values = [e[1] for e in rawParticipation]
    labels = [e[0] for e in rawParticipation]

    totalValues = sum(values)
    idx = 0
    for idx, value in enumerate(values):
        if 100 * value / totalValues >= threshold: break
        
    return [sum(values[0:idx])] + values[idx:], [OTHER_LABEL] + labels[idx:]

def displayParticipation(participation):
    values, labels = mergeSmallParticipation(participation)

    plt.figure(figsize=(8,7), tight_layout = True)
    plt.pie(values,
            startangle = 90,
            counterclock = False,
            labels = labels,
            rotatelabels = True)
    plt.show()


def consoleDisplay(participations):
    for participation in participations:
        print(participation)
# ------------------------------------------------
#                   Main Code
# ------------------------------------------------

def main(argv):
    handleArguments(argv)

    participants, messages = computeData()
    messages = filterMessages(messages)

    participation = computeParticipation(messages)
    
    consoleDisplay(participation)
    displayParticipation(participation)

if __name__ == "__main__":
    main(sys.argv[1:])