ZJIT: Implement concatstrings insn (#14154)

Co-authored-by: Alexander Momchilov <alexander.momchilov@shopify.com>
This commit is contained in:
Stan Lo 2025-08-11 23:07:26 +01:00 committed by GitHub
parent 4f34eddbd3
commit e29d333454
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 161 additions and 38 deletions

View file

@ -1506,6 +1506,22 @@ class TestZJIT < Test::Unit::TestCase
}, call_threshold: 2
end
def test_string_concat
assert_compiles '"123"', %q{
def test = "#{1}#{2}#{3}"
test
}, insns: [:concatstrings]
end
def test_string_concat_empty
assert_compiles '""', %q{
def test = "#{}"
test
}, insns: [:concatstrings]
end
private
# Assert that every method call in `test_script` can be compiled by ZJIT

View file

@ -2233,6 +2233,12 @@ impl Assembler {
out
}
pub fn sub_into(&mut self, left: Opnd, right: Opnd) -> Opnd {
let out = self.sub(left, right);
self.mov(left, out);
out
}
#[must_use]
pub fn mul(&mut self, left: Opnd, right: Opnd) -> Opnd {
let out = self.new_vreg(Opnd::match_num_bits(&[left, right]));

View file

@ -330,6 +330,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::NewRange { low, high, flag, state } => gen_new_range(asm, opnd!(low), opnd!(high), *flag, &function.frame_state(*state)),
Insn::ArrayDup { val, state } => gen_array_dup(asm, opnd!(val), &function.frame_state(*state)),
Insn::StringCopy { val, chilled, state } => gen_string_copy(asm, opnd!(val), *chilled, &function.frame_state(*state)),
Insn::StringConcat { strings, state } => gen_string_concat(jit, asm, opnds!(strings), &function.frame_state(*state))?,
Insn::Param { idx } => unreachable!("block.insns should not have Insn::Param({idx})"),
Insn::Snapshot { .. } => return Some(()), // we don't need to do anything for this instruction at the moment
Insn::Jump(branch) => return gen_jump(jit, asm, branch),
@ -1456,6 +1457,56 @@ pub fn gen_stub_exit(cb: &mut CodeBlock) -> Option<CodePtr> {
})
}
fn gen_string_concat(jit: &mut JITState, asm: &mut Assembler, strings: Vec<Opnd>, state: &FrameState) -> Option<Opnd> {
let n = strings.len();
// concatstrings shouldn't have 0 strings
// If it happens we abort the compilation for now
if n == 0 {
return None;
}
gen_prepare_non_leaf_call(jit, asm, state)?;
// Calculate the compile-time NATIVE_STACK_PTR offset from NATIVE_BASE_PTR
// At this point, frame_setup(&[], jit.c_stack_slots) has been called,
// which allocated aligned_stack_bytes(jit.c_stack_slots) on the stack
let frame_size = aligned_stack_bytes(jit.c_stack_slots);
let allocation_size = aligned_stack_bytes(n);
asm_comment!(asm, "allocate {} bytes on C stack for {} strings", allocation_size, n);
asm.sub_into(NATIVE_STACK_PTR, allocation_size.into());
// Calculate the total offset from NATIVE_BASE_PTR to our buffer
let total_offset_from_base = (frame_size + allocation_size) as i32;
for (idx, &string_opnd) in strings.iter().enumerate() {
let slot_offset = -total_offset_from_base + (idx as i32 * SIZEOF_VALUE_I32);
asm.mov(
Opnd::mem(VALUE_BITS, NATIVE_BASE_PTR, slot_offset),
string_opnd
);
}
let first_string_ptr = asm.lea(Opnd::mem(64, NATIVE_BASE_PTR, -total_offset_from_base));
let result = asm_ccall!(asm, rb_str_concat_literals, n.into(), first_string_ptr);
asm_comment!(asm, "restore C stack pointer");
asm.add_into(NATIVE_STACK_PTR, allocation_size.into());
Some(result)
}
/// Given the number of spill slots needed for a function, return the number of bytes
/// the function needs to allocate on the stack for the stack frame.
fn aligned_stack_bytes(num_slots: usize) -> usize {
// Both x86_64 and arm64 require the stack to be aligned to 16 bytes.
// Since SIZEOF_VALUE is 8 bytes, we need to round up the size to the nearest even number.
let num_slots = num_slots + (num_slots % 2);
num_slots * SIZEOF_VALUE
}
impl Assembler {
/// Make a C call while marking the start and end positions of it
fn ccall_with_branch(&mut self, fptr: *const u8, opnds: Vec<Opnd>, branch: &Rc<Branch>) -> Opnd {

View file

@ -445,6 +445,7 @@ pub enum Insn {
StringCopy { val: InsnId, chilled: bool, state: InsnId },
StringIntern { val: InsnId },
StringConcat { strings: Vec<InsnId>, state: InsnId },
/// Put special object (VMCORE, CBASE, etc.) based on value_type
PutSpecialObject { value_type: SpecialObjectType },
@ -675,6 +676,16 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
Insn::ArrayDup { val, .. } => { write!(f, "ArrayDup {val}") }
Insn::HashDup { val, .. } => { write!(f, "HashDup {val}") }
Insn::StringCopy { val, .. } => { write!(f, "StringCopy {val}") }
Insn::StringConcat { strings, .. } => {
write!(f, "StringConcat")?;
let mut prefix = " ";
for string in strings {
write!(f, "{prefix}{string}")?;
prefix = ", ";
}
Ok(())
}
Insn::Test { val } => { write!(f, "Test {val}") }
Insn::IsNil { val } => { write!(f, "IsNil {val}") }
Insn::Jump(target) => { write!(f, "Jump {target}") }
@ -1135,6 +1146,7 @@ impl Function {
&Throw { throw_state, val } => Throw { throw_state, val: find!(val) },
&StringCopy { val, chilled, state } => StringCopy { val: find!(val), chilled, state },
&StringIntern { val } => StringIntern { val: find!(val) },
&StringConcat { ref strings, state } => StringConcat { strings: find_vec!(strings), state: find!(state) },
&Test { val } => Test { val: find!(val) },
&IsNil { val } => IsNil { val: find!(val) },
&Jump(ref target) => Jump(find_branch_edge!(target)),
@ -1258,6 +1270,7 @@ impl Function {
Insn::IsNil { .. } => types::CBool,
Insn::StringCopy { .. } => types::StringExact,
Insn::StringIntern { .. } => types::StringExact,
Insn::StringConcat { .. } => types::StringExact,
Insn::NewArray { .. } => types::ArrayExact,
Insn::ArrayDup { .. } => types::ArrayExact,
Insn::NewHash { .. } => types::HashExact,
@ -1887,6 +1900,10 @@ impl Function {
worklist.push_back(high);
worklist.push_back(state);
}
&Insn::StringConcat { ref strings, state, .. } => {
worklist.extend(strings);
worklist.push_back(state);
}
| &Insn::StringIntern { val }
| &Insn::Return { val }
| &Insn::Throw { val, .. }
@ -2469,6 +2486,16 @@ impl FrameState {
self.stack.pop().ok_or_else(|| ParseError::StackUnderflow(self.clone()))
}
fn stack_pop_n(&mut self, count: usize) -> Result<Vec<InsnId>, ParseError> {
// Check if we have enough values on the stack
let stack_len = self.stack.len();
if stack_len < count {
return Err(ParseError::StackUnderflow(self.clone()));
}
Ok(self.stack.split_off(stack_len - count))
}
/// Get a stack-top operand
fn stack_top(&self) -> Result<InsnId, ParseError> {
self.stack.last().ok_or_else(|| ParseError::StackUnderflow(self.clone())).copied()
@ -2789,24 +2816,23 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
let insn_id = fun.push_insn(block, Insn::StringIntern { val });
state.stack_push(insn_id);
}
YARVINSN_concatstrings => {
let count = get_arg(pc, 0).as_u32();
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
let strings = state.stack_pop_n(count as usize)?;
let insn_id = fun.push_insn(block, Insn::StringConcat { strings, state: exit_id });
state.stack_push(insn_id);
}
YARVINSN_newarray => {
let count = get_arg(pc, 0).as_usize();
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
let mut elements = vec![];
for _ in 0..count {
elements.push(state.stack_pop()?);
}
elements.reverse();
let elements = state.stack_pop_n(count)?;
state.stack_push(fun.push_insn(block, Insn::NewArray { elements, state: exit_id }));
}
YARVINSN_opt_newarray_send => {
let count = get_arg(pc, 0).as_usize();
let method = get_arg(pc, 1).as_u32();
let mut elements = vec![];
for _ in 0..count {
elements.push(state.stack_pop()?);
}
elements.reverse();
let elements = state.stack_pop_n(count)?;
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
let (bop, insn) = match method {
VM_OPT_NEWARRAY_SEND_MAX => (BOP_MAX, Insn::ArrayMax { elements, state: exit_id }),
@ -2871,13 +2897,10 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
}
YARVINSN_pushtoarray => {
let count = get_arg(pc, 0).as_usize();
let mut vals = vec![];
for _ in 0..count {
vals.push(state.stack_pop()?);
}
let vals = state.stack_pop_n(count)?;
let array = state.stack_pop()?;
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
for val in vals.into_iter().rev() {
for val in vals.into_iter() {
fun.push_insn(block, Insn::ArrayPush { array, val, state: exit_id });
}
state.stack_push(array);
@ -3079,12 +3102,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
}
let argc = unsafe { vm_ci_argc((*cd).ci) };
let mut args = vec![];
for _ in 0..argc {
args.push(state.stack_pop()?);
}
args.reverse();
let args = state.stack_pop_n(argc as usize)?;
let recv = state.stack_pop()?;
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
let send = fun.push_insn(block, Insn::SendWithoutBlock { self_val: recv, cd, args, state: exit_id });
@ -3160,12 +3178,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
}
let argc = unsafe { vm_ci_argc((*cd).ci) };
let mut args = vec![];
for _ in 0..argc {
args.push(state.stack_pop()?);
}
args.reverse();
let args = state.stack_pop_n(argc as usize)?;
let recv = state.stack_pop()?;
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
let send = fun.push_insn(block, Insn::SendWithoutBlock { self_val: recv, cd, args, state: exit_id });
@ -3183,12 +3196,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
}
let argc = unsafe { vm_ci_argc((*cd).ci) };
let mut args = vec![];
for _ in 0..argc {
args.push(state.stack_pop()?);
}
args.reverse();
let args = state.stack_pop_n(argc as usize)?;
let recv = state.stack_pop()?;
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
let send = fun.push_insn(block, Insn::Send { self_val: recv, cd, blockiseq, args, state: exit_id });
@ -5224,7 +5232,47 @@ mod tests {
v3:Fixnum[1] = Const Value(1)
v5:BasicObject = ObjToString v3
v7:String = AnyToString v3, str: v5
SideExit UnknownOpcode(concatstrings)
v9:StringExact = StringConcat v2, v7
Return v9
"#]]);
}
#[test]
fn test_string_concat() {
eval(r##"
def test = "#{1}#{2}#{3}"
"##);
assert_method_hir_with_opcode("test", YARVINSN_concatstrings, expect![[r#"
fn test@<compiled>:2:
bb0(v0:BasicObject):
v2:Fixnum[1] = Const Value(1)
v4:BasicObject = ObjToString v2
v6:String = AnyToString v2, str: v4
v7:Fixnum[2] = Const Value(2)
v9:BasicObject = ObjToString v7
v11:String = AnyToString v7, str: v9
v12:Fixnum[3] = Const Value(3)
v14:BasicObject = ObjToString v12
v16:String = AnyToString v12, str: v14
v18:StringExact = StringConcat v6, v11, v16
Return v18
"#]]);
}
#[test]
fn test_string_concat_empty() {
eval(r##"
def test = "#{}"
"##);
assert_method_hir_with_opcode("test", YARVINSN_concatstrings, expect![[r#"
fn test@<compiled>:2:
bb0(v0:BasicObject):
v2:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
v3:NilClass = Const Value(nil)
v5:BasicObject = ObjToString v3
v7:String = AnyToString v3, str: v5
v9:StringExact = StringConcat v2, v7
Return v9
"#]]);
}
@ -7172,7 +7220,8 @@ mod opt_tests {
v2:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
v3:StringExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
v5:StringExact = StringCopy v3
SideExit UnknownOpcode(concatstrings)
v11:StringExact = StringConcat v2, v5
Return v11
"#]]);
}
@ -7186,9 +7235,10 @@ mod opt_tests {
bb0(v0:BasicObject):
v2:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
v3:Fixnum[1] = Const Value(1)
v10:BasicObject = SendWithoutBlock v3, :to_s
v7:String = AnyToString v3, str: v10
SideExit UnknownOpcode(concatstrings)
v11:BasicObject = SendWithoutBlock v3, :to_s
v7:String = AnyToString v3, str: v11
v9:StringExact = StringConcat v2, v7
Return v9
"#]]);
}