package com.intel.analytics.bigdl.dllib.models.rnn;

import com.intel.analytics.bigdl.dllib.feature.dataset.DataSet$;
import com.intel.analytics.bigdl.dllib.feature.dataset.FixedLength;
import com.intel.analytics.bigdl.dllib.feature.dataset.LocalDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.feature.dataset.PaddingParam;
import com.intel.analytics.bigdl.dllib.feature.dataset.Sample;
import com.intel.analytics.bigdl.dllib.feature.dataset.SampleToMiniBatch$;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.Dictionary;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.Dictionary$;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.LabeledSentence;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.LabeledSentenceToSample$;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.TextToLabeledSentence$;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.utils.SentenceToken$;
import com.intel.analytics.bigdl.dllib.models.rnn.Utils;
import com.intel.analytics.bigdl.dllib.nn.CrossEntropyCriterion$;
import com.intel.analytics.bigdl.dllib.nn.Module$;
import com.intel.analytics.bigdl.dllib.nn.TimeDistributedCriterion$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.optim.Loss$mcF$sp;
import com.intel.analytics.bigdl.dllib.optim.ValidationMethod;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.package$;
import org.apache.spark.SparkContext;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.collection.Seq;
import scala.collection.mutable.StringBuilder;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: Test.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/models/rnn/Test$$anonfun$main$1.class */
public final class Test$$anonfun$main$1 extends AbstractFunction1<Utils.TestParams, BoxedUnit> implements Serializable {
    public static final long serialVersionUID = 0;

    public final void apply(Utils.TestParams testParams) {
        Dictionary apply = Dictionary$.MODULE$.apply(testParams.folder());
        SparkContext sparkContext = new SparkContext(Engine$.MODULE$.createSparkConf(Engine$.MODULE$.createSparkConf$default$1()).setAppName("Test rnn on text").set("spark.task.maxFailures", "1"));
        Engine$.MODULE$.init();
        AbstractModule loadModule = Module$.MODULE$.loadModule((String) testParams.modelSnapshot().get(), Module$.MODULE$.loadModule$default$2(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        if (testParams.evaluate()) {
            String[][] strArr = (String[][]) SequencePreprocess$.MODULE$.apply(new StringBuilder().append(testParams.folder()).append("/test.txt").toString(), sparkContext, testParams.sentFile(), testParams.tokenFile()).collect();
            int unboxToInt = BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.refArrayOps(strArr).map(new Test$$anonfun$main$1$$anonfun$1(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).max(Ordering$Int$.MODULE$));
            int vocabSize = apply.getVocabSize() + 1;
            int index = apply.getIndex(SentenceToken$.MODULE$.start());
            int index2 = apply.getIndex(SentenceToken$.MODULE$.end());
            Tensor resize = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).resize(vocabSize);
            resize.setValue(index2 + 1, BoxesRunTime.boxToFloat(1.0f));
            LocalDataSet local = DataSet$.MODULE$.array(strArr).transform(TextToLabeledSentence$.MODULE$.apply(apply, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.apply(LabeledSentence.class)).transform(LabeledSentenceToSample$.MODULE$.apply(vocabSize, LabeledSentenceToSample$.MODULE$.apply$default$2(), LabeledSentenceToSample$.MODULE$.apply$default$3(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.apply(Sample.class)).transform(SampleToMiniBatch$.MODULE$.apply(testParams.batchSize(), new Some(new PaddingParam(new Some(new Tensor[]{resize}), new FixedLength(new int[]{unboxToInt}), ClassTag$.MODULE$.Float())), new Some(new PaddingParam(new Some(new Tensor[]{Tensor$.MODULE$.apply(T$.MODULE$.apply(BoxesRunTime.boxToFloat(index + 1.0f), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[0])), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)}), new FixedLength(new int[]{unboxToInt}), ClassTag$.MODULE$.Float())), SampleToMiniBatch$.MODULE$.apply$default$4(), SampleToMiniBatch$.MODULE$.apply$default$5(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.apply(MiniBatch.class)).toLocal();
            package$ package_ = package$.MODULE$;
            TimeDistributedCriterion$ timeDistributedCriterion$ = TimeDistributedCriterion$.MODULE$;
            CrossEntropyCriterion$ crossEntropyCriterion$ = CrossEntropyCriterion$.MODULE$;
            CrossEntropyCriterion$.MODULE$.apply$default$1();
            Predef$.MODULE$.refArrayOps(loadModule.evaluate(local, new ValidationMethod[]{new Loss$mcF$sp(package_.convCriterion(timeDistributedCriterion$.apply$mFc$sp(crossEntropyCriterion$.apply$mFc$sp(null, CrossEntropyCriterion$.MODULE$.apply$default$2(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), true, TimeDistributedCriterion$.MODULE$.apply$default$3(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)})).foreach(new Test$$anonfun$main$1$$anonfun$apply$3(this));
        } else {
            Tensor apply2 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
            LabeledSentence[] labeledSentenceArr = (LabeledSentence[]) Predef$.MODULE$.refArrayOps((float[][]) Predef$.MODULE$.refArrayOps(Utils$.MODULE$.readSentence(testParams.folder())).map(new Test$$anonfun$main$1$$anonfun$2(this, apply), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))))).map(new Test$$anonfun$main$1$$anonfun$3(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(LabeledSentence.class)));
            int vocabSize2 = apply.getVocabSize() + 1;
            int batchSize = testParams.batchSize();
            RDD parallelize = sparkContext.parallelize(Predef$.MODULE$.wrapRefArray(labeledSentenceArr), sparkContext.parallelize$default$2(), ClassTag$.MODULE$.apply(LabeledSentence.class));
            RDD mapPartitions = parallelize.mapPartitions(new Test$$anonfun$main$1$$anonfun$4(this, vocabSize2), parallelize.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Sample.class));
            RDD mapPartitions2 = mapPartitions.mapPartitions(new Test$$anonfun$main$1$$anonfun$5(this, batchSize), mapPartitions.mapPartitions$default$2(), ClassTag$.MODULE$.apply(MiniBatch.class));
            Predef$.MODULE$.refArrayOps((String[][]) Predef$.MODULE$.refArrayOps((float[][]) Predef$.MODULE$.refArrayOps((Object[]) mapPartitions2.mapPartitions(new Test$$anonfun$main$1$$anonfun$6(this, loadModule, 2, 3, apply2, testParams), mapPartitions2.mapPartitions$default$2(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)))).collect()).flatMap(new Test$$anonfun$main$1$$anonfun$7(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))))).map(new Test$$anonfun$main$1$$anonfun$8(this, apply), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(String.class))))).foreach(new Test$$anonfun$main$1$$anonfun$apply$7(this));
        }
        sparkContext.stop();
    }

    public final /* bridge */ /* synthetic */ Object apply(Object obj) {
        apply((Utils.TestParams) obj);
        return BoxedUnit.UNIT;
    }
}
