package com.kotlinnlp.tokensencoder.ensamble.affine;

import com.kotlinnlp.neuralparser.language.Token;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.LayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.merge.affine.AffineLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.affine.AffineLayerStructure;
import com.kotlinnlp.simplednn.core.layers.models.merge.affine.AffineLayersPool;
import com.kotlinnlp.simplednn.core.optimizer.IterableParams;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.tokensencoder.TokensEncoder;
import com.kotlinnlp.tokensencoder.TokensEncoderBuilder;
import com.kotlinnlp.tokensencoder.TokensEncoderFactory;
import com.kotlinnlp.tokensencoder.TokensEncoderModel;
import com.kotlinnlp.tokensencoder.TokensEncoderParameters;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: AffineTokensEncoder.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��\\\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0010!\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\b\u0016\u0018��2\u00020\u0001B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J\u0016\u0010\u0015\u001a\u00020\u00162\f\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\r0\u000fH\u0016J$\u0010\u0018\u001a\b\u0012\u0004\u0012\u00020\r0\u000f2\f\u0010\u0019\u001a\b\u0012\u0004\u0012\u00020\r0\u00132\u0006\u0010\u001a\u001a\u00020\rH\u0002J\u0016\u0010\u001b\u001a\u00020\r2\f\u0010\u001c\u001a\b\u0012\u0004\u0012\u00020\r0\u000fH\u0002J\u001c\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\r0\u000f2\f\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u001f0\u000fH\u0016J\u000e\u0010 \u001a\b\u0012\u0004\u0012\u00020\r0\u0013H\u0002J\b\u0010!\u001a\u00020\tH\u0002J\u0010\u0010\"\u001a\u00020#2\u0006\u0010$\u001a\u00020\u0005H\u0016J\b\u0010%\u001a\u00020\u0016H\u0002R\u0014\u0010\u0007\u001a\b\u0012\u0004\u0012\u00020\t0\bX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\n\u001a\u00020\tX\u0082.¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\r0\fX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\u00100\u000fX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u001a\u0010\u0011\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\r0\u00130\u0012X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00010\u0012X\u0082\u0004¢\u0006\u0002\n��¨\u0006&"}, d2 = {"Lcom/kotlinnlp/tokensencoder/ensamble/affine/AffineTokensEncoder;", "Lcom/kotlinnlp/tokensencoder/TokensEncoder;", "model", "Lcom/kotlinnlp/tokensencoder/ensamble/affine/AffineTokensEncoderModel;", "trainingMode", "", "(Lcom/kotlinnlp/tokensencoder/ensamble/affine/AffineTokensEncoderModel;Z)V", "affineErrorsAccumulator", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "Lcom/kotlinnlp/simplednn/core/layers/models/merge/affine/AffineLayerParameters;", "affineLayerParamsErrors", "affineLayersPool", "Lcom/kotlinnlp/simplednn/core/layers/models/merge/affine/AffineLayersPool;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "encodersBuilder", "", "Lcom/kotlinnlp/tokensencoder/TokensEncoderBuilder;", "usedAffineLayers", "", "Lcom/kotlinnlp/simplednn/core/layers/models/merge/affine/AffineLayerStructure;", "usedEncoders", "backward", "", "errors", "backwardAffineLayer", "layer", "outputErrors", "doAffineTransform", "inputVectors", "encode", "tokens", "Lcom/kotlinnlp/neuralparser/language/Token;", "getAffineLayer", "getAffineParamsErrors", "getParamsErrors", "Lcom/kotlinnlp/tokensencoder/TokensEncoderParameters;", "copy", "reset", "tokensencoder"})
/* loaded from: input_file:com/kotlinnlp/tokensencoder/ensamble/affine/AffineTokensEncoder.class */
public class AffineTokensEncoder implements TokensEncoder {
    private final List<TokensEncoderBuilder> encodersBuilder;
    private final List<TokensEncoder> usedEncoders;
    private final AffineLayersPool<DenseNDArray> affineLayersPool;
    private final List<AffineLayerStructure<DenseNDArray>> usedAffineLayers;
    private AffineLayerParameters affineLayerParamsErrors;
    private ParamsErrorsAccumulator<AffineLayerParameters> affineErrorsAccumulator;
    private final AffineTokensEncoderModel model;
    private final boolean trainingMode;

    @Override // com.kotlinnlp.tokensencoder.TokensEncoder
    @NotNull
    public List<DenseNDArray> encode(@NotNull List<Token> list) {
        Intrinsics.checkParameterIsNotNull(list, "tokens");
        reset();
        List<TokensEncoder> list2 = this.usedEncoders;
        List<TokensEncoderBuilder> list3 = this.encodersBuilder;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list3, 10));
        Iterator<T> it = list3.iterator();
        while (it.hasNext()) {
            arrayList.add(((TokensEncoderBuilder) it.next()).invoke());
        }
        list2.addAll(arrayList);
        int size = list.size();
        ArrayList arrayList2 = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            arrayList2.add(new ArrayList());
        }
        ArrayList arrayList3 = arrayList2;
        Iterator<T> it2 = this.usedEncoders.iterator();
        while (it2.hasNext()) {
            int i2 = 0;
            for (Object obj : ((TokensEncoder) it2.next()).encode(list)) {
                int i3 = i2;
                i2++;
                ((List) arrayList3.get(i3)).add((DenseNDArray) obj);
            }
        }
        ArrayList arrayList4 = arrayList3;
        ArrayList arrayList5 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList4, 10));
        Iterator it3 = arrayList4.iterator();
        while (it3.hasNext()) {
            arrayList5.add(doAffineTransform((List) it3.next()));
        }
        return arrayList5;
    }

    @Override // com.kotlinnlp.tokensencoder.TokensEncoder
    public void backward(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "errors");
        List<DenseNDArray> list2 = list;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        int i = 0;
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            arrayList.add(backwardAffineLayer(this.usedAffineLayers.get(i2), (DenseNDArray) it.next()));
        }
        ArrayList arrayList2 = arrayList;
        this.affineErrorsAccumulator.averageErrors();
        int i3 = 0;
        for (Object obj : this.usedEncoders) {
            int i4 = i3;
            i3++;
            TokensEncoder tokensEncoder = (TokensEncoder) obj;
            ArrayList arrayList3 = arrayList2;
            ArrayList arrayList4 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList3, 10));
            Iterator it2 = arrayList3.iterator();
            while (it2.hasNext()) {
                arrayList4.add((DenseNDArray) ((List) it2.next()).get(i4));
            }
            tokensEncoder.backward(arrayList4);
        }
    }

    @Override // com.kotlinnlp.tokensencoder.TokensEncoder
    @NotNull
    public TokensEncoderParameters getParamsErrors(boolean z) {
        List<TokensEncoder> list = this.usedEncoders;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(((TokensEncoder) it.next()).getParamsErrors(z));
        }
        return new AffineTokensEncoderParams(arrayList, this.affineErrorsAccumulator.getParamsErrors(z));
    }

    private final DenseNDArray doAffineTransform(List<DenseNDArray> list) {
        AffineLayerStructure<DenseNDArray> affineLayer = getAffineLayer();
        int i = 0;
        for (Object obj : list) {
            int i2 = i;
            i++;
            affineLayer.setInput(i2, (DenseNDArray) obj);
        }
        LayerStructure.forward$default(affineLayer, false, 1, (Object) null);
        return affineLayer.getOutputArray().getValues();
    }

    private final AffineLayerStructure<DenseNDArray> getAffineLayer() {
        this.usedAffineLayers.add(this.affineLayersPool.getItem());
        return (AffineLayerStructure) CollectionsKt.last(this.usedAffineLayers);
    }

    private final List<DenseNDArray> backwardAffineLayer(AffineLayerStructure<DenseNDArray> affineLayerStructure, DenseNDArray denseNDArray) {
        IterableParams affineParamsErrors = getAffineParamsErrors();
        affineLayerStructure.setErrors(denseNDArray);
        affineLayerStructure.backward((LayerParameters) affineParamsErrors, true, (Double) null);
        ParamsErrorsAccumulator.accumulate$default(this.affineErrorsAccumulator, affineParamsErrors, false, 2, (Object) null);
        return affineLayerStructure.getInputErrors(true);
    }

    private final AffineLayerParameters getAffineParamsErrors() {
        if (this.affineLayerParamsErrors == null) {
            this.affineLayerParamsErrors = ((AffineLayerStructure) CollectionsKt.last(this.usedAffineLayers)).getParams().copy();
        }
        AffineLayerParameters affineLayerParameters = this.affineLayerParamsErrors;
        if (affineLayerParameters == null) {
            Intrinsics.throwUninitializedPropertyAccessException("affineLayerParamsErrors");
        }
        return affineLayerParameters;
    }

    private final void reset() {
        this.affineLayersPool.releaseAll();
        this.usedAffineLayers.clear();
        this.usedEncoders.clear();
        this.affineErrorsAccumulator.reset();
    }

    public AffineTokensEncoder(@NotNull AffineTokensEncoderModel affineTokensEncoderModel, boolean z) {
        Intrinsics.checkParameterIsNotNull(affineTokensEncoderModel, "model");
        this.model = affineTokensEncoderModel;
        this.trainingMode = z;
        List<TokensEncoderModel> models = this.model.getModels();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(models, 10));
        for (TokensEncoderModel tokensEncoderModel : models) {
            if (!(!(tokensEncoderModel instanceof AffineTokensEncoderModel))) {
                throw new IllegalArgumentException("Failed requirement.".toString());
            }
            arrayList.add(TokensEncoderFactory.INSTANCE.invoke(tokensEncoderModel, this.trainingMode));
        }
        this.encodersBuilder = arrayList;
        this.usedEncoders = new ArrayList();
        this.affineLayersPool = new AffineLayersPool<>(this.model.getAffineParams(), this.model.getActivation$tokensencoder(), 0.0d, 4, (DefaultConstructorMarker) null);
        this.usedAffineLayers = new ArrayList();
        this.affineErrorsAccumulator = new ParamsErrorsAccumulator<>();
    }

    @NotNull
    public static final /* synthetic */ AffineLayerParameters access$getAffineLayerParamsErrors$p(AffineTokensEncoder affineTokensEncoder) {
        AffineLayerParameters affineLayerParameters = affineTokensEncoder.affineLayerParamsErrors;
        if (affineLayerParameters == null) {
            Intrinsics.throwUninitializedPropertyAccessException("affineLayerParamsErrors");
        }
        return affineLayerParameters;
    }
}
