diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs index 8dc53302835d23..b6bbadd8012f62 100644 --- a/zjit/src/cruby_methods.rs +++ b/zjit/src/cruby_methods.rs @@ -221,7 +221,7 @@ pub fn init() -> Annotations { annotate!(rb_mKernel, "nil?", inline_kernel_nil_p); annotate!(rb_mKernel, "respond_to?", inline_kernel_respond_to_p); annotate!(rb_cBasicObject, "==", inline_basic_object_eq, types::BoolExact, no_gc, leaf, elidable); - annotate!(rb_cBasicObject, "!", types::BoolExact, no_gc, leaf, elidable); + annotate!(rb_cBasicObject, "!", inline_basic_object_not, types::BoolExact, no_gc, leaf, elidable); annotate!(rb_cBasicObject, "!=", inline_basic_object_neq, types::BoolExact); annotate!(rb_cBasicObject, "initialize", inline_basic_object_initialize); annotate!(rb_cInteger, "succ", inline_integer_succ); @@ -517,6 +517,19 @@ fn inline_basic_object_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hi Some(result) } +fn inline_basic_object_not(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option { + let &[] = args else { return None; }; + if fun.type_of(recv).is_known_truthy() { + let result = fun.push_insn(block, hir::Insn::Const { val: hir::Const::Value(Qfalse) }); + return Some(result); + } + if fun.type_of(recv).is_known_falsy() { + let result = fun.push_insn(block, hir::Insn::Const { val: hir::Const::Value(Qtrue) }); + return Some(result); + } + None +} + fn inline_basic_object_neq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option { let &[other] = args else { return None; }; let result = try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumNeq { left, right }, BOP_NEQ, recv, other, state); diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index 2c56120bf621be..1276131f3e5519 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -4630,7 +4630,7 @@ mod hir_opt_tests { } #[test] - fn test_specialize_basicobject_not_to_ccall() { + fn test_specialize_basicobject_not_truthy() { eval(" def test(a) = !a @@ -4650,10 +4650,104 @@ mod hir_opt_tests { PatchPoint MethodRedefined(Array@0x1000, !@0x1008, cme:0x1010) PatchPoint NoSingletonClass(Array@0x1000) v23:ArrayExact = GuardType v9, ArrayExact + v24:FalseClass = Const Value(false) IncrCounter inline_cfunc_optimized_send_count - v25:BoolExact = CCall !@0x1038, v23 CheckInterrupts - Return v25 + Return v24 + "); + } + + #[test] + fn test_specialize_basicobject_not_false() { + eval(" + def test(a) = !a + + test(false) + "); + assert_snapshot!(hir_string("test"), @r" + fn test@:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + PatchPoint MethodRedefined(FalseClass@0x1000, !@0x1008, cme:0x1010) + v22:FalseClass = GuardType v9, FalseClass + v23:TrueClass = Const Value(true) + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v23 + "); + } + + #[test] + fn test_specialize_basicobject_not_nil() { + eval(" + def test(a) = !a + + test(nil) + "); + assert_snapshot!(hir_string("test"), @r" + fn test@:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + PatchPoint MethodRedefined(NilClass@0x1000, !@0x1008, cme:0x1010) + v22:NilClass = GuardType v9, NilClass + v23:TrueClass = Const Value(true) + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v23 + "); + } + + #[test] + fn test_specialize_basicobject_not_falsy() { + eval(" + def test(a) = !(if a then false else nil end) + + # TODO(max): Make this not GuardType NilClass and instead just reason + # statically + test(false) + test(true) + "); + assert_snapshot!(hir_string("test"), @r" + fn test@:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + CheckInterrupts + v15:CBool = Test v9 + IfFalse v15, bb3(v8, v9) + v18:FalseClass = Const Value(false) + CheckInterrupts + Jump bb4(v8, v9, v18) + bb3(v22:BasicObject, v23:BasicObject): + v26:NilClass = Const Value(nil) + Jump bb4(v22, v23, v26) + bb4(v28:BasicObject, v29:BasicObject, v30:NilClass|FalseClass): + PatchPoint MethodRedefined(NilClass@0x1000, !@0x1008, cme:0x1010) + v41:NilClass = GuardType v30, NilClass + v42:TrueClass = Const Value(true) + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v42 "); }