package io.debezium.ai.embeddings;

import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.connect.source.SourceRecord;
import org.assertj.core.api.Assertions;
import org.assertj.core.data.Offset;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.testcontainers.ollama.OllamaContainer;
import org.testcontainers.utility.DockerImageName;

/* loaded from: input_file:io/debezium/ai/embeddings/EmbeddingsOllamaIT.class */
public class EmbeddingsOllamaIT {
    private static final String OLLAMA_TEST_MODEL = "all-minilm";
    private final FieldToEmbedding<SourceRecord> embeddingSmt = new FieldToEmbedding<>();
    private static final String OLLAMA_IMAGE_NAME = "mirror.gcr.io/ollama/ollama:0.6.2";
    private static final OllamaContainer ollama = new OllamaContainer(DockerImageName.parse(OLLAMA_IMAGE_NAME).asCompatibleSubstituteFor("ollama/ollama")).withStartupTimeout(Duration.ofSeconds(180));

    @BeforeClass
    public static void startDatabase() {
        ollama.start();
    }

    @AfterClass
    public static void stopDatabase() {
        ollama.stop();
    }

    @Test
    public void testOllamaEmbeddings() throws InterruptedException, IOException {
        assertEmbeddingsForConfig(Map.of("field.source", "after.product", "field.embedding", "after.prod_embedding", "ollama.url", ollama.getEndpoint(), "ollama.model.name", OLLAMA_TEST_MODEL, "operation.timeout.ms", 20000));
    }

    @Test
    public void testOllamaEmbeddingsWithLegacyConfig() throws InterruptedException, IOException {
        assertEmbeddingsForConfig(Map.of("embeddings.field.source", "after.product", "embeddings.field.embedding", "after.prod_embedding", "embeddings.ollama.url", ollama.getEndpoint(), "embeddings.ollama.model.name", OLLAMA_TEST_MODEL, "embeddings.operation.timeout.ms", 20000));
    }

    private void assertEmbeddingsForConfig(Map<String, ?> map) throws InterruptedException, IOException {
        ollama.execInContainer(new String[]{"ollama", "pull", OLLAMA_TEST_MODEL});
        this.embeddingSmt.configure(map);
        Struct struct = (Struct) this.embeddingSmt.apply(FieldToEmbeddingTest.SOURCE_RECORD).value();
        Assertions.assertThat(struct.getStruct("after").getString("product")).contains(new CharSequence[]{"a product"});
        List array = struct.getStruct("after").getArray("prod_embedding");
        Assertions.assertThat(array.size()).isEqualTo(384);
        Offset offset = Offset.offset(Float.valueOf(0.001f));
        Assertions.assertThat((Float) array.get(0)).isCloseTo(-0.07157089f, offset);
        Assertions.assertThat((Float) array.get(1)).isCloseTo(0.022460647f, offset);
        Assertions.assertThat((Float) array.get(2)).isCloseTo(-0.02369636f, offset);
    }
}
