#!/usr/bin/env python3

import os
import time
import argparse as ap
import numpy as np
from zenlog import logging as logger
from colorama import Fore
from tqdm import trange
import ctypes
from multiprocessing import Process, Value, Array
from hailo_platform import HEF, HailoPcieObject, InputStreamParams, OutputStreamParams, FormatType
from hailo_platform import HailoUdpControllerObject, HailoPcieObject, SendPipeline, RecvPipeline
from hailo_platform.drivers.hailort.pyhailort import PcieDevice, HailoRTException

def arg_prep():
    parser = ap.ArgumentParser()
    parser.add_argument('--hef', help='Point to the HEF')
    parser.add_argument('--mode', help='Choose the communication mode [hw-only|full]', type=str, default='full')
    parser.add_argument('--iterations', help='How many repeasts on the picture stream (Default:100)', type=int, default=100)
    parser.add_argument('--power', help='Enable Power measurement', default=False, action='store_true')
    parser.add_argument('--source', help='Specify image or video source', default=None)
    parser.add_argument('--interface', help='Specify the physical interface, pcie or udp', default='pcie')
    parser.add_argument('--fps', help='Emulate a source FPS', type=int, default='0')
    # parser.add_argument('--ip', help='Set the IP of the Hailo device', type=str, default='10.0.0.100')
    args = parser.parse_args()
    return args

def _pre_infer_hef(stream, input_data):
    dst_buffer = ctypes.create_string_buffer(stream.stream_information.hw_frame_size)
    stream.transform(input_data, ctypes.addressof(dst_buffer), False)
    return dst_buffer.raw


def _recv_process_hw_only_hef(output_streams, iterations, end_time, recv_times):
    logger.info('RECV process HW-only Started')
    local_recv_times = list()
    outputs = dict()
    for i in trange(iterations, desc='INFO:Recv...', position=0, bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.GREEN, Fore.RESET)):
        try:
            for j, output_stream in enumerate(output_streams):
                outputs[j] = output_stream.recv()
                if i<100 and j==0:
                    local_recv_times.append(time.time())
        except HailoRTException as e:
            if e.args and '0x4' in e.args and i>iterations*0.9:
                pass
    for i, t in enumerate(local_recv_times):
        recv_times[i] = t
    end_time.value = time.time()

def _recv_process_full_hef(activated_network, iterations, end_time, recv_times):
    logger.info('RECV process Full Started')
    with RecvPipeline(activated_network) as recv_pipeline:
        output_streams = [recv_pipeline.get_output_by_name(output_name) for output_name in activated_network.target.sorted_output_layer_names]
        [logger.info("Output: {}".format(name)) for name in activated_network.target.sorted_output_layer_names]
        outputs = dict()
        for i in trange(iterations, desc='INFO:Recv...', position=0, bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.GREEN, Fore.RESET)):
            logger.debug("[{}] Recv frame {}/{}".format(time.time(), i, iterations))
            for j, output_stream in enumerate(output_streams):
                outputs[j] = output_stream.recv()
                if i<100:
                    recv_times[i] = time.time()
        end_time.value = time.time()
        logger.debug("[{}] Finished Recving {} frames".format(end_time.value, iterations))

def _send_process_hw_only_hef(input_streams, iterations, fps, send_times):
    logger.info('SEND process HW-only Started')
    local_send_times = list()
    if fps>0.0:
        logger.info("Emulate source FPS: {}".format(fps))
    data_per_input = [_pre_infer_hef(s, np.random.randint(256, size=(1,) + shape, dtype=np.uint8)) for s, shape in input_streams.items()]
    # [logger.info("Input: {} with shape: {}".format(name, shape)) for name, shape in input_streams.items()]
    for i in range(iterations):
        try:
            for intput_stream in input_streams.keys():
                if fps>0:
                    time.sleep(1.0/fps)
                if i<100:
                    local_send_times.append(time.time())
                intput_stream.send(data_per_input[0])
        except HailoRTException as e:
            if e.args and '0x4' in e.args and i>iterations*0.9:
                pass
    for i,t in enumerate(local_send_times):
        send_times[i] = t

