package edu.columbia.tjw.item.fit.calculator;

import edu.columbia.tjw.item.ItemSettings;
import edu.columbia.tjw.item.algo.GKQuantileBreakdown;
import edu.columbia.tjw.item.algo.VarianceCalculator;
import edu.columbia.tjw.item.optimize.OptimizationTarget;
import edu.columbia.tjw.item.util.IceTools;
import edu.columbia.tjw.item.util.MathTools;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;

/* loaded from: input_file:edu/columbia/tjw/item/fit/calculator/FitPointAnalyzer.class */
public final class FitPointAnalyzer {
    private static final double EPSILON = Math.ulp(4.0d);
    private static final double Z_SCORE_CUTOFF = 5.0d;
    private final int _superBlockSize;
    private final double _minStdDev = Z_SCORE_CUTOFF;
    private final OptimizationTarget _target;
    private final ItemSettings _settings;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: edu.columbia.tjw.item.fit.calculator.FitPointAnalyzer$1, reason: invalid class name */
    /* loaded from: input_file:edu/columbia/tjw/item/fit/calculator/FitPointAnalyzer$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget = new int[OptimizationTarget.values().length];

        static {
            try {
                $SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[OptimizationTarget.ENTROPY.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[OptimizationTarget.L2.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[OptimizationTarget.ICE_SIMPLE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[OptimizationTarget.ICE2.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[OptimizationTarget.ICE_STABLE_B.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[OptimizationTarget.ICE_RAW.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[OptimizationTarget.ICE.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[OptimizationTarget.ICE_B.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
        }
    }

    /* loaded from: input_file:edu/columbia/tjw/item/fit/calculator/FitPointAnalyzer$FitPointComparison.class */
    public final class FitPointComparison {
        private final FitPoint _pointA;
        private final FitPoint _pointB;
        private final int _blockCount;
        private final double _aVal;
        private final double _bVal;
        private final double _dev;
        private final double _zScore;
        private final double _relativeError;

        public FitPointComparison(FitPoint fitPoint, FitPoint fitPoint2, int i, double d) {
            this._pointA = fitPoint;
            this._pointB = fitPoint2;
            this._blockCount = i;
            if (this._blockCount < 1) {
                this._aVal = 0.0d;
                this._bVal = 0.0d;
                this._dev = 0.0d;
                this._zScore = 0.0d;
                this._relativeError = 0.0d;
                return;
            }
            this._aVal = FitPointAnalyzer.this.computeObjective(this._pointA, this._blockCount);
            this._bVal = FitPointAnalyzer.this.computeObjective(this._pointB, this._blockCount);
            this._dev = d;
            this._zScore = (this._aVal - this._bVal) / this._dev;
            this._relativeError = Math.abs(this._aVal - this._bVal) / (this._aVal + this._bVal);
        }

        public FitPoint getPointA() {
            return this._pointA;
        }

        public FitPoint getPointB() {
            return this._pointB;
        }

        public int getBlockCount() {
            return this._blockCount;
        }

        public double getValA() {
            return this._aVal;
        }

        public double getValB() {
            return this._bVal;
        }

        public double getDev() {
            return this._dev;
        }

        public double getZScore() {
            return this._zScore;
        }

        public double getRelativeError() {
            return this._relativeError;
        }
    }

    public FitPointAnalyzer(int i, OptimizationTarget optimizationTarget, ItemSettings itemSettings) {
        this._superBlockSize = i;
        this._target = optimizationTarget;
        this._settings = itemSettings;
    }

    public double compare(FitPoint fitPoint, FitPoint fitPoint2) {
        return generateComparision(fitPoint, fitPoint2).getZScore();
    }

    public FitPointComparison generateComparision(FitPoint fitPoint, FitPoint fitPoint2) {
        return compare(fitPoint, fitPoint2, this._minStdDev, false);
    }

    public double getSigmaTarget() {
        return this._minStdDev;
    }

    public double[] getDerivativeAdjustment(FitPoint fitPoint, FitPoint fitPoint2) {
        switch (AnonymousClass1.$SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[this._target.ordinal()]) {
            case GKQuantileBreakdown.USE_SIMPLE_BUCKETS /* 1 */:
                return new double[fitPoint.getDimension()];
            case 2:
                double l2Lambda = this._settings.getL2Lambda();
                double[] parameters = fitPoint.getParameters();
                MathTools.scalarMultiply(2.0d * l2Lambda, parameters);
                return parameters;
            case 3:
            case 4:
                fitPoint.computeAll(BlockCalculationType.SECOND_DERIVATIVE);
                return IceTools.fillIceExtraDerivative(fitPoint.getAggregated(BlockCalculationType.SECOND_DERIVATIVE));
            case 5:
                fitPoint.computeAll(BlockCalculationType.SECOND_DERIVATIVE);
                return IceTools.fillIce3ExtraDerivative(fitPoint.getAggregated(BlockCalculationType.SECOND_DERIVATIVE));
            case 6:
            case 7:
            case 8:
                fitPoint.clear();
                FitPoint fitPoint3 = fitPoint2 != null ? fitPoint2 : fitPoint;
                fitPoint3.computeAll(BlockCalculationType.FIRST_DERIVATIVE);
                fitPoint.computeAll(BlockCalculationType.FIRST_DERIVATIVE, fitPoint3.getAggregated(BlockCalculationType.FIRST_DERIVATIVE));
                BlockResult aggregated = fitPoint.getAggregated(BlockCalculationType.FIRST_DERIVATIVE);
                double[] scaledGradient2 = this._target == OptimizationTarget.ICE_B ? aggregated.getScaledGradient2() : aggregated.getScaledGradient();
                for (int i = 0; i < scaledGradient2.length; i++) {
                    double[] dArr = scaledGradient2;
                    int i2 = i;
                    dArr[i2] = dArr[i2] / fitPoint.getSize();
                }
                return scaledGradient2;
            default:
                throw new UnsupportedOperationException("Unknown target type.");
        }
    }

