package org.deeplearning4j.models.sequencevectors.graph.walkers.impl;

import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Generated;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.graph.enums.SamplingMode;
import org.deeplearning4j.models.sequencevectors.graph.primitives.IGraph;
import org.deeplearning4j.models.sequencevectors.graph.primitives.Vertex;
import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker.class */
public class NearestVertexWalker<V extends SequenceElement> implements GraphWalker<V> {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(NearestVertexWalker.class);
    protected IGraph<V, ?> sourceGraph;
    protected int[] order;
    protected Random rng;
    protected int depth;
    protected int walkLength = 0;
    protected long seed = 0;
    protected SamplingMode samplingMode = SamplingMode.RANDOM;
    private AtomicInteger position = new AtomicInteger(0);

    /* loaded from: input_file:org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker$Builder.class */
    public static class Builder<V extends SequenceElement> {
        protected IGraph<V, ?> sourceGraph;
        protected long seed;
        protected int walkLength = 0;
        protected SamplingMode samplingMode = SamplingMode.RANDOM;
        protected int depth = 1;

        public Builder(@NonNull IGraph<V, ?> iGraph) {
            if (iGraph == null) {
                throw new NullPointerException("graph is marked non-null but is null");
            }
            this.sourceGraph = iGraph;
        }

        public Builder setSeed(long j) {
            this.seed = j;
            return this;
        }

        public Builder setWalkLength(int i) {
            this.walkLength = i;
            return this;
        }

        public Builder setDepth(int i) {
            this.depth = i;
            return this;
        }

        public Builder setSamplingMode(@NonNull SamplingMode samplingMode) {
            if (samplingMode == null) {
                throw new NullPointerException("mode is marked non-null but is null");
            }
            this.samplingMode = samplingMode;
            return this;
        }

        public NearestVertexWalker<V> build() {
            NearestVertexWalker<V> nearestVertexWalker = new NearestVertexWalker<>();
            nearestVertexWalker.sourceGraph = this.sourceGraph;
            nearestVertexWalker.walkLength = this.walkLength;
            nearestVertexWalker.samplingMode = this.samplingMode;
            nearestVertexWalker.depth = this.depth;
            nearestVertexWalker.order = new int[this.sourceGraph.numVertices()];
            for (int i = 0; i < nearestVertexWalker.order.length; i++) {
                nearestVertexWalker.order[i] = i;
            }
            nearestVertexWalker.rng = new Random(this.seed);
            nearestVertexWalker.reset(true);
            return nearestVertexWalker;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker$VertexComparator.class */
    public class VertexComparator<V extends SequenceElement, E extends Number> implements Comparator<Vertex<V>> {
        private IGraph<V, E> graph;

        public VertexComparator(@NonNull IGraph<V, E> iGraph) {
            if (iGraph == null) {
                throw new NullPointerException("graph is marked non-null but is null");
            }
            this.graph = iGraph;
        }

        @Override // java.util.Comparator
        public int compare(Vertex<V> vertex, Vertex<V> vertex2) {
            return Integer.compare(this.graph.getConnectedVertices(vertex2.vertexID()).size(), this.graph.getConnectedVertices(vertex.vertexID()).size());
        }
    }

    protected NearestVertexWalker() {
    }

    @Override // org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker
    public boolean hasNext() {
        return this.position.get() < this.order.length;
    }

    @Override // org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker
    public Sequence<V> next() {
        return walk(this.sourceGraph.getVertex(this.order[this.position.getAndIncrement()]), 1);
    }

    @Override // org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker
    public void reset(boolean z) {
        this.position.set(0);
        if (z) {
            log.trace("Calling shuffle() on entries...");
            for (int length = this.order.length - 1; length > 0; length--) {
                int nextInt = this.rng.nextInt(length + 1);
                int i = this.order[nextInt];
                this.order[nextInt] = this.order[length];
                this.order[length] = i;
            }
        }
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:12:0x0063. Please report as an issue. */
    protected Sequence<V> walk(Vertex<V> vertex, int i) {
        Sequence<V> sequence = new Sequence<>();
        List<Vertex<V>> connectedVertices = this.sourceGraph.getConnectedVertices(vertex.vertexID());
        sequence.setSequenceLabel(vertex.getValue());
        if (this.walkLength == 0) {
            Iterator<Vertex<V>> it = connectedVertices.iterator();
            while (it.hasNext()) {
                sequence.addElement(it.next().getValue());
            }
        } else {
            switch (this.samplingMode) {
                case MAX_POPULARITY:
                    Collections.sort(connectedVertices, new VertexComparator(this.sourceGraph));
                    for (int i2 = 0; i2 < this.walkLength; i2++) {
                        sequence.addElement(connectedVertices.get(i2).getValue());
                        if (this.depth > 1 && i < this.depth) {
                            i++;
                            for (V v : walk(connectedVertices.get(i2), i).getElements()) {
                                if (sequence.getElementByLabel(v.getLabel()) == null) {
                                    sequence.addElement(v);
                                }
                            }
                        }
                    }
                    break;
                case MEDIAN_POPULARITY:
                    Collections.sort(connectedVertices, new VertexComparator(this.sourceGraph));
                    int size = (connectedVertices.size() / 2) - (this.walkLength / 2);
                    for (int i3 = 0; i3 < this.walkLength && size < connectedVertices.size(); i3++) {
                        sequence.addElement(connectedVertices.get(size).getValue());
                        if (this.depth > 1 && i < this.depth) {
                            i++;
                            for (V v2 : walk(connectedVertices.get(size), i).getElements()) {
                                if (sequence.getElementByLabel(v2.getLabel()) == null) {
                                    sequence.addElement(v2);
                                }
                            }
                        }
                        size++;
                    }
                    break;
                case MIN_POPULARITY:
                    Collections.sort(connectedVertices, new VertexComparator(this.sourceGraph));
                    int size2 = connectedVertices.size();
                    for (int i4 = 0; i4 < this.walkLength && size2 >= 0; i4++) {
                        sequence.addElement(connectedVertices.get(size2).getValue());
                        if (this.depth > 1 && i < this.depth) {
                            i++;
                            for (V v3 : walk(connectedVertices.get(size2), i).getElements()) {
                                if (sequence.getElementByLabel(v3.getLabel()) == null) {
                                    sequence.addElement(v3);
                                }
                            }
                        }
                        size2--;
                    }
                    break;
                case RANDOM:
                    if (connectedVertices.size() <= this.walkLength) {
                        Iterator<Vertex<V>> it2 = connectedVertices.iterator();
                        while (it2.hasNext()) {
                            sequence.addElement(it2.next().getValue());
                        }
                        break;
                    } else {
                        HashSet hashSet = new HashSet();
                        while (hashSet.size() < this.walkLength) {
                            Vertex<V> vertex2 = (Vertex) ArrayUtil.getRandomElement(connectedVertices);
                            hashSet.add(vertex2.getValue());
                            if (this.depth > 1 && i < this.depth) {
                                i++;
                                for (V v4 : walk(vertex2, i).getElements()) {
                                    if (sequence.getElementByLabel(v4.getLabel()) == null) {
                                        sequence.addElement(v4);
                                    }
                                }
                            }
                        }
                        sequence.addElements(hashSet);
                        break;
                    }
                    break;
                default:
                    throw new ND4JIllegalStateException("Unknown sampling mode was passed in: [" + this.samplingMode + "]");
            }
        }
        return sequence;
    }

    @Override // org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker
    public boolean isLabelEnabled() {
        return true;
    }

    @Override // org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker
    @Generated
    public IGraph<V, ?> getSourceGraph() {
        return this.sourceGraph;
    }
}
