package org.deeplearning4j.models.embeddings.learning.impl.elements;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import lombok.Generated;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.nlp.SkipGramInference;
import org.nd4j.linalg.api.ops.impl.nlp.SkipGramRound;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.shade.guava.cache.Cache;
import org.nd4j.shade.guava.cache.CacheBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/elements/SkipGram.class */
public class SkipGram<T extends SequenceElement> implements ElementsLearningAlgorithm<T> {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(SkipGram.class);
    protected VocabCache<T> vocabCache;
    protected WeightLookupTable<T> lookupTable;
    protected VectorsConfiguration configuration;
    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected double sampling;
    protected int[] variableWindows;
    protected int vectorLength;
    protected DeviceLocalNDArray syn0;
    protected DeviceLocalNDArray syn1;
    protected DeviceLocalNDArray syn1Neg;
    protected DeviceLocalNDArray table;
    protected DeviceLocalNDArray expTable;
    protected int maxQueueSize = Integer.parseInt(System.getProperty("org.eclipse.deeplearning4j.nlp.queuesize", "1000"));
    private Cache<IterationArraysKey, Queue<IterationArrays>> iterationArrays = CacheBuilder.newBuilder().maximumSize(Integer.parseInt(System.getProperty("org.eclipse.deeplearning4j.nlp.cachesize", "1000"))).weakKeys().expireAfterWrite(Duration.ofMinutes(5)).build();
    protected int workers = Runtime.getRuntime().availableProcessors();
    protected ThreadLocal<List<BatchItem<T>>> batches = new ThreadLocal<>();

    public int getWorkers() {
        return this.workers;
    }

    public void setWorkers(int i) {
        this.workers = i;
    }

