package edu.columbia.tjw.item.fit.curve;

import edu.columbia.tjw.item.ItemCurve;
import edu.columbia.tjw.item.ItemCurveFactory;
import edu.columbia.tjw.item.ItemCurveParams;
import edu.columbia.tjw.item.ItemCurveType;
import edu.columbia.tjw.item.ItemModel;
import edu.columbia.tjw.item.ItemParameters;
import edu.columbia.tjw.item.ItemRegressor;
import edu.columbia.tjw.item.ItemSettings;
import edu.columbia.tjw.item.ItemStatus;
import edu.columbia.tjw.item.ParamFilter;
import edu.columbia.tjw.item.data.ItemStatusGrid;
import edu.columbia.tjw.item.fit.ParamFittingGrid;
import edu.columbia.tjw.item.fit.param.ParamFitter;
import edu.columbia.tjw.item.optimize.ConvergenceException;
import edu.columbia.tjw.item.util.EnumFamily;
import edu.columbia.tjw.item.util.LogUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.logging.Logger;

/* loaded from: input_file:edu/columbia/tjw/item/fit/curve/CurveFitter.class */
public abstract class CurveFitter<S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> {
    private static final Logger LOG = LogUtil.getLogger(CurveFitter.class);
    private static final int MAX_INTERACTION_DEPTH = 4;
    private final EnumFamily<T> _family;
    private final ItemSettings _settings;
    private final ItemStatusGrid<S, R> _grid;
    private final ItemCurveFactory<R, T> _factory;

    /* loaded from: input_file:edu/columbia/tjw/item/fit/curve/CurveFitter$FitResult.class */
    public static final class FitResult<S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> {
        private final double _startingLogL;
        private final double _logL;
        private final double _llImprovement;
        private final int _rowCount;
        private final S _toState;
        private final ItemParameters<S, R, T> _params;
        private final ItemCurveParams<R, T> _curveParams;

        public FitResult(ItemParameters<S, R, T> itemParameters, ItemCurveParams<R, T> itemCurveParams, S s, double d, double d2, int i) {
            this._params = itemParameters;
            this._curveParams = itemCurveParams;
            this._toState = s;
            this._logL = d;
            this._llImprovement = d2 - this._logL;
            this._startingLogL = d2;
            this._rowCount = i;
        }

        public S getToState() {
            return this._toState;
        }

        public ItemCurveParams<R, T> getCurveParams() {
            return this._curveParams;
        }

        public ItemModel<S, R, T> getModel() {
            return new ItemModel<>(this._params);
        }

        public double getLogLikelihood() {
            return this._logL;
        }

        public double improvementPerParameter() {
            return this._llImprovement / getEffectiveParamCount();
        }

        public double aicPerParameter() {
            return calculateAicDifference() / getEffectiveParamCount();
        }

        public int getEffectiveParamCount() {
            return this._curveParams.size() - 1;
        }

        public double calculateAicDifference() {
            return 2.0d * (getEffectiveParamCount() - (this._llImprovement * this._rowCount));
        }

        public String toString() {
            return "Fit result[" + this._llImprovement + "]: \n" + this._curveParams.toString();
        }
    }

    /* loaded from: input_file:edu/columbia/tjw/item/fit/curve/CurveFitter$ParamFilterImpl.class */
    private final class ParamFilterImpl implements ParamFilter<S, R, T> {
        private final int _targetEntry;
        private final ItemCurveParams<R, T> _curveParams;

        public ParamFilterImpl(int i, ItemCurveParams<R, T> itemCurveParams) {
            this._targetEntry = i;
            this._curveParams = itemCurveParams;
        }

        @Override // edu.columbia.tjw.item.ParamFilter
        public boolean betaIsFrozen(ItemParameters<S, R, T> itemParameters, S s, int i) {
            return (itemParameters.getInterceptIndex() == i || i == this._targetEntry || itemParameters.getEntryIndex(this._curveParams) == i) ? false : true;
        }

        @Override // edu.columbia.tjw.item.ParamFilter
        public boolean curveIsForbidden(ItemParameters<S, R, T> itemParameters, S s, ItemCurveParams<R, T> itemCurveParams) {
            return false;
        }
    }

