package com.kotlinnlp.simplednn.core.layers.models.recurrent.ran;

import com.kotlinnlp.simplednn.core.layers.ArrayExtensionsKt;
import com.kotlinnlp.simplednn.core.layers.Layer;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper;
import com.kotlinnlp.simplednn.core.layers.models.LinearParams;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.RecurrentLinearParams;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin._Assertions;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: RANForwardHelper.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��:\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003B\u0013\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005¢\u0006\u0002\u0010\u0006J(\u0010\t\u001a\u00020\n2\u0006\u0010\u000b\u001a\u00020\f2\u0006\u0010\r\u001a\u00020\f2\u0006\u0010\u000e\u001a\u00020\f2\u0006\u0010\u000f\u001a\u00020\u0010H\u0002J\b\u0010\u0011\u001a\u00020\nH\u0016J\u0014\u0010\u0011\u001a\u00020\n2\n\u0010\u000f\u001a\u0006\u0012\u0002\b\u00030\u0012H\u0016J \u0010\u0013\u001a\u00020\n2\u0006\u0010\u000f\u001a\u00020\u00102\u0006\u0010\r\u001a\u00020\f2\u0006\u0010\u000e\u001a\u00020\fH\u0002J\u0016\u0010\u0014\u001a\u00020\n2\f\u0010\u0015\u001a\b\u0012\u0002\b\u0003\u0018\u00010\u0016H\u0002J\u001e\u0010\u0014\u001a\u00020\n2\f\u0010\u0015\u001a\b\u0012\u0002\b\u0003\u0018\u00010\u00162\u0006\u0010\u000f\u001a\u00020\u0010H\u0002R\u001a\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\b¨\u0006\u0017"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/ran/RANForwardHelper;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/layers/helpers/ForwardHelper;", "layer", "Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/ran/RANLayer;", "(Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/ran/RANLayer;)V", "getLayer", "()Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/ran/RANLayer;", "addGatesRecurrentContribution", "", "yPrev", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "bInG", "bForG", "layerContributions", "Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/ran/RANLayerParameters;", "forward", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "forwardGates", "setGates", "prevStateLayer", "Lcom/kotlinnlp/simplednn/core/layers/Layer;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/models/recurrent/ran/RANForwardHelper.class */
public final class RANForwardHelper<InputNDArrayType extends NDArray<InputNDArrayType>> extends ForwardHelper<InputNDArrayType> {

    @NotNull
    private final RANLayer<InputNDArrayType> layer;

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper
    public void forward() {
        Layer<?> prevState = getLayer2().getLayerContextWindow().getPrevState();
        setGates(prevState);
        DenseNDArray values = getLayer2().getOutputArray().getValues();
        DenseNDArray values2 = getLayer2().getCandidate().getValues();
        DenseNDArray denseNDArray = (DenseNDArray) getLayer2().getInputGate().getValues();
        DenseNDArray denseNDArray2 = (DenseNDArray) getLayer2().getForgetGate().getValues();
        values.assignProd(denseNDArray, values2);
        if (prevState != null) {
            values.assignSum((NDArray<?>) prevState.getOutputArray().getValuesNotActivated().prod((NDArray<?>) denseNDArray2));
        }
        getLayer2().getOutputArray().activate();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper
    public void forward(@NotNull LayerParameters<?> layerParameters) {
        Intrinsics.checkParameterIsNotNull(layerParameters, "layerContributions");
        Layer<?> prevState = getLayer2().getLayerContextWindow().getPrevState();
        setGates(prevState, (RANLayerParameters) layerParameters);
        DenseNDArray values = getLayer2().getOutputArray().getValues();
        DenseNDArray values2 = getLayer2().getCandidate().getValues();
        DenseNDArray denseNDArray = (DenseNDArray) getLayer2().getInputGate().getValues();
        DenseNDArray denseNDArray2 = (DenseNDArray) getLayer2().getForgetGate().getValues();
        values.assignProd(denseNDArray, values2);
        if (prevState != null) {
            DenseNDArray valuesNotActivated = prevState.getOutputArray().getValuesNotActivated();
            DenseNDArray values3 = ((RANLayerParameters) layerParameters).getCandidate().getBiases().getValues();
            values3.assignProd(valuesNotActivated, denseNDArray2);
            values.assignSum((NDArray<?>) values3);
        }
        getLayer2().getOutputArray().activate();
    }

    private final void setGates(Layer<?> layer) {
        LayerParameters<?> params = getLayer2().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.recurrent.ran.RANLayerParameters");
        }
        InputNDArrayType values = getLayer2().getInputArray().getValues();
        ArrayExtensionsKt.forward(getLayer2().getInputGate(), ((RANLayerParameters) getLayer2().getParams()).getInputGate().getWeights().getValues(), ((RANLayerParameters) getLayer2().getParams()).getInputGate().getBiases().getValues(), values);
        ArrayExtensionsKt.forward(getLayer2().getForgetGate(), ((RANLayerParameters) getLayer2().getParams()).getForgetGate().getWeights().getValues(), ((RANLayerParameters) getLayer2().getParams()).getForgetGate().getBiases().getValues(), values);
        ArrayExtensionsKt.forward(getLayer2().getCandidate(), ((RANLayerParameters) getLayer2().getParams()).getCandidate().getWeights().getValues(), ((RANLayerParameters) getLayer2().getParams()).getCandidate().getBiases().getValues(), values);
        if (layer != null) {
            DenseNDArray valuesNotActivated = layer.getOutputArray().getValuesNotActivated();
            getLayer2().getInputGate().addRecurrentContribution(((RANLayerParameters) getLayer2().getParams()).getInputGate(), valuesNotActivated);
            getLayer2().getForgetGate().addRecurrentContribution(((RANLayerParameters) getLayer2().getParams()).getForgetGate(), valuesNotActivated);
        }
        getLayer2().getInputGate().activate();
        getLayer2().getForgetGate().activate();
    }

    private final void setGates(Layer<?> layer, RANLayerParameters rANLayerParameters) {
        LayerParameters<?> params = getLayer2().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.recurrent.ran.RANLayerParameters");
        }
        boolean z = layer != null;
        RecurrentLinearParams inputGate = ((RANLayerParameters) getLayer2().getParams()).getInputGate();
        RecurrentLinearParams forgetGate = ((RANLayerParameters) getLayer2().getParams()).getForgetGate();
        DenseNDArray values = inputGate.getBiases().getValues();
        DenseNDArray values2 = forgetGate.getBiases().getValues();
        DenseNDArray div = z ? values.div(2.0d) : values;
        DenseNDArray div2 = z ? values2.div(2.0d) : values2;
        forwardGates(rANLayerParameters, div, div2);
        if (layer != null) {
            addGatesRecurrentContribution(layer.getOutputArray().getValuesNotActivated(), div, div2, rANLayerParameters);
        }
        getLayer2().getInputGate().activate();
        getLayer2().getForgetGate().activate();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final void forwardGates(RANLayerParameters rANLayerParameters, DenseNDArray denseNDArray, DenseNDArray denseNDArray2) {
        boolean z = getLayer2().getInputArray().getValues() instanceof DenseNDArray;
        if (_Assertions.ENABLED && !z) {
            throw new AssertionError("Forwarding with contributions requires the input to be dense.");
        }
        LayerParameters<?> params = getLayer2().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.recurrent.ran.RANLayerParameters");
        }
        InputNDArrayType values = getLayer2().getInputArray().getValues();
        if (values == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray3 = (DenseNDArray) values;
        LinearParams candidate = ((RANLayerParameters) getLayer2().getParams()).getCandidate();
        forwardArray(rANLayerParameters.getCandidate().getWeights().getValues(), denseNDArray3, getLayer2().getCandidate().getValues(), candidate.getWeights().getValues(), candidate.getBiases().getValues());
        forwardArray(rANLayerParameters.getInputGate().getWeights().getValues(), denseNDArray3, (DenseNDArray) getLayer2().getInputGate().getValues(), ((RANLayerParameters) getLayer2().getParams()).getInputGate().getWeights().getValues(), denseNDArray);
        forwardArray(rANLayerParameters.getForgetGate().getWeights().getValues(), denseNDArray3, (DenseNDArray) getLayer2().getForgetGate().getValues(), ((RANLayerParameters) getLayer2().getParams()).getForgetGate().getWeights().getValues(), denseNDArray2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final void addGatesRecurrentContribution(DenseNDArray denseNDArray, DenseNDArray denseNDArray2, DenseNDArray denseNDArray3, RANLayerParameters rANLayerParameters) {
        LayerParameters<?> params = getLayer2().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.recurrent.ran.RANLayerParameters");
        }
        RecurrentLinearParams inputGate = ((RANLayerParameters) getLayer2().getParams()).getInputGate();
        RecurrentLinearParams forgetGate = ((RANLayerParameters) getLayer2().getParams()).getForgetGate();
        addRecurrentContribution(rANLayerParameters.getInputGate().getRecurrentWeights().getValues(), denseNDArray, rANLayerParameters.getInputGate().getBiases().getValues(), (DenseNDArray) getLayer2().getInputGate().getValues(), inputGate.getRecurrentWeights().getValues(), denseNDArray2);
        addRecurrentContribution(rANLayerParameters.getForgetGate().getRecurrentWeights().getValues(), denseNDArray, rANLayerParameters.getForgetGate().getBiases().getValues(), (DenseNDArray) getLayer2().getForgetGate().getValues(), forgetGate.getRecurrentWeights().getValues(), denseNDArray3);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper
    @NotNull
    /* renamed from: getLayer */
    public RANLayer<InputNDArrayType> getLayer2() {
        return this.layer;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public RANForwardHelper(@NotNull RANLayer<InputNDArrayType> rANLayer) {
        super(rANLayer);
        Intrinsics.checkParameterIsNotNull(rANLayer, "layer");
        this.layer = rANLayer;
    }
}
