package com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.face;

import android.app.Application;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.graphics.PointF;
import android.util.Log;

import androidx.annotation.NonNull;

import com.qualcomm.qti.snpe.FloatTensor;
import com.qualcomm.qti.snpe.NeuralNetwork;
import com.qualcomm.qti.snpe.SNPE;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.InferenceResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.DLCInference;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.Anchor;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.Bbox;

import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.opencv.android.Utils;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.imgproc.Imgproc;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class RetinaFaceDetector extends DLCInference {
    static String LOGTAG = RetinaFaceDetector.class.getSimpleName();
    public final List<Anchor> anchors = new ArrayList<>();

    /** variable for cache-pre-process process*/
    private Scalar normlized_subtraction = new Scalar(104.0f, 117.0f, 123.0f);
    private Mat processMatRGBA;
    private Mat processMatRGB;
    private Mat processMat32F;

    /** variable for model I/O */
    protected String INPUT_LAYER;
    protected float[] floatArrayInputValues;
    protected FloatTensor inputTensor;
    protected Map<String, FloatTensor> modelInputSource;
    protected Map<String, FloatTensor> modelOutputSource;

    /** private model params*/
    private float mRatioWidth;
    private float mRatioHeight;
    private float CONFIDENCE_THRESHOLD;
    private float IOU_THRESHOLD;

    public RetinaFaceDetector(@NotNull Context context, Application application, int modelResId, NeuralNetwork.Runtime runtimeMode) {
        super(512, 288);
        /* try to build neural network */
        final InputStream modelInputStream = context.getResources().openRawResource(modelResId);
        try {
            network = new SNPE.NeuralNetworkBuilder(application)
                    .setDebugEnabled(false)
                    .setRuntimeOrder(
                            runtimeMode,
                            NeuralNetwork.Runtime.CPU
                    )
                    .setModel(modelInputStream, modelInputStream.available())
                    .setOutputLayers("concatenation_3", "concatenation_4", "concatenation_5")
                    //.setUseUserSuppliedBuffers()
                    .setCpuFallbackEnabled(true)
                    .setPerformanceProfile(NeuralNetwork.PerformanceProfile.HIGH_PERFORMANCE)
                    .build();
        } catch (IOException e) {
            e.printStackTrace();
        }

        /* setup input/output layer and pre-pair Mat for pre-process if model build success */
        if (network != null){
            floatArrayInputValues = new float[IMAGE_WIDTH * IMAGE_HEIGHT * 3];
            INPUT_LAYER = network.getInputTensorsNames().iterator().next();
            inputTensor = network.createFloatTensor(network.getInputTensorsShapes().get(INPUT_LAYER));
            modelInputSource = new HashMap<>();
            createAnchor();
            processMatRGBA = new Mat();
            processMatRGB = new Mat();
            processMat32F = new Mat();
        }
    }

    @Override
    protected void setModelConfig() {
        CONFIDENCE_THRESHOLD = 0.7f;
        IOU_THRESHOLD = 0.4f;
    }

    private void createAnchor() {
        final float[][] featureMap = new float[3][3];
        final float[][] minSizes = {{10, 20}, {32, 64}, {128, 256}};
        final float[] steps = {8, 16, 32};
        for (int i = 0; i < 3; ++i) {
            featureMap[i][0] = (float) Math.ceil(IMAGE_HEIGHT / steps[i]);
            featureMap[i][1] = (float) Math.ceil(IMAGE_WIDTH / steps[i]);
        }
        for (int k = 0; k < 3; ++k) {
            for (int i = 0; i < featureMap[k][0]; ++i) {
                for (int j = 0; j < featureMap[k][1]; ++j) {
                    for (int l = 0; l < 2; ++l) {
                        final float s_ky = minSizes[k][l] / IMAGE_HEIGHT;
                        final float s_kx = minSizes[k][l] / IMAGE_WIDTH;
                        final float cx = (float) (j + 0.5) * steps[k] / IMAGE_WIDTH;
                        final float cy = (float) (i + 0.5) * steps[k] / IMAGE_HEIGHT;
                        final Anchor anchor = new Anchor(cx, cy, s_kx, s_ky);
                        anchors.add(anchor);
                    }
                }
            }
        }
    }

    /** Pre-process for free size image -> auto scale image to correct process-size */
    @Override
    protected void preProcess(@NotNull final Bitmap inputBitmap) {
        mRatioWidth =   ((float) IMAGE_WIDTH) / ((float)inputBitmap.getWidth());
        mRatioHeight =  ((float) IMAGE_HEIGHT) / ((float)inputBitmap.getHeight());

        final Matrix scalingMatrix = new Matrix();
        scalingMatrix.postScale(mRatioWidth, mRatioHeight);
        Bitmap processBitmap = Bitmap.createBitmap(inputBitmap,
                0,
                0,
                inputBitmap.getWidth(),
                inputBitmap.getHeight(),
                scalingMatrix,
                false
        );

        /*if (inputBitmap.getConfig() != Bitmap.Config.ARGB_8888 || !inputBitmap.isMutable()){
            processMedia = inputBitmap.copy(Bitmap.Config.ARGB_8888, true);
        }*/
        Utils.bitmapToMat(processBitmap, processMatRGBA);
        Imgproc.cvtColor(processMatRGBA , processMatRGB , 3);//COLOR_RGBA2RGB
        processMatRGB.convertTo(processMat32F, CvType.CV_32F);
        Core.subtract(processMat32F, normlized_subtraction, processMat32F);

        /* Put pre-processed image to tensor and pre-pair run inference from model */
        prepairModelTensorsInput(processMat32F);
    }

    @Override
    protected void preProcess(InferenceResult inferenceResult) {

    }

    protected void prepairModelTensorsInput(@NotNull Mat preprocessedMat){
        preprocessedMat.get(0,0, floatArrayInputValues);
        inputTensor.write(floatArrayInputValues, 0, floatArrayInputValues.length);
        modelInputSource.put(INPUT_LAYER,inputTensor);
    }

    @Override
    public FDResult runInference(@NotNull InferenceResult inputInference) {
        Bitmap inputBitmap = inputInference.getProcessBitmap();
//        preProcess(inputBitmap);
//        Mat inputMat = inputInference.getProcessRBBMat();
        long startPre = System.currentTimeMillis();
        preProcess(inputBitmap);
//        preProcess(inputMat);
        long endPre = System.currentTimeMillis();
        float fpsPre = 1000f / (float) ((endPre- startPre) > 0 ? (endPre- startPre) : 1) ;
//        long pre = endPre- startPre;


        long startrun = System.currentTimeMillis();
        modelOutputSource = network.execute(modelInputSource);
        long endrun = System.currentTimeMillis();
        float fpsrun = 1000f / (float) ((endrun- startrun) > 0 ? (endrun - startrun) : 1) ;

        long startpost = System.currentTimeMillis();
        FDResult fdResult =postProcess(inputInference);
        long endpost = System.currentTimeMillis();
        float fpspost = 1000f / (float) ((endpost- startpost) > 0 ? (endpost - startpost) : 1) ;

        Log.d(LOGTAG+ "_face detection", "preProcess: "+ fpsPre + "run: "+ fpsrun+ "post: "+ fpspost);

        return fdResult;
    }

    @Override
    protected FDResult postProcess(InferenceResult inferenceResult) {
        float[] locs = {};
        float[] landmarks = {};
        float[] confidences = {};
        for (Map.Entry<String, FloatTensor> output : modelOutputSource.entrySet()) {
            FloatTensor outputTensor = output.getValue();
            switch (output.getKey()) {
                case "loc0":
                    locs = new float[outputTensor.getSize()];
                    outputTensor.read(locs, 0, locs.length);
                    break;
                case "landmark0":
                    landmarks = new float[outputTensor.getSize()];
                    outputTensor.read(landmarks, 0, landmarks.length);
                    break;
                case "conf0":
                    confidences = new float[outputTensor.getSize()];
                    outputTensor.read(confidences, 0, confidences.length);
                    break;
            }
        }
        return buildBbox(inferenceResult, locs, confidences, landmarks);
    }

    @NonNull
    @Contract("processedBitmap , locs, confidences, landmarks -> new DetectionResult")
    private FDResult buildBbox(InferenceResult previousInference, float[] locs, float[] confidences, float[] landmarks) {
        List<Bbox> bboxes = new ArrayList<>();
        for (int i = 0; i < anchors.size(); ++i) {
            float cx = confidences[i * 2];
            float cy = confidences[i * 2 + 1];
            float conf = (float) (Math.exp(cy) / (Math.exp(cx) + Math.exp(cy)));
            if (conf > CONFIDENCE_THRESHOLD) {
                Anchor tmp = anchors.get(i);
                Anchor tmp1 = new Anchor();
                //box result = new box();

                tmp1.cx = (float) (tmp.cx + locs[i * 4] * 0.1 * tmp.sx);
                tmp1.cy = (float) (tmp.cy + locs[i * 4 + 1] * 0.1 * tmp.sy);
                tmp1.sx = (float) (tmp.sx * Math.exp(locs[i * 4 + 2] * 0.2));
                tmp1.sy = (float) (tmp.sy * Math.exp(locs[i * 4 + 3] * 0.2));

                // Extract bbox and confidences
                float x1 = (tmp1.cx - tmp1.sx / 2) * IMAGE_WIDTH ;
                x1 = x1 < 0 ? 0 : x1 / mRatioWidth;

                float y1 = (tmp1.cy - tmp1.sy / 2) * IMAGE_HEIGHT ;
                y1 = y1 < 0 ? 0 : y1 / mRatioHeight;

                float x2 = (tmp1.cx + tmp1.sx / 2) * IMAGE_WIDTH;
                x2 = x2 > IMAGE_WIDTH ? IMAGE_WIDTH : x2 / mRatioWidth;

                float y2 = (tmp1.cy + tmp1.sy / 2) * IMAGE_HEIGHT;
                y2 = y2 > IMAGE_HEIGHT ? IMAGE_HEIGHT : y2 / mRatioHeight;

                // extracting landmark
                Point[] resultLandmark = new Point[5];
                for (int j = 0; j < 5; ++j) {
                    float lx = (tmp.cx + (landmarks[i * 10 + j * 2]) * 0.1f * tmp.sx) * IMAGE_WIDTH / mRatioWidth;
                    float ly = (tmp.cy + (landmarks[i * 10 + j * 2 + 1]) * 0.1f * tmp.sy) * IMAGE_HEIGHT / mRatioHeight;
                    resultLandmark[j] = new Point(lx, ly);
                }
                bboxes.add(new Bbox(previousInference, x1, y1, x2, y2, conf, resultLandmark));
            }
        }
        Collections.sort(bboxes);
        bboxes = nms(bboxes);
        return new FDResult(previousInference, bboxes);
    }

    @NotNull
    private List<Bbox> nms(@NotNull List<Bbox> bboxes) {
        List<Bbox> selected = new ArrayList<>();
        for (Bbox boxA : bboxes) {
            boolean shouldSelect = true;
            // Does the current box overlap one of the selected boxes more than the
            // given threshold amount? Then it's too similar, so don't keep it.
            for (Bbox boxB : selected) {
                if (IOU(boxA, boxB) > IOU_THRESHOLD) {
                    shouldSelect = false;
                    break;
                }
            }
            // This bounding box did not overlap too much with any previously selected
            // bounding box, so we'll keep it.
            if (shouldSelect) {
                selected.add(boxA);
            }
        }
        return selected;
    }

    private float IOU(@NotNull Bbox a, Bbox b ) {
        float areaA = (a.x2 - a.x1) * (a.y2 - a.y1);
        if (areaA <= 0) {
            return 0;
        }
        float areaB = (b.x2 - b.x1) * (b.y2 - b.y1);
        if (areaB <= 0) {
            return 0;
        }
        float intersectionMinX = Math.max(a.x1, b.x1);
        float intersectionMinY = Math.max(a.y1, b.y1);
        float intersectionMaxX = Math.min(a.x2, b.x2);
        float intersectionMaxY = Math.min(a.y2, b.y2);
        float intersectionArea = Math.max(intersectionMaxY - intersectionMinY, 0) *
                Math.max(intersectionMaxX - intersectionMinX, 0);
        return intersectionArea / (areaA + areaB - intersectionArea);
    }

    @Override
    public void release() {
        network.release();
        releaseTensors(modelInputSource);
    }

}
