package ws.palladian.classification.nb;

import java.util.Iterator;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.Instance;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.value.NominalValue;
import ws.palladian.core.value.NumericValue;
import ws.palladian.core.value.Value;
import ws.palladian.helper.collection.Bag;
import ws.palladian.helper.collection.LazyMatrix;
import ws.palladian.helper.collection.MapMatrix;
import ws.palladian.helper.collection.Matrix;
import ws.palladian.helper.collection.Vector;
import ws.palladian.helper.math.SlimStats;
import ws.palladian.helper.math.Stats;

/* loaded from: input_file:ws/palladian/classification/nb/NaiveBayesLearner.class */
public final class NaiveBayesLearner extends AbstractLearner<NaiveBayesModel> {
    @Override // ws.palladian.core.Learner
    public NaiveBayesModel train(Dataset dataset) {
        Bag bag = new Bag();
        LazyMatrix lazyMatrix = new LazyMatrix(Bag::new);
        LazyMatrix lazyMatrix2 = new LazyMatrix(SlimStats::new);
        Iterator<Instance> iterator2 = dataset.iterator2();
        while (iterator2.hasNext()) {
            Instance next = iterator2.next();
            String category = next.getCategory();
            bag.add(category);
            for (Vector.VectorEntry vectorEntry : next.getVector()) {
                String str = (String) vectorEntry.key();
                Value value = (Value) vectorEntry.value();
                if (value instanceof NominalValue) {
                    ((Bag) lazyMatrix.get(str, ((NominalValue) value).getString())).add(category);
                } else if (value instanceof NumericValue) {
                    ((Stats) lazyMatrix2.get(str, category)).add(Double.valueOf(((NumericValue) value).getDouble()));
                }
            }
        }
        MapMatrix mapMatrix = new MapMatrix();
        MapMatrix mapMatrix2 = new MapMatrix();
        for (Matrix.MatrixVector<Vector.VectorEntry> matrixVector : lazyMatrix2.rows()) {
            String str2 = (String) matrixVector.key();
            for (Vector.VectorEntry vectorEntry2 : matrixVector) {
                String str3 = (String) vectorEntry2.key();
                mapMatrix.set(str3, str2, Double.valueOf(((Stats) vectorEntry2.value()).getMean()));
                mapMatrix2.set(str3, str2, Double.valueOf(((Stats) vectorEntry2.value()).getStandardDeviation()));
            }
        }
        return new NaiveBayesModel(lazyMatrix.getMatrix(), bag, mapMatrix, mapMatrix2);
    }

    @Override // ws.palladian.core.AbstractLearner
    public String toString() {
        return getClass().getSimpleName();
    }
}
