package com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.layers.Layer;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import kotlin.Metadata;
import kotlin.NotImplementedError;
import kotlin.TypeCastException;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: LTMForwardHelper.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��.\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\b\u0012\u0004\u0012\u0002H\u00010\u0003B\u0013\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005¢\u0006\u0002\u0010\u0006J\b\u0010\t\u001a\u00020\nH\u0016J\u0014\u0010\t\u001a\u00020\n2\n\u0010\u000b\u001a\u0006\u0012\u0002\b\u00030\fH\u0016J\u0016\u0010\r\u001a\u00020\n2\f\u0010\u000e\u001a\b\u0012\u0002\b\u0003\u0018\u00010\u000fH\u0002J\u0016\u0010\u0010\u001a\u00020\n2\f\u0010\u000e\u001a\b\u0012\u0002\b\u0003\u0018\u00010\u000fH\u0002R\u001a\u0010\u0004\u001a\b\u0012\u0004\u0012\u00028��0\u0005X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\b¨\u0006\u0011"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/ltm/LTMForwardHelper;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/layers/helpers/ForwardHelper;", "layer", "Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/ltm/LTMLayer;", "(Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/ltm/LTMLayer;)V", "getLayer", "()Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/ltm/LTMLayer;", "forward", "", "layerContributions", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "forwardCell", "prevStateLayer", "Lcom/kotlinnlp/simplednn/core/layers/Layer;", "forwardInputGates", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/models/recurrent/ltm/LTMForwardHelper.class */
public final class LTMForwardHelper<InputNDArrayType extends NDArray<InputNDArrayType>> extends ForwardHelper<InputNDArrayType> {

    @NotNull
    private final LTMLayer<InputNDArrayType> layer;

    @Override // com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper
    public void forward() {
        Layer<?> prevState = getLayer2().getLayerContextWindow().getPrevState();
        if (!(prevState instanceof LTMLayer)) {
            prevState = null;
        }
        LTMLayer lTMLayer = (LTMLayer) prevState;
        forwardInputGates(lTMLayer);
        forwardCell(lTMLayer);
        DenseNDArray values = getLayer2().getInputGate3().getValues();
        getLayer2().getOutputArray().getValues().assignProd(getLayer2().getCell().getValues(), values);
    }

