-
Notifications
You must be signed in to change notification settings - Fork 15k
[msan] Handle AVX512/AVX10 rcp and rsqrt #158397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Adds a new handler, handleAVX512VectorGenericMaskedFP(), and applies it to rcp and rsqrt
@llvm/pr-subscribers-llvm-transforms Author: Thurston Dang (thurstond) ChangesAdds a new handler, handleAVX512VectorGenericMaskedFP(), and applies it to rcp and rsqrt Patch is 63.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158397.diff 4 Files Affected:
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 3ea790ad1839a..7933604b8ac25 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -4911,6 +4911,69 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOriginForNaryOp(I);
}
+ // Handle llvm.x86.avx512.* instructions that take a vector of floating-point
+ // values and perform an operation whose shadow propagation should be handled
+ // as all-or-nothing [*], with masking provided by a vector and a mask
+ // supplied as an integer.
+ //
+ // [*] if all bits of a vector element are initialized, the output is fully
+ // initialized; otherwise, the output is fully uninitialized
+ //
+ // e.g., <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
+ // (<16 x float>, <16 x float>, i16)
+ // A WriteThru Mask
+ //
+ // <2 x double> @llvm.x86.avx512.rcp14.pd.128
+ // (<2 x double>, <2 x double>, i8)
+ //
+ // Dst[i] = Mask[i] ? some_op(A[i]) : WriteThru[i]
+ // Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i]
+ void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I) {
+ IRBuilder<> IRB(&I);
+
+ assert(I.arg_size() == 3);
+ Value *A = I.getOperand(0);
+ Value *WriteThrough = I.getOperand(1);
+ Value *Mask = I.getOperand(2);
+
+ assert(isFixedFPVector(A));
+ assert(isFixedFPVector(WriteThrough));
+
+ [[maybe_unused]] unsigned ANumElements =
+ cast<FixedVectorType>(A->getType())->getNumElements();
+ unsigned OutputNumElements =
+ cast<FixedVectorType>(WriteThrough->getType())->getNumElements();
+ assert(ANumElements == OutputNumElements);
+
+ assert(Mask->getType()->isIntegerTy());
+ // Some bits of the mask might be unused, but check them all anyway
+ // (typically the mask is an integer constant).
+ insertCheckShadowOf(Mask, &I);
+
+ // The mask has 1 bit per element of A, but a minimum of 8 bits.
+ if (Mask->getType()->getScalarSizeInBits() == 8 && ANumElements < 8)
+ Mask = IRB.CreateTrunc(Mask, Type::getIntNTy(*MS.C, ANumElements));
+ assert(Mask->getType()->getScalarSizeInBits() == ANumElements);
+
+ assert(I.getType() == WriteThrough->getType());
+
+ Mask = IRB.CreateBitCast(
+ Mask, FixedVectorType::get(IRB.getInt1Ty(), OutputNumElements));
+
+ Value *AShadow = getShadow(A);
+
+ // All-or-nothing shadow
+ AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow)),
+ AShadow->getType());
+
+ Value *WriteThroughShadow = getShadow(WriteThrough);
+
+ Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow);
+ setShadow(&I, Shadow);
+
+ setOriginForNaryOp(I);
+ }
+
// For sh.* compiler intrinsics:
// llvm.x86.avx512fp16.mask.{add/sub/mul/div/max/min}.sh.round
// (<8 x half>, <8 x half>, <8 x half>, i8, i32)
@@ -6091,6 +6154,108 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
break;
}
+ // AVX512/AVX10 Reciprocal
+ // <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
+ // (<16 x float>, <16 x float>, i16)
+ // <8 x float> @llvm.x86.avx512.rsqrt14.ps.256
+ // (<8 x float>, <8 x float>, i8)
+ // <4 x float> @llvm.x86.avx512.rsqrt14.ps.128
+ // (<4 x float>, <4 x float>, i8)
+ //
+ // <8 x double> @llvm.x86.avx512.rsqrt14.pd.512
+ // (<8 x double>, <8 x double>, i8)
+ // <4 x double> @llvm.x86.avx512.rsqrt14.pd.256
+ // (<4 x double>, <4 x double>, i8)
+ // <2 x double> @llvm.x86.avx512.rsqrt14.pd.128
+ // (<2 x double>, <2 x double>, i8)
+ //
+ // <32 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.512
+ // (<32 x bfloat>, <32 x bfloat>, i32)
+ // <16 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.256
+ // (<16 x bfloat>, <16 x bfloat>, i16)
+ // <8 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.128
+ // (<8 x bfloat>, <8 x bfloat>, i8)
+ //
+ // <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512
+ // (<32 x half>, <32 x half>, i32)
+ // <16 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.256
+ // (<16 x half>, <16 x half>, i16)
+ // <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.128
+ // (<8 x half>, <8 x half>, i8)
+ //
+ // TODO: 3-operand variants are not handled:
+ // <2 x double> @llvm.x86.avx512.rsqrt14.sd
+ // (<2 x double>, <2 x double>, <2 x double>, i8)
+ // <4 x float> @llvm.x86.avx512.rsqrt14.ss
+ // (<4 x float>, <4 x float>, <4 x float>, i8)
+ // <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.sh
+ // (<8 x half>, <8 x half>, <8 x half>, i8)
+ case Intrinsic::x86_avx512_rsqrt14_ps_512:
+ case Intrinsic::x86_avx512_rsqrt14_ps_256:
+ case Intrinsic::x86_avx512_rsqrt14_ps_128:
+ case Intrinsic::x86_avx512_rsqrt14_pd_512:
+ case Intrinsic::x86_avx512_rsqrt14_pd_256:
+ case Intrinsic::x86_avx512_rsqrt14_pd_128:
+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_512:
+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_256:
+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_128:
+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512:
+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256:
+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128:
+ handleAVX512VectorGenericMaskedFP(I);
+ break;
+
+ // AVX512/AVX10 Reciprocal Square Root
+ // <16 x float> @llvm.x86.avx512.rcp14.ps.512
+ // (<16 x float>, <16 x float>, i16)
+ // <8 x float> @llvm.x86.avx512.rcp14.ps.256
+ // (<8 x float>, <8 x float>, i8)
+ // <4 x float> @llvm.x86.avx512.rcp14.ps.128
+ // (<4 x float>, <4 x float>, i8)
+ //
+ // <8 x double> @llvm.x86.avx512.rcp14.pd.512
+ // (<8 x double>, <8 x double>, i8)
+ // <4 x double> @llvm.x86.avx512.rcp14.pd.256
+ // (<4 x double>, <4 x double>, i8)
+ // <2 x double> @llvm.x86.avx512.rcp14.pd.128
+ // (<2 x double>, <2 x double>, i8)
+ //
+ // <32 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.512
+ // (<32 x bfloat>, <32 x bfloat>, i32)
+ // <16 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.256
+ // (<16 x bfloat>, <16 x bfloat>, i16)
+ // <8 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.128
+ // (<8 x bfloat>, <8 x bfloat>, i8)
+ //
+ // <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512
+ // (<32 x half>, <32 x half>, i32)
+ // <16 x half> @llvm.x86.avx512fp16.mask.rcp.ph.256
+ // (<16 x half>, <16 x half>, i16)
+ // <8 x half> @llvm.x86.avx512fp16.mask.rcp.ph.128
+ // (<8 x half>, <8 x half>, i8)
+ //
+ // TODO: 3-operand variants are not handled:
+ // <2 x double> @llvm.x86.avx512.rcp14.sd
+ // (<2 x double>, <2 x double>, <2 x double>, i8)
+ // <4 x float> @llvm.x86.avx512.rcp14.ss
+ // (<4 x float>, <4 x float>, <4 x float>, i8)
+ // <8 x half> @llvm.x86.avx512fp16.mask.rcp.sh
+ // (<8 x half>, <8 x half>, <8 x half>, i8)
+ case Intrinsic::x86_avx512_rcp14_ps_512:
+ case Intrinsic::x86_avx512_rcp14_ps_256:
+ case Intrinsic::x86_avx512_rcp14_ps_128:
+ case Intrinsic::x86_avx512_rcp14_pd_512:
+ case Intrinsic::x86_avx512_rcp14_pd_256:
+ case Intrinsic::x86_avx512_rcp14_pd_128:
+ case Intrinsic::x86_avx10_mask_rcp_bf16_512:
+ case Intrinsic::x86_avx10_mask_rcp_bf16_256:
+ case Intrinsic::x86_avx10_mask_rcp_bf16_128:
+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_512:
+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_256:
+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_128:
+ handleAVX512VectorGenericMaskedFP(I);
+ break;
+
// AVX512 FP16 Arithmetic
case Intrinsic::x86_avx512fp16_mask_add_sh_round:
case Intrinsic::x86_avx512fp16_mask_sub_sh_round:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
index a2f1d65e7cd41..b2a4f0e582f9e 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
@@ -28,8 +28,6 @@
; - llvm.x86.avx512.mul.pd.512, llvm.x86.avx512.mul.ps.512
; - llvm.x86.avx512.permvar.df.512, llvm.x86.avx512.permvar.sf.512
; - llvm.x86.avx512.pternlog.d.512, llvm.x86.avx512.pternlog.q.512
-; - llvm.x86.avx512.rcp14.pd.512, llvm.x86.avx512.rcp14.ps.512
-; - llvm.x86.avx512.rsqrt14.ps.512
; - llvm.x86.avx512.sitofp.round.v16f32.v16i32
; - llvm.x86.avx512.sqrt.pd.512, llvm.x86.avx512.sqrt.ps.512
; - llvm.x86.avx512.sub.ps.512
@@ -682,15 +680,11 @@ define <16 x float> @test_rcp_ps_512(<16 x float> %a0) #0 {
; CHECK-LABEL: @test_rcp_ps_512(
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK: 3:
-; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT: unreachable
-; CHECK: 4:
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
+; CHECK-NEXT: [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.rcp14.ps.512(<16 x float> [[A0:%.*]], <16 x float> zeroinitializer, i16 -1)
-; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <16 x float> [[RES]]
;
%res = call <16 x float> @llvm.x86.avx512.rcp14.ps.512(<16 x float> %a0, <16 x float> zeroinitializer, i16 -1) ; <<16 x float>> [#uses=1]
@@ -702,15 +696,11 @@ define <8 x double> @test_rcp_pd_512(<8 x double> %a0) #0 {
; CHECK-LABEL: @test_rcp_pd_512(
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i64>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i64> [[TMP1]] to i512
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK: 3:
-; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT: unreachable
-; CHECK: 4:
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <8 x i64> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i64>
+; CHECK-NEXT: [[TMP4:%.*]] = select <8 x i1> splat (i1 true), <8 x i64> [[TMP3]], <8 x i64> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <8 x double> @llvm.x86.avx512.rcp14.pd.512(<8 x double> [[A0:%.*]], <8 x double> zeroinitializer, i8 -1)
-; CHECK-NEXT: store <8 x i64> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <8 x i64> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x double> [[RES]]
;
%res = call <8 x double> @llvm.x86.avx512.rcp14.pd.512(<8 x double> %a0, <8 x double> zeroinitializer, i8 -1) ; <<8 x double>> [#uses=1]
@@ -1021,15 +1011,11 @@ define <16 x float> @test_rsqrt_ps_512(<16 x float> %a0) #0 {
; CHECK-LABEL: @test_rsqrt_ps_512(
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK: 3:
-; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT: unreachable
-; CHECK: 4:
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
+; CHECK-NEXT: [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.rsqrt14.ps.512(<16 x float> [[A0:%.*]], <16 x float> zeroinitializer, i16 -1)
-; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <16 x float> [[RES]]
;
%res = call <16 x float> @llvm.x86.avx512.rsqrt14.ps.512(<16 x float> %a0, <16 x float> zeroinitializer, i16 -1) ; <<16 x float>> [#uses=1]
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
index c5d91adf64cb3..e5cbe8c132238 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
@@ -19,7 +19,6 @@
; - llvm.x86.avx512fp16.mask.reduce.sh
; - llvm.x86.avx512fp16.mask.rndscale.ph.512
; - llvm.x86.avx512fp16.mask.rndscale.sh
-; - llvm.x86.avx512fp16.mask.rsqrt.ph.512
; - llvm.x86.avx512fp16.mask.rsqrt.sh
; - llvm.x86.avx512fp16.mask.scalef.ph.512
; - llvm.x86.avx512fp16.mask.scalef.sh
@@ -442,15 +441,11 @@ define <32 x half> @test_rsqrt_ph_512(<32 x half> %a0) #0 {
; CHECK-SAME: <32 x half> [[A0:%.*]]) #[[ATTR1]] {
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP2:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT: br i1 [[_MSCMP]], label %[[BB3:.*]], label %[[BB4:.*]], !prof [[PROF1]]
-; CHECK: [[BB3]]:
-; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR8]]
-; CHECK-NEXT: unreachable
-; CHECK: [[BB4]]:
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = sext <32 x i1> [[TMP2]] to <32 x i16>
+; CHECK-NEXT: [[TMP4:%.*]] = select <32 x i1> splat (i1 true), <32 x i16> [[TMP3]], <32 x i16> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512(<32 x half> [[A0]], <32 x half> zeroinitializer, i32 -1)
-; CHECK-NEXT: store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <32 x i16> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <32 x half> [[RES]]
;
%res = call <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512(<32 x half> %a0, <32 x half> zeroinitializer, i32 -1)
@@ -681,24 +676,22 @@ declare <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half>, <32 x half
define <32 x half> @test_rcp_ph_512(<32 x half> %a0, <32 x half> %a1, i32 %mask) #0 {
; CHECK-LABEL: define <32 x half> @test_rcp_ph_512(
; CHECK-SAME: <32 x half> [[A0:%.*]], <32 x half> [[A1:%.*]], i32 [[MASK:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
-; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP4:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0
-; CHECK-NEXT: [[TMP5:%.*]] = bitcast <32 x i16> [[TMP2]] to i512
-; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0
-; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast i32 [[MASK]] to <32 x i1>
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP6:%.*]] = sext <32 x i1> [[TMP5]] to <32 x i16>
+; CHECK-NEXT: [[TMP7:%.*]] = select <32 x i1> [[TMP4]], <32 x i16> [[TMP6]], <32 x i16> [[TMP2]]
; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i32 [[TMP3]], 0
-; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
-; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]]
-; CHECK: [[BB6]]:
+; CHECK-NEXT: br i1 [[_MSCMP2]], label %[[BB8:.*]], label %[[BB9:.*]], !prof [[PROF1]]
+; CHECK: [[BB8]]:
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR8]]
; CHECK-NEXT: unreachable
-; CHECK: [[BB7]]:
+; CHECK: [[BB9]]:
; CHECK-NEXT: [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half> [[A0]], <32 x half> [[A1]], i32 [[MASK]])
-; CHECK-NEXT: store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <32 x i16> [[TMP7]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <32 x half> [[RES]]
;
%res = call <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half> %a0, <32 x half> %a1, i32 %mask)
@@ -3260,3 +3253,6 @@ define <32 x half> @test_mm512_castph256_ph512_freeze(<16 x half> %a0) nounwind
}
attributes #0 = { sanitize_memory }
+;.
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 1, i32 1048575}
+;.
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
index e2dc8cbdca968..20114fe7d3151 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
@@ -63,10 +63,6 @@
; - llvm.x86.avx512.permvar.df.256
; - llvm.x86.avx512.pternlog.d.128, llvm.x86.avx512.pternlog.d.256
; - llvm.x86.avx512.pternlog.q.128, llvm.x86.avx512.pternlog.q.256
-; - llvm.x86.avx512.rcp14.pd.128, llvm.x86.avx512.rcp14.pd.256
-; - llvm.x86.avx512.rcp14.ps.128, llvm.x86.avx512.rcp14.ps.256
-; - llvm.x86.avx512.rsqrt14.pd.128, llvm.x86.avx512.rsqrt14.pd.256
-; - llvm.x86.avx512.rsqrt14.ps.128, llvm.x86.avx512.rsqrt14.ps.256
;
; Handled heuristically: (none)
@@ -8066,15 +8062,11 @@ define <8 x float> @test_rsqrt_ps_256_rr(<8 x float> %a0) #0 {
; CHECK-SAME: <8 x float> [[A0:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i32> [[TMP1]] to i256
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i256 [[TMP2]], 0
-; CHECK-NEXT: br i1 [[_MSCMP]], label %[[BB3:.*]], label %[[BB4:.*]], !prof [[PROF1]]
-; CHECK: [[BB3]]:
-; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR6]]
-; CHECK-NEXT: unreachable
-; CHECK: [[BB4]]:
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <8 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i32>
+; CHECK-NEXT: [[TMP4:%.*]] = select <8 x i1> splat (i1 true), <8 x i32> [[TMP3]], <8 x i32> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx512.rsqrt14.ps.256(<8 x float> [[A0]], <8 x float> zeroinitializer, i8 -1)
-; CHECK-NEXT: store <8 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <8 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x float> [[RES]]
;
%res = call <8 x float> @llvm.x86.avx512.rsqrt14.ps.256(<8 x float> %a0, <8 x float> zeroinitializer, i8 -1)
@@ -8085,20 +8077,21 @@ define <8 x float> @test_rsqrt_ps_256_rrkz(<8 x float> %a0, i8 %mask) #0 {
;
; CHECK-LABEL: define <8 x float> @test_rsqrt_ps_256_rrkz(
; CHECK-SAME: <8 x float> [[A0:%.*]], i8 [[MASK:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT: [[TMP1:%.*]] = load <8...
[truncated]
|
@llvm/pr-subscribers-compiler-rt-sanitizer Author: Thurston Dang (thurstond) ChangesAdds a new handler, handleAVX512VectorGenericMaskedFP(), and applies it to rcp and rsqrt Patch is 63.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158397.diff 4 Files Affected:
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 3ea790ad1839a..7933604b8ac25 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -4911,6 +4911,69 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOriginForNaryOp(I);
}
+ // Handle llvm.x86.avx512.* instructions that take a vector of floating-point
+ // values and perform an operation whose shadow propagation should be handled
+ // as all-or-nothing [*], with masking provided by a vector and a mask
+ // supplied as an integer.
+ //
+ // [*] if all bits of a vector element are initialized, the output is fully
+ // initialized; otherwise, the output is fully uninitialized
+ //
+ // e.g., <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
+ // (<16 x float>, <16 x float>, i16)
+ // A WriteThru Mask
+ //
+ // <2 x double> @llvm.x86.avx512.rcp14.pd.128
+ // (<2 x double>, <2 x double>, i8)
+ //
+ // Dst[i] = Mask[i] ? some_op(A[i]) : WriteThru[i]
+ // Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i]
+ void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I) {
+ IRBuilder<> IRB(&I);
+
+ assert(I.arg_size() == 3);
+ Value *A = I.getOperand(0);
+ Value *WriteThrough = I.getOperand(1);
+ Value *Mask = I.getOperand(2);
+
+ assert(isFixedFPVector(A));
+ assert(isFixedFPVector(WriteThrough));
+
+ [[maybe_unused]] unsigned ANumElements =
+ cast<FixedVectorType>(A->getType())->getNumElements();
+ unsigned OutputNumElements =
+ cast<FixedVectorType>(WriteThrough->getType())->getNumElements();
+ assert(ANumElements == OutputNumElements);
+
+ assert(Mask->getType()->isIntegerTy());
+ // Some bits of the mask might be unused, but check them all anyway
+ // (typically the mask is an integer constant).
+ insertCheckShadowOf(Mask, &I);
+
+ // The mask has 1 bit per element of A, but a minimum of 8 bits.
+ if (Mask->getType()->getScalarSizeInBits() == 8 && ANumElements < 8)
+ Mask = IRB.CreateTrunc(Mask, Type::getIntNTy(*MS.C, ANumElements));
+ assert(Mask->getType()->getScalarSizeInBits() == ANumElements);
+
+ assert(I.getType() == WriteThrough->getType());
+
+ Mask = IRB.CreateBitCast(
+ Mask, FixedVectorType::get(IRB.getInt1Ty(), OutputNumElements));
+
+ Value *AShadow = getShadow(A);
+
+ // All-or-nothing shadow
+ AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow)),
+ AShadow->getType());
+
+ Value *WriteThroughShadow = getShadow(WriteThrough);
+
+ Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow);
+ setShadow(&I, Shadow);
+
+ setOriginForNaryOp(I);
+ }
+
// For sh.* compiler intrinsics:
// llvm.x86.avx512fp16.mask.{add/sub/mul/div/max/min}.sh.round
// (<8 x half>, <8 x half>, <8 x half>, i8, i32)
@@ -6091,6 +6154,108 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
break;
}
+ // AVX512/AVX10 Reciprocal
+ // <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
+ // (<16 x float>, <16 x float>, i16)
+ // <8 x float> @llvm.x86.avx512.rsqrt14.ps.256
+ // (<8 x float>, <8 x float>, i8)
+ // <4 x float> @llvm.x86.avx512.rsqrt14.ps.128
+ // (<4 x float>, <4 x float>, i8)
+ //
+ // <8 x double> @llvm.x86.avx512.rsqrt14.pd.512
+ // (<8 x double>, <8 x double>, i8)
+ // <4 x double> @llvm.x86.avx512.rsqrt14.pd.256
+ // (<4 x double>, <4 x double>, i8)
+ // <2 x double> @llvm.x86.avx512.rsqrt14.pd.128
+ // (<2 x double>, <2 x double>, i8)
+ //
+ // <32 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.512
+ // (<32 x bfloat>, <32 x bfloat>, i32)
+ // <16 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.256
+ // (<16 x bfloat>, <16 x bfloat>, i16)
+ // <8 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.128
+ // (<8 x bfloat>, <8 x bfloat>, i8)
+ //
+ // <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512
+ // (<32 x half>, <32 x half>, i32)
+ // <16 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.256
+ // (<16 x half>, <16 x half>, i16)
+ // <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.128
+ // (<8 x half>, <8 x half>, i8)
+ //
+ // TODO: 3-operand variants are not handled:
+ // <2 x double> @llvm.x86.avx512.rsqrt14.sd
+ // (<2 x double>, <2 x double>, <2 x double>, i8)
+ // <4 x float> @llvm.x86.avx512.rsqrt14.ss
+ // (<4 x float>, <4 x float>, <4 x float>, i8)
+ // <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.sh
+ // (<8 x half>, <8 x half>, <8 x half>, i8)
+ case Intrinsic::x86_avx512_rsqrt14_ps_512:
+ case Intrinsic::x86_avx512_rsqrt14_ps_256:
+ case Intrinsic::x86_avx512_rsqrt14_ps_128:
+ case Intrinsic::x86_avx512_rsqrt14_pd_512:
+ case Intrinsic::x86_avx512_rsqrt14_pd_256:
+ case Intrinsic::x86_avx512_rsqrt14_pd_128:
+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_512:
+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_256:
+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_128:
+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512:
+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256:
+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128:
+ handleAVX512VectorGenericMaskedFP(I);
+ break;
+
+ // AVX512/AVX10 Reciprocal Square Root
+ // <16 x float> @llvm.x86.avx512.rcp14.ps.512
+ // (<16 x float>, <16 x float>, i16)
+ // <8 x float> @llvm.x86.avx512.rcp14.ps.256
+ // (<8 x float>, <8 x float>, i8)
+ // <4 x float> @llvm.x86.avx512.rcp14.ps.128
+ // (<4 x float>, <4 x float>, i8)
+ //
+ // <8 x double> @llvm.x86.avx512.rcp14.pd.512
+ // (<8 x double>, <8 x double>, i8)
+ // <4 x double> @llvm.x86.avx512.rcp14.pd.256
+ // (<4 x double>, <4 x double>, i8)
+ // <2 x double> @llvm.x86.avx512.rcp14.pd.128
+ // (<2 x double>, <2 x double>, i8)
+ //
+ // <32 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.512
+ // (<32 x bfloat>, <32 x bfloat>, i32)
+ // <16 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.256
+ // (<16 x bfloat>, <16 x bfloat>, i16)
+ // <8 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.128
+ // (<8 x bfloat>, <8 x bfloat>, i8)
+ //
+ // <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512
+ // (<32 x half>, <32 x half>, i32)
+ // <16 x half> @llvm.x86.avx512fp16.mask.rcp.ph.256
+ // (<16 x half>, <16 x half>, i16)
+ // <8 x half> @llvm.x86.avx512fp16.mask.rcp.ph.128
+ // (<8 x half>, <8 x half>, i8)
+ //
+ // TODO: 3-operand variants are not handled:
+ // <2 x double> @llvm.x86.avx512.rcp14.sd
+ // (<2 x double>, <2 x double>, <2 x double>, i8)
+ // <4 x float> @llvm.x86.avx512.rcp14.ss
+ // (<4 x float>, <4 x float>, <4 x float>, i8)
+ // <8 x half> @llvm.x86.avx512fp16.mask.rcp.sh
+ // (<8 x half>, <8 x half>, <8 x half>, i8)
+ case Intrinsic::x86_avx512_rcp14_ps_512:
+ case Intrinsic::x86_avx512_rcp14_ps_256:
+ case Intrinsic::x86_avx512_rcp14_ps_128:
+ case Intrinsic::x86_avx512_rcp14_pd_512:
+ case Intrinsic::x86_avx512_rcp14_pd_256:
+ case Intrinsic::x86_avx512_rcp14_pd_128:
+ case Intrinsic::x86_avx10_mask_rcp_bf16_512:
+ case Intrinsic::x86_avx10_mask_rcp_bf16_256:
+ case Intrinsic::x86_avx10_mask_rcp_bf16_128:
+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_512:
+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_256:
+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_128:
+ handleAVX512VectorGenericMaskedFP(I);
+ break;
+
// AVX512 FP16 Arithmetic
case Intrinsic::x86_avx512fp16_mask_add_sh_round:
case Intrinsic::x86_avx512fp16_mask_sub_sh_round:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
index a2f1d65e7cd41..b2a4f0e582f9e 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
@@ -28,8 +28,6 @@
; - llvm.x86.avx512.mul.pd.512, llvm.x86.avx512.mul.ps.512
; - llvm.x86.avx512.permvar.df.512, llvm.x86.avx512.permvar.sf.512
; - llvm.x86.avx512.pternlog.d.512, llvm.x86.avx512.pternlog.q.512
-; - llvm.x86.avx512.rcp14.pd.512, llvm.x86.avx512.rcp14.ps.512
-; - llvm.x86.avx512.rsqrt14.ps.512
; - llvm.x86.avx512.sitofp.round.v16f32.v16i32
; - llvm.x86.avx512.sqrt.pd.512, llvm.x86.avx512.sqrt.ps.512
; - llvm.x86.avx512.sub.ps.512
@@ -682,15 +680,11 @@ define <16 x float> @test_rcp_ps_512(<16 x float> %a0) #0 {
; CHECK-LABEL: @test_rcp_ps_512(
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK: 3:
-; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT: unreachable
-; CHECK: 4:
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
+; CHECK-NEXT: [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.rcp14.ps.512(<16 x float> [[A0:%.*]], <16 x float> zeroinitializer, i16 -1)
-; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <16 x float> [[RES]]
;
%res = call <16 x float> @llvm.x86.avx512.rcp14.ps.512(<16 x float> %a0, <16 x float> zeroinitializer, i16 -1) ; <<16 x float>> [#uses=1]
@@ -702,15 +696,11 @@ define <8 x double> @test_rcp_pd_512(<8 x double> %a0) #0 {
; CHECK-LABEL: @test_rcp_pd_512(
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i64>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i64> [[TMP1]] to i512
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK: 3:
-; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT: unreachable
-; CHECK: 4:
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <8 x i64> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i64>
+; CHECK-NEXT: [[TMP4:%.*]] = select <8 x i1> splat (i1 true), <8 x i64> [[TMP3]], <8 x i64> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <8 x double> @llvm.x86.avx512.rcp14.pd.512(<8 x double> [[A0:%.*]], <8 x double> zeroinitializer, i8 -1)
-; CHECK-NEXT: store <8 x i64> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <8 x i64> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x double> [[RES]]
;
%res = call <8 x double> @llvm.x86.avx512.rcp14.pd.512(<8 x double> %a0, <8 x double> zeroinitializer, i8 -1) ; <<8 x double>> [#uses=1]
@@ -1021,15 +1011,11 @@ define <16 x float> @test_rsqrt_ps_512(<16 x float> %a0) #0 {
; CHECK-LABEL: @test_rsqrt_ps_512(
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
-; CHECK: 3:
-; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
-; CHECK-NEXT: unreachable
-; CHECK: 4:
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
+; CHECK-NEXT: [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.rsqrt14.ps.512(<16 x float> [[A0:%.*]], <16 x float> zeroinitializer, i16 -1)
-; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <16 x float> [[RES]]
;
%res = call <16 x float> @llvm.x86.avx512.rsqrt14.ps.512(<16 x float> %a0, <16 x float> zeroinitializer, i16 -1) ; <<16 x float>> [#uses=1]
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
index c5d91adf64cb3..e5cbe8c132238 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll
@@ -19,7 +19,6 @@
; - llvm.x86.avx512fp16.mask.reduce.sh
; - llvm.x86.avx512fp16.mask.rndscale.ph.512
; - llvm.x86.avx512fp16.mask.rndscale.sh
-; - llvm.x86.avx512fp16.mask.rsqrt.ph.512
; - llvm.x86.avx512fp16.mask.rsqrt.sh
; - llvm.x86.avx512fp16.mask.scalef.ph.512
; - llvm.x86.avx512fp16.mask.scalef.sh
@@ -442,15 +441,11 @@ define <32 x half> @test_rsqrt_ph_512(<32 x half> %a0) #0 {
; CHECK-SAME: <32 x half> [[A0:%.*]]) #[[ATTR1]] {
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP2:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
-; CHECK-NEXT: br i1 [[_MSCMP]], label %[[BB3:.*]], label %[[BB4:.*]], !prof [[PROF1]]
-; CHECK: [[BB3]]:
-; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR8]]
-; CHECK-NEXT: unreachable
-; CHECK: [[BB4]]:
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = sext <32 x i1> [[TMP2]] to <32 x i16>
+; CHECK-NEXT: [[TMP4:%.*]] = select <32 x i1> splat (i1 true), <32 x i16> [[TMP3]], <32 x i16> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512(<32 x half> [[A0]], <32 x half> zeroinitializer, i32 -1)
-; CHECK-NEXT: store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <32 x i16> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <32 x half> [[RES]]
;
%res = call <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512(<32 x half> %a0, <32 x half> zeroinitializer, i32 -1)
@@ -681,24 +676,22 @@ declare <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half>, <32 x half
define <32 x half> @test_rcp_ph_512(<32 x half> %a0, <32 x half> %a1, i32 %mask) #0 {
; CHECK-LABEL: define <32 x half> @test_rcp_ph_512(
; CHECK-SAME: <32 x half> [[A0:%.*]], <32 x half> [[A1:%.*]], i32 [[MASK:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
-; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP4:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0
-; CHECK-NEXT: [[TMP5:%.*]] = bitcast <32 x i16> [[TMP2]] to i512
-; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0
-; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
+; CHECK-NEXT: [[TMP4:%.*]] = bitcast i32 [[MASK]] to <32 x i1>
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP6:%.*]] = sext <32 x i1> [[TMP5]] to <32 x i16>
+; CHECK-NEXT: [[TMP7:%.*]] = select <32 x i1> [[TMP4]], <32 x i16> [[TMP6]], <32 x i16> [[TMP2]]
; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i32 [[TMP3]], 0
-; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
-; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]]
-; CHECK: [[BB6]]:
+; CHECK-NEXT: br i1 [[_MSCMP2]], label %[[BB8:.*]], label %[[BB9:.*]], !prof [[PROF1]]
+; CHECK: [[BB8]]:
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR8]]
; CHECK-NEXT: unreachable
-; CHECK: [[BB7]]:
+; CHECK: [[BB9]]:
; CHECK-NEXT: [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half> [[A0]], <32 x half> [[A1]], i32 [[MASK]])
-; CHECK-NEXT: store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <32 x i16> [[TMP7]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <32 x half> [[RES]]
;
%res = call <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half> %a0, <32 x half> %a1, i32 %mask)
@@ -3260,3 +3253,6 @@ define <32 x half> @test_mm512_castph256_ph512_freeze(<16 x half> %a0) nounwind
}
attributes #0 = { sanitize_memory }
+;.
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 1, i32 1048575}
+;.
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
index e2dc8cbdca968..20114fe7d3151 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl-intrinsics.ll
@@ -63,10 +63,6 @@
; - llvm.x86.avx512.permvar.df.256
; - llvm.x86.avx512.pternlog.d.128, llvm.x86.avx512.pternlog.d.256
; - llvm.x86.avx512.pternlog.q.128, llvm.x86.avx512.pternlog.q.256
-; - llvm.x86.avx512.rcp14.pd.128, llvm.x86.avx512.rcp14.pd.256
-; - llvm.x86.avx512.rcp14.ps.128, llvm.x86.avx512.rcp14.ps.256
-; - llvm.x86.avx512.rsqrt14.pd.128, llvm.x86.avx512.rsqrt14.pd.256
-; - llvm.x86.avx512.rsqrt14.ps.128, llvm.x86.avx512.rsqrt14.ps.256
;
; Handled heuristically: (none)
@@ -8066,15 +8062,11 @@ define <8 x float> @test_rsqrt_ps_256_rr(<8 x float> %a0) #0 {
; CHECK-SAME: <8 x float> [[A0:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
-; CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i32> [[TMP1]] to i256
-; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i256 [[TMP2]], 0
-; CHECK-NEXT: br i1 [[_MSCMP]], label %[[BB3:.*]], label %[[BB4:.*]], !prof [[PROF1]]
-; CHECK: [[BB3]]:
-; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR6]]
-; CHECK-NEXT: unreachable
-; CHECK: [[BB4]]:
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <8 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i32>
+; CHECK-NEXT: [[TMP4:%.*]] = select <8 x i1> splat (i1 true), <8 x i32> [[TMP3]], <8 x i32> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx512.rsqrt14.ps.256(<8 x float> [[A0]], <8 x float> zeroinitializer, i8 -1)
-; CHECK-NEXT: store <8 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT: store <8 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x float> [[RES]]
;
%res = call <8 x float> @llvm.x86.avx512.rsqrt14.ps.256(<8 x float> %a0, <8 x float> zeroinitializer, i8 -1)
@@ -8085,20 +8077,21 @@ define <8 x float> @test_rsqrt_ps_256_rrkz(<8 x float> %a0, i8 %mask) #0 {
;
; CHECK-LABEL: define <8 x float> @test_rsqrt_ps_256_rrkz(
; CHECK-SAME: <8 x float> [[A0:%.*]], i8 [[MASK:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT: [[TMP1:%.*]] = load <8...
[truncated]
|
Adds a new handler, handleAVX512VectorGenericMaskedFP(), and applies it to AVX512/AVX10 rcp and rsqrt