package org.springframework.http.client.reactive;
import java.util.Collection;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.DefaultCookie;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Flux;
import reactor.netty.Connection;
import reactor.netty.NettyInbound;
import reactor.netty.http.client.HttpClientResponse;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseCookie;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
class ReactorClientHttpResponse implements ClientHttpResponse {
private static final Log logger = LogFactory.getLog(ReactorClientHttpResponse.class);
private final HttpClientResponse response;
private final HttpHeaders ;
private final NettyInbound inbound;
private final NettyDataBufferFactory bufferFactory;
private final AtomicInteger state = new AtomicInteger();
private final String logPrefix;
public ReactorClientHttpResponse(HttpClientResponse response, Connection connection) {
this.response = response;
MultiValueMap<String, String> adapter = new NettyHeadersAdapter(response.responseHeaders());
this.headers = HttpHeaders.readOnlyHttpHeaders(adapter);
this.inbound = connection.inbound();
this.bufferFactory = new NettyDataBufferFactory(connection.outbound().alloc());
this.logPrefix = (logger.isDebugEnabled() ? "[" + connection.channel().id().asShortText() + "] " : "");
}
@Deprecated
public ReactorClientHttpResponse(HttpClientResponse response, NettyInbound inbound, ByteBufAllocator alloc) {
this.response = response;
MultiValueMap<String, String> adapter = new NettyHeadersAdapter(response.responseHeaders());
this.headers = HttpHeaders.readOnlyHttpHeaders(adapter);
this.inbound = inbound;
this.bufferFactory = new NettyDataBufferFactory(alloc);
this.logPrefix = "";
}
@Override
public Flux<DataBuffer> getBody() {
return this.inbound.receive()
.doOnSubscribe(s -> {
if (this.state.compareAndSet(0, 1)) {
return;
}
if (this.state.get() == 2) {
throw new IllegalStateException(
"The client response body has been released already due to cancellation.");
}
})
.map(byteBuf -> {
byteBuf.retain();
return this.bufferFactory.wrap(byteBuf);
});
}
@Override
public HttpHeaders () {
return this.headers;
}
@Override
public HttpStatus getStatusCode() {
return HttpStatus.valueOf(getRawStatusCode());
}
@Override
public int getRawStatusCode() {
return this.response.status().code();
}
@Override
public MultiValueMap<String, ResponseCookie> getCookies() {
MultiValueMap<String, ResponseCookie> result = new LinkedMultiValueMap<>();
this.response.cookies().values().stream()
.flatMap(Collection::stream)
.forEach(cookie -> result.add(cookie.name(),
ResponseCookie.fromClientResponse(cookie.name(), cookie.value())
.domain(cookie.domain())
.path(cookie.path())
.maxAge(cookie.maxAge())
.secure(cookie.isSecure())
.httpOnly(cookie.isHttpOnly())
.sameSite(getSameSite(cookie))
.build()));
return CollectionUtils.unmodifiableMultiValueMap(result);
}
@Nullable
private static String getSameSite(Cookie cookie) {
if (cookie instanceof DefaultCookie) {
DefaultCookie defaultCookie = (DefaultCookie) cookie;
if (defaultCookie.sameSite() != null) {
return defaultCookie.sameSite().name();
}
}
return null;
}
void releaseAfterCancel(HttpMethod method) {
if (mayHaveBody(method) && this.state.compareAndSet(0, 2)) {
if (logger.isDebugEnabled()) {
logger.debug(this.logPrefix + "Releasing body, not yet subscribed.");
}
this.inbound.receive().doOnNext(byteBuf -> {}).subscribe(byteBuf -> {}, ex -> {});
}
}
private boolean mayHaveBody(HttpMethod method) {
int code = this.getRawStatusCode();
return !((code >= 100 && code < 200) || code == 204 || code == 205 ||
method.equals(HttpMethod.HEAD) || getHeaders().getContentLength() == 0);
}
@Override
public String toString() {
return "ReactorClientHttpResponse{" +
"request=[" + this.response.method().name() + " " + this.response.uri() + "]," +
"status=" + getRawStatusCode() + '}';
}
}