package ws.palladian.classification.liblinear;

import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.SolverType;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.utils.DummyVariableCreator;
import ws.palladian.classification.utils.Normalization;
import ws.palladian.classification.utils.Normalizer;
import ws.palladian.classification.utils.ZScoreNormalizer;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Instance;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.value.NumericValue;
import ws.palladian.core.value.Value;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.Vector;
import ws.palladian.helper.io.CloseableIterator;
import ws.palladian.helper.io.Slf4JOutputStream;

/* loaded from: input_file:ws/palladian/classification/liblinear/LibLinearLearner.class */
public final class LibLinearLearner extends AbstractLearner<LibLinearModel> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LibLinearLearner.class);
    private final Parameter parameter;
    private final double bias;
    private final Normalizer normalizer;

    public LibLinearLearner(Parameter parameter, double d, Normalizer normalizer) {
        Validate.notNull(parameter, "parameter must not be null", new Object[0]);
        Validate.notNull(normalizer, "normalizer must not be null", new Object[0]);
        if (parameter.getSolverType().isSupportVectorRegression()) {
            throw new UnsupportedOperationException("Support vector regression is not supported by this learner. This learner is for classification only!");
        }
        this.parameter = parameter;
        this.bias = d;
        this.normalizer = normalizer;
    }

    public LibLinearLearner(Normalizer normalizer) {
        this(new Parameter(SolverType.L2R_LR, 1.0d, 0.01d), 1.0d, normalizer);
    }

    public LibLinearLearner() {
        this(new ZScoreNormalizer());
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public LibLinearModel m8train(Dataset dataset) {
        Validate.notNull(dataset, "dataset must not be null", new Object[0]);
        Normalization calculate = this.normalizer.calculate(dataset);
        DummyVariableCreator dummyVariableCreator = new DummyVariableCreator(dataset, false, false);
        ArrayList arrayList = new ArrayList(dataset.transform(dummyVariableCreator).getFeatureInformation().getFeatureNames());
        Map createIndexMap = CollectionHelper.createIndexMap(arrayList);
        Problem problem = new Problem();
        LOGGER.debug("# Features = {}", Integer.valueOf(arrayList.size()));
        problem.n = arrayList.size();
        if (this.bias >= 0.0d) {
            LOGGER.debug("Add bias correction {}", Double.valueOf(this.bias));
            problem.bias = this.bias;
            problem.n++;
        }
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        CloseableIterator it = dataset.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            problem.l++;
            if (problem.l % 10000 == 0) {
                LOGGER.debug("Created {} training instances", Integer.valueOf(problem.l));
            }
            arrayList3.add(makeInstance(createIndexMap, dummyVariableCreator.convert(calculate.normalize(instance.getVector())), this.bias));
            if (!arrayList2.contains(instance.getCategory())) {
                arrayList2.add(instance.getCategory());
            }
            arrayList4.add(Integer.valueOf(arrayList2.indexOf(instance.getCategory())));
        }
        problem.x = (Feature[][]) arrayList3.toArray(new Feature[0]);
        problem.y = new double[problem.l];
        for (int i = 0; i < arrayList4.size(); i++) {
            problem.y[i] = ((Integer) arrayList4.get(i)).intValue();
        }
        LOGGER.debug("n={}, l={}", Integer.valueOf(problem.n), Integer.valueOf(problem.l));
        return new LibLinearModel(Linear.train(problem, this.parameter), createIndexMap, arrayList2, calculate, dummyVariableCreator);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Feature[] makeInstance(Map<String, Integer> map, FeatureVector featureVector, double d) {
        ArrayList arrayList = new ArrayList();
        Iterator it = featureVector.iterator();
        while (it.hasNext()) {
            Vector.VectorEntry vectorEntry = (Vector.VectorEntry) it.next();
            NumericValue numericValue = (Value) vectorEntry.value();
            Integer num = map.get(vectorEntry.key());
            if (num != null && !numericValue.isNull() && (numericValue instanceof NumericValue)) {
                double d2 = numericValue.getDouble();
                if (Math.abs(d2) >= 2.802596928649634E-45d) {
                    arrayList.add(new FeatureNode(num.intValue() + 1, d2));
                }
            }
        }
        if (d >= 0.0d) {
            arrayList.add(new FeatureNode(map.size() + 1, d));
        }
        Collections.sort(arrayList, new Comparator<Feature>() { // from class: ws.palladian.classification.liblinear.LibLinearLearner.1
            @Override // java.util.Comparator
            public int compare(Feature feature, Feature feature2) {
                return Integer.compare(feature.getIndex(), feature2.getIndex());
            }
        });
        return (Feature[]) arrayList.toArray(new Feature[arrayList.size()]);
    }

    public String toString() {
        return String.format("%s [SolverType=%s, C=%s, eps=%s, bias=%s]", getClass().getSimpleName(), this.parameter.getSolverType(), Double.valueOf(this.parameter.getC()), Double.valueOf(this.parameter.getEps()), Double.valueOf(this.bias));
    }

    static {
        Linear.setDebugOutput(new PrintStream((OutputStream) new Slf4JOutputStream(LOGGER, Slf4JOutputStream.Level.DEBUG)));
    }
}
