#ifndef _YOLOV5_POSTPROCESS_H_
#define _YOLOV5_POSTPROCESS_H_

#include <vector>
#include <cmath>
#include <stdio.h>
#include <stdlib.h>
#include <functional>
#include <example_utils.hpp>
#include <hailo/hailort.h>
#ifdef DEBUG
#include "npy.hpp"
#endif

constexpr int FEATURE_MAP_SIZE1 = 20;
constexpr int FEATURE_MAP_SIZE2 = 40;
constexpr int FEATURE_MAP_SIZE3 = 80;
constexpr int FEATURE_MAP_CHANNELS = 85;
constexpr int IMAGE_SIZE = 640;
constexpr int ANCHORS_NUM = 3;
constexpr float32_t IOU_THRESHOLD = 0.45f;
constexpr int MAX_BOXES = 100;
constexpr int CONF_CHANNEL_OFFSET = 4;
constexpr int CLASS_CHANNEL_OFFSET = 5;


struct DetectionObject {
    float ymin, xmin, ymax, xmax, confidence;
    int class_id;

    DetectionObject(float32_t ymin, float32_t xmin, float32_t ymax, float32_t xmax, float32_t confidence, int class_id):
        ymin(ymin), xmin(xmin), ymax(ymax), xmax(xmax), confidence(confidence), class_id(class_id)
        {}

    bool operator<(const DetectionObject &s2) const {
        return this->confidence > s2.confidence;
    }
};

float32_t fix_scale(float32_t input, float32_t qp_scale, float32_t qp_zp)
{
  return (input - qp_zp) * qp_scale;
}

float32_t iou_calc(const DetectionObject &box_1, const DetectionObject &box_2) {
    const float32_t width_of_overlap_area = std::min(box_1.xmax, box_2.xmax) - std::max(box_1.xmin, box_2.xmin);
    const float32_t height_of_overlap_area = std::min(box_1.ymax, box_2.ymax) - std::max(box_1.ymin, box_2.ymin);
    const float32_t positive_width_of_overlap_area = std::max(width_of_overlap_area, 0.0f);
    const float32_t positive_height_of_overlap_area = std::max(height_of_overlap_area, 0.0f);
    const float32_t area_of_overlap = positive_width_of_overlap_area * positive_height_of_overlap_area;
    const float32_t box_1_area = (box_1.ymax - box_1.ymin)  * (box_1.xmax - box_1.xmin);
    const float32_t box_2_area = (box_2.ymax - box_2.ymin)  * (box_2.xmax - box_2.xmin);
    return area_of_overlap / (box_1_area + box_2_area - area_of_overlap);
}

