package io.vertx.mqtt.impl;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttConnectReturnCode;
import io.netty.handler.codec.mqtt.MqttUnacceptableProtocolVersionException;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import io.vertx.core.Handler;
import io.vertx.core.VertxException;
import io.vertx.core.impl.NetSocketInternal;
import io.vertx.core.logging.Logger;
import io.vertx.core.logging.LoggerFactory;
import io.vertx.core.net.impl.VertxHandler;
import io.vertx.mqtt.MqttAuth;
import io.vertx.mqtt.MqttEndpoint;
import io.vertx.mqtt.MqttServerOptions;
import io.vertx.mqtt.MqttWill;
import io.vertx.mqtt.messages.MqttPublishMessage;
import io.vertx.mqtt.messages.MqttSubscribeMessage;
import io.vertx.mqtt.messages.MqttUnsubscribeMessage;
import java.util.UUID;
public class MqttServerConnection {
private static final Logger log = LoggerFactory.getLogger(MqttServerConnection.class);
private Handler<MqttEndpoint> endpointHandler;
private Handler<Throwable> exceptionHandler;
private final NetSocketInternal so;
private MqttEndpointImpl endpoint;
private final ChannelHandlerContext chctx;
private final MqttServerOptions options;
void init(Handler<MqttEndpoint> endpointHandler, Handler<Throwable> rejectHandler) {
this.endpointHandler = endpointHandler;
this.exceptionHandler = rejectHandler;
}
public MqttServerConnection(NetSocketInternal so, MqttServerOptions options) {
this.so = so;
this.chctx = so.channelHandlerContext();
this.options = options;
}
void handleMessage(Object msg) {
if (msg instanceof io.netty.handler.codec.mqtt.MqttMessage) {
io.netty.handler.codec.mqtt.MqttMessage mqttMessage = (io.netty.handler.codec.mqtt.MqttMessage) msg;
DecoderResult result = mqttMessage.decoderResult();
if (result.isFailure()) {
Throwable cause = result.cause();
if (cause instanceof MqttUnacceptableProtocolVersionException) {
endpoint = new MqttEndpointImpl(so, null, null, null, false, 0, null, 0);
endpoint.reject(MqttConnectReturnCode.CONNECTION_REFUSED_UNACCEPTABLE_PROTOCOL_VERSION);
} else {
chctx.pipeline().fireExceptionCaught(result.cause());
}
return;
}
if (!result.isFinished()) {
chctx.pipeline().fireExceptionCaught(new Exception("Unfinished message"));
return;
}
switch (mqttMessage.fixedHeader().messageType()) {
case CONNECT:
handleConnect((MqttConnectMessage) msg);
break;
case SUBSCRIBE:
io.netty.handler.codec.mqtt.MqttSubscribeMessage subscribe = (io.netty.handler.codec.mqtt.MqttSubscribeMessage) mqttMessage;
MqttSubscribeMessage mqttSubscribeMessage = MqttSubscribeMessage.create(
subscribe.variableHeader().messageId(),
subscribe.payload().topicSubscriptions());
this.handleSubscribe(mqttSubscribeMessage);
break;
case UNSUBSCRIBE:
io.netty.handler.codec.mqtt.MqttUnsubscribeMessage unsubscribe = (io.netty.handler.codec.mqtt.MqttUnsubscribeMessage) mqttMessage;
MqttUnsubscribeMessage mqttUnsubscribeMessage = MqttUnsubscribeMessage.create(
unsubscribe.variableHeader().messageId(),
unsubscribe.payload().topics());
this.handleUnsubscribe(mqttUnsubscribeMessage);
break;
case PUBLISH:
io.netty.handler.codec.mqtt.MqttPublishMessage publish = (io.netty.handler.codec.mqtt.MqttPublishMessage) mqttMessage;
ByteBuf newBuf = VertxHandler.safeBuffer(publish.payload(), this.chctx.alloc());
MqttPublishMessage mqttPublishMessage = MqttPublishMessage.create(
publish.variableHeader().packetId(),
publish.fixedHeader().qosLevel(),
publish.fixedHeader().isDup(),
publish.fixedHeader().isRetain(),
publish.variableHeader().topicName(),
newBuf);
this.handlePublish(mqttPublishMessage);
break;
case PUBACK:
io.netty.handler.codec.mqtt.MqttPubAckMessage mqttPubackMessage = (io.netty.handler.codec.mqtt.MqttPubAckMessage) mqttMessage;
this.handlePuback(mqttPubackMessage.variableHeader().messageId());
break;
case PUBREC:
int pubrecMessageId = ((io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader) mqttMessage.variableHeader()).messageId();
this.handlePubrec(pubrecMessageId);
break;
case PUBREL:
int pubrelMessageId = ((io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader) mqttMessage.variableHeader()).messageId();
this.handlePubrel(pubrelMessageId);
break;
case PUBCOMP:
int pubcompMessageId = ((io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader) mqttMessage.variableHeader()).messageId();
this.handlePubcomp(pubcompMessageId);
break;
case PINGREQ:
this.handlePingreq();
break;
case DISCONNECT:
this.handleDisconnect();
break;
default:
this.chctx.fireExceptionCaught(new Exception("Wrong MQTT message type " + mqttMessage.fixedHeader().messageType()));
break;
}
} else {
this.chctx.fireExceptionCaught(new Exception("Wrong message type " + msg.getClass().getName()));
}
}
private void handleConnect(MqttConnectMessage msg) {
if (endpoint != null) {
endpoint.close();
return;
}
MqttWill will =
new MqttWill(msg.variableHeader().isWillFlag(),
msg.payload().willTopic(),
msg.payload().willMessageInBytes(),
msg.variableHeader().willQos(),
msg.variableHeader().isWillRetain());
MqttAuth auth = (msg.variableHeader().hasUserName() &&
msg.variableHeader().hasPassword()) ?
new MqttAuth(
msg.payload().userName(),
msg.payload().password()) : null;
boolean isZeroBytes = (msg.payload().clientIdentifier() == null) ||
msg.payload().clientIdentifier().isEmpty();
String clientIdentifier = null;
if (!isZeroBytes) {
clientIdentifier = msg.payload().clientIdentifier();
} else if (this.options.isAutoClientId()) {
clientIdentifier = UUID.randomUUID().toString();
}
this.endpoint =
new MqttEndpointImpl(
so,
clientIdentifier,
auth,
will,
msg.variableHeader().isCleanSession(),
msg.variableHeader().version(),
msg.variableHeader().name(),
msg.variableHeader().keepAliveTimeSeconds());
chctx.pipeline().remove("idle");
chctx.pipeline().remove("timeoutOnConnect");
if (msg.variableHeader().keepAliveTimeSeconds() != 0) {
int timeout = msg.variableHeader().keepAliveTimeSeconds() +
msg.variableHeader().keepAliveTimeSeconds() / 2;
chctx.pipeline().addBefore("handler", "idle", new IdleStateHandler(timeout, 0, 0));
chctx.pipeline().addBefore("handler", "keepAliveHandler", new ChannelDuplexHandler() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleStateEvent e = (IdleStateEvent) evt;
if (e.state() == IdleState.READER_IDLE) {
endpoint.close();
}
}
}
});
}
if (isZeroBytes && !msg.variableHeader().isCleanSession()) {
if (this.exceptionHandler != null) {
this.exceptionHandler.handle(new VertxException("With zero-length client-id, clean session MUST be true"));
}
this.endpoint.reject(MqttConnectReturnCode.CONNECTION_REFUSED_IDENTIFIER_REJECTED);
} else {
this.so.exceptionHandler(t -> {
this.endpoint.handleException(t);
});
this.so.closeHandler(v -> this.endpoint.handleClosed());
this.endpointHandler.handle(this.endpoint);
}
}
void handleSubscribe(MqttSubscribeMessage msg) {
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handleSubscribe(msg);
}
}
}
void handleUnsubscribe(MqttUnsubscribeMessage msg) {
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handleUnsubscribe(msg);
}
}
}
void handlePublish(MqttPublishMessage msg) {
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePublish(msg);
}
}
}
void handlePuback(int pubackMessageId) {
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePuback(pubackMessageId);
}
}
}
void handlePubrec(int pubrecMessageId) {
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePubrec(pubrecMessageId);
}
}
}
void handlePubrel(int pubrelMessageId) {
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePubrel(pubrelMessageId);
}
}
}
void handlePubcomp(int pubcompMessageId) {
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePubcomp(pubcompMessageId);
}
}
}
void handlePingreq() {
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePingreq();
}
}
}
void handleDisconnect() {
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handleDisconnect();
}
}
}
private boolean checkConnected() {
synchronized (this.so) {
if ((this.endpoint != null) && (this.endpoint.isConnected())) {
return true;
} else {
so.close();
throw new IllegalStateException("Received an MQTT packet from a not connected client (CONNECT not sent yet)");
}
}
}
}