package com.kotlinnlp.simplednn.core.layers;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.layers.models.feedforward.highway.HighwayLayer;
import com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayer;
import com.kotlinnlp.simplednn.core.layers.models.merge.affine.AffineLayer;
import com.kotlinnlp.simplednn.core.layers.models.merge.affine.AffineLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.avg.AvgLayer;
import com.kotlinnlp.simplednn.core.layers.models.merge.avg.AvgLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.biaffine.BiaffineLayer;
import com.kotlinnlp.simplednn.core.layers.models.merge.biaffine.BiaffineLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.concat.ConcatLayer;
import com.kotlinnlp.simplednn.core.layers.models.merge.concat.ConcatLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.product.ProductLayer;
import com.kotlinnlp.simplednn.core.layers.models.merge.product.ProductLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.sum.SumLayer;
import com.kotlinnlp.simplednn.core.layers.models.merge.sum.SumLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.LayerContextWindow;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.RecurrentLayerUnit;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.cfn.CFNLayer;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.deltarnn.DeltaRNNLayer;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.gru.GRULayer;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.indrnn.IndRNNLayer;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.lstm.LSTMLayer;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.ran.RANLayer;
import com.kotlinnlp.simplednn.core.layers.models.recurrent.simple.SimpleRecurrentLayer;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: LayerFactory.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��T\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\bÆ\u0002\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J5\u0010\u0003\u001a\u0006\u0012\u0002\b\u00030\u00042\u0006\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\u00062\n\u0010\b\u001a\u0006\u0012\u0002\b\u00030\t2\n\b\u0002\u0010\n\u001a\u0004\u0018\u00010\u000bH\u0086\u0002Jc\u0010\u0003\u001a\b\u0012\u0004\u0012\u0002H\f0\u0004\"\u000e\b��\u0010\f*\b\u0012\u0004\u0012\u0002H\f0\r2\u0012\u0010\u000e\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u0002H\f0\u00100\u000f2\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0007\u001a\u00020\u00062\n\u0010\b\u001a\u0006\u0012\u0002\b\u00030\t2\u0006\u0010\u0013\u001a\u00020\u00142\n\b\u0002\u0010\n\u001a\u0004\u0018\u00010\u000bH\u0086\u0002Jy\u0010\u0003\u001a\b\u0012\u0004\u0012\u0002H\f0\u0004\"\u000e\b��\u0010\f*\b\u0012\u0004\u0012\u0002H\f0\r2\u0012\u0010\u000e\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u0002H\f0\u00100\u000f2\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0015\u001a\u00020\u00162\n\u0010\b\u001a\u0006\u0012\u0002\b\u00030\t2\u0006\u0010\u0017\u001a\u00020\u00182\n\b\u0002\u0010\u0019\u001a\u0004\u0018\u00010\u001a2\b\b\u0002\u0010\u0013\u001a\u00020\u00142\n\b\u0002\u0010\n\u001a\u0004\u0018\u00010\u000bH\u0086\u0002¨\u0006\u001b"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/LayerFactory;", "", "()V", "invoke", "Lcom/kotlinnlp/simplednn/core/layers/Layer;", "inputConfiguration", "Lcom/kotlinnlp/simplednn/core/layers/LayerInterface;", "outputConfiguration", "params", "Lcom/kotlinnlp/simplednn/core/layers/LayerParameters;", "contextWindow", "Lcom/kotlinnlp/simplednn/core/layers/models/recurrent/LayerContextWindow;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "inputArrays", "", "Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "inputType", "Lcom/kotlinnlp/simplednn/core/layers/LayerType$Input;", "dropout", "", "outputSize", "", "connectionType", "Lcom/kotlinnlp/simplednn/core/layers/LayerType$Connection;", "activationFunction", "Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/LayerFactory.class */
public final class LayerFactory {
    public static final LayerFactory INSTANCE = new LayerFactory();

