#!/usr/bin/env python3

import os
from glob import glob
import time
from multiprocessing import Process
import numpy as np

import argparse as ap
from hailo_platform import HailoUdpControllerObject, HailoPcieObject
from tqdm import trange
from colorama import Fore
from hailo_platform.drivers.hw_object import InferTypesName
from hailo_platform.drivers.hailo_controller.power_measurement import (DvmTypes, PowerMeasurementTypes,
                                                                       SamplingPeriod, AveragingFactor)
sent_time = 0

def arg_prep():
    parser = ap.ArgumentParser()
    parser.add_argument('--ip', help='Set the IP of the Hailo device', type=str, default='10.0.0.100')
    parser.add_argument('--jlf-dir', help='Point to the dir that contains 4 JLFs (boot, meta_data, config, params)', default='./JLFs')
    parser.add_argument('--mode', help='Choose the communication mode [hw-only|full]', type=str, default='full')
    parser.add_argument('--iterations', help='How many repeasts on the picture stream (Default:100)', type=int, default=100)
    parser.add_argument('--power', help='Enable Power measurement', default=False, action='store_true')
    parser.add_argument('--source', help='Specify image or video source', default=None)
    args = parser.parse_args()
    return args

def get_jlfs(jlf_dir):
    jlfs = list()
    patterns = list()
    for ext in ('mem', 'jlf'):
        pat = os.path.join(jlf_dir, '*.{}'.format(ext))
        patterns.extend(glob(pat))
    for p in patterns:
        with open(p, 'rb') as f:
            jlfs.append(f.read())

    return jlfs

def _initialize_board(remote_ip, jlf_dir, arch='pcie'):
    """Setup initialization function that loads a model to the device."""

    print('Initializing hardware object...')
    if arch=='udp':
        target = HailoUdpControllerObject(remote_ip)
    elif arch=='pcie':
        target = HailoPcieObject(arch)
    else:
        print("Error arch given: {}".format(arch))

    print('Loading compiled JLFs to device...')
    jlfs = get_jlfs(jlf_dir)
    target.load_jlfs(jlfs)
    return target

def _recv_proc(target, images_count):
    target.setup_recv()
    outputs_names = target.sorted_output_layer_names
    print('-I- Started Receive for {}'.format(images_count))
    for i in trange(images_count, desc='-I- Recv...', position=0, bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.GREEN, Fore.RESET)):
        for outputs_name in outputs_names:
            data = target.recv(outputs_name)

def _power_proc(target, time_to_measure_power):
    target.control.set_power_measurement(0, DvmTypes.VDD_CORE, PowerMeasurementTypes.POWER)
    delay_milliseconds = 1
    target.control.start_power_measurement(delay=delay_milliseconds, averaging_factor=AveragingFactor.AVERAGE_64)
    print('-I- Measure power for {:.2f} sec, Power measure delay is {}ms'.format(time_to_measure_power, delay_milliseconds))
    if (time_to_measure_power<3):
        print('-W- time_to_measure_power was update to 3 from {:.3f} sec, since it\'s unreliable to measure power over short times'.format(time_to_measure_power)) 
        time_to_measure_power=3
    time.sleep(time_to_measure_power)
    measurements = target.control.get_power_measurement(0, DvmTypes.VDD_CORE, PowerMeasurementTypes.POWER)
    target.control.stop_power_measurement()
    print('-I- Power Avg/Max is {:.3f}/{:.3f} mW'.format(measurements.average, measurements.max_value))

