package org.apache.hama.ml.ann;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.math.RandomUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSPJob;
import org.apache.hama.bsp.NullOutputFormat;
import org.apache.hama.bsp.SequenceFileInputFormat;
import org.apache.hama.commons.io.MatrixWritable;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleDoubleFunction;
import org.apache.hama.commons.math.DoubleFunction;
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.FunctionFactory;
import org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork;
import org.apache.hama.ml.util.FeatureTransformer;
import org.mortbay.log.Log;

/* loaded from: input_file:org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.class */
public class SmallLayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
    protected List<DoubleMatrix> weightMatrixList;
    protected List<DoubleMatrix> prevWeightUpdatesList;
    protected List<DoubleFunction> squashingFunctionList;
    protected int finalLayerIdx;

    public SmallLayeredNeuralNetwork() {
        this.layerSizeList = Lists.newArrayList();
        this.weightMatrixList = Lists.newArrayList();
        this.prevWeightUpdatesList = Lists.newArrayList();
        this.squashingFunctionList = Lists.newArrayList();
    }

    public SmallLayeredNeuralNetwork(String str) {
        super(str);
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public int addLayer(int i, boolean z, DoubleFunction doubleFunction) {
        Preconditions.checkArgument(i > 0, "Size of layer must be larger than 0.");
        if (!z) {
            i++;
        }
        this.layerSizeList.add(Integer.valueOf(i));
        int size = this.layerSizeList.size() - 1;
        if (z) {
            this.finalLayerIdx = size;
        }
        if (size > 0) {
            int intValue = this.layerSizeList.get(size - 1).intValue();
            int i2 = z ? i : i - 1;
            DoubleMatrix denseDoubleMatrix = new DenseDoubleMatrix(i2, intValue);
            denseDoubleMatrix.applyToElements(new DoubleFunction() { // from class: org.apache.hama.ml.ann.SmallLayeredNeuralNetwork.1
                public double apply(double d) {
                    return RandomUtils.nextDouble() - 0.5d;
                }

                public double applyDerivative(double d) {
                    throw new UnsupportedOperationException("");
                }
            });
            this.weightMatrixList.add(denseDoubleMatrix);
            this.prevWeightUpdatesList.add(new DenseDoubleMatrix(i2, intValue));
            this.squashingFunctionList.add(doubleFunction);
        }
        return size;
    }

    public void updateWeightMatrices(DoubleMatrix[] doubleMatrixArr) {
        for (int i = 0; i < doubleMatrixArr.length; i++) {
            this.weightMatrixList.set(i, this.weightMatrixList.get(i).add(doubleMatrixArr[i]));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setPrevWeightMatrices(DoubleMatrix[] doubleMatrixArr) {
        this.prevWeightUpdatesList.clear();
        Collections.addAll(this.prevWeightUpdatesList, doubleMatrixArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void matricesAdd(DoubleMatrix[] doubleMatrixArr, DoubleMatrix[] doubleMatrixArr2) {
        for (int i = 0; i < doubleMatrixArr.length; i++) {
            doubleMatrixArr[i] = doubleMatrixArr[i].add(doubleMatrixArr2[i]);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DoubleMatrix[] getWeightMatrices() {
        DoubleMatrix[] doubleMatrixArr = new DoubleMatrix[this.weightMatrixList.size()];
        this.weightMatrixList.toArray(doubleMatrixArr);
        return doubleMatrixArr;
    }

    public void setWeightMatrices(DoubleMatrix[] doubleMatrixArr) {
        this.weightMatrixList = new ArrayList();
        Collections.addAll(this.weightMatrixList, doubleMatrixArr);
    }

    public DoubleMatrix[] getPrevMatricesUpdates() {
        DoubleMatrix[] doubleMatrixArr = new DoubleMatrix[this.prevWeightUpdatesList.size()];
        for (int i = 0; i < this.prevWeightUpdatesList.size(); i++) {
            doubleMatrixArr[i] = this.prevWeightUpdatesList.get(i);
        }
        return doubleMatrixArr;
    }

    public void setWeightMatrix(int i, DoubleMatrix doubleMatrix) {
        Preconditions.checkArgument(0 <= i && i < this.weightMatrixList.size(), String.format("index [%d] should be in range[%d, %d].", Integer.valueOf(i), 0, Integer.valueOf(this.weightMatrixList.size())));
        this.weightMatrixList.set(i, doubleMatrix);
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork, org.apache.hama.ml.ann.NeuralNetwork
    public void readFields(DataInput dataInput) throws IOException {
        super.readFields(dataInput);
        int readInt = dataInput.readInt();
        this.squashingFunctionList = Lists.newArrayList();
        for (int i = 0; i < readInt; i++) {
            this.squashingFunctionList.add(FunctionFactory.createDoubleFunction(WritableUtils.readString(dataInput)));
        }
        int readInt2 = dataInput.readInt();
        this.weightMatrixList = Lists.newArrayList();
        this.prevWeightUpdatesList = Lists.newArrayList();
        for (int i2 = 0; i2 < readInt2; i2++) {
            DoubleMatrix read = MatrixWritable.read(dataInput);
            this.weightMatrixList.add(read);
            this.prevWeightUpdatesList.add(new DenseDoubleMatrix(read.getRowCount(), read.getColumnCount()));
        }
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork, org.apache.hama.ml.ann.NeuralNetwork
    public void write(DataOutput dataOutput) throws IOException {
        super.write(dataOutput);
        dataOutput.writeInt(this.squashingFunctionList.size());
        Iterator<DoubleFunction> it = this.squashingFunctionList.iterator();
        while (it.hasNext()) {
            WritableUtils.writeString(dataOutput, it.next().getFunctionName());
        }
        dataOutput.writeInt(this.weightMatrixList.size());
        Iterator<DoubleMatrix> it2 = this.weightMatrixList.iterator();
        while (it2.hasNext()) {
            MatrixWritable.write(it2.next(), dataOutput);
        }
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public DoubleMatrix getWeightsByLayer(int i) {
        return this.weightMatrixList.get(i);
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public DoubleVector getOutput(DoubleVector doubleVector) {
        Preconditions.checkArgument(this.layerSizeList.get(0).intValue() - 1 == doubleVector.getDimension(), String.format("The dimension of input instance should be %d.", Integer.valueOf(this.layerSizeList.get(0).intValue() - 1)));
        DoubleVector transform = this.featureTransformer.transform(doubleVector);
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(transform.getDimension() + 1);
        denseDoubleVector.set(0, 0.99999d);
        for (int i = 1; i < denseDoubleVector.getDimension(); i++) {
            denseDoubleVector.set(i, transform.get(i - 1));
        }
        List<DoubleVector> outputInternal = getOutputInternal(denseDoubleVector);
        DoubleVector doubleVector2 = outputInternal.get(outputInternal.size() - 1);
        return doubleVector2.sliceUnsafe(1, doubleVector2.getDimension() - 1);
    }

    public List<DoubleVector> getOutputInternal(DoubleVector doubleVector) {
        ArrayList arrayList = new ArrayList();
        DoubleVector doubleVector2 = doubleVector;
        arrayList.add(doubleVector2);
        for (int i = 0; i < this.layerSizeList.size() - 1; i++) {
            doubleVector2 = forward(i, doubleVector2);
            arrayList.add(doubleVector2);
        }
        return arrayList;
    }

    protected DoubleVector forward(int i, DoubleVector doubleVector) {
        DoubleVector applyToElements = this.weightMatrixList.get(i).multiplyVectorUnsafe(doubleVector).applyToElements(this.squashingFunctionList.get(i));
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(applyToElements.getDimension() + 1);
        denseDoubleVector.set(0, 1.0d);
        for (int i2 = 0; i2 < applyToElements.getDimension(); i2++) {
            denseDoubleVector.set(i2 + 1, applyToElements.get(i2));
        }
        return denseDoubleVector;
    }

    public void trainOnline(DoubleVector doubleVector) {
        updateWeightMatrices(trainByInstance(doubleVector));
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public DoubleMatrix[] trainByInstance(DoubleVector doubleVector) {
        DoubleVector transform = this.featureTransformer.transform(doubleVector.sliceUnsafe(this.layerSizeList.get(0).intValue() - 1));
        int intValue = this.layerSizeList.get(0).intValue() - 1;
        DoubleVector doubleVector2 = null;
        DoubleVector doubleVector3 = null;
        if (this.learningStyle == AbstractLayeredNeuralNetwork.LearningStyle.SUPERVISED) {
            int intValue2 = this.layerSizeList.get(this.layerSizeList.size() - 1).intValue();
            Preconditions.checkArgument(intValue + intValue2 == doubleVector.getDimension(), String.format("The dimension of training instance is %d, but requires %d.", Integer.valueOf(doubleVector.getDimension()), Integer.valueOf(intValue + intValue2)));
            doubleVector2 = new DenseDoubleVector(this.layerSizeList.get(0).intValue());
            doubleVector2.set(0, 1.0d);
            for (int i = 0; i < intValue; i++) {
                doubleVector2.set(i + 1, transform.get(i));
            }
            doubleVector3 = doubleVector.sliceUnsafe(doubleVector2.getDimension() - 1, doubleVector.getDimension() - 1);
        } else if (this.learningStyle == AbstractLayeredNeuralNetwork.LearningStyle.UNSUPERVISED) {
            Preconditions.checkArgument(intValue == doubleVector.getDimension(), String.format("The dimension of training instance is %d, but requires %d.", Integer.valueOf(doubleVector.getDimension()), Integer.valueOf(intValue)));
            doubleVector2 = new DenseDoubleVector(this.layerSizeList.get(0).intValue());
            doubleVector2.set(0, 1.0d);
            for (int i2 = 0; i2 < intValue; i2++) {
                doubleVector2.set(i2 + 1, transform.get(i2));
            }
            doubleVector3 = transform.deepCopy();
        }
        List<DoubleVector> outputInternal = getOutputInternal(doubleVector2);
        DoubleVector doubleVector4 = outputInternal.get(outputInternal.size() - 1);
        calculateTrainingError(doubleVector3, doubleVector4.deepCopy().sliceUnsafe(1, doubleVector4.getDimension() - 1));
        if (this.trainingMethod.equals(AbstractLayeredNeuralNetwork.TrainingMethod.GRADIENT_DESCENT)) {
            return trainByInstanceGradientDescent(doubleVector3, outputInternal);
        }
        throw new IllegalArgumentException(String.format("Training method is not supported.", new Object[0]));
    }

    private DoubleMatrix[] trainByInstanceGradientDescent(DoubleVector doubleVector, List<DoubleVector> list) {
        DoubleVector doubleVector2 = list.get(list.size() - 1);
        DenseDoubleMatrix[] denseDoubleMatrixArr = new DenseDoubleMatrix[this.weightMatrixList.size()];
        for (int i = 0; i < denseDoubleMatrixArr.length; i++) {
            denseDoubleMatrixArr[i] = new DenseDoubleMatrix(this.weightMatrixList.get(i).getRowCount(), this.weightMatrixList.get(i).getColumnCount());
        }
        DoubleVector denseDoubleVector = new DenseDoubleVector(this.layerSizeList.get(this.layerSizeList.size() - 1).intValue());
        DoubleFunction doubleFunction = this.squashingFunctionList.get(this.squashingFunctionList.size() - 1);
        DoubleMatrix doubleMatrix = this.weightMatrixList.get(this.weightMatrixList.size() - 1);
        for (int i2 = 0; i2 < denseDoubleVector.getDimension(); i2++) {
            denseDoubleVector.set(i2, this.costFunction.applyDerivative(doubleVector.get(i2), doubleVector2.get(i2 + 1)) + (this.regularizationWeight * doubleMatrix.getRowVector(i2).sum()));
            denseDoubleVector.set(i2, denseDoubleVector.get(i2) * doubleFunction.applyDerivative(doubleVector2.get(i2 + 1)));
        }
        for (int size = this.layerSizeList.size() - 2; size >= 0; size--) {
            list.get(size);
            denseDoubleVector = backpropagate(size, denseDoubleVector, list, denseDoubleMatrixArr[size]);
        }
        setPrevWeightMatrices(denseDoubleMatrixArr);
        return denseDoubleMatrixArr;
    }

    private DoubleVector backpropagate(int i, DoubleVector doubleVector, List<DoubleVector> list, DenseDoubleMatrix denseDoubleMatrix) {
        DoubleFunction doubleFunction = this.squashingFunctionList.get(i);
        DoubleVector doubleVector2 = list.get(i);
        DoubleMatrix doubleMatrix = this.weightMatrixList.get(i);
        DoubleMatrix doubleMatrix2 = this.prevWeightUpdatesList.get(i);
        if (i != this.layerSizeList.size() - 2) {
            doubleVector = doubleVector.slice(1, doubleVector.getDimension() - 1);
        }
        DoubleVector multiplyVector = doubleMatrix.transpose().multiplyVector(doubleVector);
        for (int i2 = 0; i2 < multiplyVector.getDimension(); i2++) {
            multiplyVector.set(i2, multiplyVector.get(i2) * doubleFunction.applyDerivative(doubleVector2.get(i2)));
        }
        for (int i3 = 0; i3 < denseDoubleMatrix.getRowCount(); i3++) {
            for (int i4 = 0; i4 < denseDoubleMatrix.getColumnCount(); i4++) {
                denseDoubleMatrix.set(i3, i4, ((-this.learningRate) * doubleVector.get(i3) * doubleVector2.get(i4)) + (this.momentumWeight * doubleMatrix2.get(i3, i4)));
            }
        }
        return multiplyVector;
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    protected void trainInternal(Path path, Map<String, String> map) throws IOException, InterruptedException, ClassNotFoundException {
        Configuration configuration = new Configuration();
        for (Map.Entry<String, String> entry : map.entrySet()) {
            configuration.set(entry.getKey(), entry.getValue());
        }
        String str = map.get("modelPath");
        if (str != null) {
            this.modelPath = str;
        }
        if (this.modelPath == null) {
            throw new IllegalArgumentException("Please specify the modelPath for model, either through setModelPath() or add 'modelPath' to the training parameters.");
        }
        configuration.set("modelPath", this.modelPath);
        writeModelToFile();
        BSPJob bSPJob = new BSPJob(new HamaConfiguration(configuration), SmallLayeredNeuralNetworkTrainer.class);
        bSPJob.setJobName("Small scale Neural Network training");
        bSPJob.setJarByClass(SmallLayeredNeuralNetworkTrainer.class);
        bSPJob.setBspClass(SmallLayeredNeuralNetworkTrainer.class);
        bSPJob.setInputPath(path);
        bSPJob.setInputFormat(SequenceFileInputFormat.class);
        bSPJob.setInputKeyClass(LongWritable.class);
        bSPJob.setInputValueClass(VectorWritable.class);
        bSPJob.setOutputKeyClass(NullWritable.class);
        bSPJob.setOutputValueClass(NullWritable.class);
        bSPJob.setOutputFormat(NullOutputFormat.class);
        int i = configuration.getInt("tasks", 1);
        Log.info(String.format("Number of tasks: %d\n", Integer.valueOf(i)));
        bSPJob.setNumBspTask(i);
        bSPJob.waitForCompletion(true);
        Log.info(String.format("Reload model from %s.", this.modelPath));
        readFromModel();
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    protected void calculateTrainingError(DoubleVector doubleVector, DoubleVector doubleVector2) {
        this.trainingError = doubleVector.deepCopy().applyToElements(doubleVector2, this.costFunction).sum();
    }

    public DoubleFunction getSquashingFunction(int i) {
        return this.squashingFunctionList.get(i);
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public /* bridge */ /* synthetic */ int getLayerSize(int i) {
        return super.getLayerSize(i);
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public /* bridge */ /* synthetic */ void setCostFunction(DoubleDoubleFunction doubleDoubleFunction) {
        super.setCostFunction(doubleDoubleFunction);
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public /* bridge */ /* synthetic */ AbstractLayeredNeuralNetwork.LearningStyle getLearningStyle() {
        return super.getLearningStyle();
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public /* bridge */ /* synthetic */ void setLearningStyle(AbstractLayeredNeuralNetwork.LearningStyle learningStyle) {
        super.setLearningStyle(learningStyle);
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public /* bridge */ /* synthetic */ AbstractLayeredNeuralNetwork.TrainingMethod getTrainingMethod() {
        return super.getTrainingMethod();
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public /* bridge */ /* synthetic */ void setTrainingMethod(AbstractLayeredNeuralNetwork.TrainingMethod trainingMethod) {
        super.setTrainingMethod(trainingMethod);
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public /* bridge */ /* synthetic */ double getMomemtumWeight() {
        return super.getMomemtumWeight();
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public /* bridge */ /* synthetic */ void setMomemtumWeight(double d) {
        super.setMomemtumWeight(d);
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public /* bridge */ /* synthetic */ double getRegularizationWeight() {
        return super.getRegularizationWeight();
    }

    @Override // org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork
    public /* bridge */ /* synthetic */ void setRegularizationWeight(double d) {
        super.setRegularizationWeight(d);
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public /* bridge */ /* synthetic */ FeatureTransformer getFeatureTransformer() {
        return super.getFeatureTransformer();
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public /* bridge */ /* synthetic */ void setFeatureTransformer(FeatureTransformer featureTransformer) {
        super.setFeatureTransformer(featureTransformer);
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public /* bridge */ /* synthetic */ String getModelPath() {
        return super.getModelPath();
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public /* bridge */ /* synthetic */ void setModelPath(String str) {
        super.setModelPath(str);
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public /* bridge */ /* synthetic */ void writeModelToFile() throws IOException {
        super.writeModelToFile();
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public /* bridge */ /* synthetic */ void train(Path path, Map map) {
        super.train(path, map);
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public /* bridge */ /* synthetic */ String getModelType() {
        return super.getModelType();
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public /* bridge */ /* synthetic */ void isLearningRateDecay(boolean z) {
        super.isLearningRateDecay(z);
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public /* bridge */ /* synthetic */ double getLearningRate() {
        return super.getLearningRate();
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public /* bridge */ /* synthetic */ void setLearningRate(double d) {
        super.setLearningRate(d);
    }
}
