package edu.umass.cs.mallet.grmm.types;

import edu.umass.cs.mallet.grmm.util.Graphs;
import gnu.trove.THashSet;
import java.util.Collections;
import java.util.Set;
import org._3pq.jgrapht.UndirectedGraph;
import org._3pq.jgrapht.alg.ConnectivityInspector;

/* loaded from: input_file:edu/umass/cs/mallet/grmm/types/UndirectedModel.class */
public class UndirectedModel extends FactorGraph {
    private Set edges;

    public UndirectedModel() {
        this.edges = new THashSet();
    }

    public UndirectedModel(Variable[] variableArr) {
        super(variableArr);
        this.edges = new THashSet();
    }

    public UndirectedModel(int i) {
        super(i);
        this.edges = new THashSet();
    }

    public Set getEdgeSet() {
        return Collections.unmodifiableSet(this.edges);
    }

    @Override // edu.umass.cs.mallet.grmm.types.FactorGraph
    public void addFactor(Factor factor) {
        super.addFactor(factor);
        if (factor.varSet().size() == 2) {
            this.edges.add(factor.varSet());
        }
    }

    public static UndirectedModel createBoltzmannMachine(double[][] dArr, double[] dArr2) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException(new StringBuffer().append("Number of weights ").append(dArr.length).append(" not equal to number of biases ").append(dArr2.length).toString());
        }
        int length = dArr.length;
        Variable[] variableArr = new Variable[length];
        for (int i = 0; i < length; i++) {
            variableArr[i] = new Variable(2);
        }
        UndirectedModel undirectedModel = new UndirectedModel(variableArr);
        for (int i2 = 0; i2 < length; i2++) {
            undirectedModel.addFactor(new TableFactor(variableArr[i2], new double[]{1.0d, Math.exp(dArr2[i2])}));
            for (int i3 = i2 + 1; i3 < length; i3++) {
                if (dArr[i2][i3] != 0.0d) {
                    undirectedModel.addFactor(variableArr[i2], variableArr[i3], new double[]{1.0d, 1.0d, 1.0d, Math.exp(dArr[i2][i3])});
                }
            }
        }
        return undirectedModel;
    }

    public boolean isConnected(Variable variable, Variable variable2) {
        UndirectedGraph mdlToGraph = Graphs.mdlToGraph(this);
        return mdlToGraph.containsVertex(variable) && mdlToGraph.containsVertex(variable2) && new ConnectivityInspector(mdlToGraph).pathExists(variable, variable2);
    }
}
