package org.apache.hama.ml.kmeans;

import com.google.common.base.Preconditions;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSP;
import org.apache.hama.bsp.BSPJob;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.SequenceFileInputFormat;
import org.apache.hama.bsp.SequenceFileOutputFormat;
import org.apache.hama.bsp.TextOutputFormat;
import org.apache.hama.bsp.sync.SyncException;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.NamedDoubleVector;
import org.apache.hama.ml.distance.DistanceMeasurer;
import org.apache.hama.ml.distance.EuclidianDistance;
import org.apache.hama.util.ReflectionUtils;

/* loaded from: input_file:org/apache/hama/ml/kmeans/KMeansBSP.class */
public final class KMeansBSP extends BSP<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> {
    public static final String CENTER_OUT_PATH = "center.out.path";
    public static final String MAX_ITERATIONS_KEY = "k.means.max.iterations";
    public static final String CACHING_ENABLED_KEY = "k.means.caching.enabled";
    public static final String DISTANCE_MEASURE_CLASS = "distance.measure.class";
    public static final String CENTER_IN_PATH = "center.in.path";
    private static final Log LOG = LogFactory.getLog(KMeansBSP.class);
    private DoubleVector[] centers;
    private List<DoubleVector> cache;
    private int maxIterations;
    private DistanceMeasurer distanceMeasurer;
    private Configuration conf;

