package io.vertx.ext.web.handler.graphql.impl;
import graphql.ExecutionInput;
import graphql.ExecutionResult;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.Promise;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.ServerWebSocket;
import io.vertx.core.impl.ContextInternal;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.common.WebEnvironment;
import io.vertx.ext.web.handler.graphql.ApolloWSConnectionInitEvent;
import io.vertx.ext.web.handler.graphql.ApolloWSMessage;
import io.vertx.ext.web.handler.graphql.ApolloWSMessageType;
import org.dataloader.DataLoaderRegistry;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import static io.vertx.ext.web.handler.graphql.ApolloWSMessageType.*;
class ApolloWSConnectionHandler {
private static final Logger log = LoggerFactory.getLogger(ApolloWSConnectionHandler.class);
private static final short WS_INTERNAL_ERROR = 1011;
private final ApolloWSHandlerImpl apolloWSHandler;
private final ServerWebSocket serverWebSocket;
private final ContextInternal context;
private final Executor executor;
private final ConcurrentMap<String, Subscription> subscriptions;
private final Promise<Object> connectionPromise;
private final AtomicBoolean connectionInitialized;
ApolloWSConnectionHandler(ApolloWSHandlerImpl apolloWSHandler, ContextInternal context, ServerWebSocket serverWebSocket) {
this.apolloWSHandler = apolloWSHandler;
this.context = context;
this.serverWebSocket = serverWebSocket;
this.executor = task -> context.runOnContext(v -> task.run());
subscriptions = new ConcurrentHashMap<>();
connectionPromise = context.promise();
connectionInitialized = new AtomicBoolean(false);
}
void handleConnection() {
Handler<ServerWebSocket> ch = apolloWSHandler.getConnectionHandler();
if (ch != null) {
ch.handle(serverWebSocket);
}
serverWebSocket.binaryMessageHandler(this::handleBinaryMessage);
serverWebSocket.textMessageHandler(this::handleTextMessage);
serverWebSocket.closeHandler(this::close);
}
private void handleBinaryMessage(Buffer buffer) {
handleMessage(new JsonObject(buffer));
}
private void handleTextMessage(String text) {
handleMessage(new JsonObject(text));
}
private void handleMessage(JsonObject jsonObject) {
String opId = jsonObject.getString("id");
ApolloWSMessageType type = from(jsonObject.getString("type"));
if (type == null) {
sendMessage(opId, ERROR, "Unknown message type: " + jsonObject.getString("type"));
return;
}
ApolloWSMessageImpl message = new ApolloWSMessageImpl(serverWebSocket, type, jsonObject);
Handler<ApolloWSMessage> mh = apolloWSHandler.getMessageHandler();
if (mh != null) {
mh.handle(message);
}
Handler<ApolloWSConnectionInitEvent> connectionInitHandler = apolloWSHandler.getConnectionInitHandler();
switch (type) {
case CONNECTION_INIT:
if (!connectionInitialized.compareAndSet(false, true)) {
sendMessage(opId, ERROR, "CONNECTION_INIT can only be sent once")
.onComplete(v -> serverWebSocket.close(WS_INTERNAL_ERROR));
break;
}
if (connectionInitHandler != null) {
connectionInitHandler.handle(new ApolloWSConnectionInitEvent() {
@Override
public ApolloWSMessage message() {
return message;
}
@Override
public boolean tryComplete(Object o) {
return connectionPromise.tryComplete(o);
}
@Override
public boolean tryFail(Throwable throwable) {
return connectionPromise.tryFail(throwable);
}
@Override
public Future<Object> future() {
return connectionPromise.future();
}
});
} else {
connectionPromise.complete();
}
connectionPromise.future().onComplete(ar -> {
if (ar.succeeded()) {
connect();
} else {
sendMessage(opId, CONNECTION_ERROR, ar.cause().getMessage())
.onComplete(v -> serverWebSocket.close(WS_INTERNAL_ERROR));
}
});
break;
case CONNECTION_TERMINATE:
serverWebSocket.close();
break;
case START:
if (!connectionInitialized.get()) {
sendMessage(opId, ERROR, "CONNECTION_INIT has to be sent before START")
.onComplete(v -> serverWebSocket.close(WS_INTERNAL_ERROR));
break;
}
connectionPromise.future().onComplete(ar -> {
if (ar.succeeded()) {
ApolloWSMessage messageWithParams = new ApolloWSMessageImpl(serverWebSocket, type, jsonObject, ar.result());
start(messageWithParams);
} else {
sendMessage(opId, ERROR, ar.cause().getMessage());
stop(opId);
}
});
break;
case STOP:
stop(opId);
break;
default:
sendMessage(opId, ERROR, "Unsupported message type: " + type);
break;
}
}
private void connect() {
sendMessage(null, CONNECTION_ACK, null);
long keepAlive = apolloWSHandler.getKeepAlive();
if (keepAlive > 0) {
sendMessage(null, CONNECTION_KEEP_ALIVE, null);
context.setPeriodic(keepAlive, timerId -> {
if (serverWebSocket.isClosed()) {
context.owner().cancelTimer(timerId);
} else {
sendMessage(null, CONNECTION_KEEP_ALIVE, null);
}
});
}
}
private void start(ApolloWSMessage message) {
String opId = message.content().getString("id");
if (subscriptions.containsKey(opId)) {
stop(opId);
}
GraphQLQuery payload = new GraphQLQuery(message.content().getJsonObject("payload"));
ExecutionInput.Builder builder = ExecutionInput.newExecutionInput();
builder.query(payload.getQuery());
builder.context(apolloWSHandler.getQueryContext().apply(message));
DataLoaderRegistry registry = apolloWSHandler.getDataLoaderRegistry().apply(message);
if (registry != null) {
builder.dataLoaderRegistry(registry);
}
Locale locale = apolloWSHandler.getLocale().apply(message);
if (locale != null) {
builder.locale(locale);
}
String operationName = payload.getOperationName();
if (operationName != null) {
builder.operationName(operationName);
}
Map<String, Object> variables = payload.getVariables();
if (variables != null) {
builder.variables(variables);
}
apolloWSHandler.getGraphQL().executeAsync(builder).whenCompleteAsync((executionResult, throwable) -> {
if (throwable == null) {
if (executionResult.getData() instanceof Publisher) {
subscribe(opId, executionResult);
} else {
sendMessage(opId, DATA, new JsonObject(executionResult.toSpecification()));
sendMessage(opId, COMPLETE, null);
}
} else {
if (log.isDebugEnabled()) {
log.debug("Failed to execute GraphQL query, opId=" + opId, throwable);
}
sendMessage(opId, ERROR, toJsonObject(throwable));
}
}, executor);
}
private void subscribe(String opId, ExecutionResult executionResult) {
Publisher<ExecutionResult> publisher = executionResult.getData();
AtomicReference<Subscription> subscriptionRef = new AtomicReference<>();
publisher.subscribe(new Subscriber<ExecutionResult>() {
@Override
public void onSubscribe(Subscription s) {
subscriptionRef.set(s);
subscriptions.put(opId, s);
s.request(1);
}
@Override
public void onNext(ExecutionResult er) {
sendMessage(opId, DATA, new JsonObject(er.toSpecification()));
subscriptionRef.get().request(1);
}
@Override
public void onError(Throwable t) {
if (log.isDebugEnabled()) {
log.debug("GraphQL subscription terminated with error, opId=" + opId, t);
}
sendMessage(opId, ERROR, toJsonObject(t));
subscriptions.remove(opId);
}
@Override
public void onComplete() {
sendMessage(opId, COMPLETE, null);
subscriptions.remove(opId);
}
});
}
private void stop(String opId) {
Subscription subscription = subscriptions.get(opId);
if (subscription != null) {
subscription.cancel();
subscriptions.remove(opId);
}
}
private JsonObject toJsonObject(Throwable t) {
JsonObject res = new JsonObject().put("message", t.toString());
if (WebEnvironment.development()) {
StringWriter sw = new StringWriter();
try (PrintWriter writer = new PrintWriter(sw)) {
t.printStackTrace(writer);
writer.flush();
}
res.put("extensions", new JsonObject()
.put("exception", new JsonObject()
.put("stacktrace", sw.toString())));
}
return res;
}
private Future<Void> sendMessage(String opId, ApolloWSMessageType type, Object payload) {
Objects.requireNonNull(type, "type is null");
JsonObject message = new JsonObject();
if (opId != null) {
message.put("id", opId);
}
message.put("type", type.getText());
if (payload != null) {
message.put("payload", payload);
}
return serverWebSocket.writeTextMessage(message.toString());
}
private void close(Void v) {
subscriptions.values().forEach(Subscription::cancel);
Handler<ServerWebSocket> eh = apolloWSHandler.getEndHandler();
if (eh != null) {
eh.handle(serverWebSocket);
}
}
}