package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.Generated;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/InplaceParallelInference.class */
public class InplaceParallelInference extends ParallelInference {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(InplaceParallelInference.class);
    protected String[] layersToOutputTo;
    protected int[] layerIndicesOutputTo;
    protected List<ModelHolder> holders = new CopyOnWriteArrayList();
    protected ModelSelector selector = new ModelSelector();
    protected final Object locker = new Object();

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/deeplearning4j/parallelism/InplaceParallelInference$ModelHolder.class */
    public static class ModelHolder {
        protected Model sourceModel;
        protected int workers;
        protected List<Model> replicas;
        protected boolean rootDevice;
        protected LoadBalanceMode loadBalanceMode;
        protected String[] layersToOutputTo;
        protected int[] layerIndicesOutputTo;
        protected int targetDeviceId;
        protected final AtomicLong position;
        protected final ReentrantReadWriteLock modelLock;
        protected final BlockingQueue<Model> queue;
        protected transient boolean isCG;
        protected transient boolean isMLN;

        @Generated
        /* loaded from: input_file:org/deeplearning4j/parallelism/InplaceParallelInference$ModelHolder$ModelHolderBuilder.class */
        public static class ModelHolderBuilder {

            @Generated
            private Model sourceModel;

            @Generated
            private boolean workers$set;

            @Generated
            private int workers$value;

            @Generated
            private boolean replicas$set;

            @Generated
            private List<Model> replicas$value;

            @Generated
            private boolean rootDevice$set;

            @Generated
            private boolean rootDevice$value;

            @Generated
            private boolean loadBalanceMode$set;

            @Generated
            private LoadBalanceMode loadBalanceMode$value;

            @Generated
            private String[] layersToOutputTo;

            @Generated
            private int[] layerIndicesOutputTo;

            @Generated
            private int targetDeviceId;

            @Generated
            private boolean isCG$set;

            @Generated
            private boolean isCG$value;

            @Generated
            private boolean isMLN$set;

            @Generated
            private boolean isMLN$value;

            @Generated
            ModelHolderBuilder() {
            }

            @Generated
            public ModelHolderBuilder sourceModel(Model model) {
                this.sourceModel = model;
                return this;
            }

            @Generated
            public ModelHolderBuilder workers(int i) {
                this.workers$value = i;
                this.workers$set = true;
                return this;
            }

            @Generated
            public ModelHolderBuilder replicas(List<Model> list) {
                this.replicas$value = list;
                this.replicas$set = true;
                return this;
            }

            @Generated
            public ModelHolderBuilder rootDevice(boolean z) {
                this.rootDevice$value = z;
                this.rootDevice$set = true;
                return this;
            }

            @Generated
            public ModelHolderBuilder loadBalanceMode(LoadBalanceMode loadBalanceMode) {
                this.loadBalanceMode$value = loadBalanceMode;
                this.loadBalanceMode$set = true;
                return this;
            }

            @Generated
            public ModelHolderBuilder layersToOutputTo(String[] strArr) {
                this.layersToOutputTo = strArr;
                return this;
            }

            @Generated
            public ModelHolderBuilder layerIndicesOutputTo(int[] iArr) {
                this.layerIndicesOutputTo = iArr;
                return this;
            }

            @Generated
            public ModelHolderBuilder targetDeviceId(int i) {
                this.targetDeviceId = i;
                return this;
            }

            @Generated
            public ModelHolderBuilder isCG(boolean z) {
                this.isCG$value = z;
                this.isCG$set = true;
                return this;
            }

            @Generated
            public ModelHolderBuilder isMLN(boolean z) {
                this.isMLN$value = z;
                this.isMLN$set = true;
                return this;
            }

            @Generated
            public ModelHolder build() {
                int i = this.workers$value;
                if (!this.workers$set) {
                    i = ModelHolder.$default$workers();
                }
                List<Model> list = this.replicas$value;
                if (!this.replicas$set) {
                    list = ModelHolder.$default$replicas();
                }
                boolean z = this.rootDevice$value;
                if (!this.rootDevice$set) {
                    z = ModelHolder.$default$rootDevice();
                }
                LoadBalanceMode loadBalanceMode = this.loadBalanceMode$value;
                if (!this.loadBalanceMode$set) {
                    loadBalanceMode = LoadBalanceMode.ROUND_ROBIN;
                }
                boolean z2 = this.isCG$value;
                if (!this.isCG$set) {
                    z2 = ModelHolder.$default$isCG();
                }
                boolean z3 = this.isMLN$value;
                if (!this.isMLN$set) {
                    z3 = ModelHolder.$default$isMLN();
                }
                return new ModelHolder(this.sourceModel, i, list, z, loadBalanceMode, this.layersToOutputTo, this.layerIndicesOutputTo, this.targetDeviceId, z2, z3);
            }

