package pl.edu.icm.cermine.tools.classification.hmm.model;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import pl.edu.icm.cermine.structure.tools.ProbabilityDistribution;
import pl.edu.icm.cermine.tools.classification.features.FeatureVector;

/* loaded from: input_file:pl/edu/icm/cermine/tools/classification/hmm/model/DecisionTreeBuilder.class */
public final class DecisionTreeBuilder {
    public static final int DEFAULT_STOP_EXPANDING = 20;
    private static int stopExpanding = 20;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:pl/edu/icm/cermine/tools/classification/hmm/model/DecisionTreeBuilder$NodeDecision.class */
    public static class NodeDecision {
        private String testedFeature;
        private double cut;

        public NodeDecision(String str, double d) {
            this.testedFeature = str;
            this.cut = d;
        }

        public boolean isLeft(FeatureVector featureVector) {
            return featureVector.getValue(this.testedFeature) <= this.cut;
        }

        public boolean isRight(FeatureVector featureVector) {
            return featureVector.getValue(this.testedFeature) > this.cut;
        }
    }

    private DecisionTreeBuilder() {
    }

    public static <S extends Comparable<S>> DecisionTree<S> buildDecisionTree(Set<HMMTrainingSample<S>> set, List<String> list) {
        return buildDecisionTree(set, list, stopExpanding);
    }

    public static <S extends Comparable<S>> DecisionTree<S> buildDecisionTree(Set<HMMTrainingSample<S>> set, List<String> list, int i) {
        return constructNode(set, list, i);
    }

    private static <S extends Comparable<S>> DecisionTree<S> constructNode(Set<HMMTrainingSample<S>> set, List<String> list, int i) {
        if (set.isEmpty()) {
            return null;
        }
        ProbabilityDistribution probabilityDistribution = new ProbabilityDistribution();
        Iterator<HMMTrainingSample<S>> it = set.iterator();
        while (it.hasNext()) {
            probabilityDistribution.addEvent(it.next().getLabel());
        }
        if (probabilityDistribution.getEvents().size() == 1 || list.isEmpty() || set.size() < i) {
            return new DecisionTree<>(probabilityDistribution);
        }
        NodeDecision chooseDecision = chooseDecision(set, list);
        if (chooseDecision == null) {
            return new DecisionTree<>(probabilityDistribution);
        }
        ArrayList arrayList = new ArrayList(list);
        arrayList.remove(chooseDecision.testedFeature);
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (HMMTrainingSample<S> hMMTrainingSample : set) {
            if (chooseDecision.isLeft(hMMTrainingSample.getObservation())) {
                hashSet.add(hMMTrainingSample);
            } else {
                hashSet2.add(hMMTrainingSample);
            }
        }
        return new DecisionTree<>(probabilityDistribution, constructNode(hashSet, arrayList, i), constructNode(hashSet2, arrayList, i), chooseDecision.testedFeature, chooseDecision.cut);
    }

    private static <S extends Comparable<S>> NodeDecision chooseDecision(Set<HMMTrainingSample<S>> set, List<String> list) {
        String str = null;
        double d = -1.0d;
        double d2 = 0.0d;
        ArrayList arrayList = new ArrayList(set);
        for (final String str2 : list) {
            Collections.sort(arrayList, new Comparator<HMMTrainingSample<S>>() { // from class: pl.edu.icm.cermine.tools.classification.hmm.model.DecisionTreeBuilder.1
                @Override // java.util.Comparator
                public int compare(HMMTrainingSample<S> hMMTrainingSample, HMMTrainingSample<S> hMMTrainingSample2) {
                    int compare = Double.compare(hMMTrainingSample.getObservation().getValue(str2), hMMTrainingSample2.getObservation().getValue(str2));
                    if (compare == 0) {
                        compare = ((Comparable) hMMTrainingSample.getLabel()).compareTo(hMMTrainingSample2.getLabel());
                    }
                    return compare;
                }
            });
            ProbabilityDistribution probabilityDistribution = new ProbabilityDistribution();
            ProbabilityDistribution probabilityDistribution2 = new ProbabilityDistribution();
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                probabilityDistribution2.addEvent(((HMMTrainingSample) it.next()).getLabel());
            }
            int i = 0;
            for (int i2 = 0; i2 < arrayList.size() - 1; i2++) {
                HMMTrainingSample hMMTrainingSample = (HMMTrainingSample) arrayList.get(i2);
                HMMTrainingSample hMMTrainingSample2 = (HMMTrainingSample) arrayList.get(i2 + 1);
                Comparable comparable = (Comparable) hMMTrainingSample.getLabel();
                Comparable comparable2 = (Comparable) hMMTrainingSample2.getLabel();
                if (i <= i2) {
                    double value = ((HMMTrainingSample) arrayList.get(i)).getObservation().getValue(str2);
                    while (i < arrayList.size() && ((HMMTrainingSample) arrayList.get(i)).getObservation().getValue(str2) == value) {
                        probabilityDistribution.addEvent(((HMMTrainingSample) arrayList.get(i)).getLabel());
                        probabilityDistribution2.removeEvent(((HMMTrainingSample) arrayList.get(i)).getLabel());
                        i++;
                    }
                }
                if (comparable.equals(comparable2)) {
                    double entropy = ((probabilityDistribution.getEntropy() * i) / arrayList.size()) + ((probabilityDistribution2.getEntropy() * (arrayList.size() - i)) / arrayList.size());
                    if (str == null || entropy < d2) {
                        double value2 = hMMTrainingSample.getObservation().getValue(str2);
                        double value3 = hMMTrainingSample2.getObservation().getValue(str2);
                        if (value2 != value3 || value2 != ((HMMTrainingSample) arrayList.get(arrayList.size() - 1)).getObservation().getValue(str2)) {
                            str = str2;
                            d = (value2 + value3) / 2.0d;
                            d2 = entropy;
                        }
                    }
                }
            }
        }
        if (str == null) {
            return null;
        }
        return new NodeDecision(str, d);
    }
}
