Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

thurstond
Copy link
Contributor

@thurstond thurstond commented Sep 13, 2025

Adds a new handler, handleAVX512VectorGenericMaskedFP(), and applies it to AVX512/AVX10 rcp and rsqrt

Adds a new handler, handleAVX512VectorGenericMaskedFP(), and applies it
to rcp and rsqrt
@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Thurston Dang (thurstond)

Changes

Adds a new handler, handleAVX512VectorGenericMaskedFP(), and applies it to rcp and rsqrt


Patch is 63.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158397.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp (+165)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll (+12-26)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll (+16-20)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll (+188-220)
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 3ea790ad1839a..7933604b8ac25 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -4911,6 +4911,69 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     setOriginForNaryOp(I);
   }
 
+  // Handle llvm.x86.avx512.* instructions that take a vector of floating-point
+  // values and perform an operation whose shadow propagation should be handled
+  // as all-or-nothing [*], with masking provided by a vector and a mask
+  // supplied as an integer.
+  //
+  // [*] if all bits of a vector element are initialized, the output is fully
+  //     initialized; otherwise, the output is fully uninitialized
+  //
+  // e.g., <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
+  //                        (<16 x float>, <16 x float>, i16)
+  //                         A             WriteThru     Mask
+  //
+  //       <2 x double> @llvm.x86.avx512.rcp14.pd.128
+  //                        (<2 x double>, <2 x double>, i8)
+  //
+  // Dst[i]        = Mask[i] ? some_op(A[i]) : WriteThru[i]
+  // Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i]
+  void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I) {
+    IRBuilder<> IRB(&I);
+
+    assert(I.arg_size() == 3);
+    Value *A = I.getOperand(0);
+    Value *WriteThrough = I.getOperand(1);
+    Value *Mask = I.getOperand(2);
+
+    assert(isFixedFPVector(A));
+    assert(isFixedFPVector(WriteThrough));
+
+    [[maybe_unused]] unsigned ANumElements =
+        cast<FixedVectorType>(A->getType())->getNumElements();
+    unsigned OutputNumElements =
+        cast<FixedVectorType>(WriteThrough->getType())->getNumElements();
+    assert(ANumElements == OutputNumElements);
+
+    assert(Mask->getType()->isIntegerTy());
+    // Some bits of the mask might be unused, but check them all anyway
+    // (typically the mask is an integer constant).
+    insertCheckShadowOf(Mask, &I);
+
+    // The mask has 1 bit per element of A, but a minimum of 8 bits.
+    if (Mask->getType()->getScalarSizeInBits() == 8 && ANumElements < 8)
+      Mask = IRB.CreateTrunc(Mask, Type::getIntNTy(*MS.C, ANumElements));
+    assert(Mask->getType()->getScalarSizeInBits() == ANumElements);
+
+    assert(I.getType() == WriteThrough->getType());
+
+    Mask = IRB.CreateBitCast(
+        Mask, FixedVectorType::get(IRB.getInt1Ty(), OutputNumElements));
+
+    Value *AShadow = getShadow(A);
+
+    // All-or-nothing shadow
+    AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow)),
+                             AShadow->getType());
+
+    Value *WriteThroughShadow = getShadow(WriteThrough);
+
+    Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow);
+    setShadow(&I, Shadow);
+
+    setOriginForNaryOp(I);
+  }
+
   // For sh.* compiler intrinsics:
   //   llvm.x86.avx512fp16.mask.{add/sub/mul/div/max/min}.sh.round
   //     (<8 x half>, <8 x half>, <8 x half>, i8,  i32)
@@ -6091,6 +6154,108 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
       break;
     }
 
