package ws.palladian.classification;

import java.util.Iterator;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.util.FastMath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.evaluation.ConfusionMatrixEvaluator;
import ws.palladian.classification.evaluation.CrossValidator;
import ws.palladian.classification.evaluation.RandomCrossValidator;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.dataset.Dataset;

/* loaded from: input_file:ws/palladian/classification/SelfTuningLibSvmLearner.class */
public class SelfTuningLibSvmLearner extends AbstractLearner<LibSvmModel> {
    private static final Logger LOGGER = LoggerFactory.getLogger(SelfTuningLibSvmLearner.class);
    private static final int STEP_SIZE = 2;
    private final int numFolds;

    public SelfTuningLibSvmLearner(int i) {
        this.numFolds = i;
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public LibSvmModel m4train(Dataset dataset) {
        RandomCrossValidator randomCrossValidator = new RandomCrossValidator(dataset, this.numFolds);
        double d = 0.0d;
        Pair pair = null;
        for (int i = -5; i <= 15; i += STEP_SIZE) {
            double pow = FastMath.pow(2.0d, i);
            for (int i2 = -15; i2 <= 3; i2 += STEP_SIZE) {
                double pow2 = FastMath.pow(2.0d, i2);
                double d2 = 0.0d;
                Iterator it = randomCrossValidator.iterator();
                while (it.hasNext()) {
                    CrossValidator.Fold fold = (CrossValidator.Fold) it.next();
                    d2 += new ConfusionMatrixEvaluator().evaluate(new LibSvmClassifier(), new LibSvmLearner(new RBFKernel(pow, pow2)).m2train(fold.getTrain()), fold.getTest()).getAccuracy();
                }
                double d3 = d2 / this.numFolds;
                if (d3 > d) {
                    d = d3;
                    pair = Pair.of(Double.valueOf(pow), Double.valueOf(pow2));
                }
                LOGGER.info("C = {}, gamma = {}, avg. accuracy = {}", new Object[]{Double.valueOf(pow), Double.valueOf(pow2), Double.valueOf(d3)});
            }
        }
        LOGGER.info("[BEST] C = {}, gamma = {}, avg. accuracy = {}", new Object[]{pair.getLeft(), pair.getRight(), Double.valueOf(d)});
        return new LibSvmLearner(new RBFKernel(((Double) pair.getLeft()).doubleValue(), ((Double) pair.getRight()).doubleValue())).m2train(dataset);
    }
}
