package hivemall.nlp.tokenizer;

import hivemall.UDFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.HttpUtils;
import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.ExceptionUtils;
import hivemall.utils.lang.Preconditions;
import io.netty.handler.codec.rtsp.RtspHeaders;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.StringReader;
import java.net.HttpURLConnection;
import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.Text;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.ja.JapaneseAnalyzer;
import org.apache.lucene.analysis.ja.JapaneseTokenizer;
import org.apache.lucene.analysis.ja.dict.UserDictionary;
import org.apache.lucene.analysis.ja.tokenattributes.PartOfSpeechAttribute;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.util.CharArraySet;
import org.apache.lucene.codecs.lucene50.Lucene50PostingsFormat;

@UDFType(deterministic = true, stateful = false)
@Description(name = "tokenize_ja", value = "_FUNC_(String line [, const string mode = \"normal\", const array<string> stopWords, const array<string> stopTags, const array<string> userDict (or string userDictURL)]) - returns tokenized strings in array<string>", extended = "select tokenize_ja(\"kuromojiを使った分かち書きのテストです。第二引数にはnormal/search/extendedを指定できます。デフォルトではnormalモードです。\");\n\n> [\"kuromoji\",\"使う\",\"分かち書き\",\"テスト\",\"第\",\"二\",\"引数\",\"normal\",\"search\",\"extended\",\"指定\",\"デフォルト\",\"normal\",\" モード\"]\n")
/* loaded from: input_file:hivemall/nlp/tokenizer/KuromojiUDF.class */
public final class KuromojiUDF extends UDFWithOptions {
    private static final int CONNECT_TIMEOUT_MS = 10000;
    private static final int READ_TIMEOUT_MS = 60000;
    private static final long MAX_INPUT_STREAM_SIZE = 33554432;
    private JapaneseTokenizer.Mode _mode;
    private boolean _returnPos;
    private transient Object[] _result;

    @Nullable
    private String[] _stopWordsArray;
    private Set<String> _stopTags;

    @Nullable
    private Object _userDictObj;
    private transient JapaneseAnalyzer _analyzer;

