package org.apache.mahout.clustering.lda;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.utils.clustering.ClusterDumper;
import org.apache.mahout.utils.vectors.VectorHelper;

/* loaded from: input_file:org/apache/mahout/clustering/lda/LDAPrintTopics.class */
public final class LDAPrintTopics {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/mahout/clustering/lda/LDAPrintTopics$StringDoublePair.class */
    public static class StringDoublePair implements Comparable<StringDoublePair> {
        private final double score;
        private final String word;

        StringDoublePair(double d, String str) {
            this.score = d;
            this.word = str;
        }

        @Override // java.lang.Comparable
        public int compareTo(StringDoublePair stringDoublePair) {
            return Double.compare(this.score, stringDoublePair.score);
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof StringDoublePair)) {
                return false;
            }
            StringDoublePair stringDoublePair = (StringDoublePair) obj;
            return this.score == stringDoublePair.score && this.word.equals(stringDoublePair.word);
        }

        public int hashCode() {
            return ((int) Double.doubleToLongBits(this.score)) ^ this.word.hashCode();
        }
    }

    private LDAPrintTopics() {
    }

    private static void ensureQueueSize(Collection<PriorityQueue<StringDoublePair>> collection, int i) {
        for (int size = collection.size(); size <= i; size++) {
            collection.add(new PriorityQueue<>());
        }
    }

    public static void main(String[] strArr) throws Exception {
        List asList;
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        GroupBuilder groupBuilder = new GroupBuilder();
        DefaultOption create = defaultOptionBuilder.withLongName("input").withRequired(true).withArgument(argumentBuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription("Path to an LDA output (a state)").withShortName("i").create();
        DefaultOption create2 = defaultOptionBuilder.withLongName("dict").withRequired(true).withArgument(argumentBuilder.withName("dict").withMinimum(1).withMaximum(1).create()).withDescription("Dictionary to read in, in the same format as one created by org.apache.mahout.utils.vectors.lucene.Driver").withShortName("d").create();
        DefaultOption create3 = defaultOptionBuilder.withLongName(ClusterDumper.OUTPUT_OPTION).withRequired(false).withArgument(argumentBuilder.withName(ClusterDumper.OUTPUT_OPTION).withMinimum(1).withMaximum(1).create()).withDescription("Output directory to write top words").withShortName("o").create();
        DefaultOption create4 = defaultOptionBuilder.withLongName("words").withRequired(false).withArgument(argumentBuilder.withName("words").withMinimum(0).withMaximum(1).withDefault("20").create()).withDescription("Number of words to print").withShortName("w").create();
        DefaultOption create5 = defaultOptionBuilder.withLongName(ClusterDumper.DICTIONARY_TYPE_OPTION).withRequired(false).withArgument(argumentBuilder.withName(ClusterDumper.DICTIONARY_TYPE_OPTION).withMinimum(1).withMaximum(1).create()).withDescription("The dictionary file type (text|sequencefile)").withShortName("dt").create();
        DefaultOption create6 = defaultOptionBuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
        Group create7 = groupBuilder.withName("Options").withOption(create2).withOption(create3).withOption(create4).withOption(create).withOption(create5).create();
        try {
            Parser parser = new Parser();
            parser.setGroup(create7);
            CommandLine parse = parser.parse(strArr);
            if (parse.hasOption(create6)) {
                CommandLineUtil.printHelp(create7);
                return;
            }
            String obj = parse.getValue(create).toString();
            String obj2 = parse.getValue(create2).toString();
            int i = 20;
            if (parse.hasOption(create4)) {
                i = Integer.parseInt(parse.getValue(create4).toString());
            }
            Configuration configuration = new Configuration();
            String obj3 = parse.hasOption(create5) ? parse.getValue(create5).toString() : "text";
            if ("text".equals(obj3)) {
                asList = Arrays.asList(VectorHelper.loadTermDictionary(new File(obj2)));
            } else {
                if (!"sequencefile".equals(obj3)) {
                    throw new IllegalArgumentException("Invalid dictionary format");
                }
                asList = Arrays.asList(VectorHelper.loadTermDictionary(configuration, FileSystem.get(new Path(obj2).toUri(), configuration), obj2));
            }
            List<List<String>> list = topWordsForTopics(obj, configuration, asList, i);
            if (parse.hasOption(create3)) {
                File file = new File(parse.getValue(create3).toString());
                if (!file.exists() && !file.mkdirs()) {
                    throw new IOException("Could not create directory: " + file);
                }
                writeTopWords(list, file);
            } else {
                printTopWords(list);
            }
        } catch (OptionException e) {
            CommandLineUtil.printHelp(create7);
            throw e;
        }
    }

    private static void maybeEnqueue(Queue<StringDoublePair> queue, String str, double d, int i) {
        if (queue.size() >= i && d > queue.peek().score) {
            queue.poll();
        }
        if (queue.size() < i) {
            queue.add(new StringDoublePair(d, str));
        }
    }

    private static void printTopWords(List<List<String>> list) {
        for (int i = 0; i < list.size(); i++) {
            List<String> list2 = list.get(i);
            System.out.println("Topic " + i);
            System.out.println("===========");
            Iterator<String> it = list2.iterator();
            while (it.hasNext()) {
                System.out.println(it.next());
            }
        }
    }

    private static List<List<String>> topWordsForTopics(String str, Configuration configuration, List<String> list, int i) throws IOException {
        FileSystem fileSystem = new Path(str).getFileSystem(configuration);
        ArrayList arrayList = new ArrayList();
        IntPairWritable intPairWritable = new IntPairWritable();
        DoubleWritable doubleWritable = new DoubleWritable();
        for (FileStatus fileStatus : fileSystem.globStatus(new Path(str, "part-*"))) {
            SequenceFile.Reader reader = new SequenceFile.Reader(fileSystem, fileStatus.getPath(), configuration);
            while (reader.next(intPairWritable, doubleWritable)) {
                int first = intPairWritable.getFirst();
                int second = intPairWritable.getSecond();
                ensureQueueSize(arrayList, first);
                if (second >= 0 && first >= 0) {
                    maybeEnqueue((Queue) arrayList.get(first), list.get(second), doubleWritable.get(), i);
                }
            }
            reader.close();
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            arrayList2.add(i2, new LinkedList());
            Iterator it = ((PriorityQueue) arrayList.get(i2)).iterator();
            while (it.hasNext()) {
                ((List) arrayList2.get(i2)).add(0, ((StringDoublePair) it.next()).word);
            }
        }
        return arrayList2;
    }

    private static void writeTopWords(List<List<String>> list, File file) throws IOException {
        for (int i = 0; i < list.size(); i++) {
            List<String> list2 = list.get(i);
            PrintWriter printWriter = new PrintWriter(new FileWriter(new File(file, "topic-" + i)));
            printWriter.println("Topic " + i);
            printWriter.println("===========");
            Iterator<String> it = list2.iterator();
            while (it.hasNext()) {
                printWriter.println(it.next());
            }
            printWriter.close();
        }
    }
}
