package com.intel.analytics.bigdl.dllib.feature.text;

import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.feature.dataset.PaddingParam;
import com.intel.analytics.bigdl.dllib.feature.dataset.Sample;
import com.intel.analytics.bigdl.dllib.feature.dataset.SampleToMiniBatch;
import com.intel.analytics.bigdl.dllib.feature.dataset.SampleToMiniBatch$;
import com.intel.analytics.bigdl.dllib.feature.dataset.Transformer;
import com.intel.analytics.bigdl.dllib.models.utils.ModelBroadcast;
import com.intel.analytics.bigdl.dllib.models.utils.ModelBroadcast$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import org.apache.spark.SparkContext;
import org.apache.spark.rdd.RDD;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

/* compiled from: TextPredictor.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/feature/text/TextPredictor$.class */
public final class TextPredictor$ implements Serializable {
    public static final TextPredictor$ MODULE$ = null;

    static {
        new TextPredictor$();
    }

    public <T> TextPredictor<T> apply(AbstractModule<Activity, Activity, T> abstractModule, int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return new TextPredictor<>(abstractModule, i, classTag, tensorNumeric);
    }

    public <T> DistributedTextSet predict(DistributedTextSet distributedTextSet, AbstractModule<Activity, Activity, T> abstractModule, int i, boolean z, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        RDD<TextFeature> rdd = distributedTextSet.rdd();
        ModelBroadcast<T> broadcast = ModelBroadcast$.MODULE$.apply(classTag, tensorNumeric).broadcast(rdd.sparkContext(), abstractModule);
        int length = rdd.partitions().length;
        SparkContext sparkContext = rdd.sparkContext();
        Option<Object> some = new Some<>(BoxesRunTime.boxToInteger(length));
        Option<PaddingParam<T>> apply$default$2 = SampleToMiniBatch$.MODULE$.apply$default$2();
        Option<PaddingParam<T>> apply$default$3 = SampleToMiniBatch$.MODULE$.apply$default$3();
        boolean apply$default$5 = SampleToMiniBatch$.MODULE$.apply$default$5();
        return (DistributedTextSet) TextSet$.MODULE$.rdd(rdd.mapPartitions(new TextPredictor$$anonfun$1(i, z, classTag, tensorNumeric, broadcast, sparkContext.broadcast(SampleToMiniBatch$.MODULE$.apply(length * i, apply$default$2, apply$default$3, some, apply$default$5, classTag, tensorNumeric), ClassTag$.MODULE$.apply(SampleToMiniBatch.class))), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(TextFeature.class))).setWordIndex(distributedTextSet.getWordIndex());
    }

    public <T> boolean predict$default$4() {
        return false;
    }

    public <T> Seq<TextFeature> predictTextBatch(AbstractModule<Activity, Activity, T> abstractModule, Seq<TextFeature> seq, Transformer<Sample<T>, MiniBatch<T>> transformer, boolean z, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        seq.toIterator().zip(transformer.apply(((Seq) seq.map(new TextPredictor$$anonfun$2(), Seq$.MODULE$.canBuildFrom())).toIterator()).flatMap(new TextPredictor$$anonfun$3(abstractModule, z, classTag, tensorNumeric))).foreach(new TextPredictor$$anonfun$predictTextBatch$1());
        return seq;
    }

    public <T> boolean predictTextBatch$default$4() {
        return false;
    }

    public <T> Tensor<T>[] com$intel$analytics$bigdl$dllib$feature$text$TextPredictor$$splitTensor(Tensor<T> tensor, boolean z, int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tensor<T> m2011clone = z ? tensor : tensor.m2011clone();
        int size = m2011clone.size(1);
        Log4Error$.MODULE$.invalidOperationError(i == size, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"batchSize is required to be ", ", while the actual batchSize is ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(size), BoxesRunTime.boxToInteger(i)})), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        return m2011clone.split(1);
    }

    private Object readResolve() {
        return MODULE$;
    }

    private TextPredictor$() {
        MODULE$ = this;
    }
}
