package org.deeplearning4j.nn.modelimport.keras.layers.attention;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.Generated;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.graph.DotProductAttentionVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/attention/KerasAttentionLayer.class */
public class KerasAttentionLayer extends KerasLayer {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(KerasAttentionLayer.class);
    private boolean useScale;
    private double dropOut;
    private String scoreMode;
    private List<String> inputNames;
    private final String LAYER_DROP_OUT = "dropout";
    private final String LAYER_SCORE_MODE = "score_mode";
    private final String LAYER_SCORE_MODE_DOT = "dot";
    private final String LAYER_SCORE_MODE_CONCAT = "concat";
    private final String LAYER_USE_SCALE = "use_scale";

    /* renamed from: org.deeplearning4j.nn.modelimport.keras.layers.attention.KerasAttentionLayer$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/attention/KerasAttentionLayer$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type = new int[InputType.Type.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type[InputType.Type.FF.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type[InputType.Type.CNN.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type[InputType.Type.RNN.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type[InputType.Type.CNN3D.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type[InputType.Type.CNNFlat.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    public KerasAttentionLayer(Integer num) throws UnsupportedKerasConfigurationException {
        super(num);
        this.LAYER_DROP_OUT = "dropout";
        this.LAYER_SCORE_MODE = "score_mode";
        this.LAYER_SCORE_MODE_DOT = "dot";
        this.LAYER_SCORE_MODE_CONCAT = "concat";
        this.LAYER_USE_SCALE = "use_scale";
    }

    public KerasAttentionLayer() throws UnsupportedKerasConfigurationException {
        this.LAYER_DROP_OUT = "dropout";
        this.LAYER_SCORE_MODE = "score_mode";
        this.LAYER_SCORE_MODE_DOT = "dot";
        this.LAYER_SCORE_MODE_CONCAT = "concat";
        this.LAYER_USE_SCALE = "use_scale";
    }

    public KerasAttentionLayer(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, false);
    }

    public KerasAttentionLayer(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(map, z);
        this.LAYER_DROP_OUT = "dropout";
        this.LAYER_SCORE_MODE = "score_mode";
        this.LAYER_SCORE_MODE_DOT = "dot";
        this.LAYER_SCORE_MODE_CONCAT = "concat";
        this.LAYER_USE_SCALE = "use_scale";
        Map<String, Object> innerLayerConfigFromConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(map, this.conf);
        this.useScale = Boolean.parseBoolean(innerLayerConfigFromConfig.getOrDefault("use_scale", "false").toString());
        this.dropOut = Double.parseDouble(innerLayerConfigFromConfig.getOrDefault("dropout", "0.0").toString());
        this.inputNames = KerasLayerUtils.getInboundLayerNamesFromConfig(map, this.conf);
        String obj = innerLayerConfigFromConfig.getOrDefault("score_mode", "dot").toString();
        if (!obj.equals("dot")) {
            throw new InvalidKerasConfigurationException("Invalid score mode " + obj);
        }
        this.vertex = new DotProductAttentionVertex.Builder().dropoutProbability(this.dropout).scale(this.useScale ? 0.2d : 1.0d).inputNames(this.inputNames).build();
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputType getOutputType(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        InputPreProcessor inputPreprocessor = getInputPreprocessor(inputTypeArr[0]);
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$inputs$InputType$Type[inputTypeArr[0].getType().ordinal()]) {
            case 1:
                InputType.InputTypeFeedForward inputTypeFeedForward = (InputType.InputTypeFeedForward) inputTypeArr[0];
                getAttentionVertex().setNIn(inputTypeFeedForward.getSize());
                getAttentionVertex().setNOut(inputTypeFeedForward.getSize());
                break;
            case 2:
                InputType.InputTypeConvolutional inputTypeConvolutional = (InputType.InputTypeConvolutional) inputTypeArr[0];
                getAttentionVertex().setNIn(inputTypeConvolutional.getChannels());
                getAttentionVertex().setNOut(inputTypeConvolutional.getChannels());
                break;
            case 3:
                InputType.InputTypeRecurrent inputTypeRecurrent = (InputType.InputTypeRecurrent) inputTypeArr[0];
                getAttentionVertex().setNIn(inputTypeRecurrent.getSize());
                getAttentionVertex().setNOut(inputTypeRecurrent.getSize());
                break;
            case 4:
            case 5:
                throw new InvalidKerasConfigurationException("Unsupported input type for attention layer: " + inputTypeArr[0].getType());
        }
        return inputPreprocessor != null ? getAttentionVertex().getOutputType(-1, new InputType[]{inputPreprocessor.getOutputType(inputTypeArr[0])}) : getAttentionVertex().getOutputType(-1, new InputType[]{inputTypeArr[0]});
    }

    private DotProductAttentionVertex getAttentionVertex() {
        return this.vertex;
    }

    @Generated
    public boolean isUseScale() {
        return this.useScale;
    }

    @Generated
    public double getDropOut() {
        return this.dropOut;
    }

    @Generated
    public String getScoreMode() {
        return this.scoreMode;
    }

    @Generated
    public List<String> getInputNames() {
        return this.inputNames;
    }

    @Generated
    public String getLAYER_DROP_OUT() {
        Objects.requireNonNull(this);
        return "dropout";
    }

    @Generated
    public String getLAYER_SCORE_MODE() {
        Objects.requireNonNull(this);
        return "score_mode";
    }

    @Generated
    public String getLAYER_SCORE_MODE_DOT() {
        Objects.requireNonNull(this);
        return "dot";
    }

    @Generated
    public String getLAYER_SCORE_MODE_CONCAT() {
        Objects.requireNonNull(this);
        return "concat";
    }

    @Generated
    public String getLAYER_USE_SCALE() {
        Objects.requireNonNull(this);
        return "use_scale";
    }

    @Generated
    public void setUseScale(boolean z) {
        this.useScale = z;
    }

    @Generated
    public void setDropOut(double d) {
        this.dropOut = d;
    }

    @Generated
    public void setScoreMode(String str) {
        this.scoreMode = str;
    }

    @Generated
    public void setInputNames(List<String> list) {
        this.inputNames = list;
    }

    @Generated
    public String toString() {
        boolean isUseScale = isUseScale();
        double dropOut = getDropOut();
        String scoreMode = getScoreMode();
        List<String> inputNames = getInputNames();
        String layer_drop_out = getLAYER_DROP_OUT();
        String layer_score_mode = getLAYER_SCORE_MODE();
        String layer_score_mode_dot = getLAYER_SCORE_MODE_DOT();
        String layer_score_mode_concat = getLAYER_SCORE_MODE_CONCAT();
        getLAYER_USE_SCALE();
        return "KerasAttentionLayer(useScale=" + isUseScale + ", dropOut=" + dropOut + ", scoreMode=" + isUseScale + ", inputNames=" + scoreMode + ", LAYER_DROP_OUT=" + inputNames + ", LAYER_SCORE_MODE=" + layer_drop_out + ", LAYER_SCORE_MODE_DOT=" + layer_score_mode + ", LAYER_SCORE_MODE_CONCAT=" + layer_score_mode_dot + ", LAYER_USE_SCALE=" + layer_score_mode_concat + ")";
    }

    @Generated
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof KerasAttentionLayer)) {
            return false;
        }
        KerasAttentionLayer kerasAttentionLayer = (KerasAttentionLayer) obj;
        if (!kerasAttentionLayer.canEqual(this) || isUseScale() != kerasAttentionLayer.isUseScale() || Double.compare(getDropOut(), kerasAttentionLayer.getDropOut()) != 0) {
            return false;
        }
        String scoreMode = getScoreMode();
        String scoreMode2 = kerasAttentionLayer.getScoreMode();
        if (scoreMode == null) {
            if (scoreMode2 != null) {
                return false;
            }
        } else if (!scoreMode.equals(scoreMode2)) {
            return false;
        }
        List<String> inputNames = getInputNames();
        List<String> inputNames2 = kerasAttentionLayer.getInputNames();
        if (inputNames == null) {
            if (inputNames2 != null) {
                return false;
            }
        } else if (!inputNames.equals(inputNames2)) {
            return false;
        }
        String layer_drop_out = getLAYER_DROP_OUT();
        String layer_drop_out2 = kerasAttentionLayer.getLAYER_DROP_OUT();
        if (layer_drop_out == null) {
            if (layer_drop_out2 != null) {
                return false;
            }
        } else if (!layer_drop_out.equals(layer_drop_out2)) {
            return false;
        }
        String layer_score_mode = getLAYER_SCORE_MODE();
        String layer_score_mode2 = kerasAttentionLayer.getLAYER_SCORE_MODE();
        if (layer_score_mode == null) {
            if (layer_score_mode2 != null) {
                return false;
            }
        } else if (!layer_score_mode.equals(layer_score_mode2)) {
            return false;
        }
        String layer_score_mode_dot = getLAYER_SCORE_MODE_DOT();
        String layer_score_mode_dot2 = kerasAttentionLayer.getLAYER_SCORE_MODE_DOT();
        if (layer_score_mode_dot == null) {
            if (layer_score_mode_dot2 != null) {
                return false;
            }
        } else if (!layer_score_mode_dot.equals(layer_score_mode_dot2)) {
            return false;
        }
        String layer_score_mode_concat = getLAYER_SCORE_MODE_CONCAT();
        String layer_score_mode_concat2 = kerasAttentionLayer.getLAYER_SCORE_MODE_CONCAT();
        if (layer_score_mode_concat == null) {
            if (layer_score_mode_concat2 != null) {
                return false;
            }
        } else if (!layer_score_mode_concat.equals(layer_score_mode_concat2)) {
            return false;
        }
        String layer_use_scale = getLAYER_USE_SCALE();
        String layer_use_scale2 = kerasAttentionLayer.getLAYER_USE_SCALE();
        return layer_use_scale == null ? layer_use_scale2 == null : layer_use_scale.equals(layer_use_scale2);
    }

    @Generated
    protected boolean canEqual(Object obj) {
        return obj instanceof KerasAttentionLayer;
    }

    @Generated
    public int hashCode() {
        int i = (1 * 59) + (isUseScale() ? 79 : 97);
        long doubleToLongBits = Double.doubleToLongBits(getDropOut());
        int i2 = (i * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        String scoreMode = getScoreMode();
        int hashCode = (i2 * 59) + (scoreMode == null ? 43 : scoreMode.hashCode());
        List<String> inputNames = getInputNames();
        int hashCode2 = (hashCode * 59) + (inputNames == null ? 43 : inputNames.hashCode());
        String layer_drop_out = getLAYER_DROP_OUT();
        int hashCode3 = (hashCode2 * 59) + (layer_drop_out == null ? 43 : layer_drop_out.hashCode());
        String layer_score_mode = getLAYER_SCORE_MODE();
        int hashCode4 = (hashCode3 * 59) + (layer_score_mode == null ? 43 : layer_score_mode.hashCode());
        String layer_score_mode_dot = getLAYER_SCORE_MODE_DOT();
        int hashCode5 = (hashCode4 * 59) + (layer_score_mode_dot == null ? 43 : layer_score_mode_dot.hashCode());
        String layer_score_mode_concat = getLAYER_SCORE_MODE_CONCAT();
        int hashCode6 = (hashCode5 * 59) + (layer_score_mode_concat == null ? 43 : layer_score_mode_concat.hashCode());
        String layer_use_scale = getLAYER_USE_SCALE();
        return (hashCode6 * 59) + (layer_use_scale == null ? 43 : layer_use_scale.hashCode());
    }
}
