package org.datavec.image.recordreader.objdetect;

import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaDataImageURI;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.util.files.FileFromPathIterator;
import org.datavec.api.util.files.URIUtil;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.datavec.image.data.Image;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.BaseImageRecordReader;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.util.ImageUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.class */
public class ObjectDetectionRecordReader extends BaseImageRecordReader {
    private final int gridW;
    private final int gridH;
    private final ImageObjectLabelProvider labelProvider;
    private final boolean nchw;
    protected Image currentImage;

    public ObjectDetectionRecordReader(int i, int i2, int i3, int i4, int i5, ImageObjectLabelProvider imageObjectLabelProvider) {
        this(i, i2, i3, i4, i5, true, imageObjectLabelProvider);
    }

    public ObjectDetectionRecordReader(int i, int i2, int i3, int i4, int i5, boolean z, ImageObjectLabelProvider imageObjectLabelProvider) {
        super(i, i2, i3, null, null);
        this.gridW = i5;
        this.gridH = i4;
        this.nchw = z;
        this.labelProvider = imageObjectLabelProvider;
        this.appendLabel = imageObjectLabelProvider != null;
    }

    public ObjectDetectionRecordReader(int i, int i2, int i3, int i4, int i5, ImageObjectLabelProvider imageObjectLabelProvider, ImageTransform imageTransform) {
        this(i, i2, i3, i4, i5, true, imageObjectLabelProvider, imageTransform);
    }

    public ObjectDetectionRecordReader(int i, int i2, int i3, int i4, int i5, boolean z, ImageObjectLabelProvider imageObjectLabelProvider, ImageTransform imageTransform) {
        super(i, i2, i3, null, null);
        this.gridW = i5;
        this.gridH = i4;
        this.nchw = z;
        this.labelProvider = imageObjectLabelProvider;
        this.appendLabel = imageObjectLabelProvider != null;
        this.imageTransform = imageTransform;
    }

    @Override // org.datavec.image.recordreader.BaseImageRecordReader
    public List<Writable> next() {
        return next(1).get(0);
    }

    @Override // org.datavec.image.recordreader.BaseImageRecordReader
    public void initialize(InputSplit inputSplit) throws IOException {
        if (this.imageLoader == null) {
            this.imageLoader = new NativeImageLoader(this.height, this.width, this.channels, this.imageTransform);
        }
        this.inputSplit = inputSplit;
        URI[] locations = inputSplit.locations();
        HashSet hashSet = new HashSet();
        if (locations == null || locations.length < 1) {
            throw new IllegalArgumentException("No path locations found in the split.");
        }
        for (URI uri : locations) {
            Iterator<ImageObject> it = this.labelProvider.getImageObjectsForPath(uri).iterator();
            while (it.hasNext()) {
                String label = it.next().getLabel();
                if (!hashSet.contains(label)) {
                    hashSet.add(label);
                }
            }
        }
        this.iter = new FileFromPathIterator(this.inputSplit.locationsPathIterator());
        if (inputSplit instanceof FileSplit) {
            this.labels.remove(((FileSplit) inputSplit).getRootDir());
        }
        this.labels = new ArrayList(hashSet);
        Collections.sort(this.labels);
    }

