package com.marklogic.spark.reader;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.marklogic.client.DatabaseClient;
import com.marklogic.client.FailedRequestException;
import com.marklogic.client.expression.PlanBuilder;
import com.marklogic.client.impl.DatabaseClientImpl;
import com.marklogic.client.io.JacksonHandle;
import com.marklogic.client.io.StringHandle;
import com.marklogic.client.row.RawQueryDSLPlan;
import com.marklogic.client.row.RowManager;
import com.marklogic.spark.ContextSupport;
import com.marklogic.spark.Options;
import com.marklogic.spark.reader.PlanAnalysis;
import com.marklogic.spark.reader.filter.OpticFilter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.SortOrder;
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation;
import org.apache.spark.sql.connector.expressions.aggregate.Avg;
import org.apache.spark.sql.connector.expressions.aggregate.Count;
import org.apache.spark.sql.connector.expressions.aggregate.CountStar;
import org.apache.spark.sql.connector.expressions.aggregate.Max;
import org.apache.spark.sql.connector.expressions.aggregate.Min;
import org.apache.spark.sql.connector.expressions.aggregate.Sum;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/marklogic/spark/reader/ReadContext.class */
public class ReadContext extends ContextSupport {
    static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) ReadContext.class);
    private static final long DEFAULT_BATCH_SIZE = 100000;
    private PlanAnalysis planAnalysis;
    private StructType schema;
    private long serverTimestamp;
    private List<OpticFilter> opticFilters;

    public ReadContext(Map<String, String> map, StructType structType) {
        super(map);
        this.schema = structType;
        long numericOption = getNumericOption(Options.READ_NUM_PARTITIONS, SparkSession.active().sparkContext().defaultMinPartitions(), 1L);
        long numericOption2 = getNumericOption(Options.READ_BATCH_SIZE, DEFAULT_BATCH_SIZE, 0L);
        String str = map.get(Options.READ_OPTIC_QUERY);
        if (str == null || str.trim().length() < 1) {
            throw new IllegalArgumentException(String.format("No Optic query found; must define %s", Options.READ_OPTIC_QUERY));
        }
        DatabaseClient connectToMarkLogic = connectToMarkLogic();
        RawQueryDSLPlan newRawQueryDSLPlan = connectToMarkLogic.newRowManager().newRawQueryDSLPlan(new StringHandle(str));
        try {
            this.planAnalysis = new PlanAnalyzer((DatabaseClientImpl) connectToMarkLogic).analyzePlan(newRawQueryDSLPlan.getHandle(), numericOption, numericOption2);
        } catch (FailedRequestException e) {
            handlePlanAnalysisError(str, e);
        }
        if (this.planAnalysis != null) {
            if (logger.isInfoEnabled()) {
                logger.info("Partition count: {}; number of requests that will be made to MarkLogic: {}", Integer.valueOf(this.planAnalysis.partitions.size()), Integer.valueOf(this.planAnalysis.getAllBuckets().size()));
            }
            this.serverTimestamp = ((StringHandle) connectToMarkLogic.newRowManager().columnInfo(newRawQueryDSLPlan, new StringHandle())).getServerTimestamp();
            if (logger.isDebugEnabled()) {
                logger.debug("Will use server timestamp: {}", Long.valueOf(this.serverTimestamp));
            }
        }
    }

    private void handlePlanAnalysisError(String str, FailedRequestException failedRequestException) {
        if (!failedRequestException.getMessage().contains("$tableId as xs:string -- Invalid coercion: () as xs:string")) {
            throw new RuntimeException(String.format("Unable to run Optic DSL query %s; cause: %s", str, failedRequestException.getMessage()), failedRequestException);
        }
        logger.info("No rows were found, so will not create any partitions.");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Iterator<JsonNode> readRowsInBucket(RowManager rowManager, PlanAnalysis.Partition partition, PlanAnalysis.Bucket bucket) {
        if (logger.isDebugEnabled()) {
            logger.debug("Getting rows for partition {} and bucket {} at server timestamp {}", partition, bucket, Long.valueOf(this.serverTimestamp));
        }
        if (this.serverTimestamp < 1) {
            throw new RuntimeException(String.format("Unable to read rows; invalid server timestamp: %d", Long.valueOf(this.serverTimestamp)));
        }
        PlanBuilder.Plan buildPlanForBucket = buildPlanForBucket(rowManager, bucket);
        JacksonHandle jacksonHandle = new JacksonHandle();
        jacksonHandle.setPointInTimeQueryTimestamp(this.serverTimestamp);
        JsonNode jsonNode = ((JacksonHandle) rowManager.resultDoc(buildPlanForBucket, jacksonHandle)).get();
        return (jsonNode == null || !jsonNode.has("rows")) ? new ArrayList().iterator() : jsonNode.get("rows").iterator();
    }

    private PlanBuilder.Plan buildPlanForBucket(RowManager rowManager, PlanAnalysis.Bucket bucket) {
        PlanBuilder.Plan bindParam = rowManager.newRawPlanDefinition(new JacksonHandle(this.planAnalysis.boundedPlan)).bindParam("ML_LOWER_BOUND", bucket.lowerBound).bindParam("ML_UPPER_BOUND", bucket.upperBound);
        if (this.opticFilters != null) {
            Iterator<OpticFilter> it = this.opticFilters.iterator();
            while (it.hasNext()) {
                bindParam = it.next().bindFilterValue(bindParam);
            }
        }
        return bindParam;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void pushDownFiltersIntoOpticQuery(List<OpticFilter> list) {
        this.opticFilters = list;
        addOperatorToPlan(PlanUtil.buildWhere(list));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void pushDownLimit(int i) {
        addOperatorToPlan(PlanUtil.buildLimit(i));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void pushDownTopN(SortOrder[] sortOrderArr, int i) {
        addOperatorToPlan(PlanUtil.buildOrderBy(sortOrderArr));
        pushDownLimit(i);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void pushDownAggregation(Aggregation aggregation) {
        List<String> list = (List) Stream.of((Object[]) aggregation.groupByExpressions()).map(expression -> {
            return PlanUtil.expressionToColumnName(expression);
        }).collect(Collectors.toList());
        addOperatorToPlan(PlanUtil.buildGroupByAggregation(list, aggregation));
        StructType buildSchemaWithColumnNames = buildSchemaWithColumnNames(list);
        for (Max max : aggregation.aggregateExpressions()) {
            if (max instanceof Avg) {
                buildSchemaWithColumnNames = buildSchemaWithColumnNames.add(max.toString(), DataTypes.DoubleType);
            } else if (max instanceof Count) {
                buildSchemaWithColumnNames = buildSchemaWithColumnNames.add(max.toString(), DataTypes.LongType);
            } else if (max instanceof CountStar) {
                buildSchemaWithColumnNames = buildSchemaWithColumnNames.add("count", DataTypes.LongType);
            } else if (max instanceof Max) {
                Max max2 = max;
                buildSchemaWithColumnNames = buildSchemaWithColumnNames.add(max.toString(), findColumnInSchema(max2.column(), PlanUtil.expressionToColumnName(max2.column())).dataType());
            } else if (max instanceof Min) {
                Min min = (Min) max;
                buildSchemaWithColumnNames = buildSchemaWithColumnNames.add(max.toString(), findColumnInSchema(min.column(), PlanUtil.expressionToColumnName(min.column())).dataType());
            } else if (max instanceof Sum) {
                Sum sum = (Sum) max;
                buildSchemaWithColumnNames = buildSchemaWithColumnNames.add(max.toString(), findColumnInSchema(sum.column(), PlanUtil.expressionToColumnName(sum.column())).dataType());
            } else {
                logger.info("Unsupported aggregate function: {}", max);
            }
        }
        if (!getProperties().containsKey(Options.READ_BATCH_SIZE)) {
            logger.info("Batch size was not overridden, so modifying each partition to make a single request to improve performance of pushed down aggregation.");
            this.planAnalysis = new PlanAnalysis(this.planAnalysis.boundedPlan, (List) this.planAnalysis.partitions.stream().map(partition -> {
                return partition.mergeBuckets();
            }).collect(Collectors.toList()));
        }
        this.schema = buildSchemaWithColumnNames;
    }

    private StructType buildSchemaWithColumnNames(List<String> list) {
        StructType structType = new StructType();
        for (String str : list) {
            StructField structField = null;
            StructField[] fields = this.schema.fields();
            int length = fields.length;
            int i = 0;
            while (true) {
                if (i >= length) {
                    break;
                }
                StructField structField2 = fields[i];
                if (str.equals(structField2.name())) {
                    structField = structField2;
                    break;
                }
                i++;
            }
            if (structField == null) {
                throw new IllegalArgumentException("Unable to find column in schema; column name: " + str);
            }
            structType = structType.add(structField);
        }
        return structType;
    }

    private StructField findColumnInSchema(Expression expression, String str) {
        for (StructField structField : this.schema.fields()) {
            if (str.equals(structField.name())) {
                return structField;
            }
        }
        throw new IllegalArgumentException("Unable to find column in schema for expression: " + expression.describe());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void pushDownRequiredSchema(StructType structType) {
        this.schema = structType;
        addOperatorToPlan(PlanUtil.buildSelect(structType));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean planAnalysisFoundNoRows() {
        return this.planAnalysis == null;
    }

    private void addOperatorToPlan(ObjectNode objectNode) {
        if (logger.isDebugEnabled()) {
            logger.debug("Adding operator to plan: {}", objectNode);
        }
        ArrayNode arrayNode = (ArrayNode) this.planAnalysis.boundedPlan.get("$optic").get("args");
        arrayNode.insert(arrayNode.size() - 1, objectNode);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public StructType getSchema() {
        return this.schema;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public PlanAnalysis getPlanAnalysis() {
        return this.planAnalysis;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public long getBucketCount() {
        if (this.planAnalysis != null) {
            return this.planAnalysis.getAllBuckets().size();
        }
        return 0L;
    }
}
