package com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.human;

import android.app.Application;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Matrix;

import com.qualcomm.qti.snpe.FloatTensor;
import com.qualcomm.qti.snpe.NeuralNetwork;
import com.qualcomm.qti.snpe.SNPE;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.InferenceResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.DLCInference;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.Bbox;

import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.opencv.android.Utils;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Scalar;
import org.opencv.imgproc.Imgproc;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ScrfdPersonDetector extends DLCInference {
    static String LOGTAG = ScrfdPersonDetector.class.getSimpleName();

    /** variable for model I/O */
    protected String INPUT_LAYER;
    protected float[] floatArrayInputValues;
    protected FloatTensor inputTensor;
    protected Map<String, FloatTensor> modelInputSource;
    protected Map<String, FloatTensor> modelOutputSource;

    /** private model params implement here*/
    private float mRatioWidth;
    private float mRatioHeight;
    public static float CONFIDENCE_THRESHOLD = 0.4f;
    public static float TRUST_THRESHOLD = 0.95f;
    public static float IOU_THRESHOLD = 0.3f;
    private final int[] STRIDE_LIST = new int[]{8, 16, 32, 64, 128};

    /** variable for cache-pre-process process*/
    private Scalar normlized_subtraction = new Scalar(127.5f, 127.5f, 127.5f);
    private Scalar normlized_divide = new Scalar(128f, 128f, 128f);
    private Mat processMatRGBA;
    private Mat processMatRGB;
    private Mat processMat32F;

    /** main constructor */
    public ScrfdPersonDetector(@NotNull Context context, Application application, int modelResId, NeuralNetwork.Runtime runtimeMode) {
        super(640, 384);
        final InputStream modelInputStream = context.getResources().openRawResource(modelResId);
        try {
            network = new SNPE.NeuralNetworkBuilder(application)
                    .setDebugEnabled(false)
                    .setRuntimeOrder(
                            runtimeMode,
                            NeuralNetwork.Runtime.CPU
                    )
                    .setModel(modelInputStream, modelInputStream.available())
                    .setOutputLayers("Sigmoid_158", "Sigmoid_179", "Sigmoid_200", "Sigmoid_221", "Sigmoid_242",
                            "Reshape_161", "Reshape_182", "Reshape_203", "Reshape_224", "Reshape_245")
                    //.setUseUserSuppliedBuffers()
                    .setCpuFallbackEnabled(true)
                    .setPerformanceProfile(NeuralNetwork.PerformanceProfile.HIGH_PERFORMANCE)
                    .build();
        } catch (IOException e) {
            e.printStackTrace();
        }

        /* setup input/output layer and pre-pair Mat for pre-process if model build success */
        if (network != null){
            floatArrayInputValues = new float[IMAGE_WIDTH * IMAGE_HEIGHT * 3];
            INPUT_LAYER = network.getInputTensorsNames().iterator().next();
            inputTensor = network.createFloatTensor(network.getInputTensorsShapes().get(INPUT_LAYER));
            modelInputSource = new HashMap<>();
            processMatRGBA = new Mat();
            processMatRGB = new Mat();
            processMat32F = new Mat();
        }
    }

    @Override
    protected void setModelConfig() {
        CONFIDENCE_THRESHOLD = 0.4f;
        IOU_THRESHOLD = 0.3f;
        TRUST_THRESHOLD = 0.95f;
    }

    @Override
    protected void preProcess(@NotNull final Bitmap inputBitmap) {
        mRatioWidth =   ((float) IMAGE_WIDTH) / ((float)inputBitmap.getWidth());
        mRatioHeight =  ((float) IMAGE_HEIGHT) / ((float)inputBitmap.getHeight());

        final Matrix scalingMatrix = new Matrix();
        scalingMatrix.postScale(mRatioWidth, mRatioHeight);
        //Log.d(LOGTAG, "check input Bitmap: " + inputBitmap.getWidth() + "x" + inputBitmap.getHeight() + " scale=" + mRatioWidth +":" + mRatioHeight);
        Bitmap processBitmap = Bitmap.createBitmap(inputBitmap,
                0, 0,
                inputBitmap.getWidth(), inputBitmap.getHeight(),
                scalingMatrix, false);
        if (inputBitmap.getConfig() != Bitmap.Config.ARGB_8888 || !inputBitmap.isMutable()){
            processBitmap = inputBitmap.copy(Bitmap.Config.ARGB_8888, true);
        }
        Utils.bitmapToMat(processBitmap, processMatRGBA);
        Imgproc.cvtColor(processMatRGBA , processMatRGB , 3);//COLOR_RGBA2RGB
        processMatRGB.convertTo(processMat32F, CvType.CV_32F);
        Core.subtract(processMat32F, normlized_subtraction, processMat32F);
        Core.divide(processMat32F, normlized_divide, processMat32F);

        /* Put pre-processed image to tensor and pre-pair run inference from model */
        prepairModelTensorsInput(processMat32F);
    }

    @Override
    protected void preProcess(InferenceResult inferenceResult) {

    }

    protected void prepairModelTensorsInput(@NotNull Mat preprocessedMat){
        preprocessedMat.get(0,0, floatArrayInputValues);
        inputTensor.write(floatArrayInputValues, 0, floatArrayInputValues.length);
        modelInputSource.put(INPUT_LAYER,inputTensor);
    }

    @Override
    public HDResult runInference(@NotNull InferenceResult inputInference) {
        Bitmap inputBitmap = inputInference.getProcessBitmap();
        //long startPre = System.currentTimeMillis();
        preProcess(inputBitmap);
        //long pre = System.currentTimeMillis() - startPre;

        //long startDetect = System.currentTimeMillis();
        modelOutputSource = network.execute(modelInputSource);
        //long runtime = startDetect - System.currentTimeMillis();


        //long startDecode = System.currentTimeMillis();
        HDResult hdResult = postProcess(inputInference);
        //long decodeRuntime = System.currentTimeMillis() - startDecode;

        //Log.d(LOGTAG + "_checkRuntime", "check humanDetect runtime: pre = "+ pre +"ms" + " | runtime = " + runtime + "ms | decode = " + decodeRuntime+"ms");
        return hdResult;
    }

    @Override
    protected HDResult postProcess(InferenceResult inputInference) {
        Map<Integer, float[]> scoreList = new HashMap<>();
        Map<Integer, float[]> vectorBoxDistanceList = new HashMap<>();

        for (Map.Entry<String, FloatTensor> output : modelOutputSource.entrySet()) {
            FloatTensor outputTensor = output.getValue();
            switch (output.getKey()) {
                case "498":
                    float[] score_1 = new float[outputTensor.getSize()];
                    outputTensor.read(score_1, 0, score_1.length);
                    scoreList.put(STRIDE_LIST[0], score_1);
                    //Log.d("_debugHumanDetectV2", "score_1: " + Arrays.toString(score_1));
                    //Log.d("_debugHumanDetectV2", "score_1: " + score_1.length);
                    break;
                case "519":
                    float[] score_2 = new float[outputTensor.getSize()];
                    outputTensor.read(score_2, 0, score_2.length);
                    scoreList.put(STRIDE_LIST[1], score_2);
                    //Log.d("_debugHumanDetectV2", "score_2: " + Arrays.toString(score_2));
                    //Log.d("_debugHumanDetectV2", "score_2: " + score_2.length);
                    break;
                case "540":
                    float[] score_3 = new float[outputTensor.getSize()];
                    outputTensor.read(score_3, 0, score_3.length);
                    scoreList.put(STRIDE_LIST[2], score_3);
                    //Log.d("_debugHumanDetectV2", "score_3: " + Arrays.toString(score_3));
                    //Log.d("_debugHumanDetectV2", "score_3: " + score_3.length);
                    break;
                case "561":
                    float[] score_4 = new float[outputTensor.getSize()];
                    outputTensor.read(score_4, 0, score_4.length);
                    scoreList.put(STRIDE_LIST[3], score_4);
                    //Log.d("_debugHumanDetectV2", "score_4: " + Arrays.toString(score_4));
                    //Log.d("_debugHumanDetectV2", "score_4: " + score_4.length);
                    break;
                case "582":
                    float[] score_5 = new float[outputTensor.getSize()];
                    outputTensor.read(score_5, 0, score_5.length);
                    scoreList.put(STRIDE_LIST[4], score_5);
                    //Log.d("_debugHumanDetectV2", "score_5: " + Arrays.toString(score_5));
                    //Log.d("_debugHumanDetectV2", "score_5: " + score_5.length);
                    break;
                case "501":
                    float[] loc_1 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_1, 0, loc_1.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[0], loc_1);
                    //Log.d("_debugHumanDetectV2", "loc_1: " + Arrays.toString(loc_1));
                    //Log.d("_debugHumanDetectV2", "loc_1: " + loc_1.length);
                    break;
                case "522":
                    float[] loc_2 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_2, 0, loc_2.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[1], loc_2);
                    //Log.d("_debugHumanDetectV2", "loc_2: " + Arrays.toString(loc_2));
                    //Log.d("_debugHumanDetectV2", "loc_2: " + loc_2.length);
                    break;
                case "543":
                    float[] loc_3 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_3, 0, loc_3.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[2], loc_3);
                    //Log.d("_debugHumanDetectV2", "loc_3: " + Arrays.toString(loc_3));
                    //Log.d("_debugHumanDetectV2", "loc_3: " + loc_3.length);
                    break;
                case "564":
                    float[] loc_4 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_4, 0, loc_4.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[3], loc_4);
                    //Log.d("_debugHumanDetectV2", "loc_4: " + Arrays.toString(loc_4));
                    //Log.d("_debugHumanDetectV2", "loc_4: " + loc_4.length);
                    break;
                case "585":
                    float[] loc_5 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_5, 0, loc_5.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[4], loc_5);
                    //Log.d("_debugHumanDetectV2", "loc_5: " + Arrays.toString(loc_5));
                    //Log.d("_debugHumanDetectV2", "loc_5: " + loc_5.length);
                    break;
            }
        }

        //long startDecode = System.currentTimeMillis();
        //List<box> bboxes = decodeScrfdPerson(scoreList, vectorBoxDistanceList);
        //long runtimeDecode = System.currentTimeMillis() - startDecode;
        //Log.d("runtimeHumanDetect", "runtimeDecode = " + runtimeDecode +"ms");
        return decodeScrfdPerson(inputInference, scoreList, vectorBoxDistanceList);
    }

    @NotNull
    @Contract("scoreList, vectorBoxDistanceList -> new DetectionResult")
    private HDResult decodeScrfdPerson(InferenceResult previousInference, Map<Integer, float[]> scoreList, Map<Integer, float[]> vectorBoxDistanceList){
        List<Bbox> listBboxes = new ArrayList<>();

        long startDecodeSCRFD = System.currentTimeMillis();
        for (int stride : STRIDE_LIST){
            listBboxes.addAll(buildBbox(previousInference, vectorBoxDistanceList.get(stride), scoreList.get(stride), stride));
        }
        long decodeTime = System.currentTimeMillis() - startDecodeSCRFD;
        //Log.d(LOGTAG, "check decodeTime scrfd:" + decodeTime);

        return new HDResult(previousInference, nms(listBboxes));
    }

    @NotNull
    private List<Bbox> buildBbox (InferenceResult previousInference, float[] distanceList, float[] scoreList, int stride){
        List<Bbox> listBbox = new ArrayList<>();
        //int currDistInd = -1;
        int currScoreInd = -1;

        //loop for vertical anchors
        for (int y = 0; y < IMAGE_HEIGHT; y += stride){
            for (int x = 0; x < IMAGE_WIDTH; x += stride ){
                //loop for vertical anchors
                if (scoreList[++currScoreInd] > CONFIDENCE_THRESHOLD) {
                    float x1 = (x - distanceList[currScoreInd * 4 + 0] * stride) / mRatioWidth;
                    float y1 = (y - distanceList[currScoreInd * 4 + 1] * stride) / mRatioHeight;
                    float x2 = (x + distanceList[currScoreInd * 4 + 2] * stride) / mRatioWidth;
                    float y2 = (y + distanceList[currScoreInd * 4 + 3] * stride) / mRatioHeight;
                    listBbox.add(new Bbox(previousInference, x1, y1, x2, y2, scoreList[currScoreInd]));
                }
                //loop for horizontal anchor
                if (scoreList[++currScoreInd] > CONFIDENCE_THRESHOLD) {
                    float x1 = (x - distanceList[currScoreInd * 4 + 0] * stride) / mRatioWidth;
                    float y1 = (y - distanceList[currScoreInd * 4 + 1] * stride) / mRatioHeight;
                    float x2 = (x + distanceList[currScoreInd * 4 + 2] * stride) / mRatioWidth;
                    float y2 = (y + distanceList[currScoreInd * 4 + 3] * stride) / mRatioHeight;
                    listBbox.add(new Bbox(previousInference, x1, y1, x2, y2, scoreList[currScoreInd]));
                }
            }
        }

        Collections.sort(listBbox);
        return listBbox;
    }

    @NotNull
    private List<Bbox> nms(@NotNull List<Bbox> bboxes) {
        List<Bbox> selected = new ArrayList<>();
        for (Bbox boxA : bboxes) {
            boolean shouldSelect = true;
            // Does the current box overlap one of the selected boxes more than the
            // given threshold amount? Then it's too similar, so don't keep it.
            for (Bbox boxB : selected) {
                if (IOU(boxA, boxB) > IOU_THRESHOLD) {
                    shouldSelect = false;
                    break;
                }
            }
            // This bounding box did not overlap too much with any previously selected
            // bounding box, so we'll keep it.
            if (shouldSelect) {
                selected.add(boxA);
            }
        }
        return selected;
    }

    private float IOU(@NotNull Bbox a, Bbox b ) {
        float areaA = (a.x2 - a.x1) * (a.y2 - a.y1);
        if (areaA <= 0) {
            return 0;
        }
        float areaB = (b.x2 - b.x1) * (b.y2 - b.y1);
        if (areaB <= 0) {
            return 0;
        }
        float intersectionMinX = Math.max(a.x1, b.x1);
        float intersectionMinY = Math.max(a.y1, b.y1);
        float intersectionMaxX = Math.min(a.x2, b.x2);
        float intersectionMaxY = Math.min(a.y2, b.y2);
        float intersectionArea = Math.max(intersectionMaxY - intersectionMinY, 0) *
                Math.max(intersectionMaxX - intersectionMinX, 0);
        return intersectionArea / (areaA + areaB - intersectionArea);
    }

    @Override
    public void release() {
        network.release();
        releaseTensors(modelInputSource);
    }

}
