package defpackage;

import edu.umass.cs.mallet.base.fst.Transducer;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.StringTokenizer;
import java.util.Vector;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_print_interface;
import libsvm.svm_problem;
import org.springframework.asm.Opcodes;

/* loaded from: input_file:svm_train.class */
class svm_train {
    private svm_parameter param;
    private svm_problem prob;
    private svm_model model;
    private String input_file_name;
    private String model_file_name;
    private String error_msg;
    private int cross_validation;
    private int nr_fold;
    private static svm_print_interface svm_print_null = new svm_print_interface() { // from class: svm_train.1
        @Override // libsvm.svm_print_interface
        public void print(String str) {
        }
    };

    svm_train() {
    }

    private static void exit_with_help() {
        System.out.print("Usage: svm_train [options] training_set_file [model_file]\noptions:\n-s svm_type : set type of SVM (default 0)\n\t0 -- C-SVC\t\t(multi-class classification)\n\t1 -- nu-SVC\t\t(multi-class classification)\n\t2 -- one-class SVM\n\t3 -- epsilon-SVR\t(regression)\n\t4 -- nu-SVR\t\t(regression)\n-t kernel_type : set type of kernel function (default 2)\n\t0 -- linear: u'*v\n\t1 -- polynomial: (gamma*u'*v + coef0)^degree\n\t2 -- radial basis function: exp(-gamma*|u-v|^2)\n\t3 -- sigmoid: tanh(gamma*u'*v + coef0)\n\t4 -- precomputed kernel (kernel values in training_set_file)\n-d degree : set degree in kernel function (default 3)\n-g gamma : set gamma in kernel function (default 1/num_features)\n-r coef0 : set coef0 in kernel function (default 0)\n-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n-m cachesize : set cache memory size in MB (default 100)\n-e epsilon : set tolerance of termination criterion (default 0.001)\n-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n-v n : n-fold cross validation mode\n-q : quiet mode (no outputs)\n");
        System.exit(1);
    }

    private void do_cross_validation() {
        int i = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        double[] dArr = new double[this.prob.l];
        svm.svm_cross_validation(this.prob, this.param, this.nr_fold, dArr);
        if (this.param.svm_type != 3 && this.param.svm_type != 4) {
            for (int i2 = 0; i2 < this.prob.l; i2++) {
                if (dArr[i2] == this.prob.y[i2]) {
                    i++;
                }
            }
            System.out.print("Cross Validation Accuracy = " + ((100.0d * i) / this.prob.l) + "%\n");
            return;
        }
        for (int i3 = 0; i3 < this.prob.l; i3++) {
            double d7 = this.prob.y[i3];
            double d8 = dArr[i3];
            d += (d8 - d7) * (d8 - d7);
            d2 += d8;
            d3 += d7;
            d4 += d8 * d8;
            d5 += d7 * d7;
            d6 += d8 * d7;
        }
        System.out.print("Cross Validation Mean squared error = " + (d / this.prob.l) + "\n");
        System.out.print("Cross Validation Squared correlation coefficient = " + ((((this.prob.l * d6) - (d2 * d3)) * ((this.prob.l * d6) - (d2 * d3))) / (((this.prob.l * d4) - (d2 * d2)) * ((this.prob.l * d5) - (d3 * d3)))) + "\n");
    }

    private void run(String[] strArr) throws IOException {
        parse_command_line(strArr);
        read_problem();
        this.error_msg = svm.svm_check_parameter(this.prob, this.param);
        if (this.error_msg != null) {
            System.err.print("ERROR: " + this.error_msg + "\n");
            System.exit(1);
        }
        if (this.cross_validation != 0) {
            do_cross_validation();
        } else {
            this.model = svm.svm_train(this.prob, this.param);
            svm.svm_save_model(this.model_file_name, this.model);
        }
    }

    public static void main(String[] strArr) throws IOException {
        new svm_train().run(strArr);
    }

    private static double atof(String str) {
        double doubleValue = Double.valueOf(str).doubleValue();
        if (Double.isNaN(doubleValue) || Double.isInfinite(doubleValue)) {
            System.err.print("NaN or Infinity in input\n");
            System.exit(1);
        }
        return doubleValue;
    }

    private static int atoi(String str) {
        return Integer.parseInt(str);
    }

