package com.kotlinnlp.simplednn.core.neuralprocessor;

import com.kotlinnlp.simplednn.core.arrays.ParamsArray;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: ChainProcessor.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��F\n\u0002\u0018\u0002\n��\n\u0002\u0010��\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010 \n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\t\n\u0002\u0010\u000b\n\u0002\b\u0005\n\u0002\u0010\u0002\n\u0002\b\t\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��*\n\b��\u0010\u0001 ��*\u00020\u0002*\n\b\u0001\u0010\u0003 \u0001*\u00020\u0002*\n\b\u0002\u0010\u0004 ��*\u00020\u0002*\n\b\u0003\u0010\u0005 \u0001*\u00020\u0002*\b\b\u0004\u0010\u0006*\u00020\u00022\u001a\u0012\u0004\u0012\u0002H\u0001\u0012\u0004\u0012\u0002H\u0003\u0012\u0004\u0012\u0002H\u0004\u0012\u0004\u0012\u0002H\u00050\u0007Bu\u0012\u001e\u0010\b\u001a\u001a\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u00030\u0007\u0012$\u0010\t\u001a \u0012\u001c\u0012\u001a\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u00040\u00070\n\u0012\u001e\u0010\u000b\u001a\u001a\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u00040\u0007\u0012\b\b\u0002\u0010\f\u001a\u00020\r¢\u0006\u0002\u0010\u000eJ\u0015\u0010\u001c\u001a\u00020\u001d2\u0006\u0010\u001e\u001a\u00028\u0002H\u0016¢\u0006\u0002\u0010\u001fJ\u0015\u0010 \u001a\u00028\u00012\u0006\u0010!\u001a\u00028��H\u0016¢\u0006\u0002\u0010\"J\u0015\u0010#\u001a\u00028\u00032\u0006\u0010$\u001a\u00020\u0017H\u0016¢\u0006\u0002\u0010%J\"\u0010&\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030'R\u00020(0\nj\u0002`)2\u0006\u0010$\u001a\u00020\u0017H\u0016J7\u0010*\u001a\u00028\u0004* \u0012\u001c\u0012\u001a\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u00040\u00070\n2\u0006\u0010\u001e\u001a\u00028\u0004H\u0002¢\u0006\u0002\u0010+J7\u0010 \u001a\u00028\u0004* \u0012\u001c\u0012\u001a\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u00040\u00070\n2\u0006\u0010!\u001a\u00028\u0004H\u0002¢\u0006\u0002\u0010+J<\u0010&\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030'R\u00020(0\nj\u0002`)*\u0018\u0012\u0014\u0012\u0012\u0012\u0002\b\u0003\u0012\u0002\b\u0003\u0012\u0002\b\u0003\u0012\u0002\b\u00030\u00070\n2\u0006\u0010$\u001a\u00020\u0017H\u0002R/\u0010\t\u001a \u0012\u001c\u0012\u001a\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u00040\u00070\n¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R\u0014\u0010\f\u001a\u00020\rX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R)\u0010\b\u001a\u001a\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u00030\u0007¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R)\u0010\u000b\u001a\u001a\u0012\u0004\u0012\u00028\u0004\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u00040\u0007¢\u0006\b\n��\u001a\u0004\b\u0015\u0010\u0014R\u0014\u0010\u0016\u001a\u00020\u0017X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0018\u0010\u0019R\u0014\u0010\u001a\u001a\u00020\u0017X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001b\u0010\u0019¨\u0006,"}, d2 = {"Lcom/kotlinnlp/simplednn/core/neuralprocessor/ChainProcessor;", "InputType", "", "OutputType", "ErrorsType", "InputErrorsType", "HiddenIOType", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "inputProcessor", "hiddenProcessors", "", "outputProcessor", "id", "", "(Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;Ljava/util/List;Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;I)V", "getHiddenProcessors", "()Ljava/util/List;", "getId", "()I", "getInputProcessor", "()Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "getOutputProcessor", "propagateToInput", "", "getPropagateToInput", "()Z", "useDropout", "getUseDropout", "backward", "", "outputErrors", "(Ljava/lang/Object;)V", "forward", "input", "(Ljava/lang/Object;)Ljava/lang/Object;", "getInputErrors", "copy", "(Z)Ljava/lang/Object;", "getParamsErrors", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray$Errors;", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray;", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsList;", "backwardAndGetInputErrors", "(Ljava/util/List;Ljava/lang/Object;)Ljava/lang/Object;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/neuralprocessor/ChainProcessor.class */
public final class ChainProcessor<InputType, OutputType, ErrorsType, InputErrorsType, HiddenIOType> implements NeuralProcessor<InputType, OutputType, ErrorsType, InputErrorsType> {
    private final boolean useDropout;
    private final boolean propagateToInput;

    @NotNull
    private final NeuralProcessor<InputType, HiddenIOType, HiddenIOType, InputErrorsType> inputProcessor;

    @NotNull
    private final List<NeuralProcessor<HiddenIOType, HiddenIOType, HiddenIOType, HiddenIOType>> hiddenProcessors;

    @NotNull
    private final NeuralProcessor<HiddenIOType, OutputType, ErrorsType, HiddenIOType> outputProcessor;
    private final int id;

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getUseDropout() {
        return this.useDropout;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getPropagateToInput() {
        return this.propagateToInput;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public OutputType forward(@NotNull InputType inputtype) {
        Intrinsics.checkParameterIsNotNull(inputtype, "input");
        return (OutputType) this.outputProcessor.forward(forward(this.hiddenProcessors, this.inputProcessor.forward(inputtype)));
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public void backward(@NotNull ErrorsType errorstype) {
        Intrinsics.checkParameterIsNotNull(errorstype, "outputErrors");
        NeuralProcessor<InputType, HiddenIOType, HiddenIOType, InputErrorsType> neuralProcessor = this.inputProcessor;
        List<? extends NeuralProcessor<? super HiddenIOType, ? extends HiddenIOType, ? super HiddenIOType, ? extends HiddenIOType>> list = this.hiddenProcessors;
        NeuralProcessor<HiddenIOType, OutputType, ErrorsType, HiddenIOType> neuralProcessor2 = this.outputProcessor;
        neuralProcessor2.backward(errorstype);
        neuralProcessor.backward(backwardAndGetInputErrors(list, neuralProcessor2.getInputErrors(false)));
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public List<ParamsArray.Errors<?>> getParamsErrors(boolean z) {
        return CollectionsKt.plus(CollectionsKt.plus(this.inputProcessor.getParamsErrors(z), getParamsErrors(this.hiddenProcessors, z)), this.outputProcessor.getParamsErrors(z));
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public InputErrorsType getInputErrors(boolean z) {
        return this.inputProcessor.getInputErrors(z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final HiddenIOType forward(@NotNull List<? extends NeuralProcessor<? super HiddenIOType, ? extends HiddenIOType, ? super HiddenIOType, ? extends HiddenIOType>> list, HiddenIOType hiddeniotype) {
        HiddenIOType hiddeniotype2 = hiddeniotype;
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            hiddeniotype2 = ((NeuralProcessor) it.next()).forward(hiddeniotype2);
        }
        return hiddeniotype2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private final HiddenIOType backwardAndGetInputErrors(@NotNull List<? extends NeuralProcessor<? super HiddenIOType, ? extends HiddenIOType, ? super HiddenIOType, ? extends HiddenIOType>> list, HiddenIOType hiddeniotype) {
        HiddenIOType hiddeniotype2 = hiddeniotype;
        for (NeuralProcessor neuralProcessor : CollectionsKt.asReversed(list)) {
            neuralProcessor.backward(hiddeniotype2);
            hiddeniotype2 = neuralProcessor.getInputErrors(false);
        }
        return hiddeniotype2;
    }

    private final List<ParamsArray.Errors<?>> getParamsErrors(@NotNull List<? extends NeuralProcessor<?, ?, ?, ?>> list, boolean z) {
        List<? extends NeuralProcessor<?, ?, ?, ?>> list2 = list;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add(((NeuralProcessor) it.next()).getParamsErrors(z));
        }
        return CollectionsKt.flatten(arrayList);
    }

    @NotNull
    public final NeuralProcessor<InputType, HiddenIOType, HiddenIOType, InputErrorsType> getInputProcessor() {
        return this.inputProcessor;
    }

    @NotNull
    public final List<NeuralProcessor<HiddenIOType, HiddenIOType, HiddenIOType, HiddenIOType>> getHiddenProcessors() {
        return this.hiddenProcessors;
    }

    @NotNull
    public final NeuralProcessor<HiddenIOType, OutputType, ErrorsType, HiddenIOType> getOutputProcessor() {
        return this.outputProcessor;
    }

    public int getId() {
        return this.id;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public ChainProcessor(@NotNull NeuralProcessor<? super InputType, ? extends HiddenIOType, ? super HiddenIOType, ? extends InputErrorsType> neuralProcessor, @NotNull List<? extends NeuralProcessor<? super HiddenIOType, ? extends HiddenIOType, ? super HiddenIOType, ? extends HiddenIOType>> list, @NotNull NeuralProcessor<? super HiddenIOType, ? extends OutputType, ? super ErrorsType, ? extends HiddenIOType> neuralProcessor2, int i) {
        boolean z;
        boolean z2;
        Intrinsics.checkParameterIsNotNull(neuralProcessor, "inputProcessor");
        Intrinsics.checkParameterIsNotNull(list, "hiddenProcessors");
        Intrinsics.checkParameterIsNotNull(neuralProcessor2, "outputProcessor");
        this.inputProcessor = neuralProcessor;
        this.hiddenProcessors = list;
        this.outputProcessor = neuralProcessor2;
        this.id = i;
        ChainProcessor<InputType, OutputType, ErrorsType, InputErrorsType, HiddenIOType> chainProcessor = this;
        if (!this.inputProcessor.getUseDropout()) {
            List<NeuralProcessor<HiddenIOType, HiddenIOType, HiddenIOType, HiddenIOType>> list2 = this.hiddenProcessors;
            if (!(list2 instanceof Collection) || !list2.isEmpty()) {
                Iterator<T> it = list2.iterator();
                while (true) {
                    if (!it.hasNext()) {
                        z2 = false;
                        break;
                    } else if (((NeuralProcessor) it.next()).getUseDropout()) {
                        z2 = true;
                        break;
                    }
                }
            } else {
                z2 = false;
            }
            boolean z3 = z2;
            chainProcessor = chainProcessor;
            if (!z3 && !this.outputProcessor.getUseDropout()) {
                z = false;
                chainProcessor.useDropout = z;
                this.propagateToInput = this.inputProcessor.getPropagateToInput();
            }
        }
        z = true;
        chainProcessor.useDropout = z;
        this.propagateToInput = this.inputProcessor.getPropagateToInput();
    }

    public /* synthetic */ ChainProcessor(NeuralProcessor neuralProcessor, List list, NeuralProcessor neuralProcessor2, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(neuralProcessor, list, neuralProcessor2, (i2 & 8) != 0 ? 0 : i);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public InputErrorsType propagateErrors(@NotNull ErrorsType errorstype, @NotNull ParamsOptimizer paramsOptimizer, boolean z) {
        Intrinsics.checkParameterIsNotNull(errorstype, "errors");
        Intrinsics.checkParameterIsNotNull(paramsOptimizer, "optimizer");
        return (InputErrorsType) NeuralProcessor.DefaultImpls.propagateErrors(this, errorstype, paramsOptimizer, z);
    }
}
