package ws.palladian.classification;

import java.util.ArrayList;
import java.util.List;
import libsvm.svm;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_print_interface;
import libsvm.svm_problem;
import org.apache.commons.lang.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.utils.DummyVariableCreator;
import ws.palladian.classification.utils.Normalization;
import ws.palladian.classification.utils.Normalizer;
import ws.palladian.classification.utils.ZScoreNormalizer;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Instance;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.value.NumericValue;
import ws.palladian.core.value.Value;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.Vector;
import ws.palladian.helper.io.CloseableIterator;

/* loaded from: input_file:ws/palladian/classification/LibSvmLearner.class */
public final class LibSvmLearner extends AbstractLearner<LibSvmModel> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LibSvmLearner.class);
    private static final Normalizer NORMALIZER = new ZScoreNormalizer();
    private final LibSvmKernel kernel;

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void redirectLogOutput() {
        svm.svm_set_print_string_function(new svm_print_interface() { // from class: ws.palladian.classification.LibSvmLearner.1
            public void print(String str) {
                LibSvmLearner.LOGGER.debug(str);
            }
        });
    }

    public LibSvmLearner(LibSvmKernel libSvmKernel) {
        Validate.notNull(libSvmKernel, "kernel must not be null");
        this.kernel = libSvmKernel;
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public LibSvmModel m2train(Dataset dataset) {
        Validate.notNull(dataset, "dataset must not be null");
        Normalization calculate = NORMALIZER.calculate(dataset);
        DummyVariableCreator dummyVariableCreator = new DummyVariableCreator(dataset, false, false);
        ArrayList arrayList = new ArrayList();
        List<String> arrayList2 = new ArrayList<>();
        CloseableIterator it = dataset.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            for (Vector.VectorEntry vectorEntry : dummyVariableCreator.convert(instance.getVector())) {
                if ((((Value) vectorEntry.value()) instanceof NumericValue) && !arrayList.contains(vectorEntry.key())) {
                    arrayList.add(vectorEntry.key());
                }
            }
            if (!arrayList2.contains(instance.getCategory())) {
                arrayList2.add(instance.getCategory());
            }
        }
        if (arrayList2.size() < 2) {
            throw new IllegalStateException("The training data contains less than two different classes. Training not possible on such a dataset.");
        }
        svm_parameter parameter = getParameter();
        svm_problem createProblem = createProblem(dataset, parameter, arrayList, arrayList2, calculate, dummyVariableCreator);
        String svm_check_parameter = svm.svm_check_parameter(createProblem, parameter);
        if (svm_check_parameter != null) {
            throw new IllegalStateException(svm_check_parameter);
        }
        return new LibSvmModel(svm.svm_train(createProblem, parameter), arrayList, arrayList2, calculate, dummyVariableCreator);
    }

    /* JADX WARN: Type inference failed for: r1v6, types: [libsvm.svm_node[], libsvm.svm_node[][]] */
    private svm_problem createProblem(Iterable<? extends Instance> iterable, svm_parameter svm_parameterVar, List<String> list, List<String> list2, Normalization normalization, DummyVariableCreator dummyVariableCreator) {
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = CollectionHelper.count(iterable.iterator());
        svm_problemVar.x = new svm_node[svm_problemVar.l];
        svm_problemVar.y = new double[svm_problemVar.l];
        int i = 0;
        for (Instance instance : iterable) {
            svm_problemVar.y[i] = list2.indexOf(instance.getCategory());
            svm_problemVar.x[i] = convertFeatureVector(instance.getVector(), list, normalization, dummyVariableCreator);
            i++;
        }
        return svm_problemVar;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static svm_node[] convertFeatureVector(FeatureVector featureVector, List<String> list, Normalization normalization, DummyVariableCreator dummyVariableCreator) {
        FeatureVector convert = dummyVariableCreator.convert(normalization.normalize(featureVector));
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            NumericValue numericValue = (Value) convert.get(list.get(i));
            if (numericValue instanceof NumericValue) {
                NumericValue numericValue2 = numericValue;
                svm_node svm_nodeVar = new svm_node();
                svm_nodeVar.index = i;
                svm_nodeVar.value = numericValue2.getDouble();
                arrayList.add(svm_nodeVar);
            }
        }
        return (svm_node[]) arrayList.toArray(new svm_node[arrayList.size()]);
    }

    private svm_parameter getParameter() {
        svm_parameter svm_parameterVar = new svm_parameter();
        this.kernel.apply(svm_parameterVar);
        svm_parameterVar.svm_type = 0;
        svm_parameterVar.degree = 3;
        svm_parameterVar.coef0 = 0.0d;
        svm_parameterVar.nu = 0.5d;
        svm_parameterVar.cache_size = 100.0d;
        svm_parameterVar.eps = 0.001d;
        svm_parameterVar.p = 0.1d;
        svm_parameterVar.shrinking = 1;
        svm_parameterVar.probability = 1;
        svm_parameterVar.nr_weight = 0;
        svm_parameterVar.weight_label = new int[0];
        svm_parameterVar.weight = new double[0];
        return svm_parameterVar;
    }

    public String toString() {
        return getClass().getSimpleName() + " (" + this.kernel.getClass().getSimpleName() + ")";
    }

    static {
        redirectLogOutput();
    }
}