void extract_boxes(std::vector<uint8_t> &fm, float32_t qp_zp, float32_t qp_scale, int feature_map_size, 
		   int* anchors, std::vector<DetectionObject>& objects, float32_t& thr) {
    float32_t  confidence, x, y, h, w, xmin, ymin, xmax, ymax, conf_max = 0.0f;
    int add = 0, anchor = 0, chosen_row = 0, chosen_col = 0, chosen_cls = -1;
    float32_t cls_prob, prob_max;
    // channels 0-3 are box coordinates, channel 4 is the confidence, and channels 5-84 are classes
    for (int row = 0; row < feature_map_size; ++row) {
        for (int col = 0; col < feature_map_size; ++col) {
            prob_max = 0;
            for (int a = 0; a < ANCHORS_NUM; ++a) {
                add = FEATURE_MAP_CHANNELS * ANCHORS_NUM * feature_map_size * row + FEATURE_MAP_CHANNELS * ANCHORS_NUM * col + FEATURE_MAP_CHANNELS * a + CONF_CHANNEL_OFFSET;
                confidence = fix_scale(fm[add], qp_scale,  qp_zp);
                for (int c = CLASS_CHANNEL_OFFSET; c < FEATURE_MAP_CHANNELS; ++c) {
                    add = FEATURE_MAP_CHANNELS * ANCHORS_NUM * feature_map_size * row + FEATURE_MAP_CHANNELS * ANCHORS_NUM * col + FEATURE_MAP_CHANNELS * a + c;
                    // final confidence: box confidence * class probability
                    cls_prob = fm[add];
                    if (cls_prob > prob_max) {
		                conf_max = fix_scale(cls_prob, qp_scale,  qp_zp) * confidence;
                        chosen_cls = c - CLASS_CHANNEL_OFFSET + 1;
                        prob_max = cls_prob;
                        anchor = a;
                        chosen_row = row;
                        chosen_col = col;
                    }
                }
            }
            if (conf_max >= thr) {
                add = FEATURE_MAP_CHANNELS * ANCHORS_NUM * feature_map_size * chosen_row + FEATURE_MAP_CHANNELS * ANCHORS_NUM * chosen_col + FEATURE_MAP_CHANNELS * anchor;
                x = (fix_scale(fm[add], qp_scale,  qp_zp) * 2.0f - 0.5f + chosen_col) / feature_map_size;
                y = (fix_scale(fm[add + 1], qp_scale,  qp_zp) * 2.0f - 0.5f +  chosen_row) / feature_map_size;
                w = pow(2.0f * (fix_scale(fm[add + 2], qp_scale,  qp_zp)), 2.0f) * anchors[anchor * 2] / IMAGE_SIZE;
                h = pow(2.0f * (fix_scale(fm[add + 3], qp_scale,  qp_zp)), 2.0f) * anchors[anchor * 2 + 1] / IMAGE_SIZE;
                xmin = (x - (w / 2.0f)) * IMAGE_SIZE;
                ymin = (y - (h / 2.0f)) * IMAGE_SIZE;
                xmax = (x + (w / 2.0f)) * IMAGE_SIZE;
                ymax = (y + (h / 2.0f)) * IMAGE_SIZE;
                objects.push_back(DetectionObject(ymin, xmin, ymax, xmax, conf_max, chosen_cls));
            }
        }
    }
}
size_t _decode(std::vector<uint8_t> &fm1, std::vector<uint8_t> &fm2, std::vector<uint8_t> &fm3, int* anchors1, int* anchors2, int* anchors3,
// xt::xarray<float32_t, xt::layout_type::row_major> _decode(std::vector<uint8_t> &fm1, std::vector<uint8_t> &fm2, std::vector<uint8_t> &fm3, int* anchors1, int* anchors2, int* anchors3,
                 qp_zp_scale_t qp_zp_scale, float32_t& thr, std::vector<float32_t> &results) {

    size_t num_boxes = 0;
    std::vector<DetectionObject> objects;
    // std::vector<std::vector<float32_t>> results;
    objects.reserve(MAX_BOXES);

    // feature map1/2/3
    extract_boxes(fm1, qp_zp_scale.qp_zp_1, qp_zp_scale.qp_scale_1, FEATURE_MAP_SIZE1, anchors1, objects, thr);
    extract_boxes(fm2, qp_zp_scale.qp_zp_2, qp_zp_scale.qp_scale_2, FEATURE_MAP_SIZE2, anchors2, objects, thr);
    extract_boxes(fm3, qp_zp_scale.qp_zp_3, qp_zp_scale.qp_scale_3, FEATURE_MAP_SIZE3, anchors3, objects, thr);

    num_boxes = objects.size();

    // filter by overlapping boxes
    if (objects.size() > 0) {
        std::sort(objects.begin(), objects.end());
        for (unsigned int i = 0; i < objects.size(); ++i) {
            if (objects[i].confidence <= thr)
                continue;
            for (unsigned int j = i + 1; j < objects.size(); ++j) {
                if (objects[i].class_id == objects[j].class_id && objects[j].confidence >= thr) {
                    if (iou_calc(objects[i], objects[j]) >= IOU_THRESHOLD) {
                        objects[j].confidence = 0;
                        num_boxes -= 1;
                    }
                }
            }
        }
    }

    // copy the results
    if (num_boxes > 0) {
        int box_ptr = 0;
        // xt::xarray<int>::shape_type shape({num_boxes, 6});
        // xt::xarray<float32_t, xt::layout_type::row_major> results(shape);

        // results = (float32_t *)calloc(num_boxes * 6, sizeof(float32_t));
        results.resize(num_boxes * 6);
        for (const auto &obj: objects) {
            if (obj.confidence >= thr) {
                /*
                results(box_ptr, 0) = obj.ymin / IMAGE_SIZE;
                results(box_ptr, 1) = obj.xmin / IMAGE_SIZE;
                results(box_ptr, 2) = obj.ymax / IMAGE_SIZE;
                results(box_ptr, 3) = obj.xmax / IMAGE_SIZE;
                results(box_ptr, 4) = (float32_t)obj.class_id;
                results(box_ptr, 5) = obj.confidence;
                */
                results[box_ptr*6 + 0] = obj.ymin / IMAGE_SIZE;
                results[box_ptr*6 + 1] = obj.xmin / IMAGE_SIZE;
                results[box_ptr*6 + 2] = obj.ymax / IMAGE_SIZE;
                results[box_ptr*6 + 3] = obj.xmax / IMAGE_SIZE;
                results[box_ptr*6 + 4] = (float32_t)obj.class_id;
                results[box_ptr*6 + 5] = obj.confidence;                
                box_ptr += 1;
            }
        }
        // return results;
        return num_boxes;
    } else {
        // return xt::zeros<float32_t>({6});
        results.resize(0);
        // float32_t results[6] = {0, 0, 0, 0, 0, 0};
        return 0;
    }
}

