package ws.palladian.classification.evaluation;

import java.util.Iterator;
import java.util.Objects;
import java.util.Random;
import java.util.function.Predicate;
import ws.palladian.classification.evaluation.CrossValidator;
import ws.palladian.core.Instance;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.helper.collection.AbstractIterator2;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.functional.Factory;
import ws.palladian.helper.functional.Predicates;

/* loaded from: input_file:ws/palladian/classification/evaluation/RandomCrossValidator.class */
public class RandomCrossValidator implements CrossValidator {
    private final Dataset data;
    private final int numFolds;
    private final int[] foldAssignments;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ws/palladian/classification/evaluation/RandomCrossValidator$FoldAssignmentFilter.class */
    public final class FoldAssignmentFilter implements Predicate<Object> {
        private final int fold;
        private int currentIndex;

        public FoldAssignmentFilter(int i) {
            this.fold = i;
        }

        @Override // java.util.function.Predicate
        public boolean test(Object obj) {
            int[] iArr = RandomCrossValidator.this.foldAssignments;
            int i = this.currentIndex;
            this.currentIndex = i + 1;
            return iArr[i] == this.fold;
        }
    }

    /* loaded from: input_file:ws/palladian/classification/evaluation/RandomCrossValidator$FoldIterator.class */
    private final class FoldIterator extends AbstractIterator2<CrossValidator.Fold> {
        private int currentFold;

        private FoldIterator() {
            this.currentFold = 0;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* renamed from: getNext, reason: merged with bridge method [inline-methods] */
        public CrossValidator.Fold m13getNext() {
            if (this.currentFold == RandomCrossValidator.this.numFolds) {
                return (CrossValidator.Fold) finished();
            }
            RandomCrossValidator randomCrossValidator = RandomCrossValidator.this;
            int i = this.currentFold;
            this.currentFold = i + 1;
            return new RandomFold(i);
        }
    }

    /* loaded from: input_file:ws/palladian/classification/evaluation/RandomCrossValidator$RandomFold.class */
    public final class RandomFold implements CrossValidator.Fold {
        private final int fold;

        private RandomFold(int i) {
            this.fold = i;
        }

        @Override // ws.palladian.core.dataset.split.TrainTestSplit
        public Dataset getTrain() {
            return RandomCrossValidator.this.data.subset(new Factory<Predicate<Object>>() { // from class: ws.palladian.classification.evaluation.RandomCrossValidator.RandomFold.1
                /* renamed from: create, reason: merged with bridge method [inline-methods] */
                public Predicate<Object> m14create() {
                    return Predicates.not(new FoldAssignmentFilter(RandomFold.this.fold));
                }
            });
        }

        @Override // ws.palladian.core.dataset.split.TrainTestSplit
        public Dataset getTest() {
            return RandomCrossValidator.this.data.subset(new Factory<Predicate<Object>>() { // from class: ws.palladian.classification.evaluation.RandomCrossValidator.RandomFold.2
                /* renamed from: create, reason: merged with bridge method [inline-methods] */
                public Predicate<Object> m15create() {
                    return new FoldAssignmentFilter(RandomFold.this.fold);
                }
            });
        }

        @Override // ws.palladian.classification.evaluation.CrossValidator.Fold
        public int getFold() {
            return this.fold;
        }

        public String toString() {
            return "Fold " + this.fold;
        }
    }

    public RandomCrossValidator(Dataset dataset) {
        this(dataset, CollectionHelper.count(dataset.iterator()));
    }

    public RandomCrossValidator(Dataset dataset, int i) {
        this.data = (Dataset) Objects.requireNonNull(dataset);
        if (i <= 2) {
            throw new IllegalArgumentException("numFolds must be at least 2");
        }
        this.numFolds = i;
        this.foldAssignments = calculateFoldAssignments(dataset, i);
    }

    private static int[] calculateFoldAssignments(Iterable<? extends Instance> iterable, int i) {
        int count = CollectionHelper.count(iterable.iterator());
        int[] iArr = new int[count];
        for (int i2 = 0; i2 < count; i2++) {
            iArr[i2] = i2 % i;
        }
        shuffle(iArr);
        return iArr;
    }

    private static void shuffle(int[] iArr) {
        Random random = new Random();
        for (int length = iArr.length - 1; length > 0; length--) {
            int nextInt = random.nextInt(length + 1);
            int i = iArr[nextInt];
            iArr[nextInt] = iArr[length];
            iArr[length] = i;
        }
    }

    @Override // java.lang.Iterable
    public Iterator<CrossValidator.Fold> iterator() {
        return new FoldIterator();
    }

    @Override // ws.palladian.classification.evaluation.CrossValidator
    public int getNumFolds() {
        return this.numFolds;
    }
}
