package com.securityandsafetythings.examples.aiapp.aicore.aiprocess;

import static com.securityandsafetythings.examples.aiapp.aicore.aiLibs.InferenceResult.ResultName.faceRecognized;
import static com.securityandsafetythings.examples.aiapp.aicore.aiLibs.InferenceResult.ResultName.trueFace;
import static com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.algorithmbasic.FaceQualifier.IQA_BLUR_THRESHOLD;

import android.app.Application;
import android.content.Context;
import android.graphics.Canvas;

import androidx.annotation.NonNull;

import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.InferenceResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.classifier.TFClassifyResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.classifier.TrueFaceClassifier;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.face.FDResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.detector.face.RetinaFaceDetector;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.aiInference.extractor.MobileFaceExtractorV2;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.algorithmbasic.FaceFeatureSearcher;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.algorithmbasic.FaceQualifier;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.algorithmbasic.FaceRecognizeResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.algorithmbasic.IQAResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.align.FaceAligner;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.motion.MotionDetector;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.motion.MotionResult;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.sorttrack.KalmanSortTracker;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.sorttrack.TrackBox;
import com.securityandsafetythings.examples.aiapp.aicore.aiLibs.algorithm.nativewarpper.sorttrack.TrackResult;

import org.opencv.core.Mat;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

public class AIFaceProcessor extends AIProcess{
    static final String LOGTAG = AIFaceProcessor.class.getSimpleName();

    /** Algorithm object */
    private RetinaFaceDetector faceDetector;
    private MobileFaceExtractorV2 faceExtractor;
    private FaceAligner faceAligner;
    private TrueFaceClassifier trueFaceClassifier;
    private KalmanSortTracker sortTracker;
    private FaceQualifier faceQualifier;
    private FaceFeatureSearcher faceFeatureSearcher;

    /** Tracking & history */
    public static final long TRACK_TIMEOUT = 10000; //10s
    private ConcurrentMap<Integer, TrackBox> mapTmpTrackCache;
    private ConcurrentMap<Integer, TrackBox> mapUnknownTrackCache;
    private ConcurrentMap<Integer, TrackBox> mapRecognizedCache;

    /** params for vote distance for decide label of detected face */
    public static float TOTAL_DECISION_SCORE = 2.1f;
    public static float ALPHA_WEIGHT = 26f/17f;
    public static float SCORE_UNKNOWN_DEFAULT = 0.7f;
    public static final float DISTANCE_ACCEPT = 1.3f;
    public static float DISTANCE_GUARANTEE = 1.1f;
    public static float MAX_VOTE_WEIGHT = 5;
    public static float MIN_VOTE_WEIGHT = 3;
    public static final float LIMIT_MAX_DISTANCE = 2.6f;

    /** Threshold */
    public static final float FDFR_EYES_MIN = 0.25f;
    public static final float EYE_DISTANCE_THRES = 5f;

    public AIFaceProcessor(Context mContext, Application mApplication, AIProcess nextProcess) {
        super(mContext, mApplication, nextProcess);
        /*this.faceDetector = new RetinaFaceDetector(
                mContext,
                mApplication,
                R.raw.retina_mb_nosm_h288_w512_quantized,
                NeuralNetwork.Runtime.DSP
        );
        this.faceExtractor = new MobileFaceExtractorV2(
                mContext,
                mApplication,
                R.raw.frv4_mask,
                NeuralNetwork.Runtime.GPU_FLOAT16
        );
        this.trueFaceClassifier = new TrueFaceClassifier(
                mContext,
                mApplication,
                R.raw.model_filter_v2_snpe1_55,
                NeuralNetwork.Runtime.DSP
        );*/
        this.faceAligner = new FaceAligner();
        this.sortTracker = new KalmanSortTracker(0.3f);
        this.faceQualifier = new FaceQualifier();
        this.mapTmpTrackCache = new ConcurrentHashMap<>();
        this.mapRecognizedCache = new ConcurrentHashMap<>();
        this.mapUnknownTrackCache = new ConcurrentHashMap<>();
        this.faceFeatureSearcher = new FaceFeatureSearcher();
        faceFeatureSearcher.setCallback();
        this.setName("FaceThread");
    }

    @Override
    /** Input: MotionResult
     * Output: TrackResult */
    public TrackResult onProcess(InferenceResult inputInference) {
        MotionResult motionResult = (MotionResult) inputInference;
        TrackResult detectTrackResult = FaceDetect(motionResult);
        removeTimeoutTrack();
        return FaceRecognize(detectTrackResult);
    }

