package com.kotlinnlp.simplednn.deeplearning.attention.pointernetwork;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.arrays.ParamsArray;
import com.kotlinnlp.simplednn.core.layers.helpers.ParamsErrorsCollector;
import com.kotlinnlp.simplednn.core.layers.models.attention.attentionmechanism.AttentionMechanismLayer;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: BackwardHelper.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��L\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010\b\n\u0002\b\u0004\n\u0002\u0010\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0004\u0018��2\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u0014\u0010\u0016\u001a\u00020\u00172\f\u0010\u0018\u001a\b\u0012\u0004\u0012\u00020\u000b0\nJ\u0016\u0010\u0019\u001a\u00020\u000b2\f\u0010\u0018\u001a\b\u0012\u0004\u0012\u00020\u000b0\nH\u0002J\u0016\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u000b0\n2\u0006\u0010\u0018\u001a\u00020\u000bH\u0002J\u0010\u0010\u001b\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u000bH\u0002J\u001e\u0010\u001c\u001a\u0010\u0012\f\u0012\n\u0012\u0002\b\u00030\u001dR\u00020\u001e0\n2\b\b\u0002\u0010\u001f\u001a\u00020 J\b\u0010!\u001a\u00020\u0017H\u0002J\b\u0010\"\u001a\u00020\u0017H\u0002J\b\u0010#\u001a\u00020\u0017H\u0002R\u000e\u0010\u0005\u001a\u00020\u0006X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0007\u001a\u00020\bX\u0082\u000e¢\u0006\u0002\n��R0\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u000b0\n2\f\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u000b0\n@BX\u0080.¢\u0006\u000e\n��\u001a\u0004\b\r\u0010\u000e\"\u0004\b\u000f\u0010\u0010R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0011\u001a\u00020\u0012X\u0082\u000e¢\u0006\u0002\n��R0\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\u000b0\n2\f\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u000b0\n@BX\u0080.¢\u0006\u000e\n��\u001a\u0004\b\u0014\u0010\u000e\"\u0004\b\u0015\u0010\u0010¨\u0006$"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/attention/pointernetwork/BackwardHelper;", "", "networkProcessor", "Lcom/kotlinnlp/simplednn/deeplearning/attention/pointernetwork/PointerNetworkProcessor;", "(Lcom/kotlinnlp/simplednn/deeplearning/attention/pointernetwork/PointerNetworkProcessor;)V", "attentionErrorsAccumulator", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "attentionParamsErrorsCollector", "Lcom/kotlinnlp/simplednn/core/layers/helpers/ParamsErrorsCollector;", "<set-?>", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "inputSequenceErrors", "getInputSequenceErrors$simplednn", "()Ljava/util/List;", "setInputSequenceErrors", "(Ljava/util/List;)V", "stateIndex", "", "vectorsErrors", "getVectorsErrors$simplednn", "setVectorsErrors", "backward", "", "outputErrors", "backwardAttentionArrays", "backwardAttentionScores", "backwardStep", "getParamsErrors", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray$Errors;", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray;", "copy", "", "initBackward", "initInputSequenceErrors", "initVectorsErrors", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/attention/pointernetwork/BackwardHelper.class */
public final class BackwardHelper {

    @NotNull
    private List<DenseNDArray> inputSequenceErrors;

    @NotNull
    private List<DenseNDArray> vectorsErrors;
    private int stateIndex;
    private ParamsErrorsCollector attentionParamsErrorsCollector;
    private ParamsErrorsAccumulator attentionErrorsAccumulator;
    private final PointerNetworkProcessor networkProcessor;

    @NotNull
    public final List<DenseNDArray> getInputSequenceErrors$simplednn() {
        List<DenseNDArray> list = this.inputSequenceErrors;
        if (list == null) {
            Intrinsics.throwUninitializedPropertyAccessException("inputSequenceErrors");
        }
        return list;
    }

    private final void setInputSequenceErrors(List<DenseNDArray> list) {
        this.inputSequenceErrors = list;
    }

    @NotNull
    public final List<DenseNDArray> getVectorsErrors$simplednn() {
        List<DenseNDArray> list = this.vectorsErrors;
        if (list == null) {
            Intrinsics.throwUninitializedPropertyAccessException("vectorsErrors");
        }
        return list;
    }

    private final void setVectorsErrors(List<DenseNDArray> list) {
        this.vectorsErrors = list;
    }

