package org.apache.cassandra.io.util;
import java.io.Closeable;
import java.io.File;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.apache.cassandra.utils.Throwables.maybeFail;
import static org.apache.cassandra.utils.Throwables.merge;
public class RewindableDataInputStreamPlus extends FilterInputStream implements RewindableDataInput, Closeable
{
private boolean marked = false;
private boolean exhausted = false;
private AtomicBoolean closed = new AtomicBoolean(false);
protected int memAvailable = 0;
protected int diskTailAvailable = 0;
protected int diskHeadAvailable = 0;
private final File spillFile;
private final int initialMemBufferSize;
private final int maxMemBufferSize;
private final int maxDiskBufferSize;
private volatile byte memBuffer[];
private int memBufferSize;
private RandomAccessFile spillBuffer;
private final DataInputPlus dataReader;
public RewindableDataInputStreamPlus(InputStream in, int initialMemBufferSize, int maxMemBufferSize,
File spillFile, int maxDiskBufferSize)
{
super(in);
dataReader = new DataInputStreamPlus(this);
this.initialMemBufferSize = initialMemBufferSize;
this.maxMemBufferSize = maxMemBufferSize;
this.spillFile = spillFile;
this.maxDiskBufferSize = maxDiskBufferSize;
}
public DataPosition mark()
{
mark(0);
return new RewindableDataInputPlusMark();
}
public void reset(DataPosition mark) throws IOException
{
reset();
}
public long bytesPastMark(DataPosition mark)
{
return maxMemBufferSize - memAvailable + (diskTailAvailable == -1? 0 : maxDiskBufferSize - diskHeadAvailable - diskTailAvailable);
}
protected static class RewindableDataInputPlusMark implements DataPosition
{
}
public boolean markSupported()
{
return true;
}
public synchronized void mark(int readlimit)
{
if (marked)
throw new IllegalStateException("Cannot mark already marked stream.");
if (memAvailable > 0 || diskHeadAvailable > 0 || diskTailAvailable > 0)
throw new IllegalStateException("Can only mark stream after reading previously marked data.");
marked = true;
memAvailable = maxMemBufferSize;
diskHeadAvailable = -1;
diskTailAvailable = -1;
}
public synchronized void reset() throws IOException
{
if (!marked)
throw new IOException("Must call mark() before calling reset().");
if (exhausted)
throw new IOException(String.format("Read more than capacity: %d bytes.", maxMemBufferSize + maxDiskBufferSize));
memAvailable = maxMemBufferSize - memAvailable;
memBufferSize = memAvailable;
if (diskTailAvailable == -1)
{
diskHeadAvailable = 0;
diskTailAvailable = 0;
}
else
{
int initialPos = diskTailAvailable > 0 ? 0 : (int)getIfNotClosed(spillBuffer).getFilePointer();
int diskMarkpos = initialPos + diskHeadAvailable;
getIfNotClosed(spillBuffer).seek(diskMarkpos);
diskHeadAvailable = diskMarkpos - diskHeadAvailable;
diskTailAvailable = (maxDiskBufferSize - diskTailAvailable) - diskMarkpos;
}
marked = false;
}
public int available() throws IOException
{
return super.available() + (marked? 0 : memAvailable + diskHeadAvailable + diskTailAvailable);
}
public int read() throws IOException
{
int read = readOne();
if (read == -1)
return read;
if (marked)
{
if (isExhausted(1))
{
exhausted = true;
return read;
}
writeOne(read);
}
return read;
}
public int read(byte[] b, int off, int len) throws IOException
{
int readBytes = readMulti(b, off, len);
if (readBytes == -1)
return readBytes;
if (marked)
{
if (isExhausted(readBytes))
{
exhausted = true;
return readBytes;
}
writeMulti(b, off, readBytes);
}
return readBytes;
}
private void maybeCreateDiskBuffer() throws IOException
{
if (spillBuffer == null)
{
if (!spillFile.getParentFile().exists())
spillFile.getParentFile().mkdirs();
spillFile.createNewFile();
this.spillBuffer = new RandomAccessFile(spillFile, "rw");
}
}
private int readOne() throws IOException
{
if (!marked)
{
if (memAvailable > 0)
{
int pos = memBufferSize - memAvailable;
memAvailable--;
return getIfNotClosed(memBuffer)[pos] & 0xff;
}
if (diskTailAvailable > 0 || diskHeadAvailable > 0)
{
int read = getIfNotClosed(spillBuffer).read();
if (diskTailAvailable > 0)
diskTailAvailable--;
else if (diskHeadAvailable > 0)
diskHeadAvailable++;
if (diskTailAvailable == 0)
spillBuffer.seek(0);
return read;
}
}
return getIfNotClosed(in).read();
}
private boolean isExhausted(int readBytes)
{
return exhausted || readBytes > memAvailable + (long)(diskTailAvailable == -1? maxDiskBufferSize : diskTailAvailable + diskHeadAvailable);
}
private int readMulti(byte[] b, int off, int len) throws IOException
{
int readBytes = 0;
if (!marked)
{
if (memAvailable > 0)
{
readBytes += memAvailable < len ? memAvailable : len;
int pos = memBufferSize - memAvailable;
System.arraycopy(memBuffer, pos, b, off, readBytes);
memAvailable -= readBytes;
off += readBytes;
len -= readBytes;
}
if (len > 0 && diskTailAvailable > 0)
{
int readFromTail = diskTailAvailable < len? diskTailAvailable : len;
readFromTail = getIfNotClosed(spillBuffer).read(b, off, readFromTail);
readBytes += readFromTail;
diskTailAvailable -= readFromTail;
off += readFromTail;
len -= readFromTail;
if (diskTailAvailable == 0)
spillBuffer.seek(0);
}
if (len > 0 && diskHeadAvailable > 0)
{
int readFromHead = diskHeadAvailable < len? diskHeadAvailable : len;
readFromHead = getIfNotClosed(spillBuffer).read(b, off, readFromHead);
readBytes += readFromHead;
diskHeadAvailable -= readFromHead;
off += readFromHead;
len -= readFromHead;
}
}
if (len > 0)
readBytes += getIfNotClosed(in).read(b, off, len);
return readBytes;
}
private void writeMulti(byte[] b, int off, int len) throws IOException
{
if (memAvailable > 0)
{
if (memBuffer == null)
memBuffer = new byte[initialMemBufferSize];
int pos = maxMemBufferSize - memAvailable;
int memWritten = memAvailable < len? memAvailable : len;
if (pos + memWritten >= getIfNotClosed(memBuffer).length)
growMemBuffer(pos, memWritten);
System.arraycopy(b, off, memBuffer, pos, memWritten);
off += memWritten;
len -= memWritten;
memAvailable -= memWritten;
}
if (len > 0)
{
if (diskTailAvailable == -1)
{
maybeCreateDiskBuffer();
diskHeadAvailable = (int)spillBuffer.getFilePointer();
diskTailAvailable = maxDiskBufferSize - diskHeadAvailable;
}
if (len > 0 && diskTailAvailable > 0)
{
int diskTailWritten = diskTailAvailable < len? diskTailAvailable : len;
getIfNotClosed(spillBuffer).write(b, off, diskTailWritten);
off += diskTailWritten;
len -= diskTailWritten;
diskTailAvailable -= diskTailWritten;
if (diskTailAvailable == 0)
spillBuffer.seek(0);
}
if (len > 0 && diskTailAvailable > 0)
{
int diskHeadWritten = diskHeadAvailable < len? diskHeadAvailable : len;
getIfNotClosed(spillBuffer).write(b, off, diskHeadWritten);
}
}
}
private void writeOne(int value) throws IOException
{
if (memAvailable > 0)
{
if (memBuffer == null)
memBuffer = new byte[initialMemBufferSize];
int pos = maxMemBufferSize - memAvailable;
if (pos == getIfNotClosed(memBuffer).length)
growMemBuffer(pos, 1);
getIfNotClosed(memBuffer)[pos] = (byte)value;
memAvailable--;
return;
}
if (diskTailAvailable == -1)
{
maybeCreateDiskBuffer();
diskHeadAvailable = (int)spillBuffer.getFilePointer();
diskTailAvailable = maxDiskBufferSize - diskHeadAvailable;
}
if (diskTailAvailable > 0 || diskHeadAvailable > 0)
{
getIfNotClosed(spillBuffer).write(value);
if (diskTailAvailable > 0)
diskTailAvailable--;
else if (diskHeadAvailable > 0)
diskHeadAvailable--;
if (diskTailAvailable == 0)
spillBuffer.seek(0);
return;
}
}
public int read(byte[] b) throws IOException
{
return read(b, 0, b.length);
}
private void growMemBuffer(int pos, int writeSize)
{
int newSize = Math.min(2 * (pos + writeSize), maxMemBufferSize);
byte newBuffer[] = new byte[newSize];
System.arraycopy(memBuffer, 0, newBuffer, 0, pos);
memBuffer = newBuffer;
}
public long skip(long n) throws IOException
{
long skipped = 0;
if (marked)
{
while (n-- > 0 && read() != -1)
{
skipped++;
}
return skipped;
}
if (memAvailable > 0)
{
skipped += memAvailable < n ? memAvailable : n;
memAvailable -= skipped;
n -= skipped;
}
if (n > 0 && diskTailAvailable > 0)
{
int skipFromTail = diskTailAvailable < n? diskTailAvailable : (int)n;
getIfNotClosed(spillBuffer).skipBytes(skipFromTail);
diskTailAvailable -= skipFromTail;
skipped += skipFromTail;
n -= skipFromTail;
if (diskTailAvailable == 0)
spillBuffer.seek(0);
}
if (n > 0 && diskHeadAvailable > 0)
{
int skipFromHead = diskHeadAvailable < n? diskHeadAvailable : (int)n;
getIfNotClosed(spillBuffer).skipBytes(skipFromHead);
diskHeadAvailable -= skipFromHead;
skipped += skipFromHead;
n -= skipFromHead;
}
if (n > 0)
skipped += getIfNotClosed(in).skip(n);
return skipped;
}
private <T> T getIfNotClosed(T in) throws IOException
{
if (closed.get())
throw new IOException("Stream closed");
return in;
}
public void close() throws IOException
{
close(true);
}
public void close(boolean closeUnderlying) throws IOException
{
if (closed.compareAndSet(false, true))
{
Throwable fail = null;
if (closeUnderlying)
{
try
{
super.close();
}
catch (IOException e)
{
fail = merge(fail, e);
}
}
try
{
if (spillBuffer != null)
{
this.spillBuffer.close();
this.spillBuffer = null;
}
} catch (IOException e)
{
fail = merge(fail, e);
}
try
{
if (spillFile.exists())
{
spillFile.delete();
}
}
catch (Throwable e)
{
fail = merge(fail, e);
}
maybeFail(fail, IOException.class);
}
}
public void readFully(byte[] b) throws IOException
{
dataReader.readFully(b);
}
public void readFully(byte[] b, int off, int len) throws IOException
{
dataReader.readFully(b, off, len);
}
public int skipBytes(int n) throws IOException
{
return dataReader.skipBytes(n);
}
public boolean readBoolean() throws IOException
{
return dataReader.readBoolean();
}
public byte readByte() throws IOException
{
return dataReader.readByte();
}
public int readUnsignedByte() throws IOException
{
return dataReader.readUnsignedByte();
}
public short readShort() throws IOException
{
return dataReader.readShort();
}
public int readUnsignedShort() throws IOException
{
return dataReader.readUnsignedShort();
}
public char readChar() throws IOException
{
return dataReader.readChar();
}
public int readInt() throws IOException
{
return dataReader.readInt();
}
public long readLong() throws IOException
{
return dataReader.readLong();
}
public float readFloat() throws IOException
{
return dataReader.readFloat();
}
public double readDouble() throws IOException
{
return dataReader.readDouble();
}
public String readLine() throws IOException
{
return dataReader.readLine();
}
public String readUTF() throws IOException
{
return dataReader.readUTF();
}
}