package mnist;

import com.kotlinnlp.simplednn.core.functionalities.activations.Sigmoid;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.adam.ADAMMethod;
import com.kotlinnlp.simplednn.dataset.BinaryOutputExample;
import com.kotlinnlp.simplednn.dataset.Corpus;
import com.kotlinnlp.simplednn.dataset.Shuffler;
import com.kotlinnlp.simplednn.deeplearning.competitivelearning.CLNetwork;
import com.kotlinnlp.simplednn.deeplearning.competitivelearning.CLNetworkModel;
import com.kotlinnlp.simplednn.deeplearning.competitivelearning.CLNetworkOptimizer;
import com.kotlinnlp.simplednn.helpers.training.CompetitiveLearningTrainingHelper;
import com.kotlinnlp.simplednn.helpers.validation.CompetitiveLearningValidationHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import kotlin.Metadata;
import kotlin.collections.SetsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: MNISTCompetitiveLearningTest.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��(\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0002\u0018��2\u00020\u0001B\u0019\u0012\u0012\u0010\u0002\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u00040\u0003¢\u0006\u0002\u0010\u0006J\u0006\u0010\u000b\u001a\u00020\fJ\b\u0010\r\u001a\u00020\fH\u0002R\u001d\u0010\u0002\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u00040\u0003¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\bR\u000e\u0010\t\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��¨\u0006\u000e"}, d2 = {"Lmnist/MNISTCompetitiveLearningTest;", "", "dataset", "Lcom/kotlinnlp/simplednn/dataset/Corpus;", "Lcom/kotlinnlp/simplednn/dataset/BinaryOutputExample;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "(Lcom/kotlinnlp/simplednn/dataset/Corpus;)V", "getDataset", "()Lcom/kotlinnlp/simplednn/dataset/Corpus;", "model", "Lcom/kotlinnlp/simplednn/deeplearning/competitivelearning/CLNetworkModel;", "start", "", "train", "simplednn"})
/* loaded from: input_file:mnist/MNISTCompetitiveLearningTest.class */
public final class MNISTCompetitiveLearningTest {
    private final CLNetworkModel model;

    @NotNull
    private final Corpus<BinaryOutputExample<DenseNDArray>> dataset;

    public final void start() {
        train();
    }

    private final void train() {
        System.out.println((Object) "\n-- TRAINING");
        new CompetitiveLearningTrainingHelper(new CLNetwork(this.model), new CLNetworkOptimizer(this.model, new ADAMMethod(0.001d, 0.9d, 0.999d, 0.0d, null, 24, null)), true).train(this.dataset.getTraining(), 15, 1, this.dataset.getValidation(), new CompetitiveLearningValidationHelper(new CLNetwork(this.model)), new Shuffler(true, 1L));
    }

    @NotNull
    public final Corpus<BinaryOutputExample<DenseNDArray>> getDataset() {
        return this.dataset;
    }

    public MNISTCompetitiveLearningTest(@NotNull Corpus<BinaryOutputExample<DenseNDArray>> corpus) {
        Intrinsics.checkParameterIsNotNull(corpus, "dataset");
        this.dataset = corpus;
        this.model = new CLNetworkModel(SetsKt.setOf(new Integer[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), 784, 50, new Sigmoid(), null, null, 48, null);
    }
}
