package edu.umass.cs.mallet.grmm.test;

import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.util.ArrayUtils;
import edu.umass.cs.mallet.grmm.inference.RandomGraphs;
import edu.umass.cs.mallet.grmm.types.Factor;
import edu.umass.cs.mallet.grmm.types.FactorGraph;
import edu.umass.cs.mallet.grmm.types.HashVarSet;
import edu.umass.cs.mallet.grmm.types.TableFactor;
import edu.umass.cs.mallet.grmm.types.UndirectedModel;
import edu.umass.cs.mallet.grmm.types.Variable;
import edu.umass.cs.mallet.grmm.util.Graphs;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;
import org._3pq.jgrapht.GraphHelper;
import org._3pq.jgrapht.UndirectedGraph;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/test/TestUndirectedModel.class */
public class TestUndirectedModel extends TestCase {
    static Class class$edu$umass$cs$mallet$grmm$test$TestUndirectedModel;

    public TestUndirectedModel(String str) {
        super(str);
    }

    public void testOutputToDot() throws IOException {
        UndirectedModel createRandomGrid = TestInference.createRandomGrid(3, 4, 2, new Random(4234L));
        PrintWriter printWriter = new PrintWriter(new FileWriter(new File("grmm-model.dot")));
        createRandomGrid.printAsDot(printWriter);
        printWriter.close();
        System.out.println("Now you can open up grmm-model.dot in Graphviz.");
    }

    public void testMultipleNodePotentials() {
        Variable variable = new Variable(2);
        FactorGraph factorGraph = new FactorGraph(new Variable[]{variable});
        factorGraph.addFactor(new TableFactor(variable, new double[]{0.5d, 0.5d}));
        factorGraph.addFactor(new TableFactor(variable, new double[]{0.25d, 0.25d}));
        Factor factorOf = factorGraph.factorOf(variable);
        double[] dArr = {0.125d, 0.125d};
        assertTrue(new StringBuffer().append("Arrays not equal\n  Expected ").append(ArrayUtils.toString(dArr)).append("\n  Actual ").append(ArrayUtils.toString(((TableFactor) factorOf).toValueArray())).toString(), Arrays.equals(dArr, ((TableFactor) factorOf).toValueArray()));
    }

    public void testMultipleEdgePotentials() {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        Variable[] variableArr = {variable, variable2};
        FactorGraph factorGraph = new FactorGraph(variableArr);
        TableFactor tableFactor = new TableFactor(variableArr, new double[]{0.5d, 0.5d, 0.5d, 0.5d});
        factorGraph.addFactor(tableFactor);
        TableFactor tableFactor2 = new TableFactor(variableArr, new double[]{0.25d, 0.25d, 0.5d, 0.5d});
        factorGraph.addFactor(tableFactor2);
        Factor factorOf = factorGraph.factorOf(variable, variable2);
        Collection allFactorsOverVars = factorGraph.allFactorsOverVars(new HashVarSet(variableArr));
        assertEquals(2, allFactorsOverVars.size());
        assertTrue(allFactorsOverVars.contains(tableFactor));
        assertTrue(allFactorsOverVars.contains(tableFactor2));
        double[] dArr = {0.125d, 0.125d, 0.25d, 0.25d};
        assertTrue(new StringBuffer().append("Arrays not equal\n  Expected ").append(ArrayUtils.toString(dArr)).append("\n  Actual ").append(ArrayUtils.toString(((TableFactor) factorOf).toValueArray())).toString(), new TableFactor(variableArr, dArr).almostEquals(factorOf, 1.0E-10d));
    }

    public void testPotentialConnections() {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        Variable variable3 = new Variable(2);
        Variable[] variableArr = {variable, variable2, variable3};
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.addFactor(new TableFactor(variableArr, new double[8]));
        assertTrue(factorGraph.isAdjacent(variable, variable2));
        assertTrue(factorGraph.isAdjacent(variable2, variable3));
        assertTrue(factorGraph.isAdjacent(variable, variable3));
    }

    public void testThreeNodeModel() {
        Random random = new Random(23534709L);
        FactorGraph factorGraph = new FactorGraph();
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        Variable variable3 = new Variable(2);
        factorGraph.addFactor(variable, variable2, RandomGraphs.generateMixedPotentialValues(random, 1.5d));
        factorGraph.addFactor(variable, variable3, RandomGraphs.generateMixedPotentialValues(random, 1.5d));
        assertTrue(factorGraph.isAdjacent(variable, variable3));
        assertTrue(factorGraph.isAdjacent(variable, variable2));
        assertTrue(!factorGraph.isAdjacent(variable2, variable3));
        assertTrue(factorGraph.factorOf(variable, variable2) != null);
        assertTrue(factorGraph.factorOf(variable, variable3) != null);
    }