    public final void backward(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "outputErrors");
        initBackward();
        IntIterator it = RangesKt.reversed(RangesKt.until(0, list.size())).iterator();
        while (it.hasNext()) {
            int nextInt = it.nextInt();
            this.stateIndex = nextInt;
            backwardStep(list.get(nextInt));
        }
        this.attentionErrorsAccumulator.averageErrors();
    }

    @NotNull
    public final List<ParamsArray.Errors<?>> getParamsErrors(boolean z) {
        return CollectionsKt.plus(this.networkProcessor.getMergeProcessor$simplednn().getParamsErrors(z), this.attentionErrorsAccumulator.getParamsErrors(z));
    }

    @NotNull
    public static /* synthetic */ List getParamsErrors$default(BackwardHelper backwardHelper, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return backwardHelper.getParamsErrors(z);
    }

    private final void backwardStep(DenseNDArray denseNDArray) {
        DenseNDArray backwardAttentionArrays = backwardAttentionArrays(backwardAttentionScores(denseNDArray));
        List<DenseNDArray> list = this.vectorsErrors;
        if (list == null) {
            Intrinsics.throwUninitializedPropertyAccessException("vectorsErrors");
        }
        list.get(this.stateIndex).assignValues((NDArray<?>) backwardAttentionArrays);
    }

    private final List<DenseNDArray> backwardAttentionScores(DenseNDArray denseNDArray) {
        AttentionMechanismLayer attentionMechanismLayer = this.networkProcessor.getUsedAttentionMechanisms$simplednn().get(this.stateIndex);
        attentionMechanismLayer.setErrors(denseNDArray);
        attentionMechanismLayer.setParamsErrorsCollector(this.attentionParamsErrorsCollector);
        ParamsErrorsAccumulator.accumulate$default(this.attentionErrorsAccumulator, (List) attentionMechanismLayer.backward(true), false, 2, (Object) null);
        List<AugmentedArray<DenseNDArray>> inputArrays = attentionMechanismLayer.getInputArrays();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(inputArrays, 10));
        Iterator<T> it = inputArrays.iterator();
        while (it.hasNext()) {
            arrayList.add(((AugmentedArray) it.next()).getErrors());
        }
        return arrayList;
    }

    private final DenseNDArray backwardAttentionArrays(List<DenseNDArray> list) {
        DenseNDArray zeros = DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.networkProcessor.getModel().getInputSize(), 0, 2, null));
        this.networkProcessor.getMergeProcessor$simplednn().backward2(list);
        int i = 0;
        for (Object obj : this.networkProcessor.getMergeProcessor$simplednn().getInputsErrors(true)) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            List list2 = (List) obj;
            DenseNDArray denseNDArray = (DenseNDArray) list2.get(0);
            zeros.assignSum((NDArray<?>) list2.get(1));
            List<DenseNDArray> list3 = this.inputSequenceErrors;
            if (list3 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("inputSequenceErrors");
            }
            list3.get(i2).assignSum((NDArray<?>) denseNDArray);
        }
        return zeros;
    }

    private final void initBackward() {
        initInputSequenceErrors();
        initVectorsErrors();
        this.attentionErrorsAccumulator.clear();
    }

    private final void initInputSequenceErrors() {
        int size = this.networkProcessor.getInputSequence$simplednn().size();
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            arrayList.add(DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.networkProcessor.getModel().getInputSize(), 0, 2, null)));
        }
        this.inputSequenceErrors = arrayList;
    }

    private final void initVectorsErrors() {
        int forwardCount$simplednn = this.networkProcessor.getForwardCount$simplednn();
        ArrayList arrayList = new ArrayList(forwardCount$simplednn);
        for (int i = 0; i < forwardCount$simplednn; i++) {
            arrayList.add(DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.networkProcessor.getModel().getVectorSize(), 0, 2, null)));
        }
        this.vectorsErrors = arrayList;
    }

    public BackwardHelper(@NotNull PointerNetworkProcessor pointerNetworkProcessor) {
        Intrinsics.checkParameterIsNotNull(pointerNetworkProcessor, "networkProcessor");
        this.networkProcessor = pointerNetworkProcessor;
        this.attentionParamsErrorsCollector = new ParamsErrorsCollector();
        this.attentionErrorsAccumulator = new ParamsErrorsAccumulator();
    }
}
