package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;

/* loaded from: input_file:ai/djl/nn/SequentialBlock.class */
public class SequentialBlock extends AbstractBlock {
    private static final byte VERSION = 2;
    private List<Block> blocks = new ArrayList();

    public SequentialBlock addAll(Block... blockArr) {
        this.blocks.addAll(Arrays.asList(blockArr));
        return this;
    }

    public SequentialBlock addAll(Collection<Block> collection) {
        this.blocks.addAll(collection);
        return this;
    }

    public SequentialBlock add(Block block) {
        this.blocks.add(block);
        return this;
    }

    public SequentialBlock add(Function<NDList, NDList> function) {
        this.blocks.add(new LambdaBlock(function));
        return this;
    }

    public void removeLastBlock() {
        this.blocks.remove(this.blocks.size() - 1);
    }

    public void replaceLastBlock(Block block) {
        removeLastBlock();
        this.blocks.add(block);
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, PairList<String, Object> pairList) {
        NDList nDList2 = nDList;
        Iterator<Block> it = this.blocks.iterator();
        while (it.hasNext()) {
            nDList2 = it.next().forward(parameterStore, nDList2);
        }
        return nDList2;
    }

    @Override // ai.djl.nn.Block
    public Shape[] initialize(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        beforeInitialize(shapeArr);
        Shape[] shapeArr2 = shapeArr;
        Iterator<Block> it = getChildren().values().iterator();
        while (it.hasNext()) {
            shapeArr2 = it.next().initialize(nDManager, dataType, shapeArr2);
        }
        return getOutputShapes(nDManager, shapeArr);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        if (this.blocks.isEmpty()) {
            throw new IllegalArgumentException("The sequential block is empty");
        }
        Shape[] shapeArr2 = shapeArr;
        Iterator<Block> it = this.blocks.iterator();
        while (it.hasNext()) {
            shapeArr2 = it.next().getOutputShapes(nDManager, shapeArr2);
        }
        return shapeArr2;
    }

    @Override // ai.djl.nn.Block
    public List<Parameter> getDirectParameters() {
        return Collections.emptyList();
    }

    @Override // ai.djl.nn.Block
    public Shape getParameterShape(String str, Shape[] shapeArr) {
        throw new IllegalArgumentException("SequentialBlocks have no parameters");
    }

    @Override // ai.djl.nn.Block
    public BlockList getChildren() {
        int size = this.blocks.size();
        BlockList blockList = new BlockList(size);
        String str = "%0" + (((int) Math.log10(size)) + 1) + "d:%s";
        for (int i = 0; i < size; i++) {
            Block block = this.blocks.get(i);
            blockList.add(String.format(str, Integer.valueOf(i), block.getClass().getSimpleName()), block);
        }
        return blockList;
    }

    @Override // ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeByte(VERSION);
        saveInputShapes(dataOutputStream);
        Iterator<Block> it = this.blocks.iterator();
        while (it.hasNext()) {
            it.next().saveParameters(dataOutputStream);
        }
    }

    @Override // ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        byte readByte = dataInputStream.readByte();
        if (readByte == VERSION) {
            readInputShapes(dataInputStream);
        } else if (readByte != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
        Iterator<Block> it = this.blocks.iterator();
        while (it.hasNext()) {
            it.next().loadParameters(nDManager, dataInputStream);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Sequential(\n");
        Iterator<Block> it = this.blocks.iterator();
        while (it.hasNext()) {
            sb.append('\t').append(it.next()).append('\n');
        }
        sb.append(')');
        return sb.toString();
    }
}
