Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4911,6 +4911,69 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOriginForNaryOp(I);
}

// Handle llvm.x86.avx512.* instructions that take a vector of floating-point
// values and perform an operation whose shadow propagation should be handled
// as all-or-nothing [*], with masking provided by a vector and a mask
// supplied as an integer.
//
// [*] if all bits of a vector element are initialized, the output is fully
// initialized; otherwise, the output is fully uninitialized
//
// e.g., <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
// (<16 x float>, <16 x float>, i16)
// A WriteThru Mask
//
// <2 x double> @llvm.x86.avx512.rcp14.pd.128
// (<2 x double>, <2 x double>, i8)
//
// Dst[i] = Mask[i] ? some_op(A[i]) : WriteThru[i]
// Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i]
void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I) {
IRBuilder<> IRB(&I);

assert(I.arg_size() == 3);
Value *A = I.getOperand(0);
Value *WriteThrough = I.getOperand(1);
Value *Mask = I.getOperand(2);

assert(isFixedFPVector(A));
assert(isFixedFPVector(WriteThrough));

[[maybe_unused]] unsigned ANumElements =
cast<FixedVectorType>(A->getType())->getNumElements();
unsigned OutputNumElements =
cast<FixedVectorType>(WriteThrough->getType())->getNumElements();
assert(ANumElements == OutputNumElements);

assert(Mask->getType()->isIntegerTy());
// Some bits of the mask might be unused, but check them all anyway
// (typically the mask is an integer constant).
insertCheckShadowOf(Mask, &I);

// The mask has 1 bit per element of A, but a minimum of 8 bits.
if (Mask->getType()->getScalarSizeInBits() == 8 && ANumElements < 8)
Mask = IRB.CreateTrunc(Mask, Type::getIntNTy(*MS.C, ANumElements));
assert(Mask->getType()->getScalarSizeInBits() == ANumElements);

assert(I.getType() == WriteThrough->getType());

Mask = IRB.CreateBitCast(
Mask, FixedVectorType::get(IRB.getInt1Ty(), OutputNumElements));

Value *AShadow = getShadow(A);

// All-or-nothing shadow
AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow)),
AShadow->getType());

Value *WriteThroughShadow = getShadow(WriteThrough);

Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow);
setShadow(&I, Shadow);

setOriginForNaryOp(I);
}

// For sh.* compiler intrinsics:
// llvm.x86.avx512fp16.mask.{add/sub/mul/div/max/min}.sh.round
// (<8 x half>, <8 x half>, <8 x half>, i8, i32)
Expand Down Expand Up @@ -6091,6 +6154,108 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
break;
}

// AVX512/AVX10 Reciprocal
// <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
// (<16 x float>, <16 x float>, i16)
// <8 x float> @llvm.x86.avx512.rsqrt14.ps.256
// (<8 x float>, <8 x float>, i8)
// <4 x float> @llvm.x86.avx512.rsqrt14.ps.128
// (<4 x float>, <4 x float>, i8)
//
// <8 x double> @llvm.x86.avx512.rsqrt14.pd.512
// (<8 x double>, <8 x double>, i8)
// <4 x double> @llvm.x86.avx512.rsqrt14.pd.256
// (<4 x double>, <4 x double>, i8)
// <2 x double> @llvm.x86.avx512.rsqrt14.pd.128
// (<2 x double>, <2 x double>, i8)
//
// <32 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.512
// (<32 x bfloat>, <32 x bfloat>, i32)
// <16 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.256
// (<16 x bfloat>, <16 x bfloat>, i16)
// <8 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.128
// (<8 x bfloat>, <8 x bfloat>, i8)
//
// <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512
// (<32 x half>, <32 x half>, i32)
// <16 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.256
// (<16 x half>, <16 x half>, i16)
// <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.128
// (<8 x half>, <8 x half>, i8)
//
// TODO: 3-operand variants are not handled:
// <2 x double> @llvm.x86.avx512.rsqrt14.sd
// (<2 x double>, <2 x double>, <2 x double>, i8)
// <4 x float> @llvm.x86.avx512.rsqrt14.ss
// (<4 x float>, <4 x float>, <4 x float>, i8)
// <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.sh
// (<8 x half>, <8 x half>, <8 x half>, i8)
case Intrinsic::x86_avx512_rsqrt14_ps_512:
case Intrinsic::x86_avx512_rsqrt14_ps_256:
case Intrinsic::x86_avx512_rsqrt14_ps_128:
case Intrinsic::x86_avx512_rsqrt14_pd_512:
case Intrinsic::x86_avx512_rsqrt14_pd_256:
case Intrinsic::x86_avx512_rsqrt14_pd_128:
case Intrinsic::x86_avx10_mask_rsqrt_bf16_512:
case Intrinsic::x86_avx10_mask_rsqrt_bf16_256:
case Intrinsic::x86_avx10_mask_rsqrt_bf16_128:
case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512:
case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256:
case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128:
handleAVX512VectorGenericMaskedFP(I);
break;

