package edu.umass.cs.mallet.projects.seg_plus_coref.anaphora;

import edu.umass.cs.mallet.base.classify.Classifier;
import edu.umass.cs.mallet.base.classify.MaxEnt;
import edu.umass.cs.mallet.base.classify.MaxEntTrainer;
import edu.umass.cs.mallet.base.classify.Trial;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.pipe.SerialPipes;
import edu.umass.cs.mallet.base.pipe.Target2Label;
import edu.umass.cs.mallet.base.pipe.iterator.FileIterator;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.LabelVector;
import edu.umass.cs.mallet.base.types.Matrix2;
import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.ClusterEvaluate;
import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.Clusterer;
import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.Clustering;
import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.KeyClustering;
import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.MappedGraph;
import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.MortonClustering;
import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.PairEvaluate;
import java.io.File;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.solr.schema.JsonPreAnalyzedParser;
import salvo.jesus.graph.WeightedEdge;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/projects/seg_plus_coref/anaphora/TUIGraph.class */
public class TUIGraph {
    public static final String[] pronouns = {"He", "he", "Him", "him", "His", "his", "She", "she", "Her", "her", "hers", "it", "It", "its", "Its", "itself", "himself", "herself"};
    public static final int pronounsSize = 18;

    public static void main(String[] strArr) {
        String str;
        String str2;
        if (new Integer(4) == new Integer(4)) {
            System.out.println("INTERESTING");
        }
        if (strArr.length != 2) {
            str = new String("/usr/dan/users8/wellner/data/all-docs/test-annotated");
            str2 = new String("/usr/dan/users8/wellner/data/all-docs/mini-train");
        } else {
            str = strArr[0];
            str2 = strArr[1];
        }
        XMLFileFilter xMLFileFilter = new XMLFileFilter(".*xml");
        FileIterator fileIterator = new FileIterator(new File(str), xMLFileFilter);
        FileIterator fileIterator2 = new FileIterator(new File(str2), xMLFileFilter);
        ArrayList arrayList = new ArrayList();
        arrayList.add(new MentionPairFilter());
        MentionPairIterator mentionPairIterator = new MentionPairIterator(fileIterator, "MUC", true, true, true, arrayList);
        MentionPairIterator mentionPairIterator2 = new MentionPairIterator(fileIterator2, "MUC", true, true, true, arrayList);
        SerialPipes serialPipes = new SerialPipes(new Pipe[]{new Target2Label(), new AffixOfMentionPair(), new MentionPairHeadIdentical(), new MentionPairIdentical(), new MentionPairSentenceDistance(), new PartOfSpeechMentionPair(), new HobbsDistanceMentionPair(), new MentionPairAntecedentPosition(), new NullAntecedentFeatureExtractor(), new ModifierWordFeatures(), new MentionPair2FeatureVector()});
        InstanceList instanceList = new InstanceList(serialPipes);
        instanceList.add(mentionPairIterator);
        InstanceList instanceList2 = new InstanceList(serialPipes);
        instanceList2.add(mentionPairIterator2);
        instanceList.split(new double[]{0.7d, 0.3d});
        MaxEnt maxEnt = (MaxEnt) new MaxEntTrainer().train(instanceList);
        System.out.println(new StringBuffer().append("Training Accuracy on \"yes\" = ").append(new Trial(maxEnt, instanceList).labelF1("yes")).toString());
        System.out.println(new StringBuffer().append("Training Accuracy on \"no\" = ").append(new Trial(maxEnt, instanceList).labelF1("no")).toString());
        System.out.println(new StringBuffer().append("Testing Accuracy on \"yes\" = ").append(new Trial(maxEnt, instanceList2).labelF1("yes")).toString());
        System.out.println(new StringBuffer().append("Testing Accuracy on \"no\" = ").append(new Trial(maxEnt, instanceList2).labelF1("no")).toString());
        MentionPairIterator.partitionIntoDocumentInstances(instanceList);
        Set<List> partitionIntoDocumentInstances = MentionPairIterator.partitionIntoDocumentInstances(instanceList2);
        Clusterer clusterer = new Clusterer();
        partitionIntoDocumentInstances.size();
        for (List list : partitionIntoDocumentInstances) {
            new LinkedHashSet();
            MappedGraph mappedGraph = new MappedGraph();
            KeyClustering collectAllKeyClusters = collectAllKeyClusters(list);
            collectAllKeyClusters.print();
            Iterator it = list.iterator();
            Clustering mortonClustering = getMortonClustering(list, maxEnt);
            System.out.println(new StringBuffer().append("Number of pairs: ").append(list.size()).toString());
            while (it.hasNext()) {
                constructEdgesUsingModel(mappedGraph, maxEnt, (Instance) it.next());
            }
            clusterer.setGraph(mappedGraph);
            Clustering clustering = clusterer.getClustering();
            System.out.println("Model clusters: ");
            clustering.printDetailed();
            System.out.println("Key clusters: ");
            collectAllKeyClusters.printDetailed();
            ClusterEvaluate clusterEvaluate = new ClusterEvaluate(collectAllKeyClusters, mortonClustering);
            clusterEvaluate.evaluate();
            System.out.println(new StringBuffer().append("F1 morton is : ").append(clusterEvaluate.getF1()).toString());
            ClusterEvaluate clusterEvaluate2 = new ClusterEvaluate(collectAllKeyClusters, clustering);
            clusterEvaluate2.evaluate();
            System.out.println(new StringBuffer().append("F1 using model is : ").append(clusterEvaluate2.getF1()).toString());
            ClusterEvaluate clusterEvaluate3 = new ClusterEvaluate(collectAllKeyClusters, collectAllKeyClusters);
            clusterEvaluate3.evaluate();
            System.out.println(new StringBuffer().append("F1 using keykey is : ").append(clusterEvaluate3.getF1()).toString());
            System.out.println("Pairwise key:morton");
            PairEvaluate pairEvaluate = new PairEvaluate(collectAllKeyClusters, mortonClustering);
            pairEvaluate.evaluate();
            System.out.println(new StringBuffer().append("Morton pairF1: ").append(pairEvaluate.getF1()).toString());
            System.out.println("Pairwise key:model");
            PairEvaluate pairEvaluate2 = new PairEvaluate(collectAllKeyClusters, clustering);
            pairEvaluate2.evaluate();
            System.out.println(new StringBuffer().append("Model pairF1: ").append(pairEvaluate2.getF1()).toString());
            System.out.println("\n\n Error analysis: MORTON");
            clusterEvaluate.printErrors(true);
            System.out.println("\n\n Error analysis: Model");
            clusterEvaluate2.printErrors(true);
            System.out.println("Mapping: ");
            mappedGraph.printMap();
        }
    }

