summaryrefslogtreecommitdiff
path: root/zjit/src
diff options
context:
space:
mode:
authorMax Bernstein <[email protected]>2025-05-23 13:45:38 -0400
committerTakashi Kokubun <[email protected]>2025-05-23 13:32:49 -0700
commitd23fe287b647d342dbb26b5b714992823b068fe4 (patch)
tree3cf5df81e12112b755cc2fd880fd9a544bab2808 /zjit/src
parenta0df4cf6f16c68406aa32ad32047511e77bc0659 (diff)
ZJIT: Side-exit into the interpreter on unknown call types
Notes
Notes: Merged: https://github.com/ruby/ruby/pull/13430
Diffstat (limited to 'zjit/src')
-rw-r--r--zjit/src/hir.rs88
1 files changed, 66 insertions, 22 deletions
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 66cbac8d7c..ffd60c493e 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -1797,7 +1797,6 @@ pub enum CallType {
#[derive(Debug, PartialEq)]
pub enum ParseError {
StackUnderflow(FrameState),
- UnhandledCallType(CallType),
}
/// Return the number of locals in the current ISEQ (includes parameters)
@@ -1806,19 +1805,19 @@ fn num_locals(iseq: *const rb_iseq_t) -> usize {
}
/// If we can't handle the type of send (yet), bail out.
-fn filter_translatable_calls(flag: u32) -> Result<(), ParseError> {
- if (flag & VM_CALL_KW_SPLAT_MUT) != 0 { return Err(ParseError::UnhandledCallType(CallType::KwSplatMut)); }
- if (flag & VM_CALL_ARGS_SPLAT_MUT) != 0 { return Err(ParseError::UnhandledCallType(CallType::SplatMut)); }
- if (flag & VM_CALL_ARGS_SPLAT) != 0 { return Err(ParseError::UnhandledCallType(CallType::Splat)); }
- if (flag & VM_CALL_KW_SPLAT) != 0 { return Err(ParseError::UnhandledCallType(CallType::KwSplat)); }
- if (flag & VM_CALL_ARGS_BLOCKARG) != 0 { return Err(ParseError::UnhandledCallType(CallType::BlockArg)); }
- if (flag & VM_CALL_KWARG) != 0 { return Err(ParseError::UnhandledCallType(CallType::Kwarg)); }
- if (flag & VM_CALL_TAILCALL) != 0 { return Err(ParseError::UnhandledCallType(CallType::Tailcall)); }
- if (flag & VM_CALL_SUPER) != 0 { return Err(ParseError::UnhandledCallType(CallType::Super)); }
- if (flag & VM_CALL_ZSUPER) != 0 { return Err(ParseError::UnhandledCallType(CallType::Zsuper)); }
- if (flag & VM_CALL_OPT_SEND) != 0 { return Err(ParseError::UnhandledCallType(CallType::OptSend)); }
- if (flag & VM_CALL_FORWARDING) != 0 { return Err(ParseError::UnhandledCallType(CallType::Forwarding)); }
- Ok(())
+fn unknown_call_type(flag: u32) -> bool {
+ if (flag & VM_CALL_KW_SPLAT_MUT) != 0 { return true; }
+ if (flag & VM_CALL_ARGS_SPLAT_MUT) != 0 { return true; }
+ if (flag & VM_CALL_ARGS_SPLAT) != 0 { return true; }
+ if (flag & VM_CALL_KW_SPLAT) != 0 { return true; }
+ if (flag & VM_CALL_ARGS_BLOCKARG) != 0 { return true; }
+ if (flag & VM_CALL_KWARG) != 0 { return true; }
+ if (flag & VM_CALL_TAILCALL) != 0 { return true; }
+ if (flag & VM_CALL_SUPER) != 0 { return true; }
+ if (flag & VM_CALL_ZSUPER) != 0 { return true; }
+ if (flag & VM_CALL_OPT_SEND) != 0 { return true; }
+ if (flag & VM_CALL_FORWARDING) != 0 { return true; }
+ false
}
/// We have IseqPayload, which keeps track of HIR Types in the interpreter, but this is not useful
@@ -2147,7 +2146,12 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
// NB: opt_neq has two cd; get_arg(0) is for eq and get_arg(1) is for neq
let cd: *const rb_call_data = get_arg(pc, 1).as_ptr();
let call_info = unsafe { rb_get_call_data_ci(cd) };
- filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?;
+ if unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
+ // Unknown call type; side-exit into the interpreter
+ let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
+ fun.push_insn(block, Insn::SideExit { state: exit_id });
+ break; // End the block
+ }
let argc = unsafe { vm_ci_argc((*cd).ci) };
@@ -2190,7 +2194,12 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
YARVINSN_opt_send_without_block => {
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
let call_info = unsafe { rb_get_call_data_ci(cd) };
- filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?;
+ if unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
+ // Unknown call type; side-exit into the interpreter
+ let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
+ fun.push_insn(block, Insn::SideExit { state: exit_id });
+ break; // End the block
+ }
let argc = unsafe { vm_ci_argc((*cd).ci) };
@@ -2213,7 +2222,12 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
let blockiseq: IseqPtr = get_arg(pc, 1).as_iseq();
let call_info = unsafe { rb_get_call_data_ci(cd) };
- filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?;
+ if unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
+ // Unknown call type; side-exit into the interpreter
+ let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
+ fun.push_insn(block, Insn::SideExit { state: exit_id });
+ break; // End the block
+ }
let argc = unsafe { vm_ci_argc((*cd).ci) };
let method_name = unsafe {
@@ -3077,7 +3091,13 @@ mod tests {
eval("
def test(a) = foo(*a)
");
- assert_compile_fails("test", ParseError::UnhandledCallType(CallType::Splat))
+ assert_method_hir("test", expect![[r#"
+ fn test:
+ bb0(v0:BasicObject):
+ v2:BasicObject = PutSelf
+ v4:ArrayExact = ToArray v0
+ SideExit
+ "#]]);
}
#[test]
@@ -3085,7 +3105,12 @@ mod tests {
eval("
def test(a) = foo(&a)
");
- assert_compile_fails("test", ParseError::UnhandledCallType(CallType::BlockArg))
+ assert_method_hir("test", expect![[r#"
+ fn test:
+ bb0(v0:BasicObject):
+ v2:BasicObject = PutSelf
+ SideExit
+ "#]]);
}
#[test]
@@ -3093,7 +3118,13 @@ mod tests {
eval("
def test(a) = foo(a: 1)
");
- assert_compile_fails("test", ParseError::UnhandledCallType(CallType::Kwarg))
+ assert_method_hir("test", expect![[r#"
+ fn test:
+ bb0(v0:BasicObject):
+ v2:BasicObject = PutSelf
+ v3:Fixnum[1] = Const Value(1)
+ SideExit
+ "#]]);
}
#[test]
@@ -3101,7 +3132,12 @@ mod tests {
eval("
def test(a) = foo(**a)
");
- assert_compile_fails("test", ParseError::UnhandledCallType(CallType::KwSplat))
+ assert_method_hir("test", expect![[r#"
+ fn test:
+ bb0(v0:BasicObject):
+ v2:BasicObject = PutSelf
+ SideExit
+ "#]]);
}
// TODO(max): Figure out how to generate a call with TAILCALL flag
@@ -3165,7 +3201,15 @@ mod tests {
eval("
def test(*) = foo *, 1
");
- assert_compile_fails("test", ParseError::UnhandledCallType(CallType::SplatMut))
+ assert_method_hir("test", expect![[r#"
+ fn test:
+ bb0(v0:ArrayExact):
+ v2:BasicObject = PutSelf
+ v4:ArrayExact = ToNewArray v0
+ v5:Fixnum[1] = Const Value(1)
+ ArrayPush v4, v5
+ SideExit
+ "#]]);
}
#[test]