package org.noear.solon.ai.rag.repository;

import com.tencent.tcvectordb.client.VectorDBClient;
import com.tencent.tcvectordb.model.Collection;
import com.tencent.tcvectordb.model.Database;
import com.tencent.tcvectordb.model.DocField;
import com.tencent.tcvectordb.model.Document;
import com.tencent.tcvectordb.model.param.collection.CreateCollectionParam;
import com.tencent.tcvectordb.model.param.collection.Embedding;
import com.tencent.tcvectordb.model.param.collection.FieldType;
import com.tencent.tcvectordb.model.param.collection.FilterIndex;
import com.tencent.tcvectordb.model.param.collection.HNSWParams;
import com.tencent.tcvectordb.model.param.collection.IndexType;
import com.tencent.tcvectordb.model.param.collection.MetricType;
import com.tencent.tcvectordb.model.param.collection.ParamsSerializer;
import com.tencent.tcvectordb.model.param.collection.VectorIndex;
import com.tencent.tcvectordb.model.param.dml.DeleteParam;
import com.tencent.tcvectordb.model.param.dml.HNSWSearchParams;
import com.tencent.tcvectordb.model.param.dml.InsertParam;
import com.tencent.tcvectordb.model.param.dml.QueryParam;
import com.tencent.tcvectordb.model.param.dml.SearchByEmbeddingItemsParam;
import com.tencent.tcvectordb.model.param.entity.AffectRes;
import com.tencent.tcvectordb.model.param.entity.SearchRes;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.noear.solon.Utils;
import org.noear.solon.ai.rag.Document;
import org.noear.solon.ai.rag.RepositoryLifecycle;
import org.noear.solon.ai.rag.RepositoryStorable;
import org.noear.solon.ai.rag.repository.tcvectordb.EmbeddingModelEnum;
import org.noear.solon.ai.rag.repository.tcvectordb.FilterTransformer;
import org.noear.solon.ai.rag.repository.tcvectordb.MetadataField;
import org.noear.solon.ai.rag.util.ListUtil;
import org.noear.solon.ai.rag.util.QueryCondition;
import org.noear.solon.ai.rag.util.SimilarityUtil;
import org.noear.solon.lang.Preview;

@Preview("3.1")
/* loaded from: input_file:org/noear/solon/ai/rag/repository/TcVectorDbRepository.class */
public class TcVectorDbRepository implements RepositoryStorable, RepositoryLifecycle {
    public static final String TEXT_FIELD_NAME = "__text";
    public static final String VECTOR_FIELD_NAME = "vector";
    private final Builder config;
    private Collection collection;
    private boolean initialized;

    /* loaded from: input_file:org/noear/solon/ai/rag/repository/TcVectorDbRepository$Builder.class */
    public static class Builder {
        private final VectorDBClient client;
        private EmbeddingModelEnum embeddingModel = EmbeddingModelEnum.BGE_LARGE_ZH_V1P5;
        private String databaseName = "solon_ai_db";
        private String collectionName = "solon_ai";
        private int shardNum = 1;
        private int replicaNum = 0;
        private MetricType metricType = MetricType.COSINE;
        private IndexType indexType = IndexType.HNSW;
        private ParamsSerializer indexParams = null;
        private List<MetadataField> metadataFields = new ArrayList();
        private int hnswM = 16;
        private int hnswSearchEf = 500;
        private int hnswConstructionEf = 400;
        private int embeddingBatchSize = 10;

        public Builder(VectorDBClient vectorDBClient) {
            if (vectorDBClient == null) {
                throw new IllegalArgumentException("Client must not be null or empty");
            }
            this.client = vectorDBClient;
        }

        public Builder databaseName(String str) {
            if (Utils.isNotEmpty(str)) {
                this.databaseName = str;
            }
            return this;
        }

        public Builder collectionName(String str) {
            if (Utils.isNotEmpty(str)) {
                this.collectionName = str;
            }
            return this;
        }

