package quickdt.randomForest;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickdt.AbstractInstance;
import quickdt.Misc;
import quickdt.PredictiveModelBuilder;
import quickdt.Tree;
import quickdt.TreeBuilder;

/* loaded from: input_file:quickdt/randomForest/RandomForestBuilder.class */
public class RandomForestBuilder implements PredictiveModelBuilder<RandomForest> {
    private static final Logger logger = LoggerFactory.getLogger(RandomForestBuilder.class);
    private final TreeBuilder treeBuilder;
    private int numTrees;
    private boolean useBagging;
    private int executorThreadCount;
    private ExecutorService executorService;

    public RandomForestBuilder() {
        this(new TreeBuilder());
    }

    public RandomForestBuilder(TreeBuilder treeBuilder) {
        this.numTrees = 8;
        this.useBagging = false;
        this.executorThreadCount = 8;
        this.treeBuilder = treeBuilder;
    }

    public RandomForestBuilder numTrees(int i) {
        this.numTrees = i;
        return this;
    }

    public RandomForestBuilder useBagging(boolean z) {
        this.useBagging = z;
        return this;
    }

    public RandomForestBuilder executorThreadCount(int i) {
        this.executorThreadCount = i;
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // quickdt.PredictiveModelBuilder
    public RandomForest buildPredictiveModel(Iterable<? extends AbstractInstance> iterable) {
        initExecutorService();
        logger.info("Building random forest with {} trees, bagging {}", Integer.valueOf(this.numTrees), Boolean.valueOf(this.useBagging));
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(this.numTrees);
        ArrayList newArrayListWithCapacity2 = Lists.newArrayListWithCapacity(this.numTrees);
        for (int i = 0; i < this.numTrees; i++) {
            newArrayListWithCapacity.add(submitTreeBuild(iterable, i));
        }
        Iterator it = newArrayListWithCapacity.iterator();
        while (it.hasNext()) {
            collectTreeFutures(newArrayListWithCapacity2, (Future) it.next());
        }
        this.executorService.shutdown();
        return new RandomForest(newArrayListWithCapacity2);
    }

    private Future<Tree> submitTreeBuild(final Iterable<? extends AbstractInstance> iterable, final int i) {
        return this.executorService.submit(new Callable<Tree>() { // from class: quickdt.randomForest.RandomForestBuilder.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public Tree call() throws Exception {
                return RandomForestBuilder.this.buildModel(iterable, i);
            }
        });
    }

    private void initExecutorService() {
        if (this.executorService == null) {
            this.executorService = Executors.newFixedThreadPool(this.executorThreadCount);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Tree buildModel(Iterable<? extends AbstractInstance> iterable, int i) {
        logger.info("Building tree {} of {}", Integer.valueOf(i), Integer.valueOf(this.numTrees));
        if (this.useBagging) {
            iterable = getBootstrapSampling(iterable);
        }
        return this.treeBuilder.buildPredictiveModel(iterable);
    }

    private void collectTreeFutures(List<Tree> list, Future<Tree> future) {
        try {
            list.add(future.get());
        } catch (Exception e) {
            logger.error("Error retrieving tree", e);
        }
    }

    private static List<AbstractInstance> getBootstrapSampling(Iterable<? extends AbstractInstance> iterable) {
        ArrayList newArrayList = Lists.newArrayList(iterable);
        ArrayList newArrayList2 = Lists.newArrayList();
        for (int i = 0; i < newArrayList.size(); i++) {
            newArrayList2.add(newArrayList.get(Misc.random.nextInt(newArrayList.size())));
        }
        return newArrayList2;
    }

    @Override // quickdt.PredictiveModelBuilder
    public /* bridge */ /* synthetic */ RandomForest buildPredictiveModel(Iterable iterable) {
        return buildPredictiveModel((Iterable<? extends AbstractInstance>) iterable);
    }
}
