package io.vertx.circuitbreaker.impl;
import io.vertx.circuitbreaker.CircuitBreaker;
import io.vertx.circuitbreaker.CircuitBreakerOptions;
import io.vertx.circuitbreaker.CircuitBreakerState;
import io.vertx.circuitbreaker.OpenCircuitException;
import io.vertx.circuitbreaker.TimeoutException;
import io.vertx.core.Context;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
public class CircuitBreakerImpl implements CircuitBreaker {
private static final Handler<Void> NOOP = (v) -> {
};
private final Vertx vertx;
private final CircuitBreakerOptions options;
private final String name;
private final long periodicUpdateTask;
private Handler<Void> openHandler = NOOP;
private Handler<Void> halfOpenHandler = NOOP;
private Handler<Void> closeHandler = NOOP;
private Function fallback = null;
private CircuitBreakerState state = CircuitBreakerState.CLOSED;
private RollingCounter rollingFailures;
private final AtomicInteger passed = new AtomicInteger();
private CircuitBreakerMetrics metrics;
private Function<Integer, Long> retryPolicy = retry -> 0L;
public CircuitBreakerImpl(String name, Vertx vertx, CircuitBreakerOptions options) {
Objects.requireNonNull(name);
Objects.requireNonNull(vertx);
this.vertx = vertx;
this.name = name;
if (options == null) {
this.options = new CircuitBreakerOptions();
} else {
this.options = new CircuitBreakerOptions(options);
}
this.metrics = new CircuitBreakerMetrics(vertx, this, options);
this.rollingFailures = new RollingCounter(options.getFailuresRollingWindow() / 1000, TimeUnit.SECONDS);
sendUpdateOnEventBus();
if (this.options.getNotificationPeriod() > 0) {
this.periodicUpdateTask = vertx.setPeriodic(this.options.getNotificationPeriod(), l -> sendUpdateOnEventBus());
} else {
this.periodicUpdateTask = -1;
}
}
@Override
public CircuitBreaker close() {
if (this.periodicUpdateTask != -1) {
vertx.cancelTimer(this.periodicUpdateTask);
}
metrics.close();
return this;
}
@Override
public synchronized CircuitBreaker openHandler(Handler<Void> handler) {
Objects.requireNonNull(handler);
openHandler = handler;
return this;
}
@Override
public synchronized CircuitBreaker halfOpenHandler(Handler<Void> handler) {
Objects.requireNonNull(handler);
halfOpenHandler = handler;
return this;
}
@Override
public synchronized CircuitBreaker closeHandler(Handler<Void> handler) {
Objects.requireNonNull(handler);
closeHandler = handler;
return this;
}
@Override
public <T> CircuitBreaker fallback(Function<Throwable, T> handler) {
Objects.requireNonNull(handler);
fallback = handler;
return this;
}
public synchronized CircuitBreaker reset(boolean force) {
rollingFailures.reset();
if (state == CircuitBreakerState.CLOSED) {
return this;
}
if (!force && state == CircuitBreakerState.OPEN) {
return this;
}
state = CircuitBreakerState.CLOSED;
closeHandler.handle(null);
sendUpdateOnEventBus();
return this;
}
@Override
public synchronized CircuitBreaker reset() {
return reset(false);
}
private synchronized void sendUpdateOnEventBus() {
String address = options.getNotificationAddress();
if (address != null) {
vertx.eventBus().publish(address, metrics.toJson());
}
}
@Override
public synchronized CircuitBreaker open() {
state = CircuitBreakerState.OPEN;
openHandler.handle(null);
sendUpdateOnEventBus();
long period = options.getResetTimeout();
if (period != -1) {
vertx.setTimer(period, l -> attemptReset());
}
return this;
}
@Override
public synchronized long failureCount() {
return rollingFailures.count();
}
@Override
public synchronized CircuitBreakerState state() {
return state;
}
private synchronized CircuitBreaker attemptReset() {
if (state == CircuitBreakerState.OPEN) {
passed.set(0);
state = CircuitBreakerState.HALF_OPEN;
halfOpenHandler.handle(null);
sendUpdateOnEventBus();
}
return this;
}
@Override
public <T> CircuitBreaker executeAndReportWithFallback(
Promise<T> userFuture,
Handler<Promise<T>> command,
Function<Throwable, T> fallback) {
Context context = vertx.getOrCreateContext();
CircuitBreakerState currentState;
synchronized (this) {
currentState = state;
}
CircuitBreakerMetrics.Operation call = metrics.enqueue();
Promise<T> operationResult = Promise.promise();
operationResult.future().setHandler(event -> {
context.runOnContext(v -> {
if (event.failed()) {
incrementFailures();
call.failed();
if (options.isFallbackOnFailure()) {
invokeFallback(event.cause(), userFuture, fallback, call);
} else {
userFuture.fail(event.cause());
}
} else {
call.complete();
reset();
userFuture.complete(event.result());
}
});
});
if (currentState == CircuitBreakerState.CLOSED) {
if (options.getMaxRetries() > 0) {
executeOperation(context, command, retryFuture(context, 0, command, operationResult, call), call);
} else {
executeOperation(context, command, operationResult, call);
}
} else if (currentState == CircuitBreakerState.OPEN) {
call.shortCircuited();
invokeFallback(OpenCircuitException.INSTANCE, userFuture, fallback, call);
} else if (currentState == CircuitBreakerState.HALF_OPEN) {
if (passed.incrementAndGet() == 1) {
operationResult.future().setHandler(event -> {
if (event.failed()) {
open();
call.failed();
if (options.isFallbackOnFailure()) {
invokeFallback(event.cause(), userFuture, fallback, call);
} else {
userFuture.fail(event.cause());
}
} else {
call.complete();
reset();
userFuture.complete(event.result());
}
});
executeOperation(context, command, operationResult, call);
} else {
call.shortCircuited();
invokeFallback(OpenCircuitException.INSTANCE, userFuture, fallback, call);
}
}
return this;
}
private <T> Promise<T> retryFuture(Context context, int retryCount, Handler<Promise<T>> command, Promise<T>
operationResult, CircuitBreakerMetrics.Operation call) {
Promise<T> retry = Promise.promise();
retry.future().setHandler(event -> {
if (event.succeeded()) {
reset();
context.runOnContext(v -> {
operationResult.complete(event.result());
});
return;
}
CircuitBreakerState currentState;
synchronized (this) {
currentState = state;
}
if (currentState == CircuitBreakerState.CLOSED) {
if (retryCount < options.getMaxRetries() - 1) {
executeRetryWithTimeout(retryCount, l -> {
context.runOnContext(v -> {
executeOperation(context, command, retryFuture(context, retryCount + 1, command, operationResult, null),
call);
});
});
} else {
executeRetryWithTimeout(retryCount, (l) -> {
context.runOnContext(v -> {
executeOperation(context, command, operationResult, call);
});
});
}
} else {
context.runOnContext(v -> operationResult.fail(OpenCircuitException.INSTANCE));
}
});
return retry;
}
private void executeRetryWithTimeout(int retryCount, Handler<Void> action) {
long retryTimeout = retryPolicy.apply(retryCount + 1);
if (retryTimeout > 0) {
vertx.setTimer(retryTimeout, (l) -> {
action.handle(null);
});
} else {
action.handle(null);
}
}
private <T> void invokeFallback(Throwable reason, Promise<T> userFuture,
Function<Throwable, T> fallback, CircuitBreakerMetrics.Operation operation) {
if (fallback == null) {
userFuture.fail(reason);
return;
}
try {
T apply = fallback.apply(reason);
operation.fallbackSucceed();
userFuture.complete(apply);
} catch (Exception e) {
userFuture.fail(e);
operation.fallbackFailed();
}
}
private <T> void executeOperation(Context context, Handler<Promise<T>> operation, Promise<T> operationResult,
CircuitBreakerMetrics.Operation call) {
if (options.getTimeout() != -1) {
vertx.setTimer(options.getTimeout(), (l) -> {
context.runOnContext(v -> {
if (!operationResult.future().isComplete()) {
if (call != null) {
call.timeout();
}
operationResult.fail(TimeoutException.INSTANCE);
}
});
});
}
try {
Promise<T> passedFuture = Promise.promise();
passedFuture.future().setHandler(ar -> {
context.runOnContext(v -> {
if (ar.failed()) {
if (!operationResult.future().isComplete()) {
operationResult.fail(ar.cause());
}
} else {
if (!operationResult.future().isComplete()) {
operationResult.complete(ar.result());
}
}
});
});
operation.handle(passedFuture);
} catch (Throwable e) {
context.runOnContext(v -> {
if (!operationResult.future().isComplete()) {
if (call != null) {
call.error();
}
operationResult.fail(e);
}
});
}
}
@Override
public <T> Future<T> executeWithFallback(Handler<Promise<T>> operation, Function<Throwable, T> fallback) {
Promise<T> future = Promise.promise();
executeAndReportWithFallback(future, operation, fallback);
return future.future();
}
public <T> Future<T> execute(Handler<Promise<T>> operation) {
return executeWithFallback(operation, fallback);
}
@Override
public <T> CircuitBreaker executeAndReport(Promise<T> resultFuture, Handler<Promise<T>> operation) {
return executeAndReportWithFallback(resultFuture, operation, fallback);
}
@Override
public String name() {
return name;
}
private synchronized void incrementFailures() {
rollingFailures.increment();
if (rollingFailures.count() >= options.getMaxFailures()) {
if (state != CircuitBreakerState.OPEN) {
open();
} else {
sendUpdateOnEventBus();
}
} else {
sendUpdateOnEventBus();
}
}
public CircuitBreakerMetrics getMetrics() {
return metrics;
}
public CircuitBreakerOptions options() {
return options;
}
@Override
public CircuitBreaker retryPolicy(Function<Integer, Long> retryPolicy) {
this.retryPolicy = retryPolicy;
return this;
}
public static class RollingCounter {
private Map<Long, Long> window;
private long timeUnitsInWindow;
private TimeUnit windowTimeUnit;
public RollingCounter(long timeUnitsInWindow, TimeUnit windowTimeUnit) {
this.windowTimeUnit = windowTimeUnit;
this.window = new LinkedHashMap<>((int) timeUnitsInWindow + 1);
this.timeUnitsInWindow = timeUnitsInWindow;
}
public void increment() {
long timeSlot = windowTimeUnit.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
Long current = window.getOrDefault(timeSlot, 0L);
window.put(timeSlot, ++current);
if (window.size() > timeUnitsInWindow) {
Iterator<Long> iterator = window.keySet().iterator();
if (iterator.hasNext()) {
window.remove(iterator.next());
}
}
}
public long count() {
long windowStartTime = windowTimeUnit.convert(System.currentTimeMillis() - windowTimeUnit.toMillis(timeUnitsInWindow), TimeUnit.MILLISECONDS);
return window.entrySet().stream().filter(entry -> entry.getKey() >= windowStartTime).mapToLong(entry -> entry.getValue()).sum();
}
public void reset() {
window.clear();
}
}
}