package ws.palladian.classifiers;

import com.mashape.unirest.http.exceptions.UnirestException;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.input.PortableDataStream;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModelHelper;
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.TrainedModels;
import org.deeplearning4j.nn.modelimport.keras.trainedmodels.Utils.ImageNetLabels;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.transferlearning.TransferLearningHelper;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.multilayer.evaluation.IEvaluateFlatMapFunction;
import org.deeplearning4j.spark.impl.multilayer.evaluation.IEvaluationReduceFunction;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import ws.palladian.core.Category;
import ws.palladian.core.CategoryEntries;
import ws.palladian.core.ImmutableCategory;
import ws.palladian.core.ImmutableCategoryEntries;
import ws.palladian.dataset.DatasetFeaturizer;
import ws.palladian.dataset.DatasetIteratorFeaturized;
import ws.palladian.dataset.ImageDataset;
import ws.palladian.extraction.multimedia.ImageHandler;
import ws.palladian.helper.StopWatch;
import ws.palladian.retrieval.parser.json.JsonException;

/* loaded from: input_file:ws/palladian/classifiers/ImageClassifier.class */
public class ImageClassifier {
    private static final Logger LOGGER = LoggerFactory.getLogger(ImageClassifier.class);
    private ComputationGraph vgg16;
    private List<String> labels;

    public ImageClassifier() {
        this(null, null, ImageNetLabels.getLabels());
    }

