Implement branch stub

This commit is contained in:
Takashi Kokubun 2023-01-07 13:21:14 -08:00
parent eddec7bc20
commit 62d36dd127
10 changed files with 256 additions and 41 deletions

View file

@ -1,6 +1,7 @@
require 'ruby_vm/mjit/assembler'
require 'ruby_vm/mjit/block'
require 'ruby_vm/mjit/block_stub'
require 'ruby_vm/mjit/branch_stub'
require 'ruby_vm/mjit/code_block'
require 'ruby_vm/mjit/context'
require 'ruby_vm/mjit/exit_compiler'
@ -61,29 +62,29 @@ module RubyVM::MJIT
$stderr.puts e.full_message # TODO: check verbose
end
# Continue compilation from a stub.
# @param stub [RubyVM::MJIT::BlockStub]
# Continue compilation from a block stub.
# @param block_stub [RubyVM::MJIT::BlockStub]
# @param cfp `RubyVM::MJIT::CPointer::Struct_rb_control_frame_t`
# @return [Integer] The starting address of a compiled stub
def stub_hit(stub, cfp)
# @return [Integer] The starting address of the compiled block stub
def block_stub_hit(block_stub, cfp)
# Update cfp->pc for `jit.at_current_insn?`
cfp.pc = stub.pc
cfp.pc = block_stub.pc
# Prepare the jump target
new_asm = Assembler.new.tap do |asm|
jit = JITState.new(iseq: stub.iseq, cfp:)
compile_block(asm, jit:, pc: stub.pc, ctx: stub.ctx)
jit = JITState.new(iseq: block_stub.iseq, cfp:)
compile_block(asm, jit:, pc: block_stub.pc, ctx: block_stub.ctx)
end
# Rewrite the stub
if @cb.write_addr == stub.end_addr
# If the stub jump is the last code, overwrite the jump with the new code.
@cb.set_write_addr(stub.start_addr)
# Rewrite the block stub
if @cb.write_addr == block_stub.end_addr
# If the block stub's jump is the last code, overwrite the jump with the new code.
@cb.set_write_addr(block_stub.start_addr)
@cb.write(new_asm)
else
# If the stub jump is old code, change the jump target to the new code.
# If the block stub's jump is old code, change the jump target to the new code.
new_addr = @cb.write(new_asm)
@cb.with_write_addr(stub.start_addr) do
@cb.with_write_addr(block_stub.start_addr) do
asm = Assembler.new
asm.comment('regenerate block stub')
asm.jmp(new_addr)
@ -92,6 +93,74 @@ module RubyVM::MJIT
end
end
# Compile a branch stub.
# @param branch_stub [RubyVM::MJIT::BranchStub]
# @param cfp `RubyVM::MJIT::CPointer::Struct_rb_control_frame_t`
# @param branch_target_p [TrueClass,FalseClass]
# @return [Integer] The starting address of the compiled branch stub
def branch_stub_hit(branch_stub, cfp, branch_target_p)
# Update cfp->pc for `jit.at_current_insn?`
pc = branch_target_p ? branch_stub.branch_target_pc : branch_stub.fallthrough_pc
cfp.pc = pc
# Prepare the jump target
new_asm = Assembler.new.tap do |asm|
jit = JITState.new(iseq: branch_stub.iseq, cfp:)
compile_block(asm, jit:, pc:, ctx: branch_stub.ctx.dup)
end
# Rewrite the branch stub
if @cb.write_addr == branch_stub.end_addr
# If the branch stub's jump is the last code, overwrite the jump with the new code.
@cb.set_write_addr(branch_stub.start_addr)
Assembler.new.tap do |branch_asm|
if branch_target_p
branch_stub.branch_target_next.call(branch_asm)
else
branch_stub.fallthrough_next.call(branch_asm)
end
@cb.write(branch_asm)
end
# Compile a fallthrough over the jump
if branch_target_p
branch_stub.branch_target_addr = @cb.write(new_asm)
else
branch_stub.fallthrough_addr = @cb.write(new_asm)
end
else
# Otherwise, just prepare the new code somewhere
if branch_target_p
unless @cb.include?(branch_stub.branch_target_addr)
branch_stub.branch_target_addr = @cb.write(new_asm)
end
else
unless @cb.include?(branch_stub.fallthrough_addr)
branch_stub.fallthrough_addr = @cb.write(new_asm)
end
end
# Update jump destinations
branch_asm = Assembler.new
if branch_stub.end_addr == branch_stub.branch_target_addr # branch_target_next has been used
branch_stub.branch_target_next.call(branch_asm)
elsif branch_stub.end_addr == branch_stub.fallthrough_addr # fallthrough_next has been used
branch_stub.fallthrough_next.call(branch_asm)
else
branch_stub.neither_next.call(branch_asm)
end
@cb.with_write_addr(branch_stub.start_addr) do
@cb.write(branch_asm)
end
end
if branch_target_p
branch_stub.branch_target_addr
else
branch_stub.fallthrough_addr
end
end
private
# Callee-saved: rbx, rsp, rbp, r12, r13, r14, r15
@ -127,15 +196,18 @@ module RubyVM::MJIT
insn = self.class.decode_insn(iseq.body.iseq_encoded[index])
jit.pc = (iseq.body.iseq_encoded + index).to_i
case @insn_compiler.compile(jit, ctx, asm, insn)
case status = @insn_compiler.compile(jit, ctx, asm, insn)
when KeepCompiling
index += insn.len
when EndBlock
# TODO: pad nops if entry exit exists
break
when CantCompile
@exit_compiler.compile_side_exit(jit, ctx, asm)
break
else
raise "compiling #{insn.name} returned unexpected status: #{status.inspect}"
end
index += insn.len
end
end
end