package com.intel.analytics.bigdl.ppml.common;

import com.intel.analytics.bigdl.ppml.base.DataHolder;
import com.intel.analytics.bigdl.ppml.base.StorageHolder;
import java.util.HashMap;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:com/intel/analytics/bigdl/ppml/common/Aggregator.class */
public abstract class Aggregator {
    protected Integer clientNum;
    private Logger logger = LogManager.getLogger(getClass());
    protected String returnMessage = "";
    public Map<FLPhase, StorageHolder> aggregateTypeMap = new HashMap();

    public Aggregator() {
        initStorage();
    }

    public void setReturnMessage(String str) {
        this.returnMessage = str;
    }

    public String getReturnMessage() {
        return this.returnMessage;
    }

    public void setClientNum(Integer num) {
        this.clientNum = num;
    }

    public abstract void initStorage();

    public abstract void aggregate(FLPhase fLPhase);

    public void putClientData(FLPhase fLPhase, String str, int i, DataHolder dataHolder) throws IllegalArgumentException, InterruptedException {
        this.logger.debug(str + " getting data to update from server: " + fLPhase.toString());
        StorageHolder storageHolder = this.aggregateTypeMap.get(fLPhase);
        if (i != -1) {
            checkVersion(storageHolder.getVersion(), i);
        }
        this.logger.debug(str + " version check pass, version: " + i);
        synchronized (this) {
            storageHolder.putClientData(str, dataHolder);
            this.logger.debug(str + " client data uploaded to server: " + fLPhase.toString());
            this.logger.debug("Server received data " + storageHolder.getClientDataSize() + "/" + this.clientNum);
            if (storageHolder.getClientDataSize() >= this.clientNum.intValue()) {
                this.logger.debug("Server received all client data, start aggregate.");
                aggregate(fLPhase);
                notifyAll();
            } else {
                wait();
            }
        }
    }

    protected void checkVersion(int i, int i2) throws IllegalArgumentException {
        if (i != i2) {
            throw new IllegalArgumentException("Version miss match, got server version: " + i + ", client version: " + i2);
        }
    }
}
