From 424b50673f718d7d2cb4f80bcbc4376794bef138 Mon Sep 17 00:00:00 2001 From: eileencodes Date: Tue, 17 Jun 2025 13:19:19 -0400 Subject: [PATCH] ZJIT: Implement getspecial in ZJIT Adds support for the getspecial instruction in zjit. We split getspecial into two instructions, one for special symbols (`$&`, $'`, etc) and one for special backrefs (`$1`, `$2`, etc). Co-authored-by: Aaron Patterson --- test/ruby/test_zjit.rb | 100 +++++++++++++++++++++++++++++++++++++++++ zjit/src/codegen.rs | 35 ++++++++++++++- zjit/src/hir.rs | 57 +++++++++++++++++++++++ 3 files changed, 191 insertions(+), 1 deletion(-) diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index 58fc9ba639..0216ab654f 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -1213,6 +1213,106 @@ class TestZJIT < Test::Unit::TestCase }, insns: [:opt_nil_p] end + def test_getspecial_last_match + assert_compiles '"hello"', %q{ + def test(str) + str =~ /hello/ + $& + end + test("hello world") + }, insns: [:getspecial] + end + + def test_getspecial_match_pre + assert_compiles '"hello "', %q{ + def test(str) + str =~ /world/ + $` + end + test("hello world") + }, insns: [:getspecial] + end + + def test_getspecial_match_post + assert_compiles '" world"', %q{ + def test(str) + str =~ /hello/ + $' + end + test("hello world") + }, insns: [:getspecial] + end + + def test_getspecial_match_last_group + assert_compiles '"world"', %q{ + def test(str) + str =~ /(hello) (world)/ + $+ + end + test("hello world") + }, insns: [:getspecial] + end + + def test_getspecial_numbered_match_1 + assert_compiles '"hello"', %q{ + def test(str) + str =~ /(hello) (world)/ + $1 + end + test("hello world") + }, insns: [:getspecial] + end + + def test_getspecial_numbered_match_2 + assert_compiles '"world"', %q{ + def test(str) + str =~ /(hello) (world)/ + $2 + end + test("hello world") + }, insns: [:getspecial] + end + + def test_getspecial_numbered_match_nonexistent + assert_compiles 'nil', %q{ + def test(str) + str =~ /(hello)/ + $2 + end + test("hello world") + }, insns: [:getspecial] + end + + def test_getspecial_no_match + assert_compiles 'nil', %q{ + def test(str) + str =~ /xyz/ + $& + end + test("hello world") + }, insns: [:getspecial] + end + + def test_getspecial_complex_pattern + assert_compiles '"123"', %q{ + def test(str) + str =~ /(\d+)/ + $1 + end + test("abc123def") + }, insns: [:getspecial] + end + + def test_getspecial_multiple_groups + assert_compiles '"456"', %q{ + def test(str) + str =~ /(\d+)-(\d+)/ + $2 + end + test("123-456") + }, insns: [:getspecial] + end + # tool/ruby_vm/views/*.erb relies on the zjit instructions a) being contiguous and # b) being reliably ordered after all the other instructions. def test_instruction_order diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 43fde7db7f..371564b7a7 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -10,7 +10,7 @@ use crate::state::ZJITState; use crate::stats::{counter_ptr, with_time_stat, Counter, Counter::compile_time_ns}; use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr}; use crate::backend::lir::{self, asm_comment, asm_ccall, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, NATIVE_STACK_PTR, NATIVE_BASE_PTR, SP}; -use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, Invariant, RangeType, SideExitReason, SideExitReason::*, SpecialObjectType, SELF_PARAM_IDX}; +use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, Invariant, RangeType, SideExitReason, SideExitReason::*, SpecialObjectType, SpecialBackrefSymbol, SELF_PARAM_IDX}; use crate::hir::{Const, FrameState, Function, Insn, InsnId}; use crate::hir_type::{types, Type}; use crate::options::get_option; @@ -378,6 +378,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::PutSpecialObject { value_type } => gen_putspecialobject(asm, *value_type), Insn::AnyToString { val, str, state } => gen_anytostring(asm, opnd!(val), opnd!(str), &function.frame_state(*state))?, Insn::Defined { op_type, obj, pushval, v, state } => gen_defined(jit, asm, *op_type, *obj, *pushval, opnd!(v), &function.frame_state(*state))?, + Insn::GetSpecialSymbol { symbol_type, state: _ } => gen_getspecial_symbol(asm, *symbol_type), + Insn::GetSpecialNumber { nth, state } => gen_getspecial_number(asm, *nth, &function.frame_state(*state)), &Insn::IncrCounter(counter) => return Some(gen_incr_counter(asm, counter)), Insn::ObjToString { val, cd, state, .. } => gen_objtostring(jit, asm, opnd!(val), *cd, &function.frame_state(*state))?, Insn::ArrayExtend { .. } @@ -640,6 +642,37 @@ fn gen_putspecialobject(asm: &mut Assembler, value_type: SpecialObjectType) -> O asm_ccall!(asm, rb_vm_get_special_object, ep_reg, Opnd::UImm(u64::from(value_type))) } +fn gen_getspecial_symbol(asm: &mut Assembler, symbol_type: SpecialBackrefSymbol) -> Opnd { + // Fetch a "special" backref based on the symbol type + + let backref = asm_ccall!(asm, rb_backref_get,); + + match symbol_type { + SpecialBackrefSymbol::LastMatch => { + asm_ccall!(asm, rb_reg_last_match, backref) + } + SpecialBackrefSymbol::PreMatch => { + asm_ccall!(asm, rb_reg_match_pre, backref) + } + SpecialBackrefSymbol::PostMatch => { + asm_ccall!(asm, rb_reg_match_post, backref) + } + SpecialBackrefSymbol::LastGroup => { + asm_ccall!(asm, rb_reg_match_last, backref) + } + } +} + +fn gen_getspecial_number(asm: &mut Assembler, nth: u64, state: &FrameState) -> Opnd { + // Fetch the N-th match from the last backref based on type shifted by 1 + + let backref = asm_ccall!(asm, rb_backref_get,); + + gen_prepare_call_with_gc(asm, state); + + asm_ccall!(asm, rb_reg_nth_match, Opnd::Imm((nth >> 1).try_into().unwrap()), backref) +} + /// Compile an interpreter entry block to be inserted into an ISEQ fn gen_entry_prologue(asm: &mut Assembler, iseq: IseqPtr) { asm_comment!(asm, "ZJIT entry point: {}", iseq_get_location(iseq, 0)); diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index c93a6858f1..c2329129d5 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -321,6 +321,29 @@ impl From for u32 { } } +/// Special regex backref symbol types +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum SpecialBackrefSymbol { + LastMatch, // $& + PreMatch, // $` + PostMatch, // $' + LastGroup, // $+ +} + +impl TryFrom for SpecialBackrefSymbol { + type Error = String; + + fn try_from(value: u8) -> Result { + match value as char { + '&' => Ok(SpecialBackrefSymbol::LastMatch), + '`' => Ok(SpecialBackrefSymbol::PreMatch), + '\'' => Ok(SpecialBackrefSymbol::PostMatch), + '+' => Ok(SpecialBackrefSymbol::LastGroup), + c => Err(format!("invalid backref symbol: '{}'", c)), + } + } +} + /// Print adaptor for [`Const`]. See [`PtrPrintMap`]. struct ConstPrinter<'a> { inner: &'a Const, @@ -415,6 +438,7 @@ pub enum SideExitReason { PatchPoint(Invariant), CalleeSideExit, ObjToStringFallback, + UnknownSpecialVariable(u64), } impl std::fmt::Display for SideExitReason { @@ -494,6 +518,8 @@ pub enum Insn { GetLocal { level: u32, ep_offset: u32 }, /// Set a local variable in a higher scope or the heap SetLocal { level: u32, ep_offset: u32, val: InsnId }, + GetSpecialSymbol { symbol_type: SpecialBackrefSymbol, state: InsnId }, + GetSpecialNumber { nth: u64, state: InsnId }, /// Own a FrameState so that instructions can look up their dominating FrameState when /// generating deopt side-exits and frame reconstruction metadata. Does not directly generate @@ -774,6 +800,8 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::SetGlobal { id, val, .. } => write!(f, "SetGlobal :{}, {val}", id.contents_lossy()), Insn::GetLocal { level, ep_offset } => write!(f, "GetLocal l{level}, EP@{ep_offset}"), Insn::SetLocal { val, level, ep_offset } => write!(f, "SetLocal l{level}, EP@{ep_offset}, {val}"), + Insn::GetSpecialSymbol { symbol_type, .. } => write!(f, "GetSpecialSymbol {symbol_type:?}"), + Insn::GetSpecialNumber { nth, .. } => write!(f, "GetSpecialNumber {nth}"), Insn::ToArray { val, .. } => write!(f, "ToArray {val}"), Insn::ToNewArray { val, .. } => write!(f, "ToNewArray {val}"), Insn::ArrayExtend { left, right, .. } => write!(f, "ArrayExtend {left}, {right}"), @@ -1221,6 +1249,8 @@ impl Function { &GetIvar { self_val, id, state } => GetIvar { self_val: find!(self_val), id, state }, &SetIvar { self_val, id, val, state } => SetIvar { self_val: find!(self_val), id, val: find!(val), state }, &SetLocal { val, ep_offset, level } => SetLocal { val: find!(val), ep_offset, level }, + &GetSpecialSymbol { symbol_type, state } => GetSpecialSymbol { symbol_type, state }, + &GetSpecialNumber { nth, state } => GetSpecialNumber { nth, state }, &ToArray { val, state } => ToArray { val: find!(val), state }, &ToNewArray { val, state } => ToNewArray { val: find!(val), state }, &ArrayExtend { left, right, state } => ArrayExtend { left: find!(left), right: find!(right), state }, @@ -1306,6 +1336,8 @@ impl Function { Insn::ArrayMax { .. } => types::BasicObject, Insn::GetGlobal { .. } => types::BasicObject, Insn::GetIvar { .. } => types::BasicObject, + Insn::GetSpecialSymbol { .. } => types::BasicObject, + Insn::GetSpecialNumber { .. } => types::BasicObject, Insn::ToNewArray { .. } => types::ArrayExact, Insn::ToArray { .. } => types::ArrayExact, Insn::ObjToString { .. } => types::BasicObject, @@ -1995,6 +2027,8 @@ impl Function { worklist.push_back(state); } &Insn::GetGlobal { state, .. } | + &Insn::GetSpecialSymbol { state, .. } | + &Insn::GetSpecialNumber { state, .. } | &Insn::SideExit { state, .. } => worklist.push_back(state), } } @@ -3323,6 +3357,29 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { let anytostring = fun.push_insn(block, Insn::AnyToString { val, str, state: exit_id }); state.stack_push(anytostring); } + YARVINSN_getspecial => { + let key = get_arg(pc, 0).as_u64(); + let svar = get_arg(pc, 1).as_u64(); + + let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); + + if svar == 0 { + // TODO: Handle non-backref + fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnknownSpecialVariable(key) }); + // End the block + break; + } else if svar & 0x01 != 0 { + // Handle symbol backrefs like $&, $`, $', $+ + let shifted_svar: u8 = (svar >> 1).try_into().unwrap(); + let symbol_type = SpecialBackrefSymbol::try_from(shifted_svar).expect("invalid backref symbol"); + let result = fun.push_insn(block, Insn::GetSpecialSymbol { symbol_type, state: exit_id }); + state.stack_push(result); + } else { + // Handle number backrefs like $1, $2, $3 + let result = fun.push_insn(block, Insn::GetSpecialNumber { nth: svar, state: exit_id }); + state.stack_push(result); + } + } _ => { // Unknown opcode; side-exit into the interpreter let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });