package dev.langchain4j.community.store.embedding.duckdb;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.sql.Array;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.duckdb.DuckDBConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/community/store/embedding/duckdb/DuckDBEmbeddingStore.class */
public class DuckDBEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(DuckDBEmbeddingStore.class);
    private static final String CREATE_TABLE_TEMPLATE = "create table if not exists %s (id UUID, embedding FLOAT[], text TEXT NULL, metadata JSON NULL);\n";
    private static final String SEARCH_QUERY_TEMPLATE = "select id, embedding, text, metadata, (list_cosine_similarity(embedding,%s)+1.0)/2.0 as score\nfrom %s\nwhere score >= %s %s\norder by score DESC\nlimit %d\n";
    private static final String INSERT_QUERY_TEMPLATE = "insert into %s (id, embedding, text, metadata) values (?,?,?,?)\n";
    private static final String DELETE_BY_IDS_QUERY_TEMPLATE = "delete from %s where id in ?\n";
    private static final String DELETE_QUERY_TEMPLATE = "delete from %s where %s\n";
    private static final String TRUNCATE_QUERY_TEMPLATE = "truncate table %s\n";
    private final String tableName;
    private final DuckDBConnection duckDBConnection;
    private final DuckDBMetadataFilterMapper jsonFilterMapper = new DuckDBMetadataFilterMapper();
    private final ObjectMapper jsonMetadataSerializer = new ObjectMapper();

    /* loaded from: input_file:dev/langchain4j/community/store/embedding/duckdb/DuckDBEmbeddingStore$Builder.class */
    public static class Builder {
        private String filePath;
        private String tableName;

        public Builder filePath(String str) {
            this.filePath = str;
            return this;
        }

        public Builder tableName(String str) {
            this.tableName = str;
            return this;
        }

        public Builder inMemory(String str) {
            return filePath(null);
        }

        public DuckDBEmbeddingStore build() {
            return new DuckDBEmbeddingStore(this.filePath, this.tableName);
        }
    }

    public DuckDBEmbeddingStore(String str, String str2) {
        String str3;
        if (str != null) {
            try {
                str3 = "jdbc:duckdb:" + str;
            } catch (SQLException e) {
                throw new DuckDBSQLException("Unable to load duckdb connection", e);
            }
        } else {
            str3 = "jdbc:duckdb:";
        }
        this.tableName = (String) Utils.getOrDefault(str2, "embeddings");
        this.duckDBConnection = DriverManager.getConnection(str3);
        initTable();
    }

    public static DuckDBEmbeddingStore inMemory() {
        return new DuckDBEmbeddingStore(null, null);
    }

    public static Builder builder() {
        return new Builder();
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        return addAll(list, null);
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).toList();
        addAll(list3, list, list2);
        return list3;
    }

    public void removeAll(Collection<String> collection) {
        ValidationUtils.ensureNotEmpty(collection, "ids");
        String format = String.format(DELETE_BY_IDS_QUERY_TEMPLATE, this.tableName);
        try {
            Connection duplicate = this.duckDBConnection.duplicate();
            try {
                PreparedStatement prepareStatement = duplicate.prepareStatement(format);
                try {
                    prepareStatement.setObject(1, duplicate.createArrayOf("UUID", collection.toArray()));
                    prepareStatement.execute();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (duplicate != null) {
                        duplicate.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new DuckDBSQLException("Unable to remove embeddings by ids", e);
        }
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull(filter, "filter");
        String format = String.format(DELETE_QUERY_TEMPLATE, this.tableName, this.jsonFilterMapper.map(filter));
        try {
            Connection duplicate = this.duckDBConnection.duplicate();
            try {
                PreparedStatement prepareStatement = duplicate.prepareStatement(format);
                try {
                    log.debug(format);
                    prepareStatement.execute();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (duplicate != null) {
                        duplicate.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new DuckDBSQLException("Unable to remove embeddings with filter", e);
        }
    }

    public void removeAll() {
        String format = String.format(TRUNCATE_QUERY_TEMPLATE, this.tableName);
        try {
            Connection duplicate = this.duckDBConnection.duplicate();
            try {
                Statement createStatement = duplicate.createStatement();
                try {
                    createStatement.execute(format);
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (duplicate != null) {
                        duplicate.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new DuckDBSQLException("Unable to remove all embeddings", e);
        }
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        String format = String.format(SEARCH_QUERY_TEMPLATE, embeddingToParam(embeddingSearchRequest.queryEmbedding()), this.tableName, Double.valueOf(embeddingSearchRequest.minScore()), embeddingSearchRequest.filter() != null ? "and " + this.jsonFilterMapper.map(embeddingSearchRequest.filter()) : "", Integer.valueOf(embeddingSearchRequest.maxResults()));
        try {
            Connection duplicate = this.duckDBConnection.duplicate();
            try {
                PreparedStatement prepareStatement = duplicate.prepareStatement(format);
                try {
                    ArrayList arrayList = new ArrayList();
                    log.debug(format);
                    ResultSet executeQuery = prepareStatement.executeQuery();
                    while (executeQuery.next()) {
                        String string = executeQuery.getString("id");
                        String string2 = executeQuery.getString("text");
                        double d = executeQuery.getDouble("score");
                        Array array = executeQuery.getArray("embedding");
                        String string3 = executeQuery.getString("metadata");
                        Map emptyMap = string3 != null ? (Map) this.jsonMetadataSerializer.readValue(string3, new TypeReference<HashMap<String, Object>>() { // from class: dev.langchain4j.community.store.embedding.duckdb.DuckDBEmbeddingStore.1
                        }) : Collections.emptyMap();
                        Object[] objArr = (Object[]) array.getArray();
                        float[] fArr = new float[objArr.length];
                        for (int i = 0; i < objArr.length; i++) {
                            fArr[i] = ((Float) objArr[i]).floatValue();
                        }
                        arrayList.add(new EmbeddingMatch(Double.valueOf(d), string, new Embedding(fArr), string2 != null ? TextSegment.from(string2, Metadata.from(emptyMap)) : null));
                    }
                    EmbeddingSearchResult<TextSegment> embeddingSearchResult = new EmbeddingSearchResult<>(arrayList);
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (duplicate != null) {
                        duplicate.close();
                    }
                    return embeddingSearchResult;
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException | JsonProcessingException e) {
            throw new DuckDBSQLException("Error while searching embeddings", e);
        }
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAll(Collections.singletonList(str), Collections.singletonList(embedding), embedding == null ? null : Collections.singletonList(textSegment));
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("[no embeddings to add to DuckDB]");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        try {
            Connection duplicate = this.duckDBConnection.duplicate();
            try {
                PreparedStatement prepareStatement = duplicate.prepareStatement(String.format(INSERT_QUERY_TEMPLATE, this.tableName));
                for (int i = 0; i < list.size(); i++) {
                    try {
                        String str = null;
                        if (list3 != null && list3.get(i) != null) {
                            str = list3.get(i).text();
                        }
                        Map map = (list3 == null || list3.get(i) == null) ? null : list3.get(i).metadata().toMap();
                        prepareStatement.setString(1, list.get(i));
                        prepareStatement.setObject(2, duplicate.createArrayOf("float", list2.get(i).vectorAsList().toArray()));
                        prepareStatement.setString(3, str);
                        prepareStatement.setString(4, this.jsonMetadataSerializer.writeValueAsString(map));
                        prepareStatement.addBatch();
                    } catch (Throwable th) {
                        if (prepareStatement != null) {
                            try {
                                prepareStatement.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                prepareStatement.executeBatch();
                if (prepareStatement != null) {
                    prepareStatement.close();
                }
                if (duplicate != null) {
                    duplicate.close();
                }
            } finally {
            }
        } catch (SQLException | JsonProcessingException e) {
            throw new DuckDBSQLException("Unable to add embeddings in DuckDB", e);
        }
    }

    private void initTable() {
        String format = String.format(CREATE_TABLE_TEMPLATE, this.tableName);
        try {
            Connection duplicate = this.duckDBConnection.duplicate();
            try {
                Statement createStatement = duplicate.createStatement();
                try {
                    log.debug(format);
                    createStatement.execute(format);
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (duplicate != null) {
                        duplicate.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new DuckDBSQLException(String.format("Failed to init duckDB table:  '%s'", format), e);
        }
    }

    protected String embeddingToParam(Embedding embedding) {
        return ((String) embedding.vectorAsList().stream().map((v0) -> {
            return v0.toString();
        }).collect(Collectors.joining(",", "[", "]"))).concat("::float[]");
    }
}