// AVX512/AVX10 Reciprocal Square Root
// <16 x float> @llvm.x86.avx512.rcp14.ps.512
// (<16 x float>, <16 x float>, i16)
// <8 x float> @llvm.x86.avx512.rcp14.ps.256
// (<8 x float>, <8 x float>, i8)
// <4 x float> @llvm.x86.avx512.rcp14.ps.128
// (<4 x float>, <4 x float>, i8)
//
// <8 x double> @llvm.x86.avx512.rcp14.pd.512
// (<8 x double>, <8 x double>, i8)
// <4 x double> @llvm.x86.avx512.rcp14.pd.256
// (<4 x double>, <4 x double>, i8)
// <2 x double> @llvm.x86.avx512.rcp14.pd.128
// (<2 x double>, <2 x double>, i8)
//
// <32 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.512
// (<32 x bfloat>, <32 x bfloat>, i32)
// <16 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.256
// (<16 x bfloat>, <16 x bfloat>, i16)
// <8 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.128
// (<8 x bfloat>, <8 x bfloat>, i8)
//
// <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512
// (<32 x half>, <32 x half>, i32)
// <16 x half> @llvm.x86.avx512fp16.mask.rcp.ph.256
// (<16 x half>, <16 x half>, i16)
// <8 x half> @llvm.x86.avx512fp16.mask.rcp.ph.128
// (<8 x half>, <8 x half>, i8)
//
// TODO: 3-operand variants are not handled:
// <2 x double> @llvm.x86.avx512.rcp14.sd
// (<2 x double>, <2 x double>, <2 x double>, i8)
// <4 x float> @llvm.x86.avx512.rcp14.ss
// (<4 x float>, <4 x float>, <4 x float>, i8)
// <8 x half> @llvm.x86.avx512fp16.mask.rcp.sh
// (<8 x half>, <8 x half>, <8 x half>, i8)
case Intrinsic::x86_avx512_rcp14_ps_512:
case Intrinsic::x86_avx512_rcp14_ps_256:
case Intrinsic::x86_avx512_rcp14_ps_128:
case Intrinsic::x86_avx512_rcp14_pd_512:
case Intrinsic::x86_avx512_rcp14_pd_256:
case Intrinsic::x86_avx512_rcp14_pd_128:
case Intrinsic::x86_avx10_mask_rcp_bf16_512:
case Intrinsic::x86_avx10_mask_rcp_bf16_256:
case Intrinsic::x86_avx10_mask_rcp_bf16_128:
case Intrinsic::x86_avx512fp16_mask_rcp_ph_512:
case Intrinsic::x86_avx512fp16_mask_rcp_ph_256:
case Intrinsic::x86_avx512fp16_mask_rcp_ph_128:
handleAVX512VectorGenericMaskedFP(I);
break;

