package edu.umass.cs.mallet.projects.seg_plus_coref.coreference;

import edu.umass.cs.mallet.base.classify.MaxEnt;
import edu.umass.cs.mallet.base.classify.MaxEntTrainer;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.Labeling;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/projects/seg_plus_coref/coreference/TreeModel.class */
public class TreeModel {
    MaxEnt treeModel;
    InstanceList ilist;
    Pipe instancePipe;
    boolean multiTree = false;

    public TreeModel(Pipe pipe, ArrayList arrayList, ArrayList arrayList2) {
        this.treeModel = null;
        this.ilist = null;
        this.instancePipe = null;
        this.instancePipe = pipe;
        this.ilist = new InstanceList(pipe);
        this.ilist.add(new PubCitIterator(arrayList, arrayList2));
        pipe.getDataAlphabet().stopGrowth();
        System.out.println(" >>>> Training Tree Model <<<< ");
        this.treeModel = (MaxEnt) new MaxEntTrainer().train(this.ilist, null, null, null, null);
    }

    public TreeModel(Pipe pipe, ArrayList arrayList, ArrayList arrayList2, ArrayList arrayList3, ArrayList arrayList4, ArrayList arrayList5, ArrayList arrayList6) {
        this.treeModel = null;
        this.ilist = null;
        this.instancePipe = null;
        this.instancePipe = pipe;
        this.ilist = new InstanceList(pipe);
        this.ilist.add(new PubCitIterator(arrayList, arrayList4));
        this.ilist.add(new PubCitIterator(arrayList2, arrayList5));
        this.ilist.add(new PubCitIterator(arrayList3, arrayList6));
        pipe.getDataAlphabet().stopGrowth();
        this.treeModel = (MaxEnt) new MaxEntTrainer().train(this.ilist, null, null, null, null);
    }

    public void setMultiTree(boolean z) {
        this.multiTree = z;
    }

    public double computeTreeObjFn(Collection collection) {
        return computeTreeObjFn(collection, false);
    }

    public double computeTreeObjFn(Collection collection, boolean z) {
        Iterator it = collection.iterator();
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            ArrayList list = Collections.list(Collections.enumeration((Collection) it.next()));
            Publication publication = new Publication((Citation) list.get(0));
            for (int i = 1; i < list.size(); i++) {
                publication.addNewCitation((Citation) list.get(i));
            }
            arrayList.add(publication);
        }
        if (!this.multiTree) {
            return computeTreeObjFnPubs(arrayList, z);
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator it2 = collection.iterator();
        while (it2.hasNext()) {
            Iterator it3 = ((Collection) it2.next()).iterator();
            while (it3.hasNext()) {
                arrayList2.add(it3.next());
            }
        }
        return computeTreeObjFnPubs2(arrayList, arrayList2);
    }

    public double computeTreeObjFnPubs2(ArrayList arrayList, ArrayList arrayList2) {
        double valueAtLocation;
        double valueAtLocation2;
        double d = 0.0d;
        for (int i = 0; i < arrayList.size(); i++) {
            Publication publication = (Publication) arrayList.get(i);
            Collection citations = publication.getCitations();
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                Citation citation = (Citation) it.next();
                Labeling labeling = this.treeModel.classify(new NodePair(publication, citation)).getLabeling();
                if (labeling.labelAtLocation(0).toString().equals("no")) {
                    if (citations.contains(citation)) {
                        valueAtLocation = labeling.valueAtLocation(1);
                        valueAtLocation2 = labeling.valueAtLocation(0);
                    } else {
                        valueAtLocation = labeling.valueAtLocation(0);
                        valueAtLocation2 = labeling.valueAtLocation(1);
                    }
                } else if (citations.contains(citation)) {
                    valueAtLocation = labeling.valueAtLocation(0);
                    valueAtLocation2 = labeling.valueAtLocation(1);
                } else {
                    valueAtLocation = labeling.valueAtLocation(1);
                    valueAtLocation2 = labeling.valueAtLocation(0);
                }
                d += valueAtLocation - valueAtLocation2;
            }
        }
        return d;
    }

    public double computeTreeObjFnPubs(ArrayList arrayList, boolean z) {
        double d = 0.0d;
        for (int i = 0; i < arrayList.size(); i++) {
            Publication publication = (Publication) arrayList.get(i);
            if (z) {
                System.out.println("\n\n PUBLICATION: ");
                System.out.println(new StringBuffer().append("  String: ").append(publication.getString()).toString());
            }
            for (Citation citation : publication.getCitations()) {
                Labeling labeling = this.treeModel.classify(new NodePair(publication, citation)).getLabeling();
                double valueAtLocation = labeling.labelAtLocation(0).toString().equals("no") ? labeling.valueAtLocation(1) : labeling.valueAtLocation(0);
                d += valueAtLocation;
                if (z) {
                    System.out.println(new StringBuffer().append("\n  CITATION: ").append(citation.print()).append(" -> ").append(valueAtLocation).toString());
                    System.out.println(citation.getString());
                }
            }
        }
        return d;
    }
}
