package com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.licensePlate;

import android.app.Application;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.graphics.PointF;
import android.util.Log;

import com.qualcomm.qti.snpe.FloatTensor;
import com.qualcomm.qti.snpe.NeuralNetwork;
import com.qualcomm.qti.snpe.SNPE;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.InferenceResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.DLCInference;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.Bbox;

import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.opencv.android.Utils;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Rect;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.imgproc.Imgproc;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ScrfdLicensePlateDetector extends DLCInference {
    static String LOGTAG = ScrfdLicensePlateDetector.class.getSimpleName();
    /** variable for model I/O */
    protected String INPUT_LAYER;
    protected float[] floatArrayInputValues;
    protected FloatTensor inputTensor;
    protected Map<String, FloatTensor> modelInputSource;
    protected Map<String, FloatTensor> modelOutputSource;

    /** private model params implement here*/
    private float mRatioWidth;
    private float mRatioHeight;
    private int cropPointX = 0;
    private int cropPointY = 0;

    private final int[] STRIDE_LIST = new int[]{8, 16, 32};
    private final String[] labels = new String[]{"moto", "truck", "bike", "car", "pedestrian", "bus", "ba_gac"};
    private final String DEFAULT_LABEL = "NONE";
    public final int NUM_ANCHOR_TYPES = 2; //horizon, vertical, square
    public static float CONFIDENCE_THRESHOLD = 0.4f;
    public static float IOU_THRESHOLD = 0.4f; // nms threshold


    /** variable for cache-pre-process process*/
    private Size size;
    private Scalar normalized_subtraction = new Scalar(127.5f, 127.5f, 127.5f);
    private Scalar normalized_divide = new Scalar(128f, 128f, 128f);
    private Mat processMatRGBA;
    private Mat processMatRGB;
    private Mat processMat32F;

    public ScrfdLicensePlateDetector(@NotNull Context context, Application application, int modelResId, NeuralNetwork.Runtime runtimeMode) {
        super(512,512);
        size = new Size(IMAGE_WIDTH, IMAGE_HEIGHT);
        final InputStream modelInputStream = context.getResources().openRawResource(modelResId);
        try {
            network = new SNPE.NeuralNetworkBuilder(application)
                    .setDebugEnabled(false)
                    .setRuntimeOrder(
                            runtimeMode,
                            NeuralNetwork.Runtime.CPU
                    )
                    .setModel(modelInputStream, modelInputStream.available())
                    .setOutputLayers("Sigmoid_207", "Reshape_210", "Reshape_213", "Sigmoid_182", "Reshape_185",
                            "Reshape_188", "Sigmoid_157", "Reshape_160", "Reshape_163")
                    //.setUseUserSuppliedBuffers()
                    .setCpuFallbackEnabled(true)
                    .setPerformanceProfile(NeuralNetwork.PerformanceProfile.BURST)
                    .build();
        } catch (IOException e) {
            e.printStackTrace();
        }

        /* setup input/output layer and pre-pair Mat for pre-process if model build success */
        if (network != null){
            floatArrayInputValues = new float[IMAGE_WIDTH * IMAGE_HEIGHT * 3];
            INPUT_LAYER = network.getInputTensorsNames().iterator().next();
            inputTensor = network.createFloatTensor(network.getInputTensorsShapes().get(INPUT_LAYER));
            modelInputSource = new HashMap<>();
            processMatRGBA = new Mat();
            processMatRGB = new Mat();
            processMat32F = new Mat();
        }
    }

    @Override
    public LPDResult runInference(InferenceResult inputInference) {
        Bitmap inputBitmap = inputInference.getProcessBitmap();

        long startPre = System.currentTimeMillis();
        preProcess(inputBitmap);
        long pre = System.currentTimeMillis() - startPre;

        long startDetect = System.currentTimeMillis();
        modelOutputSource = network.execute(modelInputSource);
        long runtime = System.currentTimeMillis() - startDetect;

        long startDecode = System.currentTimeMillis();
        LPDResult LPDResult = postProcess(inputInference);
        long decodeRuntime = System.currentTimeMillis() - startDecode;

        Log.d(LOGTAG + "_checkRuntime", "check LPD runtime: pre = "+ pre +"ms" + " | runtime = " + runtime + "ms | decode = " + decodeRuntime+"ms");
        return LPDResult;
        
    }

    /** Call in pipeline vehicle detect */
    public LPDResult runInference(Bbox bbox) {
        long startPre = System.currentTimeMillis();
        preProcess(bbox);
        long pre = System.currentTimeMillis() - startPre;

        long startDetect = System.currentTimeMillis();
        modelOutputSource = network.execute(modelInputSource);
        long runtime = System.currentTimeMillis() - startDetect;

        long startDecode = System.currentTimeMillis();
        LPDResult LPDResult = postProcess(bbox);
        long decodeRuntime = System.currentTimeMillis() - startDecode;

        Log.d(LOGTAG + "_checkRuntime", "check LPD runtime: pre = "+ pre +"ms" + " | runtime = " + runtime + "ms | decode = " + decodeRuntime+"ms");
        return LPDResult;

    }

    @Override
    public void release() {
        network.release();
        releaseTensors(modelInputSource);
    }

    @Override
    protected void setModelConfig() {

    }

    @Override
    protected void preProcess(Bitmap inputBitmap) {
        mRatioWidth =   ((float) IMAGE_WIDTH) / ((float)inputBitmap.getWidth());
        mRatioHeight =  ((float) IMAGE_HEIGHT) / ((float)inputBitmap.getHeight());

        final Matrix scalingMatrix = new Matrix();
        scalingMatrix.postScale(mRatioWidth, mRatioHeight);

        Bitmap processBitmap = Bitmap.createBitmap(inputBitmap,
                0, 0,
                inputBitmap.getWidth(), inputBitmap.getHeight(),
                scalingMatrix, false);

        if (inputBitmap.getConfig() != Bitmap.Config.ARGB_8888 || !inputBitmap.isMutable()){
            processBitmap = inputBitmap.copy(Bitmap.Config.ARGB_8888, true);
        }

        Utils.bitmapToMat(processBitmap, processMatRGBA);
        Imgproc.cvtColor(processMatRGBA , processMatRGB , 3);//COLOR_RGBA2RGB
        processMatRGB.convertTo(processMat32F, CvType.CV_32F);
        Core.subtract(processMat32F, normalized_subtraction, processMat32F);
        Core.divide(processMat32F, normalized_divide, processMat32F);

        /* Put pre-processed image to tensor and pre-pair run inference from model */
        prepairModelTensorsInput(processMat32F);
    }
    protected void prepairModelTensorsInput(@NotNull Mat preprocessedMat) {
        preprocessedMat.get(0,0, floatArrayInputValues);
        inputTensor.write(floatArrayInputValues, 0, floatArrayInputValues.length);
        modelInputSource.put(INPUT_LAYER,inputTensor);
    }

    @Override
    protected void preProcess(InferenceResult inferenceResult) {
        Bbox bbox = (Bbox) inferenceResult;
        Mat originMatRGB = bbox.getProcessRBBMat();

        cropPointX = (int) bbox.x1;
        cropPointY = (int) bbox.y1;
        float widthCrop = bbox.x2 - bbox.x1;
        float heightCrop = bbox.y2 - bbox.y1;

        int cropSize = (int) (widthCrop > heightCrop ? widthCrop : heightCrop);
        cropSize = cropSize >= originMatRGB.rows() ? originMatRGB.rows() : cropSize;

        cropPointX = Math.min(cropPointX, originMatRGB.cols() - 1 - cropSize);
        cropPointY = Math.min(cropPointY, originMatRGB.rows() - 1 - cropSize);
        cropPointX = Math.max(cropPointX,0);
        cropPointY = Math.max(cropPointY,0);



        Rect rectCrop = new Rect(cropPointX, cropPointY, cropSize, cropSize);

        Log.d(LOGTAG, "check crop size = " + cropPointX + " " + cropPointY + " " + cropSize);

        Mat cropMatRGB = new Mat(originMatRGB, rectCrop);
        mRatioWidth = IMAGE_WIDTH / (float) cropMatRGB.cols();
        mRatioHeight = mRatioWidth;

        Imgproc.resize(cropMatRGB, processMatRGB, size);
        processMatRGB.convertTo(processMat32F, CvType.CV_32F);
        Core.subtract(processMat32F, normalized_subtraction, processMat32F);
        Core.divide(processMat32F, normalized_divide, processMat32F);

        /* Put pre-processed image to tensor and pre-pair run inference from model */
        prepairModelTensorsInput(processMat32F);
    }

    @Override
    protected LPDResult postProcess(InferenceResult inferenceResult) {
        Map<Integer, float[]> scoreList = new HashMap<>();
        Map<Integer, float[]> vectorBoxDistanceList = new HashMap<>();
        Map<Integer, float[]> keyPointList = new HashMap<>();

        for (Map.Entry<String, FloatTensor> output : modelOutputSource.entrySet()) {
            FloatTensor outputTensor = output.getValue();
            switch (output.getKey()) {
                case "497":
                    float[] score_1 = new float[outputTensor.getSize()];
                    outputTensor.read(score_1, 0, score_1.length);
                    scoreList.put(STRIDE_LIST[2], score_1);
                    //Log.d("debugModelDecode", "score_1: " + Arrays.toString(score_1));
                    //Log.d("debugModelDecode", "score_1: " + score_1.length);
                    break;
                case "500":
                    float[] loc_1 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_1, 0, loc_1.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[2], loc_1);
                    //Log.d("debugModelDecode", "loc_1: " + Arrays.toString(loc_1));
                   // Log.d("debugModelDecode", "loc_1: " + loc_1.length);
                    break;
                case "503":
                    float[] keyPoint_1 = new float[outputTensor.getSize()];
                    outputTensor.read(keyPoint_1, 0, keyPoint_1.length);
                    keyPointList.put(STRIDE_LIST[2], keyPoint_1);
                    //Log.d("debugModelDecode", "keyPoint_1: " + Arrays.toString(keyPoint_1));
                    //Log.d("debugMo-delDecode", "keyPoint_1: " + keyPoint_1.length);
                    break;

                case "472":
                    float[] score_2 = new float[outputTensor.getSize()];
                    outputTensor.read(score_2, 0, score_2.length);
                    scoreList.put(STRIDE_LIST[1], score_2);
                    //Log.d("debugModelDecode", "score_2: " + Arrays.toString(score_2));
                    //Log.d("debugModelDecode", "score_2: " + score_2.length);
                    break;
                case "475":
                    float[] loc_2 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_2, 0, loc_2.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[1], loc_2);
                    //Log.d("debugModelDecode", "loc_2: " + Arrays.toString(loc_2));
                    //Log.d("debugModelDecode", "loc_2: " + loc_2.length);
                    break;
                case "478":
                    float[] keyPoint_2 = new float[outputTensor.getSize()];
                    outputTensor.read(keyPoint_2, 0, keyPoint_2.length);
                    keyPointList.put(STRIDE_LIST[1], keyPoint_2);
                    //Log.d("debugModelDecode", "keyPoint_2: " + Arrays.toString(keyPoint_2));
                    //Log.d("debugModelDecode", "keyPoint_2: " + keyPoint_2.length);
                    break;

                case "447":
                    float[] score_3 = new float[outputTensor.getSize()];
                    outputTensor.read(score_3, 0, score_3.length);
                    scoreList.put(STRIDE_LIST[0], score_3);
                    //Log.d("debugModelDecode", "score_3: " + Arrays.toString(score_3));
                    //Log.d("debugModelDecode", "score_3: " + score_3.length);
                    break;
                case "450":
                    float[] loc_3 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_3, 0, loc_3.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[0], loc_3);
                    //Log.d("debugModelDecode", "loc_3: " + Arrays.toString(loc_3));
                    //Log.d("debugModelDecode", "loc_3: " + loc_3.length);
                    break;
                case "453":
                    float[] keyPoint_3 = new float[outputTensor.getSize()];
                    outputTensor.read(keyPoint_3, 0, keyPoint_3.length);
                    keyPointList.put(STRIDE_LIST[0], keyPoint_3);
                    //Log.d("debugModelDecode", "keyPoint_3: " + Arrays.toString(keyPoint_3));
                    //Log.d("debugModelDecode", "keyPoint_3: " + keyPoint_3.length);
                    break;
            }
        }
        return decodeScrfdLicensePlate(inferenceResult, scoreList, vectorBoxDistanceList, keyPointList);
    }

    @NotNull
    @Contract("scoreList, vectorBoxDistanceList -> new DetectionResult")
    private LPDResult decodeScrfdLicensePlate(
            InferenceResult previousInference,
            Map<Integer, float[]> scoreList,
            Map<Integer, float[]> vectorBoxDistanceList,
            Map<Integer, float[]> keyPointList
    ){
        List<Bbox> listBboxes = new ArrayList<>();

        long startDecodeSCRFD = System.currentTimeMillis();
        for (int stride : STRIDE_LIST){
            //Log.d(LOGTAG, "check keyPoint here ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\n\n");
            //Log.d(LOGTAG, "check keyPoint = " + keyPointList.get(stride).length);
            listBboxes.addAll(buildBbox(previousInference, vectorBoxDistanceList.get(stride), scoreList.get(stride), keyPointList.get(stride), stride));
        }
        long decodeTime = System.currentTimeMillis() - startDecodeSCRFD;
        Log.d(LOGTAG, "check decodeTime scrfd:" + decodeTime);

        return new LPDResult(previousInference, nms(listBboxes));
    }

    @NotNull
    private List<Bbox> nms(@NotNull List<Bbox> bboxes) {
        List<Bbox> selected = new ArrayList<>();
        for (Bbox boxA : bboxes) {
            boolean shouldSelect = true;
            // Does the current box overlap one of the selected boxes more than the
            // given threshold amount? Then it's too similar, so don't keep it.
            for (Bbox boxB : selected) {
                if (IOU(boxA, boxB) > IOU_THRESHOLD) {
                    shouldSelect = false;
                    break;
                }
            }
            // This bounding box did not overlap too much with any previously selected
            // bounding box, so we'll keep it.
            if (shouldSelect) {
                selected.add(boxA);
            }
        }
        return selected;
    }

    @NotNull
    private List<Bbox> buildBbox (
            InferenceResult previousInference,
            float[] distanceList,
            float[] scoreList,
            float[] keyPoints,
            int stride
    ){
        List<Bbox> listBbox = new ArrayList<>();
        //int currDistInd = -1;
        int currScoreInd = -1;

        //loop for vertical anchors
        for (int y = 0; y < IMAGE_HEIGHT; y += stride){
            for (int x = 0; x < IMAGE_WIDTH; x += stride ){
                for (int k = 0; k < 2; ++k) { //loop 2 time for vertical and horizon anchors
                    if (scoreList[++currScoreInd] > CONFIDENCE_THRESHOLD) {
                        float x1 = (x - distanceList[currScoreInd * 4 + 0] * stride) / mRatioWidth + cropPointX;
                        float y1 = (y - distanceList[currScoreInd * 4 + 1] * stride) / mRatioHeight + cropPointY;
                        float x2 = (x + distanceList[currScoreInd * 4 + 2] * stride) / mRatioWidth + cropPointX;
                        float y2 = (y + distanceList[currScoreInd * 4 + 3] * stride) / mRatioHeight + cropPointY;

                        Point[] resultKeyPoints = new Point[5];
                        for (int j = 0; j < 5; ++j) {
                            float kp_x = (x + keyPoints[currScoreInd*10 + j*2 + 0] * stride ) / mRatioWidth + cropPointX;
                            float kp_y = (y + keyPoints[currScoreInd*10 + j*2 + 1] * stride) / mRatioHeight + cropPointY;
                            resultKeyPoints[j] = new Point(kp_x, kp_y);
                        }
                        listBbox.add(new Bbox(previousInference, x1, y1, x2, y2, scoreList[currScoreInd], resultKeyPoints));
                    }
                }
            }
        }

        Collections.sort(listBbox);
        return listBbox;
    }

    private float IOU(@NotNull Bbox a, Bbox b ) {
        float areaA = (a.x2 - a.x1) * (a.y2 - a.y1);
        if (areaA <= 0) {
            return 0;
        }
        float areaB = (b.x2 - b.x1) * (b.y2 - b.y1);
        if (areaB <= 0) {
            return 0;
        }
        float intersectionMinX = Math.max(a.x1, b.x1);
        float intersectionMinY = Math.max(a.y1, b.y1);
        float intersectionMaxX = Math.min(a.x2, b.x2);
        float intersectionMaxY = Math.min(a.y2, b.y2);
        float intersectionArea = Math.max(intersectionMaxY - intersectionMinY, 0) *
                Math.max(intersectionMaxX - intersectionMinX, 0);
        return intersectionArea / (areaA + areaB - intersectionArea);
    }
    
}
