diff options
Diffstat (limited to 'lib/ruby_vm/mjit')
-rw-r--r-- | lib/ruby_vm/mjit/assembler.rb | 151 | ||||
-rw-r--r-- | lib/ruby_vm/mjit/compiler.rb | 15 | ||||
-rw-r--r-- | lib/ruby_vm/mjit/exit_compiler.rb | 2 | ||||
-rw-r--r-- | lib/ruby_vm/mjit/insn_compiler.rb | 5 |
4 files changed, 123 insertions, 50 deletions
diff --git a/lib/ruby_vm/mjit/assembler.rb b/lib/ruby_vm/mjit/assembler.rb index 4620430faf..bfe716e554 100644 --- a/lib/ruby_vm/mjit/assembler.rb +++ b/lib/ruby_vm/mjit/assembler.rb @@ -3,16 +3,29 @@ module RubyVM::MJIT # https://www.intel.com/content/dam/develop/public/us/en/documents/325383-sdm-vol-2abcd.pdf # Mostly an x86_64 assembler, but this also has some stuff that is useful for any architecture. class Assembler + # A thin Fiddle wrapper to write bytes to memory + ByteWriter = CType::Immediate.parse('char') + + # Used for rel8 jumps class Label < Data.define(:id, :name); end # rel32 is inserted as [Rel32, Rel32Pad..] and converted on #resolve_rel32 class Rel32 < Data.define(:addr); end Rel32Pad = Object.new - ByteWriter = CType::Immediate.parse('char') + # A set of ModR/M values encoded on #insn + class ModRM < Data.define(:mod, :reg, :rm) + def initialize(mod:, reg: nil, rm: nil) = super + end + Mod00 = 0b00 # Mod 00: [reg] + Mod01 = 0b01 # Mod 01: [reg]+disp8 + Mod10 = 0b10 # Mod 10: [reg]+disp16 + Mod11 = 0b11 # Mod 11: reg ### prefix ### # REX = 0100WR0B + REX_B = 0b01000001 + REX_R = 0b01000100 REX_W = 0b01001000 def initialize @@ -47,24 +60,24 @@ module RubyVM::MJIT def add(dst, src) case [dst, src] - # ADD r/m64, imm8 (Mod 11) + # ADD r/m64, imm8 (Mod 11: reg) in [Symbol => dst_reg, Integer => src_imm] if r64?(dst_reg) && imm8?(src_imm) # REX.W + 83 /0 ib # MI: Operand 1: ModRM:r/m (r, w), Operand 2: imm8/16/32 insn( prefix: REX_W, opcode: 0x83, - mod_rm: mod_rm(mod: 0b11, rm: reg_code(dst_reg)), + mod_rm: ModRM[mod: Mod11, rm: dst_reg], imm: imm8(src_imm), ) - # ADD r/m64, imm8 (Mod 00) + # ADD r/m64, imm8 (Mod 00: [reg]) in [[Symbol => dst_reg], Integer => src_imm] if r64?(dst_reg) && imm8?(src_imm) # REX.W + 83 /0 ib # MI: Operand 1: ModRM:r/m (r, w), Operand 2: imm8/16/32 insn( prefix: REX_W, opcode: 0x83, - mod_rm: mod_rm(mod: 0b00, rm: reg_code(dst_reg)), # Mod 00: [reg] + mod_rm: ModRM[mod: Mod00, rm: dst_reg], imm: imm8(src_imm), ) else @@ -98,42 +111,51 @@ module RubyVM::MJIT case dst in Symbol => dst_reg case src - # MOV r64, r/m64 (Mod 00) + # MOV r64, r/m64 (Mod 00: [reg]) in [Symbol => src_reg] if r64?(dst_reg) && r64?(src_reg) # REX.W + 8B /r # RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r) insn( prefix: REX_W, opcode: 0x8b, - mod_rm: mod_rm(mod: 0b00, reg: reg_code(dst_reg), rm: reg_code(src_reg)), # Mod 00: [reg] + mod_rm: ModRM[mod: Mod00, reg: dst_reg, rm: src_reg], ) - # MOV r32 r/m32 (Mod 01) - in [Symbol => src_reg, Integer => src_disp] if r32?(dst_reg) && imm8?(src_disp) - # 8B /r + # MOV r64, r/m64 (Mod 01: [reg]+disp8) + in [Symbol => src_reg, Integer => src_disp] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) + # REX.W + 8B /r # RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r) insn( + prefix: REX_W, opcode: 0x8b, - mod_rm: mod_rm(mod: 0b01, reg: reg_code(dst_reg), rm: reg_code(src_reg)), # Mod 01: [reg]+disp8 + mod_rm: ModRM[mod: Mod01, reg: dst_reg, rm: src_reg], disp: src_disp, ) - # MOV r64, r/m64 (Mod 01) - in [Symbol => src_reg, Integer => src_disp] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) + # MOV r64, r/m64 (Mod 11: reg) + in Symbol => src_reg if r64?(dst_reg) && r64?(src_reg) # REX.W + 8B /r # RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r) insn( prefix: REX_W, opcode: 0x8b, - mod_rm: mod_rm(mod: 0b01, reg: reg_code(dst_reg), rm: reg_code(src_reg)), # Mod 01: [reg]+disp8 + mod_rm: ModRM[mod: Mod11, reg: dst_reg, rm: src_reg], + ) + # MOV r32 r/m32 (Mod 01: [reg]+disp8) + in [Symbol => src_reg, Integer => src_disp] if r32?(dst_reg) && imm8?(src_disp) + # 8B /r + # RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r) + insn( + opcode: 0x8b, + mod_rm: ModRM[mod: Mod01, reg: dst_reg, rm: src_reg], disp: src_disp, ) - # MOV r/m64, imm32 (Mod 11) + # MOV r/m64, imm32 (Mod 11: reg) in Integer => src_imm if r64?(dst_reg) && imm32?(src_imm) # REX.W + C7 /0 id # MI: Operand 1: ModRM:r/m (w), Operand 2: imm8/16/32/64 insn( prefix: REX_W, opcode: 0xc7, - mod_rm: mod_rm(mod: 0b11, rm: reg_code(dst_reg)), # Mod 11: reg + mod_rm: ModRM[mod: Mod11, rm: dst_reg], imm: imm32(src_imm), ) # MOV r64, imm64 @@ -142,7 +164,8 @@ module RubyVM::MJIT # OI: Operand 1: opcode + rd (w), Operand 2: imm8/16/32/64 insn( prefix: REX_W, - opcode: 0xb8 + reg_code(dst_reg), + opcode: 0xb8, + rd: dst_reg, imm: imm64(src_imm), ) else @@ -150,24 +173,24 @@ module RubyVM::MJIT end in [Symbol => dst_reg] case src - # MOV r/m64, imm32 (Mod 00) + # MOV r/m64, imm32 (Mod 00: [reg]) in Integer => src_imm if r64?(dst_reg) && imm32?(src_imm) # REX.W + C7 /0 id # MI: Operand 1: ModRM:r/m (w), Operand 2: imm8/16/32/64 insn( prefix: REX_W, opcode: 0xc7, - mod_rm: mod_rm(mod: 0b00, rm: reg_code(dst_reg)), # Mod 00: [reg] + mod_rm: ModRM[mod: Mod00, rm: dst_reg], imm: imm32(src_imm), ) - # MOV r/m64, r64 (Mod 00) + # MOV r/m64, r64 (Mod 00: [reg]) in Symbol => src_reg if r64?(dst_reg) && r64?(src_reg) # REX.W + 89 /r # MR: Operand 1: ModRM:r/m (w), Operand 2: ModRM:reg (r) insn( prefix: REX_W, opcode: 0x89, - mod_rm: mod_rm(mod: 0b00, reg: reg_code(src_reg), rm: reg_code(dst_reg)), # Mod 00: [reg] + mod_rm: ModRM[mod: Mod00, reg: src_reg, rm: dst_reg], ) else raise NotImplementedError, "mov: not-implemented operands: #{dst.inspect}, #{src.inspect}" @@ -177,25 +200,25 @@ module RubyVM::MJIT return mov([dst_reg], src) if dst_disp == 0 case src - # MOV r/m64, imm32 (Mod 01) + # MOV r/m64, imm32 (Mod 01: [reg]+disp8) in Integer => src_imm if r64?(dst_reg) && imm8?(dst_disp) && imm32?(src_imm) # REX.W + C7 /0 id # MI: Operand 1: ModRM:r/m (w), Operand 2: imm8/16/32/64 insn( prefix: REX_W, opcode: 0xc7, - mod_rm: mod_rm(mod: 0b01, rm: reg_code(dst_reg)), # Mod 01: [reg]+disp8 + mod_rm: ModRM[mod: Mod01, rm: dst_reg], disp: dst_disp, imm: imm32(src_imm), ) - # MOV r/m64, r64 (Mod 01) + # MOV r/m64, r64 (Mod 01: [reg]+disp8) in Symbol => src_reg if r64?(dst_reg) && imm8?(dst_disp) && r64?(src_reg) # REX.W + 89 /r # MR: Operand 1: ModRM:r/m (w), Operand 2: ModRM:reg (r) insn( prefix: REX_W, opcode: 0x89, - mod_rm: mod_rm(mod: 0b01, reg: reg_code(src_reg), rm: reg_code(dst_reg)), # Mod 01: [reg]+disp8 + mod_rm: ModRM[mod: Mod01, reg: src_reg, rm: dst_reg], disp: dst_disp, ) else @@ -212,7 +235,7 @@ module RubyVM::MJIT in Symbol => src_reg if r64?(src_reg) # 50+rd # O: Operand 1: opcode + rd (r) - insn(opcode: 0x50 + reg_code(src_reg)) + insn(opcode: 0x50, rd: src_reg) else raise NotImplementedError, "push: not-implemented operands: #{src.inspect}" end @@ -224,7 +247,7 @@ module RubyVM::MJIT in Symbol => dst_reg if r64?(dst_reg) # 58+ rd # O: Operand 1: opcode + rd (r) - insn(opcode: 0x58 + reg_code(dst_reg)) + insn(opcode: 0x58, rd: dst_reg) else raise NotImplementedError, "pop: not-implemented operands: #{dst.inspect}" end @@ -238,13 +261,13 @@ module RubyVM::MJIT def test(left, right) case [left, right] - # TEST r/m32, r32 (Mod 11) + # TEST r/m32, r32 (Mod 11: reg) in [Symbol => left_reg, Symbol => right_reg] if r32?(left_reg) && r32?(right_reg) # 85 /r # MR: Operand 1: ModRM:r/m (r), Operand 2: ModRM:reg (r) insn( opcode: 0x85, - mod_rm: mod_rm(mod: 0b11, reg: reg_code(right_reg), rm: reg_code(left_reg)), # Mod 11: reg + mod_rm: ModRM[mod: Mod11, reg: right_reg, rm: left_reg], ) else raise NotImplementedError, "pop: not-implemented operands: #{dst.inspect}" @@ -284,13 +307,29 @@ module RubyVM::MJIT private - def insn(prefix: nil, opcode:, mod_rm: nil, disp: nil, imm: nil) - if prefix + def insn(prefix: 0, opcode:, rd: nil, mod_rm: nil, disp: nil, imm: nil) + # Determine prefix + if rd + prefix |= REX_B if extended_reg?(rd) + opcode += reg_code(rd) + end + if mod_rm + prefix |= REX_R if mod_rm.reg && extended_reg?(mod_rm.reg) + prefix |= REX_B if mod_rm.rm && extended_reg?(mod_rm.rm) + end + + # Encode insn + if prefix > 0 @bytes.push(prefix) end @bytes.push(*Array(opcode)) if mod_rm - @bytes.push(mod_rm) + mod_rm_byte = encode_mod_rm( + mod: mod_rm.mod, + reg: mod_rm.reg ? reg_code(mod_rm.reg) : 0, + rm: mod_rm.rm ? reg_code(mod_rm.rm) : 0, + ) + @bytes.push(mod_rm_byte) end if disp unless imm8?(disp) # TODO: support displacement in 2 or 4 bytes as well @@ -304,15 +343,33 @@ module RubyVM::MJIT end def reg_code(reg) + reg_code_extended(reg).first + end + + def extended_reg?(reg) + reg_code_extended(reg).last + end + + def reg_code_extended(reg) case reg - when :al, :ax, :eax, :rax then 0 - when :cl, :cx, :ecx, :rcx then 1 - when :dl, :dx, :edx, :rdx then 2 - when :bl, :bx, :ebx, :rbx then 3 - when :ah, :sp, :esp, :rsp then 4 - when :ch, :bp, :ebp, :rbp then 5 - when :dh, :si, :esi, :rsi then 6 - when :bh, :di, :edi, :rdi then 7 + # Not extended + when :al, :ax, :eax, :rax then [0, false] + when :cl, :cx, :ecx, :rcx then [1, false] + when :dl, :dx, :edx, :rdx then [2, false] + when :bl, :bx, :ebx, :rbx then [3, false] + when :ah, :sp, :esp, :rsp then [4, false] + when :ch, :bp, :ebp, :rbp then [5, false] + when :dh, :si, :esi, :rsi then [6, false] + when :bh, :di, :edi, :rdi then [7, false] + # Extended + when :r8b, :r8w, :r8d, :r8 then [0, true] + when :r9b, :r9w, :r9d, :r9 then [1, true] + when :r10b, :r10w, :r10d, :r10 then [2, true] + when :r11b, :r11w, :r11d, :r11 then [3, true] + when :r12b, :r12w, :r12d, :r12 then [4, true] + when :r13b, :r13w, :r13d, :r13 then [5, true] + when :r14b, :r14w, :r14d, :r14 then [6, true] + when :r15b, :r15w, :r15d, :r15 then [7, true] else raise ArgumentError, "unexpected reg: #{reg.inspect}" end end @@ -330,7 +387,7 @@ module RubyVM::MJIT # # /0: R/M is 0 (not used) # /r: R/M is a register - def mod_rm(mod:, reg: 0, rm: 0) + def encode_mod_rm(mod:, reg: 0, rm: 0) if mod > 0b11 raise ArgumentError, "too large Mod: #{mod}" end @@ -389,11 +446,19 @@ module RubyVM::MJIT end def r32?(reg) - reg.start_with?('e') + if extended_reg?(reg) + reg.end_with?('d') + else + reg.start_with?('e') + end end def r64?(reg) - reg.start_with?('r') + if extended_reg?(reg) + reg.match?(/\Ar\d+\z/) + else + reg.start_with?('r') + end end def rel32(addr) diff --git a/lib/ruby_vm/mjit/compiler.rb b/lib/ruby_vm/mjit/compiler.rb index 6203863218..ac094c7007 100644 --- a/lib/ruby_vm/mjit/compiler.rb +++ b/lib/ruby_vm/mjit/compiler.rb @@ -16,9 +16,10 @@ module RubyVM::MJIT Qnil = Fiddle::Qnil Qundef = Fiddle::Qundef - # Fixed registers - EC = :rdi # TODO: change this - CFP = :rsi # TODO: change this + # Callee-saved registers + # TODO: support using r12/r13 here + EC = :r14 + CFP = :r15 SP = :rbx class Compiler @@ -64,9 +65,15 @@ module RubyVM::MJIT asm.comment("MJIT entry") # Save callee-saved registers used by JITed code + asm.push(CFP) + asm.push(EC) asm.push(SP) - # Load sp to a register + # Move arguments EC and CFP to dedicated registers + asm.mov(EC, :rdi) + asm.mov(CFP, :rsi) + + # Load sp to a dedicated register asm.mov(SP, [CFP, C.rb_control_frame_t.offsetof(:sp)]) # rbx = cfp->sp end diff --git a/lib/ruby_vm/mjit/exit_compiler.rb b/lib/ruby_vm/mjit/exit_compiler.rb index 7f5219ea5b..c9f2190931 100644 --- a/lib/ruby_vm/mjit/exit_compiler.rb +++ b/lib/ruby_vm/mjit/exit_compiler.rb @@ -26,6 +26,8 @@ module RubyVM::MJIT # Restore callee-saved registers asm.pop(SP) + asm.pop(EC) + asm.pop(CFP) asm.mov(:rax, Qundef) asm.ret diff --git a/lib/ruby_vm/mjit/insn_compiler.rb b/lib/ruby_vm/mjit/insn_compiler.rb index 4c4c86c5f0..96bab0a225 100644 --- a/lib/ruby_vm/mjit/insn_compiler.rb +++ b/lib/ruby_vm/mjit/insn_compiler.rb @@ -1,7 +1,4 @@ module RubyVM::MJIT - # ec: rdi - # cfp: rsi - # sp: rbx # scratch regs: rax # # 4/101 @@ -123,6 +120,8 @@ module RubyVM::MJIT # Restore callee-saved registers asm.pop(SP) + asm.pop(EC) + asm.pop(CFP) asm.ret EndBlock |