package net.sf.javaml.classification;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.core.kdtree.KDTree;
import net.sf.javaml.core.kdtree.KeyDuplicateException;
import net.sf.javaml.core.kdtree.KeySizeException;
import net.sf.javaml.tools.InstanceTools;

/* loaded from: input_file:net/sf/javaml/classification/KDtreeKNN.class */
public class KDtreeKNN extends AbstractClassifier {
    private static final long serialVersionUID = 1560149339188819924L;
    private int k;
    private KDTree tree;
    private Dataset training;

    public KDtreeKNN(int i) {
        this.k = i;
    }

    @Override // net.sf.javaml.classification.AbstractClassifier, net.sf.javaml.classification.Classifier
    public void buildClassifier(Dataset dataset) {
        this.training = dataset;
        try {
            this.tree = new KDTree(dataset.noAttributes());
            for (Instance instance : dataset) {
                this.tree.insert(InstanceTools.array(instance), instance);
            }
        } catch (KeyDuplicateException e) {
            e.printStackTrace();
        } catch (KeySizeException e2) {
            e2.printStackTrace();
        }
    }

    @Override // net.sf.javaml.classification.AbstractClassifier, net.sf.javaml.classification.Classifier
    public Map<Object, Double> classDistribution(Instance instance) {
        try {
            Object[] nearest = this.tree.nearest(InstanceTools.array(instance), this.k);
            HashMap hashMap = new HashMap();
            Iterator<Object> it = this.training.classes().iterator();
            while (it.hasNext()) {
                hashMap.put(it.next(), Double.valueOf(0.0d));
            }
            for (Object obj : nearest) {
                Instance instance2 = (Instance) obj;
                hashMap.put(instance2.classValue(), Double.valueOf(((Double) hashMap.get(instance2.classValue())).doubleValue() + 1.0d));
            }
            double d = this.k;
            double d2 = 0.0d;
            Iterator it2 = hashMap.keySet().iterator();
            while (it2.hasNext()) {
                double doubleValue = ((Double) hashMap.get(it2.next())).doubleValue();
                if (doubleValue > d2) {
                    d2 = doubleValue;
                }
                if (doubleValue < d) {
                    d = doubleValue;
                }
            }
            for (Object obj2 : hashMap.keySet()) {
                hashMap.put(obj2, Double.valueOf((((Double) hashMap.get(obj2)).doubleValue() - d) / (d2 - d)));
            }
            return hashMap;
        } catch (IllegalArgumentException e) {
            e.printStackTrace();
            return null;
        } catch (KeySizeException e2) {
            e2.printStackTrace();
            return null;
        }
    }
}
