package dev.langchain4j.store.embedding.tablestore;

import com.alicloud.openservices.tablestore.SyncClient;
import com.alicloud.openservices.tablestore.core.utils.ValueUtil;
import com.alicloud.openservices.tablestore.model.CapacityUnit;
import com.alicloud.openservices.tablestore.model.Column;
import com.alicloud.openservices.tablestore.model.ColumnType;
import com.alicloud.openservices.tablestore.model.ColumnValue;
import com.alicloud.openservices.tablestore.model.CreateTableRequest;
import com.alicloud.openservices.tablestore.model.DeleteRowRequest;
import com.alicloud.openservices.tablestore.model.DeleteTableRequest;
import com.alicloud.openservices.tablestore.model.Direction;
import com.alicloud.openservices.tablestore.model.GetRangeRequest;
import com.alicloud.openservices.tablestore.model.GetRangeResponse;
import com.alicloud.openservices.tablestore.model.PrimaryKeyBuilder;
import com.alicloud.openservices.tablestore.model.PrimaryKeySchema;
import com.alicloud.openservices.tablestore.model.PrimaryKeyType;
import com.alicloud.openservices.tablestore.model.PrimaryKeyValue;
import com.alicloud.openservices.tablestore.model.PutRowRequest;
import com.alicloud.openservices.tablestore.model.RangeRowQueryCriteria;
import com.alicloud.openservices.tablestore.model.ReservedThroughput;
import com.alicloud.openservices.tablestore.model.Row;
import com.alicloud.openservices.tablestore.model.RowDeleteChange;
import com.alicloud.openservices.tablestore.model.RowPutChange;
import com.alicloud.openservices.tablestore.model.TableMeta;
import com.alicloud.openservices.tablestore.model.TableOptions;
import com.alicloud.openservices.tablestore.model.search.CreateSearchIndexRequest;
import com.alicloud.openservices.tablestore.model.search.DeleteSearchIndexRequest;
import com.alicloud.openservices.tablestore.model.search.FieldSchema;
import com.alicloud.openservices.tablestore.model.search.FieldType;
import com.alicloud.openservices.tablestore.model.search.IndexSchema;
import com.alicloud.openservices.tablestore.model.search.ListSearchIndexRequest;
import com.alicloud.openservices.tablestore.model.search.SearchHit;
import com.alicloud.openservices.tablestore.model.search.SearchIndexInfo;
import com.alicloud.openservices.tablestore.model.search.SearchQuery;
import com.alicloud.openservices.tablestore.model.search.SearchRequest;
import com.alicloud.openservices.tablestore.model.search.SearchResponse;
import com.alicloud.openservices.tablestore.model.search.query.Query;
import com.alicloud.openservices.tablestore.model.search.query.QueryBuilders;
import com.alicloud.openservices.tablestore.model.search.sort.ScoreSort;
import com.alicloud.openservices.tablestore.model.search.sort.Sort;
import com.alicloud.openservices.tablestore.model.search.vector.VectorDataType;
import com.alicloud.openservices.tablestore.model.search.vector.VectorMetricType;
import com.alicloud.openservices.tablestore.model.search.vector.VectorOptions;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/store/embedding/tablestore/TablestoreEmbeddingStore.class */
public class TablestoreEmbeddingStore implements EmbeddingStore<TextSegment> {
    private final Logger log;
    private final SyncClient client;
    private final String tableName;
    private final String searchIndexName;
    private final String pkName;
    private final String textField;
    private final String embeddingField;
    private final int vectorDimension;
    private final VectorMetricType vectorMetricType;
    private final List<FieldSchema> metadataSchemaList;
    private static final String DEFAULT_TABLE_NAME = "langchain4j_embedding_store_ots_v1";
    private static final String DEFAULT_INDEX_NAME = "langchain4j_embedding_ots_index_v1";
    private static final String DEFAULT_TABLE_PK_NAME = "id";
    private static final String DEFAULT_TEXT_FIELD_NAME = "default_content";
    private static final String DEFAULT_VECTOR_FIELD_NAME = "default_embedding";
    private static final VectorMetricType DEFAULT_VECTOR_METRIC_TYPE = VectorMetricType.COSINE;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: dev.langchain4j.store.embedding.tablestore.TablestoreEmbeddingStore$1, reason: invalid class name */
    /* loaded from: input_file:dev/langchain4j/store/embedding/tablestore/TablestoreEmbeddingStore$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$alicloud$openservices$tablestore$model$ColumnType = new int[ColumnType.values().length];

        static {
            try {
                $SwitchMap$com$alicloud$openservices$tablestore$model$ColumnType[ColumnType.STRING.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$alicloud$openservices$tablestore$model$ColumnType[ColumnType.INTEGER.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$alicloud$openservices$tablestore$model$ColumnType[ColumnType.DOUBLE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public TablestoreEmbeddingStore(SyncClient syncClient, int i) {
        this(syncClient, i, Collections.emptyList());
    }

    public TablestoreEmbeddingStore(SyncClient syncClient, int i, List<FieldSchema> list) {
        this(syncClient, DEFAULT_TABLE_NAME, DEFAULT_INDEX_NAME, DEFAULT_TABLE_PK_NAME, DEFAULT_TEXT_FIELD_NAME, DEFAULT_VECTOR_FIELD_NAME, i, DEFAULT_VECTOR_METRIC_TYPE, list);
    }

    public TablestoreEmbeddingStore(SyncClient syncClient, String str, String str2, String str3, String str4, String str5, int i, VectorMetricType vectorMetricType, List<FieldSchema> list) {
        this.log = LoggerFactory.getLogger(getClass());
        this.client = (SyncClient) ValidationUtils.ensureNotNull(syncClient, "client");
        this.tableName = ValidationUtils.ensureNotBlank(str, "tableName");
        this.searchIndexName = ValidationUtils.ensureNotBlank(str2, "searchIndexName");
        this.pkName = ValidationUtils.ensureNotBlank(str3, "pkName");
        this.textField = ValidationUtils.ensureNotBlank(str4, "textField");
        this.embeddingField = ValidationUtils.ensureNotBlank(str5, "embeddingField");
        this.vectorDimension = ValidationUtils.ensureGreaterThanZero(Integer.valueOf(i), "vectorDimension");
        this.vectorMetricType = (VectorMetricType) ValidationUtils.ensureNotNull(vectorMetricType, "vectorMetricType");
        ValidationUtils.ensureNotNull(list, "metadataSchemaList");
        ArrayList arrayList = new ArrayList();
        arrayList.add(new FieldSchema(str4, FieldType.TEXT).setIndex(true).setAnalyzer(FieldSchema.Analyzer.MaxWord));
        arrayList.add(new FieldSchema(str5, FieldType.VECTOR).setIndex(true).setVectorOptions(new VectorOptions(VectorDataType.FLOAT_32, i, vectorMetricType)));
        for (FieldSchema fieldSchema : list) {
            if (fieldSchema.getFieldName().equals(str4)) {
                throw Exceptions.illegalArgument("the custom meta data field name matches the system text field:{}", new Object[]{str4});
            }
            if (fieldSchema.getFieldName().equals(str5)) {
                throw Exceptions.illegalArgument("the custom meta data field name matches the system embedding field:{}", new Object[]{str5});
            }
            arrayList.add(fieldSchema);
        }
        this.metadataSchemaList = Collections.unmodifiableList(arrayList);
    }

    public void init() {
        createTableIfNotExist();
        createSearchIndexIfNotExist();
    }

    public SyncClient getClient() {
        return this.client;
    }

    public String getTableName() {
        return this.tableName;
    }

    public String getSearchIndexName() {
        return this.searchIndexName;
    }

    public String getPkName() {
        return this.pkName;
    }

    public String getTextField() {
        return this.textField;
    }

    public String getEmbeddingField() {
        return this.embeddingField;
    }

    public int getVectorDimension() {
        return this.vectorDimension;
    }

    public VectorMetricType getVectorMetricType() {
        return this.vectorMetricType;
    }

    public List<FieldSchema> getMetadataSchemaList() {
        return this.metadataSchemaList;
    }

    public String add(Embedding embedding) {
        String uuid = UUID.randomUUID().toString();
        innerAdd(uuid, embedding, null);
        return uuid;
    }

    public void add(String str, Embedding embedding) {
        innerAdd(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String uuid = UUID.randomUUID().toString();
        innerAdd(uuid, embedding, textSegment);
        return uuid;
    }

    public List<String> addAll(List<Embedding> list) {
        return addAll(list, null);
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (list3 != null) {
            ValidationUtils.ensureEq(Integer.valueOf(list2.size()), Integer.valueOf(list3.size()), "the size of embeddings should be the same as the size of embedded", new Object[0]);
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list2.size(); i++) {
            try {
                innerAdd(list.get(i), list2.get(i), list3 != null ? list3.get(i) : null);
            } catch (Exception e) {
                arrayList.add(e);
            }
        }
        if (arrayList.isEmpty()) {
            return;
        }
        IllegalStateException illegalStateException = new IllegalStateException("Add all embeddings with error, failed:" + arrayList.size());
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            illegalStateException.addSuppressed((Exception) it.next());
        }
        throw illegalStateException;
    }

    public void remove(String str) {
        ValidationUtils.ensureNotBlank(str, DEFAULT_TABLE_PK_NAME);
        innerDelete(str);
    }

    public void removeAll(Collection<String> collection) {
        if (collection == null || collection.isEmpty()) {
            throw Exceptions.illegalArgument("ids cannot be null or empty", new Object[0]);
        }
        this.log.debug("remove all:{}", collection);
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = collection.iterator();
        while (it.hasNext()) {
            try {
                remove(it.next());
            } catch (Exception e) {
                arrayList.add(e);
            }
        }
        if (arrayList.isEmpty()) {
            return;
        }
        IllegalStateException illegalStateException = new IllegalStateException("remove all embeddings with error, failed:" + arrayList.size());
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            illegalStateException.addSuppressed((Exception) it2.next());
        }
        throw illegalStateException;
    }

    public void removeAll(Filter filter) {
        if (filter == null) {
            throw Exceptions.illegalArgument("filter cannot be null", new Object[0]);
        }
        forEachAllData(Collections.emptyList(), row -> {
            if (filter.test(rowToMetadata(row))) {
                remove(row.getPrimaryKey().getPrimaryKeyColumn(this.pkName).getValue().asString());
            }
        });
    }

    public void removeAll() {
        this.log.debug("remove all");
        forEachAllData(Collections.emptyList(), row -> {
            innerDelete(row.getPrimaryKey().getPrimaryKeyColumn(this.pkName).getValue().asString());
        });
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        this.log.debug("search ([...{}...], {}, {})", new Object[]{Integer.valueOf(embeddingSearchRequest.queryEmbedding().vector().length), Integer.valueOf(embeddingSearchRequest.maxResults()), Double.valueOf(embeddingSearchRequest.minScore())});
        SearchResponse search = this.client.search(SearchRequest.newBuilder().tableName(this.tableName).indexName(this.searchIndexName).searchQuery(SearchQuery.newBuilder().query(QueryBuilders.knnVector(this.embeddingField, embeddingSearchRequest.maxResults(), embeddingSearchRequest.queryEmbedding().vector()).filter(mapFilterToQuery(embeddingSearchRequest.filter())).build()).getTotalCount(false).limit(embeddingSearchRequest.maxResults()).offset(0).sort(new Sort(Collections.singletonList(new ScoreSort()))).build()).returnAllColumns(true).build());
        this.log.debug("search requestId:{}", search.getRequestId());
        return searchResponseToEmbeddingSearchResult(embeddingSearchRequest, search);
    }

    protected Query mapFilterToQuery(Filter filter) {
        return TablestoreMetadataFilterMapper.map(filter);
    }

    private EmbeddingSearchResult<TextSegment> searchResponseToEmbeddingSearchResult(EmbeddingSearchRequest embeddingSearchRequest, SearchResponse searchResponse) {
        List<SearchHit> searchHits = searchResponse.getSearchHits();
        ArrayList arrayList = new ArrayList(searchHits.size());
        for (SearchHit searchHit : searchHits) {
            Double score = searchHit.getScore();
            if (score.doubleValue() >= embeddingSearchRequest.minScore()) {
                Row row = searchHit.getRow();
                String asString = row.getLatestColumn(this.textField) != null ? row.getLatestColumn(this.textField).getValue().asString() : null;
                float[] parseEmbeddingString = row.getLatestColumn(this.embeddingField) != null ? TablestoreUtils.parseEmbeddingString(row.getLatestColumn(this.embeddingField).getValue().asString()) : null;
                Metadata rowToMetadata = rowToMetadata(row);
                TextSegment textSegment = null;
                if (asString != null && parseEmbeddingString != null) {
                    textSegment = new TextSegment(asString, rowToMetadata);
                }
                arrayList.add(new EmbeddingMatch(score, row.getPrimaryKey().getPrimaryKeyColumn(this.pkName).getValue().asString(), new Embedding(parseEmbeddingString), textSegment));
            }
        }
        return new EmbeddingSearchResult<>(arrayList);
    }

    private void createTableIfNotExist() {
        if (tableExists()) {
            this.log.info("table:{} already exists", this.tableName);
            return;
        }
        TableMeta tableMeta = new TableMeta(this.tableName);
        tableMeta.addPrimaryKeyColumn(new PrimaryKeySchema(this.pkName, PrimaryKeyType.STRING));
        CreateTableRequest createTableRequest = new CreateTableRequest(tableMeta, new TableOptions(-1, 1));
        createTableRequest.setReservedThroughput(new ReservedThroughput(new CapacityUnit(0, 0)));
        this.client.createTable(createTableRequest);
        this.log.info("create table:{}", this.tableName);
    }

    private void createSearchIndexIfNotExist() {
        if (searchindexExists()) {
            this.log.info("index:{} already exists", this.searchIndexName);
            return;
        }
        CreateSearchIndexRequest createSearchIndexRequest = new CreateSearchIndexRequest();
        createSearchIndexRequest.setTableName(this.tableName);
        createSearchIndexRequest.setIndexName(this.searchIndexName);
        IndexSchema indexSchema = new IndexSchema();
        indexSchema.setFieldSchemas(this.metadataSchemaList);
        createSearchIndexRequest.setIndexSchema(indexSchema);
        this.client.createSearchIndex(createSearchIndexRequest);
        this.log.info("create index:{}", this.searchIndexName);
    }

    protected void deleteTableAndIndex() {
        deleteIndex(listSearchIndex());
        deleteTable();
    }

    private boolean tableExists() {
        return this.client.listTable().getTableNames().contains(this.tableName);
    }

    private boolean searchindexExists() {
        Iterator<SearchIndexInfo> it = listSearchIndex().iterator();
        while (it.hasNext()) {
            if (it.next().getIndexName().equals(this.searchIndexName)) {
                return true;
            }
        }
        return false;
    }

    private void deleteIndex(List<SearchIndexInfo> list) {
        list.forEach(searchIndexInfo -> {
            DeleteSearchIndexRequest deleteSearchIndexRequest = new DeleteSearchIndexRequest();
            deleteSearchIndexRequest.setTableName(searchIndexInfo.getTableName());
            deleteSearchIndexRequest.setIndexName(searchIndexInfo.getIndexName());
            this.client.deleteSearchIndex(deleteSearchIndexRequest);
            this.log.info("delete table:{}, index:{}", searchIndexInfo.getTableName(), searchIndexInfo.getIndexName());
        });
    }

    private void deleteTable() {
        this.client.deleteTable(new DeleteTableRequest(this.tableName));
        this.log.info("delete table:{}", this.tableName);
    }

    private List<SearchIndexInfo> listSearchIndex() {
        ListSearchIndexRequest listSearchIndexRequest = new ListSearchIndexRequest();
        listSearchIndexRequest.setTableName(this.tableName);
        return this.client.listSearchIndex(listSearchIndexRequest).getIndexInfos();
    }

    protected void innerAdd(String str, Embedding embedding, TextSegment textSegment) {
        ValidationUtils.ensureNotNull(embedding, "embedding");
        PrimaryKeyBuilder createPrimaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
        createPrimaryKeyBuilder.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.fromString(str));
        RowPutChange rowPutChange = new RowPutChange(this.tableName, createPrimaryKeyBuilder.build());
        rowPutChange.addColumn(new Column(this.embeddingField, ColumnValue.fromString(TablestoreUtils.embeddingToString(embedding.vector()))));
        if (textSegment != null) {
            String text = textSegment.text();
            if (text != null) {
                rowPutChange.addColumn(new Column(this.textField, ColumnValue.fromString(text)));
            }
            Metadata metadata = textSegment.metadata();
            if (metadata != null) {
                for (Map.Entry entry : metadata.toMap().entrySet()) {
                    String str2 = (String) entry.getKey();
                    Object value = entry.getValue();
                    if (this.textField.equals(str2)) {
                        throw Exceptions.illegalArgument("there is a metadata(%s,%s) that is consistent with the name of the text field:%s", new Object[]{str2, value, this.textField});
                    }
                    if (this.embeddingField.equals(str2)) {
                        throw Exceptions.illegalArgument("there is a metadata(%s,%s) that is consistent with the name of the vector field:%s", new Object[]{str2, value, this.embeddingField});
                    }
                    if (value instanceof Float) {
                        rowPutChange.addColumn(new Column(str2, ColumnValue.fromDouble(((Float) value).floatValue())));
                    } else if (value instanceof UUID) {
                        rowPutChange.addColumn(new Column(str2, ColumnValue.fromString(((UUID) value).toString())));
                    } else {
                        rowPutChange.addColumn(new Column(str2, ValueUtil.toColumnValue(value)));
                    }
                }
            }
        }
        try {
            this.client.putRow(new PutRowRequest(rowPutChange));
            if (this.log.isDebugEnabled()) {
                this.log.debug("add id:{}, textSegment:{}, embedding:{}", new Object[]{str, textSegment, TablestoreUtils.maxLogOrNull(embedding.toString())});
            }
        } catch (Exception e) {
            throw new RuntimeException(String.format("add embedding data failed, id:%s, textSegment:%s,embedding:%s", str, textSegment, embedding), e);
        }
    }

    protected void innerDelete(String str) {
        PrimaryKeyBuilder createPrimaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
        createPrimaryKeyBuilder.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.fromString(str));
        try {
            this.client.deleteRow(new DeleteRowRequest(new RowDeleteChange(this.tableName, createPrimaryKeyBuilder.build())));
            this.log.debug("delete id:{}", str);
        } catch (Exception e) {
            throw new RuntimeException(String.format("delete embedding data failed, id:%s", str), e);
        }
    }

    private void forEachAllData(Collection<String> collection, Consumer<Row> consumer) {
        RangeRowQueryCriteria rangeRowQueryCriteria = new RangeRowQueryCriteria(this.tableName);
        PrimaryKeyBuilder createPrimaryKeyBuilder = PrimaryKeyBuilder.createPrimaryKeyBuilder();
        createPrimaryKeyBuilder.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.INF_MIN);
        PrimaryKeyBuilder createPrimaryKeyBuilder2 = PrimaryKeyBuilder.createPrimaryKeyBuilder();
        createPrimaryKeyBuilder2.addPrimaryKeyColumn(this.pkName, PrimaryKeyValue.INF_MAX);
        rangeRowQueryCriteria.setInclusiveStartPrimaryKey(createPrimaryKeyBuilder.build());
        rangeRowQueryCriteria.setExclusiveEndPrimaryKey(createPrimaryKeyBuilder2.build());
        rangeRowQueryCriteria.setMaxVersions(1);
        rangeRowQueryCriteria.setLimit(5000);
        rangeRowQueryCriteria.addColumnsToGet(collection);
        rangeRowQueryCriteria.setDirection(Direction.FORWARD);
        GetRangeRequest getRangeRequest = new GetRangeRequest(rangeRowQueryCriteria);
        while (true) {
            GetRangeResponse range = this.client.getRange(getRangeRequest);
            Iterator it = range.getRows().iterator();
            while (it.hasNext()) {
                consumer.accept((Row) it.next());
            }
            if (range.getNextStartPrimaryKey() == null) {
                return;
            } else {
                rangeRowQueryCriteria.setInclusiveStartPrimaryKey(range.getNextStartPrimaryKey());
            }
        }
    }

    private Metadata rowToMetadata(Row row) {
        Metadata metadata = new Metadata();
        for (Column column : row.getColumns()) {
            if (!column.getName().equals(this.embeddingField) && !column.getName().equals(this.textField)) {
                ColumnType type = column.getValue().getType();
                switch (AnonymousClass1.$SwitchMap$com$alicloud$openservices$tablestore$model$ColumnType[type.ordinal()]) {
                    case 1:
                        metadata.put(column.getName(), column.getValue().asString());
                        break;
                    case 2:
                        metadata.put(column.getName(), column.getValue().asLong());
                        break;
                    case 3:
                        metadata.put(column.getName(), column.getValue().asDouble());
                        break;
                    default:
                        this.log.warn("unsupported columnType:{}, key:{}, value:{}", new Object[]{type, column.getName(), column.getValue()});
                        break;
                }
            }
        }
        return metadata;
    }
}
