package org.apache.cassandra.transport;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.SynchronousQueue;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.netty.util.internal.logging.Slf4JLoggerFactory;
import org.apache.cassandra.cql3.QueryOptions;
import org.apache.cassandra.db.ConsistencyLevel;
import org.apache.cassandra.security.SSLFactory;
import org.apache.cassandra.transport.messages.CredentialsMessage;
import org.apache.cassandra.transport.messages.ErrorMessage;
import org.apache.cassandra.transport.messages.EventMessage;
import org.apache.cassandra.transport.messages.ExecuteMessage;
import org.apache.cassandra.transport.messages.PrepareMessage;
import org.apache.cassandra.transport.messages.QueryMessage;
import org.apache.cassandra.transport.messages.ResultMessage;
import org.apache.cassandra.transport.messages.StartupMessage;
import org.apache.cassandra.utils.MD5Digest;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.ssl.SslHandler;
import static org.apache.cassandra.config.EncryptionOptions.ClientEncryptionOptions;
public class SimpleClient implements Closeable
{
static
{
InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory());
}
private static final Logger logger = LoggerFactory.getLogger(SimpleClient.class);
public final String host;
public final int port;
private final ClientEncryptionOptions encryptionOptions;
protected final ResponseHandler responseHandler = new ResponseHandler();
protected final Connection.Tracker tracker = new ConnectionTracker();
protected final ProtocolVersion version;
protected Connection connection;
protected Bootstrap bootstrap;
protected Channel channel;
protected ChannelFuture lastWriteFuture;
private final Connection.Factory connectionFactory = new Connection.Factory()
{
public Connection newConnection(Channel channel, ProtocolVersion version)
{
return connection;
}
};
public SimpleClient(String host, int port, ProtocolVersion version, ClientEncryptionOptions encryptionOptions)
{
this(host, port, version, false, encryptionOptions);
}
public SimpleClient(String host, int port, ClientEncryptionOptions encryptionOptions)
{
this(host, port, ProtocolVersion.CURRENT, encryptionOptions);
}
public SimpleClient(String host, int port, ProtocolVersion version)
{
this(host, port, version, new ClientEncryptionOptions());
}
public SimpleClient(String host, int port, ProtocolVersion version, boolean useBeta, ClientEncryptionOptions encryptionOptions)
{
this.host = host;
this.port = port;
if (version.isBeta() && !useBeta)
throw new IllegalArgumentException(String.format("Beta version of server used (%s), but USE_BETA flag is not set", version));
this.version = version;
this.encryptionOptions = encryptionOptions;
}
public SimpleClient(String host, int port)
{
this(host, port, new ClientEncryptionOptions());
}
public void connect(boolean useCompression) throws IOException
{
connect(useCompression, false);
}
public void connect(boolean useCompression, boolean throwOnOverload) throws IOException
{
establishConnection();
Map<String, String> options = new HashMap<>();
options.put(StartupMessage.CQL_VERSION, "3.0.0");
if (throwOnOverload)
options.put(StartupMessage.THROW_ON_OVERLOAD, "1");
connection.setThrowOnOverload(throwOnOverload);
if (useCompression)
{
options.put(StartupMessage.COMPRESSION, "snappy");
connection.setCompressor(FrameCompressor.SnappyCompressor.instance);
}
execute(new StartupMessage(options));
}
public void setEventHandler(EventHandler eventHandler)
{
responseHandler.eventHandler = eventHandler;
}
protected void establishConnection() throws IOException
{
bootstrap = new Bootstrap()
.group(new NioEventLoopGroup())
.channel(io.netty.channel.socket.nio.NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true);
if(encryptionOptions.enabled)
{
bootstrap.handler(new SecureInitializer());
}
else
{
bootstrap.handler(new Initializer());
}
ChannelFuture future = bootstrap.connect(new InetSocketAddress(host, port));
channel = future.awaitUninterruptibly().channel();
if (!future.isSuccess())
{
bootstrap.group().shutdownGracefully();
throw new IOException("Connection Error", future.cause());
}
}
public void login(Map<String, String> credentials)
{
CredentialsMessage msg = new CredentialsMessage();
msg.credentials.putAll(credentials);
execute(msg);
}
public ResultMessage execute(String query, ConsistencyLevel consistency)
{
return execute(query, Collections.<ByteBuffer>emptyList(), consistency);
}
public ResultMessage execute(String query, List<ByteBuffer> values, ConsistencyLevel consistencyLevel)
{
Message.Response msg = execute(new QueryMessage(query, QueryOptions.forInternalCalls(consistencyLevel, values)));
assert msg instanceof ResultMessage;
return (ResultMessage)msg;
}
public ResultMessage.Prepared prepare(String query)
{
Message.Response msg = execute(new PrepareMessage(query));
assert msg instanceof ResultMessage.Prepared;
return (ResultMessage.Prepared)msg;
}
public ResultMessage executePrepared(byte[] statementId, List<ByteBuffer> values, ConsistencyLevel consistency)
{
Message.Response msg = execute(new ExecuteMessage(MD5Digest.wrap(statementId), QueryOptions.forInternalCalls(consistency, values)));
assert msg instanceof ResultMessage;
return (ResultMessage)msg;
}
public void close()
{
if (lastWriteFuture != null)
lastWriteFuture.awaitUninterruptibly();
channel.close().awaitUninterruptibly();
bootstrap.group().shutdownGracefully();
}
public Message.Response execute(Message.Request request)
{
try
{
request.attach(connection);
lastWriteFuture = channel.writeAndFlush(request);
Message.Response msg = responseHandler.responses.take();
if (msg instanceof ErrorMessage)
throw new RuntimeException((Throwable)((ErrorMessage)msg).error);
return msg;
}
catch (InterruptedException e)
{
throw new RuntimeException(e);
}
}
public interface EventHandler
{
void onEvent(Event event);
}
public static class SimpleEventHandler implements EventHandler
{
public final LinkedBlockingQueue<Event> queue = new LinkedBlockingQueue<>();
public void onEvent(Event event)
{
queue.add(event);
}
}
private static final Message.ProtocolDecoder messageDecoder = new Message.ProtocolDecoder();
private static final Message.ProtocolEncoder messageEncoder = new Message.ProtocolEncoder(ProtocolVersionLimit.SERVER_DEFAULT);
private static final Frame.Decompressor frameDecompressor = new Frame.Decompressor();
private static final Frame.Compressor frameCompressor = new Frame.Compressor();
private static final Frame.Encoder frameEncoder = new Frame.Encoder();
private static class ConnectionTracker implements Connection.Tracker
{
public void addConnection(Channel ch, Connection connection) {}
public boolean isRegistered(Event.Type type, Channel ch)
{
return false;
}
}
private class Initializer extends ChannelInitializer<Channel>
{
protected void initChannel(Channel channel) throws Exception
{
connection = new Connection(channel, version, tracker);
channel.attr(Connection.attributeKey).set(connection);
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast("frameDecoder", new Frame.Decoder(connectionFactory, ProtocolVersionLimit.SERVER_DEFAULT));
pipeline.addLast("frameEncoder", frameEncoder);
pipeline.addLast("frameDecompressor", frameDecompressor);
pipeline.addLast("frameCompressor", frameCompressor);
pipeline.addLast("messageDecoder", messageDecoder);
pipeline.addLast("messageEncoder", messageEncoder);
pipeline.addLast("handler", responseHandler);
}
}
private class SecureInitializer extends Initializer
{
private final SSLContext sslContext;
public SecureInitializer() throws IOException
{
this.sslContext = SSLFactory.createSSLContext(encryptionOptions, true);
}
protected void initChannel(Channel channel) throws Exception
{
super.initChannel(channel);
SSLEngine sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(true);
String[] suites = SSLFactory.filterCipherSuites(sslEngine.getSupportedCipherSuites(), encryptionOptions.cipher_suites);
sslEngine.setEnabledCipherSuites(suites);
channel.pipeline().addFirst("ssl", new SslHandler(sslEngine));
}
}
@ChannelHandler.Sharable
private static class ResponseHandler extends SimpleChannelInboundHandler<Message.Response>
{
public final BlockingQueue<Message.Response> responses = new SynchronousQueue<>(true);
public EventHandler eventHandler;
@Override
public void channelRead0(ChannelHandlerContext ctx, Message.Response r)
{
try
{
if (r instanceof EventMessage)
{
if (eventHandler != null)
eventHandler.onEvent(((EventMessage) r).event);
}
else
responses.put(r);
}
catch (InterruptedException ie)
{
throw new RuntimeException(ie);
}
}
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception
{
if (this == ctx.pipeline().last())
logger.error("Exception in response", cause);
ctx.fireExceptionCaught(cause);
}
}
}