#!/usr/bin/env python3
'''
A Simple Python example showing how to:
  1. Load a compiled network (HEF) to a Hailo-8 device
  2. setup the device for running inference
  3. Streaming in images to the device
  4. Reading out results
'''

from __future__ import division

import os
import time
import argparse as ap
import numpy as np
from zenlog import logging as logger
from colorama import Fore
from tqdm import trange
import ctypes
from multiprocessing import Process, Value, Array
from hailo_platform import (HEF, PcieDevice, EthernetDevice,HailoStreamInterface, InferVStreams, ConfigureParams,
                            InputVStreamParams, OutputVStreamParams, InputVStreams, OutputVStreams, FormatType)

import yolov5_postprocess
import cv2 

from past.utils import old_div
import tensorflow as tf

TOTAL_IMAGE=1200

g_img2 = []

g_cap = cv2.VideoCapture("./full_mov_slow.mp4")

for i in range(0,TOTAL_IMAGE):
    g_ret, g_imgs = g_cap.read()
    g_imgs = cv2.resize(g_imgs, (640,640), cv2.IMREAD_UNCHANGED)
    g_img2.append(g_imgs)



def get_coco_name_from_int(cls):
    return {
        0: "__background__",
        1: "person",
	2:  "bicycle",
	3:  "car",
	4:  "motorcycle",
	5:  "airplane",
	6:  "bus",
	7:  "train",
	8:  "truck",
	9:  "boat",
	10:  "traffic light",
	11:  "fire hydrant",
	12:  "stop sign",
	13:  "parking meter",
	14:  "bench",
	15:  "bird",
	16:  "cat",
	17:  "dog",
	18:  "horse",
	19:  "sheep",
	20:  "cow",
	21:  "elephant",
	22:  "bear",
	23:  "zebra",
	24:  "giraffe",
	25:  "backpack",
	26:  "umbrella",
	27:  "handbag",
	28:  "tie",
	29:  "suitcase",
	30:  "frisbee",
	31:  "skis",
	32:  "snowboard",
	33:  "sports ball",
	34:  "kite",
	35:  "baseball bat",
	36:  "baseball glove",
	37:  "skateboard",
	38:  "surfboard",
	39:  "tennis racket",
	40:  "bottle",
	41:  "wine glass",
	42:  "cup",
	43:  "fork",
	44:  "knife",
	45:  "spoon",
	46:  "bowl",
	47:  "banana",
	48:  "apple",
	49:  "sandwich",
	50:  "orange",
	51:  "broccoli",
	52:  "carrot",
	53:  "hot dog",
	54:  "pizza",
	55:  "donut",
	56:  "cake",
	57:  "chair",
	58:  "couch",
	59:  "potted plant",
	60:  "bed",
	61:  "dining table",
	62:  "toilet",
	63:  "tv",
	64:  "laptop",
	65:  "mouse",
	66:  "remote",
	67:  "keyboard",
	68:  "cell phone",
	69:  "microwave",
	70:  "oven",
	71:  "toaster",
	72:  "sink",
	73:  "refrigerator",
	74:  "book",
	75:  "clock",
	76:  "vase",
	77:  "scissors",
	78:  "teddy bear",
	79:  "hair drier",
	80:  "toothbrush"
    }.get(cls, "None")  


# https://github.com/anishathalye/imagenet-simple-labels
IMAGENET_LBL_FILE = 'imagenet-simple-labels.json'
INPUT_IMAGE_DIR   = './images_net'
HEF_FILE          = 'yolov5m.hef'



def arg_prep():
    parser = ap.ArgumentParser()
    parser.add_argument('--hef', help='Point to the HEF')
    parser.add_argument('--mode', help='Choose the communication mode [hw-only|full]', type=str, default='full')
    parser.add_argument('--iterations', help='How many repeasts on the picture stream (Default:100)', type=int, default=TOTAL_IMAGE)
    parser.add_argument('--power', help='Enable Power measurement', default=False, action='store_true')
    parser.add_argument('--source', help='Specify image or video source', default=None)
    parser.add_argument('--interface', help='Specify the physical interface, pcie or udp', default='pcie')
    parser.add_argument('--fps', help='Emulate a source FPS', type=int, default='0')
    # parser.add_argument('--ip', help='Set the IP of the Hailo device', type=str, default='10.0.0.100')
    args = parser.parse_args()
    return args










