package ws.palladian.classification.language;

import java.io.IOException;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.text.DictionaryModel;
import ws.palladian.classification.text.FeatureSetting;
import ws.palladian.classification.text.FeatureSettingBuilder;
import ws.palladian.classification.text.PalladianTextClassifier;
import ws.palladian.classification.text.evaluation.Dataset;
import ws.palladian.classification.text.evaluation.TextDatasetIterator;
import ws.palladian.core.Category;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.CategoryEntriesBuilder;
import ws.palladian.helper.StopWatch;
import ws.palladian.helper.constants.Language;
import ws.palladian.helper.io.FileHelper;

/* loaded from: input_file:ws/palladian/classification/language/PalladianLangDetect.class */
public class PalladianLangDetect implements LanguageClassifier {
    private static final Logger LOGGER = LoggerFactory.getLogger(PalladianLangDetect.class);
    private final PalladianTextClassifier textClassifier;
    private final DictionaryModel dictionaryModel;
    private Set<String> possibleClasses = null;

    public PalladianLangDetect(String str) {
        try {
            this.dictionaryModel = (DictionaryModel) FileHelper.deserialize(str);
            this.textClassifier = new PalladianTextClassifier(this.dictionaryModel.getFeatureSetting());
        } catch (IOException e) {
            throw new IllegalStateException("Could not deserialize model from \"" + str + "\"", e);
        }
    }

    public Set<String> getPossibleClasses() {
        return this.possibleClasses;
    }

    public void setPossibleClasses(Set<String> set) {
        this.possibleClasses = set;
    }

    public static void train(Dataset dataset, String str, String str2) {
        train(dataset, str, str2, null);
    }

    public static void train(Dataset dataset, String str, String str2, FeatureSetting featureSetting) {
        StopWatch stopWatch = new StopWatch();
        String str3 = str2 + str + ".gz";
        try {
            FileHelper.serialize(new PalladianTextClassifier(featureSetting != null ? featureSetting : FeatureSettingBuilder.chars(4, 7).m40create()).train((ws.palladian.core.dataset.Dataset) new TextDatasetIterator(dataset)), str3);
            LOGGER.info("finished training classifier in {}", stopWatch.getElapsedTimeString());
        } catch (IOException e) {
            throw new IllegalStateException("Error while serializing to \"" + str3 + "\".", e);
        }
    }

    @Override // ws.palladian.classification.language.LanguageClassifier
    public Language classify(String str) {
        return Language.getByIso6391(classifyAsCategoryEntry(str).getMostLikelyCategory());
    }

    public CategoryEntries classifyAsCategoryEntry(String str) {
        return narrowCategories(this.textClassifier.classify(str, this.dictionaryModel));
    }

    private CategoryEntries narrowCategories(CategoryEntries categoryEntries) {
        if (this.possibleClasses == null) {
            return categoryEntries;
        }
        CategoryEntriesBuilder categoryEntriesBuilder = new CategoryEntriesBuilder();
        for (Category category : categoryEntries) {
            if (this.possibleClasses.contains(category.getName())) {
                categoryEntriesBuilder.set(category.getName(), category.getProbability());
            }
        }
        return categoryEntriesBuilder.m70create();
    }

    public static void main(String[] strArr) throws IOException {
        Dataset dataset = new Dataset();
        dataset.setFirstFieldLink(true);
        dataset.setPath("H:\\PalladianData\\Datasets\\JRCLanguageCorpus\\indexAll22Languages_ipc20.txt");
        dataset.setFirstFieldLink(true);
        dataset.setSeparationString(" ");
        train(dataset, "jrc22Languages20ipc", "data/models/palladian/language/");
    }
}
