package edu.columbia.tjw.item.fit.curve;

import edu.columbia.tjw.item.ItemCurve;
import edu.columbia.tjw.item.ItemCurveParams;
import edu.columbia.tjw.item.ItemCurveType;
import edu.columbia.tjw.item.ItemParameters;
import edu.columbia.tjw.item.ItemRegressor;
import edu.columbia.tjw.item.ItemSettings;
import edu.columbia.tjw.item.ItemStatus;
import edu.columbia.tjw.item.data.ItemFittingGrid;
import edu.columbia.tjw.item.fit.EntropyCalculator;
import edu.columbia.tjw.item.fit.FitResult;
import edu.columbia.tjw.item.fit.FittingProgressChain;
import edu.columbia.tjw.item.fit.base.BaseFitter;
import edu.columbia.tjw.item.fit.param.ParamFitter;
import edu.columbia.tjw.item.optimize.ConvergenceException;
import edu.columbia.tjw.item.util.LogUtil;
import edu.columbia.tjw.item.util.MathFunctions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.logging.Logger;
import org.apache.commons.math3.util.Pair;

/* loaded from: input_file:edu/columbia/tjw/item/fit/curve/CurveFitter.class */
public final class CurveFitter<S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> {
    private static final Logger LOG = LogUtil.getLogger(CurveFitter.class);
    private final ItemSettings _settings;
    private final BaseFitter<S, R, T> _base;
    private final ParamFitter<S, R, T> _paramFitter;
    private final CurveParamsFitter<S, R, T> _fitter;

    public CurveFitter(ItemSettings itemSettings, BaseFitter<S, R, T> baseFitter) {
        if (null == itemSettings) {
            throw new NullPointerException("Settings cannot be null.");
        }
        this._settings = itemSettings;
        this._base = baseFitter;
        this._paramFitter = new ParamFitter<>(this._base);
        this._fitter = new CurveParamsFitter<>(this._settings, this._base);
    }