def _run_full_streaming(target, images_count, power):
    print('-I- Running full inference...')
    p_recv =  Process(target=_recv_proc, args=(target, images_count))
    power_proccess_started = False
    p_power = None
    input_conf, output_conf = target.generate_dataflow_configuration()
    for outp in output_conf.values():
        outp.should_send_sync = True
    for inp in input_conf.values():
        inp.should_send_sync = False
    with target.use_device(translate_input=True, rescale_outputs=True, python_pipeline=False):
        p_recv.start()
        target.setup_send()
        input_shape = target.get_input_shape()[1:]
        print('-I- Loading dataset, each image has a shape of {}'.format(input_shape))
        dataset_shape = [1] + input_shape
        dataset = np.random.random_integers(2, 20, dataset_shape).astype(np.float32)
        start_time = time.time()
        try:
            for image_id in range(images_count):
                target.send(dataset)

                # out = target.infer(dataset)

                # time measure starts when 10% of images have been sent, and ends when 70% were sent
                if power and not power_proccess_started and image_id > images_count*0.1:
                    time_to_measure_power = (time.time()-start_time)*6
                    p_power = Process(target=_power_proc, args=(target, time_to_measure_power))
                    p_power.start()
                    power_proccess_started = True
            end_time = time.time()
        finally:
            if power and p_power:
                p_power.join()
            p_recv.join()
    image_size = int(input_shape[1]*input_shape[2]*input_shape[0])
    print('-I----------------------------')
    print('-I- Throughput:        {:.3f} MB/sec'.format(images_count*image_size*8/((end_time-start_time)*1024*1024)))
    print('-I- FPS:               {:.3f}'.format(images_count / (end_time-start_time)))
    print('-I----------------------------')

def _send_proc(images_count, input_dataflow, data):
    for _ in range(images_count):
        input_dataflow.send(data)

def _run_hw_only_streaming(target, images_count):
    print('-I- Running hw-only inference')
    input_conf, output_conf = target.generate_dataflow_configuration()
    for outp in output_conf.values():
        outp.should_send_sync = True
    for inp in input_conf.values():
        inp.should_send_sync = False
    with target.use_device(InferTypesName.diy, input_dataflow_configuration=input_conf,
                               output_dataflow_configuration=output_conf, translate_input=True):
        input_shape = target.get_input_shape()
        dataset_shape = [1] + input_shape[1:]
        print('-I- Loading dataset, each image has a shape of {}'.format(input_shape[1:]))
        dataset = np.random.randint(2, 20, dataset_shape).astype(np.float32)
        (input_dataflows, output_dataflows) = target.get_streams()
        input_dataflow = input_dataflows[next(iter(input_dataflows))]
        data = input_dataflow.pre_infer(dataset)
        p_send = Process(target=_send_proc, args=(images_count, input_dataflow, data))
        start_time = time.time()
        p_send.start()
        try:
            for i in trange(images_count, desc='-I- Recv...', position=0, bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.GREEN, Fore.RESET)):
                for output_dataflow in output_dataflows.values():
                    r1, r2 = output_dataflow.recv(1)
            end_time = time.time()
            image_size = int(input_shape[1]*input_shape[2]*input_shape[3])
            print('-I----------------------------')
            print('-I- Throughput: {:.3f} MB/sec'.format(images_count*image_size*8/((end_time-start_time)*1024*1024)))
            print('-I- FPS: {:.3f}'.format(images_count / (end_time-start_time)))
            print('-I----------------------------')
        finally:
            p_send.join()

def streaming_example(remote_ip, streaming_mode, jlf_dir, iterations, power):
    """Streaming example.
    Args:
        remote_ip (str): Board IP address.
        images_count (str): How many images to run.
        streaming_mode (:class:`StreamingModes`): Whether to skip pre-infer and post-infer steps on
            host (hw-only) or do them (full).
    """
    target = _initialize_board(remote_ip, jlf_dir)

    if streaming_mode == 'hw-only':
        _run_hw_only_streaming(target, iterations)
    if streaming_mode == 'full':
        _run_full_streaming(target, iterations, power)

if __name__ == '__main__':
    args = arg_prep()
    streaming_example(remote_ip=args.ip, streaming_mode=args.mode, 
                      jlf_dir=args.jlf_dir, iterations=args.iterations, power=args.power)