    private void parse_command_line(String[] strArr) {
        svm_print_interface svm_print_interfaceVar = null;
        this.param = new svm_parameter();
        this.param.svm_type = 0;
        this.param.kernel_type = 2;
        this.param.degree = 3;
        this.param.gamma = Transducer.ZERO_COST;
        this.param.coef0 = Transducer.ZERO_COST;
        this.param.nu = 0.5d;
        this.param.cache_size = 100.0d;
        this.param.C = 1.0d;
        this.param.eps = 0.001d;
        this.param.p = 0.1d;
        this.param.shrinking = 1;
        this.param.probability = 0;
        this.param.nr_weight = 0;
        this.param.weight_label = new int[0];
        this.param.weight = new double[0];
        this.cross_validation = 0;
        int i = 0;
        while (i < strArr.length && strArr[i].charAt(0) == '-') {
            int i2 = i + 1;
            if (i2 >= strArr.length) {
                exit_with_help();
            }
            switch (strArr[i2 - 1].charAt(1)) {
                case 'b':
                    this.param.probability = atoi(strArr[i2]);
                    break;
                case 'c':
                    this.param.C = atof(strArr[i2]);
                    break;
                case 'd':
                    this.param.degree = atoi(strArr[i2]);
                    break;
                case 'e':
                    this.param.eps = atof(strArr[i2]);
                    break;
                case 'f':
                case 'i':
                case 'j':
                case 'k':
                case 'l':
                case 'o':
                case 'u':
                default:
                    System.err.print("Unknown option: " + strArr[i2 - 1] + "\n");
                    exit_with_help();
                    break;
                case 'g':
                    this.param.gamma = atof(strArr[i2]);
                    break;
                case 'h':
                    this.param.shrinking = atoi(strArr[i2]);
                    break;
                case 'm':
                    this.param.cache_size = atof(strArr[i2]);
                    break;
                case 'n':
                    this.param.nu = atof(strArr[i2]);
                    break;
                case 'p':
                    this.param.p = atof(strArr[i2]);
                    break;
                case 'q':
                    svm_print_interfaceVar = svm_print_null;
                    i2--;
                    break;
                case 'r':
                    this.param.coef0 = atof(strArr[i2]);
                    break;
                case 's':
                    this.param.svm_type = atoi(strArr[i2]);
                    break;
                case 't':
                    this.param.kernel_type = atoi(strArr[i2]);
                    break;
                case Opcodes.FNEG /* 118 */:
                    this.cross_validation = 1;
                    this.nr_fold = atoi(strArr[i2]);
                    if (this.nr_fold >= 2) {
                        break;
                    } else {
                        System.err.print("n-fold cross validation: n must >= 2\n");
                        exit_with_help();
                        break;
                    }
                case Opcodes.DNEG /* 119 */:
                    this.param.nr_weight++;
                    int[] iArr = this.param.weight_label;
                    this.param.weight_label = new int[this.param.nr_weight];
                    System.arraycopy(iArr, 0, this.param.weight_label, 0, this.param.nr_weight - 1);
                    double[] dArr = this.param.weight;
                    this.param.weight = new double[this.param.nr_weight];
                    System.arraycopy(dArr, 0, this.param.weight, 0, this.param.nr_weight - 1);
                    this.param.weight_label[this.param.nr_weight - 1] = atoi(strArr[i2 - 1].substring(2));
                    this.param.weight[this.param.nr_weight - 1] = atof(strArr[i2]);
                    break;
            }
            i = i2 + 1;
        }
        svm.svm_set_print_string_function(svm_print_interfaceVar);
        if (i >= strArr.length) {
            exit_with_help();
        }
        this.input_file_name = strArr[i];
        if (i < strArr.length - 1) {
            this.model_file_name = strArr[i + 1];
        } else {
            this.model_file_name = strArr[i].substring(strArr[i].lastIndexOf(47) + 1) + ".model";
        }
    }

    /* JADX WARN: Type inference failed for: r1v9, types: [libsvm.svm_node[], libsvm.svm_node[][]] */
    private void read_problem() throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new FileReader(this.input_file_name));
        Vector vector = new Vector();
        Vector vector2 = new Vector();
        int i = 0;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                break;
            }
            StringTokenizer stringTokenizer = new StringTokenizer(readLine, " \t\n\r\f:");
            vector.addElement(Double.valueOf(atof(stringTokenizer.nextToken())));
            int countTokens = stringTokenizer.countTokens() / 2;
            svm_node[] svm_nodeVarArr = new svm_node[countTokens];
            for (int i2 = 0; i2 < countTokens; i2++) {
                svm_nodeVarArr[i2] = new svm_node();
                svm_nodeVarArr[i2].index = atoi(stringTokenizer.nextToken());
                svm_nodeVarArr[i2].value = atof(stringTokenizer.nextToken());
            }
            if (countTokens > 0) {
                i = Math.max(i, svm_nodeVarArr[countTokens - 1].index);
            }
            vector2.addElement(svm_nodeVarArr);
        }
        this.prob = new svm_problem();
        this.prob.l = vector.size();
        this.prob.x = new svm_node[this.prob.l];
        for (int i3 = 0; i3 < this.prob.l; i3++) {
            this.prob.x[i3] = (svm_node[]) vector2.elementAt(i3);
        }
        this.prob.y = new double[this.prob.l];
        for (int i4 = 0; i4 < this.prob.l; i4++) {
            this.prob.y[i4] = ((Double) vector.elementAt(i4)).doubleValue();
        }
        if (this.param.gamma == Transducer.ZERO_COST && i > 0) {
            this.param.gamma = 1.0d / i;
        }
        if (this.param.kernel_type == 4) {
            for (int i5 = 0; i5 < this.prob.l; i5++) {
                if (this.prob.x[i5][0].index != 0) {
                    System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n");
                    System.exit(1);
                }
                if (((int) this.prob.x[i5][0].value) <= 0 || ((int) this.prob.x[i5][0].value) > i) {
                    System.err.print("Wrong input format: sample_serial_number out of range\n");
                    System.exit(1);
                }
            }
        }
        bufferedReader.close();
    }
}
