package dev.langchain4j.community.store.embedding.alloydb;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.pgvector.PGvector;
import dev.langchain4j.community.store.embedding.alloydb.filter.AlloyDBFilterMapper;
import dev.langchain4j.community.store.embedding.alloydb.index.BaseIndex;
import dev.langchain4j.community.store.embedding.alloydb.index.DistanceStrategy;
import dev.langchain4j.community.store.embedding.alloydb.index.ScaNNIndex;
import dev.langchain4j.community.store.embedding.alloydb.index.query.QueryOptions;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/community/store/embedding/alloydb/AlloyDBEmbeddingStore.class */
public class AlloyDBEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT);
    private final AlloyDBFilterMapper FILTER_MAPPER = new AlloyDBFilterMapper();
    private final AlloyDBEngine engine;
    private final String tableName;
    private final String schemaName;
    private final String contentColumn;
    private final String embeddingColumn;
    private final String idColumn;
    private final List<String> metadataColumns;
    private final DistanceStrategy distanceStrategy;
    private final QueryOptions queryOptions;
    private String metadataJsonColumn;
    private final String insertQuery;
    private final String deleteQuery;

    /* loaded from: input_file:dev/langchain4j/community/store/embedding/alloydb/AlloyDBEmbeddingStore$Builder.class */
    public static class Builder {
        private final AlloyDBEngine engine;
        private final String tableName;
        private String schemaName = "public";
        private String contentColumn = "content";
        private String embeddingColumn = "embedding";
        private String idColumn = "langchain4j_id";
        private List<String> metadataColumns = new ArrayList();
        private String metadataJsonColumn = "langchain4j_metadata";
        private List<String> ignoreMetadataColumnNames = new ArrayList();
        private DistanceStrategy distanceStrategy = DistanceStrategy.COSINE_DISTANCE;
        private QueryOptions queryOptions;

        public Builder(AlloyDBEngine alloyDBEngine, String str) {
            this.engine = alloyDBEngine;
            this.tableName = str;
        }

        public Builder schemaName(String str) {
            this.schemaName = str;
            return this;
        }

        public Builder contentColumn(String str) {
            this.contentColumn = str;
            return this;
        }

        public Builder embeddingColumn(String str) {
            this.embeddingColumn = str;
            return this;
        }

        public Builder idColumn(String str) {
            this.idColumn = str;
            return this;
        }

        public Builder metadataColumns(List<String> list) {
            this.metadataColumns = list;
            return this;
        }

        public Builder metadataJsonColumn(String str) {
            this.metadataJsonColumn = str;
            return this;
        }

        public Builder ignoreMetadataColumnNames(List<String> list) {
            this.ignoreMetadataColumnNames = list;
            return this;
        }

        public Builder distanceStrategy(DistanceStrategy distanceStrategy) {
            this.distanceStrategy = distanceStrategy;
            return this;
        }

        public Builder queryOptions(QueryOptions queryOptions) {
            this.queryOptions = queryOptions;
            return this;
        }

        public AlloyDBEmbeddingStore build() {
            return new AlloyDBEmbeddingStore(this);
        }
    }

    public AlloyDBEmbeddingStore(Builder builder) {
        this.engine = builder.engine;
        this.tableName = builder.tableName;
        this.schemaName = builder.schemaName;
        this.contentColumn = builder.contentColumn;
        this.embeddingColumn = builder.embeddingColumn;
        this.idColumn = builder.idColumn;
        this.metadataJsonColumn = builder.metadataJsonColumn;
        this.metadataColumns = builder.metadataColumns;
        this.distanceStrategy = builder.distanceStrategy;
        this.queryOptions = builder.queryOptions;
        verifyEmbeddingStoreColumns(builder.ignoreMetadataColumnNames);
        this.insertQuery = generateInsertQuery();
        this.deleteQuery = String.format("DELETE FROM \"%s\".\"%s\" WHERE %s = ANY(?)", this.schemaName, this.tableName, this.idColumn);
    }

    private void verifyEmbeddingStoreColumns(List<String> list) {
        if (!this.metadataColumns.isEmpty() && !list.isEmpty()) {
            throw new IllegalArgumentException("Cannot use both metadataColumns and ignoreMetadataColumns at the same time.");
        }
        String format = String.format("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '%s' AND table_schema = '%s'", this.tableName, this.schemaName);
        HashMap hashMap = new HashMap();
        try {
            Connection connection = this.engine.getConnection();
            try {
                ResultSet executeQuery = connection.createStatement().executeQuery(format);
                while (executeQuery.next()) {
                    hashMap.put(executeQuery.getString("column_name"), executeQuery.getString("data_type"));
                }
                if (!hashMap.containsKey(this.idColumn)) {
                    throw new IllegalStateException("Id column, " + this.idColumn + ", does not exist.");
                }
                if (!hashMap.containsKey(this.contentColumn)) {
                    throw new IllegalStateException("Content column, " + this.contentColumn + ", does not exist.");
                }
                if (!((String) hashMap.get(this.contentColumn)).equalsIgnoreCase("text") && !((String) hashMap.get(this.contentColumn)).contains("char")) {
                    throw new IllegalStateException("Content column, is type " + ((String) hashMap.get(this.contentColumn)) + ". It must be a type of character string.");
                }
                if (!hashMap.containsKey(this.embeddingColumn)) {
                    throw new IllegalStateException("Embedding column, " + this.embeddingColumn + ", does not exist.");
                }
                if (!((String) hashMap.get(this.embeddingColumn)).equalsIgnoreCase("USER-DEFINED")) {
                    throw new IllegalStateException("Embedding column, " + this.embeddingColumn + ", is not type Vector.");
                }
                if (!hashMap.containsKey(this.metadataJsonColumn)) {
                    this.metadataJsonColumn = null;
                }
                for (String str : this.metadataColumns) {
                    if (!hashMap.containsKey(str)) {
                        throw new IllegalStateException("Metadata column, " + str + ", does not exist.");
                    }
                }
                if (list != null && !list.isEmpty()) {
                    Map map = (Map) hashMap.entrySet().stream().collect(Collectors.toMap(entry -> {
                        return (String) entry.getKey();
                    }, entry2 -> {
                        return (String) entry2.getValue();
                    }));
                    list.add(this.idColumn);
                    list.add(this.contentColumn);
                    list.add(this.embeddingColumn);
                    Iterator<String> it = list.iterator();
                    while (it.hasNext()) {
                        map.remove(it.next());
                    }
                    this.metadataColumns.addAll(map.keySet());
                }
                if (connection != null) {
                    connection.close();
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException("Exception caught when verifying vector store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", e);
        }
    }

    private String generateInsertQuery() {
        String str = (String) this.metadataColumns.stream().map(str2 -> {
            return "\"" + str2 + "\"";
        }).collect(Collectors.joining(", "));
        int i = 3;
        if (Utils.isNotNullOrEmpty(str)) {
            i = 3 + str.split(",").length;
            str = ", " + str;
        }
        if (Utils.isNotNullOrEmpty(this.metadataJsonColumn)) {
            str = str + ", \"" + this.metadataJsonColumn + "\"";
            i++;
        }
        String str3 = "?";
        for (int i2 = 1; i2 < i; i2++) {
            str3 = str3 + ", ?";
        }
        return String.format("INSERT INTO \"%s\".\"%s\" (\"%s\", \"%s\", \"%s\"%s) VALUES (%s)", this.schemaName, this.tableName, this.idColumn, this.embeddingColumn, this.contentColumn, str, str3);
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, null);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAll(list2, list, Collections.nCopies(list2.size(), null));
        return list2;
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAll(list3, list, list2);
        return list3;
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        ArrayList arrayList = new ArrayList(this.metadataColumns);
        arrayList.add(this.idColumn);
        arrayList.add(this.contentColumn);
        arrayList.add(this.embeddingColumn);
        if (Utils.isNotNullOrBlank(this.metadataJsonColumn)) {
            arrayList.add(this.metadataJsonColumn);
        }
        String str = (String) arrayList.stream().map(str2 -> {
            return String.format("\"%s\"", str2);
        }).collect(Collectors.joining(", "));
        String map = this.FILTER_MAPPER.map(embeddingSearchRequest.filter());
        String format = String.format("SELECT %s, %s(%s, ?) as distance FROM \"%s\".\"%s\" %s ORDER BY %s %s ? LIMIT ?;", str, this.distanceStrategy.getSearchFunction(), this.embeddingColumn, this.schemaName, this.tableName, Utils.isNotNullOrBlank(map) ? String.format("WHERE %s", map) : "", this.embeddingColumn, this.distanceStrategy.getOperator());
        ArrayList arrayList2 = new ArrayList();
        try {
            Connection connection = this.engine.getConnection();
            try {
                PGvector.registerTypes(connection);
                Statement createStatement = connection.createStatement();
                try {
                    if (this.queryOptions != null) {
                        Iterator<String> it = this.queryOptions.getParameterSettings().iterator();
                        while (it.hasNext()) {
                            createStatement.executeQuery(String.format("SET LOCAL %s;", it.next()));
                        }
                    }
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    try {
                        PreparedStatement prepareStatement = connection.prepareStatement(format);
                        try {
                            prepareStatement.setObject(1, new PGvector(embeddingSearchRequest.queryEmbedding().vector()));
                            prepareStatement.setObject(2, new PGvector(embeddingSearchRequest.queryEmbedding().vector()));
                            prepareStatement.setInt(3, embeddingSearchRequest.maxResults());
                            ResultSet executeQuery = prepareStatement.executeQuery();
                            while (executeQuery.next()) {
                                double calculateRelevanceScore = calculateRelevanceScore(executeQuery.getDouble("distance"));
                                if (calculateRelevanceScore >= embeddingSearchRequest.minScore()) {
                                    String string = executeQuery.getString(this.idColumn);
                                    Embedding from = Embedding.from(((PGvector) executeQuery.getObject(this.embeddingColumn)).toArray());
                                    String string2 = executeQuery.getString(this.contentColumn);
                                    HashMap hashMap = new HashMap();
                                    for (String str3 : this.metadataColumns) {
                                        if (executeQuery.getObject(str3) != null) {
                                            hashMap.put(str3, executeQuery.getObject(str3));
                                        }
                                    }
                                    if (Utils.isNotNullOrBlank(this.metadataJsonColumn)) {
                                        hashMap.putAll((Map) OBJECT_MAPPER.readValue((String) Utils.getOrDefault(executeQuery.getString(this.metadataJsonColumn), "{}"), Map.class));
                                    }
                                    arrayList2.add(new EmbeddingMatch(Double.valueOf(calculateRelevanceScore), string, from, string2 != null ? new TextSegment(string2, Metadata.from(hashMap)) : null));
                                }
                            }
                            if (prepareStatement != null) {
                                prepareStatement.close();
                            }
                            if (connection != null) {
                                connection.close();
                            }
                            return new EmbeddingSearchResult<>(arrayList2);
                        } catch (Throwable th) {
                            if (prepareStatement != null) {
                                try {
                                    prepareStatement.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    } catch (JsonProcessingException e) {
                        throw new RuntimeException("Exception caught when processing JSON metadata", e);
                    }
                } catch (Throwable th3) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    }
                    throw th3;
                }
            } finally {
            }
        } catch (SQLException e2) {
            throw new RuntimeException("Exception caught when searching in store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", e2);
        }
    }

    public void removeAll(Collection<String> collection) {
        if (collection == null || collection.isEmpty()) {
            throw new IllegalArgumentException("ids cannot be null or empty");
        }
        try {
            Connection connection = this.engine.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(this.deleteQuery);
                try {
                    prepareStatement.setArray(1, connection.createArrayOf("uuid", collection.stream().map(UUID::fromString).toArray()));
                    prepareStatement.executeUpdate();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(String.format("Exception caught when deleting from vector store table: \"%s\".\"%s\"", this.schemaName, this.tableName), e);
        }
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAll(Collections.singletonList(str), Collections.singletonList(embedding), Collections.singletonList(textSegment));
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (list.size() != list2.size() || list2.size() != list3.size()) {
            throw new IllegalArgumentException("List parameters ids and embeddings and textSegments shouldn't be different sizes!");
        }
        try {
            Connection connection = this.engine.getConnection();
            try {
                try {
                    PreparedStatement prepareStatement = connection.prepareStatement(this.insertQuery);
                    try {
                        PGvector.registerTypes(connection);
                        for (int i = 0; i < list.size(); i++) {
                            String str = list.get(i);
                            Embedding embedding = list2.get(i);
                            TextSegment textSegment = list3.get(i);
                            String text = textSegment != null ? textSegment.text() : null;
                            Map map = textSegment != null ? (Map) textSegment.metadata().toMap().entrySet().stream().collect(Collectors.toMap(entry -> {
                                return (String) entry.getKey();
                            }, entry2 -> {
                                return entry2.getValue();
                            })) : null;
                            prepareStatement.setObject(1, UUID.fromString(str), 1111);
                            prepareStatement.setObject(2, new PGvector(embedding.vector()));
                            prepareStatement.setString(3, text);
                            int i2 = 0;
                            if (map == null || map.isEmpty()) {
                                while (i2 < this.metadataColumns.size()) {
                                    prepareStatement.setObject(i2 + 4, null);
                                    i2++;
                                }
                                if (Utils.isNotNullOrEmpty(this.metadataJsonColumn)) {
                                    prepareStatement.setObject(i2 + 4, null);
                                }
                            } else {
                                while (i2 < this.metadataColumns.size()) {
                                    if (map.containsKey(this.metadataColumns.get(i2))) {
                                        prepareStatement.setObject(i2 + 4, map.remove(this.metadataColumns.get(i2)));
                                    } else {
                                        prepareStatement.setObject(i2 + 4, null);
                                    }
                                    i2++;
                                }
                                if (Utils.isNotNullOrEmpty(this.metadataJsonColumn)) {
                                    prepareStatement.setObject(i2 + 4, OBJECT_MAPPER.writeValueAsString(map), 1111);
                                }
                            }
                            prepareStatement.addBatch();
                        }
                        prepareStatement.executeBatch();
                        if (prepareStatement != null) {
                            prepareStatement.close();
                        }
                        if (connection != null) {
                            connection.close();
                        }
                    } catch (Throwable th) {
                        if (prepareStatement != null) {
                            try {
                                prepareStatement.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } finally {
                }
            } catch (JsonProcessingException e) {
                throw new RuntimeException("Exception caught when processing JSON metadata", e);
            }
        } catch (SQLException e2) {
            throw new RuntimeException("Exception caught when inserting into vector store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", e2);
        }
    }

    public void applyVectorIndex(BaseIndex baseIndex, String str, Boolean bool) {
        String indexFunction;
        if (baseIndex == null) {
            dropVectorIndex(null);
            return;
        }
        if (Utils.isNullOrBlank(str)) {
            str = Utils.isNotNullOrBlank(baseIndex.getName()) ? baseIndex.getName() : this.tableName + "langchainvectorindex";
        }
        try {
            Connection connection = this.engine.getConnection();
            try {
                if (baseIndex instanceof ScaNNIndex) {
                    connection.createStatement().executeQuery("CREATE EXTENSION IF NOT EXISTS alloydb_scann");
                    indexFunction = ((ScaNNIndex) baseIndex).getDistanceStrategy().getScannIndexFunction();
                } else {
                    indexFunction = baseIndex.getDistanceStrategy().getIndexFunction();
                }
                String format = (baseIndex.getPartialIndexes() == null || !baseIndex.getPartialIndexes().isEmpty()) ? "" : String.format("WHERE %s", String.join(", ", baseIndex.getPartialIndexes()));
                connection.createStatement().executeQuery(String.format("CREATE INDEX %s %s ON \"%s\".\"%s\" USING %s (%s %s) %s %s;", bool.booleanValue() ? "CONCURRENTLY" : "", str, this.schemaName, this.tableName, baseIndex.getIndexType(), this.embeddingColumn, indexFunction, String.format("WITH %s", baseIndex.getIndexOptions()), format));
                if (connection != null) {
                    connection.close();
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException("Exception caught when creating " + str + " index in vector store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", e);
        }
    }

    public void dropVectorIndex(String str) {
        String str2 = Utils.isNotNullOrBlank(str) ? str : this.tableName + "langchainvectorindex";
        String format = String.format("DROP INDEX IF EXISTS %s;", str2);
        try {
            Connection connection = this.engine.getConnection();
            try {
                connection.createStatement().executeQuery(format);
                if (connection != null) {
                    connection.close();
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException("Exception caught when removing " + str2 + " index in vector store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", e);
        }
    }

    public void reindex(String str) {
        String str2 = Utils.isNotNullOrBlank(str) ? str : this.tableName + "langchainvectorindex";
        String format = String.format("REINDEX INDEX %s;", str2);
        try {
            Connection connection = this.engine.getConnection();
            try {
                connection.createStatement().executeQuery(format);
                if (connection != null) {
                    connection.close();
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException("Exception caught when reindexing " + str2 + " index in vector store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", e);
        }
    }

    private double calculateRelevanceScore(double d) {
        String name = this.distanceStrategy.name();
        boolean z = -1;
        switch (name.hashCode()) {
            case -1278492730:
                if (name.equals("INNER_PRODUCT")) {
                    z = 2;
                    break;
                }
                break;
            case -766752066:
                if (name.equals("EUCLIDEAN")) {
                    z = false;
                    break;
                }
                break;
            case 951577723:
                if (name.equals("COSINE_DISTANCE")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return 1.0d - (d / Math.sqrt(2.0d));
            case true:
                return RelevanceScore.fromCosineSimilarity(1.0d - d);
            case true:
                return d > 0.0d ? 1.0d - d : (-1.0d) * d;
            default:
                throw new UnsupportedOperationException(String.format("Unable to calculate relevance score for search function: %s ", this.distanceStrategy.getSearchFunction()));
        }
    }

    public static Builder builder(AlloyDBEngine alloyDBEngine, String str) {
        return new Builder(alloyDBEngine, str);
    }
}
