summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/ruby/test_zjit.rb18
-rw-r--r--zjit/src/asm/arm64/mod.rs4
-rw-r--r--zjit/src/asm/arm64/opnd.rs5
-rw-r--r--zjit/src/backend/arm64/mod.rs18
-rw-r--r--zjit/src/backend/lir.rs265
-rw-r--r--zjit/src/backend/x86_64/mod.rs28
-rw-r--r--zjit/src/codegen.rs53
-rw-r--r--zjit/src/hir.rs5
-rw-r--r--zjit/src/state.rs18
9 files changed, 302 insertions, 112 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb
index 796851a9bf..aba05ddebd 100644
--- a/test/ruby/test_zjit.rb
+++ b/test/ruby/test_zjit.rb
@@ -94,6 +94,24 @@ class TestZJIT < Test::Unit::TestCase
}, call_threshold: 2
end
+ def test_opt_plus_type_guard_exit
+ assert_compiles '[3, 3.0]', %q{
+ def test(a) = 1 + a
+ test(1) # profile opt_plus
+ [test(2), test(2.0)]
+ }, call_threshold: 2
+ end
+
+ def test_opt_plus_type_guard_nested_exit
+ omit 'rewind_caller_frames is not implemented yet'
+ assert_compiles '[3, 3.0]', %q{
+ def side_exit(n) = 1 + n
+ def jit_frame(n) = 1 + side_exit(n)
+ def entry(n) = jit_frame(n)
+ [entry(2), entry(2.0)]
+ }, call_threshold: 2
+ end
+
# Test argument ordering
def test_opt_minus
assert_compiles '2', %q{
diff --git a/zjit/src/asm/arm64/mod.rs b/zjit/src/asm/arm64/mod.rs
index a5d73d71a5..1e1b125eaa 100644
--- a/zjit/src/asm/arm64/mod.rs
+++ b/zjit/src/asm/arm64/mod.rs
@@ -644,7 +644,7 @@ pub fn mov(cb: &mut CodeBlock, rd: A64Opnd, rm: A64Opnd) {
LogicalImm::mov(rd.reg_no, bitmask_imm, rd.num_bits).into()
},
- _ => panic!("Invalid operand combination to mov instruction")
+ _ => panic!("Invalid operand combination to mov instruction: {rd:?}, {rm:?}")
};
cb.write_bytes(&bytes);
@@ -940,7 +940,7 @@ pub fn stur(cb: &mut CodeBlock, rt: A64Opnd, rn: A64Opnd) {
LoadStore::stur(rt.reg_no, rn.base_reg_no, rn.disp as i16, rn.num_bits).into()
},
- _ => panic!("Invalid operand combination to stur instruction.")
+ _ => panic!("Invalid operand combination to stur instruction: {rt:?}, {rn:?}")
};
cb.write_bytes(&bytes);
diff --git a/zjit/src/asm/arm64/opnd.rs b/zjit/src/asm/arm64/opnd.rs
index 6e31851504..28422b7476 100644
--- a/zjit/src/asm/arm64/opnd.rs
+++ b/zjit/src/asm/arm64/opnd.rs
@@ -119,6 +119,9 @@ pub const X20_REG: A64Reg = A64Reg { num_bits: 64, reg_no: 20 };
pub const X21_REG: A64Reg = A64Reg { num_bits: 64, reg_no: 21 };
pub const X22_REG: A64Reg = A64Reg { num_bits: 64, reg_no: 22 };
+// link register
+pub const X30_REG: A64Reg = A64Reg { num_bits: 64, reg_no: 30 };
+
// zero register
pub const XZR_REG: A64Reg = A64Reg { num_bits: 64, reg_no: 31 };
@@ -153,7 +156,7 @@ pub const X26: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 64, reg_no: 26 });
pub const X27: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 64, reg_no: 27 });
pub const X28: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 64, reg_no: 28 });
pub const X29: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 64, reg_no: 29 });
-pub const X30: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 64, reg_no: 30 });
+pub const X30: A64Opnd = A64Opnd::Reg(X30_REG);
pub const X31: A64Opnd = A64Opnd::Reg(XZR_REG);
// 32-bit registers
diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs
index ffde567b69..832f3c1e1e 100644
--- a/zjit/src/backend/arm64/mod.rs
+++ b/zjit/src/backend/arm64/mod.rs
@@ -211,6 +211,11 @@ impl Assembler
vec![X1_REG, X9_REG, X10_REG, X11_REG, X12_REG, X13_REG, X14_REG, X15_REG]
}
+ /// Get the address that the current frame returns to
+ pub fn return_addr_opnd() -> Opnd {
+ Opnd::Reg(X30_REG)
+ }
+
/// Split platform-specific instructions
/// The transformations done here are meant to make our lives simpler in later
/// stages of the compilation pipeline.
@@ -757,7 +762,7 @@ impl Assembler
/// called when lowering any of the conditional jump instructions.
fn emit_conditional_jump<const CONDITION: u8>(cb: &mut CodeBlock, target: Target) {
match target {
- Target::CodePtr(dst_ptr) | Target::SideExitPtr(dst_ptr) => {
+ Target::CodePtr(dst_ptr) => {
let dst_addr = dst_ptr.as_offset();
let src_addr = cb.get_write_ptr().as_offset();
@@ -829,8 +834,10 @@ impl Assembler
}
/// Emit a CBZ or CBNZ which branches when a register is zero or non-zero
- fn emit_cmp_zero_jump(cb: &mut CodeBlock, reg: A64Opnd, branch_if_zero: bool, target: Target) {
- if let Target::SideExitPtr(dst_ptr) = target {
+ fn emit_cmp_zero_jump(_cb: &mut CodeBlock, _reg: A64Opnd, _branch_if_zero: bool, target: Target) {
+ if let Target::Label(_) = target {
+ unimplemented!("this should be re-implemented with Label for side exits");
+ /*
let dst_addr = dst_ptr.as_offset();
let src_addr = cb.get_write_ptr().as_offset();
@@ -862,6 +869,7 @@ impl Assembler
br(cb, Assembler::SCRATCH0);
}
+ */
} else {
unreachable!("We should only generate Joz/Jonz with side-exit targets");
}
@@ -1162,9 +1170,6 @@ impl Assembler
Target::CodePtr(dst_ptr) => {
emit_jmp_ptr(cb, dst_ptr, true);
},
- Target::SideExitPtr(dst_ptr) => {
- emit_jmp_ptr(cb, dst_ptr, false);
- },
Target::Label(label_idx) => {
// Here we're going to save enough space for
// ourselves and then come back and write the
@@ -1297,6 +1302,7 @@ impl Assembler
pub fn compile_with_regs(self, cb: &mut CodeBlock, regs: Vec<Reg>) -> Option<(CodePtr, Vec<u32>)> {
let asm = self.arm64_split();
let mut asm = asm.alloc_regs(regs);
+ asm.compile_side_exits()?;
// Create label instances in the code block
for (idx, name) in asm.label_names.iter().enumerate() {
diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs
index 5bca786d13..3a85e3cfb5 100644
--- a/zjit/src/backend/lir.rs
+++ b/zjit/src/backend/lir.rs
@@ -1,6 +1,9 @@
+use std::collections::HashMap;
use std::fmt;
use std::mem::take;
-use crate::{cruby::VALUE, hir::FrameState};
+use crate::cruby::{Qundef, RUBY_OFFSET_CFP_PC, RUBY_OFFSET_CFP_SP, SIZEOF_VALUE_I32, VM_ENV_DATA_SIZE};
+use crate::state::ZJITState;
+use crate::{cruby::VALUE};
use crate::backend::current::*;
use crate::virtualmem::CodePtr;
use crate::asm::{CodeBlock, Label};
@@ -273,9 +276,7 @@ pub enum Target
/// Pointer to a piece of ZJIT-generated code
CodePtr(CodePtr),
// Side exit with a counter
- SideExit(FrameState),
- /// Pointer to a side exit code
- SideExitPtr(CodePtr),
+ SideExit { pc: *const VALUE, stack: Vec<Opnd>, locals: Vec<Opnd> },
/// A label within the generated code
Label(Label),
}
@@ -292,7 +293,6 @@ impl Target
pub fn unwrap_code_ptr(&self) -> CodePtr {
match self {
Target::CodePtr(ptr) => *ptr,
- Target::SideExitPtr(ptr) => *ptr,
_ => unreachable!("trying to unwrap {:?} into code ptr", self)
}
}
@@ -539,11 +539,11 @@ impl Insn {
Insn::Jne(target) |
Insn::Jnz(target) |
Insn::Jo(target) |
- Insn::Jz(target) |
- Insn::Label(target) |
Insn::JoMul(target) |
+ Insn::Jz(target) |
Insn::Joz(_, target) |
Insn::Jonz(_, target) |
+ Insn::Label(target) |
Insn::LeaJumpTarget { target, .. } => {
Some(target)
}
@@ -697,7 +697,11 @@ impl Insn {
Insn::Jne(target) |
Insn::Jnz(target) |
Insn::Jo(target) |
+ Insn::JoMul(target) |
Insn::Jz(target) |
+ Insn::Joz(_, target) |
+ Insn::Jonz(_, target) |
+ Insn::Label(target) |
Insn::LeaJumpTarget { target, .. } => Some(target),
_ => None
}
@@ -731,6 +735,63 @@ impl<'a> Iterator for InsnOpndIterator<'a> {
fn next(&mut self) -> Option<Self::Item> {
match self.insn {
+ Insn::Jbe(target) |
+ Insn::Jb(target) |
+ Insn::Je(target) |
+ Insn::Jl(target) |
+ Insn::Jg(target) |
+ Insn::Jge(target) |
+ Insn::Jmp(target) |
+ Insn::Jne(target) |
+ Insn::Jnz(target) |
+ Insn::Jo(target) |
+ Insn::JoMul(target) |
+ Insn::Jz(target) |
+ Insn::Label(target) |
+ Insn::LeaJumpTarget { target, .. } => {
+ if let Target::SideExit { stack, locals, .. } = target {
+ let stack_idx = self.idx;
+ if stack_idx < stack.len() {
+ let opnd = &stack[stack_idx];
+ self.idx += 1;
+ return Some(opnd);
+ }
+
+ let local_idx = self.idx - stack.len();
+ if local_idx < locals.len() {
+ let opnd = &locals[local_idx];
+ self.idx += 1;
+ return Some(opnd);
+ }
+ }
+ None
+ }
+
+ Insn::Joz(opnd, target) |
+ Insn::Jonz(opnd, target) => {
+ if self.idx == 0 {
+ self.idx += 1;
+ return Some(opnd);
+ }
+
+ if let Target::SideExit { stack, locals, .. } = target {
+ let stack_idx = self.idx - 1;
+ if stack_idx < stack.len() {
+ let opnd = &stack[stack_idx];
+ self.idx += 1;
+ return Some(opnd);
+ }
+
+ let local_idx = stack_idx - stack.len();
+ if local_idx < locals.len() {
+ let opnd = &locals[local_idx];
+ self.idx += 1;
+ return Some(opnd);
+ }
+ }
+ None
+ }
+
Insn::BakeString(_) |
Insn::Breakpoint |
Insn::Comment(_) |
@@ -739,20 +800,6 @@ impl<'a> Iterator for InsnOpndIterator<'a> {
Insn::CPushAll |
Insn::FrameSetup |
Insn::FrameTeardown |
- Insn::Jbe(_) |
- Insn::Jb(_) |
- Insn::Je(_) |
- Insn::Jl(_) |
- Insn::Jg(_) |
- Insn::Jge(_) |
- Insn::Jmp(_) |
- Insn::Jne(_) |
- Insn::Jnz(_) |
- Insn::Jo(_) |
- Insn::JoMul(_) |
- Insn::Jz(_) |
- Insn::Label(_) |
- Insn::LeaJumpTarget { .. } |
Insn::PadInvalPatch |
Insn::PosMarker(_) => None,
@@ -764,8 +811,6 @@ impl<'a> Iterator for InsnOpndIterator<'a> {
Insn::LiveReg { opnd, .. } |
Insn::Load { opnd, .. } |
Insn::LoadSExt { opnd, .. } |
- Insn::Joz(opnd, _) |
- Insn::Jonz(opnd, _) |
Insn::Not { opnd, .. } => {
match self.idx {
0 => {
@@ -845,6 +890,63 @@ impl<'a> InsnOpndMutIterator<'a> {
pub(super) fn next(&mut self) -> Option<&mut Opnd> {
match self.insn {
+ Insn::Jbe(target) |
+ Insn::Jb(target) |
+ Insn::Je(target) |
+ Insn::Jl(target) |
+ Insn::Jg(target) |
+ Insn::Jge(target) |
+ Insn::Jmp(target) |
+ Insn::Jne(target) |
+ Insn::Jnz(target) |
+ Insn::Jo(target) |
+ Insn::JoMul(target) |
+ Insn::Jz(target) |
+ Insn::Label(target) |
+ Insn::LeaJumpTarget { target, .. } => {
+ if let Target::SideExit { stack, locals, .. } = target {
+ let stack_idx = self.idx;
+ if stack_idx < stack.len() {
+ let opnd = &mut stack[stack_idx];
+ self.idx += 1;
+ return Some(opnd);
+ }
+
+ let local_idx = self.idx - stack.len();
+ if local_idx < locals.len() {
+ let opnd = &mut locals[local_idx];
+ self.idx += 1;
+ return Some(opnd);
+ }
+ }
+ None
+ }
+
+ Insn::Joz(opnd, target) |
+ Insn::Jonz(opnd, target) => {
+ if self.idx == 0 {
+ self.idx += 1;
+ return Some(opnd);
+ }
+
+ if let Target::SideExit { stack, locals, .. } = target {
+ let stack_idx = self.idx - 1;
+ if stack_idx < stack.len() {
+ let opnd = &mut stack[stack_idx];
+ self.idx += 1;
+ return Some(opnd);
+ }
+
+ let local_idx = stack_idx - stack.len();
+ if local_idx < locals.len() {
+ let opnd = &mut locals[local_idx];
+ self.idx += 1;
+ return Some(opnd);
+ }
+ }
+ None
+ }
+
Insn::BakeString(_) |
Insn::Breakpoint |
Insn::Comment(_) |
@@ -853,20 +955,6 @@ impl<'a> InsnOpndMutIterator<'a> {
Insn::CPushAll |
Insn::FrameSetup |
Insn::FrameTeardown |
- Insn::Jbe(_) |
- Insn::Jb(_) |
- Insn::Je(_) |
- Insn::Jl(_) |
- Insn::Jg(_) |
- Insn::Jge(_) |
- Insn::Jmp(_) |
- Insn::Jne(_) |
- Insn::Jnz(_) |
- Insn::Jo(_) |
- Insn::JoMul(_) |
- Insn::Jz(_) |
- Insn::Label(_) |
- Insn::LeaJumpTarget { .. } |
Insn::PadInvalPatch |
Insn::PosMarker(_) => None,
@@ -878,8 +966,6 @@ impl<'a> InsnOpndMutIterator<'a> {
Insn::LiveReg { opnd, .. } |
Insn::Load { opnd, .. } |
Insn::LoadSExt { opnd, .. } |
- Insn::Joz(opnd, _) |
- Insn::Jonz(opnd, _) |
Insn::Not { opnd, .. } => {
match self.idx {
0 => {
@@ -1649,10 +1735,8 @@ impl Assembler
/// Compile the instructions down to machine code.
/// Can fail due to lack of code memory and inopportune code placement, among other reasons.
#[must_use]
- pub fn compile(mut self, cb: &mut CodeBlock) -> Option<(CodePtr, Vec<u32>)>
+ pub fn compile(self, cb: &mut CodeBlock) -> Option<(CodePtr, Vec<u32>)>
{
- self.compile_side_exits(cb)?;
-
#[cfg(feature = "disasm")]
let start_addr = cb.get_write_ptr();
let alloc_regs = Self::get_alloc_regs();
@@ -1669,47 +1753,74 @@ impl Assembler
/// Compile Target::SideExit and convert it into Target::CodePtr for all instructions
#[must_use]
- pub fn compile_side_exits(&mut self, cb: &mut CodeBlock) -> Option<()> {
- for insn in self.insns.iter_mut() {
- if let Some(target) = insn.target_mut() {
- if let Target::SideExit(state) = target {
- let side_exit_ptr = cb.get_write_ptr();
- let mut asm = Assembler::new();
- asm_comment!(asm, "side exit: {state}");
- asm.ccall(Self::rb_zjit_side_exit as *const u8, vec![]);
- asm.compile(cb)?;
- *target = Target::SideExitPtr(side_exit_ptr);
- }
+ pub fn compile_side_exits(&mut self) -> Option<()> {
+ let mut targets = HashMap::new();
+ for (idx, insn) in self.insns.iter().enumerate() {
+ if let Some(target @ Target::SideExit { .. }) = insn.target() {
+ targets.insert(idx, target.clone());
}
}
- Some(())
- }
- #[unsafe(no_mangle)]
- extern "C" fn rb_zjit_side_exit() {
- unimplemented!("side exits are not implemented yet");
- }
+ for (idx, target) in targets {
+ // Compile a side exit. Note that this is past the split pass and alloc_regs(),
+ // so you can't use a VReg or an instruction that needs to be split.
+ if let Target::SideExit { pc, stack, locals } = target {
+ let side_exit_label = self.new_label("side_exit".into());
+ self.write_label(side_exit_label.clone());
+
+ // Load an operand that cannot be used as a source of Insn::Store
+ fn split_store_source(asm: &mut Assembler, opnd: Opnd) -> Opnd {
+ if matches!(opnd, Opnd::Mem(_) | Opnd::Value(_)) ||
+ (cfg!(target_arch = "aarch64") && matches!(opnd, Opnd::UImm(_))) {
+ asm.load_into(Opnd::Reg(Assembler::SCRATCH_REG), opnd);
+ Opnd::Reg(Assembler::SCRATCH_REG)
+ } else {
+ opnd
+ }
+ }
- /*
- /// Compile with a limited number of registers. Used only for unit tests.
- #[cfg(test)]
- pub fn compile_with_num_regs(self, cb: &mut CodeBlock, num_regs: usize) -> (CodePtr, Vec<u32>)
- {
- let mut alloc_regs = Self::get_alloc_regs();
- let alloc_regs = alloc_regs.drain(0..num_regs).collect();
- self.compile_with_regs(cb, None, alloc_regs).unwrap()
- }
+ asm_comment!(self, "write stack slots: {stack:?}");
+ for (idx, &opnd) in stack.iter().enumerate() {
+ let opnd = split_store_source(self, opnd);
+ self.store(Opnd::mem(64, SP, idx as i32 * SIZEOF_VALUE_I32), opnd);
+ }
+
+ asm_comment!(self, "write locals: {locals:?}");
+ for (idx, &opnd) in locals.iter().enumerate() {
+ let opnd = split_store_source(self, opnd);
+ self.store(Opnd::mem(64, SP, (-(VM_ENV_DATA_SIZE as i32) - locals.len() as i32 + idx as i32) * SIZEOF_VALUE_I32), opnd);
+ }
+
+ asm_comment!(self, "save cfp->pc");
+ self.load_into(Opnd::Reg(Assembler::SCRATCH_REG), Opnd::const_ptr(pc as *const u8));
+ self.store(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_PC), Opnd::Reg(Assembler::SCRATCH_REG));
+
+ asm_comment!(self, "save cfp->sp");
+ self.lea_into(Opnd::Reg(Assembler::SCRATCH_REG), Opnd::mem(64, SP, stack.len() as i32 * SIZEOF_VALUE_I32));
+ let cfp_sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP);
+ self.store(cfp_sp, Opnd::Reg(Assembler::SCRATCH_REG));
- /// Return true if the next ccall() is expected to be leaf.
- pub fn get_leaf_ccall(&mut self) -> bool {
- self.leaf_ccall
+ asm_comment!(self, "rewind caller frames");
+ self.mov(C_ARG_OPNDS[0], Assembler::return_addr_opnd());
+ self.ccall(Self::rewind_caller_frames as *const u8, vec![]);
+
+ asm_comment!(self, "exit to the interpreter");
+ self.frame_teardown();
+ self.mov(C_RET_OPND, Opnd::UImm(Qundef.as_u64()));
+ self.cret(C_RET_OPND);
+
+ *self.insns[idx].target_mut().unwrap() = side_exit_label;
+ }
+ }
+ Some(())
}
- /// Assert that the next ccall() is going to be leaf.
- pub fn expect_leaf_ccall(&mut self) {
- self.leaf_ccall = true;
+ #[unsafe(no_mangle)]
+ extern "C" fn rewind_caller_frames(addr: *const u8) {
+ if ZJITState::is_iseq_return_addr(addr) {
+ unimplemented!("Can't side-exit from JIT-JIT call: rewind_caller_frames is not implemented yet");
+ }
}
- */
}
impl fmt::Debug for Assembler {
@@ -1970,6 +2081,10 @@ impl Assembler {
out
}
+ pub fn lea_into(&mut self, out: Opnd, opnd: Opnd) {
+ self.push_insn(Insn::Lea { opnd, out });
+ }
+
#[must_use]
pub fn lea_jump_target(&mut self, target: Target) -> Opnd {
let out = self.new_vreg(Opnd::DEFAULT_NUM_BITS);
diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs
index f11b07c1b7..cf62cdd7f5 100644
--- a/zjit/src/backend/x86_64/mod.rs
+++ b/zjit/src/backend/x86_64/mod.rs
@@ -109,6 +109,11 @@ impl Assembler
vec![RAX_REG, RCX_REG, RDX_REG, RSI_REG, RDI_REG, R8_REG, R9_REG, R10_REG, R11_REG]
}
+ /// Get the address that the current frame returns to
+ pub fn return_addr_opnd() -> Opnd {
+ Opnd::mem(64, Opnd::Reg(RSP_REG), 0)
+ }
+
// These are the callee-saved registers in the x86-64 SysV ABI
// RBX, RSP, RBP, and R12–R15
@@ -665,7 +670,7 @@ impl Assembler
// Conditional jump to a label
Insn::Jmp(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jmp_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => jmp_ptr(cb, code_ptr),
Target::Label(label) => jmp_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -673,7 +678,7 @@ impl Assembler
Insn::Je(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => je_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => je_ptr(cb, code_ptr),
Target::Label(label) => je_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -681,7 +686,7 @@ impl Assembler
Insn::Jne(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jne_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => jne_ptr(cb, code_ptr),
Target::Label(label) => jne_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -689,7 +694,7 @@ impl Assembler
Insn::Jl(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jl_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => jl_ptr(cb, code_ptr),
Target::Label(label) => jl_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -697,7 +702,7 @@ impl Assembler
Insn::Jg(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jg_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => jg_ptr(cb, code_ptr),
Target::Label(label) => jg_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -705,7 +710,7 @@ impl Assembler
Insn::Jge(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jge_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => jge_ptr(cb, code_ptr),
Target::Label(label) => jge_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -713,7 +718,7 @@ impl Assembler
Insn::Jbe(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jbe_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => jbe_ptr(cb, code_ptr),
Target::Label(label) => jbe_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -721,7 +726,7 @@ impl Assembler
Insn::Jb(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jb_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => jb_ptr(cb, code_ptr),
Target::Label(label) => jb_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -729,7 +734,7 @@ impl Assembler
Insn::Jz(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jz_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => jz_ptr(cb, code_ptr),
Target::Label(label) => jz_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -737,7 +742,7 @@ impl Assembler
Insn::Jnz(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jnz_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => jnz_ptr(cb, code_ptr),
Target::Label(label) => jnz_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -746,7 +751,7 @@ impl Assembler
Insn::Jo(target) |
Insn::JoMul(target) => {
match *target {
- Target::CodePtr(code_ptr) | Target::SideExitPtr(code_ptr) => jo_ptr(cb, code_ptr),
+ Target::CodePtr(code_ptr) => jo_ptr(cb, code_ptr),
Target::Label(label) => jo_label(cb, label),
Target::SideExit { .. } => unreachable!("Target::SideExit should have been compiled by compile_side_exits"),
}
@@ -831,6 +836,7 @@ impl Assembler
pub fn compile_with_regs(self, cb: &mut CodeBlock, regs: Vec<Reg>) -> Option<(CodePtr, Vec<u32>)> {
let asm = self.x86_split();
let mut asm = asm.alloc_regs(regs);
+ asm.compile_side_exits()?;
// Create label instances in the code block
for (idx, name) in asm.label_names.iter().enumerate() {
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index d5202486f1..c1fad915f0 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -260,9 +260,9 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::SendWithoutBlock { call_info, cd, state, self_val, args, .. } => gen_send_without_block(jit, asm, call_info, *cd, &function.frame_state(*state), self_val, args)?,
Insn::SendWithoutBlockDirect { iseq, self_val, args, .. } => gen_send_without_block_direct(cb, jit, asm, *iseq, opnd!(self_val), args)?,
Insn::Return { val } => return Some(gen_return(asm, opnd!(val))?),
- Insn::FixnumAdd { left, right, state } => gen_fixnum_add(asm, opnd!(left), opnd!(right), &function.frame_state(*state))?,
- Insn::FixnumSub { left, right, state } => gen_fixnum_sub(asm, opnd!(left), opnd!(right), &function.frame_state(*state))?,
- Insn::FixnumMult { left, right, state } => gen_fixnum_mult(asm, opnd!(left), opnd!(right), &function.frame_state(*state))?,
+ Insn::FixnumAdd { left, right, state } => gen_fixnum_add(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?,
+ Insn::FixnumSub { left, right, state } => gen_fixnum_sub(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?,
+ Insn::FixnumMult { left, right, state } => gen_fixnum_mult(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?,
Insn::FixnumEq { left, right } => gen_fixnum_eq(asm, opnd!(left), opnd!(right))?,
Insn::FixnumNeq { left, right } => gen_fixnum_neq(asm, opnd!(left), opnd!(right))?,
Insn::FixnumLt { left, right } => gen_fixnum_lt(asm, opnd!(left), opnd!(right))?,
@@ -270,8 +270,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::FixnumGt { left, right } => gen_fixnum_gt(asm, opnd!(left), opnd!(right))?,
Insn::FixnumGe { left, right } => gen_fixnum_ge(asm, opnd!(left), opnd!(right))?,
Insn::Test { val } => gen_test(asm, opnd!(val))?,
- Insn::GuardType { val, guard_type, state } => gen_guard_type(asm, opnd!(val), *guard_type, &function.frame_state(*state))?,
- Insn::GuardBitEquals { val, expected, state } => gen_guard_bit_equals(asm, opnd!(val), *expected, &function.frame_state(*state))?,
+ Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state))?,
+ Insn::GuardBitEquals { val, expected, state } => gen_guard_bit_equals(jit, asm, opnd!(val), *expected, &function.frame_state(*state))?,
Insn::PatchPoint(_) => return Some(()), // For now, rb_zjit_bop_redefined() panics. TODO: leave a patch point and fix rb_zjit_bop_redefined()
Insn::CCall { cfun, args, name: _, return_type: _, elidable: _ } => gen_ccall(jit, asm, *cfun, args)?,
_ => {
@@ -569,27 +569,27 @@ fn gen_return(asm: &mut Assembler, val: lir::Opnd) -> Option<()> {
}
/// Compile Fixnum + Fixnum
-fn gen_fixnum_add(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> {
+fn gen_fixnum_add(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> {
// Add left + right and test for overflow
let left_untag = asm.sub(left, Opnd::Imm(1));
let out_val = asm.add(left_untag, right);
- asm.jo(Target::SideExit(state.clone()));
+ asm.jo(side_exit(jit, state)?);
Some(out_val)
}
/// Compile Fixnum - Fixnum
-fn gen_fixnum_sub(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> {
+fn gen_fixnum_sub(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> {
// Subtract left - right and test for overflow
let val_untag = asm.sub(left, right);
- asm.jo(Target::SideExit(state.clone()));
+ asm.jo(side_exit(jit, state)?);
let out_val = asm.add(val_untag, Opnd::Imm(1));
Some(out_val)
}
/// Compile Fixnum * Fixnum
-fn gen_fixnum_mult(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> {
+fn gen_fixnum_mult(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> Option<lir::Opnd> {
// Do some bitwise gymnastics to handle tag bits
// x * y is translated to (x >> 1) * (y - 1) + 1
let left_untag = asm.rshift(left, Opnd::UImm(1));
@@ -597,7 +597,7 @@ fn gen_fixnum_mult(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state
let out_val = asm.mul(left_untag, right_untag);
// Test for overflow
- asm.jo_mul(Target::SideExit(state.clone()));
+ asm.jo_mul(side_exit(jit, state)?);
let out_val = asm.add(out_val, Opnd::UImm(1));
Some(out_val)
@@ -651,11 +651,11 @@ fn gen_test(asm: &mut Assembler, val: lir::Opnd) -> Option<lir::Opnd> {
}
/// Compile a type check with a side exit
-fn gen_guard_type(asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state: &FrameState) -> Option<lir::Opnd> {
+fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state: &FrameState) -> Option<lir::Opnd> {
if guard_type.is_subtype(Fixnum) {
// Check if opnd is Fixnum
asm.test(val, Opnd::UImm(RUBY_FIXNUM_FLAG as u64));
- asm.jz(Target::SideExit(state.clone()));
+ asm.jz(side_exit(jit, state)?);
} else {
unimplemented!("unsupported type: {guard_type}");
}
@@ -663,9 +663,9 @@ fn gen_guard_type(asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state:
}
/// Compile an identity check with a side exit
-fn gen_guard_bit_equals(asm: &mut Assembler, val: lir::Opnd, expected: VALUE, state: &FrameState) -> Option<lir::Opnd> {
+fn gen_guard_bit_equals(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, expected: VALUE, state: &FrameState) -> Option<lir::Opnd> {
asm.cmp(val, Opnd::UImm(expected.into()));
- asm.jnz(Target::SideExit(state.clone()));
+ asm.jnz(side_exit(jit, state)?);
Some(val)
}
@@ -731,6 +731,26 @@ fn compile_iseq(iseq: IseqPtr) -> Option<Function> {
Some(function)
}
+/// Build a Target::SideExit out of a FrameState
+fn side_exit(jit: &mut JITState, state: &FrameState) -> Option<Target> {
+ let mut stack = Vec::new();
+ for &insn_id in state.stack() {
+ stack.push(jit.get_opnd(insn_id)?);
+ }
+
+ let mut locals = Vec::new();
+ for &insn_id in state.locals() {
+ locals.push(jit.get_opnd(insn_id)?);
+ }
+
+ let target = Target::SideExit {
+ pc: state.pc,
+ stack,
+ locals,
+ };
+ Some(target)
+}
+
impl Assembler {
/// Make a C call while marking the start and end positions of it
fn ccall_with_branch(&mut self, fptr: *const u8, opnds: Vec<Opnd>, branch: &Rc<Branch>) -> Opnd {
@@ -744,8 +764,9 @@ impl Assembler {
move |code_ptr, _| {
start_branch.start_addr.set(Some(code_ptr));
},
- move |code_ptr, _| {
+ move |code_ptr, cb| {
end_branch.end_addr.set(Some(code_ptr));
+ ZJITState::add_iseq_return_addr(code_ptr.raw_ptr(cb));
},
)
}
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 1d163e5741..14a94cdff7 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -1719,6 +1719,11 @@ impl FrameState {
self.stack.iter()
}
+ /// Iterate over all local variables
+ pub fn locals(&self) -> Iter<InsnId> {
+ self.locals.iter()
+ }
+
/// Push a stack operand
fn stack_push(&mut self, opnd: InsnId) {
self.stack.push(opnd);
diff --git a/zjit/src/state.rs b/zjit/src/state.rs
index e846ee6f8d..e8c389a5f8 100644
--- a/zjit/src/state.rs
+++ b/zjit/src/state.rs
@@ -1,3 +1,5 @@
+use std::collections::HashSet;
+
use crate::cruby::{self, rb_bug_panic_hook, EcPtr, Qnil, VALUE};
use crate::cruby_methods;
use crate::invariants::Invariants;
@@ -29,6 +31,9 @@ pub struct ZJITState {
/// Properties of core library methods
method_annotations: cruby_methods::Annotations,
+
+ /// The address of the instruction that JIT-to-JIT calls return to
+ iseq_return_addrs: HashSet<*const u8>,
}
/// Private singleton instance of the codegen globals
@@ -82,7 +87,8 @@ impl ZJITState {
options,
invariants: Invariants::default(),
assert_compiles: false,
- method_annotations: cruby_methods::init()
+ method_annotations: cruby_methods::init(),
+ iseq_return_addrs: HashSet::new(),
};
unsafe { ZJIT_STATE = Some(zjit_state); }
}
@@ -126,6 +132,16 @@ impl ZJITState {
let instance = ZJITState::get_instance();
instance.assert_compiles = true;
}
+
+ /// Record an address that a JIT-to-JIT call returns to
+ pub fn add_iseq_return_addr(addr: *const u8) {
+ ZJITState::get_instance().iseq_return_addrs.insert(addr);
+ }
+
+ /// Returns true if a JIT-to-JIT call returns to a given address
+ pub fn is_iseq_return_addr(addr: *const u8) -> bool {
+ ZJITState::get_instance().iseq_return_addrs.contains(&addr)
+ }
}
/// Initialize ZJIT, given options allocated by rb_zjit_init_options()