package edu.umass.cs.mallet.grmm.test;

import edu.umass.cs.mallet.grmm.inference.BruteForceInferencer;
import edu.umass.cs.mallet.grmm.types.AbstractTableFactor;
import edu.umass.cs.mallet.grmm.types.CPT;
import edu.umass.cs.mallet.grmm.types.DirectedModel;
import edu.umass.cs.mallet.grmm.types.DiscreteFactor;
import edu.umass.cs.mallet.grmm.types.FactorGraph;
import edu.umass.cs.mallet.grmm.types.Factors;
import edu.umass.cs.mallet.grmm.types.LogTableFactor;
import edu.umass.cs.mallet.grmm.types.TableFactor;
import edu.umass.cs.mallet.grmm.types.Variable;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/test/TestDirectedModel.class */
public class TestDirectedModel extends TestCase {
    private CPT pA;
    private CPT pB;
    private CPT pC;
    private DiscreteFactor fA;
    private DiscreteFactor fB;
    private DiscreteFactor fC;
    private Variable[] vars;
    private Variable A;
    private Variable B;
    private Variable C;
    static Class class$edu$umass$cs$mallet$grmm$test$TestDirectedModel;

    public TestDirectedModel(String str) {
        super(str);
        this.A = new Variable(2);
        this.B = new Variable(2);
        this.C = new Variable(2);
        this.vars = new Variable[]{this.A, this.B, this.C};
        this.fA = LogTableFactor.makeFromValues(this.A, new double[]{1.0d, 4.0d});
        this.fB = LogTableFactor.makeFromValues(this.B, new double[]{3.0d, 2.0d});
        this.fC = new TableFactor(this.vars, new double[]{3.0d, 7.0d, 5.0d, 5.0d, 9.0d, 1.0d, 6.0d, 4.0d});
        this.pA = Factors.normalizeAsCpt((AbstractTableFactor) this.fA.duplicate(), this.A);
        this.pB = Factors.normalizeAsCpt((AbstractTableFactor) this.fB.duplicate(), this.B);
        this.pC = Factors.normalizeAsCpt((AbstractTableFactor) this.fC.duplicate(), this.C);
    }

    public void testSimpleModel() {
        FactorGraph factorGraph = new FactorGraph(this.vars);
        factorGraph.addFactor(this.pA);
        factorGraph.addFactor(this.pB);
        factorGraph.addFactor(this.fC);
        DirectedModel directedModel = new DirectedModel(this.vars);
        directedModel.addFactor(this.pA);
        directedModel.addFactor(this.pB);
        directedModel.addFactor(this.pC);
        BruteForceInferencer bruteForceInferencer = new BruteForceInferencer();
        comparePotentials((DiscreteFactor) bruteForceInferencer.joint(factorGraph), (DiscreteFactor) bruteForceInferencer.joint(directedModel));
    }

    private void comparePotentials(DiscreteFactor discreteFactor, DiscreteFactor discreteFactor2) {
        double[] valueArray = discreteFactor.toValueArray();
        double[] valueArray2 = discreteFactor2.toValueArray();
        assertEquals(valueArray2.length, valueArray.length);
        for (int i = 0; i < valueArray2.length; i++) {
            assertEquals(valueArray2[i], valueArray[i], 0.001d);
        }
    }

    public void testCycleChecking() {
        DirectedModel directedModel = new DirectedModel(this.vars);
        directedModel.addFactor(this.pA);
        directedModel.addFactor(this.pB);
        directedModel.addFactor(this.pC);
        try {
            directedModel.addFactor(new CPT(new TableFactor(new Variable[]{this.B, this.C}), this.B));
            assertTrue("Test failed: No exception thrown.", false);
        } catch (IllegalArgumentException e) {
        }
        try {
            directedModel.addFactor(new CPT(new TableFactor(new Variable[]{this.A, this.C}), this.A));
            assertTrue("Test failed: No exception thrown.", false);
        } catch (IllegalArgumentException e2) {
        }
    }

    public void testCptOfVar() {
        DirectedModel directedModel = new DirectedModel(this.vars);
        directedModel.addFactor(this.pA);
        directedModel.addFactor(this.pB);
        directedModel.addFactor(this.pC);
        assertTrue(this.pA == directedModel.getCptofVar(this.A));
        assertTrue(this.pB == directedModel.getCptofVar(this.B));
        assertTrue(this.pC == directedModel.getCptofVar(this.C));
    }

    public void testFactorReplace() {
        DirectedModel directedModel = new DirectedModel(this.vars);
        directedModel.addFactor(this.pA);
        directedModel.addFactor(this.pB);
        directedModel.addFactor(this.pC);
        assertEquals(3, directedModel.factors().size());
        try {
            directedModel.addFactor(new CPT(new TableFactor(new Variable[]{this.B, this.C}), this.C));
        } catch (IllegalArgumentException e) {
        }
    }

    public static Test suite() {
        Class cls;
        if (class$edu$umass$cs$mallet$grmm$test$TestDirectedModel == null) {
            cls = class$("edu.umass.cs.mallet.grmm.test.TestDirectedModel");
            class$edu$umass$cs$mallet$grmm$test$TestDirectedModel = cls;
        } else {
            cls = class$edu$umass$cs$mallet$grmm$test$TestDirectedModel;
        }
        return new TestSuite(cls);
    }

    public static void main(String[] strArr) throws Throwable {
        TestSuite suite;
        if (strArr.length > 0) {
            suite = new TestSuite();
            for (String str : strArr) {
                suite.addTest(new TestDirectedModel(str));
            }
        } else {
            suite = suite();
        }
        TestRunner.run(suite);
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }
}