+    // AVX512/AVX10 Reciprocal
+    //   <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
+    //                    (<16 x float>, <16 x float>, i16)
+    //   <8 x float> @llvm.x86.avx512.rsqrt14.ps.256
+    //                    (<8 x float>, <8 x float>, i8)
+    //   <4 x float> @llvm.x86.avx512.rsqrt14.ps.128
+    //                    (<4 x float>, <4 x float>, i8)
+    //
+    //   <8 x double> @llvm.x86.avx512.rsqrt14.pd.512
+    //                    (<8 x double>, <8 x double>, i8)
+    //   <4 x double> @llvm.x86.avx512.rsqrt14.pd.256
+    //                    (<4 x double>, <4 x double>, i8)
+    //   <2 x double> @llvm.x86.avx512.rsqrt14.pd.128
+    //                    (<2 x double>, <2 x double>, i8)
+    //
+    //   <32 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.512
+    //                    (<32 x bfloat>, <32 x bfloat>, i32)
+    //   <16 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.256
+    //                    (<16 x bfloat>, <16 x bfloat>, i16)
+    //   <8 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.128
+    //                    (<8 x bfloat>, <8 x bfloat>, i8)
+    //
+    //   <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512
+    //                    (<32 x half>, <32 x half>, i32)
+    //   <16 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.256
+    //                    (<16 x half>, <16 x half>, i16)
+    //   <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.128
+    //                    (<8 x half>, <8 x half>, i8)
+    //
+    // TODO: 3-operand variants are not handled:
+    //   <2 x double> @llvm.x86.avx512.rsqrt14.sd
+    //                    (<2 x double>, <2 x double>, <2 x double>, i8)
+    //   <4 x float> @llvm.x86.avx512.rsqrt14.ss
+    //                    (<4 x float>, <4 x float>, <4 x float>, i8)
+    //   <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.sh
+    //                    (<8 x half>, <8 x half>, <8 x half>, i8)
+    case Intrinsic::x86_avx512_rsqrt14_ps_512:
+    case Intrinsic::x86_avx512_rsqrt14_ps_256:
+    case Intrinsic::x86_avx512_rsqrt14_ps_128:
+    case Intrinsic::x86_avx512_rsqrt14_pd_512:
+    case Intrinsic::x86_avx512_rsqrt14_pd_256:
+    case Intrinsic::x86_avx512_rsqrt14_pd_128:
+    case Intrinsic::x86_avx10_mask_rsqrt_bf16_512:
+    case Intrinsic::x86_avx10_mask_rsqrt_bf16_256:
+    case Intrinsic::x86_avx10_mask_rsqrt_bf16_128:
+    case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512:
+    case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256:
+    case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128:
+      handleAVX512VectorGenericMaskedFP(I);
+      break;
+
+    // AVX512/AVX10 Reciprocal Square Root
+    //   <16 x float> @llvm.x86.avx512.rcp14.ps.512
+    //                    (<16 x float>, <16 x float>, i16)
+    //   <8 x float> @llvm.x86.avx512.rcp14.ps.256
+    //                    (<8 x float>, <8 x float>, i8)
+    //   <4 x float> @llvm.x86.avx512.rcp14.ps.128
+    //                    (<4 x float>, <4 x float>, i8)
+    //
+    //   <8 x double> @llvm.x86.avx512.rcp14.pd.512
+    //                    (<8 x double>, <8 x double>, i8)
+    //   <4 x double> @llvm.x86.avx512.rcp14.pd.256
+    //                    (<4 x double>, <4 x double>, i8)
+    //   <2 x double> @llvm.x86.avx512.rcp14.pd.128
+    //                    (<2 x double>, <2 x double>, i8)
+    //
+    //   <32 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.512
+    //                    (<32 x bfloat>, <32 x bfloat>, i32)
+    //   <16 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.256
+    //                    (<16 x bfloat>, <16 x bfloat>, i16)
+    //   <8 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.128
+    //                    (<8 x bfloat>, <8 x bfloat>, i8)
+    //
+    //   <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512
+    //                    (<32 x half>, <32 x half>, i32)
+    //   <16 x half> @llvm.x86.avx512fp16.mask.rcp.ph.256
+    //                    (<16 x half>, <16 x half>, i16)
+    //   <8 x half> @llvm.x86.avx512fp16.mask.rcp.ph.128
+    //                    (<8 x half>, <8 x half>, i8)
+    //
+    // TODO: 3-operand variants are not handled:
+    //   <2 x double> @llvm.x86.avx512.rcp14.sd
+    //                    (<2 x double>, <2 x double>, <2 x double>, i8)
+    //   <4 x float> @llvm.x86.avx512.rcp14.ss
+    //                    (<4 x float>, <4 x float>, <4 x float>, i8)
+    //   <8 x half> @llvm.x86.avx512fp16.mask.rcp.sh
+    //                    (<8 x half>, <8 x half>, <8 x half>, i8)
+    case Intrinsic::x86_avx512_rcp14_ps_512:
+    case Intrinsic::x86_avx512_rcp14_ps_256:
+    case Intrinsic::x86_avx512_rcp14_ps_128:
+    case Intrinsic::x86_avx512_rcp14_pd_512:
+    case Intrinsic::x86_avx512_rcp14_pd_256:
+    case Intrinsic::x86_avx512_rcp14_pd_128:
+    case Intrinsic::x86_avx10_mask_rcp_bf16_512:
+    case Intrinsic::x86_avx10_mask_rcp_bf16_256:
+    case Intrinsic::x86_avx10_mask_rcp_bf16_128:
+    case Intrinsic::x86_avx512fp16_mask_rcp_ph_512:
+    case Intrinsic::x86_avx512fp16_mask_rcp_ph_256:
+    case Intrinsic::x86_avx512fp16_mask_rcp_ph_128:
+      handleAVX512VectorGenericMaskedFP(I);
+      break;
+
     // AVX512 FP16 Arithmetic
     case Intrinsic::x86_avx512fp16_mask_add_sh_round:
     case Intrinsic::x86_avx512fp16_mask_sub_sh_round:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
