package edu.columbia.tjw.item;

import edu.columbia.tjw.item.ItemCurveType;
import edu.columbia.tjw.item.ItemRegressor;
import edu.columbia.tjw.item.ItemStatus;
import edu.columbia.tjw.item.fit.ItemParamGrid;
import edu.columbia.tjw.item.fit.PackedParameters;
import edu.columbia.tjw.item.fit.ParamFittingGrid;
import edu.columbia.tjw.item.util.LogLikelihood;
import edu.columbia.tjw.item.util.LogUtil;
import edu.columbia.tjw.item.util.MultiLogistic;
import java.util.Arrays;
import java.util.logging.Logger;

/* loaded from: input_file:edu/columbia/tjw/item/ItemModel.class */
public final class ItemModel<S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> implements Cloneable {
    private static final Logger LOG = LogUtil.getLogger(ItemModel.class);
    private final double ROUNDING_TOLERANCE = 1.0E-8d;
    private final LogLikelihood<S> _likelihood;
    private final ItemParameters<S, R, T> _params;
    private final PackedParameters<S, R, T> _packed;
    private final double[][] _betas;
    private final int _reachableSize;
    private final double[] _rawRegWorkspace;
    private final double[] _regWorkspace;
    private final double[] _probWorkspace;
    private final double[] _psDerivativeWorkspace;

    public ItemModel(ItemParameters<S, R, T> itemParameters) {
        this(itemParameters, itemParameters.generatePacked());
    }

    public ItemModel(PackedParameters<S, R, T> packedParameters) {
        this(packedParameters.generateParams(), packedParameters.m20clone());
    }

    private ItemModel(ItemParameters<S, R, T> itemParameters, PackedParameters<S, R, T> packedParameters) {
        this.ROUNDING_TOLERANCE = 1.0E-8d;
        synchronized (this) {
            this._params = itemParameters;
            this._packed = packedParameters;
            S status = itemParameters.getStatus();
            this._betas = itemParameters.getBetas();
            this._reachableSize = status.getReachableCount();
            this._likelihood = new LogLikelihood<>(status);
            int entryCount = itemParameters.getEntryCount();
            this._rawRegWorkspace = new double[itemParameters.getUniqueRegressors().size()];
            this._regWorkspace = new double[entryCount];
            this._probWorkspace = new double[this._reachableSize];
            this._psDerivativeWorkspace = new double[this._packed.size()];
        }
    }

    public S getStatus() {
        return this._params.getStatus();
    }

    public final ItemParameters<S, R, T> getParams() {
        return this._params;
    }

    public int getDerivativeSize() {
        return this._packed.size();
    }

    public final double logLikelihood(ParamFittingGrid<S, R, T> paramFittingGrid, int i) {
        double[] dArr = this._probWorkspace;
        transitionProbability(paramFittingGrid, i, dArr);
        int ordinalToOffset = this._likelihood.ordinalToOffset(paramFittingGrid.getNextStatus(i));
        if (ordinalToOffset < 0) {
            return 0.0d;
        }
        return this._likelihood.logLikelihood(dArr, ordinalToOffset);
    }

    private double computeEntryWeight(double[] dArr, int i) {
        double d;
        double transform;
        if (dArr.length != this._rawRegWorkspace.length) {
            throw new IllegalArgumentException("Length mismatch.");
        }
        int entryDepth = this._params.getEntryDepth(i);
        double d2 = 1.0d;
        for (int i2 = 0; i2 < entryDepth; i2++) {
            double d3 = dArr[this._params.getEntryRegressorOffset(i, i2)];
            ItemCurve<T> entryCurve = this._params.getEntryCurve(i, i2);
            if (null == entryCurve) {
                d = d2;
                transform = d3;
            } else {
                d = d2;
                transform = entryCurve.transform(d3);
            }
            d2 = d * transform;
        }
        return d2;
    }

    private void addEntryPowerScores(double[] dArr, int i, double[] dArr2) {
        if (dArr2.length != this._betas.length) {
            throw new IllegalArgumentException("Mismatch.");
        }
        double computeEntryWeight = computeEntryWeight(dArr, i);
        for (int i2 = 0; i2 < this._betas.length; i2++) {
            int i3 = i2;
            dArr2[i3] = dArr2[i3] + (computeEntryWeight * this._betas[i2][i]);
        }
    }

