package com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.extractor;

import android.app.Application;
import android.content.Context;
import android.graphics.Bitmap;
import android.util.Log;

import com.qualcomm.qti.snpe.FloatTensor;
import com.qualcomm.qti.snpe.NeuralNetwork;
import com.qualcomm.qti.snpe.SNPE;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.InferenceResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.DLCInference;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.align.FaceAlignResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.align.FaceAligner;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.sorttrack.TrackBox;

import org.jetbrains.annotations.NotNull;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.MatOfFloat;
import org.opencv.core.Scalar;

import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;

public class MobileFaceExtractorV2 extends DLCInference {
/** MobileFaceExtractor
 *  - this @MobileFaceExtractor class implement for face's feature extraction from detection
 *  result of retina face detector
 */

    /** variable for cache-pre-process process*/
    private Scalar normlized_subtraction = new Scalar(0.5f, 0.5f, 0.5f);
    private Scalar normlized_divide = new Scalar(0.5f, 0.5f, 0.5f);

    /** variable for model I/O */
    protected String INPUT_LAYER;
    protected float[] floatArrayInputValues;
    protected FloatTensor inputTensor;
    protected Map<String, FloatTensor> modelInputSource;
    protected Map<String, FloatTensor> modelOutputSource;

    /** private model params*/
    public static final int FEATURE_SIZE = 512;
    private FaceAligner faceAligner = null;

    public MobileFaceExtractorV2(@NotNull Context context, Application application, int modelResId, NeuralNetwork.Runtime runtimeMode) {
        super(112, 112);

        final InputStream modelInputStream = context.getResources().openRawResource(modelResId);
        try {
            network = new SNPE.NeuralNetworkBuilder(application)
                    .setDebugEnabled(false)
                    .setRuntimeOrder(
                            runtimeMode,
                            NeuralNetwork.Runtime.CPU
                    )
                    .setModel(modelInputStream, modelInputStream.available())
                    .setOutputLayers("Gemm_271")
                    //.setUseUserSuppliedBuffers()
                    .setCpuFallbackEnabled(true)
                    .setPerformanceProfile(NeuralNetwork.PerformanceProfile.HIGH_PERFORMANCE)
                    .build();
        } catch (IOException e) {
            e.printStackTrace();
        }

        /** setup input/output layer and pre-pair Mat for pre-process if model build success */
        if (network != null){
            floatArrayInputValues = new float[IMAGE_WIDTH * IMAGE_HEIGHT * 3];
            INPUT_LAYER = network.getInputTensorsNames().iterator().next();
            inputTensor = network.createFloatTensor(network.getInputTensorsShapes().get(INPUT_LAYER));
            modelInputSource = new HashMap<>();
        }
    }

    @Override
    protected void setModelConfig() {

    }

    @Override
    @Deprecated
    protected void preProcess(Bitmap inputBitmap) {
        // this extractor need extract fromm detected image with detection result
    }

    @Override
    protected void preProcess(InferenceResult inputInference) {
        FaceAlignResult faceAlignResult = (FaceAlignResult) inputInference;

        Mat processMatPreprocessAligned = faceAlignResult.getAlignedFaceMat();
        Core.subtract(processMatPreprocessAligned, normlized_subtraction, processMatPreprocessAligned);
        Core.divide(processMatPreprocessAligned, normlized_divide, processMatPreprocessAligned);
        //bbox.alignInput = inputAlignedMat;

        prepairModelTensorsInput(processMatPreprocessAligned);
    }

    protected void prepairModelTensorsInput(@NotNull Mat preprocessedMat){
        preprocessedMat.get(0,0, floatArrayInputValues);
        inputTensor.write(floatArrayInputValues, 0, floatArrayInputValues.length);
        modelInputSource.put(INPUT_LAYER,inputTensor);
    }

    /** Note: @param inputInference must be an instance of FaceAlignResult or have FaceAlignResult
     * This default method for runInference of face extractor
     * Input: FaceAlignResult
     * Output: Feature
     * */
    @Override
    public Feature runInference(@NotNull InferenceResult inputInference) {
        FaceAlignResult faceAlignResult = (FaceAlignResult) inputInference;
//        preProcess(faceAlignResult);
//        modelOutputSource = network.execute(modelInputSource);

        long startPre = System.currentTimeMillis();
        preProcess(faceAlignResult);
//        preProcess(inputMat);
        long endPre = System.currentTimeMillis();
        float fpsPre = 1000f / (float) ((endPre- startPre) > 0 ? (endPre- startPre) : 1) ;
//        long pre = endPre- startPre;


        long startrun = System.currentTimeMillis();
        modelOutputSource = network.execute(modelInputSource);
        long endrun = System.currentTimeMillis();
        float fpsrun = 1000f / (float) ((endrun- startrun) > 0 ? (endrun - startrun) : 1) ;

        long startpost = System.currentTimeMillis();
        Feature feature =postProcess(inputInference);
        long endpost = System.currentTimeMillis();
        float fpspost = 1000f / (float) ((endpost- startpost) > 0 ? (endpost - startpost) : 1) ;

        Log.d( "_face re", "preProcess: "+ fpsPre + "run: "+ fpsrun+ "post: "+ fpspost);

        return feature;
    }

    public Feature runInference(@NotNull TrackBox inputInference) {
        FaceAlignResult faceAlignResult = (FaceAlignResult) inputInference.getInferenceResult(InferenceResult.ResultName.faceAlign);
        preProcess(faceAlignResult);
        modelOutputSource = network.execute(modelInputSource);
        Feature feature = postProcess(inputInference);
        inputInference.addResult(feature);
        return feature;
    }

    @Override
    protected Feature postProcess(InferenceResult inferenceResult) {
        float[] outs = {};
        for (Map.Entry<String, FloatTensor> output : modelOutputSource.entrySet()) {
            FloatTensor outputTensor = output.getValue();
            switch (output.getKey()) {
                case "output0":
                    outs = new float[outputTensor.getSize()];
                    outputTensor.read(outs, 0, outs.length);
                    break;
            }
        }
        Mat cvMat = new MatOfFloat(outs).reshape(1, FEATURE_SIZE);
        Core.transpose(cvMat, cvMat);
        Core.normalize(cvMat, cvMat);

        return new Feature(inferenceResult, cvMat);
    }

    @Override
    public void release() {
        network.release();
        releaseTensors(modelInputSource);
    }
}
