ZJIT: Support invalidating on method redefinition (#13875)

ZJIT: Support invalidating method redefinition

This commit adds support for the MethodRedefined invariant to be invalidated
when a method is redefined.

Changes:
- Added CME pointer to the MethodRedefined invariant in HIR
- Updated all places where MethodRedefined invariants are created to
    include the CME pointer
- Added handling for MethodRedefined invariants in gen_patch_point to
    call track_cme_assumption, which registers the patch point for
    invalidation when rb_zjit_cme_invalidate is called

This ensures that when a method is redefined, all JIT code that
depends on that method will be properly invalidated.
This commit is contained in:
Stan Lo 2025-07-18 16:36:51 +01:00 committed by GitHub
parent dafc4e131e
commit 8df61bfc92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 168 additions and 69 deletions

View file

@ -1066,6 +1066,31 @@ class TestZJIT < Test::Unit::TestCase
}, call_threshold: 2
end
# ZJIT currently only generates a MethodRedefined patch point when the method
# is called on the top-level self.
def test_method_redefinition_with_top_self
assert_runs '["original", "redefined"]', %q{
def foo
"original"
end
def test = foo
test; test
result1 = test
# Redefine the method
def foo
"redefined"
end
result2 = test
[result1, result2]
}, call_threshold: 2
end
def test_module_name_with_guard_passes
assert_compiles '"Integer"', %q{
def test(mod)

View file

@ -122,6 +122,7 @@ vm_cme_invalidate(rb_callable_method_entry_t *cme)
RB_DEBUG_COUNTER_INC(cc_cme_invalidate);
rb_yjit_cme_invalidate(cme);
rb_zjit_cme_invalidate(cme);
}
static int

2
zjit.h
View file

