package jdk.incubator.http.internal.websocket;
import jdk.incubator.http.WebSocket;
import jdk.incubator.http.internal.common.Log;
import jdk.incubator.http.internal.common.Pair;
import jdk.incubator.http.internal.websocket.OpeningHandshake.Result;
import jdk.incubator.http.internal.websocket.OutgoingMessage.Binary;
import jdk.incubator.http.internal.websocket.OutgoingMessage.Close;
import jdk.incubator.http.internal.websocket.OutgoingMessage.Context;
import jdk.incubator.http.internal.websocket.OutgoingMessage.Ping;
import jdk.incubator.http.internal.websocket.OutgoingMessage.Pong;
import jdk.incubator.http.internal.websocket.OutgoingMessage.Text;
import java.io.IOException;
import java.net.ProtocolException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Function;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.failedFuture;
import static jdk.incubator.http.internal.common.Pair.pair;
import static jdk.incubator.http.internal.websocket.StatusCodes.CLOSED_ABNORMALLY;
import static jdk.incubator.http.internal.websocket.StatusCodes.NO_STATUS_CODE;
import static jdk.incubator.http.internal.websocket.StatusCodes.isLegalToSendFromClient;
final class WebSocketImpl implements WebSocket {
private final URI uri;
private final String subprotocol;
private final RawChannel channel;
private final Listener listener;
private boolean lastMethodInvoked;
private final AtomicBoolean outstandingSend = new AtomicBoolean();
private final CooperativeHandler sendHandler =
new CooperativeHandler(this::sendFirst);
private final Queue<Pair<OutgoingMessage, CompletableFuture<WebSocket>>>
queue = new ConcurrentLinkedQueue<>();
private final Context context = new OutgoingMessage.Context();
private final Transmitter transmitter;
private final Receiver receiver;
private final AtomicBoolean closed = new AtomicBoolean();
private final Object lock = new Object();
private final CompletableFuture<?> closeReceived = new CompletableFuture<>();
private final CompletableFuture<?> closeSent = new CompletableFuture<>();
static CompletableFuture<WebSocket> newInstanceAsync(BuilderImpl b) {
Function<Result, WebSocket> newWebSocket = r -> {
WebSocketImpl ws = new WebSocketImpl(b.getUri(),
r.subprotocol,
r.channel,
b.getListener());
ws.signalOpen();
return ws;
};
OpeningHandshake h;
try {
h = new OpeningHandshake(b);
} catch (IllegalArgumentException e) {
return failedFuture(e);
}
return h.send().thenApply(newWebSocket);
}
WebSocketImpl(URI uri,
String subprotocol,
RawChannel channel,
Listener listener)
{
this.uri = requireNonNull(uri);
this.subprotocol = requireNonNull(subprotocol);
this.channel = requireNonNull(channel);
this.listener = requireNonNull(listener);
this.transmitter = new Transmitter(channel);
this.receiver = new Receiver(messageConsumerOf(listener), channel);
CompletableFuture.allOf(closeReceived, closeSent)
.whenComplete((result, error) -> {
try {
channel.close();
} catch (IOException e) {
Log.logError(e);
} finally {
closed.set(true);
}
});
}
private void signalOpen() {
synchronized (lock) {
try {
listener.onOpen(this);
} catch (Exception e) {
signalError(e);
}
}
}
private void signalError(Throwable error) {
synchronized (lock) {
if (lastMethodInvoked) {
Log.logError(error);
} else {
lastMethodInvoked = true;
receiver.close();
try {
listener.onError(this, error);
} catch (Exception e) {
Log.logError(e);
}
}
}
}
private void processClose(int statusCode, String reason) {
receiver.close();
try {
channel.shutdownInput();
} catch (IOException e) {
Log.logError(e);
}
boolean alreadyCompleted = !closeReceived.complete(null);
if (alreadyCompleted) {
throw new InternalError();
}
int code;
if (statusCode == NO_STATUS_CODE || statusCode == CLOSED_ABNORMALLY) {
code = NORMAL_CLOSURE;
} else {
code = statusCode;
}
CompletionStage<?> readyToClose = signalClose(statusCode, reason);
if (readyToClose == null) {
readyToClose = CompletableFuture.completedFuture(null);
}
readyToClose.whenComplete((r, error) -> {
enqueueClose(new Close(code, ""))
.whenComplete((r1, error1) -> {
if (error1 != null) {
Log.logError(error1);
}
});
});
}
private CompletionStage<?> signalClose(int statusCode, String reason) {
synchronized (lock) {
if (lastMethodInvoked) {
Log.logTrace("Close: {0}, ''{1}''", statusCode, reason);
} else {
lastMethodInvoked = true;
receiver.close();
try {
return listener.onClose(this, statusCode, reason);
} catch (Exception e) {
Log.logError(e);
}
}
}
return null;
}
@Override
public CompletableFuture<WebSocket> sendText(CharSequence message,
boolean isLast)
{
return enqueueExclusively(new Text(message, isLast));
}
@Override
public CompletableFuture<WebSocket> sendBinary(ByteBuffer message,
boolean isLast)
{
return enqueueExclusively(new Binary(message, isLast));
}
@Override
public CompletableFuture<WebSocket> sendPing(ByteBuffer message) {
return enqueueExclusively(new Ping(message));
}
@Override
public CompletableFuture<WebSocket> sendPong(ByteBuffer message) {
return enqueueExclusively(new Pong(message));
}
@Override
public CompletableFuture<WebSocket> sendClose(int statusCode,
String reason) {
if (!isLegalToSendFromClient(statusCode)) {
return failedFuture(
new IllegalArgumentException("statusCode: " + statusCode));
}
Close msg;
try {
msg = new Close(statusCode, reason);
} catch (IllegalArgumentException e) {
return failedFuture(e);
}
return enqueueClose(msg);
}
private CompletableFuture<WebSocket> enqueueClose(Close m) {
return enqueue(m).whenComplete((r, error) -> {
try {
channel.shutdownOutput();
} catch (IOException e) {
Log.logError(e);
}
boolean alreadyCompleted = !closeSent.complete(null);
if (alreadyCompleted) {
throw new InternalError();
}
});
}
private CompletableFuture<WebSocket> enqueueExclusively(OutgoingMessage m)
{
if (closed.get()) {
return failedFuture(new IllegalStateException("Closed"));
}
if (!outstandingSend.compareAndSet(false, true)) {
return failedFuture(new IllegalStateException("Outstanding send"));
}
return enqueue(m).whenComplete((r, e) -> outstandingSend.set(false));
}
private CompletableFuture<WebSocket> enqueue(OutgoingMessage m) {
CompletableFuture<WebSocket> cf = new CompletableFuture<>();
boolean added = queue.add(pair(m, cf));
if (!added) {
throw new InternalError();
}
sendHandler.handle();
return cf;
}
private void sendFirst(Runnable whenSent) {
Pair<OutgoingMessage, CompletableFuture<WebSocket>> p = queue.poll();
if (p == null) {
whenSent.run();
return;
}
OutgoingMessage message = p.first;
CompletableFuture<WebSocket> cf = p.second;
try {
message.contextualize(context);
Consumer<Exception> h = e -> {
if (e == null) {
cf.complete(WebSocketImpl.this);
} else {
cf.completeExceptionally(e);
}
sendHandler.handle();
whenSent.run();
};
transmitter.send(message, h);
} catch (Exception t) {
cf.completeExceptionally(t);
}
}
@Override
public void request(long n) {
receiver.request(n);
}
@Override
public String getSubprotocol() {
return subprotocol;
}
@Override
public boolean isClosed() {
return closed.get();
}
@Override
public void abort() throws IOException {
try {
channel.close();
} finally {
closed.set(true);
signalClose(CLOSED_ABNORMALLY, "");
}
}
@Override
public String toString() {
return super.toString()
+ "[" + (closed.get() ? "CLOSED" : "OPEN") + "]: " + uri
+ (!subprotocol.isEmpty() ? ", subprotocol=" + subprotocol : "");
}
private MessageStreamConsumer messageConsumerOf(Listener listener) {
return new MessageStreamConsumer() {
@Override
public void onText(MessagePart part, CharSequence data) {
receiver.acknowledge();
synchronized (WebSocketImpl.this.lock) {
try {
listener.onText(WebSocketImpl.this, data, part);
} catch (Exception e) {
signalError(e);
}
}
}
@Override
public void onBinary(MessagePart part, ByteBuffer data) {
receiver.acknowledge();
synchronized (WebSocketImpl.this.lock) {
try {
listener.onBinary(WebSocketImpl.this, data.slice(), part);
} catch (Exception e) {
signalError(e);
}
}
}
@Override
public void onPing(ByteBuffer data) {
receiver.acknowledge();
ByteBuffer slice = data.slice();
ByteBuffer copy = ByteBuffer.allocate(data.remaining())
.put(data)
.flip();
CompletableFuture<WebSocket> pongSent = enqueue(new Pong(copy));
pongSent.whenComplete(
(r, error) -> {
if (error != null) {
WebSocketImpl.this.signalError(error);
}
}
);
synchronized (WebSocketImpl.this.lock) {
try {
listener.onPing(WebSocketImpl.this, slice);
} catch (Exception e) {
signalError(e);
}
}
}
@Override
public void onPong(ByteBuffer data) {
receiver.acknowledge();
synchronized (WebSocketImpl.this.lock) {
try {
listener.onPong(WebSocketImpl.this, data.slice());
} catch (Exception e) {
signalError(e);
}
}
}
@Override
public void onClose(int statusCode, CharSequence reason) {
receiver.acknowledge();
processClose(statusCode, reason.toString());
}
@Override
public void onError(Exception error) {
if (!(error instanceof FailWebSocketException)) {
signalError(error);
} else {
Exception ex = (Exception) new ProtocolException().initCause(error);
int code = ((FailWebSocketException) error).getStatusCode();
enqueueClose(new Close(code, ""))
.whenComplete((r, e) -> {
if (e != null) {
ex.addSuppressed(e);
}
try {
channel.close();
} catch (IOException e1) {
ex.addSuppressed(e1);
} finally {
closed.set(true);
}
signalError(ex);
});
}
}
@Override
public void onComplete() {
processClose(CLOSED_ABNORMALLY, "");
}
};
}
}