package ws.palladian.classification.text.vector;

import de.bwaldvogel.liblinear.Parameter;
import java.util.Set;
import ws.palladian.classification.liblinear.LibLinearClassifier;
import ws.palladian.classification.liblinear.LibLinearLearner;
import ws.palladian.classification.liblinear.LibLinearModel;
import ws.palladian.classification.utils.NoNormalizer;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.Classifier;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Learner;
import ws.palladian.core.Model;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.extraction.text.vector.ITextVectorizer;
import ws.palladian.helper.functional.Predicates;

/* loaded from: input_file:ws/palladian/classification/text/vector/TextVectorClassifier.class */
public class TextVectorClassifier<M extends Model> extends AbstractLearner<TextVectorModel<M>> implements Classifier<TextVectorModel<M>> {
    private final ITextVectorizer vectorizer;
    private final Learner<M> learner;
    private final Classifier<M> classifier;

    /* loaded from: input_file:ws/palladian/classification/text/vector/TextVectorClassifier$TextVectorModel.class */
    public static class TextVectorModel<M extends Model> implements Model {
        private static final long serialVersionUID = 1;
        private final M model;

        public TextVectorModel(M m) {
            this.model = m;
        }

        public Set<String> getCategories() {
            return this.model.getCategories();
        }
    }

    public static TextVectorClassifier<LibLinearModel> libLinear(ITextVectorizer iTextVectorizer) {
        return new TextVectorClassifier<>(iTextVectorizer, new LibLinearLearner(new NoNormalizer()), new LibLinearClassifier());
    }

    public static TextVectorClassifier<LibLinearModel> libLinear(ITextVectorizer iTextVectorizer, Parameter parameter) {
        return new TextVectorClassifier<>(iTextVectorizer, new LibLinearLearner(parameter, 1.0d, new NoNormalizer()), new LibLinearClassifier());
    }

    public TextVectorClassifier(ITextVectorizer iTextVectorizer, Learner<M> learner, Classifier<M> classifier) {
        this.vectorizer = iTextVectorizer;
        this.learner = learner;
        this.classifier = classifier;
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public TextVectorModel<M> m16train(Dataset dataset) {
        return new TextVectorModel<>(this.learner.train(dataset.transform(this.vectorizer).filterFeatures(Predicates.not(Predicates.equal("text")))));
    }

    public CategoryEntries classify(FeatureVector featureVector, TextVectorModel<M> textVectorModel) {
        return this.classifier.classify(this.vectorizer.apply(featureVector), ((TextVectorModel) textVectorModel).model);
    }

    public String toString() {
        return String.format("%s [learner=%s, vectorizer=%s]", getClass().getSimpleName(), this.learner, this.vectorizer);
    }
}
