package io.vertx.mysqlclient.impl.codec;
import io.netty.buffer.ByteBuf;
import io.vertx.mysqlclient.impl.datatype.DataType;
import io.vertx.mysqlclient.impl.datatype.DataTypeCodec;
import io.vertx.mysqlclient.impl.protocol.CommandType;
import io.vertx.sqlclient.Tuple;
import io.vertx.sqlclient.impl.command.CommandResponse;
import io.vertx.sqlclient.impl.command.ExtendedQueryCommand;
import static io.vertx.mysqlclient.impl.protocol.Packets.ERROR_PACKET_HEADER;
import static io.vertx.mysqlclient.impl.protocol.Packets.EnumCursorType.CURSOR_TYPE_NO_CURSOR;
import static io.vertx.mysqlclient.impl.protocol.Packets.EnumCursorType.CURSOR_TYPE_READ_ONLY;
class ExtendedQueryCommandCodec<R> extends ExtendedQueryCommandBaseCodec<R, ExtendedQueryCommand<R>> {
ExtendedQueryCommandCodec(ExtendedQueryCommand<R> cmd) {
super(cmd);
if (cmd.fetch() > 0 && statement.isCursorOpen) {
columnDefinitions = statement.rowDesc.columnDefinitions();
}
}
@Override
void encode(MySQLEncoder encoder) {
super.encode(encoder);
if (statement.isCursorOpen) {
decoder = new RowResultDecoder<>(cmd.collector(), statement.rowDesc);
sendStatementFetchCommand(statement.statementId, cmd.fetch());
} else {
Tuple params = cmd.params();
String bindMsg = statement.bindParameters(params);
if (bindMsg != null) {
completionHandler.handle(CommandResponse.failure(bindMsg));
return;
}
if (cmd.fetch() > 0) {
sendStatementExecuteCommand(statement, true, params, CURSOR_TYPE_READ_ONLY);
} else {
sendStatementExecuteCommand(statement, statement.sendTypesToServer(), params, CURSOR_TYPE_NO_CURSOR);
}
}
}
@Override
void decodePayload(ByteBuf payload, int payloadLength) {
if (statement.isCursorOpen) {
int first = payload.getUnsignedByte(payload.readerIndex());
if (first == ERROR_PACKET_HEADER) {
handleErrorPacketPayload(payload);
} else {
handleRows(payload, payloadLength);
}
} else {
if (cmd.fetch() > 0) {
switch (commandHandlerState) {
case INIT:
int first = payload.getUnsignedByte(payload.readerIndex());
if (first == ERROR_PACKET_HEADER) {
handleErrorPacketPayload(payload);
} else {
handleResultsetColumnCountPacketBody(payload);
}
break;
case HANDLING_COLUMN_DEFINITION:
handleResultsetColumnDefinitions(payload);
break;
case COLUMN_DEFINITIONS_DECODING_COMPLETED:
skipEofPacketIfNeeded(payload);
case HANDLING_ROW_DATA_OR_END_PACKET:
handleResultsetColumnDefinitionsDecodingCompleted();
sequenceId = 0;
decoder = new RowResultDecoder<>(cmd.collector(), statement.rowDesc);
statement.isCursorOpen = true;
sendStatementFetchCommand(statement.statementId, cmd.fetch());
break;
default:
throw new IllegalStateException("Unexpected state for decoding COM_STMT_EXECUTE response with cursor opening");
}
} else {
super.decodePayload(payload, payloadLength);
}
}
}
private void sendStatementExecuteCommand(MySQLPreparedStatement statement, boolean sendTypesToServer, Tuple params, byte cursorType) {
ByteBuf packet = allocateBuffer();
int packetStartIdx = packet.writerIndex();
packet.writeMediumLE(0);
packet.writeByte(sequenceId);
packet.writeByte(CommandType.COM_STMT_EXECUTE);
packet.writeIntLE((int) statement.statementId);
packet.writeByte(cursorType);
packet.writeIntLE(1);
int numOfParams = statement.bindingTypes().length;
int bitmapLength = (numOfParams + 7) / 8;
byte[] nullBitmap = new byte[bitmapLength];
int pos = packet.writerIndex();
if (numOfParams > 0) {
packet.writeBytes(nullBitmap);
packet.writeBoolean(sendTypesToServer);
if (sendTypesToServer) {
for (DataType bindingType : statement.bindingTypes()) {
packet.writeByte(bindingType.id);
packet.writeByte(0);
}
}
for (int i = 0; i < numOfParams; i++) {
Object value = params.getValue(i);
if (value != null) {
DataTypeCodec.encodeBinary(statement.bindingTypes()[i], value, encoder.encodingCharset, packet);
} else {
nullBitmap[i / 8] |= (1 << (i & 7));
}
}
packet.setBytes(pos, nullBitmap);
}
int payloadLength = packet.writerIndex() - packetStartIdx - 4;
packet.setMediumLE(packetStartIdx, payloadLength);
sendPacket(packet, payloadLength);
}
private void sendStatementFetchCommand(long statementId, int count) {
ByteBuf packet = allocateBuffer();
int packetStartIdx = packet.writerIndex();
packet.writeMediumLE(0);
packet.writeByte(sequenceId);
packet.writeByte(CommandType.COM_STMT_FETCH);
packet.writeIntLE((int) statementId);
packet.writeIntLE(count);
int lenOfPayload = packet.writerIndex() - packetStartIdx - 4;
packet.setMediumLE(packetStartIdx, lenOfPayload);
encoder.chctx.writeAndFlush(packet);
}
}