package com.kotlinnlp.simplednn.core.layers.models.attention.attentionmechanism;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.functionalities.activations.SoftmaxBase;
import com.kotlinnlp.simplednn.core.layers.Layer;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper;
import com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper;
import com.kotlinnlp.simplednn.core.layers.helpers.RelevanceHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.utils.ItemsPool;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: AttentionMechanismLayer.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��T\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n��\n\u0002\u0010\b\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��2\u00020\u00012\b\u0012\u0004\u0012\u00020\u00030\u0002BM\u0012\u0012\u0010\u0004\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u00060\u0005\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\n\u0010\t\u001a\u0006\u0012\u0002\b\u00030\n\u0012\n\b\u0002\u0010\u000b\u001a\u0004\u0018\u00010\f\u0012\b\b\u0002\u0010\r\u001a\u00020\u000e\u0012\b\b\u0002\u0010\u000f\u001a\u00020\u0010¢\u0006\u0002\u0010\u0011R\u001a\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00030\u0006X\u0080\u0004¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u0014\u0010\u0015\u001a\u00020\u0016X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018R\u0014\u0010\u0019\u001a\u00020\u001aX\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u001b\u0010\u001cR\u0014\u0010\u000f\u001a\u00020\u0010X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001d\u0010\u001eR\u001d\u0010\u0004\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u00060\u0005¢\u0006\b\n��\u001a\u0004\b\u001f\u0010 R\u0016\u0010!\u001a\u0004\u0018\u00010\"X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b#\u0010$¨\u0006%"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismLayer;", "Lcom/kotlinnlp/utils/ItemsPool$IDItem;", "Lcom/kotlinnlp/simplednn/core/layers/Layer;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "inputArrays", "", "Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "inputType", "Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;", "params", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "activation", "Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;", "dropout", "", "id", "", "(Ljava/util/List;Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;DI)V", "attentionMatrix", "getAttentionMatrix$simplednn", "()Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "backwardHelper", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismBackwardHelper;", "getBackwardHelper", "()Lcom/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismBackwardHelper;", "forwardHelper", "Lcom/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismForwardHelper;", "getForwardHelper", "()Lcom/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismForwardHelper;", "getId", "()I", "getInputArrays", "()Ljava/util/List;", "relevanceHelper", "Lcom/kotlinnlp/simplednn/core/layers/helpers/RelevanceHelper;", "getRelevanceHelper", "()Lcom/kotlinnlp/simplednn/core/layers/helpers/RelevanceHelper;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/models/attention/attentionmechanism/AttentionMechanismLayer.class */
public final class AttentionMechanismLayer extends Layer<DenseNDArray> implements ItemsPool.IDItem {

    @NotNull
    private final AugmentedArray<DenseNDArray> attentionMatrix;

    @NotNull
    private final AttentionMechanismForwardHelper forwardHelper;

    @NotNull
    private final AttentionMechanismBackwardHelper backwardHelper;

    @Nullable
    private final RelevanceHelper relevanceHelper;

    @NotNull
    private final List<AugmentedArray<DenseNDArray>> inputArrays;
    private final int id;

    @NotNull
    public final AugmentedArray<DenseNDArray> getAttentionMatrix$simplednn() {
        return this.attentionMatrix;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @NotNull
    /* renamed from: getForwardHelper, reason: merged with bridge method [inline-methods] */
    public ForwardHelper<DenseNDArray> getForwardHelper2() {
        return this.forwardHelper;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @NotNull
    /* renamed from: getBackwardHelper, reason: merged with bridge method [inline-methods] */
    public BackwardHelper<DenseNDArray> getBackwardHelper2() {
        return this.backwardHelper;
    }

    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    @Nullable
    protected RelevanceHelper getRelevanceHelper() {
        return this.relevanceHelper;
    }

    @NotNull
    public final List<AugmentedArray<DenseNDArray>> getInputArrays() {
        return this.inputArrays;
    }

    @Override // com.kotlinnlp.simplednn.core.layers.Layer
    public int getId() {
        return this.id;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    /* JADX WARN: Multi-variable type inference failed */
    public AttentionMechanismLayer(@NotNull List<? extends AugmentedArray<DenseNDArray>> list, @NotNull LayerType.Input input, @NotNull LayerParameters<?> layerParameters, @Nullable ActivationFunction activationFunction, double d, int i) {
        super((AugmentedArray) list.get(0), input, new AugmentedArray(list.size()), layerParameters, activationFunction, d, 0, 64, null);
        boolean z;
        Intrinsics.checkParameterIsNotNull(list, "inputArrays");
        Intrinsics.checkParameterIsNotNull(input, "inputType");
        Intrinsics.checkParameterIsNotNull(layerParameters, "params");
        this.inputArrays = list;
        this.id = i;
        AugmentedArray.Companion companion = AugmentedArray.Companion;
        DenseNDArrayFactory denseNDArrayFactory = DenseNDArrayFactory.INSTANCE;
        List<AugmentedArray<DenseNDArray>> list2 = this.inputArrays;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add(((DenseNDArray) ((AugmentedArray) it.next()).getValues()).toDoubleArray());
        }
        this.attentionMatrix = companion.invoke(denseNDArrayFactory.arrayOf(arrayList));
        this.forwardHelper = new AttentionMechanismForwardHelper(this);
        this.backwardHelper = new AttentionMechanismBackwardHelper(this);
        if (!(!this.inputArrays.isEmpty())) {
            throw new IllegalArgumentException("The attention sequence cannot be empty.".toString());
        }
        List<AugmentedArray<DenseNDArray>> list3 = this.inputArrays;
        if (!(list3 instanceof Collection) || !list3.isEmpty()) {
            Iterator<T> it2 = list3.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    z = true;
                    break;
                }
                int length = ((DenseNDArray) ((AugmentedArray) it2.next()).getValues()).getLength();
                LayerParameters<?> params = getParams();
                if (params == null) {
                    throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.attention.attentionmechanism.AttentionMechanismLayerParameters");
                }
                if (!(length == ((AttentionMechanismLayerParameters) params).getInputSize())) {
                    z = false;
                    break;
                }
            }
        } else {
            z = true;
        }
        if (z) {
            if (getActivationFunction() != null) {
                getOutputArray().setActivation(getActivationFunction());
                return;
            }
            return;
        }
        Object[] objArr = new Object[1];
        LayerParameters<?> params2 = getParams();
        if (params2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.attention.attentionmechanism.AttentionMechanismLayerParameters");
        }
        objArr[0] = Integer.valueOf(((AttentionMechanismLayerParameters) params2).getInputSize());
        String format = String.format("The input arrays must have the expected size (%d).", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        throw new IllegalArgumentException(format.toString());
    }

    public /* synthetic */ AttentionMechanismLayer(List list, LayerType.Input input, LayerParameters layerParameters, ActivationFunction activationFunction, double d, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(list, input, layerParameters, (i2 & 8) != 0 ? new SoftmaxBase() : activationFunction, (i2 & 16) != 0 ? 0.0d : d, (i2 & 32) != 0 ? 0 : i);
    }
}
