package org.apache.hama.ml.perception;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.lang.SerializationUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSPJob;
import org.apache.hama.bsp.NullOutputFormat;
import org.apache.hama.bsp.SequenceFileInputFormat;
import org.apache.hama.commons.io.MatrixWritable;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleFunction;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.FunctionFactory;
import org.apache.hama.ml.util.FeatureTransformer;
import org.mortbay.log.Log;

/* loaded from: input_file:org/apache/hama/ml/perception/SmallMultiLayerPerceptron.class */
public final class SmallMultiLayerPerceptron extends MultiLayerPerceptron implements Writable {
    private DenseDoubleMatrix[] weightMatrice;
    private DenseDoubleMatrix[] prevWeightUpdateMatrices;

    public SmallMultiLayerPerceptron(double d, double d2, double d3, String str, String str2, int[] iArr) {
        super(d, d2, d3, str, str2, iArr);
        initializeWeightMatrix();
        initializePrevWeightUpdateMatrix();
    }

    public SmallMultiLayerPerceptron(String str) {
        super(str);
        if (str != null) {
            try {
                readFromModel();
                initializePrevWeightUpdateMatrix();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    private void initializeWeightMatrix() {
        this.weightMatrice = new DenseDoubleMatrix[this.numberOfLayers - 1];
        for (int i = 0; i < this.numberOfLayers - 1; i++) {
            this.weightMatrice[i] = new DenseDoubleMatrix(this.layerSizeArray[i] + 1, this.layerSizeArray[i + 1]);
            this.weightMatrice[i].applyToElements(new DoubleFunction() { // from class: org.apache.hama.ml.perception.SmallMultiLayerPerceptron.1
                private final Random rnd = new Random();

                public double apply(double d) {
                    return this.rnd.nextDouble() - 0.5d;
                }

                public double applyDerivative(double d) {
                    throw new UnsupportedOperationException("Not supported");
                }
            });
        }
    }

    private void initializePrevWeightUpdateMatrix() {
        this.prevWeightUpdateMatrices = new DenseDoubleMatrix[this.numberOfLayers - 1];
        for (int i = 0; i < this.prevWeightUpdateMatrices.length; i++) {
            this.prevWeightUpdateMatrices[i] = new DenseDoubleMatrix(this.layerSizeArray[i] + 1, this.layerSizeArray[i + 1]);
        }
    }

    @Override // org.apache.hama.ml.perception.MultiLayerPerceptron
    public DoubleVector outputWrapper(DoubleVector doubleVector) {
        List<double[]> outputInternal = outputInternal(doubleVector);
        return new DenseDoubleVector(outputInternal.get(outputInternal.size() - 1));
    }

    private List<double[]> outputInternal(DoubleVector doubleVector) {
        ArrayList arrayList = new ArrayList();
        double[] dArr = new double[this.layerSizeArray[0] + 1];
        if (dArr.length - 1 != doubleVector.getDimension()) {
            throw new IllegalStateException("Input feature dimension incorrect! The dimension of input layer is " + (this.layerSizeArray[0] - 1) + ", but the dimension of input feature is " + doubleVector.getDimension());
        }
        dArr[0] = 1.0d;
        DoubleVector transform = this.featureTransformer.transform(doubleVector);
        for (int i = 0; i < transform.getDimension(); i++) {
            dArr[i + 1] = transform.get(i);
        }
        arrayList.add(dArr);
        for (int i2 = 0; i2 < this.numberOfLayers - 1; i2++) {
            dArr = forward(i2, dArr);
            arrayList.add(dArr);
        }
        return arrayList;
    }

    private double[] forward(int i, double[] dArr) {
        double[] dArr2;
        int i2 = i + 1;
        int i3 = 0;
        if (i2 < this.layerSizeArray.length - 1) {
            dArr2 = new double[this.layerSizeArray[i2] + 1];
            i3 = 1;
            dArr2[0] = 1.0d;
        } else {
            dArr2 = new double[this.layerSizeArray[i2]];
        }
        for (int i4 = 0; i4 < this.layerSizeArray[i2]; i4++) {
            for (int i5 = 0; i5 < this.layerSizeArray[i] + 1; i5++) {
                double[] dArr3 = dArr2;
                int i6 = i4 + i3;
                dArr3[i6] = dArr3[i6] + (this.weightMatrice[i].get(i5, i4) * dArr[i5]);
            }
            dArr2[i4 + i3] = this.squashingFunction.apply(dArr2[i4 + i3]);
        }
        return dArr2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DenseDoubleMatrix[] trainByInstance(DoubleVector doubleVector) throws Exception {
        DenseDoubleMatrix[] denseDoubleMatrixArr = new DenseDoubleMatrix[this.layerSizeArray.length - 1];
        for (int i = 0; i < denseDoubleMatrixArr.length; i++) {
            denseDoubleMatrixArr[i] = new DenseDoubleMatrix(this.layerSizeArray[i] + 1, this.layerSizeArray[i + 1]);
        }
        if (doubleVector == null) {
            return denseDoubleMatrixArr;
        }
        double[] array = doubleVector.toArray();
        double[] array2 = this.featureTransformer.transform(doubleVector.sliceUnsafe(0, this.layerSizeArray[0] - 1)).toArray();
        double[] copyOfRange = Arrays.copyOfRange(array, this.layerSizeArray[0], array.length);
        List<double[]> outputInternal = outputInternal(new DenseDoubleVector(array2));
        double[] dArr = new double[this.layerSizeArray[this.layerSizeArray.length - 1]];
        double[] dArr2 = outputInternal.get(outputInternal.size() - 1);
        double[] dArr3 = outputInternal.get(outputInternal.size() - 2);
        DenseDoubleMatrix denseDoubleMatrix = this.prevWeightUpdateMatrices[this.prevWeightUpdateMatrices.length - 1];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = this.costFunction.applyDerivative(copyOfRange[i2], dArr2[i2]);
            if (this.regularization != 0.0d) {
                double d = 0.0d;
                DenseDoubleMatrix denseDoubleMatrix2 = this.weightMatrice[this.weightMatrice.length - 1];
                for (int i3 = 0; i3 < this.layerSizeArray[this.layerSizeArray.length - 1]; i3++) {
                    d += denseDoubleMatrix2.get(i3, i2);
                }
                int i4 = i2;
                dArr[i4] = dArr[i4] + (this.regularization * (d / this.layerSizeArray[this.layerSizeArray.length - 1]));
            }
            int i5 = i2;
            dArr[i5] = dArr[i5] * this.squashingFunction.applyDerivative(dArr2[i2]);
            for (int i6 = 0; i6 < this.layerSizeArray[this.layerSizeArray.length - 2] + 1; i6++) {
                denseDoubleMatrixArr[denseDoubleMatrixArr.length - 1].set(i6, i2, ((-this.learningRate) * dArr[i2] * dArr3[i6]) + (this.momentum * denseDoubleMatrix.get(i6, i2)));
            }
        }
        for (int length = this.layerSizeArray.length - 2; length >= 1; length--) {
            dArr = backpropagate(length, dArr, outputInternal, denseDoubleMatrixArr);
        }
        return denseDoubleMatrixArr;
    }

    private double[] backpropagate(int i, double[] dArr, List<double[]> list, DenseDoubleMatrix[] denseDoubleMatrixArr) {
        int i2 = i - 1;
        double[] dArr2 = new double[this.layerSizeArray[i]];
        double[] dArr3 = list.get(i);
        double[] dArr4 = list.get(i2);
        for (int i3 = 0; i3 < dArr2.length; i3++) {
            for (int i4 = 0; i4 < dArr.length; i4++) {
                int i5 = i3;
                dArr2[i5] = dArr2[i5] + (this.weightMatrice[i].get(i3, i4) * dArr[i4]);
            }
            int i6 = i3;
            dArr2[i6] = dArr2[i6] * this.squashingFunction.applyDerivative(dArr3[i3 + 1]);
            for (int i7 = 0; i7 < denseDoubleMatrixArr[i2].getRowCount(); i7++) {
                denseDoubleMatrixArr[i2].set(i7, i3, (-this.learningRate) * dArr2[i3] * dArr4[i7]);
            }
        }
        return dArr2;
    }

    @Override // org.apache.hama.ml.perception.MultiLayerPerceptron
    public void train(Path path, Map<String, String> map) throws IOException, InterruptedException, ClassNotFoundException {
        Configuration configuration = new Configuration();
        for (Map.Entry<String, String> entry : map.entrySet()) {
            configuration.set(entry.getKey(), entry.getValue());
        }
        if (this.modelPath == null || this.modelPath.trim().length() == 0) {
            configuration.set("MLPType", this.MLPType);
            configuration.set("learningRate", "" + this.learningRate);
            configuration.set("regularization", "" + this.regularization);
            configuration.set("momentum", "" + this.momentum);
            configuration.set("squashingFunctionName", this.squashingFunctionName);
            configuration.set("costFunctionName", this.costFunctionName);
            StringBuilder sb = new StringBuilder();
            for (int i : this.layerSizeArray) {
                sb.append(i);
                sb.append(' ');
            }
            configuration.set("layerSizeArray", sb.toString());
        }
        BSPJob bSPJob = new BSPJob(new HamaConfiguration(configuration), SmallMLPTrainer.class);
        bSPJob.setJobName("Small scale MLP training");
        bSPJob.setJarByClass(SmallMLPTrainer.class);
        bSPJob.setBspClass(SmallMLPTrainer.class);
        bSPJob.setInputPath(path);
        bSPJob.setInputFormat(SequenceFileInputFormat.class);
        bSPJob.setInputKeyClass(LongWritable.class);
        bSPJob.setInputValueClass(VectorWritable.class);
        bSPJob.setOutputKeyClass(NullWritable.class);
        bSPJob.setOutputValueClass(NullWritable.class);
        bSPJob.setOutputFormat(NullOutputFormat.class);
        bSPJob.setNumBspTask(configuration.getInt("tasks", 1));
        bSPJob.waitForCompletion(true);
        Log.info(String.format("Reload model from %s.", map.get("modelPath")));
        this.modelPath = map.get("modelPath");
        readFromModel();
    }

    public void readFields(DataInput dataInput) throws IOException {
        this.MLPType = WritableUtils.readString(dataInput);
        this.learningRate = dataInput.readDouble();
        this.regularization = dataInput.readDouble();
        this.momentum = dataInput.readDouble();
        this.numberOfLayers = dataInput.readInt();
        this.squashingFunctionName = WritableUtils.readString(dataInput);
        this.costFunctionName = WritableUtils.readString(dataInput);
        this.squashingFunction = FunctionFactory.createDoubleFunction(this.squashingFunctionName);
        this.costFunction = FunctionFactory.createDoubleDoubleFunction(this.costFunctionName);
        this.layerSizeArray = new int[this.numberOfLayers];
        for (int i = 0; i < this.numberOfLayers; i++) {
            this.layerSizeArray[i] = dataInput.readInt();
        }
        this.weightMatrice = new DenseDoubleMatrix[this.numberOfLayers - 1];
        for (int i2 = 0; i2 < this.numberOfLayers - 1; i2++) {
            this.weightMatrice[i2] = (DenseDoubleMatrix) MatrixWritable.read(dataInput);
        }
        byte[] bArr = new byte[dataInput.readInt()];
        for (int i3 = 0; i3 < bArr.length; i3++) {
            bArr[i3] = dataInput.readByte();
        }
        try {
            this.featureTransformer = (FeatureTransformer) ((Class) SerializationUtils.deserialize(bArr)).getConstructors()[0].newInstance(new Object[0]);
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (IllegalArgumentException e2) {
            e2.printStackTrace();
        } catch (InstantiationException e3) {
            e3.printStackTrace();
        } catch (InvocationTargetException e4) {
            e4.printStackTrace();
        }
    }

    public void write(DataOutput dataOutput) throws IOException {
        WritableUtils.writeString(dataOutput, this.MLPType);
        dataOutput.writeDouble(this.learningRate);
        dataOutput.writeDouble(this.regularization);
        dataOutput.writeDouble(this.momentum);
        dataOutput.writeInt(this.numberOfLayers);
        WritableUtils.writeString(dataOutput, this.squashingFunctionName);
        WritableUtils.writeString(dataOutput, this.costFunctionName);
        for (int i = 0; i < this.numberOfLayers; i++) {
            dataOutput.writeInt(this.layerSizeArray[i]);
        }
        for (int i2 = 0; i2 < this.numberOfLayers - 1; i2++) {
            new MatrixWritable(this.weightMatrice[i2]).write(dataOutput);
        }
        byte[] serialize = SerializationUtils.serialize(this.featureTransformer.getClass());
        dataOutput.writeInt(serialize.length);
        dataOutput.write(serialize);
    }

    @Override // org.apache.hama.ml.perception.MultiLayerPerceptron
    protected void readFromModel() throws IOException {
        try {
            readFields(new FSDataInputStream(FileSystem.get(new URI(this.modelPath), new Configuration()).open(new Path(this.modelPath))));
            if (this.MLPType.equals(getClass().getName())) {
            } else {
                throw new IllegalStateException(String.format("Model type incorrect, cannot load model '%s' for '%s'.", this.MLPType, getClass().getName()));
            }
        } catch (URISyntaxException e) {
            e.printStackTrace();
        }
    }

    @Override // org.apache.hama.ml.perception.MultiLayerPerceptron
    public void writeModelToFile(String str) throws IOException {
        FSDataOutputStream create = FileSystem.get(new Configuration()).create(new Path(str), true);
        write(create);
        create.close();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DenseDoubleMatrix[] getWeightMatrices() {
        return this.weightMatrice;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DenseDoubleMatrix[] getPrevWeightUpdateMatrices() {
        return this.prevWeightUpdateMatrices;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setWeightMatrices(DenseDoubleMatrix[] denseDoubleMatrixArr) {
        this.weightMatrice = denseDoubleMatrixArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setPrevWeightUpdateMatrices(DenseDoubleMatrix[] denseDoubleMatrixArr) {
        this.prevWeightUpdateMatrices = denseDoubleMatrixArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void updateWeightMatrices(DenseDoubleMatrix[] denseDoubleMatrixArr) {
        for (int i = 0; i < this.weightMatrice.length; i++) {
            this.weightMatrice[i] = (DenseDoubleMatrix) this.weightMatrice[i].add(denseDoubleMatrixArr[i]);
        }
    }

    static String weightsToString(DenseDoubleMatrix[] denseDoubleMatrixArr) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < denseDoubleMatrixArr.length; i++) {
            sb.append(String.format("Matrix [%d]\n", Integer.valueOf(i)));
            for (double[] dArr : denseDoubleMatrixArr[i].getValues()) {
                sb.append(Arrays.toString(dArr));
                sb.append('\n');
            }
            sb.append('\n');
        }
        return sb.toString();
    }

    @Override // org.apache.hama.ml.perception.MultiLayerPerceptron
    protected String getTypeName() {
        return getClass().getName();
    }
}
