package org.springframework.http.server.reactive;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.handlers.Cookie;
import io.undertow.server.handlers.CookieImpl;
import org.reactivestreams.Processor;
import org.reactivestreams.Publisher;
import org.xnio.channels.Channels;
import org.xnio.channels.StreamSinkChannel;
import reactor.core.publisher.Mono;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseCookie;
import org.springframework.http.ZeroCopyHttpOutputMessage;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse implements ZeroCopyHttpOutputMessage {
private final HttpServerExchange exchange;
private final UndertowServerHttpRequest request;
@Nullable
private StreamSinkChannel responseChannel;
UndertowServerHttpResponse(
HttpServerExchange exchange, DataBufferFactory bufferFactory, UndertowServerHttpRequest request) {
super(bufferFactory, createHeaders(exchange));
Assert.notNull(exchange, "HttpServerExchange must not be null");
this.exchange = exchange;
this.request = request;
}
private static HttpHeaders (HttpServerExchange exchange) {
UndertowHeadersAdapter headersMap =
new UndertowHeadersAdapter(exchange.getResponseHeaders());
return new HttpHeaders(headersMap);
}
@SuppressWarnings("unchecked")
@Override
public <T> T getNativeResponse() {
return (T) this.exchange;
}
@Override
public HttpStatus getStatusCode() {
HttpStatus httpStatus = super.getStatusCode();
return httpStatus != null ? httpStatus : HttpStatus.resolve(this.exchange.getStatusCode());
}
@Override
protected void applyStatusCode() {
Integer statusCode = getStatusCodeValue();
if (statusCode != null) {
this.exchange.setStatusCode(statusCode);
}
}
@Override
protected void () {
}
@Override
protected void applyCookies() {
for (String name : getCookies().keySet()) {
for (ResponseCookie httpCookie : getCookies().get(name)) {
Cookie cookie = new CookieImpl(name, httpCookie.getValue());
if (!httpCookie.getMaxAge().isNegative()) {
cookie.setMaxAge((int) httpCookie.getMaxAge().getSeconds());
}
if (httpCookie.getDomain() != null) {
cookie.setDomain(httpCookie.getDomain());
}
if (httpCookie.getPath() != null) {
cookie.setPath(httpCookie.getPath());
}
cookie.setSecure(httpCookie.isSecure());
cookie.setHttpOnly(httpCookie.isHttpOnly());
this.exchange.getResponseCookies().putIfAbsent(name, cookie);
}
}
}
@Override
public Mono<Void> writeWith(Path file, long position, long count) {
return doCommit(() ->
Mono.defer(() -> {
try (FileChannel source = FileChannel.open(file, StandardOpenOption.READ)) {
StreamSinkChannel destination = this.exchange.getResponseChannel();
Channels.transferBlocking(destination, source, position, count);
return Mono.empty();
}
catch (IOException ex) {
return Mono.error(ex);
}
}));
}
@Override
protected Processor<? super Publisher<? extends DataBuffer>, Void> createBodyFlushProcessor() {
return new ResponseBodyFlushProcessor();
}
private ResponseBodyProcessor createBodyProcessor() {
if (this.responseChannel == null) {
this.responseChannel = this.exchange.getResponseChannel();
}
return new ResponseBodyProcessor(this.responseChannel);
}
private class ResponseBodyProcessor extends AbstractListenerWriteProcessor<DataBuffer> {
private final StreamSinkChannel channel;
@Nullable
private volatile ByteBuffer byteBuffer;
private volatile boolean writePossible;
public ResponseBodyProcessor(StreamSinkChannel channel) {
super(request.getLogPrefix());
Assert.notNull(channel, "StreamSinkChannel must not be null");
this.channel = channel;
this.channel.getWriteSetter().set(c -> {
this.writePossible = true;
onWritePossible();
});
this.channel.suspendWrites();
}
@Override
protected boolean isWritePossible() {
this.channel.resumeWrites();
return this.writePossible;
}
@Override
protected boolean write(DataBuffer dataBuffer) throws IOException {
ByteBuffer buffer = this.byteBuffer;
if (buffer == null) {
return false;
}
this.writePossible = false;
int total = buffer.remaining();
int written = writeByteBuffer(buffer);
if (logger.isTraceEnabled()) {
logger.trace(getLogPrefix() + "Wrote " + written + " of " + total + " bytes");
}
else if (rsWriteLogger.isTraceEnabled()) {
rsWriteLogger.trace(getLogPrefix() + "Wrote " + written + " of " + total + " bytes");
}
if (written != total) {
return false;
}
this.writePossible = true;
DataBufferUtils.release(dataBuffer);
this.byteBuffer = null;
return true;
}
private int writeByteBuffer(ByteBuffer byteBuffer) throws IOException {
int written;
int totalWritten = 0;
do {
written = this.channel.write(byteBuffer);
totalWritten += written;
}
while (byteBuffer.hasRemaining() && written > 0);
return totalWritten;
}
@Override
protected void dataReceived(DataBuffer dataBuffer) {
super.dataReceived(dataBuffer);
this.byteBuffer = dataBuffer.asByteBuffer();
}
@Override
protected boolean isDataEmpty(DataBuffer dataBuffer) {
return (dataBuffer.readableByteCount() == 0);
}
@Override
protected void writingComplete() {
this.channel.getWriteSetter().set(null);
this.channel.resumeWrites();
}
@Override
protected void writingFailed(Throwable ex) {
cancel();
onError(ex);
}
@Override
protected void discardData(DataBuffer dataBuffer) {
DataBufferUtils.release(dataBuffer);
}
}
private class ResponseBodyFlushProcessor extends AbstractListenerWriteFlushProcessor<DataBuffer> {
public ResponseBodyFlushProcessor() {
super(request.getLogPrefix());
}
@Override
protected Processor<? super DataBuffer, Void> createWriteProcessor() {
return UndertowServerHttpResponse.this.createBodyProcessor();
}
@Override
protected void flush() throws IOException {
StreamSinkChannel channel = UndertowServerHttpResponse.this.responseChannel;
if (channel != null) {
if (rsWriteFlushLogger.isTraceEnabled()) {
rsWriteFlushLogger.trace(getLogPrefix() + "flush");
}
channel.flush();
}
}
@Override
protected void flushingFailed(Throwable t) {
cancel();
onError(t);
}
@Override
protected boolean isWritePossible() {
StreamSinkChannel channel = UndertowServerHttpResponse.this.responseChannel;
if (channel != null) {
channel.resumeWrites();
return true;
}
return false;
}
@Override
protected boolean isFlushPending() {
return false;
}
}
}