package fr.insee.vtl.spark;

import fr.insee.vtl.model.AggregationExpression;
import fr.insee.vtl.model.BooleanExpression;
import fr.insee.vtl.model.Dataset;
import fr.insee.vtl.model.DatasetExpression;
import fr.insee.vtl.model.ProcessingEngine;
import fr.insee.vtl.model.ProcessingEngineFactory;
import fr.insee.vtl.model.ResolvableExpression;
import fr.insee.vtl.model.Structured;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.script.ScriptEngine;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import scala.collection.JavaConverters;

/* loaded from: input_file:fr/insee/vtl/spark/SparkProcessingEngine.class */
public class SparkProcessingEngine implements ProcessingEngine {
    private final SparkSession spark;

    /* loaded from: input_file:fr/insee/vtl/spark/SparkProcessingEngine$Factory.class */
    public static class Factory implements ProcessingEngineFactory {
        private static final String SPARK_SESSION = "$vtl.spark.session";

        public String getName() {
            return "spark";
        }

        public ProcessingEngine getProcessingEngine(ScriptEngine scriptEngine) {
            Object obj = scriptEngine.get(SPARK_SESSION);
            if (obj != null) {
                if (obj instanceof SparkSession) {
                    return new SparkProcessingEngine((SparkSession) obj);
                }
                throw new IllegalArgumentException("$vtl.spark.session was not a spark session");
            }
            SparkSession active = SparkSession.active();
            if (active != null) {
                return new SparkProcessingEngine(active);
            }
            throw new IllegalArgumentException("no active spark session");
        }
    }

    public SparkProcessingEngine(SparkSession sparkSession) {
        this.spark = (SparkSession) Objects.requireNonNull(sparkSession);
    }

    public SparkProcessingEngine() {
        this.spark = SparkSession.active();
    }

    private static Map<String, Dataset.Role> getRoleMap(Collection<Structured.Component> collection) {
        return (Map) collection.stream().collect(Collectors.toMap((v0) -> {
            return v0.getName();
        }, (v0) -> {
            return v0.getRole();
        }));
    }

    private static Map<String, Dataset.Role> getRoleMap(Dataset dataset) {
        return getRoleMap((Collection<Structured.Component>) dataset.getDataStructure().values());
    }

    private SparkDataset asSparkDataset(DatasetExpression datasetExpression) {
        if (datasetExpression instanceof SparkDatasetExpression) {
            return ((SparkDatasetExpression) datasetExpression).resolve(Map.of());
        }
        Dataset resolve = datasetExpression.resolve(Map.of());
        return resolve instanceof SparkDataset ? (SparkDataset) resolve : new SparkDataset(resolve, getRoleMap(resolve), this.spark);
    }

