package edu.umass.cs.mallet.grmm.inference;

import edu.umass.cs.mallet.base.util.MalletLogger;
import edu.umass.cs.mallet.grmm.inference.AbstractBeliefPropagation;
import edu.umass.cs.mallet.grmm.types.Assignment;
import edu.umass.cs.mallet.grmm.types.DirectedModel;
import edu.umass.cs.mallet.grmm.types.Factor;
import edu.umass.cs.mallet.grmm.types.FactorGraph;
import edu.umass.cs.mallet.grmm.types.TableFactor;
import edu.umass.cs.mallet.grmm.types.Tree;
import edu.umass.cs.mallet.grmm.types.VarSet;
import edu.umass.cs.mallet.grmm.types.Variable;
import gnu.trove.THashMap;
import gnu.trove.THashSet;
import gnu.trove.TIntObjectHashMap;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.logging.Logger;
import org._3pq.jgrapht.Edge;
import org._3pq.jgrapht.Graph;
import org._3pq.jgrapht.graph.SimpleGraph;
import org._3pq.jgrapht.traverse.BreadthFirstIterator;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/TRP.class */
public class TRP extends AbstractBeliefPropagation {
    private static Logger logger;
    private static final boolean reportSpanningTrees = false;
    private TreeFactory factory;
    private TerminationCondition terminator;
    private Random random;
    private transient TIntObjectHashMap factorTouched;
    private transient boolean hasConverged;
    private transient int iterUsed;
    private static final long serialVersionUID = 1;
    static Class class$edu$umass$cs$mallet$grmm$inference$TRP;
    static final boolean $assertionsDisabled;

    /* renamed from: edu.umass.cs.mallet.grmm.inference.TRP$1, reason: invalid class name */
    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/TRP$1.class */
    static class AnonymousClass1 {
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/TRP$AlmostRandomTreeFactory.class */
    public class AlmostRandomTreeFactory implements TreeFactory {
        private final TRP this$0;

        public AlmostRandomTreeFactory(TRP trp) {
            this.this$0 = trp;
        }

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TreeFactory
        public Tree nextTree(FactorGraph factorGraph) {
            SimpleUnionFind simpleUnionFind = new SimpleUnionFind(null);
            ArrayList arrayList = new ArrayList(factorGraph.factors());
            ArrayList arrayList2 = new ArrayList(factorGraph.numVariables());
            Collections.shuffle(arrayList, this.this$0.random);
            try {
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    Factor factor = (Factor) it.next();
                    VarSet varSet = factor.varSet();
                    if (!this.this$0.isFactorTouched(factor) && simpleUnionFind.noPairConnected(varSet)) {
                        arrayList2.add(factor);
                        simpleUnionFind.unionAll(factor);
                        it.remove();
                    }
                }
                Iterator it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    Factor factor2 = (Factor) it2.next();
                    if (simpleUnionFind.noPairConnected(factor2.varSet())) {
                        arrayList2.add(factor2);
                        simpleUnionFind.unionAll(factor2);
                    }
                }
                Iterator it3 = arrayList2.iterator();
                while (it3.hasNext()) {
                    this.this$0.touchFactor((Factor) it3.next());
                }
                SimpleGraph simpleGraph = new SimpleGraph();
                Iterator variablesIterator = factorGraph.variablesIterator();
                while (variablesIterator.hasNext()) {
                    simpleGraph.addVertex((Variable) variablesIterator.next());
                }
                Iterator it4 = arrayList2.iterator();
                while (it4.hasNext()) {
                    Factor factor3 = (Factor) it4.next();
                    simpleGraph.addVertex(factor3);
                    Iterator it5 = factor3.varSet().iterator();
                    while (it5.hasNext()) {
                        simpleGraph.addEdge(factor3, (Variable) it5.next());
                    }
                }
                return TRP.graphToTree(simpleGraph);
            } catch (Exception e) {
                e.printStackTrace();
                throw new RuntimeException(e);
            }
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/TRP$ConvergenceTerminator.class */
    public static class ConvergenceTerminator implements TerminationCondition {
        double delta;

