package edu.umass.cs.mallet.base.minimize.tests;

import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.minimize.Minimizable;
import edu.umass.cs.mallet.base.types.Matrix;
import edu.umass.cs.mallet.base.util.MalletLogger;
import java.util.logging.Logger;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/base/minimize/tests/TestMinimizable.class */
public class TestMinimizable extends TestCase {
    private static Logger logger;
    static Class class$edu$umass$cs$mallet$base$minimize$tests$TestMinimizable;
    static final boolean $assertionsDisabled;

    public TestMinimizable(String str) {
        super(str);
    }

    public static boolean testGetSetParameters(Minimizable minimizable) {
        System.out.println("TestMinimizable testGetSetParameters");
        Matrix newMatrix = minimizable.getNewMatrix();
        minimizable.getParameters(newMatrix);
        for (int i = 0; i < newMatrix.singleSize(); i++) {
            newMatrix.setSingleValue(i, i);
        }
        minimizable.setParameters(newMatrix);
        newMatrix.setAll(Transducer.ZERO_COST);
        minimizable.getParameters(newMatrix);
        for (int i2 = 0; i2 < newMatrix.singleSize(); i2++) {
            assertTrue(newMatrix.singleValue(i2) == ((double) i2));
        }
        newMatrix.setAll(Transducer.ZERO_COST);
        minimizable.setParameters(newMatrix);
        int[] iArr = new int[newMatrix.getNumDimensions()];
        for (int i3 = 0; i3 < newMatrix.singleSize(); i3++) {
            newMatrix.singleToIndices(i3, iArr);
            minimizable.setParameter(iArr, i3);
        }
        newMatrix.setAll(Transducer.ZERO_COST);
        minimizable.getParameters(newMatrix);
        for (int i4 = 0; i4 < newMatrix.singleSize(); i4++) {
            assertTrue(newMatrix.singleValue(i4) == ((double) i4));
        }
        for (int i5 = 0; i5 < newMatrix.singleSize(); i5++) {
            newMatrix.singleToIndices(i5, iArr);
            assertTrue(minimizable.getParameter(iArr) == ((double) i5));
        }
        return true;
    }

    public static double testCostAndGradientCurrentParameters(Minimizable.ByGradient byGradient) {
        Matrix parameters = byGradient.getParameters(byGradient.getNewMatrix());
        double cost = byGradient.getCost();
        Matrix costGradient = byGradient.getCostGradient(byGradient.getNewMatrix());
        Matrix matrix = (Matrix) costGradient.cloneMatrix();
        double twoNorm = 0.1d / costGradient.twoNorm();
        double d = twoNorm * 5.0d;
        System.out.println(new StringBuffer().append("epsilon = ").append(twoNorm).append(" tolerance=").append(d).toString());
        for (int i = 0; i < parameters.singleSize(); i++) {
            double singleValue = parameters.singleValue(i);
            parameters.setSingleValue(i, singleValue + twoNorm);
            byGradient.setParameters(parameters);
            double cost2 = byGradient.getCost();
            double d2 = (cost2 - cost) / twoNorm;
            System.out.println(new StringBuffer().append("cost=").append(cost).append(" epsCost=").append(cost2).append(" slope[").append(i).append("] = ").append(d2).append(" gradient[]=").append(costGradient.singleValue(i)).toString());
            if (!$assertionsDisabled && Double.isNaN(d2)) {
                throw new AssertionError();
            }
            logger.fine(new StringBuffer().append("TestMinimizable checking singleIndex ").append(i).append(": gradient slope = ").append(costGradient.singleValue(i)).append(", cost+epsilon slope = ").append(d2).append(": slope difference = ").append(Math.abs(d2 - costGradient.singleValue(i))).toString());
            matrix.setSingleValue(i, d2);
            parameters.setSingleValue(i, singleValue);
        }
        System.out.println(new StringBuffer().append("empiricalGradient.twoNorm = ").append(matrix.twoNorm()).toString());
        costGradient.timesEquals(1.0d / costGradient.twoNorm());
        matrix.timesEquals(1.0d / matrix.twoNorm());
        double acos = Math.acos(costGradient.dotProduct(matrix));
        logger.info(new StringBuffer().append("TestMinimizable angle = ").append(acos).toString());
        if (Math.abs(acos) > d) {
            throw new IllegalStateException(new StringBuffer().append("Gradient/Cost mismatch: angle=").append(acos).toString());
        }
        if (Double.isNaN(acos)) {
            throw new IllegalStateException("Gradient/Cost error: angle is NaN!");
        }
        return acos;
    }

    public static boolean testCostAndGradient(Minimizable.ByGradient byGradient) {
        Matrix newMatrix = byGradient.getNewMatrix();
        newMatrix.setAll(Transducer.ZERO_COST);
        byGradient.setParameters(newMatrix);
        testCostAndGradientCurrentParameters(byGradient);
        newMatrix.setAll(Transducer.ZERO_COST);
        Matrix newMatrix2 = byGradient.getNewMatrix();
        byGradient.getCostGradient(newMatrix2);
        newMatrix2.timesEquals(-1.0E-4d);
        newMatrix.plusEquals(newMatrix2);
        byGradient.setParameters(newMatrix);
        testCostAndGradientCurrentParameters(byGradient);
        return true;
    }

    public void testTestCostAndGradient() {
        testCostAndGradient(new Quadratic(10.0d, 2.0d, 3.0d));
    }

    public static Test suite() {
        Class cls;
        if (class$edu$umass$cs$mallet$base$minimize$tests$TestMinimizable == null) {
            cls = class$("edu.umass.cs.mallet.base.minimize.tests.TestMinimizable");
            class$edu$umass$cs$mallet$base$minimize$tests$TestMinimizable = cls;
        } else {
            cls = class$edu$umass$cs$mallet$base$minimize$tests$TestMinimizable;
        }
        return new TestSuite(cls);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // junit.framework.TestCase
    public void setUp() {
    }

    public static void main(String[] strArr) {
        TestRunner.run(suite());
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        Class cls2;
        if (class$edu$umass$cs$mallet$base$minimize$tests$TestMinimizable == null) {
            cls = class$("edu.umass.cs.mallet.base.minimize.tests.TestMinimizable");
            class$edu$umass$cs$mallet$base$minimize$tests$TestMinimizable = cls;
        } else {
            cls = class$edu$umass$cs$mallet$base$minimize$tests$TestMinimizable;
        }
        $assertionsDisabled = !cls.desiredAssertionStatus();
        if (class$edu$umass$cs$mallet$base$minimize$tests$TestMinimizable == null) {
            cls2 = class$("edu.umass.cs.mallet.base.minimize.tests.TestMinimizable");
            class$edu$umass$cs$mallet$base$minimize$tests$TestMinimizable = cls2;
        } else {
            cls2 = class$edu$umass$cs$mallet$base$minimize$tests$TestMinimizable;
        }
        logger = MalletLogger.getLogger(cls2.getName());
    }
}
