package ws.palladian.extraction.location.scope;

import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.dt.QuickDtClassifier;
import ws.palladian.classification.dt.QuickDtLearner;
import ws.palladian.classification.dt.QuickDtModel;
import ws.palladian.classification.featureselection.FeatureSelector;
import ws.palladian.classification.utils.CsvDatasetReader;
import ws.palladian.core.Classifier;
import ws.palladian.core.Instance;
import ws.palladian.core.InstanceBuilder;
import ws.palladian.core.Learner;
import ws.palladian.core.Model;
import ws.palladian.extraction.location.Location;
import ws.palladian.extraction.location.LocationAnnotation;
import ws.palladian.extraction.location.LocationExtractor;
import ws.palladian.extraction.location.LocationExtractorUtils;
import ws.palladian.extraction.location.LocationFilters;
import ws.palladian.extraction.location.LocationSet;
import ws.palladian.extraction.location.PalladianLocationExtractor;
import ws.palladian.extraction.location.disambiguation.ClassifiableLocation;
import ws.palladian.extraction.location.disambiguation.ConfigurableFeatureExtractor;
import ws.palladian.extraction.location.disambiguation.FeatureBasedDisambiguation;
import ws.palladian.extraction.location.evaluation.LocationDocument;
import ws.palladian.extraction.location.evaluation.TudLoc2013DatasetIterable;
import ws.palladian.extraction.location.persistence.LocationDatabase;
import ws.palladian.helper.NoProgress;
import ws.palladian.helper.ProgressReporter;
import ws.palladian.helper.StopWatch;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.geo.GeoCoordinate;
import ws.palladian.helper.io.FileHelper;
import ws.palladian.helper.math.Stats;
import ws.palladian.persistence.DatabaseManagerFactory;

/* loaded from: input_file:ws/palladian/extraction/location/scope/FeatureBasedScopeDetector.class */
public final class FeatureBasedScopeDetector extends AbstractRankingScopeDetector {
    private static final Logger LOGGER = LoggerFactory.getLogger(FeatureBasedScopeDetector.class);
    private static final int POSITIVE_DISTANCE_THRESHOLD = 50;
    private static final String NAME = "FeatureBased";
    private final QuickDtModel scopeModel;
    private final QuickDtClassifier classifier;

    public FeatureBasedScopeDetector(LocationExtractor locationExtractor, QuickDtModel quickDtModel) {
        super(locationExtractor);
        this.classifier = new QuickDtClassifier();
        Validate.notNull(quickDtModel, "scopeModel must not be null", new Object[0]);
        this.scopeModel = quickDtModel;
    }

    @Override // ws.palladian.extraction.location.scope.RankingScopeDetector
    public Location getScope(Collection<LocationAnnotation> collection) {
        Validate.notNull(collection, "annotations must not be null", new Object[0]);
        if (collection.isEmpty()) {
            return null;
        }
        double d = Double.MIN_VALUE;
        Location location = null;
        for (ClassifiableLocation classifiableLocation : extractFeatures(collection)) {
            double probability = this.classifier.classify(classifiableLocation.getFeatureVector(), this.scopeModel).getProbability("true");
            LOGGER.trace("{} : {}", classifiableLocation.getLocation().getPrimaryName(), Double.valueOf(probability));
            if (location == null || probability > d) {
                d = probability;
                location = classifiableLocation.getLocation();
            }
        }
        return location;
    }

