package ai.djl.mxnet.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.nn.Parameter;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
import ai.djl.translate.Translator;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/mxnet/engine/MxModel.class */
public class MxModel extends BaseModel {
    private static final Logger logger = LoggerFactory.getLogger(MxModel.class);
    private AtomicBoolean first;

    /* JADX INFO: Access modifiers changed from: package-private */
    public MxModel(String str, Device device) {
        super(str);
        Device defaultIfNull = Device.defaultIfNull(device);
        this.dataType = DataType.FLOAT32;
        this.properties = new ConcurrentHashMap();
        this.manager = MxNDManager.getSystemManager().mo9newSubManager(defaultIfNull);
        this.first = new AtomicBoolean(true);
    }

    public void load(Path path, String str, Map<String, Object> map) throws IOException, MalformedModelException {
        this.modelDir = path.toAbsolutePath();
        if (str == null) {
            str = this.modelName;
        }
        Path paramPathResolver = paramPathResolver(str, map);
        if (paramPathResolver == null) {
            str = this.modelDir.toFile().getName();
            paramPathResolver = paramPathResolver(str, map);
            if (paramPathResolver == null) {
                throw new IOException("Parameter file not found in: " + this.modelDir);
            }
        }
        if (this.block == null) {
            Path resolve = this.modelDir.resolve(str + "-symbol.json");
            if (Files.notExists(resolve, new LinkOption[0])) {
                throw new FileNotFoundException("Symbol file not found: " + resolve + ", please set block manually for imperative model.");
            }
            this.block = new MxSymbolBlock(this.manager, Symbol.load(this.manager, resolve.toAbsolutePath().toString()));
        }
        loadParameters(paramPathResolver, map);
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        Initializer initializer = trainingConfig.getInitializer();
        if (this.block == null) {
            throw new IllegalStateException("You must set a block for the model before creating a new trainer");
        }
        this.block.setInitializer(initializer);
        return new Trainer(this, trainingConfig);
    }

    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
        return new Predictor<>(this, translator, (JnaUtils.useThreadSafePredictor() || this.first.getAndSet(false)) ? false : true);
    }

    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public String[] getArtifactNames() {
        try {
            List<Path> list = (List) Files.walk(this.modelDir, new FileVisitOption[0]).filter(path -> {
                return Files.isRegularFile(path, new LinkOption[0]);
            }).collect(Collectors.toList());
            ArrayList arrayList = new ArrayList(list.size());
            for (Path path2 : list) {
                String name = path2.toFile().getName();
                if (!name.endsWith(".params") && !name.endsWith("-symbol.json")) {
                    arrayList.add(this.modelDir.relativize(path2).toString());
                }
            }
            return (String[]) arrayList.toArray(new String[0]);
        } catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }

    public void close() {
        JnaUtils.waitAll();
        this.manager.close();
    }

    private void loadParameters(Path path, Map<String, Object> map) throws IOException, MalformedModelException {
        if (readParameters(path, map)) {
            return;
        }
        logger.debug("DJL formatted model not found, try to find MXNet model");
        NDList load = this.manager.load(path);
        MxSymbolBlock mxSymbolBlock = this.block;
        List<Parameter> allParameters = mxSymbolBlock.getAllParameters();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        allParameters.forEach(parameter -> {
        });
        Iterator it = load.iterator();
        while (it.hasNext()) {
            NDArray nDArray = (NDArray) it.next();
            String name = nDArray.getName();
            if (name == null) {
                throw new IllegalArgumentException("Array names must be present in parameter file");
            }
            ((Parameter) linkedHashMap.remove(name.split(":", 2)[1])).setArray(nDArray);
        }
        mxSymbolBlock.setInputNames(new ArrayList(linkedHashMap.keySet()));
        this.dataType = load.head().getDataType();
        logger.debug("MXNet Model {} ({}) loaded successfully.", path, this.dataType);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Model (\n\tName: ").append(this.modelName);
        if (this.modelDir != null) {
            sb.append("\n\tModel location: ").append(this.modelDir.toAbsolutePath());
        }
        sb.append("\n\tData Type: ").append(this.dataType);
        for (Map.Entry entry : this.properties.entrySet()) {
            sb.append("\n\t").append((String) entry.getKey()).append(": ").append((String) entry.getValue());
        }
        sb.append("\n)");
        return sb.toString();
    }
}
