package com.datastax.oss.driver.internal.core.protocol;
import com.datastax.oss.driver.api.core.context.DriverContext;
import com.datastax.oss.driver.internal.core.util.DependencyCheck;
import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import java.nio.ByteBuffer;
import net.jcip.annotations.ThreadSafe;
import net.jpountz.lz4.LZ4Compressor;
import net.jpountz.lz4.LZ4Factory;
import net.jpountz.lz4.LZ4FastDecompressor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ThreadSafe
public class Lz4Compressor extends ByteBufCompressor {
private static final Logger LOG = LoggerFactory.getLogger(Lz4Compressor.class);
private final LZ4Compressor compressor;
private final LZ4FastDecompressor decompressor;
public Lz4Compressor(DriverContext context) {
this(context.getSessionName());
}
@VisibleForTesting
Lz4Compressor(String sessionName) {
if (DependencyCheck.LZ4.isPresent()) {
LZ4Factory lz4Factory = LZ4Factory.fastestInstance();
LOG.info("[{}] Using {}", sessionName, lz4Factory.toString());
this.compressor = lz4Factory.fastCompressor();
this.decompressor = lz4Factory.fastDecompressor();
} else {
throw new IllegalStateException(
"Could not find the LZ4 library on the classpath "
+ "(the driver declares it as an optional dependency, "
+ "so you need to declare it explicitly)");
}
}
@Override
public String algorithm() {
return "lz4";
}
@Override
protected ByteBuf compressDirect(ByteBuf input, boolean prependWithUncompressedLength) {
int maxCompressedLength = compressor.maxCompressedLength(input.readableBytes());
ByteBuf output =
input.alloc().directBuffer((prependWithUncompressedLength ? 4 : 0) + maxCompressedLength);
try {
ByteBuffer in = inputNioBuffer(input);
input.readerIndex(input.writerIndex());
if (prependWithUncompressedLength) {
output.writeInt(in.remaining());
}
ByteBuffer out = outputNioBuffer(output);
int written =
compressor.compress(
in, in.position(), in.remaining(), out, out.position(), out.remaining());
output.writerIndex(output.writerIndex() + written);
} catch (Exception e) {
output.release();
throw e;
}
return output;
}
@Override
protected ByteBuf compressHeap(ByteBuf input, boolean prependWithUncompressedLength) {
int maxCompressedLength = compressor.maxCompressedLength(input.readableBytes());
int inOffset = input.arrayOffset() + input.readerIndex();
byte[] in = input.array();
int len = input.readableBytes();
input.readerIndex(input.writerIndex());
ByteBuf output =
input.alloc().heapBuffer((prependWithUncompressedLength ? 4 : 0) + maxCompressedLength);
try {
if (prependWithUncompressedLength) {
output.writeInt(len);
}
int offset = output.arrayOffset() + output.writerIndex();
byte[] out = output.array();
int written = compressor.compress(in, inOffset, len, out, offset);
output.writerIndex(output.writerIndex() + written);
} catch (Exception e) {
output.release();
throw e;
}
return output;
}
@Override
protected int readUncompressedLength(ByteBuf compressed) {
return compressed.readInt();
}
@Override
protected ByteBuf decompressDirect(ByteBuf input, int uncompressedLength) {
int readable = input.readableBytes();
ByteBuffer in = inputNioBuffer(input);
input.readerIndex(input.writerIndex());
ByteBuf output = input.alloc().directBuffer(uncompressedLength);
try {
ByteBuffer out = outputNioBuffer(output);
int read = decompressor.decompress(in, in.position(), out, out.position(), out.remaining());
if (read != readable) {
throw new IllegalArgumentException("Compressed lengths mismatch");
}
output.writerIndex(output.writerIndex() + uncompressedLength);
} catch (Exception e) {
output.release();
throw e;
}
return output;
}
@Override
protected ByteBuf decompressHeap(ByteBuf input, int uncompressedLength) {
byte[] in = input.array();
int len = input.readableBytes();
int inOffset = input.arrayOffset() + input.readerIndex();
input.readerIndex(input.writerIndex());
ByteBuf output = input.alloc().heapBuffer(uncompressedLength);
try {
int offset = output.arrayOffset() + output.writerIndex();
byte[] out = output.array();
int read = decompressor.decompress(in, inOffset, out, offset, uncompressedLength);
if (read != len) {
throw new IllegalArgumentException("Compressed lengths mismatch");
}
output.writerIndex(output.writerIndex() + uncompressedLength);
} catch (Exception e) {
output.release();
throw e;
}
return output;
}
}