package io.vertx.ext.web.handler.graphql.impl;
import graphql.ExecutionInput;
import graphql.GraphQL;
import io.vertx.core.Context;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.impl.NoStackTraceThrowable;
import io.vertx.core.json.Json;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.MIMEHeader;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.graphql.GraphQLHandler;
import io.vertx.ext.web.handler.graphql.GraphQLHandlerOptions;
import io.vertx.ext.web.handler.graphql.GraphiQLOptions;
import org.dataloader.DataLoaderRegistry;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Scanner;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.function.Function;
import java.util.stream.Stream;
import static io.vertx.core.http.HttpMethod.GET;
import static io.vertx.core.http.HttpMethod.POST;
import static java.util.stream.Collectors.toList;
public class GraphQLHandlerImpl implements GraphQLHandler {
private static final Function<RoutingContext, Object> DEFAULT_QUERY_CONTEXT_FACTORY = rc -> rc;
private static final Function<RoutingContext, DataLoaderRegistry> DEFAULT_DATA_LOADER_REGISTRY_FACTORY = rc -> null;
private static final Function<RoutingContext, MultiMap> DEFAULT_GRAPHIQL_REQUEST_HEADERS_FACTORY = rc -> null;
private final GraphQL graphQL;
private final GraphQLHandlerOptions options;
private Function<RoutingContext, Object> queryContextFactory = DEFAULT_QUERY_CONTEXT_FACTORY;
private Function<RoutingContext, DataLoaderRegistry> dataLoaderRegistryFactory = DEFAULT_DATA_LOADER_REGISTRY_FACTORY;
private Function<RoutingContext, MultiMap> graphiQLRequestHeadersFactory = DEFAULT_GRAPHIQL_REQUEST_HEADERS_FACTORY;
public GraphQLHandlerImpl(GraphQL graphQL, GraphQLHandlerOptions options) {
Objects.requireNonNull(graphQL, "graphQL");
Objects.requireNonNull(options, "options");
this.graphQL = graphQL;
this.options = options;
}
@Override
public synchronized GraphQLHandler queryContext(Function<RoutingContext, Object> factory) {
queryContextFactory = factory != null ? factory : DEFAULT_QUERY_CONTEXT_FACTORY;
return this;
}
@Override
public synchronized GraphQLHandler dataLoaderRegistry(Function<RoutingContext, DataLoaderRegistry> factory) {
dataLoaderRegistryFactory = factory != null ? factory : DEFAULT_DATA_LOADER_REGISTRY_FACTORY;
return this;
}
@Override
public synchronized GraphQLHandler graphiQLRequestHeaders(Function<RoutingContext, MultiMap> factory) {
graphiQLRequestHeadersFactory = factory != null ? factory : DEFAULT_GRAPHIQL_REQUEST_HEADERS_FACTORY;
return this;
}
@Override
public void handle(RoutingContext rc) {
HttpMethod method = rc.request().method();
if (method == GET) {
handleGet(rc);
} else if (method == POST) {
Buffer body = rc.getBody();
if (body == null) {
rc.request().bodyHandler(buffer -> handlePost(rc, buffer));
} else {
handlePost(rc, body);
}
} else {
rc.fail(405);
}
}
private void handleGet(RoutingContext rc) {
if (options.getGraphiQLOptions().isEnabled()) {
Stream<String> accept = rc.parsedHeaders().accept().stream().map(MIMEHeader::subComponent);
if (accept.anyMatch(sub -> "html".equalsIgnoreCase(sub))) {
handleGraphiQL(rc);
return;
}
}
String query = rc.queryParams().get("query");
if (query == null) {
failQueryMissing(rc);
return;
}
Map<String, Object> variables;
try {
variables = getVariablesFromQueryParam(rc);
} catch (Exception e) {
rc.fail(400, e);
return;
}
executeOne(rc, new GraphQLQuery(query, rc.queryParams().get("operationName"), variables));
}
private void handleGraphiQL(RoutingContext rc) {
ClassLoader classLoader = getClass().getClassLoader();
try (InputStream stream = classLoader.getResourceAsStream("io/vertx/ext/web/handler/graphql/graphiql.html")) {
String source = new Scanner(stream, "UTF-8").useDelimiter("\\A").next();
String replacement = replacement(rc);
String html = replacement.isEmpty() ? source : source.replace("<!-- VERTX-WEB-GRAPHIQL-REPLACEMENT -->", replacement);
rc.response().end(html);
} catch (IOException ignore) {
}
}
private String replacement(RoutingContext rc) {
GraphiQLOptions graphiQLOptions = options.getGraphiQLOptions();
StringBuilder builder = new StringBuilder();
if (graphiQLOptions.getGraphQLUri() != null) {
builder.append("var graphQLUri = ").append(graphiQLOptions.getGraphQLUri()).append(";");
}
MultiMap headers = MultiMap.caseInsensitiveMultiMap();
Map<String, String> fixedHeaders = graphiQLOptions.getHeaders();
if (fixedHeaders != null) {
fixedHeaders.forEach(headers::add);
}
Function<RoutingContext, MultiMap> rh;
synchronized (this) {
rh = this.graphiQLRequestHeadersFactory;
}
MultiMap dynamicHeaders = rh.apply(rc);
if (dynamicHeaders != null) {
headers.addAll(dynamicHeaders);
}
if (!headers.isEmpty()) {
headers.forEach(header -> builder.append("headers['").append(header.getKey()).append("'] = '").append(header.getValue()).append("';"));
}
if (graphiQLOptions.getQuery() != null) {
builder.append("parameters['query'] = '").append(graphiQLOptions.getQuery()).append("';");
}
if (graphiQLOptions.getVariables() != null) {
builder.append("parameters['variables'] = '").append(graphiQLOptions.getVariables().encode()).append("';");
}
return builder.toString();
}
private void handlePost(RoutingContext rc, Buffer body) {
Map<String, Object> variables;
try {
variables = getVariablesFromQueryParam(rc);
} catch (Exception e) {
rc.fail(400, e);
return;
}
String query = rc.queryParams().get("query");
if (query != null) {
executeOne(rc, new GraphQLQuery(query, rc.queryParams().get("operationName"), variables));
return;
}
switch (getContentType(rc)) {
case "application/json":
handlePostJson(rc, body, rc.queryParams().get("operationName"), variables);
break;
case "application/graphql":
executeOne(rc, new GraphQLQuery(body.toString(), rc.queryParams().get("operationName"), variables));
break;
default:
rc.fail(415);
}
}
private void handlePostJson(RoutingContext rc, Buffer body, String operationName, Map<String, Object> variables) {
GraphQLInput graphQLInput;
try {
graphQLInput = Json.decodeValue(body, GraphQLInput.class);
} catch (Exception e) {
rc.fail(400, e);
return;
}
if (graphQLInput instanceof GraphQLBatch) {
handlePostBatch(rc, (GraphQLBatch) graphQLInput, operationName, variables);
} else if (graphQLInput instanceof GraphQLQuery) {
handlePostQuery(rc, (GraphQLQuery) graphQLInput, operationName, variables);
} else {
rc.fail(500);
}
}
private void handlePostBatch(RoutingContext rc, GraphQLBatch batch, String operationName, Map<String, Object> variables) {
if (!options.isRequestBatchingEnabled()) {
rc.fail(400);
return;
}
for (GraphQLQuery query : batch) {
if (query.getQuery() == null) {
failQueryMissing(rc);
return;
}
if (operationName != null) {
query.setOperationName(operationName);
}
if (variables != null) {
query.setVariables(variables);
}
}
executeBatch(rc, batch);
}
private void executeBatch(RoutingContext rc, GraphQLBatch batch) {
List<CompletableFuture<JsonObject>> results = batch.stream()
.map(q -> execute(rc, q))
.collect(toList());
CompletableFuture.allOf((CompletableFuture<?>[]) results.toArray(new CompletableFuture<?>[0])).whenCompleteAsync((v, throwable) -> {
JsonArray jsonArray = results.stream()
.map(CompletableFuture::join)
.collect(JsonArray::new, JsonArray::add, JsonArray::addAll);
sendResponse(rc, jsonArray.toBuffer(), throwable);
}, contextExecutor(rc));
}
private void handlePostQuery(RoutingContext rc, GraphQLQuery query, String operationName, Map<String, Object> variables) {
if (query.getQuery() == null) {
failQueryMissing(rc);
return;
}
if (operationName != null) {
query.setOperationName(operationName);
}
if (variables != null) {
query.setVariables(variables);
}
executeOne(rc, query);
}
private void executeOne(RoutingContext rc, GraphQLQuery query) {
execute(rc, query)
.thenApply(JsonObject::toBuffer)
.whenComplete((buffer, throwable) -> sendResponse(rc, buffer, throwable));
}
private CompletableFuture<JsonObject> execute(RoutingContext rc, GraphQLQuery query) {
ExecutionInput.Builder builder = ExecutionInput.newExecutionInput();
builder.query(query.getQuery());
String operationName = query.getOperationName();
if (operationName != null) {
builder.operationName(operationName);
}
Map<String, Object> variables = query.getVariables();
if (variables != null) {
builder.variables(variables);
}
Function<RoutingContext, Object> qc;
synchronized (this) {
qc = queryContextFactory;
}
builder.context(qc.apply(rc));
Function<RoutingContext, DataLoaderRegistry> dlr;
synchronized (this) {
dlr = dataLoaderRegistryFactory;
}
DataLoaderRegistry registry = dlr.apply(rc);
if (registry != null) {
builder.dataLoaderRegistry(registry);
}
return graphQL.executeAsync(builder.build()).thenApplyAsync(executionResult -> {
return new JsonObject(executionResult.toSpecification());
}, contextExecutor(rc));
}
private String getContentType(RoutingContext rc) {
String contentType = rc.request().headers().get(HttpHeaders.CONTENT_TYPE);
return contentType == null ? "application/json" : contentType.toLowerCase();
}
private Map<String, Object> getVariablesFromQueryParam(RoutingContext rc) throws Exception {
String variablesParam = rc.queryParams().get("variables");
if (variablesParam == null) {
return null;
} else {
return new JsonObject(variablesParam).getMap();
}
}
private void sendResponse(RoutingContext rc, Buffer buffer, Throwable throwable) {
if (throwable == null) {
rc.response().putHeader(HttpHeaders.CONTENT_TYPE, "application/json").end(buffer);
} else {
rc.fail(throwable);
}
}
private void failQueryMissing(RoutingContext rc) {
rc.fail(400, new NoStackTraceThrowable("Query is missing"));
}
private Executor contextExecutor(RoutingContext rc) {
Context ctx = rc.vertx().getOrCreateContext();
return command -> ctx.runOnContext(v -> command.run());
}
}