package weka.filters.supervised.attribute;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.functions.NonNegativeLogisticRegression;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.ManhattanDistance;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.neighboursearch.LinearNNSearch;
import weka.filters.SimpleBatchFilter;
import weka.filters.SupervisedFilter;

/* loaded from: input_file:weka/filters/supervised/attribute/SupervisedAttributeScaler.class */
public class SupervisedAttributeScaler extends SimpleBatchFilter implements SupervisedFilter, TechnicalInformationHandler {
    static final long serialVersionUID = -4448107323933117974L;
    protected double[] m_weights = null;
    protected boolean m_AssumeEuclideanDistance = false;
    protected int m_numNeighbours = 30;

    public String globalInfo() {
        return "Rescales the attributes in a classification problem based on their discriminative power. This is useful as a pre-processing step for learning algorithms such as the k-nearest-neighbour method, to replace simple normalization. Each attribute is rescaled by multiplying it with a learned weight. All attributes excluding the class are assumed to be numeric and missing values are not permitted.\n\nThe attribute weights are learned by taking the original labeled dataset with N instances and creating a new dataset with N*K instances, where K is the number of neighbours selected. To this end, each instance in the original dataset is paired with its K nearest neighbours, creating K pairs. Then, an instance in the new dataset is created for each pair, with the same number of attributes as in the original data. An attribute's value in this new instance is set to the absolute difference between the corresponding attribute values in the pair of original instances. The new instance's label depends on whether the two instances in the pair have the same class label or not, yielding a two-class classification problem. A logistic regression model with non-negative coefficients is learned from this data and the resulting coefficients are used as weights to rescale the original data.\n\nThis process assumes that distance in the original space is measured using Manhattan distance because the absolute difference is taken between attribute values. The method can optionally be used to learn weights for a Euclidean distance. In this case, squared differences are taken rather than absolute differences, and the square root of the learned coefficients is used to rescale the attributes in the original data.\n\nThe approach is based on the Probabilistic Global Distance Metric Learning method included in the experimental comparison in\n\n" + getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "L. Yang and R. Jin and R. Sukthankar and Y. Liu");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "An efficient algorithm for local distance metric learning");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the National Conference on Artificial Intelligence");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2006");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "543-548");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "AAAI Press");
        return technicalInformation;
    }

    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tIf set, weights are learned for Euclidean distance.\n", "-assume-Euclidean-distance", 0, "-assume-Euclidean-distance"));
        vector.addElement(new Option("\tThe number of neighbours to use (default: 30).\n", "-K", 1, "-K"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    public String[] getOptions() {
        Vector vector = new Vector();
        if (getAssumeEuclideanDistance()) {
            vector.add("-assume-Euclidean-distance");
        }
        vector.add("-K");
        vector.add("" + getNumNeighbours());
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public void setOptions(String[] strArr) throws Exception {
        setAssumeEuclideanDistance(Utils.getFlag("assume-Euclidean-distance", strArr));
        String option = Utils.getOption('K', strArr);
        if (option.length() != 0) {
            setNumNeighbours(Integer.parseInt(option));
        } else {
            setNumNeighbours(30);
        }
        super.setOptions(strArr);
        Utils.checkForRemainingOptions(strArr);
    }

    public String assumeEuclideanDistanceTipText() {
        return "Whether to assume Euclidean distance rather than Manhattan distance.";
    }

    public boolean getAssumeEuclideanDistance() {
        return this.m_AssumeEuclideanDistance;
    }

    public void setAssumeEuclideanDistance(boolean z) {
        this.m_AssumeEuclideanDistance = z;
    }

    public String numNeighboursTipText() {
        return "The number of neighbours to use.";
    }

    public int getNumNeighbours() {
        return this.m_numNeighbours;
    }

    public void setNumNeighbours(int i) {
        this.m_numNeighbours = i;
    }

    protected Instances determineOutputFormat(Instances instances) {
        return new Instances(instances);
    }

    public void initFilter(Instances instances) throws Exception {
        EuclideanDistance euclideanDistance;
        int numInstances = instances.numInstances();
        int numAttributes = instances.numAttributes();
        this.m_weights = new double[numAttributes];
        ArrayList arrayList = new ArrayList(numAttributes + 1);
        arrayList.add(new Attribute("-1"));
        for (int i = 0; i < numAttributes; i++) {
            if (i == instances.classIndex()) {
                ArrayList arrayList2 = new ArrayList(2);
                arrayList2.add("different_class_values");
                arrayList2.add("same_class_values");
                arrayList.add(new Attribute("Class", arrayList2));
            } else {
                arrayList.add((Attribute) instances.attribute(i).copy());
            }
        }
        if (getAssumeEuclideanDistance()) {
            EuclideanDistance euclideanDistance2 = new EuclideanDistance();
            euclideanDistance2.setDontNormalize(true);
            euclideanDistance = euclideanDistance2;
        } else {
            EuclideanDistance manhattanDistance = new ManhattanDistance();
            manhattanDistance.setDontNormalize(true);
            euclideanDistance = manhattanDistance;
        }
        LinearNNSearch linearNNSearch = new LinearNNSearch();
        linearNNSearch.setDistanceFunction(euclideanDistance);
        linearNNSearch.setInstances(instances);
        int i2 = this.m_numNeighbours;
        if (this.m_numNeighbours >= instances.numInstances()) {
            i2 = instances.numInstances() - 1;
        }
        Instances instances2 = new Instances("pairwise_data", arrayList, numInstances * i2);
        for (int i3 = 0; i3 < numInstances; i3++) {
            Instance instance = instances.instance(i3);
            linearNNSearch.addInstanceInfo(instance);
            Instances kNearestNeighbours = linearNNSearch.kNearestNeighbours(instance, i2);
            for (int i4 = 0; i4 < i2; i4++) {
                Instance instance2 = kNearestNeighbours.instance(i4);
                double[] dArr = new double[numAttributes + 1];
                dArr[0] = -1.0d;
                for (int i5 = 0; i5 < numAttributes; i5++) {
                    if (i5 != instances.classIndex()) {
                        double value = instance.value(i5) - instance2.value(i5);
                        if (getAssumeEuclideanDistance()) {
                            dArr[i5 + 1] = value * value;
                        } else {
                            dArr[i5 + 1] = Math.abs(value);
                        }
                    } else {
                        dArr[i5 + 1] = instance.classValue() == instance2.classValue() ? 1.0d : 0.0d;
                    }
                }
                instances2.add(new DenseInstance(1.0d, dArr));
            }
        }
        instances2.setClassIndex(instances.classIndex() + 1);
        NonNegativeLogisticRegression nonNegativeLogisticRegression = new NonNegativeLogisticRegression();
        nonNegativeLogisticRegression.buildClassifier(instances2);
        double[] coefficients = nonNegativeLogisticRegression.getCoefficients();
        for (int i6 = 1; i6 < coefficients.length; i6++) {
            this.m_weights[i6 - 1] = coefficients[i6];
            if (getAssumeEuclideanDistance()) {
                this.m_weights[i6 - 1] = Math.sqrt(this.m_weights[i6 - 1]);
            }
        }
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        return capabilities;
    }

    protected Instances process(Instances instances) throws Exception {
        if (!isFirstBatchDone()) {
            initFilter(instances);
        }
        Instances instances2 = new Instances(instances, instances.numInstances());
        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            double[] dArr = new double[instances.numAttributes()];
            for (int i2 = 0; i2 < instances.numAttributes(); i2++) {
                dArr[i2] = instance.value(i2);
                if (i2 != instances.classIndex()) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] * this.m_weights[i2];
                }
            }
            instances2.add(new DenseInstance(1.0d, dArr));
        }
        return instances2;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }

    public static void main(String[] strArr) {
        runFilter(new SupervisedAttributeScaler(), strArr);
    }
}
