package org.apache.hama.ml.ann;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hama.commons.math.DoubleDoubleFunction;
import org.apache.hama.commons.math.DoubleFunction;
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.FunctionFactory;

/* loaded from: input_file:org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.class */
abstract class AbstractLayeredNeuralNetwork extends NeuralNetwork {
    private static final double DEFAULT_REGULARIZATION_WEIGHT = 0.0d;
    private static final double DEFAULT_MOMENTUM_WEIGHT = 0.1d;
    double trainingError;
    protected double regularizationWeight;
    protected double momentumWeight;
    protected DoubleDoubleFunction costFunction;
    protected List<Integer> layerSizeList;
    protected TrainingMethod trainingMethod;
    protected LearningStyle learningStyle;

    /* loaded from: input_file:org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork$LearningStyle.class */
    public enum LearningStyle {
        UNSUPERVISED,
        SUPERVISED
    }

    /* loaded from: input_file:org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork$TrainingMethod.class */
    public enum TrainingMethod {
        GRADIENT_DESCENT
    }

    public AbstractLayeredNeuralNetwork() {
        this.regularizationWeight = DEFAULT_REGULARIZATION_WEIGHT;
        this.momentumWeight = DEFAULT_MOMENTUM_WEIGHT;
        this.trainingMethod = TrainingMethod.GRADIENT_DESCENT;
        this.learningStyle = LearningStyle.SUPERVISED;
    }

    public AbstractLayeredNeuralNetwork(String str) {
        super(str);
    }

    public void setRegularizationWeight(double d) {
        Preconditions.checkArgument(d >= DEFAULT_REGULARIZATION_WEIGHT && d < 1.0d, "Regularization weight must be in range [0, 1.0)");
        this.regularizationWeight = d;
    }

    public double getRegularizationWeight() {
        return this.regularizationWeight;
    }

    public void setMomemtumWeight(double d) {
        Preconditions.checkArgument(d >= DEFAULT_REGULARIZATION_WEIGHT && d <= 1.0d, "Momentum weight must be in range [0, 1.0]");
        this.momentumWeight = d;
    }

    public double getMomemtumWeight() {
        return this.momentumWeight;
    }

    public void setTrainingMethod(TrainingMethod trainingMethod) {
        this.trainingMethod = trainingMethod;
    }

    public TrainingMethod getTrainingMethod() {
        return this.trainingMethod;
    }

    public void setLearningStyle(LearningStyle learningStyle) {
        this.learningStyle = learningStyle;
    }

    public LearningStyle getLearningStyle() {
        return this.learningStyle;
    }

    public void setCostFunction(DoubleDoubleFunction doubleDoubleFunction) {
        this.costFunction = doubleDoubleFunction;
    }

    public abstract int addLayer(int i, boolean z, DoubleFunction doubleFunction);

    public int getLayerSize(int i) {
        Preconditions.checkArgument(i >= 0 && i < this.layerSizeList.size(), String.format("Input must be in range [0, %d]\n", Integer.valueOf(this.layerSizeList.size() - 1)));
        return this.layerSizeList.get(i).intValue();
    }

    protected List<Integer> getLayerSizeList() {
        return this.layerSizeList;
    }

    public abstract DoubleMatrix getWeightsByLayer(int i);

    public abstract DoubleMatrix[] trainByInstance(DoubleVector doubleVector);

    public abstract DoubleVector getOutput(DoubleVector doubleVector);

    protected abstract void calculateTrainingError(DoubleVector doubleVector, DoubleVector doubleVector2);

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public void readFields(DataInput dataInput) throws IOException {
        super.readFields(dataInput);
        this.regularizationWeight = dataInput.readDouble();
        this.momentumWeight = dataInput.readDouble();
        this.costFunction = FunctionFactory.createDoubleDoubleFunction(WritableUtils.readString(dataInput));
        int readInt = dataInput.readInt();
        this.layerSizeList = Lists.newArrayList();
        for (int i = 0; i < readInt; i++) {
            this.layerSizeList.add(Integer.valueOf(dataInput.readInt()));
        }
        this.trainingMethod = (TrainingMethod) WritableUtils.readEnum(dataInput, TrainingMethod.class);
        this.learningStyle = (LearningStyle) WritableUtils.readEnum(dataInput, LearningStyle.class);
    }

    @Override // org.apache.hama.ml.ann.NeuralNetwork
    public void write(DataOutput dataOutput) throws IOException {
        super.write(dataOutput);
        dataOutput.writeDouble(this.regularizationWeight);
        dataOutput.writeDouble(this.momentumWeight);
        WritableUtils.writeString(dataOutput, this.costFunction.getFunctionName());
        dataOutput.writeInt(this.layerSizeList.size());
        Iterator<Integer> it = this.layerSizeList.iterator();
        while (it.hasNext()) {
            dataOutput.writeInt(it.next().intValue());
        }
        WritableUtils.writeEnum(dataOutput, this.trainingMethod);
        WritableUtils.writeEnum(dataOutput, this.learningStyle);
    }
}
