package java.net;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.BufferedOutputStream;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.util.Iterator;
import jdk.internal.util.StaticProperty;
import sun.net.SocksProxy;
import sun.net.spi.DefaultProxySelector;
import sun.net.www.ParseUtil;
class SocksSocketImpl extends DelegatingSocketImpl implements SocksConsts {
private String server = null;
private int serverPort = DEFAULT_PORT;
private InetSocketAddress external_address;
private boolean useV4 = false;
private Socket cmdsock = null;
private InputStream cmdIn = null;
private OutputStream cmdOut = null;
SocksSocketImpl(SocketImpl delegate) {
super(delegate);
}
SocksSocketImpl(Proxy proxy, SocketImpl delegate) {
super(delegate);
SocketAddress a = proxy.address();
if (a instanceof InetSocketAddress) {
InetSocketAddress ad = (InetSocketAddress) a;
server = ad.getHostString();
serverPort = ad.getPort();
}
useV4 = useV4(proxy);
}
private static boolean useV4(Proxy proxy) {
if (proxy instanceof SocksProxy
&& ((SocksProxy)proxy).protocolVersion() == 4) {
return true;
}
return DefaultProxySelector.socksProxyVersion() == 4;
}
private synchronized void privilegedConnect(final String host,
final int port,
final int timeout)
throws IOException
{
try {
AccessController.doPrivileged(
new java.security.PrivilegedExceptionAction<>() {
public Void run() throws IOException {
superConnectServer(host, port, timeout);
cmdIn = getInputStream();
cmdOut = getOutputStream();
return null;
}
});
} catch (java.security.PrivilegedActionException pae) {
throw (IOException) pae.getException();
}
}
private void superConnectServer(String host, int port,
int timeout) throws IOException {
delegate.connect(new InetSocketAddress(host, port), timeout);
}
private static int remainingMillis(long deadlineMillis) throws IOException {
if (deadlineMillis == 0L)
return 0;
final long remaining = deadlineMillis - System.currentTimeMillis();
if (remaining > 0)
return (int) remaining;
throw new SocketTimeoutException();
}
private int readSocksReply(InputStream in, byte[] data, long deadlineMillis) throws IOException {
int len = data.length;
int received = 0;
int originalTimeout = (int) getOption(SocketOptions.SO_TIMEOUT);
try {
while (received < len) {
int count;
int remaining = remainingMillis(deadlineMillis);
setOption(SocketOptions.SO_TIMEOUT, remaining);
try {
count = in.read(data, received, len - received);
} catch (SocketTimeoutException e) {
throw new SocketTimeoutException("Connect timed out");
}
if (count < 0)
throw new SocketException("Malformed reply from SOCKS server");
received += count;
}
} finally {
setOption(SocketOptions.SO_TIMEOUT, originalTimeout);
}
return received;
}
private boolean authenticate(byte method, InputStream in,
BufferedOutputStream out,
long deadlineMillis) throws IOException {
if (method == NO_AUTH)
return true;
if (method == USER_PASSW) {
String userName;
String password = null;
final InetAddress addr = InetAddress.getByName(server);
PasswordAuthentication pw =
java.security.AccessController.doPrivileged(
new java.security.PrivilegedAction<>() {
public PasswordAuthentication run() {
return Authenticator.requestPasswordAuthentication(
server, addr, serverPort, "SOCKS5", "SOCKS authentication", null);
}
});
if (pw != null) {
userName = pw.getUserName();
password = new String(pw.getPassword());
} else {
userName = StaticProperty.userName();
}
if (userName == null)
return false;
out.write(1);
out.write(userName.length());
out.write(userName.getBytes(StandardCharsets.ISO_8859_1));
if (password != null) {
out.write(password.length());
out.write(password.getBytes(StandardCharsets.ISO_8859_1));
} else
out.write(0);
out.flush();
byte[] data = new byte[2];
int i = readSocksReply(in, data, deadlineMillis);
if (i != 2 || data[1] != 0) {
out.close();
in.close();
return false;
}
return true;
}
return false;
}
private void connectV4(InputStream in, OutputStream out,
InetSocketAddress endpoint,
long deadlineMillis) throws IOException {
if (!(endpoint.getAddress() instanceof Inet4Address)) {
throw new SocketException("SOCKS V4 requires IPv4 only addresses");
}
out.write(PROTO_VERS4);
out.write(CONNECT);
out.write((endpoint.getPort() >> 8) & 0xff);
out.write((endpoint.getPort() >> 0) & 0xff);
out.write(endpoint.getAddress().getAddress());
String userName = getUserName();
out.write(userName.getBytes(StandardCharsets.ISO_8859_1));
out.write(0);
out.flush();
byte[] data = new byte[8];
int n = readSocksReply(in, data, deadlineMillis);
if (n != 8)
throw new SocketException("Reply from SOCKS server has bad length: " + n);
if (data[0] != 0 && data[0] != 4)
throw new SocketException("Reply from SOCKS server has bad version");
SocketException ex = null;
switch (data[1]) {
case 90:
external_address = endpoint;
break;
case 91:
ex = new SocketException("SOCKS request rejected");
break;
case 92:
ex = new SocketException("SOCKS server couldn't reach destination");
break;
case 93:
ex = new SocketException("SOCKS authentication failed");
break;
default:
ex = new SocketException("Reply from SOCKS server contains bad status");
break;
}
if (ex != null) {
in.close();
out.close();
throw ex;
}
}
@Override
protected void connect(String host, int port) throws IOException {
connect(new InetSocketAddress(host, port), 0);
}
@Override
protected void connect(InetAddress address, int port) throws IOException {
connect(new InetSocketAddress(address, port), 0);
}
@Override
protected void connect(SocketAddress endpoint, int timeout) throws IOException {
final long deadlineMillis;
if (timeout == 0) {
deadlineMillis = 0L;
} else {
long finish = System.currentTimeMillis() + timeout;
deadlineMillis = finish < 0 ? Long.MAX_VALUE : finish;
}
SecurityManager security = System.getSecurityManager();
if (endpoint == null || !(endpoint instanceof InetSocketAddress))
throw new IllegalArgumentException("Unsupported address type");
InetSocketAddress epoint = (InetSocketAddress) endpoint;
if (security != null) {
if (epoint.isUnresolved())
security.checkConnect(epoint.getHostName(),
epoint.getPort());
else
security.checkConnect(epoint.getAddress().getHostAddress(),
epoint.getPort());
}
if (server == null) {
ProxySelector sel = java.security.AccessController.doPrivileged(
new java.security.PrivilegedAction<>() {
public ProxySelector run() {
return ProxySelector.getDefault();
}
});
if (sel == null) {
delegate.connect(epoint, remainingMillis(deadlineMillis));
return;
}
URI uri;
String host = epoint.getHostString();
if (epoint.getAddress() instanceof Inet6Address &&
(!host.startsWith("[")) && (host.indexOf(':') >= 0)) {
host = "[" + host + "]";
}
try {
uri = new URI("socket://" + ParseUtil.encodePath(host) + ":"+ epoint.getPort());
} catch (URISyntaxException e) {
assert false : e;
uri = null;
}
Proxy p = null;
IOException savedExc = null;
final Iterator<Proxy> iProxy;
try {
iProxy = sel.select(uri).iterator();
} catch (IllegalArgumentException iae) {
throw new IOException("Failed to select a proxy", iae);
}
if (iProxy == null || !(iProxy.hasNext())) {
delegate.connect(epoint, remainingMillis(deadlineMillis));
return;
}
while (iProxy.hasNext()) {
p = iProxy.next();
if (p == null || p.type() != Proxy.Type.SOCKS) {
delegate.connect(epoint, remainingMillis(deadlineMillis));
return;
}
if (!(p.address() instanceof InetSocketAddress))
throw new SocketException("Unknown address type for proxy: " + p);
server = ((InetSocketAddress) p.address()).getHostString();
serverPort = ((InetSocketAddress) p.address()).getPort();
useV4 = useV4(p);
try {
privilegedConnect(server, serverPort, remainingMillis(deadlineMillis));
break;
} catch (IOException e) {
sel.connectFailed(uri,p.address(),e);
server = null;
serverPort = -1;
savedExc = e;
}
}
if (server == null) {
throw new SocketException("Can't connect to SOCKS proxy:"
+ savedExc.getMessage());
}
} else {
try {
privilegedConnect(server, serverPort, remainingMillis(deadlineMillis));
} catch (IOException e) {
throw new SocketException(e.getMessage());
}
}
BufferedOutputStream out = new BufferedOutputStream(cmdOut, 512);
InputStream in = cmdIn;
if (useV4) {
if (epoint.isUnresolved())
throw new UnknownHostException(epoint.toString());
connectV4(in, out, epoint, deadlineMillis);
return;
}
out.write(PROTO_VERS);
out.write(2);
out.write(NO_AUTH);
out.write(USER_PASSW);
out.flush();
byte[] data = new byte[2];
int i = readSocksReply(in, data, deadlineMillis);
if (i != 2 || ((int)data[0]) != PROTO_VERS) {
if (epoint.isUnresolved())
throw new UnknownHostException(epoint.toString());
connectV4(in, out, epoint, deadlineMillis);
return;
}
if (((int)data[1]) == NO_METHODS)
throw new SocketException("SOCKS : No acceptable methods");
if (!authenticate(data[1], in, out, deadlineMillis)) {
throw new SocketException("SOCKS : authentication failed");
}
out.write(PROTO_VERS);
out.write(CONNECT);
out.write(0);
if (epoint.isUnresolved()) {
out.write(DOMAIN_NAME);
out.write(epoint.getHostName().length());
out.write(epoint.getHostName().getBytes(StandardCharsets.ISO_8859_1));
out.write((epoint.getPort() >> 8) & 0xff);
out.write((epoint.getPort() >> 0) & 0xff);
} else if (epoint.getAddress() instanceof Inet6Address) {
out.write(IPV6);
out.write(epoint.getAddress().getAddress());
out.write((epoint.getPort() >> 8) & 0xff);
out.write((epoint.getPort() >> 0) & 0xff);
} else {
out.write(IPV4);
out.write(epoint.getAddress().getAddress());
out.write((epoint.getPort() >> 8) & 0xff);
out.write((epoint.getPort() >> 0) & 0xff);
}
out.flush();
data = new byte[4];
i = readSocksReply(in, data, deadlineMillis);
if (i != 4)
throw new SocketException("Reply from SOCKS server has bad length");
SocketException ex = null;
int len;
byte[] addr;
switch (data[1]) {
case REQUEST_OK:
switch(data[3]) {
case IPV4:
addr = new byte[4];
i = readSocksReply(in, addr, deadlineMillis);
if (i != 4)
throw new SocketException("Reply from SOCKS server badly formatted");
data = new byte[2];
i = readSocksReply(in, data, deadlineMillis);
if (i != 2)
throw new SocketException("Reply from SOCKS server badly formatted");
break;
case DOMAIN_NAME:
byte[] lenByte = new byte[1];
i = readSocksReply(in, lenByte, deadlineMillis);
if (i != 1)
throw new SocketException("Reply from SOCKS server badly formatted");
len = lenByte[0] & 0xFF;
byte[] host = new byte[len];
i = readSocksReply(in, host, deadlineMillis);
if (i != len)
throw new SocketException("Reply from SOCKS server badly formatted");
data = new byte[2];
i = readSocksReply(in, data, deadlineMillis);
if (i != 2)
throw new SocketException("Reply from SOCKS server badly formatted");
break;
case IPV6:
len = 16;
addr = new byte[len];
i = readSocksReply(in, addr, deadlineMillis);
if (i != len)
throw new SocketException("Reply from SOCKS server badly formatted");
data = new byte[2];
i = readSocksReply(in, data, deadlineMillis);
if (i != 2)
throw new SocketException("Reply from SOCKS server badly formatted");
break;
default:
ex = new SocketException("Reply from SOCKS server contains wrong code");
break;
}
break;
case GENERAL_FAILURE:
ex = new SocketException("SOCKS server general failure");
break;
case NOT_ALLOWED:
ex = new SocketException("SOCKS: Connection not allowed by ruleset");
break;
case NET_UNREACHABLE:
ex = new SocketException("SOCKS: Network unreachable");
break;
case HOST_UNREACHABLE:
ex = new SocketException("SOCKS: Host unreachable");
break;
case CONN_REFUSED:
ex = new SocketException("SOCKS: Connection refused");
break;
case TTL_EXPIRED:
ex = new SocketException("SOCKS: TTL expired");
break;
case CMD_NOT_SUPPORTED:
ex = new SocketException("SOCKS: Command not supported");
break;
case ADDR_TYPE_NOT_SUP:
ex = new SocketException("SOCKS: address type not supported");
break;
}
if (ex != null) {
in.close();
out.close();
throw ex;
}
external_address = epoint;
}
@Override
protected void listen(int backlog) {
throw new InternalError("should not get here");
}
@Override
protected void accept(SocketImpl s) {
throw new InternalError("should not get here");
}
@Override
protected InetAddress getInetAddress() {
if (external_address != null)
return external_address.getAddress();
else
return delegate.getInetAddress();
}
@Override
protected int getPort() {
if (external_address != null)
return external_address.getPort();
else
return delegate.getPort();
}
@Override
protected void close() throws IOException {
if (cmdsock != null)
cmdsock.close();
cmdsock = null;
delegate.close();
}
private String getUserName() {
return StaticProperty.userName();
}
@Override
void reset() {
throw new InternalError("should not get here");
}
}