package com.kotlinnlp.simplednn.deeplearning.competitivelearning;

import com.kotlinnlp.simplednn.core.functionalities.losses.MSECalculator;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralnetwork.NeuralNetwork;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor;
import com.kotlinnlp.simplednn.simplemath.SimplemathKt;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TuplesKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.MapsKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: CLNetwork.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��N\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010%\n\u0002\u0010\b\n\u0002\u0010\u0006\n��\n\u0002\u0010$\n\u0002\b\u0004\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\u0018��2\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u001c\u0010\u0015\u001a\u000e\u0012\u0004\u0012\u00020\u000e\u0012\u0004\u0012\u00020\u00070\u00162\b\b\u0002\u0010\u0017\u001a\u00020\u0018J\u001a\u0010\u0019\u001a\u000e\u0012\u0004\u0012\u00020\u000e\u0012\u0004\u0012\u00020\u001a0\u00162\u0006\u0010\u0017\u001a\u00020\u0018J\u0016\u0010\u001b\u001a\u00020\u000f2\u0006\u0010\u001c\u001a\u00020\u00072\u0006\u0010\u001d\u001a\u00020\u000eJ\u000e\u0010\u001e\u001a\u00020\u000e2\u0006\u0010\u001c\u001a\u00020\u0007J\u000e\u0010\u001f\u001a\u00020\u000e2\u0006\u0010\u001c\u001a\u00020\u0007R\u0014\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006X\u0082.¢\u0006\u0002\n��R\u000e\u0010\b\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\n\u0010\u000bR\u001a\u0010\f\u001a\u000e\u0012\u0004\u0012\u00020\u000e\u0012\u0004\u0012\u00020\u000f0\rX\u0082\u0004¢\u0006\u0002\n��R\u001d\u0010\u0010\u001a\u000e\u0012\u0004\u0012\u00020\u000e\u0012\u0004\u0012\u00020\u000f0\u00118F¢\u0006\u0006\u001a\u0004\b\u0012\u0010\u0013R \u0010\u0014\u001a\u0014\u0012\u0004\u0012\u00020\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00070\u00060\u0011X\u0082\u0004¢\u0006\u0002\n��¨\u0006 "}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/competitivelearning/CLNetwork;", "", "model", "Lcom/kotlinnlp/simplednn/deeplearning/competitivelearning/CLNetworkModel;", "(Lcom/kotlinnlp/simplednn/deeplearning/competitivelearning/CLNetworkModel;)V", "lastBackwardProcessor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/feedforward/FeedforwardNeuralProcessor;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "lossCalculator", "Lcom/kotlinnlp/simplednn/core/functionalities/losses/MSECalculator;", "getModel", "()Lcom/kotlinnlp/simplednn/deeplearning/competitivelearning/CLNetworkModel;", "mutableScores", "", "", "", "outputScores", "", "getOutputScores", "()Ljava/util/Map;", "processors", "getInputErrors", "Lkotlin/Pair;", "copy", "", "getParamsErrors", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "learn", "inputArray", "classId", "predict", "predictByLoss", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/competitivelearning/CLNetwork.class */
public final class CLNetwork {
    private final Map<Integer, Double> mutableScores;
    private final Map<Integer, FeedforwardNeuralProcessor<DenseNDArray>> processors;
    private final MSECalculator lossCalculator;
    private FeedforwardNeuralProcessor<DenseNDArray> lastBackwardProcessor;

    @NotNull
    private final CLNetworkModel model;

    @NotNull
    public final Map<Integer, Double> getOutputScores() {
        return this.mutableScores;
    }

    public final int predict(@NotNull DenseNDArray denseNDArray) {
        Object obj;
        Intrinsics.checkParameterIsNotNull(denseNDArray, "inputArray");
        DenseNDArray normalize = SimplemathKt.normalize(denseNDArray);
        Iterator<T> it = this.processors.entrySet().iterator();
        if (it.hasNext()) {
            Object next = it.next();
            Map.Entry entry = (Map.Entry) next;
            int intValue = ((Number) entry.getKey()).intValue();
            double similarity = SimplemathKt.similarity(normalize, SimplemathKt.normalize(FeedforwardNeuralProcessor.forward$default((FeedforwardNeuralProcessor) entry.getValue(), denseNDArray, false, 2, null)));
            this.mutableScores.put(Integer.valueOf(intValue), Double.valueOf(similarity));
            double d = similarity;
            while (it.hasNext()) {
                Object next2 = it.next();
                Map.Entry entry2 = (Map.Entry) next2;
                int intValue2 = ((Number) entry2.getKey()).intValue();
                double similarity2 = SimplemathKt.similarity(normalize, SimplemathKt.normalize(FeedforwardNeuralProcessor.forward$default((FeedforwardNeuralProcessor) entry2.getValue(), denseNDArray, false, 2, null)));
                this.mutableScores.put(Integer.valueOf(intValue2), Double.valueOf(similarity2));
                if (Double.compare(d, similarity2) < 0) {
                    next = next2;
                    d = similarity2;
                }
            }
            obj = next;
        } else {
            obj = null;
        }
        Map.Entry entry3 = (Map.Entry) obj;
        if (entry3 == null) {
            Intrinsics.throwNpe();
        }
        return ((Number) entry3.getKey()).intValue();
    }