def _send_process_full_hef(activated_network, iterations, input_shapes, fps, send_times):
    with SendPipeline(activated_network) as send_pipeline:
        input_streams = {send_pipeline.get_input_by_name(input_name): shape
                        for input_name, shape in input_shapes.items()}
        data_per_input = [np.random.randint(256, size=(1,) + shape, dtype=np.uint8) for _, shape in input_shapes.items()]
        [logger.info("Input: {} with shape: {}".format(name, shape)) for name, shape in input_shapes.items()]
        if fps>0.0:
            logger.info("Emulate source FPS: {}".format(fps))
        local_send_times = list()
        for i in range(iterations):
            try:   
                for input_stream, _ in input_streams.items():
                    logger.debug("[{}] Send frame {}/{}".format(time.time(), i, iterations))
                    if fps>0:
                        time.sleep(1.0/fps)
                    if i<100:
                        local_send_times.append(time.time())
                    input_stream.send(data_per_input[0])
            except HailoRTException as e:
                if e.args and '0x4' in e.args and i>iterations*0.9:
                    pass 
        for i,t in enumerate(local_send_times):
            send_times[i] = t

def run_hef(target, hef, iterations, streaming_mode, fps):
    logger.info("Loading HEF to target")
    network = target.configure(hef)[0]
    
    application_params    = network.create_params()
    # Explanation about definition of the streams-
    #    quantized   - Whether to scale and zero-point the values from 0-255, if the input is uint8 can use as-is
    #    format_type - The input type UINT8, UINT16, FLOAT32, AUTO
    input_streams_params  = InputStreamParams.make_from_network_group(network, quantized=True, format_type=FormatType.UINT8)
    output_streams_params = OutputStreamParams.make_from_network_group(network, quantized=True, format_type=FormatType.UINT8)
    with network.activate(application_params, input_streams_params=input_streams_params,
                output_streams_params=output_streams_params) as activated_network:
        send_times = Array(ctypes.c_double, [0.0] * 100)
        recv_times = Array(ctypes.c_double, [0.0] * 100)
        end_time = Value(ctypes.c_double, 0.0)
        
        if streaming_mode=='full':
            input_shapes = {layer_info.name: layer_info.shape for layer_info in hef.get_input_layers_info()}
            recv_process = Process(target=_recv_process_full_hef, args=(activated_network, iterations, end_time, recv_times))
            send_process = Process(target=_send_process_full_hef, args=(activated_network, iterations, input_shapes, fps, send_times))
        else:
            input_streams  = {activated_network.get_input_by_name(l.name): l.shape for l in hef.get_input_layers_info()}
            send_process   = Process(target=_send_process_hw_only_hef, args=(input_streams, iterations, fps, send_times))
            output_streams = [activated_network.get_output_by_name(l.name) for l in hef.get_output_layers_info()]
            recv_process   = Process(target=_recv_process_hw_only_hef, args=(output_streams, iterations, end_time, recv_times))
        start = time.time()
        try:
            logger.info("[{}] Starting Inference".format(start))
            send_process.start()
            recv_process.start()
        except KeyboardInterrupt:
            logger.info("Interrupted by the user, stopping..")
            send_process.terminate()
            recv_process.terminate()
        except Exception:
            logger.info("Exception happened, stopping..")
            send_process.terminate()
            recv_process.terminate()
        finally:
            send_process.join()
            recv_process.join()
        logger.info("[{}] Finished Inference".format(end_time.value))

        latencies = [r-t for r,t in zip(recv_times, send_times)]
        logger.info("-------------------------------------")
        logger.info(" Infer Time:      {:.3f} sec".format(end_time.value - start))
        logger.info(" Average FPS:     {:.3f}".format(iterations/(end_time.value - start)))
        logger.info(" Average Latency: {:.3f} ms".format(np.average(latencies) * 1000.0))
        logger.info("-------------------------------------")

def main():
    logger.basicConfig(level=logger.INFO)
    args = arg_prep()
    logger.info('Reading HEF from: {}'.format(args.hef))
    config = HEF(args.hef)
    with HailoPcieObject() as target:
        run_hef(target, config, args.iterations, args.mode, args.fps)
if __name__ == "__main__":
    main()
