package edu.columbia.tjw.item.util;

import edu.columbia.tjw.item.algo.DoubleVector;
import edu.columbia.tjw.item.algo.VectorTools;
import edu.columbia.tjw.item.fit.calculator.BlockResult;
import org.apache.commons.math3.analysis.function.Multiply;

/* loaded from: input_file:edu/columbia/tjw/item/util/IceTools.class */
public final class IceTools {
    public static final double EPSILON = Math.ulp(4.0d);
    public static final double SQRT_EPSILON = Math.sqrt(EPSILON);

    public static double computeWeight(double d, double d2) {
        if (d <= 0.0d) {
            return 0.0d;
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid Cutoff term.");
        }
        return Math.exp(-(d2 / d));
    }

    public static double computeJDiagCutoff(DoubleVector doubleVector) {
        return VectorTools.maxAbsElement(doubleVector) * SQRT_EPSILON;
    }

    public static double computeJDiagCutoff(BlockResult blockResult) {
        return computeJDiagCutoff(blockResult.getJDiag());
    }

    public static double computeITermCutoff(DoubleVector doubleVector) {
        return VectorTools.maxAbsElement(doubleVector) * EPSILON;
    }

    public static double computeITermCutoff(BlockResult blockResult) {
        return computeITermCutoff(blockResult.getDerivativeSquared());
    }

    public static double computeIceSum(DoubleVector doubleVector, DoubleVector doubleVector2) {
        double computeITermCutoff = computeITermCutoff(doubleVector);
        if (computeITermCutoff == 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < doubleVector.getSize(); i++) {
            double entry = doubleVector.getEntry(i);
            if (entry >= computeITermCutoff) {
                d += entry / doubleVector2.getEntry(i);
            }
        }
        return d;
    }

    public static double computeIceSum(BlockResult blockResult) {
        return computeIceSum(blockResult.getDerivativeSquared(), blockResult.getJDiag());
    }

    public static double computeIce2Sum(DoubleVector doubleVector, DoubleVector doubleVector2, int i) {
        double computeITermCutoff = computeITermCutoff(doubleVector);
        if (computeITermCutoff == 0.0d) {
            return 0.0d;
        }
        int size = doubleVector.getSize();
        double log = 1.0d / (Math.log(3.0d) * i);
        double d = 0.0d;
        for (int i2 = 0; i2 < size; i2++) {
            double entry = doubleVector.getEntry(i2);
            if (entry >= computeITermCutoff) {
                d += entry / ((Math.abs(doubleVector2.getEntry(i2)) * (1.0d - log)) + (entry * log));
            }
        }
        return d;
    }

    public static double computeIce2Sum(BlockResult blockResult) {
        return computeIce2Sum(blockResult.getDerivativeSquared(), blockResult.getJDiag(), blockResult.getSize());
    }

    public static DoubleVector computeJWeight(DoubleVector doubleVector) {
        double maxAbsElement = VectorTools.maxAbsElement(doubleVector) * SQRT_EPSILON;
        return DoubleVector.apply(d -> {
            return computeWeight(d, maxAbsElement);
        }, doubleVector);
    }

    public static double computeIce3Sum(double[] dArr, DoubleVector doubleVector, DoubleVector doubleVector2, boolean z, boolean z2) {
        return computeIce3Sum(DoubleVector.of(dArr, false), doubleVector, doubleVector2, z, z2);
    }

    public static double computeIce3Sum(DoubleVector doubleVector, DoubleVector doubleVector2, DoubleVector doubleVector3, boolean z, boolean z2) {
        DoubleVector apply = z ? doubleVector : DoubleVector.apply(new Multiply(), doubleVector, doubleVector);
        int size = apply.getSize();
        double computeITermCutoff = computeITermCutoff(apply);
        if (computeITermCutoff == 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < size; i++) {
            double entry = apply.getEntry(i);
            if (entry >= computeITermCutoff) {
                double entry2 = doubleVector2.getEntry(i);
                double entry3 = doubleVector3.getEntry(i);
                double computeTermRatio = computeTermRatio(entry, entry2, entry3);
                if (z2) {
                    computeTermRatio *= entry3;
                }
                d += computeTermRatio;
            }
        }
        return d;
    }

    public static double computeTermRatio(double d, double d2, double d3) {
        return d / ((d2 * d3) + (d * (1.0d - d3)));
    }

    public static double computeIce3Sum(BlockResult blockResult) {
        return computeIce3Sum(blockResult.getDerivativeSquared(), blockResult.getJDiag(), computeJWeight(blockResult.getJDiag()), true, false);
    }

    public static DoubleVector fillIceExtraDerivative(DoubleVector doubleVector, DoubleVector doubleVector2, DoubleVector doubleVector3, int i) {
        int size = doubleVector.getSize();
        double computeITermCutoff = computeITermCutoff(doubleVector);
        if (computeITermCutoff != 0.0d && i != 0) {
            double[] dArr = new double[size];
            double log = 1.0d / (Math.log(3.0d) * i);
            for (int i2 = 0; i2 < size; i2++) {
                double entry = doubleVector2.getEntry(i2);
                if (entry >= computeITermCutoff) {
                    double entry2 = doubleVector.getEntry(i2);
                    if (entry2 >= computeITermCutoff) {
                        dArr[i2] = entry / (((Math.max(doubleVector3.getEntry(i2), 0.0d) * (1.0d - log)) + (entry2 * log)) * i);
                    }
                }
            }
            return DoubleVector.of(dArr, false);
        }
        return DoubleVector.constantVector(0.0d, size);
    }

    public static DoubleVector fillIceExtraDerivative(BlockResult blockResult) {
        return fillIceExtraDerivative(blockResult.getDerivativeSquared(), blockResult.getShiftGradient(), blockResult.getJDiag(), blockResult.getSize());
    }

    public static DoubleVector fillIceStableBExtraDerivative(BlockResult blockResult) {
        int derivativeDimension = blockResult.getDerivativeDimension();
        int size = blockResult.getSize();
        double computeITermCutoff = computeITermCutoff(blockResult);
        double computeJDiagCutoff = computeJDiagCutoff(blockResult);
        if (computeITermCutoff == 0.0d || computeJDiagCutoff == 0.0d) {
            return DoubleVector.constantVector(0.0d, derivativeDimension);
        }
        if (size == 0) {
            return DoubleVector.constantVector(0.0d, derivativeDimension);
        }
        double[] dArr = new double[derivativeDimension];
        for (int i = 0; i < derivativeDimension; i++) {
            double shiftGradientEntry = blockResult.getShiftGradientEntry(i);
            if (shiftGradientEntry >= computeITermCutoff) {
                double d2Entry = blockResult.getD2Entry(i);
                if (d2Entry >= computeITermCutoff) {
                    double jDiagEntry = blockResult.getJDiagEntry(i);
                    double computeWeight = computeWeight(jDiagEntry, computeJDiagCutoff);
                    dArr[i] = shiftGradientEntry / (((jDiagEntry * computeWeight) + (d2Entry * (1.0d - computeWeight))) * size);
                }
            }
        }
        return DoubleVector.of(dArr, false);
    }
}
