package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
import com.google.common.io.Files;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import java.util.TreeMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.Version;
import org.apache.mahout.classifier.sgd.ModelDissector;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/TrainNewsGroups.class */
public final class TrainNewsGroups {
    private static final int FEATURES = 10000;
    private static final long DATE_REFERENCE = 853286460;
    private static final long MONTH = 2592000;
    private static final long WEEK = 604800;
    private static final Random rand = RandomUtils.getRandom();
    private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"};
    private static final SimpleDateFormat[] DATE_FORMATS = {new SimpleDateFormat("", Locale.ENGLISH), new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH), new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)};
    private static final Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_30);
    private static final FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
    private static final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    private static Multiset<String> overallCounts;

    private TrainNewsGroups() {
    }

    public static void main(String[] strArr) throws IOException {
        double d;
        double d2;
        double d3;
        double d4;
        File file = new File(strArr[0]);
        overallCounts = HashMultiset.create();
        int parseInt = strArr.length > 1 ? Integer.parseInt(strArr[1]) : 0;
        Dictionary dictionary = new Dictionary();
        encoder.setProbes(2);
        AdaptiveLogisticRegression adaptiveLogisticRegression = new AdaptiveLogisticRegression(20, FEATURES, new L1());
        adaptiveLogisticRegression.setInterval(800);
        adaptiveLogisticRegression.setAveragingWindow(500);
        ArrayList newArrayList = Lists.newArrayList();
        for (File file2 : file.listFiles()) {
            if (file2.isDirectory()) {
                dictionary.intern(file2.getName());
                newArrayList.addAll(Arrays.asList(file2.listFiles()));
            }
        }
        Collections.shuffle(newArrayList);
        System.out.printf("%d training files\n", Integer.valueOf(newArrayList.size()));
        double d5 = 0.0d;
        double d6 = 0.0d;
        int i = 0;
        double d7 = 0.0d;
        int[] iArr = {1, 2, 5};
        for (File file3 : newArrayList.subList(0, 3000)) {
            int intern = dictionary.intern(file3.getParentFile().getName());
            adaptiveLogisticRegression.train(intern, encodeFeatureVector(file3, intern, parseInt));
            i++;
            int i2 = iArr[((int) Math.floor(d7)) % iArr.length];
            int pow = (int) Math.pow(10.0d, Math.floor(d7 / iArr.length));
            State best = adaptiveLogisticRegression.getBest();
            double d8 = 0.0d;
            double d9 = 0.0d;
            if (best != null) {
                CrossFoldLearner learner = best.getPayload().getLearner();
                d6 = learner.percentCorrect();
                d5 = learner.logLikelihood();
                OnlineLogisticRegression onlineLogisticRegression = (OnlineLogisticRegression) learner.getModels().get(0);
                onlineLogisticRegression.close();
                Matrix beta = onlineLogisticRegression.getBeta();
                d = beta.aggregate(Functions.MAX, Functions.ABS);
                d2 = beta.aggregate(Functions.PLUS, new DoubleFunction() { // from class: org.apache.mahout.classifier.sgd.TrainNewsGroups.1
                    public double apply(double d10) {
                        return Math.abs(d10) > 1.0E-6d ? 1.0d : 0.0d;
                    }
                });
                d3 = beta.aggregate(Functions.PLUS, new DoubleFunction() { // from class: org.apache.mahout.classifier.sgd.TrainNewsGroups.2
                    public double apply(double d10) {
                        return d10 > 0.0d ? 1.0d : 0.0d;
                    }
                });
                d4 = beta.aggregate(Functions.PLUS, Functions.ABS);
                d8 = adaptiveLogisticRegression.getBest().getMappedParams()[0];
                d9 = adaptiveLogisticRegression.getBest().getMappedParams()[1];
            } else {
                d = 0.0d;
                d2 = 0.0d;
                d3 = 0.0d;
                d4 = 0.0d;
            }
            if (i % (i2 * pow) == 0) {
                if (adaptiveLogisticRegression.getBest() != null) {
                    ModelSerializer.writeBinary("/tmp/news-group-" + i + ".model", (OnlineLogisticRegression) adaptiveLogisticRegression.getBest().getPayload().getLearner().getModels().get(0));
                }
                d7 += 0.25d;
                System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3), Double.valueOf(d4), Double.valueOf(d8), Double.valueOf(d9));
                System.out.printf("%d\t%.3f\t%.2f\t%s\n", Integer.valueOf(i), Double.valueOf(d5), Double.valueOf(d6 * 100.0d), LEAK_LABELS[parseInt % 3]);
            }
        }
        adaptiveLogisticRegression.close();
        dissect(parseInt, dictionary, adaptiveLogisticRegression, newArrayList);
        System.out.println("exiting main");
        ModelSerializer.writeBinary("/tmp/news-group.model", (OnlineLogisticRegression) adaptiveLogisticRegression.getBest().getPayload().getLearner().getModels().get(0));
        ArrayList newArrayList2 = Lists.newArrayList();
        System.out.printf("Word counts\n", new Object[0]);
        Iterator it = overallCounts.elementSet().iterator();
        while (it.hasNext()) {
            newArrayList2.add(Integer.valueOf(overallCounts.count((String) it.next())));
        }
        Collections.sort(newArrayList2, Ordering.natural().reverse());
        int i3 = 0;
        Iterator it2 = newArrayList2.iterator();
        while (it2.hasNext()) {
            System.out.printf("%d\t%d\n", Integer.valueOf(i3), (Integer) it2.next());
            i3++;
            if (i3 > 1000) {
                return;
            }
        }
    }

    private static void dissect(int i, Dictionary dictionary, AdaptiveLogisticRegression adaptiveLogisticRegression, Iterable<File> iterable) throws IOException {
        CrossFoldLearner learner = adaptiveLogisticRegression.getBest().getPayload().getLearner();
        learner.close();
        TreeMap newTreeMap = Maps.newTreeMap();
        ModelDissector modelDissector = new ModelDissector();
        encoder.setTraceDictionary(newTreeMap);
        bias.setTraceDictionary(newTreeMap);
        for (File file : permute(iterable, rand).subList(0, 500)) {
            int intern = dictionary.intern(file.getParentFile().getName());
            newTreeMap.clear();
            modelDissector.update(encodeFeatureVector(file, intern, i), newTreeMap, learner);
        }
        ArrayList newArrayList = Lists.newArrayList(dictionary.values());
        for (ModelDissector.Weight weight : modelDissector.summary(100)) {
            System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n", weight.getFeature(), Double.valueOf(weight.getWeight()), newArrayList.get(weight.getMaxImpact() + 1), Double.valueOf(weight.getCategory(1)), Double.valueOf(weight.getWeight(1)), Double.valueOf(weight.getCategory(2)), Double.valueOf(weight.getWeight(2)));
        }
    }

    private static Vector encodeFeatureVector(File file, int i, int i2) throws IOException {
        long nextDouble = (long) (1000.0d * (DATE_REFERENCE + (i * MONTH) + (604800.0d * rand.nextDouble())));
        ConcurrentHashMultiset create = ConcurrentHashMultiset.create();
        BufferedReader newReader = Files.newReader(file, Charsets.UTF_8);
        try {
            String readLine = newReader.readLine();
            countWords(analyzer, create, new StringReader(DATE_FORMATS[i2 % 3].format(new Date(nextDouble))));
            while (readLine != null && readLine.length() > 0) {
                boolean z = (readLine.startsWith("From:") || readLine.startsWith("Subject:") || readLine.startsWith("Keywords:") || readLine.startsWith("Summary:")) && i2 < 6;
                do {
                    StringReader stringReader = new StringReader(readLine);
                    if (z) {
                        countWords(analyzer, create, stringReader);
                    }
                    readLine = newReader.readLine();
                    if (readLine != null) {
                    }
                } while (readLine.startsWith(" "));
            }
            if (i2 < 3) {
                countWords(analyzer, create, newReader);
            }
            RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(FEATURES);
            bias.addToVector("", 1.0d, randomAccessSparseVector);
            Iterator it = create.elementSet().iterator();
            while (it.hasNext()) {
                encoder.addToVector((String) it.next(), Math.log(1 + create.count(r0)), randomAccessSparseVector);
            }
            return randomAccessSparseVector;
        } finally {
            newReader.close();
        }
    }

    private static void countWords(Analyzer analyzer2, Collection<String> collection, Reader reader) throws IOException {
        TokenStream tokenStream = analyzer2.tokenStream("text", reader);
        tokenStream.addAttribute(CharTermAttribute.class);
        while (tokenStream.incrementToken()) {
            collection.add(tokenStream.getAttribute(CharTermAttribute.class).toString());
        }
        overallCounts.addAll(collection);
    }

    private static List<File> permute(Iterable<File> iterable, Random random) {
        ArrayList newArrayList = Lists.newArrayList();
        for (File file : iterable) {
            int nextInt = random.nextInt(newArrayList.size() + 1);
            if (nextInt == newArrayList.size()) {
                newArrayList.add(file);
            } else {
                newArrayList.add(newArrayList.get(nextInt));
                newArrayList.set(nextInt, file);
            }
        }
        return newArrayList;
    }
}
