package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasActivationUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDepthwiseConvolution2D.class */
public class KerasDepthwiseConvolution2D extends KerasConvolution {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(KerasDepthwiseConvolution2D.class);

    public KerasDepthwiseConvolution2D(Integer num) throws UnsupportedKerasConfigurationException {
        super(num);
    }

    public KerasDepthwiseConvolution2D(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, Collections.emptyMap(), true);
    }

    public KerasDepthwiseConvolution2D(Map<String, Object> map, Map<String, ? extends KerasLayer> map2) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, map2, true);
    }

    public KerasDepthwiseConvolution2D(Map<String, Object> map, Map<String, ? extends KerasLayer> map2, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, map2, null, z);
    }

    public KerasDepthwiseConvolution2D(Map<String, Object> map, Map<String, ? extends KerasLayer> map2, List<String> list, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(map, z);
        if (list != null) {
            this.inboundLayerNames.addAll(list);
        }
        this.hasBias = KerasLayerUtils.getHasBiasFromConfig(map, this.conf);
        this.numTrainableParams = this.hasBias ? 2 : 1;
        long[] dilationRateLong = KerasConvolutionUtils.getDilationRateLong(map, 2, this.conf, false);
        IWeightInit weightInitFromConfig = KerasInitilizationUtils.getWeightInitFromConfig(map, this.conf.getLAYER_FIELD_DEPTH_WISE_INIT(), z, this.conf, this.kerasMajorVersion.intValue());
        long nInFromConfig = getNInFromConfig(map2);
        int depthMultiplier = KerasConvolutionUtils.getDepthMultiplier(map, this.conf);
        this.weightL1Regularization = KerasRegularizerUtils.getWeightRegularizerFromConfig(map, this.conf, this.conf.getLAYER_FIELD_DEPTH_WISE_REGULARIZER(), this.conf.getREGULARIZATION_TYPE_L1());
        this.weightL2Regularization = KerasRegularizerUtils.getWeightRegularizerFromConfig(map, this.conf, this.conf.getLAYER_FIELD_DEPTH_WISE_REGULARIZER(), this.conf.getREGULARIZATION_TYPE_L2());
        LayerConstraint constraintsFromConfig = KerasConstraintUtils.getConstraintsFromConfig(map, this.conf.getLAYER_FIELD_B_CONSTRAINT(), this.conf, this.kerasMajorVersion.intValue());
        LayerConstraint constraintsFromConfig2 = KerasConstraintUtils.getConstraintsFromConfig(map, this.conf.getLAYER_FIELD_DEPTH_WISE_CONSTRAINT(), this.conf, this.kerasMajorVersion.intValue());
        DepthwiseConvolution2D.Builder stride = new DepthwiseConvolution2D.Builder().name(this.layerName).dropOut(this.dropout).nIn(nInFromConfig).nOut(nInFromConfig * depthMultiplier).activation(KerasActivationUtils.getIActivationFromConfig(map, this.conf)).weightInit(weightInitFromConfig).depthMultiplier(depthMultiplier).l1(this.weightL1Regularization).l2(this.weightL2Regularization).convolutionMode(KerasConvolutionUtils.getConvolutionModeFromConfig(map, this.conf)).kernelSize(KerasConvolutionUtils.getKernelSizeFromConfigLong(map, 2, this.conf, this.kerasMajorVersion.intValue())).hasBias(this.hasBias).dataFormat(this.dimOrder == KerasLayer.DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW).stride(KerasConvolutionUtils.getStrideFromConfigLong(map, 2, this.conf));
        long[] paddingFromBorderModeConfigLong = KerasConvolutionUtils.getPaddingFromBorderModeConfigLong(map, 2, this.conf, this.kerasMajorVersion.intValue());
        if (this.hasBias) {
            stride.biasInit(0.0d);
        }
        if (paddingFromBorderModeConfigLong != null) {
            stride.padding(paddingFromBorderModeConfigLong);
        }
        if (dilationRateLong != null) {
            stride.dilation(dilationRateLong);
        }
        if (constraintsFromConfig != null) {
            stride.constrainBias(new LayerConstraint[]{constraintsFromConfig});
        }
        if (constraintsFromConfig2 != null) {
            stride.constrainWeights(new LayerConstraint[]{constraintsFromConfig2});
        }
        this.layer = stride.build();
        this.layer.setDefaultValueOverriden(true);
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution, org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public void setWeights(Map<String, INDArray> map) throws InvalidKerasConfigurationException {
        INDArray iNDArray;
        this.weights = new HashMap();
        if (!map.containsKey(this.conf.getLAYER_PARAM_NAME_DEPTH_WISE_KERNEL())) {
            throw new InvalidKerasConfigurationException("Keras DepthwiseConvolution2D layer does not contain parameter " + this.conf.getLAYER_PARAM_NAME_DEPTH_WISE_KERNEL());
        }
        this.weights.put("W", map.get(this.conf.getLAYER_PARAM_NAME_DEPTH_WISE_KERNEL()));
        if (this.hasBias) {
            if (this.kerasMajorVersion.intValue() == 2 && map.containsKey("bias")) {
                iNDArray = map.get("bias");
            } else {
                if (this.kerasMajorVersion.intValue() != 1 || !map.containsKey("b")) {
                    throw new InvalidKerasConfigurationException("Keras DepthwiseConvolution2D layer does not contain bias parameter");
                }
                iNDArray = map.get("b");
            }
            this.weights.put("b", iNDArray);
        }
    }

    public DepthwiseConvolution2D getDepthwiseConvolution2DLayer() {
        return this.layer;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputType getOutputType(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        if (inputTypeArr.length > 1) {
            throw new InvalidKerasConfigurationException("Keras depth-wise convolution 2D layer accepts only one input (received " + inputTypeArr.length + ")");
        }
        return getDepthwiseConvolution2DLayer().getOutputType(-1, inputTypeArr[0]);
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution
    @Generated
    public String toString() {
        return "KerasDepthwiseConvolution2D()";
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution
    @Generated
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof KerasDepthwiseConvolution2D) && ((KerasDepthwiseConvolution2D) obj).canEqual(this);
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution
    @Generated
    protected boolean canEqual(Object obj) {
        return obj instanceof KerasDepthwiseConvolution2D;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution
    @Generated
    public int hashCode() {
        return 1;
    }
}
