package org.noear.solon.ai.rag.util;

import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.noear.solon.ai.rag.Document;

/* loaded from: input_file:org/noear/solon/ai/rag/util/SimilarityUtil.class */
public final class SimilarityUtil {
    public static List<Document> refilter(Stream<Document> stream) {
        return refilter(stream, 4);
    }

    public static List<Document> refilter(Stream<Document> stream, int i) {
        return refilter(stream, i, 0.4d);
    }

    public static List<Document> refilter(Stream<Document> stream, int i, double d) {
        return (List) stream.filter(document -> {
            return similarityCheck(document, d);
        }).sorted(Comparator.comparing((v0) -> {
            return v0.getScore();
        }).reversed()).limit(i).collect(Collectors.toList());
    }

    public static List<Document> refilter(Stream<Document> stream, QueryCondition queryCondition) throws IOException {
        if (queryCondition.isDisableRefilter()) {
            return refilter(stream, queryCondition.getLimit(), queryCondition.getSimilarityThreshold());
        }
        queryCondition.getClass();
        return refilter(stream.filter(queryCondition::doFilter), queryCondition.getLimit(), queryCondition.getSimilarityThreshold());
    }

    public static Document score(Document document, float[] fArr) {
        return document.score(cosineSimilarity(fArr, document.getEmbedding()));
    }

    public static Document copyAndScore(Document document, float[] fArr) {
        return new Document(document.getId(), document.getContent(), document.getMetadata(), cosineSimilarity(fArr, document.getEmbedding()));
    }

    public static boolean similarityCheck(Document document, double d) {
        return document.getScore() >= d;
    }

    private static double cosineSimilarity(float[] fArr, float[] fArr2) {
        if (fArr == null || fArr2 == null) {
            throw new RuntimeException("Embed must not be null");
        }
        if (fArr.length != fArr2.length) {
            throw new IllegalArgumentException("Embed length must be equal");
        }
        float dotProduct = dotProduct(fArr, fArr2);
        float norm = norm(fArr);
        float norm2 = norm(fArr2);
        if (norm == 0.0f || norm2 == 0.0f) {
            throw new IllegalArgumentException("Embed cannot be zero norm");
        }
        return dotProduct / (Math.sqrt(norm) * Math.sqrt(norm2));
    }

    private static float dotProduct(float[] fArr, float[] fArr2) {
        if (fArr.length != fArr2.length) {
            throw new IllegalArgumentException("Embed length must be equal");
        }
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            f += fArr[i] * fArr2[i];
        }
        return f;
    }

    private static float norm(float[] fArr) {
        return dotProduct(fArr, fArr);
    }
}
