package com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.classifier;

import android.app.Application;
import android.content.Context;
import android.graphics.Bitmap;

import androidx.annotation.NonNull;

import com.qualcomm.qti.snpe.FloatTensor;
import com.qualcomm.qti.snpe.NeuralNetwork;
import com.qualcomm.qti.snpe.SNPE;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.InferenceResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.DLCInference;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.Bbox;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.align.FaceAlignResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.sorttrack.TrackBox;

import org.jetbrains.annotations.NotNull;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.Scalar;

import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;

public class TrueFaceClassifier extends DLCInference {
    static final String LOGTAG = TrueFaceClassifier.class.getSimpleName();
    /** variable for model I/O */
    protected String INPUT_LAYER;
    protected float[] floatArrayInputValues;
    protected FloatTensor inputTensor;
    protected Map<String, FloatTensor> modelInputSource;
    protected Map<String, FloatTensor> modelOutputSource;

    /** variable for cache-pre-process process*/
    private final Scalar normalized_subtraction = new Scalar(0.485f, 0.456f, 0.406f);
    private final Scalar normalized_divide = new Scalar(0.229f, 0.224f, 0.225f);
    private Mat processMat;

    public TrueFaceClassifier(@NotNull Context context, Application application, int modelResId, NeuralNetwork.Runtime runtimeMode) {
        super(112,112);
        final InputStream modelInputStream = context.getResources().openRawResource(modelResId);
        try {
            network = new SNPE.NeuralNetworkBuilder(application)
                    .setDebugEnabled(false)
                    .setRuntimeOrder(
                            runtimeMode,
                            NeuralNetwork.Runtime.CPU
                    )
                    .setModel(modelInputStream, modelInputStream.available())
                    .setOutputLayers("Gemm_297")
                    //.setUseUserSuppliedBuffers()
                    .setCpuFallbackEnabled(true)
                    .setPerformanceProfile(NeuralNetwork.PerformanceProfile.HIGH_PERFORMANCE)
                    .build();
        } catch (IOException e) {
            e.printStackTrace();
        }

        /* setup input/output layer and pre-pair Mat for pre-process if model build success */
        if (network != null){
            floatArrayInputValues = new float[IMAGE_WIDTH * IMAGE_HEIGHT * 3];
            INPUT_LAYER = network.getInputTensorsNames().iterator().next();
            inputTensor = network.createFloatTensor(network.getInputTensorsShapes().get(INPUT_LAYER));
            modelInputSource = new HashMap<>();
            processMat = new Mat();
        }
    }

    @Override
    protected void setModelConfig() {

    }

    @Override
    protected void preProcess(Bitmap inputBitmap) {

    }

    @Override
    protected void preProcess(InferenceResult inputInference) {
        FaceAlignResult faceAlignResult= (FaceAlignResult) inputInference;

        Mat preprocessedAlignFace = faceAlignResult.getAlignedFaceMat();
        Core.subtract(preprocessedAlignFace, normalized_subtraction, processMat);
        Core.divide(processMat, normalized_divide, processMat);
        prepairModelTensorsInput(processMat);
    }

    /** This method process for single align face box and return if the result is a real face
     * or not throw score trueFace in TFClassifyResult
     * Input: Bbox from FDResult
     * Output: TFClassifyResult, added TFClassifyResult to Bbox also
     * */
    @Override
    public TFClassifyResult runInference(InferenceResult inputInference) {
        Bbox bbox = (Bbox) inputInference;
        FaceAlignResult faceAlignResult = (FaceAlignResult) bbox.getInferenceResult(InferenceResult.ResultName.faceAlign);
        preProcess(faceAlignResult);
        modelOutputSource = network.execute(modelInputSource);
        TFClassifyResult tfClassifyResult = postProcess(bbox);
        bbox.addResult(tfClassifyResult);
        return postProcess(inputInference);
    }

    /** This method process for single align face box and return the inputObject
     * with TFClassifyResult added into
     * Input: TrackBox <- {FaceAlignResult (for pre-process), Bbox (for post-process)}
     * Output: TFClassifyResult & input TrackBox with added TFClassifyResult
     * */
    public TFClassifyResult runInference(TrackBox inputInference) {
        TrackBox trackBox = (TrackBox) inputInference;
        FaceAlignResult faceAlignResult = (FaceAlignResult) trackBox.getInferenceResult(InferenceResult.ResultName.faceAlign);
        preProcess(faceAlignResult);
        modelOutputSource = network.execute(modelInputSource);
        TFClassifyResult tfClassifyResult = postProcess(inputInference.getInferenceResult(InferenceResult.ResultName.box));
        inputInference.addResult(tfClassifyResult);
        return tfClassifyResult;
    }

    protected void prepairModelTensorsInput(@NotNull Mat preprocessedMat){
        preprocessedMat.get(0,0, floatArrayInputValues);

        //Log.d(LOGTAG, "check input: array length= " + floatArrayInputValues.length + " | buffer size = " + inputTensor.getSize());
        inputTensor.write(floatArrayInputValues, 0, floatArrayInputValues.length/2);
        modelInputSource.put(INPUT_LAYER,inputTensor);
    }

    @Override
    protected TFClassifyResult postProcess(@NonNull InferenceResult inputInference) {
        Bbox box = (Bbox) inputInference;
        float[] outs = {};
        for (Map.Entry<String, FloatTensor> output : modelOutputSource.entrySet()) {
            FloatTensor outputTensor = output.getValue();
            switch (output.getKey()) {
                case "output0":
                    outs = new float[outputTensor.getSize()];
                    outputTensor.read(outs, 0, outs.length);
                    //Log.d(LOGTAG, "convertOutputs " + outputTensor.getSize());
                    break;
            }
        }
        double faceProb = Math.exp(outs[0]) / (Math.exp(outs[0]) + Math.exp(outs[1]));
        float finalProb = (float)(0.6 * faceProb + 0.4 * box.getConfidence());
        return new TFClassifyResult(inputInference, faceProb, outs, finalProb);
    }

    @Override
    public void release() {
        network.release();
        releaseTensors(modelInputSource);
    }

}
