package smile.feature.imputation;

import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import smile.data.AbstractTuple;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.measure.NominalScale;
import smile.data.transform.Transform;
import smile.data.type.StructField;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.neighbor.KNNSearch;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;

/* loaded from: input_file:smile/feature/imputation/KNNImputer.class */
public class KNNImputer implements Transform {
    private final int k;
    private final KNNSearch<Tuple, Tuple> knn;

    public KNNImputer(DataFrame dataFrame, int i, Distance<Tuple> distance) {
        this.k = i;
        this.knn = LinearSearch.of(dataFrame.stream().map(row -> {
            return row;
        }).toList(), distance);
    }

    public KNNImputer(DataFrame dataFrame, int i, String... strArr) {
        this(dataFrame, i, (Distance<Tuple>) (tuple, tuple2) -> {
            return MathEx.squaredDistanceWithMissingValues(tuple.toArray(strArr), tuple2.toArray(strArr));
        });
    }

    @Override // java.util.function.Function
    public Tuple apply(final Tuple tuple) {
        final Neighbor<Tuple, Tuple>[] search = this.knn.search(tuple, this.k);
        return new AbstractTuple(tuple.schema()) { // from class: smile.feature.imputation.KNNImputer.1
            @Override // smile.data.Tuple
            public Object get(int i) {
                Object obj = tuple.get(i);
                if (!SimpleImputer.isMissing(obj)) {
                    return obj;
                }
                StructField field = this.schema.field(i);
                if (field.dtype().isBoolean()) {
                    int[] omit = MathEx.omit(Arrays.stream(search).mapToInt(neighbor -> {
                        return ((Tuple) neighbor.key()).getInt(i);
                    }).toArray(), Integer.MIN_VALUE);
                    if (omit.length == 0) {
                        return null;
                    }
                    return Boolean.valueOf(MathEx.mode(omit) != 0);
                }
                if (field.dtype().isChar()) {
                    int[] omit2 = MathEx.omit(Arrays.stream(search).mapToInt(neighbor2 -> {
                        return ((Tuple) neighbor2.key()).getInt(i);
                    }).toArray(), Integer.MIN_VALUE);
                    if (omit2.length == 0) {
                        return null;
                    }
                    return Character.valueOf((char) MathEx.mode(omit2));
                }
                if (field.measure() instanceof NominalScale) {
                    int[] omit3 = MathEx.omit(Arrays.stream(search).mapToInt(neighbor3 -> {
                        return ((Tuple) neighbor3.key()).getInt(i);
                    }).toArray(), Integer.MIN_VALUE);
                    if (omit3.length == 0) {
                        return null;
                    }
                    return Integer.valueOf(MathEx.mode(omit3));
                }
                if (!field.dtype().isNumeric()) {
                    return null;
                }
                double[] omit4 = MathEx.omit(Arrays.stream(search).mapToDouble(neighbor4 -> {
                    return ((Tuple) neighbor4.key()).getDouble(i);
                }).toArray(), -2.147483648E9d);
                if (omit4.length == 0) {
                    return null;
                }
                return Double.valueOf(MathEx.mean(omit4));
            }
        };
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 2071819056:
                if (implMethodName.equals("lambda$new$91dbfa3d$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("smile/math/distance/Distance") && serializedLambda.getFunctionalInterfaceMethodName().equals("d") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D") && serializedLambda.getImplClass().equals("smile/feature/imputation/KNNImputer") && serializedLambda.getImplMethodSignature().equals("([Ljava/lang/String;Lsmile/data/Tuple;Lsmile/data/Tuple;)D")) {
                    String[] strArr = (String[]) serializedLambda.getCapturedArg(0);
                    return (tuple, tuple2) -> {
                        return MathEx.squaredDistanceWithMissingValues(tuple.toArray(strArr), tuple2.toArray(strArr));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
