package org.apache.tez.dag.app;

import com.google.common.collect.Maps;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.ProtocolSignature;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.authorize.PolicyProvider;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.tez.common.ContainerTask;
import org.apache.tez.common.TezConverterUtils;
import org.apache.tez.common.TezLocalResource;
import org.apache.tez.common.TezTaskUmbilicalProtocol;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.security.JobTokenIdentifier;
import org.apache.tez.common.security.JobTokenSecretManager;
import org.apache.tez.common.security.TokenCache;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.impl.TaskSpec;
import org.apache.tez.runtime.api.impl.TezHeartbeatRequest;
import org.apache.tez.runtime.api.impl.TezHeartbeatResponse;
import org.apache.tez.serviceplugins.api.ContainerEndReason;
import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
import org.apache.tez.serviceplugins.api.TaskCommunicator;
import org.apache.tez.serviceplugins.api.TaskCommunicatorContext;
import org.apache.tez.serviceplugins.api.TaskHeartbeatRequest;
import org.apache.tez.serviceplugins.api.TaskHeartbeatResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InterfaceAudience.Private
/* loaded from: input_file:org/apache/tez/dag/app/TezTaskCommunicatorImpl.class */
public class TezTaskCommunicatorImpl extends TaskCommunicator {
    private static final Logger LOG = LoggerFactory.getLogger(TezTaskCommunicatorImpl.class);
    private static final ContainerTask TASK_FOR_INVALID_JVM = new ContainerTask((TaskSpec) null, true, (Map) null, (Credentials) null, false);
    private final TezTaskUmbilicalProtocol taskUmbilical;
    protected final ConcurrentMap<ContainerId, ContainerInfo> registeredContainers;
    protected final ConcurrentMap<TezTaskAttemptID, ContainerId> attemptToContainerMap;
    protected final String tokenIdentifier;
    protected final Token<JobTokenIdentifier> sessionToken;
    protected final Configuration conf;
    protected InetSocketAddress address;
    protected volatile Server server;

    /* loaded from: input_file:org/apache/tez/dag/app/TezTaskCommunicatorImpl$ContainerInfo.class */
    public static final class ContainerInfo {
        final ContainerId containerId;
        public final String host;
        public final int port;
        TezHeartbeatResponse lastResponse = null;
        TaskSpec taskSpec = null;
        long lastRequestId = 0;
        Map<String, LocalResource> additionalLRs = null;
        Credentials credentials = null;
        boolean credentialsChanged = false;
        boolean taskPulled = false;

        ContainerInfo(ContainerId containerId, String str, int i) {
            this.containerId = containerId;
            this.host = str;
            this.port = i;
        }

        void reset() {
            this.taskSpec = null;
            this.additionalLRs = null;
            this.credentials = null;
            this.credentialsChanged = false;
            this.taskPulled = false;
        }
    }

    /* loaded from: input_file:org/apache/tez/dag/app/TezTaskCommunicatorImpl$TezTaskUmbilicalProtocolImpl.class */
    private class TezTaskUmbilicalProtocolImpl implements TezTaskUmbilicalProtocol {
        private TezTaskUmbilicalProtocolImpl() {
        }

        public ContainerTask getTask(org.apache.tez.common.ContainerContext containerContext) throws IOException {
            ContainerTask containerTask;
            if (containerContext == null || containerContext.getContainerIdentifier() == null) {
                TezTaskCommunicatorImpl.LOG.info("Invalid task request with an empty containerContext or containerId");
                containerTask = TezTaskCommunicatorImpl.TASK_FOR_INVALID_JVM;
            } else {
                ContainerId containerId = ConverterUtils.toContainerId(containerContext.getContainerIdentifier());
                if (TezTaskCommunicatorImpl.LOG.isDebugEnabled()) {
                    TezTaskCommunicatorImpl.LOG.debug("Container with id: " + containerId + " asked for a task");
                }
                containerTask = TezTaskCommunicatorImpl.this.getContainerTask(containerId);
                if (containerTask != null && !containerTask.shouldDie()) {
                    TezTaskCommunicatorImpl.this.getContext().taskSubmitted(containerTask.getTaskSpec().getTaskAttemptID(), containerId);
                    TezTaskCommunicatorImpl.this.getContext().taskStartedRemotely(containerTask.getTaskSpec().getTaskAttemptID());
                }
            }
            if (TezTaskCommunicatorImpl.LOG.isDebugEnabled()) {
                TezTaskCommunicatorImpl.LOG.debug("getTask returning task: " + containerTask);
            }
            return containerTask;
        }