// AVX512 FP16 Arithmetic
case Intrinsic::x86_avx512fp16_mask_add_sh_round:
case Intrinsic::x86_avx512fp16_mask_sub_sh_round:
Expand Down
38 changes: 12 additions & 26 deletions llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
; - llvm.x86.avx512.mul.pd.512, llvm.x86.avx512.mul.ps.512
; - llvm.x86.avx512.permvar.df.512, llvm.x86.avx512.permvar.sf.512
; - llvm.x86.avx512.pternlog.d.512, llvm.x86.avx512.pternlog.q.512
; - llvm.x86.avx512.rcp14.pd.512, llvm.x86.avx512.rcp14.ps.512
; - llvm.x86.avx512.rsqrt14.ps.512
; - llvm.x86.avx512.sitofp.round.v16f32.v16i32
; - llvm.x86.avx512.sqrt.pd.512, llvm.x86.avx512.sqrt.ps.512
; - llvm.x86.avx512.sub.ps.512
Expand Down Expand Up @@ -682,15 +680,11 @@ define <16 x float> @test_rcp_ps_512(<16 x float> %a0) #0 {
; CHECK-LABEL: @test_rcp_ps_512(
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
; CHECK: 3:
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
; CHECK-NEXT: unreachable
; CHECK: 4:
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.rcp14.ps.512(<16 x float> [[A0:%.*]], <16 x float> zeroinitializer, i16 -1)
; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
; CHECK-NEXT: store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <16 x float> [[RES]]
;
%res = call <16 x float> @llvm.x86.avx512.rcp14.ps.512(<16 x float> %a0, <16 x float> zeroinitializer, i16 -1) ; <<16 x float>> [#uses=1]
Expand All @@ -702,15 +696,11 @@ define <8 x double> @test_rcp_pd_512(<8 x double> %a0) #0 {
; CHECK-LABEL: @test_rcp_pd_512(
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i64>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i64> [[TMP1]] to i512
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
; CHECK: 3:
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
; CHECK-NEXT: unreachable
; CHECK: 4:
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <8 x i64> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i64>
; CHECK-NEXT: [[TMP4:%.*]] = select <8 x i1> splat (i1 true), <8 x i64> [[TMP3]], <8 x i64> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <8 x double> @llvm.x86.avx512.rcp14.pd.512(<8 x double> [[A0:%.*]], <8 x double> zeroinitializer, i8 -1)
; CHECK-NEXT: store <8 x i64> zeroinitializer, ptr @__msan_retval_tls, align 8
; CHECK-NEXT: store <8 x i64> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x double> [[RES]]
;
%res = call <8 x double> @llvm.x86.avx512.rcp14.pd.512(<8 x double> %a0, <8 x double> zeroinitializer, i8 -1) ; <<8 x double>> [#uses=1]
Expand Down Expand Up @@ -1021,15 +1011,11 @@ define <16 x float> @test_rsqrt_ps_512(<16 x float> %a0) #0 {
; CHECK-LABEL: @test_rsqrt_ps_512(
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP3:%.*]], label [[TMP4:%.*]], !prof [[PROF1]]
; CHECK: 3:
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
; CHECK-NEXT: unreachable
; CHECK: 4:
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.rsqrt14.ps.512(<16 x float> [[A0:%.*]], <16 x float> zeroinitializer, i16 -1)
; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
; CHECK-NEXT: store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <16 x float> [[RES]]
;
%res = call <16 x float> @llvm.x86.avx512.rsqrt14.ps.512(<16 x float> %a0, <16 x float> zeroinitializer, i16 -1) ; <<16 x float>> [#uses=1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
; - llvm.x86.avx512fp16.mask.reduce.sh
; - llvm.x86.avx512fp16.mask.rndscale.ph.512
; - llvm.x86.avx512fp16.mask.rndscale.sh
; - llvm.x86.avx512fp16.mask.rsqrt.ph.512
; - llvm.x86.avx512fp16.mask.rsqrt.sh
; - llvm.x86.avx512fp16.mask.scalef.ph.512
; - llvm.x86.avx512fp16.mask.scalef.sh
Expand Down Expand Up @@ -442,15 +441,11 @@ define <32 x half> @test_rsqrt_ph_512(<32 x half> %a0) #0 {
; CHECK-SAME: <32 x half> [[A0:%.*]]) #[[ATTR1]] {
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
; CHECK-NEXT: br i1 [[_MSCMP]], label %[[BB3:.*]], label %[[BB4:.*]], !prof [[PROF1]]
; CHECK: [[BB3]]:
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR8]]
; CHECK-NEXT: unreachable
; CHECK: [[BB4]]:
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = sext <32 x i1> [[TMP2]] to <32 x i16>
; CHECK-NEXT: [[TMP4:%.*]] = select <32 x i1> splat (i1 true), <32 x i16> [[TMP3]], <32 x i16> zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512(<32 x half> [[A0]], <32 x half> zeroinitializer, i32 -1)
; CHECK-NEXT: store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
; CHECK-NEXT: store <32 x i16> [[TMP4]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <32 x half> [[RES]]
;
%res = call <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512(<32 x half> %a0, <32 x half> zeroinitializer, i32 -1)
Expand Down Expand Up @@ -681,24 +676,22 @@ declare <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half>, <32 x half
define <32 x half> @test_rcp_ph_512(<32 x half> %a0, <32 x half> %a1, i32 %mask) #0 {
; CHECK-LABEL: define <32 x half> @test_rcp_ph_512(
; CHECK-SAME: <32 x half> [[A0:%.*]], <32 x half> [[A1:%.*]], i32 [[MASK:%.*]]) #[[ATTR1]] {
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <32 x i16> [[TMP2]] to i512
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i32 [[MASK]] to <32 x i1>
; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP6:%.*]] = sext <32 x i1> [[TMP5]] to <32 x i16>
; CHECK-NEXT: [[TMP7:%.*]] = select <32 x i1> [[TMP4]], <32 x i16> [[TMP6]], <32 x i16> [[TMP2]]
; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i32 [[TMP3]], 0
; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]]
; CHECK: [[BB6]]:
; CHECK-NEXT: br i1 [[_MSCMP2]], label %[[BB8:.*]], label %[[BB9:.*]], !prof [[PROF1]]
; CHECK: [[BB8]]:
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR8]]
; CHECK-NEXT: unreachable
; CHECK: [[BB7]]:
; CHECK: [[BB9]]:
; CHECK-NEXT: [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half> [[A0]], <32 x half> [[A1]], i32 [[MASK]])
; CHECK-NEXT: store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
; CHECK-NEXT: store <32 x i16> [[TMP7]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <32 x half> [[RES]]
;
%res = call <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512(<32 x half> %a0, <32 x half> %a1, i32 %mask)
Expand Down Expand Up @@ -3260,3 +3253,6 @@ define <32 x half> @test_mm512_castph256_ph512_freeze(<16 x half> %a0) nounwind
}

attributes #0 = { sanitize_memory }
;.
; CHECK: [[PROF1]] = !{!"branch_weights", i32 1, i32 1048575}
;.
Loading