package com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.vehicle;

import static com.google.common.base.Ascii.FF;

import android.app.Application;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.util.Log;
import android.util.Pair;

import androidx.annotation.NonNull;

import com.qualcomm.qti.snpe.FloatTensor;
import com.qualcomm.qti.snpe.NeuralNetwork;
import com.qualcomm.qti.snpe.SNPE;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.InferenceResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.DLCInference;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.Bbox;
import com.securityandsafetythings.examples.aiapp.utilities.FileUtils;
import com.securityandsafetythings.examples.aiapp.utilities.ImageUtils;

import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.opencv.android.Utils;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Scalar;
import org.opencv.imgproc.Imgproc;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ScrfdVehicleDetector extends DLCInference {
    static String LOGTAG = ScrfdVehicleDetector.class.getSimpleName();
    /** variable for model I/O */
    protected String INPUT_LAYER;
    protected float[] floatArrayInputValues;
    protected FloatTensor inputTensor;
    protected Map<String, FloatTensor> modelInputSource;
    protected Map<String, FloatTensor> modelOutputSource;

    /** private model params implement here*/
    private float mRatioWidth;
    private float mRatioHeight;
    private final int[] STRIDE_LIST = new int[]{8, 16, 32, 64, 128};
    private final String[] labels = new String[]{"moto", "truck", "bike", "car", "pedestrian", "bus", "ba_gac"};
    private final String DEFAULT_LABEL = "NONE";
    public final int NUM_ANCHOR_TYPES = 3; //horizon, vertical, square
    public static float CONFIDENCE_THRESHOLD = 0.4f;
    public static float IOU_THRESHOLD = 0.4f; // nms threshold


    /** variable for cache-pre-process process*/
    private Scalar normalized_subtraction = new Scalar(127.5f, 127.5f, 127.5f);
    private Scalar normalized_divide = new Scalar(128f, 128f, 128f);
    private Pair<Float, String> labelCache = new Pair<>(0.0f, "NONE");
    private Mat processMatRGBA;
    private Mat processMatRGB;
    private Mat processMat32F;
    private float aspectRatio = 1;
    private int borderType = Core.BORDER_CONSTANT;
    Scalar value = new Scalar(255, 255, 255);

    public ScrfdVehicleDetector(@NotNull Context context, Application application, int modelResId, NeuralNetwork.Runtime runtimeMode) {
        super(896, 512);
        this.context = context;
        this.application = application;
        aspectRatio = IMAGE_WIDTH / (float)IMAGE_HEIGHT;

        final InputStream modelInputStream = context.getResources().openRawResource(modelResId);
        try {
            network = new SNPE.NeuralNetworkBuilder(application)
                    .setDebugEnabled(false)
                    .setRuntimeOrder(
                            runtimeMode,
                            NeuralNetwork.Runtime.CPU
                    )
                    .setModel(modelInputStream, modelInputStream.available())
                    .setOutputLayers("Sigmoid_240", "Reshape_243", "Sigmoid_224", "Reshape_227", "Sigmoid_208",
                            "Reshape_211", "Sigmoid_192", "Reshape_195", "Sigmoid_176", "Reshape_179")
                    //.setUseUserSuppliedBuffers()
                    .setCpuFallbackEnabled(true)
                    .setPerformanceProfile(NeuralNetwork.PerformanceProfile.BURST)
                    .build();
        } catch (IOException e) {
            e.printStackTrace();
        }

        /* setup input/output layer and pre-pair Mat for pre-process if model build success */
        if (network != null){
            floatArrayInputValues = new float[IMAGE_WIDTH * IMAGE_HEIGHT * 3];
            INPUT_LAYER = network.getInputTensorsNames().iterator().next();
            inputTensor = network.createFloatTensor(network.getInputTensorsShapes().get(INPUT_LAYER));
            modelInputSource = new HashMap<>();
            processMatRGBA = new Mat();
            processMatRGB = new Mat();
            processMat32F = new Mat();
        }
    }

    @Override
    public VDResult runInference(@NonNull InferenceResult inputInference) {
        Bitmap inputBitmap = inputInference.getProcessBitmap();

        long startPre = System.currentTimeMillis();
        preProcess(inputBitmap);
        long pre = System.currentTimeMillis() - startPre;

        long startDetect = System.currentTimeMillis();
        modelOutputSource = network.execute(modelInputSource);
        long runtime = System.currentTimeMillis() - startDetect;

        long startDecode = System.currentTimeMillis();
        VDResult vdResult = postProcess(inputInference);
        long decodeRuntime = System.currentTimeMillis() - startDecode;

        Log.d(LOGTAG + "_checkRuntime", "check vehicle runtime: pre = "+ pre +"ms" + " | runtime = " + runtime + "ms | decode = " + decodeRuntime+"ms");
        return vdResult;
    }

    @Override
    public void release() {
        network.release();
        releaseTensors(modelInputSource);
    }

    @Override
    protected void setModelConfig() {

    }

    @Override
    protected void preProcess(Bitmap inputBitmap) {

        Bitmap processBitmap = resizeKeepRatio(inputBitmap);
        //ImageUtils.saveBitmapToCache(context, processBitmap, "checkPreprocessedInputBitmap");

        Utils.bitmapToMat(processBitmap, processMatRGBA);
        Imgproc.cvtColor(processMatRGBA , processMatRGB , 3);//COLOR_RGBA2RGB

        //Log.d(LOGTAG, "check preprocess before padding = " + processMatRGB.cols() + "x" + processMatRGB.rows());

        if (processMatRGB.cols() < IMAGE_WIDTH){
            int right = IMAGE_WIDTH - processMatRGBA.cols();
            Log.d(LOGTAG, "check preprocess add padding right = " + right);
            Core.copyMakeBorder( processMatRGB, processMatRGB, 0, 0, 0, right, borderType, value);
        } else if (processMatRGB.rows() < IMAGE_HEIGHT){
            int bot = IMAGE_HEIGHT - processMatRGB.rows();
            Log.d(LOGTAG, "check preprocess add padding bot = " + bot);
            Core.copyMakeBorder( processMatRGB, processMatRGB, 0, bot, 0, 0, borderType, value);
        }

        //Log.d(LOGTAG, "check preprocess after padding = " + processMatRGB.cols() + "x" + processMatRGB.rows());

        //ImageUtils.saveMatToCache(context, processMatRGB,"checkMatImageAfterPre");

        processMatRGB.convertTo(processMat32F, CvType.CV_32F);
        Core.subtract(processMat32F, normalized_subtraction, processMat32F);
        Core.divide(processMat32F, normalized_divide, processMat32F);

        /* Put pre-processed image to tensor and pre-pair run inference from model */
        prepairModelTensorsInput(processMat32F);
    }

    protected void prepairModelTensorsInput(@NotNull Mat preprocessedMat) {
        preprocessedMat.get(0,0, floatArrayInputValues);
        inputTensor.write(floatArrayInputValues, 0, floatArrayInputValues.length);
        modelInputSource.put(INPUT_LAYER,inputTensor);
    }

    @Override
    protected void preProcess(InferenceResult inferenceResult) {

    }

    @Override
    protected VDResult postProcess(InferenceResult inferenceResult) {
        Map<Integer, float[]> scoreList = new HashMap<>();
        Map<Integer, float[]> vectorBoxDistanceList = new HashMap<>();

        for (Map.Entry<String, FloatTensor> output : modelOutputSource.entrySet()) {
            FloatTensor outputTensor = output.getValue();
            switch (output.getKey()) {
                case "479": //float32[21504,7]
                    float[] score_1 = new float[outputTensor.getSize()];
                    outputTensor.read(score_1, 0, score_1.length);
                    scoreList.put(STRIDE_LIST[0], score_1);

                    //FileUtils.saveToTextFile(context, Arrays.toString(score_1), "479");
                    //Log.d("debugModelDecode", "score_1: " + Arrays.toString(score_1));
                    //Log.d("debugModelDecode", "score_1: " + score_1.length);
                    break;
                case "495":
                    float[] score_2 = new float[outputTensor.getSize()];
                    outputTensor.read(score_2, 0, score_2.length);
                    scoreList.put(STRIDE_LIST[1], score_2);

                    //FileUtils.saveToTextFile(context, Arrays.toString(score_2), "495");
                    //Log.d("debugModelDecode", "score_2: " + Arrays.toString(score_2));
                    //Log.d("debugModelDecode", "score_2: " + score_2.length);
                    break;
                case "511":
                    float[] score_3 = new float[outputTensor.getSize()];
                    outputTensor.read(score_3, 0, score_3.length);
                    scoreList.put(STRIDE_LIST[2], score_3);

                    //FileUtils.saveToTextFile(context, Arrays.toString(score_3), "511");
                    //Log.d("debugModelDecode", "score_3: " + Arrays.toString(score_3));
                    //Log.d("debugModelDecode", "score_3: " + score_3.length);
                    break;
                case "527":
                    float[] score_4 = new float[outputTensor.getSize()];
                    outputTensor.read(score_4, 0, score_4.length);
                    scoreList.put(STRIDE_LIST[3], score_4);

                    //FileUtils.saveToTextFile(context, Arrays.toString(score_4), "527");
                    //Log.d("debugModelDecode", "score_4: " + Arrays.toString(score_4));
                    //Log.d("debugModelDecode", "score_4: " + score_4.length);
                    break;
                case "543":
                    float[] score_5 = new float[outputTensor.getSize()];
                    outputTensor.read(score_5, 0, score_5.length);
                    scoreList.put(STRIDE_LIST[4], score_5);

                    //FileUtils.saveToTextFile(context, Arrays.toString(score_5), "543");
                    //Log.d("debugModelDecode", "score_5: " + Arrays.toString(score_5));
                    //Log.d("debugModelDecode", "score_5: " + score_5.length);
                    break;
                case "482":
                    float[] loc_1 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_1, 0, loc_1.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[0], loc_1);

                    //FileUtils.saveToTextFile(context, Arrays.toString(loc_1), "482");
                    //Log.d("debugModelDecode", "loc_1: " + Arrays.toString(loc_1));
                    //Log.d("debugModelDecode", "loc_1: " + loc_1.length);
                    break;
                case "498":
                    float[] loc_2 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_2, 0, loc_2.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[1], loc_2);

                    //FileUtils.saveToTextFile(context, Arrays.toString(loc_2), "498");
                    //Log.d("debugModelDecode", "loc_2: " + Arrays.toString(loc_2));
                    //Log.d("debugModelDecode", "loc_2: " + loc_2.length);
                    break;
                case "514":
                    float[] loc_3 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_3, 0, loc_3.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[2], loc_3);

                    //FileUtils.saveToTextFile(context, Arrays.toString(loc_3), "514");
                    //Log.d("debugModelDecode", "loc_3: " + Arrays.toString(loc_3));
                    //Log.d("debugModelDecode", "loc_3: " + loc_3.length);
                    break;
                case "530":
                    float[] loc_4 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_4, 0, loc_4.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[3], loc_4);

                    //FileUtils.saveToTextFile(context, Arrays.toString(loc_4), "530");
                    //Log.d("debugModelDecode", "loc_4: " + Arrays.toString(loc_4));
                    //Log.d("debugModelDecode", "loc_4: " + loc_4.length);
                    break;
                case "546":
                    float[] loc_5 = new float[outputTensor.getSize()];
                    outputTensor.read(loc_5, 0, loc_5.length);
                    vectorBoxDistanceList.put(STRIDE_LIST[4], loc_5);

                    //FileUtils.saveToTextFile(context, Arrays.toString(loc_5), "546");
                    //Log.d("debugModelDecode", "loc_5: " + Arrays.toString(loc_5));
                    //Log.d("debugModelDecode", "loc_5: " + loc_5.length);
                    break;
            }
        }
        return decodeScrfdVehicle(inferenceResult, scoreList, vectorBoxDistanceList);
    }



    @NonNull
    @Contract("_, _, _ -> new")
    private VDResult decodeScrfdVehicle(InferenceResult previousInference, Map<Integer, float[]> scoreList, Map<Integer, float[]> vectorBoxDistanceList){
        Map<String, List<Bbox>> rawMapResults = new HashMap<>();

        for (String label : labels){
            rawMapResults.put(label, new ArrayList<>());
        }

        long startDecodeSCRFD = System.currentTimeMillis();
        for (int stride : STRIDE_LIST){
            buildBbox(previousInference, rawMapResults, vectorBoxDistanceList.get(stride), scoreList.get(stride), stride);
        }

        //Sort rawMapResults after build bbox
        sortMapBuiltBbox(rawMapResults);

        long decodeTime = System.currentTimeMillis() - startDecodeSCRFD;
        Log.d(LOGTAG, "check decodeTime scrfd:" + decodeTime);

        //NMS
        return new VDResult(previousInference, nms(rawMapResults));
    }

    private void buildBbox(final InferenceResult previousInference,
                                  final Map<String, List<Bbox>> outputMap,
                                  float[] distanceList, float[] scoreList,
                                  int stride
    ){
        int currInd = -1;

        String labelCache;
        float scoreCache = 0.0f;

        //loop for vertical anchors
        for (int y = 0; y < IMAGE_HEIGHT; y += stride){
            for (int x = 0; x < IMAGE_WIDTH; x += stride ){
                //loop for vertical anchors
                for (int i=0; i < NUM_ANCHOR_TYPES; ++i) {
                    ++currInd;
                    labelCache = null;

                    for (int j=0; j< labels.length; ++j){
                        if (scoreList[currInd * labels.length + j] >= CONFIDENCE_THRESHOLD){
                            if (scoreList[currInd * labels.length + j] > scoreCache) {
                                labelCache = labels[j];
                                scoreCache = scoreList[currInd * labels.length + j];
                            }
                        }
                    }

                    if (labelCache != null) {
                        //Build box
                        float x1 = (x - distanceList[currInd * 4 + 0] * stride) / mRatioWidth;
                        float y1 = (y - distanceList[currInd * 4 + 1] * stride) / mRatioHeight;
                        float x2 = (x + distanceList[currInd * 4 + 2] * stride) / mRatioWidth;
                        float y2 = (y + distanceList[currInd * 4 + 3] * stride) / mRatioHeight;

                        outputMap.get(labelCache).add(new Bbox(previousInference, x1, y1, x2, y2, scoreCache, labelCache));
                    }
                }
            }
        }

        //Collections.sort(listBbox);
    }

    private void sortMapBuiltBbox(@NonNull final Map<String, List<Bbox>> outputMap){
        for (Map.Entry<String, List<Bbox>> entry : outputMap.entrySet()){
            Collections.sort(entry.getValue());
        }
    }

    @NotNull
    private Map<String, List<Bbox>> nms(@NotNull final Map<String, List<Bbox>> inputMap) {
        for (Map.Entry<String, List<Bbox>> bboxesEntry : inputMap.entrySet()) {
            List<Bbox> selected = new ArrayList<>();
            //Log.d(LOGTAG + "debugModelDecode", "class = " + bboxesEntry.getKey() + " | before nms arr size " + selected.size() +" | values = " + selected);

            for (Bbox boxA : bboxesEntry.getValue()) {
                boolean shouldSelect = true;
                // Does the current box overlap one of the selected boxes more than the
                // given threshold amount? Then it's too similar, so don't keep it.
                for (Bbox boxB : selected) {
                    if (IOU(boxA, boxB) > IOU_THRESHOLD) {
                        shouldSelect = false;
                        break;
                    }
                }
                // This bounding box did not overlap too much with any previously selected
                // bounding box, so we'll keep it.
                if (shouldSelect) {
                    selected.add(boxA);
                }
            }

            bboxesEntry.setValue(selected);
            //Log.d(LOGTAG + "debugModelDecode", "class = " + bboxesEntry.getKey() + " | decoded arr size " + selected.size() +" | values = " + selected);
        }
        return inputMap;
    }

    public Bitmap resizeKeepRatio(Bitmap inputBitmap){
        float currentRatio = (float) inputBitmap.getWidth() / (float) inputBitmap.getHeight();
        //Log.d(LOGTAG, "resizeKeepRatio:  currentRatio = " + currentRatio);
        if (currentRatio < aspectRatio){
            Log.d(LOGTAG, "resizeKeepRatio:  currentRatio = " + currentRatio + " | aspect ratio = " + aspectRatio);
            mRatioHeight =  ((float) IMAGE_HEIGHT) / ((float)inputBitmap.getHeight());
            mRatioWidth = mRatioHeight;

            final Matrix scalingMatrix = new Matrix();
            scalingMatrix.postScale(mRatioHeight, mRatioHeight);

            Bitmap processBitmap = Bitmap.createBitmap(inputBitmap,
                    0, 0,
                    inputBitmap.getWidth(), inputBitmap.getHeight(),
                    scalingMatrix, false);

            if (inputBitmap.getConfig() != Bitmap.Config.ARGB_8888 || !inputBitmap.isMutable()){
                processBitmap = inputBitmap.copy(Bitmap.Config.ARGB_8888, true);
            }

            //return pad(processBitmap, IMAGE_WIDTH - processBitmap.getWidth(), 0);
            //Log.d(LOGTAG, "resizeKeepRatio: base height - image = " + processBitmap.getHeight() + "x" + processBitmap.getWidth());
            return processBitmap;
        } else {
            //Log.d(LOGTAG, "resizeKeepRatio:  currentRatio = " + currentRatio + " | aspect ratio = " + aspectRatio);
            mRatioWidth =   ((float) IMAGE_WIDTH) / ((float)inputBitmap.getWidth());
            mRatioHeight = mRatioWidth;
            final Matrix scalingMatrix = new Matrix();
            scalingMatrix.postScale(mRatioWidth, mRatioWidth);

            Bitmap processBitmap = Bitmap.createBitmap(inputBitmap,
                    0, 0,
                    inputBitmap.getWidth(), inputBitmap.getHeight(),
                    scalingMatrix, false);

            if (inputBitmap.getConfig() != Bitmap.Config.ARGB_8888 || !inputBitmap.isMutable()){
                processBitmap = inputBitmap.copy(Bitmap.Config.ARGB_8888, true);
            }
            //return pad(processBitmap, 0, IMAGE_HEIGHT - processBitmap.getHeight());
            //Log.d(LOGTAG, "resizeKeepRatio: base width - image = " + processBitmap.getHeight() + "x" + processBitmap.getWidth());
            return processBitmap;
        }
    }

    private float IOU(@NotNull Bbox a, Bbox b ) {
        float areaA = (a.x2 - a.x1) * (a.y2 - a.y1);
        if (areaA <= 0) {
            return 0;
        }
        float areaB = (b.x2 - b.x1) * (b.y2 - b.y1);
        if (areaB <= 0) {
            return 0;
        }
        float intersectionMinX = Math.max(a.x1, b.x1);
        float intersectionMinY = Math.max(a.y1, b.y1);
        float intersectionMaxX = Math.min(a.x2, b.x2);
        float intersectionMaxY = Math.min(a.y2, b.y2);
        float intersectionArea = Math.max(intersectionMaxY - intersectionMinY, 0) *
                Math.max(intersectionMaxX - intersectionMinX, 0);
        return intersectionArea / (areaA + areaB - intersectionArea);
    }

    public Bitmap pad(Bitmap Src, int padding_x, int padding_y) {
        Bitmap outputimage = Bitmap.createBitmap(Src.getWidth() + padding_x,Src.getHeight() + padding_y, Bitmap.Config.ARGB_8888);
        Canvas can = new Canvas(outputimage);
        can.drawARGB(FF,FF,FF,FF); //This represents White color
        can.drawBitmap(Src, padding_x, padding_y, null);
        return outputimage;
    }
}
