package org.apache.seatunnel.translation.spark.sink;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.seatunnel.api.sink.SeaTunnelSink;
import org.apache.seatunnel.api.sink.SinkAggregatedCommitter;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.translation.spark.sink.write.SeaTunnelSparkDataWriterFactory;
import org.apache.seatunnel.translation.spark.sink.write.SeaTunnelSparkWriterCommitMessage;
import org.apache.spark.sql.connector.write.BatchWrite;
import org.apache.spark.sql.connector.write.DataWriterFactory;
import org.apache.spark.sql.connector.write.PhysicalWriteInfo;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory;
import org.apache.spark.sql.connector.write.streaming.StreamingWrite;

/* loaded from: input_file:org/apache/seatunnel/translation/spark/sink/SeaTunnelBatchWrite.class */
public class SeaTunnelBatchWrite<StateT, CommitInfoT, AggregatedCommitInfoT> implements BatchWrite, StreamingWrite {
    private final SeaTunnelSink<SeaTunnelRow, StateT, CommitInfoT, AggregatedCommitInfoT> sink;
    private final SinkAggregatedCommitter<CommitInfoT, AggregatedCommitInfoT> aggregatedCommitter;

    public SeaTunnelBatchWrite(SeaTunnelSink<SeaTunnelRow, StateT, CommitInfoT, AggregatedCommitInfoT> seaTunnelSink) throws IOException {
        this.sink = seaTunnelSink;
        this.aggregatedCommitter = seaTunnelSink.createAggregatedCommitter().orElse(null);
    }

    public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo physicalWriteInfo) {
        return new SeaTunnelSparkDataWriterFactory(this.sink);
    }

    public void commit(WriterCommitMessage[] writerCommitMessageArr) {
        if (this.aggregatedCommitter != null) {
            try {
                this.aggregatedCommitter.commit(combineCommitMessage(writerCommitMessageArr));
            } catch (IOException e) {
                throw new RuntimeException("SinkAggregatedCommitter commit failed in driver", e);
            }
        }
    }

    public void abort(WriterCommitMessage[] writerCommitMessageArr) {
        if (this.aggregatedCommitter != null) {
            try {
                this.aggregatedCommitter.abort(combineCommitMessage(writerCommitMessageArr));
            } catch (Exception e) {
                throw new RuntimeException("SinkAggregatedCommitter abort failed in driver", e);
            }
        }
    }

    public StreamingDataWriterFactory createStreamingWriterFactory(PhysicalWriteInfo physicalWriteInfo) {
        return createBatchWriterFactory(physicalWriteInfo);
    }

    public void commit(long j, WriterCommitMessage[] writerCommitMessageArr) {
        commit(writerCommitMessageArr);
    }

    public void abort(long j, WriterCommitMessage[] writerCommitMessageArr) {
        abort(writerCommitMessageArr);
    }

    private List<AggregatedCommitInfoT> combineCommitMessage(WriterCommitMessage[] writerCommitMessageArr) {
        if (this.aggregatedCommitter == null || writerCommitMessageArr.length == 0) {
            return Collections.emptyList();
        }
        return Collections.singletonList(this.aggregatedCommitter.combine((List) Arrays.stream(writerCommitMessageArr).map(writerCommitMessage -> {
            return ((SeaTunnelSparkWriterCommitMessage) writerCommitMessage).getMessage();
        }).filter(Objects::nonNull).collect(Collectors.toList())));
    }
}
