package ws.palladian.classification;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.lang3.Validate;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.SparseInstance;
import ws.palladian.core.AbstractLearner;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.Instance;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.value.NominalValue;
import ws.palladian.core.value.NullValue;
import ws.palladian.core.value.NumericValue;
import ws.palladian.core.value.Value;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.Vector;
import ws.palladian.helper.io.CloseableIterator;

/* loaded from: input_file:ws/palladian/classification/WekaLearner.class */
public final class WekaLearner extends AbstractLearner<WekaModel> {
    static final String DUMMY_CLASS = "wekadummyclass";
    private static final String TARGET_CLASS_ATTRIBUTE = "palladianWekaTargetClass";
    private final Classifier classifier;

    public WekaLearner(Classifier classifier) {
        Validate.notNull(classifier, "classifier must not be null.", new Object[0]);
        this.classifier = classifier;
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public WekaModel m5train(Dataset dataset) {
        Validate.notNull(dataset, "dataset must not be null", new Object[0]);
        Instances instances = new Instances("dataset", new FastVector(), CollectionHelper.count(dataset.iterator()));
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        ArrayList arrayList2 = new ArrayList();
        CloseableIterator it = dataset.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            arrayList.add(createWekaFeatureSet(instance.getVector(), instances, dataset));
            hashSet.add(instance.getCategory());
            arrayList2.add(instance.getCategory());
        }
        FastVector fastVector = new FastVector();
        fastVector.addElement(DUMMY_CLASS);
        Iterator it2 = hashSet.iterator();
        while (it2.hasNext()) {
            fastVector.addElement((String) it2.next());
        }
        instances.insertAttributeAt(new Attribute(TARGET_CLASS_ATTRIBUTE, fastVector), instances.numAttributes());
        int numAttributes = instances.numAttributes();
        for (int i = 0; i < arrayList.size(); i++) {
            Map map = (Map) arrayList.get(i);
            String str = (String) arrayList2.get(i);
            int[] iArr = new int[map.size() + 1];
            double[] dArr = new double[map.size() + 1];
            int i2 = 0;
            for (Map.Entry entry : map.entrySet()) {
                iArr[i2] = ((Integer) entry.getKey()).intValue();
                dArr[i2] = ((Double) entry.getValue()).doubleValue();
                i2++;
            }
            iArr[i2] = numAttributes - 1;
            dArr[i2] = r0.indexOfValue(str);
            SparseInstance sparseInstance = new SparseInstance(1.0d, dArr, iArr, map.size());
            sparseInstance.setDataset(instances);
            instances.add(sparseInstance);
        }
        instances.compactify();
        instances.setClassIndex(instances.attribute(TARGET_CLASS_ATTRIBUTE).index());
        try {
            this.classifier.buildClassifier(instances);
            return new WekaModel(this.classifier, instances);
        } catch (Exception e) {
            throw new IllegalStateException("An exception occurred while building the classifier: " + e.getMessage(), e);
        }
    }

    private Map<Integer, Double> createWekaFeatureSet(FeatureVector featureVector, Instances instances, Iterable<? extends Instance> iterable) {
        HashMap hashMap = new HashMap();
        Iterator it = featureVector.iterator();
        while (it.hasNext()) {
            Vector.VectorEntry vectorEntry = (Vector.VectorEntry) it.next();
            String str = (String) vectorEntry.key();
            NumericValue numericValue = (Value) vectorEntry.value();
            if (numericValue instanceof NominalValue) {
                Attribute attribute = instances.attribute(str);
                if (attribute == null) {
                    instances.insertAttributeAt(new Attribute(str, getValues(str, iterable)), instances.numAttributes());
                    attribute = instances.attribute(str);
                }
                hashMap.put(Integer.valueOf(attribute.index()), Double.valueOf(attribute.indexOfValue(((NominalValue) numericValue).getString())));
            } else if (numericValue instanceof NumericValue) {
                Attribute attribute2 = instances.attribute(str);
                if (attribute2 == null) {
                    instances.insertAttributeAt(new Attribute(str), instances.numAttributes());
                    attribute2 = instances.attribute(str);
                }
                hashMap.put(Integer.valueOf(attribute2.index()), Double.valueOf(numericValue.getDouble()));
            }
        }
        return hashMap;
    }

    private FastVector getValues(String str, Iterable<? extends Instance> iterable) {
        HashSet hashSet = new HashSet();
        Iterator<? extends Instance> it = iterable.iterator();
        while (it.hasNext()) {
            NominalValue nominalValue = (Value) it.next().getVector().get(str);
            if (nominalValue != null && nominalValue != NullValue.NULL) {
                hashSet.add(nominalValue.getString());
            }
        }
        FastVector fastVector = new FastVector(hashSet.size());
        Iterator it2 = hashSet.iterator();
        while (it2.hasNext()) {
            fastVector.addElement((String) it2.next());
        }
        return fastVector;
    }

    public String toString() {
        return this.classifier.toString();
    }
}
