package com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.extractor;

import android.app.Application;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.util.Log;

import androidx.annotation.NonNull;

import com.qualcomm.qti.snpe.FloatTensor;
import com.qualcomm.qti.snpe.NeuralNetwork;
import com.qualcomm.qti.snpe.SNPE;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.InferenceResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.DLCInference;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.Bbox;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.licensePlate.LPDResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.align.LPAlignResult;
import com.securityandsafetythings.examples.aiapp.utilities.FileUtils;
import com.securityandsafetythings.examples.aiapp.utilities.ImageUtils;

import org.jetbrains.annotations.NotNull;
import org.opencv.android.Utils;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import org.opencv.imgproc.Imgproc;

import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;

public class LicencePlateExtractor extends DLCInference {
    static String LOGTAG = LicencePlateExtractor.class.getSimpleName();
    /** variable for model I/O */
    protected String INPUT_LAYER;
    protected float[] floatArrayInputValues;
    protected FloatTensor inputTensor;
    protected Map<String, FloatTensor> modelInputSource;
    protected Map<String, FloatTensor> modelOutputSource;

    /** private model params implement here*/
    private float mRatioWidth;
    private float mRatioHeight;
    private char[] labels = {'A','B','C','D','E','F','G','H','J','K','L','M','N','P','Q','R','S','T','U','V','W','X','Y','Z','0','1','2','3','4','5','6','7','8','9',' '};

    /** variable for cache-pre-process process*/
    private final Size size = new Size(IMAGE_WIDTH,IMAGE_HEIGHT);
    private Mat processMatRGBA;
    private Mat processMatRGB;
    private Mat processMat32F;

    public LicencePlateExtractor(@NotNull Context context, Application application, int modelResId, NeuralNetwork.Runtime runtimeMode) {
        super(156, 32);
        this.context = context;

        final InputStream modelInputStream = context.getResources().openRawResource(modelResId);
        try {
            network = new SNPE.NeuralNetworkBuilder(application)
                    .setDebugEnabled(false)
                    .setRuntimeOrder(
                            runtimeMode,
                            NeuralNetwork.Runtime.CPU
                    )
                    .setModel(modelInputStream, modelInputStream.available())
                    .setOutputLayers("logits")
                    //.setUseUserSuppliedBuffers()
                    .setCpuFallbackEnabled(true)
                    .setPerformanceProfile(NeuralNetwork.PerformanceProfile.BURST)
                    .build();
        } catch (IOException e) {
            e.printStackTrace();
        }

        /* setup input/output layer and pre-pair Mat for pre-process if model build success */
        if (network != null){
            floatArrayInputValues = new float[IMAGE_WIDTH * IMAGE_HEIGHT * 3];
            INPUT_LAYER = network.getInputTensorsNames().iterator().next();
            inputTensor = network.createFloatTensor(network.getInputTensorsShapes().get(INPUT_LAYER));
            modelInputSource = new HashMap<>();
            processMatRGBA = new Mat();
            processMatRGB = new Mat();
            processMat32F = new Mat();
        }
    }


    @Override
    protected void setModelConfig() {

    }

    @Override
    protected void preProcess(@NonNull Bitmap inputBitmap) {
        mRatioWidth =   ((float) IMAGE_WIDTH) / ((float)inputBitmap.getWidth());
        mRatioHeight =  ((float) IMAGE_HEIGHT) / ((float)inputBitmap.getHeight());

        final Matrix scalingMatrix = new Matrix();
        scalingMatrix.postScale(mRatioWidth, mRatioHeight);

        Bitmap processBitmap = Bitmap.createBitmap(inputBitmap,
                0, 0,
                inputBitmap.getWidth(), inputBitmap.getHeight(),
                scalingMatrix, false);

        if (inputBitmap.getConfig() != Bitmap.Config.ARGB_8888 || !inputBitmap.isMutable()){
            processBitmap = inputBitmap.copy(Bitmap.Config.ARGB_8888, true);
        }

        Utils.bitmapToMat(processBitmap, processMatRGBA);
        Imgproc.cvtColor(processMatRGBA , processMatRGB , 3);//COLOR_RGBA2RGB
        processMatRGB.convertTo(processMat32F, CvType.CV_32F);

        /* Put pre-processed image to tensor and pre-pair run inference from model */
        prepairModelTensorsInput(processMat32F);
    }