            @Generated
            public String toString() {
                return "InplaceParallelInference.ModelHolder.ModelHolderBuilder(sourceModel=" + this.sourceModel + ", workers$value=" + this.workers$value + ", replicas$value=" + this.replicas$value + ", rootDevice$value=" + this.rootDevice$value + ", loadBalanceMode$value=" + this.loadBalanceMode$value + ", layersToOutputTo=" + Arrays.deepToString(this.layersToOutputTo) + ", layerIndicesOutputTo=" + Arrays.toString(this.layerIndicesOutputTo) + ", targetDeviceId=" + this.targetDeviceId + ", isCG$value=" + this.isCG$value + ", isMLN$value=" + this.isMLN$value + ")";
            }
        }

        protected synchronized void init() {
            if (this.workers < 1) {
                throw new ND4JIllegalStateException("Workers must be positive value");
            }
            this.replicas.clear();
            this.isCG = this.sourceModel instanceof ComputationGraph;
            this.isMLN = this.sourceModel instanceof MultiLayerNetwork;
            INDArray params = this.rootDevice ? this.sourceModel.params() : this.sourceModel.params().unsafeDuplication(true);
            if (!this.rootDevice) {
                Nd4j.getAffinityManager().replicateToDevice(Integer.valueOf(this.targetDeviceId), params);
            }
            for (int i = 0; i < this.workers; i++) {
                if (this.sourceModel instanceof ComputationGraph) {
                    Model computationGraph = new ComputationGraph(ComputationGraphConfiguration.fromJson(this.sourceModel.getConfiguration().toJson()));
                    computationGraph.init(params, false);
                    Nd4j.getExecutioner().commit();
                    this.replicas.add(computationGraph);
                    if (this.loadBalanceMode == LoadBalanceMode.FIFO) {
                        this.queue.add(computationGraph);
                    }
                } else if (this.sourceModel instanceof MultiLayerNetwork) {
                    Model multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(this.sourceModel.getLayerWiseConfigurations().toJson()));
                    multiLayerNetwork.init(params, false);
                    Nd4j.getExecutioner().commit();
                    this.replicas.add(multiLayerNetwork);
                    if (this.loadBalanceMode == LoadBalanceMode.FIFO) {
                        this.queue.add(multiLayerNetwork);
                    }
                }
            }
        }

        protected Model acquireModel() throws InterruptedException {
            try {
                this.modelLock.readLock().lock();
                switch (this.loadBalanceMode) {
                    case FIFO:
                        Model take = this.queue.take();
                        this.modelLock.readLock().unlock();
                        return take;
                    case ROUND_ROBIN:
                        Model model = this.replicas.get((int) (this.position.getAndIncrement() % this.replicas.size()));
                        this.modelLock.readLock().unlock();
                        return model;
                    default:
                        throw new ND4JIllegalStateException("Unknown LoadBalanceMode was specified: [" + this.loadBalanceMode + "]");
                }
            } catch (Throwable th) {
                this.modelLock.readLock().unlock();
                throw th;
            }
        }

        protected void releaseModel(Model model) {
            try {
                this.modelLock.readLock().lock();
                switch (this.loadBalanceMode) {
                    case FIFO:
                        this.queue.add(model);
                        break;
                    case ROUND_ROBIN:
                        break;
                    default:
                        throw new ND4JIllegalStateException("Unknown LoadBalanceMode was specified: [" + this.loadBalanceMode + "]");
                }
            } finally {
                this.modelLock.readLock().unlock();
            }
        }

        /* JADX WARN: Finally extract failed */
        protected INDArray[] output(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
            ComputationGraph acquireModel;
            INDArray output;
            try {
                try {
                    this.modelLock.readLock().lock();
                    if (this.isCG) {
                        acquireModel = acquireModel();
                        try {
                            INDArray[] output2 = this.layersToOutputTo != null ? acquireModel.output(Arrays.asList(this.layersToOutputTo), false, iNDArrayArr, iNDArrayArr2) : acquireModel.output(false, iNDArrayArr, iNDArrayArr2);
                            releaseModel(acquireModel);
                            return output2;
                        } catch (Throwable th) {
                            throw th;
                        }
                    }
                    if (!this.isMLN) {
                        throw new UnsupportedOperationException();
                    }
                    if (iNDArrayArr.length > 1 || (iNDArrayArr2 != null && iNDArrayArr2.length > 1)) {
                        throw new ND4JIllegalStateException("MultilayerNetwork can't have multiple inputs");
                    }
                    acquireModel = acquireModel();
                    try {
                        if (this.layerIndicesOutputTo != null) {
                            output = (INDArray) acquireModel.feedForwardToLayer(this.layerIndicesOutputTo[0], iNDArrayArr[0], false).get(0);
                        } else {
                            output = acquireModel.output(iNDArrayArr[0], false, iNDArrayArr2 == null ? null : iNDArrayArr2[0], (INDArray) null);
                        }
                        releaseModel(acquireModel);
                        INDArray[] iNDArrayArr3 = {output};
                        this.modelLock.readLock().unlock();
                        return iNDArrayArr3;
                    } finally {
                        releaseModel(acquireModel);
                    }
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            } finally {
                this.modelLock.readLock().unlock();
            }
        }

        protected void updateModel(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model is marked non-null but is null");
            }
            try {
                this.modelLock.writeLock().lock();
                this.sourceModel = model;
                init();
            } finally {
                this.modelLock.writeLock().unlock();
            }
        }

