package ws.palladian.classification.text.nbsvm;

import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.SolverType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;
import ws.palladian.classification.liblinear.LibLinearLearner;
import ws.palladian.classification.utils.NoNormalizer;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Instance;
import ws.palladian.core.InstanceBuilder;
import ws.palladian.core.dataset.AbstractDatasetFeatureVectorTransformer;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.value.NumericValue;
import ws.palladian.extraction.text.vector.TextVectorizer;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.Vector;
import ws.palladian.helper.functional.Predicates;
import ws.palladian.helper.io.CloseableIterator;

/* loaded from: input_file:ws/palladian/classification/text/nbsvm/NbSvmLearner.class */
public class NbSvmLearner extends AbstractLearner<NbSvmModel> {
    private static final String TRUE_CATEGORY = "1";
    private static final String FALSE_CATEGORY = "0";
    private static final float ALPHA = 1.0f;
    private final TextVectorizer vectorizer;
    private final LibLinearLearner learner;

    public NbSvmLearner(TextVectorizer textVectorizer) {
        this(textVectorizer, new Parameter(SolverType.L2R_LR, 1.0d, 0.01d));
    }

    public NbSvmLearner(TextVectorizer textVectorizer, Parameter parameter) {
        this.vectorizer = textVectorizer;
        this.learner = new LibLinearLearner(parameter, 1.0d, new NoNormalizer());
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public NbSvmModel m15train(Dataset dataset) {
        Dataset filterFeatures = dataset.transform(this.vectorizer).filterFeatures(Predicates.not(Predicates.equal("text")));
        HashSet hashSet = new HashSet();
        CloseableIterator it = filterFeatures.iterator();
        while (it.hasNext()) {
            Iterator it2 = ((Instance) it.next()).getVector().iterator();
            while (it2.hasNext()) {
                hashSet.add(((Vector.VectorEntry) it2.next()).key());
            }
        }
        final Map createIndexMap = CollectionHelper.createIndexMap(new ArrayList(hashSet));
        int size = hashSet.size();
        float[] fArr = new float[size];
        float[] fArr2 = new float[size];
        Arrays.fill(fArr, ALPHA);
        Arrays.fill(fArr2, ALPHA);
        CloseableIterator it3 = filterFeatures.iterator();
        while (it3.hasNext()) {
            Instance instance = (Instance) it3.next();
            for (Vector.VectorEntry vectorEntry : instance.getVector()) {
                String str = (String) vectorEntry.key();
                float f = ((NumericValue) vectorEntry.value()).getFloat();
                if (instance.getCategory().equals(TRUE_CATEGORY)) {
                    int intValue = ((Integer) createIndexMap.get(str)).intValue();
                    fArr[intValue] = fArr[intValue] + f;
                } else {
                    if (!instance.getCategory().equals(FALSE_CATEGORY)) {
                        throw new IllegalStateException(String.format("Instance must currently be of category '%s' or '%s'", FALSE_CATEGORY, TRUE_CATEGORY));
                    }
                    int intValue2 = ((Integer) createIndexMap.get(str)).intValue();
                    fArr2[intValue2] = fArr2[intValue2] + f;
                }
            }
        }
        float f2 = 0.0f;
        float f3 = 0.0f;
        for (int i = 0; i < size; i++) {
            f2 += fArr[i];
            f3 += fArr2[i];
        }
        final float[] fArr3 = new float[size];
        for (int i2 = 0; i2 < size; i2++) {
            fArr3[i2] = (float) FastMath.log((fArr[i2] / f2) / (fArr2[i2] / f3));
        }
        return new NbSvmModel(this.learner.m8train(filterFeatures.transform(new AbstractDatasetFeatureVectorTransformer() { // from class: ws.palladian.classification.text.nbsvm.NbSvmLearner.1
            public FeatureVector apply(FeatureVector featureVector) {
                return NbSvmLearner.transform(createIndexMap, fArr3, featureVector);
            }
        })), createIndexMap, fArr3);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static FeatureVector transform(Map<String, Integer> map, float[] fArr, FeatureVector featureVector) {
        InstanceBuilder instanceBuilder = new InstanceBuilder();
        Iterator it = featureVector.iterator();
        while (it.hasNext()) {
            Vector.VectorEntry vectorEntry = (Vector.VectorEntry) it.next();
            String str = (String) vectorEntry.key();
            float f = ((NumericValue) vectorEntry.value()).getFloat();
            Integer num = map.get(str);
            if (num != null) {
                instanceBuilder.set((String) vectorEntry.key(), f * fArr[num.intValue()]);
            }
        }
        return instanceBuilder.create();
    }

    public String toString() {
        return String.format("%s [learner=%s, vectorizer=%s]", getClass().getSimpleName(), this.learner, this.vectorizer);
    }
}
