package com.intel.analytics.bigdl.dllib.utils.tf;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.tf.AssignGrad;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Node;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.TraversableLike;
import scala.collection.immutable.Set;
import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction1;

/* JADX INFO: Add missing generic type declarations: [T] */
/* compiled from: TensorflowLoader.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dllib/utils/tf/TensorflowLoader$$anonfun$20.class */
public final class TensorflowLoader$$anonfun$20<T> extends AbstractFunction1<Node<AbstractModule<Activity, Activity, T>>, Node<AbstractModule<Activity, Activity, T>>> implements Serializable {
    public static final long serialVersionUID = 0;
    private final ClassTag evidence$5$1;
    private final TensorNumericMath.TensorNumeric ev$3;
    private final Context context$2;

    public final Node<AbstractModule<Activity, Activity, T>> apply(Node<AbstractModule<Activity, Activity, T>> node) {
        Set set = (Set) ((TraversableLike) this.context$2.assignGrads().get()).filter(new TensorflowLoader$$anonfun$20$$anonfun$21(this, node));
        Log4Error$.MODULE$.invalidInputError(set.size() <= 1, "Invalid gradients output", Log4Error$.MODULE$.invalidInputError$default$3());
        return set.size() == 1 ? new AssignGrad((Tensor) this.context$2.apply((String) ((Tuple2) set.head())._1())._2(), this.evidence$5$1, this.ev$3).inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{node})) : node;
    }

    public TensorflowLoader$$anonfun$20(ClassTag classTag, TensorNumericMath.TensorNumeric tensorNumeric, Context context) {
        this.evidence$5$1 = classTag;
        this.ev$3 = tensorNumeric;
        this.context$2 = context;
    }
}