/*
    Given all parameters this function returns boxes with class and confidence
    Inputs:
        feature map1: 20x20x255
        feature map2: 40x40x255
        feature map3: 80x80x255
    Outputs:
        final boxes for display (Nx6) - ymin, xmin, ymax, xmax, class, conf
*/
size_t get_detections(std::vector<uint8_t> fm1, std::vector<uint8_t> fm2, std::vector<uint8_t> fm3,
// xt::xarray<float32_t> get_detections(std::vector<uint8_t> fm1, std::vector<uint8_t> fm2, std::vector<uint8_t> fm3,
                           qp_zp_scale_t qp_zp_scale, float32_t thr, std::vector<float32_t> &results) {

    int anchors1[] = {116, 90, 156, 198, 373, 326};
    int anchors2[] = {30,  61, 62,  45,  59,  119};
    int anchors3[] = {10,  13, 16,  30,  33,  23};

    return _decode(std::ref(fm1), std::ref(fm2), std::ref(fm3), anchors1, anchors2, anchors3, qp_zp_scale, thr, results);
}

//https://stackoverflow.com/questions/28562401/resize-an-image-to-a-square-but-keep-aspect-ratio-c-opencv
cv::Mat letterbox( const cv::Mat& img, int target_width = 640, int color = 114 )
{
    int width = img.cols,
       height = img.rows;

    cv::Mat square( target_width, target_width, img.type(), cv::Scalar(color, color, color) );

    int max_dim = ( width >= height ) ? width : height;
    float32_t scale = ( ( float32_t ) target_width ) / max_dim;
    cv::Rect roi;
    if (width >= height) {
        roi.width = target_width;
        roi.x = 0;
        roi.height = height * scale;
        roi.y = ( target_width - roi.height ) / 2;
    } else {
        roi.y = 0;
        roi.height = target_width;
        roi.width = width * scale;
        roi.x = ( target_width - roi.width ) / 2;
    }

    cv::resize( img, square( roi ), roi.size() );
    return square;
}

typedef cv::Vec<uchar, 12> Vec12b;

cv::Mat yolov5_input_reshape(cv::Mat &input)
{
    cv::Mat output( 320, 320, CV_8UC(12) );
    int cols = input.cols;
    int rows = input.rows;

    for (int i = 0; i < cols; i+=2) {
        for (int j = 0; j < rows; j+=2) {
            Vec12b &output_pixel = output.at<Vec12b>(i/2, j/2);
            for (int inner_i=0; inner_i<2; inner_i++) {    
                for (int inner_j=0; inner_j<2; inner_j++) {    
                    cv::Vec3b &pixel = input.at<cv::Vec3b>(i+inner_i, j+inner_j);
                    output_pixel.val[(inner_i*2 + inner_j)* 3 + 0] = pixel.val[0];
                    output_pixel.val[(inner_i*2 + inner_j)* 3 + 1] = pixel.val[1];
                    output_pixel.val[(inner_i*2 + inner_j)* 3 + 2] = pixel.val[2];
                }
            }
        }
    }
    return output;
}

