package edu.umass.cs.mallet.grmm.inference.gbp;

import edu.umass.cs.mallet.base.util.MalletLogger;
import edu.umass.cs.mallet.base.util.Timing;
import edu.umass.cs.mallet.grmm.inference.AbstractInferencer;
import edu.umass.cs.mallet.grmm.types.Assignment;
import edu.umass.cs.mallet.grmm.types.AssignmentIterator;
import edu.umass.cs.mallet.grmm.types.DiscreteFactor;
import edu.umass.cs.mallet.grmm.types.Factor;
import edu.umass.cs.mallet.grmm.types.FactorGraph;
import edu.umass.cs.mallet.grmm.types.LogTableFactor;
import edu.umass.cs.mallet.grmm.types.VarSet;
import edu.umass.cs.mallet.grmm.types.Variable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.lucene.analysis.shingle.ShingleFilter;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/grmm/inference/gbp/ParentChildGBP.class */
public class ParentChildGBP extends AbstractInferencer {
    private static final Logger logger;
    private static final boolean debug = false;
    private RegionGraphGenerator regioner;
    private MessageStrategy sender;
    private boolean useInertia;
    private double inertiaWeight;
    private static final double THRESHOLD = 0.001d;
    private static final int MAX_ITER = 500;
    private MessageArray oldMessages;
    private MessageArray newMessages;
    private RegionGraph rg;
    private FactorGraph mdl;
    static Class class$edu$umass$cs$mallet$grmm$inference$gbp$ParentChildGBP;
    static final boolean $assertionsDisabled;

    private ParentChildGBP() {
        this.useInertia = true;
        this.inertiaWeight = 0.5d;
    }

    public ParentChildGBP(RegionGraphGenerator regionGraphGenerator) {
        this(regionGraphGenerator, new FullMessageStrategy());
    }

    public ParentChildGBP(RegionGraphGenerator regionGraphGenerator, MessageStrategy messageStrategy) {
        this.useInertia = true;
        this.inertiaWeight = 0.5d;
        this.regioner = regionGraphGenerator;
        this.sender = messageStrategy;
    }

    public static ParentChildGBP makeBPInferencer() {
        ParentChildGBP parentChildGBP = new ParentChildGBP();
        parentChildGBP.regioner = new BPRegionGenerator();
        parentChildGBP.sender = new FullMessageStrategy();
        return parentChildGBP;
    }

    public static ParentChildGBP makeKikuchiInferencer() {
        ParentChildGBP parentChildGBP = new ParentChildGBP();
        parentChildGBP.regioner = new Kikuchi4SquareRegionGenerator();
        parentChildGBP.sender = new FullMessageStrategy();
        return parentChildGBP;
    }

    public boolean getUseInertia() {
        return this.useInertia;
    }

    public void setUseInertia(boolean z) {
        this.useInertia = z;
    }

    public double getInertiaWeight() {
        return this.inertiaWeight;
    }

    public void setInertiaWeight(double d) {
        this.inertiaWeight = d;
    }

    @Override // edu.umass.cs.mallet.grmm.inference.AbstractInferencer, edu.umass.cs.mallet.grmm.inference.Inferencer
    public Factor lookupMarginal(Variable variable) {
        Region findContainingRegion = this.rg.findContainingRegion(variable);
        if (findContainingRegion == null) {
            throw new IllegalArgumentException(new StringBuffer().append("Could not find region containing variable ").append(variable).append(" in region graph ").append(this.rg).toString());
        }
        return computeBelief(findContainingRegion).marginalize(variable);
    }

    @Override // edu.umass.cs.mallet.grmm.inference.AbstractInferencer, edu.umass.cs.mallet.grmm.inference.Inferencer
    public Factor lookupMarginal(VarSet varSet) {
        Region findContainingRegion = this.rg.findContainingRegion(varSet);
        if (findContainingRegion == null) {
            throw new IllegalArgumentException(new StringBuffer().append("Could not find region containing clique ").append(varSet).append(" in region graph ").append(this.rg).toString());
        }
        return computeBelief(findContainingRegion).marginalize(varSet);
    }

    private Factor computeBelief(Region region) {
        return computeBelief(region, this.newMessages);
    }

    static Factor computeBelief(Region region, MessageArray messageArray) {
        LogTableFactor logTableFactor = new LogTableFactor(region.vars);
        Iterator it = region.factors.iterator();
        while (it.hasNext()) {
            logTableFactor.multiplyBy((Factor) it.next());
        }
        Iterator it2 = region.parents.iterator();
        while (it2.hasNext()) {
            logTableFactor.multiplyBy(messageArray.getMessage((Region) it2.next(), region));
        }
        for (Region region2 : region.descendants) {
            for (Region region3 : region2.parents) {
                if (region3 != region && !region.descendants.contains(region3)) {
                    logTableFactor.multiplyBy(messageArray.getMessage(region3, region2));
                }
            }
        }
        logTableFactor.normalize();
        return logTableFactor;
    }

    @Override // edu.umass.cs.mallet.grmm.inference.AbstractInferencer, edu.umass.cs.mallet.grmm.inference.Inferencer
    public double lookupLogJoint(Assignment assignment) {
        return this.mdl.logProduct(assignment) + computeFreeEnergy(this.rg);
    }

