package pl.edu.icm.yadda.analysis.classification.svm;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Formatter;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.StringTokenizer;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import org.apache.commons.collections.iterators.ArrayIterator;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.lucene.analysis.shingle.ShingleFilter;
import pl.edu.icm.yadda.analysis.AnalysisException;
import pl.edu.icm.yadda.analysis.classification.features.FeatureVectorBuilder;
import pl.edu.icm.yadda.analysis.classification.hmm.training.TrainingElement;
import pl.edu.icm.yadda.analysis.classification.tools.FeatureLimits;
import pl.edu.icm.yadda.analysis.classification.tools.FeatureVectorScaler;
import pl.edu.icm.yadda.analysis.classification.tools.LinearScaling;
import pl.edu.icm.yadda.analysis.textr.ZoneClassifier;
import pl.edu.icm.yadda.analysis.textr.model.BxDocument;
import pl.edu.icm.yadda.analysis.textr.model.BxPage;
import pl.edu.icm.yadda.analysis.textr.model.BxZone;
import pl.edu.icm.yadda.analysis.textr.model.BxZoneLabel;

/* loaded from: input_file:WEB-INF/lib/yadda-analysis-impl-1.12.5.jar:pl/edu/icm/yadda/analysis/classification/svm/SVMZoneClassifier.class */
public class SVMZoneClassifier implements ZoneClassifier {
    protected static final svm_parameter defaultParameter;
    protected FeatureVectorBuilder<BxZone, BxPage> featureVectorBuilder;
    protected FeatureVectorScaler scaler;
    protected String[] features;
    private svm_parameter param;
    private svm_problem problem;
    private svm_model model;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SVMZoneClassifier(FeatureVectorBuilder<BxZone, BxPage> featureVectorBuilder) {
        this.featureVectorBuilder = featureVectorBuilder;
        this.scaler = new FeatureVectorScaler(Integer.valueOf(featureVectorBuilder.size()), Double.valueOf(0.0d), Double.valueOf(1.0d));
        this.scaler.setStrategy(new LinearScaling());
        this.param = getDefaultParam();
    }

    protected static svm_parameter clone(svm_parameter svm_parameterVar) {
        svm_parameter svm_parameterVar2 = new svm_parameter();
        svm_parameterVar2.svm_type = svm_parameterVar.svm_type;
        svm_parameterVar2.C = svm_parameterVar.C;
        svm_parameterVar2.kernel_type = svm_parameterVar.kernel_type;
        svm_parameterVar2.degree = svm_parameterVar.degree;
        svm_parameterVar2.gamma = svm_parameterVar.gamma;
        svm_parameterVar2.coef0 = svm_parameterVar.coef0;
        svm_parameterVar2.nu = svm_parameterVar.nu;
        svm_parameterVar2.cache_size = svm_parameterVar.cache_size;
        svm_parameterVar2.eps = svm_parameterVar.eps;
        svm_parameterVar2.p = svm_parameterVar.p;
        svm_parameterVar2.shrinking = svm_parameterVar.shrinking;
        svm_parameterVar2.probability = svm_parameterVar.probability;
        svm_parameterVar2.nr_weight = svm_parameterVar.nr_weight;
        svm_parameterVar2.weight_label = svm_parameterVar.weight_label;
        svm_parameterVar2.weight = svm_parameterVar.weight;
        return svm_parameterVar2;
    }

    public static svm_parameter getDefaultParam() {
        return clone(defaultParameter);
    }

