/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.cassandra.io.util;

import java.io.Closeable;
import java.io.File;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.apache.cassandra.utils.Throwables.maybeFail;
import static org.apache.cassandra.utils.Throwables.merge;

Adds mark/reset functionality to another input stream by caching read bytes to a memory buffer and spilling to disk if necessary. When the stream is marked via mark() or mark(int), up to maxMemBufferSize will be cached in memory (heap). If more than maxMemBufferSize bytes are read while the stream is marked, the following bytes are cached on the spillFile for up to maxDiskBufferSize. Please note that successive calls to mark() and reset() will write sequentially to the same spillFile until maxDiskBufferSize is reached. At this point, if less than maxDiskBufferSize bytes are currently cached on the spillFile, the remaining bytes are written to the beginning of the file, treating the spillFile as a circular buffer. If more than maxMemBufferSize + maxDiskBufferSize are cached while the stream is marked, the following reset() invocation will throw a IllegalStateException.
/** * Adds mark/reset functionality to another input stream by caching read bytes to a memory buffer and * spilling to disk if necessary. * * When the stream is marked via {@link #mark()} or {@link #mark(int)}, up to * <code>maxMemBufferSize</code> will be cached in memory (heap). If more than * <code>maxMemBufferSize</code> bytes are read while the stream is marked, the * following bytes are cached on the <code>spillFile</code> for up to <code>maxDiskBufferSize</code>. * * Please note that successive calls to {@link #mark()} and {@link #reset()} will write * sequentially to the same <code>spillFile</code> until <code>maxDiskBufferSize</code> is reached. * At this point, if less than <code>maxDiskBufferSize</code> bytes are currently cached on the * <code>spillFile</code>, the remaining bytes are written to the beginning of the file, * treating the <code>spillFile</code> as a circular buffer. * * If more than <code>maxMemBufferSize + maxDiskBufferSize</code> are cached while the stream is marked, * the following {@link #reset()} invocation will throw a {@link IllegalStateException}. * */
public class RewindableDataInputStreamPlus extends FilterInputStream implements RewindableDataInput, Closeable { private boolean marked = false; private boolean exhausted = false; private AtomicBoolean closed = new AtomicBoolean(false); protected int memAvailable = 0; protected int diskTailAvailable = 0; protected int diskHeadAvailable = 0; private final File spillFile; private final int initialMemBufferSize; private final int maxMemBufferSize; private final int maxDiskBufferSize; private volatile byte memBuffer[]; private int memBufferSize; private RandomAccessFile spillBuffer; private final DataInputPlus dataReader; public RewindableDataInputStreamPlus(InputStream in, int initialMemBufferSize, int maxMemBufferSize, File spillFile, int maxDiskBufferSize) { super(in); dataReader = new DataInputStreamPlus(this); this.initialMemBufferSize = initialMemBufferSize; this.maxMemBufferSize = maxMemBufferSize; this.spillFile = spillFile; this.maxDiskBufferSize = maxDiskBufferSize; } /* RewindableDataInput methods */
Marks the current position of a stream to return to this position later via the reset(DataPosition) method.
Returns:An empty @link{DataPosition} object
/** * Marks the current position of a stream to return to this position later via the {@link #reset(DataPosition)} method. * @return An empty @link{DataPosition} object */
public DataPosition mark() { mark(0); return new RewindableDataInputPlusMark(); }
Rewinds to the previously marked position via the mark() method.
Params:
  • mark – it's not possible to return to a custom position, so this parameter is ignored.
Throws:
/** * Rewinds to the previously marked position via the {@link #mark()} method. * @param mark it's not possible to return to a custom position, so this parameter is ignored. * @throws IOException if an error ocurs while resetting */
public void reset(DataPosition mark) throws IOException { reset(); } public long bytesPastMark(DataPosition mark) { return maxMemBufferSize - memAvailable + (diskTailAvailable == -1? 0 : maxDiskBufferSize - diskHeadAvailable - diskTailAvailable); } protected static class RewindableDataInputPlusMark implements DataPosition { } /* InputStream methods */ public boolean markSupported() { return true; }
Marks the current position of a stream to return to this position later via the reset() method.
Params:
  • readlimit – the maximum amount of bytes to cache
/** * Marks the current position of a stream to return to this position * later via the {@link #reset()} method. * @param readlimit the maximum amount of bytes to cache */
public synchronized void mark(int readlimit) { if (marked) throw new IllegalStateException("Cannot mark already marked stream."); if (memAvailable > 0 || diskHeadAvailable > 0 || diskTailAvailable > 0) throw new IllegalStateException("Can only mark stream after reading previously marked data."); marked = true; memAvailable = maxMemBufferSize; diskHeadAvailable = -1; diskTailAvailable = -1; } public synchronized void reset() throws IOException { if (!marked) throw new IOException("Must call mark() before calling reset()."); if (exhausted) throw new IOException(String.format("Read more than capacity: %d bytes.", maxMemBufferSize + maxDiskBufferSize)); memAvailable = maxMemBufferSize - memAvailable; memBufferSize = memAvailable; if (diskTailAvailable == -1) { diskHeadAvailable = 0; diskTailAvailable = 0; } else { int initialPos = diskTailAvailable > 0 ? 0 : (int)getIfNotClosed(spillBuffer).getFilePointer(); int diskMarkpos = initialPos + diskHeadAvailable; getIfNotClosed(spillBuffer).seek(diskMarkpos); diskHeadAvailable = diskMarkpos - diskHeadAvailable; diskTailAvailable = (maxDiskBufferSize - diskTailAvailable) - diskMarkpos; } marked = false; } public int available() throws IOException { return super.available() + (marked? 0 : memAvailable + diskHeadAvailable + diskTailAvailable); } public int read() throws IOException { int read = readOne(); if (read == -1) return read; if (marked) { //mark exhausted if (isExhausted(1)) { exhausted = true; return read; } writeOne(read); } return read; } public int read(byte[] b, int off, int len) throws IOException { int readBytes = readMulti(b, off, len); if (readBytes == -1) return readBytes; if (marked) { //check we have space on buffer if (isExhausted(readBytes)) { exhausted = true; return readBytes; } writeMulti(b, off, readBytes); } return readBytes; } private void maybeCreateDiskBuffer() throws IOException { if (spillBuffer == null) { if (!spillFile.getParentFile().exists()) spillFile.getParentFile().mkdirs(); spillFile.createNewFile(); this.spillBuffer = new RandomAccessFile(spillFile, "rw"); } } private int readOne() throws IOException { if (!marked) { if (memAvailable > 0) { int pos = memBufferSize - memAvailable; memAvailable--; return getIfNotClosed(memBuffer)[pos] & 0xff; } if (diskTailAvailable > 0 || diskHeadAvailable > 0) { int read = getIfNotClosed(spillBuffer).read(); if (diskTailAvailable > 0) diskTailAvailable--; else if (diskHeadAvailable > 0) diskHeadAvailable++; if (diskTailAvailable == 0) spillBuffer.seek(0); return read; } } return getIfNotClosed(in).read(); } private boolean isExhausted(int readBytes) { return exhausted || readBytes > memAvailable + (long)(diskTailAvailable == -1? maxDiskBufferSize : diskTailAvailable + diskHeadAvailable); } private int readMulti(byte[] b, int off, int len) throws IOException { int readBytes = 0; if (!marked) { if (memAvailable > 0) { readBytes += memAvailable < len ? memAvailable : len; int pos = memBufferSize - memAvailable; System.arraycopy(memBuffer, pos, b, off, readBytes); memAvailable -= readBytes; off += readBytes; len -= readBytes; } if (len > 0 && diskTailAvailable > 0) { int readFromTail = diskTailAvailable < len? diskTailAvailable : len; readFromTail = getIfNotClosed(spillBuffer).read(b, off, readFromTail); readBytes += readFromTail; diskTailAvailable -= readFromTail; off += readFromTail; len -= readFromTail; if (diskTailAvailable == 0) spillBuffer.seek(0); } if (len > 0 && diskHeadAvailable > 0) { int readFromHead = diskHeadAvailable < len? diskHeadAvailable : len; readFromHead = getIfNotClosed(spillBuffer).read(b, off, readFromHead); readBytes += readFromHead; diskHeadAvailable -= readFromHead; off += readFromHead; len -= readFromHead; } } if (len > 0) readBytes += getIfNotClosed(in).read(b, off, len); return readBytes; } private void writeMulti(byte[] b, int off, int len) throws IOException { if (memAvailable > 0) { if (memBuffer == null) memBuffer = new byte[initialMemBufferSize]; int pos = maxMemBufferSize - memAvailable; int memWritten = memAvailable < len? memAvailable : len; if (pos + memWritten >= getIfNotClosed(memBuffer).length) growMemBuffer(pos, memWritten); System.arraycopy(b, off, memBuffer, pos, memWritten); off += memWritten; len -= memWritten; memAvailable -= memWritten; } if (len > 0) { if (diskTailAvailable == -1) { maybeCreateDiskBuffer(); diskHeadAvailable = (int)spillBuffer.getFilePointer(); diskTailAvailable = maxDiskBufferSize - diskHeadAvailable; } if (len > 0 && diskTailAvailable > 0) { int diskTailWritten = diskTailAvailable < len? diskTailAvailable : len; getIfNotClosed(spillBuffer).write(b, off, diskTailWritten); off += diskTailWritten; len -= diskTailWritten; diskTailAvailable -= diskTailWritten; if (diskTailAvailable == 0) spillBuffer.seek(0); } if (len > 0 && diskTailAvailable > 0) { int diskHeadWritten = diskHeadAvailable < len? diskHeadAvailable : len; getIfNotClosed(spillBuffer).write(b, off, diskHeadWritten); } } } private void writeOne(int value) throws IOException { if (memAvailable > 0) { if (memBuffer == null) memBuffer = new byte[initialMemBufferSize]; int pos = maxMemBufferSize - memAvailable; if (pos == getIfNotClosed(memBuffer).length) growMemBuffer(pos, 1); getIfNotClosed(memBuffer)[pos] = (byte)value; memAvailable--; return; } if (diskTailAvailable == -1) { maybeCreateDiskBuffer(); diskHeadAvailable = (int)spillBuffer.getFilePointer(); diskTailAvailable = maxDiskBufferSize - diskHeadAvailable; } if (diskTailAvailable > 0 || diskHeadAvailable > 0) { getIfNotClosed(spillBuffer).write(value); if (diskTailAvailable > 0) diskTailAvailable--; else if (diskHeadAvailable > 0) diskHeadAvailable--; if (diskTailAvailable == 0) spillBuffer.seek(0); return; } } public int read(byte[] b) throws IOException { return read(b, 0, b.length); } private void growMemBuffer(int pos, int writeSize) { int newSize = Math.min(2 * (pos + writeSize), maxMemBufferSize); byte newBuffer[] = new byte[newSize]; System.arraycopy(memBuffer, 0, newBuffer, 0, pos); memBuffer = newBuffer; } public long skip(long n) throws IOException { long skipped = 0; if (marked) { //if marked, we need to cache skipped bytes while (n-- > 0 && read() != -1) { skipped++; } return skipped; } if (memAvailable > 0) { skipped += memAvailable < n ? memAvailable : n; memAvailable -= skipped; n -= skipped; } if (n > 0 && diskTailAvailable > 0) { int skipFromTail = diskTailAvailable < n? diskTailAvailable : (int)n; getIfNotClosed(spillBuffer).skipBytes(skipFromTail); diskTailAvailable -= skipFromTail; skipped += skipFromTail; n -= skipFromTail; if (diskTailAvailable == 0) spillBuffer.seek(0); } if (n > 0 && diskHeadAvailable > 0) { int skipFromHead = diskHeadAvailable < n? diskHeadAvailable : (int)n; getIfNotClosed(spillBuffer).skipBytes(skipFromHead); diskHeadAvailable -= skipFromHead; skipped += skipFromHead; n -= skipFromHead; } if (n > 0) skipped += getIfNotClosed(in).skip(n); return skipped; } private <T> T getIfNotClosed(T in) throws IOException { if (closed.get()) throw new IOException("Stream closed"); return in; } public void close() throws IOException { close(true); } public void close(boolean closeUnderlying) throws IOException { if (closed.compareAndSet(false, true)) { Throwable fail = null; if (closeUnderlying) { try { super.close(); } catch (IOException e) { fail = merge(fail, e); } } try { if (spillBuffer != null) { this.spillBuffer.close(); this.spillBuffer = null; } } catch (IOException e) { fail = merge(fail, e); } try { if (spillFile.exists()) { spillFile.delete(); } } catch (Throwable e) { fail = merge(fail, e); } maybeFail(fail, IOException.class); } } /* DataInputPlus methods */ public void readFully(byte[] b) throws IOException { dataReader.readFully(b); } public void readFully(byte[] b, int off, int len) throws IOException { dataReader.readFully(b, off, len); } public int skipBytes(int n) throws IOException { return dataReader.skipBytes(n); } public boolean readBoolean() throws IOException { return dataReader.readBoolean(); } public byte readByte() throws IOException { return dataReader.readByte(); } public int readUnsignedByte() throws IOException { return dataReader.readUnsignedByte(); } public short readShort() throws IOException { return dataReader.readShort(); } public int readUnsignedShort() throws IOException { return dataReader.readUnsignedShort(); } public char readChar() throws IOException { return dataReader.readChar(); } public int readInt() throws IOException { return dataReader.readInt(); } public long readLong() throws IOException { return dataReader.readLong(); } public float readFloat() throws IOException { return dataReader.readFloat(); } public double readDouble() throws IOException { return dataReader.readDouble(); } public String readLine() throws IOException { return dataReader.readLine(); } public String readUTF() throws IOException { return dataReader.readUTF(); } }