package org.jboss.resteasy.core;
import org.jboss.resteasy.annotations.Stream;
import org.jboss.resteasy.core.ResteasyContext.CloseableContext;
import org.jboss.resteasy.plugins.providers.sse.OutboundSseEventImpl;
import org.jboss.resteasy.plugins.providers.sse.SseConstants;
import org.jboss.resteasy.plugins.providers.sse.SseImpl;
import org.jboss.resteasy.resteasy_jaxrs.i18n.Messages;
import org.jboss.resteasy.specimpl.BuiltResponse;
import org.jboss.resteasy.specimpl.BuiltResponseEntityNotBacked;
import org.jboss.resteasy.specimpl.MultivaluedTreeMap;
import org.jboss.resteasy.spi.AsyncResponseProvider;
import org.jboss.resteasy.spi.AsyncStreamProvider;
import org.jboss.resteasy.spi.Dispatcher;
import org.jboss.resteasy.spi.HttpRequest;
import org.jboss.resteasy.spi.HttpResponse;
import org.jboss.resteasy.spi.ResteasyAsynchronousResponse;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import javax.ws.rs.sse.OutboundSseEvent;
import javax.ws.rs.sse.SseEventSink;
import java.io.IOException;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
public abstract class AsyncResponseConsumer
{
protected Map<Class<?>, Object> contextDataMap;
protected ResourceMethodInvoker method;
protected SynchronousDispatcher dispatcher;
protected ResteasyAsynchronousResponse asyncResponse;
protected boolean isComplete;
public AsyncResponseConsumer(final ResourceMethodInvoker method)
{
this.method = method;
contextDataMap = ResteasyContext.getContextDataMap();
dispatcher = (SynchronousDispatcher) contextDataMap.get(Dispatcher.class);
HttpRequest httpRequest = (HttpRequest) contextDataMap.get(HttpRequest.class);
if(httpRequest.getAsyncContext().isSuspended())
asyncResponse = httpRequest.getAsyncContext().getAsyncResponse();
else
asyncResponse = httpRequest.getAsyncContext().suspend();
}
public static AsyncResponseConsumer makeAsyncResponseConsumer(ResourceMethodInvoker method, AsyncResponseProvider<?> asyncResponseProvider) {
return new CompletionStageResponseConsumer(method, asyncResponseProvider);
}
public static AsyncResponseConsumer makeAsyncResponseConsumer(ResourceMethodInvoker method, AsyncStreamProvider<?> asyncStreamProvider) {
if(method.isSse())
{
return new AsyncGeneralStreamingSseResponseConsumer(method, asyncStreamProvider);
}
Stream stream = method.getMethod().getAnnotation(Stream.class);
if (stream != null)
{
if (Stream.MODE.RAW.equals(stream.value()))
{
return new AsyncRawStreamingResponseConsumer(method, asyncStreamProvider);
}
else
{
return new AsyncGeneralStreamingSseResponseConsumer(method, asyncStreamProvider);
}
}
return new AsyncStreamCollectorResponseConsumer(method, asyncStreamProvider);
}
protected void doComplete() {
asyncResponse.complete();
}
public final synchronized void complete(Throwable t)
{
if (!isComplete)
{
isComplete = true;
doComplete();
asyncResponse.completionCallbacks(t);
ResteasyContext.removeContextDataLevel();
}
}
protected void internalResume(Object entity, Consumer<Throwable> onComplete)
{
try(CloseableContext c = ResteasyContext.addCloseableContextDataLevel(contextDataMap)){
HttpRequest httpRequest = (HttpRequest) contextDataMap.get(HttpRequest.class);
HttpResponse httpResponse = (HttpResponse) contextDataMap.get(HttpResponse.class);
BuiltResponse builtResponse = createResponse(entity, httpRequest);
try
{
sendBuiltResponse(builtResponse, httpRequest, httpResponse, e -> {
if(e != null)
{
exceptionWhileResuming(e);
}
onComplete.accept(e);
});
}
catch (Throwable e)
{
exceptionWhileResuming(e);
onComplete.accept(e);
}
}
}
private void exceptionWhileResuming(Throwable e)
{
try
{
internalResume(e, t -> {});
}
catch(Throwable t2)
{
}
complete(e);
}
protected void sendBuiltResponse(BuiltResponse builtResponse, HttpRequest httpRequest, HttpResponse httpResponse, Consumer<Throwable> onComplete) throws IOException
{
boolean sendHeaders = sendHeaders();
ServerResponseWriter.writeNomapResponse(builtResponse, httpRequest, httpResponse, dispatcher.getProviderFactory(), onComplete, sendHeaders);
}
protected abstract boolean ();
protected void internalResume(Throwable t, Consumer<Throwable> onComplete)
{
try(CloseableContext c = ResteasyContext.addCloseableContextDataLevel(contextDataMap)){
HttpRequest httpRequest = (HttpRequest) contextDataMap.get(HttpRequest.class);
HttpResponse httpResponse = (HttpResponse) contextDataMap.get(HttpResponse.class);
try {
dispatcher.writeException(httpRequest, httpResponse, t, onComplete);
}catch(Throwable t2) {
dispatcher.unhandledAsynchronousException(httpResponse, t);
onComplete.accept(t);
}
}
}
protected BuiltResponse createResponse(Object entity, HttpRequest httpRequest)
{
BuiltResponse builtResponse = null;
if (entity == null)
{
builtResponse = (BuiltResponse) Response.noContent().build();
}
else if (entity instanceof BuiltResponse)
{
builtResponse = (BuiltResponse) entity;
}
else if (entity instanceof Response)
{
Response r = (Response) entity;
Headers<Object> metadata = new Headers<Object>();
metadata.putAll(r.getMetadata());
builtResponse = new BuiltResponseEntityNotBacked(r.getStatus(), r.getStatusInfo().getReasonPhrase(),
metadata, r.getEntity(), method.getMethodAnnotations());
}
else
{
if (method == null)
{
throw new IllegalStateException(Messages.MESSAGES.unknownMediaTypeResponseEntity());
}
BuiltResponse jaxrsResponse = (BuiltResponse) Response.ok(entity).build();
Type unwrappedType = ((ParameterizedType)method.getGenericReturnType()).getActualTypeArguments()[0];
Type newType = adaptGenericType(unwrappedType);
jaxrsResponse.setGenericType(newType);
jaxrsResponse.addMethodAnnotations(method.getMethodAnnotations());
builtResponse = jaxrsResponse;
}
return builtResponse;
}
protected Type adaptGenericType(Type unwrappedType)
{
return unwrappedType;
}
private static class CompletionStageResponseConsumer extends AsyncResponseConsumer implements BiConsumer<Object, Throwable>
{
private AsyncResponseProvider<?> asyncResponseProvider;
CompletionStageResponseConsumer(final ResourceMethodInvoker method, final AsyncResponseProvider<?> asyncResponseProvider)
{
super(method);
this.asyncResponseProvider = asyncResponseProvider;
}
@Override
protected boolean ()
{
return true;
}
@Override
public void accept(Object t, Throwable u)
{
if (t != null || u == null)
{
internalResume(t, x -> complete(null));
}
else
{
if(u instanceof CompletionException) {
u = u.getCause();
}
Throwable throwable = u;
internalResume(throwable, x -> complete(throwable));
}
}
@Override
public void subscribe(Object rtn)
{
@SuppressWarnings({ "unchecked", "rawtypes" })
CompletionStage<?> stage = ((AsyncResponseProvider)asyncResponseProvider).toCompletionStage(rtn);
stage.whenComplete(this);
}
}
private abstract static class AsyncStreamResponseConsumer extends AsyncResponseConsumer implements Subscriber<Object>
{
protected Subscription subscription;
private AsyncStreamProvider<?> asyncStreamProvider;
AsyncStreamResponseConsumer(final ResourceMethodInvoker method, final AsyncStreamProvider<?> asyncStreamProvider)
{
super(method);
this.asyncStreamProvider = asyncStreamProvider;
}
@Override
protected void doComplete()
{
if(subscription != null)
subscription.cancel();
super.doComplete();
}
@Override
public void onComplete()
{
complete(null);
}
@Override
public void onError(Throwable t)
{
internalResume(t, x -> complete(t));
}
protected void addNextElement(Object element)
{
internalResume(element, t -> {
if(t != null)
complete(t);
});
}
@Override
public void onNext(Object v)
{
addNextElement(v);
}
@Override
public void onSubscribe(Subscription subscription)
{
this.subscription = subscription;
subscription.request(1);
}
@Override
public void subscribe(Object rtn)
{
@SuppressWarnings({ "unchecked", "rawtypes" })
Publisher<?> publisher = ((AsyncStreamProvider)asyncStreamProvider).toAsyncStream(rtn);
publisher.subscribe(this);
}
}
private static class AsyncRawStreamingResponseConsumer extends AsyncStreamResponseConsumer
{
private boolean sentEntity;
private boolean onCompleteReceived;
private volatile boolean sendingEvent;
AsyncRawStreamingResponseConsumer(final ResourceMethodInvoker method, final AsyncStreamProvider<?> asyncStreamProvider)
{
super(method, asyncStreamProvider);
}
@Override
protected void sendBuiltResponse(BuiltResponse builtResponse, HttpRequest httpRequest, HttpResponse httpResponse, Consumer<Throwable> onComplete) throws IOException
{
ServerResponseWriter.setResponseMediaType(builtResponse, httpRequest, httpResponse, dispatcher.getProviderFactory(), method);
boolean resetMediaType = false;
String mediaTypeString = builtResponse.getHeaderString("Content-Type");
if (mediaTypeString == null)
{
mediaTypeString = MediaType.APPLICATION_OCTET_STREAM;
resetMediaType = true;
}
MediaType mediaType = MediaType.valueOf(mediaTypeString);
Stream[] streams = method.getMethod().getAnnotationsByType(Stream.class);
if (streams.length > 0)
{
Stream stream = streams[0];
if (stream.includeStreaming())
{
Map<String, String> map = new HashMap<String, String>(mediaType.getParameters());
map.put(Stream.INCLUDE_STREAMING_PARAMETER, "true");
mediaType = new MediaType(mediaType.getType(), mediaType.getSubtype(), map);
resetMediaType = true;
}
}
if (resetMediaType)
{
MultivaluedMap<String, Object> headerMap = new MultivaluedTreeMap<String, Object>();
headerMap.putAll(builtResponse.getHeaders());
headerMap.remove("Content-Type");
headerMap.add("Content-Type", mediaType);
builtResponse.setMetadata(headerMap);
}
super.sendBuiltResponse(builtResponse, httpRequest, httpResponse, onComplete);
sentEntity = true;
}
protected void addNextElement(Object element)
{
sendingEvent = true;
internalResume(element, t -> {
synchronized(this) {
sendingEvent = false;
if(onCompleteReceived) {
super.onComplete();
}
else if(t != null)
{
complete(t);
}
else
{
subscription.request(1);
}
}
});
}
@Override
public synchronized void onComplete()
{
onCompleteReceived = true;
if(sendingEvent == false)
super.onComplete();
}
@Override
protected boolean ()
{
return !sentEntity;
}
}
private static class AsyncStreamCollectorResponseConsumer extends AsyncStreamResponseConsumer
{
private List<Object> collector = new ArrayList<Object>();
AsyncStreamCollectorResponseConsumer(final ResourceMethodInvoker method, final AsyncStreamProvider<?> asyncStreamProvider)
{
super(method, asyncStreamProvider);
}
@Override
protected boolean ()
{
return true;
}
@Override
protected void addNextElement(Object element)
{
collector.add(element);
subscription.request(1);
}
@Override
public void onComplete()
{
internalResume(collector, t -> complete(t));
}
@Override
protected Type adaptGenericType(Type unwrappedType)
{
return new ParameterizedType()
{
@Override
public Type[] getActualTypeArguments() {
return new Type[]{unwrappedType};
}
@Override
public Type getOwnerType() {
return null;
}
@Override
public Type getRawType() {
return List.class;
}
};
}
}
private static class AsyncGeneralStreamingSseResponseConsumer extends AsyncStreamResponseConsumer
{
private SseImpl sse;
private SseEventSink sseEventSink;
private boolean onCompleteReceived;
private volatile boolean sendingEvent;
private AsyncGeneralStreamingSseResponseConsumer(final ResourceMethodInvoker method, final AsyncStreamProvider<?> asyncStreamProvider)
{
super(method, asyncStreamProvider);
sse = new SseImpl();
sseEventSink = ResteasyContext.getContextData(SseEventSink.class);
}
@Override
protected void doComplete()
{
if(subscription != null)
subscription.cancel();
sseEventSink.close();
}
@Override
protected void addNextElement(Object element)
{
super.addNextElement(element);
}
@Override
public synchronized void onComplete()
{
onCompleteReceived = true;
if(sendingEvent == false)
super.onComplete();
}
@Override
protected void sendBuiltResponse(BuiltResponse builtResponse, HttpRequest httpRequest, HttpResponse httpResponse, Consumer<Throwable> onComplete)
{
ServerResponseWriter.setResponseMediaType(builtResponse, httpRequest, httpResponse, dispatcher.getProviderFactory(), method);
MediaType elementType = null;
if (builtResponse.getEntity() instanceof OutboundSseEvent)
{
OutboundSseEvent entity = (OutboundSseEvent)builtResponse.getEntity();
elementType = entity.getMediaType();
}
MediaType contentType = null;
Object o = httpResponse.getOutputHeaders().getFirst("Content-Type");
if (o != null)
{
if (o instanceof String)
{
contentType = MediaType.valueOf((String) o);
}
else if (o instanceof MediaType)
{
contentType = (MediaType) o;
}
else
{
throw new RuntimeException(Messages.MESSAGES.expectedStringOrMediaType(o));
}
if (elementType == null)
{
String et = contentType.getParameters().get(SseConstants.SSE_ELEMENT_MEDIA_TYPE);
elementType = et != null ? MediaType.valueOf(et) : MediaType.TEXT_PLAIN_TYPE;
}
}
else
{
throw new RuntimeException(Messages.MESSAGES.expectedStringOrMediaType(o));
}
OutboundSseEvent event = sse.newEventBuilder()
.mediaType(elementType)
.data(builtResponse.getEntityClass(), builtResponse.getEntity())
.build();
if ("application".equals(contentType.getType())
&& "x-stream-general".equals(contentType.getSubtype())
&& event instanceof OutboundSseEventImpl)
{
((OutboundSseEventImpl) event).setEscape(true);
}
sendingEvent = true;
try {
sseEventSink.send(event).whenComplete((val, ex) -> {
synchronized(this) {
sendingEvent = false;
if(onCompleteReceived)
super.onComplete();
else if(ex != null)
{
complete(ex);
onComplete.accept(ex);
}
else
{
subscription.request(1);
onComplete.accept(ex);
}
}
});
}catch(Exception x) {
complete(x);
onComplete.accept(x);
}
}
@Override
protected boolean ()
{
return false;
}
}
public abstract void subscribe(Object rtn);
}