package org.deeplearning4j.text.documentiterator;

import java.util.List;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.shade.guava.collect.Lists;

/* loaded from: input_file:org/deeplearning4j/text/documentiterator/ShardedLabelAwareIterator.class */
public class ShardedLabelAwareIterator implements LabelAwareIterator {
    private LabelAwareIterator subIterator;
    private int documentSizeLimit;
    private TokenizerFactory tokenizerFactory;
    private List<List<String>> docBatches;
    private int currentBatch = 0;

    public ShardedLabelAwareIterator(LabelAwareIterator labelAwareIterator, TokenizerFactory tokenizerFactory, int i) {
        this.subIterator = labelAwareIterator;
        this.documentSizeLimit = i;
        this.tokenizerFactory = tokenizerFactory;
    }

    private void shardDocument(LabelledDocument labelledDocument) {
        this.docBatches = Lists.partition(this.tokenizerFactory.create(labelledDocument.getContent()).getTokens(), this.documentSizeLimit);
        this.currentBatch = 0;
    }

    @Override // org.deeplearning4j.text.documentiterator.LabelAwareIterator
    public boolean hasNextDocument() {
        LabelledDocument nextDocument = nextDocument();
        if (nextDocument != null) {
            this.currentBatch--;
        }
        return nextDocument != null;
    }

    @Override // org.deeplearning4j.text.documentiterator.LabelAwareIterator
    public LabelledDocument nextDocument() {
        while (true) {
            if (this.docBatches == null || this.currentBatch >= this.docBatches.size() || (this.docBatches != null && this.docBatches.isEmpty())) {
                if (!this.subIterator.hasNextDocument()) {
                    return null;
                }
                shardDocument(this.subIterator.nextDocument());
            }
        }
        if (this.currentBatch >= this.docBatches.size()) {
            throw new IllegalStateException("No more documents");
        }
        LabelledDocument labelledDocument = new LabelledDocument();
        labelledDocument.setLabels(this.subIterator.getLabelsSource().getLabels());
        labelledDocument.setContent(String.join(" ", this.docBatches.get(this.currentBatch)));
        this.currentBatch++;
        return labelledDocument;
    }

    @Override // org.deeplearning4j.text.documentiterator.LabelAwareIterator
    public void reset() {
        this.subIterator.reset();
        this.docBatches = null;
        this.currentBatch = 0;
    }

    @Override // org.deeplearning4j.text.documentiterator.LabelAwareIterator
    public LabelsSource getLabelsSource() {
        return this.subIterator.getLabelsSource();
    }

    @Override // org.deeplearning4j.text.documentiterator.LabelAwareIterator
    public void shutdown() {
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return hasNextDocument();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public LabelledDocument next() {
        return nextDocument();
    }

    public LabelAwareIterator getSubIterator() {
        return this.subIterator;
    }

    public void setSubIterator(LabelAwareIterator labelAwareIterator) {
        this.subIterator = labelAwareIterator;
    }

    public int getDocumentSizeLimit() {
        return this.documentSizeLimit;
    }

    public void setDocumentSizeLimit(int i) {
        this.documentSizeLimit = i;
    }

    public TokenizerFactory getTokenizerFactory() {
        return this.tokenizerFactory;
    }

    public void setTokenizerFactory(TokenizerFactory tokenizerFactory) {
        this.tokenizerFactory = tokenizerFactory;
    }

    public List<List<String>> getDocBatches() {
        return this.docBatches;
    }

    public void setDocBatches(List<List<String>> list) {
        this.docBatches = list;
    }

    public int getCurrentBatch() {
        return this.currentBatch;
    }

    public void setCurrentBatch(int i) {
        this.currentBatch = i;
    }
}
