summaryrefslogtreecommitdiff
path: root/lib/ruby_vm
diff options
context:
space:
mode:
authorTakashi Kokubun <[email protected]>2023-01-07 13:21:14 -0800
committerTakashi Kokubun <[email protected]>2023-03-05 22:11:20 -0800
commit62d36dd1277bdfeac609f89bc64589e8856421b8 (patch)
tree4c90de46486efd2c8912ef970449f4744630a5ba /lib/ruby_vm
parenteddec7bc209d721e99a8cd5ceaafd0f2ab270cc3 (diff)
Implement branch stub
Diffstat (limited to 'lib/ruby_vm')
-rw-r--r--lib/ruby_vm/mjit/assembler.rb18
-rw-r--r--lib/ruby_vm/mjit/block_stub.rb2
-rw-r--r--lib/ruby_vm/mjit/branch_stub.rb14
-rw-r--r--lib/ruby_vm/mjit/code_block.rb6
-rw-r--r--lib/ruby_vm/mjit/compiler.rb102
-rw-r--r--lib/ruby_vm/mjit/exit_compiler.rb37
-rw-r--r--lib/ruby_vm/mjit/insn_compiler.rb64
-rw-r--r--lib/ruby_vm/mjit/jit_state.rb4
8 files changed, 212 insertions, 35 deletions
diff --git a/lib/ruby_vm/mjit/assembler.rb b/lib/ruby_vm/mjit/assembler.rb
index f23696244d..32668ff3c8 100644
--- a/lib/ruby_vm/mjit/assembler.rb
+++ b/lib/ruby_vm/mjit/assembler.rb
@@ -169,6 +169,10 @@ module RubyVM::MJIT
in Label => dst_label
# 74 cb
insn(opcode: 0x74, imm: dst_label)
+ # JZ rel32
+ in Integer => dst_addr
+ # 0F 84 cd
+ insn(opcode: [0x0f, 0x84], imm: rel32(dst_addr))
else
raise NotImplementedError, "jz: not-implemented operands: #{dst.inspect}"
end
@@ -338,7 +342,7 @@ module RubyVM::MJIT
def test(left, right)
case [left, right]
# TEST r/m8*, imm8 (Mod 01: [reg]+disp8)
- in [[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm8?(right_imm)
+ in [[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm8?(right_imm) && right_imm >= 0
# REX + F6 /0 ib
# MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32
insn(
@@ -347,6 +351,17 @@ module RubyVM::MJIT
disp: left_disp,
imm: imm8(right_imm),
)
+ # TEST r/m64, imm32 (Mod 01: [reg]+disp8)
+ in [[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm32?(right_imm)
+ # REX.W + F7 /0 id
+ # MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32
+ insn(
+ prefix: REX_W,
+ opcode: 0xf7,
+ mod_rm: ModRM[mod: Mod01, reg: 0, rm: left_reg],
+ disp: left_disp,
+ imm: imm32(right_imm),
+ )
# TEST r/m32, r32 (Mod 11: reg)
in [Symbol => left_reg, Symbol => right_reg] if r32?(left_reg) && r32?(right_reg)
# 85 /r
@@ -574,7 +589,6 @@ module RubyVM::MJIT
end
@stub_ends.fetch(index, []).each do |stub|
stub.end_addr = write_addr + index
- stub.freeze
end
end
end
diff --git a/lib/ruby_vm/mjit/block_stub.rb b/lib/ruby_vm/mjit/block_stub.rb
index e6a0a867b9..c44b66170b 100644
--- a/lib/ruby_vm/mjit/block_stub.rb
+++ b/lib/ruby_vm/mjit/block_stub.rb
@@ -1,7 +1,7 @@
class RubyVM::MJIT::BlockStub < Struct.new(
:iseq, # @param [RubyVM::MJIT::CPointer::Struct_rb_iseq_struct] Stub target ISEQ
- :pc, # @param [Integer] Stub target pc
:ctx, # @param [RubyVM::MJIT::Context] Stub target context
+ :pc, # @param [Integer] Stub target pc
:start_addr, # @param [Integer] Stub source start address to be re-generated
:end_addr, # @param [Integer] Stub source end address to be re-generated
)
diff --git a/lib/ruby_vm/mjit/branch_stub.rb b/lib/ruby_vm/mjit/branch_stub.rb
new file mode 100644
index 0000000000..27ea5b9515
--- /dev/null
+++ b/lib/ruby_vm/mjit/branch_stub.rb
@@ -0,0 +1,14 @@
+class RubyVM::MJIT::BranchStub < Struct.new(
+ :iseq, # @param [RubyVM::MJIT::CPointer::Struct_rb_iseq_struct] Branch target ISEQ
+ :ctx, # @param [RubyVM::MJIT::Context] Branch target context
+ :branch_target_pc, # @param [Integer] Branch target PC
+ :branch_target_addr, # @param [Integer] Branch target address
+ :branch_target_next, # @param [Proc] Compile branch target next
+ :fallthrough_pc, # @param [Integer] Fallthrough PC
+ :fallthrough_addr, # @param [Integer] Fallthrough address
+ :fallthrough_next, # @param [Proc] Compile fallthrough next
+ :neither_next, # @param [Proc] Compile neither branch target nor fallthrough next
+ :start_addr, # @param [Integer] Stub source start address to be re-generated
+ :end_addr, # @param [Integer] Stub source end address to be re-generated
+)
+end
diff --git a/lib/ruby_vm/mjit/code_block.rb b/lib/ruby_vm/mjit/code_block.rb
index 15589b91d0..21ae2386b7 100644
--- a/lib/ruby_vm/mjit/code_block.rb
+++ b/lib/ruby_vm/mjit/code_block.rb
@@ -44,7 +44,7 @@ module RubyVM::MJIT
def with_write_addr(addr)
old_write_pos = @write_pos
- set_addr(addr)
+ set_write_addr(addr)
yield
ensure
@write_pos = old_write_pos
@@ -54,6 +54,10 @@ module RubyVM::MJIT
@mem_block + @write_pos
end
+ def include?(addr)
+ (@mem_block...(@mem_block + @mem_size)).include?(addr)
+ end
+
private
def dump_disasm(from, to)
diff --git a/lib/ruby_vm/mjit/compiler.rb b/lib/ruby_vm/mjit/compiler.rb
index 0e018a31f2..6a461fd580 100644
--- a/lib/ruby_vm/mjit/compiler.rb
+++ b/lib/ruby_vm/mjit/compiler.rb
@@ -1,6 +1,7 @@
require 'ruby_vm/mjit/assembler'
require 'ruby_vm/mjit/block'
require 'ruby_vm/mjit/block_stub'
+require 'ruby_vm/mjit/branch_stub'
require 'ruby_vm/mjit/code_block'
require 'ruby_vm/mjit/context'
require 'ruby_vm/mjit/exit_compiler'
@@ -61,29 +62,29 @@ module RubyVM::MJIT
$stderr.puts e.full_message # TODO: check verbose
end
- # Continue compilation from a stub.
- # @param stub [RubyVM::MJIT::BlockStub]
+ # Continue compilation from a block stub.
+ # @param block_stub [RubyVM::MJIT::BlockStub]
# @param cfp `RubyVM::MJIT::CPointer::Struct_rb_control_frame_t`
- # @return [Integer] The starting address of a compiled stub
- def stub_hit(stub, cfp)
+ # @return [Integer] The starting address of the compiled block stub
+ def block_stub_hit(block_stub, cfp)
# Update cfp->pc for `jit.at_current_insn?`
- cfp.pc = stub.pc
+ cfp.pc = block_stub.pc
# Prepare the jump target
new_asm = Assembler.new.tap do |asm|
- jit = JITState.new(iseq: stub.iseq, cfp:)
- compile_block(asm, jit:, pc: stub.pc, ctx: stub.ctx)
+ jit = JITState.new(iseq: block_stub.iseq, cfp:)
+ compile_block(asm, jit:, pc: block_stub.pc, ctx: block_stub.ctx)
end
- # Rewrite the stub
- if @cb.write_addr == stub.end_addr
- # If the stub jump is the last code, overwrite the jump with the new code.
- @cb.set_write_addr(stub.start_addr)
+ # Rewrite the block stub
+ if @cb.write_addr == block_stub.end_addr
+ # If the block stub's jump is the last code, overwrite the jump with the new code.
+ @cb.set_write_addr(block_stub.start_addr)
@cb.write(new_asm)
else
- # If the stub jump is old code, change the jump target to the new code.
+ # If the block stub's jump is old code, change the jump target to the new code.
new_addr = @cb.write(new_asm)
- @cb.with_write_addr(stub.start_addr) do
+ @cb.with_write_addr(block_stub.start_addr) do
asm = Assembler.new
asm.comment('regenerate block stub')
asm.jmp(new_addr)
@@ -92,6 +93,74 @@ module RubyVM::MJIT
end
end
+ # Compile a branch stub.
+ # @param branch_stub [RubyVM::MJIT::BranchStub]
+ # @param cfp `RubyVM::MJIT::CPointer::Struct_rb_control_frame_t`
+ # @param branch_target_p [TrueClass,FalseClass]
+ # @return [Integer] The starting address of the compiled branch stub
+ def branch_stub_hit(branch_stub, cfp, branch_target_p)
+ # Update cfp->pc for `jit.at_current_insn?`
+ pc = branch_target_p ? branch_stub.branch_target_pc : branch_stub.fallthrough_pc
+ cfp.pc = pc
+
+ # Prepare the jump target
+ new_asm = Assembler.new.tap do |asm|
+ jit = JITState.new(iseq: branch_stub.iseq, cfp:)
+ compile_block(asm, jit:, pc:, ctx: branch_stub.ctx.dup)
+ end
+
+ # Rewrite the branch stub
+ if @cb.write_addr == branch_stub.end_addr
+ # If the branch stub's jump is the last code, overwrite the jump with the new code.
+ @cb.set_write_addr(branch_stub.start_addr)
+ Assembler.new.tap do |branch_asm|
+ if branch_target_p
+ branch_stub.branch_target_next.call(branch_asm)
+ else
+ branch_stub.fallthrough_next.call(branch_asm)
+ end
+ @cb.write(branch_asm)
+ end
+
+ # Compile a fallthrough over the jump
+ if branch_target_p
+ branch_stub.branch_target_addr = @cb.write(new_asm)
+ else
+ branch_stub.fallthrough_addr = @cb.write(new_asm)
+ end
+ else
+ # Otherwise, just prepare the new code somewhere
+ if branch_target_p
+ unless @cb.include?(branch_stub.branch_target_addr)
+ branch_stub.branch_target_addr = @cb.write(new_asm)
+ end
+ else
+ unless @cb.include?(branch_stub.fallthrough_addr)
+ branch_stub.fallthrough_addr = @cb.write(new_asm)
+ end
+ end
+
+ # Update jump destinations
+ branch_asm = Assembler.new
+ if branch_stub.end_addr == branch_stub.branch_target_addr # branch_target_next has been used
+ branch_stub.branch_target_next.call(branch_asm)
+ elsif branch_stub.end_addr == branch_stub.fallthrough_addr # fallthrough_next has been used
+ branch_stub.fallthrough_next.call(branch_asm)
+ else
+ branch_stub.neither_next.call(branch_asm)
+ end
+ @cb.with_write_addr(branch_stub.start_addr) do
+ @cb.write(branch_asm)
+ end
+ end
+
+ if branch_target_p
+ branch_stub.branch_target_addr
+ else
+ branch_stub.fallthrough_addr
+ end
+ end
+
private
# Callee-saved: rbx, rsp, rbp, r12, r13, r14, r15
@@ -127,15 +196,18 @@ module RubyVM::MJIT
insn = self.class.decode_insn(iseq.body.iseq_encoded[index])
jit.pc = (iseq.body.iseq_encoded + index).to_i
- case @insn_compiler.compile(jit, ctx, asm, insn)
+ case status = @insn_compiler.compile(jit, ctx, asm, insn)
+ when KeepCompiling
+ index += insn.len
when EndBlock
# TODO: pad nops if entry exit exists
break
when CantCompile
@exit_compiler.compile_side_exit(jit, ctx, asm)
break
+ else
+ raise "compiling #{insn.name} returned unexpected status: #{status.inspect}"
end
- index += insn.len
end
end
end
diff --git a/lib/ruby_vm/mjit/exit_compiler.rb b/lib/ruby_vm/mjit/exit_compiler.rb
index a1eca9fe23..3a0c12f525 100644
--- a/lib/ruby_vm/mjit/exit_compiler.rb
+++ b/lib/ruby_vm/mjit/exit_compiler.rb
@@ -47,19 +47,30 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- # @param stub [RubyVM::MJIT::BlockStub]
- def compile_jump_stub(jit, ctx, asm, stub)
- case stub
- when BlockStub
- asm.comment("block stub hit: #{stub.iseq.body.location.label}@#{C.rb_iseq_path(stub.iseq)}:#{stub.iseq.body.location.first_lineno}")
- else
- raise "unexpected stub object: #{stub.inspect}"
- end
+ # @param block_stub [RubyVM::MJIT::BlockStub]
+ def compile_block_stub(jit, ctx, asm, block_stub)
+ # Call rb_mjit_block_stub_hit
+ asm.comment("block stub hit: #{block_stub.iseq.body.location.label}@#{C.rb_iseq_path(block_stub.iseq)}:#{iseq_lineno(block_stub.iseq, block_stub.pc)}")
+ asm.mov(:rdi, to_value(block_stub))
+ asm.mov(:esi, ctx.sp_offset)
+ asm.call(C.rb_mjit_block_stub_hit)
+
+ # Jump to the address returned by rb_mjit_stub_hit
+ asm.jmp(:rax)
+ end
- # Call rb_mjit_stub_hit
- asm.mov(:rdi, to_value(stub))
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ # @param branch_stub [RubyVM::MJIT::BranchStub]
+ # @param branch_target_p [TrueClass,FalseClass]
+ def compile_branch_stub(jit, ctx, asm, branch_stub, branch_target_p)
+ # Call rb_mjit_branch_stub_hit
+ asm.comment("branch stub hit: #{branch_stub.iseq.body.location.label}@#{C.rb_iseq_path(branch_stub.iseq)}:#{iseq_lineno(branch_stub.iseq, branch_target_p ? branch_stub.branch_target_pc : branch_stub.fallthrough_pc)}")
+ asm.mov(:rdi, to_value(branch_stub))
asm.mov(:esi, ctx.sp_offset)
- asm.call(C.rb_mjit_stub_hit)
+ asm.mov(:edx, branch_target_p ? 1 : 0)
+ asm.call(C.rb_mjit_branch_stub_hit)
# Jump to the address returned by rb_mjit_stub_hit
asm.jmp(:rax)
@@ -103,5 +114,9 @@ module RubyVM::MJIT
@gc_refs << obj
C.to_value(obj)
end
+
+ def iseq_lineno(iseq, pc)
+ C.rb_iseq_line_no(iseq, (pc - iseq.body.iseq_encoded.to_i) / C.VALUE.size)
+ end
end
end
diff --git a/lib/ruby_vm/mjit/insn_compiler.rb b/lib/ruby_vm/mjit/insn_compiler.rb
index 3468be9598..19bb607637 100644
--- a/lib/ruby_vm/mjit/insn_compiler.rb
+++ b/lib/ruby_vm/mjit/insn_compiler.rb
@@ -17,7 +17,7 @@ module RubyVM::MJIT
asm.incr_counter(:mjit_insns_count)
asm.comment("Insn: #{insn.name}")
- # 5/101
+ # 6/101
case insn.name
# nop
# getlocal
@@ -83,7 +83,7 @@ module RubyVM::MJIT
# throw
# jump
# branchif
- # branchunless
+ when :branchunless then branchunless(jit, ctx, asm)
# branchnil
# once
# opt_case_dispatch
@@ -247,7 +247,61 @@ module RubyVM::MJIT
# throw
# jump
# branchif
- # branchunless
+
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def branchunless(jit, ctx, asm)
+ # TODO: check ints for backward branches
+ # TODO: skip check for known truthy
+
+ # This `test` sets ZF only for Qnil and Qfalse, which let jz jump.
+ asm.test([SP, C.VALUE.size * (ctx.stack_size - 1)], ~Qnil)
+ ctx.stack_pop(1)
+
+ # Set stubs
+ # TODO: reuse already-compiled blocks
+ branch_stub = BranchStub.new(
+ iseq: jit.iseq,
+ ctx: ctx.dup,
+ branch_target_pc: jit.pc + (jit.insn.len + jit.operand(0)) * C.VALUE.size,
+ fallthrough_pc: jit.pc + jit.insn.len * C.VALUE.size,
+ )
+ branch_stub.branch_target_addr = Assembler.new.then do |ocb_asm|
+ @exit_compiler.compile_branch_stub(jit, ctx, ocb_asm, branch_stub, true)
+ @ocb.write(ocb_asm)
+ end
+ branch_stub.fallthrough_addr = Assembler.new.then do |ocb_asm|
+ @exit_compiler.compile_branch_stub(jit, ctx, ocb_asm, branch_stub, false)
+ @ocb.write(ocb_asm)
+ end
+
+ # Prepare codegen for all cases
+ branch_stub.branch_target_next = proc do |branch_asm|
+ branch_asm.stub(branch_stub) do
+ branch_asm.comment('branch_target_next')
+ branch_asm.jnz(branch_stub.fallthrough_addr)
+ end
+ end
+ branch_stub.fallthrough_next = proc do |branch_asm|
+ branch_asm.stub(branch_stub) do
+ branch_asm.comment('fallthrough_next')
+ branch_asm.jz(branch_stub.branch_target_addr)
+ end
+ end
+ branch_stub.neither_next = proc do |branch_asm|
+ branch_asm.stub(branch_stub) do
+ branch_asm.comment('neither_next')
+ branch_asm.jz(branch_stub.branch_target_addr)
+ branch_asm.jmp(branch_stub.fallthrough_addr)
+ end
+ end
+
+ # Just jump to stubs
+ branch_stub.neither_next.call(asm)
+ EndBlock
+ end
+
# branchnil
# once
# opt_case_dispatch
@@ -370,12 +424,12 @@ module RubyVM::MJIT
# Make a stub to compile the current insn
block_stub = BlockStub.new(
iseq: jit.iseq,
- pc: jit.pc,
ctx: ctx.dup,
+ pc: jit.pc,
)
stub_hit = Assembler.new.then do |ocb_asm|
- @exit_compiler.compile_jump_stub(jit, ctx, ocb_asm, block_stub)
+ @exit_compiler.compile_block_stub(jit, ctx, ocb_asm, block_stub)
@ocb.write(ocb_asm)
end
diff --git a/lib/ruby_vm/mjit/jit_state.rb b/lib/ruby_vm/mjit/jit_state.rb
index 97331e8108..ff48c2c107 100644
--- a/lib/ruby_vm/mjit/jit_state.rb
+++ b/lib/ruby_vm/mjit/jit_state.rb
@@ -8,6 +8,10 @@ module RubyVM::MJIT
)
def initialize(side_exits: {}, **) = super
+ def insn
+ Compiler.decode_insn(C.VALUE.new(pc).*)
+ end
+
def operand(index)
C.VALUE.new(pc)[index + 1]
end