    public void testUndirectedCaches() {
        Iterator it = TestInference.createTestModels().iterator();
        while (it.hasNext()) {
            verifyCachesConsistent((FactorGraph) it.next());
        }
    }

    private void verifyCachesConsistent(FactorGraph factorGraph) {
        for (Factor factor : factorGraph.factors()) {
            Object[] array = factor.varSet().toArray();
            switch (array.length) {
                case 1:
                    assertTrue(factor == factorGraph.factorOf((Variable) array[0]));
                    break;
                case 2:
                    Variable variable = (Variable) array[0];
                    Variable variable2 = (Variable) array[1];
                    Factor factorOf = factorGraph.factorOf(variable, variable2);
                    Factor factorOf2 = factorGraph.factorOf(variable2, variable);
                    assertTrue(factor == factorOf);
                    assertTrue(factorOf == factorOf2);
                    break;
            }
        }
    }

    public void testUndirectedCachesAfterRemove() {
        Iterator it = TestInference.createTestModels().iterator();
        while (it.hasNext()) {
            FactorGraph duplicate = ((FactorGraph) it.next()).duplicate();
            duplicate.remove(duplicate.get(0));
            Iterator variablesIterator = duplicate.variablesIterator();
            while (variablesIterator.hasNext()) {
                int index = duplicate.getIndex((Variable) variablesIterator.next());
                assertTrue(index >= 0);
                assertTrue(index < duplicate.numVariables());
            }
            verifyCachesConsistent(duplicate);
        }
    }

    public void testMdlToGraph() {
        for (UndirectedModel undirectedModel : TestInference.createTestModels()) {
            UndirectedGraph mdlToGraph = Graphs.mdlToGraph(undirectedModel);
            Set<Variable> vertexSet = mdlToGraph.vertexSet();
            assertEquals(undirectedModel.numVariables(), vertexSet.size());
            int i = 0;
            Iterator it = undirectedModel.factors().iterator();
            while (it.hasNext()) {
                if (((Factor) it.next()).varSet().size() == 2) {
                    i++;
                }
            }
            assertEquals(i, mdlToGraph.edgeSet().size());
            for (Variable variable : vertexSet) {
                assertTrue(vertexSet.contains(variable));
                HashSet hashSet = new HashSet(GraphHelper.neighborListOf(mdlToGraph, variable));
                hashSet.add(variable);
                Iterator it2 = undirectedModel.allFactorsOfVar(variable).iterator();
                while (it2.hasNext()) {
                    assertTrue(hashSet.containsAll(((Factor) it2.next()).varSet()));
                }
            }
        }
    }

    public void testFactorOfSet() {
        Variable[] variableArr = new Variable[3];
        for (int i = 0; i < variableArr.length; i++) {
            variableArr[i] = new Variable(2);
        }
        TableFactor tableFactor = new TableFactor(variableArr, new double[]{Transducer.ZERO_COST, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d});
        FactorGraph factorGraph = new FactorGraph(variableArr);
        factorGraph.addFactor(tableFactor);
        assertTrue(tableFactor == factorGraph.factorOf(tableFactor.varSet()));
        HashSet hashSet = new HashSet(tableFactor.varSet());
        assertTrue(tableFactor == factorGraph.factorOf(hashSet));
        hashSet.remove(variableArr[0]);
        assertTrue(null == factorGraph.factorOf(hashSet));
    }

    public static Test suite() {
        Class cls;
        if (class$edu$umass$cs$mallet$grmm$test$TestUndirectedModel == null) {
            cls = class$("edu.umass.cs.mallet.grmm.test.TestUndirectedModel");
            class$edu$umass$cs$mallet$grmm$test$TestUndirectedModel = cls;
        } else {
            cls = class$edu$umass$cs$mallet$grmm$test$TestUndirectedModel;
        }
        return new TestSuite(cls);
    }

    public static void main(String[] strArr) throws Throwable {
        TestSuite suite;
        if (strArr.length > 0) {
            suite = new TestSuite();
            for (String str : strArr) {
                suite.addTest(new TestUndirectedModel(str));
            }
        } else {
            suite = suite();
        }
        TestRunner.run(suite);
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }
}