    public CurveFitter(ItemCurveFactory<R, T> itemCurveFactory, ItemSettings itemSettings, ItemStatusGrid<S, R> itemStatusGrid) {
        if (null == itemSettings) {
            throw new NullPointerException("Settings cannot be null.");
        }
        this._factory = itemCurveFactory;
        this._family = itemCurveFactory.getFamily();
        this._settings = itemSettings;
        this._grid = itemStatusGrid;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double computeLogLikelihood(ItemParameters<S, R, T> itemParameters, ItemStatusGrid<S, R> itemStatusGrid) {
        return new ParamFitter(new ItemModel(itemParameters), this._settings).computeLogLikelihood(itemParameters, new ParamFittingGrid<>(itemParameters, itemStatusGrid), null);
    }

    public final ItemModel<S, R, T> calibrateCurves() {
        LOG.info("Starting curve calibration sweep.");
        ItemParameters<S, R, T> params = getParams();
        int entryCount = params.getEntryCount();
        ItemModel<S, R, T> itemModel = new ItemModel<>(params);
        ArrayList<ItemCurveParams<R, T>> arrayList = new ArrayList();
        for (int i = 0; i < entryCount; i++) {
            if (params.getEntryStatusRestrict(i) != null) {
                arrayList.add(params.getEntryCurveParams(i));
            }
        }
        for (ItemCurveParams<R, T> itemCurveParams : arrayList) {
            ItemParameters<S, R, T> params2 = itemModel.getParams();
            int entryIndex = params2.getEntryIndex(itemCurveParams);
            if (entryIndex == -1) {
                throw new IllegalStateException("Impossible.");
            }
            S entryStatusRestrict = params2.getEntryStatusRestrict(entryIndex);
            if (null == entryStatusRestrict) {
                throw new IllegalStateException("Impossible.");
            }
            double computeLogLikelihood = computeLogLikelihood(params2, this._grid);
            try {
                itemModel = calibrateCurve(entryIndex, entryStatusRestrict);
            } catch (ConvergenceException e) {
                LOG.info("Trouble converging, moving on to next curve.");
                LOG.info(e.getMessage());
            }
            double computeLogLikelihood2 = computeLogLikelihood(itemModel.getParams(), this._grid);
            if (computeLogLikelihood2 > computeLogLikelihood) {
                LOG.warning("Ending LL is worse than starting: " + computeLogLikelihood + " -> " + computeLogLikelihood2);
                LOG.info("Curve calibration starting params: " + params2);
                LOG.warning("Starting entry: " + itemCurveParams);
                LOG.warning("Ending params: " + itemModel.getParams());
                throw new IllegalStateException("Impossible.");
            }
        }
        LOG.info("Finished curve calibration sweep: " + getParams());
        return itemModel;
    }

    public final ItemModel<S, R, T> generateCurve(Set<R> set, Collection<ParamFilter<S, R, T>> collection) throws ConvergenceException {
        FitResult<S, R, T> findBest = findBest(set, collection);
        if (null == findBest) {
            throw new ConvergenceException("Unable to improve model.");
        }
        LOG.info("Generated curve[" + findBest.aicPerParameter() + "][" + ((FitResult) findBest)._startingLogL + " -> " + ((FitResult) findBest)._logL + "][" + findBest.getToState() + "]: " + findBest.getCurveParams());
        double aicPerParameter = findBest.aicPerParameter();
        if (aicPerParameter > this._settings.getAicCutoff()) {
            LOG.info("AIC improvement is not large enough.");
            throw new ConvergenceException("No curves could be added with sufficient AIC improvement: " + aicPerParameter);
        }
        if (this._settings.getAllowInteractionCurves()) {
            LOG.info("Now calculating interactions.");
            int i = 0;
            while (true) {
                if (i >= 4) {
                    break;
                }
                FitResult<S, R, T> generateInteractionTerm = generateInteractionTerm(findBest);
                double aicPerParameter2 = findBest.aicPerParameter();
                double aicPerParameter3 = generateInteractionTerm.aicPerParameter();
                if (aicPerParameter3 >= aicPerParameter2) {
                    LOG.info("Interaction terms were not better.");
                    break;
                }
                LOG.info("Added interaction term[" + aicPerParameter2 + " -> " + aicPerParameter3 + "]");
                findBest = generateInteractionTerm;
                i++;
            }
        }
        ItemModel<S, R, T> model = findBest.getModel();
        LOG.info("New Parameters: \n" + model.getParams().toString());
        return model;
    }

    private ItemCurveParams<R, T> appendToCurveParams(ItemCurveParams<R, T> itemCurveParams, ItemCurve<T> itemCurve, R r) {
        ArrayList arrayList = new ArrayList(itemCurveParams.getCurves());
        ArrayList arrayList2 = new ArrayList(itemCurveParams.getRegressors());
        arrayList.add(itemCurve);
        arrayList2.add(r);
        return new ItemCurveParams<>(itemCurveParams.getIntercept(), itemCurveParams.getBeta(), arrayList2, arrayList);
    }

    private SortedSet<R> getFlagRegs(ItemParameters<S, R, T> itemParameters) {
        int entryCount = itemParameters.getEntryCount();
        TreeSet treeSet = new TreeSet();
        for (int i = 0; i < entryCount; i++) {
            if (itemParameters.getEntryStatusRestrict(i) == null && itemParameters.getInterceptIndex() != i) {
                int entryDepth = itemParameters.getEntryDepth(i);
                for (int i2 = 0; i2 < entryDepth; i2++) {
                    treeSet.add(itemParameters.getEntryRegressor(i, i2));
                }
            }
        }
        return treeSet;
    }

    public FitResult<S, R, T> generateFlagInteraction(double d) {
        ItemParameters<S, R, T> params = getParams();
        int entryCount = params.getEntryCount();
        SortedSet<R> flagRegs = getFlagRegs(params);
        FitResult<S, R, T> fitResult = null;
        for (int i = 0; i < entryCount; i++) {
            ItemCurveParams<R, T> entryCurveParams = params.getEntryCurveParams(i, true);
            TreeSet treeSet = new TreeSet(entryCurveParams.getRegressors());
            for (R r : flagRegs) {
                if (!treeSet.contains(r)) {
                    ItemCurveParams<R, T> appendToCurveParams = appendToCurveParams(entryCurveParams, null, r);
                    S entryStatusRestrict = params.getEntryStatusRestrict(i);
                    FitResult<S, R, T> fitResult2 = null;
                    if (null == entryStatusRestrict) {
                        ItemParameters<S, R, T> addBeta = params.addBeta(appendToCurveParams, null);
                        ParamFittingGrid<S, R, T> paramFittingGrid = new ParamFittingGrid<>(addBeta, this._grid);
                        ParamFitter paramFitter = new ParamFitter(new ItemModel(addBeta), this._settings);
                        try {
                            ItemModel<S, R, T> fit = paramFitter.fit(paramFittingGrid, null);
                            double computeLogLikelihood = paramFitter.computeLogLikelihood(fit.getParams(), paramFittingGrid, null);
                            if (computeLogLikelihood < d) {
                                ItemParameters<S, R, T> params2 = fit.getParams();
                                fitResult2 = new FitResult<>(params2, params2.getEntryCurveParams(params2.getEntryCount() - 1, true), entryStatusRestrict, computeLogLikelihood, d, paramFittingGrid.size());
                            }
                        } catch (ConvergenceException e) {
                            LOG.info("Convergence exception, moving on: " + e.toString());
                        }
                    } else {
                        try {
                            fitResult2 = fitEntryExpansion(params, appendToCurveParams, entryStatusRestrict, false, d);
                        } catch (ConvergenceException e2) {
                            LOG.info("Convergence exception, moving on: " + e2.toString());
                        }
                    }
                    if (null != fitResult2) {
                        double aicPerParameter = fitResult2.aicPerParameter();
                        if (aicPerParameter < 0.0d) {
                            if (null == fitResult) {
                                fitResult = fitResult2;
                            } else {
                                double aicPerParameter2 = fitResult.aicPerParameter();
                                if (aicPerParameter < aicPerParameter2) {
                                    fitResult = fitResult2;
                                    LOG.info("Found improved result[" + aicPerParameter2 + " -> " + aicPerParameter + "]: " + fitResult.getCurveParams());
                                }
                            }
                        }
                    }
                }
            }
        }
        return fitResult;
    }

    private FitResult<S, R, T> generateInteractionTerm(ItemCurveParams<R, T> itemCurveParams, S s, double d) {
        ItemParameters<S, R, T> params = getParams();
        int entryCount = params.getEntryCount();
        TreeSet treeSet = new TreeSet();
        TreeSet treeSet2 = new TreeSet();
        for (int i = 0; i < itemCurveParams.getEntryDepth(); i++) {
            treeSet.add(itemCurveParams.getRegressor(i));
        }
        FitResult<S, R, T> fitResult = null;
        for (int i2 = 0; i2 < entryCount; i2++) {
            if (i2 != params.getInterceptIndex()) {
                int entryDepth = params.getEntryDepth(i2);
                for (int i3 = 0; i3 < entryDepth; i3++) {
                    R entryRegressor = params.getEntryRegressor(i2, i3);
                    if (!treeSet.contains(entryRegressor)) {
                        ItemCurve<T> entryCurve = params.getEntryCurve(i2, i3);
                        if (null == entryCurve) {
                            if (!treeSet2.contains(entryRegressor)) {
                                treeSet2.add(entryRegressor);
                            }
                        }
                        try {
                            FitResult<S, R, T> fitEntryExpansion = fitEntryExpansion(params, appendToCurveParams(itemCurveParams, entryCurve, entryRegressor), s, false, d);
                            if (null == fitResult) {
                                fitResult = fitEntryExpansion;
                            } else {
                                double aicPerParameter = fitResult.aicPerParameter();
                                double aicPerParameter2 = fitEntryExpansion.aicPerParameter();
                                if (aicPerParameter2 < aicPerParameter) {
                                    fitResult = fitEntryExpansion;
                                    LOG.info("Found improved result[" + aicPerParameter + " -> " + aicPerParameter2 + "]: " + fitResult.getCurveParams());
                                }
                            }
                        } catch (ConvergenceException e) {
                            LOG.info("Convergence exception, moving on: " + e.toString());
                        }
                    }
                }
            }
        }
        return fitResult;
    }

    private FitResult<S, R, T> generateInteractionTerm(FitResult<S, R, T> fitResult) {
        double logLikelihood = fitResult.getLogLikelihood();
        getParams();
        FitResult<S, R, T> generateInteractionTerm = generateInteractionTerm(fitResult.getCurveParams(), fitResult.getToState(), logLikelihood);
        if (null == generateInteractionTerm) {
            return fitResult;
        }
        double aicPerParameter = fitResult.aicPerParameter();
        double aicPerParameter2 = generateInteractionTerm.aicPerParameter();
        if (aicPerParameter2 >= aicPerParameter) {
            return fitResult;
        }
        LOG.info("Found improved result[" + aicPerParameter + " -> " + aicPerParameter2 + "]: " + generateInteractionTerm.getCurveParams());
        return generateInteractionTerm;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private FitResult<S, R, T> findBest(Set<R> set, Collection<ParamFilter<S, R, T>> collection) {
        ItemParameters params = getParams();
        FitResult<S, R, T> fitResult = null;
        double d = 0.0d;
        for (ItemStatus itemStatus : params.getStatus().getReachable()) {
            for (R r : set) {
                for (T t : this._family.getMembers()) {
                    try {
                        if (!params.curveIsForbidden(itemStatus, new ItemCurveParams(0.0d, 0.0d, r, this._factory.generateCurve(t, 0, new double[t.getParamCount()])), collection)) {
                            FitResult<S, R, T> findBest = findBest(t, r, itemStatus);
                            if (params.curveIsForbidden(itemStatus, findBest.getCurveParams(), collection)) {
                                LOG.info("Generated curve, but it is forbidden by filters, dropping: " + findBest.getCurveParams());
                            } else {
                                double calculateAicDifference = findBest.calculateAicDifference();
                                if (calculateAicDifference < d) {
                                    LOG.info("New Best: " + findBest + " -> " + calculateAicDifference + " vs. " + d);
                                    d = calculateAicDifference;
                                    fitResult = findBest;
                                }
                            }
                        }
                    } catch (ConvergenceException e) {
                        LOG.info("Trouble converging, moving on to next curve.");
                        LOG.info(e.getMessage());
                    } catch (IllegalArgumentException e2) {
                        LOG.info("Argument trouble (" + r + "), moving on to next curve.");
                        LOG.info(e2.getMessage());
                    }
                }
            }
        }
        return fitResult;
    }

    protected abstract ItemModel<S, R, T> calibrateCurve(int i, S s) throws ConvergenceException;

    public abstract FitResult<S, R, T> fitEntryExpansion(ItemParameters<S, R, T> itemParameters, ItemCurveParams<R, T> itemCurveParams, S s, boolean z, double d) throws ConvergenceException;

    protected abstract ItemParameters<S, R, T> getParams();

    protected abstract FitResult<S, R, T> findBest(T t, R r, S s) throws ConvergenceException;
}