    private double computeFreeEnergy(RegionGraph regionGraph) {
        double d = 0.0d;
        double d2 = 0.0d;
        Iterator it = regionGraph.iterator();
        while (it.hasNext()) {
            Region region = (Region) it.next();
            Factor computeBelief = computeBelief(region);
            d2 += region.countingNumber * computeBelief.entropy();
            LogTableFactor logTableFactor = new LogTableFactor(computeBelief.varSet());
            Iterator it2 = region.factors.iterator();
            while (it2.hasNext()) {
                logTableFactor.multiplyBy((Factor) it2.next());
            }
            double d3 = 0.0d;
            AssignmentIterator assignmentIterator = computeBelief.assignmentIterator();
            while (assignmentIterator.hasNext()) {
                Assignment assignment = assignmentIterator.assignment();
                d3 += computeBelief.value(assignment) * (-logTableFactor.logValue(assignment));
                assignmentIterator.advance();
            }
            d += region.countingNumber * d3;
        }
        return d - d2;
    }

    @Override // edu.umass.cs.mallet.grmm.inference.AbstractInferencer, edu.umass.cs.mallet.grmm.inference.Inferencer
    public void computeMarginals(FactorGraph factorGraph) {
        Timing timing = new Timing();
        this.mdl = factorGraph;
        this.rg = this.regioner.constructRegionGraph(factorGraph);
        RegionEdge[] chooseMessageSendingOrder = chooseMessageSendingOrder();
        this.newMessages = new MessageArray(this.rg);
        timing.tick("GBP Region Graph construction");
        int i = 0;
        do {
            this.oldMessages = this.newMessages;
            this.newMessages = this.oldMessages.duplicate();
            this.sender.setMessageArray(this.oldMessages, this.newMessages);
            for (RegionEdge regionEdge : chooseMessageSendingOrder) {
                this.sender.sendMessage(regionEdge);
            }
            if (logger.isLoggable(Level.FINER)) {
                timing.tick(new StringBuffer().append("GBP iteration ").append(i).toString());
            }
            i++;
            if (this.useInertia) {
                this.newMessages = this.sender.averageMessages(this.rg, this.oldMessages, this.newMessages, this.inertiaWeight);
            }
            if (hasConverged()) {
                break;
            }
        } while (i < 500);
        logger.info(new StringBuffer().append("GBP: Used ").append(i).append(" iterations.").toString());
        if (i >= 500) {
            logger.warning("***WARNING: GBP not converged!");
        }
    }

    private RegionEdge[] chooseMessageSendingOrder() {
        ArrayList arrayList = new ArrayList();
        Iterator edgeIterator = this.rg.edgeIterator();
        while (edgeIterator.hasNext()) {
            arrayList.add((RegionEdge) edgeIterator.next());
        }
        Collections.sort(arrayList, new Comparator(this) { // from class: edu.umass.cs.mallet.grmm.inference.gbp.ParentChildGBP.1
            private final ParentChildGBP this$0;

            {
                this.this$0 = this;
            }

            @Override // java.util.Comparator
            public int compare(Object obj, Object obj2) {
                return Double.compare(((RegionEdge) obj).to.vars.size(), ((RegionEdge) obj2).to.vars.size());
            }
        });
        return (RegionEdge[]) arrayList.toArray(new RegionEdge[arrayList.size()]);
    }

    private boolean hasConverged() {
        Iterator edgeIterator = this.rg.edgeIterator();
        while (edgeIterator.hasNext()) {
            RegionEdge regionEdge = (RegionEdge) edgeIterator.next();
            DiscreteFactor message = this.oldMessages.getMessage(regionEdge.from, regionEdge.to);
            DiscreteFactor message2 = this.newMessages.getMessage(regionEdge.from, regionEdge.to);
            if (message == null) {
                if (!$assertionsDisabled && message2 != null) {
                    throw new AssertionError();
                }
            } else if (!message.almostEquals(message2, 0.001d)) {
                return false;
            }
        }
        return true;
    }

    public void dump() {
        Iterator edgeIterator = this.rg.edgeIterator();
        while (edgeIterator.hasNext()) {
            RegionEdge regionEdge = (RegionEdge) edgeIterator.next();
            System.out.println(new StringBuffer().append("Message: ").append(regionEdge.from).append(" --> ").append(regionEdge.to).append(ShingleFilter.TOKEN_SEPARATOR).append(this.newMessages.getMessage(regionEdge.from, regionEdge.to)).toString());
        }
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        Class cls2;
        if (class$edu$umass$cs$mallet$grmm$inference$gbp$ParentChildGBP == null) {
            cls = class$("edu.umass.cs.mallet.grmm.inference.gbp.ParentChildGBP");
            class$edu$umass$cs$mallet$grmm$inference$gbp$ParentChildGBP = cls;
        } else {
            cls = class$edu$umass$cs$mallet$grmm$inference$gbp$ParentChildGBP;
        }
        $assertionsDisabled = !cls.desiredAssertionStatus();
        if (class$edu$umass$cs$mallet$grmm$inference$gbp$ParentChildGBP == null) {
            cls2 = class$("edu.umass.cs.mallet.grmm.inference.gbp.ParentChildGBP");
            class$edu$umass$cs$mallet$grmm$inference$gbp$ParentChildGBP = cls2;
        } else {
            cls2 = class$edu$umass$cs$mallet$grmm$inference$gbp$ParentChildGBP;
        }
        logger = MalletLogger.getLogger(cls2.getName());
    }
}
