package io.thomasvitale.langchain4j.spring.weaviate;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.thomasvitale.langchain4j.spring.weaviate.client.WeaviateClientConfig;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateAuthClient;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.auth.exception.AuthException;
import io.weaviate.client.v1.batch.model.ObjectGetResponse;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.v1.graphql.model.GraphQLError;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
import io.weaviate.client.v1.graphql.query.builder.GetBuilder;
import io.weaviate.client.v1.graphql.query.fields.Field;
import io.weaviate.client.v1.graphql.query.fields.Fields;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/* loaded from: input_file:io/thomasvitale/langchain4j/spring/weaviate/WeaviateEmbeddingStore.class */
public class WeaviateEmbeddingStore implements EmbeddingStore<TextSegment> {
    public static final String DEFAULT_CONSISTENCY_LEVEL = "ALL";
    public static final String DEFAULT_OBJECT_CLASS_NAME = "LangChain4jClass";
    static final String ADDITIONAL_FIELD_NAME = "_additional";
    static final String ADDITIONAL_ID_FIELD_NAME = "id";
    static final String ADDITIONAL_CERTAINTY_FIELD_NAME = "certainty";
    static final String ADDITIONAL_VECTOR_FIELD_NAME = "vector";
    static final String CONTENT_FIELD_NAME = "text";
    private final WeaviateClient weaviateClient;
    private final String objectClassName;
    private final String consistencyLevel;

    /* loaded from: input_file:io/thomasvitale/langchain4j/spring/weaviate/WeaviateEmbeddingStore$Builder.class */
    public static class Builder {
        private WeaviateClientConfig clientConfig;
        private String objectClassName;
        private String consistencyLevel;

        private Builder() {
        }

        public Builder clientConfig(WeaviateClientConfig weaviateClientConfig) {
            this.clientConfig = weaviateClientConfig;
            return this;
        }

        public Builder objectClassName(String str) {
            this.objectClassName = str;
            return this;
        }

        public Builder consistencyLevel(String str) {
            this.consistencyLevel = str;
            return this;
        }

        public WeaviateEmbeddingStore build() {
            return new WeaviateEmbeddingStore(this.clientConfig, this.objectClassName, this.consistencyLevel);
        }
    }

    private WeaviateEmbeddingStore(WeaviateClientConfig weaviateClientConfig, @Nullable String str, @Nullable String str2) {
        Assert.notNull(weaviateClientConfig, "clientConfig cannot be null");
        try {
            this.weaviateClient = WeaviateAuthClient.apiKey(new Config(weaviateClientConfig.url().getScheme(), computeFullHostFromUrl(weaviateClientConfig.url()), (Map) Objects.requireNonNullElse(weaviateClientConfig.headers(), Map.of()), (int) weaviateClientConfig.connectTimeout().toSeconds(), (int) weaviateClientConfig.readTimeout().toSeconds(), (int) weaviateClientConfig.readTimeout().toSeconds()), (String) Objects.requireNonNullElse(weaviateClientConfig.apiKey(), ""));
            this.objectClassName = StringUtils.hasText(str) ? str : DEFAULT_OBJECT_CLASS_NAME;
            this.consistencyLevel = StringUtils.hasText(str2) ? str2 : DEFAULT_CONSISTENCY_LEVEL;
        } catch (AuthException e) {
            throw new IllegalArgumentException("Failed to create Weaviate client with API Key", e);
        }
    }

    private static String computeFullHostFromUrl(URI uri) {
        return uri.getPort() == -1 ? uri.getHost() : uri.getHost() + ":" + uri.getPort();
    }

    public String add(Embedding embedding) {
        Assert.notNull(embedding, "embedding cannot be null");
        String randomUUID = Utils.randomUUID();
        sendAddEmbeddingsRequest(List.of(randomUUID), List.of(embedding), null);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        Assert.hasText(str, "id cannot be null or empty");
        Assert.notNull(embedding, "embedding cannot be null");
        sendAddEmbeddingsRequest(List.of(str), List.of(embedding), null);
    }

