package org.apache.hama.ml.recommendation.cf;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Random;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.MapWritable;
import org.apache.hadoop.io.Text;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSP;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
import org.apache.hama.commons.io.MatrixWritable;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.ml.recommendation.Preference;
import org.apache.hama.ml.recommendation.cf.OnlineCF;
import org.apache.hama.ml.recommendation.cf.function.OnlineUpdate;

/* loaded from: input_file:org/apache/hama/ml/recommendation/cf/OnlineTrainBSP.class */
public class OnlineTrainBSP extends BSP<Text, VectorWritable, Text, VectorWritable, MapWritable> {
    protected static Log LOG = LogFactory.getLog(OnlineTrainBSP.class);
    private String inputPreferenceDelim = null;
    private String inputUserDelim = null;
    private String inputItemDelim = null;
    private int ITERATION = 0;
    private int MATRIX_RANK = 0;
    private int SKIP_COUNT = 0;
    private HashMap<String, VectorWritable> usersMatrix = new HashMap<>();
    private HashMap<String, VectorWritable> itemsMatrix = new HashMap<>();
    private DoubleMatrix userFeatureMatrix = null;
    private DoubleMatrix itemFeatureMatrix = null;
    private HashMap<String, VectorWritable> inpUsersFeatures = null;
    private HashMap<String, VectorWritable> inpItemsFeatures = null;
    private OnlineUpdate.Function function = null;
    private ArrayList<Preference<String, String>> preferences = new ArrayList<>();
    private ArrayList<Integer> indexes = new ArrayList<>();
    Random rnd = new Random();

    public void setup(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer) throws IOException, SyncException, InterruptedException {
        HamaConfiguration configuration = bSPPeer.getConfiguration();
        this.ITERATION = configuration.getInt(OnlineCF.Settings.CONF_ITERATION_COUNT, 100);
        this.MATRIX_RANK = configuration.getInt(OnlineCF.Settings.CONF_MATRIX_RANK, 10);
        this.SKIP_COUNT = configuration.getInt(OnlineCF.Settings.CONF_SKIP_COUNT, 5);
        this.inputItemDelim = configuration.get(OnlineCF.Settings.CONF_INPUT_ITEM_DELIM, OnlineCF.Settings.DFLT_ITEM_DELIM);
        this.inputUserDelim = configuration.get(OnlineCF.Settings.CONF_INPUT_USER_DELIM, OnlineCF.Settings.DFLT_USER_DELIM);
        this.inputPreferenceDelim = configuration.get(OnlineCF.Settings.CONF_INPUT_PREFERENCES_DELIM, OnlineCF.Settings.DFLT_PREFERENCE_DELIM);
        try {
            this.function = (OnlineUpdate.Function) configuration.getClass(OnlineCF.Settings.CONF_ONLINE_UPDATE_FUNCTION, (Class) null).newInstance();
        } catch (Exception e) {
        }
    }

    public void bsp(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer) throws IOException, SyncException, InterruptedException {
        LOG.info(bSPPeer.getPeerName() + ") collecting input data");
        collectInput(bSPPeer, null, null);
        askForFeatures(bSPPeer, null, null);
        bSPPeer.sync();
        sendRequiredFeatures(bSPPeer);
        bSPPeer.sync();
        collectFeatures(bSPPeer);
        LOG.info(bSPPeer.getPeerName() + ") collected: " + this.usersMatrix.size() + " users, " + this.itemsMatrix.size() + " items, " + this.preferences.size() + " preferences");
        for (int i = 0; i < this.ITERATION; i++) {
            computeValues();
            if ((i + 1) % this.SKIP_COUNT == 0) {
                normalizeWithBroadcastingValues(bSPPeer);
            }
        }
        saveModel(bSPPeer);
    }

    private void normalizeWithBroadcastingValues(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer) throws IOException, SyncException, InterruptedException {
        bSPPeer.sync();
        normalizeItemFactorizedValues(bSPPeer);
        bSPPeer.sync();
        if (this.itemFeatureMatrix != null) {
            normalizeMatrix(bSPPeer, this.itemFeatureMatrix, OnlineCF.Settings.MSG_ITEM_FEATURE_MATRIX, true);
            bSPPeer.sync();
        }
        if (this.userFeatureMatrix != null) {
            normalizeMatrix(bSPPeer, this.userFeatureMatrix, OnlineCF.Settings.MSG_USER_FEATURE_MATRIX, true);
            bSPPeer.sync();
        }
    }

