package io.vertx.ext.stomp.impl;
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.eventbus.DeliveryOptions;
import io.vertx.core.eventbus.Message;
import io.vertx.core.eventbus.MessageConsumer;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.bridge.PermittedOptions;
import io.vertx.ext.stomp.*;
import io.vertx.ext.stomp.utils.Headers;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class EventBusBridge extends Topic {
private final BridgeOptions options;
private final Map<String, Pattern> expressions = new HashMap<>();
private final Map<String, MessageConsumer<?>> registry = new HashMap<>();
public EventBusBridge(Vertx vertx, BridgeOptions options) {
super(vertx, null);
this.options = options;
}
@Override
public String destination() {
return "<<bridge>>";
}
@Override
public synchronized Destination subscribe(StompServerConnection connection, Frame frame) {
String address = frame.getDestination();
if (checkMatches(false, address, null)) {
Subscription subscription = new Subscription(connection, frame);
subscriptions.add(subscription);
if (!registry.containsKey(address)) {
registry.put(address, vertx.eventBus().consumer(address, msg -> {
if (!checkMatches(false, address, msg.body())) {
return;
}
if (options.isPointToPoint()) {
Optional<Subscription> chosen = subscriptions.stream().filter(s -> s.destination.equals(address)).findAny();
if (chosen.isPresent()) {
Frame stompFrame = transform(msg, chosen.get());
chosen.get().connection.write(stompFrame);
}
} else {
subscriptions.stream().filter(s -> s.destination.equals(address)).forEach(s -> {
Frame stompFrame = transform(msg, s);
s.connection.write(stompFrame);
});
}
}));
}
return this;
}
return null;
}
@Override
public synchronized boolean unsubscribe(StompServerConnection connection, Frame frame) {
for (Subscription subscription : new ArrayList<>(subscriptions)) {
if (subscription.connection.equals(connection)
&& subscription.id.equals(frame.getId())) {
boolean r = subscriptions.remove(subscription);
Optional<Subscription> any = subscriptions.stream().filter(s -> s.destination.equals(subscription.destination)).findAny();
if (!any.isPresent()) {
MessageConsumer<?> consumer = registry.remove(subscription.destination);
if (consumer != null) {
consumer.unregister();
}
}
return r;
}
}
return false;
}
@Override
public synchronized Destination unsubscribeConnection(StompServerConnection connection) {
new ArrayList<>(subscriptions)
.stream()
.filter(subscription -> subscription.connection.equals(connection))
.forEach(s -> {
subscriptions.remove(s);
Optional<Subscription> any = subscriptions.stream().filter(s2 -> s2.destination.equals(s.destination))
.findAny();
if (!any.isPresent()) {
MessageConsumer<?> consumer = registry.remove(s.destination);
if (consumer != null) {
consumer.unregister();
}
}
});
return this;
}
private Frame transform(Message<Object> msg, Subscription subscription) {
String messageId = UUID.randomUUID().toString();
Frame frame = new Frame();
frame.setCommand(Frame.Command.MESSAGE);
final Headers headers = Headers.create(frame.getHeaders())
.add(Frame.SUBSCRIPTION, subscription.id)
.add(Frame.MESSAGE_ID, messageId)
.add(Frame.DESTINATION, msg.address());
if (!"auto".equals(subscription.ackMode)) {
headers.add(Frame.ACK, messageId);
}
if (msg.replyAddress() != null) {
headers.put("reply-address", msg.replyAddress());
}
for (Map.Entry<String, String> entry : msg.headers()) {
headers.putIfAbsent(entry.getKey(), entry.getValue());
}
frame.setHeaders(headers);
Object body = msg.body();
if (body != null) {
if (body instanceof String) {
frame.setBody(Buffer.buffer((String) body));
} else if (body instanceof Buffer) {
frame.setBody((Buffer) body);
} else if (body instanceof JsonObject) {
frame.setBody(Buffer.buffer(((JsonObject) body).encode()));
} else {
throw new IllegalStateException("Illegal body - unsupported body type: " + body.getClass().getName());
}
}
if (body != null && frame.getHeader(Frame.CONTENT_LENGTH) == null) {
frame.addHeader(Frame.CONTENT_LENGTH, Integer.toString(frame.getBody().length()));
}
return frame;
}
@Override
public Destination dispatch(StompServerConnection connection, Frame frame) {
String address = frame.getDestination();
if (checkMatches(true, address, frame.getBody())) {
final String replyAddress = frame.getHeader("reply-address");
if (replyAddress != null) {
send(address, frame, (AsyncResult<Message<Object>> res) -> {
if (res.failed()) {
Throwable cause = res.cause();
connection.write(Frames.createErrorFrame("Message dispatch error", Headers.create(Frame.DESTINATION,
address, "reply-address", replyAddress), cause.getMessage())).close();
} else {
Optional<Subscription> subscription = subscriptions.stream()
.filter(s -> s.connection.equals(connection) && s.destination.equals(replyAddress))
.findFirst();
if (subscription.isPresent()) {
Frame stompFrame = transform(res.result(), subscription.get());
subscription.get().connection.write(stompFrame);
}
}
});
} else {
send(address, frame, null);
}
} else {
connection.write(Frames.createErrorFrame("Access denied", Headers.create(Frame.DESTINATION,
address), "Access denied to " + address)).close();
return null;
}
return this;
}
private void send(String address, Frame frame, Handler<AsyncResult<Message<Object>>> replyHandler) {
if (options.isPointToPoint()) {
vertx.eventBus().request(address, frame.getBody(),
new DeliveryOptions().setHeaders(toMultimap(frame.getHeaders())), replyHandler);
} else {
vertx.eventBus().publish(address, frame.getBody(),
new DeliveryOptions().setHeaders(toMultimap(frame.getHeaders())));
}
}
private MultiMap toMultimap(Map<String, String> headers) {
return MultiMap.caseInsensitiveMultiMap().addAll(headers);
}
public boolean matches(String address, Buffer payload) {
return checkMatches(false, address, payload) || checkMatches(true, address, payload);
}
public boolean matches(String address) {
return checkMatches(false, address, null) || checkMatches(true, address, null);
}
private boolean regexMatches(String matchRegex, String address) {
Pattern pattern = expressions.get(matchRegex);
if (pattern == null) {
pattern = Pattern.compile(matchRegex);
expressions.put(matchRegex, pattern);
}
Matcher m = pattern.matcher(address);
return m.matches();
}
private boolean checkMatches(boolean inbound, String address, Object body) {
List<PermittedOptions> matches = inbound ? options.getInboundPermitteds() : options.getOutboundPermitteds();
for (PermittedOptions matchHolder : matches) {
String matchAddress = matchHolder.getAddress();
String matchRegex;
if (matchAddress == null) {
matchRegex = matchHolder.getAddressRegex();
} else {
matchRegex = null;
}
boolean addressOK;
if (matchAddress == null) {
addressOK = matchRegex == null || regexMatches(matchRegex, address);
} else {
addressOK = matchAddress.equals(address);
}
if (addressOK) {
return structureMatches(matchHolder.getMatch(), body);
}
}
return false;
}
private boolean structureMatches(JsonObject match, Object body) {
if (match == null || body == null) {
return true;
}
try {
JsonObject object;
if (body instanceof JsonObject) {
object = (JsonObject) body;
} else if (body instanceof Buffer) {
object = new JsonObject(((Buffer) body).toString("UTF-8"));
} else if (body instanceof String) {
object = new JsonObject((String) body);
} else {
return false;
}
for (String fieldName : match.fieldNames()) {
Object mv = match.getValue(fieldName);
Object bv = object.getValue(fieldName);
if (mv instanceof JsonObject) {
if (!structureMatches((JsonObject) mv, bv)) {
return false;
}
} else if (!match.getValue(fieldName).equals(object.getValue(fieldName))) {
return false;
}
}
return true;
} catch (Exception e) {
return false;
}
}
}