package org.apache.hama.ml.regression;

import java.io.IOException;
import java.math.BigDecimal;
import java.util.Arrays;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hama.bsp.BSP;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.util.KeyValuePair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hama/ml/regression/GradientDescentBSP.class */
public class GradientDescentBSP extends BSP<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> {
    private static final Logger log = LoggerFactory.getLogger(GradientDescentBSP.class);
    public static final String INITIAL_THETA_VALUES = "gd.initial.theta";
    public static final String ALPHA = "gd.alpha";
    public static final String COST_THRESHOLD = "gd.cost.threshold";
    public static final String ITERATIONS_THRESHOLD = "gd.iterations.threshold";
    public static final String REGRESSION_MODEL_CLASS = "gd.regression.model";
    private boolean master;
    private DoubleVector theta;
    private double cost;
    private double costThreshold;
    private float alpha;
    private RegressionModel regressionModel;
    private int iterationsThreshold;
    private int m;

    public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer) throws IOException, SyncException, InterruptedException {
        this.master = bSPPeer.getPeerIndex() == bSPPeer.getNumPeers() / 2;
        this.cost = Double.MAX_VALUE;
        this.costThreshold = bSPPeer.getConfiguration().getFloat(COST_THRESHOLD, 0.1f);
        this.iterationsThreshold = bSPPeer.getConfiguration().getInt(ITERATIONS_THRESHOLD, 10000);
        this.alpha = bSPPeer.getConfiguration().getFloat(ALPHA, 0.003f);
        try {
            this.regressionModel = (RegressionModel) bSPPeer.getConfiguration().getClass(REGRESSION_MODEL_CLASS, LinearRegressionModel.class).newInstance();
        } catch (Exception e) {
            throw new IOException(e);
        }
    }

    public void bsp(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer) throws IOException, SyncException, InterruptedException {
        getInitialTheta(bSPPeer);
        int i = 0;
        while (bSPPeer.readNext() != null) {
            i++;
        }
        broadcastVector(bSPPeer, new double[]{i});
        bSPPeer.sync();
        aggregateItemsNumber(bSPPeer, i);
        bSPPeer.reopenInput();
        int i2 = 0;
        while (true) {
            double calculateLocalCost = calculateLocalCost(bSPPeer);
            broadcastVector(bSPPeer, new double[]{calculateLocalCost});
            bSPPeer.sync();
            if (checkCost(bSPPeer, i2, aggregateTotalCost(bSPPeer, calculateLocalCost))) {
                return;
            }
            bSPPeer.sync();
            bSPPeer.reopenInput();
            double[] calculatePartialDerivatives = calculatePartialDerivatives(bSPPeer);
            broadcastVector(bSPPeer, calculatePartialDerivatives);
            bSPPeer.sync();
            updateTheta(aggregatePartialDerivatives(bSPPeer, calculatePartialDerivatives));
            if (log.isDebugEnabled()) {
                log.debug("{}: new theta for cost {} is {}", new Object[]{bSPPeer.getPeerName(), Double.valueOf(this.cost), this.theta});
            }
            if (this.master) {
                bSPPeer.write(new VectorWritable(this.theta), new DoubleWritable(this.cost));
            }
            bSPPeer.reopenInput();
            bSPPeer.sync();
            i2++;
        }
    }

    private double aggregateTotalCost(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer, double d) throws IOException {
        double d2 = d;
        while (true) {
            double d3 = d2;
            VectorWritable currentMessage = bSPPeer.getCurrentMessage();
            if (currentMessage == null) {
                return d3;
            }
            d2 = d3 + currentMessage.getVector().get(0);
        }
    }

    private double[] aggregatePartialDerivatives(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer, double[] dArr) throws IOException {
        double[] copyOf = Arrays.copyOf(dArr, dArr.length);
        while (true) {
            VectorWritable currentMessage = bSPPeer.getCurrentMessage();
            if (currentMessage == null) {
                return copyOf;
            }
            for (int i = 0; i < this.theta.getLength(); i++) {
                int i2 = i;
                copyOf[i2] = copyOf[i2] + currentMessage.getVector().get(i);
            }
        }
    }

    private void updateTheta(double[] dArr) {
        double[] dArr2 = new double[this.theta.getLength()];
        for (int i = 0; i < this.theta.getLength(); i++) {
            dArr2[i] = this.theta.get(i) - (dArr[i] * this.alpha);
        }
        this.theta = new DenseDoubleVector(dArr2);
    }

    private void aggregateItemsNumber(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer, int i) throws IOException {
        while (true) {
            VectorWritable currentMessage = bSPPeer.getCurrentMessage();
            if (currentMessage == null) {
                this.m = i;
                return;
            }
            i = (int) (i + currentMessage.getVector().get(0));
        }
    }

    private boolean checkCost(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer, int i, double d) {
        if (i > 0 && this.cost < d) {
            throw new RuntimeException("gradient descent failed to converge with alpha " + this.alpha);
        }
        if (d == 0.0d || d < this.costThreshold || i >= this.iterationsThreshold) {
            this.cost = d;
            return true;
        }
        this.cost = d;
        if (!log.isDebugEnabled()) {
            return false;
        }
        log.debug("{}: current cost is {}", bSPPeer.getPeerName(), Double.valueOf(this.cost));
        return false;
    }

    private double calculateLocalCost(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer) throws IOException {
        double d = 0.0d;
        while (true) {
            double d2 = d;
            KeyValuePair readNext = bSPPeer.readNext();
            if (readNext == null) {
                return d2;
            }
            double d3 = ((DoubleWritable) readNext.getValue()).get();
            d = d2 + this.regressionModel.calculateCostForItem(((VectorWritable) readNext.getKey()).getVector(), d3, this.m, this.theta).doubleValue();
        }
    }

    private void broadcastVector(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer, double[] dArr) throws IOException {
        for (String str : bSPPeer.getAllPeerNames()) {
            if (!str.equals(bSPPeer.getPeerName())) {
                bSPPeer.send(str, new VectorWritable(new DenseDoubleVector(dArr)));
            }
        }
    }

    private double[] calculatePartialDerivatives(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer) throws IOException {
        double[] dArr = new double[this.theta.getLength()];
        while (true) {
            KeyValuePair readNext = bSPPeer.readNext();
            if (readNext == null) {
                return dArr;
            }
            DoubleVector vector = ((VectorWritable) readNext.getKey()).getVector();
            BigDecimal subtract = this.regressionModel.applyHypothesis(this.theta, vector).subtract(BigDecimal.valueOf(((DoubleWritable) readNext.getValue()).get()));
            for (int i = 0; i < this.theta.getLength(); i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] + subtract.multiply(BigDecimal.valueOf(vector.get(i))).doubleValue();
            }
        }
    }

    public void cleanup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer) throws IOException {
        if (this.master) {
            bSPPeer.write(new VectorWritable(this.theta), new DoubleWritable(this.cost));
            if (log.isInfoEnabled()) {
                log.info("{}:computation finished with cost {} and theta {}", new Object[]{bSPPeer.getPeerName(), Double.valueOf(this.cost), this.theta});
            }
        }
    }

    void getInitialTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer) throws IOException, SyncException, InterruptedException {
        if (this.theta == null) {
            if (!this.master) {
                if (log.isDebugEnabled()) {
                    log.debug("{}: getting theta", bSPPeer.getPeerName());
                }
                bSPPeer.sync();
                this.theta = bSPPeer.getCurrentMessage().getVector();
                return;
            }
            this.theta = new DenseDoubleVector(getXSize(bSPPeer), bSPPeer.getConfiguration().getInt(INITIAL_THETA_VALUES, 1));
            broadcastVector(bSPPeer, this.theta.toArray());
            if (log.isDebugEnabled()) {
                log.debug("{}: sending theta", bSPPeer.getPeerName());
            }
            bSPPeer.sync();
        }
    }

    private int getXSize(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> bSPPeer) throws IOException {
        VectorWritable vectorWritable = new VectorWritable();
        bSPPeer.readNext(vectorWritable, new DoubleWritable());
        bSPPeer.reopenInput();
        if (vectorWritable.getVector() == null) {
            throw new IOException("cannot read input vector size");
        }
        return vectorWritable.getVector().getDimension();
    }
}