    public ImageClassifier(String str, String str2, List<String> list) {
        TrainedModelHelper trainedModelHelper = new TrainedModelHelper(TrainedModels.VGG16);
        if (str != null && str2 != null) {
            trainedModelHelper.setPathToH5(str);
            trainedModelHelper.setPathToJSON(str2);
        }
        try {
            this.labels = list;
            this.vgg16 = trainedModelHelper.loadModel();
        } catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) {
            e.printStackTrace();
        }
    }

    public ImageClassifier(String str, List<String> list) {
        File file = new File(str);
        try {
            this.labels = list;
            this.vgg16 = ModelSerializer.restoreComputationGraph(file);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public CategoryEntries classify(File file) {
        return classify(file, 10);
    }

    public CategoryEntries classify(File file, int i) {
        if (i < 1) {
            i = 1;
        }
        NativeImageLoader nativeImageLoader = new NativeImageLoader(224, 224, 3);
        Category immutableCategory = new ImmutableCategory("unknown", 0.0d);
        try {
            INDArray asMatrix = nativeImageLoader.asMatrix(file);
            new VGG16ImagePreProcessor().transform(asMatrix);
            INDArray iNDArray = this.vgg16.output(false, new INDArray[]{asMatrix})[0];
            int i2 = 0;
            int[] iArr = new int[i];
            float[] fArr = new float[i];
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (int i3 = 0; i3 < iNDArray.size(0); i3++) {
                INDArray dup = iNDArray.getRow(i3).dup();
                while (i2 < i) {
                    iArr[i2] = Nd4j.argMax(dup, new int[]{1}).getInt(new int[]{0, 0});
                    fArr[i2] = dup.getFloat(i3, iArr[i2]);
                    dup.putScalar(0, iArr[i2], 0.0d);
                    Category immutableCategory2 = new ImmutableCategory(this.labels.get(iArr[i2]), fArr[i2]);
                    linkedHashMap.put(this.labels.get(iArr[i2]), immutableCategory2);
                    if (i2 == 0) {
                        immutableCategory = immutableCategory2;
                    }
                    i2++;
                }
            }
            return new ImmutableCategoryEntries(linkedHashMap, immutableCategory);
        } catch (IOException e) {
            e.printStackTrace();
            return new ImmutableCategoryEntries(new HashMap(), immutableCategory);
        }
    }

    public void transferLearn(ImageDataset imageDataset, File file, int i) throws IOException, JsonException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        StopWatch stopWatch = new StopWatch();
        new DatasetFeaturizer().featurizeDataset(imageDataset, i);
        TrainedModelHelper trainedModelHelper = new TrainedModelHelper(TrainedModels.VGG16);
        LOGGER.info("loading vgg16...");
        ComputationGraph loadModel = trainedModelHelper.loadModel();
        LOGGER.info("...loaded");
        ComputationGraph build = new TransferLearning.GraphBuilder(loadModel).fineTuneConfiguration(new FineTuneConfiguration.Builder().learningRate(Double.valueOf(3.0E-5d)).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).seed(12345L).build()).setFeatureExtractor(new String[]{DatasetFeaturizer.featurizeExtractionLayer}).removeVertexKeepConnections("predictions").addLayer("predictions", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(4096).nOut(imageDataset.getNumberOfClasses()).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0.0d, 0.2d * (2.0d / (4096 + r0)))).activation(Activation.SOFTMAX).build(), new String[]{DatasetFeaturizer.featurizeExtractionLayer}).build();
        DatasetIteratorFeaturized datasetIteratorFeaturized = new DatasetIteratorFeaturized(imageDataset, DatasetFeaturizer.featurizeExtractionLayer);
        DataSetIterator trainingIterator = datasetIteratorFeaturized.getTrainingIterator();
        DataSetIterator testingIterator = datasetIteratorFeaturized.getTestingIterator();
        TransferLearningHelper transferLearningHelper = new TransferLearningHelper(build);
        for (int i2 = 0; i2 < 3; i2++) {
            if (i2 == 0) {
                Evaluation evaluate = transferLearningHelper.unfrozenGraph().evaluate(testingIterator);
                LOGGER.info("Evaluation stats BEFORE fit:");
                LOGGER.info(evaluate.stats() + "\n");
                testingIterator.reset();
            }
            int i3 = 0;
            while (trainingIterator.hasNext()) {
                transferLearningHelper.fitFeaturized((DataSet) trainingIterator.next());
                if (i3 % 10 == 0) {
                    LOGGER.info("Evaluate model at iteration " + i3 + ":");
                    LOGGER.info(transferLearningHelper.unfrozenGraph().evaluate(testingIterator).stats());
                    testingIterator.reset();
                }
                i3++;
            }
            trainingIterator.reset();
            LOGGER.info("epoch #" + i2 + " complete");
        }
        LOGGER.info("model build complete, saving now...");
        ModelSerializer.writeModel(build, file, false);
        LOGGER.info("...model saved successfully, total time " + stopWatch.getElapsedTimeString());
    }

    public void transferLearnOnSpark(ImageDataset imageDataset, File file, int i) throws IOException, JsonException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        StopWatch stopWatch = new StopWatch();
        new DatasetFeaturizer().featurizeDataset(imageDataset, i);
        TrainedModelHelper trainedModelHelper = new TrainedModelHelper(TrainedModels.VGG16);
        LOGGER.info("loading vgg16...");
        ComputationGraph loadModel = trainedModelHelper.loadModel();
        LOGGER.info("...loaded");
        int numberOfClasses = imageDataset.getNumberOfClasses();
        ComputationGraph build = new TransferLearning.GraphBuilder(loadModel).fineTuneConfiguration(new FineTuneConfiguration.Builder().learningRate(Double.valueOf(3.0E-5d)).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).seed(12345L).build()).setFeatureExtractor(new String[]{DatasetFeaturizer.featurizeExtractionLayer}).removeVertexKeepConnections("predictions").addLayer("predictions", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(4096).nOut(numberOfClasses).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0.0d, 0.2d * (2.0d / (4096 + numberOfClasses)))).activation(Activation.SOFTMAX).build(), new String[]{DatasetFeaturizer.featurizeExtractionLayer}).build();
        DatasetIteratorFeaturized datasetIteratorFeaturized = new DatasetIteratorFeaturized(imageDataset, DatasetFeaturizer.featurizeExtractionLayer);
        DataSetIterator trainingIterator = datasetIteratorFeaturized.getTrainingIterator();
        DataSetIterator testingIterator = datasetIteratorFeaturized.getTestingIterator();
        TransferLearningHelper transferLearningHelper = new TransferLearningHelper(build);
        ParameterAveragingTrainingMaster build2 = new ParameterAveragingTrainingMaster.Builder(16).averagingFrequency(5).workerPrefetchNumBatches(2).batchSizePerWorker(16).build();
        LOGGER.info(build.summary());
        SparkConf sparkConf = new SparkConf();
        if (1 != 0) {
            sparkConf.setMaster("local[*]");
        }
        sparkConf.setAppName("vgg16");
        JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
        FileSystem fileSystem = FileSystem.get(javaSparkContext.hadoopConfiguration());
        SparkComputationGraph sparkComputationGraph = new SparkComputationGraph(javaSparkContext, transferLearningHelper.unfrozenGraph(), build2);
        System.out.println("Writing train to hdfs");
        int i2 = 0;
        while (trainingIterator.hasNext()) {
            int i3 = i2;
            i2++;
            FSDataOutputStream create = fileSystem.create(new Path("data/hdfstemp/train", "dataset" + i3));
            ((DataSet) trainingIterator.next()).save(create);
            create.close();
        }
        System.out.println("Writing test to hdfs");
        String str = "data/hdfstemp/test";
        int i4 = 0;
        while (testingIterator.hasNext()) {
            int i5 = i4;
            i4++;
            FSDataOutputStream create2 = fileSystem.create(new Path(str, "dataset" + i5));
            ((DataSet) testingIterator.next()).save(create2);
            create2.close();
        }
        for (int i6 = 0; i6 < 3; i6++) {
            sparkComputationGraph.fit("data/hdfstemp/train");
            LOGGER.info("Epoch #" + i6 + " complete");
        }
        javaSparkContext.binaryFiles(str + "/*").map(new Function<Tuple2<String, PortableDataStream>, DataSet>() { // from class: ws.palladian.classifiers.ImageClassifier.1
            public DataSet call(Tuple2<String, PortableDataStream> tuple2) throws Exception {
                DataSet dataSet = new DataSet();
                dataSet.load(((PortableDataStream) tuple2._2()).open());
                return dataSet;
            }
        }).mapPartitions(new IEvaluateFlatMapFunction(javaSparkContext.broadcast(loadModel.getConfiguration().toJson()), javaSparkContext.broadcast(sparkComputationGraph.getNetwork().params()), 16, new Evaluation(numberOfClasses))).reduce(new IEvaluationReduceFunction());
        Evaluation evaluate = sparkComputationGraph.getNetwork().evaluate(testingIterator);
        LOGGER.info("Eval stats BEFORE fit.....");
        LOGGER.info(evaluate.stats() + "\n");
        testingIterator.reset();
        for (int i7 = 0; i7 < 3; i7++) {
            if (i7 == 0) {
                Evaluation evaluate2 = transferLearningHelper.unfrozenGraph().evaluate(testingIterator);
                LOGGER.info("Evaluation stats BEFORE fit:");
                LOGGER.info(evaluate2.stats() + "\n");
                testingIterator.reset();
            }
            int i8 = 0;
            while (trainingIterator.hasNext()) {
                transferLearningHelper.fitFeaturized((DataSet) trainingIterator.next());
                if (i8 % 10 == 0) {
                    LOGGER.info("Evaluate model at iteration " + i8 + ":");
                    LOGGER.info(transferLearningHelper.unfrozenGraph().evaluate(testingIterator).stats());
                    testingIterator.reset();
                }
                i8++;
            }
            trainingIterator.reset();
            LOGGER.info("epoch #" + i7 + " complete");
        }
        LOGGER.info("model build complete, saving now...");
        ModelSerializer.writeModel(build, file, false);
        LOGGER.info("...model saved successfully, total time " + stopWatch.getElapsedTimeString());
    }

    public static void main(String[] strArr) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException, JsonException, UnirestException {
        ImageDataset imageDataset = new ImageDataset(new File("F:\\PalladianData\\Datasets\\spoonacular-menu-items\\dataset.json"));
        File file = new File("data\\models\\deeplearning4j\\" + imageDataset.getName() + ".zip");
        new ImageClassifier().transferLearnOnSpark(imageDataset, file, 80);
        System.exit(0);
        ImageClassifier imageClassifier = new ImageClassifier(file.getPath(), imageDataset.getClassNames());
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in, "UTF-8"));
        while (true) {
            System.out.println("url:");
            File file2 = new File(ImageHandler.downloadAndSave(bufferedReader.readLine(), ""));
            CategoryEntries<Category> classify = imageClassifier.classify(file2);
            file2.delete();
            for (Category category : classify) {
                System.out.println(category.getName() + " : " + category.getProbability());
            }
        }
    }
}