cv::Mat yolov5_pre_process(cv::Mat &org_frame)
{
    cv::Mat letter = letterbox(org_frame);
    cv::imwrite("letter_boxed.jpg", letter);
    // cv::Mat output = yolov5_input_reshape(letter);
    // size_t shape [] = {320, 320, 12};
    // npy::SaveArrayPointerAsNumpy("cpp-prep.npy", false, 3, shape, letter.data);
    // cv::Mat output2 = letter.reshape(12/*Channels*/, 320/*Rows*/);
    return letter;
}

std::string get_coco_name_from_int(int cls)
{
    std::string result = "N/A";
    switch(cls) {
		case 0: result = "__background__";break;
		case 1: result = "person";break;
		case 2: result = "bicycle";break;
		case 3: result = "car";break;
		case 4: result = "motorcycle";break;
		case 5: result = "airplane";break;
		case 6: result = "bus";break;
		case 7: result = "train";break;
		case 8: result = "truck";break;
		case 9: result = "boat";break;
		case 10: result = "traffic light";break;
		case 11: result = "fire hydrant";break;
		case 12: result = "stop sign";break;
		case 13: result = "parking meter";break;
		case 14: result = "bench";break;
		case 15: result = "bird";break;
		case 16: result = "cat";break;
		case 17: result = "dog";break;
		case 18: result = "horse";break;
		case 19: result = "sheep";break;
		case 20: result = "cow";break;
		case 21: result = "elephant";break;
		case 22: result = "bear";break;
		case 23: result = "zebra";break;
		case 24: result = "giraffe";break;
		case 25: result = "backpack";break;
		case 26: result = "umbrella";break;
		case 27: result = "handbag";break;
		case 28: result = "tie";break;
		case 29: result = "suitcase";break;
		case 30: result = "frisbee";break;
		case 31: result = "skis";break;
		case 32: result = "snowboard";break;
		case 33: result = "sports ball";break;
		case 34: result = "kite";break;
		case 35: result = "baseball bat";break;
		case 36: result = "baseball glove";break;;
		case 37: result = "skateboard";break;
		case 38: result = "surfboard";break;
		case 39: result = "tennis racket";break;
		case 40: result = "bottle";break;
		case 41: result = "wine glass";break;
		case 42: result = "cup";break;
		case 43: result = "fork";break;
		case 44: result = "knife";break;
		case 45: result = "spoon";break;
		case 46: result = "bowl";break;
		case 47: result = "banana";break;
		case 48: result = "apple";break;
		case 49: result = "sandwich";break;
		case 50: result = "orange";break;
		case 51: result = "broccoli";break;
		case 52: result = "carrot";break;
		case 53: result = "hot dog";break;
		case 54: result = "pizza";break;
		case 55: result = "donut";break;
		case 56: result = "cake";break;
		case 57: result = "chair";break;
		case 58: result = "couch";break;
		case 59: result = "potted plant";break;
		case 60: result = "bed";break;
		case 61: result = "dining table";break;
		case 62: result = "toilet";break;
		case 63: result = "tv";break;
		case 64: result = "laptop";break;
		case 65: result = "mouse";break;
		case 66: result = "remote";break;
		case 67: result = "keyboard";break;
		case 68: result = "cell phone";break;
		case 69: result = "microwave";break;
		case 70: result = "oven";break;
		case 71: result = "toaster";break;
		case 72: result = "sink";break;
		case 73: result = "refrigerator";break;
		case 74: result = "book";break;
		case 75: result = "clock";break;
		case 76: result = "vase";break;
		case 77: result = "scissors";break;
		case 78: result = "teddy bear";break;
		case 79: result = "hair drier";break;
		case 80: result = "toothbrush";break;
    }
	return result;
}

#endif
