ZJIT: Support invalidating constant patch points (#13998)

This commit is contained in:
Stan Lo 2025-07-28 22:48:41 +01:00 committed by GitHub
parent 23000e7123
commit a0d0b84bad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 108 additions and 3 deletions

View file

@ -889,6 +889,48 @@ class TestZJIT < Test::Unit::TestCase
end end
end end
def test_constant_invalidation
assert_compiles '123', <<~RUBY, call_threshold: 2, insns: [:opt_getconstant_path]
class C; end
def test = C
test
test
C = 123
test
RUBY
end
def test_constant_path_invalidation
assert_compiles '["Foo::C", "Foo::C", "Bar::C"]', <<~RUBY, call_threshold: 2, insns: [:opt_getconstant_path]
module A
module B; end
end
module Foo
C = "Foo::C"
end
module Bar
C = "Bar::C"
end
A::B = Foo
def test = A::B::C
result = []
result << test
result << test
A::B = Bar
result << test
result
RUBY
end
def test_dupn def test_dupn
assert_compiles '[[1], [1, 1], :rhs, [nil, :rhs]]', <<~RUBY, insns: [:dupn] assert_compiles '[[1], [1, 1], :rhs, [nil, :rhs]]', <<~RUBY, insns: [:dupn]
def test(array) = (array[1, 2] ||= :rhs) def test(array) = (array[1, 2] ||= :rhs)

View file

@ -148,6 +148,7 @@ rb_clear_constant_cache_for_id(ID id)
} }
rb_yjit_constant_state_changed(id); rb_yjit_constant_state_changed(id);
rb_zjit_constant_state_changed(id);
} }
static void static void

2
zjit.h
View file

@ -14,6 +14,7 @@ void rb_zjit_profile_enable(const rb_iseq_t *iseq);
void rb_zjit_bop_redefined(int redefined_flag, enum ruby_basic_operators bop); void rb_zjit_bop_redefined(int redefined_flag, enum ruby_basic_operators bop);
void rb_zjit_cme_invalidate(const rb_callable_method_entry_t *cme); void rb_zjit_cme_invalidate(const rb_callable_method_entry_t *cme);
void rb_zjit_invalidate_ep_is_bp(const rb_iseq_t *iseq); void rb_zjit_invalidate_ep_is_bp(const rb_iseq_t *iseq);
void rb_zjit_constant_state_changed(ID id);
void rb_zjit_iseq_mark(void *payload); void rb_zjit_iseq_mark(void *payload);
void rb_zjit_iseq_update_references(void *payload); void rb_zjit_iseq_update_references(void *payload);
#else #else
@ -24,6 +25,7 @@ static inline void rb_zjit_profile_enable(const rb_iseq_t *iseq) {}
static inline void rb_zjit_bop_redefined(int redefined_flag, enum ruby_basic_operators bop) {} static inline void rb_zjit_bop_redefined(int redefined_flag, enum ruby_basic_operators bop) {}
static inline void rb_zjit_cme_invalidate(const rb_callable_method_entry_t *cme) {} static inline void rb_zjit_cme_invalidate(const rb_callable_method_entry_t *cme) {}
static inline void rb_zjit_invalidate_ep_is_bp(const rb_iseq_t *iseq) {} static inline void rb_zjit_invalidate_ep_is_bp(const rb_iseq_t *iseq) {}
static inline void rb_zjit_constant_state_changed(ID id) {}
#endif // #if USE_YJIT #endif // #if USE_YJIT
#endif // #ifndef ZJIT_H #endif // #ifndef ZJIT_H

View file

