package org.apache.hama.ml.regression;

import java.math.BigDecimal;
import java.math.MathContext;
import org.apache.hama.commons.math.DoubleVector;

/* loaded from: input_file:org/apache/hama/ml/regression/LogisticRegressionModel.class */
public class LogisticRegressionModel implements RegressionModel {
    private static final MathContext DEFAULT_PRECISION = MathContext.DECIMAL128;
    private final CostFunction costFunction = new CostFunction() { // from class: org.apache.hama.ml.regression.LogisticRegressionModel.1
        @Override // org.apache.hama.ml.regression.CostFunction
        public BigDecimal calculateCostForItem(DoubleVector doubleVector, double d, int i, DoubleVector doubleVector2, HypothesisFunction hypothesisFunction) {
            BigDecimal applyHypothesisWithPrecision = LogisticRegressionModel.this.applyHypothesisWithPrecision(doubleVector2, doubleVector);
            return BigDecimal.valueOf(d).multiply(LogisticRegressionModel.this.ln(applyHypothesisWithPrecision)).add(BigDecimal.valueOf(1.0d - d).multiply(LogisticRegressionModel.this.ln(BigDecimal.valueOf(1L).subtract(applyHypothesisWithPrecision, LogisticRegressionModel.DEFAULT_PRECISION)))).divide(BigDecimal.valueOf((-1) * i), LogisticRegressionModel.DEFAULT_PRECISION);
        }
    };

    @Override // org.apache.hama.ml.regression.HypothesisFunction
    public BigDecimal applyHypothesis(DoubleVector doubleVector, DoubleVector doubleVector2) {
        return applyHypothesisWithPrecision(doubleVector, doubleVector2);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public BigDecimal applyHypothesisWithPrecision(DoubleVector doubleVector, DoubleVector doubleVector2) {
        BigDecimal add = BigDecimal.valueOf(1.0d).add(BigDecimal.valueOf(Math.exp(doubleVector.multiply(-1.0d).dotUnsafe(doubleVector2))));
        BigDecimal divide = BigDecimal.valueOf(1L).divide(add, DEFAULT_PRECISION);
        BigDecimal subtract = BigDecimal.valueOf(1L).subtract(add, DEFAULT_PRECISION);
        if (divide.doubleValue() == 1.0d && subtract.doubleValue() < 0.0d) {
            divide = divide.add(subtract);
        }
        return divide;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public BigDecimal ln(BigDecimal bigDecimal) {
        return BigDecimal.valueOf(Math.log(bigDecimal.doubleValue()));
    }

    @Override // org.apache.hama.ml.regression.RegressionModel
    public BigDecimal calculateCostForItem(DoubleVector doubleVector, double d, int i, DoubleVector doubleVector2) {
        return this.costFunction.calculateCostForItem(doubleVector, d, i, doubleVector2, this);
    }
}