    public double[] getDerivative(FitPoint fitPoint) {
        return getDerivative(fitPoint, null);
    }

    public double[] getDerivative(FitPoint fitPoint, FitPoint fitPoint2) {
        switch (AnonymousClass1.$SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[this._target.ordinal()]) {
            case GKQuantileBreakdown.USE_SIMPLE_BUCKETS /* 1 */:
                fitPoint.computeAll(BlockCalculationType.FIRST_DERIVATIVE);
                return fitPoint.getAggregated(BlockCalculationType.FIRST_DERIVATIVE).getDerivative();
            case 2:
                fitPoint.computeAll(BlockCalculationType.FIRST_DERIVATIVE);
                double[] derivative = fitPoint.getAggregated(BlockCalculationType.FIRST_DERIVATIVE).getDerivative();
                double l2Lambda = this._settings.getL2Lambda();
                double[] parameters = fitPoint.getParameters();
                MathTools.scalarMultiply(2.0d * l2Lambda, parameters);
                for (int i = 0; i < parameters.length; i++) {
                    int i2 = i;
                    derivative[i2] = derivative[i2] + parameters[i];
                }
                return derivative;
            case 3:
            case 4:
                fitPoint.computeAll(BlockCalculationType.SECOND_DERIVATIVE);
                BlockResult aggregated = fitPoint.getAggregated(BlockCalculationType.SECOND_DERIVATIVE);
                int derivativeDimension = aggregated.getDerivativeDimension();
                double[] derivative2 = aggregated.getDerivative();
                double[] derivativeAdjustment = getDerivativeAdjustment(fitPoint, fitPoint2);
                for (int i3 = 0; i3 < derivativeDimension; i3++) {
                    int i4 = i3;
                    derivative2[i4] = derivative2[i4] + derivativeAdjustment[i3];
                }
                return derivative2;
            case 5:
                fitPoint.computeAll(BlockCalculationType.SECOND_DERIVATIVE);
                BlockResult aggregated2 = fitPoint.getAggregated(BlockCalculationType.SECOND_DERIVATIVE);
                int derivativeDimension2 = aggregated2.getDerivativeDimension();
                double[] derivative3 = aggregated2.getDerivative();
                double[] derivativeAdjustment2 = getDerivativeAdjustment(fitPoint, fitPoint2);
                for (int i5 = 0; i5 < derivativeDimension2; i5++) {
                    int i6 = i5;
                    derivative3[i6] = derivative3[i6] + derivativeAdjustment2[i5];
                }
                return derivative3;
            case 6:
            case 7:
            case 8:
                fitPoint.clear();
                FitPoint fitPoint3 = fitPoint2 != null ? fitPoint2 : fitPoint;
                fitPoint3.computeAll(BlockCalculationType.FIRST_DERIVATIVE);
                fitPoint.computeAll(BlockCalculationType.FIRST_DERIVATIVE, fitPoint3.getAggregated(BlockCalculationType.FIRST_DERIVATIVE));
                BlockResult aggregated3 = fitPoint.getAggregated(BlockCalculationType.FIRST_DERIVATIVE);
                int derivativeDimension3 = aggregated3.getDerivativeDimension();
                double[] derivative4 = aggregated3.getDerivative();
                double[] derivativeAdjustment3 = getDerivativeAdjustment(fitPoint, fitPoint2);
                for (int i7 = 0; i7 < derivativeDimension3; i7++) {
                    int i8 = i7;
                    derivative4[i8] = derivative4[i8] + derivativeAdjustment3[i7];
                }
                return derivative4;
            default:
                throw new UnsupportedOperationException("Unknown target type.");
        }
    }