def _recv_process(activated_network, num_frames, end_time, recv_times):
    global g_img2
    global g_start
    thr=0.2;
    vstreams_params = OutputVStreamParams.make_from_network_group(activated_network)

    logger.info('RECV process Full Started')
    with OutputVStreams(activated_network, vstreams_params) as vstreams:
        #output_streams = {name : recv.get_output_by_name(name) for name in activated_network.target.sorted_output_layer_names}
#        [logger.info("Output: {}".format(name)) for name in activated_network.target.sorted_output_layer_names]
        outputs = dict()
        for y in range(num_frames):
            x=0
            for vstream in vstreams:
                #data = vstream.recv()
                #print("recv frame", data.shape. vstreams_params.name())
                #if x<2:
                #print("recv frame", outputs[x].shape)
                if x==2:
                    #print("recv frame", data.shape. vstreams_params.name())
                    #print("recv frame")

                    detections = yolov5_postprocess.run(vstream.recv(), outputs[0], outputs[1], 0.0, 0, 0.0, 0.003921144641935825, 0.0, 0.00392112368717789659, 0.2)

                    if y%6 == 0:
                        num_detections = int(detections.shape[0])
                        classes = detections[:, 4].astype("int32")

                        #for i in classes: 
                        #    print(get_coco_name_from_int(int(i)), end =" ") 
                        #print("")   

                        boxes = 640*(np.reshape(np.array([x[:4] for x in detections]), (-1, 4)))

                        fpsl = int(y/(time.time() - g_start))
                        strl = "FPS: " + str(fpsl)
                        for j in range(0,num_detections):
                            cv2.rectangle(g_img2[y], ((int(boxes[j][1]), int(boxes[j][0]))), ((int(boxes[j][3]), int(boxes[j][2]))), (80,220,80), 3)
                            cv2.putText(g_img2[y], strl, (50, 50), cv2.FONT_HERSHEY_PLAIN, 2.0, (0,0,255), 4);

                        cv2.imshow("Image", g_img2[y])
                        cv2.waitKey(1)
                else:
                    outputs[x] = vstream.recv() 
                x=x+1


        '''
        for i in trange(iterations, desc='INFO:Recv...', position=0, bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.GREEN, Fore.RESET)):
            for name, output_stream in output_streams.items():
                outputs[x] = output_stream.recv() 
                x=x+1
                if x==3:
                   detections = yolov5_postprocess.run(outputs[0], outputs[1], outputs[2], 0.0, 0, 0.0, 0.003921144641935825, 0.0, 0.00392112368717789659, 0.2)
                   num_detections = int(detections.shape[0])
                   classes = detections[:, 4].astype("int32")
				   
                    
                   for i in classes: 
                       print(get_coco_name_from_int(int(i)), end =" ") 
                   print("")   
                   boxes = 640*(np.reshape(np.array([x[:4] for x in detections]), (-1, 4)))

                   fpsl = int(y/(time.time() - g_start))
                   strl = "FPS: " + str(fpsl)
                   for j in range(0,num_detections):
                       cv2.rectangle(g_img2[y], ((int(boxes[j][1]), int(boxes[j][0]))), ((int(boxes[j][3]), int(boxes[j][2]))), (80,220,80), 3)
                       cv2.putText(g_img2[y], strl, (50, 50), cv2.FONT_HERSHEY_PLAIN, 2.0, (0,0,255), 4);

                   cv2.imshow("Image", g_img2[y])
                   cv2.waitKey(1)
                   x=0
                   y+=1
        '''



    end_time.value = time.time()
    logger.info("[{}] Finished Recving {} frames".format(end_time.value, num_frames))

