package org.apache.hama.ml.regression;

import java.io.IOException;
import java.util.Map;
import org.apache.hadoop.fs.Path;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.FunctionFactory;
import org.apache.hama.ml.ann.SmallLayeredNeuralNetwork;
import org.apache.hama.ml.util.FeatureTransformer;

/* loaded from: input_file:org/apache/hama/ml/regression/LogisticRegression.class */
public class LogisticRegression {
    private final SmallLayeredNeuralNetwork ann;

    public LogisticRegression(int i) {
        this.ann = new SmallLayeredNeuralNetwork();
        this.ann.addLayer(i, false, FunctionFactory.createDoubleFunction("Sigmoid"));
        this.ann.addLayer(1, true, FunctionFactory.createDoubleFunction("Sigmoid"));
        this.ann.setCostFunction(FunctionFactory.createDoubleDoubleFunction("CrossEntropy"));
    }

    public LogisticRegression(String str) {
        this.ann = new SmallLayeredNeuralNetwork(str);
    }

    public LogisticRegression setLearningRate(double d) {
        this.ann.setLearningRate(d);
        return this;
    }

    public double getLearningRate() {
        return this.ann.getLearningRate();
    }

    public LogisticRegression setMomemtumWeight(double d) {
        this.ann.setMomemtumWeight(d);
        return this;
    }

    public double getMomemtumWeight() {
        return this.ann.getMomemtumWeight();
    }

    public LogisticRegression setRegularizationWeight(double d) {
        this.ann.setRegularizationWeight(d);
        return this;
    }

    public double getRegularizationWeight() {
        return this.ann.getRegularizationWeight();
    }

    public void trainOnline(DoubleVector doubleVector) {
        this.ann.trainOnline(doubleVector);
    }

    public void train(Path path, Map<String, String> map) {
        this.ann.train(path, map);
    }

    public DoubleVector getOutput(DoubleVector doubleVector) {
        return this.ann.getOutput(doubleVector);
    }

    public void setModelPath(String str) {
        this.ann.setModelPath(str);
    }

    public void writeModelToFile() {
        try {
            this.ann.writeModelToFile();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public DoubleVector getWeights() {
        return this.ann.getWeightsByLayer(0).getRowVector(0);
    }

    public void setFeatureTransformer(FeatureTransformer featureTransformer) {
        this.ann.setFeatureTransformer(featureTransformer);
    }
}