        public ConvergenceTerminator() {
            this.delta = 0.01d;
        }

        public ConvergenceTerminator(double d) {
            this.delta = 0.01d;
            this.delta = d;
        }

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TerminationCondition
        public void reset() {
        }

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TerminationCondition
        public boolean shouldContinue(TRP trp) {
            boolean z = !trp.hasConverged(this.delta);
            trp.copyOldMessages();
            return z;
        }

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TerminationCondition
        public Object clone() throws CloneNotSupportedException {
            return super.clone();
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/TRP$DefaultConvergenceTerminator.class */
    public static class DefaultConvergenceTerminator implements TerminationCondition {
        ConvergenceTerminator cterminator;
        IterationTerminator iterminator;
        String msg;

        public DefaultConvergenceTerminator() {
            this(0.001d, 1000);
        }

        public DefaultConvergenceTerminator(double d, int i) {
            this.cterminator = new ConvergenceTerminator(d);
            this.iterminator = new IterationTerminator(i);
            this.msg = new StringBuffer().append("***TRP quitting: over ").append(i).append(" iterations").toString();
        }

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TerminationCondition
        public void reset() {
            this.iterminator.reset();
            this.cterminator.reset();
        }

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TerminationCondition
        public boolean shouldContinue(TRP trp) {
            boolean z = !trp.allEdgesTouched();
            if (this.iterminator.shouldContinue(trp)) {
                if (z) {
                    return true;
                }
                return this.cterminator.shouldContinue(trp);
            }
            TRP.logger.warning(this.msg);
            if (!z) {
                return false;
            }
            TRP.logger.warning("***TRP warning: Not all edges used!");
            return false;
        }

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TerminationCondition
        public Object clone() throws CloneNotSupportedException {
            DefaultConvergenceTerminator defaultConvergenceTerminator = (DefaultConvergenceTerminator) super.clone();
            defaultConvergenceTerminator.iterminator = (IterationTerminator) this.iterminator.clone();
            defaultConvergenceTerminator.cterminator = (ConvergenceTerminator) this.cterminator.clone();
            return defaultConvergenceTerminator;
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/TRP$IterationTerminator.class */
    public static class IterationTerminator implements TerminationCondition {
        int current;
        int max;

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TerminationCondition
        public void reset() {
            this.current = 0;
        }

        public IterationTerminator(int i) {
            this.max = i;
            reset();
        }

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TerminationCondition
        public boolean shouldContinue(TRP trp) {
            this.current++;
            if (this.current >= this.max) {
                TRP.logger.finest(new StringBuffer().append("***TRP quitting: Iteration ").append(this.current).append(" >= ").append(this.max).toString());
            }
            return this.current <= this.max;
        }

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TerminationCondition
        public Object clone() throws CloneNotSupportedException {
            return super.clone();
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/TRP$SimpleUnionFind.class */
    private static class SimpleUnionFind {
        private Map obj2set;

        private SimpleUnionFind() {
            this.obj2set = new THashMap();
        }

        private Set findSet(Object obj) {
            Set set = (Set) this.obj2set.get(obj);
            if (set != null) {
                return set;
            }
            THashSet tHashSet = new THashSet();
            tHashSet.add(obj);
            this.obj2set.put(obj, tHashSet);
            return tHashSet;
        }

        private void union(Object obj, Object obj2) {
            Set findSet = findSet(obj);
            Set findSet2 = findSet(obj2);
            findSet.addAll(findSet2);
            Iterator it = findSet2.iterator();
            while (it.hasNext()) {
                this.obj2set.put(it.next(), findSet);
            }
        }

        public boolean noPairConnected(VarSet varSet) {
            for (int i = 0; i < varSet.size(); i++) {
                for (int i2 = i + 1; i2 < varSet.size(); i2++) {
                    if (findSet(varSet.get(i)) == findSet(varSet.get(i2))) {
                        return false;
                    }
                }
            }
            return true;
        }

        public void unionAll(Factor factor) {
            VarSet varSet = factor.varSet();
            for (int i = 0; i < varSet.size(); i++) {
                union(varSet.get(i), factor);
            }
        }

        SimpleUnionFind(AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/TRP$TerminationCondition.class */
    public interface TerminationCondition extends Cloneable, Serializable {
        boolean shouldContinue(TRP trp);

        void reset();

        Object clone() throws CloneNotSupportedException;
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/TRP$TreeFactory.class */
    public interface TreeFactory extends Serializable {
        Tree nextTree(FactorGraph factorGraph);
    }

    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/TRP$TreeListFactory.class */
    public static class TreeListFactory implements TreeFactory {
        private List lst;
        private Iterator it;

        public TreeListFactory(List list) {
            this.lst = list;
            this.it = this.lst.iterator();
        }

        public TreeListFactory(Tree[] treeArr) {
            this.lst = new ArrayList(Arrays.asList(treeArr));
            this.it = this.lst.iterator();
        }

        @Override // edu.umass.cs.mallet.grmm.inference.TRP.TreeFactory
        public Tree nextTree(FactorGraph factorGraph) {
            if (!this.it.hasNext()) {
                this.it = this.lst.iterator();
            }
            return (Tree) this.it.next();
        }
    }

    public TRP() {
        this(null, null);
    }

    public TRP(TreeFactory treeFactory) {
        this(treeFactory, null);
    }

    public TRP(TerminationCondition terminationCondition) {
        this(null, terminationCondition);
    }

    public TRP(TreeFactory treeFactory, TerminationCondition terminationCondition) {
        this.random = new Random();
        this.iterUsed = 0;
        this.factory = treeFactory;
        this.terminator = terminationCondition;
    }

    public static TRP createForMaxProduct() {
        TRP trp = new TRP();
        trp.setMessager(new AbstractBeliefPropagation.MaxProductMessageStrategy());
        return trp;
    }

    public TRP setTerminator(TerminationCondition terminationCondition) {
        this.terminator = terminationCondition;
        return this;
    }

    public void setRandomSeed(long j) {
        this.random = new Random(j);
    }

    public boolean isConverged() {
        return this.hasConverged;
    }

    public int iterationsUsed() {
        return this.iterUsed;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umass.cs.mallet.grmm.inference.AbstractBeliefPropagation
    public void initForGraph(FactorGraph factorGraph) {
        super.initForGraph(factorGraph);
        this.factorTouched = new TIntObjectHashMap(factorGraph.numVariables());
        this.hasConverged = false;
        if (this.factory == null) {
            this.factory = new AlmostRandomTreeFactory(this);
        }
        if (this.terminator == null) {
            this.terminator = new DefaultConvergenceTerminator();
        } else {
            this.terminator.reset();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Tree graphToTree(Graph graph) throws Exception {
        if (graph.vertexSet().size() <= 0) {
            throw new RuntimeException("Empty graph.");
        }
        Tree tree = new Tree();
        Object next = graph.vertexSet().iterator().next();
        tree.add(next);
        BreadthFirstIterator breadthFirstIterator = new BreadthFirstIterator(graph, next);
        while (breadthFirstIterator.hasNext()) {
            Object next2 = breadthFirstIterator.next();
            Iterator it = graph.edgesOf(next2).iterator();
            while (it.hasNext()) {
                Object oppositeVertex = ((Edge) it.next()).oppositeVertex(next2);
                if (tree.getParent(next2) != oppositeVertex) {
                    tree.addNode(next2, oppositeVertex);
                    if (!$assertionsDisabled && tree.getParent(oppositeVertex) != next2) {
                        throw new AssertionError();
                    }
                }
            }
        }
        return tree;
    }

    @Override // edu.umass.cs.mallet.grmm.inference.AbstractInferencer, edu.umass.cs.mallet.grmm.inference.Inferencer
    public void computeMarginals(FactorGraph factorGraph) {
        resetMessagesSentAtStart();
        initForGraph(factorGraph);
        int i = 0;
        while (this.terminator.shouldContinue(this)) {
            int i2 = i;
            i++;
            logger.finer(new StringBuffer().append("TRP iteration ").append(i2).toString());
            propagate(this.factory.nextTree(factorGraph));
        }
        this.iterUsed = i;
        logger.info(new StringBuffer().append("TRP used ").append(i).append(" iterations.").toString());
        doneWithGraph(factorGraph);
    }

    private void propagate(Tree tree) {
        Object root = tree.getRoot();
        lambdaPropagation(tree, null, root);
        piPropagation(tree, root);
    }

    private void lambdaPropagation(Tree tree, Object obj, Object obj2) {
        logger.finer(new StringBuffer().append("TRP lambdaPropagation from ").append(obj).toString());
        Iterator it = tree.getChildren(obj2).iterator();
        while (it.hasNext()) {
            lambdaPropagation(tree, obj2, it.next());
        }
        if (obj != null) {
            sendMessage(this.mdlCurrent, obj2, obj);
        }
    }

    private void piPropagation(Tree tree, Object obj) {
        logger.finer(new StringBuffer().append("TRP piPropagation from ").append(obj).toString());
        for (Object obj2 : tree.getChildren(obj)) {
            sendMessage(this.mdlCurrent, obj, obj2);
            piPropagation(tree, obj2);
        }
    }

    private void sendMessage(FactorGraph factorGraph, Object obj, Object obj2) {
        if (obj instanceof Factor) {
            sendMessage(factorGraph, (Factor) obj, (Variable) obj2);
        } else if (obj instanceof Variable) {
            sendMessage(factorGraph, (Variable) obj, (Factor) obj2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public boolean allEdgesTouched() {
        Iterator factorsIterator = this.mdlCurrent.factorsIterator();
        while (factorsIterator.hasNext()) {
            int index = this.mdlCurrent.getIndex((Factor) factorsIterator.next());
            if (getNumTouches(index) == 0) {
                logger.finest(new StringBuffer().append("***TRP continuing: factor ").append(index).append(" not touched.").toString());
                return false;
            }
        }
        return true;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void touchFactor(Factor factor) {
        incrementTouches(this.mdlCurrent.getIndex(factor));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public boolean isFactorTouched(Factor factor) {
        return getNumTouches(this.mdlCurrent.getIndex(factor)) > 0;
    }

    private int getNumTouches(int i) {
        Integer num = (Integer) this.factorTouched.get(i);
        if (num == null) {
            return 0;
        }
        return num.intValue();
    }

    private void incrementTouches(int i) {
        this.factorTouched.put(i, new Integer(getNumTouches(i) + 1));
    }

    public Factor query(DirectedModel directedModel, Variable variable) {
        throw new UnsupportedOperationException("GRMM doesn't yet do directed models.");
    }

    public Assignment bestAssignment() {
        int[] iArr = new int[this.mdlCurrent.numVariables()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = ((TableFactor) lookupMarginal(this.mdlCurrent.get(i))).argmax();
        }
        return new Assignment(this.mdlCurrent, iArr);
    }

    public Object clone() {
        try {
            TRP trp = (TRP) super.clone();
            if (this.terminator != null) {
                trp.terminator = (TerminationCondition) this.terminator.clone();
            }
            return trp;
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        Class cls2;
        if (class$edu$umass$cs$mallet$grmm$inference$TRP == null) {
            cls = class$("edu.umass.cs.mallet.grmm.inference.TRP");
            class$edu$umass$cs$mallet$grmm$inference$TRP = cls;
        } else {
            cls = class$edu$umass$cs$mallet$grmm$inference$TRP;
        }
        $assertionsDisabled = !cls.desiredAssertionStatus();
        if (class$edu$umass$cs$mallet$grmm$inference$TRP == null) {
            cls2 = class$("edu.umass.cs.mallet.grmm.inference.TRP");
            class$edu$umass$cs$mallet$grmm$inference$TRP = cls2;
        } else {
            cls2 = class$edu$umass$cs$mallet$grmm$inference$TRP;
        }
        logger = MalletLogger.getLogger(cls2.getName());
    }
}
