package io.vertx.mysqlclient.impl.codec;
import io.netty.buffer.ByteBuf;
import io.vertx.core.Handler;
import io.vertx.mysqlclient.MySQLException;
import io.vertx.mysqlclient.impl.protocol.CapabilitiesFlag;
import io.vertx.mysqlclient.impl.datatype.DataType;
import io.vertx.mysqlclient.impl.protocol.ColumnDefinition;
import io.vertx.mysqlclient.impl.util.BufferUtils;
import io.vertx.sqlclient.impl.command.CommandBase;
import io.vertx.sqlclient.impl.command.CommandResponse;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import static io.vertx.mysqlclient.impl.protocol.Packets.*;
abstract class CommandCodec<R, C extends CommandBase<R>> {
Handler<? super CommandResponse<R>> completionHandler;
public Throwable failure;
public R result;
final C cmd;
MySQLEncoder encoder;
int sequenceId;
CommandCodec(C cmd) {
this.cmd = cmd;
}
abstract void decodePayload(ByteBuf payload, int payloadLength);
void encode(MySQLEncoder encoder) {
this.encoder = encoder;
this.sequenceId = 0;
}
ByteBuf allocateBuffer() {
return encoder.chctx.alloc().ioBuffer();
}
ByteBuf allocateBuffer(int capacity) {
return encoder.chctx.alloc().ioBuffer(capacity);
}
void sendPacket(ByteBuf packet, int payloadLength) {
if (payloadLength >= PACKET_PAYLOAD_LENGTH_LIMIT) {
sendSplitPacket(packet);
} else {
sendNonSplitPacket(packet);
}
}
private void sendSplitPacket(ByteBuf packet) {
ByteBuf payload = packet.skipBytes(4);
while (payload.readableBytes() >= PACKET_PAYLOAD_LENGTH_LIMIT) {
ByteBuf packetHeader = allocateBuffer(4);
packetHeader.writeMediumLE(PACKET_PAYLOAD_LENGTH_LIMIT);
packetHeader.writeByte(sequenceId++);
encoder.chctx.write(packetHeader);
encoder.chctx.write(payload.readRetainedSlice(PACKET_PAYLOAD_LENGTH_LIMIT));
}
ByteBuf packetHeader = allocateBuffer(4);
packetHeader.writeMediumLE(payload.readableBytes());
packetHeader.writeByte(sequenceId++);
encoder.chctx.write(packetHeader);
encoder.chctx.writeAndFlush(payload);
}
void sendNonSplitPacket(ByteBuf packet) {
sequenceId++;
encoder.chctx.writeAndFlush(packet);
}
final void sendBytesAsPacket(byte[] payload) {
int payloadLength = payload.length;
ByteBuf packet = allocateBuffer(payloadLength + 4);
packet.writeMediumLE(payloadLength);
packet.writeByte(sequenceId);
packet.writeBytes(payload);
sendNonSplitPacket(packet);
}
void handleOkPacketOrErrorPacketPayload(ByteBuf payload) {
int header = payload.getUnsignedByte(payload.readerIndex());
switch (header) {
case EOF_PACKET_HEADER:
case OK_PACKET_HEADER:
completionHandler.handle(CommandResponse.success(null));
break;
case ERROR_PACKET_HEADER:
handleErrorPacketPayload(payload);
break;
}
}
void handleErrorPacketPayload(ByteBuf payload) {
payload.skipBytes(1);
int errorCode = payload.readUnsignedShortLE();
payload.skipBytes(1);
String sqlState = BufferUtils.readFixedLengthString(payload, 5, StandardCharsets.UTF_8);
String errorMessage = readRestOfPacketString(payload, StandardCharsets.UTF_8);
completionHandler.handle(CommandResponse.failure(new MySQLException(errorMessage, errorCode, sqlState)));
}
OkPacket decodeOkPacketPayload(ByteBuf payload) {
payload.skipBytes(1);
long affectedRows = BufferUtils.readLengthEncodedInteger(payload);
long lastInsertId = BufferUtils.readLengthEncodedInteger(payload);
int serverStatusFlags = payload.readUnsignedShortLE();
String statusInfo = null;
String sessionStateInfo = null;
return new OkPacket(affectedRows, lastInsertId, serverStatusFlags, 0, statusInfo, sessionStateInfo);
}
EofPacket decodeEofPacketPayload(ByteBuf payload) {
payload.skipBytes(1);
int numberOfWarnings = payload.readUnsignedShortLE();
int serverStatusFlags = payload.readUnsignedShortLE();
return new EofPacket(numberOfWarnings, serverStatusFlags);
}
String readRestOfPacketString(ByteBuf payload, Charset charset) {
return BufferUtils.readFixedLengthString(payload, payload.readableBytes(), charset);
}
ColumnDefinition decodeColumnDefinitionPacketPayload(ByteBuf payload) {
String catalog = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
String schema = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
String table = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
String orgTable = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
String name = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
String orgName = BufferUtils.readLengthEncodedString(payload, StandardCharsets.UTF_8);
long lengthOfFixedLengthFields = BufferUtils.readLengthEncodedInteger(payload);
int characterSet = payload.readUnsignedShortLE();
long columnLength = payload.readUnsignedIntLE();
DataType type = DataType.valueOf(payload.readUnsignedByte());
int flags = payload.readUnsignedShortLE();
byte decimals = payload.readByte();
return new ColumnDefinition(catalog, schema, table, orgTable, name, orgName, characterSet, columnLength, type, flags, decimals);
}
void skipEofPacketIfNeeded(ByteBuf payload) {
if (!isDeprecatingEofFlagEnabled()) {
payload.skipBytes(5);
}
}
boolean isDeprecatingEofFlagEnabled() {
return (encoder.clientCapabilitiesFlag & CapabilitiesFlag.CLIENT_DEPRECATE_EOF) != 0;
}
}