package com.zilliz.spark.connector.filter;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.StructType;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.ArrayOps$;
import scala.collection.IndexedSeqOps;
import scala.collection.IterableOnceOps;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Seq;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.math.Ordering$DeprecatedDoubleOrdering$;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: VectorBruteForceSearch.scala */
/* loaded from: input_file:com/zilliz/spark/connector/filter/VectorBruteForceSearch$.class */
public final class VectorBruteForceSearch$ {
    public static final VectorBruteForceSearch$ MODULE$ = new VectorBruteForceSearch$();
    private static final UserDefinedFunction arrayFloatToDenseVectorUDF = functions$.MODULE$.udf(seq -> {
        if (seq == null) {
            return null;
        }
        return Vectors$.MODULE$.dense((double[]) ((IterableOnceOps) seq.map(f -> {
            return f;
        })).toArray(ClassTag$.MODULE$.Double()));
    }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(MODULE$.getClass().getClassLoader()), new TypeCreator() { // from class: com.zilliz.spark.connector.filter.VectorBruteForceSearch$$typecreator1$1
        public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
            mirror.universe();
            return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
        }
    }), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(MODULE$.getClass().getClassLoader()), new TypeCreator() { // from class: com.zilliz.spark.connector.filter.VectorBruteForceSearch$$typecreator2$1
        public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
            Universe universe = mirror.universe();
            return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().SingleType(universe.internal().reificationSupport().SingleType(universe.internal().reificationSupport().thisPrefix(mirror.RootClass()), mirror.staticPackage("scala")), mirror.staticModule("scala.package")), universe.internal().reificationSupport().selectType(mirror.staticModule("scala.package").asModule().moduleClass(), "Seq"), new $colon.colon(mirror.staticClass("scala.Float").asType().toTypeConstructor(), Nil$.MODULE$));
        }
    }));

    private double cosineSimilarity(Vector vector, Vector vector2) {
        double dot = vector.dot(vector2);
        double norm = Vectors$.MODULE$.norm(vector, 2.0d);
        double norm2 = Vectors$.MODULE$.norm(vector2, 2.0d);
        if (norm == 0.0d || norm2 == 0.0d) {
            return 0.0d;
        }
        return dot / (norm * norm2);
    }

    private UserDefinedFunction arrayFloatToDenseVectorUDF() {
        return arrayFloatToDenseVectorUDF;
    }

    public Dataset<Row> filterSimilarVectors(Dataset<Row> dataset, Seq<Object> seq, int i, double d, String str, Option<String> option) {
        String str2;
        SparkSession sparkSession = dataset.sparkSession();
        StructType schema = dataset.schema();
        if (!ArrayOps$.MODULE$.contains$extension(Predef$.MODULE$.refArrayOps(schema.fieldNames()), str)) {
            throw new IllegalArgumentException(new StringBuilder(42).append("DataFrame does not contain vector column: ").append(str).toString());
        }
        String str3 = "_converted_dense_vector_";
        ArrayType dataType = schema.apply(str).dataType();
        if (!(dataType instanceof ArrayType) || !FloatType$.MODULE$.equals(dataType.elementType())) {
            throw new IllegalArgumentException(new StringBuilder(54).append("Vector column '").append(str).append("' must be of type Array[Float]. Found: ").append(schema.apply(str).dataType()).toString());
        }
        Dataset withColumn = dataset.withColumn("_converted_dense_vector_", arrayFloatToDenseVectorUDF().apply(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str)})));
        if (option instanceof Some) {
            String str4 = (String) ((Some) option).value();
            if (!ArrayOps$.MODULE$.contains$extension(Predef$.MODULE$.refArrayOps(schema.fieldNames()), str4)) {
                throw new IllegalArgumentException(new StringBuilder(38).append("DataFrame does not contain ID column: ").append(str4).toString());
            }
            str2 = str4;
        } else {
            if (!None$.MODULE$.equals(option)) {
                throw new MatchError(option);
            }
            Predef$.MODULE$.println("Warning: No ID column provided. Adding a default row index, which may cause a shuffle. For production, consider ensuring your DataFrame has a unique identifier column.");
            withColumn = withColumn.withColumn("_brute_force_search_row_id_", functions$.MODULE$.monotonically_increasing_id());
            str2 = "_brute_force_search_row_id_";
        }
        String str5 = str2;
        Broadcast broadcast = sparkSession.sparkContext().broadcast(Vectors$.MODULE$.dense((double[]) ((IterableOnceOps) seq.map(f -> {
            return f;
        })).toArray(ClassTag$.MODULE$.Double())), ClassTag$.MODULE$.apply(Vector.class));
        Broadcast broadcast2 = sparkSession.sparkContext().broadcast(BoxesRunTime.boxToDouble(d), ClassTag$.MODULE$.Double());
        Broadcast broadcast3 = sparkSession.sparkContext().broadcast(BoxesRunTime.boxToInteger(i), ClassTag$.MODULE$.Int());
        RDD rdd = withColumn.rdd();
        Dataset<Row> drop = sparkSession.createDataFrame(rdd.mapPartitions(iterator -> {
            Vector vector = (Vector) broadcast.value();
            int unboxToInt = BoxesRunTime.unboxToInt(broadcast3.value());
            ArrayBuffer arrayBuffer = (ArrayBuffer) ArrayBuffer$.MODULE$.apply(Nil$.MODULE$);
            iterator.foreach(row -> {
                long unboxToLong;
                Object as = row.getAs(str5);
                if (as instanceof Integer) {
                    unboxToLong = BoxesRunTime.unboxToInt(as);
                } else {
                    if (!(as instanceof Long)) {
                        throw new IllegalArgumentException(new StringBuilder(58).append("ID column '").append(str5).append("' must be a numeric type (Int or Long). Found: ").append(row.getAs(str5).getClass().getName()).toString());
                    }
                    unboxToLong = BoxesRunTime.unboxToLong(as);
                }
                double cosineSimilarity = MODULE$.cosineSimilarity(vector, (Vector) row.getAs(str3));
                return cosineSimilarity >= BoxesRunTime.unboxToDouble(broadcast2.value()) ? arrayBuffer.$plus$eq(new Tuple2(row, BoxesRunTime.boxToDouble(cosineSimilarity))) : BoxedUnit.UNIT;
            });
            return ((IndexedSeqOps) ((IndexedSeqOps) arrayBuffer.sortBy(tuple2 -> {
                return BoxesRunTime.boxToDouble($anonfun$filterSimilarVectors$4(tuple2));
            }, Ordering$DeprecatedDoubleOrdering$.MODULE$)).take(unboxToInt)).iterator();
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            return Row$.MODULE$.fromSeq((Seq) ((Row) tuple2._1()).toSeq().$colon$plus(BoxesRunTime.boxToDouble(tuple2._2$mcD$sp())));
        }, ClassTag$.MODULE$.apply(Row.class)), withColumn.schema().add("similarity", DoubleType$.MODULE$)).orderBy(ScalaRunTime$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.desc("similarity")})).limit(i).drop("_converted_dense_vector_");
        return option.isEmpty() ? drop.drop("_brute_force_search_row_id_") : drop;
    }

    public int filterSimilarVectors$default$3() {
        return 10;
    }

    public double filterSimilarVectors$default$4() {
        return 0.0d;
    }

    public String filterSimilarVectors$default$5() {
        return "vector";
    }

    public Option<String> filterSimilarVectors$default$6() {
        return None$.MODULE$;
    }

    public static final /* synthetic */ double $anonfun$filterSimilarVectors$4(Tuple2 tuple2) {
        return -tuple2._2$mcD$sp();
    }

    private VectorBruteForceSearch$() {
    }
}
