package smile.clustering;

import java.lang.reflect.Array;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.distance.Distance;

/* loaded from: input_file:smile/clustering/CLARANS.class */
public class CLARANS<T> extends CentroidClustering<T, T> {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) CLARANS.class);
    private final Distance<T> distance;

    public CLARANS(double d, T[] tArr, int[] iArr, Distance<T> distance) {
        super(d, tArr, iArr);
        this.distance = distance;
    }

    @Override // smile.clustering.CentroidClustering
    protected double distance(T t, T t2) {
        return this.distance.d(t, t2);
    }

    public static <T> CLARANS<T> fit(T[] tArr, Distance<T> distance, int i) {
        return fit(tArr, distance, i, (int) Math.round(0.0125d * i * (tArr.length - i)));
    }

    public static <T> CLARANS<T> fit(T[] tArr, Distance<T> distance, int i, int i2) {
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maxNeighbors: " + i2);
        }
        int length = tArr.length;
        if (i >= length) {
            throw new IllegalArgumentException("Too large k: " + i);
        }
        if (i2 > length) {
            throw new IllegalArgumentException("Too large maxNeighbor: " + i2);
        }
        int i3 = 100;
        if (i * (length - i) < 100) {
            i3 = i * (length - i);
        }
        if (i2 < i3) {
            i2 = i3;
        }
        Object[] objArr = (Object[]) Array.newInstance(tArr.getClass().getComponentType(), i);
        Object[] objArr2 = (Object[]) objArr.clone();
        int[] iArr = new int[length];
        int[] iArr2 = new int[length];
        double[] dArr = new double[length];
        double[] seed = seed(tArr, objArr, iArr, distance);
        double sum = MathEx.sum(seed);
        System.arraycopy(objArr, 0, objArr2, 0, i);
        System.arraycopy(iArr, 0, iArr2, 0, length);
        System.arraycopy(seed, 0, dArr, 0, length);
        int i4 = 1;
        while (i4 <= i2) {
            double randomNeighbor = getRandomNeighbor(tArr, objArr2, iArr2, dArr, distance);
            if (randomNeighbor < sum) {
                logger.info("Distortion after {} random neighbors reduces to {} ", Integer.valueOf(i4), Double.valueOf(sum));
                i4 = 0;
                sum = randomNeighbor;
                System.arraycopy(objArr2, 0, objArr, 0, i);
                System.arraycopy(iArr2, 0, iArr, 0, length);
                System.arraycopy(dArr, 0, seed, 0, length);
            } else {
                System.arraycopy(objArr, 0, objArr2, 0, i);
                System.arraycopy(iArr, 0, iArr2, 0, length);
                System.arraycopy(seed, 0, dArr, 0, length);
            }
            i4++;
        }
        logger.info("Final distortion: {}", Double.valueOf(sum));
        return new CLARANS<>(sum, objArr, iArr, distance);
    }

    private static <T> double getRandomNeighbor(T[] tArr, T[] tArr2, int[] iArr, double[] dArr, ToDoubleBiFunction<T, T> toDoubleBiFunction) {
        int length = tArr.length;
        int length2 = tArr2.length;
        int randomInt = MathEx.randomInt(length2);
        Object randomMedoid = getRandomMedoid(tArr, tArr2);
        tArr2[randomInt] = randomMedoid;
        IntStream.range(0, length).parallel().forEach(i -> {
            double applyAsDouble = toDoubleBiFunction.applyAsDouble(tArr[i], randomMedoid);
            if (dArr[i] > applyAsDouble) {
                iArr[i] = randomInt;
                dArr[i] = applyAsDouble;
                return;
            }
            if (iArr[i] == randomInt) {
                dArr[i] = applyAsDouble;
                for (int i = 0; i < length2; i++) {
                    if (i != randomInt) {
                        double applyAsDouble2 = toDoubleBiFunction.applyAsDouble(tArr[i], tArr2[i]);
                        if (dArr[i] > applyAsDouble2) {
                            dArr[i] = applyAsDouble2;
                            iArr[i] = i;
                        }
                    }
                }
            }
        });
        return MathEx.sum(dArr);
    }

    private static <T> T getRandomMedoid(T[] tArr, T[] tArr2) {
        int length = tArr.length;
        T t = tArr[MathEx.randomInt(length)];
        while (true) {
            T t2 = t;
            if (!contains(tArr2, t2)) {
                return t2;
            }
            t = tArr[MathEx.randomInt(length)];
        }
    }

    private static <T> boolean contains(T[] tArr, T t) {
        for (T t2 : tArr) {
            if (t2 == t) {
                return true;
            }
        }
        return false;
    }
}
