diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index bf43fd1324..4dc0919b6b 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -1506,6 +1506,22 @@ class TestZJIT < Test::Unit::TestCase }, call_threshold: 2 end + def test_string_concat + assert_compiles '"123"', %q{ + def test = "#{1}#{2}#{3}" + + test + }, insns: [:concatstrings] + end + + def test_string_concat_empty + assert_compiles '""', %q{ + def test = "#{}" + + test + }, insns: [:concatstrings] + end + private # Assert that every method call in `test_script` can be compiled by ZJIT diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index 0902d347c7..86bea62fcd 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -2233,6 +2233,12 @@ impl Assembler { out } + pub fn sub_into(&mut self, left: Opnd, right: Opnd) -> Opnd { + let out = self.sub(left, right); + self.mov(left, out); + out + } + #[must_use] pub fn mul(&mut self, left: Opnd, right: Opnd) -> Opnd { let out = self.new_vreg(Opnd::match_num_bits(&[left, right])); diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index f9532dfe03..b1b43abbe6 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -330,6 +330,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::NewRange { low, high, flag, state } => gen_new_range(asm, opnd!(low), opnd!(high), *flag, &function.frame_state(*state)), Insn::ArrayDup { val, state } => gen_array_dup(asm, opnd!(val), &function.frame_state(*state)), Insn::StringCopy { val, chilled, state } => gen_string_copy(asm, opnd!(val), *chilled, &function.frame_state(*state)), + Insn::StringConcat { strings, state } => gen_string_concat(jit, asm, opnds!(strings), &function.frame_state(*state))?, Insn::Param { idx } => unreachable!("block.insns should not have Insn::Param({idx})"), Insn::Snapshot { .. } => return Some(()), // we don't need to do anything for this instruction at the moment Insn::Jump(branch) => return gen_jump(jit, asm, branch), @@ -1456,6 +1457,56 @@ pub fn gen_stub_exit(cb: &mut CodeBlock) -> Option { }) } +fn gen_string_concat(jit: &mut JITState, asm: &mut Assembler, strings: Vec, state: &FrameState) -> Option { + let n = strings.len(); + + // concatstrings shouldn't have 0 strings + // If it happens we abort the compilation for now + if n == 0 { + return None; + } + + gen_prepare_non_leaf_call(jit, asm, state)?; + + // Calculate the compile-time NATIVE_STACK_PTR offset from NATIVE_BASE_PTR + // At this point, frame_setup(&[], jit.c_stack_slots) has been called, + // which allocated aligned_stack_bytes(jit.c_stack_slots) on the stack + let frame_size = aligned_stack_bytes(jit.c_stack_slots); + let allocation_size = aligned_stack_bytes(n); + + asm_comment!(asm, "allocate {} bytes on C stack for {} strings", allocation_size, n); + asm.sub_into(NATIVE_STACK_PTR, allocation_size.into()); + + // Calculate the total offset from NATIVE_BASE_PTR to our buffer + let total_offset_from_base = (frame_size + allocation_size) as i32; + + for (idx, &string_opnd) in strings.iter().enumerate() { + let slot_offset = -total_offset_from_base + (idx as i32 * SIZEOF_VALUE_I32); + asm.mov( + Opnd::mem(VALUE_BITS, NATIVE_BASE_PTR, slot_offset), + string_opnd + ); + } + + let first_string_ptr = asm.lea(Opnd::mem(64, NATIVE_BASE_PTR, -total_offset_from_base)); + + let result = asm_ccall!(asm, rb_str_concat_literals, n.into(), first_string_ptr); + + asm_comment!(asm, "restore C stack pointer"); + asm.add_into(NATIVE_STACK_PTR, allocation_size.into()); + + Some(result) +} + +/// Given the number of spill slots needed for a function, return the number of bytes +/// the function needs to allocate on the stack for the stack frame. +fn aligned_stack_bytes(num_slots: usize) -> usize { + // Both x86_64 and arm64 require the stack to be aligned to 16 bytes. + // Since SIZEOF_VALUE is 8 bytes, we need to round up the size to the nearest even number. + let num_slots = num_slots + (num_slots % 2); + num_slots * SIZEOF_VALUE +} + impl Assembler { /// Make a C call while marking the start and end positions of it fn ccall_with_branch(&mut self, fptr: *const u8, opnds: Vec, branch: &Rc) -> Opnd { diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index bff0fcd757..5111ab30f9 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -445,6 +445,7 @@ pub enum Insn { StringCopy { val: InsnId, chilled: bool, state: InsnId }, StringIntern { val: InsnId }, + StringConcat { strings: Vec, state: InsnId }, /// Put special object (VMCORE, CBASE, etc.) based on value_type PutSpecialObject { value_type: SpecialObjectType }, @@ -675,6 +676,16 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::ArrayDup { val, .. } => { write!(f, "ArrayDup {val}") } Insn::HashDup { val, .. } => { write!(f, "HashDup {val}") } Insn::StringCopy { val, .. } => { write!(f, "StringCopy {val}") } + Insn::StringConcat { strings, .. } => { + write!(f, "StringConcat")?; + let mut prefix = " "; + for string in strings { + write!(f, "{prefix}{string}")?; + prefix = ", "; + } + + Ok(()) + } Insn::Test { val } => { write!(f, "Test {val}") } Insn::IsNil { val } => { write!(f, "IsNil {val}") } Insn::Jump(target) => { write!(f, "Jump {target}") } @@ -1135,6 +1146,7 @@ impl Function { &Throw { throw_state, val } => Throw { throw_state, val: find!(val) }, &StringCopy { val, chilled, state } => StringCopy { val: find!(val), chilled, state }, &StringIntern { val } => StringIntern { val: find!(val) }, + &StringConcat { ref strings, state } => StringConcat { strings: find_vec!(strings), state: find!(state) }, &Test { val } => Test { val: find!(val) }, &IsNil { val } => IsNil { val: find!(val) }, &Jump(ref target) => Jump(find_branch_edge!(target)), @@ -1258,6 +1270,7 @@ impl Function { Insn::IsNil { .. } => types::CBool, Insn::StringCopy { .. } => types::StringExact, Insn::StringIntern { .. } => types::StringExact, + Insn::StringConcat { .. } => types::StringExact, Insn::NewArray { .. } => types::ArrayExact, Insn::ArrayDup { .. } => types::ArrayExact, Insn::NewHash { .. } => types::HashExact, @@ -1887,6 +1900,10 @@ impl Function { worklist.push_back(high); worklist.push_back(state); } + &Insn::StringConcat { ref strings, state, .. } => { + worklist.extend(strings); + worklist.push_back(state); + } | &Insn::StringIntern { val } | &Insn::Return { val } | &Insn::Throw { val, .. } @@ -2469,6 +2486,16 @@ impl FrameState { self.stack.pop().ok_or_else(|| ParseError::StackUnderflow(self.clone())) } + fn stack_pop_n(&mut self, count: usize) -> Result, ParseError> { + // Check if we have enough values on the stack + let stack_len = self.stack.len(); + if stack_len < count { + return Err(ParseError::StackUnderflow(self.clone())); + } + + Ok(self.stack.split_off(stack_len - count)) + } + /// Get a stack-top operand fn stack_top(&self) -> Result { self.stack.last().ok_or_else(|| ParseError::StackUnderflow(self.clone())).copied() @@ -2789,24 +2816,23 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { let insn_id = fun.push_insn(block, Insn::StringIntern { val }); state.stack_push(insn_id); } + YARVINSN_concatstrings => { + let count = get_arg(pc, 0).as_u32(); + let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); + let strings = state.stack_pop_n(count as usize)?; + let insn_id = fun.push_insn(block, Insn::StringConcat { strings, state: exit_id }); + state.stack_push(insn_id); + } YARVINSN_newarray => { let count = get_arg(pc, 0).as_usize(); let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); - let mut elements = vec![]; - for _ in 0..count { - elements.push(state.stack_pop()?); - } - elements.reverse(); + let elements = state.stack_pop_n(count)?; state.stack_push(fun.push_insn(block, Insn::NewArray { elements, state: exit_id })); } YARVINSN_opt_newarray_send => { let count = get_arg(pc, 0).as_usize(); let method = get_arg(pc, 1).as_u32(); - let mut elements = vec![]; - for _ in 0..count { - elements.push(state.stack_pop()?); - } - elements.reverse(); + let elements = state.stack_pop_n(count)?; let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); let (bop, insn) = match method { VM_OPT_NEWARRAY_SEND_MAX => (BOP_MAX, Insn::ArrayMax { elements, state: exit_id }), @@ -2871,13 +2897,10 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { } YARVINSN_pushtoarray => { let count = get_arg(pc, 0).as_usize(); - let mut vals = vec![]; - for _ in 0..count { - vals.push(state.stack_pop()?); - } + let vals = state.stack_pop_n(count)?; let array = state.stack_pop()?; let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); - for val in vals.into_iter().rev() { + for val in vals.into_iter() { fun.push_insn(block, Insn::ArrayPush { array, val, state: exit_id }); } state.stack_push(array); @@ -3079,12 +3102,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { } let argc = unsafe { vm_ci_argc((*cd).ci) }; - let mut args = vec![]; - for _ in 0..argc { - args.push(state.stack_pop()?); - } - args.reverse(); - + let args = state.stack_pop_n(argc as usize)?; let recv = state.stack_pop()?; let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); let send = fun.push_insn(block, Insn::SendWithoutBlock { self_val: recv, cd, args, state: exit_id }); @@ -3160,12 +3178,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { } let argc = unsafe { vm_ci_argc((*cd).ci) }; - let mut args = vec![]; - for _ in 0..argc { - args.push(state.stack_pop()?); - } - args.reverse(); - + let args = state.stack_pop_n(argc as usize)?; let recv = state.stack_pop()?; let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); let send = fun.push_insn(block, Insn::SendWithoutBlock { self_val: recv, cd, args, state: exit_id }); @@ -3183,12 +3196,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { } let argc = unsafe { vm_ci_argc((*cd).ci) }; - let mut args = vec![]; - for _ in 0..argc { - args.push(state.stack_pop()?); - } - args.reverse(); - + let args = state.stack_pop_n(argc as usize)?; let recv = state.stack_pop()?; let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); let send = fun.push_insn(block, Insn::Send { self_val: recv, cd, blockiseq, args, state: exit_id }); @@ -5224,7 +5232,47 @@ mod tests { v3:Fixnum[1] = Const Value(1) v5:BasicObject = ObjToString v3 v7:String = AnyToString v3, str: v5 - SideExit UnknownOpcode(concatstrings) + v9:StringExact = StringConcat v2, v7 + Return v9 + "#]]); + } + + #[test] + fn test_string_concat() { + eval(r##" + def test = "#{1}#{2}#{3}" + "##); + assert_method_hir_with_opcode("test", YARVINSN_concatstrings, expect![[r#" + fn test@:2: + bb0(v0:BasicObject): + v2:Fixnum[1] = Const Value(1) + v4:BasicObject = ObjToString v2 + v6:String = AnyToString v2, str: v4 + v7:Fixnum[2] = Const Value(2) + v9:BasicObject = ObjToString v7 + v11:String = AnyToString v7, str: v9 + v12:Fixnum[3] = Const Value(3) + v14:BasicObject = ObjToString v12 + v16:String = AnyToString v12, str: v14 + v18:StringExact = StringConcat v6, v11, v16 + Return v18 + "#]]); + } + + #[test] + fn test_string_concat_empty() { + eval(r##" + def test = "#{}" + "##); + assert_method_hir_with_opcode("test", YARVINSN_concatstrings, expect![[r#" + fn test@:2: + bb0(v0:BasicObject): + v2:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000)) + v3:NilClass = Const Value(nil) + v5:BasicObject = ObjToString v3 + v7:String = AnyToString v3, str: v5 + v9:StringExact = StringConcat v2, v7 + Return v9 "#]]); } @@ -7172,7 +7220,8 @@ mod opt_tests { v2:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000)) v3:StringExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) v5:StringExact = StringCopy v3 - SideExit UnknownOpcode(concatstrings) + v11:StringExact = StringConcat v2, v5 + Return v11 "#]]); } @@ -7186,9 +7235,10 @@ mod opt_tests { bb0(v0:BasicObject): v2:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000)) v3:Fixnum[1] = Const Value(1) - v10:BasicObject = SendWithoutBlock v3, :to_s - v7:String = AnyToString v3, str: v10 - SideExit UnknownOpcode(concatstrings) + v11:BasicObject = SendWithoutBlock v3, :to_s + v7:String = AnyToString v3, str: v11 + v9:StringExact = StringConcat v2, v7 + Return v9 "#]]); }