package com.marklogic.spark.reader;

import com.marklogic.spark.Options;
import com.marklogic.spark.reader.filter.FilterFactory;
import com.marklogic.spark.reader.filter.OpticFilter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;
import org.apache.spark.sql.connector.expressions.SortOrder;
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc;
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation;
import org.apache.spark.sql.connector.expressions.aggregate.Avg;
import org.apache.spark.sql.connector.expressions.aggregate.Count;
import org.apache.spark.sql.connector.expressions.aggregate.CountStar;
import org.apache.spark.sql.connector.expressions.aggregate.Max;
import org.apache.spark.sql.connector.expressions.aggregate.Min;
import org.apache.spark.sql.connector.expressions.aggregate.Sum;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.connector.read.SupportsPushDownAggregates;
import org.apache.spark.sql.connector.read.SupportsPushDownFilters;
import org.apache.spark.sql.connector.read.SupportsPushDownLimit;
import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns;
import org.apache.spark.sql.connector.read.SupportsPushDownTopN;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/marklogic/spark/reader/MarkLogicScanBuilder.class */
public class MarkLogicScanBuilder implements ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownTopN, SupportsPushDownAggregates, SupportsPushDownRequiredColumns {
    private final ReadContext readContext;
    private List<Filter> pushedFilters;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) MarkLogicScanBuilder.class);
    private static final Set<Class<? extends AggregateFunc>> SUPPORTED_AGGREGATE_FUNCTIONS = new HashSet() { // from class: com.marklogic.spark.reader.MarkLogicScanBuilder.1
        {
            add(Avg.class);
            add(Count.class);
            add(CountStar.class);
            add(Max.class);
            add(Min.class);
            add(Sum.class);
        }
    };

    public MarkLogicScanBuilder(ReadContext readContext) {
        this.readContext = readContext;
    }

    public Scan build() {
        if (logger.isDebugEnabled()) {
            logger.debug("Creating new scan");
        }
        return new MarkLogicScan(this.readContext);
    }

    public Filter[] pushFilters(Filter[] filterArr) {
        this.pushedFilters = new ArrayList();
        if (this.readContext.planAnalysisFoundNoRows()) {
            return filterArr;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        if (logger.isDebugEnabled()) {
            logger.debug("Filter count: {}", Integer.valueOf(filterArr.length));
        }
        for (Filter filter : filterArr) {
            OpticFilter planFilter = FilterFactory.toPlanFilter(filter);
            if (planFilter != null) {
                if (logger.isInfoEnabled()) {
                    logger.info("Pushing down filter: {}", filter);
                }
                arrayList2.add(planFilter);
                this.pushedFilters.add(filter);
            } else {
                if (logger.isDebugEnabled()) {
                    logger.debug("Unsupported filter, will be handled by Spark: {}", filter);
                }
                arrayList.add(filter);
            }
        }
        this.readContext.pushDownFiltersIntoOpticQuery(arrayList2);
        return (Filter[]) arrayList.toArray(new Filter[0]);
    }

    public Filter[] pushedFilters() {
        return (Filter[]) this.pushedFilters.toArray(new Filter[0]);
    }

    public boolean pushLimit(int i) {
        if (this.readContext.planAnalysisFoundNoRows()) {
            return false;
        }
        if (logger.isInfoEnabled()) {
            logger.info("Pushing down limit: {}", Integer.valueOf(i));
        }
        this.readContext.pushDownLimit(i);
        return true;
    }

    public boolean pushTopN(SortOrder[] sortOrderArr, int i) {
        if (this.readContext.planAnalysisFoundNoRows()) {
            return false;
        }
        if (logger.isInfoEnabled()) {
            logger.info("Pushing down topN: {}; limit: {}", Arrays.asList(sortOrderArr), Integer.valueOf(i));
        }
        this.readContext.pushDownTopN(sortOrderArr, i);
        return true;
    }

    public boolean isPartiallyPushed() {
        return this.readContext.getBucketCount() > 1;
    }

    public boolean supportCompletePushDown(Aggregation aggregation) {
        if (this.readContext.planAnalysisFoundNoRows() || pushDownAggregatesIsDisabled()) {
            return false;
        }
        if (hasUnsupportedAggregateFunction(aggregation)) {
            logger.info("Aggregation contains one or more unsupported functions, so not pushing aggregation to MarkLogic: {}", describeAggregation(aggregation));
            return false;
        }
        if (this.readContext.getBucketCount() <= 1) {
            return true;
        }
        logger.info("Multiple requests will be made to MarkLogic; aggregation will be applied by Spark as well: {}", describeAggregation(aggregation));
        return false;
    }

    public boolean pushAggregation(Aggregation aggregation) {
        if (this.readContext.planAnalysisFoundNoRows() || hasUnsupportedAggregateFunction(aggregation)) {
            return false;
        }
        if (pushDownAggregatesIsDisabled()) {
            logger.info("Push down of aggregates is disabled; Spark will handle all aggregations.");
            return false;
        }
        logger.info("Pushing down aggregation: {}", describeAggregation(aggregation));
        this.readContext.pushDownAggregation(aggregation);
        return true;
    }

    public void pruneColumns(StructType structType) {
        if (this.readContext.planAnalysisFoundNoRows()) {
            return;
        }
        if (structType.equals(this.readContext.getSchema())) {
            if (logger.isDebugEnabled()) {
                logger.debug("The schema to push down is equal to the existing schema, so not pushing it down.");
            }
        } else {
            if (logger.isDebugEnabled()) {
                logger.debug("Pushing down required schema: {}", structType.json());
            }
            this.readContext.pushDownRequiredSchema(structType);
        }
    }

    private boolean hasUnsupportedAggregateFunction(Aggregation aggregation) {
        return Stream.of((Object[]) aggregation.aggregateExpressions()).anyMatch(aggregateFunc -> {
            return !SUPPORTED_AGGREGATE_FUNCTIONS.contains(aggregateFunc.getClass());
        });
    }

    private String describeAggregation(Aggregation aggregation) {
        return String.format("groupBy: %s; aggregates: %s", Arrays.asList(aggregation.groupByExpressions()), Arrays.asList(aggregation.aggregateExpressions()));
    }

    private boolean pushDownAggregatesIsDisabled() {
        return "false".equalsIgnoreCase(this.readContext.getProperties().get(Options.READ_PUSH_DOWN_AGGREGATES));
    }
}
