package org.deeplearning4j.datasets.base;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.common.resources.ResourceType;
import org.nd4j.common.resources.Downloader;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/datasets/base/IrisUtils.class */
public class IrisUtils {
    private static final String IRIS_RELATIVE_URL = "datasets/iris.dat";
    private static final String MD5 = "1c21400a78061197eac64c6748844216";

    private IrisUtils() {
    }

    public static List<DataSet> loadIris(int i, int i2) throws IOException {
        File file = new File(DL4JResources.getDirectory(ResourceType.DATASET, "iris"), "iris.dat");
        if (!file.exists()) {
            Downloader.download("Iris", DL4JResources.getURL(IRIS_RELATIVE_URL), file, MD5, 3);
        }
        FileInputStream fileInputStream = new FileInputStream(file);
        try {
            List readLines = IOUtils.readLines(fileInputStream);
            fileInputStream.close();
            ArrayList arrayList = new ArrayList();
            INDArray ones = i2 - i > 1 ? Nd4j.ones(new int[]{Math.abs(i2 - i), 4}) : Nd4j.ones(new int[]{4});
            double[][] dArr = new double[readLines.size()][3];
            int i3 = 0;
            for (int i4 = i; i4 < i2; i4++) {
                String[] split = ((String) readLines.get(i4)).split(",");
                int i5 = i3;
                i3++;
                addRow(ones, i5, split);
                String str = split[split.length - 1];
                double[] dArr2 = new double[3];
                dArr2[Integer.parseInt(str)] = 1.0d;
                dArr[i4] = dArr2;
            }
            for (int i6 = 0; i6 < ones.rows(); i6++) {
                arrayList.add(new DataSet(ones.getRow(i6, false), Nd4j.create(dArr[i + i6], new long[]{3})));
            }
            return arrayList;
        } catch (Throwable th) {
            try {
                fileInputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static void addRow(INDArray iNDArray, int i, String[] strArr) {
        double[] dArr = new double[4];
        for (int i2 = 0; i2 < 4; i2++) {
            dArr[i2] = Double.parseDouble(strArr[i2]);
        }
        iNDArray.putRow(i, Nd4j.create(dArr));
    }
}