    public String add(Embedding embedding, @Nullable TextSegment textSegment) {
        Assert.notNull(embedding, "embedding cannot be null");
        String randomUUID = Utils.randomUUID();
        sendAddEmbeddingsRequest(List.of(randomUUID), List.of(embedding), textSegment == null ? null : List.of(textSegment));
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        Assert.notNull(list, "embeddings cannot be null");
        List<String> list2 = list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).toList();
        sendAddEmbeddingsRequest(list2, list, null);
        return list2;
    }

    public List<String> addAll(List<Embedding> list, @Nullable List<TextSegment> list2) {
        Assert.notNull(list, "embeddings cannot be null");
        List<String> list3 = list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).toList();
        sendAddEmbeddingsRequest(list3, list, list2);
        return list3;
    }

    private void sendAddEmbeddingsRequest(List<String> list, List<Embedding> list2, @Nullable List<TextSegment> list3) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list2.size(); i++) {
            arrayList.add(toWeaviateObject(list.get(i), list2.get(i), CollectionUtils.isEmpty(list3) ? "" : list3.get(i).text()));
        }
        Result run = this.weaviateClient.batch().objectsBatcher().withObjects((WeaviateObject[]) arrayList.toArray(new WeaviateObject[0])).withConsistencyLevel(this.consistencyLevel).run();
        ArrayList arrayList2 = new ArrayList();
        if (run.hasErrors()) {
            arrayList2.add((String) run.getError().getMessages().stream().map((v0) -> {
                return v0.getMessage();
            }).collect(Collectors.joining("\n")));
            throw new RuntimeException("Failed to add documents because: \n" + String.valueOf(arrayList2));
        }
        if (run.getResult() != null) {
            for (ObjectGetResponse objectGetResponse : (ObjectGetResponse[]) run.getResult()) {
                if (objectGetResponse.getResult() != null && objectGetResponse.getResult().getErrors() != null) {
                    arrayList2.add((String) objectGetResponse.getResult().getErrors().getError().stream().map((v0) -> {
                        return v0.getMessage();
                    }).collect(Collectors.joining("\n")));
                }
            }
        }
        if (!CollectionUtils.isEmpty(arrayList2)) {
            throw new RuntimeException("Failed to add documents because: \n" + String.valueOf(arrayList2));
        }
    }

    private WeaviateObject toWeaviateObject(String str, Embedding embedding, String str2) {
        HashMap hashMap = new HashMap();
        hashMap.put(CONTENT_FIELD_NAME, str2);
        return WeaviateObject.builder().className(this.objectClassName).id(str).vector((Float[]) embedding.vectorAsList().toArray(new Float[0])).properties(hashMap).build();
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        Assert.notNull(embeddingSearchRequest, "referenceEmbedding cannot be null");
        Result run = this.weaviateClient.graphQL().raw().withQuery(GetBuilder.builder().className(this.objectClassName).withNearVectorFilter(NearVectorArgument.builder().vector((Float[]) embeddingSearchRequest.queryEmbedding().vectorAsList().toArray(new Float[0])).certainty(Float.valueOf((float) embeddingSearchRequest.minScore())).build()).limit(Integer.valueOf(embeddingSearchRequest.maxResults())).fields(Fields.builder().fields((Field[]) List.of(Field.builder().name(CONTENT_FIELD_NAME).build(), Field.builder().name(ADDITIONAL_FIELD_NAME).fields(new Field[]{Field.builder().name(ADDITIONAL_ID_FIELD_NAME).build(), Field.builder().name(ADDITIONAL_CERTAINTY_FIELD_NAME).build(), Field.builder().name(ADDITIONAL_VECTOR_FIELD_NAME).build()}).build()).toArray(new Field[0])).build()).build().buildQuery()).run();
        if (run.hasErrors()) {
            throw new IllegalArgumentException((String) run.getError().getMessages().stream().map((v0) -> {
                return v0.getMessage();
            }).collect(Collectors.joining("\n")));
        }
        GraphQLError[] errors = ((GraphQLResponse) run.getResult()).getErrors();
        if (errors != null && errors.length > 0) {
            throw new IllegalArgumentException((String) Arrays.stream(errors).map((v0) -> {
                return v0.getMessage();
            }).collect(Collectors.joining("\n")));
        }
        Optional findFirst = ((Map) ((GraphQLResponse) run.getResult()).getData()).entrySet().stream().findFirst();
        if (findFirst.isEmpty()) {
            return new EmbeddingSearchResult<>(List.of());
        }
        Optional findFirst2 = ((Map) ((Map.Entry) findFirst.get()).getValue()).entrySet().stream().findFirst();
        return findFirst2.isEmpty() ? new EmbeddingSearchResult<>(List.of()) : new EmbeddingSearchResult<>(((List) ((Map.Entry) findFirst2.get()).getValue()).stream().map(WeaviateAdapters::toEmbeddingMatch).toList());
    }

    public static Builder builder() {
        return new Builder();
    }
}
