diff options
author | Max Bernstein <[email protected]> | 2025-03-27 16:38:31 -0400 |
---|---|---|
committer | Takashi Kokubun <[email protected]> | 2025-04-18 21:53:01 +0900 |
commit | 308cd59bf8d9a3d6141fae5931bad397020e7bc8 (patch) | |
tree | c68faaf02a965e5b2b375cbd7379b87fbc630c1b | |
parent | cfc9234ccdb457934f4daeef599e303844869fc3 (diff) |
Rewrite SendWithoutBlock to SendWithoutBlockDirect when possible
In calls to top-level functions, we assume that call targets will not
get rewritten, so we can insert a PatchPoint and do the lookup at
compile-time.
Notes
Notes:
Merged: https://github.com/ruby/ruby/pull/13131
-rw-r--r-- | zjit/src/codegen.rs | 11 | ||||
-rw-r--r-- | zjit/src/hir.rs | 262 | ||||
-rw-r--r-- | zjit/src/hir_type/mod.rs | 8 |
3 files changed, 269 insertions, 12 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 316f6d9b80..7fbe4f55dd 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -187,7 +187,8 @@ fn gen_insn(jit: &mut JITState, asm: &mut Assembler, function: &Function, insn_i Insn::Jump(branch) => return gen_jump(jit, asm, branch), Insn::IfTrue { val, target } => return gen_if_true(jit, asm, opnd!(val), target), Insn::IfFalse { val, target } => return gen_if_false(jit, asm, opnd!(val), target), - Insn::SendWithoutBlock { call_info, cd, state, .. } => gen_send_without_block(jit, asm, call_info, *cd, function.frame_state(*state))?, + Insn::SendWithoutBlock { call_info, cd, state, .. } | Insn::SendWithoutBlockDirect { call_info, cd, state, .. } + => gen_send_without_block(jit, asm, call_info, *cd, function.frame_state(*state))?, Insn::Return { val } => return Some(gen_return(asm, opnd!(val))?), Insn::FixnumAdd { left, right, state } => gen_fixnum_add(asm, opnd!(left), opnd!(right), function.frame_state(*state))?, Insn::FixnumSub { left, right, state } => gen_fixnum_sub(asm, opnd!(left), opnd!(right), function.frame_state(*state))?, @@ -200,6 +201,7 @@ fn gen_insn(jit: &mut JITState, asm: &mut Assembler, function: &Function, insn_i Insn::FixnumGe { left, right } => gen_fixnum_ge(asm, opnd!(left), opnd!(right))?, Insn::Test { val } => gen_test(asm, opnd!(val))?, Insn::GuardType { val, guard_type, state } => gen_guard_type(asm, opnd!(val), *guard_type, function.frame_state(*state))?, + Insn::GuardBitEquals { val, expected, state } => gen_guard_bit_equals(asm, opnd!(val), *expected, function.frame_state(*state))?, Insn::PatchPoint(_) => return Some(()), // For now, rb_zjit_bop_redefined() panics. TODO: leave a patch point and fix rb_zjit_bop_redefined() _ => { debug!("ZJIT: gen_function: unexpected insn {:?}", insn); @@ -497,6 +499,13 @@ fn gen_guard_type(asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state: Some(val) } +/// Compile an identity check with a side exit +fn gen_guard_bit_equals(asm: &mut Assembler, val: lir::Opnd, expected: VALUE, state: &FrameState) -> Option<lir::Opnd> { + asm.cmp(val, Opnd::UImm(expected.into())); + asm.jnz(Target::SideExit(state.clone())); + Some(val) +} + /// Save the incremented PC on the CFP. /// This is necessary when callees can raise or allocate. fn gen_save_pc(asm: &mut Assembler, state: &FrameState) { diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 98c492daa3..7b66d05bed 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -5,7 +5,7 @@ use crate::{ cruby::*, options::get_option, hir_type::types::Fixnum, options::DumpHIR, profile::get_or_create_iseq_payload }; use std::{cell::RefCell, collections::{HashMap, HashSet}, ffi::c_void, mem::{align_of, size_of}, ptr, slice::Iter}; -use crate::hir_type::{Type, types}; +use crate::hir_type::{Type, types, get_class_name}; #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)] pub struct InsnId(pub usize); @@ -95,7 +95,7 @@ pub struct CallInfo { } /// Invalidation reasons -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub enum Invariant { /// Basic operation is redefined BOPRedefined { @@ -104,19 +104,37 @@ pub enum Invariant { /// BOP_{bop} bop: ruby_basic_operators, }, + MethodRedefined { + /// The class object whose method we want to assume unchanged + klass: VALUE, + /// The method ID of the method we want to assume unchanged + method: ID, + } +} + +impl Invariant { + pub fn print(self, ptr_map: &PtrPrintMap) -> InvariantPrinter { + InvariantPrinter { inner: self, ptr_map } + } +} + +/// Print adaptor for [`Invariant`]. See [`PtrPrintMap`]. +pub struct InvariantPrinter<'a> { + inner: Invariant, + ptr_map: &'a PtrPrintMap, } -impl std::fmt::Display for Invariant { +impl<'a> std::fmt::Display for InvariantPrinter<'a> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::BOPRedefined { klass, bop } => { + match self.inner { + Invariant::BOPRedefined { klass, bop } => { write!(f, "BOPRedefined(")?; - match *klass { + match klass { INTEGER_REDEFINED_OP_FLAG => write!(f, "INTEGER_REDEFINED_OP_FLAG")?, _ => write!(f, "{klass}")?, } write!(f, ", ")?; - match *bop { + match bop { BOP_PLUS => write!(f, "BOP_PLUS")?, BOP_MINUS => write!(f, "BOP_MINUS")?, BOP_MULT => write!(f, "BOP_MULT")?, @@ -132,6 +150,16 @@ impl std::fmt::Display for Invariant { } write!(f, ")") } + Invariant::MethodRedefined { klass, method } => { + let class_name = get_class_name(Some(klass)); + let method_name = unsafe { + cstr_to_rust_string(rb_id2name(method)).unwrap_or_else(|| "<unknown>".to_owned()) + }; + write!(f, "MethodRedefined({class_name}@{:p}, {method_name}@{:p})", + self.ptr_map.map_ptr(klass.as_ptr::<VALUE>()), + self.ptr_map.map_id(method) + ) + } } } } @@ -238,6 +266,11 @@ impl PtrPrintMap { } } } + + /// Map a Ruby ID (index into intern table) for printing + fn map_id(&self, id: u64) -> *const c_void { + self.map_ptr(id as *const c_void) + } } #[derive(Debug, Clone)] @@ -285,6 +318,7 @@ pub enum Insn { // Ignoring keyword arguments etc for now SendWithoutBlock { self_val: InsnId, call_info: CallInfo, cd: *const rb_call_data, args: Vec<InsnId>, state: FrameStateId }, Send { self_val: InsnId, call_info: CallInfo, cd: *const rb_call_data, blockiseq: IseqPtr, args: Vec<InsnId>, state: FrameStateId }, + SendWithoutBlockDirect { self_val: InsnId, call_info: CallInfo, cd: *const rb_call_data, iseq: IseqPtr, args: Vec<InsnId>, state: FrameStateId }, // Control flow instructions Return { val: InsnId }, @@ -302,8 +336,10 @@ pub enum Insn { FixnumGt { left: InsnId, right: InsnId }, FixnumGe { left: InsnId, right: InsnId }, - /// Side-exist if val doesn't have the expected type. + /// Side-exit if val doesn't have the expected type. GuardType { val: InsnId, guard_type: Type, state: FrameStateId }, + /// Side-exit if val is not the expected VALUE. + GuardBitEquals { val: InsnId, expected: VALUE, state: FrameStateId }, /// Generate no code (or padding if necessary) and insert a patch point /// that can be rewritten to a side exit when the Invariant is broken. @@ -360,6 +396,13 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { } Ok(()) } + Insn::SendWithoutBlockDirect { self_val, call_info, iseq, args, .. } => { + write!(f, "SendWithoutBlockDirect {self_val}, :{} ({:?})", call_info.method_name, self.ptr_map.map_ptr(iseq))?; + for arg in args { + write!(f, ", {arg}")?; + } + Ok(()) + } Insn::Send { self_val, call_info, args, blockiseq, .. } => { // For tests, we want to check HIR snippets textually. Addresses change // between runs, making tests fail. Instead, pick an arbitrary hex value to @@ -383,7 +426,8 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::FixnumGt { left, right, .. } => { write!(f, "FixnumGt {left}, {right}") }, Insn::FixnumGe { left, right, .. } => { write!(f, "FixnumGe {left}, {right}") }, Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {guard_type}") }, - Insn::PatchPoint(invariant) => { write!(f, "PatchPoint {invariant:}") }, + Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(&self.ptr_map)) }, + Insn::PatchPoint(invariant) => { write!(f, "PatchPoint {}", invariant.print(&self.ptr_map)) }, insn => { write!(f, "{insn:?}") } } } @@ -573,6 +617,12 @@ impl Function { id } + // Add an instruction to an SSA block + fn push_insn_id(&mut self, block: BlockId, insn_id: InsnId) -> InsnId { + self.blocks[block.0].insns.push(insn_id); + insn_id + } + /// Return the number of instructions pub fn num_insns(&self) -> usize { self.insns.len() @@ -641,6 +691,7 @@ impl Function { IfTrue { val, target } => IfTrue { val: find!(*val), target: target.clone() }, IfFalse { val, target } => IfFalse { val: find!(*val), target: target.clone() }, GuardType { val, guard_type, state } => GuardType { val: find!(*val), guard_type: *guard_type, state: *state }, + GuardBitEquals { val, expected, state } => GuardBitEquals { val: find!(*val), expected: *expected, state: *state }, FixnumAdd { left, right, state } => FixnumAdd { left: find!(*left), right: find!(*right), state: *state }, FixnumSub { left, right, state } => FixnumSub { left: find!(*left), right: find!(*right), state: *state }, FixnumMult { left, right, state } => FixnumMult { left: find!(*left), right: find!(*right), state: *state }, @@ -659,6 +710,14 @@ impl Function { args: args.iter().map(|arg| find!(*arg)).collect(), state: *state, }, + SendWithoutBlockDirect { self_val, call_info, cd, iseq, args, state } => SendWithoutBlockDirect { + self_val: find!(*self_val), + call_info: call_info.clone(), + cd: cd.clone(), + iseq: *iseq, + args: args.iter().map(|arg| find!(*arg)).collect(), + state: *state, + }, Send { self_val, call_info, cd, blockiseq, args, state } => Send { self_val: find!(*self_val), call_info: call_info.clone(), @@ -719,6 +778,7 @@ impl Function { Insn::ArrayDup { .. } => types::ArrayExact, Insn::CCall { .. } => types::Any, Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type), + Insn::GuardBitEquals { val, expected, .. } => self.type_of(*val).intersection(Type::from_value(*expected)), Insn::FixnumAdd { .. } => types::Fixnum, Insn::FixnumSub { .. } => types::Fixnum, Insn::FixnumMult { .. } => types::Fixnum, @@ -731,6 +791,7 @@ impl Function { Insn::FixnumGt { .. } => types::BoolExact, Insn::FixnumGe { .. } => types::BoolExact, Insn::SendWithoutBlock { .. } => types::BasicObject, + Insn::SendWithoutBlockDirect { .. } => types::BasicObject, Insn::Send { .. } => types::BasicObject, Insn::PutSelf => types::BasicObject, Insn::Defined { .. } => types::BasicObject, @@ -801,6 +862,50 @@ impl Function { } } + /// Rewrite SendWithoutBlock opcodes into SendWithoutBlockDirect opcodes if we know the target + /// ISEQ statically. This removes run-time method lookups and opens the door for inlining. + fn optimize_direct_sends(&mut self) { + let payload = get_or_create_iseq_payload(self.iseq); + for block in self.rpo() { + let old_insns = std::mem::take(&mut self.blocks[block.0].insns); + assert!(self.blocks[block.0].insns.is_empty()); + for insn_id in old_insns { + match self.find(insn_id) { + Insn::SendWithoutBlock { self_val, call_info, cd, args, state } => { + let frame_state = &self.frame_states[state.0]; + let self_type = match payload.get_operand_types(frame_state.insn_idx) { + Some([self_type, ..]) if self_type.is_top_self() => self_type, + _ => { self.push_insn_id(block, insn_id); continue; } + }; + let top_self = self_type.ruby_object().unwrap(); + let top_self_klass = top_self.class_of(); + let ci = unsafe { get_call_data_ci(cd) }; // info about the call site + let mid = unsafe { vm_ci_mid(ci) }; + // Do method lookup + let mut cme = unsafe { rb_callable_method_entry(top_self_klass, mid) }; + if cme.is_null() { + self.push_insn_id(block, insn_id); continue; + } + // Load an overloaded cme if applicable. See vm_search_cc(). + // It allows you to use a faster ISEQ if possible. + cme = unsafe { rb_check_overloaded_cme(cme, ci) }; + let def_type = unsafe { get_cme_def_type(cme) }; + if def_type != VM_METHOD_TYPE_ISEQ { + self.push_insn_id(block, insn_id); continue; + } + self.push_insn(block, Insn::PatchPoint(Invariant::MethodRedefined { klass: top_self_klass, method: mid })); + let iseq = unsafe { get_def_iseq_ptr((*cme).def) }; + let self_val = self.push_insn(block, Insn::GuardBitEquals { val: self_val, expected: top_self, state }); + let send_direct = self.push_insn(block, Insn::SendWithoutBlockDirect { self_val, call_info, cd, iseq, args, state }); + self.make_equal_to(insn_id, send_direct); + } + _ => { self.push_insn_id(block, insn_id); } + } + } + } + self.infer_types(); + } + /// Use type information left by `infer_types` to fold away operations that can be evaluated at compile-time. /// /// It can fold fixnum math, truthiness tests, and branches with constant conditionals. @@ -932,6 +1037,7 @@ impl Function { /// Run all the optimization passes we have. pub fn optimize(&mut self) { // Function is assumed to have types inferred already + self.optimize_direct_sends(); self.fold_constants(); // Dump HIR after optimization @@ -987,6 +1093,7 @@ impl<'a> std::fmt::Display for FunctionPrinter<'a> { #[derive(Debug, Clone)] pub struct FrameState { iseq: IseqPtr, + insn_idx: usize, // Ruby bytecode instruction pointer pub pc: *const VALUE, @@ -1032,7 +1139,7 @@ fn ep_offset_to_local_idx(iseq: IseqPtr, ep_offset: u32) -> usize { impl FrameState { fn new(iseq: IseqPtr) -> FrameState { - FrameState { iseq, pc: 0 as *const VALUE, stack: vec![], locals: vec![] } + FrameState { iseq, pc: 0 as *const VALUE, insn_idx: 0, stack: vec![], locals: vec![] } } /// Get the number of stack operands @@ -1204,6 +1311,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { result }; while insn_idx < iseq_size { + state.insn_idx = insn_idx as usize; // Get the current pc and opcode let pc = unsafe { rb_iseq_pc_at_idx(iseq, insn_idx.into()) }; state.pc = pc; @@ -2423,4 +2531,138 @@ mod opt_tests { Return v8 "#]]); } + + #[test] + fn test_optimize_top_level_call_into_send_direct() { + eval(" + def foo + end + def test + foo + end + test; test + "); + assert_optimized_method_hir("test", expect![[r#" + fn test: + bb0(): + v1:BasicObject = PutSelf + PatchPoint MethodRedefined(Object@0x1000, foo@0x1008) + v7:BasicObject[VALUE(0x1010)] = GuardBitEquals v1, VALUE(0x1010) + v8:BasicObject = SendWithoutBlockDirect v7, :foo (0x1018) + Return v8 + "#]]); + } + + #[test] + fn test_optimize_nonexistent_top_level_call() { + eval(" + def foo + end + def test + foo + end + test; test + undef :foo + "); + assert_optimized_method_hir("test", expect![[r#" + fn test: + bb0(): + v1:BasicObject = PutSelf + v3:BasicObject = SendWithoutBlock v1, :foo + Return v3 + "#]]); + } + + #[test] + fn test_optimize_private_top_level_call() { + eval(" + def foo + end + private :foo + def test + foo + end + test; test + "); + assert_optimized_method_hir("test", expect![[r#" + fn test: + bb0(): + v1:BasicObject = PutSelf + PatchPoint MethodRedefined(Object@0x1000, foo@0x1008) + v7:BasicObject[VALUE(0x1010)] = GuardBitEquals v1, VALUE(0x1010) + v8:BasicObject = SendWithoutBlockDirect v7, :foo (0x1018) + Return v8 + "#]]); + } + + #[test] + fn test_optimize_top_level_call_with_overloaded_cme() { + eval(" + def test + Integer(3) + end + test; test + "); + assert_optimized_method_hir("test", expect![[r#" + fn test: + bb0(): + v1:BasicObject = PutSelf + v3:Fixnum[3] = Const Value(3) + PatchPoint MethodRedefined(Object@0x1000, Integer@0x1008) + v9:BasicObject[VALUE(0x1010)] = GuardBitEquals v1, VALUE(0x1010) + v10:BasicObject = SendWithoutBlockDirect v9, :Integer (0x1018), v3 + Return v10 + "#]]); + } + + #[test] + fn test_optimize_top_level_call_with_args_into_send_direct() { + eval(" + def foo a, b + end + def test + foo 1, 2 + end + test; test + "); + assert_optimized_method_hir("test", expect![[r#" + fn test: + bb0(): + v1:BasicObject = PutSelf + v3:Fixnum[1] = Const Value(1) + v5:Fixnum[2] = Const Value(2) + PatchPoint MethodRedefined(Object@0x1000, foo@0x1008) + v11:BasicObject[VALUE(0x1010)] = GuardBitEquals v1, VALUE(0x1010) + v12:BasicObject = SendWithoutBlockDirect v11, :foo (0x1018), v3, v5 + Return v12 + "#]]); + } + + #[test] + fn test_optimize_top_level_sends_into_send_direct() { + eval(" + def foo + end + def bar + end + def test + foo + bar + end + test; test + "); + assert_optimized_method_hir("test", expect![[r#" + fn test: + bb0(): + v1:BasicObject = PutSelf + PatchPoint MethodRedefined(Object@0x1000, foo@0x1008) + v12:BasicObject[VALUE(0x1010)] = GuardBitEquals v1, VALUE(0x1010) + v13:BasicObject = SendWithoutBlockDirect v12, :foo (0x1018) + v6:BasicObject = PutSelf + PatchPoint MethodRedefined(Object@0x1000, bar@0x1020) + v15:BasicObject[VALUE(0x1010)] = GuardBitEquals v6, VALUE(0x1010) + v16:BasicObject = SendWithoutBlockDirect v15, :bar (0x1018) + Return v16 + "#]]); + } } diff --git a/zjit/src/hir_type/mod.rs b/zjit/src/hir_type/mod.rs index 2168bbda2b..5e71e26305 100644 --- a/zjit/src/hir_type/mod.rs +++ b/zjit/src/hir_type/mod.rs @@ -64,7 +64,7 @@ pub struct Type { include!("hir_type.inc.rs"); /// Get class name from a class pointer. -fn get_class_name(class: Option<VALUE>) -> String { +pub fn get_class_name(class: Option<VALUE>) -> String { use crate::cruby::{RB_TYPE_P, RUBY_T_MODULE, RUBY_T_CLASS}; use crate::cruby::{cstr_to_rust_string, rb_class2name}; class.filter(|&class| { @@ -244,6 +244,12 @@ impl Type { self.is_subtype(types::NilClassExact) || self.is_subtype(types::FalseClassExact) } + /// Top self is the Ruby global object, where top-level method definitions go. Return true if + /// this Type has a Ruby object specialization that is the top-level self. + pub fn is_top_self(&self) -> bool { + self.ruby_object() == Some(unsafe { crate::cruby::rb_vm_top_self() }) + } + /// Return the object specialization, if any. pub fn ruby_object(&self) -> Option<VALUE> { match self.spec { |