diff options
Diffstat (limited to 'zjit/src')
-rw-r--r-- | zjit/src/hir.rs | 88 |
1 files changed, 66 insertions, 22 deletions
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 66cbac8d7c..ffd60c493e 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -1797,7 +1797,6 @@ pub enum CallType { #[derive(Debug, PartialEq)] pub enum ParseError { StackUnderflow(FrameState), - UnhandledCallType(CallType), } /// Return the number of locals in the current ISEQ (includes parameters) @@ -1806,19 +1805,19 @@ fn num_locals(iseq: *const rb_iseq_t) -> usize { } /// If we can't handle the type of send (yet), bail out. -fn filter_translatable_calls(flag: u32) -> Result<(), ParseError> { - if (flag & VM_CALL_KW_SPLAT_MUT) != 0 { return Err(ParseError::UnhandledCallType(CallType::KwSplatMut)); } - if (flag & VM_CALL_ARGS_SPLAT_MUT) != 0 { return Err(ParseError::UnhandledCallType(CallType::SplatMut)); } - if (flag & VM_CALL_ARGS_SPLAT) != 0 { return Err(ParseError::UnhandledCallType(CallType::Splat)); } - if (flag & VM_CALL_KW_SPLAT) != 0 { return Err(ParseError::UnhandledCallType(CallType::KwSplat)); } - if (flag & VM_CALL_ARGS_BLOCKARG) != 0 { return Err(ParseError::UnhandledCallType(CallType::BlockArg)); } - if (flag & VM_CALL_KWARG) != 0 { return Err(ParseError::UnhandledCallType(CallType::Kwarg)); } - if (flag & VM_CALL_TAILCALL) != 0 { return Err(ParseError::UnhandledCallType(CallType::Tailcall)); } - if (flag & VM_CALL_SUPER) != 0 { return Err(ParseError::UnhandledCallType(CallType::Super)); } - if (flag & VM_CALL_ZSUPER) != 0 { return Err(ParseError::UnhandledCallType(CallType::Zsuper)); } - if (flag & VM_CALL_OPT_SEND) != 0 { return Err(ParseError::UnhandledCallType(CallType::OptSend)); } - if (flag & VM_CALL_FORWARDING) != 0 { return Err(ParseError::UnhandledCallType(CallType::Forwarding)); } - Ok(()) +fn unknown_call_type(flag: u32) -> bool { + if (flag & VM_CALL_KW_SPLAT_MUT) != 0 { return true; } + if (flag & VM_CALL_ARGS_SPLAT_MUT) != 0 { return true; } + if (flag & VM_CALL_ARGS_SPLAT) != 0 { return true; } + if (flag & VM_CALL_KW_SPLAT) != 0 { return true; } + if (flag & VM_CALL_ARGS_BLOCKARG) != 0 { return true; } + if (flag & VM_CALL_KWARG) != 0 { return true; } + if (flag & VM_CALL_TAILCALL) != 0 { return true; } + if (flag & VM_CALL_SUPER) != 0 { return true; } + if (flag & VM_CALL_ZSUPER) != 0 { return true; } + if (flag & VM_CALL_OPT_SEND) != 0 { return true; } + if (flag & VM_CALL_FORWARDING) != 0 { return true; } + false } /// We have IseqPayload, which keeps track of HIR Types in the interpreter, but this is not useful @@ -2147,7 +2146,12 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { // NB: opt_neq has two cd; get_arg(0) is for eq and get_arg(1) is for neq let cd: *const rb_call_data = get_arg(pc, 1).as_ptr(); let call_info = unsafe { rb_get_call_data_ci(cd) }; - filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?; + if unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) { + // Unknown call type; side-exit into the interpreter + let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); + fun.push_insn(block, Insn::SideExit { state: exit_id }); + break; // End the block + } let argc = unsafe { vm_ci_argc((*cd).ci) }; @@ -2190,7 +2194,12 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { YARVINSN_opt_send_without_block => { let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); let call_info = unsafe { rb_get_call_data_ci(cd) }; - filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?; + if unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) { + // Unknown call type; side-exit into the interpreter + let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); + fun.push_insn(block, Insn::SideExit { state: exit_id }); + break; // End the block + } let argc = unsafe { vm_ci_argc((*cd).ci) }; @@ -2213,7 +2222,12 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); let blockiseq: IseqPtr = get_arg(pc, 1).as_iseq(); let call_info = unsafe { rb_get_call_data_ci(cd) }; - filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?; + if unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) { + // Unknown call type; side-exit into the interpreter + let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); + fun.push_insn(block, Insn::SideExit { state: exit_id }); + break; // End the block + } let argc = unsafe { vm_ci_argc((*cd).ci) }; let method_name = unsafe { @@ -3077,7 +3091,13 @@ mod tests { eval(" def test(a) = foo(*a) "); - assert_compile_fails("test", ParseError::UnhandledCallType(CallType::Splat)) + assert_method_hir("test", expect![[r#" + fn test: + bb0(v0:BasicObject): + v2:BasicObject = PutSelf + v4:ArrayExact = ToArray v0 + SideExit + "#]]); } #[test] @@ -3085,7 +3105,12 @@ mod tests { eval(" def test(a) = foo(&a) "); - assert_compile_fails("test", ParseError::UnhandledCallType(CallType::BlockArg)) + assert_method_hir("test", expect![[r#" + fn test: + bb0(v0:BasicObject): + v2:BasicObject = PutSelf + SideExit + "#]]); } #[test] @@ -3093,7 +3118,13 @@ mod tests { eval(" def test(a) = foo(a: 1) "); - assert_compile_fails("test", ParseError::UnhandledCallType(CallType::Kwarg)) + assert_method_hir("test", expect![[r#" + fn test: + bb0(v0:BasicObject): + v2:BasicObject = PutSelf + v3:Fixnum[1] = Const Value(1) + SideExit + "#]]); } #[test] @@ -3101,7 +3132,12 @@ mod tests { eval(" def test(a) = foo(**a) "); - assert_compile_fails("test", ParseError::UnhandledCallType(CallType::KwSplat)) + assert_method_hir("test", expect![[r#" + fn test: + bb0(v0:BasicObject): + v2:BasicObject = PutSelf + SideExit + "#]]); } // TODO(max): Figure out how to generate a call with TAILCALL flag @@ -3165,7 +3201,15 @@ mod tests { eval(" def test(*) = foo *, 1 "); - assert_compile_fails("test", ParseError::UnhandledCallType(CallType::SplatMut)) + assert_method_hir("test", expect![[r#" + fn test: + bb0(v0:ArrayExact): + v2:BasicObject = PutSelf + v4:ArrayExact = ToNewArray v0 + v5:Fixnum[1] = Const Value(1) + ArrayPush v4, v5 + SideExit + "#]]); } #[test] |