package com.kotlinnlp.simplednn.core.layers;

import com.kotlinnlp.simplednn.core.arrays.DistributionArray;
import com.kotlinnlp.simplednn.core.arrays.ParamsArray;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.helpers.ParamsErrorsCollector;
import com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayer;
import com.kotlinnlp.simplednn.core.layers.models.merge.MergeLayer;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.LayerContextWindow;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IndexedValue;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: StackedLayers.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��n\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\t\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\b\u0016\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\u00020\u0003B%\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\u0010\u0010\u0007\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\b0\u0005¢\u0006\u0002\u0010\tJ*\u0010\u001f\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030 R\u00020!0\u0005j\u0002`\"2\u0006\u0010#\u001a\u00020\u001b2\b\b\u0002\u0010$\u001a\u00020%J\u0012\u0010&\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u00110\u0005H\u0004J%\u0010'\u001a\u00020\u001b2\u0006\u0010(\u001a\u00028��2\u0006\u0010)\u001a\u00020*2\b\b\u0002\u0010+\u001a\u00020%¢\u0006\u0002\u0010,J\u001d\u0010'\u001a\u00020\u001b2\u0006\u0010(\u001a\u00028��2\b\b\u0002\u0010+\u001a\u00020%¢\u0006\u0002\u0010-J&\u0010'\u001a\u00020\u001b2\f\u0010(\u001a\b\u0012\u0004\u0012\u00028��0\u00052\u0006\u0010)\u001a\u00020*2\b\b\u0002\u0010+\u001a\u00020%J\u001e\u0010'\u001a\u00020\u001b2\f\u0010(\u001a\b\u0012\u0004\u0012\u00028��0\u00052\b\b\u0002\u0010+\u001a\u00020%J\u000e\u0010.\u001a\b\u0012\u0002\b\u0003\u0018\u00010\u0011H\u0016J\u000e\u0010/\u001a\b\u0012\u0002\b\u0003\u0018\u00010\u0011H\u0016J\u0016\u00100\u001a\u0002012\u0006\u0010)\u001a\u00020*2\u0006\u00102\u001a\u000203J\u000e\u00104\u001a\u0002012\u0006\u00105\u001a\u000206R\u001a\u0010\n\u001a\u00020\u000bX\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\f\u0010\r\"\u0004\b\u000e\u0010\u000fR\u001d\u0010\u0010\u001a\b\u0012\u0004\u0012\u00028��0\u00118F¢\u0006\f\u0012\u0004\b\u0012\u0010\u0013\u001a\u0004\b\u0014\u0010\u0015R\u001b\u0010\u0016\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u00110\u0005¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018R\u0017\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005¢\u0006\b\n��\u001a\u0004\b\u0019\u0010\u0018R\u001d\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u001b0\u00118F¢\u0006\f\u0012\u0004\b\u001c\u0010\u0013\u001a\u0004\b\u001d\u0010\u0015R\u001b\u0010\u0007\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\b0\u0005¢\u0006\b\n��\u001a\u0004\b\u001e\u0010\u0018¨\u00067"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/StackedLayers;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/LayerContextWindow;", "layersConfiguration", "", "Lcom/kotlinnlp/simplednn/core/layers/LayerInterface;", "paramsPerLayer", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "(Ljava/util/List;Ljava/util/List;)V", "curLayerIndex", "", "getCurLayerIndex", "()I", "setCurLayerIndex", "(I)V", "inputLayer", "Lcom/kotlinnlp/simplednn/core/layers/Layer;", "inputLayer$annotations", "()V", "getInputLayer", "()Lcom/kotlinnlp/simplednn/core/layers/Layer;", "layers", "getLayers", "()Ljava/util/List;", "getLayersConfiguration", "outputLayer", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "outputLayer$annotations", "getOutputLayer", "getParamsPerLayer", "backward", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray$Errors;", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray;", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsList;", "outputErrors", "propagateToInput", "", "buildLayers", "forward", "input", "stackedLayersContributions", "Lcom/kotlinnlp/simplednn/core/layers/StackedLayersParameters;", "useDropout", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;Lcom/kotlinnlp/simplednn/core/layers/StackedLayersParameters;Z)Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;Z)Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getNextState", "getPrevState", "propagateRelevance", "", "relevantOutcomesDistribution", "Lcom/kotlinnlp/simplednn/core/arrays/DistributionArray;", "setParamsErrorsCollector", "c", "Lcom/kotlinnlp/simplednn/core/layers/helpers/ParamsErrorsCollector;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/StackedLayers.class */
public class StackedLayers<InputNDArrayType extends NDArray<InputNDArrayType>> implements LayerContextWindow {

    @NotNull
    private final List<Layer<?>> layers;
    private int curLayerIndex;

    @NotNull
    private final List<LayerInterface> layersConfiguration;

    @NotNull
    private final List<LayerParameters<?>> paramsPerLayer;

    @NotNull
    public final List<Layer<?>> getLayers() {
        return this.layers;
    }

    public final int getCurLayerIndex() {
        return this.curLayerIndex;
    }

    public final void setCurLayerIndex(int i) {
        this.curLayerIndex = i;
    }

    public static /* synthetic */ void inputLayer$annotations() {
    }

    @NotNull
    public final Layer<InputNDArrayType> getInputLayer() {
        Object first = CollectionsKt.first(this.layers);
        if (first == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.Layer<InputNDArrayType>");
        }
        return (Layer) first;
    }

    public static /* synthetic */ void outputLayer$annotations() {
    }

    @NotNull
    public final Layer<DenseNDArray> getOutputLayer() {
        Object last = CollectionsKt.last(this.layers);
        if (last == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.Layer<com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray>");
        }
        return (Layer) last;
    }

    @NotNull
    public final DenseNDArray forward(@NotNull InputNDArrayType inputndarraytype, boolean z) {
        Intrinsics.checkParameterIsNotNull(inputndarraytype, "input");
        getInputLayer().setInput(inputndarraytype);
        int i = 0;
        Iterator<T> it = this.layers.iterator();
        while (it.hasNext()) {
            Layer layer = (Layer) it.next();
            this.curLayerIndex = i;
            layer.forward(z);
            i++;
        }
        return getOutputLayer().getOutputArray().getValues();
    }

    @NotNull
    public static /* synthetic */ DenseNDArray forward$default(StackedLayers stackedLayers, NDArray nDArray, boolean z, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: forward");
        }
        if ((i & 2) != 0) {
            z = false;
        }
        return stackedLayers.forward((StackedLayers) nDArray, z);
    }

    @NotNull
    public final DenseNDArray forward(@NotNull InputNDArrayType inputndarraytype, @NotNull StackedLayersParameters stackedLayersParameters, boolean z) {
        Intrinsics.checkParameterIsNotNull(inputndarraytype, "input");
        Intrinsics.checkParameterIsNotNull(stackedLayersParameters, "stackedLayersContributions");
        getInputLayer().setInput(inputndarraytype);
        int i = 0;
        Iterator<T> it = this.layers.iterator();
        while (it.hasNext()) {
            Layer layer = (Layer) it.next();
            this.curLayerIndex = i;
            layer.forward(stackedLayersParameters.getParamsPerLayer().get(i), z);
            i++;
        }
        return getOutputLayer().getOutputArray().getValues();
    }

    @NotNull
    public static /* synthetic */ DenseNDArray forward$default(StackedLayers stackedLayers, NDArray nDArray, StackedLayersParameters stackedLayersParameters, boolean z, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: forward");
        }
        if ((i & 4) != 0) {
            z = false;
        }
        return stackedLayers.forward((StackedLayers) nDArray, stackedLayersParameters, z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    public final DenseNDArray forward(@NotNull List<? extends InputNDArrayType> list, boolean z) {
        Intrinsics.checkParameterIsNotNull(list, "input");
        if (!(getInputLayer() instanceof MergeLayer)) {
            throw new IllegalArgumentException("Cannot call the forward with multiple inputs if the first layer is not a Merge layer.".toString());
        }
        Layer<InputNDArrayType> inputLayer = getInputLayer();
        if (inputLayer == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.merge.MergeLayer<InputNDArrayType>");
        }
        MergeLayer mergeLayer = (MergeLayer) inputLayer;
        int i = 0;
        for (Object obj : list) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            mergeLayer.setInput(i2, (NDArray) obj);
        }
        int i3 = 0;
        Iterator<T> it = this.layers.iterator();
        while (it.hasNext()) {
            Layer layer = (Layer) it.next();
            this.curLayerIndex = i3;
            layer.forward(z);
            i3++;
        }
        return getOutputLayer().getOutputArray().getValues();
    }

    @NotNull
    public static /* synthetic */ DenseNDArray forward$default(StackedLayers stackedLayers, List list, boolean z, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: forward");
        }
        if ((i & 2) != 0) {
            z = false;
        }
        return stackedLayers.forward(list, z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    public final DenseNDArray forward(@NotNull List<? extends InputNDArrayType> list, @NotNull StackedLayersParameters stackedLayersParameters, boolean z) {
        Intrinsics.checkParameterIsNotNull(list, "input");
        Intrinsics.checkParameterIsNotNull(stackedLayersParameters, "stackedLayersContributions");
        if (!(getInputLayer() instanceof MergeLayer)) {
            throw new IllegalArgumentException("Cannot call the forward with multiple inputs if the first layer is not a Merge layer.".toString());
        }
        Layer<InputNDArrayType> inputLayer = getInputLayer();
        if (inputLayer == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.merge.MergeLayer<InputNDArrayType>");
        }
        MergeLayer mergeLayer = (MergeLayer) inputLayer;
        int i = 0;
        for (Object obj : list) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            mergeLayer.setInput(i2, (NDArray) obj);
        }
        int i3 = 0;
        Iterator<T> it = this.layers.iterator();
        while (it.hasNext()) {
            Layer layer = (Layer) it.next();
            this.curLayerIndex = i3;
            layer.forward(stackedLayersParameters.getParamsPerLayer().get(i3), z);
            i3++;
        }
        return getOutputLayer().getOutputArray().getValues();
    }

    @NotNull
    public static /* synthetic */ DenseNDArray forward$default(StackedLayers stackedLayers, List list, StackedLayersParameters stackedLayersParameters, boolean z, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: forward");
        }
        if ((i & 4) != 0) {
            z = false;
        }
        return stackedLayers.forward(list, stackedLayersParameters, z);
    }

    @NotNull
    public final List<ParamsArray.Errors<?>> backward(@NotNull DenseNDArray denseNDArray, boolean z) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "outputErrors");
        getOutputLayer().setErrors(denseNDArray);
        ArrayList arrayList = new ArrayList();
        for (IndexedValue indexedValue : CollectionsKt.reversed(CollectionsKt.withIndex(this.layers))) {
            int component1 = indexedValue.component1();
            Layer layer = (Layer) indexedValue.component2();
            this.curLayerIndex = component1;
            arrayList.add(layer.backward(component1 > 0 || z));
        }
        return CollectionsKt.flatten(arrayList);
    }

    @NotNull
    public static /* synthetic */ List backward$default(StackedLayers stackedLayers, DenseNDArray denseNDArray, boolean z, int i, Object obj) {
        if (obj != null) {
            throw new UnsupportedOperationException("Super calls with default arguments not supported in this target, function: backward");
        }
        if ((i & 2) != 0) {
            z = false;
        }
        return stackedLayers.backward(denseNDArray, z);
    }

    public final void propagateRelevance(@NotNull StackedLayersParameters stackedLayersParameters, @NotNull DistributionArray distributionArray) {
        boolean z;
        Intrinsics.checkParameterIsNotNull(stackedLayersParameters, "stackedLayersContributions");
        Intrinsics.checkParameterIsNotNull(distributionArray, "relevantOutcomesDistribution");
        List<Layer<?>> list = this.layers;
        if (!(list instanceof Collection) || !list.isEmpty()) {
            Iterator<T> it = list.iterator();
            while (true) {
                if (it.hasNext()) {
                    if (!(((Layer) it.next()) instanceof FeedforwardLayer)) {
                        z = false;
                        break;
                    }
                } else {
                    z = true;
                    break;
                }
            }
        } else {
            z = true;
        }
        if (!z) {
            throw new IllegalArgumentException("The relevance propagation requires that all the layers must be feed-forward.".toString());
        }
        ((Layer) CollectionsKt.last(this.layers)).setOutputRelevance(distributionArray);
        for (IndexedValue indexedValue : CollectionsKt.reversed(CollectionsKt.withIndex(this.layers))) {
            int component1 = indexedValue.component1();
            Layer layer = (Layer) indexedValue.component2();
            if (layer == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayer<out com.kotlinnlp.simplednn.simplemath.ndarray.NDArray<*>>");
            }
            this.curLayerIndex = component1;
            layer.setInputRelevance(stackedLayersParameters.getParamsPerLayer().get(component1));
        }
    }

    @Override // com.kotlinnlp.simplednn.core.layers.models.recurrent.LayerContextWindow
    @Nullable
    public Layer<?> getPrevState() {
        return null;
    }

    @Override // com.kotlinnlp.simplednn.core.layers.models.recurrent.LayerContextWindow
    @Nullable
    public Layer<?> getNextState() {
        return null;
    }

    public final void setParamsErrorsCollector(@NotNull ParamsErrorsCollector paramsErrorsCollector) {
        Intrinsics.checkParameterIsNotNull(paramsErrorsCollector, "c");
        Iterator<T> it = this.layers.iterator();
        while (it.hasNext()) {
            ((Layer) it.next()).setParamsErrorsCollector(paramsErrorsCollector);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @NotNull
    public final List<Layer<?>> buildLayers() {
        boolean z;
        boolean z2;
        Layer<?> invoke;
        List<LayerInterface> list = this.layersConfiguration;
        List<LayerInterface> subList = list.subList(1, list.size());
        if (!(subList instanceof Collection) || !subList.isEmpty()) {
            Iterator<T> it = subList.iterator();
            while (true) {
                if (!it.hasNext()) {
                    z = true;
                    break;
                }
                if (!(((LayerInterface) it.next()).getType() == LayerType.Input.Dense)) {
                    z = false;
                    break;
                }
            }
        } else {
            z = true;
        }
        if (!z) {
            throw new IllegalArgumentException("The last layers must be dense.".toString());
        }
        List<LayerInterface> subList2 = list.subList(2, list.size());
        if (!(subList2 instanceof Collection) || !subList2.isEmpty()) {
            Iterator<T> it2 = subList2.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    z2 = true;
                    break;
                }
                LayerType.Connection connectionType = ((LayerInterface) it2.next()).getConnectionType();
                if (connectionType == null) {
                    Intrinsics.throwNpe();
                }
                if (!(connectionType.getProperty() != LayerType.Property.Merge)) {
                    z2 = false;
                    break;
                }
            }
        } else {
            z2 = true;
        }
        if (!z2) {
            throw new IllegalArgumentException("Only the first layer can be a Merge layer.".toString());
        }
        Layer<?> layer = (Layer) null;
        int size = list.size() - 1;
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            int i2 = i;
            if (i2 == 0) {
                invoke = LayerFactory.INSTANCE.invoke(list.get(0), list.get(1), this.paramsPerLayer.get(0), this);
            } else {
                LayerFactory layerFactory = LayerFactory.INSTANCE;
                Layer<?> layer2 = layer;
                if (layer2 == null) {
                    Intrinsics.throwNpe();
                }
                invoke = layerFactory.invoke(CollectionsKt.listOf(layer2.getOutputArray()), LayerType.Input.Dense, list.get(i2 + 1), this.paramsPerLayer.get(i2), list.get(i2).getDropout(), this);
            }
            Layer<?> layer3 = invoke;
            layer = layer3;
            arrayList.add(layer3);
        }
        return arrayList;
    }

    @NotNull
    public final List<LayerInterface> getLayersConfiguration() {
        return this.layersConfiguration;
    }

    @NotNull
    public final List<LayerParameters<?>> getParamsPerLayer() {
        return this.paramsPerLayer;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public StackedLayers(@NotNull List<LayerInterface> list, @NotNull List<? extends LayerParameters<?>> list2) {
        Intrinsics.checkParameterIsNotNull(list, "layersConfiguration");
        Intrinsics.checkParameterIsNotNull(list2, "paramsPerLayer");
        this.layersConfiguration = list;
        this.paramsPerLayer = list2;
        this.layers = buildLayers();
    }
}
