package com.kotlinnlp.simplednn.core.layers.models.feedforward.highway;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.arrays.AugmentedArrayExtensionsKt;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: HighwayBackwardHelper.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��&\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0002\n\u0002\b\u0004\n\u0002\u0010\u000b\n��\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003B\u0013\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005¢\u0006\u0002\u0010\u0006J\b\u0010\t\u001a\u00020\nH\u0002J\b\u0010\u000b\u001a\u00020\nH\u0002J\b\u0010\f\u001a\u00020\nH\u0002J\u0010\u0010\r\u001a\u00020\n2\u0006\u0010\u000e\u001a\u00020\u000fH\u0014R\u001a\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\b¨\u0006\u0010"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/models/feedforward/highway/HighwayBackwardHelper;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/layers/helpers/BackwardHelper;", "layer", "Lcom/kotlinnlp/simplednn/core/layers/models/feedforward/highway/HighwayLayer;", "(Lcom/kotlinnlp/simplednn/core/layers/models/feedforward/highway/HighwayLayer;)V", "getLayer", "()Lcom/kotlinnlp/simplednn/core/layers/models/feedforward/highway/HighwayLayer;", "assignGatesGradients", "", "assignLayerGradients", "assignParamsGradients", "execBackward", "propagateToInput", "", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/models/feedforward/highway/HighwayBackwardHelper.class */
public final class HighwayBackwardHelper<InputNDArrayType extends NDArray<InputNDArrayType>> extends BackwardHelper<InputNDArrayType> {

    @NotNull
    private final HighwayLayer<InputNDArrayType> layer;

    @Override // com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper
    protected void execBackward(boolean z) {
        if (!(getLayer2().getInputArray().getValues() instanceof DenseNDArray)) {
            throw new IllegalArgumentException("Highway layer supports only dense input.".toString());
        }
        assignGatesGradients();
        assignParamsGradients();
        if (z) {
            assignLayerGradients();
        }
    }

    private final void assignGatesGradients() {
        InputNDArrayType values = getLayer2().getInputArray().getValues();
        DenseNDArray errors = getLayer2().getOutputArray().getErrors();
        DenseNDArray values2 = getLayer2().getInputUnit().getValues();
        DenseNDArray values3 = getLayer2().getTransformGate().getValues();
        DenseNDArray calculateActivationDeriv = getLayer2().getTransformGate().calculateActivationDeriv();
        AugmentedArray<DenseNDArray> transformGate = getLayer2().getTransformGate();
        if (values == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        transformGate.assignErrorsByProd(values2.sub((DenseNDArray) values), errors);
        getLayer2().getTransformGate().getErrors().assignProd(calculateActivationDeriv);
        getLayer2().getInputUnit().assignErrorsByProd(values3, errors);
        if (getLayer2().getInputUnit().getHasActivation()) {
            getLayer2().getInputUnit().getErrors().assignProd(getLayer2().getInputUnit().calculateActivationDeriv());
        }
    }

    private final void assignParamsGradients() {
        LayerParameters<?> params = getLayer2().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.highway.HighwayLayerParameters");
        }
        getLayer2().getInputUnit().assignParamsGradients(getErrors(((HighwayLayerParameters) getLayer2().getParams()).getInput().getWeights()).getValues(), getErrors(((HighwayLayerParameters) getLayer2().getParams()).getInput().getBiases()).getValues(), getLayer2().getInputArray().getValues());
        getLayer2().getTransformGate().assignParamsGradients(getErrors(((HighwayLayerParameters) getLayer2().getParams()).getTransformGate().getWeights()).getValues(), getErrors(((HighwayLayerParameters) getLayer2().getParams()).getTransformGate().getBiases()).getValues(), getLayer2().getInputArray().getValues());
    }

    private final void assignLayerGradients() {
        LayerParameters<?> params = getLayer2().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.highway.HighwayLayerParameters");
        }
        DenseNDArray values = getLayer2().getTransformGate().getValues();
        getLayer2().getInputArray().assignErrors(values.reverseSub(1.0d).assignProd(getLayer2().getOutputArray().getErrors()).assignSum((NDArray<?>) AugmentedArrayExtensionsKt.getInputErrors(getLayer2().getInputUnit(), ((HighwayLayerParameters) getLayer2().getParams()).getInput().getWeights().getValues())));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper
    @NotNull
    /* renamed from: getLayer */
    public HighwayLayer<InputNDArrayType> getLayer2() {
        return this.layer;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public HighwayBackwardHelper(@NotNull HighwayLayer<InputNDArrayType> highwayLayer) {
        super(highwayLayer);
        Intrinsics.checkParameterIsNotNull(highwayLayer, "layer");
        this.layer = highwayLayer;
    }
}
