package fr.insee.vtl.spark;

import fr.insee.vtl.model.Dataset;
import fr.insee.vtl.model.Structured;
import java.time.Instant;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Predef;
import scala.collection.JavaConverters;

/* loaded from: input_file:fr/insee/vtl/spark/SparkDataset.class */
public class SparkDataset implements Dataset {
    private final org.apache.spark.sql.Dataset<Row> sparkDataset;
    private Structured.DataStructure dataStructure;
    private Map<String, Dataset.Role> roles;

    public SparkDataset(org.apache.spark.sql.Dataset<Row> dataset, Map<String, Dataset.Role> map) {
        this.dataStructure = null;
        this.roles = Collections.emptyMap();
        this.sparkDataset = addMetadata(castIfNeeded((org.apache.spark.sql.Dataset) Objects.requireNonNull(dataset)), fromSparkSchema(dataset.schema(), map));
        this.roles = (Map) Objects.requireNonNull(map);
    }

    public SparkDataset(org.apache.spark.sql.Dataset<Row> dataset) {
        this.dataStructure = null;
        this.roles = Collections.emptyMap();
        this.sparkDataset = castIfNeeded(dataset);
    }

    public SparkDataset(Dataset dataset, Map<String, Dataset.Role> map, SparkSession sparkSession) {
        this.dataStructure = null;
        this.roles = Collections.emptyMap();
        this.sparkDataset = sparkSession.createDataFrame((List) dataset.getDataPoints().stream().map(dataPoint -> {
            return RowFactory.create(dataPoint.toArray(new Object[0]));
        }).collect(Collectors.toList()), toSparkSchema(dataset.getDataStructure()));
        this.roles = (Map) Objects.requireNonNull(map);
    }

    private static org.apache.spark.sql.Dataset<Row> castIfNeeded(org.apache.spark.sql.Dataset<Row> dataset) {
        org.apache.spark.sql.Dataset<Row> dataset2 = dataset;
        for (StructField structField : JavaConverters.asJavaCollection(dataset.schema())) {
            if (DataTypes.IntegerType.sameType(structField.dataType())) {
                dataset2 = dataset2.withColumn(structField.name(), dataset2.col(structField.name()).cast(DataTypes.LongType));
            } else if (DataTypes.FloatType.sameType(structField.dataType())) {
                dataset2 = dataset2.withColumn(structField.name(), dataset2.col(structField.name()).cast(DataTypes.DoubleType));
            } else if (DecimalType.class.equals(structField.dataType().getClass())) {
                dataset2 = dataset2.withColumn(structField.name(), dataset2.col(structField.name()).cast(DataTypes.DoubleType));
            }
        }
        return dataset2;
    }

    private static org.apache.spark.sql.Dataset<Row> addMetadata(org.apache.spark.sql.Dataset<Row> dataset, Structured.DataStructure dataStructure) {
        org.apache.spark.sql.Dataset<Row> dataset2 = dataset;
        for (StructField structField : JavaConverters.asJavaCollection(toSparkSchema(dataStructure))) {
            String name = structField.name();
            dataset2 = dataset2.withColumn(name, dataset2.col(name), structField.metadata());
        }
        return dataset2;
    }

    public static StructType toSparkSchema(Structured.DataStructure dataStructure) {
        ArrayList arrayList = new ArrayList();
        for (Structured.Component component : dataStructure.values()) {
            arrayList.add(DataTypes.createStructField(component.getName(), fromVtlType(component.getType()), true, new Metadata(JavaConverters.mapAsScalaMap(Map.of("vtlRole", component.getRole().name())).toMap(Predef.$conforms()))));
        }
        return DataTypes.createStructType(arrayList);
    }

    public static Structured.DataStructure fromSparkSchema(StructType structType, Map<String, Dataset.Role> map) {
        ArrayList arrayList = new ArrayList();
        for (StructField structField : JavaConverters.asJavaCollection(structType)) {
            arrayList.add(new Structured.Component(structField.name(), toVtlType(structField.dataType()), map.containsKey(structField.name()) ? map.get(structField.name()) : structField.metadata().contains("vtlRole") ? Dataset.Role.valueOf(structField.metadata().getString("vtlRole")) : Dataset.Role.MEASURE, (Boolean) null));
        }
        return new Structured.DataStructure(arrayList);
    }

    public static Class<?> toVtlType(DataType dataType) {
        if (DataTypes.StringType.sameType(dataType)) {
            return String.class;
        }
        if (DataTypes.IntegerType.sameType(dataType) || DataTypes.LongType.sameType(dataType)) {
            return Long.class;
        }
        if (DataTypes.FloatType.sameType(dataType) || DataTypes.DoubleType.sameType(dataType)) {
            return Double.class;
        }
        if (DataTypes.BooleanType.sameType(dataType)) {
            return Boolean.class;
        }
        if (DecimalType.class.equals(dataType.getClass())) {
            return Double.class;
        }
        if (DataTypes.DateType.sameType(dataType) || DataTypes.TimestampType.sameType(dataType)) {
            return Instant.class;
        }
        throw new UnsupportedOperationException("unsupported type " + dataType);
    }

    public static DataType fromVtlType(Class<?> cls) {
        if (String.class.equals(cls)) {
            return DataTypes.StringType;
        }
        if (Long.class.equals(cls)) {
            return DataTypes.LongType;
        }
        if (Double.class.equals(cls)) {
            return DataTypes.DoubleType;
        }
        if (Boolean.class.equals(cls)) {
            return DataTypes.BooleanType;
        }
        if (Instant.class.equals(cls)) {
            return DataTypes.TimestampType;
        }
        if (LocalDate.class.equals(cls)) {
            return DataTypes.DateType;
        }
        throw new UnsupportedOperationException("unsupported type " + cls);
    }

    public org.apache.spark.sql.Dataset<Row> getSparkDataset() {
        return this.sparkDataset;
    }

    public List<Structured.DataPoint> getDataPoints() {
        return (List) this.sparkDataset.collectAsList().stream().map(row -> {
            return JavaConverters.seqAsJavaList(row.toSeq());
        }).map(list -> {
            return new Structured.DataPoint(getDataStructure(), list);
        }).collect(Collectors.toList());
    }

    public Structured.DataStructure getDataStructure() {
        if (this.dataStructure == null) {
            this.dataStructure = fromSparkSchema(this.sparkDataset.schema(), this.roles);
        }
        return this.dataStructure;
    }
}
