package ws.palladian.classification.utils;

import java.io.EOFException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.core.AppendedVector;
import ws.palladian.core.FeatureVector;
import ws.palladian.core.FilteredVector;
import ws.palladian.core.InstanceBuilder;
import ws.palladian.core.dataset.AbstractDatasetFeatureVectorTransformer;
import ws.palladian.core.dataset.Dataset;
import ws.palladian.core.dataset.FeatureInformation;
import ws.palladian.core.dataset.FeatureInformationBuilder;
import ws.palladian.core.dataset.statistics.DatasetStatistics;
import ws.palladian.core.dataset.statistics.NominalValueStatistics;
import ws.palladian.core.featurevector.FlyweightVectorBuilder;
import ws.palladian.core.featurevector.FlyweightVectorSchema;
import ws.palladian.core.value.ImmutableIntegerValue;
import ws.palladian.core.value.NominalValue;
import ws.palladian.core.value.Value;
import ws.palladian.helper.StopWatch;
import ws.palladian.helper.collection.Vector;
import ws.palladian.helper.functional.Predicates;

/* loaded from: input_file:ws/palladian/classification/utils/DummyVariableCreator.class */
public class DummyVariableCreator extends AbstractDatasetFeatureVectorTransformer implements Serializable {
    private static final long serialVersionUID = 1;
    private static final Logger LOGGER = LoggerFactory.getLogger(DummyVariableCreator.class);
    private transient Map<String, Mapper> mappers;
    private transient boolean keepOriginalFeature;
    private transient boolean dense;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ws/palladian/classification/utils/DummyVariableCreator$Mapper.class */
    public static final class Mapper {
        final Map<String, FeatureVector> mapping;
        final FeatureVector missing;
        final FeatureInformation featureInformation;

        Mapper(String str, NominalValueStatistics nominalValueStatistics, boolean z) {
            this(prepare(str, nominalValueStatistics), z);
        }

        Mapper(Map<String, String> map, boolean z) {
            this.mapping = new HashMap();
            FlyweightVectorSchema flyweightVectorSchema = z ? new FlyweightVectorSchema((String[]) map.values().toArray(new String[0])) : null;
            for (Map.Entry<String, String> entry : map.entrySet()) {
                this.mapping.put(entry.getKey(), createDummyVector(flyweightVectorSchema, entry.getValue(), map.values(), z));
            }
            this.missing = createDummyVector(flyweightVectorSchema, null, map.values(), z);
            this.featureInformation = new FeatureInformationBuilder().set(map.values(), ImmutableIntegerValue.class).m85create();
        }

        private static Map<String, String> prepare(String str, NominalValueStatistics nominalValueStatistics) {
            Set<String> values = nominalValueStatistics.getValues();
            if (nominalValueStatistics.getNumUniqueValuesIncludingNull() <= 2) {
                if (values.containsAll(Arrays.asList("true", "false"))) {
                    values = Collections.singleton("true");
                } else if (values.size() > 0) {
                    values = Collections.singleton(values.iterator().next());
                }
            }
            HashMap hashMap = new HashMap();
            for (String str2 : values) {
                hashMap.put(str2, str + ":" + str2);
            }
            return hashMap;
        }

        private static FeatureVector createDummyVector(FlyweightVectorSchema flyweightVectorSchema, String str, Collection<String> collection, boolean z) {
            if (!z) {
                InstanceBuilder instanceBuilder = new InstanceBuilder();
                if (str != null) {
                    instanceBuilder.set(str, 1);
                }
                return instanceBuilder.create();
            }
            FlyweightVectorBuilder builder = flyweightVectorSchema.builder();
            for (String str2 : collection) {
                builder.set(str2, ImmutableIntegerValue.valueOf(str2.equals(str) ? 1 : 0));
            }
            return builder.m99create();
        }

        public FeatureInformation getFeatureInformation() {
            return this.featureInformation;
        }

        public FeatureVector getAppendedFeatureVector(Value value) {
            FeatureVector featureVector = this.mapping.get(value.toString());
            return featureVector != null ? featureVector : this.missing;
        }

        public String toString() {
            return this.mapping.toString();
        }
    }

    @Deprecated
    public DummyVariableCreator(Dataset dataset) {
        this(dataset, false);
    }

    @Deprecated
    public DummyVariableCreator(Dataset dataset, boolean z) {
        this(dataset, z, true);
    }

    public DummyVariableCreator(Dataset dataset, boolean z, boolean z2) {
        this.keepOriginalFeature = false;
        this.dense = true;
        Validate.notNull(dataset, "dataset must not be null", new Object[0]);
        this.mappers = buildMappers(dataset, z2);
        if (getNominalFeatureCount() > 0) {
            LOGGER.info("# nominal features which will be mapped: {}", Integer.valueOf(getNominalFeatureCount()));
            LOGGER.info("# created features: {}", Integer.valueOf(getCreatedNumericFeatures().size()));
        }
        this.keepOriginalFeature = z;
        this.dense = z2;
    }