    private static Set<ClassifiableLocation> extractFeatures(Collection<LocationAnnotation> collection) {
        LocationSet locationSet = new LocationSet(CollectionHelper.convertList(collection, LocationExtractorUtils.ANNOTATION_LOCATION_FUNCTION));
        LocationSet where = locationSet.where(LocationFilters.coordinate());
        if (where.isEmpty()) {
            return Collections.emptySet();
        }
        GeoCoordinate midpoint = where.midpoint();
        GeoCoordinate center = where.center();
        int i = 1;
        Iterator<LocationAnnotation> it = collection.iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().getStartPosition());
        }
        int size = collection.size();
        HashSet hashSet = new HashSet();
        Iterator<Location> it2 = locationSet.iterator();
        while (it2.hasNext()) {
            Location next = it2.next();
            GeoCoordinate geoCoordinate = (GeoCoordinate) CollectionHelper.coalesce(new GeoCoordinate[]{next.getCoordinate(), GeoCoordinate.NULL});
            double d = 0.0d;
            int i2 = Integer.MAX_VALUE;
            int i3 = Integer.MIN_VALUE;
            for (LocationAnnotation locationAnnotation : collection) {
                if (locationAnnotation.getLocation().equals(next)) {
                    d = Math.max(d, locationAnnotation.getTrust());
                    i2 = Math.min(i2, locationAnnotation.getStartPosition());
                    i3 = Math.max(i3, locationAnnotation.getStartPosition());
                }
            }
            Stats distanceStats = where.distanceStats(next);
            InstanceBuilder instanceBuilder = new InstanceBuilder();
            instanceBuilder.set("midpointDistance", midpoint.distance(geoCoordinate));
            instanceBuilder.set("centerpointDistance", center.distance(geoCoordinate));
            instanceBuilder.set("occurrenceFrequency", Collections.frequency(r0, next) / size);
            instanceBuilder.set("descendantPercentage", locationSet.where(LocationFilters.descendantOf(next)).size() / size);
            instanceBuilder.set("ancestorPercentage", locationSet.where(LocationFilters.ancestorOf(next)).size() / size);
            instanceBuilder.set("hierarchyDepth", next.getAncestorIds().size());
            instanceBuilder.set("population", ((Long) CollectionHelper.coalesce(new Long[]{next.getPopulation(), 0L})).longValue());
            instanceBuilder.set("locationType", next.getType().toString());
            instanceBuilder.set("disambiguationTrust", d);
            instanceBuilder.set("offsetFirst", i2 / i);
            instanceBuilder.set("offsetLast", i3 / i);
            instanceBuilder.set("offsetSpread", (i3 - i2) / i);
            instanceBuilder.set("minDistanceToOthers", Double.isNaN(distanceStats.getMin()) ? 0.0d : distanceStats.getMin());
            instanceBuilder.set("maxDistanceToOthers", Double.isNaN(distanceStats.getMax()) ? 0.0d : distanceStats.getMax());
            instanceBuilder.set("meanDistanceToOthers", Double.isNaN(distanceStats.getMean()) ? 0.0d : distanceStats.getMean());
            instanceBuilder.set("medianDistanceToOthers", Double.isNaN(distanceStats.getMedian()) ? 0.0d : distanceStats.getMedian());
            hashSet.add(new ClassifiableLocation(next, instanceBuilder.create()));
        }
        return hashSet;
    }

    public String toString() {
        return "FeatureBased:" + this.classifier.getClass().getSimpleName();
    }

    public static QuickDtModel train(Iterable<LocationDocument> iterable, LocationExtractor locationExtractor) {
        Validate.notNull(iterable, "documentIterator must not be null", new Object[0]);
        Validate.notNull(locationExtractor, "extractor must not be null", new Object[0]);
        Collection<Instance> createInstances = createInstances(iterable, locationExtractor);
        StopWatch stopWatch = new StopWatch();
        QuickDtModel train = QuickDtLearner.randomForest(100).train(createInstances);
        LOGGER.info("Trained model in {}", stopWatch.getElapsedTimeString());
        return train;
    }

    public static Collection<Instance> createInstances(Iterable<LocationDocument> iterable, LocationExtractor locationExtractor) {
        Validate.notNull(iterable, "documents must not be null", new Object[0]);
        Validate.notNull(locationExtractor, "extractor must not be null", new Object[0]);
        HashSet hashSet = new HashSet();
        for (LocationDocument locationDocument : iterable) {
            List<LocationAnnotation> annotations = locationExtractor.getAnnotations(locationDocument.getText());
            Location mainLocation = locationDocument.getMainLocation();
            if (!annotations.isEmpty() && mainLocation != null && mainLocation.getCoordinate() != null) {
                Set<ClassifiableLocation> extractFeatures = extractFeatures(annotations);
                ClassifiableLocation classifiableLocation = null;
                double d = Double.MAX_VALUE;
                for (ClassifiableLocation classifiableLocation2 : extractFeatures) {
                    double distance = mainLocation.getCoordinate().distance((GeoCoordinate) CollectionHelper.coalesce(new GeoCoordinate[]{classifiableLocation2.getLocation().getCoordinate(), GeoCoordinate.NULL}));
                    if (distance < d) {
                        d = distance;
                        classifiableLocation = classifiableLocation2;
                    }
                }
                if (d > 50.0d) {
                    classifiableLocation = null;
                    LOGGER.warn("Could not determine positive candidate, distance to closest is {}", Double.valueOf(d));
                } else {
                    LOGGER.trace("Distance between actual and training candidate is {}", Double.valueOf(d));
                }
                Iterator<ClassifiableLocation> it = extractFeatures.iterator();
                while (it.hasNext()) {
                    ClassifiableLocation next = it.next();
                    hashSet.add(new InstanceBuilder().add(next.getFeatureVector()).create(next == classifiableLocation));
                }
            }
        }
        return hashSet;
    }

    public static <M extends Model> void runFeatureElimination(File file, File file2, Learner<M> learner, Classifier<M> classifier) {
        CollectionHelper.print(new FeatureSelector(learner, classifier, new FeatureSelector.FMeasureScorer("true")).rankFeatures((Iterable<? extends Instance>) new CsvDatasetReader(file).readAll(), (Iterable<? extends Instance>) new CsvDatasetReader(file2).readAll(), (ProgressReporter) NoProgress.INSTANCE).getAll());
    }

    public static void main(String[] strArr) throws IOException {
        PalladianLocationExtractor palladianLocationExtractor = new PalladianLocationExtractor((LocationDatabase) DatabaseManagerFactory.create(LocationDatabase.class, "locations"), new FeatureBasedDisambiguation((QuickDtModel) FileHelper.deserialize("/Users/pk/Dropbox/Uni/Dissertation_LocationLab/Models/location_disambiguation_all_train_1377442726898.model"), 0.0d, new ConfigurableFeatureExtractor()));
        FileHelper.serialize(train(new TudLoc2013DatasetIterable(new File("/Users/pk/Dropbox/Uni/Datasets/TUD-Loc-2013/1-training")), palladianLocationExtractor), new File("scopeDetection_tud-loc_quickDt.model").getPath());
    }
}