def _send_process(configured_network, num_frames, input_shapes, fps, send_times):
    global g_img2
    vstreams_params = InputVStreamParams.make_from_network_group(configured_network)
    with InputVStreams(configured_network, vstreams_params) as vstreams:
        #input_streams  = {input_name : send_pipeline.get_input_by_name(input_name) for input_name, shape in input_shapes.items()}
        #data_per_input = {input_name : np.random.randint(256, size=(1,) + shape, dtype=np.uint8) for input_name, shape in input_shapes.items()}
        [logger.info("Input: {} with shape: {}".format(name, shape)) for name, shape in input_shapes.items()]
        if fps>0.0:
            logger.info("Emulate source FPS: {}".format(fps))

        cap = cv2.VideoCapture("./full_mov_slow.mp4")
        local_send_times = list()

        for i in range(num_frames):
            for vstream in vstreams:
                #print("send frame", buff.shape)
                #ret, g_img2[i] = cap.read()
                #g_img2[i] = cv2.resize(cap.read()[1], (640,640), cv2.IMREAD_UNCHANGED)
                #image_np  = np.expand_dims(np.asarray(cv2.resize(cap.read()[1], (640,640), cv2.IMREAD_UNCHANGED)), axis=0)
                vstream.send(np.expand_dims(np.asarray(cv2.resize(cap.read()[1], (640,640), cv2.IMREAD_UNCHANGED)), axis=0))

        '''
        i=0
        while i<TOTAL_IMAGE:

              ret, g_img2[i] = cap.read()
              g_img2[i] = cv2.resize(g_img2[i], (640,640), cv2.IMREAD_UNCHANGED)
              image_np  = np.expand_dims(np.asarray(g_img2[i]), axis=0)
              i+=1            

              try: 
                  for name, input_stream in input_streams.items():
                      input_stream.send(image_np)
              except HailoRTException as e:
                  if e.args and '0x4' in e.args and i>iterations*0.9:
                      pass 
        '''

def run_hef(target, hef, iterations, streaming_mode, fps):
    global g_start
    '''
    Receive a hef buffer, and a target device, and run ImageNet classifcation
    on a predefined directory containing JPG images of size 224x224
    '''
    logger.info("Loading HEF to target")
    
    # Configure network groups
    configure_params = ConfigureParams.create_from_hef(hef=hef, interface=HailoStreamInterface.ETH)
    network_groups = target.configure(hef, configure_params)
    network_group = network_groups[0]
    network_group_params = network_group.create_params()


    # Explanation about definition of the streams-
    #    quantized   - Whether to scale and zero-point the values from 0-255, if the input is uint8 can use as-is
    #    format_type - The input type UINT8, UINT16, FLOAT32, AUTO
    #print("##################### quantized ########################", quantized)
    input_streams_params  = InputVStreamParams.make_from_network_group(network_group, quantized=False, format_type=FormatType.UINT8)
    output_streams_params = OutputVStreamParams.make_from_network_group(network_group, quantized=True, format_type=FormatType.UINT8)

    with network_group.activate(network_group_params) as activated_network:
        send_times = Array(ctypes.c_double, [0.0] * 100)
        recv_times = Array(ctypes.c_double, [0.0] * 100)
        end_time = Value(ctypes.c_double, 0.0)
        

        input_shapes = {layer_info.name: layer_info.shape for layer_info in hef.get_input_layers_info()}
        recv_process = Process(target=_recv_process, args=(network_group, iterations, end_time, recv_times))
        send_process = Process(target=_send_process, args=(network_group, iterations, input_shapes, fps, send_times))
        start = time.time()
        g_start = start
        try:
            logger.info("[{}] Starting Inference".format(start))
            send_process.start()
            recv_process.start()
        except KeyboardInterrupt:
            logger.info("Interrupted by the user, stopping..")
            send_process.terminate()
            recv_process.terminate()
        except Exception:
            logger.info("Exception happened, stopping..")
            send_process.terminate()
            recv_process.terminate()
        finally:
            send_process.join()
            recv_process.join()
        logger.info("[{}] Finished Inference".format(end_time.value))

        latencies = [r-t for r,t in zip(recv_times, send_times)]
        logger.info("-------------------------------------")
        logger.info(" Infer Time:      {:.3f} sec".format(end_time.value - start))
        logger.info(" Average FPS:     {:.3f}".format(iterations/(end_time.value - start)))
        logger.info(" Average Latency: {:.3f} ms".format(np.average(latencies) * 1000.0))
        logger.info("-------------------------------------")

def main():
    logger.basicConfig(level=logger.INFO)
    args = arg_prep()
    logger.info('Reading HEF from: {}'.format(args.hef))
    hef = HEF(args.hef)

    # The target can be used as a context manager ("with" statement) to ensure it's released on time.
    # Here it's avoided for the sake of simplicity
    remote_ip = '10.0.0.93'
    with EthernetDevice(remote_ip) as target:
        run_hef(target, hef, args.iterations, args.mode, args.fps)

    #target = PcieDevice(hw_arch='hailo8')
    #with PcieDevice() as target:
    #    run_hef(target, hef, args.iterations, args.mode, args.fps)
if __name__ == "__main__":
    main()