    protected void prepairModelTensorsInput(@NotNull Mat preprocessedMat) {
        preprocessedMat.get(0,0, floatArrayInputValues);
        inputTensor.write(floatArrayInputValues, 0, floatArrayInputValues.length);
        modelInputSource.put(INPUT_LAYER,inputTensor);
    }

    @Override
    public OCRResult runInference(@NonNull InferenceResult inputInference) {
        Bitmap inputBitmap = inputInference.getProcessBitmap();
        long startPre = System.currentTimeMillis();
        preProcess(inputBitmap);
        long pre = System.currentTimeMillis() - startPre;

        long startDetect = System.currentTimeMillis();
        modelOutputSource = network.execute(modelInputSource);
        long runtime = System.currentTimeMillis() - startDetect;

        long startDecode = System.currentTimeMillis();
        OCRResult ocrResult = postProcess(inputInference);
        long decodeRuntime = System.currentTimeMillis() - startDecode;

        Log.d(LOGTAG + "_checkRuntime", "check LPD runtime: pre = "+ pre +"ms" + " | runtime = " + runtime + "ms | decode = " + decodeRuntime+"ms");

        return ocrResult;
    }

    public OCRResult runInference(@NonNull Bbox inputInference) {
        long startPre = System.currentTimeMillis();
        preProcess(inputInference);
        long pre = System.currentTimeMillis() - startPre;

        long startDetect = System.currentTimeMillis();
        modelOutputSource = network.execute(modelInputSource);
        long runtime = System.currentTimeMillis() - startDetect;

        long startDecode = System.currentTimeMillis();
        OCRResult ocrResult = postProcess(inputInference);
        long decodeRuntime = System.currentTimeMillis() - startDecode;

        Log.d(LOGTAG + "_checkRuntime", "check LPR runtime: pre = "+ pre +"ms" + " | runtime = " + runtime + "ms | decode = " + decodeRuntime+"ms");

        return ocrResult;
    }



    @Override
    protected void preProcess(@NonNull InferenceResult inferenceResult) {
        LPAlignResult alignResult = (LPAlignResult) inferenceResult.getInferenceResult(InferenceResult.ResultName.licencePlateAlign);
        Mat alignedVehicleMat = alignResult.getAlignedVehicleMat();
        Imgproc.resize(alignedVehicleMat, alignedVehicleMat, size);

        //ImageUtils.saveMatToCache(context, alignedVehicleMat, "test_ocr" + System.currentTimeMillis());

        alignedVehicleMat.convertTo(processMat32F, CvType.CV_32F);

        /* Put pre-processed image to tensor and pre-pair run inference from model */
        prepairModelTensorsInput(processMat32F);
    }

    @Override
    protected OCRResult postProcess(InferenceResult inferenceResult) {
        Log.d("decode_model", "check output size = " + modelOutputSource.size());

        for (Map.Entry<String, FloatTensor> output : modelOutputSource.entrySet()) {
            switch (output.getKey()) {
                case "logits:0":
                    Log.d("decode_model", "check output map name = " + output.getKey());
                    FloatTensor outputTensor = output.getValue();
                    float[] netArr = new float[outputTensor.getSize()];
                    outputTensor.read(netArr, 0, netArr.length);

                    Log.d("decode_model", "check output arr size =  " + netArr.length);
                    //Log.d("decode_model", "check output arr = " + Arrays.toString(netArr));

                    return new OCRResult(inferenceResult.getProcessMedia(), decodeLicencePlate(netArr));
            }
        }
        return null;
    }

    private String decodeLicencePlate(float[] netOut){
        // out net = 39x1x35
        String result = "";
        char cacheChar = 0;

        for (int i=0; i< 39;++i){
            float tmpMax = netOut[35*i];
            int id_max = 0;

            for (int j =0;j<35;++j){
                int curInd = 35*i + j;
                if(netOut[curInd] > tmpMax ){
                    tmpMax = netOut[curInd];
                    id_max = j;
                }
            }
            //Log.d(LOGTAG +"_decode_model", "check char in netout[" + i + "] = " + tmpMax + "~" + id_max + "~" + "'" + labels[id_max] +"'" );
            if (labels[id_max] != ' ' ) {
                if (result.length() == 0) {
                    result += labels[id_max];
                } else if (labels[id_max] != cacheChar) {
                    result += labels[id_max];
                }
            }

            cacheChar = labels[id_max];
        }

        Log.d("decode_model", "final plate OCR =  " + result);

        return result;
    }

    @Override
    public void release() {
        network.release();
        releaseTensors(modelInputSource);
    }
}