    @NotNull
    public final Layer<?> invoke(@NotNull LayerInterface layerInterface, @NotNull LayerInterface layerInterface2, @NotNull LayerParameters<?> layerParameters, @Nullable LayerContextWindow layerContextWindow) {
        Intrinsics.checkParameterIsNotNull(layerInterface, "inputConfiguration");
        Intrinsics.checkParameterIsNotNull(layerInterface2, "outputConfiguration");
        Intrinsics.checkParameterIsNotNull(layerParameters, "params");
        if (!(layerInterface2.getConnectionType() != null)) {
            throw new IllegalArgumentException("Output layer configurations must have a not null connectionType".toString());
        }
        switch (layerInterface.getType()) {
            case Dense:
                LayerFactory layerFactory = INSTANCE;
                List<Integer> sizes = layerInterface.getSizes();
                ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(sizes, 10));
                Iterator<T> it = sizes.iterator();
                while (it.hasNext()) {
                    arrayList.add(new AugmentedArray(((Number) it.next()).intValue()));
                }
                return layerFactory.invoke(arrayList, layerInterface.getType(), layerInterface2.getSize(), layerParameters, layerInterface2.getConnectionType(), layerInterface2.getActivationFunction(), layerInterface.getDropout(), layerContextWindow);
            case Sparse:
                LayerFactory layerFactory2 = INSTANCE;
                List<Integer> sizes2 = layerInterface.getSizes();
                ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(sizes2, 10));
                Iterator<T> it2 = sizes2.iterator();
                while (it2.hasNext()) {
                    arrayList2.add(new AugmentedArray(((Number) it2.next()).intValue()));
                }
                return layerFactory2.invoke(arrayList2, layerInterface.getType(), layerInterface2.getSize(), layerParameters, layerInterface2.getConnectionType(), layerInterface2.getActivationFunction(), layerInterface.getDropout(), layerContextWindow);
            case SparseBinary:
                LayerFactory layerFactory3 = INSTANCE;
                List<Integer> sizes3 = layerInterface.getSizes();
                ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(sizes3, 10));
                Iterator<T> it3 = sizes3.iterator();
                while (it3.hasNext()) {
                    arrayList3.add(new AugmentedArray(((Number) it3.next()).intValue()));
                }
                return layerFactory3.invoke(arrayList3, layerInterface.getType(), layerInterface2.getSize(), layerParameters, layerInterface2.getConnectionType(), layerInterface2.getActivationFunction(), layerInterface.getDropout(), layerContextWindow);
            default:
                throw new NoWhenBranchMatchedException();
        }
    }

    @NotNull
    public static /* synthetic */ Layer invoke$default(LayerFactory layerFactory, LayerInterface layerInterface, LayerInterface layerInterface2, LayerParameters layerParameters, LayerContextWindow layerContextWindow, int i, Object obj) {
        if ((i & 8) != 0) {
            layerContextWindow = (LayerContextWindow) null;
        }
        return layerFactory.invoke(layerInterface, layerInterface2, layerParameters, layerContextWindow);
    }

    @NotNull
    public final <InputNDArrayType extends NDArray<InputNDArrayType>> Layer<InputNDArrayType> invoke(@NotNull List<? extends AugmentedArray<InputNDArrayType>> list, @NotNull LayerType.Input input, @NotNull LayerInterface layerInterface, @NotNull LayerParameters<?> layerParameters, double d, @Nullable LayerContextWindow layerContextWindow) {
        Intrinsics.checkParameterIsNotNull(list, "inputArrays");
        Intrinsics.checkParameterIsNotNull(input, "inputType");
        Intrinsics.checkParameterIsNotNull(layerInterface, "outputConfiguration");
        Intrinsics.checkParameterIsNotNull(layerParameters, "params");
        LayerFactory layerFactory = INSTANCE;
        int size = layerInterface.getSize();
        ActivationFunction activationFunction = layerInterface.getActivationFunction();
        LayerType.Connection connectionType = layerInterface.getConnectionType();
        if (connectionType == null) {
            Intrinsics.throwNpe();
        }
        return layerFactory.invoke(list, input, size, layerParameters, connectionType, activationFunction, d, layerContextWindow);
    }

    @NotNull
    public static /* synthetic */ Layer invoke$default(LayerFactory layerFactory, List list, LayerType.Input input, LayerInterface layerInterface, LayerParameters layerParameters, double d, LayerContextWindow layerContextWindow, int i, Object obj) {
        if ((i & 32) != 0) {
            layerContextWindow = (LayerContextWindow) null;
        }
        return layerFactory.invoke(list, input, layerInterface, layerParameters, d, layerContextWindow);
    }

    @NotNull
    public final <InputNDArrayType extends NDArray<InputNDArrayType>> Layer<InputNDArrayType> invoke(@NotNull List<? extends AugmentedArray<InputNDArrayType>> list, @NotNull LayerType.Input input, int i, @NotNull LayerParameters<?> layerParameters, @NotNull LayerType.Connection connection, @Nullable ActivationFunction activationFunction, double d, @Nullable LayerContextWindow layerContextWindow) {
        Intrinsics.checkParameterIsNotNull(list, "inputArrays");
        Intrinsics.checkParameterIsNotNull(input, "inputType");
        Intrinsics.checkParameterIsNotNull(layerParameters, "params");
        Intrinsics.checkParameterIsNotNull(connection, "connectionType");
        switch (connection) {
            case Feedforward:
                return new FeedforwardLayer((AugmentedArray) CollectionsKt.first(list), input, AugmentedArray.Companion.zeros(i), layerParameters, activationFunction, d, 0, 64, null);
            case Highway:
                return new HighwayLayer((AugmentedArray) CollectionsKt.first(list), input, AugmentedArray.Companion.zeros(i), layerParameters, activationFunction, d, 0, 64, null);
            case Affine:
                return new AffineLayer(list, input, AugmentedArray.Companion.zeros(i), (AffineLayerParameters) layerParameters, activationFunction, d, 0, 64, null);
            case Biaffine:
                return new BiaffineLayer(list.get(0), list.get(1), input, AugmentedArray.Companion.zeros(i), (BiaffineLayerParameters) layerParameters, activationFunction, d, 0, 128, null);
            case Concat:
                return new ConcatLayer(list, input, AugmentedArray.Companion.zeros(i), (ConcatLayerParameters) layerParameters, 0, 16, null);
            case Sum:
                return new SumLayer(list, input, AugmentedArray.Companion.zeros(i), (SumLayerParameters) layerParameters, 0, 16, null);
            case Avg:
                return new AvgLayer(list, input, AugmentedArray.Companion.zeros(i), (AvgLayerParameters) layerParameters, 0, 16, null);
            case Product:
                return new ProductLayer(list, input, AugmentedArray.Companion.zeros(i), (ProductLayerParameters) layerParameters, 0, 16, null);
            case SimpleRecurrent:
                AugmentedArray augmentedArray = (AugmentedArray) CollectionsKt.first(list);
                RecurrentLayerUnit recurrentLayerUnit = new RecurrentLayerUnit(i);
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new SimpleRecurrentLayer(augmentedArray, input, recurrentLayerUnit, layerParameters, layerContextWindow, activationFunction, d);
            case GRU:
                AugmentedArray augmentedArray2 = (AugmentedArray) CollectionsKt.first(list);
                AugmentedArray<DenseNDArray> zeros = AugmentedArray.Companion.zeros(i);
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new GRULayer(augmentedArray2, input, zeros, layerParameters, layerContextWindow, activationFunction, d);
            case LSTM:
                AugmentedArray augmentedArray3 = (AugmentedArray) CollectionsKt.first(list);
                AugmentedArray<DenseNDArray> zeros2 = AugmentedArray.Companion.zeros(i);
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new LSTMLayer(augmentedArray3, input, zeros2, layerParameters, layerContextWindow, activationFunction, d);
            case CFN:
                AugmentedArray augmentedArray4 = (AugmentedArray) CollectionsKt.first(list);
                AugmentedArray<DenseNDArray> zeros3 = AugmentedArray.Companion.zeros(i);
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new CFNLayer(augmentedArray4, input, zeros3, layerParameters, layerContextWindow, activationFunction, d);
            case RAN:
                AugmentedArray augmentedArray5 = (AugmentedArray) CollectionsKt.first(list);
                AugmentedArray<DenseNDArray> zeros4 = AugmentedArray.Companion.zeros(i);
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new RANLayer(augmentedArray5, input, zeros4, layerParameters, layerContextWindow, activationFunction, d);
            case DeltaRNN:
                AugmentedArray augmentedArray6 = (AugmentedArray) CollectionsKt.first(list);
                AugmentedArray<DenseNDArray> zeros5 = AugmentedArray.Companion.zeros(i);
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new DeltaRNNLayer(augmentedArray6, input, zeros5, layerParameters, layerContextWindow, activationFunction, d);
            case IndRNN:
                AugmentedArray augmentedArray7 = (AugmentedArray) CollectionsKt.first(list);
                AugmentedArray<DenseNDArray> zeros6 = AugmentedArray.Companion.zeros(i);
                if (layerContextWindow == null) {
                    Intrinsics.throwNpe();
                }
                return new IndRNNLayer(augmentedArray7, input, zeros6, layerParameters, layerContextWindow, activationFunction, d);
            default:
                throw new NoWhenBranchMatchedException();
        }
    }

    @NotNull
    public static /* synthetic */ Layer invoke$default(LayerFactory layerFactory, List list, LayerType.Input input, int i, LayerParameters layerParameters, LayerType.Connection connection, ActivationFunction activationFunction, double d, LayerContextWindow layerContextWindow, int i2, Object obj) {
        if ((i2 & 32) != 0) {
            activationFunction = (ActivationFunction) null;
        }
        if ((i2 & 64) != 0) {
            d = 0.0d;
        }
        if ((i2 & 128) != 0) {
            layerContextWindow = (LayerContextWindow) null;
        }
        return layerFactory.invoke(list, input, i, layerParameters, connection, activationFunction, d, layerContextWindow);
    }

    private LayerFactory() {
    }
}