    private DoubleMatrix normalizeMatrix(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer, DoubleMatrix doubleMatrix, IntWritable intWritable, boolean z) throws IOException, SyncException, InterruptedException {
        MapWritable mapWritable = new MapWritable();
        mapWritable.put(intWritable, new MatrixWritable(doubleMatrix));
        String peerName = bSPPeer.getPeerName(bSPPeer.getNumPeers() / 2);
        bSPPeer.send(peerName, mapWritable);
        bSPPeer.sync();
        DoubleMatrix doubleMatrix2 = null;
        if (bSPPeer.getPeerName().equals(peerName)) {
            doubleMatrix2 = new DenseDoubleMatrix(doubleMatrix.getRowCount(), doubleMatrix.getColumnCount(), 0.0d);
            int i = 0;
            while (true) {
                MapWritable currentMessage = bSPPeer.getCurrentMessage();
                if (currentMessage == null) {
                    break;
                }
                doubleMatrix2.add(currentMessage.get(intWritable).getMatrix());
                i++;
            }
            doubleMatrix2.divide(i);
        }
        if (z) {
            if (bSPPeer.getPeerName().equals(peerName)) {
                MapWritable mapWritable2 = new MapWritable();
                mapWritable2.put(intWritable, new MatrixWritable(doubleMatrix2));
                for (String str : bSPPeer.getAllPeerNames()) {
                    bSPPeer.send(str, mapWritable2);
                }
            }
            bSPPeer.sync();
            bSPPeer.getCurrentMessage().get(intWritable).getMatrix();
        }
        return doubleMatrix2;
    }

    private VectorWritable convertMatrixToVector(DoubleMatrix doubleMatrix) {
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector((doubleMatrix.getRowCount() * doubleMatrix.getColumnCount()) + 1);
        denseDoubleVector.set(0, this.MATRIX_RANK);
        int i = 0 + 1;
        for (int i2 = 0; i2 < doubleMatrix.getRowCount(); i2++) {
            for (int i3 = 0; i3 < doubleMatrix.getColumnCount(); i3++) {
                denseDoubleVector.set(i, doubleMatrix.get(i2, i3));
                i++;
            }
        }
        return new VectorWritable(denseDoubleVector);
    }

    private void normalizeItemFactorizedValues(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer) throws IOException, SyncException, InterruptedException {
        sendItemFactorizedValues(bSPPeer);
        bSPPeer.sync();
        HashMap<Text, LinkedList<IntWritable>> hashMap = new HashMap<>();
        HashMap<Text, DoubleVector> hashMap2 = new HashMap<>();
        getNormalizedItemFactorizedValues(bSPPeer, hashMap2, hashMap);
        sendTo(bSPPeer, hashMap, hashMap2);
        bSPPeer.sync();
        receiveSyncedItemFactorizedValues(bSPPeer);
    }

    private void sendTo(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer, HashMap<Text, LinkedList<IntWritable>> hashMap, HashMap<Text, DoubleVector> hashMap2) throws IOException {
        for (Map.Entry<Text, DoubleVector> entry : hashMap2.entrySet()) {
            MapWritable mapWritable = new MapWritable();
            mapWritable.put(OnlineCF.Settings.MSG_ITEM_MATRIX, entry.getKey());
            mapWritable.put(OnlineCF.Settings.MSG_VALUE, new VectorWritable(entry.getValue()));
            Iterator<IntWritable> it = hashMap.get(entry.getKey()).iterator();
            while (it.hasNext()) {
                bSPPeer.send(bSPPeer.getPeerName(it.next().get()), mapWritable);
            }
        }
    }

    private void getNormalizedItemFactorizedValues(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer, HashMap<Text, DoubleVector> hashMap, HashMap<Text, LinkedList<IntWritable>> hashMap2) throws IOException {
        HashMap hashMap3 = new HashMap();
        new MapWritable();
        while (true) {
            MapWritable currentMessage = bSPPeer.getCurrentMessage();
            if (currentMessage == null) {
                break;
            }
            Text text = (Text) currentMessage.get(OnlineCF.Settings.MSG_ITEM_MATRIX);
            VectorWritable vectorWritable = currentMessage.get(OnlineCF.Settings.MSG_VALUE);
            IntWritable intWritable = (IntWritable) currentMessage.get(OnlineCF.Settings.MSG_SENDER_ID);
            if (!hashMap.containsKey(text)) {
                hashMap.put(text, new DenseDoubleVector(this.MATRIX_RANK, 0.0d));
                hashMap3.put(text, 0);
                hashMap2.put(text, new LinkedList<>());
            }
            hashMap.put(text, hashMap.get(text).add(vectorWritable.getVector()));
            hashMap3.put(text, Integer.valueOf(((Integer) hashMap3.get(text)).intValue() + 1));
            hashMap2.get(text).add(intWritable);
        }
        for (Map.Entry<Text, DoubleVector> entry : hashMap.entrySet()) {
            entry.setValue(entry.getValue().multiply(1.0d / ((Integer) hashMap3.get(entry.getKey())).intValue()));
        }
    }

