summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Bernstein <[email protected]>2025-03-27 16:38:31 -0400
committerTakashi Kokubun <[email protected]>2025-04-18 21:53:01 +0900
commit308cd59bf8d9a3d6141fae5931bad397020e7bc8 (patch)
treec68faaf02a965e5b2b375cbd7379b87fbc630c1b
parentcfc9234ccdb457934f4daeef599e303844869fc3 (diff)
Rewrite SendWithoutBlock to SendWithoutBlockDirect when possible
In calls to top-level functions, we assume that call targets will not get rewritten, so we can insert a PatchPoint and do the lookup at compile-time.
Notes
Notes: Merged: https://github.com/ruby/ruby/pull/13131
-rw-r--r--zjit/src/codegen.rs11
-rw-r--r--zjit/src/hir.rs262
-rw-r--r--zjit/src/hir_type/mod.rs8
3 files changed, 269 insertions, 12 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 316f6d9b80..7fbe4f55dd 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -187,7 +187,8 @@ fn gen_insn(jit: &mut JITState, asm: &mut Assembler, function: &Function, insn_i
Insn::Jump(branch) => return gen_jump(jit, asm, branch),
Insn::IfTrue { val, target } => return gen_if_true(jit, asm, opnd!(val), target),
Insn::IfFalse { val, target } => return gen_if_false(jit, asm, opnd!(val), target),
- Insn::SendWithoutBlock { call_info, cd, state, .. } => gen_send_without_block(jit, asm, call_info, *cd, function.frame_state(*state))?,
+ Insn::SendWithoutBlock { call_info, cd, state, .. } | Insn::SendWithoutBlockDirect { call_info, cd, state, .. }
+ => gen_send_without_block(jit, asm, call_info, *cd, function.frame_state(*state))?,
Insn::Return { val } => return Some(gen_return(asm, opnd!(val))?),
Insn::FixnumAdd { left, right, state } => gen_fixnum_add(asm, opnd!(left), opnd!(right), function.frame_state(*state))?,
Insn::FixnumSub { left, right, state } => gen_fixnum_sub(asm, opnd!(left), opnd!(right), function.frame_state(*state))?,
@@ -200,6 +201,7 @@ fn gen_insn(jit: &mut JITState, asm: &mut Assembler, function: &Function, insn_i
Insn::FixnumGe { left, right } => gen_fixnum_ge(asm, opnd!(left), opnd!(right))?,
Insn::Test { val } => gen_test(asm, opnd!(val))?,
Insn::GuardType { val, guard_type, state } => gen_guard_type(asm, opnd!(val), *guard_type, function.frame_state(*state))?,
+ Insn::GuardBitEquals { val, expected, state } => gen_guard_bit_equals(asm, opnd!(val), *expected, function.frame_state(*state))?,
Insn::PatchPoint(_) => return Some(()), // For now, rb_zjit_bop_redefined() panics. TODO: leave a patch point and fix rb_zjit_bop_redefined()
_ => {
debug!("ZJIT: gen_function: unexpected insn {:?}", insn);
@@ -497,6 +499,13 @@ fn gen_guard_type(asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state:
Some(val)
}
+/// Compile an identity check with a side exit
+fn gen_guard_bit_equals(asm: &mut Assembler, val: lir::Opnd, expected: VALUE, state: &FrameState) -> Option<lir::Opnd> {
+ asm.cmp(val, Opnd::UImm(expected.into()));
+ asm.jnz(Target::SideExit(state.clone()));
+ Some(val)
+}
+
/// Save the incremented PC on the CFP.
/// This is necessary when callees can raise or allocate.
fn gen_save_pc(asm: &mut Assembler, state: &FrameState) {
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 98c492daa3..7b66d05bed 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -5,7 +5,7 @@ use crate::{
cruby::*, options::get_option, hir_type::types::Fixnum, options::DumpHIR, profile::get_or_create_iseq_payload
};
use std::{cell::RefCell, collections::{HashMap, HashSet}, ffi::c_void, mem::{align_of, size_of}, ptr, slice::Iter};
-use crate::hir_type::{Type, types};
+use crate::hir_type::{Type, types, get_class_name};
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
pub struct InsnId(pub usize);
@@ -95,7 +95,7 @@ pub struct CallInfo {
}
/// Invalidation reasons
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, Copy)]
pub enum Invariant {
/// Basic operation is redefined
BOPRedefined {
@@ -104,19 +104,37 @@ pub enum Invariant {
/// BOP_{bop}
bop: ruby_basic_operators,
},
+ MethodRedefined {
+ /// The class object whose method we want to assume unchanged
+ klass: VALUE,
+ /// The method ID of the method we want to assume unchanged
+ method: ID,
+ }
+}
+
+impl Invariant {
+ pub fn print(self, ptr_map: &PtrPrintMap) -> InvariantPrinter {
+ InvariantPrinter { inner: self, ptr_map }
+ }
+}
+
+/// Print adaptor for [`Invariant`]. See [`PtrPrintMap`].
+pub struct InvariantPrinter<'a> {
+ inner: Invariant,
+ ptr_map: &'a PtrPrintMap,
}
-impl std::fmt::Display for Invariant {
+impl<'a> std::fmt::Display for InvariantPrinter<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
- match self {
- Self::BOPRedefined { klass, bop } => {
+ match self.inner {
+ Invariant::BOPRedefined { klass, bop } => {
write!(f, "BOPRedefined(")?;
- match *klass {
+ match klass {
INTEGER_REDEFINED_OP_FLAG => write!(f, "INTEGER_REDEFINED_OP_FLAG")?,
_ => write!(f, "{klass}")?,
}
write!(f, ", ")?;
- match *bop {
+ match bop {
BOP_PLUS => write!(f, "BOP_PLUS")?,
BOP_MINUS => write!(f, "BOP_MINUS")?,
BOP_MULT => write!(f, "BOP_MULT")?,
@@ -132,6 +150,16 @@ impl std::fmt::Display for Invariant {
}
write!(f, ")")
}
+ Invariant::MethodRedefined { klass, method } => {
+ let class_name = get_class_name(Some(klass));
+ let method_name = unsafe {
+ cstr_to_rust_string(rb_id2name(method)).unwrap_or_else(|| "<unknown>".to_owned())
+ };
+ write!(f, "MethodRedefined({class_name}@{:p}, {method_name}@{:p})",
+ self.ptr_map.map_ptr(klass.as_ptr::<VALUE>()),
+ self.ptr_map.map_id(method)
+ )
+ }
}
}
}
@@ -238,6 +266,11 @@ impl PtrPrintMap {
}
}
}
+
+ /// Map a Ruby ID (index into intern table) for printing
+ fn map_id(&self, id: u64) -> *const c_void {
+ self.map_ptr(id as *const c_void)
+ }
}
#[derive(Debug, Clone)]
@@ -285,6 +318,7 @@ pub enum Insn {
// Ignoring keyword arguments etc for now
SendWithoutBlock { self_val: InsnId, call_info: CallInfo, cd: *const rb_call_data, args: Vec<InsnId>, state: FrameStateId },
Send { self_val: InsnId, call_info: CallInfo, cd: *const rb_call_data, blockiseq: IseqPtr, args: Vec<InsnId>, state: FrameStateId },
+ SendWithoutBlockDirect { self_val: InsnId, call_info: CallInfo, cd: *const rb_call_data, iseq: IseqPtr, args: Vec<InsnId>, state: FrameStateId },
// Control flow instructions
Return { val: InsnId },
@@ -302,8 +336,10 @@ pub enum Insn {
FixnumGt { left: InsnId, right: InsnId },
FixnumGe { left: InsnId, right: InsnId },
- /// Side-exist if val doesn't have the expected type.
+ /// Side-exit if val doesn't have the expected type.
GuardType { val: InsnId, guard_type: Type, state: FrameStateId },
+ /// Side-exit if val is not the expected VALUE.
+ GuardBitEquals { val: InsnId, expected: VALUE, state: FrameStateId },
/// Generate no code (or padding if necessary) and insert a patch point
/// that can be rewritten to a side exit when the Invariant is broken.
@@ -360,6 +396,13 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
}
Ok(())
}
+ Insn::SendWithoutBlockDirect { self_val, call_info, iseq, args, .. } => {
+ write!(f, "SendWithoutBlockDirect {self_val}, :{} ({:?})", call_info.method_name, self.ptr_map.map_ptr(iseq))?;
+ for arg in args {
+ write!(f, ", {arg}")?;
+ }
+ Ok(())
+ }
Insn::Send { self_val, call_info, args, blockiseq, .. } => {
// For tests, we want to check HIR snippets textually. Addresses change
// between runs, making tests fail. Instead, pick an arbitrary hex value to
@@ -383,7 +426,8 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
Insn::FixnumGt { left, right, .. } => { write!(f, "FixnumGt {left}, {right}") },
Insn::FixnumGe { left, right, .. } => { write!(f, "FixnumGe {left}, {right}") },
Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {guard_type}") },
- Insn::PatchPoint(invariant) => { write!(f, "PatchPoint {invariant:}") },
+ Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(&self.ptr_map)) },
+ Insn::PatchPoint(invariant) => { write!(f, "PatchPoint {}", invariant.print(&self.ptr_map)) },
insn => { write!(f, "{insn:?}") }
}
}
@@ -573,6 +617,12 @@ impl Function {
id
}
+ // Add an instruction to an SSA block
+ fn push_insn_id(&mut self, block: BlockId, insn_id: InsnId) -> InsnId {
+ self.blocks[block.0].insns.push(insn_id);
+ insn_id
+ }
+
/// Return the number of instructions
pub fn num_insns(&self) -> usize {
self.insns.len()
@@ -641,6 +691,7 @@ impl Function {
IfTrue { val, target } => IfTrue { val: find!(*val), target: target.clone() },
IfFalse { val, target } => IfFalse { val: find!(*val), target: target.clone() },
GuardType { val, guard_type, state } => GuardType { val: find!(*val), guard_type: *guard_type, state: *state },
+ GuardBitEquals { val, expected, state } => GuardBitEquals { val: find!(*val), expected: *expected, state: *state },
FixnumAdd { left, right, state } => FixnumAdd { left: find!(*left), right: find!(*right), state: *state },
FixnumSub { left, right, state } => FixnumSub { left: find!(*left), right: find!(*right), state: *state },
FixnumMult { left, right, state } => FixnumMult { left: find!(*left), right: find!(*right), state: *state },
@@ -659,6 +710,14 @@ impl Function {
args: args.iter().map(|arg| find!(*arg)).collect(),
state: *state,
},
+ SendWithoutBlockDirect { self_val, call_info, cd, iseq, args, state } => SendWithoutBlockDirect {
+ self_val: find!(*self_val),
+ call_info: call_info.clone(),
+ cd: cd.clone(),
+ iseq: *iseq,
+ args: args.iter().map(|arg| find!(*arg)).collect(),
+ state: *state,
+ },
Send { self_val, call_info, cd, blockiseq, args, state } => Send {
self_val: find!(*self_val),
call_info: call_info.clone(),
@@ -719,6 +778,7 @@ impl Function {
Insn::ArrayDup { .. } => types::ArrayExact,
Insn::CCall { .. } => types::Any,
Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type),
+ Insn::GuardBitEquals { val, expected, .. } => self.type_of(*val).intersection(Type::from_value(*expected)),
Insn::FixnumAdd { .. } => types::Fixnum,
Insn::FixnumSub { .. } => types::Fixnum,
Insn::FixnumMult { .. } => types::Fixnum,
@@ -731,6 +791,7 @@ impl Function {
Insn::FixnumGt { .. } => types::BoolExact,
Insn::FixnumGe { .. } => types::BoolExact,
Insn::SendWithoutBlock { .. } => types::BasicObject,
+ Insn::SendWithoutBlockDirect { .. } => types::BasicObject,
Insn::Send { .. } => types::BasicObject,
Insn::PutSelf => types::BasicObject,
Insn::Defined { .. } => types::BasicObject,
@@ -801,6 +862,50 @@ impl Function {
}
}
+ /// Rewrite SendWithoutBlock opcodes into SendWithoutBlockDirect opcodes if we know the target
+ /// ISEQ statically. This removes run-time method lookups and opens the door for inlining.
+ fn optimize_direct_sends(&mut self) {
+ let payload = get_or_create_iseq_payload(self.iseq);
+ for block in self.rpo() {
+ let old_insns = std::mem::take(&mut self.blocks[block.0].insns);
+ assert!(self.blocks[block.0].insns.is_empty());
+ for insn_id in old_insns {
+ match self.find(insn_id) {
+ Insn::SendWithoutBlock { self_val, call_info, cd, args, state } => {
+ let frame_state = &self.frame_states[state.0];
+ let self_type = match payload.get_operand_types(frame_state.insn_idx) {
+ Some([self_type, ..]) if self_type.is_top_self() => self_type,
+ _ => { self.push_insn_id(block, insn_id); continue; }
+ };
+ let top_self = self_type.ruby_object().unwrap();
+ let top_self_klass = top_self.class_of();
+ let ci = unsafe { get_call_data_ci(cd) }; // info about the call site
+ let mid = unsafe { vm_ci_mid(ci) };
+ // Do method lookup
+ let mut cme = unsafe { rb_callable_method_entry(top_self_klass, mid) };
+ if cme.is_null() {
+ self.push_insn_id(block, insn_id); continue;
+ }
+ // Load an overloaded cme if applicable. See vm_search_cc().
+ // It allows you to use a faster ISEQ if possible.
+ cme = unsafe { rb_check_overloaded_cme(cme, ci) };
+ let def_type = unsafe { get_cme_def_type(cme) };
+ if def_type != VM_METHOD_TYPE_ISEQ {
+ self.push_insn_id(block, insn_id); continue;
+ }
+ self.push_insn(block, Insn::PatchPoint(Invariant::MethodRedefined { klass: top_self_klass, method: mid }));
+ let iseq = unsafe { get_def_iseq_ptr((*cme).def) };
+ let self_val = self.push_insn(block, Insn::GuardBitEquals { val: self_val, expected: top_self, state });
+ let send_direct = self.push_insn(block, Insn::SendWithoutBlockDirect { self_val, call_info, cd, iseq, args, state });
+ self.make_equal_to(insn_id, send_direct);
+ }
+ _ => { self.push_insn_id(block, insn_id); }
+ }
+ }
+ }
+ self.infer_types();
+ }
+
/// Use type information left by `infer_types` to fold away operations that can be evaluated at compile-time.
///
/// It can fold fixnum math, truthiness tests, and branches with constant conditionals.
@@ -932,6 +1037,7 @@ impl Function {
/// Run all the optimization passes we have.
pub fn optimize(&mut self) {
// Function is assumed to have types inferred already
+ self.optimize_direct_sends();
self.fold_constants();
// Dump HIR after optimization
@@ -987,6 +1093,7 @@ impl<'a> std::fmt::Display for FunctionPrinter<'a> {
#[derive(Debug, Clone)]
pub struct FrameState {
iseq: IseqPtr,
+ insn_idx: usize,
// Ruby bytecode instruction pointer
pub pc: *const VALUE,
@@ -1032,7 +1139,7 @@ fn ep_offset_to_local_idx(iseq: IseqPtr, ep_offset: u32) -> usize {
impl FrameState {
fn new(iseq: IseqPtr) -> FrameState {
- FrameState { iseq, pc: 0 as *const VALUE, stack: vec![], locals: vec![] }
+ FrameState { iseq, pc: 0 as *const VALUE, insn_idx: 0, stack: vec![], locals: vec![] }
}
/// Get the number of stack operands
@@ -1204,6 +1311,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
result
};
while insn_idx < iseq_size {
+ state.insn_idx = insn_idx as usize;
// Get the current pc and opcode
let pc = unsafe { rb_iseq_pc_at_idx(iseq, insn_idx.into()) };
state.pc = pc;
@@ -2423,4 +2531,138 @@ mod opt_tests {
Return v8
"#]]);
}
+
+ #[test]
+ fn test_optimize_top_level_call_into_send_direct() {
+ eval("
+ def foo
+ end
+ def test
+ foo
+ end
+ test; test
+ ");
+ assert_optimized_method_hir("test", expect![[r#"
+ fn test:
+ bb0():
+ v1:BasicObject = PutSelf
+ PatchPoint MethodRedefined(Object@0x1000, foo@0x1008)
+ v7:BasicObject[VALUE(0x1010)] = GuardBitEquals v1, VALUE(0x1010)
+ v8:BasicObject = SendWithoutBlockDirect v7, :foo (0x1018)
+ Return v8
+ "#]]);
+ }
+
+ #[test]
+ fn test_optimize_nonexistent_top_level_call() {
+ eval("
+ def foo
+ end
+ def test
+ foo
+ end
+ test; test
+ undef :foo
+ ");
+ assert_optimized_method_hir("test", expect![[r#"
+ fn test:
+ bb0():
+ v1:BasicObject = PutSelf
+ v3:BasicObject = SendWithoutBlock v1, :foo
+ Return v3
+ "#]]);
+ }
+
+ #[test]
+ fn test_optimize_private_top_level_call() {
+ eval("
+ def foo
+ end
+ private :foo
+ def test
+ foo
+ end
+ test; test
+ ");
+ assert_optimized_method_hir("test", expect![[r#"
+ fn test:
+ bb0():
+ v1:BasicObject = PutSelf
+ PatchPoint MethodRedefined(Object@0x1000, foo@0x1008)
+ v7:BasicObject[VALUE(0x1010)] = GuardBitEquals v1, VALUE(0x1010)
+ v8:BasicObject = SendWithoutBlockDirect v7, :foo (0x1018)
+ Return v8
+ "#]]);
+ }
+
+ #[test]
+ fn test_optimize_top_level_call_with_overloaded_cme() {
+ eval("
+ def test
+ Integer(3)
+ end
+ test; test
+ ");
+ assert_optimized_method_hir("test", expect![[r#"
+ fn test:
+ bb0():
+ v1:BasicObject = PutSelf
+ v3:Fixnum[3] = Const Value(3)
+ PatchPoint MethodRedefined(Object@0x1000, Integer@0x1008)
+ v9:BasicObject[VALUE(0x1010)] = GuardBitEquals v1, VALUE(0x1010)
+ v10:BasicObject = SendWithoutBlockDirect v9, :Integer (0x1018), v3
+ Return v10
+ "#]]);
+ }
+
+ #[test]
+ fn test_optimize_top_level_call_with_args_into_send_direct() {
+ eval("
+ def foo a, b
+ end
+ def test
+ foo 1, 2
+ end
+ test; test
+ ");
+ assert_optimized_method_hir("test", expect![[r#"
+ fn test:
+ bb0():
+ v1:BasicObject = PutSelf
+ v3:Fixnum[1] = Const Value(1)
+ v5:Fixnum[2] = Const Value(2)
+ PatchPoint MethodRedefined(Object@0x1000, foo@0x1008)
+ v11:BasicObject[VALUE(0x1010)] = GuardBitEquals v1, VALUE(0x1010)
+ v12:BasicObject = SendWithoutBlockDirect v11, :foo (0x1018), v3, v5
+ Return v12
+ "#]]);
+ }
+
+ #[test]
+ fn test_optimize_top_level_sends_into_send_direct() {
+ eval("
+ def foo
+ end
+ def bar
+ end
+ def test
+ foo
+ bar
+ end
+ test; test
+ ");
+ assert_optimized_method_hir("test", expect![[r#"
+ fn test:
+ bb0():
+ v1:BasicObject = PutSelf
+ PatchPoint MethodRedefined(Object@0x1000, foo@0x1008)
+ v12:BasicObject[VALUE(0x1010)] = GuardBitEquals v1, VALUE(0x1010)
+ v13:BasicObject = SendWithoutBlockDirect v12, :foo (0x1018)
+ v6:BasicObject = PutSelf
+ PatchPoint MethodRedefined(Object@0x1000, bar@0x1020)
+ v15:BasicObject[VALUE(0x1010)] = GuardBitEquals v6, VALUE(0x1010)
+ v16:BasicObject = SendWithoutBlockDirect v15, :bar (0x1018)
+ Return v16
+ "#]]);
+ }
}
diff --git a/zjit/src/hir_type/mod.rs b/zjit/src/hir_type/mod.rs
index 2168bbda2b..5e71e26305 100644
--- a/zjit/src/hir_type/mod.rs
+++ b/zjit/src/hir_type/mod.rs
@@ -64,7 +64,7 @@ pub struct Type {
include!("hir_type.inc.rs");
/// Get class name from a class pointer.
-fn get_class_name(class: Option<VALUE>) -> String {
+pub fn get_class_name(class: Option<VALUE>) -> String {
use crate::cruby::{RB_TYPE_P, RUBY_T_MODULE, RUBY_T_CLASS};
use crate::cruby::{cstr_to_rust_string, rb_class2name};
class.filter(|&class| {
@@ -244,6 +244,12 @@ impl Type {
self.is_subtype(types::NilClassExact) || self.is_subtype(types::FalseClassExact)
}
+ /// Top self is the Ruby global object, where top-level method definitions go. Return true if
+ /// this Type has a Ruby object specialization that is the top-level self.
+ pub fn is_top_self(&self) -> bool {
+ self.ruby_object() == Some(unsafe { crate::cruby::rb_vm_top_self() })
+ }
+
/// Return the object specialization, if any.
pub fn ruby_object(&self) -> Option<VALUE> {
match self.spec {