diff options
-rw-r--r-- | test/ruby/test_zjit.rb | 18 | ||||
-rw-r--r-- | zjit/src/asm/arm64/mod.rs | 4 | ||||
-rw-r--r-- | zjit/src/asm/arm64/opnd.rs | 5 | ||||
-rw-r--r-- | zjit/src/backend/arm64/mod.rs | 18 | ||||
-rw-r--r-- | zjit/src/backend/lir.rs | 265 | ||||
-rw-r--r-- | zjit/src/backend/x86_64/mod.rs | 28 | ||||
-rw-r--r-- | zjit/src/codegen.rs | 53 | ||||
-rw-r--r-- | zjit/src/hir.rs | 5 | ||||
-rw-r--r-- | zjit/src/state.rs | 18 |
9 files changed, 302 insertions, 112 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index 796851a9bf..aba05ddebd 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -94,6 +94,24 @@ class TestZJIT < Test::Unit::TestCase }, call_threshold: 2 end + def test_opt_plus_type_guard_exit + assert_compiles '[3, 3.0]', %q{ + def test(a) = 1 + a + test(1) # profile opt_plus + [test(2), test(2.0)] + }, call_threshold: 2 + end + + def test_opt_plus_type_guard_nested_exit + omit 'rewind_caller_frames is not implemented yet' + assert_compiles '[3, 3.0]', %q{ + def side_exit(n) = 1 + n + def jit_frame(n) = 1 + side_exit(n) + def entry(n) = jit_frame(n) + [entry(2), entry(2.0)] + }, call_threshold: 2 + end + # Test argument ordering def test_opt_minus assert_compiles '2', %q{ diff --git a/zjit/src/asm/arm64/mod.rs b/zjit/src/asm/arm64/mod.rs index a5d73d71a5..1e1b125eaa 100644 --- a/zjit/src/asm/arm64/mod.rs +++ b/zjit/src/asm/arm64/mod.rs @@ -644,7 +644,7 @@ pub fn mov(cb: &mut CodeBlock, rd: A64Opnd, rm: A64Opnd) { LogicalImm::mov(rd.reg_no, bitmask_imm, rd.num_bits).into() }, - _ => panic!("Invalid operand combination to mov instruction") + _ => panic!("Invalid operand combination to mov instruction: {rd:?}, {rm:?}") }; cb.write_bytes(&bytes); @@ -940,7 +940,7 @@ pub fn stur(cb: &mut CodeBlock, rt: A64Opnd, rn: A64Opnd) { LoadStore::stur(rt.reg_no, rn.base_reg_no, rn.disp as i16, rn.num_bits).into() }, - _ => panic!("Invalid operand combination to stur instruction.") + _ => panic!("Invalid operand combination to stur instruction: {rt:?}, {rn:?}") }; cb.write_bytes(&bytes); diff --git a/zjit/src/asm/arm64/opnd.rs b/zjit/src/asm/arm64/opnd.rs index 6e31851504..28422b7476 100644 --- a/zjit/src/asm/arm64/opnd.rs +++ b/zjit/src/asm/arm64/opnd.rs @@ -119,6 +119,9 @@ pub const X20_REG: A64Reg = A64Reg { num_bits: 64, reg_no: 20 }; pub const X21_REG: A64Reg = A64Reg { num_bits: 64, reg_no: 21 }; pub const X22_REG: A64Reg = A64Reg { num_bits: 64, reg_no: 22 }; +// link register +pub const X30_REG: A64Reg = A64Reg { num_bits: 64, reg_no: 30 }; + // zero register pub const XZR_REG: A64Reg = A64Reg { num_bits: 64, reg_no: 31 }; @@ -153,7 +156,7 @@ pub const X26: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 64, reg_no: 26 }); pub const X27: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 64, reg_no: 27 }); pub const X28: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 64, reg_no: 28 }); pub const X29: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 64, reg_no: 29 }); -pub const X30: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 64, reg_no: 30 }); +pub const X30: A64Opnd = A64Opnd::Reg(X30_REG); pub const X31: A64Opnd = A64Opnd::Reg(XZR_REG); // 32-bit registers diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs index ffde567b69..832f3c1e1e 100644 --- a/zjit/src/backend/arm64/mod.rs +++ b/zjit/src/backend/arm64/mod.rs @@ -211,6 +211,11 @@ impl Assembler vec![X1_REG, X9_REG, X10_REG, X11_REG, X12_REG, X13_REG, X14_REG, X15_REG] } + /// Get the address that the current frame returns to + pub fn return_addr_opnd() -> Opnd { + Opnd::Reg(X30_REG) + } + /// Split platform-specific instructions /// The transformations done here are meant to make our lives simpler in later /// stages of the compilation pipeline. @@ -757,7 +762,7 @@ impl Assembler /// called when lowering any of the conditional jump instructions. fn emit_conditional_jump<const CONDITION: u8>(cb: &mut CodeBlock, target: Target) { match target { - Target::CodePtr(dst_ptr) | Target::SideExitPtr(dst_ptr) => { + Target::CodePtr(dst_ptr) => { let dst_addr = dst_ptr.as_offset(); let src_addr = cb.get_write_ptr().as_offset(); @@ -829,8 +834,10 @@ impl Assembler } /// Emit a CBZ or CBNZ which branches when a register is zero or non-zero - fn emit_cmp_zero_jump(cb: &mut CodeBlock, reg: A64Opnd, branch_if_zero: bool, target: Target) { - if let Target::SideExitPtr(dst_ptr) = target { + fn emit_cmp_zero_jump(_cb: &mut CodeBlock, _reg: A64Opnd, _branch_if_zero: bool, target: Target) { + if let Target::Label(_) = target { + unimplemented!("this should be re-implemented with Label for side exits"); + /* let dst_addr = dst_ptr.as_offset(); let src_addr = cb.get_write_ptr().as_offset(); @@ -862,6 +869,7 @@ impl Assembler br(cb, Assembler::SCRATCH0); } + */ } else { unreachable!("We should only generate Joz/Jonz with side-exit targets"); } @@ -1162,9 +1170,6 @@ impl Assembler Target::CodePtr(dst_ptr) => { emit_jmp_ptr(cb, dst_ptr, true); }, - Target::SideExitPtr(dst_ptr) => { - emit_jmp_ptr(cb, dst_ptr, false); - }, Target::Label(label_idx) => { // Here we're going to save enough space for // ourselves and then come back and write the @@ -1297,6 +1302,7 @@ impl Assembler pub fn compile_with_regs(self, cb: &mut CodeBlock, regs: Vec<Reg>) -> Option<(CodePtr, Vec<u32>)> { let asm = self.arm64_split(); let mut asm = asm.alloc_regs(regs); + asm.compile_side_exits()?; // Create label instances in the code block for (idx, name) in asm.label_names.iter().enumerate() { diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index 5bca786d13..3a85e3cfb5 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -1,6 +1,9 @@ +use std::collections::HashMap; use std::fmt; use std::mem::take; -use crate::{cruby::VALUE, hir::FrameState}; +use crate::cruby::{Qundef, RUBY_OFFSET_CFP_PC, RUBY_OFFSET_CFP_SP, SIZEOF_VALUE_I32, VM_ENV_DATA_SIZE}; +use crate::state::ZJITState; +use crate::{cruby::VALUE}; use crate::backend::current::*; use crate::virtualmem::CodePtr; use crate::asm::{CodeBlock, Label}; @@ -273,9 +276,7 @@ pub enum Target /// Pointer to a piece of ZJIT-generated code CodePtr(CodePtr), // Side exit with a counter - SideExit(FrameState), - /// Pointer to a side exit code - SideExitPtr(CodePtr), + SideExit { pc: *const VALUE, stack: Vec<Opnd>, locals: Vec<Opnd> }, /// A label within the generated code Label(Label), } @@ -292,7 +293,6 @@ impl Target pub fn unwrap_code_ptr(&self) -> CodePtr { match self { Target::CodePtr(ptr) => *ptr, - Target::SideExitPtr(ptr) => *ptr, _ => unreachable!("trying to unwrap {:?} into code ptr", self) } } @@ -539,11 +539,11 @@ impl Insn { Insn::Jne(target) | Insn::Jnz(target) | Insn::Jo(target) | - Insn::Jz(target) | - Insn::Label(target) | Insn::JoMul(target) | + Insn::Jz(target) | Insn::Joz(_, target) | Insn::Jonz(_, target) | + Insn::Label(target) | Insn::LeaJumpTarget { target, .. } => { Some(target) } @@ -697,7 +697,11 @@ impl Insn { Insn::Jne(target) | Insn::Jnz(target) | Insn::Jo(target) | + Insn::JoMul(target) | Insn::Jz(target) | + Insn::Joz(_, target) | + Insn::Jonz(_, target) | + Insn::Label(target) | Insn::LeaJumpTarget { target, .. } => Some(target), _ => None } @@ -731,6 +735,63 @@ impl<'a> Iterator for InsnOpndIterator<'a> { fn next(&mut self) -> Option<Self::Item> { match self.insn { + Insn::Jbe(target) | + Insn::Jb(target) | + Insn::Je(target) | + Insn::Jl(target) | + Insn::Jg(target) | + Insn::Jge(target) | + Insn::Jmp(target) | + Insn::Jne(target) | + Insn::Jnz(target) | + Insn::Jo(target) | + Insn::JoMul(target) | + Insn::Jz(target) | + Insn::Label(target) | + Insn::LeaJumpTarget { target, .. } => { + if let Target::SideExit { stack, locals, .. } = target { + let stack_idx = self.idx; + if stack_idx < stack.len() { + let opnd = &stack[stack_idx]; + self.idx += 1; + return Some(opnd); + } + + let local_idx = self.idx - stack.len(); + if local_idx < locals.len() { + let opnd = &locals[local_idx]; + self.idx += 1; + return Some(opnd); + } + } + None + } + + Insn::Joz(opnd, target) | + Insn::Jonz(opnd, target) => { + if self.idx == 0 { + self.idx += 1; + return Some(opnd); + } + + if let Target::SideExit { stack, locals, .. } = target { + let stack_idx = self.idx - 1; + if stack_idx < stack.len() { + let opnd = &stack[stack_idx]; + self.idx += 1; + return Some(opnd); + } + + let local_idx = stack_idx - stack.len(); + if local_idx < locals.len() { + let opnd = &locals[local_idx]; + self.idx += 1; + return Some(opnd); + } + } + None + } + Insn::BakeString(_) | Insn::Breakpoint | Insn::Comment(_) | @@ -739,20 +800,6 @@ impl<'a> Iterator for InsnOpndIterator<'a> { Insn::CPushAll | Insn::FrameSetup | Insn::FrameTeardown | - Insn::Jbe(_) | - Insn::Jb(_) | - Insn::Je(_) | - Insn::Jl(_) | - Insn::Jg(_) | - Insn::Jge(_) | - Insn::Jmp(_) | - Insn::Jne(_) | - Insn::Jnz(_) | - Insn::Jo(_) | - Insn::JoMul(_) | - Insn::Jz(_) | - Insn::Label(_) | - Insn::LeaJumpTarget { .. } | Insn::PadInvalPatch | Insn::PosMarker(_) => None, @@ -764,8 +811,6 @@ impl<'a> Iterator for InsnOpndIterator<'a> { Insn::LiveReg { opnd, .. } | Insn::Load { opnd, .. } | Insn::LoadSExt { opnd, .. } | - Insn::Joz(opnd, _) | - Insn::Jonz(opnd, _) | Insn::Not { opnd, .. } => { match self.idx { 0 => { @@ -845,6 +890,63 @@ impl<'a> InsnOpndMutIterator<'a> { pub(super) fn next(&mut self) -> Option<&mut Opnd> { match self.insn { + Insn::Jbe(target) | + Insn::Jb(target) | + Insn::Je(target) | + Insn::Jl(target) | + Insn::Jg(target) | + Insn::Jge(target) | + Insn::Jmp(target) | + Insn::Jne(target) | + Insn::Jnz(target) | + Insn::Jo(target) | + Insn::JoMul(target) | + Insn::Jz(target) | + Insn::Label(target) | + Insn::LeaJumpTarget { target, .. } => { + if let Target::SideExit { stack, locals, .. } = target { + let stack_idx = self.idx; + if stack_idx < stack.len() { + let opnd = &mut stack[stack_idx]; + self.idx += 1; + return Some(opnd); + } + + let local_idx = self.idx - stack.len(); + if local_idx < locals.len() { + let opnd = &mut locals[local_idx]; + self.idx += 1; + return Some(opnd); + } + } + None + } + + Insn::Joz(opnd, target) | + Insn::Jonz(opnd, target) => { + if self.idx == 0 { + self.idx += 1; + return Some(opnd); + } + + if let Target::SideExit { stack, locals, .. } = target { + let stack_idx = self.idx - 1; + if stack_idx < stack.len() { + let opnd = &mut stack[stack_idx]; + self.idx += 1; + return Some(opnd); + } + + let local_idx = stack_idx - stack.len(); + if local_idx < locals.len() { + let opnd = &mut locals[local_idx]; + self.idx += 1; + return Some(opnd); + } + } + None + } + Insn::BakeString(_) | Insn::Breakpoint | Insn::Comment(_) | @@ -853,20 +955,6 @@ impl<'a> InsnOpndMutIterator<'a> { Insn::CPushAll | Insn::FrameSetup | Insn::FrameTeardown | - Insn::Jbe(_) | - Insn::Jb(_) | - Insn::Je(_) | - Insn::Jl(_) | - Insn::Jg(_) | - Insn::Jge(_) | - Insn::Jmp(_) | - Insn::Jne(_) | - Insn::Jnz(_) | - Insn::Jo(_) | - Insn::JoMul(_) | - Insn::Jz(_) | - Insn::Label(_) | - Insn::LeaJumpTarget { .. } | Insn::PadInvalPatch | Insn::PosMarker(_) => None, @@ -878,8 +966,6 @@ impl<'a> InsnOpndMutIterator<'a> { Insn::LiveReg { opnd, .. } | Insn::Load { opnd, .. } | Insn::LoadSExt { opnd, .. } | - Insn::Joz(opnd, _) | - Insn::Jonz(opnd, _) | Insn::Not { opnd, .. } => { match self.idx { 0 => { @@ -1649,10 +1735,8 @@ impl Assembler /// Compile the instructions down to machine code. /// Can fail due to lack of code memory and inopportune code placement, among other reasons. #[must_use] - pub fn compile(mut self, cb: &mut CodeBlock) -> Option<(CodePtr, Vec<u32>)> + pub fn compile(self, cb: &mut CodeBlock) -> Option<(CodePtr, Vec<u32>)> { - self.compile_side_exits(cb)?; - #[cfg(feature = "disasm")] let start_addr = cb.get_write_ptr(); let alloc_regs = Self::get_alloc_regs(); @@ -1669,47 +1753,74 @@ impl Assembler /// Compile Target::SideExit and convert it into Target::CodePtr for all instructions #[must_use] - pub fn compile_side_exits(&mut self, cb: &mut CodeBlock) -> Option<()> { - for insn in self.insns.iter_mut() { - if let Some(target) = insn.target_mut() { - if let Target::SideExit(state) = target { - let side_exit_ptr = cb.get_write_ptr(); - let mut asm = Assembler::new(); - asm_comment!(asm, "side exit: {state}"); - asm.ccall(Self::rb_zjit_side_exit as *const u8, vec![]); - asm.compile(cb)?; - *target = Target::SideExitPtr(side_exit_ptr); - } + pub fn compile_side_exits(&mut self) -> Option<()> { + let mut targets = HashMap::new(); + for (idx, insn) in self.insns.iter().enumerate() { + if let Some(target @ Target::SideExit { .. }) = insn.target() { + targets.insert(idx, target.clone()); } } - Some(()) - } - #[unsafe(no_mangle)] - extern "C" fn rb_zjit_side_exit() { - unimplemented!("side exits are not implemented yet"); - } + for (idx, target) in targets { + // Compile a side exit. Note that this is past the split pass and alloc_regs(), + // so you can't use a VReg or an instruction that needs to be split. + if let Target::SideExit { pc, stack, locals } = target { + let side_exit_label = self.new_label("side_exit".into()); + self.write_label(side_exit_label.clone()); + + // Load an operand that cannot be used as a source of Insn::Store + fn split_store_source(asm: &mut Assembler, opnd: Opnd) -> Opnd { + if matches!(opnd, Opnd::Mem(_) | Opnd::Value(_)) || + (cfg!(target_arch = "aarch64") && matches!(opnd, Opnd::UImm(_))) { + asm.load_into(Opnd::Reg(Assembler::SCRATCH_REG), opnd); + Opnd::Reg(Assembler::SCRATCH_REG) + } else { + opnd + } + } - /* - /// Compile with a limited number of registers. Used only for unit tests. - #[cfg(test)] - pub fn compile_with_num_regs(self, cb: &mut CodeBlock, num_regs: usize) -> (CodePtr, Vec<u32>) - { - let mut alloc_regs = Self::get_alloc_regs(); - let alloc_regs = alloc_regs.drain(0..num_regs).collect(); - self.compile_with_regs(cb, None, alloc_regs).unwrap() - } + asm_comment!(self, "write stack slots: {stack:?}"); + for (idx, &opnd) in stack.iter().enumerate() { + let opnd = split_store_source(self, opnd); + self.store(Opnd::mem(64, SP, idx as i32 * SIZEOF_VALUE_I32), opnd); + } + + asm_comment!(self, "write locals: {locals:?}"); + for (idx, &opnd) in locals.iter().enumerate() { + let opnd = split_store_source(self, opnd); + self.store(Opnd::mem(64, SP, (-(VM_ENV_DATA_SIZE as i32) - locals.len() as i32 + idx as i32) * SIZEOF_VALUE_I32), opnd); + } + + asm_comment!(self, "save cfp->pc"); + self.load_into(Opnd::Reg(Assembler::SCRATCH_REG), Opnd::const_ptr(pc as *const u8)); + self.store(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_PC), Opnd::Reg(Assembler::SCRATCH_REG)); + + asm_comment!(self, "save cfp->sp"); + self.lea_into(Opnd::Reg(Assembler::SCRATCH_REG), Opnd::mem(64, SP, stack.len() as i32 * SIZEOF_VALUE_I32)); + let cfp_sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP); + self.store(cfp_sp, Opnd::Reg(Assembler::SCRATCH_REG)); - /// Return true if the next ccall() is expected to be leaf. - pub fn get_leaf_ccall(&mut self) -> bool { - self.leaf_ccall + asm_comment!(self, "rewind caller frames"); + self.mov(C_ARG_OPNDS[0], Assembler::return_addr_opnd()); + self.ccall(Self::rewind_caller_frames as *const u8, vec![]); + + asm_comment!(self, "exit to the interpreter"); + self.frame_teardown(); + self.mov(C_RET_OPND, Opnd::UImm(Qundef.as_u64())); + self.cret(C_RET_OPND); + + *self.insns[idx].target_mut().unwrap() = side_exit_label; + } + } + Some(()) } - /// Assert that the next ccall() is going to be leaf. - pub fn expect_leaf_ccall(&mut self) { - self.leaf_ccall = true; + #[unsafe(no_mangle)] + extern "C" fn rewind_caller_frames(addr: *const u8) { + if ZJITState::is_iseq_return_addr(addr) { + unimplemented!("Can't side-exit from JIT-JIT call: rewind_caller_frames is not implemented yet"); + } } - */ } impl fmt::Debug for Assembler { @@ -1970,6 +2081,10 @@ impl Assembler { out } + pub fn lea_into(&mut self, out: Opnd, opnd: Opnd) { + self.push_insn(Insn::Lea { opnd, out }); + } + #[must_use] pub fn lea_jump_target(&mut self, target: Target) -> Opnd { let out = self.new_vreg(Opnd::DEFAULT_NUM_BITS); diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs index f11b07c1b7..cf62cdd7f5 100644 --- a/zjit/src/backend/x86_64/mod.rs +++ b/zjit/src/backend/x86_64/mod.rs @@ -109,6 +109,11 @@ impl Assembler vec![RAX_REG, RCX_REG, RDX_REG, RSI_REG, RDI_REG, R8_REG, R9_REG, R10_REG, R11_REG] } + /// Get the address that the current frame returns to + pub fn return_addr_opnd() -> Opnd { + Opnd::mem(64, Opnd::Reg(RSP_REG), 0) + } + // These are the callee-saved registers in the x86-64 SysV ABI // RBX, RSP, RBP, and R12–R15 @@ -665,7 +670,7 @@ impl Assembler // Conditional jump to a label Insn::Jmp(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jmp_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => jmp_ptr(cb, code_ptr), Target::Label(label) => jmp_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -673,7 +678,7 @@ impl Assembler Insn::Je(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => je_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => je_ptr(cb, code_ptr), Target::Label(label) => je_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -681,7 +686,7 @@ impl Assembler Insn::Jne(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jne_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => jne_ptr(cb, code_ptr), Target::Label(label) => jne_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -689,7 +694,7 @@ impl Assembler Insn::Jl(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jl_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => jl_ptr(cb, code_ptr), Target::Label(label) => jl_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -697,7 +702,7 @@ impl Assembler Insn::Jg(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jg_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => jg_ptr(cb, code_ptr), Target::Label(label) => jg_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -705,7 +710,7 @@ impl Assembler Insn::Jge(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jge_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => jge_ptr(cb, code_ptr), Target::Label(label) => jge_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -713,7 +718,7 @@ impl Assembler Insn::Jbe(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jbe_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => jbe_ptr(cb, code_ptr), Target::Label(label) => jbe_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -721,7 +726,7 @@ impl Assembler Insn::Jb(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jb_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => jb_ptr(cb, code_ptr), Target::Label(label) => jb_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -729,7 +734,7 @@ impl Assembler Insn::Jz(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jz_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => jz_ptr(cb, code_ptr), Target::Label(label) => jz_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -737,7 +742,7 @@ impl Assembler Insn::Jnz(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jnz_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => jnz_ptr(cb, code_ptr), Target::Label(label) => jnz_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -746,7 +751,7 @@ impl Assembler Insn::Jo(target) | Insn::JoMul(target) => { match *target { - Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jo_ptr(cb, code_ptr), + Target::CodePtr(code_ptr) => jo_ptr(cb, code_ptr), Target::Label(label) => jo_label(cb, label), Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"), } @@ -831,6 +836,7 @@ impl Assembler pub fn compile_with_regs(self, cb: &mut CodeBlock, regs: Vec<Reg>) -> Option<(CodePtr, Vec<u32>)> { let asm = self.x86_split(); let mut asm = asm.alloc_regs(regs); + asm.compile_side_exits()?; // Create label instances in the code block for (idx, name) in asm.label_names.iter().enumerate() { diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index d5202486f1..c1fad915f0 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -260,9 +260,9 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::SendWithoutBlock { call_info, cd, state, self_val, args, .. } => gen_send_without_block(jit, asm, call_info, *cd, &function.frame_state(*state), self_val, args)?, Insn::SendWithoutBlockDirect { iseq, self_val, args, .. } => gen_send_without_block_direct(cb, jit, asm, *iseq, opnd!(self_val), args)?, Insn::Return { val } => return Some(gen_return(asm, opnd!(val))?), - Insn::FixnumAdd { left, right, state } => gen_fixnum_add(asm, opnd!(left), opnd!(right), &function.frame_state(*state))?, - Insn::FixnumSub { left, right, state } => gen_fixnum_sub(asm, opnd!(left), opnd!(right), &function.frame_state(*state))?, - Insn::FixnumMult { left, right, state } => gen_fixnum_mult(asm, opnd!(left), opnd!(right), &function.frame_state(*state))?, + Insn::FixnumAdd { left, right, state } => gen_fixnum_add(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?, + Insn::FixnumSub { left, right, state } => gen_fixnum_sub(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?, + Insn::FixnumMult { left, right, state } => gen_fixnum_mult(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?, Insn::FixnumEq { left, right } => gen_fixnum_eq(asm, opnd!(left), opnd!(right))?, Insn::FixnumNeq { left, right } => gen_fixnum_neq(asm, opnd!(left), opnd!(right))?, Insn::FixnumLt { left, right } => gen_fixnum_lt(asm, opnd!(left), opnd!(right))?, @@ -270,8 +270,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::FixnumGt { left, right } => gen_fixnum_gt(asm, opnd!(left), opnd!(right))?, Insn::FixnumGe { left, right } => gen_fixnum_ge(asm, opnd!(left), opnd!(right))?, Insn::Test { val } => gen_test(asm, opnd!(val))?, - Insn::GuardType { val, guard_type, state } => gen_guard_type(asm, opnd!(val), *guard_type, &function.frame_state(*state))?, - Insn::GuardBitEquals { val, expected, state } => gen_guard_bit_equals(asm, opnd!(val), *expected, &function.frame_state(*state))?, + Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state))?, + Insn::GuardBitEquals { val, expected, state } => gen_guard_bit_equals(jit, asm, opnd!(val), *expected, &function.frame_state(*state))?, Insn::PatchPoint(_) => return Some(()), // For now, rb_zjit_bop_redefined() panics. TODO: leave a patch point and fix rb_zjit_bop_redefined() Insn::CCall { cfun, args, name: _, return_type: _, elidable: _ } => gen_ccall(jit, asm, *cfun, args)?, _ => { @@ -569,27 +569,27 @@ fn gen_return(asm: &mut Assembler, val: lir::Opnd) -> Option<()> { } /// Compile Fixnum + Fixnum -fn gen_fixnum_add(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> { +fn gen_fixnum_add(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> { // Add left + right and test for overflow let left_untag = asm.sub(left, Opnd::Imm(1)); let out_val = asm.add(left_untag, right); - asm.jo(Target::SideExit(state.clone())); + asm.jo(side_exit(jit, state)?); Some(out_val) } /// Compile Fixnum - Fixnum -fn gen_fixnum_sub(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> { +fn gen_fixnum_sub(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> { // Subtract left - right and test for overflow let val_untag = asm.sub(left, right); - asm.jo(Target::SideExit(state.clone())); + asm.jo(side_exit(jit, state)?); let out_val = asm.add(val_untag, Opnd::Imm(1)); Some(out_val) } /// Compile Fixnum * Fixnum -fn gen_fixnum_mult(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> { +fn gen_fixnum_mult(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> { // Do some bitwise gymnastics to handle tag bits // x * y is translated to (x >> 1) * (y - 1) + 1 let left_untag = asm.rshift(left, Opnd::UImm(1)); @@ -597,7 +597,7 @@ fn gen_fixnum_mult(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state let out_val = asm.mul(left_untag, right_untag); // Test for overflow - asm.jo_mul(Target::SideExit(state.clone())); + asm.jo_mul(side_exit(jit, state)?); let out_val = asm.add(out_val, Opnd::UImm(1)); Some(out_val) @@ -651,11 +651,11 @@ fn gen_test(asm: &mut Assembler, val: lir::Opnd) -> Option<lir::Opnd> { } /// Compile a type check with a side exit -fn gen_guard_type(asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state: &FrameState) -> Option<lir::Opnd> { +fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state: &FrameState) -> Option<lir::Opnd> { if guard_type.is_subtype(Fixnum) { // Check if opnd is Fixnum asm.test(val, Opnd::UImm(RUBY_FIXNUM_FLAG as u64)); - asm.jz(Target::SideExit(state.clone())); + asm.jz(side_exit(jit, state)?); } else { unimplemented!("unsupported type: {guard_type}"); } @@ -663,9 +663,9 @@ fn gen_guard_type(asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state: } /// Compile an identity check with a side exit -fn gen_guard_bit_equals(asm: &mut Assembler, val: lir::Opnd, expected: VALUE, state: &FrameState) -> Option<lir::Opnd> { +fn gen_guard_bit_equals(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, expected: VALUE, state: &FrameState) -> Option<lir::Opnd> { asm.cmp(val, Opnd::UImm(expected.into())); - asm.jnz(Target::SideExit(state.clone())); + asm.jnz(side_exit(jit, state)?); Some(val) } @@ -731,6 +731,26 @@ fn compile_iseq(iseq: IseqPtr) -> Option<Function> { Some(function) } +/// Build a Target::SideExit out of a FrameState +fn side_exit(jit: &mut JITState, state: &FrameState) -> Option<Target> { + let mut stack = Vec::new(); + for &insn_id in state.stack() { + stack.push(jit.get_opnd(insn_id)?); + } + + let mut locals = Vec::new(); + for &insn_id in state.locals() { + locals.push(jit.get_opnd(insn_id)?); + } + + let target = Target::SideExit { + pc: state.pc, + stack, + locals, + }; + Some(target) +} + impl Assembler { /// Make a C call while marking the start and end positions of it fn ccall_with_branch(&mut self, fptr: *const u8, opnds: Vec<Opnd>, branch: &Rc<Branch>) -> Opnd { @@ -744,8 +764,9 @@ impl Assembler { move |code_ptr, _| { start_branch.start_addr.set(Some(code_ptr)); }, - move |code_ptr, _| { + move |code_ptr, cb| { end_branch.end_addr.set(Some(code_ptr)); + ZJITState::add_iseq_return_addr(code_ptr.raw_ptr(cb)); }, ) } diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 1d163e5741..14a94cdff7 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -1719,6 +1719,11 @@ impl FrameState { self.stack.iter() } + /// Iterate over all local variables + pub fn locals(&self) -> Iter<InsnId> { + self.locals.iter() + } + /// Push a stack operand fn stack_push(&mut self, opnd: InsnId) { self.stack.push(opnd); diff --git a/zjit/src/state.rs b/zjit/src/state.rs index e846ee6f8d..e8c389a5f8 100644 --- a/zjit/src/state.rs +++ b/zjit/src/state.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use crate::cruby::{self, rb_bug_panic_hook, EcPtr, Qnil, VALUE}; use crate::cruby_methods; use crate::invariants::Invariants; @@ -29,6 +31,9 @@ pub struct ZJITState { /// Properties of core library methods method_annotations: cruby_methods::Annotations, + + /// The address of the instruction that JIT-to-JIT calls return to + iseq_return_addrs: HashSet<*const u8>, } /// Private singleton instance of the codegen globals @@ -82,7 +87,8 @@ impl ZJITState { options, invariants: Invariants::default(), assert_compiles: false, - method_annotations: cruby_methods::init() + method_annotations: cruby_methods::init(), + iseq_return_addrs: HashSet::new(), }; unsafe { ZJIT_STATE = Some(zjit_state); } } @@ -126,6 +132,16 @@ impl ZJITState { let instance = ZJITState::get_instance(); instance.assert_compiles = true; } + + /// Record an address that a JIT-to-JIT call returns to + pub fn add_iseq_return_addr(addr: *const u8) { + ZJITState::get_instance().iseq_return_addrs.insert(addr); + } + + /// Returns true if a JIT-to-JIT call returns to a given address + pub fn is_iseq_return_addr(addr: *const u8) -> bool { + ZJITState::get_instance().iseq_return_addrs.contains(&addr) + } } /// Initialize ZJIT, given options allocated by rb_zjit_init_options() |