package ws.palladian.classification.text.evaluation;

import java.util.Objects;
import org.apache.commons.lang3.Validate;
import ws.palladian.classification.evaluation.roc.RocCurves;
import ws.palladian.classification.text.BayesScorer;
import ws.palladian.classification.text.DictionaryModel;
import ws.palladian.classification.text.FeatureSetting;
import ws.palladian.classification.text.PalladianTextClassifier;
import ws.palladian.classification.text.PruningStrategies;
import ws.palladian.classification.text.evaluation.PalladianTextClassifierOptimizerConfig;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.Classifier;
import ws.palladian.helper.NoProgress;
import ws.palladian.helper.ProgressReporter;
import ws.palladian.helper.functional.Filter;
import ws.palladian.helper.io.FileHelper;

/* loaded from: input_file:ws/palladian/classification/text/evaluation/PalladianTextClassifierOptimizer.class */
public final class PalladianTextClassifierOptimizer<R> {
    private final PalladianTextClassifierOptimizerConfig<R> config;

    public PalladianTextClassifierOptimizer(PalladianTextClassifierOptimizerConfig<R> palladianTextClassifierOptimizerConfig) {
        this.config = (PalladianTextClassifierOptimizerConfig) Objects.requireNonNull(palladianTextClassifierOptimizerConfig, "config must not be null");
    }

    public void runOptimization(ws.palladian.core.dataset.Dataset dataset, ws.palladian.core.dataset.Dataset dataset2, String str, ProgressReporter progressReporter) {
        Validate.notNull(dataset, "training must not be null", new Object[0]);
        Validate.notNull(dataset2, "validation must not be null", new Object[0]);
        Validate.notEmpty(str, "resultCsv must not be empty", new Object[0]);
        if (progressReporter == null) {
            progressReporter = NoProgress.INSTANCE;
        }
        boolean z = false;
        progressReporter.startTask("Evaluating feature settings", this.config.getFeatureSettings().size() * this.config.getScorers().size());
        for (FeatureSetting featureSetting : this.config.getFeatureSettings()) {
            DictionaryModel train = new PalladianTextClassifier(featureSetting, this.config.getDictionaryBuilder()).train(dataset);
            for (Filter<? super CategoryEntries> filter : this.config.getPruningStrategies()) {
                train = new PruningSimulatedDictionaryModel(train, filter);
                for (PalladianTextClassifier.Scorer scorer : this.config.getScorers()) {
                    R evaluate = this.config.getEvaluator().evaluate((Classifier<PalladianTextClassifier>) new PalladianTextClassifier(featureSetting, scorer), (PalladianTextClassifier) train, dataset2);
                    if (!z) {
                        FileHelper.appendFile(str, "featureSetting;scorer;pruningStrategy;" + this.config.getEvaluator().getCsvHeader(evaluate) + ";numTerms;numEntries\n");
                        z = true;
                    }
                    StringBuilder sb = new StringBuilder();
                    sb.append(featureSetting).append(';');
                    sb.append(scorer).append(';');
                    sb.append(filter).append(';');
                    sb.append(this.config.getEvaluator().getCsvLine(evaluate)).append(';');
                    sb.append(train.getNumUniqTerms()).append(';');
                    sb.append(train.getNumEntries()).append('\n');
                    FileHelper.appendFile(str, sb);
                    progressReporter.increment();
                }
            }
        }
        progressReporter.finishTask();
    }

    public static void main(String[] strArr) {
        TextDatasetIterator textDatasetIterator = new TextDatasetIterator("/Users/pk/Dropbox/Uni/Datasets/20newsgroups-18828/index_split1.txt", " ", true);
        TextDatasetIterator textDatasetIterator2 = new TextDatasetIterator("/Users/pk/Dropbox/Uni/Datasets/20newsgroups-18828/index_split2.txt", " ", true);
        PalladianTextClassifierOptimizerConfig.Builder withEvaluator = PalladianTextClassifierOptimizerConfig.withEvaluator(new RocCurves.RocCurvesEvaluator("true"));
        withEvaluator.setFeatureSettings(new FeatureSettingGenerator().chars(5, 8).words(1, 3).m41create());
        withEvaluator.setPruningStrategies(PruningStrategies.none(), PruningStrategies.termCount(2));
        withEvaluator.setScorers(new BayesScorer(BayesScorer.Options.LAPLACE, BayesScorer.Options.COMPLEMENT));
        withEvaluator.m43create().runOptimization(textDatasetIterator, textDatasetIterator2, "optimizationResult.csv", null);
    }
}
