package org.noear.solon.ai.rag.repository;

import com.google.gson.Gson;
import com.google.gson.JsonObject;
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.AddFieldReq;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.DropCollectionReq;
import io.milvus.v2.service.collection.request.GetLoadStateReq;
import io.milvus.v2.service.collection.request.HasCollectionReq;
import io.milvus.v2.service.vector.request.DeleteReq;
import io.milvus.v2.service.vector.request.GetReq;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.SearchReq;
import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.response.SearchResp;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.noear.solon.Utils;
import org.noear.solon.ai.embedding.EmbeddingModel;
import org.noear.solon.ai.rag.Document;
import org.noear.solon.ai.rag.RepositoryLifecycle;
import org.noear.solon.ai.rag.RepositoryStorable;
import org.noear.solon.ai.rag.repository.milvus.FilterTransformer;
import org.noear.solon.ai.rag.util.ListUtil;
import org.noear.solon.ai.rag.util.QueryCondition;
import org.noear.solon.ai.rag.util.SimilarityUtil;
import org.noear.solon.lang.Preview;

@Preview("3.1")
/* loaded from: input_file:org/noear/solon/ai/rag/repository/MilvusRepository.class */
public class MilvusRepository implements RepositoryStorable, RepositoryLifecycle {
    private final Gson gson;
    private final Builder config;

    /* loaded from: input_file:org/noear/solon/ai/rag/repository/MilvusRepository$Builder.class */
    public static class Builder {
        private final EmbeddingModel embeddingModel;
        private final MilvusClientV2 client;
        private String collectionName = "solon_ai";

        public Builder(EmbeddingModel embeddingModel, MilvusClientV2 milvusClientV2) {
            this.embeddingModel = embeddingModel;
            this.client = milvusClientV2;
        }

        public Builder collectionName(String str) {
            this.collectionName = str;
            return this;
        }

        public MilvusRepository build() {
            return new MilvusRepository(this);
        }
    }

    private MilvusRepository(Builder builder) {
        this.gson = new Gson();
        this.config = builder;
        initRepository();
    }

    public void initRepository() {
        if (this.config.client.hasCollection(HasCollectionReq.builder().collectionName(this.config.collectionName).build()).booleanValue()) {
            return;
        }
        try {
            CreateCollectionReq.CollectionSchema createSchema = this.config.client.createSchema();
            createSchema.addField(AddFieldReq.builder().fieldName("id").dataType(DataType.VarChar).maxLength(64).isPrimaryKey(true).autoID(false).build());
            createSchema.addField(AddFieldReq.builder().fieldName("embedding").dataType(DataType.FloatVector).dimension(Integer.valueOf(this.config.embeddingModel.dimensions())).build());
            createSchema.addField(AddFieldReq.builder().fieldName("content").dataType(DataType.VarChar).maxLength(65535).build());
            createSchema.addField(AddFieldReq.builder().fieldName("metadata").dataType(DataType.JSON).build());
            IndexParam build = IndexParam.builder().fieldName("id").build();
            IndexParam build2 = IndexParam.builder().fieldName("embedding").indexName("embedding_index").indexType(IndexParam.IndexType.IVF_FLAT).metricType(IndexParam.MetricType.COSINE).extraParams(Collections.singletonMap("nlist", 128)).build();
            ArrayList arrayList = new ArrayList();
            arrayList.add(build);
            arrayList.add(build2);
            this.config.client.createCollection(CreateCollectionReq.builder().collectionName(this.config.collectionName).collectionSchema(createSchema).indexParams(arrayList).build());
            this.config.client.getLoadState(GetLoadStateReq.builder().collectionName(this.config.collectionName).build());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void dropRepository() {
        this.config.client.dropCollection(DropCollectionReq.builder().collectionName(this.config.collectionName).build());
    }

    public void insert(List<Document> list) throws IOException {
        if (Utils.isEmpty(list)) {
            return;
        }
        for (List list2 : ListUtil.partition(list, this.config.embeddingModel.batchSize())) {
            this.config.embeddingModel.embed(list2);
            this.config.client.insert(InsertReq.builder().collectionName(this.config.collectionName).data((List) list2.stream().map(this::toJsonObject).collect(Collectors.toList())).build());
        }
    }

    public void delete(String... strArr) throws IOException {
        this.config.client.delete(DeleteReq.builder().collectionName(this.config.collectionName).ids(Arrays.asList(strArr)).build());
    }

    public boolean exists(String str) throws IOException {
        return this.config.client.get(GetReq.builder().collectionName(this.config.collectionName).ids(Arrays.asList(str)).build()).getGetResults().size() > 0;
    }

    public List<Document> search(QueryCondition queryCondition) throws IOException {
        SearchReq.SearchReqBuilder outputFields = SearchReq.builder().collectionName(this.config.collectionName).data(Collections.singletonList(new FloatVec(this.config.embeddingModel.embed(queryCondition.getQuery())))).topK(queryCondition.getLimit()).outputFields(Arrays.asList("content", "metadata"));
        if (queryCondition.getFilterExpression() != null) {
            String transform = FilterTransformer.getInstance().transform(queryCondition.getFilterExpression());
            if (Utils.isNotEmpty(transform)) {
                outputFields.filter(transform);
            }
        }
        return SimilarityUtil.refilter(this.config.client.search(outputFields.build()).getSearchResults().stream().flatMap(list -> {
            return list.stream();
        }).map(this::toDocument), queryCondition);
    }

    private JsonObject toJsonObject(Document document) {
        if (document.getId() == null) {
            document.id(Utils.uuid());
        }
        return this.gson.toJsonTree(document).getAsJsonObject();
    }

    private Document toDocument(SearchResp.SearchResult searchResult) {
        Map entity = searchResult.getEntity();
        String str = (String) entity.get("content");
        JsonObject jsonObject = (JsonObject) entity.get("metadata");
        Map map = null;
        if (jsonObject != null) {
            map = (Map) this.gson.fromJson(jsonObject, Map.class);
        }
        return new Document((String) searchResult.getId(), str, map, searchResult.getScore().floatValue()).embedding((float[]) entity.get("embedding"));
    }

    public static Builder builder(EmbeddingModel embeddingModel, MilvusClientV2 milvusClientV2) {
        return new Builder(embeddingModel, milvusClientV2);
    }
}
