package defpackage;

import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax;
import com.kotlinnlp.simplednn.core.functionalities.activations.Tanh;
import com.kotlinnlp.simplednn.core.functionalities.losses.SoftmaxCrossEntropyCalculator;
import com.kotlinnlp.simplednn.core.functionalities.outputevaluation.ClassificationEvaluation;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.learningrate.LearningRateMethod;
import com.kotlinnlp.simplednn.core.neuralnetwork.NeuralNetwork;
import com.kotlinnlp.simplednn.core.neuralnetwork.preset.CFN;
import com.kotlinnlp.simplednn.core.neuralprocessor.recurrent.RecurrentNeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.dataset.Corpus;
import com.kotlinnlp.simplednn.dataset.SequenceExample;
import com.kotlinnlp.simplednn.dataset.Shuffler;
import com.kotlinnlp.simplednn.helpers.training.SequenceTrainingHelper;
import com.kotlinnlp.simplednn.helpers.validation.SequenceValidationHelper;
import com.kotlinnlp.simplednn.helpers.validation.ValidationHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.Arrays;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: ProgressiveSumTest.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��(\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0003\u0018��2\u00020\u0001B\u0019\u0012\u0012\u0010\u0002\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u00040\u0003¢\u0006\u0002\u0010\u0006J\b\u0010\u000b\u001a\u00020\fH\u0002J\u0006\u0010\r\u001a\u00020\fJ\b\u0010\u000e\u001a\u00020\fH\u0002R\u001d\u0010\u0002\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u00040\u0003¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\bR\u000e\u0010\t\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��¨\u0006\u000f"}, d2 = {"LProgressiveSumTest;", "", "dataset", "Lcom/kotlinnlp/simplednn/dataset/Corpus;", "Lcom/kotlinnlp/simplednn/dataset/SequenceExample;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "(Lcom/kotlinnlp/simplednn/dataset/Corpus;)V", "getDataset", "()Lcom/kotlinnlp/simplednn/dataset/Corpus;", "neuralNetwork", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "initialValidation", "", "start", "train", "simplednn"})
/* loaded from: input_file:ProgressiveSumTest.class */
public final class ProgressiveSumTest {
    private final NeuralNetwork neuralNetwork;

    @NotNull
    private final Corpus<SequenceExample<DenseNDArray>> dataset;

    public final void start() {
        initialValidation();
        train();
    }

    private final void initialValidation() {
        System.out.println((Object) "\n-- VALIDATION BEFORE TRAINING\n");
        Object[] objArr = {Double.valueOf(100.0d * ValidationHelper.validate$default(new SequenceValidationHelper(new RecurrentNeuralProcessor(this.neuralNetwork, 0, 2, null), new ClassificationEvaluation()), this.dataset.getValidation(), null, false, 6, null))};
        String format = String.format("Accuracy: %.2f%%", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
    }

    private final void train() {
        System.out.println((Object) "\n-- TRAINING\n");
        new SequenceTrainingHelper(new RecurrentNeuralProcessor(this.neuralNetwork, 0, 2, null), new ParamsOptimizer(this.neuralNetwork.getModel(), new LearningRateMethod(0.1d, null, null, 6, null)), new SoftmaxCrossEntropyCalculator(), null, true, 8, null).train(this.dataset.getTraining(), 4, 1, this.dataset.getValidation(), new SequenceValidationHelper(new RecurrentNeuralProcessor(this.neuralNetwork, 0, 2, null), new ClassificationEvaluation()), new Shuffler(true, 1L));
    }

    @NotNull
    public final Corpus<SequenceExample<DenseNDArray>> getDataset() {
        return this.dataset;
    }

    public ProgressiveSumTest(@NotNull Corpus<SequenceExample<DenseNDArray>> corpus) {
        Intrinsics.checkParameterIsNotNull(corpus, "dataset");
        this.dataset = corpus;
        this.neuralNetwork = CFN.invoke$default(CFN.INSTANCE, 1, null, 0.0d, 100, new Tanh(), 0.0d, 0, 11, new Softmax(), null, null, 1638, null);
    }
}
