package edu.umass.cs.mallet.grmm.inference;

import edu.umass.cs.mallet.grmm.types.Assignment;
import edu.umass.cs.mallet.grmm.types.Factor;
import edu.umass.cs.mallet.grmm.types.TableFactor;
import edu.umass.cs.mallet.grmm.types.Tree;
import edu.umass.cs.mallet.grmm.types.VarSet;
import edu.umass.cs.mallet.grmm.types.Variable;
import gnu.trove.THashSet;
import gnu.trove.TIntObjectHashMap;
import gnu.trove.TIntObjectIterator;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/JunctionTree.class */
public class JunctionTree extends Tree {
    private int numNodes;
    private TIntObjectHashMap sepsets = new TIntObjectHashMap();
    private Factor[] cpfs;
    static final boolean $assertionsDisabled;
    static Class class$edu$umass$cs$mallet$grmm$inference$JunctionTree;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/JunctionTree$Sepset.class */
    public static class Sepset {
        Set set;
        Factor ptl;

        Sepset(Set set, Factor factor) {
            this.set = set;
            this.ptl = factor;
        }
    }

    public JunctionTree(int i) {
        this.numNodes = i;
        this.cpfs = new Factor[i];
    }

    @Override // edu.umass.cs.mallet.grmm.types.Tree
    public void addNode(Object obj, Object obj2) {
        super.addNode(obj, obj2);
        VarSet varSet = (VarSet) obj;
        VarSet varSet2 = (VarSet) obj2;
        Set intersection = varSet.intersection(varSet2);
        putSepset(lookupIndex(varSet), lookupIndex(varSet2), new Sepset(intersection, new TableFactor(intersection)));
    }

    private int hashIdxIdx(int i, int i2) {
        if ($assertionsDisabled || (i < 65536 && i2 < 65536)) {
            return i < i2 ? (i << 16) | i2 : (i2 << 16) | i;
        }
        throw new AssertionError();
    }

    private void putSepset(int i, int i2, Sepset sepset) {
        this.sepsets.put(hashIdxIdx(i, i2), sepset);
    }

    private Sepset getSepset(int i, int i2) {
        return (Sepset) this.sepsets.get(hashIdxIdx(i, i2));
    }

    public Factor getCPF(VarSet varSet) {
        return this.cpfs[lookupIndex(varSet)];
    }

    public void setCPF(VarSet varSet, Factor factor) {
        this.cpfs[lookupIndex(varSet)] = factor;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void clearCPFs() {
        for (int i = 0; i < this.cpfs.length; i++) {
            this.cpfs[i] = new TableFactor((VarSet) lookupVertex(i));
        }
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            Sepset sepset = (Sepset) it.value();
            sepset.ptl = new TableFactor(sepset.set);
        }
    }

    public Set sepsetPotentials() {
        THashSet tHashSet = new THashSet();
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            tHashSet.add(((Sepset) it.value()).ptl);
        }
        return tHashSet;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setSepsetPot(Factor factor, VarSet varSet, VarSet varSet2) {
        getSepset(lookupIndex(varSet), lookupIndex(varSet2)).ptl = factor;
    }

    public Factor getSepsetPot(VarSet varSet, VarSet varSet2) {
        return getSepset(lookupIndex(varSet), lookupIndex(varSet2)).ptl;
    }

    public Collection clusterPotentials() {
        HashSet hashSet = new HashSet();
        for (int i = 0; i < this.cpfs.length; i++) {
            if (this.cpfs[i] != null) {
                hashSet.add(this.cpfs[i]);
            }
        }
        return hashSet;
    }

    public Set getSepset(VarSet varSet, VarSet varSet2) {
        return getSepset(lookupIndex(varSet), lookupIndex(varSet2)).set;
    }

    public Factor lookupMarginal(Variable variable) {
        return getCPF(findParentCluster(variable)).marginalize(variable);
    }

