package com.datastax.oss.driver.internal.core.protocol;
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList;
import com.datastax.oss.protocol.internal.Frame;
import com.datastax.oss.protocol.internal.FrameCodec;
import com.datastax.oss.protocol.internal.PrimitiveCodec;
import com.datastax.oss.protocol.internal.Segment;
import com.datastax.oss.protocol.internal.SegmentBuilder;
import edu.umd.cs.findbugs.annotations.NonNull;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import java.util.ArrayList;
import java.util.List;
import net.jcip.annotations.NotThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@NotThreadSafe
public class ByteBufSegmentBuilder extends SegmentBuilder<ByteBuf, ChannelPromise> {
private static final Logger LOG = LoggerFactory.getLogger(ByteBufSegmentBuilder.class);
private final ChannelHandlerContext context;
private final String logPrefix;
public ByteBufSegmentBuilder(
@NonNull ChannelHandlerContext context,
@NonNull PrimitiveCodec<ByteBuf> primitiveCodec,
@NonNull FrameCodec<ByteBuf> frameCodec,
@NonNull String logPrefix) {
super(primitiveCodec, frameCodec);
this.context = context;
this.logPrefix = logPrefix;
}
@Override
@NonNull
protected ChannelPromise mergeStates(@NonNull List<ChannelPromise> framePromises) {
if (framePromises.size() == 1) {
return framePromises.get(0);
}
ChannelPromise segmentPromise = context.newPromise();
ImmutableList<ChannelPromise> dependents = ImmutableList.copyOf(framePromises);
segmentPromise.addListener(
future -> {
if (future.isSuccess()) {
for (ChannelPromise framePromise : dependents) {
framePromise.setSuccess();
}
} else {
Throwable cause = future.cause();
for (ChannelPromise framePromise : dependents) {
framePromise.setFailure(cause);
}
}
});
return segmentPromise;
}
@Override
@NonNull
protected List<ChannelPromise> splitState(@NonNull ChannelPromise framePromise, int sliceCount) {
List<ChannelPromise> slicePromises = new ArrayList<>(sliceCount);
for (int i = 0; i < sliceCount; i++) {
slicePromises.add(context.newPromise());
}
GenericFutureListener<Future<Void>> sliceListener =
new SliceWriteListener(framePromise, slicePromises);
for (int i = 0; i < sliceCount; i++) {
slicePromises.get(i).addListener(sliceListener);
}
return slicePromises;
}
@Override
protected void processSegment(
@NonNull Segment<ByteBuf> segment, @NonNull ChannelPromise segmentPromise) {
context.write(segment, segmentPromise);
}
@Override
protected void onLargeFrameSplit(@NonNull Frame frame, int frameLength, int sliceCount) {
LOG.trace(
"[{}] Frame {} is too large ({} > {}), splitting into {} segments",
logPrefix,
frame.streamId,
frameLength,
Segment.MAX_PAYLOAD_LENGTH,
sliceCount);
}
@Override
protected void onSegmentFull(
@NonNull Frame frame, int frameLength, int currentPayloadLength, int currentFrameCount) {
LOG.trace(
"[{}] Current self-contained segment is full ({}/{} bytes, {} frames), processing now",
logPrefix,
currentPayloadLength,
Segment.MAX_PAYLOAD_LENGTH,
currentFrameCount);
}
@Override
protected void onSmallFrameAdded(
@NonNull Frame frame, int frameLength, int currentPayloadLength, int currentFrameCount) {
LOG.trace(
"[{}] Added frame {} to current self-contained segment "
+ "(bringing it to {}/{} bytes, {} frames)",
logPrefix,
frame.streamId,
currentPayloadLength,
Segment.MAX_PAYLOAD_LENGTH,
currentFrameCount);
}
@Override
protected void onLastSegmentFlushed(int currentPayloadLength, int currentFrameCount) {
LOG.trace(
"[{}] Flushing last self-contained segment ({}/{} bytes, {} frames)",
logPrefix,
currentPayloadLength,
Segment.MAX_PAYLOAD_LENGTH,
currentFrameCount);
}
@NotThreadSafe
static class SliceWriteListener implements GenericFutureListener<Future<Void>> {
private final ChannelPromise parentPromise;
private final List<ChannelPromise> slicePromises;
private int remainingSlices;
SliceWriteListener(@NonNull ChannelPromise parentPromise, List<ChannelPromise> slicePromises) {
this.parentPromise = parentPromise;
this.slicePromises = slicePromises;
this.remainingSlices = slicePromises.size();
}
@Override
public void operationComplete(@NonNull Future<Void> future) {
if (!parentPromise.isDone()) {
if (future.isSuccess()) {
remainingSlices -= 1;
if (remainingSlices == 0) {
parentPromise.setSuccess();
}
} else {
parentPromise.setFailure(future.cause());
for (ChannelPromise slicePromise : slicePromises) {
slicePromise.cancel( false);
}
}
}
}
}
}