package smile.graph;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.SerializedLambda;
import java.lang.runtime.ObjectMethods;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.Metric;
import smile.neighbor.RandomProjectionTree;

/* loaded from: input_file:smile/graph/NearestNeighborGraph.class */
public final class NearestNeighborGraph extends Record {
    private final int k;
    private final int[][] neighbors;
    private final double[][] distances;
    private final int[] index;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) NearestNeighborGraph.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/graph/NearestNeighborGraph$CandidateGenerator.class */
    public interface CandidateGenerator {
        int[] generate(int i, int i2, int i3);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/graph/NearestNeighborGraph$Neighbor.class */
    public static class Neighbor implements Comparable<Neighbor> {
        public int index;
        public double distance;

        public Neighbor(int i, double d) {
            this.index = i;
            this.distance = d;
        }

        public int hashCode() {
            return this.index;
        }

        @Override // java.lang.Comparable
        public int compareTo(Neighbor neighbor) {
            return Double.compare(neighbor.distance, this.distance);
        }
    }

    public NearestNeighborGraph(int i, int[][] iArr, double[][] dArr) {
        this(i, iArr, dArr, IntStream.range(0, iArr.length).toArray());
    }

    public NearestNeighborGraph(int i, int[][] iArr, double[][] dArr, int[] iArr2) {
        this.k = i;
        this.neighbors = iArr;
        this.distances = dArr;
        this.index = iArr2;
    }

    public int size() {
        return this.neighbors.length;
    }

    public AdjacencyList graph(boolean z) {
        int length = this.neighbors.length;
        AdjacencyList adjacencyList = new AdjacencyList(length, z);
        IntStream.range(0, length).forEach(i -> {
            int[] iArr = this.neighbors[i];
            double[] dArr = this.distances[i];
            for (int i = 0; i < iArr.length; i++) {
                adjacencyList.setWeight(i, iArr[i], dArr[i]);
            }
        });
        return adjacencyList;
    }

    public static NearestNeighborGraph of(double[][] dArr, int i) {
        return of(dArr, MathEx::distance, i);
    }

    public NearestNeighborGraph largest(boolean z) {
        int[][] bfcc = graph(z).bfcc();
        if (bfcc.length == 1) {
            return this;
        }
        int[] iArr = (int[]) Arrays.stream(bfcc).max(Comparator.comparing(iArr2 -> {
            return Integer.valueOf(iArr2.length);
        })).orElseThrow(NoSuchElementException::new);
        logger.info("{} connected components, largest one has {} samples.", Integer.valueOf(bfcc.length), Integer.valueOf(iArr.length));
        int length = this.neighbors.length;
        int[] iArr3 = new int[length];
        for (int i = 0; i < length; i++) {
            iArr3[iArr[i]] = i;
        }
        int[][] iArr4 = new int[length][this.k];
        double[][] dArr = new double[length][this.k];
        for (int i2 = 0; i2 < length; i2++) {
            dArr[i2] = this.distances[iArr[i2]];
            int[] iArr5 = this.neighbors[iArr[i2]];
            for (int i3 = 0; i3 < this.k; i3++) {
                iArr4[i2][i3] = iArr3[iArr5[i3]];
            }
        }
        return new NearestNeighborGraph(this.k, iArr4, dArr, iArr);
    }

    public static <T> NearestNeighborGraph of(T[] tArr, Distance<T> distance, int i) {
        return toGraph(build(tArr, distance, i, (i2, i3, i4) -> {
            return IntStream.range(0, i2).toArray();
        }), i);
    }

    public static <T> NearestNeighborGraph random(T[] tArr, Distance<T> distance, int i) {
        List<PriorityQueue<Neighbor>> build = build(tArr, distance, i, NearestNeighborGraph::rejectionSample);
        extend(build);
        return toGraph(build, i);
    }

    private static <T> List<PriorityQueue<Neighbor>> build(T[] tArr, Distance<T> distance, int i, CandidateGenerator candidateGenerator) {
        if (i < 2) {
            throw new IllegalArgumentException("k must be greater than 1: " + i);
        }
        int length = tArr.length;
        ArrayList arrayList = new ArrayList(length);
        for (int i2 = 0; i2 < length; i2++) {
            arrayList.add(new PriorityQueue());
        }
        IntStream.range(0, length).parallel().forEach(i3 -> {
            Object obj = tArr[i3];
            PriorityQueue priorityQueue = (PriorityQueue) arrayList.get(i3);
            for (int i3 : candidateGenerator.generate(length, i, i3)) {
                if (i3 != i3) {
                    double d = distance.d(obj, tArr[i3]);
                    if (priorityQueue.size() < i) {
                        priorityQueue.offer(new Neighbor(i3, d));
                    } else if (d < ((Neighbor) priorityQueue.peek()).distance) {
                        Neighbor neighbor = (Neighbor) priorityQueue.poll();
                        neighbor.index = i3;
                        neighbor.distance = d;
                        priorityQueue.offer(neighbor);
                    }
                }
            }
        });
        return arrayList;
    }

    private static void extend(List<PriorityQueue<Neighbor>> list) {
        int size = list.size();
        ArrayList arrayList = new ArrayList(size);
        ArrayList arrayList2 = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            arrayList.add(new HashSet());
            arrayList2.add(new HashSet());
        }
        for (int i2 = 0; i2 < size; i2++) {
            Set set = (Set) arrayList.get(i2);
            Iterator<Neighbor> it = list.get(i2).iterator();
            while (it.hasNext()) {
                Neighbor next = it.next();
                set.add(Integer.valueOf(next.index));
                ((Set) arrayList2.get(next.index)).add(new Neighbor(i2, next.distance));
            }
        }
        for (int i3 = 0; i3 < size; i3++) {
            Set set2 = (Set) arrayList.get(i3);
            PriorityQueue<Neighbor> priorityQueue = list.get(i3);
            for (Neighbor neighbor : (Set) arrayList2.get(i3)) {
                if (!set2.contains(Integer.valueOf(neighbor.index)) && neighbor.distance < priorityQueue.peek().distance) {
                    Neighbor poll = priorityQueue.poll();
                    poll.index = neighbor.index;
                    poll.distance = neighbor.distance;
                    priorityQueue.offer(poll);
                }
            }
        }
    }

    private static NearestNeighborGraph toGraph(List<PriorityQueue<Neighbor>> list, int i) {
        int size = list.size();
        int[][] iArr = new int[size][i];
        double[][] dArr = new double[size][i];
        for (int i2 = 0; i2 < size; i2++) {
            PriorityQueue<Neighbor> priorityQueue = list.get(i2);
            int size2 = priorityQueue.size();
            while (!priorityQueue.isEmpty()) {
                Neighbor poll = priorityQueue.poll();
                size2--;
                if (size2 < i) {
                    iArr[i2][size2] = poll.index;
                    dArr[i2][size2] = poll.distance;
                }
            }
        }
        return new NearestNeighborGraph(i, iArr, dArr);
    }

    public static NearestNeighborGraph descent(double[][] dArr, int i) {
        return descent(dArr, i, 5, i, 50, 50, 0.001d);
    }

    public static NearestNeighborGraph descent(double[][] dArr, int i, int i2, int i3, int i4, int i5, double d) {
        int length = dArr.length;
        ArrayList arrayList = new ArrayList(dArr.length);
        ArrayList arrayList2 = new ArrayList(dArr.length);
        for (int i6 = 0; i6 < dArr.length; i6++) {
            arrayList.add(new PriorityQueue());
            arrayList2.add(new HashSet());
        }
        for (int i7 = 0; i7 < i2; i7++) {
            for (int[] iArr : RandomProjectionTree.of(dArr, i3, false).leafSamples()) {
                for (int i8 = 0; i8 < iArr.length; i8++) {
                    int i9 = iArr[i8];
                    double[] dArr2 = dArr[i9];
                    for (int i10 = i8 + 1; i10 < iArr.length; i10++) {
                        int i11 = iArr[i10];
                        double distance = MathEx.distance(dArr2, dArr[i11]);
                        updateHeap((PriorityQueue) arrayList.get(i9), (Set) arrayList2.get(i9), i, i11, distance);
                        updateHeap((PriorityQueue) arrayList.get(i11), (Set) arrayList2.get(i11), i, i9, distance);
                    }
                }
            }
        }
        return descent(dArr, MathEx::distance, arrayList, i, i4, i5, d);
    }

    private static boolean updateHeap(PriorityQueue<Neighbor> priorityQueue, Set<Integer> set, int i, int i2, double d) {
        if (set.contains(Integer.valueOf(i2))) {
            return false;
        }
        if (priorityQueue.size() < i) {
            priorityQueue.add(new Neighbor(i2, d));
            set.add(Integer.valueOf(i2));
            return true;
        }
        if (d >= priorityQueue.peek().distance) {
            return false;
        }
        Neighbor poll = priorityQueue.poll();
        set.remove(Integer.valueOf(poll.index));
        set.add(Integer.valueOf(i2));
        poll.distance = d;
        poll.index = i2;
        priorityQueue.offer(poll);
        return true;
    }

    public static <T> NearestNeighborGraph descent(T[] tArr, Metric<T> metric, int i) {
        return descent(tArr, metric, i, 50, 10, 0.001d);
    }

    public static <T> NearestNeighborGraph descent(T[] tArr, Metric<T> metric, int i, int i2, int i3, double d) {
        if (i < 2) {
            throw new IllegalArgumentException("k must be greater than 1: " + i);
        }
        List<PriorityQueue<Neighbor>> build = build(tArr, metric, i, NearestNeighborGraph::rejectionSample);
        extend(build);
        return descent(tArr, metric, build, i, i2, i3, d);
    }

    private static <T> NearestNeighborGraph descent(T[] tArr, Metric<T> metric, List<PriorityQueue<Neighbor>> list, int i, int i2, int i3, double d) {
        int length = tArr.length;
        ArrayList arrayList = new ArrayList(tArr.length);
        for (int i4 = 0; i4 < tArr.length; i4++) {
            arrayList.add(new HashSet());
        }
        for (int i5 = 0; i5 < length; i5++) {
            Set set = (Set) arrayList.get(i5);
            Iterator<Neighbor> it = list.get(i5).iterator();
            while (it.hasNext()) {
                set.add(Integer.valueOf(it.next().index));
            }
        }
        for (int i6 = 1; i6 <= i3; i6++) {
            int i7 = 0;
            int[][] generateCandidates = generateCandidates(list, i2);
            for (int i8 = 0; i8 < length; i8++) {
                for (int i9 : generateCandidates[i8]) {
                    double d2 = metric.d(tArr[i8], tArr[i9]);
                    if (updateHeap(list.get(i8), (Set) arrayList.get(i8), i, i9, d2)) {
                        i7++;
                    }
                    if (updateHeap(list.get(i9), (Set) arrayList.get(i9), i, i8, d2)) {
                        i7++;
                    }
                }
            }
            logger.info("NearestNeighborDescent iteration {}: {}", Integer.valueOf(i6), Integer.valueOf(i7));
            if (i7 <= d * i * length) {
                break;
            }
        }
        return toGraph(list, i);
    }

    /* JADX WARN: Type inference failed for: r0v8, types: [int[], int[][]] */
    private static int[][] generateCandidates(List<PriorityQueue<Neighbor>> list, int i) {
        int size = list.size();
        ArrayList arrayList = new ArrayList(size);
        for (int i2 = 0; i2 < size; i2++) {
            arrayList.add(new HashSet());
        }
        for (int i3 = 0; i3 < size; i3++) {
            Iterator<Neighbor> it = list.get(i3).iterator();
            while (it.hasNext()) {
                Neighbor next = it.next();
                int i4 = next.index;
                double d = next.distance;
                Iterator<Neighbor> it2 = list.get(i4).iterator();
                while (it2.hasNext()) {
                    Neighbor next2 = it2.next();
                    int i5 = next2.index;
                    double d2 = next2.distance;
                    ((Set) arrayList.get(i3)).add(new Neighbor(i5, d + d2));
                    ((Set) arrayList.get(i5)).add(new Neighbor(i3, d + d2));
                }
            }
        }
        ?? r0 = new int[size];
        for (int i6 = 0; i6 < size; i6++) {
            ArrayList arrayList2 = new ArrayList((Collection) arrayList.get(i6));
            arrayList2.sort(Comparator.comparingDouble(neighbor -> {
                return neighbor.distance;
            }));
            r0[i6] = arrayList2.stream().limit(i).mapToInt(neighbor2 -> {
                return neighbor2.index;
            }).toArray();
        }
        return r0;
    }

    private static int[] rejectionSample(int i, int i2, int i3) {
        if (i2 > i) {
            throw new IllegalArgumentException();
        }
        int[] iArr = new int[i2];
        for (int i4 = 0; i4 < i2; i4++) {
            boolean z = true;
            while (z) {
                z = false;
                iArr[i4] = MathEx.randomInt(i);
                if (iArr[i4] == i3) {
                    z = true;
                } else {
                    int i5 = 0;
                    while (true) {
                        if (i5 >= i4) {
                            break;
                        }
                        if (iArr[i4] == iArr[i5]) {
                            z = true;
                            break;
                        }
                        i5++;
                    }
                }
            }
        }
        return iArr;
    }

    @Override // java.lang.Record
    public final String toString() {
        return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, NearestNeighborGraph.class), NearestNeighborGraph.class, "k;neighbors;distances;index", "FIELD:Lsmile/graph/NearestNeighborGraph;->k:I", "FIELD:Lsmile/graph/NearestNeighborGraph;->neighbors:[[I", "FIELD:Lsmile/graph/NearestNeighborGraph;->distances:[[D", "FIELD:Lsmile/graph/NearestNeighborGraph;->index:[I").dynamicInvoker().invoke(this) /* invoke-custom */;
    }

    @Override // java.lang.Record
    public final int hashCode() {
        return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, NearestNeighborGraph.class), NearestNeighborGraph.class, "k;neighbors;distances;index", "FIELD:Lsmile/graph/NearestNeighborGraph;->k:I", "FIELD:Lsmile/graph/NearestNeighborGraph;->neighbors:[[I", "FIELD:Lsmile/graph/NearestNeighborGraph;->distances:[[D", "FIELD:Lsmile/graph/NearestNeighborGraph;->index:[I").dynamicInvoker().invoke(this) /* invoke-custom */;
    }

    @Override // java.lang.Record
    public final boolean equals(Object obj) {
        return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, NearestNeighborGraph.class, Object.class), NearestNeighborGraph.class, "k;neighbors;distances;index", "FIELD:Lsmile/graph/NearestNeighborGraph;->k:I", "FIELD:Lsmile/graph/NearestNeighborGraph;->neighbors:[[I", "FIELD:Lsmile/graph/NearestNeighborGraph;->distances:[[D", "FIELD:Lsmile/graph/NearestNeighborGraph;->index:[I").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
    }

    public int k() {
        return this.k;
    }

    public int[][] neighbors() {
        return this.neighbors;
    }

    public double[][] distances() {
        return this.distances;
    }

    public int[] index() {
        return this.index;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 288459765:
                if (implMethodName.equals("distance")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("smile/math/distance/Distance") && serializedLambda.getFunctionalInterfaceMethodName().equals("d") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D") && serializedLambda.getImplClass().equals("smile/math/MathEx") && serializedLambda.getImplMethodSignature().equals("([D[D)D")) {
                    return MathEx::distance;
                }
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("smile/math/distance/Metric") && serializedLambda.getFunctionalInterfaceMethodName().equals("d") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D") && serializedLambda.getImplClass().equals("smile/math/MathEx") && serializedLambda.getImplMethodSignature().equals("([D[D)D")) {
                    return MathEx::distance;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