@ -12,6 +12,7 @@ void rb_zjit_compile_iseq(const rb_iseq_t *iseq, rb_execution_context_t *ec, boo
void rb_zjit_profile_insn(enum ruby_vminsn_type insn, rb_execution_context_t *ec);
void rb_zjit_profile_enable(const rb_iseq_t *iseq);
void rb_zjit_bop_redefined(int redefined_flag, enum ruby_basic_operators bop);
void rb_zjit_cme_invalidate(const rb_callable_method_entry_t *cme);
void rb_zjit_invalidate_ep_is_bp(const rb_iseq_t *iseq);
void rb_zjit_iseq_mark(void *payload);
void rb_zjit_iseq_update_references(void *payload);
@ -21,6 +22,7 @@ static inline void rb_zjit_compile_iseq(const rb_iseq_t *iseq, rb_execution_cont
static inline void rb_zjit_profile_insn(enum ruby_vminsn_type insn, rb_execution_context_t *ec) {}
static inline void rb_zjit_profile_enable(const rb_iseq_t *iseq) {}
static inline void rb_zjit_bop_redefined(int redefined_flag, enum ruby_basic_operators bop) {}
static inline void rb_zjit_cme_invalidate(const rb_callable_method_entry_t *cme) {}
static inline void rb_zjit_invalidate_ep_is_bp(const rb_iseq_t *iseq) {}
#endif // #if USE_YJIT

View file

@ -3,7 +3,7 @@ use std::rc::Rc;
use crate::asm::Label;
use crate::backend::current::{Reg, ALLOC_REGS};
use crate::invariants::track_bop_assumption;
use crate::invariants::{track_bop_assumption, track_cme_assumption};
use crate::gc::{get_or_create_iseq_payload, append_gc_offsets};
use crate::state::ZJITState;
use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr};
@ -494,6 +494,10 @@ fn gen_patch_point(jit: &mut JITState, asm: &mut Assembler, invariant: &Invarian
let side_exit_ptr = cb.resolve_label(label);
track_bop_assumption(klass, bop, code_ptr, side_exit_ptr);
}
Invariant::MethodRedefined { klass: _, method: _, cme } => {
let side_exit_ptr = cb.resolve_label(label);
track_cme_assumption(cme, code_ptr, side_exit_ptr);
}
_ => {
debug!("ZJIT: gen_patch_point: unimplemented invariant {invariant:?}");
return;

View file

@ -127,6 +127,8 @@ pub enum Invariant {
klass: VALUE,
/// The method ID of the method we want to assume unchanged
method: ID,
/// The callable method entry that we want to track
cme: *const rb_callable_method_entry_t,
},
/// A list of constant expression path segments that must have not been written to for the
/// following code to be valid.
@ -222,12 +224,13 @@ impl<'a> std::fmt::Display for InvariantPrinter<'a> {
}
write!(f, ")")
}
Invariant::MethodRedefined { klass, method } => {
Invariant::MethodRedefined { klass, method, cme } => {
let class_name = get_class_name(klass);
write!(f, "MethodRedefined({class_name}@{:p}, {}@{:p})",
write!(f, "MethodRedefined({class_name}@{:p}, {}@{:p}, cme:{:p})",
self.ptr_map.map_ptr(klass.as_ptr::<VALUE>()),
method.contents_lossy(),
self.ptr_map.map_id(method.0)
self.ptr_map.map_id(method.0),
self.ptr_map.map_ptr(cme)
)
}
Invariant::StableConstantNames { idlist } => {
@ -1537,7 +1540,7 @@ impl Function {
if !can_direct_send(iseq) {
self.push_insn_id(block, insn_id); continue;
}
self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass, method: mid }, state });
self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass, method: mid, cme }, state });
if let Some(expected) = guard_equal_to {
self_val = self.push_insn(block, Insn::GuardBitEquals { val: self_val, expected, state });
}
@ -1656,7 +1659,7 @@ impl Function {
// Filter for simple call sites (i.e. no splats etc.)
if ci_flags & VM_CALL_ARGS_SIMPLE != 0 {
// Commit to the replacement. Put PatchPoint.
fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: recv_class, method: method_id }, state });
fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: recv_class, method: method_id, cme: method }, state });
if let Some(guard_type) = guard_type {
// Guard receiver class
self_val = fun.push_insn(block, Insn::GuardType { val: self_val, guard_type, state });
@ -5514,9 +5517,9 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:5:
bb0(v0:BasicObject):
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008)
v6:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010)
v7:BasicObject = SendWithoutBlockDirect v6, :foo (0x1018)
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010)
v6:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038)
v7:BasicObject = SendWithoutBlockDirect v6, :foo (0x1040)
Return v7
"#]]);
}
@ -5554,9 +5557,9 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:6:
bb0(v0:BasicObject):
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008)
v6:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010)
v7:BasicObject = SendWithoutBlockDirect v6, :foo (0x1018)
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010)
v6:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038)
v7:BasicObject = SendWithoutBlockDirect v6, :foo (0x1040)
Return v7
"#]]);
}
@ -5573,9 +5576,9 @@ mod opt_tests {
fn test@<compiled>:3:
bb0(v0:BasicObject):
v2:Fixnum[3] = Const Value(3)
PatchPoint MethodRedefined(Object@0x1000, Integer@0x1008)
v7:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010)
v8:BasicObject = SendWithoutBlockDirect v7, :Integer (0x1018), v2
PatchPoint MethodRedefined(Object@0x1000, Integer@0x1008, cme:0x1010)
v7:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038)
v8:BasicObject = SendWithoutBlockDirect v7, :Integer (0x1040), v2
Return v8
"#]]);
}
@ -5595,9 +5598,9 @@ mod opt_tests {
bb0(v0:BasicObject):
v2:Fixnum[1] = Const Value(1)
v3:Fixnum[2] = Const Value(2)
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008)
v8:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010)
v9:BasicObject = SendWithoutBlockDirect v8, :foo (0x1018), v2, v3
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010)
v8:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038)
v9:BasicObject = SendWithoutBlockDirect v8, :foo (0x1040), v2, v3
Return v9
"#]]);
}
@ -5618,12 +5621,12 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:7:
bb0(v0:BasicObject):
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008)
v8:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010)
v9:BasicObject = SendWithoutBlockDirect v8, :foo (0x1018)
PatchPoint MethodRedefined(Object@0x1000, bar@0x1020)
v11:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010)
v12:BasicObject = SendWithoutBlockDirect v11, :bar (0x1018)
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010)
v8:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038)
v9:BasicObject = SendWithoutBlockDirect v8, :foo (0x1040)
PatchPoint MethodRedefined(Object@0x1000, bar@0x1048, cme:0x1050)
v11:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038)
v12:BasicObject = SendWithoutBlockDirect v11, :bar (0x1040)
Return v12
"#]]);
}
@ -6149,9 +6152,9 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
PatchPoint MethodRedefined(Integer@0x1000, itself@0x1008)
PatchPoint MethodRedefined(Integer@0x1000, itself@0x1008, cme:0x1010)
v7:Fixnum = GuardType v1, Fixnum
v8:BasicObject = CCall itself@0x1010, v7
v8:BasicObject = CCall itself@0x1038, v7
Return v8
"#]]);
}
@ -6165,8 +6168,8 @@ mod opt_tests {
fn test@<compiled>:2:
bb0(v0:BasicObject):
v3:ArrayExact = NewArray
PatchPoint MethodRedefined(Array@0x1000, itself@0x1008)
v8:BasicObject = CCall itself@0x1010, v3
PatchPoint MethodRedefined(Array@0x1000, itself@0x1008, cme:0x1010)
v8:BasicObject = CCall itself@0x1038, v3
Return v8
"#]]);
}
@ -6184,7 +6187,7 @@ mod opt_tests {
bb0(v0:BasicObject):
v1:NilClassExact = Const Value(nil)
v4:ArrayExact = NewArray
PatchPoint MethodRedefined(Array@0x1000, itself@0x1008)
PatchPoint MethodRedefined(Array@0x1000, itself@0x1008, cme:0x1010)
v7:Fixnum[1] = Const Value(1)
Return v7
"#]]);
@ -6207,7 +6210,7 @@ mod opt_tests {
PatchPoint SingleRactorMode
PatchPoint StableConstantNames(0x1000, M)
v11:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
PatchPoint MethodRedefined(Module@0x1010, name@0x1018)
PatchPoint MethodRedefined(Module@0x1010, name@0x1018, cme:0x1020)
v7:Fixnum[1] = Const Value(1)
Return v7
"#]]);
@ -6226,7 +6229,7 @@ mod opt_tests {
bb0(v0:BasicObject):
v1:NilClassExact = Const Value(nil)
v4:ArrayExact = NewArray
PatchPoint MethodRedefined(Array@0x1000, length@0x1008)
PatchPoint MethodRedefined(Array@0x1000, length@0x1008, cme:0x1010)
v7:Fixnum[5] = Const Value(5)
Return v7
"#]]);
@ -6326,7 +6329,7 @@ mod opt_tests {
bb0(v0:BasicObject):
v1:NilClassExact = Const Value(nil)
v4:ArrayExact = NewArray
PatchPoint MethodRedefined(Array@0x1000, size@0x1008)
PatchPoint MethodRedefined(Array@0x1000, size@0x1008, cme:0x1010)
v7:Fixnum[5] = Const Value(5)
Return v7
"#]]);
@ -6359,8 +6362,8 @@ mod opt_tests {
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
v3:Fixnum[1] = Const Value(1)
PatchPoint MethodRedefined(Integer@0x1000, zero?@0x1008)
v8:BasicObject = SendWithoutBlockDirect v3, :zero? (0x1010)
PatchPoint MethodRedefined(Integer@0x1000, zero?@0x1008, cme:0x1010)
v8:BasicObject = SendWithoutBlockDirect v3, :zero? (0x1038)
Return v8
"#]]);
}
@ -6379,8 +6382,8 @@ mod opt_tests {
v2:NilClassExact = Const Value(nil)
v4:ArrayExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
v6:ArrayExact = ArrayDup v4
PatchPoint MethodRedefined(Array@0x1008, first@0x1010)
v11:BasicObject = SendWithoutBlockDirect v6, :first (0x1018)
PatchPoint MethodRedefined(Array@0x1008, first@0x1010, cme:0x1018)
v11:BasicObject = SendWithoutBlockDirect v6, :first (0x1040)
Return v11
"#]]);
}
@ -6399,8 +6402,8 @@ mod opt_tests {
PatchPoint SingleRactorMode
PatchPoint StableConstantNames(0x1000, M)
v9:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
PatchPoint MethodRedefined(Module@0x1010, class@0x1018)
v11:BasicObject = SendWithoutBlockDirect v9, :class (0x1020)
PatchPoint MethodRedefined(Module@0x1010, class@0x1018, cme:0x1020)
v11:BasicObject = SendWithoutBlockDirect v9, :class (0x1048)
Return v11
"#]]);
}
@ -6499,8 +6502,8 @@ mod opt_tests {
bb0(v0:BasicObject):
v2:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
v3:StringExact = StringCopy v2
PatchPoint MethodRedefined(String@0x1008, bytesize@0x1010)
v8:Fixnum = CCall bytesize@0x1018, v3
PatchPoint MethodRedefined(String@0x1008, bytesize@0x1010, cme:0x1018)
v8:Fixnum = CCall bytesize@0x1040, v3
Return v8
"#]]);
}
@ -6623,8 +6626,8 @@ mod opt_tests {
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject, v2:BasicObject):
v5:ArrayExact = NewArray v1, v2
PatchPoint MethodRedefined(Array@0x1000, length@0x1008)
v10:Fixnum = CCall length@0x1010, v5
PatchPoint MethodRedefined(Array@0x1000, length@0x1008, cme:0x1010)
v10:Fixnum = CCall length@0x1038, v5
Return v10
"#]]);
}
@ -6638,8 +6641,8 @@ mod opt_tests {
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject, v2:BasicObject):
v5:ArrayExact = NewArray v1, v2
PatchPoint MethodRedefined(Array@0x1000, size@0x1008)
v10:Fixnum = CCall size@0x1010, v5
PatchPoint MethodRedefined(Array@0x1000, size@0x1008, cme:0x1010)
v10:Fixnum = CCall size@0x1038, v5
Return v10
"#]]);
}
@ -6973,8 +6976,8 @@ mod opt_tests {
fn test@<compiled>:3:
bb0(v0:BasicObject):
v3:Fixnum[1] = Const Value(1)
PatchPoint MethodRedefined(Integer@0x1000, itself@0x1008)
v15:BasicObject = CCall itself@0x1010, v3
PatchPoint MethodRedefined(Integer@0x1000, itself@0x1008, cme:0x1010)
v15:BasicObject = CCall itself@0x1038, v3
Return v15
"#]]);
}
@ -7079,8 +7082,8 @@ mod opt_tests {
bb0(v0:BasicObject):
v2:ArrayExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
v4:ArrayExact = ArrayDup v2
PatchPoint MethodRedefined(Array@0x1008, max@0x1010)
v9:BasicObject = SendWithoutBlockDirect v4, :max (0x1018)
PatchPoint MethodRedefined(Array@0x1008, max@0x1010, cme:0x1018)
v9:BasicObject = SendWithoutBlockDirect v4, :max (0x1040)
Return v9
"#]]);
}
@ -7127,8 +7130,8 @@ mod opt_tests {
fn test@<compiled>:2:
bb0(v0:BasicObject):
v2:NilClassExact = Const Value(nil)
PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008)
v7:TrueClassExact = CCall nil?@0x1010, v2
PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008, cme:0x1010)
v7:TrueClassExact = CCall nil?@0x1038, v2
Return v7
"#]]);
}
@ -7145,7 +7148,7 @@ mod opt_tests {
fn test@<compiled>:3:
bb0(v0:BasicObject):
v2:NilClassExact = Const Value(nil)
PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008)
PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008, cme:0x1010)
v5:Fixnum[1] = Const Value(1)
Return v5
"#]]);
@ -7160,8 +7163,8 @@ mod opt_tests {
fn test@<compiled>:2:
bb0(v0:BasicObject):
v2:Fixnum[1] = Const Value(1)
PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008)
v7:FalseClassExact = CCall nil?@0x1010, v2
PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008, cme:0x1010)
v7:FalseClassExact = CCall nil?@0x1038, v2
Return v7
"#]]);
}
@ -7178,7 +7181,7 @@ mod opt_tests {
fn test@<compiled>:3:
bb0(v0:BasicObject):
v2:Fixnum[1] = Const Value(1)
PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008)
PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008, cme:0x1010)
v5:Fixnum[2] = Const Value(2)
Return v5
"#]]);
@ -7194,9 +7197,9 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008)
PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008, cme:0x1010)
v7:NilClassExact = GuardType v1, NilClassExact
v8:TrueClassExact = CCall nil?@0x1010, v7
v8:TrueClassExact = CCall nil?@0x1038, v7
Return v8
"#]]);
}
@ -7211,9 +7214,9 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
PatchPoint MethodRedefined(FalseClass@0x1000, nil?@0x1008)
PatchPoint MethodRedefined(FalseClass@0x1000, nil?@0x1008, cme:0x1010)
v7:FalseClassExact = GuardType v1, FalseClassExact
v8:FalseClassExact = CCall nil?@0x1010, v7
v8:FalseClassExact = CCall nil?@0x1038, v7
Return v8
"#]]);
}
@ -7228,9 +7231,9 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
PatchPoint MethodRedefined(TrueClass@0x1000, nil?@0x1008)
PatchPoint MethodRedefined(TrueClass@0x1000, nil?@0x1008, cme:0x1010)
v7:TrueClassExact = GuardType v1, TrueClassExact
v8:FalseClassExact = CCall nil?@0x1010, v7
v8:FalseClassExact = CCall nil?@0x1038, v7
Return v8
"#]]);
}
@ -7245,9 +7248,9 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
PatchPoint MethodRedefined(Symbol@0x1000, nil?@0x1008)
PatchPoint MethodRedefined(Symbol@0x1000, nil?@0x1008, cme:0x1010)
v7:StaticSymbol = GuardType v1, StaticSymbol
v8:FalseClassExact = CCall nil?@0x1010, v7
v8:FalseClassExact = CCall nil?@0x1038, v7
Return v8
"#]]);
}
@ -7262,9 +7265,9 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008)
PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008, cme:0x1010)
v7:Fixnum = GuardType v1, Fixnum
v8:FalseClassExact = CCall nil?@0x1010, v7
v8:FalseClassExact = CCall nil?@0x1038, v7
Return v8
"#]]);
}
@ -7279,9 +7282,9 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
PatchPoint MethodRedefined(Float@0x1000, nil?@0x1008)
PatchPoint MethodRedefined(Float@0x1000, nil?@0x1008, cme:0x1010)
v7:Flonum = GuardType v1, Flonum
v8:FalseClassExact = CCall nil?@0x1010, v7
v8:FalseClassExact = CCall nil?@0x1038, v7
Return v8
"#]]);
}
@ -7296,9 +7299,9 @@ mod opt_tests {
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
PatchPoint MethodRedefined(String@0x1000, nil?@0x1008)
PatchPoint MethodRedefined(String@0x1000, nil?@0x1008, cme:0x1010)
v7:StringExact = GuardType v1, StringExact
v8:FalseClassExact = CCall nil?@0x1010, v7
v8:FalseClassExact = CCall nil?@0x1038, v7
Return v8
"#]]);
}
@ -7338,4 +7341,23 @@ mod opt_tests {
Return v10
"#]]);
}
#[test]
fn test_method_redefinition_patch_point_on_top_level_method() {
eval("
def foo; end
def test = foo
test; test
");
assert_optimized_method_hir("test", expect![[r#"
fn test@<compiled>:3:
bb0(v0:BasicObject):
PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010)
v6:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038)
v7:BasicObject = SendWithoutBlockDirect v6, :foo (0x1040)
Return v7
"#]]);
}
}

View file

@ -1,6 +1,6 @@
use std::{collections::{HashMap, HashSet}};
use crate::{backend::lir::{asm_comment, Assembler}, cruby::{ruby_basic_operators, src_loc, with_vm_lock, IseqPtr, RedefinitionFlag}, hir::Invariant, options::debug, state::{zjit_enabled_p, ZJITState}, virtualmem::CodePtr};
use crate::{backend::lir::{asm_comment, Assembler}, cruby::{rb_callable_method_entry_t, ruby_basic_operators, src_loc, with_vm_lock, IseqPtr, RedefinitionFlag}, hir::Invariant, options::debug, state::{zjit_enabled_p, ZJITState}, virtualmem::CodePtr};
#[derive(Debug, Eq, Hash, PartialEq)]
struct Jump {
@ -20,6 +20,9 @@ pub struct Invariants {
/// Map from a class and its associated basic operator to a set of patch points
bop_patch_points: HashMap<(RedefinitionFlag, ruby_basic_operators), HashSet<Jump>>,
/// Map from CME to patch points that assume the method hasn't been redefined
cme_patch_points: HashMap<*const rb_callable_method_entry_t, HashSet<Jump>>,
}
/// Called when a basic operator is redefined. Note that all the blocks assuming
@ -99,3 +102,45 @@ pub fn track_bop_assumption(
to: side_exit_ptr,
});
}
/// Track a patch point for a callable method entry (CME).
pub fn track_cme_assumption(
cme: *const rb_callable_method_entry_t,
patch_point_ptr: CodePtr,
side_exit_ptr: CodePtr
) {
let invariants = ZJITState::get_invariants();
invariants.cme_patch_points.entry(cme).or_default().insert(Jump {
from: patch_point_ptr,
to: side_exit_ptr,
});
}
/// Called when a method is redefined. Invalidates all JIT code that depends on the CME.
#[unsafe(no_mangle)]
pub extern "C" fn rb_zjit_cme_invalidate(cme: *const rb_callable_method_entry_t) {
// If ZJIT isn't enabled, do nothing
if !zjit_enabled_p() {
return;
}
with_vm_lock(src_loc!(), || {
let invariants = ZJITState::get_invariants();
// Get the CMD's jumps and remove the entry from the map as it has been invalidated
if let Some(jumps) = invariants.cme_patch_points.remove(&cme) {
let cb = ZJITState::get_code_block();
debug!("CME is invalidated: {:?}", cme);
// Invalidate all patch points for this CME
for jump in jumps {
cb.with_write_ptr(jump.from, |cb| {
let mut asm = Assembler::new();
asm_comment!(asm, "CME is invalidated: {:?}", cme);
asm.jmp(jump.to.into());
asm.compile(cb).expect("can write existing code");
});
}
cb.mark_all_executable();
}
});
}