    private void fillEntryWeights(double[] dArr, double[] dArr2) {
        int entryCount = this._params.getEntryCount();
        if (dArr2.length != entryCount) {
            throw new IllegalArgumentException("Length mismatch.");
        }
        for (int i = 0; i < entryCount; i++) {
            dArr2[i] = computeEntryWeight(dArr, i);
        }
    }

    public void computeGradient(ParamFittingGrid<S, R, T> paramFittingGrid, int i, double[] dArr, double[] dArr2, double[][] dArr3) {
        int size = this._packed.size();
        if (dArr.length != size) {
            throw new IllegalArgumentException("Derivative size mismatch.");
        }
        double[] dArr4 = this._probWorkspace;
        double[] dArr5 = this._psDerivativeWorkspace;
        double[] dArr6 = this._regWorkspace;
        double[] dArr7 = this._rawRegWorkspace;
        paramFittingGrid.getRegressors(i, dArr7);
        fillEntryWeights(dArr7, dArr6);
        rawPowerScores(dArr6, dArr4);
        MultiLogistic.multiLogisticFunction(dArr4, dArr4);
        int ordinalToOffset = this._likelihood.ordinalToOffset(paramFittingGrid.getNextStatus(i));
        if (ordinalToOffset < 0) {
            Arrays.fill(dArr, 0.0d);
            return;
        }
        double d = dArr4[ordinalToOffset];
        double d2 = (-1.0d) / d;
        for (int i2 = 0; i2 < size; i2++) {
            int entry = this._packed.getEntry(i2);
            boolean isBeta = this._packed.isBeta(i2);
            int transition = this._packed.getTransition(i2);
            double d3 = ((transition == ordinalToOffset ? 1.0d : 0.0d) - dArr4[transition]) * d;
            double computeEntryWeight = computeEntryWeight(dArr7, entry);
            if (computeEntryWeight != dArr6[entry]) {
                throw new IllegalStateException("Error.");
            }
            double entryBeta = isBeta ? computeEntryWeight : this._packed.getEntryBeta(i2) * computeWeightDerivative(dArr7, i2, computeEntryWeight, entry);
            dArr5[i2] = entryBeta;
            dArr[i2] = d3 * entryBeta;
        }
        if (null != dArr3 || null != dArr2) {
            fillSecondDerivatives(dArr7, ordinalToOffset, d, dArr4, dArr5, dArr, dArr2, dArr3);
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = dArr[i3] * d2;
        }
    }

    private double computeWeightDerivative(double[] dArr, int i, double d, int i2) {
        int depth = this._packed.getDepth(i);
        ItemCurve<T> entryCurve = this._params.getEntryCurve(i2, depth);
        double d2 = dArr[this._params.getEntryRegressorOffset(i2, depth)];
        int curveIndex = this._packed.getCurveIndex(i);
        double transform = entryCurve.transform(d2);
        double derivative = entryCurve.derivative(curveIndex, d2);
        if (derivative == 0.0d) {
            return 0.0d;
        }
        return d * (derivative / transform);
    }

    private void fillSecondDerivatives(double[] dArr, int i, double d, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, double[][] dArr6) {
        if (null != dArr6 && dArr6.length != dArr4.length) {
            throw new IllegalArgumentException("Mismatched sizes! " + dArr6.length + " != " + dArr4.length);
        }
        double d2 = d * d;
        for (int i2 = 0; i2 < dArr4.length; i2++) {
            if (null != dArr6 && dArr6[i2].length != dArr4.length) {
                throw new IllegalArgumentException("Mismatched sizes! " + dArr6[i2].length + " != " + dArr4.length);
            }
            int transition = this._packed.getTransition(i2);
            double d3 = dArr2[transition];
            double d4 = dArr3[i2];
            double d5 = dArr4[i2];
            int entry = this._packed.getEntry(i2);
            double d6 = (transition == i ? 1.0d : 0.0d) - d3;
            int length = dArr6 != null ? dArr4.length : i2 + 1;
            for (int i3 = i2; i3 < length; i3++) {
                int transition2 = this._packed.getTransition(i3);
                double d7 = dArr3[i3];
                double d8 = dArr2[transition2];
                int entry2 = this._packed.getEntry(i3);
                double d9 = transition == transition2 ? 1.0d : 0.0d;
                double d10 = dArr4[i3];
                double d11 = ((-(((entry != entry2 ? 0.0d : ((-powerScoreSecondDerivative(dArr, i2, i3, transition, entry)) * d) * d6) + ((d4 * d10) * d6)) + ((((d4 * d7) * d) * d3) * (d9 - d8)))) / d) + ((d5 * d10) / d2);
                if (i3 == i2) {
                    dArr5[i3] = -d11;
                }
                if (dArr6 != null) {
                    dArr6[i2][i3] = -d11;
                    dArr6[i3][i2] = -d11;
                }
            }
        }
    }

