package io.vertx.ext.web.handler.graphql.impl;
import graphql.GraphQL;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.http.ServerWebSocket;
import io.vertx.core.impl.ContextInternal;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.graphql.ApolloWSConnectionInitEvent;
import io.vertx.ext.web.handler.graphql.ApolloWSHandler;
import io.vertx.ext.web.handler.graphql.ApolloWSMessage;
import io.vertx.ext.web.handler.graphql.ApolloWSOptions;
import org.dataloader.DataLoaderRegistry;
import java.util.Locale;
import java.util.Objects;
import java.util.function.Function;
import static io.vertx.core.http.HttpHeaders.*;
public class ApolloWSHandlerImpl implements ApolloWSHandler {
private static final Function<ApolloWSMessage, Object> DEFAULT_QUERY_CONTEXT_FACTORY = context -> context;
private static final Function<ApolloWSMessage, DataLoaderRegistry> DEFAULT_DATA_LOADER_REGISTRY_FACTORY = rc -> null;
private static final Function<ApolloWSMessage, Locale> DEFAULT_LOCALE_FACTORY = rc -> null;
private final GraphQL graphQL;
private final long keepAlive;
private Function<ApolloWSMessage, Object> queryContextFactory = DEFAULT_QUERY_CONTEXT_FACTORY;
private Function<ApolloWSMessage, DataLoaderRegistry> dataLoaderRegistryFactory = DEFAULT_DATA_LOADER_REGISTRY_FACTORY;
private Function<ApolloWSMessage, Locale> localeFactory = DEFAULT_LOCALE_FACTORY;
private Handler<ServerWebSocket> connectionHandler;
private Handler<ApolloWSConnectionInitEvent> connectionInitHandler;
private Handler<ServerWebSocket> endHandler;
private Handler<ApolloWSMessage> messageHandler;
public ApolloWSHandlerImpl(GraphQL graphQL, ApolloWSOptions options) {
Objects.requireNonNull(graphQL, "graphQL");
Objects.requireNonNull(options, "options");
this.graphQL = graphQL;
this.keepAlive = options.getKeepAlive();
}
GraphQL getGraphQL() {
return graphQL;
}
long getKeepAlive() {
return keepAlive;
}
@Override
public synchronized ApolloWSHandler connectionHandler(Handler<ServerWebSocket> connectionHandler) {
this.connectionHandler = connectionHandler;
return this;
}
synchronized Handler<ServerWebSocket> getConnectionHandler() {
return connectionHandler;
}
@Override
public ApolloWSHandler connectionInitHandler(Handler<ApolloWSConnectionInitEvent> connectionInitHandler) {
this.connectionInitHandler = connectionInitHandler;
return this;
}
synchronized Handler<ApolloWSConnectionInitEvent> getConnectionInitHandler() {
return connectionInitHandler;
}
@Override
public synchronized ApolloWSHandler messageHandler(Handler<ApolloWSMessage> messageHandler) {
this.messageHandler = messageHandler;
return this;
}
synchronized Handler<ApolloWSMessage> getMessageHandler() {
return messageHandler;
}
@Override
public synchronized ApolloWSHandler endHandler(Handler<ServerWebSocket> endHandler) {
this.endHandler = endHandler;
return this;
}
synchronized Handler<ServerWebSocket> getEndHandler() {
return endHandler;
}
@Override
public synchronized ApolloWSHandler queryContext(Function<ApolloWSMessage, Object> factory) {
queryContextFactory = factory != null ? factory : DEFAULT_QUERY_CONTEXT_FACTORY;
return this;
}
synchronized Function<ApolloWSMessage, Object> getQueryContext() {
return queryContextFactory;
}
@Override
public synchronized ApolloWSHandler dataLoaderRegistry(Function<ApolloWSMessage, DataLoaderRegistry> factory) {
dataLoaderRegistryFactory = factory != null ? factory : DEFAULT_DATA_LOADER_REGISTRY_FACTORY;
return this;
}
synchronized Function<ApolloWSMessage, DataLoaderRegistry> getDataLoaderRegistry() {
return dataLoaderRegistryFactory;
}
@Override
public synchronized ApolloWSHandler locale(Function<ApolloWSMessage, Locale> factory) {
localeFactory = factory != null ? factory : DEFAULT_LOCALE_FACTORY;
return this;
}
synchronized Function<ApolloWSMessage, Locale> getLocale() {
return localeFactory;
}
@Override
public void handle(RoutingContext routingContext) {
MultiMap headers = routingContext.request().headers();
if (headers.contains(CONNECTION) && headers.contains(UPGRADE, WEBSOCKET, true)) {
ContextInternal context = (ContextInternal) routingContext.vertx().getOrCreateContext();
routingContext.request().toWebSocket().onSuccess(ws -> {
ApolloWSConnectionHandler connectionHandler = new ApolloWSConnectionHandler(this, context, ws);
connectionHandler.handleConnection();
});
} else {
routingContext.next();
}
}
}