    public List<BatchItem<T>> getBatch() {
        if (this.batches.get() == null) {
            this.batches.set(new ArrayList());
        }
        return this.batches.get();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public String getCodeName() {
        return "SkipGram";
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> weightLookupTable, @NonNull VectorsConfiguration vectorsConfiguration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache is marked non-null but is null");
        }
        if (weightLookupTable == null) {
            throw new NullPointerException("lookupTable is marked non-null but is null");
        }
        if (vectorsConfiguration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = weightLookupTable;
        this.configuration = vectorsConfiguration;
        if (vectorsConfiguration.getNegative().doubleValue() > 0.0d && ((InMemoryLookupTable) weightLookupTable).getSyn1Neg() == null) {
            log.info("Initializing syn1Neg...");
            ((InMemoryLookupTable) weightLookupTable).setUseHS(vectorsConfiguration.isUseHierarchicSoftmax().booleanValue());
            ((InMemoryLookupTable) weightLookupTable).setNegative(vectorsConfiguration.getNegative().doubleValue());
            weightLookupTable.resetWeights(false);
        }
        this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getSyn0());
        this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getSyn1());
        this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getSyn1Neg());
        this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable) weightLookupTable).getExpTable(), new long[]{((InMemoryLookupTable) weightLookupTable).getExpTable().length}, this.syn0.get() == null ? DataType.DOUBLE : this.syn0.get().dataType()));
        this.table = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getTable());
        this.window = vectorsConfiguration.getWindow().intValue();
        this.useAdaGrad = vectorsConfiguration.isUseAdaGrad().booleanValue();
        this.negative = vectorsConfiguration.getNegative().doubleValue();
        this.sampling = vectorsConfiguration.getSampling().doubleValue();
        this.variableWindows = vectorsConfiguration.getVariableWindows();
        this.workers = vectorsConfiguration.getWorkers();
        this.vectorLength = vectorsConfiguration.getLayersSize().intValue();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void pretrain(SequenceIterator<T> sequenceIterator) {
    }

    public Sequence<T> applySubsampling(@NonNull Sequence<T> sequence, @NonNull AtomicLong atomicLong) {
        if (sequence == null) {
            throw new NullPointerException("sequence is marked non-null but is null");
        }
        if (atomicLong == null) {
            throw new NullPointerException("nextRandom is marked non-null but is null");
        }
        Sequence<T> sequence2 = new Sequence<>();
        if (this.sampling <= 0.0d) {
            return sequence;
        }
        sequence2.setSequenceId(sequence.getSequenceId());
        if (sequence.getSequenceLabels() != null) {
            sequence2.setSequenceLabels(sequence.getSequenceLabels());
        }
        if (sequence.getSequenceLabel() != null) {
            sequence2.setSequenceLabel(sequence.getSequenceLabel());
        }
        for (T t : sequence.getElements()) {
            double d = this.vocabCache.totalWordOccurrences();
            double sqrt = ((Math.sqrt(t.getElementFrequency() / (this.sampling * d)) + 1.0d) * (this.sampling * d)) / t.getElementFrequency();
            atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
            if (sqrt >= (atomicLong.get() & 65535) / 65536.0d) {
                sequence2.addElement(t);
            }
        }
        return sequence2;
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public double learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong atomicLong, double d) {
        if (sequence == null) {
            throw new NullPointerException("sequence is marked non-null but is null");
        }
        if (atomicLong == null) {
            throw new NullPointerException("nextRandom is marked non-null but is null");
        }
        Sequence<T> sequence2 = sequence;
        if (this.sampling > 0.0d) {
            sequence2 = applySubsampling(sequence, atomicLong);
        }
        double d2 = 0.0d;
        int i = this.window;
        if (this.variableWindows != null && this.variableWindows.length != 0) {
            i = this.variableWindows[RandomUtils.nextInt(0, this.variableWindows.length)];
        }
        for (int i2 = 0; i2 < sequence2.getElements().size(); i2++) {
            atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
            d2 = skipGram(i2, sequence2.getElements(), ((int) atomicLong.get()) % i, atomicLong, d, i);
        }
        if (getBatch() != null && getBatch().size() >= this.configuration.getBatchSize().intValue()) {
            doExec(getBatch(), null);
            getBatch().clear();
        }
        return d2;
    }

    public void clearBatch() {
        getBatch().clear();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void finish() {
        if (this.batches == null || this.batches.get() == null || this.batches.get().isEmpty()) {
            return;
        }
        iterateSample(null);
        clearBatch();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void finish(INDArray iNDArray) {
        if (this.batches == null || this.batches.get() == null || this.batches.get().isEmpty()) {
            return;
        }
        iterateSample(null);
        clearBatch();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public boolean isEarlyTerminationHit() {
        return false;
    }

    public void addBatchItem(BatchItem<T> batchItem) {
        getBatch().add(batchItem);
    }

    private double skipGram(int i, List<T> list, int i2, AtomicLong atomicLong, double d, int i3) {
        int i4;
        T t = list.get(i);
        if (t == null || list.isEmpty() || t.isLocked()) {
            return 0.0d;
        }
        int i5 = ((i3 * 2) + 1) - i2;
        for (int i6 = i2; i6 < i5; i6++) {
            if (i6 != i3 && (i4 = (i - i3) + i6) >= 0 && i4 < list.size()) {
                T t2 = list.get(i4);
                atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
                addBatchItem(new BatchItem<>(t, t2, atomicLong.get(), d));
            }
        }
        return 0.0d;
    }

    public double iterateSample(BatchItem<T> batchItem) {
        double d = 0.0d;
        List<BatchItem<T>> batch = getBatch();
        if (batchItem != null) {
            batch.add(batchItem);
            if (batch.size() >= this.configuration.getBatchSize().intValue()) {
                d = doExec(batch, null).doubleValue();
            }
        } else if (batchItem == null && !batch.isEmpty() && batch.size() >= this.configuration.getBatchSize().intValue()) {
            d = doExec(batch, null).doubleValue();
        }
        return d;
    }

    public Double doExec(List<BatchItem<T>> list, INDArray iNDArray) {
        IterationArrays iterationArrays;
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        try {
            if (list.size() > 1) {
                int i = 1;
                for (int i2 = 0; i2 < list.size(); i2++) {
                    int codeLength = list.get(i2).getWord().getCodeLength();
                    if (codeLength > i) {
                        i = codeLength;
                    }
                }
                IterationArraysKey build = IterationArraysKey.builder().itemSize(list.size()).maxCols(i).build();
                Queue queue = (Queue) this.iterationArrays.getIfPresent(build);
                if (queue == null) {
                    queue = new ConcurrentLinkedQueue();
                    this.iterationArrays.put(build, queue);
                    iterationArrays = new IterationArrays(list.size(), i);
                } else if (queue.isEmpty()) {
                    iterationArrays = new IterationArrays(list.size(), i);
                } else {
                    try {
                        iterationArrays = (IterationArrays) queue.remove();
                        iterationArrays.initCodes();
                    } catch (NoSuchElementException e) {
                        iterationArrays = new IterationArrays(list.size(), i);
                    }
                }
                int[][] iArr = iterationArrays.indicesArr;
                int[][] iArr2 = iterationArrays.codesArr;
                long[] jArr = iterationArrays.randomValues;
                double[] dArr = iterationArrays.alphas;
                int[] iArr3 = iterationArrays.targets;
                int[] iArr4 = iterationArrays.ngStarters;
                for (int i3 = 0; i3 < list.size(); i3++) {
                    T word = list.get(i3).getWord();
                    T lastWord = list.get(i3).getLastWord();
                    jArr[i3] = list.get(i3).getRandomValue();
                    double alpha = list.get(i3).getAlpha();
                    if (word != null && lastWord != null && ((lastWord.getIndex() >= 0 || iNDArray != null) && word.getIndex() != lastWord.getIndex() && !word.getLabel().equals("STOP") && !lastWord.getLabel().equals("STOP") && !word.getLabel().equals(WordVectorsImpl.DEFAULT_UNK) && !lastWord.getLabel().equals(WordVectorsImpl.DEFAULT_UNK))) {
                        int index = lastWord.getIndex();
                        int index2 = word.getIndex();
                        iArr3[i3] = index;
                        iArr4[i3] = index2;
                        dArr[i3] = alpha;
                        if (this.configuration.isUseHierarchicSoftmax().booleanValue()) {
                            for (int i4 = 0; i4 < word.getCodeLength(); i4++) {
                                byte byteValue = word.getCodes().get(i4).byteValue();
                                int intValue = word.getPoints().get(i4).intValue();
                                if (intValue < this.vocabCache.numWords() && intValue >= 0) {
                                    iArr2[i3][i4] = byteValue;
                                    iArr[i3][i4] = intValue;
                                }
                            }
                        }
                        if (this.negative > 0.0d && this.syn1Neg == null) {
                            ((InMemoryLookupTable) this.lookupTable).initNegative();
                            this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable) this.lookupTable).getSyn1Neg());
                        }
                    }
                }
                INDArray createFromArray = Nd4j.createFromArray(dArr);
                INDArray createFromArray2 = this.negative > 0.0d ? Nd4j.createFromArray(iArr4) : null;
                INDArray createFromArray3 = Nd4j.createFromArray(jArr);
                INDArray createFromArray4 = Nd4j.createFromArray(iArr3);
                INDArray createFromArray5 = this.configuration.isUseHierarchicSoftmax().booleanValue() ? Nd4j.createFromArray(iArr2) : null;
                INDArray createFromArray6 = this.configuration.isUseHierarchicSoftmax().booleanValue() ? Nd4j.createFromArray(iArr) : null;
                SkipGramRound build2 = SkipGramRound.builder().target(createFromArray4).expTable(this.expTable.get()).ngStarter(this.negative > 0.0d ? createFromArray2 : Nd4j.empty(DataType.INT32)).syn0(this.syn0.get()).syn1(this.configuration.isUseHierarchicSoftmax().booleanValue() ? this.syn1.get() : Nd4j.empty(this.syn0.get().dataType())).syn1Neg(this.negative > 0.0d ? this.syn1Neg.get() : Nd4j.empty(this.syn0.get().dataType())).negTable(this.negative > 0.0d ? this.table.get() : Nd4j.empty(this.syn0.get().dataType())).indices(this.configuration.isUseHierarchicSoftmax().booleanValue() ? createFromArray6 : Nd4j.empty(DataType.INT32)).codes(this.configuration.isUseHierarchicSoftmax().booleanValue() ? createFromArray5 : Nd4j.empty(DataType.INT8)).alpha(createFromArray).randomValue(createFromArray3).inferenceVector(iNDArray != null ? iNDArray : Nd4j.empty(this.syn0.get().dataType())).preciseMode(this.configuration.isPreciseMode().booleanValue()).numWorkers(this.workers).iterations(iNDArray != null ? this.configuration.getIterations().intValue() * this.configuration.getEpochs().intValue() : 1).build();
                Nd4j.getExecutioner().exec(build2);
                list.clear();
                build2.inputArguments().clear();
                Nd4j.close(new INDArray[]{createFromArray4, createFromArray5, createFromArray6, createFromArray, createFromArray2, createFromArray3});
                if (queue.size() < this.maxQueueSize) {
                    queue.add(iterationArrays);
                }
            } else {
                T word2 = list.get(0).getWord();
                T lastWord2 = list.get(0).getLastWord();
                byte[] bArr = new byte[word2.getCodeLength()];
                int[] iArr5 = new int[word2.getCodeLength()];
                double alpha2 = list.get(0).getAlpha();
                if (word2 == null || lastWord2 == null || ((lastWord2.getIndex() < 0 && iNDArray == null) || word2.getIndex() == lastWord2.getIndex() || word2.getLabel().equals("STOP") || lastWord2.getLabel().equals("STOP") || word2.getLabel().equals(WordVectorsImpl.DEFAULT_UNK) || lastWord2.getLabel().equals(WordVectorsImpl.DEFAULT_UNK))) {
                    Double valueOf = Double.valueOf(0.0d);
                    if (scopeOutOfWorkspaces != null) {
                        scopeOutOfWorkspaces.close();
                    }
                    return valueOf;
                }
                int index3 = lastWord2.getIndex();
                int index4 = word2.getIndex();
                if (this.configuration.isUseHierarchicSoftmax().booleanValue()) {
                    for (int i5 = 0; i5 < word2.getCodeLength(); i5++) {
                        byte byteValue2 = word2.getCodes().get(i5).byteValue();
                        int intValue2 = word2.getPoints().get(i5).intValue();
                        if (intValue2 < this.vocabCache.numWords() && intValue2 >= 0 && i5 < word2.getCodeLength()) {
                            bArr[i5] = byteValue2;
                            iArr5[i5] = intValue2;
                        }
                    }
                }
                if (this.negative > 0.0d && this.syn1Neg == null) {
                    ((InMemoryLookupTable) this.lookupTable).initNegative();
                    this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable) this.lookupTable).getSyn1Neg());
                }
                Nd4j.getExecutioner().exec(SkipGramInference.builder().inferenceVector(iNDArray != null ? iNDArray : Nd4j.empty(this.syn0.get().dataType())).randomValue((int) list.get(0).getRandomValue()).syn0(this.syn0.get()).negTable(this.negative > 0.0d ? this.table.get() : Nd4j.empty(this.syn0.get().dataType())).expTable(this.expTable.get()).syn1(this.configuration.isUseHierarchicSoftmax().booleanValue() ? this.syn1.get() : Nd4j.empty(this.syn0.get().dataType())).syn1Neg(this.negative > 0.0d ? this.syn1Neg.get() : Nd4j.empty(this.syn0.get().dataType())).negTable(this.negative > 0.0d ? this.table.get() : Nd4j.empty(this.syn0.get().dataType())).alpha(new double[]{alpha2}).iteration(1).ngStarter(index4).indices(iArr5).target(index3).codes(bArr).preciseMode(this.configuration.getPreciseMode().booleanValue()).numWorkers(this.configuration.getWorkers()).build());
                list.clear();
            }
            Double valueOf2 = Double.valueOf(0.0d);
            if (scopeOutOfWorkspaces != null) {
                scopeOutOfWorkspaces.close();
            }
            return valueOf2;
        } catch (Throwable th) {
            if (scopeOutOfWorkspaces != null) {
                try {
                    scopeOutOfWorkspaces.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Generated
    public DeviceLocalNDArray getSyn0() {
        return this.syn0;
    }

    @Generated
    public DeviceLocalNDArray getSyn1() {
        return this.syn1;
    }

    @Generated
    public DeviceLocalNDArray getSyn1Neg() {
        return this.syn1Neg;
    }

    @Generated
    public DeviceLocalNDArray getTable() {
        return this.table;
    }

    @Generated
    public DeviceLocalNDArray getExpTable() {
        return this.expTable;
    }

    @Generated
    public void setSyn0(DeviceLocalNDArray deviceLocalNDArray) {
        this.syn0 = deviceLocalNDArray;
    }

    @Generated
    public void setSyn1(DeviceLocalNDArray deviceLocalNDArray) {
        this.syn1 = deviceLocalNDArray;
    }

    @Generated
    public void setSyn1Neg(DeviceLocalNDArray deviceLocalNDArray) {
        this.syn1Neg = deviceLocalNDArray;
    }

    @Generated
    public void setTable(DeviceLocalNDArray deviceLocalNDArray) {
        this.table = deviceLocalNDArray;
    }

    @Generated
    public void setExpTable(DeviceLocalNDArray deviceLocalNDArray) {
        this.expTable = deviceLocalNDArray;
    }
}
