package com.oracle.truffle.tools.chromeinspector.test;
import org.graalvm.polyglot.Context;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.util.HashSet;
import java.util.Set;
import static java.lang.Integer.parseInt;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
public class InspectorDnsRebindAttackTest {
private Context context;
private ByteArrayOutputStream errorOutput;
@Before
public void setUp() {
errorOutput = new ByteArrayOutputStream();
final String testPath = "testPath-" + SecureInspectorPathGenerator.getToken();
context = Context.newBuilder().option("inspect.Path", testPath).err(errorOutput).build();
}
@After
public void tearDown() {
if (context != null) {
context.close();
context = null;
}
errorOutput.reset();
}
@Test
public void testMissingHost() throws IOException {
testDnsRebindForInvalidHost(null);
}
@Test
public void testInvalidHostWithoutPort() throws IOException {
testDnsRebindForInvalidHost("evil.example.com");
}
@Test
public void testInvalidHostWithStandardPort() throws IOException {
testDnsRebindForInvalidHost("evil.example.com:9229");
}
@Test
public void testInvalidHostWithNonstandardPort() throws IOException {
testDnsRebindForInvalidHost("evil.example.com:9228");
}
@Test
public void testLocalhostIpv4() throws IOException {
testDnsRebindForValidAddress("127.0.0.1");
}
@Test
public void testLocalhostIpv6() throws IOException {
testDnsRebindForValidAddress("[::1]");
}
@Test
public void testLocalhostHostname() throws IOException {
testDnsRebindForValidAddress("localhost");
}
@Test
public void testLocalNetworkIp() throws IOException {
testDnsRebindForValidAddress("192.168.6.6");
}
@Test
public void testPublicIpv4() throws IOException {
testDnsRebindForValidAddress("1.2.3.4");
}
@Test
public void testIpv6Long() throws IOException {
testDnsRebindForValidAddress("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]");
}
@Test
public void testIpv6Short() throws IOException {
testDnsRebindForValidAddress("[2001:0db8:85a3::0370:7334]");
}
private void testDnsRebindForValidAddress(String address) throws IOException {
testDnsRebindForValidHost(address);
testDnsRebindForValidHost(address + ":9229");
testDnsRebindForValidHost(address + ":9228");
}
private void testDnsRebindForInvalidHost(String host) throws IOException {
testDnsRebindForHost(host, false);
}
private void testDnsRebindForValidHost(String host) throws IOException {
testDnsRebindForHost(host, true);
}
private void testDnsRebindForHost(String host, boolean valid) throws IOException {
testDnsRebind(host, "/", valid);
testDnsRebind(host, "/json", valid);
testDnsRebind(host, "/json/version", valid);
testDnsRebind(host, "/some-nonsense", valid);
}
private void testDnsRebind(String host, String path, boolean valid) throws IOException {
testDnsRebindForHostCapitalization("host", host, path, valid);
testDnsRebindForHostCapitalization("Host", host, path, valid);
testDnsRebindForHostCapitalization("HoSt", host, path, valid);
testDnsRebindForHostCapitalization("HOST", host, path, valid);
}
private void testDnsRebindForHostCapitalization(String hostCapitalization, String host, String path, boolean valid) throws IOException {
try (
Socket socket = new Socket("127.0.0.1", 9229);
OutputStreamWriter out = new OutputStreamWriter(socket.getOutputStream(), StandardCharsets.US_ASCII);
InputStream in = socket.getInputStream();) {
out.write("GET " + path + " HTTP/1.1\r\n" +
"User-Agent: Test\r\n" +
(host != null ? (hostCapitalization + ": " + host + "\r\n") : "") +
"Accept: text/html, image/gif, image/jpeg, *; q=.2, */*; q=.2\r\n" +
"\r\n");
out.flush();
final String httpStatus = readAsciiLine(in);
final Set<String> headers = new HashSet<>();
String line;
int length = -1;
while (!"".equals(line = readAsciiLine(in))) {
final String prefix = "content-length:";
if (line.toLowerCase().startsWith(prefix)) {
length = parseInt(line.substring(prefix.length()).trim());
}
assertNotNull(line);
headers.add(line);
}
assertNotEquals(-1, length);
final byte[] rawBody = readBytes(in, length);
final String body = new String(rawBody, StandardCharsets.UTF_8);
String badHost = host != null ? "Bad host " + host + ". Please use IP address." : "Missing host header. Use an up-to-date client.";
String errorMessage = badHost + " This request cannot be served because it looks like DNS rebind attack.";
if (valid) {
assertNotEquals("HTTP/1.1 400 Bad Request ", httpStatus);
assertNotEquals(errorMessage, body);
} else {
assertEquals("HTTP/1.1 400 Bad Request ", httpStatus);
assertEquals(errorMessage, body);
errorOutput.toString().endsWith(errorMessage);
}
}
}
private static String readAsciiLine(InputStream in) throws IOException {
final StringBuilder sb = new StringBuilder();
int c;
while ((c = in.read()) != '\n') {
assertNotEquals(c, -1);
sb.append((char) c);
}
if (sb.charAt(sb.length() - 1) == '\r') {
sb.setLength(sb.length() - 1);
}
return sb.toString();
}
private static byte[] readBytes(InputStream in, int length) throws IOException {
int missing = length;
int pos = 0;
final byte[] buffer = new byte[length];
while (missing > 0) {
final int read = in.read(buffer, pos, missing);
if (read == -1) {
throw new EOFException();
}
missing -= read;
pos += read;
}
return buffer;
}
}