package ws.palladian.kaggle.restaurants.aggregation;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import ws.palladian.classification.evaluation.ConfusionMatrixEvaluator;
import ws.palladian.classification.liblinear.LibLinearClassifier;
import ws.palladian.classification.liblinear.LibLinearLearner;
import ws.palladian.classification.liblinear.LibLinearModel;
import ws.palladian.classification.utils.CsvDatasetReader;
import ws.palladian.classification.utils.CsvDatasetReaderConfig;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Instance;
import ws.palladian.core.InstanceBuilder;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.dataset.DefaultDataset;
import ws.palladian.core.value.ImmutableStringValue;
import ws.palladian.core.value.NumericValue;
import ws.palladian.core.value.Value;
import ws.palladian.helper.collection.DefaultMultiMap;
import ws.palladian.helper.collection.LazyMap;
import ws.palladian.helper.collection.MultiMap;
import ws.palladian.helper.date.DateHelper;
import ws.palladian.helper.functional.Factory;
import ws.palladian.helper.io.CloseableIterator;
import ws.palladian.helper.io.FileHelper;
import ws.palladian.helper.math.ConfusionMatrix;
import ws.palladian.helper.math.FatStats;
import ws.palladian.helper.math.Stats;
import ws.palladian.kaggle.restaurants.dataset.Label;
import ws.palladian.kaggle.restaurants.utils.Config;

/* loaded from: input_file:ws/palladian/kaggle/restaurants/aggregation/MultiClassificationAggregator.class */
public class MultiClassificationAggregator {
    public static void main(String[] strArr) throws IOException {
        classify();
    }

