package org.deeplearning4j.nn.modelimport.keras.layers;

import com.google.gson.Gson;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.config.DL4JClassLoading;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.TFGraphRunnerService;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.protobuf.TextFormat;
import org.nd4j.shade.protobuf.util.JsonFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.class */
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(TFOpLayerImpl.class);
    private Map nodeDef;
    private Map constants;
    private List<String> inputNames;
    TFGraphRunnerService graphRunnerService;

    public TFOpLayerImpl(Map map, Map map2, NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
        this.nodeDef = map;
        this.constants = map2;
        setGraphRunner();
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new RuntimeException("Backprop through TFOpLayerImpl is not supported yet. TFOpLayerImpl is created when importing TensorFlow 2.0 Keras models (tf.keras) into DL4J, that contains TensorFlow operations not just Keras layers.");
    }

    private void setGraphRunner() {
        try {
            String json = new Gson().toJson(this.nodeDef);
            NodeDef.Builder newBuilder = NodeDef.newBuilder();
            JsonFormat.parser().merge(json, newBuilder);
            NodeDef build = newBuilder.build();
            ArrayList arrayList = new ArrayList();
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            this.inputNames = new ArrayList();
            List asList = Arrays.asList(build.getName());
            Map attrMap = build.getAttrMap();
            for (int i = 0; i < build.getInputCount(); i++) {
                String[] split = build.getInput(i).split("/");
                String str = split.length == 1 ? "T" : "T" + split[split.length - 1];
                arrayList.add(build.getInput(i));
                hashMap.put(build.getInput(i), ((AttrValue) attrMap.get(str)).getType().toString());
                if (this.constants.containsKey(String.valueOf(i))) {
                    hashMap2.put(build.getInput(i), Nd4j.create((List) this.constants.get(String.valueOf(i))));
                } else {
                    this.inputNames.add(build.getInput(i));
                }
            }
            String str2 = "node{\n" + build.toString() + "\n}\nversions {\n producer: 22\n}";
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                String str3 = (String) arrayList.get(i2);
                str2 = "node{\nname: \"" + str3 + "\"\nop: \"Placeholder\"\nattr{\nkey: \"dtype\"\n value {\n type: " + ((String) hashMap.get(str3)) + "}\n}\n}\n" + str2;
            }
            GraphDef.Builder newBuilder2 = GraphDef.newBuilder();
            TextFormat.getParser().merge(str2, newBuilder2);
            byte[] byteArray = newBuilder2.build().toByteString().toByteArray();
            Iterator it = DL4JClassLoading.loadService(TFGraphRunnerService.class).iterator();
            if (!it.hasNext()) {
                throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");
            }
            this.graphRunnerService = ((TFGraphRunnerService) it.next()).init(arrayList, asList, byteArray, hashMap2, hashMap);
        } catch (Exception e) {
            throw new RuntimeException("Error parsing protobuf", e);
        }
    }

    private INDArray runGraph(INDArray iNDArray) {
        HashMap hashMap = new HashMap();
        hashMap.put(this.inputNames.get(0), iNDArray);
        return ((INDArray[]) this.graphRunnerService.run(hashMap).values().toArray(new INDArray[0]))[0];
    }

    public long[] getOutputShape(long[] jArr) {
        long[] clone = ArrayUtils.clone(jArr);
        for (int i = 0; i < clone.length; i++) {
            if (clone[i] < 0) {
                clone[i] = 1;
            }
        }
        return runGraph(Nd4j.zeros(clone)).shape();
    }

    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return runGraph(this.input);
    }

    public boolean isPretrainLayer() {
        return false;
    }

    public void clearNoiseWeightParams() {
    }

    @Generated
    public Map getNodeDef() {
        return this.nodeDef;
    }

    @Generated
    public Map getConstants() {
        return this.constants;
    }

    @Generated
    public List<String> getInputNames() {
        return this.inputNames;
    }

    @Generated
    public TFGraphRunnerService getGraphRunnerService() {
        return this.graphRunnerService;
    }

    @Generated
    public void setNodeDef(Map map) {
        this.nodeDef = map;
    }

    @Generated
    public void setConstants(Map map) {
        this.constants = map;
    }

    @Generated
    public void setInputNames(List<String> list) {
        this.inputNames = list;
    }

    @Generated
    public void setGraphRunnerService(TFGraphRunnerService tFGraphRunnerService) {
        this.graphRunnerService = tFGraphRunnerService;
    }

    @Generated
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof TFOpLayerImpl)) {
            return false;
        }
        TFOpLayerImpl tFOpLayerImpl = (TFOpLayerImpl) obj;
        if (!tFOpLayerImpl.canEqual(this)) {
            return false;
        }
        Map nodeDef = getNodeDef();
        Map nodeDef2 = tFOpLayerImpl.getNodeDef();
        if (nodeDef == null) {
            if (nodeDef2 != null) {
                return false;
            }
        } else if (!nodeDef.equals(nodeDef2)) {
            return false;
        }
        Map constants = getConstants();
        Map constants2 = tFOpLayerImpl.getConstants();
        if (constants == null) {
            if (constants2 != null) {
                return false;
            }
        } else if (!constants.equals(constants2)) {
            return false;
        }
        List<String> inputNames = getInputNames();
        List<String> inputNames2 = tFOpLayerImpl.getInputNames();
        if (inputNames == null) {
            if (inputNames2 != null) {
                return false;
            }
        } else if (!inputNames.equals(inputNames2)) {
            return false;
        }
        TFGraphRunnerService graphRunnerService = getGraphRunnerService();
        TFGraphRunnerService graphRunnerService2 = tFOpLayerImpl.getGraphRunnerService();
        return graphRunnerService == null ? graphRunnerService2 == null : graphRunnerService.equals(graphRunnerService2);
    }

    @Generated
    protected boolean canEqual(Object obj) {
        return obj instanceof TFOpLayerImpl;
    }

    @Generated
    public int hashCode() {
        Map nodeDef = getNodeDef();
        int hashCode = (1 * 59) + (nodeDef == null ? 43 : nodeDef.hashCode());
        Map constants = getConstants();
        int hashCode2 = (hashCode * 59) + (constants == null ? 43 : constants.hashCode());
        List<String> inputNames = getInputNames();
        int hashCode3 = (hashCode2 * 59) + (inputNames == null ? 43 : inputNames.hashCode());
        TFGraphRunnerService graphRunnerService = getGraphRunnerService();
        return (hashCode3 * 59) + (graphRunnerService == null ? 43 : graphRunnerService.hashCode());
    }

    @Generated
    public String toString() {
        return "TFOpLayerImpl(nodeDef=" + getNodeDef() + ", constants=" + getConstants() + ", inputNames=" + getInputNames() + ", graphRunnerService=" + getGraphRunnerService() + ")";
    }
}
