package edu.umass.cs.mallet.grmm.test;

import edu.umass.cs.mallet.grmm.inference.BruteForceInferencer;
import edu.umass.cs.mallet.grmm.inference.LoopyBP;
import edu.umass.cs.mallet.grmm.inference.RandomGraphs;
import edu.umass.cs.mallet.grmm.inference.TRP;
import edu.umass.cs.mallet.grmm.types.Assignment;
import edu.umass.cs.mallet.grmm.types.AssignmentIterator;
import edu.umass.cs.mallet.grmm.types.Factor;
import edu.umass.cs.mallet.grmm.types.FactorGraph;
import edu.umass.cs.mallet.grmm.types.TableFactor;
import edu.umass.cs.mallet.grmm.types.UndirectedGrid;
import edu.umass.cs.mallet.grmm.types.Variable;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Iterator;
import java.util.Random;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/test/TestRandomGraphs.class */
public class TestRandomGraphs extends TestCase {
    static Class class$edu$umass$cs$mallet$grmm$test$TestRandomGraphs;

    public TestRandomGraphs(String str) {
        super(str);
    }

    public static Test suite() {
        Class cls;
        if (class$edu$umass$cs$mallet$grmm$test$TestRandomGraphs == null) {
            cls = class$("edu.umass.cs.mallet.grmm.test.TestRandomGraphs");
            class$edu$umass$cs$mallet$grmm$test$TestRandomGraphs = cls;
        } else {
            cls = class$edu$umass$cs$mallet$grmm$test$TestRandomGraphs;
        }
        return new TestSuite(cls);
    }

    public void testAttractiveGraphs() throws IOException {
        Random random = new Random(31421L);
        for (int i = 0; i < 5; i++) {
            UndirectedGrid randomAttractiveGrid = RandomGraphs.randomAttractiveGrid(5, 0.5d, random);
            System.out.println("************");
            randomAttractiveGrid.dump();
            TRP createForMaxProduct = TRP.createForMaxProduct();
            createForMaxProduct.computeMarginals(randomAttractiveGrid);
            Assignment bestAssignment = createForMaxProduct.bestAssignment();
            PrintWriter printWriter = new PrintWriter(new FileWriter(new File(new StringBuffer().append("attract.").append(i).append(".dot").toString())));
            randomAttractiveGrid.printAsDot(printWriter, bestAssignment);
            printWriter.close();
        }
    }

    public void testRepulsiveGraphs() throws IOException {
        Random random = new Random(31421L);
        for (int i = 0; i < 5; i++) {
            UndirectedGrid randomRepulsiveGrid = RandomGraphs.randomRepulsiveGrid(5, 0.5d, random);
            TRP createForMaxProduct = TRP.createForMaxProduct();
            createForMaxProduct.computeMarginals(randomRepulsiveGrid);
            Assignment bestAssignment = createForMaxProduct.bestAssignment();
            PrintWriter printWriter = new PrintWriter(new FileWriter(new File(new StringBuffer().append("repulse.").append(i).append(".dot").toString())));
            randomRepulsiveGrid.printAsDot(printWriter, bestAssignment);
            printWriter.close();
        }
    }

    public void testFrustratedGraphs() throws IOException {
        Random random = new Random(31421L);
        for (int i = 0; i < 5; i++) {
            UndirectedGrid randomFrustratedGrid = RandomGraphs.randomFrustratedGrid(5, 0.5d, random);
            TRP createForMaxProduct = TRP.createForMaxProduct();
            createForMaxProduct.computeMarginals(randomFrustratedGrid);
            Assignment bestAssignment = createForMaxProduct.bestAssignment();
            PrintWriter printWriter = new PrintWriter(new FileWriter(new File(new StringBuffer().append("mixed.").append(i).append(".dot").toString())));
            randomFrustratedGrid.printAsDot(printWriter, bestAssignment);
            printWriter.close();
        }
    }

    public void testFrustratedIsGrid() throws IOException {
        Random random = new Random(0L);
        for (int i = 0; i < 100; i++) {
            UndirectedGrid randomFrustratedGrid = RandomGraphs.randomFrustratedGrid(10, 1.0d, random);
            assertEquals(280, randomFrustratedGrid.factors().size());
            assertEquals(100, randomFrustratedGrid.numVariables());
            int[] iArr = new int[5];
            for (int i2 = 0; i2 < randomFrustratedGrid.numVariables(); i2++) {
                int degree = randomFrustratedGrid.getDegree(randomFrustratedGrid.get(i2));
                assertTrue(degree >= 2 && degree <= 4);
                iArr[degree] = iArr[degree] + 1;
            }
            assertEquals(iArr[0], 0);
            assertEquals(iArr[1], 0);
            assertEquals(iArr[2], 4);
            assertEquals(iArr[3], 32);
            assertEquals(iArr[4], 64);
        }
    }

    public void testUniformGrid() {
        UndirectedGrid undirectedGrid = (UndirectedGrid) RandomGraphs.createUniformGrid(3);
        assertEquals(9, undirectedGrid.numVariables());
        assertEquals(12, undirectedGrid.factors().size());
        TableFactor tableFactor = (TableFactor) new BruteForceInferencer().joint(undirectedGrid);
        AssignmentIterator assignmentIterator = tableFactor.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            assertEquals((-9.0d) * Math.log(2.0d), tableFactor.logValue(assignmentIterator), 0.001d);
        }
    }

    public void testUniformGridWithObservations() {
        FactorGraph createGridWithObs = RandomGraphs.createGridWithObs(new RandomGraphs.UniformFactorGenerator(), new RandomGraphs.UniformFactorGenerator(), 3);
        assertEquals(18, createGridWithObs.numVariables());
        assertEquals(21, createGridWithObs.factors().size());
        LoopyBP loopyBP = new LoopyBP();
        loopyBP.computeMarginals(createGridWithObs);
        Iterator variablesIterator = createGridWithObs.variablesIterator();
        while (variablesIterator.hasNext()) {
            Factor lookupMarginal = loopyBP.lookupMarginal((Variable) variablesIterator.next());
            AssignmentIterator assignmentIterator = lookupMarginal.assignmentIterator();
            while (assignmentIterator.hasNext()) {
                assertEquals(-Math.log(2.0d), lookupMarginal.logValue(assignmentIterator), 0.001d);
                assignmentIterator.advance();
            }
        }
    }

    public static void main(String[] strArr) throws Throwable {
        TestSuite testSuite;
        if (strArr.length > 0) {
            testSuite = new TestSuite();
            for (String str : strArr) {
                testSuite.addTest(new TestRandomGraphs(str));
            }
        } else {
            testSuite = (TestSuite) suite();
        }
        TestRunner.run(testSuite);
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }
}
