[Bug #21449] Fix Set#divide{|a,b|} using Union-find structure (#13680)

* [Bug #21449] Fix Set#divide{|a,b|} using Union-find structure

Implements Union-find structure with path compression.
Since divide{|a,b|} calls the given block n**2 times in the worst case, there is no need to implement union-by-rank or union-by-size optimization.

* Avoid internal arrays from being modified from block passed to Set#divide

Internal arrays can be modified from yielded block through ObjectSpace.
Freeze readonly array, use ALLOCV_N instead of mutable array.
This commit is contained in:
tomoya ishida 2025-06-24 02:56:04 +09:00 committed by GitHub
parent db6f397987
commit 67346a7d94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 64 additions and 66 deletions

126
set.c
View file

@ -843,66 +843,72 @@ set_i_classify(VALUE set)
return args[0];
}
struct set_divide_args {
VALUE self;
VALUE set_class;
VALUE final_set;
VALUE hash;
VALUE current_set;
VALUE current_item;
unsigned long ni;
unsigned long nj;
};
static VALUE
set_divide_block0(RB_BLOCK_CALL_FUNC_ARGLIST(j, arg))
// Union-find with path compression
static long
set_divide_union_find_root(long *uf_parents, long index, long *tmp_array)
{
struct set_divide_args *args = (struct set_divide_args *)arg;
if (args->nj > args->ni) {
VALUE i = args->current_item;
if (RTEST(rb_yield_values(2, i, j)) && RTEST(rb_yield_values(2, j, i))) {
VALUE hash = args->hash;
if (args->current_set == Qnil) {
VALUE set = rb_hash_aref(hash, j);
if (set == Qnil) {
VALUE both[2] = {i, j};
set = set_s_create(2, both, args->set_class);
rb_hash_aset(hash, i, set);
rb_hash_aset(hash, j, set);
set_i_add(args->final_set, set);
}
else {
set_i_add(set, i);
rb_hash_aset(hash, i, set);
}
args->current_set = set;
}
else {
set_i_add(args->current_set, j);
rb_hash_aset(hash, j, args->current_set);
}
}
long root = uf_parents[index];
long update_size = 0;
while (root != index) {
tmp_array[update_size++] = index;
index = root;
root = uf_parents[index];
}
args->nj++;
return j;
for (long j = 0; j < update_size; j++) {
long idx = tmp_array[j];
uf_parents[idx] = root;
}
return root;
}
static void
set_divide_union_find_merge(long *uf_parents, long i, long j, long *tmp_array)
{
long root_i = set_divide_union_find_root(uf_parents, i, tmp_array);
long root_j = set_divide_union_find_root(uf_parents, j, tmp_array);
if (root_i != root_j) uf_parents[root_j] = root_i;
}
static VALUE
set_divide_block(RB_BLOCK_CALL_FUNC_ARGLIST(i, arg))
set_divide_arity2(VALUE set)
{
struct set_divide_args *args = (struct set_divide_args *)arg;
VALUE hash = args->hash;
args->current_set = rb_hash_aref(hash, i);
args->current_item = i;
args->nj = 0;
rb_block_call(args->self, id_each, 0, 0, set_divide_block0, arg);
if (args->current_set == Qnil) {
VALUE set = set_s_create(1, &i, args->set_class);
rb_hash_aset(hash, i, set);
set_i_add(args->final_set, set);
VALUE tmp, uf;
long size, *uf_parents, *tmp_array;
VALUE set_class = rb_obj_class(set);
VALUE items = set_i_to_a(set);
rb_ary_freeze(items);
size = RARRAY_LEN(items);
tmp_array = ALLOCV_N(long, tmp, size);
uf_parents = ALLOCV_N(long, uf, size);
for (long i = 0; i < size; i++) {
uf_parents[i] = i;
}
args->ni++;
return i;
for (long i = 0; i < size - 1; i++) {
VALUE item1 = RARRAY_AREF(items, i);
for (long j = i + 1; j < size; j++) {
VALUE item2 = RARRAY_AREF(items, j);
if (RTEST(rb_yield_values(2, item1, item2)) &&
RTEST(rb_yield_values(2, item2, item1))) {
set_divide_union_find_merge(uf_parents, i, j, tmp_array);
}
}
}
VALUE final_set = set_s_create(0, 0, rb_cSet);
VALUE hash = rb_hash_new();
for (long i = 0; i < size; i++) {
VALUE v = RARRAY_AREF(items, i);
long root = set_divide_union_find_root(uf_parents, i, tmp_array);
VALUE set = rb_hash_aref(hash, LONG2FIX(root));
if (set == Qnil) {
set = set_s_create(0, 0, set_class);
rb_hash_aset(hash, LONG2FIX(root), set);
set_i_add(final_set, set);
}
set_i_add(set, v);
}
ALLOCV_END(tmp);
ALLOCV_END(uf);
return final_set;
}
static void set_merge_enum_into(VALUE set, VALUE arg);
@ -936,19 +942,7 @@ set_i_divide(VALUE set)
RETURN_SIZED_ENUMERATOR(set, 0, 0, set_enum_size);
if (rb_block_arity() == 2) {
VALUE final_set = set_s_create(0, 0, rb_cSet);
struct set_divide_args args = {
.self = set,
.set_class = rb_obj_class(set),
.final_set = final_set,
.hash = rb_hash_new(),
.current_set = 0,
.current_item = 0,
.ni = 0,
.nj = 0
};
rb_block_call(set, id_each, 0, 0, set_divide_block, (VALUE)&args);
return final_set;
return set_divide_arity2(set);
}
VALUE values = rb_hash_values(set_i_classify(set));