        public Builder embeddingModel(EmbeddingModelEnum embeddingModelEnum) {
            if (embeddingModelEnum != null) {
                this.embeddingModel = embeddingModelEnum;
            }
            return this;
        }

        public Builder shardNum(int i) {
            this.shardNum = i;
            return this;
        }

        public Builder replicaNum(int i) {
            this.replicaNum = i;
            return this;
        }

        public Builder metricType(MetricType metricType) {
            if (metricType != null) {
                this.metricType = metricType;
            }
            return this;
        }

        public Builder indexType(IndexType indexType) {
            if (indexType != null) {
                this.indexType = indexType;
            }
            return this;
        }

        public Builder indexParams(ParamsSerializer paramsSerializer) {
            if (paramsSerializer != null) {
                this.indexParams = paramsSerializer;
            }
            return this;
        }

        public Builder hnswM(int i) {
            this.hnswM = i;
            return this;
        }

        public Builder hnswSearchEf(int i) {
            this.hnswSearchEf = i;
            return this;
        }

        public Builder hnswConstructionEf(int i) {
            this.hnswConstructionEf = i;
            return this;
        }

        public Builder metadataFields(List<MetadataField> list) {
            this.metadataFields = list;
            return this;
        }

        public Builder addMetadataField(MetadataField metadataField) {
            this.metadataFields.add(metadataField);
            return this;
        }

        public Builder embeddingBatchSize(int i) {
            if (i > 0) {
                this.embeddingBatchSize = i;
            }
            return this;
        }

        public TcVectorDbRepository build() {
            return new TcVectorDbRepository(this);
        }
    }

    private TcVectorDbRepository(Builder builder) {
        this.initialized = false;
        this.config = builder;
        initRepository();
    }

    public void initRepository() {
        if (this.initialized) {
            return;
        }
        try {
            Database createDatabase = !this.config.client.listDatabase().contains(this.config.databaseName) ? this.config.client.createDatabase(this.config.databaseName) : this.config.client.database(this.config.databaseName);
            if (!createDatabase.listCollections().stream().anyMatch(collection -> {
                return this.config.collectionName.equals(collection.getCollection());
            })) {
                CreateCollectionParam.Builder addField = CreateCollectionParam.newBuilder().withName(this.config.collectionName).withShardNum(this.config.shardNum).withReplicaNum(this.config.replicaNum).withDescription("Collection created by Solon AI").addField(new FilterIndex("id", FieldType.String, IndexType.PRIMARY_KEY));
                addField.addField(getVectorIndex());
                for (MetadataField metadataField : this.config.metadataFields) {
                    addField.addField(new FilterIndex(metadataField.getName(), metadataField.getFieldType(), IndexType.FILTER));
                }
                addField.withEmbedding(Embedding.newBuilder().withVectorField(VECTOR_FIELD_NAME).withField(TEXT_FIELD_NAME).withModelName(this.config.embeddingModel.getModelName()).build());
                createDatabase.createCollection(addField.build());
            }
            this.collection = createDatabase.describeCollection(this.config.collectionName);
            this.initialized = true;
        } catch (Exception e) {
            throw new RuntimeException("Failed to initialize VectorDB repository: " + e.getMessage(), e);
        }
    }

    public void dropRepository() {
        (!this.config.client.listDatabase().contains(this.config.databaseName) ? this.config.client.createDatabase(this.config.databaseName) : this.config.client.database(this.config.databaseName)).dropCollection(this.config.collectionName);
        this.collection = null;
        this.initialized = false;
    }

    private VectorIndex getVectorIndex() {
        return this.config.indexParams != null ? new VectorIndex(VECTOR_FIELD_NAME, Integer.valueOf(this.config.embeddingModel.getDimension()), this.config.indexType, this.config.metricType, this.config.indexParams) : new VectorIndex(VECTOR_FIELD_NAME, Integer.valueOf(this.config.embeddingModel.getDimension()), this.config.indexType, this.config.metricType, new HNSWParams(this.config.hnswM, this.config.hnswConstructionEf));
    }