    public double computeObjective(FitPoint fitPoint, int i) {
        switch (AnonymousClass1.$SwitchMap$edu$columbia$tjw$item$optimize$OptimizationTarget[this._target.ordinal()]) {
            case GKQuantileBreakdown.USE_SIMPLE_BUCKETS /* 1 */:
                fitPoint.computeUntil(i, BlockCalculationType.VALUE);
                return fitPoint.getAggregated(BlockCalculationType.VALUE).getEntropyMean();
            case 2:
                fitPoint.computeUntil(i, BlockCalculationType.VALUE);
                double entropyMean = fitPoint.getAggregated(BlockCalculationType.VALUE).getEntropyMean();
                double l2Lambda = this._settings.getL2Lambda();
                double[] parameters = fitPoint.getParameters();
                return entropyMean + (l2Lambda * MathTools.dot(parameters, parameters));
            case 3:
            case 4:
            case 5:
            case 7:
            case 8:
                fitPoint.computeUntil(i, BlockCalculationType.FIRST_DERIVATIVE);
                BlockResult aggregated = fitPoint.getAggregated(BlockCalculationType.FIRST_DERIVATIVE);
                double entropyMean2 = aggregated.getEntropyMean();
                return this._target == OptimizationTarget.ICE_SIMPLE ? entropyMean2 + (IceTools.computeIceSum(aggregated) / fitPoint.getSize()) : this._target == OptimizationTarget.ICE2 ? entropyMean2 + (IceTools.computeIce2Sum(aggregated) / fitPoint.getSize()) : entropyMean2 + (IceTools.computeIce3Sum(aggregated) / fitPoint.getSize());
            case 6:
                fitPoint.computeUntil(i, BlockCalculationType.SECOND_DERIVATIVE);
                BlockResult aggregated2 = fitPoint.getAggregated(BlockCalculationType.SECOND_DERIVATIVE);
                double entropyMean3 = aggregated2.getEntropyMean();
                aggregated2.getEntropyMeanDev();
                aggregated2.getDerivative();
                RealMatrix multiply = new SingularValueDecomposition(aggregated2.getSecondDerivative()).getSolver().getInverse().multiply(aggregated2.getFisherInformation());
                double d = 0.0d;
                for (int i2 = 0; i2 < multiply.getRowDimension(); i2++) {
                    d += multiply.getEntry(i2, i2);
                }
                return entropyMean3 + (d / fitPoint.getSize());
            default:
                throw new UnsupportedOperationException("Unknown target type.");
        }
    }

    public double computeObjectiveStdDev(FitPoint fitPoint, int i) {
        fitPoint.computeUntil(i, BlockCalculationType.VALUE);
        return fitPoint.getAggregated(BlockCalculationType.VALUE).getEntropyMeanDev();
    }

    public FitPointComparison compare(FitPoint fitPoint, FitPoint fitPoint2, double d, boolean z) {
        if (fitPoint.getBlockCount() != fitPoint2.getBlockCount()) {
            throw new IllegalArgumentException("Incomparable points.");
        }
        if (fitPoint.getNextBlock(BlockCalculationType.VALUE) > fitPoint2.getNextBlock(BlockCalculationType.VALUE)) {
            return compare(fitPoint2, fitPoint, d, !z);
        }
        BlockCalculationType blockCalculationType = BlockCalculationType.VALUE;
        VarianceCalculator varianceCalculator = new VarianceCalculator();
        double sqrt = Math.sqrt(fitPoint.getBlockSize());
        for (int i = 0; i < fitPoint.getBlockCount(); i++) {
            if (i >= fitPoint.getNextBlock(blockCalculationType)) {
                if (i >= this._superBlockSize && varianceCalculator.getMean() != 0.0d && Math.abs(varianceCalculator.getDev() * sqrt) >= d) {
                    break;
                }
                int min = Math.min(i + this._superBlockSize, fitPoint.getBlockCount());
                fitPoint.computeUntil(min, blockCalculationType);
                fitPoint2.computeUntil(min, blockCalculationType);
            }
            BlockResult block = fitPoint.getBlock(i, blockCalculationType);
            BlockResult block2 = fitPoint2.getBlock(i, blockCalculationType);
            if (block.getRowStart() != block2.getRowStart()) {
                throw new IllegalArgumentException("Misaligned blocks.");
            }
            if (block.getRowEnd() != block2.getRowEnd()) {
                throw new IllegalArgumentException("Misaligned blocks.");
            }
            varianceCalculator.update(block.getEntropyMean() - block2.getEntropyMean());
        }
        if (fitPoint.getNextBlock(blockCalculationType) < 1) {
            return new FitPointComparison(fitPoint, fitPoint2, 0, 0.0d);
        }
        varianceCalculator.getMean();
        double meanDev = varianceCalculator.getMeanDev();
        if (varianceCalculator.getCount() < 3 || 0.0d == meanDev) {
            meanDev = 0.5d * (fitPoint.getBlock(0, blockCalculationType).getEntropyMeanDev() + fitPoint2.getBlock(0, blockCalculationType).getEntropyMeanDev());
        }
        int max = Math.max(fitPoint.getNextBlock(BlockCalculationType.VALUE), fitPoint2.getNextBlock(BlockCalculationType.VALUE));
        return z ? new FitPointComparison(fitPoint2, fitPoint, max, meanDev) : new FitPointComparison(fitPoint, fitPoint2, max, meanDev);
    }
}
