package org.deeplearning4j.parallelism.inference.observers;

import java.util.Collections;
import java.util.List;
import java.util.Observable;
import lombok.Generated;
import lombok.NonNull;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.guava.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/inference/observers/BasicInferenceObservable.class */
public class BasicInferenceObservable extends Observable implements InferenceObservable {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(BasicInferenceObservable.class);
    private INDArray[] input;
    private INDArray[] inputMasks;
    private long id;
    private INDArray[] output;
    protected Exception exception;
    protected String[] layersToOutputTo;
    protected int[] layerIndicesOutputTo;

    public BasicInferenceObservable(int[] iArr, INDArray... iNDArrayArr) {
        this(iArr, iNDArrayArr, (INDArray[]) null);
    }

    public BasicInferenceObservable(int[] iArr, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        this.layerIndicesOutputTo = iArr;
        this.input = iNDArrayArr;
        this.inputMasks = iNDArrayArr2;
    }

    public BasicInferenceObservable(String[] strArr, INDArray... iNDArrayArr) {
        this(strArr, iNDArrayArr, (INDArray[]) null);
    }

    public BasicInferenceObservable(String[] strArr, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        this.layersToOutputTo = strArr;
        this.input = iNDArrayArr;
        this.inputMasks = iNDArrayArr2;
    }

    public BasicInferenceObservable(INDArray... iNDArrayArr) {
        this(iNDArrayArr, (INDArray[]) null);
    }

    public BasicInferenceObservable(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        this.input = iNDArrayArr;
        this.inputMasks = iNDArrayArr2;
    }

    @Override // org.deeplearning4j.parallelism.inference.InferenceObservable
    public void addInput(@NonNull INDArray... iNDArrayArr) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        addInput(iNDArrayArr, null);
    }

    @Override // org.deeplearning4j.parallelism.inference.InferenceObservable
    public void addInput(@NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        this.input = iNDArrayArr;
        this.inputMasks = iNDArrayArr2;
    }

    @Override // org.deeplearning4j.parallelism.inference.InferenceObservable
    public void setOutputBatches(@NonNull List<INDArray[]> list) {
        if (list == null) {
            throw new NullPointerException("output is marked non-null but is null");
        }
        Preconditions.checkArgument(list.size() == 1, "Expected size 1 output: got size " + list.size());
        this.output = list.get(0);
        setChanged();
        notifyObservers();
    }

    @Override // org.deeplearning4j.parallelism.inference.InferenceObservable
    public List<Pair<INDArray[], INDArray[]>> getInputBatches() {
        return Collections.singletonList(new Pair(this.input, this.inputMasks));
    }

    @Override // org.deeplearning4j.parallelism.inference.InferenceObservable
    public void setOutputException(Exception exc) {
        this.exception = exc;
        setChanged();
        notifyObservers();
    }

    @Override // org.deeplearning4j.parallelism.inference.InferenceObservable
    public INDArray[] getOutput() {
        checkOutputException();
        return this.output;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void checkOutputException() {
        if (this.exception != null) {
            if (!(this.exception instanceof RuntimeException)) {
                throw new RuntimeException("Exception encountered while getting output: " + this.exception.getMessage(), this.exception);
            }
            throw ((RuntimeException) this.exception);
        }
    }

    @Generated
    public long getId() {
        return this.id;
    }
}