    public void insert(List<Document> list) throws IOException {
        if (list == null || list.isEmpty()) {
            return;
        }
        for (Document document : list) {
            if (Utils.isEmpty(document.getId())) {
                document.id(Utils.uuid());
            }
        }
        for (List<Document> list2 : ListUtil.partition(list, this.config.embeddingBatchSize)) {
            ArrayList arrayList = new ArrayList();
            for (Document document2 : list2) {
                Document.Builder addDocField = com.tencent.tcvectordb.model.Document.newBuilder().withId(document2.getId()).withDoc(document2.getContent()).addDocField(new DocField(TEXT_FIELD_NAME, document2.getContent()));
                if (document2.getMetadata() != null && !document2.getMetadata().isEmpty()) {
                    for (Map.Entry entry : document2.getMetadata().entrySet()) {
                        addDocField.addDocField(new DocField((String) entry.getKey(), entry.getValue()));
                    }
                }
                arrayList.add(addDocField.build());
            }
            AffectRes upsert = this.collection.upsert(InsertParam.newBuilder().addAllDocument(arrayList).withBuildIndex(true).build());
            if (upsert.getCode() != 0) {
                throw new IOException("Failed to insert documents: " + upsert.getMsg());
            }
        }
    }

    public void delete(String... strArr) throws IOException {
        if (strArr == null || strArr.length == 0) {
            return;
        }
        try {
            this.collection.delete(DeleteParam.newBuilder().addAllDocumentId(Arrays.asList(strArr)).build());
        } catch (Exception e) {
            throw new IOException("Failed to delete documents: " + e.getMessage(), e);
        }
    }

    public boolean exists(String str) throws IOException {
        try {
            return Utils.isNotEmpty(this.collection.query(QueryParam.newBuilder().withDocumentIds(Collections.singletonList(str)).withLimit(1L).build()));
        } catch (Exception e) {
            throw new IOException("Failed to check document existence: " + e.getMessage(), e);
        }
    }

    public List<org.noear.solon.ai.rag.Document> search(QueryCondition queryCondition) throws IOException {
        if (queryCondition == null) {
            throw new IllegalArgumentException("QueryCondition must not be null");
        }
        try {
            SearchByEmbeddingItemsParam.Builder withLimit = SearchByEmbeddingItemsParam.newBuilder().withEmbeddingItems(Collections.singletonList(queryCondition.getQuery())).withParams(new HNSWSearchParams(this.config.hnswSearchEf)).withLimit(queryCondition.getLimit() > 0 ? queryCondition.getLimit() : 10);
            if (queryCondition.getFilterExpression() != null) {
                withLimit.withFilter(FilterTransformer.getInstance().transform(queryCondition.getFilterExpression()));
            }
            return SimilarityUtil.refilter(getDocuments(this.collection.searchByEmbeddingItems(withLimit.build())).stream(), queryCondition);
        } catch (Exception e) {
            throw new IOException("Failed to search documents: " + e.getMessage(), e);
        }
    }

    private static List<org.noear.solon.ai.rag.Document> getDocuments(SearchRes searchRes) {
        ArrayList arrayList = new ArrayList();
        if (Utils.isEmpty(searchRes.getDocuments())) {
            return arrayList;
        }
        Iterator it = searchRes.getDocuments().iterator();
        while (it.hasNext()) {
            for (com.tencent.tcvectordb.model.Document document : (List) it.next()) {
                arrayList.add(new org.noear.solon.ai.rag.Document(document.getId(), document.getDoc(), toMetadata(document.getDocFields()), document.getScore().doubleValue()));
            }
        }
        return arrayList;
    }

    private static Map<String, Object> toMetadata(List<DocField> list) {
        HashMap hashMap = new HashMap();
        if (list == null || list.isEmpty()) {
            return hashMap;
        }
        for (DocField docField : list) {
            if (!TEXT_FIELD_NAME.equals(docField.getName())) {
                hashMap.put(docField.getName(), docField.getValue());
            }
        }
        return hashMap;
    }

    public static Builder builder(VectorDBClient vectorDBClient) {
        return new Builder(vectorDBClient);
    }
}
