package edu.columbia.tjw.item.fit.curve;

import edu.columbia.tjw.item.ItemCurve;
import edu.columbia.tjw.item.ItemCurveFactory;
import edu.columbia.tjw.item.ItemCurveParams;
import edu.columbia.tjw.item.ItemCurveType;
import edu.columbia.tjw.item.ItemRegressor;
import edu.columbia.tjw.item.ItemSettings;
import edu.columbia.tjw.item.ItemStatus;
import edu.columbia.tjw.item.algo.QuantileStatistics;
import edu.columbia.tjw.item.util.LogUtil;
import java.util.Arrays;
import java.util.Random;
import java.util.logging.Logger;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.MaxIter;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.MultiStartMultivariateOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.PowellOptimizer;
import org.apache.commons.math3.random.RandomVectorGenerator;

/* loaded from: input_file:edu/columbia/tjw/item/fit/curve/RawCurveCalibrator.class */
public class RawCurveCalibrator<S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> {
    private static final Logger LOG = LogUtil.getLogger(RawCurveCalibrator.class);

    /* loaded from: input_file:edu/columbia/tjw/item/fit/curve/RawCurveCalibrator$InnerFunction.class */
    private static final class InnerFunction<S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> implements MultivariateFunction {
        final ItemCurveFactory<R, T> _factory;
        private final QuantileStatistics _dist;
        private final ItemCurveParams<R, T> _params;

        public InnerFunction(ItemCurveFactory<R, T> itemCurveFactory, QuantileStatistics quantileStatistics, ItemCurveParams<R, T> itemCurveParams) {
            this._factory = itemCurveFactory;
            this._dist = quantileStatistics;
            this._params = itemCurveParams;
        }

        public double value(double[] dArr) {
            ItemCurveParams itemCurveParams = new ItemCurveParams(this._params, this._factory, dArr);
            if (itemCurveParams.getEntryDepth() > 1) {
                throw new IllegalArgumentException("Raw calibration only available for entries of depth 1.");
            }
            ItemCurve<T> curve = itemCurveParams.getCurve(0);
            double totalCount = this._dist.getQuantApprox().getTotalCount();
            if (totalCount < 1.0d) {
                return Double.NaN;
            }
            double d = 0.0d;
            for (int i = 0; i < this._dist.getSize(); i++) {
                double bucketMean = this._dist.getQuantApprox().getBucketMean(i);
                double meanY = this._dist.getMeanY(i);
                double count = this._dist.getCount(i);
                double intercept = meanY - (itemCurveParams.getIntercept() + (itemCurveParams.getBeta() * curve.transform(bucketMean)));
                d += count * intercept * intercept;
            }
            return d / totalCount;
        }
    }

    /* loaded from: input_file:edu/columbia/tjw/item/fit/curve/RawCurveCalibrator$VectorGenerator.class */
    private static final class VectorGenerator implements RandomVectorGenerator {
        private final Random _rand;
        private final double[] _base;

        public VectorGenerator(double[] dArr, Random random) {
            this._base = (double[]) dArr.clone();
            this._rand = random;
        }

        public double[] nextVector() {
            double[] dArr = (double[]) this._base.clone();
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = (this._rand.nextDouble() - 0.5d) * dArr[i];
            }
            return dArr;
        }
    }

    public static <S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> ItemCurveParams<R, T> polishCurveParameters(ItemCurveFactory<R, T> itemCurveFactory, ItemSettings itemSettings, QuantileStatistics quantileStatistics, ItemCurveParams<R, T> itemCurveParams) {
        if (itemCurveParams.getEntryDepth() > 1) {
            throw new IllegalArgumentException("Only valid on entries of depth 1.");
        }
        InnerFunction innerFunction = new InnerFunction(itemCurveFactory, quantileStatistics, itemCurveParams);
        double[] generatePoint = itemCurveParams.generatePoint();
        VectorGenerator vectorGenerator = new VectorGenerator(generatePoint, itemSettings.getRandom());
        int max = Math.max(1, itemSettings.getPolishMultiStartPoints());
        MultiStartMultivariateOptimizer multiStartMultivariateOptimizer = new MultiStartMultivariateOptimizer(new PowellOptimizer(0.001d, 0.001d), max, vectorGenerator);
        OptimizationData initialGuess = new InitialGuess(generatePoint);
        double value = innerFunction.value(generatePoint);
        try {
            PointValuePair pointValuePair = (PointValuePair) multiStartMultivariateOptimizer.optimize(new OptimizationData[]{new ObjectiveFunction(innerFunction), GoalType.MINIMIZE, initialGuess, new MaxIter(max * 100), new MaxEval(max * 300)});
            double doubleValue = ((Double) pointValuePair.getValue()).doubleValue();
            double[] pointRef = pointValuePair.getPointRef();
            LOG.info("Polish run completed (" + value + " -> " + doubleValue + ")[" + multiStartMultivariateOptimizer.getIterations() + "]: " + Arrays.toString(generatePoint) + " -> " + Arrays.toString(pointRef));
            if (doubleValue < value) {
                ItemCurveParams<R, T> itemCurveParams2 = new ItemCurveParams<>(itemCurveParams, itemCurveFactory, pointRef);
                if (!itemSettings.getBoundCentrality()) {
                    return itemCurveParams2;
                }
                double bucketMean = quantileStatistics.getQuantApprox().getBucketMean(0);
                double bucketMean2 = quantileStatistics.getQuantApprox().getBucketMean(quantileStatistics.getSize() - 1);
                ItemCurve<T> curve = itemCurveParams2.getCurve(0);
                ItemCurve<T> boundCentrality = itemCurveFactory.boundCentrality(curve, bucketMean, bucketMean2);
                return boundCentrality == curve ? itemCurveParams2 : new ItemCurveParams<>(itemCurveParams2.getIntercept(), itemCurveParams2.getBeta(), itemCurveParams2.getRegressor(0), boundCentrality);
            }
        } catch (TooManyEvaluationsException e) {
            LOG.info("Polish failed, too many evaluations.");
        }
        return itemCurveParams;
    }
}