    public double lookupLogJoint(Assignment assignment) {
        double d = 0.0d;
        for (int i = 0; i < this.cpfs.length; i++) {
            if (this.cpfs[i] != null) {
                d += this.cpfs[i].logValue(assignment);
            }
        }
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            d -= ((Sepset) it.value()).ptl.logValue(assignment);
        }
        return d;
    }

    public VarSet findParentCluster(Variable variable) {
        int i = Integer.MAX_VALUE;
        VarSet varSet = null;
        Iterator verticesIterator = getVerticesIterator();
        while (verticesIterator.hasNext()) {
            VarSet varSet2 = (VarSet) verticesIterator.next();
            if (varSet2.contains(variable) && varSet2.weight() < i) {
                varSet = varSet2;
                i = varSet2.weight();
            }
        }
        return varSet;
    }

    public VarSet findParentCluster(Collection collection) {
        int i = Integer.MAX_VALUE;
        VarSet varSet = null;
        Iterator verticesIterator = getVerticesIterator();
        while (verticesIterator.hasNext()) {
            VarSet varSet2 = (VarSet) verticesIterator.next();
            if (varSet2.containsAll(collection) && varSet2.weight() < i) {
                varSet = varSet2;
                i = varSet2.weight();
            }
        }
        return varSet;
    }

    public VarSet findCluster(Variable[] variableArr) {
        List asList = Arrays.asList(variableArr);
        Iterator verticesIterator = getVerticesIterator();
        while (verticesIterator.hasNext()) {
            VarSet varSet = (VarSet) verticesIterator.next();
            if (varSet.containsAll(asList) && asList.containsAll(varSet)) {
                return varSet;
            }
        }
        return null;
    }

    public void normalizeAll() {
        int length = this.cpfs.length;
        for (int i = 0; i < length; i++) {
            if (this.cpfs[i] != null) {
                this.cpfs[i].normalize();
            }
        }
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            ((Sepset) it.value()).ptl.normalize();
        }
    }

    int getId(VarSet varSet) {
        return lookupIndex(varSet);
    }

    @Override // edu.umass.cs.mallet.grmm.types.Tree
    public void dump() {
        int length = this.cpfs.length;
        System.out.println(this);
        System.out.println("Vertex CPFs");
        for (int i = 0; i < length; i++) {
            if (this.cpfs[i] != null) {
                System.out.println(new StringBuffer().append("CPF ").append(i).append(" ").append(this.cpfs[i].dump()).toString());
            }
        }
        System.out.println("sepset CPFs");
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            System.out.println(((Sepset) it.value()).ptl.dump());
        }
        System.out.println("/End JT");
    }

    public double dumpLogJoint(Assignment assignment) {
        for (int i = 0; i < this.cpfs.length; i++) {
            if (this.cpfs[i] != null) {
                this.cpfs[i].logValue(assignment);
                System.out.println(new StringBuffer().append("CPF ").append(i).append(" accum = ").append(0.0d).toString());
            }
        }
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            Factor factor = ((Sepset) it.value()).ptl;
            factor.logValue(assignment);
            System.out.println(new StringBuffer().append("Sepset ").append(factor.varSet()).append(" accum ").append(0.0d).toString());
        }
        return 0.0d;
    }

    public boolean isNaN() {
        int length = this.cpfs.length;
        for (int i = 0; i < length; i++) {
            if (this.cpfs[i].isNaN()) {
                return true;
            }
        }
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            if (((Sepset) it.value()).ptl.isNaN()) {
                return true;
            }
        }
        return false;
    }

    public double entropy() {
        double d = 0.0d;
        Iterator it = clusterPotentials().iterator();
        while (it.hasNext()) {
            d += ((Factor) it.next()).entropy();
        }
        Iterator it2 = sepsetPotentials().iterator();
        while (it2.hasNext()) {
            d -= ((Factor) it2.next()).entropy();
        }
        return d;
    }

    public void decompact() {
        this.cpfs = new Factor[this.numNodes];
        clearCPFs();
    }

    public void compact() {
        this.cpfs = null;
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$edu$umass$cs$mallet$grmm$inference$JunctionTree == null) {
            cls = class$("edu.umass.cs.mallet.grmm.inference.JunctionTree");
            class$edu$umass$cs$mallet$grmm$inference$JunctionTree = cls;
        } else {
            cls = class$edu$umass$cs$mallet$grmm$inference$JunctionTree;
        }
        $assertionsDisabled = !cls.desiredAssertionStatus();
    }
}
