package ws.palladian.classification.quickml;

import java.util.Iterator;
import java.util.Objects;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest;
import quickml.supervised.tree.decisionTree.DecisionTree;
import quickml.supervised.tree.decisionTree.nodes.DTCatBranch;
import quickml.supervised.tree.decisionTree.nodes.DTLeaf;
import quickml.supervised.tree.decisionTree.nodes.DTNumBranch;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.nodes.Branch;
import quickml.supervised.tree.nodes.Node;
import ws.palladian.core.Model;

/* loaded from: input_file:ws/palladian/classification/quickml/QuickMlModel.class */
public class QuickMlModel implements Model {
    private static final long serialVersionUID = 1;
    private final Classifier classifier;
    private final Set<String> classes;

    /* loaded from: input_file:ws/palladian/classification/quickml/QuickMlModel$ToStringVistor.class */
    private static final class ToStringVistor implements TreeVisitor {
        final StringBuilder builder;
        int treeCount;

        private ToStringVistor() {
            this.builder = new StringBuilder();
            this.treeCount = 0;
        }

        @Override // ws.palladian.classification.quickml.TreeVisitor
        public void tree(DecisionTree decisionTree) {
            if (this.treeCount > 0) {
                this.builder.append('\n');
            }
            this.treeCount++;
            this.builder.append("Tree " + this.treeCount + ":\n");
        }

        @Override // ws.palladian.classification.quickml.TreeVisitor
        public void categoricalBranch(DTCatBranch dTCatBranch, boolean z) {
            appendBranch(dTCatBranch, z);
        }

        @Override // ws.palladian.classification.quickml.TreeVisitor
        public void numericalBranch(DTNumBranch dTNumBranch, boolean z) {
            appendBranch(dTNumBranch, z);
        }

        @Override // ws.palladian.classification.quickml.TreeVisitor
        public void leaf(DTLeaf dTLeaf, boolean z) {
            this.builder.append(getIndent(dTLeaf.getDepth()) + String.valueOf(z) + ": " + dTLeaf + "\n");
        }

        private static String getIndent(int i) {
            return StringUtils.repeat('\t', i);
        }

        private void appendBranch(Branch<?> branch, boolean z) {
            this.builder.append(getIndent(branch.getDepth()) + (branch.getDepth() > 0 ? String.valueOf(z) + ": " : "") + branch + "\n");
        }

        public String toString() {
            return this.builder.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public QuickMlModel(Classifier classifier, Set<String> set) {
        this.classifier = classifier;
        this.classes = set;
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public Set<String> getCategories() {
        return this.classes;
    }

    public String toString() {
        return ((ToStringVistor) traverseModel(new ToStringVistor())).toString();
    }

    public <TV extends TreeVisitor> TV traverseModel(TV tv) {
        Objects.requireNonNull(tv, "visitor must not be null");
        if (this.classifier instanceof RandomDecisionForest) {
            Iterator it = this.classifier.decisionTrees.iterator();
            while (it.hasNext()) {
                traverseTree((DecisionTree) it.next(), tv);
            }
        } else {
            if (!(this.classifier instanceof DecisionTree)) {
                throw new IllegalStateException("Unsupported classifer type: " + this.classifier.getClass().getName());
            }
            traverseTree(this.classifier, tv);
        }
        return tv;
    }

    private static void traverseTree(DecisionTree decisionTree, TreeVisitor treeVisitor) {
        treeVisitor.tree(decisionTree);
        traverseNode(decisionTree.root, treeVisitor, true);
    }

    private static void traverseNode(Node<ClassificationCounter> node, TreeVisitor treeVisitor, boolean z) {
        if (node instanceof DTCatBranch) {
            DTCatBranch dTCatBranch = (DTCatBranch) node;
            treeVisitor.categoricalBranch(dTCatBranch, z);
            traverseNode(dTCatBranch.getTrueChild(), treeVisitor, true);
            traverseNode(dTCatBranch.getFalseChild(), treeVisitor, false);
            return;
        }
        if (!(node instanceof DTNumBranch)) {
            if (node instanceof DTLeaf) {
                treeVisitor.leaf((DTLeaf) node, z);
            }
        } else {
            DTNumBranch dTNumBranch = (DTNumBranch) node;
            treeVisitor.numericalBranch(dTNumBranch, z);
            traverseNode(dTNumBranch.getTrueChild(), treeVisitor, true);
            traverseNode(dTNumBranch.getFalseChild(), treeVisitor, false);
        }
    }
}
