package pl.edu.icm.yadda.analysis.hmm.probability.decisiontree;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import pl.edu.icm.yadda.analysis.hmm.features.FeatureVector;
import pl.edu.icm.yadda.analysis.hmm.training.HMMTrainingElement;
import pl.edu.icm.yadda.analysis.textr.tools.ProbabilityDistribution;

/* loaded from: input_file:WEB-INF/lib/yadda-analysis-impl-1.10.0.jar:pl/edu/icm/yadda/analysis/hmm/probability/decisiontree/DecisionTreeBuilder.class */
public class DecisionTreeBuilder {
    private static int stopExpanding = 20;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/yadda-analysis-impl-1.10.0.jar:pl/edu/icm/yadda/analysis/hmm/probability/decisiontree/DecisionTreeBuilder$NodeDecision.class */
    public static class NodeDecision {
        String testedFeature;
        double cut;

        public NodeDecision(String str, double d) {
            this.testedFeature = str;
            this.cut = d;
        }

        public boolean isLeft(FeatureVector featureVector) {
            return featureVector.getFeature(this.testedFeature) <= this.cut;
        }

        public boolean isRight(FeatureVector featureVector) {
            return featureVector.getFeature(this.testedFeature) > this.cut;
        }
    }

    public static <T extends Comparable> DecisionTree<T> buildDecisionTree(Set<HMMTrainingElement<T, FeatureVector>> set, Set<String> set2) {
        return buildDecisionTree(set, set2, stopExpanding);
    }

    public static <T extends Comparable> SimpleDecisionTree<T> buildDecisionTree(Set<HMMTrainingElement<T, FeatureVector>> set, Set<String> set2, int i) {
        return constructNode(set, set2, i);
    }

    private static <T extends Comparable> SimpleDecisionTree<T> constructNode(Set<HMMTrainingElement<T, FeatureVector>> set, Set<String> set2, int i) {
        if (set.isEmpty()) {
            return null;
        }
        ProbabilityDistribution probabilityDistribution = new ProbabilityDistribution();
        Iterator<HMMTrainingElement<T, FeatureVector>> it = set.iterator();
        while (it.hasNext()) {
            probabilityDistribution.addEvent(it.next().getLabel());
        }
        if (probabilityDistribution.getEvents().size() == 1 || set2.isEmpty() || set.size() < i) {
            return new SimpleDecisionTree<>(probabilityDistribution);
        }
        NodeDecision chooseDecision = chooseDecision(set, set2);
        if (chooseDecision == null) {
            return new SimpleDecisionTree<>(probabilityDistribution);
        }
        HashSet hashSet = new HashSet(set2);
        hashSet.remove(chooseDecision.testedFeature);
        HashSet hashSet2 = new HashSet();
        HashSet hashSet3 = new HashSet();
        for (HMMTrainingElement<T, FeatureVector> hMMTrainingElement : set) {
            if (chooseDecision.isLeft(hMMTrainingElement.getObservation())) {
                hashSet2.add(hMMTrainingElement);
            } else {
                hashSet3.add(hMMTrainingElement);
            }
        }
        return new SimpleDecisionTree<>(probabilityDistribution, constructNode(hashSet2, hashSet, i), constructNode(hashSet3, hashSet, i), chooseDecision.testedFeature, chooseDecision.cut);
    }

    private static <T extends Comparable> NodeDecision chooseDecision(Set<HMMTrainingElement<T, FeatureVector>> set, Set<String> set2) {
        String str = null;
        double d = -1.0d;
        double d2 = 0.0d;
        ArrayList arrayList = new ArrayList(set);
        for (final String str2 : set2) {
            Collections.sort(arrayList, new Comparator() { // from class: pl.edu.icm.yadda.analysis.hmm.probability.decisiontree.DecisionTreeBuilder.1
                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    HMMTrainingElement hMMTrainingElement = (HMMTrainingElement) obj;
                    HMMTrainingElement hMMTrainingElement2 = (HMMTrainingElement) obj2;
                    int compare = Double.compare(((FeatureVector) hMMTrainingElement.getObservation()).getFeature(str2), ((FeatureVector) hMMTrainingElement2.getObservation()).getFeature(str2));
                    if (compare == 0) {
                        compare = ((Comparable) hMMTrainingElement.getLabel()).compareTo(hMMTrainingElement2.getLabel());
                    }
                    return compare;
                }
            });
            ProbabilityDistribution probabilityDistribution = new ProbabilityDistribution();
            ProbabilityDistribution probabilityDistribution2 = new ProbabilityDistribution();
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                probabilityDistribution2.addEvent(((HMMTrainingElement) it.next()).getLabel());
            }
            int i = 0;
            for (int i2 = 0; i2 < arrayList.size() - 1; i2++) {
                HMMTrainingElement hMMTrainingElement = (HMMTrainingElement) arrayList.get(i2);
                HMMTrainingElement hMMTrainingElement2 = (HMMTrainingElement) arrayList.get(i2 + 1);
                Comparable comparable = (Comparable) hMMTrainingElement.getLabel();
                Comparable comparable2 = (Comparable) hMMTrainingElement2.getLabel();
                if (i <= i2) {
                    double feature = ((FeatureVector) ((HMMTrainingElement) arrayList.get(i)).getObservation()).getFeature(str2);
                    while (i < arrayList.size() && ((FeatureVector) ((HMMTrainingElement) arrayList.get(i)).getObservation()).getFeature(str2) == feature) {
                        probabilityDistribution.addEvent(((HMMTrainingElement) arrayList.get(i)).getLabel());
                        probabilityDistribution2.removeEvent(((HMMTrainingElement) arrayList.get(i)).getLabel());
                        i++;
                    }
                }
                if (comparable != comparable2) {
                    double entropy = ((probabilityDistribution.getEntropy() * i) / arrayList.size()) + ((probabilityDistribution2.getEntropy() * (arrayList.size() - i)) / arrayList.size());
                    if (str == null || entropy < d2) {
                        double feature2 = ((FeatureVector) hMMTrainingElement.getObservation()).getFeature(str2);
                        double feature3 = ((FeatureVector) hMMTrainingElement2.getObservation()).getFeature(str2);
                        if (feature2 != feature3 || feature2 != ((FeatureVector) ((HMMTrainingElement) arrayList.get(arrayList.size() - 1)).getObservation()).getFeature(str2)) {
                            str = str2;
                            d = (feature2 + feature3) / 2.0d;
                            d2 = entropy;
                        }
                    }
                }
            }
        }
        if (str == null) {
            return null;
        }
        return new NodeDecision(str, d);
    }
}
