package org.apache.hama.ml.perception;

import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.ml.ann.NeuralNetworkTrainer;

/* loaded from: input_file:org/apache/hama/ml/perception/SmallMLPTrainer.class */
class SmallMLPTrainer extends NeuralNetworkTrainer {
    private BitSet statusSet;
    private int numTrainingInstanceRead = 0;
    private boolean terminateTraining = false;
    private SmallMultiLayerPerceptron inMemoryPerceptron;
    private int[] layerSizeArray;

    SmallMLPTrainer() {
    }

    @Override // org.apache.hama.ml.ann.NeuralNetworkTrainer
    protected void extraSetup(BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> bSPPeer) {
        this.trainingMode = this.conf.get("training.mode", "minibatch.gradient.descent");
        this.batchSize = this.conf.getInt("training.batch.size", 100);
        this.statusSet = new BitSet(bSPPeer.getConfiguration().getInt("tasks", 1));
        String str = this.conf.get("modelPath");
        if (str == null || str.trim().length() == 0) {
            try {
                throw new Exception("Please specify output model path.");
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        String str2 = this.conf.get("existingModelPath");
        if (str2 != null && str2.trim().length() != 0) {
            this.inMemoryPerceptron = new SmallMultiLayerPerceptron(str2);
            LOG.info("Training with existing model.");
            return;
        }
        double parseDouble = Double.parseDouble(this.conf.get("learningRate"));
        double parseDouble2 = Double.parseDouble(this.conf.get("regularization"));
        double parseDouble3 = Double.parseDouble(this.conf.get("momentum"));
        String str3 = this.conf.get("squashingFunctionName");
        String str4 = this.conf.get("costFunctionName");
        String[] split = this.conf.get("layerSizeArray").trim().split(" ");
        this.layerSizeArray = new int[split.length];
        for (int i = 0; i < this.layerSizeArray.length; i++) {
            this.layerSizeArray[i] = Integer.parseInt(split[i]);
        }
        this.inMemoryPerceptron = new SmallMultiLayerPerceptron(parseDouble, parseDouble2, parseDouble3, str3, str4, this.layerSizeArray);
        LOG.info("Training model from scratch.");
    }

    @Override // org.apache.hama.ml.ann.NeuralNetworkTrainer
    protected void extraCleanup(BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> bSPPeer) {
        LOG.info(String.format("Task %d totally read %d records.\n", Integer.valueOf(bSPPeer.getPeerIndex()), Integer.valueOf(this.numTrainingInstanceRead)));
        if (bSPPeer.getPeerIndex() == 0) {
            try {
                LOG.info(String.format("Master write learned model to %s\n", this.conf.get("modelPath")));
                this.inMemoryPerceptron.writeModelToFile(this.conf.get("modelPath"));
            } catch (IOException e) {
                System.err.println("Please set a correct model path.");
            }
        }
    }

    @Override // org.apache.hama.ml.ann.NeuralNetworkTrainer
    public void bsp(BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> bSPPeer) throws IOException, SyncException, InterruptedException {
        LOG.info("Start training...");
        if (this.trainingMode.equalsIgnoreCase("minibatch.gradient.descent")) {
            LOG.info("Training Mode: minibatch.gradient.descent");
            trainByMinibatch(bSPPeer);
        }
        LOG.info(String.format("Task %d finished.", Integer.valueOf(bSPPeer.getPeerIndex())));
    }

    private void trainByMinibatch(BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> bSPPeer) throws IOException, SyncException, InterruptedException {
        boolean updateWeights;
        int i = this.conf.getInt("training.iteration", 1);
        LOG.info("# of Training Iteration: " + i);
        for (int i2 = 0; i2 < i; i2++) {
            if (bSPPeer.getPeerIndex() == 0) {
                LOG.info(String.format("Iteration [%d] begins...", Integer.valueOf(i2)));
            }
            bSPPeer.reopenInput();
            if (bSPPeer.getPeerIndex() == 0) {
                this.statusSet = new BitSet(bSPPeer.getConfiguration().getInt("tasks", 1));
            }
            this.terminateTraining = false;
            bSPPeer.sync();
            do {
                updateWeights = updateWeights(bSPPeer);
                bSPPeer.sync();
                if (bSPPeer.getPeerIndex() == 0) {
                    mergeUpdate(bSPPeer);
                }
                bSPPeer.sync();
            } while (!updateWeights);
        }
    }

    private void mergeUpdate(BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> bSPPeer) throws IOException {
        DenseDoubleMatrix[] zeroWeightMatrices = getZeroWeightMatrices();
        int numCurrentMessages = bSPPeer.getNumCurrentMessages();
        while (bSPPeer.getNumCurrentMessages() > 0) {
            SmallMLPMessage smallMLPMessage = (SmallMLPMessage) bSPPeer.getCurrentMessage();
            if (smallMLPMessage.isTerminated()) {
                this.statusSet.set(smallMLPMessage.getOwner());
            }
            DoubleMatrix[] weightUpdatedMatrices = smallMLPMessage.getWeightUpdatedMatrices();
            for (int i = 0; i < zeroWeightMatrices.length; i++) {
                zeroWeightMatrices[i] = (DenseDoubleMatrix) zeroWeightMatrices[i].add(weightUpdatedMatrices[i]);
            }
        }
        if (numCurrentMessages != 0) {
            for (int i2 = 0; i2 < zeroWeightMatrices.length; i2++) {
                zeroWeightMatrices[i2] = (DenseDoubleMatrix) zeroWeightMatrices[i2].divide(numCurrentMessages);
            }
            if (this.statusSet.cardinality() == this.conf.getInt("tasks", 1)) {
                this.terminateTraining = true;
            }
            this.inMemoryPerceptron.updateWeightMatrices(zeroWeightMatrices);
            this.inMemoryPerceptron.setPrevWeightUpdateMatrices(zeroWeightMatrices);
        }
        for (String str : bSPPeer.getAllPeerNames()) {
            bSPPeer.send(str, new SmallMLPMessage(bSPPeer.getPeerIndex(), this.terminateTraining, this.inMemoryPerceptron.getWeightMatrices(), this.inMemoryPerceptron.getPrevWeightUpdateMatrices()));
        }
    }

    private boolean updateWeights(BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> bSPPeer) throws IOException {
        if (bSPPeer.getNumCurrentMessages() > 0) {
            SmallMLPMessage smallMLPMessage = (SmallMLPMessage) bSPPeer.getCurrentMessage();
            this.terminateTraining = smallMLPMessage.isTerminated();
            this.inMemoryPerceptron.setWeightMatrices(smallMLPMessage.getWeightUpdatedMatrices());
            this.inMemoryPerceptron.setPrevWeightUpdateMatrices(smallMLPMessage.getPrevWeightsUpdatedMatrices());
            if (this.terminateTraining) {
                return true;
            }
        }
        DenseDoubleMatrix[] zeroWeightMatrices = getZeroWeightMatrices();
        int i = 0;
        LongWritable longWritable = new LongWritable();
        VectorWritable vectorWritable = new VectorWritable();
        boolean z = false;
        do {
            int i2 = i;
            i++;
            if (i2 >= this.batchSize) {
                break;
            }
            z = bSPPeer.readNext(longWritable, vectorWritable);
            try {
                DoubleMatrix[] trainByInstance = this.inMemoryPerceptron.trainByInstance(vectorWritable.getVector());
                for (int i3 = 0; i3 < zeroWeightMatrices.length; i3++) {
                    zeroWeightMatrices[i3] = (DenseDoubleMatrix) zeroWeightMatrices[i3].add(trainByInstance[i3]);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
            this.numTrainingInstanceRead++;
        } while (z);
        for (int i4 = 0; i4 < zeroWeightMatrices.length; i4++) {
            zeroWeightMatrices[i4] = (DenseDoubleMatrix) zeroWeightMatrices[i4].divide(i);
        }
        LOG.info(String.format("Task %d has read %d records.", Integer.valueOf(bSPPeer.getPeerIndex()), Integer.valueOf(this.numTrainingInstanceRead)));
        bSPPeer.send(bSPPeer.getPeerName(0), new SmallMLPMessage(bSPPeer.getPeerIndex(), !z, zeroWeightMatrices));
        return !z;
    }

    private DenseDoubleMatrix[] getZeroWeightMatrices() {
        DenseDoubleMatrix[] denseDoubleMatrixArr = new DenseDoubleMatrix[this.layerSizeArray.length - 1];
        for (int i = 0; i < denseDoubleMatrixArr.length; i++) {
            denseDoubleMatrixArr[i] = new DenseDoubleMatrix(this.layerSizeArray[i] + 1, this.layerSizeArray[i + 1]);
        }
        return denseDoubleMatrixArr;
    }

    protected static String weightsToString(DenseDoubleMatrix[] denseDoubleMatrixArr) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < denseDoubleMatrixArr.length; i++) {
            sb.append(String.format("Matrix [%d]\n", Integer.valueOf(i)));
            for (double[] dArr : denseDoubleMatrixArr[i].getValues()) {
                sb.append(Arrays.toString(dArr));
                sb.append('\n');
            }
            sb.append('\n');
        }
        return sb.toString();
    }
}
