package org.deeplearning4j.models.embeddings.learning.impl.elements;

import java.util.ArrayList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import lombok.Generated;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.nlp.CbowInference;
import org.nd4j.linalg.api.ops.impl.nlp.CbowRound;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.shade.guava.cache.Cache;
import org.nd4j.shade.guava.cache.CacheBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.class */
public class CBOW<T extends SequenceElement> implements ElementsLearningAlgorithm<T> {
    private VocabCache<T> vocabCache;
    private WeightLookupTable<T> lookupTable;
    private VectorsConfiguration configuration;
    private static final Logger logger = LoggerFactory.getLogger(CBOW.class);
    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected double sampling;
    protected int[] variableWindows;
    protected DeviceLocalNDArray syn0;
    protected DeviceLocalNDArray syn1;
    protected DeviceLocalNDArray syn1Neg;
    protected DeviceLocalNDArray expTable;
    protected DeviceLocalNDArray table;
    protected int workers = Runtime.getRuntime().availableProcessors();
    private Cache<IterationArraysKey, Queue<IterationArrays>> iterationArrays = CacheBuilder.newBuilder().maximumSize(Integer.parseInt(System.getProperty("org.eclipse.deeplearning4j.nlp.cachesize", "10000"))).build();
    protected int maxQueueSize = Integer.parseInt(System.getProperty("org.eclipse.deeplearning4j.nlp.queuesize", "1000"));
    protected ThreadLocal<List<BatchItem<T>>> batches = new ThreadLocal<>();

    /* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW$IterationArraysKey.class */
    public static class IterationArraysKey {
        private int itemSize;
        private int maxCols;

        @Generated
        /* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW$IterationArraysKey$IterationArraysKeyBuilder.class */
        public static class IterationArraysKeyBuilder {

            @Generated
            private int itemSize;

            @Generated
            private int maxCols;

            @Generated
            IterationArraysKeyBuilder() {
            }

            @Generated
            public IterationArraysKeyBuilder itemSize(int i) {
                this.itemSize = i;
                return this;
            }

            @Generated
            public IterationArraysKeyBuilder maxCols(int i) {
                this.maxCols = i;
                return this;
            }

            @Generated
            public IterationArraysKey build() {
                return new IterationArraysKey(this.itemSize, this.maxCols);
            }

            @Generated
            public String toString() {
                return "CBOW.IterationArraysKey.IterationArraysKeyBuilder(itemSize=" + this.itemSize + ", maxCols=" + this.maxCols + ")";
            }
        }

        @Generated
        public static IterationArraysKeyBuilder builder() {
            return new IterationArraysKeyBuilder();
        }

        @Generated
        public int getItemSize() {
            return this.itemSize;
        }

        @Generated
        public int getMaxCols() {
            return this.maxCols;
        }

        @Generated
        public void setItemSize(int i) {
            this.itemSize = i;
        }

        @Generated
        public void setMaxCols(int i) {
            this.maxCols = i;
        }

