package org.graalvm.compiler.graph.test.graphio;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import org.graalvm.graphio.GraphOutput;
import org.graalvm.graphio.GraphStructure;
import org.graalvm.graphio.GraphTypes;
import static org.junit.Assert.assertSame;
import org.junit.Test;
import java.lang.reflect.Field;
import static org.junit.Assert.assertEquals;
public final class GraphOutputTest {
@Test
@SuppressWarnings("static-method")
public void testWritableByteChannel() throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
WritableByteChannel channel = Channels.newChannel(out);
ByteBuffer data = generateData(1 << 17);
GraphOutput<?, ?> graphOutput = GraphOutput.newBuilder(new MockGraphStructure()).protocolVersion(6, 0).embedded(true).build(channel);
try (GraphOutput<?, ?> closable = graphOutput) {
assertTrue(closable.isOpen());
closable.write(data);
}
assertFalse(graphOutput.isOpen());
assertArrayEquals(data.array(), out.toByteArray());
}
@Test
@SuppressWarnings("static-method")
public void testWriteDuringPrint() throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
WritableByteChannel channel = Channels.newChannel(out);
class Action implements Runnable {
GraphOutput<MockGraph, ?> out;
@Override
public void run() {
try {
ByteBuffer data = ByteBuffer.allocate(16);
data.limit(16);
out.write(data);
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
}
}
Action action = new Action();
try (GraphOutput<MockGraph, ?> graphOutput = GraphOutput.newBuilder(new MockGraphStructure(action)).protocolVersion(6, 0).build(channel)) {
action.out = graphOutput;
try {
graphOutput.print(new MockGraph(), Collections.emptyMap(), 0, "Mock Graph");
fail("Expected IllegalStateException");
} catch (IllegalStateException ise) {
}
}
}
@Test
@SuppressWarnings("static-method")
public void testEmbeddedWritableByteChannel() throws IOException {
ByteArrayOutputStream expected = new ByteArrayOutputStream();
WritableByteChannel expectedChannel = Channels.newChannel(expected);
Map<Object, Object> properties = Collections.singletonMap("version.id", 1);
try (GraphOutput<MockGraph, ?> graphOutput = GraphOutput.newBuilder(new MockGraphStructure()).protocolVersion(6, 0).build(expectedChannel)) {
graphOutput.print(new MockGraph(), properties, 1, "Graph 1");
graphOutput.write(ByteBuffer.allocate(0));
graphOutput.print(new MockGraph(), properties, 2, "Graph 1");
}
ByteArrayOutputStream embedded = new ByteArrayOutputStream();
SharedWritableByteChannel embeddChannel = new SharedWritableByteChannel(Channels.newChannel(embedded));
try {
try (GraphOutput<MockGraph, ?> baseOutput = GraphOutput.newBuilder(new MockGraphStructure()).protocolVersion(6, 0).build(embeddChannel)) {
try (GraphOutput<MockGraph, ?> embeddedOutput = GraphOutput.newBuilder(new MockGraphStructure()).protocolVersion(6, 0).embedded(true).build((WritableByteChannel) baseOutput)) {
embeddedOutput.print(new MockGraph(), properties, 1, "Graph 1");
baseOutput.print(new MockGraph(), properties, 2, "Graph 1");
}
}
} finally {
embeddChannel.realClose();
}
assertArrayEquals(expected.toByteArray(), embedded.toByteArray());
}
@Test
@SuppressWarnings({"static-method", "unchecked"})
public void testClassOfEnumValueWithImplementation() throws ClassNotFoundException, ReflectiveOperationException {
Class<? extends GraphTypes> defaultTypesClass = (Class<? extends GraphTypes>) Class.forName("org.graalvm.graphio.DefaultGraphTypes");
Field f = defaultTypesClass.getDeclaredField("DEFAULT");
f.setAccessible(true);
GraphTypes types = (GraphTypes) f.get(null);
Object clazz = types.enumClass(CustomEnum.ONE);
assertSame(CustomEnum.class, clazz);
}
@Test
@SuppressWarnings({"static-method", "unchecked"})
public void testBuilderPromotesVersion() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try (WritableByteChannel channel = Channels.newChannel(out)) {
GraphOutput.Builder<MockGraph, Void, ?> builder = GraphOutput.newBuilder(new MockGraphStructure()).attr("test", "failed");
try (GraphOutput<MockGraph, ?> graphOutput = builder.build(channel)) {
graphOutput.print(new MockGraph(), Collections.emptyMap(), 0, "Mock Graph");
} catch (IllegalStateException ise) {
}
}
byte[] bytes = out.toByteArray();
assertEquals("Major version 7 must be auto-selected", 7, bytes[4]);
}
@Test
@SuppressWarnings({"static-method", "unchecked"})
public void testTooOldVersionFails() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try (WritableByteChannel channel = Channels.newChannel(out)) {
GraphOutput.Builder<MockGraph, Void, ?> builder = GraphOutput.newBuilder(new MockGraphStructure()).protocolVersion(6, 1);
try {
builder.attr("test", "failed");
fail("Should have failed, attr() requires version 7.0");
} catch (IllegalStateException ex) {
}
try (GraphOutput<MockGraph, ?> graphOutput = builder.build(channel)) {
graphOutput.print(new MockGraph(), Collections.emptyMap(), 0, "Mock Graph");
} catch (IllegalStateException ise) {
}
}
}
@Test
@SuppressWarnings({"static-method", "unchecked"})
public void testVersionDowngradeFails() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try (WritableByteChannel channel = Channels.newChannel(out)) {
GraphOutput.Builder<MockGraph, Void, ?> builder = GraphOutput.newBuilder(new MockGraphStructure());
builder.attr("test", "failed");
try {
builder.protocolVersion(6, 0);
fail("Should fail, cannot downgrade from required version.");
} catch (IllegalArgumentException e) {
}
try (GraphOutput<MockGraph, ?> graphOutput = builder.build(channel)) {
graphOutput.print(new MockGraph(), Collections.emptyMap(), 0, "Mock Graph");
} catch (IllegalStateException ise) {
}
}
}
@Test
@SuppressWarnings({"static-method", "unchecked"})
public void testManualAncientVersion() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try (WritableByteChannel channel = Channels.newChannel(out)) {
GraphOutput.Builder<MockGraph, Void, ?> builder = GraphOutput.newBuilder(new MockGraphStructure()).protocolVersion(3, 0);
try (GraphOutput<MockGraph, ?> graphOutput = builder.build(channel)) {
graphOutput.print(new MockGraph(), Collections.emptyMap(), 0, "Mock Graph");
}
}
byte[] bytes = out.toByteArray();
assertEquals("Protocol version 3 was requested", 3, bytes[4]);
}
@Test
@SuppressWarnings({"static-method", "unchecked"})
public void testManualVersionUpgradeOK() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try (WritableByteChannel channel = Channels.newChannel(out)) {
GraphOutput.Builder<MockGraph, Void, ?> builder = GraphOutput.newBuilder(new MockGraphStructure());
builder.attr("some", "thing");
builder.protocolVersion(7, 0);
try (GraphOutput<MockGraph, ?> graphOutput = builder.build(channel)) {
graphOutput.print(new MockGraph(), Collections.emptyMap(), 0, "Mock Graph");
}
}
byte[] bytes = out.toByteArray();
assertEquals("Protocol version 7 was requested", 7, bytes[4]);
}
private static ByteBuffer generateData(int size) {
ByteBuffer buffer = ByteBuffer.allocate(size);
for (int i = 0; i < size; i++) {
buffer.put(i, (byte) i);
}
buffer.limit(size);
return buffer;
}
private static final class SharedWritableByteChannel implements WritableByteChannel {
private final WritableByteChannel delegate;
SharedWritableByteChannel(WritableByteChannel delegate) {
Objects.requireNonNull(delegate, "Delegate must be non null.");
this.delegate = delegate;
}
@Override
public int write(ByteBuffer bb) throws IOException {
return delegate.write(bb);
}
@Override
public boolean isOpen() {
return delegate.isOpen();
}
@Override
public void close() throws IOException {
}
void realClose() throws IOException {
delegate.close();
}
}
private static final class MockGraphStructure implements GraphStructure<MockGraph, Void, Void, Void> {
private final Runnable enterAction;
MockGraphStructure() {
this.enterAction = null;
}
MockGraphStructure(Runnable enterAction) {
this.enterAction = enterAction;
}
@Override
public MockGraph graph(MockGraph currentGraph, Object obj) {
onEnter();
return null;
}
@Override
public Iterable<? extends Void> nodes(MockGraph graph) {
onEnter();
return Collections.emptySet();
}
@Override
public int nodesCount(MockGraph graph) {
onEnter();
return 0;
}
@Override
public int nodeId(Void node) {
onEnter();
return 0;
}
@Override
public boolean nodeHasPredecessor(Void node) {
onEnter();
return false;
}
@Override
public void nodeProperties(MockGraph graph, Void node, Map<String, ? super Object> properties) {
onEnter();
}
@Override
public Void node(Object obj) {
onEnter();
return null;
}
@Override
public Void nodeClass(Object obj) {
onEnter();
return null;
}
@Override
public Void classForNode(Void node) {
onEnter();
return null;
}
@Override
public String nameTemplate(Void nodeClass) {
onEnter();
return null;
}
@Override
public Object nodeClassType(Void nodeClass) {
onEnter();
return null;
}
@Override
public Void portInputs(Void nodeClass) {
onEnter();
return null;
}
@Override
public Void portOutputs(Void nodeClass) {
onEnter();
return null;
}
@Override
public int portSize(Void port) {
onEnter();
return 0;
}
@Override
public boolean edgeDirect(Void port, int index) {
onEnter();
return false;
}
@Override
public String edgeName(Void port, int index) {
onEnter();
return null;
}
@Override
public Object edgeType(Void port, int index) {
onEnter();
return null;
}
@Override
public Collection<? extends Void> edgeNodes(MockGraph graph, Void node, Void port, int index) {
onEnter();
return null;
}
private void onEnter() {
if (enterAction != null) {
enterAction.run();
}
}
}
private static final class MockGraph {
}
private enum {
ONE() {
@Override
public String () {
return "one";
}
},
TWO() {
@Override
public String () {
return "two";
}
}
}
}