package sun.nio.ch;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodHandles.Lookup;
import java.lang.invoke.MethodType;
import java.lang.invoke.VarHandle;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.MulticastSocket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.SocketOption;
import java.net.StandardSocketOptions;
import java.nio.ByteBuffer;
import java.nio.channels.AlreadyConnectedException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.DatagramChannel;
import java.nio.channels.MembershipKey;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.PrivilegedExceptionAction;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.locks.ReentrantLock;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
public class DatagramSocketAdaptor
extends MulticastSocket
{
private final DatagramChannelImpl dc;
private volatile int timeout;
private DatagramSocketAdaptor(DatagramChannelImpl dc) throws IOException {
super( DatagramSockets.NO_DELEGATE);
this.dc = dc;
}
static DatagramSocket create(DatagramChannelImpl dc) {
try {
return new DatagramSocketAdaptor(dc);
} catch (IOException e) {
throw new Error(e);
}
}
private void connectInternal(SocketAddress remote) throws SocketException {
try {
dc.connect(remote, false);
} catch (ClosedChannelException e) {
} catch (Exception x) {
Net.translateToSocketException(x);
}
}
@Override
public void bind(SocketAddress local) throws SocketException {
if (local != null) {
local = Net.asInetSocketAddress(local);
} else {
local = new InetSocketAddress(0);
}
try {
dc.bind(local);
} catch (Exception x) {
Net.translateToSocketException(x);
}
}
@Override
public void connect(InetAddress address, int port) {
if (address == null)
throw new IllegalArgumentException("Address can't be null");
try {
connectInternal(new InetSocketAddress(address, port));
} catch (SocketException x) {
throw new UncheckedIOException(x);
}
}
@Override
public void connect(SocketAddress remote) throws SocketException {
if (remote == null)
throw new IllegalArgumentException("Address can't be null");
connectInternal(Net.asInetSocketAddress(remote));
}
@Override
public void disconnect() {
try {
dc.disconnect();
} catch (IOException x) {
throw new UncheckedIOException(x);
}
}
@Override
public boolean isBound() {
return dc.localAddress() != null;
}
@Override
public boolean isConnected() {
return dc.remoteAddress() != null;
}
@Override
public InetAddress getInetAddress() {
InetSocketAddress remote = dc.remoteAddress();
return (remote != null) ? remote.getAddress() : null;
}
@Override
public int getPort() {
InetSocketAddress remote = dc.remoteAddress();
return (remote != null) ? remote.getPort() : -1;
}
@Override
public SocketAddress getRemoteSocketAddress() {
return dc.remoteAddress();
}
@Override
public SocketAddress getLocalSocketAddress() {
InetSocketAddress local = dc.localAddress();
if (local == null || isClosed())
return null;
InetAddress addr = local.getAddress();
if (addr.isAnyLocalAddress())
return local;
SecurityManager sm = System.getSecurityManager();
if (sm != null) {
try {
sm.checkConnect(addr.getHostAddress(), -1);
} catch (SecurityException x) {
return new InetSocketAddress(local.getPort());
}
}
return local;
}
@Override
public void send(DatagramPacket p) throws IOException {
ByteBuffer bb = null;
try {
InetSocketAddress target;
synchronized (p) {
int len = p.getLength();
bb = Util.getTemporaryDirectBuffer(len);
bb.put(p.getData(), p.getOffset(), len);
bb.flip();
if (p.getAddress() == null) {
InetSocketAddress remote = dc.remoteAddress();
if (remote == null) {
throw new IllegalArgumentException("Address not set");
}
p.setAddress(remote.getAddress());
p.setPort(remote.getPort());
target = remote;
} else {
target = (InetSocketAddress) p.getSocketAddress();
}
}
try {
dc.blockingSend(bb, target);
} catch (AlreadyConnectedException e) {
throw new IllegalArgumentException("Connected and packet address differ");
} catch (ClosedChannelException e) {
var exc = new SocketException("Socket closed");
exc.initCause(e);
throw exc;
}
} finally {
if (bb != null) {
Util.offerFirstTemporaryDirectBuffer(bb);
}
}
}
@Override
public void receive(DatagramPacket p) throws IOException {
int bufLength = DatagramPackets.getBufLength(p);
ByteBuffer bb = Util.getTemporaryDirectBuffer(bufLength);
try {
long nanos = MILLISECONDS.toNanos(timeout);
SocketAddress sender = dc.blockingReceive(bb, nanos);
bb.flip();
synchronized (p) {
int len = Math.min(bb.limit(), DatagramPackets.getBufLength(p));
bb.get(p.getData(), p.getOffset(), len);
DatagramPackets.setLength(p, len);
p.setSocketAddress(sender);
}
} catch (ClosedChannelException e) {
var exc = new SocketException("Socket closed");
exc.initCause(e);
throw exc;
} finally {
Util.offerFirstTemporaryDirectBuffer(bb);
}
}
@Override
public InetAddress getLocalAddress() {
if (isClosed())
return null;
InetSocketAddress local = dc.localAddress();
if (local == null)
local = new InetSocketAddress(0);
InetAddress result = local.getAddress();
SecurityManager sm = System.getSecurityManager();
if (sm != null) {
try {
sm.checkConnect(result.getHostAddress(), -1);
} catch (SecurityException x) {
return new InetSocketAddress(0).getAddress();
}
}
return result;
}
@Override
public int getLocalPort() {
if (isClosed())
return -1;
InetSocketAddress local = dc.localAddress();
if (local != null) {
return local.getPort();
}
return 0;
}
@Override
public void setSoTimeout(int timeout) throws SocketException {
if (isClosed())
throw new SocketException("Socket is closed");
if (timeout < 0)
throw new IllegalArgumentException("timeout < 0");
this.timeout = timeout;
}
@Override
public int getSoTimeout() throws SocketException {
if (isClosed())
throw new SocketException("Socket is closed");
return timeout;
}
private void setBooleanOption(SocketOption<Boolean> name, boolean value)
throws SocketException
{
try {
dc.setOption(name, value);
} catch (IOException x) {
Net.translateToSocketException(x);
}
}
private void setIntOption(SocketOption<Integer> name, int value)
throws SocketException
{
try {
dc.setOption(name, value);
} catch (IOException x) {
Net.translateToSocketException(x);
}
}
private boolean getBooleanOption(SocketOption<Boolean> name) throws SocketException {
try {
return dc.getOption(name).booleanValue();
} catch (IOException x) {
Net.translateToSocketException(x);
return false;
}
}
private int getIntOption(SocketOption<Integer> name) throws SocketException {
try {
return dc.getOption(name).intValue();
} catch (IOException x) {
Net.translateToSocketException(x);
return -1;
}
}
@Override
public void setSendBufferSize(int size) throws SocketException {
if (size <= 0)
throw new IllegalArgumentException("Invalid send size");
setIntOption(StandardSocketOptions.SO_SNDBUF, size);
}
@Override
public int getSendBufferSize() throws SocketException {
return getIntOption(StandardSocketOptions.SO_SNDBUF);
}
@Override
public void setReceiveBufferSize(int size) throws SocketException {
if (size <= 0)
throw new IllegalArgumentException("Invalid receive size");
setIntOption(StandardSocketOptions.SO_RCVBUF, size);
}
@Override
public int getReceiveBufferSize() throws SocketException {
return getIntOption(StandardSocketOptions.SO_RCVBUF);
}
@Override
public void setReuseAddress(boolean on) throws SocketException {
setBooleanOption(StandardSocketOptions.SO_REUSEADDR, on);
}
@Override
public boolean getReuseAddress() throws SocketException {
return getBooleanOption(StandardSocketOptions.SO_REUSEADDR);
}
@Override
public void setBroadcast(boolean on) throws SocketException {
setBooleanOption(StandardSocketOptions.SO_BROADCAST, on);
}
@Override
public boolean getBroadcast() throws SocketException {
return getBooleanOption(StandardSocketOptions.SO_BROADCAST);
}
@Override
public void setTrafficClass(int tc) throws SocketException {
setIntOption(StandardSocketOptions.IP_TOS, tc);
}
@Override
public int getTrafficClass() throws SocketException {
return getIntOption(StandardSocketOptions.IP_TOS);
}
@Override
public void close() {
try {
dc.close();
} catch (IOException x) {
throw new Error(x);
}
}
@Override
public boolean isClosed() {
return !dc.isOpen();
}
@Override
public DatagramChannel getChannel() {
return dc;
}
@Override
public <T> DatagramSocket setOption(SocketOption<T> name, T value) throws IOException {
dc.setOption(name, value);
return this;
}
@Override
public <T> T getOption(SocketOption<T> name) throws IOException {
return dc.getOption(name);
}
@Override
public Set<SocketOption<?>> supportedOptions() {
return dc.supportedOptions();
}
private final ReentrantLock sendLock = new ReentrantLock();
private final Object outgoingInterfaceLock = new Object();
private NetworkInterface outgoingNetworkInterface;
private InetAddress outgoingInetAddress;
@Override
@Deprecated
public void setTTL(byte ttl) throws IOException {
setTimeToLive(Byte.toUnsignedInt(ttl));
}
@Override
public void setTimeToLive(int ttl) throws IOException {
sendLock.lock();
try {
setIntOption(StandardSocketOptions.IP_MULTICAST_TTL, ttl);
} finally {
sendLock.unlock();
}
}
@Override
@Deprecated
public byte getTTL() throws IOException {
return (byte) getTimeToLive();
}
@Override
public int getTimeToLive() throws IOException {
sendLock.lock();
try {
return getIntOption(StandardSocketOptions.IP_MULTICAST_TTL);
} finally {
sendLock.unlock();
}
}
@Override
@Deprecated
public void joinGroup(InetAddress group) throws IOException {
Objects.requireNonNull(group);
try {
joinGroup(new InetSocketAddress(group, 0), null);
} catch (IllegalArgumentException iae) {
throw (SocketException) new SocketException("joinGroup failed").initCause(iae);
}
}
@Override
@Deprecated
public void leaveGroup(InetAddress group) throws IOException {
Objects.requireNonNull(group);
try {
leaveGroup(new InetSocketAddress(group, 0), null);
} catch (IllegalArgumentException iae) {
throw (SocketException) new SocketException("leaveGroup failed").initCause(iae);
}
}
private static InetAddress checkGroup(SocketAddress mcastaddr) throws SocketException {
if (mcastaddr == null || !(mcastaddr instanceof InetSocketAddress))
throw new IllegalArgumentException("Unsupported address type");
InetAddress group = ((InetSocketAddress) mcastaddr).getAddress();
if (group == null)
throw new IllegalArgumentException("Unresolved address");
if (!group.isMulticastAddress())
throw new SocketException("Not a multicast address");
return group;
}
@Override
public void joinGroup(SocketAddress mcastaddr, NetworkInterface netIf) throws IOException {
InetAddress group = checkGroup(mcastaddr);
NetworkInterface ni = (netIf != null) ? netIf : defaultNetworkInterface();
if (isClosed())
throw new SocketException("Socket is closed");
synchronized (this) {
MembershipKey key = dc.findMembership(group, ni);
if (key != null) {
SecurityManager sm = System.getSecurityManager();
if (sm != null)
sm.checkMulticast(group);
throw new SocketException("Already a member of group");
}
dc.join(group, ni);
}
}
@Override
public void leaveGroup(SocketAddress mcastaddr, NetworkInterface netIf) throws IOException {
InetAddress group = checkGroup(mcastaddr);
NetworkInterface ni = (netIf != null) ? netIf : defaultNetworkInterface();
if (isClosed())
throw new SocketException("Socket is closed");
SecurityManager sm = System.getSecurityManager();
if (sm != null)
sm.checkMulticast(group);
synchronized (this) {
MembershipKey key = dc.findMembership(group, ni);
if (key == null)
throw new SocketException("Not a member of group");
key.drop();
}
}
@Override
@Deprecated
public void setInterface(InetAddress inf) throws SocketException {
if (inf == null)
throw new SocketException("Invalid value 'null'");
NetworkInterface ni = NetworkInterface.getByInetAddress(inf);
if (ni == null) {
String address = inf.getHostAddress();
throw new SocketException("No network interface with address " + address);
}
synchronized (outgoingInterfaceLock) {
setNetworkInterface(ni);
outgoingNetworkInterface = ni;
outgoingInetAddress = inf;
}
}
@Override
@Deprecated
public InetAddress getInterface() throws SocketException {
synchronized (outgoingInterfaceLock) {
NetworkInterface ni = outgoingNetworkInterface();
if (ni != null) {
if (ni.equals(outgoingNetworkInterface)) {
return outgoingInetAddress;
} else {
PrivilegedAction<InetAddress> pa;
pa = () -> ni.inetAddresses().findFirst().orElse(null);
InetAddress ia = AccessController.doPrivileged(pa);
if (ia == null)
throw new SocketException("Network interface has no IP address");
outgoingNetworkInterface = ni;
outgoingInetAddress = ia;
return ia;
}
}
}
return anyInetAddress();
}
@Override
public void setNetworkInterface(NetworkInterface netIf) throws SocketException {
try {
setOption(StandardSocketOptions.IP_MULTICAST_IF, netIf);
} catch (IOException e) {
Net.translateToSocketException(e);
}
}
@Override
public NetworkInterface getNetworkInterface() throws SocketException {
NetworkInterface ni = outgoingNetworkInterface();
if (ni == null) {
ni = anyNetworkInterface();
}
return ni;
}
@Override
@Deprecated
public void setLoopbackMode(boolean disable) throws SocketException {
boolean enable = !disable;
setBooleanOption(StandardSocketOptions.IP_MULTICAST_LOOP, enable);
}
@Override
@Deprecated
public boolean getLoopbackMode() throws SocketException {
boolean enabled = getBooleanOption(StandardSocketOptions.IP_MULTICAST_LOOP);
return !enabled;
}
@Override
@Deprecated
public void send(DatagramPacket p, byte ttl) throws IOException {
sendLock.lock();
try {
int oldValue = getTimeToLive();
try {
setTTL(ttl);
send(p);
} finally {
setTimeToLive(oldValue);
}
} finally {
sendLock.unlock();
}
}
private NetworkInterface outgoingNetworkInterface() throws SocketException {
try {
return getOption(StandardSocketOptions.IP_MULTICAST_IF);
} catch (IOException e) {
Net.translateToSocketException(e);
return null;
}
}
private NetworkInterface defaultNetworkInterface() throws SocketException {
NetworkInterface ni = outgoingNetworkInterface();
if (ni == null)
ni = NetworkInterfaces.getDefault();
if (ni == null)
ni = anyNetworkInterface();
return ni;
}
private NetworkInterface anyNetworkInterface() {
InetAddress[] addrs = new InetAddress[1];
addrs[0] = anyInetAddress();
return NetworkInterfaces.newNetworkInterface(addrs[0].getHostName(), 0, addrs);
}
private InetAddress anyInetAddress() {
return new InetSocketAddress(0).getAddress();
}
private static class DatagramPackets {
private static final VarHandle LENGTH;
private static final VarHandle BUF_LENGTH;
static {
try {
PrivilegedExceptionAction<Lookup> pa = () ->
MethodHandles.privateLookupIn(DatagramPacket.class, MethodHandles.lookup());
MethodHandles.Lookup l = AccessController.doPrivileged(pa);
LENGTH = l.findVarHandle(DatagramPacket.class, "length", int.class);
BUF_LENGTH = l.findVarHandle(DatagramPacket.class, "bufLength", int.class);
} catch (Exception e) {
throw new ExceptionInInitializerError(e);
}
}
static void setLength(DatagramPacket p, int value) {
synchronized (p) {
LENGTH.set(p, value);
}
}
static int getBufLength(DatagramPacket p) {
synchronized (p) {
return (int) BUF_LENGTH.get(p);
}
}
}
private static class NetworkInterfaces {
static final MethodHandle GET_DEFAULT;
static final MethodHandle CONSTRUCTOR;
static {
try {
PrivilegedExceptionAction<Lookup> pa = () ->
MethodHandles.privateLookupIn(NetworkInterface.class, MethodHandles.lookup());
MethodHandles.Lookup l = AccessController.doPrivileged(pa);
MethodType methodType = MethodType.methodType(NetworkInterface.class);
GET_DEFAULT = l.findStatic(NetworkInterface.class, "getDefault", methodType);
methodType = MethodType.methodType(void.class, String.class, int.class, InetAddress[].class);
CONSTRUCTOR = l.findConstructor(NetworkInterface.class, methodType);
} catch (Exception e) {
throw new ExceptionInInitializerError(e);
}
}
static NetworkInterface getDefault() {
try {
return (NetworkInterface) GET_DEFAULT.invokeExact();
} catch (Throwable e) {
throw new InternalError(e);
}
}
static NetworkInterface newNetworkInterface(String name, int index, InetAddress[] addrs) {
try {
return (NetworkInterface) CONSTRUCTOR.invoke(name, index, addrs);
} catch (Throwable e) {
throw new InternalError(e);
}
}
}
private static class DatagramSockets {
private static final SocketAddress NO_DELEGATE;
static {
try {
PrivilegedExceptionAction<Lookup> pa = () ->
MethodHandles.privateLookupIn(DatagramSocket.class, MethodHandles.lookup());
MethodHandles.Lookup l = AccessController.doPrivileged(pa);
NO_DELEGATE = (SocketAddress)
l.findStaticVarHandle(DatagramSocket.class, "NO_DELEGATE",
SocketAddress.class).get();
} catch (Exception e) {
throw new ExceptionInInitializerError(e);
}
}
}
}