index a2f1d65e7cd41..b2a4f0e582f9e 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
@@ -28,8 +28,6 @@
 ; - llvm.x86.avx512.mul.pd.512, llvm.x86.avx512.mul.ps.512
 ; - llvm.x86.avx512.permvar.df.512, llvm.x86.avx512.permvar.sf.512
 ; - llvm.x86.avx512.pternlog.d.512, llvm.x86.avx512.pternlog.q.512
-; - llvm.x86.avx512.rcp14.pd.512, llvm.x86.avx512.rcp14.ps.512
-; - llvm.x86.avx512.rsqrt14.ps.512
 ; - llvm.x86.avx512.sitofp.round.v16f32.v16i32
 ; - llvm.x86.avx512.sqrt.pd.512, llvm.x86.avx512.sqrt.ps.512
 ; - llvm.x86.avx512.sub.ps.512
@@ -682,15 +680,11 @@ define <16 x float> @test_rcp_ps_512(<16 x float> %a0) #0 {
 ; CHECK-LABEL: @test_rcp_ps_512(
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT:    br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK:       3:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT:    unreachable
-; CHECK:       4:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.rcp14.ps.512(<16 x float> [[A0:%.*]], <16 x float> zeroinitializer, i16 -1)
-; CHECK-NEXT:    store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x float> [[RES]]
 ;
   %res = call <16 x float> @llvm.x86.avx512.rcp14.ps.512(<16 x float> %a0, <16 x float> zeroinitializer, i16 -1) ; <<16 x float>> [#uses=1]
@@ -702,15 +696,11 @@ define <8 x double> @test_rcp_pd_512(<8 x double> %a0) #0 {
 ; CHECK-LABEL: @test_rcp_pd_512(
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <8 x i64>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <8 x i64> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT:    br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK:       3:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT:    unreachable
-; CHECK:       4:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <8 x i64> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i64>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <8 x i1> splat (i1 true), <8 x i64> [[TMP3]], <8 x i64> zeroinitializer
 ; CHECK-NEXT:    [[RES:%.*]] = call <8 x double> @llvm.x86.avx512.rcp14.pd.512(<8 x double> [[A0:%.*]], <8 x double> zeroinitializer, i8 -1)
-; CHECK-NEXT:    store <8 x i64> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <8 x i64> [[TMP4]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <8 x double> [[RES]]
 ;
   %res = call <8 x double> @llvm.x86.avx512.rcp14.pd.512(<8 x double> %a0, <8 x double> zeroinitializer, i8 -1) ; <<8 x double>> [#uses=1]
@@ -1021,15 +1011,11 @@ define <16 x float> @test_rsqrt_ps_512(<16 x float> %a0) #0 {
 ; CHECK-LABEL: @test_rsqrt_ps_512(
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT:    br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK:       3:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT:    unreachable
-; CHECK:       4:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.rsqrt14.ps.512(<16 x float> [[A0:%.*]], <16 x float> zeroinitializer, i16 -1)
-; CHECK-NEXT:    store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x float> [[RES]]
 ;
   %res = call <16 x float> @llvm.x86.avx512.rsqrt14.ps.512(<16 x float> %a0, <16 x float> zeroinitializer, i16 -1) ; <<16 x float>> [#uses=1]
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
index c5d91adf64cb3..e5cbe8c132238 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
@@ -19,7 +19,6 @@
 ; - llvm.x86.avx512fp16.mask.reduce.sh
 ; - llvm.x86.avx512fp16.mask.rndscale.ph.512
 ; - llvm.x86.avx512fp16.mask.rndscale.sh
-; - llvm.x86.avx512fp16.mask.rsqrt.ph.512
 ; - llvm.x86.avx512fp16.mask.rsqrt.sh
 ; - llvm.x86.avx512fp16.mask.scalef.ph.512
 ; - llvm.x86.avx512fp16.mask.scalef.sh
@@ -442,15 +441,11 @@ define <32 x half> @test_rsqrt_ph_512(<32 x half> %a0) #0 {
 ; CHECK-SAME: <32 x half> [[A0:%.*]]) #[[ATTR1]] {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT:    br i1 [[_MSCMP]], label %[[BB3:.*]], label %[[BB4:.*]], !prof [[PROF1]]
-; CHECK:       [[BB3]]:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR8]]
-; CHECK-NEXT:    unreachable
-; CHECK:       [[BB4]]:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <32 x i1> [[TMP2]] to <32 x i16>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <32 x i1> splat (i1 true), <32 x i16> [[TMP3]], <32 x i16> zeroinitializer
 ; CHECK-NEXT:    [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512(<32 x half> [[A0]], <32 x half> zeroinitializer, i32 -1)
-; CHECK-NEXT:    store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <32 x i16> [[TMP4]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <32 x half> [[RES]]
 ;
   %res = call <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512(<32 x half> %a0, <32 x half> zeroinitializer, i32 -1)
@@ -681,24 +676,22 @@ declare <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half>, <32 x half
 define <32 x half> @test_rcp_ph_512(<32 x half> %a0, <32 x half> %a1, i32 %mask) #0 {
 ; CHECK-LABEL: define <32 x half> @test_rcp_ph_512(
 ; CHECK-SAME: <32 x half> [[A0:%.*]], <32 x half> [[A1:%.*]], i32 [[MASK:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <32 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
-; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0
-; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <32 x i16> [[TMP2]] to i512
-; CHECK-NEXT:    [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0
-; CHECK-NEXT:    [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i32 [[MASK]] to <32 x i1>
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <32 x i1> [[TMP5]] to <32 x i16>
+; CHECK-NEXT:    [[TMP7:%.*]] = select <32 x i1> [[TMP4]], <32 x i16> [[TMP6]], <32 x i16> [[TMP2]]
 ; CHECK-NEXT:    [[_MSCMP2:%.*]] = icmp ne i32 [[TMP3]], 0
-; CHECK-NEXT:    [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
-; CHECK-NEXT:    br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]]
-; CHECK:       [[BB6]]:
+; CHECK-NEXT:    br i1 [[_MSCMP2]], label %[[BB8:.*]], label %[[BB9:.*]], !prof [[PROF1]]
+; CHECK:       [[BB8]]:
 ; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR8]]
 ; CHECK-NEXT:    unreachable
-; CHECK:       [[BB7]]:
+; CHECK:       [[BB9]]:
 ; CHECK-NEXT:    [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half> [[A0]], <32 x half> [[A1]], i32 [[MASK]])
-; CHECK-NEXT:    store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <32 x i16> [[TMP7]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <32 x half> [[RES]]
 ;
   %res = call <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half> %a0, <32 x half> %a1, i32 %mask)
@@ -3260,3 +3253,6 @@ define <32 x half> @test_mm512_castph256_ph512_freeze(<16 x half> %a0) nounwind
 }
 
 attributes #0 = { sanitize_memory }
+;.
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 1, i32 1048575}
+;.
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
index e2dc8cbdca968..20114fe7d3151 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
@@ -63,10 +63,6 @@
 ; - llvm.x86.avx512.permvar.df.256
 ; - llvm.x86.avx512.pternlog.d.128, llvm.x86.avx512.pternlog.d.256
 ; - llvm.x86.avx512.pternlog.q.128, llvm.x86.avx512.pternlog.q.256
-; - llvm.x86.avx512.rcp14.pd.128, llvm.x86.avx512.rcp14.pd.256
-; - llvm.x86.avx512.rcp14.ps.128, llvm.x86.avx512.rcp14.ps.256
-; - llvm.x86.avx512.rsqrt14.pd.128, llvm.x86.avx512.rsqrt14.pd.256
-; - llvm.x86.avx512.rsqrt14.ps.128, llvm.x86.avx512.rsqrt14.ps.256
 ;
 ; Handled heuristically: (none)
 
@@ -8066,15 +8062,11 @@ define <8 x float> @test_rsqrt_ps_256_rr(<8 x float> %a0) #0 {
 ; CHECK-SAME: <8 x float> [[A0:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <8 x i32> [[TMP1]] to i256
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i256 [[TMP2]], 0
-; CHECK-NEXT:    br i1 [[_MSCMP]], label %[[BB3:.*]], label %[[BB4:.*]], !prof [[PROF1]]
-; CHECK:       [[BB3]]:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR6]]
-; CHECK-NEXT:    unreachable
-; CHECK:       [[BB4]]:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <8 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <8 x i1> splat (i1 true), <8 x i32> [[TMP3]], <8 x i32> zeroinitializer
 ; CHECK-NEXT:    [[RES:%.*]] = call <8 x float> @llvm.x86.avx512.rsqrt14.ps.256(<8 x float> [[A0]], <8 x float> zeroinitializer, i8 -1)
-; CHECK-NEXT:    store <8 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <8 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <8 x float> [[RES]]
 ;
   %res = call <8 x float> @llvm.x86.avx512.rsqrt14.ps.256(<8 x float> %a0, <8 x float> zeroinitializer, i8 -1)
@@ -8085,20 +8077,21 @@ define <8 x float> @test_rsqrt_ps_256_rrkz(<8 x float> %a0, i8 %mask) #0 {
 ;
 ; CHECK-LABEL: define <8 x float> @test_rsqrt_ps_256_rrkz(
 ; CHECK-SAME: <8 x float> [[A0:%.*]], i8 [[MASK:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load <8...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2025

@llvm/pr-subscribers-compiler-rt-sanitizer

Author: Thurston Dang (thurstond)

Changes

Adds a new handler, handleAVX512VectorGenericMaskedFP(), and applies it to rcp and rsqrt


Patch is 63.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158397.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp (+165)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll (+12-26)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll (+16-20)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll (+188-220)
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 3ea790ad1839a..7933604b8ac25 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -4911,6 +4911,69 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     setOriginForNaryOp(I);
   }
 
+  // Handle llvm.x86.avx512.* instructions that take a vector of floating-point
+  // values and perform an operation whose shadow propagation should be handled
+  // as all-or-nothing [*], with masking provided by a vector and a mask
+  // supplied as an integer.
+  //
+  // [*] if all bits of a vector element are initialized, the output is fully
+  //     initialized; otherwise, the output is fully uninitialized
+  //
+  // e.g., <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
+  //                        (<16 x float>, <16 x float>, i16)
+  //                         A             WriteThru     Mask
+  //
+  //       <2 x double> @llvm.x86.avx512.rcp14.pd.128
+  //                        (<2 x double>, <2 x double>, i8)
+  //
+  // Dst[i]        = Mask[i] ? some_op(A[i]) : WriteThru[i]
+  // Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i]
+  void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I) {
+    IRBuilder<> IRB(&I);
+
+    assert(I.arg_size() == 3);
+    Value *A = I.getOperand(0);
+    Value *WriteThrough = I.getOperand(1);
+    Value *Mask = I.getOperand(2);
+
+    assert(isFixedFPVector(A));
+    assert(isFixedFPVector(WriteThrough));
+
+    [[maybe_unused]] unsigned ANumElements =
+        cast<FixedVectorType>(A->getType())->getNumElements();
+    unsigned OutputNumElements =
+        cast<FixedVectorType>(WriteThrough->getType())->getNumElements();
+    assert(ANumElements == OutputNumElements);
+
+    assert(Mask->getType()->isIntegerTy());
+    // Some bits of the mask might be unused, but check them all anyway
+    // (typically the mask is an integer constant).
+    insertCheckShadowOf(Mask, &I);
+
+    // The mask has 1 bit per element of A, but a minimum of 8 bits.
+    if (Mask->getType()->getScalarSizeInBits() == 8 && ANumElements < 8)
+      Mask = IRB.CreateTrunc(Mask, Type::getIntNTy(*MS.C, ANumElements));
+    assert(Mask->getType()->getScalarSizeInBits() == ANumElements);
+
+    assert(I.getType() == WriteThrough->getType());
+
+    Mask = IRB.CreateBitCast(
+        Mask, FixedVectorType::get(IRB.getInt1Ty(), OutputNumElements));
+
+    Value *AShadow = getShadow(A);
+
+    // All-or-nothing shadow
+    AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow)),
+                             AShadow->getType());
+
+    Value *WriteThroughShadow = getShadow(WriteThrough);
+
+    Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow);
+    setShadow(&I, Shadow);
+
+    setOriginForNaryOp(I);
+  }
+
   // For sh.* compiler intrinsics:
   //   llvm.x86.avx512fp16.mask.{add/sub/mul/div/max/min}.sh.round
   //     (<8 x half>, <8 x half>, <8 x half>, i8,  i32)
@@ -6091,6 +6154,108 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
       break;
     }
 
+    // AVX512/AVX10 Reciprocal
+    //   <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
+    //                    (<16 x float>, <16 x float>, i16)
+    //   <8 x float> @llvm.x86.avx512.rsqrt14.ps.256
+    //                    (<8 x float>, <8 x float>, i8)
+    //   <4 x float> @llvm.x86.avx512.rsqrt14.ps.128
+    //                    (<4 x float>, <4 x float>, i8)
+    //
+    //   <8 x double> @llvm.x86.avx512.rsqrt14.pd.512
+    //                    (<8 x double>, <8 x double>, i8)
+    //   <4 x double> @llvm.x86.avx512.rsqrt14.pd.256
+    //                    (<4 x double>, <4 x double>, i8)
+    //   <2 x double> @llvm.x86.avx512.rsqrt14.pd.128
+    //                    (<2 x double>, <2 x double>, i8)
+    //
+    //   <32 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.512
+    //                    (<32 x bfloat>, <32 x bfloat>, i32)
+    //   <16 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.256
+    //                    (<16 x bfloat>, <16 x bfloat>, i16)
+    //   <8 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.128
+    //                    (<8 x bfloat>, <8 x bfloat>, i8)
+    //
+    //   <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512
+    //                    (<32 x half>, <32 x half>, i32)
+    //   <16 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.256
+    //                    (<16 x half>, <16 x half>, i16)
+    //   <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.128
+    //                    (<8 x half>, <8 x half>, i8)
+    //
+    // TODO: 3-operand variants are not handled:
+    //   <2 x double> @llvm.x86.avx512.rsqrt14.sd
+    //                    (<2 x double>, <2 x double>, <2 x double>, i8)
+    //   <4 x float> @llvm.x86.avx512.rsqrt14.ss
+    //                    (<4 x float>, <4 x float>, <4 x float>, i8)
+    //   <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.sh
+    //                    (<8 x half>, <8 x half>, <8 x half>, i8)
+    case Intrinsic::x86_avx512_rsqrt14_ps_512:
+    case Intrinsic::x86_avx512_rsqrt14_ps_256:
+    case Intrinsic::x86_avx512_rsqrt14_ps_128:
+    case Intrinsic::x86_avx512_rsqrt14_pd_512:
+    case Intrinsic::x86_avx512_rsqrt14_pd_256:
+    case Intrinsic::x86_avx512_rsqrt14_pd_128:
+    case Intrinsic::x86_avx10_mask_rsqrt_bf16_512:
+    case Intrinsic::x86_avx10_mask_rsqrt_bf16_256:
+    case Intrinsic::x86_avx10_mask_rsqrt_bf16_128:
+    case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512:
+    case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256:
+    case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128:
+      handleAVX512VectorGenericMaskedFP(I);
+      break;
+
+    // AVX512/AVX10 Reciprocal Square Root
+    //   <16 x float> @llvm.x86.avx512.rcp14.ps.512
+    //                    (<16 x float>, <16 x float>, i16)
+    //   <8 x float> @llvm.x86.avx512.rcp14.ps.256
+    //                    (<8 x float>, <8 x float>, i8)
+    //   <4 x float> @llvm.x86.avx512.rcp14.ps.128
+    //                    (<4 x float>, <4 x float>, i8)
+    //
+    //   <8 x double> @llvm.x86.avx512.rcp14.pd.512
+    //                    (<8 x double>, <8 x double>, i8)
+    //   <4 x double> @llvm.x86.avx512.rcp14.pd.256
+    //                    (<4 x double>, <4 x double>, i8)
+    //   <2 x double> @llvm.x86.avx512.rcp14.pd.128
+    //                    (<2 x double>, <2 x double>, i8)
+    //
+    //   <32 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.512
+    //                    (<32 x bfloat>, <32 x bfloat>, i32)
+    //   <16 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.256
+    //                    (<16 x bfloat>, <16 x bfloat>, i16)
+    //   <8 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.128
+    //                    (<8 x bfloat>, <8 x bfloat>, i8)
+    //
+    //   <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512
+    //                    (<32 x half>, <32 x half>, i32)
+    //   <16 x half> @llvm.x86.avx512fp16.mask.rcp.ph.256
+    //                    (<16 x half>, <16 x half>, i16)
+    //   <8 x half> @llvm.x86.avx512fp16.mask.rcp.ph.128
+    //                    (<8 x half>, <8 x half>, i8)
+    //
+    // TODO: 3-operand variants are not handled:
+    //   <2 x double> @llvm.x86.avx512.rcp14.sd
+    //                    (<2 x double>, <2 x double>, <2 x double>, i8)
+    //   <4 x float> @llvm.x86.avx512.rcp14.ss
+    //                    (<4 x float>, <4 x float>, <4 x float>, i8)
+    //   <8 x half> @llvm.x86.avx512fp16.mask.rcp.sh
+    //                    (<8 x half>, <8 x half>, <8 x half>, i8)
+    case Intrinsic::x86_avx512_rcp14_ps_512:
+    case Intrinsic::x86_avx512_rcp14_ps_256:
+    case Intrinsic::x86_avx512_rcp14_ps_128:
+    case Intrinsic::x86_avx512_rcp14_pd_512:
+    case Intrinsic::x86_avx512_rcp14_pd_256:
+    case Intrinsic::x86_avx512_rcp14_pd_128:
+    case Intrinsic::x86_avx10_mask_rcp_bf16_512:
+    case Intrinsic::x86_avx10_mask_rcp_bf16_256:
+    case Intrinsic::x86_avx10_mask_rcp_bf16_128:
+    case Intrinsic::x86_avx512fp16_mask_rcp_ph_512:
+    case Intrinsic::x86_avx512fp16_mask_rcp_ph_256:
+    case Intrinsic::x86_avx512fp16_mask_rcp_ph_128:
+      handleAVX512VectorGenericMaskedFP(I);
+      break;
+
     // AVX512 FP16 Arithmetic
     case Intrinsic::x86_avx512fp16_mask_add_sh_round:
     case Intrinsic::x86_avx512fp16_mask_sub_sh_round:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
index a2f1d65e7cd41..b2a4f0e582f9e 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
@@ -28,8 +28,6 @@
 ; - llvm.x86.avx512.mul.pd.512, llvm.x86.avx512.mul.ps.512
 ; - llvm.x86.avx512.permvar.df.512, llvm.x86.avx512.permvar.sf.512
 ; - llvm.x86.avx512.pternlog.d.512, llvm.x86.avx512.pternlog.q.512
-; - llvm.x86.avx512.rcp14.pd.512, llvm.x86.avx512.rcp14.ps.512
-; - llvm.x86.avx512.rsqrt14.ps.512
 ; - llvm.x86.avx512.sitofp.round.v16f32.v16i32
 ; - llvm.x86.avx512.sqrt.pd.512, llvm.x86.avx512.sqrt.ps.512
 ; - llvm.x86.avx512.sub.ps.512
@@ -682,15 +680,11 @@ define <16 x float> @test_rcp_ps_512(<16 x float> %a0) #0 {
 ; CHECK-LABEL: @test_rcp_ps_512(
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT:    br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK:       3:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT:    unreachable
-; CHECK:       4:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.rcp14.ps.512(<16 x float> [[A0:%.*]], <16 x float> zeroinitializer, i16 -1)
-; CHECK-NEXT:    store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x float> [[RES]]
 ;
   %res = call <16 x float> @llvm.x86.avx512.rcp14.ps.512(<16 x float> %a0, <16 x float> zeroinitializer, i16 -1) ; <<16 x float>> [#uses=1]
@@ -702,15 +696,11 @@ define <8 x double> @test_rcp_pd_512(<8 x double> %a0) #0 {
 ; CHECK-LABEL: @test_rcp_pd_512(
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <8 x i64>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <8 x i64> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT:    br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK:       3:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT:    unreachable
-; CHECK:       4:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <8 x i64> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i64>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <8 x i1> splat (i1 true), <8 x i64> [[TMP3]], <8 x i64> zeroinitializer
 ; CHECK-NEXT:    [[RES:%.*]] = call <8 x double> @llvm.x86.avx512.rcp14.pd.512(<8 x double> [[A0:%.*]], <8 x double> zeroinitializer, i8 -1)
-; CHECK-NEXT:    store <8 x i64> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <8 x i64> [[TMP4]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <8 x double> [[RES]]
 ;
   %res = call <8 x double> @llvm.x86.avx512.rcp14.pd.512(<8 x double> %a0, <8 x double> zeroinitializer, i8 -1) ; <<8 x double>> [#uses=1]
@@ -1021,15 +1011,11 @@ define <16 x float> @test_rsqrt_ps_512(<16 x float> %a0) #0 {
 ; CHECK-LABEL: @test_rsqrt_ps_512(
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT:    br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK:       3:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT:    unreachable
-; CHECK:       4:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.rsqrt14.ps.512(<16 x float> [[A0:%.*]], <16 x float> zeroinitializer, i16 -1)
-; CHECK-NEXT:    store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x float> [[RES]]
 ;
   %res = call <16 x float> @llvm.x86.avx512.rsqrt14.ps.512(<16 x float> %a0, <16 x float> zeroinitializer, i16 -1) ; <<16 x float>> [#uses=1]
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
index c5d91adf64cb3..e5cbe8c132238 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
@@ -19,7 +19,6 @@
 ; - llvm.x86.avx512fp16.mask.reduce.sh
 ; - llvm.x86.avx512fp16.mask.rndscale.ph.512
 ; - llvm.x86.avx512fp16.mask.rndscale.sh
-; - llvm.x86.avx512fp16.mask.rsqrt.ph.512
 ; - llvm.x86.avx512fp16.mask.rsqrt.sh
 ; - llvm.x86.avx512fp16.mask.scalef.ph.512
 ; - llvm.x86.avx512fp16.mask.scalef.sh
@@ -442,15 +441,11 @@ define <32 x half> @test_rsqrt_ph_512(<32 x half> %a0) #0 {
 ; CHECK-SAME: <32 x half> [[A0:%.*]]) #[[ATTR1]] {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT:    br i1 [[_MSCMP]], label %[[BB3:.*]], label %[[BB4:.*]], !prof [[PROF1]]
-; CHECK:       [[BB3]]:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR8]]
-; CHECK-NEXT:    unreachable
-; CHECK:       [[BB4]]:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <32 x i1> [[TMP2]] to <32 x i16>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <32 x i1> splat (i1 true), <32 x i16> [[TMP3]], <32 x i16> zeroinitializer
 ; CHECK-NEXT:    [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512(<32 x half> [[A0]], <32 x half> zeroinitializer, i32 -1)
-; CHECK-NEXT:    store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <32 x i16> [[TMP4]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <32 x half> [[RES]]
 ;
   %res = call <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512(<32 x half> %a0, <32 x half> zeroinitializer, i32 -1)
@@ -681,24 +676,22 @@ declare <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half>, <32 x half
 define <32 x half> @test_rcp_ph_512(<32 x half> %a0, <32 x half> %a1, i32 %mask) #0 {
 ; CHECK-LABEL: define <32 x half> @test_rcp_ph_512(
 ; CHECK-SAME: <32 x half> [[A0:%.*]], <32 x half> [[A1:%.*]], i32 [[MASK:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <32 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
-; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0
-; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <32 x i16> [[TMP2]] to i512
-; CHECK-NEXT:    [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0
-; CHECK-NEXT:    [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i32 [[MASK]] to <32 x i1>
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <32 x i1> [[TMP5]] to <32 x i16>
+; CHECK-NEXT:    [[TMP7:%.*]] = select <32 x i1> [[TMP4]], <32 x i16> [[TMP6]], <32 x i16> [[TMP2]]
 ; CHECK-NEXT:    [[_MSCMP2:%.*]] = icmp ne i32 [[TMP3]], 0
-; CHECK-NEXT:    [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
-; CHECK-NEXT:    br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]]
-; CHECK:       [[BB6]]:
+; CHECK-NEXT:    br i1 [[_MSCMP2]], label %[[BB8:.*]], label %[[BB9:.*]], !prof [[PROF1]]
+; CHECK:       [[BB8]]:
 ; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR8]]
 ; CHECK-NEXT:    unreachable
-; CHECK:       [[BB7]]:
+; CHECK:       [[BB9]]:
 ; CHECK-NEXT:    [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half> [[A0]], <32 x half> [[A1]], i32 [[MASK]])
-; CHECK-NEXT:    store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <32 x i16> [[TMP7]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <32 x half> [[RES]]
 ;
   %res = call <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half> %a0, <32 x half> %a1, i32 %mask)
@@ -3260,3 +3253,6 @@ define <32 x half> @test_mm512_castph256_ph512_freeze(<16 x half> %a0) nounwind
 }
 
 attributes #0 = { sanitize_memory }
+;.
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 1, i32 1048575}
+;.
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
index e2dc8cbdca968..20114fe7d3151 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
@@ -63,10 +63,6 @@
 ; - llvm.x86.avx512.permvar.df.256
 ; - llvm.x86.avx512.pternlog.d.128, llvm.x86.avx512.pternlog.d.256
 ; - llvm.x86.avx512.pternlog.q.128, llvm.x86.avx512.pternlog.q.256
-; - llvm.x86.avx512.rcp14.pd.128, llvm.x86.avx512.rcp14.pd.256
-; - llvm.x86.avx512.rcp14.ps.128, llvm.x86.avx512.rcp14.ps.256
-; - llvm.x86.avx512.rsqrt14.pd.128, llvm.x86.avx512.rsqrt14.pd.256
-; - llvm.x86.avx512.rsqrt14.ps.128, llvm.x86.avx512.rsqrt14.ps.256
 ;
 ; Handled heuristically: (none)
 
@@ -8066,15 +8062,11 @@ define <8 x float> @test_rsqrt_ps_256_rr(<8 x float> %a0) #0 {
 ; CHECK-SAME: <8 x float> [[A0:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <8 x i32> [[TMP1]] to i256
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i256 [[TMP2]], 0
-; CHECK-NEXT:    br i1 [[_MSCMP]], label %[[BB3:.*]], label %[[BB4:.*]], !prof [[PROF1]]
-; CHECK:       [[BB3]]:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR6]]
-; CHECK-NEXT:    unreachable
-; CHECK:       [[BB4]]:
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <8 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <8 x i1> splat (i1 true), <8 x i32> [[TMP3]], <8 x i32> zeroinitializer
 ; CHECK-NEXT:    [[RES:%.*]] = call <8 x float> @llvm.x86.avx512.rsqrt14.ps.256(<8 x float> [[A0]], <8 x float> zeroinitializer, i8 -1)
-; CHECK-NEXT:    store <8 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <8 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <8 x float> [[RES]]
 ;
   %res = call <8 x float> @llvm.x86.avx512.rsqrt14.ps.256(<8 x float> %a0, <8 x float> zeroinitializer, i8 -1)
@@ -8085,20 +8077,21 @@ define <8 x float> @test_rsqrt_ps_256_rrkz(<8 x float> %a0, i8 %mask) #0 {
 ;
 ; CHECK-LABEL: define <8 x float> @test_rsqrt_ps_256_rrkz(
 ; CHECK-SAME: <8 x float> [[A0:%.*]], i8 [[MASK:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load <8...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants