diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index d598b094a4..1166533bd9 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -1066,6 +1066,31 @@ class TestZJIT < Test::Unit::TestCase }, call_threshold: 2 end + # ZJIT currently only generates a MethodRedefined patch point when the method + # is called on the top-level self. + def test_method_redefinition_with_top_self + assert_runs '["original", "redefined"]', %q{ + def foo + "original" + end + + def test = foo + + test; test + + result1 = test + + # Redefine the method + def foo + "redefined" + end + + result2 = test + + [result1, result2] + }, call_threshold: 2 + end + def test_module_name_with_guard_passes assert_compiles '"Integer"', %q{ def test(mod) diff --git a/vm_method.c b/vm_method.c index e63804d34d..84d0ed2f9e 100644 --- a/vm_method.c +++ b/vm_method.c @@ -122,6 +122,7 @@ vm_cme_invalidate(rb_callable_method_entry_t *cme) RB_DEBUG_COUNTER_INC(cc_cme_invalidate); rb_yjit_cme_invalidate(cme); + rb_zjit_cme_invalidate(cme); } static int diff --git a/zjit.h b/zjit.h index 84df6d009e..724ae4abd0 100644 --- a/zjit.h +++ b/zjit.h @@ -12,6 +12,7 @@ void rb_zjit_compile_iseq(const rb_iseq_t *iseq, rb_execution_context_t *ec, boo void rb_zjit_profile_insn(enum ruby_vminsn_type insn, rb_execution_context_t *ec); void rb_zjit_profile_enable(const rb_iseq_t *iseq); void rb_zjit_bop_redefined(int redefined_flag, enum ruby_basic_operators bop); +void rb_zjit_cme_invalidate(const rb_callable_method_entry_t *cme); void rb_zjit_invalidate_ep_is_bp(const rb_iseq_t *iseq); void rb_zjit_iseq_mark(void *payload); void rb_zjit_iseq_update_references(void *payload); @@ -21,6 +22,7 @@ static inline void rb_zjit_compile_iseq(const rb_iseq_t *iseq, rb_execution_cont static inline void rb_zjit_profile_insn(enum ruby_vminsn_type insn, rb_execution_context_t *ec) {} static inline void rb_zjit_profile_enable(const rb_iseq_t *iseq) {} static inline void rb_zjit_bop_redefined(int redefined_flag, enum ruby_basic_operators bop) {} +static inline void rb_zjit_cme_invalidate(const rb_callable_method_entry_t *cme) {} static inline void rb_zjit_invalidate_ep_is_bp(const rb_iseq_t *iseq) {} #endif // #if USE_YJIT diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 49dbf46dce..6ce3915978 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use crate::asm::Label; use crate::backend::current::{Reg, ALLOC_REGS}; -use crate::invariants::track_bop_assumption; +use crate::invariants::{track_bop_assumption, track_cme_assumption}; use crate::gc::{get_or_create_iseq_payload, append_gc_offsets}; use crate::state::ZJITState; use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr}; @@ -494,6 +494,10 @@ fn gen_patch_point(jit: &mut JITState, asm: &mut Assembler, invariant: &Invarian let side_exit_ptr = cb.resolve_label(label); track_bop_assumption(klass, bop, code_ptr, side_exit_ptr); } + Invariant::MethodRedefined { klass: _, method: _, cme } => { + let side_exit_ptr = cb.resolve_label(label); + track_cme_assumption(cme, code_ptr, side_exit_ptr); + } _ => { debug!("ZJIT: gen_patch_point: unimplemented invariant {invariant:?}"); return; diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 41efce6624..53e55b4428 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -127,6 +127,8 @@ pub enum Invariant { klass: VALUE, /// The method ID of the method we want to assume unchanged method: ID, + /// The callable method entry that we want to track + cme: *const rb_callable_method_entry_t, }, /// A list of constant expression path segments that must have not been written to for the /// following code to be valid. @@ -222,12 +224,13 @@ impl<'a> std::fmt::Display for InvariantPrinter<'a> { } write!(f, ")") } - Invariant::MethodRedefined { klass, method } => { + Invariant::MethodRedefined { klass, method, cme } => { let class_name = get_class_name(klass); - write!(f, "MethodRedefined({class_name}@{:p}, {}@{:p})", + write!(f, "MethodRedefined({class_name}@{:p}, {}@{:p}, cme:{:p})", self.ptr_map.map_ptr(klass.as_ptr::()), method.contents_lossy(), - self.ptr_map.map_id(method.0) + self.ptr_map.map_id(method.0), + self.ptr_map.map_ptr(cme) ) } Invariant::StableConstantNames { idlist } => { @@ -1537,7 +1540,7 @@ impl Function { if !can_direct_send(iseq) { self.push_insn_id(block, insn_id); continue; } - self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass, method: mid }, state }); + self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass, method: mid, cme }, state }); if let Some(expected) = guard_equal_to { self_val = self.push_insn(block, Insn::GuardBitEquals { val: self_val, expected, state }); } @@ -1656,7 +1659,7 @@ impl Function { // Filter for simple call sites (i.e. no splats etc.) if ci_flags & VM_CALL_ARGS_SIMPLE != 0 { // Commit to the replacement. Put PatchPoint. - fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: recv_class, method: method_id }, state }); + fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: recv_class, method: method_id, cme: method }, state }); if let Some(guard_type) = guard_type { // Guard receiver class self_val = fun.push_insn(block, Insn::GuardType { val: self_val, guard_type, state }); @@ -5514,9 +5517,9 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:5: bb0(v0:BasicObject): - PatchPoint MethodRedefined(Object@0x1000, foo@0x1008) - v6:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010) - v7:BasicObject = SendWithoutBlockDirect v6, :foo (0x1018) + PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010) + v6:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038) + v7:BasicObject = SendWithoutBlockDirect v6, :foo (0x1040) Return v7 "#]]); } @@ -5554,9 +5557,9 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:6: bb0(v0:BasicObject): - PatchPoint MethodRedefined(Object@0x1000, foo@0x1008) - v6:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010) - v7:BasicObject = SendWithoutBlockDirect v6, :foo (0x1018) + PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010) + v6:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038) + v7:BasicObject = SendWithoutBlockDirect v6, :foo (0x1040) Return v7 "#]]); } @@ -5573,9 +5576,9 @@ mod opt_tests { fn test@:3: bb0(v0:BasicObject): v2:Fixnum[3] = Const Value(3) - PatchPoint MethodRedefined(Object@0x1000, Integer@0x1008) - v7:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010) - v8:BasicObject = SendWithoutBlockDirect v7, :Integer (0x1018), v2 + PatchPoint MethodRedefined(Object@0x1000, Integer@0x1008, cme:0x1010) + v7:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038) + v8:BasicObject = SendWithoutBlockDirect v7, :Integer (0x1040), v2 Return v8 "#]]); } @@ -5595,9 +5598,9 @@ mod opt_tests { bb0(v0:BasicObject): v2:Fixnum[1] = Const Value(1) v3:Fixnum[2] = Const Value(2) - PatchPoint MethodRedefined(Object@0x1000, foo@0x1008) - v8:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010) - v9:BasicObject = SendWithoutBlockDirect v8, :foo (0x1018), v2, v3 + PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010) + v8:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038) + v9:BasicObject = SendWithoutBlockDirect v8, :foo (0x1040), v2, v3 Return v9 "#]]); } @@ -5618,12 +5621,12 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:7: bb0(v0:BasicObject): - PatchPoint MethodRedefined(Object@0x1000, foo@0x1008) - v8:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010) - v9:BasicObject = SendWithoutBlockDirect v8, :foo (0x1018) - PatchPoint MethodRedefined(Object@0x1000, bar@0x1020) - v11:BasicObject[VALUE(0x1010)] = GuardBitEquals v0, VALUE(0x1010) - v12:BasicObject = SendWithoutBlockDirect v11, :bar (0x1018) + PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010) + v8:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038) + v9:BasicObject = SendWithoutBlockDirect v8, :foo (0x1040) + PatchPoint MethodRedefined(Object@0x1000, bar@0x1048, cme:0x1050) + v11:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038) + v12:BasicObject = SendWithoutBlockDirect v11, :bar (0x1040) Return v12 "#]]); } @@ -6149,9 +6152,9 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:2: bb0(v0:BasicObject, v1:BasicObject): - PatchPoint MethodRedefined(Integer@0x1000, itself@0x1008) + PatchPoint MethodRedefined(Integer@0x1000, itself@0x1008, cme:0x1010) v7:Fixnum = GuardType v1, Fixnum - v8:BasicObject = CCall itself@0x1010, v7 + v8:BasicObject = CCall itself@0x1038, v7 Return v8 "#]]); } @@ -6165,8 +6168,8 @@ mod opt_tests { fn test@:2: bb0(v0:BasicObject): v3:ArrayExact = NewArray - PatchPoint MethodRedefined(Array@0x1000, itself@0x1008) - v8:BasicObject = CCall itself@0x1010, v3 + PatchPoint MethodRedefined(Array@0x1000, itself@0x1008, cme:0x1010) + v8:BasicObject = CCall itself@0x1038, v3 Return v8 "#]]); } @@ -6184,7 +6187,7 @@ mod opt_tests { bb0(v0:BasicObject): v1:NilClassExact = Const Value(nil) v4:ArrayExact = NewArray - PatchPoint MethodRedefined(Array@0x1000, itself@0x1008) + PatchPoint MethodRedefined(Array@0x1000, itself@0x1008, cme:0x1010) v7:Fixnum[1] = Const Value(1) Return v7 "#]]); @@ -6207,7 +6210,7 @@ mod opt_tests { PatchPoint SingleRactorMode PatchPoint StableConstantNames(0x1000, M) v11:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) - PatchPoint MethodRedefined(Module@0x1010, name@0x1018) + PatchPoint MethodRedefined(Module@0x1010, name@0x1018, cme:0x1020) v7:Fixnum[1] = Const Value(1) Return v7 "#]]); @@ -6226,7 +6229,7 @@ mod opt_tests { bb0(v0:BasicObject): v1:NilClassExact = Const Value(nil) v4:ArrayExact = NewArray - PatchPoint MethodRedefined(Array@0x1000, length@0x1008) + PatchPoint MethodRedefined(Array@0x1000, length@0x1008, cme:0x1010) v7:Fixnum[5] = Const Value(5) Return v7 "#]]); @@ -6326,7 +6329,7 @@ mod opt_tests { bb0(v0:BasicObject): v1:NilClassExact = Const Value(nil) v4:ArrayExact = NewArray - PatchPoint MethodRedefined(Array@0x1000, size@0x1008) + PatchPoint MethodRedefined(Array@0x1000, size@0x1008, cme:0x1010) v7:Fixnum[5] = Const Value(5) Return v7 "#]]); @@ -6359,8 +6362,8 @@ mod opt_tests { fn test@:2: bb0(v0:BasicObject, v1:BasicObject): v3:Fixnum[1] = Const Value(1) - PatchPoint MethodRedefined(Integer@0x1000, zero?@0x1008) - v8:BasicObject = SendWithoutBlockDirect v3, :zero? (0x1010) + PatchPoint MethodRedefined(Integer@0x1000, zero?@0x1008, cme:0x1010) + v8:BasicObject = SendWithoutBlockDirect v3, :zero? (0x1038) Return v8 "#]]); } @@ -6379,8 +6382,8 @@ mod opt_tests { v2:NilClassExact = Const Value(nil) v4:ArrayExact[VALUE(0x1000)] = Const Value(VALUE(0x1000)) v6:ArrayExact = ArrayDup v4 - PatchPoint MethodRedefined(Array@0x1008, first@0x1010) - v11:BasicObject = SendWithoutBlockDirect v6, :first (0x1018) + PatchPoint MethodRedefined(Array@0x1008, first@0x1010, cme:0x1018) + v11:BasicObject = SendWithoutBlockDirect v6, :first (0x1040) Return v11 "#]]); } @@ -6399,8 +6402,8 @@ mod opt_tests { PatchPoint SingleRactorMode PatchPoint StableConstantNames(0x1000, M) v9:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) - PatchPoint MethodRedefined(Module@0x1010, class@0x1018) - v11:BasicObject = SendWithoutBlockDirect v9, :class (0x1020) + PatchPoint MethodRedefined(Module@0x1010, class@0x1018, cme:0x1020) + v11:BasicObject = SendWithoutBlockDirect v9, :class (0x1048) Return v11 "#]]); } @@ -6499,8 +6502,8 @@ mod opt_tests { bb0(v0:BasicObject): v2:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000)) v3:StringExact = StringCopy v2 - PatchPoint MethodRedefined(String@0x1008, bytesize@0x1010) - v8:Fixnum = CCall bytesize@0x1018, v3 + PatchPoint MethodRedefined(String@0x1008, bytesize@0x1010, cme:0x1018) + v8:Fixnum = CCall bytesize@0x1040, v3 Return v8 "#]]); } @@ -6623,8 +6626,8 @@ mod opt_tests { fn test@:2: bb0(v0:BasicObject, v1:BasicObject, v2:BasicObject): v5:ArrayExact = NewArray v1, v2 - PatchPoint MethodRedefined(Array@0x1000, length@0x1008) - v10:Fixnum = CCall length@0x1010, v5 + PatchPoint MethodRedefined(Array@0x1000, length@0x1008, cme:0x1010) + v10:Fixnum = CCall length@0x1038, v5 Return v10 "#]]); } @@ -6638,8 +6641,8 @@ mod opt_tests { fn test@:2: bb0(v0:BasicObject, v1:BasicObject, v2:BasicObject): v5:ArrayExact = NewArray v1, v2 - PatchPoint MethodRedefined(Array@0x1000, size@0x1008) - v10:Fixnum = CCall size@0x1010, v5 + PatchPoint MethodRedefined(Array@0x1000, size@0x1008, cme:0x1010) + v10:Fixnum = CCall size@0x1038, v5 Return v10 "#]]); } @@ -6973,8 +6976,8 @@ mod opt_tests { fn test@:3: bb0(v0:BasicObject): v3:Fixnum[1] = Const Value(1) - PatchPoint MethodRedefined(Integer@0x1000, itself@0x1008) - v15:BasicObject = CCall itself@0x1010, v3 + PatchPoint MethodRedefined(Integer@0x1000, itself@0x1008, cme:0x1010) + v15:BasicObject = CCall itself@0x1038, v3 Return v15 "#]]); } @@ -7079,8 +7082,8 @@ mod opt_tests { bb0(v0:BasicObject): v2:ArrayExact[VALUE(0x1000)] = Const Value(VALUE(0x1000)) v4:ArrayExact = ArrayDup v2 - PatchPoint MethodRedefined(Array@0x1008, max@0x1010) - v9:BasicObject = SendWithoutBlockDirect v4, :max (0x1018) + PatchPoint MethodRedefined(Array@0x1008, max@0x1010, cme:0x1018) + v9:BasicObject = SendWithoutBlockDirect v4, :max (0x1040) Return v9 "#]]); } @@ -7127,8 +7130,8 @@ mod opt_tests { fn test@:2: bb0(v0:BasicObject): v2:NilClassExact = Const Value(nil) - PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008) - v7:TrueClassExact = CCall nil?@0x1010, v2 + PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008, cme:0x1010) + v7:TrueClassExact = CCall nil?@0x1038, v2 Return v7 "#]]); } @@ -7145,7 +7148,7 @@ mod opt_tests { fn test@:3: bb0(v0:BasicObject): v2:NilClassExact = Const Value(nil) - PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008) + PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008, cme:0x1010) v5:Fixnum[1] = Const Value(1) Return v5 "#]]); @@ -7160,8 +7163,8 @@ mod opt_tests { fn test@:2: bb0(v0:BasicObject): v2:Fixnum[1] = Const Value(1) - PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008) - v7:FalseClassExact = CCall nil?@0x1010, v2 + PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008, cme:0x1010) + v7:FalseClassExact = CCall nil?@0x1038, v2 Return v7 "#]]); } @@ -7178,7 +7181,7 @@ mod opt_tests { fn test@:3: bb0(v0:BasicObject): v2:Fixnum[1] = Const Value(1) - PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008) + PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008, cme:0x1010) v5:Fixnum[2] = Const Value(2) Return v5 "#]]); @@ -7194,9 +7197,9 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:2: bb0(v0:BasicObject, v1:BasicObject): - PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008) + PatchPoint MethodRedefined(NilClass@0x1000, nil?@0x1008, cme:0x1010) v7:NilClassExact = GuardType v1, NilClassExact - v8:TrueClassExact = CCall nil?@0x1010, v7 + v8:TrueClassExact = CCall nil?@0x1038, v7 Return v8 "#]]); } @@ -7211,9 +7214,9 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:2: bb0(v0:BasicObject, v1:BasicObject): - PatchPoint MethodRedefined(FalseClass@0x1000, nil?@0x1008) + PatchPoint MethodRedefined(FalseClass@0x1000, nil?@0x1008, cme:0x1010) v7:FalseClassExact = GuardType v1, FalseClassExact - v8:FalseClassExact = CCall nil?@0x1010, v7 + v8:FalseClassExact = CCall nil?@0x1038, v7 Return v8 "#]]); } @@ -7228,9 +7231,9 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:2: bb0(v0:BasicObject, v1:BasicObject): - PatchPoint MethodRedefined(TrueClass@0x1000, nil?@0x1008) + PatchPoint MethodRedefined(TrueClass@0x1000, nil?@0x1008, cme:0x1010) v7:TrueClassExact = GuardType v1, TrueClassExact - v8:FalseClassExact = CCall nil?@0x1010, v7 + v8:FalseClassExact = CCall nil?@0x1038, v7 Return v8 "#]]); } @@ -7245,9 +7248,9 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:2: bb0(v0:BasicObject, v1:BasicObject): - PatchPoint MethodRedefined(Symbol@0x1000, nil?@0x1008) + PatchPoint MethodRedefined(Symbol@0x1000, nil?@0x1008, cme:0x1010) v7:StaticSymbol = GuardType v1, StaticSymbol - v8:FalseClassExact = CCall nil?@0x1010, v7 + v8:FalseClassExact = CCall nil?@0x1038, v7 Return v8 "#]]); } @@ -7262,9 +7265,9 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:2: bb0(v0:BasicObject, v1:BasicObject): - PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008) + PatchPoint MethodRedefined(Integer@0x1000, nil?@0x1008, cme:0x1010) v7:Fixnum = GuardType v1, Fixnum - v8:FalseClassExact = CCall nil?@0x1010, v7 + v8:FalseClassExact = CCall nil?@0x1038, v7 Return v8 "#]]); } @@ -7279,9 +7282,9 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:2: bb0(v0:BasicObject, v1:BasicObject): - PatchPoint MethodRedefined(Float@0x1000, nil?@0x1008) + PatchPoint MethodRedefined(Float@0x1000, nil?@0x1008, cme:0x1010) v7:Flonum = GuardType v1, Flonum - v8:FalseClassExact = CCall nil?@0x1010, v7 + v8:FalseClassExact = CCall nil?@0x1038, v7 Return v8 "#]]); } @@ -7296,9 +7299,9 @@ mod opt_tests { assert_optimized_method_hir("test", expect![[r#" fn test@:2: bb0(v0:BasicObject, v1:BasicObject): - PatchPoint MethodRedefined(String@0x1000, nil?@0x1008) + PatchPoint MethodRedefined(String@0x1000, nil?@0x1008, cme:0x1010) v7:StringExact = GuardType v1, StringExact - v8:FalseClassExact = CCall nil?@0x1010, v7 + v8:FalseClassExact = CCall nil?@0x1038, v7 Return v8 "#]]); } @@ -7338,4 +7341,23 @@ mod opt_tests { Return v10 "#]]); } + + #[test] + fn test_method_redefinition_patch_point_on_top_level_method() { + eval(" + def foo; end + def test = foo + + test; test + "); + + assert_optimized_method_hir("test", expect![[r#" + fn test@:3: + bb0(v0:BasicObject): + PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010) + v6:BasicObject[VALUE(0x1038)] = GuardBitEquals v0, VALUE(0x1038) + v7:BasicObject = SendWithoutBlockDirect v6, :foo (0x1040) + Return v7 + "#]]); + } } diff --git a/zjit/src/invariants.rs b/zjit/src/invariants.rs index 62b3805485..6949da6a86 100644 --- a/zjit/src/invariants.rs +++ b/zjit/src/invariants.rs @@ -1,6 +1,6 @@ use std::{collections::{HashMap, HashSet}}; -use crate::{backend::lir::{asm_comment, Assembler}, cruby::{ruby_basic_operators, src_loc, with_vm_lock, IseqPtr, RedefinitionFlag}, hir::Invariant, options::debug, state::{zjit_enabled_p, ZJITState}, virtualmem::CodePtr}; +use crate::{backend::lir::{asm_comment, Assembler}, cruby::{rb_callable_method_entry_t, ruby_basic_operators, src_loc, with_vm_lock, IseqPtr, RedefinitionFlag}, hir::Invariant, options::debug, state::{zjit_enabled_p, ZJITState}, virtualmem::CodePtr}; #[derive(Debug, Eq, Hash, PartialEq)] struct Jump { @@ -20,6 +20,9 @@ pub struct Invariants { /// Map from a class and its associated basic operator to a set of patch points bop_patch_points: HashMap<(RedefinitionFlag, ruby_basic_operators), HashSet>, + + /// Map from CME to patch points that assume the method hasn't been redefined + cme_patch_points: HashMap<*const rb_callable_method_entry_t, HashSet>, } /// Called when a basic operator is redefined. Note that all the blocks assuming @@ -99,3 +102,45 @@ pub fn track_bop_assumption( to: side_exit_ptr, }); } + +/// Track a patch point for a callable method entry (CME). +pub fn track_cme_assumption( + cme: *const rb_callable_method_entry_t, + patch_point_ptr: CodePtr, + side_exit_ptr: CodePtr +) { + let invariants = ZJITState::get_invariants(); + invariants.cme_patch_points.entry(cme).or_default().insert(Jump { + from: patch_point_ptr, + to: side_exit_ptr, + }); +} + +/// Called when a method is redefined. Invalidates all JIT code that depends on the CME. +#[unsafe(no_mangle)] +pub extern "C" fn rb_zjit_cme_invalidate(cme: *const rb_callable_method_entry_t) { + // If ZJIT isn't enabled, do nothing + if !zjit_enabled_p() { + return; + } + + with_vm_lock(src_loc!(), || { + let invariants = ZJITState::get_invariants(); + // Get the CMD's jumps and remove the entry from the map as it has been invalidated + if let Some(jumps) = invariants.cme_patch_points.remove(&cme) { + let cb = ZJITState::get_code_block(); + debug!("CME is invalidated: {:?}", cme); + + // Invalidate all patch points for this CME + for jump in jumps { + cb.with_write_ptr(jump.from, |cb| { + let mut asm = Assembler::new(); + asm_comment!(asm, "CME is invalidated: {:?}", cme); + asm.jmp(jump.to.into()); + asm.compile(cb).expect("can write existing code"); + }); + } + cb.mark_all_executable(); + } + }); +}