Improve performance of bignum[beg, len] (#14007)

Implement rb_big_aref2.
Taking a small slice from large bignum was slow in rb_int_aref2.
This commit is contained in:
tomoya ishida 2025-07-30 01:34:13 +09:00 committed by GitHub
parent 46d106f7ab
commit a66e4f2154
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 131 additions and 18 deletions

View file

@ -6757,6 +6757,73 @@ rb_big_aref(VALUE x, VALUE y)
return (xds[s1] & bit) ? INT2FIX(1) : INT2FIX(0);
}
VALUE
rb_big_aref2(VALUE x, VALUE beg, VALUE len)
{
BDIGIT *xds, *vds;
VALUE v;
size_t copy_begin, xn, shift;
ssize_t begin, length, end;
bool negative_add_one;
beg = rb_to_int(beg);
len = rb_to_int(len);
length = NUM2SSIZET(len);
begin = NUM2SSIZET(beg);
end = NUM2SSIZET(rb_int_plus(beg, len));
shift = begin < 0 ? -begin : 0;
xn = BIGNUM_LEN(x);
xds = BDIGITS(x);
if (length < 0) return rb_big_rshift(x, beg);
if (length == 0 || end <= 0) return INT2FIX(0);
if (begin < 0) begin = 0;
if ((size_t)(end - 1) / BITSPERDIG >= xn) {
/* end > xn * BITSPERDIG */
end = xn * BITSPERDIG;
}
if ((size_t)begin / BITSPERDIG < xn) {
/* begin < xn * BITSPERDIG */
size_t shift_bits, copy_end;
copy_begin = begin / BITSPERDIG;
shift_bits = begin % BITSPERDIG;
copy_end = (end - 1) / BITSPERDIG + 1;
v = bignew(copy_end - copy_begin, 1);
vds = BDIGITS(v);
MEMCPY(vds, xds + copy_begin, BDIGIT, copy_end - copy_begin);
negative_add_one = (vds[0] & ((1 << shift_bits) - 1)) == 0;
v = bignorm(v);
if (shift_bits) v = rb_int_rshift(v, SIZET2NUM(shift_bits));
}
else {
/* Out of range */
v = INT2FIX(0);
negative_add_one = false;
copy_begin = begin = end = 0;
}
if (BIGNUM_NEGATIVE_P(x)) {
size_t mask_size = length - shift;
VALUE mask = rb_int_minus(rb_int_lshift(INT2FIX(1), SIZET2NUM(mask_size)), INT2FIX(1));
v = rb_int_xor(v, mask);
for (size_t i = 0; negative_add_one && i < copy_begin; i++) {
if (xds[i]) negative_add_one = false;
}
if (negative_add_one) v = rb_int_plus(v, INT2FIX(1));
v = rb_int_and(v, mask);
}
else {
size_t mask_size = (size_t)end - begin;
VALUE mask = rb_int_minus(rb_int_lshift(INT2FIX(1), SIZET2NUM(mask_size)), INT2FIX(1));
v = rb_int_and(v, mask);
}
RB_GC_GUARD(x);
if (shift) v = rb_int_lshift(v, SSIZET2NUM(shift));
return v;
}
VALUE
rb_big_hash(VALUE x)
{

View file

@ -121,6 +121,7 @@ VALUE rb_integer_float_eq(VALUE x, VALUE y);
VALUE rb_str_convert_to_inum(VALUE str, int base, int badcheck, int raise_exception);
VALUE rb_big_comp(VALUE x);
VALUE rb_big_aref(VALUE x, VALUE y);
VALUE rb_big_aref2(VALUE num, VALUE beg, VALUE len);
VALUE rb_big_abs(VALUE x);
VALUE rb_big_size_m(VALUE big);
VALUE rb_big_bit_length(VALUE big);

View file

@ -85,6 +85,7 @@ VALUE rb_int_cmp(VALUE x, VALUE y);
VALUE rb_int_equal(VALUE x, VALUE y);
VALUE rb_int_divmod(VALUE x, VALUE y);
VALUE rb_int_and(VALUE x, VALUE y);
VALUE rb_int_xor(VALUE x, VALUE y);
VALUE rb_int_lshift(VALUE x, VALUE y);
VALUE rb_int_rshift(VALUE x, VALUE y);
VALUE rb_int_div(VALUE x, VALUE y);

View file

@ -5115,8 +5115,8 @@ fix_xor(VALUE x, VALUE y)
*
*/
static VALUE
int_xor(VALUE x, VALUE y)
VALUE
rb_int_xor(VALUE x, VALUE y)
{
if (FIXNUM_P(x)) {
return fix_xor(x, y);
@ -5288,10 +5288,23 @@ generate_mask(VALUE len)
return rb_int_minus(rb_int_lshift(INT2FIX(1), len), INT2FIX(1));
}
static VALUE
int_aref2(VALUE num, VALUE beg, VALUE len)
{
if (RB_TYPE_P(num, T_BIGNUM)) {
return rb_big_aref2(num, beg, len);
}
else {
num = rb_int_rshift(num, beg);
VALUE mask = generate_mask(len);
return rb_int_and(num, mask);
}
}
static VALUE
int_aref1(VALUE num, VALUE arg)
{
VALUE orig_num = num, beg, end;
VALUE beg, end;
int excl;
if (rb_range_values(arg, &beg, &end, &excl)) {
@ -5311,22 +5324,19 @@ int_aref1(VALUE num, VALUE arg)
return INT2FIX(0);
}
}
num = rb_int_rshift(num, beg);
int cmp = compare_indexes(beg, end);
if (!NIL_P(end) && cmp < 0) {
VALUE len = rb_int_minus(end, beg);
if (!excl) len = rb_int_plus(len, INT2FIX(1));
VALUE mask = generate_mask(len);
num = rb_int_and(num, mask);
return int_aref2(num, beg, len);
}
else if (cmp == 0) {
if (excl) return INT2FIX(0);
num = orig_num;
arg = beg;
goto one_bit;
}
return num;
return rb_int_rshift(num, beg);
}
one_bit:
@ -5339,15 +5349,6 @@ one_bit:
return Qnil;
}
static VALUE
int_aref2(VALUE num, VALUE beg, VALUE len)
{
num = rb_int_rshift(num, beg);
VALUE mask = generate_mask(len);
num = rb_int_and(num, mask);
return num;
}
/*
* call-seq:
* self[offset] -> 0 or 1
@ -6366,7 +6367,7 @@ Init_Numeric(void)
rb_define_method(rb_cInteger, "&", rb_int_and, 1);
rb_define_method(rb_cInteger, "|", int_or, 1);
rb_define_method(rb_cInteger, "^", int_xor, 1);
rb_define_method(rb_cInteger, "^", rb_int_xor, 1);
rb_define_method(rb_cInteger, "[]", int_aref, -1);
rb_define_method(rb_cInteger, "<<", rb_int_lshift, 1);

View file

@ -605,6 +605,49 @@ class TestBignum < Test::Unit::TestCase
assert_equal(1, (-2**(BIGNUM_MIN_BITS*4))[BIGNUM_MIN_BITS*4])
end
def test_aref2
x = (0x123456789abcdef << (BIGNUM_MIN_BITS + 32)) | 0x12345678
assert_equal(x, x[0, x.bit_length])
assert_equal(x >> 10, x[10, x.bit_length])
assert_equal(0x45678, x[0, 20])
assert_equal(0x6780, x[-4, 16])
assert_equal(0x123456, x[x.bit_length - 21, 40])
assert_equal(0x6789ab, x[x.bit_length - 41, 24])
assert_equal(0, x[-20, 10])
assert_equal(0, x[x.bit_length + 10, 10])
assert_equal(0, x[5, 0])
assert_equal(0, (-x)[5, 0])
assert_equal(x >> 5, x[5, -1])
assert_equal(x << 5, x[-5, -1])
assert_equal((-x) >> 5, (-x)[5, -1])
assert_equal((-x) << 5, (-x)[-5, -1])
assert_equal(x << 5, x[-5, FIXNUM_MAX])
assert_equal(x >> 5, x[5, FIXNUM_MAX])
assert_equal(0, x[FIXNUM_MIN, 100])
assert_equal(0, (-x)[FIXNUM_MIN, 100])
y = (x << 160) | 0x1234_0000_0000_0000_1234_0000_0000_0000
assert_equal(0xffffedcc00, (-y)[40, 40])
assert_equal(0xfffffffedc, (-y)[52, 40])
assert_equal(0xffffedcbff, (-y)[104, 40])
assert_equal(0xfffff6e5d4, (-y)[y.bit_length - 20, 40])
assert_equal(0, (-y)[-20, 10])
assert_equal(0xfff, (-y)[y.bit_length + 10, 12])
z = (1 << (BIGNUM_MIN_BITS * 2)) - 1
assert_equal(0x400, (-z)[-10, 20])
assert_equal(1, (-z)[0, 20])
assert_equal(0, (-z)[10, 20])
assert_equal(1, (-z)[0, z.bit_length])
assert_equal(0, (-z)[z.bit_length - 10, 10])
assert_equal(0x400, (-z)[z.bit_length - 10, 11])
assert_equal(0xfff, (-z)[z.bit_length, 12])
assert_equal(0xfff00, (-z)[z.bit_length - 8, 20])
end
def test_hash
assert_nothing_raised { T31P.hash }
end