    public final void setup(BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> bSPPeer) throws IOException, InterruptedException {
        this.conf = bSPPeer.getConfiguration();
        Path path = new Path(bSPPeer.getConfiguration().get(CENTER_IN_PATH));
        FileSystem fileSystem = FileSystem.get(bSPPeer.getConfiguration());
        ArrayList arrayList = new ArrayList();
        SequenceFile.Reader reader = null;
        try {
            try {
                reader = new SequenceFile.Reader(fileSystem, path, bSPPeer.getConfiguration());
                VectorWritable vectorWritable = new VectorWritable();
                NullWritable nullWritable = NullWritable.get();
                while (reader.next(vectorWritable, nullWritable)) {
                    arrayList.add(vectorWritable.getVector());
                }
                if (reader != null) {
                    reader.close();
                }
                Preconditions.checkArgument(arrayList.size() > 0, "Centers file must contain at least a single center!");
                this.centers = (DoubleVector[]) arrayList.toArray(new DoubleVector[arrayList.size()]);
                String str = bSPPeer.getConfiguration().get(DISTANCE_MEASURE_CLASS);
                if (str != null) {
                    try {
                        this.distanceMeasurer = (DistanceMeasurer) ReflectionUtils.newInstance(str);
                    } catch (ClassNotFoundException e) {
                        throw new RuntimeException("Wrong DistanceMeasurer implementation " + str + " provided");
                    }
                } else {
                    this.distanceMeasurer = new EuclidianDistance();
                }
                this.maxIterations = bSPPeer.getConfiguration().getInt(MAX_ITERATIONS_KEY, -1);
                if (bSPPeer.getConfiguration().getBoolean(CACHING_ENABLED_KEY, false)) {
                    this.cache = new ArrayList();
                }
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
        } catch (Throwable th) {
            if (reader != null) {
                reader.close();
            }
            throw th;
        }
    }

    public final void bsp(BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> bSPPeer) throws IOException, InterruptedException, SyncException {
        while (true) {
            assignCenters(bSPPeer);
            bSPPeer.sync();
            long updateCenters = updateCenters(bSPPeer);
            bSPPeer.reopenInput();
            if (updateCenters != 0 && (this.maxIterations <= 0 || this.maxIterations >= bSPPeer.getSuperstepCount())) {
            }
        }
        LOG.info("Finished! Writing the assignments...");
        recalculateAssignmentsAndWrite(bSPPeer);
        LOG.info("Done.");
    }

    private long updateCenters(BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> bSPPeer) throws IOException {
        DoubleVector[] doubleVectorArr = new DoubleVector[this.centers.length];
        int[] iArr = new int[this.centers.length];
        while (true) {
            CenterMessage centerMessage = (CenterMessage) bSPPeer.getCurrentMessage();
            if (centerMessage == null) {
                break;
            }
            DoubleVector doubleVector = doubleVectorArr[centerMessage.getCenterIndex()];
            DoubleVector data = centerMessage.getData();
            int centerIndex = centerMessage.getCenterIndex();
            iArr[centerIndex] = iArr[centerIndex] + centerMessage.getIncrementCounter();
            if (doubleVector == null) {
                doubleVectorArr[centerMessage.getCenterIndex()] = data;
            } else {
                doubleVectorArr[centerMessage.getCenterIndex()] = doubleVector.addUnsafe(data);
            }
        }
        for (int i = 0; i < doubleVectorArr.length; i++) {
            if (doubleVectorArr[i] != null) {
                doubleVectorArr[i] = doubleVectorArr[i].divide(iArr[i]);
            }
        }
        long j = 0;
        for (int i2 = 0; i2 < doubleVectorArr.length; i2++) {
            DoubleVector doubleVector2 = this.centers[i2];
            if (doubleVectorArr[i2] != null && doubleVector2.subtractUnsafe(doubleVectorArr[i2]).abs().sum() > 0.0d) {
                this.centers[i2] = doubleVectorArr[i2];
                j++;
            }
        }
        return j;
    }

    private void assignCenters(BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> bSPPeer) throws IOException {
        DoubleVector[] doubleVectorArr = new DoubleVector[this.centers.length];
        int[] iArr = new int[this.centers.length];
        if (this.cache == null) {
            NullWritable nullWritable = NullWritable.get();
            VectorWritable vectorWritable = new VectorWritable();
            while (bSPPeer.readNext(vectorWritable, nullWritable)) {
                assignCentersInternal(doubleVectorArr, iArr, vectorWritable.getVector().deepCopy());
            }
        } else if (this.cache.isEmpty()) {
            NullWritable nullWritable2 = NullWritable.get();
            VectorWritable vectorWritable2 = new VectorWritable();
            while (bSPPeer.readNext(vectorWritable2, nullWritable2)) {
                DoubleVector deepCopy = vectorWritable2.getVector().deepCopy();
                this.cache.add(deepCopy);
                assignCentersInternal(doubleVectorArr, iArr, deepCopy);
            }
        } else {
            Iterator<DoubleVector> it = this.cache.iterator();
            while (it.hasNext()) {
                assignCentersInternal(doubleVectorArr, iArr, it.next());
            }
        }
        for (int i = 0; i < doubleVectorArr.length; i++) {
            if (doubleVectorArr[i] != null) {
                for (String str : bSPPeer.getAllPeerNames()) {
                    bSPPeer.send(str, new CenterMessage(i, iArr[i], doubleVectorArr[i]));
                }
            }
        }
    }

    private void assignCentersInternal(DoubleVector[] doubleVectorArr, int[] iArr, DoubleVector doubleVector) {
        int nearestCenter = getNearestCenter(doubleVector);
        if (doubleVectorArr[nearestCenter] == null) {
            doubleVectorArr[nearestCenter] = doubleVector;
        } else {
            doubleVectorArr[nearestCenter] = doubleVectorArr[nearestCenter].addUnsafe(doubleVector);
        }
        iArr[nearestCenter] = iArr[nearestCenter] + 1;
    }

    private int getNearestCenter(DoubleVector doubleVector) {
        int i = 0;
        double d = Double.MAX_VALUE;
        for (int i2 = 0; i2 < this.centers.length; i2++) {
            double measureDistance = this.distanceMeasurer.measureDistance(this.centers[i2], doubleVector);
            if (measureDistance < d) {
                d = measureDistance;
                i = i2;
            }
        }
        return i;
    }

    private void recalculateAssignmentsAndWrite(BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> bSPPeer) throws IOException {
        String str;
        NullWritable nullWritable = NullWritable.get();
        if (this.cache == null) {
            VectorWritable vectorWritable = new VectorWritable();
            IntWritable intWritable = new IntWritable();
            while (bSPPeer.readNext(vectorWritable, nullWritable)) {
                intWritable.set(getNearestCenter(vectorWritable.getVector()));
                bSPPeer.write(intWritable, vectorWritable);
            }
        } else {
            IntWritable intWritable2 = new IntWritable();
            for (DoubleVector doubleVector : this.cache) {
                intWritable2.set(getNearestCenter(doubleVector));
                bSPPeer.write(intWritable2, new VectorWritable(doubleVector));
            }
        }
        if (!bSPPeer.getPeerName().equals(bSPPeer.getPeerName(0)) || (str = this.conf.get(CENTER_OUT_PATH)) == null) {
            return;
        }
        SequenceFile.Writer createWriter = SequenceFile.createWriter(FileSystem.get(this.conf), this.conf, new Path(str), VectorWritable.class, NullWritable.class, SequenceFile.CompressionType.NONE);
        for (DoubleVector doubleVector2 : this.centers) {
            createWriter.append(new VectorWritable(doubleVector2), nullWritable);
        }
        createWriter.close();
    }

    public static BSPJob createJob(Configuration configuration, Path path, Path path2, boolean z) throws IOException {
        BSPJob bSPJob = new BSPJob(new HamaConfiguration(configuration), KMeansBSP.class);
        bSPJob.setJobName("KMeans Clustering");
        bSPJob.setJarByClass(KMeansBSP.class);
        bSPJob.setBspClass(KMeansBSP.class);
        bSPJob.setInputPath(path);
        bSPJob.setOutputPath(path2);
        bSPJob.setInputFormat(SequenceFileInputFormat.class);
        if (z) {
            bSPJob.setOutputFormat(TextOutputFormat.class);
        } else {
            bSPJob.setOutputFormat(SequenceFileOutputFormat.class);
        }
        bSPJob.setOutputKeyClass(IntWritable.class);
        bSPJob.setOutputValueClass(VectorWritable.class);
        return bSPJob;
    }

    public static void main(String[] strArr) throws IOException, ClassNotFoundException, InterruptedException {
        if (strArr.length < 6) {
            LOG.info("USAGE: <INPUT_PATH> <OUTPUT_PATH> <COUNT> <K> <DIMENSION OF VECTORS> <MAXITERATIONS> <optional: num of tasks>");
            return;
        }
        Configuration configuration = new Configuration();
        int parseInt = Integer.parseInt(strArr[2]);
        int parseInt2 = Integer.parseInt(strArr[3]);
        int parseInt3 = Integer.parseInt(strArr[4]);
        int parseInt4 = Integer.parseInt(strArr[5]);
        configuration.setInt(MAX_ITERATIONS_KEY, parseInt4);
        Path path = new Path(strArr[0]);
        Path path2 = new Path(strArr[1]);
        Path path3 = new Path(path, "center/cen.seq");
        Path path4 = new Path(path2, "center/center_output.seq");
        configuration.set(CENTER_IN_PATH, path3.toString());
        configuration.set(CENTER_OUT_PATH, path4.toString());
        configuration.set("bsp.local.tasks.maximum", "" + Runtime.getRuntime().availableProcessors());
        configuration.setBoolean(CACHING_ENABLED_KEY, true);
        BSPJob createJob = createJob(configuration, path, path2, false);
        LOG.info("N: " + parseInt + " k: " + parseInt2 + " Dimension: " + parseInt3 + " Iterations: " + parseInt4);
        prepareInput(parseInt, parseInt2, parseInt3, configuration, path, path3, path2, FileSystem.get(configuration));
        if (strArr.length == 7) {
            createJob.setNumBspTask(Integer.parseInt(strArr[6]));
        }
        createJob.waitForCompletion(true);
    }

    public static HashMap<Integer, DoubleVector> readClusterCenters(Configuration configuration, Path path, Path path2, FileSystem fileSystem) throws IOException {
        HashMap<Integer, DoubleVector> hashMap = new HashMap<>();
        SequenceFile.Reader reader = new SequenceFile.Reader(fileSystem, path2, configuration);
        int i = 0;
        VectorWritable vectorWritable = new VectorWritable();
        while (reader.next(vectorWritable, NullWritable.get())) {
            int i2 = i;
            i++;
            hashMap.put(Integer.valueOf(i2), vectorWritable.getVector());
        }
        reader.close();
        return hashMap;
    }

    /* JADX WARN: Code restructure failed: missing block: B:15:0x00a9, code lost:
    
        r15 = r15 + 1;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static java.util.List<java.lang.String> readOutput(org.apache.hadoop.conf.Configuration r7, org.apache.hadoop.fs.Path r8, org.apache.hadoop.fs.FileSystem r9, int r10) throws java.io.IOException {
        /*
            java.util.ArrayList r0 = new java.util.ArrayList
            r1 = r0
            r1.<init>()
            r11 = r0
            r0 = r9
            org.apache.hadoop.fs.Path r1 = new org.apache.hadoop.fs.Path
            r2 = r1
            java.lang.StringBuilder r3 = new java.lang.StringBuilder
            r4 = r3
            r4.<init>()
            r4 = r8
            java.lang.StringBuilder r3 = r3.append(r4)
            java.lang.String r4 = "/part-*"
            java.lang.StringBuilder r3 = r3.append(r4)
            java.lang.String r3 = r3.toString()
            r2.<init>(r3)
            org.apache.hadoop.fs.FileStatus[] r0 = r0.globStatus(r1)
            r12 = r0
            r0 = r12
            r13 = r0
            r0 = r13
            int r0 = r0.length
            r14 = r0
            r0 = 0
            r15 = r0
        L35:
            r0 = r15
            r1 = r14
            if (r0 >= r1) goto Laf
            r0 = r13
            r1 = r15
            r0 = r0[r1]
            r16 = r0
            java.io.BufferedReader r0 = new java.io.BufferedReader
            r1 = r0
            java.io.InputStreamReader r2 = new java.io.InputStreamReader
            r3 = r2
            r4 = r9
            r5 = r16
            org.apache.hadoop.fs.Path r5 = r5.getPath()
            org.apache.hadoop.fs.FSDataInputStream r4 = r4.open(r5)
            r3.<init>(r4)
            r1.<init>(r2)
            r17 = r0
            r0 = 0
            r18 = r0
        L5f:
            r0 = r17
            java.lang.String r0 = r0.readLine()
            r1 = r0
            r18 = r1
            if (r0 == 0) goto La9
            r0 = r18
            java.lang.String r1 = "\t"
            java.lang.String[] r0 = r0.split(r1)
            r19 = r0
            r0 = r11
            java.lang.StringBuilder r1 = new java.lang.StringBuilder
            r2 = r1
            r2.<init>()
            r2 = r19
            r3 = 1
            r2 = r2[r3]
            java.lang.StringBuilder r1 = r1.append(r2)
            java.lang.String r2 = " belongs to cluster "
            java.lang.StringBuilder r1 = r1.append(r2)
            r2 = r19
            r3 = 0
            r2 = r2[r3]
            java.lang.StringBuilder r1 = r1.append(r2)
            java.lang.String r1 = r1.toString()
            boolean r0 = r0.add(r1)
            r0 = r11
            int r0 = r0.size()
            r1 = r10
            if (r0 < r1) goto La6
            r0 = r11
            return r0
        La6:
            goto L5f
        La9:
            int r15 = r15 + 1
            goto L35
        Laf:
            r0 = r11
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: org.apache.hama.ml.kmeans.KMeansBSP.readOutput(org.apache.hadoop.conf.Configuration, org.apache.hadoop.fs.Path, org.apache.hadoop.fs.FileSystem, int):java.util.List");
    }

    public static Path prepareInputText(int i, Configuration configuration, Path path, Path path2, Path path3, FileSystem fileSystem, boolean z) throws IOException {
        Path path4 = fileSystem.isFile(path) ? new Path(path.getParent(), "textinput/in.seq") : new Path(path, "textinput/in.seq");
        if (fileSystem.exists(path3)) {
            fileSystem.delete(path3, true);
        }
        if (fileSystem.exists(path2)) {
            fileSystem.delete(path2, true);
        }
        if (fileSystem.exists(path4)) {
            fileSystem.delete(path4, true);
        }
        NullWritable nullWritable = NullWritable.get();
        SequenceFile.Writer writer = new SequenceFile.Writer(fileSystem, configuration, path2, VectorWritable.class, NullWritable.class);
        SequenceFile.Writer createWriter = SequenceFile.createWriter(fileSystem, configuration, path4, VectorWritable.class, NullWritable.class, SequenceFile.CompressionType.NONE);
        int i2 = 0;
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(fileSystem.open(path)));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                writer.close();
                createWriter.close();
                return path4;
            }
            String[] split = readLine.split("\t");
            int length = split.length;
            int i3 = 0;
            if (z) {
                length--;
                i3 = 0 + 1;
            }
            DenseDoubleVector denseDoubleVector = new DenseDoubleVector(length);
            for (int i4 = 0; i4 < length; i4++) {
                denseDoubleVector.set(i4, Double.parseDouble(split[i4 + i3]));
            }
            VectorWritable vectorWritable = z ? new VectorWritable(new NamedDoubleVector(split[0], denseDoubleVector)) : new VectorWritable(denseDoubleVector);
            createWriter.append(vectorWritable, nullWritable);
            if (i > i2) {
                writer.append(vectorWritable, nullWritable);
            }
            i2++;
        }
    }

    public static void prepareInput(int i, int i2, int i3, Configuration configuration, Path path, Path path2, Path path3, FileSystem fileSystem) throws IOException {
        if (fileSystem.exists(path3)) {
            fileSystem.delete(path3, true);
        }
        if (fileSystem.exists(path2)) {
            fileSystem.delete(path2, true);
        }
        if (fileSystem.exists(path)) {
            fileSystem.delete(path, true);
        }
        SequenceFile.Writer createWriter = SequenceFile.createWriter(fileSystem, configuration, path2, VectorWritable.class, NullWritable.class, SequenceFile.CompressionType.NONE);
        NullWritable nullWritable = NullWritable.get();
        SequenceFile.Writer createWriter2 = SequenceFile.createWriter(fileSystem, configuration, path, VectorWritable.class, NullWritable.class, SequenceFile.CompressionType.NONE);
        Random random = new Random();
        for (int i4 = 0; i4 < i; i4++) {
            double[] dArr = new double[i3];
            for (int i5 = 0; i5 < i3; i5++) {
                dArr[i5] = random.nextInt(i);
            }
            VectorWritable vectorWritable = new VectorWritable(new DenseDoubleVector(dArr));
            createWriter2.append(vectorWritable, nullWritable);
            if (i2 > i4) {
                createWriter.append(vectorWritable, nullWritable);
            }
        }
        createWriter.close();
        createWriter2.close();
    }
}