    private double powerScoreSecondDerivative(double[] dArr, int i, int i2, int i3, int i4) {
        boolean isBeta = this._packed.isBeta(i);
        boolean isBeta2 = this._packed.isBeta(i2);
        if (isBeta && isBeta2) {
            return 0.0d;
        }
        double computeEntryWeight = computeEntryWeight(dArr, i4);
        if (isBeta) {
            return computeWeightDerivative(dArr, i2, computeEntryWeight, i4);
        }
        if (isBeta2) {
            return computeWeightDerivative(dArr, i, computeEntryWeight, i4);
        }
        double entryBeta = this._packed.getEntryBeta(i);
        int depth = this._packed.getDepth(i);
        int depth2 = this._packed.getDepth(i2);
        int curveIndex = this._packed.getCurveIndex(i);
        int curveIndex2 = this._packed.getCurveIndex(i2);
        ItemCurve<T> entryCurve = this._params.getEntryCurve(i4, depth);
        double d = dArr[this._params.getEntryRegressorOffset(i4, depth)];
        double transform = entryCurve.transform(d);
        if (depth == depth2) {
            double secondDerivative = entryCurve.secondDerivative(curveIndex, curveIndex2, d);
            if (secondDerivative == 0.0d) {
                return 0.0d;
            }
            return entryBeta * computeEntryWeight * (secondDerivative / transform);
        }
        ItemCurve<T> entryCurve2 = this._params.getEntryCurve(i4, depth2);
        double d2 = dArr[this._params.getEntryRegressorOffset(i4, depth2)];
        double transform2 = entryCurve2.transform(d2);
        double derivative = entryCurve.derivative(curveIndex, d);
        double derivative2 = entryCurve2.derivative(curveIndex2, d2);
        if (derivative == 0.0d || derivative2 == 0.0d) {
            return 0.0d;
        }
        return entryBeta * computeEntryWeight * (derivative / transform) * (derivative2 / transform2);
    }

    public int transitionProbability(ItemParamGrid<S, R, T> itemParamGrid, int i, double[] dArr) {
        itemParamGrid.getRegressors(i, this._rawRegWorkspace);
        return transitionProbability(this._rawRegWorkspace, dArr);
    }

    public int transitionProbability(double[] dArr, double[] dArr2) {
        multiLogisticFunction(dArr, dArr2);
        return this._betas.length;
    }

    private void powerScores(double[] dArr, double[] dArr2) {
        Arrays.fill(dArr2, 0.0d);
        for (int i = 0; i < this._params.getEntryCount(); i++) {
            addEntryPowerScores(dArr, i, dArr2);
        }
    }

    private void rawPowerScores(double[] dArr, double[] dArr2) {
        int length = this._regWorkspace.length;
        int length2 = this._betas.length;
        for (int i = 0; i < length2; i++) {
            double d = 0.0d;
            double[] dArr3 = this._betas[i];
            for (int i2 = 0; i2 < length; i2++) {
                d += dArr[i2] * dArr3[i2];
            }
            dArr2[i] = d;
        }
    }

    private void multiLogisticFunction(double[] dArr, double[] dArr2) {
        powerScores(dArr, dArr2);
        MultiLogistic.multiLogisticFunction(dArr2, dArr2);
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public final synchronized ItemModel<S, R, T> m4clone() {
        return new ItemModel<>(this._params, this._packed);
    }
}
