package smile.regression;

import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/regression/ElasticNet.class */
public class ElasticNet {
    public static LinearModel fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Double.parseDouble(properties.getProperty("smile.elastic_net.lambda1")), Double.parseDouble(properties.getProperty("smile.elastic_net.lambda2")), Double.parseDouble(properties.getProperty("smile.elastic_net.tolerance", "1E-4")), Integer.parseInt(properties.getProperty("smile.elastic_net.iterations", "1000")));
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d, double d2) {
        return fit(formula, dataFrame, d, d2, 1.0E-4d, 1000);
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d, double d2, double d3, int i) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Please use Ridge instead, wrong L1 portion setting: " + d);
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Please use LASSO instead, wrong L2 portion setting: " + d2);
        }
        double sqrt = 1.0d / Math.sqrt(1.0d + d2);
        Formula expand = formula.expand(dataFrame.schema());
        StructType bind = expand.bind(dataFrame.schema());
        Matrix matrix = expand.matrix(dataFrame, false);
        double[] doubleArray = expand.y(dataFrame).toDoubleArray();
        int nrow = matrix.nrow();
        int ncol = matrix.ncol();
        double[] colMeans = matrix.colMeans();
        double[] colSds = matrix.colSds();
        double[] dArr = new double[nrow + ncol];
        double mean = MathEx.mean(doubleArray);
        for (int i2 = 0; i2 < nrow; i2++) {
            dArr[i2] = doubleArray[i2] - mean;
        }
        Matrix matrix2 = new Matrix(matrix.nrow() + ncol, ncol);
        double sqrt2 = sqrt * Math.sqrt(d2);
        for (int i3 = 0; i3 < ncol; i3++) {
            for (int i4 = 0; i4 < nrow; i4++) {
                matrix2.set(i4, i3, (sqrt * (matrix.get(i4, i3) - colMeans[i3])) / colSds[i3]);
            }
            matrix2.set(i3 + nrow, i3, sqrt2);
        }
        double[] train = LASSO.train(matrix2, dArr, d * sqrt, d3, i);
        for (int i5 = 0; i5 < ncol; i5++) {
            train[i5] = (sqrt * train[i5]) / colSds[i5];
        }
        return new LinearModel(expand, bind, matrix, doubleArray, train, mean - MathEx.dot(train, colMeans));
    }
}