    @Override // com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper
    public void forward(@NotNull LayerParameters<?> layerParameters) {
        Intrinsics.checkParameterIsNotNull(layerParameters, "layerContributions");
        throw new NotImplementedError("An operation is not implemented: not implemented");
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final void forwardInputGates(Layer<?> layer) {
        DenseNDArray denseNDArray;
        DenseNDArray sum;
        LayerParameters<?> params = getLayer2().getParams();
        if (params == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayerParameters");
        }
        if (layer != null) {
            AugmentedArray<DenseNDArray> outputArray = layer.getOutputArray();
            if (outputArray != null) {
                denseNDArray = outputArray.getValues();
                DenseNDArray denseNDArray2 = denseNDArray;
                getLayer2().setX((denseNDArray2 != 0 || (sum = denseNDArray2.sum((DenseNDArray) getLayer2().getInputArray().getValues())) == null) ? getLayer2().getInputArray().getValues() : sum);
                getLayer2().getInputGate1().forward(((LTMLayerParameters) getLayer2().getParams()).getInputGate1().getWeights().getValues(), ((LTMLayerParameters) getLayer2().getParams()).getInputGate1().getBiases().getValues(), getLayer2().getX());
                getLayer2().getInputGate2().forward(((LTMLayerParameters) getLayer2().getParams()).getInputGate2().getWeights().getValues(), ((LTMLayerParameters) getLayer2().getParams()).getInputGate2().getBiases().getValues(), getLayer2().getX());
                getLayer2().getInputGate3().forward(((LTMLayerParameters) getLayer2().getParams()).getInputGate3().getWeights().getValues(), ((LTMLayerParameters) getLayer2().getParams()).getInputGate3().getBiases().getValues(), getLayer2().getX());
                getLayer2().getInputGate1().activate();
                getLayer2().getInputGate2().activate();
                getLayer2().getInputGate3().activate();
            }
        }
        denseNDArray = null;
        DenseNDArray denseNDArray22 = denseNDArray;
        getLayer2().setX((denseNDArray22 != 0 || (sum = denseNDArray22.sum((DenseNDArray) getLayer2().getInputArray().getValues())) == null) ? getLayer2().getInputArray().getValues() : sum);
        getLayer2().getInputGate1().forward(((LTMLayerParameters) getLayer2().getParams()).getInputGate1().getWeights().getValues(), ((LTMLayerParameters) getLayer2().getParams()).getInputGate1().getBiases().getValues(), getLayer2().getX());
        getLayer2().getInputGate2().forward(((LTMLayerParameters) getLayer2().getParams()).getInputGate2().getWeights().getValues(), ((LTMLayerParameters) getLayer2().getParams()).getInputGate2().getBiases().getValues(), getLayer2().getX());
        getLayer2().getInputGate3().forward(((LTMLayerParameters) getLayer2().getParams()).getInputGate3().getWeights().getValues(), ((LTMLayerParameters) getLayer2().getParams()).getInputGate3().getBiases().getValues(), getLayer2().getX());
        getLayer2().getInputGate1().activate();
        getLayer2().getInputGate2().activate();
        getLayer2().getInputGate3().activate();
    }

    /* JADX WARN: Removed duplicated region for block: B:16:0x0072  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private final void forwardCell(com.kotlinnlp.simplednn.core.layers.Layer<?> r6) {
        /*
            r5 = this;
            r0 = r5
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r0 = r0.getLayer2()
            com.kotlinnlp.simplednn.core.layers.LayerParameters r0 = r0.getParams()
            r1 = r0
            if (r1 != 0) goto L15
            kotlin.TypeCastException r1 = new kotlin.TypeCastException
            r2 = r1
            java.lang.String r3 = "null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayerParameters"
            r2.<init>(r3)
            throw r1
        L15:
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayerParameters r0 = (com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayerParameters) r0
            r0 = r5
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r0 = r0.getLayer2()
            com.kotlinnlp.simplednn.core.arrays.AugmentedArray r0 = r0.getInputGate1()
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r0 = r0.getValues()
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = (com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray) r0
            r7 = r0
            r0 = r5
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r0 = r0.getLayer2()
            com.kotlinnlp.simplednn.core.arrays.AugmentedArray r0 = r0.getInputGate2()
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r0 = r0.getValues()
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = (com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray) r0
            r8 = r0
            r0 = r6
            r1 = r0
            boolean r1 = r1 instanceof com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer
            if (r1 != 0) goto L3f
        L3e:
            r0 = 0
        L3f:
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r0 = (com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer) r0
            r1 = r0
            if (r1 == 0) goto L56
            com.kotlinnlp.simplednn.core.arrays.AugmentedArray r0 = r0.getCell()
            r1 = r0
            if (r1 == 0) goto L56
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r0 = r0.getValues()
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = (com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray) r0
            goto L58
        L56:
            r0 = 0
        L58:
            r9 = r0
            r0 = r5
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r0 = r0.getLayer2()
            com.kotlinnlp.simplednn.core.arrays.AugmentedArray r0 = r0.getC()
            r1 = r7
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r1 = (com.kotlinnlp.simplednn.simplemath.ndarray.NDArray) r1
            r2 = r8
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r2 = (com.kotlinnlp.simplednn.simplemath.ndarray.NDArray) r2
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r0 = r0.assignValuesByProd(r1, r2)
            r0 = r9
            if (r0 == 0) goto L88
            r0 = r5
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r0 = r0.getLayer2()
            com.kotlinnlp.simplednn.core.arrays.AugmentedArray r0 = r0.getC()
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r0 = r0.getValues()
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = (com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray) r0
            r1 = r9
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r1 = (com.kotlinnlp.simplednn.simplemath.ndarray.NDArray) r1
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = r0.assignSum(r1)
        L88:
            r0 = r5
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r0 = r0.getLayer2()
            com.kotlinnlp.simplednn.core.arrays.AugmentedArray r0 = r0.getCell()
            r1 = r5
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r1 = r1.getLayer2()
            com.kotlinnlp.simplednn.core.layers.LayerParameters r1 = r1.getParams()
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayerParameters r1 = (com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayerParameters) r1
            com.kotlinnlp.simplednn.core.layers.models.LinearParams r1 = r1.getCell()
            com.kotlinnlp.simplednn.core.arrays.ParamsArray r1 = r1.getWeights()
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r1 = r1.getValues()
            r2 = r5
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r2 = r2.getLayer2()
            com.kotlinnlp.simplednn.core.layers.LayerParameters r2 = r2.getParams()
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayerParameters r2 = (com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayerParameters) r2
            com.kotlinnlp.simplednn.core.layers.models.LinearParams r2 = r2.getCell()
            com.kotlinnlp.simplednn.core.arrays.ParamsArray r2 = r2.getBiases()
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r2 = r2.getValues()
            r3 = r5
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r3 = r3.getLayer2()
            com.kotlinnlp.simplednn.core.arrays.AugmentedArray r3 = r3.getC()
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r3 = r3.getValues()
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r0 = r0.forward(r1, r2, r3)
            r0 = r5
            com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMLayer r0 = r0.getLayer2()
            com.kotlinnlp.simplednn.core.arrays.AugmentedArray r0 = r0.getCell()
            r0.activate()
            return
        */
        throw new UnsupportedOperationException("Method not decompiled: com.kotlinnlp.simplednn.core.layers.models.recurrent.ltm.LTMForwardHelper.forwardCell(com.kotlinnlp.simplednn.core.layers.Layer):void");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.helpers.ForwardHelper
    @NotNull
    /* renamed from: getLayer */
    public LTMLayer<InputNDArrayType> getLayer2() {
        return this.layer;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public LTMForwardHelper(@NotNull LTMLayer<InputNDArrayType> lTMLayer) {
        super(lTMLayer);
        Intrinsics.checkParameterIsNotNull(lTMLayer, "layer");
        this.layer = lTMLayer;
    }
}
