package org.jboss.resteasy.plugins.providers.sse.client;
import java.io.IOException;
import java.util.Date;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import javax.ws.rs.ServiceUnavailableException;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response.Status.Family;
import javax.ws.rs.ext.Providers;
import javax.ws.rs.sse.InboundSseEvent;
import javax.ws.rs.sse.SseEventSource;
import org.jboss.resteasy.client.jaxrs.ResteasyWebTarget;
import org.jboss.resteasy.client.jaxrs.internal.ClientConfiguration;
import org.jboss.resteasy.client.jaxrs.internal.ClientInvocation;
import org.jboss.resteasy.client.jaxrs.internal.ClientResponse;
import org.jboss.resteasy.plugins.providers.sse.SseConstants;
import org.jboss.resteasy.plugins.providers.sse.SseEventInputImpl;
import org.jboss.resteasy.resteasy_jaxrs.i18n.Messages;
public class SseEventSourceImpl implements SseEventSource
{
public static final long RECONNECT_DEFAULT = 500;
private final WebTarget target;
private final long reconnectDelay;
private final SseEventSourceScheduler sseEventSourceScheduler;
private enum State {
PENDING, OPEN, CLOSED
}
private final AtomicReference<State> state = new AtomicReference<>(State.PENDING);
private final List<Consumer<InboundSseEvent>> onEventConsumers = new CopyOnWriteArrayList<>();
private final List<Consumer<Throwable>> onErrorConsumers = new CopyOnWriteArrayList<>();
private final List<Runnable> onCompleteConsumers = new CopyOnWriteArrayList<>();
private final boolean alwaysReconnect;
private volatile ClientResponse response;
public static class SourceBuilder extends Builder
{
private WebTarget target = null;
private long reconnect = RECONNECT_DEFAULT;
private String name = null;
private ScheduledExecutorService executor;
private boolean alwaysReconnect = true;
public SourceBuilder()
{
}
public Builder named(String name)
{
this.name = name;
return this;
}
public SseEventSource build()
{
return new SseEventSourceImpl(target, name, reconnect, false, executor, alwaysReconnect);
}
@Override
public Builder target(WebTarget target)
{
if (target == null)
{
throw new NullPointerException();
}
this.target = target;
return this;
}
@Override
public Builder reconnectingEvery(long delay, TimeUnit unit)
{
reconnect = unit.toMillis(delay);
return this;
}
public Builder executor(ScheduledExecutorService executor)
{
this.executor = executor;
return this;
}
public Builder alwaysReconnect(boolean alwaysReconnect) {
this.alwaysReconnect = alwaysReconnect;
return this;
}
}
public SseEventSourceImpl(final WebTarget target)
{
this(target, true);
}
public SseEventSourceImpl(final WebTarget target, final boolean open)
{
this(target, null, RECONNECT_DEFAULT, open, null, true);
}
private SseEventSourceImpl(final WebTarget target, final String name, final long reconnectDelay, final boolean open, final ScheduledExecutorService executor, final boolean alwaysReconnect)
{
if (target == null)
{
throw new IllegalArgumentException(Messages.MESSAGES.webTargetIsNotSetForEventSource());
}
this.target = target;
this.reconnectDelay = reconnectDelay;
this.alwaysReconnect = alwaysReconnect;
if (executor == null)
{
ScheduledExecutorService scheduledExecutor = null;
if (target instanceof ResteasyWebTarget)
{
scheduledExecutor = ((ResteasyWebTarget) target).getResteasyClient().getScheduledExecutor();
}
if (name != null) {
this.sseEventSourceScheduler = new SseEventSourceScheduler(scheduledExecutor, name);
} else {
this.sseEventSourceScheduler = new SseEventSourceScheduler(scheduledExecutor, String.format("sse-event-source(%s)", target.getUri()));
}
}
else
{
if (name != null) {
this.sseEventSourceScheduler = new SseEventSourceScheduler(executor, name);
} else {
this.sseEventSourceScheduler = new SseEventSourceScheduler(executor, String.format("sse-event-source(%s)", target.getUri()));
}
}
if (open)
{
open();
}
}
@Override
public void open()
{
open(null);
}
public void open(String lastEventId)
{
open(lastEventId, "GET", null, MediaType.SERVER_SENT_EVENTS_TYPE);
}
public void open(String lastEventId, String verb, Entity<?> entity, MediaType... mediaTypes)
{
if (!state.compareAndSet(State.PENDING, State.OPEN))
{
throw new IllegalStateException(Messages.MESSAGES.eventSourceIsNotReadyForOpen());
}
EventHandler handler = new EventHandler(reconnectDelay, lastEventId, verb, entity, mediaTypes);
sseEventSourceScheduler.schedule(handler, 0, TimeUnit.SECONDS);
handler.awaitConnected();
}
@Override
public boolean isOpen()
{
return state.get() == State.OPEN;
}
@Override
public void register(Consumer<InboundSseEvent> onEvent)
{
if (onEvent == null)
{
throw new IllegalArgumentException();
}
onEventConsumers.add(onEvent);
}
@Override
public void register(Consumer<InboundSseEvent> onEvent, Consumer<Throwable> onError)
{
if (onEvent == null)
{
throw new IllegalArgumentException();
}
if (onError == null)
{
throw new IllegalArgumentException();
}
onEventConsumers.add(onEvent);
onErrorConsumers.add(onError);
}
@Override
public void register(Consumer<InboundSseEvent> onEvent, Consumer<Throwable> onError, Runnable onComplete)
{
if (onEvent == null)
{
throw new IllegalArgumentException();
}
if (onError == null)
{
throw new IllegalArgumentException();
}
if (onComplete == null)
{
throw new IllegalArgumentException();
}
onEventConsumers.add(onEvent);
onErrorConsumers.add(onError);
onCompleteConsumers.add(onComplete);
}
@Override
public boolean close(final long timeout, final TimeUnit unit)
{
internalClose();
try
{
return sseEventSourceScheduler.awaitTermination(timeout, unit);
}
catch (InterruptedException e)
{
onErrorConsumers.forEach(consumer -> {
consumer.accept(e);
});
Thread.currentThread().interrupt();
return false;
}
}
private void internalClose()
{
if (state.getAndSet(State.CLOSED) == State.CLOSED)
{
return;
}
final ClientResponse clientResponse = response;
if (clientResponse != null)
{
try
{
clientResponse.releaseConnection(false);
}
catch (IOException e)
{
onErrorConsumers.forEach(consumer -> {
consumer.accept(e);
});
}
}
sseEventSourceScheduler.shutdownNow();
onCompleteConsumers.forEach(Runnable::run);
}
private class EventHandler implements Runnable
{
private final CountDownLatch connectedLatch;
private String lastEventId;
private long reconnectDelay;
private String verb;
private Entity<?> entity;
private MediaType[] mediaTypes;
EventHandler(final long reconnectDelay, final String lastEventId, final String verb, final Entity<?> entity, final MediaType... mediaTypes)
{
this.connectedLatch = new CountDownLatch(1);
this.reconnectDelay = reconnectDelay;
this.lastEventId = lastEventId;
this.verb = verb;
this.entity = entity;
this.mediaTypes = mediaTypes;
}
private EventHandler(final EventHandler anotherHandler)
{
this.connectedLatch = anotherHandler.connectedLatch;
this.reconnectDelay = anotherHandler.reconnectDelay;
this.lastEventId = anotherHandler.lastEventId;
this.verb = anotherHandler.verb;
this.entity = anotherHandler.entity;
this.mediaTypes = anotherHandler.mediaTypes;
}
@Override
public void run()
{
if (state.get() != State.OPEN)
{
return;
}
SseEventInputImpl eventInput = null;
try
{
final Invocation.Builder requestBuilder = buildRequest(mediaTypes);
Invocation request = null;
if (entity == null)
{
request = requestBuilder.build(verb);
}
else
{
request = requestBuilder.build(verb, entity);
}
final ClientResponse clientResponse = (ClientResponse) request.invoke();
response = clientResponse;
if (Family.SUCCESSFUL.equals(clientResponse.getStatusInfo().getFamily()))
{
onConnection();
eventInput = clientResponse.readEntity(SseEventInputImpl.class);
if (eventInput == null && !alwaysReconnect)
{
internalClose();
return;
}
}
else
{
clientResponse.bufferEntity();
ClientInvocation.handleErrorStatus(clientResponse);
}
}
catch (ServiceUnavailableException ex)
{
if (ex.hasRetryAfter())
{
onConnection();
Date requestTime = new Date();
long localReconnectDelay = ex.getRetryTime(requestTime).getTime() - requestTime.getTime();
onErrorConsumers.forEach(consumer -> {
consumer.accept(ex);
});
reconnect(localReconnectDelay);
}
else
{
onUnrecoverableError(ex);
}
return;
}
catch (Throwable e)
{
onUnrecoverableError(e);
return;
}
final Providers providers = (ClientConfiguration) target.getConfiguration();
while (!Thread.currentThread().isInterrupted() && state.get() == State.OPEN)
{
if (eventInput == null || eventInput.isClosed())
{
reconnect(reconnectDelay);
break;
}
try
{
InboundSseEvent event = eventInput.read(providers);
if (event != null)
{
onEvent(event);
}
else if (!alwaysReconnect)
{
internalClose();
break;
}
}
catch (IOException e)
{
reconnect(reconnectDelay);
break;
}
}
}
public void awaitConnected()
{
try
{
connectedLatch.await();
}
catch (InterruptedException ex)
{
Thread.currentThread().interrupt();
}
}
private void onConnection()
{
connectedLatch.countDown();
}
private void onUnrecoverableError(Throwable throwable)
{
connectedLatch.countDown();
onErrorConsumers.forEach(consumer -> {
consumer.accept(throwable);
});
internalClose();
}
private void onEvent(final InboundSseEvent event)
{
if (event.getId() != null)
{
lastEventId = event.getId();
}
if (event.isReconnectDelaySet())
{
reconnectDelay = event.getReconnectDelay();
}
onEventConsumers.forEach(consumer -> {
consumer.accept(event);
});
}
private Invocation.Builder buildRequest(MediaType... mediaTypes)
{
final Invocation.Builder request = (mediaTypes != null && mediaTypes.length > 0) ? target.request(mediaTypes) : target.request();
if (lastEventId != null && !lastEventId.isEmpty())
{
request.header(SseConstants.LAST_EVENT_ID_HEADER, lastEventId);
}
return request;
}
private void reconnect(final long delay)
{
if (state.get() != State.OPEN)
{
return;
}
EventHandler processor = new EventHandler(this);
sseEventSourceScheduler.schedule(processor, delay, TimeUnit.MILLISECONDS);
}
}
}