package io.vertx.mysqlclient.impl.codec;
import io.netty.buffer.ByteBuf;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.mysqlclient.MySQLAuthenticationPlugin;
import io.vertx.mysqlclient.SslMode;
import io.vertx.mysqlclient.impl.MySQLDatabaseMetadata;
import io.vertx.mysqlclient.impl.command.InitialHandshakeCommand;
import io.vertx.mysqlclient.impl.protocol.CapabilitiesFlag;
import io.vertx.mysqlclient.impl.util.BufferUtils;
import io.vertx.mysqlclient.impl.util.CachingSha2Authenticator;
import io.vertx.mysqlclient.impl.util.Native41Authenticator;
import io.vertx.sqlclient.impl.Connection;
import io.vertx.sqlclient.impl.command.CommandResponse;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import static io.vertx.mysqlclient.impl.protocol.CapabilitiesFlag.*;
import static io.vertx.mysqlclient.impl.protocol.Packets.*;
class InitialHandshakeCommandCodec extends AuthenticationCommandBaseCodec<Connection, InitialHandshakeCommand> {
private static final Logger LOGGER = LoggerFactory.getLogger(InitialHandshakeCommandCodec.class);
private static final int AUTH_PLUGIN_DATA_PART1_LENGTH = 8;
private static final int ST_CONNECTING = 0;
private static final int ST_AUTHENTICATING = 1;
private static final int ST_CONNECTED = 2;
private int status = ST_CONNECTING;
InitialHandshakeCommandCodec(InitialHandshakeCommand cmd) {
super(cmd);
}
@Override
void decodePayload(ByteBuf payload, int payloadLength) {
switch (status) {
case ST_CONNECTING:
handleInitialHandshake(payload);
status = ST_AUTHENTICATING;
break;
case ST_AUTHENTICATING:
handleAuthentication(payload);
break;
}
}
private void handleInitialHandshake(ByteBuf payload) {
encoder.clientCapabilitiesFlag = cmd.initialCapabilitiesFlags();
encoder.encodingCharset = cmd.charsetEncoding();
short protocolVersion = payload.readUnsignedByte();
String serverVersion = BufferUtils.readNullTerminatedString(payload, StandardCharsets.US_ASCII);
MySQLDatabaseMetadata md = MySQLDatabaseMetadata.parse(serverVersion);
encoder.socketConnection.metaData = md;
if (md.majorVersion() == 5 &&
(md.minorVersion() < 7 || (md.minorVersion() == 7 && md.microVersion() < 5))) {
} else {
encoder.clientCapabilitiesFlag |= CLIENT_DEPRECATE_EOF;
}
long connectionId = payload.readUnsignedIntLE();
this.authPluginData = new byte[NONCE_LENGTH];
payload.readBytes(authPluginData, 0, AUTH_PLUGIN_DATA_PART1_LENGTH);
payload.readByte();
int lowerServerCapabilitiesFlags = payload.readUnsignedShortLE();
short characterSet = payload.readUnsignedByte();
int statusFlags = payload.readUnsignedShortLE();
int capabilityFlagsUpper = payload.readUnsignedShortLE();
final int serverCapabilitiesFlags = (lowerServerCapabilitiesFlags | (capabilityFlagsUpper << 16));
short lenOfAuthPluginData;
boolean isClientPluginAuthSupported = (serverCapabilitiesFlags & CapabilitiesFlag.CLIENT_PLUGIN_AUTH) != 0;
if (isClientPluginAuthSupported) {
lenOfAuthPluginData = payload.readUnsignedByte();
} else {
payload.readerIndex(payload.readerIndex() + 1);
lenOfAuthPluginData = 0;
}
payload.readerIndex(payload.readerIndex() + 10);
payload.readBytes(authPluginData, AUTH_PLUGIN_DATA_PART1_LENGTH, Math.max(NONCE_LENGTH - AUTH_PLUGIN_DATA_PART1_LENGTH, lenOfAuthPluginData - 9));
payload.readByte();
final String serverAuthPluginName = BufferUtils.readNullTerminatedString(payload, StandardCharsets.UTF_8);
boolean upgradeToSsl;
SslMode sslMode = cmd.sslMode();
switch (sslMode) {
case DISABLED:
upgradeToSsl = false;
break;
case PREFERRED:
upgradeToSsl = isTlsSupportedByServer(serverCapabilitiesFlags);
break;
case REQUIRED:
case VERIFY_CA:
case VERIFY_IDENTITY:
upgradeToSsl = true;
break;
default:
completionHandler.handle(CommandResponse.failure(new IllegalStateException("Unknown SSL mode to handle: " + sslMode)));
return;
}
if (upgradeToSsl) {
encoder.clientCapabilitiesFlag |= CLIENT_SSL;
sendSslRequest();
encoder.socketConnection.upgradeToSsl(upgrade -> {
if (upgrade.succeeded()) {
doSendHandshakeResponseMessage(serverAuthPluginName, cmd.authenticationPlugin(), authPluginData, serverCapabilitiesFlags);
} else {
completionHandler.handle(CommandResponse.failure(upgrade.cause()));
}
});
} else {
doSendHandshakeResponseMessage(serverAuthPluginName, cmd.authenticationPlugin(), authPluginData, serverCapabilitiesFlags);
}
}
private void doSendHandshakeResponseMessage(String serverAuthPluginName, MySQLAuthenticationPlugin authPlugin, byte[] nonce, int serverCapabilitiesFlags) {
Map<String, String> clientConnectionAttributes = cmd.connectionAttributes();
encoder.clientCapabilitiesFlag &= serverCapabilitiesFlags;
String clientPluginName = authPlugin == MySQLAuthenticationPlugin.DEFAULT ? serverAuthPluginName : authPlugin.value;
sendHandshakeResponseMessage(cmd.username(), cmd.password(), cmd.database(), nonce, clientPluginName, clientConnectionAttributes);
}
private void handleAuthentication(ByteBuf payload) {
int header = payload.getUnsignedByte(payload.readerIndex());
switch (header) {
case OK_PACKET_HEADER:
status = ST_CONNECTED;
completionHandler.handle(CommandResponse.success(cmd.connection()));
break;
case ERROR_PACKET_HEADER:
handleErrorPacketPayload(payload);
break;
case AUTH_SWITCH_REQUEST_STATUS_FLAG:
handleAuthSwitchRequest(cmd.password().getBytes(StandardCharsets.UTF_8), payload);
break;
case AUTH_MORE_DATA_STATUS_FLAG:
handleAuthMoreData(cmd.password().getBytes(StandardCharsets.UTF_8), payload);
break;
default:
completionHandler.handle(CommandResponse.failure(new IllegalStateException("Unhandled state with header: " + header)));
}
}
private void handleAuthSwitchRequest(byte[] password, ByteBuf payload) {
payload.skipBytes(1);
String pluginName = BufferUtils.readNullTerminatedString(payload, StandardCharsets.UTF_8);
byte[] nonce = new byte[NONCE_LENGTH];
payload.readBytes(nonce);
byte[] authResponse;
switch (pluginName) {
case "mysql_native_password":
authResponse = Native41Authenticator.encode(password, nonce);
break;
case "caching_sha2_password":
authResponse = CachingSha2Authenticator.encode(password, nonce);
break;
case "mysql_clear_password":
authResponse = password;
break;
default:
completionHandler.handle(CommandResponse.failure(new UnsupportedOperationException("Unsupported authentication method: " + pluginName)));
return;
}
sendBytesAsPacket(authResponse);
}
private void sendSslRequest() {
ByteBuf packet = allocateBuffer(36);
packet.writeMediumLE(32);
packet.writeByte(sequenceId);
packet.writeIntLE(encoder.clientCapabilitiesFlag);
packet.writeIntLE(PACKET_PAYLOAD_LENGTH_LIMIT);
packet.writeByte(cmd.collation().collationId());
packet.writeZero(23);
sendNonSplitPacket(packet);
}
private void sendHandshakeResponseMessage(String username, String password, String database, byte[] nonce, String clientPluginName, Map<String, String> clientConnectionAttributes) {
ByteBuf packet = allocateBuffer();
int packetStartIdx = packet.writerIndex();
packet.writeMediumLE(0);
packet.writeByte(sequenceId);
int clientCapabilitiesFlags = encoder.clientCapabilitiesFlag;
packet.writeIntLE(clientCapabilitiesFlags);
packet.writeIntLE(PACKET_PAYLOAD_LENGTH_LIMIT);
packet.writeByte(cmd.collation().collationId());
packet.writeZero(23);
BufferUtils.writeNullTerminatedString(packet, username, StandardCharsets.UTF_8);
String authMethod = clientPluginName;
if (password.isEmpty()) {
packet.writeByte(0);
} else {
byte[] authResponse;
switch (authMethod) {
case "mysql_native_password":
authResponse = Native41Authenticator.encode(password.getBytes(StandardCharsets.UTF_8), nonce);
break;
case "caching_sha2_password":
authResponse = CachingSha2Authenticator.encode(password.getBytes(StandardCharsets.UTF_8), nonce);
break;
case "mysql_clear_password":
authResponse = password.getBytes(StandardCharsets.UTF_8);
break;
default:
LOGGER.warn("Unknown authentication method: " + authMethod + ", the client will try to use mysql_native_password instead.");
authMethod = "mysql_native_password";
authResponse = Native41Authenticator.encode(password.getBytes(StandardCharsets.UTF_8), nonce);
break;
}
if ((clientCapabilitiesFlags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) != 0) {
BufferUtils.writeLengthEncodedInteger(packet, authResponse.length);
packet.writeBytes(authResponse);
} else if ((clientCapabilitiesFlags & CLIENT_SECURE_CONNECTION) != 0) {
packet.writeByte(authResponse.length);
packet.writeBytes(authResponse);
} else {
packet.writeByte(0);
}
}
if ((clientCapabilitiesFlags & CLIENT_CONNECT_WITH_DB) != 0) {
BufferUtils.writeNullTerminatedString(packet, database, StandardCharsets.UTF_8);
}
if ((clientCapabilitiesFlags & CLIENT_PLUGIN_AUTH) != 0) {
BufferUtils.writeNullTerminatedString(packet, authMethod, StandardCharsets.UTF_8);
}
if ((clientCapabilitiesFlags & CLIENT_CONNECT_ATTRS) != 0) {
encodeConnectionAttributes(clientConnectionAttributes, packet);
}
int payloadLength = packet.writerIndex() - packetStartIdx - 4;
packet.setMediumLE(packetStartIdx, payloadLength);
sendPacket(packet, payloadLength);
}
private boolean isTlsSupportedByServer(int serverCapabilitiesFlags) {
return (serverCapabilitiesFlags & CLIENT_SSL) != 0;
}
}