    @Override // hivemall.UDFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption(RtspHeaders.Values.MODE, true, "The tokenization mode. One of ['normal', 'search', 'extended', 'default' (normal)]");
        options.addOption(Lucene50PostingsFormat.POS_EXTENSION, false, "Return part-of-speech information");
        return options;
    }

    @Override // hivemall.UDFWithOptions
    protected CommandLine processOptions(String str) throws UDFArgumentException {
        CommandLine parseOptions = parseOptions(str);
        if (parseOptions.hasOption(RtspHeaders.Values.MODE)) {
            this._mode = tokenizationMode(parseOptions.getOptionValue(RtspHeaders.Values.MODE));
        }
        this._returnPos = parseOptions.hasOption(Lucene50PostingsFormat.POS_EXTENSION);
        return parseOptions;
    }

    public ObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        String constString;
        int length = objectInspectorArr.length;
        if (length < 1 || length > 5) {
            showHelp("Invalid number of arguments for `tokenize_ja`: " + length);
        }
        this._mode = JapaneseTokenizer.Mode.NORMAL;
        if (length >= 2 && (constString = HiveUtils.getConstString(objectInspectorArr[1])) != null) {
            if (constString.startsWith("-")) {
                processOptions(constString);
            } else {
                this._mode = tokenizationMode(constString);
            }
        }
        if (length >= 3 && !HiveUtils.isVoidOI(objectInspectorArr[2])) {
            this._stopWordsArray = HiveUtils.getConstStringArray(objectInspectorArr[2]);
        }
        this._stopTags = length >= 4 ? stopTags(objectInspectorArr[3]) : JapaneseAnalyzer.getDefaultStopTags();
        if (length >= 5) {
            if (HiveUtils.isConstListOI(objectInspectorArr[4])) {
                this._userDictObj = HiveUtils.getConstStringArray(objectInspectorArr[4]);
            } else {
                if (!HiveUtils.isConstString(objectInspectorArr[4])) {
                    throw new UDFArgumentException("User dictionary MUST be given as an array of constant string or constant string (URL)");
                }
                this._userDictObj = HiveUtils.getConstString(objectInspectorArr[4]);
            }
        }
        this._analyzer = null;
        if (!this._returnPos) {
            return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        }
        this._result = new Object[2];
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("tokens");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector));
        arrayList.add(Lucene50PostingsFormat.POS_EXTENSION);
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector));
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public Object evaluate(GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        if (this._analyzer == null) {
            CharArraySet stopWords = stopWords(this._stopWordsArray);
            UserDictionary userDictionary = null;
            if (this._userDictObj instanceof String[]) {
                userDictionary = userDictionary((String[]) this._userDictObj);
            } else if (this._userDictObj instanceof String) {
                userDictionary = userDictionary((String) this._userDictObj);
            }
            this._analyzer = new JapaneseAnalyzer(userDictionary, this._mode, stopWords, this._stopTags);
        }
        Object obj = deferredObjectArr[0].get();
        if (obj == null) {
            return null;
        }
        String obj2 = obj.toString();
        return this._returnPos ? parseLine(this._analyzer, obj2, this._result) : parseLine(this._analyzer, obj2);
    }

    @Nonnull
    private static Object[] parseLine(@Nonnull JapaneseAnalyzer japaneseAnalyzer, @Nonnull String str, @Nonnull Object[] objArr) throws HiveException {
        Objects.requireNonNull(objArr);
        Preconditions.checkArgument(objArr.length == 2);
        ArrayList arrayList = new ArrayList(32);
        ArrayList arrayList2 = new ArrayList(32);
        TokenStream tokenStream = null;
        try {
            try {
                tokenStream = japaneseAnalyzer.tokenStream("", str);
                if (tokenStream != null) {
                    analyzeTokens(tokenStream, arrayList, arrayList2);
                }
                IOUtils.closeQuietly(tokenStream);
                objArr[0] = arrayList;
                objArr[1] = arrayList2;
                return objArr;
            } catch (IOException e) {
                IOUtils.closeQuietly(japaneseAnalyzer);
                throw new HiveException(e);
            }
        } catch (Throwable th) {
            IOUtils.closeQuietly(tokenStream);
            throw th;
        }
    }

    @Nonnull
    private static List<Text> parseLine(@Nonnull JapaneseAnalyzer japaneseAnalyzer, @Nonnull String str) throws HiveException {
        ArrayList arrayList = new ArrayList(32);
        TokenStream tokenStream = null;
        try {
            try {
                tokenStream = japaneseAnalyzer.tokenStream("", str);
                if (tokenStream != null) {
                    analyzeTokens(tokenStream, arrayList);
                }
                IOUtils.closeQuietly(tokenStream);
                return arrayList;
            } catch (IOException e) {
                IOUtils.closeQuietly(japaneseAnalyzer);
                throw new HiveException(e);
            }
        } catch (Throwable th) {
            IOUtils.closeQuietly(tokenStream);
            throw th;
        }
    }

    public void close() throws IOException {
        IOUtils.closeQuietly(this._analyzer);
    }

    @Nonnull
    private static JapaneseTokenizer.Mode tokenizationMode(@Nonnull String str) throws UDFArgumentException {
        JapaneseTokenizer.Mode mode;
        if ("NORMAL".equalsIgnoreCase(str)) {
            mode = JapaneseTokenizer.Mode.NORMAL;
        } else if ("SEARCH".equalsIgnoreCase(str)) {
            mode = JapaneseTokenizer.Mode.SEARCH;
        } else if ("EXTENDED".equalsIgnoreCase(str)) {
            mode = JapaneseTokenizer.Mode.EXTENDED;
        } else {
            if (!"DEFAULT".equalsIgnoreCase(str)) {
                throw new UDFArgumentException("Expected NORMAL|SEARCH|EXTENDED|DEFAULT but got an unexpected mode: " + str);
            }
            mode = JapaneseTokenizer.DEFAULT_MODE;
        }
        return mode;
    }

    @Nonnull
    private static CharArraySet stopWords(@Nullable String[] strArr) throws UDFArgumentException {
        return strArr == null ? JapaneseAnalyzer.getDefaultStopSet() : strArr.length == 0 ? CharArraySet.EMPTY_SET : new CharArraySet((Collection<?>) Arrays.asList(strArr), true);
    }

    @Nonnull
    private static Set<String> stopTags(@Nonnull ObjectInspector objectInspector) throws UDFArgumentException {
        String[] constStringArray;
        if (!HiveUtils.isVoidOI(objectInspector) && (constStringArray = HiveUtils.getConstStringArray(objectInspector)) != null) {
            int length = constStringArray.length;
            if (length == 0) {
                return Collections.emptySet();
            }
            HashSet hashSet = new HashSet(length);
            for (String str : constStringArray) {
                if (str != null) {
                    hashSet.add(str);
                }
            }
            return hashSet;
        }
        return JapaneseAnalyzer.getDefaultStopTags();
    }

    @Nullable
    private static UserDictionary userDictionary(@Nullable String[] strArr) throws UDFArgumentException {
        if (strArr == null) {
            return null;
        }
        StringBuilder sb = new StringBuilder();
        for (String str : strArr) {
            sb.append(str).append('\n');
        }
        try {
            return UserDictionary.open(new StringReader(sb.toString()));
        } catch (Throwable th) {
            throw new UDFArgumentException("Failed to create user dictionary based on the given array<string>: " + sb.toString() + '\n' + ExceptionUtils.prettyPrintStackTrace(th));
        }
    }

    @Nullable
    private static UserDictionary userDictionary(@Nullable String str) throws UDFArgumentException {
        if (str == null) {
            return null;
        }
        try {
            HttpURLConnection httpURLConnection = HttpUtils.getHttpURLConnection(str);
            httpURLConnection.setRequestProperty("Accept-Encoding", "gzip");
            httpURLConnection.setConnectTimeout(10000);
            httpURLConnection.setReadTimeout(READ_TIMEOUT_MS);
            try {
                int responseCode = httpURLConnection.getResponseCode();
                if (responseCode != 200) {
                    throw new UDFArgumentException("Got invalid response code: " + responseCode);
                }
                try {
                    try {
                        return UserDictionary.open(new InputStreamReader(IOUtils.decodeInputStream(HttpUtils.getLimitedInputStream(httpURLConnection, MAX_INPUT_STREAM_SIZE)), StandardCharsets.UTF_8.newDecoder().onMalformedInput(CodingErrorAction.REPORT).onUnmappableCharacter(CodingErrorAction.REPORT)));
                    } catch (Throwable th) {
                        throw new UDFArgumentException("Failed to parse the file in CSV format (UTF-8 encoding is expected): " + str + '\n' + ExceptionUtils.prettyPrintStackTrace(th));
                    }
                } catch (IOException | NullPointerException e) {
                    throw new UDFArgumentException("Failed to get input stream from the connection: " + str + '\n' + ExceptionUtils.prettyPrintStackTrace(e));
                }
            } catch (IOException e2) {
                throw new UDFArgumentException("Failed to get response code: " + str + '\n' + ExceptionUtils.prettyPrintStackTrace(e2));
            }
        } catch (IOException | IllegalArgumentException e3) {
            throw new UDFArgumentException("Failed to create HTTP connection to the URL: " + str + '\n' + ExceptionUtils.prettyPrintStackTrace(e3));
        }
    }

    private static void analyzeTokens(@Nonnull TokenStream tokenStream, @Nonnull List<Text> list) throws IOException {
        CharTermAttribute charTermAttribute = (CharTermAttribute) tokenStream.getAttribute(CharTermAttribute.class);
        tokenStream.reset();
        while (tokenStream.incrementToken()) {
            list.add(new Text(charTermAttribute.toString()));
        }
    }

    private static void analyzeTokens(@Nonnull TokenStream tokenStream, @Nonnull List<Text> list, @Nonnull List<Text> list2) throws IOException {
        CharTermAttribute charTermAttribute = (CharTermAttribute) tokenStream.getAttribute(CharTermAttribute.class);
        PartOfSpeechAttribute partOfSpeechAttribute = (PartOfSpeechAttribute) tokenStream.addAttribute(PartOfSpeechAttribute.class);
        tokenStream.reset();
        while (tokenStream.incrementToken()) {
            list.add(new Text(charTermAttribute.toString()));
            list2.add(new Text(partOfSpeechAttribute.getPartOfSpeech()));
        }
    }

    public String getDisplayString(String[] strArr) {
        return "tokenize_ja(" + Arrays.toString(strArr) + ')';
    }
}
