/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.engine.server.task;

import com.hazelcast.cluster.Address;
import com.hazelcast.logging.ILogger;
import com.hazelcast.logging.Logger;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.apache.seatunnel.api.serialization.Serializer;
import org.apache.seatunnel.api.sink.SinkAggregatedCommitter;
import org.apache.seatunnel.engine.common.utils.ExceptionUtil;
import org.apache.seatunnel.engine.core.dag.actions.SinkAction;
import org.apache.seatunnel.engine.server.checkpoint.ActionSubtaskState;
import org.apache.seatunnel.engine.server.checkpoint.CheckpointBarrier;
import org.apache.seatunnel.engine.server.checkpoint.operation.TaskAcknowledgeOperation;
import org.apache.seatunnel.engine.server.execution.ProgressState;
import org.apache.seatunnel.engine.server.execution.TaskLocation;
import org.apache.seatunnel.engine.server.task.CoordinatorTask;
import org.apache.seatunnel.engine.server.task.record.Barrier;
import org.apache.seatunnel.engine.server.task.statemachine.SeaTunnelTaskState;

public class SinkAggregatedCommitterTask<CommandInfoT, AggregatedCommitInfoT>
extends CoordinatorTask {
    private static final ILogger LOGGER = Logger.getLogger(SinkAggregatedCommitterTask.class);
    private static final long serialVersionUID = 5906594537520393503L;
    private SeaTunnelTaskState currState;
    private final SinkAction<?, ?, CommandInfoT, AggregatedCommitInfoT> sink;
    private final int maxWriterSize;
    private final SinkAggregatedCommitter<CommandInfoT, AggregatedCommitInfoT> aggregatedCommitter;
    private transient Serializer<AggregatedCommitInfoT> aggregatedCommitInfoSerializer;
    private Map<Long, Address> writerAddressMap;
    private ConcurrentMap<Long, List<CommandInfoT>> commitInfoCache;
    private ConcurrentMap<Long, List<AggregatedCommitInfoT>> checkpointCommitInfoMap;
    private Map<Long, Integer> checkpointBarrierCounter;
    private CompletableFuture<Void> completableFuture;
    private volatile boolean receivedSinkWriter;

    public SinkAggregatedCommitterTask(long jobID, TaskLocation taskID, SinkAction<?, ?, CommandInfoT, AggregatedCommitInfoT> sink, SinkAggregatedCommitter<CommandInfoT, AggregatedCommitInfoT> aggregatedCommitter) {
        super(jobID, taskID);
        this.sink = sink;
        this.aggregatedCommitter = aggregatedCommitter;
        this.maxWriterSize = sink.getParallelism();
        this.receivedSinkWriter = false;
    }

    @Override
    public void init() throws Exception {
        super.init();
        this.currState = SeaTunnelTaskState.INIT;
        this.checkpointBarrierCounter = new ConcurrentHashMap<Long, Integer>();
        this.commitInfoCache = new ConcurrentHashMap<Long, List<CommandInfoT>>();
        this.writerAddressMap = new ConcurrentHashMap<Long, Address>();
        this.checkpointCommitInfoMap = new ConcurrentHashMap<Long, List<AggregatedCommitInfoT>>();
        this.completableFuture = new CompletableFuture();
        this.aggregatedCommitInfoSerializer = this.sink.getSink().getAggregatedCommitInfoSerializer().get();
        LOGGER.info("starting seatunnel sink aggregated committer task, sink name: " + this.sink.getName());
    }

    public void receivedWriterRegister(TaskLocation writerID, Address address) {
        this.writerAddressMap.put(writerID.getTaskID(), address);
        if (this.maxWriterSize <= this.writerAddressMap.size()) {
            this.receivedSinkWriter = true;
        }
    }

    @Override
    @NonNull
    public ProgressState call() throws Exception {
        this.stateProcess();
        return this.progress.toState();
    }

    protected void stateProcess() throws Exception {
        switch (this.currState) {
            case INIT: {
                this.currState = SeaTunnelTaskState.WAITING_RESTORE;
                this.reportTaskStatus(SeaTunnelTaskState.WAITING_RESTORE);
                break;
            }
            case WAITING_RESTORE: {
                if (!this.restoreComplete) break;
                this.currState = SeaTunnelTaskState.READY_START;
                this.reportTaskStatus(SeaTunnelTaskState.READY_START);
                break;
            }
            case READY_START: {
                if (!this.startCalled) break;
                this.currState = SeaTunnelTaskState.STARTING;
                break;
            }
            case STARTING: {
                if (!this.receivedSinkWriter) break;
                this.currState = SeaTunnelTaskState.RUNNING;
                break;
            }
            case RUNNING: {
                if (this.prepareCloseStatus) {
                    this.currState = SeaTunnelTaskState.PREPARE_CLOSE;
                    break;
                }
                Thread.sleep(100L);
                break;
            }
            case PREPARE_CLOSE: {
                if (this.closeCalled) {
                    this.currState = SeaTunnelTaskState.CLOSED;
                    break;
                }
                Thread.sleep(100L);
                break;
            }
            case CLOSED: {
                this.close();
                return;
            }
            case CANCELLING: {
                this.close();
                this.currState = SeaTunnelTaskState.CANCELED;
                return;
            }
            default: {
                throw new IllegalArgumentException("Unknown Enumerator State: " + this.currState);
            }
        }
    }

    @Override
    public void close() throws IOException {
        this.aggregatedCommitter.close();
        this.progress.done();
        this.completableFuture.complete(null);
    }

    @Override
    public void triggerBarrier(Barrier barrier) throws Exception {
        Integer count2 = this.checkpointBarrierCounter.compute(barrier.getId(), (id, num) -> {
            int n;
            if (num == null) {
                n = 1;
            } else {
                num = num + 1;
                n = num;
            }
            return n;
        });
        if (count2 != this.maxWriterSize) {
            return;
        }
        if (barrier.prepareClose()) {
            this.prepareCloseStatus = true;
            this.prepareCloseBarrierId.set(barrier.getId());
        }
        if (barrier.snapshot()) {
            if (this.commitInfoCache.containsKey(barrier.getId())) {
                AggregatedCommitInfoT aggregatedCommitInfoT = this.aggregatedCommitter.combine((List)this.commitInfoCache.get(barrier.getId()));
                this.checkpointCommitInfoMap.put(barrier.getId(), Collections.singletonList(aggregatedCommitInfoT));
            }
            List<byte[]> states = SinkAggregatedCommitterTask.serializeStates(this.aggregatedCommitInfoSerializer, this.checkpointCommitInfoMap.getOrDefault(barrier.getId(), Collections.emptyList()));
            this.getExecutionContext().sendToMaster(new TaskAcknowledgeOperation(this.taskLocation, (CheckpointBarrier)barrier, Collections.singletonList(new ActionSubtaskState(this.sink.getId(), -1, states))));
        }
    }

    @Override
    public void restoreState(List<ActionSubtaskState> actionStateList) throws Exception {
        List aggregatedCommitInfos = actionStateList.stream().map(ActionSubtaskState::getState).flatMap(Collection::stream).map(bytes -> ExceptionUtil.sneaky(() -> this.aggregatedCommitInfoSerializer.deserialize((byte[])bytes))).collect(Collectors.toList());
        this.aggregatedCommitter.commit(aggregatedCommitInfos);
        this.restoreComplete = true;
    }

    public void receivedWriterCommitInfo(long checkpointID, CommandInfoT commitInfos) {
        this.commitInfoCache.computeIfAbsent(checkpointID, id -> new CopyOnWriteArrayList());
        ((List)this.commitInfoCache.get(checkpointID)).add(commitInfos);
    }

    @Override
    public Set<URL> getJarsUrl() {
        return new HashSet<URL>(this.sink.getJarUrls());
    }

    @Override
    public void notifyCheckpointComplete(long checkpointId) throws Exception {
        ArrayList aggregatedCommitInfo = new ArrayList();
        this.checkpointCommitInfoMap.forEach((key, value) -> {
            if (key > checkpointId) {
                return;
            }
            aggregatedCommitInfo.addAll(value);
            this.checkpointCommitInfoMap.remove(key);
        });
        this.aggregatedCommitter.commit(aggregatedCommitInfo);
        this.tryClose(checkpointId);
    }

    @Override
    public void notifyCheckpointAborted(long checkpointId) throws Exception {
        this.aggregatedCommitter.abort((List)this.checkpointCommitInfoMap.get(checkpointId));
        this.checkpointCommitInfoMap.remove(checkpointId);
        this.tryClose(checkpointId);
    }
}