    private void receiveSyncedItemFactorizedValues(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer) throws IOException {
        new MapWritable();
        while (true) {
            MapWritable currentMessage = bSPPeer.getCurrentMessage();
            if (currentMessage == null) {
                return;
            }
            this.itemsMatrix.put(currentMessage.get(OnlineCF.Settings.MSG_ITEM_MATRIX).toString(), (VectorWritable) currentMessage.get(OnlineCF.Settings.MSG_VALUE));
        }
    }

    private void sendItemFactorizedValues(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer) throws IOException, SyncException, InterruptedException {
        int numPeers = bSPPeer.getNumPeers();
        IntWritable intWritable = new IntWritable(bSPPeer.getPeerIndex());
        for (Map.Entry<String, VectorWritable> entry : this.itemsMatrix.entrySet()) {
            MapWritable mapWritable = new MapWritable();
            mapWritable.put(OnlineCF.Settings.MSG_ITEM_MATRIX, new Text(entry.getKey()));
            mapWritable.put(OnlineCF.Settings.MSG_VALUE, entry.getValue());
            mapWritable.put(OnlineCF.Settings.MSG_SENDER_ID, intWritable);
            bSPPeer.send(bSPPeer.getPeerName(entry.getKey().hashCode() % numPeers), mapWritable);
        }
    }

    private void computeValues() {
        for (int size = this.indexes.size(); size > 0; size--) {
            int abs = Math.abs(this.rnd.nextInt()) % size;
            int intValue = this.indexes.get(abs).intValue();
            int intValue2 = this.indexes.get(size - 1).intValue();
            this.indexes.set(size - 1, Integer.valueOf(intValue));
            this.indexes.set(abs, Integer.valueOf(intValue2));
        }
        OnlineUpdate.InputStructure inputStructure = new OnlineUpdate.InputStructure();
        Iterator<Integer> it = this.indexes.iterator();
        while (it.hasNext()) {
            Preference<String, String> preference = this.preferences.get(it.next().intValue());
            VectorWritable vectorWritable = this.usersMatrix.get(preference.getUserId());
            VectorWritable vectorWritable2 = this.itemsMatrix.get(preference.getItemId());
            VectorWritable vectorWritable3 = this.inpUsersFeatures != null ? this.inpUsersFeatures.get(preference.getUserId()) : null;
            VectorWritable vectorWritable4 = this.inpItemsFeatures != null ? this.inpItemsFeatures.get(preference.getItemId()) : null;
            inputStructure.user = vectorWritable;
            inputStructure.item = vectorWritable2;
            inputStructure.expectedScore = preference.getValue();
            inputStructure.userFeatures = vectorWritable3;
            inputStructure.itemFeatures = vectorWritable4;
            inputStructure.userFeatureFactorized = this.userFeatureMatrix;
            inputStructure.itemFeatureFactorized = this.itemFeatureMatrix;
            OnlineUpdate.OutputStructure compute = this.function.compute(inputStructure);
            this.usersMatrix.put(preference.getUserId(), compute.userFactorized);
            this.itemsMatrix.put(preference.getItemId(), compute.itemFactorized);
            this.userFeatureMatrix = compute.userFeatureFactorized;
            this.itemFeatureMatrix = compute.itemFeatureFactorized;
        }
    }

