package com.qualcomm.qti.snpe.imageclassifiers.thread;

import com.qualcomm.qti.snpe.imageclassifiers.detector.Bbox;

public class PostProcessPrediction {
    public interface ICalculateMatch {
        public float run(Bbox mPred1, Bbox mPred2);
    }
    float matchThreshold;
    String matchMetric;
    boolean classAgnostic;
    ICalculateMatch calculateMatch;
    public float calculateArea(Bbox mBbox){
//        return ((mBbox.BBoxParams[2]-mBbox.BBoxParams[0])*(mBbox.BBoxParams[3]-mBbox.BBoxParams[1]));
        return ((mBbox.x2 -mBbox.x1)*(mBbox.y2 -mBbox.y1));
    }

    public float calculateIntersectionArea(Bbox mBbox1, Bbox mBbox2){
        float[] leftTop = new float[]{Math.max(mBbox1.x1,mBbox2.x1),Math.max(mBbox1.y1,mBbox2.y1)};
        float[] rightBottom = new float[]{Math.min(mBbox1.x2,mBbox2.x2),Math.min(mBbox1.y2,mBbox2.y2)};
        float[] widthHeight = new float[]{Math.max(rightBottom[0]-leftTop[0],0),Math.max(rightBottom[1]-leftTop[1],0)};
        return widthHeight[0] * widthHeight[1];
    }
    public float calculateBoxIOS(Bbox mPred1, Bbox mPred2){
        float mArea1 = this.calculateArea(mPred1);
        float mArea2 = this.calculateArea(mPred2);
        float intersect = calculateIntersectionArea( mPred1, mPred2);
        float smallerArea = Math.min(mArea1,mArea2);
        return intersect/smallerArea;
    }

    public boolean hasMatch (Bbox mPred1, Bbox mPred2){
        boolean thresholdCondition = this.calculateMatch.run(mPred1,mPred2) > this.matchThreshold;
        boolean categoryCondition = (mPred1.getLabel().equals(mPred2.getLabel())
                || this.classAgnostic);
        return thresholdCondition && categoryCondition;

    }
    public PostProcessPrediction(float matchThreshold_, String matchMetric_, boolean classAgnostic_){
        this.matchThreshold = matchThreshold_;
        this.matchMetric = matchMetric_;
        this.classAgnostic = classAgnostic_;

        if (this.matchMetric == "IOS"){
            this.calculateMatch = new ICalculateMatch() {
                @Override
                public float run(Bbox mPred1, Bbox mPred2) {
                    return calculateBoxIOS(mPred1,mPred2);
                }
            };
        }
    }


}