    private static void classify() throws IOException {
        Dataset read = read("/Volumes/iMac HD/Research/Yelp_Kaggle_Restaurants/data/joined_classified_test_with_bizIds_2016-04-02_13-36-15.csv");
        Dataset read2 = read("/Volumes/iMac HD/Research/Yelp_Kaggle_Restaurants/data/joined_classified_test_with_bizIds_2016-04-13_00-28-31.csv");
        LazyMap lazyMap = new LazyMap(new Factory<Map<String, Stats>>() { // from class: ws.palladian.kaggle.restaurants.aggregation.MultiClassificationAggregator.1
            /* renamed from: create, reason: merged with bridge method [inline-methods] */
            public Map<String, Stats> m22create() {
                return new LazyMap(FatStats.FACTORY);
            }
        });
        CloseableIterator it = read.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            String obj = ((Value) instance.getVector().get("businessId")).toString();
            for (Label label : Label.values()) {
                ((Stats) ((Map) lazyMap.get(obj)).get(label.toString() + "_1")).add(Double.valueOf(((NumericValue) instance.getVector().get(label.toString())).getDouble()));
            }
        }
        CloseableIterator it2 = read2.iterator();
        while (it2.hasNext()) {
            Instance instance2 = (Instance) it2.next();
            String obj2 = ((Value) instance2.getVector().get("businessId")).toString();
            for (Label label2 : Label.values()) {
                ((Stats) ((Map) lazyMap.get(obj2)).get(label2.toString() + "_2")).add(Double.valueOf(((NumericValue) instance2.getVector().get(label2.toString())).getDouble()));
            }
        }
        HashMap hashMap = new HashMap();
        for (Label label3 : Label.values()) {
            hashMap.put(label3, FileHelper.deserialize(Config.getFilePath("model.aggregation.multi." + label3.toString().toLowerCase()).getAbsolutePath()));
        }
        StringBuilder sb = new StringBuilder();
        sb.append("business_id,labels").append('\n');
        Iterator it3 = lazyMap.entrySet().iterator();
        while (it3.hasNext()) {
            String str = (String) ((Map.Entry) it3.next()).getKey();
            InstanceBuilder instanceBuilder = new InstanceBuilder();
            for (Map.Entry entry : ((Map) lazyMap.get(str)).entrySet()) {
                instanceBuilder.set(((String) entry.getKey()).toLowerCase() + "_mean_probability", ((Stats) entry.getValue()).getMean());
                instanceBuilder.set(((String) entry.getKey()).toLowerCase() + "_max_probability", ((Stats) entry.getValue()).getMax());
            }
            FeatureVector create = instanceBuilder.create();
            StringBuilder sb2 = new StringBuilder();
            sb2.append(str);
            sb2.append(',');
            for (Label label4 : Label.values()) {
                if (new LibLinearClassifier().classify(create, (LibLinearModel) hashMap.get(label4)).getProbability("true") > 0.5d) {
                    sb2.append(label4.getLabelId()).append(' ');
                }
            }
            sb.append(sb2.toString().trim()).append('\n');
        }
        FileHelper.writeToFile("/Users/pk/Desktop/submission_multi_" + DateHelper.getCurrentDatetime() + ".csv", sb.toString());
    }

    private static void train() throws IOException {
        Dataset read = read("/Volumes/iMac HD/Research/Yelp_Kaggle_Restaurants/data/classified_train_true_2016-04-02_08-30-41.csv");
        Dataset read2 = read("/Volumes/iMac HD/Research/Yelp_Kaggle_Restaurants/data/classified_train_true_2016-04-12_22-59-48.csv");
        MultiMap createWithSet = DefaultMultiMap.createWithSet();
        CloseableIterator it = read.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            HashMap hashMap = new HashMap();
            for (Label label : Label.values()) {
                hashMap.put(label.toString() + "_1", Double.valueOf(((NumericValue) instance.getVector().get(label.toString())).getDouble()));
            }
            createWithSet.add(((Value) instance.getVector().get("businessId")).toString(), hashMap);
        }
        CloseableIterator it2 = read2.iterator();
        while (it2.hasNext()) {
            Instance instance2 = (Instance) it2.next();
            HashMap hashMap2 = new HashMap();
            for (Label label2 : Label.values()) {
                hashMap2.put(label2.toString() + "_2", Double.valueOf(((NumericValue) instance2.getVector().get(label2.toString())).getDouble()));
            }
            createWithSet.add(((Value) instance2.getVector().get("businessId")).toString(), hashMap2);
        }
        for (Label label3 : Label.values()) {
            HashMap hashMap3 = new HashMap();
            for (Map.Entry entry : createWithSet.entrySet()) {
                Collection collection = (Collection) entry.getValue();
                LazyMap lazyMap = new LazyMap(FatStats.FACTORY);
                Iterator it3 = collection.iterator();
                while (it3.hasNext()) {
                    for (Map.Entry entry2 : ((Map) it3.next()).entrySet()) {
                        ((Stats) lazyMap.get(entry2.getKey())).add((Number) entry2.getValue());
                    }
                }
                hashMap3.put(entry.getKey(), lazyMap);
            }
            CsvDatasetReaderConfig.Builder filePath = CsvDatasetReaderConfig.filePath(Config.getFilePath("dataset.yelp.restaurants.train.csv"));
            filePath.readClassFromLastColumn(false);
            filePath.setFieldSeparator(',');
            filePath.treatAsNullValue("");
            filePath.parser("business_id", ImmutableStringValue.PARSER);
            CsvDatasetReader create = filePath.create();
            ArrayList arrayList = new ArrayList();
            CloseableIterator it4 = create.iterator();
            while (it4.hasNext()) {
                Instance instance3 = (Instance) it4.next();
                String obj = ((Value) instance3.getVector().get("business_id")).toString();
                InstanceBuilder instanceBuilder = new InstanceBuilder();
                for (Map.Entry entry3 : ((Map) hashMap3.get(obj)).entrySet()) {
                    instanceBuilder.set(((String) entry3.getKey()).toLowerCase() + "_mean_probability", ((Stats) entry3.getValue()).getMean());
                    instanceBuilder.set(((String) entry3.getKey()).toLowerCase() + "_max_probability", ((Stats) entry3.getValue()).getMax());
                }
                arrayList.add(instanceBuilder.create(new HashSet(Arrays.asList(((Value) instance3.getVector().get("labels")).toString().split(" "))).contains(label3.getLabelId() + "")));
            }
            DefaultDataset defaultDataset = new DefaultDataset(arrayList);
            LibLinearModel train = new LibLinearLearner().train(defaultDataset);
            ConfusionMatrix evaluate = new ConfusionMatrixEvaluator().evaluate(new LibLinearClassifier(), train, defaultDataset);
            System.out.println(label3);
            System.out.println(evaluate);
            FileHelper.serialize(train, "/Users/pk/Desktop/aggregation_multi_" + label3.toString() + "_" + DateHelper.getCurrentDatetime() + ".ser.gz");
        }
    }

    private static Dataset read(String str) {
        CsvDatasetReaderConfig.Builder filePath = CsvDatasetReaderConfig.filePath(new File(str));
        filePath.parser("businessId", ImmutableStringValue.PARSER);
        return filePath.create();
    }
}
