package weka.classifiers.trees;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.meta.Bagging;
import weka.core.Capabilities;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.Utils;
import weka.core.WekaException;
import weka.gui.ProgrammaticProperty;
import weka.gui.knowledgeflow.KnowledgeFlowApp;

/* loaded from: input_file:weka/classifiers/trees/RandomForest.class */
public class RandomForest extends Bagging {
    static final long serialVersionUID = 1116839470751428698L;
    protected boolean m_computeAttributeImportance;

    @Override // weka.classifiers.IteratedSingleClassifierEnhancer
    protected int defaultNumberOfIterations() {
        return 100;
    }

    public RandomForest() {
        RandomTree randomTree = new RandomTree();
        randomTree.setDoNotCheckCapabilities(true);
        super.setClassifier(randomTree);
        super.setRepresentCopiesUsingWeights(true);
        setNumIterations(defaultNumberOfIterations());
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        return new RandomTree().getCapabilities();
    }

    @Override // weka.classifiers.meta.Bagging, weka.classifiers.SingleClassifierEnhancer
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.RandomTree";
    }

    @Override // weka.classifiers.SingleClassifierEnhancer
    protected String[] defaultClassifierOptions() {
        return new String[]{"-do-not-check-capabilities"};
    }

    @Override // weka.classifiers.meta.Bagging
    public String globalInfo() {
        return "Class for constructing a forest of random trees.\n\nFor more information see: \n\n" + getTechnicalInformation().toString();
    }

    @Override // weka.classifiers.meta.Bagging, weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Leo Breiman");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2001");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Random Forests");
        technicalInformation.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        technicalInformation.setValue(TechnicalInformation.Field.VOLUME, "45");
        technicalInformation.setValue(TechnicalInformation.Field.NUMBER, "1");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "5-32");
        return technicalInformation;
    }

    @Override // weka.classifiers.SingleClassifierEnhancer
    @ProgrammaticProperty
    public void setClassifier(Classifier classifier) {
        if (!(classifier instanceof RandomTree)) {
            throw new IllegalArgumentException("RandomForest: Argument of setClassifier() must be a RandomTree.");
        }
        super.setClassifier(classifier);
    }

    @Override // weka.classifiers.meta.Bagging
    @ProgrammaticProperty
    public void setRepresentCopiesUsingWeights(boolean z) {
        if (!z) {
            throw new IllegalArgumentException("RandomForest: Argument of setRepresentCopiesUsingWeights() must be true.");
        }
        super.setRepresentCopiesUsingWeights(z);
    }

    public String numFeaturesTipText() {
        return ((RandomTree) getClassifier()).KValueTipText();
    }

    public int getNumFeatures() {
        return ((RandomTree) getClassifier()).getKValue();
    }

    public void setNumFeatures(int i) {
        ((RandomTree) getClassifier()).setKValue(i);
    }

    public String computeAttributeImportanceTipText() {
        return "Compute attribute importance via mean impurity decrease";
    }

    public void setComputeAttributeImportance(boolean z) {
        this.m_computeAttributeImportance = z;
        ((RandomTree) this.m_Classifier).setComputeImpurityDecreases(z);
    }

    public boolean getComputeAttributeImportance() {
        return this.m_computeAttributeImportance;
    }

    public String maxDepthTipText() {
        return ((RandomTree) getClassifier()).maxDepthTipText();
    }

    public int getMaxDepth() {
        return ((RandomTree) getClassifier()).getMaxDepth();
    }

    public void setMaxDepth(int i) {
        ((RandomTree) getClassifier()).setMaxDepth(i);
    }

    public String breakTiesRandomlyTipText() {
        return ((RandomTree) getClassifier()).breakTiesRandomlyTipText();
    }

    public boolean getBreakTiesRandomly() {
        return ((RandomTree) getClassifier()).getBreakTiesRandomly();
    }

    public void setBreakTiesRandomly(boolean z) {
        ((RandomTree) getClassifier()).setBreakTiesRandomly(z);
    }

    @Override // weka.classifiers.AbstractClassifier
    public void setDebug(boolean z) {
        super.setDebug(z);
        ((RandomTree) getClassifier()).setDebug(z);
    }

    @Override // weka.classifiers.AbstractClassifier
    public void setNumDecimalPlaces(int i) {
        super.setNumDecimalPlaces(i);
        ((RandomTree) getClassifier()).setNumDecimalPlaces(i);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public void setBatchSize(String str) {
        super.setBatchSize(str);
        ((RandomTree) getClassifier()).setBatchSize(str);
    }

    @Override // weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer, weka.core.Randomizable
    public void setSeed(int i) {
        super.setSeed(i);
        ((RandomTree) getClassifier()).setSeed(i);
    }

    @Override // weka.classifiers.meta.Bagging
    public String toString() {
        if (this.m_Classifiers == null) {
            return "RandomForest: No model built yet.";
        }
        StringBuilder sb = new StringBuilder("RandomForest\n\n");
        sb.append(super.toString());
        if (getComputeAttributeImportance()) {
            try {
                double[] dArr = new double[this.m_data.numAttributes()];
                double[] computeAverageImpurityDecreasePerAttribute = computeAverageImpurityDecreasePerAttribute(dArr);
                int[] sort = Utils.sort(computeAverageImpurityDecreasePerAttribute);
                sb.append("\n\nAttribute importance based on average impurity decrease (and number of nodes using that attribute)\n\n");
                for (int length = sort.length - 1; length >= 0; length--) {
                    int i = sort[length];
                    if (i != this.m_data.classIndex()) {
                        sb.append(Utils.doubleToString(computeAverageImpurityDecreasePerAttribute[i], 10, getNumDecimalPlaces())).append(" (").append(Utils.doubleToString(dArr[i], 6, 0)).append(")  ").append(this.m_data.attribute(i).name()).append("\n");
                    }
                }
            } catch (WekaException e) {
            }
        }
        return sb.toString();
    }

    public double[] computeAverageImpurityDecreasePerAttribute(double[] dArr) throws WekaException {
        if (this.m_Classifiers == null) {
            throw new WekaException("Classifier has not been built yet!");
        }
        if (!getComputeAttributeImportance()) {
            throw new WekaException("Stats for attribute importance have not been collected!");
        }
        double[] dArr2 = new double[this.m_data.numAttributes()];
        if (dArr == null) {
            dArr = new double[this.m_data.numAttributes()];
        }
        for (Classifier classifier : this.m_Classifiers) {
            double[][] impurityDecreases = ((RandomTree) classifier).getImpurityDecreases();
            for (int i = 0; i < this.m_data.numAttributes(); i++) {
                int i2 = i;
                dArr2[i2] = dArr2[i2] + impurityDecreases[i][0];
                double[] dArr3 = dArr;
                int i3 = i;
                dArr3[i3] = dArr3[i3] + impurityDecreases[i][1];
            }
        }
        for (int i4 = 0; i4 < this.m_data.numAttributes(); i4++) {
            if (dArr[i4] > KStarConstants.FLOOR) {
                int i5 = i4;
                dArr2[i5] = dArr2[i5] / dArr[i4];
            }
        }
        return dArr2;
    }

    @Override // weka.classifiers.meta.Bagging, weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer, weka.classifiers.ParallelIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tSize of each bag, as a percentage of the\n\ttraining set size. (default 100)", "P", 1, "-P"));
        vector.addElement(new Option("\tCalculate the out of bag error.", "O", 0, "-O"));
        vector.addElement(new Option("\tWhether to store out of bag predictions in internal evaluation object.", "store-out-of-bag-predictions", 0, "-store-out-of-bag-predictions"));
        vector.addElement(new Option("\tWhether to output complexity-based statistics when out-of-bag evaluation is performed.", "output-out-of-bag-complexity-statistics", 0, "-output-out-of-bag-complexity-statistics"));
        vector.addElement(new Option("\tPrint the individual classifiers in the output", "print", 0, "-print"));
        vector.addElement(new Option("\tCompute and output attribute importance (mean impurity decrease method)", "attribute-importance", 0, "-attribute-importance"));
        vector.addElement(new Option("\tNumber of iterations.\n\t(current value " + getNumIterations() + ")", "I", 1, "-I <num>"));
        vector.addElement(new Option("\tNumber of execution slots.\n\t(default 1 - i.e. no parallelism)\n\t(use 0 to auto-detect number of cores)", "num-slots", 1, "-num-slots <num>"));
        vector.addAll(Collections.list(((OptionHandler) getClassifier()).listOptions()));
        return vector.elements();
    }

    @Override // weka.classifiers.meta.Bagging, weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer, weka.classifiers.ParallelIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-P");
        vector.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getBagSizePercent());
        if (getCalcOutOfBag()) {
            vector.add("-O");
        }
        if (getStoreOutOfBagPredictions()) {
            vector.add("-store-out-of-bag-predictions");
        }
        if (getOutputOutOfBagComplexityStatistics()) {
            vector.add("-output-out-of-bag-complexity-statistics");
        }
        if (getPrintClassifiers()) {
            vector.add("-print");
        }
        if (getComputeAttributeImportance()) {
            vector.add("-attribute-importance");
        }
        vector.add("-I");
        vector.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getNumIterations());
        vector.add("-num-slots");
        vector.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getNumExecutionSlots());
        if (getDoNotCheckCapabilities()) {
            vector.add("-do-not-check-capabilities");
        }
        Vector vector2 = new Vector();
        Collections.addAll(vector2, ((OptionHandler) getClassifier()).getOptions());
        Option.deleteFlagString(vector2, "-do-not-check-capabilities");
        vector.addAll(vector2);
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    @Override // weka.classifiers.meta.Bagging, weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer, weka.classifiers.ParallelIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('P', strArr);
        if (option.length() != 0) {
            setBagSizePercent(Integer.parseInt(option));
        } else {
            setBagSizePercent(100);
        }
        setCalcOutOfBag(Utils.getFlag('O', strArr));
        setStoreOutOfBagPredictions(Utils.getFlag("store-out-of-bag-predictions", strArr));
        setOutputOutOfBagComplexityStatistics(Utils.getFlag("output-out-of-bag-complexity-statistics", strArr));
        setPrintClassifiers(Utils.getFlag("print", strArr));
        setComputeAttributeImportance(Utils.getFlag("attribute-importance", strArr));
        String option2 = Utils.getOption('I', strArr);
        if (option2.length() != 0) {
            setNumIterations(Integer.parseInt(option2));
        } else {
            setNumIterations(defaultNumberOfIterations());
        }
        String option3 = Utils.getOption("num-slots", strArr);
        if (option3.length() != 0) {
            setNumExecutionSlots(Integer.parseInt(option3));
        } else {
            setNumExecutionSlots(1);
        }
        RandomTree randomTree = (RandomTree) AbstractClassifier.forName(defaultClassifierString(), strArr);
        randomTree.setComputeImpurityDecreases(this.m_computeAttributeImportance);
        setDoNotCheckCapabilities(randomTree.getDoNotCheckCapabilities());
        setSeed(randomTree.getSeed());
        setDebug(randomTree.getDebug());
        setNumDecimalPlaces(randomTree.getNumDecimalPlaces());
        setBatchSize(randomTree.getBatchSize());
        randomTree.setDoNotCheckCapabilities(true);
        setClassifier(randomTree);
        Utils.checkForRemainingOptions(strArr);
    }

    @Override // weka.classifiers.meta.Bagging, weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 13294 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new RandomForest(), strArr);
    }
}