    public static Clustering getMortonClustering(List list, Classifier classifier) {
        MortonClustering mortonClustering = new MortonClustering();
        Iterator it = list.iterator();
        Mention mention = null;
        Mention mention2 = null;
        double d = -10000.0d;
        double d2 = -10000.0d;
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            MentionPair mentionPair = (MentionPair) instance.getSource();
            LabelVector labelVector = classifier.classify(instance).getLabelVector();
            Mention referent = mentionPair.getReferent();
            Mention antecedent = mentionPair.getAntecedent();
            for (int i = 0; i < labelVector.singleSize(); i++) {
                if (labelVector.labelAtLocation(i).toString().equals("yes")) {
                    d2 = labelVector.valueAtLocation(i);
                }
            }
            if (referent != mention) {
                d = -10000.0d;
                if (mention != null) {
                    if (mention2 != null) {
                        mortonClustering.addToClustering(mention, mention2);
                        System.out.println(new StringBuffer().append("merging: ").append(mention.getString()).append(":").append(mention2.getString()).toString());
                    } else {
                        mortonClustering.addToClustering(mention);
                        System.out.println(new StringBuffer().append("merging: ").append(mention.getString()).append(":NULL").toString());
                    }
                }
                mention = referent;
                if (d2 > -10000.0d) {
                    mention2 = antecedent;
                    d = d2;
                } else {
                    mention2 = null;
                }
            } else if (d2 > d) {
                mention2 = antecedent;
                d = d2;
            }
        }
        if (mention2 != null) {
            mortonClustering.addToClustering(mention, mention2);
            System.out.println(new StringBuffer().append("merging: ").append(mention.getString()).append(":").append(mention2.getString()).toString());
        } else {
            mortonClustering.addToClustering(mention);
            System.out.println(new StringBuffer().append("merging: ").append(mention.getString()).append(":NULL").toString());
        }
        return mortonClustering;
    }

    public static List getMentionsFromPairs(List list) {
        ArrayList arrayList = new ArrayList();
        Iterator it = list.iterator();
        while (it.hasNext()) {
            MentionPair mentionPair = (MentionPair) ((Instance) it.next()).getSource();
            Mention antecedent = mentionPair.getAntecedent();
            Mention referent = mentionPair.getReferent();
            if (antecedent != null && !arrayList.contains(antecedent)) {
                arrayList.add(antecedent);
            }
            if (referent != null && !arrayList.contains(referent)) {
                arrayList.add(referent);
            }
        }
        return arrayList;
    }

    public static void normalizeGraphEdges(MappedGraph mappedGraph) {
        Set<WeightedEdge> edgeSet = mappedGraph.getGraph().getEdgeSet();
        Iterator it = edgeSet.iterator();
        double d = 0.0d;
        while (it.hasNext()) {
            double abs = Math.abs(((WeightedEdge) it.next()).getWeight());
            if (abs > d) {
                d = abs;
            }
        }
        for (WeightedEdge weightedEdge : edgeSet) {
            weightedEdge.setWeight(weightedEdge.getWeight() / d);
        }
    }

    public static KeyClustering collectAllKeyClusters(List list) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        Iterator it = list.iterator();
        while (it.hasNext()) {
            MentionPair mentionPair = (MentionPair) ((Instance) it.next()).getSource();
            Mention antecedent = mentionPair.getAntecedent();
            Mention referent = mentionPair.getReferent();
            if (antecedent != null) {
                linkedHashSet.add(antecedent);
            }
            if (referent != null) {
                linkedHashSet.add(referent);
            }
        }
        KeyClustering keyClustering = new KeyClustering();
        Iterator it2 = list.iterator();
        while (it2.hasNext()) {
            MentionPair mentionPair2 = (MentionPair) ((Instance) it2.next()).getSource();
            String entityReference = mentionPair2.getEntityReference();
            Mention referent2 = mentionPair2.getReferent();
            Mention antecedent2 = mentionPair2.getAntecedent();
            if (entityReference != null) {
                if (antecedent2 != null) {
                    keyClustering.addToClustering(entityReference, antecedent2);
                    linkedHashSet.remove(antecedent2);
                }
                if (referent2 != null) {
                    keyClustering.addToClustering(entityReference, referent2);
                    linkedHashSet.remove(referent2);
                }
            }
        }
        Iterator it3 = linkedHashSet.iterator();
        int i = 0;
        while (it3.hasNext()) {
            keyClustering.addToClustering(new String(JsonPreAnalyzedParser.OFFSET_START_KEY).concat(new Integer(i).toString()), (Mention) it3.next());
            i++;
        }
        return keyClustering;
    }

    private static void coalesceNewPair(Set set, Instance instance) {
        if (instance.getLabeling().toString().equals("yes")) {
            MentionPair mentionPair = (MentionPair) instance.getSource();
            if (mentionPair.nullPair()) {
                LinkedHashSet linkedHashSet = new LinkedHashSet();
                mentionPair.getReferent().getMalletPhrase();
                linkedHashSet.add(mentionPair.getReferent());
                Iterator it = set.iterator();
                while (it.hasNext()) {
                    System.out.println(new StringBuffer().append("Creating ").append(mentionPair.getReferent()).append(" when it already exists in ").append((Set) it.next()).toString());
                }
                set.add(linkedHashSet);
                return;
            }
            Iterator it2 = set.iterator();
            boolean z = false;
            while (it2.hasNext()) {
                Set set2 = (Set) it2.next();
                if (set2.contains(mentionPair.getAntecedent()) || set2.contains(mentionPair.getReferent())) {
                    set2.add(mentionPair.getReferent());
                    set2.add(mentionPair.getAntecedent());
                    z = true;
                }
            }
            if (z) {
                return;
            }
            LinkedHashSet linkedHashSet2 = new LinkedHashSet();
            linkedHashSet2.add(mentionPair.getAntecedent());
            linkedHashSet2.add(mentionPair.getReferent());
            set.add(linkedHashSet2);
        }
    }

    private static void printClusters(Set set) {
        System.out.println("[[[");
        for (Object obj : set) {
            if (obj instanceof Set) {
                printCluster((Set) obj);
            }
        }
        System.out.println("]]]");
    }

    private static void printCluster(Set set) {
        System.out.print(DefaultExpressionEngine.DEFAULT_INDEX_START);
        Iterator it = set.iterator();
        while (it.hasNext()) {
            Mention mention = (Mention) it.next();
            mention.getMalletPhrase().printPreTerms();
            System.out.print(new StringBuffer().append(" - ").append(mention.getUniqueEntityIndex()).toString());
            System.out.println(new StringBuffer().append("++").append(mention).toString());
        }
        System.out.println(") ");
    }

    private static boolean referentPronoun(Mention mention) {
        String string = mention.getString();
        for (int i = 0; i < 18; i++) {
            if (pronouns[i].equals(string)) {
                return true;
            }
        }
        return false;
    }

    private static boolean referentNNP(Mention mention) {
        MalletPreTerm headPreTerm = mention.getMalletPhrase().getHeadPreTerm();
        return headPreTerm.getPartOfSpeech() != null && headPreTerm.getPartOfSpeech().equals("NNP");
    }

    private static void constructEdgesUsingModel(MappedGraph mappedGraph, MaxEnt maxEnt, Instance instance) {
        MentionPair mentionPair = (MentionPair) instance.getSource();
        Object antecedent = mentionPair.getAntecedent();
        Mention referent = mentionPair.getReferent();
        double[] parameters = maxEnt.getParameters();
        Matrix2 matrix2 = new Matrix2(parameters, 2, Array.getLength(parameters) / 2);
        double rowDotProduct = matrix2.rowDotProduct(0, (FeatureVector) instance.getData()) - matrix2.rowDotProduct(1, (FeatureVector) instance.getData());
        if (!mentionPair.nullPair()) {
            try {
                mappedGraph.addEdgeMap(antecedent, referent, rowDotProduct);
                return;
            } catch (Exception e) {
                e.printStackTrace();
                return;
            }
        }
        if (referentPronoun(referent)) {
            try {
                mappedGraph.addVertexMap(referent);
            } catch (Exception e2) {
                e2.printStackTrace();
            }
        }
    }

    private static void constructEdgesUsingTargets(MappedGraph mappedGraph, Instance instance) {
        MentionPair mentionPair = (MentionPair) instance.getSource();
        Mention antecedent = mentionPair.getAntecedent();
        Mention referent = mentionPair.getReferent();
        if (mentionPair.nullPair()) {
            try {
                mappedGraph.addVertexMap(referent);
                return;
            } catch (Exception e) {
                e.printStackTrace();
                return;
            }
        }
        if (mentionPair.getEntityReference() != null) {
            try {
                mappedGraph.addEdgeMap(antecedent, referent, 1000.0d);
            } catch (Exception e2) {
                e2.printStackTrace();
            }
        }
    }
}
