package org.noear.solon.ai.rag.splitter;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.knuddels.jtokkit.api.IntArrayList;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/* loaded from: input_file:org/noear/solon/ai/rag/splitter/TokenSizeTextSplitter.class */
public class TokenSizeTextSplitter extends TextSplitter {
    private EncodingRegistry encodingRegistry;
    private EncodingType encodingType;
    private final int chunkSize;
    private final int minChunkSizeChars;
    private final int minChunkLengthToEmbed;
    private final int maxChunkCount;
    private final boolean keepSeparator;

    public TokenSizeTextSplitter() {
        this(500);
    }

    public TokenSizeTextSplitter(int i) {
        this(i, 300);
    }

    public TokenSizeTextSplitter(int i, int i2) {
        this(i, i2, 5, 1000, true);
    }

    public TokenSizeTextSplitter(int i, int i2, int i3, int i4, boolean z) {
        this.encodingRegistry = Encodings.newLazyEncodingRegistry();
        this.encodingType = EncodingType.CL100K_BASE;
        this.chunkSize = i;
        this.minChunkSizeChars = i2;
        this.minChunkLengthToEmbed = i3;
        this.maxChunkCount = i4;
        this.keepSeparator = z;
    }

    public void setEncodingRegistry(EncodingRegistry encodingRegistry) {
        if (encodingRegistry != null) {
            this.encodingRegistry = encodingRegistry;
        }
    }

    public void setEncodingType(EncodingType encodingType) {
        if (encodingType != null) {
            this.encodingType = encodingType;
        }
    }

    @Override // org.noear.solon.ai.rag.splitter.TextSplitter
    protected List<String> splitText(String str) {
        Encoding encoding = this.encodingRegistry.getEncoding(this.encodingType);
        ArrayList arrayList = new ArrayList();
        if (str != null && !str.trim().isEmpty()) {
            List<Integer> encodeTokens = encodeTokens(encoding, str);
            int i = 0;
            while (!encodeTokens.isEmpty() && i < this.maxChunkCount) {
                List<Integer> subList = encodeTokens.subList(0, Math.min(this.chunkSize, encodeTokens.size()));
                String decodeTokens = decodeTokens(encoding, subList);
                if (decodeTokens.trim().isEmpty()) {
                    encodeTokens = encodeTokens.subList(subList.size(), encodeTokens.size());
                } else {
                    int max = Math.max(decodeTokens.lastIndexOf(46), Math.max(decodeTokens.lastIndexOf(63), Math.max(decodeTokens.lastIndexOf(33), decodeTokens.lastIndexOf(10))));
                    if (max > 0 && max > this.minChunkSizeChars) {
                        decodeTokens = decodeTokens.substring(0, max + 1);
                    }
                    String trim = this.keepSeparator ? decodeTokens.trim() : decodeTokens.replace(System.lineSeparator(), " ").trim();
                    if (trim.length() > this.minChunkLengthToEmbed) {
                        arrayList.add(trim);
                    }
                    encodeTokens = encodeTokens.subList(encodeTokens(encoding, decodeTokens).size(), encodeTokens.size());
                    i++;
                }
            }
            if (!encodeTokens.isEmpty()) {
                String trim2 = decodeTokens(encoding, encodeTokens).replace(System.lineSeparator(), " ").trim();
                if (trim2.length() > this.minChunkLengthToEmbed) {
                    arrayList.add(trim2);
                }
            }
        }
        return arrayList;
    }

    protected List<Integer> encodeTokens(Encoding encoding, String str) {
        Objects.requireNonNull(str, "tokens is null");
        return encoding.encode(str).boxed();
    }

    protected String decodeTokens(Encoding encoding, List<Integer> list) {
        Objects.requireNonNull(list, "tokens is null");
        IntArrayList intArrayList = new IntArrayList(list.size());
        intArrayList.getClass();
        list.forEach((v1) -> {
            r1.add(v1);
        });
        return encoding.decode(intArrayList);
    }
}
