ZJIT: Remove the need for unwrap() on with_num_bits() (#14144)

* ZJIT: Remove the need for unwrap() on with_num_bits()

* Fix arm64 tests

* Track the caller of with_num_bits

Co-authored-by: Alan Wu <XrXr@users.noreply.github.com>

---------

Co-authored-by: Alan Wu <XrXr@users.noreply.github.com>
This commit is contained in:
Takashi Kokubun 2025-08-07 16:56:27 -07:00 committed by GitHub
parent 2edc944702
commit 4fef87588a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 26 additions and 13 deletions

View file

@ -256,7 +256,7 @@ impl Assembler
// Many Arm insns support only 32-bit or 64-bit operands. asm.load with fewer
// bits zero-extends the value, so it's safe to recognize it as a 32-bit value.
if out_opnd.rm_num_bits() < 32 {
out_opnd.with_num_bits(32).unwrap()
out_opnd.with_num_bits(32)
} else {
out_opnd
}
@ -282,7 +282,7 @@ impl Assembler
BitmaskImmediate::new_32b_reg(imm as u32).is_ok()) {
Opnd::UImm(imm as u64)
} else {
asm.load(opnd).with_num_bits(dest_num_bits).unwrap()
asm.load(opnd).with_num_bits(dest_num_bits)
}
},
Opnd::UImm(uimm) => {
@ -292,7 +292,7 @@ impl Assembler
BitmaskImmediate::new_32b_reg(uimm as u32).is_ok()) {
opnd
} else {
asm.load(opnd).with_num_bits(dest_num_bits).unwrap()
asm.load(opnd).with_num_bits(dest_num_bits)
}
},
Opnd::None | Opnd::Value(_) => unreachable!()
@ -360,8 +360,8 @@ impl Assembler
match opnd0 {
Opnd::Reg(_) | Opnd::VReg { .. } => {
match opnd0.rm_num_bits() {
8 => asm.and(opnd0.with_num_bits(64).unwrap(), Opnd::UImm(0xff)),
16 => asm.and(opnd0.with_num_bits(64).unwrap(), Opnd::UImm(0xffff)),
8 => asm.and(opnd0.with_num_bits(64), Opnd::UImm(0xff)),
16 => asm.and(opnd0.with_num_bits(64), Opnd::UImm(0xffff)),
32 | 64 => opnd0,
bits => unreachable!("Invalid number of bits. {}", bits)
}
@ -505,7 +505,7 @@ impl Assembler
let split_right = split_shifted_immediate(asm, *right);
let opnd1 = match split_right {
Opnd::VReg { .. } if opnd0.num_bits() != split_right.num_bits() => {
split_right.with_num_bits(opnd0.num_bits().unwrap()).unwrap()
split_right.with_num_bits(opnd0.num_bits().unwrap())
},
_ => split_right
};
@ -1823,7 +1823,7 @@ mod tests {
#[test]
fn test_emit_test_32b_reg_not_bitmask_imm() {
let (mut asm, mut cb) = setup_asm();
let w0 = Opnd::Reg(X0_REG).with_num_bits(32).unwrap();
let w0 = Opnd::Reg(X0_REG).with_num_bits(32);
asm.test(w0, Opnd::UImm(u32::MAX.into()));
// All ones is not encodable with a bitmask immediate,
// so this needs one register
@ -1833,7 +1833,7 @@ mod tests {
#[test]
fn test_emit_test_32b_reg_bitmask_imm() {
let (mut asm, mut cb) = setup_asm();
let w0 = Opnd::Reg(X0_REG).with_num_bits(32).unwrap();
let w0 = Opnd::Reg(X0_REG).with_num_bits(32);
asm.test(w0, Opnd::UImm(0x80000001));
asm.compile_with_num_regs(&mut cb, 0);
}

View file

@ -146,17 +146,29 @@ impl Opnd
}
}
pub fn with_num_bits(&self, num_bits: u8) -> Option<Opnd> {
/// Return Some(Opnd) with a given num_bits if self has num_bits.
/// None if self doesn't have a num_bits field.
pub fn try_num_bits(&self, num_bits: u8) -> Option<Opnd> {
assert!(num_bits == 8 || num_bits == 16 || num_bits == 32 || num_bits == 64);
match *self {
Opnd::Reg(reg) => Some(Opnd::Reg(reg.with_num_bits(num_bits))),
Opnd::Mem(Mem { base, disp, .. }) => Some(Opnd::Mem(Mem { base, disp, num_bits })),
Opnd::VReg { idx, .. } => Some(Opnd::VReg { idx, num_bits }),
//Opnd::Stack { idx, stack_size, num_locals, sp_offset, reg_mapping, .. } => Some(Opnd::Stack { idx, num_bits, stack_size, num_locals, sp_offset, reg_mapping }),
_ => None,
}
}
/// Return Opnd with a given num_bits if self has num_bits.
/// Panic otherwise. This should be used only when you know which Opnd self is.
#[track_caller]
pub fn with_num_bits(&self, num_bits: u8) -> Opnd {
if let Some(opnd) = self.try_num_bits(num_bits) {
opnd
} else {
unreachable!("with_num_bits should not be used on: {self:?}");
}
}
/// Get the size in bits for register/memory operands.
pub fn rm_num_bits(&self) -> u8 {
self.num_bits().unwrap()
@ -1720,7 +1732,7 @@ impl Assembler
while let Some(opnd) = opnd_iter.next() {
match *opnd {
Opnd::VReg { idx, num_bits } => {
*opnd = Opnd::Reg(reg_mapping[idx].unwrap()).with_num_bits(num_bits).unwrap();
*opnd = Opnd::Reg(reg_mapping[idx].unwrap()).with_num_bits(num_bits);
},
Opnd::Mem(Mem { base: MemBase::VReg(idx), disp, num_bits }) => {
let base = MemBase::Reg(reg_mapping[idx].unwrap().reg_no);

View file

@ -463,7 +463,7 @@ fn gen_defined(jit: &JITState, asm: &mut Assembler, op_type: usize, obj: VALUE,
// Call vm_defined(ec, reg_cfp, op_type, obj, v)
let def_result = asm_ccall!(asm, rb_vm_defined, EC, CFP, op_type.into(), obj.into(), tested_value);
asm.cmp(def_result.with_num_bits(8).unwrap(), 0.into());
asm.cmp(def_result.with_num_bits(8), 0.into());
Some(asm.csel_ne(pushval.into(), Qnil.into()))
}
}
@ -1070,7 +1070,8 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard
} else if guard_type.is_subtype(types::StaticSymbol) {
// Static symbols have (val & 0xff) == RUBY_SYMBOL_FLAG
// Use 8-bit comparison like YJIT does
asm.cmp(val.with_num_bits(8).unwrap(), Opnd::UImm(RUBY_SYMBOL_FLAG as u64));
debug_assert!(val.try_num_bits(8).is_some(), "GuardType should not be used for a known constant, but val was: {val:?}");
asm.cmp(val.try_num_bits(8)?, Opnd::UImm(RUBY_SYMBOL_FLAG as u64));
asm.jne(side_exit(jit, state, GuardType(guard_type))?);
} else if guard_type.is_subtype(types::NilClass) {
asm.cmp(val, Qnil.into());