package org.noear.solon.ai.rag.repository;

import io.qdrant.client.PointIdFactory;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QueryFactory;
import io.qdrant.client.ValueFactory;
import io.qdrant.client.VectorsFactory;
import io.qdrant.client.WithPayloadSelectorFactory;
import io.qdrant.client.WithVectorsSelectorFactory;
import io.qdrant.client.grpc.Collections;
import io.qdrant.client.grpc.JsonWithInt;
import io.qdrant.client.grpc.Points;
import java.io.IOException;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import org.noear.snack.ONode;
import org.noear.solon.Utils;
import org.noear.solon.ai.embedding.EmbeddingModel;
import org.noear.solon.ai.rag.Document;
import org.noear.solon.ai.rag.RepositoryLifecycle;
import org.noear.solon.ai.rag.RepositoryStorable;
import org.noear.solon.ai.rag.repository.qdrant.FilterTransformer;
import org.noear.solon.ai.rag.repository.qdrant.QdrantValueUtil;
import org.noear.solon.ai.rag.util.ListUtil;
import org.noear.solon.ai.rag.util.QueryCondition;
import org.noear.solon.ai.rag.util.SimilarityUtil;
import org.noear.solon.lang.Preview;

@Preview("3.1")
/* loaded from: input_file:org/noear/solon/ai/rag/repository/QdrantRepository.class */
public class QdrantRepository implements RepositoryStorable, RepositoryLifecycle {
    private final Builder config;

    /* loaded from: input_file:org/noear/solon/ai/rag/repository/QdrantRepository$Builder.class */
    public static class Builder {
        private final EmbeddingModel embeddingModel;
        private final QdrantClient client;
        private String collectionName = "solon_ai";
        private String contentFieldName = "content";
        private String metadataFieldName = "metadata";

        public Builder(EmbeddingModel embeddingModel, QdrantClient qdrantClient) {
            this.embeddingModel = embeddingModel;
            this.client = qdrantClient;
        }

        public Builder collectionName(String str) {
            this.collectionName = str;
            return this;
        }

        public QdrantRepository build() {
            return new QdrantRepository(this);
        }
    }

    private QdrantRepository(Builder builder) {
        this.config = builder;
        initRepository();
    }

    public void initRepository() {
        try {
            if (!((Boolean) this.config.client.collectionExistsAsync(this.config.collectionName).get()).booleanValue()) {
                this.config.client.createCollectionAsync(this.config.collectionName, Collections.VectorParams.newBuilder().setSize(this.config.embeddingModel.dimensions()).setDistance(Collections.Distance.Cosine).build()).get();
            }
        } catch (IOException | InterruptedException | ExecutionException e) {
            throw new RuntimeException("Failed to initialize Qdrant repository", e);
        }
    }

    public void dropRepository() {
        try {
            this.config.client.deleteCollectionAsync(this.config.collectionName).get();
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException("Failed to drop Qdrant repository", e);
        }
    }

    public void insert(List<Document> list) throws IOException {
        if (Utils.isEmpty(list)) {
            return;
        }
        for (List list2 : ListUtil.partition(list, this.config.embeddingModel.batchSize())) {
            this.config.embeddingModel.embed(list2);
            try {
                this.config.client.upsertAsync(Points.UpsertPoints.newBuilder().setCollectionName(this.config.collectionName).addAllPoints((List) list2.stream().map(this::toPointStruct).collect(Collectors.toList())).build()).get();
            } catch (InterruptedException | ExecutionException e) {
                throw new IOException("Failed to insert documents from Qdrant", e);
            }
        }
    }

    public void delete(String... strArr) throws IOException {
        try {
            this.config.client.deleteAsync(this.config.collectionName, (List) Arrays.stream(strArr).map(str -> {
                return Points.PointId.newBuilder().setUuid(str).build();
            }).collect(Collectors.toList())).get();
        } catch (InterruptedException | ExecutionException e) {
            throw new IOException("Failed to delete documents from Qdrant", e);
        }
    }

    public boolean exists(String str) throws IOException {
        try {
            return ((List) this.config.client.retrieveAsync(Points.GetPoints.newBuilder().setCollectionName(this.config.collectionName).addIds(PointIdFactory.id(UUID.fromString(str))).build(), (Duration) null).get()).size() > 0;
        } catch (InterruptedException | ExecutionException e) {
            throw new IOException("Failed to check document existence in Qdrant", e);
        }
    }

    public List<Document> search(QueryCondition queryCondition) throws IOException {
        try {
            Points.QueryPoints.Builder withVectors = Points.QueryPoints.newBuilder().setCollectionName(this.config.collectionName).setQuery(QueryFactory.nearest(this.config.embeddingModel.embed(queryCondition.getQuery()))).setLimit(queryCondition.getLimit()).setScoreThreshold((float) queryCondition.getSimilarityThreshold()).setWithPayload(WithPayloadSelectorFactory.include(Arrays.asList(this.config.contentFieldName, this.config.metadataFieldName))).setWithVectors(WithVectorsSelectorFactory.enable(true));
            Points.Filter transform = FilterTransformer.getInstance().transform(queryCondition.getFilterExpression());
            if (transform != null) {
                withVectors.setFilter(transform);
            }
            return SimilarityUtil.refilter(((List) this.config.client.queryAsync(withVectors.build()).get()).stream().map(this::toDocument), queryCondition);
        } catch (InterruptedException | ExecutionException e) {
            throw new IOException("Failed to search documents in Qdrant", e);
        }
    }

    private Points.PointStruct toPointStruct(Document document) {
        if (document.getId() == null) {
            document.id(Utils.uuid());
        }
        Map<String, JsonWithInt.Value> fromMap = QdrantValueUtil.fromMap(document.getMetadata());
        fromMap.put(this.config.contentFieldName, ValueFactory.value(document.getContent()));
        if (document.getMetadata() != null) {
            fromMap.put("metadata", ValueFactory.value(ONode.stringify(document.getMetadata())));
        }
        return Points.PointStruct.newBuilder().setId(PointIdFactory.id(UUID.fromString(document.getId()))).setVectors(VectorsFactory.vectors(document.getEmbedding())).putAllPayload(fromMap).build();
    }

    private Document toDocument(Points.ScoredPoint scoredPoint) {
        String uuid = scoredPoint.getId().getUuid();
        float score = scoredPoint.getScore();
        Map payloadMap = scoredPoint.getPayloadMap();
        String stringValue = ((JsonWithInt.Value) payloadMap.get(this.config.contentFieldName)).getStringValue();
        Map map = null;
        if (payloadMap.containsKey(this.config.metadataFieldName)) {
            map = (Map) ONode.deserialize(((JsonWithInt.Value) payloadMap.get(this.config.metadataFieldName)).getStringValue(), Map.class);
        }
        return new Document(uuid, stringValue, map, score).embedding(listToFloatArray(scoredPoint.getVectors().getVector().getDataList()));
    }

    private float[] listToFloatArray(List<Float> list) {
        float[] fArr = new float[list.size()];
        for (int i = 0; i < list.size(); i++) {
            fArr[i] = list.get(i).floatValue();
        }
        return fArr;
    }

    public static Builder builder(EmbeddingModel embeddingModel, QdrantClient qdrantClient) {
        return new Builder(embeddingModel, qdrantClient);
    }
}
