1
0
Fork 0
mirror of https://github.com/lua/lua.git synced 2025-07-24 04:32:35 +00:00

Small optimization in 'project' from math.random

When computing the Mersenne number, instead of spreading 1's a fixed
number of times (with shifts of 1, 2, 4, 8, 16, and 32), spread only
until the number becomes a Mersenne number.
This commit is contained in:
Roberto Ierusalimschy 2025-03-27 15:22:40 -03:00
commit 37a1b72706

View file

@ -533,7 +533,7 @@ typedef struct {
** Project the random integer 'ran' into the interval [0, n]. ** Project the random integer 'ran' into the interval [0, n].
** Because 'ran' has 2^B possible values, the projection can only be ** Because 'ran' has 2^B possible values, the projection can only be
** uniform when the size of the interval is a power of 2 (exact ** uniform when the size of the interval is a power of 2 (exact
** division). Otherwise, to get a uniform projection into [0, n], we ** division). So, to get a uniform projection into [0, n], we
** first compute 'lim', the smallest Mersenne number not smaller than ** first compute 'lim', the smallest Mersenne number not smaller than
** 'n'. We then project 'ran' into the interval [0, lim]. If the result ** 'n'. We then project 'ran' into the interval [0, lim]. If the result
** is inside [0, n], we are done. Otherwise, we try with another 'ran', ** is inside [0, n], we are done. Otherwise, we try with another 'ran',
@ -541,26 +541,14 @@ typedef struct {
*/ */
static lua_Unsigned project (lua_Unsigned ran, lua_Unsigned n, static lua_Unsigned project (lua_Unsigned ran, lua_Unsigned n,
RanState *state) { RanState *state) {
if ((n & (n + 1)) == 0) /* is 'n + 1' a power of 2? */ lua_Unsigned lim = n; /* to compute the Mersenne number */
return ran & n; /* no bias */ int sh; /* how much to spread bits to the right in 'lim' */
else { /* spread '1' bits in 'lim' until it becomes a Mersenne number */
lua_Unsigned lim = n; for (sh = 1; (lim & (lim + 1)) != 0; sh *= 2)
/* compute the smallest (2^b - 1) not smaller than 'n' */ lim |= (lim >> sh); /* spread '1's to the right */
lim |= (lim >> 1); while ((ran &= lim) > n) /* project 'ran' into [0..lim] and test */
lim |= (lim >> 2); ran = I2UInt(nextrand(state->s)); /* not inside [0..n]? try again */
lim |= (lim >> 4); return ran;
lim |= (lim >> 8);
lim |= (lim >> 16);
#if (LUA_MAXUNSIGNED >> 31) >= 3
lim |= (lim >> 32); /* integer type has more than 32 bits */
#endif
lua_assert((lim & (lim + 1)) == 0 /* 'lim + 1' is a power of 2, */
&& lim >= n /* not smaller than 'n', */
&& (lim >> 1) < n); /* and it is the smallest one */
while ((ran &= lim) > n) /* project 'ran' into [0..lim] */
ran = I2UInt(nextrand(state->s)); /* not inside [0..n]? try again */
return ran;
}
} }