package io.vertx.mqtt.impl;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.mqtt.MqttConnAckVariableHeader;
import io.netty.handler.codec.mqtt.MqttConnectReturnCode;
import io.netty.handler.codec.mqtt.MqttFixedHeader;
import io.netty.handler.codec.mqtt.MqttMessageFactory;
import io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader;
import io.netty.handler.codec.mqtt.MqttMessageType;
import io.netty.handler.codec.mqtt.MqttPublishVariableHeader;
import io.netty.handler.codec.mqtt.MqttQoS;
import io.netty.handler.codec.mqtt.MqttSubAckPayload;
import io.vertx.core.AsyncResult;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.impl.NetSocketInternal;
import io.vertx.core.logging.Logger;
import io.vertx.core.logging.LoggerFactory;
import io.vertx.core.net.SocketAddress;
import io.vertx.mqtt.MqttAuth;
import io.vertx.mqtt.MqttEndpoint;
import io.vertx.mqtt.MqttTopicSubscription;
import io.vertx.mqtt.MqttWill;
import javax.net.ssl.SSLSession;
import java.util.List;
import java.util.stream.Collectors;
public class MqttEndpointImpl implements MqttEndpoint {
private static final int MAX_MESSAGE_ID = 65535;
private static final Logger log = LoggerFactory.getLogger(MqttEndpointImpl.class);
private final NetSocketInternal conn;
private String clientIdentifier;
private final MqttAuth auth;
private final MqttWill will;
private final boolean isCleanSession;
private final int protocolVersion;
private final String protocolName;
private final int keepAliveTimeoutSeconds;
private Handler<io.vertx.mqtt.messages.MqttSubscribeMessage> subscribeHandler;
private Handler<io.vertx.mqtt.messages.MqttUnsubscribeMessage> unsubscribeHandler;
private Handler<io.vertx.mqtt.messages.MqttPublishMessage> publishHandler;
private Handler<Integer> pubackHandler;
private Handler<Integer> pubrecHandler;
private Handler<Integer> pubrelHandler;
private Handler<Integer> pubcompHandler;
private Handler<Void> disconnectHandler;
private Handler<Void> pingreqHandler;
private Handler<Void> closeHandler;
private Handler<Throwable> exceptionHandler;
private boolean isConnected;
private boolean isClosed;
private int messageIdCounter;
private boolean isSubscriptionAutoAck;
private boolean isPublishAutoAck;
private boolean isAutoKeepAlive = true;
public MqttEndpointImpl(NetSocketInternal conn, String clientIdentifier, MqttAuth auth, MqttWill will, boolean isCleanSession, int protocolVersion, String protocolName, int keepAliveTimeoutSeconds) {
this.conn = conn;
this.clientIdentifier = clientIdentifier;
this.auth = auth;
this.will = will;
this.isCleanSession = isCleanSession;
this.protocolVersion = protocolVersion;
this.protocolName = protocolName;
this.keepAliveTimeoutSeconds = keepAliveTimeoutSeconds;
}
public String clientIdentifier() {
return this.clientIdentifier;
}
public MqttAuth auth() {
return this.auth;
}
public MqttWill will() {
return this.will;
}
public boolean isCleanSession() {
return this.isCleanSession;
}
public int protocolVersion() {
return this.protocolVersion;
}
public String protocolName() {
return this.protocolName;
}
public int keepAliveTimeSeconds() {
return this.keepAliveTimeoutSeconds;
}
public int lastMessageId() {
return this.messageIdCounter;
}
public void subscriptionAutoAck(boolean isSubscriptionAutoAck) {
this.isSubscriptionAutoAck = isSubscriptionAutoAck;
}
public boolean isSubscriptionAutoAck() {
return this.isSubscriptionAutoAck;
}
public MqttEndpoint publishAutoAck(boolean isPublishAutoAck) {
this.isPublishAutoAck = isPublishAutoAck;
return this;
}
public boolean isPublishAutoAck() {
return this.isPublishAutoAck;
}
public MqttEndpoint autoKeepAlive(boolean isAutoKeepAlive) {
this.isAutoKeepAlive = isAutoKeepAlive;
return this;
}
public boolean isAutoKeepAlive() {
return this.isAutoKeepAlive;
}
public boolean isConnected() {
synchronized (this.conn) {
return this.isConnected;
}
}
public MqttEndpoint setClientIdentifier(String clientIdentifier) {
synchronized (this.conn) {
this.clientIdentifier = clientIdentifier;
}
return this;
}
public MqttEndpointImpl disconnectHandler(Handler<Void> handler) {
synchronized (this.conn) {
this.checkClosed();
this.disconnectHandler = handler;
return this;
}
}
public MqttEndpointImpl subscribeHandler(Handler<io.vertx.mqtt.messages.MqttSubscribeMessage> handler) {
synchronized (this.conn) {
this.checkClosed();
this.subscribeHandler = handler;
return this;
}
}
public MqttEndpointImpl unsubscribeHandler(Handler<io.vertx.mqtt.messages.MqttUnsubscribeMessage> handler) {
synchronized (this.conn) {
this.checkClosed();
this.unsubscribeHandler = handler;
return this;
}
}
public MqttEndpointImpl publishHandler(Handler<io.vertx.mqtt.messages.MqttPublishMessage> handler) {
synchronized (this.conn) {
this.checkClosed();
this.publishHandler = handler;
return this;
}
}
public MqttEndpointImpl publishAcknowledgeHandler(Handler<Integer> handler) {
synchronized (this.conn) {
this.checkClosed();
this.pubackHandler = handler;
return this;
}
}
public MqttEndpointImpl publishReceivedHandler(Handler<Integer> handler) {
synchronized (this.conn) {
this.checkClosed();
this.pubrecHandler = handler;
return this;
}
}
public MqttEndpointImpl publishReleaseHandler(Handler<Integer> handler) {
synchronized (this.conn) {
this.checkClosed();
this.pubrelHandler = handler;
return this;
}
}
public MqttEndpointImpl publishCompletionHandler(Handler<Integer> handler) {
synchronized (this.conn) {
this.checkClosed();
this.pubcompHandler = handler;
return this;
}
}
public MqttEndpointImpl pingHandler(Handler<Void> handler) {
synchronized (this.conn) {
this.checkClosed();
this.pingreqHandler = handler;
return this;
}
}
public MqttEndpointImpl closeHandler(Handler<Void> handler) {
synchronized (this.conn) {
this.checkClosed();
this.closeHandler = handler;
return this;
}
}
public MqttEndpointImpl exceptionHandler(Handler<Throwable> handler) {
synchronized (this.conn) {
this.checkClosed();
this.exceptionHandler = handler;
return this;
}
}
private MqttEndpointImpl connack(MqttConnectReturnCode returnCode, boolean sessionPresent) {
MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.CONNACK, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttConnAckVariableHeader variableHeader =
new MqttConnAckVariableHeader(returnCode, sessionPresent);
io.netty.handler.codec.mqtt.MqttMessage connack = MqttMessageFactory.newMessage(fixedHeader, variableHeader, null);
this.write(connack);
if (returnCode != MqttConnectReturnCode.CONNECTION_ACCEPTED) {
this.close();
} else {
this.isConnected = true;
}
return this;
}
@Override
public MqttEndpoint accept() {
return accept(false);
}
public MqttEndpointImpl accept(boolean sessionPresent) {
synchronized (conn) {
if (this.isConnected) {
throw new IllegalStateException("Connection already accepted");
}
return this.connack(MqttConnectReturnCode.CONNECTION_ACCEPTED, sessionPresent);
}
}
public MqttEndpointImpl reject(MqttConnectReturnCode returnCode) {
synchronized (conn) {
if (returnCode == MqttConnectReturnCode.CONNECTION_ACCEPTED) {
throw new IllegalArgumentException("Need to use the 'accept' method for accepting connection");
}
return this.connack(returnCode, false);
}
}
public MqttEndpointImpl subscribeAcknowledge(int subscribeMessageId, List<MqttQoS> grantedQoSLevels) {
MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.SUBACK, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
MqttMessageIdVariableHeader.from(subscribeMessageId);
MqttSubAckPayload payload = new MqttSubAckPayload(grantedQoSLevels.stream().mapToInt(MqttQoS::value).toArray());
io.netty.handler.codec.mqtt.MqttMessage suback = MqttMessageFactory.newMessage(fixedHeader, variableHeader, payload);
this.write(suback);
return this;
}
public MqttEndpointImpl unsubscribeAcknowledge(int unsubscribeMessageId) {
MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.UNSUBACK, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
MqttMessageIdVariableHeader.from(unsubscribeMessageId);
io.netty.handler.codec.mqtt.MqttMessage unsuback = MqttMessageFactory.newMessage(fixedHeader, variableHeader, null);
this.write(unsuback);
return this;
}
public MqttEndpointImpl publishAcknowledge(int publishMessageId) {
MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PUBACK, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
MqttMessageIdVariableHeader.from(publishMessageId);
io.netty.handler.codec.mqtt.MqttMessage puback = MqttMessageFactory.newMessage(fixedHeader, variableHeader, null);
this.write(puback);
return this;
}
public MqttEndpointImpl publishReceived(int publishMessageId) {
MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PUBREC, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
MqttMessageIdVariableHeader.from(publishMessageId);
io.netty.handler.codec.mqtt.MqttMessage pubrec = MqttMessageFactory.newMessage(fixedHeader, variableHeader, null);
this.write(pubrec);
return this;
}
public MqttEndpointImpl publishRelease(int publishMessageId) {
MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PUBREL, false, MqttQoS.AT_LEAST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
MqttMessageIdVariableHeader.from(publishMessageId);
io.netty.handler.codec.mqtt.MqttMessage pubrel = MqttMessageFactory.newMessage(fixedHeader, variableHeader, null);
this.write(pubrel);
return this;
}
public MqttEndpointImpl publishComplete(int publishMessageId) {
MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PUBCOMP, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
MqttMessageIdVariableHeader.from(publishMessageId);
io.netty.handler.codec.mqtt.MqttMessage pubcomp = MqttMessageFactory.newMessage(fixedHeader, variableHeader, null);
this.write(pubcomp);
return this;
}
@Override
public MqttEndpointImpl publish(String topic, Buffer payload, MqttQoS qosLevel, boolean isDup, boolean isRetain) {
return publish(topic, payload, qosLevel, isDup, isRetain, null);
}
@Override
public MqttEndpointImpl publish(String topic, Buffer payload, MqttQoS qosLevel, boolean isDup, boolean isRetain, Handler<AsyncResult<Integer>> publishSentHandler) {
return publish(topic, payload, qosLevel, isDup, isRetain, this.nextMessageId(), publishSentHandler);
}
@Override
public MqttEndpointImpl publish(String topic, Buffer payload, MqttQoS qosLevel, boolean isDup, boolean isRetain, int messageId, Handler<AsyncResult<Integer>> publishSentHandler) {
if (messageId > MAX_MESSAGE_ID || messageId < 0) {
throw new IllegalArgumentException("messageId must be non-negative integer not larger than " + MAX_MESSAGE_ID);
}
MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PUBLISH, isDup, qosLevel, isRetain, 0);
MqttPublishVariableHeader variableHeader =
new MqttPublishVariableHeader(topic, messageId);
ByteBuf buf = Unpooled.copiedBuffer(payload.getBytes());
io.netty.handler.codec.mqtt.MqttMessage publish = MqttMessageFactory.newMessage(fixedHeader, variableHeader, buf);
this.write(publish);
if (publishSentHandler != null) {
publishSentHandler.handle(Future.succeededFuture(variableHeader.packetId()));
}
return this;
}
public MqttEndpointImpl pong() {
MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PINGRESP, false, MqttQoS.AT_MOST_ONCE, false, 0);
io.netty.handler.codec.mqtt.MqttMessage pingresp = MqttMessageFactory.newMessage(fixedHeader, null, null);
this.write(pingresp);
return this;
}
void handleSubscribe(io.vertx.mqtt.messages.MqttSubscribeMessage msg) {
synchronized (this.conn) {
if (this.subscribeHandler != null) {
this.subscribeHandler.handle(msg);
}
if (this.isSubscriptionAutoAck) {
this.subscribeAcknowledge(msg.messageId(), msg.topicSubscriptions()
.stream()
.map(MqttTopicSubscription::qualityOfService)
.collect(Collectors.toList()));
}
}
}
void handleUnsubscribe(io.vertx.mqtt.messages.MqttUnsubscribeMessage msg) {
synchronized (this.conn) {
if (this.unsubscribeHandler != null) {
this.unsubscribeHandler.handle(msg);
}
if (this.isSubscriptionAutoAck) {
this.unsubscribeAcknowledge(msg.messageId());
}
}
}
void handlePublish(io.vertx.mqtt.messages.MqttPublishMessage msg) {
synchronized (this.conn) {
if (this.publishHandler != null) {
this.publishHandler.handle(msg);
}
if (this.isPublishAutoAck) {
switch (msg.qosLevel()) {
case AT_LEAST_ONCE:
this.publishAcknowledge(msg.messageId());
break;
case EXACTLY_ONCE:
this.publishReceived(msg.messageId());
break;
}
}
}
}
void handlePuback(int pubackMessageId) {
synchronized (this.conn) {
if (this.pubackHandler != null) {
this.pubackHandler.handle(pubackMessageId);
}
}
}
void handlePubrec(int pubrecMessageId) {
synchronized (this.conn) {
if (this.pubrecHandler != null) {
this.pubrecHandler.handle(pubrecMessageId);
}
if (this.isPublishAutoAck) {
this.publishRelease(pubrecMessageId);
}
}
}
void handlePubrel(int pubrelMessageId) {
synchronized (this.conn) {
if (this.pubrelHandler != null) {
this.pubrelHandler.handle(pubrelMessageId);
}
if (this.isPublishAutoAck) {
this.publishComplete(pubrelMessageId);
}
}
}
void handlePubcomp(int pubcompMessageId) {
synchronized (this.conn) {
if (this.pubcompHandler != null) {
this.pubcompHandler.handle(pubcompMessageId);
}
}
}
void handlePingreq() {
synchronized (this.conn) {
if (this.pingreqHandler != null) {
this.pingreqHandler.handle(null);
}
if (this.isAutoKeepAlive) {
this.pong();
}
}
}
void handleDisconnect() {
synchronized (this.conn) {
if (this.disconnectHandler != null) {
this.disconnectHandler.handle(null);
this.close();
}
}
}
void handleClosed() {
synchronized (this.conn) {
this.cleanup();
if (this.closeHandler != null) {
this.closeHandler.handle(null);
}
}
}
void handleException(Throwable t) {
synchronized (this.conn) {
if (this.exceptionHandler != null) {
this.exceptionHandler.handle(t);
}
}
}
public void close() {
synchronized (this.conn) {
checkClosed();
this.conn.close();
this.cleanup();
}
}
public SocketAddress localAddress() {
synchronized (this.conn) {
this.checkClosed();
return conn.localAddress();
}
}
public SocketAddress remoteAddress() {
synchronized (this.conn) {
this.checkClosed();
return conn.remoteAddress();
}
}
public boolean isSsl() {
synchronized (this.conn) {
this.checkClosed();
return conn.isSsl();
}
}
public SSLSession sslSession() {
synchronized (this.conn) {
this.checkClosed();
return this.conn.sslSession();
}
}
private void write(io.netty.handler.codec.mqtt.MqttMessage mqttMessage) {
synchronized (this.conn) {
if (mqttMessage.fixedHeader().messageType() != MqttMessageType.CONNACK) {
this.checkConnected();
}
this.conn.writeMessage(mqttMessage);
}
}
private void checkClosed() {
if (this.isClosed) {
throw new IllegalStateException("MQTT endpoint is closed");
}
}
private void checkConnected() {
if (!this.isConnected) {
throw new IllegalStateException("Connection not accepted yet");
}
}
private void cleanup() {
if (!this.isClosed) {
this.isClosed = true;
this.isConnected = false;
}
}
private int nextMessageId() {
this.messageIdCounter = ((this.messageIdCounter % MAX_MESSAGE_ID) != 0) ? this.messageIdCounter + 1 : 1;
return this.messageIdCounter;
}
}