package edu.umass.cs.mallet.projects.seg_plus_coref.condclust.pipe;

import edu.umass.cs.mallet.base.classify.Classifier;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.Labeling;
import edu.umass.cs.mallet.projects.seg_plus_coref.condclust.types.NodeClusterPair;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.Citation;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.NodePair;
import java.util.Collection;

/* loaded from: input_file:WEB-INF/lib/mallet-0.1.3.jar:edu/umass/cs/mallet/projects/seg_plus_coref/condclust/pipe/ClusterHomogeneity.class */
public class ClusterHomogeneity extends Pipe {
    Classifier classifier;

    public ClusterHomogeneity(Classifier classifier) {
        this.classifier = classifier;
    }

    @Override // edu.umass.cs.mallet.base.pipe.Pipe
    public Instance pipe(Instance instance) {
        NodeClusterPair nodeClusterPair = (NodeClusterPair) instance.getData();
        Collection collection = (Collection) nodeClusterPair.getCluster();
        Citation[] citationArr = (Citation[]) collection.toArray(new Citation[0]);
        double[][] dArr = new double[citationArr.length][citationArr.length];
        if (collection.size() > 1) {
            for (int i = 0; i < citationArr.length; i++) {
                Citation citation = citationArr[i];
                dArr[i][i] = 1.0d;
                for (int i2 = i + 1; i2 < citationArr.length; i2++) {
                    double similarity = getSimilarity(citation, citationArr[i2]);
                    dArr[i2][i] = similarity;
                    dArr[i][i2] = similarity;
                }
            }
            double average = getAverage(dArr, citationArr.length);
            double similarity2 = getSimilarity(getCitationClosestToAll(citationArr, dArr), (Citation) nodeClusterPair.getNode());
            if (similarity2 > 0.9d) {
                nodeClusterPair.setFeatureValue("CH_PrototypeNodeSimilarityHigh", 1.0d);
            } else if (similarity2 > 0.75d) {
                nodeClusterPair.setFeatureValue("CH_PrototypeNodeSimilarityMed", 1.0d);
            } else if (similarity2 > 0.5d) {
                nodeClusterPair.setFeatureValue("CH_PrototypeNodeSimilarityWeak", 1.0d);
            } else if (similarity2 > 0.3d) {
                nodeClusterPair.setFeatureValue("CH_PrototypeNodeSimilarityMin", 1.0d);
            } else {
                nodeClusterPair.setFeatureValue("CH_PrototypeNodeSimilarityNone", 1.0d);
            }
            if (average > 0.9d) {
                nodeClusterPair.setFeatureValue("WithinClusterSimilarityHigh", 1.0d);
            } else if (average > 0.75d) {
                nodeClusterPair.setFeatureValue("WithinClusterSimilarityMed", 1.0d);
            } else if (average > 0.5d) {
                nodeClusterPair.setFeatureValue("WithinClusterSimilarityWeak", 1.0d);
            } else if (average > 0.3d) {
                nodeClusterPair.setFeatureValue("WithinClusterSimilarityMin", 1.0d);
            } else {
                nodeClusterPair.setFeatureValue("WithinClusterSimilarityNone", 1.0d);
            }
        }
        return instance;
    }

    private double getAverage(double[][] dArr, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                d += dArr[i2][i3];
            }
        }
        return d / (i * i);
    }

    private double getSimilarity(Citation citation, Citation citation2) {
        NodePair nodePair = new NodePair(citation, citation2);
        Labeling labeling = this.classifier.classify(new Instance(nodePair, "unknown", null, nodePair, this.classifier.getInstancePipe())).getLabeling();
        return labeling.labelAtLocation(0).toString().equals("no") ? labeling.valueAtLocation(1) - labeling.valueAtLocation(0) : labeling.valueAtLocation(0) - labeling.valueAtLocation(1);
    }

    private Citation getCitationClosestToAll(Citation[] citationArr, double[][] dArr) {
        double d = -9999999.9d;
        int i = -1;
        for (int i2 = 0; i2 < citationArr.length; i2++) {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < citationArr.length; i3++) {
                d2 += dArr[i2][i3];
            }
            if (d2 > d) {
                d = d2;
                i = i2;
            }
        }
        return citationArr[i];
    }
}