        public boolean canCommit(TezTaskAttemptID tezTaskAttemptID) throws IOException {
            return TezTaskCommunicatorImpl.this.getContext().canCommit(tezTaskAttemptID);
        }

        public TezHeartbeatResponse heartbeat(TezHeartbeatRequest tezHeartbeatRequest) throws IOException, TezException {
            ContainerId containerId = ConverterUtils.toContainerId(tezHeartbeatRequest.getContainerIdentifier());
            long requestId = tezHeartbeatRequest.getRequestId();
            if (TezTaskCommunicatorImpl.LOG.isDebugEnabled()) {
                TezTaskCommunicatorImpl.LOG.debug("Received heartbeat from container, request=" + tezHeartbeatRequest);
            }
            ContainerInfo containerInfo = TezTaskCommunicatorImpl.this.registeredContainers.get(containerId);
            if (containerInfo == null) {
                TezTaskCommunicatorImpl.LOG.warn("Received task heartbeat from unknown container with id: " + containerId + ", asking it to die");
                TezHeartbeatResponse tezHeartbeatResponse = new TezHeartbeatResponse();
                tezHeartbeatResponse.setLastRequestId(requestId);
                tezHeartbeatResponse.setShouldDie();
                return tezHeartbeatResponse;
            }
            synchronized (containerInfo) {
                if (containerInfo.lastRequestId == requestId) {
                    TezTaskCommunicatorImpl.LOG.warn("Old sequenceId received: " + requestId + ", Re-sending last response to client");
                    return containerInfo.lastResponse;
                }
                TezHeartbeatResponse tezHeartbeatResponse2 = new TezHeartbeatResponse();
                TezTaskAttemptID currentTaskAttemptID = tezHeartbeatRequest.getCurrentTaskAttemptID();
                if (currentTaskAttemptID != null) {
                    synchronized (containerInfo) {
                        ContainerId containerId2 = TezTaskCommunicatorImpl.this.attemptToContainerMap.get(currentTaskAttemptID);
                        if (containerId2 == null || !containerId2.equals(containerId)) {
                            throw new TezException("Attempt " + currentTaskAttemptID + " is not recognized for heartbeat");
                        }
                        if (containerInfo.lastRequestId + 1 != requestId) {
                            throw new TezException("Container " + containerId + " has invalid request id. Expected: " + containerInfo.lastRequestId + "1 and actual: " + requestId);
                        }
                    }
                    TaskHeartbeatResponse heartbeat = TezTaskCommunicatorImpl.this.getContext().heartbeat(new TaskHeartbeatRequest(tezHeartbeatRequest.getContainerIdentifier(), tezHeartbeatRequest.getCurrentTaskAttemptID(), tezHeartbeatRequest.getEvents(), tezHeartbeatRequest.getStartIndex(), tezHeartbeatRequest.getPreRoutedStartIndex(), tezHeartbeatRequest.getMaxEvents()));
                    tezHeartbeatResponse2.setEvents(heartbeat.getEvents());
                    tezHeartbeatResponse2.setNextFromEventId(heartbeat.getNextFromEventId());
                    tezHeartbeatResponse2.setNextPreRoutedEventId(heartbeat.getNextPreRoutedEventId());
                }
                tezHeartbeatResponse2.setLastRequestId(requestId);
                containerInfo.lastRequestId = requestId;
                containerInfo.lastResponse = tezHeartbeatResponse2;
                return tezHeartbeatResponse2;
            }
        }

        public long getProtocolVersion(String str, long j) throws IOException {
            return 19L;
        }