    public final boolean calibrateCurves(double d, boolean z, FittingProgressChain<S, R, T> fittingProgressChain) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("Improvement target must be nonnegative.");
        }
        LOG.info("Starting curve calibration sweep.");
        ItemParameters<S, R, T> bestParameters = fittingProgressChain.getBestParameters();
        int entryCount = bestParameters.getEntryCount();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < entryCount; i++) {
            if (bestParameters.getEntryStatusRestrict(i) != null) {
                arrayList.add(bestParameters.getEntryCurveParams(i));
            }
        }
        Collections.shuffle(arrayList, this._settings.getRandom());
        int calibrateSize = this._settings.getCalibrateSize();
        double improvementRatio = d * this._settings.getImprovementRatio();
        double d2 = 0.0d;
        double logLikelihood = fittingProgressChain.getLogLikelihood();
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            double d3 = d2 / (i2 + 1);
            if (!z && i2 >= calibrateSize && d3 < improvementRatio) {
                break;
            }
            ItemCurveParams<R, T> itemCurveParams = (ItemCurveParams) arrayList.get(i2);
            ItemParameters<S, R, T> bestParameters2 = fittingProgressChain.getBestParameters();
            int entryIndex = bestParameters2.getEntryIndex(itemCurveParams);
            if (entryIndex == -1) {
                System.out.println("Should not be possible, but skipping.");
            } else {
                S entryStatusRestrict = bestParameters2.getEntryStatusRestrict(entryIndex);
                if (null == entryStatusRestrict) {
                    throw new IllegalStateException("Impossible.");
                }
                double logLikelihood2 = fittingProgressChain.getLogLikelihood();
                try {
                    calibrateCurve(entryIndex, entryStatusRestrict, fittingProgressChain);
                    double logLikelihood3 = fittingProgressChain.getLogLikelihood();
                    double d4 = logLikelihood2 - logLikelihood3;
                    if (MathFunctions.doubleCompareRounded(logLikelihood3, logLikelihood2) < 0) {
                        LOG.warning("Ending LL is worse than starting: " + logLikelihood2 + " -> " + logLikelihood3);
                        LOG.info("Curve calibration starting params: " + bestParameters2);
                        LOG.warning("Starting entry: " + itemCurveParams);
                        LOG.warning("Ending params: " + fittingProgressChain.getBestParameters());
                        throw new IllegalStateException("Impossible.");
                    }
                    d2 += d4;
                } catch (ConvergenceException e) {
                    LOG.info("Trouble converging, done calibrating.");
                    LOG.info(e.getMessage());
                }
            }
        }
        boolean z2 = 0 != MathFunctions.doubleCompareRounded(logLikelihood, fittingProgressChain.getLogLikelihood());
        LOG.info("Finished curve calibration sweep[" + z2 + "]: " + fittingProgressChain.getBestParameters());
        return z2;
    }

    public final boolean generateCurve(FittingProgressChain<S, R, T> fittingProgressChain, Set<R> set) {
        CurveFitResult<S, R, T> findBest = findBest(set, fittingProgressChain.getLatestResults());
        if (null == findBest || !fittingProgressChain.pushResults("CurveGeneration", findBest)) {
            return false;
        }
        LOG.info("Generated curve[" + findBest.aicPerParameter() + "][" + findBest.getStartingLogLikelihood() + " -> " + findBest.getLogLikelihood() + "][" + findBest.getToState() + "]: " + findBest.getCurveParams());
        if (this._settings.getAllowInteractionCurves()) {
            LOG.info("Now calculating interactions.");
            if (!generateInteractions(fittingProgressChain, findBest)) {
                LOG.info("Interaction terms were not better.");
            }
        }
        LOG.info("New Parameters[" + findBest.getLogLikelihood() + "]: \n" + findBest.getModelParams().toString());
        return true;
    }

    private ItemCurveParams<R, T> appendToCurveParams(ItemCurveParams<R, T> itemCurveParams, ItemCurve<T> itemCurve, R r) {
        ArrayList arrayList = new ArrayList(itemCurveParams.getCurves());
        ArrayList arrayList2 = new ArrayList(itemCurveParams.getRegressors());
        arrayList2.add(r);
        arrayList.add(itemCurve);
        return new ItemCurveParams<>(itemCurveParams.getIntercept(), itemCurveParams.getBeta(), arrayList2, arrayList);
    }

    private SortedSet<R> getFlagRegs(ItemParameters<S, R, T> itemParameters) {
        int entryCount = itemParameters.getEntryCount();
        TreeSet treeSet = new TreeSet();
        for (int i = 0; i < entryCount; i++) {
            if (itemParameters.getEntryStatusRestrict(i) == null && itemParameters.getInterceptIndex() != i) {
                int entryDepth = itemParameters.getEntryDepth(i);
                for (int i2 = 0; i2 < entryDepth; i2++) {
                    treeSet.add(itemParameters.getEntryRegressor(i, i2));
                }
            }
        }
        return treeSet;
    }

    private CurveFitResult<S, R, T> generateSingleInteraction(R r, ItemParameters<S, R, T> itemParameters, CurveFitResult<S, R, T> curveFitResult, ItemCurve<T> itemCurve, S s) {
        ItemCurveParams<R, T> appendToCurveParams = appendToCurveParams(curveFitResult.getCurveParams(), itemCurve, r);
        if (itemParameters.curveIsForbidden(s, appendToCurveParams)) {
            return null;
        }
        FittingProgressChain<S, R, T> fittingProgressChain = new FittingProgressChain<>(this._settings, "SingleInteraction", (ItemParameters) curveFitResult.getFitResult().getParams(), getSize(), (EntropyCalculator) this._base.getCalc(), true);
        if (null == s) {
            FitResult<S, R, T> fit = this._paramFitter.fit(fittingProgressChain, itemParameters.addBeta(appendToCurveParams, null));
            ItemParameters<S, R, T> params = fit.getParams();
            return new CurveFitResult<>(fit, params.getEntryCurveParams(params.getEntryCount() - 1, true), s, getGrid().size());
        }
        CurveFitResult<S, R, T> doCalibration = this._fitter.doCalibration(appendToCurveParams, itemParameters, curveFitResult.getFitResult(), s);
        if (!fittingProgressChain.pushResults("ParameterExpansion", doCalibration)) {
            return doCalibration;
        }
        FitResult<S, R, T> fit2 = this._paramFitter.fit(fittingProgressChain, doCalibration.getModelParams());
        if (fit2.getInformationCriterionDiff() >= this._settings.getAicCutoff()) {
            return doCalibration;
        }
        ItemParameters<S, R, T> params2 = fit2.getParams();
        return new CurveFitResult<>(fit2, params2.getEntryCurveParams(params2.getEntryCount() - 1, true), s, getGrid().size());
    }

    private List<Pair<R, ItemCurve<T>>> extractRegs(ItemParameters<S, R, T> itemParameters, S s) {
        SortedSet<R> flagRegs = getFlagRegs(itemParameters);
        ArrayList arrayList = new ArrayList();
        Iterator<R> it = flagRegs.iterator();
        while (it.hasNext()) {
            arrayList.add(new Pair(it.next(), (Object) null));
        }
        if (null != s) {
            for (int i = 0; i < itemParameters.getEntryCount(); i++) {
                int entryDepth = itemParameters.getEntryDepth(i);
                for (int i2 = 0; i2 < entryDepth; i2++) {
                    ItemCurve<T> entryCurve = itemParameters.getEntryCurve(i, i2);
                    if (null != entryCurve) {
                        arrayList.add(new Pair(itemParameters.getEntryRegressor(i, i2), entryCurve));
                    }
                }
            }
        }
        return arrayList;
    }

    public boolean generateInteractions(FittingProgressChain<S, R, T> fittingProgressChain, CurveFitResult<S, R, T> curveFitResult) {
        return generateInteractions(fittingProgressChain, curveFitResult.getCurveParams(), curveFitResult.getToState(), curveFitResult.aicPerParameter(), curveFitResult.getFitResult().getPrev().getEntropy(), true);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public boolean generateInteractions(FittingProgressChain<S, R, T> fittingProgressChain, ItemCurveParams<R, T> itemCurveParams, S s, double d, double d2, boolean z) {
        ItemParameters<S, R, T> bestParameters = fittingProgressChain.getBestParameters();
        List extractRegs = extractRegs(bestParameters, s);
        Collections.shuffle(extractRegs, this._settings.getRandom());
        FitResult<S, R, T> fitResults = fittingProgressChain.getLatestFrame().getFitResults();
        double entropy = fitResults.getEntropy();
        double improvementRatio = this._settings.getImprovementRatio() * fittingProgressChain.getLatestFrame().getAicDiff();
        CurveFitResult curveFitResult = new CurveFitResult(fitResults, itemCurveParams, s, fittingProgressChain.getRowCount());
        boolean z2 = itemCurveParams.getEntryDepth() == 1 && itemCurveParams.getCurve(0) == null;
        int i = 0;
        boolean z3 = false;
        for (int i2 = 0; i2 < extractRegs.size(); i2++) {
            double logLikelihood = fittingProgressChain.getLogLikelihood();
            double d3 = improvementRatio * (i + 1);
            double aicDiff = fittingProgressChain.getLatestFrame().getAicDiff();
            double d4 = entropy - ((entropy * 0.001d) * (i + 1));
            if (!z && i >= this._settings.getCalibrateSize() && aicDiff >= d3 && logLikelihood > d4) {
                break;
            }
            Pair pair = (Pair) extractRegs.get(i2);
            ItemRegressor itemRegressor = (ItemRegressor) pair.getFirst();
            ItemCurve itemCurve = (ItemCurve) pair.getSecond();
            if ((null != itemCurve || !z2 || itemRegressor.ordinal() >= itemCurveParams.getRegressor(0).ordinal()) && ((null != s || null == itemCurve) && (null != itemCurve || !itemCurveParams.getRegressors().contains(itemRegressor)))) {
                i++;
                CurveFitResult generateSingleInteraction = generateSingleInteraction(itemRegressor, bestParameters, curveFitResult, itemCurve, s);
                if (null != generateSingleInteraction && generateSingleInteraction.getFitResult().getInformationCriterionDiff() / generateSingleInteraction.getEffectiveParamCount() < d && fittingProgressChain.pushResults("CurveInteractions", generateSingleInteraction.getFitResult())) {
                    z3 = true;
                    curveFitResult = generateSingleInteraction;
                }
            }
        }
        return z3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public List<CurveFitResult<S, R, T>> generateCandidateResults(Set<R> set, FitResult<S, R, T> fitResult) {
        ArrayList arrayList = new ArrayList();
        S status = fitResult.getParams().getStatus();
        for (ItemStatus itemStatus : status.getReachable()) {
            if (!itemStatus.equals(status)) {
                Iterator<R> it = set.iterator();
                while (it.hasNext()) {
                    arrayList.addAll(this._fitter.calibrateCurveAdditions(it.next(), itemStatus, fitResult));
                }
            }
        }
        return arrayList;
    }

    public CurveFitResult<S, R, T> findBest(Set<R> set, FitResult<S, R, T> fitResult) {
        CurveFitResult<S, R, T> curveFitResult = null;
        double d = 0.0d;
        for (CurveFitResult<S, R, T> curveFitResult2 : generateCandidateResults(set, fitResult)) {
            double calculateAicDifference = curveFitResult2.calculateAicDifference();
            if (calculateAicDifference < d) {
                LOG.info("New Best: " + curveFitResult2 + " -> " + calculateAicDifference + " vs. " + d);
                d = calculateAicDifference;
                curveFitResult = curveFitResult2;
            }
        }
        return curveFitResult;
    }

    private int getSize() {
        return getGrid().size();
    }

    private ItemFittingGrid<S, R> getGrid() {
        return this._base.getCalc().getGrid();
    }

    private boolean calibrateCurve(int i, S s, FittingProgressChain<S, R, T> fittingProgressChain) throws ConvergenceException {
        CurveFitResult<S, R, T> calibrateExistingCurve = this._fitter.calibrateExistingCurve(i, s, fittingProgressChain.getLatestResults());
        if (null == calibrateExistingCurve) {
            return false;
        }
        return fittingProgressChain.pushResults("CurveCalibrate", calibrateExistingCurve);
    }
}
