package org.noear.solon.ai.rag.repository;

import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.noear.solon.Utils;
import org.noear.solon.ai.embedding.EmbeddingModel;
import org.noear.solon.ai.rag.Document;
import org.noear.solon.ai.rag.RepositoryLifecycle;
import org.noear.solon.ai.rag.RepositoryStorable;
import org.noear.solon.ai.rag.repository.chroma.ChromaClient;
import org.noear.solon.ai.rag.repository.chroma.CollectionResponse;
import org.noear.solon.ai.rag.repository.chroma.FilterTransformer;
import org.noear.solon.ai.rag.repository.chroma.QueryResponse;
import org.noear.solon.ai.rag.util.ListUtil;
import org.noear.solon.ai.rag.util.QueryCondition;
import org.noear.solon.ai.rag.util.SimilarityUtil;
import org.noear.solon.lang.Preview;

@Preview("3.1")
/* loaded from: input_file:org/noear/solon/ai/rag/repository/ChromaRepository.class */
public class ChromaRepository implements RepositoryStorable, RepositoryLifecycle {
    private final Builder config;
    private String collectionId;

    /* loaded from: input_file:org/noear/solon/ai/rag/repository/ChromaRepository$Builder.class */
    public static class Builder {
        private final EmbeddingModel embeddingModel;
        private final ChromaClient client;
        private String collectionName;

        private Builder(EmbeddingModel embeddingModel, ChromaClient chromaClient) {
            this.collectionName = "solon_ai";
            this.embeddingModel = embeddingModel;
            this.client = chromaClient;
        }

        public Builder collectionName(String str) {
            this.collectionName = str;
            return this;
        }

        public ChromaRepository build() {
            return new ChromaRepository(this);
        }
    }

    private ChromaRepository(Builder builder) {
        this.config = builder;
        try {
            initRepository();
        } catch (IOException e) {
            throw new RuntimeException("Failed to initialize Chroma repository: " + e.getMessage(), e);
        }
    }

    public void initRepository() throws IOException {
        if (this.collectionId != null) {
            return;
        }
        CollectionResponse collectionStats = this.config.client.getCollectionStats(this.config.collectionName);
        if (collectionStats != null) {
            this.collectionId = collectionStats.getId();
        }
        if (this.collectionId != null) {
            return;
        }
        createNewCollection();
        if (this.collectionId == null) {
            throw new IOException("Failed to create or find collection: " + this.config.collectionName);
        }
    }

    private void createNewCollection() throws IOException {
        HashMap hashMap = new HashMap();
        hashMap.put("description", "Collection created by Solon AI");
        hashMap.put("created_at", Long.valueOf(System.currentTimeMillis()));
        hashMap.put("hnsw:space", "cosine");
        this.collectionId = this.config.client.createCollection(this.config.collectionName, hashMap).getId();
    }

    public boolean isHealthy() {
        return this.config.client.isHealthy();
    }

    public void insert(List<Document> list) throws IOException {
        if (Utils.isEmpty(list)) {
            return;
        }
        for (Document document : list) {
            if (Utils.isEmpty(document.getId())) {
                document.id(Utils.uuid());
            }
        }
        for (List<Document> list2 : ListUtil.partition(list)) {
            this.config.embeddingModel.embed(list2);
            addDocuments(list2);
        }
    }

    private void addDocuments(List<Document> list) throws IOException {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (Document document : list) {
            arrayList.add(document.getId());
            arrayList2.add(floatArrayToList(document.getEmbedding()));
            HashMap hashMap = new HashMap(document.getMetadata());
            if (!Utils.isEmpty(document.getUrl())) {
                hashMap.put("url", document.getUrl());
            }
            arrayList3.add(hashMap);
            arrayList4.add(document.getContent());
        }
        this.config.client.addDocuments(this.collectionId, arrayList, arrayList2, arrayList4, arrayList3);
    }

    public void delete(String... strArr) throws IOException {
        if (strArr == null || strArr.length == 0) {
            return;
        }
        this.config.client.deleteDocuments(this.collectionId, new ArrayList(Arrays.asList(strArr)));
    }

    public boolean exists(String str) throws IOException {
        if (Utils.isEmpty(str)) {
            return false;
        }
        return this.config.client.documentExists(this.collectionId, str);
    }

    public List<Document> search(QueryCondition queryCondition) throws IOException {
        if (queryCondition == null || queryCondition.getQuery() == null) {
            return new ArrayList();
        }
        try {
            return SimilarityUtil.refilter(parseQueryResponse(this.config.client.queryDocuments(this.collectionId, floatArrayToList(this.config.embeddingModel.embed(queryCondition.getQuery())), queryCondition.getLimit(), FilterTransformer.getInstance().transform(queryCondition.getFilterExpression()))).stream(), queryCondition);
        } catch (Exception e) {
            throw new IOException("Failed to search documents: " + e.getMessage(), e);
        }
    }

    private List<Document> parseQueryResponse(QueryResponse queryResponse) {
        ArrayList arrayList = new ArrayList();
        if (queryResponse.hasError()) {
            return arrayList;
        }
        List<List<String>> ids = queryResponse.getIds();
        List<List<String>> documents = queryResponse.getDocuments();
        List<List<Map<String, Object>>> metadatas = queryResponse.getMetadatas();
        List<List<BigDecimal>> distances = queryResponse.getDistances();
        if (ids == null || ids.isEmpty()) {
            return arrayList;
        }
        for (int i = 0; i < ids.size(); i++) {
            List<String> list = ids.get(i);
            if (!list.isEmpty()) {
                List<String> list2 = documents.get(i);
                List<Map<String, Object>> list3 = metadatas.get(i);
                List<BigDecimal> list4 = distances.get(i);
                for (int i2 = 0; i2 < list.size(); i2++) {
                    String str = list.get(i2);
                    String str2 = list2.get(i2);
                    Map<String, Object> map = list3.get(i2);
                    Document document = new Document(str, str2, map, 1.0d - Math.min(1.0d, Math.max(0.0d, list4.get(i2).doubleValue())));
                    if (map.containsKey("url")) {
                        document.url((String) map.get("url"));
                    }
                    arrayList.add(document);
                }
            }
        }
        return arrayList;
    }

    public void dropRepository() throws IOException {
        if (this.collectionId != null) {
            this.config.client.deleteCollection(this.collectionId);
            this.collectionId = null;
        }
    }

    private List<Float> floatArrayToList(float[] fArr) {
        if (fArr == null) {
            return new ArrayList();
        }
        ArrayList arrayList = new ArrayList(fArr.length);
        for (float f : fArr) {
            arrayList.add(Float.valueOf(f));
        }
        return arrayList;
    }

    public static Builder builder(EmbeddingModel embeddingModel, ChromaClient chromaClient) {
        return new Builder(embeddingModel, chromaClient);
    }
}
