package org.jruby.ext.securerandom;
import org.jruby.RubyBignum;
import org.jruby.RubyFixnum;
import org.jruby.RubyFloat;
import org.jruby.RubyInteger;
import org.jruby.RubyRange;
import org.jruby.RubyString;
import org.jruby.anno.JRubyMethod;
import org.jruby.anno.JRubyModule;
import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.builtin.IRubyObject;
import org.jruby.util.ConvertBytes;
import java.math.BigInteger;
import java.security.SecureRandom;
@JRubyModule(name = "SecureRandom")
public class RubySecureRandom {
@JRubyMethod(meta = true, name = "random_bytes")
public static IRubyObject random_bytes(ThreadContext context, IRubyObject self) {
return RubyString.newStringNoCopy(context.runtime, nextBytes(context, 16));
}
@JRubyMethod(meta = true, name = "random_bytes")
public static IRubyObject random_bytes(ThreadContext context, IRubyObject self, IRubyObject n) {
return RubyString.newStringNoCopy(context.runtime, nextBytes(context, n));
}
@JRubyMethod(meta = true, name = {"gen_random", "bytes"})
public static IRubyObject gen_random(ThreadContext context, IRubyObject self, IRubyObject n) {
return random_bytes(context, self, n);
}
@JRubyMethod(meta = true)
public static IRubyObject hex(ThreadContext context, IRubyObject self) {
return RubyString.newStringNoCopy(context.runtime, ConvertBytes.twosComplementToHexBytes(nextBytes(context, 16), false));
}
@JRubyMethod(meta = true)
public static IRubyObject hex(ThreadContext context, IRubyObject self, IRubyObject n) {
return RubyString.newStringNoCopy(context.runtime, ConvertBytes.twosComplementToHexBytes(nextBytes(context, n), false));
}
@JRubyMethod(meta = true)
public static IRubyObject uuid(ThreadContext context, IRubyObject self) {
return RubyString.newStringNoCopy(context.runtime, ConvertBytes.bytesToUUIDBytes(nextBytes(context, 16), false));
}
private static byte[] nextBytes(ThreadContext context, IRubyObject n) {
return nextBytes(context, n.isNil() ? 16 : n.convertToInteger().getIntValue());
}
private static byte[] nextBytes(ThreadContext context, int size) {
if (size < 0) throw context.runtime.newArgumentError("negative argument: " + size);
byte[] bytes = new byte[size];
getSecureRandom(context).nextBytes(bytes);
return bytes;
}
@JRubyMethod(meta = true, name = {"random_number", "rand"})
public static IRubyObject random_number(ThreadContext context, IRubyObject self) {
return randomDouble(context);
}
@JRubyMethod(meta = true, name = {"random_number", "rand"})
public static IRubyObject random_number(ThreadContext context, IRubyObject self, IRubyObject n) {
if (n instanceof RubyFixnum) {
final long bound = ((RubyFixnum) n).getLongValue();
return ( bound < 0 ) ? randomDouble(context) : randomFixnum(context, 0, bound - 1);
}
if (n instanceof RubyFloat) {
final double bound = ((RubyFloat) n).getDoubleValue();
return ( bound < 0 ) ? randomDouble(context) : randomDouble(context, 0, bound - Double.MIN_VALUE);
}
if (n instanceof RubyBignum) {
final BigInteger bound = ((RubyBignum) n).getBigIntegerValue();
return ( bound.signum() < 0 ) ? randomDouble(context) : randomBignum(context, 0, bound);
}
if (n instanceof RubyRange) {
final IRubyObject beg = ((RubyRange) n).begin(context);
final IRubyObject end = ((RubyRange) n).end(context);
final boolean exclude = ((RubyRange) n).isExcludeEnd();
if (beg instanceof RubyFixnum && end instanceof RubyFixnum) {
long lower = ((RubyFixnum) beg).getLongValue();
long upper = ((RubyFixnum) end).getLongValue();
if ( lower > upper ) return randomDouble(context);
if ( exclude ) upper--;
return randomFixnum(context, lower, upper);
}
if (beg instanceof RubyInteger && end instanceof RubyInteger) {
BigInteger lower = ((RubyInteger) beg).getBigIntegerValue();
BigInteger upper = ((RubyInteger) end).getBigIntegerValue();
if ( lower.compareTo(upper) > 0 ) return randomDouble(context);
if ( ! exclude ) upper = upper.add(BigInteger.ONE);
return randomBignum(context, lower, upper);
}
if (beg instanceof RubyFloat && end instanceof RubyFloat) {
double lower = ((RubyFloat) beg).getDoubleValue();
double upper = ((RubyFloat) end).getDoubleValue();
if ( lower > upper ) return randomDouble(context);
if ( exclude ) upper = upper - Double.MIN_VALUE;
return randomDouble(context, lower, upper);
}
}
throw context.runtime.newArgumentError("invalid argument - " + n.anyToString());
}
private static RubyFixnum randomFixnum(final ThreadContext context, final long lower, final long upper) {
double rnd = getSecureRandom(context).nextDouble();
rnd = rnd * upper + (1.0 - rnd) * lower + rnd;
return context.runtime.newFixnum((long) Math.floor(rnd));
}
private static RubyBignum randomBignum(final ThreadContext context, final Number lower, final BigInteger upperExc) {
BigInteger lowerBig = lower instanceof BigInteger ? (BigInteger) lower : BigInteger.valueOf(lower.longValue());
BigInteger bound = upperExc.subtract(lowerBig);
BigInteger rnd = nextBigInteger(getSecureRandom(context), bound, bound.bitLength());
return RubyBignum.newBignum(context.runtime, rnd.add(lowerBig));
}
private static final int BI_ADD_BITS = 96;
private static BigInteger nextBigInteger(final SecureRandom random, final BigInteger bound, final int bits) {
BigInteger val = new BigInteger(bits + BI_ADD_BITS, random);
BigInteger rnd = val.mod(bound);
if (val.add(bound).subtract(rnd).subtract(BigInteger.ONE).bitLength() >= bits + BI_ADD_BITS) {
return nextBigInteger(random, bound, bits);
}
return rnd;
}
private static RubyFloat randomDouble(final ThreadContext context, final double lower, final double upper) {
double rnd = getSecureRandom(context).nextDouble();
return context.runtime.newFloat( rnd * upper + (1.0 - rnd) * lower );
}
private static RubyFloat randomDouble(final ThreadContext context) {
return context.runtime.newFloat( getSecureRandom(context).nextDouble() );
}
private static SecureRandom getSecureRandom(ThreadContext context) {
return context.getSecureRandom();
}
}