package smile.validation.metric;

import smile.math.MathEx;

/* loaded from: input_file:smile/validation/metric/Recall.class */
public class Recall implements ClassificationMetric {
    private static final long serialVersionUID = 2;
    public static final Recall instance = new Recall();
    private final Averaging strategy;

    public Recall() {
        this(null);
    }

    public Recall(Averaging averaging) {
        this.strategy = averaging;
    }

    @Override // smile.validation.metric.ClassificationMetric
    public double score(int[] iArr, int[] iArr2) {
        return of(iArr, iArr2, this.strategy);
    }

    public String toString() {
        return this.strategy == null ? "Recall" : String.valueOf(this.strategy) + "-Recall";
    }

    public static double of(int[] iArr, int[] iArr2) {
        for (int i : iArr) {
            if (i != 0 && i != 1) {
                throw new IllegalArgumentException("Recall can only be applied to binary classification: " + i);
            }
        }
        for (int i2 : iArr2) {
            if (i2 != 0 && i2 != 1) {
                throw new IllegalArgumentException("Recall can only be applied to binary classification: " + i2);
            }
        }
        return of(iArr, iArr2, null);
    }

    public static double of(int[] iArr, int[] iArr2, Averaging averaging) {
        if (iArr.length != iArr2.length) {
            throw new IllegalArgumentException(String.format("The vector sizes don't match: %d != %d.", Integer.valueOf(iArr.length), Integer.valueOf(iArr2.length)));
        }
        int max = Math.max(MathEx.max(iArr), MathEx.max(iArr2)) + 1;
        if (max > 2 && averaging == null) {
            throw new IllegalArgumentException("Averaging strategy is null for multi-class");
        }
        int[] iArr3 = new int[(averaging == Averaging.Macro || averaging == Averaging.Weighted) ? max : 1];
        int[] iArr4 = new int[max];
        int length = iArr.length;
        for (int i : iArr) {
            iArr4[i] = iArr4[i] + 1;
        }
        if (averaging == null) {
            for (int i2 = 0; i2 < length; i2++) {
                if (iArr2[i2] == 1 && iArr[i2] == 1) {
                    iArr3[0] = iArr3[0] + 1;
                }
            }
        } else if (averaging == Averaging.Micro) {
            for (int i3 = 0; i3 < length; i3++) {
                iArr3[0] = iArr3[0] + (iArr[i3] == iArr2[i3] ? 1 : 0);
            }
        } else {
            for (int i4 = 0; i4 < length; i4++) {
                int i5 = iArr[i4];
                iArr3[i5] = iArr3[i5] + (iArr[i4] == iArr2[i4] ? 1 : 0);
            }
        }
        double[] dArr = new double[iArr3.length];
        if (iArr3.length == 1) {
            dArr[0] = iArr3[0] / (averaging == null ? iArr4[1] : length);
        } else {
            for (int i6 = 0; i6 < iArr3.length; i6++) {
                dArr[i6] = iArr3[i6] / iArr4[i6];
            }
        }
        if (averaging == Averaging.Macro) {
            return MathEx.mean(dArr);
        }
        if (averaging != Averaging.Weighted) {
            return dArr[0];
        }
        double d = 0.0d;
        for (int i7 = 0; i7 < max; i7++) {
            d += dArr[i7] * iArr4[i7];
        }
        return d / length;
    }
}
