package ws.palladian.dataset;

import java.io.File;
import java.io.IOException;
import java.util.Random;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ws/palladian/dataset/DatasetIterator.class */
public class DatasetIterator {
    private static final Logger LOGGER = LoggerFactory.getLogger(DatasetIterator.class);
    private static final String[] ALLOWED_EXTENSIONS = BaseImageLoader.ALLOWED_FORMATS;
    private static final Random rng = new Random(13);
    private ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    private InputSplit trainData;
    private InputSplit testData;
    private int batchSize;
    private ImageDataset imageDataset;

    public DatasetIterator(ImageDataset imageDataset, int i) {
        this.batchSize = i;
        this.imageDataset = imageDataset;
    }

    public DataSetIterator getTrainingIterator() throws IOException {
        return makeIterator(this.trainData);
    }

    public DataSetIterator getTestingIterator() throws IOException {
        return makeIterator(this.testData);
    }

    public void setup(int i) throws IOException {
        FileSplit fileSplit = new FileSplit(new File(this.imageDataset.getFolderedPath()), ALLOWED_EXTENSIONS, rng);
        BalancedPathFilter balancedPathFilter = new BalancedPathFilter(rng, ALLOWED_EXTENSIONS, this.labelMaker);
        if (i <= 0) {
            throw new IllegalArgumentException("Percentage of data set aside for training has to be less more than 0%. Test percentage = 100 - training percentage, has to be greater than 0");
        }
        if (i >= 100) {
            throw new IllegalArgumentException("Percentage of data set aside for training has to be less than 100%. Test percentage = 100 - training percentage, has to be greater than 0");
        }
        InputSplit[] sample = fileSplit.sample(balancedPathFilter, new double[]{i, 100 - i});
        this.trainData = sample[0];
        this.testData = sample[1];
    }

    private DataSetIterator makeIterator(InputSplit inputSplit) throws IOException {
        ImageRecordReader imageRecordReader = new ImageRecordReader(224, 224, 3, this.labelMaker);
        imageRecordReader.initialize(inputSplit);
        RecordReaderDataSetIterator recordReaderDataSetIterator = new RecordReaderDataSetIterator(imageRecordReader, this.batchSize, 1, this.imageDataset.getNumberOfClasses());
        recordReaderDataSetIterator.setPreProcessor(TrainedModels.VGG16.getPreProcessor());
        return recordReaderDataSetIterator;
    }
}