    private static Map<String, Mapper> buildMappers(Dataset dataset, boolean z) {
        HashMap hashMap = new HashMap();
        Set<String> featureNamesOfType = dataset.getFeatureInformation().getFeatureNamesOfType(NominalValue.class);
        if (featureNamesOfType.isEmpty()) {
            LOGGER.debug("No nominal features in dataset.");
        } else {
            LOGGER.debug("Determine domain for dataset ...");
            StopWatch stopWatch = new StopWatch();
            DatasetStatistics datasetStatistics = new DatasetStatistics(dataset.filterFeatures(Predicates.equal(featureNamesOfType)));
            for (String str : featureNamesOfType) {
                hashMap.put(str, new Mapper(str, (NominalValueStatistics) datasetStatistics.getValueStatistics(str), z));
            }
            LOGGER.debug("... finished determining domain in {}", stopWatch);
        }
        return hashMap;
    }

    @Override // ws.palladian.core.dataset.AbstractDatasetFeatureVectorTransformer, ws.palladian.core.dataset.DatasetTransformer
    public FeatureInformation getFeatureInformation(FeatureInformation featureInformation) {
        FeatureInformationBuilder featureInformationBuilder = new FeatureInformationBuilder();
        for (FeatureInformation.FeatureInformationEntry featureInformationEntry : featureInformation) {
            Mapper mapper = this.mappers.get(featureInformationEntry.getName());
            if (mapper == null || this.keepOriginalFeature) {
                featureInformationBuilder.set(featureInformationEntry);
            }
            if (mapper != null) {
                featureInformationBuilder.add(mapper.getFeatureInformation());
            }
        }
        return featureInformationBuilder.m85create();
    }

    @Override // ws.palladian.core.dataset.AbstractDatasetFeatureVectorTransformer
    public FeatureVector apply(FeatureVector featureVector) {
        return convert(featureVector);
    }

    public FeatureVector convert(FeatureVector featureVector) {
        Validate.notNull(featureVector, "featureVector must not be null", new Object[0]);
        if (this.mappers.isEmpty()) {
            return featureVector;
        }
        ArrayList arrayList = new ArrayList();
        if (this.keepOriginalFeature) {
            arrayList.add(featureVector);
        } else {
            arrayList.add(new FilteredVector(featureVector, (Predicate<? super String>) Predicates.not(Predicates.equal(this.mappers.keySet()))));
        }
        Iterator it = featureVector.iterator();
        while (it.hasNext()) {
            Vector.VectorEntry vectorEntry = (Vector.VectorEntry) it.next();
            String str = (String) vectorEntry.key();
            Value value = (Value) vectorEntry.value();
            Mapper mapper = this.mappers.get(str);
            if (mapper != null) {
                arrayList.add(mapper.getAppendedFeatureVector(value));
            }
        }
        return new AppendedVector(arrayList);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("NumericFeatureMapper\n");
        for (Map.Entry<String, Mapper> entry : this.mappers.entrySet()) {
            sb.append(entry.getKey()).append(":").append(entry.getValue()).append('\n');
        }
        sb.append('\n');
        sb.append("# nominal features: ").append(getNominalFeatureCount()).append('\n');
        sb.append("# created numeric features: ").append(getCreatedNumericFeatures().size());
        return sb.toString();
    }

    final int getNominalFeatureCount() {
        return this.mappers.size();
    }

    Set<String> getCreatedNumericFeatures() {
        HashSet hashSet = new HashSet();
        Iterator<Mapper> it = this.mappers.values().iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getFeatureInformation().getFeatureNames());
        }
        return hashSet;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(getNominalFeatureCount());
        for (String str : this.mappers.keySet()) {
            objectOutputStream.writeObject(str);
            Map<String, FeatureVector> map = this.mappers.get(str).mapping;
            objectOutputStream.writeInt(map.size());
            for (Map.Entry<String, FeatureVector> entry : map.entrySet()) {
                objectOutputStream.writeObject(entry.getKey());
                Iterator it = entry.getValue().iterator();
                while (true) {
                    if (it.hasNext()) {
                        Vector.VectorEntry vectorEntry = (Vector.VectorEntry) it.next();
                        if (((Value) vectorEntry.value()).equals(ImmutableIntegerValue.valueOf(1))) {
                            objectOutputStream.writeObject(vectorEntry.key());
                            break;
                        }
                    }
                }
            }
        }
        objectOutputStream.writeBoolean(this.keepOriginalFeature);
        objectOutputStream.writeBoolean(this.dense);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        HashMap hashMap = new HashMap();
        int readInt = objectInputStream.readInt();
        for (int i = 0; i < readInt; i++) {
            String str = (String) objectInputStream.readObject();
            HashMap hashMap2 = new HashMap();
            int readInt2 = objectInputStream.readInt();
            for (int i2 = 0; i2 < readInt2; i2++) {
                hashMap2.put((String) objectInputStream.readObject(), (String) objectInputStream.readObject());
            }
            hashMap.put(str, hashMap2);
        }
        try {
            this.keepOriginalFeature = objectInputStream.readBoolean();
            this.dense = objectInputStream.readBoolean();
        } catch (EOFException e) {
        }
        this.mappers = new HashMap();
        for (Map.Entry entry : hashMap.entrySet()) {
            this.mappers.put(entry.getKey(), new Mapper((Map) entry.getValue(), this.dense));
        }
    }
}
