package org.apache.hama.ml.ann;

import com.google.common.base.Preconditions;
import java.util.Map;
import org.apache.hadoop.fs.Path;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.FunctionFactory;
import org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork;
import org.apache.hama.ml.util.FeatureTransformer;

/* loaded from: input_file:org/apache/hama/ml/ann/AutoEncoder.class */
public class AutoEncoder {
    private final SmallLayeredNeuralNetwork model;

    public AutoEncoder(int i, int i2) {
        this.model = new SmallLayeredNeuralNetwork();
        this.model.addLayer(i, false, FunctionFactory.createDoubleFunction("Sigmoid"));
        this.model.addLayer(i2, false, FunctionFactory.createDoubleFunction("Sigmoid"));
        this.model.addLayer(i, true, FunctionFactory.createDoubleFunction("Sigmoid"));
        this.model.setLearningStyle(AbstractLayeredNeuralNetwork.LearningStyle.UNSUPERVISED);
        this.model.setCostFunction(FunctionFactory.createDoubleDoubleFunction("SquaredError"));
    }

    public AutoEncoder(String str) {
        this.model = new SmallLayeredNeuralNetwork(str);
    }

    public AutoEncoder setLearningRate(double d) {
        this.model.setLearningRate(d);
        return this;
    }

    public AutoEncoder setMomemtumWeight(double d) {
        this.model.setMomemtumWeight(d);
        return this;
    }

    public AutoEncoder setRegularizationWeight(double d) {
        this.model.setRegularizationWeight(d);
        return this;
    }

    public AutoEncoder setModelPath(String str) {
        this.model.setModelPath(str);
        return this;
    }

    public void train(Path path, Map<String, String> map) {
        this.model.train(path, map);
    }

    public void trainOnline(DoubleVector doubleVector) {
        this.model.trainOnline(doubleVector);
    }

    public DoubleMatrix getEncodeWeightMatrix() {
        return this.model.getWeightsByLayer(0);
    }

    public DoubleMatrix getDecodeWeightMatrix() {
        return this.model.getWeightsByLayer(1);
    }

    private DoubleVector transform(DoubleVector doubleVector, int i) {
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(doubleVector.getDimension() + 1);
        denseDoubleVector.set(0, 1.0d);
        for (int i2 = 0; i2 < doubleVector.getDimension(); i2++) {
            denseDoubleVector.set(i2 + 1, doubleVector.get(i2));
        }
        return (i == 0 ? getEncodeWeightMatrix() : getDecodeWeightMatrix()).multiplyVectorUnsafe(denseDoubleVector).applyToElements(this.model.getSquashingFunction(i));
    }

    public DoubleVector encode(DoubleVector doubleVector) {
        Preconditions.checkArgument(doubleVector.getDimension() == this.model.getLayerSize(0) - 1, String.format("The dimension of input instance is %d, but the model requires dimension %d.", Integer.valueOf(doubleVector.getDimension()), Integer.valueOf(this.model.getLayerSize(1) - 1)));
        return transform(doubleVector, 0);
    }

    public DoubleVector decode(DoubleVector doubleVector) {
        Preconditions.checkArgument(doubleVector.getDimension() == this.model.getLayerSize(1) - 1, String.format("The dimension of input instance is %d, but the model requires dimension %d.", Integer.valueOf(doubleVector.getDimension()), Integer.valueOf(this.model.getLayerSize(1) - 1)));
        return transform(doubleVector, 1);
    }

    public DoubleVector getOutput(DoubleVector doubleVector) {
        return this.model.getOutput(doubleVector);
    }

    public void setFeatureTransformer(FeatureTransformer featureTransformer) {
        this.model.setFeatureTransformer(featureTransformer);
    }
}