    @NonNull
    private TrackResult FaceDetect(@NonNull MotionResult inputInference){
        Mat fgMat;
        boolean vibeMode;
        MotionResult motionResult = inputInference;
        fgMat = motionResult.getForegroundMat();
        vibeMode = motionResult.getRuntimeMode();

        /* STEP1: execute model */
        FDResult fdResult = faceDetector.runInference(inputInference);

        /* STEP2: do sort track */
        TrackResult trackResult = sortTracker.runInference(fdResult);

        //WARNING: check trackResult is null or not - the first frame in trackSort can be null!!
        List<TrackBox> listTrackBox = new ArrayList<>();
        if (trackResult != null){
            listTrackBox = trackResult.getTrackList();
        }

        /* STEP3: check logic and filter face box */
        List<TrackBox> listPassLogicTrack = new ArrayList<>();
        for (TrackBox trackBox : listTrackBox){
            //filter by motionScore
            float motionScore = MotionDetector.getMotionScoreOfTrack(fgMat, trackBox);
            boolean[] resultThres = MotionDetector.checkStateMotionImage(motionScore, vibeMode);
            if (resultThres[0]){
                trackBox.setMotionScore(motionScore);
                listPassLogicTrack.add(trackBox);
            } else {
                // filter all box not pass motionScore
            }
        }


        ///* STEP3: render raw detect box */
        /*if (EasySharedPreference.getInstance().getIsDrawAi()) {
            Renderer.renderAllBoxFace(
                    new Canvas(inputInference.getProcessBitmap()),
                    listTrackBox,
                    null,
                    null,
                    fps
            );
        }*/

        /* Combine result(s) */
        TrackResult finalCombineResult = new TrackResult(inputInference, listPassLogicTrack);
        finalCombineResult.addResult(fdResult);
        finalCombineResult.addResult(motionResult);
        return finalCombineResult;
    }

    public TrackResult FaceRecognize(TrackResult inputInference){
        /* STEP1: Align face */
        faceAligner.runInference(inputInference);

        /* STEP2: Filter true-face + Tracking + Search feature */
        List<TrackBox> listTrackDetectResult = inputInference.getTrackList();
        List<TrackBox> listFinalResultTrack = new ArrayList<>();
        for (TrackBox trackBox : listTrackDetectResult){
            trueFaceClassifier.runInference(trackBox);
            TFClassifyResult tfResults = (TFClassifyResult) trackBox.getInferenceResult(trueFace);

            /* Filter true face here */
            if (tfResults.getFinalProb() >= 0.8){
                /* update track if track exist in mapRecognize cache else do search with orther */
                if (mapRecognizedCache.containsKey(trackBox.getTrackID())){
                    /* update for existed face recognize, assign fr result of previous inference to current result */
                    FaceRecognizeResult previousFaceRecognizeResult = (FaceRecognizeResult) mapRecognizedCache.get(trackBox.getTrackID()).getInferenceResult(faceRecognized);
                    trackBox.addResult(previousFaceRecognizeResult);
                    mapRecognizedCache.put(trackBox.getTrackID(), trackBox);

                    /* Add fr result by tracked face-recognized to list final fr result */
                    listFinalResultTrack.add(trackBox);
                } else {
                    int trackId = trackBox.getTrackID();

                    /* Extract and Search feature, caculate face iqa */
                    IQAResult iqaResult = faceQualifier.runInference(trackBox);
                    faceExtractor.runInference(trackBox);
                    FaceRecognizeResult faceRecognizeResult = faceFeatureSearcher.runInference(trackBox);

                    /* Decide Unknown/Recognize Face */
                    if (faceRecognizeResult.getMinDistance() < DISTANCE_GUARANTEE){
                        mapRecognizedCache.put(trackId, trackBox);
                        mapTmpTrackCache.remove(trackId);
                        mapUnknownTrackCache.remove(trackId);
                        listFinalResultTrack.add(trackBox);
                    } else if (iqaResult.getBlur() > IQA_BLUR_THRESHOLD && iqaResult.getEyeRatio() > FDFR_EYES_MIN){
                        if (faceRecognizeResult.getMinDistance() < DISTANCE_ACCEPT) { // -> current label could not be unknown
                            /* check iqa (blur and eye ratio) for unreliable face but score still acceptable */
                            String currentLabel = faceRecognizeResult.getLabel();
                            if (mapTmpTrackCache.containsKey(trackId)) {
                                // if track exist in tmpTrackCache
                                FaceRecognizeResult cacheFRResult = (FaceRecognizeResult) mapTmpTrackCache.get(trackId).getInferenceResult(faceRecognized);
                                cacheFRResult.setScoreMapByLabel(currentLabel, convertDistanceToScore(faceRecognizeResult.getMinDistance()));

                                if (cacheFRResult.getScoreMapByLabel(currentLabel) >= TOTAL_DECISION_SCORE) {
                                    cacheFRResult.setLabel(currentLabel);
                                    trackBox.addResult(cacheFRResult);
                                    mapRecognizedCache.put(trackId, trackBox);
                                    mapTmpTrackCache.remove(trackId);
                                    mapUnknownTrackCache.remove(trackId);
                                    listFinalResultTrack.add(trackBox);
                                } else {
                                    //update track in tmpTrackCache with new one but stack the scoreMap
                                    trackBox.addResult(cacheFRResult);
                                    mapTmpTrackCache.put(trackId, trackBox);
                                }
                            } else {
                                // if track is the new one
                                faceRecognizeResult.setScoreMapByLabel(currentLabel, faceRecognizeResult.getMinDistance());
                                trackBox.addResult(faceRecognizeResult);
                                mapTmpTrackCache.put(trackId, trackBox);
                            }
                        } else { // -> current label must be unknown
                            /* Checking post iqa before voting unknown */
                            if (iqaResult.getEyeDistance() > EYE_DISTANCE_THRES){
                                String currentLabel = faceRecognizeResult.getLabel();
                                if (mapTmpTrackCache.containsKey(trackId)){
                                    // if track exist in mapTmpTrackCache
                                    FaceRecognizeResult cacheUnknownResult = (FaceRecognizeResult) mapTmpTrackCache.get(trackId).getInferenceResult(faceRecognized);
                                    cacheUnknownResult.setScoreMapByLabel(currentLabel, SCORE_UNKNOWN_DEFAULT);
                                    if (cacheUnknownResult.getScoreMapByLabel(currentLabel) >= TOTAL_DECISION_SCORE ){
                                        cacheUnknownResult.setLabel(currentLabel);
                                        trackBox.addResult(cacheUnknownResult);
                                        mapUnknownTrackCache.put(trackId, trackBox);
                                        mapTmpTrackCache.remove(trackId);
                                        listFinalResultTrack.add(trackBox);
                                    } else {
                                        trackBox.addResult(cacheUnknownResult);
                                        mapTmpTrackCache.put(trackId, trackBox);
                                    }
                                } else {
                                    // if track is the new one
                                    faceRecognizeResult.setScoreMapByLabel(currentLabel, SCORE_UNKNOWN_DEFAULT);
                                    trackBox.addResult(faceRecognizeResult);
                                    mapTmpTrackCache.put(trackId, trackBox);
                                }
                            }
                        }

                    }
                }
            }
        } // end for

        /* STEP3: Rendering */
        /*if (EasySharedPreference.getInstance().getIsDrawAi()) {
            Renderer.renderAllBoxFace(
                    new Canvas(inputInference.getProcessBitmap()),
                    null,
                    listTrackDetectResult,
                    listFinalResultTrack,
                    fps
            );
        }*/

        TrackResult finalCombineResult = new TrackResult(inputInference, listFinalResultTrack);
        return finalCombineResult;
    }

