package edu.columbia.tjw.item.fit.param;

import edu.columbia.tjw.item.ItemCurveType;
import edu.columbia.tjw.item.ItemModel;
import edu.columbia.tjw.item.ItemParameters;
import edu.columbia.tjw.item.ItemRegressor;
import edu.columbia.tjw.item.ItemSettings;
import edu.columbia.tjw.item.ItemStatus;
import edu.columbia.tjw.item.ParamFilter;
import edu.columbia.tjw.item.fit.EntropyCalculator;
import edu.columbia.tjw.item.fit.FittingProgressChain;
import edu.columbia.tjw.item.fit.ParamFittingGrid;
import edu.columbia.tjw.item.optimize.ConvergenceException;
import edu.columbia.tjw.item.optimize.MultivariateOptimizer;
import edu.columbia.tjw.item.optimize.MultivariatePoint;
import edu.columbia.tjw.item.optimize.OptimizationResult;
import edu.columbia.tjw.item.util.LogUtil;
import java.util.Arrays;
import java.util.Collection;
import java.util.logging.Logger;

/* loaded from: input_file:edu/columbia/tjw/item/fit/param/ParamFitter.class */
public final class ParamFitter<S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> {
    private static final Logger LOG = LogUtil.getLogger(ParamFitter.class);
    private final MultivariateOptimizer _optimizer;
    private final ItemSettings _settings;
    private final Collection<ParamFilter<S, R, T>> _filters;
    private final EntropyCalculator<S, R, T> _calc;
    ItemParameters<S, R, T> _cacheParams;
    LogisticModelFunction<S, R, T> _cacheFunction;

    public ParamFitter(EntropyCalculator<S, R, T> entropyCalculator, ItemSettings itemSettings, Collection<ParamFilter<S, R, T>> collection) {
        this._calc = entropyCalculator;
        this._filters = collection;
        this._optimizer = new MultivariateOptimizer(itemSettings.getBlockSize(), 300, 20, 0.1d);
        this._settings = itemSettings;
    }

    public ParamFitResult<S, R, T> fit(FittingProgressChain<S, R, T> fittingProgressChain) throws ConvergenceException {
        return fit(fittingProgressChain, fittingProgressChain.getBestParameters());
    }

    public synchronized ParamFitResult<S, R, T> fit(FittingProgressChain<S, R, T> fittingProgressChain, ItemParameters<S, R, T> itemParameters) throws ConvergenceException {
        ParamFitResult<S, R, T> paramFitResult;
        double logLikelihood = fittingProgressChain.getLogLikelihood();
        if (itemParameters != this._cacheParams) {
            this._cacheParams = itemParameters;
            this._cacheFunction = generateFunction(itemParameters);
        }
        LogisticModelFunction<S, R, T> logisticModelFunction = this._cacheFunction;
        double[] beta = logisticModelFunction.getBeta();
        MultivariatePoint multivariatePoint = new MultivariatePoint(beta);
        int numRows = logisticModelFunction.numRows();
        OptimizationResult<MultivariatePoint> optimize = this._optimizer.optimize(logisticModelFunction, multivariatePoint);
        MultivariatePoint optimum = optimize.getOptimum();
        for (int i = 0; i < beta.length; i++) {
            beta[i] = optimum.getElement(i);
        }
        double minValue = optimize.minValue();
        LOG.info("Fitting coefficients, LL improvement: " + logLikelihood + " -> " + minValue + "(" + (minValue - logLikelihood) + ")");
        if (!optimize.converged()) {
            LOG.info("Exhausted dataset before convergence, moving on.");
        }
        if (minValue > logLikelihood) {
            paramFitResult = new ParamFitResult<>(fittingProgressChain.getBestParameters(), fittingProgressChain.getBestParameters(), logLikelihood, logLikelihood, numRows);
            fittingProgressChain.pushResults("ParamFit", paramFitResult.getEndingParams(), paramFitResult.getEndingLL());
        } else {
            ItemParameters<S, R, T> generateParams = logisticModelFunction.generateParams(beta);
            paramFitResult = new ParamFitResult<>(itemParameters, generateParams, this._calc.computeEntropy(generateParams).getEntropy(), logLikelihood, numRows);
            fittingProgressChain.pushResults("ParamFit", paramFitResult.getEndingParams(), paramFitResult.getEndingLL());
        }
        return paramFitResult;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private LogisticModelFunction<S, R, T> generateFunction(ItemParameters<S, R, T> itemParameters) {
        int reachableCount = itemParameters.getStatus().getReachableCount();
        int entryCount = itemParameters.getEntryCount();
        ItemStatus status = itemParameters.getStatus();
        int i = reachableCount * entryCount;
        int i2 = 0;
        double[] dArr = new double[i];
        int[] iArr = new int[i];
        int[] iArr2 = new int[i];
        for (int i3 = 0; i3 < reachableCount; i3++) {
            ItemStatus itemStatus = (ItemStatus) status.getReachable().get(i3);
            for (int i4 = 0; i4 < entryCount; i4++) {
                if (!itemParameters.betaIsFrozen(itemStatus, i4, this._filters)) {
                    dArr[i2] = itemParameters.getBeta(i3, i4);
                    iArr[i2] = i3;
                    iArr2[i2] = i4;
                    i2++;
                }
            }
        }
        return new LogisticModelFunction<>(Arrays.copyOf(dArr, i2), Arrays.copyOf(iArr, i2), Arrays.copyOf(iArr2, i2), itemParameters, new ParamFittingGrid(itemParameters, this._calc.getGrid()), new ItemModel(itemParameters), this._settings);
    }
}
