package ws.palladian.classification.numeric;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.tuple.Pair;
import ws.palladian.core.Category;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.CategoryEntriesBuilder;
import ws.palladian.core.Classifier;
import ws.palladian.core.FeatureVector;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.EntryValueComparator;
import ws.palladian.helper.collection.FixedSizePriorityQueue;

/* loaded from: input_file:ws/palladian/classification/numeric/KnnClassifier.class */
public final class KnnClassifier implements Classifier<KnnModel> {
    private int k;
    private boolean useTextualFeatures;

    public KnnClassifier(int i) {
        Validate.isTrue(i > 0, "k must be greater zero", new Object[0]);
        this.k = i;
    }

    public KnnClassifier(int i, boolean z) {
        Validate.isTrue(i > 0, "k must be greater zero", new Object[0]);
        this.k = i;
        this.useTextualFeatures = z;
    }

    public KnnClassifier() {
        this(3);
    }

    @Override // ws.palladian.core.Classifier
    public CategoryEntries classify(FeatureVector featureVector, KnnModel knnModel) {
        CategoryEntriesBuilder categoryEntriesBuilder = new CategoryEntriesBuilder().set(knnModel.getCategories(), 0.0d);
        double[] normalizedVectorForClassification = knnModel.getNormalizedVectorForClassification(featureVector);
        String[] stringVectorForClassification = knnModel.getStringVectorForClassification(featureVector);
        FixedSizePriorityQueue fixedSizePriorityQueue = new FixedSizePriorityQueue(this.k, new EntryValueComparator(CollectionHelper.Order.DESCENDING));
        for (TrainingExample trainingExample : knnModel.getTrainingExamples()) {
            fixedSizePriorityQueue.add(Pair.of(trainingExample.category, Double.valueOf(this.useTextualFeatures ? trainingExample.distance(normalizedVectorForClassification, stringVectorForClassification) : trainingExample.distance(normalizedVectorForClassification))));
        }
        for (Pair pair : fixedSizePriorityQueue.asList()) {
            categoryEntriesBuilder.add((String) pair.getKey(), 1.0d / (((Double) pair.getValue()).doubleValue() + 1.0E-9d));
        }
        return categoryEntriesBuilder.m76create();
    }

    public List<String> getNeighbors(FeatureVector featureVector, KnnModel knnModel, int i) {
        int i2 = this.k;
        this.k = i;
        ArrayList arrayList = new ArrayList();
        Iterator<Category> it = classify(featureVector, knnModel).iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getName());
        }
        this.k = i2;
        return CollectionHelper.getSublist(arrayList, 0, i);
    }

    public String toString() {
        return getClass().getSimpleName() + " (k=" + this.k + ")";
    }
}
