package jdk.tools.jaotc.utils;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
public class NativeOrderOutputStream {
private final PatchableByteOutputStream os = new PatchableByteOutputStream();
private final byte[] backingArray = new byte[8];
private final ByteBuffer buffer;
private final List<Patchable> patches = new ArrayList<>();
private int size;
public NativeOrderOutputStream() {
buffer = ByteBuffer.wrap(backingArray);
buffer.order(ByteOrder.nativeOrder());
}
public static int alignUp(int value, int alignment) {
if (Integer.bitCount(alignment) != 1) {
throw new IllegalArgumentException("Must be a power of 2");
}
int aligned = (value + (alignment - 1)) & -alignment;
if (aligned < value || (aligned & (alignment - 1)) != 0) {
throw new RuntimeException("Error aligning: " + value + " -> " + aligned);
}
return aligned;
}
public NativeOrderOutputStream putLong(long value) {
fillLong(value);
os.write(backingArray, 0, 8);
size += 8;
return this;
}
public NativeOrderOutputStream putInt(int value) {
fillInt(value);
os.write(backingArray, 0, 4);
size += 4;
return this;
}
public NativeOrderOutputStream align(int alignment) {
int aligned = alignUp(position(), alignment);
int diff = aligned - position();
if (diff > 0) {
byte[] b = new byte[diff];
os.write(b, 0, b.length);
size += diff;
}
assert aligned == position();
return this;
}
public int position() {
return os.size();
}
private void fillInt(int value) {
buffer.putInt(0, value);
}
private void fillLong(long value) {
buffer.putLong(0, value);
}
private NativeOrderOutputStream putAt(byte[] data, int length, int position) {
os.writeAt(data, length, position);
return this;
}
public NativeOrderOutputStream put(byte[] data) {
os.write(data, 0, data.length);
size += data.length;
return this;
}
public byte[] array() {
checkPatches();
byte[] bytes = os.toByteArray();
assert size == bytes.length;
return bytes;
}
private void checkPatches() {
for (Patchable patch : patches) {
if (!patch.patched()) {
throw new RuntimeException("Not all patches patched, missing at offset: " + patch);
}
}
}
public PatchableInt patchableInt() {
int position = os.size();
PatchableInt patchableInt = new PatchableInt(position);
putInt(0);
patches.add(patchableInt);
return patchableInt;
}
public abstract class Patchable {
private final int position;
private boolean patched = false;
Patchable(int position) {
this.position = position;
}
protected boolean patched() {
return patched;
}
protected void patch(int length) {
putAt(backingArray, length, position);
patched = true;
}
public int position() {
return position;
}
}
public class PatchableInt extends Patchable {
private int value = 0;
public PatchableInt(int position) {
super(position);
}
public void set(int value) {
this.value = value;
fillInt(value);
patch(4);
}
public int value() {
return value;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("PatchableInt{");
sb.append("position=").append(super.position()).append(", ");
sb.append("patched=").append(patched()).append(", ");
sb.append("value=").append(value);
sb.append('}');
return sb.toString();
}
}
private static class PatchableByteOutputStream extends ByteArrayOutputStream {
public void writeAt(byte[] data, int length, int position) {
long end = (long)position + (long)length;
if (buf.length < end) {
throw new IllegalArgumentException("Array not properly sized");
}
System.arraycopy(data, 0, buf, position, length);
}
}
}