@ -4,7 +4,7 @@ use std::ffi::{c_int};
use crate::asm::Label; use crate::asm::Label;
use crate::backend::current::{Reg, ALLOC_REGS}; use crate::backend::current::{Reg, ALLOC_REGS};
use crate::invariants::{track_bop_assumption, track_cme_assumption}; use crate::invariants::{track_bop_assumption, track_cme_assumption, track_stable_constant_names_assumption};
use crate::gc::{get_or_create_iseq_payload, append_gc_offsets}; use crate::gc::{get_or_create_iseq_payload, append_gc_offsets};
use crate::state::ZJITState; use crate::state::ZJITState;
use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr}; use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr};
@ -505,6 +505,10 @@ fn gen_patch_point(jit: &mut JITState, asm: &mut Assembler, invariant: &Invarian
let side_exit_ptr = cb.resolve_label(label); let side_exit_ptr = cb.resolve_label(label);
track_cme_assumption(cme, code_ptr, side_exit_ptr); track_cme_assumption(cme, code_ptr, side_exit_ptr);
} }
Invariant::StableConstantNames { idlist } => {
let side_exit_ptr = cb.resolve_label(label);
track_stable_constant_names_assumption(idlist, code_ptr, side_exit_ptr);
}
_ => { _ => {
debug!("ZJIT: gen_patch_point: unimplemented invariant {invariant:?}"); debug!("ZJIT: gen_patch_point: unimplemented invariant {invariant:?}");
return; return;

View file

@ -259,7 +259,7 @@ pub struct VALUE(pub usize);
/// An interned string. See [ids] and methods this type. /// An interned string. See [ids] and methods this type.
/// `0` is a sentinal value for IDs. /// `0` is a sentinal value for IDs.
#[repr(transparent)] #[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, Debug)] #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
pub struct ID(pub ::std::os::raw::c_ulong); pub struct ID(pub ::std::os::raw::c_ulong);
/// Pointer to an ISEQ /// Pointer to an ISEQ

View file

@ -1,6 +1,6 @@
use std::{collections::{HashMap, HashSet}}; use std::{collections::{HashMap, HashSet}};
use crate::{backend::lir::{asm_comment, Assembler}, cruby::{rb_callable_method_entry_t, ruby_basic_operators, src_loc, with_vm_lock, IseqPtr, RedefinitionFlag}, hir::Invariant, options::debug, state::{zjit_enabled_p, ZJITState}, virtualmem::CodePtr}; use crate::{backend::lir::{asm_comment, Assembler}, cruby::{rb_callable_method_entry_t, ruby_basic_operators, src_loc, with_vm_lock, IseqPtr, RedefinitionFlag, ID}, hir::Invariant, options::debug, state::{zjit_enabled_p, ZJITState}, virtualmem::CodePtr};
#[derive(Debug, Eq, Hash, PartialEq)] #[derive(Debug, Eq, Hash, PartialEq)]
struct Jump { struct Jump {
@ -23,6 +23,9 @@ pub struct Invariants {
/// Map from CME to patch points that assume the method hasn't been redefined /// Map from CME to patch points that assume the method hasn't been redefined
cme_patch_points: HashMap<*const rb_callable_method_entry_t, HashSet<Jump>>, cme_patch_points: HashMap<*const rb_callable_method_entry_t, HashSet<Jump>>,
/// Map from constant ID to patch points that assume the constant hasn't been redefined
constant_state_patch_points: HashMap<ID, HashSet<Jump>>,
} }
/// Called when a basic operator is redefined. Note that all the blocks assuming /// Called when a basic operator is redefined. Note that all the blocks assuming
@ -116,6 +119,30 @@ pub fn track_cme_assumption(
}); });
} }
/// Track a patch point for each constant name in a constant path assumption.
pub fn track_stable_constant_names_assumption(
idlist: *const ID,
patch_point_ptr: CodePtr,
side_exit_ptr: CodePtr
) {
let invariants = ZJITState::get_invariants();
let mut idx = 0;
loop {
let id = unsafe { *idlist.wrapping_add(idx) };
if id.0 == 0 {
break;
}
invariants.constant_state_patch_points.entry(id).or_default().insert(Jump {
from: patch_point_ptr,
to: side_exit_ptr,
});
idx += 1;
}
}
/// Called when a method is redefined. Invalidates all JIT code that depends on the CME. /// Called when a method is redefined. Invalidates all JIT code that depends on the CME.
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
pub extern "C" fn rb_zjit_cme_invalidate(cme: *const rb_callable_method_entry_t) { pub extern "C" fn rb_zjit_cme_invalidate(cme: *const rb_callable_method_entry_t) {
@ -144,3 +171,32 @@ pub extern "C" fn rb_zjit_cme_invalidate(cme: *const rb_callable_method_entry_t)
} }
}); });
} }
/// Called when a constant is redefined. Invalidates all JIT code that depends on the constant.
#[unsafe(no_mangle)]
pub extern "C" fn rb_zjit_constant_state_changed(id: ID) {
// If ZJIT isn't enabled, do nothing
if !zjit_enabled_p() {
return;
}
with_vm_lock(src_loc!(), || {
let invariants = ZJITState::get_invariants();
if let Some(jumps) = invariants.constant_state_patch_points.get(&id) {
let cb = ZJITState::get_code_block();
debug!("Constant state changed: {:?}", id);
// Invalidate all patch points for this constant ID
for jump in jumps {
cb.with_write_ptr(jump.from, |cb| {
let mut asm = Assembler::new();
asm_comment!(asm, "Constant state changed: {:?}", id);
asm.jmp(jump.to.into());
asm.compile(cb).expect("can write existing code");
});
}
cb.mark_all_executable();
}
});
}