    public void buildClassifier(List<TrainingElement<BxZoneLabel>> list) {
        if (!$assertionsDisabled && list.size() <= 0) {
            throw new AssertionError();
        }
        if (this.features == null) {
            this.features = (String[]) list.get(0).getObservation().getFeatureNames().toArray(new String[1]);
        }
        this.scaler.setFeatureLimits(list);
        this.problem = buildDatasetForTraining(list);
        this.model = svm.svm_train(this.problem, this.param);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public BxZoneLabel predictZoneLabel(BxZone bxZone) {
        return BxZoneLabel.values()[(int) svm.svm_predict(this.model, buildDatasetForClassification(bxZone))];
    }

    @Override // pl.edu.icm.yadda.analysis.textr.ZoneClassifier
    public BxDocument classifyZones(BxDocument bxDocument) throws AnalysisException {
        for (BxZone bxZone : bxDocument.asZones()) {
            double svm_predict = svm.svm_predict(this.model, buildDatasetForClassification(bxZone));
            System.out.println("predictedVal " + svm_predict + ShingleFilter.TOKEN_SEPARATOR + BxZoneLabel.values()[((int) svm_predict) / 100] + " (" + bxZone.getLabel() + DefaultExpressionEngine.DEFAULT_INDEX_END);
            bxZone.setLabel(BxZoneLabel.values()[((int) svm_predict) / 100]);
        }
        return bxDocument;
    }

    protected svm_problem buildDatasetForTraining(List<TrainingElement<BxZoneLabel>> list) {
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = list.size();
        svm_problemVar.x = new svm_node[svm_problemVar.l][list.get(0).getObservation().size().intValue()];
        svm_problemVar.y = new double[list.size()];
        Integer num = 0;
        for (TrainingElement<BxZoneLabel> trainingElement : list) {
            Integer num2 = 0;
            for (Double d : this.scaler.scaleFeatureVector(trainingElement.getObservation()).getFeatures()) {
                svm_node svm_nodeVar = new svm_node();
                svm_nodeVar.index = num2.intValue();
                svm_nodeVar.value = d.doubleValue();
                svm_problemVar.x[num.intValue()][num2.intValue()] = svm_nodeVar;
                num2 = Integer.valueOf(num2.intValue() + 1);
            }
            svm_problemVar.y[num.intValue()] = trainingElement.getLabel().ordinal() * 100;
            System.out.println("?? " + (trainingElement.getLabel().ordinal() * 100) + " (" + trainingElement.getLabel() + DefaultExpressionEngine.DEFAULT_INDEX_END);
            num = Integer.valueOf(num.intValue() + 1);
        }
        return svm_problemVar;
    }

    protected svm_node[] buildDatasetForClassification(BxZone bxZone) {
        svm_node[] svm_nodeVarArr = new svm_node[this.featureVectorBuilder.getFeatureNames().size()];
        Integer num = 0;
        for (Double d : this.scaler.scaleFeatureVector(this.featureVectorBuilder.getFeatureVector(bxZone, bxZone.getContext())).getFeatures()) {
            svm_node svm_nodeVar = new svm_node();
            svm_nodeVar.index = num.intValue();
            svm_nodeVar.value = d.doubleValue();
            svm_nodeVarArr[num.intValue()] = svm_nodeVar;
            num = Integer.valueOf(num.intValue() + 1);
        }
        return svm_nodeVarArr;
    }

    public double[] getWeights() {
        double[][] dArr = this.model.sv_coef;
        double[][] dArr2 = new double[this.model.SV.length][this.featureVectorBuilder.size()];
        for (int i = 0; i < this.model.SV.length; i++) {
            for (int i2 = 0; i2 < this.model.SV[i].length; i2++) {
                dArr2[i][i2] = this.model.SV[i][i2].value;
            }
        }
        double[][][] dArr3 = new double[this.model.nr_class][this.model.nr_class - 1][this.model.SV[0].length];
        System.out.println(this.model.nr_class);
        for (int i3 = 0; i3 < this.model.SV[0].length; i3++) {
            for (int i4 = 0; i4 < this.model.nr_class - 1; i4++) {
                int i5 = 0;
                int i6 = 0;
                while (i6 < this.model.nr_class) {
                    double d = 0.0d;
                    i5 += i6 == 0 ? 0 : this.model.nSV[i6 - 1];
                    int i7 = i5 + this.model.nSV[i6];
                    for (int i8 = i5; i8 < i7; i8++) {
                        d += dArr[i4][i8] * dArr2[i8][i3];
                    }
                    dArr3[i6][i4][i3] = d;
                    i6++;
                }
            }
        }
        double[] dArr4 = new double[this.model.SV[0].length];
        for (int i9 = 0; i9 < this.model.nr_class - 1; i9++) {
            int i10 = i9 + 1;
            int i11 = i9;
            while (i10 < this.model.nr_class) {
                for (int i12 = 0; i12 < this.model.SV[0].length; i12++) {
                    dArr4[i12] = dArr3[i9][i11][i12] + dArr3[i10][i9][i12];
                }
                i10++;
                i11++;
            }
        }
        return dArr4;
    }

    @Override // pl.edu.icm.yadda.analysis.textr.ZoneClassifier
    public void loadModel(String str) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str + ".range"));
        if (bufferedReader.read() != 120) {
            throw new RuntimeException("y scaling not supported");
        }
        bufferedReader.readLine();
        StringTokenizer stringTokenizer = new StringTokenizer(bufferedReader.readLine());
        Double valueOf = Double.valueOf(Double.parseDouble(stringTokenizer.nextToken()));
        Double valueOf2 = Double.valueOf(Double.parseDouble(stringTokenizer.nextToken()));
        if (valueOf.doubleValue() != 0.0d || valueOf2.doubleValue() != 1.0d) {
            throw new RuntimeException("Feature lower bound and upper bound mustbe set in range file to resepctively 0 and 1");
        }
        ArrayList arrayList = new ArrayList();
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                this.scaler = new FeatureVectorScaler(Integer.valueOf(arrayList.size()), valueOf, valueOf2);
                bufferedReader.close();
                svm.svm_load_model(str);
                return;
            } else {
                StringTokenizer stringTokenizer2 = new StringTokenizer(readLine);
                Integer.valueOf(Integer.parseInt(stringTokenizer2.nextToken()));
                arrayList.add(new FeatureLimits(Double.valueOf(Double.parseDouble(stringTokenizer2.nextToken())), Double.valueOf(Double.parseDouble(stringTokenizer2.nextToken()))));
            }
        }
    }

    @Override // pl.edu.icm.yadda.analysis.textr.ZoneClassifier
    public void saveModel(String str) throws IOException {
        Formatter formatter = new Formatter(new StringBuilder());
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str + ".range"));
        Double valueOf = Double.valueOf(0.0d);
        Double valueOf2 = Double.valueOf(1.0d);
        formatter.format("x\n", new Object[0]);
        formatter.format("%.16g %.16g\n", valueOf, valueOf2);
        int i = 0;
        while (true) {
            Integer num = i;
            if (num.intValue() >= this.featureVectorBuilder.size()) {
                bufferedWriter.write(formatter.toString());
                bufferedWriter.close();
                svm.svm_save_model(str, this.model);
                return;
            }
            formatter.format("%d %.16g %.16g\n", num, Double.valueOf(this.scaler.getLimits()[num.intValue()].getMin()), Double.valueOf(this.scaler.getLimits()[num.intValue()].getMax()));
            i = Integer.valueOf(num.intValue() + 1);
        }
    }

    public void printWeigths(FeatureVectorBuilder<BxZone, BxPage> featureVectorBuilder) {
        Set<String> featureNames = this.featureVectorBuilder.getFeatureNames();
        Iterator<String> it = featureNames.iterator();
        ArrayIterator arrayIterator = new ArrayIterator(getWeights());
        if (!$assertionsDisabled && featureNames.size() != getWeights().length) {
            throw new AssertionError();
        }
        while (it.hasNext() && arrayIterator.hasNext()) {
            System.out.println(it.next() + ShingleFilter.TOKEN_SEPARATOR + ((Double) arrayIterator.next()));
        }
    }

    public void setParameter(svm_parameter svm_parameterVar) {
        this.param = svm_parameterVar;
    }

    static {
        $assertionsDisabled = !SVMZoneClassifier.class.desiredAssertionStatus();
        defaultParameter = new svm_parameter();
        defaultParameter.svm_type = 0;
        defaultParameter.C = 2048.0d;
        defaultParameter.kernel_type = 1;
        defaultParameter.degree = 3;
        defaultParameter.gamma = 128.0d;
        defaultParameter.coef0 = 0.5d;
        defaultParameter.nu = 0.5d;
        defaultParameter.cache_size = 100.0d;
        defaultParameter.eps = 0.001d;
        defaultParameter.p = 0.1d;
        defaultParameter.shrinking = 1;
        defaultParameter.probability = 0;
        defaultParameter.nr_weight = 0;
        defaultParameter.weight_label = new int[0];
        defaultParameter.weight = new double[0];
    }
}
