package ai.djl.nn;

import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:ai/djl/nn/AbstractBlock.class */
public abstract class AbstractBlock implements Block {
    protected Shape[] inputShapes;
    protected List<String> inputNames = Collections.singletonList("data");

    @Override // ai.djl.nn.Block
    public PairList<String, Shape> describeInput() {
        if (isInitialized()) {
            return new PairList<>(this.inputNames, Arrays.asList(this.inputShapes));
        }
        throw new IllegalStateException("Parameter of this block are not initialised");
    }

    @Override // ai.djl.nn.Block
    public void setInitializer(Initializer initializer) {
        Iterator<Parameter> it = getDirectParameters().iterator();
        while (it.hasNext()) {
            it.next().setInitializer(initializer, false);
        }
        Iterator<Block> it2 = getChildren().values().iterator();
        while (it2.hasNext()) {
            it2.next().setInitializer(initializer);
        }
    }

    @Override // ai.djl.nn.Block
    public void setInitializer(Initializer initializer, String str) {
        getDirectParameters().stream().filter(parameter -> {
            return parameter.getName().equals(str);
        }).findFirst().orElseThrow(() -> {
            return new IllegalArgumentException("Could not find parameter " + str);
        }).setInitializer(initializer, true);
    }

    @Override // ai.djl.nn.Block
    public ParameterList getParameters() {
        ParameterList parameterList = new ParameterList();
        getDirectParameters().forEach(parameter -> {
            parameterList.add(parameter.getName(), parameter);
        });
        ParameterList childrenParameters = getChildrenParameters();
        parameterList.getClass();
        childrenParameters.forEach(parameterList::add);
        return parameterList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void beforeInitialize(Shape[] shapeArr) {
        this.inputShapes = shapeArr;
    }

    @Override // ai.djl.nn.Block
    public boolean isInitialized() {
        Iterator<Parameter> it = getParameters().values().iterator();
        while (it.hasNext()) {
            if (!it.next().isInitialized()) {
                return false;
            }
        }
        return true;
    }

    @Override // ai.djl.nn.Block
    public void clear() {
        getParameters().forEach(pair -> {
            ((Parameter) pair.getValue()).close();
        });
    }

    @Override // ai.djl.nn.Block
    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    private ParameterList getChildrenParameters() {
        ParameterList parameterList = new ParameterList();
        Iterator<Pair<String, Block>> it = getChildren().iterator();
        while (it.hasNext()) {
            Pair<String, Block> next = it.next();
            Iterator<Pair<String, Parameter>> it2 = next.getValue().getParameters().iterator();
            while (it2.hasNext()) {
                Pair<String, Parameter> next2 = it2.next();
                parameterList.add(next.getKey() + "_" + next2.getKey(), next2.getValue());
            }
        }
        return parameterList;
    }
}
