package org.apache.lucene.store;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.IntConsumer;
import java.util.function.IntFunction;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.UnicodeUtil;
public final class ByteBuffersDataOutput extends DataOutput implements Accountable {
private final static ByteBuffer EMPTY = ByteBuffer.allocate(0);
private final static byte [] EMPTY_BYTE_ARRAY = {};
public final static IntFunction<ByteBuffer> ALLOCATE_BB_ON_HEAP = ByteBuffer::allocate;
public final static Consumer<ByteBuffer> NO_REUSE = (bb) -> {
throw new RuntimeException("reset() is not allowed on this buffer.");
};
public final static class ByteBufferRecycler {
private final ArrayDeque<ByteBuffer> reuse = new ArrayDeque<>();
private final IntFunction<ByteBuffer> delegate;
public ByteBufferRecycler(IntFunction<ByteBuffer> delegate) {
this.delegate = Objects.requireNonNull(delegate);
}
public ByteBuffer allocate(int size) {
while (!reuse.isEmpty()) {
ByteBuffer bb = reuse.removeFirst();
if (bb.remaining() == size) {
return bb;
}
}
return delegate.apply(size);
}
public void reuse(ByteBuffer buffer) {
buffer.rewind();
reuse.addLast(buffer);
}
}
public final static int DEFAULT_MIN_BITS_PER_BLOCK = 10;
public final static int DEFAULT_MAX_BITS_PER_BLOCK = 26;
final static int MAX_BLOCKS_BEFORE_BLOCK_EXPANSION = 100;
private final int maxBitsPerBlock;
private final IntFunction<ByteBuffer> blockAllocate;
private final Consumer<ByteBuffer> blockReuse;
private int blockBits;
private final ArrayDeque<ByteBuffer> blocks = new ArrayDeque<>();
private ByteBuffer currentBlock = EMPTY;
public ByteBuffersDataOutput(long expectedSize) {
this(computeBlockSizeBitsFor(expectedSize), DEFAULT_MAX_BITS_PER_BLOCK, ALLOCATE_BB_ON_HEAP, NO_REUSE);
}
public ByteBuffersDataOutput() {
this(DEFAULT_MIN_BITS_PER_BLOCK, DEFAULT_MAX_BITS_PER_BLOCK, ALLOCATE_BB_ON_HEAP, NO_REUSE);
}
public ByteBuffersDataOutput(int minBitsPerBlock,
int maxBitsPerBlock,
IntFunction<ByteBuffer> blockAllocate,
Consumer<ByteBuffer> blockReuse) {
if (minBitsPerBlock < 10 ||
minBitsPerBlock > maxBitsPerBlock ||
maxBitsPerBlock > 31) {
throw new IllegalArgumentException(String.format(Locale.ROOT,
"Invalid arguments: %s %s",
minBitsPerBlock,
maxBitsPerBlock));
}
this.maxBitsPerBlock = maxBitsPerBlock;
this.blockBits = minBitsPerBlock;
this.blockAllocate = Objects.requireNonNull(blockAllocate, "Block allocator must not be null.");
this.blockReuse = Objects.requireNonNull(blockReuse, "Block reuse must not be null.");
}
@Override
public void writeByte(byte b) {
if (!currentBlock.hasRemaining()) {
appendBlock();
}
currentBlock.put(b);
}
@Override
public void writeBytes(byte[] src, int offset, int length) {
assert length >= 0;
while (length > 0) {
if (!currentBlock.hasRemaining()) {
appendBlock();
}
int chunk = Math.min(currentBlock.remaining(), length);
currentBlock.put(src, offset, chunk);
length -= chunk;
offset += chunk;
}
}
@Override
public void writeBytes(byte[] b, int length) {
writeBytes(b, 0, length);
}
public void writeBytes(byte[] b) {
writeBytes(b, 0, b.length);
}
public void writeBytes(ByteBuffer buffer) {
buffer = buffer.duplicate();
int length = buffer.remaining();
while (length > 0) {
if (!currentBlock.hasRemaining()) {
appendBlock();
}
int chunk = Math.min(currentBlock.remaining(), length);
buffer.limit(buffer.position() + chunk);
currentBlock.put(buffer);
length -= chunk;
}
}
public ArrayList<ByteBuffer> toBufferList() {
ArrayList<ByteBuffer> result = new ArrayList<>(Math.max(blocks.size(), 1));
if (blocks.isEmpty()) {
result.add(EMPTY);
} else {
for (ByteBuffer bb : blocks) {
bb = (ByteBuffer) bb.asReadOnlyBuffer().flip();
result.add(bb);
}
}
return result;
}
public ArrayList<ByteBuffer> toWriteableBufferList() {
ArrayList<ByteBuffer> result = new ArrayList<>(Math.max(blocks.size(), 1));
if (blocks.isEmpty()) {
result.add(EMPTY);
} else {
for (ByteBuffer bb : blocks) {
bb = (ByteBuffer) bb.duplicate().flip();
result.add(bb);
}
}
return result;
}
public ByteBuffersDataInput toDataInput() {
return new ByteBuffersDataInput(toBufferList());
}
public byte[] toArrayCopy() {
if (blocks.size() == 0) {
return EMPTY_BYTE_ARRAY;
}
long size = size();
if (size > Integer.MAX_VALUE) {
throw new RuntimeException("Data exceeds maximum size of a single byte array: " + size);
}
byte [] arr = new byte[Math.toIntExact(size())];
int offset = 0;
for (ByteBuffer bb : toBufferList()) {
int len = bb.remaining();
bb.get(arr, offset, len);
offset += len;
}
return arr;
}
public void copyTo(DataOutput output) throws IOException {
for (ByteBuffer bb : toBufferList()) {
if (bb.hasArray()) {
output.writeBytes(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining());
} else {
output.copyBytes(new ByteBuffersDataInput(Arrays.asList(bb)), bb.remaining());
}
}
}
public long size() {
long size = 0;
int blockCount = blocks.size();
if (blockCount >= 1) {
long fullBlockSize = (blockCount - 1L) * blockSize();
long lastBlockSize = blocks.getLast().position();
size = fullBlockSize + lastBlockSize;
}
return size;
}
@Override
public String toString() {
return String.format(Locale.ROOT,
"%,d bytes, block size: %,d, blocks: %,d",
size(),
blockSize(),
blocks.size());
}
@Override
public void writeShort(short v) {
try {
if (currentBlock.remaining() >= Short.BYTES) {
currentBlock.putShort(v);
} else {
super.writeShort(v);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public void writeInt(int v) {
try {
if (currentBlock.remaining() >= Integer.BYTES) {
currentBlock.putInt(v);
} else {
super.writeInt(v);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public void writeLong(long v) {
try {
if (currentBlock.remaining() >= Long.BYTES) {
currentBlock.putLong(v);
} else {
super.writeLong(v);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public void writeString(String v) {
try {
final int MAX_CHARS_PER_WINDOW = 1024;
if (v.length() <= MAX_CHARS_PER_WINDOW) {
final BytesRef utf8 = new BytesRef(v);
writeVInt(utf8.length);
writeBytes(utf8.bytes, utf8.offset, utf8.length);
} else {
writeVInt(UnicodeUtil.calcUTF16toUTF8Length(v, 0, v.length()));
final byte [] buf = new byte [UnicodeUtil.MAX_UTF8_BYTES_PER_CHAR * MAX_CHARS_PER_WINDOW];
UTF16toUTF8(v, 0, v.length(), buf, (len) -> {
writeBytes(buf, 0, len);
});
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public void writeMapOfStrings(Map<String, String> map) {
try {
super.writeMapOfStrings(map);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public void writeSetOfStrings(Set<String> set) {
try {
super.writeSetOfStrings(set);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public long ramBytesUsed() {
return RamUsageEstimator.NUM_BYTES_OBJECT_REF * blocks.size() +
blocks.stream().mapToLong(buf -> buf.capacity()).sum();
}
public void reset() {
blocks.stream().forEach(blockReuse);
blocks.clear();
currentBlock = EMPTY;
}
public static ByteBuffersDataOutput newResettableInstance() {
ByteBuffersDataOutput.ByteBufferRecycler reuser = new ByteBuffersDataOutput.ByteBufferRecycler(
ByteBuffersDataOutput.ALLOCATE_BB_ON_HEAP);
return new ByteBuffersDataOutput(
ByteBuffersDataOutput.DEFAULT_MIN_BITS_PER_BLOCK,
ByteBuffersDataOutput.DEFAULT_MAX_BITS_PER_BLOCK,
reuser::allocate,
reuser::reuse);
}
private int blockSize() {
return 1 << blockBits;
}
private void appendBlock() {
if (blocks.size() >= MAX_BLOCKS_BEFORE_BLOCK_EXPANSION && blockBits < maxBitsPerBlock) {
rewriteToBlockSize(blockBits + 1);
if (blocks.getLast().hasRemaining()) {
return;
}
}
final int requiredBlockSize = 1 << blockBits;
currentBlock = blockAllocate.apply(requiredBlockSize);
assert currentBlock.capacity() == requiredBlockSize;
blocks.add(currentBlock);
}
private void rewriteToBlockSize(int targetBlockBits) {
assert targetBlockBits <= maxBitsPerBlock;
ByteBuffersDataOutput cloned = new ByteBuffersDataOutput(targetBlockBits, targetBlockBits, blockAllocate, NO_REUSE);
ByteBuffer block;
while ((block = blocks.pollFirst()) != null) {
block.flip();
cloned.writeBytes(block);
if (blockReuse != NO_REUSE) {
blockReuse.accept(block);
}
}
assert blocks.isEmpty();
this.blockBits = targetBlockBits;
blocks.addAll(cloned.blocks);
}
private static int computeBlockSizeBitsFor(long bytes) {
long powerOfTwo = BitUtil.nextHighestPowerOfTwo(bytes / MAX_BLOCKS_BEFORE_BLOCK_EXPANSION);
if (powerOfTwo == 0) {
return DEFAULT_MIN_BITS_PER_BLOCK;
}
int blockBits = Long.numberOfTrailingZeros(powerOfTwo);
blockBits = Math.min(blockBits, DEFAULT_MAX_BITS_PER_BLOCK);
blockBits = Math.max(blockBits, DEFAULT_MIN_BITS_PER_BLOCK);
return blockBits;
}
private static final long HALF_SHIFT = 10;
private static final int SURROGATE_OFFSET =
Character.MIN_SUPPLEMENTARY_CODE_POINT -
(UnicodeUtil.UNI_SUR_HIGH_START << HALF_SHIFT) - UnicodeUtil.UNI_SUR_LOW_START;
private static int UTF16toUTF8(final CharSequence s,
final int offset,
final int length,
byte[] buf,
IntConsumer bufferFlusher) {
int utf8Len = 0;
int j = 0;
for (int i = offset, end = offset + length; i < end; i++) {
final int chr = (int) s.charAt(i);
if (j + 4 >= buf.length) {
bufferFlusher.accept(j);
utf8Len += j;
j = 0;
}
if (chr < 0x80)
buf[j++] = (byte) chr;
else if (chr < 0x800) {
buf[j++] = (byte) (0xC0 | (chr >> 6));
buf[j++] = (byte) (0x80 | (chr & 0x3F));
} else if (chr < 0xD800 || chr > 0xDFFF) {
buf[j++] = (byte) (0xE0 | (chr >> 12));
buf[j++] = (byte) (0x80 | ((chr >> 6) & 0x3F));
buf[j++] = (byte) (0x80 | (chr & 0x3F));
} else {
if (chr < 0xDC00 && (i < end - 1)) {
int utf32 = (int) s.charAt(i + 1);
if (utf32 >= 0xDC00 && utf32 <= 0xDFFF) {
utf32 = (chr << 10) + utf32 + SURROGATE_OFFSET;
i++;
buf[j++] = (byte) (0xF0 | (utf32 >> 18));
buf[j++] = (byte) (0x80 | ((utf32 >> 12) & 0x3F));
buf[j++] = (byte) (0x80 | ((utf32 >> 6) & 0x3F));
buf[j++] = (byte) (0x80 | (utf32 & 0x3F));
continue;
}
}
buf[j++] = (byte) 0xEF;
buf[j++] = (byte) 0xBF;
buf[j++] = (byte) 0xBD;
}
}
bufferFlusher.accept(j);
utf8Len += j;
return utf8Len;
}
}