package ai.djl.mxnet.engine;

import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.mxnet.jna.MxnetLibrary;
import ai.djl.mxnet.jna.NativeResource;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import com.sun.jna.Pointer;
import java.util.Arrays;

/* loaded from: input_file:ai/djl/mxnet/engine/MxParameterServer.class */
public class MxParameterServer extends NativeResource implements ParameterServer {
    private OptimizerCallback callback;
    private int priority;

    /* loaded from: input_file:ai/djl/mxnet/engine/MxParameterServer$OptimizerCallback.class */
    private static final class OptimizerCallback implements MxnetLibrary.MXKVStoreStrUpdater {
        private Optimizer optimizer;

        OptimizerCallback(Optimizer optimizer) {
            this.optimizer = optimizer;
        }

        @Override // ai.djl.mxnet.jna.MxnetLibrary.MXKVStoreStrUpdater
        public void apply(String str, Pointer pointer, Pointer pointer2, Pointer pointer3) {
            MxNDManager mo10newSubManager = MxNDManager.getSystemManager().mo10newSubManager();
            Throwable th = null;
            try {
                try {
                    this.optimizer.update(str, mo10newSubManager.create(pointer2), mo10newSubManager.create(pointer));
                    if (mo10newSubManager != null) {
                        if (0 == 0) {
                            mo10newSubManager.close();
                            return;
                        }
                        try {
                            mo10newSubManager.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } catch (Throwable th4) {
                if (mo10newSubManager != null) {
                    if (th != null) {
                        try {
                            mo10newSubManager.close();
                        } catch (Throwable th5) {
                            th.addSuppressed(th5);
                        }
                    } else {
                        mo10newSubManager.close();
                    }
                }
                throw th4;
            }
        }
    }

    public MxParameterServer(Optimizer optimizer) {
        super(createdKVStore());
        this.callback = new OptimizerCallback(optimizer);
        JnaUtils.parameterStoreSetUpdater(getHandle(), null, this.callback, null);
        this.priority = 0;
    }

    public void init(String str, NDArray[] nDArrayArr) {
        String[] strArr = new String[nDArrayArr.length];
        Arrays.fill(strArr, str);
        JnaUtils.parameterStoreInit(getHandle(), nDArrayArr.length, strArr, new NDList(nDArrayArr));
    }

    public void update(String str, NDArray[] nDArrayArr, NDArray[] nDArrayArr2) {
        String[] strArr = new String[nDArrayArr.length];
        String[] strArr2 = new String[nDArrayArr2.length];
        Arrays.fill(strArr, str);
        Arrays.fill(strArr2, str);
        JnaUtils.parameterStorePushPull(getHandle(), nDArrayArr.length, strArr, nDArrayArr2.length, strArr2, new NDList(nDArrayArr), new NDList(nDArrayArr2), -this.priority);
        this.priority++;
    }

    private static Pointer createdKVStore() {
        return JnaUtils.parameterStoreCreate("device");
    }

    @Override // ai.djl.mxnet.jna.NativeResource, java.lang.AutoCloseable
    public void close() {
        Pointer andSet = this.handle.getAndSet(null);
        if (andSet != null) {
            JnaUtils.parameterStoreClose(andSet);
        }
    }
}
