package io.vertx.core.eventbus.impl;
import io.vertx.codegen.annotations.Nullable;
import io.vertx.core.*;
import io.vertx.core.eventbus.Message;
import io.vertx.core.eventbus.MessageConsumer;
import io.vertx.core.eventbus.ReplyException;
import io.vertx.core.eventbus.ReplyFailure;
import io.vertx.core.eventbus.impl.clustered.ClusteredMessage;
import io.vertx.core.impl.Arguments;
import io.vertx.core.impl.ContextInternal;
import io.vertx.core.logging.Logger;
import io.vertx.core.logging.LoggerFactory;
import io.vertx.core.spi.metrics.EventBusMetrics;
import io.vertx.core.streams.ReadStream;
import java.util.*;
public class HandlerRegistration<T> implements MessageConsumer<T>, Handler<Message<T>> {
private static final Logger log = LoggerFactory.getLogger(HandlerRegistration.class);
public static final int DEFAULT_MAX_BUFFERED_MESSAGES = 1000;
private final Vertx vertx;
private final EventBusMetrics metrics;
private final EventBusImpl eventBus;
private final String address;
private final String repliedAddress;
private final boolean localOnly;
private final Handler<AsyncResult<Message<T>>> asyncResultHandler;
private long timeoutID = -1;
private HandlerHolder<T> registered;
private Handler<Message<T>> handler;
private ContextInternal handlerContext;
private AsyncResult<Void> result;
private Handler<AsyncResult<Void>> completionHandler;
private Handler<Void> endHandler;
private Handler<Message<T>> discardHandler;
private int maxBufferedMessages = DEFAULT_MAX_BUFFERED_MESSAGES;
private final Queue<Message<T>> pending = new ArrayDeque<>(8);
private long demand = Long.MAX_VALUE;
private Object metric;
public HandlerRegistration(Vertx vertx, EventBusMetrics metrics, EventBusImpl eventBus, String address,
String repliedAddress, boolean localOnly,
Handler<AsyncResult<Message<T>>> asyncResultHandler, long timeout) {
this.vertx = vertx;
this.metrics = metrics;
this.eventBus = eventBus;
this.address = address;
this.repliedAddress = repliedAddress;
this.localOnly = localOnly;
this.asyncResultHandler = asyncResultHandler;
if (timeout != -1) {
timeoutID = vertx.setTimer(timeout, tid -> {
if (metrics != null) {
metrics.replyFailure(address, ReplyFailure.TIMEOUT);
}
sendAsyncResultFailure(new ReplyException(ReplyFailure.TIMEOUT, "Timed out after waiting " + timeout + "(ms) for a reply. address: " + address + ", repliedAddress: " + repliedAddress));
});
}
}
@Override
public MessageConsumer<T> setMaxBufferedMessages(int maxBufferedMessages) {
Arguments.require(maxBufferedMessages >= 0, "Max buffered messages cannot be negative");
List<Message<T>> discarded;
Handler<Message<T>> discardHandler;
synchronized (this) {
this.maxBufferedMessages = maxBufferedMessages;
int overflow = pending.size() - maxBufferedMessages;
if (overflow <= 0) {
return this;
}
discardHandler = this.discardHandler;
if (discardHandler == null) {
while (pending.size() > maxBufferedMessages) {
pending.poll();
}
return this;
}
discarded = new ArrayList<>(overflow);
while (pending.size() > maxBufferedMessages) {
discarded.add(pending.poll());
}
}
for (Message<T> msg : discarded) {
discardHandler.handle(msg);
}
return this;
}
@Override
public synchronized int getMaxBufferedMessages() {
return maxBufferedMessages;
}
@Override
public String address() {
return address;
}
@Override
public synchronized void completionHandler(Handler<AsyncResult<Void>> completionHandler) {
Objects.requireNonNull(completionHandler);
if (result != null) {
AsyncResult<Void> value = result;
vertx.runOnContext(v -> completionHandler.handle(value));
} else {
this.completionHandler = completionHandler;
}
}
@Override
public void unregister() {
doUnregister(null);
}
@Override
public void unregister(Handler<AsyncResult<Void>> completionHandler) {
doUnregister(completionHandler);
}
public void sendAsyncResultFailure(ReplyException failure) {
unregister();
asyncResultHandler.handle(Future.failedFuture(failure));
}
private void doUnregister(Handler<AsyncResult<Void>> doneHandler) {
Deque<Message<T>> discarded;
Handler<Message<T>> discardHandler;
synchronized (this) {
handler = null;
if (timeoutID != -1) {
vertx.cancelTimer(timeoutID);
}
if (endHandler != null) {
Handler<Void> theEndHandler = endHandler;
Handler<AsyncResult<Void>> handler = doneHandler;
doneHandler = ar -> {
theEndHandler.handle(null);
if (handler != null) {
handler.handle(ar);
}
};
}
HandlerHolder<T> holder = registered;
if (pending.size() > 0) {
discarded = new ArrayDeque<>(pending);
pending.clear();
} else {
discarded = null;
}
discardHandler = this.discardHandler;
if (holder != null) {
handler = null;
registered = null;
eventBus.removeRegistration(holder, doneHandler);
} else {
callHandlerAsync(Future.succeededFuture(), doneHandler);
}
if (result == null) {
result = Future.failedFuture("Consumer unregistered before registration completed");
callHandlerAsync(result, completionHandler);
} else {
EventBusMetrics metrics = eventBus.metrics;
if (metrics != null) {
metrics.handlerUnregistered(metric);
}
}
}
if (discardHandler != null && discarded != null) {
Message<T> msg;
while ((msg = discarded.poll()) != null) {
discardHandler.handle(msg);
}
}
}
private void callHandlerAsync(AsyncResult<Void> result, Handler<AsyncResult<Void>> completionHandler) {
if (completionHandler != null) {
vertx.runOnContext(v -> completionHandler.handle(result));
}
}
synchronized void setHandlerContext(Context context) {
handlerContext = (ContextInternal) context;
}
public synchronized void setResult(AsyncResult<Void> result) {
if (this.result != null) {
return;
}
this.result = result;
if (result.failed()) {
log.error("Failed to propagate registration for handler " + handler + " and address " + address);
} else {
if (metrics != null) {
metric = metrics.handlerRegistered(address, repliedAddress);
}
callHandlerAsync(result, completionHandler);
}
}
@Override
public void handle(Message<T> message) {
Handler<Message<T>> theHandler;
ContextInternal ctx;
synchronized (this) {
if (registered == null) {
return;
} else if (demand == 0L) {
if (pending.size() < maxBufferedMessages) {
pending.add(message);
} else {
if (discardHandler != null) {
discardHandler.handle(message);
} else {
log.warn("Discarding message as more than " + maxBufferedMessages + " buffered in paused consumer. address: " + address);
}
}
return;
} else {
if (pending.size() > 0) {
pending.add(message);
message = pending.poll();
}
if (demand != Long.MAX_VALUE) {
demand--;
}
theHandler = handler;
}
ctx = handlerContext;
}
deliver(theHandler, message, ctx);
}
private void deliver(Handler<Message<T>> theHandler, Message<T> message, ContextInternal context) {
boolean local = true;
if (message instanceof ClusteredMessage) {
ClusteredMessage cmsg = (ClusteredMessage)message;
if (cmsg.isFromWire()) {
local = false;
}
}
String creditsAddress = message.headers().get(MessageProducerImpl.CREDIT_ADDRESS_HEADER_NAME);
if (creditsAddress != null) {
eventBus.send(creditsAddress, 1);
}
try {
if (metrics != null) {
metrics.beginHandleMessage(metric, local);
}
theHandler.handle(message);
if (metrics != null) {
metrics.endHandleMessage(metric, null);
}
} catch (Exception e) {
log.error("Failed to handleMessage. address: " + message.address(), e);
if (metrics != null) {
metrics.endHandleMessage(metric, e);
}
context.reportException(e);
}
checkNextTick();
}
private synchronized void checkNextTick() {
if (!pending.isEmpty() && demand > 0L) {
handlerContext.runOnContext(v -> {
Message<T> message;
Handler<Message<T>> theHandler;
ContextInternal ctx;
synchronized (HandlerRegistration.this) {
if (demand == 0L || (message = pending.poll()) == null) {
return;
}
if (demand != Long.MAX_VALUE) {
demand--;
}
theHandler = handler;
ctx = handlerContext;
}
deliver(theHandler, message, ctx);
});
}
}
public synchronized void discardHandler(Handler<Message<T>> handler) {
this.discardHandler = handler;
}
@Override
public synchronized MessageConsumer<T> handler(Handler<Message<T>> h) {
if (h != null) {
synchronized (this) {
handler = h;
if (registered == null) {
registered = eventBus.addRegistration(address, this, repliedAddress != null, localOnly);
}
}
return this;
}
this.unregister();
return this;
}
@Override
public ReadStream<T> bodyStream() {
return new BodyReadStream<>(this);
}
@Override
public synchronized boolean isRegistered() {
return registered != null;
}
@Override
public synchronized MessageConsumer<T> pause() {
demand = 0L;
return this;
}
@Override
public MessageConsumer<T> resume() {
return fetch(Long.MAX_VALUE);
}
@Override
public synchronized MessageConsumer<T> fetch(long amount) {
if (amount < 0) {
throw new IllegalArgumentException();
}
demand += amount;
if (demand < 0L) {
demand = Long.MAX_VALUE;
}
if (demand > 0L) {
checkNextTick();
}
return this;
}
@Override
public synchronized MessageConsumer<T> endHandler(Handler<Void> endHandler) {
if (endHandler != null) {
Context endCtx = vertx.getOrCreateContext();
this.endHandler = v1 -> endCtx.runOnContext(v2 -> endHandler.handle(null));
} else {
this.endHandler = null;
}
return this;
}
@Override
public synchronized MessageConsumer<T> exceptionHandler(Handler<Throwable> handler) {
return this;
}
public Handler<Message<T>> getHandler() {
return handler;
}
public Object getMetric() {
return metric;
}
}