package edu.columbia.tjw.item.fit.param;

import edu.columbia.tjw.item.ItemCurveType;
import edu.columbia.tjw.item.ItemModel;
import edu.columbia.tjw.item.ItemParameters;
import edu.columbia.tjw.item.ItemRegressor;
import edu.columbia.tjw.item.ItemSettings;
import edu.columbia.tjw.item.ItemStatus;
import edu.columbia.tjw.item.ParamFilter;
import edu.columbia.tjw.item.data.ItemStatusGrid;
import edu.columbia.tjw.item.fit.ParamFittingGrid;
import edu.columbia.tjw.item.optimize.ConvergenceException;
import edu.columbia.tjw.item.optimize.EvaluationResult;
import edu.columbia.tjw.item.optimize.MultivariateOptimizer;
import edu.columbia.tjw.item.optimize.MultivariatePoint;
import edu.columbia.tjw.item.optimize.OptimizationResult;
import edu.columbia.tjw.item.util.LogUtil;
import java.util.Arrays;
import java.util.Collection;
import java.util.logging.Logger;

/* loaded from: input_file:edu/columbia/tjw/item/fit/param/ParamFitter.class */
public final class ParamFitter<S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> {
    private static final Logger LOG = LogUtil.getLogger(ParamFitter.class);
    private final ItemModel<S, R, T> _model;
    private final MultivariateOptimizer _optimizer;
    private final ItemSettings _settings;
    private final ParamFittingGrid<S, R, T> _grid;
    private final Collection<ParamFilter<S, R, T>> _filters;
    private final LogisticModelFunction<S, R, T> _function = generateFunction();

    public ParamFitter(ItemParameters<S, R, T> itemParameters, ItemStatusGrid<S, R> itemStatusGrid, ItemSettings itemSettings, Collection<ParamFilter<S, R, T>> collection) {
        this._model = new ItemModel<>(itemParameters);
        this._grid = new ParamFittingGrid<>(itemParameters, itemStatusGrid);
        this._filters = collection;
        this._optimizer = new MultivariateOptimizer(itemSettings.getBlockSize(), 300, 20, 0.1d);
        this._settings = itemSettings;
    }

    public double computeLogLikelihood(ItemParameters<S, R, T> itemParameters) {
        double[] beta = this._function.getBeta();
        EvaluationResult generateResult = this._function.generateResult();
        this._function.value(new MultivariatePoint(beta), 0, this._function.numRows(), generateResult);
        return generateResult.getMean();
    }

    public ParamFitResult<S, R, T> fit() throws ConvergenceException {
        double[] beta = this._function.getBeta();
        MultivariatePoint multivariatePoint = new MultivariatePoint(beta);
        int numRows = this._function.numRows();
        EvaluationResult generateResult = this._function.generateResult();
        this._function.value(multivariatePoint, 0, numRows, generateResult);
        double mean = generateResult.getMean();
        OptimizationResult<MultivariatePoint> optimize = this._optimizer.optimize(this._function, multivariatePoint);
        MultivariatePoint optimum = optimize.getOptimum();
        for (int i = 0; i < beta.length; i++) {
            beta[i] = optimum.getElement(i);
        }
        double minValue = optimize.minValue();
        LOG.info("Fitting coefficients, LL improvement: " + mean + " -> " + minValue + "(" + (minValue - mean) + ")");
        if (!optimize.converged()) {
            LOG.info("Exhausted dataset before convergence, moving on.");
        }
        return minValue > mean ? new ParamFitResult<>(this._model.getParams(), this._model.getParams(), mean, mean, numRows) : new ParamFitResult<>(this._model.getParams(), this._function.generateParams(beta), minValue, mean, numRows);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private LogisticModelFunction<S, R, T> generateFunction() {
        ItemParameters<S, R, T> params = this._model.getParams();
        int reachableCount = params.getStatus().getReachableCount();
        int entryCount = params.getEntryCount();
        ItemStatus status = params.getStatus();
        int i = reachableCount * entryCount;
        int i2 = 0;
        double[] dArr = new double[i];
        int[] iArr = new int[i];
        int[] iArr2 = new int[i];
        for (int i3 = 0; i3 < reachableCount; i3++) {
            ItemStatus itemStatus = (ItemStatus) status.getReachable().get(i3);
            for (int i4 = 0; i4 < entryCount; i4++) {
                if (!params.betaIsFrozen(itemStatus, i4, this._filters)) {
                    dArr[i2] = params.getBeta(i3, i4);
                    iArr[i2] = i3;
                    iArr2[i2] = i4;
                    i2++;
                }
            }
        }
        return new LogisticModelFunction<>(Arrays.copyOf(dArr, i2), Arrays.copyOf(iArr, i2), Arrays.copyOf(iArr2, i2), params, this._grid, new ItemModel(params), this._settings);
    }
}