        @Generated
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof IterationArraysKey)) {
                return false;
            }
            IterationArraysKey iterationArraysKey = (IterationArraysKey) obj;
            return iterationArraysKey.canEqual(this) && getItemSize() == iterationArraysKey.getItemSize() && getMaxCols() == iterationArraysKey.getMaxCols();
        }

        @Generated
        protected boolean canEqual(Object obj) {
            return obj instanceof IterationArraysKey;
        }

        @Generated
        public int hashCode() {
            return (((1 * 59) + getItemSize()) * 59) + getMaxCols();
        }

        @Generated
        public String toString() {
            return "CBOW.IterationArraysKey(itemSize=" + getItemSize() + ", maxCols=" + getMaxCols() + ")";
        }

        @Generated
        public IterationArraysKey(int i, int i2) {
            this.itemSize = i;
            this.maxCols = i2;
        }

        @Generated
        public IterationArraysKey() {
        }
    }

    public int getWorkers() {
        return this.workers;
    }

    public void setWorkers(int i) {
        this.workers = i;
    }

    public List<BatchItem<T>> getBatch() {
        if (this.batches.get() == null) {
            this.batches.set(new ArrayList());
        }
        return this.batches.get();
    }

    public void addBatchItem(BatchItem<T> batchItem) {
        getBatch().add(batchItem);
    }

    public void clearBatch() {
        getBatch().clear();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public String getCodeName() {
        return "CBOW";
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> weightLookupTable, @NonNull VectorsConfiguration vectorsConfiguration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache is marked non-null but is null");
        }
        if (weightLookupTable == null) {
            throw new NullPointerException("lookupTable is marked non-null but is null");
        }
        if (vectorsConfiguration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = weightLookupTable;
        this.configuration = vectorsConfiguration;
        this.window = vectorsConfiguration.getWindow().intValue();
        this.useAdaGrad = vectorsConfiguration.isUseAdaGrad().booleanValue();
        this.negative = vectorsConfiguration.getNegative().doubleValue();
        this.sampling = vectorsConfiguration.getSampling().doubleValue();
        this.workers = vectorsConfiguration.getWorkers();
        if (vectorsConfiguration.getNegative().doubleValue() > 0.0d && ((InMemoryLookupTable) weightLookupTable).getSyn1Neg() == null) {
            logger.info("Initializing syn1Neg...");
            ((InMemoryLookupTable) weightLookupTable).setUseHS(vectorsConfiguration.isUseHierarchicSoftmax().booleanValue());
            ((InMemoryLookupTable) weightLookupTable).setNegative(vectorsConfiguration.getNegative().doubleValue());
            weightLookupTable.resetWeights(false);
        }
        this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getSyn0());
        this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getSyn1());
        this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getSyn1Neg());
        this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable) weightLookupTable).getExpTable(), new long[]{((InMemoryLookupTable) weightLookupTable).getExpTable().length}, this.syn0.get() == null ? DataType.DOUBLE : this.syn0.get().dataType()));
        this.table = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getTable());
        this.variableWindows = vectorsConfiguration.getVariableWindows();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void pretrain(SequenceIterator<T> sequenceIterator) {
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void finish() {
        if (this.batches == null || this.batches.get() == null || this.batches.get().isEmpty()) {
            return;
        }
        doExec(this.batches.get(), null);
        this.batches.get().clear();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void finish(INDArray iNDArray) {
        if (this.batches == null || this.batches.get() == null || this.batches.get().isEmpty()) {
            return;
        }
        doExec(this.batches.get(), iNDArray);
        this.batches.get().clear();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public double learnSequence(Sequence<T> sequence, AtomicLong atomicLong, double d) {
        Sequence<T> sequence2 = sequence;
        if (this.sampling > 0.0d) {
            sequence2 = applySubsampling(sequence, atomicLong);
        }
        int i = this.window;
        if (this.variableWindows != null && this.variableWindows.length != 0) {
            i = this.variableWindows[RandomUtils.nextInt(0, this.variableWindows.length)];
        }
        for (int i2 = 0; i2 < sequence2.getElements().size(); i2++) {
            atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
            cbow(i2, sequence2.getElements(), ((int) atomicLong.get()) % i, atomicLong, d, i, null);
        }
        if (getBatch() == null || getBatch().size() < this.configuration.getBatchSize().intValue()) {
            return 0.0d;
        }
        doExec(getBatch(), null);
        getBatch().clear();
        return 0.0d;
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public boolean isEarlyTerminationHit() {
        return false;
    }

    public double doExec(List<BatchItem<T>> list, INDArray iNDArray) {
        IterationArrays iterationArrays;
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        try {
            boolean booleanValue = this.configuration.isUseHierarchicSoftmax().booleanValue();
            boolean z = this.configuration.getNegative().doubleValue() > 0.0d;
            boolean z2 = iNDArray != null;
            if (list.size() <= 1) {
                T word = list.get(0).getWord();
                word.getIndex();
                int[] iArr = (int[]) list.get(0).getWindowWords().clone();
                boolean[] zArr = (boolean[]) list.get(0).getWordStatuses().clone();
                byte[] bArr = new byte[word.getCodeLength()];
                int[] iArr2 = new int[word.getCodeLength()];
                long randomValue = list.get(0).getRandomValue();
                double alpha = list.get(0).getAlpha();
                int numLabel = list.get(0).getNumLabel();
                int[] iArr3 = new int[iArr.length];
                for (int i = 0; i < iArr.length; i++) {
                    if (i < zArr.length) {
                        iArr3[i] = zArr[i] ? 1 : 0;
                    } else {
                        iArr3[i] = -1;
                    }
                }
                if (booleanValue) {
                    for (int i2 = 0; i2 < word.getCodeLength(); i2++) {
                        if (word.getPoints().get(i2).intValue() >= 0) {
                            bArr[i2] = word.getCodes().get(i2).byteValue();
                            iArr2[i2] = word.getPoints().get(i2).intValue();
                        }
                    }
                }
                if (this.negative <= 0.0d) {
                    this.syn1Neg.set(Nd4j.empty(this.syn0.get().dataType()));
                } else if (this.syn1Neg == null) {
                    ((InMemoryLookupTable) this.lookupTable).initNegative();
                    this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable) this.lookupTable).getSyn1Neg());
                }
                Nd4j.getExecutioner().exec(CbowInference.builder().target(word.getIndex()).ngStarter(word.getIndex()).negTable(this.table.get() == null ? Nd4j.empty(this.syn0.get().dataType()) : this.table.get()).expTable(this.expTable.get() == null ? Nd4j.empty(this.syn0.get().dataType()) : this.expTable.get()).syn0(this.syn0.get()).syn1(!booleanValue ? Nd4j.empty(this.syn0.get().dataType()) : this.syn1.get()).syn1Neg(this.syn1Neg.get()).alpha(alpha).context(iArr == null ? new int[0] : iArr).indices((booleanValue || z) ? iArr2 : new int[0]).codes(booleanValue ? bArr : new byte[0]).lockedWords(iArr3).randomValue((int) randomValue).numWorkers(this.workers).numLabels(numLabel).nsRounds(z ? (int) this.negative : 0).preciseMode(this.configuration.isPreciseMode().booleanValue()).inferenceVector(iNDArray != null ? iNDArray : Nd4j.empty(this.syn0.get().dataType())).iterations(z2 ? this.configuration.getIterations().intValue() * this.configuration.getEpochs().intValue() : 1).build());
                if (scopeOutOfWorkspaces == null) {
                    return 0.0d;
                }
                scopeOutOfWorkspaces.close();
                return 0.0d;
            }
            int i3 = 1;
            for (int i4 = 0; i4 < list.size(); i4++) {
                int codeLength = list.get(i4).getWord().getCodeLength();
                if (codeLength > i3) {
                    i3 = codeLength;
                }
            }
            boolean z3 = false;
            int i5 = -1;
            for (int i6 = 0; i6 < list.size(); i6++) {
                int codeLength2 = list.get(i6).getWord().getCodeLength();
                if (codeLength2 > i5) {
                    i5 = codeLength2;
                }
            }
            IterationArraysKey build = IterationArraysKey.builder().itemSize(list.size()).maxCols(i3).build();
            Queue queue = (Queue) this.iterationArrays.getIfPresent(build);
            if (queue == null) {
                queue = new ConcurrentLinkedQueue();
                this.iterationArrays.put(build, queue);
                iterationArrays = new IterationArrays(list.size(), i3, i5);
            } else if (queue.isEmpty()) {
                iterationArrays = new IterationArrays(list.size(), i3, i5);
            } else {
                try {
                    iterationArrays = (IterationArrays) queue.remove();
                    iterationArrays.initCodes();
                } catch (NoSuchElementException e) {
                    iterationArrays = new IterationArrays(list.size(), i3);
                }
            }
            int[][] iArr4 = iterationArrays.inputWindowWordsArr;
            int[][] iArr5 = iterationArrays.inputWindowWordStatuses;
            int[] iArr6 = iterationArrays.currentWindowIndexes;
            double[] dArr = iterationArrays.alphas;
            int[][] iArr7 = iterationArrays.indicesArr;
            int[][] iArr8 = iterationArrays.codesArr;
            long[] jArr = iterationArrays.randomValues;
            int[] iArr9 = iterationArrays.numLabels;
            INDArray createFromArray = Nd4j.createFromArray(iArr6);
            for (int i7 = 0; i7 < list.size(); i7++) {
                T word2 = list.get(i7).getWord();
                createFromArray.putScalar(0L, word2.getIndex());
                iArr6[0] = word2.getIndex();
                int[] iArr10 = (int[]) list.get(i7).getWindowWords().clone();
                boolean[] zArr2 = (boolean[]) list.get(i7).getWordStatuses().clone();
                for (int i8 = 0; i8 < i5; i8++) {
                    if (i8 < iArr10.length) {
                        iArr4[i7][i8] = iArr10[i8];
                        iArr5[i7][i8] = zArr2[i8] ? 1 : 0;
                    } else {
                        iArr4[i7][i8] = -1;
                        iArr5[i7][i8] = -1;
                    }
                }
                long randomValue2 = list.get(i7).getRandomValue();
                dArr[i7] = list.get(i7).getAlpha();
                jArr[i7] = randomValue2;
                iArr9[i7] = list.get(i7).getNumLabel();
                if (list.get(i7).getNumLabel() > 0) {
                    z3 = true;
                }
                if (booleanValue) {
                    for (int i9 = 0; i9 < word2.getCodeLength(); i9++) {
                        if (word2.getPoints().get(i9).intValue() >= 0) {
                            iArr8[i7][i9] = word2.getCodes().get(i9).byteValue();
                            iArr7[i7][i9] = word2.getPoints().get(i9).intValue();
                        }
                    }
                }
                if (this.negative > 0.0d && this.syn1Neg == null) {
                    ((InMemoryLookupTable) this.lookupTable).initNegative();
                    this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable) this.lookupTable).getSyn1Neg());
                }
            }
            INDArray createFromArray2 = Nd4j.createFromArray(iArr4);
            INDArray createFromArray3 = Nd4j.createFromArray(iArr5);
            INDArray createFromArray4 = Nd4j.createFromArray(iArr9);
            INDArray createFromArray5 = Nd4j.createFromArray(iArr7);
            INDArray createFromArray6 = Nd4j.createFromArray(iArr8);
            INDArray createFromArray7 = Nd4j.createFromArray(dArr);
            INDArray createFromArray8 = Nd4j.createFromArray(jArr);
            Nd4j.getExecutioner().exec(CbowRound.builder().target(createFromArray).context(createFromArray2).lockedWords(createFromArray3).ngStarter(createFromArray).syn0(this.syn0.get()).syn1(booleanValue ? this.syn1.get() : Nd4j.empty(this.syn0.get().dataType())).syn1Neg(this.negative > 0.0d ? this.syn1Neg.get() : Nd4j.empty(this.syn0.get().dataType())).expTable(this.expTable.get()).negTable(this.negative > 0.0d ? this.table.get() : Nd4j.empty(this.syn0.get().dataType())).indices(booleanValue ? createFromArray5 : Nd4j.empty(DataType.INT32)).codes(booleanValue ? createFromArray6 : Nd4j.empty(DataType.INT8)).nsRounds((int) this.negative).alpha(createFromArray7).nextRandom(createFromArray8).inferenceVector(iNDArray != null ? iNDArray : Nd4j.empty(this.syn0.get().dataType())).numLabels(z3 ? createFromArray4 : Nd4j.empty(DataType.INT32)).trainWords(this.configuration.isTrainElementsVectors().booleanValue()).numWorkers(this.workers).iterations(z2 ? this.configuration.getIterations().intValue() * this.configuration.getEpochs().intValue() : 1).build());
            Nd4j.close(new INDArray[]{createFromArray, createFromArray2, createFromArray7, createFromArray8, createFromArray6, createFromArray4, createFromArray5});
            if (queue.size() < this.maxQueueSize) {
                queue.add(iterationArrays);
            }
            this.batches.get().clear();
            if (scopeOutOfWorkspaces != null) {
                scopeOutOfWorkspaces.close();
            }
            return 0.0d;
        } catch (Throwable th) {
            if (scopeOutOfWorkspaces != null) {
                try {
                    scopeOutOfWorkspaces.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void cbow(int i, List<T> list, int i2, AtomicLong atomicLong, double d, int i3, List<BatchItem<T>> list2) {
        int i4;
        int i5 = ((this.window * 2) + 1) - i2;
        T t = list.get(i);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i6 = i2; i6 < i5; i6++) {
            if (i6 != i3 && (i4 = (i - i3) + i6) >= 0 && i4 < list.size()) {
                T t2 = list.get(i4);
                arrayList.add(Integer.valueOf(t2.getIndex()));
                arrayList2.add(Boolean.valueOf(t2.isLocked()));
            }
        }
        int[] iArr = new int[arrayList.size()];
        boolean[] zArr = new boolean[arrayList.size()];
        for (int i7 = 0; i7 < iArr.length; i7++) {
            iArr[i7] = ((Integer) arrayList.get(i7)).intValue();
            zArr[i7] = ((Boolean) arrayList2.get(i7)).booleanValue();
        }
        list2.add(new BatchItem<>(t, iArr, zArr, atomicLong.get(), d));
        iterateBatchesIfReady(list2);
    }

    private double iterateBatchesIfReady(List<BatchItem<T>> list) {
        double d = 0.0d;
        if (this.batches.get() == null) {
            this.batches.set(list);
        } else {
            this.batches.get().addAll(list);
        }
        if (this.batches.get().size() >= this.configuration.getBatchSize().intValue()) {
            d = doExec(this.batches.get(), null);
            this.batches.get().clear();
        }
        return d;
    }

    public Sequence<T> applySubsampling(@NonNull Sequence<T> sequence, @NonNull AtomicLong atomicLong) {
        if (sequence == null) {
            throw new NullPointerException("sequence is marked non-null but is null");
        }
        if (atomicLong == null) {
            throw new NullPointerException("nextRandom is marked non-null but is null");
        }
        Sequence<T> sequence2 = new Sequence<>();
        if (this.sampling <= 0.0d) {
            return sequence;
        }
        sequence2.setSequenceId(sequence.getSequenceId());
        if (sequence.getSequenceLabels() != null) {
            sequence2.setSequenceLabels(sequence.getSequenceLabels());
        }
        if (sequence.getSequenceLabel() != null) {
            sequence2.setSequenceLabel(sequence.getSequenceLabel());
        }
        for (T t : sequence.getElements()) {
            double d = this.vocabCache.totalWordOccurrences();
            double sqrt = ((Math.sqrt(t.getElementFrequency() / (this.sampling * d)) + 1.0d) * (this.sampling * d)) / t.getElementFrequency();
            atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
            if (sqrt >= (atomicLong.get() & 65535) / 65536.0d) {
                sequence2.addElement(t);
            }
        }
        return sequence2;
    }

    @Generated
    public DeviceLocalNDArray getSyn0() {
        return this.syn0;
    }

    @Generated
    public DeviceLocalNDArray getSyn1() {
        return this.syn1;
    }

    @Generated
    public DeviceLocalNDArray getSyn1Neg() {
        return this.syn1Neg;
    }

    @Generated
    public DeviceLocalNDArray getExpTable() {
        return this.expTable;
    }

    @Generated
    public DeviceLocalNDArray getTable() {
        return this.table;
    }

    @Generated
    public void setSyn0(DeviceLocalNDArray deviceLocalNDArray) {
        this.syn0 = deviceLocalNDArray;
    }

    @Generated
    public void setSyn1(DeviceLocalNDArray deviceLocalNDArray) {
        this.syn1 = deviceLocalNDArray;
    }

    @Generated
    public void setSyn1Neg(DeviceLocalNDArray deviceLocalNDArray) {
        this.syn1Neg = deviceLocalNDArray;
    }

    @Generated
    public void setExpTable(DeviceLocalNDArray deviceLocalNDArray) {
        this.expTable = deviceLocalNDArray;
    }

    @Generated
    public void setTable(DeviceLocalNDArray deviceLocalNDArray) {
        this.table = deviceLocalNDArray;
    }
}
