package com.marklogic.spark.reader;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.marklogic.spark.reader.filter.OpticFilter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.SortDirection;
import org.apache.spark.sql.connector.expressions.SortOrder;
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc;
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation;
import org.apache.spark.sql.connector.expressions.aggregate.Avg;
import org.apache.spark.sql.connector.expressions.aggregate.Count;
import org.apache.spark.sql.connector.expressions.aggregate.CountStar;
import org.apache.spark.sql.connector.expressions.aggregate.Max;
import org.apache.spark.sql.connector.expressions.aggregate.Min;
import org.apache.spark.sql.connector.expressions.aggregate.Sum;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/marklogic/spark/reader/PlanUtil.class */
public abstract class PlanUtil {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) PlanUtil.class);
    private static final ObjectMapper objectMapper = new ObjectMapper();
    private static Map<Class<? extends AggregateFunc>, Function<AggregateFunc, OpticFunction>> aggregateFunctionHandlers = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/marklogic/spark/reader/PlanUtil$OpticFunction.class */
    public static class OpticFunction {
        final String functionName;
        final String columnName;
        final boolean distinct;

        OpticFunction(String str, Expression expression) {
            this(str, expression, false);
        }

        OpticFunction(String str, Expression expression, boolean z) {
            this.functionName = str;
            this.columnName = PlanUtil.expressionToColumnName(expression);
            this.distinct = z;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static ObjectNode buildGroupByAggregation(List<String> list, Aggregation aggregation) {
        return newOperation("group-by", arrayNode -> {
            ArrayNode addArray = arrayNode.addArray();
            list.forEach(str -> {
                populateSchemaCol(addArray.addObject(), str);
            });
            ArrayNode addArray2 = arrayNode.addArray();
            for (AggregateFunc aggregateFunc : aggregation.aggregateExpressions()) {
                if (aggregateFunc instanceof CountStar) {
                    addArray2.addObject().put("ns", "op").put("fn", "count").putArray("args").add("count").add(objectMapper.nullNode());
                } else if (aggregateFunctionHandlers.containsKey(aggregateFunc.getClass())) {
                    OpticFunction apply = aggregateFunctionHandlers.get(aggregateFunc.getClass()).apply(aggregateFunc);
                    ArrayNode putArray = addArray2.addObject().put("ns", "op").put("fn", apply.functionName).putArray("args");
                    putArray.add(aggregateFunc.toString());
                    populateSchemaCol(putArray.addObject(), apply.columnName);
                    if (apply.distinct) {
                        putArray.addObject().put("values", "distinct");
                    }
                } else {
                    logger.info("Unsupported aggregate function, will not be pushed to Optic: {}", aggregateFunc);
                }
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static ObjectNode buildLimit(int i) {
        return newOperation("limit", arrayNode -> {
            arrayNode.add(i);
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static ObjectNode buildOrderBy(SortOrder[] sortOrderArr) {
        return newOperation("order-by", arrayNode -> {
            ArrayNode addArray = arrayNode.addArray();
            for (SortOrder sortOrder : sortOrderArr) {
                ArrayNode putArray = addArray.addObject().put("ns", "op").put("fn", SortDirection.ASCENDING.equals(sortOrder.direction()) ? "asc" : "desc").putArray("args");
                String expressionToColumnName = expressionToColumnName(sortOrder.expression());
                if ("COUNT(*)".equals(expressionToColumnName)) {
                    if (logger.isDebugEnabled()) {
                        logger.debug("Adjusting `COUNT(*)` column to be `count`");
                    }
                    expressionToColumnName = "count";
                }
                populateSchemaCol(putArray.addObject(), expressionToColumnName);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static ObjectNode buildSelect(StructType structType) {
        return newOperation("select", arrayNode -> {
            ArrayNode addArray = arrayNode.addArray();
            for (StructField structField : structType.fields()) {
                populateSchemaCol(addArray.addObject(), structField.name());
            }
        });
    }

    public static void populateSchemaCol(ObjectNode objectNode, String str) {
        String[] split = removeTickMarksFromColumnName(str).split("\\.");
        ArrayNode putArray = objectNode.put("ns", "op").put("fn", "schema-col").putArray("args");
        if (split.length == 3) {
            putArray.add(split[0]).add(split[1]).add(split[2]);
        } else if (split.length == 2) {
            putArray.add(objectMapper.nullNode()).add(split[0]).add(split[1]);
        } else {
            putArray.add(objectMapper.nullNode()).add(objectMapper.nullNode()).add(split[0]);
        }
    }

    private static String removeTickMarksFromColumnName(String str) {
        if (str.startsWith("`")) {
            str = str.substring(1);
        }
        return str.endsWith("`") ? str.substring(0, str.length() - 1) : str;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static ObjectNode buildWhere(List<OpticFilter> list) {
        return newOperation("where", arrayNode -> {
            ArrayNode putArray = list.size() == 1 ? arrayNode : arrayNode.addObject().put("ns", "op").put("fn", "and").putArray("args");
            list.forEach(opticFilter -> {
                opticFilter.populateArg(putArray.addObject());
            });
        });
    }

    private static ObjectNode newOperation(String str, Consumer<ArrayNode> consumer) {
        ObjectNode put = objectMapper.createObjectNode().put("ns", "op").put("fn", str);
        consumer.accept(put.putArray("args"));
        return put;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static String expressionToColumnName(Expression expression) {
        NamedReference[] references = expression.references();
        if (references == null || references.length < 1) {
            return expression.describe();
        }
        String[] fieldNames = references[0].fieldNames();
        if (fieldNames.length != 1) {
            throw new IllegalArgumentException("Unsupported expression: " + expression + "; expecting expression to have exactly one field name.");
        }
        return fieldNames[0];
    }

    static {
        aggregateFunctionHandlers.put(Avg.class, aggregateFunc -> {
            Avg avg = (Avg) aggregateFunc;
            return new OpticFunction("avg", avg.column(), avg.isDistinct());
        });
        aggregateFunctionHandlers.put(Count.class, aggregateFunc2 -> {
            Count count = (Count) aggregateFunc2;
            return new OpticFunction("count", count.column(), count.isDistinct());
        });
        aggregateFunctionHandlers.put(Max.class, aggregateFunc3 -> {
            return new OpticFunction("max", ((Max) aggregateFunc3).column());
        });
        aggregateFunctionHandlers.put(Min.class, aggregateFunc4 -> {
            return new OpticFunction("min", ((Min) aggregateFunc4).column());
        });
        aggregateFunctionHandlers.put(Sum.class, aggregateFunc5 -> {
            Sum sum = (Sum) aggregateFunc5;
            return new OpticFunction("sum", sum.column(), sum.isDistinct());
        });
    }
}
