package dev.langchain4j.store.embedding;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetrieverTest;
import org.assertj.core.api.WithAssertions;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:dev/langchain4j/store/embedding/CosineSimilarityTest.class */
class CosineSimilarityTest implements WithAssertions {
    CosineSimilarityTest() {
    }

    @Test
    void bad() {
        Embedding from = Embedding.from(new float[]{1.0f, 1.0f, 1.0f});
        Embedding from2 = Embedding.from(new float[]{1.0f, 1.0f, 1.0f, 1.0f});
        assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> {
            CosineSimilarity.between(from, from2);
        }).withMessage("Length of vector a (3) must be equal to the length of vector b (4)");
    }

    @Test
    void zeros() {
        assertThat(CosineSimilarity.between(Embedding.from(new float[]{0.0f, 0.0f, 0.0f}), Embedding.from(new float[]{0.0f, 0.0f, 0.0f}))).isCloseTo(EmbeddingStoreContentRetrieverTest.DEFAULT_MIN_SCORE, Percentage.withPercentage(1.0d));
    }

    @Test
    void should_calculate_cosine_similarity() {
        Embedding from = Embedding.from(new float[]{1.0f, -1.0f, 1.0f});
        Embedding from2 = Embedding.from(new float[]{-1.0f, 1.0f, -1.0f});
        assertThat(CosineSimilarity.between(from, from)).isCloseTo(1.0d, Percentage.withPercentage(1.0d));
        assertThat(CosineSimilarity.between(from, from2)).isCloseTo(-1.0d, Percentage.withPercentage(1.0d));
    }

    @Test
    void should_convert_relevance_score_into_cosine_similarity() {
        assertThat(CosineSimilarity.fromRelevanceScore(EmbeddingStoreContentRetrieverTest.DEFAULT_MIN_SCORE)).isEqualTo(-1.0d);
        assertThat(CosineSimilarity.fromRelevanceScore(0.5d)).isEqualTo(EmbeddingStoreContentRetrieverTest.DEFAULT_MIN_SCORE);
        assertThat(CosineSimilarity.fromRelevanceScore(1.0d)).isEqualTo(1.0d);
    }
}
