package io.debezium.ai.embeddings;

import io.debezium.data.Envelope;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.connect.source.SourceRecord;
import org.assertj.core.api.Assertions;
import org.junit.Test;

/* loaded from: input_file:io/debezium/ai/embeddings/FieldToEmbeddingTest.class */
public class FieldToEmbeddingTest {
    public static final Schema VALUE_SCHEMA = SchemaBuilder.struct().name("mysql.inventory.products.Value").field("id", Schema.INT64_SCHEMA).field("price", Schema.FLOAT32_SCHEMA).field("product", Schema.STRING_SCHEMA).build();
    public static final Struct ROW = new Struct(VALUE_SCHEMA).put("id", 101L).put("price", Float.valueOf(20.0f)).put("product", "a product");
    public static final Envelope ENVELOPE = Envelope.defineSchema().withName("mysql.inventory.products.Envelope").withRecord(VALUE_SCHEMA).withSource(Schema.STRING_SCHEMA).build();
    public static final Struct PAYLOAD = ENVELOPE.create(ROW, (Struct) null, Instant.now());
    public static final SourceRecord SOURCE_RECORD = new SourceRecord(new HashMap(), new HashMap(), "topic", ENVELOPE.schema(), PAYLOAD);

    @Test
    public void testNonNestedFieldIsEmbeddedNonNested() {
        FieldToEmbedding fieldToEmbedding = new FieldToEmbedding();
        fieldToEmbedding.configure(Map.of("field.source", "op", "field.embedding", "op_embedding"));
        Struct struct = (Struct) fieldToEmbedding.apply(SOURCE_RECORD).value();
        Assertions.assertThat(struct.getString("op")).isEqualTo("c");
        Assertions.assertThat(struct.getArray("op_embedding")).contains(new Object[]{Float.valueOf(0.0f), Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f)});
    }

    @Test
    public void testNestedFieldIsEmbeddedNested() {
        FieldToEmbedding fieldToEmbedding = new FieldToEmbedding();
        fieldToEmbedding.configure(Map.of("field.source", "after.product", "field.embedding", "after.prod_embedding"));
        Struct struct = (Struct) fieldToEmbedding.apply(SOURCE_RECORD).value();
        Assertions.assertThat(struct.getStruct("after").getString("product")).contains(new CharSequence[]{"a product"});
        Assertions.assertThat(struct.getStruct("after").getArray("prod_embedding")).contains(new Object[]{Float.valueOf(0.0f), Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f)});
    }

    @Test
    public void testNestedFieldIsWithSameName() {
        FieldToEmbedding fieldToEmbedding = new FieldToEmbedding();
        fieldToEmbedding.configure(Map.of("field.source", "after.product", "field.embedding", "after.product_embedding"));
        Struct struct = (Struct) fieldToEmbedding.apply(SOURCE_RECORD).value();
        Assertions.assertThat(struct.getStruct("after").getString("product")).contains(new CharSequence[]{"a product"});
        Assertions.assertThat(struct.getStruct("after").getArray("product_embedding")).contains(new Object[]{Float.valueOf(0.0f), Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f)});
    }

    @Test
    public void testNoEmbeddingsConfProvided() {
        FieldToEmbedding fieldToEmbedding = new FieldToEmbedding();
        fieldToEmbedding.configure(Map.of("field.source", "after.product"));
        Assertions.assertThat((List) fieldToEmbedding.apply(SOURCE_RECORD).value()).contains(new Float[]{Float.valueOf(0.0f), Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f)});
    }
}
