[Bug #21449] Fix Set#divide{|a,b|} using Union-find structure (#13680)

* [Bug #21449] Fix Set#divide{|a,b|} using Union-find structure

Implements Union-find structure with path compression.
Since divide{|a,b|} calls the given block n**2 times in the worst case, there is no need to implement union-by-rank or union-by-size optimization.

* Avoid internal arrays from being modified from block passed to Set#divide

Internal arrays can be modified from yielded block through ObjectSpace.
Freeze readonly array, use ALLOCV_N instead of mutable array.
This commit is contained in:
tomoya ishida 2025-06-24 02:56:04 +09:00 committed by GitHub
parent db6f397987
commit 67346a7d94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 64 additions and 66 deletions

124
set.c
View file

@ -843,66 +843,72 @@ set_i_classify(VALUE set)
return args[0]; return args[0];
} }
struct set_divide_args { // Union-find with path compression
VALUE self; static long
VALUE set_class; set_divide_union_find_root(long *uf_parents, long index, long *tmp_array)
VALUE final_set;
VALUE hash;
VALUE current_set;
VALUE current_item;
unsigned long ni;
unsigned long nj;
};
static VALUE
set_divide_block0(RB_BLOCK_CALL_FUNC_ARGLIST(j, arg))
{ {
struct set_divide_args *args = (struct set_divide_args *)arg; long root = uf_parents[index];
if (args->nj > args->ni) { long update_size = 0;
VALUE i = args->current_item; while (root != index) {
if (RTEST(rb_yield_values(2, i, j)) && RTEST(rb_yield_values(2, j, i))) { tmp_array[update_size++] = index;
VALUE hash = args->hash; index = root;
if (args->current_set == Qnil) { root = uf_parents[index];
VALUE set = rb_hash_aref(hash, j);
if (set == Qnil) {
VALUE both[2] = {i, j};
set = set_s_create(2, both, args->set_class);
rb_hash_aset(hash, i, set);
rb_hash_aset(hash, j, set);
set_i_add(args->final_set, set);
} }
else { for (long j = 0; j < update_size; j++) {
set_i_add(set, i); long idx = tmp_array[j];
rb_hash_aset(hash, i, set); uf_parents[idx] = root;
} }
args->current_set = set; return root;
} }
else {
set_i_add(args->current_set, j); static void
rb_hash_aset(hash, j, args->current_set); set_divide_union_find_merge(long *uf_parents, long i, long j, long *tmp_array)
} {
} long root_i = set_divide_union_find_root(uf_parents, i, tmp_array);
} long root_j = set_divide_union_find_root(uf_parents, j, tmp_array);
args->nj++; if (root_i != root_j) uf_parents[root_j] = root_i;
return j;
} }
static VALUE static VALUE
set_divide_block(RB_BLOCK_CALL_FUNC_ARGLIST(i, arg)) set_divide_arity2(VALUE set)
{ {
struct set_divide_args *args = (struct set_divide_args *)arg; VALUE tmp, uf;
VALUE hash = args->hash; long size, *uf_parents, *tmp_array;
args->current_set = rb_hash_aref(hash, i); VALUE set_class = rb_obj_class(set);
args->current_item = i; VALUE items = set_i_to_a(set);
args->nj = 0; rb_ary_freeze(items);
rb_block_call(args->self, id_each, 0, 0, set_divide_block0, arg); size = RARRAY_LEN(items);
if (args->current_set == Qnil) { tmp_array = ALLOCV_N(long, tmp, size);
VALUE set = set_s_create(1, &i, args->set_class); uf_parents = ALLOCV_N(long, uf, size);
rb_hash_aset(hash, i, set); for (long i = 0; i < size; i++) {
set_i_add(args->final_set, set); uf_parents[i] = i;
} }
args->ni++; for (long i = 0; i < size - 1; i++) {
return i; VALUE item1 = RARRAY_AREF(items, i);
for (long j = i + 1; j < size; j++) {
VALUE item2 = RARRAY_AREF(items, j);
if (RTEST(rb_yield_values(2, item1, item2)) &&
RTEST(rb_yield_values(2, item2, item1))) {
set_divide_union_find_merge(uf_parents, i, j, tmp_array);
}
}
}
VALUE final_set = set_s_create(0, 0, rb_cSet);
VALUE hash = rb_hash_new();
for (long i = 0; i < size; i++) {
VALUE v = RARRAY_AREF(items, i);
long root = set_divide_union_find_root(uf_parents, i, tmp_array);
VALUE set = rb_hash_aref(hash, LONG2FIX(root));
if (set == Qnil) {
set = set_s_create(0, 0, set_class);
rb_hash_aset(hash, LONG2FIX(root), set);
set_i_add(final_set, set);
}
set_i_add(set, v);
}
ALLOCV_END(tmp);
ALLOCV_END(uf);
return final_set;
} }
static void set_merge_enum_into(VALUE set, VALUE arg); static void set_merge_enum_into(VALUE set, VALUE arg);
@ -936,19 +942,7 @@ set_i_divide(VALUE set)
RETURN_SIZED_ENUMERATOR(set, 0, 0, set_enum_size); RETURN_SIZED_ENUMERATOR(set, 0, 0, set_enum_size);
if (rb_block_arity() == 2) { if (rb_block_arity() == 2) {
VALUE final_set = set_s_create(0, 0, rb_cSet); return set_divide_arity2(set);
struct set_divide_args args = {
.self = set,
.set_class = rb_obj_class(set),
.final_set = final_set,
.hash = rb_hash_new(),
.current_set = 0,
.current_item = 0,
.ni = 0,
.nj = 0
};
rb_block_call(set, id_each, 0, 0, set_divide_block, (VALUE)&args);
return final_set;
} }
VALUE values = rb_hash_values(set_i_classify(set)); VALUE values = rb_hash_values(set_i_classify(set));

View file

@ -781,6 +781,10 @@ class TC_Set < Test::Unit::TestCase
ret.each { |s| n += s.size } ret.each { |s| n += s.size }
assert_equal(set.size, n) assert_equal(set.size, n)
assert_equal(set, ret.flatten) assert_equal(set, ret.flatten)
set = Set[2,12,9,11,13,4,10,15,3,8,5,0,1,7,14]
ret = set.divide { |a,b| (a - b).abs == 1 }
assert_equal(2, ret.size)
end end
def test_freeze def test_freeze