package edu.columbia.tjw.item.util;

import edu.columbia.tjw.item.fit.calculator.BlockResult;

/* loaded from: input_file:edu/columbia/tjw/item/util/IceTools.class */
public final class IceTools {
    public static final double EPSILON = Math.ulp(4.0d);
    public static final double SQRT_EPSILON = Math.sqrt(EPSILON);

    public static double computeWeight(double d, double d2) {
        if (d <= 0.0d) {
            return 0.0d;
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid Cutoff term.");
        }
        return Math.exp(-(d2 / d));
    }

    public static double computeJDiagCutoff(BlockResult blockResult) {
        int derivativeDimension = blockResult.getDerivativeDimension();
        double d = 0.0d;
        for (int i = 0; i < derivativeDimension; i++) {
            d = Math.max(d, Math.abs(blockResult.getJDiagEntry(i)));
        }
        return d * SQRT_EPSILON;
    }

    public static double computeITermCutoff(BlockResult blockResult) {
        int derivativeDimension = blockResult.getDerivativeDimension();
        double d = 0.0d;
        for (int i = 0; i < derivativeDimension; i++) {
            d = Math.max(d, Math.abs(blockResult.getD2Entry(i)));
        }
        return d * EPSILON;
    }

    public static double computeIceSum(BlockResult blockResult) {
        int derivativeDimension = blockResult.getDerivativeDimension();
        double computeITermCutoff = computeITermCutoff(blockResult);
        if (computeITermCutoff == 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < derivativeDimension; i++) {
            double d2Entry = blockResult.getD2Entry(i);
            if (d2Entry >= computeITermCutoff) {
                d += d2Entry / blockResult.getJDiagEntry(i);
            }
        }
        return d;
    }

    public static double computeIce2Sum(BlockResult blockResult) {
        int derivativeDimension = blockResult.getDerivativeDimension();
        double computeITermCutoff = computeITermCutoff(blockResult);
        if (computeITermCutoff == 0.0d) {
            return 0.0d;
        }
        double log = 1.0d / (Math.log(3.0d) * blockResult.getSize());
        double d = 0.0d;
        for (int i = 0; i < derivativeDimension; i++) {
            double d2Entry = blockResult.getD2Entry(i);
            if (d2Entry >= computeITermCutoff) {
                d += d2Entry / ((Math.abs(blockResult.getJDiagEntry(i)) * (1.0d - log)) + (d2Entry * log));
            }
        }
        return d;
    }

    public static double[] computeJWeight(double[] dArr) {
        double maxAbsElement = MathTools.maxAbsElement(dArr) * SQRT_EPSILON;
        double[] dArr2 = (double[]) dArr.clone();
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = computeWeight(dArr[i], maxAbsElement);
        }
        return dArr2;
    }

    public static double computeIce3Sum(double[] dArr, double[] dArr2, double[] dArr3) {
        int length = dArr.length;
        double maxAbsElement = MathTools.maxAbsElement(dArr) * EPSILON;
        if (maxAbsElement == 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            double d2 = dArr[i];
            if (d2 >= maxAbsElement) {
                double d3 = dArr2[i];
                double d4 = dArr3[i];
                d += d2 / ((d3 * d4) + (d2 * (1.0d - d4)));
            }
        }
        return d;
    }

    public static double computeIce3Sum(BlockResult blockResult) {
        int derivativeDimension = blockResult.getDerivativeDimension();
        double computeITermCutoff = computeITermCutoff(blockResult);
        if (computeITermCutoff == 0.0d) {
            return 0.0d;
        }
        double computeJDiagCutoff = computeJDiagCutoff(blockResult);
        double d = 0.0d;
        for (int i = 0; i < derivativeDimension; i++) {
            double d2Entry = blockResult.getD2Entry(i);
            if (d2Entry >= computeITermCutoff) {
                double jDiagEntry = blockResult.getJDiagEntry(i);
                double computeWeight = computeWeight(jDiagEntry, computeJDiagCutoff);
                d += d2Entry / ((jDiagEntry * computeWeight) + (d2Entry * (1.0d - computeWeight)));
            }
        }
        return d;
    }

    public static double[] fillIceExtraDerivative(BlockResult blockResult) {
        int derivativeDimension = blockResult.getDerivativeDimension();
        int size = blockResult.getSize();
        double[] dArr = new double[derivativeDimension];
        double computeITermCutoff = computeITermCutoff(blockResult);
        if (computeITermCutoff != 0.0d && size != 0) {
            double log = 1.0d / (Math.log(3.0d) * size);
            for (int i = 0; i < derivativeDimension; i++) {
                double shiftGradientEntry = blockResult.getShiftGradientEntry(i);
                if (shiftGradientEntry >= computeITermCutoff) {
                    double d2Entry = blockResult.getD2Entry(i);
                    if (d2Entry >= computeITermCutoff) {
                        dArr[i] = shiftGradientEntry / (((Math.max(blockResult.getJDiagEntry(i), 0.0d) * (1.0d - log)) + (d2Entry * log)) * size);
                    }
                }
            }
            return dArr;
        }
        return dArr;
    }

    public static double[] fillIce3ExtraDerivative(BlockResult blockResult) {
        int derivativeDimension = blockResult.getDerivativeDimension();
        int size = blockResult.getSize();
        double[] dArr = new double[derivativeDimension];
        double computeITermCutoff = computeITermCutoff(blockResult);
        double computeJDiagCutoff = computeJDiagCutoff(blockResult);
        if (computeITermCutoff == 0.0d || computeJDiagCutoff == 0.0d) {
            return dArr;
        }
        if (size == 0) {
            return dArr;
        }
        for (int i = 0; i < derivativeDimension; i++) {
            double shiftGradientEntry = blockResult.getShiftGradientEntry(i);
            if (shiftGradientEntry >= computeITermCutoff) {
                double d2Entry = blockResult.getD2Entry(i);
                if (d2Entry >= computeITermCutoff) {
                    double jDiagEntry = blockResult.getJDiagEntry(i);
                    double computeWeight = computeWeight(jDiagEntry, computeJDiagCutoff);
                    dArr[i] = shiftGradientEntry / (((jDiagEntry * computeWeight) + (d2Entry * (1.0d - computeWeight))) * size);
                }
            }
        }
        return dArr;
    }
}