    public final int predictByLoss(@NotNull DenseNDArray denseNDArray) {
        Object obj;
        Intrinsics.checkParameterIsNotNull(denseNDArray, "inputArray");
        Iterator<T> it = this.processors.entrySet().iterator();
        if (it.hasNext()) {
            Object next = it.next();
            Map.Entry entry = (Map.Entry) next;
            int intValue = ((Number) entry.getKey()).intValue();
            double avg = this.lossCalculator.calculateLoss(FeedforwardNeuralProcessor.forward$default((FeedforwardNeuralProcessor) entry.getValue(), denseNDArray, false, 2, null), denseNDArray).avg();
            this.mutableScores.put(Integer.valueOf(intValue), Double.valueOf(avg));
            double d = avg;
            while (it.hasNext()) {
                Object next2 = it.next();
                Map.Entry entry2 = (Map.Entry) next2;
                int intValue2 = ((Number) entry2.getKey()).intValue();
                double avg2 = this.lossCalculator.calculateLoss(FeedforwardNeuralProcessor.forward$default((FeedforwardNeuralProcessor) entry2.getValue(), denseNDArray, false, 2, null), denseNDArray).avg();
                this.mutableScores.put(Integer.valueOf(intValue2), Double.valueOf(avg2));
                if (Double.compare(d, avg2) > 0) {
                    next = next2;
                    d = avg2;
                }
            }
            obj = next;
        } else {
            obj = null;
        }
        Map.Entry entry3 = (Map.Entry) obj;
        if (entry3 == null) {
            Intrinsics.throwNpe();
        }
        return ((Number) entry3.getKey()).intValue();
    }

    public final double learn(@NotNull DenseNDArray denseNDArray, int i) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "inputArray");
        FeedforwardNeuralProcessor<DenseNDArray> feedforwardNeuralProcessor = this.processors.get(Integer.valueOf(i));
        if (feedforwardNeuralProcessor == null) {
            throw new IllegalStateException(("Unknown class: " + i).toString());
        }
        this.lastBackwardProcessor = feedforwardNeuralProcessor;
        FeedforwardNeuralProcessor<DenseNDArray> feedforwardNeuralProcessor2 = this.lastBackwardProcessor;
        if (feedforwardNeuralProcessor2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("lastBackwardProcessor");
        }
        DenseNDArray forward$default = FeedforwardNeuralProcessor.forward$default(feedforwardNeuralProcessor2, denseNDArray, false, 2, null);
        DenseNDArray calculateErrors = this.lossCalculator.calculateErrors(forward$default, denseNDArray);
        FeedforwardNeuralProcessor<DenseNDArray> feedforwardNeuralProcessor3 = this.lastBackwardProcessor;
        if (feedforwardNeuralProcessor3 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("lastBackwardProcessor");
        }
        FeedforwardNeuralProcessor.backward$default(feedforwardNeuralProcessor3, calculateErrors, false, null, 6, null);
        return this.lossCalculator.calculateLoss(forward$default, denseNDArray).avg();
    }

    @NotNull
    public final Pair<Integer, DenseNDArray> getInputErrors(boolean z) {
        FeedforwardNeuralProcessor<DenseNDArray> feedforwardNeuralProcessor = this.lastBackwardProcessor;
        if (feedforwardNeuralProcessor == null) {
            Intrinsics.throwUninitializedPropertyAccessException("lastBackwardProcessor");
        }
        Integer valueOf = Integer.valueOf(feedforwardNeuralProcessor.getId());
        FeedforwardNeuralProcessor<DenseNDArray> feedforwardNeuralProcessor2 = this.lastBackwardProcessor;
        if (feedforwardNeuralProcessor2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("lastBackwardProcessor");
        }
        return new Pair<>(valueOf, feedforwardNeuralProcessor2.getInputErrors(z));
    }

    @NotNull
    public static /* bridge */ /* synthetic */ Pair getInputErrors$default(CLNetwork cLNetwork, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return cLNetwork.getInputErrors(z);
    }

    @NotNull
    public final Pair<Integer, NetworkParameters> getParamsErrors(boolean z) {
        FeedforwardNeuralProcessor<DenseNDArray> feedforwardNeuralProcessor = this.lastBackwardProcessor;
        if (feedforwardNeuralProcessor == null) {
            Intrinsics.throwUninitializedPropertyAccessException("lastBackwardProcessor");
        }
        Integer valueOf = Integer.valueOf(feedforwardNeuralProcessor.getId());
        FeedforwardNeuralProcessor<DenseNDArray> feedforwardNeuralProcessor2 = this.lastBackwardProcessor;
        if (feedforwardNeuralProcessor2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("lastBackwardProcessor");
        }
        return new Pair<>(valueOf, feedforwardNeuralProcessor2.getParamsErrors(z));
    }

    @NotNull
    public final CLNetworkModel getModel() {
        return this.model;
    }

    public CLNetwork(@NotNull CLNetworkModel cLNetworkModel) {
        Intrinsics.checkParameterIsNotNull(cLNetworkModel, "model");
        this.model = cLNetworkModel;
        this.mutableScores = new LinkedHashMap();
        Set<Integer> classes = this.model.getClasses();
        LinkedHashMap linkedHashMap = new LinkedHashMap(RangesKt.coerceAtLeast(MapsKt.mapCapacity(CollectionsKt.collectionSizeOrDefault(classes, 10)), 16));
        Iterator<T> it = classes.iterator();
        while (it.hasNext()) {
            int intValue = ((Number) it.next()).intValue();
            Pair pair = TuplesKt.to(Integer.valueOf(intValue), new FeedforwardNeuralProcessor((NeuralNetwork) MapsKt.getValue(this.model.getNetworks(), Integer.valueOf(intValue)), intValue));
            linkedHashMap.put(pair.getFirst(), pair.getSecond());
        }
        this.processors = linkedHashMap;
        this.lossCalculator = new MSECalculator();
    }
}
