package io.debezium.ai.embeddings;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import io.debezium.DebeziumException;
import io.debezium.Module;
import io.debezium.config.Configuration;
import io.debezium.config.Field;
import io.debezium.data.vector.FloatVector;
import io.debezium.transforms.ConnectRecordUtil;
import io.debezium.transforms.SmtManager;
import io.debezium.util.BoundedConcurrentHashMap;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.ServiceLoader;
import org.apache.kafka.common.config.ConfigDef;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.connect.components.Versioned;
import org.apache.kafka.connect.connector.ConnectRecord;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.connect.transforms.Transformation;
import org.apache.kafka.connect.transforms.util.Requirements;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/debezium/ai/embeddings/FieldToEmbedding.class */
public class FieldToEmbedding<R extends ConnectRecord<R>> implements Transformation<R>, Versioned {
    public static final String EMBEDDINGS_PREFIX = "embeddings.";
    private SmtManager<R> smtManager;
    private String sourceField;
    private String embeddingsField;
    private List<String> sourceFieldPath;
    private EmbeddingModel model;
    private static final String NESTING_SPLIT_REG_EXP = "\\.";
    private static final int CACHE_SIZE = 64;
    private final BoundedConcurrentHashMap<Schema, Schema> schemaUpdateCache = new BoundedConcurrentHashMap<>(CACHE_SIZE);
    private static final Logger LOGGER = LoggerFactory.getLogger(FieldToEmbedding.class);
    private static final Field TEXT_FIELD = Field.create("embeddings.field.source").withDisplayName("Name of the record field from which embeddings should be created.").withType(ConfigDef.Type.STRING).withWidth(ConfigDef.Width.SHORT).withImportance(ConfigDef.Importance.HIGH).required().withDescription("Name of the field of the record which content will be used as an input for embeddings. Supports also nested fields.");
    private static final Field EMBEDDGINS_FIELD = Field.create("embeddings.field.embedding").withDisplayName("Name of the field which would contain the embeddings of the input field.").withType(ConfigDef.Type.STRING).withWidth(ConfigDef.Width.SHORT).withImportance(ConfigDef.Importance.HIGH).withDescription("Name of the field which which will be appended to the record and which would contain the embeddings of the content `filed.source` field. Supports also nested fields.");
    private static final Schema EMBEDDING_SCHEMA = FloatVector.schema();
    private static final EmbeddingsModelFactory MODEL_FACTORY = EmbeddingsModelFactoryLoader.getModelFactory();
    public static final Field.Set ALL_FIELDS = Field.setOf(new Field[]{TEXT_FIELD, EMBEDDGINS_FIELD}).with(MODEL_FACTORY.getConfigFields());

    /* loaded from: input_file:io/debezium/ai/embeddings/FieldToEmbedding$EmbeddingsModelFactoryLoader.class */
    public static class EmbeddingsModelFactoryLoader<R extends ConnectRecord<R>> {
        private static final Logger LOGGER = LoggerFactory.getLogger(EmbeddingsModelFactoryLoader.class);

        static EmbeddingsModelFactory getModelFactory() {
            ServiceLoader load = ServiceLoader.load(EmbeddingsModelFactory.class);
            Optional findFirst = load.findFirst();
            if (findFirst.isEmpty()) {
                throw new DebeziumException("No implementation of Debezium embeddings model factory found.");
            }
            if (load.stream().count() > 1) {
                LOGGER.warn("More then one Debezium embeddings model factory found. Order of loading is not defined and you may load different factory than you intended.");
                LOGGER.warn("Found following factories:");
                load.stream().forEach(provider -> {
                    LOGGER.warn(((EmbeddingsModelFactory) provider.get()).getClass().getName());
                });
            }
            return (EmbeddingsModelFactory) findFirst.get();
        }
    }

    public void configure(Map<String, ?> map) {
        Configuration from = Configuration.from(map);
        this.smtManager = new SmtManager<>(from);
        this.smtManager.validate(from, ALL_FIELDS);
        this.sourceField = from.getString(TEXT_FIELD);
        this.embeddingsField = from.getString(EMBEDDGINS_FIELD);
        MODEL_FACTORY.configure(from);
        validateConfiguration();
        this.sourceFieldPath = Arrays.asList(this.sourceField.split(NESTING_SPLIT_REG_EXP));
        this.model = MODEL_FACTORY.getModel();
    }

    public R apply(R r) {
        if (r.value() == null || !this.smtManager.isValidEnvelope(r)) {
            LOGGER.trace("Record {} has null value of invalid envelope and will be skipped.", r.value());
            return r;
        }
        String sourceString = getSourceString(r);
        return sourceString == null ? r : buildUpdatedRecord(r, sourceString);
    }

    public ConfigDef config() {
        ConfigDef configDef = new ConfigDef();
        Field.group(configDef, (String) null, new Field[]{TEXT_FIELD, EMBEDDGINS_FIELD});
        return configDef;
    }

    public void close() {
    }

    public String version() {
        return Module.version();
    }

    protected void validateConfiguration() {
        if (this.sourceField == null || this.sourceField.isBlank()) {
            throw new ConfigException(String.format("'%s' must be set to non-empty value.", TEXT_FIELD));
        }
        MODEL_FACTORY.validateConfiguration();
    }

    protected String getSourceString(R r) {
        if (r.value() == null || !this.smtManager.isValidEnvelope(r) || r.valueSchema().type() != Schema.Type.STRUCT) {
            LOGGER.debug("Skipping record {}, it has either null value or invalid structure", r);
            return null;
        }
        Struct requireStruct = Requirements.requireStruct(r.value(), "Obtaining source field for embeddings");
        for (int i = 0; i < this.sourceFieldPath.size() - 1; i++) {
            if (requireStruct.schema().type() != Schema.Type.STRUCT) {
                throw new IllegalArgumentException(String.format("Invalid field name %s, %s is not struct.", this.sourceField, requireStruct.schema().name()));
            }
            requireStruct = requireStruct.getStruct(this.sourceFieldPath.get(i));
            if (requireStruct == null) {
                LOGGER.debug("Skipping record {}, the structure is not present", r);
                return null;
            }
        }
        return requireStruct.getString((String) this.sourceFieldPath.getLast());
    }

    protected R buildUpdatedRecord(R r, String str) {
        Schema schema;
        List makeUpdatedValue;
        Struct requireStruct = Requirements.requireStruct(r.value(), "Original value must be struct");
        Embedding embedding = (Embedding) this.model.embed(TextSegment.from(str)).content();
        if (this.embeddingsField == null) {
            schema = EMBEDDING_SCHEMA;
            makeUpdatedValue = embedding.vectorAsList();
        } else {
            List of = List.of(new ConnectRecordUtil.NewEntry(this.embeddingsField, EMBEDDING_SCHEMA, embedding.vectorAsList()));
            schema = (Schema) this.schemaUpdateCache.computeIfAbsent(requireStruct.schema(), schema2 -> {
                return ConnectRecordUtil.makeNewSchema(schema2, of);
            });
            makeUpdatedValue = ConnectRecordUtil.makeUpdatedValue(requireStruct, of, schema);
        }
        return (R) r.newRecord(r.topic(), r.kafkaPartition(), r.keySchema(), r.key(), schema, makeUpdatedValue, r.timestamp(), r.headers());
    }
}
