package ws.palladian.classification.nb;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.Validate;
import org.apache.commons.math3.util.FastMath;
import ws.palladian.core.Model;
import ws.palladian.helper.collection.Bag;
import ws.palladian.helper.collection.Matrix;

/* loaded from: input_file:ws/palladian/classification/nb/NaiveBayesModel.class */
public final class NaiveBayesModel implements Model {
    private static final long serialVersionUID = 3;
    private final Matrix<String, Bag<String>> nominalCounts;
    private final Bag<String> categories;
    private final Matrix<String, Double> sampleMeans;
    private final Matrix<String, Double> standardDeviations;
    private transient Map<String, Double> densityNormalization;

    /* JADX INFO: Access modifiers changed from: package-private */
    public NaiveBayesModel(Matrix<String, Bag<String>> matrix, Bag<String> bag, Matrix<String, Double> matrix2, Matrix<String, Double> matrix3) {
        this.nominalCounts = matrix;
        this.categories = bag;
        this.sampleMeans = matrix2;
        this.standardDeviations = matrix3;
    }

    public double getPrior(String str) {
        Validate.notNull(str, "category must not be null", new Object[0]);
        return this.categories.count(str) / this.categories.size();
    }

    public double getProbability(String str, String str2, String str3, double d) {
        Validate.notNull(str, "featureName must not be null", new Object[0]);
        Validate.notNull(str2, "featureValue must not be null", new Object[0]);
        Validate.notNull(str3, "category must not be null", new Object[0]);
        Validate.isTrue(d >= 0.0d, "laplace corrector must be equal or greater than zero", new Object[0]);
        return ((((Bag) this.nominalCounts.get(str, str2)) != null ? r0.count(str3) : 0) + d) / (this.categories.count(str3) + (d * this.categories.unique().size()));
    }

    private Double getStandardDeviation(String str, String str2) {
        return (Double) this.standardDeviations.get(str, str2);
    }

    private Double getMean(String str, String str2) {
        return (Double) this.sampleMeans.get(str, str2);
    }

    public double getDensity(String str, double d, String str2) {
        Validate.notNull(str, "featureName must not be null", new Object[0]);
        Validate.notNull(str2, "category must not be null", new Object[0]);
        Double standardDeviation = getStandardDeviation(str, str2);
        Double mean = getMean(str, str2);
        if (standardDeviation == null || standardDeviation.doubleValue() == 0.0d) {
            return 0.0d;
        }
        double doubleValue = standardDeviation.doubleValue() * standardDeviation.doubleValue();
        return ((1.0d / Math.sqrt(6.283185307179586d * doubleValue)) * FastMath.exp((-FastMath.pow(d - mean.doubleValue(), 2)) / (2.0d * doubleValue))) / getDensityNormalization(str);
    }

    private double getDensityNormalization(String str) {
        if (this.densityNormalization == null) {
            this.densityNormalization = calcDensityNormalization(this.standardDeviations);
        }
        return this.densityNormalization.get(str).doubleValue();
    }

    private static Map<String, Double> calcDensityNormalization(Matrix<String, Double> matrix) {
        HashMap hashMap = new HashMap();
        for (String str : matrix.getColumnKeys()) {
            double d = 0.0d;
            for (Double d2 : matrix.getColumn(str).values()) {
                if (d2.doubleValue() > 0.0d) {
                    d += 1.0d / (d2.doubleValue() * Math.sqrt(6.283185307179586d));
                }
            }
            hashMap.put(str, Double.valueOf(d));
        }
        return Collections.unmodifiableMap(hashMap);
    }

    public String toString() {
        return "NaiveBayesModel [nominalCounts=" + this.nominalCounts + ", categories=" + this.categories + ", sampleMeans=" + this.sampleMeans + ", standardDeviations=" + this.standardDeviations + "]";
    }

    @Override // ws.palladian.core.Model
    public Set<String> getCategories() {
        return this.categories.uniqueItems();
    }

    public Set<String> getLearnedFeatures() {
        HashSet hashSet = new HashSet();
        hashSet.addAll(this.nominalCounts.getColumnKeys());
        hashSet.addAll(this.sampleMeans.getColumnKeys());
        return Collections.unmodifiableSet(hashSet);
    }
}
