package dev.langchain4j.community.store.embedding.clickhouse;

import com.clickhouse.client.api.Client;
import com.clickhouse.client.api.data_formats.internal.BinaryStreamReader;
import com.clickhouse.client.api.insert.InsertResponse;
import com.clickhouse.client.api.metrics.ServerMetrics;
import com.clickhouse.client.api.query.GenericRecord;
import com.clickhouse.client.api.query.Records;
import com.clickhouse.data.ClickHouseDataType;
import com.clickhouse.data.ClickHouseFormat;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/community/store/embedding/clickhouse/ClickHouseEmbeddingStore.class */
public class ClickHouseEmbeddingStore implements EmbeddingStore<TextSegment>, AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(ClickHouseEmbeddingStore.class);
    private final Client client;
    private final ClickHouseSettings settings;
    private final ClickHouseMetadataFilterMapper filterMapper;

    /* loaded from: input_file:dev/langchain4j/community/store/embedding/clickhouse/ClickHouseEmbeddingStore$Builder.class */
    public static class Builder {
        private Client client;
        private ClickHouseSettings settings;

        public Builder client(Client client) {
            this.client = client;
            return this;
        }

        public Builder settings(ClickHouseSettings clickHouseSettings) {
            this.settings = clickHouseSettings;
            return this;
        }

        public ClickHouseEmbeddingStore build() {
            return new ClickHouseEmbeddingStore(this.client, this.settings);
        }
    }

    public ClickHouseEmbeddingStore(Client client, ClickHouseSettings clickHouseSettings) {
        this.settings = (ClickHouseSettings) ValidationUtils.ensureNotNull(clickHouseSettings, "settings");
        this.filterMapper = new ClickHouseMetadataFilterMapper(clickHouseSettings.getColumnMap(), clickHouseSettings.getMetadataTypeMap());
        this.client = (Client) Optional.ofNullable(client).orElse(new Client.Builder().addEndpoint(clickHouseSettings.getUrl()).setUsername(clickHouseSettings.getUsername()).setPassword(clickHouseSettings.getPassword()).serverSetting("allow_experimental_vector_similarity_index", "1").build());
        createDatabase();
        createTable();
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        this.client.close();
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAll(list2, list, null);
        return list2;
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAll(list3, list, list2);
        return list3;
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        try {
            Records records = (Records) this.client.queryRecords(buildQuerySql(embeddingSearchRequest)).get(this.settings.getTimeout().longValue(), TimeUnit.MILLISECONDS);
            try {
                ArrayList arrayList = new ArrayList();
                records.forEach(genericRecord -> {
                    arrayList.add(toEmbeddingMatch(genericRecord));
                });
                EmbeddingSearchResult<TextSegment> embeddingSearchResult = new EmbeddingSearchResult<>((List) arrayList.stream().filter(embeddingMatch -> {
                    return embeddingMatch.score().doubleValue() >= embeddingSearchRequest.minScore();
                }).collect(Collectors.toList()));
                if (records != null) {
                    records.close();
                }
                return embeddingSearchResult;
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void removeAll(Collection<String> collection) {
        ValidationUtils.ensureNotEmpty(collection, "ids");
        removeAll(MetadataFilterBuilder.metadataKey(this.settings.getColumnMapping("id")).isIn(collection));
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull(filter, "filter");
        this.client.execute(String.format("DELETE FROM %s.%s %s", this.settings.getDatabase(), this.settings.getTable(), "WHERE " + this.filterMapper.map(filter)));
    }

    public void removeAll() {
        this.client.execute(String.format("TRUNCATE TABLE IF EXISTS %s.%s", this.settings.getDatabase(), this.settings.getTable()));
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAll(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("ClickhouseEmbeddingStore don't add empty embeddings to ClickHouse");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        int size = list.size();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            arrayList.add(toInsertData(list.get(i), list2.get(i), list3 == null ? null : list3.get(i)));
        }
        try {
            InsertResponse insertResponse = (InsertResponse) this.client.insert(this.settings.getTable(), new ByteArrayInputStream(ClickHouseJsonUtils.toJson(arrayList).getBytes(StandardCharsets.UTF_8)), ClickHouseFormat.JSON).get(this.settings.getTimeout().longValue(), TimeUnit.MILLISECONDS);
            try {
                if (log.isDebugEnabled()) {
                    log.debug("Insert finished: {} rows written", Long.valueOf(insertResponse.getMetrics().getMetric(ServerMetrics.NUM_ROWS_WRITTEN).getLong()));
                }
                if (insertResponse != null) {
                    insertResponse.close();
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private void createDatabase() {
        this.client.execute(String.format("CREATE DATABASE IF NOT EXISTS %s", this.settings.getDatabase()));
    }

    private void createTable() {
        ArrayList arrayList = new ArrayList();
        if (this.settings.containsMetadata()) {
            for (Map.Entry<String, ClickHouseDataType> entry : this.settings.getMetadataTypeMap().entrySet()) {
                arrayList.add(String.format("%s Nullable(%s)", entry.getKey(), entry.getValue().name()));
            }
        }
        this.client.execute(String.format("CREATE TABLE IF NOT EXISTS %s.%s(%s String,%s Nullable(String),%s Array(Float64),%sCONSTRAINT cons_vec_len CHECK length(%s) = %d,INDEX vec_idx %s TYPE vector_similarity('hnsw', 'cosineDistance', %d) GRANULARITY 1000) ENGINE = MergeTree ORDER BY id SETTINGS index_granularity = 8192", this.settings.getDatabase(), this.settings.getTable(), this.settings.getColumnMapping("id"), this.settings.getColumnMapping("text"), this.settings.getColumnMapping("embedding"), arrayList.isEmpty() ? "" : String.join(",", arrayList) + ", ", this.settings.getColumnMapping("embedding"), this.settings.getDimension(), this.settings.getColumnMapping("embedding"), this.settings.getDimension()));
    }

    private String buildQuerySql(EmbeddingSearchRequest embeddingSearchRequest) {
        Embedding queryEmbedding = embeddingSearchRequest.queryEmbedding();
        int maxResults = embeddingSearchRequest.maxResults();
        Filter filter = embeddingSearchRequest.filter();
        String format = filter == null ? "" : String.format("WHERE %s", this.filterMapper.map(filter));
        String str = "[" + ((String) queryEmbedding.vectorAsList().stream().map((v0) -> {
            return String.valueOf(v0);
        }).collect(Collectors.joining(","))) + "]";
        ArrayList arrayList = new ArrayList(Arrays.asList(this.settings.getColumnMapping("id"), this.settings.getColumnMapping("text"), this.settings.getColumnMapping("embedding")));
        if (this.settings.containsMetadata()) {
            arrayList.addAll(this.settings.getMetadataTypeMap().keySet());
        }
        return String.format("WITH %s AS reference_vector SELECT %s, dist FROM %s.%s %s ORDER BY cosineDistance(%s, reference_vector) AS %s ASC LIMIT %d", str, String.join(",", arrayList), this.settings.getDatabase(), this.settings.getTable(), format, this.settings.getColumnMapping("embedding"), "dist", Integer.valueOf(maxResults));
    }

    private EmbeddingMatch<TextSegment> toEmbeddingMatch(GenericRecord genericRecord) {
        String string = genericRecord.getString(this.settings.getColumnMapping("id"));
        String string2 = genericRecord.getString(this.settings.getColumnMapping("text"));
        List asList = ((BinaryStreamReader.ArrayValue) genericRecord.getObject("embedding")).asList();
        float[] fArr = new float[asList.size()];
        for (int i = 0; i < asList.size(); i++) {
            fArr[i] = ((Double) asList.get(i)).floatValue();
        }
        TextSegment textSegment = null;
        if (string2 != null) {
            Metadata metadata = new Metadata();
            if (this.settings.containsMetadata()) {
                HashMap hashMap = new HashMap();
                for (String str : this.settings.getMetadataTypeMap().keySet()) {
                    Object object = genericRecord.getObject(str);
                    if (object != null) {
                        hashMap.put(str, object);
                    }
                }
                metadata = Metadata.from(hashMap);
            }
            textSegment = TextSegment.from(string2, metadata);
        }
        return new EmbeddingMatch<>(Double.valueOf(RelevanceScore.fromCosineSimilarity(1.0d - genericRecord.getDouble("dist"))), string, Embedding.from(fArr), textSegment);
    }

    private Map<String, Object> toInsertData(String str, Embedding embedding, TextSegment textSegment) {
        HashMap hashMap = new HashMap(4);
        Float[] fArr = (Float[]) embedding.vectorAsList().toArray(new Float[0]);
        Map map = textSegment == null ? null : textSegment.metadata().toMap();
        hashMap.put(this.settings.getColumnMapping("id"), str);
        hashMap.put(this.settings.getColumnMapping("embedding"), fArr);
        hashMap.put(this.settings.getColumnMapping("text"), textSegment == null ? null : textSegment.text());
        if (this.settings.containsMetadata()) {
            for (String str2 : this.settings.getMetadataTypeMap().keySet()) {
                hashMap.put(str2, Optional.ofNullable(map).map(map2 -> {
                    return map2.get(str2);
                }).orElse(null));
            }
        }
        return hashMap;
    }
}
