package com.kotlinnlp.simplednn.core.functionalities.losses;

import com.kotlinnlp.simplednn.core.functionalities.losses.LossCalculator;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.List;
import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: SoftmaxCrossEntropyCalculator.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��\u001c\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\b\n\u0002\b\u0003\b\u0016\u0018�� \n2\u00020\u0001:\u0001\nB\u0005¢\u0006\u0002\u0010\u0002J\u0018\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00042\u0006\u0010\u0006\u001a\u00020\u0004H\u0016J\u0016\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00042\u0006\u0010\u0007\u001a\u00020\bJ\u0018\u0010\t\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00042\u0006\u0010\u0006\u001a\u00020\u0004H\u0016¨\u0006\u000b"}, d2 = {"Lcom/kotlinnlp/simplednn/core/functionalities/losses/SoftmaxCrossEntropyCalculator;", "Lcom/kotlinnlp/simplednn/core/functionalities/losses/LossCalculator;", "()V", "calculateErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "output", "outputGold", "goldIndex", "", "calculateLoss", "Companion", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/functionalities/losses/SoftmaxCrossEntropyCalculator.class */
public class SoftmaxCrossEntropyCalculator implements LossCalculator {
    private static final double EPS = 1.0E-8d;
    public static final Companion Companion = new Companion(null);

    /* compiled from: SoftmaxCrossEntropyCalculator.kt */
    @Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��\u0012\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\u0006\n��\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002R\u000e\u0010\u0003\u001a\u00020\u0004X\u0082T¢\u0006\u0002\n��¨\u0006\u0005"}, d2 = {"Lcom/kotlinnlp/simplednn/core/functionalities/losses/SoftmaxCrossEntropyCalculator$Companion;", "", "()V", "EPS", "", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/core/functionalities/losses/SoftmaxCrossEntropyCalculator$Companion.class */
    public static final class Companion {
        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    @Override // com.kotlinnlp.simplednn.core.functionalities.losses.LossCalculator
    @NotNull
    public DenseNDArray calculateLoss(@NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "output");
        Intrinsics.checkParameterIsNotNull(denseNDArray2, "outputGold");
        if (!denseNDArray2.isOneHotEncoder()) {
            throw new IllegalArgumentException("The gold output must be a one hot encoder to calculate the loss with the cross-entropy function.".toString());
        }
        int argMaxIndex$default = NDArray.DefaultImpls.argMaxIndex$default(denseNDArray2, 0, 1, null);
        DenseNDArray zerosLike = denseNDArray.zerosLike();
        double doubleValue = denseNDArray.get(argMaxIndex$default).doubleValue();
        zerosLike.set(argMaxIndex$default, Double.valueOf(-Math.log(doubleValue >= EPS ? doubleValue : EPS)));
        return zerosLike;
    }

    @Override // com.kotlinnlp.simplednn.core.functionalities.losses.LossCalculator
    @NotNull
    public DenseNDArray calculateErrors(@NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "output");
        Intrinsics.checkParameterIsNotNull(denseNDArray2, "outputGold");
        return denseNDArray.sub(denseNDArray2);
    }

    @NotNull
    public final DenseNDArray calculateErrors(@NotNull DenseNDArray denseNDArray, int i) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "output");
        DenseNDArray copy = denseNDArray.copy();
        copy.set(i, Double.valueOf(copy.get(i).doubleValue() - 1.0d));
        return copy;
    }

    @Override // com.kotlinnlp.simplednn.core.functionalities.losses.LossCalculator
    @NotNull
    public List<DenseNDArray> calculateErrors(@NotNull List<DenseNDArray> list, @NotNull List<DenseNDArray> list2) {
        Intrinsics.checkParameterIsNotNull(list, "outputSequence");
        Intrinsics.checkParameterIsNotNull(list2, "outputGoldSequence");
        return LossCalculator.DefaultImpls.calculateErrors(this, list, list2);
    }

    @Override // com.kotlinnlp.simplednn.core.functionalities.losses.LossCalculator
    public double calculateMeanLoss(@NotNull List<DenseNDArray> list, @NotNull List<DenseNDArray> list2) {
        Intrinsics.checkParameterIsNotNull(list, "outputSequence");
        Intrinsics.checkParameterIsNotNull(list2, "outputGoldSequence");
        return LossCalculator.DefaultImpls.calculateMeanLoss(this, list, list2);
    }
}
