package org.apache.cassandra.transport;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.Version;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.GlobalEventExecutor;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.netty.util.internal.logging.Slf4JLoggerFactory;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.config.EncryptionOptions;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.net.ResourceLimits;
import org.apache.cassandra.security.SSLFactory;
import org.apache.cassandra.service.*;
import org.apache.cassandra.transport.messages.EventMessage;
import org.apache.cassandra.utils.FBUtilities;
public class Server implements CassandraDaemon.Server
{
static
{
InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory());
}
private static final Logger logger = LoggerFactory.getLogger(Server.class);
private static final boolean useEpoll = NativeTransportService.useEpoll();
private final ConnectionTracker connectionTracker = new ConnectionTracker();
private final Connection.Factory connectionFactory = new Connection.Factory()
{
public Connection newConnection(Channel channel, ProtocolVersion version)
{
return new ServerConnection(channel, version, connectionTracker);
}
};
public final InetSocketAddress socket;
public boolean useSSL = false;
private final AtomicBoolean isRunning = new AtomicBoolean(false);
private EventLoopGroup workerGroup;
private final ProtocolVersionLimit protocolVersionLimit;
private Server (Builder builder)
{
this.socket = builder.getSocket();
this.useSSL = builder.useSSL;
this.protocolVersionLimit = builder.getProtocolVersionLimit();
if (builder.workerGroup != null)
{
workerGroup = builder.workerGroup;
}
else
{
if (useEpoll)
workerGroup = new EpollEventLoopGroup();
else
workerGroup = new NioEventLoopGroup();
}
EventNotifier notifier = new EventNotifier(this);
StorageService.instance.register(notifier);
MigrationManager.instance.register(notifier);
}
public void stop()
{
if (isRunning.compareAndSet(true, false))
close();
}
public boolean isRunning()
{
return isRunning.get();
}
public synchronized void start()
{
if(isRunning())
return;
ServerBootstrap bootstrap = new ServerBootstrap()
.channel(useEpoll ? EpollServerSocketChannel.class : NioServerSocketChannel.class)
.childOption(ChannelOption.TCP_NODELAY, true)
.childOption(ChannelOption.SO_LINGER, 0)
.childOption(ChannelOption.SO_KEEPALIVE, DatabaseDescriptor.getRpcKeepAlive())
.childOption(ChannelOption.ALLOCATOR, CBUtil.allocator)
.childOption(ChannelOption.WRITE_BUFFER_HIGH_WATER_MARK, 32 * 1024)
.childOption(ChannelOption.WRITE_BUFFER_LOW_WATER_MARK, 8 * 1024);
if (workerGroup != null)
bootstrap = bootstrap.group(workerGroup);
if (this.useSSL)
{
final EncryptionOptions.ClientEncryptionOptions clientEnc = DatabaseDescriptor.getClientEncryptionOptions();
if (clientEnc.optional)
{
logger.info("Enabling optionally encrypted CQL connections between client and server");
bootstrap.childHandler(new OptionalSecureInitializer(this, clientEnc));
}
else
{
logger.info("Enabling encrypted CQL connections between client and server");
bootstrap.childHandler(new SecureInitializer(this, clientEnc));
}
}
else
{
bootstrap.childHandler(new Initializer(this));
}
logger.info("Using Netty Version: {}", Version.identify().entrySet());
logger.info("Starting listening for CQL clients on {} ({})...", socket, this.useSSL ? "encrypted" : "unencrypted");
ChannelFuture bindFuture = bootstrap.bind(socket);
if (!bindFuture.awaitUninterruptibly().isSuccess())
throw new IllegalStateException(String.format("Failed to bind port %d on %s.", socket.getPort(), socket.getAddress().getHostAddress()));
connectionTracker.allChannels.add(bindFuture.channel());
isRunning.set(true);
}
public int getConnectedClients()
{
return connectionTracker.getConnectedClients();
}
private void close()
{
connectionTracker.closeAll();
logger.info("Stop listening for CQL clients");
}
public static class Builder
{
private EventLoopGroup workerGroup;
private EventExecutor eventExecutorGroup;
private boolean useSSL = false;
private InetAddress hostAddr;
private int port = -1;
private InetSocketAddress socket;
private ProtocolVersionLimit versionLimit;
public Builder withSSL(boolean useSSL)
{
this.useSSL = useSSL;
return this;
}
public Builder withEventLoopGroup(EventLoopGroup eventLoopGroup)
{
this.workerGroup = eventLoopGroup;
return this;
}
public Builder withHost(InetAddress host)
{
this.hostAddr = host;
this.socket = null;
return this;
}
public Builder withPort(int port)
{
this.port = port;
this.socket = null;
return this;
}
public Builder withProtocolVersionLimit(ProtocolVersionLimit limit)
{
this.versionLimit = limit;
return this;
}
ProtocolVersionLimit getProtocolVersionLimit()
{
if (versionLimit == null)
throw new IllegalArgumentException("Missing protocol version limiter");
return versionLimit;
}
public Server build()
{
return new Server(this);
}
private InetSocketAddress getSocket()
{
if (this.socket != null)
return this.socket;
else
{
if (this.port == -1)
throw new IllegalStateException("Missing port number");
if (this.hostAddr != null)
this.socket = new InetSocketAddress(this.hostAddr, this.port);
else
throw new IllegalStateException("Missing host");
return this.socket;
}
}
}
public static class ConnectionTracker implements Connection.Tracker
{
public final ChannelGroup allChannels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
private final EnumMap<Event.Type, ChannelGroup> groups = new EnumMap<>(Event.Type.class);
public ConnectionTracker()
{
for (Event.Type type : Event.Type.values())
groups.put(type, new DefaultChannelGroup(type.toString(), GlobalEventExecutor.INSTANCE));
}
public void addConnection(Channel ch, Connection connection)
{
allChannels.add(ch);
}
public void register(Event.Type type, Channel ch)
{
groups.get(type).add(ch);
}
public void send(Event event)
{
groups.get(event.type).writeAndFlush(new EventMessage(event));
}
public void closeAll()
{
allChannels.close().awaitUninterruptibly();
}
public int getConnectedClients()
{
return allChannels.size() != 0 ? allChannels.size() - 1 : 0;
}
}
private static final ResourceLimits.Concurrent globalRequestPayloadInFlight = new ResourceLimits.Concurrent(DatabaseDescriptor.getNativeTransportMaxConcurrentRequestsInBytes());
public static class EndpointPayloadTracker
{
private static final ConcurrentMap<InetAddress, EndpointPayloadTracker> requestPayloadInFlightPerEndpoint = new ConcurrentHashMap<>();
private final AtomicInteger refCount = new AtomicInteger(0);
private final InetAddress endpoint;
final ResourceLimits.EndpointAndGlobal endpointAndGlobalPayloadsInFlight = new ResourceLimits.EndpointAndGlobal(new ResourceLimits.Concurrent(DatabaseDescriptor.getNativeTransportMaxConcurrentRequestsInBytesPerIp()),
globalRequestPayloadInFlight);
private EndpointPayloadTracker(InetAddress endpoint)
{
this.endpoint = endpoint;
}
public static EndpointPayloadTracker get(InetAddress endpoint)
{
while (true)
{
EndpointPayloadTracker result = requestPayloadInFlightPerEndpoint.computeIfAbsent(endpoint, EndpointPayloadTracker::new);
if (result.acquire())
return result;
requestPayloadInFlightPerEndpoint.remove(endpoint, result);
}
}
private boolean acquire()
{
return 0 < refCount.updateAndGet(i -> i < 0 ? i : i + 1);
}
public void release()
{
if (-1 == refCount.updateAndGet(i -> i == 1 ? -1 : i - 1))
requestPayloadInFlightPerEndpoint.remove(endpoint, this);
}
}
private static class Initializer extends ChannelInitializer<Channel>
{
private static final Message.ProtocolDecoder messageDecoder = new Message.ProtocolDecoder();
private static final Frame.Decompressor frameDecompressor = new Frame.Decompressor();
private static final Frame.Compressor frameCompressor = new Frame.Compressor();
private static final Frame.Encoder frameEncoder = new Frame.Encoder();
private static final Message.ExceptionHandler exceptionHandler = new Message.ExceptionHandler();
private static final ConnectionLimitHandler connectionLimitHandler = new ConnectionLimitHandler();
private final Server server;
public Initializer(Server server)
{
this.server = server;
}
protected void initChannel(Channel channel) throws Exception
{
ChannelPipeline pipeline = channel.pipeline();
if (DatabaseDescriptor.getNativeTransportMaxConcurrentConnections() > 0
|| DatabaseDescriptor.getNativeTransportMaxConcurrentConnectionsPerIp() > 0)
{
pipeline.addFirst("connectionLimitHandler", connectionLimitHandler);
}
pipeline.addLast("frameDecoder", new Frame.Decoder(server.connectionFactory, server.protocolVersionLimit));
pipeline.addLast("frameEncoder", frameEncoder);
pipeline.addLast("frameDecompressor", frameDecompressor);
pipeline.addLast("frameCompressor", frameCompressor);
pipeline.addLast("messageDecoder", messageDecoder);
pipeline.addLast("messageEncoder", new Message.ProtocolEncoder(server.protocolVersionLimit));
pipeline.addLast("executor", new Message.Dispatcher(DatabaseDescriptor.useNativeTransportLegacyFlusher(),
EndpointPayloadTracker.get(((InetSocketAddress) channel.remoteAddress()).getAddress())));
pipeline.addLast("exceptionHandler", exceptionHandler);
}
}
protected abstract static class AbstractSecureIntializer extends Initializer
{
private final SSLContext sslContext;
private final EncryptionOptions encryptionOptions;
protected AbstractSecureIntializer(Server server, EncryptionOptions encryptionOptions)
{
super(server);
this.encryptionOptions = encryptionOptions;
try
{
this.sslContext = SSLFactory.createSSLContext(encryptionOptions, encryptionOptions.require_client_auth);
}
catch (IOException e)
{
throw new RuntimeException("Failed to setup secure pipeline", e);
}
}
protected final SslHandler createSslHandler()
{
SSLEngine sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(false);
String[] suites = SSLFactory.filterCipherSuites(sslEngine.getSupportedCipherSuites(), encryptionOptions.cipher_suites);
sslEngine.setEnabledCipherSuites(suites);
sslEngine.setNeedClientAuth(encryptionOptions.require_client_auth);
return new SslHandler(sslEngine);
}
}
private static class OptionalSecureInitializer extends AbstractSecureIntializer
{
public OptionalSecureInitializer(Server server, EncryptionOptions encryptionOptions)
{
super(server, encryptionOptions);
}
protected void initChannel(final Channel channel) throws Exception
{
super.initChannel(channel);
channel.pipeline().addFirst("sslDetectionHandler", new ByteToMessageDecoder()
{
@Override
protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception
{
if (byteBuf.readableBytes() < 5)
{
return;
}
if (SslHandler.isEncrypted(byteBuf))
{
SslHandler sslHandler = createSslHandler();
channelHandlerContext.pipeline().replace(this, "ssl", sslHandler);
}
else
{
channelHandlerContext.pipeline().remove(this);
}
}
});
}
}
private static class SecureInitializer extends AbstractSecureIntializer
{
public SecureInitializer(Server server, EncryptionOptions encryptionOptions)
{
super(server, encryptionOptions);
}
protected void initChannel(Channel channel) throws Exception
{
SslHandler sslHandler = createSslHandler();
super.initChannel(channel);
channel.pipeline().addFirst("ssl", sslHandler);
}
}
private static class LatestEvent
{
public final Event.StatusChange.Status status;
public final Event.TopologyChange.Change topology;
private LatestEvent(Event.StatusChange.Status status, Event.TopologyChange.Change topology)
{
this.status = status;
this.topology = topology;
}
@Override
public String toString()
{
return String.format("Status %s, Topology %s", status, topology);
}
public static LatestEvent forStatusChange(Event.StatusChange.Status status, LatestEvent prev)
{
return new LatestEvent(status,
prev == null ?
null :
prev.topology);
}
public static LatestEvent forTopologyChange(Event.TopologyChange.Change change, LatestEvent prev)
{
return new LatestEvent(prev == null ?
null :
prev.status,
change);
}
}
private static class EventNotifier extends MigrationListener implements IEndpointLifecycleSubscriber
{
private final Server server;
private final Map<InetAddress, LatestEvent> latestEvents = new ConcurrentHashMap<>();
private final Set<InetAddress> endpointsPendingJoinedNotification = ConcurrentHashMap.newKeySet();
private static final InetAddress bindAll;
static
{
try
{
bindAll = InetAddress.getByAddress(new byte[4]);
}
catch (UnknownHostException e)
{
throw new AssertionError(e);
}
}
private EventNotifier(Server server)
{
this.server = server;
}
private InetAddress getRpcAddress(InetAddress endpoint)
{
try
{
InetAddress rpcAddress = InetAddress.getByName(StorageService.instance.getRpcaddress(endpoint));
return rpcAddress.equals(bindAll) ? endpoint : rpcAddress;
}
catch (UnknownHostException e)
{
logger.error("Problem retrieving RPC address for {}", endpoint, e);
return endpoint;
}
}
private void send(InetAddress endpoint, Event.NodeEvent event)
{
if (logger.isTraceEnabled())
logger.trace("Sending event for endpoint {}, rpc address {}", endpoint, event.nodeAddress());
if (!endpoint.equals(FBUtilities.getBroadcastAddress()) &&
event.nodeAddress().equals(FBUtilities.getBroadcastRpcAddress()))
return;
send(event);
}
private void send(Event event)
{
server.connectionTracker.send(event);
}
public void onJoinCluster(InetAddress endpoint)
{
if (!StorageService.instance.isRpcReady(endpoint))
endpointsPendingJoinedNotification.add(endpoint);
else
onTopologyChange(endpoint, Event.TopologyChange.newNode(getRpcAddress(endpoint), server.socket.getPort()));
}
public void onLeaveCluster(InetAddress endpoint)
{
onTopologyChange(endpoint, Event.TopologyChange.removedNode(getRpcAddress(endpoint), server.socket.getPort()));
}
public void onMove(InetAddress endpoint)
{
onTopologyChange(endpoint, Event.TopologyChange.movedNode(getRpcAddress(endpoint), server.socket.getPort()));
}
public void onUp(InetAddress endpoint)
{
if (endpointsPendingJoinedNotification.remove(endpoint))
onJoinCluster(endpoint);
onStatusChange(endpoint, Event.StatusChange.nodeUp(getRpcAddress(endpoint), server.socket.getPort()));
}
public void onDown(InetAddress endpoint)
{
onStatusChange(endpoint, Event.StatusChange.nodeDown(getRpcAddress(endpoint), server.socket.getPort()));
}
private void onTopologyChange(InetAddress endpoint, Event.TopologyChange event)
{
if (logger.isTraceEnabled())
logger.trace("Topology changed event : {}, {}", endpoint, event.change);
LatestEvent prev = latestEvents.get(endpoint);
if (prev == null || prev.topology != event.change)
{
LatestEvent ret = latestEvents.put(endpoint, LatestEvent.forTopologyChange(event.change, prev));
if (ret == prev)
send(endpoint, event);
}
}
private void onStatusChange(InetAddress endpoint, Event.StatusChange event)
{
if (logger.isTraceEnabled())
logger.trace("Status changed event : {}, {}", endpoint, event.status);
LatestEvent prev = latestEvents.get(endpoint);
if (prev == null || prev.status != event.status)
{
LatestEvent ret = latestEvents.put(endpoint, LatestEvent.forStatusChange(event.status, null));
if (ret == prev)
send(endpoint, event);
}
}
public void onCreateKeyspace(String ksName)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.CREATED, ksName));
}
public void onCreateColumnFamily(String ksName, String cfName)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.CREATED, Event.SchemaChange.Target.TABLE, ksName, cfName));
}
public void onCreateUserType(String ksName, String typeName)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.CREATED, Event.SchemaChange.Target.TYPE, ksName, typeName));
}
public void onCreateFunction(String ksName, String functionName, List<AbstractType<?>> argTypes)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.CREATED, Event.SchemaChange.Target.FUNCTION,
ksName, functionName, AbstractType.asCQLTypeStringList(argTypes)));
}
public void onCreateAggregate(String ksName, String aggregateName, List<AbstractType<?>> argTypes)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.CREATED, Event.SchemaChange.Target.AGGREGATE,
ksName, aggregateName, AbstractType.asCQLTypeStringList(argTypes)));
}
public void onUpdateKeyspace(String ksName)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.UPDATED, ksName));
}
public void onUpdateColumnFamily(String ksName, String cfName, boolean affectsStatements)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.UPDATED, Event.SchemaChange.Target.TABLE, ksName, cfName));
}
public void onUpdateUserType(String ksName, String typeName)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.UPDATED, Event.SchemaChange.Target.TYPE, ksName, typeName));
}
public void onUpdateFunction(String ksName, String functionName, List<AbstractType<?>> argTypes)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.UPDATED, Event.SchemaChange.Target.FUNCTION,
ksName, functionName, AbstractType.asCQLTypeStringList(argTypes)));
}
public void onUpdateAggregate(String ksName, String aggregateName, List<AbstractType<?>> argTypes)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.UPDATED, Event.SchemaChange.Target.AGGREGATE,
ksName, aggregateName, AbstractType.asCQLTypeStringList(argTypes)));
}
public void onDropKeyspace(String ksName)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.DROPPED, ksName));
}
public void onDropColumnFamily(String ksName, String cfName)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.DROPPED, Event.SchemaChange.Target.TABLE, ksName, cfName));
}
public void onDropUserType(String ksName, String typeName)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.DROPPED, Event.SchemaChange.Target.TYPE, ksName, typeName));
}
public void onDropFunction(String ksName, String functionName, List<AbstractType<?>> argTypes)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.DROPPED, Event.SchemaChange.Target.FUNCTION,
ksName, functionName, AbstractType.asCQLTypeStringList(argTypes)));
}
public void onDropAggregate(String ksName, String aggregateName, List<AbstractType<?>> argTypes)
{
send(new Event.SchemaChange(Event.SchemaChange.Change.DROPPED, Event.SchemaChange.Target.AGGREGATE,
ksName, aggregateName, AbstractType.asCQLTypeStringList(argTypes)));
}
}
}