package org.vesalainen.math;

import java.io.Serializable;
import org.vesalainen.math.matrix.DoubleMatrix;

/* loaded from: input_file:org/vesalainen/math/LevenbergMarquardt.class */
public class LevenbergMarquardt implements Serializable {
    private static final long serialVersionUID = 1;
    private static final double DELTA = 1.0E-8d;
    private int iter1;
    private int iter2;
    private double maxDifference;
    private double initialLambda;
    private Function func;
    private JacobianFactory jacobianFactory;
    private DoubleMatrix param;
    private double initialCost;
    private double finalCost;
    private DoubleMatrix d;
    private DoubleMatrix H;
    private DoubleMatrix negDelta;
    private DoubleMatrix tempParam;
    private DoubleMatrix A;
    private DoubleMatrix temp0;
    private DoubleMatrix temp1;
    private DoubleMatrix tempDH;
    private DoubleMatrix jacobian;

    /* loaded from: input_file:org/vesalainen/math/LevenbergMarquardt$Function.class */
    public interface Function {
        void compute(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3);
    }

    /* loaded from: input_file:org/vesalainen/math/LevenbergMarquardt$JacobianFactory.class */
    public interface JacobianFactory {
        void computeJacobian(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3);
    }

    public LevenbergMarquardt(Function function) {
        this(function, null);
    }

    public LevenbergMarquardt(Function function, JacobianFactory jacobianFactory) {
        this.iter1 = 25;
        this.iter2 = 5;
        this.maxDifference = DELTA;
        this.initialLambda = 1.0d;
        this.temp0 = new DoubleMatrix(1, 1);
        this.temp1 = new DoubleMatrix(1, 1);
        this.tempDH = new DoubleMatrix(1, 1);
        this.jacobian = new DoubleMatrix(1, 1);
        this.func = function;
        this.jacobianFactory = jacobianFactory;
        this.param = new DoubleMatrix(1, 1);
        this.d = new DoubleMatrix(1, 1);
        this.H = new DoubleMatrix(1, 1);
        this.negDelta = new DoubleMatrix(1, 1);
        this.tempParam = new DoubleMatrix(1, 1);
        this.A = new DoubleMatrix(1, 1);
    }

    public double getInitialCost() {
        return this.initialCost;
    }

    public double getFinalCost() {
        return this.finalCost;
    }

    public DoubleMatrix getParameters() {
        return this.param;
    }

    public boolean optimize(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3) {
        if (doubleMatrix2.rows() == 0) {
            return false;
        }
        configure(doubleMatrix, doubleMatrix2, doubleMatrix3);
        this.initialCost = cost(this.param, doubleMatrix2, doubleMatrix3);
        if (adjustParam(doubleMatrix2, doubleMatrix3, this.initialCost)) {
            return true;
        }
        this.finalCost = Double.NaN;
        return false;
    }

    private boolean adjustParam(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, double d) {
        double d2;
        double d4 = this.initialLambda;
        double d5 = 1000.0d;
        for (int i = 0; i < this.iter1 && d5 > this.maxDifference; i++) {
            computeDandH(this.param, doubleMatrix, doubleMatrix2);
            boolean z = false;
            for (int i2 = 0; i2 < this.iter2; i2++) {
                computeA(this.A, this.H, d4);
                if (!DoubleMatrix.solve(this.A, this.d, this.negDelta)) {
                    return false;
                }
                DoubleMatrix.subtract(this.param, this.negDelta, this.tempParam);
                double cost = cost(this.tempParam, doubleMatrix, doubleMatrix2);
                if (cost < d) {
                    z = true;
                    this.param.set(this.tempParam);
                    d5 = d - cost;
                    d = cost;
                    d2 = d4 / 10.0d;
                } else {
                    d2 = d4 * 10.0d;
                }
                d4 = d2;
            }
            if (!z) {
                break;
            }
        }
        this.finalCost = d;
        return true;
    }

    protected void configure(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3) {
        if (doubleMatrix3.getNumRows() != doubleMatrix2.getNumRows()) {
            throw new IllegalArgumentException("Different vector lengths");
        }
        if (doubleMatrix3.getNumCols() != 1) {
            throw new IllegalArgumentException("Inputs must be a column vector");
        }
        int numElements = doubleMatrix.getNumElements();
        int numRows = doubleMatrix3.getNumRows();
        if (this.param.getNumElements() != doubleMatrix.getNumElements()) {
            this.param.reshape(numElements, 1, false);
            this.d.reshape(numElements, 1, false);
            this.H.reshape(numElements, numElements, false);
            this.negDelta.reshape(numElements, 1, false);
            this.tempParam.reshape(numElements, 1, false);
            this.A.reshape(numElements, numElements, false);
        }
        this.param.set(doubleMatrix);
        this.temp0.reshape(numRows, 1, false);
        this.temp1.reshape(numRows, 1, false);
        this.tempDH.reshape(numRows, 1, false);
        this.jacobian.reshape(numElements, numRows, false);
    }

    private void computeDandH(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3) {
        this.func.compute(doubleMatrix, doubleMatrix2, this.tempDH);
        DoubleMatrix.subtractEquals(this.tempDH, doubleMatrix3);
        if (this.jacobianFactory != null) {
            this.jacobianFactory.computeJacobian(doubleMatrix, doubleMatrix2, this.jacobian);
        } else {
            computeNumericalJacobian(doubleMatrix, doubleMatrix2, this.jacobian);
        }
        int numElements = doubleMatrix.getNumElements();
        int numElements2 = doubleMatrix3.getNumElements();
        for (int i = 0; i < numElements; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < numElements2; i2++) {
                d += this.tempDH.get(i2, 0) * this.jacobian.get(i, i2);
            }
            this.d.set(i, 0, d / numElements2);
        }
        DoubleMatrix.multTransB(this.jacobian, this.jacobian, this.H);
        DoubleMatrix.scale(1.0d / numElements2, this.H);
    }

    private void computeA(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, double d) {
        int numElements = this.param.getNumElements();
        doubleMatrix.set(doubleMatrix2);
        for (int i = 0; i < numElements; i++) {
            doubleMatrix.set(i, i, doubleMatrix.get(i, i) + d);
        }
    }

    public double cost(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3) {
        this.func.compute(doubleMatrix, doubleMatrix2, this.temp0);
        double diffNorm = DoubleMatrix.diffNorm(this.temp0, doubleMatrix3);
        return (diffNorm * diffNorm) / doubleMatrix2.rows();
    }

    protected void computeNumericalJacobian(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3) {
        this.func.compute(doubleMatrix, doubleMatrix2, this.temp0);
        for (int i = 0; i < doubleMatrix.rows(); i++) {
            doubleMatrix.add(i, 0, DELTA);
            this.func.compute(doubleMatrix, doubleMatrix2, this.temp1);
            DoubleMatrix.add(1.0E8d, this.temp1, -1.0E8d, this.temp0, this.temp1);
            for (int i2 = 0; i2 < this.temp1.rows(); i2++) {
                doubleMatrix3.set(i, i2, this.temp1.get(i2, 0));
            }
            doubleMatrix.sub(i, 0, DELTA);
        }
    }

    public void setIter1(int i) {
        this.iter1 = i;
    }

    public void setIter2(int i) {
        this.iter2 = i;
    }

    public void setMaxDifference(double d) {
        this.maxDifference = d;
    }
}
