package edu.columbia.tjw.item.fit.param;

import edu.columbia.tjw.item.ItemCurveType;
import edu.columbia.tjw.item.ItemModel;
import edu.columbia.tjw.item.ItemParameters;
import edu.columbia.tjw.item.ItemRegressor;
import edu.columbia.tjw.item.ItemSettings;
import edu.columbia.tjw.item.ItemStatus;
import edu.columbia.tjw.item.ParamFilter;
import edu.columbia.tjw.item.fit.ParamFittingGrid;
import edu.columbia.tjw.item.optimize.ConvergenceException;
import edu.columbia.tjw.item.optimize.EvaluationResult;
import edu.columbia.tjw.item.optimize.MultivariateOptimizer;
import edu.columbia.tjw.item.optimize.MultivariatePoint;
import edu.columbia.tjw.item.optimize.OptimizationResult;
import edu.columbia.tjw.item.util.LogUtil;
import java.util.Arrays;
import java.util.Collection;
import java.util.logging.Logger;

/* loaded from: input_file:edu/columbia/tjw/item/fit/param/ParamFitter.class */
public final class ParamFitter<S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> {
    private static final Logger LOG = LogUtil.getLogger(ParamFitter.class);
    private ItemModel<S, R, T> _model;
    private final MultivariateOptimizer _optimizer;
    private final ItemSettings _settings;

    public ParamFitter(ItemModel<S, R, T> itemModel, ItemSettings itemSettings) {
        this._model = itemModel;
        this._optimizer = new MultivariateOptimizer(itemSettings.getBlockSize(), 300, 20, 0.1d);
        this._settings = itemSettings;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public LogisticModelFunction<S, R, T> generateFunction(ItemParameters<S, R, T> itemParameters, ParamFittingGrid<S, R, T> paramFittingGrid, Collection<ParamFilter<S, R, T>> collection) {
        int reachableCount = itemParameters.getStatus().getReachableCount();
        int entryCount = itemParameters.getEntryCount();
        ItemStatus status = itemParameters.getStatus();
        int i = reachableCount * entryCount;
        int i2 = 0;
        double[] dArr = new double[i];
        int[] iArr = new int[i];
        int[] iArr2 = new int[i];
        for (int i3 = 0; i3 < reachableCount; i3++) {
            ItemStatus itemStatus = (ItemStatus) status.getReachable().get(i3);
            for (int i4 = 0; i4 < entryCount; i4++) {
                if (!itemParameters.betaIsFrozen(itemStatus, i4, collection)) {
                    dArr[i2] = itemParameters.getBeta(i3, i4);
                    iArr[i2] = i3;
                    iArr2[i2] = i4;
                    i2++;
                }
            }
        }
        return new LogisticModelFunction<>(Arrays.copyOf(dArr, i2), Arrays.copyOf(iArr, i2), Arrays.copyOf(iArr2, i2), itemParameters, paramFittingGrid, new ItemModel(itemParameters), this._settings);
    }

    public double computeLogLikelihood(ItemParameters<S, R, T> itemParameters, ParamFittingGrid<S, R, T> paramFittingGrid, Collection<ParamFilter<S, R, T>> collection) {
        LogisticModelFunction<S, R, T> generateFunction = generateFunction(itemParameters, paramFittingGrid, collection);
        double[] beta = generateFunction.getBeta();
        EvaluationResult generateResult = generateFunction.generateResult();
        generateFunction.value(new MultivariatePoint(beta), 0, generateFunction.numRows(), generateResult);
        return generateResult.getMean();
    }

    public ItemModel<S, R, T> fit(ParamFittingGrid<S, R, T> paramFittingGrid, Collection<ParamFilter<S, R, T>> collection) throws ConvergenceException {
        LOG.info("Fitting Coefficients: " + this._model.getParams());
        LogisticModelFunction<S, R, T> generateFunction = generateFunction(this._model.getParams(), paramFittingGrid, collection);
        double[] beta = generateFunction.getBeta();
        MultivariatePoint multivariatePoint = new MultivariatePoint(beta);
        EvaluationResult generateResult = generateFunction.generateResult();
        generateFunction.value(multivariatePoint, 0, generateFunction.numRows(), generateResult);
        double mean = generateResult.getMean();
        LOG.info("\n\n -->Log Likelihood: " + mean);
        OptimizationResult<MultivariatePoint> optimize = this._optimizer.optimize(generateFunction, multivariatePoint);
        if (!optimize.converged()) {
            LOG.info("Exhausted dataset before convergence, moving on.");
        }
        MultivariatePoint optimum = optimize.getOptimum();
        for (int i = 0; i < beta.length; i++) {
            beta[i] = optimum.getElement(i);
        }
        double minValue = optimize.minValue();
        LOG.info("LL improvement: " + mean + " -> " + minValue);
        if (minValue > mean) {
            return null;
        }
        ItemModel<S, R, T> updateParameters = this._model.updateParameters(generateFunction.generateParams(beta));
        LOG.info("Updated Coefficients: " + updateParameters.getParams());
        return updateParameters;
    }

    public ItemModel<S, R, T> fitAndUpdate(ParamFittingGrid<S, R, T> paramFittingGrid, Collection<ParamFilter<S, R, T>> collection) throws ConvergenceException {
        ItemModel<S, R, T> fit = fit(paramFittingGrid, collection);
        if (null != fit) {
            this._model = fit;
        }
        return fit;
    }

    private ItemParameters<S, R, T> updateParams(ItemParameters<S, R, T> itemParameters, int[] iArr, int[] iArr2, double[] dArr) {
        double[][] betas = itemParameters.getBetas();
        for (int i = 0; i < dArr.length; i++) {
            betas[iArr[i]][iArr2[i]] = dArr[i];
        }
        return itemParameters.updateBetas(betas);
    }
}