        public ProtocolSignature getProtocolSignature(String str, long j, int i) throws IOException {
            return ProtocolSignature.getProtocolSignature(this, str, j, i);
        }
    }

    public TezTaskCommunicatorImpl(TaskCommunicatorContext taskCommunicatorContext) {
        super(taskCommunicatorContext);
        this.registeredContainers = new ConcurrentHashMap();
        this.attemptToContainerMap = new ConcurrentHashMap();
        this.taskUmbilical = new TezTaskUmbilicalProtocolImpl();
        this.tokenIdentifier = taskCommunicatorContext.getApplicationAttemptId().getApplicationId().toString();
        this.sessionToken = TokenCache.getSessionToken(taskCommunicatorContext.getAMCredentials());
        try {
            this.conf = TezUtils.createConfFromUserPayload(getContext().getInitialUserPayload());
        } catch (IOException e) {
            throw new TezUncheckedException("Unable to parse user payload for " + TezTaskCommunicatorImpl.class.getSimpleName(), e);
        }
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicator
    public void start() {
        startRpcServer();
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicator
    public void shutdown() {
        stopRpcServer();
    }

    protected void startRpcServer() {
        try {
            JobTokenSecretManager jobTokenSecretManager = new JobTokenSecretManager();
            jobTokenSecretManager.addTokenForJob(this.tokenIdentifier, this.sessionToken);
            this.server = new RPC.Builder(this.conf).setProtocol(TezTaskUmbilicalProtocol.class).setBindAddress("0.0.0.0").setPort(0).setInstance(this.taskUmbilical).setNumHandlers(this.conf.getInt("tez.am.task.listener.thread-count", 30)).setPortRangeConfig("tez.am.task.am.port-range").setSecretManager(jobTokenSecretManager).build();
            if (this.conf.getBoolean("hadoop.security.authorization", false)) {
                refreshServiceAcls(this.conf, new TezAMPolicyProvider());
            }
            this.server.start();
            InetSocketAddress connectAddress = NetUtils.getConnectAddress(this.server);
            this.address = NetUtils.createSocketAddrForHost(connectAddress.getAddress().getCanonicalHostName(), connectAddress.getPort());
            LOG.info("Instantiated TezTaskCommunicator RPC at " + this.address);
        } catch (IOException e) {
            throw new TezUncheckedException(e);
        }
    }

    protected void stopRpcServer() {
        if (this.server != null) {
            this.server.stop();
            this.server = null;
        }
    }

    protected Configuration getConf() {
        return this.conf;
    }

    private void refreshServiceAcls(Configuration configuration, PolicyProvider policyProvider) {
        this.server.refreshServiceAcl(configuration, policyProvider);
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicator
    public void registerRunningContainer(ContainerId containerId, String str, int i) {
        if (this.registeredContainers.putIfAbsent(containerId, new ContainerInfo(containerId, str, i)) != null) {
            throw new TezUncheckedException("Multiple registrations for containerId: " + containerId);
        }
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicator
    public void registerContainerEnd(ContainerId containerId, ContainerEndReason containerEndReason, String str) {
        ContainerInfo remove = this.registeredContainers.remove(containerId);
        if (remove != null) {
            synchronized (remove) {
                if (remove.taskSpec != null && remove.taskSpec.getTaskAttemptID() != null) {
                    this.attemptToContainerMap.remove(remove.taskSpec.getTaskAttemptID());
                }
            }
        }
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicator
    public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec, Map<String, LocalResource> map, Credentials credentials, boolean z, int i) {
        ContainerInfo containerInfo = this.registeredContainers.get(containerId);
        Objects.requireNonNull(containerInfo, String.format("Cannot register task attempt %s to unknown container %s", taskSpec.getTaskAttemptID(), containerId));
        synchronized (containerInfo) {
            if (containerInfo.taskSpec != null) {
                throw new TezUncheckedException("Cannot register task: " + taskSpec.getTaskAttemptID() + " to container: " + containerId + " , with pre-existing assignment: " + containerInfo.taskSpec.getTaskAttemptID());
            }
            containerInfo.taskSpec = taskSpec;
            containerInfo.additionalLRs = map;
            containerInfo.credentials = credentials;
            containerInfo.credentialsChanged = z;
            containerInfo.taskPulled = false;
            ContainerId putIfAbsent = this.attemptToContainerMap.putIfAbsent(taskSpec.getTaskAttemptID(), containerId);
            if (putIfAbsent != null) {
                throw new TezUncheckedException("Attempting to register an already registered taskAttempt with id: " + taskSpec.getTaskAttemptID() + " to containerId: " + containerId + ". Already registered to containerId: " + putIfAbsent);
            }
        }
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicator
    public void unregisterRunningTaskAttempt(TezTaskAttemptID tezTaskAttemptID, TaskAttemptEndReason taskAttemptEndReason, String str) {
        ContainerId remove = this.attemptToContainerMap.remove(tezTaskAttemptID);
        if (remove == null) {
            LOG.warn("Unregister task attempt: " + tezTaskAttemptID + " from unknown container");
            return;
        }
        ContainerInfo containerInfo = this.registeredContainers.get(remove);
        if (containerInfo == null) {
            LOG.warn("Unregister task attempt: " + tezTaskAttemptID + " from non-registered container: " + remove);
            return;
        }
        synchronized (containerInfo) {
            containerInfo.reset();
            this.attemptToContainerMap.remove(tezTaskAttemptID);
        }
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicator
    public InetSocketAddress getAddress() {
        return this.address;
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicator
    public void onVertexStateUpdated(VertexStateUpdate vertexStateUpdate) {
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicator
    public void dagComplete(int i) {
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicator
    public Object getMetaInfo() {
        return this.address;
    }

    protected String getTokenIdentifier() {
        return this.tokenIdentifier;
    }

    protected Token<JobTokenIdentifier> getSessionToken() {
        return this.sessionToken;
    }

    public TezTaskUmbilicalProtocol getUmbilical() {
        return this.taskUmbilical;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public ContainerTask getContainerTask(ContainerId containerId) throws IOException {
        ContainerTask containerTask;
        ContainerInfo containerInfo = this.registeredContainers.get(containerId);
        if (containerInfo == null) {
            if (getContext().isKnownContainer(containerId)) {
                LOG.info("Container with id: " + containerId + " is valid, but no longer registered, and will be killed");
            } else {
                LOG.info("Container with id: " + containerId + " is invalid and will be killed");
            }
            containerTask = TASK_FOR_INVALID_JVM;
        } else {
            synchronized (containerInfo) {
                getContext().containerAlive(containerId);
                if (containerInfo.taskSpec == null) {
                    containerTask = null;
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("No task assigned yet for running container: " + containerId);
                    }
                } else if (containerInfo.taskPulled) {
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("Task " + containerInfo.taskSpec.getTaskAttemptID() + " already sent to container: " + containerId);
                    }
                    containerTask = null;
                } else {
                    containerInfo.taskPulled = true;
                    containerTask = constructContainerTask(containerInfo);
                }
            }
        }
        return containerTask;
    }

    private ContainerTask constructContainerTask(ContainerInfo containerInfo) throws IOException {
        return new ContainerTask(containerInfo.taskSpec, false, convertLocalResourceMap(containerInfo.additionalLRs), containerInfo.credentials, containerInfo.credentialsChanged);
    }

    private Map<String, TezLocalResource> convertLocalResourceMap(Map<String, LocalResource> map) throws IOException {
        HashMap newHashMap = Maps.newHashMap();
        if (map != null) {
            for (Map.Entry<String, LocalResource> entry : map.entrySet()) {
                try {
                    newHashMap.put(entry.getKey(), TezConverterUtils.convertYarnLocalResourceToTez(entry.getValue()));
                } catch (URISyntaxException e) {
                    throw new IOException(e);
                }
            }
        }
        return newHashMap;
    }

    protected ContainerInfo getContainerInfo(ContainerId containerId) {
        return this.registeredContainers.get(containerId);
    }

    protected ContainerId getContainerForAttempt(TezTaskAttemptID tezTaskAttemptID) {
        return this.attemptToContainerMap.get(tezTaskAttemptID);
    }
}
