package org.apache.cassandra.transport;
import java.util.ArrayList;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.MessageToMessageEncoder;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.cassandra.concurrent.LocalAwareExecutorService;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.exceptions.OverloadedException;
import org.apache.cassandra.metrics.ClientMetrics;
import org.apache.cassandra.net.ResourceLimits;
import org.apache.cassandra.service.ClientWarn;
import org.apache.cassandra.transport.messages.*;
import org.apache.cassandra.service.QueryState;
import org.apache.cassandra.utils.JVMStabilityInspector;
import static org.apache.cassandra.concurrent.SharedExecutorPool.SHARED;
public abstract class Message
{
protected static final Logger logger = LoggerFactory.getLogger(Message.class);
private static final Set<String> ioExceptionsAtDebugLevel = ImmutableSet.<String>builder().
add("Connection reset by peer").
add("Broken pipe").
add("Connection timed out").
build();
public interface Codec<M extends Message> extends CBCodec<M> {}
public enum Direction
{
REQUEST, RESPONSE;
public static Direction extractFromVersion(int versionWithDirection)
{
return (versionWithDirection & 0x80) == 0 ? REQUEST : RESPONSE;
}
public int addToVersion(int rawVersion)
{
return this == REQUEST ? (rawVersion & 0x7F) : (rawVersion | 0x80);
}
}
public enum Type
{
ERROR (0, Direction.RESPONSE, ErrorMessage.codec),
STARTUP (1, Direction.REQUEST, StartupMessage.codec),
READY (2, Direction.RESPONSE, ReadyMessage.codec),
AUTHENTICATE (3, Direction.RESPONSE, AuthenticateMessage.codec),
CREDENTIALS (4, Direction.REQUEST, CredentialsMessage.codec),
OPTIONS (5, Direction.REQUEST, OptionsMessage.codec),
SUPPORTED (6, Direction.RESPONSE, SupportedMessage.codec),
QUERY (7, Direction.REQUEST, QueryMessage.codec),
RESULT (8, Direction.RESPONSE, ResultMessage.codec),
PREPARE (9, Direction.REQUEST, PrepareMessage.codec),
EXECUTE (10, Direction.REQUEST, ExecuteMessage.codec),
REGISTER (11, Direction.REQUEST, RegisterMessage.codec),
EVENT (12, Direction.RESPONSE, EventMessage.codec),
BATCH (13, Direction.REQUEST, BatchMessage.codec),
AUTH_CHALLENGE (14, Direction.RESPONSE, AuthChallenge.codec),
AUTH_RESPONSE (15, Direction.REQUEST, AuthResponse.codec),
AUTH_SUCCESS (16, Direction.RESPONSE, AuthSuccess.codec);
public final int opcode;
public final Direction direction;
public final Codec<?> codec;
private static final Type[] opcodeIdx;
static
{
int maxOpcode = -1;
for (Type type : Type.values())
maxOpcode = Math.max(maxOpcode, type.opcode);
opcodeIdx = new Type[maxOpcode + 1];
for (Type type : Type.values())
{
if (opcodeIdx[type.opcode] != null)
throw new IllegalStateException("Duplicate opcode");
opcodeIdx[type.opcode] = type;
}
}
Type(int opcode, Direction direction, Codec<?> codec)
{
this.opcode = opcode;
this.direction = direction;
this.codec = codec;
}
public static Type fromOpcode(int opcode, Direction direction)
{
if (opcode >= opcodeIdx.length)
throw new ProtocolException(String.format("Unknown opcode %d", opcode));
Type t = opcodeIdx[opcode];
if (t == null)
throw new ProtocolException(String.format("Unknown opcode %d", opcode));
if (t.direction != direction)
throw new ProtocolException(String.format("Wrong protocol direction (expected %s, got %s) for opcode %d (%s)",
t.direction,
direction,
opcode,
t));
return t;
}
}
public final Type type;
protected Connection connection;
private int streamId;
private Frame sourceFrame;
private Map<String, ByteBuffer> customPayload;
protected ProtocolVersion forcedProtocolVersion = null;
protected Message(Type type)
{
this.type = type;
}
public void attach(Connection connection)
{
this.connection = connection;
}
public Connection connection()
{
return connection;
}
public Message setStreamId(int streamId)
{
this.streamId = streamId;
return this;
}
public int getStreamId()
{
return streamId;
}
public void setSourceFrame(Frame sourceFrame)
{
this.sourceFrame = sourceFrame;
}
public Frame getSourceFrame()
{
return sourceFrame;
}
public Map<String, ByteBuffer> getCustomPayload()
{
return customPayload;
}
public void setCustomPayload(Map<String, ByteBuffer> customPayload)
{
this.customPayload = customPayload;
}
public static abstract class Request extends Message
{
protected boolean tracingRequested;
protected Request(Type type)
{
super(type);
if (type.direction != Direction.REQUEST)
throw new IllegalArgumentException();
}
public abstract Response execute(QueryState queryState, long queryStartNanoTime);
public void setTracingRequested()
{
this.tracingRequested = true;
}
public boolean isTracingRequested()
{
return tracingRequested;
}
}
public static abstract class Response extends Message
{
protected UUID tracingId;
protected List<String> warnings;
protected Response(Type type)
{
super(type);
if (type.direction != Direction.RESPONSE)
throw new IllegalArgumentException();
}
public Message setTracingId(UUID tracingId)
{
this.tracingId = tracingId;
return this;
}
public UUID getTracingId()
{
return tracingId;
}
public Message setWarnings(List<String> warnings)
{
this.warnings = warnings;
return this;
}
public List<String> getWarnings()
{
return warnings;
}
}
@ChannelHandler.Sharable
public static class ProtocolDecoder extends MessageToMessageDecoder<Frame>
{
public void decode(ChannelHandlerContext ctx, Frame frame, List results)
{
boolean isRequest = frame.header.type.direction == Direction.REQUEST;
boolean isTracing = frame.header.flags.contains(Frame.Header.Flag.TRACING);
boolean isCustomPayload = frame.header.flags.contains(Frame.Header.Flag.CUSTOM_PAYLOAD);
boolean hasWarning = frame.header.flags.contains(Frame.Header.Flag.WARNING);
UUID tracingId = isRequest || !isTracing ? null : CBUtil.readUUID(frame.body);
List<String> warnings = isRequest || !hasWarning ? null : CBUtil.readStringList(frame.body);
Map<String, ByteBuffer> customPayload = !isCustomPayload ? null : CBUtil.readBytesMap(frame.body);
try
{
if (isCustomPayload && frame.header.version.isSmallerThan(ProtocolVersion.V4))
throw new ProtocolException("Received frame with CUSTOM_PAYLOAD flag for native protocol version < 4");
Message message = frame.header.type.codec.decode(frame.body, frame.header.version);
message.setStreamId(frame.header.streamId);
message.setSourceFrame(frame);
message.setCustomPayload(customPayload);
if (isRequest)
{
assert message instanceof Request;
Request req = (Request)message;
Connection connection = ctx.channel().attr(Connection.attributeKey).get();
req.attach(connection);
if (isTracing)
req.setTracingRequested();
}
else
{
assert message instanceof Response;
if (isTracing)
((Response)message).setTracingId(tracingId);
if (hasWarning)
((Response)message).setWarnings(warnings);
}
results.add(message);
}
catch (Throwable ex)
{
frame.release();
throw ErrorMessage.wrap(ex, frame.header.streamId);
}
}
}
@ChannelHandler.Sharable
public static class ProtocolEncoder extends MessageToMessageEncoder<Message>
{
private final ProtocolVersionLimit versionCap;
ProtocolEncoder(ProtocolVersionLimit versionCap)
{
this.versionCap = versionCap;
}
public void encode(ChannelHandlerContext ctx, Message message, List results)
{
Connection connection = ctx.channel().attr(Connection.attributeKey).get();
ProtocolVersion version = connection == null ? versionCap.getMaxVersion() : connection.getVersion();
EnumSet<Frame.Header.Flag> flags = EnumSet.noneOf(Frame.Header.Flag.class);
Codec<Message> codec = (Codec<Message>)message.type.codec;
try
{
int messageSize = codec.encodedSize(message, version);
ByteBuf body;
if (message instanceof Response)
{
UUID tracingId = ((Response)message).getTracingId();
Map<String, ByteBuffer> customPayload = message.getCustomPayload();
if (tracingId != null)
messageSize += CBUtil.sizeOfUUID(tracingId);
List<String> warnings = ((Response)message).getWarnings();
if (warnings != null)
{
if (version.isSmallerThan(ProtocolVersion.V4))
throw new ProtocolException("Must not send frame with WARNING flag for native protocol version < 4");
messageSize += CBUtil.sizeOfStringList(warnings);
}
if (customPayload != null)
{
if (version.isSmallerThan(ProtocolVersion.V4))
throw new ProtocolException("Must not send frame with CUSTOM_PAYLOAD flag for native protocol version < 4");
messageSize += CBUtil.sizeOfBytesMap(customPayload);
}
body = CBUtil.allocator.buffer(messageSize);
if (tracingId != null)
{
CBUtil.writeUUID(tracingId, body);
flags.add(Frame.Header.Flag.TRACING);
}
if (warnings != null)
{
CBUtil.writeStringList(warnings, body);
flags.add(Frame.Header.Flag.WARNING);
}
if (customPayload != null)
{
CBUtil.writeBytesMap(customPayload, body);
flags.add(Frame.Header.Flag.CUSTOM_PAYLOAD);
}
}
else
{
assert message instanceof Request;
if (((Request)message).isTracingRequested())
flags.add(Frame.Header.Flag.TRACING);
Map<String, ByteBuffer> payload = message.getCustomPayload();
if (payload != null)
messageSize += CBUtil.sizeOfBytesMap(payload);
body = CBUtil.allocator.buffer(messageSize);
if (payload != null)
{
CBUtil.writeBytesMap(payload, body);
flags.add(Frame.Header.Flag.CUSTOM_PAYLOAD);
}
}
try
{
codec.encode(message, body, version);
}
catch (Throwable e)
{
body.release();
throw e;
}
ProtocolVersion responseVersion = message.forcedProtocolVersion == null
? version
: message.forcedProtocolVersion;
if (responseVersion.isBeta())
flags.add(Frame.Header.Flag.USE_BETA);
results.add(Frame.create(message.type, message.getStreamId(), responseVersion, flags, body));
}
catch (Throwable e)
{
throw ErrorMessage.wrap(e, message.getStreamId());
}
}
}
public static class Dispatcher extends SimpleChannelInboundHandler<Request>
{
private static final LocalAwareExecutorService requestExecutor = SHARED.newExecutor(DatabaseDescriptor.getNativeTransportMaxThreads(),
Integer.MAX_VALUE,
"transport",
"Native-Transport-Requests");
private long channelPayloadBytesInFlight;
private final Server.EndpointPayloadTracker endpointPayloadTracker;
private boolean paused;
private static class FlushItem
{
final ChannelHandlerContext ctx;
final Object response;
final Frame sourceFrame;
final Dispatcher dispatcher;
private FlushItem(ChannelHandlerContext ctx, Object response, Frame sourceFrame, Dispatcher dispatcher)
{
this.ctx = ctx;
this.sourceFrame = sourceFrame;
this.response = response;
this.dispatcher = dispatcher;
}
public void release()
{
dispatcher.releaseItem(this);
}
}
private static abstract class Flusher implements Runnable
{
final EventLoop eventLoop;
final ConcurrentLinkedQueue<FlushItem> queued = new ConcurrentLinkedQueue<>();
final AtomicBoolean scheduled = new AtomicBoolean(false);
final HashSet<ChannelHandlerContext> channels = new HashSet<>();
final List<FlushItem> flushed = new ArrayList<>();
void start()
{
if (!scheduled.get() && scheduled.compareAndSet(false, true))
{
this.eventLoop.execute(this);
}
}
public Flusher(EventLoop eventLoop)
{
this.eventLoop = eventLoop;
}
}
private static final class LegacyFlusher extends Flusher
{
int runsSinceFlush = 0;
int runsWithNoWork = 0;
private LegacyFlusher(EventLoop eventLoop)
{
super(eventLoop);
}
public void run()
{
boolean doneWork = false;
FlushItem flush;
while ( null != (flush = queued.poll()) )
{
channels.add(flush.ctx);
flush.ctx.write(flush.response, flush.ctx.voidPromise());
flushed.add(flush);
doneWork = true;
}
runsSinceFlush++;
if (!doneWork || runsSinceFlush > 2 || flushed.size() > 50)
{
for (ChannelHandlerContext channel : channels)
channel.flush();
for (FlushItem item : flushed)
item.release();
channels.clear();
flushed.clear();
runsSinceFlush = 0;
}
if (doneWork)
{
runsWithNoWork = 0;
}
else
{
if (++runsWithNoWork > 5)
{
scheduled.set(false);
if (queued.isEmpty() || !scheduled.compareAndSet(false, true))
return;
}
}
eventLoop.schedule(this, 10000, TimeUnit.NANOSECONDS);
}
}
private static final class ImmediateFlusher extends Flusher
{
private ImmediateFlusher(EventLoop eventLoop)
{
super(eventLoop);
}
public void run()
{
boolean doneWork = false;
FlushItem flush;
scheduled.set(false);
while (null != (flush = queued.poll()))
{
channels.add(flush.ctx);
flush.ctx.write(flush.response, flush.ctx.voidPromise());
flushed.add(flush);
doneWork = true;
}
if (doneWork)
{
for (ChannelHandlerContext channel : channels)
channel.flush();
for (FlushItem item : flushed)
item.release();
channels.clear();
flushed.clear();
}
}
}
private static final ConcurrentMap<EventLoop, Flusher> flusherLookup = new ConcurrentHashMap<>();
private final boolean useLegacyFlusher;
public Dispatcher(boolean useLegacyFlusher, Server.EndpointPayloadTracker endpointPayloadTracker)
{
super(false);
this.useLegacyFlusher = useLegacyFlusher;
this.endpointPayloadTracker = endpointPayloadTracker;
}
@Override
public void channelRead0(ChannelHandlerContext ctx, Request request)
{
if (shouldHandleRequest(ctx, request))
requestExecutor.submit(() -> processRequest(ctx, request));
}
private boolean shouldHandleRequest(ChannelHandlerContext ctx, Request request)
{
long frameSize = request.getSourceFrame().header.bodySizeInBytes;
ResourceLimits.EndpointAndGlobal endpointAndGlobalPayloadsInFlight = endpointPayloadTracker.endpointAndGlobalPayloadsInFlight;
if (endpointAndGlobalPayloadsInFlight.tryAllocate(frameSize) != ResourceLimits.Outcome.SUCCESS)
{
if (request.connection.isThrowOnOverload())
{
ClientMetrics.instance.markRequestDiscarded();
logger.trace("Discarded request of size: {}. InflightChannelRequestPayload: {}, InflightEndpointRequestPayload: {}, InflightOverallRequestPayload: {}, Request: {}",
frameSize,
channelPayloadBytesInFlight,
endpointAndGlobalPayloadsInFlight.endpoint().using(),
endpointAndGlobalPayloadsInFlight.global().using(),
request);
throw ErrorMessage.wrap(new OverloadedException("Server is in overloaded state. Cannot accept more requests at this point"),
request.getSourceFrame().header.streamId);
}
else
{
endpointAndGlobalPayloadsInFlight.allocate(frameSize);
ctx.channel().config().setAutoRead(false);
ClientMetrics.instance.pauseConnection();
paused = true;
}
}
channelPayloadBytesInFlight += frameSize;
return true;
}
private void releaseItem(FlushItem item)
{
long itemSize = item.sourceFrame.header.bodySizeInBytes;
item.sourceFrame.release();
channelPayloadBytesInFlight -= itemSize;
ResourceLimits.Outcome endpointGlobalReleaseOutcome = endpointPayloadTracker.endpointAndGlobalPayloadsInFlight.release(itemSize);
ChannelConfig config = item.ctx.channel().config();
if (paused && (channelPayloadBytesInFlight == 0 || endpointGlobalReleaseOutcome == ResourceLimits.Outcome.BELOW_LIMIT))
{
paused = false;
ClientMetrics.instance.unpauseConnection();
config.setAutoRead(true);
}
}
void processRequest(ChannelHandlerContext ctx, Request request)
{
final Response response;
final ServerConnection connection;
long queryStartNanoTime = System.nanoTime();
try
{
assert request.connection() instanceof ServerConnection;
connection = (ServerConnection)request.connection();
if (connection.getVersion().isGreaterOrEqualTo(ProtocolVersion.V4))
ClientWarn.instance.captureWarnings();
QueryState qstate = connection.validateNewMessage(request.type, connection.getVersion(), request.getStreamId());
logger.trace("Received: {}, v={}", request, connection.getVersion());
response = request.execute(qstate, queryStartNanoTime);
response.setStreamId(request.getStreamId());
response.setWarnings(ClientWarn.instance.getWarnings());
response.attach(connection);
connection.applyStateTransition(request.type, response.type);
}
catch (Throwable t)
{
JVMStabilityInspector.inspectThrowable(t);
UnexpectedChannelExceptionHandler handler = new UnexpectedChannelExceptionHandler(ctx.channel(), true);
flush(new FlushItem(ctx, ErrorMessage.fromException(t, handler).setStreamId(request.getStreamId()), request.getSourceFrame(), this));
return;
}
finally
{
ClientWarn.instance.resetWarnings();
}
logger.trace("Responding: {}, v={}", response, connection.getVersion());
flush(new FlushItem(ctx, response, request.getSourceFrame(), this));
}
@Override
public void channelInactive(ChannelHandlerContext ctx)
{
endpointPayloadTracker.release();
if (paused)
{
paused = false;
ClientMetrics.instance.unpauseConnection();
}
ctx.fireChannelInactive();
}
private void flush(FlushItem item)
{
EventLoop loop = item.ctx.channel().eventLoop();
Flusher flusher = flusherLookup.get(loop);
if (flusher == null)
{
Flusher created = useLegacyFlusher ? new LegacyFlusher(loop) : new ImmediateFlusher(loop);
Flusher alt = flusherLookup.putIfAbsent(loop, flusher = created);
if (alt != null)
flusher = alt;
}
flusher.queued.add(item);
flusher.start();
}
public static void shutdown()
{
if (requestExecutor != null)
{
requestExecutor.shutdown();
}
}
}
@ChannelHandler.Sharable
public static final class ExceptionHandler extends ChannelInboundHandlerAdapter
{
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, Throwable cause)
{
UnexpectedChannelExceptionHandler handler = new UnexpectedChannelExceptionHandler(ctx.channel(), false);
ErrorMessage errorMessage = ErrorMessage.fromException(cause, handler);
if (ctx.channel().isOpen())
{
ChannelFuture future = ctx.writeAndFlush(errorMessage);
if (cause instanceof ProtocolException)
{
future.addListener(new ChannelFutureListener()
{
public void operationComplete(ChannelFuture future)
{
ctx.close();
}
});
}
}
}
}
static final class UnexpectedChannelExceptionHandler implements Predicate<Throwable>
{
private final Channel channel;
private final boolean alwaysLogAtError;
UnexpectedChannelExceptionHandler(Channel channel, boolean alwaysLogAtError)
{
this.channel = channel;
this.alwaysLogAtError = alwaysLogAtError;
}
@Override
public boolean apply(Throwable exception)
{
String message;
try
{
message = "Unexpected exception during request; channel = " + channel;
}
catch (Exception ignore)
{
message = "Unexpected exception during request; channel = <unprintable>";
}
if (!alwaysLogAtError && exception instanceof IOException)
{
String errorMessage = exception.getMessage();
boolean logAtTrace = false;
for (String ioException : ioExceptionsAtDebugLevel)
{
if (errorMessage.contains(ioException))
{
logAtTrace = true;
break;
}
}
if (logAtTrace)
{
logger.trace(message, exception);
}
else
{
logger.info(message, exception);
}
}
else
{
logger.error(message, exception);
}
return true;
}
}
}