    @Override // org.datavec.image.recordreader.BaseImageRecordReader
    public List<List<Writable>> next(int i) {
        ArrayList arrayList = new ArrayList(i);
        ArrayList arrayList2 = new ArrayList(i);
        for (int i2 = 0; i2 < i && hasNext(); i2++) {
            File next = this.iter.next();
            this.currentFile = next;
            if (!next.isDirectory()) {
                arrayList.add(next);
                arrayList2.add(this.labelProvider.getImageObjectsForPath(next.getPath()));
            }
        }
        int size = this.labels.size();
        INDArray create = Nd4j.create(new long[]{arrayList.size(), this.channels, this.height, this.width});
        INDArray create2 = Nd4j.create(new int[]{arrayList.size(), 4 + size, this.gridH, this.gridW});
        int i3 = 0;
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            File file = (File) arrayList.get(i4);
            this.currentFile = file;
            try {
                invokeListeners(file);
                Image asImageMatrix = this.imageLoader.asImageMatrix(file);
                this.currentImage = asImageMatrix;
                Nd4j.getAffinityManager().ensureLocation(asImageMatrix.getImage(), AffinityManager.Location.DEVICE);
                create.put(new INDArrayIndex[]{NDArrayIndex.point(i3), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}, asImageMatrix.getImage());
                label(asImageMatrix, (List) arrayList2.get(i3), create2, i3);
                i3++;
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        if (!this.nchw) {
            create = create.permute(new long[]{0, 2, 3, 1});
            create2 = create2.permute(new long[]{0, 2, 3, 1});
        }
        return new NDArrayRecordBatch(Arrays.asList(create, create2));
    }

    private void label(Image image, List<ImageObject> list, INDArray iNDArray, int i) {
        int origW = image.getOrigW();
        int origH = image.getOrigH();
        for (ImageObject imageObject : list) {
            double xCenterPixels = imageObject.getXCenterPixels();
            double yCenterPixels = imageObject.getYCenterPixels();
            if (this.imageTransform != null) {
                origW = this.imageTransform.getCurrentImage().getWidth();
                origH = this.imageTransform.getCurrentImage().getHeight();
                float[] query = this.imageTransform.query(imageObject.getX1(), imageObject.getY1(), imageObject.getX2(), imageObject.getY2());
                imageObject = new ImageObject(Math.round(Math.min(query[0], query[2])), Math.round(Math.min(query[1], query[3])), Math.round(Math.max(query[0], query[2])), Math.round(Math.max(query[1], query[3])), imageObject.getLabel());
                xCenterPixels = imageObject.getXCenterPixels();
                yCenterPixels = imageObject.getYCenterPixels();
                if (xCenterPixels >= 0.0d && xCenterPixels < origW && yCenterPixels >= 0.0d && yCenterPixels < origH) {
                }
            }
            double[] translateCoordsScaleImage = ImageUtils.translateCoordsScaleImage(xCenterPixels, yCenterPixels, origW, origH, this.width, this.height);
            double[] translateCoordsScaleImage2 = ImageUtils.translateCoordsScaleImage(imageObject.getX1(), imageObject.getY1(), origW, origH, this.width, this.height);
            double[] translateCoordsScaleImage3 = ImageUtils.translateCoordsScaleImage(imageObject.getX2(), imageObject.getY2(), origW, origH, this.width, this.height);
            int i2 = (int) ((translateCoordsScaleImage[0] / this.width) * this.gridW);
            int i3 = (int) ((translateCoordsScaleImage[1] / this.height) * this.gridH);
            translateCoordsScaleImage2[0] = (translateCoordsScaleImage2[0] / this.width) * this.gridW;
            translateCoordsScaleImage2[1] = (translateCoordsScaleImage2[1] / this.height) * this.gridH;
            translateCoordsScaleImage3[0] = (translateCoordsScaleImage3[0] / this.width) * this.gridW;
            translateCoordsScaleImage3[1] = (translateCoordsScaleImage3[1] / this.height) * this.gridH;
            Preconditions.checkState(i3 >= 0 && ((long) i3) < iNDArray.size(2), "Invalid image center in Y axis: calculated grid location of %s, must be between 0 (inclusive) and %s (exclusive). Object label center is outside of image bounds. Image object: %s", Integer.valueOf(i3), Long.valueOf(iNDArray.size(2)), imageObject);
            Preconditions.checkState(i2 >= 0 && ((long) i2) < iNDArray.size(3), "Invalid image center in X axis: calculated grid location of %s, must be between 0 (inclusive) and %s (exclusive). Object label center is outside of image bounds. Image object: %s", Integer.valueOf(i3), Long.valueOf(iNDArray.size(2)), imageObject);
            iNDArray.putScalar(i, 0L, i3, i2, translateCoordsScaleImage2[0]);
            iNDArray.putScalar(i, 1L, i3, i2, translateCoordsScaleImage2[1]);
            iNDArray.putScalar(i, 2L, i3, i2, translateCoordsScaleImage3[0]);
            iNDArray.putScalar(i, 3L, i3, i2, translateCoordsScaleImage3[1]);
            iNDArray.putScalar(i, 4 + this.labels.indexOf(imageObject.getLabel()), i3, i2, 1.0d);
        }
    }

    @Override // org.datavec.image.recordreader.BaseImageRecordReader
    public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
        invokeListeners(uri);
        if (this.imageLoader == null) {
            this.imageLoader = new NativeImageLoader(this.height, this.width, this.channels, this.imageTransform);
        }
        Image asImageMatrix = this.imageLoader.asImageMatrix(dataInputStream);
        if (!this.nchw) {
            asImageMatrix.setImage(asImageMatrix.getImage().permute(new long[]{0, 2, 3, 1}));
        }
        Nd4j.getAffinityManager().ensureLocation(asImageMatrix.getImage(), AffinityManager.Location.DEVICE);
        List<Writable> record = RecordConverter.toRecord(asImageMatrix.getImage());
        if (this.appendLabel) {
            List<ImageObject> imageObjectsForPath = this.labelProvider.getImageObjectsForPath(uri.getPath());
            INDArray create = Nd4j.create(new int[]{1, 4 + this.labels.size(), this.gridH, this.gridW});
            label(asImageMatrix, imageObjectsForPath, create, 0);
            if (!this.nchw) {
                create = create.permute(new long[]{0, 2, 3, 1});
            }
            record.add(new NDArrayWritable(create));
        }
        return record;
    }

    @Override // org.datavec.image.recordreader.BaseImageRecordReader
    public Record nextRecord() {
        return new org.datavec.api.records.impl.Record(next(), new RecordMetaDataImageURI(URIUtil.fileToURI(this.currentFile), BaseImageRecordReader.class, this.currentImage.getOrigC(), this.currentImage.getOrigH(), this.currentImage.getOrigW()));
    }
}