    private void removeTimeoutTrack(){
        long currentTime = System.currentTimeMillis();

        /* Remove timeout tmpTrack */
        Iterator<Map.Entry<Integer, TrackBox>> trackEntry = mapTmpTrackCache.entrySet().iterator();
        while (trackEntry.hasNext()){
            TrackBox trackBox = trackEntry.next().getValue();
            if (currentTime - trackBox.getTimeLife() > TRACK_TIMEOUT){
                trackEntry.remove();
            }
        }

        /* remove timeout unknownTrack */
        trackEntry = mapUnknownTrackCache.entrySet().iterator();
        while (trackEntry.hasNext()){
            TrackBox trackBox = trackEntry.next().getValue();
            if (currentTime - trackBox.getTimeLife() > TRACK_TIMEOUT){
                trackEntry.remove();
            }
        }

        /* remove timeout */
        trackEntry = mapRecognizedCache.entrySet().iterator();
        while (trackEntry.hasNext()){
            TrackBox trackBox = trackEntry.next().getValue();
            if (currentTime - trackBox.getTimeLife() > TRACK_TIMEOUT){
                trackEntry.remove();
            }
        }
    }

    public void caculateVoteWeightIndex(){
        ALPHA_WEIGHT = LIMIT_MAX_DISTANCE * (MAX_VOTE_WEIGHT - MIN_VOTE_WEIGHT) / (MAX_VOTE_WEIGHT * DISTANCE_ACCEPT - MIN_VOTE_WEIGHT * DISTANCE_GUARANTEE);
        TOTAL_DECISION_SCORE = MIN_VOTE_WEIGHT * (LIMIT_MAX_DISTANCE - ALPHA_WEIGHT * DISTANCE_GUARANTEE);
        SCORE_UNKNOWN_DEFAULT = convertDistanceToScore(DISTANCE_GUARANTEE);
    }

    public static float convertDistanceToScore(float distance){
        return LIMIT_MAX_DISTANCE - ALPHA_WEIGHT * distance;
    }

    @Override
    public void onProcessEnd() {
        if (faceDetector != null){
            faceDetector.release();
        }
        if (faceExtractor != null){
            faceExtractor.release();
        }
        if (trueFaceClassifier != null){
            trueFaceClassifier.release();
        }
        if (faceAligner != null){
            faceAligner.release();
        }
        if (sortTracker != null){
            sortTracker.release();
        }
        faceQualifier.release();
        faceFeatureSearcher.release();
        mapTmpTrackCache.clear();
    }
}
