package org.apache.shardingsphere.proxy.frontend.mysql.authentication;

import com.google.common.base.Strings;
import io.netty.channel.ChannelHandlerContext;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Optional;
import org.apache.shardingsphere.db.protocol.CommonConstants;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLCapabilityFlag;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLCharacterSet;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConnectionPhase;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLStatusFlag;
import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLErrPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLAuthSwitchRequestPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLAuthSwitchResponsePacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandshakePacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
import org.apache.shardingsphere.dialect.mysql.vendor.MySQLVendorError;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationEngine;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResult;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResultBuilder;
import org.apache.shardingsphere.proxy.frontend.connection.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.mysql.authentication.authenticator.MySQLAuthenticator;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLStatementIDGenerator;

/* loaded from: input_file:org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngine.class */
public final class MySQLAuthenticationEngine implements AuthenticationEngine {
    private static final int DEFAULT_STATUS_FLAG = MySQLStatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue();
    private final MySQLAuthenticationHandler authenticationHandler = new MySQLAuthenticationHandler();
    private MySQLConnectionPhase connectionPhase = MySQLConnectionPhase.INITIAL_HANDSHAKE;
    private int sequenceId;
    private byte[] authResponse;
    private AuthenticationResult currentAuthResult;

    public int handshake(ChannelHandlerContext channelHandlerContext) {
        int nextId = ConnectionIdGenerator.getInstance().nextId();
        this.connectionPhase = MySQLConnectionPhase.AUTH_PHASE_FAST_PATH;
        channelHandlerContext.writeAndFlush(new MySQLHandshakePacket(nextId, this.authenticationHandler.getAuthPluginData()));
        MySQLStatementIDGenerator.getInstance().registerConnection(nextId);
        return nextId;
    }

    public AuthenticationResult authenticate(ChannelHandlerContext channelHandlerContext, PacketPayload packetPayload) {
        if (MySQLConnectionPhase.AUTH_PHASE_FAST_PATH == this.connectionPhase) {
            this.currentAuthResult = authPhaseFastPath(channelHandlerContext, packetPayload);
            if (!this.currentAuthResult.isFinished()) {
                return this.currentAuthResult;
            }
        } else if (MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH == this.connectionPhase) {
            authenticationMethodMismatch((MySQLPacketPayload) packetPayload);
        }
        Optional<MySQLVendorError> login = this.authenticationHandler.login(this.currentAuthResult.getUsername(), getHostAddress(channelHandlerContext), this.authResponse, this.currentAuthResult.getDatabase());
        if (login.isPresent()) {
            channelHandlerContext.writeAndFlush(createErrorPacket(login.get(), channelHandlerContext));
            channelHandlerContext.close();
            return AuthenticationResultBuilder.continued();
        }
        int i = this.sequenceId + 1;
        this.sequenceId = i;
        channelHandlerContext.writeAndFlush(new MySQLOKPacket(i, DEFAULT_STATUS_FLAG));
        return AuthenticationResultBuilder.finished(this.currentAuthResult.getUsername(), getHostAddress(channelHandlerContext), this.currentAuthResult.getDatabase());
    }

    private AuthenticationResult authPhaseFastPath(ChannelHandlerContext channelHandlerContext, PacketPayload packetPayload) {
        MySQLHandshakeResponse41Packet mySQLHandshakeResponse41Packet = new MySQLHandshakeResponse41Packet((MySQLPacketPayload) packetPayload);
        this.authResponse = mySQLHandshakeResponse41Packet.getAuthResponse();
        this.sequenceId = mySQLHandshakeResponse41Packet.getSequenceId();
        MySQLCharacterSet findById = MySQLCharacterSet.findById(mySQLHandshakeResponse41Packet.getCharacterSet());
        channelHandlerContext.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).set(findById.getCharset());
        channelHandlerContext.channel().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).set(findById);
        if (!Strings.isNullOrEmpty(mySQLHandshakeResponse41Packet.getDatabase()) && !ProxyContext.getInstance().databaseExists(mySQLHandshakeResponse41Packet.getDatabase())) {
            int i = this.sequenceId + 1;
            this.sequenceId = i;
            channelHandlerContext.writeAndFlush(new MySQLErrPacket(i, MySQLVendorError.ER_BAD_DB_ERROR, new Object[]{mySQLHandshakeResponse41Packet.getDatabase()}));
            channelHandlerContext.close();
            return AuthenticationResultBuilder.continued();
        }
        MySQLAuthenticator authenticator = this.authenticationHandler.getAuthenticator(mySQLHandshakeResponse41Packet.getUsername(), getHostAddress(channelHandlerContext));
        if (!isClientPluginAuth(mySQLHandshakeResponse41Packet) || authenticator.getAuthenticationMethodName().equals(mySQLHandshakeResponse41Packet.getAuthPluginName())) {
            return AuthenticationResultBuilder.finished(mySQLHandshakeResponse41Packet.getUsername(), getHostAddress(channelHandlerContext), mySQLHandshakeResponse41Packet.getDatabase());
        }
        this.connectionPhase = MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH;
        int i2 = this.sequenceId + 1;
        this.sequenceId = i2;
        channelHandlerContext.writeAndFlush(new MySQLAuthSwitchRequestPacket(i2, authenticator.getAuthenticationMethodName(), this.authenticationHandler.getAuthPluginData()));
        return AuthenticationResultBuilder.continued(mySQLHandshakeResponse41Packet.getUsername(), getHostAddress(channelHandlerContext), mySQLHandshakeResponse41Packet.getDatabase());
    }

    private boolean isClientPluginAuth(MySQLHandshakeResponse41Packet mySQLHandshakeResponse41Packet) {
        return 0 != (mySQLHandshakeResponse41Packet.getCapabilityFlags() & MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.getValue());
    }

    private void authenticationMethodMismatch(MySQLPacketPayload mySQLPacketPayload) {
        MySQLAuthSwitchResponsePacket mySQLAuthSwitchResponsePacket = new MySQLAuthSwitchResponsePacket(mySQLPacketPayload);
        this.sequenceId = mySQLAuthSwitchResponsePacket.getSequenceId();
        this.authResponse = mySQLAuthSwitchResponsePacket.getAuthPluginResponse();
    }

    private MySQLErrPacket createErrorPacket(MySQLVendorError mySQLVendorError, ChannelHandlerContext channelHandlerContext) {
        if (MySQLVendorError.ER_DBACCESS_DENIED_ERROR == mySQLVendorError) {
            int i = this.sequenceId + 1;
            this.sequenceId = i;
            return new MySQLErrPacket(i, MySQLVendorError.ER_DBACCESS_DENIED_ERROR, new Object[]{this.currentAuthResult.getUsername(), getHostAddress(channelHandlerContext), this.currentAuthResult.getDatabase()});
        }
        int i2 = this.sequenceId + 1;
        this.sequenceId = i2;
        return new MySQLErrPacket(i2, MySQLVendorError.ER_ACCESS_DENIED_ERROR, new Object[]{this.currentAuthResult.getUsername(), getHostAddress(channelHandlerContext), getErrorMessage()});
    }

    private String getErrorMessage() {
        return 0 == this.authResponse.length ? "NO" : "YES";
    }

    private String getHostAddress(ChannelHandlerContext channelHandlerContext) {
        SocketAddress remoteAddress = channelHandlerContext.channel().remoteAddress();
        return remoteAddress instanceof InetSocketAddress ? ((InetSocketAddress) remoteAddress).getAddress().getHostAddress() : remoteAddress.toString();
    }
}