    public DatasetExpression executeCalc(DatasetExpression datasetExpression, Map<String, ResolvableExpression> map, Map<String, Dataset.Role> map2, Map<String, String> map3) {
        SparkDataset asSparkDataset = asSparkDataset(datasetExpression);
        org.apache.spark.sql.Dataset<Row> sparkDataset = asSparkDataset.getSparkDataset();
        StructType schema = sparkDataset.schema();
        List asList = Arrays.asList(schema.fieldNames());
        ArrayList<String> arrayList = new ArrayList(asList);
        for (String str : map.keySet()) {
            if (!arrayList.contains(str)) {
                arrayList.add(str);
            }
        }
        ArrayList arrayList2 = new ArrayList();
        for (String str2 : arrayList) {
            if (!asList.contains(str2) || map.containsKey(str2)) {
                arrayList2.add(DataTypes.createStructField(str2, SparkDataset.fromVtlType(map.get(str2).getType()), true));
            } else {
                arrayList2.add(schema.apply(str2));
            }
        }
        StructType createStructType = DataTypes.createStructType(arrayList2);
        Map<String, Dataset.Role> roleMap = getRoleMap(asSparkDataset);
        roleMap.putAll(map2);
        try {
            return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>) sparkDataset.selectExpr((String[]) Stream.concat(arrayList.stream().filter(str3 -> {
                return !map3.containsKey(str3);
            }), ((List) map3.entrySet().stream().map(entry -> {
                return String.format("%s as %s", entry.getValue(), entry.getKey());
            }).collect(Collectors.toList())).stream()).toArray(i -> {
                return new String[i];
            })), roleMap));
        } catch (Exception e) {
            return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>) sparkDataset.map(row -> {
                SparkRowMap sparkRowMap = new SparkRowMap(row);
                Object[] objArr = new Object[createStructType.size()];
                for (String str4 : createStructType.fieldNames()) {
                    int fieldIndex = createStructType.fieldIndex(str4);
                    if (map.containsKey(str4)) {
                        objArr[fieldIndex] = ((ResolvableExpression) map.get(str4)).resolve(sparkRowMap);
                    } else {
                        objArr[fieldIndex] = row.get(fieldIndex);
                    }
                }
                return new GenericRowWithSchema(objArr, createStructType);
            }, RowEncoder.apply(createStructType)), roleMap));
        }
    }

    public DatasetExpression executeFilter(DatasetExpression datasetExpression, BooleanExpression booleanExpression, String str) {
        SparkDataset asSparkDataset = asSparkDataset(datasetExpression);
        org.apache.spark.sql.Dataset<Row> sparkDataset = asSparkDataset.getSparkDataset();
        try {
            return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>) sparkDataset.filter(str), getRoleMap(asSparkDataset)));
        } catch (Exception e) {
            return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>) sparkDataset.filter(new SparkFilterFunction(booleanExpression)), getRoleMap(asSparkDataset)));
        }
    }

    public DatasetExpression executeRename(DatasetExpression datasetExpression, Map<String, String> map) {
        SparkDataset asSparkDataset = asSparkDataset(datasetExpression);
        ArrayList arrayList = new ArrayList();
        for (String str : asSparkDataset.getColumnNames()) {
            Column column = new Column(str);
            if (map.containsKey(str)) {
                column = column.as(map.get(str));
            }
            arrayList.add(column);
        }
        org.apache.spark.sql.Dataset select = asSparkDataset.getSparkDataset().select(JavaConverters.iterableAsScalaIterable(arrayList).toSeq());
        Map<String, Dataset.Role> roleMap = getRoleMap(asSparkDataset);
        for (Map.Entry<String, String> entry : map.entrySet()) {
            roleMap.put(entry.getValue(), roleMap.remove(entry.getKey()));
        }
        return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>) select, roleMap));
    }

    public DatasetExpression executeProject(DatasetExpression datasetExpression, List<String> list) {
        SparkDataset asSparkDataset = asSparkDataset(datasetExpression);
        return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>) asSparkDataset.getSparkDataset().select(JavaConverters.iterableAsScalaIterable((List) list.stream().map(Column::new).collect(Collectors.toList())).toSeq()), getRoleMap(asSparkDataset)));
    }

    public DatasetExpression executeUnion(List<DatasetExpression> list) {
        throw new UnsupportedOperationException("TODO");
    }

    public DatasetExpression executeAggr(DatasetExpression datasetExpression, Structured.DataStructure dataStructure, Map<String, AggregationExpression> map, Function<Structured.DataPoint, Map<String, Object>> function) {
        throw new UnsupportedOperationException("TODO");
    }

    public DatasetExpression executeInnerJoin(Map<String, DatasetExpression> map, List<Structured.Component> list) {
        return new SparkDatasetExpression(new SparkDataset(executeJoin(toAliasedDatasets(map), identifierNames(list), "inner"), getRoleMap(list)));
    }

    public DatasetExpression executeLeftJoin(Map<String, DatasetExpression> map, List<Structured.Component> list) {
        return new SparkDatasetExpression(new SparkDataset(executeJoin(toAliasedDatasets(map), identifierNames(list), "left"), getRoleMap(list)));
    }

    public DatasetExpression executeCrossJoin(Map<String, DatasetExpression> map, List<Structured.Component> list) {
        return new SparkDatasetExpression(new SparkDataset(executeJoin(toAliasedDatasets(map), List.of(), "cross"), getRoleMap(list)));
    }

    public DatasetExpression executeFullJoin(Map<String, DatasetExpression> map, List<Structured.Component> list) {
        return new SparkDatasetExpression(new SparkDataset(executeJoin(toAliasedDatasets(map), identifierNames(list), "outer"), getRoleMap(list)));
    }

    private List<org.apache.spark.sql.Dataset<Row>> toAliasedDatasets(Map<String, DatasetExpression> map) {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, DatasetExpression> entry : map.entrySet()) {
            arrayList.add(asSparkDataset(entry.getValue()).getSparkDataset().as(entry.getKey()));
        }
        return arrayList;
    }

    private static List<String> identifierNames(List<Structured.Component> list) {
        return (List) list.stream().filter(component -> {
            return Dataset.Role.IDENTIFIER.equals(component.getRole());
        }).map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toList());
    }

    public org.apache.spark.sql.Dataset<Row> executeJoin(List<org.apache.spark.sql.Dataset<Row>> list, List<String> list2, String str) {
        Iterator<org.apache.spark.sql.Dataset<Row>> it = list.iterator();
        org.apache.spark.sql.Dataset<Row> next = it.next();
        while (true) {
            org.apache.spark.sql.Dataset<Row> dataset = next;
            if (!it.hasNext()) {
                return dataset;
            }
            next = str.equals("cross") ? dataset.crossJoin(it.next()) : dataset.join(it.next(), JavaConverters.iterableAsScalaIterable(list2).toSeq(), str);
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 818964137:
                if (implMethodName.equals("lambda$executeCalc$d7b323ff$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("fr/insee/vtl/spark/SparkProcessingEngine") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/sql/types/StructType;Ljava/util/Map;Lorg/apache/spark/sql/Row;)Lorg/apache/spark/sql/Row;")) {
                    StructType structType = (StructType) serializedLambda.getCapturedArg(0);
                    Map map = (Map) serializedLambda.getCapturedArg(1);
                    return row -> {
                        SparkRowMap sparkRowMap = new SparkRowMap(row);
                        Object[] objArr = new Object[structType.size()];
                        for (String str4 : structType.fieldNames()) {
                            int fieldIndex = structType.fieldIndex(str4);
                            if (map.containsKey(str4)) {
                                objArr[fieldIndex] = ((ResolvableExpression) map.get(str4)).resolve(sparkRowMap);
                            } else {
                                objArr[fieldIndex] = row.get(fieldIndex);
                            }
                        }
                        return new GenericRowWithSchema(objArr, structType);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