        @Generated
        private static int $default$workers() {
            return 4;
        }

        @Generated
        private static List<Model> $default$replicas() {
            return new ArrayList();
        }

        @Generated
        private static boolean $default$rootDevice() {
            return true;
        }

        @Generated
        private static boolean $default$isCG() {
            return false;
        }

        @Generated
        private static boolean $default$isMLN() {
            return false;
        }

        @Generated
        public static ModelHolderBuilder builder() {
            return new ModelHolderBuilder();
        }

        @Generated
        public ModelHolder() {
            this.position = new AtomicLong(0L);
            this.modelLock = new ReentrantReadWriteLock();
            this.queue = new LinkedBlockingQueue();
            this.workers = $default$workers();
            this.replicas = $default$replicas();
            this.rootDevice = $default$rootDevice();
            this.loadBalanceMode = LoadBalanceMode.ROUND_ROBIN;
            this.isCG = $default$isCG();
            this.isMLN = $default$isMLN();
        }

        @Generated
        public ModelHolder(Model model, int i, List<Model> list, boolean z, LoadBalanceMode loadBalanceMode, String[] strArr, int[] iArr, int i2, boolean z2, boolean z3) {
            this.position = new AtomicLong(0L);
            this.modelLock = new ReentrantReadWriteLock();
            this.queue = new LinkedBlockingQueue();
            this.sourceModel = model;
            this.workers = i;
            this.replicas = list;
            this.rootDevice = z;
            this.loadBalanceMode = loadBalanceMode;
            this.layersToOutputTo = strArr;
            this.layerIndicesOutputTo = iArr;
            this.targetDeviceId = i2;
            this.isCG = z2;
            this.isMLN = z3;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/parallelism/InplaceParallelInference$ModelSelector.class */
    protected static class ModelSelector {
        protected Map<Integer, ModelHolder> map;
        protected final LoadBalanceMode loadBalanceMode;

        public ModelSelector() {
            this(LoadBalanceMode.ROUND_ROBIN);
        }

        public ModelSelector(LoadBalanceMode loadBalanceMode) {
            this.map = new HashMap();
            this.loadBalanceMode = loadBalanceMode;
        }

        protected void addModelHolder(@NonNull Integer num, @NonNull ModelHolder modelHolder) {
            if (num == null) {
                throw new NullPointerException("device is marked non-null but is null");
            }
            if (modelHolder == null) {
                throw new NullPointerException("holder is marked non-null but is null");
            }
            this.map.put(num, modelHolder);
        }

        public ModelHolder getModelForThread(long j) {
            return this.map.get(Nd4j.getAffinityManager().getDeviceForThread(j));
        }

        public INDArray[] output(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
            return getModelForThisThread().output(iNDArrayArr, iNDArrayArr2);
        }

        public ModelHolder getModelForThisThread() {
            return getModelForThread(Thread.currentThread().getId());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.parallelism.ParallelInference
    public void init() {
        int i = 0;
        while (i < Nd4j.getAffinityManager().getNumberOfDevices()) {
            ModelHolder build = ModelHolder.builder().sourceModel(this.model).workers(this.workers).layerIndicesOutputTo(this.layerIndicesOutputTo).layersToOutputTo(this.layersToOutputTo).loadBalanceMode(this.loadBalanceMode).targetDeviceId(i).rootDevice(i == Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue()).build();
            build.init();
            this.holders.add(build);
            this.selector.addModelHolder(Integer.valueOf(i), build);
            i++;
        }
    }

    @Override // org.deeplearning4j.parallelism.ParallelInference
    public synchronized void updateModel(@NonNull Model model) {
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        Iterator<ModelHolder> it = this.holders.iterator();
        while (it.hasNext()) {
            it.next().updateModel(model);
        }
    }

    @Override // org.deeplearning4j.parallelism.ParallelInference
    protected synchronized Model[] getCurrentModelsFromWorkers() {
        Model[] modelArr = new Model[this.holders.size()];
        int i = 0;
        Iterator<ModelHolder> it = this.holders.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            modelArr[i2] = it.next().sourceModel;
        }
        return modelArr;
    }

    @Override // org.deeplearning4j.parallelism.ParallelInference
    public INDArray[] output(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        return this.selector.output(iNDArrayArr, iNDArrayArr2);
    }

    public <T> T output(@NonNull ModelAdapter<T> modelAdapter, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3) {
        if (modelAdapter == null) {
            throw new NullPointerException("adapter is marked non-null but is null");
        }
        ModelHolder modelForThisThread = this.selector.getModelForThisThread();
        Model model = null;
        boolean z = false;
        try {
            try {
                model = modelForThisThread.acquireModel();
                z = true;
                T t = (T) modelAdapter.apply(model, iNDArrayArr, iNDArrayArr2, iNDArrayArr3);
                if (model != null && 1 != 0) {
                    modelForThisThread.releaseModel(model);
                }
                return t;
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            if (model != null && z) {
                modelForThisThread.releaseModel(model);
            }
            throw th;
        }
    }
}