    private void saveModel(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer) throws IOException, SyncException, InterruptedException {
        LOG.info(bSPPeer.getPeerName() + ") saving " + this.usersMatrix.size() + " users");
        for (Map.Entry<String, VectorWritable> entry : this.usersMatrix.entrySet()) {
            bSPPeer.write(new Text(OnlineCF.Settings.DFLT_MODEL_USER_DELIM + entry.getKey()), entry.getValue());
        }
        sendItemFactorizedValues(bSPPeer);
        bSPPeer.sync();
        HashMap<Text, LinkedList<IntWritable>> hashMap = new HashMap<>();
        HashMap<Text, DoubleVector> hashMap2 = new HashMap<>();
        getNormalizedItemFactorizedValues(bSPPeer, hashMap2, hashMap);
        saveItemFactorizedValues(bSPPeer, hashMap2);
        if (this.itemFeatureMatrix != null) {
            for (Map.Entry<String, VectorWritable> entry2 : this.inpItemsFeatures.entrySet()) {
                bSPPeer.write(new Text(OnlineCF.Settings.DFLT_MODEL_ITEM_FEATURES_DELIM + entry2.getKey()), entry2.getValue());
            }
            DoubleMatrix normalizeMatrix = normalizeMatrix(bSPPeer, this.itemFeatureMatrix, OnlineCF.Settings.MSG_ITEM_FEATURE_MATRIX, false);
            if (normalizeMatrix != null) {
                bSPPeer.write(new Text(OnlineCF.Settings.DFLT_MODEL_ITEM_MTX_FEATURES_DELIM + OnlineCF.Settings.MSG_ITEM_FEATURE_MATRIX.toString()), convertMatrixToVector(normalizeMatrix));
            }
        }
        if (this.userFeatureMatrix != null) {
            for (Map.Entry<String, VectorWritable> entry3 : this.inpUsersFeatures.entrySet()) {
                bSPPeer.write(new Text(OnlineCF.Settings.DFLT_MODEL_USER_FEATURES_DELIM + entry3.getKey()), entry3.getValue());
            }
            DoubleMatrix normalizeMatrix2 = normalizeMatrix(bSPPeer, this.userFeatureMatrix, OnlineCF.Settings.MSG_USER_FEATURE_MATRIX, false);
            if (normalizeMatrix2 != null) {
                bSPPeer.write(new Text(OnlineCF.Settings.DFLT_MODEL_USER_MTX_FEATURES_DELIM + OnlineCF.Settings.MSG_USER_FEATURE_MATRIX.toString()), convertMatrixToVector(normalizeMatrix2));
            }
        }
    }

    private void saveItemFactorizedValues(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer, HashMap<Text, DoubleVector> hashMap) throws IOException {
        LOG.info(bSPPeer.getPeerName() + ") saving " + hashMap.size() + " items");
        for (Map.Entry<Text, DoubleVector> entry : hashMap.entrySet()) {
            bSPPeer.write(new Text(OnlineCF.Settings.DFLT_MODEL_ITEM_DELIM + entry.getKey().toString()), new VectorWritable(entry.getValue()));
        }
    }

    private void sendRequiredFeatures(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer) throws IOException, SyncException, InterruptedException {
        while (true) {
            MapWritable currentMessage = bSPPeer.getCurrentMessage();
            if (currentMessage == null) {
                return;
            }
            int i = currentMessage.get(OnlineCF.Settings.MSG_SENDER_ID).get();
            MapWritable mapWritable = new MapWritable();
            if (currentMessage.containsKey(OnlineCF.Settings.MSG_INP_ITEM_FEATURES)) {
                String substring = currentMessage.get(OnlineCF.Settings.MSG_INP_ITEM_FEATURES).toString().substring(1);
                mapWritable.put(OnlineCF.Settings.MSG_INP_ITEM_FEATURES, new Text(substring));
                mapWritable.put(OnlineCF.Settings.MSG_VALUE, this.inpItemsFeatures.get(substring));
            } else if (currentMessage.containsKey(OnlineCF.Settings.MSG_INP_USER_FEATURES)) {
                String substring2 = currentMessage.get(OnlineCF.Settings.MSG_INP_USER_FEATURES).toString().substring(1);
                mapWritable.put(OnlineCF.Settings.MSG_INP_USER_FEATURES, new Text(substring2));
                mapWritable.put(OnlineCF.Settings.MSG_VALUE, this.inpUsersFeatures.get(substring2));
            }
            bSPPeer.send(bSPPeer.getPeerName(i), mapWritable);
        }
    }

    private void collectFeatures(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer) throws IOException {
        this.inpItemsFeatures = new HashMap<>();
        this.inpUsersFeatures = new HashMap<>();
        int i = 0;
        int i2 = 0;
        while (true) {
            MapWritable currentMessage = bSPPeer.getCurrentMessage();
            if (currentMessage == null) {
                break;
            }
            if (currentMessage.containsKey(OnlineCF.Settings.MSG_INP_ITEM_FEATURES)) {
                this.inpItemsFeatures.put(currentMessage.get(OnlineCF.Settings.MSG_INP_ITEM_FEATURES).toString(), (VectorWritable) currentMessage.get(OnlineCF.Settings.MSG_VALUE));
                i2 = currentMessage.get(OnlineCF.Settings.MSG_VALUE).getVector().getLength();
            } else if (currentMessage.containsKey(OnlineCF.Settings.MSG_INP_USER_FEATURES)) {
                this.inpUsersFeatures.put(currentMessage.get(OnlineCF.Settings.MSG_INP_USER_FEATURES).toString(), (VectorWritable) currentMessage.get(OnlineCF.Settings.MSG_VALUE));
                i = currentMessage.get(OnlineCF.Settings.MSG_VALUE).getVector().getLength();
            }
        }
        if (this.inpItemsFeatures.size() > 0) {
            this.itemFeatureMatrix = new DenseDoubleMatrix(this.MATRIX_RANK, i2, this.rnd);
        }
        if (this.inpUsersFeatures.size() > 0) {
            this.userFeatureMatrix = new DenseDoubleMatrix(this.MATRIX_RANK, i, this.rnd);
        }
    }

    private void askForFeatures(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer, HashSet<Text> hashSet, HashSet<Text> hashSet2) throws IOException, SyncException, InterruptedException {
        int numPeers = bSPPeer.getNumPeers();
        int peerIndex = bSPPeer.getPeerIndex();
        if (hashSet != null) {
            Iterator<Text> it = hashSet.iterator();
            while (it.hasNext()) {
                MapWritable mapWritable = new MapWritable();
                Text next = it.next();
                mapWritable.put(OnlineCF.Settings.MSG_INP_USER_FEATURES, next);
                mapWritable.put(OnlineCF.Settings.MSG_SENDER_ID, new IntWritable(peerIndex));
                bSPPeer.send(bSPPeer.getPeerName(next.hashCode() % numPeers), mapWritable);
            }
        }
        if (hashSet2 != null) {
            Iterator<Text> it2 = hashSet2.iterator();
            while (it2.hasNext()) {
                MapWritable mapWritable2 = new MapWritable();
                Text next2 = it2.next();
                mapWritable2.put(OnlineCF.Settings.MSG_INP_ITEM_FEATURES, next2);
                mapWritable2.put(OnlineCF.Settings.MSG_SENDER_ID, new IntWritable(peerIndex));
                bSPPeer.send(bSPPeer.getPeerName(next2.hashCode() % numPeers), mapWritable2);
            }
        }
    }

    private void collectInput(BSPPeer<Text, VectorWritable, Text, VectorWritable, MapWritable> bSPPeer, HashSet<Text> hashSet, HashSet<Text> hashSet2) throws IOException {
        Text text = new Text();
        VectorWritable vectorWritable = new VectorWritable();
        int i = 0;
        HashSet hashSet3 = new HashSet();
        HashSet hashSet4 = new HashSet();
        while (bSPPeer.readNext(text, vectorWritable)) {
            String substring = text.toString().substring(0, 1);
            String substring2 = text.toString().substring(1);
            if (substring.equals(this.inputPreferenceDelim)) {
                String l = Long.toString((long) vectorWritable.getVector().get(0));
                String d = Double.toString(vectorWritable.getVector().get(1));
                if (!this.usersMatrix.containsKey(substring2)) {
                    DenseDoubleVector denseDoubleVector = new DenseDoubleVector(this.MATRIX_RANK);
                    for (int i2 = 0; i2 < this.MATRIX_RANK; i2++) {
                        denseDoubleVector.set(i2, this.rnd.nextDouble());
                    }
                    this.usersMatrix.put(substring2, new VectorWritable(denseDoubleVector));
                }
                if (!this.itemsMatrix.containsKey(l)) {
                    DenseDoubleVector denseDoubleVector2 = new DenseDoubleVector(this.MATRIX_RANK);
                    for (int i3 = 0; i3 < this.MATRIX_RANK; i3++) {
                        denseDoubleVector2.set(i3, this.rnd.nextDouble());
                    }
                    this.itemsMatrix.put(l, new VectorWritable(denseDoubleVector2));
                }
                this.preferences.add(new Preference<>(substring2, l, Double.parseDouble(d)));
                this.indexes.add(Integer.valueOf(i));
                hashSet3.add(new Text(this.inputUserDelim + substring2));
                hashSet4.add(new Text(this.inputItemDelim + l));
                i++;
            } else if (substring.equals(this.inputUserDelim)) {
                if (this.inpUsersFeatures == null) {
                    this.inpUsersFeatures = new HashMap<>();
                }
                this.inpUsersFeatures.put(substring2, vectorWritable);
            } else if (substring.equals(this.inputItemDelim)) {
                if (this.inpItemsFeatures == null) {
                    this.inpItemsFeatures = new HashMap<>();
                }
                this.inpItemsFeatures.put(substring2, vectorWritable);
            }
        }
    }
}
